├── .gitignore ├── LICENSE ├── data ├── N-MNIST │ └── .gitignore ├── SHD │ └── .gitignore └── ephys │ └── .gitignore ├── docs └── README.md ├── environment.yml ├── figures ├── figure2.png ├── figure4.png └── figure5.png ├── notebooks ├── ephys.ipynb ├── getting_started.ipynb ├── library_comparison.ipynb ├── speed_benchmarks.ipynb └── supervised_benchmarks.ipynb ├── scripts ├── ephys │ ├── build_data.py │ ├── run.py │ └── train.py ├── run_benchmarks.py └── supervised │ ├── run_blocks_nmnist.py │ ├── run_blocks_shd.py │ ├── run_detach_spikes.py │ ├── run_standard_nmnist.py │ ├── run_standard_shd.py │ └── train.py ├── setup.py ├── src ├── __init__.py ├── benchmark.py ├── datasets │ ├── __init__.py │ ├── ephys.py │ ├── neuromorphic.py │ ├── synthetic.py │ └── transforms.py ├── metric.py ├── models.py ├── query.py ├── snn │ ├── __init__.py │ ├── block │ │ ├── __init__.py │ │ ├── block.py │ │ ├── blocks.py │ │ └── util.py │ ├── snn.py │ └── surrogate.py └── train.py └── tests ├── __init__.py ├── test_block.py └── test_blocks.py /.gitignore: -------------------------------------------------------------------------------- 1 | # OSX 2 | .DS_Store 3 | 4 | # Python 5 | __pycache__/ 6 | *.pyc 7 | src.egg-info 8 | 9 | # PyCharm 10 | .idea 11 | 12 | # Jupyter Notebook 13 | .ipynb_checkpoints -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Luke Taylor 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /data/N-MNIST/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /data/SHD/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /data/ephys/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Addressing the speed-accuracy simulation trade-off for adaptive spiking neurons 2 | 3 | A new model for quickly training and simulating adaptive leaky integrate-and-fire spiking neural networks. 4 | 5 |

6 | 7 |

8 | 9 | ## Installing dependencies 10 | 11 | Install all required dependencies and activate the blocks environment using conda. 12 | ``` 13 | conda env create -f environment.yml 14 | conda activate blocks 15 | ``` 16 | 17 | ## Getting started tutorial 18 | 19 | See the [notebooks/getting_started.ipynb](../notebooks/getting_started.ipynb) notebook for getting started with our model. 20 | 21 | ## Reproducing paper results 22 | 23 | All the paper results can be reproduced using the scripts available in the `scripts` folder. 24 | 25 | ### Running benchmark experiments 26 | 27 | The `python run_benchmarks.py` script will benchmark the time of the forward and backward passes of the blocks and the standard SNN model for different numbers of neurons and simulation steps. 28 | 29 | ### Training models 30 | 31 | Ensure that the computer has a CUDA capable GPU with CUDA 11.7 installed. 32 | 33 | #### 1. Downloading and processing datasets 34 | 35 | #### Machine learning datasets: 36 | The content of the Neuromorphic-MNIST dataset can be [downloaded](https://www.garrickorchard.com/datasets/n-mnist) and unzipped into the `data/N-MNIST` directory. Thereafter, the `python convert_nmnist2h5.py` script (adapted from Perez-Nieves et al., 2021) needs to be run which processes the raw dataset. The Spiking Heidelberg Digits (SHD) dataset can be [downloaded](https://compneuro.net/posts/2019-spiking-heidelberg-digits/) and unzipped into the `data/SHD` directory. 37 | 38 | #### E-phys dataset: 39 | 40 | Running the `scripts/ephys/build_data.py` script will download and process the necessary data from the Allen Institute. 41 | 42 | #### 2. Train model 43 | 44 | You can train the blocks and standard SNN on the different datasets using the train.py scripts in the `scripts/ephys` and `scripts/supervised` folders respectively. See respective folders for different experiment run scripts. 45 | 46 | ## Building result figures 47 | 48 | Speedup plots can be built using: `notebooks/results/speed_benchmarks.ipynb` 49 | 50 | Machine learning benchmark plots can be built using: `notebooks/results/supervised_benchmarks.ipynb` 51 | 52 | Neural-fitting plots can be built using: `notebooks/results/ephys.ipynb` 53 | 54 | ### Machine learning benchmark results 55 |

56 | 57 |

58 | 59 | ### Neural-fitting results 60 |

61 | 62 |

