├── .gitignore ├── requirements.txt ├── set_all_seeds.py ├── dvs ├── dataloader.py ├── test.py ├── train.py ├── run.py └── Net.py ├── fmnist ├── test.py ├── dataloader.py ├── train.py ├── run.py └── Net.py ├── mnist ├── test.py ├── dataloader.py ├── train.py ├── run.py └── Net.py ├── LICENSE ├── extract_test_set_accuracy.py ├── earlystopping.py ├── README.md ├── evaluate.py ├── plot_results.py └── quickstart.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | ~/ 3 | *.pt 4 | *.csv -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | snntorch 3 | numpy 4 | pandas 5 | optuna 6 | seaborn 7 | joblib 8 | h5py 9 | brevitas 10 | requests -------------------------------------------------------------------------------- /set_all_seeds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def set_all_seeds(seed=0): 8 | random.seed(seed) 9 | os.environ["PYTHONHASHSEED"] = str(seed) 10 | np.random.seed(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | -------------------------------------------------------------------------------- /dvs/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from snntorch.spikevision import spikedata 4 | 5 | 6 | def load_data(config): 7 | data_dir = config["data_dir"] 8 | # Note: the train set / test set are of different durations, we used num_steps=100 due to memory limits. 9 | # You will likely to improve our reported results by increasing num_steps=100 to 150. 10 | trainset = spikedata.DVSGesture(data_dir, train=True, num_steps=100, dt=3000, ds=4) 11 | testset = spikedata.DVSGesture(data_dir, train=False, num_steps=600, dt=3000, ds=4) 12 | return trainset, testset 13 | -------------------------------------------------------------------------------- /fmnist/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from snntorch import functional as SF 3 | 4 | 5 | def test(config, net, testloader, device="cpu"): 6 | correct = 0 7 | total = 0 8 | with torch.no_grad(): 9 | net.eval() 10 | for data in testloader: 11 | images, labels = data 12 | images, labels = images.to(device), labels.to(device) 13 | outputs, _ = net(images) 14 | accuracy = SF.accuracy_rate(outputs, labels) 15 | total += labels.size(0) 16 | correct += accuracy * labels.size(0) 17 | 18 | return 100 * correct / total 19 | -------------------------------------------------------------------------------- /mnist/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from snntorch import functional as SF 3 | 4 | 5 | def test(config, net, testloader, device="cpu"): 6 | correct = 0 7 | total = 0 8 | with torch.no_grad(): 9 | net.eval() 10 | for data in testloader: 11 | images, labels = data 12 | images, labels = images.to(device), labels.to(device) 13 | outputs, _ = net(images) 14 | accuracy = SF.accuracy_rate(outputs, labels) 15 | total += labels.size(0) 16 | correct += accuracy * labels.size(0) 17 | 18 | return 100 * correct / total 19 | -------------------------------------------------------------------------------- /dvs/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from snntorch import functional as SF 3 | 4 | 5 | def test(config, net, testloader, device="cpu"): 6 | correct = 0 7 | total = 0 8 | with torch.no_grad(): 9 | net.eval() 10 | for data in testloader: 11 | images, labels = data 12 | images, labels = images.to(device), labels.to(device) 13 | outputs, _ = net(images.permute(1, 0, 2, 3, 4)) 14 | accuracy = SF.accuracy_rate(outputs, labels.long()) 15 | total += labels.size(0) 16 | correct += accuracy * labels.size(0) 17 | 18 | return 100 * correct / total 19 | -------------------------------------------------------------------------------- /mnist/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import datasets, transforms 3 | from snntorch import utils 4 | 5 | 6 | def load_data(config): 7 | data_dir = config["data_dir"] 8 | transform = transforms.Compose( 9 | [ 10 | transforms.Resize((28, 28)), 11 | transforms.Grayscale(), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0,), (1,)), 14 | ] 15 | ) 16 | trainset = datasets.MNIST(data_dir, train=True, download=True, transform=transform) 17 | testset = datasets.MNIST(data_dir, train=False, download=True, transform=transform) 18 | return trainset, testset 19 | -------------------------------------------------------------------------------- /fmnist/dataloader.py: -------------------------------------------------------------------------------- 1 | from torchvision import datasets, transforms 2 | 3 | 4 | def load_data(config): 5 | data_dir = config["data_dir"] 6 | transform = transforms.Compose( 7 | [ 8 | transforms.Resize((28, 28)), 9 | transforms.Grayscale(), 10 | transforms.ToTensor(), 11 | transforms.Normalize((0,), (1,)), 12 | ] 13 | ) 14 | trainset = datasets.FashionMNIST( 15 | data_dir, train=True, download=True, transform=transform 16 | ) 17 | testset = datasets.FashionMNIST( 18 | data_dir, train=False, download=True, transform=transform 19 | ) 20 | return trainset, testset 21 | -------------------------------------------------------------------------------- /fmnist/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def train(config, net, trainloader, criterion, optimizer, device="cpu", scheduler=None): 6 | net.train() 7 | loss_accum = [] 8 | lr_accum = [] 9 | i = 0 10 | for data, labels in trainloader: 11 | data, labels = data.to(device), labels.to(device) 12 | spk_rec, _ = net(data) 13 | loss = criterion(spk_rec, labels) 14 | optimizer.zero_grad() 15 | loss.backward() 16 | if config["grad_clip"]: 17 | nn.utils.clip_grad_norm_(net.parameters(), 1.0) 18 | 19 | if config["weight_clip"]: 20 | with torch.no_grad(): 21 | for param in net.parameters(): 22 | param.clamp_(-1, 1) 23 | 24 | optimizer.step() 25 | scheduler.step() 26 | loss_accum.append(loss.item() / config["num_steps"]) 27 | lr_accum.append(optimizer.param_groups[0]["lr"]) 28 | 29 | return loss_accum, lr_accum 30 | -------------------------------------------------------------------------------- /mnist/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def train(config, net, trainloader, criterion, optimizer, device="cpu", scheduler=None): 6 | net.train() 7 | loss_accum = [] 8 | lr_accum = [] 9 | i = 0 10 | for data, labels in trainloader: 11 | data, labels = data.to(device), labels.to(device) 12 | spk_rec, _ = net(data) 13 | loss = criterion(spk_rec, labels) 14 | optimizer.zero_grad() 15 | loss.backward() 16 | if config["grad_clip"]: 17 | nn.utils.clip_grad_norm_(net.parameters(), 1.0) 18 | 19 | if config["weight_clip"]: 20 | with torch.no_grad(): 21 | for param in net.parameters(): 22 | param.clamp_(-1, 1) 23 | 24 | optimizer.step() 25 | scheduler.step() 26 | loss_accum.append(loss.item() / config["num_steps"]) 27 | lr_accum.append(optimizer.param_groups[0]["lr"]) 28 | 29 | return loss_accum, lr_accum 30 | -------------------------------------------------------------------------------- /dvs/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def train(config, net, trainloader, criterion, optimizer, device="cpu", scheduler=None): 6 | net.train() 7 | loss_accum = [] 8 | lr_accum = [] 9 | i = 0 10 | for data, labels in trainloader: 11 | data, labels = data.to(device), labels.to(device) 12 | spk_rec, _ = net(data.permute(1, 0, 2, 3, 4)) 13 | loss = criterion(spk_rec, labels.long()) 14 | optimizer.zero_grad() 15 | loss.backward() 16 | if config["grad_clip"]: 17 | nn.utils.clip_grad_norm_(net.parameters(), 1.0) 18 | 19 | if config["weight_clip"]: 20 | with torch.no_grad(): 21 | for param in net.parameters(): 22 | param.clamp_(-1, 1) 23 | 24 | optimizer.step() 25 | scheduler.step() 26 | loss_accum.append(loss.item() / config["num_steps"]) 27 | lr_accum.append(optimizer.param_groups[0]["lr"]) 28 | 29 | return loss_accum, lr_accum 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Quantized Spiking Neural Networks 5 | Copyright (C) 2022 Jason K. Eshraghian and Corey Lammie 6 | 7 | This program is free software: you can redistribute it and/or modify 8 | it under the terms of the GNU General Public License as published by 9 | the Free Software Foundation, either version 3 of the License, or 10 | (at your option) any later version. 11 | 12 | This program is distributed in the hope that it will be useful, 13 | but WITHOUT ANY WARRANTY; without even the implied warranty of 14 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 15 | GNU General Public License for more details. 16 | 17 | You should have received a copy of the GNU General Public License 18 | along with this program. If not, see . 19 | 20 | Also add information on how to contact you by electronic and paper mail. 21 | 22 | You should also get your employer (if you work as a programmer) or school, 23 | if any, to sign a "copyright disclaimer" for the program, if necessary. 24 | For more information on this, and how to apply and follow the GNU GPL, see 25 | . 26 | 27 | The GNU General Public License does not permit incorporating your program 28 | into proprietary programs. If your program is a subroutine library, you 29 | may consider it more useful to permit linking proprietary applications with 30 | the library. If this is what you want to do, use the GNU Lesser General 31 | Public License instead of this License. But first, please read 32 | . -------------------------------------------------------------------------------- /extract_test_set_accuracy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import os 4 | 5 | 6 | data = { 7 | 4: { 8 | "MNIST": { 9 | "cosine": [ 10 | "mnist/acc_MNIST_t0.csv", 11 | "mnist/acc_MNIST_t1.csv", 12 | "mnist/acc_MNIST_t2.csv", 13 | ], 14 | }, 15 | "FashionMNIST": { 16 | "cosine": [ 17 | "fmnist/acc_FMNIST_t0.csv", 18 | "fmnist/acc_FMNIST_t1.csv", 19 | "fmnist/acc_FMNIST_t2.csv", 20 | ], 21 | }, 22 | "DVS128 Gesture": { 23 | "cosine": [ 24 | "DVS/acc_DVS_t0.csv", 25 | "DVS/acc_DVS_t1.csv", 26 | "DVS/acc_DVS_t2.csv", 27 | ], 28 | }, 29 | }, 30 | } 31 | 32 | 33 | df = pd.DataFrame( 34 | columns=[ 35 | "dataset", 36 | "network_precision", 37 | "scheduler", 38 | "test_set_accuracy_best", 39 | "test_set_accuracy_mean", 40 | "test_set_accuracy_std", 41 | ] 42 | ) 43 | for precision in data.keys(): 44 | for dataset_idx, dataset in enumerate(data[precision].keys()): 45 | for scheduler in data[precision][dataset].keys(): 46 | test_set_accuracy_values = [] 47 | for idx, trial in enumerate(data[precision][dataset][scheduler]): 48 | trial_df = pd.read_csv(trial) 49 | test_set_accuracy = trial_df["test_acc"].max().item() 50 | test_set_accuracy_values.append(test_set_accuracy) 51 | 52 | df = df.append( 53 | { 54 | "dataset": dataset, 55 | "network_precision": precision, 56 | "scheduler": scheduler, 57 | "test_set_accuracy_best": max(test_set_accuracy_values), 58 | "test_set_accuracy_mean": np.mean(test_set_accuracy_values), 59 | "test_set_accuracy_std": np.std(test_set_accuracy_values), 60 | }, 61 | ignore_index=True, 62 | ) 63 | 64 | df.to_csv("test_set_accuracy.csv", index=False) 65 | -------------------------------------------------------------------------------- /earlystopping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EarlyStopping_acc: 5 | """Early stops the training if test acc doesn't improve after a given patience.""" 6 | 7 | def __init__( 8 | self, patience=7, verbose=False, delta=0, path="checkpoint.pt", trace_func=print 9 | ): 10 | """ 11 | Args: 12 | patience (int): How long to wait after last time validation loss improved. 13 | Default: 7 14 | verbose (bool): If True, prints a message for each validation loss improvement. 15 | Default: False 16 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 17 | Default: 0 18 | path (str): Path for the checkpoint to be saved to. 19 | Default: 'checkpoint.pt' 20 | trace_func (function): trace print function. 21 | Default: print 22 | """ 23 | self.patience = patience 24 | self.verbose = verbose 25 | self.counter = 0 26 | self.best_score = None 27 | self.early_stop = False 28 | self.test_loss_min = 0 29 | self.delta = delta 30 | self.path = path 31 | self.trace_func = trace_func 32 | 33 | def __call__(self, test_loss, model): 34 | score = test_loss 35 | if self.best_score is None: 36 | self.best_score = score 37 | self.save_checkpoint(test_loss, model) 38 | elif score <= self.best_score + self.delta: 39 | self.counter += 1 40 | self.trace_func( 41 | f"EarlyStopping counter: {self.counter} out of {self.patience}" 42 | ) 43 | if self.counter >= self.patience: 44 | self.early_stop = True 45 | self.counter = 0 46 | else: 47 | self.best_score = score 48 | self.save_checkpoint(test_loss, model) 49 | self.counter = 0 50 | 51 | def save_checkpoint(self, test_loss, model): 52 | """Saves model when test acc increases.""" 53 | if self.verbose: 54 | self.trace_func( 55 | f"Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}). Saving model ..." 56 | ) 57 | 58 | torch.save(model.state_dict(), self.path) 59 | self.test_loss_min = test_loss 60 | -------------------------------------------------------------------------------- /dvs/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | 5 | current = os.path.dirname(os.path.realpath(__file__)) 6 | parent = os.path.dirname(current) 7 | sys.path.append(parent) 8 | 9 | from Net import Net 10 | from evaluate import evaluate 11 | from dataloader import load_data 12 | from train import train 13 | from test import test 14 | 15 | config = { 16 | "exp_name": "DVS", # Experiment name 17 | "num_trials_eval": 3, # Number of trails to execute (separate training and evaluation instances) 18 | "num_epochs_eval": 500, # Number of epochs to train for (per trial) 19 | "data_dir": "~/data/", # Data directory to download and store data 20 | "batch_size": 16, # Batch size 21 | "seed": 0, # Random seed 22 | "num_workers": 0, # Number of workers for the dataloader 23 | "num_bits": 4, # Bit resolution. If None, floating point resolution is used 24 | "save_csv": True, # Whether or not to save loss, lr, and accuracy dataframes 25 | "early_stopping": True, # Whether or not to use early stopping 26 | "patience": 100, # Number of epochs to wait for improvement before stopping 27 | # Network parameters 28 | "grad_clip": True, # Whether or not to clip gradients 29 | "weight_clip": True, # Whether or not to clip weights 30 | "batch_norm": False, # Whether or not to use batch normalization 31 | "dropout": 0.203, # Dropout rate 32 | "beta": 0.614, # Decay rate parameter (beta) 33 | "threshold": 0.427, # Threshold parameter (theta) 34 | "lr": 2.634e-3, # Initial learning rate 35 | "slope": 4.413, # Slope value (k) 36 | # Fixed params 37 | "num_steps": 1, # Number of timesteps to encode input for 100 TODO 38 | "correct_rate": 0.8, # Correct rate 39 | "incorrect_rate": 0.2, # Incorrect rate 40 | "betas": (0.9, 0.999), # Adam optimizer beta values 41 | "t_max": 735, # Frequency of the cosine annealing scheduler (5 epochs) 42 | "t_0": 735, # Initial frequency of the cosine annealing scheduler 43 | "t_mult": 2, # The frequency of cosine is halved after every 4690 iters (10 epochs) 44 | "eta_min": 0, # Minimum learning rate 45 | } 46 | 47 | 48 | def optim_func(net, config): 49 | optimizer = torch.optim.Adam( 50 | net.parameters(), lr=config["lr"], betas=config["betas"] 51 | ) 52 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 53 | optimizer, T_max=config["t_0"], eta_min=config["eta_min"], last_epoch=-1 54 | ) 55 | loss_dependent = False 56 | return optimizer, scheduler, loss_dependent 57 | 58 | 59 | if __name__ == "__main__": 60 | evaluate(Net, config, load_data, train, test, optim_func) 61 | -------------------------------------------------------------------------------- /mnist/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | 5 | current = os.path.dirname(os.path.realpath(__file__)) 6 | parent = os.path.dirname(current) 7 | sys.path.append(parent) 8 | 9 | from Net import Net 10 | from evaluate import evaluate 11 | from dataloader import load_data 12 | from train import train 13 | from test import test 14 | 15 | config = { 16 | "exp_name": "MNIST", # Experiment name 17 | "num_trials_eval": 3, # Number of trails to execute (separate training and evaluation instances) 18 | "num_epochs_eval": 500, # Number of epochs to train for (per trial) 19 | "data_dir": "~/data/", # Data directory to download and store data 20 | "batch_size": 128, # Batch size 21 | "seed": 0, # Random seed 22 | "num_workers": 0, # Number of workers for the dataloader 23 | "num_bits": 4, # Bit resolution. If None, floating point resolution is used 24 | "save_csv": True, # Whether or not to save loss, lr, and accuracy dataframes 25 | "early_stopping": True, # Whether or not to use early stopping 26 | "patience": 100, # Number of epochs to wait for improvement before stopping 27 | # Network parameters 28 | "grad_clip": False, # Whether or not to clip gradients 29 | "weight_clip": False, # Whether or not to clip weights 30 | "batch_norm": True, # Whether or not to use batch normalization 31 | "dropout": 0.003168, # Dropout rate 32 | "beta": 0.994, # Decay rate parameter (beta) 33 | "threshold": 2.915, # Threshold parameter (theta) 34 | "lr": 5.4e-3, # Initial learning rate 35 | "slope": 13.84, # Slope value (k) 36 | # Fixed params 37 | "num_steps": 100, # Number of timesteps to encode input for 38 | "correct_rate": 0.8, # Correct rate 39 | "incorrect_rate": 0.2, # Incorrect rate 40 | "betas": (0.9, 0.999), # Adam optimizer beta values 41 | "t_max": 4690, # Frequency of the cosine annealing scheduler (5 epochs) 42 | "t_0": 4690, # Initial frequency of the cosine annealing scheduler 43 | "t_mult": 2, # The frequency of cosine is halved after every 4690 iters (10 epochs) 44 | "eta_min": 0, # Minimum learning rate 45 | } 46 | 47 | 48 | def optim_func(net, config): 49 | optimizer = torch.optim.Adam( 50 | net.parameters(), lr=config["lr"], betas=config["betas"] 51 | ) 52 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 53 | optimizer, T_max=config["t_0"], eta_min=config["eta_min"], last_epoch=-1 54 | ) 55 | loss_dependent = False 56 | return optimizer, scheduler, loss_dependent 57 | 58 | 59 | if __name__ == "__main__": 60 | evaluate(Net, config, load_data, train, test, optim_func) 61 | -------------------------------------------------------------------------------- /fmnist/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import os 4 | 5 | current = os.path.dirname(os.path.realpath(__file__)) 6 | parent = os.path.dirname(current) 7 | sys.path.append(parent) 8 | 9 | from Net import Net 10 | from evaluate import evaluate 11 | from dataloader import load_data 12 | from train import train 13 | from test import test 14 | 15 | config = { 16 | "exp_name": "FMNIST", # Experiment name 17 | "num_trials_eval": 3, # Number of trails to execute (separate training and evaluation instances) 18 | "num_epochs_eval": 500, # Number of epochs to train for (per trial) 19 | "data_dir": "~/data/", # Data directory to download and store data 20 | "batch_size": 128, # Batch size 21 | "seed": 0, # Random seed 22 | "num_workers": 0, # Number of workers for the dataloader 23 | "num_bits": 4, # Bit resolution. If None, floating point resolution is used 24 | "save_csv": True, # Whether or not to save loss, lr, and accuracy dataframes 25 | "early_stopping": True, # Whether or not to use early stopping 26 | "patience": 100, # Number of epochs to wait for improvement before stopping 27 | # Network parameters 28 | "grad_clip": False, # Whether or not to clip gradients 29 | "weight_clip": False, # Whether or not to clip weights 30 | "batch_norm": True, # Whether or not to use batch normalization 31 | "dropout": 0.073556, # Dropout rate 32 | "beta": 0.974, # Decay rate parameter (beta) 33 | "threshold": 2.473, # Threshold parameter (theta) 34 | "lr": 2.908e-3, # Initial learning rate 35 | "slope": 5.5565, # Slope value (k) 36 | # Fixed params 37 | "num_steps": 100, # Number of timesteps to encode input for 38 | "correct_rate": 0.8, # Correct rate 39 | "incorrect_rate": 0.2, # Incorrect rate 40 | "betas": (0.9, 0.999), # Adam optimizer beta values 41 | "t_max": 4690, # Frequency of the cosine annealing scheduler (5 epochs) 42 | "t_0": 4690, # Initial frequency of the cosine annealing scheduler 43 | "t_mult": 2, # The frequency of cosine is halved after every 4690 iters (10 epochs) 44 | "eta_min": 0, # Minimum learning rate 45 | } 46 | 47 | 48 | def optim_func(net, config): 49 | optimizer = torch.optim.Adam( 50 | net.parameters(), lr=config["lr"], betas=config["betas"] 51 | ) 52 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 53 | optimizer, T_max=config["t_0"], eta_min=config["eta_min"], last_epoch=-1 54 | ) 55 | loss_dependent = False 56 | return optimizer, scheduler, loss_dependent 57 | 58 | 59 | if __name__ == "__main__": 60 | evaluate(Net, config, load_data, train, test, optim_func) 61 | -------------------------------------------------------------------------------- /dvs/Net.py: -------------------------------------------------------------------------------- 1 | import snntorch as snn 2 | from snntorch import surrogate 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import brevitas.nn as qnn 7 | 8 | 9 | class Net(nn.Module): 10 | def __init__(self, config): 11 | super().__init__() 12 | self.num_bits = config["num_bits"] 13 | self.thr = config["threshold"] 14 | self.slope = config["slope"] 15 | self.beta = config["beta"] 16 | self.num_steps = config["num_steps"] 17 | self.batch_norm = config["batch_norm"] 18 | self.p1 = config["dropout"] 19 | self.spike_grad = surrogate.fast_sigmoid(self.slope) 20 | if self.num_bits is None: 21 | self.init_net() 22 | else: 23 | self.init_quantized_net() 24 | 25 | def init_net(self): 26 | self.conv1 = nn.Conv2d(2, 16, 5, bias=False) 27 | self.conv1_bn = nn.BatchNorm2d(16) 28 | self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 29 | self.conv2 = nn.Conv2d(16, 32, 5, bias=False) 30 | self.conv2_bn = nn.BatchNorm2d(32) 31 | self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 32 | self.fc1 = nn.Linear(32 * 5 * 5, 11) 33 | self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 34 | self.dropout = nn.Dropout(self.p1) 35 | 36 | def init_quantized_net(self): 37 | self.conv1 = qnn.QuantConv2d( 38 | 2, 16, 5, bias=False, weight_bit_width=self.num_bits 39 | ) 40 | self.conv1_bn = nn.BatchNorm2d(16) 41 | self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 42 | self.conv2 = qnn.QuantConv2d( 43 | 16, 32, 5, bias=False, weight_bit_width=self.num_bits 44 | ) 45 | self.conv2_bn = nn.BatchNorm2d(32) 46 | self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 47 | self.fc1 = qnn.QuantLinear( 48 | 32 * 5 * 5, 11, bias=False, weight_bit_width=self.num_bits 49 | ) 50 | self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 51 | self.dropout = nn.Dropout(self.p1) 52 | 53 | def forward(self, x): 54 | # Initialize hidden states and outputs at t=0 55 | mem1 = self.lif1.init_leaky() 56 | mem2 = self.lif2.init_leaky() 57 | mem3 = self.lif3.init_leaky() 58 | # Record the final layer 59 | spk3_rec = [] 60 | mem3_rec = [] 61 | for step in range(x.size(0)): 62 | cur1 = F.avg_pool2d(self.conv1(x[step]), 2) 63 | if self.batch_norm: 64 | cur1 = self.conv1_bn(cur1) 65 | 66 | spk1, mem1 = self.lif1(cur1, mem1) 67 | cur2 = F.avg_pool2d(self.conv2(spk1), 2) 68 | if self.batch_norm: 69 | cur2 = self.conv2_bn(cur2) 70 | 71 | spk2, mem2 = self.lif2(cur2, mem2) 72 | cur3 = self.dropout(self.fc1(spk2.flatten(1))) 73 | spk3, mem3 = self.lif3(cur3, mem3) 74 | spk3_rec.append(spk3) 75 | mem3_rec.append(mem3) 76 | 77 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0) 78 | -------------------------------------------------------------------------------- /fmnist/Net.py: -------------------------------------------------------------------------------- 1 | import snntorch as snn 2 | from snntorch import surrogate 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import brevitas.nn as qnn 7 | 8 | 9 | class Net(nn.Module): 10 | def __init__(self, config): 11 | super().__init__() 12 | self.num_bits = config["num_bits"] 13 | self.thr = config["threshold"] 14 | self.slope = config["slope"] 15 | self.beta = config["beta"] 16 | self.num_steps = config["num_steps"] 17 | self.batch_norm = config["batch_norm"] 18 | self.p1 = config["dropout"] 19 | self.spike_grad = surrogate.fast_sigmoid(self.slope) 20 | if self.num_bits is None: 21 | self.init_net() 22 | else: 23 | self.init_quantized_net() 24 | 25 | def init_net(self): 26 | self.conv1 = nn.Conv2d(1, 16, 5, bias=False) 27 | self.conv1_bn = nn.BatchNorm2d(16) 28 | self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 29 | self.conv2 = nn.Conv2d(16, 64, 5, bias=False) 30 | self.conv2_bn = nn.BatchNorm2d(64) 31 | self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 32 | self.fc1 = nn.Linear(64 * 4 * 4, 10) 33 | self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 34 | self.dropout = nn.Dropout(self.p1) 35 | 36 | def init_quantized_net(self): 37 | self.conv1 = qnn.QuantConv2d( 38 | 1, 16, 5, bias=False, weight_bit_width=self.num_bits 39 | ) 40 | self.conv1_bn = nn.BatchNorm2d(16) 41 | self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 42 | self.conv2 = qnn.QuantConv2d( 43 | 16, 64, 5, bias=False, weight_bit_width=self.num_bits 44 | ) 45 | self.conv2_bn = nn.BatchNorm2d(64) 46 | self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 47 | self.fc1 = qnn.QuantLinear( 48 | 64 * 4 * 4, 10, bias=False, weight_bit_width=self.num_bits 49 | ) 50 | self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 51 | self.dropout = nn.Dropout(self.p1) 52 | 53 | def forward(self, x): 54 | # Initialize hidden states and outputs at t=0 55 | mem1 = self.lif1.init_leaky() 56 | mem2 = self.lif2.init_leaky() 57 | mem3 = self.lif3.init_leaky() 58 | # Record the final layer 59 | spk3_rec = [] 60 | mem3_rec = [] 61 | for step in range(self.num_steps): 62 | cur1 = F.avg_pool2d(self.conv1(x), 2) 63 | if self.batch_norm: 64 | cur1 = self.conv1_bn(cur1) 65 | 66 | spk1, mem1 = self.lif1(cur1, mem1) 67 | cur2 = F.avg_pool2d(self.conv2(spk1), 2) 68 | if self.batch_norm: 69 | cur2 = self.conv2_bn(cur2) 70 | 71 | spk2, mem2 = self.lif2(cur2, mem2) 72 | cur3 = self.dropout(self.fc1(spk2.flatten(1))) 73 | spk3, mem3 = self.lif3(cur3, mem3) 74 | spk3_rec.append(spk3) 75 | mem3_rec.append(mem3) 76 | 77 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0) 78 | -------------------------------------------------------------------------------- /mnist/Net.py: -------------------------------------------------------------------------------- 1 | import snntorch as snn 2 | from snntorch import surrogate 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import brevitas.nn as qnn 7 | 8 | 9 | class Net(nn.Module): 10 | def __init__(self, config): 11 | super().__init__() 12 | self.num_bits = config["num_bits"] 13 | self.thr = config["threshold"] 14 | self.slope = config["slope"] 15 | self.beta = config["beta"] 16 | self.num_steps = config["num_steps"] 17 | self.batch_norm = config["batch_norm"] 18 | self.p1 = config["dropout"] 19 | self.spike_grad = surrogate.fast_sigmoid(self.slope) 20 | if self.num_bits is None: 21 | self.init_net() 22 | else: 23 | self.init_quantized_net() 24 | 25 | def init_net(self): 26 | self.conv1 = nn.Conv2d(1, 16, 5, bias=False) 27 | self.conv1_bn = nn.BatchNorm2d(16) 28 | self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 29 | self.conv2 = nn.Conv2d(16, 64, 5, bias=False) 30 | self.conv2_bn = nn.BatchNorm2d(64) 31 | self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 32 | self.fc1 = nn.Linear(64 * 4 * 4, 10) 33 | self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 34 | self.dropout = nn.Dropout(self.p1) 35 | 36 | def init_quantized_net(self): 37 | self.conv1 = qnn.QuantConv2d( 38 | 1, 16, 5, bias=False, weight_bit_width=self.num_bits 39 | ) 40 | self.conv1_bn = nn.BatchNorm2d(16) 41 | self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 42 | self.conv2 = qnn.QuantConv2d( 43 | 16, 64, 5, bias=False, weight_bit_width=self.num_bits 44 | ) 45 | self.conv2_bn = nn.BatchNorm2d(64) 46 | self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 47 | self.fc1 = qnn.QuantLinear( 48 | 64 * 4 * 4, 10, bias=False, weight_bit_width=self.num_bits 49 | ) 50 | self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad) 51 | self.dropout = nn.Dropout(self.p1) 52 | 53 | def forward(self, x): 54 | # Initialize hidden states and outputs at t=0 55 | mem1 = self.lif1.init_leaky() 56 | mem2 = self.lif2.init_leaky() 57 | mem3 = self.lif3.init_leaky() 58 | # Record the final layer 59 | spk3_rec = [] 60 | mem3_rec = [] 61 | for step in range(self.num_steps): 62 | cur1 = F.avg_pool2d(self.conv1(x), 2) 63 | if self.batch_norm: 64 | cur1 = self.conv1_bn(cur1) 65 | 66 | spk1, mem1 = self.lif1(cur1, mem1) 67 | cur2 = F.avg_pool2d(self.conv2(spk1), 2) 68 | if self.batch_norm: 69 | cur2 = self.conv2_bn(cur2) 70 | 71 | spk2, mem2 = self.lif2(cur2, mem2) 72 | cur3 = self.dropout(self.fc1(spk2.flatten(1))) 73 | spk3, mem3 = self.lif3(cur3, mem3) 74 | spk3_rec.append(spk3) 75 | mem3_rec.append(mem3) 76 | 77 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0) 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quantized Spiking Neural Networks 2 | This repository contains the corresponding code from the paper [Jason K. Eshraghian, Corey Lammie, Mostafa Rahimi Azghadi, and Wei D. Lu "Navigating Local Minima in Quantized Spiking Neural Networks". https://arxiv.org/abs/2202.07221, February 2022.](https://arxiv.org/abs/2202.07221) 3 | 4 | 5 | ![anim_2](https://user-images.githubusercontent.com/40262130/154583824-fa940d58-3249-40aa-a85b-0c0fbcaf68c4.gif) 6 | 7 |

Illustrations of the key concepts of the paper: Periodic scheduling can enable SNNs to overcome flat surfaces and local minima. When the LR is boosted during training using a cyclic scheduler, it is given another chance to reduce the loss with different initial conditions. While the loss appears to converge, subsequent LR boosting enables it to traverse more optimal solutions.

8 | 9 | If you find this code useful in your work, please cite the following source: 10 | 11 | ``` 12 | @article{eshraghian2022navigating, 13 | title={{Navigating Local Minima in Quantized Spiking Neural Networks}}, 14 | author={Eshraghian, Jason K and Lammie, Corey and Rahimi Azghadi, Mostafa and Lu, Wei D}, 15 | year={2022}, 16 | eprint={2202.07221}, 17 | archivePrefix={arXiv}, 18 | } 19 | ``` 20 | 21 | ## Jupyter Notebook 22 | We provide a Jupyter notebook [here](https://github.com/jeshraghian/QSNNs/blob/main/quickstart.ipynb), which includes documentation and information about our developed scripts and methodologies. This can be run in a Google Collaboratory environment without any prerequisites [here](https://colab.research.google.com/github/jeshraghian/QSNNs/blob/main/quickstart.ipynb). 23 | 24 | ## Code Execution of Standalone Scripts 25 | For more advanced users, i.e., those proficient with Python, we provide executable code in the form of Python scripts. Simulations can be run by configuring and executing `run.py` in each respective dataset directory. 26 | 27 | ## Requirements 28 | ### Jupyter Notebook 29 | To run the Jupyter notebook, Google Colab can be used. Otherwise, a working `Python` (≥3.6) interpreter and the `pip` package manager are required. 30 | 31 | ### Standalone Scripts 32 | To run all standalone scripts, a working `Python` (≥3.6) interpreter and the `pip` package manager. All required libraries and packages can be installed using `pip install -r requirements.txt`. To avoid potential package conflicts, the use of a `conda` environment is recommended. The following commands can be used to create and activate a separate `conda` environment, clone this repository, and to install all dependencies: 33 | 34 | ``` 35 | conda create -n QSNNs python=3.8 36 | conda activate QSNNs 37 | git clone https://github.com/jeshraghian/QSNNs.git 38 | cd QSNNs 39 | pip install -r requirements.txt 40 | ``` 41 | 42 | ## Hyperparameter Tuning 43 | * In each directory, within `run.py` files, the `config` dictionary defines all configuration parameters and parameters for each dataset. 44 | * The default parameters in this repo are identical to those for the Q4 cosine anneling learning rate schedule configurations reported in the corresponding paper. 45 | 46 | ## Interpreting and Plotting Results 47 | * Results can be gathered and plotted using `extract_test_set_accuracy.py` and `plot_results.py`, respectively. 48 | * `plot_results.py` can be reconfigured to plot different quantities. 49 | * By default, `plot_results.py` plots the loss curve evolution during training for all three datasets. 50 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import snntorch as snn 2 | from snntorch import functional as SF 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | import pandas as pd 7 | import time 8 | from earlystopping import * 9 | from set_all_seeds import set_all_seeds 10 | 11 | 12 | def evaluate(Net, config, load_data, train, test, optim_func): 13 | file_name = config["exp_name"] 14 | for trial in range(config["num_trials_eval"]): 15 | csv_name = file_name + "_t" + str(trial) + ".csv" 16 | model_name = file_name + "_t" + str(trial) + ".pt" 17 | num_epochs = config["num_epochs_eval"] 18 | set_all_seeds(config["seed"] + trial) 19 | df_train_loss = pd.DataFrame() 20 | df_test_acc = pd.DataFrame(columns=["epoch", "test_acc", "train_time"]) 21 | df_lr = pd.DataFrame() 22 | # Initialize the network 23 | net = Net(config) 24 | device = "cpu" 25 | if torch.cuda.is_available(): 26 | device = "cuda" 27 | 28 | net.to(device) 29 | # Initialize the optimizer and scheduler 30 | criterion = SF.mse_count_loss( 31 | correct_rate=config["correct_rate"], incorrect_rate=config["incorrect_rate"] 32 | ) 33 | optimizer, scheduler, loss_dependent = optim_func(net, config) 34 | # Early stopping condition 35 | if config["early_stopping"]: 36 | early_stopping = EarlyStopping_acc( 37 | patience=config["patience"], verbose=True, path=model_name 38 | ) 39 | early_stopping.early_stop = False 40 | early_stopping.best_score = None 41 | 42 | # Load data 43 | trainset, testset = load_data(config) 44 | config["dataset_length"] = len(trainset) 45 | trainloader = DataLoader( 46 | trainset, batch_size=int(config["batch_size"]), shuffle=True 47 | ) 48 | testloader = DataLoader( 49 | testset, batch_size=int(config["batch_size"]), shuffle=False 50 | ) 51 | if loss_dependent: 52 | old_loss_hist = float("inf") 53 | 54 | print( 55 | f"=======Trial: {trial}, Batch: {config['batch_size']}, beta: {config['beta']:.3f}, threshold: {config['threshold']:.2f}, slope: {config['slope']}, lr: {config['lr']:.3e}======" 56 | ) 57 | # Train 58 | for epoch in range(num_epochs): 59 | start_time = time.time() 60 | loss_list, lr_list = train( 61 | config, net, trainloader, criterion, optimizer, device, scheduler 62 | ) 63 | epoch_time = time.time() - start_time 64 | if loss_dependent: 65 | avg_loss_hist = sum(loss_list) / len(loss_list) 66 | if avg_loss_hist > old_loss_hist: 67 | for param_group in optimizer.param_groups: 68 | param_group["lr"] = param_group["lr"] * 0.5 69 | else: 70 | old_loss_hist = avg_loss_hist 71 | 72 | # Test 73 | test_accuracy = test(config, net, testloader, device) 74 | print(f"Epoch: {epoch} \tTest Accuracy: {test_accuracy}") 75 | df_lr = df_lr.append(lr_list, ignore_index=True) 76 | 77 | df_train_loss = df_train_loss.append(loss_list, ignore_index=True) 78 | df_test_acc = df_test_acc.append( 79 | {"epoch": epoch, "test_acc": test_accuracy, "train_time": epoch_time}, 80 | ignore_index=True, 81 | ) 82 | if config["save_csv"]: 83 | df_train_loss.to_csv("loss_" + csv_name, index=False) 84 | df_test_acc.to_csv("acc_" + csv_name, index=False) 85 | df_lr.to_csv("lr_" + csv_name, index=False) 86 | 87 | if config["early_stopping"]: 88 | early_stopping(test_accuracy, net) 89 | if early_stopping.early_stop: 90 | print("Early stopping") 91 | early_stopping.early_stop = False 92 | early_stopping.best_score = None 93 | break 94 | -------------------------------------------------------------------------------- /plot_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.lib.twodim_base import tri 3 | import pandas as pd 4 | import os 5 | import matplotlib.pyplot as plt 6 | from matplotlib.ticker import FormatStrFormatter 7 | import seaborn as sns 8 | 9 | 10 | plt.rcParams["font.family"] = "sans-serif" 11 | plt.rcParams["font.sans-serif"] = ["Arial"] 12 | plt.rcParams["figure.figsize"] = (32.5, 10) 13 | plt.rcParams.update({"font.size": 18}) 14 | plt.rcParams["axes.linewidth"] = 2 15 | plt.rcParams["axes.formatter.limits"] = [-5, 4] 16 | 17 | fig, ax = plt.subplots(1, 3) 18 | 19 | data = { 20 | 4: { 21 | "MNIST": { 22 | "cosine": [ 23 | "mnist/loss_MNIST_t0.csv", 24 | "mnist/loss_MNIST_t1.csv", 25 | "mnist/loss_MNIST_t2.csv", 26 | ], 27 | }, 28 | "FashionMNIST": { 29 | "cosine": [ 30 | "fmnist/loss_FMNIST_t0.csv", 31 | "fmnist/loss_FMNIST_t1.csv", 32 | "fmnist/loss_FMNIST_t2.csv", 33 | ], 34 | }, 35 | "DVS128 Gesture": { 36 | "cosine": [ 37 | "DVS/loss_DVS_t0.csv", 38 | "DVS/loss_DVS_t1.csv", 39 | "DVS/loss_DVS_t2.csv", 40 | ], 41 | }, 42 | }, 43 | } 44 | 45 | df = pd.DataFrame( 46 | columns=["dataset", "network_precision", "scheduler", "idx", "mean", "std"] 47 | ) 48 | for precision in data.keys(): 49 | for dataset_idx, dataset in enumerate(data[precision].keys()): 50 | for scheduler in data[precision][dataset].keys(): 51 | grouped_trial_df = pd.DataFrame(columns=["idx", "loss"]) 52 | for idx, trial in enumerate(data[precision][dataset][scheduler]): 53 | if trial is not None: 54 | trial_df = pd.read_csv(trial) 55 | trial_data = np.vstack( 56 | (trial_df.index, trial_df.values.flatten()) 57 | ).T 58 | trial_df = pd.DataFrame(trial_data, columns=["idx", "loss"]) 59 | grouped_trial_df = grouped_trial_df.append(trial_df) 60 | else: 61 | grouped_trial_df = grouped_trial_df.append( 62 | {"idx": 0, "loss": 1}, ignore_index=True 63 | ) 64 | 65 | grouped_trial_df["loss"] = pd.to_numeric(grouped_trial_df["loss"]) 66 | grouped_trial_df_ = grouped_trial_df.groupby("idx") 67 | grouped_trial_data = np.hstack( 68 | ( 69 | np.expand_dims(grouped_trial_df["idx"].unique(), 1), 70 | grouped_trial_df_.mean(), 71 | grouped_trial_df_.std(), 72 | ) 73 | ) 74 | grouped_trial_data = np.nan_to_num(grouped_trial_data, nan=0) 75 | trial_df = pd.DataFrame(grouped_trial_data, columns=["idx", "mean", "std"]) 76 | trial_df["network_precision"] = precision 77 | trial_df["scheduler"] = scheduler 78 | trial_df["dataset"] = dataset 79 | df = df.append(trial_df, ignore_index=True) 80 | 81 | df = df[df["idx"] % 250 == 0] 82 | df.to_csv("loss_data.csv") 83 | df_quant = df[df["network_precision"] == 4] 84 | del df 85 | # Separate out dataframes to independently take moving avgs 86 | df_mnist_quant = df_quant[df_quant["dataset"] == "MNIST"] 87 | df_fmnist_quant = df_quant[df_quant["dataset"] == "FashionMNIST"] 88 | df_dvs_quant = df_quant[df_quant["dataset"] == "DVS128 Gesture"] 89 | del df_quant 90 | df_mnist_quant["mean_rolling"] = ( 91 | df_mnist_quant.iloc[:, 4].rolling(window=20, min_periods=1).mean() 92 | ) 93 | df_mnist_quant["std_rolling"] = ( 94 | df_mnist_quant.iloc[:, 5].rolling(window=20, min_periods=1).mean() 95 | ) 96 | df_mnist_quant = df_mnist_quant.dropna() 97 | df_fmnist_quant["mean_rolling"] = ( 98 | df_fmnist_quant.iloc[:, 4].rolling(window=20, min_periods=1).mean() 99 | ) 100 | df_fmnist_quant["std_rolling"] = ( 101 | df_fmnist_quant.iloc[:, 5].rolling(window=20, min_periods=1).mean() 102 | ) 103 | df_fmnist_quant = df_fmnist_quant.dropna() 104 | df_dvs_quant["mean_rolling"] = ( 105 | df_dvs_quant.iloc[:, 4].rolling(window=5, min_periods=1).mean() 106 | ) 107 | df_dvs_quant["std_rolling"] = ( 108 | df_dvs_quant.iloc[:, 5].rolling(window=5, min_periods=1).mean() 109 | ) 110 | df_dvs_quant = df_dvs_quant.dropna() 111 | # Combine them 112 | frames = [df_mnist_quant, df_fmnist_quant, df_dvs_quant] 113 | df = pd.concat(frames, ignore_index=True) 114 | # Plot rolling avgs or raw mean/std 115 | col_name = "mean_rolling" # or mean 116 | std_name = "std_rolling" # or std 117 | y_axis_limits = [[0.0008, 0.003], [0.004, 0.0075], [0.001, 0.0125]] 118 | palette = sns.color_palette("bright", 4) 119 | for precision_idx, precision in enumerate(data.keys()): 120 | for dataset_idx, dataset in enumerate(data[precision].keys()): 121 | df_tmp = df[df["network_precision"] == precision] 122 | df_ = df_tmp[df_tmp["dataset"] == dataset] 123 | # Plot mean values 124 | sns.lineplot( 125 | data=df_, 126 | x="idx", 127 | y=col_name, 128 | hue="scheduler", 129 | ax=ax[dataset_idx], 130 | linewidth=2.5, 131 | palette=palette, 132 | ) # alpha=0.95 133 | # Manually plot error bounds 134 | for scheduler_idx, scheduler in enumerate(data[precision][dataset].keys()): 135 | df__ = df_[df_["scheduler"] == scheduler] 136 | x = df__["idx"].values 137 | try: 138 | lower = df__[col_name].values - df__[std_name].values 139 | upper = df__[col_name].values + df__[std_name].values 140 | ax[dataset_idx].plot(x, lower, color=palette[scheduler_idx], alpha=0.2) 141 | ax[dataset_idx].plot(x, upper, color=palette[scheduler_idx], alpha=0.2) 142 | ax[dataset_idx].spines["top"].set_visible(False) 143 | ax[dataset_idx].spines["right"].set_visible(False) 144 | ax[dataset_idx].fill_between(x, lower, upper, alpha=0.1) 145 | except: 146 | pass 147 | 148 | ax[dataset_idx].set_title(dataset) 149 | ax[dataset_idx].set_xlim([0, None]) 150 | ax[dataset_idx].set_ylim(y_axis_limits[dataset_idx]) 151 | ax[dataset_idx].yaxis.set_major_formatter(FormatStrFormatter("%.5f")) 152 | ax[dataset_idx].grid() 153 | ax[dataset_idx].set_xlabel("Minibatch") 154 | ax[dataset_idx].set_ylabel("MSE Loss") 155 | print(precision, dataset) 156 | 157 | plt.show() 158 | -------------------------------------------------------------------------------- /quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "47d5313e-c29d-4581-a9c7-a45122337069", 16 | "metadata": { 17 | "id": "47d5313e-c29d-4581-a9c7-a45122337069" 18 | }, 19 | "source": [ 20 | "# Quantized Spiking Neural Networks\n", 21 | "This notebook is based on the paper *Navigating Local Minima in Quantized Spiking Neural Networks.* It demonstrates how to train quantized spiking neural networks using cosine annealing on the FashionMNIST dataset. For other datasets, networks, and for the experiments described in the corresponding paper, please [refer to the QSNNs repo](https://github.com/jeshraghian/QSNNs/).\n", 22 | "\n", 23 | "\n", 24 | "![git_path](https://user-images.githubusercontent.com/13549940/154009399-eb6152f7-31db-4f93-9978-ac1e1c4a8c6a.svg)\n", 25 | "\n", 26 | "

Illustrations of the key concepts of the paper: Periodic scheduling can enable SNNs to overcome flat surfaces and local minima. When the LR is boosted during training using a cyclic scheduler, it is given another chance to reduce the loss with different initial conditions. While the loss appears to converge, subsequent LR boosting enables it to traverse more optimal solutions.

\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "b68d7bb4", 32 | "metadata": {}, 33 | "source": [ 34 | "## Install All Required Packages and Import Necessary Libraries" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "hDnIEHOKB8LD", 41 | "metadata": { 42 | "id": "hDnIEHOKB8LD" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import urllib.request\n", 47 | "urllib.request.urlretrieve('https://raw.githubusercontent.com/jeshraghian/QSNNs/main/requirements.txt', 'requirements.txt')\n", 48 | "!pip install -r requirements.txt --quiet\n", 49 | "import torch, torch.nn as nn\n", 50 | "import snntorch as snn\n", 51 | "import brevitas.nn as qnn" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "EYf13Gtx1OCj", 57 | "metadata": { 58 | "id": "EYf13Gtx1OCj" 59 | }, 60 | "source": [ 61 | "## Create a Dataloader for the FashionMNIST Dataset" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "17e61945", 67 | "metadata": {}, 68 | "source": [ 69 | "Download and apply transforms to the FashionMNIST dataset." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "eo4T5MC21hgD", 76 | "metadata": { 77 | "id": "eo4T5MC21hgD" 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "from torch.utils.data import DataLoader\n", 82 | "from torchvision import datasets, transforms\n", 83 | "\n", 84 | "\n", 85 | "data_path='/data/fmnist' # Directory where FMNIST dataset is stored\n", 86 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\") # Use GPU if available\n", 87 | "\n", 88 | "# Define a transform to normalize data\n", 89 | "transform = transforms.Compose([\n", 90 | " transforms.Resize((28, 28)),\n", 91 | " transforms.Grayscale(),\n", 92 | " transforms.ToTensor(),\n", 93 | " transforms.Normalize((0,), (1,))])\n", 94 | "\n", 95 | "# Download and load the training and test FashionMNIST datasets\n", 96 | "fmnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)\n", 97 | "fmnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "CHcNZT-7iCQH", 103 | "metadata": { 104 | "id": "CHcNZT-7iCQH" 105 | }, 106 | "source": [ 107 | "To speed-up simulations for demonstration purposes, the below code cell can be run to reduce the number of samples in the training and test sets by a factor of 10." 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "q5bhKdF_h7qk", 114 | "metadata": { 115 | "id": "q5bhKdF_h7qk" 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "from snntorch import utils\n", 120 | "\n", 121 | "\n", 122 | "utils.data_subset(fmnist_train, 10)\n", 123 | "utils.data_subset(fmnist_test, 10)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "bLmrQ5pEiSSJ", 129 | "metadata": { 130 | "id": "bLmrQ5pEiSSJ" 131 | }, 132 | "source": [ 133 | "Create DataLoaders with batches of 128 samples and shuffle the training set." 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "xstp4mn_iRxi", 140 | "metadata": { 141 | "id": "xstp4mn_iRxi" 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "batch_size = 128 # Batches of 128 samples\n", 146 | "trainloader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True)\n", 147 | "testloader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=False)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "i3A4exp_c0c5", 153 | "metadata": { 154 | "id": "i3A4exp_c0c5" 155 | }, 156 | "source": [ 157 | "## Define Network Parameters" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "vrt2wObbiXSf", 163 | "metadata": { 164 | "id": "vrt2wObbiXSf" 165 | }, 166 | "source": [ 167 | "We have only specified 15 epochs without early stopping as a quick, early demonstration. Feel free to increase this. " 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "ivhGn7Lhc6te", 174 | "metadata": { 175 | "id": "ivhGn7Lhc6te" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "config = {\n", 180 | " \"num_epochs\": 15, # Number of epochs to train for (per trial)\n", 181 | " \"batch_size\": 128, # Batch size\n", 182 | " \"seed\": 0, # Random seed\n", 183 | " \n", 184 | " # Quantization\n", 185 | " \"num_bits\": 4, # Bit resolution\n", 186 | " \n", 187 | " # Network parameters\n", 188 | " \"grad_clip\": False, # Whether or not to clip gradients\n", 189 | " \"weight_clip\": False, # Whether or not to clip weights\n", 190 | " \"batch_norm\": True, # Whether or not to use batch normalization\n", 191 | " \"dropout\": 0.07, # Dropout rate\n", 192 | " \"beta\": 0.97, # Decay rate parameter (beta)\n", 193 | " \"threshold\": 2.5, # Threshold parameter (theta)\n", 194 | " \"lr\": 3.0e-3, # Initial learning rate\n", 195 | " \"slope\": 5.6, # Slope value (k)\n", 196 | " \n", 197 | " # Fixed params\n", 198 | " \"num_steps\": 100, # Number of timesteps to encode input for\n", 199 | " \"correct_rate\": 0.8, # Correct rate\n", 200 | " \"incorrect_rate\": 0.2, # Incorrect rate\n", 201 | " \"betas\": (0.9, 0.999), # Adam optimizer beta values\n", 202 | " \"t_0\": 4690, # Initial frequency of the cosine annealing scheduler\n", 203 | " \"eta_min\": 0, # Minimum learning rate\n", 204 | "}" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "id": "BtJBOtez11wy", 210 | "metadata": { 211 | "id": "BtJBOtez11wy" 212 | }, 213 | "source": [ 214 | "## Define the Network Architecture\n", 215 | "* 5 $\\times$ Conv Layer w/16 Filters\n", 216 | "* 2 $\\times$ 2 Average Pooling\n", 217 | "* 5 $\\times$ Conv Layer w/64 Filters\n", 218 | "* 2 $\\times$ 2 Average Pooling\n", 219 | "* (64 $\\times$ 4 $\\times$ 4) -- 10 Dense Layer" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "JM2thnrc10rD", 226 | "metadata": { 227 | "id": "JM2thnrc10rD" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "from snntorch import surrogate\n", 232 | "import torch.nn.functional as F\n", 233 | "\n", 234 | "\n", 235 | "class Net(nn.Module):\n", 236 | " def __init__(self, config):\n", 237 | " super().__init__()\n", 238 | " self.num_bits = config[\"num_bits\"]\n", 239 | " self.thr = config[\"threshold\"]\n", 240 | " self.slope = config[\"slope\"]\n", 241 | " self.beta = config[\"beta\"]\n", 242 | " self.num_steps = config[\"num_steps\"]\n", 243 | " self.batch_norm = config[\"batch_norm\"]\n", 244 | " self.p1 = config[\"dropout\"]\n", 245 | " self.spike_grad = surrogate.fast_sigmoid(self.slope)\n", 246 | " \n", 247 | " # Initialize Layers\n", 248 | " self.conv1 = qnn.QuantConv2d(1, 16, 5, bias=False, weight_bit_width=self.num_bits)\n", 249 | " self.conv1_bn = nn.BatchNorm2d(16)\n", 250 | " self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)\n", 251 | " self.conv2 = qnn.QuantConv2d(16, 64, 5, bias=False, weight_bit_width=self.num_bits)\n", 252 | " self.conv2_bn = nn.BatchNorm2d(64)\n", 253 | " self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)\n", 254 | " self.fc1 = qnn.QuantLinear(64 * 4 * 4, 10, bias=False, weight_bit_width=self.num_bits)\n", 255 | " self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)\n", 256 | " self.dropout = nn.Dropout(self.p1)\n", 257 | "\n", 258 | " def forward(self, x):\n", 259 | " # Initialize hidden states and outputs at t=0\n", 260 | " mem1 = self.lif1.init_leaky()\n", 261 | " mem2 = self.lif2.init_leaky()\n", 262 | " mem3 = self.lif3.init_leaky()\n", 263 | "\n", 264 | " # Record the final layer\n", 265 | " spk3_rec = []\n", 266 | " mem3_rec = []\n", 267 | "\n", 268 | " # Forward pass\n", 269 | " for step in range(self.num_steps):\n", 270 | " cur1 = F.avg_pool2d(self.conv1(x), 2)\n", 271 | " if self.batch_norm:\n", 272 | " cur1 = self.conv1_bn(cur1)\n", 273 | "\n", 274 | " spk1, mem1 = self.lif1(cur1, mem1)\n", 275 | " cur2 = F.avg_pool2d(self.conv2(spk1), 2)\n", 276 | " if self.batch_norm:\n", 277 | " cur2 = self.conv2_bn(cur2)\n", 278 | "\n", 279 | " spk2, mem2 = self.lif2(cur2, mem2)\n", 280 | " cur3 = self.dropout(self.fc1(spk2.flatten(1)))\n", 281 | " spk3, mem3 = self.lif3(cur3, mem3)\n", 282 | " spk3_rec.append(spk3)\n", 283 | " mem3_rec.append(mem3)\n", 284 | "\n", 285 | " return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)\n", 286 | "\n", 287 | "net = Net(config).to(device)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "id": "BmtJx_AAeOyP", 293 | "metadata": { 294 | "id": "BmtJx_AAeOyP" 295 | }, 296 | "source": [ 297 | "## Define the Optimizer, Learning Rate Scheduler, and Loss Function\n", 298 | "* Adam optimizer\n", 299 | "* Cosine Annealing Scheduler\n", 300 | "* MSE Spike Count Loss (Target spike count for correct and incorrect classes are specified)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "id": "ky-qAN_YeKmE", 307 | "metadata": { 308 | "id": "ky-qAN_YeKmE" 309 | }, 310 | "outputs": [], 311 | "source": [ 312 | "import snntorch.functional as SF\n", 313 | "\n", 314 | "\n", 315 | "optimizer = torch.optim.Adam(net.parameters(), \n", 316 | " lr=config[\"lr\"], betas=config[\"betas\"]\n", 317 | ")\n", 318 | "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, \n", 319 | " T_max=config[\"t_0\"], \n", 320 | " eta_min=config[\"eta_min\"], \n", 321 | " last_epoch=-1\n", 322 | ")\n", 323 | "criterion = SF.mse_count_loss(correct_rate=config[\"correct_rate\"], \n", 324 | " incorrect_rate=config[\"incorrect_rate\"]\n", 325 | ")" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "id": "UGtJwmtVexb4", 331 | "metadata": { 332 | "id": "UGtJwmtVexb4" 333 | }, 334 | "source": [ 335 | "## Train and Evaluate the Network" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "id": "2321a02f", 341 | "metadata": {}, 342 | "source": [ 343 | "As the learning rate follows a periodic schedule, the accuracy will oscillate across the training process, but with a general tendency to improve." 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "id": "tbOQgPiEe-lp", 350 | "metadata": { 351 | "id": "tbOQgPiEe-lp" 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "def train(config, net, trainloader, criterion, optimizer, device=\"cpu\", scheduler=None):\n", 356 | " \"\"\"Complete one epoch of training.\"\"\"\n", 357 | " \n", 358 | " net.train()\n", 359 | " loss_accum = []\n", 360 | " lr_accum = []\n", 361 | " i = 0\n", 362 | " for data, labels in trainloader:\n", 363 | " data, labels = data.to(device), labels.to(device)\n", 364 | " spk_rec, _ = net(data)\n", 365 | " loss = criterion(spk_rec, labels)\n", 366 | " optimizer.zero_grad()\n", 367 | " loss.backward()\n", 368 | "\n", 369 | " ## Enable gradient clipping\n", 370 | " if config[\"grad_clip\"]:\n", 371 | " nn.utils.clip_grad_norm_(net.parameters(), 1.0)\n", 372 | "\n", 373 | " ## Enable weight clipping\n", 374 | " if config[\"weight_clip\"]:\n", 375 | " with torch.no_grad():\n", 376 | " for param in net.parameters():\n", 377 | " param.clamp_(-1, 1)\n", 378 | "\n", 379 | " optimizer.step()\n", 380 | " scheduler.step()\n", 381 | " loss_accum.append(loss.item() / config[\"num_steps\"])\n", 382 | " lr_accum.append(optimizer.param_groups[0][\"lr\"])\n", 383 | "\n", 384 | " return loss_accum, lr_accum\n", 385 | "\n", 386 | "def test(config, net, testloader, device=\"cpu\"):\n", 387 | " \"\"\"Calculate accuracy on full test set.\"\"\"\n", 388 | " correct = 0\n", 389 | " total = 0\n", 390 | " with torch.no_grad():\n", 391 | " net.eval()\n", 392 | " for data in testloader:\n", 393 | " images, labels = data\n", 394 | " images, labels = images.to(device), labels.to(device)\n", 395 | " outputs, _ = net(images)\n", 396 | " accuracy = SF.accuracy_rate(outputs, labels)\n", 397 | " total += labels.size(0)\n", 398 | " correct += accuracy * labels.size(0)\n", 399 | "\n", 400 | " return 100 * correct / total\n", 401 | "\n", 402 | "loss_list = []\n", 403 | "lr_list = []\n", 404 | "\n", 405 | "print(f\"=======Training Network=======\")\n", 406 | "# Train\n", 407 | "for epoch in range(config['num_epochs']):\n", 408 | " loss, lr = train(config, net, trainloader, criterion, optimizer, \n", 409 | " device, scheduler\n", 410 | " )\n", 411 | " loss_list = loss_list + loss\n", 412 | " lr_list = lr_list + lr\n", 413 | " # Test\n", 414 | " test_accuracy = test(config, net, testloader, device)\n", 415 | " print(f\"Epoch: {epoch} \\tTest Accuracy: {test_accuracy}\")" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "id": "14d0bd78", 421 | "metadata": {}, 422 | "source": [ 423 | "## Plot the Training Loss and Learning Rate Over Time" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "id": "B22SnaTElOLh", 430 | "metadata": { 431 | "id": "B22SnaTElOLh" 432 | }, 433 | "outputs": [], 434 | "source": [ 435 | "%matplotlib inline\n", 436 | "import matplotlib.pyplot as plt\n", 437 | "import seaborn as sns\n", 438 | "\n", 439 | "\n", 440 | "sns.set_theme()\n", 441 | "fig, ax1 = plt.subplots()\n", 442 | "ax2 = ax1.twinx()\n", 443 | "ax1.plot(loss_list, color='tab:orange')\n", 444 | "ax2.plot(lr_list, color='tab:blue')\n", 445 | "ax1.set_xlabel('Iteration')\n", 446 | "ax1.set_ylabel('Loss', color='tab:orange')\n", 447 | "ax2.set_ylabel('Learning Rate', color='tab:blue')\n", 448 | "plt.show()" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "id": "-iSGTq0Q3Lcm", 454 | "metadata": { 455 | "id": "-iSGTq0Q3Lcm" 456 | }, 457 | "source": [ 458 | "# Conclusion\n", 459 | "That's it for the quick intro to quantized SNNs! Results can be further improved by not using the `snntorch.utils.data_subset` method to train with the full FashionMNIST dataset, training for a larger number of epochs, and utilizing early stopping logic.\n", 460 | "\n", 461 | "To run the experiments from the corresponding paper, including those on dynamic datasets, please [refer to the corresponding GitHub repo](https://github.com/jeshraghian/QSNNs/)." 462 | ] 463 | } 464 | ], 465 | "metadata": { 466 | "accelerator": "GPU", 467 | "colab": { 468 | "include_colab_link": true, 469 | "name": "Copy of tutorial_5_neuromorphic_datasets.ipynb", 470 | "provenance": [] 471 | }, 472 | "kernelspec": { 473 | "display_name": "Python 3 (ipykernel)", 474 | "language": "python", 475 | "name": "python3" 476 | }, 477 | "language_info": { 478 | "codemirror_mode": { 479 | "name": "ipython", 480 | "version": 3 481 | }, 482 | "file_extension": ".py", 483 | "mimetype": "text/x-python", 484 | "name": "python", 485 | "nbconvert_exporter": "python", 486 | "pygments_lexer": "ipython3", 487 | "version": "3.8.11" 488 | } 489 | }, 490 | "nbformat": 4, 491 | "nbformat_minor": 5 492 | } 493 | --------------------------------------------------------------------------------