├── README.md ├── __init__.py ├── bindsnet ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── utils.cpython-37.pyc ├── analysis │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── pipeline_analysis.cpython-37.pyc │ │ ├── plotting.cpython-37.pyc │ │ └── visualization.cpython-37.pyc │ ├── pipeline_analysis.py │ ├── plotting.py │ └── visualization.py ├── bindsnet-0.2.5.dist-info │ ├── INSTALLER │ ├── METADATA │ ├── RECORD │ ├── WHEEL │ └── top_level.txt ├── conversion │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── conversion.cpython-37.pyc │ └── conversion.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── collate.cpython-37.pyc │ │ ├── dataloader.cpython-37.pyc │ │ ├── davis.cpython-37.pyc │ │ ├── preprocess.cpython-37.pyc │ │ ├── spoken_mnist.cpython-37.pyc │ │ └── torchvision_wrapper.cpython-37.pyc │ ├── collate.py │ ├── dataloader.py │ ├── davis.py │ ├── preprocess.py │ ├── spoken_mnist.py │ └── torchvision_wrapper.py ├── encoding │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── encoders.cpython-37.pyc │ │ ├── encodings.cpython-37.pyc │ │ └── loaders.cpython-37.pyc │ ├── encoders.py │ ├── encodings.py │ └── loaders.py ├── environment │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── environment.cpython-37.pyc │ └── environment.py ├── evaluation │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── evaluation.cpython-37.pyc │ └── evaluation.py ├── learning │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── learning.cpython-37.pyc │ │ └── reward.cpython-37.pyc │ ├── learning.py │ └── reward.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── models.cpython-37.pyc │ └── models.py ├── network │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── monitors.cpython-37.pyc │ │ ├── network.cpython-37.pyc │ │ ├── nodes.cpython-37.pyc │ │ └── topology.cpython-37.pyc │ ├── monitors.py │ ├── network.py │ ├── nodes.py │ └── topology.py ├── pipeline │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── action.cpython-37.pyc │ │ ├── base_pipeline.cpython-37.pyc │ │ ├── dataloader_pipeline.cpython-37.pyc │ │ └── environment_pipeline.cpython-37.pyc │ ├── action.py │ ├── base_pipeline.py │ ├── dataloader_pipeline.py │ └── environment_pipeline.py ├── preprocessing │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── preprocessing.cpython-37.pyc │ └── preprocessing.py └── utils.py ├── conversion.py └── vgg.py /README.md: -------------------------------------------------------------------------------- 1 | # Exploring the Connection Between Binary and Spiking Neural Networks 2 | 3 | ## Overview 4 | 10 | This codebase outlines a training methodology and provides trained models for Full Precision and Binary Spiking Neural Networks (B-SNNs) utilizing [BindsNet](https://github.com/BindsNET/bindsnet) for large-scale datasets, namely CIFAR-100 and ImageNet. Following the proposed procedures and design features mentioned in [our work](https://www.frontiersin.org/article/10.3389/fnins.2020.00535), we have shown that B-SNNs exhibit near full-precision accuracy even with many SNN-specific constraints. Additionally, we used ANN-SNN conversion technique for training and explored a novel set of optimizations for generating high accuracy and low latency SNNs. The optimization techniques also apply to the full precision ANN-SNN conversion. 11 | 12 | ## Requirements 13 | 14 | - A Python installation version 3.6 or above 15 | - The matplotlib, numpy, tqdm, and torchvision 16 | - A PyTorch install version 1.3.0 ([pytorch.org](http://pytorch.org)) 17 | - CUDA 10.1 18 | - The ImageNet dataset (which can be automatically downloaded by a recent version of [torchvision](https://pytorch.org/docs/stable/torchvision/datasets.html#imagenet)) (If needed) 19 | 20 | ## Training from scratch 21 | We explored various network architectures constrained by ANN-SNN conversion. The finalized network structure can be found in ```vgg.py```. Further details can be found in the [paper](http://arxiv.org/abs/2002.10064). 22 | 23 | ### Hyperparameter Settings 24 | | Model | Batch Size | Epoch | Learning Rate | Weight Decay | Optimizer | 25 | | ---- | ---- | ---- | ---- | ---- | ---- | 26 | | CIFAR-100 Full Precision | 256 | 200 | 5e-2, divided by 10 at 81 and 122 epoch | 1e-4 | SGD (momentum=0.9) | 27 | | CIFAR-100 Binary | 256 | 200 | 5e-4, halved every 30 epochs | 5e-4 (0 after 30 epochs) | Adam | 28 | | ImageNet Full Precision| 128 | 100 | 1e-2, divided by 10 every 30 epochs | 1e-4 | SGD (momentum=0.9) | 29 | | ImageNet Binary | 128 | 100 | 5e-4, halved every 30 epochs | 5e-4 (0 after 30 epochs) | Adam(**beta=(0.0,0.999)**) | 30 | 31 | Note that these hyper-parameters may be further optimized. 32 | 33 | ## Evaluating Pre-trained models 34 | We provide pre-trained models of the VGG architecture mentioned in the paper and described above, available for download. Note that the first and the last layers are not binarized for our models. The corresponding top-1 accuracies are indicated in parentheses. 35 | 36 | * [CIFAR-100 Full Precision ANN (64.9%)](https://drive.google.com/open?id=1ZmagwfBdWVVztCdn67gmAWtfQJY3yrev) 37 | * [CIFAR-100 Binary ANN (64.8%)](https://drive.google.com/open?id=1605x2i_noKiQ-Z4OZW9L__deR_ubvfGS) 38 | * [ImageNet Full Precision ANN (69.05%)](https://drive.google.com/open?id=1SHXlvUrkPAkl8nQ8_LCNja5ypkqh59_x) 39 | * [ImageNet Binary ANN (64.4%)](https://drive.google.com/open?id=12WeIAfrVNxD45NFv4HV1nSvrLa3rRZp_) 40 | 41 | The Full Precision ANNs are trained using standard [PyTorch training practices](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) and the binarization process utilizes part of the [XNOR-Net-Pytorch](https://github.com/jiecaoyu/XNOR-Net-PyTorch) script which is the python implementation of the original [XNOR-Net](https://github.com/allenai/XNOR-Net) script. 42 | 43 | ## Running a simulation 44 | 45 | Prepare the pre-trained model and move to the same directory, and run the following code for each model: 46 | 47 | ```python conversion.py --job-dir cifar100_test --gpu --dataset cifar100 --data . --percentile 99.9 --norm 3500 --time 100 --arch vgg15ab --model bin_cifar100.pth.tar``` 48 | 49 | Full documentation of the arguments in `conversion.py`: 50 | ``` 51 | usage: conversion.py [-h] --job-dir JOB_DIR --model MODEL 52 | [--results-file RESULTS_FILE] [--seed SEED] [--time TIME] 53 | [--batch-size BATCH_SIZE] [--n-workers N_WORKERS] 54 | [--norm NORM] [--gpu] [--one-step] [--data DATA_PATH] 55 | [--arch ARCH] [--percentile PERCENTILE] 56 | [--eval_size EVAL_SIZE] [--dataset DATASET] 57 | 58 | optional arguments: 59 | -h, --help show this help message and exit 60 | --job-dir JOB_DIR The working directory to store results 61 | --model MODEL The path to the pretrained model 62 | --results-file RESULTS_FILE 63 | The file to store simulation result 64 | --seed SEED A random seed 65 | --time TIME Time steps to be simulated by the converted SNN 66 | (default: 80) 67 | --batch-size BATCH_SIZE 68 | Mini batch size 69 | --n-workers N_WORKERS 70 | Number of data loaders 71 | --norm NORM The amount of data to be normalized at once 72 | --gpu Whether to use GPU or not 73 | --one-step Single step inference flag 74 | --data DATA_PATH The path to ImageNet data (default: './data/)', 75 | CIFAR-100 will be downloaded 76 | --arch ARCH ANN architecture to be instantiated 77 | --percentile PERCENTILE 78 | The percentile of activation in the training set to be 79 | used for normalization of SNN voltage threshold 80 | --eval_size EVAL_SIZE 81 | The amount of samples to be evaluated (default: 82 | evaluate all) 83 | --dataset DATASET cifar100 or imagenet 84 | ``` 85 | Depending on your computing resources, some settings can be changed to speed up or to accommodate the available device. ```--norm```, ```--batch-size```, and ```--time``` can be changed for better performance. 86 | 87 | ## Reference 88 | 89 | If you use this code, please cite the following paper: 90 | 91 | Sen Lu and Abhronil Sengupta. "Exploring the Connection Between Binary and Spiking Neural Networks", Frontiers in Neuroscience, Vol. 14, pp. 535 (2020). 92 | 93 | ``` 94 | @ARTICLE{10.3389/fnins.2020.00535, 95 | AUTHOR={Lu, Sen and Sengupta, Abhronil}, 96 | TITLE={Exploring the Connection Between Binary and Spiking Neural Networks}, 97 | JOURNAL={Frontiers in Neuroscience}, 98 | VOLUME={14}, 99 | PAGES={535}, 100 | YEAR={2020}, 101 | URL={https://www.frontiersin.org/article/10.3389/fnins.2020.00535}, 102 | DOI={10.3389/fnins.2020.00535}, 103 | ISSN={1662-453X} 104 | } 105 | ``` 106 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from MLP import MLP 2 | from alexnet import AlexNet -------------------------------------------------------------------------------- /bindsnet/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from . import ( 4 | utils, 5 | network, 6 | models, 7 | analysis, 8 | preprocessing, 9 | datasets, 10 | encoding, 11 | pipeline, 12 | learning, 13 | evaluation, 14 | environment, 15 | conversion, 16 | ) 17 | 18 | ROOT_DIR = Path(__file__).parents[0].parents[0] 19 | -------------------------------------------------------------------------------- /bindsnet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from . import plotting, visualization, pipeline_analysis 2 | -------------------------------------------------------------------------------- /bindsnet/analysis/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/analysis/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/__pycache__/pipeline_analysis.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/analysis/__pycache__/pipeline_analysis.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/__pycache__/plotting.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/analysis/__pycache__/plotting.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/__pycache__/visualization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/analysis/__pycache__/visualization.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/pipeline_analysis.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Optional 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from tensorboardX import SummaryWriter 9 | from torchvision.utils import make_grid 10 | 11 | from .plotting import plot_spikes, plot_voltages, plot_conv2d_weights 12 | from ..utils import reshape_conv2d_weights 13 | 14 | 15 | class PipelineAnalyzer(ABC): 16 | # language=rst 17 | """ 18 | Responsible for pipeline analysis. Subclasses maintain state 19 | information related to plotting or logging. 20 | """ 21 | 22 | @abstractmethod 23 | def finalize_step(self) -> None: 24 | # language=rst 25 | """ 26 | Flush the output from the current step. 27 | """ 28 | pass 29 | 30 | @abstractmethod 31 | def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None: 32 | # language=rst 33 | """ 34 | Pulls the observation from PyTorch and sets up for Matplotlib 35 | plotting. 36 | 37 | :param obs: A 2D array of floats depicting an input image. 38 | :param tag: A unique tag to associate the data with. 39 | :param step: The step of the pipeline. 40 | """ 41 | pass 42 | 43 | @abstractmethod 44 | def plot_reward( 45 | self, 46 | reward_list: list, 47 | reward_window: int = None, 48 | tag: str = "reward", 49 | step: int = None, 50 | ) -> None: 51 | # language=rst 52 | """ 53 | Plot the accumulated reward for each episode. 54 | 55 | :param reward_list: The list of recent rewards to be plotted. 56 | :param reward_window: The length of the window to compute a moving average over. 57 | :param tag: A unique tag to associate the data with. 58 | :param step: The step of the pipeline. 59 | """ 60 | pass 61 | 62 | @abstractmethod 63 | def plot_spikes( 64 | self, 65 | spike_record: Dict[str, torch.Tensor], 66 | tag: str = "spike", 67 | step: int = None, 68 | ) -> None: 69 | # language=rst 70 | """ 71 | Plots all spike records inside of ``spike_record``. Keeps unique 72 | plots for all unique tags that are given. 73 | 74 | :param spike_record: Dictionary of spikes to be rasterized. 75 | :param tag: A unique tag to associate the data with. 76 | :param step: The step of the pipeline. 77 | """ 78 | pass 79 | 80 | @abstractmethod 81 | def plot_voltages( 82 | self, 83 | voltage_record: Dict[str, torch.Tensor], 84 | thresholds: Optional[Dict[str, torch.Tensor]] = None, 85 | tag: str = "voltage", 86 | step: int = None, 87 | ) -> None: 88 | # language=rst 89 | """ 90 | Plots all voltage records and given thresholds. Keeps unique 91 | plots for all unique tags that are given. 92 | 93 | :param voltage_record: Dictionary of voltages for neurons inside of networks 94 | organized by the layer they correspond to. 95 | :param thresholds: Optional dictionary of threshold values for neurons. 96 | :param tag: A unique tag to associate the data with. 97 | :param step: The step of the pipeline. 98 | """ 99 | pass 100 | 101 | @abstractmethod 102 | def plot_conv2d_weights( 103 | self, weights: torch.Tensor, tag: str = "conv2d", step: int = None 104 | ) -> None: 105 | # language=rst 106 | """ 107 | Plot a connection weight matrix of a ``Conv2dConnection``. 108 | 109 | :param weights: Weight matrix of ``Conv2dConnection`` object. 110 | :param tag: A unique tag to associate the data with. 111 | :param step: The step of the pipeline. 112 | """ 113 | pass 114 | 115 | 116 | class MatplotlibAnalyzer(PipelineAnalyzer): 117 | # language=rst 118 | """ 119 | Renders output using Matplotlib. 120 | 121 | Matplotlib requires objects to be kept around over the full lifetime 122 | of the plots; this is done through ``self.plots``. An interactive session 123 | is needed so that we can continue processing and just update the 124 | plots. 125 | """ 126 | 127 | def __init__(self, **kwargs) -> None: 128 | # language=rst 129 | """ 130 | Initializes the analyzer. 131 | 132 | Keyword arguments: 133 | 134 | :param str volts_type: Type of plotting for voltages (``"color"`` or ``"line"``). 135 | """ 136 | self.volts_type = kwargs.get("volts_type", "color") 137 | plt.ion() 138 | self.plots = {} 139 | 140 | def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None: 141 | # language=rst 142 | """ 143 | Pulls the observation off of torch and sets up for Matplotlib 144 | plotting. 145 | 146 | :param obs: A 2D array of floats depicting an input image. 147 | :param tag: A unique tag to associate the data with. 148 | :param step: The step of the pipeline. 149 | """ 150 | obs = obs.detach().cpu().numpy() 151 | obs = np.transpose(obs, (1, 2, 0)).squeeze() 152 | 153 | if tag in self.plots: 154 | obs_ax, obs_im = self.plots[tag] 155 | else: 156 | obs_ax, obs_im = None, None 157 | 158 | if obs_im is None and obs_ax is None: 159 | fig, obs_ax = plt.subplots() 160 | obs_ax.set_title("Observation") 161 | obs_ax.set_xticks(()) 162 | obs_ax.set_yticks(()) 163 | obs_im = obs_ax.imshow(obs, cmap="gray") 164 | 165 | self.plots[tag] = obs_ax, obs_im 166 | else: 167 | obs_im.set_data(obs) 168 | 169 | def plot_reward( 170 | self, 171 | reward_list: list, 172 | reward_window: int = None, 173 | tag: str = "reward", 174 | step: int = None, 175 | ) -> None: 176 | # language=rst 177 | """ 178 | Plot the accumulated reward for each episode. 179 | 180 | :param reward_list: The list of recent rewards to be plotted. 181 | :param reward_window: The length of the window to compute a moving average over. 182 | :param tag: A unique tag to associate the data with. 183 | :param step: The step of the pipeline. 184 | """ 185 | if tag in self.plots: 186 | reward_im, reward_ax, reward_plot = self.plots[tag] 187 | else: 188 | reward_im, reward_ax, reward_plot = None, None, None 189 | 190 | # Compute moving average. 191 | if reward_window is not None: 192 | # Ensure window size > 0 and < size of reward list. 193 | window = max(min(len(reward_list), reward_window), 0) 194 | 195 | # Fastest implementation of moving average. 196 | reward_list_ = ( 197 | pd.Series(reward_list) 198 | .rolling(window=window, min_periods=1) 199 | .mean() 200 | .values 201 | ) 202 | else: 203 | reward_list_ = reward_list[:] 204 | 205 | if reward_im is None and reward_ax is None: 206 | reward_im, reward_ax = plt.subplots() 207 | reward_ax.set_title("Accumulated reward") 208 | reward_ax.set_xlabel("Episode") 209 | reward_ax.set_ylabel("Reward") 210 | (reward_plot,) = reward_ax.plot(reward_list_) 211 | 212 | self.plots[tag] = reward_im, reward_ax, reward_plot 213 | else: 214 | reward_plot.set_data(range(len(reward_list_)), reward_list_) 215 | reward_ax.relim() 216 | reward_ax.autoscale_view() 217 | 218 | def plot_spikes( 219 | self, 220 | spike_record: Dict[str, torch.Tensor], 221 | tag: str = "spike", 222 | step: int = None, 223 | ) -> None: 224 | # language=rst 225 | """ 226 | Plots all spike records inside of ``spike_record``. Keeps unique 227 | plots for all unique tags that are given. 228 | 229 | :param spike_record: Dictionary of spikes to be rasterized. 230 | :param tag: A unique tag to associate the data with. 231 | :param step: The step of the pipeline. 232 | """ 233 | if tag not in self.plots: 234 | self.plots[tag] = plot_spikes(spike_record) 235 | else: 236 | s_im, s_ax = self.plots[tag] 237 | self.plots[tag] = plot_spikes(spike_record, ims=s_im, axes=s_ax) 238 | 239 | def plot_voltages( 240 | self, 241 | voltage_record: Dict[str, torch.Tensor], 242 | thresholds: Optional[Dict[str, torch.Tensor]] = None, 243 | tag: str = "voltage", 244 | step: int = None, 245 | ) -> None: 246 | # language=rst 247 | """ 248 | Plots all voltage records and given thresholds. Keeps unique 249 | plots for all unique tags that are given. 250 | 251 | :param voltage_record: Dictionary of voltages for neurons inside of networks 252 | organized by the layer they correspond to. 253 | :param thresholds: Optional dictionary of threshold values for neurons. 254 | :param tag: A unique tag to associate the data with. 255 | :param step: The step of the pipeline. 256 | """ 257 | if tag not in self.plots: 258 | self.plots[tag] = plot_voltages( 259 | voltage_record, plot_type=self.volts_type, thresholds=thresholds 260 | ) 261 | else: 262 | v_im, v_ax = self.plots[tag] 263 | self.plots[tag] = plot_voltages( 264 | voltage_record, 265 | ims=v_im, 266 | axes=v_ax, 267 | plot_type=self.volts_type, 268 | thresholds=thresholds, 269 | ) 270 | 271 | def plot_conv2d_weights( 272 | self, weights: torch.Tensor, tag: str = "conv2d", step: int = None 273 | ) -> None: 274 | # language=rst 275 | """ 276 | Plot a connection weight matrix of a ``Conv2dConnection``. 277 | 278 | :param weights: Weight matrix of ``Conv2dConnection`` object. 279 | :param tag: A unique tag to associate the data with. 280 | :param step: The step of the pipeline. 281 | """ 282 | wmin = weights.min().item() 283 | wmax = weights.max().item() 284 | 285 | if tag not in self.plots: 286 | self.plots[tag] = plot_conv2d_weights(weights, wmin, wmax) 287 | else: 288 | im = self.plots[tag] 289 | plot_conv2d_weights(weights, wmin, wmax, im=im) 290 | 291 | def finalize_step(self) -> None: 292 | # language=rst 293 | """ 294 | Flush the output from the current step 295 | """ 296 | plt.draw() 297 | plt.pause(1e-8) 298 | plt.show() 299 | 300 | 301 | class TensorboardAnalyzer(PipelineAnalyzer): 302 | def __init__(self, summary_directory: str = "./logs"): 303 | # language=rst 304 | """ 305 | Initializes the analyzer. 306 | 307 | :param summary_directory: Directory to save log files. 308 | """ 309 | self.writer = SummaryWriter(summary_directory) 310 | 311 | def finalize_step(self) -> None: 312 | # language=rst 313 | """ 314 | No-op for ``TensorboardAnalyzer``. 315 | """ 316 | pass 317 | 318 | def plot_obs(self, obs: torch.Tensor, tag: str = "obs", step: int = None) -> None: 319 | # language=rst 320 | """ 321 | Pulls the observation off of torch and sets up for Matplotlib 322 | plotting. 323 | 324 | :param obs: A 2D array of floats depicting an input image. 325 | :param tag: A unique tag to associate the data with. 326 | :param step: The step of the pipeline. 327 | """ 328 | obs_grid = make_grid(obs.float(), nrow=4, normalize=True) 329 | self.writer.add_image(tag, obs_grid, step) 330 | 331 | def plot_reward( 332 | self, 333 | reward_list: list, 334 | reward_window: int = None, 335 | tag: str = "reward", 336 | step: int = None, 337 | ) -> None: 338 | # language=rst 339 | """ 340 | Plot the accumulated reward for each episode. 341 | 342 | :param reward_list: The list of recent rewards to be plotted. 343 | :param reward_window: The length of the window to compute a moving average over. 344 | :param tag: A unique tag to associate the data with. 345 | :param step: The step of the pipeline. 346 | """ 347 | self.writer.add_scalar(tag, reward_list[-1], step) 348 | 349 | def plot_spikes( 350 | self, 351 | spike_record: Dict[str, torch.Tensor], 352 | tag: str = "spike", 353 | step: int = None, 354 | ) -> None: 355 | # language=rst 356 | """ 357 | Plots all spike records inside of ``spike_record``. Keeps unique 358 | plots for all unique tags that are given. 359 | 360 | :param spike_record: Dictionary of spikes to be rasterized. 361 | :param tag: A unique tag to associate the data with. 362 | :param step: The step of the pipeline. 363 | """ 364 | for k, spikes in spike_record.items(): 365 | # shuffle spikes into 1x1x#NueronsxT 366 | spikes = spikes.view(1, 1, -1, spikes.shape[-1]).float() 367 | spike_grid_img = make_grid(spikes, nrow=1, pad_value=0.5) 368 | 369 | self.writer.add_image(tag + "_" + str(k), spike_grid_img, step) 370 | 371 | def plot_voltages( 372 | self, 373 | voltage_record: Dict[str, torch.Tensor], 374 | thresholds: Optional[Dict[str, torch.Tensor]] = None, 375 | tag: str = "voltage", 376 | step: int = None, 377 | ) -> None: 378 | # language=rst 379 | """ 380 | Plots all voltage records and given thresholds. Keeps unique 381 | plots for all unique tags that are given. 382 | 383 | :param voltage_record: Dictionary of voltages for neurons inside of networks 384 | organized by the layer they correspond to. 385 | :param thresholds: Optional dictionary of threshold values for neurons. 386 | :param tag: A unique tag to associate the data with. 387 | :param step: The step of the pipeline. 388 | """ 389 | for k, v in voltage_record.items(): 390 | # Shuffle voltages into 1x1x#neuronsxT 391 | v = v.view(1, 1, -1, v.shape[-1]) 392 | voltage_grid_img = make_grid(v, nrow=1, pad_value=0) 393 | 394 | self.writer.add_image(tag + "_" + str(k), voltage_grid_img, step) 395 | 396 | def plot_conv2d_weights( 397 | self, weights: torch.Tensor, tag: str = "conv2d", step: int = None 398 | ) -> None: 399 | # language=rst 400 | """ 401 | Plot a connection weight matrix of a ``Conv2dConnection``. 402 | 403 | :param weights: Weight matrix of ``Conv2dConnection`` object. 404 | :param tag: A unique tag to associate the data with. 405 | :param step: The step of the pipeline. 406 | """ 407 | reshaped = reshape_conv2d_weights(weights).unsqueeze(0) 408 | 409 | reshaped -= reshaped.min() 410 | reshaped /= reshaped.max() 411 | 412 | self.writer.add_image(tag, reshaped, step) 413 | -------------------------------------------------------------------------------- /bindsnet/analysis/visualization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib.animation as animation 5 | 6 | from typing import List, Tuple, Optional 7 | 8 | 9 | def plot_weights_movie(ws: np.ndarray, sample_every: int = 1) -> None: 10 | # language=rst 11 | """ 12 | Create and plot movie of weights. 13 | 14 | :param ws: Array of shape ``[n_examples, source, target, time]`` 15 | :param sample_every: Sub-sample using this parameter. 16 | """ 17 | weights = [] 18 | 19 | # Obtain samples from the weights for every example. 20 | for i in range(ws.shape[0]): 21 | sub_sampled_weight = ws[i, :, :, range(0, ws[i].shape[2], sample_every)] 22 | weights.append(sub_sampled_weight) 23 | else: 24 | weights = np.concatenate(weights, axis=0) 25 | 26 | # Initialize plot. 27 | fig = plt.figure() 28 | im = plt.imshow(weights[0, :, :], cmap="hot_r", animated=True, vmin=0, vmax=1) 29 | plt.axis("off") 30 | plt.colorbar(im) 31 | 32 | # Update function for the animation. 33 | def update(j): 34 | im.set_data(weights[j, :, :]) 35 | return [im] 36 | 37 | # Initialize animation. 38 | global ani 39 | ani = 0 40 | ani = animation.FuncAnimation( 41 | fig, update, frames=weights.shape[-1], interval=1000, blit=True 42 | ) 43 | plt.show() 44 | 45 | 46 | def plot_spike_trains_for_example( 47 | spikes: torch.Tensor, 48 | n_ex: Optional[int] = None, 49 | top_k: Optional[int] = None, 50 | indices: Optional[List[int]] = None, 51 | ) -> None: 52 | # language=rst 53 | """ 54 | Plot spike trains for top-k neurons or for specific indices. 55 | 56 | :param spikes: Spikes for one simulation run of shape ``(n_examples, n_neurons, time)``. 57 | :param n_ex: Allows user to pick which example to plot spikes for. 58 | :param top_k: Plot k neurons that spiked the most for n_ex example. 59 | :param indices: Plot specific neurons' spiking activity instead of top_k. 60 | """ 61 | assert n_ex is not None and 0 <= n_ex < spikes.shape[0] 62 | 63 | plt.figure() 64 | 65 | if top_k is None and indices is None: # Plot all neurons' spiking activity 66 | spike_per_neuron = [np.argwhere(i == 1).flatten() for i in spikes[n_ex, :, :]] 67 | plt.title("Spiking activity for all %d neurons" % spikes.shape[1]) 68 | 69 | elif top_k is None: # Plot based on indices parameter 70 | assert indices is not None 71 | spike_per_neuron = [ 72 | np.argwhere(i == 1).flatten() for i in spikes[n_ex, indices, :] 73 | ] 74 | 75 | elif indices is None: # Plot based on top_k parameter 76 | assert top_k is not None 77 | # Obtain the top k neurons that fired the most 78 | top_k_loc = np.argsort(np.sum(spikes[n_ex, :, :], axis=1), axis=0)[::-1] 79 | spike_per_neuron = [ 80 | np.argwhere(i == 1).flatten() for i in spikes[n_ex, top_k_loc[0:top_k], :] 81 | ] 82 | plt.title("Spiking activity for top %d neurons" % top_k) 83 | 84 | else: 85 | raise ValueError('One of "top_k" or "indices" or both must be None') 86 | 87 | plt.eventplot(spike_per_neuron, linelengths=[0.5] * len(spike_per_neuron)) 88 | plt.xlabel("Simulation Time") 89 | plt.ylabel("Neuron index") 90 | plt.show() 91 | 92 | 93 | def plot_voltage( 94 | voltage: torch.Tensor, 95 | n_ex: int = 0, 96 | n_neuron: int = 0, 97 | time: Optional[Tuple[int, int]] = None, 98 | threshold: float = None, 99 | ) -> None: 100 | # language=rst 101 | """ 102 | Plot voltage for a single neuron on a specific example. 103 | 104 | :param voltage: Tensor or array of shape ``[n_examples, n_neurons, time]``. 105 | :param n_ex: Allows user to pick which example to plot voltage for. 106 | :param n_neuron: Neuron index for which to plot voltages for. 107 | :param time: Plot spiking activity of neurons between the given range of time. 108 | :param threshold: Neuron spiking threshold. 109 | """ 110 | assert n_ex >= 0 and n_neuron >= 0 111 | assert n_ex < voltage.shape[0] and n_neuron < voltage.shape[1] 112 | 113 | if time is None: 114 | time = (0, voltage.shape[-1]) 115 | else: 116 | assert time[0] < time[1] 117 | assert time[1] <= voltage.shape[-1] 118 | 119 | timer = np.arange(time[0], time[1]) 120 | time_ticks = np.arange(time[0], time[1] + 1, 10) 121 | 122 | plt.figure() 123 | plt.plot(voltage[n_ex, n_neuron, timer]) 124 | plt.xlabel("Simulation Time") 125 | plt.ylabel("Voltage") 126 | plt.title("Membrane voltage of neuron %d for example %d" % (n_neuron, n_ex + 1)) 127 | locs, labels = plt.xticks() 128 | locs = range(int(locs[1]), int(locs[-1]), 10) 129 | plt.xticks(locs, time_ticks) 130 | 131 | # Draw threshold line only if given 132 | if threshold is not None: 133 | plt.axhline(threshold, linestyle="--", color="black", zorder=0) 134 | 135 | plt.show() 136 | -------------------------------------------------------------------------------- /bindsnet/bindsnet-0.2.5.dist-info/INSTALLER: -------------------------------------------------------------------------------- 1 | pip 2 | -------------------------------------------------------------------------------- /bindsnet/bindsnet-0.2.5.dist-info/METADATA: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: bindsnet 3 | Version: 0.2.5 4 | Summary: Spiking neural networks for ML in Python 5 | Home-page: http://github.com/Hananel-Hazan/bindsnet 6 | Author: Daniel Saunders, Hananel Hazan, Darpan Sanghavi, Hassaan Khan 7 | Author-email: danjsaund@gmail.com 8 | License: AGPL-3.0 9 | Download-URL: https://github.com/Hananel-Hazan/bindsnet/archive/0.2.5.tar.gz 10 | Platform: UNKNOWN 11 | Description-Content-Type: text/markdown 12 | Requires-Dist: numpy (>=1.14.2) 13 | Requires-Dist: torch (>=1.2.0) 14 | Requires-Dist: torchvision (>=0.4.0) 15 | Requires-Dist: tensorboardX (>=1.7) 16 | Requires-Dist: tqdm (>=4.19.9) 17 | Requires-Dist: matplotlib (>=2.1.0) 18 | Requires-Dist: gym (>=0.10.4) 19 | Requires-Dist: scikit-image (>=0.13.1) 20 | Requires-Dist: scikit-learn (>=0.19.1) 21 | Requires-Dist: opencv-python (>=3.4.0.12) 22 | Requires-Dist: pytest (>=3.4.0) 23 | Requires-Dist: scipy (>=1.1.0) 24 | Requires-Dist: cython (>=0.28.5) 25 | Requires-Dist: pandas (>=0.23.4) 26 | 27 |

