├── .gitignore ├── README.md ├── dvs ├── Net.py ├── bnn.py ├── conf.py ├── dataloader.py ├── earlystopping.py ├── functions.py ├── run.py ├── test_acc.py ├── tha.py └── train.py ├── figs └── temporal_code.png ├── fmnist ├── Net.py ├── bnn.py ├── conf.py ├── dataloader.py ├── earlystopping.py ├── functions.py ├── run.py ├── test_acc.py ├── tha.py └── train.py ├── mnist ├── Net.py ├── bnn.py ├── conf.py ├── dataloader.py ├── earlystopping.py ├── functions.py ├── run.py ├── test_acc.py ├── tha.py └── train.py ├── requirements.txt ├── shd ├── Net.py ├── bnn.py ├── conf.py ├── dataloader.py ├── earlystopping.py ├── functions.py ├── run.py ├── test_acc.py ├── tha.py └── train.py └── temporal └── bounded_homeostasis.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | instructions.txt 2 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bounded Homeostasis in Binarized Spiking Neural Networks 2 | 3 | 18 | 19 | ## Requirements 20 | A working `Python` (≥3.6) interpreter and the `pip` package manager. All required libraries and packages can be installed using `pip install -r requirements.txt`. To avoid potential package conflicts, the use of a `conda` environment is recommended. The following commands can be used to create and activate a separate `conda` environment, clone this repository, and to install all dependencies: 21 | 22 | ``` 23 | conda create -n snn-tha python=3.8 24 | conda activate snn-tha 25 | git clone https://github.com/jeshraghian/snn-tha.git 26 | cd snn-tha 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Code Execution 31 | To execute code, `cd` into one of four dataset directories, and then run `python run.py`. 32 | 33 | ## Hyperparameter Tuning 34 | * In each directory, `conf.py` defines all configuration parameters and hyperparameters for each dataset. The default parameters in this repo are identical to those for the binarized case with bounded homeostasis as reported in the corresponding paper. 35 | * To run binarized networks, set `"binarize" : True"` in `conf.py`. For optimized parameters, follow the values reported in the paper. 36 | 37 | 38 | # Temporal Coding 39 | Section 4 of the paper demonstrates the use of bounded homeostasis (using threshold annealing as the warm-up technique) in a spike-timing task. A fully connected network of structure 100-1000-1 is used, where a Poisson spike train is passed at the input, and the output neuron is trained to spike at by linearly ramping up the membrane potential over time using a mean square error loss at each time step: 40 | 41 | 42 | 43 | The animated versions of the above figures are provided below, and can be reproduced in the corresponding notebook. 44 | 45 | ## Animations 46 | 47 | ### High Precision Weights, Normalized Threshold 48 | 49 | This is the optimal baseline, showing that it is a reasonably straightforward task to achieve. 50 | 51 | https://user-images.githubusercontent.com/40262130/150855093-4cdaa55b-7cad-4d5a-b5fa-9e482c6fe07e.mp4 52 | 53 | ### Binarized Weights, Normalized Threshold 54 | The results become significantly unstable when binarizing weights. 55 | 56 | https://user-images.githubusercontent.com/40262130/150855727-9ccfcca2-8b48-48cc-b5df-0d17f367968c.mp4 57 | 58 | A moving average over training iterations is used in an attempt to clean up the above plot, but the results remain senseless: 59 | 60 | https://user-images.githubusercontent.com/40262130/150855822-02d9177c-e08f-48f4-8753-d5c937e49c00.mp4 61 | 62 | ### Binarized Weights, Large Threshold 63 | Increasing the threshold of all neurons provides a higher dynamic range state-space. But increasing the threshold too high leads to the dead neuron problem. The animation below shows how spiking activity has been suppressed; the flat membrane potential is purely a result of the bias. 64 | 65 | https://user-images.githubusercontent.com/40262130/150856229-0a3ae7ce-5670-4545-b13c-06dd3ca992f3.mp4 66 | 67 | ### Binarized Weights, Threshold Annealing 68 | Now apply threshold annealing to use an evolving neuronal state-space to gradually lift spiking activity. This avoids the dead neuron problem in the large threshold case, and avoids the instability/memory leakage in the normalized threshold case. 69 | 70 | https://user-images.githubusercontent.com/40262130/150856483-f53f2156-4348-46da-9c0f-5f05f31cf677.mp4 71 | 72 | This now looks far more functional than all previous binarized cases. 73 | We can take a moving average to smooth out the impact of sudden reset dynamics. Although not as perfect as the high precision case, the binarized SNN continues to learn despite the excessively high final threshold. 74 | 75 | https://user-images.githubusercontent.com/40262130/150856726-aedb1d08-fe61-4b32-a3aa-6dcc9c76311a.mp4 76 | 77 | -------------------------------------------------------------------------------- /dvs/Net.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import surrogate 4 | 5 | # torch 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # local 11 | from bnn import * 12 | 13 | 14 | class Net(nn.Module): 15 | def __init__(self, config): 16 | super().__init__() 17 | 18 | self.thr1 = config['threshold1'] 19 | self.thr2 = config['threshold2'] 20 | self.thr3 = config['threshold3'] 21 | slope = config['slope'] 22 | beta = config['beta'] 23 | self.num_steps = config['num_steps'] 24 | self.batch_norm = config['batch_norm'] 25 | p1 = config['dropout1'] 26 | self.binarize = config['binarize'] 27 | 28 | 29 | spike_grad = surrogate.fast_sigmoid(slope) 30 | # Initialize layers with spike operator 31 | self.bconv1 = BinaryConv2d(2, 16, 5, bias=False) 32 | self.conv1 = nn.Conv2d(2, 16, 5, bias=False) 33 | self.conv1_bn = nn.BatchNorm2d(16) 34 | self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad) 35 | self.bconv2 = BinaryConv2d(16, 32, 5, bias=False) 36 | self.conv2 = nn.Conv2d(16, 32, 5, bias=False) 37 | self.conv2_bn = nn.BatchNorm2d(32) 38 | self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad) 39 | self.bfc1 = BinaryLinear(32 * 5 * 5, 11) 40 | self.fc1 = nn.Linear(32 * 5 * 5, 11) 41 | self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad) 42 | self.dropout = nn.Dropout(p1) 43 | 44 | 45 | def forward(self, x): 46 | 47 | # Initialize hidden states and outputs at t=0 48 | mem1 = self.lif1.init_leaky() 49 | mem2 = self.lif2.init_leaky() 50 | mem3 = self.lif3.init_leaky() 51 | 52 | # Record the final layer 53 | spk3_rec = [] 54 | mem3_rec = [] 55 | 56 | # Binarization 57 | 58 | if self.binarize: 59 | 60 | for step in range(x.size(0)): 61 | 62 | # fc1weight = self.fc1.weight.data 63 | cur1 = F.avg_pool2d(self.bconv1(x[step]), 2) 64 | if self.batch_norm: 65 | cur1 = self.conv1_bn(cur1) 66 | spk1, mem1 = self.lif1(cur1, mem1) 67 | cur2 = F.avg_pool2d(self.bconv2(spk1), 2) 68 | if self.batch_norm: 69 | cur2 = self.conv2_bn(cur2) 70 | spk2, mem2 = self.lif2(cur2, mem2) 71 | 72 | cur3 = self.dropout(self.bfc1(spk2.flatten(1))) 73 | spk3, mem3 = self.lif3(cur3, mem3) 74 | 75 | spk3_rec.append(spk3) 76 | mem3_rec.append(mem3) 77 | 78 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0) 79 | 80 | # Full Precision 81 | 82 | else: 83 | 84 | for step in range(x.size(0)): 85 | # fc1weight = self.fc1.weight.data 86 | cur1 = F.avg_pool2d(self.conv1(x[step]), 2) 87 | if self.batch_norm: 88 | cur1 = self.conv1_bn(cur1) 89 | spk1, mem1 = self.lif1(cur1, mem1) 90 | cur2 = F.avg_pool2d(self.conv2(spk1), 2) 91 | if self.batch_norm: 92 | cur2 = self.conv2_bn(cur2) 93 | spk2, mem2 = self.lif2(cur2, mem2) 94 | 95 | cur3 = self.dropout(self.fc1(spk2.flatten(1))) 96 | spk3, mem3 = self.lif3(cur3, mem3) 97 | 98 | 99 | spk3_rec.append(spk3) 100 | mem3_rec.append(mem3) 101 | 102 | 103 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0) 104 | -------------------------------------------------------------------------------- /dvs/bnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from functions import * 8 | 9 | 10 | class BinaryTanh(nn.Module): 11 | def __init__(self): 12 | super(BinaryTanh, self).__init__() 13 | self.hardtanh = nn.Hardtanh() 14 | 15 | def forward(self, input): 16 | output = self.hardtanh(input) 17 | output = binarize(output) 18 | return output 19 | 20 | 21 | class BinaryLinear(nn.Linear): 22 | 23 | def forward(self, input): 24 | binary_weight = binarize(self.weight) 25 | if self.bias is None: 26 | return F.linear(input, binary_weight) 27 | else: 28 | return F.linear(input, binary_weight, self.bias) 29 | 30 | def reset_parameters(self): 31 | # Glorot initialization 32 | in_features, out_features = self.weight.size() 33 | stdv = math.sqrt(1.5 / (in_features + out_features)) 34 | self.weight.data.uniform_(-stdv, stdv) 35 | if self.bias is not None: 36 | self.bias.data.zero_() 37 | 38 | self.weight.lr_scale = 1. / stdv 39 | 40 | 41 | 42 | class BinaryConv2d(nn.Conv2d): 43 | 44 | def forward(self, input): 45 | bw = binarize(self.weight) 46 | return F.conv2d(input, bw, self.bias, self.stride, 47 | self.padding, self.dilation, self.groups) 48 | 49 | def reset_parameters(self): 50 | # Glorot initialization 51 | in_features = self.in_channels 52 | out_features = self.out_channels 53 | for k in self.kernel_size: 54 | in_features *= k 55 | out_features *= k 56 | stdv = math.sqrt(1.5 / (in_features + out_features)) 57 | self.weight.data.uniform_(-stdv, stdv) 58 | if self.bias is not None: 59 | self.bias.data.zero_() 60 | 61 | self.weight.lr_scale = 1. / stdv -------------------------------------------------------------------------------- /dvs/conf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | config = { 4 | 'exp_name' : 'dvs_tha', 5 | 'num_trials' : 5, 6 | 'num_epochs' : 500, 7 | 'binarize' : True, 8 | 'data_dir' : "/home/dvs", 9 | 'batch_size' : 8, 10 | 'seed' : 0, 11 | 'num_workers' : 0, 12 | 13 | # final run sweeps 14 | 'save_csv' : True, 15 | 'save_model' : True, 16 | 'early_stopping': True, 17 | 'patience': 100, 18 | 19 | # final params 20 | 'grad_clip' : True, 21 | 'weight_clip' : False, 22 | 'batch_norm' : False, 23 | 'dropout1' : 0.43, 24 | 'beta' : 0.9297, 25 | 'lr' : 1.765e-3, 26 | 'slope': 0.24, 27 | 28 | # threshold annealing. note: thr_final = threshold + thr_final 29 | 'threshold1' : 10.4, 30 | 'alpha_thr1' : 0.00333, 31 | 'thr_final1' : 1.7565, 32 | 33 | 'threshold2' : 16.62, 34 | 'alpha_thr2' : 0.0061, 35 | 'thr_final2' : 2.457, 36 | 37 | 'threshold3' : 6.81, 38 | 'alpha_thr3' : 0.173, 39 | 'thr_final3' : 9.655, 40 | 41 | # fixed params 42 | 'num_steps' : 100, 43 | 'correct_rate': 0.8, 44 | 'incorrect_rate' : 0.2, 45 | 'betas' : (0.9, 0.999), 46 | 't_0' : 735, 47 | 'eta_min' : 0, 48 | 'df_lr' : True, # return learning rate. Useful for scheduling 49 | 50 | 51 | } 52 | 53 | def optim_func(net, config): 54 | optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=config['betas']) 55 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1) 56 | return optimizer, scheduler -------------------------------------------------------------------------------- /dvs/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | 6 | from snntorch.spikevision import spikedata 7 | 8 | def load_data(config): 9 | data_dir = config['data_dir'] 10 | 11 | trainset = spikedata.DVSGesture(data_dir, train=True, num_steps=100, dt=5000, ds=4) 12 | testset = spikedata.DVSGesture(data_dir, train=False, num_steps=360, dt=5000, ds=4) 13 | 14 | return trainset, testset -------------------------------------------------------------------------------- /dvs/earlystopping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EarlyStopping_acc: 5 | """Early stops the training if test acc doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.patience = patience 21 | self.verbose = verbose 22 | self.counter = 0 23 | self.best_score = None 24 | self.early_stop = False 25 | self.test_loss_min = 0 26 | self.delta = delta 27 | self.path = path 28 | self.trace_func = trace_func 29 | def __call__(self, test_loss, model): 30 | 31 | score = test_loss 32 | 33 | if self.best_score is None: 34 | self.best_score = score 35 | self.save_checkpoint(test_loss, model) 36 | elif score <= self.best_score + self.delta: 37 | self.counter += 1 38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 39 | if self.counter >= self.patience: 40 | self.early_stop = True 41 | self.counter = 0 42 | else: 43 | self.best_score = score 44 | self.save_checkpoint(test_loss, model) 45 | self.counter = 0 46 | 47 | def save_checkpoint(self, test_loss, model): 48 | '''Saves model when test acc increases.''' 49 | if self.verbose: 50 | self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}). Saving model ...') 51 | torch.save(model.state_dict(), self.path) 52 | self.test_loss_min = test_loss -------------------------------------------------------------------------------- /dvs/functions.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Function 5 | 6 | 7 | class BinarizeF(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, input): 11 | output = input.new(input.size()) 12 | output[input >= 0] = 1 13 | output[input < 0] = -1 14 | return output 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | grad_input = grad_output.clone() 19 | return grad_input 20 | 21 | # aliases 22 | binarize = BinarizeF.apply -------------------------------------------------------------------------------- /dvs/run.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import spikegen 4 | from snntorch import surrogate 5 | 6 | # torch 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | # misc 13 | import numpy as np 14 | import pandas as pd 15 | import time 16 | import logging 17 | 18 | # local imports 19 | from dataloader import * 20 | from Net import * 21 | from test_acc import * 22 | from train import * 23 | from earlystopping import * 24 | from conf import * 25 | 26 | #################################################### 27 | ## Notes: modify config in conf to reparameterize ## 28 | #################################################### 29 | 30 | file_name = config['exp_name'] 31 | 32 | ### to address conditional parameters, s.t. thr_final > threshold 33 | config['thr_final1'] = config['thr_final1'] + config['threshold1'] 34 | config['thr_final2'] = config['thr_final2'] + config['threshold2'] 35 | config['thr_final3'] = config['thr_final3'] + config['threshold3'] 36 | 37 | threshold1 = config['threshold1'] 38 | threshold2 = config['threshold2'] 39 | threshold3 = config['threshold3'] 40 | 41 | for trial in range(config['num_trials']): 42 | 43 | # file names 44 | SAVE_CSV = config['save_csv'] 45 | SAVE_MODEL = config['save_model'] 46 | csv_name = file_name + '_t' + str(trial) + '.csv' 47 | log_name = file_name + '_t' + str(trial) + '.log' 48 | model_name = file_name + '_t' + str(trial) + '.pt' 49 | num_epochs = config['num_epochs'] 50 | torch.manual_seed(config['seed']) 51 | 52 | config['threshold1'] = threshold1 53 | config['threshold2'] = threshold2 54 | config['threshold3'] = threshold3 55 | 56 | # dataframes 57 | df_train_loss = pd.DataFrame() 58 | df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time']) 59 | df_lr = pd.DataFrame() 60 | 61 | # initialize network 62 | net = Net(config) 63 | device = "cpu" 64 | if torch.cuda.is_available(): 65 | device = "cuda:0" 66 | if torch.cuda.device_count() > 1: 67 | net = nn.DataParallel(net) 68 | net.to(device) 69 | 70 | # net params 71 | criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate']) 72 | optimizer, scheduler = optim_func(net, config) 73 | 74 | # early stopping condition 75 | if config['early_stopping']: 76 | early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name) 77 | early_stopping.early_stop = False 78 | early_stopping.best_score = None 79 | 80 | # load data 81 | trainset, testset = load_data(config) 82 | config['dataset_length'] = len(trainset) 83 | trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True) 84 | testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False) 85 | 86 | print(f"=======Trial: {trial}=======") 87 | 88 | for epoch in range(num_epochs): 89 | 90 | # train 91 | start_time = time.time() 92 | loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device) 93 | epoch_time = time.time() - start_time 94 | 95 | # test 96 | test_acc = test_accuracy(config, net, testloader, device) 97 | print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}') 98 | 99 | if config['df_lr']: 100 | df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)]) 101 | df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)]) 102 | test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time']) 103 | df_test_acc = pd.concat([df_test_acc, test_data]) 104 | 105 | if SAVE_CSV: 106 | df_train_loss.to_csv('loss_' + csv_name, index=False) 107 | df_test_acc.to_csv('acc_' + csv_name, index=False) 108 | if config['df_lr']: 109 | df_lr.to_csv('lr_' + csv_name, index=False) 110 | 111 | if config['early_stopping']: 112 | early_stopping(test_acc, net) 113 | 114 | if early_stopping.early_stop: 115 | print("Early stopping") 116 | early_stopping.early_stop = False 117 | early_stopping.best_score = None 118 | break 119 | 120 | if SAVE_MODEL and not config['early_stopping']: 121 | torch.save(net.state_dict(), model_name) 122 | 123 | # net.load_state_dict(torch.load(model_name)) 124 | -------------------------------------------------------------------------------- /dvs/test_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import snntorch as snn 3 | from snntorch import functional as SF 4 | 5 | 6 | def test_accuracy(config, net, testloader, device="cpu"): 7 | 8 | 9 | correct = 0 10 | total = 0 11 | with torch.no_grad(): 12 | net.eval() 13 | for data in testloader: 14 | images, labels = data 15 | images, labels = images.to(device), labels.to(device) # .permute(1, 0, 2, 3, 4) 16 | 17 | outputs, _ = net(images.permute(1, 0, 2, 3, 4)) 18 | accuracy = SF.accuracy_rate(outputs, labels) 19 | 20 | total += labels.size(0) 21 | correct += accuracy * labels.size(0) 22 | 23 | return 100 * correct / total -------------------------------------------------------------------------------- /dvs/tha.py: -------------------------------------------------------------------------------- 1 | # exp relaxation implementation of THA based on Eq (4) 2 | 3 | def thr_annealing(config, network): 4 | alpha_thr1 = config['alpha_thr1'] 5 | alpha_thr2 = config['alpha_thr2'] 6 | alpha_thr3 = config['alpha_thr3'] 7 | 8 | thr_final1 = config['thr_final1'] 9 | thr_final2 = config['thr_final2'] 10 | thr_final3 = config['thr_final3'] 11 | 12 | network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1 13 | network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2 14 | network.lif3.threshold += (thr_final3 - network.lif3.threshold) * alpha_thr3 15 | 16 | return -------------------------------------------------------------------------------- /dvs/train.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import spikegen 4 | from snntorch import surrogate 5 | from snntorch import functional as SF 6 | 7 | # torch 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | from torchvision import datasets, transforms 12 | import torch.nn.functional as F 13 | from torch.optim.lr_scheduler import StepLR 14 | 15 | # misc 16 | import os 17 | import numpy as np 18 | import math 19 | import itertools 20 | import matplotlib.pyplot as plt 21 | import pandas as pd 22 | import shutil 23 | import time 24 | 25 | from dataloader import * 26 | from test import * 27 | from test_acc import * 28 | from tha import * 29 | 30 | def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device): 31 | 32 | net.train() 33 | loss_accum = [] 34 | lr_accum = [] 35 | 36 | # TRAIN 37 | for data, labels in trainloader: 38 | data, labels = data.to(device), labels.to(device) 39 | spk_rec2, _ = net(data.permute(1, 0, 2, 3, 4)) 40 | loss = criterion(spk_rec2, labels.long()) 41 | optimizer.zero_grad() 42 | loss.backward() 43 | 44 | if config['grad_clip']: 45 | nn.utils.clip_grad_norm_(net.parameters(), 1.0) 46 | if config['weight_clip']: 47 | with torch.no_grad(): 48 | for param in net.parameters(): 49 | param.clamp_(-1, 1) 50 | 51 | optimizer.step() 52 | scheduler.step() 53 | thr_annealing(config, net) 54 | 55 | 56 | loss_accum.append(loss.item()/config['num_steps']) 57 | lr_accum.append(optimizer.param_groups[0]["lr"]) 58 | 59 | 60 | return loss_accum, lr_accum 61 | -------------------------------------------------------------------------------- /figs/temporal_code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeshraghian/snn-tha/f9c0b516a67a4be508b908176992b30894a18af9/figs/temporal_code.png -------------------------------------------------------------------------------- /fmnist/Net.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import surrogate 4 | 5 | # torch 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # local 11 | from bnn import * 12 | 13 | class Net(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | 17 | self.thr1 = config['threshold1'] 18 | self.thr2 = config['threshold2'] 19 | self.thr3 = config['threshold3'] 20 | slope = config['slope'] 21 | beta = config['beta'] 22 | self.num_steps = config['num_steps'] 23 | self.batch_norm = config['batch_norm'] 24 | p1 = config['dropout1'] 25 | self.binarize = config['binarize'] 26 | 27 | spike_grad = surrogate.fast_sigmoid(slope) 28 | self.bconv1 = BinaryConv2d(1, 16, 5, bias=False) 29 | self.conv1 = nn.Conv2d(1, 16, 5, bias=False) 30 | self.conv1_bn = nn.BatchNorm2d(16) 31 | self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad) 32 | self.bconv2 = BinaryConv2d(16, 64, 5, bias=False) 33 | self.conv2 = nn.Conv2d(16, 64, 5, bias=False) 34 | self.conv2_bn = nn.BatchNorm2d(64) 35 | self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad) 36 | self.bfc1 = BinaryLinear(64 * 4 * 4, 10) 37 | self.fc1 = nn.Linear(64 * 4 * 4, 10) 38 | self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad) 39 | self.dropout = nn.Dropout(p1) 40 | 41 | def forward(self, x): 42 | 43 | # Initialize hidden states and outputs at t=0 44 | mem1 = self.lif1.init_leaky() 45 | mem2 = self.lif2.init_leaky() 46 | mem3 = self.lif3.init_leaky() 47 | 48 | # Record the final layer 49 | spk3_rec = [] 50 | mem3_rec = [] 51 | 52 | # Binarization 53 | if self.binarize: 54 | 55 | for step in range(self.num_steps): 56 | cur1 = F.avg_pool2d(self.bconv1(x), 2) 57 | if self.batch_norm: 58 | cur1 = self.conv1_bn(cur1) 59 | spk1, mem1 = self.lif1(cur1, mem1) 60 | cur2 = F.avg_pool2d(self.bconv2(spk1), 2) 61 | if self.batch_norm: 62 | cur2 = self.conv2_bn(cur2) 63 | spk2, mem2 = self.lif2(cur2, mem2) 64 | cur3 = self.dropout(self.bfc1(spk2.flatten(1))) 65 | spk3, mem3 = self.lif3(cur3, mem3) 66 | 67 | spk3_rec.append(spk3) 68 | mem3_rec.append(mem3) 69 | 70 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0) 71 | 72 | # Full Precision 73 | else: 74 | 75 | for step in range(self.num_steps): 76 | 77 | cur1 = F.avg_pool2d(self.conv1(x), 2) 78 | if self.batch_norm: 79 | cur1 = self.conv1_bn(cur1) 80 | spk1, mem1 = self.lif1(cur1, mem1) 81 | cur2 = F.avg_pool2d(self.conv2(spk1), 2) 82 | if self.batch_norm: 83 | cur2 = self.conv2_bn(cur2) 84 | spk2, mem2 = self.lif2(cur2, mem2) 85 | cur3 = self.dropout(self.fc1(spk2.flatten(1))) 86 | spk3, mem3 = self.lif3(cur3, mem3) 87 | 88 | spk3_rec.append(spk3) 89 | mem3_rec.append(mem3) 90 | 91 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0) 92 | -------------------------------------------------------------------------------- /fmnist/bnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from functions import * 8 | 9 | 10 | class BinaryTanh(nn.Module): 11 | def __init__(self): 12 | super(BinaryTanh, self).__init__() 13 | self.hardtanh = nn.Hardtanh() 14 | 15 | def forward(self, input): 16 | output = self.hardtanh(input) 17 | output = binarize(output) 18 | return output 19 | 20 | 21 | class BinaryLinear(nn.Linear): 22 | 23 | def forward(self, input): 24 | binary_weight = binarize(self.weight) 25 | if self.bias is None: 26 | return F.linear(input, binary_weight) 27 | else: 28 | return F.linear(input, binary_weight, self.bias) 29 | 30 | def reset_parameters(self): 31 | # Glorot initialization 32 | in_features, out_features = self.weight.size() 33 | stdv = math.sqrt(1.5 / (in_features + out_features)) 34 | self.weight.data.uniform_(-stdv, stdv) 35 | if self.bias is not None: 36 | self.bias.data.zero_() 37 | 38 | self.weight.lr_scale = 1. / stdv 39 | 40 | 41 | 42 | class BinaryConv2d(nn.Conv2d): 43 | 44 | def forward(self, input): 45 | bw = binarize(self.weight) 46 | return F.conv2d(input, bw, self.bias, self.stride, 47 | self.padding, self.dilation, self.groups) 48 | 49 | def reset_parameters(self): 50 | # Glorot initialization 51 | in_features = self.in_channels 52 | out_features = self.out_channels 53 | for k in self.kernel_size: 54 | in_features *= k 55 | out_features *= k 56 | stdv = math.sqrt(1.5 / (in_features + out_features)) 57 | self.weight.data.uniform_(-stdv, stdv) 58 | if self.bias is not None: 59 | self.bias.data.zero_() 60 | 61 | self.weight.lr_scale = 1. / stdv -------------------------------------------------------------------------------- /fmnist/conf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | config = { 4 | 'exp_name' : 'fmnist_tha', 5 | 'num_trials' : 5, 6 | 'num_epochs' : 500, 7 | 'binarize' : True, 8 | 'data_dir' : "~/data/fmnist", 9 | 'batch_size' : 128, 10 | 'seed' : 0, 11 | 'num_workers' : 0, 12 | 13 | # final run sweeps 14 | 'save_csv' : True, 15 | 'save_model' : True, 16 | 'early_stopping': True, 17 | 'patience': 100, 18 | 19 | # final params 20 | 'grad_clip' : False, 21 | 'weight_clip' : False, 22 | 'batch_norm' : True, 23 | 'dropout1' : 0.648, 24 | 'beta' : 0.868, 25 | 'lr' : 8.4e-4, 26 | 'slope': 0.1557, 27 | 'momentum' : 0.855, 28 | 29 | 30 | # threshold annealing. note: thr_final = threshold + thr_final 31 | 'threshold1' : 6.9, 32 | 'alpha_thr1' : 0.0368, 33 | 'thr_final1' : 7.1456, 34 | 35 | 'threshold2' : 10.25, 36 | 'alpha_thr2' : 0.29687, 37 | 'thr_final2' : 12.826, 38 | 39 | 'threshold3' : 17.95, 40 | 'alpha_thr3' : 0.1048, 41 | 'thr_final3' : 9.936668, 42 | 43 | # fixed params 44 | 'num_steps' : 100, 45 | 'correct_rate': 0.8, 46 | 'incorrect_rate' : 0.2, 47 | 't_0' : 4688, 48 | 'eta_min' : 0, 49 | 'df_lr' : True, # return learning rate. Useful for scheduling 50 | 51 | 52 | 53 | } 54 | 55 | def optim_func(net, config): 56 | optimizer = torch.optim.SGD(net.parameters(), lr=config["lr"], momentum=config['momentum']) 57 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1) 58 | return optimizer, scheduler 59 | -------------------------------------------------------------------------------- /fmnist/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | 6 | def load_data(config): 7 | data_dir = config['data_dir'] 8 | 9 | transform = transforms.Compose([ 10 | transforms.Resize((28, 28)), 11 | transforms.Grayscale(), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0,), (1,))]) 14 | 15 | trainset = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform) 16 | testset = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform) 17 | 18 | return trainset, testset -------------------------------------------------------------------------------- /fmnist/earlystopping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EarlyStopping_acc: 5 | """Early stops the training if test acc doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.patience = patience 21 | self.verbose = verbose 22 | self.counter = 0 23 | self.best_score = None 24 | self.early_stop = False 25 | self.test_loss_min = 0 26 | self.delta = delta 27 | self.path = path 28 | self.trace_func = trace_func 29 | def __call__(self, test_loss, model): 30 | 31 | score = test_loss 32 | 33 | if self.best_score is None: 34 | self.best_score = score 35 | self.save_checkpoint(test_loss, model) 36 | elif score <= self.best_score + self.delta: 37 | self.counter += 1 38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 39 | if self.counter >= self.patience: 40 | self.early_stop = True 41 | self.counter = 0 42 | else: 43 | self.best_score = score 44 | self.save_checkpoint(test_loss, model) 45 | self.counter = 0 46 | 47 | def save_checkpoint(self, test_loss, model): 48 | '''Saves model when test acc increases.''' 49 | if self.verbose: 50 | self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}). Saving model ...') 51 | torch.save(model.state_dict(), self.path) 52 | self.test_loss_min = test_loss -------------------------------------------------------------------------------- /fmnist/functions.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Function 5 | 6 | 7 | class BinarizeF(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, input): 11 | output = input.new(input.size()) 12 | output[input >= 0] = 1 13 | output[input < 0] = -1 14 | return output 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | grad_input = grad_output.clone() 19 | return grad_input 20 | 21 | # aliases 22 | binarize = BinarizeF.apply -------------------------------------------------------------------------------- /fmnist/run.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import spikegen 4 | from snntorch import surrogate 5 | 6 | # torch 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | # misc 13 | import numpy as np 14 | import pandas as pd 15 | import time 16 | import logging 17 | 18 | # local imports 19 | from dataloader import * 20 | from Net import * 21 | from test_acc import * 22 | from train import * 23 | from earlystopping import * 24 | from conf import * 25 | 26 | #################################################### 27 | ## Notes: modify config in conf to reparameterize ## 28 | #################################################### 29 | 30 | 31 | file_name = config['exp_name'] 32 | 33 | ### to address conditional parameters, s.t. thr_final > threshold 34 | config['thr_final1'] = config['thr_final1'] + config['threshold1'] 35 | config['thr_final2'] = config['thr_final2'] + config['threshold2'] 36 | config['thr_final3'] = config['thr_final3'] + config['threshold3'] 37 | 38 | threshold1 = config['threshold1'] 39 | threshold2 = config['threshold2'] 40 | threshold3 = config['threshold3'] 41 | 42 | for trial in range(config['num_trials']): 43 | 44 | # file names 45 | SAVE_CSV = config['save_csv'] 46 | SAVE_MODEL = config['save_model'] 47 | csv_name = file_name + '_t' + str(trial) + '.csv' 48 | log_name = file_name + '_t' + str(trial) + '.log' 49 | model_name = file_name + '_t' + str(trial) + '.pt' 50 | num_epochs = config['num_epochs'] 51 | torch.manual_seed(config['seed']) 52 | 53 | config['threshold1'] = threshold1 54 | config['threshold2'] = threshold2 55 | config['threshold3'] = threshold3 56 | 57 | 58 | # dataframes 59 | df_train_loss = pd.DataFrame() 60 | df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time']) 61 | df_lr = pd.DataFrame() 62 | 63 | 64 | # initialize network 65 | net = Net(config) 66 | device = "cpu" 67 | if torch.cuda.is_available(): 68 | device = "cuda:0" 69 | if torch.cuda.device_count() > 1: 70 | net = nn.DataParallel(net) 71 | net.to(device) 72 | 73 | # net params 74 | criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate']) 75 | optimizer, scheduler = optim_func(net, config) 76 | 77 | # early stopping condition 78 | if config['early_stopping']: 79 | early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name) 80 | early_stopping.early_stop = False 81 | early_stopping.best_score = None 82 | 83 | # load data 84 | trainset, testset = load_data(config) 85 | config['dataset_length'] = len(trainset) 86 | trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True) 87 | testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False) 88 | 89 | print(f"=======Trial: {trial}=======") 90 | 91 | for epoch in range(num_epochs): 92 | 93 | # train 94 | start_time = time.time() 95 | loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device) 96 | epoch_time = time.time() - start_time 97 | 98 | # test 99 | test_acc = test_accuracy(config, net, testloader, device) 100 | print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}') 101 | 102 | if config['df_lr']: 103 | df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)]) 104 | df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)]) 105 | test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time']) 106 | df_test_acc = pd.concat([df_test_acc, test_data]) 107 | 108 | if SAVE_CSV: 109 | df_train_loss.to_csv('loss_' + csv_name, index=False) 110 | df_test_acc.to_csv('acc_' + csv_name, index=False) 111 | if config['df_lr']: 112 | df_lr.to_csv('lr_' + csv_name, index=False) 113 | 114 | if config['early_stopping']: 115 | early_stopping(test_acc, net) 116 | 117 | if early_stopping.early_stop: 118 | print("Early stopping") 119 | early_stopping.early_stop = False 120 | early_stopping.best_score = None 121 | break 122 | 123 | if SAVE_MODEL and not config['early_stopping']: 124 | torch.save(net.state_dict(), model_name) 125 | 126 | # net.load_state_dict(torch.load(model_name)) 127 | -------------------------------------------------------------------------------- /fmnist/test_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import snntorch as snn 3 | from snntorch import functional as SF 4 | 5 | 6 | def test_accuracy(config, net, testloader, device="cpu"): 7 | 8 | 9 | correct = 0 10 | total = 0 11 | with torch.no_grad(): 12 | net.eval() 13 | for data in testloader: 14 | images, labels = data 15 | images, labels = images.to(device), labels.to(device) 16 | 17 | outputs, _ = net(images) 18 | accuracy = SF.accuracy_rate(outputs, labels) 19 | 20 | total += labels.size(0) 21 | correct += accuracy * labels.size(0) 22 | 23 | return 100 * correct / total -------------------------------------------------------------------------------- /fmnist/tha.py: -------------------------------------------------------------------------------- 1 | # exp relaxation implementation of THA based on Eq (4) 2 | 3 | def thr_annealing(config, network): 4 | alpha_thr1 = config['alpha_thr1'] 5 | alpha_thr2 = config['alpha_thr2'] 6 | alpha_thr3 = config['alpha_thr3'] 7 | 8 | thr_final1 = config['thr_final1'] 9 | thr_final2 = config['thr_final2'] 10 | thr_final3 = config['thr_final3'] 11 | 12 | network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1 13 | network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2 14 | network.lif3.threshold += (thr_final3 - network.lif3.threshold) * alpha_thr3 15 | 16 | return -------------------------------------------------------------------------------- /fmnist/train.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import spikegen 4 | from snntorch import surrogate 5 | from snntorch import functional as SF 6 | 7 | # torch 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | from torchvision import datasets, transforms 12 | import torch.nn.functional as F 13 | from torch.optim.lr_scheduler import StepLR 14 | 15 | # misc 16 | import os 17 | import numpy as np 18 | import math 19 | import itertools 20 | import matplotlib.pyplot as plt 21 | import pandas as pd 22 | import shutil 23 | import time 24 | 25 | from dataloader import * 26 | from test import * 27 | from test_acc import * 28 | from tha import * 29 | 30 | 31 | def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device): 32 | 33 | net.train() 34 | loss_accum = [] 35 | lr_accum = [] 36 | 37 | # TRAIN 38 | for data, labels in trainloader: 39 | data, labels = data.to(device), labels.to(device) 40 | 41 | spk_rec2, _ = net(data) 42 | loss = criterion(spk_rec2, labels) 43 | optimizer.zero_grad() 44 | loss.backward() 45 | 46 | if config['grad_clip']: 47 | nn.utils.clip_grad_norm_(net.parameters(), 1.0) 48 | if config['weight_clip']: 49 | with torch.no_grad(): 50 | for param in net.parameters(): 51 | param.clamp_(-1, 1) 52 | 53 | optimizer.step() 54 | scheduler.step() 55 | thr_annealing(config, net) 56 | 57 | 58 | loss_accum.append(loss.item()/config['num_steps']) 59 | lr_accum.append(optimizer.param_groups[0]["lr"]) 60 | 61 | 62 | return loss_accum, lr_accum 63 | -------------------------------------------------------------------------------- /mnist/Net.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import surrogate 4 | 5 | # torch 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # local 11 | from bnn import * 12 | 13 | class Net(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | 17 | self.thr1 = config['threshold1'] 18 | self.thr2 = config['threshold2'] 19 | self.thr3 = config['threshold3'] 20 | slope = config['slope'] 21 | beta = config['beta'] 22 | self.num_steps = config['num_steps'] 23 | self.batch_norm = config['batch_norm'] 24 | p1 = config['dropout1'] 25 | self.binarize = config['binarize'] 26 | 27 | spike_grad = surrogate.fast_sigmoid(slope) 28 | # Initialize layers with spike operator 29 | self.bconv1 = BinaryConv2d(1, 16, 5, bias=False) 30 | self.conv1 = nn.Conv2d(1, 16, 5, bias=False) 31 | self.conv1_bn = nn.BatchNorm2d(16) 32 | self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad) 33 | self.bconv2 = BinaryConv2d(16, 64, 5, bias=False) 34 | self.conv2 = nn.Conv2d(16, 64, 5, bias=False) 35 | self.conv2_bn = nn.BatchNorm2d(64) 36 | self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad) 37 | self.bfc1 = BinaryLinear(64 * 4 * 4, 10) 38 | self.fc1 = nn.Linear(64 * 4 * 4, 10) 39 | self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad) 40 | self.dropout = nn.Dropout(p1) 41 | 42 | def forward(self, x): 43 | 44 | # Initialize hidden states and outputs at t=0 45 | mem1 = self.lif1.init_leaky() 46 | mem2 = self.lif2.init_leaky() 47 | mem3 = self.lif3.init_leaky() 48 | 49 | # Record the final layer 50 | spk3_rec = [] 51 | mem3_rec = [] 52 | 53 | # Binarized 54 | if self.binarize: 55 | 56 | for step in range(self.num_steps): 57 | 58 | cur1 = F.avg_pool2d(self.bconv1(x), 2) 59 | if self.batch_norm: 60 | cur1 = self.conv1_bn(cur1) 61 | spk1, mem1 = self.lif1(cur1, mem1) 62 | cur2 = F.avg_pool2d(self.bconv2(spk1), 2) 63 | if self.batch_norm: 64 | cur2 = self.conv2_bn(cur2) 65 | spk2, mem2 = self.lif2(cur2, mem2) 66 | cur3 = self.dropout(self.bfc1(spk2.flatten(1))) 67 | spk3, mem3 = self.lif3(cur3, mem3) 68 | 69 | spk3_rec.append(spk3) 70 | mem3_rec.append(mem3) 71 | 72 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0) 73 | 74 | # Full Precision 75 | else: 76 | 77 | for step in range(self.num_steps): 78 | 79 | cur1 = F.avg_pool2d(self.conv1(x), 2) 80 | if self.batch_norm: 81 | cur1 = self.conv1_bn(cur1) 82 | spk1, mem1 = self.lif1(cur1, mem1) 83 | cur2 = F.avg_pool2d(self.conv2(spk1), 2) 84 | if self.batch_norm: 85 | cur2 = self.conv2_bn(cur2) 86 | spk2, mem2 = self.lif2(cur2, mem2) 87 | cur3 = self.dropout(self.fc1(spk2.flatten(1))) 88 | spk3, mem3 = self.lif3(cur3, mem3) 89 | 90 | spk3_rec.append(spk3) 91 | mem3_rec.append(mem3) 92 | 93 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0) 94 | -------------------------------------------------------------------------------- /mnist/bnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from functions import * 8 | 9 | 10 | class BinaryTanh(nn.Module): 11 | def __init__(self): 12 | super(BinaryTanh, self).__init__() 13 | self.hardtanh = nn.Hardtanh() 14 | 15 | def forward(self, input): 16 | output = self.hardtanh(input) 17 | output = binarize(output) 18 | return output 19 | 20 | 21 | class BinaryLinear(nn.Linear): 22 | 23 | def forward(self, input): 24 | binary_weight = binarize(self.weight) 25 | if self.bias is None: 26 | return F.linear(input, binary_weight) 27 | else: 28 | return F.linear(input, binary_weight, self.bias) 29 | 30 | def reset_parameters(self): 31 | # Glorot initialization 32 | in_features, out_features = self.weight.size() 33 | stdv = math.sqrt(1.5 / (in_features + out_features)) 34 | self.weight.data.uniform_(-stdv, stdv) 35 | if self.bias is not None: 36 | self.bias.data.zero_() 37 | 38 | self.weight.lr_scale = 1. / stdv 39 | 40 | 41 | 42 | class BinaryConv2d(nn.Conv2d): 43 | 44 | def forward(self, input): 45 | bw = binarize(self.weight) 46 | return F.conv2d(input, bw, self.bias, self.stride, 47 | self.padding, self.dilation, self.groups) 48 | 49 | def reset_parameters(self): 50 | # Glorot initialization 51 | in_features = self.in_channels 52 | out_features = self.out_channels 53 | for k in self.kernel_size: 54 | in_features *= k 55 | out_features *= k 56 | stdv = math.sqrt(1.5 / (in_features + out_features)) 57 | self.weight.data.uniform_(-stdv, stdv) 58 | if self.bias is not None: 59 | self.bias.data.zero_() 60 | 61 | self.weight.lr_scale = 1. / stdv -------------------------------------------------------------------------------- /mnist/conf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | config = { 4 | 'exp_name' : 'mnist_tha', 5 | 'num_trials' : 5, 6 | 'num_epochs' : 500, 7 | 'binarize' : True, 8 | 'data_dir' : "~/data/mnist", 9 | 'batch_size' : 128, 10 | 'seed' : 0, 11 | 'num_workers' : 0, 12 | 13 | # final run sweeps 14 | 'save_csv' : True, 15 | 'save_model' : True, 16 | 'early_stopping': True, 17 | 'patience': 100, 18 | 19 | # final params 20 | 'grad_clip' : False, 21 | 'weight_clip' : False, 22 | 'batch_norm' : True, 23 | 'dropout1' : 0.02856, 24 | 'beta' : 0.99, 25 | 'lr' : 9.97e-3, 26 | 'slope': 10.22, 27 | 28 | # threshold annealing. note: thr_final = threshold + thr_final 29 | 'threshold1' : 11.666, 30 | 'alpha_thr1' : 0.024, 31 | 'thr_final1' : 4.317, 32 | 33 | 'threshold2' : 14.105, 34 | 'alpha_thr2' : 0.119, 35 | 'thr_final2' : 16.29, 36 | 37 | 'threshold3' : 0.6656, 38 | 'alpha_thr3' : 0.0011, 39 | 'thr_final3' : 3.496, 40 | 41 | # fixed params 42 | 'num_steps' : 100, 43 | 'correct_rate': 0.8, 44 | 'incorrect_rate' : 0.2, 45 | 'betas' : (0.9, 0.999), 46 | 't_0' : 4688, 47 | 'eta_min' : 0, 48 | 'df_lr' : True, # return learning rate. Useful for scheduling 49 | 50 | 51 | 52 | } 53 | 54 | def optim_func(net, config): 55 | optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=config['betas']) 56 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1) 57 | return optimizer, scheduler 58 | -------------------------------------------------------------------------------- /mnist/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | 6 | def load_data(config): 7 | data_dir = config['data_dir'] 8 | 9 | transform = transforms.Compose([ 10 | transforms.Resize((28, 28)), 11 | transforms.Grayscale(), 12 | transforms.ToTensor(), 13 | transforms.Normalize((0,), (1,))]) 14 | 15 | trainset = datasets.MNIST(data_dir, train=True, download=True, transform=transform) 16 | testset = datasets.MNIST(data_dir, train=False, download=True, transform=transform) 17 | 18 | return trainset, testset -------------------------------------------------------------------------------- /mnist/earlystopping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EarlyStopping_acc: 5 | """Early stops the training if test acc doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.patience = patience 21 | self.verbose = verbose 22 | self.counter = 0 23 | self.best_score = None 24 | self.early_stop = False 25 | self.test_loss_min = 0 26 | self.delta = delta 27 | self.path = path 28 | self.trace_func = trace_func 29 | def __call__(self, test_loss, model): 30 | 31 | score = test_loss 32 | 33 | if self.best_score is None: 34 | self.best_score = score 35 | self.save_checkpoint(test_loss, model) 36 | elif score <= self.best_score + self.delta: 37 | self.counter += 1 38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 39 | if self.counter >= self.patience: 40 | self.early_stop = True 41 | self.counter = 0 42 | else: 43 | self.best_score = score 44 | self.save_checkpoint(test_loss, model) 45 | self.counter = 0 46 | 47 | def save_checkpoint(self, test_loss, model): 48 | '''Saves model when test acc increases.''' 49 | if self.verbose: 50 | self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}). Saving model ...') 51 | torch.save(model.state_dict(), self.path) 52 | self.test_loss_min = test_loss -------------------------------------------------------------------------------- /mnist/functions.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Function 5 | 6 | 7 | class BinarizeF(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, input): 11 | output = input.new(input.size()) 12 | output[input >= 0] = 1 13 | output[input < 0] = -1 14 | return output 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | grad_input = grad_output.clone() 19 | return grad_input 20 | 21 | # aliases 22 | binarize = BinarizeF.apply -------------------------------------------------------------------------------- /mnist/run.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import spikegen 4 | from snntorch import surrogate 5 | 6 | # torch 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | # misc 12 | import numpy as np 13 | import pandas as pd 14 | import time 15 | import logging 16 | 17 | # local imports 18 | from dataloader import * 19 | from Net import * 20 | from test_acc import * 21 | from train import * 22 | from earlystopping import * 23 | from conf import * 24 | 25 | #################################################### 26 | ## Notes: modify config in conf to reparameterize ## 27 | #################################################### 28 | 29 | file_name = config['exp_name'] 30 | 31 | ### to address conditional parameters, s.t. thr_final > threshold 32 | config['thr_final1'] = config['thr_final1'] + config['threshold1'] 33 | config['thr_final2'] = config['thr_final2'] + config['threshold2'] 34 | config['thr_final3'] = config['thr_final3'] + config['threshold3'] 35 | 36 | threshold1 = config['threshold1'] 37 | threshold2 = config['threshold2'] 38 | threshold3 = config['threshold3'] 39 | 40 | 41 | 42 | for trial in range(config['num_trials']): 43 | 44 | 45 | # file names 46 | SAVE_CSV = config['save_csv'] 47 | SAVE_MODEL = config['save_model'] 48 | csv_name = file_name + '_t' + str(trial) + '.csv' 49 | log_name = file_name + '_t' + str(trial) + '.log' 50 | model_name = file_name + '_t' + str(trial) + '.pt' 51 | num_epochs = config['num_epochs'] 52 | torch.manual_seed(config['seed']) 53 | 54 | config['threshold1'] = threshold1 55 | config['threshold2'] = threshold2 56 | config['threshold3'] = threshold3 57 | 58 | # dataframes 59 | df_train_loss = pd.DataFrame() 60 | df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time']) 61 | df_lr = pd.DataFrame() 62 | 63 | 64 | # initialize network 65 | net = Net(config) 66 | device = "cpu" 67 | if torch.cuda.is_available(): 68 | device = "cuda:0" 69 | if torch.cuda.device_count() > 1: 70 | net = nn.DataParallel(net) 71 | net.to(device) 72 | 73 | # net params 74 | criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate']) 75 | optimizer, scheduler = optim_func(net, config) 76 | 77 | # early stopping condition 78 | if config['early_stopping']: 79 | early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name) 80 | early_stopping.early_stop = False 81 | early_stopping.best_score = None 82 | 83 | # load data 84 | trainset, testset = load_data(config) 85 | config['dataset_length'] = len(trainset) 86 | trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True) 87 | testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False) 88 | 89 | print(f"=======Trial: {trial}=======") 90 | 91 | for epoch in range(num_epochs): 92 | 93 | # train 94 | start_time = time.time() 95 | loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device) 96 | epoch_time = time.time() - start_time 97 | 98 | # test 99 | test_acc = test_accuracy(config, net, testloader, device) 100 | print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}') 101 | 102 | if config['df_lr']: 103 | df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)]) 104 | df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)]) 105 | test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time']) 106 | df_test_acc = pd.concat([df_test_acc, test_data]) 107 | 108 | if SAVE_CSV: 109 | df_train_loss.to_csv('loss_' + csv_name, index=False) 110 | df_test_acc.to_csv('acc_' + csv_name, index=False) 111 | if config['df_lr']: 112 | df_lr.to_csv('lr_' + csv_name, index=False) 113 | 114 | if config['early_stopping']: 115 | early_stopping(test_acc, net) 116 | 117 | if early_stopping.early_stop: 118 | print("Early stopping") 119 | early_stopping.early_stop = False 120 | early_stopping.best_score = None 121 | break 122 | 123 | if SAVE_MODEL and not config['early_stopping']: 124 | torch.save(net.state_dict(), model_name) 125 | 126 | # net.load_state_dict(torch.load(model_name)) 127 | -------------------------------------------------------------------------------- /mnist/test_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import snntorch as snn 3 | from snntorch import functional as SF 4 | 5 | 6 | def test_accuracy(config, net, testloader, device="cpu"): 7 | 8 | 9 | correct = 0 10 | total = 0 11 | with torch.no_grad(): 12 | net.eval() 13 | for data in testloader: 14 | images, labels = data 15 | images, labels = images.to(device), labels.to(device) 16 | 17 | outputs, _ = net(images) 18 | accuracy = SF.accuracy_rate(outputs, labels) 19 | 20 | total += labels.size(0) 21 | correct += accuracy * labels.size(0) 22 | 23 | return 100 * correct / total -------------------------------------------------------------------------------- /mnist/tha.py: -------------------------------------------------------------------------------- 1 | # exp relaxation implementation of THA based on Eq (4) 2 | 3 | def thr_annealing(config, network): 4 | alpha_thr1 = config['alpha_thr1'] 5 | alpha_thr2 = config['alpha_thr2'] 6 | alpha_thr3 = config['alpha_thr3'] 7 | 8 | thr_final1 = config['thr_final1'] 9 | thr_final2 = config['thr_final2'] 10 | thr_final3 = config['thr_final3'] 11 | 12 | network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1 13 | network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2 14 | network.lif3.threshold += (thr_final3 - network.lif3.threshold) * alpha_thr3 15 | 16 | return -------------------------------------------------------------------------------- /mnist/train.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import spikegen 4 | from snntorch import surrogate 5 | from snntorch import functional as SF 6 | 7 | # torch 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | from torchvision import datasets, transforms 12 | import torch.nn.functional as F 13 | from torch.optim.lr_scheduler import StepLR 14 | 15 | # misc 16 | import os 17 | import numpy as np 18 | import math 19 | import itertools 20 | import matplotlib.pyplot as plt 21 | import pandas as pd 22 | import shutil 23 | import time 24 | 25 | from dataloader import * 26 | from test import * 27 | from test_acc import * 28 | from tha import * 29 | 30 | 31 | def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device): 32 | 33 | net.train() 34 | loss_accum = [] 35 | lr_accum = [] 36 | 37 | # TRAIN 38 | for data, labels in trainloader: 39 | data, labels = data.to(device), labels.to(device) 40 | 41 | spk_rec2, _ = net(data) 42 | loss = criterion(spk_rec2, labels) 43 | optimizer.zero_grad() 44 | loss.backward() 45 | 46 | if config['grad_clip']: 47 | nn.utils.clip_grad_norm_(net.parameters(), 1.0) 48 | if config['weight_clip']: 49 | with torch.no_grad(): 50 | for param in net.parameters(): 51 | param.clamp_(-1, 1) 52 | 53 | optimizer.step() 54 | scheduler.step() 55 | thr_annealing(config, net) 56 | 57 | 58 | loss_accum.append(loss.item()/config['num_steps']) 59 | lr_accum.append(optimizer.param_groups[0]["lr"]) 60 | 61 | 62 | return loss_accum, lr_accum -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | snntorch 4 | pandas 5 | matplotlib 6 | numpy -------------------------------------------------------------------------------- /shd/Net.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import surrogate 4 | 5 | # torch 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # local 11 | from bnn import * 12 | 13 | 14 | class Net(nn.Module): 15 | def __init__(self, config): 16 | super().__init__() 17 | 18 | self.thr1 = config['threshold1'] 19 | self.thr2 = config['threshold2'] 20 | slope = config['slope'] 21 | beta = config['beta'] 22 | self.num_steps = config['num_steps'] 23 | p1 = config['dropout1'] 24 | p2 = config['dropout2'] 25 | self.binarize = config['binarize'] 26 | num_hidden = 3000 27 | spike_grad = surrogate.fast_sigmoid(slope) 28 | # Initialize layers with spike operator 29 | 30 | 31 | self.bfc1 = BinaryLinear(700, num_hidden) 32 | self.fc1 = nn.Linear(700, num_hidden) 33 | self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad) 34 | self.dropout1 = nn.Dropout(p1) 35 | 36 | self.bfc2 = BinaryLinear(num_hidden, 20) 37 | self.fc2 = nn.Linear(num_hidden, 20) 38 | self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad) 39 | self.dropout2 = nn.Dropout(p2) 40 | 41 | 42 | def forward(self, x): 43 | 44 | # Initialize hidden states and outputs at t=0 45 | mem1 = self.lif1.init_leaky() 46 | mem2 = self.lif2.init_leaky() 47 | 48 | # Record the final layer 49 | spk2_rec = [] 50 | mem2_rec = [] 51 | 52 | # Binarization 53 | 54 | if self.binarize: 55 | 56 | for step in range(x.size(0)): 57 | 58 | cur1 = self.dropout1(self.bfc1(x[step].flatten(1))) 59 | spk1, mem1 = self.lif1(cur1, mem1) 60 | cur2 = self.dropout2(self.bfc2(spk1)) 61 | spk2, mem2 = self.lif2(cur2, mem2) 62 | 63 | 64 | spk2_rec.append(spk2) 65 | mem2_rec.append(mem2) 66 | 67 | return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0) 68 | 69 | # Full Precision 70 | 71 | else: 72 | 73 | for step in range(x.size(0)): 74 | 75 | cur1 = self.dropout1(self.fc1(x[step].flatten(1))) 76 | spk1, mem1 = self.lif1(cur1, mem1) 77 | cur2 = self.dropout2(self.fc2(spk1)) 78 | spk2, mem2 = self.lif2(cur2, mem2) 79 | spk2_rec.append(spk2) 80 | mem2_rec.append(mem2) 81 | 82 | return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0) 83 | 84 | -------------------------------------------------------------------------------- /shd/bnn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from functions import * 8 | 9 | 10 | class BinaryTanh(nn.Module): 11 | def __init__(self): 12 | super(BinaryTanh, self).__init__() 13 | self.hardtanh = nn.Hardtanh() 14 | 15 | def forward(self, input): 16 | output = self.hardtanh(input) 17 | output = binarize(output) 18 | return output 19 | 20 | 21 | class BinaryLinear(nn.Linear): 22 | 23 | def forward(self, input): 24 | binary_weight = binarize(self.weight) 25 | if self.bias is None: 26 | return F.linear(input, binary_weight) 27 | else: 28 | return F.linear(input, binary_weight, self.bias) 29 | 30 | def reset_parameters(self): 31 | # Glorot initialization 32 | in_features, out_features = self.weight.size() 33 | stdv = math.sqrt(1.5 / (in_features + out_features)) 34 | self.weight.data.uniform_(-stdv, stdv) 35 | if self.bias is not None: 36 | self.bias.data.zero_() 37 | 38 | self.weight.lr_scale = 1. / stdv 39 | 40 | 41 | 42 | class BinaryConv2d(nn.Conv2d): 43 | 44 | def forward(self, input): 45 | bw = binarize(self.weight) 46 | return F.conv2d(input, bw, self.bias, self.stride, 47 | self.padding, self.dilation, self.groups) 48 | 49 | def reset_parameters(self): 50 | # Glorot initialization 51 | in_features = self.in_channels 52 | out_features = self.out_channels 53 | for k in self.kernel_size: 54 | in_features *= k 55 | out_features *= k 56 | stdv = math.sqrt(1.5 / (in_features + out_features)) 57 | self.weight.data.uniform_(-stdv, stdv) 58 | if self.bias is not None: 59 | self.bias.data.zero_() 60 | 61 | self.weight.lr_scale = 1. / stdv -------------------------------------------------------------------------------- /shd/conf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | config = { 4 | 'exp_name' : 'shd_tha', 5 | 'num_trials' : 5, 6 | 'num_epochs' : 5, 7 | 'binarize' : True, 8 | 'data_dir' : "/home/shd", 9 | 'batch_size' : 32, 10 | 'seed' : 0, 11 | 'num_workers' : 0, 12 | 13 | # final run sweeps 14 | 'save_csv' : True, 15 | 'save_model' : True, 16 | 'early_stopping': True, 17 | 'patience': 100, 18 | 19 | # final params 20 | 'grad_clip' : True, 21 | 'weight_clip' : True, 22 | 'batch_norm' : True, 23 | 'dropout2' : 0.0176, 24 | 'dropout1' : 0.186, 25 | 'beta' : 0.950, 26 | 'lr' : 6.54e-4, 27 | 'slope': 0.257, 28 | 29 | 30 | # threshold annealing. note: thr_final = threshold + thr_final 31 | 'threshold1' : 13.504, 32 | 'alpha_thr1' : 2.78e-5, 33 | 'thr_final1' : 31.767, 34 | 35 | 'threshold2' : 11.20, 36 | 'alpha_thr2' : 1.36e-5, 37 | 'thr_final2' : 39.92, 38 | 39 | # fixed params 40 | 'num_steps' : 100, 41 | 'correct_rate': 0.8, 42 | 'incorrect_rate' : 0.2, 43 | 'betas1' : 0.9, 44 | 'betas2' : 0.999, 45 | 't_0' : 2604, 46 | 'eta_min' : 0, 47 | 'df_lr' : True, # return learning rate. Useful for scheduling 48 | 49 | 50 | } 51 | 52 | def optim_func(net, config): 53 | optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=(config['betas1'], config['betas2'])) 54 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1) 55 | return optimizer, scheduler -------------------------------------------------------------------------------- /shd/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import DataLoader 4 | from torchvision import datasets, transforms 5 | 6 | from snntorch.spikevision import spikedata 7 | 8 | def load_data(config): 9 | 10 | data_dir = config['data_dir'] 11 | dt_scalar = 3 # set to 2 for float in our experiments 12 | 13 | 14 | dt = int(1000*dt_scalar) 15 | num_steps = int(1000/dt_scalar) 16 | 17 | trainset = spikedata.SHD(data_dir, train=True, num_steps=num_steps, dt=dt) 18 | testset = spikedata.SHD(data_dir, train=False, num_steps=num_steps, dt=dt) 19 | 20 | return trainset, testset -------------------------------------------------------------------------------- /shd/earlystopping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EarlyStopping_acc: 5 | """Early stops the training if test acc doesn't improve after a given patience.""" 6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): 7 | """ 8 | Args: 9 | patience (int): How long to wait after last time validation loss improved. 10 | Default: 7 11 | verbose (bool): If True, prints a message for each validation loss improvement. 12 | Default: False 13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement. 14 | Default: 0 15 | path (str): Path for the checkpoint to be saved to. 16 | Default: 'checkpoint.pt' 17 | trace_func (function): trace print function. 18 | Default: print 19 | """ 20 | self.patience = patience 21 | self.verbose = verbose 22 | self.counter = 0 23 | self.best_score = None 24 | self.early_stop = False 25 | self.test_loss_min = 0 26 | self.delta = delta 27 | self.path = path 28 | self.trace_func = trace_func 29 | def __call__(self, test_loss, model): 30 | 31 | score = test_loss 32 | 33 | if self.best_score is None: 34 | self.best_score = score 35 | self.save_checkpoint(test_loss, model) 36 | elif score <= self.best_score + self.delta: 37 | self.counter += 1 38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') 39 | if self.counter >= self.patience: 40 | self.early_stop = True 41 | self.counter = 0 42 | else: 43 | self.best_score = score 44 | self.save_checkpoint(test_loss, model) 45 | self.counter = 0 46 | 47 | def save_checkpoint(self, test_loss, model): 48 | '''Saves model when test acc increases.''' 49 | if self.verbose: 50 | self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}). Saving model ...') 51 | torch.save(model.state_dict(), self.path) 52 | self.test_loss_min = test_loss -------------------------------------------------------------------------------- /shd/functions.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Function 5 | 6 | 7 | class BinarizeF(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, input): 11 | output = input.new(input.size()) 12 | output[input >= 0] = 1 13 | output[input < 0] = -1 14 | return output 15 | 16 | @staticmethod 17 | def backward(ctx, grad_output): 18 | grad_input = grad_output.clone() 19 | return grad_input 20 | 21 | # aliases 22 | binarize = BinarizeF.apply -------------------------------------------------------------------------------- /shd/run.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import spikegen 4 | from snntorch import surrogate 5 | 6 | # torch 7 | import torch 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | 11 | 12 | # misc 13 | import numpy as np 14 | import pandas as pd 15 | import time 16 | import logging 17 | 18 | # local imports 19 | from dataloader import * 20 | from Net import * 21 | from test_acc import * 22 | from train import * 23 | from earlystopping import * 24 | from conf import * 25 | 26 | #################################################### 27 | ## Notes: modify config in conf to reparameterize ## 28 | #################################################### 29 | 30 | file_name = config['exp_name'] 31 | 32 | ### to address conditional parameters, s.t. thr_final > threshold 33 | config['thr_final1'] = config['thr_final1'] + config['threshold1'] 34 | config['thr_final2'] = config['thr_final2'] + config['threshold2'] 35 | 36 | threshold1 = config['threshold1'] 37 | threshold2 = config['threshold2'] 38 | 39 | for trial in range(config['num_trials']): 40 | 41 | # file names 42 | SAVE_CSV = config['save_csv'] 43 | SAVE_MODEL = config['save_model'] 44 | csv_name = file_name + '_t' + str(trial) + '.csv' 45 | log_name = file_name + '_t' + str(trial) + '.log' 46 | model_name = file_name + '_t' + str(trial) + '.pt' 47 | num_epochs = config['num_epochs'] 48 | torch.manual_seed(config['seed']) 49 | 50 | config['threshold1'] = threshold1 51 | config['threshold2'] = threshold2 52 | 53 | # dataframes 54 | df_train_loss = pd.DataFrame() 55 | df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time']) 56 | df_lr = pd.DataFrame() 57 | 58 | # initialize network 59 | net = Net(config) 60 | device = "cpu" 61 | if torch.cuda.is_available(): 62 | device = "cuda:0" 63 | if torch.cuda.device_count() > 1: 64 | net = nn.DataParallel(net) 65 | net.to(device) 66 | 67 | # net params 68 | criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate']) 69 | optimizer, scheduler = optim_func(net, config) 70 | 71 | # early stopping condition 72 | if config['early_stopping']: 73 | early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name) 74 | early_stopping.early_stop = False 75 | early_stopping.best_score = None 76 | 77 | # load data 78 | trainset, testset = load_data(config) 79 | trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True) 80 | testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False) 81 | 82 | print(f"=======Trial: {trial}=======") 83 | 84 | for epoch in range(num_epochs): 85 | 86 | # train 87 | start_time = time.time() 88 | loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device) 89 | epoch_time = time.time() - start_time 90 | 91 | # test 92 | test_acc = test_accuracy(config, net, testloader, device) 93 | print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}') 94 | 95 | if config['df_lr']: 96 | df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)]) 97 | df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)]) 98 | test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time']) 99 | df_test_acc = pd.concat([df_test_acc, test_data]) 100 | 101 | if SAVE_CSV: 102 | df_train_loss.to_csv('loss_' + csv_name, index=False) 103 | df_test_acc.to_csv('acc_' + csv_name, index=False) 104 | if config['df_lr']: 105 | df_lr.to_csv('lr_' + csv_name, index=False) 106 | 107 | if config['early_stopping']: 108 | early_stopping(test_acc, net) 109 | 110 | if early_stopping.early_stop: 111 | print("Early stopping") 112 | early_stopping.early_stop = False 113 | early_stopping.best_score = None 114 | break 115 | 116 | if SAVE_MODEL and not config['early_stopping']: 117 | torch.save(net.state_dict(), model_name) 118 | 119 | # net.load_state_dict(torch.load(model_name)) -------------------------------------------------------------------------------- /shd/test_acc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import snntorch as snn 3 | from snntorch import functional as SF 4 | 5 | 6 | def test_accuracy(config, net, testloader, device="cpu"): 7 | 8 | 9 | correct = 0 10 | total = 0 11 | with torch.no_grad(): 12 | net.eval() 13 | for data in testloader: 14 | images, labels = data 15 | images, labels = images.to(device), labels.to(device) 16 | 17 | outputs, _ = net(images.permute(1, 0, 2)) 18 | accuracy = SF.accuracy_rate(outputs, labels) 19 | 20 | total += labels.size(0) 21 | correct += accuracy * labels.size(0) 22 | 23 | return 100 * correct / total -------------------------------------------------------------------------------- /shd/tha.py: -------------------------------------------------------------------------------- 1 | # exp relaxation implementation of THA based on Eq (4) 2 | 3 | def thr_annealing(config, network): 4 | alpha_thr1 = config['alpha_thr1'] 5 | thr_final1 = config['thr_final1'] 6 | 7 | alpha_thr2 = config['alpha_thr2'] 8 | thr_final2 = config['thr_final2'] 9 | 10 | network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1 11 | network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2 12 | 13 | return -------------------------------------------------------------------------------- /shd/train.py: -------------------------------------------------------------------------------- 1 | # snntorch 2 | import snntorch as snn 3 | from snntorch import spikegen 4 | from snntorch import surrogate 5 | from snntorch import functional as SF 6 | 7 | # torch 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | from torchvision import datasets, transforms 12 | import torch.nn.functional as F 13 | from torch.optim.lr_scheduler import StepLR 14 | 15 | # misc 16 | import os 17 | import numpy as np 18 | import math 19 | import itertools 20 | import matplotlib.pyplot as plt 21 | import pandas as pd 22 | import shutil 23 | import time 24 | 25 | # raytune 26 | # from functools import partial 27 | # from ray import tune 28 | # from ray.tune import CLIReporter 29 | # # from ray.tune import JupyterNotebookReporter 30 | # from ray.tune.schedulers import ASHAScheduler 31 | 32 | from dataloader import * 33 | from test import * 34 | from test_acc import * 35 | from tha import * 36 | 37 | def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device): 38 | 39 | net.train() 40 | loss_accum = [] 41 | lr_accum = [] 42 | 43 | # TRAIN 44 | for data, labels in trainloader: 45 | data, labels = data.to(device), labels.to(device) 46 | spk_rec2, _ = net(data.permute(1, 0, 2)) 47 | loss = criterion(spk_rec2, labels.long()) 48 | optimizer.zero_grad() 49 | loss.backward() 50 | 51 | if config['grad_clip']: 52 | nn.utils.clip_grad_norm_(net.parameters(), 1.0) 53 | if config['weight_clip']: 54 | with torch.no_grad(): 55 | for param in net.parameters(): 56 | param.clamp_(-1, 1) 57 | 58 | optimizer.step() 59 | scheduler.step() 60 | thr_annealing(config, net) 61 | 62 | 63 | loss_accum.append(loss.item()/config['num_steps']) 64 | lr_accum.append(optimizer.param_groups[0]["lr"]) 65 | 66 | return loss_accum, lr_accum 67 | -------------------------------------------------------------------------------- /temporal/bounded_homeostasis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "accelerator": "GPU", 6 | "colab": { 7 | "name": "icml_spike_time_exp.ipynb", 8 | "provenance": [], 9 | "collapsed_sections": [] 10 | }, 11 | "kernelspec": { 12 | "display_name": "Python 3", 13 | "name": "python3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | } 18 | }, 19 | "cells": [ 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "id": "GapeQDQsl-sx" 24 | }, 25 | "source": [ 26 | "# Bounded Homeostasis to Learn Temporal Targets\n", 27 | "\n", 28 | "This notebook replicates the temporal coding experiments in the paper *`The fine line between dead neurons and sparsity in binarized spiking neural networks'*." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "metadata": { 34 | "id": "U_R27gZyULBI" 35 | }, 36 | "source": [ 37 | "!pip install snntorch --quiet" 38 | ], 39 | "execution_count": null, 40 | "outputs": [] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": { 45 | "id": "kKI4l8OXQxXk" 46 | }, 47 | "source": [ 48 | "## Imports" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "metadata": { 54 | "id": "JyB2DosmUNO0" 55 | }, 56 | "source": [ 57 | "import snntorch as snn\n", 58 | "from snntorch import surrogate\n", 59 | "from snntorch import spikegen\n", 60 | "import snntorch.functional as SF\n", 61 | "from snntorch import spikeplot as splt\n", 62 | "from snntorch import utils\n", 63 | "\n", 64 | "import torch\n", 65 | "import torch.nn as nn\n", 66 | "import torch.nn.functional as F\n", 67 | "from torch.autograd import Function\n", 68 | "\n", 69 | "import matplotlib.pyplot as plt\n", 70 | "from matplotlib.animation import FuncAnimation\n", 71 | "import matplotlib.gridspec as gridspec\n", 72 | "import seaborn as sns \n", 73 | "\n", 74 | "from IPython import display\n", 75 | "import numpy as np\n", 76 | "from tqdm import tqdm\n", 77 | "import math\n", 78 | "import random\n", 79 | "from scipy.ndimage.filters import uniform_filter1d\n", 80 | "import os" 81 | ], 82 | "execution_count": null, 83 | "outputs": [] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "source": [ 88 | "## Plotting Utility Functions" 89 | ], 90 | "metadata": { 91 | "id": "34Oz6bzcuCne" 92 | } 93 | }, 94 | { 95 | "cell_type": "code", 96 | "source": [ 97 | "#@title\n", 98 | "sns.set_theme()\n", 99 | "\n", 100 | "def prep_for_plot(mem):\n", 101 | " return mem.cpu().detach().squeeze(-1).squeeze(-1)\n", 102 | "\n", 103 | "def plot_quadrant(mem, spk_out, target_mem, spk_target, y1, y2, threshold=1, save=False, epoch1 = 1, epoch2=25, epoch3=100, fill=True):\n", 104 | " # Generate Plots\n", 105 | " gs = gridspec.GridSpec(2, 4, height_ratios=[1, 0.07])\n", 106 | " fig = plt.figure(figsize=(12,4.5),)\n", 107 | " ax1 = plt.subplot(gs[0,0])\n", 108 | " ax2 = plt.subplot(gs[1,0])\n", 109 | " ax3 = plt.subplot(gs[0,1])\n", 110 | " ax4 = plt.subplot(gs[1,1])\n", 111 | " ax5 = plt.subplot(gs[0,2])\n", 112 | " ax6 = plt.subplot(gs[1,2])\n", 113 | " ax7 = plt.subplot(gs[0,3])\n", 114 | " ax8 = plt.subplot(gs[1,3])\n", 115 | "\n", 116 | " mem = prep_for_plot(mem)\n", 117 | " spk_out = prep_for_plot(spk_out)\n", 118 | " target_mem = prep_for_plot(target_mem)\n", 119 | " epoch1_str = str(epoch1)\n", 120 | " epoch2_str = str(epoch2)\n", 121 | " epoch3_str = str(epoch3)\n", 122 | "\n", 123 | " fontsize = 25\n", 124 | "\n", 125 | " ########### TARGET ########\n", 126 | " # Plot membrane potential\n", 127 | " ax1.plot(target_mem)\n", 128 | " ax1.set_ylim([y1, y2]) # 0.1, 1.3\n", 129 | " ax1.set_ylabel(\"$u$\", fontsize=fontsize, fontweight='bold')\n", 130 | " ax1.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 131 | " ax1.set_yticks([])\n", 132 | " ax1.set_xticks([])\n", 133 | " ax1.set_title(\"Target\",fontsize=fontsize, fontweight='bold')\n", 134 | " # plt.xlabel(\"Time\") \n", 135 | "\n", 136 | " # Plot output spike using spikeplot\n", 137 | " splt.raster(spk_target, ax2, s=250, c=\"black\", marker=\".\")\n", 138 | " ax2.set_ylabel(\"$z$\", fontsize=fontsize, fontweight='bold')\n", 139 | " ax2.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n", 140 | " ax2.set_yticks([]) \n", 141 | " ax2.set_xticks([])\n", 142 | " ax2.set_xlim(0, 100)\n", 143 | "\n", 144 | " ############## EPOCH 1 ########\n", 145 | "\n", 146 | " # Plot membrane potential\n", 147 | " ax3.plot(mem[epoch1])\n", 148 | " ax3.set_ylim([y1, y2]) # 0.1, 1.3\n", 149 | " ax3.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 150 | " ax3.set_yticks([])\n", 151 | " ax3.set_xticks([])\n", 152 | " ax3.set_title(\"$\\gamma =$\" + epoch1_str ,fontsize=fontsize, fontweight='bold')\n", 153 | " # plt.xlabel(\"Time\") \n", 154 | "\n", 155 | " # Plot output spike using spikeplot\n", 156 | " splt.raster(spk_out[epoch1], ax4, s=250, c=\"black\", marker=\".\")\n", 157 | " # ax4.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n", 158 | " ax4.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n", 159 | " ax4.set_yticks([]) \n", 160 | " ax4.set_xticks([])\n", 161 | " ax4.set_xlim(0, 100)\n", 162 | "\n", 163 | " ############## EPOCH 100 ########\n", 164 | "\n", 165 | " # Plot membrane potential\n", 166 | " ax5.plot(mem[epoch2])\n", 167 | " ax5.set_ylim([y1, y2]) # 0.1, 1.3\n", 168 | " # ax5.set_ylabel(\"$u^{~j}_t$\", fontsize=fontsize, fontweight='bold')\n", 169 | " ax5.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 170 | " ax5.set_yticks([])\n", 171 | " ax5.set_xticks([]) \n", 172 | " ax5.set_title(\"$\\gamma =$\" + epoch2_str,fontsize=fontsize, fontweight='bold')\n", 173 | " # plt.xlabel(\"Time\") \n", 174 | "\n", 175 | " # Plot output spike using spikeplot\n", 176 | " splt.raster(spk_out[epoch2], ax6, s=250, c=\"black\", marker=\".\")\n", 177 | " # ax6.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n", 178 | " ax6.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n", 179 | " ax6.set_yticks([]) \n", 180 | " ax6.set_xticks([])\n", 181 | " ax6.set_xlim(0, 100)\n", 182 | "\n", 183 | " ########## EPOCH 100 ##############\n", 184 | " # Plot membrane potential\n", 185 | " ax7.plot(mem[epoch3])\n", 186 | " ax7.set_ylim([y1, y2]) # 0.1, 1.3\n", 187 | " # ax7.set_ylabel(\"$u^{~j}_t$\", fontsize=fontsize, fontweight='bold')\n", 188 | " ax7.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 189 | " ax7.set_yticks([])\n", 190 | " ax7.set_xticks([])\n", 191 | " ax7.set_title(\"$\\gamma =$\" + epoch3_str,fontsize=fontsize, fontweight='bold')\n", 192 | " # plt.xlabel(\"Time\") \n", 193 | "\n", 194 | " # Plot output spike using spikeplot\n", 195 | " splt.raster(spk_out[epoch3], ax8, s=250, c=\"black\", marker=\".\")\n", 196 | " # ax8.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n", 197 | " ax8.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n", 198 | " ax8.set_yticks([]) \n", 199 | " ax8.set_xticks([])\n", 200 | " ax8.set_xlim(0, 100)\n", 201 | " \n", 202 | " fig.tight_layout()\n", 203 | " plt.subplots_adjust(\n", 204 | " # left=0.125,\n", 205 | " # bottom=0.1, \n", 206 | " # right=0.9, \n", 207 | " # top=0.9, \n", 208 | " wspace=0.01, \n", 209 | " hspace=0.)\n", 210 | " \n", 211 | " if fill:\n", 212 | " ax1.fill_between(x, target_mem, step=\"pre\", alpha=0.4, color='tab:blue')\n", 213 | " ax3.fill_between(x, mem[epoch1], step=\"pre\", alpha=0.4, color='tab:blue')\n", 214 | " ax5.fill_between(x, mem[epoch2], step=\"pre\", alpha=0.4, color='tab:blue')\n", 215 | " ax7.fill_between(x, mem[epoch3], step=\"pre\", alpha=0.4, color='tab:blue')\n", 216 | "\n", 217 | " fig1 = plt.gcf()\n", 218 | " if save:\n", 219 | " fig1.savefig(save, dpi=600)\n", 220 | "\n", 221 | " plt.show()\n", 222 | "\n", 223 | "\n", 224 | "def plot_quadrant_tha(mem, spk_out, target_mem, spk_target, y1, y2, \n", 225 | " threshold=[1, 1, 1], save=False, epoch1 = 1, epoch2=25, \n", 226 | " epoch3=100, fill=True):\n", 227 | " # Generate Plots\n", 228 | " gs = gridspec.GridSpec(2, 4, height_ratios=[1, 0.07])\n", 229 | " fig = plt.figure(figsize=(12,4.5),)\n", 230 | " ax1 = plt.subplot(gs[0,0])\n", 231 | " ax2 = plt.subplot(gs[1,0])\n", 232 | " ax3 = plt.subplot(gs[0,1])\n", 233 | " ax4 = plt.subplot(gs[1,1])\n", 234 | " ax5 = plt.subplot(gs[0,2])\n", 235 | " ax6 = plt.subplot(gs[1,2])\n", 236 | " ax7 = plt.subplot(gs[0,3])\n", 237 | " ax8 = plt.subplot(gs[1,3])\n", 238 | "\n", 239 | " mem = prep_for_plot(mem)\n", 240 | " spk_out = prep_for_plot(spk_out)\n", 241 | " target_mem = prep_for_plot(target_mem)\n", 242 | " epoch1_str = str(epoch1)\n", 243 | " epoch2_str = str(epoch2)\n", 244 | " epoch3_str = str(epoch3)\n", 245 | "\n", 246 | " fontsize = 25\n", 247 | "\n", 248 | " ########### TARGET ########\n", 249 | " # Plot membrane potential\n", 250 | " ax1.plot(target_mem)\n", 251 | " ax1.set_ylim([y1, y2]) # 0.1, 1.3\n", 252 | " ax1.set_ylabel(\"$u$\", fontsize=fontsize, fontweight='bold')\n", 253 | " ax1.axhline(y=threshold[999], alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 254 | " ax1.set_yticks([])\n", 255 | " ax1.set_xticks([])\n", 256 | " ax1.set_title(\"Target\",fontsize=fontsize, fontweight='bold')\n", 257 | " # plt.xlabel(\"Time\") \n", 258 | "\n", 259 | " # Plot output spike using spikeplot\n", 260 | " splt.raster(spk_target, ax2, s=250, c=\"black\", marker=\".\")\n", 261 | " ax2.set_ylabel(\"$z$\", fontsize=fontsize, fontweight='bold')\n", 262 | " ax2.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n", 263 | " ax2.set_yticks([]) \n", 264 | " ax2.set_xticks([])\n", 265 | " ax2.set_xlim(0, 100)\n", 266 | "\n", 267 | " ############## EPOCH 1 ########\n", 268 | "\n", 269 | " # Plot membrane potential\n", 270 | " ax3.plot(mem[epoch1])\n", 271 | " ax3.set_ylim([y1, y2]) # 0.1, 1.3\n", 272 | " ax3.axhline(y=threshold[epoch1], alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 273 | " ax3.set_yticks([])\n", 274 | " ax3.set_xticks([])\n", 275 | " ax3.set_title(\"$\\gamma =$\" + epoch1_str ,fontsize=fontsize, fontweight='bold')\n", 276 | " # plt.xlabel(\"Time\") \n", 277 | "\n", 278 | " # Plot output spike using spikeplot\n", 279 | " splt.raster(spk_out[epoch1], ax4, s=250, c=\"black\", marker=\".\")\n", 280 | " # ax4.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n", 281 | " ax4.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n", 282 | " ax4.set_yticks([]) \n", 283 | " ax4.set_xticks([])\n", 284 | " ax4.set_xlim(0, 100)\n", 285 | "\n", 286 | " ############## EPOCH 100 ########\n", 287 | "\n", 288 | " # Plot membrane potential\n", 289 | " ax5.plot(mem[epoch2])\n", 290 | " ax5.set_ylim([y1, y2]) # 0.1, 1.3\n", 291 | " # ax5.set_ylabel(\"$u^{~j}_t$\", fontsize=fontsize, fontweight='bold')\n", 292 | " ax5.axhline(y=threshold[epoch2], alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 293 | " ax5.set_yticks([])\n", 294 | " ax5.set_xticks([]) \n", 295 | " ax5.set_title(\"$\\gamma =$\" + epoch2_str,fontsize=fontsize, fontweight='bold')\n", 296 | " # plt.xlabel(\"Time\") \n", 297 | "\n", 298 | " # Plot output spike using spikeplot\n", 299 | " splt.raster(spk_out[epoch2], ax6, s=250, c=\"black\", marker=\".\")\n", 300 | " # ax6.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n", 301 | " ax6.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n", 302 | " ax6.set_yticks([]) \n", 303 | " ax6.set_xticks([])\n", 304 | " ax6.set_xlim(0, 100)\n", 305 | "\n", 306 | " ########## EPOCH 100 ##############\n", 307 | " # Plot membrane potential\n", 308 | " ax7.plot(mem[epoch3])\n", 309 | " ax7.set_ylim([y1, y2]) # 0.1, 1.3\n", 310 | " # ax7.set_ylabel(\"$u^{~j}_t$\", fontsize=fontsize, fontweight='bold')\n", 311 | " ax7.axhline(y=threshold[epoch3], alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 312 | " ax7.set_yticks([])\n", 313 | " ax7.set_xticks([])\n", 314 | " ax7.set_title(\"$\\gamma =$\" + epoch3_str,fontsize=fontsize, fontweight='bold')\n", 315 | " # plt.xlabel(\"Time\") \n", 316 | "\n", 317 | " # Plot output spike using spikeplot\n", 318 | " splt.raster(spk_out[epoch3], ax8, s=250, c=\"black\", marker=\".\")\n", 319 | " # ax8.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n", 320 | " ax8.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n", 321 | " ax8.set_yticks([]) \n", 322 | " ax8.set_xticks([])\n", 323 | " ax8.set_xlim(0, 100)\n", 324 | " \n", 325 | " fig.tight_layout()\n", 326 | " plt.subplots_adjust(\n", 327 | " # left=0.125,\n", 328 | " # bottom=0.1, \n", 329 | " # right=0.9, \n", 330 | " # top=0.9, \n", 331 | " wspace=0.01, \n", 332 | " hspace=0.)\n", 333 | " \n", 334 | " if fill:\n", 335 | " ax1.fill_between(x, target_mem, step=\"pre\", alpha=0.4, color='tab:blue')\n", 336 | " ax3.fill_between(x, mem[epoch1], step=\"pre\", alpha=0.4, color='tab:blue')\n", 337 | " ax5.fill_between(x, mem[epoch2], step=\"pre\", alpha=0.4, color='tab:blue')\n", 338 | " ax7.fill_between(x, mem[epoch3], step=\"pre\", alpha=0.4, color='tab:blue')\n", 339 | "\n", 340 | " fig1 = plt.gcf()\n", 341 | " if save:\n", 342 | " fig1.savefig(save, dpi=600)\n", 343 | "\n", 344 | " plt.show()\n" 345 | ], 346 | "metadata": { 347 | "cellView": "form", 348 | "id": "RERn5ncNBF2Z" 349 | }, 350 | "execution_count": null, 351 | "outputs": [] 352 | }, 353 | { 354 | "cell_type": "markdown", 355 | "metadata": { 356 | "id": "1svHV-viQ_Ll" 357 | }, 358 | "source": [ 359 | "# 1. High Precision Testing\n", 360 | "## 1.1 Choose some random hyperparameters" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "metadata": { 366 | "id": "izqVw9L4UaWx" 367 | }, 368 | "source": [ 369 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", 370 | "dtype = torch.float\n", 371 | "num_steps = 100\n", 372 | "num_inputs = 100\n", 373 | "num_hidden = 1000\n", 374 | "batch_size = 1\n", 375 | "beta=0.6\n", 376 | "spike_time = 75\n", 377 | "\n", 378 | "loss_fn = nn.MSELoss() " 379 | ], 380 | "execution_count": null, 381 | "outputs": [] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "source": [ 386 | "def set_all_seeds(seed=0):\n", 387 | " random.seed(seed)\n", 388 | " os.environ[\"PYTHONHASHSEED\"] = str(seed)\n", 389 | " np.random.seed(seed)\n", 390 | " torch.manual_seed(seed)\n", 391 | " torch.cuda.manual_seed(seed)\n", 392 | " torch.backends.cudnn.deterministic = True\n", 393 | "\n", 394 | "set_all_seeds()" 395 | ], 396 | "metadata": { 397 | "id": "1npF4uSpAbLf" 398 | }, 399 | "execution_count": null, 400 | "outputs": [] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": { 405 | "id": "9rQJ71afRqwo" 406 | }, 407 | "source": [ 408 | "## 1.2 Generate Random Inputs and Membrane Trace Target\n", 409 | "* The random inputs will be fed to the network\n", 410 | "* The output neuron will be trained to replicate the evolution of the membrane trace generated below" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "metadata": { 416 | "id": "dGXVx6u-UsuL" 417 | }, 418 | "source": [ 419 | "input_prob = torch.rand(num_steps, batch_size, num_inputs).to(device)\n", 420 | "input_data = spikegen.rate(input_prob, time_var_input=True)\n", 421 | "target_mem = spikegen.targets_latency(torch.zeros(1, dtype=dtype, device=device), num_classes=1, first_spike_time=75, on_target=1.05, num_steps=100, interpolate=True)" 422 | ], 423 | "execution_count": null, 424 | "outputs": [] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "metadata": { 429 | "id": "8MUE72R0bSSU" 430 | }, 431 | "source": [ 432 | "# membrane trace target: Threshold=1\n", 433 | "splt.traces(target_mem, spk=False, dim=(1,1), spk_height=1)" 434 | ], 435 | "execution_count": null, 436 | "outputs": [] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": { 441 | "id": "638dTiPNRzuc" 442 | }, 443 | "source": [ 444 | "## 1.3 Define network\n", 445 | "100-1000-1 Dense Network" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "metadata": { 451 | "id": "k42uNpyxzsCb" 452 | }, 453 | "source": [ 454 | "net = nn.Sequential(\n", 455 | " nn.Linear(num_inputs, num_hidden),\n", 456 | " snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True),\n", 457 | " nn.Linear(num_hidden, 1),\n", 458 | " snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True, output=True)\n", 459 | ").to(device)" 460 | ], 461 | "execution_count": null, 462 | "outputs": [] 463 | }, 464 | { 465 | "cell_type": "markdown", 466 | "metadata": { 467 | "id": "hBLjAshUR1SM" 468 | }, 469 | "source": [ 470 | "## 1.4 High-precision training loop\n", 471 | "Start with high precision weights." 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "metadata": { 477 | "id": "-3fkyBYxU_h8" 478 | }, 479 | "source": [ 480 | "# optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999)) \n", 481 | "optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, momentum=0.9)\n", 482 | "num_epochs = 1000\n", 483 | "\n", 484 | "mem_tot = []\n", 485 | "spk_tot = []\n", 486 | "\n", 487 | "for epoch in tqdm(range(num_epochs)):\n", 488 | " mem_rec = []\n", 489 | " spk_rec = []\n", 490 | "\n", 491 | " utils.reset(net)\n", 492 | "\n", 493 | " for step in range(num_steps):\n", 494 | " spk, mem = net(input_data[step])\n", 495 | " mem_rec.append(mem)\n", 496 | " spk_rec.append(spk)\n", 497 | "\n", 498 | " mem_rec = torch.stack(mem_rec)\n", 499 | " mem_tot.append(mem_rec)\n", 500 | "\n", 501 | " spk_rec = torch.stack(spk_rec)\n", 502 | " spk_tot.append(spk_rec)\n", 503 | "\n", 504 | " # loss = loss_fn(targets_spike, mem_rec) + 2*loss_fn(targets_spike[75], mem_rec[75])+ 5e-1*sum(spk_rec) # full trace \n", 505 | " loss = loss_fn(target_mem, mem_rec) # + 2 * loss_fn(targets_spike[75], mem_rec[75]) # + 0*(torch.exp(sum(spk_rec))-1)\n", 506 | "\n", 507 | " # clear previously stored gradients\n", 508 | " optimizer.zero_grad()\n", 509 | "\n", 510 | " # calculate the gradients\n", 511 | " loss.backward()\n", 512 | "\n", 513 | " # weight update\n", 514 | " optimizer.step()\n", 515 | "\n", 516 | "mem_tot = torch.stack(mem_tot)\n", 517 | "spk_tot = torch.stack(spk_tot)" 518 | ], 519 | "execution_count": null, 520 | "outputs": [] 521 | }, 522 | { 523 | "cell_type": "markdown", 524 | "source": [ 525 | "## 1.5 Plot Membrane Potential\n", 526 | "$\\gamma$ refers to the training iteration." 527 | ], 528 | "metadata": { 529 | "id": "aa6ioKwHuG9_" 530 | } 531 | }, 532 | { 533 | "cell_type": "code", 534 | "source": [ 535 | "plot_quadrant(mem_tot, spk_tot, target_mem, spk_target, -0.1, 1.2, threshold=1, save=\"spk_time_flt.png\", epoch1=1, epoch2=100, epoch3=500, fill=True) # save=\"spk_time_flt.png\"" 536 | ], 537 | "metadata": { 538 | "id": "ClAhLranCClW" 539 | }, 540 | "execution_count": null, 541 | "outputs": [] 542 | }, 543 | { 544 | "cell_type": "markdown", 545 | "metadata": { 546 | "id": "sL3ywHTumXY7" 547 | }, 548 | "source": [ 549 | "## 1.6 Evolution of membrane potential over training epochs" 550 | ] 551 | }, 552 | { 553 | "cell_type": "code", 554 | "metadata": { 555 | "id": "OEyHUZD5IcUU" 556 | }, 557 | "source": [ 558 | "threshold = 1\n", 559 | "fig, ax = plt.subplots()\n", 560 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n", 561 | "x = np.arange(0, 100, 1) \n", 562 | "\n", 563 | "ax.set_xlim(0, num_steps)\n", 564 | "ax.set_ylim(-0.5, 1.5)\n", 565 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n", 566 | " verticalalignment='top', transform=ax.transAxes, size='large')\n", 567 | "\n", 568 | "ax.set_ylabel('Membrane Potential ($u$)')\n", 569 | "ax.set_xlabel('Time Steps')\n", 570 | "ax.plot(target_mem[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n", 571 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 572 | "\n", 573 | "\n", 574 | "def animate(frame_num):\n", 575 | " line.set_data(x, mem_tot[frame_num, x, 0,0].cpu().detach().numpy())\n", 576 | " time_text.set_text(f'Epoch: {frame_num}')\n", 577 | "\n", 578 | " # ax.plot([], [], ' ', label=str(frame_num))\n", 579 | " # ax.legend(loc='upper right')\n", 580 | " return (line, time_text)\n", 581 | "\n", 582 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30)\n", 583 | "anim.save('spk_time_flt.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n", 584 | "\n", 585 | "video = anim.to_html5_video()\n", 586 | "html = display.HTML(video)\n", 587 | "display.display(html)\n", 588 | "plt.close() # avoid plotting a spare static plot" 589 | ], 590 | "execution_count": null, 591 | "outputs": [] 592 | }, 593 | { 594 | "cell_type": "markdown", 595 | "metadata": { 596 | "id": "_4RqawaA4BbU" 597 | }, 598 | "source": [ 599 | "# 2. Binarized Spike Timing: Threshold=1\n", 600 | "The high precision simulation does a good job of tracking the desired membrane potential. There is instability when a spike occurs because of the discontinuous reset: when the neuron is reset, the weights try to offset the sudden change by increasing weights.\n", 601 | "\n", 602 | "Now, let's test out binarized spiking neural nets. \n", 603 | "Before introducing threshold annealing, we will apply a threshold of $\\theta=1$ to all neurons. The input of each axon can only ever be +1 or -1. \n", 604 | "We can expect the outcome to be extremely unstable.\n", 605 | "\n", 606 | "## 2.1 Binarized Functions\n" 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "metadata": { 612 | "id": "4dZkN83RbDP0" 613 | }, 614 | "source": [ 615 | "class BinaryLinear(nn.Linear):\n", 616 | " def forward(self, input):\n", 617 | " binary_weight = binarize(self.weight)\n", 618 | " if self.bias is None:\n", 619 | " return F.linear(input, binary_weight)\n", 620 | " else:\n", 621 | " return F.linear(input, binary_weight, self.bias)\n", 622 | "\n", 623 | " def reset_parameters(self):\n", 624 | " # Glorot initialization\n", 625 | " in_features, out_features = self.weight.size()\n", 626 | " stdv = math.sqrt(1.5 / (in_features + out_features))\n", 627 | " self.weight.data.uniform_(-stdv, stdv)\n", 628 | " if self.bias is not None:\n", 629 | " self.bias.data.zero_()\n", 630 | "\n", 631 | " self.weight.lr_scale = 1. / stdv\n", 632 | "\n", 633 | "\n", 634 | "class BinarizeF(Function):\n", 635 | "\n", 636 | " @staticmethod\n", 637 | " def forward(ctx, input):\n", 638 | " output = input.new(input.size())\n", 639 | " output[input >= 0] = 1\n", 640 | " output[input < 0] = -1\n", 641 | " return output\n", 642 | "\n", 643 | " @staticmethod\n", 644 | " def backward(ctx, grad_output):\n", 645 | " grad_input = grad_output.clone()\n", 646 | " return grad_input\n", 647 | "\n", 648 | "# aliases\n", 649 | "binarize = BinarizeF.apply" 650 | ], 651 | "execution_count": null, 652 | "outputs": [] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "metadata": { 657 | "id": "pZ_v2xDcU6Hh" 658 | }, 659 | "source": [ 660 | "## 2.2 Hyperparameters" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "metadata": { 666 | "id": "Q1UYIj62U3Vm" 667 | }, 668 | "source": [ 669 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", 670 | "dtype = torch.float\n", 671 | "num_steps = 100\n", 672 | "num_inputs = 100\n", 673 | "num_hidden = 1000\n", 674 | "batch_size = 1\n", 675 | "beta=0.15 \n", 676 | "\n", 677 | "loss_fn = nn.MSELoss() " 678 | ], 679 | "execution_count": null, 680 | "outputs": [] 681 | }, 682 | { 683 | "cell_type": "markdown", 684 | "metadata": { 685 | "id": "WbfDFqYJm-95" 686 | }, 687 | "source": [ 688 | "## 2.3 Network Definition\n", 689 | "Same architecture will be used all throughout: 100-1000-1 Dense Layers." 690 | ] 691 | }, 692 | { 693 | "cell_type": "code", 694 | "metadata": { 695 | "id": "vf1uJDT6dbst" 696 | }, 697 | "source": [ 698 | "b_net = nn.Sequential(\n", 699 | " BinaryLinear(num_inputs, num_hidden),\n", 700 | " snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True),\n", 701 | " BinaryLinear(num_hidden, 1),\n", 702 | " snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True, output=True)\n", 703 | ").to(device)" 704 | ], 705 | "execution_count": null, 706 | "outputs": [] 707 | }, 708 | { 709 | "cell_type": "markdown", 710 | "metadata": { 711 | "id": "bZHfDW88nAx5" 712 | }, 713 | "source": [ 714 | "## 2.4 Binarized Training Loop" 715 | ] 716 | }, 717 | { 718 | "cell_type": "code", 719 | "metadata": { 720 | "id": "jlUf-mt28z0h" 721 | }, 722 | "source": [ 723 | "optimizer = torch.optim.SGD(b_net.parameters(), lr=1e-3, momentum=0.9)\n", 724 | "num_epochs = 1000\n", 725 | "mem_tot_bin = []\n", 726 | "spk_tot_bin = []\n", 727 | "\n", 728 | "for epoch in tqdm(range(num_epochs)):\n", 729 | " mem_rec = []\n", 730 | " spk_rec = []\n", 731 | "\n", 732 | " utils.reset(net)\n", 733 | "\n", 734 | " for step in range(num_steps):\n", 735 | " spk, mem = b_net(input_data[step])\n", 736 | " mem_rec.append(mem)\n", 737 | " spk_rec.append(spk)\n", 738 | "\n", 739 | " spk_rec = torch.stack(spk_rec)\n", 740 | " mem_rec = torch.stack(mem_rec)\n", 741 | " mem_tot_bin.append(mem_rec)\n", 742 | " spk_tot_bin.append(spk_rec)\n", 743 | "\n", 744 | " loss = loss_fn(target_mem, mem_rec)\n", 745 | "\n", 746 | " # clear previously stored gradients\n", 747 | " optimizer.zero_grad()\n", 748 | "\n", 749 | " # calculate the gradients\n", 750 | " loss.backward()\n", 751 | "\n", 752 | " # weight update\n", 753 | " optimizer.step()\n", 754 | "\n", 755 | "mem_tot_bin = torch.stack(mem_tot_bin)\n", 756 | "spk_tot_bin = torch.stack(spk_tot_bin)" 757 | ], 758 | "execution_count": null, 759 | "outputs": [] 760 | }, 761 | { 762 | "cell_type": "markdown", 763 | "source": [ 764 | "## 2.5 Plot Membrane Potential" 765 | ], 766 | "metadata": { 767 | "id": "iUHLpcIdu4wJ" 768 | } 769 | }, 770 | { 771 | "cell_type": "code", 772 | "source": [ 773 | "plot_quadrant(mem_tot_bin, spk_tot_bin, target_mem, spk_target, -0.1, 1.2, threshold=1, save='spk_time_bin.png', epoch1=0, epoch2=75, epoch3=750, fill=True) # save=\"spk_time_flt.png\"" 774 | ], 775 | "metadata": { 776 | "id": "PYewBqdZgWa0" 777 | }, 778 | "execution_count": null, 779 | "outputs": [] 780 | }, 781 | { 782 | "cell_type": "markdown", 783 | "source": [ 784 | "As expected, this doesn't look great. \n", 785 | "This somewhat resembles the pathological case described in section 2 of the paper, where BSNNs struggle to incorporate both memory dynamics and spike propagation. I.e., no smooth memory dynamics are visible above." 786 | ], 787 | "metadata": { 788 | "id": "7IB8TGFWu68L" 789 | } 790 | }, 791 | { 792 | "cell_type": "markdown", 793 | "metadata": { 794 | "id": "N8orKC1dntVS" 795 | }, 796 | "source": [ 797 | "## 2.6 Evolution of membrane trace over training epochs" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "source": [ 803 | "threshold = 1\n", 804 | "fig, ax = plt.subplots()\n", 805 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n", 806 | "x = np.arange(0, 100, 1) \n", 807 | "\n", 808 | "ax.set_xlim(0, num_steps)\n", 809 | "ax.set_ylim(-0.5, 1.5)\n", 810 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n", 811 | " verticalalignment='top', transform=ax.transAxes, size='large')\n", 812 | "\n", 813 | "ax.set_ylabel('Membrane Potential ($u$)')\n", 814 | "ax.set_xlabel('Time Steps')\n", 815 | "ax.plot(target_mem[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n", 816 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 817 | "\n", 818 | "\n", 819 | "def animate(frame_num):\n", 820 | " line.set_data(x, mem_tot_bin[frame_num, x, 0,0].cpu().detach().numpy())\n", 821 | " time_text.set_text(f'Epoch: {frame_num}')\n", 822 | " return (line, time_text)\n", 823 | "\n", 824 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30)\n", 825 | "anim.save('spk_time_bin.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n", 826 | "\n", 827 | "video = anim.to_html5_video()\n", 828 | "html = display.HTML(video)\n", 829 | "display.display(html)\n", 830 | "plt.close() # avoid plotting a spare static plot" 831 | ], 832 | "metadata": { 833 | "id": "5OWG-Vzqhonn" 834 | }, 835 | "execution_count": null, 836 | "outputs": [] 837 | }, 838 | { 839 | "cell_type": "markdown", 840 | "metadata": { 841 | "id": "G8IpqgZvoeYT" 842 | }, 843 | "source": [ 844 | "## 2.7 Moving Average\n", 845 | "Perhaps we will see better results if we take the moving average of the membrane potential (over epochs)." 846 | ] 847 | }, 848 | { 849 | "cell_type": "code", 850 | "source": [ 851 | "threshold = 1\n", 852 | "fig, ax = plt.subplots()\n", 853 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n", 854 | "x = np.arange(0, 100, 1)\n", 855 | "\n", 856 | "N = 5 # size of filter\n", 857 | "mem_avg_bin = uniform_filter1d(mem_tot_bin.cpu().detach(), size=N, axis=1)\n", 858 | "\n", 859 | "ax.set_xlim(0, num_steps)\n", 860 | "ax.set_ylim(-0.5, 1.5)\n", 861 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n", 862 | " verticalalignment='top', transform=ax.transAxes, size='large')\n", 863 | "\n", 864 | "ax.set_ylabel('Membrane Potential ($u$)')\n", 865 | "ax.set_xlabel('Time Steps')\n", 866 | "ax.plot(target_mem[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n", 867 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 868 | "\n", 869 | "\n", 870 | "def animate(frame_num):\n", 871 | " line.set_data(x, mem_avg_bin[frame_num, x, 0,0])\n", 872 | " time_text.set_text(f'Epoch: {frame_num}')\n", 873 | "\n", 874 | " # ax.plot([], [], ' ', label=str(frame_num))\n", 875 | " # ax.legend(loc='upper right')\n", 876 | " return (line, time_text)\n", 877 | "\n", 878 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30)\n", 879 | "anim.save('spk_time_bin_MVA.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n", 880 | "\n", 881 | "video = anim.to_html5_video()\n", 882 | "html = display.HTML(video)\n", 883 | "display.display(html)\n", 884 | "plt.close() # avoid plotting a spare static plot" 885 | ], 886 | "metadata": { 887 | "id": "vMwX-33biqey" 888 | }, 889 | "execution_count": null, 890 | "outputs": [] 891 | }, 892 | { 893 | "cell_type": "markdown", 894 | "source": [ 895 | "Perhaps not." 896 | ], 897 | "metadata": { 898 | "id": "zGS39QO9vf5L" 899 | } 900 | }, 901 | { 902 | "cell_type": "markdown", 903 | "metadata": { 904 | "id": "5lakzDLwG8K5" 905 | }, 906 | "source": [ 907 | "# 3. Wide threshold BNN\n", 908 | "If we use a large threshold, then each spiking neuron would have a wider dynamic range state-space, and this could enable more precise tuning. \n", 909 | "\n", 910 | "The problem we will run into is, if the threshold is too high, then downstream spikes probably won't occur, and so learning will also fail to take place. Let's set the threhsold of all neurons to $\\theta=50$. This is a significant jump from $\\theta=1$. " 911 | ] 912 | }, 913 | { 914 | "cell_type": "markdown", 915 | "metadata": { 916 | "id": "EqAGq-h1pIe5" 917 | }, 918 | "source": [ 919 | "## 3.1 Hyperparameters" 920 | ] 921 | }, 922 | { 923 | "cell_type": "code", 924 | "metadata": { 925 | "id": "eB_7CDHhWqb_" 926 | }, 927 | "source": [ 928 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", 929 | "dtype = torch.float\n", 930 | "num_steps = 100\n", 931 | "num_inputs = 100\n", 932 | "num_hidden = 1000\n", 933 | "batch_size = 1\n", 934 | "beta=0.15\n", 935 | "w_thr1 = 50\n", 936 | "w_thr2 = 50 # 25 works well\n", 937 | "on_target = w_thr2 + w_thr2*0.1\n", 938 | "first_spike_time = 75\n", 939 | "\n", 940 | "loss_fn = nn.MSELoss() \n", 941 | "# loss_fn = nn.CrossEntropyLoss()" 942 | ], 943 | "execution_count": null, 944 | "outputs": [] 945 | }, 946 | { 947 | "cell_type": "markdown", 948 | "metadata": { 949 | "id": "VUkd2Y6MpAg2" 950 | }, 951 | "source": [ 952 | "## 3.2 Define target" 953 | ] 954 | }, 955 | { 956 | "cell_type": "code", 957 | "metadata": { 958 | "id": "YT-ZdMAJbr6K" 959 | }, 960 | "source": [ 961 | "targets_wthr = spikegen.targets_latency(torch.zeros(1, dtype=dtype, device=device), num_classes=1, first_spike_time=first_spike_time, on_target=on_target, num_steps=num_steps, interpolate=True)" 962 | ], 963 | "execution_count": null, 964 | "outputs": [] 965 | }, 966 | { 967 | "cell_type": "markdown", 968 | "metadata": { 969 | "id": "3aTkUWNBpLSs" 970 | }, 971 | "source": [ 972 | "## 3.3 Define network" 973 | ] 974 | }, 975 | { 976 | "cell_type": "code", 977 | "metadata": { 978 | "id": "suVpw-rm9DKa" 979 | }, 980 | "source": [ 981 | "wthr_net = nn.Sequential(\n", 982 | " BinaryLinear(num_inputs, num_hidden),\n", 983 | " snn.Leaky(beta=beta, threshold=w_thr1, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True),\n", 984 | " BinaryLinear(num_hidden, 1),\n", 985 | " snn.Leaky(beta=beta, threshold=w_thr2, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True, output=True)\n", 986 | ").to(device)" 987 | ], 988 | "execution_count": null, 989 | "outputs": [] 990 | }, 991 | { 992 | "cell_type": "markdown", 993 | "metadata": { 994 | "id": "d3UEhuiqpM4T" 995 | }, 996 | "source": [ 997 | "## 3.4 Training Loop" 998 | ] 999 | }, 1000 | { 1001 | "cell_type": "code", 1002 | "metadata": { 1003 | "id": "k6LqaEDKHElP" 1004 | }, 1005 | "source": [ 1006 | "optimizer = torch.optim.SGD(wthr_net.parameters(), lr=1e-3, momentum=0.9)\n", 1007 | "num_epochs = 1000\n", 1008 | "mem_tot_wthr = []\n", 1009 | "spk_tot_wthr = []\n", 1010 | "\n", 1011 | "for epoch in tqdm(range(num_epochs)):\n", 1012 | " mem_rec = []\n", 1013 | " spk_rec = []\n", 1014 | "\n", 1015 | " utils.reset(net)\n", 1016 | "\n", 1017 | " for step in range(num_steps):\n", 1018 | " spk, mem = wthr_net(input_data[step])\n", 1019 | " mem_rec.append(mem)\n", 1020 | " spk_rec.append(spk)\n", 1021 | "\n", 1022 | " spk_rec = torch.stack(spk_rec)\n", 1023 | " mem_rec = torch.stack(mem_rec)\n", 1024 | " spk_tot_wthr.append(spk_rec)\n", 1025 | " mem_tot_wthr.append(mem_rec)\n", 1026 | "\n", 1027 | " loss = loss_fn(targets_wthr, mem_rec)\n", 1028 | "\n", 1029 | " # clear previously stored gradients\n", 1030 | " optimizer.zero_grad()\n", 1031 | "\n", 1032 | " # calculate the gradients\n", 1033 | " loss.backward()\n", 1034 | "\n", 1035 | " # weight update\n", 1036 | " optimizer.step()\n", 1037 | "\n", 1038 | "mem_tot_wthr = torch.stack(mem_tot_wthr)\n", 1039 | "spk_tot_wthr = torch.stack(spk_tot_wthr)" 1040 | ], 1041 | "execution_count": null, 1042 | "outputs": [] 1043 | }, 1044 | { 1045 | "cell_type": "markdown", 1046 | "source": [ 1047 | "## 3.5 Plot Membrane Potential" 1048 | ], 1049 | "metadata": { 1050 | "id": "k9aHG9Olv1rT" 1051 | } 1052 | }, 1053 | { 1054 | "cell_type": "code", 1055 | "source": [ 1056 | "plot_quadrant(mem_tot_wthr, spk_tot_wthr, targets_wthr, spk_target, -1, 60, threshold=50, save=\"spk_time_wthr.png\", epoch1=0, epoch2=75, epoch3=750, fill=True) # save=\"spk_time_flt.png\"" 1057 | ], 1058 | "metadata": { 1059 | "id": "lp4uJVMnkpDH" 1060 | }, 1061 | "execution_count": null, 1062 | "outputs": [] 1063 | }, 1064 | { 1065 | "cell_type": "markdown", 1066 | "metadata": { 1067 | "id": "KGttbe8epcvp" 1068 | }, 1069 | "source": [ 1070 | "## 3.6 Animation of membrane potential\n", 1071 | "\n", 1072 | "This result doesn't fluctuate, but neither does it produce the desired behavior of spiking at the 75th time step - in fact, no spikes at all are produced. \n", 1073 | "\n", 1074 | "The membrane potential staying constant over time indicates the output neuron does not receive any spikes from the previous layer. Rather, it is the bias driving the second layer. The bias slowly increases until it hits roughly the mid-point of the threshold to minimize the overall loss over time.\n", 1075 | "\n", 1076 | "If the bias was removed, the membrane potential would be stuck at zero. So clearly, this doesn't quite work either.\n", 1077 | "\n", 1078 | "Note that the membrane potential falls just short of 25. This can be explained by the final steps of the target being set to 0, which suppresses the overall steady-state response." 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "code", 1083 | "source": [ 1084 | "threshold = 50\n", 1085 | "fig, ax = plt.subplots()\n", 1086 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n", 1087 | "x = np.arange(0, 100, 1) \n", 1088 | "\n", 1089 | "ax.set_xlim(0, num_steps)\n", 1090 | "ax.set_ylim(-1, 60)\n", 1091 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n", 1092 | " verticalalignment='top', transform=ax.transAxes, size='large')\n", 1093 | "\n", 1094 | "ax.set_ylabel('Membrane Potential ($u$)')\n", 1095 | "ax.set_xlabel('Time Steps')\n", 1096 | "ax.plot(targets_wthr[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n", 1097 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 1098 | "\n", 1099 | "\n", 1100 | "def animate(frame_num):\n", 1101 | " line.set_data(x, mem_tot_wthr[frame_num, x, 0,0].cpu().detach().numpy())\n", 1102 | " time_text.set_text(f'Epoch: {frame_num}')\n", 1103 | "\n", 1104 | " # ax.plot([], [], ' ', label=str(frame_num))\n", 1105 | " # ax.legend(loc='upper right')\n", 1106 | " return (line, time_text)\n", 1107 | "\n", 1108 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30)\n", 1109 | "anim.save('spk_time_wthr.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n", 1110 | "\n", 1111 | "video = anim.to_html5_video()\n", 1112 | "html = display.HTML(video)\n", 1113 | "display.display(html)\n", 1114 | "plt.close() # avoid plotting a spare static plot" 1115 | ], 1116 | "metadata": { 1117 | "id": "lNN-KZ_Yk620" 1118 | }, 1119 | "execution_count": null, 1120 | "outputs": [] 1121 | }, 1122 | { 1123 | "cell_type": "markdown", 1124 | "metadata": { 1125 | "id": "KoGgy9igi5WE" 1126 | }, 1127 | "source": [ 1128 | "# 4. Bounded Homeostasis\n", 1129 | "## 4.1 Define Threshold Annealing Function\n", 1130 | "\n", 1131 | "If we slowly anneal the threshold from a small value to a larger value, this will result in strong spiking activity in early epochs which avoids the dead neuron problem we saw in the previous case where $\\theta=50$.\n", 1132 | "\n", 1133 | "We implement the most naive form of bounded homeostasis (i.e., one that does not depend on the weight update gradient as with other experiments, which can simply referred to as `threshold annealing') below with exponential relaxation of threshold toward a steady state, completely independent of the input data. The same threshold is applied to all neurons in all layers." 1134 | ] 1135 | }, 1136 | { 1137 | "cell_type": "code", 1138 | "metadata": { 1139 | "id": "bfrH7cqSjEs3" 1140 | }, 1141 | "source": [ 1142 | "def thr_annealing(conf, network):\n", 1143 | " alpha_thr1 = conf['alpha_thr1']\n", 1144 | " alpha_thr2 = conf['alpha_thr2']\n", 1145 | "\n", 1146 | " thr_final1 = conf['thr_final1']\n", 1147 | " thr_final2 = conf['thr_final2']\n", 1148 | "\n", 1149 | " with torch.no_grad():\n", 1150 | " network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1\n", 1151 | " network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2" 1152 | ], 1153 | "execution_count": null, 1154 | "outputs": [] 1155 | }, 1156 | { 1157 | "cell_type": "markdown", 1158 | "metadata": { 1159 | "id": "k7KBgUEFqQEb" 1160 | }, 1161 | "source": [ 1162 | "## 4.2 Define Hyperparameters\n", 1163 | "As before, we set the final threshold to 50. But let's start with 5.0, and gradually warm it up to 50. `alpha_thr1` and `alpha_thr2` are the inverse time constants of the threshold evolution." 1164 | ] 1165 | }, 1166 | { 1167 | "cell_type": "code", 1168 | "metadata": { 1169 | "id": "pV-iBTtCjTAV" 1170 | }, 1171 | "source": [ 1172 | "config = {\n", 1173 | " \n", 1174 | " 'thr_init1' : 5.0,\n", 1175 | " 'thr_init2' : 5.0,\n", 1176 | "\n", 1177 | " 'alpha_thr1' : 5e-3,\n", 1178 | " 'alpha_thr2' : 5e-3,\n", 1179 | "\n", 1180 | " 'thr_final1' : 50.0,\n", 1181 | " 'thr_final2' : 50.0,\n", 1182 | "}" 1183 | ], 1184 | "execution_count": null, 1185 | "outputs": [] 1186 | }, 1187 | { 1188 | "cell_type": "code", 1189 | "metadata": { 1190 | "id": "99WfOLJWXsop" 1191 | }, 1192 | "source": [ 1193 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", 1194 | "dtype = torch.float\n", 1195 | "num_steps = 100\n", 1196 | "num_inputs = 100\n", 1197 | "num_hidden = 1000\n", 1198 | "batch_size = 1\n", 1199 | "beta=0.15\n", 1200 | "on_target = config['thr_final2'] + config['thr_final2']*0.1\n", 1201 | "\n", 1202 | "loss_fn = nn.MSELoss() \n", 1203 | "# loss_fn = nn.CrossEntropyLoss()" 1204 | ], 1205 | "execution_count": null, 1206 | "outputs": [] 1207 | }, 1208 | { 1209 | "cell_type": "markdown", 1210 | "metadata": { 1211 | "id": "0hCzxrlbqmS5" 1212 | }, 1213 | "source": [ 1214 | "## 4.3 Define Target" 1215 | ] 1216 | }, 1217 | { 1218 | "cell_type": "code", 1219 | "metadata": { 1220 | "id": "g8T0SawYc3vI" 1221 | }, 1222 | "source": [ 1223 | "targets_tha = spikegen.targets_latency(torch.zeros(1, dtype=dtype, device=device), num_classes=1, first_spike_time=75, on_target=on_target, num_steps=num_steps, interpolate=True)" 1224 | ], 1225 | "execution_count": null, 1226 | "outputs": [] 1227 | }, 1228 | { 1229 | "cell_type": "markdown", 1230 | "metadata": { 1231 | "id": "5CBSUwJxYjaU" 1232 | }, 1233 | "source": [ 1234 | "## 4.4 Define network" 1235 | ] 1236 | }, 1237 | { 1238 | "cell_type": "code", 1239 | "metadata": { 1240 | "id": "upUaiEzci-Dk" 1241 | }, 1242 | "source": [ 1243 | "class Net(nn.Module):\n", 1244 | " def __init__(self):\n", 1245 | " super().__init__()\n", 1246 | "\n", 1247 | " beta = 0.15\n", 1248 | " spike_grad = surrogate.fast_sigmoid(slope=5)\n", 1249 | "\n", 1250 | " self.fc1 = BinaryLinear(num_inputs, num_hidden)\n", 1251 | " self.fc2 = BinaryLinear(num_hidden, 1)\n", 1252 | "\n", 1253 | " self.lif1 = snn.Leaky(beta=beta, threshold=config['thr_init1'], spike_grad = spike_grad)\n", 1254 | " self.lif2 = snn.Leaky(beta=beta, threshold=config['thr_init2'], spike_grad=spike_grad)\n", 1255 | "\n", 1256 | " def forward(self, x):\n", 1257 | " mem1 = self.lif1.init_leaky() \n", 1258 | " mem2 = self.lif2.init_leaky() \n", 1259 | "\n", 1260 | " spk2_rec = []\n", 1261 | " mem2_rec = []\n", 1262 | "\n", 1263 | " for step in range(x.size(0)):\n", 1264 | " cur1 = self.fc1(x[step])\n", 1265 | " spk1, mem1 = self.lif1(cur1, mem1)\n", 1266 | " cur2 = self.fc2(spk1)\n", 1267 | " spk2, mem2 = self.lif2(cur2, mem2)\n", 1268 | "\n", 1269 | " spk2_rec.append(spk2)\n", 1270 | " mem2_rec.append(mem2)\n", 1271 | " \n", 1272 | " return torch.stack(spk2_rec), torch.stack(mem2_rec)\n", 1273 | "\n", 1274 | "net_tha = Net().to(device)" 1275 | ], 1276 | "execution_count": null, 1277 | "outputs": [] 1278 | }, 1279 | { 1280 | "cell_type": "markdown", 1281 | "metadata": { 1282 | "id": "iqMrSdqrq0mW" 1283 | }, 1284 | "source": [ 1285 | "## 4.5 Training Loop" 1286 | ] 1287 | }, 1288 | { 1289 | "cell_type": "code", 1290 | "metadata": { 1291 | "id": "AS-4Wn10jA6d" 1292 | }, 1293 | "source": [ 1294 | "optimizer = torch.optim.SGD(net_tha.parameters(), lr=1e-3, momentum=0.9)\n", 1295 | "num_epochs = 1000\n", 1296 | "mem_tot_tha = []\n", 1297 | "spk_tot_tha = []\n", 1298 | "thr_L1 = []\n", 1299 | "thr_L2 = []\n", 1300 | "\n", 1301 | "for epoch in tqdm(range(num_epochs)):\n", 1302 | "\n", 1303 | " spk_rec, mem_rec = net_tha(input_data)\n", 1304 | " spk_tot_tha.append(spk_rec)\n", 1305 | " mem_tot_tha.append(mem_rec)\n", 1306 | " loss = loss_fn(targets_tha, mem_rec)\n", 1307 | "\n", 1308 | " # clear previously stored gradients\n", 1309 | " optimizer.zero_grad()\n", 1310 | "\n", 1311 | " # calculate the gradients\n", 1312 | " loss.backward()\n", 1313 | "\n", 1314 | " # weight update\n", 1315 | " optimizer.step()\n", 1316 | "\n", 1317 | " thr_L1.append(net_tha.lif1.threshold.item())\n", 1318 | " thr_L2.append(net_tha.lif2.threshold.item())\n", 1319 | "\n", 1320 | " thr_annealing(config, net_tha)\n", 1321 | " \n", 1322 | "\n", 1323 | "mem_tot_tha = torch.stack(mem_tot_tha)\n", 1324 | "spk_tot_tha = torch.stack(spk_tot_tha)" 1325 | ], 1326 | "execution_count": null, 1327 | "outputs": [] 1328 | }, 1329 | { 1330 | "cell_type": "markdown", 1331 | "metadata": { 1332 | "id": "s8L8DJ7aq4Ov" 1333 | }, 1334 | "source": [ 1335 | "## 4.6 Plot Membrane Potential" 1336 | ] 1337 | }, 1338 | { 1339 | "cell_type": "code", 1340 | "source": [ 1341 | "plot_quadrant_tha(mem_tot_tha, spk_tot_tha, targets_tha, spk_target, -1, 60, threshold=thr_L1, save=\"spk_time_tha.png\", epoch1=0, epoch2=100, epoch3=400, fill=True) # save=\"spk_time_flt.png\"" 1342 | ], 1343 | "metadata": { 1344 | "id": "n166WYwtnBIL" 1345 | }, 1346 | "execution_count": null, 1347 | "outputs": [] 1348 | }, 1349 | { 1350 | "cell_type": "code", 1351 | "source": [ 1352 | "thr_L1[400]" 1353 | ], 1354 | "metadata": { 1355 | "id": "7YQULZmjhbhc" 1356 | }, 1357 | "execution_count": null, 1358 | "outputs": [] 1359 | }, 1360 | { 1361 | "cell_type": "markdown", 1362 | "source": [ 1363 | "This is looking quite nice as training progresses! Let's see the animated version to get better insight." 1364 | ], 1365 | "metadata": { 1366 | "id": "x0Hr_q2Pxipv" 1367 | } 1368 | }, 1369 | { 1370 | "cell_type": "markdown", 1371 | "source": [ 1372 | "## 4.7 Animation of Membrane Potential" 1373 | ], 1374 | "metadata": { 1375 | "id": "4bAvxJ92xe7-" 1376 | } 1377 | }, 1378 | { 1379 | "cell_type": "code", 1380 | "source": [ 1381 | "threshold = 50\n", 1382 | "fig, ax = plt.subplots()\n", 1383 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n", 1384 | "thr_line, = ax.plot([])\n", 1385 | "thr_text1 = ax.text(0.98, 0.91,'',horizontalalignment='right',verticalalignment='top', transform=ax.transAxes, size='large')\n", 1386 | "x = np.arange(0, 100, 1) \n", 1387 | "\n", 1388 | "ax.set_xlim(0, num_steps)\n", 1389 | "ax.set_ylim(-1, 60)\n", 1390 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n", 1391 | " verticalalignment='top', transform=ax.transAxes, size='large')\n", 1392 | "\n", 1393 | "ax.set_ylabel('Membrane Potential ($u$)')\n", 1394 | "ax.set_xlabel('Time Steps')\n", 1395 | "ax.plot(targets_tha[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n", 1396 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 1397 | "\n", 1398 | "\n", 1399 | "def animate(frame_num):\n", 1400 | " line.set_data(x, mem_tot_tha[frame_num, x, 0,0].cpu().detach().numpy())\n", 1401 | " thr_line.set_data(x, thr_L1[frame_num])\n", 1402 | " thr_text1.set_text(f'Threshold: {thr_L1[frame_num]:.3f}')\n", 1403 | " time_text.set_text(f'Epoch: {frame_num}')\n", 1404 | "\n", 1405 | " # ax.plot([], [], ' ', label=str(frame_num))\n", 1406 | " # ax.legend(loc='upper right')\n", 1407 | " return (line, time_text, thr_text1)\n", 1408 | "\n", 1409 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30) # num_epochs\n", 1410 | "anim.save('spk_time_tha.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n", 1411 | "\n", 1412 | "video = anim.to_html5_video()\n", 1413 | "html = display.HTML(video)\n", 1414 | "display.display(html)\n", 1415 | "plt.close() # avoid plotting a spare static plot" 1416 | ], 1417 | "metadata": { 1418 | "id": "xTqE4_2TnV8c" 1419 | }, 1420 | "execution_count": null, 1421 | "outputs": [] 1422 | }, 1423 | { 1424 | "cell_type": "markdown", 1425 | "source": [ 1426 | "To begin with, the several spikes trigger a sudden explosion in activity as the neuron tries to climb its way to $u=50$. Sensory overload. \n", 1427 | "\n", 1428 | "But as the threshold warms up further, activity becomes sparser until finally, the neuron actually hits the desired firing time at several epochs." 1429 | ], 1430 | "metadata": { 1431 | "id": "tIZe7dLiyfwW" 1432 | } 1433 | }, 1434 | { 1435 | "cell_type": "markdown", 1436 | "source": [ 1437 | "## 4.8 Moving Average of Membrane" 1438 | ], 1439 | "metadata": { 1440 | "id": "PPckpEvxrn8R" 1441 | } 1442 | }, 1443 | { 1444 | "cell_type": "code", 1445 | "source": [ 1446 | "threshold = 50\n", 1447 | "fig, ax = plt.subplots()\n", 1448 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n", 1449 | "thr_line, = ax.plot([])\n", 1450 | "thr_text1 = ax.text(0.98, 0.91,'',horizontalalignment='right',verticalalignment='top', transform=ax.transAxes, size='large')\n", 1451 | "x = np.arange(0, 100, 1) \n", 1452 | "\n", 1453 | "\n", 1454 | "N = 5 # size of filter\n", 1455 | "mem_avg_tha = uniform_filter1d(mem_tot_tha.cpu().detach(), size=N, axis=1)\n", 1456 | "\n", 1457 | "ax.set_xlim(0, num_steps)\n", 1458 | "ax.set_ylim(-1, 60)\n", 1459 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n", 1460 | " verticalalignment='top', transform=ax.transAxes, size='large')\n", 1461 | "\n", 1462 | "ax.set_ylabel('Membrane Potential ($u$)')\n", 1463 | "ax.set_xlabel('Time Steps')\n", 1464 | "ax.plot(targets_tha[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n", 1465 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n", 1466 | "\n", 1467 | "\n", 1468 | "def animate(frame_num):\n", 1469 | " line.set_data(x, mem_avg_tha[frame_num, x, 0,0])\n", 1470 | " thr_line.set_data(x, thr_L1[frame_num])\n", 1471 | " thr_text1.set_text(f'Threshold: {thr_L1[frame_num]:.3f}')\n", 1472 | " time_text.set_text(f'Epoch: {frame_num}')\n", 1473 | "\n", 1474 | " # ax.plot([], [], ' ', label=str(frame_num))\n", 1475 | " # ax.legend(loc='upper right')\n", 1476 | " return (line, time_text, thr_text1)\n", 1477 | "\n", 1478 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30) # num_epochs\n", 1479 | "anim.save('spk_time_tha_MVA.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n", 1480 | "\n", 1481 | "video = anim.to_html5_video()\n", 1482 | "html = display.HTML(video)\n", 1483 | "display.display(html)\n", 1484 | "plt.close() # avoid plotting a spare static plot" 1485 | ], 1486 | "metadata": { 1487 | "id": "je3u03vKrncP" 1488 | }, 1489 | "execution_count": null, 1490 | "outputs": [] 1491 | }, 1492 | { 1493 | "cell_type": "markdown", 1494 | "metadata": { 1495 | "id": "3H4v8Augr4wO" 1496 | }, 1497 | "source": [ 1498 | "Not only do we see learning taking place, but the values chosen are completely arbitary. When writing this notebook, this was the first result we obtained. It is likely something more precise could be obtained by choosing layer-independent thresholds & annealing rates." 1499 | ] 1500 | } 1501 | ] 1502 | } --------------------------------------------------------------------------------