├── checkpoint └── ckptspiking_model.t7 ├── __pycache__ └── spiking_model.cpython-35.pyc ├── README.md ├── main.py └── spiking_model.py /checkpoint/ckptspiking_model.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjwu17/STBP-for-training-SpikingNN/HEAD/checkpoint/ckptspiking_model.t7 -------------------------------------------------------------------------------- /__pycache__/spiking_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yjwu17/STBP-for-training-SpikingNN/HEAD/__pycache__/spiking_model.cpython-35.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spatio-temporal BP for spiking neural networks. 2 | Matlab version of convolutional SNN on MNSIT[1]. 3 | 4 | Please find another branch for Pytorch version on CIFAR10[2]. 5 | 6 | For neurmorphic dataset(N-MNIST and DVS-Gesture), please refer to examples of our another projects[3]: 7 | 8 | https://github.com/hewh16/SNNs-RNNs 9 | 10 | ### Requirement 11 | - Python 3.6 12 | - MNIST dataset 13 | - CIFAR10 dataset 14 | - N_MSNIT dataset 15 | 16 | ### Results 17 | After 100 epochs, it can obtain ~ 99.4% acc on MNIST. 18 | 19 | ### Reference 20 | 1. Wu, Yujie, Lei Deng, Guoqi Li, Jun Zhu, and Luping Shi. "Direct Training for Spiking Neural Networks: Faster, Larger, Better." arXiv preprint arXiv:1809.05793 (2018). 21 | 2. Wu, Yujie, Lei Deng, Guoqi Li, Jun Zhu, and Luping Shi. "Spatio-temporal backpropagation for training high-performance spiking neural networks." Frontiers in neuroscience 12 (2018). 22 | 3. He W, Wu Y J, Deng L, et al. Comparing SNNs and RNNs on neuromorphic vision datasets: Similarities and differences[J]. Neural Networks, 2020, 132: 108-120. 23 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Jun 9 09:46:25 2018 4 | 5 | @author: yjwu 6 | 7 | Python 3.5.2 8 | 9 | """ 10 | 11 | from __future__ import print_function 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | import os 15 | import time 16 | from spiking_model import* 17 | # os.environ['CUDA_VISIBLE_DEVICES'] = "3" 18 | names = 'spiking_model' 19 | data_path = './raw/' #todo: input your data path 20 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | train_dataset = torchvision.datasets.MNIST(root= data_path, train=True, download=True, transform=transforms.ToTensor()) 22 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0) 23 | 24 | test_set = torchvision.datasets.MNIST(root= data_path, train=False, download=True, transform=transforms.ToTensor()) 25 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0) 26 | 27 | best_acc = 0 # best test accuracy 28 | start_epoch = 0 # start from epoch 0 or last checkpoint epoch 29 | acc_record = list([]) 30 | loss_train_record = list([]) 31 | loss_test_record = list([]) 32 | 33 | snn = SCNN() 34 | snn.to(device) 35 | criterion = nn.MSELoss() 36 | optimizer = torch.optim.Adam(snn.parameters(), lr=learning_rate) 37 | 38 | for epoch in range(num_epochs): 39 | running_loss = 0 40 | start_time = time.time() 41 | for i, (images, labels) in enumerate(train_loader): 42 | snn.zero_grad() 43 | optimizer.zero_grad() 44 | 45 | images = images.float().to(device) 46 | outputs = snn(images) 47 | labels_ = torch.zeros(batch_size, 10).scatter_(1, labels.view(-1, 1), 1) 48 | loss = criterion(outputs.cpu(), labels_) 49 | running_loss += loss.item() 50 | loss.backward() 51 | optimizer.step() 52 | if (i+1)%100 == 0: 53 | print ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f' 54 | %(epoch+1, num_epochs, i+1, len(train_dataset)//batch_size,running_loss )) 55 | running_loss = 0 56 | print('Time elasped:', time.time()-start_time) 57 | correct = 0 58 | total = 0 59 | optimizer = lr_scheduler(optimizer, epoch, learning_rate, 40) 60 | 61 | with torch.no_grad(): 62 | for batch_idx, (inputs, targets) in enumerate(test_loader): 63 | inputs = inputs.to(device) 64 | optimizer.zero_grad() 65 | outputs = snn(inputs) 66 | labels_ = torch.zeros(batch_size, 10).scatter_(1, targets.view(-1, 1), 1) 67 | loss = criterion(outputs.cpu(), labels_) 68 | _, predicted = outputs.cpu().max(1) 69 | total += float(targets.size(0)) 70 | correct += float(predicted.eq(targets).sum().item()) 71 | if batch_idx %100 ==0: 72 | acc = 100. * float(correct) / float(total) 73 | print(batch_idx, len(test_loader),' Acc: %.5f' % acc) 74 | 75 | print('Iters:', epoch,'\n\n\n') 76 | print('Test Accuracy of the model on the 10000 test images: %.3f' % (100 * correct / total)) 77 | acc = 100. * float(correct) / float(total) 78 | acc_record.append(acc) 79 | if epoch % 5 == 0: 80 | print(acc) 81 | print('Saving..') 82 | state = { 83 | 'net': snn.state_dict(), 84 | 'acc': acc, 85 | 'epoch': epoch, 86 | 'acc_record': acc_record, 87 | } 88 | if not os.path.isdir('checkpoint'): 89 | os.mkdir('checkpoint') 90 | torch.save(state, './checkpoint/ckpt' + names + '.t7') 91 | best_acc = acc -------------------------------------------------------------------------------- /spiking_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | thresh = 0.5 # neuronal threshold 7 | lens = 0.5 # hyper-parameters of approximate function 8 | decay = 0.2 # decay constants 9 | num_classes = 10 10 | batch_size = 100 11 | learning_rate = 1e-3 12 | num_epochs = 100 # max epoch 13 | # define approximate firing function 14 | class ActFun(torch.autograd.Function): 15 | 16 | @staticmethod 17 | def forward(ctx, input): 18 | ctx.save_for_backward(input) 19 | return input.gt(thresh).float() 20 | 21 | @staticmethod 22 | def backward(ctx, grad_output): 23 | input, = ctx.saved_tensors 24 | grad_input = grad_output.clone() 25 | temp = abs(input - thresh) < lens 26 | return grad_input * temp.float() 27 | 28 | act_fun = ActFun.apply 29 | # membrane potential update 30 | def mem_update(ops, x, mem, spike): 31 | mem = mem * decay * (1. - spike) + ops(x) 32 | spike = act_fun(mem) # act_fun : approximation firing function 33 | return mem, spike 34 | 35 | # cnn_layer(in_planes, out_planes, stride, padding, kernel_size) 36 | cfg_cnn = [(1, 32, 1, 1, 3), 37 | (32, 32, 1, 1, 3),] 38 | # kernel size 39 | cfg_kernel = [28, 14, 7] 40 | # fc layer 41 | cfg_fc = [128, 10] 42 | 43 | # Dacay learning_rate 44 | def lr_scheduler(optimizer, epoch, init_lr=0.1, lr_decay_epoch=50): 45 | """Decay learning rate by a factor of 0.1 every lr_decay_epoch epochs.""" 46 | if epoch % lr_decay_epoch == 0 and epoch > 1: 47 | for param_group in optimizer.param_groups: 48 | param_group['lr'] = param_group['lr'] * 0.1 49 | return optimizer 50 | 51 | class SCNN(nn.Module): 52 | def __init__(self): 53 | super(SCNN, self).__init__() 54 | in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[0] 55 | self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding) 56 | in_planes, out_planes, stride, padding, kernel_size = cfg_cnn[1] 57 | self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding) 58 | 59 | self.fc1 = nn.Linear(cfg_kernel[-1] * cfg_kernel[-1] * cfg_cnn[-1][1], cfg_fc[0]) 60 | self.fc2 = nn.Linear(cfg_fc[0], cfg_fc[1]) 61 | 62 | def forward(self, input, time_window = 20): 63 | c1_mem = c1_spike = torch.zeros(batch_size, cfg_cnn[0][1], cfg_kernel[0], cfg_kernel[0], device=device) 64 | c2_mem = c2_spike = torch.zeros(batch_size, cfg_cnn[1][1], cfg_kernel[1], cfg_kernel[1], device=device) 65 | 66 | h1_mem = h1_spike = h1_sumspike = torch.zeros(batch_size, cfg_fc[0], device=device) 67 | h2_mem = h2_spike = h2_sumspike = torch.zeros(batch_size, cfg_fc[1], device=device) 68 | 69 | for step in range(time_window): # simulation time steps 70 | x = input > torch.rand(input.size(), device=device) # prob. firing 71 | 72 | c1_mem, c1_spike = mem_update(self.conv1, x.float(), c1_mem, c1_spike) 73 | 74 | x = F.avg_pool2d(c1_spike, 2) 75 | 76 | c2_mem, c2_spike = mem_update(self.conv2,x, c2_mem,c2_spike) 77 | 78 | x = F.avg_pool2d(c2_spike, 2) 79 | x = x.view(batch_size, -1) 80 | 81 | h1_mem, h1_spike = mem_update(self.fc1, x, h1_mem, h1_spike) 82 | h1_sumspike += h1_spike 83 | h2_mem, h2_spike = mem_update(self.fc2, h1_spike, h2_mem,h2_spike) 84 | h2_sumspike += h2_spike 85 | 86 | outputs = h2_sumspike / time_window 87 | return outputs 88 | 89 | 90 | --------------------------------------------------------------------------------