-------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: blocks 2 | dependencies: 3 | - python=3.8 4 | - pytorch::pytorch 5 | - pytorch::torchvision 6 | - nvidia::cudatoolkit=11.7 7 | - matplotlib 8 | - seaborn 9 | - pandas 10 | - conda-forge::h5py 11 | - nb_conda_kernels 12 | - ipywidgets 13 | - pytest 14 | - pip: 15 | - brainbox==0.0.6 16 | - jupyter 17 | - allensdk 18 | - --editable . 19 | -------------------------------------------------------------------------------- /figures/figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/figures/figure2.png -------------------------------------------------------------------------------- /figures/figure4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/figures/figure4.png -------------------------------------------------------------------------------- /figures/figure5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/figures/figure5.png -------------------------------------------------------------------------------- /notebooks/getting_started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 89, 6 | "id": "da7df381", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "The autoreload extension is already loaded. To reload it, use:\n", 14 | " %reload_ext autoreload\n" 15 | ] 16 | } 17 | ], 18 | "source": [ 19 | "import time\n", 20 | "\n", 21 | "import torch\n", 22 | "import numpy as np\n", 23 | "import matplotlib.pyplot as plt\n", 24 | "import seaborn as sns\n", 25 | "import pandas as pd\n", 26 | "\n", 27 | "from src.datasets import SyntheticSpikes\n", 28 | "from src.snn.snn import SNN\n", 29 | "from src.snn.block.blocks import Blocks\n", 30 | "\n", 31 | "%load_ext autoreload\n", 32 | "%autoreload 2" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "8702a126", 38 | "metadata": {}, 39 | "source": [ 40 | "## Model equivalence" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "id": "ca0f57e9", 46 | "metadata": {}, 47 | "source": [ 48 | "Let's check if the blocks model and the standard model produce the same output raster" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 86, 54 | "id": "286f166a", 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "n_in = 100\n", 59 | "n_out = 20\n", 60 | "rf_len = 20\n", 61 | "t_len = 1000 \n", 62 | "block_len = 40\n", 63 | "\n", 64 | "# Instantiate recurrently connected ALIF SNNs (blocks and standard version)\n", 65 | "torch.manual_seed(42)\n", 66 | "input_raster = SyntheticSpikes(t_len, n_in, min_r=0, max_r=100, n_samples=1)[0]\n", 67 | "standard_snn = SNN(n_in, n_out, rf_len, t_len, t_latency=block_len, recurrent=True, init_beta=0.99, init_p=0.99)\n", 68 | "blocks_snn = Blocks(n_in, n_out, rf_len, t_len, t_latency=block_len, recurrent=True, init_beta=0.99, init_p=0.99)\n", 69 | "\n", 70 | "# Ensure bocks model has the same weights as the standard model\n", 71 | "blocks_snn._rf_weight = standard_snn._rf_weight\n", 72 | "blocks_snn._rf_bias = standard_snn._rf_bias\n", 73 | "blocks_snn._rec_weight = standard_snn._rec_weight\n", 74 | "\n", 75 | "# Obtain spikes from blocks and standard model\n", 76 | "with torch.no_grad():\n", 77 | " blocks_spikes = blocks_snn(input_raster.unsqueeze(0), mode=\"train\")\n", 78 | " standard_spikes = standard_snn(input_raster.unsqueeze(0), mode=\"train\")" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 62, 84 | "id": "3c475ec8", 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "Are model outputs the same? True\n" 92 | ] 93 | } 94 | ], 95 | "source": [ 96 | "# Drum roll, the moment we have all been waiting for...\n", 97 | "print(f\"Are model outputs the same? {torch.allclose(blocks_spikes, standard_spikes)}\")" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": 88, 103 | "id": "e476bd42", 104 | "metadata": { 105 | "scrolled": true 106 | }, 107 | "outputs": [ 108 | { 109 | "data": { 110 | "image/png": "", 111 | "text/plain": [ 112 | "
" 113 | ] 114 | }, 115 | "metadata": {}, 116 | "output_type": "display_data" 117 | } 118 | ], 119 | "source": [ 120 | "# Little util function\n", 121 | "def spike_tensor_to_points(spike_tensor):\n", 122 | " x = np.array([p[1].item() for p in torch.nonzero(spike_tensor.cpu())])\n", 123 | " y = np.array([p[0].item() for p in torch.nonzero(spike_tensor.cpu())])\n", 124 | " \n", 125 | " return x, y\n", 126 | "\n", 127 | "fig, axs = plt.subplots(1, 3, figsize=(6, 3))\n", 128 | "axs[0].scatter(*spike_tensor_to_points(input_raster), color=\"black\", s=0.5)\n", 129 | "axs[1].scatter(*spike_tensor_to_points(blocks_spikes[0]), color=\"black\", s=0.5)\n", 130 | "axs[2].scatter(*spike_tensor_to_points(standard_spikes[0]), color=\"black\", s=0.5)\n", 131 | "axs[0].set_title(\"Input raster\")\n", 132 | "axs[1].set_title(\"Blocks output\")\n", 133 | "axs[2].set_title(\"Standard output\")\n", 134 | "axs[0].set_ylabel(\"NeuronID\")\n", 135 | "axs[0].set_xlabel(\"Sim step\")\n", 136 | "axs[1].set_xlabel(\"Sim step\")\n", 137 | "axs[2].set_xlabel(\"Sim step\")\n", 138 | "fig.tight_layout()" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "id": "7def7ba2", 144 | "metadata": {}, 145 | "source": [ 146 | "## Model speed" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "24c34073", 152 | "metadata": {}, 153 | "source": [ 154 | "How much faster is our blocks models model?" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": 108, 160 | "id": "fb472914", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "def get_training_duration(model, device=\"cpu\"):\n", 165 | " model = model.to(device)\n", 166 | " data = input_raster.unsqueeze(0).to(device).repeat(128, 1, 1) # 128 batch size\n", 167 | " torch.cuda.synchronize()\n", 168 | " start_time = time.time()\n", 169 | " output = model(data)\n", 170 | " loss = output.sum() # Arbitraty loss just so we have something to backpropogate\n", 171 | " loss.backward()\n", 172 | " torch.cuda.synchronize()\n", 173 | " training_duration = time.time() - start_time\n", 174 | " return training_duration" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 103, 180 | "id": "91db4382", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "def get_training_durations_for_both_models(device):\n", 185 | " standard_training_times = [get_training_duration(standard_snn, device) for _ in range(11)][1:]\n", 186 | " block_training_times = [get_training_duration(blocks_snn, device) for _ in range(11)][1:]\n", 187 | " return pd.DataFrame({\"standard\": standard_training_times, \"block\": block_training_times})" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 104, 193 | "id": "b45b4bdf", 194 | "metadata": {}, 195 | "outputs": [], 196 | "source": [ 197 | "torch.backends.cudnn.benchmark = True # Make sure we use the best conv algorithm" 198 | ] 199 | }, 200 | { 201 | "cell_type": "code", 202 | "execution_count": 107, 203 | "id": "204b6a24", 204 | "metadata": {}, 205 | "outputs": [ 206 | { 207 | "data": { 208 | "image/png": "", 209 | "text/plain": [ 210 | "
" 211 | ] 212 | }, 213 | "metadata": {}, 214 | "output_type": "display_data" 215 | } 216 | ], 217 | "source": [ 218 | "fig, axs = plt.subplots(1, 2, figsize=(8, 3), sharey=False)\n", 219 | "\n", 220 | "def plot_training_durations(ax, device):\n", 221 | " sns.barplot(get_training_durations_for_both_models(device), ax=ax)\n", 222 | " ax.set(ylabel=\"Training time (sec)\", title=device)\n", 223 | " sns.despine()\n", 224 | " \n", 225 | "plot_training_durations(axs[0], \"cpu\")\n", 226 | "plot_training_durations(axs[1], \"cuda\")" 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "id": "f9e54af9", 232 | "metadata": {}, 233 | "source": [ 234 | "Voila! A condsiderable reduction in training time, on both CPUs and GPUs!" 235 | ] 236 | } 237 | ], 238 | "metadata": { 239 | "kernelspec": { 240 | "display_name": "Python [conda env:blocks] *", 241 | "language": "python", 242 | "name": "conda-env-blocks-py" 243 | }, 244 | "language_info": { 245 | "codemirror_mode": { 246 | "name": "ipython", 247 | "version": 3 248 | }, 249 | "file_extension": ".py", 250 | "mimetype": "text/x-python", 251 | "name": "python", 252 | "nbconvert_exporter": "python", 253 | "pygments_lexer": "ipython3", 254 | "version": "3.8.16" 255 | } 256 | }, 257 | "nbformat": 4, 258 | "nbformat_minor": 5 259 | } 260 | -------------------------------------------------------------------------------- /notebooks/library_comparison.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 53, 6 | "id": "bef22c5e", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import time\n", 11 | "\n", 12 | "import torch\n", 13 | "import torch.nn as nn\n", 14 | "from spikingjelly.activation_based import neuron, layer, surrogate\n", 15 | "from norse.torch.module.lif import LIFCell\n", 16 | "\n", 17 | "from src.snn.block.blocks import Blocks\n", 18 | "from src.snn.snn import SNN" 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "id": "f8051e64", 24 | "metadata": {}, 25 | "source": [ 26 | "## Setting up the different implementations" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "id": "4ca0b48b", 32 | "metadata": {}, 33 | "source": [ 34 | "Network benchmarked: 200 input units -> 100 spiking units over 1000 simulation steps using a batch size of 128." 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 93, 40 | "id": "38193107", 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "def time_jelly():\n", 45 | " input_tensor = torch.rand(128, 200, 1000).cuda()\n", 46 | " \n", 47 | " jelly_layer = nn.Sequential(\n", 48 | " layer.Linear(200, 100, bias=False),\n", 49 | " neuron.LIFNode(tau=100.0, surrogate_function=surrogate.ATan())\n", 50 | " ).cuda()\n", 51 | " \n", 52 | " start_time = time.time()\n", 53 | " \n", 54 | " for t in range(1000):\n", 55 | " out = jelly_layer(input_tensor[:, :, t])\n", 56 | " \n", 57 | " end_time = time.time()\n", 58 | " return end_time - start_time" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 94, 64 | "id": "976ed321", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "def time_norse():\n", 69 | " input_tensor = torch.rand(128, 200, 1000).cuda()\n", 70 | " \n", 71 | " norse_layer = nn.Sequential(\n", 72 | " layer.Linear(200, 100, bias=False),\n", 73 | " LIFCell()\n", 74 | " ).cuda()\n", 75 | " \n", 76 | " start_time = time.time()\n", 77 | " \n", 78 | " for t in range(1000):\n", 79 | " out = norse_layer(input_tensor[:, :, t])\n", 80 | " \n", 81 | " end_time = time.time()\n", 82 | " \n", 83 | " return end_time - start_time" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 95, 89 | "id": "932cb5f0", 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "def time_blocks():\n", 94 | " input_tensor = torch.rand(128, 200, 1000).cuda()\n", 95 | " \n", 96 | " blocks_snn = Blocks(200, 100, 1, 1000, t_latency=50, recurrent=False, init_beta=0.99, init_p=0.99).cuda()\n", 97 | " start_time = time.time()\n", 98 | " out = blocks_snn(input_tensor)\n", 99 | " end_time = time.time()\n", 100 | " \n", 101 | " return end_time - start_time\n", 102 | " \n", 103 | "def time_standard():\n", 104 | " input_tensor = torch.rand(128, 200, 1000).cuda()\n", 105 | " \n", 106 | " blocks_snn = SNN(200, 100, 1, 1000, t_latency=1, recurrent=False, init_beta=0.99, init_p=0.99).cuda()\n", 107 | " start_time = time.time()\n", 108 | " out = blocks_snn(input_tensor)\n", 109 | " end_time = time.time()\n", 110 | " \n", 111 | " return end_time - start_time" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "id": "733c9e80", 117 | "metadata": {}, 118 | "source": [ 119 | "## Benchmarking the differnet implementations" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 97, 125 | "id": "0ecf6b39", 126 | "metadata": {}, 127 | "outputs": [ 128 | { 129 | "name": "stdout", 130 | "output_type": "stream", 131 | "text": [ 132 | "Norse=0.3080105781555176\n", 133 | "Jelly=0.1869184970855713\n", 134 | "Standard=0.3317074775695801\n", 135 | "Blocks=0.016329288482666016\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "print(f\"Norse={time_norse()}\")\n", 141 | "print(f\"Jelly={time_jelly()}\")\n", 142 | "print(f\"Standard={time_standard()}\")\n", 143 | "print(f\"Blocks={time_blocks()}\")" 144 | ] 145 | } 146 | ], 147 | "metadata": { 148 | "kernelspec": { 149 | "display_name": "Python [conda env:blocks] *", 150 | "language": "python", 151 | "name": "conda-env-blocks-py" 152 | }, 153 | "language_info": { 154 | "codemirror_mode": { 155 | "name": "ipython", 156 | "version": 3 157 | }, 158 | "file_extension": ".py", 159 | "mimetype": "text/x-python", 160 | "name": "python", 161 | "nbconvert_exporter": "python", 162 | "pygments_lexer": "ipython3", 163 | "version": "3.8.16" 164 | } 165 | }, 166 | "nbformat": 4, 167 | "nbformat_minor": 5 168 | } 169 | -------------------------------------------------------------------------------- /scripts/ephys/build_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from src import datasets 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--root", type=str, default=".") 9 | args = parser.parse_args() 10 | 11 | # Builds current and spike tensors with DT=0.1ms 12 | # Note: You might want to set manifest_file if you have already downloaded the data from the Allen Institute 13 | builder = datasets.NoiseBuilder() 14 | builder.build(f"{args.root}/Blocks/data/ephys/train", noise_type="noise1") 15 | builder.build(f"{args.root}/Blocks/data/ephys/test", noise_type="noise2") 16 | 17 | # Builds current tensors with DT=0.05ms 18 | builder = datasets.NoiseBuilder() 19 | builder.build(f"{args.root}/Blocks/data/ephys/train", noise_type="noise1", target_sampling_rate=20000) 20 | builder.build(f"{args.root}/Blocks/data/ephys/test", noise_type="noise2", target_sampling_rate=20000) 21 | 22 | 23 | if __name__ == "__main__": 24 | main() 25 | -------------------------------------------------------------------------------- /scripts/ephys/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | root = "" # TODO: Change this to the project folder 3 | 4 | from src import datasets 5 | 6 | 7 | def launch(method, abs_refac_ms, downsample): 8 | neuron_idx_query = datasets.ValidNeuronQuery(f"{root}/data/ephys") 9 | 10 | dir_name = f"{method}_{abs_refac_ms}_{downsample}" 11 | os.makedirs(f"{root}/results/ephys/{dir_name}") 12 | 13 | for neuron_idx in neuron_idx_query.idx: 14 | os.system(f"python {root}/scripts/ephys/train.py --root={root} --method={method} --abs_refac_ms={abs_refac_ms} --downsample={downsample} --neuron_idx={neuron_idx} --dir_name={dir_name} --id={neuron_idx}") 15 | 16 | 17 | # Blocks: Different dt 18 | for downsample in [0.5, 1, 5, 10, 20, 40]: 19 | launch("blocks", abs_refac_ms=2, downsample=downsample) 20 | 21 | # Blocks: Different ARP 22 | for abs_refac_ms in [1, 4, 6, 8, 16]: 23 | launch("blocks", abs_refac_ms=abs_refac_ms, downsample=1) 24 | 25 | # Standard 26 | launch("standard", abs_refac_ms=2, downsample=1) 27 | -------------------------------------------------------------------------------- /scripts/ephys/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ast 3 | import logging 4 | import argparse 5 | 6 | from src import datasets, models, train 7 | 8 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 9 | 10 | 11 | def eval(v): 12 | return ast.literal_eval(v) 13 | 14 | 15 | def get_dataset(args): 16 | if args.downsample == 0.5: 17 | # Used dt=0.05ms for neural fits 18 | return datasets.EphysDataset(f"{args.root}/data/ephys", "train", args.neuron_idx, target_sampling_rate=20000) 19 | else: 20 | # Used dt>=0.1ms for neural fits (used by most experiments) 21 | return datasets.EphysDataset(f"{args.root}/data/ephys", "train", args.neuron_idx) 22 | 23 | 24 | def get_model(args): 25 | if args.downsample == 0.5: 26 | # Used dt=0.05ms for neural fits 27 | return models.Neuron(method=args.method, abs_refac_ms=args.abs_refac_ms, downsample=1, dt01ref=True) 28 | else: 29 | # Used dt>=0.1ms for neural fits (used by most experiments) 30 | return models.Neuron(method=args.method, abs_refac_ms=args.abs_refac_ms, downsample=int(args.downsample)) 31 | 32 | 33 | def get_trainer(args, model, train_dataset): 34 | n_epochs = 200 35 | batch_size = 5 36 | lr = 0.0001 37 | dt = 0.1 * args.downsample 38 | return train.EphysTrainer(f"{args.root}/results/ephys/{args.dir_name}", model, train_dataset, n_epochs, batch_size, lr, gamma=0.1, dt=dt, epoch_scan=5, max_decay=0, val_dataset=None, device="cuda", id=args.id) 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--root", type=str, default=".") 44 | 45 | # Model 46 | parser.add_argument("--method", type=str, default="standard") 47 | parser.add_argument("--abs_refac_ms", type=int, default=10) 48 | parser.add_argument("--downsample", type=float, default=1) 49 | 50 | # Dataset 51 | parser.add_argument("--neuron_idx", type=int, default="") 52 | 53 | # Trainer 54 | parser.add_argument("--dir_name", type=str, default="") 55 | parser.add_argument("--id", type=str, default="") 56 | 57 | args = parser.parse_args() 58 | 59 | train_dataset = get_dataset(args) 60 | model = get_model(args) 61 | model_trainer = get_trainer(args, model, train_dataset) 62 | model_trainer.train(save=True) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /scripts/run_benchmarks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cudnn.benchmark = True # Important in order to use best conv algorithm 3 | 4 | from src.benchmark import Benchmarker 5 | 6 | 7 | def run_different_sim_lengths(root): 8 | n_in = 1000 9 | n_hidden = 128 10 | 11 | for abs_refac in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]: 12 | for t_len in [2**9, 2**10, 2**11]: 13 | for batch_size in [32, 64, 128]: 14 | for method in ["standard", "blocks"]: 15 | bencher = Benchmarker(method, t_len, abs_refac, n_in, n_hidden, n_layers=1, batch_size=batch_size) 16 | bencher.benchmark() 17 | bencher.save(root) 18 | 19 | 20 | def run_different_layers(root): 21 | n_in = 1000 22 | 23 | for abs_refac in [40]: 24 | for t_len in [2**10]: 25 | for n_hidden in [128, 256, 512]: 26 | for batch_size in [64]: 27 | for n_layers in [2, 3, 4, 5]: 28 | for method in ["standard", "blocks"]: 29 | bencher = Benchmarker(method, t_len, abs_refac, n_in, n_hidden, n_layers=n_layers, batch_size=batch_size) 30 | bencher.benchmark() 31 | bencher.save(root) 32 | 33 | 34 | if __name__ == "__main__": 35 | root = "" # TODO: Change this to the project folder 36 | run_different_sim_lengths(f"{root}/benchmarks/sim_lengths") 37 | run_different_layers(f"{root}/benchmarks/layers") 38 | 39 | -------------------------------------------------------------------------------- /scripts/supervised/run_blocks_nmnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | root = "" # TODO: Change this to the project folder 3 | 4 | 5 | def launch(abs_refac, surr_grad, dt, method): 6 | for i in range(3): 7 | id = f"nmnist_{method}_{surr_grad}_{abs_refac}_{dt}_{i}" 8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method={method} --abs_refac={abs_refac} --surr_grad={surr_grad} --name=nmnist --dt={dt} --id={id}") 9 | 10 | 11 | for abs_refac in [10, 20, 30, 40, 50]: 12 | launch(abs_refac, "mg", 1, "blocks") 13 | -------------------------------------------------------------------------------- /scripts/supervised/run_blocks_shd.py: -------------------------------------------------------------------------------- 1 | import os 2 | root = "" # TODO: Change this to the project folder 3 | 4 | 5 | def launch(abs_refac, surr_grad, dt, method): 6 | for i in range(3): 7 | id = f"shd_{method}_{surr_grad}_{abs_refac}_{dt}_{i}" 8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method={method} --abs_refac={abs_refac} --surr_grad={surr_grad} --name=shd --dt={dt} --id={id}") 9 | 10 | 11 | for abs_refac in [10, 20, 30, 40, 50]: 12 | launch(abs_refac, "mg", 1, "blocks") 13 | -------------------------------------------------------------------------------- /scripts/supervised/run_detach_spikes.py: -------------------------------------------------------------------------------- 1 | import os 2 | root = "" # TODO: Change this to the project folder 3 | 4 | 5 | def launch(abs_refac, surr_grad, dt, detach_spike_grad): 6 | for i in range(3): 7 | id = f"shd_blocks_{surr_grad}_{abs_refac}_{dt}_{detach_spike_grad}_{i}" 8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method=blocks --abs_refac={abs_refac} --surr_grad={surr_grad} --name=shd --dt={dt} --id={id} --detach_spike_grad={detach_spike_grad}") 9 | 10 | 11 | launch(30, "mg", dt=2, detach_spike_grad=False) 12 | launch(30, "fast_sigmoid", dt=2, detach_spike_grad=False) 13 | launch(30, "box_car", dt=2, detach_spike_grad=False) 14 | -------------------------------------------------------------------------------- /scripts/supervised/run_standard_nmnist.py: -------------------------------------------------------------------------------- 1 | import os 2 | root = "" # TODO: Change this to the project folder 3 | 4 | 5 | def launch(abs_refac, surr_grad, dt, method): 6 | for i in range(3): 7 | id = f"nmnist_{method}_{surr_grad}_{abs_refac}_{dt}_{i}" 8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method={method} --abs_refac={abs_refac} --surr_grad={surr_grad} --name=nmnist --dt={dt} --id={id}") 9 | 10 | 11 | launch(0, "mg", 1, "standard") 12 | -------------------------------------------------------------------------------- /scripts/supervised/run_standard_shd.py: -------------------------------------------------------------------------------- 1 | import os 2 | root = "" # TODO: Change this to the project folder 3 | 4 | 5 | def launch(abs_refac, surr_grad, dt): 6 | for i in range(3): 7 | id = f"shd_standard_{surr_grad}_{abs_refac}_{dt}_{i}" 8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method=standard --abs_refac={abs_refac} --surr_grad={surr_grad} --name=shd --dt={dt} --id={id}") 9 | 10 | 11 | launch(0, "mg", dt=2) 12 | -------------------------------------------------------------------------------- /scripts/supervised/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import ast 3 | import logging 4 | import argparse 5 | 6 | from src import datasets, models, train 7 | 8 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 9 | 10 | 11 | def eval(v): 12 | return ast.literal_eval(v) 13 | 14 | 15 | def get_dataset(args, train=True): 16 | if args.name == "shd": 17 | return datasets.SHDDataset(f"{args.root}/data/SHD", train=train, dt=args.dt) 18 | elif args.name == "nmnist": 19 | return datasets.NMNISTDataset(f"{args.root}/data/N-MNIST", train=train, dt=args.dt) 20 | 21 | 22 | def get_model(args, dataset): 23 | abs_refac = int(args.abs_refac / args.dt) 24 | 25 | if args.name == "shd": 26 | n_in = 700 27 | n_out = 20 28 | elif args.name == "nmnist": 29 | n_in = 1156 30 | n_out = 10 31 | 32 | return models.AuditoryModel(args.method, n_in, args.n_hidden, n_out, dataset.t_len, abs_refac, eval(args.recurrent), args.dt, args.surr_grad, detach_spike_grad=eval(args.detach_spike_grad)) 33 | 34 | 35 | def get_trainer(args, model, train_dataset, val_dataset): 36 | gamma = 0.1 37 | if args.name == "shd": 38 | milestones = [15, 15] 39 | epochs = 40 40 | elif args.name == "nmnist": 41 | milestones = [30] 42 | epochs = 20 43 | return train.Trainer(f"{args.root}/results", model, train_dataset, epochs, args.batch_size, args.lr, milestones, gamma, val_dataset, id=args.id) 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument("--root", type=str, default=".") 49 | 50 | # Model 51 | parser.add_argument("--method", type=str, default="standard") 52 | parser.add_argument("--n_hidden", type=int, default=256) 53 | parser.add_argument("--abs_refac", type=float, default=10) 54 | parser.add_argument("--recurrent", type=str, default="True") 55 | parser.add_argument("--detach_spike_grad", type=str, default="True") 56 | parser.add_argument("--surr_grad", type=str, default="fast_sigmoid") 57 | 58 | # Dataset 59 | parser.add_argument("--name", type=str, default="shd") 60 | parser.add_argument("--dt", type=float, default=1) 61 | 62 | # Trainer 63 | parser.add_argument("--batch_size", type=int, default=64) 64 | parser.add_argument("--lr", type=float, default=0.001) 65 | parser.add_argument('--id', type=str, default="") 66 | 67 | args = parser.parse_args() 68 | 69 | train_dataset = get_dataset(args, train=True) 70 | val_dataset = get_dataset(args, train=False) 71 | if args.name == "nmnist": 72 | val_dataset = None 73 | model = get_model(args, train_dataset) 74 | model_trainer = get_trainer(args, model, train_dataset, val_dataset) 75 | model_trainer.train(save=True) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="src", 5 | packages=find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/src/__init__.py -------------------------------------------------------------------------------- /src/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch 5 | import pandas as pd 6 | 7 | from src import datasets 8 | from src import models 9 | 10 | 11 | class Benchmarker: 12 | 13 | def __init__(self, method, t_len, abs_refac, n_in, n_hidden, n_layers, batch_size=16, min_r=0, max_r=200, n_samples=11): 14 | self._method = method 15 | self._t_len = t_len 16 | self._abs_refac = abs_refac 17 | self._n_in = n_in 18 | self._n_hidden = n_hidden 19 | self._n_layers = n_layers 20 | self._batch_size = batch_size 21 | self._model = models.ModelBuilder(method, t_len, abs_refac, n_in, n_hidden, n_layers) 22 | 23 | self._data_loader = self._get_data_loader(t_len, n_in, min_r, max_r, batch_size, n_samples*batch_size) 24 | self._benchmark_results = None 25 | 26 | def benchmark(self, device="cuda"): 27 | timing_list = [] 28 | 29 | self._model = self._model.to(device) 30 | 31 | for i, data in enumerate(self._data_loader): 32 | # Benchmark forward pass 33 | data = data.to(device) 34 | 35 | start_time = time.time() 36 | output = self._model(data) 37 | torch.cuda.synchronize() 38 | forward_pass_time = time.time() - start_time 39 | 40 | # Benchmark backward pass 41 | start_time = time.time() 42 | loss = output.sum() 43 | loss.backward() 44 | torch.cuda.synchronize() 45 | backward_pass_time = time.time() - start_time 46 | 47 | # Ignore first run (as this usually loads things which slows things down) 48 | # e.g. cudnn finds best conv algorithm 49 | if i > 0: 50 | timing_row = {"forward_time": forward_pass_time, "backward_time": backward_pass_time} 51 | timing_list.append(timing_row) 52 | 53 | self._benchmark_results = timing_list 54 | 55 | def save(self, path): 56 | results_df = self._to_df() 57 | results_df.to_csv(os.path.join(path, f"{self._get_df_name()}.csv"), index=False) 58 | 59 | def _get_description(self): 60 | return {"method": self._method, "t_len": self._t_len, "abs_refac": self._abs_refac, "units": self._n_hidden, "layers": self._n_layers, "batch": self._batch_size} 61 | 62 | def _get_df_name(self): 63 | return f"{self._method}_{self._t_len}_{self._abs_refac}_{self._n_hidden}_{self._n_layers}_{self._batch_size}" 64 | 65 | def _get_data_loader(self, t_len, n_units, min_r, max_r, batch_size, n_samples): 66 | spikes_dataset = datasets.SyntheticSpikes(t_len, n_units, min_r, max_r, n_samples) 67 | return torch.utils.data.DataLoader(spikes_dataset, batch_size, shuffle=False) 68 | 69 | def _to_df(self): 70 | assert self._benchmark_results is not None 71 | results = [] 72 | 73 | for results_row in self._benchmark_results: 74 | results.append({**results_row, **self._get_description()}) 75 | 76 | return pd.DataFrame(results) 77 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .synthetic import SyntheticSpikes 2 | from .neuromorphic import NMNISTDataset, SHDDataset 3 | from .ephys import ValidNeuronQuery, Builder, NoiseBuilder, EphysDataset 4 | -------------------------------------------------------------------------------- /src/datasets/ephys.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import pickle 4 | from collections import defaultdict 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from allensdk.core.cell_types_cache import CellTypesCache 10 | from allensdk.ephys.ephys_extractor import EphysSweepFeatureExtractor 11 | 12 | 13 | class ValidNeuronQuery: 14 | 15 | """ 16 | This class builds an idx set of neurons which meet the requirement of having 4 repeats. 17 | """ 18 | 19 | def __init__(self, root): 20 | self.root = root 21 | self.query_df = self.build_query_df() 22 | 23 | @property 24 | def idx(self): 25 | query = self.query_df["n_trials"] == 4 # Want 4 repeats 26 | return self.query_df[query]["idx"].values 27 | 28 | def build_query_df(self): 29 | train_idxs = [v.split("/")[-1] for v in glob.glob(f"{self.root}/train/*")] 30 | test_idxs = [v.split("/")[-1] for v in glob.glob(f"{self.root}/test/*")] 31 | joint_idxs = list(set(train_idxs) & set(test_idxs)) 32 | 33 | data_list = [] 34 | 35 | for idx in joint_idxs: 36 | try: 37 | v = torch.load(f"{self.root}/test/{idx}/v.pt") 38 | n_trials = v.shape[0] 39 | data_list.append({"idx": idx, "n_trials": n_trials}) 40 | except: 41 | pass 42 | 43 | return pd.DataFrame(data_list) 44 | 45 | 46 | class EphysDataset: 47 | 48 | """ 49 | This is the PyTorch friendly e-phys dataset, with all the current and spike tensors. 50 | """ 51 | 52 | LENGTH = 10000 # = 1 second (with DT=0.1) 53 | 54 | def __init__(self, root, dataset, neuron_idx, target_sampling_rate=None): 55 | # dataset: str which is train or test 56 | info_df = pd.read_csv(f"{root}/info_df.csv").set_index("idx") 57 | 58 | if target_sampling_rate is None: 59 | self.i = torch.load(f"{root}/{dataset}/{neuron_idx}/i.pt")[:, :, :1*EphysDataset.LENGTH] 60 | elif target_sampling_rate == 20000: 61 | self.i = torch.load(f"{root}/{dataset}/{neuron_idx}/i_20k.pt")[:, :, :2*EphysDataset.LENGTH] 62 | self.v = torch.load(f"{root}/{dataset}/{neuron_idx}/v.pt")[:, :, :EphysDataset.LENGTH] 63 | self.s = torch.load(f"{root}/{dataset}/{neuron_idx}/s.pt")[:, :, :EphysDataset.LENGTH] 64 | 65 | vrest = info_df.loc[neuron_idx]["vrest"] 66 | vthresh = info_df.loc[neuron_idx]["vthresh"] 67 | self.i = self.i / (100 * self.i.max()) 68 | self.v = (self.v - vrest) / (vthresh - vrest) 69 | 70 | @property 71 | def hyperparams(self): 72 | return {} 73 | 74 | def __getitem__(self, item): 75 | x = self.i[0, item].unsqueeze(0) # Input current is the same across all trials 76 | trace = self.v[0, item].unsqueeze(0) 77 | trace = torch.clamp(trace, -1, 1) 78 | spikes = self.s[0, item].unsqueeze(0) # Take first spike trial (all are the same) 79 | 80 | return x, (trace, spikes) 81 | 82 | def __len__(self): 83 | return 3 84 | 85 | 86 | class Builder: 87 | 88 | def __init__(self, manifest_file="data/plot_data/allen-brain-observatory/cell_types/manifest.json"): 89 | self.ctc = CellTypesCache(manifest_file=manifest_file) 90 | ephys_df = self.generate_ephys_df() 91 | self.info_df = self.generate_info_df(ephys_df).set_index("idx") 92 | 93 | def save_info_df(self, path): 94 | self.info_df.to_csv(f"{path}/info_df.csv") 95 | 96 | def build(self, path, **kwargs): 97 | for neuron_idx in self.info_df.index: 98 | print(f"Building {neuron_idx}...") 99 | try: 100 | meta, i, v = self.generate_all_sweep_tensor(neuron_idx, **kwargs) 101 | os.mkdir(f"{path}/{neuron_idx}") 102 | torch.save(i, f"{path}/{neuron_idx}/i.pt") 103 | torch.save(v, f"{path}/{neuron_idx}/v.pt") 104 | with open(f"{path}/{neuron_idx}/meta.pkl", "wb") as f: 105 | pickle.dump(meta, f) 106 | except Exception as e: 107 | print(f"Failed {neuron_idx}: {e}") 108 | 109 | def generate_all_sweep_tensor(self, neuron_idx, target_sampling_rate=10000, start_s=1.02, end_s=1.3): 110 | sweep_idxs = self.info_df.loc[neuron_idx]["long_square"] 111 | sweep_dict = {} 112 | 113 | for sweep_idx in sweep_idxs: 114 | i, v, spike_times = self.generate_sweep_tensor(neuron_idx, sweep_idx, target_sampling_rate, start_s, end_s) 115 | sweep_dict[i] = (i, v, spike_times) 116 | 117 | # Sort from lowest to highest current 118 | sweep_dict = {k: v for k, v in sorted(sweep_dict.items(), key=lambda item: item[0])} 119 | 120 | meta = pd.DataFrame([{"i": v[0], "spikes": v[2]} for k, v in sweep_dict.items()]) 121 | v = torch.stack([sweep_dict[key][1] for key in sweep_dict.keys()]) 122 | 123 | return meta, v 124 | 125 | def generate_sweep_tensor(self, neuron_idx, sweep_number, target_sampling_rate=10000, start_s=1.02, end_s=1.3): 126 | data_set = self.ctc.get_ephys_data(neuron_idx) 127 | sweep_data = data_set.get_sweep(sweep_number) 128 | 129 | index_range = sweep_data["index_range"] 130 | i = sweep_data["stimulus"][0:index_range[1]+1] # in A 131 | v = sweep_data["response"][0:index_range[1]+1] # in V 132 | i *= 1e12 # to pA 133 | v *= 1e3 # to mV 134 | 135 | sampling_rate = int(sweep_data["sampling_rate"]) # in Hz 136 | t = np.arange(0, len(v)) * (1.0 / sampling_rate) 137 | downsample_factor = sampling_rate / target_sampling_rate 138 | 139 | sweep_ext = EphysSweepFeatureExtractor(t=t, v=v, i=i, start=start_s, end=end_s) 140 | sweep_ext.process_spikes() 141 | spike_times = sweep_ext.spike_feature("threshold_t") 142 | 143 | start_idx = int(start_s*sampling_rate) 144 | end_idx = int(end_s*sampling_rate) 145 | 146 | downsampled_v, downsampled_i = [], [] 147 | assert v.shape == i.shape 148 | 149 | idx = start_idx 150 | while idx < end_idx: 151 | downsampled_v.append(v[int(idx)]) 152 | downsampled_i.append(i[int(idx)]) 153 | idx += downsample_factor 154 | 155 | return torch.Tensor(downsampled_i), torch.Tensor(downsampled_v), spike_times 156 | 157 | def generate_ephys_df(self): 158 | cells = {cell["id"]: cell for cell in self.ctc.get_cells()} 159 | ephys_features = self.ctc.get_ephys_features() 160 | ephys_df = pd.DataFrame(ephys_features) 161 | ephys_df['id'] = pd.Series([idx for idx in ephys_df['specimen_id']], index=ephys_df.index) 162 | ephys_df['species'] = pd.Series([cells[idx]['species'] for idx in ephys_df['specimen_id']], index=ephys_df.index) 163 | ephys_df['dendrite_type'] = pd.Series([cells[idx]['dendrite_type'] for idx in ephys_df['specimen_id']], index=ephys_df.index) 164 | ephys_df['structure_layer_name'] = pd.Series([cells[idx]['structure_layer_name'] for idx in ephys_df['specimen_id']], index=ephys_df.index) 165 | ephys_df['disease_state'] = pd.Series([cells[idx]['disease_state'] for idx in ephys_df['specimen_id']], index=ephys_df.index) 166 | query = ephys_df["structure_layer_name"] == "4" 167 | query &= ephys_df["species"] == "Mus musculus" 168 | query &= ephys_df["disease_state"] == "" 169 | 170 | return ephys_df[query] 171 | 172 | def generate_info_df(self, ephys_df): 173 | info_list = [] 174 | 175 | neuron_idxs = ephys_df["id"].values 176 | 177 | for neuron_idx in neuron_idxs: 178 | # Sweep info 179 | sweeps = self.ctc.get_ephys_sweeps(neuron_idx) 180 | sweep_numbers = defaultdict(list) 181 | for sweep in sweeps: 182 | sweep_numbers[sweep['stimulus_name']].append(sweep['sweep_number']) 183 | 184 | neuron_type = ephys_df[ephys_df["id"] == neuron_idx]["dendrite_type"].values[0] 185 | vrest = ephys_df[ephys_df["id"] == neuron_idx]["vrest"].values[0] 186 | vthresh = ephys_df[ephys_df["id"] == neuron_idx]["threshold_v_long_square"].values[0] 187 | 188 | info_list.append({"idx": neuron_idx, "type": neuron_type, "long_square": sweep_numbers.get("Long Square"), "noise1": sweep_numbers.get("Noise 1"), "noise2": sweep_numbers.get("Noise 2"), "test": sweep_numbers.get("Test"), "vrest": vrest, "vthresh": vthresh}) 189 | 190 | return pd.DataFrame(info_list) 191 | 192 | 193 | class NoiseBuilder(Builder): 194 | 195 | def build(self, path, **kwargs): 196 | for neuron_idx in self.info_df.index: 197 | print(f"Building {neuron_idx}...") 198 | try: 199 | if not os.path.exists(f"{path}/{neuron_idx}"): 200 | os.makedirs(f"{path}/{neuron_idx}") 201 | 202 | i, v, s = self.generate_all_sweep_tensor(neuron_idx, **kwargs) 203 | 204 | # Default used for all experiments 205 | if kwargs.get("target_sampling_rate") is None: 206 | torch.save(i, f"{path}/{neuron_idx}/i.pt") 207 | torch.save(v, f"{path}/{neuron_idx}/v.pt") 208 | torch.save(s, f"{path}/{neuron_idx}/s.pt") 209 | # Re-ran some experiments with DT=0.05ms as requested by one reviewer 210 | elif kwargs.get("target_sampling_rate") == 20000: 211 | torch.save(i, f"{path}/{neuron_idx}/i_20k.pt") 212 | torch.save(v, f"{path}/{neuron_idx}/v_20k.pt") 213 | torch.save(s, f"{path}/{neuron_idx}/s_20k.pt") 214 | 215 | except Exception as e: 216 | print(f"Failed {neuron_idx}: {e}") 217 | 218 | def generate_all_sweep_tensor(self, neuron_idx, target_sampling_rate=10000, noise_type="noise1"): 219 | sweep_idxs = self.info_df.loc[neuron_idx][noise_type] 220 | assert len(sweep_idxs) is not None 221 | 222 | i_list = [] 223 | v_list = [] 224 | s_list = [] 225 | 226 | for sweep_idx in sweep_idxs: 227 | i, v, s = self.generate_noise_sweep_tensor(neuron_idx, sweep_idx, target_sampling_rate) 228 | i_list.append(i) 229 | v_list.append(v) 230 | s_list.append(s) 231 | 232 | return torch.stack(i_list), torch.stack(v_list), torch.stack(s_list) 233 | 234 | def generate_noise_sweep_tensor(self, neuron_idx, sweep_number, target_sampling_rate=10000): 235 | i1, v1, t1 = self.generate_sweep_tensor(neuron_idx, sweep_number, target_sampling_rate, start_s=2, end_s=5) 236 | i2, v2, t2 = self.generate_sweep_tensor(neuron_idx, sweep_number, target_sampling_rate, start_s=10, end_s=13) 237 | i3, v3, t3 = self.generate_sweep_tensor(neuron_idx, sweep_number, target_sampling_rate, start_s=18, end_s=21) 238 | t1 -= 2 239 | t2 -= 10 240 | t3 -= 18 241 | s1 = NoiseBuilder.to_spike_target(t1, target_sampling_rate) 242 | s2 = NoiseBuilder.to_spike_target(t2, target_sampling_rate) 243 | s3 = NoiseBuilder.to_spike_target(t3, target_sampling_rate) 244 | 245 | return torch.stack([i1, i2, i3]), torch.stack([v1, v2, v3]), torch.stack([s1, s2, s3]) 246 | 247 | @staticmethod 248 | def to_spike_target(spike_times, target_sampling_rate): 249 | dt = 0.0001 250 | spike_target = torch.zeros(target_sampling_rate) 251 | spike_idx = [int(spike_time // dt) for spike_time in spike_times if int(spike_time // dt) < target_sampling_rate] 252 | spike_target[spike_idx] = 1 253 | 254 | return spike_target 255 | -------------------------------------------------------------------------------- /src/datasets/neuromorphic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tables 3 | 4 | import torch 5 | import numpy as np 6 | from brainbox.datasets import BBDataset 7 | 8 | from src.datasets.transforms import SpikeTensorBuilder 9 | 10 | 11 | class H5Dataset(BBDataset): 12 | 13 | def __init__(self, root, train, n_in, n_out, t_len, train_name, test_name, dt): 14 | self._file = None 15 | self._train_name = train_name 16 | self._test_name = test_name 17 | super().__init__(root, train, lambda dataset: H5Dataset.preprocess(dataset), SpikeTensorBuilder(n_in, t_len, dt)) 18 | self._file.close() 19 | 20 | self._n_in = n_in 21 | self._n_out = n_out 22 | self._t_len = t_len 23 | self._dt = dt 24 | 25 | @property 26 | def hyperparams(self): 27 | return {**super().hyperparams, "t_len": self._t_len, "dt": self._dt} 28 | 29 | @property 30 | def n_in(self): 31 | return self._n_in 32 | 33 | @property 34 | def n_out(self): 35 | return self._n_out 36 | 37 | @property 38 | def t_len(self): 39 | return self._t_len 40 | 41 | @property 42 | def dt(self): 43 | return self._dt 44 | 45 | @staticmethod 46 | def preprocess(dataset): 47 | processed_dataset = [] 48 | units, times = dataset 49 | 50 | for i in range(len(units)): 51 | item_units = torch.Tensor(np.array(units[i], dtype=np.int)) 52 | item_times = torch.Tensor(np.array(times[i], dtype=np.float)) 53 | processed_dataset.append((item_units, item_times)) 54 | 55 | return processed_dataset 56 | 57 | @staticmethod 58 | def _open_file(hdf5_file_path): 59 | fileh = tables.open_file(hdf5_file_path, mode="r") 60 | units = fileh.root.spikes.units 61 | times = fileh.root.spikes.times 62 | labels = fileh.root.labels 63 | 64 | return fileh, units, times, labels 65 | 66 | def _load_dataset(self, train): 67 | name = self._train_name if train else self._test_name 68 | fileh, units, times, labels = H5Dataset._open_file(os.path.join(self._root, name)) 69 | targets = torch.Tensor(labels) 70 | self._file = fileh 71 | 72 | return (units, times), targets 73 | 74 | 75 | class NMNISTDataset(H5Dataset): 76 | 77 | T_LEN = 400 78 | 79 | def __init__(self, root, train=True, dt=1): 80 | t_len = int(NMNISTDataset.T_LEN / dt) 81 | super().__init__(root, train, n_in=1156, n_out=10, t_len=t_len, train_name="train.h5", test_name="test.h5", dt=dt) 82 | 83 | @property 84 | def name(self): 85 | return "nmnist" 86 | 87 | 88 | class SHDDataset(H5Dataset): 89 | 90 | T_LEN = 1200 91 | 92 | def __init__(self, root, train=True, dt=2): 93 | t_len = int(SHDDataset.T_LEN / dt) 94 | super().__init__(root, train, n_in=700, n_out=20, t_len=t_len, train_name="shd_train.h5", test_name="shd_test.h5", dt=dt) 95 | 96 | @property 97 | def name(self): 98 | return "shd" 99 | -------------------------------------------------------------------------------- /src/datasets/synthetic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions.poisson import Poisson 3 | from brainbox.datasets import BBDataset 4 | 5 | 6 | class SyntheticSpikes(BBDataset): 7 | 8 | """ 9 | This is the synthetic spike dataset with which the model was benchmarked. 10 | """ 11 | 12 | def __init__(self, t_len, n_units, min_r, max_r, n_samples): 13 | super().__init__(None) 14 | self.t_len = t_len 15 | self.n_units = n_units 16 | self.min_r = min_r 17 | self.max_r = max_r 18 | self.n_samples = n_samples 19 | 20 | def __getitem__(self, i): 21 | rate = torch.FloatTensor(1).uniform_(self.min_r, self.max_r).item() 22 | x = self._create_spikes(rate, self.n_units, self.t_len) 23 | 24 | return x 25 | 26 | def __len__(self): 27 | return self.n_samples 28 | 29 | def _load_dataset(self, train): 30 | return None, None 31 | 32 | def _create_spikes(self, rate, n_units, t_len): 33 | pois_dis = Poisson(rate/t_len) 34 | if type(n_units) == tuple: 35 | samples = pois_dis.sample(sample_shape=(*n_units, t_len)) 36 | else: 37 | samples = pois_dis.sample(sample_shape=(n_units, t_len)) 38 | samples[samples > 1] = 1 39 | 40 | return samples -------------------------------------------------------------------------------- /src/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import brainbox 4 | from brainbox.datasets.transforms import BBTransform 5 | 6 | 7 | class SpikeTensorBuilder(BBTransform): 8 | 9 | def __init__(self, n_units, t_len, dt): 10 | self._n_units = n_units 11 | self._t_len = t_len 12 | self._dt = dt 13 | 14 | def __call__(self, args): 15 | units, times = args[0], args[1] 16 | units = units % self._n_units 17 | times = torch.round(times * 1000. / self._dt).int() 18 | 19 | # Constrain spike length 20 | idxs = (times < self._t_len) 21 | units = units[idxs] 22 | times = times[idxs] 23 | 24 | # Build COO tensor 25 | indices = torch.stack([torch.Tensor(units.tolist()), torch.Tensor(times.tolist())], dim=0).long() 26 | shape = torch.Size([self._n_units, self._t_len, ]) 27 | spikes = torch.FloatTensor(np.ones(len(indices[0]))) 28 | 29 | return torch.sparse.FloatTensor(indices, spikes, shape).to_dense() 30 | 31 | 32 | class List: 33 | 34 | @staticmethod 35 | def get_nmnist_transform(t_len, use_augmentation=False): 36 | if use_augmentation: 37 | raise NotImplementedError 38 | else: 39 | transform_list = [SpikeTensorBuilder(n_units=1156, t_len=t_len, dt=1)] 40 | 41 | return brainbox.datasets.transforms.Compose(transform_list) 42 | 43 | @staticmethod 44 | def get_shd_transform(t_len): 45 | transform_list = [SpikeTensorBuilder(n_units=700, t_len=t_len, dt=2)] 46 | 47 | return brainbox.datasets.transforms.Compose(transform_list) 48 | -------------------------------------------------------------------------------- /src/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from brainbox.physiology.spiking import SpikeToPSTH 4 | 5 | from src import datasets, train 6 | 7 | 8 | class SpikeTrainEV: 9 | 10 | def __init__(self, st_len, sig): 11 | self._spike_smoother = SpikeToPSTH(st_len, sig) 12 | 13 | def __call__(self, input, target): 14 | # input: b x 3 x t 15 | smooth_input_trains = self.smooth_spike_trains(input) 16 | smooth_target_trains = self.smooth_spike_trains(target) 17 | flatten_input_trains = self.flatten_spike_trains(smooth_input_trains) 18 | flatten_target_trains = self.flatten_spike_trains(smooth_target_trains) 19 | 20 | return SpikeTrainEV.ev(flatten_input_trains, flatten_target_trains.mean(0).unsqueeze(0)).mean() 21 | 22 | def smooth_spike_trains(self, spike_trains): 23 | # spike_trains: b x 3 x t 24 | b = spike_trains.shape[0] 25 | 26 | return self._spike_smoother(spike_trains.view(b*3, -1).unsqueeze(1))[:, 0].view(b, 3, -1) 27 | 28 | def flatten_spike_trains(self, spike_trains): 29 | # spike_trains: b x 3 x t 30 | return spike_trains.flatten(1, 2) 31 | 32 | @staticmethod 33 | def ev(input, target): 34 | # input: b x t 35 | return (input.var(1) + target.var(1) - (input - target).var(1)) / (input.var(1) + target.var(1)) 36 | 37 | 38 | class EphysAnalysis: 39 | 40 | def __init__(self, root, method, abs_refac_ms, downsample, taus=[100]): 41 | self.root = root 42 | self.method = method 43 | self.abs_refac_ms = abs_refac_ms 44 | self.downsample = downsample 45 | self.taus = taus 46 | 47 | self.metrics = {tau: SpikeTrainEV(10000, tau) for tau in taus} 48 | self.neuron_idxs = datasets.ValidNeuronQuery(f"{root}/data/ephys").idx 49 | 50 | # Init test dataset 51 | self.test_dataset = {} 52 | self.norm_factors = {tau: {} for tau in taus} 53 | 54 | for neuron_idx in self.neuron_idxs: 55 | if downsample == 0.5: 56 | self.test_dataset[neuron_idx] = datasets.EphysDataset(f"{root}/data/ephys", "test", int(neuron_idx), target_sampling_rate=20000) 57 | else: 58 | self.test_dataset[neuron_idx] = datasets.EphysDataset(f"{root}/data/ephys", "test", int(neuron_idx)) 59 | target_spikes = self.test_dataset[neuron_idx].s.cpu() 60 | for tau in taus: 61 | self.norm_factors[tau][neuron_idx] = self.metrics[tau](target_spikes, target_spikes).item() 62 | 63 | self._norm_df = {tau: pd.Series(self.norm_factors[tau]).to_frame().rename(columns={0: "score"}) for tau in taus} 64 | self._ev_df = None 65 | 66 | def ev_df(self, tau, normalise): 67 | if self._ev_df is None: 68 | self._ev_df = self._build_ev_df().set_index("neuron_idx") 69 | 70 | query = self._ev_df["tau"] == tau 71 | ev_df = self._ev_df[query]["score"].to_frame() 72 | 73 | if normalise: 74 | df = ev_df / self._norm_df[tau] 75 | else: 76 | df = ev_df 77 | 78 | df.index = df.index.map(int) 79 | 80 | return df 81 | 82 | def get_times_df(self): 83 | dir_name = f"{self.method}_{self.abs_refac_ms}_{self.downsample}" 84 | 85 | times_list = [] 86 | 87 | for neuron_idx in self.neuron_idxs: 88 | times_csv = pd.read_csv(f"{self.root}/results/ephys/{dir_name}/{neuron_idx}/times.csv") 89 | forward_pass, backward_pass = times_csv.sum() 90 | times_list.append({"neuron_idx": neuron_idx, "forward_pass": forward_pass, "backward_pass": backward_pass}) 91 | 92 | return pd.DataFrame(times_list) 93 | 94 | def _build_ev_df(self): 95 | dir_name = f"{self.method}_{self.abs_refac_ms}_{self.downsample}" 96 | 97 | metric_list = [] 98 | 99 | for i, neuron_idx in enumerate(self.neuron_idxs): 100 | print(f"Building {i}/{len(self.neuron_idxs)} {neuron_idx}...") 101 | dt01ref = self.downsample == 0.5 102 | neuron = train.EphysTrainer.load_model(f"{self.root}/results/ephys/{dir_name}", neuron_idx, dt01ref=dt01ref) 103 | 104 | with torch.no_grad(): 105 | test_dataset = self.test_dataset[neuron_idx] 106 | pred_spikes = neuron(test_dataset.i[0].unsqueeze(1).cuda()).permute(1, 0, 2).cpu() 107 | target_spikes = test_dataset.s.cpu() 108 | 109 | for tau in self.taus: 110 | score = self.metrics[tau](target_spikes, pred_spikes).item() # note: model spikes are reference 111 | metric_list.append({"neuron_idx": neuron_idx, "score": score, "tau": tau}) 112 | 113 | return pd.DataFrame(metric_list) 114 | 115 | def load_prediction(self, neuron_idx): # Load prediction for a certain fit and neuron 116 | test_dataset = self.test_dataset[str(neuron_idx)] 117 | 118 | with torch.no_grad(): 119 | dir_name = f"{self.method}_{self.abs_refac_ms}_{self.downsample}" 120 | dt01ref = self.downsample == 0.5 121 | neuron = train.EphysTrainer.load_model(f"{self.root}/results/ephys/{dir_name}", neuron_idx, dt01ref=dt01ref) 122 | 123 | output = neuron(test_dataset.i[0].unsqueeze(1).cuda(), mode="val") 124 | spikes = output[0].permute(1, 0, 2).cpu() 125 | mem = output[1].permute(1, 0, 2).cpu() 126 | 127 | return spikes, mem, test_dataset.s, test_dataset.v, test_dataset.i 128 | 129 | def load_params_df(self): 130 | dir_name = f"{self.method}_{self.abs_refac_ms}_{self.downsample}" 131 | 132 | params_list = [] 133 | 134 | for i, neuron_idx in enumerate(self.neuron_idxs): 135 | print(f"Building {i}/{len(self.neuron_idxs)} {neuron_idx}...") 136 | neuron = train.EphysTrainer.load_model(f"{self.root}/results/ephys/{dir_name}", neuron_idx).neuron 137 | 138 | params_list.append({"neuron_idx": neuron_idx, "beta": neuron.beta.item(), "p": neuron.p.item(), "b": neuron.b.item()}) 139 | 140 | return pd.DataFrame(params_list) 141 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from brainbox import models 6 | 7 | import src.snn.block.blocks as blocks 8 | import src.snn.snn as snn 9 | 10 | 11 | class AuditoryModel(models.BBModel): 12 | 13 | # AuditoryModel name comes from the SHD dataset being auditory 14 | 15 | HIDDEN_MEM_TIME = 20 16 | HIDDEN_ADAPT_TIME = 150 17 | READOUT_MEM_TIME = 20 18 | 19 | def __init__(self, method, n_in, n_hidden, n_out, t_len, abs_refac, recurrent=True, dt=1, surr_grad="fast_sigmoid", detach_spike_grad=True): 20 | super().__init__() 21 | self._method = method 22 | self._n_in = n_in 23 | self._n_hidden = n_hidden 24 | self._n_out = n_out 25 | self._t_len = t_len 26 | self._abs_refac = abs_refac 27 | self._recurrent = recurrent 28 | self._dt = dt 29 | self._surr_grad = surr_grad 30 | 31 | init_hidden_beta = np.exp(-dt / AuditoryModel.HIDDEN_MEM_TIME) 32 | init_hidden_p = np.exp(-dt / AuditoryModel.HIDDEN_ADAPT_TIME) 33 | init_readout_beta = np.exp(-dt / AuditoryModel.READOUT_MEM_TIME) 34 | 35 | if method == "standard": 36 | self._thalamic_layer = snn.SNN(n_in, n_hidden, 1, t_len, abs_refac, recurrent, beta_grad=True, adapt=True, init_beta=init_hidden_beta, init_p=init_hidden_p, surr_grad=self._surr_grad) 37 | self._cortical_layer = snn.SNN(n_hidden, n_hidden, 1, t_len, abs_refac, recurrent, beta_grad=True, adapt=True, init_beta=init_hidden_beta, init_p=init_hidden_p, surr_grad=self._surr_grad) 38 | self._output = snn.SNNIntegrator(n_hidden, n_out, t_len, init_beta=init_readout_beta) 39 | else: 40 | self._thalamic_layer = blocks.Blocks(n_in, n_hidden, 1, t_len, abs_refac, recurrent, beta_grad=True, adapt=True, init_beta=init_hidden_beta, init_p=init_hidden_p, surr_grad=self._surr_grad, detach_spike_grad=detach_spike_grad) 41 | self._cortical_layer = blocks.Blocks(n_hidden, n_hidden, 1, t_len, abs_refac, recurrent, beta_grad=True, adapt=True, init_beta=init_hidden_beta, init_p=init_hidden_p, surr_grad=self._surr_grad, detach_spike_grad=detach_spike_grad) 42 | self._output = blocks.BlocksIntegrator(n_hidden, n_out, t_len, init_beta=init_readout_beta) 43 | 44 | @property 45 | def hyperparams(self): 46 | return {**super().hyperparams, "method": self._method, "n_in": self._n_in, "n_hidden": self._n_hidden, "n_out": self._n_out, "t_len": self._t_len, "abs_refac": self._abs_refac, "recurrent": self._recurrent, "dt": self._dt, "surr_grad": self._surr_grad} 47 | 48 | def forward(self, x, mode="train"): 49 | # x: b x n x t 50 | thalamic_output = self._thalamic_layer(x, mode) 51 | cortical_output = self._cortical_layer(thalamic_output if mode == "train" else thalamic_output[0], mode) 52 | 53 | if mode == "train": 54 | return self._output(cortical_output, mode).sum(2) 55 | else: 56 | return self._output(cortical_output[0], mode).sum(2), cortical_output, thalamic_output 57 | 58 | 59 | class ModelBuilder(models.BBModel): 60 | 61 | def __init__(self, method, t_len, abs_refac, n_in, n_hidden, n_layers): 62 | super().__init__() 63 | self._layers = nn.ModuleList() 64 | 65 | for i in range(n_layers): 66 | n_in = n_in if i == 0 else n_hidden 67 | if method == "standard": 68 | self._layers.append(snn.SNN(n_in, n_hidden, 1, t_len, abs_refac, recurrent=True, beta_grad=True, adapt=True, init_beta=0.9, init_p=0.9, surr_grad="mg")) 69 | else: 70 | self._layers.append(blocks.Blocks(n_in, n_hidden, 1, t_len, abs_refac, recurrent=True, beta_grad=True, adapt=True, init_beta=0.9, init_p=0.9, surr_grad="mg")) 71 | 72 | def forward(self, x): 73 | for layer in self._layers: 74 | x = layer(x) 75 | 76 | return x 77 | 78 | 79 | class Neuron(models.BBModel): 80 | 81 | def __init__(self, method, abs_refac_ms, downsample=1, dt01ref=False): 82 | super().__init__() 83 | self.method = method 84 | self.abs_refac_ms = abs_refac_ms 85 | self.downsample = downsample 86 | self.dt01ref = dt01ref 87 | self.dt_ms = downsample * 0.1 # Larger dt_ms == downsample temporal resolution 88 | 89 | if not dt01ref: 90 | self.neuron = Neuron.get_neuron(method, abs_refac_ms, self.dt_ms, downsample) 91 | else: 92 | self.neuron = Neuron.get_neuron(method, abs_refac_ms, 0.5*self.dt_ms, 0.5*downsample) 93 | upsample_kernel = torch.zeros(downsample) 94 | upsample_kernel[0] = 1 95 | self.upsample_kernel = nn.Parameter(upsample_kernel.view(1, 1, -1), requires_grad=False) 96 | 97 | @property 98 | def hyperparams(self): 99 | return {**super().hyperparams, "method": self.method, "abs_refac_ms": self.abs_refac_ms, "downsample": self.downsample, "dt01ref": self.dt01ref} 100 | 101 | def forward(self, x, mode="train"): 102 | x = F.avg_pool1d(x, self.downsample, self.downsample) # Down sample the signal (if need be) 103 | spikes = self.neuron(x, mode) 104 | 105 | # Return data for plotting 106 | if mode == "val": 107 | spikes, mem = spikes[0], spikes[1] 108 | 109 | return spikes, mem 110 | 111 | # Return predicted spike train for training 112 | if self.dt01ref: # When running in DT=0.05ms (was added on to run additional experiments for a reviewer) 113 | spikes = F.max_pool1d(spikes, 2, 2) 114 | return spikes 115 | 116 | if self.downsample == 1: 117 | return spikes 118 | else: 119 | return F.conv_transpose1d(spikes, self.upsample_kernel, stride=self.downsample) 120 | 121 | @staticmethod 122 | def get_neuron(method, abs_refac_ms, dt_ms, downsample): 123 | init_beta = np.exp(-dt_ms / 20) 124 | init_p = np.exp(-dt_ms / 100) 125 | t_len = int(1000 / dt_ms) 126 | abs_refac_ms = int(abs_refac_ms / dt_ms) 127 | 128 | if method == "blocks": 129 | neuron = blocks.Blocks(1, 1, rf_len=1, t_len=t_len, t_latency=abs_refac_ms, recurrent=False, beta_grad=True, adapt=True, init_beta=init_beta, init_p=init_p, detach_spike_grad=True, surr_grad="mg") 130 | else: 131 | neuron = snn.SNN(1, 1, rf_len=1, t_len=t_len, t_latency=abs_refac_ms, recurrent=False, beta_grad=True, adapt=True, init_beta=init_beta, init_p=init_p, detach_spike_grad=True, surr_grad="mg") 132 | 133 | neuron.init_weight(neuron._rf_weight, "constant", c=downsample) 134 | neuron._b = nn.Parameter(data=torch.Tensor([0.1 / downsample]), requires_grad=True) 135 | 136 | return neuron 137 | -------------------------------------------------------------------------------- /src/query.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | import pandas as pd 4 | from brainbox import trainer 5 | 6 | from src import models, train 7 | 8 | 9 | models.snn.BaseSNN.MIN_BETA = 0.01 10 | models.snn.BaseSNN.MAX_BETA = 0.99 11 | 12 | 13 | class BenchmarkQuery: 14 | 15 | def __init__(self, root, batches=[32, 64, 128]): 16 | self._root = root 17 | 18 | self._results_df = self._build_df() 19 | self._results_df = pd.concat([self._query_results(batch=b) for b in batches]) 20 | 21 | def _build_df(self): 22 | results_df_list = [] 23 | 24 | for path in self._get_paths(self._root): 25 | results_df_list.append(pd.read_csv(path)) 26 | 27 | results_df = pd.concat(results_df_list) 28 | results_df["total_time"] = results_df["forward_time"] + results_df["backward_time"] 29 | 30 | return results_df 31 | 32 | def _get_paths(self, root): 33 | return [path for path in glob.glob(f"{root}/*")] 34 | 35 | def _query_results(self, **kwargs): 36 | query = True 37 | for key, value in kwargs.items(): 38 | query &= self._results_df[key] == value 39 | 40 | if len(kwargs) > 0: 41 | return self._results_df[query] 42 | 43 | return self._results_df 44 | 45 | def get_speedup(self): 46 | results_df = self._build_df() 47 | standard_times = results_df[results_df["method"] == "standard"].set_index(["t_len", "units", "batch", "abs_refac", "layers"])[["forward_time", "backward_time", "total_time"]] 48 | blocks_times = results_df[results_df["method"] == "blocks"].set_index(["t_len", "units", "batch", "abs_refac", "layers"])[["forward_time", "backward_time", "total_time"]] 49 | 50 | speedup_df = standard_times / blocks_times 51 | speedup_df.rename(columns={"forward_time": "forward_speedup", "backward_time": "backward_speedup", "total_time": "total_speedup"}, inplace=True) 52 | 53 | return speedup_df 54 | 55 | 56 | class SupervisedQuery: 57 | 58 | def __init__(self, root): 59 | self.root = root 60 | 61 | def get_average_duration_per_batch(self, models_root, model_id): 62 | durations_list = [] 63 | 64 | duration = trainer.load_log(models_root, model_id)["duration"][1:].mean() 65 | durations_list.append({"model_id": model_id, "duration": duration}) 66 | 67 | return pd.DataFrame(durations_list).set_index("model_id").values[0][0] 68 | 69 | def build_results(self, dataset, methods, sgs, abs_refacs, repeats, detach=True, batch_size=500): 70 | results_list = [] 71 | 72 | for method in methods: 73 | for sg in sgs: 74 | for abs_refac in abs_refacs: 75 | for i in range(repeats): 76 | if detach: 77 | name = f"{dataset.name}_{method}_{sg}_{abs_refac}_{dataset.dt}_{i}" 78 | else: 79 | name = f"{dataset.name}_{method}_{sg}_{abs_refac}_{dataset.dt}_{detach}_{i}" 80 | print(f"Loading {name}...") 81 | model = train.Trainer.load_model(f"{self.root}/results/supervised", name) 82 | val_acc = train.Trainer.get_acc(model, dataset, batch_size) 83 | avg_time = self.get_average_duration_per_batch(f"{self.root}/results/supervised", name) 84 | results_list.append({"dataset": dataset.name, "method": method, "sg": sg, "abs_refac": abs_refac, "dt": dataset.dt, "i": i, "val_acc": val_acc, "avg_time": avg_time}) 85 | 86 | return pd.DataFrame(results_list) -------------------------------------------------------------------------------- /src/snn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/src/snn/__init__.py -------------------------------------------------------------------------------- /src/snn/block/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/src/snn/block/__init__.py -------------------------------------------------------------------------------- /src/snn/block/block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from src.snn.block.util import bconv1d 6 | from src.snn import surrogate 7 | 8 | 9 | class Block(nn.Module): 10 | 11 | def __init__(self, n_in, t_len, surr_grad): 12 | super().__init__() 13 | self._n_in = n_in 14 | self._t_len = t_len 15 | self._surr_grad = surr_grad 16 | 17 | self._beta_ident_base = nn.Parameter(torch.ones(n_in, t_len), requires_grad=False) 18 | self._beta_exp = nn.Parameter(torch.arange(t_len).flip(0).unsqueeze(0).expand(n_in, t_len).float(), requires_grad=False) 19 | self._phi_kernel = nn.Parameter((torch.arange(t_len) + 1).flip(0).float().view(1, 1, 1, t_len), requires_grad=False) 20 | 21 | @staticmethod 22 | def g(faulty_spikes): 23 | negate_faulty_spikes = faulty_spikes.clone().detach() 24 | negate_faulty_spikes[faulty_spikes == 1.0] = 0 25 | faulty_spikes -= negate_faulty_spikes 26 | 27 | return faulty_spikes 28 | 29 | def forward(self, current, beta, v_init=None, v_th=1, mode="train"): 30 | 31 | if v_init is not None: 32 | current[:, :, 0] += beta * v_init 33 | 34 | pad_current = F.pad(current, pad=(self._t_len - 1, 0)).unsqueeze(1) 35 | 36 | # compute membrane potential without reset 37 | beta_kernel = self.build_beta_kernel(beta) 38 | membrane = bconv1d(pad_current, beta_kernel) 39 | 40 | # map no-reset membrane potentials to output spikes 41 | v_th = v_th.unsqueeze(1) 42 | faulty_spikes = surrogate.spike(membrane - v_th, self._surr_grad) 43 | 44 | pad_spikes = F.pad(faulty_spikes, pad=(self._t_len - 1, 0)) 45 | z = F.conv2d(pad_spikes, self._phi_kernel) 46 | z_copy = z.clone().squeeze(1) 47 | 48 | if mode == "train": 49 | return Block.g(z).squeeze(1), z_copy, membrane.squeeze(1) 50 | elif mode == "val": 51 | return Block.g(z).squeeze(1), z_copy, faulty_spikes, membrane.squeeze(1) 52 | 53 | def build_beta_kernel(self, beta): 54 | beta_base = beta.unsqueeze(1).multiply(self._beta_ident_base) 55 | return torch.pow(beta_base, self._beta_exp).unsqueeze(1).unsqueeze(1) 56 | -------------------------------------------------------------------------------- /src/snn/block/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from src.snn.snn import BaseSNN 7 | from src.snn.block.block import Block 8 | from src.snn.block.util import time_cat, bconv1d 9 | 10 | 11 | class Blocks(BaseSNN): 12 | 13 | def __init__(self, n_in, n_out, rf_len, t_len, t_latency, recurrent=True, beta_grad=True, adapt=True, init_beta=1, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid"): 14 | super().__init__(n_in, n_out, rf_len, t_len, t_latency, recurrent, beta_grad, adapt, init_beta, init_p, detach_spike_grad, surr_grad) 15 | 16 | self._t_len_block = t_latency + 1 17 | self._block = Block(n_out, self._t_len_block, surr_grad) 18 | self._n_blocks = math.ceil(t_len / self._t_len_block) 19 | self._t_pad = self._n_blocks * self._t_len_block - self._t_len 20 | 21 | self._p_ident_base = nn.Parameter(torch.ones(n_out, self._t_len_block), requires_grad=False) 22 | self._p_exp = nn.Parameter(torch.arange(1, self._t_len_block + 1).float(), requires_grad=False) 23 | 24 | def process(self, x, mode="train"): 25 | x_init = x 26 | if self._t_pad != 0: 27 | x = F.pad(x, pad=(0, self._t_pad)) 28 | 29 | mem_list = [] 30 | spikes_list = [] 31 | z_list = [] 32 | 33 | z = torch.zeros_like(x[:, :, self._t_len_block:]) 34 | v_init = torch.zeros_like(x[:, :, 0]).to(x.device) 35 | int_mem = torch.zeros_like(x[:, :, 0]).to(x.device) 36 | 37 | a_kernel = torch.zeros_like(x).to(x.device)[:, :, :self._t_len_block] 38 | v_th = torch.ones_like(x).to(x.device)[:, :, :self._t_len_block] 39 | v_th_list = [] 40 | 41 | for i in range(self._n_blocks): 42 | x_slice = x[:, :, i * self._t_len_block: (i+1) * self._t_len_block] 43 | 44 | # Recurrent current and refractory mask only included after first block 45 | if i > 0: 46 | # Add recurrent current to input 47 | if self._recurrent: 48 | rec_current = self.get_rec_input(spikes) 49 | x_slice = x_slice + rec_current 50 | 51 | # Apply refractory mask to input 52 | if self._detach_spike_grad: 53 | spike_mask = spikes.detach().amax(dim=2).bool() 54 | else: 55 | spike_mask = spikes.amax(dim=2).bool() 56 | refac_mask = (z < spike_mask.unsqueeze(2)) * x_slice 57 | x_slice -= refac_mask 58 | 59 | # Set initial membrane potentials 60 | v_init = int_mem[:, :, -1] * ~spike_mask # if spiked -> zero initial membrane potential 61 | 62 | # Set initial adaptive params 63 | if self._adapt: 64 | # Get a at time of spike + spike (which is equal to 1/p to account for raising v_th by 1 next step 65 | # do the math or see paper if this is not clear) 66 | if self._detach_spike_grad: 67 | a_at_spike = (a_kernel * spikes.detach()).sum(dim=2) + (1 / self.p) 68 | else: 69 | a_at_spike = (a_kernel * spikes).sum(dim=2) + (1 / self.p) 70 | decay_steps = (z > 1).sum(dim=2) # Compute number of decay steps 71 | new_a = a_at_spike * torch.pow(self.p.unsqueeze(0), decay_steps) 72 | a = (a_kernel[:, :, -1] * ~spike_mask) + (new_a * spike_mask) 73 | 74 | # Update a for neurons that spiked 75 | a_kernel = self.compute_a_kernel(a, self.p) 76 | v_th = 1 + self.b.view(1, -1, 1) * a_kernel 77 | 78 | if mode == "train": 79 | spikes, z, int_mem = self._block(x_slice, self.beta, v_init=v_init, v_th=v_th, mode="train") 80 | spikes_list.append(spikes) 81 | elif mode == "val": 82 | spikes, z, _, int_mem = self._block(x_slice, self.beta, v_init=v_init, v_th=v_th, mode="val") 83 | spikes_list.append(spikes) 84 | mem_list.append(int_mem) 85 | z_list.append(z) 86 | v_th_list.append(v_th) 87 | 88 | if mode == "train": 89 | return time_cat(spikes_list, self._t_pad) 90 | elif mode == "val": 91 | return time_cat(spikes_list, self._t_pad), time_cat(mem_list, self._t_pad), x_init, time_cat(z_list, self._t_pad), time_cat(v_th_list, self._t_pad) 92 | 93 | def compute_a_kernel(self, a, p): 94 | # a: b x n 95 | # p: n 96 | # output: b x n x t 97 | 98 | return torch.pow(p.unsqueeze(-1) * self._p_ident_base, self._p_exp).unsqueeze(0) * a.unsqueeze(-1) 99 | 100 | 101 | class BlocksIntegrator(BaseSNN): 102 | 103 | def __init__(self, n_in, n_out, t_len, init_beta=1): 104 | super().__init__(n_in, n_out, 1, t_len, t_latency=0, recurrent=False, beta_grad=True, adapt=False, init_beta=init_beta, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid") 105 | self._block = Block(n_out, t_len, "fast_sigmoid") 106 | 107 | def process(self, x, mode="train"): 108 | pad_current = F.pad(x, pad=(self._t_len - 1, 0)).unsqueeze(1) 109 | 110 | # compute membrane potential without reset 111 | beta_kernel = self._block.build_beta_kernel(self.beta) 112 | membrane = bconv1d(pad_current, beta_kernel) 113 | 114 | return membrane.squeeze(1) 115 | -------------------------------------------------------------------------------- /src/snn/block/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def bconv1d(x, weight, stride=1, dilation=1, padding=0): 6 | # Would be useful if PyTorch provided batched 1D convs in their library 7 | b, c, n, h = x.shape 8 | n, out_channels, in_channels, kernel_width_size = weight.shape 9 | 10 | out = x.view(b, c * n, h) 11 | weight = weight.view(n * out_channels, in_channels, kernel_width_size) 12 | 13 | out = F.conv1d(out, weight=weight, bias=None, stride=stride, dilation=dilation, groups=n, padding=padding) 14 | 15 | return out.view(b, c, n, -1) 16 | 17 | 18 | def time_cat(tensor_list, t_pad): 19 | tensor = torch.cat(tensor_list, dim=2) 20 | 21 | if t_pad > 0: 22 | tensor = tensor[:, :, :-t_pad] 23 | 24 | return tensor 25 | -------------------------------------------------------------------------------- /src/snn/snn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from brainbox.models import BBModel 6 | 7 | from src.snn import surrogate 8 | 9 | # SNN control models. 10 | 11 | 12 | class BaseSNN(BBModel): 13 | 14 | MIN_BETA = 0.001 15 | MAX_BETA = 0.999 16 | 17 | def __init__(self, n_in, n_out, rf_len, t_len, t_latency, recurrent=True, beta_grad=True, adapt=True, init_beta=1, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid"): 18 | super().__init__() 19 | self._n_in = n_in 20 | self._n_out = n_out 21 | self._rf_len = rf_len 22 | self._t_len = t_len 23 | self._t_latency = t_latency 24 | self._recurrent = recurrent 25 | self._beta_grad = beta_grad 26 | self._adapt = adapt 27 | self._detach_spike_grad = detach_spike_grad 28 | self._surr_grad = surr_grad 29 | 30 | self._beta = nn.Parameter(data=torch.Tensor(n_out * [init_beta]), requires_grad=beta_grad) 31 | self._rf_weight = nn.Parameter(torch.rand(n_out, 1, n_in, self._rf_len), requires_grad=True) 32 | self._rf_bias = nn.Parameter(torch.zeros(n_out), requires_grad=True) 33 | 34 | self._rec_weight = nn.Parameter(torch.rand(n_out, n_out), requires_grad=recurrent) 35 | 36 | self._p = nn.Parameter(data=torch.Tensor(n_out * [init_p]), requires_grad=adapt) 37 | self._b = nn.Parameter(data=torch.Tensor(n_out * [1.8]), requires_grad=adapt) 38 | 39 | self.init_weight(self._rf_weight, "uniform", a=-1 / np.sqrt(n_in * rf_len), b=1 / np.sqrt(n_in * rf_len)) 40 | self.init_weight(self._rec_weight, "identity") 41 | 42 | @property 43 | def hyperparams(self): 44 | return {**super().hyperparams, "n_in": self._n_in, "n_out": self._n_out, "rf_len": self._rf_len, "t_len": self._t_len, "t_latency": self._t_latency, "recurrent": self._recurrent, "beta_grad": self._beta_grad, "adapt": self._adapt, "detach_spike_grad": self._detach_spike_grad, "surr_grad": self._surr_grad} 45 | 46 | @property 47 | def p(self): 48 | return torch.clamp(self._p.abs(), min=0, max=0.999) 49 | 50 | @property 51 | def b(self): 52 | return torch.clamp(self._b.abs(), min=0.001, max=1) 53 | 54 | @property 55 | def beta(self): 56 | return torch.clamp(self._beta, min=BaseSNN.MIN_BETA, max=BaseSNN.MAX_BETA) 57 | 58 | @property 59 | def rec_weight(self): 60 | return self._rec_weight 61 | 62 | def get_rec_input(self, spikes): 63 | return torch.einsum("ij, bj...->bi...", self.rec_weight, spikes.detach() if self._detach_spike_grad else spikes) 64 | 65 | def forward(self, x, mode="train"): 66 | # x: b x n x t 67 | 68 | x = F.pad(x, (self._rf_len - 1, 0)) 69 | x = x.unsqueeze(1) # Add channel dim 70 | x = F.conv2d(x, self._rf_weight, self._rf_bias)[:, :, 0] # Slice out height dim 71 | 72 | return self.process(x, mode) 73 | 74 | def process(self, x, mode): 75 | raise NotImplementedError 76 | 77 | 78 | class SNN(BaseSNN): 79 | 80 | def __init__(self, n_in, n_out, rf_len, t_len, t_latency, recurrent=False, beta_grad=True, adapt=True, init_beta=1, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid"): 81 | super().__init__(n_in, n_out, rf_len, t_len, t_latency, recurrent, beta_grad, adapt, init_beta, init_p, detach_spike_grad, surr_grad) 82 | 83 | def process(self, x, mode="train"): 84 | # x: b x n x t 85 | 86 | mem_list = [] 87 | spikes_list = [] 88 | spikes = torch.zeros_like(x).to(x.device)[:, :, 0] 89 | rec_current = torch.zeros_like(x) 90 | mem = torch.zeros_like(x).to(x.device)[:, :, 0] 91 | refac_times = torch.zeros_like(x).to(x.device)[:, :, 0] + self._t_latency 92 | 93 | v_th = torch.ones_like(x).to(x.device)[:, :, 0] 94 | a = torch.zeros_like(x).to(x.device)[:, :, 0] 95 | v_th_list = [] 96 | 97 | for t in range(x.shape[2]): 98 | stimulus_current = x[:, :, t] 99 | rec_current[:, :, t] = self.get_rec_input(spikes) 100 | 101 | # Recurrent latency 102 | if t >= self._t_latency and self._recurrent: 103 | input_current = stimulus_current + rec_current[:, :, t-self._t_latency] 104 | else: 105 | input_current = stimulus_current 106 | 107 | # Apply absolute refractory period 108 | refac_times[spikes > 0] = 0 109 | refac_mask = refac_times < self._t_latency 110 | input_current[refac_mask] = 0 111 | refac_times += 1 112 | 113 | new_mem = torch.einsum("bn...,n->bn...", mem, self.beta) + input_current 114 | spikes = surrogate.spike(new_mem - v_th, self._surr_grad) 115 | 116 | mem_list.append(new_mem) 117 | if self._detach_spike_grad: 118 | mem = new_mem * (1 - spikes.detach()) 119 | else: 120 | mem = new_mem * (1 - spikes) 121 | # new_mem -= new_mem * spikes (should be same as above?) 122 | spikes_list.append(spikes) 123 | 124 | if self._adapt: 125 | a = self.p * a + spikes 126 | v_th = 1 + self.b * a 127 | v_th_list.append(v_th) 128 | 129 | if mode == "train": 130 | return torch.stack(spikes_list, dim=2) 131 | elif mode == "val": 132 | v_th = torch.stack(v_th_list, dim=2) 133 | v_th = torch.roll(v_th, 1, dims=2) 134 | v_th[:, :, :1] = 1 135 | 136 | return torch.stack(spikes_list, dim=2), torch.stack(mem_list, dim=2), x, v_th 137 | 138 | 139 | class SNNIntegrator(SNN): 140 | 141 | def __init__(self, n_in, n_out, t_len, init_beta=1): 142 | super().__init__(n_in, n_out, 1, t_len, t_latency=0, recurrent=False, beta_grad=True, adapt=False, init_beta=init_beta, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid") 143 | 144 | def process(self, x, mode="train"): 145 | mem_list = [] 146 | mem = torch.zeros_like(x).to(x.device)[:, :, 0] 147 | 148 | for t in range(x.shape[2]): 149 | input_current = x[:, :, t] 150 | 151 | new_mem = torch.einsum("bn...,n->bn...", mem, self.beta) + input_current 152 | mem_list.append(new_mem) 153 | mem = new_mem 154 | 155 | return torch.stack(mem_list, dim=2) 156 | -------------------------------------------------------------------------------- /src/snn/surrogate.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | class FastSigmoid(torch.autograd.Function): 7 | 8 | @staticmethod 9 | def forward(ctx, input, scale=10): 10 | ctx.scale = scale 11 | ctx.save_for_backward(input) 12 | 13 | return input.gt(0).float() 14 | 15 | @staticmethod 16 | def backward(ctx, grad_output): 17 | input, = ctx.saved_tensors 18 | grad_input = grad_output.clone() 19 | grad = grad_input / (ctx.scale * torch.abs(input) + 1.0) ** 2 20 | 21 | return grad, None 22 | 23 | 24 | class BoxCar(torch.autograd.Function): 25 | 26 | @staticmethod 27 | def forward(ctx, input): 28 | ctx.save_for_backward(input) 29 | 30 | return input.gt(0).float() 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | input, = ctx.saved_tensors 35 | grad = grad_output.clone() 36 | grad[input <= -0.5] = 0 37 | grad[input > 0.5] = 0 38 | 39 | return grad 40 | 41 | 42 | class MG(torch.autograd.Function): 43 | 44 | @staticmethod 45 | def forward(ctx, input): 46 | ctx.save_for_backward(input) 47 | return input.gt(0).float() 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, = ctx.saved_tensors 52 | grad = grad_output.clone() 53 | lens = 0.5 54 | hight = 0.15 55 | scale = 6 56 | gamma = 0.5 57 | 58 | temp = MG.gaussian(input, mu=0., sigma=lens) * (1. + hight) - MG.gaussian(input, mu=lens, sigma=scale * lens) * hight - MG.gaussian(input, mu=-lens, sigma=scale * lens) * hight 59 | 60 | return gamma * grad * temp.float() 61 | 62 | @staticmethod 63 | def gaussian(x, mu=0., sigma=.5): 64 | return torch.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / torch.sqrt(2 * torch.tensor(math.pi)) / sigma 65 | 66 | 67 | def spike(x, type): 68 | if type == "fast_sigmoid": 69 | return FastSigmoid.apply(x) 70 | elif type == "box_car": 71 | return BoxCar.apply(x) 72 | elif type == "mg": 73 | return MG.apply(x) 74 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import logging 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import pandas as pd 10 | from brainbox import trainer 11 | from brainbox.physiology.spiking import VanRossum 12 | 13 | from src import datasets, models 14 | 15 | 16 | torch.backends.cudnn.benchmark = True 17 | logger = logging.getLogger("trainer") 18 | logging.basicConfig(stream=sys.stdout, level=logging.INFO) 19 | 20 | 21 | class Trainer(trainer.Trainer): 22 | 23 | def __init__(self, root, model, dataset, n_epochs, batch_size, lr, milestones=[-1], gamma=0.1, val_dataset=None, device="cuda", id=None): 24 | super().__init__(root, model, dataset, n_epochs, batch_size, lr, torch.optim.Adam, device=device, optimizer_kwargs={"eps": 1e-5}, loader_kwargs={"shuffle": True, "pin_memory": True, "num_workers": 16}, id=id) 25 | self._milestones = milestones 26 | self._gamma = gamma 27 | self._val_dataset = val_dataset 28 | 29 | self._times = {"forward_pass": [], "backward_pass": []} 30 | self._train_acc = [] 31 | self._val_acc = [] 32 | self._min_loss = np.inf 33 | self._milestone_idx = 0 34 | 35 | @staticmethod 36 | def accuracy_metric(output, target): 37 | _, predictions = torch.max(output, 1) 38 | return (predictions == target).sum().cpu().item() 39 | 40 | @staticmethod 41 | def spike_count(output, target): 42 | _, cortical_output, thalamic_output = output 43 | 44 | count = cortical_output[0].sum().cpu().item() 45 | count += thalamic_output[0].sum().cpu().item() 46 | 47 | return count 48 | 49 | @property 50 | def times_path(self): 51 | return os.path.join(self.root, self.id, "times.csv") 52 | 53 | @property 54 | def train_acc_path(self): 55 | return os.path.join(self.root, self.id, "train_acc.csv") 56 | 57 | @property 58 | def val_acc_path(self): 59 | return os.path.join(self.root, self.id, "val_acc.csv") 60 | 61 | def save_model_log(self): 62 | super().save_model_log() 63 | 64 | # Save times 65 | times_df = pd.DataFrame(self._times) 66 | times_df.to_csv(self.times_path, index=False) 67 | 68 | # Save acc 69 | train_acc_df = pd.DataFrame(self._train_acc) 70 | train_acc_df.to_csv(self.train_acc_path, index=False) 71 | val_acc_df = pd.DataFrame(self._val_acc) 72 | val_acc_df.to_csv(self.val_acc_path, index=False) 73 | 74 | def loss(self, output, target, model): 75 | target = target.long() 76 | loss = F.cross_entropy(output, target, reduction="mean") 77 | 78 | return loss 79 | 80 | def train_for_single_epoch(self): 81 | epoch_loss = 0 82 | n_samples = 0 83 | n_correct = 0 84 | 85 | for batch_id, (data, target) in enumerate(self.train_data_loader): 86 | data = data.to(self.device).type(self.dtype) 87 | target = target.to(self.device).type(self.dtype) 88 | torch.cuda.synchronize() 89 | 90 | # Forward pass 91 | start_time = time.time() 92 | output = self.model(data) 93 | torch.cuda.synchronize() 94 | forward_pass_time = time.time() - start_time 95 | self._times["forward_pass"].append(forward_pass_time) 96 | 97 | # Compute accuracy 98 | _, predictions = torch.max(output, 1) 99 | n_correct += (predictions == target).sum().cpu().item() 100 | 101 | # Compute loss 102 | loss = self.loss(output, target, self.model) 103 | 104 | # Backward pass 105 | start_time = time.time() 106 | loss.backward() 107 | torch.cuda.synchronize() 108 | backward_pass_time = time.time() - start_time 109 | self._times["backward_pass"].append(backward_pass_time) 110 | 111 | self.optimizer.step() 112 | self.optimizer.zero_grad() 113 | 114 | with torch.no_grad(): 115 | epoch_loss += (loss.item() * data.shape[0]) 116 | n_samples += data.shape[0] 117 | 118 | train_acc = n_correct/n_samples 119 | logging.info(f"Train acc: {train_acc}") 120 | self._train_acc.append(train_acc) 121 | 122 | if self._val_dataset is not None and len(self.log["train_loss"]) % 5 == 0: 123 | val_acc = Trainer.get_acc(self.model, self._val_dataset, self.batch_size) 124 | logging.info(f"Val acc: {val_acc}") 125 | self._val_acc.append(val_acc) 126 | 127 | return epoch_loss / n_samples 128 | 129 | @staticmethod 130 | def get_acc(model, dataset, batch_size): 131 | scores = trainer.compute_metric(model, dataset, Trainer.accuracy_metric, batch_size=batch_size) 132 | return np.sum(scores) / len(dataset) 133 | 134 | @staticmethod 135 | def get_spike_count(model, dataset, batch_size): 136 | scores = trainer.compute_metric(model, dataset, Trainer.spike_count, batch_size=batch_size) 137 | return np.sum(scores) / len(dataset) 138 | 139 | def on_epoch_complete(self, save): 140 | if save: 141 | self.save_model_log() 142 | 143 | epoch_loss = self.log["train_loss"][-1] 144 | if epoch_loss < self._min_loss: 145 | logging.info(f"Saving model...") 146 | self._min_loss = epoch_loss 147 | self.save_model() 148 | 149 | n_epoch = len(self.log["train_loss"]) 150 | 151 | if n_epoch == self._milestones[self._milestone_idx]: 152 | logging.info(f"Decaying lr...") 153 | self.lr *= self._gamma 154 | # Load best model 155 | self.model = Trainer.load_model(self.root, self.id, self.device, self.dtype) 156 | self.optimizer = self.optimizer_func( 157 | self.model.parameters(), self.lr, **self.optimizer_kwargs 158 | ) 159 | 160 | if self._milestone_idx != len(self._milestones) - 1: 161 | logging.info(f"New milestone target...") 162 | self._milestone_idx += 1 163 | 164 | def on_training_complete(self, save): 165 | pass 166 | 167 | @staticmethod 168 | def hyperparams_loader(hyperparams): 169 | model_params = hyperparams["model"] 170 | del model_params["name"] 171 | del model_params["weight_initializers"] 172 | 173 | return models.AuditoryModel(**model_params) 174 | 175 | @staticmethod 176 | def load_model(root, id, device="cuda", dtype=torch.float): 177 | return trainer.load_model(root, id, Trainer.hyperparams_loader, device, dtype) 178 | 179 | 180 | class EphysTrainer(Trainer): 181 | 182 | def __init__(self, root, model, dataset, n_epochs, batch_size, lr, gamma=0.1, dt=0.1, epoch_scan=5, max_decay=1, val_dataset=None, device="cuda", id=None): 183 | super().__init__(root, model, dataset, n_epochs, batch_size, lr, [-1], gamma, val_dataset, device, id) 184 | self.epoch_scan = epoch_scan 185 | self.van_rossum = VanRossum(datasets.EphysDataset.LENGTH, tau=100, dt=dt).to(device) 186 | self.max_decay = max_decay 187 | 188 | self._decay_count = 0 189 | 190 | def loss(self, spikes_pred, spikes): 191 | spike_loss = self.van_rossum(spikes_pred, spikes) 192 | 193 | return spike_loss 194 | 195 | def train_for_single_epoch(self): 196 | epoch_loss = 0 197 | n_samples = 0 198 | 199 | for batch_id, (data, target) in enumerate(self.train_data_loader): 200 | data = data.to(self.device).type(self.dtype) 201 | trace = target[0].to(self.device).type(self.dtype) 202 | spikes = target[1].to(self.device).type(self.dtype) 203 | torch.cuda.synchronize() 204 | 205 | # Forward pass 206 | start_time = time.time() 207 | spikes_pred = self.model(data) 208 | torch.cuda.synchronize() 209 | forward_pass_time = time.time() - start_time 210 | self._times["forward_pass"].append(forward_pass_time) 211 | 212 | # Compute loss 213 | loss = self.loss(spikes_pred, spikes) 214 | 215 | # Backward pass 216 | start_time = time.time() 217 | loss.backward() 218 | torch.cuda.synchronize() 219 | backward_pass_time = time.time() - start_time 220 | self._times["backward_pass"].append(backward_pass_time) 221 | 222 | self.optimizer.step() 223 | self.optimizer.zero_grad() 224 | 225 | with torch.no_grad(): 226 | epoch_loss += (loss.item() * data.shape[0]) 227 | n_samples += data.shape[0] 228 | 229 | return epoch_loss / n_samples 230 | 231 | def on_epoch_complete(self, save): 232 | if save: 233 | self.save_model_log() 234 | 235 | epoch_loss = self.log["train_loss"][-1] 236 | if epoch_loss < self._min_loss: 237 | logging.info(f"Saving model...") 238 | self._min_loss = epoch_loss 239 | self.save_model() 240 | 241 | min_lost_over_last_epochs = np.array(self.log["train_loss"][-self.epoch_scan:]).min() 242 | 243 | if min_lost_over_last_epochs > self._min_loss: 244 | if self._decay_count < self.max_decay: 245 | self._min_loss = np.inf 246 | self._last_train_scores = [] 247 | logging.info(f"Decaying lr...") 248 | self._decay_count += 1 249 | self.lr *= self._gamma 250 | # Load best model 251 | self.model = EphysTrainer.load_model(self.root, self.id, self.device, self.dtype) 252 | self.optimizer = self.optimizer_func( 253 | self.model.parameters(), self.lr, **self.optimizer_kwargs 254 | ) 255 | else: 256 | self.exit = True 257 | 258 | @staticmethod 259 | def load_model(root, id, device="cuda", dtype=torch.float, dt01ref=False): 260 | 261 | def model_loader(hyperparams): 262 | model_params = hyperparams["model"] 263 | del model_params["name"] 264 | del model_params["weight_initializers"] 265 | 266 | model_params = {**model_params, "dt01ref": dt01ref} 267 | 268 | return models.Neuron(**model_params) 269 | 270 | return trainer.load_model(root, id, model_loader, device, dtype) 271 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from src.snn.block.block import Block 5 | 6 | 7 | @pytest.fixture 8 | def block(): 9 | return Block(2, 4, "fast_sigmoid") 10 | 11 | 12 | def test_beta_kernel(block): 13 | # Use a single beta 14 | assert torch.allclose(block.build_beta_kernel(torch.Tensor([0.1]))[0, 0, 0], torch.Tensor([0.0010, 0.0100, 0.1000, 1.0000])) 15 | assert torch.allclose(block.build_beta_kernel(torch.Tensor([0.1]))[1, 0, 0], torch.Tensor([0.0010, 0.0100, 0.1000, 1.0000])) 16 | 17 | # Use multiple beta 18 | assert torch.allclose(block.build_beta_kernel(torch.Tensor([0.1]))[0, 0, 0], torch.Tensor([0.0010, 0.0100, 0.1000, 1.0000])) 19 | assert torch.allclose(block.build_beta_kernel(torch.Tensor([0.5]))[1, 0, 0], torch.Tensor([0.1250, 0.2500, 0.5000, 1.0000])) 20 | 21 | 22 | def test_phi_kernel(block): 23 | assert torch.allclose(block._phi_kernel, torch.Tensor([[[[4., 3., 2., 1.]]]])) 24 | 25 | 26 | def test_g(block): 27 | phi_spikes = torch.zeros(2, 4) 28 | phi_spikes[0, 0] = 1 29 | phi_spikes[0, 2] = 2 30 | phi_spikes[1, 1] = 1 31 | phi_spikes[1, 3] = 3 32 | assert block.g(phi_spikes).sum() == 2 33 | 34 | 35 | def test_differentiable_vars(block): 36 | assert not block._beta_ident_base.requires_grad 37 | assert not block._beta_exp.requires_grad 38 | assert not block._phi_kernel.requires_grad -------------------------------------------------------------------------------- /tests/test_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | 4 | from src.snn.snn import SNN 5 | from src.snn.block.blocks import Blocks 6 | 7 | 8 | @pytest.fixture 9 | def in_spikes(b=4, n=100, t=200): 10 | return torch.rand(b, n, t) 11 | 12 | 13 | def get_models(n_in=100, n_out=100, rf_len=10, t_len=200, t_latency=0, recurrent=True, adapt=False): 14 | blocks = Blocks(n_in, n_out, rf_len, t_len, t_latency, recurrent=recurrent, adapt=adapt) 15 | snn = SNN(n_in, n_out, rf_len, t_len, t_latency, recurrent=recurrent, adapt=adapt) 16 | 17 | blocks._rf_weight = snn._rf_weight 18 | blocks._rf_bias = snn._rf_bias 19 | blocks._rec_weight = snn._rec_weight 20 | 21 | return blocks, snn 22 | 23 | 24 | def test_networks_none(in_spikes): 25 | for t_latency in [0, 1, 2, 4, 8]: 26 | blocks, snn = get_models(t_latency=t_latency, recurrent=False) 27 | spikes1 = blocks(in_spikes, mode="train") 28 | spikes2 = snn(in_spikes, mode="train") 29 | 30 | assert torch.allclose(spikes1, spikes2) 31 | 32 | 33 | def test_networks_recurrent(in_spikes): 34 | for t_latency in [0, 1, 2, 4, 8]: 35 | blocks, snn = get_models(t_latency=t_latency, recurrent=True) 36 | spikes1 = blocks(in_spikes, mode="train") 37 | spikes2 = snn(in_spikes, mode="train") 38 | 39 | assert torch.allclose(spikes1, spikes2) 40 | 41 | 42 | def test_adaption(in_spikes): 43 | for t_latency in [0, 1, 2, 4, 8]: 44 | blocks, snn = get_models(t_latency=t_latency, recurrent=True, adapt=True) 45 | spikes1, mem1, x1, z1, v_th1 = blocks(in_spikes, mode="val") 46 | spikes2, mem2, x2, v_th2 = snn(in_spikes, mode="val") 47 | 48 | # v_th should equal at the start of every block 49 | for i in range(blocks._t_len // blocks._t_len_block): 50 | assert torch.allclose(v_th1[:, :, i*blocks._t_len_block], v_th2[:, :, i*blocks._t_len_block]) 51 | --------------------------------------------------------------------------------