├── .gitignore
├── LICENSE
├── data
├── N-MNIST
│ └── .gitignore
├── SHD
│ └── .gitignore
└── ephys
│ └── .gitignore
├── docs
└── README.md
├── environment.yml
├── figures
├── figure2.png
├── figure4.png
└── figure5.png
├── notebooks
├── ephys.ipynb
├── getting_started.ipynb
├── library_comparison.ipynb
├── speed_benchmarks.ipynb
└── supervised_benchmarks.ipynb
├── scripts
├── ephys
│ ├── build_data.py
│ ├── run.py
│ └── train.py
├── run_benchmarks.py
└── supervised
│ ├── run_blocks_nmnist.py
│ ├── run_blocks_shd.py
│ ├── run_detach_spikes.py
│ ├── run_standard_nmnist.py
│ ├── run_standard_shd.py
│ └── train.py
├── setup.py
├── src
├── __init__.py
├── benchmark.py
├── datasets
│ ├── __init__.py
│ ├── ephys.py
│ ├── neuromorphic.py
│ ├── synthetic.py
│ └── transforms.py
├── metric.py
├── models.py
├── query.py
├── snn
│ ├── __init__.py
│ ├── block
│ │ ├── __init__.py
│ │ ├── block.py
│ │ ├── blocks.py
│ │ └── util.py
│ ├── snn.py
│ └── surrogate.py
└── train.py
└── tests
├── __init__.py
├── test_block.py
└── test_blocks.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # OSX
2 | .DS_Store
3 |
4 | # Python
5 | __pycache__/
6 | *.pyc
7 | src.egg-info
8 |
9 | # PyCharm
10 | .idea
11 |
12 | # Jupyter Notebook
13 | .ipynb_checkpoints
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, Luke Taylor
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/data/N-MNIST/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/data/SHD/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/data/ephys/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Addressing the speed-accuracy simulation trade-off for adaptive spiking neurons
2 |
3 | A new model for quickly training and simulating adaptive leaky integrate-and-fire spiking neural networks.
4 |
5 |
6 |
7 |
8 |
9 | ## Installing dependencies
10 |
11 | Install all required dependencies and activate the blocks environment using conda.
12 | ```
13 | conda env create -f environment.yml
14 | conda activate blocks
15 | ```
16 |
17 | ## Getting started tutorial
18 |
19 | See the [notebooks/getting_started.ipynb](../notebooks/getting_started.ipynb) notebook for getting started with our model.
20 |
21 | ## Reproducing paper results
22 |
23 | All the paper results can be reproduced using the scripts available in the `scripts` folder.
24 |
25 | ### Running benchmark experiments
26 |
27 | The `python run_benchmarks.py` script will benchmark the time of the forward and backward passes of the blocks and the standard SNN model for different numbers of neurons and simulation steps.
28 |
29 | ### Training models
30 |
31 | Ensure that the computer has a CUDA capable GPU with CUDA 11.7 installed.
32 |
33 | #### 1. Downloading and processing datasets
34 |
35 | #### Machine learning datasets:
36 | The content of the Neuromorphic-MNIST dataset can be [downloaded](https://www.garrickorchard.com/datasets/n-mnist) and unzipped into the `data/N-MNIST` directory. Thereafter, the `python convert_nmnist2h5.py` script (adapted from Perez-Nieves et al., 2021) needs to be run which processes the raw dataset. The Spiking Heidelberg Digits (SHD) dataset can be [downloaded](https://compneuro.net/posts/2019-spiking-heidelberg-digits/) and unzipped into the `data/SHD` directory.
37 |
38 | #### E-phys dataset:
39 |
40 | Running the `scripts/ephys/build_data.py` script will download and process the necessary data from the Allen Institute.
41 |
42 | #### 2. Train model
43 |
44 | You can train the blocks and standard SNN on the different datasets using the train.py scripts in the `scripts/ephys` and `scripts/supervised` folders respectively. See respective folders for different experiment run scripts.
45 |
46 | ## Building result figures
47 |
48 | Speedup plots can be built using: `notebooks/results/speed_benchmarks.ipynb`
49 |
50 | Machine learning benchmark plots can be built using: `notebooks/results/supervised_benchmarks.ipynb`
51 |
52 | Neural-fitting plots can be built using: `notebooks/results/ephys.ipynb`
53 |
54 | ### Machine learning benchmark results
55 |
56 |
57 |
58 |
59 | ### Neural-fitting results
60 |
61 |
62 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: blocks
2 | dependencies:
3 | - python=3.8
4 | - pytorch::pytorch
5 | - pytorch::torchvision
6 | - nvidia::cudatoolkit=11.7
7 | - matplotlib
8 | - seaborn
9 | - pandas
10 | - conda-forge::h5py
11 | - nb_conda_kernels
12 | - ipywidgets
13 | - pytest
14 | - pip:
15 | - brainbox==0.0.6
16 | - jupyter
17 | - allensdk
18 | - --editable .
19 |
--------------------------------------------------------------------------------
/figures/figure2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/figures/figure2.png
--------------------------------------------------------------------------------
/figures/figure4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/figures/figure4.png
--------------------------------------------------------------------------------
/figures/figure5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/figures/figure5.png
--------------------------------------------------------------------------------
/notebooks/getting_started.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 89,
6 | "id": "da7df381",
7 | "metadata": {},
8 | "outputs": [
9 | {
10 | "name": "stdout",
11 | "output_type": "stream",
12 | "text": [
13 | "The autoreload extension is already loaded. To reload it, use:\n",
14 | " %reload_ext autoreload\n"
15 | ]
16 | }
17 | ],
18 | "source": [
19 | "import time\n",
20 | "\n",
21 | "import torch\n",
22 | "import numpy as np\n",
23 | "import matplotlib.pyplot as plt\n",
24 | "import seaborn as sns\n",
25 | "import pandas as pd\n",
26 | "\n",
27 | "from src.datasets import SyntheticSpikes\n",
28 | "from src.snn.snn import SNN\n",
29 | "from src.snn.block.blocks import Blocks\n",
30 | "\n",
31 | "%load_ext autoreload\n",
32 | "%autoreload 2"
33 | ]
34 | },
35 | {
36 | "cell_type": "markdown",
37 | "id": "8702a126",
38 | "metadata": {},
39 | "source": [
40 | "## Model equivalence"
41 | ]
42 | },
43 | {
44 | "cell_type": "markdown",
45 | "id": "ca0f57e9",
46 | "metadata": {},
47 | "source": [
48 | "Let's check if the blocks model and the standard model produce the same output raster"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 86,
54 | "id": "286f166a",
55 | "metadata": {},
56 | "outputs": [],
57 | "source": [
58 | "n_in = 100\n",
59 | "n_out = 20\n",
60 | "rf_len = 20\n",
61 | "t_len = 1000 \n",
62 | "block_len = 40\n",
63 | "\n",
64 | "# Instantiate recurrently connected ALIF SNNs (blocks and standard version)\n",
65 | "torch.manual_seed(42)\n",
66 | "input_raster = SyntheticSpikes(t_len, n_in, min_r=0, max_r=100, n_samples=1)[0]\n",
67 | "standard_snn = SNN(n_in, n_out, rf_len, t_len, t_latency=block_len, recurrent=True, init_beta=0.99, init_p=0.99)\n",
68 | "blocks_snn = Blocks(n_in, n_out, rf_len, t_len, t_latency=block_len, recurrent=True, init_beta=0.99, init_p=0.99)\n",
69 | "\n",
70 | "# Ensure bocks model has the same weights as the standard model\n",
71 | "blocks_snn._rf_weight = standard_snn._rf_weight\n",
72 | "blocks_snn._rf_bias = standard_snn._rf_bias\n",
73 | "blocks_snn._rec_weight = standard_snn._rec_weight\n",
74 | "\n",
75 | "# Obtain spikes from blocks and standard model\n",
76 | "with torch.no_grad():\n",
77 | " blocks_spikes = blocks_snn(input_raster.unsqueeze(0), mode=\"train\")\n",
78 | " standard_spikes = standard_snn(input_raster.unsqueeze(0), mode=\"train\")"
79 | ]
80 | },
81 | {
82 | "cell_type": "code",
83 | "execution_count": 62,
84 | "id": "3c475ec8",
85 | "metadata": {},
86 | "outputs": [
87 | {
88 | "name": "stdout",
89 | "output_type": "stream",
90 | "text": [
91 | "Are model outputs the same? True\n"
92 | ]
93 | }
94 | ],
95 | "source": [
96 | "# Drum roll, the moment we have all been waiting for...\n",
97 | "print(f\"Are model outputs the same? {torch.allclose(blocks_spikes, standard_spikes)}\")"
98 | ]
99 | },
100 | {
101 | "cell_type": "code",
102 | "execution_count": 88,
103 | "id": "e476bd42",
104 | "metadata": {
105 | "scrolled": true
106 | },
107 | "outputs": [
108 | {
109 | "data": {
110 | "image/png": "",
111 | "text/plain": [
112 | ""
113 | ]
114 | },
115 | "metadata": {},
116 | "output_type": "display_data"
117 | }
118 | ],
119 | "source": [
120 | "# Little util function\n",
121 | "def spike_tensor_to_points(spike_tensor):\n",
122 | " x = np.array([p[1].item() for p in torch.nonzero(spike_tensor.cpu())])\n",
123 | " y = np.array([p[0].item() for p in torch.nonzero(spike_tensor.cpu())])\n",
124 | " \n",
125 | " return x, y\n",
126 | "\n",
127 | "fig, axs = plt.subplots(1, 3, figsize=(6, 3))\n",
128 | "axs[0].scatter(*spike_tensor_to_points(input_raster), color=\"black\", s=0.5)\n",
129 | "axs[1].scatter(*spike_tensor_to_points(blocks_spikes[0]), color=\"black\", s=0.5)\n",
130 | "axs[2].scatter(*spike_tensor_to_points(standard_spikes[0]), color=\"black\", s=0.5)\n",
131 | "axs[0].set_title(\"Input raster\")\n",
132 | "axs[1].set_title(\"Blocks output\")\n",
133 | "axs[2].set_title(\"Standard output\")\n",
134 | "axs[0].set_ylabel(\"NeuronID\")\n",
135 | "axs[0].set_xlabel(\"Sim step\")\n",
136 | "axs[1].set_xlabel(\"Sim step\")\n",
137 | "axs[2].set_xlabel(\"Sim step\")\n",
138 | "fig.tight_layout()"
139 | ]
140 | },
141 | {
142 | "cell_type": "markdown",
143 | "id": "7def7ba2",
144 | "metadata": {},
145 | "source": [
146 | "## Model speed"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "id": "24c34073",
152 | "metadata": {},
153 | "source": [
154 | "How much faster is our blocks models model?"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": 108,
160 | "id": "fb472914",
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "def get_training_duration(model, device=\"cpu\"):\n",
165 | " model = model.to(device)\n",
166 | " data = input_raster.unsqueeze(0).to(device).repeat(128, 1, 1) # 128 batch size\n",
167 | " torch.cuda.synchronize()\n",
168 | " start_time = time.time()\n",
169 | " output = model(data)\n",
170 | " loss = output.sum() # Arbitraty loss just so we have something to backpropogate\n",
171 | " loss.backward()\n",
172 | " torch.cuda.synchronize()\n",
173 | " training_duration = time.time() - start_time\n",
174 | " return training_duration"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 103,
180 | "id": "91db4382",
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "def get_training_durations_for_both_models(device):\n",
185 | " standard_training_times = [get_training_duration(standard_snn, device) for _ in range(11)][1:]\n",
186 | " block_training_times = [get_training_duration(blocks_snn, device) for _ in range(11)][1:]\n",
187 | " return pd.DataFrame({\"standard\": standard_training_times, \"block\": block_training_times})"
188 | ]
189 | },
190 | {
191 | "cell_type": "code",
192 | "execution_count": 104,
193 | "id": "b45b4bdf",
194 | "metadata": {},
195 | "outputs": [],
196 | "source": [
197 | "torch.backends.cudnn.benchmark = True # Make sure we use the best conv algorithm"
198 | ]
199 | },
200 | {
201 | "cell_type": "code",
202 | "execution_count": 107,
203 | "id": "204b6a24",
204 | "metadata": {},
205 | "outputs": [
206 | {
207 | "data": {
208 | "image/png": "",
209 | "text/plain": [
210 | ""
211 | ]
212 | },
213 | "metadata": {},
214 | "output_type": "display_data"
215 | }
216 | ],
217 | "source": [
218 | "fig, axs = plt.subplots(1, 2, figsize=(8, 3), sharey=False)\n",
219 | "\n",
220 | "def plot_training_durations(ax, device):\n",
221 | " sns.barplot(get_training_durations_for_both_models(device), ax=ax)\n",
222 | " ax.set(ylabel=\"Training time (sec)\", title=device)\n",
223 | " sns.despine()\n",
224 | " \n",
225 | "plot_training_durations(axs[0], \"cpu\")\n",
226 | "plot_training_durations(axs[1], \"cuda\")"
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "id": "f9e54af9",
232 | "metadata": {},
233 | "source": [
234 | "Voila! A condsiderable reduction in training time, on both CPUs and GPUs!"
235 | ]
236 | }
237 | ],
238 | "metadata": {
239 | "kernelspec": {
240 | "display_name": "Python [conda env:blocks] *",
241 | "language": "python",
242 | "name": "conda-env-blocks-py"
243 | },
244 | "language_info": {
245 | "codemirror_mode": {
246 | "name": "ipython",
247 | "version": 3
248 | },
249 | "file_extension": ".py",
250 | "mimetype": "text/x-python",
251 | "name": "python",
252 | "nbconvert_exporter": "python",
253 | "pygments_lexer": "ipython3",
254 | "version": "3.8.16"
255 | }
256 | },
257 | "nbformat": 4,
258 | "nbformat_minor": 5
259 | }
260 |
--------------------------------------------------------------------------------
/notebooks/library_comparison.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 53,
6 | "id": "bef22c5e",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import time\n",
11 | "\n",
12 | "import torch\n",
13 | "import torch.nn as nn\n",
14 | "from spikingjelly.activation_based import neuron, layer, surrogate\n",
15 | "from norse.torch.module.lif import LIFCell\n",
16 | "\n",
17 | "from src.snn.block.blocks import Blocks\n",
18 | "from src.snn.snn import SNN"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "id": "f8051e64",
24 | "metadata": {},
25 | "source": [
26 | "## Setting up the different implementations"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "id": "4ca0b48b",
32 | "metadata": {},
33 | "source": [
34 | "Network benchmarked: 200 input units -> 100 spiking units over 1000 simulation steps using a batch size of 128."
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 93,
40 | "id": "38193107",
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "def time_jelly():\n",
45 | " input_tensor = torch.rand(128, 200, 1000).cuda()\n",
46 | " \n",
47 | " jelly_layer = nn.Sequential(\n",
48 | " layer.Linear(200, 100, bias=False),\n",
49 | " neuron.LIFNode(tau=100.0, surrogate_function=surrogate.ATan())\n",
50 | " ).cuda()\n",
51 | " \n",
52 | " start_time = time.time()\n",
53 | " \n",
54 | " for t in range(1000):\n",
55 | " out = jelly_layer(input_tensor[:, :, t])\n",
56 | " \n",
57 | " end_time = time.time()\n",
58 | " return end_time - start_time"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 94,
64 | "id": "976ed321",
65 | "metadata": {},
66 | "outputs": [],
67 | "source": [
68 | "def time_norse():\n",
69 | " input_tensor = torch.rand(128, 200, 1000).cuda()\n",
70 | " \n",
71 | " norse_layer = nn.Sequential(\n",
72 | " layer.Linear(200, 100, bias=False),\n",
73 | " LIFCell()\n",
74 | " ).cuda()\n",
75 | " \n",
76 | " start_time = time.time()\n",
77 | " \n",
78 | " for t in range(1000):\n",
79 | " out = norse_layer(input_tensor[:, :, t])\n",
80 | " \n",
81 | " end_time = time.time()\n",
82 | " \n",
83 | " return end_time - start_time"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 95,
89 | "id": "932cb5f0",
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "def time_blocks():\n",
94 | " input_tensor = torch.rand(128, 200, 1000).cuda()\n",
95 | " \n",
96 | " blocks_snn = Blocks(200, 100, 1, 1000, t_latency=50, recurrent=False, init_beta=0.99, init_p=0.99).cuda()\n",
97 | " start_time = time.time()\n",
98 | " out = blocks_snn(input_tensor)\n",
99 | " end_time = time.time()\n",
100 | " \n",
101 | " return end_time - start_time\n",
102 | " \n",
103 | "def time_standard():\n",
104 | " input_tensor = torch.rand(128, 200, 1000).cuda()\n",
105 | " \n",
106 | " blocks_snn = SNN(200, 100, 1, 1000, t_latency=1, recurrent=False, init_beta=0.99, init_p=0.99).cuda()\n",
107 | " start_time = time.time()\n",
108 | " out = blocks_snn(input_tensor)\n",
109 | " end_time = time.time()\n",
110 | " \n",
111 | " return end_time - start_time"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "id": "733c9e80",
117 | "metadata": {},
118 | "source": [
119 | "## Benchmarking the differnet implementations"
120 | ]
121 | },
122 | {
123 | "cell_type": "code",
124 | "execution_count": 97,
125 | "id": "0ecf6b39",
126 | "metadata": {},
127 | "outputs": [
128 | {
129 | "name": "stdout",
130 | "output_type": "stream",
131 | "text": [
132 | "Norse=0.3080105781555176\n",
133 | "Jelly=0.1869184970855713\n",
134 | "Standard=0.3317074775695801\n",
135 | "Blocks=0.016329288482666016\n"
136 | ]
137 | }
138 | ],
139 | "source": [
140 | "print(f\"Norse={time_norse()}\")\n",
141 | "print(f\"Jelly={time_jelly()}\")\n",
142 | "print(f\"Standard={time_standard()}\")\n",
143 | "print(f\"Blocks={time_blocks()}\")"
144 | ]
145 | }
146 | ],
147 | "metadata": {
148 | "kernelspec": {
149 | "display_name": "Python [conda env:blocks] *",
150 | "language": "python",
151 | "name": "conda-env-blocks-py"
152 | },
153 | "language_info": {
154 | "codemirror_mode": {
155 | "name": "ipython",
156 | "version": 3
157 | },
158 | "file_extension": ".py",
159 | "mimetype": "text/x-python",
160 | "name": "python",
161 | "nbconvert_exporter": "python",
162 | "pygments_lexer": "ipython3",
163 | "version": "3.8.16"
164 | }
165 | },
166 | "nbformat": 4,
167 | "nbformat_minor": 5
168 | }
169 |
--------------------------------------------------------------------------------
/scripts/ephys/build_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from src import datasets
4 |
5 |
6 | def main():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--root", type=str, default=".")
9 | args = parser.parse_args()
10 |
11 | # Builds current and spike tensors with DT=0.1ms
12 | # Note: You might want to set manifest_file if you have already downloaded the data from the Allen Institute
13 | builder = datasets.NoiseBuilder()
14 | builder.build(f"{args.root}/Blocks/data/ephys/train", noise_type="noise1")
15 | builder.build(f"{args.root}/Blocks/data/ephys/test", noise_type="noise2")
16 |
17 | # Builds current tensors with DT=0.05ms
18 | builder = datasets.NoiseBuilder()
19 | builder.build(f"{args.root}/Blocks/data/ephys/train", noise_type="noise1", target_sampling_rate=20000)
20 | builder.build(f"{args.root}/Blocks/data/ephys/test", noise_type="noise2", target_sampling_rate=20000)
21 |
22 |
23 | if __name__ == "__main__":
24 | main()
25 |
--------------------------------------------------------------------------------
/scripts/ephys/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | root = "" # TODO: Change this to the project folder
3 |
4 | from src import datasets
5 |
6 |
7 | def launch(method, abs_refac_ms, downsample):
8 | neuron_idx_query = datasets.ValidNeuronQuery(f"{root}/data/ephys")
9 |
10 | dir_name = f"{method}_{abs_refac_ms}_{downsample}"
11 | os.makedirs(f"{root}/results/ephys/{dir_name}")
12 |
13 | for neuron_idx in neuron_idx_query.idx:
14 | os.system(f"python {root}/scripts/ephys/train.py --root={root} --method={method} --abs_refac_ms={abs_refac_ms} --downsample={downsample} --neuron_idx={neuron_idx} --dir_name={dir_name} --id={neuron_idx}")
15 |
16 |
17 | # Blocks: Different dt
18 | for downsample in [0.5, 1, 5, 10, 20, 40]:
19 | launch("blocks", abs_refac_ms=2, downsample=downsample)
20 |
21 | # Blocks: Different ARP
22 | for abs_refac_ms in [1, 4, 6, 8, 16]:
23 | launch("blocks", abs_refac_ms=abs_refac_ms, downsample=1)
24 |
25 | # Standard
26 | launch("standard", abs_refac_ms=2, downsample=1)
27 |
--------------------------------------------------------------------------------
/scripts/ephys/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import ast
3 | import logging
4 | import argparse
5 |
6 | from src import datasets, models, train
7 |
8 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
9 |
10 |
11 | def eval(v):
12 | return ast.literal_eval(v)
13 |
14 |
15 | def get_dataset(args):
16 | if args.downsample == 0.5:
17 | # Used dt=0.05ms for neural fits
18 | return datasets.EphysDataset(f"{args.root}/data/ephys", "train", args.neuron_idx, target_sampling_rate=20000)
19 | else:
20 | # Used dt>=0.1ms for neural fits (used by most experiments)
21 | return datasets.EphysDataset(f"{args.root}/data/ephys", "train", args.neuron_idx)
22 |
23 |
24 | def get_model(args):
25 | if args.downsample == 0.5:
26 | # Used dt=0.05ms for neural fits
27 | return models.Neuron(method=args.method, abs_refac_ms=args.abs_refac_ms, downsample=1, dt01ref=True)
28 | else:
29 | # Used dt>=0.1ms for neural fits (used by most experiments)
30 | return models.Neuron(method=args.method, abs_refac_ms=args.abs_refac_ms, downsample=int(args.downsample))
31 |
32 |
33 | def get_trainer(args, model, train_dataset):
34 | n_epochs = 200
35 | batch_size = 5
36 | lr = 0.0001
37 | dt = 0.1 * args.downsample
38 | return train.EphysTrainer(f"{args.root}/results/ephys/{args.dir_name}", model, train_dataset, n_epochs, batch_size, lr, gamma=0.1, dt=dt, epoch_scan=5, max_decay=0, val_dataset=None, device="cuda", id=args.id)
39 |
40 |
41 | def main():
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--root", type=str, default=".")
44 |
45 | # Model
46 | parser.add_argument("--method", type=str, default="standard")
47 | parser.add_argument("--abs_refac_ms", type=int, default=10)
48 | parser.add_argument("--downsample", type=float, default=1)
49 |
50 | # Dataset
51 | parser.add_argument("--neuron_idx", type=int, default="")
52 |
53 | # Trainer
54 | parser.add_argument("--dir_name", type=str, default="")
55 | parser.add_argument("--id", type=str, default="")
56 |
57 | args = parser.parse_args()
58 |
59 | train_dataset = get_dataset(args)
60 | model = get_model(args)
61 | model_trainer = get_trainer(args, model, train_dataset)
62 | model_trainer.train(save=True)
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/scripts/run_benchmarks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.backends.cudnn.benchmark = True # Important in order to use best conv algorithm
3 |
4 | from src.benchmark import Benchmarker
5 |
6 |
7 | def run_different_sim_lengths(root):
8 | n_in = 1000
9 | n_hidden = 128
10 |
11 | for abs_refac in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
12 | for t_len in [2**9, 2**10, 2**11]:
13 | for batch_size in [32, 64, 128]:
14 | for method in ["standard", "blocks"]:
15 | bencher = Benchmarker(method, t_len, abs_refac, n_in, n_hidden, n_layers=1, batch_size=batch_size)
16 | bencher.benchmark()
17 | bencher.save(root)
18 |
19 |
20 | def run_different_layers(root):
21 | n_in = 1000
22 |
23 | for abs_refac in [40]:
24 | for t_len in [2**10]:
25 | for n_hidden in [128, 256, 512]:
26 | for batch_size in [64]:
27 | for n_layers in [2, 3, 4, 5]:
28 | for method in ["standard", "blocks"]:
29 | bencher = Benchmarker(method, t_len, abs_refac, n_in, n_hidden, n_layers=n_layers, batch_size=batch_size)
30 | bencher.benchmark()
31 | bencher.save(root)
32 |
33 |
34 | if __name__ == "__main__":
35 | root = "" # TODO: Change this to the project folder
36 | run_different_sim_lengths(f"{root}/benchmarks/sim_lengths")
37 | run_different_layers(f"{root}/benchmarks/layers")
38 |
39 |
--------------------------------------------------------------------------------
/scripts/supervised/run_blocks_nmnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | root = "" # TODO: Change this to the project folder
3 |
4 |
5 | def launch(abs_refac, surr_grad, dt, method):
6 | for i in range(3):
7 | id = f"nmnist_{method}_{surr_grad}_{abs_refac}_{dt}_{i}"
8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method={method} --abs_refac={abs_refac} --surr_grad={surr_grad} --name=nmnist --dt={dt} --id={id}")
9 |
10 |
11 | for abs_refac in [10, 20, 30, 40, 50]:
12 | launch(abs_refac, "mg", 1, "blocks")
13 |
--------------------------------------------------------------------------------
/scripts/supervised/run_blocks_shd.py:
--------------------------------------------------------------------------------
1 | import os
2 | root = "" # TODO: Change this to the project folder
3 |
4 |
5 | def launch(abs_refac, surr_grad, dt, method):
6 | for i in range(3):
7 | id = f"shd_{method}_{surr_grad}_{abs_refac}_{dt}_{i}"
8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method={method} --abs_refac={abs_refac} --surr_grad={surr_grad} --name=shd --dt={dt} --id={id}")
9 |
10 |
11 | for abs_refac in [10, 20, 30, 40, 50]:
12 | launch(abs_refac, "mg", 1, "blocks")
13 |
--------------------------------------------------------------------------------
/scripts/supervised/run_detach_spikes.py:
--------------------------------------------------------------------------------
1 | import os
2 | root = "" # TODO: Change this to the project folder
3 |
4 |
5 | def launch(abs_refac, surr_grad, dt, detach_spike_grad):
6 | for i in range(3):
7 | id = f"shd_blocks_{surr_grad}_{abs_refac}_{dt}_{detach_spike_grad}_{i}"
8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method=blocks --abs_refac={abs_refac} --surr_grad={surr_grad} --name=shd --dt={dt} --id={id} --detach_spike_grad={detach_spike_grad}")
9 |
10 |
11 | launch(30, "mg", dt=2, detach_spike_grad=False)
12 | launch(30, "fast_sigmoid", dt=2, detach_spike_grad=False)
13 | launch(30, "box_car", dt=2, detach_spike_grad=False)
14 |
--------------------------------------------------------------------------------
/scripts/supervised/run_standard_nmnist.py:
--------------------------------------------------------------------------------
1 | import os
2 | root = "" # TODO: Change this to the project folder
3 |
4 |
5 | def launch(abs_refac, surr_grad, dt, method):
6 | for i in range(3):
7 | id = f"nmnist_{method}_{surr_grad}_{abs_refac}_{dt}_{i}"
8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method={method} --abs_refac={abs_refac} --surr_grad={surr_grad} --name=nmnist --dt={dt} --id={id}")
9 |
10 |
11 | launch(0, "mg", 1, "standard")
12 |
--------------------------------------------------------------------------------
/scripts/supervised/run_standard_shd.py:
--------------------------------------------------------------------------------
1 | import os
2 | root = "" # TODO: Change this to the project folder
3 |
4 |
5 | def launch(abs_refac, surr_grad, dt):
6 | for i in range(3):
7 | id = f"shd_standard_{surr_grad}_{abs_refac}_{dt}_{i}"
8 | os.system(f"python {root}/scripts/supervised/train.py --root={root} --method=standard --abs_refac={abs_refac} --surr_grad={surr_grad} --name=shd --dt={dt} --id={id}")
9 |
10 |
11 | launch(0, "mg", dt=2)
12 |
--------------------------------------------------------------------------------
/scripts/supervised/train.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import ast
3 | import logging
4 | import argparse
5 |
6 | from src import datasets, models, train
7 |
8 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
9 |
10 |
11 | def eval(v):
12 | return ast.literal_eval(v)
13 |
14 |
15 | def get_dataset(args, train=True):
16 | if args.name == "shd":
17 | return datasets.SHDDataset(f"{args.root}/data/SHD", train=train, dt=args.dt)
18 | elif args.name == "nmnist":
19 | return datasets.NMNISTDataset(f"{args.root}/data/N-MNIST", train=train, dt=args.dt)
20 |
21 |
22 | def get_model(args, dataset):
23 | abs_refac = int(args.abs_refac / args.dt)
24 |
25 | if args.name == "shd":
26 | n_in = 700
27 | n_out = 20
28 | elif args.name == "nmnist":
29 | n_in = 1156
30 | n_out = 10
31 |
32 | return models.AuditoryModel(args.method, n_in, args.n_hidden, n_out, dataset.t_len, abs_refac, eval(args.recurrent), args.dt, args.surr_grad, detach_spike_grad=eval(args.detach_spike_grad))
33 |
34 |
35 | def get_trainer(args, model, train_dataset, val_dataset):
36 | gamma = 0.1
37 | if args.name == "shd":
38 | milestones = [15, 15]
39 | epochs = 40
40 | elif args.name == "nmnist":
41 | milestones = [30]
42 | epochs = 20
43 | return train.Trainer(f"{args.root}/results", model, train_dataset, epochs, args.batch_size, args.lr, milestones, gamma, val_dataset, id=args.id)
44 |
45 |
46 | def main():
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument("--root", type=str, default=".")
49 |
50 | # Model
51 | parser.add_argument("--method", type=str, default="standard")
52 | parser.add_argument("--n_hidden", type=int, default=256)
53 | parser.add_argument("--abs_refac", type=float, default=10)
54 | parser.add_argument("--recurrent", type=str, default="True")
55 | parser.add_argument("--detach_spike_grad", type=str, default="True")
56 | parser.add_argument("--surr_grad", type=str, default="fast_sigmoid")
57 |
58 | # Dataset
59 | parser.add_argument("--name", type=str, default="shd")
60 | parser.add_argument("--dt", type=float, default=1)
61 |
62 | # Trainer
63 | parser.add_argument("--batch_size", type=int, default=64)
64 | parser.add_argument("--lr", type=float, default=0.001)
65 | parser.add_argument('--id', type=str, default="")
66 |
67 | args = parser.parse_args()
68 |
69 | train_dataset = get_dataset(args, train=True)
70 | val_dataset = get_dataset(args, train=False)
71 | if args.name == "nmnist":
72 | val_dataset = None
73 | model = get_model(args, train_dataset)
74 | model_trainer = get_trainer(args, model, train_dataset, val_dataset)
75 | model_trainer.train(save=True)
76 |
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="src",
5 | packages=find_packages(),
6 | )
7 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/src/__init__.py
--------------------------------------------------------------------------------
/src/benchmark.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import torch
5 | import pandas as pd
6 |
7 | from src import datasets
8 | from src import models
9 |
10 |
11 | class Benchmarker:
12 |
13 | def __init__(self, method, t_len, abs_refac, n_in, n_hidden, n_layers, batch_size=16, min_r=0, max_r=200, n_samples=11):
14 | self._method = method
15 | self._t_len = t_len
16 | self._abs_refac = abs_refac
17 | self._n_in = n_in
18 | self._n_hidden = n_hidden
19 | self._n_layers = n_layers
20 | self._batch_size = batch_size
21 | self._model = models.ModelBuilder(method, t_len, abs_refac, n_in, n_hidden, n_layers)
22 |
23 | self._data_loader = self._get_data_loader(t_len, n_in, min_r, max_r, batch_size, n_samples*batch_size)
24 | self._benchmark_results = None
25 |
26 | def benchmark(self, device="cuda"):
27 | timing_list = []
28 |
29 | self._model = self._model.to(device)
30 |
31 | for i, data in enumerate(self._data_loader):
32 | # Benchmark forward pass
33 | data = data.to(device)
34 |
35 | start_time = time.time()
36 | output = self._model(data)
37 | torch.cuda.synchronize()
38 | forward_pass_time = time.time() - start_time
39 |
40 | # Benchmark backward pass
41 | start_time = time.time()
42 | loss = output.sum()
43 | loss.backward()
44 | torch.cuda.synchronize()
45 | backward_pass_time = time.time() - start_time
46 |
47 | # Ignore first run (as this usually loads things which slows things down)
48 | # e.g. cudnn finds best conv algorithm
49 | if i > 0:
50 | timing_row = {"forward_time": forward_pass_time, "backward_time": backward_pass_time}
51 | timing_list.append(timing_row)
52 |
53 | self._benchmark_results = timing_list
54 |
55 | def save(self, path):
56 | results_df = self._to_df()
57 | results_df.to_csv(os.path.join(path, f"{self._get_df_name()}.csv"), index=False)
58 |
59 | def _get_description(self):
60 | return {"method": self._method, "t_len": self._t_len, "abs_refac": self._abs_refac, "units": self._n_hidden, "layers": self._n_layers, "batch": self._batch_size}
61 |
62 | def _get_df_name(self):
63 | return f"{self._method}_{self._t_len}_{self._abs_refac}_{self._n_hidden}_{self._n_layers}_{self._batch_size}"
64 |
65 | def _get_data_loader(self, t_len, n_units, min_r, max_r, batch_size, n_samples):
66 | spikes_dataset = datasets.SyntheticSpikes(t_len, n_units, min_r, max_r, n_samples)
67 | return torch.utils.data.DataLoader(spikes_dataset, batch_size, shuffle=False)
68 |
69 | def _to_df(self):
70 | assert self._benchmark_results is not None
71 | results = []
72 |
73 | for results_row in self._benchmark_results:
74 | results.append({**results_row, **self._get_description()})
75 |
76 | return pd.DataFrame(results)
77 |
--------------------------------------------------------------------------------
/src/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .synthetic import SyntheticSpikes
2 | from .neuromorphic import NMNISTDataset, SHDDataset
3 | from .ephys import ValidNeuronQuery, Builder, NoiseBuilder, EphysDataset
4 |
--------------------------------------------------------------------------------
/src/datasets/ephys.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import pickle
4 | from collections import defaultdict
5 |
6 | import numpy as np
7 | import pandas as pd
8 | import torch
9 | from allensdk.core.cell_types_cache import CellTypesCache
10 | from allensdk.ephys.ephys_extractor import EphysSweepFeatureExtractor
11 |
12 |
13 | class ValidNeuronQuery:
14 |
15 | """
16 | This class builds an idx set of neurons which meet the requirement of having 4 repeats.
17 | """
18 |
19 | def __init__(self, root):
20 | self.root = root
21 | self.query_df = self.build_query_df()
22 |
23 | @property
24 | def idx(self):
25 | query = self.query_df["n_trials"] == 4 # Want 4 repeats
26 | return self.query_df[query]["idx"].values
27 |
28 | def build_query_df(self):
29 | train_idxs = [v.split("/")[-1] for v in glob.glob(f"{self.root}/train/*")]
30 | test_idxs = [v.split("/")[-1] for v in glob.glob(f"{self.root}/test/*")]
31 | joint_idxs = list(set(train_idxs) & set(test_idxs))
32 |
33 | data_list = []
34 |
35 | for idx in joint_idxs:
36 | try:
37 | v = torch.load(f"{self.root}/test/{idx}/v.pt")
38 | n_trials = v.shape[0]
39 | data_list.append({"idx": idx, "n_trials": n_trials})
40 | except:
41 | pass
42 |
43 | return pd.DataFrame(data_list)
44 |
45 |
46 | class EphysDataset:
47 |
48 | """
49 | This is the PyTorch friendly e-phys dataset, with all the current and spike tensors.
50 | """
51 |
52 | LENGTH = 10000 # = 1 second (with DT=0.1)
53 |
54 | def __init__(self, root, dataset, neuron_idx, target_sampling_rate=None):
55 | # dataset: str which is train or test
56 | info_df = pd.read_csv(f"{root}/info_df.csv").set_index("idx")
57 |
58 | if target_sampling_rate is None:
59 | self.i = torch.load(f"{root}/{dataset}/{neuron_idx}/i.pt")[:, :, :1*EphysDataset.LENGTH]
60 | elif target_sampling_rate == 20000:
61 | self.i = torch.load(f"{root}/{dataset}/{neuron_idx}/i_20k.pt")[:, :, :2*EphysDataset.LENGTH]
62 | self.v = torch.load(f"{root}/{dataset}/{neuron_idx}/v.pt")[:, :, :EphysDataset.LENGTH]
63 | self.s = torch.load(f"{root}/{dataset}/{neuron_idx}/s.pt")[:, :, :EphysDataset.LENGTH]
64 |
65 | vrest = info_df.loc[neuron_idx]["vrest"]
66 | vthresh = info_df.loc[neuron_idx]["vthresh"]
67 | self.i = self.i / (100 * self.i.max())
68 | self.v = (self.v - vrest) / (vthresh - vrest)
69 |
70 | @property
71 | def hyperparams(self):
72 | return {}
73 |
74 | def __getitem__(self, item):
75 | x = self.i[0, item].unsqueeze(0) # Input current is the same across all trials
76 | trace = self.v[0, item].unsqueeze(0)
77 | trace = torch.clamp(trace, -1, 1)
78 | spikes = self.s[0, item].unsqueeze(0) # Take first spike trial (all are the same)
79 |
80 | return x, (trace, spikes)
81 |
82 | def __len__(self):
83 | return 3
84 |
85 |
86 | class Builder:
87 |
88 | def __init__(self, manifest_file="data/plot_data/allen-brain-observatory/cell_types/manifest.json"):
89 | self.ctc = CellTypesCache(manifest_file=manifest_file)
90 | ephys_df = self.generate_ephys_df()
91 | self.info_df = self.generate_info_df(ephys_df).set_index("idx")
92 |
93 | def save_info_df(self, path):
94 | self.info_df.to_csv(f"{path}/info_df.csv")
95 |
96 | def build(self, path, **kwargs):
97 | for neuron_idx in self.info_df.index:
98 | print(f"Building {neuron_idx}...")
99 | try:
100 | meta, i, v = self.generate_all_sweep_tensor(neuron_idx, **kwargs)
101 | os.mkdir(f"{path}/{neuron_idx}")
102 | torch.save(i, f"{path}/{neuron_idx}/i.pt")
103 | torch.save(v, f"{path}/{neuron_idx}/v.pt")
104 | with open(f"{path}/{neuron_idx}/meta.pkl", "wb") as f:
105 | pickle.dump(meta, f)
106 | except Exception as e:
107 | print(f"Failed {neuron_idx}: {e}")
108 |
109 | def generate_all_sweep_tensor(self, neuron_idx, target_sampling_rate=10000, start_s=1.02, end_s=1.3):
110 | sweep_idxs = self.info_df.loc[neuron_idx]["long_square"]
111 | sweep_dict = {}
112 |
113 | for sweep_idx in sweep_idxs:
114 | i, v, spike_times = self.generate_sweep_tensor(neuron_idx, sweep_idx, target_sampling_rate, start_s, end_s)
115 | sweep_dict[i] = (i, v, spike_times)
116 |
117 | # Sort from lowest to highest current
118 | sweep_dict = {k: v for k, v in sorted(sweep_dict.items(), key=lambda item: item[0])}
119 |
120 | meta = pd.DataFrame([{"i": v[0], "spikes": v[2]} for k, v in sweep_dict.items()])
121 | v = torch.stack([sweep_dict[key][1] for key in sweep_dict.keys()])
122 |
123 | return meta, v
124 |
125 | def generate_sweep_tensor(self, neuron_idx, sweep_number, target_sampling_rate=10000, start_s=1.02, end_s=1.3):
126 | data_set = self.ctc.get_ephys_data(neuron_idx)
127 | sweep_data = data_set.get_sweep(sweep_number)
128 |
129 | index_range = sweep_data["index_range"]
130 | i = sweep_data["stimulus"][0:index_range[1]+1] # in A
131 | v = sweep_data["response"][0:index_range[1]+1] # in V
132 | i *= 1e12 # to pA
133 | v *= 1e3 # to mV
134 |
135 | sampling_rate = int(sweep_data["sampling_rate"]) # in Hz
136 | t = np.arange(0, len(v)) * (1.0 / sampling_rate)
137 | downsample_factor = sampling_rate / target_sampling_rate
138 |
139 | sweep_ext = EphysSweepFeatureExtractor(t=t, v=v, i=i, start=start_s, end=end_s)
140 | sweep_ext.process_spikes()
141 | spike_times = sweep_ext.spike_feature("threshold_t")
142 |
143 | start_idx = int(start_s*sampling_rate)
144 | end_idx = int(end_s*sampling_rate)
145 |
146 | downsampled_v, downsampled_i = [], []
147 | assert v.shape == i.shape
148 |
149 | idx = start_idx
150 | while idx < end_idx:
151 | downsampled_v.append(v[int(idx)])
152 | downsampled_i.append(i[int(idx)])
153 | idx += downsample_factor
154 |
155 | return torch.Tensor(downsampled_i), torch.Tensor(downsampled_v), spike_times
156 |
157 | def generate_ephys_df(self):
158 | cells = {cell["id"]: cell for cell in self.ctc.get_cells()}
159 | ephys_features = self.ctc.get_ephys_features()
160 | ephys_df = pd.DataFrame(ephys_features)
161 | ephys_df['id'] = pd.Series([idx for idx in ephys_df['specimen_id']], index=ephys_df.index)
162 | ephys_df['species'] = pd.Series([cells[idx]['species'] for idx in ephys_df['specimen_id']], index=ephys_df.index)
163 | ephys_df['dendrite_type'] = pd.Series([cells[idx]['dendrite_type'] for idx in ephys_df['specimen_id']], index=ephys_df.index)
164 | ephys_df['structure_layer_name'] = pd.Series([cells[idx]['structure_layer_name'] for idx in ephys_df['specimen_id']], index=ephys_df.index)
165 | ephys_df['disease_state'] = pd.Series([cells[idx]['disease_state'] for idx in ephys_df['specimen_id']], index=ephys_df.index)
166 | query = ephys_df["structure_layer_name"] == "4"
167 | query &= ephys_df["species"] == "Mus musculus"
168 | query &= ephys_df["disease_state"] == ""
169 |
170 | return ephys_df[query]
171 |
172 | def generate_info_df(self, ephys_df):
173 | info_list = []
174 |
175 | neuron_idxs = ephys_df["id"].values
176 |
177 | for neuron_idx in neuron_idxs:
178 | # Sweep info
179 | sweeps = self.ctc.get_ephys_sweeps(neuron_idx)
180 | sweep_numbers = defaultdict(list)
181 | for sweep in sweeps:
182 | sweep_numbers[sweep['stimulus_name']].append(sweep['sweep_number'])
183 |
184 | neuron_type = ephys_df[ephys_df["id"] == neuron_idx]["dendrite_type"].values[0]
185 | vrest = ephys_df[ephys_df["id"] == neuron_idx]["vrest"].values[0]
186 | vthresh = ephys_df[ephys_df["id"] == neuron_idx]["threshold_v_long_square"].values[0]
187 |
188 | info_list.append({"idx": neuron_idx, "type": neuron_type, "long_square": sweep_numbers.get("Long Square"), "noise1": sweep_numbers.get("Noise 1"), "noise2": sweep_numbers.get("Noise 2"), "test": sweep_numbers.get("Test"), "vrest": vrest, "vthresh": vthresh})
189 |
190 | return pd.DataFrame(info_list)
191 |
192 |
193 | class NoiseBuilder(Builder):
194 |
195 | def build(self, path, **kwargs):
196 | for neuron_idx in self.info_df.index:
197 | print(f"Building {neuron_idx}...")
198 | try:
199 | if not os.path.exists(f"{path}/{neuron_idx}"):
200 | os.makedirs(f"{path}/{neuron_idx}")
201 |
202 | i, v, s = self.generate_all_sweep_tensor(neuron_idx, **kwargs)
203 |
204 | # Default used for all experiments
205 | if kwargs.get("target_sampling_rate") is None:
206 | torch.save(i, f"{path}/{neuron_idx}/i.pt")
207 | torch.save(v, f"{path}/{neuron_idx}/v.pt")
208 | torch.save(s, f"{path}/{neuron_idx}/s.pt")
209 | # Re-ran some experiments with DT=0.05ms as requested by one reviewer
210 | elif kwargs.get("target_sampling_rate") == 20000:
211 | torch.save(i, f"{path}/{neuron_idx}/i_20k.pt")
212 | torch.save(v, f"{path}/{neuron_idx}/v_20k.pt")
213 | torch.save(s, f"{path}/{neuron_idx}/s_20k.pt")
214 |
215 | except Exception as e:
216 | print(f"Failed {neuron_idx}: {e}")
217 |
218 | def generate_all_sweep_tensor(self, neuron_idx, target_sampling_rate=10000, noise_type="noise1"):
219 | sweep_idxs = self.info_df.loc[neuron_idx][noise_type]
220 | assert len(sweep_idxs) is not None
221 |
222 | i_list = []
223 | v_list = []
224 | s_list = []
225 |
226 | for sweep_idx in sweep_idxs:
227 | i, v, s = self.generate_noise_sweep_tensor(neuron_idx, sweep_idx, target_sampling_rate)
228 | i_list.append(i)
229 | v_list.append(v)
230 | s_list.append(s)
231 |
232 | return torch.stack(i_list), torch.stack(v_list), torch.stack(s_list)
233 |
234 | def generate_noise_sweep_tensor(self, neuron_idx, sweep_number, target_sampling_rate=10000):
235 | i1, v1, t1 = self.generate_sweep_tensor(neuron_idx, sweep_number, target_sampling_rate, start_s=2, end_s=5)
236 | i2, v2, t2 = self.generate_sweep_tensor(neuron_idx, sweep_number, target_sampling_rate, start_s=10, end_s=13)
237 | i3, v3, t3 = self.generate_sweep_tensor(neuron_idx, sweep_number, target_sampling_rate, start_s=18, end_s=21)
238 | t1 -= 2
239 | t2 -= 10
240 | t3 -= 18
241 | s1 = NoiseBuilder.to_spike_target(t1, target_sampling_rate)
242 | s2 = NoiseBuilder.to_spike_target(t2, target_sampling_rate)
243 | s3 = NoiseBuilder.to_spike_target(t3, target_sampling_rate)
244 |
245 | return torch.stack([i1, i2, i3]), torch.stack([v1, v2, v3]), torch.stack([s1, s2, s3])
246 |
247 | @staticmethod
248 | def to_spike_target(spike_times, target_sampling_rate):
249 | dt = 0.0001
250 | spike_target = torch.zeros(target_sampling_rate)
251 | spike_idx = [int(spike_time // dt) for spike_time in spike_times if int(spike_time // dt) < target_sampling_rate]
252 | spike_target[spike_idx] = 1
253 |
254 | return spike_target
255 |
--------------------------------------------------------------------------------
/src/datasets/neuromorphic.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tables
3 |
4 | import torch
5 | import numpy as np
6 | from brainbox.datasets import BBDataset
7 |
8 | from src.datasets.transforms import SpikeTensorBuilder
9 |
10 |
11 | class H5Dataset(BBDataset):
12 |
13 | def __init__(self, root, train, n_in, n_out, t_len, train_name, test_name, dt):
14 | self._file = None
15 | self._train_name = train_name
16 | self._test_name = test_name
17 | super().__init__(root, train, lambda dataset: H5Dataset.preprocess(dataset), SpikeTensorBuilder(n_in, t_len, dt))
18 | self._file.close()
19 |
20 | self._n_in = n_in
21 | self._n_out = n_out
22 | self._t_len = t_len
23 | self._dt = dt
24 |
25 | @property
26 | def hyperparams(self):
27 | return {**super().hyperparams, "t_len": self._t_len, "dt": self._dt}
28 |
29 | @property
30 | def n_in(self):
31 | return self._n_in
32 |
33 | @property
34 | def n_out(self):
35 | return self._n_out
36 |
37 | @property
38 | def t_len(self):
39 | return self._t_len
40 |
41 | @property
42 | def dt(self):
43 | return self._dt
44 |
45 | @staticmethod
46 | def preprocess(dataset):
47 | processed_dataset = []
48 | units, times = dataset
49 |
50 | for i in range(len(units)):
51 | item_units = torch.Tensor(np.array(units[i], dtype=np.int))
52 | item_times = torch.Tensor(np.array(times[i], dtype=np.float))
53 | processed_dataset.append((item_units, item_times))
54 |
55 | return processed_dataset
56 |
57 | @staticmethod
58 | def _open_file(hdf5_file_path):
59 | fileh = tables.open_file(hdf5_file_path, mode="r")
60 | units = fileh.root.spikes.units
61 | times = fileh.root.spikes.times
62 | labels = fileh.root.labels
63 |
64 | return fileh, units, times, labels
65 |
66 | def _load_dataset(self, train):
67 | name = self._train_name if train else self._test_name
68 | fileh, units, times, labels = H5Dataset._open_file(os.path.join(self._root, name))
69 | targets = torch.Tensor(labels)
70 | self._file = fileh
71 |
72 | return (units, times), targets
73 |
74 |
75 | class NMNISTDataset(H5Dataset):
76 |
77 | T_LEN = 400
78 |
79 | def __init__(self, root, train=True, dt=1):
80 | t_len = int(NMNISTDataset.T_LEN / dt)
81 | super().__init__(root, train, n_in=1156, n_out=10, t_len=t_len, train_name="train.h5", test_name="test.h5", dt=dt)
82 |
83 | @property
84 | def name(self):
85 | return "nmnist"
86 |
87 |
88 | class SHDDataset(H5Dataset):
89 |
90 | T_LEN = 1200
91 |
92 | def __init__(self, root, train=True, dt=2):
93 | t_len = int(SHDDataset.T_LEN / dt)
94 | super().__init__(root, train, n_in=700, n_out=20, t_len=t_len, train_name="shd_train.h5", test_name="shd_test.h5", dt=dt)
95 |
96 | @property
97 | def name(self):
98 | return "shd"
99 |
--------------------------------------------------------------------------------
/src/datasets/synthetic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.distributions.poisson import Poisson
3 | from brainbox.datasets import BBDataset
4 |
5 |
6 | class SyntheticSpikes(BBDataset):
7 |
8 | """
9 | This is the synthetic spike dataset with which the model was benchmarked.
10 | """
11 |
12 | def __init__(self, t_len, n_units, min_r, max_r, n_samples):
13 | super().__init__(None)
14 | self.t_len = t_len
15 | self.n_units = n_units
16 | self.min_r = min_r
17 | self.max_r = max_r
18 | self.n_samples = n_samples
19 |
20 | def __getitem__(self, i):
21 | rate = torch.FloatTensor(1).uniform_(self.min_r, self.max_r).item()
22 | x = self._create_spikes(rate, self.n_units, self.t_len)
23 |
24 | return x
25 |
26 | def __len__(self):
27 | return self.n_samples
28 |
29 | def _load_dataset(self, train):
30 | return None, None
31 |
32 | def _create_spikes(self, rate, n_units, t_len):
33 | pois_dis = Poisson(rate/t_len)
34 | if type(n_units) == tuple:
35 | samples = pois_dis.sample(sample_shape=(*n_units, t_len))
36 | else:
37 | samples = pois_dis.sample(sample_shape=(n_units, t_len))
38 | samples[samples > 1] = 1
39 |
40 | return samples
--------------------------------------------------------------------------------
/src/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import brainbox
4 | from brainbox.datasets.transforms import BBTransform
5 |
6 |
7 | class SpikeTensorBuilder(BBTransform):
8 |
9 | def __init__(self, n_units, t_len, dt):
10 | self._n_units = n_units
11 | self._t_len = t_len
12 | self._dt = dt
13 |
14 | def __call__(self, args):
15 | units, times = args[0], args[1]
16 | units = units % self._n_units
17 | times = torch.round(times * 1000. / self._dt).int()
18 |
19 | # Constrain spike length
20 | idxs = (times < self._t_len)
21 | units = units[idxs]
22 | times = times[idxs]
23 |
24 | # Build COO tensor
25 | indices = torch.stack([torch.Tensor(units.tolist()), torch.Tensor(times.tolist())], dim=0).long()
26 | shape = torch.Size([self._n_units, self._t_len, ])
27 | spikes = torch.FloatTensor(np.ones(len(indices[0])))
28 |
29 | return torch.sparse.FloatTensor(indices, spikes, shape).to_dense()
30 |
31 |
32 | class List:
33 |
34 | @staticmethod
35 | def get_nmnist_transform(t_len, use_augmentation=False):
36 | if use_augmentation:
37 | raise NotImplementedError
38 | else:
39 | transform_list = [SpikeTensorBuilder(n_units=1156, t_len=t_len, dt=1)]
40 |
41 | return brainbox.datasets.transforms.Compose(transform_list)
42 |
43 | @staticmethod
44 | def get_shd_transform(t_len):
45 | transform_list = [SpikeTensorBuilder(n_units=700, t_len=t_len, dt=2)]
46 |
47 | return brainbox.datasets.transforms.Compose(transform_list)
48 |
--------------------------------------------------------------------------------
/src/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | from brainbox.physiology.spiking import SpikeToPSTH
4 |
5 | from src import datasets, train
6 |
7 |
8 | class SpikeTrainEV:
9 |
10 | def __init__(self, st_len, sig):
11 | self._spike_smoother = SpikeToPSTH(st_len, sig)
12 |
13 | def __call__(self, input, target):
14 | # input: b x 3 x t
15 | smooth_input_trains = self.smooth_spike_trains(input)
16 | smooth_target_trains = self.smooth_spike_trains(target)
17 | flatten_input_trains = self.flatten_spike_trains(smooth_input_trains)
18 | flatten_target_trains = self.flatten_spike_trains(smooth_target_trains)
19 |
20 | return SpikeTrainEV.ev(flatten_input_trains, flatten_target_trains.mean(0).unsqueeze(0)).mean()
21 |
22 | def smooth_spike_trains(self, spike_trains):
23 | # spike_trains: b x 3 x t
24 | b = spike_trains.shape[0]
25 |
26 | return self._spike_smoother(spike_trains.view(b*3, -1).unsqueeze(1))[:, 0].view(b, 3, -1)
27 |
28 | def flatten_spike_trains(self, spike_trains):
29 | # spike_trains: b x 3 x t
30 | return spike_trains.flatten(1, 2)
31 |
32 | @staticmethod
33 | def ev(input, target):
34 | # input: b x t
35 | return (input.var(1) + target.var(1) - (input - target).var(1)) / (input.var(1) + target.var(1))
36 |
37 |
38 | class EphysAnalysis:
39 |
40 | def __init__(self, root, method, abs_refac_ms, downsample, taus=[100]):
41 | self.root = root
42 | self.method = method
43 | self.abs_refac_ms = abs_refac_ms
44 | self.downsample = downsample
45 | self.taus = taus
46 |
47 | self.metrics = {tau: SpikeTrainEV(10000, tau) for tau in taus}
48 | self.neuron_idxs = datasets.ValidNeuronQuery(f"{root}/data/ephys").idx
49 |
50 | # Init test dataset
51 | self.test_dataset = {}
52 | self.norm_factors = {tau: {} for tau in taus}
53 |
54 | for neuron_idx in self.neuron_idxs:
55 | if downsample == 0.5:
56 | self.test_dataset[neuron_idx] = datasets.EphysDataset(f"{root}/data/ephys", "test", int(neuron_idx), target_sampling_rate=20000)
57 | else:
58 | self.test_dataset[neuron_idx] = datasets.EphysDataset(f"{root}/data/ephys", "test", int(neuron_idx))
59 | target_spikes = self.test_dataset[neuron_idx].s.cpu()
60 | for tau in taus:
61 | self.norm_factors[tau][neuron_idx] = self.metrics[tau](target_spikes, target_spikes).item()
62 |
63 | self._norm_df = {tau: pd.Series(self.norm_factors[tau]).to_frame().rename(columns={0: "score"}) for tau in taus}
64 | self._ev_df = None
65 |
66 | def ev_df(self, tau, normalise):
67 | if self._ev_df is None:
68 | self._ev_df = self._build_ev_df().set_index("neuron_idx")
69 |
70 | query = self._ev_df["tau"] == tau
71 | ev_df = self._ev_df[query]["score"].to_frame()
72 |
73 | if normalise:
74 | df = ev_df / self._norm_df[tau]
75 | else:
76 | df = ev_df
77 |
78 | df.index = df.index.map(int)
79 |
80 | return df
81 |
82 | def get_times_df(self):
83 | dir_name = f"{self.method}_{self.abs_refac_ms}_{self.downsample}"
84 |
85 | times_list = []
86 |
87 | for neuron_idx in self.neuron_idxs:
88 | times_csv = pd.read_csv(f"{self.root}/results/ephys/{dir_name}/{neuron_idx}/times.csv")
89 | forward_pass, backward_pass = times_csv.sum()
90 | times_list.append({"neuron_idx": neuron_idx, "forward_pass": forward_pass, "backward_pass": backward_pass})
91 |
92 | return pd.DataFrame(times_list)
93 |
94 | def _build_ev_df(self):
95 | dir_name = f"{self.method}_{self.abs_refac_ms}_{self.downsample}"
96 |
97 | metric_list = []
98 |
99 | for i, neuron_idx in enumerate(self.neuron_idxs):
100 | print(f"Building {i}/{len(self.neuron_idxs)} {neuron_idx}...")
101 | dt01ref = self.downsample == 0.5
102 | neuron = train.EphysTrainer.load_model(f"{self.root}/results/ephys/{dir_name}", neuron_idx, dt01ref=dt01ref)
103 |
104 | with torch.no_grad():
105 | test_dataset = self.test_dataset[neuron_idx]
106 | pred_spikes = neuron(test_dataset.i[0].unsqueeze(1).cuda()).permute(1, 0, 2).cpu()
107 | target_spikes = test_dataset.s.cpu()
108 |
109 | for tau in self.taus:
110 | score = self.metrics[tau](target_spikes, pred_spikes).item() # note: model spikes are reference
111 | metric_list.append({"neuron_idx": neuron_idx, "score": score, "tau": tau})
112 |
113 | return pd.DataFrame(metric_list)
114 |
115 | def load_prediction(self, neuron_idx): # Load prediction for a certain fit and neuron
116 | test_dataset = self.test_dataset[str(neuron_idx)]
117 |
118 | with torch.no_grad():
119 | dir_name = f"{self.method}_{self.abs_refac_ms}_{self.downsample}"
120 | dt01ref = self.downsample == 0.5
121 | neuron = train.EphysTrainer.load_model(f"{self.root}/results/ephys/{dir_name}", neuron_idx, dt01ref=dt01ref)
122 |
123 | output = neuron(test_dataset.i[0].unsqueeze(1).cuda(), mode="val")
124 | spikes = output[0].permute(1, 0, 2).cpu()
125 | mem = output[1].permute(1, 0, 2).cpu()
126 |
127 | return spikes, mem, test_dataset.s, test_dataset.v, test_dataset.i
128 |
129 | def load_params_df(self):
130 | dir_name = f"{self.method}_{self.abs_refac_ms}_{self.downsample}"
131 |
132 | params_list = []
133 |
134 | for i, neuron_idx in enumerate(self.neuron_idxs):
135 | print(f"Building {i}/{len(self.neuron_idxs)} {neuron_idx}...")
136 | neuron = train.EphysTrainer.load_model(f"{self.root}/results/ephys/{dir_name}", neuron_idx).neuron
137 |
138 | params_list.append({"neuron_idx": neuron_idx, "beta": neuron.beta.item(), "p": neuron.p.item(), "b": neuron.b.item()})
139 |
140 | return pd.DataFrame(params_list)
141 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from brainbox import models
6 |
7 | import src.snn.block.blocks as blocks
8 | import src.snn.snn as snn
9 |
10 |
11 | class AuditoryModel(models.BBModel):
12 |
13 | # AuditoryModel name comes from the SHD dataset being auditory
14 |
15 | HIDDEN_MEM_TIME = 20
16 | HIDDEN_ADAPT_TIME = 150
17 | READOUT_MEM_TIME = 20
18 |
19 | def __init__(self, method, n_in, n_hidden, n_out, t_len, abs_refac, recurrent=True, dt=1, surr_grad="fast_sigmoid", detach_spike_grad=True):
20 | super().__init__()
21 | self._method = method
22 | self._n_in = n_in
23 | self._n_hidden = n_hidden
24 | self._n_out = n_out
25 | self._t_len = t_len
26 | self._abs_refac = abs_refac
27 | self._recurrent = recurrent
28 | self._dt = dt
29 | self._surr_grad = surr_grad
30 |
31 | init_hidden_beta = np.exp(-dt / AuditoryModel.HIDDEN_MEM_TIME)
32 | init_hidden_p = np.exp(-dt / AuditoryModel.HIDDEN_ADAPT_TIME)
33 | init_readout_beta = np.exp(-dt / AuditoryModel.READOUT_MEM_TIME)
34 |
35 | if method == "standard":
36 | self._thalamic_layer = snn.SNN(n_in, n_hidden, 1, t_len, abs_refac, recurrent, beta_grad=True, adapt=True, init_beta=init_hidden_beta, init_p=init_hidden_p, surr_grad=self._surr_grad)
37 | self._cortical_layer = snn.SNN(n_hidden, n_hidden, 1, t_len, abs_refac, recurrent, beta_grad=True, adapt=True, init_beta=init_hidden_beta, init_p=init_hidden_p, surr_grad=self._surr_grad)
38 | self._output = snn.SNNIntegrator(n_hidden, n_out, t_len, init_beta=init_readout_beta)
39 | else:
40 | self._thalamic_layer = blocks.Blocks(n_in, n_hidden, 1, t_len, abs_refac, recurrent, beta_grad=True, adapt=True, init_beta=init_hidden_beta, init_p=init_hidden_p, surr_grad=self._surr_grad, detach_spike_grad=detach_spike_grad)
41 | self._cortical_layer = blocks.Blocks(n_hidden, n_hidden, 1, t_len, abs_refac, recurrent, beta_grad=True, adapt=True, init_beta=init_hidden_beta, init_p=init_hidden_p, surr_grad=self._surr_grad, detach_spike_grad=detach_spike_grad)
42 | self._output = blocks.BlocksIntegrator(n_hidden, n_out, t_len, init_beta=init_readout_beta)
43 |
44 | @property
45 | def hyperparams(self):
46 | return {**super().hyperparams, "method": self._method, "n_in": self._n_in, "n_hidden": self._n_hidden, "n_out": self._n_out, "t_len": self._t_len, "abs_refac": self._abs_refac, "recurrent": self._recurrent, "dt": self._dt, "surr_grad": self._surr_grad}
47 |
48 | def forward(self, x, mode="train"):
49 | # x: b x n x t
50 | thalamic_output = self._thalamic_layer(x, mode)
51 | cortical_output = self._cortical_layer(thalamic_output if mode == "train" else thalamic_output[0], mode)
52 |
53 | if mode == "train":
54 | return self._output(cortical_output, mode).sum(2)
55 | else:
56 | return self._output(cortical_output[0], mode).sum(2), cortical_output, thalamic_output
57 |
58 |
59 | class ModelBuilder(models.BBModel):
60 |
61 | def __init__(self, method, t_len, abs_refac, n_in, n_hidden, n_layers):
62 | super().__init__()
63 | self._layers = nn.ModuleList()
64 |
65 | for i in range(n_layers):
66 | n_in = n_in if i == 0 else n_hidden
67 | if method == "standard":
68 | self._layers.append(snn.SNN(n_in, n_hidden, 1, t_len, abs_refac, recurrent=True, beta_grad=True, adapt=True, init_beta=0.9, init_p=0.9, surr_grad="mg"))
69 | else:
70 | self._layers.append(blocks.Blocks(n_in, n_hidden, 1, t_len, abs_refac, recurrent=True, beta_grad=True, adapt=True, init_beta=0.9, init_p=0.9, surr_grad="mg"))
71 |
72 | def forward(self, x):
73 | for layer in self._layers:
74 | x = layer(x)
75 |
76 | return x
77 |
78 |
79 | class Neuron(models.BBModel):
80 |
81 | def __init__(self, method, abs_refac_ms, downsample=1, dt01ref=False):
82 | super().__init__()
83 | self.method = method
84 | self.abs_refac_ms = abs_refac_ms
85 | self.downsample = downsample
86 | self.dt01ref = dt01ref
87 | self.dt_ms = downsample * 0.1 # Larger dt_ms == downsample temporal resolution
88 |
89 | if not dt01ref:
90 | self.neuron = Neuron.get_neuron(method, abs_refac_ms, self.dt_ms, downsample)
91 | else:
92 | self.neuron = Neuron.get_neuron(method, abs_refac_ms, 0.5*self.dt_ms, 0.5*downsample)
93 | upsample_kernel = torch.zeros(downsample)
94 | upsample_kernel[0] = 1
95 | self.upsample_kernel = nn.Parameter(upsample_kernel.view(1, 1, -1), requires_grad=False)
96 |
97 | @property
98 | def hyperparams(self):
99 | return {**super().hyperparams, "method": self.method, "abs_refac_ms": self.abs_refac_ms, "downsample": self.downsample, "dt01ref": self.dt01ref}
100 |
101 | def forward(self, x, mode="train"):
102 | x = F.avg_pool1d(x, self.downsample, self.downsample) # Down sample the signal (if need be)
103 | spikes = self.neuron(x, mode)
104 |
105 | # Return data for plotting
106 | if mode == "val":
107 | spikes, mem = spikes[0], spikes[1]
108 |
109 | return spikes, mem
110 |
111 | # Return predicted spike train for training
112 | if self.dt01ref: # When running in DT=0.05ms (was added on to run additional experiments for a reviewer)
113 | spikes = F.max_pool1d(spikes, 2, 2)
114 | return spikes
115 |
116 | if self.downsample == 1:
117 | return spikes
118 | else:
119 | return F.conv_transpose1d(spikes, self.upsample_kernel, stride=self.downsample)
120 |
121 | @staticmethod
122 | def get_neuron(method, abs_refac_ms, dt_ms, downsample):
123 | init_beta = np.exp(-dt_ms / 20)
124 | init_p = np.exp(-dt_ms / 100)
125 | t_len = int(1000 / dt_ms)
126 | abs_refac_ms = int(abs_refac_ms / dt_ms)
127 |
128 | if method == "blocks":
129 | neuron = blocks.Blocks(1, 1, rf_len=1, t_len=t_len, t_latency=abs_refac_ms, recurrent=False, beta_grad=True, adapt=True, init_beta=init_beta, init_p=init_p, detach_spike_grad=True, surr_grad="mg")
130 | else:
131 | neuron = snn.SNN(1, 1, rf_len=1, t_len=t_len, t_latency=abs_refac_ms, recurrent=False, beta_grad=True, adapt=True, init_beta=init_beta, init_p=init_p, detach_spike_grad=True, surr_grad="mg")
132 |
133 | neuron.init_weight(neuron._rf_weight, "constant", c=downsample)
134 | neuron._b = nn.Parameter(data=torch.Tensor([0.1 / downsample]), requires_grad=True)
135 |
136 | return neuron
137 |
--------------------------------------------------------------------------------
/src/query.py:
--------------------------------------------------------------------------------
1 | import glob
2 |
3 | import pandas as pd
4 | from brainbox import trainer
5 |
6 | from src import models, train
7 |
8 |
9 | models.snn.BaseSNN.MIN_BETA = 0.01
10 | models.snn.BaseSNN.MAX_BETA = 0.99
11 |
12 |
13 | class BenchmarkQuery:
14 |
15 | def __init__(self, root, batches=[32, 64, 128]):
16 | self._root = root
17 |
18 | self._results_df = self._build_df()
19 | self._results_df = pd.concat([self._query_results(batch=b) for b in batches])
20 |
21 | def _build_df(self):
22 | results_df_list = []
23 |
24 | for path in self._get_paths(self._root):
25 | results_df_list.append(pd.read_csv(path))
26 |
27 | results_df = pd.concat(results_df_list)
28 | results_df["total_time"] = results_df["forward_time"] + results_df["backward_time"]
29 |
30 | return results_df
31 |
32 | def _get_paths(self, root):
33 | return [path for path in glob.glob(f"{root}/*")]
34 |
35 | def _query_results(self, **kwargs):
36 | query = True
37 | for key, value in kwargs.items():
38 | query &= self._results_df[key] == value
39 |
40 | if len(kwargs) > 0:
41 | return self._results_df[query]
42 |
43 | return self._results_df
44 |
45 | def get_speedup(self):
46 | results_df = self._build_df()
47 | standard_times = results_df[results_df["method"] == "standard"].set_index(["t_len", "units", "batch", "abs_refac", "layers"])[["forward_time", "backward_time", "total_time"]]
48 | blocks_times = results_df[results_df["method"] == "blocks"].set_index(["t_len", "units", "batch", "abs_refac", "layers"])[["forward_time", "backward_time", "total_time"]]
49 |
50 | speedup_df = standard_times / blocks_times
51 | speedup_df.rename(columns={"forward_time": "forward_speedup", "backward_time": "backward_speedup", "total_time": "total_speedup"}, inplace=True)
52 |
53 | return speedup_df
54 |
55 |
56 | class SupervisedQuery:
57 |
58 | def __init__(self, root):
59 | self.root = root
60 |
61 | def get_average_duration_per_batch(self, models_root, model_id):
62 | durations_list = []
63 |
64 | duration = trainer.load_log(models_root, model_id)["duration"][1:].mean()
65 | durations_list.append({"model_id": model_id, "duration": duration})
66 |
67 | return pd.DataFrame(durations_list).set_index("model_id").values[0][0]
68 |
69 | def build_results(self, dataset, methods, sgs, abs_refacs, repeats, detach=True, batch_size=500):
70 | results_list = []
71 |
72 | for method in methods:
73 | for sg in sgs:
74 | for abs_refac in abs_refacs:
75 | for i in range(repeats):
76 | if detach:
77 | name = f"{dataset.name}_{method}_{sg}_{abs_refac}_{dataset.dt}_{i}"
78 | else:
79 | name = f"{dataset.name}_{method}_{sg}_{abs_refac}_{dataset.dt}_{detach}_{i}"
80 | print(f"Loading {name}...")
81 | model = train.Trainer.load_model(f"{self.root}/results/supervised", name)
82 | val_acc = train.Trainer.get_acc(model, dataset, batch_size)
83 | avg_time = self.get_average_duration_per_batch(f"{self.root}/results/supervised", name)
84 | results_list.append({"dataset": dataset.name, "method": method, "sg": sg, "abs_refac": abs_refac, "dt": dataset.dt, "i": i, "val_acc": val_acc, "avg_time": avg_time})
85 |
86 | return pd.DataFrame(results_list)
--------------------------------------------------------------------------------
/src/snn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/src/snn/__init__.py
--------------------------------------------------------------------------------
/src/snn/block/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/src/snn/block/__init__.py
--------------------------------------------------------------------------------
/src/snn/block/block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from src.snn.block.util import bconv1d
6 | from src.snn import surrogate
7 |
8 |
9 | class Block(nn.Module):
10 |
11 | def __init__(self, n_in, t_len, surr_grad):
12 | super().__init__()
13 | self._n_in = n_in
14 | self._t_len = t_len
15 | self._surr_grad = surr_grad
16 |
17 | self._beta_ident_base = nn.Parameter(torch.ones(n_in, t_len), requires_grad=False)
18 | self._beta_exp = nn.Parameter(torch.arange(t_len).flip(0).unsqueeze(0).expand(n_in, t_len).float(), requires_grad=False)
19 | self._phi_kernel = nn.Parameter((torch.arange(t_len) + 1).flip(0).float().view(1, 1, 1, t_len), requires_grad=False)
20 |
21 | @staticmethod
22 | def g(faulty_spikes):
23 | negate_faulty_spikes = faulty_spikes.clone().detach()
24 | negate_faulty_spikes[faulty_spikes == 1.0] = 0
25 | faulty_spikes -= negate_faulty_spikes
26 |
27 | return faulty_spikes
28 |
29 | def forward(self, current, beta, v_init=None, v_th=1, mode="train"):
30 |
31 | if v_init is not None:
32 | current[:, :, 0] += beta * v_init
33 |
34 | pad_current = F.pad(current, pad=(self._t_len - 1, 0)).unsqueeze(1)
35 |
36 | # compute membrane potential without reset
37 | beta_kernel = self.build_beta_kernel(beta)
38 | membrane = bconv1d(pad_current, beta_kernel)
39 |
40 | # map no-reset membrane potentials to output spikes
41 | v_th = v_th.unsqueeze(1)
42 | faulty_spikes = surrogate.spike(membrane - v_th, self._surr_grad)
43 |
44 | pad_spikes = F.pad(faulty_spikes, pad=(self._t_len - 1, 0))
45 | z = F.conv2d(pad_spikes, self._phi_kernel)
46 | z_copy = z.clone().squeeze(1)
47 |
48 | if mode == "train":
49 | return Block.g(z).squeeze(1), z_copy, membrane.squeeze(1)
50 | elif mode == "val":
51 | return Block.g(z).squeeze(1), z_copy, faulty_spikes, membrane.squeeze(1)
52 |
53 | def build_beta_kernel(self, beta):
54 | beta_base = beta.unsqueeze(1).multiply(self._beta_ident_base)
55 | return torch.pow(beta_base, self._beta_exp).unsqueeze(1).unsqueeze(1)
56 |
--------------------------------------------------------------------------------
/src/snn/block/blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from src.snn.snn import BaseSNN
7 | from src.snn.block.block import Block
8 | from src.snn.block.util import time_cat, bconv1d
9 |
10 |
11 | class Blocks(BaseSNN):
12 |
13 | def __init__(self, n_in, n_out, rf_len, t_len, t_latency, recurrent=True, beta_grad=True, adapt=True, init_beta=1, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid"):
14 | super().__init__(n_in, n_out, rf_len, t_len, t_latency, recurrent, beta_grad, adapt, init_beta, init_p, detach_spike_grad, surr_grad)
15 |
16 | self._t_len_block = t_latency + 1
17 | self._block = Block(n_out, self._t_len_block, surr_grad)
18 | self._n_blocks = math.ceil(t_len / self._t_len_block)
19 | self._t_pad = self._n_blocks * self._t_len_block - self._t_len
20 |
21 | self._p_ident_base = nn.Parameter(torch.ones(n_out, self._t_len_block), requires_grad=False)
22 | self._p_exp = nn.Parameter(torch.arange(1, self._t_len_block + 1).float(), requires_grad=False)
23 |
24 | def process(self, x, mode="train"):
25 | x_init = x
26 | if self._t_pad != 0:
27 | x = F.pad(x, pad=(0, self._t_pad))
28 |
29 | mem_list = []
30 | spikes_list = []
31 | z_list = []
32 |
33 | z = torch.zeros_like(x[:, :, self._t_len_block:])
34 | v_init = torch.zeros_like(x[:, :, 0]).to(x.device)
35 | int_mem = torch.zeros_like(x[:, :, 0]).to(x.device)
36 |
37 | a_kernel = torch.zeros_like(x).to(x.device)[:, :, :self._t_len_block]
38 | v_th = torch.ones_like(x).to(x.device)[:, :, :self._t_len_block]
39 | v_th_list = []
40 |
41 | for i in range(self._n_blocks):
42 | x_slice = x[:, :, i * self._t_len_block: (i+1) * self._t_len_block]
43 |
44 | # Recurrent current and refractory mask only included after first block
45 | if i > 0:
46 | # Add recurrent current to input
47 | if self._recurrent:
48 | rec_current = self.get_rec_input(spikes)
49 | x_slice = x_slice + rec_current
50 |
51 | # Apply refractory mask to input
52 | if self._detach_spike_grad:
53 | spike_mask = spikes.detach().amax(dim=2).bool()
54 | else:
55 | spike_mask = spikes.amax(dim=2).bool()
56 | refac_mask = (z < spike_mask.unsqueeze(2)) * x_slice
57 | x_slice -= refac_mask
58 |
59 | # Set initial membrane potentials
60 | v_init = int_mem[:, :, -1] * ~spike_mask # if spiked -> zero initial membrane potential
61 |
62 | # Set initial adaptive params
63 | if self._adapt:
64 | # Get a at time of spike + spike (which is equal to 1/p to account for raising v_th by 1 next step
65 | # do the math or see paper if this is not clear)
66 | if self._detach_spike_grad:
67 | a_at_spike = (a_kernel * spikes.detach()).sum(dim=2) + (1 / self.p)
68 | else:
69 | a_at_spike = (a_kernel * spikes).sum(dim=2) + (1 / self.p)
70 | decay_steps = (z > 1).sum(dim=2) # Compute number of decay steps
71 | new_a = a_at_spike * torch.pow(self.p.unsqueeze(0), decay_steps)
72 | a = (a_kernel[:, :, -1] * ~spike_mask) + (new_a * spike_mask)
73 |
74 | # Update a for neurons that spiked
75 | a_kernel = self.compute_a_kernel(a, self.p)
76 | v_th = 1 + self.b.view(1, -1, 1) * a_kernel
77 |
78 | if mode == "train":
79 | spikes, z, int_mem = self._block(x_slice, self.beta, v_init=v_init, v_th=v_th, mode="train")
80 | spikes_list.append(spikes)
81 | elif mode == "val":
82 | spikes, z, _, int_mem = self._block(x_slice, self.beta, v_init=v_init, v_th=v_th, mode="val")
83 | spikes_list.append(spikes)
84 | mem_list.append(int_mem)
85 | z_list.append(z)
86 | v_th_list.append(v_th)
87 |
88 | if mode == "train":
89 | return time_cat(spikes_list, self._t_pad)
90 | elif mode == "val":
91 | return time_cat(spikes_list, self._t_pad), time_cat(mem_list, self._t_pad), x_init, time_cat(z_list, self._t_pad), time_cat(v_th_list, self._t_pad)
92 |
93 | def compute_a_kernel(self, a, p):
94 | # a: b x n
95 | # p: n
96 | # output: b x n x t
97 |
98 | return torch.pow(p.unsqueeze(-1) * self._p_ident_base, self._p_exp).unsqueeze(0) * a.unsqueeze(-1)
99 |
100 |
101 | class BlocksIntegrator(BaseSNN):
102 |
103 | def __init__(self, n_in, n_out, t_len, init_beta=1):
104 | super().__init__(n_in, n_out, 1, t_len, t_latency=0, recurrent=False, beta_grad=True, adapt=False, init_beta=init_beta, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid")
105 | self._block = Block(n_out, t_len, "fast_sigmoid")
106 |
107 | def process(self, x, mode="train"):
108 | pad_current = F.pad(x, pad=(self._t_len - 1, 0)).unsqueeze(1)
109 |
110 | # compute membrane potential without reset
111 | beta_kernel = self._block.build_beta_kernel(self.beta)
112 | membrane = bconv1d(pad_current, beta_kernel)
113 |
114 | return membrane.squeeze(1)
115 |
--------------------------------------------------------------------------------
/src/snn/block/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def bconv1d(x, weight, stride=1, dilation=1, padding=0):
6 | # Would be useful if PyTorch provided batched 1D convs in their library
7 | b, c, n, h = x.shape
8 | n, out_channels, in_channels, kernel_width_size = weight.shape
9 |
10 | out = x.view(b, c * n, h)
11 | weight = weight.view(n * out_channels, in_channels, kernel_width_size)
12 |
13 | out = F.conv1d(out, weight=weight, bias=None, stride=stride, dilation=dilation, groups=n, padding=padding)
14 |
15 | return out.view(b, c, n, -1)
16 |
17 |
18 | def time_cat(tensor_list, t_pad):
19 | tensor = torch.cat(tensor_list, dim=2)
20 |
21 | if t_pad > 0:
22 | tensor = tensor[:, :, :-t_pad]
23 |
24 | return tensor
25 |
--------------------------------------------------------------------------------
/src/snn/snn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from brainbox.models import BBModel
6 |
7 | from src.snn import surrogate
8 |
9 | # SNN control models.
10 |
11 |
12 | class BaseSNN(BBModel):
13 |
14 | MIN_BETA = 0.001
15 | MAX_BETA = 0.999
16 |
17 | def __init__(self, n_in, n_out, rf_len, t_len, t_latency, recurrent=True, beta_grad=True, adapt=True, init_beta=1, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid"):
18 | super().__init__()
19 | self._n_in = n_in
20 | self._n_out = n_out
21 | self._rf_len = rf_len
22 | self._t_len = t_len
23 | self._t_latency = t_latency
24 | self._recurrent = recurrent
25 | self._beta_grad = beta_grad
26 | self._adapt = adapt
27 | self._detach_spike_grad = detach_spike_grad
28 | self._surr_grad = surr_grad
29 |
30 | self._beta = nn.Parameter(data=torch.Tensor(n_out * [init_beta]), requires_grad=beta_grad)
31 | self._rf_weight = nn.Parameter(torch.rand(n_out, 1, n_in, self._rf_len), requires_grad=True)
32 | self._rf_bias = nn.Parameter(torch.zeros(n_out), requires_grad=True)
33 |
34 | self._rec_weight = nn.Parameter(torch.rand(n_out, n_out), requires_grad=recurrent)
35 |
36 | self._p = nn.Parameter(data=torch.Tensor(n_out * [init_p]), requires_grad=adapt)
37 | self._b = nn.Parameter(data=torch.Tensor(n_out * [1.8]), requires_grad=adapt)
38 |
39 | self.init_weight(self._rf_weight, "uniform", a=-1 / np.sqrt(n_in * rf_len), b=1 / np.sqrt(n_in * rf_len))
40 | self.init_weight(self._rec_weight, "identity")
41 |
42 | @property
43 | def hyperparams(self):
44 | return {**super().hyperparams, "n_in": self._n_in, "n_out": self._n_out, "rf_len": self._rf_len, "t_len": self._t_len, "t_latency": self._t_latency, "recurrent": self._recurrent, "beta_grad": self._beta_grad, "adapt": self._adapt, "detach_spike_grad": self._detach_spike_grad, "surr_grad": self._surr_grad}
45 |
46 | @property
47 | def p(self):
48 | return torch.clamp(self._p.abs(), min=0, max=0.999)
49 |
50 | @property
51 | def b(self):
52 | return torch.clamp(self._b.abs(), min=0.001, max=1)
53 |
54 | @property
55 | def beta(self):
56 | return torch.clamp(self._beta, min=BaseSNN.MIN_BETA, max=BaseSNN.MAX_BETA)
57 |
58 | @property
59 | def rec_weight(self):
60 | return self._rec_weight
61 |
62 | def get_rec_input(self, spikes):
63 | return torch.einsum("ij, bj...->bi...", self.rec_weight, spikes.detach() if self._detach_spike_grad else spikes)
64 |
65 | def forward(self, x, mode="train"):
66 | # x: b x n x t
67 |
68 | x = F.pad(x, (self._rf_len - 1, 0))
69 | x = x.unsqueeze(1) # Add channel dim
70 | x = F.conv2d(x, self._rf_weight, self._rf_bias)[:, :, 0] # Slice out height dim
71 |
72 | return self.process(x, mode)
73 |
74 | def process(self, x, mode):
75 | raise NotImplementedError
76 |
77 |
78 | class SNN(BaseSNN):
79 |
80 | def __init__(self, n_in, n_out, rf_len, t_len, t_latency, recurrent=False, beta_grad=True, adapt=True, init_beta=1, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid"):
81 | super().__init__(n_in, n_out, rf_len, t_len, t_latency, recurrent, beta_grad, adapt, init_beta, init_p, detach_spike_grad, surr_grad)
82 |
83 | def process(self, x, mode="train"):
84 | # x: b x n x t
85 |
86 | mem_list = []
87 | spikes_list = []
88 | spikes = torch.zeros_like(x).to(x.device)[:, :, 0]
89 | rec_current = torch.zeros_like(x)
90 | mem = torch.zeros_like(x).to(x.device)[:, :, 0]
91 | refac_times = torch.zeros_like(x).to(x.device)[:, :, 0] + self._t_latency
92 |
93 | v_th = torch.ones_like(x).to(x.device)[:, :, 0]
94 | a = torch.zeros_like(x).to(x.device)[:, :, 0]
95 | v_th_list = []
96 |
97 | for t in range(x.shape[2]):
98 | stimulus_current = x[:, :, t]
99 | rec_current[:, :, t] = self.get_rec_input(spikes)
100 |
101 | # Recurrent latency
102 | if t >= self._t_latency and self._recurrent:
103 | input_current = stimulus_current + rec_current[:, :, t-self._t_latency]
104 | else:
105 | input_current = stimulus_current
106 |
107 | # Apply absolute refractory period
108 | refac_times[spikes > 0] = 0
109 | refac_mask = refac_times < self._t_latency
110 | input_current[refac_mask] = 0
111 | refac_times += 1
112 |
113 | new_mem = torch.einsum("bn...,n->bn...", mem, self.beta) + input_current
114 | spikes = surrogate.spike(new_mem - v_th, self._surr_grad)
115 |
116 | mem_list.append(new_mem)
117 | if self._detach_spike_grad:
118 | mem = new_mem * (1 - spikes.detach())
119 | else:
120 | mem = new_mem * (1 - spikes)
121 | # new_mem -= new_mem * spikes (should be same as above?)
122 | spikes_list.append(spikes)
123 |
124 | if self._adapt:
125 | a = self.p * a + spikes
126 | v_th = 1 + self.b * a
127 | v_th_list.append(v_th)
128 |
129 | if mode == "train":
130 | return torch.stack(spikes_list, dim=2)
131 | elif mode == "val":
132 | v_th = torch.stack(v_th_list, dim=2)
133 | v_th = torch.roll(v_th, 1, dims=2)
134 | v_th[:, :, :1] = 1
135 |
136 | return torch.stack(spikes_list, dim=2), torch.stack(mem_list, dim=2), x, v_th
137 |
138 |
139 | class SNNIntegrator(SNN):
140 |
141 | def __init__(self, n_in, n_out, t_len, init_beta=1):
142 | super().__init__(n_in, n_out, 1, t_len, t_latency=0, recurrent=False, beta_grad=True, adapt=False, init_beta=init_beta, init_p=1, detach_spike_grad=True, surr_grad="fast_sigmoid")
143 |
144 | def process(self, x, mode="train"):
145 | mem_list = []
146 | mem = torch.zeros_like(x).to(x.device)[:, :, 0]
147 |
148 | for t in range(x.shape[2]):
149 | input_current = x[:, :, t]
150 |
151 | new_mem = torch.einsum("bn...,n->bn...", mem, self.beta) + input_current
152 | mem_list.append(new_mem)
153 | mem = new_mem
154 |
155 | return torch.stack(mem_list, dim=2)
156 |
--------------------------------------------------------------------------------
/src/snn/surrogate.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 |
6 | class FastSigmoid(torch.autograd.Function):
7 |
8 | @staticmethod
9 | def forward(ctx, input, scale=10):
10 | ctx.scale = scale
11 | ctx.save_for_backward(input)
12 |
13 | return input.gt(0).float()
14 |
15 | @staticmethod
16 | def backward(ctx, grad_output):
17 | input, = ctx.saved_tensors
18 | grad_input = grad_output.clone()
19 | grad = grad_input / (ctx.scale * torch.abs(input) + 1.0) ** 2
20 |
21 | return grad, None
22 |
23 |
24 | class BoxCar(torch.autograd.Function):
25 |
26 | @staticmethod
27 | def forward(ctx, input):
28 | ctx.save_for_backward(input)
29 |
30 | return input.gt(0).float()
31 |
32 | @staticmethod
33 | def backward(ctx, grad_output):
34 | input, = ctx.saved_tensors
35 | grad = grad_output.clone()
36 | grad[input <= -0.5] = 0
37 | grad[input > 0.5] = 0
38 |
39 | return grad
40 |
41 |
42 | class MG(torch.autograd.Function):
43 |
44 | @staticmethod
45 | def forward(ctx, input):
46 | ctx.save_for_backward(input)
47 | return input.gt(0).float()
48 |
49 | @staticmethod
50 | def backward(ctx, grad_output):
51 | input, = ctx.saved_tensors
52 | grad = grad_output.clone()
53 | lens = 0.5
54 | hight = 0.15
55 | scale = 6
56 | gamma = 0.5
57 |
58 | temp = MG.gaussian(input, mu=0., sigma=lens) * (1. + hight) - MG.gaussian(input, mu=lens, sigma=scale * lens) * hight - MG.gaussian(input, mu=-lens, sigma=scale * lens) * hight
59 |
60 | return gamma * grad * temp.float()
61 |
62 | @staticmethod
63 | def gaussian(x, mu=0., sigma=.5):
64 | return torch.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / torch.sqrt(2 * torch.tensor(math.pi)) / sigma
65 |
66 |
67 | def spike(x, type):
68 | if type == "fast_sigmoid":
69 | return FastSigmoid.apply(x)
70 | elif type == "box_car":
71 | return BoxCar.apply(x)
72 | elif type == "mg":
73 | return MG.apply(x)
74 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import logging
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | import numpy as np
9 | import pandas as pd
10 | from brainbox import trainer
11 | from brainbox.physiology.spiking import VanRossum
12 |
13 | from src import datasets, models
14 |
15 |
16 | torch.backends.cudnn.benchmark = True
17 | logger = logging.getLogger("trainer")
18 | logging.basicConfig(stream=sys.stdout, level=logging.INFO)
19 |
20 |
21 | class Trainer(trainer.Trainer):
22 |
23 | def __init__(self, root, model, dataset, n_epochs, batch_size, lr, milestones=[-1], gamma=0.1, val_dataset=None, device="cuda", id=None):
24 | super().__init__(root, model, dataset, n_epochs, batch_size, lr, torch.optim.Adam, device=device, optimizer_kwargs={"eps": 1e-5}, loader_kwargs={"shuffle": True, "pin_memory": True, "num_workers": 16}, id=id)
25 | self._milestones = milestones
26 | self._gamma = gamma
27 | self._val_dataset = val_dataset
28 |
29 | self._times = {"forward_pass": [], "backward_pass": []}
30 | self._train_acc = []
31 | self._val_acc = []
32 | self._min_loss = np.inf
33 | self._milestone_idx = 0
34 |
35 | @staticmethod
36 | def accuracy_metric(output, target):
37 | _, predictions = torch.max(output, 1)
38 | return (predictions == target).sum().cpu().item()
39 |
40 | @staticmethod
41 | def spike_count(output, target):
42 | _, cortical_output, thalamic_output = output
43 |
44 | count = cortical_output[0].sum().cpu().item()
45 | count += thalamic_output[0].sum().cpu().item()
46 |
47 | return count
48 |
49 | @property
50 | def times_path(self):
51 | return os.path.join(self.root, self.id, "times.csv")
52 |
53 | @property
54 | def train_acc_path(self):
55 | return os.path.join(self.root, self.id, "train_acc.csv")
56 |
57 | @property
58 | def val_acc_path(self):
59 | return os.path.join(self.root, self.id, "val_acc.csv")
60 |
61 | def save_model_log(self):
62 | super().save_model_log()
63 |
64 | # Save times
65 | times_df = pd.DataFrame(self._times)
66 | times_df.to_csv(self.times_path, index=False)
67 |
68 | # Save acc
69 | train_acc_df = pd.DataFrame(self._train_acc)
70 | train_acc_df.to_csv(self.train_acc_path, index=False)
71 | val_acc_df = pd.DataFrame(self._val_acc)
72 | val_acc_df.to_csv(self.val_acc_path, index=False)
73 |
74 | def loss(self, output, target, model):
75 | target = target.long()
76 | loss = F.cross_entropy(output, target, reduction="mean")
77 |
78 | return loss
79 |
80 | def train_for_single_epoch(self):
81 | epoch_loss = 0
82 | n_samples = 0
83 | n_correct = 0
84 |
85 | for batch_id, (data, target) in enumerate(self.train_data_loader):
86 | data = data.to(self.device).type(self.dtype)
87 | target = target.to(self.device).type(self.dtype)
88 | torch.cuda.synchronize()
89 |
90 | # Forward pass
91 | start_time = time.time()
92 | output = self.model(data)
93 | torch.cuda.synchronize()
94 | forward_pass_time = time.time() - start_time
95 | self._times["forward_pass"].append(forward_pass_time)
96 |
97 | # Compute accuracy
98 | _, predictions = torch.max(output, 1)
99 | n_correct += (predictions == target).sum().cpu().item()
100 |
101 | # Compute loss
102 | loss = self.loss(output, target, self.model)
103 |
104 | # Backward pass
105 | start_time = time.time()
106 | loss.backward()
107 | torch.cuda.synchronize()
108 | backward_pass_time = time.time() - start_time
109 | self._times["backward_pass"].append(backward_pass_time)
110 |
111 | self.optimizer.step()
112 | self.optimizer.zero_grad()
113 |
114 | with torch.no_grad():
115 | epoch_loss += (loss.item() * data.shape[0])
116 | n_samples += data.shape[0]
117 |
118 | train_acc = n_correct/n_samples
119 | logging.info(f"Train acc: {train_acc}")
120 | self._train_acc.append(train_acc)
121 |
122 | if self._val_dataset is not None and len(self.log["train_loss"]) % 5 == 0:
123 | val_acc = Trainer.get_acc(self.model, self._val_dataset, self.batch_size)
124 | logging.info(f"Val acc: {val_acc}")
125 | self._val_acc.append(val_acc)
126 |
127 | return epoch_loss / n_samples
128 |
129 | @staticmethod
130 | def get_acc(model, dataset, batch_size):
131 | scores = trainer.compute_metric(model, dataset, Trainer.accuracy_metric, batch_size=batch_size)
132 | return np.sum(scores) / len(dataset)
133 |
134 | @staticmethod
135 | def get_spike_count(model, dataset, batch_size):
136 | scores = trainer.compute_metric(model, dataset, Trainer.spike_count, batch_size=batch_size)
137 | return np.sum(scores) / len(dataset)
138 |
139 | def on_epoch_complete(self, save):
140 | if save:
141 | self.save_model_log()
142 |
143 | epoch_loss = self.log["train_loss"][-1]
144 | if epoch_loss < self._min_loss:
145 | logging.info(f"Saving model...")
146 | self._min_loss = epoch_loss
147 | self.save_model()
148 |
149 | n_epoch = len(self.log["train_loss"])
150 |
151 | if n_epoch == self._milestones[self._milestone_idx]:
152 | logging.info(f"Decaying lr...")
153 | self.lr *= self._gamma
154 | # Load best model
155 | self.model = Trainer.load_model(self.root, self.id, self.device, self.dtype)
156 | self.optimizer = self.optimizer_func(
157 | self.model.parameters(), self.lr, **self.optimizer_kwargs
158 | )
159 |
160 | if self._milestone_idx != len(self._milestones) - 1:
161 | logging.info(f"New milestone target...")
162 | self._milestone_idx += 1
163 |
164 | def on_training_complete(self, save):
165 | pass
166 |
167 | @staticmethod
168 | def hyperparams_loader(hyperparams):
169 | model_params = hyperparams["model"]
170 | del model_params["name"]
171 | del model_params["weight_initializers"]
172 |
173 | return models.AuditoryModel(**model_params)
174 |
175 | @staticmethod
176 | def load_model(root, id, device="cuda", dtype=torch.float):
177 | return trainer.load_model(root, id, Trainer.hyperparams_loader, device, dtype)
178 |
179 |
180 | class EphysTrainer(Trainer):
181 |
182 | def __init__(self, root, model, dataset, n_epochs, batch_size, lr, gamma=0.1, dt=0.1, epoch_scan=5, max_decay=1, val_dataset=None, device="cuda", id=None):
183 | super().__init__(root, model, dataset, n_epochs, batch_size, lr, [-1], gamma, val_dataset, device, id)
184 | self.epoch_scan = epoch_scan
185 | self.van_rossum = VanRossum(datasets.EphysDataset.LENGTH, tau=100, dt=dt).to(device)
186 | self.max_decay = max_decay
187 |
188 | self._decay_count = 0
189 |
190 | def loss(self, spikes_pred, spikes):
191 | spike_loss = self.van_rossum(spikes_pred, spikes)
192 |
193 | return spike_loss
194 |
195 | def train_for_single_epoch(self):
196 | epoch_loss = 0
197 | n_samples = 0
198 |
199 | for batch_id, (data, target) in enumerate(self.train_data_loader):
200 | data = data.to(self.device).type(self.dtype)
201 | trace = target[0].to(self.device).type(self.dtype)
202 | spikes = target[1].to(self.device).type(self.dtype)
203 | torch.cuda.synchronize()
204 |
205 | # Forward pass
206 | start_time = time.time()
207 | spikes_pred = self.model(data)
208 | torch.cuda.synchronize()
209 | forward_pass_time = time.time() - start_time
210 | self._times["forward_pass"].append(forward_pass_time)
211 |
212 | # Compute loss
213 | loss = self.loss(spikes_pred, spikes)
214 |
215 | # Backward pass
216 | start_time = time.time()
217 | loss.backward()
218 | torch.cuda.synchronize()
219 | backward_pass_time = time.time() - start_time
220 | self._times["backward_pass"].append(backward_pass_time)
221 |
222 | self.optimizer.step()
223 | self.optimizer.zero_grad()
224 |
225 | with torch.no_grad():
226 | epoch_loss += (loss.item() * data.shape[0])
227 | n_samples += data.shape[0]
228 |
229 | return epoch_loss / n_samples
230 |
231 | def on_epoch_complete(self, save):
232 | if save:
233 | self.save_model_log()
234 |
235 | epoch_loss = self.log["train_loss"][-1]
236 | if epoch_loss < self._min_loss:
237 | logging.info(f"Saving model...")
238 | self._min_loss = epoch_loss
239 | self.save_model()
240 |
241 | min_lost_over_last_epochs = np.array(self.log["train_loss"][-self.epoch_scan:]).min()
242 |
243 | if min_lost_over_last_epochs > self._min_loss:
244 | if self._decay_count < self.max_decay:
245 | self._min_loss = np.inf
246 | self._last_train_scores = []
247 | logging.info(f"Decaying lr...")
248 | self._decay_count += 1
249 | self.lr *= self._gamma
250 | # Load best model
251 | self.model = EphysTrainer.load_model(self.root, self.id, self.device, self.dtype)
252 | self.optimizer = self.optimizer_func(
253 | self.model.parameters(), self.lr, **self.optimizer_kwargs
254 | )
255 | else:
256 | self.exit = True
257 |
258 | @staticmethod
259 | def load_model(root, id, device="cuda", dtype=torch.float, dt01ref=False):
260 |
261 | def model_loader(hyperparams):
262 | model_params = hyperparams["model"]
263 | del model_params["name"]
264 | del model_params["weight_initializers"]
265 |
266 | model_params = {**model_params, "dt01ref": dt01ref}
267 |
268 | return models.Neuron(**model_params)
269 |
270 | return trainer.load_model(root, id, model_loader, device, dtype)
271 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/webstorms/Blocks/540c9f28eedd58ef638dcacf75b4c27cdf52baa0/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 |
4 | from src.snn.block.block import Block
5 |
6 |
7 | @pytest.fixture
8 | def block():
9 | return Block(2, 4, "fast_sigmoid")
10 |
11 |
12 | def test_beta_kernel(block):
13 | # Use a single beta
14 | assert torch.allclose(block.build_beta_kernel(torch.Tensor([0.1]))[0, 0, 0], torch.Tensor([0.0010, 0.0100, 0.1000, 1.0000]))
15 | assert torch.allclose(block.build_beta_kernel(torch.Tensor([0.1]))[1, 0, 0], torch.Tensor([0.0010, 0.0100, 0.1000, 1.0000]))
16 |
17 | # Use multiple beta
18 | assert torch.allclose(block.build_beta_kernel(torch.Tensor([0.1]))[0, 0, 0], torch.Tensor([0.0010, 0.0100, 0.1000, 1.0000]))
19 | assert torch.allclose(block.build_beta_kernel(torch.Tensor([0.5]))[1, 0, 0], torch.Tensor([0.1250, 0.2500, 0.5000, 1.0000]))
20 |
21 |
22 | def test_phi_kernel(block):
23 | assert torch.allclose(block._phi_kernel, torch.Tensor([[[[4., 3., 2., 1.]]]]))
24 |
25 |
26 | def test_g(block):
27 | phi_spikes = torch.zeros(2, 4)
28 | phi_spikes[0, 0] = 1
29 | phi_spikes[0, 2] = 2
30 | phi_spikes[1, 1] = 1
31 | phi_spikes[1, 3] = 3
32 | assert block.g(phi_spikes).sum() == 2
33 |
34 |
35 | def test_differentiable_vars(block):
36 | assert not block._beta_ident_base.requires_grad
37 | assert not block._beta_exp.requires_grad
38 | assert not block._phi_kernel.requires_grad
--------------------------------------------------------------------------------
/tests/test_blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 |
4 | from src.snn.snn import SNN
5 | from src.snn.block.blocks import Blocks
6 |
7 |
8 | @pytest.fixture
9 | def in_spikes(b=4, n=100, t=200):
10 | return torch.rand(b, n, t)
11 |
12 |
13 | def get_models(n_in=100, n_out=100, rf_len=10, t_len=200, t_latency=0, recurrent=True, adapt=False):
14 | blocks = Blocks(n_in, n_out, rf_len, t_len, t_latency, recurrent=recurrent, adapt=adapt)
15 | snn = SNN(n_in, n_out, rf_len, t_len, t_latency, recurrent=recurrent, adapt=adapt)
16 |
17 | blocks._rf_weight = snn._rf_weight
18 | blocks._rf_bias = snn._rf_bias
19 | blocks._rec_weight = snn._rec_weight
20 |
21 | return blocks, snn
22 |
23 |
24 | def test_networks_none(in_spikes):
25 | for t_latency in [0, 1, 2, 4, 8]:
26 | blocks, snn = get_models(t_latency=t_latency, recurrent=False)
27 | spikes1 = blocks(in_spikes, mode="train")
28 | spikes2 = snn(in_spikes, mode="train")
29 |
30 | assert torch.allclose(spikes1, spikes2)
31 |
32 |
33 | def test_networks_recurrent(in_spikes):
34 | for t_latency in [0, 1, 2, 4, 8]:
35 | blocks, snn = get_models(t_latency=t_latency, recurrent=True)
36 | spikes1 = blocks(in_spikes, mode="train")
37 | spikes2 = snn(in_spikes, mode="train")
38 |
39 | assert torch.allclose(spikes1, spikes2)
40 |
41 |
42 | def test_adaption(in_spikes):
43 | for t_latency in [0, 1, 2, 4, 8]:
44 | blocks, snn = get_models(t_latency=t_latency, recurrent=True, adapt=True)
45 | spikes1, mem1, x1, z1, v_th1 = blocks(in_spikes, mode="val")
46 | spikes2, mem2, x2, v_th2 = snn(in_spikes, mode="val")
47 |
48 | # v_th should equal at the start of every block
49 | for i in range(blocks._t_len // blocks._t_len_block):
50 | assert torch.allclose(v_th1[:, :, i*blocks._t_len_block], v_th2[:, :, i*blocks._t_len_block])
51 |
--------------------------------------------------------------------------------