Illustrations of the key concepts of the paper: Periodic scheduling can enable SNNs to overcome flat surfaces and local minima. When the LR is boosted during training using a cyclic scheduler, it is given another chance to reduce the loss with different initial conditions. While the loss appears to converge, subsequent LR boosting enables it to traverse more optimal solutions.
8 | 9 | If you find this code useful in your work, please cite the following source: 10 | 11 | ``` 12 | @article{eshraghian2022navigating, 13 | title={{Navigating Local Minima in Quantized Spiking Neural Networks}}, 14 | author={Eshraghian, Jason K and Lammie, Corey and Rahimi Azghadi, Mostafa and Lu, Wei D}, 15 | year={2022}, 16 | eprint={2202.07221}, 17 | archivePrefix={arXiv}, 18 | } 19 | ``` 20 | 21 | ## Jupyter Notebook 22 | We provide a Jupyter notebook [here](https://github.com/jeshraghian/QSNNs/blob/main/quickstart.ipynb), which includes documentation and information about our developed scripts and methodologies. This can be run in a Google Collaboratory environment without any prerequisites [here](https://colab.research.google.com/github/jeshraghian/QSNNs/blob/main/quickstart.ipynb). 23 | 24 | ## Code Execution of Standalone Scripts 25 | For more advanced users, i.e., those proficient with Python, we provide executable code in the form of Python scripts. Simulations can be run by configuring and executing `run.py` in each respective dataset directory. 26 | 27 | ## Requirements 28 | ### Jupyter Notebook 29 | To run the Jupyter notebook, Google Colab can be used. Otherwise, a working `Python` (≥3.6) interpreter and the `pip` package manager are required. 30 | 31 | ### Standalone Scripts 32 | To run all standalone scripts, a working `Python` (≥3.6) interpreter and the `pip` package manager. All required libraries and packages can be installed using `pip install -r requirements.txt`. To avoid potential package conflicts, the use of a `conda` environment is recommended. The following commands can be used to create and activate a separate `conda` environment, clone this repository, and to install all dependencies: 33 | 34 | ``` 35 | conda create -n QSNNs python=3.8 36 | conda activate QSNNs 37 | git clone https://github.com/jeshraghian/QSNNs.git 38 | cd QSNNs 39 | pip install -r requirements.txt 40 | ``` 41 | 42 | ## Hyperparameter Tuning 43 | * In each directory, within `run.py` files, the `config` dictionary defines all configuration parameters and parameters for each dataset. 44 | * The default parameters in this repo are identical to those for the Q4 cosine anneling learning rate schedule configurations reported in the corresponding paper. 45 | 46 | ## Interpreting and Plotting Results 47 | * Results can be gathered and plotted using `extract_test_set_accuracy.py` and `plot_results.py`, respectively. 48 | * `plot_results.py` can be reconfigured to plot different quantities. 49 | * By default, `plot_results.py` plots the loss curve evolution during training for all three datasets. 50 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import snntorch as snn 2 | from snntorch import functional as SF 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | import pandas as pd 7 | import time 8 | from earlystopping import * 9 | from set_all_seeds import set_all_seeds 10 | 11 | 12 | def evaluate(Net, config, load_data, train, test, optim_func): 13 | file_name = config["exp_name"] 14 | for trial in range(config["num_trials_eval"]): 15 | csv_name = file_name + "_t" + str(trial) + ".csv" 16 | model_name = file_name + "_t" + str(trial) + ".pt" 17 | num_epochs = config["num_epochs_eval"] 18 | set_all_seeds(config["seed"] + trial) 19 | df_train_loss = pd.DataFrame() 20 | df_test_acc = pd.DataFrame(columns=["epoch", "test_acc", "train_time"]) 21 | df_lr = pd.DataFrame() 22 | # Initialize the network 23 | net = Net(config) 24 | device = "cpu" 25 | if torch.cuda.is_available(): 26 | device = "cuda" 27 | 28 | net.to(device) 29 | # Initialize the optimizer and scheduler 30 | criterion = SF.mse_count_loss( 31 | correct_rate=config["correct_rate"], incorrect_rate=config["incorrect_rate"] 32 | ) 33 | optimizer, scheduler, loss_dependent = optim_func(net, config) 34 | # Early stopping condition 35 | if config["early_stopping"]: 36 | early_stopping = EarlyStopping_acc( 37 | patience=config["patience"], verbose=True, path=model_name 38 | ) 39 | early_stopping.early_stop = False 40 | early_stopping.best_score = None 41 | 42 | # Load data 43 | trainset, testset = load_data(config) 44 | config["dataset_length"] = len(trainset) 45 | trainloader = DataLoader( 46 | trainset, batch_size=int(config["batch_size"]), shuffle=True 47 | ) 48 | testloader = DataLoader( 49 | testset, batch_size=int(config["batch_size"]), shuffle=False 50 | ) 51 | if loss_dependent: 52 | old_loss_hist = float("inf") 53 | 54 | print( 55 | f"=======Trial: {trial}, Batch: {config['batch_size']}, beta: {config['beta']:.3f}, threshold: {config['threshold']:.2f}, slope: {config['slope']}, lr: {config['lr']:.3e}======" 56 | ) 57 | # Train 58 | for epoch in range(num_epochs): 59 | start_time = time.time() 60 | loss_list, lr_list = train( 61 | config, net, trainloader, criterion, optimizer, device, scheduler 62 | ) 63 | epoch_time = time.time() - start_time 64 | if loss_dependent: 65 | avg_loss_hist = sum(loss_list) / len(loss_list) 66 | if avg_loss_hist > old_loss_hist: 67 | for param_group in optimizer.param_groups: 68 | param_group["lr"] = param_group["lr"] * 0.5 69 | else: 70 | old_loss_hist = avg_loss_hist 71 | 72 | # Test 73 | test_accuracy = test(config, net, testloader, device) 74 | print(f"Epoch: {epoch} \tTest Accuracy: {test_accuracy}") 75 | df_lr = df_lr.append(lr_list, ignore_index=True) 76 | 77 | df_train_loss = df_train_loss.append(loss_list, ignore_index=True) 78 | df_test_acc = df_test_acc.append( 79 | {"epoch": epoch, "test_acc": test_accuracy, "train_time": epoch_time}, 80 | ignore_index=True, 81 | ) 82 | if config["save_csv"]: 83 | df_train_loss.to_csv("loss_" + csv_name, index=False) 84 | df_test_acc.to_csv("acc_" + csv_name, index=False) 85 | df_lr.to_csv("lr_" + csv_name, index=False) 86 | 87 | if config["early_stopping"]: 88 | early_stopping(test_accuracy, net) 89 | if early_stopping.early_stop: 90 | print("Early stopping") 91 | early_stopping.early_stop = False 92 | early_stopping.best_score = None 93 | break 94 | -------------------------------------------------------------------------------- /plot_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.lib.twodim_base import tri 3 | import pandas as pd 4 | import os 5 | import matplotlib.pyplot as plt 6 | from matplotlib.ticker import FormatStrFormatter 7 | import seaborn as sns 8 | 9 | 10 | plt.rcParams["font.family"] = "sans-serif" 11 | plt.rcParams["font.sans-serif"] = ["Arial"] 12 | plt.rcParams["figure.figsize"] = (32.5, 10) 13 | plt.rcParams.update({"font.size": 18}) 14 | plt.rcParams["axes.linewidth"] = 2 15 | plt.rcParams["axes.formatter.limits"] = [-5, 4] 16 | 17 | fig, ax = plt.subplots(1, 3) 18 | 19 | data = { 20 | 4: { 21 | "MNIST": { 22 | "cosine": [ 23 | "mnist/loss_MNIST_t0.csv", 24 | "mnist/loss_MNIST_t1.csv", 25 | "mnist/loss_MNIST_t2.csv", 26 | ], 27 | }, 28 | "FashionMNIST": { 29 | "cosine": [ 30 | "fmnist/loss_FMNIST_t0.csv", 31 | "fmnist/loss_FMNIST_t1.csv", 32 | "fmnist/loss_FMNIST_t2.csv", 33 | ], 34 | }, 35 | "DVS128 Gesture": { 36 | "cosine": [ 37 | "DVS/loss_DVS_t0.csv", 38 | "DVS/loss_DVS_t1.csv", 39 | "DVS/loss_DVS_t2.csv", 40 | ], 41 | }, 42 | }, 43 | } 44 | 45 | df = pd.DataFrame( 46 | columns=["dataset", "network_precision", "scheduler", "idx", "mean", "std"] 47 | ) 48 | for precision in data.keys(): 49 | for dataset_idx, dataset in enumerate(data[precision].keys()): 50 | for scheduler in data[precision][dataset].keys(): 51 | grouped_trial_df = pd.DataFrame(columns=["idx", "loss"]) 52 | for idx, trial in enumerate(data[precision][dataset][scheduler]): 53 | if trial is not None: 54 | trial_df = pd.read_csv(trial) 55 | trial_data = np.vstack( 56 | (trial_df.index, trial_df.values.flatten()) 57 | ).T 58 | trial_df = pd.DataFrame(trial_data, columns=["idx", "loss"]) 59 | grouped_trial_df = grouped_trial_df.append(trial_df) 60 | else: 61 | grouped_trial_df = grouped_trial_df.append( 62 | {"idx": 0, "loss": 1}, ignore_index=True 63 | ) 64 | 65 | grouped_trial_df["loss"] = pd.to_numeric(grouped_trial_df["loss"]) 66 | grouped_trial_df_ = grouped_trial_df.groupby("idx") 67 | grouped_trial_data = np.hstack( 68 | ( 69 | np.expand_dims(grouped_trial_df["idx"].unique(), 1), 70 | grouped_trial_df_.mean(), 71 | grouped_trial_df_.std(), 72 | ) 73 | ) 74 | grouped_trial_data = np.nan_to_num(grouped_trial_data, nan=0) 75 | trial_df = pd.DataFrame(grouped_trial_data, columns=["idx", "mean", "std"]) 76 | trial_df["network_precision"] = precision 77 | trial_df["scheduler"] = scheduler 78 | trial_df["dataset"] = dataset 79 | df = df.append(trial_df, ignore_index=True) 80 | 81 | df = df[df["idx"] % 250 == 0] 82 | df.to_csv("loss_data.csv") 83 | df_quant = df[df["network_precision"] == 4] 84 | del df 85 | # Separate out dataframes to independently take moving avgs 86 | df_mnist_quant = df_quant[df_quant["dataset"] == "MNIST"] 87 | df_fmnist_quant = df_quant[df_quant["dataset"] == "FashionMNIST"] 88 | df_dvs_quant = df_quant[df_quant["dataset"] == "DVS128 Gesture"] 89 | del df_quant 90 | df_mnist_quant["mean_rolling"] = ( 91 | df_mnist_quant.iloc[:, 4].rolling(window=20, min_periods=1).mean() 92 | ) 93 | df_mnist_quant["std_rolling"] = ( 94 | df_mnist_quant.iloc[:, 5].rolling(window=20, min_periods=1).mean() 95 | ) 96 | df_mnist_quant = df_mnist_quant.dropna() 97 | df_fmnist_quant["mean_rolling"] = ( 98 | df_fmnist_quant.iloc[:, 4].rolling(window=20, min_periods=1).mean() 99 | ) 100 | df_fmnist_quant["std_rolling"] = ( 101 | df_fmnist_quant.iloc[:, 5].rolling(window=20, min_periods=1).mean() 102 | ) 103 | df_fmnist_quant = df_fmnist_quant.dropna() 104 | df_dvs_quant["mean_rolling"] = ( 105 | df_dvs_quant.iloc[:, 4].rolling(window=5, min_periods=1).mean() 106 | ) 107 | df_dvs_quant["std_rolling"] = ( 108 | df_dvs_quant.iloc[:, 5].rolling(window=5, min_periods=1).mean() 109 | ) 110 | df_dvs_quant = df_dvs_quant.dropna() 111 | # Combine them 112 | frames = [df_mnist_quant, df_fmnist_quant, df_dvs_quant] 113 | df = pd.concat(frames, ignore_index=True) 114 | # Plot rolling avgs or raw mean/std 115 | col_name = "mean_rolling" # or mean 116 | std_name = "std_rolling" # or std 117 | y_axis_limits = [[0.0008, 0.003], [0.004, 0.0075], [0.001, 0.0125]] 118 | palette = sns.color_palette("bright", 4) 119 | for precision_idx, precision in enumerate(data.keys()): 120 | for dataset_idx, dataset in enumerate(data[precision].keys()): 121 | df_tmp = df[df["network_precision"] == precision] 122 | df_ = df_tmp[df_tmp["dataset"] == dataset] 123 | # Plot mean values 124 | sns.lineplot( 125 | data=df_, 126 | x="idx", 127 | y=col_name, 128 | hue="scheduler", 129 | ax=ax[dataset_idx], 130 | linewidth=2.5, 131 | palette=palette, 132 | ) # alpha=0.95 133 | # Manually plot error bounds 134 | for scheduler_idx, scheduler in enumerate(data[precision][dataset].keys()): 135 | df__ = df_[df_["scheduler"] == scheduler] 136 | x = df__["idx"].values 137 | try: 138 | lower = df__[col_name].values - df__[std_name].values 139 | upper = df__[col_name].values + df__[std_name].values 140 | ax[dataset_idx].plot(x, lower, color=palette[scheduler_idx], alpha=0.2) 141 | ax[dataset_idx].plot(x, upper, color=palette[scheduler_idx], alpha=0.2) 142 | ax[dataset_idx].spines["top"].set_visible(False) 143 | ax[dataset_idx].spines["right"].set_visible(False) 144 | ax[dataset_idx].fill_between(x, lower, upper, alpha=0.1) 145 | except: 146 | pass 147 | 148 | ax[dataset_idx].set_title(dataset) 149 | ax[dataset_idx].set_xlim([0, None]) 150 | ax[dataset_idx].set_ylim(y_axis_limits[dataset_idx]) 151 | ax[dataset_idx].yaxis.set_major_formatter(FormatStrFormatter("%.5f")) 152 | ax[dataset_idx].grid() 153 | ax[dataset_idx].set_xlabel("Minibatch") 154 | ax[dataset_idx].set_ylabel("MSE Loss") 155 | print(precision, dataset) 156 | 157 | plt.show() 158 | -------------------------------------------------------------------------------- /quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "Illustrations of the key concepts of the paper: Periodic scheduling can enable SNNs to overcome flat surfaces and local minima. When the LR is boosted during training using a cyclic scheduler, it is given another chance to reduce the loss with different initial conditions. While the loss appears to converge, subsequent LR boosting enables it to traverse more optimal solutions.
\n" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "b68d7bb4", 32 | "metadata": {}, 33 | "source": [ 34 | "## Install All Required Packages and Import Necessary Libraries" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": null, 40 | "id": "hDnIEHOKB8LD", 41 | "metadata": { 42 | "id": "hDnIEHOKB8LD" 43 | }, 44 | "outputs": [], 45 | "source": [ 46 | "import urllib.request\n", 47 | "urllib.request.urlretrieve('https://raw.githubusercontent.com/jeshraghian/QSNNs/main/requirements.txt', 'requirements.txt')\n", 48 | "!pip install -r requirements.txt --quiet\n", 49 | "import torch, torch.nn as nn\n", 50 | "import snntorch as snn\n", 51 | "import brevitas.nn as qnn" 52 | ] 53 | }, 54 | { 55 | "cell_type": "markdown", 56 | "id": "EYf13Gtx1OCj", 57 | "metadata": { 58 | "id": "EYf13Gtx1OCj" 59 | }, 60 | "source": [ 61 | "## Create a Dataloader for the FashionMNIST Dataset" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "id": "17e61945", 67 | "metadata": {}, 68 | "source": [ 69 | "Download and apply transforms to the FashionMNIST dataset." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "eo4T5MC21hgD", 76 | "metadata": { 77 | "id": "eo4T5MC21hgD" 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "from torch.utils.data import DataLoader\n", 82 | "from torchvision import datasets, transforms\n", 83 | "\n", 84 | "\n", 85 | "data_path='/data/fmnist' # Directory where FMNIST dataset is stored\n", 86 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\") # Use GPU if available\n", 87 | "\n", 88 | "# Define a transform to normalize data\n", 89 | "transform = transforms.Compose([\n", 90 | " transforms.Resize((28, 28)),\n", 91 | " transforms.Grayscale(),\n", 92 | " transforms.ToTensor(),\n", 93 | " transforms.Normalize((0,), (1,))])\n", 94 | "\n", 95 | "# Download and load the training and test FashionMNIST datasets\n", 96 | "fmnist_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform)\n", 97 | "fmnist_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "CHcNZT-7iCQH", 103 | "metadata": { 104 | "id": "CHcNZT-7iCQH" 105 | }, 106 | "source": [ 107 | "To speed-up simulations for demonstration purposes, the below code cell can be run to reduce the number of samples in the training and test sets by a factor of 10." 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "q5bhKdF_h7qk", 114 | "metadata": { 115 | "id": "q5bhKdF_h7qk" 116 | }, 117 | "outputs": [], 118 | "source": [ 119 | "from snntorch import utils\n", 120 | "\n", 121 | "\n", 122 | "utils.data_subset(fmnist_train, 10)\n", 123 | "utils.data_subset(fmnist_test, 10)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "bLmrQ5pEiSSJ", 129 | "metadata": { 130 | "id": "bLmrQ5pEiSSJ" 131 | }, 132 | "source": [ 133 | "Create DataLoaders with batches of 128 samples and shuffle the training set." 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "xstp4mn_iRxi", 140 | "metadata": { 141 | "id": "xstp4mn_iRxi" 142 | }, 143 | "outputs": [], 144 | "source": [ 145 | "batch_size = 128 # Batches of 128 samples\n", 146 | "trainloader = DataLoader(fmnist_train, batch_size=batch_size, shuffle=True)\n", 147 | "testloader = DataLoader(fmnist_test, batch_size=batch_size, shuffle=False)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "i3A4exp_c0c5", 153 | "metadata": { 154 | "id": "i3A4exp_c0c5" 155 | }, 156 | "source": [ 157 | "## Define Network Parameters" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "vrt2wObbiXSf", 163 | "metadata": { 164 | "id": "vrt2wObbiXSf" 165 | }, 166 | "source": [ 167 | "We have only specified 15 epochs without early stopping as a quick, early demonstration. Feel free to increase this. " 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "ivhGn7Lhc6te", 174 | "metadata": { 175 | "id": "ivhGn7Lhc6te" 176 | }, 177 | "outputs": [], 178 | "source": [ 179 | "config = {\n", 180 | " \"num_epochs\": 15, # Number of epochs to train for (per trial)\n", 181 | " \"batch_size\": 128, # Batch size\n", 182 | " \"seed\": 0, # Random seed\n", 183 | " \n", 184 | " # Quantization\n", 185 | " \"num_bits\": 4, # Bit resolution\n", 186 | " \n", 187 | " # Network parameters\n", 188 | " \"grad_clip\": False, # Whether or not to clip gradients\n", 189 | " \"weight_clip\": False, # Whether or not to clip weights\n", 190 | " \"batch_norm\": True, # Whether or not to use batch normalization\n", 191 | " \"dropout\": 0.07, # Dropout rate\n", 192 | " \"beta\": 0.97, # Decay rate parameter (beta)\n", 193 | " \"threshold\": 2.5, # Threshold parameter (theta)\n", 194 | " \"lr\": 3.0e-3, # Initial learning rate\n", 195 | " \"slope\": 5.6, # Slope value (k)\n", 196 | " \n", 197 | " # Fixed params\n", 198 | " \"num_steps\": 100, # Number of timesteps to encode input for\n", 199 | " \"correct_rate\": 0.8, # Correct rate\n", 200 | " \"incorrect_rate\": 0.2, # Incorrect rate\n", 201 | " \"betas\": (0.9, 0.999), # Adam optimizer beta values\n", 202 | " \"t_0\": 4690, # Initial frequency of the cosine annealing scheduler\n", 203 | " \"eta_min\": 0, # Minimum learning rate\n", 204 | "}" 205 | ] 206 | }, 207 | { 208 | "cell_type": "markdown", 209 | "id": "BtJBOtez11wy", 210 | "metadata": { 211 | "id": "BtJBOtez11wy" 212 | }, 213 | "source": [ 214 | "## Define the Network Architecture\n", 215 | "* 5 $\\times$ Conv Layer w/16 Filters\n", 216 | "* 2 $\\times$ 2 Average Pooling\n", 217 | "* 5 $\\times$ Conv Layer w/64 Filters\n", 218 | "* 2 $\\times$ 2 Average Pooling\n", 219 | "* (64 $\\times$ 4 $\\times$ 4) -- 10 Dense Layer" 220 | ] 221 | }, 222 | { 223 | "cell_type": "code", 224 | "execution_count": null, 225 | "id": "JM2thnrc10rD", 226 | "metadata": { 227 | "id": "JM2thnrc10rD" 228 | }, 229 | "outputs": [], 230 | "source": [ 231 | "from snntorch import surrogate\n", 232 | "import torch.nn.functional as F\n", 233 | "\n", 234 | "\n", 235 | "class Net(nn.Module):\n", 236 | " def __init__(self, config):\n", 237 | " super().__init__()\n", 238 | " self.num_bits = config[\"num_bits\"]\n", 239 | " self.thr = config[\"threshold\"]\n", 240 | " self.slope = config[\"slope\"]\n", 241 | " self.beta = config[\"beta\"]\n", 242 | " self.num_steps = config[\"num_steps\"]\n", 243 | " self.batch_norm = config[\"batch_norm\"]\n", 244 | " self.p1 = config[\"dropout\"]\n", 245 | " self.spike_grad = surrogate.fast_sigmoid(self.slope)\n", 246 | " \n", 247 | " # Initialize Layers\n", 248 | " self.conv1 = qnn.QuantConv2d(1, 16, 5, bias=False, weight_bit_width=self.num_bits)\n", 249 | " self.conv1_bn = nn.BatchNorm2d(16)\n", 250 | " self.lif1 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)\n", 251 | " self.conv2 = qnn.QuantConv2d(16, 64, 5, bias=False, weight_bit_width=self.num_bits)\n", 252 | " self.conv2_bn = nn.BatchNorm2d(64)\n", 253 | " self.lif2 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)\n", 254 | " self.fc1 = qnn.QuantLinear(64 * 4 * 4, 10, bias=False, weight_bit_width=self.num_bits)\n", 255 | " self.lif3 = snn.Leaky(self.beta, threshold=self.thr, spike_grad=self.spike_grad)\n", 256 | " self.dropout = nn.Dropout(self.p1)\n", 257 | "\n", 258 | " def forward(self, x):\n", 259 | " # Initialize hidden states and outputs at t=0\n", 260 | " mem1 = self.lif1.init_leaky()\n", 261 | " mem2 = self.lif2.init_leaky()\n", 262 | " mem3 = self.lif3.init_leaky()\n", 263 | "\n", 264 | " # Record the final layer\n", 265 | " spk3_rec = []\n", 266 | " mem3_rec = []\n", 267 | "\n", 268 | " # Forward pass\n", 269 | " for step in range(self.num_steps):\n", 270 | " cur1 = F.avg_pool2d(self.conv1(x), 2)\n", 271 | " if self.batch_norm:\n", 272 | " cur1 = self.conv1_bn(cur1)\n", 273 | "\n", 274 | " spk1, mem1 = self.lif1(cur1, mem1)\n", 275 | " cur2 = F.avg_pool2d(self.conv2(spk1), 2)\n", 276 | " if self.batch_norm:\n", 277 | " cur2 = self.conv2_bn(cur2)\n", 278 | "\n", 279 | " spk2, mem2 = self.lif2(cur2, mem2)\n", 280 | " cur3 = self.dropout(self.fc1(spk2.flatten(1)))\n", 281 | " spk3, mem3 = self.lif3(cur3, mem3)\n", 282 | " spk3_rec.append(spk3)\n", 283 | " mem3_rec.append(mem3)\n", 284 | "\n", 285 | " return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)\n", 286 | "\n", 287 | "net = Net(config).to(device)" 288 | ] 289 | }, 290 | { 291 | "cell_type": "markdown", 292 | "id": "BmtJx_AAeOyP", 293 | "metadata": { 294 | "id": "BmtJx_AAeOyP" 295 | }, 296 | "source": [ 297 | "## Define the Optimizer, Learning Rate Scheduler, and Loss Function\n", 298 | "* Adam optimizer\n", 299 | "* Cosine Annealing Scheduler\n", 300 | "* MSE Spike Count Loss (Target spike count for correct and incorrect classes are specified)" 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "id": "ky-qAN_YeKmE", 307 | "metadata": { 308 | "id": "ky-qAN_YeKmE" 309 | }, 310 | "outputs": [], 311 | "source": [ 312 | "import snntorch.functional as SF\n", 313 | "\n", 314 | "\n", 315 | "optimizer = torch.optim.Adam(net.parameters(), \n", 316 | " lr=config[\"lr\"], betas=config[\"betas\"]\n", 317 | ")\n", 318 | "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, \n", 319 | " T_max=config[\"t_0\"], \n", 320 | " eta_min=config[\"eta_min\"], \n", 321 | " last_epoch=-1\n", 322 | ")\n", 323 | "criterion = SF.mse_count_loss(correct_rate=config[\"correct_rate\"], \n", 324 | " incorrect_rate=config[\"incorrect_rate\"]\n", 325 | ")" 326 | ] 327 | }, 328 | { 329 | "cell_type": "markdown", 330 | "id": "UGtJwmtVexb4", 331 | "metadata": { 332 | "id": "UGtJwmtVexb4" 333 | }, 334 | "source": [ 335 | "## Train and Evaluate the Network" 336 | ] 337 | }, 338 | { 339 | "cell_type": "markdown", 340 | "id": "2321a02f", 341 | "metadata": {}, 342 | "source": [ 343 | "As the learning rate follows a periodic schedule, the accuracy will oscillate across the training process, but with a general tendency to improve." 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": null, 349 | "id": "tbOQgPiEe-lp", 350 | "metadata": { 351 | "id": "tbOQgPiEe-lp" 352 | }, 353 | "outputs": [], 354 | "source": [ 355 | "def train(config, net, trainloader, criterion, optimizer, device=\"cpu\", scheduler=None):\n", 356 | " \"\"\"Complete one epoch of training.\"\"\"\n", 357 | " \n", 358 | " net.train()\n", 359 | " loss_accum = []\n", 360 | " lr_accum = []\n", 361 | " i = 0\n", 362 | " for data, labels in trainloader:\n", 363 | " data, labels = data.to(device), labels.to(device)\n", 364 | " spk_rec, _ = net(data)\n", 365 | " loss = criterion(spk_rec, labels)\n", 366 | " optimizer.zero_grad()\n", 367 | " loss.backward()\n", 368 | "\n", 369 | " ## Enable gradient clipping\n", 370 | " if config[\"grad_clip\"]:\n", 371 | " nn.utils.clip_grad_norm_(net.parameters(), 1.0)\n", 372 | "\n", 373 | " ## Enable weight clipping\n", 374 | " if config[\"weight_clip\"]:\n", 375 | " with torch.no_grad():\n", 376 | " for param in net.parameters():\n", 377 | " param.clamp_(-1, 1)\n", 378 | "\n", 379 | " optimizer.step()\n", 380 | " scheduler.step()\n", 381 | " loss_accum.append(loss.item() / config[\"num_steps\"])\n", 382 | " lr_accum.append(optimizer.param_groups[0][\"lr\"])\n", 383 | "\n", 384 | " return loss_accum, lr_accum\n", 385 | "\n", 386 | "def test(config, net, testloader, device=\"cpu\"):\n", 387 | " \"\"\"Calculate accuracy on full test set.\"\"\"\n", 388 | " correct = 0\n", 389 | " total = 0\n", 390 | " with torch.no_grad():\n", 391 | " net.eval()\n", 392 | " for data in testloader:\n", 393 | " images, labels = data\n", 394 | " images, labels = images.to(device), labels.to(device)\n", 395 | " outputs, _ = net(images)\n", 396 | " accuracy = SF.accuracy_rate(outputs, labels)\n", 397 | " total += labels.size(0)\n", 398 | " correct += accuracy * labels.size(0)\n", 399 | "\n", 400 | " return 100 * correct / total\n", 401 | "\n", 402 | "loss_list = []\n", 403 | "lr_list = []\n", 404 | "\n", 405 | "print(f\"=======Training Network=======\")\n", 406 | "# Train\n", 407 | "for epoch in range(config['num_epochs']):\n", 408 | " loss, lr = train(config, net, trainloader, criterion, optimizer, \n", 409 | " device, scheduler\n", 410 | " )\n", 411 | " loss_list = loss_list + loss\n", 412 | " lr_list = lr_list + lr\n", 413 | " # Test\n", 414 | " test_accuracy = test(config, net, testloader, device)\n", 415 | " print(f\"Epoch: {epoch} \\tTest Accuracy: {test_accuracy}\")" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "id": "14d0bd78", 421 | "metadata": {}, 422 | "source": [ 423 | "## Plot the Training Loss and Learning Rate Over Time" 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": null, 429 | "id": "B22SnaTElOLh", 430 | "metadata": { 431 | "id": "B22SnaTElOLh" 432 | }, 433 | "outputs": [], 434 | "source": [ 435 | "%matplotlib inline\n", 436 | "import matplotlib.pyplot as plt\n", 437 | "import seaborn as sns\n", 438 | "\n", 439 | "\n", 440 | "sns.set_theme()\n", 441 | "fig, ax1 = plt.subplots()\n", 442 | "ax2 = ax1.twinx()\n", 443 | "ax1.plot(loss_list, color='tab:orange')\n", 444 | "ax2.plot(lr_list, color='tab:blue')\n", 445 | "ax1.set_xlabel('Iteration')\n", 446 | "ax1.set_ylabel('Loss', color='tab:orange')\n", 447 | "ax2.set_ylabel('Learning Rate', color='tab:blue')\n", 448 | "plt.show()" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "id": "-iSGTq0Q3Lcm", 454 | "metadata": { 455 | "id": "-iSGTq0Q3Lcm" 456 | }, 457 | "source": [ 458 | "# Conclusion\n", 459 | "That's it for the quick intro to quantized SNNs! Results can be further improved by not using the `snntorch.utils.data_subset` method to train with the full FashionMNIST dataset, training for a larger number of epochs, and utilizing early stopping logic.\n", 460 | "\n", 461 | "To run the experiments from the corresponding paper, including those on dynamic datasets, please [refer to the corresponding GitHub repo](https://github.com/jeshraghian/QSNNs/)." 462 | ] 463 | } 464 | ], 465 | "metadata": { 466 | "accelerator": "GPU", 467 | "colab": { 468 | "include_colab_link": true, 469 | "name": "Copy of tutorial_5_neuromorphic_datasets.ipynb", 470 | "provenance": [] 471 | }, 472 | "kernelspec": { 473 | "display_name": "Python 3 (ipykernel)", 474 | "language": "python", 475 | "name": "python3" 476 | }, 477 | "language_info": { 478 | "codemirror_mode": { 479 | "name": "ipython", 480 | "version": 3 481 | }, 482 | "file_extension": ".py", 483 | "mimetype": "text/x-python", 484 | "name": "python", 485 | "nbconvert_exporter": "python", 486 | "pygments_lexer": "ipython3", 487 | "version": "3.8.11" 488 | } 489 | }, 490 | "nbformat": 4, 491 | "nbformat_minor": 5 492 | } 493 | --------------------------------------------------------------------------------