28 | 29 | A Python package used for simulating spiking neural networks (SNNs) on CPUs or GPUs using [PyTorch](http://pytorch.org/) `Tensor` functionality. 30 | 31 | BindsNET is a spiking neural network simulation library geared towards the development of biologically inspired algorithms for machine learning. 32 | 33 | This package is used as part of ongoing research on applying SNNs to machine learning (ML) and reinforcement learning (RL) problems in the [Biologically Inspired Neural & Dynamical Systems (BINDS) lab](http://binds.cs.umass.edu/). 34 | 35 | Check out the [BindsNET experiments repository](https://github.com/djsaunde/bindsnet_experiments) for a collection of experiments, accompanying bash scripts for dispatching on [CICS](https://www.cics.umass.edu/) clusters, functions for the analysis of results, plots of experiment outcomes, and more. 36 | 37 | [![Build Status](https://travis-ci.com/Hananel-Hazan/bindsnet.svg?token=trym5Uzx1rs9Ez2yENEF&branch=master)](https://travis-ci.com/Hananel-Hazan/bindsnet) 38 | [![Documentation Status](https://readthedocs.org/projects/bindsnet-docs/badge/?version=latest)](https://bindsnet-docs.readthedocs.io/?badge=latest) 39 | [![HitCount](http://hits.dwyl.io/Hananel-Hazan/bindsnet.svg)](http://hits.dwyl.io/Hananel-Hazan/bindsnet) 40 | [![Gitter chat](https://badges.gitter.im/gitterHQ/gitter.png)](https://gitter.im/bindsnet_/community) 41 | 42 | ## Requirements 43 | 44 | - Python 3.6 45 | - `requirements.txt` 46 | 47 | ## Setting things up 48 | 49 | ### Using pip 50 | BindsNET is available on PyPI. Issue 51 | 52 | ``` 53 | pip install bindsnet 54 | ``` 55 | 56 | to get the most recent stable release. Or, to build the `bindsnet` package from source, clone the GitHub repository, change directory to the top level of this project, and issue 57 | 58 | ``` 59 | pip install . 60 | ``` 61 | 62 | Or, to install in editable mode (allows modification of package without re-installing): 63 | 64 | ``` 65 | pip install -e . 66 | ``` 67 | 68 | To install the packages necessary to interface with the [OpenAI gym RL environments library](https://github.com/openai/gym), follow their instructions for installing the packages needed to run the RL environments simulator (on Linux / MacOS). 69 | 70 | ### Using Docker 71 | [Link](https://hub.docker.com/r/hqkhan/bindsnet/) to Docker repository. 72 | 73 | We also provide a Dockerfile in which BindsNET and all of its dependencies come installed in. Issue 74 | 75 | ``` 76 | docker image build . 77 | ``` 78 | at the top level directory of this project to create a docker image. 79 | 80 | To change the name of the newly built image, issue 81 | ``` 82 | docker tag 83 | ``` 84 | 85 | To run a container and get a bash terminal inside it, issue 86 | 87 | ``` 88 | docker run -it bash 89 | ``` 90 | 91 | ## Getting started 92 | 93 | To run a near-replication of the SNN from [this paper](https://www.frontiersin.org/articles/10.3389/fncom.2015.00099/full#), issue 94 | 95 | ``` 96 | cd examples/mnist 97 | python eth_mnist.py 98 | ``` 99 | 100 | There are a number of optional command-line arguments which can be passed in, including `--plot` (displays useful monitoring figures), `--n_neurons [int]` (number of excitatory, inhibitory neurons simulated), `--mode ['train' | 'test']` (sets network operation to the training or testing phase), and more. Run the script with the `--help` or `-h` flag for more information. 101 | 102 | A number of other examples are available in the `examples` directory that are meant to showcase BindsNET's functionality. Take a look, and let us know what you think! 103 | 104 | ## Running the tests 105 | 106 | Issue the following to run the tests: 107 | 108 | ``` 109 | python -m pytest test/ 110 | ``` 111 | 112 | Some tests will fail if Open AI `gym` is not installed on your machine. 113 | 114 | ## Background 115 | 116 | The simulation of biologically plausible spiking neuron dynamics can be challenging. It is typically done by solving ordinary differential equations (ODEs) which describe said dynamics. PyTorch does not explicitly support the solution of differential equations (as opposed to [`brian2`](https://github.com/brian-team/brian2), for example), but we can convert the ODEs defining the dynamics into difference equations and solve them at regular, short intervals (a `dt` on the order of 1 millisecond) as an approximation. Of course, under the hood, packages like `brian2` are doing the same thing. Doing this in [`PyTorch`](http://pytorch.org/) is exciting for a few reasons: 117 | 118 | 1. We can use the powerful and flexible [`torch.Tensor`](http://pytorch.org/) object, a wrapper around the [`numpy.ndarray`](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.html) which can be transferred to and from GPU devices. 119 | 120 | 2. We can avoid "reinventing the wheel" by repurposing functions from the [`torch.nn.functional`](http://pytorch.org/docs/master/nn.html#torch-nn-functional) PyTorch submodule in our SNN architectures; e.g., convolution or pooling functions. 121 | 122 | The concept that the neuron spike ordering and their relative timing encode information is a central theme in neuroscience. [Markram et al. (1997)](http://www.caam.rice.edu/~caam415/lec_gab/g4/markram_etal98.pdf) proposed that synapses between neurons should strengthen or degrade based on this relative timing, and prior to that, [Donald Hebb](https://en.wikipedia.org/wiki/Donald_O._Hebb) proposed the theory of Hebbian learning, often simply stated as "Neurons that fire together, wire together." Markram et al.'s extension of the Hebbian theory is known as spike-timing-dependent plasticity (STDP). 123 | 124 | We are interested in applying SNNs to ML and RL problems. We use STDP to modify weights of synapses connecting pairs or populations of neurons in SNNs. In the context of ML, we want to learn a setting of synapse weights which will generate data-dependent spiking activity in SNNs. This activity will allow us to subsequently perform some ML task of interest; e.g., discriminating or clustering input data. In the context of RL, we may think of the spiking neural network as an RL agent, whose spiking activity may be converted into actions in an environment's action space. 125 | 126 | We have provided some simple starter scripts for doing unsupervised learning (learning a fully-connected or convolutional representation via STDP), supervised learning (clamping output neurons to desired spiking behavior depending on data labels), and reinforcement learning (converting observations from the Atari game Space Invaders to input to an SNN, and converting network activity back to actions in the game). 127 | 128 | ## Benchmarking 129 | We simulated a network with a population of n Poisson input neurons with firing rates (in Hertz) drawn randomly from U(0, 100), connected all-to-all with a equally-sized population of leaky integrate-and-fire (LIF) neurons, with connection weights sampled from N(0,1). We varied n systematically from 250 to 10,000 in steps of 250, and ran each simulation with every library for 1,000ms with a time resolution dt = 1.0. We tested BindsNET (with CPU and GPU computation), BRIAN2, PyNEST (the Python interface to the NEST SLI interface that runs the C++NEST core simulator), ANNarchy (with CPU and GPU computation), and BRIAN2genn (the BRIAN2 front-end to the GeNN simulator). 130 | 131 | Several packages, including BRIAN and PyNEST, allow the setting of certain global preferences; e.g., the number of CPU threads, the number of OpenMP processes, etc. We chose these settings for our benchmark study in an attempt to maximize each library's speed, but note that BindsNET requires no setting of such options. Our approach, inheriting the computational model of PyTorch, appears to make the best use of the available hardware, and therefore makes it simple for practicioners to get the best performance from their system with the least effort. 132 | 133 |

134 | BindsNET%20Benchmark 135 |

136 | 137 | All simulations run on Ubuntu 16.04 LTS with Intel(R) Xeon(R) CPU E5-2687W v3 @ 3.10GHz, 128Gb RAM @ 2133MHz, and two GeForce GTX TITAN X (GM200) GPUs. Python 3.6 is used in all cases. Clock time was recorded for each simulation run. 138 | 139 | ## Citation 140 | 141 | If you use BindsNET in your research, please cite the following [article](https://www.frontiersin.org/article/10.3389/fninf.2018.00089): 142 | 143 | ``` 144 | @ARTICLE{10.3389/fninf.2018.00089, 145 | AUTHOR={Hazan, Hananel and Saunders, Daniel J. and Khan, Hassaan and Patel, Devdhar and Sanghavi, Darpan T. and Siegelmann, Hava T. and Kozma, Robert}, 146 | TITLE={BindsNET: A Machine Learning-Oriented Spiking Neural Networks Library in Python}, 147 | JOURNAL={Frontiers in Neuroinformatics}, 148 | VOLUME={12}, 149 | PAGES={89}, 150 | YEAR={2018}, 151 | URL={https://www.frontiersin.org/article/10.3389/fninf.2018.00089}, 152 | DOI={10.3389/fninf.2018.00089}, 153 | ISSN={1662-5196}, 154 | } 155 | 156 | ``` 157 | 158 | ## Contributors 159 | 160 | - Daniel Saunders ([email](mailto:djsaunde@cs.umass.edu)) 161 | - Hananel Hazan ([email](mailto:hananel@hazan.org.il)) 162 | - Darpan Sanghavi ([email](mailto:dsanghavi@cs.umass.edu)) 163 | - Hassaan Khan ([email](mailto:hqkhan@umass.edu)) 164 | - Devdhar Patel ([email](mailto:devdharpatel@cs.umass.edu)) 165 | 166 | ## License 167 | GNU Affero General Public License v3.0 168 | 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /bindsnet/bindsnet-0.2.5.dist-info/RECORD: -------------------------------------------------------------------------------- 1 | bindsnet-0.2.5.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 2 | bindsnet-0.2.5.dist-info/METADATA,sha256=607kTwrUkJf8qf3_sJm8I7ShSjqIKWWxa4Z_mVfU9Qw,9894 3 | bindsnet-0.2.5.dist-info/RECORD,, 4 | bindsnet-0.2.5.dist-info/WHEEL,sha256=6lWjmfsNs_bGGL5Yl5gRRr5x9whxDoYnVMS85C0Tfgs,98 5 | bindsnet-0.2.5.dist-info/top_level.txt,sha256=_rRWpJi6ML98T3oY7Yi0UE3ThTbjrPuIcjQ6JPyDoT4,9 6 | bindsnet/__init__.py,sha256=Copph0F10T-N_RBshOt81aLlmAHyMW70wp1kUneO9BI,267 7 | bindsnet/__pycache__/__init__.cpython-37.pyc,, 8 | bindsnet/__pycache__/utils.cpython-37.pyc,, 9 | bindsnet/analysis/__init__.py,sha256=ncx2TPPt_WsrqDb6W9My5dFHTKUqenDbPAiICPF8cdM,57 10 | bindsnet/analysis/__pycache__/__init__.cpython-37.pyc,, 11 | bindsnet/analysis/__pycache__/pipeline_analysis.cpython-37.pyc,, 12 | bindsnet/analysis/__pycache__/plotting.cpython-37.pyc,, 13 | bindsnet/analysis/__pycache__/visualization.cpython-37.pyc,, 14 | bindsnet/analysis/pipeline_analysis.py,sha256=SyP2wB5xQpmjeN0lXmM3lkxEIqLEhMTsOE0bepwcmCw,13629 15 | bindsnet/analysis/plotting.py,sha256=C5N9SBceCmQZdnTZIWepNYSZELxtKV87KwQAvoDOXNc,22172 16 | bindsnet/analysis/visualization.py,sha256=-mEB-ExZWRkDOyN0wq_DVGJ7p7OqKBUf48R0aEhmBUE,4432 17 | bindsnet/conversion/__init__.py,sha256=vAYbOc6MTuRf01ugLqE6D2nJzB1aTmhVClvOlZkpLgU,212 18 | bindsnet/conversion/__pycache__/__init__.cpython-37.pyc,, 19 | bindsnet/conversion/__pycache__/conversion.cpython-37.pyc,, 20 | bindsnet/conversion/conversion.py,sha256=qPEbzSzaZUi-xwqnHXEyJ3eSad9fmBq9oEy7y3HBass,19736 21 | bindsnet/datasets/__init__.py,sha256=JFmWp5g0zkaRoTzW9OftgQ19lW6dN_FWcaG7IrxeT5U,1625 22 | bindsnet/datasets/__pycache__/__init__.cpython-37.pyc,, 23 | bindsnet/datasets/__pycache__/collate.cpython-37.pyc,, 24 | bindsnet/datasets/__pycache__/dataloader.cpython-37.pyc,, 25 | bindsnet/datasets/__pycache__/davis.cpython-37.pyc,, 26 | bindsnet/datasets/__pycache__/preprocess.cpython-37.pyc,, 27 | bindsnet/datasets/__pycache__/spoken_mnist.cpython-37.pyc,, 28 | bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-37.pyc,, 29 | bindsnet/datasets/collate.py,sha256=jGmO8qyoSavSIdORGGiaf_9dBAjZdfHBIbzRSXVqZgM,3144 30 | bindsnet/datasets/dataloader.py,sha256=GgbN-YTdyzbRHTrHq5dPJCKfiJDnppLWRPzQkC6zS5M,818 31 | bindsnet/datasets/davis.py,sha256=wH5GsZ2CIy0-L-yfAwdvIjLwJgyP8NfMaFAo-E-AfqI,12586 32 | bindsnet/datasets/preprocess.py,sha256=CVPY1VrIRc9vzBi56TCmP3sustx6u8Y7GbbDFps7s8Y,1342 33 | bindsnet/datasets/spoken_mnist.py,sha256=9hipF6IrL55B_J0EdQU5E8RlBJ9YG9rwBbOpZcH02DM,10438 34 | bindsnet/datasets/torchvision_wrapper.py,sha256=fJ9fl7S4Q0Vo6m0z8I5m_SvpaZhFPJXuDkMI_P8vinA,2825 35 | bindsnet/encoding/__init__.py,sha256=_IbgfqJtClr4I3R8NNVhw_3UwerBtzhI4zHsrx0xdXE,301 36 | bindsnet/encoding/__pycache__/__init__.cpython-37.pyc,, 37 | bindsnet/encoding/__pycache__/encoders.cpython-37.pyc,, 38 | bindsnet/encoding/__pycache__/encodings.cpython-37.pyc,, 39 | bindsnet/encoding/__pycache__/loaders.cpython-37.pyc,, 40 | bindsnet/encoding/encoders.py,sha256=4r6VZzm9i7cHb2G2doGzTl4jAwW6xfJlRPeLlDcHGi8,3177 41 | bindsnet/encoding/encodings.py,sha256=i08jPyCV3Jj4H7qpphqVw8BJPwuKJGukEXPwtDafALA,5965 42 | bindsnet/encoding/loaders.py,sha256=2TbyhK-ib_evgFNHPBkQeZC-aN5yvLFdGsTBirDAXrI,2380 43 | bindsnet/environment/__init__.py,sha256=6JMOGowIuDuvByo9kFIRR1ntJfpOBUQjSZqE7LWaWHE,53 44 | bindsnet/environment/__pycache__/__init__.cpython-37.pyc,, 45 | bindsnet/environment/__pycache__/environment.cpython-37.pyc,, 46 | bindsnet/environment/environment.py,sha256=W3ZLDCKRypPMuouUAVSkrfBulXSt_6yPEfMdGLQ7PtY,8425 47 | bindsnet/evaluation/__init__.py,sha256=umhjCQ4eaNhUHKGcZ40wqwH5Xh9jOJ3cFWrJLt-cF-g,163 48 | bindsnet/evaluation/__pycache__/__init__.cpython-37.pyc,, 49 | bindsnet/evaluation/__pycache__/evaluation.cpython-37.pyc,, 50 | bindsnet/evaluation/evaluation.py,sha256=Y_qS1PvvIc_YHjxEeKFqJ53p5VQSr-dFU89fJTKYemg,9093 51 | bindsnet/learning/__init__.py,sha256=zR7Zaao__m3rAcL1gNFUwaNooEWbJoJ7Om-E3VSn0zc,142 52 | bindsnet/learning/__pycache__/__init__.cpython-37.pyc,, 53 | bindsnet/learning/__pycache__/learning.cpython-37.pyc,, 54 | bindsnet/learning/__pycache__/reward.cpython-37.pyc,, 55 | bindsnet/learning/learning.py,sha256=5Ihyv7um2a3RGGJv_r7MpRa-KBUl0zaW8qr7c7PNCYY,32970 56 | bindsnet/learning/reward.py,sha256=4VeLl_T1kkLRWvhaVwVtD22c5e8oM1FSaMgqspDNnPg,2605 57 | bindsnet/models/__init__.py,sha256=Mt-J2zoeyxdoj0vlQ2XCciYoitAeJPqNoimN6PUc4ns,153 58 | bindsnet/models/__pycache__/__init__.cpython-37.pyc,, 59 | bindsnet/models/__pycache__/models.cpython-37.pyc,, 60 | bindsnet/models/models.py,sha256=PcXeFjT-fKHbH7F_qGgbaYf_IEQW8PDgsQb-7H6QBzs,18723 61 | bindsnet/network/__init__.py,sha256=2HHNEEtnmcyI5uAkRJAwby7hsWNls9a3nOrCxqPFZ2U,75 62 | bindsnet/network/__pycache__/__init__.cpython-37.pyc,, 63 | bindsnet/network/__pycache__/monitors.cpython-37.pyc,, 64 | bindsnet/network/__pycache__/network.cpython-37.pyc,, 65 | bindsnet/network/__pycache__/nodes.cpython-37.pyc,, 66 | bindsnet/network/__pycache__/topology.cpython-37.pyc,, 67 | bindsnet/network/monitors.py,sha256=S4i5HtEvAKsHkC9kd4ruvKtBcyEOrbtWUa-nhTHFVog,9678 68 | bindsnet/network/network.py,sha256=2WRqoi5nipJDy1uXOjhMvoqWLg6RzTO41W-HAudCAec,14894 69 | bindsnet/network/nodes.py,sha256=Fhp6BoBO0r-KkFIOfTt94sM4F3ZjxJbdo57ctncW0_o,49303 70 | bindsnet/network/topology.py,sha256=KKwOBuAB_f2OCFJJ_OrjWCzelgUvMr8QyAxbkUop7BY,29391 71 | bindsnet/pipeline/__init__.py,sha256=Sq9muSJcWFuvss53MuRBa5W4q1mSGcdAMrwaFGMjqMg,195 72 | bindsnet/pipeline/__pycache__/__init__.cpython-37.pyc,, 73 | bindsnet/pipeline/__pycache__/action.cpython-37.pyc,, 74 | bindsnet/pipeline/__pycache__/base_pipeline.cpython-37.pyc,, 75 | bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-37.pyc,, 76 | bindsnet/pipeline/__pycache__/environment_pipeline.cpython-37.pyc,, 77 | bindsnet/pipeline/action.py,sha256=qf7Vw5LvwUOx8t6in2NeU4JsWyweev5rkehdvGDggQw,3212 78 | bindsnet/pipeline/base_pipeline.py,sha256=J8viSSWYcKC7Jj4E0tNK9DpaJvqjcJwvDFPhz5PO1CE,7763 79 | bindsnet/pipeline/dataloader_pipeline.py,sha256=qsNJ4fweS3C8I5Fju6odK1x5qqswssd06d5VA_hxeQU,4889 80 | bindsnet/pipeline/environment_pipeline.py,sha256=ZazbqFmPhaw8DSCZSSt4qWupgyQ4tp1kLDuEs4nz4PA,6498 81 | bindsnet/preprocessing/__init__.py,sha256=IqFn8cXso7NvrnflQbkcmjcIBRuXMvtUX9D0GVXyr8A,48 82 | bindsnet/preprocessing/__pycache__/__init__.cpython-37.pyc,, 83 | bindsnet/preprocessing/__pycache__/preprocessing.cpython-37.pyc,, 84 | bindsnet/preprocessing/preprocessing.py,sha256=xMGKvEAGp_4AORzYUHaMnXNLs3aI4UQdFM7TwlyAygY,3248 85 | bindsnet/utils.py,sha256=VIE5vMCsZpIyje82QNFbnOOX1Sa-8xDhrbh5M6NhbeY,7538 86 | -------------------------------------------------------------------------------- /bindsnet/bindsnet-0.2.5.dist-info/WHEEL: -------------------------------------------------------------------------------- 1 | Wheel-Version: 1.0 2 | Generator: bdist_wheel (0.33.4) 3 | Root-Is-Purelib: true 4 | Tag: cp37-none-any 5 | 6 | -------------------------------------------------------------------------------- /bindsnet/bindsnet-0.2.5.dist-info/top_level.txt: -------------------------------------------------------------------------------- 1 | bindsnet 2 | -------------------------------------------------------------------------------- /bindsnet/conversion/__init__.py: -------------------------------------------------------------------------------- 1 | from .conversion import ( 2 | Permute, 3 | FeatureExtractor, 4 | SubtractiveResetIFNodes, 5 | PassThroughNodes, 6 | PermuteConnection, 7 | ConstantPad2dConnection, 8 | data_based_normalization, 9 | ann_to_snn, 10 | ) 11 | -------------------------------------------------------------------------------- /bindsnet/conversion/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/conversion/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/conversion/__pycache__/conversion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/conversion/__pycache__/conversion.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .torchvision_wrapper import create_torchvision_dataset_wrapper 2 | from .spoken_mnist import SpokenMNIST 3 | from .davis import Davis 4 | 5 | from .collate import time_aware_collate 6 | from .dataloader import DataLoader 7 | 8 | 9 | CIFAR10 = create_torchvision_dataset_wrapper("CIFAR10") 10 | CIFAR100 = create_torchvision_dataset_wrapper("CIFAR100") 11 | Cityscapes = create_torchvision_dataset_wrapper("Cityscapes") 12 | CocoCaptions = create_torchvision_dataset_wrapper("CocoCaptions") 13 | CocoDetection = create_torchvision_dataset_wrapper("CocoDetection") 14 | DatasetFolder = create_torchvision_dataset_wrapper("DatasetFolder") 15 | EMNIST = create_torchvision_dataset_wrapper("EMNIST") 16 | FakeData = create_torchvision_dataset_wrapper("FakeData") 17 | FashionMNIST = create_torchvision_dataset_wrapper("FashionMNIST") 18 | Flickr30k = create_torchvision_dataset_wrapper("Flickr30k") 19 | Flickr8k = create_torchvision_dataset_wrapper("Flickr8k") 20 | ImageFolder = create_torchvision_dataset_wrapper("ImageFolder") 21 | KMNIST = create_torchvision_dataset_wrapper("KMNIST") 22 | LSUN = create_torchvision_dataset_wrapper("LSUN") 23 | LSUNClass = create_torchvision_dataset_wrapper("LSUNClass") 24 | MNIST = create_torchvision_dataset_wrapper("MNIST") 25 | Omniglot = create_torchvision_dataset_wrapper("Omniglot") 26 | PhotoTour = create_torchvision_dataset_wrapper("PhotoTour") 27 | SBU = create_torchvision_dataset_wrapper("SBU") 28 | SEMEION = create_torchvision_dataset_wrapper("SEMEION") 29 | STL10 = create_torchvision_dataset_wrapper("STL10") 30 | SVHN = create_torchvision_dataset_wrapper("SVHN") 31 | VOCDetection = create_torchvision_dataset_wrapper("VOCDetection") 32 | VOCSegmentation = create_torchvision_dataset_wrapper("VOCSegmentation") 33 | ImageNet = create_torchvision_dataset_wrapper("ImageNet") 34 | -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/collate.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/collate.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/davis.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/davis.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/preprocess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/preprocess.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/spoken_mnist.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/spoken_mnist.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/collate.py: -------------------------------------------------------------------------------- 1 | r"""" This code is directly pulled from the pytorch version found at: 2 | 3 | https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py 4 | 5 | Modifications exist to have [time, batch, n_0, ... n_k] instead of batch 6 | in dimension 0. 7 | """ 8 | 9 | import torch 10 | import re 11 | from torch._six import container_abcs, string_classes, int_classes 12 | 13 | from torch.utils.data._utils import collate as pytorch_collate 14 | 15 | 16 | def safe_worker_check(): 17 | """ Method to check to used shared memory will change in a newer 18 | version of pytorch 19 | """ 20 | try: 21 | return torch.utils.data.get_worker_info() is not None 22 | except: 23 | return pytorch_collate._use_shared_memory 24 | 25 | 26 | def time_aware_collate(batch): 27 | r"""Puts each data field into a tensor with dimensions [time, 28 | batch size, ...] 29 | 30 | Interpretation of dimensions being input: 31 | - 0 dim (,) - (1, batch_size, 1) 32 | - 1 dim (time,) - (time, batch_size, 1) 33 | - >2 dim (time, n_0, ...) - (time, batch_size, n_0, ...) 34 | """ 35 | 36 | elem = batch[0] 37 | elem_type = type(elem) 38 | if isinstance(elem, torch.Tensor): 39 | # catch 0 and 1 dimension cases and view as specified 40 | if elem.dim() == 0: 41 | batch = [x.view((1, 1)) for x in batch] 42 | elif elem.dim() == 1: 43 | batch = [x.view((x.shape[0], 1)) for x in batch] 44 | 45 | out = None 46 | if safe_worker_check(): 47 | # If we're in a background process, concatenate directly into a 48 | # shared memory tensor to avoid an extra copy 49 | numel = sum([x.numel() for x in batch]) 50 | storage = elem.storage()._new_shared(numel) 51 | out = elem.new(storage) 52 | return torch.stack(batch, 1, out=out) 53 | elif ( 54 | elem_type.__module__ == "numpy" 55 | and elem_type.__name__ != "str_" 56 | and elem_type.__name__ != "string_" 57 | ): 58 | elem = batch[0] 59 | if elem_type.__name__ == "ndarray": 60 | # array of string classes and object 61 | if ( 62 | pytorch_collate.np_str_obj_array_pattern.search(elem.dtype.str) 63 | is not None 64 | ): 65 | raise TypeError( 66 | pytorch_collate.default_collate_err_msg_format.format(elem.dtype) 67 | ) 68 | 69 | return time_aware_collate([torch.as_tensor(b) for b in batch]) 70 | elif elem.shape == (): # scalars 71 | return torch.as_tensor(batch) 72 | elif isinstance(elem, float): 73 | return torch.tensor(batch, dtype=torch.float64) 74 | elif isinstance(elem, int_classes): 75 | return torch.tensor(batch) 76 | elif isinstance(elem, string_classes): 77 | return batch 78 | elif isinstance(elem, container_abcs.Mapping): 79 | return {key: time_aware_collate([d[key] for d in batch]) for key in elem} 80 | elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple 81 | return elem_type(*(time_aware_collate(samples) for samples in zip(*batch))) 82 | elif isinstance(elem, container_abcs.Sequence): 83 | transposed = zip(*batch) 84 | return [time_aware_collate(samples) for samples in transposed] 85 | 86 | raise TypeError(pytorch_collate.default_collate_err_msg_format.format(elem_type)) 87 | -------------------------------------------------------------------------------- /bindsnet/datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .collate import time_aware_collate 4 | 5 | 6 | class DataLoader(torch.utils.data.DataLoader): 7 | def __init__( 8 | self, 9 | dataset, 10 | batch_size=1, 11 | shuffle=False, 12 | sampler=None, 13 | batch_sampler=None, 14 | num_workers=0, 15 | collate_fn=time_aware_collate, 16 | pin_memory=False, 17 | drop_last=False, 18 | timeout=0, 19 | worker_init_fn=None, 20 | ): 21 | super().__init__( 22 | dataset, 23 | sampler=sampler, 24 | shuffle=shuffle, 25 | batch_size=batch_size, 26 | drop_last=drop_last, 27 | pin_memory=pin_memory, 28 | timeout=timeout, 29 | num_workers=num_workers, 30 | worker_init_fn=worker_init_fn, 31 | batch_sampler=batch_sampler, 32 | collate_fn=collate_fn, 33 | ) 34 | -------------------------------------------------------------------------------- /bindsnet/datasets/davis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import shutil 5 | import zipfile 6 | import sys 7 | import time 8 | import shutil 9 | 10 | from PIL import Image 11 | from glob import glob 12 | from tqdm import tqdm 13 | from collections import defaultdict 14 | from typing import Optional, Tuple, List, Iterable 15 | from urllib.request import urlretrieve 16 | 17 | import warnings 18 | 19 | 20 | class Davis(torch.utils.data.Dataset): 21 | SUBSET_OPTIONS = ["train", "val", "test-dev", "test-challenge"] 22 | TASKS = ["semi-supervised", "unsupervised"] 23 | RESOLUTION_OPTIONS = ["480p", "Full-Resolution"] 24 | DATASET_WEB = "https://davischallenge.org/davis2017/code.html" 25 | VOID_LABEL = 255 26 | 27 | def __init__( 28 | self, 29 | root, 30 | task="unsupervised", 31 | subset="train", 32 | sequences="all", 33 | resolution="480p", 34 | size=(600, 480), 35 | codalab=False, 36 | download=False, 37 | num_samples: int = -1, 38 | ): 39 | """ 40 | Class to read the DAVIS dataset 41 | :param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. 42 | :param task: Task to load the annotations, choose between semi-supervised or unsupervised. 43 | :param subset: Set to load the annotations 44 | :param sequences: Sequences to consider, 'all' to use all the sequences in a set. 45 | :param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution' 46 | :param download: Specify whether to download the dataset if it is not present 47 | :param num_samples: Number of samples to pass to the batch 48 | """ 49 | super().__init__() 50 | 51 | if subset not in self.SUBSET_OPTIONS: 52 | raise ValueError(f"Subset should be in {self.SUBSET_OPTIONS}") 53 | if task not in self.TASKS: 54 | raise ValueError(f"The only tasks that are supported are {self.TASKS}") 55 | if resolution not in self.RESOLUTION_OPTIONS: 56 | raise ValueError( 57 | f"You may only choose one of these resolutions: {self.RESOLUTION_OPTIONS}" 58 | ) 59 | 60 | self.task = task 61 | self.subset = subset 62 | self.resolution = resolution 63 | self.size = size 64 | 65 | # Sets the boolean converted if the size of the images must be scaled down 66 | self.converted = not self.size == (600, 480) 67 | 68 | # Sets a tag for naming the folder containing the dataset 69 | self.tag = "" 70 | if self.task == "unsupervised": 71 | self.tag += "Unsupervised-" 72 | if self.subset == "train" or self.subset == "val": 73 | self.tag += "trainval" 74 | else: 75 | self.tag += self.subset 76 | self.tag += "-" + self.resolution 77 | 78 | # Makes a unique path for a given instance of davis 79 | self.converted_root = os.path.join( 80 | root, self.tag + "-" + str(self.size[0]) + "x" + str(self.size[1]) 81 | ) 82 | self.root = os.path.join(root, self.tag) 83 | self.download = download 84 | self.num_samples = num_samples 85 | self.zip_path = os.path.join(self.root, "repo.zip") 86 | self.img_path = os.path.join(self.root, "JPEGImages", resolution) 87 | annotations_folder = ( 88 | "Annotations" if task == "semi-supervised" else "Annotations_unsupervised" 89 | ) 90 | self.mask_path = os.path.join(self.root, annotations_folder, resolution) 91 | year = ( 92 | "2019" 93 | if task == "unsupervised" 94 | and (subset == "test-dev" or subset == "test-challenge") 95 | else "2017" 96 | ) 97 | self.imagesets_path = os.path.join(self.root, "ImageSets", year) 98 | 99 | # Makes a converted path for scaled images 100 | if self.converted: 101 | self.converted_img_path = os.path.join( 102 | self.converted_root, "JPEGImages", resolution 103 | ) 104 | self.converted_mask_path = os.path.join( 105 | self.converted_root, annotations_folder, resolution 106 | ) 107 | self.converted_imagesets_path = os.path.join( 108 | self.converted_root, "ImageSets", year 109 | ) 110 | 111 | # Sets seqence_names to the relevant sequences 112 | if sequences == "all": 113 | with open( 114 | os.path.join(self.imagesets_path, f"{self.subset}.txt"), "r" 115 | ) as f: 116 | tmp = f.readlines() 117 | self.sequences_names = [x.strip() for x in tmp] 118 | else: 119 | self.sequences_names = ( 120 | sequences if isinstance(sequences, list) else [sequences] 121 | ) 122 | self.sequences = defaultdict(dict) 123 | 124 | # Check if Davis is installed and download it if necessary 125 | self._check_directories() 126 | 127 | # Sets the images and masks for each sequence resizing for the given size 128 | for seq in self.sequences_names: 129 | images = np.sort(glob(os.path.join(self.img_path, seq, "*.jpg"))).tolist() 130 | if len(images) == 0 and not codalab: 131 | raise FileNotFoundError(f"Images for sequence {seq} not found.") 132 | self.sequences[seq]["images"] = images 133 | masks = np.sort(glob(os.path.join(self.mask_path, seq, "*.png"))).tolist() 134 | masks.extend([-1] * (len(images) - len(masks))) 135 | self.sequences[seq]["masks"] = masks 136 | 137 | # Creates an enumeration for the sequences for __getitem__ 138 | self.enum_sequences = [] 139 | for seq in self.sequences_names: 140 | self.enum_sequences.append(self.sequences[seq]) 141 | 142 | def __len__(self): 143 | """ 144 | Calculates the number of sequences the dataset holds 145 | 146 | :return: the number of sequences in the dataset 147 | """ 148 | return len(self.sequences) 149 | 150 | def _convert_sequences(self): 151 | """ 152 | Creates a new root for the dataset to be converted and placed into, then copies each image and mask into the given size and stores correctly. 153 | """ 154 | os.makedirs(os.path.join(self.converted_imagesets_path, f"{self.subset}.txt")) 155 | os.makedirs(self.converted_img_path) 156 | os.makedirs(self.converted_mask_path) 157 | 158 | shutil.copy( 159 | os.path.join(self.imagesets_path, f"{self.subset}.txt"), 160 | os.path.join(self.converted_imagesets_path, f"{self.subset}.txt"), 161 | ) 162 | 163 | print("Converting sequences to size: {0}".format(self.size)) 164 | for seq in tqdm(self.sequences_names): 165 | os.makedirs(os.path.join(self.converted_img_path, seq)) 166 | os.makedirs(os.path.join(self.converted_mask_path, seq)) 167 | images = np.sort(glob(os.path.join(self.img_path, seq, "*.jpg"))).tolist() 168 | if len(images) == 0 and not codalab: 169 | raise FileNotFoundError(f"Images for sequence {seq} not found.") 170 | for ind, img in enumerate(images): 171 | im = Image.open(img) 172 | im.thumbnail(self.size, Image.ANTIALIAS) 173 | im.save( 174 | os.path.join( 175 | self.converted_img_path, seq, str(ind).zfill(5) + ".jpg" 176 | ) 177 | ) 178 | masks = np.sort(glob(os.path.join(self.mask_path, seq, "*.png"))).tolist() 179 | for ind, msk in enumerate(masks): 180 | im = Image.open(msk) 181 | im.thumbnail(self.size, Image.ANTIALIAS) 182 | im.convert("RGB").save( 183 | os.path.join( 184 | self.converted_mask_path, seq, str(ind).zfill(5) + ".png" 185 | ) 186 | ) 187 | 188 | def _check_directories(self): 189 | """ 190 | Verifies that the correct dataset is downloaded; downloads if it isn't and download=True. 191 | 192 | :raises: FileNotFoundError if the subset sequence, annotation or root folder is missing. 193 | """ 194 | if not os.path.exists(self.root): 195 | if self.download: 196 | self._download() 197 | else: 198 | raise FileNotFoundError( 199 | f"DAVIS not found in the specified directory, download it from {self.DATASET_WEB} or add download=True to your call" 200 | ) 201 | if not os.path.exists(os.path.join(self.imagesets_path, f"{self.subset}.txt")): 202 | raise FileNotFoundError( 203 | f"Subset sequences list for {self.subset} not found, download the missing subset " 204 | f"for the {self.task} task from {self.DATASET_WEB}" 205 | ) 206 | if self.subset in ["train", "val"] and not os.path.exists(self.mask_path): 207 | raise FileNotFoundError( 208 | f"Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}" 209 | ) 210 | if self.converted: 211 | if not os.path.exists(self.converted_img_path): 212 | self._convert_sequences() 213 | self.img_path = self.converted_img_path 214 | self.mask_path = self.converted_mask_path 215 | self.imagesets_path = self.converted_imagesets_path 216 | 217 | def get_frames(self, sequence): 218 | for img, msk in zip( 219 | self.sequences[sequence]["images"], self.sequences[sequence]["masks"] 220 | ): 221 | image = np.array(Image.open(img)) 222 | mask = None if msk is None else np.array(Image.open(msk)) 223 | yield image, mask 224 | 225 | def _get_all_elements(self, sequence, obj_type): 226 | obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) 227 | all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) 228 | obj_id = [] 229 | for i, obj in enumerate(self.sequences[sequence][obj_type]): 230 | all_objs[i, ...] = np.array(Image.open(obj)) 231 | obj_id.append("".join(obj.split("/")[-1].split(".")[:-1])) 232 | return all_objs, obj_id 233 | 234 | def get_all_images(self, sequence): 235 | return self._get_all_elements(sequence, "images") 236 | 237 | def get_all_masks(self, sequence, separate_objects_masks=False): 238 | masks, masks_id = self._get_all_elements(sequence, "masks") 239 | masks_void = np.zeros_like(masks) 240 | 241 | # Separate void and object masks 242 | for i in range(masks.shape[0]): 243 | masks_void[i, ...] = masks[i, ...] == 255 244 | masks[i, masks[i, ...] == 255] = 0 245 | 246 | if separate_objects_masks: 247 | num_objects = int(np.max(masks[0, ...])) 248 | tmp = np.ones((num_objects, *masks.shape)) 249 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 250 | masks = tmp == masks[None, ...] 251 | masks = masks > 0 252 | return masks, masks_void, masks_id 253 | 254 | def get_sequences(self): 255 | for seq in self.sequences: 256 | yield seq 257 | 258 | def _download(self): 259 | """ 260 | Downloads the correct dataset based on the given parameters 261 | 262 | Relies on self.tag to determine both the name of the folder created for the dataset and for the finding the correct download url. 263 | """ 264 | 265 | os.makedirs(self.root) 266 | 267 | # Grabs the correct zip url based on parameters 268 | zip_url = f"https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-{self.tag}.zip" 269 | 270 | print("\nDownloading Davis data set from " + zip_url + "\n") 271 | 272 | # Downloads the relevant dataset 273 | urlretrieve(zip_url, self.zip_path, reporthook=self.progress) 274 | 275 | print("\nDone! \n\nUnzipping and restructuring") 276 | 277 | # Extracts the dataset 278 | z = zipfile.ZipFile(self.zip_path, "r") 279 | z.extractall(path=self.root) 280 | z.close() 281 | os.remove(self.zip_path) 282 | 283 | temp_folder = os.path.join(self.root, "DAVIS\\") 284 | 285 | # Deletes an unnecessary containing folder "DAVIS" which comes with every download 286 | for file in os.listdir(temp_folder): 287 | shutil.move(temp_folder + file, self.root) 288 | cwd = os.getcwd() 289 | os.chdir(self.root) 290 | os.rmdir("DAVIS") 291 | os.chdir(cwd) 292 | 293 | print("\nDone!\n") 294 | 295 | def __getitem__(self, ind): 296 | """ 297 | Gets an item of the Dataset based on index 298 | 299 | :param ind: index of item to take from dataset 300 | 301 | :return: a sequence which contains a list of images and masks 302 | """ 303 | seq = self.enum_sequences[ind] 304 | return seq 305 | 306 | # Simple progress indicator for the download of the dataset 307 | def progress(self, count, block_size, total_size): 308 | global start_time 309 | if count == 0: 310 | start_time = time.time() 311 | return 312 | duration = time.time() - start_time 313 | progress_size = int(count * block_size) 314 | speed = int(progress_size / (1024 * duration)) 315 | percent = min(int(count * block_size * 100 / total_size), 100) 316 | sys.stdout.write( 317 | "\r...%d%%, %d MB, %d KB/s, %d seconds passed" 318 | % (percent, progress_size / (1024 * 1024), speed, duration) 319 | ) 320 | sys.stdout.flush() 321 | -------------------------------------------------------------------------------- /bindsnet/datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def gray_scale(image: np.ndarray) -> np.ndarray: 6 | # language=rst 7 | """ 8 | Converts RGB image into grayscale. 9 | 10 | :param image: RGB image. 11 | :return: Gray-scaled image. 12 | """ 13 | return cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) 14 | 15 | 16 | def crop(image: np.ndarray, x1: int, x2: int, y1: int, y2: int) -> np.ndarray: 17 | # language=rst 18 | """ 19 | Crops an image given coordinates of cropping box. 20 | 21 | :param image: 3-dimensional image. 22 | :param x1: Left x coordinate. 23 | :param x2: Right x coordinate. 24 | :param y1: Bottom y coordinate. 25 | :param y2: Top y coordinate. 26 | :return: Image cropped using coordinates (x1, x2, y1, y2). 27 | """ 28 | return image[x1:x2, y1:y2, :] 29 | 30 | 31 | def binary_image(image: np.ndarray) -> np.ndarray: 32 | # language=rst 33 | """ 34 | Converts input image into black and white (binary) 35 | 36 | :param image: Gray-scaled image. 37 | :return: Black and white image. 38 | """ 39 | return cv2.threshold(image, 0, 1, cv2.THRESH_BINARY)[1] 40 | 41 | 42 | def subsample(image: np.ndarray, x: int, y: int) -> np.ndarray: 43 | # language=rst 44 | """ 45 | Scale the image to (x, y). 46 | 47 | :param image: Image to be rescaled. 48 | :param x: Output value for ``image``'s x dimension. 49 | :param y: Output value for ``image``'s y dimension. 50 | :return: Re-scaled image. 51 | """ 52 | return cv2.resize(image, (x, y)) 53 | -------------------------------------------------------------------------------- /bindsnet/datasets/spoken_mnist.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, List, Iterable 2 | import os 3 | import torch 4 | import numpy as np 5 | import shutil 6 | import zipfile 7 | 8 | 9 | from urllib.request import urlretrieve 10 | from scipy.io import wavfile 11 | 12 | import warnings 13 | 14 | 15 | class SpokenMNIST(torch.utils.data.Dataset): 16 | # language=rst 17 | """ 18 | Handles loading and saving of the Spoken MNIST audio dataset `(link) 19 | `_. 20 | """ 21 | train_pickle = "train.pt" 22 | test_pickle = "test.pt" 23 | 24 | url = "https://github.com/Jakobovski/free-spoken-digit-dataset/archive/master.zip" 25 | 26 | files = [] 27 | for digit in range(10): 28 | for speaker in ["jackson", "nicolas", "theo"]: 29 | for example in range(50): 30 | files.append("_".join([str(digit), speaker, str(example)]) + ".wav") 31 | 32 | n_files = len(files) 33 | 34 | def __init__( 35 | self, 36 | path: str, 37 | download: bool = False, 38 | shuffle: bool = True, 39 | train: bool = True, 40 | split: float = 0.8, 41 | num_samples: int = -1, 42 | ) -> None: 43 | # language=rst 44 | """ 45 | Constructor for the ``SpokenMNIST`` object. Makes the data directory if it doesn't already exist. 46 | 47 | :param path: Pathname of directory in which to store the dataset. 48 | :param download: Whether or not to download the dataset (requires internet connection). 49 | :param shuffle: Whether to randomly permute order of dataset. 50 | :param train: Load training split if true else load test split 51 | :param split: Train, test split; in range ``(0, 1)``. 52 | :param num_samples: Number of samples to pass to the batch 53 | """ 54 | super().__init__() 55 | 56 | if not os.path.isdir(path): 57 | os.makedirs(path) 58 | 59 | self.path = path 60 | self.download = download 61 | self.shuffle = shuffle 62 | 63 | self.zip_path = os.path.join(path, "repo.zip") 64 | 65 | if train: 66 | self.audio, self.labels = self._get_train(split) 67 | else: 68 | self.audio, self.labels = self._get_test(split) 69 | 70 | self.num_samples = num_samples 71 | 72 | def __len__(self): 73 | return len(self.audio) 74 | 75 | def __getitem__(self, ind): 76 | audio = self.audio[ind][: self.num_samples, :] 77 | label = self.labels[ind] 78 | 79 | return {"audio": audio, "label": label} 80 | 81 | def _get_train(self, split: float = 0.8) -> Tuple[torch.Tensor, torch.Tensor]: 82 | # language=rst 83 | """ 84 | Gets the Spoken MNIST training audio and labels. 85 | 86 | :param split: Train, test split; in range ``(0, 1)``. 87 | :return: Spoken MNIST training audio and labels. 88 | """ 89 | split_index = int(split * SpokenMNIST.n_files) 90 | path = os.path.join(self.path, "_".join([SpokenMNIST.train_pickle, str(split)])) 91 | 92 | if not all([os.path.isfile(os.path.join(self.path, f)) for f in self.files]): 93 | # Download data if it isn't on disk. 94 | if self.download: 95 | print("Downloading Spoken MNIST data.\n") 96 | self._download() 97 | 98 | # Process data into audio, label (input, output) pairs. 99 | audio, labels = self.process_data(SpokenMNIST.files[:split_index]) 100 | 101 | # Serialize image data on disk for next time. 102 | torch.save((audio, labels), open(path, "wb")) 103 | else: 104 | msg = "Dataset not found on disk; specify 'download=True' to allow downloads." 105 | raise FileNotFoundError(msg) 106 | else: 107 | if not os.path.isdir(path): 108 | # Process image and label data if pickled file doesn't exist. 109 | audio, labels = self.process_data(SpokenMNIST.files) 110 | 111 | # Serialize image data on disk for next time. 112 | torch.save((audio, labels), open(path, "wb")) 113 | else: 114 | # Load image data from disk if it has already been processed. 115 | print("Loading training data from serialized object file.\n") 116 | audio, labels = torch.load(open(path, "rb")) 117 | 118 | labels = torch.Tensor(labels) 119 | 120 | if self.shuffle: 121 | perm = np.random.permutation(np.arange(labels.shape[0])) 122 | audio, labels = [torch.Tensor(audio[_]) for _ in perm], labels[perm] 123 | 124 | return audio, torch.Tensor(labels) 125 | 126 | def _get_test(self, split: float = 0.8) -> Tuple[torch.Tensor, List[torch.Tensor]]: 127 | # language=rst 128 | """ 129 | Gets the Spoken MNIST training audio and labels. 130 | 131 | :param split: Train, test split; in range ``(0, 1)``. 132 | :return: The Spoken MNIST test audio and labels. 133 | """ 134 | split_index = int(split * SpokenMNIST.n_files) 135 | path = os.path.join(self.path, "_".join([SpokenMNIST.test_pickle, str(split)])) 136 | 137 | if not all([os.path.isfile(os.path.join(self.path, f)) for f in self.files]): 138 | # Download data if it isn't on disk. 139 | if self.download: 140 | print("Downloading Spoken MNIST data.\n") 141 | self._download() 142 | 143 | # Process data into audio, label (input, output) pairs. 144 | audio, labels = self.process_data(SpokenMNIST.files[split_index:]) 145 | 146 | # Serialize image data on disk for next time. 147 | torch.save((audio, labels), open(path, "wb")) 148 | else: 149 | msg = "Dataset not found on disk; specify 'download=True' to allow downloads." 150 | raise FileNotFoundError(msg) 151 | else: 152 | if not os.path.isdir(path): 153 | # Process image and label data if pickled file doesn't exist. 154 | audio, labels = self.process_data(SpokenMNIST.files) 155 | 156 | # Serialize image data on disk for next time. 157 | torch.save((audio, labels), open(path, "wb")) 158 | else: 159 | # Load image data from disk if it has already been processed. 160 | print("Loading test data from serialized object file.\n") 161 | audio, labels = torch.load(open(path, "rb")) 162 | 163 | labels = torch.Tensor(labels) 164 | 165 | if self.shuffle: 166 | perm = np.random.permutation(np.arange(labels.shape[0])) 167 | audio, labels = audio[perm], labels[perm] 168 | 169 | return audio, torch.Tensor(labels) 170 | 171 | def _download(self) -> None: 172 | # language=rst 173 | """ 174 | Downloads and unzips all Spoken MNIST data. 175 | """ 176 | urlretrieve(SpokenMNIST.url, self.zip_path) 177 | 178 | z = zipfile.ZipFile(self.zip_path, "r") 179 | z.extractall(path=self.path) 180 | z.close() 181 | 182 | path = os.path.join(self.path, "free-spoken-digit-dataset-master", "recordings") 183 | for f in os.listdir(path): 184 | shutil.move(os.path.join(path, f), os.path.join(self.path)) 185 | 186 | cwd = os.getcwd() 187 | os.chdir(self.path) 188 | shutil.rmtree("free-spoken-digit-dataset-master") 189 | os.chdir(cwd) 190 | 191 | def process_data( 192 | self, file_names: Iterable[str] 193 | ) -> Tuple[List[torch.Tensor], torch.Tensor]: 194 | # language=rst 195 | """ 196 | Opens files of Spoken MNIST data and processes them into ``numpy`` arrays. 197 | 198 | :param file_names: Names of the files containing Spoken MNIST audio to load. 199 | :return: Processed Spoken MNIST audio and label data. 200 | """ 201 | audio, labels = [], [] 202 | 203 | for f in file_names: 204 | label = int(f.split("_")[0]) 205 | 206 | sample_rate, signal = wavfile.read(os.path.join(self.path, f)) 207 | pre_emphasis = 0.97 208 | emphasized_signal = np.append( 209 | signal[0], signal[1:] - pre_emphasis * signal[:-1] 210 | ) 211 | 212 | # Popular settings are 25 ms for the frame size and a 10 ms stride (15 ms overlap) 213 | frame_size = 0.025 214 | frame_stride = 0.01 215 | 216 | # Convert from seconds to samples 217 | frame_length, frame_step = ( 218 | frame_size * sample_rate, 219 | frame_stride * sample_rate, 220 | ) 221 | signal_length = len(emphasized_signal) 222 | frame_length = int(round(frame_length)) 223 | frame_step = int(round(frame_step)) 224 | 225 | # Make sure that we have at least 1 frame 226 | num_frames = int( 227 | np.ceil(float(np.abs(signal_length - frame_length)) / frame_step) 228 | ) 229 | 230 | pad_signal_length = num_frames * frame_step + frame_length 231 | z = np.zeros((pad_signal_length - signal_length)) 232 | pad_signal = np.append(emphasized_signal, z) # Pad signal 233 | 234 | indices = ( 235 | np.tile(np.arange(0, frame_length), (num_frames, 1)) 236 | + np.tile( 237 | np.arange(0, num_frames * frame_step, frame_step), (frame_length, 1) 238 | ).T 239 | ) 240 | frames = pad_signal[indices.astype(np.int32, copy=False)] 241 | 242 | # Hamming Window 243 | frames *= np.hamming(frame_length) 244 | 245 | # Fast Fourier Transform and Power Spectrum 246 | NFFT = 512 247 | mag_frames = np.absolute(np.fft.rfft(frames, NFFT)) # Magnitude of the FFT 248 | pow_frames = (1.0 / NFFT) * (mag_frames ** 2) # Power Spectrum 249 | 250 | # Log filter banks 251 | nfilt = 40 252 | low_freq_mel = 0 253 | high_freq_mel = 2595 * np.log10( 254 | 1 + (sample_rate / 2) / 700 255 | ) # Convert Hz to Mel 256 | mel_points = np.linspace( 257 | low_freq_mel, high_freq_mel, nfilt + 2 258 | ) # Equally spaced in Mel scale 259 | hz_points = 700 * (10 ** (mel_points / 2595) - 1) # Convert Mel to Hz 260 | bin = np.floor((NFFT + 1) * hz_points / sample_rate) 261 | 262 | fbank = np.zeros((nfilt, int(np.floor(NFFT / 2 + 1)))) 263 | for m in range(1, nfilt + 1): 264 | f_m_minus = int(bin[m - 1]) # left 265 | f_m = int(bin[m]) # center 266 | f_m_plus = int(bin[m + 1]) # right 267 | 268 | for k in range(f_m_minus, f_m): 269 | fbank[m - 1, k] = (k - bin[m - 1]) / (bin[m] - bin[m - 1]) 270 | for k in range(f_m, f_m_plus): 271 | fbank[m - 1, k] = (bin[m + 1] - k) / (bin[m + 1] - bin[m]) 272 | 273 | filter_banks = np.dot(pow_frames, fbank.T) 274 | filter_banks = np.where( 275 | filter_banks == 0, np.finfo(float).eps, filter_banks 276 | ) # Numerical Stability 277 | filter_banks = 20 * np.log10(filter_banks) # dB 278 | 279 | audio.append(filter_banks), labels.append(label) 280 | 281 | return audio, torch.Tensor(labels) 282 | -------------------------------------------------------------------------------- /bindsnet/datasets/torchvision_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | import torch 4 | import torchvision 5 | 6 | from ..encoding import Encoder, NullEncoder 7 | 8 | 9 | def create_torchvision_dataset_wrapper(ds_type): 10 | """ Creates wrapper classes for datasets that output (image, label) 11 | from __getitem__. This is all of the datasets inside of torchvision. 12 | """ 13 | 14 | if type(ds_type) == str: 15 | ds_type = getattr(torchvision.datasets, ds_type) 16 | 17 | class TorchvisionDatasetWrapper(ds_type): 18 | __doc__ = ( 19 | """BindsNET torchvision dataset wrapper for: 20 | 21 | The core difference is the output of __getitem__ is no longer 22 | (image, label) rather a dictionary containing the image, label, 23 | and their encoded versions if encoders were provided. 24 | 25 | \n\n""" 26 | + str(ds_type) 27 | if ds_type.__doc__ is None 28 | else ds_type.__doc__ 29 | ) 30 | 31 | def __init__( 32 | self, 33 | image_encoder: Optional[Encoder] = None, 34 | label_encoder: Optional[Encoder] = None, 35 | *args, 36 | **kwargs 37 | ): 38 | # language=rst 39 | """ 40 | Constructor for the BindsNET torchvision dataset wrapper. 41 | For details on the dataset you're interested in visit 42 | 43 | https://pytorch.org/docs/stable/torchvision/datasets.html 44 | 45 | :param image_encoder: Spike encoder for use on the image 46 | :param label_encoder: Spike encoder for use on the label 47 | :param *args: Arguments for the original dataset 48 | :param **kwargs: Keyword arguments for the original dataset 49 | """ 50 | super().__init__(*args, **kwargs) 51 | 52 | self.args = args 53 | self.kwargs = kwargs 54 | 55 | # Allow the passthrough of None, but change to NullEncoder 56 | if image_encoder is None: 57 | image_encoder = NullEncoder() 58 | 59 | if label_encoder is None: 60 | label_encoder = NullEncoder() 61 | 62 | self.image_encoder = image_encoder 63 | self.label_encoder = label_encoder 64 | 65 | def __getitem__(self, ind: int) -> Dict[str, torch.Tensor]: 66 | """ 67 | Utilizes the torchvision.dataset parent class to grab the 68 | data, then encodes using the supplied encoders. 69 | 70 | :param int ind: Index to grab data at 71 | 72 | :return: The relevant data and encoded data from the 73 | requested index. 74 | """ 75 | 76 | image, label = super().__getitem__(ind) 77 | 78 | output = { 79 | "image": image, 80 | "label": label, 81 | "encoded_image": self.image_encoder(image), 82 | "encoded_label": self.label_encoder(label), 83 | } 84 | 85 | return output 86 | 87 | def __len__(self): 88 | return super().__len__() 89 | 90 | return TorchvisionDatasetWrapper 91 | -------------------------------------------------------------------------------- /bindsnet/encoding/__init__.py: -------------------------------------------------------------------------------- 1 | from .encodings import single, repeat, bernoulli, poisson, rank_order 2 | from .loaders import bernoulli_loader, poisson_loader, rank_order_loader 3 | from .encoders import ( 4 | Encoder, 5 | NullEncoder, 6 | SingleEncoder, 7 | RepeatEncoder, 8 | BernoulliEncoder, 9 | PoissonEncoder, 10 | RankOrderEncoder, 11 | ) 12 | -------------------------------------------------------------------------------- /bindsnet/encoding/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/encoding/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/encoding/__pycache__/encoders.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/encoding/__pycache__/encoders.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/encoding/__pycache__/encodings.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/encoding/__pycache__/encodings.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/encoding/__pycache__/loaders.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/encoding/__pycache__/loaders.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/encoding/encoders.py: -------------------------------------------------------------------------------- 1 | from . import encodings 2 | 3 | 4 | class Encoder: 5 | """ 6 | Base class for spike encodings transforms. 7 | 8 | - Calls self.enc from the subclass and passes whatever arguments were 9 | provided. self.enc must be callable with torch.Tensor, *args, **kwargs 10 | """ 11 | 12 | def __init__(self, *args, **kwargs) -> None: 13 | self.enc_args = args 14 | self.enc_kwargs = kwargs 15 | 16 | def __call__(self, img): 17 | return self.enc(img, *self.enc_args, **self.enc_kwargs) 18 | 19 | 20 | class NullEncoder(Encoder): 21 | """ 22 | Pass through of the datum that was input. 23 | 24 | WARNING - this is not a real encoder into spikes. Be careful with 25 | the usage of this class. 26 | """ 27 | 28 | def __init__(self): 29 | super().__init__() 30 | 31 | def __call__(self, img): 32 | return img 33 | 34 | 35 | class SingleEncoder(Encoder): 36 | def __init__(self, time: int, dt: float = 1.0, sparsity: float = 0.3, **kwargs): 37 | """ 38 | Creates a callable SingleEncoder which encodes as defined in 39 | :code:`bindsnet.encodings.single` 40 | 41 | :param time: Length of single spike train per input variable. 42 | :param dt: Simulation time step. 43 | :param sparsity: Sparsity of the input representation. 0 for no spike and 1 for all spike. 44 | """ 45 | super().__init__(time, dt=dt, sparsity=sparsity, **kwargs) 46 | 47 | self.enc = encodings.single 48 | 49 | 50 | class RepeatEncoder(Encoder): 51 | def __init__(self, time: int, dt: float = 1.0, **kwargs): 52 | """ 53 | Creates a callable RepeatEncoder which encodes as defined in 54 | :code:`bindsnet.encodings.repeat` 55 | 56 | :param time: Length of repeat spike train per input variable. 57 | :param dt: Simulation time step. 58 | """ 59 | super().__init__(time, dt=dt, **kwargs) 60 | 61 | self.enc = encodings.repeat 62 | 63 | 64 | class BernoulliEncoder(Encoder): 65 | def __init__(self, time: int, dt: float = 1.0, **kwargs): 66 | """ 67 | Creates a callable BernoulliEncoder which encodes as defined in 68 | :code:`bindsnet.encodings.bernoulli` 69 | 70 | :param time: Length of Bernoulli spike train per input variable. 71 | :param dt: Simulation time step. 72 | 73 | Keyword arguments: 74 | 75 | :param float max_prob: Maximum probability of spike per Bernoulli trial. 76 | """ 77 | super().__init__(time, dt=dt, **kwargs) 78 | 79 | self.enc = encodings.bernoulli 80 | 81 | 82 | class PoissonEncoder(Encoder): 83 | def __init__(self, time: int, dt: float = 1.0, **kwargs): 84 | """ 85 | Creates a callable PoissonEncoder which encodes as defined in 86 | :code:`bindsnet.encodings.poisson` 87 | 88 | :param time: Length of Poisson spike train per input variable. 89 | :param dt: Simulation time step. 90 | """ 91 | super().__init__(time, dt=dt, **kwargs) 92 | 93 | self.enc = encodings.poisson 94 | 95 | 96 | class RankOrderEncoder(Encoder): 97 | def __init__(self, time: int, dt: float = 1.0, **kwargs): 98 | """ 99 | Creates a callable RankOrderEncoder which encodes as defined in 100 | :code:`bindsnet.encodings.rank_order` 101 | 102 | :param time: Length of RankOrder spike train per input variable. 103 | :param dt: Simulation time step. 104 | """ 105 | super().__init__(time, dt=dt, **kwargs) 106 | 107 | self.enc = encodings.rank_order 108 | -------------------------------------------------------------------------------- /bindsnet/encoding/encodings.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import numpy as np 5 | 6 | 7 | def single( 8 | datum: torch.Tensor, time: int, dt: float = 1.0, sparsity: float = 0.3, **kwargs 9 | ) -> torch.Tensor: 10 | # language=rst 11 | """ 12 | Generates timing based single-spike encoding. Spike occurs earlier if the 13 | intensity of the input feature is higher. Features whose value is lower than 14 | threshold is remain silent. 15 | 16 | :param datum: Tensor of shape ``[n_1, ..., n_k]``. 17 | :param time: Length of the input and output. 18 | :param dt: Simulation time step. 19 | :param sparsity: Sparsity of the input representation. 0 for no spike and 1 for all spike. 20 | :return: Tensor of shape ``[time, n_1, ..., n_k]``. 21 | """ 22 | time = int(time / dt) 23 | shape = list(datum.shape) 24 | datum = np.copy(datum) 25 | quantile = np.quantile(datum, 1 - sparsity) 26 | s = np.zeros([time, *shape]) 27 | s[0] = np.where(datum > quantile, np.ones(shape), np.zeros(shape)) 28 | return torch.Tensor(s).byte() 29 | 30 | 31 | def repeat(datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs) -> torch.Tensor: 32 | # language=rst 33 | """ 34 | :param datum: Repeats a tensor along a new dimension in the 0th position for ``int(time / dt)`` timesteps. 35 | :param time: Tensor of shape ``[n_1, ..., n_k]``. 36 | :param dt: Simulation time step. 37 | :return: Tensor of shape ``[time, n_1, ..., n_k]`` of repeated data along the 0th dimension. 38 | """ 39 | time = int(time / dt) 40 | return datum.repeat([time, *([1] * len(datum.shape))]) 41 | 42 | 43 | def bernoulli( 44 | datum: torch.Tensor, time: Optional[int] = None, dt: float = 1.0, **kwargs 45 | ) -> torch.Tensor: 46 | # language=rst 47 | """ 48 | :param datum: Generates Bernoulli-distributed spike trains based on input intensity. Inputs must be non-negative. 49 | Spikes correspond to successful Bernoulli trials, with success probability equal to (normalized in 50 | [0, 1]) input value. 51 | :param time: Tensor of shape ``[n_1, ..., n_k]``. 52 | :param dt: Simulation time step. 53 | :return: Tensor of shape ``[time, n_1, ..., n_k]`` of Bernoulli-distributed spikes. 54 | 55 | Keyword arguments: 56 | 57 | :param float max_prob: Maximum probability of spike per Bernoulli trial. 58 | """ 59 | # Setting kwargs. 60 | max_prob = kwargs.get("max_prob", 1.0) 61 | 62 | assert 0 <= max_prob <= 1, "Maximum firing probability must be in range [0, 1]" 63 | assert (datum >= 0).all(), "Inputs must be non-negative" 64 | 65 | shape, size = datum.shape, datum.numel() 66 | datum = datum.view(-1) 67 | 68 | if time is not None: 69 | time = int(time / dt) 70 | 71 | # Normalize inputs and rescale (spike probability proportional to normalized intensity). 72 | if datum.max() > 1.0: 73 | datum /= datum.max() 74 | 75 | # Make spike data from Bernoulli sampling. 76 | if time is None: 77 | spikes = torch.bernoulli(max_prob * datum) 78 | spikes = spikes.view(*shape) 79 | else: 80 | spikes = torch.bernoulli(max_prob * datum.repeat([time, 1])) 81 | spikes = spikes.view(time, *shape) 82 | 83 | return spikes.byte() 84 | 85 | 86 | def poisson(datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs) -> torch.Tensor: 87 | # language=rst 88 | """ 89 | Generates Poisson-distributed spike trains based on input intensity. Inputs must be non-negative, and give the 90 | firing rate in Hz. Inter-spike intervals (ISIs) for non-negative data incremented by one to avoid zero intervals 91 | while maintaining ISI distributions. 92 | 93 | For example, an input of intensity :code:`x` will have an average firing rate of :code:`x`Hz. 94 | 95 | :param datum: Tensor of shape ``[n_1, ..., n_k]``. 96 | :param time: Length of Poisson spike train per input variable. 97 | :param dt: Simulation time step. 98 | :return: Tensor of shape ``[time, n_1, ..., n_k]`` of Poisson-distributed spikes. 99 | """ 100 | assert (datum >= 0).all(), "Inputs must be non-negative" 101 | 102 | # Get shape and size of data. 103 | shape, size = datum.shape, datum.numel() 104 | datum = datum.view(-1) 105 | time = int(time / dt) 106 | 107 | # Compute firing rates in seconds as function of data intensity, 108 | # accounting for simulation time step. 109 | rate = torch.zeros(size) 110 | rate[datum != 0] = 1 / datum[datum != 0] * (1000 / dt) 111 | 112 | # Create Poisson distribution and sample inter-spike intervals 113 | # (incrementing by 1 to avoid zero intervals). 114 | dist = torch.distributions.Poisson(rate=rate) 115 | intervals = dist.sample(sample_shape=torch.Size([time + 1])) 116 | intervals[:, datum != 0] += (intervals[:, datum != 0] == 0).float() 117 | 118 | # Calculate spike times by cumulatively summing over time dimension. 119 | times = torch.cumsum(intervals, dim=0).long() 120 | times[times >= time + 1] = 0 121 | 122 | # Create tensor of spikes. 123 | spikes = torch.zeros(time + 1, size).byte() 124 | spikes[times, torch.arange(size)] = 1 125 | spikes = spikes[1:] 126 | 127 | return spikes.view(time, *shape) 128 | 129 | 130 | def rank_order( 131 | datum: torch.Tensor, time: int, dt: float = 1.0, **kwargs 132 | ) -> torch.Tensor: 133 | # language=rst 134 | """ 135 | Encodes data via a rank order coding-like representation. One spike per neuron, temporally ordered by decreasing 136 | intensity. Inputs must be non-negative. 137 | 138 | :param datum: Tensor of shape ``[n_samples, n_1, ..., n_k]``. 139 | :param time: Length of rank order-encoded spike train per input variable. 140 | :param dt: Simulation time step. 141 | :return: Tensor of shape ``[time, n_1, ..., n_k]`` of rank order-encoded spikes. 142 | """ 143 | assert (datum >= 0).all(), "Inputs must be non-negative" 144 | 145 | shape, size = datum.shape, datum.numel() 146 | datum = datum.view(-1) 147 | time = int(time / dt) 148 | 149 | # Create spike times in order of decreasing intensity. 150 | datum /= datum.max() 151 | times = torch.zeros(size) 152 | times[datum != 0] = 1 / datum[datum != 0] 153 | times *= time / times.max() # Extended through simulation time. 154 | times = torch.ceil(times).long() 155 | 156 | # Create spike times tensor. 157 | spikes = torch.zeros(time, size).byte() 158 | for i in range(size): 159 | if 0 < times[i] < time: 160 | spikes[times[i] - 1, i] = 1 161 | 162 | return spikes.reshape(time, *shape) 163 | -------------------------------------------------------------------------------- /bindsnet/encoding/loaders.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Iterable, Iterator 2 | 3 | import torch 4 | 5 | from .encodings import bernoulli, poisson, rank_order 6 | 7 | 8 | def bernoulli_loader( 9 | data: Union[torch.Tensor, Iterable[torch.Tensor]], 10 | time: Optional[int] = None, 11 | dt: float = 1.0, 12 | **kwargs 13 | ) -> Iterator[torch.Tensor]: 14 | # language=rst 15 | """ 16 | Lazily invokes ``bindsnet.encoding.bernoulli`` to iteratively encode a sequence of data. 17 | 18 | :param data: Tensor of shape ``[n_samples, n_1, ..., n_k]``. 19 | :param time: Length of Bernoulli spike train per input variable. 20 | :param dt: Simulation time step. 21 | :return: Tensors of shape ``[time, n_1, ..., n_k]`` of Bernoulli-distributed spikes. 22 | 23 | Keyword arguments: 24 | 25 | :param float max_prob: Maximum probability of spike per Bernoulli trial. 26 | """ 27 | # Setting kwargs. 28 | max_prob = kwargs.get("dt", 1.0) 29 | 30 | for i in range(len(data)): 31 | # Encode datum as Bernoulli spike trains. 32 | yield bernoulli(datum=data[i], time=time, dt=dt, max_prob=max_prob) 33 | 34 | 35 | def poisson_loader( 36 | data: Union[torch.Tensor, Iterable[torch.Tensor]], 37 | time: int, 38 | dt: float = 1.0, 39 | **kwargs 40 | ) -> Iterator[torch.Tensor]: 41 | # language=rst 42 | """ 43 | Lazily invokes ``bindsnet.encoding.poisson`` to iteratively encode a sequence of data. 44 | 45 | :param data: Tensor of shape ``[n_samples, n_1, ..., n_k]``. 46 | :param time: Length of Poisson spike train per input variable. 47 | :param dt: Simulation time step. 48 | :return: Tensors of shape ``[time, n_1, ..., n_k]`` of Poisson-distributed spikes. 49 | """ 50 | for i in range(len(data)): 51 | # Encode datum as Poisson spike trains. 52 | yield poisson(datum=data[i], time=time, dt=dt) 53 | 54 | 55 | def rank_order_loader( 56 | data: Union[torch.Tensor, Iterable[torch.Tensor]], 57 | time: int, 58 | dt: float = 1.0, 59 | **kwargs 60 | ) -> Iterator[torch.Tensor]: 61 | # language=rst 62 | """ 63 | Lazily invokes ``bindsnet.encoding.rank_order`` to iteratively encode a sequence of data. 64 | 65 | :param data: Tensor of shape ``[n_samples, n_1, ..., n_k]``. 66 | :param time: Length of rank order-encoded spike train per input variable. 67 | :param dt: Simulation time step. 68 | :return: Tensors of shape ``[time, n_1, ..., n_k]`` of rank order-encoded spikes. 69 | """ 70 | for i in range(len(data)): 71 | # Encode datum as rank order-encoded spike trains. 72 | yield rank_order(datum=data[i], time=time, dt=dt) 73 | -------------------------------------------------------------------------------- /bindsnet/environment/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment import Environment, GymEnvironment 2 | -------------------------------------------------------------------------------- /bindsnet/environment/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/environment/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/environment/__pycache__/environment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/environment/__pycache__/environment.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/environment/environment.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple, Dict, Any 3 | 4 | import gym 5 | import numpy as np 6 | import torch 7 | 8 | from ..datasets.preprocess import subsample, gray_scale, binary_image, crop 9 | from ..encoding import Encoder, NullEncoder 10 | 11 | 12 | class Environment(ABC): 13 | # language=rst 14 | """ 15 | Abstract environment class. 16 | """ 17 | 18 | @abstractmethod 19 | def step(self, a: int) -> Tuple[Any, ...]: 20 | # language=rst 21 | """ 22 | Abstract method head for ``step()``. 23 | 24 | :param a: Integer action to take in environment. 25 | """ 26 | pass 27 | 28 | @abstractmethod 29 | def reset(self) -> None: 30 | # language=rst 31 | """ 32 | Abstract method header for ``reset()``. 33 | """ 34 | pass 35 | 36 | @abstractmethod 37 | def render(self) -> None: 38 | # language=rst 39 | """ 40 | Abstract method header for ``render()``. 41 | """ 42 | pass 43 | 44 | @abstractmethod 45 | def close(self) -> None: 46 | # language=rst 47 | """ 48 | Abstract method header for ``close()``. 49 | """ 50 | pass 51 | 52 | @abstractmethod 53 | def preprocess(self) -> None: 54 | # language=rst 55 | """ 56 | Abstract method header for ``preprocess()``. 57 | """ 58 | pass 59 | 60 | 61 | class GymEnvironment(Environment): 62 | # language=rst 63 | """ 64 | A wrapper around the OpenAI ``gym`` environments. 65 | """ 66 | 67 | def __init__(self, name: str, encoder: Encoder = NullEncoder(), **kwargs) -> None: 68 | # language=rst 69 | """ 70 | Initializes the environment wrapper. This class makes the 71 | assumption that the OpenAI ``gym`` environment will provide an image 72 | of format HxW or CxHxW as an observation (we will add the C 73 | dimension to HxW tensors) or a 1D observation in which case no 74 | dimensions will be added. 75 | 76 | :param name: The name of an OpenAI ``gym`` environment. 77 | :param encoder: Function to encode observations into spike trains. 78 | 79 | Keyword arguments: 80 | 81 | :param float max_prob: Maximum spiking probability. 82 | :param bool clip_rewards: Whether or not to use ``np.sign`` of rewards. 83 | 84 | :param int history: Number of observations to keep track of. 85 | :param int delta: Step size to save observations in history. 86 | :param bool add_channel_dim: Allows for the adding of the channel dimension in 2D inputs. 87 | """ 88 | self.name = name 89 | self.env = gym.make(name) 90 | self.action_space = self.env.action_space 91 | 92 | self.encoder = encoder 93 | 94 | # Keyword arguments. 95 | self.max_prob = kwargs.get("max_prob", 1.0) 96 | self.clip_rewards = kwargs.get("clip_rewards", True) 97 | 98 | self.history_length = kwargs.get("history_length", None) 99 | self.delta = kwargs.get("delta", 1) 100 | self.add_channel_dim = kwargs.get("add_channel_dim", True) 101 | 102 | if self.history_length is not None and self.delta is not None: 103 | self.history = { 104 | i: torch.Tensor() 105 | for i in range(1, self.history_length * self.delta + 1, self.delta) 106 | } 107 | else: 108 | self.history = {} 109 | 110 | self.episode_step_count = 0 111 | self.history_index = 1 112 | 113 | self.obs = None 114 | self.reward = None 115 | 116 | assert ( 117 | 0.0 < self.max_prob <= 1.0 118 | ), "Maximum spiking probability must be in (0, 1]." 119 | 120 | def step(self, a: int) -> Tuple[torch.Tensor, float, bool, Dict[Any, Any]]: 121 | # language=rst 122 | """ 123 | Wrapper around the OpenAI ``gym`` environment ``step()`` function. 124 | 125 | :param a: Action to take in the environment. 126 | :return: Observation, reward, done flag, and information dictionary. 127 | """ 128 | # Call gym's environment step function. 129 | self.obs, self.reward, self.done, info = self.env.step(a) 130 | 131 | if self.clip_rewards: 132 | self.reward = np.sign(self.reward) 133 | 134 | self.preprocess() 135 | 136 | # Add the raw observation from the gym environment into the info 137 | # for debugging and display. 138 | info["gym_obs"] = self.obs 139 | 140 | # Store frame of history and encode the inputs. 141 | if len(self.history) > 0: 142 | self.update_history() 143 | self.update_index() 144 | # Add the delta observation into the info for debugging and display. 145 | info["delta_obs"] = self.obs 146 | 147 | # The new standard for images is BxTxCxHxW. 148 | # The gym environment doesn't follow exactly the same protocol. 149 | # 150 | # 1D observations will be left as is before the encoder and will become BxTxL. 151 | # 2D observations are assumed to be mono images will become BxTx1xHxW 152 | # 3D observations will become BxTxCxHxW 153 | if self.obs.dim() == 2 and self.add_channel_dim: 154 | # We want CxHxW, it is currently HxW. 155 | self.obs = self.obs.unsqueeze(0) 156 | 157 | # The encoder will add time - now Tx... 158 | if self.encoder is not None: 159 | self.obs = self.encoder(self.obs) 160 | 161 | # Add the batch - now BxTx... 162 | self.obs = self.obs.unsqueeze(0) 163 | 164 | self.episode_step_count += 1 165 | 166 | # Return converted observations and other information. 167 | return self.obs, self.reward, self.done, info 168 | 169 | def reset(self) -> torch.Tensor: 170 | # language=rst 171 | """ 172 | Wrapper around the OpenAI ``gym`` environment ``reset()`` function. 173 | 174 | :return: Observation from the environment. 175 | """ 176 | # Call gym's environment reset function. 177 | self.obs = self.env.reset() 178 | self.preprocess() 179 | 180 | self.history = {i: torch.Tensor() for i in self.history} 181 | 182 | self.episode_step_count = 0 183 | 184 | return self.obs 185 | 186 | def render(self) -> None: 187 | # language=rst 188 | """ 189 | Wrapper around the OpenAI ``gym`` environment ``render()`` function. 190 | """ 191 | self.env.render() 192 | 193 | def close(self) -> None: 194 | # language=rst 195 | """ 196 | Wrapper around the OpenAI ``gym`` environment ``close()`` function. 197 | """ 198 | self.env.close() 199 | 200 | def preprocess(self) -> None: 201 | # language=rst 202 | """ 203 | Pre-processing step for an observation from a ``gym`` environment. 204 | """ 205 | if self.name == "SpaceInvaders-v0": 206 | self.obs = subsample(gray_scale(self.obs), 84, 110) 207 | self.obs = self.obs[26:104, :] 208 | self.obs = binary_image(self.obs) 209 | elif self.name == "BreakoutDeterministic-v4": 210 | self.obs = subsample(gray_scale(crop(self.obs, 34, 194, 0, 160)), 80, 80) 211 | self.obs = binary_image(self.obs) 212 | else: # Default pre-processing step. 213 | pass 214 | 215 | self.obs = torch.from_numpy(self.obs).float() 216 | 217 | def update_history(self) -> None: 218 | # language=rst 219 | """ 220 | Updates the observations inside history by performing subtraction from most recent observation and the sum of 221 | previous observations. If there are not enough observations to take a difference from, simply store the 222 | observation without any differencing. 223 | """ 224 | # Recording initial observations. 225 | if self.episode_step_count < len(self.history) * self.delta: 226 | # Store observation based on delta value. 227 | if self.episode_step_count % self.delta == 0: 228 | self.history[self.history_index] = self.obs 229 | else: 230 | # Take difference between stored frames and current frame. 231 | temp = torch.clamp(self.obs - sum(self.history.values()), 0, 1) 232 | 233 | # Store observation based on delta value. 234 | if self.episode_step_count % self.delta == 0: 235 | self.history[self.history_index] = self.obs 236 | 237 | assert ( 238 | len(self.history) == self.history_length 239 | ), "History size is out of bounds" 240 | self.obs = temp 241 | 242 | def update_index(self) -> None: 243 | # language=rst 244 | """ 245 | Updates the index to keep track of history. For example: ``history = 4``, ``delta = 3`` will produce 246 | ``self.history = {1, 4, 7, 10}`` and ``self.history_index`` will be updated according to ``self.delta`` 247 | and will wrap around the history dictionary. 248 | """ 249 | if self.episode_step_count % self.delta == 0: 250 | if self.history_index != max(self.history.keys()): 251 | self.history_index += self.delta 252 | else: 253 | # Wrap around the history. 254 | self.history_index = (self.history_index % max(self.history.keys())) + 1 255 | -------------------------------------------------------------------------------- /bindsnet/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation import ( 2 | assign_labels, 3 | logreg_fit, 4 | logreg_predict, 5 | all_activity, 6 | proportion_weighting, 7 | ngram, 8 | update_ngram_scores, 9 | ) 10 | -------------------------------------------------------------------------------- /bindsnet/evaluation/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/evaluation/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/evaluation/__pycache__/evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/evaluation/__pycache__/evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | from typing import Optional, Tuple, Dict 3 | 4 | import torch 5 | from sklearn.linear_model import LogisticRegression 6 | 7 | 8 | def assign_labels( 9 | spikes: torch.Tensor, 10 | labels: torch.Tensor, 11 | n_labels: int, 12 | rates: Optional[torch.Tensor] = None, 13 | alpha: float = 1.0, 14 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 15 | # language=rst 16 | """ 17 | Assign labels to the neurons based on highest average spiking activity. 18 | 19 | :param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a single layer's spiking activity. 20 | :param labels: Vector of shape ``(n_samples,)`` with data labels corresponding to spiking activity. 21 | :param n_labels: The number of target labels in the data. 22 | :param rates: If passed, these represent spike rates from a previous ``assign_labels()`` call. 23 | :param alpha: Rate of decay of label assignments. 24 | :return: Tuple of class assignments, per-class spike proportions, and per-class firing rates. 25 | """ 26 | n_neurons = spikes.size(2) 27 | 28 | if rates is None: 29 | rates = torch.zeros_like(torch.Tensor(n_neurons, n_labels)) 30 | 31 | # Sum over time dimension (spike ordering doesn't matter). 32 | spikes = spikes.sum(1) 33 | 34 | for i in range(n_labels): 35 | # Count the number of samples with this label. 36 | n_labeled = torch.sum(labels == i).float() 37 | 38 | if n_labeled > 0: 39 | # Get indices of samples with this label. 40 | indices = torch.nonzero(labels == i).view(-1) 41 | 42 | # Compute average firing rates for this label. 43 | rates[:, i] = alpha * rates[:, i] + ( 44 | torch.sum(spikes[indices], 0) / n_labeled 45 | ) 46 | 47 | # Compute proportions of spike activity per class. 48 | proportions = rates / rates.sum(1, keepdim=True) 49 | proportions[proportions != proportions] = 0 # Set NaNs to 0 50 | 51 | # Neuron assignments are the labels they fire most for. 52 | assignments = torch.max(proportions, 1)[1] 53 | 54 | return assignments, proportions, rates 55 | 56 | 57 | def logreg_fit( 58 | spikes: torch.Tensor, labels: torch.Tensor, logreg: LogisticRegression 59 | ) -> LogisticRegression: 60 | # language=rst 61 | """ 62 | (Re)fit logistic regression model to spike data summed over time. 63 | 64 | :param spikes: Summed (over time) spikes of shape ``(n_examples, time, n_neurons)``. 65 | :param labels: Vector of shape ``(n_samples,)`` with data labels corresponding to spiking activity. 66 | :param logreg: Logistic regression model from previous fits. 67 | :return: (Re)fitted logistic regression model. 68 | """ 69 | # (Re)fit logistic regression model. 70 | logreg.fit(spikes, labels) 71 | return logreg 72 | 73 | 74 | def logreg_predict(spikes: torch.Tensor, logreg: LogisticRegression) -> torch.Tensor: 75 | # language=rst 76 | """ 77 | Predicts classes according to spike data summed over time. 78 | 79 | :param spikes: Summed (over time) spikes of shape ``(n_examples, time, n_neurons)``. 80 | :param logreg: Logistic regression model from previous fits. 81 | :return: Predictions per example. 82 | """ 83 | # Make class label predictions. 84 | if not hasattr(logreg, "coef_") or logreg.coef_ is None: 85 | return -1 * torch.ones(spikes.size(0)).long() 86 | 87 | predictions = logreg.predict(spikes) 88 | return torch.Tensor(predictions).long() 89 | 90 | 91 | def all_activity( 92 | spikes: torch.Tensor, assignments: torch.Tensor, n_labels: int 93 | ) -> torch.Tensor: 94 | # language=rst 95 | """ 96 | Classify data with the label with highest average spiking activity over all neurons. 97 | 98 | :param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a layer's spiking activity. 99 | :param assignments: A vector of shape ``(n_neurons,)`` of neuron label assignments. 100 | :param n_labels: The number of target labels in the data. 101 | :return: Predictions tensor of shape ``(n_samples,)`` resulting from the "all activity" classification scheme. 102 | """ 103 | n_samples = spikes.size(0) 104 | 105 | # Sum over time dimension (spike ordering doesn't matter). 106 | spikes = spikes.sum(1) 107 | 108 | rates = torch.zeros(n_samples, n_labels) 109 | for i in range(n_labels): 110 | # Count the number of neurons with this label assignment. 111 | n_assigns = torch.sum(assignments == i).float() 112 | 113 | if n_assigns > 0: 114 | # Get indices of samples with this label. 115 | indices = torch.nonzero(assignments == i).view(-1) 116 | 117 | # Compute layer-wise firing rate for this label. 118 | rates[:, i] = torch.sum(spikes[:, indices], 1) / n_assigns 119 | 120 | # Predictions are arg-max of layer-wise firing rates. 121 | return torch.sort(rates, dim=1, descending=True)[1][:, 0] 122 | 123 | 124 | def proportion_weighting( 125 | spikes: torch.Tensor, 126 | assignments: torch.Tensor, 127 | proportions: torch.Tensor, 128 | n_labels: int, 129 | ) -> torch.Tensor: 130 | # language=rst 131 | """ 132 | Classify data with the label with highest average spiking activity over all neurons, weighted by class-wise 133 | proportion. 134 | 135 | :param spikes: Binary tensor of shape ``(n_samples, time, n_neurons)`` of a single layer's spiking activity. 136 | :param assignments: A vector of shape ``(n_neurons,)`` of neuron label assignments. 137 | :param proportions: A matrix of shape ``(n_neurons, n_labels)`` giving the per-class proportions of neuron spiking 138 | activity. 139 | :param n_labels: The number of target labels in the data. 140 | :return: Predictions tensor of shape ``(n_samples,)`` resulting from the "proportion weighting" classification 141 | scheme. 142 | """ 143 | n_samples = spikes.size(0) 144 | 145 | # Sum over time dimension (spike ordering doesn't matter). 146 | spikes = spikes.sum(1) 147 | 148 | rates = torch.zeros(n_samples, n_labels) 149 | for i in range(n_labels): 150 | # Count the number of neurons with this label assignment. 151 | n_assigns = torch.sum(assignments == i).float() 152 | 153 | if n_assigns > 0: 154 | # Get indices of samples with this label. 155 | indices = torch.nonzero(assignments == i).view(-1) 156 | 157 | # Compute layer-wise firing rate for this label. 158 | rates[:, i] += ( 159 | torch.sum((proportions[:, i] * spikes)[:, indices], 1) / n_assigns 160 | ) 161 | 162 | # Predictions are arg-max of layer-wise firing rates. 163 | predictions = torch.sort(rates, dim=1, descending=True)[1][:, 0] 164 | 165 | return predictions 166 | 167 | 168 | def ngram( 169 | spikes: torch.Tensor, 170 | ngram_scores: Dict[Tuple[int, ...], torch.Tensor], 171 | n_labels: int, 172 | n: int, 173 | ) -> torch.Tensor: 174 | # language=rst 175 | """ 176 | Predicts between ``n_labels`` using ``ngram_scores``. 177 | 178 | :param spikes: Spikes of shape ``(n_examples, time, n_neurons)``. 179 | :param ngram_scores: Previously recorded scores to update. 180 | :param n_labels: The number of target labels in the data. 181 | :param n: The max size of n-gram to use. 182 | :return: Predictions per example. 183 | """ 184 | predictions = [] 185 | for activity in spikes: 186 | score = torch.zeros(n_labels) 187 | 188 | # Aggregate all of the firing neurons' indices 189 | fire_order = [] 190 | for t in range(activity.size()[0]): 191 | ordering = torch.nonzero(activity[t].view(-1)) 192 | if ordering.numel() > 0: 193 | fire_order += ordering[:, 0].tolist() 194 | 195 | # Consider all n-gram sequences. 196 | for j in range(len(fire_order) - n): 197 | if tuple(fire_order[j : j + n]) in ngram_scores: 198 | score += ngram_scores[tuple(fire_order[j : j + n])] 199 | 200 | predictions.append(torch.argmax(score)) 201 | 202 | return torch.Tensor(predictions).long() 203 | 204 | 205 | def update_ngram_scores( 206 | spikes: torch.Tensor, 207 | labels: torch.Tensor, 208 | n_labels: int, 209 | n: int, 210 | ngram_scores: Dict[Tuple[int, ...], torch.Tensor], 211 | ) -> Dict[Tuple[int, ...], torch.Tensor]: 212 | # language=rst 213 | """ 214 | Updates ngram scores by adding the count of each spike sequence of length n from the past ``n_examples``. 215 | 216 | :param spikes: Spikes of shape ``(n_examples, time, n_neurons)``. 217 | :param labels: The ground truth labels of shape ``(n_examples)``. 218 | :param n_labels: The number of target labels in the data. 219 | :param n: The max size of n-gram to use. 220 | :param ngram_scores: Previously recorded scores to update. 221 | :return: Dictionary mapping n-grams to vectors of per-class spike counts. 222 | """ 223 | for i, activity in enumerate(spikes): 224 | # Obtain firing order for spiking activity. 225 | fire_order = [] 226 | 227 | # Aggregate all of the firing neurons' indices. 228 | for t in range(spikes.size(1)): 229 | # Gets the indices of the neurons which fired on this timestep. 230 | ordering = torch.nonzero(activity[t]).view(-1) 231 | if ordering.numel() > 0: # If there was more than one spike... 232 | # Add the indices of spiked neurons to the fire ordering. 233 | ordering = ordering.tolist() 234 | fire_order.append(ordering) 235 | 236 | # Check every sequence of length n. 237 | for order in zip(*(fire_order[k:] for k in range(n))): 238 | for sequence in product(*order): 239 | if sequence not in ngram_scores: 240 | ngram_scores[sequence] = torch.zeros(n_labels) 241 | 242 | ngram_scores[sequence][int(labels[i])] += 1 243 | 244 | return ngram_scores 245 | -------------------------------------------------------------------------------- /bindsnet/learning/__init__.py: -------------------------------------------------------------------------------- 1 | from .learning import ( 2 | LearningRule, 3 | NoOp, 4 | PostPre, 5 | WeightDependentPostPre, 6 | Hebbian, 7 | MSTDP, 8 | MSTDPET, 9 | Rmax, 10 | ) 11 | -------------------------------------------------------------------------------- /bindsnet/learning/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/learning/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/learning/__pycache__/learning.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/learning/__pycache__/learning.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/learning/__pycache__/reward.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/learning/__pycache__/reward.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/learning/reward.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class AbstractReward(ABC): 7 | # language=rst 8 | """ 9 | Abstract base class for reward computation. 10 | """ 11 | 12 | @abstractmethod 13 | def compute(self, **kwargs) -> None: 14 | # language=rst 15 | """ 16 | Computes/modifies reward. 17 | """ 18 | pass 19 | 20 | @abstractmethod 21 | def update(self, **kwargs) -> None: 22 | # language=rst 23 | """ 24 | Updates internal variables needed to modify reward. Usually called once per episode. 25 | """ 26 | pass 27 | 28 | 29 | class MovingAvgRPE(AbstractReward): 30 | # language=rst 31 | """ 32 | Computes reward prediction error (RPE) based on an exponential moving average (EMA) of past rewards. 33 | """ 34 | 35 | def __init__(self, **kwargs) -> None: 36 | # language=rst 37 | """ 38 | Constructor for EMA reward prediction error. 39 | """ 40 | self.reward_predict = torch.tensor(0.0) # Predicted reward (per step). 41 | self.reward_predict_episode = torch.tensor(0.0) # Predicted reward per episode. 42 | self.rewards_predict_episode = ( 43 | [] 44 | ) # List of predicted rewards per episode (used for plotting). 45 | 46 | def compute(self, **kwargs) -> torch.Tensor: 47 | # language=rst 48 | """ 49 | Computes the reward prediction error using EMA. 50 | 51 | Keyword arguments: 52 | 53 | :param Union[float, torch.Tensor] reward: Current reward. 54 | :return: Reward prediction error. 55 | """ 56 | # Get keyword arguments. 57 | reward = kwargs["reward"] 58 | 59 | return reward - self.reward_predict 60 | 61 | def update(self, **kwargs) -> None: 62 | # language=rst 63 | """ 64 | Updates the EMAs. Called once per episode. 65 | 66 | Keyword arguments: 67 | 68 | :param Union[float, torch.Tensor] accumulated_reward: Reward accumulated over one episode. 69 | :param int steps: Steps in that episode. 70 | :param float ema_window: Width of the averaging window. 71 | """ 72 | # Get keyword arguments. 73 | accumulated_reward = kwargs["accumulated_reward"] 74 | steps = torch.tensor(kwargs["steps"]).float() 75 | ema_window = torch.tensor(kwargs.get("ema_window", 10.0)) 76 | 77 | # Compute average reward per step. 78 | reward = accumulated_reward / steps 79 | 80 | # Update EMAs. 81 | self.reward_predict = ( 82 | 1 - 1 / ema_window 83 | ) * self.reward_predict + 1 / ema_window * reward 84 | self.reward_predict_episode = ( 85 | 1 - 1 / ema_window 86 | ) * self.reward_predict_episode + 1 / ema_window * accumulated_reward 87 | self.rewards_predict_episode.append(self.reward_predict_episode.item()) 88 | -------------------------------------------------------------------------------- /bindsnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ( 2 | TwoLayerNetwork, 3 | DiehlAndCook2015, 4 | DiehlAndCook2015v2, 5 | IncreasingInhibitionNetwork, 6 | LocallyConnectedNetwork, 7 | ) 8 | -------------------------------------------------------------------------------- /bindsnet/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/models/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/models/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .network import Network, load 2 | from . import nodes, topology, monitors 3 | -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/monitors.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/monitors.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/network.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/network.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/nodes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/nodes.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/topology.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/network/__pycache__/topology.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/network/monitors.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | from abc import ABC, abstractmethod 6 | from typing import Union, Optional, Iterable, Dict 7 | 8 | from .nodes import Nodes 9 | from .topology import AbstractConnection 10 | 11 | 12 | class AbstractMonitor(ABC): 13 | # language=rst 14 | """ 15 | Abstract base class for state variable monitors. 16 | """ 17 | 18 | 19 | class Monitor(AbstractMonitor): 20 | # language=rst 21 | """ 22 | Records state variables of interest. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | obj: Union[Nodes, AbstractConnection], 28 | state_vars: Iterable[str], 29 | time: Optional[int] = None, 30 | batch_size: int = 1, 31 | ): 32 | # language=rst 33 | """ 34 | Constructs a ``Monitor`` object. 35 | 36 | :param obj: An object to record state variables from during network simulation. 37 | :param state_vars: Iterable of strings indicating names of state variables to record. 38 | :param time: If not ``None``, pre-allocate memory for state variable recording. 39 | """ 40 | super().__init__() 41 | 42 | self.obj = obj 43 | self.state_vars = state_vars 44 | self.time = time 45 | self.batch_size = batch_size 46 | 47 | # Deal with time later, the same underlying list is used 48 | self.recording = {v: [] for v in self.state_vars} 49 | 50 | def get(self, var: str) -> torch.Tensor: 51 | # language=rst 52 | """ 53 | Return recording to user. 54 | 55 | :param var: State variable recording to return. 56 | :return: Tensor of shape ``[time, n_1, ..., n_k]``, where ``[n_1, ..., n_k]`` is the shape of the recorded 57 | state variable. 58 | """ 59 | return torch.cat(self.recording[var], 0) 60 | 61 | def record(self) -> None: 62 | # language=rst 63 | """ 64 | Appends the current value of the recorded state variables to the recording. 65 | """ 66 | for v in self.state_vars: 67 | data = getattr(self.obj, v).unsqueeze(0) 68 | self.recording[v].append(data.detach().clone()) 69 | 70 | # remove the oldest element (first in the list) 71 | if self.time is not None: 72 | for v in self.state_vars: 73 | if len(self.recording[v]) > self.time: 74 | self.recording[v].pop(0) 75 | 76 | def reset_(self) -> None: 77 | # language=rst 78 | """ 79 | Resets recordings to empty ``torch.Tensor``s. 80 | """ 81 | self.recording = {v: [] for v in self.state_vars} 82 | 83 | 84 | class NetworkMonitor(AbstractMonitor): 85 | # language=rst 86 | """ 87 | Record state variables of all layers and connections. 88 | """ 89 | 90 | def __init__( 91 | self, 92 | network: "Network", 93 | layers: Optional[Iterable[str]] = None, 94 | connections: Optional[Iterable[str]] = None, 95 | state_vars: Optional[Iterable[str]] = None, 96 | time: Optional[int] = None, 97 | ): 98 | # language=rst 99 | """ 100 | Constructs a ``NetworkMonitor`` object. 101 | 102 | :param network: Network to record state variables from. 103 | :param layers: Layers to record state variables from. 104 | :param connections: Connections to record state variables from. 105 | :param state_vars: List of strings indicating names of state variables to record. 106 | :param time: If not ``None``, pre-allocate memory for state variable recording. 107 | """ 108 | super().__init__() 109 | 110 | self.network = network 111 | self.layers = layers if layers is not None else list(self.network.layers.keys()) 112 | self.connections = ( 113 | connections 114 | if connections is not None 115 | else list(self.network.connections.keys()) 116 | ) 117 | self.state_vars = state_vars if state_vars is not None else ("v", "s", "w") 118 | self.time = time 119 | 120 | if self.time is not None: 121 | self.i = 0 122 | 123 | # Initialize empty recording. 124 | self.recording = {k: {} for k in self.layers + self.connections} 125 | 126 | # If no simulation time is specified, specify 0-dimensional recordings. 127 | if self.time is None: 128 | for v in self.state_vars: 129 | for l in self.layers: 130 | if hasattr(self.network.layers[l], v): 131 | self.recording[l][v] = torch.Tensor() 132 | 133 | for c in self.connections: 134 | if hasattr(self.network.connections[c], v): 135 | self.recording[c][v] = torch.Tensor() 136 | 137 | # If simulation time is specified, pre-allocate recordings in memory for speed. 138 | else: 139 | for v in self.state_vars: 140 | for l in self.layers: 141 | if hasattr(self.network.layers[l], v): 142 | self.recording[l][v] = torch.zeros( 143 | self.time, *getattr(self.network.layers[l], v).size() 144 | ) 145 | 146 | for c in self.connections: 147 | if hasattr(self.network.connections[c], v): 148 | self.recording[c][v] = torch.zeros( 149 | self.time, *getattr(self.network.connections[c], v).size() 150 | ) 151 | 152 | def get(self) -> Dict[str, Dict[str, Union[Nodes, AbstractConnection]]]: 153 | # language=rst 154 | """ 155 | Return entire recording to user. 156 | 157 | :return: Dictionary of dictionary of all layers' and connections' recorded state variables. 158 | """ 159 | return self.recording 160 | 161 | def record(self) -> None: 162 | # language=rst 163 | """ 164 | Appends the current value of the recorded state variables to the recording. 165 | """ 166 | if self.time is None: 167 | for v in self.state_vars: 168 | for l in self.layers: 169 | if hasattr(self.network.layers[l], v): 170 | data = getattr(self.network.layers[l], v).unsqueeze(0).float() 171 | self.recording[l][v] = torch.cat( 172 | (self.recording[l][v], data), 0 173 | ) 174 | 175 | for c in self.connections: 176 | if hasattr(self.network.connections[c], v): 177 | data = getattr(self.network.connections[c], v).unsqueeze(0) 178 | self.recording[c][v] = torch.cat( 179 | (self.recording[c][v], data), 0 180 | ) 181 | 182 | else: 183 | for v in self.state_vars: 184 | for l in self.layers: 185 | if hasattr(self.network.layers[l], v): 186 | data = getattr(self.network.layers[l], v).float().unsqueeze(0) 187 | self.recording[l][v] = torch.cat( 188 | (self.recording[l][v][1:].type(data.type()), data), 0 189 | ) 190 | 191 | for c in self.connections: 192 | if hasattr(self.network.connections[c], v): 193 | data = getattr(self.network.connections[c], v).unsqueeze(0) 194 | self.recording[c][v] = torch.cat( 195 | (self.recording[c][v][1:].type(data.type()), data), 0 196 | ) 197 | 198 | self.i += 1 199 | 200 | def save(self, path: str, fmt: str = "npz") -> None: 201 | # language=rst 202 | """ 203 | Write the recording dictionary out to file. 204 | 205 | :param path: The directory to which to write the monitor's recording. 206 | :param fmt: Type of file to write to disk. One of ``"pickle"`` or ``"npz"``. 207 | """ 208 | if not os.path.exists(os.path.dirname(path)): 209 | os.makedirs(os.path.dirname(path)) 210 | 211 | if fmt == "npz": 212 | # Build a list of arrays to write to disk. 213 | arrays = {} 214 | for o in self.recording: 215 | if type(o) == tuple: 216 | arrays.update( 217 | { 218 | "_".join(["-".join(o), v]): self.recording[o][v] 219 | for v in self.recording[o] 220 | } 221 | ) 222 | elif type(o) == str: 223 | arrays.update( 224 | { 225 | "_".join([o, v]): self.recording[o][v] 226 | for v in self.recording[o] 227 | } 228 | ) 229 | 230 | np.savez_compressed(path, **arrays) 231 | 232 | elif fmt == "pickle": 233 | with open(path, "wb") as f: 234 | torch.save(self.recording, f) 235 | 236 | def reset_(self) -> None: 237 | # language=rst 238 | """ 239 | Resets recordings to empty ``torch.Tensors``. 240 | """ 241 | # Reset to empty recordings 242 | self.recording = {k: {} for k in self.layers + self.connections} 243 | 244 | if self.time is not None: 245 | self.i = 0 246 | 247 | # If no simulation time is specified, specify 0-dimensional recordings. 248 | if self.time is None: 249 | for v in self.state_vars: 250 | for l in self.layers: 251 | if hasattr(self.network.layers[l], v): 252 | self.recording[l][v] = torch.Tensor() 253 | 254 | for c in self.connections: 255 | if hasattr(self.network.connections[c], v): 256 | self.recording[c][v] = torch.Tensor() 257 | 258 | # If simulation time is specified, pre-allocate recordings in memory for speed. 259 | else: 260 | for v in self.state_vars: 261 | for l in self.layers: 262 | if hasattr(self.network.layers[l], v): 263 | self.recording[l][v] = torch.zeros( 264 | self.time, *getattr(self.network.layers[l], v).size() 265 | ) 266 | 267 | for c in self.connections: 268 | if hasattr(self.network.connections[c], v): 269 | self.recording[c][v] = torch.zeros( 270 | self.time, *getattr(self.network.layers[c], v).size() 271 | ) 272 | -------------------------------------------------------------------------------- /bindsnet/network/network.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from typing import Dict, Optional, Type, Iterable 3 | 4 | import torch 5 | 6 | from .monitors import AbstractMonitor 7 | from .nodes import AbstractInput, Nodes 8 | from .topology import AbstractConnection 9 | from ..learning.reward import AbstractReward 10 | import numpy as np 11 | 12 | def load(file_name: str, map_location: str = "cpu", learning: bool = None) -> "Network": 13 | # language=rst 14 | """ 15 | Loads serialized network object from disk. 16 | 17 | :param file_name: Path to serialized network object on disk. 18 | :param map_location: One of ``"cpu"`` or ``"cuda"``. Defaults to ``"cpu"``. 19 | :param learning: Whether to load with learning enabled. Default loads value from disk. 20 | """ 21 | network = torch.load(open(file_name, "rb"), map_location=map_location) 22 | if learning is not None and "learning" in vars(network): 23 | network.learning = learning 24 | 25 | return network 26 | 27 | 28 | class Network(torch.nn.Module): 29 | # language=rst 30 | """ 31 | Most important object of the ``bindsnet`` package. Responsible for the simulation and interaction of nodes and 32 | connections. 33 | 34 | **Example:** 35 | 36 | .. code-block:: python 37 | 38 | import torch 39 | import matplotlib.pyplot as plt 40 | 41 | from bindsnet import encoding 42 | from bindsnet.network import Network, nodes, topology, monitors 43 | 44 | network = Network(dt=1.0) # Instantiates network. 45 | 46 | X = nodes.Input(100) # Input layer. 47 | Y = nodes.LIFNodes(100) # Layer of LIF neurons. 48 | C = topology.Connection(source=X, target=Y, w=torch.rand(X.n, Y.n)) # Connection from X to Y. 49 | 50 | # Spike monitor objects. 51 | M1 = monitors.Monitor(obj=X, state_vars=['s']) 52 | M2 = monitors.Monitor(obj=Y, state_vars=['s']) 53 | 54 | # Add everything to the network object. 55 | network.add_layer(layer=X, name='X') 56 | network.add_layer(layer=Y, name='Y') 57 | network.add_connection(connection=C, source='X', target='Y') 58 | network.add_monitor(monitor=M1, name='X') 59 | network.add_monitor(monitor=M2, name='Y') 60 | 61 | # Create Poisson-distributed spike train inputs. 62 | data = 15 * torch.rand(100) # Generate random Poisson rates for 100 input neurons. 63 | train = encoding.poisson(datum=data, time=5000) # Encode input as 5000ms Poisson spike trains. 64 | 65 | # Simulate network on generated spike trains. 66 | inpts = {'X' : train} # Create inputs mapping. 67 | network.run(inpts=inpts, time=5000) # Run network simulation. 68 | 69 | # Plot spikes of input and output layers. 70 | spikes = {'X' : M1.get('s'), 'Y' : M2.get('s')} 71 | 72 | fig, axes = plt.subplots(2, 1, figsize=(12, 7)) 73 | for i, layer in enumerate(spikes): 74 | axes[i].matshow(spikes[layer], cmap='binary') 75 | axes[i].set_title('%s spikes' % layer) 76 | axes[i].set_xlabel('Time'); axes[i].set_ylabel('Index of neuron') 77 | axes[i].set_xticks(()); axes[i].set_yticks(()) 78 | axes[i].set_aspect('auto') 79 | 80 | plt.tight_layout(); plt.show() 81 | """ 82 | 83 | def __init__( 84 | self, 85 | dt: float = 1.0, 86 | batch_size: int = 1, 87 | learning: bool = True, 88 | reward_fn: Optional[Type[AbstractReward]] = None, 89 | ) -> None: 90 | # language=rst 91 | """ 92 | Initializes network object. 93 | 94 | :param dt: Simulation timestep. 95 | :param learning: Whether to allow connection updates. True by default. 96 | :param reward_fn: Optional class allowing for modification of reward in case of reward-modulated learning. 97 | """ 98 | super().__init__() 99 | 100 | self.dt = dt 101 | self.batch_size = batch_size 102 | 103 | self.layers = {} 104 | self.connections = {} 105 | self.monitors = {} 106 | self.train(learning) 107 | 108 | if reward_fn is not None: 109 | self.reward_fn = reward_fn() 110 | else: 111 | self.reward_fn = None 112 | 113 | def add_layer(self, layer: Nodes, name: str) -> None: 114 | # language=rst 115 | """ 116 | Adds a layer of nodes to the network. 117 | 118 | :param layer: A subclass of the ``Nodes`` object. 119 | :param name: Logical name of layer. 120 | """ 121 | self.layers[name] = layer 122 | self.add_module(name, layer) 123 | 124 | layer.train(self.learning) 125 | layer.compute_decays(self.dt) 126 | layer.set_batch_size(self.batch_size) 127 | 128 | def add_connection( 129 | self, connection: AbstractConnection, source: str, target: str 130 | ) -> None: 131 | # language=rst 132 | """ 133 | Adds a connection between layers of nodes to the network. 134 | 135 | :param connection: An instance of class ``Connection``. 136 | :param source: Logical name of the connection's source layer. 137 | :param target: Logical name of the connection's target layer. 138 | """ 139 | self.connections[(source, target)] = connection 140 | self.add_module(source + "_to_" + target, connection) 141 | 142 | connection.dt = self.dt 143 | connection.train(self.learning) 144 | 145 | def add_monitor(self, monitor: AbstractMonitor, name: str) -> None: 146 | # language=rst 147 | """ 148 | Adds a monitor on a network object to the network. 149 | 150 | :param monitor: An instance of class ``Monitor``. 151 | :param name: Logical name of monitor object. 152 | """ 153 | self.monitors[name] = monitor 154 | monitor.network = self 155 | monitor.dt = self.dt 156 | 157 | def save(self, file_name: str) -> None: 158 | # language=rst 159 | """ 160 | Serializes the network object to disk. 161 | 162 | :param file_name: Path to store serialized network object on disk. 163 | 164 | **Example:** 165 | 166 | .. code-block:: python 167 | 168 | import torch 169 | import matplotlib.pyplot as plt 170 | 171 | from pathlib import Path 172 | from bindsnet.network import * 173 | from bindsnet.network import topology 174 | 175 | # Build simple network. 176 | network = Network(dt=1.0) 177 | 178 | X = nodes.Input(100) # Input layer. 179 | Y = nodes.LIFNodes(100) # Layer of LIF neurons. 180 | C = topology.Connection(source=X, target=Y, w=torch.rand(X.n, Y.n)) # Connection from X to Y. 181 | 182 | # Add everything to the network object. 183 | network.add_layer(layer=X, name='X') 184 | network.add_layer(layer=Y, name='Y') 185 | network.add_connection(connection=C, source='X', target='Y') 186 | 187 | # Save the network to disk. 188 | network.save(str(Path.home()) + '/network.pt') 189 | """ 190 | torch.save(self, open(file_name, "wb")) 191 | 192 | def clone(self) -> "Network": 193 | # language=rst 194 | """ 195 | Returns a cloned network object. 196 | 197 | :return: A copy of this network. 198 | """ 199 | virtual_file = tempfile.SpooledTemporaryFile() 200 | torch.save(self, virtual_file) 201 | virtual_file.seek(0) 202 | return torch.load(virtual_file) 203 | 204 | def _get_inputs(self, layers: Iterable = None) -> Dict[str, torch.Tensor]: 205 | # language=rst 206 | """ 207 | Fetches outputs from network layers to use as input to downstream layers. 208 | 209 | :param layers: Layers to update inputs for. Defaults to all network layers. 210 | :return: Inputs to all layers for the current iteration. 211 | """ 212 | inpts = {} 213 | 214 | if layers is None: 215 | layers = self.layers 216 | 217 | # Loop over network connections. 218 | for c in self.connections: 219 | if c[1] in layers: 220 | # Fetch source and target populations. 221 | source = self.connections[c].source 222 | target = self.connections[c].target 223 | 224 | if not c[1] in inpts: 225 | inpts[c[1]] = torch.zeros( 226 | self.batch_size, *target.shape, device=target.s.device 227 | ) 228 | 229 | # Add to input: source's spikes multiplied by connection weights. 230 | inpts[c[1]] += self.connections[c].compute(source.s) 231 | 232 | return inpts 233 | 234 | def run(self, inpts: Dict[str, torch.Tensor], time: int, acc: np.zeros, 235 | step:int, labels: torch.Tensor, one_step=False, **kwargs 236 | ) -> None: 237 | # language=rst 238 | """ 239 | Simulate network for given inputs and time. 240 | 241 | :param inpts: Dictionary of ``Tensor``s of shape ``[time, *input_shape]`` or 242 | ``[batch_size, time, *input_shape]``. 243 | :param time: Simulation time. 244 | :param one_step: Whether to run the network in "feed-forward" mode, where inputs 245 | propagate all the way through the network in a single simulation time step. 246 | Layers are updated in the order they are added to the network. 247 | 248 | Keyword arguments: 249 | 250 | :param Dict[str, torch.Tensor] clamp: Mapping of layer names to boolean masks if 251 | neurons should be clamped to spiking. The ``Tensor``s have shape 252 | ``[n_neurons]`` or ``[time, n_neurons]``. 253 | :param Dict[str, torch.Tensor] unclamp: Mapping of layer names to boolean masks 254 | if neurons should be clamped to not spiking. The ``Tensor``s should have 255 | shape ``[n_neurons]`` or ``[time, n_neurons]``. 256 | :param Dict[str, torch.Tensor] injects_v: Mapping of layer names to boolean 257 | masks if neurons should be added voltage. The ``Tensor``s should have shape 258 | ``[n_neurons]`` or ``[time, n_neurons]``. 259 | :param Union[float, torch.Tensor] reward: Scalar value used in reward-modulated 260 | learning. 261 | :param Dict[Tuple[str], torch.Tensor] masks: Mapping of connection names to 262 | boolean masks determining which weights to clamp to zero. 263 | 264 | **Example:** 265 | 266 | .. code-block:: python 267 | 268 | import torch 269 | import matplotlib.pyplot as plt 270 | 271 | from bindsnet.network import Network 272 | from bindsnet.network.nodes import Input 273 | from bindsnet.network.monitors import Monitor 274 | 275 | # Build simple network. 276 | network = Network() 277 | network.add_layer(Input(500), name='I') 278 | network.add_monitor(Monitor(network.layers['I'], state_vars=['s']), 'I') 279 | 280 | # Generate spikes by running Bernoulli trials on Uniform(0, 0.5) samples. 281 | spikes = torch.bernoulli(0.5 * torch.rand(500, 500)) 282 | 283 | # Run network simulation. 284 | network.run(inpts={'I' : spikes}, time=500) 285 | 286 | # Look at input spiking activity. 287 | spikes = network.monitors['I'].get('s') 288 | plt.matshow(spikes, cmap='binary') 289 | plt.xticks(()); plt.yticks(()); 290 | plt.xlabel('Time'); plt.ylabel('Neuron index') 291 | plt.title('Input spiking') 292 | plt.show() 293 | """ 294 | # Parse keyword arguments. 295 | clamps = kwargs.get("clamp", {}) 296 | unclamps = kwargs.get("unclamp", {}) 297 | masks = kwargs.get("masks", {}) 298 | injects_v = kwargs.get("injects_v", {}) 299 | 300 | # Compute reward. 301 | if self.reward_fn is not None: 302 | kwargs["reward"] = self.reward_fn.compute(**kwargs) 303 | 304 | # Dynamic setting of batch size. 305 | if inpts != {}: 306 | for key in inpts: 307 | # goal shape is [time, batch, n_0, ...] 308 | if len(inpts[key].size()) == 1: 309 | # current shape is [n_0, ...] 310 | # unsqueeze twice to make [1, 1, n_0, ...] 311 | inpts[key] = inpts[key].unsqueeze(0).unsqueeze(0) 312 | elif len(inpts[key].size()) == 2: 313 | # current shape is [time, n_0, ...] 314 | # unsqueeze dim 1 so that we have 315 | # [time, 1, n_0, ...] 316 | inpts[key] = inpts[key].unsqueeze(1) 317 | 318 | for key in inpts: 319 | # batch dimension is 1, grab this and use for batch size 320 | if inpts[key].size(1) != self.batch_size: 321 | self.batch_size = inpts[key].size(1) 322 | 323 | for l in self.layers: 324 | self.layers[l].set_batch_size(self.batch_size) 325 | 326 | for m in self.monitors: 327 | self.monitors[m].reset_() 328 | 329 | break 330 | 331 | # Effective number of timesteps. 332 | timesteps = int(time / self.dt) 333 | 334 | # Get input to all layers (synchronous mode). 335 | if not one_step: 336 | inpts.update(self._get_inputs()) 337 | 338 | # Simulate network activity for `time` timesteps.count the frequency in tensor pytorch 339 | 340 | for t in range(timesteps): 341 | for l in self.layers: 342 | if isinstance(self.layers[l], AbstractInput): 343 | # shape is [time, batch, n_0, ...] 344 | self.layers[l].forward(x=inpts[l][t]) 345 | else: 346 | if one_step: 347 | # Get input to this layer (one-step mode). 348 | inpts.update(self._get_inputs(layers=[l])) 349 | self.layers[l].forward(x=inpts[l]) 350 | 351 | # Clamp neurons to spike. 352 | clamp = clamps.get(l, None) 353 | if clamp is not None: 354 | if clamp.ndimension() == 1: 355 | self.layers[l].s[:, clamp] = 1 356 | else: 357 | self.layers[l].s[:, clamp[t]] = 1 358 | 359 | # Clamp neurons not to spike. 360 | unclamp = unclamps.get(l, None) 361 | if unclamp is not None: 362 | if unclamp.ndimension() == 1: 363 | self.layers[l].s[unclamp] = 0 364 | else: 365 | self.layers[l].s[unclamp[t]] = 0 366 | 367 | # Inject voltage to neurons. 368 | inject_v = injects_v.get(l, None) 369 | if inject_v is not None: 370 | if inject_v.ndimension() == 1: 371 | self.layers[l].v += inject_v 372 | else: 373 | self.layers[l].v += inject_v[t] 374 | 375 | # Run synapse updates. 376 | for c in self.connections: 377 | self.connections[c].update( 378 | mask=masks.get(c, None), learning=self.learning, **kwargs 379 | ) 380 | 381 | # Get input to all layers. 382 | inpts.update(self._get_inputs()) 383 | 384 | # Record state variables of interest. 385 | for m in self.monitors: 386 | self.monitors[m].record() 387 | last_layer = list(self.layers.keys())[-1] 388 | output_voltages = self.layers[last_layer].summed 389 | prediction = torch.softmax(output_voltages, dim=1).argmax(dim=1) 390 | # print(output_voltages) 391 | for i, p in enumerate(prediction): 392 | maxxi=output_voltages.max(dim=1) 393 | correct = (prediction.cpu() == labels).sum().item() 394 | acc[t][step] = correct 395 | # Re-normalize connections. 396 | for c in self.connections: 397 | self.connections[c].normalize() 398 | 399 | def reset_(self) -> None: 400 | # language=rst 401 | """ 402 | Reset state variables of objects in network. 403 | """ 404 | for layer in self.layers: 405 | self.layers[layer].reset_() 406 | 407 | for connection in self.connections: 408 | self.connections[connection].reset_() 409 | 410 | for monitor in self.monitors: 411 | self.monitors[monitor].reset_() 412 | 413 | def train(self, mode: bool = True) -> "torch.nn.Module": 414 | # language=rst 415 | """Sets the node in training mode. 416 | 417 | :param mode: Turn training on or off. 418 | 419 | :return: ``self`` as specified in ``torch.nn.Module``. 420 | """ 421 | self.learning = mode 422 | return super().train(mode) 423 | -------------------------------------------------------------------------------- /bindsnet/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from .environment_pipeline import EnvironmentPipeline 2 | from .base_pipeline import BasePipeline 3 | from .dataloader_pipeline import DataLoaderPipeline, TorchVisionDatasetPipeline 4 | from . import action 5 | -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/action.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/action.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/base_pipeline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/base_pipeline.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/environment_pipeline.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/pipeline/__pycache__/environment_pipeline.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/action.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from . import EnvironmentPipeline 5 | 6 | 7 | def select_multinomial(pipeline: EnvironmentPipeline, **kwargs) -> int: 8 | # language=rst 9 | """ 10 | Selects an action probabilistically based on spiking activity from a network layer. 11 | 12 | :param pipeline: EnvironmentPipeline with environment that has an integer action space. 13 | :return: Action sampled from multinomial over activity of similarly-sized output layer. 14 | 15 | Keyword arguments: 16 | 17 | :param str output: Name of output layer whose activity to base action selection on. 18 | """ 19 | try: 20 | output = kwargs["output"] 21 | except KeyError: 22 | raise KeyError('select_multinomial() requires an "output" layer argument.') 23 | 24 | output = pipeline.network.layers[output] 25 | action_space = pipeline.env.action_space 26 | 27 | assert ( 28 | output.n % action_space.n == 0 29 | ), f"Output layer size of {output.n} is not divisible by action space size of {action_space.n}." 30 | 31 | pop_size = int(output.n / action_space.n) 32 | spikes = output.s 33 | _sum = spikes.sum().float() 34 | 35 | # Choose action based on population's spiking. 36 | if _sum == 0: 37 | action = np.random.choice(pipeline.env.action_space.n) 38 | else: 39 | pop_spikes = torch.Tensor( 40 | [ 41 | spikes[(i * pop_size) : (i * pop_size) + pop_size].sum() 42 | for i in range(action_space.n) 43 | ] 44 | ) 45 | action = torch.multinomial((pop_spikes.float() / _sum).view(-1), 1)[0].item() 46 | 47 | return action 48 | 49 | 50 | def select_softmax(pipeline: EnvironmentPipeline, **kwargs) -> int: 51 | # language=rst 52 | """ 53 | Selects an action using softmax function based on spiking from a network layer. 54 | 55 | :param pipeline: EnvironmentPipeline with environment that has an integer action space. 56 | :return: Action sampled from softmax over activity of similarly-sized output layer. 57 | 58 | Keyword arguments: 59 | 60 | :param str output: Name of output layer whose activity to base action selection on. 61 | """ 62 | try: 63 | output = kwargs["output"] 64 | except KeyError: 65 | raise KeyError('select_softmax() requires an "output" layer argument.') 66 | 67 | assert ( 68 | pipeline.network.layers[output].n == pipeline.env.action_space.n 69 | ), "Output layer size is not equal to the size of the action space." 70 | 71 | assert hasattr( 72 | pipeline, "spike_record" 73 | ), "EnvironmentPipeline is missing the attribute: spike_record." 74 | 75 | # Sum of previous iterations' spikes (Not yet implemented) 76 | spikes = torch.sum(pipeline.spike_record[output], dim=1) 77 | _sum = torch.sum(torch.exp(spikes.float())) 78 | 79 | if _sum == 0: 80 | action = np.random.choice(pipeline.env.action_space.n) 81 | else: 82 | action = torch.multinomial((torch.exp(spikes.float()) / _sum).view(-1), 1)[0] 83 | 84 | return action 85 | 86 | 87 | def select_random(pipeline: EnvironmentPipeline, **kwargs) -> int: 88 | # language=rst 89 | """ 90 | Selects an action randomly from the action space. 91 | 92 | :param pipeline: EnvironmentPipeline with environment that has an integer action space. 93 | :return: Action randomly sampled over size of pipeline's action space. 94 | """ 95 | # Choose action randomly from the action space. 96 | return np.random.choice(pipeline.env.action_space.n) 97 | -------------------------------------------------------------------------------- /bindsnet/pipeline/base_pipeline.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Tuple, Dict, Any 3 | 4 | import torch 5 | from torch._six import container_abcs, string_classes 6 | 7 | from ..network import Network 8 | from ..network.monitors import Monitor 9 | 10 | 11 | def recursive_to(item, device): 12 | # language=rst 13 | """ 14 | Recursively transfers everything contained in item to the target 15 | device. 16 | 17 | :param item: An individual tensor or container of tensors. 18 | :param device: ``torch.device`` pointing to ``"cuda"`` or ``"cpu"``. 19 | 20 | :return: A version of the item that has been sent to a device. 21 | """ 22 | 23 | if isinstance(item, torch.Tensor): 24 | return item.to(device) 25 | elif isinstance(item, (string_classes, int, float, bool)): 26 | return item 27 | elif isinstance(item, container_abcs.Mapping): 28 | return {key: recursive_to(item[key], device) for key in item} 29 | elif isinstance(item, tuple) and hasattr(item, "_fields"): 30 | return type(item)(*(recursive_to(i, device) for i in item)) 31 | elif isinstance(item, container_abcs.Sequence): 32 | return [recursive_to(i, device) for i in item] 33 | else: 34 | raise NotImplementedError(f"Target type {type(item)} not supported.") 35 | 36 | 37 | class BasePipeline: 38 | # language=rst 39 | """ 40 | A generic pipeline that handles high level functionality. 41 | """ 42 | 43 | def __init__(self, network: Network, **kwargs) -> None: 44 | # language=rst 45 | """ 46 | Initializes the pipeline. 47 | 48 | :param network: Arbitrary network object, will be managed by the ``BasePipeline`` class. 49 | 50 | Keyword arguments: 51 | 52 | :param int save_interval: How often to save the network to disk. 53 | :param str save_dir: Directory to save network object to. 54 | :param Dict[str, Any] plot_config: Dict containing the plot configuration. Includes length, 55 | type (``"color"`` or ``"line"``), and interval per plot type. 56 | :param int print_interval: Interval to print text output. 57 | :param bool allow_gpu: Allows automatic transfer to the GPU. 58 | """ 59 | self.network = network 60 | 61 | # Network saving handles caching of intermediate results. 62 | self.save_dir = kwargs.get("save_dir", "network.pt") 63 | self.save_interval = kwargs.get("save_interval", None) 64 | 65 | # Handles plotting of all layer spikes and voltages. 66 | # This constructs monitors at every level. 67 | self.plot_config = kwargs.get( 68 | "plot_config", {"data_step": None, "data_length": 10} 69 | ) 70 | 71 | if self.plot_config["data_step"] is not None: 72 | for l in self.network.layers: 73 | self.network.add_monitor( 74 | Monitor( 75 | self.network.layers[l], "s", self.plot_config["data_length"] 76 | ), 77 | name=f"{l}_spikes", 78 | ) 79 | if hasattr(self.network.layers[l], "v"): 80 | self.network.add_monitor( 81 | Monitor( 82 | self.network.layers[l], "v", self.plot_config["data_length"] 83 | ), 84 | name=f"{l}_voltages", 85 | ) 86 | 87 | self.print_interval = kwargs.get("print_interval", None) 88 | 89 | self.test_interval = kwargs.get("test_interval", None) 90 | 91 | self.step_count = 0 92 | 93 | self.init_fn() 94 | 95 | self.clock = time.time() 96 | 97 | self.allow_gpu = kwargs.get("allow_gpu", True) 98 | 99 | if torch.cuda.is_available() and self.allow_gpu: 100 | self.device = torch.device("cuda") 101 | else: 102 | self.device = torch.device("cpu") 103 | 104 | self.network.to(self.device) 105 | 106 | def reset_(self) -> None: 107 | # language=rst 108 | """ 109 | Reset the pipeline. 110 | """ 111 | 112 | self.network.reset_() 113 | self.step_count = 0 114 | 115 | def step(self, batch: Any, **kwargs) -> Any: 116 | # language=rst 117 | """ 118 | Single step of any pipeline at a high level. 119 | 120 | :param batch: A batch of inputs to be handed to the ``step_()`` function. 121 | Standard in subclasses of ``BasePipeline``. 122 | 123 | :return: The output from the subclass's ``step_()`` method, which could be anything. 124 | Passed to plotting to accommodate this. 125 | """ 126 | self.step_count += 1 127 | 128 | batch = recursive_to(batch, self.device) 129 | 130 | step_out = self.step_(batch, **kwargs) 131 | 132 | if ( 133 | self.print_interval is not None 134 | and self.step_count % self.print_interval == 0 135 | ): 136 | print( 137 | f"Iteration: {self.step_count} (Time: {time.time() - self.clock:.4f})" 138 | ) 139 | self.clock = time.time() 140 | 141 | # if self.plot_interval is not None and self.step_count % self.plot_interval == 0: 142 | self.plots(batch, step_out) 143 | 144 | if self.save_interval is not None and self.step_count % self.save_interval == 0: 145 | self.network.save(self.save_dir) 146 | 147 | if self.test_interval is not None and self.step_count % self.test_interval == 0: 148 | self.test() 149 | 150 | return step_out 151 | 152 | def get_spike_data(self) -> Dict[str, torch.Tensor]: 153 | # language=rst 154 | """ 155 | Get the spike data from all layers in the pipeline's network. 156 | 157 | :return: A dictionary containing all spike monitors from the network. 158 | """ 159 | return { 160 | l: self.network.monitors[f"{l}_spikes"].get("s") 161 | for l in self.network.layers 162 | } 163 | 164 | def get_voltage_data( 165 | self 166 | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: 167 | # language=rst 168 | """ 169 | Get the voltage data and threshold value from all applicable layers in the pipeline's network. 170 | 171 | :return: Two dictionaries containing the voltage data and threshold values from the network. 172 | """ 173 | voltage_record = {} 174 | threshold_value = {} 175 | for l in self.network.layers: 176 | if hasattr(self.network.layers[l], "v"): 177 | voltage_record[l] = self.network.monitors[f"{l}_voltages"].get("v") 178 | if hasattr(self.network.layers[l], "thresh"): 179 | threshold_value[l] = self.network.layers[l].thresh 180 | 181 | return voltage_record, threshold_value 182 | 183 | def step_(self, batch: Any, **kwargs) -> Any: 184 | # language=rst 185 | """ 186 | Perform a pass of the network given the input batch. 187 | 188 | :param batch: The current batch. This could be anything as long as 189 | the subclass agrees upon the format in some way. 190 | 191 | :return: Any output that is need for recording purposes. 192 | """ 193 | raise NotImplementedError("You need to provide a step_ method.") 194 | 195 | def train(self) -> None: 196 | # language=rst 197 | """ 198 | A fully self-contained training loop. 199 | """ 200 | raise NotImplementedError("You need to provide a train method.") 201 | 202 | def test(self) -> None: 203 | # language=rst 204 | """ 205 | A fully self contained test function. 206 | """ 207 | raise NotImplementedError("You need to provide a test method.") 208 | 209 | def init_fn(self) -> None: 210 | # language=rst 211 | """ 212 | Placeholder function for subclass-specific actions that need to 213 | happen during the construction of the ``BasePipeline``. 214 | """ 215 | raise NotImplementedError("You need to provide an init_fn method.") 216 | 217 | def plots(self, batch: Any, step_out: Any) -> None: 218 | # language=rst 219 | """ 220 | Create any plots and logs for a step given the input batch and step output. 221 | 222 | :param batch: The current batch. This could be anything as long as 223 | the subclass agrees upon the format in some way. 224 | :param step_out: The output from the ``step_()`` method. 225 | """ 226 | raise NotImplementedError("You need to provide a plots method.") 227 | -------------------------------------------------------------------------------- /bindsnet/pipeline/dataloader_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from tqdm import tqdm 6 | 7 | from ..network import Network 8 | from .base_pipeline import BasePipeline 9 | from ..analysis.pipeline_analysis import PipelineAnalyzer 10 | from ..datasets import DataLoader 11 | 12 | 13 | class DataLoaderPipeline(BasePipeline): 14 | # language=rst 15 | """ 16 | A generic ``DataLoader`` pipeline that leverages the ``torch.utils.data`` 17 | setup. This still needs to be subclasses for specific 18 | implementations for functions given the dataset that will be used. 19 | An example can be seen in ``TorchVisionDatasetPipeline``. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | network: Network, 25 | train_ds: Dataset, 26 | test_ds: Optional[Dataset] = None, 27 | **kwargs 28 | ) -> None: 29 | # language=rst 30 | """ 31 | Initializes the pipeline. 32 | 33 | :param network: Arbitrary ``network`` object. 34 | :param train_ds: Arbitrary ``torch.utils.data.Dataset`` object. 35 | :param test_ds: Arbitrary ``torch.utils.data.Dataset`` object. 36 | """ 37 | super().__init__(network, **kwargs) 38 | 39 | self.train_ds = train_ds 40 | self.test_ds = test_ds 41 | 42 | self.num_epochs = kwargs.get("num_epochs", 10) 43 | self.batch_size = kwargs.get("batch_size", 1) 44 | self.num_workers = kwargs.get("num_workers", 0) 45 | self.pin_memory = kwargs.get("pin_memory", True) 46 | self.shuffle = kwargs.get("shuffle", True) 47 | 48 | def train(self) -> None: 49 | # language=rst 50 | """ 51 | Training loop that runs for the set number of epochs and creates 52 | a new ``DataLoader`` at each epoch. 53 | """ 54 | for epoch in range(self.num_epochs): 55 | train_dataloader = DataLoader( 56 | self.train_ds, 57 | batch_size=self.batch_size, 58 | num_workers=self.num_workers, 59 | pin_memory=self.pin_memory, 60 | shuffle=self.shuffle, 61 | ) 62 | 63 | for step, batch in enumerate( 64 | tqdm( 65 | train_dataloader, 66 | desc="Epoch %d/%d" % (epoch + 1, self.num_epochs), 67 | total=len(self.train_ds) // self.batch_size, 68 | ) 69 | ): 70 | self.step(batch) 71 | 72 | def test(self) -> None: 73 | raise NotImplementedError("You need to provide a test function.") 74 | 75 | 76 | class TorchVisionDatasetPipeline(DataLoaderPipeline): 77 | # language=rst 78 | """ 79 | An example implementation of ``DataLoaderPipeline`` that runs all of the 80 | datasets inside of ``bindsnet.datasets`` that inherit from an instance 81 | of a ``torchvision.datasets``. These are documented in 82 | ``bindsnet/datasets/README.md``. This specific class just runs an 83 | unsupervised network. 84 | """ 85 | 86 | def __init__( 87 | self, 88 | network: Network, 89 | train_ds: Dataset, 90 | pipeline_analyzer: Optional[PipelineAnalyzer] = None, 91 | **kwargs 92 | ) -> None: 93 | # language=rst 94 | """ 95 | Initializes the pipeline. 96 | 97 | :param network: Arbitrary ``network`` object. 98 | :param train_ds: A ``torchvision.datasets`` wrapper dataset from ``bindsnet.datasets``. 99 | 100 | Keyword arguments: 101 | 102 | :param str input_layer: Layer of the network that receives input. 103 | """ 104 | super().__init__(network, train_ds, None, **kwargs) 105 | 106 | self.input_layer = kwargs.get("input_layer", "X") 107 | self.pipeline_analyzer = pipeline_analyzer 108 | 109 | def step_(self, batch: Dict[str, torch.Tensor], **kwargs) -> None: 110 | # language=rst 111 | """ 112 | Perform a pass of the network given the input batch. Unsupervised training 113 | (implying everything is stored inside of the ``network`` object, therefore returns ``None``. 114 | 115 | :param batch: A dictionary of the current batch. Includes image, 116 | label and encoded versions. 117 | """ 118 | self.network.reset_() 119 | inpts = {self.input_layer: batch["encoded_image"]} 120 | self.network.run(inpts, time=batch["encoded_image"].shape[0]) 121 | 122 | def init_fn(self) -> None: 123 | pass 124 | 125 | def plots(self, batch: Dict[str, torch.Tensor], *args) -> None: 126 | # language=rst 127 | """ 128 | Create any plots and logs for a step given the input batch. 129 | 130 | :param batch: A dictionary of the current batch. Includes image, 131 | label and encoded versions. 132 | """ 133 | if self.pipeline_analyzer is not None: 134 | self.pipeline_analyzer.plot_obs( 135 | batch["encoded_image"][0, ...].sum(0), step=self.step_count 136 | ) 137 | 138 | self.pipeline_analyzer.plot_spikes( 139 | self.get_spike_data(), step=self.step_count 140 | ) 141 | 142 | vr, tv = self.get_voltage_data() 143 | self.pipeline_analyzer.plot_voltages(vr, tv, step=self.step_count) 144 | 145 | self.pipeline_analyzer.finalize_step() 146 | 147 | def test_step(self): 148 | pass 149 | -------------------------------------------------------------------------------- /bindsnet/pipeline/environment_pipeline.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Callable, Optional, Tuple, Dict 3 | 4 | import torch 5 | 6 | from .base_pipeline import BasePipeline 7 | from ..analysis.pipeline_analysis import MatplotlibAnalyzer 8 | from ..environment import Environment 9 | from ..network import Network 10 | from ..network.nodes import AbstractInput 11 | 12 | 13 | class EnvironmentPipeline(BasePipeline): 14 | # language=rst 15 | """ 16 | Abstracts the interaction between ``Network``, ``Environment`` and environment feedback action. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | network: Network, 22 | environment: Environment, 23 | action_function: Optional[Callable] = None, 24 | **kwargs, 25 | ): 26 | # language=rst 27 | """ 28 | Initializes the pipeline. 29 | 30 | :param network: Arbitrary network object. 31 | :param environment: Arbitrary environment. 32 | :param action_function: Function to convert network outputs into environment inputs. 33 | 34 | Keyword arguments: 35 | 36 | :param int num_episodes: Number of episodes to train for. Defaults to 100. 37 | :param str output: String name of the layer from which to take output. 38 | :param int render_interval: Interval to render the environment. 39 | :param int reward_delay: How many iterations to delay delivery of reward. 40 | :param int time: Time for which to run the network. Defaults to the network's timestep. 41 | """ 42 | super().__init__(network, **kwargs) 43 | 44 | self.episode = 0 45 | 46 | self.env = environment 47 | self.action_function = action_function 48 | 49 | self.accumulated_reward = 0.0 50 | self.reward_list = [] 51 | 52 | # Setting kwargs. 53 | self.num_episodes = kwargs.get("num_episodes", 100) 54 | self.output = kwargs.get("output", None) 55 | self.render_interval = kwargs.get("render_interval", None) 56 | self.reward_delay = kwargs.get("reward_delay", None) 57 | self.time = kwargs.get("time", int(network.dt)) 58 | 59 | if self.reward_delay is not None: 60 | assert self.reward_delay > 0 61 | self.rewards = torch.zeros(self.reward_delay) 62 | 63 | # Set up for multiple layers of input layers. 64 | self.inpts = [ 65 | name 66 | for name, layer in network.layers.items() 67 | if isinstance(layer, AbstractInput) 68 | ] 69 | 70 | self.action = None 71 | 72 | self.voltage_record = None 73 | self.threshold_value = None 74 | self.reward_plot = None 75 | 76 | self.first = True 77 | self.analyzer = MatplotlibAnalyzer(**self.plot_config) 78 | 79 | def init_fn(self) -> None: 80 | pass 81 | 82 | def train(self, **kwargs) -> None: 83 | # language=rst 84 | """ 85 | Trains for the specified number of episodes. Each episode can be of arbitrary length. 86 | """ 87 | while self.episode < self.num_episodes: 88 | self.reset_() 89 | 90 | for _ in itertools.count(): 91 | obs, reward, done, info = self.env_step() 92 | 93 | self.step((obs, reward, done, info), **kwargs) 94 | 95 | if done: 96 | break 97 | 98 | print( 99 | f"Episode: {self.episode} - accumulated reward: {self.accumulated_reward:.2f}" 100 | ) 101 | self.episode += 1 102 | 103 | def env_step(self) -> Tuple[torch.Tensor, float, bool, Dict]: 104 | # language=rst 105 | """ 106 | Single step of the environment which includes rendering, getting and performing the action, 107 | and accumulating/delaying rewards. 108 | 109 | :return: An OpenAI ``gym`` compatible tuple with modified reward and info. 110 | """ 111 | # Render game. 112 | if ( 113 | self.render_interval is not None 114 | and self.step_count % self.render_interval == 0 115 | ): 116 | self.env.render() 117 | 118 | # Choose action based on output neuron spiking. 119 | if self.action_function is not None: 120 | self.action = self.action_function(self, output=self.output) 121 | 122 | # Run a step of the environment. 123 | obs, reward, done, info = self.env.step(self.action) 124 | 125 | # Set reward in case of delay. 126 | if self.reward_delay is not None: 127 | self.rewards = torch.tensor([reward, *self.rewards[1:]]).float() 128 | reward = self.rewards[-1] 129 | 130 | # Accumulate reward. 131 | self.accumulated_reward += reward 132 | 133 | info["accumulated_reward"] = self.accumulated_reward 134 | 135 | return obs, reward, done, info 136 | 137 | def step_( 138 | self, gym_batch: Tuple[torch.Tensor, float, bool, Dict], **kwargs 139 | ) -> None: 140 | # language=rst 141 | """ 142 | Run a single iteration of the network and update it and the 143 | reward list when done. 144 | 145 | :param gym_batch: An OpenAI ``gym`` compatible tuple. 146 | """ 147 | obs, reward, done, info = gym_batch 148 | 149 | # Place the observations into the inputs. 150 | inpts = {k: obs for k in self.inpts} 151 | 152 | # Run the network on the spike train-encoded inputs. 153 | self.network.run( 154 | inpts=inpts, time=self.time, reward=reward, input_time_dim=1, **kwargs 155 | ) 156 | 157 | if done: 158 | if self.network.reward_fn is not None: 159 | self.network.reward_fn.update( 160 | accumulated_reward=self.accumulated_reward, 161 | steps=self.step_count, 162 | **kwargs, 163 | ) 164 | self.reward_list.append(self.accumulated_reward) 165 | 166 | def reset_(self) -> None: 167 | # language=rst 168 | """ 169 | Reset the pipeline. 170 | """ 171 | self.env.reset() 172 | self.network.reset_() 173 | self.accumulated_reward = 0.0 174 | self.step_count = 0 175 | 176 | def plots(self, gym_batch: Tuple[torch.Tensor, float, bool, Dict], *args) -> None: 177 | # language=rst 178 | """ 179 | Plot the encoded input, layer spikes, and layer voltages. 180 | 181 | :param gym_batch: An OpenAI ``gym`` compatible tuple. 182 | """ 183 | obs, reward, done, info = gym_batch 184 | 185 | for key, item in self.plot_config.items(): 186 | if key == "obs_step" and item is not None: 187 | if self.step_count % item == 0: 188 | self.analyzer.plot_obs(obs[0, ...].sum(0)) 189 | elif key == "data_step" and item is not None: 190 | if self.step_count % item == 0: 191 | self.analyzer.plot_spikes(self.get_spike_data()) 192 | self.analyzer.plot_voltages(*self.get_voltage_data()) 193 | elif key == "reward_eps" and item is not None: 194 | if self.episode % item == 0 and done: 195 | self.analyzer.plot_reward(self.reward_list) 196 | 197 | self.analyzer.finalize_step() 198 | -------------------------------------------------------------------------------- /bindsnet/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessing import AbstractPreprocessor 2 | -------------------------------------------------------------------------------- /bindsnet/preprocessing/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/preprocessing/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/preprocessing/__pycache__/preprocessing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeuroCompLab-psu/SNN-Conversion/3d26aae1f8f5f284df67b9eed27fbe1ad69ccc27/bindsnet/preprocessing/__pycache__/preprocessing.cpython-37.pyc -------------------------------------------------------------------------------- /bindsnet/preprocessing/preprocessing.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import pickle 4 | import torch 5 | 6 | from abc import abstractmethod, ABC 7 | 8 | 9 | class AbstractPreprocessor(ABC): 10 | # language=rst 11 | """ 12 | Abstract base class for Preprocessor. 13 | """ 14 | 15 | def process( 16 | self, 17 | csvfile: str, 18 | use_cache: bool = True, 19 | cachedfile: str = "./processed/data.pt", 20 | ) -> torch.tensor: 21 | # cache dictionary for storing encodings if previously encoded 22 | cache = {"verify": "", "data": None} 23 | 24 | # if the file exists 25 | if use_cache: 26 | # generate a hash 27 | cache["verify"] = self.__gen_hash(csvfile) 28 | 29 | # compare hash, if valid return cached value 30 | if self.__check_file(cachedfile, cache): 31 | return cache["data"] 32 | 33 | # otherwise process the data 34 | self._process(csvfile, cache) 35 | 36 | # save if use_cache 37 | if use_cache: 38 | self.__save(cachedfile, cache) 39 | 40 | # return data 41 | return cache["data"] 42 | 43 | @abstractmethod 44 | def _process(self, filename: str, cache: dict): 45 | # language=rst 46 | """ 47 | Method for defining how to preprocess the data. 48 | :param filename: file to load raw data from 49 | :param cache: dict for caching 'data' needs to be updated for caching to work 50 | """ 51 | pass 52 | 53 | def __gen_hash(self, filename: str) -> str: 54 | # language=rst 55 | """ 56 | Generates an hash for a csv file and the preprocessor name 57 | :param filename: file to generate hash for 58 | :return: hash for the csv file 59 | """ 60 | # read all the lines 61 | with open(filename, "r") as f: 62 | lines = f.readlines() 63 | # generate md5 hash after concatenating all of the lines 64 | pre = "".join(lines) + str(self.__class__.__name__) 65 | m = hashlib.md5(pre.encode("utf-8")) 66 | return m.hexdigest() 67 | 68 | @staticmethod 69 | def __check_file(cachedfile: str, cache: dict) -> bool: 70 | # language=rst 71 | """ 72 | Compares the csv file and the saved file to see if a new encoding needs to be generated. 73 | :param cachedfile: the filename of the cached data 74 | :param cache: dict containing the current csvfile hash. This is updated if the cachefile has valid data 75 | :return: whether the cache is valid 76 | 77 | """ 78 | # try opening the cached file 79 | try: 80 | with open(cachedfile, "rb") as f: 81 | temp = pickle.load(f) 82 | except FileNotFoundError: 83 | temp = {"verify": "", "data": None} 84 | 85 | # if the hash matches up, keep the data from the cache 86 | if cache["verify"] == temp["verify"]: 87 | cache["data"] = temp["data"] 88 | return True 89 | 90 | # otherwise don't do anything 91 | return False 92 | 93 | @staticmethod 94 | def __save(filename: str, data: dict) -> None: 95 | # language=rst 96 | """ 97 | Creates/Overwrites existing encoding file 98 | :param filename: filename to save to 99 | """ 100 | # if the directories in path don't exist create them 101 | if not os.path.exists(os.path.dirname(filename)): 102 | os.makedirs(os.path.dirname(filename), exist_ok=True) 103 | 104 | # save file 105 | with open(filename, "wb") as f: 106 | pickle.dump(data, f) 107 | -------------------------------------------------------------------------------- /bindsnet/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | 5 | from torch import Tensor 6 | import torch.nn.functional as F 7 | from numpy import ndarray 8 | from typing import Tuple, Union 9 | from torch.nn.modules.utils import _pair 10 | 11 | 12 | def im2col_indices( 13 | x: Tensor, 14 | kernel_height: int, 15 | kernel_width: int, 16 | padding: Tuple[int, int] = (0, 0), 17 | stride: Tuple[int, int] = (1, 1), 18 | ) -> Tensor: 19 | # language=rst 20 | """ 21 | im2col is a special case of unfold which is implemented inside of Pytorch. 22 | 23 | :param x: Input image tensor to be reshaped to column-wise format. 24 | :param kernel_height: Height of the convolutional kernel in pixels. 25 | :param kernel_width: Width of the convolutional kernel in pixels. 26 | :param padding: Amount of zero padding on the input image. 27 | :param stride: Amount to stride over image by per convolution. 28 | :return: Input tensor reshaped to column-wise format. 29 | """ 30 | return F.unfold(x, (kernel_height, kernel_width), padding=padding, stride=stride) 31 | 32 | 33 | def col2im_indices( 34 | cols: Tensor, 35 | x_shape: Tuple[int, int, int, int], 36 | kernel_height: int, 37 | kernel_width: int, 38 | padding: Tuple[int, int] = (0, 0), 39 | stride: Tuple[int, int] = (1, 1), 40 | ) -> Tensor: 41 | # language=rst 42 | """ 43 | col2im is a special case of fold which is implemented inside of Pytorch. 44 | 45 | :param cols: Image tensor in column-wise format. 46 | :param x_shape: Shape of original image tensor. 47 | :param kernel_height: Height of the convolutional kernel in pixels. 48 | :param kernel_width: Width of the convolutional kernel in pixels. 49 | :param padding: Amount of zero padding on the input image. 50 | :param stride: Amount to stride over image by per convolution. 51 | :return: Image tensor in original image shape. 52 | """ 53 | return F.fold( 54 | cols, x_shape, (kernel_height, kernel_width), padding=padding, stride=stride 55 | ) 56 | 57 | 58 | def get_square_weights( 59 | weights: Tensor, n_sqrt: int, side: Union[int, Tuple[int, int]] 60 | ) -> Tensor: 61 | # language=rst 62 | """ 63 | Return a grid of a number of filters ``sqrt ** 2`` with side lengths ``side``. 64 | 65 | :param weights: Two-dimensional tensor of weights for two-dimensional data. 66 | :param n_sqrt: Square root of no. of filters. 67 | :param side: Side length(s) of filter. 68 | :return: Reshaped weights to square matrix of filters. 69 | """ 70 | if isinstance(side, int): 71 | side = (side, side) 72 | 73 | square_weights = torch.zeros(side[0] * n_sqrt, side[1] * n_sqrt) 74 | for i in range(n_sqrt): 75 | for j in range(n_sqrt): 76 | n = i * n_sqrt + j 77 | 78 | if not n < weights.size(1): 79 | break 80 | 81 | x = i * side[0] 82 | y = (j % n_sqrt) * side[1] 83 | filter_ = weights[:, n].contiguous().view(*side) 84 | square_weights[x : x + side[0], y : y + side[1]] = filter_ 85 | 86 | return square_weights 87 | 88 | 89 | def get_square_assignments(assignments: Tensor, n_sqrt: int) -> Tensor: 90 | # language=rst 91 | """ 92 | Return a grid of assignments. 93 | 94 | :param assignments: Vector of integers corresponding to class labels. 95 | :param n_sqrt: Square root of no. of assignments. 96 | :return: Reshaped square matrix of assignments. 97 | """ 98 | square_assignments = torch.mul(torch.ones(n_sqrt, n_sqrt), -1.0) 99 | for i in range(n_sqrt): 100 | for j in range(n_sqrt): 101 | n = i * n_sqrt + j 102 | 103 | if not n < assignments.size(0): 104 | break 105 | 106 | square_assignments[ 107 | i : (i + 1), (j % n_sqrt) : ((j % n_sqrt) + 1) 108 | ] = assignments[n] 109 | 110 | return square_assignments 111 | 112 | 113 | def reshape_locally_connected_weights( 114 | w: Tensor, 115 | n_filters: int, 116 | kernel_size: Union[int, Tuple[int, int]], 117 | conv_size: Union[int, Tuple[int, int]], 118 | locations: Tensor, 119 | input_sqrt: Union[int, Tuple[int, int]], 120 | ) -> Tensor: 121 | # language=rst 122 | """ 123 | Get the weights from a locally connected layer and reshape them to be two-dimensional and square. 124 | 125 | :param w: Weights from a locally connected layer. 126 | :param n_filters: No. of neuron filters. 127 | :param kernel_size: Side length(s) of convolutional kernel. 128 | :param conv_size: Side length(s) of convolution population. 129 | :param locations: Binary mask indicating receptive fields of convolution population neurons. 130 | :param input_sqrt: Sides length(s) of input neurons. 131 | :return: Locally connected weights reshaped as a collection of spatially ordered square grids. 132 | """ 133 | kernel_size = _pair(kernel_size) 134 | conv_size = _pair(conv_size) 135 | input_sqrt = _pair(input_sqrt) 136 | 137 | k1, k2 = kernel_size 138 | c1, c2 = conv_size 139 | i1, i2 = input_sqrt 140 | c1sqrt, c2sqrt = int(math.ceil(math.sqrt(c1))), int(math.ceil(math.sqrt(c2))) 141 | fs = int(math.ceil(math.sqrt(n_filters))) 142 | 143 | w_ = torch.zeros((n_filters * k1, k2 * c1 * c2)) 144 | 145 | for n1 in range(c1): 146 | for n2 in range(c2): 147 | for feature in range(n_filters): 148 | n = n1 * c2 + n2 149 | filter_ = w[ 150 | locations[:, n], 151 | feature * (c1 * c2) + (n // c2sqrt) * c2sqrt + (n % c2sqrt), 152 | ].view(k1, k2) 153 | w_[feature * k1 : (feature + 1) * k1, n * k2 : (n + 1) * k2] = filter_ 154 | 155 | if c1 == 1 and c2 == 1: 156 | square = torch.zeros((i1 * fs, i2 * fs)) 157 | 158 | for n in range(n_filters): 159 | square[ 160 | (n // fs) * i1 : ((n // fs) + 1) * i2, 161 | (n % fs) * i2 : ((n % fs) + 1) * i2, 162 | ] = w_[n * i1 : (n + 1) * i2] 163 | 164 | return square 165 | else: 166 | square = torch.zeros((k1 * fs * c1, k2 * fs * c2)) 167 | 168 | for n1 in range(c1): 169 | for n2 in range(c2): 170 | for f1 in range(fs): 171 | for f2 in range(fs): 172 | if f1 * fs + f2 < n_filters: 173 | square[ 174 | k1 * (n1 * fs + f1) : k1 * (n1 * fs + f1 + 1), 175 | k2 * (n2 * fs + f2) : k2 * (n2 * fs + f2 + 1), 176 | ] = w_[ 177 | (f1 * fs + f2) * k1 : (f1 * fs + f2 + 1) * k1, 178 | (n1 * c2 + n2) * k2 : (n1 * c2 + n2 + 1) * k2, 179 | ] 180 | 181 | return square 182 | 183 | 184 | def reshape_conv2d_weights(weights: torch.Tensor) -> torch.Tensor: 185 | # language=rst 186 | """ 187 | Flattens a connection weight matrix of a Conv2dConnection 188 | 189 | :param weights: Weight matrix of Conv2dConnection object. 190 | :param wmin: Minimum allowed weight value. 191 | :param wmax: Maximum allowed weight value. 192 | """ 193 | sqrt1 = int(np.ceil(np.sqrt(weights.size(0)))) 194 | sqrt2 = int(np.ceil(np.sqrt(weights.size(1)))) 195 | height, width = weights.size(2), weights.size(3) 196 | reshaped = torch.zeros( 197 | sqrt1 * sqrt2 * weights.size(2), sqrt1 * sqrt2 * weights.size(3) 198 | ) 199 | 200 | for i in range(sqrt1): 201 | for j in range(sqrt1): 202 | for k in range(sqrt2): 203 | for l in range(sqrt2): 204 | if i * sqrt1 + j < weights.size(0) and k * sqrt2 + l < weights.size( 205 | 1 206 | ): 207 | fltr = weights[i * sqrt1 + j, k * sqrt2 + l].view(height, width) 208 | reshaped[ 209 | i * height 210 | + k * height * sqrt1 : (i + 1) * height 211 | + k * height * sqrt1, 212 | (j % sqrt1) * width 213 | + (l % sqrt2) * width * sqrt1 : ((j % sqrt1) + 1) * width 214 | + (l % sqrt2) * width * sqrt1, 215 | ] = fltr 216 | 217 | return reshaped 218 | -------------------------------------------------------------------------------- /conversion.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from time import time 4 | from matplotlib import pyplot as plt 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | from bindsnet.conversion import ann_to_snn 9 | from bindsnet.encoding import RepeatEncoder 10 | from bindsnet.datasets import ImageNet, CIFAR100, DataLoader 11 | import torchvision.transforms as transforms 12 | from vgg import vgg_15_avg_before_relu 13 | 14 | 15 | def main(args): 16 | if args.gpu and torch.cuda.is_available(): 17 | torch.cuda.manual_seed_all(args.seed) 18 | 19 | np.random.seed(args.seed) 20 | torch.manual_seed(args.seed) 21 | 22 | if args.n_workers == -1: 23 | args.n_workers = args.gpu * 4 * torch.cuda.device_count() 24 | 25 | device = torch.device("cuda" if args.gpu else "cpu") 26 | 27 | # Load trained ANN from disk. 28 | if args.arch == 'vgg15ab': 29 | ann = vgg_15_avg_before_relu(dataset=args.dataset) 30 | # add other architectures here# 31 | else: 32 | raise ValueError('Unknown architecture') 33 | 34 | 35 | ann.features = torch.nn.DataParallel(ann.features) 36 | ann.cuda() 37 | if not os.path.isdir(args.job_dir): 38 | os.mkdir(args.job_dir) 39 | f = os.path.join('.', args.model) 40 | try: 41 | dictionary = torch.load(f=f)['state_dict'] 42 | except KeyError: 43 | dictionary = torch.load(f=f) 44 | ann.load_state_dict(state_dict=dictionary, strict=True) 45 | 46 | if args.dataset=='imagenet': 47 | input_shape=(3,224,224) 48 | 49 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 50 | std=[0.229, 0.224, 0.225]) 51 | # the actual data to be evaluated 52 | val_loader = ImageNet( 53 | image_encoder=RepeatEncoder(time=args.time, dt=1.0), 54 | label_encoder=None, 55 | root=args.data, 56 | download=False, 57 | transform=transforms.Compose([ 58 | transforms.Resize((256, 256)), 59 | transforms.CenterCrop(224), 60 | transforms.ToTensor(), 61 | normalize, 62 | ]), 63 | split='val') 64 | # a wrapper class 65 | dataloader = DataLoader( 66 | val_loader, 67 | batch_size=args.batch_size, 68 | shuffle=True, 69 | num_workers=4, 70 | pin_memory=args.gpu, 71 | ) 72 | # A loader of samples for normalization of the SNN from the training set 73 | norm_loader = ImageNet( 74 | image_encoder=RepeatEncoder(time=args.time, dt=1.0), 75 | label_encoder=None, 76 | root=args.data, 77 | download=False, 78 | split='train', 79 | transform = transforms.Compose([ 80 | transforms.Resize(256), 81 | transforms.CenterCrop(224), 82 | transforms.ToTensor(), 83 | normalize, ] 84 | ), 85 | ) 86 | 87 | elif args.dataset == 'cifar100': 88 | input_shape=(3, 32, 32) 89 | print('==> Using Pytorch CIFAR-100 Dataset') 90 | normalize = transforms.Normalize(mean=[0.507, 0.487, 0.441], 91 | std=[0.267, 0.256, 0.276]) 92 | val_loader = CIFAR100( 93 | image_encoder=RepeatEncoder(time=args.time, dt=1.0), 94 | label_encoder=None, 95 | root=args.data, 96 | download=True, 97 | train=False, 98 | transform=transforms.Compose([ 99 | transforms.RandomCrop(32, padding=4), 100 | transforms.RandomHorizontalFlip(0.5), 101 | transforms.ToTensor(), 102 | normalize, ] 103 | ) 104 | ) 105 | 106 | dataloader = DataLoader( 107 | val_loader, 108 | batch_size=args.batch_size, 109 | shuffle=True, 110 | num_workers=0, 111 | pin_memory=args.gpu, 112 | ) 113 | 114 | norm_loader = CIFAR100( 115 | image_encoder=RepeatEncoder(time=args.time, dt=1.0), 116 | label_encoder=None, 117 | root=args.data, 118 | download=True, 119 | train=True, 120 | transform=transforms.Compose([ 121 | transforms.RandomCrop(32, padding=4), 122 | transforms.RandomHorizontalFlip(0.5), 123 | transforms.ToTensor(), 124 | normalize, ] 125 | ) 126 | ) 127 | else: 128 | raise ValueError('Unsupported dataset.') 129 | 130 | if args.eval_size == -1: 131 | args.eval_size = len(val_loader) 132 | 133 | for step, batch in enumerate(torch.utils.data.DataLoader(norm_loader, batch_size=args.norm)): 134 | data = batch['image'] 135 | break 136 | 137 | snn = ann_to_snn(ann, input_shape=input_shape, data=data, percentile=args.percentile) 138 | 139 | 140 | torch.cuda.empty_cache() 141 | snn = snn.to(device) 142 | 143 | correct = 0 144 | t0 = time() 145 | accuracies = np.zeros((args.time, (args.eval_size//args.batch_size)+1), dtype=np.float32) 146 | for step, batch in enumerate(tqdm(dataloader)): 147 | if (step+1)*args.batch_size > args.eval_size: 148 | break 149 | # Prep next input batch. 150 | inputs = batch["encoded_image"] 151 | labels = batch["label"] 152 | inpts = {"Input": inputs} 153 | if args.gpu: 154 | inpts = {k: v.cuda() for k, v in inpts.items()} 155 | 156 | snn.run(inpts=inpts, time=args.time, step=step, acc= accuracies, labels=labels,one_step=args.one_step) 157 | last_layer = list(snn.layers.keys())[-1] 158 | output_voltages = snn.layers[last_layer].summed 159 | prediction = torch.softmax(output_voltages, dim=1).argmax(dim=1) 160 | correct += (prediction.cpu() == labels).sum().item() 161 | snn.reset_() 162 | t1 = time() - t0 163 | 164 | final = accuracies.sum(axis=1) / args.eval_size 165 | 166 | plt.plot(final) 167 | plt.suptitle('{} {} ANN-SNN@{} percentile'.format(args.dataset, args.arch, args.percentile), fontsize=20) 168 | plt.xlabel('Timestep', fontsize=19) 169 | plt.ylabel('Accuracy', fontsize=19) 170 | plt.grid() 171 | plt.show() 172 | plt.savefig('{}/{}_{}.png'.format(args.job_dir, args.arch, args.percentile)) 173 | np.save('{}/voltage_accuracy_{}_{}.npy'.format(args.job_dir, args.arch, args.percentile), final) 174 | 175 | 176 | accuracy = 100 * correct / args.eval_size 177 | 178 | print(f"SNN accuracy: {accuracy:.2f}") 179 | print(f"Clock time used: {t1:.4f} ms.") 180 | path = os.path.join(args.job_dir, "results", args.results_file) 181 | os.makedirs(os.path.dirname(path), exist_ok=True) 182 | if not os.path.isfile(path): 183 | with open(path, "w") as f: 184 | f.write("seed,simulation time,batch size,inference time,accuracy\n") 185 | to_write = [args.seed, args.time, args.batch_size, t1, accuracy] 186 | to_write = ",".join(map(str, to_write)) + "\n" 187 | with open(path, "a") as f: 188 | f.write(to_write) 189 | 190 | return t1 191 | 192 | 193 | def parse_args(): 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument("--job-dir", type=str, required=True, help='The working directory to store results') 196 | parser.add_argument("--model", type=str, required=True, help='The path to the pretrained model') 197 | parser.add_argument("--results-file", type=str, default='sim_result.txt', help='The file to store simulation result') 198 | parser.add_argument("--seed", type=int, default=0, help='A random seed') 199 | parser.add_argument("--time", type=int, default=80, help='Time steps to be simulated by the converted SNN (default: 80)') 200 | parser.add_argument("--batch-size", type=int, default=100, help='Mini batch size') 201 | parser.add_argument("--n-workers", type=int, default=4, help='Number of data loaders') 202 | parser.add_argument("--norm", type=int, default=128, help='The amount of data to be normalized at once') 203 | parser.add_argument("--gpu", action="store_true", help='Whether to use GPU or not') 204 | parser.add_argument("--one-step", action="store_true", help='Single step inference flag') 205 | parser.add_argument('--data', metavar='DATA_PATH', default='./data/', 206 | help='The path to ImageNet data (default: \'./data/)\', CIFAR-100 will be downloaded') 207 | parser.add_argument("--arch", type=str, default='vgg15ab', help='ANN architecture to be instantiated') 208 | parser.add_argument("--percentile", type=float, default=99.7, help='The percentile of activation in the training set to be used for normalization of SNN voltage threshold') 209 | parser.add_argument("--eval_size", type=int, default=-1, help='The amount of samples to be evaluated (default: evaluate all)') 210 | parser.add_argument("--dataset", type=str, default='cifar100', help='cifar100 or imagenet') 211 | 212 | return parser.parse_args() 213 | 214 | 215 | if __name__ == "__main__": 216 | main(parse_args()) 217 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class VGG_15_avg_before_relu(nn.Module): 5 | def __init__(self, dr=0.1, num_classes=1000, units=512*7*7): 6 | super(VGG_15_avg_before_relu, self).__init__() 7 | self.features = nn.Sequential( 8 | nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 9 | nn.ReLU(), 10 | nn.Dropout(dr), 11 | nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 12 | nn.AvgPool2d((2, 2), (2, 2)), 13 | nn.ReLU(), 14 | 15 | nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 16 | nn.ReLU(), 17 | nn.Dropout(dr), 18 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 19 | nn.AvgPool2d((2, 2), (2, 2)), 20 | nn.ReLU(), 21 | 22 | nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 23 | nn.ReLU(), 24 | nn.Dropout(dr), 25 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 26 | nn.ReLU(), 27 | nn.Dropout(dr), 28 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 29 | nn.AvgPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), # AvgPool2d, 30 | nn.ReLU(), 31 | 32 | nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 33 | nn.ReLU(), 34 | nn.Dropout(dr), 35 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 36 | nn.ReLU(), 37 | nn.Dropout(dr), 38 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 39 | nn.AvgPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), # AvgPool2d, 40 | nn.ReLU(), 41 | 42 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 43 | nn.ReLU(), 44 | nn.Dropout(dr), 45 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 46 | nn.ReLU(), 47 | nn.Dropout(dr), 48 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), 1, 1, bias=False), 49 | nn.AvgPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 50 | nn.ReLU() 51 | ) 52 | self.classifier = nn.Sequential( 53 | nn.Dropout(dr), 54 | nn.Linear(units, 4096, bias=False), # Linear, 55 | nn.ReLU(), 56 | nn.Dropout(dr), 57 | nn.Linear(4096, num_classes, bias=False) # Linear, 58 | ) 59 | 60 | self._initialize_weights() 61 | 62 | def _initialize_weights(self): 63 | for m in self.modules(): 64 | if isinstance(m, nn.Conv2d): 65 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 66 | if m.bias is not None: 67 | nn.init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.BatchNorm2d): 69 | nn.init.constant_(m.weight, 1) 70 | # nn.init.constant_(m.bias, 0) 71 | elif isinstance(m, nn.Linear): 72 | nn.init.normal_(m.weight, 0, 0.01) 73 | # nn.init.constant_(m.bias, 0) 74 | 75 | def forward(self, x): 76 | x = self.features(x) 77 | x = x.view(x.size(0), -1) 78 | x = self.classifier(x) 79 | return x 80 | 81 | def vgg_15_avg_before_relu(dataset='imagenet' , **kwargs): 82 | if dataset == 'imagenet': 83 | model = VGG_15_avg_before_relu(num_classes=1000, **kwargs) 84 | elif dataset == 'cifar100': 85 | model = VGG_15_avg_before_relu(num_classes=100, units=512,**kwargs) 86 | else: 87 | model = None 88 | raise ValueError('Unsupported Dataset!') 89 | return model 90 | --------------------------------------------------------------------------------