├── .gitignore ├── README.md ├── notebooks ├── .gitignore ├── SpyTorchTutorial1.ipynb ├── SpyTorchTutorial2.ipynb ├── SpyTorchTutorial3.ipynb ├── SpyTorchTutorial4.ipynb ├── figures │ ├── .gitignore │ ├── mlp_sketch │ │ ├── Makefile │ │ ├── mlp_sketch.png │ │ └── mlp_sketch.tex │ ├── snn_graph │ │ ├── Makefile │ │ ├── snn_graph.png │ │ └── snn_graph.tex │ └── surrgrad │ │ ├── Makefile │ │ ├── surrgrad.gnu │ │ └── surrgrad.png └── utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # ---> Python 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | 61 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SpyTorch 2 | A tutorial on surrogate gradient learning in spiking neural networks 3 | 4 | Version: 0.4 5 | 6 | [![DOI](https://zenodo.org/badge/170391179.svg)](https://zenodo.org/badge/latestdoi/170391179) 7 | 8 | This repository contains tutorial files to get you started with the basic ideas 9 | of surrogate gradient learning in spiking neural networks using PyTorch. 10 | 11 | Feedback and contributions are welcome. 12 | 13 | For more information on surrogate gradient learning please refer to: 14 | > Neftci, E.O., Mostafa, H., and Zenke, F. (2019). Surrogate Gradient Learning in Spiking Neural Networks: Bringing the Power of Gradient-based optimization to spiking neural networks. IEEE Signal Processing Magazine 36, 51–63. 15 | > https://ieeexplore.ieee.org/document/8891809 16 | > preprint: https://arxiv.org/abs/1901.09948 17 | 18 | 19 | Also see https://github.com/surrogate-gradient-learning 20 | 21 | ## Copyright and license 22 | 23 | Copyright 2019-2020 Friedemann Zenke, https://fzenke.net 24 | 25 | This work is licensed under a Creative Commons Attribution 4.0 International License. 26 | http://creativecommons.org/licenses/by/4.0/ 27 | -------------------------------------------------------------------------------- /notebooks/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # IPython 78 | profile_default/ 79 | ipython_config.py 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # SageMath parsed files 88 | *.sage.py 89 | 90 | # Environments 91 | .env 92 | .venv 93 | env/ 94 | venv/ 95 | ENV/ 96 | env.bak/ 97 | venv.bak/ 98 | 99 | # Spyder project settings 100 | .spyderproject 101 | .spyproject 102 | 103 | # Rope project settings 104 | .ropeproject 105 | 106 | # mkdocs documentation 107 | /site 108 | 109 | # mypy 110 | .mypy_cache/ 111 | .dmypy.json 112 | dmypy.json 113 | 114 | # Pyre type checker 115 | .pyre/ 116 | 117 | -------------------------------------------------------------------------------- /notebooks/SpyTorchTutorial3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Tutorial 3: Training a spiking neural network on a simple vision dataset\n", 8 | "\n", 9 | "Friedemann Zenke (https://fzenke.net)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "> For more details on surrogate gradient learning, please see: \n", 17 | "> Neftci, E.O., Mostafa, H., and Zenke, F. (2019). Surrogate Gradient Learning in Spiking Neural Networks.\n", 18 | "> https://arxiv.org/abs/1901.09948" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "In Tutorial 2, we have seen how to train a simple multi-layer spiking neural network on the [Fashion MNIST dataset](https://github.com/zalandoresearch/fashion-mnist). However, the spiking activity in the hidden layer was not particularly plausible in a biological sense. Here we modify the network from this previous tutorial by adding activity regularizer, which encourages solutions with sparse spiking." 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 1, 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "import os\n", 35 | "\n", 36 | "import numpy as np\n", 37 | "import matplotlib.pyplot as plt\n", 38 | "from matplotlib.gridspec import GridSpec\n", 39 | "import seaborn as sns\n", 40 | "\n", 41 | "import torch\n", 42 | "import torch.nn as nn\n", 43 | "import torchvision" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "# The coarse network structure is dicated by the Fashion MNIST dataset. \n", 53 | "nb_inputs = 28*28\n", 54 | "nb_hidden = 100\n", 55 | "nb_outputs = 10\n", 56 | "\n", 57 | "time_step = 1e-3\n", 58 | "nb_steps = 100\n", 59 | "\n", 60 | "batch_size = 256" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 3, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "dtype = torch.float\n", 70 | "\n", 71 | "# Check whether a GPU is available\n", 72 | "if torch.cuda.is_available():\n", 73 | " device = torch.device(\"cuda\") \n", 74 | "else:\n", 75 | " device = torch.device(\"cpu\")" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 4, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "# Here we load the Dataset\n", 85 | "root = os.path.expanduser(\"~/data/datasets/torch/fashion-mnist\")\n", 86 | "train_dataset = torchvision.datasets.FashionMNIST(root, train=True, transform=None, target_transform=None, download=True)\n", 87 | "test_dataset = torchvision.datasets.FashionMNIST(root, train=False, transform=None, target_transform=None, download=True)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 5, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "# Standardize data\n", 97 | "# x_train = torch.tensor(train_dataset.train_data, device=device, dtype=dtype)\n", 98 | "x_train = np.array(train_dataset.data, dtype=np.float)\n", 99 | "x_train = x_train.reshape(x_train.shape[0],-1)/255\n", 100 | "# x_test = torch.tensor(test_dataset.test_data, device=device, dtype=dtype)\n", 101 | "x_test = np.array(test_dataset.data, dtype=np.float)\n", 102 | "x_test = x_test.reshape(x_test.shape[0],-1)/255\n", 103 | "\n", 104 | "# y_train = torch.tensor(train_dataset.train_labels, device=device, dtype=dtype)\n", 105 | "# y_test = torch.tensor(test_dataset.test_labels, device=device, dtype=dtype)\n", 106 | "y_train = np.array(train_dataset.targets, dtype=np.int)\n", 107 | "y_test = np.array(test_dataset.targets, dtype=np.int)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 6, 113 | "metadata": {}, 114 | "outputs": [ 115 | { 116 | "data": { 117 | "text/plain": [ 118 | "(-0.5, 27.5, 27.5, -0.5)" 119 | ] 120 | }, 121 | "execution_count": 6, 122 | "metadata": {}, 123 | "output_type": "execute_result" 124 | }, 125 | { 126 | "data": { 127 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAJqElEQVR4nO3dv0vWaxzG8dvUp57StJ+UWEORENZgg0W0NbjW0ljQ1NAWtPQXBLW0RNAYgUM0tERDREsNWUiJggURpTRoGubv7EznLMfv9RHvvue5ns77tV7c+n3Mqy/44b7vhl+/fiUAfjbU+gEArI5yAqYoJ2CKcgKmKCdgqinIbf+UG/2VuaGh4T96kn8bHh6W+eXLlwuzc+fOybU9PT0yr1QqMm9q0v/kQ0NDhdnDhw/l2gMHDsj86tWrMm9vb5f5H2zVX1benIApygmYopyAKcoJmKKcgCnKCZiinICphmBeWNqcs5Zzyjdv3si8v79f5g8ePJB5Y2OjzGdmZgqzubk5uXZyclLmZerq6pL5hg36//qRkRGZ79mzpzDr6+uTa69cuSLzo0ePyrzGmHMC9YRyAqYoJ2CKcgKmKCdginICpignYKpmc85c379/l/n58+cLs8HBQbk2msG2tLTIvFqtylztqYxmpMvLyzKfnp6W+ebNm2Wuvn/Ze2Tn5+cLs2j+u7i4KPNTp07J/N69ezIvGXNOoJ5QTsAU5QRMUU7AFOUETFFOwFTdjlJOnz4t80+fPhVmO3bskGujkcHPnz9lHo1DcqysrMi8ublZ5tGzK7W89Cp3i+H4+LjMHz9+LPPDhw/LPBOjFKCeUE7AFOUETFFOwBTlBExRTsAU5QRMRVcA1szAwIDM1RwzpZR27txZmEXbriLR9qUvX76se300x4yu8IvmmNHxlUq0LSuasba2tsq8s7OzMIs+dyT63Hfv3pX5zZs3s77/evDmBExRTsAU5QRMUU7AFOUETFFOwBTlBEzZ7ue8ceOGzG/duiVztWczmnlFs8Zo/aVLl2S+d+/ewmzfvn1y7djY2Lq/dkp5+0GjOae62jCllF6/fi1z9W+6a9cuuXZpaUnm0VGq0Xz448ePMs/Efk6gnlBOwBTlBExRTsAU5QRMUU7AFOUETNnOOU+cOCHzr1+/ynzr1q2FWaVSkWujeV1bW5vMX758KfMnT54UZp8/f5ZrL168KPM7d+7IvLu7W+bqGr5oFrh7926Z9/T0yPzQoUOFWXTtonrulOK9piMjIzJ/9+5dYdbV1SXXrgFzTqCeUE7AFOUETFFOwBTlBExRTsCU7dGYg4ODMo+2Vqk/+y8sLKzrmf42PT2dtb6vr68wi0YGw8PDMo+22p09e1bmjx49KsyiI0WjUUm0ZUwdfzk7OyvXRtv4ojz6fXrx4kVh9htGKavizQmYopyAKcoJmKKcgCnKCZiinIApygmYqtmc8+3btzKPjkJsbGyUuZpzRlufoiv+tm/fLvPI0NBQYbZx40a5dnx8XObXrl2TebBFUG6titaqWeBaqGM9oyNBo9+HhoZVd2X9o1qtyvz58+eF2YULF+Ta9eLNCZiinIApygmYopyAKcoJmKKcgCnKCZiq2Zzz+vXrMo9mjVu2bJF5zt7ATZs2yTw6ZvHVq1cyn5iYKMwmJyfl2uiqu+jI0OjZ1WePrgCcmpqSeX9/v8y/fftWmEVzyOh7R+ujn+vAwIDMy8CbEzBFOQFTlBMwRTkBU5QTMEU5AVOUEzBVsznnyZMnZR7N696/fy9zdbZsNOdUV9GlFJ+Bevz4cZmrvYe556+urKzIPJrnqT2banacUrxPVl3LmJI+//XHjx9ybfS5o72oHR0dMj9z5ozMy8CbEzBFOQFTlBMwRTkBU5QTMEU5AVOUEzDVEMx/9HCohtTev5RSGh0dLcxu374t1z579kzm+/fvl3l0f2d7e3thFu2ZjOZ5ZYpmhdGzRftk1c/tyJEjcu39+/dlbm7VQ3V5cwKmKCdginICpignYIpyAqYoJ2CqZlvGcm3btk3mvb29hVl0zd7Tp09lHl0nt7CwIHO1/Wl5eVmujbaMRaJxiMqj7x197uhYzvn5+cIs2mL4J+LNCZiinIApygmYopyAKcoJmKKcgCnKCZiynXNG87joiMdKpVKYRXPK1tZWmUdHQKqjL9fy/ZXo55LztcuWs91NbbNbi+jfLJrh1uLnypsTMEU5AVOUEzBFOQFTlBMwRTkBU5QTMGU754zmStHeQOXgwYMyj66qi/ZcqhlrJPrcznPO6HNHx34qbW1t616bUjxjjWbTtcCbEzBFOQFTlBMwRTkBU5QTMEU5AVOUEzBlO+eM5MytqtWqXBuda6vOV00pnsGqvai5c8ycc2lTyttzGV3xNzs7K3P1bI5zyLLx5gRMUU7AFOUETFFOwBTlBExRTsAU5QRM1e2cM2ffYnRGae4ZprmzyJyvnTOnTEk/W85zpxT/XNXZsrn3kjqf51uENydginICpignYIpyAqYoJ2CKcgKm6naUUqaxsTGZR9fRRdfNKblbvmoperZoK51aHx1H+ifizQmYopyAKcoJmKKcgCnKCZiinIApygmYqts5Z5lbgHKPYYyuulPbn3LnnGUerRmtjT53dOSo+vq5c062jAH4bSgnYIpyAqYoJ2CKcgKmKCdginICpup2zlmmaB6Xc/1gtD73WM5oHhjtqVRfP9qnGj1bU9P6f92mpqbWvbZe8eYETFFOwBTlBExRTsAU5QRMUU7AFOUETDHnXEXufs5Izp7JSDSLzJk15l5tGK1XM9i5uTm5NsJ+TgC/DeUETFFOwBTlBExRTsAU5QRMMUpZRc4VfmtR5p/1y7wiMHruaCtdtF6NsGZnZ+XaPxFvTsAU5QRMUU7AFOUETFFOwBTlBExRTsBU3c45a7kFKJrnlSl3jpkzw83dMhb93NR2trJnz454cwKmKCdginICpignYIpyAqYoJ2CKcgKm6nbOmXsMo1KpVGSee0yjEl0BWOb1g2v5/kruHFQ9e+6ck6MxAfw2lBMwRTkBU5QTMEU5AVOUEzBFOQFTdTvnrKXcWaOa90VfOzeP5pg5+0Vzz7VV2M8JwAblBExRTsAU5QRMUU7AFOUETFFOwFTdzjnL3J/X0dEh89HRUZmr81dT0rPGaA65uLi47q+dUvxzU3n0uZaWlmSeg/2cAGxQTsAU5QRMUU7AFOUETFFOwFTdjlLKNDU1JfOZmRmZRyOFiYmJwiwaGUTbrsocZ0SjlOjZOzs7Za6OHP3w4YNcGynzSNCy+D0RgJQS5QRsUU7AFOUETFFOwBTlBExRTsBU3c45y7wC8NixYzLv7u6WeXt7u8xzZpHRvK6lpUXmOdf05WyFSyml5uZmmav5cm9vr1wbcZxjRurviYH/CcoJmKKcgCnKCZiinIApygmYopyAqYacK98AlIc3J2CKcgKmKCdginICpignYIpyAqb+AoBXs/GvnBRPAAAAAElFTkSuQmCC\n", 128 | "text/plain": [ 129 | "
" 130 | ] 131 | }, 132 | "metadata": { 133 | "needs_background": "light" 134 | }, 135 | "output_type": "display_data" 136 | } 137 | ], 138 | "source": [ 139 | "# Here we plot one of the raw data points as an example\n", 140 | "data_id = 1\n", 141 | "plt.imshow(x_train[data_id].reshape(28,28), cmap=plt.cm.gray_r)\n", 142 | "plt.axis(\"off\")" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "metadata": {}, 148 | "source": [ 149 | "Since we are working with spiking neural networks, we ideally want to use a temporal code to make use of spike timing. To that end, we will use a spike latency code to feed spikes to our network." 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 7, 155 | "metadata": {}, 156 | "outputs": [], 157 | "source": [ 158 | "def current2firing_time(x, tau=20, thr=0.2, tmax=1.0, epsilon=1e-7):\n", 159 | " \"\"\" Computes first firing time latency for a current input x assuming the charge time of a current based LIF neuron.\n", 160 | "\n", 161 | " Args:\n", 162 | " x -- The \"current\" values\n", 163 | "\n", 164 | " Keyword args:\n", 165 | " tau -- The membrane time constant of the LIF neuron to be charged\n", 166 | " thr -- The firing threshold value \n", 167 | " tmax -- The maximum time returned \n", 168 | " epsilon -- A generic (small) epsilon > 0\n", 169 | "\n", 170 | " Returns:\n", 171 | " Time to first spike for each \"current\" x\n", 172 | " \"\"\"\n", 173 | " idx = x0.0] = spike_height\n", 282 | " dat = dat.detach().cpu().numpy()\n", 283 | " else:\n", 284 | " dat = mem.detach().cpu().numpy()\n", 285 | " for i in range(np.prod(dim)):\n", 286 | " if i==0: a0=ax=plt.subplot(gs[i])\n", 287 | " else: ax=plt.subplot(gs[i],sharey=a0)\n", 288 | " ax.plot(dat[i])\n", 289 | " ax.axis(\"off\")" 290 | ] 291 | }, 292 | { 293 | "cell_type": "markdown", 294 | "metadata": {}, 295 | "source": [ 296 | "## Training the network" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 11, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "class SurrGradSpike(torch.autograd.Function):\n", 306 | " \"\"\"\n", 307 | " Here we implement our spiking nonlinearity which also implements \n", 308 | " the surrogate gradient. By subclassing torch.autograd.Function, \n", 309 | " we will be able to use all of PyTorch's autograd functionality.\n", 310 | " Here we use the normalized negative part of a fast sigmoid \n", 311 | " as this was done in Zenke & Ganguli (2018).\n", 312 | " \"\"\"\n", 313 | " \n", 314 | " scale = 100.0 # controls steepness of surrogate gradient\n", 315 | "\n", 316 | " @staticmethod\n", 317 | " def forward(ctx, input):\n", 318 | " \"\"\"\n", 319 | " In the forward pass we compute a step function of the input Tensor\n", 320 | " and return it. ctx is a context object that we use to stash information which \n", 321 | " we need to later backpropagate our error signals. To achieve this we use the \n", 322 | " ctx.save_for_backward method.\n", 323 | " \"\"\"\n", 324 | " ctx.save_for_backward(input)\n", 325 | " out = torch.zeros_like(input)\n", 326 | " out[input > 0] = 1.0\n", 327 | " return out\n", 328 | "\n", 329 | " @staticmethod\n", 330 | " def backward(ctx, grad_output):\n", 331 | " \"\"\"\n", 332 | " In the backward pass we receive a Tensor we need to compute the \n", 333 | " surrogate gradient of the loss with respect to the input. \n", 334 | " Here we use the normalized negative part of a fast sigmoid \n", 335 | " as this was done in Zenke & Ganguli (2018).\n", 336 | " \"\"\"\n", 337 | " input, = ctx.saved_tensors\n", 338 | " grad_input = grad_output.clone()\n", 339 | " grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2\n", 340 | " return grad\n", 341 | " \n", 342 | "# here we overwrite our naive spike function by the \"SurrGradSpike\" nonlinearity which implements a surrogate gradient\n", 343 | "spike_fn = SurrGradSpike.apply" 344 | ] 345 | }, 346 | { 347 | "cell_type": "code", 348 | "execution_count": 12, 349 | "metadata": {}, 350 | "outputs": [], 351 | "source": [ 352 | "def run_snn(inputs):\n", 353 | " h1 = torch.einsum(\"abc,cd->abd\", (inputs, w1))\n", 354 | " syn = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)\n", 355 | " mem = torch.zeros((batch_size,nb_hidden), device=device, dtype=dtype)\n", 356 | "\n", 357 | " mem_rec = []\n", 358 | " spk_rec = []\n", 359 | "\n", 360 | " # Compute hidden layer activity\n", 361 | " for t in range(nb_steps):\n", 362 | " mthr = mem-1.0\n", 363 | " out = spike_fn(mthr)\n", 364 | " rst = out.detach() # We do not want to backprop through the reset\n", 365 | "\n", 366 | " new_syn = alpha*syn +h1[:,t]\n", 367 | " new_mem = (beta*mem +syn)*(1.0-rst)\n", 368 | "\n", 369 | " mem_rec.append(mem)\n", 370 | " spk_rec.append(out)\n", 371 | " \n", 372 | " mem = new_mem\n", 373 | " syn = new_syn\n", 374 | "\n", 375 | " mem_rec = torch.stack(mem_rec,dim=1)\n", 376 | " spk_rec = torch.stack(spk_rec,dim=1)\n", 377 | "\n", 378 | " # Readout layer\n", 379 | " h2= torch.einsum(\"abc,cd->abd\", (spk_rec, w2))\n", 380 | " flt = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)\n", 381 | " out = torch.zeros((batch_size,nb_outputs), device=device, dtype=dtype)\n", 382 | " out_rec = [out]\n", 383 | " for t in range(nb_steps):\n", 384 | " new_flt = alpha*flt +h2[:,t]\n", 385 | " new_out = beta*out +flt\n", 386 | "\n", 387 | " flt = new_flt\n", 388 | " out = new_out\n", 389 | "\n", 390 | " out_rec.append(out)\n", 391 | "\n", 392 | " out_rec = torch.stack(out_rec,dim=1)\n", 393 | " other_recs = [mem_rec, spk_rec]\n", 394 | " return out_rec, other_recs" 395 | ] 396 | }, 397 | { 398 | "cell_type": "code", 399 | "execution_count": 13, 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "def train(x_data, y_data, lr=1e-3, nb_epochs=10):\n", 404 | " params = [w1,w2]\n", 405 | " optimizer = torch.optim.Adamax(params, lr=lr, betas=(0.9,0.999))\n", 406 | "\n", 407 | " log_softmax_fn = nn.LogSoftmax(dim=1)\n", 408 | " loss_fn = nn.NLLLoss()\n", 409 | " \n", 410 | " loss_hist = []\n", 411 | " for e in range(nb_epochs):\n", 412 | " local_loss = []\n", 413 | " for x_local, y_local in sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs):\n", 414 | " output,recs = run_snn(x_local.to_dense())\n", 415 | " _,spks=recs\n", 416 | " m,_=torch.max(output,1)\n", 417 | " log_p_y = log_softmax_fn(m)\n", 418 | " \n", 419 | " # Here we set up our regularizer loss\n", 420 | " # The strength paramters here are merely a guess and there should be ample room for improvement by\n", 421 | " # tuning these paramters.\n", 422 | " reg_loss = 1e-5*torch.sum(spks) # L1 loss on total number of spikes\n", 423 | " reg_loss += 1e-5*torch.mean(torch.sum(torch.sum(spks,dim=0),dim=0)**2) # L2 loss on spikes per neuron\n", 424 | " \n", 425 | " # Here we combine supervised loss and the regularizer\n", 426 | " loss_val = loss_fn(log_p_y, y_local) + reg_loss\n", 427 | "\n", 428 | " optimizer.zero_grad()\n", 429 | " loss_val.backward()\n", 430 | " optimizer.step()\n", 431 | " local_loss.append(loss_val.item())\n", 432 | " mean_loss = np.mean(local_loss)\n", 433 | " print(\"Epoch %i: loss=%.5f\"%(e+1,mean_loss))\n", 434 | " loss_hist.append(mean_loss)\n", 435 | " \n", 436 | " return loss_hist\n", 437 | " \n", 438 | " \n", 439 | "def compute_classification_accuracy(x_data, y_data):\n", 440 | " \"\"\" Computes classification accuracy on supplied data in batches. \"\"\"\n", 441 | " accs = []\n", 442 | " for x_local, y_local in sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False):\n", 443 | " output,_ = run_snn(x_local.to_dense())\n", 444 | " m,_= torch.max(output,1) # max over time\n", 445 | " _,am=torch.max(m,1) # argmax over output units\n", 446 | " tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels\n", 447 | " accs.append(tmp)\n", 448 | " return np.mean(accs)" 449 | ] 450 | }, 451 | { 452 | "cell_type": "code", 453 | "execution_count": 14, 454 | "metadata": {}, 455 | "outputs": [ 456 | { 457 | "name": "stdout", 458 | "output_type": "stream", 459 | "text": [ 460 | "Epoch 1: loss=2.10319\n", 461 | "Epoch 2: loss=1.64978\n", 462 | "Epoch 3: loss=1.30702\n", 463 | "Epoch 4: loss=1.08484\n", 464 | "Epoch 5: loss=0.95575\n", 465 | "Epoch 6: loss=0.87375\n", 466 | "Epoch 7: loss=0.81585\n", 467 | "Epoch 8: loss=0.77363\n", 468 | "Epoch 9: loss=0.73897\n", 469 | "Epoch 10: loss=0.70913\n", 470 | "Epoch 11: loss=0.68416\n", 471 | "Epoch 12: loss=0.66362\n", 472 | "Epoch 13: loss=0.64633\n", 473 | "Epoch 14: loss=0.63024\n", 474 | "Epoch 15: loss=0.61729\n", 475 | "Epoch 16: loss=0.60596\n", 476 | "Epoch 17: loss=0.59438\n", 477 | "Epoch 18: loss=0.58355\n", 478 | "Epoch 19: loss=0.57630\n", 479 | "Epoch 20: loss=0.56730\n", 480 | "Epoch 21: loss=0.56177\n", 481 | "Epoch 22: loss=0.55338\n", 482 | "Epoch 23: loss=0.54661\n", 483 | "Epoch 24: loss=0.54120\n", 484 | "Epoch 25: loss=0.53651\n", 485 | "Epoch 26: loss=0.53226\n", 486 | "Epoch 27: loss=0.52634\n", 487 | "Epoch 28: loss=0.52089\n", 488 | "Epoch 29: loss=0.51909\n", 489 | "Epoch 30: loss=0.51457\n" 490 | ] 491 | } 492 | ], 493 | "source": [ 494 | "loss_hist = train(x_train, y_train, lr=2e-4, nb_epochs=30)" 495 | ] 496 | }, 497 | { 498 | "cell_type": "code", 499 | "execution_count": 15, 500 | "metadata": {}, 501 | "outputs": [ 502 | { 503 | "data": { 504 | "image/png": "\n", 505 | "text/plain": [ 506 | "
" 507 | ] 508 | }, 509 | "metadata": { 510 | "needs_background": "light" 511 | }, 512 | "output_type": "display_data" 513 | } 514 | ], 515 | "source": [ 516 | "plt.figure(figsize=(3.3,2),dpi=150)\n", 517 | "plt.plot(loss_hist)\n", 518 | "plt.xlabel(\"Epoch\")\n", 519 | "plt.ylabel(\"Loss\")\n", 520 | "sns.despine()" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 16, 526 | "metadata": {}, 527 | "outputs": [ 528 | { 529 | "name": "stdout", 530 | "output_type": "stream", 531 | "text": [ 532 | "Training accuracy: 0.848\n", 533 | "Test accuracy: 0.827\n" 534 | ] 535 | } 536 | ], 537 | "source": [ 538 | "print(\"Training accuracy: %.3f\"%(compute_classification_accuracy(x_train,y_train)))\n", 539 | "print(\"Test accuracy: %.3f\"%(compute_classification_accuracy(x_test,y_test)))" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": 17, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "def get_mini_batch(x_data, y_data, shuffle=False):\n", 549 | " for ret in sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=shuffle):\n", 550 | " return ret " 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 18, 556 | "metadata": {}, 557 | "outputs": [], 558 | "source": [ 559 | "x_batch, y_batch = get_mini_batch(x_test, y_test)\n", 560 | "output, other_recordings = run_snn(x_batch.to_dense())\n", 561 | "mem_rec, spk_rec = other_recordings" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": null, 567 | "metadata": {}, 568 | "outputs": [], 569 | "source": [] 570 | }, 571 | { 572 | "cell_type": "code", 573 | "execution_count": 19, 574 | "metadata": {}, 575 | "outputs": [ 576 | { 577 | "data": { 578 | "image/png": "\n", 579 | "text/plain": [ 580 | "
" 581 | ] 582 | }, 583 | "metadata": { 584 | "needs_background": "light" 585 | }, 586 | "output_type": "display_data" 587 | } 588 | ], 589 | "source": [ 590 | "fig=plt.figure(dpi=100)\n", 591 | "plot_voltage_traces(output)" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": 20, 597 | "metadata": {}, 598 | "outputs": [ 599 | { 600 | "data": { 601 | "image/png": "\n", 602 | "text/plain": [ 603 | "
" 604 | ] 605 | }, 606 | "metadata": { 607 | "needs_background": "light" 608 | }, 609 | "output_type": "display_data" 610 | } 611 | ], 612 | "source": [ 613 | "# Let's plot the hiddden layer spiking activity for some input stimuli\n", 614 | "\n", 615 | "nb_plt = 4\n", 616 | "gs = GridSpec(1,nb_plt)\n", 617 | "fig= plt.figure(figsize=(7,3),dpi=150)\n", 618 | "for i in range(nb_plt):\n", 619 | " plt.subplot(gs[i])\n", 620 | " plt.imshow(spk_rec[i].detach().cpu().numpy().T,cmap=plt.cm.gray_r, origin=\"lower\" )\n", 621 | " if i==0:\n", 622 | " plt.xlabel(\"Time\")\n", 623 | " plt.ylabel(\"Units\")\n", 624 | "\n", 625 | " sns.despine()" 626 | ] 627 | }, 628 | { 629 | "cell_type": "markdown", 630 | "metadata": {}, 631 | "source": [ 632 | "Compared to the hidden layer activity in our previous Tutorial 2, we can now appreciate that spiking in the hidden layer is much sparser.\n", 633 | "\n", 634 | "In the next tutorial notebook, we will apply the same training paradigm to the Heidelberg Digits, a realistic speech dataset generated from spoken digits processed through a plausible cochlea model." 635 | ] 636 | }, 637 | { 638 | "cell_type": "markdown", 639 | "metadata": {}, 640 | "source": [ 641 | "\"Creative
This work is licensed under a Creative Commons Attribution 4.0 International License." 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": null, 647 | "metadata": {}, 648 | "outputs": [], 649 | "source": [] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "metadata": {}, 655 | "outputs": [], 656 | "source": [] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": null, 661 | "metadata": {}, 662 | "outputs": [], 663 | "source": [] 664 | } 665 | ], 666 | "metadata": { 667 | "kernelspec": { 668 | "display_name": "Python 3", 669 | "language": "python", 670 | "name": "python3" 671 | }, 672 | "language_info": { 673 | "codemirror_mode": { 674 | "name": "ipython", 675 | "version": 3 676 | }, 677 | "file_extension": ".py", 678 | "mimetype": "text/x-python", 679 | "name": "python", 680 | "nbconvert_exporter": "python", 681 | "pygments_lexer": "ipython3", 682 | "version": "3.6.9" 683 | } 684 | }, 685 | "nbformat": 4, 686 | "nbformat_minor": 2 687 | } 688 | -------------------------------------------------------------------------------- /notebooks/figures/.gitignore: -------------------------------------------------------------------------------- 1 | ## Core latex/pdflatex auxiliary files: 2 | *.aux 3 | *.lof 4 | *.log 5 | *.lot 6 | *.fls 7 | *.out 8 | *.toc 9 | *.fmt 10 | *.fot 11 | *.cb 12 | *.cb2 13 | .*.lb 14 | 15 | ## Intermediate documents: 16 | *.dvi 17 | *.xdv 18 | *-converted-to.* 19 | # these rules might exclude image files for figures etc. 20 | # *.ps 21 | # *.eps 22 | # *.pdf 23 | 24 | ## Generated if empty string is given at "Please type another file name for output:" 25 | .pdf 26 | 27 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 28 | *.bbl 29 | *.bcf 30 | *.blg 31 | *-blx.aux 32 | *-blx.bib 33 | *.run.xml 34 | 35 | ## Build tool auxiliary files: 36 | *.fdb_latexmk 37 | *.synctex 38 | *.synctex(busy) 39 | *.synctex.gz 40 | *.synctex.gz(busy) 41 | *.pdfsync 42 | 43 | ## Build tool directories for auxiliary files 44 | # latexrun 45 | latex.out/ 46 | 47 | ## Auxiliary and intermediate files from other packages: 48 | # algorithms 49 | *.alg 50 | *.loa 51 | 52 | # achemso 53 | acs-*.bib 54 | 55 | # amsthm 56 | *.thm 57 | 58 | # beamer 59 | *.nav 60 | *.pre 61 | *.snm 62 | *.vrb 63 | 64 | # changes 65 | *.soc 66 | 67 | # comment 68 | *.cut 69 | 70 | # cprotect 71 | *.cpt 72 | 73 | # elsarticle (documentclass of Elsevier journals) 74 | *.spl 75 | 76 | # endnotes 77 | *.ent 78 | 79 | # fixme 80 | *.lox 81 | 82 | # feynmf/feynmp 83 | *.mf 84 | *.mp 85 | *.t[1-9] 86 | *.t[1-9][0-9] 87 | *.tfm 88 | 89 | #(r)(e)ledmac/(r)(e)ledpar 90 | *.end 91 | *.?end 92 | *.[1-9] 93 | *.[1-9][0-9] 94 | *.[1-9][0-9][0-9] 95 | *.[1-9]R 96 | *.[1-9][0-9]R 97 | *.[1-9][0-9][0-9]R 98 | *.eledsec[1-9] 99 | *.eledsec[1-9]R 100 | *.eledsec[1-9][0-9] 101 | *.eledsec[1-9][0-9]R 102 | *.eledsec[1-9][0-9][0-9] 103 | *.eledsec[1-9][0-9][0-9]R 104 | 105 | # glossaries 106 | *.acn 107 | *.acr 108 | *.glg 109 | *.glo 110 | *.gls 111 | *.glsdefs 112 | 113 | # gnuplottex 114 | *-gnuplottex-* 115 | 116 | # gregoriotex 117 | *.gaux 118 | *.gtex 119 | 120 | # htlatex 121 | *.4ct 122 | *.4tc 123 | *.idv 124 | *.lg 125 | *.trc 126 | *.xref 127 | 128 | # hyperref 129 | *.brf 130 | 131 | # knitr 132 | *-concordance.tex 133 | # TODO Comment the next line if you want to keep your tikz graphics files 134 | *.tikz 135 | *-tikzDictionary 136 | 137 | # listings 138 | *.lol 139 | 140 | # makeidx 141 | *.idx 142 | *.ilg 143 | *.ind 144 | *.ist 145 | 146 | # minitoc 147 | *.maf 148 | *.mlf 149 | *.mlt 150 | *.mtc[0-9]* 151 | *.slf[0-9]* 152 | *.slt[0-9]* 153 | *.stc[0-9]* 154 | 155 | # minted 156 | _minted* 157 | *.pyg 158 | 159 | # morewrites 160 | *.mw 161 | 162 | # nomencl 163 | *.nlg 164 | *.nlo 165 | *.nls 166 | 167 | # pax 168 | *.pax 169 | 170 | # pdfpcnotes 171 | *.pdfpc 172 | 173 | # sagetex 174 | *.sagetex.sage 175 | *.sagetex.py 176 | *.sagetex.scmd 177 | 178 | # scrwfile 179 | *.wrt 180 | 181 | # sympy 182 | *.sout 183 | *.sympy 184 | sympy-plots-for-*.tex/ 185 | 186 | # pdfcomment 187 | *.upa 188 | *.upb 189 | 190 | # pythontex 191 | *.pytxcode 192 | pythontex-files-*/ 193 | 194 | # tcolorbox 195 | *.listing 196 | 197 | # thmtools 198 | *.loe 199 | 200 | # TikZ & PGF 201 | *.dpth 202 | *.md5 203 | *.auxlock 204 | 205 | # todonotes 206 | *.tdo 207 | 208 | # vhistory 209 | *.hst 210 | *.ver 211 | 212 | # easy-todo 213 | *.lod 214 | 215 | # xcolor 216 | *.xcp 217 | 218 | # xmpincl 219 | *.xmpi 220 | 221 | # xindy 222 | *.xdy 223 | 224 | # xypic precompiled matrices 225 | *.xyc 226 | 227 | # endfloat 228 | *.ttt 229 | *.fff 230 | 231 | # Latexian 232 | TSWLatexianTemp* 233 | 234 | ## Editors: 235 | # WinEdt 236 | *.bak 237 | *.sav 238 | 239 | # Texpad 240 | .texpadtmp 241 | 242 | # LyX 243 | *.lyx~ 244 | 245 | # Kile 246 | *.backup 247 | 248 | # KBibTeX 249 | *~[0-9]* 250 | 251 | # auto folder when using emacs and auctex 252 | ./auto/* 253 | *.el 254 | 255 | # expex forward references with \gathertags 256 | *-tags.tex 257 | 258 | # standalone packages 259 | *.sta 260 | 261 | -------------------------------------------------------------------------------- /notebooks/figures/mlp_sketch/Makefile: -------------------------------------------------------------------------------- 1 | all: mlp_sketch.png 2 | 3 | %.png: %.pdf 4 | convert -density 150 $< $@ 5 | 6 | %.pdf: %.tex 7 | pdflatex $< 8 | 9 | -------------------------------------------------------------------------------- /notebooks/figures/mlp_sketch/mlp_sketch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/surrogate-gradient-learning/spytorch/f4dcc662b8fe1d23e88790400e58b5c51b2ebc36/notebooks/figures/mlp_sketch/mlp_sketch.png -------------------------------------------------------------------------------- /notebooks/figures/mlp_sketch/mlp_sketch.tex: -------------------------------------------------------------------------------- 1 | \documentclass{standalone} 2 | 3 | 4 | \usepackage{tikz} 5 | \usepackage{verbatim} 6 | \usepackage{tikz,graphicx} 7 | \usepackage{sfmath} 8 | \usepackage{xifthen} 9 | \usetikzlibrary{positioning,arrows,calc} 10 | % \usetikzlibrary{spy} 11 | \renewcommand\familydefault\sfdefault 12 | 13 | \usepackage{xcolor} 14 | 15 | \definecolor{gnu-violet}{HTML}{9400D3}% 16 | \definecolor{gnu-green}{HTML}{009E73}% 17 | \definecolor{gnu-blue}{HTML}{56b4e9}% 18 | \definecolor{gnu-yellow}{HTML}{E69F00}% 19 | 20 | 21 | \begin{document} 22 | 23 | \pagestyle{empty} 24 | 25 | \def\layersep{1.2cm} 26 | \def\unitsep{0.6cm} 27 | \def\nbhidden{1} 28 | 29 | \begin{tikzpicture}[shorten >=1pt,->,draw=black!50] 30 | 31 | \tikzstyle{every pin edge}=[<-,shorten <=1pt] 32 | \tikzstyle{input neuron}=[circle, fill=black!25, minimum size=15pt, inner sep=0pt] 33 | \tikzstyle{neuron}=[circle,fill=black!25,minimum size=12pt,inner sep=0pt] 34 | \tikzstyle{annot} = [text width=2cm, text centered, node distance=2.2cm] 35 | 36 | 37 | % Input units 38 | \foreach \name / \y in {0,...,5} { 39 | \node[neuron] (I-\name) at (\y*\unitsep, 0) {}; 40 | } 41 | 42 | % Draw the hidden layer nodes 43 | \foreach \layer in {1,...,\nbhidden} { 44 | \foreach \name / \y in {1,...,4} { 45 | \node[neuron] (H\layer-\name) at (\y*\unitsep, \layer*\layersep) {}; 46 | } 47 | } 48 | 49 | % Draw the output layer nodes 50 | \foreach \name / \y in {2,...,3} { 51 | \node[neuron] (O-\name) at (\y*\unitsep, \nbhidden*\layersep + \layersep) {}; 52 | } 53 | 54 | % Connect every node in the input layer with every node in the 55 | % hidden layer. 56 | \foreach \source in {0,...,5} 57 | \foreach \dest in {1,...,4} { 58 | \path (I-\source) edge [->] (H1-\dest); 59 | } 60 | 61 | % Connect input to hidden 62 | \foreach \source in {1,...,4} 63 | \foreach \dest in {2,...,3} { 64 | \path (H\nbhidden-\source) edge [->] (O-\dest); 65 | } 66 | 67 | % Connect hidden layers 68 | % \foreach \layer in {1,...,\nbhidden} { 69 | % \foreach \source in {1,...,4} 70 | % \foreach \dest in {1,...,4} { 71 | % \path (H\layer-\source) edge [->] (H\nbhidden-\dest); 72 | % } 73 | % } 74 | 75 | % Connect every node in the hidden layer with the output layer 76 | \foreach \source in {1,...,4} 77 | \foreach \dest in {2,...,3} { 78 | \path (H\nbhidden-\source) edge [->] (O-\dest); 79 | } 80 | 81 | 82 | % Annotate the layers 83 | \node[annot,left of=I-0] {Input layer}; 84 | \node[annot,left of=O-2] {Output layer}; 85 | 86 | 87 | \end{tikzpicture} 88 | 89 | % End of code 90 | 91 | \end{document} 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /notebooks/figures/snn_graph/Makefile: -------------------------------------------------------------------------------- 1 | all: snn_graph.png 2 | 3 | %.pdf: %.tex 4 | pdflatex $< 5 | 6 | %.png: %.pdf 7 | convert -density 150 $< $@ 8 | -------------------------------------------------------------------------------- /notebooks/figures/snn_graph/snn_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/surrogate-gradient-learning/spytorch/f4dcc662b8fe1d23e88790400e58b5c51b2ebc36/notebooks/figures/snn_graph/snn_graph.png -------------------------------------------------------------------------------- /notebooks/figures/snn_graph/snn_graph.tex: -------------------------------------------------------------------------------- 1 | \documentclass{standalone} 2 | 3 | \usepackage[latin1]{inputenc} 4 | \usepackage{tikz} 5 | \usepackage{tikz,graphicx} 6 | \usepackage{sfmath} 7 | \usepackage{bbold} 8 | \usepackage{dsfont} 9 | \usetikzlibrary{shapes,arrows} 10 | \usetikzlibrary{positioning,quotes} 11 | 12 | \renewcommand{\familydefault}{\sfdefault} 13 | 14 | \begin{document} 15 | \pagestyle{empty} 16 | 17 | \definecolor{spikecolor}{RGB}{180,51,76} 18 | \definecolor{pyblue}{HTML}{1f77b4} 19 | \definecolor{pyorange}{HTML}{ff7f0e} 20 | 21 | 22 | % Define block styles 23 | \tikzstyle{block} = [rectangle, draw, text width=3em, text centered, rounded 24 | corners, minimum height=2.2em] 25 | \tikzstyle{sum} = [block] 26 | \tikzstyle{syn} = [block, fill=black!10 ] 27 | \tikzstyle{mem} = [block, fill=pyblue!40 ] 28 | \tikzstyle{spk} = [block, fill=pyorange!40 ] 29 | 30 | 31 | \begin{tikzpicture}[node distance = 1.2cm, auto] 32 | \pgfmathsetmacro{\n}{4} % sets number of units -1 33 | \pgfmathsetmacro{\nm}{3} % set to \n -1 34 | \pgfmathsetmacro{\sep}{2.6} 35 | 36 | % Place nodes 37 | \foreach \t in {0,...,\n} { 38 | \node [sum] (x\t) at (\sep*\t,0) {$S^{(0)}[\t]$}; 39 | \node [syn, above of=x\t] (I\t) {$I^{(1)}[\t]$}; 40 | \node [mem, above of=I\t] (U\t) {$U^{(1)}[\t]$}; 41 | \node [spk, above of=U\t] (S\t) {$S^{(1)}[\t]$}; 42 | \node [block, above of=S\t] (y\t) {}; 43 | \node [above of=y\t] (downstream\t) {}; 44 | } 45 | 46 | 47 | % Draw edges 48 | \foreach \t in {0,...,\n} { 49 | \path [->, draw, thick] (U\t) -- (S\t); 50 | % \path [->, draw, thick] (y\t) -- (downstream\t) node[near start, right] 51 | % {$W^{(1,2)}$}; 52 | % \path [->, draw, thick, spikecolor] (upstream\t) -- (x\t) node[near end, 53 | % right] {$W^{(0,1)}$}; 54 | } 55 | 56 | % Forward in time 57 | \foreach \t in {0,...,\nm} { 58 | \pgfmathtruncatemacro{\next}{\t + 1} 59 | \path [->, draw, thick] (x\t) -- (I\next) node[near start, right] 60 | {~$W^{(1)}$}; 61 | \path [->, draw, thick] (I\t) -- (U\next); 62 | 63 | \path [->, draw, thick] (I\t) -- (I\next) node[near end] {$\alpha$}; 64 | \path [->, draw, thick] (U\t) -- (U\next) node[near end] {$\beta$}; 65 | 66 | \path [->, draw, thick, pyorange] (S\t) -- (y\next) node[near end, 67 | left] {$W^{(2)}$}; 68 | \path [->, draw, thick, pyorange, dashed] (S\t) -- (U\next) node[near start] {$-\mathds{1}$}; 69 | \path [->, draw, thick, pyorange] (S\t) -- (I\next) node[near end, 70 | left] {$V^{(1)}$}; 71 | 72 | } 73 | 74 | \node[node distance=1.0cm, below of=x0] (a0) {}; 75 | \node[right of=a0] (a1) {}; 76 | % \path [->, draw, thick, black] (a0) -- (a1) node[near start, below] {Time}; 77 | 78 | \draw (a0) edge[->,"Time"] (a1) ; 79 | 80 | \end{tikzpicture} 81 | 82 | \end{document} 83 | -------------------------------------------------------------------------------- /notebooks/figures/surrgrad/Makefile: -------------------------------------------------------------------------------- 1 | all: surrgrad.png 2 | 3 | %.png: %.pdf 4 | convert -density 150 $< $@ 5 | 6 | %.pdf: %.tex 7 | pdflatex $< 8 | 9 | %.tex: %.gnu 10 | gnuplot $< 11 | -------------------------------------------------------------------------------- /notebooks/figures/surrgrad/surrgrad.gnu: -------------------------------------------------------------------------------- 1 | #!/usr/bin/gnuplot 2 | 3 | set border 3 4 | set xtics nomirror out 5 | set ytics nomirror out 6 | 7 | 8 | theta(x) = x>0?1:0 9 | 10 | sigma(x) = 0.5*x/(abs(x)+1)+0.5 11 | 12 | set xlabel '$x$' 13 | # set ylabel 'Output' 14 | set key bottom right 15 | 16 | set xrange [-1:1] 17 | plot theta(x) lw 3 title '$\Theta(x)$', \ 18 | sigma(x) lw 3 title '$\sigma(x)$', \ 19 | sigma(10*x) lw 3 title '$\sigma(10 x)$', \ 20 | 21 | 22 | set term epslatex standalone dashed color size 3.3, 2.0 font '\sfdefault,8' \ 23 | header '\usepackage{sfmath}' 24 | set out 'surrgrad.tex' 25 | rep 26 | -------------------------------------------------------------------------------- /notebooks/figures/surrgrad/surrgrad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/surrogate-gradient-learning/spytorch/f4dcc662b8fe1d23e88790400e58b5c51b2ebc36/notebooks/figures/surrgrad/surrgrad.png -------------------------------------------------------------------------------- /notebooks/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import gzip, shutil 4 | import hashlib 5 | 6 | from six.moves.urllib.error import HTTPError 7 | from six.moves.urllib.error import URLError 8 | from six.moves.urllib.request import urlretrieve 9 | 10 | # The functions used in this file to download the dataset are based on 11 | # code from the keras library. Specifically, from the following file: 12 | # https://github.com/tensorflow/tensorflow/blob/v2.3.1/tensorflow/python/keras/utils/data_utils.py 13 | 14 | 15 | def get_shd_dataset(cache_dir, cache_subdir): 16 | 17 | # The remote directory with the data files 18 | base_url = "https://compneuro.net/datasets" 19 | 20 | # Retrieve MD5 hashes from remote 21 | response = urllib.request.urlopen("%s/md5sums.txt"%base_url) 22 | data = response.read() 23 | lines = data.decode('utf-8').split("\n") 24 | file_hashes = { line.split()[1]:line.split()[0] for line in lines if len(line.split())==2 } 25 | # Download the Spiking Heidelberg Digits (SHD) dataset 26 | files = [ "shd_train.h5.gz", 27 | "shd_test.h5.gz", 28 | ] 29 | for fn in files: 30 | origin = "%s/%s"%(base_url,fn) 31 | hdf5_file_path = get_and_gunzip(origin, fn, md5hash=file_hashes[fn], cache_dir=cache_dir, cache_subdir=cache_subdir) 32 | print("File %s decompressed to:"%(fn)) 33 | print(hdf5_file_path) 34 | 35 | def get_and_gunzip(origin, filename, md5hash=None, cache_dir=None, cache_subdir=None): 36 | gz_file_path = get_file(filename, origin, md5_hash=md5hash, cache_dir=cache_dir, cache_subdir=cache_subdir) 37 | hdf5_file_path = gz_file_path 38 | if not os.path.isfile(hdf5_file_path) or os.path.getctime(gz_file_path) > os.path.getctime(hdf5_file_path): 39 | print("Decompressing %s"%gz_file_path) 40 | with gzip.open(gz_file_path, 'r') as f_in, open(hdf5_file_path, 'wb') as f_out: 41 | shutil.copyfileobj(f_in, f_out) 42 | return hdf5_file_path 43 | 44 | def validate_file(fpath, file_hash, algorithm='auto', chunk_size=65535): 45 | if (algorithm == 'sha256') or (algorithm == 'auto' and len(file_hash) == 64): 46 | hasher = 'sha256' 47 | else: 48 | hasher = 'md5' 49 | 50 | if str(_hash_file(fpath, hasher, chunk_size)) == str(file_hash): 51 | return True 52 | else: 53 | return False 54 | 55 | def _hash_file(fpath, algorithm='sha256', chunk_size=65535): 56 | if (algorithm == 'sha256') or (algorithm == 'auto' and len(hash) == 64): 57 | hasher = hashlib.sha256() 58 | else: 59 | hasher = hashlib.md5() 60 | 61 | with open(fpath, 'rb') as fpath_file: 62 | for chunk in iter(lambda: fpath_file.read(chunk_size), b''): 63 | hasher.update(chunk) 64 | 65 | return hasher.hexdigest() 66 | 67 | def get_file(fname, 68 | origin, 69 | md5_hash=None, 70 | file_hash=None, 71 | cache_subdir='datasets', 72 | hash_algorithm='auto', 73 | extract=False, 74 | archive_format='auto', 75 | cache_dir=None): 76 | if cache_dir is None: 77 | cache_dir = os.path.join(os.path.expanduser('~'), '.data-cache') 78 | if md5_hash is not None and file_hash is None: 79 | file_hash = md5_hash 80 | hash_algorithm = 'md5' 81 | datadir_base = os.path.expanduser(cache_dir) 82 | if not os.access(datadir_base, os.W_OK): 83 | datadir_base = os.path.join('/tmp', '.data-cache') 84 | datadir = os.path.join(datadir_base, cache_subdir) 85 | os.makedirs(datadir, exist_ok=True) 86 | 87 | fpath = os.path.join(datadir, fname) 88 | 89 | download = False 90 | if os.path.exists(fpath): 91 | # File found; verify integrity if a hash was provided. 92 | if file_hash is not None: 93 | if not validate_file(fpath, file_hash, algorithm=hash_algorithm): 94 | print('A local file was found, but it seems to be ' 95 | 'incomplete or outdated because the ' + hash_algorithm + 96 | ' file hash does not match the original value of ' + file_hash + 97 | ' so we will re-download the data.') 98 | download = True 99 | else: 100 | download = True 101 | 102 | if download: 103 | print('Downloading data from', origin) 104 | 105 | error_msg = 'URL fetch failure on {}: {} -- {}' 106 | try: 107 | try: 108 | urlretrieve(origin, fpath) 109 | except HTTPError as e: 110 | raise Exception(error_msg.format(origin, e.code, e.msg)) 111 | except URLError as e: 112 | raise Exception(error_msg.format(origin, e.errno, e.reason)) 113 | except (Exception, KeyboardInterrupt) as e: 114 | if os.path.exists(fpath): 115 | os.remove(fpath) 116 | 117 | return fpath -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | jupyter==1.0.0 2 | jupyter-client==6.1.6 3 | jupyter-console==6.1.0 4 | jupyter-core==4.6.3 5 | torch==1.6.0 6 | torchvision==0.7.0 7 | seaborn==0.10.1 8 | h5py==2.10.0 9 | --------------------------------------------------------------------------------