├── .gitignore ├── static ├── fabrizio.jpeg ├── jason_eshraghian.jpg └── hls4nm-flow-horizontal.png ├── LICENSE ├── hw ├── preprocess_dvsgesture.py └── main.py ├── README.md └── software ├── ISFPGA_SNN.ipynb └── ISFPGA_SNN_cheatsheet.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_store 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /static/fabrizio.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-neuromorphic/fpga-snntorch/main/static/fabrizio.jpeg -------------------------------------------------------------------------------- /static/jason_eshraghian.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-neuromorphic/fpga-snntorch/main/static/jason_eshraghian.jpg -------------------------------------------------------------------------------- /static/hls4nm-flow-horizontal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-neuromorphic/fpga-snntorch/main/static/hls4nm-flow-horizontal.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Open Neuromorphic 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /hw/preprocess_dvsgesture.py: -------------------------------------------------------------------------------- 1 | import tonic 2 | import tonic.transforms as transforms 3 | import numpy as np 4 | import h5py 5 | 6 | DATADIR = "./data" 7 | batch_size=16 8 | sensor_size = (32, 32, 2) 9 | frame_transform_test = transforms.Compose([transforms.Denoise(filter_time=10000), 10 | transforms.Downsample(spatial_factor=0.25), 11 | transforms.ToFrame(sensor_size=sensor_size, 12 | n_time_bins=150) 13 | ]) 14 | 15 | test_ds = tonic.datasets.DVSGesture(save_to=DATADIR, transform=frame_transform_test, train=False) 16 | 17 | nsamples = len(test_ds) 18 | arr = np.empty((nsamples, 150, 2, 32, 32), dtype='> 32) 27 | 28 | overlay.snn_0.write(0x1c, dout.device_address) 29 | overlay.snn_0.write(0x20, dout.device_address >> 32) 30 | 31 | for sample in range(SAMPLES): 32 | # Writing the sample to the IP. 33 | din[:] = data[sample][:] 34 | din.sync_to_device() 35 | 36 | # Starting inference. 37 | overlay.snn_0.write(0x00, 1) 38 | 39 | # Waiting for execution to finish. 40 | while not overlay.snn_0.read(0x00) & 2: 41 | pass 42 | 43 | # Reading out data. 44 | dout.sync_from_device() 45 | tmp = np.zeros((11,)) 46 | for t in range(TIMESTEPS): 47 | for i in range(11): 48 | tmp[i] += dout[t*TIMESTEPS + i] 49 | 50 | results[sample] = np.argmax(tmp) 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ISFPGA: Who needs neuromorphic hardware? Deploying SNNs to FPGAs via HLS 2 | 3 | This repository contains the notebooks related to hardware-aware training of 4 | spiking neural networks presented and their deployment onto FPGAs at ISFPGA 2024 5 | (Monterey, CA) for the workshop *Who needs neuromorphic hardware? Deploying SNNs 6 | to FPGAs via HLS* co-presented by [Jason Eshraghian](https:///ncg.ucsc.edu) and Fabrizio Ottati. 7 | 8 | ![Abstract](/static/hls4nm-flow-horizontal.png) 9 | 10 | ## Abstract 11 | 12 | How can we use natural intelligence to improve artificial intelligence? The human brain is a great place to look to improve modern neural networks and reduce their exorbitant energy costs. 13 | While we may be far from having a complete understanding of the brain, we are at a point where a set of design principles have enabled us to build potentially more efficient deep learning tools. 14 | Most of these are linked back to spiking neural networks (SNNs). In a cruel twist of irony, the neuromorphic hardware that is out there for research and/or commercial use, is considerably more expensive (and often less performant), and harder to obtain than a consumer-grade GPU. 15 | How can we move towards using low-cost hardware that sits on our desk, or fits in a PCIe slot in our desktops, and accelerates SNNs? FPGAs might be the solution. 16 | This tutorial will take a hands-on approach to learning how to train SNNs for hardware deployment on conventional GPUs, and running these models on a embedded class FPGA (AMD Kria KV260) for inference. 17 | FPGA inference is achieved using high level synthesis, employing the AMD Vitis HLS compiler, and using a dataflow architecture[^1] of a deep SNN, with in-hardware testing. 18 | 19 | ## Tutorial 20 | 21 | There are two tutorial components in this repo: 22 | 23 | * One for hardware-aware training spiking neural networks using [snnTorch](https://github.com/jeshraghian/snntorch) 24 | * One for the bitstream, PYNQ and C++ code to run the HLS model on the FPGA. 25 | 26 | ## Software 27 | 28 | | Title | Colab Link | 29 | |-------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------| 30 | | Hardware-Aware Training of Spiking Neural Networks with [snnTorch](https://github.com/jeshraghian/snntorch) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-neuromorphic/fpga-snntorch/blob/main/software/ISFPGA_SNN.ipynb) | 31 | | **Cheat-Sheet:** Hardware-Aware Training of Spiking Neural Networks with [snnTorch](https://github.com/jeshraghian/snntorch) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/open-neuromorphic/fpga-snntorch/blob/main/software/ISFPGA_SNN_cheatsheet.ipynb) | 32 | 33 | ## Hardware 34 | 35 | In the `hw` folder you will find: 36 | * the bistream of the accelerator, implemented for the Kria KV260 AI starter kit: `design_1.bit`. 37 | * the PYNQ script to execute inference: `main.py` and some utils `utils.py`. 38 | * the addresses configuration used to communicate between the Zynq PS and the HLS IP: `xwrapper_hw.h`. 39 | * the hardware handoff file `design_1.hwh`. 40 | 41 | This material will be provided ASAP. 42 | The HLS C++ code will be released in the upcoming weeks. 43 | 44 | ## Speakers 45 | 46 | ### Jason Eshraghian, Assistant Professor, UC Santa Cruz 47 | 48 | Jason Eshraghian 49 | 50 | Jason K. Eshraghian received the B.Eng. (electrical and electronic), L.L.B., and 51 | Ph.D. degrees from The University of Western Australia, Perth, WA, Australia, in 52 | 2016 and 2019, respectively. From 2019 to 2022, he was a Post-Doctoral Research 53 | Fellow at the University of Michigan, Ann Arbor MI, USA. He is currently an 54 | Assistant Professor with the Department of Electrical and Computer Engineering, 55 | The University of California at Santa Cruz, Santa Cruz, CA, USA. His research 56 | interests include neuromorphic computing, resistive random access memory (RRAM) 57 | circuits, and spiking neural networks. 58 | 59 | ### Fabrizio Ottati, AI Computer Architect, NXP Semiconductors 60 | 61 | Fabrizio Ottati 62 | 63 | Fabrizio received his Ph.D. from Politecnico di Torino in 2024, with a thesis on 64 | efficient inference of spiking neural networks on FPGA platforms. He is now an 65 | AI Computer Architect at NXP semiconductors. He is mainly interested in deep 66 | learning and computer architecture, beyond trying to learn how to write decent 67 | code. 68 | 69 | [^1]: Work carried out by Fabrizio Ottati during the Ph.D. at Politecnico di Torino. 70 | -------------------------------------------------------------------------------- /software/ISFPGA_SNN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "colab_type": "text", 7 | "id": "view-in-github" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "8w6lhn7H8fW5" 17 | }, 18 | "source": [ 19 | "# ISFPGA Workshop\n", 20 | "## Who needs neuromorphic hardware? Deploying SNNs to FPGAs via HLS Open-Source Neuromorphic Circuit Design\n", 21 | "### By Jason K. Eshraghian (www.ncg.ucsc.edu)\n", 22 | "\n", 23 | "\n", 24 | "[](https://github.com/jeshraghian/snntorch/)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": { 31 | "id": "1BlTqunB73-t" 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "!pip install snntorch --quiet # shift + enter" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": { 41 | "id": "xLlnINAI9mgJ" 42 | }, 43 | "source": [ 44 | "*What will I learn?*\n", 45 | "\n", 46 | "1. Train an SNN classifier using snnTorch\n", 47 | "2. Hardware Friendly Training\n", 48 | " - Weight Quantization with Brevitas\n", 49 | " - Stateful Quantization\n", 50 | "3. Handling neuromorphic data with Tonic" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": { 56 | "id": "-CS50cwuCW6n" 57 | }, 58 | "source": [ 59 | "# 1. Train an SNN Classifier using snnTorch\n", 60 | "## 1.1 Imports\n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": { 67 | "id": "H_TzogsCCcSe" 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# snntorch imports\n", 72 | "import snntorch as snn\n", 73 | "from snntorch import functional as SF\n", 74 | "\n", 75 | "# pytorch imports\n", 76 | "import torch\n", 77 | "import torch.nn as nn\n", 78 | "from torch.utils.data import DataLoader\n", 79 | "from torchvision import datasets, transforms\n", 80 | "\n", 81 | "# data manipulation\n", 82 | "import numpy as np\n", 83 | "import itertools\n", 84 | "\n", 85 | "# plotting\n", 86 | "import matplotlib.pyplot as plt\n", 87 | "from IPython.display import HTML" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": { 93 | "id": "nftOdpyAGv7D" 94 | }, 95 | "source": [ 96 | "## 1.2 Boilerplate: DataLoading the MNIST Dataset" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "id": "SsM2Z5NXGu5z" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "# dataloader arguments\n", 108 | "batch_size = 128\n", 109 | "data_path='/data/mnist'\n", 110 | "\n", 111 | "dtype = torch.float\n", 112 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", 113 | "## if you're on M1 or M2 GPU:\n", 114 | "# device = torch.device(\"mps\")" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": { 121 | "id": "XqbYptgDHUPg" 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "# Define a transform\n", 126 | "transform = transforms.Compose([\n", 127 | " transforms.Resize((28, 28)),\n", 128 | " transforms.Grayscale(),\n", 129 | " transforms.ToTensor(),\n", 130 | " transforms.Normalize((0,), (1,))])\n", 131 | "\n", 132 | "mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n", 133 | "mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "id": "jSlS3gWZHXI0" 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "# Create DataLoaders\n", 145 | "train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)\n", 146 | "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": { 152 | "id": "SJVhfNukHbsp" 153 | }, 154 | "source": [ 155 | "## 1.3 Construct SNN Model" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": { 162 | "id": "uu324fr_HhxV" 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "# Network Architecture\n", 167 | "num_inputs =\n", 168 | "num_hidden =\n", 169 | "num_outputs =\n", 170 | "\n", 171 | "# Temporal Dynamics\n", 172 | "num_steps =\n", 173 | "beta =" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": { 180 | "id": "CkM1Z1EjHeW8" 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "from snntorch import surrogate\n", 185 | "\n", 186 | "# Define Network\n", 187 | "class Net(nn.Module):\n", 188 | " def __init__(self):\n", 189 | " super().__init__()\n", 190 | "\n", 191 | " # Initialize layers\n", 192 | " self.fc1 =\n", 193 | " self.lif1 =\n", 194 | " self.fc2 =\n", 195 | " self.lif2 =\n", 196 | "\n", 197 | " def forward(self, x):\n", 198 | "\n", 199 | " # Initialize hidden states at t=0\n", 200 | " mem1 = self.lif1.init_leaky()\n", 201 | " mem2 = self.lif2.init_leaky()\n", 202 | "\n", 203 | " # Record the final layer\n", 204 | " spk2_rec = []\n", 205 | " mem2_rec = []\n", 206 | "\n", 207 | " # time-loop\n", 208 | " for step in range(num_steps):\n", 209 | " cur1 = self.fc1(...) # batch: 128 x 784\n", 210 | " spk1, mem1 = self.lif1(...)\n", 211 | " cur2 = self.fc2(...)\n", 212 | " spk2, mem2 = self.lif2(...)\n", 213 | "\n", 214 | " # store in list\n", 215 | " spk2_rec.append(spk2)\n", 216 | " mem2_rec.append(mem2)\n", 217 | "\n", 218 | " return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0) # time-steps x batch x num_out\n", 219 | "\n", 220 | "# Load the network onto CUDA if available\n", 221 | "net = Net().to(device)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": { 227 | "id": "p8qBw03rHpn3" 228 | }, 229 | "source": [ 230 | "## 1.4 Training the SNN" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": { 237 | "id": "1telBMU-HrIg" 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "def training_loop(model, dataloader, num_epochs=1):\n", 242 | " loss = nn.CrossEntropyLoss()\n", 243 | " optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999))\n", 244 | " counter = 0\n", 245 | "\n", 246 | " # Outer training loop\n", 247 | " for epoch in range(num_epochs):\n", 248 | " train_batch = iter(dataloader)\n", 249 | "\n", 250 | " # Minibatch training loop\n", 251 | " for data, targets in train_batch:\n", 252 | " data = data.to(device)\n", 253 | " targets = targets.to(device)\n", 254 | "\n", 255 | " # forward pass\n", 256 | " model.train()\n", 257 | " spk_rec, _ = model(data)\n", 258 | "\n", 259 | " # initialize the loss & sum over time\n", 260 | " loss_val = torch.zeros((1), dtype=dtype, device=device)\n", 261 | " loss_val = loss(spk_rec.sum(0), targets) # batch x num_out\n", 262 | "\n", 263 | " # Gradient calculation + weight update\n", 264 | " optimizer.zero_grad()\n", 265 | " loss_val.backward()\n", 266 | " optimizer.step()\n", 267 | "\n", 268 | " # Print train/test loss/accuracy\n", 269 | " if counter % 10 == 0:\n", 270 | " print(f\"Iteration: {counter} \\t Train Loss: {loss_val.item()}\")\n", 271 | " counter += 1\n", 272 | "\n", 273 | " if counter == 100:\n", 274 | " break\n", 275 | "\n", 276 | "training_loop(net, train_loader)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "id": "nHsdkuSlIS1E" 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "def measure_accuracy(model, dataloader):\n", 288 | " with torch.no_grad():\n", 289 | " model.eval()\n", 290 | " running_length = 0\n", 291 | " running_accuracy = 0\n", 292 | "\n", 293 | " for data, targets in iter(dataloader):\n", 294 | " data = data.to(device)\n", 295 | " targets = targets.to(device)\n", 296 | "\n", 297 | " # forward-pass\n", 298 | " spk_rec, _ = model(data)\n", 299 | " spike_count = spk_rec.sum(0) # batch x num_outputs\n", 300 | " _, max_spike = spike_count.max(1)\n", 301 | "\n", 302 | " # correct classes for one batch\n", 303 | " num_correct = (max_spike == targets).sum()\n", 304 | "\n", 305 | " # total accuracy\n", 306 | " running_length += len(targets)\n", 307 | " running_accuracy += num_correct\n", 308 | "\n", 309 | " accuracy = (running_accuracy / running_length)\n", 310 | "\n", 311 | " return accuracy.item()\n" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": { 318 | "id": "oJHAltRCKGyx" 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "print(f\"Test set accuracy: {measure_accuracy(net, test_loader)}\")" 323 | ] 324 | }, 325 | { 326 | "cell_type": "markdown", 327 | "metadata": { 328 | "id": "D5Nc7TTxZWTp" 329 | }, 330 | "source": [ 331 | "### A Sanity Check" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "metadata": { 338 | "id": "ct-kV1rZZVb3" 339 | }, 340 | "outputs": [], 341 | "source": [ 342 | "def print_sample(model, dataloader, idx=0):\n", 343 | " with torch.no_grad():\n", 344 | " model.eval()\n", 345 | "\n", 346 | " data, targets = next(iter(dataloader))\n", 347 | " data = data.to(device)\n", 348 | " targets = targets.to(device)\n", 349 | "\n", 350 | " # forward-pass\n", 351 | " spk_rec, _ = model(data)\n", 352 | " spike_count = spk_rec.sum(0) # batch x num_outputs\n", 353 | " _, max_spike = spike_count.max(1)\n", 354 | "\n", 355 | " # Plot the sample\n", 356 | " plt.imshow(data[idx].cpu().squeeze(), cmap='gray')\n", 357 | " plt.title(f'Target: {targets[idx].item()}')\n", 358 | " plt.show()\n", 359 | "\n", 360 | "\n", 361 | " return" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "metadata": { 368 | "id": "iPtBj7gtadA2" 369 | }, 370 | "outputs": [], 371 | "source": [ 372 | "print_sample(net, test_loader)" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": { 378 | "id": "Z9vrb2zUD6S-" 379 | }, 380 | "source": [ 381 | "# 2. Hardware Friendly Training\n", 382 | "## 2.1 Weight Quantization" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": { 389 | "id": "icK4WzuL-2QA" 390 | }, 391 | "outputs": [], 392 | "source": [ 393 | "!pip install brevitas --quiet" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": { 399 | "id": "idiLnVjJGJAL" 400 | }, 401 | "source": [ 402 | "Just replace all `nn.Linear` layers with `qnn.QuantLinear(num_inputs, num_outputs, weight_bit_width, bias)`." 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": { 409 | "id": "PiZYTMA6D-YL" 410 | }, 411 | "outputs": [], 412 | "source": [ 413 | "import brevitas.nn as qnn\n", 414 | "\n", 415 | "# Define Network\n", 416 | "class QuantNet(nn.Module):\n", 417 | " def __init__(self):\n", 418 | " super().__init__()\n", 419 | "\n", 420 | " # Initialize layers\n", 421 | " self.fc1 =\n", 422 | " self.lif1 =\n", 423 | " self.fc2 =\n", 424 | " self.lif2 =\n", 425 | "\n", 426 | " def forward(self, x):\n", 427 | "\n", 428 | " # Initialize hidden states at t=0\n", 429 | " mem1 = self.lif1.init_leaky()\n", 430 | " mem2 = self.lif2.init_leaky()\n", 431 | "\n", 432 | " # Record the final layer\n", 433 | " spk2_rec = []\n", 434 | " mem2_rec = []\n", 435 | "\n", 436 | " for step in range(num_steps):\n", 437 | " cur1 = self.fc1(x.flatten(1))\n", 438 | " spk1, mem1 = self.lif1(cur1, mem1)\n", 439 | " cur2 = self.fc2(spk1)\n", 440 | " spk2, mem2 = self.lif2(cur2, mem2)\n", 441 | " spk2_rec.append(spk2)\n", 442 | " mem2_rec.append(mem2)\n", 443 | "\n", 444 | " return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)\n", 445 | "\n", 446 | "# Load the network onto CUDA if available\n", 447 | "qnet = QuantNet().to(device)" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": { 454 | "id": "PocSa27MOOoK" 455 | }, 456 | "outputs": [], 457 | "source": [ 458 | "training_loop(qnet, train_loader)\n", 459 | "print(f\"Test set accuracy: {measure_accuracy(qnet, test_loader)}\")" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": { 465 | "id": "JvRzLHdjD-6S" 466 | }, 467 | "source": [ 468 | "## 2.2 SQUAT: Stateful Quantization-Aware Training\n", 469 | "\n" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "id": "tAaafUczEAu_" 477 | }, 478 | "outputs": [], 479 | "source": [ 480 | "from snntorch.functional import quant\n", 481 | "\n", 482 | "# Define Network\n", 483 | "class SquatNet(nn.Module):\n", 484 | " def __init__(self):\n", 485 | " super().__init__()\n", 486 | "\n", 487 | " # Define state quantization parameters\n", 488 | " q_lif = quant.state_quant(num_bits=4, uniform=True)\n", 489 | "\n", 490 | " # Initialize layers\n", 491 | " self.fc1 =\n", 492 | " self.lif1 =\n", 493 | " self.fc2 =\n", 494 | " self.lif2 =\n", 495 | "\n", 496 | " def forward(self, x):\n", 497 | "\n", 498 | " # Initialize hidden states at t=0\n", 499 | " mem1 = self.lif1.init_leaky()\n", 500 | " mem2 = self.lif2.init_leaky()\n", 501 | "\n", 502 | " # Record the final layer\n", 503 | " spk2_rec = []\n", 504 | " mem2_rec = []\n", 505 | "\n", 506 | " for step in range(num_steps):\n", 507 | " cur1 = self.fc1(x.flatten(1))\n", 508 | " spk1, mem1 = self.lif1(cur1, mem1)\n", 509 | " cur2 = self.fc2(spk1)\n", 510 | " spk2, mem2 = self.lif2(cur2, mem2)\n", 511 | " spk2_rec.append(spk2)\n", 512 | " mem2_rec.append(mem2)\n", 513 | "\n", 514 | " return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)\n", 515 | "\n", 516 | "# Load the network onto CUDA if available\n", 517 | "sqnet = SquatNet().to(device)" 518 | ] 519 | }, 520 | { 521 | "cell_type": "code", 522 | "execution_count": null, 523 | "metadata": { 524 | "id": "0zfo0h6AQDzS" 525 | }, 526 | "outputs": [], 527 | "source": [ 528 | "training_loop(sqnet, train_loader)\n", 529 | "print(f\"Test set accuracy: {measure_accuracy(sqnet, test_loader)}\")" 530 | ] 531 | }, 532 | { 533 | "cell_type": "markdown", 534 | "metadata": { 535 | "id": "EZAFVgHW4M2g" 536 | }, 537 | "source": [ 538 | "# 3. Handling Neuromorphic Data with Tonic" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": null, 544 | "metadata": { 545 | "id": "CJvN9EH7b7XL" 546 | }, 547 | "outputs": [], 548 | "source": [ 549 | "!pip install tonic --quiet" 550 | ] 551 | }, 552 | { 553 | "cell_type": "markdown", 554 | "metadata": { 555 | "id": "zB9oK66mcCYt" 556 | }, 557 | "source": [ 558 | "## 3.1 PokerDVS Dataset\n", 559 | "\n", 560 | "The dataset used in this tutorial is POKERDVS by T. Serrano-Gotarredona and B. Linares-Barranco:\n", 561 | "\n", 562 | "```\n", 563 | "Serrano-Gotarredona, Teresa, and Bernabé Linares-Barranco. \"Poker-DVS and MNIST-DVS. Their history, how they were made, and other details.\" Frontiers in neuroscience 9 (2015): 481.\n", 564 | "```\n", 565 | "\n", 566 | "It is comprised of four classes, each being a suite of a playing card deck: clubs, spades, hearts, and diamonds. The data consists of 131 poker pip symbols, and was collected by flipping poker cards in front of a DVS128 camera." 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "metadata": { 573 | "id": "s-TSV9h2hGkW" 574 | }, 575 | "outputs": [], 576 | "source": [ 577 | "import tonic\n", 578 | "\n", 579 | "poker_train = tonic.datasets.POKERDVS(save_to='./data', train=True)\n", 580 | "poker_test = tonic.datasets.POKERDVS(save_to='./data', train=False)\n", 581 | "\n", 582 | "events, target = poker_train[0]\n", 583 | "print(events)\n", 584 | "tonic.utils.plot_event_grid(events)" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": null, 590 | "metadata": { 591 | "id": "Tcz7CE4wb6Gs" 592 | }, 593 | "outputs": [], 594 | "source": [ 595 | "import tonic.transforms as transforms\n", 596 | "from tonic import DiskCachedDataset\n", 597 | "\n", 598 | "# time_window\n", 599 | "frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=10000),\n", 600 | " tonic.transforms.ToFrame(\n", 601 | " sensor_size=tonic.datasets.POKERDVS.sensor_size,\n", 602 | " time_window=1000)\n", 603 | " ])\n", 604 | "\n", 605 | "batch_size = 8\n", 606 | "cached_trainset = DiskCachedDataset(poker_train, transform=frame_transform, cache_path='./cache/pokerdvs/train')\n", 607 | "cached_testset = DiskCachedDataset(poker_test, transform=frame_transform, cache_path='./cache/pokerdvs/test')\n", 608 | "\n", 609 | "train_loader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)\n", 610 | "test_loader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)\n", 611 | "\n", 612 | "data, labels = next(iter(train_loader))\n", 613 | "print(data.size())\n", 614 | "print(labels)" 615 | ] 616 | }, 617 | { 618 | "cell_type": "markdown", 619 | "metadata": { 620 | "id": "p6h_YU0zhrs3" 621 | }, 622 | "source": [ 623 | "## 3.2 Construct Model" 624 | ] 625 | }, 626 | { 627 | "cell_type": "code", 628 | "execution_count": null, 629 | "metadata": { 630 | "id": "-HEhzeA02Ueg" 631 | }, 632 | "outputs": [], 633 | "source": [ 634 | "import torch.functional as F\n", 635 | "\n", 636 | "# Define Network\n", 637 | "class DVSNet(nn.Module):\n", 638 | " def __init__(self):\n", 639 | " super().__init__()\n", 640 | "\n", 641 | " beta = 0.9\n", 642 | "\n", 643 | " # Initialize layers\n", 644 | " self.conv1 =\n", 645 | " self.mp1 =\n", 646 | " self.lif1 =\n", 647 | " self.conv2 =\n", 648 | " self.mp2 =\n", 649 | " self.lif2 =\n", 650 | " self.fc =\n", 651 | " self.lif3 =\n", 652 | "\n", 653 | "\n", 654 | " def forward(self, x):\n", 655 | "\n", 656 | " # Initialize hidden states at t=0\n", 657 | " mem1 = self.lif1.init_leaky()\n", 658 | " mem2 = self.lif2.init_leaky()\n", 659 | " mem3 = self.lif3.init_leaky()\n", 660 | "\n", 661 | " # Record the final layer\n", 662 | " spk3_rec = []\n", 663 | " mem3_rec = []\n", 664 | "\n", 665 | " for step in range(...):\n", 666 | " cur1 =\n", 667 | " spk1, mem1 =\n", 668 | " cur2 =\n", 669 | " spk2, mem2 =\n", 670 | " cur3 =\n", 671 | " spk3, mem3 =\n", 672 | "\n", 673 | " spk3_rec.append(spk3)\n", 674 | " mem3_rec.append(mem3)\n", 675 | "\n", 676 | " return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)\n", 677 | "\n", 678 | "# Load the network onto CUDA if available\n", 679 | "dvsnet = DVSNet().to(device)" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": null, 685 | "metadata": { 686 | "id": "V5havFWd3c-F" 687 | }, 688 | "outputs": [], 689 | "source": [ 690 | "training_loop(dvsnet, train_loader, num_epochs=10)\n", 691 | "print(f\"Test set accuracy: {measure_accuracy(dvsnet, test_loader)}\")" 692 | ] 693 | }, 694 | { 695 | "cell_type": "markdown", 696 | "metadata": { 697 | "id": "pNrFs3ro-xUm" 698 | }, 699 | "source": [ 700 | "That's all folks!" 701 | ] 702 | } 703 | ], 704 | "metadata": { 705 | "accelerator": "GPU", 706 | "colab": { 707 | "gpuType": "T4", 708 | "include_colab_link": true, 709 | "provenance": [] 710 | }, 711 | "kernelspec": { 712 | "display_name": "Python 3", 713 | "name": "python3" 714 | }, 715 | "language_info": { 716 | "name": "python" 717 | } 718 | }, 719 | "nbformat": 4, 720 | "nbformat_minor": 0 721 | } 722 | -------------------------------------------------------------------------------- /software/ISFPGA_SNN_cheatsheet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": { 16 | "id": "8w6lhn7H8fW5" 17 | }, 18 | "source": [ 19 | "# ISFPGA Workshop\n", 20 | "## Who needs neuromorphic hardware? Deploying SNNs to FPGAs via HLS\n", 21 | "### By Jason K. Eshraghian (www.ncg.ucsc.edu)\n", 22 | "\n", 23 | "\n", 24 | "[](https://github.com/jeshraghian/snntorch/)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": { 31 | "id": "1BlTqunB73-t" 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "!pip install snntorch --quiet # shift + enter" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": { 41 | "id": "xLlnINAI9mgJ" 42 | }, 43 | "source": [ 44 | "*What will I learn?*\n", 45 | "\n", 46 | "1. Train an SNN classifier using snnTorch\n", 47 | "2. Hardware Friendly Training\n", 48 | " - Weight Quantization with Brevitas\n", 49 | " - Stateful Quantization\n", 50 | "3. Handling neuromorphic data with Tonic" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": { 56 | "id": "-CS50cwuCW6n" 57 | }, 58 | "source": [ 59 | "# 1. Train an SNN Classifier using snnTorch\n", 60 | "## 1.1 Imports\n" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "metadata": { 67 | "id": "H_TzogsCCcSe" 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# snntorch imports\n", 72 | "import snntorch as snn\n", 73 | "from snntorch import functional as SF\n", 74 | "\n", 75 | "# pytorch imports\n", 76 | "import torch\n", 77 | "import torch.nn as nn\n", 78 | "from torch.utils.data import DataLoader\n", 79 | "from torchvision import datasets, transforms\n", 80 | "\n", 81 | "# data manipulation\n", 82 | "import numpy as np\n", 83 | "import itertools\n", 84 | "\n", 85 | "# plotting\n", 86 | "import matplotlib.pyplot as plt\n", 87 | "from IPython.display import HTML" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": { 93 | "id": "nftOdpyAGv7D" 94 | }, 95 | "source": [ 96 | "## 1.2 Boilerplate: DataLoading the MNIST Dataset" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": { 103 | "id": "SsM2Z5NXGu5z" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "# dataloader arguments\n", 108 | "batch_size = 128\n", 109 | "data_path='/data/mnist'\n", 110 | "\n", 111 | "dtype = torch.float\n", 112 | "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", 113 | "## if you're on M1 or M2 GPU:\n", 114 | "# device = torch.device(\"mps\")" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": null, 120 | "metadata": { 121 | "id": "XqbYptgDHUPg" 122 | }, 123 | "outputs": [], 124 | "source": [ 125 | "# Define a transform\n", 126 | "transform = transforms.Compose([\n", 127 | " transforms.Resize((28, 28)),\n", 128 | " transforms.Grayscale(),\n", 129 | " transforms.ToTensor(),\n", 130 | " transforms.Normalize((0,), (1,))])\n", 131 | "\n", 132 | "mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n", 133 | "mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "metadata": { 140 | "id": "jSlS3gWZHXI0" 141 | }, 142 | "outputs": [], 143 | "source": [ 144 | "# Create DataLoaders\n", 145 | "train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)\n", 146 | "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "metadata": { 152 | "id": "SJVhfNukHbsp" 153 | }, 154 | "source": [ 155 | "## 1.3 Construct SNN Model" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": { 162 | "id": "uu324fr_HhxV" 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "# Network Architecture\n", 167 | "num_inputs = 28*28\n", 168 | "num_hidden = 100\n", 169 | "num_outputs = 10\n", 170 | "\n", 171 | "# Temporal Dynamics\n", 172 | "num_steps = 25\n", 173 | "beta = 0.95" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": { 180 | "id": "CkM1Z1EjHeW8" 181 | }, 182 | "outputs": [], 183 | "source": [ 184 | "from snntorch import surrogate\n", 185 | "\n", 186 | "# Define Network\n", 187 | "class Net(nn.Module):\n", 188 | " def __init__(self):\n", 189 | " super().__init__()\n", 190 | "\n", 191 | " # Initialize layers\n", 192 | " self.fc1 = nn.Linear(num_inputs, num_hidden)\n", 193 | " self.lif1 = snn.Leaky(beta=beta)\n", 194 | " self.fc2 = nn.Linear(num_hidden, num_outputs)\n", 195 | " self.lif2 = snn.Leaky(beta=beta)\n", 196 | "\n", 197 | " def forward(self, x):\n", 198 | "\n", 199 | " # Initialize hidden states at t=0\n", 200 | " mem1 = self.lif1.init_leaky()\n", 201 | " mem2 = self.lif2.init_leaky()\n", 202 | "\n", 203 | " # Record the final layer\n", 204 | " spk2_rec = []\n", 205 | " mem2_rec = []\n", 206 | "\n", 207 | " # time-loop\n", 208 | " for step in range(num_steps):\n", 209 | " cur1 = self.fc1(x.flatten(1)) # batch: 128 x 784\n", 210 | " spk1, mem1 = self.lif1(cur1, mem1)\n", 211 | " cur2 = self.fc2(spk1)\n", 212 | " spk2, mem2 = self.lif2(cur2, mem2)\n", 213 | "\n", 214 | " # store in list\n", 215 | " spk2_rec.append(spk2)\n", 216 | " mem2_rec.append(mem2)\n", 217 | "\n", 218 | " return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0) # time-steps x batch x num_out\n", 219 | "\n", 220 | "# Load the network onto CUDA if available\n", 221 | "net = Net().to(device)" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "metadata": { 227 | "id": "p8qBw03rHpn3" 228 | }, 229 | "source": [ 230 | "## 1.4 Training the SNN" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "metadata": { 237 | "id": "1telBMU-HrIg" 238 | }, 239 | "outputs": [], 240 | "source": [ 241 | "def training_loop(model, dataloader, num_epochs=1):\n", 242 | " loss = nn.CrossEntropyLoss()\n", 243 | " optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999))\n", 244 | " counter = 0\n", 245 | "\n", 246 | " # Outer training loop\n", 247 | " for epoch in range(num_epochs):\n", 248 | " train_batch = iter(dataloader)\n", 249 | "\n", 250 | " # Minibatch training loop\n", 251 | " for data, targets in train_batch:\n", 252 | " data = data.to(device)\n", 253 | " targets = targets.to(device)\n", 254 | "\n", 255 | " # forward pass\n", 256 | " model.train()\n", 257 | " spk_rec, _ = model(data)\n", 258 | "\n", 259 | " # initialize the loss & sum over time\n", 260 | " loss_val = torch.zeros((1), dtype=dtype, device=device)\n", 261 | " loss_val = loss(spk_rec.sum(0), targets) # batch x num_out\n", 262 | "\n", 263 | " # Gradient calculation + weight update\n", 264 | " optimizer.zero_grad()\n", 265 | " loss_val.backward()\n", 266 | " optimizer.step()\n", 267 | "\n", 268 | " # Print train/test loss/accuracy\n", 269 | " if counter % 10 == 0:\n", 270 | " print(f\"Iteration: {counter} \\t Train Loss: {loss_val.item()}\")\n", 271 | " counter += 1\n", 272 | "\n", 273 | " if counter == 100:\n", 274 | " break\n", 275 | "\n", 276 | "training_loop(net, train_loader)" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": { 283 | "id": "nHsdkuSlIS1E" 284 | }, 285 | "outputs": [], 286 | "source": [ 287 | "def measure_accuracy(model, dataloader):\n", 288 | " with torch.no_grad():\n", 289 | " model.eval()\n", 290 | " running_length = 0\n", 291 | " running_accuracy = 0\n", 292 | "\n", 293 | " for data, targets in iter(dataloader):\n", 294 | " data = data.to(device)\n", 295 | " targets = targets.to(device)\n", 296 | "\n", 297 | " # forward-pass\n", 298 | " spk_rec, _ = model(data)\n", 299 | " spike_count = spk_rec.sum(0) # batch x num_outputs\n", 300 | " _, max_spike = spike_count.max(1)\n", 301 | "\n", 302 | " # correct classes for one batch\n", 303 | " num_correct = (max_spike == targets).sum()\n", 304 | "\n", 305 | " # total accuracy\n", 306 | " running_length += len(targets)\n", 307 | " running_accuracy += num_correct\n", 308 | "\n", 309 | " accuracy = (running_accuracy / running_length)\n", 310 | "\n", 311 | " return accuracy.item()\n" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": null, 317 | "metadata": { 318 | "id": "oJHAltRCKGyx" 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "print(f\"Test set accuracy: {measure_accuracy(net, test_loader)}\")" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "metadata": { 329 | "id": "gFjUZ0MebB1A" 330 | }, 331 | "outputs": [], 332 | "source": [ 333 | "print_sample(net, test_loader)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": { 339 | "id": "lGMrI2_ga_CA" 340 | }, 341 | "source": [ 342 | "### A Sanity Check" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": null, 348 | "metadata": { 349 | "id": "zphRrbBQbAjp" 350 | }, 351 | "outputs": [], 352 | "source": [ 353 | "def print_sample(model, dataloader, idx=0):\n", 354 | " with torch.no_grad():\n", 355 | " model.eval()\n", 356 | "\n", 357 | " data, targets = next(iter(dataloader))\n", 358 | " data = data.to(device)\n", 359 | " targets = targets.to(device)\n", 360 | "\n", 361 | " # forward-pass\n", 362 | " spk_rec, _ = model(data)\n", 363 | " spike_count = spk_rec.sum(0) # batch x num_outputs\n", 364 | " _, max_spike = spike_count.max(1)\n", 365 | "\n", 366 | " # Plot the sample\n", 367 | " plt.imshow(data[idx].cpu().squeeze(), cmap='gray')\n", 368 | " plt.title(f'Target: {targets[idx].item()}')\n", 369 | " plt.show()\n", 370 | "\n", 371 | "\n", 372 | " return" 373 | ] 374 | }, 375 | { 376 | "cell_type": "markdown", 377 | "metadata": { 378 | "id": "Z9vrb2zUD6S-" 379 | }, 380 | "source": [ 381 | "# 2. Hardware Friendly Training\n", 382 | "## 2.1 Weight Quantization" 383 | ] 384 | }, 385 | { 386 | "cell_type": "code", 387 | "execution_count": null, 388 | "metadata": { 389 | "id": "icK4WzuL-2QA" 390 | }, 391 | "outputs": [], 392 | "source": [ 393 | "!pip install brevitas --quiet" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": { 399 | "id": "idiLnVjJGJAL" 400 | }, 401 | "source": [ 402 | "Just replace all `nn.Linear` layers with `qnn.QuantLinear(num_inputs, num_outputs, weight_bit_width, bias)`." 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": null, 408 | "metadata": { 409 | "id": "PiZYTMA6D-YL" 410 | }, 411 | "outputs": [], 412 | "source": [ 413 | "import brevitas.nn as qnn\n", 414 | "\n", 415 | "# Define Network\n", 416 | "class QuantNet(nn.Module):\n", 417 | " def __init__(self):\n", 418 | " super().__init__()\n", 419 | "\n", 420 | " # Initialize layers\n", 421 | " self.fc1 = qnn.QuantLinear(num_inputs, num_hidden, weight_bit_width=8, bias=False)\n", 422 | " self.lif1 = snn.Leaky(beta=beta)\n", 423 | " self.fc2 = qnn.QuantLinear(num_hidden, num_outputs, weight_bit_width=8, bias=False)\n", 424 | " self.lif2 = snn.Leaky(beta=beta)\n", 425 | "\n", 426 | " def forward(self, x):\n", 427 | "\n", 428 | " # Initialize hidden states at t=0\n", 429 | " mem1 = self.lif1.init_leaky()\n", 430 | " mem2 = self.lif2.init_leaky()\n", 431 | "\n", 432 | " # Record the final layer\n", 433 | " spk2_rec = []\n", 434 | " mem2_rec = []\n", 435 | "\n", 436 | " for step in range(num_steps):\n", 437 | " cur1 = self.fc1(x.flatten(1))\n", 438 | " spk1, mem1 = self.lif1(cur1, mem1)\n", 439 | " cur2 = self.fc2(spk1)\n", 440 | " spk2, mem2 = self.lif2(cur2, mem2)\n", 441 | " spk2_rec.append(spk2)\n", 442 | " mem2_rec.append(mem2)\n", 443 | "\n", 444 | " return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)\n", 445 | "\n", 446 | "# Load the network onto CUDA if available\n", 447 | "qnet = QuantNet().to(device)" 448 | ] 449 | }, 450 | { 451 | "cell_type": "code", 452 | "execution_count": null, 453 | "metadata": { 454 | "id": "PocSa27MOOoK" 455 | }, 456 | "outputs": [], 457 | "source": [ 458 | "training_loop(qnet, train_loader)\n", 459 | "print(f\"Test set accuracy: {measure_accuracy(qnet, test_loader)}\")" 460 | ] 461 | }, 462 | { 463 | "cell_type": "markdown", 464 | "metadata": { 465 | "id": "JvRzLHdjD-6S" 466 | }, 467 | "source": [ 468 | "## 2.2 SQUAT: Stateful Quantization-Aware Training\n", 469 | "\n" 470 | ] 471 | }, 472 | { 473 | "cell_type": "code", 474 | "execution_count": null, 475 | "metadata": { 476 | "id": "tAaafUczEAu_" 477 | }, 478 | "outputs": [], 479 | "source": [ 480 | "from snntorch.functional import quant\n", 481 | "\n", 482 | "# Define Network\n", 483 | "class SquatNet(nn.Module):\n", 484 | " def __init__(self):\n", 485 | " super().__init__()\n", 486 | "\n", 487 | " q_lif = quant.state_quant(num_bits=4, uniform=True)\n", 488 | "\n", 489 | " # Initialize layers\n", 490 | " self.fc1 = qnn.QuantLinear(num_inputs, num_hidden, weight_bit_width=8, bias=False)\n", 491 | " self.lif1 = snn.Leaky(beta=beta, state_quant=q_lif)\n", 492 | " self.fc2 = qnn.QuantLinear(num_hidden, num_outputs, weight_bit_width=8, bias=False)\n", 493 | " self.lif2 = snn.Leaky(beta=beta, state_quant=q_lif)\n", 494 | "\n", 495 | " def forward(self, x):\n", 496 | "\n", 497 | " # Initialize hidden states at t=0\n", 498 | " mem1 = self.lif1.init_leaky()\n", 499 | " mem2 = self.lif2.init_leaky()\n", 500 | "\n", 501 | " # Record the final layer\n", 502 | " spk2_rec = []\n", 503 | " mem2_rec = []\n", 504 | "\n", 505 | " for step in range(num_steps):\n", 506 | " cur1 = self.fc1(x.flatten(1))\n", 507 | " spk1, mem1 = self.lif1(cur1, mem1)\n", 508 | " cur2 = self.fc2(spk1)\n", 509 | " spk2, mem2 = self.lif2(cur2, mem2)\n", 510 | " spk2_rec.append(spk2)\n", 511 | " mem2_rec.append(mem2)\n", 512 | "\n", 513 | " return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)\n", 514 | "\n", 515 | "# Load the network onto CUDA if available\n", 516 | "sqnet = SquatNet().to(device)" 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": null, 522 | "metadata": { 523 | "id": "0zfo0h6AQDzS" 524 | }, 525 | "outputs": [], 526 | "source": [ 527 | "training_loop(sqnet, train_loader)\n", 528 | "print(f\"Test set accuracy: {measure_accuracy(sqnet, test_loader)}\")" 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "metadata": { 534 | "id": "EZAFVgHW4M2g" 535 | }, 536 | "source": [ 537 | "# 3. Handling Neuromorphic Data with Tonic" 538 | ] 539 | }, 540 | { 541 | "cell_type": "code", 542 | "execution_count": null, 543 | "metadata": { 544 | "id": "g-sCKamvcI73" 545 | }, 546 | "outputs": [], 547 | "source": [ 548 | "!pip install tonic --quiet" 549 | ] 550 | }, 551 | { 552 | "cell_type": "markdown", 553 | "metadata": { 554 | "id": "7ovHbF2DcJtv" 555 | }, 556 | "source": [ 557 | "## 3.1 PokerDVS Dataset\n", 558 | "\n", 559 | "The dataset used in this tutorial is POKERDVS by T. Serrano-Gotarredona and B. Linares-Barranco:\n", 560 | "\n", 561 | "```\n", 562 | "Serrano-Gotarredona, Teresa, and Bernabé Linares-Barranco. \"Poker-DVS and MNIST-DVS. Their history, how they were made, and other details.\" Frontiers in neuroscience 9 (2015): 481.\n", 563 | "```\n", 564 | "\n", 565 | "It is comprised of four classes, each being a suite of a playing card deck: clubs, spades, hearts, and diamonds. The data consists of 131 poker pip symbols, and was collected by flipping poker cards in front of a DVS128 camera." 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": null, 571 | "metadata": { 572 | "id": "Wya2zvXAhizi" 573 | }, 574 | "outputs": [], 575 | "source": [ 576 | "import tonic\n", 577 | "\n", 578 | "poker_train = tonic.datasets.POKERDVS(save_to='./data', train=True)\n", 579 | "poker_test = tonic.datasets.POKERDVS(save_to='./data', train=False)\n", 580 | "\n", 581 | "events, target = poker_train[0]\n", 582 | "print(events)\n", 583 | "tonic.utils.plot_event_grid(events)" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": null, 589 | "metadata": { 590 | "id": "JikbjfRxhkVI" 591 | }, 592 | "outputs": [], 593 | "source": [ 594 | "import tonic.transforms as transforms\n", 595 | "from tonic import DiskCachedDataset\n", 596 | "\n", 597 | "# time_window\n", 598 | "frame_transform = tonic.transforms.Compose([tonic.transforms.Denoise(filter_time=10000),\n", 599 | " tonic.transforms.ToFrame(\n", 600 | " sensor_size=tonic.datasets.POKERDVS.sensor_size,\n", 601 | " time_window=1000)\n", 602 | " ])\n", 603 | "\n", 604 | "batch_size = 8\n", 605 | "cached_trainset = DiskCachedDataset(poker_train, transform=frame_transform, cache_path='./cache/pokerdvs/train')\n", 606 | "cached_testset = DiskCachedDataset(poker_test, transform=frame_transform, cache_path='./cache/pokerdvs/test')\n", 607 | "\n", 608 | "train_loader = DataLoader(cached_trainset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)\n", 609 | "test_loader = DataLoader(cached_testset, batch_size=batch_size, collate_fn=tonic.collation.PadTensors(batch_first=False), shuffle=True)\n", 610 | "\n", 611 | "data, labels = next(iter(train_loader))\n", 612 | "print(data.size())\n", 613 | "print(labels)" 614 | ] 615 | }, 616 | { 617 | "cell_type": "code", 618 | "execution_count": null, 619 | "metadata": { 620 | "id": "-HEhzeA02Ueg" 621 | }, 622 | "outputs": [], 623 | "source": [ 624 | "import torch.functional as F\n", 625 | "\n", 626 | "# Define Network\n", 627 | "class DVSNet(nn.Module):\n", 628 | " def __init__(self):\n", 629 | " super().__init__()\n", 630 | "\n", 631 | " beta = 0.9\n", 632 | "\n", 633 | " # Initialize layers\n", 634 | " self.conv1 = nn.Conv2d(2, 12, 5)\n", 635 | " self.mp1 = nn.MaxPool2d(2)\n", 636 | " self.lif1 = snn.Leaky(beta=beta)\n", 637 | " self.conv2 = nn.Conv2d(12, 32, 5)\n", 638 | " self.mp2 = nn.MaxPool2d(2)\n", 639 | " self.lif2 = snn.Leaky(beta=beta)\n", 640 | " self.fc = nn.Linear(32*5*5, 4)\n", 641 | " self.lif3 = snn.Leaky(beta=beta)\n", 642 | "\n", 643 | "\n", 644 | " def forward(self, x):\n", 645 | "\n", 646 | " # Initialize hidden states at t=0\n", 647 | " mem1 = self.lif1.init_leaky()\n", 648 | " mem2 = self.lif2.init_leaky()\n", 649 | " mem3 = self.lif3.init_leaky()\n", 650 | "\n", 651 | " # Record the final layer\n", 652 | " spk3_rec = []\n", 653 | " mem3_rec = []\n", 654 | "\n", 655 | " for step in range(x.size(0)):\n", 656 | " cur1 = self.mp1(self.conv1(x[step]))\n", 657 | " spk1, mem1 = self.lif1(cur1, mem1)\n", 658 | " cur2 = self.mp2(self.conv2(spk1))\n", 659 | " spk2, mem2 = self.lif2(cur2, mem2)\n", 660 | " cur3 = self.fc(spk2.flatten(1))\n", 661 | " spk3, mem3 = self.lif3(cur3, mem3)\n", 662 | "\n", 663 | " spk3_rec.append(spk3)\n", 664 | " mem3_rec.append(mem3)\n", 665 | "\n", 666 | " return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=0)\n", 667 | "\n", 668 | "# Load the network onto CUDA if available\n", 669 | "dvsnet = DVSNet().to(device)" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": null, 675 | "metadata": { 676 | "id": "V5havFWd3c-F" 677 | }, 678 | "outputs": [], 679 | "source": [ 680 | "training_loop(dvsnet, train_loader, num_epochs=10)\n", 681 | "print(f\"Test set accuracy: {measure_accuracy(dvsnet, test_loader)}\")" 682 | ] 683 | }, 684 | { 685 | "cell_type": "markdown", 686 | "metadata": { 687 | "id": "pNrFs3ro-xUm" 688 | }, 689 | "source": [ 690 | "That's all folks!" 691 | ] 692 | } 693 | ], 694 | "metadata": { 695 | "accelerator": "GPU", 696 | "colab": { 697 | "gpuType": "T4", 698 | "provenance": [], 699 | "include_colab_link": true 700 | }, 701 | "kernelspec": { 702 | "display_name": "Python 3", 703 | "name": "python3" 704 | }, 705 | "language_info": { 706 | "name": "python" 707 | } 708 | }, 709 | "nbformat": 4, 710 | "nbformat_minor": 0 711 | } --------------------------------------------------------------------------------