├── .gitignore ├── cnns.py ├── datasets ├── loadCIFAR10.py ├── loadCIFAR100.py ├── loadFashionMNIST.py ├── loadMNIST.py ├── loadNMNIST_Spiking.py ├── loadSMNIST.py ├── loadSpiking.py └── utils.py ├── global_v.py ├── layers ├── .vscode │ └── settings.json ├── conv.py ├── cpp_wrapper.cpp ├── dropout.py ├── functions.py ├── linear.py ├── losses.py ├── neuron_cuda.cpp ├── neuron_cuda_kernel.cu └── pooling.py ├── main.py ├── network_parser.py ├── networks ├── CIFAR10.yaml ├── CIFAR100.yaml ├── FashionMNIST.yaml ├── RESNET.yaml ├── TTFS.yaml ├── mnist.yaml ├── n-mnist.yaml └── resnet.py ├── readme.md └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /cnns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import layers.conv as conv 4 | import layers.pooling as pooling 5 | import layers.dropout as dropout 6 | import layers.linear as linear 7 | from math import * 8 | import global_v as glv 9 | 10 | 11 | class Network(nn.Module): 12 | def __init__(self, input_shape=None): 13 | super(Network, self).__init__() 14 | self.layers = [] 15 | network_config, layers_config = glv.network_config, glv.layers_config 16 | print("Network Structure:") 17 | for key in layers_config: 18 | c = layers_config[key] 19 | if c['type'] == 'conv': 20 | self.layers.append(conv.ConvLayer(network_config, c, key)) 21 | elif c['type'] == 'linear': 22 | self.layers.append(linear.LinearLayer(network_config, c, key)) 23 | elif c['type'] == 'pooling': 24 | self.layers.append(pooling.PoolLayer(network_config, c, key)) 25 | elif c['type'] == 'dropout': 26 | self.layers.append(dropout.DropoutLayer(c, key)) 27 | else: 28 | raise Exception('Undefined layer type. It is: {}'.format(c['type'])) 29 | 30 | self.net = nn.Sequential(*self.layers) 31 | print("-----------------------------------------") 32 | 33 | def forward(self, inputs, labels, epoch, is_train): 34 | assert(is_train or labels==None) 35 | # spikes = f.psp(spike_input, self.network_config) 36 | spikes = inputs 37 | 38 | for i, l in enumerate(self.layers): 39 | if l.type == "dropout": 40 | if is_train: 41 | spikes = l(spikes) 42 | elif i == len(self.layers) - 1: 43 | assert(l.type == 'linear') 44 | spikes = l.forward(spikes, labels) 45 | else: 46 | spikes = l.forward(spikes) 47 | 48 | return spikes 49 | 50 | def weight_clipper(self): 51 | for l in self.layers: 52 | l.weight_clipper() 53 | 54 | def train(self): 55 | for l in self.layers: 56 | l.train() 57 | 58 | def eval(self): 59 | for l in self.layers: 60 | l.eval() 61 | -------------------------------------------------------------------------------- /datasets/loadCIFAR10.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | from datasets.utils import packaging_class 5 | import PIL 6 | 7 | 8 | def get_cifar10(data_path, network_config): 9 | print("loading CIFAR10") 10 | if not os.path.exists(data_path): 11 | os.mkdir(data_path) 12 | 13 | transform_train = transforms.Compose([ 14 | transforms.RandomCrop(32, padding=4), 15 | # transforms.RandomResizedCrop(32, scale=(0.75,1.0), interpolation=PIL.Image.BILINEAR), 16 | transforms.RandomHorizontalFlip(), 17 | transforms.AutoAugment(), 18 | transforms.ToTensor(), 19 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 20 | ]) 21 | 22 | transform_test = transforms.Compose([ 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 25 | ]) 26 | 27 | trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=transform_train) 28 | testset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=transform_test) 29 | 30 | return trainset, testset 31 | -------------------------------------------------------------------------------- /datasets/loadCIFAR100.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision 3 | import torchvision.transforms as transforms 4 | import torch 5 | 6 | 7 | def get_cifar100(data_path, network_config): 8 | print("loading CIFAR100") 9 | if not os.path.exists(data_path): 10 | os.mkdir(data_path) 11 | transform_train = transforms.Compose([ 12 | transforms.RandomCrop(32, padding=4), 13 | transforms.RandomHorizontalFlip(), 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 16 | ]) 17 | 18 | transform_test = transforms.Compose([ 19 | transforms.ToTensor(), 20 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 21 | ]) 22 | 23 | trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=transform_train) 24 | testset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=transform_test) 25 | 26 | return trainset, testset 27 | -------------------------------------------------------------------------------- /datasets/loadFashionMNIST.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets 3 | import torchvision.transforms as transforms 4 | import torch 5 | 6 | 7 | def get_fashionmnist(data_path, network_config): 8 | print("loading Fashion MNIST") 9 | if not os.path.exists(data_path): 10 | os.mkdir(data_path) 11 | transform_train = transforms.Compose([ 12 | # transforms.RandomCrop(28, padding=4), 13 | transforms.RandomHorizontalFlip(), 14 | # transforms.RandomRotation(degrees=20), 15 | transforms.ToTensor(), 16 | transforms.Normalize((0.1307,), (0.3081,)) 17 | ]) 18 | 19 | transform_test = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Normalize((0.1307,), (0.3081,)) 22 | ]) 23 | 24 | trainset = torchvision.datasets.FashionMNIST(data_path, train=True, transform=transform_train, download=True) 25 | testset = torchvision.datasets.FashionMNIST(data_path, train=False, transform=transform_test, download=True) 26 | 27 | return trainset, testset 28 | 29 | -------------------------------------------------------------------------------- /datasets/loadMNIST.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets 3 | import torchvision.transforms as transforms 4 | import torch 5 | 6 | 7 | def get_mnist(data_path, network_config): 8 | print("loading MNIST") 9 | if not os.path.exists(data_path): 10 | os.mkdir(data_path) 11 | 12 | transform_train = transforms.Compose([ 13 | # transforms.RandomCrop(28, padding=4), 14 | # transforms.RandomHorizontalFlip(), 15 | # transforms.RandomRotation(30), 16 | transforms.ToTensor(), 17 | transforms.Normalize((0.1307,), (0.3081,)) 18 | ]) 19 | 20 | transform_test = transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.1307,), (0.3081,)) 23 | ]) 24 | trainset = torchvision.datasets.MNIST(data_path, train=True, transform=transform_train, download=True) 25 | testset = torchvision.datasets.MNIST(data_path, train=False, transform=transform_test, download=True) 26 | 27 | return trainset, testset 28 | 29 | -------------------------------------------------------------------------------- /datasets/loadNMNIST_Spiking.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import torch 4 | from torch.utils.data import Dataset 5 | import torchvision.transforms as transforms 6 | import numpy as np 7 | from tqdm import tqdm 8 | from os import listdir 9 | from os.path import isfile, join 10 | 11 | 12 | class NMNIST(Dataset): 13 | def __init__(self, dataset_path, T, transform=None): 14 | self.path = dataset_path 15 | self.samples = [] 16 | self.labels = [] 17 | self.transform = transform 18 | self.T = T 19 | for i in tqdm(range(10)): 20 | sample_dir = dataset_path + '/' + str(i) + '/' 21 | for f in listdir(sample_dir): 22 | filename = join(sample_dir, f) 23 | if isfile(filename): 24 | self.samples.append(filename) 25 | self.labels.append(i) 26 | 27 | def __getitem__(self, index): 28 | filename = self.samples[index] 29 | label = self.labels[index] 30 | 31 | data = np.zeros((2, 34, 34, self.T)) 32 | 33 | f = open(filename, 'r') 34 | lines = f.readlines() 35 | for line in lines: 36 | if line is None: 37 | break 38 | line = line.split() 39 | line = [int(l) for l in line] 40 | pos = line[0] - 1 41 | if pos >= 1156: 42 | channel = 1 43 | pos -= 1156 44 | else: 45 | channel = 0 46 | y = pos % 34 47 | x = int(math.floor(pos/34)) 48 | for i in range(1, len(line)): 49 | if line[i] >= self.T: 50 | break 51 | data[channel, x, y, line[i]-1] = 1 52 | if self.transform: 53 | data = self.transform(data) 54 | data = data.type(torch.float32) 55 | else: 56 | data = torch.FloatTensor(data) 57 | 58 | # Input spikes are reshaped to ignore the spatial dimension and the neurons are placed in channel dimension. 59 | # The spatial dimension can be maintained and used as it is. 60 | # It requires different definition of the dense layer. 61 | return data.permute(3,0,1,2), label 62 | 63 | def __len__(self): 64 | return len(self.samples) 65 | 66 | 67 | def get_nmnist(data_path, network_config): 68 | T = network_config['n_steps'] 69 | print("loading NMNIST") 70 | if not os.path.exists(data_path): 71 | os.mkdir(data_path) 72 | train_path = data_path + '/Train' 73 | test_path = data_path + '/Test' 74 | 75 | trainset = NMNIST(train_path, T) 76 | testset = NMNIST(test_path, T) 77 | 78 | return trainset, testset 79 | -------------------------------------------------------------------------------- /datasets/loadSMNIST.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torchvision.datasets 3 | import torchvision.transforms as transforms 4 | 5 | 6 | def trans(data): 7 | trans = transforms.Compose([ 8 | transforms.ToTensor(), 9 | transforms.Normalize((0.1307,), (0.3081,)) 10 | ]) 11 | data = trans(data) 12 | C, H, W = data.shape 13 | data = data.permute(2,0,1).reshape(W,C,H,1) 14 | return data 15 | 16 | 17 | def get_smnist(data_path, network_config): 18 | print("loading S-MNIST") 19 | if not os.path.exists(data_path): 20 | os.mkdir(data_path) 21 | 22 | trainset = torchvision.datasets.MNIST(data_path, train=True, transform=trans, download=True) 23 | testset = torchvision.datasets.MNIST(data_path, train=False, transform=trans, download=True) 24 | 25 | return trainset, testset 26 | -------------------------------------------------------------------------------- /datasets/loadSpiking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | import PIL 5 | from spikingjelly.datasets.n_mnist import NMNIST 6 | from spikingjelly.datasets.dvs128_gesture import DVS128Gesture 7 | from spikingjelly.datasets.cifar10_dvs import CIFAR10DVS 8 | from spikingjelly.datasets import split_to_train_test_set, RandomTemporalDelete 9 | import numpy as np 10 | from datasets.utils import function_nda, packaging_class 11 | import global_v as glv 12 | 13 | 14 | 15 | def get_dataset(dataset_func, data_path, network_config, transform_train=None, transform_test=None): 16 | T = network_config['n_steps'] 17 | if not os.path.exists(data_path): 18 | os.mkdir(data_path) 19 | trainset = dataset_func(data_path, data_type='frame', frames_number=T, split_by='number', train=True) 20 | testset = dataset_func(data_path, data_type='frame', frames_number=T, split_by='number', train=False) 21 | trainset, testset = packaging_class(trainset, transform_train), packaging_class(testset, transform_test) 22 | return trainset, testset 23 | 24 | 25 | def get_nmnist(data_path, network_config): 26 | return get_dataset(NMNIST, data_path, network_config) 27 | 28 | 29 | def trans_t(data): 30 | # print(data.shape) 31 | # exit(0) 32 | data = transforms.RandomResizedCrop(128, scale=(0.7, 1.0), interpolation=PIL.Image.NEAREST)(data) 33 | resize = transforms.Resize(size=(48, 48)) # 48 48 34 | data = resize(data).float() 35 | flip = np.random.random() > 0.5 36 | if flip: 37 | data = torch.flip(data, dims=(3,)) 38 | data = function_nda(data) 39 | return data.float() 40 | 41 | 42 | def trans(data): 43 | resize = transforms.Resize(size=(48, 48)) # 48 48 44 | data = resize(data).float() 45 | return data.float() 46 | 47 | 48 | def get_dvs128_gesture(data_path, network_config): 49 | T = network_config['t_train'] 50 | transform_train = transforms.Compose([ 51 | transforms.RandomResizedCrop(128, scale=(0.5, 1.0), interpolation=PIL.Image.NEAREST), 52 | transforms.RandomHorizontalFlip(), 53 | RandomTemporalDelete(T_remain=T, batch_first=False), 54 | ]) 55 | return get_dataset(DVS128Gesture, data_path, network_config, transform_train) 56 | # return get_dataset(DVS128Gesture, data_path, network_config, trans_t, trans) 57 | 58 | 59 | def get_cifar10_dvs(data_path, network_config): 60 | T = network_config['n_steps'] 61 | if not os.path.exists(data_path): 62 | os.mkdir(data_path) 63 | dataset = CIFAR10DVS(data_path, data_type='frame', frames_number=T, split_by='number') 64 | trainset, testset = split_to_train_test_set(train_ratio=0.9, origin_dataset=dataset, num_classes=10) 65 | 66 | trainset, testset = packaging_class(trainset, trans_t), packaging_class(testset, trans) 67 | return trainset, testset 68 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | import PIL 5 | import math 6 | import numpy as np 7 | 8 | 9 | class packaging_class(torch.utils.data.Dataset): 10 | def __init__(self, dataset, transform=None): 11 | self.transform = transform 12 | self.dataset = dataset 13 | 14 | def __getitem__(self, index): 15 | data, label = self.dataset[index] 16 | data = torch.FloatTensor(data) 17 | if self.transform: 18 | data = self.transform(data) 19 | 20 | return data, label 21 | 22 | def __len__(self): 23 | return len(self.dataset) 24 | 25 | 26 | class Cutout(object): 27 | """Randomly mask out one or more patches from an image. 28 | Args: 29 | n_holes (int): Number of patches to cut out of each image. 30 | length (int): The length (in pixels) of each square patch. 31 | """ 32 | 33 | def __init__(self, length): 34 | self.length = length 35 | 36 | def __call__(self, img): 37 | h = img.size(2) 38 | w = img.size(3) 39 | mask = np.ones((h, w), np.float32) 40 | y = np.random.randint(h) 41 | x = np.random.randint(w) 42 | y1 = np.clip(y - self.length // 2, 0, h) 43 | y2 = np.clip(y + self.length // 2, 0, h) 44 | x1 = np.clip(x - self.length // 2, 0, w) 45 | x2 = np.clip(x + self.length // 2, 0, w) 46 | mask[y1: y2, x1: x2] = 0. 47 | mask = torch.from_numpy(mask) 48 | mask = mask.expand_as(img) 49 | img = img * mask 50 | return img 51 | 52 | 53 | def function_nda(data, M=1, N=2): 54 | c = 15 * N 55 | rotate_tf = transforms.RandomRotation(degrees=c) 56 | e = 8 * N 57 | cutout_tf = Cutout(length=e) 58 | 59 | def roll(data, N=1): 60 | a = N * 2 + 1 61 | off1 = np.random.randint(-a, a + 1) 62 | off2 = np.random.randint(-a, a + 1) 63 | return torch.roll(data, shifts=(off1, off2), dims=(2, 3)) 64 | 65 | def rotate(data, N): 66 | return rotate_tf(data) 67 | 68 | def cutout(data, N): 69 | return cutout_tf(data) 70 | 71 | transforms_list = [roll, rotate, cutout] 72 | sampled_ops = np.random.choice(transforms_list, M) 73 | for op in sampled_ops: 74 | data = op(data, N) 75 | return data 76 | 77 | 78 | def TTFS(data, T): 79 | # data: C*H*W 80 | # output: T*C*H*W 81 | C, H, W = data.shape 82 | low, high = torch.min(data), torch.max(data) 83 | data = ((data - low) / (high - low) * T).long() 84 | # T --> T-1 85 | data = torch.clip(data, 0, T - 1) 86 | res = torch.zeros(T, C, H, W, device=data.device) 87 | return res.scatter_(0, data.unsqueeze(0), 1) 88 | -------------------------------------------------------------------------------- /global_v.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def init(config_n, config_l=None): 5 | global T, T_train, syn_a, delta_syn_a, tau_s, tau_m, grad_rec, outputs_raw 6 | global rank, network_config, layers_config, time_use, req_grad, init_flag 7 | init_flag = False 8 | 9 | network_config, layers_config = config_n, config_l 10 | 11 | if 'loss_reverse' not in network_config.keys(): 12 | network_config['loss_reverse'] = True 13 | 14 | if 'encoding' not in network_config.keys(): 15 | network_config['encoding'] = 'None' 16 | if 'amp' not in network_config.keys(): 17 | network_config['amp'] = False 18 | if 'backend' not in network_config.keys(): 19 | network_config['backend'] = 'python' 20 | if 'norm_grad' not in network_config.keys(): 21 | network_config['norm_grad'] = 1 22 | 23 | if 'max_dudt_inv' not in network_config: 24 | network_config['max_dudt_inv'] = 123456789 25 | if 'avg_spike_init' not in network_config: 26 | network_config['avg_spike_init'] = 1 27 | if 'weight_decay' not in network_config: 28 | network_config['weight_decay'] = 0 29 | if 't_train' not in network_config: 30 | network_config['t_train'] = network_config['n_steps'] 31 | 32 | T, tau_s, tau_m, grad_type = (config_n[x] for x in ('n_steps', 'tau_s', 'tau_m', 'gradient_type')) 33 | if 'forward_type' not in network_config: 34 | network_config['forward_type'] = 'leaky' 35 | 36 | assert(network_config['forward_type'] in ['leaky', 'nonleaky']) 37 | assert(grad_type in ['original', 'exponential']) 38 | assert(not (network_config['forward_type'] == 'nonleaky' and grad_type == 'original')) 39 | 40 | syn_a, delta_syn_a = (torch.zeros(T + 1, device=torch.device(rank)) for _ in range(2)) 41 | theta_m, theta_s = 1 / tau_m, 1 / tau_s 42 | if grad_type == 'exponential': 43 | assert('tau_grad' in config_n) 44 | tau_grad = config_n['tau_grad'] 45 | theta_grad = 1 / tau_grad 46 | 47 | for t in range(T): 48 | t1 = t + 1 49 | syn_a[t] = ((1 - theta_m) ** t1 - (1 - theta_s) ** t1) * theta_s / (theta_s - theta_m) 50 | if grad_type == 'exponential': 51 | delta_syn_a[t] = (1 - theta_grad) ** t1 52 | else: 53 | f = lambda t: ((1 - theta_m) ** t - (1 - theta_s) ** t) * theta_s / (theta_s - theta_m) 54 | delta_syn_a[t] = f(t1) - f(t1 - 1) 55 | # print(syn_a, delta_syn_a) -------------------------------------------------------------------------------- /layers/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.associations": { 3 | "atomic": "cpp", 4 | "bit": "cpp", 5 | "cctype": "cpp", 6 | "clocale": "cpp", 7 | "cmath": "cpp", 8 | "compare": "cpp", 9 | "concepts": "cpp", 10 | "cstddef": "cpp", 11 | "cstdint": "cpp", 12 | "cstdio": "cpp", 13 | "cstdlib": "cpp", 14 | "cstring": "cpp", 15 | "ctime": "cpp", 16 | "cwchar": "cpp", 17 | "exception": "cpp", 18 | "initializer_list": "cpp", 19 | "ios": "cpp", 20 | "iosfwd": "cpp", 21 | "iostream": "cpp", 22 | "istream": "cpp", 23 | "iterator": "cpp", 24 | "limits": "cpp", 25 | "memory": "cpp", 26 | "new": "cpp", 27 | "ostream": "cpp", 28 | "stdexcept": "cpp", 29 | "streambuf": "cpp", 30 | "system_error": "cpp", 31 | "tuple": "cpp", 32 | "type_traits": "cpp", 33 | "typeinfo": "cpp", 34 | "utility": "cpp", 35 | "vector": "cpp", 36 | "xfacet": "cpp", 37 | "xiosbase": "cpp", 38 | "xlocale": "cpp", 39 | "xlocinfo": "cpp", 40 | "xlocnum": "cpp", 41 | "xmemory": "cpp", 42 | "xstddef": "cpp", 43 | "xstring": "cpp", 44 | "xtr1common": "cpp", 45 | "xutility": "cpp" 46 | } 47 | } -------------------------------------------------------------------------------- /layers/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | from layers.functions import neuron_forward, neuron_backward, bn_forward, bn_backward, readConfig, initialize 5 | import global_v as glv 6 | import torch.backends.cudnn as cudnn 7 | from torch.utils.cpp_extension import load_inline, load 8 | from torch.cuda.amp import custom_fwd, custom_bwd 9 | from datetime import datetime 10 | 11 | cpp_wrapper = load(name="cpp_wrapper", sources=["layers/cpp_wrapper.cpp"], verbose=True) 12 | 13 | 14 | class ConvLayer(nn.Conv2d): 15 | def __init__(self, network_config, config, name, groups=1): 16 | self.name = name 17 | self.threshold = config['threshold'] if 'threshold' in config else None 18 | self.type = config['type'] 19 | in_features = config['in_channels'] 20 | out_features = config['out_channels'] 21 | kernel_size = config['kernel_size'] 22 | 23 | padding = config['padding'] if 'padding' in config else 0 24 | stride = config['stride'] if 'stride' in config else 1 25 | dilation = config['dilation'] if 'dilation' in config else 1 26 | 27 | self.kernel = readConfig(kernel_size, 'kernelSize') 28 | self.stride = readConfig(stride, 'stride') 29 | self.padding = readConfig(padding, 'stride') 30 | self.dilation = readConfig(dilation, 'stride') 31 | 32 | super(ConvLayer, self).__init__(in_features, out_features, self.kernel, self.stride, self.padding, 33 | self.dilation, groups, bias=False) 34 | self.weight = torch.nn.Parameter(self.weight.cuda(), requires_grad=True) 35 | self.norm_weight = torch.nn.Parameter(torch.ones(out_features, 1, 1, 1, device='cuda')) 36 | self.norm_bias = torch.nn.Parameter(torch.zeros(out_features, 1, 1, 1, device='cuda')) 37 | 38 | print('conv') 39 | print(f'Shape of weight is {list(self.weight.shape)}') # Cout * Cin * Hk * Wk 40 | print(f'stride = {self.stride}, padding = {self.padding}, dilation = {self.dilation}, groups = {self.groups}') 41 | print("-----------------------------------------") 42 | 43 | def forward(self, x): 44 | if glv.init_flag: 45 | glv.init_flag = False 46 | x = initialize(self, x) 47 | glv.init_flag = True 48 | return x 49 | 50 | # self.weight_clipper() 51 | config_n = glv.network_config 52 | theta_m = 1 / config_n['tau_m'] 53 | theta_s = 1 / config_n['tau_s'] 54 | theta_grad = 1 / config_n['tau_grad'] if config_n[ 55 | 'gradient_type'] == 'exponential' else -123456789 # instead of None 56 | y = ConvFunc.apply(x, self.weight, self.norm_weight, self.norm_bias, 57 | (self.bias, self.stride, self.padding, self.dilation, self.groups), 58 | (theta_m, theta_s, theta_grad, self.threshold)) 59 | return y 60 | 61 | def weight_clipper(self): 62 | w = self.weight.data 63 | w = w.clamp(-4, 4) 64 | self.weight.data = w 65 | 66 | 67 | class ConvFunc(torch.autograd.Function): 68 | @staticmethod 69 | @custom_fwd 70 | def forward(ctx, inputs, weight, norm_weight, norm_bias, conv_config, neuron_config): 71 | # input.shape: T * n_batch * C_in * H_in * W_in 72 | bias, stride, padding, dilation, groups = conv_config 73 | T, n_batch, C, H, W = inputs.shape 74 | 75 | inputs, mean, var, weight_ = bn_forward(inputs, weight, norm_weight, norm_bias) 76 | 77 | in_I = f.conv2d(inputs.reshape(T * n_batch, C, H, W), weight_, bias, stride, padding, dilation, groups) 78 | _, C, H, W = in_I.shape 79 | in_I = in_I.reshape(T, n_batch, C, H, W) 80 | 81 | delta_u, delta_u_t, outputs = neuron_forward(in_I, neuron_config) 82 | 83 | ctx.save_for_backward(delta_u, delta_u_t, inputs, outputs, weight, norm_weight, norm_bias, mean, var) 84 | ctx.conv_config = conv_config 85 | 86 | return outputs 87 | 88 | @staticmethod 89 | @custom_bwd 90 | def backward(ctx, grad_delta): 91 | # shape of grad_delta: T * n_batch * C * H * W 92 | (delta_u, delta_u_t, inputs, outputs, weight, norm_weight, norm_bias, mean, var) = ctx.saved_tensors 93 | bias, stride, padding, dilation, groups = ctx.conv_config 94 | grad_delta *= outputs 95 | # sum_next = grad_delta.sum().item() 96 | # print("Max of dLdt: ", abs(grad_delta).max().item()) 97 | 98 | grad_in_, grad_w_ = neuron_backward(grad_delta, outputs, delta_u, delta_u_t) 99 | weight_ = (weight - mean) / torch.sqrt(var + 1e-5) * norm_weight + norm_bias 100 | 101 | T, n_batch, C, H, W = grad_delta.shape 102 | inputs = inputs.reshape(T * n_batch, *inputs.shape[2:]) 103 | grad_in_, grad_w_ = map(lambda x: x.reshape(T * n_batch, C, H, W), [grad_in_, grad_w_]) 104 | grad_input = cpp_wrapper.cudnn_convolution_backward_input(inputs.shape, grad_in_.to(weight_), weight_, padding, 105 | stride, dilation, groups, 106 | cudnn.benchmark, cudnn.deterministic, 107 | cudnn.allow_tf32) * inputs 108 | grad_weight = cpp_wrapper.cudnn_convolution_backward_weight(weight.shape, grad_w_.to(inputs), inputs, padding, 109 | stride, dilation, groups, 110 | cudnn.benchmark, cudnn.deterministic, 111 | cudnn.allow_tf32) 112 | 113 | grad_weight, grad_bn_w, grad_bn_b = bn_backward(grad_weight, weight, norm_weight, norm_bias, mean, var) 114 | 115 | # sum_last = grad_input.sum().item() 116 | # print(f'sum_next = {sum_next}, sum_last = {sum_last}') 117 | # assert(abs(sum_next - sum_last) < 1) 118 | return grad_input.reshape(T, n_batch, *inputs.shape[1:]) * 0.85, grad_weight, grad_bn_w, grad_bn_b, None, None, None 119 | -------------------------------------------------------------------------------- /layers/cpp_wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 4 | m.def("cudnn_convolution_backward", &at::cudnn_convolution_backward, "CUDNN convolution backward"); 5 | m.def("cudnn_convolution_backward_input", &at::cudnn_convolution_backward_input, "CUDNN convolution backward for input"); 6 | m.def("cudnn_convolution_backward_weight", &at::cudnn_convolution_backward_weight, "CUDNN convolution backward for weight"); 7 | } -------------------------------------------------------------------------------- /layers/dropout.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as f 3 | import global_v as glv 4 | 5 | 6 | class DropoutLayer(nn.Dropout3d): 7 | def __init__(self, config, name, inplace=False): 8 | self.name = name 9 | self.type = config['type'] 10 | if 'p' in config: 11 | p = config['p'] 12 | else: 13 | p = 0.5 14 | super(DropoutLayer, self).__init__(p, inplace) 15 | print('dropout') 16 | print("p: %.2f" % p) 17 | print("-----------------------------------------") 18 | 19 | def forward(self, x): 20 | if self.p <= 0 or self.p >= 1 or glv.init_flag: 21 | return x 22 | ndim = len(x.shape) 23 | if ndim == 3: 24 | T, n_batch, N = x.shape 25 | result = f.dropout2d(x.permute(1,2,0).reshape((n_batch, N, 1, T)), self.p, self.training, self.inplace) 26 | return result.reshape((n_batch, N, T)).permute(2,0,1) 27 | elif ndim == 5: 28 | T, n_batch, C, H, W = x.shape 29 | result = f.dropout2d(x.permute(1,2,3,4,0).reshape((n_batch, C, H*W, T)), self.p, self.training, self.inplace) 30 | return result.reshape((n_batch, C, H, W, T)).permute(4,0,1,2,3) 31 | else: 32 | raise("In dropout layer, dimension of input is not 3 or 5!") 33 | 34 | def weight_clipper(self): 35 | return 36 | -------------------------------------------------------------------------------- /layers/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import global_v as glv 3 | from torch.utils.cpp_extension import load 4 | 5 | try: 6 | neuron_cuda = load(name="neuron_cuda", sources=["layers/neuron_cuda.cpp", 'layers/neuron_cuda_kernel.cu'], 7 | verbose=True) 8 | except: 9 | print('Cannot load cuda neuron kernel.') 10 | 11 | 12 | def readConfig(data, name): 13 | if type(data) == int: 14 | res = (data, data) 15 | else: # str 16 | try: 17 | assert(data[0] == '(' and data[-1] == ')') 18 | data = data[1:len(data)-1] 19 | x, y = map(int, data.split(',')) 20 | res = (x, y) 21 | except: 22 | raise Exception(f'The format of {name} is illegal!') 23 | return res 24 | 25 | 26 | def initialize(layer, spikes): 27 | avg_spike_init = glv.network_config['avg_spike_init'] 28 | from math import sqrt 29 | T = spikes.shape[0] 30 | t_start = T * 2 // 3 31 | 32 | low, high = 0.05, 500 33 | while high / low >= 1.01: 34 | mid = sqrt(high * low) 35 | layer.norm_weight.data *= mid 36 | outputs = layer.forward(spikes) 37 | layer.norm_weight.data /= mid 38 | n_neuron = outputs[0].numel() 39 | avg_spike = torch.sum(outputs[t_start:]) / n_neuron 40 | if avg_spike > avg_spike_init / T * (T - t_start) * 1.2: 41 | high = mid 42 | else: 43 | low = mid 44 | layer.norm_weight.data *= mid 45 | return layer.forward(spikes) 46 | 47 | 48 | def norm(inputs): 49 | T = inputs.shape[0] 50 | t_start = T * 2 // 3 51 | if (inputs >= 0).all(): 52 | num_spike = (torch.sum(inputs[t_start:]) + 1e-5) 53 | target_spike = inputs.numel() / T * (T - t_start) / T 54 | inputs = inputs / num_spike * target_spike 55 | return inputs 56 | 57 | 58 | def bn_forward(inputs, weight, norm_weight, norm_bias): 59 | # inputs = norm(inputs) 60 | C = weight.shape[0] 61 | # print(weight.shape) 62 | mean, var = torch.mean(weight.reshape(C, -1), dim=1), torch.std(weight.reshape(C, -1), dim=1) ** 2 63 | shape = (-1, 1, 1, 1) if len(weight.shape) == 4 else (-1, 1) 64 | mean, var, norm_weight, norm_bias = [x.reshape(*shape) for x in [mean, var, norm_weight, norm_bias]] 65 | weight_ = (weight - mean) / torch.sqrt(var + 1e-5) * norm_weight + norm_bias 66 | return inputs, mean, var, weight_ 67 | 68 | 69 | def bn_backward(grad_weight, weight, norm_weight, norm_bias, mean, var): 70 | C = weight.shape[0] 71 | std_inv = 1 / torch.sqrt(var + 1e-5) 72 | shape = (-1, 1, 1, 1) if len(weight.shape) == 4 else (-1, 1) 73 | weight_ = (weight - mean) * std_inv * norm_weight.reshape(*shape) + norm_bias.reshape(*shape) 74 | grad_bn_b = torch.sum(grad_weight.reshape(C, -1), dim=1).reshape(norm_bias.shape) 75 | grad_bn_w = torch.sum((grad_weight * weight_).reshape(C, -1), dim=1).reshape(norm_weight.shape) 76 | grad_weight *= norm_weight.reshape(*shape) 77 | m = weight.numel() // C 78 | grad_var = grad_weight * (weight - mean) / m * (-0.5) * std_inv ** 3 79 | grad_mean = -grad_weight * std_inv 80 | grad_weight = grad_weight * std_inv + grad_var * 2 * (weight - mean) / m + grad_mean / m 81 | return grad_weight, grad_bn_w, grad_bn_b 82 | 83 | 84 | @torch.jit.script 85 | def neuron_forward_py(in_I, theta_m, theta_s, theta_grad, threshold, is_forward_leaky, is_grad_exp): 86 | # syn_m & syn_s: (1-theta_m)^t & (1-theta_s)^t in eps(t) 87 | # syn_grad: (1-theta_grad)^t in backward 88 | u_last = torch.zeros_like(in_I[0]) 89 | syn_m, syn_s, syn_grad = torch.zeros_like(in_I[0]), torch.zeros_like(in_I[0]), torch.zeros_like(in_I[0]) 90 | delta_u, delta_u_t, outputs = torch.zeros_like(in_I), torch.zeros_like(in_I), torch.zeros_like(in_I) 91 | T = in_I.shape[0] 92 | for t in range(T): 93 | syn_m = (syn_m + in_I[t]) * (1 - theta_m) 94 | syn_s = (syn_s + in_I[t]) * (1 - theta_s) 95 | syn_grad = (syn_grad + in_I[t]) * (1 - theta_grad) 96 | 97 | if not is_forward_leaky: 98 | delta_u_t[t] = syn_grad 99 | u = u_last + delta_u_t[t] 100 | delta_u[t] = delta_u_t[t] 101 | else: 102 | u = (syn_m - syn_s) * theta_s / (theta_s - theta_m) 103 | delta_u[t] = u - u_last 104 | delta_u_t[t] = syn_grad if is_grad_exp else delta_u[t] 105 | 106 | out = (u >= threshold).to(u) 107 | u_last = u * (1 - out) 108 | 109 | syn_m = syn_m * (1 - out) 110 | syn_s = syn_s * (1 - out) 111 | syn_grad = syn_grad * (1 - out) 112 | outputs[t] = out 113 | 114 | return delta_u, delta_u_t, outputs 115 | 116 | 117 | @torch.jit.script 118 | def neuron_backward_py(grad_delta, outputs, delta_u, delta_u_t, syn_a, partial_a, max_dudt_inv): 119 | T = grad_delta.shape[0] 120 | 121 | grad_in_, grad_w_ = torch.zeros_like(outputs), torch.zeros_like(outputs) 122 | partial_u_grad_w, partial_u_grad_t = torch.zeros_like(outputs[0]), torch.zeros_like(outputs[0]) 123 | delta_t = torch.zeros(outputs.shape[1:], device=outputs.device, dtype=torch.long) 124 | spiked = torch.zeros_like(outputs[0]) 125 | 126 | for t in range(T - 1, -1, -1): 127 | out = outputs[t] 128 | spiked += (1 - spiked) * out 129 | 130 | partial_u = torch.clamp(-1 / delta_u[t], -4, 0) 131 | partial_u_t = torch.clamp(-1 / delta_u_t[t], -max_dudt_inv, 0) 132 | # current time is t_m 133 | partial_u_grad_w = partial_u_grad_w * (1 - out) + grad_delta[t] * partial_u * out 134 | partial_u_grad_t = partial_u_grad_t * (1 - out) + grad_delta[t] * partial_u_t * out 135 | 136 | delta_t = (delta_t + 1) * (1 - out).long() 137 | grad_in_[t] = partial_u_grad_t * partial_a[delta_t] * spiked.to(partial_a) 138 | grad_w_[t] = partial_u_grad_w * syn_a[delta_t] * spiked.to(syn_a) 139 | 140 | return grad_in_, grad_w_ 141 | 142 | 143 | def neuron_forward(in_I, neuron_config): 144 | theta_m, theta_s, theta_grad, threshold = torch.tensor(neuron_config).to(in_I) 145 | assert (theta_m != theta_s) 146 | is_grad_exp = torch.tensor(glv.network_config['gradient_type'] == 'exponential') 147 | is_forward_leaky = torch.tensor(glv.network_config['forward_type'] == 'leaky') 148 | if glv.network_config['backend'] == 'python': 149 | return neuron_forward_py(in_I, theta_m, theta_s, theta_grad, threshold, is_forward_leaky, is_grad_exp) 150 | elif glv.network_config['backend'] == 'cuda': 151 | # global neuron_cuda 152 | # if neuron_cuda is None: 153 | theta_m, theta_s, theta_grad, threshold = neuron_config 154 | return neuron_cuda.forward(in_I, theta_m, theta_s, theta_grad, threshold, is_forward_leaky, is_grad_exp) 155 | else: 156 | raise Exception('Unrecognized computation backend.') 157 | 158 | 159 | def neuron_backward(grad_delta, outputs, delta_u, delta_u_t): 160 | syn_a, partial_a = glv.syn_a.to(outputs), -glv.delta_syn_a.to(outputs) 161 | max_dudt_inv = torch.tensor(glv.network_config['max_dudt_inv']) 162 | if glv.network_config['backend'] == 'python': 163 | return neuron_backward_py(grad_delta, outputs, delta_u, delta_u_t, syn_a, partial_a, max_dudt_inv) 164 | elif glv.network_config['backend'] == 'cuda': 165 | max_dudt_inv = max_dudt_inv.item() 166 | return neuron_cuda.backward(grad_delta, outputs, delta_u, delta_u_t, syn_a, partial_a, max_dudt_inv) 167 | else: 168 | raise Exception('Unrecognized computation backend.') 169 | 170 | 171 | if __name__ == '__main__': 172 | T = 12 173 | glv.rank = 0 174 | config = dict() 175 | config['gradient_type'] = 'exponential' 176 | config['forward_type'] = 'nonleaky' 177 | for key, val in zip(('n_steps', 'tau_s', 'tau_m', 'tau_grad', 'threshold'), (T, 7, 4, 3.5, 1)): 178 | config[key] = val 179 | glv.init(config, config) 180 | neuron_cuda = load(name="neuron_cuda", sources=["neuron_cuda.cpp", 'neuron_cuda_kernel.cu'], verbose=True) 181 | shape = (T, 50, 3, 32, 32) 182 | 183 | neuron_config = [1 / glv.network_config[key] for key in ('tau_m', 'tau_s', 'tau_grad')] + [ 184 | glv.network_config['threshold']] 185 | in_I = torch.rand(*shape, device=torch.device('cuda')) 186 | glv.network_config['backend'] = 'python' 187 | delta_u_py, delta_u_t_py, outputs_py = neuron_forward(in_I, neuron_config) 188 | glv.network_config['backend'] = 'cuda' 189 | delta_u_cuda, delta_u_t_cuda, outputs_cuda = neuron_forward(in_I, neuron_config) 190 | print(torch.sum(delta_u_py), torch.sum(delta_u_cuda)) 191 | assert (torch.sum(torch.abs(delta_u_py - delta_u_cuda)).item() <= 1e-3) 192 | assert (torch.sum(torch.abs(delta_u_t_py - delta_u_t_cuda)).item() <= 1e-3) 193 | assert (torch.sum(torch.abs(outputs_py - outputs_cuda)) <= 1e-3) 194 | 195 | grad_delta = torch.rand(*shape, device=torch.device('cuda')) 196 | outputs = torch.round(torch.rand_like(grad_delta)) 197 | delta_u = torch.rand_like(grad_delta) * 8 - 4 198 | delta_u_t = torch.rand_like(grad_delta) * 8 - 4 199 | glv.network_config['backend'] = 'python' 200 | grad_in_py, grad_w_py = neuron_backward(grad_delta, outputs, delta_u, delta_u_t) 201 | glv.network_config['backend'] = 'cuda' 202 | grad_in_cuda, grad_w_cuda = neuron_backward(grad_delta, outputs, delta_u, delta_u_t) 203 | print(torch.sum(grad_in_py), torch.sum(grad_in_cuda)) 204 | assert (torch.sum(torch.abs(grad_in_py - grad_in_cuda)) <= 1e-3) 205 | assert (torch.sum(torch.abs(grad_w_py - grad_w_cuda)) <= 1e-3) 206 | -------------------------------------------------------------------------------- /layers/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import global_v as glv 4 | from layers.functions import neuron_forward, neuron_backward, bn_forward, bn_backward, initialize 5 | from torch.cuda.amp import custom_fwd, custom_bwd 6 | 7 | 8 | class LinearLayer(nn.Linear): 9 | def __init__(self, network_config, config, name): 10 | # extract information for kernel and inChannels 11 | in_features = config['n_inputs'] 12 | out_features = config['n_outputs'] 13 | self.threshold = config['threshold'] if 'threshold' in config else None 14 | self.name = name 15 | self.type = config['type'] 16 | # self.in_shape = in_shape 17 | # self.out_shape = [out_features, 1, 1] 18 | 19 | if type(in_features) == int: 20 | n_inputs = in_features 21 | else: 22 | raise Exception('inFeatures should not be more than 1 dimesnion. It was: {}'.format(in_features.shape)) 23 | if type(out_features) == int: 24 | n_outputs = out_features 25 | else: 26 | raise Exception('outFeatures should not be more than 1 dimesnion. It was: {}'.format(out_features.shape)) 27 | 28 | super(LinearLayer, self).__init__(n_inputs, n_outputs, bias=False) 29 | self.weight = torch.nn.Parameter(self.weight.cuda(), requires_grad=True) 30 | self.norm_weight = torch.nn.Parameter(torch.ones(out_features,1, device='cuda')) 31 | self.norm_bias = torch.nn.Parameter(torch.zeros(out_features,1, device='cuda')) 32 | 33 | print("linear") 34 | print(self.name) 35 | # print(self.in_shape) 36 | # print(self.out_shape) 37 | print(f'Shape of weight is {list(self.weight.shape)}') 38 | print("-----------------------------------------") 39 | 40 | def forward(self, x, labels=None): 41 | if glv.init_flag: 42 | glv.init_flag = False 43 | x = initialize(self, x) 44 | glv.init_flag = True 45 | return x 46 | 47 | # self.weight_clipper() 48 | ndim = len(x.shape) 49 | assert(ndim == 3 or ndim == 5) 50 | if ndim == 5: 51 | T, n_batch, C, H, W = x.shape 52 | x = x.view(T, n_batch, C * H * W) 53 | config_n = glv.network_config 54 | theta_m = 1 / config_n['tau_m'] 55 | theta_s = 1 / config_n['tau_s'] 56 | theta_grad = 1 / config_n['tau_grad'] if config_n['gradient_type'] == 'exponential' else -123456789 #instead of None 57 | y = LinearFunc.apply(x, self.weight, self.norm_weight, self.norm_bias, (theta_m, theta_s, theta_grad, self.threshold), labels) 58 | return y 59 | 60 | def weight_clipper(self): 61 | w = self.weight.data 62 | w = w.clamp(-4, 4) 63 | self.weight.data = w 64 | 65 | 66 | class LinearFunc(torch.autograd.Function): 67 | @staticmethod 68 | @custom_fwd 69 | def forward(ctx, inputs, weight, norm_weight, norm_bias, config, labels): 70 | #input.shape: T * n_batch * N_in 71 | inputs, mean, var, weight_ = bn_forward(inputs, weight, norm_weight, norm_bias) 72 | 73 | in_I = torch.matmul(inputs, weight_.t()) 74 | 75 | T, n_batch, N = in_I.shape 76 | theta_m, theta_s, theta_grad, threshold = torch.tensor(config) 77 | assert (theta_m != theta_s) 78 | delta_u, delta_u_t, outputs = neuron_forward(in_I, config) 79 | 80 | if labels is not None: 81 | glv.outputs_raw = outputs.clone() 82 | i2 = torch.arange(n_batch) 83 | # Add supervisory signal when synaptic potential is increasing: 84 | is_inc = (delta_u[:, i2, labels] > 0.05).float() 85 | _, i1 = torch.max(is_inc * torch.arange(1, T+1, device=is_inc.device).unsqueeze(-1), dim=0) 86 | outputs[i1, i2, labels] = (delta_u[i1, i2, labels] != 0).to(outputs) 87 | 88 | # i1 = (torch.ones(n_batch) * -1).long() 89 | # delta_u[i1, i2, labels] = torch.maximum(delta_u[i1, i2, labels], theta_s.to(outputs)) 90 | # delta_u_t[i1, i2, labels] = torch.maximum(delta_u_t[i1, i2, labels], theta_s.to(outputs)) 91 | 92 | ctx.save_for_backward(delta_u, delta_u_t, inputs, outputs, weight, norm_weight, norm_bias, mean, var) 93 | ctx.is_out_layer = labels != None 94 | 95 | return outputs 96 | 97 | @staticmethod 98 | @custom_bwd 99 | def backward(ctx, grad_delta): 100 | # shape of grad_delta: T * n_batch * N_out 101 | (delta_u, delta_u_t, inputs, outputs, weight, norm_weight, norm_bias, mean, var) = ctx.saved_tensors 102 | grad_delta *= outputs 103 | # sum_next = grad_delta.sum().item() 104 | # print("Max of dLdt: ", abs(grad_delta).max().item()) 105 | 106 | grad_in_, grad_w_ = neuron_backward(grad_delta, outputs, delta_u, delta_u_t) 107 | weight_ = (weight - mean) / torch.sqrt(var + 1e-5) * norm_weight + norm_bias 108 | 109 | grad_input = torch.matmul(grad_in_, weight_) * inputs 110 | grad_weight = torch.sum(torch.matmul(grad_w_.transpose(1,2), inputs), dim=0) 111 | 112 | grad_weight, grad_bn_w, grad_bn_b = bn_backward(grad_weight, weight, norm_weight, norm_bias, mean, var) 113 | 114 | # sum_last = grad_input.sum().item() 115 | # assert(ctx.is_out_layer or abs(sum_next - sum_last) < 1) 116 | return grad_input * 0.85, grad_weight, grad_bn_w, grad_bn_b, None, None, None 117 | -------------------------------------------------------------------------------- /layers/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as f 3 | import global_v as glv 4 | from torch.cuda.amp import custom_fwd, custom_bwd 5 | from math import sqrt 6 | 7 | 8 | def psp(inputs): 9 | n_steps = glv.network_config['n_steps'] 10 | tau_s = glv.network_config['tau_s'] 11 | syns = torch.zeros_like(inputs).to(glv.rank) 12 | syn = torch.zeros(syns.shape[1:]).to(glv.rank) 13 | 14 | for t in range(n_steps): 15 | syn = syn * (1 - 1 / tau_s) + inputs[t, ...] 16 | syns[t, ...] = syn / tau_s 17 | return syns 18 | 19 | 20 | class SpikeLoss(torch.nn.Module): 21 | """ 22 | This class defines different spike based loss modules that can be used to optimize the SNN. 23 | """ 24 | 25 | def __init__(self): 26 | super(SpikeLoss, self).__init__() 27 | self.criterion = torch.nn.CrossEntropyLoss() 28 | 29 | def spike_count(self, output, target): 30 | delta = loss_count.apply(output, target) 31 | return 1 / 2 * torch.sum(delta ** 2) 32 | 33 | def spike_kernel(self, output, target): 34 | out = grad_sign.apply(output) 35 | delta = psp(out - target) 36 | return 1 / 2 * torch.sum(delta ** 2) 37 | 38 | def spike_TET(self, output, target): 39 | output = output.permute(1, 2, 0) 40 | out = grad_sign.apply(output) 41 | return f.cross_entropy(out, target.unsqueeze(-1).repeat(1, out.shape[-1])) 42 | 43 | 44 | class loss_count(torch.autograd.Function): 45 | @staticmethod 46 | @custom_fwd 47 | def forward(ctx, output, target): 48 | desired_count = glv.network_config['desired_count'] 49 | undesired_count = glv.network_config['undesired_count'] 50 | T = output.shape[0] 51 | out_count = torch.sum(output, dim=0) 52 | 53 | delta = (out_count - target) / T 54 | delta[(target == desired_count) & (delta > 0) | (target == undesired_count) & (delta < 0)] = 0 55 | delta = delta.unsqueeze_(0).repeat(T, 1, 1) 56 | return delta 57 | 58 | @staticmethod 59 | @custom_bwd 60 | def backward(ctx, grad): 61 | sign = -1 if glv.network_config['loss_reverse'] else 1 62 | return sign * grad, None 63 | 64 | 65 | class grad_sign(torch.autograd.Function): # a and u is the increment of each time steps 66 | @staticmethod 67 | @custom_fwd 68 | def forward(ctx, outputs): 69 | return outputs 70 | 71 | @staticmethod 72 | @custom_bwd 73 | def backward(ctx, grad): 74 | sign = -1 if glv.network_config['loss_reverse'] else 1 75 | return sign * grad 76 | -------------------------------------------------------------------------------- /layers/neuron_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | // CUDA forward declarations 6 | 7 | std::vector neuron_forward_cuda( 8 | const torch::Tensor &in_I, 9 | const float theta_m, 10 | const float theta_s, 11 | const float theta_grad, 12 | const float threshold, 13 | const float is_forward_leaky, 14 | const float is_grad_exp); 15 | 16 | std::vector neuron_backward_cuda( 17 | const torch::Tensor &grad_delta, 18 | const torch::Tensor &outputs, 19 | const torch::Tensor &delta_u, 20 | const torch::Tensor &delta_u_t, 21 | const torch::Tensor &syn_a, 22 | const torch::Tensor &partial_a, 23 | const float max_dudt_inv); 24 | 25 | // C++ interface 26 | 27 | // NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. 28 | #define CHECK_CUDA(x) AT_ASSERTM((x).type().is_cuda(), #x " must be a CUDA tensor") 29 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM((x).is_contiguous(), #x " must be contiguous") 30 | #define CHECK_INPUT(x) \ 31 | CHECK_CUDA(x); \ 32 | CHECK_CONTIGUOUS(x) 33 | 34 | std::vector neuron_forward( 35 | const torch::Tensor &in_I, 36 | const float theta_m, 37 | const float theta_s, 38 | const float theta_grad, 39 | const float threshold, 40 | const float is_forward_leaky, 41 | const float is_grad_exp) { 42 | 43 | CHECK_INPUT(in_I); 44 | return neuron_forward_cuda(in_I, theta_m, theta_s, theta_grad, threshold, is_forward_leaky, is_grad_exp); 45 | } 46 | 47 | std::vector neuron_backward( 48 | const torch::Tensor &grad_delta, 49 | const torch::Tensor &outputs, 50 | const torch::Tensor &delta_u, 51 | const torch::Tensor &delta_u_t, 52 | const torch::Tensor &syn_a, 53 | const torch::Tensor &partial_a, 54 | const float max_dudt_inv) { 55 | 56 | CHECK_INPUT(grad_delta); 57 | CHECK_INPUT(outputs); 58 | CHECK_INPUT(delta_u); 59 | CHECK_INPUT(delta_u_t); 60 | CHECK_INPUT(syn_a); 61 | CHECK_INPUT(partial_a); 62 | return neuron_backward_cuda(grad_delta, outputs, delta_u, delta_u_t, syn_a, partial_a, max_dudt_inv); 63 | } 64 | 65 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 66 | m.def("forward", &neuron_forward, "Neuron forward (CUDA)"); 67 | m.def("backward", &neuron_backward, "Neuron backward (CUDA)"); 68 | } 69 | -------------------------------------------------------------------------------- /layers/neuron_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | namespace { 11 | 12 | template 13 | __global__ void neuron_forward_cuda_kernel( 14 | const scalar_t *__restrict__ in_I, 15 | scalar_t *__restrict__ delta_u, 16 | scalar_t *__restrict__ delta_u_t, 17 | scalar_t *__restrict__ outputs, 18 | const float theta_m, const float theta_s, const float theta_grad, const float threshold, 19 | const float is_forward_leaky, const float is_grad_exp, 20 | size_t neuron_num, size_t tot_size) { 21 | 22 | size_t index = blockIdx.x * blockDim.x + threadIdx.x; 23 | float syn_m = 0, syn_s = 0, syn_grad = 0, u_last = 0, u = 0, out = 0; 24 | if (index < neuron_num) { 25 | for (; index < tot_size; index += neuron_num) { 26 | syn_m = (syn_m + in_I[index]) * (1 - theta_m); 27 | syn_s = (syn_s + in_I[index]) * (1 - theta_s); 28 | syn_grad = (syn_grad + in_I[index]) * (1 - theta_grad); 29 | 30 | if (!is_forward_leaky) { 31 | delta_u_t[index] = syn_grad; 32 | u = u_last + delta_u_t[index]; 33 | delta_u[index] = delta_u_t[index]; 34 | } else { 35 | u = (syn_m - syn_s) * theta_s / (theta_s - theta_m); 36 | delta_u[index] = u - u_last; 37 | delta_u_t[index] = is_grad_exp ? syn_grad : delta_u[index]; 38 | } 39 | 40 | out = u >= threshold; 41 | u_last = out ? 0 : u; 42 | syn_m = out ? 0 : syn_m; 43 | syn_s = out ? 0 : syn_s; 44 | syn_grad = out ? 0 : syn_grad; 45 | 46 | outputs[index] = out; 47 | } 48 | } 49 | } 50 | 51 | template 52 | __global__ void neuron_backward_cuda_kernel( 53 | const scalar_t *__restrict__ grad_delta, 54 | const scalar_t *__restrict__ outputs, 55 | const scalar_t *__restrict__ delta_u, 56 | const scalar_t *__restrict__ delta_u_t, 57 | const scalar_t *__restrict__ syn_a, 58 | const scalar_t *__restrict__ partial_a, 59 | scalar_t *__restrict__ grad_in_, 60 | scalar_t *__restrict__ grad_w_, 61 | const float max_dudt_inv, 62 | size_t neuron_num, size_t tot_size) { 63 | 64 | long long index = blockIdx.x * blockDim.x + threadIdx.x; 65 | float partial_u = 0, partial_u_t = 0, partial_u_grad_w = 0, partial_u_grad_t = 0; 66 | int delta_t = 0; 67 | bool spiked = false, out = false; 68 | if (index < neuron_num) { 69 | for (index = tot_size - neuron_num + index; index >= 0; index -= neuron_num) { 70 | out = outputs[index] > 0; 71 | spiked |= out; 72 | 73 | partial_u = min(max(-1.0f / delta_u[index], -4.0f), 0.0f); 74 | partial_u_t = min(max(-1.0f / delta_u_t[index], -max_dudt_inv), 0.0f); 75 | partial_u_grad_w = out ? grad_delta[index] * partial_u : partial_u_grad_w; 76 | partial_u_grad_t = out ? grad_delta[index] * partial_u_t : partial_u_grad_t; 77 | 78 | delta_t = out ? 0 : delta_t + 1; 79 | grad_in_[index] = spiked ? partial_u_grad_t * partial_a[delta_t] : 0; 80 | grad_w_[index] = spiked ? partial_u_grad_w * syn_a[delta_t] : 0; 81 | } 82 | } 83 | } 84 | 85 | } // namespace 86 | 87 | std::vector neuron_forward_cuda( 88 | const torch::Tensor &in_I, 89 | const float theta_m, 90 | const float theta_s, 91 | const float theta_grad, 92 | const float threshold, 93 | const float is_forward_leaky, 94 | const float is_grad_exp) { 95 | 96 | auto delta_u = torch::zeros_like(in_I); 97 | auto delta_u_t = torch::zeros_like(in_I); 98 | auto outputs = torch::zeros_like(in_I); 99 | 100 | const auto tot_size = in_I.numel(), neuron_num = tot_size / in_I.size(0); 101 | 102 | const int threads = 1024; 103 | const auto blocks = (neuron_num + threads - 1) / threads; 104 | 105 | AT_DISPATCH_FLOATING_TYPES(in_I.type(), "neuron_forward_cuda_kernel", ([&] { 106 | neuron_forward_cuda_kernel<<>>( 107 | in_I.data(), 108 | delta_u.data(), 109 | delta_u_t.data(), 110 | outputs.data(), 111 | theta_m, theta_s, theta_grad, threshold, is_forward_leaky, is_grad_exp, 112 | neuron_num, tot_size); 113 | })); 114 | 115 | return {delta_u, delta_u_t, outputs}; 116 | } 117 | 118 | 119 | std::vector neuron_backward_cuda( 120 | const torch::Tensor &grad_delta, 121 | const torch::Tensor &outputs, 122 | const torch::Tensor &delta_u, 123 | const torch::Tensor &delta_u_t, 124 | const torch::Tensor &syn_a, 125 | const torch::Tensor &partial_a, 126 | const float max_dudt_inv) { 127 | 128 | auto grad_in_ = torch::zeros_like(outputs); 129 | auto grad_w_ = torch::zeros_like(outputs); 130 | 131 | const auto tot_size = outputs.numel(), neuron_num = tot_size / outputs.size(0); 132 | 133 | const int threads = 1024; 134 | const auto blocks = (neuron_num + threads - 1) / threads; 135 | 136 | AT_DISPATCH_FLOATING_TYPES(grad_delta.type(), "neuron_backward_cuda_kernel", ([&] { 137 | neuron_backward_cuda_kernel<<>>( 138 | grad_delta.data(), 139 | outputs.data(), 140 | delta_u.data(), 141 | delta_u_t.data(), 142 | syn_a.data(), 143 | partial_a.data(), 144 | grad_in_.data(), 145 | grad_w_.data(), 146 | max_dudt_inv, 147 | neuron_num, tot_size); 148 | })); 149 | 150 | return {grad_in_, grad_w_}; 151 | } 152 | -------------------------------------------------------------------------------- /layers/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as f 4 | import global_v as glv 5 | from torch.cuda.amp import custom_fwd, custom_bwd 6 | from layers.functions import readConfig 7 | 8 | 9 | class PoolLayer(nn.Module): 10 | def __init__(self, network_config, config, name): 11 | super(PoolLayer, self).__init__() 12 | self.name = name 13 | self.layer_config = config 14 | self.network_config = network_config 15 | self.type = config['type'] 16 | kernel_size = config['kernel_size'] 17 | 18 | self.kernel = readConfig(kernel_size, 'kernelSize') 19 | # self.in_shape = in_shape 20 | # self.out_shape = [in_shape[0], int(in_shape[1] / kernel[0]), int(in_shape[2] / kernel[1])] 21 | print('pooling') 22 | # print(self.in_shape) 23 | # print(self.out_shape) 24 | print("-----------------------------------------") 25 | 26 | def forward(self, x): 27 | pool_type = glv.network_config['pooling_type'] 28 | assert(pool_type in ['avg', 'max', 'adjusted_avg']) 29 | T, n_batch, C, H, W = x.shape 30 | x = x.reshape(T * n_batch, C, H, W) 31 | if pool_type == 'avg': 32 | x = f.avg_pool2d(x, self.kernel) 33 | elif pool_type == 'max': 34 | x = f.max_pool2d(x, self.kernel) 35 | elif pool_type == 'adjusted_avg': 36 | x = PoolFunc.apply(x, self.kernel) 37 | x = x.reshape(T, n_batch, *x.shape[1:]) 38 | return x 39 | 40 | def get_parameters(self): 41 | return 42 | 43 | def forward_pass(self, x, epoch): 44 | y1 = self.forward(x) 45 | return y1 46 | 47 | def weight_clipper(self): 48 | return 49 | 50 | class PoolFunc(torch.autograd.Function): 51 | @staticmethod 52 | @custom_fwd 53 | def forward(ctx, inputs, kernel): 54 | outputs = f.avg_pool2d(inputs, kernel) 55 | ctx.save_for_backward(outputs, torch.tensor(inputs.shape), torch.tensor(kernel)) 56 | return outputs 57 | 58 | @staticmethod 59 | @custom_bwd 60 | def backward(ctx, grad_delta): 61 | (outputs, input_shape, kernel) = ctx.saved_tensors 62 | kernel = kernel.tolist() 63 | outputs = 1 / outputs 64 | outputs[outputs > kernel[0] * kernel[1] + 1] = 0 65 | outputs /= kernel[0] * kernel[1] 66 | grad = f.interpolate(grad_delta * outputs, size=input_shape.tolist()[2:]) 67 | return grad, None -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | from torch.cuda.amp import autocast, GradScaler 8 | import torch.distributed as dist 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch.optim.lr_scheduler import CosineAnnealingLR 12 | from network_parser import parse 13 | from datasets import loadMNIST, loadCIFAR10, loadCIFAR100, loadFashionMNIST, loadSpiking, loadSMNIST 14 | from datasets.utils import TTFS 15 | import cnns 16 | from utils import learningStats 17 | import layers.losses as losses 18 | import numpy as np 19 | from datetime import datetime 20 | from torch.nn.utils import clip_grad_norm_ 21 | import global_v as glv 22 | 23 | from sklearn.metrics import confusion_matrix 24 | import argparse 25 | 26 | log_interval = 100 27 | multigpu = False 28 | 29 | 30 | def get_loss(network_config, err, outputs, labels): 31 | if network_config['loss'] in ['kernel', 'timing']: 32 | targets = torch.zeros_like(outputs) 33 | device = torch.device(glv.rank) 34 | if T >= 8: 35 | desired_spikes = torch.tensor([0, 1], device=device).repeat(T // 2) 36 | if T % 2 == 1: 37 | desired_spikes = torch.cat([torch.zeros(1, device=device), desired_spikes]) 38 | else: 39 | desired_spikes = torch.ones(T, device=device) 40 | desired_spikes[0] = 0 41 | for i in range(len(labels)): 42 | targets[..., i, labels[i]] = desired_spikes 43 | 44 | if network_config['loss'] == "count": 45 | # set target signal 46 | desired_count = network_config['desired_count'] 47 | undesired_count = network_config['undesired_count'] 48 | targets = torch.ones_like(outputs[0]) * undesired_count 49 | for i in range(len(labels)): 50 | targets[i, labels[i]] = desired_count 51 | loss = err.spike_count(outputs, targets) 52 | elif network_config['loss'] == "kernel": 53 | loss = err.spike_kernel(outputs, targets) 54 | elif network_config['loss'] == "TET": 55 | # set target signal 56 | loss = err.spike_TET(outputs, labels) 57 | else: 58 | raise Exception('Unrecognized loss function.') 59 | 60 | return loss.to(glv.rank) 61 | 62 | 63 | def readout(output, T): 64 | output *= 1.1 - torch.arange(T, device=torch.device(glv.rank)).reshape(T, 1, 1) / T / 10 65 | return torch.sum(output, dim=0).detach() 66 | 67 | 68 | def preprocess(inputs, network_config): 69 | inputs = inputs.to(glv.rank) 70 | if network_config['encoding'] == 'TTFS': 71 | inputs = torch.stack([TTFS(data, T) for data in inputs], dim=0) 72 | if len(inputs.shape) < 5: 73 | inputs = inputs.unsqueeze_(0).repeat(T, 1, 1, 1, 1) 74 | else: 75 | inputs = inputs.permute(1, 0, 2, 3, 4) 76 | return inputs 77 | 78 | 79 | def train(network, trainloader, opti, epoch, states, err): 80 | train_loss, correct, total = 0, 0, 0 81 | cnt_oneof, cnt_unique = 0, 0 82 | network_config = glv.network_config 83 | batch_size = network_config['batch_size'] 84 | scaler = GradScaler() 85 | start_time = datetime.now() 86 | 87 | forward_time, backward_time, data_time, other_time, glv.time_use = 0, 0, 0, 0, 0 88 | t0 = datetime.now() 89 | num_batch = len(trainloader) 90 | batch_idx = 0 91 | for inputs, labels in trainloader: 92 | torch.cuda.synchronize() 93 | data_time += (datetime.now() - t0).total_seconds() 94 | t0 = datetime.now() 95 | batch_idx += 1 96 | 97 | labels, inputs = (x.to(glv.rank) for x in (labels, inputs)) 98 | inputs = preprocess(inputs, network_config) 99 | # forward pass 100 | if network_config['amp']: 101 | with autocast(): 102 | outputs = network(inputs, labels, epoch, True) 103 | loss = get_loss(network_config, err, outputs, labels) 104 | else: 105 | outputs = network(inputs, labels, epoch, True) 106 | loss = get_loss(network_config, err, outputs, labels) 107 | assert (len(outputs.shape) == 3) 108 | 109 | torch.cuda.synchronize() 110 | forward_time += (datetime.now() - t0).total_seconds() 111 | t0 = datetime.now() 112 | # backward pass 113 | opti.zero_grad() 114 | if network_config['amp']: 115 | scaler.scale(loss).backward() 116 | scaler.unscale_(opti) 117 | clip_grad_norm_(network.parameters(), 1) 118 | scaler.step(opti) 119 | scaler.update() 120 | else: 121 | loss.backward() 122 | clip_grad_norm_(network.parameters(), 1) 123 | opti.step() 124 | # (network.module if multigpu else network).weight_clipper() 125 | torch.cuda.synchronize() 126 | backward_time += (datetime.now() - t0).total_seconds() 127 | t0 = datetime.now() 128 | 129 | spike_counts = readout(glv.outputs_raw, T) 130 | predicted = torch.argmax(spike_counts, axis=1) 131 | train_loss += torch.sum(loss).item() 132 | total += len(labels) 133 | correct += (predicted == labels).sum().item() 134 | 135 | states.training.correctSamples = correct 136 | states.training.numSamples = total 137 | states.training.lossSum += loss.to('cpu').data.item() 138 | 139 | labels = labels.reshape(-1) 140 | idx = torch.arange(labels.shape[0], device=torch.device(glv.rank)) 141 | nspike_label = spike_counts[idx, labels] 142 | cnt_oneof += torch.sum(nspike_label == torch.max(spike_counts, axis=1).values).item() 143 | spike_counts[idx, labels] -= 1 144 | cnt_unique += torch.sum(nspike_label > torch.max(spike_counts, axis=1).values).item() 145 | spike_counts[idx, labels] += 1 146 | 147 | if (not multigpu or dist.get_rank() == 0) and (batch_idx % log_interval == 0 or batch_idx == num_batch): 148 | # if batch_idx % log_interval == 0: 149 | states.print(epoch, batch_idx, (datetime.now() - start_time).total_seconds()) 150 | print('Time consumed on loading data = %.2f, forward = %.2f, backward = %.2f, other = %.2f' 151 | % (data_time, forward_time, backward_time, other_time)) 152 | data_time, forward_time, backward_time, other_time, glv.time_use = 0, 0, 0, 0, 0 153 | 154 | avg_oneof, avg_unique = cnt_oneof / (batch_size * batch_idx), cnt_unique / (batch_size * batch_idx) 155 | print( 156 | 'Percentage of partially right = %.2f%%, entirely right = %.2f%%' % (avg_oneof * 100, avg_unique * 100)) 157 | print() 158 | torch.cuda.synchronize() 159 | other_time += (datetime.now() - t0).total_seconds() 160 | t0 = datetime.now() 161 | 162 | acc = correct / total 163 | train_loss = train_loss / total 164 | 165 | return acc, train_loss 166 | 167 | 168 | def test(network, testloader, epoch, states, log_dir): 169 | global best_acc 170 | correct = 0 171 | total = 0 172 | network_config = glv.network_config 173 | T = network_config['n_steps'] 174 | n_class = network_config['n_class'] 175 | time = datetime.now() 176 | y_pred = [] 177 | y_true = [] 178 | num_batch = len(testloader) 179 | batch_idx = 0 180 | for inputs, labels in testloader: 181 | batch_idx += 1 182 | inputs = preprocess(inputs, network_config) 183 | # forward pass 184 | labels = labels.to(glv.rank) 185 | inputs = inputs.to(glv.rank) 186 | with torch.no_grad(): 187 | outputs = network(inputs, None, epoch, False) 188 | 189 | spike_counts = readout(outputs, T).cpu().numpy() 190 | predicted = np.argmax(spike_counts, axis=1) 191 | labels = labels.cpu().numpy() 192 | y_pred.append(predicted) 193 | y_true.append(labels) 194 | total += len(labels) 195 | correct += (predicted == labels).sum().item() 196 | 197 | states.testing.correctSamples += (predicted == labels).sum().item() 198 | states.testing.numSamples = total 199 | if batch_idx % log_interval == 0 or batch_idx == num_batch: 200 | states.print(epoch, batch_idx, (datetime.now() - time).total_seconds()) 201 | print() 202 | 203 | y_pred = np.concatenate(y_pred) 204 | y_true = np.concatenate(y_true) 205 | 206 | nums = np.bincount(y_true) 207 | confusion = confusion_matrix(y_true, y_pred, labels=np.arange(n_class)) / nums.reshape(-1, 1) 208 | 209 | test_acc = correct / total 210 | 211 | state = { 212 | 'net': (network.module if multigpu else network).state_dict(), 213 | 'epoch': epoch, 214 | } 215 | torch.save(state, os.path.join(log_dir, 'last.pth')) 216 | 217 | if test_acc > best_acc: 218 | best_acc = test_acc 219 | torch.save(state, os.path.join(log_dir, 'best.pth')) 220 | return test_acc, confusion 221 | 222 | 223 | if __name__ == '__main__': 224 | parser = argparse.ArgumentParser() 225 | parser.add_argument('-config', action='store', dest='config', help='The path of config file') 226 | parser.add_argument('-checkpoint', action='store', dest='checkpoint', 227 | help='The path of checkpoint, if use checkpoint') 228 | parser.add_argument('-seed', type=int, default=3, help='random seed (default: 3)') 229 | parser.add_argument('-dist', type=str, default="nccl", help='distributed data parallel backend') 230 | parser.add_argument('--local_rank', type=int, default=-1) 231 | try: 232 | args = parser.parse_args() 233 | except: 234 | parser.print_help() 235 | exit(0) 236 | 237 | if args.config is None: 238 | raise Exception('Unrecognized config file.') 239 | else: 240 | config_path = args.config 241 | 242 | params = parse(config_path) 243 | 244 | # check GPU 245 | if not torch.cuda.is_available(): 246 | print('No GPU device available') 247 | sys.exit(1) 248 | # set GPU 249 | if args.local_rank >= 0: 250 | torch.cuda.set_device(args.local_rank) 251 | torch.distributed.init_process_group(backend=args.dist) 252 | glv.rank = args.local_rank 253 | multigpu = True 254 | else: 255 | glv.rank = 0 256 | cudnn.benchmark = True 257 | cudnn.enabled = True 258 | torch.manual_seed(args.seed) 259 | torch.cuda.manual_seed_all(args.seed) 260 | np.random.seed(args.seed) 261 | 262 | glv.init(params['Network'], params['Layers'] if 'Layers' in params.parameters else None) 263 | 264 | data_path = os.path.expanduser(params['Network']['data_path']) 265 | dataset_func = {"MNIST": loadMNIST.get_mnist, 266 | "NMNIST": loadSpiking.get_nmnist, 267 | "FashionMNIST": loadFashionMNIST.get_fashionmnist, 268 | "CIFAR10": loadCIFAR10.get_cifar10, 269 | "CIFAR100": loadCIFAR100.get_cifar100, 270 | "DVS128Gesture": loadSpiking.get_dvs128_gesture, 271 | "CIFAR10DVS": loadSpiking.get_cifar10_dvs, 272 | "SMNIST": loadSMNIST.get_smnist} 273 | try: 274 | trainset, testset = dataset_func[params['Network']['dataset']](data_path, params['Network']) 275 | except: 276 | raise Exception('Unrecognized dataset name.') 277 | batch_size = params['Network']['batch_size'] 278 | if multigpu: 279 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) 280 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=4, 281 | sampler=train_sampler, pin_memory=True) 282 | else: 283 | train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, 284 | pin_memory=True) 285 | test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, 286 | pin_memory=True) 287 | 288 | if 'model_import' not in glv.network_config: 289 | net = cnns.Network(list(train_loader.dataset[0][0].shape[-3:])).to(glv.rank) 290 | else: 291 | exec(f"from {glv.network_config['model_import']} import Network") 292 | net = Network().to(glv.rank) 293 | print(net) 294 | 295 | T = params['Network']['t_train'] 296 | if args.checkpoint is not None: 297 | checkpoint_path = args.checkpoint 298 | checkpoint = torch.load(checkpoint_path, map_location='cuda:0') 299 | net.load_state_dict(checkpoint['net']) 300 | epoch_start = checkpoint['epoch'] + 1 301 | print('Network loaded.') 302 | print(f'Start training from epoch {epoch_start}.') 303 | else: 304 | inputs = torch.stack([train_loader.dataset[i][0] for i in range(batch_size)], dim=0).to(glv.rank) 305 | inputs = preprocess(inputs, glv.network_config) 306 | print("Start to initialize.") 307 | # initialize weights 308 | net.eval() 309 | glv.init_flag = True 310 | net(inputs, None, None, False) 311 | net.train() 312 | glv.init_flag = False 313 | epoch_start = 1 314 | 315 | error = losses.SpikeLoss().to(glv.rank) # the loss is not defined here 316 | if multigpu: 317 | net = DDP(net, device_ids=[glv.rank], output_device=glv.rank) 318 | optim_type, weight_decay, lr = (glv.network_config[x] for x in ('optimizer', 'weight_decay', 'lr')) 319 | assert (optim_type in ['SGD', 'Adam', 'AdamW']) 320 | 321 | # norm_param, weight_param = net.get_parameters() 322 | optim_dict = {'SGD': torch.optim.SGD, 323 | 'Adam': torch.optim.Adam, 324 | 'AdamW': torch.optim.AdamW} 325 | norm_param, param = [], [] 326 | for layer in net.modules(): 327 | if layer.type in ['conv', 'linear']: 328 | norm_param.extend([layer.norm_weight, layer.norm_bias]) 329 | param.append(layer.weight) 330 | optimizer = optim_dict[optim_type]([ 331 | {'params': param}, 332 | {'params': norm_param, 'lr': lr * glv.network_config['norm_grad']} 333 | ], lr=lr, weight_decay=weight_decay) 334 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=glv.network_config['epochs']) 335 | 336 | best_acc = 0 337 | 338 | l_states = learningStats() 339 | 340 | log_dir = f"{params['Network']['log_path']}_{datetime.now().strftime('%Y%m%d-%H%M%S')}/" 341 | writer = SummaryWriter(log_dir) 342 | confu_mats = [] 343 | for path in ['logs']: 344 | if not os.path.isdir(path): 345 | os.mkdir(path) 346 | shutil.copyfile(config_path, os.path.join(log_dir, os.path.split(config_path)[-1])) 347 | 348 | (net.module if multigpu else net).train() 349 | for epoch in range(epoch_start, params['Network']['epochs'] + epoch_start): 350 | if multigpu: 351 | train_loader.sampler.set_epoch(epoch) 352 | l_states.training.reset() 353 | train_acc, loss = train(net, train_loader, optimizer, epoch, l_states, error) 354 | l_states.training.update() 355 | l_states.testing.reset() 356 | test_acc, confu_mat = test(net, test_loader, epoch, l_states, log_dir) 357 | l_states.testing.update() 358 | lr_scheduler.step() 359 | 360 | confu_mats.append(confu_mat) 361 | if glv.rank == 0: 362 | writer.add_scalars('Accuracy', {'train': train_acc, 363 | 'test': test_acc}, epoch) 364 | writer.add_scalars('Loss', {'loss': loss}, epoch) 365 | np.save(log_dir + 'confusion_matrix.npy', np.stack(confu_mats)) 366 | -------------------------------------------------------------------------------- /network_parser.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | 4 | class parse(object): 5 | """ 6 | This class reads yaml parameter file and allows dictionary like access to the members. 7 | """ 8 | def __init__(self, path): 9 | with open(path, 'r') as file: 10 | self.parameters = yaml.safe_load(file) 11 | 12 | # Allow dictionary like access 13 | def __getitem__(self, key): 14 | return self.parameters[key] 15 | 16 | def save(self, filename): 17 | with open(filename, 'w') as f: 18 | yaml.dump(self.parameters, f) 19 | -------------------------------------------------------------------------------- /networks/CIFAR10.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | epochs: 500 3 | batch_size: 50 4 | n_steps: 12 5 | dataset: "CIFAR10" 6 | data_path: "../../datasets/cifar10" 7 | # data_path: "../../../../datasets/cifar10" 8 | log_path: "logs/cifar10" 9 | backend: "cuda" 10 | optimizer: "SGD" 11 | lr: 0.05 12 | weight_decay: 0.0005 13 | loss: "count" 14 | norm_grad: 0 15 | gradient_type: "exponential" 16 | pooling_type: "adjusted_avg" 17 | n_class: 10 18 | desired_count: 10 19 | undesired_count: 1 20 | tau_m: 7 21 | tau_s: 4 22 | tau_grad: 3.5 23 | 24 | Layers: 25 | conv_1: 26 | type: "conv" # 32*32 27 | in_channels: 3 28 | out_channels: 128 29 | kernel_size: 3 30 | padding: 1 31 | threshold: 1 32 | 33 | conv_2: 34 | type: "conv" # 32*32 35 | in_channels: 128 36 | out_channels: 128 37 | kernel_size: 3 38 | padding: 1 39 | threshold: 1 40 | 41 | pooling_1: 42 | type: "pooling" 43 | kernel_size: 2 44 | 45 | dropout_1: 46 | type: "dropout" 47 | p: 0.1 48 | 49 | conv_3: 50 | type: "conv" # 16*16 51 | in_channels: 128 52 | out_channels: 256 53 | kernel_size: 3 54 | padding: 1 55 | threshold: 1 56 | 57 | conv_4: 58 | type: "conv" # 16*16 59 | in_channels: 256 60 | out_channels: 256 61 | kernel_size: 3 62 | padding: 1 63 | threshold: 1 64 | 65 | conv_5: 66 | type: "conv" # 16*16 67 | in_channels: 256 68 | out_channels: 256 69 | kernel_size: 3 70 | padding: 1 71 | threshold: 1 72 | 73 | pooling_2: 74 | type: "pooling" 75 | kernel_size: 2 76 | 77 | dropout_2: 78 | type: "dropout" 79 | p: 0.1 80 | 81 | conv_6: 82 | type: "conv" # 8*8 83 | in_channels: 256 84 | out_channels: 512 85 | kernel_size: 3 86 | padding: 1 87 | threshold: 1 88 | 89 | conv_7: 90 | type: "conv" # 8*8 91 | in_channels: 512 92 | out_channels: 512 93 | kernel_size: 3 94 | padding: 1 95 | threshold: 1 96 | 97 | conv_8: 98 | type: "conv" # 8*8 99 | in_channels: 512 100 | out_channels: 512 101 | kernel_size: 3 102 | padding: 1 103 | threshold: 1 104 | 105 | pooling_3: 106 | type: "pooling" 107 | kernel_size: 2 108 | 109 | dropout_3: 110 | type: "dropout" 111 | p: 0.1 112 | 113 | FC_1: 114 | type: "linear" 115 | n_inputs: 8192 116 | n_outputs: 2048 117 | threshold: 1 118 | 119 | dropout_4: 120 | type: "dropout" 121 | p: 0.1 122 | 123 | FC_2: 124 | type: "linear" 125 | n_inputs: 2048 126 | n_outputs: 2048 127 | threshold: 1 128 | 129 | dropout_5: 130 | type: "dropout" 131 | p: 0.1 132 | 133 | output: 134 | type: "linear" 135 | n_inputs: 2048 136 | n_outputs: 10 137 | threshold: 1 138 | 139 | -------------------------------------------------------------------------------- /networks/CIFAR100.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | epochs: 500 3 | batch_size: 50 4 | n_steps: 16 5 | dataset: "CIFAR100" 6 | data_path: "../../datasets/cifar100" 7 | # data_path: "../../../../datasets/cifar100" 8 | log_path: "logs/cifar100" 9 | backend: "cuda" 10 | optimizer: "SGD" 11 | lr: 0.06 12 | weight_decay: 0.0005 13 | avg_spike_init: 1.2 14 | loss: "count" 15 | norm_grad: 0 16 | gradient_type: "exponential" 17 | pooling_type: "adjusted_avg" 18 | n_class: 100 19 | desired_count: 15 20 | undesired_count: 1 21 | tau_m: 10 22 | tau_s: 6 23 | tau_grad: 5.5 24 | 25 | Layers: 26 | conv_1: 27 | type: "conv" # 32*32 28 | in_channels: 3 29 | out_channels: 128 30 | kernel_size: 3 31 | padding: 1 32 | threshold: 1 33 | 34 | conv_2: 35 | type: "conv" # 32*32 36 | in_channels: 128 37 | out_channels: 128 38 | kernel_size: 3 39 | padding: 1 40 | threshold: 1 41 | 42 | pooling_1: 43 | type: "pooling" 44 | kernel_size: 2 45 | 46 | conv_3: 47 | type: "conv" # 16*16 48 | in_channels: 128 49 | out_channels: 256 50 | kernel_size: 3 51 | padding: 1 52 | threshold: 1 53 | 54 | conv_4: 55 | type: "conv" # 16*16 56 | in_channels: 256 57 | out_channels: 256 58 | kernel_size: 3 59 | padding: 1 60 | threshold: 1 61 | 62 | conv_5: 63 | type: "conv" # 16*16 64 | in_channels: 256 65 | out_channels: 256 66 | kernel_size: 3 67 | padding: 1 68 | threshold: 1 69 | 70 | pooling_2: 71 | type: "pooling" 72 | kernel_size: 2 73 | 74 | conv_6: 75 | type: "conv" # 8*8 76 | in_channels: 256 77 | out_channels: 512 78 | kernel_size: 3 79 | padding: 1 80 | threshold: 1 81 | 82 | conv_7: 83 | type: "conv" # 8*8 84 | in_channels: 512 85 | out_channels: 512 86 | kernel_size: 3 87 | padding: 1 88 | threshold: 1 89 | 90 | conv_8: 91 | type: "conv" # 8*8 92 | in_channels: 512 93 | out_channels: 512 94 | kernel_size: 3 95 | padding: 1 96 | threshold: 1 97 | 98 | pooling_3: 99 | type: "pooling" 100 | kernel_size: 2 101 | 102 | FC_1: 103 | type: "linear" 104 | n_inputs: 8192 105 | n_outputs: 2048 106 | threshold: 1 107 | 108 | FC_2: 109 | type: "linear" 110 | n_inputs: 2048 111 | n_outputs: 2048 112 | threshold: 1 113 | 114 | output: 115 | type: "linear" 116 | n_inputs: 2048 117 | n_outputs: 100 118 | threshold: 1 119 | 120 | -------------------------------------------------------------------------------- /networks/FashionMNIST.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | epochs: 150 3 | batch_size: 50 4 | n_steps: 5 5 | dataset: "FashionMNIST" 6 | # data_path: "../../datasets/FashionMNIST" 7 | data_path: "../../../../datasets/FashionMNIST" 8 | log_path: "logs/FashionMNIST" 9 | backend: "cuda" 10 | optimizer: "AdamW" 11 | lr: 0.0005 12 | weight_decay: 0.0005 13 | avg_spike_init: 0.5 14 | loss: "count" 15 | gradient_type: "exponential" 16 | pooling_type: "adjusted_avg" 17 | n_class: 10 18 | desired_count: 5 19 | undesired_count: 1 20 | tau_m: 5 21 | tau_s: 3 22 | tau_grad: 2.5 23 | 24 | Layers: 25 | conv_1: 26 | type: "conv" 27 | in_channels: 1 28 | out_channels: 32 29 | kernel_size: 5 30 | padding: 2 31 | threshold: 1 32 | 33 | pooling_1: 34 | type: "pooling" 35 | kernel_size: 2 36 | 37 | conv_2: 38 | type: "conv" 39 | in_channels: 32 40 | out_channels: 64 41 | kernel_size: 5 42 | padding: 2 43 | threshold: 1 44 | 45 | pooling_2: 46 | type: "pooling" 47 | kernel_size: 2 48 | 49 | FC_1: 50 | type: "linear" 51 | n_inputs: 3136 52 | n_outputs: 1024 53 | threshold: 1 54 | 55 | output: 56 | type: "linear" 57 | n_inputs: 1024 58 | n_outputs: 10 59 | threshold: 1 60 | -------------------------------------------------------------------------------- /networks/RESNET.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | epochs: 500 3 | batch_size: 50 4 | n_steps: 12 5 | dataset: "CIFAR10" 6 | data_path: "../../datasets/cifar10" 7 | # data_path: "../../../../datasets/cifar10" 8 | log_path: "logs/cifar10" 9 | backend: "cuda" 10 | optimizer: "AdamW" 11 | lr: 0.0002 12 | weight_decay: 0.05 13 | loss: "count" 14 | norm_grad: 0 15 | gradient_type: "exponential" 16 | pooling_type: "adjusted_avg" 17 | n_class: 10 18 | desired_count: 10 19 | undesired_count: 1 20 | tau_m: 7 21 | tau_s: 4 22 | tau_grad: 3.5 23 | 24 | model_import: "networks.resnet" 25 | -------------------------------------------------------------------------------- /networks/TTFS.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | epochs: 500 3 | batch_size: 50 4 | n_steps: 25 5 | dataset: "CIFAR10" 6 | data_path: "../../datasets/cifar10" 7 | # data_path: "../../../../datasets/cifar10" 8 | log_path: "logs/TTFS" 9 | backend: "cuda" 10 | optimizer: "AdamW" 11 | encoding: "TTFS" 12 | lr: 0.0001 13 | weight_decay: 0.0002 14 | avg_spike_init: 0.5 15 | loss: "count" 16 | norm_grad: 0 17 | gradient_type: "exponential" 18 | pooling_type: "adjusted_avg" 19 | n_class: 10 20 | desired_count: 12 21 | undesired_count: 1 22 | tau_m: 7 23 | tau_s: 4 24 | tau_grad: 3.5 25 | 26 | model_import: "networks.resnet" -------------------------------------------------------------------------------- /networks/mnist.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | epochs: 100 3 | batch_size: 50 4 | n_steps: 5 5 | dataset: "MNIST" 6 | data_path: "../../datasets/mnist" 7 | # data_path: "../../../../datasets/mnist" 8 | log_path: "logs/mnist" 9 | backend: "cuda" 10 | optimizer: "AdamW" 11 | lr: 0.0005 12 | weight_decay: 0.0005 13 | avg_spike_init: 0.5 14 | loss: "count" 15 | gradient_type: "exponential" 16 | pooling_type: "adjusted_avg" 17 | n_class: 10 18 | desired_count: 5 19 | undesired_count: 1 20 | tau_m: 5 21 | tau_s: 3 22 | tau_grad: 2.5 23 | 24 | Layers: 25 | conv_1: 26 | type: "conv" 27 | in_channels: 1 28 | out_channels: 15 29 | kernel_size: 5 30 | padding: 0 31 | threshold: 1 32 | 33 | pooling_1: 34 | type: "pooling" 35 | kernel_size: 2 36 | 37 | conv_2: 38 | type: "conv" 39 | in_channels: 15 40 | out_channels: 40 41 | kernel_size: 5 42 | padding: 0 43 | threshold: 1 44 | 45 | pooling_2: 46 | type: "pooling" 47 | kernel_size: 2 48 | 49 | FC_1: 50 | type: "linear" 51 | n_inputs: 640 52 | n_outputs: 300 53 | threshold: 1 54 | 55 | output: 56 | type: "linear" 57 | n_inputs: 300 58 | n_outputs: 10 59 | threshold: 1 -------------------------------------------------------------------------------- /networks/n-mnist.yaml: -------------------------------------------------------------------------------- 1 | Network: 2 | epochs: 120 3 | batch_size: 50 4 | n_steps: 30 5 | dataset: "NMNIST" 6 | data_path: "../../datasets/n_mnist" 7 | # data_path: "../../../../datasets/n_mnist" 8 | log_path: "logs/n-mnist" 9 | backend: "cuda" 10 | optimizer: "AdamW" 11 | lr: 0.0005 12 | weight_decay: 0.2 13 | avg_spike_init: 2 14 | loss: "count" 15 | gradient_type: "exponential" 16 | pooling_type: "adjusted_avg" 17 | n_class: 10 18 | desired_count: 15 19 | undesired_count: 2 20 | tau_m: 8 21 | tau_s: 4 22 | tau_grad: 3 23 | 24 | Layers: 25 | conv_1: 26 | type: "conv" # 32 27 | in_channels: 2 28 | out_channels: 12 29 | kernel_size: 5 30 | padding: 1 31 | threshold: 1 32 | 33 | pooling_1: # 16 34 | type: "pooling" 35 | kernel_size: 2 36 | 37 | conv_2: # 12 38 | type: "conv" 39 | in_channels: 12 40 | out_channels: 64 41 | kernel_size: 5 42 | padding: 0 43 | threshold: 1 44 | 45 | pooling_2: # 6 46 | type: "pooling" 47 | kernel_size: 2 48 | 49 | output: 50 | type: "linear" 51 | n_inputs: 2304 52 | n_outputs: 10 53 | threshold: 1 54 | 55 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import layers.conv as conv 4 | import layers.pooling as pooling 5 | import layers.dropout as dropout 6 | import layers.linear as linear 7 | from torch.cuda.amp import custom_fwd, custom_bwd 8 | import global_v as glv 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 12 | """3x3 convolution with padding""" 13 | config = {'in_channels': in_planes, 'out_channels': out_planes, 'type': 'conv', 14 | 'kernel_size': 3, 'padding': 1, 'stride': stride, 'dilation': dilation, 'threshold': 1} 15 | return conv.ConvLayer(network_config=None, config=config, name=None) 16 | 17 | 18 | def conv1x1(in_planes, out_planes, stride=1): 19 | """1x1 convolution""" 20 | config = {'in_channels': in_planes, 'out_channels': out_planes, 'type': 'conv', 21 | 'kernel_size': 1, 'padding': 0, 'stride': stride, 'threshold': 1} 22 | return conv.ConvLayer(network_config=None, config=config, name=None) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 29 | base_width=64, dilation=1, **kwargs): 30 | super(BasicBlock, self).__init__() 31 | if groups != 1 or base_width != 64: 32 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 33 | if dilation > 1: 34 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 35 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 36 | self.conv1 = conv3x3(inplanes, planes, stride) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | identity = x 43 | 44 | out = self.conv1(x) 45 | out = self.conv2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | # may need custom backward 51 | # out = out + identity 52 | out = AddFunc.apply(out, identity) 53 | 54 | return out 55 | 56 | 57 | class AddFunc(torch.autograd.Function): 58 | @staticmethod 59 | @custom_fwd 60 | def forward(ctx, a, b): 61 | ctx.save_for_backward(a, b) 62 | return a + b 63 | 64 | @staticmethod 65 | @custom_bwd 66 | def backward(ctx, grad): 67 | a, b = ctx.saved_tensors 68 | s = a + b 69 | s[s == 0] = 1 70 | return grad * a / s, grad * b / s 71 | 72 | 73 | class SpikingResNet(nn.Module): 74 | def __init__(self, block, layers, num_classes=10, groups=1, width_per_group=64, norm_layer=None, **kwargs): 75 | super(SpikingResNet, self).__init__() 76 | self._norm_layer = norm_layer 77 | 78 | self.inplanes = 128 79 | self.dilation = 1 80 | 81 | self.groups = groups 82 | self.base_width = width_per_group 83 | 84 | config = {'in_channels': 3, 'out_channels': self.inplanes, 'type': 'conv', 85 | 'kernel_size': 5, 'padding': 2, 'stride': 1, 'dilation': 1, 'threshold': 1} 86 | self.conv1 = conv.ConvLayer(network_config=None, config=config, name=None) 87 | 88 | # self.maxpool = layer.MaxPool2d(kernel_size=3, stride=2, padding=1) 89 | self.layer1 = self._make_layer(block, 128, layers[0], stride=2, **kwargs) 90 | self.layer2 = self._make_layer(block, 256, layers[1], stride=2, **kwargs) 91 | self.layer3 = self._make_layer(block, 512, layers[2], stride=2, **kwargs) 92 | # self.layer3 = self._make_layer(block, 256, layers[2], stride=2, **kwargs) 93 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=1, **kwargs) 94 | config = {'type': 'pool', 'kernel_size': 32 // 2 ** 3} 95 | self.pool = pooling.PoolLayer(network_config=None, config=config, name=None) 96 | config = {'type': 'linear', 'n_inputs': 512 * block.expansion, 'n_outputs': num_classes, 'threshold': 1} 97 | self.fc = linear.LinearLayer(network_config=None, config=config, name=None) 98 | 99 | def _make_layer(self, block, planes, blocks, stride=1, **kwargs): 100 | norm_layer = self._norm_layer 101 | downsample = None 102 | previous_dilation = self.dilation 103 | if stride != 1 or self.inplanes != planes * block.expansion: 104 | downsample = conv1x1(self.inplanes, planes * block.expansion, stride) 105 | 106 | layers = [] 107 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 108 | self.base_width, previous_dilation, **kwargs)) 109 | self.inplanes = planes * block.expansion 110 | for _ in range(1, blocks): 111 | layers.append(block(self.inplanes, planes, groups=self.groups, 112 | base_width=self.base_width, dilation=self.dilation, **kwargs)) 113 | 114 | return nn.Sequential(*layers) 115 | 116 | def forward(self, x, labels, epoch, is_train): 117 | assert (is_train or labels == None) 118 | # See note [TorchScript super()] 119 | x = self.conv1(x) 120 | # x = self.maxpool(x) 121 | 122 | x = self.layer1(x) 123 | x = self.layer2(x) 124 | x = self.layer3(x) 125 | # x = self.layer4(x) 126 | 127 | x = self.pool(x) 128 | x = self.fc(x, labels) 129 | 130 | return x 131 | 132 | 133 | class Network(SpikingResNet): 134 | def __init__(self, input_shape=None): 135 | super(Network, self).__init__(BasicBlock, [2, 2, 2, 2], glv.network_config['n_class']) 136 | print("-----------------------------------------") 137 | 138 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # **Training Spiking Neural Networks with Event-driven Backpropagation** 2 | 3 | This repository is the official implementation of *Training Spiking Neural Networks with Event-driven Backpropagation* (**Neurips 2022**) \[[pdf](https://hal.science/hal-03889062v1/preview/Zhu%20et%20al.%20-%202022%20-%20Training%20Spiking%20Neural%20Networks%20with%20Event-driven%20Backpropagation.pdf)\]. 4 | 5 | # Requirements 6 | - pytorch=1.10.0 7 | - torchvision=0.11.0 8 | - spikingjelly 9 | 10 | # Training 11 | 12 | ## Before running 13 | 14 | Modify the data path and network settings in the .yaml config files (in the networks folder). 15 | 16 | We recommend you to run the code in Linux environment, since we use pytorch cuda functions in the backward stage and the compile process is inconvenient in Windows environment. 17 | 18 | In addition, we have implemented two backends for neuron functions in our algorithm: The python backend and the cuda backend, where the cuda backend significantly accelerates the neuron functions. 19 | 20 | The backend option can be configured by setting **backend: "cuda"** or **backend: "python"** in the .yaml config files. 21 | 22 | ## Run the code 23 | ``` 24 | $ CUDA_VISIBLE_DEVICES=0 python main.py -config networks/config_file.yaml 25 | ``` -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | # task_train = Progress().add_task("[red]Training...", total=100) 8 | # task_test = Progress().add_task("[green]Testing...", total=100) 9 | # 10 | class learningStat(): 11 | ''' 12 | This class collect the learning statistics over the epoch. 13 | 14 | Usage: 15 | 16 | This class is designed to be used with learningStats instance although it can be used separately. 17 | 18 | >>> trainingStat = learningStat() 19 | ''' 20 | def __init__(self): 21 | self.lossSum = 0 22 | self.correctSamples = 0 23 | self.numSamples = 0 24 | self.minloss = None 25 | self.maxAccuracy = None 26 | self.lossLog = [] 27 | self.accuracyLog = [] 28 | self.bestLoss = False 29 | self.bestAccuracy = False 30 | 31 | def reset(self): 32 | ''' 33 | Reset the learning staistics. 34 | This should usually be done before the start of an epoch so that new statistics counts can be accumulated. 35 | 36 | Usage: 37 | 38 | >>> trainingStat.reset() 39 | ''' 40 | self.lossSum = 0 41 | self.correctSamples = 0 42 | self.numSamples = 0 43 | 44 | def loss(self): 45 | ''' 46 | Returns the average loss calculated from the point the stats was reset. 47 | 48 | Usage: 49 | 50 | >>> loss = trainingStat.loss() 51 | ''' 52 | if self.numSamples > 0: 53 | return self.lossSum/self.numSamples 54 | else: 55 | return None 56 | 57 | def accuracy(self): 58 | ''' 59 | Returns the average accuracy calculated from the point the stats was reset. 60 | 61 | Usage: 62 | 63 | >>> accuracy = trainingStat.accuracy() 64 | ''' 65 | if self.numSamples > 0: 66 | return self.correctSamples/self.numSamples 67 | else: 68 | return None 69 | 70 | def update(self): 71 | ''' 72 | Updates the stats of the current session and resets the measures for next session. 73 | 74 | Usage: 75 | 76 | >>> trainingStat.update() 77 | ''' 78 | currentLoss = self.loss() 79 | self.lossLog.append(currentLoss) 80 | if self.minloss is None: 81 | self.minloss = currentLoss 82 | else: 83 | if currentLoss < self.minloss: 84 | self.minloss = currentLoss 85 | self.bestLoss = True 86 | else: 87 | self.bestLoss = False 88 | # self.minloss = self.minloss if self.minloss < currentLoss else currentLoss 89 | 90 | currentAccuracy = self.accuracy() 91 | self.accuracyLog.append(currentAccuracy) 92 | if self.maxAccuracy is None: 93 | self.maxAccuracy = currentAccuracy 94 | else: 95 | if currentAccuracy > self.maxAccuracy: 96 | self.maxAccuracy = currentAccuracy 97 | self.bestAccuracy = True 98 | else: 99 | self.bestAccuracy = False 100 | # self.maxAccuracy = self.maxAccuracy if self.maxAccuracy > currentAccuracy else currentAccuracy 101 | 102 | def displayString(self): 103 | loss = self.loss() 104 | accuracy = self.accuracy() 105 | minloss = self.minloss 106 | maxAccuracy = self.maxAccuracy 107 | 108 | if loss is None: # no stats available 109 | return 'No testing results' 110 | elif accuracy is None: 111 | if minloss is None: # accuracy and minloss stats is not available 112 | return 'loss = %-11.5g'%(loss) 113 | else: # accuracy is not available but minloss is available 114 | return 'loss = %-11.5g (min = %-11.5g)'%(loss, minloss) 115 | else: 116 | if minloss is None and maxAccuracy is None: # minloss and maxAccuracy is available 117 | return 'loss = %-11.5g %-11s accuracy = %.2f%% %-8s '%(loss, ' ', accuracy*100, ' ') 118 | else: # all stats are available 119 | return 'loss = %-11.5g (min = %-11.5g) accuracy = %.2f%% (max = %.2f%%)'\ 120 | %(loss, minloss, accuracy*100, maxAccuracy*100) 121 | 122 | 123 | class learningStats(): 124 | ''' 125 | This class provides mechanism to collect learning stats for training and testing, and displaying them efficiently. 126 | 127 | Usage: 128 | 129 | .. code-block:: python 130 | 131 | stats = learningStats() 132 | 133 | for epoch in range(100): 134 | tSt = datetime.now() 135 | 136 | stats.training.reset() 137 | for i in trainingLoop: 138 | # other main stuffs 139 | stats.training.correctSamples += numberOfCorrectClassification 140 | stats.training.numSamples += numberOfSamplesProcessed 141 | stats.training.lossSum += currentLoss 142 | stats.print(epoch, i, (datetime.now() - tSt).total_seconds()) 143 | stats.training.update() 144 | 145 | stats.testing.reset() 146 | for i in testingLoop 147 | # other main stuffs 148 | stats.testing.correctSamples += numberOfCorrectClassification 149 | stats.testing.numSamples += numberOfSamplesProcessed 150 | stats.testing.lossSum += currentLoss 151 | stats.print(epoch, i) 152 | stats.training.update() 153 | 154 | ''' 155 | 156 | def __init__(self): 157 | self.linesPrinted = 0 158 | self.training = learningStat() 159 | self.testing = learningStat() 160 | 161 | def update(self): 162 | ''' 163 | Updates the stats for training and testing and resets the measures for next session. 164 | 165 | Usage: 166 | 167 | >>> stats.update() 168 | ''' 169 | self.training.update() 170 | self.training.reset() 171 | self.testing.update() 172 | self.testing.reset() 173 | 174 | def print(self, epoch, iter=None, timeElapsed=None, header=None, footer=None): 175 | ''' 176 | Prints the available learning statistics from the current session on the console. 177 | For Linux systems, prints the data on same terminal space (might not work properly on other systems). 178 | 179 | Arguments: 180 | * ``epoch``: epoch counter to display (required). 181 | * ``iter``: iteration counter to display (not required). 182 | * ``timeElapsed``: runtime information (not required). 183 | * ``header``: things to be printed before printing learning statistics. Default: ``None``. 184 | * ``footer``: things to be printed after printing learning statistics. Default: ``None``. 185 | 186 | Usage: 187 | 188 | .. code-block:: python 189 | 190 | # prints stats with epoch index provided 191 | stats.print(epoch) 192 | 193 | # prints stats with epoch index and iteration index provided 194 | stats.print(epoch, iter=i) 195 | 196 | # prints stats with epoch index, iteration index and time elapsed information provided 197 | stats.print(epoch, iter=i, timeElapsed=time) 198 | ''' 199 | #print('\033[%dA' % (self.linesPrinted)) 200 | 201 | self.linesPrinted = 1 202 | 203 | epochStr = 'Epoch : %10d' % (epoch) 204 | iterStr = '' if iter is None else '(i = %7d)' % (iter) 205 | profileStr = '' if timeElapsed is None else ', %12.4f s elapsed' % timeElapsed 206 | 207 | if header is not None: 208 | for h in header: 209 | #print('\033[2K' + str(h)) 210 | print(h) 211 | self.linesPrinted += 1 212 | 213 | print(epochStr + iterStr + profileStr) 214 | print(self.training.displayString()) 215 | print(self.testing.displayString()) 216 | self.linesPrinted += 3 217 | 218 | if footer is not None: 219 | for f in footer: 220 | #print('\033[2K' + str(f)) 221 | print(f) 222 | self.linesPrinted += 1 223 | 224 | def plot(self, figures=(1, 2), saveFig=False, path=''): 225 | ''' 226 | Plots the available learning statistics. 227 | 228 | Arguments: 229 | * ``figures``: Index of figure ID to plot on. Default is figure(1) for loss plot and figure(2) for accuracy plot. 230 | * ``saveFig``(``bool``): flag to save figure into a file. 231 | * ``path``: path to save the file. Defaule is ``''``. 232 | 233 | Usage: 234 | 235 | .. code-block:: python 236 | 237 | # plot stats 238 | stats.plot() 239 | 240 | # plot stats figures specified 241 | stats.print(figures=(10, 11)) 242 | ''' 243 | plt.figure(figures[0]) 244 | plt.cla() 245 | if len(self.training.lossLog) > 0: 246 | plt.semilogy(self.training.lossLog, label='Training') 247 | if len(self.testing.lossLog) > 0: 248 | plt.semilogy(self.testing.lossLog, label='Testing') 249 | plt.xlabel('Epoch') 250 | plt.ylabel('Loss') 251 | plt.legend() 252 | if saveFig is True: 253 | plt.savefig(path + 'loss.png') 254 | # plt.close() 255 | 256 | plt.figure(figures[1]) 257 | plt.cla() 258 | if len(self.training.accuracyLog) > 0: 259 | plt.plot(self.training.accuracyLog, label='Training') 260 | if len(self.testing.accuracyLog) > 0: 261 | plt.plot(self.testing.accuracyLog, label='Testing') 262 | plt.xlabel('Epoch') 263 | plt.ylabel('Accuracy') 264 | plt.legend() 265 | if saveFig is True: 266 | plt.savefig(path + 'accuracy.png') 267 | # plt.close() 268 | 269 | def save(self, filename=''): 270 | ''' 271 | Saves the learning satatistics logs. 272 | 273 | Arguments: 274 | * ``filename``: filename to save the logs. ``accuracy.txt`` and ``loss.txt`` will be appended. 275 | 276 | Usage: 277 | 278 | .. code-block:: python 279 | 280 | # save stats 281 | stats.save() 282 | 283 | # save stats filename specified 284 | stats.save(filename='Run101-0.001-') # Run101-0.001-accuracy.txt and Run101-0.001-loss.txt 285 | ''' 286 | 287 | with open(filename + 'loss.txt', 'wt') as loss: 288 | loss.write('#%11s %11s\r\n' % ('Train', 'Test')) 289 | for i in range(len(self.training.lossLog)): 290 | loss.write('%12.6g %12.6g \r\n' % (self.training.lossLog[i], self.testing.lossLog[i])) 291 | 292 | with open(filename + 'accuracy.txt', 'wt') as accuracy: 293 | accuracy.write('#%11s %11s\r\n' % ('Train', 'Test')) 294 | if self.training.accuracyLog != [None] * len(self.training.accuracyLog): 295 | for i in range(len(self.training.accuracyLog)): 296 | accuracy.write('%12.6g %12.6g \r\n' % ( 297 | self.training.accuracyLog[i], 298 | self.testing.accuracyLog[i] if self.testing.accuracyLog[i] is not None else 0, 299 | )) 300 | 301 | def load(self, filename='', numEpoch=None, modulo=1): 302 | ''' 303 | Loads the learning statistics logs from saved files. 304 | 305 | Arguments: 306 | * ``filename``: filename to save the logs. ``accuracy.txt`` and ``loss.txt`` will be appended. 307 | * ``numEpoch``: number of epochs of logs to load. Default: None. ``numEpoch`` will be automatically determined from saved files. 308 | * ``modulo``: the gap in number of epoch before model was saved. 309 | 310 | Usage: 311 | 312 | .. code-block:: python 313 | 314 | # save stats 315 | stats.load(epoch=10) 316 | 317 | # save stats filename specified 318 | stats.save(filename='Run101-0.001-', epoch=50) # Run101-0.001-accuracy.txt and Run101-0.001-loss.txt 319 | ''' 320 | saved = {} 321 | saved['accuracy'] = np.loadtxt(filename + 'accuracy.txt') 322 | saved['loss'] = np.loadtxt(filename + 'loss.txt') 323 | if numEpoch is None: 324 | saved['epoch'] = saved['loss'].shape[0] // modulo * modulo + 1 325 | else: 326 | saved['epoch'] = numEpoch 327 | 328 | self.training.lossLog = saved['loss'][:saved['epoch'], 0].tolist() 329 | self.testing.lossLog = saved['loss'][:saved['epoch'], 1].tolist() 330 | self.training.minloss = saved['loss'][:saved['epoch'], 0].min() 331 | self.testing.minloss = saved['loss'][:saved['epoch'], 1].min() 332 | self.training.accuracyLog = saved['accuracy'][:saved['epoch'], 0].tolist() 333 | self.testing.accuracyLog = saved['accuracy'][:saved['epoch'], 1].tolist() 334 | self.training.maxAccuracy = saved['accuracy'][:saved['epoch'], 0].max() 335 | self.testing.maxAccuracy = saved['accuracy'][:saved['epoch'], 1].max() 336 | 337 | return saved['epoch'] --------------------------------------------------------------------------------