├── README.md
├── __init__.py
├── bindsnet
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── analysis
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── pipeline_analysis.cpython-37.pyc
│ │ ├── plotting.cpython-37.pyc
│ │ └── visualization.cpython-37.pyc
│ ├── pipeline_analysis.py
│ ├── plotting.py
│ └── visualization.py
├── bindsnet-0.2.5.dist-info
│ ├── INSTALLER
│ ├── METADATA
│ ├── RECORD
│ ├── WHEEL
│ └── top_level.txt
├── conversion
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── conversion.cpython-37.pyc
│ └── conversion.py
├── datasets
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── collate.cpython-37.pyc
│ │ ├── dataloader.cpython-37.pyc
│ │ ├── davis.cpython-37.pyc
│ │ ├── preprocess.cpython-37.pyc
│ │ ├── spoken_mnist.cpython-37.pyc
│ │ └── torchvision_wrapper.cpython-37.pyc
│ ├── collate.py
│ ├── dataloader.py
│ ├── davis.py
│ ├── preprocess.py
│ ├── spoken_mnist.py
│ └── torchvision_wrapper.py
├── encoding
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── encoders.cpython-37.pyc
│ │ ├── encodings.cpython-37.pyc
│ │ └── loaders.cpython-37.pyc
│ ├── encoders.py
│ ├── encodings.py
│ └── loaders.py
├── environment
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── environment.cpython-37.pyc
│ └── environment.py
├── evaluation
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── evaluation.cpython-37.pyc
│ └── evaluation.py
├── learning
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── learning.cpython-37.pyc
│ │ └── reward.cpython-37.pyc
│ ├── learning.py
│ └── reward.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── models.cpython-37.pyc
│ └── models.py
├── network
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── monitors.cpython-37.pyc
│ │ ├── network.cpython-37.pyc
│ │ ├── nodes.cpython-37.pyc
│ │ └── topology.cpython-37.pyc
│ ├── monitors.py
│ ├── network.py
│ ├── nodes.py
│ └── topology.py
├── pipeline
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ ├── action.cpython-37.pyc
│ │ ├── base_pipeline.cpython-37.pyc
│ │ ├── dataloader_pipeline.cpython-37.pyc
│ │ └── environment_pipeline.cpython-37.pyc
│ ├── action.py
│ ├── base_pipeline.py
│ ├── dataloader_pipeline.py
│ └── environment_pipeline.py
├── preprocessing
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-37.pyc
│ │ └── preprocessing.cpython-37.pyc
│ └── preprocessing.py
└── utils.py
├── conversion.py
└── vgg.py
/README.md:
--------------------------------------------------------------------------------
1 | # Exploring the Connection Between Binary and Spiking Neural Networks
2 |
3 | ## Overview
4 |
10 | This codebase outlines a training methodology and provides trained models for Full Precision and Binary Spiking Neural Networks (B-SNNs) utilizing [BindsNet](https://github.com/BindsNET/bindsnet) for large-scale datasets, namely CIFAR-100 and ImageNet. Following the proposed procedures and design features mentioned in [our work](https://www.frontiersin.org/article/10.3389/fnins.2020.00535), we have shown that B-SNNs exhibit near full-precision accuracy even with many SNN-specific constraints. Additionally, we used ANN-SNN conversion technique for training and explored a novel set of optimizations for generating high accuracy and low latency SNNs. The optimization techniques also apply to the full precision ANN-SNN conversion.
11 |
12 | ## Requirements
13 |
14 | - A Python installation version 3.6 or above
15 | - The matplotlib, numpy, tqdm, and torchvision
16 | - A PyTorch install version 1.3.0 ([pytorch.org](http://pytorch.org))
17 | - CUDA 10.1
18 | - The ImageNet dataset (which can be automatically downloaded by a recent version of [torchvision](https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet)) (If needed)
19 |
20 | ## Training from scratch
21 | We explored various network architectures constrained by ANN-SNN conversion. The finalized network structure can be found in ```vgg.py```. Further details can be found in the [paper](http://arxiv.org/abs/2002.10064).
22 |
23 | ### Hyperparameter Settings
24 | | Model | Batch Size | Epoch | Learning Rate | Weight Decay | Optimizer |
25 | | ---- | ---- | ---- | ---- | ---- | ---- |
26 | | CIFAR-100 Full Precision | 256 | 200 | 5e-2, divided by 10 at 81 and 122 epoch | 1e-4 | SGD (momentum=0.9) |
27 | | CIFAR-100 Binary | 256 | 200 | 5e-4, halved every 30 epochs | 5e-4 (0 after 30 epochs) | Adam |
28 | | ImageNet Full Precision| 128 | 100 | 1e-2, divided by 10 every 30 epochs | 1e-4 | SGD (momentum=0.9) |
29 | | ImageNet Binary | 128 | 100 | 5e-4, halved every 30 epochs | 5e-4 (0 after 30 epochs) | Adam(**beta=(0.0,0.999)**) |
30 |
31 | Note that these hyper-parameters may be further optimized.
32 |
33 | ## Evaluating Pre-trained models
34 | We provide pre-trained models of the VGG architecture mentioned in the paper and described above, available for download. Note that the first and the last layers are not binarized for our models. The corresponding top-1 accuracies are indicated in parentheses.
35 |
36 | * [CIFAR-100 Full Precision ANN (64.9%)](https://drive.google.com/open?id=1ZmagwfBdWVVztCdn67gmAWtfQJY3yrev)
37 | * [CIFAR-100 Binary ANN (64.8%)](https://drive.google.com/open?id=1605x2i_noKiQ-Z4OZW9L__deR_ubvfGS)
38 | * [ImageNet Full Precision ANN (69.05%)](https://drive.google.com/open?id=1SHXlvUrkPAkl8nQ8_LCNja5ypkqh59_x)
39 | * [ImageNet Binary ANN (64.4%)](https://drive.google.com/open?id=12WeIAfrVNxD45NFv4HV1nSvrLa3rRZp_)
40 |
41 | The Full Precision ANNs are trained using standard [PyTorch training practices](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) and the binarization process utilizes part of the [XNOR-Net-Pytorch](https://github.com/jiecaoyu/XNOR-Net-PyTorch) script which is the python implementation of the original [XNOR-Net](https://github.com/allenai/XNOR-Net) script.
42 |
43 | ## Running a simulation
44 |
45 | Prepare the pre-trained model and move to the same directory, and run the following code for each model:
46 |
47 | ```python conversion.py --job-dir cifar100_test --gpu --dataset cifar100 --data . --percentile 99.9 --norm 3500 --time 100 --arch vgg15ab --model bin_cifar100.pth.tar```
48 |
49 | Full documentation of the arguments in `conversion.py`:
50 | ```
51 | usage: conversion.py [-h] --job-dir JOB_DIR --model MODEL
52 | [--results-file RESULTS_FILE] [--seed SEED] [--time TIME]
53 | [--batch-size BATCH_SIZE] [--n-workers N_WORKERS]
54 | [--norm NORM] [--gpu] [--one-step] [--data DATA_PATH]
55 | [--arch ARCH] [--percentile PERCENTILE]
56 | [--eval_size EVAL_SIZE] [--dataset DATASET]
57 |
58 | optional arguments:
59 | -h, --help show this help message and exit
60 | --job-dir JOB_DIR The working directory to store results
61 | --model MODEL The path to the pretrained model
62 | --results-file RESULTS_FILE
63 | The file to store simulation result
64 | --seed SEED A random seed
65 | --time TIME Time steps to be simulated by the converted SNN
66 | (default: 80)
67 | --batch-size BATCH_SIZE
68 | Mini batch size
69 | --n-workers N_WORKERS
70 | Number of data loaders
71 | --norm NORM The amount of data to be normalized at once
72 | --gpu Whether to use GPU or not
73 | --one-step Single step inference flag
74 | --data DATA_PATH The path to ImageNet data (default: './data/)',
75 | CIFAR-100 will be downloaded
76 | --arch ARCH ANN architecture to be instantiated
77 | --percentile PERCENTILE
78 | The percentile of activation in the training set to be
79 | used for normalization of SNN voltage threshold
80 | --eval_size EVAL_SIZE
81 | The amount of samples to be evaluated (default:
82 | evaluate all)
83 | --dataset DATASET cifar100 or imagenet
84 | ```
85 | Depending on your computing resources, some settings can be changed to speed up or to accommodate the available device. ```--norm```, ```--batch-size```, and ```--time``` can be changed for better performance.
86 |
87 | ## Reference
88 |
89 | If you use this code, please cite the following paper:
90 |
91 | Sen Lu and Abhronil Sengupta. "Exploring the Connection Between Binary and Spiking Neural Networks", Frontiers in Neuroscience, Vol. 14, pp. 535 (2020).
92 |
93 | ```
94 | @ARTICLE{10.3389/fnins.2020.00535,
95 | AUTHOR={Lu, Sen and Sengupta, Abhronil},
96 | TITLE={Exploring the Connection Between Binary and Spiking Neural Networks},
97 | JOURNAL={Frontiers in Neuroscience},
98 | VOLUME={14},
99 | PAGES={535},
100 | YEAR={2020},
101 | URL={https://www.frontiersin.org/article/10.3389/fnins.2020.00535},
102 | DOI={10.3389/fnins.2020.00535},
103 | ISSN={1662-453X}
104 | }
105 | ```
106 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | from MLP import MLP
2 | from alexnet import AlexNet
--------------------------------------------------------------------------------
/bindsnet/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from . import (
4 | utils,
5 | network,
6 | models,
7 | analysis,
8 | preprocessing,
9 | datasets,
10 | encoding,
11 | pipeline,
12 | learning,
13 | evaluation,
14 | environment,
15 | conversion,
16 | )
17 |
18 | ROOT_DIR = Path(__file__).parents[0].parents[0]
19 |
--------------------------------------------------------------------------------
/bindsnet/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/analysis/__init__.py:
--------------------------------------------------------------------------------
1 | from . import plotting, visualization, pipeline_analysis
2 |
--------------------------------------------------------------------------------
/bindsnet/analysis/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/analysis/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/analysis/__pycache__/pipeline_analysis.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/analysis/__pycache__/pipeline_analysis.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/analysis/__pycache__/plotting.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/analysis/__pycache__/plotting.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/analysis/__pycache__/visualization.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/analysis/__pycache__/visualization.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/analysis/pipeline_analysis.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, Optional
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from tensorboardX import SummaryWriter
9 | from torchvision.utils import make_grid
10 |
11 | from .plotting import plot_spikes, plot_voltages, plot_conv2d_weights
12 | from ..utils import reshape_conv2d_weights
13 |
14 |
15 | class PipelineAnalyzer(ABC):
16 | # language=rst
17 | """
18 | Responsible for pipeline analysis. Subclasses maintain state
19 | information related to plotting or logging.
20 | """
21 |
22 | @abstractmethod
23 | def finalize_step(self) -> None:
24 | # language=rst
25 | """
26 | Flush the output from the current step.
27 | """
28 | pass
29 |
30 | @abstractmethod
31 | def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None:
32 | # language=rst
33 | """
34 | Pulls the observation from PyTorch and sets up for Matplotlib
35 | plotting.
36 |
37 | :param obs: A 2D array of floats depicting an input image.
38 | :param tag: A unique tag to associate the data with.
39 | :param step: The step of the pipeline.
40 | """
41 | pass
42 |
43 | @abstractmethod
44 | def plot_reward(
45 | self,
46 | reward_list: list,
47 | reward_window: int = None,
48 | tag: str = "reward",
49 | step: int = None,
50 | ) -> None:
51 | # language=rst
52 | """
53 | Plot the accumulated reward for each episode.
54 |
55 | :param reward_list: The list of recent rewards to be plotted.
56 | :param reward_window: The length of the window to compute a moving average over.
57 | :param tag: A unique tag to associate the data with.
58 | :param step: The step of the pipeline.
59 | """
60 | pass
61 |
62 | @abstractmethod
63 | def plot_spikes(
64 | self,
65 | spike_record: Dict[str, torch.Tensor],
66 | tag: str = "spike",
67 | step: int = None,
68 | ) -> None:
69 | # language=rst
70 | """
71 | Plots all spike records inside of ``spike_record``. Keeps unique
72 | plots for all unique tags that are given.
73 |
74 | :param spike_record: Dictionary of spikes to be rasterized.
75 | :param tag: A unique tag to associate the data with.
76 | :param step: The step of the pipeline.
77 | """
78 | pass
79 |
80 | @abstractmethod
81 | def plot_voltages(
82 | self,
83 | voltage_record: Dict[str, torch.Tensor],
84 | thresholds: Optional[Dict[str, torch.Tensor]] = None,
85 | tag: str = "voltage",
86 | step: int = None,
87 | ) -> None:
88 | # language=rst
89 | """
90 | Plots all voltage records and given thresholds. Keeps unique
91 | plots for all unique tags that are given.
92 |
93 | :param voltage_record: Dictionary of voltages for neurons inside of networks
94 | organized by the layer they correspond to.
95 | :param thresholds: Optional dictionary of threshold values for neurons.
96 | :param tag: A unique tag to associate the data with.
97 | :param step: The step of the pipeline.
98 | """
99 | pass
100 |
101 | @abstractmethod
102 | def plot_conv2d_weights(
103 | self, weights: torch.Tensor, tag: str = "conv2d", step: int = None
104 | ) -> None:
105 | # language=rst
106 | """
107 | Plot a connection weight matrix of a ``Conv2dConnection``.
108 |
109 | :param weights: Weight matrix of ``Conv2dConnection`` object.
110 | :param tag: A unique tag to associate the data with.
111 | :param step: The step of the pipeline.
112 | """
113 | pass
114 |
115 |
116 | class MatplotlibAnalyzer(PipelineAnalyzer):
117 | # language=rst
118 | """
119 | Renders output using Matplotlib.
120 |
121 | Matplotlib requires objects to be kept around over the full lifetime
122 | of the plots; this is done through ``self.plots``. An interactive session
123 | is needed so that we can continue processing and just update the
124 | plots.
125 | """
126 |
127 | def __init__(self, **kwargs) -> None:
128 | # language=rst
129 | """
130 | Initializes the analyzer.
131 |
132 | Keyword arguments:
133 |
134 | :param str volts_type: Type of plotting for voltages (``"color"`` or ``"line"``).
135 | """
136 | self.volts_type = kwargs.get("volts_type", "color")
137 | plt.ion()
138 | self.plots = {}
139 |
140 | def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None:
141 | # language=rst
142 | """
143 | Pulls the observation off of torch and sets up for Matplotlib
144 | plotting.
145 |
146 | :param obs: A 2D array of floats depicting an input image.
147 | :param tag: A unique tag to associate the data with.
148 | :param step: The step of the pipeline.
149 | """
150 | obs = obs.detach().cpu().numpy()
151 | obs = np.transpose(obs, (1, 2, 0)).squeeze()
152 |
153 | if tag in self.plots:
154 | obs_ax, obs_im = self.plots[tag]
155 | else:
156 | obs_ax, obs_im = None, None
157 |
158 | if obs_im is None and obs_ax is None:
159 | fig, obs_ax = plt.subplots()
160 | obs_ax.set_title("Observation")
161 | obs_ax.set_xticks(())
162 | obs_ax.set_yticks(())
163 | obs_im = obs_ax.imshow(obs, cmap="gray")
164 |
165 | self.plots[tag] = obs_ax, obs_im
166 | else:
167 | obs_im.set_data(obs)
168 |
169 | def plot_reward(
170 | self,
171 | reward_list: list,
172 | reward_window: int = None,
173 | tag: str = "reward",
174 | step: int = None,
175 | ) -> None:
176 | # language=rst
177 | """
178 | Plot the accumulated reward for each episode.
179 |
180 | :param reward_list: The list of recent rewards to be plotted.
181 | :param reward_window: The length of the window to compute a moving average over.
182 | :param tag: A unique tag to associate the data with.
183 | :param step: The step of the pipeline.
184 | """
185 | if tag in self.plots:
186 | reward_im, reward_ax, reward_plot = self.plots[tag]
187 | else:
188 | reward_im, reward_ax, reward_plot = None, None, None
189 |
190 | # Compute moving average.
191 | if reward_window is not None:
192 | # Ensure window size > 0 and < size of reward list.
193 | window = max(min(len(reward_list), reward_window), 0)
194 |
195 | # Fastest implementation of moving average.
196 | reward_list_ = (
197 | pd.Series(reward_list)
198 | .rolling(window=window, min_periods=1)
199 | .mean()
200 | .values
201 | )
202 | else:
203 | reward_list_ = reward_list[:]
204 |
205 | if reward_im is None and reward_ax is None:
206 | reward_im, reward_ax = plt.subplots()
207 | reward_ax.set_title("Accumulated reward")
208 | reward_ax.set_xlabel("Episode")
209 | reward_ax.set_ylabel("Reward")
210 | (reward_plot,) = reward_ax.plot(reward_list_)
211 |
212 | self.plots[tag] = reward_im, reward_ax, reward_plot
213 | else:
214 | reward_plot.set_data(range(len(reward_list_)), reward_list_)
215 | reward_ax.relim()
216 | reward_ax.autoscale_view()
217 |
218 | def plot_spikes(
219 | self,
220 | spike_record: Dict[str, torch.Tensor],
221 | tag: str = "spike",
222 | step: int = None,
223 | ) -> None:
224 | # language=rst
225 | """
226 | Plots all spike records inside of ``spike_record``. Keeps unique
227 | plots for all unique tags that are given.
228 |
229 | :param spike_record: Dictionary of spikes to be rasterized.
230 | :param tag: A unique tag to associate the data with.
231 | :param step: The step of the pipeline.
232 | """
233 | if tag not in self.plots:
234 | self.plots[tag] = plot_spikes(spike_record)
235 | else:
236 | s_im, s_ax = self.plots[tag]
237 | self.plots[tag] = plot_spikes(spike_record, ims=s_im, axes=s_ax)
238 |
239 | def plot_voltages(
240 | self,
241 | voltage_record: Dict[str, torch.Tensor],
242 | thresholds: Optional[Dict[str, torch.Tensor]] = None,
243 | tag: str = "voltage",
244 | step: int = None,
245 | ) -> None:
246 | # language=rst
247 | """
248 | Plots all voltage records and given thresholds. Keeps unique
249 | plots for all unique tags that are given.
250 |
251 | :param voltage_record: Dictionary of voltages for neurons inside of networks
252 | organized by the layer they correspond to.
253 | :param thresholds: Optional dictionary of threshold values for neurons.
254 | :param tag: A unique tag to associate the data with.
255 | :param step: The step of the pipeline.
256 | """
257 | if tag not in self.plots:
258 | self.plots[tag] = plot_voltages(
259 | voltage_record, plot_type=self.volts_type, thresholds=thresholds
260 | )
261 | else:
262 | v_im, v_ax = self.plots[tag]
263 | self.plots[tag] = plot_voltages(
264 | voltage_record,
265 | ims=v_im,
266 | axes=v_ax,
267 | plot_type=self.volts_type,
268 | thresholds=thresholds,
269 | )
270 |
271 | def plot_conv2d_weights(
272 | self, weights: torch.Tensor, tag: str = "conv2d", step: int = None
273 | ) -> None:
274 | # language=rst
275 | """
276 | Plot a connection weight matrix of a ``Conv2dConnection``.
277 |
278 | :param weights: Weight matrix of ``Conv2dConnection`` object.
279 | :param tag: A unique tag to associate the data with.
280 | :param step: The step of the pipeline.
281 | """
282 | wmin = weights.min().item()
283 | wmax = weights.max().item()
284 |
285 | if tag not in self.plots:
286 | self.plots[tag] = plot_conv2d_weights(weights, wmin, wmax)
287 | else:
288 | im = self.plots[tag]
289 | plot_conv2d_weights(weights, wmin, wmax, im=im)
290 |
291 | def finalize_step(self) -> None:
292 | # language=rst
293 | """
294 | Flush the output from the current step
295 | """
296 | plt.draw()
297 | plt.pause(1e-8)
298 | plt.show()
299 |
300 |
301 | class TensorboardAnalyzer(PipelineAnalyzer):
302 | def __init__(self, summary_directory: str = "./logs"):
303 | # language=rst
304 | """
305 | Initializes the analyzer.
306 |
307 | :param summary_directory: Directory to save log files.
308 | """
309 | self.writer = SummaryWriter(summary_directory)
310 |
311 | def finalize_step(self) -> None:
312 | # language=rst
313 | """
314 | No-op for ``TensorboardAnalyzer``.
315 | """
316 | pass
317 |
318 | def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None:
319 | # language=rst
320 | """
321 | Pulls the observation off of torch and sets up for Matplotlib
322 | plotting.
323 |
324 | :param obs: A 2D array of floats depicting an input image.
325 | :param tag: A unique tag to associate the data with.
326 | :param step: The step of the pipeline.
327 | """
328 | obs_grid = make_grid(obs.float(), nrow=4, normalize=True)
329 | self.writer.add_image(tag, obs_grid, step)
330 |
331 | def plot_reward(
332 | self,
333 | reward_list: list,
334 | reward_window: int = None,
335 | tag: str = "reward",
336 | step: int = None,
337 | ) -> None:
338 | # language=rst
339 | """
340 | Plot the accumulated reward for each episode.
341 |
342 | :param reward_list: The list of recent rewards to be plotted.
343 | :param reward_window: The length of the window to compute a moving average over.
344 | :param tag: A unique tag to associate the data with.
345 | :param step: The step of the pipeline.
346 | """
347 | self.writer.add_scalar(tag, reward_list[-1], step)
348 |
349 | def plot_spikes(
350 | self,
351 | spike_record: Dict[str, torch.Tensor],
352 | tag: str = "spike",
353 | step: int = None,
354 | ) -> None:
355 | # language=rst
356 | """
357 | Plots all spike records inside of ``spike_record``. Keeps unique
358 | plots for all unique tags that are given.
359 |
360 | :param spike_record: Dictionary of spikes to be rasterized.
361 | :param tag: A unique tag to associate the data with.
362 | :param step: The step of the pipeline.
363 | """
364 | for k, spikes in spike_record.items():
365 | # shuffle spikes into 1x1x#NueronsxT
366 | spikes = spikes.view(1, 1, -1, spikes.shape[-1]).float()
367 | spike_grid_img = make_grid(spikes, nrow=1, pad_value=0.5)
368 |
369 | self.writer.add_image(tag + "_" + str(k), spike_grid_img, step)
370 |
371 | def plot_voltages(
372 | self,
373 | voltage_record: Dict[str, torch.Tensor],
374 | thresholds: Optional[Dict[str, torch.Tensor]] = None,
375 | tag: str = "voltage",
376 | step: int = None,
377 | ) -> None:
378 | # language=rst
379 | """
380 | Plots all voltage records and given thresholds. Keeps unique
381 | plots for all unique tags that are given.
382 |
383 | :param voltage_record: Dictionary of voltages for neurons inside of networks
384 | organized by the layer they correspond to.
385 | :param thresholds: Optional dictionary of threshold values for neurons.
386 | :param tag: A unique tag to associate the data with.
387 | :param step: The step of the pipeline.
388 | """
389 | for k, v in voltage_record.items():
390 | # Shuffle voltages into 1x1x#neuronsxT
391 | v = v.view(1, 1, -1, v.shape[-1])
392 | voltage_grid_img = make_grid(v, nrow=1, pad_value=0)
393 |
394 | self.writer.add_image(tag + "_" + str(k), voltage_grid_img, step)
395 |
396 | def plot_conv2d_weights(
397 | self, weights: torch.Tensor, tag: str = "conv2d", step: int = None
398 | ) -> None:
399 | # language=rst
400 | """
401 | Plot a connection weight matrix of a ``Conv2dConnection``.
402 |
403 | :param weights: Weight matrix of ``Conv2dConnection`` object.
404 | :param tag: A unique tag to associate the data with.
405 | :param step: The step of the pipeline.
406 | """
407 | reshaped = reshape_conv2d_weights(weights).unsqueeze(0)
408 |
409 | reshaped -= reshaped.min()
410 | reshaped /= reshaped.max()
411 |
412 | self.writer.add_image(tag, reshaped, step)
413 |
--------------------------------------------------------------------------------
/bindsnet/analysis/visualization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import matplotlib.animation as animation
5 |
6 | from typing import List, Tuple, Optional
7 |
8 |
9 | def plot_weights_movie(ws: np.ndarray, sample_every: int = 1) -> None:
10 | # language=rst
11 | """
12 | Create and plot movie of weights.
13 |
14 | :param ws: Array of shape ``[n_examples, source, target, time]``
15 | :param sample_every: Sub-sample using this parameter.
16 | """
17 | weights = []
18 |
19 | # Obtain samples from the weights for every example.
20 | for i in range(ws.shape[0]):
21 | sub_sampled_weight = ws[i, :, :, range(0, ws[i].shape[2], sample_every)]
22 | weights.append(sub_sampled_weight)
23 | else:
24 | weights = np.concatenate(weights, axis=0)
25 |
26 | # Initialize plot.
27 | fig = plt.figure()
28 | im = plt.imshow(weights[0, :, :], cmap="hot_r", animated=True, vmin=0, vmax=1)
29 | plt.axis("off")
30 | plt.colorbar(im)
31 |
32 | # Update function for the animation.
33 | def update(j):
34 | im.set_data(weights[j, :, :])
35 | return [im]
36 |
37 | # Initialize animation.
38 | global ani
39 | ani = 0
40 | ani = animation.FuncAnimation(
41 | fig, update, frames=weights.shape[-1], interval=1000, blit=True
42 | )
43 | plt.show()
44 |
45 |
46 | def plot_spike_trains_for_example(
47 | spikes: torch.Tensor,
48 | n_ex: Optional[int] = None,
49 | top_k: Optional[int] = None,
50 | indices: Optional[List[int]] = None,
51 | ) -> None:
52 | # language=rst
53 | """
54 | Plot spike trains for top-k neurons or for specific indices.
55 |
56 | :param spikes: Spikes for one simulation run of shape ``(n_examples, n_neurons, time)``.
57 | :param n_ex: Allows user to pick which example to plot spikes for.
58 | :param top_k: Plot k neurons that spiked the most for n_ex example.
59 | :param indices: Plot specific neurons' spiking activity instead of top_k.
60 | """
61 | assert n_ex is not None and 0 <= n_ex < spikes.shape[0]
62 |
63 | plt.figure()
64 |
65 | if top_k is None and indices is None: # Plot all neurons' spiking activity
66 | spike_per_neuron = [np.argwhere(i == 1).flatten() for i in spikes[n_ex, :, :]]
67 | plt.title("Spiking activity for all %d neurons" % spikes.shape[1])
68 |
69 | elif top_k is None: # Plot based on indices parameter
70 | assert indices is not None
71 | spike_per_neuron = [
72 | np.argwhere(i == 1).flatten() for i in spikes[n_ex, indices, :]
73 | ]
74 |
75 | elif indices is None: # Plot based on top_k parameter
76 | assert top_k is not None
77 | # Obtain the top k neurons that fired the most
78 | top_k_loc = np.argsort(np.sum(spikes[n_ex, :, :], axis=1), axis=0)[::-1]
79 | spike_per_neuron = [
80 | np.argwhere(i == 1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :]
81 | ]
82 | plt.title("Spiking activity for top %d neurons" % top_k)
83 |
84 | else:
85 | raise ValueError('One of "top_k" or "indices" or both must be None')
86 |
87 | plt.eventplot(spike_per_neuron, linelengths=[0.5] * len(spike_per_neuron))
88 | plt.xlabel("Simulation Time")
89 | plt.ylabel("Neuron index")
90 | plt.show()
91 |
92 |
93 | def plot_voltage(
94 | voltage: torch.Tensor,
95 | n_ex: int = 0,
96 | n_neuron: int = 0,
97 | time: Optional[Tuple[int, int]] = None,
98 | threshold: float = None,
99 | ) -> None:
100 | # language=rst
101 | """
102 | Plot voltage for a single neuron on a specific example.
103 |
104 | :param voltage: Tensor or array of shape ``[n_examples, n_neurons, time]``.
105 | :param n_ex: Allows user to pick which example to plot voltage for.
106 | :param n_neuron: Neuron index for which to plot voltages for.
107 | :param time: Plot spiking activity of neurons between the given range of time.
108 | :param threshold: Neuron spiking threshold.
109 | """
110 | assert n_ex >= 0 and n_neuron >= 0
111 | assert n_ex < voltage.shape[0] and n_neuron < voltage.shape[1]
112 |
113 | if time is None:
114 | time = (0, voltage.shape[-1])
115 | else:
116 | assert time[0] < time[1]
117 | assert time[1] <= voltage.shape[-1]
118 |
119 | timer = np.arange(time[0], time[1])
120 | time_ticks = np.arange(time[0], time[1] + 1, 10)
121 |
122 | plt.figure()
123 | plt.plot(voltage[n_ex, n_neuron, timer])
124 | plt.xlabel("Simulation Time")
125 | plt.ylabel("Voltage")
126 | plt.title("Membrane voltage of neuron %d for example %d" % (n_neuron, n_ex + 1))
127 | locs, labels = plt.xticks()
128 | locs = range(int(locs[1]), int(locs[-1]), 10)
129 | plt.xticks(locs, time_ticks)
130 |
131 | # Draw threshold line only if given
132 | if threshold is not None:
133 | plt.axhline(threshold, linestyle="--", color="black", zorder=0)
134 |
135 | plt.show()
136 |
--------------------------------------------------------------------------------
/bindsnet/bindsnet-0.2.5.dist-info/INSTALLER:
--------------------------------------------------------------------------------
1 | pip
2 |
--------------------------------------------------------------------------------
/bindsnet/bindsnet-0.2.5.dist-info/METADATA:
--------------------------------------------------------------------------------
1 | Metadata-Version: 2.1
2 | Name: bindsnet
3 | Version: 0.2.5
4 | Summary: Spiking neural networks for ML in Python
5 | Home-page: http://github.com/Hananel-Hazan/bindsnet
6 | Author: Daniel Saunders, Hananel Hazan, Darpan Sanghavi, Hassaan Khan
7 | Author-email: danjsaund@gmail.com
8 | License: AGPL-3.0
9 | Download-URL: https://github.com/Hananel-Hazan/bindsnet/archive/0.2.5.tar.gz
10 | Platform: UNKNOWN
11 | Description-Content-Type: text/markdown
12 | Requires-Dist: numpy (>=1.14.2)
13 | Requires-Dist: torch (>=1.2.0)
14 | Requires-Dist: torchvision (>=0.4.0)
15 | Requires-Dist: tensorboardX (>=1.7)
16 | Requires-Dist: tqdm (>=4.19.9)
17 | Requires-Dist: matplotlib (>=2.1.0)
18 | Requires-Dist: gym (>=0.10.4)
19 | Requires-Dist: scikit-image (>=0.13.1)
20 | Requires-Dist: scikit-learn (>=0.19.1)
21 | Requires-Dist: opencv-python (>=3.4.0.12)
22 | Requires-Dist: pytest (>=3.4.0)
23 | Requires-Dist: scipy (>=1.1.0)
24 | Requires-Dist: cython (>=0.28.5)
25 | Requires-Dist: pandas (>=0.23.4)
26 |
27 |

28 |
29 | A Python package used for simulating spiking neural networks (SNNs) on CPUs or GPUs using [PyTorch](http://pytorch.org/) `Tensor` functionality.
30 |
31 | BindsNET is a spiking neural network simulation library geared towards the development of biologically inspired algorithms for machine learning.
32 |
33 | This package is used as part of ongoing research on applying SNNs to machine learning (ML) and reinforcement learning (RL) problems in the [Biologically Inspired Neural & Dynamical Systems (BINDS) lab](http://binds.cs.umass.edu/).
34 |
35 | Check out the [BindsNET experiments repository](https://github.com/djsaunde/bindsnet_experiments) for a collection of experiments, accompanying bash scripts for dispatching on [CICS](https://www.cics.umass.edu/) clusters, functions for the analysis of results, plots of experiment outcomes, and more.
36 |
37 | [](https://travis-ci.com/Hananel-Hazan/bindsnet)
38 | [](https://bindsnet-docs.readthedocs.io/?badge=latest)
39 | [](http://hits.dwyl.io/Hananel-Hazan/bindsnet)
40 | [](https://gitter.im/bindsnet_/community)
41 |
42 | ## Requirements
43 |
44 | - Python 3.6
45 | - `requirements.txt`
46 |
47 | ## Setting things up
48 |
49 | ### Using pip
50 | BindsNET is available on PyPI. Issue
51 |
52 | ```
53 | pip install bindsnet
54 | ```
55 |
56 | to get the most recent stable release. Or, to build the `bindsnet` package from source, clone the GitHub repository, change directory to the top level of this project, and issue
57 |
58 | ```
59 | pip install .
60 | ```
61 |
62 | Or, to install in editable mode (allows modification of package without re-installing):
63 |
64 | ```
65 | pip install -e .
66 | ```
67 |
68 | To install the packages necessary to interface with the [OpenAI gym RL environments library](https://github.com/openai/gym), follow their instructions for installing the packages needed to run the RL environments simulator (on Linux / MacOS).
69 |
70 | ### Using Docker
71 | [Link](https://hub.docker.com/r/hqkhan/bindsnet/) to Docker repository.
72 |
73 | We also provide a Dockerfile in which BindsNET and all of its dependencies come installed in. Issue
74 |
75 | ```
76 | docker image build .
77 | ```
78 | at the top level directory of this project to create a docker image.
79 |
80 | To change the name of the newly built image, issue
81 | ```
82 | docker tag
83 | ```
84 |
85 | To run a container and get a bash terminal inside it, issue
86 |
87 | ```
88 | docker run -it bash
89 | ```
90 |
91 | ## Getting started
92 |
93 | To run a near-replication of the SNN from [this paper](https://www.frontiersin.org/articles/10.3389/fncom.2015.00099/full#), issue
94 |
95 | ```
96 | cd examples/mnist
97 | python eth_mnist.py
98 | ```
99 |
100 | There are a number of optional command-line arguments which can be passed in, including `--plot` (displays useful monitoring figures), `--n_neurons [int]` (number of excitatory, inhibitory neurons simulated), `--mode ['train' | 'test']` (sets network operation to the training or testing phase), and more. Run the script with the `--help` or `-h` flag for more information.
101 |
102 | A number of other examples are available in the `examples` directory that are meant to showcase BindsNET's functionality. Take a look, and let us know what you think!
103 |
104 | ## Running the tests
105 |
106 | Issue the following to run the tests:
107 |
108 | ```
109 | python -m pytest test/
110 | ```
111 |
112 | Some tests will fail if Open AI `gym` is not installed on your machine.
113 |
114 | ## Background
115 |
116 | The simulation of biologically plausible spiking neuron dynamics can be challenging. It is typically done by solving ordinary differential equations (ODEs) which describe said dynamics. PyTorch does not explicitly support the solution of differential equations (as opposed to [`brian2`](https://github.com/brian-team/brian2), for example), but we can convert the ODEs defining the dynamics into difference equations and solve them at regular, short intervals (a `dt` on the order of 1 millisecond) as an approximation. Of course, under the hood, packages like `brian2` are doing the same thing. Doing this in [`PyTorch`](http://pytorch.org/) is exciting for a few reasons:
117 |
118 | 1. We can use the powerful and flexible [`torch.Tensor`](http://pytorch.org/) object, a wrapper around the [`numpy.ndarray`](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.html) which can be transferred to and from GPU devices.
119 |
120 | 2. We can avoid "reinventing the wheel" by repurposing functions from the [`torch.nn.functional`](http://pytorch.org/docs/master/nn.html#torch-nn-functional) PyTorch submodule in our SNN architectures; e.g., convolution or pooling functions.
121 |
122 | The concept that the neuron spike ordering and their relative timing encode information is a central theme in neuroscience. [Markram et al. (1997)](http://www.caam.rice.edu/~caam415/lec_gab/g4/markram_etal98.pdf) proposed that synapses between neurons should strengthen or degrade based on this relative timing, and prior to that, [Donald Hebb](https://en.wikipedia.org/wiki/Donald_O._Hebb) proposed the theory of Hebbian learning, often simply stated as "Neurons that fire together, wire together." Markram et al.'s extension of the Hebbian theory is known as spike-timing-dependent plasticity (STDP).
123 |
124 | We are interested in applying SNNs to ML and RL problems. We use STDP to modify weights of synapses connecting pairs or populations of neurons in SNNs. In the context of ML, we want to learn a setting of synapse weights which will generate data-dependent spiking activity in SNNs. This activity will allow us to subsequently perform some ML task of interest; e.g., discriminating or clustering input data. In the context of RL, we may think of the spiking neural network as an RL agent, whose spiking activity may be converted into actions in an environment's action space.
125 |
126 | We have provided some simple starter scripts for doing unsupervised learning (learning a fully-connected or convolutional representation via STDP), supervised learning (clamping output neurons to desired spiking behavior depending on data labels), and reinforcement learning (converting observations from the Atari game Space Invaders to input to an SNN, and converting network activity back to actions in the game).
127 |
128 | ## Benchmarking
129 | We simulated a network with a population of n Poisson input neurons with firing rates (in Hertz) drawn randomly from U(0, 100), connected all-to-all with a equally-sized population of leaky integrate-and-fire (LIF) neurons, with connection weights sampled from N(0,1). We varied n systematically from 250 to 10,000 in steps of 250, and ran each simulation with every library for 1,000ms with a time resolution dt = 1.0. We tested BindsNET (with CPU and GPU computation), BRIAN2, PyNEST (the Python interface to the NEST SLI interface that runs the C++NEST core simulator), ANNarchy (with CPU and GPU computation), and BRIAN2genn (the BRIAN2 front-end to the GeNN simulator).
130 |
131 | Several packages, including BRIAN and PyNEST, allow the setting of certain global preferences; e.g., the number of CPU threads, the number of OpenMP processes, etc. We chose these settings for our benchmark study in an attempt to maximize each library's speed, but note that BindsNET requires no setting of such options. Our approach, inheriting the computational model of PyTorch, appears to make the best use of the available hardware, and therefore makes it simple for practicioners to get the best performance from their system with the least effort.
132 |
133 |
134 |
135 |
136 |
137 | All simulations run on Ubuntu 16.04 LTS with Intel(R) Xeon(R) CPU E5-2687W v3 @ 3.10GHz, 128Gb RAM @ 2133MHz, and two GeForce GTX TITAN X (GM200) GPUs. Python 3.6 is used in all cases. Clock time was recorded for each simulation run.
138 |
139 | ## Citation
140 |
141 | If you use BindsNET in your research, please cite the following [article](https://www.frontiersin.org/article/10.3389/fninf.2018.00089):
142 |
143 | ```
144 | @ARTICLE{10.3389/fninf.2018.00089,
145 | AUTHOR={Hazan, Hananel and Saunders, Daniel J. and Khan, Hassaan and Patel, Devdhar and Sanghavi, Darpan T. and Siegelmann, Hava T. and Kozma, Robert},
146 | TITLE={BindsNET: A Machine Learning-Oriented Spiking Neural Networks Library in Python},
147 | JOURNAL={Frontiers in Neuroinformatics},
148 | VOLUME={12},
149 | PAGES={89},
150 | YEAR={2018},
151 | URL={https://www.frontiersin.org/article/10.3389/fninf.2018.00089},
152 | DOI={10.3389/fninf.2018.00089},
153 | ISSN={1662-5196},
154 | }
155 |
156 | ```
157 |
158 | ## Contributors
159 |
160 | - Daniel Saunders ([email](mailto:djsaunde@cs.umass.edu))
161 | - Hananel Hazan ([email](mailto:hananel@hazan.org.il))
162 | - Darpan Sanghavi ([email](mailto:dsanghavi@cs.umass.edu))
163 | - Hassaan Khan ([email](mailto:hqkhan@umass.edu))
164 | - Devdhar Patel ([email](mailto:devdharpatel@cs.umass.edu))
165 |
166 | ## License
167 | GNU Affero General Public License v3.0
168 |
169 |
170 |
171 |
172 |
--------------------------------------------------------------------------------
/bindsnet/bindsnet-0.2.5.dist-info/RECORD:
--------------------------------------------------------------------------------
1 | bindsnet-0.2.5.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2 | bindsnet-0.2.5.dist-info/METADATA,sha256=607kTwrUkJf8qf3_sJm8I7ShSjqIKWWxa4Z_mVfU9Qw,9894
3 | bindsnet-0.2.5.dist-info/RECORD,,
4 | bindsnet-0.2.5.dist-info/WHEEL,sha256=6lWjmfsNs_bGGL5Yl5gRRr5x9whxDoYnVMS85C0Tfgs,98
5 | bindsnet-0.2.5.dist-info/top_level.txt,sha256=_rRWpJi6ML98T3oY7Yi0UE3ThTbjrPuIcjQ6JPyDoT4,9
6 | bindsnet/__init__.py,sha256=Copph0F10T-N_RBshOt81aLlmAHyMW70wp1kUneO9BI,267
7 | bindsnet/__pycache__/__init__.cpython-37.pyc,,
8 | bindsnet/__pycache__/utils.cpython-37.pyc,,
9 | bindsnet/analysis/__init__.py,sha256=ncx2TPPt_WsrqDb6W9My5dFHTKUqenDbPAiICPF8cdM,57
10 | bindsnet/analysis/__pycache__/__init__.cpython-37.pyc,,
11 | bindsnet/analysis/__pycache__/pipeline_analysis.cpython-37.pyc,,
12 | bindsnet/analysis/__pycache__/plotting.cpython-37.pyc,,
13 | bindsnet/analysis/__pycache__/visualization.cpython-37.pyc,,
14 | bindsnet/analysis/pipeline_analysis.py,sha256=SyP2wB5xQpmjeN0lXmM3lkxEIqLEhMTsOE0bepwcmCw,13629
15 | bindsnet/analysis/plotting.py,sha256=C5N9SBceCmQZdnTZIWepNYSZELxtKV87KwQAvoDOXNc,22172
16 | bindsnet/analysis/visualization.py,sha256=-mEB-ExZWRkDOyN0wq_DVGJ7p7OqKBUf48R0aEhmBUE,4432
17 | bindsnet/conversion/__init__.py,sha256=vAYbOc6MTuRf01ugLqE6D2nJzB1aTmhVClvOlZkpLgU,212
18 | bindsnet/conversion/__pycache__/__init__.cpython-37.pyc,,
19 | bindsnet/conversion/__pycache__/conversion.cpython-37.pyc,,
20 | bindsnet/conversion/conversion.py,sha256=qPEbzSzaZUi-xwqnHXEyJ3eSad9fmBq9oEy7y3HBass,19736
21 | bindsnet/datasets/__init__.py,sha256=JFmWp5g0zkaRoTzW9OftgQ19lW6dN_FWcaG7IrxeT5U,1625
22 | bindsnet/datasets/__pycache__/__init__.cpython-37.pyc,,
23 | bindsnet/datasets/__pycache__/collate.cpython-37.pyc,,
24 | bindsnet/datasets/__pycache__/dataloader.cpython-37.pyc,,
25 | bindsnet/datasets/__pycache__/davis.cpython-37.pyc,,
26 | bindsnet/datasets/__pycache__/preprocess.cpython-37.pyc,,
27 | bindsnet/datasets/__pycache__/spoken_mnist.cpython-37.pyc,,
28 | bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-37.pyc,,
29 | bindsnet/datasets/collate.py,sha256=jGmO8qyoSavSIdORGGiaf_9dBAjZdfHBIbzRSXVqZgM,3144
30 | bindsnet/datasets/dataloader.py,sha256=GgbN-YTdyzbRHTrHq5dPJCKfiJDnppLWRPzQkC6zS5M,818
31 | bindsnet/datasets/davis.py,sha256=wH5GsZ2CIy0-L-yfAwdvIjLwJgyP8NfMaFAo-E-AfqI,12586
32 | bindsnet/datasets/preprocess.py,sha256=CVPY1VrIRc9vzBi56TCmP3sustx6u8Y7GbbDFps7s8Y,1342
33 | bindsnet/datasets/spoken_mnist.py,sha256=9hipF6IrL55B_J0EdQU5E8RlBJ9YG9rwBbOpZcH02DM,10438
34 | bindsnet/datasets/torchvision_wrapper.py,sha256=fJ9fl7S4Q0Vo6m0z8I5m_SvpaZhFPJXuDkMI_P8vinA,2825
35 | bindsnet/encoding/__init__.py,sha256=_IbgfqJtClr4I3R8NNVhw_3UwerBtzhI4zHsrx0xdXE,301
36 | bindsnet/encoding/__pycache__/__init__.cpython-37.pyc,,
37 | bindsnet/encoding/__pycache__/encoders.cpython-37.pyc,,
38 | bindsnet/encoding/__pycache__/encodings.cpython-37.pyc,,
39 | bindsnet/encoding/__pycache__/loaders.cpython-37.pyc,,
40 | bindsnet/encoding/encoders.py,sha256=4r6VZzm9i7cHb2G2doGzTl4jAwW6xfJlRPeLlDcHGi8,3177
41 | bindsnet/encoding/encodings.py,sha256=i08jPyCV3Jj4H7qpphqVw8BJPwuKJGukEXPwtDafALA,5965
42 | bindsnet/encoding/loaders.py,sha256=2TbyhK-ib_evgFNHPBkQeZC-aN5yvLFdGsTBirDAXrI,2380
43 | bindsnet/environment/__init__.py,sha256=6JMOGowIuDuvByo9kFIRR1ntJfpOBUQjSZqE7LWaWHE,53
44 | bindsnet/environment/__pycache__/__init__.cpython-37.pyc,,
45 | bindsnet/environment/__pycache__/environment.cpython-37.pyc,,
46 | bindsnet/environment/environment.py,sha256=W3ZLDCKRypPMuouUAVSkrfBulXSt_6yPEfMdGLQ7PtY,8425
47 | bindsnet/evaluation/__init__.py,sha256=umhjCQ4eaNhUHKGcZ40wqwH5Xh9jOJ3cFWrJLt-cF-g,163
48 | bindsnet/evaluation/__pycache__/__init__.cpython-37.pyc,,
49 | bindsnet/evaluation/__pycache__/evaluation.cpython-37.pyc,,
50 | bindsnet/evaluation/evaluation.py,sha256=Y_qS1PvvIc_YHjxEeKFqJ53p5VQSr-dFU89fJTKYemg,9093
51 | bindsnet/learning/__init__.py,sha256=zR7Zaao__m3rAcL1gNFUwaNooEWbJoJ7Om-E3VSn0zc,142
52 | bindsnet/learning/__pycache__/__init__.cpython-37.pyc,,
53 | bindsnet/learning/__pycache__/learning.cpython-37.pyc,,
54 | bindsnet/learning/__pycache__/reward.cpython-37.pyc,,
55 | bindsnet/learning/learning.py,sha256=5Ihyv7um2a3RGGJv_r7MpRa-KBUl0zaW8qr7c7PNCYY,32970
56 | bindsnet/learning/reward.py,sha256=4VeLl_T1kkLRWvhaVwVtD22c5e8oM1FSaMgqspDNnPg,2605
57 | bindsnet/models/__init__.py,sha256=Mt-J2zoeyxdoj0vlQ2XCciYoitAeJPqNoimN6PUc4ns,153
58 | bindsnet/models/__pycache__/__init__.cpython-37.pyc,,
59 | bindsnet/models/__pycache__/models.cpython-37.pyc,,
60 | bindsnet/models/models.py,sha256=PcXeFjT-fKHbH7F_qGgbaYf_IEQW8PDgsQb-7H6QBzs,18723
61 | bindsnet/network/__init__.py,sha256=2HHNEEtnmcyI5uAkRJAwby7hsWNls9a3nOrCxqPFZ2U,75
62 | bindsnet/network/__pycache__/__init__.cpython-37.pyc,,
63 | bindsnet/network/__pycache__/monitors.cpython-37.pyc,,
64 | bindsnet/network/__pycache__/network.cpython-37.pyc,,
65 | bindsnet/network/__pycache__/nodes.cpython-37.pyc,,
66 | bindsnet/network/__pycache__/topology.cpython-37.pyc,,
67 | bindsnet/network/monitors.py,sha256=S4i5HtEvAKsHkC9kd4ruvKtBcyEOrbtWUa-nhTHFVog,9678
68 | bindsnet/network/network.py,sha256=2WRqoi5nipJDy1uXOjhMvoqWLg6RzTO41W-HAudCAec,14894
69 | bindsnet/network/nodes.py,sha256=Fhp6BoBO0r-KkFIOfTt94sM4F3ZjxJbdo57ctncW0_o,49303
70 | bindsnet/network/topology.py,sha256=KKwOBuAB_f2OCFJJ_OrjWCzelgUvMr8QyAxbkUop7BY,29391
71 | bindsnet/pipeline/__init__.py,sha256=Sq9muSJcWFuvss53MuRBa5W4q1mSGcdAMrwaFGMjqMg,195
72 | bindsnet/pipeline/__pycache__/__init__.cpython-37.pyc,,
73 | bindsnet/pipeline/__pycache__/action.cpython-37.pyc,,
74 | bindsnet/pipeline/__pycache__/base_pipeline.cpython-37.pyc,,
75 | bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-37.pyc,,
76 | bindsnet/pipeline/__pycache__/environment_pipeline.cpython-37.pyc,,
77 | bindsnet/pipeline/action.py,sha256=qf7Vw5LvwUOx8t6in2NeU4JsWyweev5rkehdvGDggQw,3212
78 | bindsnet/pipeline/base_pipeline.py,sha256=J8viSSWYcKC7Jj4E0tNK9DpaJvqjcJwvDFPhz5PO1CE,7763
79 | bindsnet/pipeline/dataloader_pipeline.py,sha256=qsNJ4fweS3C8I5Fju6odK1x5qqswssd06d5VA_hxeQU,4889
80 | bindsnet/pipeline/environment_pipeline.py,sha256=ZazbqFmPhaw8DSCZSSt4qWupgyQ4tp1kLDuEs4nz4PA,6498
81 | bindsnet/preprocessing/__init__.py,sha256=IqFn8cXso7NvrnflQbkcmjcIBRuXMvtUX9D0GVXyr8A,48
82 | bindsnet/preprocessing/__pycache__/__init__.cpython-37.pyc,,
83 | bindsnet/preprocessing/__pycache__/preprocessing.cpython-37.pyc,,
84 | bindsnet/preprocessing/preprocessing.py,sha256=xMGKvEAGp_4AORzYUHaMnXNLs3aI4UQdFM7TwlyAygY,3248
85 | bindsnet/utils.py,sha256=VIE5vMCsZpIyje82QNFbnOOX1Sa-8xDhrbh5M6NhbeY,7538
86 |
--------------------------------------------------------------------------------
/bindsnet/bindsnet-0.2.5.dist-info/WHEEL:
--------------------------------------------------------------------------------
1 | Wheel-Version: 1.0
2 | Generator: bdist_wheel (0.33.4)
3 | Root-Is-Purelib: true
4 | Tag: cp37-none-any
5 |
6 |
--------------------------------------------------------------------------------
/bindsnet/bindsnet-0.2.5.dist-info/top_level.txt:
--------------------------------------------------------------------------------
1 | bindsnet
2 |
--------------------------------------------------------------------------------
/bindsnet/conversion/__init__.py:
--------------------------------------------------------------------------------
1 | from .conversion import (
2 | Permute,
3 | FeatureExtractor,
4 | SubtractiveResetIFNodes,
5 | PassThroughNodes,
6 | PermuteConnection,
7 | ConstantPad2dConnection,
8 | data_based_normalization,
9 | ann_to_snn,
10 | )
11 |
--------------------------------------------------------------------------------
/bindsnet/conversion/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/conversion/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/conversion/__pycache__/conversion.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/conversion/__pycache__/conversion.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .torchvision_wrapper import create_torchvision_dataset_wrapper
2 | from .spoken_mnist import SpokenMNIST
3 | from .davis import Davis
4 |
5 | from .collate import time_aware_collate
6 | from .dataloader import DataLoader
7 |
8 |
9 | CIFAR10 = create_torchvision_dataset_wrapper("CIFAR10")
10 | CIFAR100 = create_torchvision_dataset_wrapper("CIFAR100")
11 | Cityscapes = create_torchvision_dataset_wrapper("Cityscapes")
12 | CocoCaptions = create_torchvision_dataset_wrapper("CocoCaptions")
13 | CocoDetection = create_torchvision_dataset_wrapper("CocoDetection")
14 | DatasetFolder = create_torchvision_dataset_wrapper("DatasetFolder")
15 | EMNIST = create_torchvision_dataset_wrapper("EMNIST")
16 | FakeData = create_torchvision_dataset_wrapper("FakeData")
17 | FashionMNIST = create_torchvision_dataset_wrapper("FashionMNIST")
18 | Flickr30k = create_torchvision_dataset_wrapper("Flickr30k")
19 | Flickr8k = create_torchvision_dataset_wrapper("Flickr8k")
20 | ImageFolder = create_torchvision_dataset_wrapper("ImageFolder")
21 | KMNIST = create_torchvision_dataset_wrapper("KMNIST")
22 | LSUN = create_torchvision_dataset_wrapper("LSUN")
23 | LSUNClass = create_torchvision_dataset_wrapper("LSUNClass")
24 | MNIST = create_torchvision_dataset_wrapper("MNIST")
25 | Omniglot = create_torchvision_dataset_wrapper("Omniglot")
26 | PhotoTour = create_torchvision_dataset_wrapper("PhotoTour")
27 | SBU = create_torchvision_dataset_wrapper("SBU")
28 | SEMEION = create_torchvision_dataset_wrapper("SEMEION")
29 | STL10 = create_torchvision_dataset_wrapper("STL10")
30 | SVHN = create_torchvision_dataset_wrapper("SVHN")
31 | VOCDetection = create_torchvision_dataset_wrapper("VOCDetection")
32 | VOCSegmentation = create_torchvision_dataset_wrapper("VOCSegmentation")
33 | ImageNet = create_torchvision_dataset_wrapper("ImageNet")
34 |
--------------------------------------------------------------------------------
/bindsnet/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/datasets/__pycache__/collate.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/collate.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/datasets/__pycache__/dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/datasets/__pycache__/davis.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/davis.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/datasets/__pycache__/preprocess.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/preprocess.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/datasets/__pycache__/spoken_mnist.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/spoken_mnist.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/datasets/collate.py:
--------------------------------------------------------------------------------
1 | r"""" This code is directly pulled from the pytorch version found at:
2 |
3 | https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py
4 |
5 | Modifications exist to have [time, batch, n_0, ... n_k] instead of batch
6 | in dimension 0.
7 | """
8 |
9 | import torch
10 | import re
11 | from torch._six import container_abcs, string_classes, int_classes
12 |
13 | from torch.utils.data._utils import collate as pytorch_collate
14 |
15 |
16 | def safe_worker_check():
17 | """ Method to check to used shared memory will change in a newer
18 | version of pytorch
19 | """
20 | try:
21 | return torch.utils.data.get_worker_info() is not None
22 | except:
23 | return pytorch_collate._use_shared_memory
24 |
25 |
26 | def time_aware_collate(batch):
27 | r"""Puts each data field into a tensor with dimensions [time,
28 | batch size, ...]
29 |
30 | Interpretation of dimensions being input:
31 | - 0 dim (,) - (1, batch_size, 1)
32 | - 1 dim (time,) - (time, batch_size, 1)
33 | - >2 dim (time, n_0, ...) - (time, batch_size, n_0, ...)
34 | """
35 |
36 | elem = batch[0]
37 | elem_type = type(elem)
38 | if isinstance(elem, torch.Tensor):
39 | # catch 0 and 1 dimension cases and view as specified
40 | if elem.dim() == 0:
41 | batch = [x.view((1, 1)) for x in batch]
42 | elif elem.dim() == 1:
43 | batch = [x.view((x.shape[0], 1)) for x in batch]
44 |
45 | out = None
46 | if safe_worker_check():
47 | # If we're in a background process, concatenate directly into a
48 | # shared memory tensor to avoid an extra copy
49 | numel = sum([x.numel() for x in batch])
50 | storage = elem.storage()._new_shared(numel)
51 | out = elem.new(storage)
52 | return torch.stack(batch, 1, out=out)
53 | elif (
54 | elem_type.__module__ == "numpy"
55 | and elem_type.__name__ != "str_"
56 | and elem_type.__name__ != "string_"
57 | ):
58 | elem = batch[0]
59 | if elem_type.__name__ == "ndarray":
60 | # array of string classes and object
61 | if (
62 | pytorch_collate.np_str_obj_array_pattern.search(elem.dtype.str)
63 | is not None
64 | ):
65 | raise TypeError(
66 | pytorch_collate.default_collate_err_msg_format.format(elem.dtype)
67 | )
68 |
69 | return time_aware_collate([torch.as_tensor(b) for b in batch])
70 | elif elem.shape == (): # scalars
71 | return torch.as_tensor(batch)
72 | elif isinstance(elem, float):
73 | return torch.tensor(batch, dtype=torch.float64)
74 | elif isinstance(elem, int_classes):
75 | return torch.tensor(batch)
76 | elif isinstance(elem, string_classes):
77 | return batch
78 | elif isinstance(elem, container_abcs.Mapping):
79 | return {key: time_aware_collate([d[key] for d in batch]) for key in elem}
80 | elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
81 | return elem_type(*(time_aware_collate(samples) for samples in zip(*batch)))
82 | elif isinstance(elem, container_abcs.Sequence):
83 | transposed = zip(*batch)
84 | return [time_aware_collate(samples) for samples in transposed]
85 |
86 | raise TypeError(pytorch_collate.default_collate_err_msg_format.format(elem_type))
87 |
--------------------------------------------------------------------------------
/bindsnet/datasets/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .collate import time_aware_collate
4 |
5 |
6 | class DataLoader(torch.utils.data.DataLoader):
7 | def __init__(
8 | self,
9 | dataset,
10 | batch_size=1,
11 | shuffle=False,
12 | sampler=None,
13 | batch_sampler=None,
14 | num_workers=0,
15 | collate_fn=time_aware_collate,
16 | pin_memory=False,
17 | drop_last=False,
18 | timeout=0,
19 | worker_init_fn=None,
20 | ):
21 | super().__init__(
22 | dataset,
23 | sampler=sampler,
24 | shuffle=shuffle,
25 | batch_size=batch_size,
26 | drop_last=drop_last,
27 | pin_memory=pin_memory,
28 | timeout=timeout,
29 | num_workers=num_workers,
30 | worker_init_fn=worker_init_fn,
31 | batch_sampler=batch_sampler,
32 | collate_fn=collate_fn,
33 | )
34 |
--------------------------------------------------------------------------------
/bindsnet/datasets/davis.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import shutil
5 | import zipfile
6 | import sys
7 | import time
8 | import shutil
9 |
10 | from PIL import Image
11 | from glob import glob
12 | from tqdm import tqdm
13 | from collections import defaultdict
14 | from typing import Optional, Tuple, List, Iterable
15 | from urllib.request import urlretrieve
16 |
17 | import warnings
18 |
19 |
20 | class Davis(torch.utils.data.Dataset):
21 | SUBSET_OPTIONS = ["train", "val", "test-dev", "test-challenge"]
22 | TASKS = ["semi-supervised", "unsupervised"]
23 | RESOLUTION_OPTIONS = ["480p", "Full-Resolution"]
24 | DATASET_WEB = "https://davischallenge.org/davis2017/code.html"
25 | VOID_LABEL = 255
26 |
27 | def __init__(
28 | self,
29 | root,
30 | task="unsupervised",
31 | subset="train",
32 | sequences="all",
33 | resolution="480p",
34 | size=(600, 480),
35 | codalab=False,
36 | download=False,
37 | num_samples: int = -1,
38 | ):
39 | """
40 | Class to read the DAVIS dataset
41 | :param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders.
42 | :param task: Task to load the annotations, choose between semi-supervised or unsupervised.
43 | :param subset: Set to load the annotations
44 | :param sequences: Sequences to consider, 'all' to use all the sequences in a set.
45 | :param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution'
46 | :param download: Specify whether to download the dataset if it is not present
47 | :param num_samples: Number of samples to pass to the batch
48 | """
49 | super().__init__()
50 |
51 | if subset not in self.SUBSET_OPTIONS:
52 | raise ValueError(f"Subset should be in {self.SUBSET_OPTIONS}")
53 | if task not in self.TASKS:
54 | raise ValueError(f"The only tasks that are supported are {self.TASKS}")
55 | if resolution not in self.RESOLUTION_OPTIONS:
56 | raise ValueError(
57 | f"You may only choose one of these resolutions: {self.RESOLUTION_OPTIONS}"
58 | )
59 |
60 | self.task = task
61 | self.subset = subset
62 | self.resolution = resolution
63 | self.size = size
64 |
65 | # Sets the boolean converted if the size of the images must be scaled down
66 | self.converted = not self.size == (600, 480)
67 |
68 | # Sets a tag for naming the folder containing the dataset
69 | self.tag = ""
70 | if self.task == "unsupervised":
71 | self.tag += "Unsupervised-"
72 | if self.subset == "train" or self.subset == "val":
73 | self.tag += "trainval"
74 | else:
75 | self.tag += self.subset
76 | self.tag += "-" + self.resolution
77 |
78 | # Makes a unique path for a given instance of davis
79 | self.converted_root = os.path.join(
80 | root, self.tag + "-" + str(self.size[0]) + "x" + str(self.size[1])
81 | )
82 | self.root = os.path.join(root, self.tag)
83 | self.download = download
84 | self.num_samples = num_samples
85 | self.zip_path = os.path.join(self.root, "repo.zip")
86 | self.img_path = os.path.join(self.root, "JPEGImages", resolution)
87 | annotations_folder = (
88 | "Annotations" if task == "semi-supervised" else "Annotations_unsupervised"
89 | )
90 | self.mask_path = os.path.join(self.root, annotations_folder, resolution)
91 | year = (
92 | "2019"
93 | if task == "unsupervised"
94 | and (subset == "test-dev" or subset == "test-challenge")
95 | else "2017"
96 | )
97 | self.imagesets_path = os.path.join(self.root, "ImageSets", year)
98 |
99 | # Makes a converted path for scaled images
100 | if self.converted:
101 | self.converted_img_path = os.path.join(
102 | self.converted_root, "JPEGImages", resolution
103 | )
104 | self.converted_mask_path = os.path.join(
105 | self.converted_root, annotations_folder, resolution
106 | )
107 | self.converted_imagesets_path = os.path.join(
108 | self.converted_root, "ImageSets", year
109 | )
110 |
111 | # Sets seqence_names to the relevant sequences
112 | if sequences == "all":
113 | with open(
114 | os.path.join(self.imagesets_path, f"{self.subset}.txt"), "r"
115 | ) as f:
116 | tmp = f.readlines()
117 | self.sequences_names = [x.strip() for x in tmp]
118 | else:
119 | self.sequences_names = (
120 | sequences if isinstance(sequences, list) else [sequences]
121 | )
122 | self.sequences = defaultdict(dict)
123 |
124 | # Check if Davis is installed and download it if necessary
125 | self._check_directories()
126 |
127 | # Sets the images and masks for each sequence resizing for the given size
128 | for seq in self.sequences_names:
129 | images = np.sort(glob(os.path.join(self.img_path, seq, "*.jpg"))).tolist()
130 | if len(images) == 0 and not codalab:
131 | raise FileNotFoundError(f"Images for sequence {seq} not found.")
132 | self.sequences[seq]["images"] = images
133 | masks = np.sort(glob(os.path.join(self.mask_path, seq, "*.png"))).tolist()
134 | masks.extend([-1] * (len(images) - len(masks)))
135 | self.sequences[seq]["masks"] = masks
136 |
137 | # Creates an enumeration for the sequences for __getitem__
138 | self.enum_sequences = []
139 | for seq in self.sequences_names:
140 | self.enum_sequences.append(self.sequences[seq])
141 |
142 | def __len__(self):
143 | """
144 | Calculates the number of sequences the dataset holds
145 |
146 | :return: the number of sequences in the dataset
147 | """
148 | return len(self.sequences)
149 |
150 | def _convert_sequences(self):
151 | """
152 | Creates a new root for the dataset to be converted and placed into, then copies each image and mask into the given size and stores correctly.
153 | """
154 | os.makedirs(os.path.join(self.converted_imagesets_path, f"{self.subset}.txt"))
155 | os.makedirs(self.converted_img_path)
156 | os.makedirs(self.converted_mask_path)
157 |
158 | shutil.copy(
159 | os.path.join(self.imagesets_path, f"{self.subset}.txt"),
160 | os.path.join(self.converted_imagesets_path, f"{self.subset}.txt"),
161 | )
162 |
163 | print("Converting sequences to size: {0}".format(self.size))
164 | for seq in tqdm(self.sequences_names):
165 | os.makedirs(os.path.join(self.converted_img_path, seq))
166 | os.makedirs(os.path.join(self.converted_mask_path, seq))
167 | images = np.sort(glob(os.path.join(self.img_path, seq, "*.jpg"))).tolist()
168 | if len(images) == 0 and not codalab:
169 | raise FileNotFoundError(f"Images for sequence {seq} not found.")
170 | for ind, img in enumerate(images):
171 | im = Image.open(img)
172 | im.thumbnail(self.size, Image.ANTIALIAS)
173 | im.save(
174 | os.path.join(
175 | self.converted_img_path, seq, str(ind).zfill(5) + ".jpg"
176 | )
177 | )
178 | masks = np.sort(glob(os.path.join(self.mask_path, seq, "*.png"))).tolist()
179 | for ind, msk in enumerate(masks):
180 | im = Image.open(msk)
181 | im.thumbnail(self.size, Image.ANTIALIAS)
182 | im.convert("RGB").save(
183 | os.path.join(
184 | self.converted_mask_path, seq, str(ind).zfill(5) + ".png"
185 | )
186 | )
187 |
188 | def _check_directories(self):
189 | """
190 | Verifies that the correct dataset is downloaded; downloads if it isn't and download=True.
191 |
192 | :raises: FileNotFoundError if the subset sequence, annotation or root folder is missing.
193 | """
194 | if not os.path.exists(self.root):
195 | if self.download:
196 | self._download()
197 | else:
198 | raise FileNotFoundError(
199 | f"DAVIS not found in the specified directory, download it from {self.DATASET_WEB} or add download=True to your call"
200 | )
201 | if not os.path.exists(os.path.join(self.imagesets_path, f"{self.subset}.txt")):
202 | raise FileNotFoundError(
203 | f"Subset sequences list for {self.subset} not found, download the missing subset "
204 | f"for the {self.task} task from {self.DATASET_WEB}"
205 | )
206 | if self.subset in ["train", "val"] and not os.path.exists(self.mask_path):
207 | raise FileNotFoundError(
208 | f"Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}"
209 | )
210 | if self.converted:
211 | if not os.path.exists(self.converted_img_path):
212 | self._convert_sequences()
213 | self.img_path = self.converted_img_path
214 | self.mask_path = self.converted_mask_path
215 | self.imagesets_path = self.converted_imagesets_path
216 |
217 | def get_frames(self, sequence):
218 | for img, msk in zip(
219 | self.sequences[sequence]["images"], self.sequences[sequence]["masks"]
220 | ):
221 | image = np.array(Image.open(img))
222 | mask = None if msk is None else np.array(Image.open(msk))
223 | yield image, mask
224 |
225 | def _get_all_elements(self, sequence, obj_type):
226 | obj = np.array(Image.open(self.sequences[sequence][obj_type][0]))
227 | all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape))
228 | obj_id = []
229 | for i, obj in enumerate(self.sequences[sequence][obj_type]):
230 | all_objs[i, ...] = np.array(Image.open(obj))
231 | obj_id.append("".join(obj.split("/")[-1].split(".")[:-1]))
232 | return all_objs, obj_id
233 |
234 | def get_all_images(self, sequence):
235 | return self._get_all_elements(sequence, "images")
236 |
237 | def get_all_masks(self, sequence, separate_objects_masks=False):
238 | masks, masks_id = self._get_all_elements(sequence, "masks")
239 | masks_void = np.zeros_like(masks)
240 |
241 | # Separate void and object masks
242 | for i in range(masks.shape[0]):
243 | masks_void[i, ...] = masks[i, ...] == 255
244 | masks[i, masks[i, ...] == 255] = 0
245 |
246 | if separate_objects_masks:
247 | num_objects = int(np.max(masks[0, ...]))
248 | tmp = np.ones((num_objects, *masks.shape))
249 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None]
250 | masks = tmp == masks[None, ...]
251 | masks = masks > 0
252 | return masks, masks_void, masks_id
253 |
254 | def get_sequences(self):
255 | for seq in self.sequences:
256 | yield seq
257 |
258 | def _download(self):
259 | """
260 | Downloads the correct dataset based on the given parameters
261 |
262 | Relies on self.tag to determine both the name of the folder created for the dataset and for the finding the correct download url.
263 | """
264 |
265 | os.makedirs(self.root)
266 |
267 | # Grabs the correct zip url based on parameters
268 | zip_url = f"https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-{self.tag}.zip"
269 |
270 | print("\nDownloading Davis data set from " + zip_url + "\n")
271 |
272 | # Downloads the relevant dataset
273 | urlretrieve(zip_url, self.zip_path, reporthook=self.progress)
274 |
275 | print("\nDone! \n\nUnzipping and restructuring")
276 |
277 | # Extracts the dataset
278 | z = zipfile.ZipFile(self.zip_path, "r")
279 | z.extractall(path=self.root)
280 | z.close()
281 | os.remove(self.zip_path)
282 |
283 | temp_folder = os.path.join(self.root, "DAVIS\\")
284 |
285 | # Deletes an unnecessary containing folder "DAVIS" which comes with every download
286 | for file in os.listdir(temp_folder):
287 | shutil.move(temp_folder + file, self.root)
288 | cwd = os.getcwd()
289 | os.chdir(self.root)
290 | os.rmdir("DAVIS")
291 | os.chdir(cwd)
292 |
293 | print("\nDone!\n")
294 |
295 | def __getitem__(self, ind):
296 | """
297 | Gets an item of the Dataset based on index
298 |
299 | :param ind: index of item to take from dataset
300 |
301 | :return: a sequence which contains a list of images and masks
302 | """
303 | seq = self.enum_sequences[ind]
304 | return seq
305 |
306 | # Simple progress indicator for the download of the dataset
307 | def progress(self, count, block_size, total_size):
308 | global start_time
309 | if count == 0:
310 | start_time = time.time()
311 | return
312 | duration = time.time() - start_time
313 | progress_size = int(count * block_size)
314 | speed = int(progress_size / (1024 * duration))
315 | percent = min(int(count * block_size * 100 / total_size), 100)
316 | sys.stdout.write(
317 | "\r...%d%%, %d MB, %d KB/s, %d seconds passed"
318 | % (percent, progress_size / (1024 * 1024), speed, duration)
319 | )
320 | sys.stdout.flush()
321 |
--------------------------------------------------------------------------------
/bindsnet/datasets/preprocess.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def gray_scale(image: np.ndarray) -> np.ndarray:
6 | # language=rst
7 | """
8 | Converts RGB image into grayscale.
9 |
10 | :param image: RGB image.
11 | :return: Gray-scaled image.
12 | """
13 | return cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
14 |
15 |
16 | def crop(image: np.ndarray, x1: int, x2: int, y1: int, y2: int) -> np.ndarray:
17 | # language=rst
18 | """
19 | Crops an image given coordinates of cropping box.
20 |
21 | :param image: 3-dimensional image.
22 | :param x1: Left x coordinate.
23 | :param x2: Right x coordinate.
24 | :param y1: Bottom y coordinate.
25 | :param y2: Top y coordinate.
26 | :return: Image cropped using coordinates (x1, x2, y1, y2).
27 | """
28 | return image[x1:x2, y1:y2, :]
29 |
30 |
31 | def binary_image(image: np.ndarray) -> np.ndarray:
32 | # language=rst
33 | """
34 | Converts input image into black and white (binary)
35 |
36 | :param image: Gray-scaled image.
37 | :return: Black and white image.
38 | """
39 | return cv2.threshold(image, 0, 1, cv2.THRESH_BINARY)[1]
40 |
41 |
42 | def subsample(image: np.ndarray, x: int, y: int) -> np.ndarray:
43 | # language=rst
44 | """
45 | Scale the image to (x, y).
46 |
47 | :param image: Image to be rescaled.
48 | :param x: Output value for ``image``'s x dimension.
49 | :param y: Output value for ``image``'s y dimension.
50 | :return: Re-scaled image.
51 | """
52 | return cv2.resize(image, (x, y))
53 |
--------------------------------------------------------------------------------
/bindsnet/datasets/spoken_mnist.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, List, Iterable
2 | import os
3 | import torch
4 | import numpy as np
5 | import shutil
6 | import zipfile
7 |
8 |
9 | from urllib.request import urlretrieve
10 | from scipy.io import wavfile
11 |
12 | import warnings
13 |
14 |
15 | class SpokenMNIST(torch.utils.data.Dataset):
16 | # language=rst
17 | """
18 | Handles loading and saving of the Spoken MNIST audio dataset `(link)
19 | `_.
20 | """
21 | train_pickle = "train.pt"
22 | test_pickle = "test.pt"
23 |
24 | url = "https://github.com/Jakobovski/free-spoken-digit-dataset/archive/master.zip"
25 |
26 | files = []
27 | for digit in range(10):
28 | for speaker in ["jackson", "nicolas", "theo"]:
29 | for example in range(50):
30 | files.append("_".join([str(digit), speaker, str(example)]) + ".wav")
31 |
32 | n_files = len(files)
33 |
34 | def __init__(
35 | self,
36 | path: str,
37 | download: bool = False,
38 | shuffle: bool = True,
39 | train: bool = True,
40 | split: float = 0.8,
41 | num_samples: int = -1,
42 | ) -> None:
43 | # language=rst
44 | """
45 | Constructor for the ``SpokenMNIST`` object. Makes the data directory if it doesn't already exist.
46 |
47 | :param path: Pathname of directory in which to store the dataset.
48 | :param download: Whether or not to download the dataset (requires internet connection).
49 | :param shuffle: Whether to randomly permute order of dataset.
50 | :param train: Load training split if true else load test split
51 | :param split: Train, test split; in range ``(0, 1)``.
52 | :param num_samples: Number of samples to pass to the batch
53 | """
54 | super().__init__()
55 |
56 | if not os.path.isdir(path):
57 | os.makedirs(path)
58 |
59 | self.path = path
60 | self.download = download
61 | self.shuffle = shuffle
62 |
63 | self.zip_path = os.path.join(path, "repo.zip")
64 |
65 | if train:
66 | self.audio, self.labels = self._get_train(split)
67 | else:
68 | self.audio, self.labels = self._get_test(split)
69 |
70 | self.num_samples = num_samples
71 |
72 | def __len__(self):
73 | return len(self.audio)
74 |
75 | def __getitem__(self, ind):
76 | audio = self.audio[ind][: self.num_samples, :]
77 | label = self.labels[ind]
78 |
79 | return {"audio": audio, "label": label}
80 |
81 | def _get_train(self, split: float = 0.8) -> Tuple[torch.Tensor, torch.Tensor]:
82 | # language=rst
83 | """
84 | Gets the Spoken MNIST training audio and labels.
85 |
86 | :param split: Train, test split; in range ``(0, 1)``.
87 | :return: Spoken MNIST training audio and labels.
88 | """
89 | split_index = int(split * SpokenMNIST.n_files)
90 | path = os.path.join(self.path, "_".join([SpokenMNIST.train_pickle, str(split)]))
91 |
92 | if not all([os.path.isfile(os.path.join(self.path, f)) for f in self.files]):
93 | # Download data if it isn't on disk.
94 | if self.download:
95 | print("Downloading Spoken MNIST data.\n")
96 | self._download()
97 |
98 | # Process data into audio, label (input, output) pairs.
99 | audio, labels = self.process_data(SpokenMNIST.files[:split_index])
100 |
101 | # Serialize image data on disk for next time.
102 | torch.save((audio, labels), open(path, "wb"))
103 | else:
104 | msg = "Dataset not found on disk; specify 'download=True' to allow downloads."
105 | raise FileNotFoundError(msg)
106 | else:
107 | if not os.path.isdir(path):
108 | # Process image and label data if pickled file doesn't exist.
109 | audio, labels = self.process_data(SpokenMNIST.files)
110 |
111 | # Serialize image data on disk for next time.
112 | torch.save((audio, labels), open(path, "wb"))
113 | else:
114 | # Load image data from disk if it has already been processed.
115 | print("Loading training data from serialized object file.\n")
116 | audio, labels = torch.load(open(path, "rb"))
117 |
118 | labels = torch.Tensor(labels)
119 |
120 | if self.shuffle:
121 | perm = np.random.permutation(np.arange(labels.shape[0]))
122 | audio, labels = [torch.Tensor(audio[_]) for _ in perm], labels[perm]
123 |
124 | return audio, torch.Tensor(labels)
125 |
126 | def _get_test(self, split: float = 0.8) -> Tuple[torch.Tensor, List[torch.Tensor]]:
127 | # language=rst
128 | """
129 | Gets the Spoken MNIST training audio and labels.
130 |
131 | :param split: Train, test split; in range ``(0, 1)``.
132 | :return: The Spoken MNIST test audio and labels.
133 | """
134 | split_index = int(split * SpokenMNIST.n_files)
135 | path = os.path.join(self.path, "_".join([SpokenMNIST.test_pickle, str(split)]))
136 |
137 | if not all([os.path.isfile(os.path.join(self.path, f)) for f in self.files]):
138 | # Download data if it isn't on disk.
139 | if self.download:
140 | print("Downloading Spoken MNIST data.\n")
141 | self._download()
142 |
143 | # Process data into audio, label (input, output) pairs.
144 | audio, labels = self.process_data(SpokenMNIST.files[split_index:])
145 |
146 | # Serialize image data on disk for next time.
147 | torch.save((audio, labels), open(path, "wb"))
148 | else:
149 | msg = "Dataset not found on disk; specify 'download=True' to allow downloads."
150 | raise FileNotFoundError(msg)
151 | else:
152 | if not os.path.isdir(path):
153 | # Process image and label data if pickled file doesn't exist.
154 | audio, labels = self.process_data(SpokenMNIST.files)
155 |
156 | # Serialize image data on disk for next time.
157 | torch.save((audio, labels), open(path, "wb"))
158 | else:
159 | # Load image data from disk if it has already been processed.
160 | print("Loading test data from serialized object file.\n")
161 | audio, labels = torch.load(open(path, "rb"))
162 |
163 | labels = torch.Tensor(labels)
164 |
165 | if self.shuffle:
166 | perm = np.random.permutation(np.arange(labels.shape[0]))
167 | audio, labels = audio[perm], labels[perm]
168 |
169 | return audio, torch.Tensor(labels)
170 |
171 | def _download(self) -> None:
172 | # language=rst
173 | """
174 | Downloads and unzips all Spoken MNIST data.
175 | """
176 | urlretrieve(SpokenMNIST.url, self.zip_path)
177 |
178 | z = zipfile.ZipFile(self.zip_path, "r")
179 | z.extractall(path=self.path)
180 | z.close()
181 |
182 | path = os.path.join(self.path, "free-spoken-digit-dataset-master", "recordings")
183 | for f in os.listdir(path):
184 | shutil.move(os.path.join(path, f), os.path.join(self.path))
185 |
186 | cwd = os.getcwd()
187 | os.chdir(self.path)
188 | shutil.rmtree("free-spoken-digit-dataset-master")
189 | os.chdir(cwd)
190 |
191 | def process_data(
192 | self, file_names: Iterable[str]
193 | ) -> Tuple[List[torch.Tensor], torch.Tensor]:
194 | # language=rst
195 | """
196 | Opens files of Spoken MNIST data and processes them into ``numpy`` arrays.
197 |
198 | :param file_names: Names of the files containing Spoken MNIST audio to load.
199 | :return: Processed Spoken MNIST audio and label data.
200 | """
201 | audio, labels = [], []
202 |
203 | for f in file_names:
204 | label = int(f.split("_")[0])
205 |
206 | sample_rate, signal = wavfile.read(os.path.join(self.path, f))
207 | pre_emphasis = 0.97
208 | emphasized_signal = np.append(
209 | signal[0], signal[1:] - pre_emphasis * signal[:-1]
210 | )
211 |
212 | # Popular settings are 25 ms for the frame size and a 10 ms stride (15 ms overlap)
213 | frame_size = 0.025
214 | frame_stride = 0.01
215 |
216 | # Convert from seconds to samples
217 | frame_length, frame_step = (
218 | frame_size * sample_rate,
219 | frame_stride * sample_rate,
220 | )
221 | signal_length = len(emphasized_signal)
222 | frame_length = int(round(frame_length))
223 | frame_step = int(round(frame_step))
224 |
225 | # Make sure that we have at least 1 frame
226 | num_frames = int(
227 | np.ceil(float(np.abs(signal_length - frame_length)) / frame_step)
228 | )
229 |
230 | pad_signal_length = num_frames * frame_step + frame_length
231 | z = np.zeros((pad_signal_length - signal_length))
232 | pad_signal = np.append(emphasized_signal, z) # Pad signal
233 |
234 | indices = (
235 | np.tile(np.arange(0, frame_length), (num_frames, 1))
236 | + np.tile(
237 | np.arange(0, num_frames * frame_step, frame_step), (frame_length, 1)
238 | ).T
239 | )
240 | frames = pad_signal[indices.astype(np.int32, copy=False)]
241 |
242 | # Hamming Window
243 | frames *= np.hamming(frame_length)
244 |
245 | # Fast Fourier Transform and Power Spectrum
246 | NFFT = 512
247 | mag_frames = np.absolute(np.fft.rfft(frames, NFFT)) # Magnitude of the FFT
248 | pow_frames = (1.0 / NFFT) * (mag_frames ** 2) # Power Spectrum
249 |
250 | # Log filter banks
251 | nfilt = 40
252 | low_freq_mel = 0
253 | high_freq_mel = 2595 * np.log10(
254 | 1 + (sample_rate / 2) / 700
255 | ) # Convert Hz to Mel
256 | mel_points = np.linspace(
257 | low_freq_mel, high_freq_mel, nfilt + 2
258 | ) # Equally spaced in Mel scale
259 | hz_points = 700 * (10 ** (mel_points / 2595) - 1) # Convert Mel to Hz
260 | bin = np.floor((NFFT + 1) * hz_points / sample_rate)
261 |
262 | fbank = np.zeros((nfilt, int(np.floor(NFFT / 2 + 1))))
263 | for m in range(1, nfilt + 1):
264 | f_m_minus = int(bin[m - 1]) # left
265 | f_m = int(bin[m]) # center
266 | f_m_plus = int(bin[m + 1]) # right
267 |
268 | for k in range(f_m_minus, f_m):
269 | fbank[m - 1, k] = (k - bin[m - 1]) / (bin[m] - bin[m - 1])
270 | for k in range(f_m, f_m_plus):
271 | fbank[m - 1, k] = (bin[m + 1] - k) / (bin[m + 1] - bin[m])
272 |
273 | filter_banks = np.dot(pow_frames, fbank.T)
274 | filter_banks = np.where(
275 | filter_banks == 0, np.finfo(float).eps, filter_banks
276 | ) # Numerical Stability
277 | filter_banks = 20 * np.log10(filter_banks) # dB
278 |
279 | audio.append(filter_banks), labels.append(label)
280 |
281 | return audio, torch.Tensor(labels)
282 |
--------------------------------------------------------------------------------
/bindsnet/datasets/torchvision_wrapper.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict
2 |
3 | import torch
4 | import torchvision
5 |
6 | from ..encoding import Encoder, NullEncoder
7 |
8 |
9 | def create_torchvision_dataset_wrapper(ds_type):
10 | """ Creates wrapper classes for datasets that output (image, label)
11 | from __getitem__. This is all of the datasets inside of torchvision.
12 | """
13 |
14 | if type(ds_type) == str:
15 | ds_type = getattr(torchvision.datasets, ds_type)
16 |
17 | class TorchvisionDatasetWrapper(ds_type):
18 | __doc__ = (
19 | """BindsNET torchvision dataset wrapper for:
20 |
21 | The core difference is the output of __getitem__ is no longer
22 | (image, label) rather a dictionary containing the image, label,
23 | and their encoded versions if encoders were provided.
24 |
25 | \n\n"""
26 | + str(ds_type)
27 | if ds_type.__doc__ is None
28 | else ds_type.__doc__
29 | )
30 |
31 | def __init__(
32 | self,
33 | image_encoder: Optional[Encoder] = None,
34 | label_encoder: Optional[Encoder] = None,
35 | *args,
36 | **kwargs
37 | ):
38 | # language=rst
39 | """
40 | Constructor for the BindsNET torchvision dataset wrapper.
41 | For details on the dataset you're interested in visit
42 |
43 | https://pytorch.org/docs/stable/torchvision/datasets.html
44 |
45 | :param image_encoder: Spike encoder for use on the image
46 | :param label_encoder: Spike encoder for use on the label
47 | :param *args: Arguments for the original dataset
48 | :param **kwargs: Keyword arguments for the original dataset
49 | """
50 | super().__init__(*args, **kwargs)
51 |
52 | self.args = args
53 | self.kwargs = kwargs
54 |
55 | # Allow the passthrough of None, but change to NullEncoder
56 | if image_encoder is None:
57 | image_encoder = NullEncoder()
58 |
59 | if label_encoder is None:
60 | label_encoder = NullEncoder()
61 |
62 | self.image_encoder = image_encoder
63 | self.label_encoder = label_encoder
64 |
65 | def __getitem__(self, ind: int) -> Dict[str, torch.Tensor]:
66 | """
67 | Utilizes the torchvision.dataset parent class to grab the
68 | data, then encodes using the supplied encoders.
69 |
70 | :param int ind: Index to grab data at
71 |
72 | :return: The relevant data and encoded data from the
73 | requested index.
74 | """
75 |
76 | image, label = super().__getitem__(ind)
77 |
78 | output = {
79 | "image": image,
80 | "label": label,
81 | "encoded_image": self.image_encoder(image),
82 | "encoded_label": self.label_encoder(label),
83 | }
84 |
85 | return output
86 |
87 | def __len__(self):
88 | return super().__len__()
89 |
90 | return TorchvisionDatasetWrapper
91 |
--------------------------------------------------------------------------------
/bindsnet/encoding/__init__.py:
--------------------------------------------------------------------------------
1 | from .encodings import single, repeat, bernoulli, poisson, rank_order
2 | from .loaders import bernoulli_loader, poisson_loader, rank_order_loader
3 | from .encoders import (
4 | Encoder,
5 | NullEncoder,
6 | SingleEncoder,
7 | RepeatEncoder,
8 | BernoulliEncoder,
9 | PoissonEncoder,
10 | RankOrderEncoder,
11 | )
12 |
--------------------------------------------------------------------------------
/bindsnet/encoding/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/encoding/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/encoding/__pycache__/encoders.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/encoding/__pycache__/encoders.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/encoding/__pycache__/encodings.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/encoding/__pycache__/encodings.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/encoding/__pycache__/loaders.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/encoding/__pycache__/loaders.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/encoding/encoders.py:
--------------------------------------------------------------------------------
1 | from . import encodings
2 |
3 |
4 | class Encoder:
5 | """
6 | Base class for spike encodings transforms.
7 |
8 | - Calls self.enc from the subclass and passes whatever arguments were
9 | provided. self.enc must be callable with torch.Tensor, *args, **kwargs
10 | """
11 |
12 | def __init__(self, *args, **kwargs) -> None:
13 | self.enc_args = args
14 | self.enc_kwargs = kwargs
15 |
16 | def __call__(self, img):
17 | return self.enc(img, *self.enc_args, **self.enc_kwargs)
18 |
19 |
20 | class NullEncoder(Encoder):
21 | """
22 | Pass through of the datum that was input.
23 |
24 | WARNING - this is not a real encoder into spikes. Be careful with
25 | the usage of this class.
26 | """
27 |
28 | def __init__(self):
29 | super().__init__()
30 |
31 | def __call__(self, img):
32 | return img
33 |
34 |
35 | class SingleEncoder(Encoder):
36 | def __init__(self, time: int, dt: float = 1.0, sparsity: float = 0.3, **kwargs):
37 | """
38 | Creates a callable SingleEncoder which encodes as defined in
39 | :code:`bindsnet.encodings.single`
40 |
41 | :param time: Length of single spike train per input variable.
42 | :param dt: Simulation time step.
43 | :param sparsity: Sparsity of the input representation. 0 for no spike and 1 for all spike.
44 | """
45 | super().__init__(time, dt=dt, sparsity=sparsity, **kwargs)
46 |
47 | self.enc = encodings.single
48 |
49 |
50 | class RepeatEncoder(Encoder):
51 | def __init__(self, time: int, dt: float = 1.0, **kwargs):
52 | """
53 | Creates a callable RepeatEncoder which encodes as defined in
54 | :code:`bindsnet.encodings.repeat`
55 |
56 | :param time: Length of repeat spike train per input variable.
57 | :param dt: Simulation time step.
58 | """
59 | super().__init__(time, dt=dt, **kwargs)
60 |
61 | self.enc = encodings.repeat
62 |
63 |
64 | class BernoulliEncoder(Encoder):
65 | def __init__(self, time: int, dt: float = 1.0, **kwargs):
66 | """
67 | Creates a callable BernoulliEncoder which encodes as defined in
68 | :code:`bindsnet.encodings.bernoulli`
69 |
70 | :param time: Length of Bernoulli spike train per input variable.
71 | :param dt: Simulation time step.
72 |
73 | Keyword arguments:
74 |
75 | :param float max_prob: Maximum probability of spike per Bernoulli trial.
76 | """
77 | super().__init__(time, dt=dt, **kwargs)
78 |
79 | self.enc = encodings.bernoulli
80 |
81 |
82 | class PoissonEncoder(Encoder):
83 | def __init__(self, time: int, dt: float = 1.0, **kwargs):
84 | """
85 | Creates a callable PoissonEncoder which encodes as defined in
86 | :code:`bindsnet.encodings.poisson`
87 |
88 | :param time: Length of Poisson spike train per input variable.
89 | :param dt: Simulation time step.
90 | """
91 | super().__init__(time, dt=dt, **kwargs)
92 |
93 | self.enc = encodings.poisson
94 |
95 |
96 | class RankOrderEncoder(Encoder):
97 | def __init__(self, time: int, dt: float = 1.0, **kwargs):
98 | """
99 | Creates a callable RankOrderEncoder which encodes as defined in
100 | :code:`bindsnet.encodings.rank_order`
101 |
102 | :param time: Length of RankOrder spike train per input variable.
103 | :param dt: Simulation time step.
104 | """
105 | super().__init__(time, dt=dt, **kwargs)
106 |
107 | self.enc = encodings.rank_order
108 |
--------------------------------------------------------------------------------
/bindsnet/encoding/encodings.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import numpy as np
5 |
6 |
7 | def single(
8 | datum: torch.Tensor, time: int, dt: float = 1.0, sparsity: float = 0.3, **kwargs
9 | ) -> torch.Tensor:
10 | # language=rst
11 | """
12 | Generates timing based single-spike encoding. Spike occurs earlier if the
13 | intensity of the input feature is higher. Features whose value is lower than
14 | threshold is remain silent.
15 |
16 | :param datum: Tensor of shape ``[n_1, ..., n_k]``.
17 | :param time: Length of the input and output.
18 | :param dt: Simulation time step.
19 | :param sparsity: Sparsity of the input representation. 0 for no spike and 1 for all spike.
20 | :return: Tensor of shape ``[time, n_1, ..., n_k]``.
21 | """
22 | time = int(time / dt)
23 | shape = list(datum.shape)
24 | datum = np.copy(datum)
25 | quantile = np.quantile(datum, 1 - sparsity)
26 | s = np.zeros([time, *shape])
27 | s[0] = np.where(datum > quantile, np.ones(shape), np.zeros(shape))
28 | return torch.Tensor(s).byte()
29 |
30 |
31 | def repeat(datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs) -> torch.Tensor:
32 | # language=rst
33 | """
34 | :param datum: Repeats a tensor along a new dimension in the 0th position for ``int(time / dt)`` timesteps.
35 | :param time: Tensor of shape ``[n_1, ..., n_k]``.
36 | :param dt: Simulation time step.
37 | :return: Tensor of shape ``[time, n_1, ..., n_k]`` of repeated data along the 0th dimension.
38 | """
39 | time = int(time / dt)
40 | return datum.repeat([time, *([1] * len(datum.shape))])
41 |
42 |
43 | def bernoulli(
44 | datum: torch.Tensor, time: Optional[int] = None, dt: float = 1.0, **kwargs
45 | ) -> torch.Tensor:
46 | # language=rst
47 | """
48 | :param datum: Generates Bernoulli-distributed spike trains based on input intensity. Inputs must be non-negative.
49 | Spikes correspond to successful Bernoulli trials, with success probability equal to (normalized in
50 | [0, 1]) input value.
51 | :param time: Tensor of shape ``[n_1, ..., n_k]``.
52 | :param dt: Simulation time step.
53 | :return: Tensor of shape ``[time, n_1, ..., n_k]`` of Bernoulli-distributed spikes.
54 |
55 | Keyword arguments:
56 |
57 | :param float max_prob: Maximum probability of spike per Bernoulli trial.
58 | """
59 | # Setting kwargs.
60 | max_prob = kwargs.get("max_prob", 1.0)
61 |
62 | assert 0 <= max_prob <= 1, "Maximum firing probability must be in range [0, 1]"
63 | assert (datum >= 0).all(), "Inputs must be non-negative"
64 |
65 | shape, size = datum.shape, datum.numel()
66 | datum = datum.view(-1)
67 |
68 | if time is not None:
69 | time = int(time / dt)
70 |
71 | # Normalize inputs and rescale (spike probability proportional to normalized intensity).
72 | if datum.max() > 1.0:
73 | datum /= datum.max()
74 |
75 | # Make spike data from Bernoulli sampling.
76 | if time is None:
77 | spikes = torch.bernoulli(max_prob * datum)
78 | spikes = spikes.view(*shape)
79 | else:
80 | spikes = torch.bernoulli(max_prob * datum.repeat([time, 1]))
81 | spikes = spikes.view(time, *shape)
82 |
83 | return spikes.byte()
84 |
85 |
86 | def poisson(datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs) -> torch.Tensor:
87 | # language=rst
88 | """
89 | Generates Poisson-distributed spike trains based on input intensity. Inputs must be non-negative, and give the
90 | firing rate in Hz. Inter-spike intervals (ISIs) for non-negative data incremented by one to avoid zero intervals
91 | while maintaining ISI distributions.
92 |
93 | For example, an input of intensity :code:`x` will have an average firing rate of :code:`x`Hz.
94 |
95 | :param datum: Tensor of shape ``[n_1, ..., n_k]``.
96 | :param time: Length of Poisson spike train per input variable.
97 | :param dt: Simulation time step.
98 | :return: Tensor of shape ``[time, n_1, ..., n_k]`` of Poisson-distributed spikes.
99 | """
100 | assert (datum >= 0).all(), "Inputs must be non-negative"
101 |
102 | # Get shape and size of data.
103 | shape, size = datum.shape, datum.numel()
104 | datum = datum.view(-1)
105 | time = int(time / dt)
106 |
107 | # Compute firing rates in seconds as function of data intensity,
108 | # accounting for simulation time step.
109 | rate = torch.zeros(size)
110 | rate[datum != 0] = 1 / datum[datum != 0] * (1000 / dt)
111 |
112 | # Create Poisson distribution and sample inter-spike intervals
113 | # (incrementing by 1 to avoid zero intervals).
114 | dist = torch.distributions.Poisson(rate=rate)
115 | intervals = dist.sample(sample_shape=torch.Size([time + 1]))
116 | intervals[:, datum != 0] += (intervals[:, datum != 0] == 0).float()
117 |
118 | # Calculate spike times by cumulatively summing over time dimension.
119 | times = torch.cumsum(intervals, dim=0).long()
120 | times[times >= time + 1] = 0
121 |
122 | # Create tensor of spikes.
123 | spikes = torch.zeros(time + 1, size).byte()
124 | spikes[times, torch.arange(size)] = 1
125 | spikes = spikes[1:]
126 |
127 | return spikes.view(time, *shape)
128 |
129 |
130 | def rank_order(
131 | datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs
132 | ) -> torch.Tensor:
133 | # language=rst
134 | """
135 | Encodes data via a rank order coding-like representation. One spike per neuron, temporally ordered by decreasing
136 | intensity. Inputs must be non-negative.
137 |
138 | :param datum: Tensor of shape ``[n_samples, n_1, ..., n_k]``.
139 | :param time: Length of rank order-encoded spike train per input variable.
140 | :param dt: Simulation time step.
141 | :return: Tensor of shape ``[time, n_1, ..., n_k]`` of rank order-encoded spikes.
142 | """
143 | assert (datum >= 0).all(), "Inputs must be non-negative"
144 |
145 | shape, size = datum.shape, datum.numel()
146 | datum = datum.view(-1)
147 | time = int(time / dt)
148 |
149 | # Create spike times in order of decreasing intensity.
150 | datum /= datum.max()
151 | times = torch.zeros(size)
152 | times[datum != 0] = 1 / datum[datum != 0]
153 | times *= time / times.max() # Extended through simulation time.
154 | times = torch.ceil(times).long()
155 |
156 | # Create spike times tensor.
157 | spikes = torch.zeros(time, size).byte()
158 | for i in range(size):
159 | if 0 < times[i] < time:
160 | spikes[times[i] - 1, i] = 1
161 |
162 | return spikes.reshape(time, *shape)
163 |
--------------------------------------------------------------------------------
/bindsnet/encoding/loaders.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Iterable, Iterator
2 |
3 | import torch
4 |
5 | from .encodings import bernoulli, poisson, rank_order
6 |
7 |
8 | def bernoulli_loader(
9 | data: Union[torch.Tensor, Iterable[torch.Tensor]],
10 | time: Optional[int] = None,
11 | dt: float = 1.0,
12 | **kwargs
13 | ) -> Iterator[torch.Tensor]:
14 | # language=rst
15 | """
16 | Lazily invokes ``bindsnet.encoding.bernoulli`` to iteratively encode a sequence of data.
17 |
18 | :param data: Tensor of shape ``[n_samples, n_1, ..., n_k]``.
19 | :param time: Length of Bernoulli spike train per input variable.
20 | :param dt: Simulation time step.
21 | :return: Tensors of shape ``[time, n_1, ..., n_k]`` of Bernoulli-distributed spikes.
22 |
23 | Keyword arguments:
24 |
25 | :param float max_prob: Maximum probability of spike per Bernoulli trial.
26 | """
27 | # Setting kwargs.
28 | max_prob = kwargs.get("dt", 1.0)
29 |
30 | for i in range(len(data)):
31 | # Encode datum as Bernoulli spike trains.
32 | yield bernoulli(datum=data[i], time=time, dt=dt, max_prob=max_prob)
33 |
34 |
35 | def poisson_loader(
36 | data: Union[torch.Tensor, Iterable[torch.Tensor]],
37 | time: int,
38 | dt: float = 1.0,
39 | **kwargs
40 | ) -> Iterator[torch.Tensor]:
41 | # language=rst
42 | """
43 | Lazily invokes ``bindsnet.encoding.poisson`` to iteratively encode a sequence of data.
44 |
45 | :param data: Tensor of shape ``[n_samples, n_1, ..., n_k]``.
46 | :param time: Length of Poisson spike train per input variable.
47 | :param dt: Simulation time step.
48 | :return: Tensors of shape ``[time, n_1, ..., n_k]`` of Poisson-distributed spikes.
49 | """
50 | for i in range(len(data)):
51 | # Encode datum as Poisson spike trains.
52 | yield poisson(datum=data[i], time=time, dt=dt)
53 |
54 |
55 | def rank_order_loader(
56 | data: Union[torch.Tensor, Iterable[torch.Tensor]],
57 | time: int,
58 | dt: float = 1.0,
59 | **kwargs
60 | ) -> Iterator[torch.Tensor]:
61 | # language=rst
62 | """
63 | Lazily invokes ``bindsnet.encoding.rank_order`` to iteratively encode a sequence of data.
64 |
65 | :param data: Tensor of shape ``[n_samples, n_1, ..., n_k]``.
66 | :param time: Length of rank order-encoded spike train per input variable.
67 | :param dt: Simulation time step.
68 | :return: Tensors of shape ``[time, n_1, ..., n_k]`` of rank order-encoded spikes.
69 | """
70 | for i in range(len(data)):
71 | # Encode datum as rank order-encoded spike trains.
72 | yield rank_order(datum=data[i], time=time, dt=dt)
73 |
--------------------------------------------------------------------------------
/bindsnet/environment/__init__.py:
--------------------------------------------------------------------------------
1 | from .environment import Environment, GymEnvironment
2 |
--------------------------------------------------------------------------------
/bindsnet/environment/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/environment/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/environment/__pycache__/environment.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/environment/__pycache__/environment.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/environment/environment.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Tuple, Dict, Any
3 |
4 | import gym
5 | import numpy as np
6 | import torch
7 |
8 | from ..datasets.preprocess import subsample, gray_scale, binary_image, crop
9 | from ..encoding import Encoder, NullEncoder
10 |
11 |
12 | class Environment(ABC):
13 | # language=rst
14 | """
15 | Abstract environment class.
16 | """
17 |
18 | @abstractmethod
19 | def step(self, a: int) -> Tuple[Any, ...]:
20 | # language=rst
21 | """
22 | Abstract method head for ``step()``.
23 |
24 | :param a: Integer action to take in environment.
25 | """
26 | pass
27 |
28 | @abstractmethod
29 | def reset(self) -> None:
30 | # language=rst
31 | """
32 | Abstract method header for ``reset()``.
33 | """
34 | pass
35 |
36 | @abstractmethod
37 | def render(self) -> None:
38 | # language=rst
39 | """
40 | Abstract method header for ``render()``.
41 | """
42 | pass
43 |
44 | @abstractmethod
45 | def close(self) -> None:
46 | # language=rst
47 | """
48 | Abstract method header for ``close()``.
49 | """
50 | pass
51 |
52 | @abstractmethod
53 | def preprocess(self) -> None:
54 | # language=rst
55 | """
56 | Abstract method header for ``preprocess()``.
57 | """
58 | pass
59 |
60 |
61 | class GymEnvironment(Environment):
62 | # language=rst
63 | """
64 | A wrapper around the OpenAI ``gym`` environments.
65 | """
66 |
67 | def __init__(self, name: str, encoder: Encoder = NullEncoder(), **kwargs) -> None:
68 | # language=rst
69 | """
70 | Initializes the environment wrapper. This class makes the
71 | assumption that the OpenAI ``gym`` environment will provide an image
72 | of format HxW or CxHxW as an observation (we will add the C
73 | dimension to HxW tensors) or a 1D observation in which case no
74 | dimensions will be added.
75 |
76 | :param name: The name of an OpenAI ``gym`` environment.
77 | :param encoder: Function to encode observations into spike trains.
78 |
79 | Keyword arguments:
80 |
81 | :param float max_prob: Maximum spiking probability.
82 | :param bool clip_rewards: Whether or not to use ``np.sign`` of rewards.
83 |
84 | :param int history: Number of observations to keep track of.
85 | :param int delta: Step size to save observations in history.
86 | :param bool add_channel_dim: Allows for the adding of the channel dimension in 2D inputs.
87 | """
88 | self.name = name
89 | self.env = gym.make(name)
90 | self.action_space = self.env.action_space
91 |
92 | self.encoder = encoder
93 |
94 | # Keyword arguments.
95 | self.max_prob = kwargs.get("max_prob", 1.0)
96 | self.clip_rewards = kwargs.get("clip_rewards", True)
97 |
98 | self.history_length = kwargs.get("history_length", None)
99 | self.delta = kwargs.get("delta", 1)
100 | self.add_channel_dim = kwargs.get("add_channel_dim", True)
101 |
102 | if self.history_length is not None and self.delta is not None:
103 | self.history = {
104 | i: torch.Tensor()
105 | for i in range(1, self.history_length * self.delta + 1, self.delta)
106 | }
107 | else:
108 | self.history = {}
109 |
110 | self.episode_step_count = 0
111 | self.history_index = 1
112 |
113 | self.obs = None
114 | self.reward = None
115 |
116 | assert (
117 | 0.0 < self.max_prob <= 1.0
118 | ), "Maximum spiking probability must be in (0, 1]."
119 |
120 | def step(self, a: int) -> Tuple[torch.Tensor, float, bool, Dict[Any, Any]]:
121 | # language=rst
122 | """
123 | Wrapper around the OpenAI ``gym`` environment ``step()`` function.
124 |
125 | :param a: Action to take in the environment.
126 | :return: Observation, reward, done flag, and information dictionary.
127 | """
128 | # Call gym's environment step function.
129 | self.obs, self.reward, self.done, info = self.env.step(a)
130 |
131 | if self.clip_rewards:
132 | self.reward = np.sign(self.reward)
133 |
134 | self.preprocess()
135 |
136 | # Add the raw observation from the gym environment into the info
137 | # for debugging and display.
138 | info["gym_obs"] = self.obs
139 |
140 | # Store frame of history and encode the inputs.
141 | if len(self.history) > 0:
142 | self.update_history()
143 | self.update_index()
144 | # Add the delta observation into the info for debugging and display.
145 | info["delta_obs"] = self.obs
146 |
147 | # The new standard for images is BxTxCxHxW.
148 | # The gym environment doesn't follow exactly the same protocol.
149 | #
150 | # 1D observations will be left as is before the encoder and will become BxTxL.
151 | # 2D observations are assumed to be mono images will become BxTx1xHxW
152 | # 3D observations will become BxTxCxHxW
153 | if self.obs.dim() == 2 and self.add_channel_dim:
154 | # We want CxHxW, it is currently HxW.
155 | self.obs = self.obs.unsqueeze(0)
156 |
157 | # The encoder will add time - now Tx...
158 | if self.encoder is not None:
159 | self.obs = self.encoder(self.obs)
160 |
161 | # Add the batch - now BxTx...
162 | self.obs = self.obs.unsqueeze(0)
163 |
164 | self.episode_step_count += 1
165 |
166 | # Return converted observations and other information.
167 | return self.obs, self.reward, self.done, info
168 |
169 | def reset(self) -> torch.Tensor:
170 | # language=rst
171 | """
172 | Wrapper around the OpenAI ``gym`` environment ``reset()`` function.
173 |
174 | :return: Observation from the environment.
175 | """
176 | # Call gym's environment reset function.
177 | self.obs = self.env.reset()
178 | self.preprocess()
179 |
180 | self.history = {i: torch.Tensor() for i in self.history}
181 |
182 | self.episode_step_count = 0
183 |
184 | return self.obs
185 |
186 | def render(self) -> None:
187 | # language=rst
188 | """
189 | Wrapper around the OpenAI ``gym`` environment ``render()`` function.
190 | """
191 | self.env.render()
192 |
193 | def close(self) -> None:
194 | # language=rst
195 | """
196 | Wrapper around the OpenAI ``gym`` environment ``close()`` function.
197 | """
198 | self.env.close()
199 |
200 | def preprocess(self) -> None:
201 | # language=rst
202 | """
203 | Pre-processing step for an observation from a ``gym`` environment.
204 | """
205 | if self.name == "SpaceInvaders-v0":
206 | self.obs = subsample(gray_scale(self.obs), 84, 110)
207 | self.obs = self.obs[26:104, :]
208 | self.obs = binary_image(self.obs)
209 | elif self.name == "BreakoutDeterministic-v4":
210 | self.obs = subsample(gray_scale(crop(self.obs, 34, 194, 0, 160)), 80, 80)
211 | self.obs = binary_image(self.obs)
212 | else: # Default pre-processing step.
213 | pass
214 |
215 | self.obs = torch.from_numpy(self.obs).float()
216 |
217 | def update_history(self) -> None:
218 | # language=rst
219 | """
220 | Updates the observations inside history by performing subtraction from most recent observation and the sum of
221 | previous observations. If there are not enough observations to take a difference from, simply store the
222 | observation without any differencing.
223 | """
224 | # Recording initial observations.
225 | if self.episode_step_count < len(self.history) * self.delta:
226 | # Store observation based on delta value.
227 | if self.episode_step_count % self.delta == 0:
228 | self.history[self.history_index] = self.obs
229 | else:
230 | # Take difference between stored frames and current frame.
231 | temp = torch.clamp(self.obs - sum(self.history.values()), 0, 1)
232 |
233 | # Store observation based on delta value.
234 | if self.episode_step_count % self.delta == 0:
235 | self.history[self.history_index] = self.obs
236 |
237 | assert (
238 | len(self.history) == self.history_length
239 | ), "History size is out of bounds"
240 | self.obs = temp
241 |
242 | def update_index(self) -> None:
243 | # language=rst
244 | """
245 | Updates the index to keep track of history. For example: ``history = 4``, ``delta = 3`` will produce
246 | ``self.history = {1, 4, 7, 10}`` and ``self.history_index`` will be updated according to ``self.delta``
247 | and will wrap around the history dictionary.
248 | """
249 | if self.episode_step_count % self.delta == 0:
250 | if self.history_index != max(self.history.keys()):
251 | self.history_index += self.delta
252 | else:
253 | # Wrap around the history.
254 | self.history_index = (self.history_index % max(self.history.keys())) + 1
255 |
--------------------------------------------------------------------------------
/bindsnet/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .evaluation import (
2 | assign_labels,
3 | logreg_fit,
4 | logreg_predict,
5 | all_activity,
6 | proportion_weighting,
7 | ngram,
8 | update_ngram_scores,
9 | )
10 |
--------------------------------------------------------------------------------
/bindsnet/evaluation/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/evaluation/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/evaluation/__pycache__/evaluation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/evaluation/__pycache__/evaluation.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | from itertools import product
2 | from typing import Optional, Tuple, Dict
3 |
4 | import torch
5 | from sklearn.linear_model import LogisticRegression
6 |
7 |
8 | def assign_labels(
9 | spikes: torch.Tensor,
10 | labels: torch.Tensor,
11 | n_labels: int,
12 | rates: Optional[torch.Tensor] = None,
13 | alpha: float = 1.0,
14 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
15 | # language=rst
16 | """
17 | Assign labels to the neurons based on highest average spiking activity.
18 |
19 | :param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a single layer's spiking activity.
20 | :param labels: Vector of shape ``(n_samples,)`` with data labels corresponding to spiking activity.
21 | :param n_labels: The number of target labels in the data.
22 | :param rates: If passed, these represent spike rates from a previous ``assign_labels()`` call.
23 | :param alpha: Rate of decay of label assignments.
24 | :return: Tuple of class assignments, per-class spike proportions, and per-class firing rates.
25 | """
26 | n_neurons = spikes.size(2)
27 |
28 | if rates is None:
29 | rates = torch.zeros_like(torch.Tensor(n_neurons, n_labels))
30 |
31 | # Sum over time dimension (spike ordering doesn't matter).
32 | spikes = spikes.sum(1)
33 |
34 | for i in range(n_labels):
35 | # Count the number of samples with this label.
36 | n_labeled = torch.sum(labels == i).float()
37 |
38 | if n_labeled > 0:
39 | # Get indices of samples with this label.
40 | indices = torch.nonzero(labels == i).view(-1)
41 |
42 | # Compute average firing rates for this label.
43 | rates[:, i] = alpha * rates[:, i] + (
44 | torch.sum(spikes[indices], 0) / n_labeled
45 | )
46 |
47 | # Compute proportions of spike activity per class.
48 | proportions = rates / rates.sum(1, keepdim=True)
49 | proportions[proportions != proportions] = 0 # Set NaNs to 0
50 |
51 | # Neuron assignments are the labels they fire most for.
52 | assignments = torch.max(proportions, 1)[1]
53 |
54 | return assignments, proportions, rates
55 |
56 |
57 | def logreg_fit(
58 | spikes: torch.Tensor, labels: torch.Tensor, logreg: LogisticRegression
59 | ) -> LogisticRegression:
60 | # language=rst
61 | """
62 | (Re)fit logistic regression model to spike data summed over time.
63 |
64 | :param spikes: Summed (over time) spikes of shape ``(n_examples, time, n_neurons)``.
65 | :param labels: Vector of shape ``(n_samples,)`` with data labels corresponding to spiking activity.
66 | :param logreg: Logistic regression model from previous fits.
67 | :return: (Re)fitted logistic regression model.
68 | """
69 | # (Re)fit logistic regression model.
70 | logreg.fit(spikes, labels)
71 | return logreg
72 |
73 |
74 | def logreg_predict(spikes: torch.Tensor, logreg: LogisticRegression) -> torch.Tensor:
75 | # language=rst
76 | """
77 | Predicts classes according to spike data summed over time.
78 |
79 | :param spikes: Summed (over time) spikes of shape ``(n_examples, time, n_neurons)``.
80 | :param logreg: Logistic regression model from previous fits.
81 | :return: Predictions per example.
82 | """
83 | # Make class label predictions.
84 | if not hasattr(logreg, "coef_") or logreg.coef_ is None:
85 | return -1 * torch.ones(spikes.size(0)).long()
86 |
87 | predictions = logreg.predict(spikes)
88 | return torch.Tensor(predictions).long()
89 |
90 |
91 | def all_activity(
92 | spikes: torch.Tensor, assignments: torch.Tensor, n_labels: int
93 | ) -> torch.Tensor:
94 | # language=rst
95 | """
96 | Classify data with the label with highest average spiking activity over all neurons.
97 |
98 | :param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a layer's spiking activity.
99 | :param assignments: A vector of shape ``(n_neurons,)`` of neuron label assignments.
100 | :param n_labels: The number of target labels in the data.
101 | :return: Predictions tensor of shape ``(n_samples,)`` resulting from the "all activity" classification scheme.
102 | """
103 | n_samples = spikes.size(0)
104 |
105 | # Sum over time dimension (spike ordering doesn't matter).
106 | spikes = spikes.sum(1)
107 |
108 | rates = torch.zeros(n_samples, n_labels)
109 | for i in range(n_labels):
110 | # Count the number of neurons with this label assignment.
111 | n_assigns = torch.sum(assignments == i).float()
112 |
113 | if n_assigns > 0:
114 | # Get indices of samples with this label.
115 | indices = torch.nonzero(assignments == i).view(-1)
116 |
117 | # Compute layer-wise firing rate for this label.
118 | rates[:, i] = torch.sum(spikes[:, indices], 1) / n_assigns
119 |
120 | # Predictions are arg-max of layer-wise firing rates.
121 | return torch.sort(rates, dim=1, descending=True)[1][:, 0]
122 |
123 |
124 | def proportion_weighting(
125 | spikes: torch.Tensor,
126 | assignments: torch.Tensor,
127 | proportions: torch.Tensor,
128 | n_labels: int,
129 | ) -> torch.Tensor:
130 | # language=rst
131 | """
132 | Classify data with the label with highest average spiking activity over all neurons, weighted by class-wise
133 | proportion.
134 |
135 | :param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a single layer's spiking activity.
136 | :param assignments: A vector of shape ``(n_neurons,)`` of neuron label assignments.
137 | :param proportions: A matrix of shape ``(n_neurons, n_labels)`` giving the per-class proportions of neuron spiking
138 | activity.
139 | :param n_labels: The number of target labels in the data.
140 | :return: Predictions tensor of shape ``(n_samples,)`` resulting from the "proportion weighting" classification
141 | scheme.
142 | """
143 | n_samples = spikes.size(0)
144 |
145 | # Sum over time dimension (spike ordering doesn't matter).
146 | spikes = spikes.sum(1)
147 |
148 | rates = torch.zeros(n_samples, n_labels)
149 | for i in range(n_labels):
150 | # Count the number of neurons with this label assignment.
151 | n_assigns = torch.sum(assignments == i).float()
152 |
153 | if n_assigns > 0:
154 | # Get indices of samples with this label.
155 | indices = torch.nonzero(assignments == i).view(-1)
156 |
157 | # Compute layer-wise firing rate for this label.
158 | rates[:, i] += (
159 | torch.sum((proportions[:, i] * spikes)[:, indices], 1) / n_assigns
160 | )
161 |
162 | # Predictions are arg-max of layer-wise firing rates.
163 | predictions = torch.sort(rates, dim=1, descending=True)[1][:, 0]
164 |
165 | return predictions
166 |
167 |
168 | def ngram(
169 | spikes: torch.Tensor,
170 | ngram_scores: Dict[Tuple[int, ...], torch.Tensor],
171 | n_labels: int,
172 | n: int,
173 | ) -> torch.Tensor:
174 | # language=rst
175 | """
176 | Predicts between ``n_labels`` using ``ngram_scores``.
177 |
178 | :param spikes: Spikes of shape ``(n_examples, time, n_neurons)``.
179 | :param ngram_scores: Previously recorded scores to update.
180 | :param n_labels: The number of target labels in the data.
181 | :param n: The max size of n-gram to use.
182 | :return: Predictions per example.
183 | """
184 | predictions = []
185 | for activity in spikes:
186 | score = torch.zeros(n_labels)
187 |
188 | # Aggregate all of the firing neurons' indices
189 | fire_order = []
190 | for t in range(activity.size()[0]):
191 | ordering = torch.nonzero(activity[t].view(-1))
192 | if ordering.numel() > 0:
193 | fire_order += ordering[:, 0].tolist()
194 |
195 | # Consider all n-gram sequences.
196 | for j in range(len(fire_order) - n):
197 | if tuple(fire_order[j : j + n]) in ngram_scores:
198 | score += ngram_scores[tuple(fire_order[j : j + n])]
199 |
200 | predictions.append(torch.argmax(score))
201 |
202 | return torch.Tensor(predictions).long()
203 |
204 |
205 | def update_ngram_scores(
206 | spikes: torch.Tensor,
207 | labels: torch.Tensor,
208 | n_labels: int,
209 | n: int,
210 | ngram_scores: Dict[Tuple[int, ...], torch.Tensor],
211 | ) -> Dict[Tuple[int, ...], torch.Tensor]:
212 | # language=rst
213 | """
214 | Updates ngram scores by adding the count of each spike sequence of length n from the past ``n_examples``.
215 |
216 | :param spikes: Spikes of shape ``(n_examples, time, n_neurons)``.
217 | :param labels: The ground truth labels of shape ``(n_examples)``.
218 | :param n_labels: The number of target labels in the data.
219 | :param n: The max size of n-gram to use.
220 | :param ngram_scores: Previously recorded scores to update.
221 | :return: Dictionary mapping n-grams to vectors of per-class spike counts.
222 | """
223 | for i, activity in enumerate(spikes):
224 | # Obtain firing order for spiking activity.
225 | fire_order = []
226 |
227 | # Aggregate all of the firing neurons' indices.
228 | for t in range(spikes.size(1)):
229 | # Gets the indices of the neurons which fired on this timestep.
230 | ordering = torch.nonzero(activity[t]).view(-1)
231 | if ordering.numel() > 0: # If there was more than one spike...
232 | # Add the indices of spiked neurons to the fire ordering.
233 | ordering = ordering.tolist()
234 | fire_order.append(ordering)
235 |
236 | # Check every sequence of length n.
237 | for order in zip(*(fire_order[k:] for k in range(n))):
238 | for sequence in product(*order):
239 | if sequence not in ngram_scores:
240 | ngram_scores[sequence] = torch.zeros(n_labels)
241 |
242 | ngram_scores[sequence][int(labels[i])] += 1
243 |
244 | return ngram_scores
245 |
--------------------------------------------------------------------------------
/bindsnet/learning/__init__.py:
--------------------------------------------------------------------------------
1 | from .learning import (
2 | LearningRule,
3 | NoOp,
4 | PostPre,
5 | WeightDependentPostPre,
6 | Hebbian,
7 | MSTDP,
8 | MSTDPET,
9 | Rmax,
10 | )
11 |
--------------------------------------------------------------------------------
/bindsnet/learning/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/learning/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/learning/__pycache__/learning.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/learning/__pycache__/learning.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/learning/__pycache__/reward.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/learning/__pycache__/reward.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/learning/reward.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch
4 |
5 |
6 | class AbstractReward(ABC):
7 | # language=rst
8 | """
9 | Abstract base class for reward computation.
10 | """
11 |
12 | @abstractmethod
13 | def compute(self, **kwargs) -> None:
14 | # language=rst
15 | """
16 | Computes/modifies reward.
17 | """
18 | pass
19 |
20 | @abstractmethod
21 | def update(self, **kwargs) -> None:
22 | # language=rst
23 | """
24 | Updates internal variables needed to modify reward. Usually called once per episode.
25 | """
26 | pass
27 |
28 |
29 | class MovingAvgRPE(AbstractReward):
30 | # language=rst
31 | """
32 | Computes reward prediction error (RPE) based on an exponential moving average (EMA) of past rewards.
33 | """
34 |
35 | def __init__(self, **kwargs) -> None:
36 | # language=rst
37 | """
38 | Constructor for EMA reward prediction error.
39 | """
40 | self.reward_predict = torch.tensor(0.0) # Predicted reward (per step).
41 | self.reward_predict_episode = torch.tensor(0.0) # Predicted reward per episode.
42 | self.rewards_predict_episode = (
43 | []
44 | ) # List of predicted rewards per episode (used for plotting).
45 |
46 | def compute(self, **kwargs) -> torch.Tensor:
47 | # language=rst
48 | """
49 | Computes the reward prediction error using EMA.
50 |
51 | Keyword arguments:
52 |
53 | :param Union[float, torch.Tensor] reward: Current reward.
54 | :return: Reward prediction error.
55 | """
56 | # Get keyword arguments.
57 | reward = kwargs["reward"]
58 |
59 | return reward - self.reward_predict
60 |
61 | def update(self, **kwargs) -> None:
62 | # language=rst
63 | """
64 | Updates the EMAs. Called once per episode.
65 |
66 | Keyword arguments:
67 |
68 | :param Union[float, torch.Tensor] accumulated_reward: Reward accumulated over one episode.
69 | :param int steps: Steps in that episode.
70 | :param float ema_window: Width of the averaging window.
71 | """
72 | # Get keyword arguments.
73 | accumulated_reward = kwargs["accumulated_reward"]
74 | steps = torch.tensor(kwargs["steps"]).float()
75 | ema_window = torch.tensor(kwargs.get("ema_window", 10.0))
76 |
77 | # Compute average reward per step.
78 | reward = accumulated_reward / steps
79 |
80 | # Update EMAs.
81 | self.reward_predict = (
82 | 1 - 1 / ema_window
83 | ) * self.reward_predict + 1 / ema_window * reward
84 | self.reward_predict_episode = (
85 | 1 - 1 / ema_window
86 | ) * self.reward_predict_episode + 1 / ema_window * accumulated_reward
87 | self.rewards_predict_episode.append(self.reward_predict_episode.item())
88 |
--------------------------------------------------------------------------------
/bindsnet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import (
2 | TwoLayerNetwork,
3 | DiehlAndCook2015,
4 | DiehlAndCook2015v2,
5 | IncreasingInhibitionNetwork,
6 | LocallyConnectedNetwork,
7 | )
8 |
--------------------------------------------------------------------------------
/bindsnet/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/models/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/models/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/network/__init__.py:
--------------------------------------------------------------------------------
1 | from .network import Network, load
2 | from . import nodes, topology, monitors
3 |
--------------------------------------------------------------------------------
/bindsnet/network/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/network/__pycache__/monitors.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/monitors.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/network/__pycache__/network.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/network.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/network/__pycache__/nodes.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/nodes.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/network/__pycache__/topology.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/topology.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/network/monitors.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 |
5 | from abc import ABC, abstractmethod
6 | from typing import Union, Optional, Iterable, Dict
7 |
8 | from .nodes import Nodes
9 | from .topology import AbstractConnection
10 |
11 |
12 | class AbstractMonitor(ABC):
13 | # language=rst
14 | """
15 | Abstract base class for state variable monitors.
16 | """
17 |
18 |
19 | class Monitor(AbstractMonitor):
20 | # language=rst
21 | """
22 | Records state variables of interest.
23 | """
24 |
25 | def __init__(
26 | self,
27 | obj: Union[Nodes, AbstractConnection],
28 | state_vars: Iterable[str],
29 | time: Optional[int] = None,
30 | batch_size: int = 1,
31 | ):
32 | # language=rst
33 | """
34 | Constructs a ``Monitor`` object.
35 |
36 | :param obj: An object to record state variables from during network simulation.
37 | :param state_vars: Iterable of strings indicating names of state variables to record.
38 | :param time: If not ``None``, pre-allocate memory for state variable recording.
39 | """
40 | super().__init__()
41 |
42 | self.obj = obj
43 | self.state_vars = state_vars
44 | self.time = time
45 | self.batch_size = batch_size
46 |
47 | # Deal with time later, the same underlying list is used
48 | self.recording = {v: [] for v in self.state_vars}
49 |
50 | def get(self, var: str) -> torch.Tensor:
51 | # language=rst
52 | """
53 | Return recording to user.
54 |
55 | :param var: State variable recording to return.
56 | :return: Tensor of shape ``[time, n_1, ..., n_k]``, where ``[n_1, ..., n_k]`` is the shape of the recorded
57 | state variable.
58 | """
59 | return torch.cat(self.recording[var], 0)
60 |
61 | def record(self) -> None:
62 | # language=rst
63 | """
64 | Appends the current value of the recorded state variables to the recording.
65 | """
66 | for v in self.state_vars:
67 | data = getattr(self.obj, v).unsqueeze(0)
68 | self.recording[v].append(data.detach().clone())
69 |
70 | # remove the oldest element (first in the list)
71 | if self.time is not None:
72 | for v in self.state_vars:
73 | if len(self.recording[v]) > self.time:
74 | self.recording[v].pop(0)
75 |
76 | def reset_(self) -> None:
77 | # language=rst
78 | """
79 | Resets recordings to empty ``torch.Tensor``s.
80 | """
81 | self.recording = {v: [] for v in self.state_vars}
82 |
83 |
84 | class NetworkMonitor(AbstractMonitor):
85 | # language=rst
86 | """
87 | Record state variables of all layers and connections.
88 | """
89 |
90 | def __init__(
91 | self,
92 | network: "Network",
93 | layers: Optional[Iterable[str]] = None,
94 | connections: Optional[Iterable[str]] = None,
95 | state_vars: Optional[Iterable[str]] = None,
96 | time: Optional[int] = None,
97 | ):
98 | # language=rst
99 | """
100 | Constructs a ``NetworkMonitor`` object.
101 |
102 | :param network: Network to record state variables from.
103 | :param layers: Layers to record state variables from.
104 | :param connections: Connections to record state variables from.
105 | :param state_vars: List of strings indicating names of state variables to record.
106 | :param time: If not ``None``, pre-allocate memory for state variable recording.
107 | """
108 | super().__init__()
109 |
110 | self.network = network
111 | self.layers = layers if layers is not None else list(self.network.layers.keys())
112 | self.connections = (
113 | connections
114 | if connections is not None
115 | else list(self.network.connections.keys())
116 | )
117 | self.state_vars = state_vars if state_vars is not None else ("v", "s", "w")
118 | self.time = time
119 |
120 | if self.time is not None:
121 | self.i = 0
122 |
123 | # Initialize empty recording.
124 | self.recording = {k: {} for k in self.layers + self.connections}
125 |
126 | # If no simulation time is specified, specify 0-dimensional recordings.
127 | if self.time is None:
128 | for v in self.state_vars:
129 | for l in self.layers:
130 | if hasattr(self.network.layers[l], v):
131 | self.recording[l][v] = torch.Tensor()
132 |
133 | for c in self.connections:
134 | if hasattr(self.network.connections[c], v):
135 | self.recording[c][v] = torch.Tensor()
136 |
137 | # If simulation time is specified, pre-allocate recordings in memory for speed.
138 | else:
139 | for v in self.state_vars:
140 | for l in self.layers:
141 | if hasattr(self.network.layers[l], v):
142 | self.recording[l][v] = torch.zeros(
143 | self.time, *getattr(self.network.layers[l], v).size()
144 | )
145 |
146 | for c in self.connections:
147 | if hasattr(self.network.connections[c], v):
148 | self.recording[c][v] = torch.zeros(
149 | self.time, *getattr(self.network.connections[c], v).size()
150 | )
151 |
152 | def get(self) -> Dict[str, Dict[str, Union[Nodes, AbstractConnection]]]:
153 | # language=rst
154 | """
155 | Return entire recording to user.
156 |
157 | :return: Dictionary of dictionary of all layers' and connections' recorded state variables.
158 | """
159 | return self.recording
160 |
161 | def record(self) -> None:
162 | # language=rst
163 | """
164 | Appends the current value of the recorded state variables to the recording.
165 | """
166 | if self.time is None:
167 | for v in self.state_vars:
168 | for l in self.layers:
169 | if hasattr(self.network.layers[l], v):
170 | data = getattr(self.network.layers[l], v).unsqueeze(0).float()
171 | self.recording[l][v] = torch.cat(
172 | (self.recording[l][v], data), 0
173 | )
174 |
175 | for c in self.connections:
176 | if hasattr(self.network.connections[c], v):
177 | data = getattr(self.network.connections[c], v).unsqueeze(0)
178 | self.recording[c][v] = torch.cat(
179 | (self.recording[c][v], data), 0
180 | )
181 |
182 | else:
183 | for v in self.state_vars:
184 | for l in self.layers:
185 | if hasattr(self.network.layers[l], v):
186 | data = getattr(self.network.layers[l], v).float().unsqueeze(0)
187 | self.recording[l][v] = torch.cat(
188 | (self.recording[l][v][1:].type(data.type()), data), 0
189 | )
190 |
191 | for c in self.connections:
192 | if hasattr(self.network.connections[c], v):
193 | data = getattr(self.network.connections[c], v).unsqueeze(0)
194 | self.recording[c][v] = torch.cat(
195 | (self.recording[c][v][1:].type(data.type()), data), 0
196 | )
197 |
198 | self.i += 1
199 |
200 | def save(self, path: str, fmt: str = "npz") -> None:
201 | # language=rst
202 | """
203 | Write the recording dictionary out to file.
204 |
205 | :param path: The directory to which to write the monitor's recording.
206 | :param fmt: Type of file to write to disk. One of ``"pickle"`` or ``"npz"``.
207 | """
208 | if not os.path.exists(os.path.dirname(path)):
209 | os.makedirs(os.path.dirname(path))
210 |
211 | if fmt == "npz":
212 | # Build a list of arrays to write to disk.
213 | arrays = {}
214 | for o in self.recording:
215 | if type(o) == tuple:
216 | arrays.update(
217 | {
218 | "_".join(["-".join(o), v]): self.recording[o][v]
219 | for v in self.recording[o]
220 | }
221 | )
222 | elif type(o) == str:
223 | arrays.update(
224 | {
225 | "_".join([o, v]): self.recording[o][v]
226 | for v in self.recording[o]
227 | }
228 | )
229 |
230 | np.savez_compressed(path, **arrays)
231 |
232 | elif fmt == "pickle":
233 | with open(path, "wb") as f:
234 | torch.save(self.recording, f)
235 |
236 | def reset_(self) -> None:
237 | # language=rst
238 | """
239 | Resets recordings to empty ``torch.Tensors``.
240 | """
241 | # Reset to empty recordings
242 | self.recording = {k: {} for k in self.layers + self.connections}
243 |
244 | if self.time is not None:
245 | self.i = 0
246 |
247 | # If no simulation time is specified, specify 0-dimensional recordings.
248 | if self.time is None:
249 | for v in self.state_vars:
250 | for l in self.layers:
251 | if hasattr(self.network.layers[l], v):
252 | self.recording[l][v] = torch.Tensor()
253 |
254 | for c in self.connections:
255 | if hasattr(self.network.connections[c], v):
256 | self.recording[c][v] = torch.Tensor()
257 |
258 | # If simulation time is specified, pre-allocate recordings in memory for speed.
259 | else:
260 | for v in self.state_vars:
261 | for l in self.layers:
262 | if hasattr(self.network.layers[l], v):
263 | self.recording[l][v] = torch.zeros(
264 | self.time, *getattr(self.network.layers[l], v).size()
265 | )
266 |
267 | for c in self.connections:
268 | if hasattr(self.network.connections[c], v):
269 | self.recording[c][v] = torch.zeros(
270 | self.time, *getattr(self.network.layers[c], v).size()
271 | )
272 |
--------------------------------------------------------------------------------
/bindsnet/network/network.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | from typing import Dict, Optional, Type, Iterable
3 |
4 | import torch
5 |
6 | from .monitors import AbstractMonitor
7 | from .nodes import AbstractInput, Nodes
8 | from .topology import AbstractConnection
9 | from ..learning.reward import AbstractReward
10 | import numpy as np
11 |
12 | def load(file_name: str, map_location: str = "cpu", learning: bool = None) -> "Network":
13 | # language=rst
14 | """
15 | Loads serialized network object from disk.
16 |
17 | :param file_name: Path to serialized network object on disk.
18 | :param map_location: One of ``"cpu"`` or ``"cuda"``. Defaults to ``"cpu"``.
19 | :param learning: Whether to load with learning enabled. Default loads value from disk.
20 | """
21 | network = torch.load(open(file_name, "rb"), map_location=map_location)
22 | if learning is not None and "learning" in vars(network):
23 | network.learning = learning
24 |
25 | return network
26 |
27 |
28 | class Network(torch.nn.Module):
29 | # language=rst
30 | """
31 | Most important object of the ``bindsnet`` package. Responsible for the simulation and interaction of nodes and
32 | connections.
33 |
34 | **Example:**
35 |
36 | .. code-block:: python
37 |
38 | import torch
39 | import matplotlib.pyplot as plt
40 |
41 | from bindsnet import encoding
42 | from bindsnet.network import Network, nodes, topology, monitors
43 |
44 | network = Network(dt=1.0) # Instantiates network.
45 |
46 | X = nodes.Input(100) # Input layer.
47 | Y = nodes.LIFNodes(100) # Layer of LIF neurons.
48 | C = topology.Connection(source=X, target=Y, w=torch.rand(X.n, Y.n)) # Connection from X to Y.
49 |
50 | # Spike monitor objects.
51 | M1 = monitors.Monitor(obj=X, state_vars=['s'])
52 | M2 = monitors.Monitor(obj=Y, state_vars=['s'])
53 |
54 | # Add everything to the network object.
55 | network.add_layer(layer=X, name='X')
56 | network.add_layer(layer=Y, name='Y')
57 | network.add_connection(connection=C, source='X', target='Y')
58 | network.add_monitor(monitor=M1, name='X')
59 | network.add_monitor(monitor=M2, name='Y')
60 |
61 | # Create Poisson-distributed spike train inputs.
62 | data = 15 * torch.rand(100) # Generate random Poisson rates for 100 input neurons.
63 | train = encoding.poisson(datum=data, time=5000) # Encode input as 5000ms Poisson spike trains.
64 |
65 | # Simulate network on generated spike trains.
66 | inpts = {'X' : train} # Create inputs mapping.
67 | network.run(inpts=inpts, time=5000) # Run network simulation.
68 |
69 | # Plot spikes of input and output layers.
70 | spikes = {'X' : M1.get('s'), 'Y' : M2.get('s')}
71 |
72 | fig, axes = plt.subplots(2, 1, figsize=(12, 7))
73 | for i, layer in enumerate(spikes):
74 | axes[i].matshow(spikes[layer], cmap='binary')
75 | axes[i].set_title('%s spikes' % layer)
76 | axes[i].set_xlabel('Time'); axes[i].set_ylabel('Index of neuron')
77 | axes[i].set_xticks(()); axes[i].set_yticks(())
78 | axes[i].set_aspect('auto')
79 |
80 | plt.tight_layout(); plt.show()
81 | """
82 |
83 | def __init__(
84 | self,
85 | dt: float = 1.0,
86 | batch_size: int = 1,
87 | learning: bool = True,
88 | reward_fn: Optional[Type[AbstractReward]] = None,
89 | ) -> None:
90 | # language=rst
91 | """
92 | Initializes network object.
93 |
94 | :param dt: Simulation timestep.
95 | :param learning: Whether to allow connection updates. True by default.
96 | :param reward_fn: Optional class allowing for modification of reward in case of reward-modulated learning.
97 | """
98 | super().__init__()
99 |
100 | self.dt = dt
101 | self.batch_size = batch_size
102 |
103 | self.layers = {}
104 | self.connections = {}
105 | self.monitors = {}
106 | self.train(learning)
107 |
108 | if reward_fn is not None:
109 | self.reward_fn = reward_fn()
110 | else:
111 | self.reward_fn = None
112 |
113 | def add_layer(self, layer: Nodes, name: str) -> None:
114 | # language=rst
115 | """
116 | Adds a layer of nodes to the network.
117 |
118 | :param layer: A subclass of the ``Nodes`` object.
119 | :param name: Logical name of layer.
120 | """
121 | self.layers[name] = layer
122 | self.add_module(name, layer)
123 |
124 | layer.train(self.learning)
125 | layer.compute_decays(self.dt)
126 | layer.set_batch_size(self.batch_size)
127 |
128 | def add_connection(
129 | self, connection: AbstractConnection, source: str, target: str
130 | ) -> None:
131 | # language=rst
132 | """
133 | Adds a connection between layers of nodes to the network.
134 |
135 | :param connection: An instance of class ``Connection``.
136 | :param source: Logical name of the connection's source layer.
137 | :param target: Logical name of the connection's target layer.
138 | """
139 | self.connections[(source, target)] = connection
140 | self.add_module(source + "_to_" + target, connection)
141 |
142 | connection.dt = self.dt
143 | connection.train(self.learning)
144 |
145 | def add_monitor(self, monitor: AbstractMonitor, name: str) -> None:
146 | # language=rst
147 | """
148 | Adds a monitor on a network object to the network.
149 |
150 | :param monitor: An instance of class ``Monitor``.
151 | :param name: Logical name of monitor object.
152 | """
153 | self.monitors[name] = monitor
154 | monitor.network = self
155 | monitor.dt = self.dt
156 |
157 | def save(self, file_name: str) -> None:
158 | # language=rst
159 | """
160 | Serializes the network object to disk.
161 |
162 | :param file_name: Path to store serialized network object on disk.
163 |
164 | **Example:**
165 |
166 | .. code-block:: python
167 |
168 | import torch
169 | import matplotlib.pyplot as plt
170 |
171 | from pathlib import Path
172 | from bindsnet.network import *
173 | from bindsnet.network import topology
174 |
175 | # Build simple network.
176 | network = Network(dt=1.0)
177 |
178 | X = nodes.Input(100) # Input layer.
179 | Y = nodes.LIFNodes(100) # Layer of LIF neurons.
180 | C = topology.Connection(source=X, target=Y, w=torch.rand(X.n, Y.n)) # Connection from X to Y.
181 |
182 | # Add everything to the network object.
183 | network.add_layer(layer=X, name='X')
184 | network.add_layer(layer=Y, name='Y')
185 | network.add_connection(connection=C, source='X', target='Y')
186 |
187 | # Save the network to disk.
188 | network.save(str(Path.home()) + '/network.pt')
189 | """
190 | torch.save(self, open(file_name, "wb"))
191 |
192 | def clone(self) -> "Network":
193 | # language=rst
194 | """
195 | Returns a cloned network object.
196 |
197 | :return: A copy of this network.
198 | """
199 | virtual_file = tempfile.SpooledTemporaryFile()
200 | torch.save(self, virtual_file)
201 | virtual_file.seek(0)
202 | return torch.load(virtual_file)
203 |
204 | def _get_inputs(self, layers: Iterable = None) -> Dict[str, torch.Tensor]:
205 | # language=rst
206 | """
207 | Fetches outputs from network layers to use as input to downstream layers.
208 |
209 | :param layers: Layers to update inputs for. Defaults to all network layers.
210 | :return: Inputs to all layers for the current iteration.
211 | """
212 | inpts = {}
213 |
214 | if layers is None:
215 | layers = self.layers
216 |
217 | # Loop over network connections.
218 | for c in self.connections:
219 | if c[1] in layers:
220 | # Fetch source and target populations.
221 | source = self.connections[c].source
222 | target = self.connections[c].target
223 |
224 | if not c[1] in inpts:
225 | inpts[c[1]] = torch.zeros(
226 | self.batch_size, *target.shape, device=target.s.device
227 | )
228 |
229 | # Add to input: source's spikes multiplied by connection weights.
230 | inpts[c[1]] += self.connections[c].compute(source.s)
231 |
232 | return inpts
233 |
234 | def run(self, inpts: Dict[str, torch.Tensor], time: int, acc: np.zeros,
235 | step:int, labels: torch.Tensor, one_step=False, **kwargs
236 | ) -> None:
237 | # language=rst
238 | """
239 | Simulate network for given inputs and time.
240 |
241 | :param inpts: Dictionary of ``Tensor``s of shape ``[time, *input_shape]`` or
242 | ``[batch_size, time, *input_shape]``.
243 | :param time: Simulation time.
244 | :param one_step: Whether to run the network in "feed-forward" mode, where inputs
245 | propagate all the way through the network in a single simulation time step.
246 | Layers are updated in the order they are added to the network.
247 |
248 | Keyword arguments:
249 |
250 | :param Dict[str, torch.Tensor] clamp: Mapping of layer names to boolean masks if
251 | neurons should be clamped to spiking. The ``Tensor``s have shape
252 | ``[n_neurons]`` or ``[time, n_neurons]``.
253 | :param Dict[str, torch.Tensor] unclamp: Mapping of layer names to boolean masks
254 | if neurons should be clamped to not spiking. The ``Tensor``s should have
255 | shape ``[n_neurons]`` or ``[time, n_neurons]``.
256 | :param Dict[str, torch.Tensor] injects_v: Mapping of layer names to boolean
257 | masks if neurons should be added voltage. The ``Tensor``s should have shape
258 | ``[n_neurons]`` or ``[time, n_neurons]``.
259 | :param Union[float, torch.Tensor] reward: Scalar value used in reward-modulated
260 | learning.
261 | :param Dict[Tuple[str], torch.Tensor] masks: Mapping of connection names to
262 | boolean masks determining which weights to clamp to zero.
263 |
264 | **Example:**
265 |
266 | .. code-block:: python
267 |
268 | import torch
269 | import matplotlib.pyplot as plt
270 |
271 | from bindsnet.network import Network
272 | from bindsnet.network.nodes import Input
273 | from bindsnet.network.monitors import Monitor
274 |
275 | # Build simple network.
276 | network = Network()
277 | network.add_layer(Input(500), name='I')
278 | network.add_monitor(Monitor(network.layers['I'], state_vars=['s']), 'I')
279 |
280 | # Generate spikes by running Bernoulli trials on Uniform(0, 0.5) samples.
281 | spikes = torch.bernoulli(0.5 * torch.rand(500, 500))
282 |
283 | # Run network simulation.
284 | network.run(inpts={'I' : spikes}, time=500)
285 |
286 | # Look at input spiking activity.
287 | spikes = network.monitors['I'].get('s')
288 | plt.matshow(spikes, cmap='binary')
289 | plt.xticks(()); plt.yticks(());
290 | plt.xlabel('Time'); plt.ylabel('Neuron index')
291 | plt.title('Input spiking')
292 | plt.show()
293 | """
294 | # Parse keyword arguments.
295 | clamps = kwargs.get("clamp", {})
296 | unclamps = kwargs.get("unclamp", {})
297 | masks = kwargs.get("masks", {})
298 | injects_v = kwargs.get("injects_v", {})
299 |
300 | # Compute reward.
301 | if self.reward_fn is not None:
302 | kwargs["reward"] = self.reward_fn.compute(**kwargs)
303 |
304 | # Dynamic setting of batch size.
305 | if inpts != {}:
306 | for key in inpts:
307 | # goal shape is [time, batch, n_0, ...]
308 | if len(inpts[key].size()) == 1:
309 | # current shape is [n_0, ...]
310 | # unsqueeze twice to make [1, 1, n_0, ...]
311 | inpts[key] = inpts[key].unsqueeze(0).unsqueeze(0)
312 | elif len(inpts[key].size()) == 2:
313 | # current shape is [time, n_0, ...]
314 | # unsqueeze dim 1 so that we have
315 | # [time, 1, n_0, ...]
316 | inpts[key] = inpts[key].unsqueeze(1)
317 |
318 | for key in inpts:
319 | # batch dimension is 1, grab this and use for batch size
320 | if inpts[key].size(1) != self.batch_size:
321 | self.batch_size = inpts[key].size(1)
322 |
323 | for l in self.layers:
324 | self.layers[l].set_batch_size(self.batch_size)
325 |
326 | for m in self.monitors:
327 | self.monitors[m].reset_()
328 |
329 | break
330 |
331 | # Effective number of timesteps.
332 | timesteps = int(time / self.dt)
333 |
334 | # Get input to all layers (synchronous mode).
335 | if not one_step:
336 | inpts.update(self._get_inputs())
337 |
338 | # Simulate network activity for `time` timesteps.count the frequency in tensor pytorch
339 |
340 | for t in range(timesteps):
341 | for l in self.layers:
342 | if isinstance(self.layers[l], AbstractInput):
343 | # shape is [time, batch, n_0, ...]
344 | self.layers[l].forward(x=inpts[l][t])
345 | else:
346 | if one_step:
347 | # Get input to this layer (one-step mode).
348 | inpts.update(self._get_inputs(layers=[l]))
349 | self.layers[l].forward(x=inpts[l])
350 |
351 | # Clamp neurons to spike.
352 | clamp = clamps.get(l, None)
353 | if clamp is not None:
354 | if clamp.ndimension() == 1:
355 | self.layers[l].s[:, clamp] = 1
356 | else:
357 | self.layers[l].s[:, clamp[t]] = 1
358 |
359 | # Clamp neurons not to spike.
360 | unclamp = unclamps.get(l, None)
361 | if unclamp is not None:
362 | if unclamp.ndimension() == 1:
363 | self.layers[l].s[unclamp] = 0
364 | else:
365 | self.layers[l].s[unclamp[t]] = 0
366 |
367 | # Inject voltage to neurons.
368 | inject_v = injects_v.get(l, None)
369 | if inject_v is not None:
370 | if inject_v.ndimension() == 1:
371 | self.layers[l].v += inject_v
372 | else:
373 | self.layers[l].v += inject_v[t]
374 |
375 | # Run synapse updates.
376 | for c in self.connections:
377 | self.connections[c].update(
378 | mask=masks.get(c, None), learning=self.learning, **kwargs
379 | )
380 |
381 | # Get input to all layers.
382 | inpts.update(self._get_inputs())
383 |
384 | # Record state variables of interest.
385 | for m in self.monitors:
386 | self.monitors[m].record()
387 | last_layer = list(self.layers.keys())[-1]
388 | output_voltages = self.layers[last_layer].summed
389 | prediction = torch.softmax(output_voltages, dim=1).argmax(dim=1)
390 | # print(output_voltages)
391 | for i, p in enumerate(prediction):
392 | maxxi=output_voltages.max(dim=1)
393 | correct = (prediction.cpu() == labels).sum().item()
394 | acc[t][step] = correct
395 | # Re-normalize connections.
396 | for c in self.connections:
397 | self.connections[c].normalize()
398 |
399 | def reset_(self) -> None:
400 | # language=rst
401 | """
402 | Reset state variables of objects in network.
403 | """
404 | for layer in self.layers:
405 | self.layers[layer].reset_()
406 |
407 | for connection in self.connections:
408 | self.connections[connection].reset_()
409 |
410 | for monitor in self.monitors:
411 | self.monitors[monitor].reset_()
412 |
413 | def train(self, mode: bool = True) -> "torch.nn.Module":
414 | # language=rst
415 | """Sets the node in training mode.
416 |
417 | :param mode: Turn training on or off.
418 |
419 | :return: ``self`` as specified in ``torch.nn.Module``.
420 | """
421 | self.learning = mode
422 | return super().train(mode)
423 |
--------------------------------------------------------------------------------
/bindsnet/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | from .environment_pipeline import EnvironmentPipeline
2 | from .base_pipeline import BasePipeline
3 | from .dataloader_pipeline import DataLoaderPipeline, TorchVisionDatasetPipeline
4 | from . import action
5 |
--------------------------------------------------------------------------------
/bindsnet/pipeline/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/pipeline/__pycache__/action.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/action.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/pipeline/__pycache__/base_pipeline.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/base_pipeline.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/pipeline/__pycache__/environment_pipeline.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/environment_pipeline.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/pipeline/action.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from . import EnvironmentPipeline
5 |
6 |
7 | def select_multinomial(pipeline: EnvironmentPipeline, **kwargs) -> int:
8 | # language=rst
9 | """
10 | Selects an action probabilistically based on spiking activity from a network layer.
11 |
12 | :param pipeline: EnvironmentPipeline with environment that has an integer action space.
13 | :return: Action sampled from multinomial over activity of similarly-sized output layer.
14 |
15 | Keyword arguments:
16 |
17 | :param str output: Name of output layer whose activity to base action selection on.
18 | """
19 | try:
20 | output = kwargs["output"]
21 | except KeyError:
22 | raise KeyError('select_multinomial() requires an "output" layer argument.')
23 |
24 | output = pipeline.network.layers[output]
25 | action_space = pipeline.env.action_space
26 |
27 | assert (
28 | output.n % action_space.n == 0
29 | ), f"Output layer size of {output.n} is not divisible by action space size of {action_space.n}."
30 |
31 | pop_size = int(output.n / action_space.n)
32 | spikes = output.s
33 | _sum = spikes.sum().float()
34 |
35 | # Choose action based on population's spiking.
36 | if _sum == 0:
37 | action = np.random.choice(pipeline.env.action_space.n)
38 | else:
39 | pop_spikes = torch.Tensor(
40 | [
41 | spikes[(i * pop_size) : (i * pop_size) + pop_size].sum()
42 | for i in range(action_space.n)
43 | ]
44 | )
45 | action = torch.multinomial((pop_spikes.float() / _sum).view(-1), 1)[0].item()
46 |
47 | return action
48 |
49 |
50 | def select_softmax(pipeline: EnvironmentPipeline, **kwargs) -> int:
51 | # language=rst
52 | """
53 | Selects an action using softmax function based on spiking from a network layer.
54 |
55 | :param pipeline: EnvironmentPipeline with environment that has an integer action space.
56 | :return: Action sampled from softmax over activity of similarly-sized output layer.
57 |
58 | Keyword arguments:
59 |
60 | :param str output: Name of output layer whose activity to base action selection on.
61 | """
62 | try:
63 | output = kwargs["output"]
64 | except KeyError:
65 | raise KeyError('select_softmax() requires an "output" layer argument.')
66 |
67 | assert (
68 | pipeline.network.layers[output].n == pipeline.env.action_space.n
69 | ), "Output layer size is not equal to the size of the action space."
70 |
71 | assert hasattr(
72 | pipeline, "spike_record"
73 | ), "EnvironmentPipeline is missing the attribute: spike_record."
74 |
75 | # Sum of previous iterations' spikes (Not yet implemented)
76 | spikes = torch.sum(pipeline.spike_record[output], dim=1)
77 | _sum = torch.sum(torch.exp(spikes.float()))
78 |
79 | if _sum == 0:
80 | action = np.random.choice(pipeline.env.action_space.n)
81 | else:
82 | action = torch.multinomial((torch.exp(spikes.float()) / _sum).view(-1), 1)[0]
83 |
84 | return action
85 |
86 |
87 | def select_random(pipeline: EnvironmentPipeline, **kwargs) -> int:
88 | # language=rst
89 | """
90 | Selects an action randomly from the action space.
91 |
92 | :param pipeline: EnvironmentPipeline with environment that has an integer action space.
93 | :return: Action randomly sampled over size of pipeline's action space.
94 | """
95 | # Choose action randomly from the action space.
96 | return np.random.choice(pipeline.env.action_space.n)
97 |
--------------------------------------------------------------------------------
/bindsnet/pipeline/base_pipeline.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Tuple, Dict, Any
3 |
4 | import torch
5 | from torch._six import container_abcs, string_classes
6 |
7 | from ..network import Network
8 | from ..network.monitors import Monitor
9 |
10 |
11 | def recursive_to(item, device):
12 | # language=rst
13 | """
14 | Recursively transfers everything contained in item to the target
15 | device.
16 |
17 | :param item: An individual tensor or container of tensors.
18 | :param device: ``torch.device`` pointing to ``"cuda"`` or ``"cpu"``.
19 |
20 | :return: A version of the item that has been sent to a device.
21 | """
22 |
23 | if isinstance(item, torch.Tensor):
24 | return item.to(device)
25 | elif isinstance(item, (string_classes, int, float, bool)):
26 | return item
27 | elif isinstance(item, container_abcs.Mapping):
28 | return {key: recursive_to(item[key], device) for key in item}
29 | elif isinstance(item, tuple) and hasattr(item, "_fields"):
30 | return type(item)(*(recursive_to(i, device) for i in item))
31 | elif isinstance(item, container_abcs.Sequence):
32 | return [recursive_to(i, device) for i in item]
33 | else:
34 | raise NotImplementedError(f"Target type {type(item)} not supported.")
35 |
36 |
37 | class BasePipeline:
38 | # language=rst
39 | """
40 | A generic pipeline that handles high level functionality.
41 | """
42 |
43 | def __init__(self, network: Network, **kwargs) -> None:
44 | # language=rst
45 | """
46 | Initializes the pipeline.
47 |
48 | :param network: Arbitrary network object, will be managed by the ``BasePipeline`` class.
49 |
50 | Keyword arguments:
51 |
52 | :param int save_interval: How often to save the network to disk.
53 | :param str save_dir: Directory to save network object to.
54 | :param Dict[str, Any] plot_config: Dict containing the plot configuration. Includes length,
55 | type (``"color"`` or ``"line"``), and interval per plot type.
56 | :param int print_interval: Interval to print text output.
57 | :param bool allow_gpu: Allows automatic transfer to the GPU.
58 | """
59 | self.network = network
60 |
61 | # Network saving handles caching of intermediate results.
62 | self.save_dir = kwargs.get("save_dir", "network.pt")
63 | self.save_interval = kwargs.get("save_interval", None)
64 |
65 | # Handles plotting of all layer spikes and voltages.
66 | # This constructs monitors at every level.
67 | self.plot_config = kwargs.get(
68 | "plot_config", {"data_step": None, "data_length": 10}
69 | )
70 |
71 | if self.plot_config["data_step"] is not None:
72 | for l in self.network.layers:
73 | self.network.add_monitor(
74 | Monitor(
75 | self.network.layers[l], "s", self.plot_config["data_length"]
76 | ),
77 | name=f"{l}_spikes",
78 | )
79 | if hasattr(self.network.layers[l], "v"):
80 | self.network.add_monitor(
81 | Monitor(
82 | self.network.layers[l], "v", self.plot_config["data_length"]
83 | ),
84 | name=f"{l}_voltages",
85 | )
86 |
87 | self.print_interval = kwargs.get("print_interval", None)
88 |
89 | self.test_interval = kwargs.get("test_interval", None)
90 |
91 | self.step_count = 0
92 |
93 | self.init_fn()
94 |
95 | self.clock = time.time()
96 |
97 | self.allow_gpu = kwargs.get("allow_gpu", True)
98 |
99 | if torch.cuda.is_available() and self.allow_gpu:
100 | self.device = torch.device("cuda")
101 | else:
102 | self.device = torch.device("cpu")
103 |
104 | self.network.to(self.device)
105 |
106 | def reset_(self) -> None:
107 | # language=rst
108 | """
109 | Reset the pipeline.
110 | """
111 |
112 | self.network.reset_()
113 | self.step_count = 0
114 |
115 | def step(self, batch: Any, **kwargs) -> Any:
116 | # language=rst
117 | """
118 | Single step of any pipeline at a high level.
119 |
120 | :param batch: A batch of inputs to be handed to the ``step_()`` function.
121 | Standard in subclasses of ``BasePipeline``.
122 |
123 | :return: The output from the subclass's ``step_()`` method, which could be anything.
124 | Passed to plotting to accommodate this.
125 | """
126 | self.step_count += 1
127 |
128 | batch = recursive_to(batch, self.device)
129 |
130 | step_out = self.step_(batch, **kwargs)
131 |
132 | if (
133 | self.print_interval is not None
134 | and self.step_count % self.print_interval == 0
135 | ):
136 | print(
137 | f"Iteration: {self.step_count} (Time: {time.time() - self.clock:.4f})"
138 | )
139 | self.clock = time.time()
140 |
141 | # if self.plot_interval is not None and self.step_count % self.plot_interval == 0:
142 | self.plots(batch, step_out)
143 |
144 | if self.save_interval is not None and self.step_count % self.save_interval == 0:
145 | self.network.save(self.save_dir)
146 |
147 | if self.test_interval is not None and self.step_count % self.test_interval == 0:
148 | self.test()
149 |
150 | return step_out
151 |
152 | def get_spike_data(self) -> Dict[str, torch.Tensor]:
153 | # language=rst
154 | """
155 | Get the spike data from all layers in the pipeline's network.
156 |
157 | :return: A dictionary containing all spike monitors from the network.
158 | """
159 | return {
160 | l: self.network.monitors[f"{l}_spikes"].get("s")
161 | for l in self.network.layers
162 | }
163 |
164 | def get_voltage_data(
165 | self
166 | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
167 | # language=rst
168 | """
169 | Get the voltage data and threshold value from all applicable layers in the pipeline's network.
170 |
171 | :return: Two dictionaries containing the voltage data and threshold values from the network.
172 | """
173 | voltage_record = {}
174 | threshold_value = {}
175 | for l in self.network.layers:
176 | if hasattr(self.network.layers[l], "v"):
177 | voltage_record[l] = self.network.monitors[f"{l}_voltages"].get("v")
178 | if hasattr(self.network.layers[l], "thresh"):
179 | threshold_value[l] = self.network.layers[l].thresh
180 |
181 | return voltage_record, threshold_value
182 |
183 | def step_(self, batch: Any, **kwargs) -> Any:
184 | # language=rst
185 | """
186 | Perform a pass of the network given the input batch.
187 |
188 | :param batch: The current batch. This could be anything as long as
189 | the subclass agrees upon the format in some way.
190 |
191 | :return: Any output that is need for recording purposes.
192 | """
193 | raise NotImplementedError("You need to provide a step_ method.")
194 |
195 | def train(self) -> None:
196 | # language=rst
197 | """
198 | A fully self-contained training loop.
199 | """
200 | raise NotImplementedError("You need to provide a train method.")
201 |
202 | def test(self) -> None:
203 | # language=rst
204 | """
205 | A fully self contained test function.
206 | """
207 | raise NotImplementedError("You need to provide a test method.")
208 |
209 | def init_fn(self) -> None:
210 | # language=rst
211 | """
212 | Placeholder function for subclass-specific actions that need to
213 | happen during the construction of the ``BasePipeline``.
214 | """
215 | raise NotImplementedError("You need to provide an init_fn method.")
216 |
217 | def plots(self, batch: Any, step_out: Any) -> None:
218 | # language=rst
219 | """
220 | Create any plots and logs for a step given the input batch and step output.
221 |
222 | :param batch: The current batch. This could be anything as long as
223 | the subclass agrees upon the format in some way.
224 | :param step_out: The output from the ``step_()`` method.
225 | """
226 | raise NotImplementedError("You need to provide a plots method.")
227 |
--------------------------------------------------------------------------------
/bindsnet/pipeline/dataloader_pipeline.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict
2 |
3 | import torch
4 | from torch.utils.data import Dataset
5 | from tqdm import tqdm
6 |
7 | from ..network import Network
8 | from .base_pipeline import BasePipeline
9 | from ..analysis.pipeline_analysis import PipelineAnalyzer
10 | from ..datasets import DataLoader
11 |
12 |
13 | class DataLoaderPipeline(BasePipeline):
14 | # language=rst
15 | """
16 | A generic ``DataLoader`` pipeline that leverages the ``torch.utils.data``
17 | setup. This still needs to be subclasses for specific
18 | implementations for functions given the dataset that will be used.
19 | An example can be seen in ``TorchVisionDatasetPipeline``.
20 | """
21 |
22 | def __init__(
23 | self,
24 | network: Network,
25 | train_ds: Dataset,
26 | test_ds: Optional[Dataset] = None,
27 | **kwargs
28 | ) -> None:
29 | # language=rst
30 | """
31 | Initializes the pipeline.
32 |
33 | :param network: Arbitrary ``network`` object.
34 | :param train_ds: Arbitrary ``torch.utils.data.Dataset`` object.
35 | :param test_ds: Arbitrary ``torch.utils.data.Dataset`` object.
36 | """
37 | super().__init__(network, **kwargs)
38 |
39 | self.train_ds = train_ds
40 | self.test_ds = test_ds
41 |
42 | self.num_epochs = kwargs.get("num_epochs", 10)
43 | self.batch_size = kwargs.get("batch_size", 1)
44 | self.num_workers = kwargs.get("num_workers", 0)
45 | self.pin_memory = kwargs.get("pin_memory", True)
46 | self.shuffle = kwargs.get("shuffle", True)
47 |
48 | def train(self) -> None:
49 | # language=rst
50 | """
51 | Training loop that runs for the set number of epochs and creates
52 | a new ``DataLoader`` at each epoch.
53 | """
54 | for epoch in range(self.num_epochs):
55 | train_dataloader = DataLoader(
56 | self.train_ds,
57 | batch_size=self.batch_size,
58 | num_workers=self.num_workers,
59 | pin_memory=self.pin_memory,
60 | shuffle=self.shuffle,
61 | )
62 |
63 | for step, batch in enumerate(
64 | tqdm(
65 | train_dataloader,
66 | desc="Epoch %d/%d" % (epoch + 1, self.num_epochs),
67 | total=len(self.train_ds) // self.batch_size,
68 | )
69 | ):
70 | self.step(batch)
71 |
72 | def test(self) -> None:
73 | raise NotImplementedError("You need to provide a test function.")
74 |
75 |
76 | class TorchVisionDatasetPipeline(DataLoaderPipeline):
77 | # language=rst
78 | """
79 | An example implementation of ``DataLoaderPipeline`` that runs all of the
80 | datasets inside of ``bindsnet.datasets`` that inherit from an instance
81 | of a ``torchvision.datasets``. These are documented in
82 | ``bindsnet/datasets/README.md``. This specific class just runs an
83 | unsupervised network.
84 | """
85 |
86 | def __init__(
87 | self,
88 | network: Network,
89 | train_ds: Dataset,
90 | pipeline_analyzer: Optional[PipelineAnalyzer] = None,
91 | **kwargs
92 | ) -> None:
93 | # language=rst
94 | """
95 | Initializes the pipeline.
96 |
97 | :param network: Arbitrary ``network`` object.
98 | :param train_ds: A ``torchvision.datasets`` wrapper dataset from ``bindsnet.datasets``.
99 |
100 | Keyword arguments:
101 |
102 | :param str input_layer: Layer of the network that receives input.
103 | """
104 | super().__init__(network, train_ds, None, **kwargs)
105 |
106 | self.input_layer = kwargs.get("input_layer", "X")
107 | self.pipeline_analyzer = pipeline_analyzer
108 |
109 | def step_(self, batch: Dict[str, torch.Tensor], **kwargs) -> None:
110 | # language=rst
111 | """
112 | Perform a pass of the network given the input batch. Unsupervised training
113 | (implying everything is stored inside of the ``network`` object, therefore returns ``None``.
114 |
115 | :param batch: A dictionary of the current batch. Includes image,
116 | label and encoded versions.
117 | """
118 | self.network.reset_()
119 | inpts = {self.input_layer: batch["encoded_image"]}
120 | self.network.run(inpts, time=batch["encoded_image"].shape[0])
121 |
122 | def init_fn(self) -> None:
123 | pass
124 |
125 | def plots(self, batch: Dict[str, torch.Tensor], *args) -> None:
126 | # language=rst
127 | """
128 | Create any plots and logs for a step given the input batch.
129 |
130 | :param batch: A dictionary of the current batch. Includes image,
131 | label and encoded versions.
132 | """
133 | if self.pipeline_analyzer is not None:
134 | self.pipeline_analyzer.plot_obs(
135 | batch["encoded_image"][0, ...].sum(0), step=self.step_count
136 | )
137 |
138 | self.pipeline_analyzer.plot_spikes(
139 | self.get_spike_data(), step=self.step_count
140 | )
141 |
142 | vr, tv = self.get_voltage_data()
143 | self.pipeline_analyzer.plot_voltages(vr, tv, step=self.step_count)
144 |
145 | self.pipeline_analyzer.finalize_step()
146 |
147 | def test_step(self):
148 | pass
149 |
--------------------------------------------------------------------------------
/bindsnet/pipeline/environment_pipeline.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from typing import Callable, Optional, Tuple, Dict
3 |
4 | import torch
5 |
6 | from .base_pipeline import BasePipeline
7 | from ..analysis.pipeline_analysis import MatplotlibAnalyzer
8 | from ..environment import Environment
9 | from ..network import Network
10 | from ..network.nodes import AbstractInput
11 |
12 |
13 | class EnvironmentPipeline(BasePipeline):
14 | # language=rst
15 | """
16 | Abstracts the interaction between ``Network``, ``Environment`` and environment feedback action.
17 | """
18 |
19 | def __init__(
20 | self,
21 | network: Network,
22 | environment: Environment,
23 | action_function: Optional[Callable] = None,
24 | **kwargs,
25 | ):
26 | # language=rst
27 | """
28 | Initializes the pipeline.
29 |
30 | :param network: Arbitrary network object.
31 | :param environment: Arbitrary environment.
32 | :param action_function: Function to convert network outputs into environment inputs.
33 |
34 | Keyword arguments:
35 |
36 | :param int num_episodes: Number of episodes to train for. Defaults to 100.
37 | :param str output: String name of the layer from which to take output.
38 | :param int render_interval: Interval to render the environment.
39 | :param int reward_delay: How many iterations to delay delivery of reward.
40 | :param int time: Time for which to run the network. Defaults to the network's timestep.
41 | """
42 | super().__init__(network, **kwargs)
43 |
44 | self.episode = 0
45 |
46 | self.env = environment
47 | self.action_function = action_function
48 |
49 | self.accumulated_reward = 0.0
50 | self.reward_list = []
51 |
52 | # Setting kwargs.
53 | self.num_episodes = kwargs.get("num_episodes", 100)
54 | self.output = kwargs.get("output", None)
55 | self.render_interval = kwargs.get("render_interval", None)
56 | self.reward_delay = kwargs.get("reward_delay", None)
57 | self.time = kwargs.get("time", int(network.dt))
58 |
59 | if self.reward_delay is not None:
60 | assert self.reward_delay > 0
61 | self.rewards = torch.zeros(self.reward_delay)
62 |
63 | # Set up for multiple layers of input layers.
64 | self.inpts = [
65 | name
66 | for name, layer in network.layers.items()
67 | if isinstance(layer, AbstractInput)
68 | ]
69 |
70 | self.action = None
71 |
72 | self.voltage_record = None
73 | self.threshold_value = None
74 | self.reward_plot = None
75 |
76 | self.first = True
77 | self.analyzer = MatplotlibAnalyzer(**self.plot_config)
78 |
79 | def init_fn(self) -> None:
80 | pass
81 |
82 | def train(self, **kwargs) -> None:
83 | # language=rst
84 | """
85 | Trains for the specified number of episodes. Each episode can be of arbitrary length.
86 | """
87 | while self.episode < self.num_episodes:
88 | self.reset_()
89 |
90 | for _ in itertools.count():
91 | obs, reward, done, info = self.env_step()
92 |
93 | self.step((obs, reward, done, info), **kwargs)
94 |
95 | if done:
96 | break
97 |
98 | print(
99 | f"Episode: {self.episode} - accumulated reward: {self.accumulated_reward:.2f}"
100 | )
101 | self.episode += 1
102 |
103 | def env_step(self) -> Tuple[torch.Tensor, float, bool, Dict]:
104 | # language=rst
105 | """
106 | Single step of the environment which includes rendering, getting and performing the action,
107 | and accumulating/delaying rewards.
108 |
109 | :return: An OpenAI ``gym`` compatible tuple with modified reward and info.
110 | """
111 | # Render game.
112 | if (
113 | self.render_interval is not None
114 | and self.step_count % self.render_interval == 0
115 | ):
116 | self.env.render()
117 |
118 | # Choose action based on output neuron spiking.
119 | if self.action_function is not None:
120 | self.action = self.action_function(self, output=self.output)
121 |
122 | # Run a step of the environment.
123 | obs, reward, done, info = self.env.step(self.action)
124 |
125 | # Set reward in case of delay.
126 | if self.reward_delay is not None:
127 | self.rewards = torch.tensor([reward, *self.rewards[1:]]).float()
128 | reward = self.rewards[-1]
129 |
130 | # Accumulate reward.
131 | self.accumulated_reward += reward
132 |
133 | info["accumulated_reward"] = self.accumulated_reward
134 |
135 | return obs, reward, done, info
136 |
137 | def step_(
138 | self, gym_batch: Tuple[torch.Tensor, float, bool, Dict], **kwargs
139 | ) -> None:
140 | # language=rst
141 | """
142 | Run a single iteration of the network and update it and the
143 | reward list when done.
144 |
145 | :param gym_batch: An OpenAI ``gym`` compatible tuple.
146 | """
147 | obs, reward, done, info = gym_batch
148 |
149 | # Place the observations into the inputs.
150 | inpts = {k: obs for k in self.inpts}
151 |
152 | # Run the network on the spike train-encoded inputs.
153 | self.network.run(
154 | inpts=inpts, time=self.time, reward=reward, input_time_dim=1, **kwargs
155 | )
156 |
157 | if done:
158 | if self.network.reward_fn is not None:
159 | self.network.reward_fn.update(
160 | accumulated_reward=self.accumulated_reward,
161 | steps=self.step_count,
162 | **kwargs,
163 | )
164 | self.reward_list.append(self.accumulated_reward)
165 |
166 | def reset_(self) -> None:
167 | # language=rst
168 | """
169 | Reset the pipeline.
170 | """
171 | self.env.reset()
172 | self.network.reset_()
173 | self.accumulated_reward = 0.0
174 | self.step_count = 0
175 |
176 | def plots(self, gym_batch: Tuple[torch.Tensor, float, bool, Dict], *args) -> None:
177 | # language=rst
178 | """
179 | Plot the encoded input, layer spikes, and layer voltages.
180 |
181 | :param gym_batch: An OpenAI ``gym`` compatible tuple.
182 | """
183 | obs, reward, done, info = gym_batch
184 |
185 | for key, item in self.plot_config.items():
186 | if key == "obs_step" and item is not None:
187 | if self.step_count % item == 0:
188 | self.analyzer.plot_obs(obs[0, ...].sum(0))
189 | elif key == "data_step" and item is not None:
190 | if self.step_count % item == 0:
191 | self.analyzer.plot_spikes(self.get_spike_data())
192 | self.analyzer.plot_voltages(*self.get_voltage_data())
193 | elif key == "reward_eps" and item is not None:
194 | if self.episode % item == 0 and done:
195 | self.analyzer.plot_reward(self.reward_list)
196 |
197 | self.analyzer.finalize_step()
198 |
--------------------------------------------------------------------------------
/bindsnet/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .preprocessing import AbstractPreprocessor
2 |
--------------------------------------------------------------------------------
/bindsnet/preprocessing/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/preprocessing/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/preprocessing/__pycache__/preprocessing.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/preprocessing/__pycache__/preprocessing.cpython-37.pyc
--------------------------------------------------------------------------------
/bindsnet/preprocessing/preprocessing.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import pickle
4 | import torch
5 |
6 | from abc import abstractmethod, ABC
7 |
8 |
9 | class AbstractPreprocessor(ABC):
10 | # language=rst
11 | """
12 | Abstract base class for Preprocessor.
13 | """
14 |
15 | def process(
16 | self,
17 | csvfile: str,
18 | use_cache: bool = True,
19 | cachedfile: str = "./processed/data.pt",
20 | ) -> torch.tensor:
21 | # cache dictionary for storing encodings if previously encoded
22 | cache = {"verify": "", "data": None}
23 |
24 | # if the file exists
25 | if use_cache:
26 | # generate a hash
27 | cache["verify"] = self.__gen_hash(csvfile)
28 |
29 | # compare hash, if valid return cached value
30 | if self.__check_file(cachedfile, cache):
31 | return cache["data"]
32 |
33 | # otherwise process the data
34 | self._process(csvfile, cache)
35 |
36 | # save if use_cache
37 | if use_cache:
38 | self.__save(cachedfile, cache)
39 |
40 | # return data
41 | return cache["data"]
42 |
43 | @abstractmethod
44 | def _process(self, filename: str, cache: dict):
45 | # language=rst
46 | """
47 | Method for defining how to preprocess the data.
48 | :param filename: file to load raw data from
49 | :param cache: dict for caching 'data' needs to be updated for caching to work
50 | """
51 | pass
52 |
53 | def __gen_hash(self, filename: str) -> str:
54 | # language=rst
55 | """
56 | Generates an hash for a csv file and the preprocessor name
57 | :param filename: file to generate hash for
58 | :return: hash for the csv file
59 | """
60 | # read all the lines
61 | with open(filename, "r") as f:
62 | lines = f.readlines()
63 | # generate md5 hash after concatenating all of the lines
64 | pre = "".join(lines) + str(self.__class__.__name__)
65 | m = hashlib.md5(pre.encode("utf-8"))
66 | return m.hexdigest()
67 |
68 | @staticmethod
69 | def __check_file(cachedfile: str, cache: dict) -> bool:
70 | # language=rst
71 | """
72 | Compares the csv file and the saved file to see if a new encoding needs to be generated.
73 | :param cachedfile: the filename of the cached data
74 | :param cache: dict containing the current csvfile hash. This is updated if the cachefile has valid data
75 | :return: whether the cache is valid
76 |
77 | """
78 | # try opening the cached file
79 | try:
80 | with open(cachedfile, "rb") as f:
81 | temp = pickle.load(f)
82 | except FileNotFoundError:
83 | temp = {"verify": "", "data": None}
84 |
85 | # if the hash matches up, keep the data from the cache
86 | if cache["verify"] == temp["verify"]:
87 | cache["data"] = temp["data"]
88 | return True
89 |
90 | # otherwise don't do anything
91 | return False
92 |
93 | @staticmethod
94 | def __save(filename: str, data: dict) -> None:
95 | # language=rst
96 | """
97 | Creates/Overwrites existing encoding file
98 | :param filename: filename to save to
99 | """
100 | # if the directories in path don't exist create them
101 | if not os.path.exists(os.path.dirname(filename)):
102 | os.makedirs(os.path.dirname(filename), exist_ok=True)
103 |
104 | # save file
105 | with open(filename, "wb") as f:
106 | pickle.dump(data, f)
107 |
--------------------------------------------------------------------------------
/bindsnet/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 |
5 | from torch import Tensor
6 | import torch.nn.functional as F
7 | from numpy import ndarray
8 | from typing import Tuple, Union
9 | from torch.nn.modules.utils import _pair
10 |
11 |
12 | def im2col_indices(
13 | x: Tensor,
14 | kernel_height: int,
15 | kernel_width: int,
16 | padding: Tuple[int, int] = (0, 0),
17 | stride: Tuple[int, int] = (1, 1),
18 | ) -> Tensor:
19 | # language=rst
20 | """
21 | im2col is a special case of unfold which is implemented inside of Pytorch.
22 |
23 | :param x: Input image tensor to be reshaped to column-wise format.
24 | :param kernel_height: Height of the convolutional kernel in pixels.
25 | :param kernel_width: Width of the convolutional kernel in pixels.
26 | :param padding: Amount of zero padding on the input image.
27 | :param stride: Amount to stride over image by per convolution.
28 | :return: Input tensor reshaped to column-wise format.
29 | """
30 | return F.unfold(x, (kernel_height, kernel_width), padding=padding, stride=stride)
31 |
32 |
33 | def col2im_indices(
34 | cols: Tensor,
35 | x_shape: Tuple[int, int, int, int],
36 | kernel_height: int,
37 | kernel_width: int,
38 | padding: Tuple[int, int] = (0, 0),
39 | stride: Tuple[int, int] = (1, 1),
40 | ) -> Tensor:
41 | # language=rst
42 | """
43 | col2im is a special case of fold which is implemented inside of Pytorch.
44 |
45 | :param cols: Image tensor in column-wise format.
46 | :param x_shape: Shape of original image tensor.
47 | :param kernel_height: Height of the convolutional kernel in pixels.
48 | :param kernel_width: Width of the convolutional kernel in pixels.
49 | :param padding: Amount of zero padding on the input image.
50 | :param stride: Amount to stride over image by per convolution.
51 | :return: Image tensor in original image shape.
52 | """
53 | return F.fold(
54 | cols, x_shape, (kernel_height, kernel_width), padding=padding, stride=stride
55 | )
56 |
57 |
58 | def get_square_weights(
59 | weights: Tensor, n_sqrt: int, side: Union[int, Tuple[int, int]]
60 | ) -> Tensor:
61 | # language=rst
62 | """
63 | Return a grid of a number of filters ``sqrt ** 2`` with side lengths ``side``.
64 |
65 | :param weights: Two-dimensional tensor of weights for two-dimensional data.
66 | :param n_sqrt: Square root of no. of filters.
67 | :param side: Side length(s) of filter.
68 | :return: Reshaped weights to square matrix of filters.
69 | """
70 | if isinstance(side, int):
71 | side = (side, side)
72 |
73 | square_weights = torch.zeros(side[0] * n_sqrt, side[1] * n_sqrt)
74 | for i in range(n_sqrt):
75 | for j in range(n_sqrt):
76 | n = i * n_sqrt + j
77 |
78 | if not n < weights.size(1):
79 | break
80 |
81 | x = i * side[0]
82 | y = (j % n_sqrt) * side[1]
83 | filter_ = weights[:, n].contiguous().view(*side)
84 | square_weights[x : x + side[0], y : y + side[1]] = filter_
85 |
86 | return square_weights
87 |
88 |
89 | def get_square_assignments(assignments: Tensor, n_sqrt: int) -> Tensor:
90 | # language=rst
91 | """
92 | Return a grid of assignments.
93 |
94 | :param assignments: Vector of integers corresponding to class labels.
95 | :param n_sqrt: Square root of no. of assignments.
96 | :return: Reshaped square matrix of assignments.
97 | """
98 | square_assignments = torch.mul(torch.ones(n_sqrt, n_sqrt), -1.0)
99 | for i in range(n_sqrt):
100 | for j in range(n_sqrt):
101 | n = i * n_sqrt + j
102 |
103 | if not n < assignments.size(0):
104 | break
105 |
106 | square_assignments[
107 | i : (i + 1), (j % n_sqrt) : ((j % n_sqrt) + 1)
108 | ] = assignments[n]
109 |
110 | return square_assignments
111 |
112 |
113 | def reshape_locally_connected_weights(
114 | w: Tensor,
115 | n_filters: int,
116 | kernel_size: Union[int, Tuple[int, int]],
117 | conv_size: Union[int, Tuple[int, int]],
118 | locations: Tensor,
119 | input_sqrt: Union[int, Tuple[int, int]],
120 | ) -> Tensor:
121 | # language=rst
122 | """
123 | Get the weights from a locally connected layer and reshape them to be two-dimensional and square.
124 |
125 | :param w: Weights from a locally connected layer.
126 | :param n_filters: No. of neuron filters.
127 | :param kernel_size: Side length(s) of convolutional kernel.
128 | :param conv_size: Side length(s) of convolution population.
129 | :param locations: Binary mask indicating receptive fields of convolution population neurons.
130 | :param input_sqrt: Sides length(s) of input neurons.
131 | :return: Locally connected weights reshaped as a collection of spatially ordered square grids.
132 | """
133 | kernel_size = _pair(kernel_size)
134 | conv_size = _pair(conv_size)
135 | input_sqrt = _pair(input_sqrt)
136 |
137 | k1, k2 = kernel_size
138 | c1, c2 = conv_size
139 | i1, i2 = input_sqrt
140 | c1sqrt, c2sqrt = int(math.ceil(math.sqrt(c1))), int(math.ceil(math.sqrt(c2)))
141 | fs = int(math.ceil(math.sqrt(n_filters)))
142 |
143 | w_ = torch.zeros((n_filters * k1, k2 * c1 * c2))
144 |
145 | for n1 in range(c1):
146 | for n2 in range(c2):
147 | for feature in range(n_filters):
148 | n = n1 * c2 + n2
149 | filter_ = w[
150 | locations[:, n],
151 | feature * (c1 * c2) + (n // c2sqrt) * c2sqrt + (n % c2sqrt),
152 | ].view(k1, k2)
153 | w_[feature * k1 : (feature + 1) * k1, n * k2 : (n + 1) * k2] = filter_
154 |
155 | if c1 == 1 and c2 == 1:
156 | square = torch.zeros((i1 * fs, i2 * fs))
157 |
158 | for n in range(n_filters):
159 | square[
160 | (n // fs) * i1 : ((n // fs) + 1) * i2,
161 | (n % fs) * i2 : ((n % fs) + 1) * i2,
162 | ] = w_[n * i1 : (n + 1) * i2]
163 |
164 | return square
165 | else:
166 | square = torch.zeros((k1 * fs * c1, k2 * fs * c2))
167 |
168 | for n1 in range(c1):
169 | for n2 in range(c2):
170 | for f1 in range(fs):
171 | for f2 in range(fs):
172 | if f1 * fs + f2 < n_filters:
173 | square[
174 | k1 * (n1 * fs + f1) : k1 * (n1 * fs + f1 + 1),
175 | k2 * (n2 * fs + f2) : k2 * (n2 * fs + f2 + 1),
176 | ] = w_[
177 | (f1 * fs + f2) * k1 : (f1 * fs + f2 + 1) * k1,
178 | (n1 * c2 + n2) * k2 : (n1 * c2 + n2 + 1) * k2,
179 | ]
180 |
181 | return square
182 |
183 |
184 | def reshape_conv2d_weights(weights: torch.Tensor) -> torch.Tensor:
185 | # language=rst
186 | """
187 | Flattens a connection weight matrix of a Conv2dConnection
188 |
189 | :param weights: Weight matrix of Conv2dConnection object.
190 | :param wmin: Minimum allowed weight value.
191 | :param wmax: Maximum allowed weight value.
192 | """
193 | sqrt1 = int(np.ceil(np.sqrt(weights.size(0))))
194 | sqrt2 = int(np.ceil(np.sqrt(weights.size(1))))
195 | height, width = weights.size(2), weights.size(3)
196 | reshaped = torch.zeros(
197 | sqrt1 * sqrt2 * weights.size(2), sqrt1 * sqrt2 * weights.size(3)
198 | )
199 |
200 | for i in range(sqrt1):
201 | for j in range(sqrt1):
202 | for k in range(sqrt2):
203 | for l in range(sqrt2):
204 | if i * sqrt1 + j < weights.size(0) and k * sqrt2 + l < weights.size(
205 | 1
206 | ):
207 | fltr = weights[i * sqrt1 + j, k * sqrt2 + l].view(height, width)
208 | reshaped[
209 | i * height
210 | + k * height * sqrt1 : (i + 1) * height
211 | + k * height * sqrt1,
212 | (j % sqrt1) * width
213 | + (l % sqrt2) * width * sqrt1 : ((j % sqrt1) + 1) * width
214 | + (l % sqrt2) * width * sqrt1,
215 | ] = fltr
216 |
217 | return reshaped
218 |
--------------------------------------------------------------------------------
/conversion.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from time import time
4 | from matplotlib import pyplot as plt
5 | import torch
6 | import numpy as np
7 | from tqdm import tqdm
8 | from bindsnet.conversion import ann_to_snn
9 | from bindsnet.encoding import RepeatEncoder
10 | from bindsnet.datasets import ImageNet, CIFAR100, DataLoader
11 | import torchvision.transforms as transforms
12 | from vgg import vgg_15_avg_before_relu
13 |
14 |
15 | def main(args):
16 | if args.gpu and torch.cuda.is_available():
17 | torch.cuda.manual_seed_all(args.seed)
18 |
19 | np.random.seed(args.seed)
20 | torch.manual_seed(args.seed)
21 |
22 | if args.n_workers == -1:
23 | args.n_workers = args.gpu * 4 * torch.cuda.device_count()
24 |
25 | device = torch.device("cuda" if args.gpu else "cpu")
26 |
27 | # Load trained ANN from disk.
28 | if args.arch == 'vgg15ab':
29 | ann = vgg_15_avg_before_relu(dataset=args.dataset)
30 | # add other architectures here#
31 | else:
32 | raise ValueError('Unknown architecture')
33 |
34 |
35 | ann.features = torch.nn.DataParallel(ann.features)
36 | ann.cuda()
37 | if not os.path.isdir(args.job_dir):
38 | os.mkdir(args.job_dir)
39 | f = os.path.join('.', args.model)
40 | try:
41 | dictionary = torch.load(f=f)['state_dict']
42 | except KeyError:
43 | dictionary = torch.load(f=f)
44 | ann.load_state_dict(state_dict=dictionary, strict=True)
45 |
46 | if args.dataset=='imagenet':
47 | input_shape=(3,224,224)
48 |
49 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
50 | std=[0.229, 0.224, 0.225])
51 | # the actual data to be evaluated
52 | val_loader = ImageNet(
53 | image_encoder=RepeatEncoder(time=args.time, dt=1.0),
54 | label_encoder=None,
55 | root=args.data,
56 | download=False,
57 | transform=transforms.Compose([
58 | transforms.Resize((256, 256)),
59 | transforms.CenterCrop(224),
60 | transforms.ToTensor(),
61 | normalize,
62 | ]),
63 | split='val')
64 | # a wrapper class
65 | dataloader = DataLoader(
66 | val_loader,
67 | batch_size=args.batch_size,
68 | shuffle=True,
69 | num_workers=4,
70 | pin_memory=args.gpu,
71 | )
72 | # A loader of samples for normalization of the SNN from the training set
73 | norm_loader = ImageNet(
74 | image_encoder=RepeatEncoder(time=args.time, dt=1.0),
75 | label_encoder=None,
76 | root=args.data,
77 | download=False,
78 | split='train',
79 | transform = transforms.Compose([
80 | transforms.Resize(256),
81 | transforms.CenterCrop(224),
82 | transforms.ToTensor(),
83 | normalize, ]
84 | ),
85 | )
86 |
87 | elif args.dataset == 'cifar100':
88 | input_shape=(3, 32, 32)
89 | print('==> Using Pytorch CIFAR-100 Dataset')
90 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441],
91 | std=[0.267, 0.256, 0.276])
92 | val_loader = CIFAR100(
93 | image_encoder=RepeatEncoder(time=args.time, dt=1.0),
94 | label_encoder=None,
95 | root=args.data,
96 | download=True,
97 | train=False,
98 | transform=transforms.Compose([
99 | transforms.RandomCrop(32, padding=4),
100 | transforms.RandomHorizontalFlip(0.5),
101 | transforms.ToTensor(),
102 | normalize, ]
103 | )
104 | )
105 |
106 | dataloader = DataLoader(
107 | val_loader,
108 | batch_size=args.batch_size,
109 | shuffle=True,
110 | num_workers=0,
111 | pin_memory=args.gpu,
112 | )
113 |
114 | norm_loader = CIFAR100(
115 | image_encoder=RepeatEncoder(time=args.time, dt=1.0),
116 | label_encoder=None,
117 | root=args.data,
118 | download=True,
119 | train=True,
120 | transform=transforms.Compose([
121 | transforms.RandomCrop(32, padding=4),
122 | transforms.RandomHorizontalFlip(0.5),
123 | transforms.ToTensor(),
124 | normalize, ]
125 | )
126 | )
127 | else:
128 | raise ValueError('Unsupported dataset.')
129 |
130 | if args.eval_size == -1:
131 | args.eval_size = len(val_loader)
132 |
133 | for step, batch in enumerate(torch.utils.data.DataLoader(norm_loader, batch_size=args.norm)):
134 | data = batch['image']
135 | break
136 |
137 | snn = ann_to_snn(ann, input_shape=input_shape, data=data, percentile=args.percentile)
138 |
139 |
140 | torch.cuda.empty_cache()
141 | snn = snn.to(device)
142 |
143 | correct = 0
144 | t0 = time()
145 | accuracies = np.zeros((args.time, (args.eval_size//args.batch_size)+1), dtype=np.float32)
146 | for step, batch in enumerate(tqdm(dataloader)):
147 | if (step+1)*args.batch_size > args.eval_size:
148 | break
149 | # Prep next input batch.
150 | inputs = batch["encoded_image"]
151 | labels = batch["label"]
152 | inpts = {"Input": inputs}
153 | if args.gpu:
154 | inpts = {k: v.cuda() for k, v in inpts.items()}
155 |
156 | snn.run(inpts=inpts, time=args.time, step=step, acc= accuracies, labels=labels,one_step=args.one_step)
157 | last_layer = list(snn.layers.keys())[-1]
158 | output_voltages = snn.layers[last_layer].summed
159 | prediction = torch.softmax(output_voltages, dim=1).argmax(dim=1)
160 | correct += (prediction.cpu() == labels).sum().item()
161 | snn.reset_()
162 | t1 = time() - t0
163 |
164 | final = accuracies.sum(axis=1) / args.eval_size
165 |
166 | plt.plot(final)
167 | plt.suptitle('{} {} ANN-SNN@{} percentile'.format(args.dataset, args.arch, args.percentile), fontsize=20)
168 | plt.xlabel('Timestep', fontsize=19)
169 | plt.ylabel('Accuracy', fontsize=19)
170 | plt.grid()
171 | plt.show()
172 | plt.savefig('{}/{}_{}.png'.format(args.job_dir, args.arch, args.percentile))
173 | np.save('{}/voltage_accuracy_{}_{}.npy'.format(args.job_dir, args.arch, args.percentile), final)
174 |
175 |
176 | accuracy = 100 * correct / args.eval_size
177 |
178 | print(f"SNN accuracy: {accuracy:.2f}")
179 | print(f"Clock time used: {t1:.4f} ms.")
180 | path = os.path.join(args.job_dir, "results", args.results_file)
181 | os.makedirs(os.path.dirname(path), exist_ok=True)
182 | if not os.path.isfile(path):
183 | with open(path, "w") as f:
184 | f.write("seed,simulation time,batch size,inference time,accuracy\n")
185 | to_write = [args.seed, args.time, args.batch_size, t1, accuracy]
186 | to_write = ",".join(map(str, to_write)) + "\n"
187 | with open(path, "a") as f:
188 | f.write(to_write)
189 |
190 | return t1
191 |
192 |
193 | def parse_args():
194 | parser = argparse.ArgumentParser()
195 | parser.add_argument("--job-dir", type=str, required=True, help='The working directory to store results')
196 | parser.add_argument("--model", type=str, required=True, help='The path to the pretrained model')
197 | parser.add_argument("--results-file", type=str, default='sim_result.txt', help='The file to store simulation result')
198 | parser.add_argument("--seed", type=int, default=0, help='A random seed')
199 | parser.add_argument("--time", type=int, default=80, help='Time steps to be simulated by the converted SNN (default: 80)')
200 | parser.add_argument("--batch-size", type=int, default=100, help='Mini batch size')
201 | parser.add_argument("--n-workers", type=int, default=4, help='Number of data loaders')
202 | parser.add_argument("--norm", type=int, default=128, help='The amount of data to be normalized at once')
203 | parser.add_argument("--gpu", action="store_true", help='Whether to use GPU or not')
204 | parser.add_argument("--one-step", action="store_true", help='Single step inference flag')
205 | parser.add_argument('--data', metavar='DATA_PATH', default='./data/',
206 | help='The path to ImageNet data (default: \'./data/)\', CIFAR-100 will be downloaded')
207 | parser.add_argument("--arch", type=str, default='vgg15ab', help='ANN architecture to be instantiated')
208 | parser.add_argument("--percentile", type=float, default=99.7, help='The percentile of activation in the training set to be used for normalization of SNN voltage threshold')
209 | parser.add_argument("--eval_size", type=int, default=-1, help='The amount of samples to be evaluated (default: evaluate all)')
210 | parser.add_argument("--dataset", type=str, default='cifar100', help='cifar100 or imagenet')
211 |
212 | return parser.parse_args()
213 |
214 |
215 | if __name__ == "__main__":
216 | main(parse_args())
217 |
--------------------------------------------------------------------------------
/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class VGG_15_avg_before_relu(nn.Module):
5 | def __init__(self, dr=0.1, num_classes=1000, units=512*7*7):
6 | super(VGG_15_avg_before_relu, self).__init__()
7 | self.features = nn.Sequential(
8 | nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
9 | nn.ReLU(),
10 | nn.Dropout(dr),
11 | nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
12 | nn.AvgPool2d((2, 2), (2, 2)),
13 | nn.ReLU(),
14 |
15 | nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
16 | nn.ReLU(),
17 | nn.Dropout(dr),
18 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
19 | nn.AvgPool2d((2, 2), (2, 2)),
20 | nn.ReLU(),
21 |
22 | nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
23 | nn.ReLU(),
24 | nn.Dropout(dr),
25 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
26 | nn.ReLU(),
27 | nn.Dropout(dr),
28 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
29 | nn.AvgPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), # AvgPool2d,
30 | nn.ReLU(),
31 |
32 | nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
33 | nn.ReLU(),
34 | nn.Dropout(dr),
35 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
36 | nn.ReLU(),
37 | nn.Dropout(dr),
38 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
39 | nn.AvgPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), # AvgPool2d,
40 | nn.ReLU(),
41 |
42 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
43 | nn.ReLU(),
44 | nn.Dropout(dr),
45 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
46 | nn.ReLU(),
47 | nn.Dropout(dr),
48 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False),
49 | nn.AvgPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
50 | nn.ReLU()
51 | )
52 | self.classifier = nn.Sequential(
53 | nn.Dropout(dr),
54 | nn.Linear(units, 4096, bias=False), # Linear,
55 | nn.ReLU(),
56 | nn.Dropout(dr),
57 | nn.Linear(4096, num_classes, bias=False) # Linear,
58 | )
59 |
60 | self._initialize_weights()
61 |
62 | def _initialize_weights(self):
63 | for m in self.modules():
64 | if isinstance(m, nn.Conv2d):
65 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
66 | if m.bias is not None:
67 | nn.init.constant_(m.bias, 0)
68 | elif isinstance(m, nn.BatchNorm2d):
69 | nn.init.constant_(m.weight, 1)
70 | # nn.init.constant_(m.bias, 0)
71 | elif isinstance(m, nn.Linear):
72 | nn.init.normal_(m.weight, 0, 0.01)
73 | # nn.init.constant_(m.bias, 0)
74 |
75 | def forward(self, x):
76 | x = self.features(x)
77 | x = x.view(x.size(0), -1)
78 | x = self.classifier(x)
79 | return x
80 |
81 | def vgg_15_avg_before_relu(dataset='imagenet' , **kwargs):
82 | if dataset == 'imagenet':
83 | model = VGG_15_avg_before_relu(num_classes=1000, **kwargs)
84 | elif dataset == 'cifar100':
85 | model = VGG_15_avg_before_relu(num_classes=100, units=512,**kwargs)
86 | else:
87 | model = None
88 | raise ValueError('Unsupported Dataset!')
89 | return model
90 |
--------------------------------------------------------------------------------