├── .gitignore
├── README.md
├── dvs
├── Net.py
├── bnn.py
├── conf.py
├── dataloader.py
├── earlystopping.py
├── functions.py
├── run.py
├── test_acc.py
├── tha.py
└── train.py
├── figs
└── temporal_code.png
├── fmnist
├── Net.py
├── bnn.py
├── conf.py
├── dataloader.py
├── earlystopping.py
├── functions.py
├── run.py
├── test_acc.py
├── tha.py
└── train.py
├── mnist
├── Net.py
├── bnn.py
├── conf.py
├── dataloader.py
├── earlystopping.py
├── functions.py
├── run.py
├── test_acc.py
├── tha.py
└── train.py
├── requirements.txt
├── shd
├── Net.py
├── bnn.py
├── conf.py
├── dataloader.py
├── earlystopping.py
├── functions.py
├── run.py
├── test_acc.py
├── tha.py
└── train.py
└── temporal
└── bounded_homeostasis.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | instructions.txt
2 | __pycache__
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bounded Homeostasis in Binarized Spiking Neural Networks
2 |
3 |
18 |
19 | ## Requirements
20 | A working `Python` (≥3.6) interpreter and the `pip` package manager. All required libraries and packages can be installed using `pip install -r requirements.txt`. To avoid potential package conflicts, the use of a `conda` environment is recommended. The following commands can be used to create and activate a separate `conda` environment, clone this repository, and to install all dependencies:
21 |
22 | ```
23 | conda create -n snn-tha python=3.8
24 | conda activate snn-tha
25 | git clone https://github.com/jeshraghian/snn-tha.git
26 | cd snn-tha
27 | pip install -r requirements.txt
28 | ```
29 |
30 | ## Code Execution
31 | To execute code, `cd` into one of four dataset directories, and then run `python run.py`.
32 |
33 | ## Hyperparameter Tuning
34 | * In each directory, `conf.py` defines all configuration parameters and hyperparameters for each dataset. The default parameters in this repo are identical to those for the binarized case with bounded homeostasis as reported in the corresponding paper.
35 | * To run binarized networks, set `"binarize" : True"` in `conf.py`. For optimized parameters, follow the values reported in the paper.
36 |
37 |
38 | # Temporal Coding
39 | Section 4 of the paper demonstrates the use of bounded homeostasis (using threshold annealing as the warm-up technique) in a spike-timing task. A fully connected network of structure 100-1000-1 is used, where a Poisson spike train is passed at the input, and the output neuron is trained to spike at
by linearly ramping up the membrane potential over time using a mean square error loss at each time step:
40 |
41 |
42 |
43 | The animated versions of the above figures are provided below, and can be reproduced in the corresponding notebook.
44 |
45 | ## Animations
46 |
47 | ### High Precision Weights, Normalized Threshold
48 |
49 | This is the optimal baseline, showing that it is a reasonably straightforward task to achieve.
50 |
51 | https://user-images.githubusercontent.com/40262130/150855093-4cdaa55b-7cad-4d5a-b5fa-9e482c6fe07e.mp4
52 |
53 | ### Binarized Weights, Normalized Threshold
54 | The results become significantly unstable when binarizing weights.
55 |
56 | https://user-images.githubusercontent.com/40262130/150855727-9ccfcca2-8b48-48cc-b5df-0d17f367968c.mp4
57 |
58 | A moving average over training iterations is used in an attempt to clean up the above plot, but the results remain senseless:
59 |
60 | https://user-images.githubusercontent.com/40262130/150855822-02d9177c-e08f-48f4-8753-d5c937e49c00.mp4
61 |
62 | ### Binarized Weights, Large Threshold
63 | Increasing the threshold of all neurons provides a higher dynamic range state-space. But increasing the threshold too high leads to the dead neuron problem. The animation below shows how spiking activity has been suppressed; the flat membrane potential is purely a result of the bias.
64 |
65 | https://user-images.githubusercontent.com/40262130/150856229-0a3ae7ce-5670-4545-b13c-06dd3ca992f3.mp4
66 |
67 | ### Binarized Weights, Threshold Annealing
68 | Now apply threshold annealing to use an evolving neuronal state-space to gradually lift spiking activity. This avoids the dead neuron problem in the large threshold case, and avoids the instability/memory leakage in the normalized threshold case.
69 |
70 | https://user-images.githubusercontent.com/40262130/150856483-f53f2156-4348-46da-9c0f-5f05f31cf677.mp4
71 |
72 | This now looks far more functional than all previous binarized cases.
73 | We can take a moving average to smooth out the impact of sudden reset dynamics. Although not as perfect as the high precision case, the binarized SNN continues to learn despite the excessively high final threshold.
74 |
75 | https://user-images.githubusercontent.com/40262130/150856726-aedb1d08-fe61-4b32-a3aa-6dcc9c76311a.mp4
76 |
77 |
--------------------------------------------------------------------------------
/dvs/Net.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import surrogate
4 |
5 | # torch
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | # local
11 | from bnn import *
12 |
13 |
14 | class Net(nn.Module):
15 | def __init__(self, config):
16 | super().__init__()
17 |
18 | self.thr1 = config['threshold1']
19 | self.thr2 = config['threshold2']
20 | self.thr3 = config['threshold3']
21 | slope = config['slope']
22 | beta = config['beta']
23 | self.num_steps = config['num_steps']
24 | self.batch_norm = config['batch_norm']
25 | p1 = config['dropout1']
26 | self.binarize = config['binarize']
27 |
28 |
29 | spike_grad = surrogate.fast_sigmoid(slope)
30 | # Initialize layers with spike operator
31 | self.bconv1 = BinaryConv2d(2, 16, 5, bias=False)
32 | self.conv1 = nn.Conv2d(2, 16, 5, bias=False)
33 | self.conv1_bn = nn.BatchNorm2d(16)
34 | self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)
35 | self.bconv2 = BinaryConv2d(16, 32, 5, bias=False)
36 | self.conv2 = nn.Conv2d(16, 32, 5, bias=False)
37 | self.conv2_bn = nn.BatchNorm2d(32)
38 | self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)
39 | self.bfc1 = BinaryLinear(32 * 5 * 5, 11)
40 | self.fc1 = nn.Linear(32 * 5 * 5, 11)
41 | self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad)
42 | self.dropout = nn.Dropout(p1)
43 |
44 |
45 | def forward(self, x):
46 |
47 | # Initialize hidden states and outputs at t=0
48 | mem1 = self.lif1.init_leaky()
49 | mem2 = self.lif2.init_leaky()
50 | mem3 = self.lif3.init_leaky()
51 |
52 | # Record the final layer
53 | spk3_rec = []
54 | mem3_rec = []
55 |
56 | # Binarization
57 |
58 | if self.binarize:
59 |
60 | for step in range(x.size(0)):
61 |
62 | # fc1weight = self.fc1.weight.data
63 | cur1 = F.avg_pool2d(self.bconv1(x[step]), 2)
64 | if self.batch_norm:
65 | cur1 = self.conv1_bn(cur1)
66 | spk1, mem1 = self.lif1(cur1, mem1)
67 | cur2 = F.avg_pool2d(self.bconv2(spk1), 2)
68 | if self.batch_norm:
69 | cur2 = self.conv2_bn(cur2)
70 | spk2, mem2 = self.lif2(cur2, mem2)
71 |
72 | cur3 = self.dropout(self.bfc1(spk2.flatten(1)))
73 | spk3, mem3 = self.lif3(cur3, mem3)
74 |
75 | spk3_rec.append(spk3)
76 | mem3_rec.append(mem3)
77 |
78 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)
79 |
80 | # Full Precision
81 |
82 | else:
83 |
84 | for step in range(x.size(0)):
85 | # fc1weight = self.fc1.weight.data
86 | cur1 = F.avg_pool2d(self.conv1(x[step]), 2)
87 | if self.batch_norm:
88 | cur1 = self.conv1_bn(cur1)
89 | spk1, mem1 = self.lif1(cur1, mem1)
90 | cur2 = F.avg_pool2d(self.conv2(spk1), 2)
91 | if self.batch_norm:
92 | cur2 = self.conv2_bn(cur2)
93 | spk2, mem2 = self.lif2(cur2, mem2)
94 |
95 | cur3 = self.dropout(self.fc1(spk2.flatten(1)))
96 | spk3, mem3 = self.lif3(cur3, mem3)
97 |
98 |
99 | spk3_rec.append(spk3)
100 | mem3_rec.append(mem3)
101 |
102 |
103 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)
104 |
--------------------------------------------------------------------------------
/dvs/bnn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from functions import *
8 |
9 |
10 | class BinaryTanh(nn.Module):
11 | def __init__(self):
12 | super(BinaryTanh, self).__init__()
13 | self.hardtanh = nn.Hardtanh()
14 |
15 | def forward(self, input):
16 | output = self.hardtanh(input)
17 | output = binarize(output)
18 | return output
19 |
20 |
21 | class BinaryLinear(nn.Linear):
22 |
23 | def forward(self, input):
24 | binary_weight = binarize(self.weight)
25 | if self.bias is None:
26 | return F.linear(input, binary_weight)
27 | else:
28 | return F.linear(input, binary_weight, self.bias)
29 |
30 | def reset_parameters(self):
31 | # Glorot initialization
32 | in_features, out_features = self.weight.size()
33 | stdv = math.sqrt(1.5 / (in_features + out_features))
34 | self.weight.data.uniform_(-stdv, stdv)
35 | if self.bias is not None:
36 | self.bias.data.zero_()
37 |
38 | self.weight.lr_scale = 1. / stdv
39 |
40 |
41 |
42 | class BinaryConv2d(nn.Conv2d):
43 |
44 | def forward(self, input):
45 | bw = binarize(self.weight)
46 | return F.conv2d(input, bw, self.bias, self.stride,
47 | self.padding, self.dilation, self.groups)
48 |
49 | def reset_parameters(self):
50 | # Glorot initialization
51 | in_features = self.in_channels
52 | out_features = self.out_channels
53 | for k in self.kernel_size:
54 | in_features *= k
55 | out_features *= k
56 | stdv = math.sqrt(1.5 / (in_features + out_features))
57 | self.weight.data.uniform_(-stdv, stdv)
58 | if self.bias is not None:
59 | self.bias.data.zero_()
60 |
61 | self.weight.lr_scale = 1. / stdv
--------------------------------------------------------------------------------
/dvs/conf.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | config = {
4 | 'exp_name' : 'dvs_tha',
5 | 'num_trials' : 5,
6 | 'num_epochs' : 500,
7 | 'binarize' : True,
8 | 'data_dir' : "/home/dvs",
9 | 'batch_size' : 8,
10 | 'seed' : 0,
11 | 'num_workers' : 0,
12 |
13 | # final run sweeps
14 | 'save_csv' : True,
15 | 'save_model' : True,
16 | 'early_stopping': True,
17 | 'patience': 100,
18 |
19 | # final params
20 | 'grad_clip' : True,
21 | 'weight_clip' : False,
22 | 'batch_norm' : False,
23 | 'dropout1' : 0.43,
24 | 'beta' : 0.9297,
25 | 'lr' : 1.765e-3,
26 | 'slope': 0.24,
27 |
28 | # threshold annealing. note: thr_final = threshold + thr_final
29 | 'threshold1' : 10.4,
30 | 'alpha_thr1' : 0.00333,
31 | 'thr_final1' : 1.7565,
32 |
33 | 'threshold2' : 16.62,
34 | 'alpha_thr2' : 0.0061,
35 | 'thr_final2' : 2.457,
36 |
37 | 'threshold3' : 6.81,
38 | 'alpha_thr3' : 0.173,
39 | 'thr_final3' : 9.655,
40 |
41 | # fixed params
42 | 'num_steps' : 100,
43 | 'correct_rate': 0.8,
44 | 'incorrect_rate' : 0.2,
45 | 'betas' : (0.9, 0.999),
46 | 't_0' : 735,
47 | 'eta_min' : 0,
48 | 'df_lr' : True, # return learning rate. Useful for scheduling
49 |
50 |
51 | }
52 |
53 | def optim_func(net, config):
54 | optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=config['betas'])
55 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1)
56 | return optimizer, scheduler
--------------------------------------------------------------------------------
/dvs/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import DataLoader
4 | from torchvision import datasets, transforms
5 |
6 | from snntorch.spikevision import spikedata
7 |
8 | def load_data(config):
9 | data_dir = config['data_dir']
10 |
11 | trainset = spikedata.DVSGesture(data_dir, train=True, num_steps=100, dt=5000, ds=4)
12 | testset = spikedata.DVSGesture(data_dir, train=False, num_steps=360, dt=5000, ds=4)
13 |
14 | return trainset, testset
--------------------------------------------------------------------------------
/dvs/earlystopping.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class EarlyStopping_acc:
5 | """Early stops the training if test acc doesn't improve after a given patience."""
6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
7 | """
8 | Args:
9 | patience (int): How long to wait after last time validation loss improved.
10 | Default: 7
11 | verbose (bool): If True, prints a message for each validation loss improvement.
12 | Default: False
13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
14 | Default: 0
15 | path (str): Path for the checkpoint to be saved to.
16 | Default: 'checkpoint.pt'
17 | trace_func (function): trace print function.
18 | Default: print
19 | """
20 | self.patience = patience
21 | self.verbose = verbose
22 | self.counter = 0
23 | self.best_score = None
24 | self.early_stop = False
25 | self.test_loss_min = 0
26 | self.delta = delta
27 | self.path = path
28 | self.trace_func = trace_func
29 | def __call__(self, test_loss, model):
30 |
31 | score = test_loss
32 |
33 | if self.best_score is None:
34 | self.best_score = score
35 | self.save_checkpoint(test_loss, model)
36 | elif score <= self.best_score + self.delta:
37 | self.counter += 1
38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
39 | if self.counter >= self.patience:
40 | self.early_stop = True
41 | self.counter = 0
42 | else:
43 | self.best_score = score
44 | self.save_checkpoint(test_loss, model)
45 | self.counter = 0
46 |
47 | def save_checkpoint(self, test_loss, model):
48 | '''Saves model when test acc increases.'''
49 | if self.verbose:
50 | self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}). Saving model ...')
51 | torch.save(model.state_dict(), self.path)
52 | self.test_loss_min = test_loss
--------------------------------------------------------------------------------
/dvs/functions.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Function
5 |
6 |
7 | class BinarizeF(Function):
8 |
9 | @staticmethod
10 | def forward(ctx, input):
11 | output = input.new(input.size())
12 | output[input >= 0] = 1
13 | output[input < 0] = -1
14 | return output
15 |
16 | @staticmethod
17 | def backward(ctx, grad_output):
18 | grad_input = grad_output.clone()
19 | return grad_input
20 |
21 | # aliases
22 | binarize = BinarizeF.apply
--------------------------------------------------------------------------------
/dvs/run.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import spikegen
4 | from snntorch import surrogate
5 |
6 | # torch
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 |
11 |
12 | # misc
13 | import numpy as np
14 | import pandas as pd
15 | import time
16 | import logging
17 |
18 | # local imports
19 | from dataloader import *
20 | from Net import *
21 | from test_acc import *
22 | from train import *
23 | from earlystopping import *
24 | from conf import *
25 |
26 | ####################################################
27 | ## Notes: modify config in conf to reparameterize ##
28 | ####################################################
29 |
30 | file_name = config['exp_name']
31 |
32 | ### to address conditional parameters, s.t. thr_final > threshold
33 | config['thr_final1'] = config['thr_final1'] + config['threshold1']
34 | config['thr_final2'] = config['thr_final2'] + config['threshold2']
35 | config['thr_final3'] = config['thr_final3'] + config['threshold3']
36 |
37 | threshold1 = config['threshold1']
38 | threshold2 = config['threshold2']
39 | threshold3 = config['threshold3']
40 |
41 | for trial in range(config['num_trials']):
42 |
43 | # file names
44 | SAVE_CSV = config['save_csv']
45 | SAVE_MODEL = config['save_model']
46 | csv_name = file_name + '_t' + str(trial) + '.csv'
47 | log_name = file_name + '_t' + str(trial) + '.log'
48 | model_name = file_name + '_t' + str(trial) + '.pt'
49 | num_epochs = config['num_epochs']
50 | torch.manual_seed(config['seed'])
51 |
52 | config['threshold1'] = threshold1
53 | config['threshold2'] = threshold2
54 | config['threshold3'] = threshold3
55 |
56 | # dataframes
57 | df_train_loss = pd.DataFrame()
58 | df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time'])
59 | df_lr = pd.DataFrame()
60 |
61 | # initialize network
62 | net = Net(config)
63 | device = "cpu"
64 | if torch.cuda.is_available():
65 | device = "cuda:0"
66 | if torch.cuda.device_count() > 1:
67 | net = nn.DataParallel(net)
68 | net.to(device)
69 |
70 | # net params
71 | criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate'])
72 | optimizer, scheduler = optim_func(net, config)
73 |
74 | # early stopping condition
75 | if config['early_stopping']:
76 | early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name)
77 | early_stopping.early_stop = False
78 | early_stopping.best_score = None
79 |
80 | # load data
81 | trainset, testset = load_data(config)
82 | config['dataset_length'] = len(trainset)
83 | trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True)
84 | testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False)
85 |
86 | print(f"=======Trial: {trial}=======")
87 |
88 | for epoch in range(num_epochs):
89 |
90 | # train
91 | start_time = time.time()
92 | loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device)
93 | epoch_time = time.time() - start_time
94 |
95 | # test
96 | test_acc = test_accuracy(config, net, testloader, device)
97 | print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}')
98 |
99 | if config['df_lr']:
100 | df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)])
101 | df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)])
102 | test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time'])
103 | df_test_acc = pd.concat([df_test_acc, test_data])
104 |
105 | if SAVE_CSV:
106 | df_train_loss.to_csv('loss_' + csv_name, index=False)
107 | df_test_acc.to_csv('acc_' + csv_name, index=False)
108 | if config['df_lr']:
109 | df_lr.to_csv('lr_' + csv_name, index=False)
110 |
111 | if config['early_stopping']:
112 | early_stopping(test_acc, net)
113 |
114 | if early_stopping.early_stop:
115 | print("Early stopping")
116 | early_stopping.early_stop = False
117 | early_stopping.best_score = None
118 | break
119 |
120 | if SAVE_MODEL and not config['early_stopping']:
121 | torch.save(net.state_dict(), model_name)
122 |
123 | # net.load_state_dict(torch.load(model_name))
124 |
--------------------------------------------------------------------------------
/dvs/test_acc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import snntorch as snn
3 | from snntorch import functional as SF
4 |
5 |
6 | def test_accuracy(config, net, testloader, device="cpu"):
7 |
8 |
9 | correct = 0
10 | total = 0
11 | with torch.no_grad():
12 | net.eval()
13 | for data in testloader:
14 | images, labels = data
15 | images, labels = images.to(device), labels.to(device) # .permute(1, 0, 2, 3, 4)
16 |
17 | outputs, _ = net(images.permute(1, 0, 2, 3, 4))
18 | accuracy = SF.accuracy_rate(outputs, labels)
19 |
20 | total += labels.size(0)
21 | correct += accuracy * labels.size(0)
22 |
23 | return 100 * correct / total
--------------------------------------------------------------------------------
/dvs/tha.py:
--------------------------------------------------------------------------------
1 | # exp relaxation implementation of THA based on Eq (4)
2 |
3 | def thr_annealing(config, network):
4 | alpha_thr1 = config['alpha_thr1']
5 | alpha_thr2 = config['alpha_thr2']
6 | alpha_thr3 = config['alpha_thr3']
7 |
8 | thr_final1 = config['thr_final1']
9 | thr_final2 = config['thr_final2']
10 | thr_final3 = config['thr_final3']
11 |
12 | network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1
13 | network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2
14 | network.lif3.threshold += (thr_final3 - network.lif3.threshold) * alpha_thr3
15 |
16 | return
--------------------------------------------------------------------------------
/dvs/train.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import spikegen
4 | from snntorch import surrogate
5 | from snntorch import functional as SF
6 |
7 | # torch
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | from torchvision import datasets, transforms
12 | import torch.nn.functional as F
13 | from torch.optim.lr_scheduler import StepLR
14 |
15 | # misc
16 | import os
17 | import numpy as np
18 | import math
19 | import itertools
20 | import matplotlib.pyplot as plt
21 | import pandas as pd
22 | import shutil
23 | import time
24 |
25 | from dataloader import *
26 | from test import *
27 | from test_acc import *
28 | from tha import *
29 |
30 | def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device):
31 |
32 | net.train()
33 | loss_accum = []
34 | lr_accum = []
35 |
36 | # TRAIN
37 | for data, labels in trainloader:
38 | data, labels = data.to(device), labels.to(device)
39 | spk_rec2, _ = net(data.permute(1, 0, 2, 3, 4))
40 | loss = criterion(spk_rec2, labels.long())
41 | optimizer.zero_grad()
42 | loss.backward()
43 |
44 | if config['grad_clip']:
45 | nn.utils.clip_grad_norm_(net.parameters(), 1.0)
46 | if config['weight_clip']:
47 | with torch.no_grad():
48 | for param in net.parameters():
49 | param.clamp_(-1, 1)
50 |
51 | optimizer.step()
52 | scheduler.step()
53 | thr_annealing(config, net)
54 |
55 |
56 | loss_accum.append(loss.item()/config['num_steps'])
57 | lr_accum.append(optimizer.param_groups[0]["lr"])
58 |
59 |
60 | return loss_accum, lr_accum
61 |
--------------------------------------------------------------------------------
/figs/temporal_code.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jeshraghian/snn-tha/f9c0b516a67a4be508b908176992b30894a18af9/figs/temporal_code.png
--------------------------------------------------------------------------------
/fmnist/Net.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import surrogate
4 |
5 | # torch
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | # local
11 | from bnn import *
12 |
13 | class Net(nn.Module):
14 | def __init__(self, config):
15 | super().__init__()
16 |
17 | self.thr1 = config['threshold1']
18 | self.thr2 = config['threshold2']
19 | self.thr3 = config['threshold3']
20 | slope = config['slope']
21 | beta = config['beta']
22 | self.num_steps = config['num_steps']
23 | self.batch_norm = config['batch_norm']
24 | p1 = config['dropout1']
25 | self.binarize = config['binarize']
26 |
27 | spike_grad = surrogate.fast_sigmoid(slope)
28 | self.bconv1 = BinaryConv2d(1, 16, 5, bias=False)
29 | self.conv1 = nn.Conv2d(1, 16, 5, bias=False)
30 | self.conv1_bn = nn.BatchNorm2d(16)
31 | self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)
32 | self.bconv2 = BinaryConv2d(16, 64, 5, bias=False)
33 | self.conv2 = nn.Conv2d(16, 64, 5, bias=False)
34 | self.conv2_bn = nn.BatchNorm2d(64)
35 | self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)
36 | self.bfc1 = BinaryLinear(64 * 4 * 4, 10)
37 | self.fc1 = nn.Linear(64 * 4 * 4, 10)
38 | self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad)
39 | self.dropout = nn.Dropout(p1)
40 |
41 | def forward(self, x):
42 |
43 | # Initialize hidden states and outputs at t=0
44 | mem1 = self.lif1.init_leaky()
45 | mem2 = self.lif2.init_leaky()
46 | mem3 = self.lif3.init_leaky()
47 |
48 | # Record the final layer
49 | spk3_rec = []
50 | mem3_rec = []
51 |
52 | # Binarization
53 | if self.binarize:
54 |
55 | for step in range(self.num_steps):
56 | cur1 = F.avg_pool2d(self.bconv1(x), 2)
57 | if self.batch_norm:
58 | cur1 = self.conv1_bn(cur1)
59 | spk1, mem1 = self.lif1(cur1, mem1)
60 | cur2 = F.avg_pool2d(self.bconv2(spk1), 2)
61 | if self.batch_norm:
62 | cur2 = self.conv2_bn(cur2)
63 | spk2, mem2 = self.lif2(cur2, mem2)
64 | cur3 = self.dropout(self.bfc1(spk2.flatten(1)))
65 | spk3, mem3 = self.lif3(cur3, mem3)
66 |
67 | spk3_rec.append(spk3)
68 | mem3_rec.append(mem3)
69 |
70 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)
71 |
72 | # Full Precision
73 | else:
74 |
75 | for step in range(self.num_steps):
76 |
77 | cur1 = F.avg_pool2d(self.conv1(x), 2)
78 | if self.batch_norm:
79 | cur1 = self.conv1_bn(cur1)
80 | spk1, mem1 = self.lif1(cur1, mem1)
81 | cur2 = F.avg_pool2d(self.conv2(spk1), 2)
82 | if self.batch_norm:
83 | cur2 = self.conv2_bn(cur2)
84 | spk2, mem2 = self.lif2(cur2, mem2)
85 | cur3 = self.dropout(self.fc1(spk2.flatten(1)))
86 | spk3, mem3 = self.lif3(cur3, mem3)
87 |
88 | spk3_rec.append(spk3)
89 | mem3_rec.append(mem3)
90 |
91 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)
92 |
--------------------------------------------------------------------------------
/fmnist/bnn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from functions import *
8 |
9 |
10 | class BinaryTanh(nn.Module):
11 | def __init__(self):
12 | super(BinaryTanh, self).__init__()
13 | self.hardtanh = nn.Hardtanh()
14 |
15 | def forward(self, input):
16 | output = self.hardtanh(input)
17 | output = binarize(output)
18 | return output
19 |
20 |
21 | class BinaryLinear(nn.Linear):
22 |
23 | def forward(self, input):
24 | binary_weight = binarize(self.weight)
25 | if self.bias is None:
26 | return F.linear(input, binary_weight)
27 | else:
28 | return F.linear(input, binary_weight, self.bias)
29 |
30 | def reset_parameters(self):
31 | # Glorot initialization
32 | in_features, out_features = self.weight.size()
33 | stdv = math.sqrt(1.5 / (in_features + out_features))
34 | self.weight.data.uniform_(-stdv, stdv)
35 | if self.bias is not None:
36 | self.bias.data.zero_()
37 |
38 | self.weight.lr_scale = 1. / stdv
39 |
40 |
41 |
42 | class BinaryConv2d(nn.Conv2d):
43 |
44 | def forward(self, input):
45 | bw = binarize(self.weight)
46 | return F.conv2d(input, bw, self.bias, self.stride,
47 | self.padding, self.dilation, self.groups)
48 |
49 | def reset_parameters(self):
50 | # Glorot initialization
51 | in_features = self.in_channels
52 | out_features = self.out_channels
53 | for k in self.kernel_size:
54 | in_features *= k
55 | out_features *= k
56 | stdv = math.sqrt(1.5 / (in_features + out_features))
57 | self.weight.data.uniform_(-stdv, stdv)
58 | if self.bias is not None:
59 | self.bias.data.zero_()
60 |
61 | self.weight.lr_scale = 1. / stdv
--------------------------------------------------------------------------------
/fmnist/conf.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | config = {
4 | 'exp_name' : 'fmnist_tha',
5 | 'num_trials' : 5,
6 | 'num_epochs' : 500,
7 | 'binarize' : True,
8 | 'data_dir' : "~/data/fmnist",
9 | 'batch_size' : 128,
10 | 'seed' : 0,
11 | 'num_workers' : 0,
12 |
13 | # final run sweeps
14 | 'save_csv' : True,
15 | 'save_model' : True,
16 | 'early_stopping': True,
17 | 'patience': 100,
18 |
19 | # final params
20 | 'grad_clip' : False,
21 | 'weight_clip' : False,
22 | 'batch_norm' : True,
23 | 'dropout1' : 0.648,
24 | 'beta' : 0.868,
25 | 'lr' : 8.4e-4,
26 | 'slope': 0.1557,
27 | 'momentum' : 0.855,
28 |
29 |
30 | # threshold annealing. note: thr_final = threshold + thr_final
31 | 'threshold1' : 6.9,
32 | 'alpha_thr1' : 0.0368,
33 | 'thr_final1' : 7.1456,
34 |
35 | 'threshold2' : 10.25,
36 | 'alpha_thr2' : 0.29687,
37 | 'thr_final2' : 12.826,
38 |
39 | 'threshold3' : 17.95,
40 | 'alpha_thr3' : 0.1048,
41 | 'thr_final3' : 9.936668,
42 |
43 | # fixed params
44 | 'num_steps' : 100,
45 | 'correct_rate': 0.8,
46 | 'incorrect_rate' : 0.2,
47 | 't_0' : 4688,
48 | 'eta_min' : 0,
49 | 'df_lr' : True, # return learning rate. Useful for scheduling
50 |
51 |
52 |
53 | }
54 |
55 | def optim_func(net, config):
56 | optimizer = torch.optim.SGD(net.parameters(), lr=config["lr"], momentum=config['momentum'])
57 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1)
58 | return optimizer, scheduler
59 |
--------------------------------------------------------------------------------
/fmnist/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import DataLoader
4 | from torchvision import datasets, transforms
5 |
6 | def load_data(config):
7 | data_dir = config['data_dir']
8 |
9 | transform = transforms.Compose([
10 | transforms.Resize((28, 28)),
11 | transforms.Grayscale(),
12 | transforms.ToTensor(),
13 | transforms.Normalize((0,), (1,))])
14 |
15 | trainset = datasets.FashionMNIST(data_dir, train=True, download=True, transform=transform)
16 | testset = datasets.FashionMNIST(data_dir, train=False, download=True, transform=transform)
17 |
18 | return trainset, testset
--------------------------------------------------------------------------------
/fmnist/earlystopping.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class EarlyStopping_acc:
5 | """Early stops the training if test acc doesn't improve after a given patience."""
6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
7 | """
8 | Args:
9 | patience (int): How long to wait after last time validation loss improved.
10 | Default: 7
11 | verbose (bool): If True, prints a message for each validation loss improvement.
12 | Default: False
13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
14 | Default: 0
15 | path (str): Path for the checkpoint to be saved to.
16 | Default: 'checkpoint.pt'
17 | trace_func (function): trace print function.
18 | Default: print
19 | """
20 | self.patience = patience
21 | self.verbose = verbose
22 | self.counter = 0
23 | self.best_score = None
24 | self.early_stop = False
25 | self.test_loss_min = 0
26 | self.delta = delta
27 | self.path = path
28 | self.trace_func = trace_func
29 | def __call__(self, test_loss, model):
30 |
31 | score = test_loss
32 |
33 | if self.best_score is None:
34 | self.best_score = score
35 | self.save_checkpoint(test_loss, model)
36 | elif score <= self.best_score + self.delta:
37 | self.counter += 1
38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
39 | if self.counter >= self.patience:
40 | self.early_stop = True
41 | self.counter = 0
42 | else:
43 | self.best_score = score
44 | self.save_checkpoint(test_loss, model)
45 | self.counter = 0
46 |
47 | def save_checkpoint(self, test_loss, model):
48 | '''Saves model when test acc increases.'''
49 | if self.verbose:
50 | self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}). Saving model ...')
51 | torch.save(model.state_dict(), self.path)
52 | self.test_loss_min = test_loss
--------------------------------------------------------------------------------
/fmnist/functions.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Function
5 |
6 |
7 | class BinarizeF(Function):
8 |
9 | @staticmethod
10 | def forward(ctx, input):
11 | output = input.new(input.size())
12 | output[input >= 0] = 1
13 | output[input < 0] = -1
14 | return output
15 |
16 | @staticmethod
17 | def backward(ctx, grad_output):
18 | grad_input = grad_output.clone()
19 | return grad_input
20 |
21 | # aliases
22 | binarize = BinarizeF.apply
--------------------------------------------------------------------------------
/fmnist/run.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import spikegen
4 | from snntorch import surrogate
5 |
6 | # torch
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 |
11 |
12 | # misc
13 | import numpy as np
14 | import pandas as pd
15 | import time
16 | import logging
17 |
18 | # local imports
19 | from dataloader import *
20 | from Net import *
21 | from test_acc import *
22 | from train import *
23 | from earlystopping import *
24 | from conf import *
25 |
26 | ####################################################
27 | ## Notes: modify config in conf to reparameterize ##
28 | ####################################################
29 |
30 |
31 | file_name = config['exp_name']
32 |
33 | ### to address conditional parameters, s.t. thr_final > threshold
34 | config['thr_final1'] = config['thr_final1'] + config['threshold1']
35 | config['thr_final2'] = config['thr_final2'] + config['threshold2']
36 | config['thr_final3'] = config['thr_final3'] + config['threshold3']
37 |
38 | threshold1 = config['threshold1']
39 | threshold2 = config['threshold2']
40 | threshold3 = config['threshold3']
41 |
42 | for trial in range(config['num_trials']):
43 |
44 | # file names
45 | SAVE_CSV = config['save_csv']
46 | SAVE_MODEL = config['save_model']
47 | csv_name = file_name + '_t' + str(trial) + '.csv'
48 | log_name = file_name + '_t' + str(trial) + '.log'
49 | model_name = file_name + '_t' + str(trial) + '.pt'
50 | num_epochs = config['num_epochs']
51 | torch.manual_seed(config['seed'])
52 |
53 | config['threshold1'] = threshold1
54 | config['threshold2'] = threshold2
55 | config['threshold3'] = threshold3
56 |
57 |
58 | # dataframes
59 | df_train_loss = pd.DataFrame()
60 | df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time'])
61 | df_lr = pd.DataFrame()
62 |
63 |
64 | # initialize network
65 | net = Net(config)
66 | device = "cpu"
67 | if torch.cuda.is_available():
68 | device = "cuda:0"
69 | if torch.cuda.device_count() > 1:
70 | net = nn.DataParallel(net)
71 | net.to(device)
72 |
73 | # net params
74 | criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate'])
75 | optimizer, scheduler = optim_func(net, config)
76 |
77 | # early stopping condition
78 | if config['early_stopping']:
79 | early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name)
80 | early_stopping.early_stop = False
81 | early_stopping.best_score = None
82 |
83 | # load data
84 | trainset, testset = load_data(config)
85 | config['dataset_length'] = len(trainset)
86 | trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True)
87 | testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False)
88 |
89 | print(f"=======Trial: {trial}=======")
90 |
91 | for epoch in range(num_epochs):
92 |
93 | # train
94 | start_time = time.time()
95 | loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device)
96 | epoch_time = time.time() - start_time
97 |
98 | # test
99 | test_acc = test_accuracy(config, net, testloader, device)
100 | print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}')
101 |
102 | if config['df_lr']:
103 | df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)])
104 | df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)])
105 | test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time'])
106 | df_test_acc = pd.concat([df_test_acc, test_data])
107 |
108 | if SAVE_CSV:
109 | df_train_loss.to_csv('loss_' + csv_name, index=False)
110 | df_test_acc.to_csv('acc_' + csv_name, index=False)
111 | if config['df_lr']:
112 | df_lr.to_csv('lr_' + csv_name, index=False)
113 |
114 | if config['early_stopping']:
115 | early_stopping(test_acc, net)
116 |
117 | if early_stopping.early_stop:
118 | print("Early stopping")
119 | early_stopping.early_stop = False
120 | early_stopping.best_score = None
121 | break
122 |
123 | if SAVE_MODEL and not config['early_stopping']:
124 | torch.save(net.state_dict(), model_name)
125 |
126 | # net.load_state_dict(torch.load(model_name))
127 |
--------------------------------------------------------------------------------
/fmnist/test_acc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import snntorch as snn
3 | from snntorch import functional as SF
4 |
5 |
6 | def test_accuracy(config, net, testloader, device="cpu"):
7 |
8 |
9 | correct = 0
10 | total = 0
11 | with torch.no_grad():
12 | net.eval()
13 | for data in testloader:
14 | images, labels = data
15 | images, labels = images.to(device), labels.to(device)
16 |
17 | outputs, _ = net(images)
18 | accuracy = SF.accuracy_rate(outputs, labels)
19 |
20 | total += labels.size(0)
21 | correct += accuracy * labels.size(0)
22 |
23 | return 100 * correct / total
--------------------------------------------------------------------------------
/fmnist/tha.py:
--------------------------------------------------------------------------------
1 | # exp relaxation implementation of THA based on Eq (4)
2 |
3 | def thr_annealing(config, network):
4 | alpha_thr1 = config['alpha_thr1']
5 | alpha_thr2 = config['alpha_thr2']
6 | alpha_thr3 = config['alpha_thr3']
7 |
8 | thr_final1 = config['thr_final1']
9 | thr_final2 = config['thr_final2']
10 | thr_final3 = config['thr_final3']
11 |
12 | network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1
13 | network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2
14 | network.lif3.threshold += (thr_final3 - network.lif3.threshold) * alpha_thr3
15 |
16 | return
--------------------------------------------------------------------------------
/fmnist/train.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import spikegen
4 | from snntorch import surrogate
5 | from snntorch import functional as SF
6 |
7 | # torch
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | from torchvision import datasets, transforms
12 | import torch.nn.functional as F
13 | from torch.optim.lr_scheduler import StepLR
14 |
15 | # misc
16 | import os
17 | import numpy as np
18 | import math
19 | import itertools
20 | import matplotlib.pyplot as plt
21 | import pandas as pd
22 | import shutil
23 | import time
24 |
25 | from dataloader import *
26 | from test import *
27 | from test_acc import *
28 | from tha import *
29 |
30 |
31 | def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device):
32 |
33 | net.train()
34 | loss_accum = []
35 | lr_accum = []
36 |
37 | # TRAIN
38 | for data, labels in trainloader:
39 | data, labels = data.to(device), labels.to(device)
40 |
41 | spk_rec2, _ = net(data)
42 | loss = criterion(spk_rec2, labels)
43 | optimizer.zero_grad()
44 | loss.backward()
45 |
46 | if config['grad_clip']:
47 | nn.utils.clip_grad_norm_(net.parameters(), 1.0)
48 | if config['weight_clip']:
49 | with torch.no_grad():
50 | for param in net.parameters():
51 | param.clamp_(-1, 1)
52 |
53 | optimizer.step()
54 | scheduler.step()
55 | thr_annealing(config, net)
56 |
57 |
58 | loss_accum.append(loss.item()/config['num_steps'])
59 | lr_accum.append(optimizer.param_groups[0]["lr"])
60 |
61 |
62 | return loss_accum, lr_accum
63 |
--------------------------------------------------------------------------------
/mnist/Net.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import surrogate
4 |
5 | # torch
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | # local
11 | from bnn import *
12 |
13 | class Net(nn.Module):
14 | def __init__(self, config):
15 | super().__init__()
16 |
17 | self.thr1 = config['threshold1']
18 | self.thr2 = config['threshold2']
19 | self.thr3 = config['threshold3']
20 | slope = config['slope']
21 | beta = config['beta']
22 | self.num_steps = config['num_steps']
23 | self.batch_norm = config['batch_norm']
24 | p1 = config['dropout1']
25 | self.binarize = config['binarize']
26 |
27 | spike_grad = surrogate.fast_sigmoid(slope)
28 | # Initialize layers with spike operator
29 | self.bconv1 = BinaryConv2d(1, 16, 5, bias=False)
30 | self.conv1 = nn.Conv2d(1, 16, 5, bias=False)
31 | self.conv1_bn = nn.BatchNorm2d(16)
32 | self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)
33 | self.bconv2 = BinaryConv2d(16, 64, 5, bias=False)
34 | self.conv2 = nn.Conv2d(16, 64, 5, bias=False)
35 | self.conv2_bn = nn.BatchNorm2d(64)
36 | self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)
37 | self.bfc1 = BinaryLinear(64 * 4 * 4, 10)
38 | self.fc1 = nn.Linear(64 * 4 * 4, 10)
39 | self.lif3 = snn.Leaky(beta, threshold=self.thr3, spike_grad=spike_grad)
40 | self.dropout = nn.Dropout(p1)
41 |
42 | def forward(self, x):
43 |
44 | # Initialize hidden states and outputs at t=0
45 | mem1 = self.lif1.init_leaky()
46 | mem2 = self.lif2.init_leaky()
47 | mem3 = self.lif3.init_leaky()
48 |
49 | # Record the final layer
50 | spk3_rec = []
51 | mem3_rec = []
52 |
53 | # Binarized
54 | if self.binarize:
55 |
56 | for step in range(self.num_steps):
57 |
58 | cur1 = F.avg_pool2d(self.bconv1(x), 2)
59 | if self.batch_norm:
60 | cur1 = self.conv1_bn(cur1)
61 | spk1, mem1 = self.lif1(cur1, mem1)
62 | cur2 = F.avg_pool2d(self.bconv2(spk1), 2)
63 | if self.batch_norm:
64 | cur2 = self.conv2_bn(cur2)
65 | spk2, mem2 = self.lif2(cur2, mem2)
66 | cur3 = self.dropout(self.bfc1(spk2.flatten(1)))
67 | spk3, mem3 = self.lif3(cur3, mem3)
68 |
69 | spk3_rec.append(spk3)
70 | mem3_rec.append(mem3)
71 |
72 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)
73 |
74 | # Full Precision
75 | else:
76 |
77 | for step in range(self.num_steps):
78 |
79 | cur1 = F.avg_pool2d(self.conv1(x), 2)
80 | if self.batch_norm:
81 | cur1 = self.conv1_bn(cur1)
82 | spk1, mem1 = self.lif1(cur1, mem1)
83 | cur2 = F.avg_pool2d(self.conv2(spk1), 2)
84 | if self.batch_norm:
85 | cur2 = self.conv2_bn(cur2)
86 | spk2, mem2 = self.lif2(cur2, mem2)
87 | cur3 = self.dropout(self.fc1(spk2.flatten(1)))
88 | spk3, mem3 = self.lif3(cur3, mem3)
89 |
90 | spk3_rec.append(spk3)
91 | mem3_rec.append(mem3)
92 |
93 | return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)
94 |
--------------------------------------------------------------------------------
/mnist/bnn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from functions import *
8 |
9 |
10 | class BinaryTanh(nn.Module):
11 | def __init__(self):
12 | super(BinaryTanh, self).__init__()
13 | self.hardtanh = nn.Hardtanh()
14 |
15 | def forward(self, input):
16 | output = self.hardtanh(input)
17 | output = binarize(output)
18 | return output
19 |
20 |
21 | class BinaryLinear(nn.Linear):
22 |
23 | def forward(self, input):
24 | binary_weight = binarize(self.weight)
25 | if self.bias is None:
26 | return F.linear(input, binary_weight)
27 | else:
28 | return F.linear(input, binary_weight, self.bias)
29 |
30 | def reset_parameters(self):
31 | # Glorot initialization
32 | in_features, out_features = self.weight.size()
33 | stdv = math.sqrt(1.5 / (in_features + out_features))
34 | self.weight.data.uniform_(-stdv, stdv)
35 | if self.bias is not None:
36 | self.bias.data.zero_()
37 |
38 | self.weight.lr_scale = 1. / stdv
39 |
40 |
41 |
42 | class BinaryConv2d(nn.Conv2d):
43 |
44 | def forward(self, input):
45 | bw = binarize(self.weight)
46 | return F.conv2d(input, bw, self.bias, self.stride,
47 | self.padding, self.dilation, self.groups)
48 |
49 | def reset_parameters(self):
50 | # Glorot initialization
51 | in_features = self.in_channels
52 | out_features = self.out_channels
53 | for k in self.kernel_size:
54 | in_features *= k
55 | out_features *= k
56 | stdv = math.sqrt(1.5 / (in_features + out_features))
57 | self.weight.data.uniform_(-stdv, stdv)
58 | if self.bias is not None:
59 | self.bias.data.zero_()
60 |
61 | self.weight.lr_scale = 1. / stdv
--------------------------------------------------------------------------------
/mnist/conf.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | config = {
4 | 'exp_name' : 'mnist_tha',
5 | 'num_trials' : 5,
6 | 'num_epochs' : 500,
7 | 'binarize' : True,
8 | 'data_dir' : "~/data/mnist",
9 | 'batch_size' : 128,
10 | 'seed' : 0,
11 | 'num_workers' : 0,
12 |
13 | # final run sweeps
14 | 'save_csv' : True,
15 | 'save_model' : True,
16 | 'early_stopping': True,
17 | 'patience': 100,
18 |
19 | # final params
20 | 'grad_clip' : False,
21 | 'weight_clip' : False,
22 | 'batch_norm' : True,
23 | 'dropout1' : 0.02856,
24 | 'beta' : 0.99,
25 | 'lr' : 9.97e-3,
26 | 'slope': 10.22,
27 |
28 | # threshold annealing. note: thr_final = threshold + thr_final
29 | 'threshold1' : 11.666,
30 | 'alpha_thr1' : 0.024,
31 | 'thr_final1' : 4.317,
32 |
33 | 'threshold2' : 14.105,
34 | 'alpha_thr2' : 0.119,
35 | 'thr_final2' : 16.29,
36 |
37 | 'threshold3' : 0.6656,
38 | 'alpha_thr3' : 0.0011,
39 | 'thr_final3' : 3.496,
40 |
41 | # fixed params
42 | 'num_steps' : 100,
43 | 'correct_rate': 0.8,
44 | 'incorrect_rate' : 0.2,
45 | 'betas' : (0.9, 0.999),
46 | 't_0' : 4688,
47 | 'eta_min' : 0,
48 | 'df_lr' : True, # return learning rate. Useful for scheduling
49 |
50 |
51 |
52 | }
53 |
54 | def optim_func(net, config):
55 | optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=config['betas'])
56 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1)
57 | return optimizer, scheduler
58 |
--------------------------------------------------------------------------------
/mnist/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import DataLoader
4 | from torchvision import datasets, transforms
5 |
6 | def load_data(config):
7 | data_dir = config['data_dir']
8 |
9 | transform = transforms.Compose([
10 | transforms.Resize((28, 28)),
11 | transforms.Grayscale(),
12 | transforms.ToTensor(),
13 | transforms.Normalize((0,), (1,))])
14 |
15 | trainset = datasets.MNIST(data_dir, train=True, download=True, transform=transform)
16 | testset = datasets.MNIST(data_dir, train=False, download=True, transform=transform)
17 |
18 | return trainset, testset
--------------------------------------------------------------------------------
/mnist/earlystopping.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class EarlyStopping_acc:
5 | """Early stops the training if test acc doesn't improve after a given patience."""
6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
7 | """
8 | Args:
9 | patience (int): How long to wait after last time validation loss improved.
10 | Default: 7
11 | verbose (bool): If True, prints a message for each validation loss improvement.
12 | Default: False
13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
14 | Default: 0
15 | path (str): Path for the checkpoint to be saved to.
16 | Default: 'checkpoint.pt'
17 | trace_func (function): trace print function.
18 | Default: print
19 | """
20 | self.patience = patience
21 | self.verbose = verbose
22 | self.counter = 0
23 | self.best_score = None
24 | self.early_stop = False
25 | self.test_loss_min = 0
26 | self.delta = delta
27 | self.path = path
28 | self.trace_func = trace_func
29 | def __call__(self, test_loss, model):
30 |
31 | score = test_loss
32 |
33 | if self.best_score is None:
34 | self.best_score = score
35 | self.save_checkpoint(test_loss, model)
36 | elif score <= self.best_score + self.delta:
37 | self.counter += 1
38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
39 | if self.counter >= self.patience:
40 | self.early_stop = True
41 | self.counter = 0
42 | else:
43 | self.best_score = score
44 | self.save_checkpoint(test_loss, model)
45 | self.counter = 0
46 |
47 | def save_checkpoint(self, test_loss, model):
48 | '''Saves model when test acc increases.'''
49 | if self.verbose:
50 | self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}). Saving model ...')
51 | torch.save(model.state_dict(), self.path)
52 | self.test_loss_min = test_loss
--------------------------------------------------------------------------------
/mnist/functions.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Function
5 |
6 |
7 | class BinarizeF(Function):
8 |
9 | @staticmethod
10 | def forward(ctx, input):
11 | output = input.new(input.size())
12 | output[input >= 0] = 1
13 | output[input < 0] = -1
14 | return output
15 |
16 | @staticmethod
17 | def backward(ctx, grad_output):
18 | grad_input = grad_output.clone()
19 | return grad_input
20 |
21 | # aliases
22 | binarize = BinarizeF.apply
--------------------------------------------------------------------------------
/mnist/run.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import spikegen
4 | from snntorch import surrogate
5 |
6 | # torch
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 |
11 | # misc
12 | import numpy as np
13 | import pandas as pd
14 | import time
15 | import logging
16 |
17 | # local imports
18 | from dataloader import *
19 | from Net import *
20 | from test_acc import *
21 | from train import *
22 | from earlystopping import *
23 | from conf import *
24 |
25 | ####################################################
26 | ## Notes: modify config in conf to reparameterize ##
27 | ####################################################
28 |
29 | file_name = config['exp_name']
30 |
31 | ### to address conditional parameters, s.t. thr_final > threshold
32 | config['thr_final1'] = config['thr_final1'] + config['threshold1']
33 | config['thr_final2'] = config['thr_final2'] + config['threshold2']
34 | config['thr_final3'] = config['thr_final3'] + config['threshold3']
35 |
36 | threshold1 = config['threshold1']
37 | threshold2 = config['threshold2']
38 | threshold3 = config['threshold3']
39 |
40 |
41 |
42 | for trial in range(config['num_trials']):
43 |
44 |
45 | # file names
46 | SAVE_CSV = config['save_csv']
47 | SAVE_MODEL = config['save_model']
48 | csv_name = file_name + '_t' + str(trial) + '.csv'
49 | log_name = file_name + '_t' + str(trial) + '.log'
50 | model_name = file_name + '_t' + str(trial) + '.pt'
51 | num_epochs = config['num_epochs']
52 | torch.manual_seed(config['seed'])
53 |
54 | config['threshold1'] = threshold1
55 | config['threshold2'] = threshold2
56 | config['threshold3'] = threshold3
57 |
58 | # dataframes
59 | df_train_loss = pd.DataFrame()
60 | df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time'])
61 | df_lr = pd.DataFrame()
62 |
63 |
64 | # initialize network
65 | net = Net(config)
66 | device = "cpu"
67 | if torch.cuda.is_available():
68 | device = "cuda:0"
69 | if torch.cuda.device_count() > 1:
70 | net = nn.DataParallel(net)
71 | net.to(device)
72 |
73 | # net params
74 | criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate'])
75 | optimizer, scheduler = optim_func(net, config)
76 |
77 | # early stopping condition
78 | if config['early_stopping']:
79 | early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name)
80 | early_stopping.early_stop = False
81 | early_stopping.best_score = None
82 |
83 | # load data
84 | trainset, testset = load_data(config)
85 | config['dataset_length'] = len(trainset)
86 | trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True)
87 | testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False)
88 |
89 | print(f"=======Trial: {trial}=======")
90 |
91 | for epoch in range(num_epochs):
92 |
93 | # train
94 | start_time = time.time()
95 | loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device)
96 | epoch_time = time.time() - start_time
97 |
98 | # test
99 | test_acc = test_accuracy(config, net, testloader, device)
100 | print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}')
101 |
102 | if config['df_lr']:
103 | df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)])
104 | df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)])
105 | test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time'])
106 | df_test_acc = pd.concat([df_test_acc, test_data])
107 |
108 | if SAVE_CSV:
109 | df_train_loss.to_csv('loss_' + csv_name, index=False)
110 | df_test_acc.to_csv('acc_' + csv_name, index=False)
111 | if config['df_lr']:
112 | df_lr.to_csv('lr_' + csv_name, index=False)
113 |
114 | if config['early_stopping']:
115 | early_stopping(test_acc, net)
116 |
117 | if early_stopping.early_stop:
118 | print("Early stopping")
119 | early_stopping.early_stop = False
120 | early_stopping.best_score = None
121 | break
122 |
123 | if SAVE_MODEL and not config['early_stopping']:
124 | torch.save(net.state_dict(), model_name)
125 |
126 | # net.load_state_dict(torch.load(model_name))
127 |
--------------------------------------------------------------------------------
/mnist/test_acc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import snntorch as snn
3 | from snntorch import functional as SF
4 |
5 |
6 | def test_accuracy(config, net, testloader, device="cpu"):
7 |
8 |
9 | correct = 0
10 | total = 0
11 | with torch.no_grad():
12 | net.eval()
13 | for data in testloader:
14 | images, labels = data
15 | images, labels = images.to(device), labels.to(device)
16 |
17 | outputs, _ = net(images)
18 | accuracy = SF.accuracy_rate(outputs, labels)
19 |
20 | total += labels.size(0)
21 | correct += accuracy * labels.size(0)
22 |
23 | return 100 * correct / total
--------------------------------------------------------------------------------
/mnist/tha.py:
--------------------------------------------------------------------------------
1 | # exp relaxation implementation of THA based on Eq (4)
2 |
3 | def thr_annealing(config, network):
4 | alpha_thr1 = config['alpha_thr1']
5 | alpha_thr2 = config['alpha_thr2']
6 | alpha_thr3 = config['alpha_thr3']
7 |
8 | thr_final1 = config['thr_final1']
9 | thr_final2 = config['thr_final2']
10 | thr_final3 = config['thr_final3']
11 |
12 | network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1
13 | network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2
14 | network.lif3.threshold += (thr_final3 - network.lif3.threshold) * alpha_thr3
15 |
16 | return
--------------------------------------------------------------------------------
/mnist/train.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import spikegen
4 | from snntorch import surrogate
5 | from snntorch import functional as SF
6 |
7 | # torch
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | from torchvision import datasets, transforms
12 | import torch.nn.functional as F
13 | from torch.optim.lr_scheduler import StepLR
14 |
15 | # misc
16 | import os
17 | import numpy as np
18 | import math
19 | import itertools
20 | import matplotlib.pyplot as plt
21 | import pandas as pd
22 | import shutil
23 | import time
24 |
25 | from dataloader import *
26 | from test import *
27 | from test_acc import *
28 | from tha import *
29 |
30 |
31 | def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device):
32 |
33 | net.train()
34 | loss_accum = []
35 | lr_accum = []
36 |
37 | # TRAIN
38 | for data, labels in trainloader:
39 | data, labels = data.to(device), labels.to(device)
40 |
41 | spk_rec2, _ = net(data)
42 | loss = criterion(spk_rec2, labels)
43 | optimizer.zero_grad()
44 | loss.backward()
45 |
46 | if config['grad_clip']:
47 | nn.utils.clip_grad_norm_(net.parameters(), 1.0)
48 | if config['weight_clip']:
49 | with torch.no_grad():
50 | for param in net.parameters():
51 | param.clamp_(-1, 1)
52 |
53 | optimizer.step()
54 | scheduler.step()
55 | thr_annealing(config, net)
56 |
57 |
58 | loss_accum.append(loss.item()/config['num_steps'])
59 | lr_accum.append(optimizer.param_groups[0]["lr"])
60 |
61 |
62 | return loss_accum, lr_accum
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | snntorch
4 | pandas
5 | matplotlib
6 | numpy
--------------------------------------------------------------------------------
/shd/Net.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import surrogate
4 |
5 | # torch
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | # local
11 | from bnn import *
12 |
13 |
14 | class Net(nn.Module):
15 | def __init__(self, config):
16 | super().__init__()
17 |
18 | self.thr1 = config['threshold1']
19 | self.thr2 = config['threshold2']
20 | slope = config['slope']
21 | beta = config['beta']
22 | self.num_steps = config['num_steps']
23 | p1 = config['dropout1']
24 | p2 = config['dropout2']
25 | self.binarize = config['binarize']
26 | num_hidden = 3000
27 | spike_grad = surrogate.fast_sigmoid(slope)
28 | # Initialize layers with spike operator
29 |
30 |
31 | self.bfc1 = BinaryLinear(700, num_hidden)
32 | self.fc1 = nn.Linear(700, num_hidden)
33 | self.lif1 = snn.Leaky(beta, threshold=self.thr1, spike_grad=spike_grad)
34 | self.dropout1 = nn.Dropout(p1)
35 |
36 | self.bfc2 = BinaryLinear(num_hidden, 20)
37 | self.fc2 = nn.Linear(num_hidden, 20)
38 | self.lif2 = snn.Leaky(beta, threshold=self.thr2, spike_grad=spike_grad)
39 | self.dropout2 = nn.Dropout(p2)
40 |
41 |
42 | def forward(self, x):
43 |
44 | # Initialize hidden states and outputs at t=0
45 | mem1 = self.lif1.init_leaky()
46 | mem2 = self.lif2.init_leaky()
47 |
48 | # Record the final layer
49 | spk2_rec = []
50 | mem2_rec = []
51 |
52 | # Binarization
53 |
54 | if self.binarize:
55 |
56 | for step in range(x.size(0)):
57 |
58 | cur1 = self.dropout1(self.bfc1(x[step].flatten(1)))
59 | spk1, mem1 = self.lif1(cur1, mem1)
60 | cur2 = self.dropout2(self.bfc2(spk1))
61 | spk2, mem2 = self.lif2(cur2, mem2)
62 |
63 |
64 | spk2_rec.append(spk2)
65 | mem2_rec.append(mem2)
66 |
67 | return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
68 |
69 | # Full Precision
70 |
71 | else:
72 |
73 | for step in range(x.size(0)):
74 |
75 | cur1 = self.dropout1(self.fc1(x[step].flatten(1)))
76 | spk1, mem1 = self.lif1(cur1, mem1)
77 | cur2 = self.dropout2(self.fc2(spk1))
78 | spk2, mem2 = self.lif2(cur2, mem2)
79 | spk2_rec.append(spk2)
80 | mem2_rec.append(mem2)
81 |
82 | return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
83 |
84 |
--------------------------------------------------------------------------------
/shd/bnn.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from functions import *
8 |
9 |
10 | class BinaryTanh(nn.Module):
11 | def __init__(self):
12 | super(BinaryTanh, self).__init__()
13 | self.hardtanh = nn.Hardtanh()
14 |
15 | def forward(self, input):
16 | output = self.hardtanh(input)
17 | output = binarize(output)
18 | return output
19 |
20 |
21 | class BinaryLinear(nn.Linear):
22 |
23 | def forward(self, input):
24 | binary_weight = binarize(self.weight)
25 | if self.bias is None:
26 | return F.linear(input, binary_weight)
27 | else:
28 | return F.linear(input, binary_weight, self.bias)
29 |
30 | def reset_parameters(self):
31 | # Glorot initialization
32 | in_features, out_features = self.weight.size()
33 | stdv = math.sqrt(1.5 / (in_features + out_features))
34 | self.weight.data.uniform_(-stdv, stdv)
35 | if self.bias is not None:
36 | self.bias.data.zero_()
37 |
38 | self.weight.lr_scale = 1. / stdv
39 |
40 |
41 |
42 | class BinaryConv2d(nn.Conv2d):
43 |
44 | def forward(self, input):
45 | bw = binarize(self.weight)
46 | return F.conv2d(input, bw, self.bias, self.stride,
47 | self.padding, self.dilation, self.groups)
48 |
49 | def reset_parameters(self):
50 | # Glorot initialization
51 | in_features = self.in_channels
52 | out_features = self.out_channels
53 | for k in self.kernel_size:
54 | in_features *= k
55 | out_features *= k
56 | stdv = math.sqrt(1.5 / (in_features + out_features))
57 | self.weight.data.uniform_(-stdv, stdv)
58 | if self.bias is not None:
59 | self.bias.data.zero_()
60 |
61 | self.weight.lr_scale = 1. / stdv
--------------------------------------------------------------------------------
/shd/conf.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | config = {
4 | 'exp_name' : 'shd_tha',
5 | 'num_trials' : 5,
6 | 'num_epochs' : 5,
7 | 'binarize' : True,
8 | 'data_dir' : "/home/shd",
9 | 'batch_size' : 32,
10 | 'seed' : 0,
11 | 'num_workers' : 0,
12 |
13 | # final run sweeps
14 | 'save_csv' : True,
15 | 'save_model' : True,
16 | 'early_stopping': True,
17 | 'patience': 100,
18 |
19 | # final params
20 | 'grad_clip' : True,
21 | 'weight_clip' : True,
22 | 'batch_norm' : True,
23 | 'dropout2' : 0.0176,
24 | 'dropout1' : 0.186,
25 | 'beta' : 0.950,
26 | 'lr' : 6.54e-4,
27 | 'slope': 0.257,
28 |
29 |
30 | # threshold annealing. note: thr_final = threshold + thr_final
31 | 'threshold1' : 13.504,
32 | 'alpha_thr1' : 2.78e-5,
33 | 'thr_final1' : 31.767,
34 |
35 | 'threshold2' : 11.20,
36 | 'alpha_thr2' : 1.36e-5,
37 | 'thr_final2' : 39.92,
38 |
39 | # fixed params
40 | 'num_steps' : 100,
41 | 'correct_rate': 0.8,
42 | 'incorrect_rate' : 0.2,
43 | 'betas1' : 0.9,
44 | 'betas2' : 0.999,
45 | 't_0' : 2604,
46 | 'eta_min' : 0,
47 | 'df_lr' : True, # return learning rate. Useful for scheduling
48 |
49 |
50 | }
51 |
52 | def optim_func(net, config):
53 | optimizer = torch.optim.Adam(net.parameters(), lr=config["lr"], betas=(config['betas1'], config['betas2']))
54 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['t_0'], eta_min=config['eta_min'], last_epoch=-1)
55 | return optimizer, scheduler
--------------------------------------------------------------------------------
/shd/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data import DataLoader
4 | from torchvision import datasets, transforms
5 |
6 | from snntorch.spikevision import spikedata
7 |
8 | def load_data(config):
9 |
10 | data_dir = config['data_dir']
11 | dt_scalar = 3 # set to 2 for float in our experiments
12 |
13 |
14 | dt = int(1000*dt_scalar)
15 | num_steps = int(1000/dt_scalar)
16 |
17 | trainset = spikedata.SHD(data_dir, train=True, num_steps=num_steps, dt=dt)
18 | testset = spikedata.SHD(data_dir, train=False, num_steps=num_steps, dt=dt)
19 |
20 | return trainset, testset
--------------------------------------------------------------------------------
/shd/earlystopping.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class EarlyStopping_acc:
5 | """Early stops the training if test acc doesn't improve after a given patience."""
6 | def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
7 | """
8 | Args:
9 | patience (int): How long to wait after last time validation loss improved.
10 | Default: 7
11 | verbose (bool): If True, prints a message for each validation loss improvement.
12 | Default: False
13 | delta (float): Minimum change in the monitored quantity to qualify as an improvement.
14 | Default: 0
15 | path (str): Path for the checkpoint to be saved to.
16 | Default: 'checkpoint.pt'
17 | trace_func (function): trace print function.
18 | Default: print
19 | """
20 | self.patience = patience
21 | self.verbose = verbose
22 | self.counter = 0
23 | self.best_score = None
24 | self.early_stop = False
25 | self.test_loss_min = 0
26 | self.delta = delta
27 | self.path = path
28 | self.trace_func = trace_func
29 | def __call__(self, test_loss, model):
30 |
31 | score = test_loss
32 |
33 | if self.best_score is None:
34 | self.best_score = score
35 | self.save_checkpoint(test_loss, model)
36 | elif score <= self.best_score + self.delta:
37 | self.counter += 1
38 | self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
39 | if self.counter >= self.patience:
40 | self.early_stop = True
41 | self.counter = 0
42 | else:
43 | self.best_score = score
44 | self.save_checkpoint(test_loss, model)
45 | self.counter = 0
46 |
47 | def save_checkpoint(self, test_loss, model):
48 | '''Saves model when test acc increases.'''
49 | if self.verbose:
50 | self.trace_func(f'Test acc increased ({self.test_loss_min:.6f} --> {test_loss:.6f}). Saving model ...')
51 | torch.save(model.state_dict(), self.path)
52 | self.test_loss_min = test_loss
--------------------------------------------------------------------------------
/shd/functions.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import Function
5 |
6 |
7 | class BinarizeF(Function):
8 |
9 | @staticmethod
10 | def forward(ctx, input):
11 | output = input.new(input.size())
12 | output[input >= 0] = 1
13 | output[input < 0] = -1
14 | return output
15 |
16 | @staticmethod
17 | def backward(ctx, grad_output):
18 | grad_input = grad_output.clone()
19 | return grad_input
20 |
21 | # aliases
22 | binarize = BinarizeF.apply
--------------------------------------------------------------------------------
/shd/run.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import spikegen
4 | from snntorch import surrogate
5 |
6 | # torch
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 |
11 |
12 | # misc
13 | import numpy as np
14 | import pandas as pd
15 | import time
16 | import logging
17 |
18 | # local imports
19 | from dataloader import *
20 | from Net import *
21 | from test_acc import *
22 | from train import *
23 | from earlystopping import *
24 | from conf import *
25 |
26 | ####################################################
27 | ## Notes: modify config in conf to reparameterize ##
28 | ####################################################
29 |
30 | file_name = config['exp_name']
31 |
32 | ### to address conditional parameters, s.t. thr_final > threshold
33 | config['thr_final1'] = config['thr_final1'] + config['threshold1']
34 | config['thr_final2'] = config['thr_final2'] + config['threshold2']
35 |
36 | threshold1 = config['threshold1']
37 | threshold2 = config['threshold2']
38 |
39 | for trial in range(config['num_trials']):
40 |
41 | # file names
42 | SAVE_CSV = config['save_csv']
43 | SAVE_MODEL = config['save_model']
44 | csv_name = file_name + '_t' + str(trial) + '.csv'
45 | log_name = file_name + '_t' + str(trial) + '.log'
46 | model_name = file_name + '_t' + str(trial) + '.pt'
47 | num_epochs = config['num_epochs']
48 | torch.manual_seed(config['seed'])
49 |
50 | config['threshold1'] = threshold1
51 | config['threshold2'] = threshold2
52 |
53 | # dataframes
54 | df_train_loss = pd.DataFrame()
55 | df_test_acc = pd.DataFrame(columns=['epoch', 'test_acc', 'train_time'])
56 | df_lr = pd.DataFrame()
57 |
58 | # initialize network
59 | net = Net(config)
60 | device = "cpu"
61 | if torch.cuda.is_available():
62 | device = "cuda:0"
63 | if torch.cuda.device_count() > 1:
64 | net = nn.DataParallel(net)
65 | net.to(device)
66 |
67 | # net params
68 | criterion = SF.mse_count_loss(correct_rate=config['correct_rate'], incorrect_rate=config['incorrect_rate'])
69 | optimizer, scheduler = optim_func(net, config)
70 |
71 | # early stopping condition
72 | if config['early_stopping']:
73 | early_stopping = EarlyStopping_acc(patience=config['patience'], verbose=True, path=model_name)
74 | early_stopping.early_stop = False
75 | early_stopping.best_score = None
76 |
77 | # load data
78 | trainset, testset = load_data(config)
79 | trainloader = DataLoader(trainset, batch_size=int(config["batch_size"]), shuffle=True)
80 | testloader = DataLoader(testset, batch_size=int(config["batch_size"]), shuffle=False)
81 |
82 | print(f"=======Trial: {trial}=======")
83 |
84 | for epoch in range(num_epochs):
85 |
86 | # train
87 | start_time = time.time()
88 | loss_list, lr_list = train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device)
89 | epoch_time = time.time() - start_time
90 |
91 | # test
92 | test_acc = test_accuracy(config, net, testloader, device)
93 | print(f'Epoch: {epoch} \tTest Accuracy: {test_acc}')
94 |
95 | if config['df_lr']:
96 | df_lr = pd.concat([df_lr, pd.DataFrame(lr_list)])
97 | df_train_loss = pd.concat([df_train_loss, pd.DataFrame(loss_list)])
98 | test_data = pd.DataFrame([[epoch, test_acc, epoch_time]], columns = ['epoch', 'test_acc', 'train_time'])
99 | df_test_acc = pd.concat([df_test_acc, test_data])
100 |
101 | if SAVE_CSV:
102 | df_train_loss.to_csv('loss_' + csv_name, index=False)
103 | df_test_acc.to_csv('acc_' + csv_name, index=False)
104 | if config['df_lr']:
105 | df_lr.to_csv('lr_' + csv_name, index=False)
106 |
107 | if config['early_stopping']:
108 | early_stopping(test_acc, net)
109 |
110 | if early_stopping.early_stop:
111 | print("Early stopping")
112 | early_stopping.early_stop = False
113 | early_stopping.best_score = None
114 | break
115 |
116 | if SAVE_MODEL and not config['early_stopping']:
117 | torch.save(net.state_dict(), model_name)
118 |
119 | # net.load_state_dict(torch.load(model_name))
--------------------------------------------------------------------------------
/shd/test_acc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import snntorch as snn
3 | from snntorch import functional as SF
4 |
5 |
6 | def test_accuracy(config, net, testloader, device="cpu"):
7 |
8 |
9 | correct = 0
10 | total = 0
11 | with torch.no_grad():
12 | net.eval()
13 | for data in testloader:
14 | images, labels = data
15 | images, labels = images.to(device), labels.to(device)
16 |
17 | outputs, _ = net(images.permute(1, 0, 2))
18 | accuracy = SF.accuracy_rate(outputs, labels)
19 |
20 | total += labels.size(0)
21 | correct += accuracy * labels.size(0)
22 |
23 | return 100 * correct / total
--------------------------------------------------------------------------------
/shd/tha.py:
--------------------------------------------------------------------------------
1 | # exp relaxation implementation of THA based on Eq (4)
2 |
3 | def thr_annealing(config, network):
4 | alpha_thr1 = config['alpha_thr1']
5 | thr_final1 = config['thr_final1']
6 |
7 | alpha_thr2 = config['alpha_thr2']
8 | thr_final2 = config['thr_final2']
9 |
10 | network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1
11 | network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2
12 |
13 | return
--------------------------------------------------------------------------------
/shd/train.py:
--------------------------------------------------------------------------------
1 | # snntorch
2 | import snntorch as snn
3 | from snntorch import spikegen
4 | from snntorch import surrogate
5 | from snntorch import functional as SF
6 |
7 | # torch
8 | import torch
9 | import torch.nn as nn
10 | from torch.utils.data import DataLoader
11 | from torchvision import datasets, transforms
12 | import torch.nn.functional as F
13 | from torch.optim.lr_scheduler import StepLR
14 |
15 | # misc
16 | import os
17 | import numpy as np
18 | import math
19 | import itertools
20 | import matplotlib.pyplot as plt
21 | import pandas as pd
22 | import shutil
23 | import time
24 |
25 | # raytune
26 | # from functools import partial
27 | # from ray import tune
28 | # from ray.tune import CLIReporter
29 | # # from ray.tune import JupyterNotebookReporter
30 | # from ray.tune.schedulers import ASHAScheduler
31 |
32 | from dataloader import *
33 | from test import *
34 | from test_acc import *
35 | from tha import *
36 |
37 | def train(config, net, epoch, trainloader, testloader, criterion, optimizer, scheduler, device):
38 |
39 | net.train()
40 | loss_accum = []
41 | lr_accum = []
42 |
43 | # TRAIN
44 | for data, labels in trainloader:
45 | data, labels = data.to(device), labels.to(device)
46 | spk_rec2, _ = net(data.permute(1, 0, 2))
47 | loss = criterion(spk_rec2, labels.long())
48 | optimizer.zero_grad()
49 | loss.backward()
50 |
51 | if config['grad_clip']:
52 | nn.utils.clip_grad_norm_(net.parameters(), 1.0)
53 | if config['weight_clip']:
54 | with torch.no_grad():
55 | for param in net.parameters():
56 | param.clamp_(-1, 1)
57 |
58 | optimizer.step()
59 | scheduler.step()
60 | thr_annealing(config, net)
61 |
62 |
63 | loss_accum.append(loss.item()/config['num_steps'])
64 | lr_accum.append(optimizer.param_groups[0]["lr"])
65 |
66 | return loss_accum, lr_accum
67 |
--------------------------------------------------------------------------------
/temporal/bounded_homeostasis.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "accelerator": "GPU",
6 | "colab": {
7 | "name": "icml_spike_time_exp.ipynb",
8 | "provenance": [],
9 | "collapsed_sections": []
10 | },
11 | "kernelspec": {
12 | "display_name": "Python 3",
13 | "name": "python3"
14 | },
15 | "language_info": {
16 | "name": "python"
17 | }
18 | },
19 | "cells": [
20 | {
21 | "cell_type": "markdown",
22 | "metadata": {
23 | "id": "GapeQDQsl-sx"
24 | },
25 | "source": [
26 | "# Bounded Homeostasis to Learn Temporal Targets\n",
27 | "\n",
28 | "This notebook replicates the temporal coding experiments in the paper *`The fine line between dead neurons and sparsity in binarized spiking neural networks'*."
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "metadata": {
34 | "id": "U_R27gZyULBI"
35 | },
36 | "source": [
37 | "!pip install snntorch --quiet"
38 | ],
39 | "execution_count": null,
40 | "outputs": []
41 | },
42 | {
43 | "cell_type": "markdown",
44 | "metadata": {
45 | "id": "kKI4l8OXQxXk"
46 | },
47 | "source": [
48 | "## Imports"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "metadata": {
54 | "id": "JyB2DosmUNO0"
55 | },
56 | "source": [
57 | "import snntorch as snn\n",
58 | "from snntorch import surrogate\n",
59 | "from snntorch import spikegen\n",
60 | "import snntorch.functional as SF\n",
61 | "from snntorch import spikeplot as splt\n",
62 | "from snntorch import utils\n",
63 | "\n",
64 | "import torch\n",
65 | "import torch.nn as nn\n",
66 | "import torch.nn.functional as F\n",
67 | "from torch.autograd import Function\n",
68 | "\n",
69 | "import matplotlib.pyplot as plt\n",
70 | "from matplotlib.animation import FuncAnimation\n",
71 | "import matplotlib.gridspec as gridspec\n",
72 | "import seaborn as sns \n",
73 | "\n",
74 | "from IPython import display\n",
75 | "import numpy as np\n",
76 | "from tqdm import tqdm\n",
77 | "import math\n",
78 | "import random\n",
79 | "from scipy.ndimage.filters import uniform_filter1d\n",
80 | "import os"
81 | ],
82 | "execution_count": null,
83 | "outputs": []
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "source": [
88 | "## Plotting Utility Functions"
89 | ],
90 | "metadata": {
91 | "id": "34Oz6bzcuCne"
92 | }
93 | },
94 | {
95 | "cell_type": "code",
96 | "source": [
97 | "#@title\n",
98 | "sns.set_theme()\n",
99 | "\n",
100 | "def prep_for_plot(mem):\n",
101 | " return mem.cpu().detach().squeeze(-1).squeeze(-1)\n",
102 | "\n",
103 | "def plot_quadrant(mem, spk_out, target_mem, spk_target, y1, y2, threshold=1, save=False, epoch1 = 1, epoch2=25, epoch3=100, fill=True):\n",
104 | " # Generate Plots\n",
105 | " gs = gridspec.GridSpec(2, 4, height_ratios=[1, 0.07])\n",
106 | " fig = plt.figure(figsize=(12,4.5),)\n",
107 | " ax1 = plt.subplot(gs[0,0])\n",
108 | " ax2 = plt.subplot(gs[1,0])\n",
109 | " ax3 = plt.subplot(gs[0,1])\n",
110 | " ax4 = plt.subplot(gs[1,1])\n",
111 | " ax5 = plt.subplot(gs[0,2])\n",
112 | " ax6 = plt.subplot(gs[1,2])\n",
113 | " ax7 = plt.subplot(gs[0,3])\n",
114 | " ax8 = plt.subplot(gs[1,3])\n",
115 | "\n",
116 | " mem = prep_for_plot(mem)\n",
117 | " spk_out = prep_for_plot(spk_out)\n",
118 | " target_mem = prep_for_plot(target_mem)\n",
119 | " epoch1_str = str(epoch1)\n",
120 | " epoch2_str = str(epoch2)\n",
121 | " epoch3_str = str(epoch3)\n",
122 | "\n",
123 | " fontsize = 25\n",
124 | "\n",
125 | " ########### TARGET ########\n",
126 | " # Plot membrane potential\n",
127 | " ax1.plot(target_mem)\n",
128 | " ax1.set_ylim([y1, y2]) # 0.1, 1.3\n",
129 | " ax1.set_ylabel(\"$u$\", fontsize=fontsize, fontweight='bold')\n",
130 | " ax1.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
131 | " ax1.set_yticks([])\n",
132 | " ax1.set_xticks([])\n",
133 | " ax1.set_title(\"Target\",fontsize=fontsize, fontweight='bold')\n",
134 | " # plt.xlabel(\"Time\") \n",
135 | "\n",
136 | " # Plot output spike using spikeplot\n",
137 | " splt.raster(spk_target, ax2, s=250, c=\"black\", marker=\".\")\n",
138 | " ax2.set_ylabel(\"$z$\", fontsize=fontsize, fontweight='bold')\n",
139 | " ax2.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n",
140 | " ax2.set_yticks([]) \n",
141 | " ax2.set_xticks([])\n",
142 | " ax2.set_xlim(0, 100)\n",
143 | "\n",
144 | " ############## EPOCH 1 ########\n",
145 | "\n",
146 | " # Plot membrane potential\n",
147 | " ax3.plot(mem[epoch1])\n",
148 | " ax3.set_ylim([y1, y2]) # 0.1, 1.3\n",
149 | " ax3.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
150 | " ax3.set_yticks([])\n",
151 | " ax3.set_xticks([])\n",
152 | " ax3.set_title(\"$\\gamma =$\" + epoch1_str ,fontsize=fontsize, fontweight='bold')\n",
153 | " # plt.xlabel(\"Time\") \n",
154 | "\n",
155 | " # Plot output spike using spikeplot\n",
156 | " splt.raster(spk_out[epoch1], ax4, s=250, c=\"black\", marker=\".\")\n",
157 | " # ax4.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n",
158 | " ax4.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n",
159 | " ax4.set_yticks([]) \n",
160 | " ax4.set_xticks([])\n",
161 | " ax4.set_xlim(0, 100)\n",
162 | "\n",
163 | " ############## EPOCH 100 ########\n",
164 | "\n",
165 | " # Plot membrane potential\n",
166 | " ax5.plot(mem[epoch2])\n",
167 | " ax5.set_ylim([y1, y2]) # 0.1, 1.3\n",
168 | " # ax5.set_ylabel(\"$u^{~j}_t$\", fontsize=fontsize, fontweight='bold')\n",
169 | " ax5.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
170 | " ax5.set_yticks([])\n",
171 | " ax5.set_xticks([]) \n",
172 | " ax5.set_title(\"$\\gamma =$\" + epoch2_str,fontsize=fontsize, fontweight='bold')\n",
173 | " # plt.xlabel(\"Time\") \n",
174 | "\n",
175 | " # Plot output spike using spikeplot\n",
176 | " splt.raster(spk_out[epoch2], ax6, s=250, c=\"black\", marker=\".\")\n",
177 | " # ax6.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n",
178 | " ax6.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n",
179 | " ax6.set_yticks([]) \n",
180 | " ax6.set_xticks([])\n",
181 | " ax6.set_xlim(0, 100)\n",
182 | "\n",
183 | " ########## EPOCH 100 ##############\n",
184 | " # Plot membrane potential\n",
185 | " ax7.plot(mem[epoch3])\n",
186 | " ax7.set_ylim([y1, y2]) # 0.1, 1.3\n",
187 | " # ax7.set_ylabel(\"$u^{~j}_t$\", fontsize=fontsize, fontweight='bold')\n",
188 | " ax7.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
189 | " ax7.set_yticks([])\n",
190 | " ax7.set_xticks([])\n",
191 | " ax7.set_title(\"$\\gamma =$\" + epoch3_str,fontsize=fontsize, fontweight='bold')\n",
192 | " # plt.xlabel(\"Time\") \n",
193 | "\n",
194 | " # Plot output spike using spikeplot\n",
195 | " splt.raster(spk_out[epoch3], ax8, s=250, c=\"black\", marker=\".\")\n",
196 | " # ax8.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n",
197 | " ax8.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n",
198 | " ax8.set_yticks([]) \n",
199 | " ax8.set_xticks([])\n",
200 | " ax8.set_xlim(0, 100)\n",
201 | " \n",
202 | " fig.tight_layout()\n",
203 | " plt.subplots_adjust(\n",
204 | " # left=0.125,\n",
205 | " # bottom=0.1, \n",
206 | " # right=0.9, \n",
207 | " # top=0.9, \n",
208 | " wspace=0.01, \n",
209 | " hspace=0.)\n",
210 | " \n",
211 | " if fill:\n",
212 | " ax1.fill_between(x, target_mem, step=\"pre\", alpha=0.4, color='tab:blue')\n",
213 | " ax3.fill_between(x, mem[epoch1], step=\"pre\", alpha=0.4, color='tab:blue')\n",
214 | " ax5.fill_between(x, mem[epoch2], step=\"pre\", alpha=0.4, color='tab:blue')\n",
215 | " ax7.fill_between(x, mem[epoch3], step=\"pre\", alpha=0.4, color='tab:blue')\n",
216 | "\n",
217 | " fig1 = plt.gcf()\n",
218 | " if save:\n",
219 | " fig1.savefig(save, dpi=600)\n",
220 | "\n",
221 | " plt.show()\n",
222 | "\n",
223 | "\n",
224 | "def plot_quadrant_tha(mem, spk_out, target_mem, spk_target, y1, y2, \n",
225 | " threshold=[1, 1, 1], save=False, epoch1 = 1, epoch2=25, \n",
226 | " epoch3=100, fill=True):\n",
227 | " # Generate Plots\n",
228 | " gs = gridspec.GridSpec(2, 4, height_ratios=[1, 0.07])\n",
229 | " fig = plt.figure(figsize=(12,4.5),)\n",
230 | " ax1 = plt.subplot(gs[0,0])\n",
231 | " ax2 = plt.subplot(gs[1,0])\n",
232 | " ax3 = plt.subplot(gs[0,1])\n",
233 | " ax4 = plt.subplot(gs[1,1])\n",
234 | " ax5 = plt.subplot(gs[0,2])\n",
235 | " ax6 = plt.subplot(gs[1,2])\n",
236 | " ax7 = plt.subplot(gs[0,3])\n",
237 | " ax8 = plt.subplot(gs[1,3])\n",
238 | "\n",
239 | " mem = prep_for_plot(mem)\n",
240 | " spk_out = prep_for_plot(spk_out)\n",
241 | " target_mem = prep_for_plot(target_mem)\n",
242 | " epoch1_str = str(epoch1)\n",
243 | " epoch2_str = str(epoch2)\n",
244 | " epoch3_str = str(epoch3)\n",
245 | "\n",
246 | " fontsize = 25\n",
247 | "\n",
248 | " ########### TARGET ########\n",
249 | " # Plot membrane potential\n",
250 | " ax1.plot(target_mem)\n",
251 | " ax1.set_ylim([y1, y2]) # 0.1, 1.3\n",
252 | " ax1.set_ylabel(\"$u$\", fontsize=fontsize, fontweight='bold')\n",
253 | " ax1.axhline(y=threshold[999], alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
254 | " ax1.set_yticks([])\n",
255 | " ax1.set_xticks([])\n",
256 | " ax1.set_title(\"Target\",fontsize=fontsize, fontweight='bold')\n",
257 | " # plt.xlabel(\"Time\") \n",
258 | "\n",
259 | " # Plot output spike using spikeplot\n",
260 | " splt.raster(spk_target, ax2, s=250, c=\"black\", marker=\".\")\n",
261 | " ax2.set_ylabel(\"$z$\", fontsize=fontsize, fontweight='bold')\n",
262 | " ax2.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n",
263 | " ax2.set_yticks([]) \n",
264 | " ax2.set_xticks([])\n",
265 | " ax2.set_xlim(0, 100)\n",
266 | "\n",
267 | " ############## EPOCH 1 ########\n",
268 | "\n",
269 | " # Plot membrane potential\n",
270 | " ax3.plot(mem[epoch1])\n",
271 | " ax3.set_ylim([y1, y2]) # 0.1, 1.3\n",
272 | " ax3.axhline(y=threshold[epoch1], alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
273 | " ax3.set_yticks([])\n",
274 | " ax3.set_xticks([])\n",
275 | " ax3.set_title(\"$\\gamma =$\" + epoch1_str ,fontsize=fontsize, fontweight='bold')\n",
276 | " # plt.xlabel(\"Time\") \n",
277 | "\n",
278 | " # Plot output spike using spikeplot\n",
279 | " splt.raster(spk_out[epoch1], ax4, s=250, c=\"black\", marker=\".\")\n",
280 | " # ax4.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n",
281 | " ax4.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n",
282 | " ax4.set_yticks([]) \n",
283 | " ax4.set_xticks([])\n",
284 | " ax4.set_xlim(0, 100)\n",
285 | "\n",
286 | " ############## EPOCH 100 ########\n",
287 | "\n",
288 | " # Plot membrane potential\n",
289 | " ax5.plot(mem[epoch2])\n",
290 | " ax5.set_ylim([y1, y2]) # 0.1, 1.3\n",
291 | " # ax5.set_ylabel(\"$u^{~j}_t$\", fontsize=fontsize, fontweight='bold')\n",
292 | " ax5.axhline(y=threshold[epoch2], alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
293 | " ax5.set_yticks([])\n",
294 | " ax5.set_xticks([]) \n",
295 | " ax5.set_title(\"$\\gamma =$\" + epoch2_str,fontsize=fontsize, fontweight='bold')\n",
296 | " # plt.xlabel(\"Time\") \n",
297 | "\n",
298 | " # Plot output spike using spikeplot\n",
299 | " splt.raster(spk_out[epoch2], ax6, s=250, c=\"black\", marker=\".\")\n",
300 | " # ax6.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n",
301 | " ax6.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n",
302 | " ax6.set_yticks([]) \n",
303 | " ax6.set_xticks([])\n",
304 | " ax6.set_xlim(0, 100)\n",
305 | "\n",
306 | " ########## EPOCH 100 ##############\n",
307 | " # Plot membrane potential\n",
308 | " ax7.plot(mem[epoch3])\n",
309 | " ax7.set_ylim([y1, y2]) # 0.1, 1.3\n",
310 | " # ax7.set_ylabel(\"$u^{~j}_t$\", fontsize=fontsize, fontweight='bold')\n",
311 | " ax7.axhline(y=threshold[epoch3], alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
312 | " ax7.set_yticks([])\n",
313 | " ax7.set_xticks([])\n",
314 | " ax7.set_title(\"$\\gamma =$\" + epoch3_str,fontsize=fontsize, fontweight='bold')\n",
315 | " # plt.xlabel(\"Time\") \n",
316 | "\n",
317 | " # Plot output spike using spikeplot\n",
318 | " splt.raster(spk_out[epoch3], ax8, s=250, c=\"black\", marker=\".\")\n",
319 | " # ax8.set_ylabel(\"$z^{~j}_t$\", fontsize=fontsize, fontweight='bold') \n",
320 | " ax8.set_xlabel(\"$t$\", fontsize=fontsize, fontweight='bold')\n",
321 | " ax8.set_yticks([]) \n",
322 | " ax8.set_xticks([])\n",
323 | " ax8.set_xlim(0, 100)\n",
324 | " \n",
325 | " fig.tight_layout()\n",
326 | " plt.subplots_adjust(\n",
327 | " # left=0.125,\n",
328 | " # bottom=0.1, \n",
329 | " # right=0.9, \n",
330 | " # top=0.9, \n",
331 | " wspace=0.01, \n",
332 | " hspace=0.)\n",
333 | " \n",
334 | " if fill:\n",
335 | " ax1.fill_between(x, target_mem, step=\"pre\", alpha=0.4, color='tab:blue')\n",
336 | " ax3.fill_between(x, mem[epoch1], step=\"pre\", alpha=0.4, color='tab:blue')\n",
337 | " ax5.fill_between(x, mem[epoch2], step=\"pre\", alpha=0.4, color='tab:blue')\n",
338 | " ax7.fill_between(x, mem[epoch3], step=\"pre\", alpha=0.4, color='tab:blue')\n",
339 | "\n",
340 | " fig1 = plt.gcf()\n",
341 | " if save:\n",
342 | " fig1.savefig(save, dpi=600)\n",
343 | "\n",
344 | " plt.show()\n"
345 | ],
346 | "metadata": {
347 | "cellView": "form",
348 | "id": "RERn5ncNBF2Z"
349 | },
350 | "execution_count": null,
351 | "outputs": []
352 | },
353 | {
354 | "cell_type": "markdown",
355 | "metadata": {
356 | "id": "1svHV-viQ_Ll"
357 | },
358 | "source": [
359 | "# 1. High Precision Testing\n",
360 | "## 1.1 Choose some random hyperparameters"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "metadata": {
366 | "id": "izqVw9L4UaWx"
367 | },
368 | "source": [
369 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
370 | "dtype = torch.float\n",
371 | "num_steps = 100\n",
372 | "num_inputs = 100\n",
373 | "num_hidden = 1000\n",
374 | "batch_size = 1\n",
375 | "beta=0.6\n",
376 | "spike_time = 75\n",
377 | "\n",
378 | "loss_fn = nn.MSELoss() "
379 | ],
380 | "execution_count": null,
381 | "outputs": []
382 | },
383 | {
384 | "cell_type": "code",
385 | "source": [
386 | "def set_all_seeds(seed=0):\n",
387 | " random.seed(seed)\n",
388 | " os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
389 | " np.random.seed(seed)\n",
390 | " torch.manual_seed(seed)\n",
391 | " torch.cuda.manual_seed(seed)\n",
392 | " torch.backends.cudnn.deterministic = True\n",
393 | "\n",
394 | "set_all_seeds()"
395 | ],
396 | "metadata": {
397 | "id": "1npF4uSpAbLf"
398 | },
399 | "execution_count": null,
400 | "outputs": []
401 | },
402 | {
403 | "cell_type": "markdown",
404 | "metadata": {
405 | "id": "9rQJ71afRqwo"
406 | },
407 | "source": [
408 | "## 1.2 Generate Random Inputs and Membrane Trace Target\n",
409 | "* The random inputs will be fed to the network\n",
410 | "* The output neuron will be trained to replicate the evolution of the membrane trace generated below"
411 | ]
412 | },
413 | {
414 | "cell_type": "code",
415 | "metadata": {
416 | "id": "dGXVx6u-UsuL"
417 | },
418 | "source": [
419 | "input_prob = torch.rand(num_steps, batch_size, num_inputs).to(device)\n",
420 | "input_data = spikegen.rate(input_prob, time_var_input=True)\n",
421 | "target_mem = spikegen.targets_latency(torch.zeros(1, dtype=dtype, device=device), num_classes=1, first_spike_time=75, on_target=1.05, num_steps=100, interpolate=True)"
422 | ],
423 | "execution_count": null,
424 | "outputs": []
425 | },
426 | {
427 | "cell_type": "code",
428 | "metadata": {
429 | "id": "8MUE72R0bSSU"
430 | },
431 | "source": [
432 | "# membrane trace target: Threshold=1\n",
433 | "splt.traces(target_mem, spk=False, dim=(1,1), spk_height=1)"
434 | ],
435 | "execution_count": null,
436 | "outputs": []
437 | },
438 | {
439 | "cell_type": "markdown",
440 | "metadata": {
441 | "id": "638dTiPNRzuc"
442 | },
443 | "source": [
444 | "## 1.3 Define network\n",
445 | "100-1000-1 Dense Network"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "metadata": {
451 | "id": "k42uNpyxzsCb"
452 | },
453 | "source": [
454 | "net = nn.Sequential(\n",
455 | " nn.Linear(num_inputs, num_hidden),\n",
456 | " snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True),\n",
457 | " nn.Linear(num_hidden, 1),\n",
458 | " snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True, output=True)\n",
459 | ").to(device)"
460 | ],
461 | "execution_count": null,
462 | "outputs": []
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "metadata": {
467 | "id": "hBLjAshUR1SM"
468 | },
469 | "source": [
470 | "## 1.4 High-precision training loop\n",
471 | "Start with high precision weights."
472 | ]
473 | },
474 | {
475 | "cell_type": "code",
476 | "metadata": {
477 | "id": "-3fkyBYxU_h8"
478 | },
479 | "source": [
480 | "# optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, betas=(0.9, 0.999)) \n",
481 | "optimizer = torch.optim.SGD(net.parameters(), lr=1e-3, momentum=0.9)\n",
482 | "num_epochs = 1000\n",
483 | "\n",
484 | "mem_tot = []\n",
485 | "spk_tot = []\n",
486 | "\n",
487 | "for epoch in tqdm(range(num_epochs)):\n",
488 | " mem_rec = []\n",
489 | " spk_rec = []\n",
490 | "\n",
491 | " utils.reset(net)\n",
492 | "\n",
493 | " for step in range(num_steps):\n",
494 | " spk, mem = net(input_data[step])\n",
495 | " mem_rec.append(mem)\n",
496 | " spk_rec.append(spk)\n",
497 | "\n",
498 | " mem_rec = torch.stack(mem_rec)\n",
499 | " mem_tot.append(mem_rec)\n",
500 | "\n",
501 | " spk_rec = torch.stack(spk_rec)\n",
502 | " spk_tot.append(spk_rec)\n",
503 | "\n",
504 | " # loss = loss_fn(targets_spike, mem_rec) + 2*loss_fn(targets_spike[75], mem_rec[75])+ 5e-1*sum(spk_rec) # full trace \n",
505 | " loss = loss_fn(target_mem, mem_rec) # + 2 * loss_fn(targets_spike[75], mem_rec[75]) # + 0*(torch.exp(sum(spk_rec))-1)\n",
506 | "\n",
507 | " # clear previously stored gradients\n",
508 | " optimizer.zero_grad()\n",
509 | "\n",
510 | " # calculate the gradients\n",
511 | " loss.backward()\n",
512 | "\n",
513 | " # weight update\n",
514 | " optimizer.step()\n",
515 | "\n",
516 | "mem_tot = torch.stack(mem_tot)\n",
517 | "spk_tot = torch.stack(spk_tot)"
518 | ],
519 | "execution_count": null,
520 | "outputs": []
521 | },
522 | {
523 | "cell_type": "markdown",
524 | "source": [
525 | "## 1.5 Plot Membrane Potential\n",
526 | "$\\gamma$ refers to the training iteration."
527 | ],
528 | "metadata": {
529 | "id": "aa6ioKwHuG9_"
530 | }
531 | },
532 | {
533 | "cell_type": "code",
534 | "source": [
535 | "plot_quadrant(mem_tot, spk_tot, target_mem, spk_target, -0.1, 1.2, threshold=1, save=\"spk_time_flt.png\", epoch1=1, epoch2=100, epoch3=500, fill=True) # save=\"spk_time_flt.png\""
536 | ],
537 | "metadata": {
538 | "id": "ClAhLranCClW"
539 | },
540 | "execution_count": null,
541 | "outputs": []
542 | },
543 | {
544 | "cell_type": "markdown",
545 | "metadata": {
546 | "id": "sL3ywHTumXY7"
547 | },
548 | "source": [
549 | "## 1.6 Evolution of membrane potential over training epochs"
550 | ]
551 | },
552 | {
553 | "cell_type": "code",
554 | "metadata": {
555 | "id": "OEyHUZD5IcUU"
556 | },
557 | "source": [
558 | "threshold = 1\n",
559 | "fig, ax = plt.subplots()\n",
560 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n",
561 | "x = np.arange(0, 100, 1) \n",
562 | "\n",
563 | "ax.set_xlim(0, num_steps)\n",
564 | "ax.set_ylim(-0.5, 1.5)\n",
565 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n",
566 | " verticalalignment='top', transform=ax.transAxes, size='large')\n",
567 | "\n",
568 | "ax.set_ylabel('Membrane Potential ($u$)')\n",
569 | "ax.set_xlabel('Time Steps')\n",
570 | "ax.plot(target_mem[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n",
571 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
572 | "\n",
573 | "\n",
574 | "def animate(frame_num):\n",
575 | " line.set_data(x, mem_tot[frame_num, x, 0,0].cpu().detach().numpy())\n",
576 | " time_text.set_text(f'Epoch: {frame_num}')\n",
577 | "\n",
578 | " # ax.plot([], [], ' ', label=str(frame_num))\n",
579 | " # ax.legend(loc='upper right')\n",
580 | " return (line, time_text)\n",
581 | "\n",
582 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30)\n",
583 | "anim.save('spk_time_flt.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n",
584 | "\n",
585 | "video = anim.to_html5_video()\n",
586 | "html = display.HTML(video)\n",
587 | "display.display(html)\n",
588 | "plt.close() # avoid plotting a spare static plot"
589 | ],
590 | "execution_count": null,
591 | "outputs": []
592 | },
593 | {
594 | "cell_type": "markdown",
595 | "metadata": {
596 | "id": "_4RqawaA4BbU"
597 | },
598 | "source": [
599 | "# 2. Binarized Spike Timing: Threshold=1\n",
600 | "The high precision simulation does a good job of tracking the desired membrane potential. There is instability when a spike occurs because of the discontinuous reset: when the neuron is reset, the weights try to offset the sudden change by increasing weights.\n",
601 | "\n",
602 | "Now, let's test out binarized spiking neural nets. \n",
603 | "Before introducing threshold annealing, we will apply a threshold of $\\theta=1$ to all neurons. The input of each axon can only ever be +1 or -1. \n",
604 | "We can expect the outcome to be extremely unstable.\n",
605 | "\n",
606 | "## 2.1 Binarized Functions\n"
607 | ]
608 | },
609 | {
610 | "cell_type": "code",
611 | "metadata": {
612 | "id": "4dZkN83RbDP0"
613 | },
614 | "source": [
615 | "class BinaryLinear(nn.Linear):\n",
616 | " def forward(self, input):\n",
617 | " binary_weight = binarize(self.weight)\n",
618 | " if self.bias is None:\n",
619 | " return F.linear(input, binary_weight)\n",
620 | " else:\n",
621 | " return F.linear(input, binary_weight, self.bias)\n",
622 | "\n",
623 | " def reset_parameters(self):\n",
624 | " # Glorot initialization\n",
625 | " in_features, out_features = self.weight.size()\n",
626 | " stdv = math.sqrt(1.5 / (in_features + out_features))\n",
627 | " self.weight.data.uniform_(-stdv, stdv)\n",
628 | " if self.bias is not None:\n",
629 | " self.bias.data.zero_()\n",
630 | "\n",
631 | " self.weight.lr_scale = 1. / stdv\n",
632 | "\n",
633 | "\n",
634 | "class BinarizeF(Function):\n",
635 | "\n",
636 | " @staticmethod\n",
637 | " def forward(ctx, input):\n",
638 | " output = input.new(input.size())\n",
639 | " output[input >= 0] = 1\n",
640 | " output[input < 0] = -1\n",
641 | " return output\n",
642 | "\n",
643 | " @staticmethod\n",
644 | " def backward(ctx, grad_output):\n",
645 | " grad_input = grad_output.clone()\n",
646 | " return grad_input\n",
647 | "\n",
648 | "# aliases\n",
649 | "binarize = BinarizeF.apply"
650 | ],
651 | "execution_count": null,
652 | "outputs": []
653 | },
654 | {
655 | "cell_type": "markdown",
656 | "metadata": {
657 | "id": "pZ_v2xDcU6Hh"
658 | },
659 | "source": [
660 | "## 2.2 Hyperparameters"
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "metadata": {
666 | "id": "Q1UYIj62U3Vm"
667 | },
668 | "source": [
669 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
670 | "dtype = torch.float\n",
671 | "num_steps = 100\n",
672 | "num_inputs = 100\n",
673 | "num_hidden = 1000\n",
674 | "batch_size = 1\n",
675 | "beta=0.15 \n",
676 | "\n",
677 | "loss_fn = nn.MSELoss() "
678 | ],
679 | "execution_count": null,
680 | "outputs": []
681 | },
682 | {
683 | "cell_type": "markdown",
684 | "metadata": {
685 | "id": "WbfDFqYJm-95"
686 | },
687 | "source": [
688 | "## 2.3 Network Definition\n",
689 | "Same architecture will be used all throughout: 100-1000-1 Dense Layers."
690 | ]
691 | },
692 | {
693 | "cell_type": "code",
694 | "metadata": {
695 | "id": "vf1uJDT6dbst"
696 | },
697 | "source": [
698 | "b_net = nn.Sequential(\n",
699 | " BinaryLinear(num_inputs, num_hidden),\n",
700 | " snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True),\n",
701 | " BinaryLinear(num_hidden, 1),\n",
702 | " snn.Leaky(beta=beta, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True, output=True)\n",
703 | ").to(device)"
704 | ],
705 | "execution_count": null,
706 | "outputs": []
707 | },
708 | {
709 | "cell_type": "markdown",
710 | "metadata": {
711 | "id": "bZHfDW88nAx5"
712 | },
713 | "source": [
714 | "## 2.4 Binarized Training Loop"
715 | ]
716 | },
717 | {
718 | "cell_type": "code",
719 | "metadata": {
720 | "id": "jlUf-mt28z0h"
721 | },
722 | "source": [
723 | "optimizer = torch.optim.SGD(b_net.parameters(), lr=1e-3, momentum=0.9)\n",
724 | "num_epochs = 1000\n",
725 | "mem_tot_bin = []\n",
726 | "spk_tot_bin = []\n",
727 | "\n",
728 | "for epoch in tqdm(range(num_epochs)):\n",
729 | " mem_rec = []\n",
730 | " spk_rec = []\n",
731 | "\n",
732 | " utils.reset(net)\n",
733 | "\n",
734 | " for step in range(num_steps):\n",
735 | " spk, mem = b_net(input_data[step])\n",
736 | " mem_rec.append(mem)\n",
737 | " spk_rec.append(spk)\n",
738 | "\n",
739 | " spk_rec = torch.stack(spk_rec)\n",
740 | " mem_rec = torch.stack(mem_rec)\n",
741 | " mem_tot_bin.append(mem_rec)\n",
742 | " spk_tot_bin.append(spk_rec)\n",
743 | "\n",
744 | " loss = loss_fn(target_mem, mem_rec)\n",
745 | "\n",
746 | " # clear previously stored gradients\n",
747 | " optimizer.zero_grad()\n",
748 | "\n",
749 | " # calculate the gradients\n",
750 | " loss.backward()\n",
751 | "\n",
752 | " # weight update\n",
753 | " optimizer.step()\n",
754 | "\n",
755 | "mem_tot_bin = torch.stack(mem_tot_bin)\n",
756 | "spk_tot_bin = torch.stack(spk_tot_bin)"
757 | ],
758 | "execution_count": null,
759 | "outputs": []
760 | },
761 | {
762 | "cell_type": "markdown",
763 | "source": [
764 | "## 2.5 Plot Membrane Potential"
765 | ],
766 | "metadata": {
767 | "id": "iUHLpcIdu4wJ"
768 | }
769 | },
770 | {
771 | "cell_type": "code",
772 | "source": [
773 | "plot_quadrant(mem_tot_bin, spk_tot_bin, target_mem, spk_target, -0.1, 1.2, threshold=1, save='spk_time_bin.png', epoch1=0, epoch2=75, epoch3=750, fill=True) # save=\"spk_time_flt.png\""
774 | ],
775 | "metadata": {
776 | "id": "PYewBqdZgWa0"
777 | },
778 | "execution_count": null,
779 | "outputs": []
780 | },
781 | {
782 | "cell_type": "markdown",
783 | "source": [
784 | "As expected, this doesn't look great. \n",
785 | "This somewhat resembles the pathological case described in section 2 of the paper, where BSNNs struggle to incorporate both memory dynamics and spike propagation. I.e., no smooth memory dynamics are visible above."
786 | ],
787 | "metadata": {
788 | "id": "7IB8TGFWu68L"
789 | }
790 | },
791 | {
792 | "cell_type": "markdown",
793 | "metadata": {
794 | "id": "N8orKC1dntVS"
795 | },
796 | "source": [
797 | "## 2.6 Evolution of membrane trace over training epochs"
798 | ]
799 | },
800 | {
801 | "cell_type": "code",
802 | "source": [
803 | "threshold = 1\n",
804 | "fig, ax = plt.subplots()\n",
805 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n",
806 | "x = np.arange(0, 100, 1) \n",
807 | "\n",
808 | "ax.set_xlim(0, num_steps)\n",
809 | "ax.set_ylim(-0.5, 1.5)\n",
810 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n",
811 | " verticalalignment='top', transform=ax.transAxes, size='large')\n",
812 | "\n",
813 | "ax.set_ylabel('Membrane Potential ($u$)')\n",
814 | "ax.set_xlabel('Time Steps')\n",
815 | "ax.plot(target_mem[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n",
816 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
817 | "\n",
818 | "\n",
819 | "def animate(frame_num):\n",
820 | " line.set_data(x, mem_tot_bin[frame_num, x, 0,0].cpu().detach().numpy())\n",
821 | " time_text.set_text(f'Epoch: {frame_num}')\n",
822 | " return (line, time_text)\n",
823 | "\n",
824 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30)\n",
825 | "anim.save('spk_time_bin.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n",
826 | "\n",
827 | "video = anim.to_html5_video()\n",
828 | "html = display.HTML(video)\n",
829 | "display.display(html)\n",
830 | "plt.close() # avoid plotting a spare static plot"
831 | ],
832 | "metadata": {
833 | "id": "5OWG-Vzqhonn"
834 | },
835 | "execution_count": null,
836 | "outputs": []
837 | },
838 | {
839 | "cell_type": "markdown",
840 | "metadata": {
841 | "id": "G8IpqgZvoeYT"
842 | },
843 | "source": [
844 | "## 2.7 Moving Average\n",
845 | "Perhaps we will see better results if we take the moving average of the membrane potential (over epochs)."
846 | ]
847 | },
848 | {
849 | "cell_type": "code",
850 | "source": [
851 | "threshold = 1\n",
852 | "fig, ax = plt.subplots()\n",
853 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n",
854 | "x = np.arange(0, 100, 1)\n",
855 | "\n",
856 | "N = 5 # size of filter\n",
857 | "mem_avg_bin = uniform_filter1d(mem_tot_bin.cpu().detach(), size=N, axis=1)\n",
858 | "\n",
859 | "ax.set_xlim(0, num_steps)\n",
860 | "ax.set_ylim(-0.5, 1.5)\n",
861 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n",
862 | " verticalalignment='top', transform=ax.transAxes, size='large')\n",
863 | "\n",
864 | "ax.set_ylabel('Membrane Potential ($u$)')\n",
865 | "ax.set_xlabel('Time Steps')\n",
866 | "ax.plot(target_mem[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n",
867 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
868 | "\n",
869 | "\n",
870 | "def animate(frame_num):\n",
871 | " line.set_data(x, mem_avg_bin[frame_num, x, 0,0])\n",
872 | " time_text.set_text(f'Epoch: {frame_num}')\n",
873 | "\n",
874 | " # ax.plot([], [], ' ', label=str(frame_num))\n",
875 | " # ax.legend(loc='upper right')\n",
876 | " return (line, time_text)\n",
877 | "\n",
878 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30)\n",
879 | "anim.save('spk_time_bin_MVA.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n",
880 | "\n",
881 | "video = anim.to_html5_video()\n",
882 | "html = display.HTML(video)\n",
883 | "display.display(html)\n",
884 | "plt.close() # avoid plotting a spare static plot"
885 | ],
886 | "metadata": {
887 | "id": "vMwX-33biqey"
888 | },
889 | "execution_count": null,
890 | "outputs": []
891 | },
892 | {
893 | "cell_type": "markdown",
894 | "source": [
895 | "Perhaps not."
896 | ],
897 | "metadata": {
898 | "id": "zGS39QO9vf5L"
899 | }
900 | },
901 | {
902 | "cell_type": "markdown",
903 | "metadata": {
904 | "id": "5lakzDLwG8K5"
905 | },
906 | "source": [
907 | "# 3. Wide threshold BNN\n",
908 | "If we use a large threshold, then each spiking neuron would have a wider dynamic range state-space, and this could enable more precise tuning. \n",
909 | "\n",
910 | "The problem we will run into is, if the threshold is too high, then downstream spikes probably won't occur, and so learning will also fail to take place. Let's set the threhsold of all neurons to $\\theta=50$. This is a significant jump from $\\theta=1$. "
911 | ]
912 | },
913 | {
914 | "cell_type": "markdown",
915 | "metadata": {
916 | "id": "EqAGq-h1pIe5"
917 | },
918 | "source": [
919 | "## 3.1 Hyperparameters"
920 | ]
921 | },
922 | {
923 | "cell_type": "code",
924 | "metadata": {
925 | "id": "eB_7CDHhWqb_"
926 | },
927 | "source": [
928 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
929 | "dtype = torch.float\n",
930 | "num_steps = 100\n",
931 | "num_inputs = 100\n",
932 | "num_hidden = 1000\n",
933 | "batch_size = 1\n",
934 | "beta=0.15\n",
935 | "w_thr1 = 50\n",
936 | "w_thr2 = 50 # 25 works well\n",
937 | "on_target = w_thr2 + w_thr2*0.1\n",
938 | "first_spike_time = 75\n",
939 | "\n",
940 | "loss_fn = nn.MSELoss() \n",
941 | "# loss_fn = nn.CrossEntropyLoss()"
942 | ],
943 | "execution_count": null,
944 | "outputs": []
945 | },
946 | {
947 | "cell_type": "markdown",
948 | "metadata": {
949 | "id": "VUkd2Y6MpAg2"
950 | },
951 | "source": [
952 | "## 3.2 Define target"
953 | ]
954 | },
955 | {
956 | "cell_type": "code",
957 | "metadata": {
958 | "id": "YT-ZdMAJbr6K"
959 | },
960 | "source": [
961 | "targets_wthr = spikegen.targets_latency(torch.zeros(1, dtype=dtype, device=device), num_classes=1, first_spike_time=first_spike_time, on_target=on_target, num_steps=num_steps, interpolate=True)"
962 | ],
963 | "execution_count": null,
964 | "outputs": []
965 | },
966 | {
967 | "cell_type": "markdown",
968 | "metadata": {
969 | "id": "3aTkUWNBpLSs"
970 | },
971 | "source": [
972 | "## 3.3 Define network"
973 | ]
974 | },
975 | {
976 | "cell_type": "code",
977 | "metadata": {
978 | "id": "suVpw-rm9DKa"
979 | },
980 | "source": [
981 | "wthr_net = nn.Sequential(\n",
982 | " BinaryLinear(num_inputs, num_hidden),\n",
983 | " snn.Leaky(beta=beta, threshold=w_thr1, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True),\n",
984 | " BinaryLinear(num_hidden, 1),\n",
985 | " snn.Leaky(beta=beta, threshold=w_thr2, spike_grad=surrogate.fast_sigmoid(slope=5), init_hidden=True, output=True)\n",
986 | ").to(device)"
987 | ],
988 | "execution_count": null,
989 | "outputs": []
990 | },
991 | {
992 | "cell_type": "markdown",
993 | "metadata": {
994 | "id": "d3UEhuiqpM4T"
995 | },
996 | "source": [
997 | "## 3.4 Training Loop"
998 | ]
999 | },
1000 | {
1001 | "cell_type": "code",
1002 | "metadata": {
1003 | "id": "k6LqaEDKHElP"
1004 | },
1005 | "source": [
1006 | "optimizer = torch.optim.SGD(wthr_net.parameters(), lr=1e-3, momentum=0.9)\n",
1007 | "num_epochs = 1000\n",
1008 | "mem_tot_wthr = []\n",
1009 | "spk_tot_wthr = []\n",
1010 | "\n",
1011 | "for epoch in tqdm(range(num_epochs)):\n",
1012 | " mem_rec = []\n",
1013 | " spk_rec = []\n",
1014 | "\n",
1015 | " utils.reset(net)\n",
1016 | "\n",
1017 | " for step in range(num_steps):\n",
1018 | " spk, mem = wthr_net(input_data[step])\n",
1019 | " mem_rec.append(mem)\n",
1020 | " spk_rec.append(spk)\n",
1021 | "\n",
1022 | " spk_rec = torch.stack(spk_rec)\n",
1023 | " mem_rec = torch.stack(mem_rec)\n",
1024 | " spk_tot_wthr.append(spk_rec)\n",
1025 | " mem_tot_wthr.append(mem_rec)\n",
1026 | "\n",
1027 | " loss = loss_fn(targets_wthr, mem_rec)\n",
1028 | "\n",
1029 | " # clear previously stored gradients\n",
1030 | " optimizer.zero_grad()\n",
1031 | "\n",
1032 | " # calculate the gradients\n",
1033 | " loss.backward()\n",
1034 | "\n",
1035 | " # weight update\n",
1036 | " optimizer.step()\n",
1037 | "\n",
1038 | "mem_tot_wthr = torch.stack(mem_tot_wthr)\n",
1039 | "spk_tot_wthr = torch.stack(spk_tot_wthr)"
1040 | ],
1041 | "execution_count": null,
1042 | "outputs": []
1043 | },
1044 | {
1045 | "cell_type": "markdown",
1046 | "source": [
1047 | "## 3.5 Plot Membrane Potential"
1048 | ],
1049 | "metadata": {
1050 | "id": "k9aHG9Olv1rT"
1051 | }
1052 | },
1053 | {
1054 | "cell_type": "code",
1055 | "source": [
1056 | "plot_quadrant(mem_tot_wthr, spk_tot_wthr, targets_wthr, spk_target, -1, 60, threshold=50, save=\"spk_time_wthr.png\", epoch1=0, epoch2=75, epoch3=750, fill=True) # save=\"spk_time_flt.png\""
1057 | ],
1058 | "metadata": {
1059 | "id": "lp4uJVMnkpDH"
1060 | },
1061 | "execution_count": null,
1062 | "outputs": []
1063 | },
1064 | {
1065 | "cell_type": "markdown",
1066 | "metadata": {
1067 | "id": "KGttbe8epcvp"
1068 | },
1069 | "source": [
1070 | "## 3.6 Animation of membrane potential\n",
1071 | "\n",
1072 | "This result doesn't fluctuate, but neither does it produce the desired behavior of spiking at the 75th time step - in fact, no spikes at all are produced. \n",
1073 | "\n",
1074 | "The membrane potential staying constant over time indicates the output neuron does not receive any spikes from the previous layer. Rather, it is the bias driving the second layer. The bias slowly increases until it hits roughly the mid-point of the threshold to minimize the overall loss over time.\n",
1075 | "\n",
1076 | "If the bias was removed, the membrane potential would be stuck at zero. So clearly, this doesn't quite work either.\n",
1077 | "\n",
1078 | "Note that the membrane potential falls just short of 25. This can be explained by the final steps of the target being set to 0, which suppresses the overall steady-state response."
1079 | ]
1080 | },
1081 | {
1082 | "cell_type": "code",
1083 | "source": [
1084 | "threshold = 50\n",
1085 | "fig, ax = plt.subplots()\n",
1086 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n",
1087 | "x = np.arange(0, 100, 1) \n",
1088 | "\n",
1089 | "ax.set_xlim(0, num_steps)\n",
1090 | "ax.set_ylim(-1, 60)\n",
1091 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n",
1092 | " verticalalignment='top', transform=ax.transAxes, size='large')\n",
1093 | "\n",
1094 | "ax.set_ylabel('Membrane Potential ($u$)')\n",
1095 | "ax.set_xlabel('Time Steps')\n",
1096 | "ax.plot(targets_wthr[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n",
1097 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
1098 | "\n",
1099 | "\n",
1100 | "def animate(frame_num):\n",
1101 | " line.set_data(x, mem_tot_wthr[frame_num, x, 0,0].cpu().detach().numpy())\n",
1102 | " time_text.set_text(f'Epoch: {frame_num}')\n",
1103 | "\n",
1104 | " # ax.plot([], [], ' ', label=str(frame_num))\n",
1105 | " # ax.legend(loc='upper right')\n",
1106 | " return (line, time_text)\n",
1107 | "\n",
1108 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30)\n",
1109 | "anim.save('spk_time_wthr.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n",
1110 | "\n",
1111 | "video = anim.to_html5_video()\n",
1112 | "html = display.HTML(video)\n",
1113 | "display.display(html)\n",
1114 | "plt.close() # avoid plotting a spare static plot"
1115 | ],
1116 | "metadata": {
1117 | "id": "lNN-KZ_Yk620"
1118 | },
1119 | "execution_count": null,
1120 | "outputs": []
1121 | },
1122 | {
1123 | "cell_type": "markdown",
1124 | "metadata": {
1125 | "id": "KoGgy9igi5WE"
1126 | },
1127 | "source": [
1128 | "# 4. Bounded Homeostasis\n",
1129 | "## 4.1 Define Threshold Annealing Function\n",
1130 | "\n",
1131 | "If we slowly anneal the threshold from a small value to a larger value, this will result in strong spiking activity in early epochs which avoids the dead neuron problem we saw in the previous case where $\\theta=50$.\n",
1132 | "\n",
1133 | "We implement the most naive form of bounded homeostasis (i.e., one that does not depend on the weight update gradient as with other experiments, which can simply referred to as `threshold annealing') below with exponential relaxation of threshold toward a steady state, completely independent of the input data. The same threshold is applied to all neurons in all layers."
1134 | ]
1135 | },
1136 | {
1137 | "cell_type": "code",
1138 | "metadata": {
1139 | "id": "bfrH7cqSjEs3"
1140 | },
1141 | "source": [
1142 | "def thr_annealing(conf, network):\n",
1143 | " alpha_thr1 = conf['alpha_thr1']\n",
1144 | " alpha_thr2 = conf['alpha_thr2']\n",
1145 | "\n",
1146 | " thr_final1 = conf['thr_final1']\n",
1147 | " thr_final2 = conf['thr_final2']\n",
1148 | "\n",
1149 | " with torch.no_grad():\n",
1150 | " network.lif1.threshold += (thr_final1 - network.lif1.threshold) * alpha_thr1\n",
1151 | " network.lif2.threshold += (thr_final2 - network.lif2.threshold) * alpha_thr2"
1152 | ],
1153 | "execution_count": null,
1154 | "outputs": []
1155 | },
1156 | {
1157 | "cell_type": "markdown",
1158 | "metadata": {
1159 | "id": "k7KBgUEFqQEb"
1160 | },
1161 | "source": [
1162 | "## 4.2 Define Hyperparameters\n",
1163 | "As before, we set the final threshold to 50. But let's start with 5.0, and gradually warm it up to 50. `alpha_thr1` and `alpha_thr2` are the inverse time constants of the threshold evolution."
1164 | ]
1165 | },
1166 | {
1167 | "cell_type": "code",
1168 | "metadata": {
1169 | "id": "pV-iBTtCjTAV"
1170 | },
1171 | "source": [
1172 | "config = {\n",
1173 | " \n",
1174 | " 'thr_init1' : 5.0,\n",
1175 | " 'thr_init2' : 5.0,\n",
1176 | "\n",
1177 | " 'alpha_thr1' : 5e-3,\n",
1178 | " 'alpha_thr2' : 5e-3,\n",
1179 | "\n",
1180 | " 'thr_final1' : 50.0,\n",
1181 | " 'thr_final2' : 50.0,\n",
1182 | "}"
1183 | ],
1184 | "execution_count": null,
1185 | "outputs": []
1186 | },
1187 | {
1188 | "cell_type": "code",
1189 | "metadata": {
1190 | "id": "99WfOLJWXsop"
1191 | },
1192 | "source": [
1193 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
1194 | "dtype = torch.float\n",
1195 | "num_steps = 100\n",
1196 | "num_inputs = 100\n",
1197 | "num_hidden = 1000\n",
1198 | "batch_size = 1\n",
1199 | "beta=0.15\n",
1200 | "on_target = config['thr_final2'] + config['thr_final2']*0.1\n",
1201 | "\n",
1202 | "loss_fn = nn.MSELoss() \n",
1203 | "# loss_fn = nn.CrossEntropyLoss()"
1204 | ],
1205 | "execution_count": null,
1206 | "outputs": []
1207 | },
1208 | {
1209 | "cell_type": "markdown",
1210 | "metadata": {
1211 | "id": "0hCzxrlbqmS5"
1212 | },
1213 | "source": [
1214 | "## 4.3 Define Target"
1215 | ]
1216 | },
1217 | {
1218 | "cell_type": "code",
1219 | "metadata": {
1220 | "id": "g8T0SawYc3vI"
1221 | },
1222 | "source": [
1223 | "targets_tha = spikegen.targets_latency(torch.zeros(1, dtype=dtype, device=device), num_classes=1, first_spike_time=75, on_target=on_target, num_steps=num_steps, interpolate=True)"
1224 | ],
1225 | "execution_count": null,
1226 | "outputs": []
1227 | },
1228 | {
1229 | "cell_type": "markdown",
1230 | "metadata": {
1231 | "id": "5CBSUwJxYjaU"
1232 | },
1233 | "source": [
1234 | "## 4.4 Define network"
1235 | ]
1236 | },
1237 | {
1238 | "cell_type": "code",
1239 | "metadata": {
1240 | "id": "upUaiEzci-Dk"
1241 | },
1242 | "source": [
1243 | "class Net(nn.Module):\n",
1244 | " def __init__(self):\n",
1245 | " super().__init__()\n",
1246 | "\n",
1247 | " beta = 0.15\n",
1248 | " spike_grad = surrogate.fast_sigmoid(slope=5)\n",
1249 | "\n",
1250 | " self.fc1 = BinaryLinear(num_inputs, num_hidden)\n",
1251 | " self.fc2 = BinaryLinear(num_hidden, 1)\n",
1252 | "\n",
1253 | " self.lif1 = snn.Leaky(beta=beta, threshold=config['thr_init1'], spike_grad = spike_grad)\n",
1254 | " self.lif2 = snn.Leaky(beta=beta, threshold=config['thr_init2'], spike_grad=spike_grad)\n",
1255 | "\n",
1256 | " def forward(self, x):\n",
1257 | " mem1 = self.lif1.init_leaky() \n",
1258 | " mem2 = self.lif2.init_leaky() \n",
1259 | "\n",
1260 | " spk2_rec = []\n",
1261 | " mem2_rec = []\n",
1262 | "\n",
1263 | " for step in range(x.size(0)):\n",
1264 | " cur1 = self.fc1(x[step])\n",
1265 | " spk1, mem1 = self.lif1(cur1, mem1)\n",
1266 | " cur2 = self.fc2(spk1)\n",
1267 | " spk2, mem2 = self.lif2(cur2, mem2)\n",
1268 | "\n",
1269 | " spk2_rec.append(spk2)\n",
1270 | " mem2_rec.append(mem2)\n",
1271 | " \n",
1272 | " return torch.stack(spk2_rec), torch.stack(mem2_rec)\n",
1273 | "\n",
1274 | "net_tha = Net().to(device)"
1275 | ],
1276 | "execution_count": null,
1277 | "outputs": []
1278 | },
1279 | {
1280 | "cell_type": "markdown",
1281 | "metadata": {
1282 | "id": "iqMrSdqrq0mW"
1283 | },
1284 | "source": [
1285 | "## 4.5 Training Loop"
1286 | ]
1287 | },
1288 | {
1289 | "cell_type": "code",
1290 | "metadata": {
1291 | "id": "AS-4Wn10jA6d"
1292 | },
1293 | "source": [
1294 | "optimizer = torch.optim.SGD(net_tha.parameters(), lr=1e-3, momentum=0.9)\n",
1295 | "num_epochs = 1000\n",
1296 | "mem_tot_tha = []\n",
1297 | "spk_tot_tha = []\n",
1298 | "thr_L1 = []\n",
1299 | "thr_L2 = []\n",
1300 | "\n",
1301 | "for epoch in tqdm(range(num_epochs)):\n",
1302 | "\n",
1303 | " spk_rec, mem_rec = net_tha(input_data)\n",
1304 | " spk_tot_tha.append(spk_rec)\n",
1305 | " mem_tot_tha.append(mem_rec)\n",
1306 | " loss = loss_fn(targets_tha, mem_rec)\n",
1307 | "\n",
1308 | " # clear previously stored gradients\n",
1309 | " optimizer.zero_grad()\n",
1310 | "\n",
1311 | " # calculate the gradients\n",
1312 | " loss.backward()\n",
1313 | "\n",
1314 | " # weight update\n",
1315 | " optimizer.step()\n",
1316 | "\n",
1317 | " thr_L1.append(net_tha.lif1.threshold.item())\n",
1318 | " thr_L2.append(net_tha.lif2.threshold.item())\n",
1319 | "\n",
1320 | " thr_annealing(config, net_tha)\n",
1321 | " \n",
1322 | "\n",
1323 | "mem_tot_tha = torch.stack(mem_tot_tha)\n",
1324 | "spk_tot_tha = torch.stack(spk_tot_tha)"
1325 | ],
1326 | "execution_count": null,
1327 | "outputs": []
1328 | },
1329 | {
1330 | "cell_type": "markdown",
1331 | "metadata": {
1332 | "id": "s8L8DJ7aq4Ov"
1333 | },
1334 | "source": [
1335 | "## 4.6 Plot Membrane Potential"
1336 | ]
1337 | },
1338 | {
1339 | "cell_type": "code",
1340 | "source": [
1341 | "plot_quadrant_tha(mem_tot_tha, spk_tot_tha, targets_tha, spk_target, -1, 60, threshold=thr_L1, save=\"spk_time_tha.png\", epoch1=0, epoch2=100, epoch3=400, fill=True) # save=\"spk_time_flt.png\""
1342 | ],
1343 | "metadata": {
1344 | "id": "n166WYwtnBIL"
1345 | },
1346 | "execution_count": null,
1347 | "outputs": []
1348 | },
1349 | {
1350 | "cell_type": "code",
1351 | "source": [
1352 | "thr_L1[400]"
1353 | ],
1354 | "metadata": {
1355 | "id": "7YQULZmjhbhc"
1356 | },
1357 | "execution_count": null,
1358 | "outputs": []
1359 | },
1360 | {
1361 | "cell_type": "markdown",
1362 | "source": [
1363 | "This is looking quite nice as training progresses! Let's see the animated version to get better insight."
1364 | ],
1365 | "metadata": {
1366 | "id": "x0Hr_q2Pxipv"
1367 | }
1368 | },
1369 | {
1370 | "cell_type": "markdown",
1371 | "source": [
1372 | "## 4.7 Animation of Membrane Potential"
1373 | ],
1374 | "metadata": {
1375 | "id": "4bAvxJ92xe7-"
1376 | }
1377 | },
1378 | {
1379 | "cell_type": "code",
1380 | "source": [
1381 | "threshold = 50\n",
1382 | "fig, ax = plt.subplots()\n",
1383 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n",
1384 | "thr_line, = ax.plot([])\n",
1385 | "thr_text1 = ax.text(0.98, 0.91,'',horizontalalignment='right',verticalalignment='top', transform=ax.transAxes, size='large')\n",
1386 | "x = np.arange(0, 100, 1) \n",
1387 | "\n",
1388 | "ax.set_xlim(0, num_steps)\n",
1389 | "ax.set_ylim(-1, 60)\n",
1390 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n",
1391 | " verticalalignment='top', transform=ax.transAxes, size='large')\n",
1392 | "\n",
1393 | "ax.set_ylabel('Membrane Potential ($u$)')\n",
1394 | "ax.set_xlabel('Time Steps')\n",
1395 | "ax.plot(targets_tha[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n",
1396 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
1397 | "\n",
1398 | "\n",
1399 | "def animate(frame_num):\n",
1400 | " line.set_data(x, mem_tot_tha[frame_num, x, 0,0].cpu().detach().numpy())\n",
1401 | " thr_line.set_data(x, thr_L1[frame_num])\n",
1402 | " thr_text1.set_text(f'Threshold: {thr_L1[frame_num]:.3f}')\n",
1403 | " time_text.set_text(f'Epoch: {frame_num}')\n",
1404 | "\n",
1405 | " # ax.plot([], [], ' ', label=str(frame_num))\n",
1406 | " # ax.legend(loc='upper right')\n",
1407 | " return (line, time_text, thr_text1)\n",
1408 | "\n",
1409 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30) # num_epochs\n",
1410 | "anim.save('spk_time_tha.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n",
1411 | "\n",
1412 | "video = anim.to_html5_video()\n",
1413 | "html = display.HTML(video)\n",
1414 | "display.display(html)\n",
1415 | "plt.close() # avoid plotting a spare static plot"
1416 | ],
1417 | "metadata": {
1418 | "id": "xTqE4_2TnV8c"
1419 | },
1420 | "execution_count": null,
1421 | "outputs": []
1422 | },
1423 | {
1424 | "cell_type": "markdown",
1425 | "source": [
1426 | "To begin with, the several spikes trigger a sudden explosion in activity as the neuron tries to climb its way to $u=50$. Sensory overload. \n",
1427 | "\n",
1428 | "But as the threshold warms up further, activity becomes sparser until finally, the neuron actually hits the desired firing time at several epochs."
1429 | ],
1430 | "metadata": {
1431 | "id": "tIZe7dLiyfwW"
1432 | }
1433 | },
1434 | {
1435 | "cell_type": "markdown",
1436 | "source": [
1437 | "## 4.8 Moving Average of Membrane"
1438 | ],
1439 | "metadata": {
1440 | "id": "PPckpEvxrn8R"
1441 | }
1442 | },
1443 | {
1444 | "cell_type": "code",
1445 | "source": [
1446 | "threshold = 50\n",
1447 | "fig, ax = plt.subplots()\n",
1448 | "line, = ax.plot([]) # A tuple unpacking to unpack the only plot\n",
1449 | "thr_line, = ax.plot([])\n",
1450 | "thr_text1 = ax.text(0.98, 0.91,'',horizontalalignment='right',verticalalignment='top', transform=ax.transAxes, size='large')\n",
1451 | "x = np.arange(0, 100, 1) \n",
1452 | "\n",
1453 | "\n",
1454 | "N = 5 # size of filter\n",
1455 | "mem_avg_tha = uniform_filter1d(mem_tot_tha.cpu().detach(), size=N, axis=1)\n",
1456 | "\n",
1457 | "ax.set_xlim(0, num_steps)\n",
1458 | "ax.set_ylim(-1, 60)\n",
1459 | "time_text = ax.text(0.02, 0.98,'',horizontalalignment='left',\n",
1460 | " verticalalignment='top', transform=ax.transAxes, size='large')\n",
1461 | "\n",
1462 | "ax.set_ylabel('Membrane Potential ($u$)')\n",
1463 | "ax.set_xlabel('Time Steps')\n",
1464 | "ax.plot(targets_tha[:, 0, 0].cpu().detach().numpy(), label='Target', linestyle='dashed')\n",
1465 | "ax.axhline(y=threshold, alpha=0.25, linestyle=\"dashed\", c=\"black\", linewidth=2)\n",
1466 | "\n",
1467 | "\n",
1468 | "def animate(frame_num):\n",
1469 | " line.set_data(x, mem_avg_tha[frame_num, x, 0,0])\n",
1470 | " thr_line.set_data(x, thr_L1[frame_num])\n",
1471 | " thr_text1.set_text(f'Threshold: {thr_L1[frame_num]:.3f}')\n",
1472 | " time_text.set_text(f'Epoch: {frame_num}')\n",
1473 | "\n",
1474 | " # ax.plot([], [], ' ', label=str(frame_num))\n",
1475 | " # ax.legend(loc='upper right')\n",
1476 | " return (line, time_text, thr_text1)\n",
1477 | "\n",
1478 | "anim = FuncAnimation(fig, animate, frames=num_epochs, interval=30) # num_epochs\n",
1479 | "anim.save('spk_time_tha_MVA.mp4', fps=25, extra_args=['-vcodec', 'libx264'], dpi=300)\n",
1480 | "\n",
1481 | "video = anim.to_html5_video()\n",
1482 | "html = display.HTML(video)\n",
1483 | "display.display(html)\n",
1484 | "plt.close() # avoid plotting a spare static plot"
1485 | ],
1486 | "metadata": {
1487 | "id": "je3u03vKrncP"
1488 | },
1489 | "execution_count": null,
1490 | "outputs": []
1491 | },
1492 | {
1493 | "cell_type": "markdown",
1494 | "metadata": {
1495 | "id": "3H4v8Augr4wO"
1496 | },
1497 | "source": [
1498 | "Not only do we see learning taking place, but the values chosen are completely arbitary. When writing this notebook, this was the first result we obtained. It is likely something more precise could be obtained by choosing layer-independent thresholds & annealing rates."
1499 | ]
1500 | }
1501 | ]
1502 | }
--------------------------------------------------------------------------------