├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .gitlint ├── .nengobones.yml ├── .pre-commit-config.yaml ├── CHANGES.rst ├── CONTRIBUTING.rst ├── CONTRIBUTORS.rst ├── LICENSE.rst ├── MANIFEST.in ├── README.rst ├── docs ├── _static │ └── favicon.ico ├── conf.py ├── contributing.rst ├── examples.rst ├── examples │ └── spiking-fashion-mnist.ipynb ├── index.rst ├── installation.rst ├── license.rst ├── nengo-dl-comparison.rst ├── project.rst ├── reference.rst └── release-history.rst ├── pyproject.toml ├── pytorch_spiking ├── __init__.py ├── functional.py ├── modules.py ├── tests │ └── test_modules.py └── version.py ├── setup.cfg └── setup.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI testing 2 | on: 3 | pull_request: {} 4 | push: 5 | branches: 6 | - main 7 | - release-candidate-* 8 | tags: 9 | - v* 10 | workflow_dispatch: 11 | inputs: 12 | debug_enabled: 13 | description: Run the build with SSH debugging enabled 14 | type: boolean 15 | required: false 16 | default: false 17 | 18 | jobs: 19 | static: 20 | runs-on: ubuntu-latest 21 | timeout-minutes: 30 22 | steps: 23 | - uses: nengo/nengo-bones/actions/setup@main 24 | - uses: nengo/nengo-bones/actions/generate-and-check@main 25 | - uses: nengo/nengo-bones/actions/run-script@main 26 | with: 27 | name: static 28 | test: 29 | needs: 30 | - static 31 | timeout-minutes: 60 32 | runs-on: ubuntu-latest 33 | strategy: 34 | matrix: 35 | include: 36 | - script: test 37 | coverage-name: basic 38 | - script: test 39 | coverage-name: oldest 40 | pytorch-version: torch==1.0.0 41 | python: "3.7" 42 | fail-fast: false 43 | env: 44 | PYTORCH_VERSION: ${{ matrix.pytorch-version || 'torch' }} 45 | steps: 46 | - uses: nengo/nengo-bones/actions/setup@main 47 | with: 48 | python-version: ${{ matrix.python || '3.9' }} 49 | - uses: nengo/nengo-bones/actions/generate-and-check@main 50 | - uses: nengo/nengo-bones/actions/run-script@main 51 | with: 52 | name: ${{ matrix.script }} 53 | - uses: actions/upload-artifact@v3 54 | if: ${{ always() && matrix.coverage-name }} 55 | with: 56 | name: coverage-${{ matrix.coverage-name }} 57 | path: .coverage 58 | remote: 59 | needs: 60 | - test 61 | timeout-minutes: 60 62 | runs-on: ubuntu-latest 63 | strategy: 64 | matrix: 65 | include: 66 | - script: remote-docs 67 | - script: remote-examples 68 | fail-fast: false 69 | env: 70 | PYTORCH_VERSION: torch>=1.0.0 71 | SSH_KEY: ${{ secrets.SSH_KEY }} 72 | SSH_CONFIG: ${{ secrets.SSH_CONFIG }} 73 | GH_TOKEN: ${{ secrets.GH_TOKEN }} 74 | steps: 75 | - uses: nengo/nengo-bones/actions/setup@main 76 | with: 77 | python-version: "3.8" 78 | - name: Write secrets to file 79 | run: | 80 | mkdir -p ~/.ssh 81 | echo '${{ secrets.AZURE_PEM }}' > ~/.ssh/azure.pem 82 | - uses: nengo/nengo-bones/actions/generate-and-check@main 83 | - uses: nengo/nengo-bones/actions/run-script@main 84 | with: 85 | name: ${{ matrix.script }} 86 | coverage: 87 | runs-on: ubuntu-latest 88 | timeout-minutes: 10 89 | needs: 90 | - test 91 | if: ${{ always() }} 92 | steps: 93 | - uses: nengo/nengo-bones/actions/coverage-report@main 94 | deploy: 95 | needs: 96 | - remote 97 | if: >- 98 | startsWith(github.ref_name, 'release-candidate-') || 99 | (github.ref_type == 'tag' && startsWith(github.ref_name, 'v')) 100 | runs-on: ubuntu-latest 101 | timeout-minutes: 30 102 | steps: 103 | - name: Write .pypirc to file 104 | run: | 105 | echo '${{ secrets.PYPIRC_FILE }}' > ~/.pypirc 106 | - uses: actions/checkout@v3 107 | - uses: nengo/nengo-bones/actions/setup@main 108 | with: 109 | python-version: "3.10" 110 | - uses: nengo/nengo-bones/actions/generate-and-check@main 111 | - uses: nengo/nengo-bones/actions/run-script@main 112 | with: 113 | name: deploy 114 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | /.idea 3 | /docs/_build 4 | /docs/examples/* 5 | /pytorch_spiking.egg-info 6 | /secret 7 | /tmp 8 | bones-scripts/ 9 | 10 | # Exceptions, must be last 11 | !/docs/examples/*.ipynb 12 | -------------------------------------------------------------------------------- /.gitlint: -------------------------------------------------------------------------------- 1 | [general] 2 | ignore=body-is-missing 3 | 4 | [title-max-length] 5 | line-length=50 6 | 7 | [B1] 8 | # body line length 9 | line-length=72 10 | 11 | [title-match-regex] 12 | regex=^[A-Z] 13 | -------------------------------------------------------------------------------- /.nengobones.yml: -------------------------------------------------------------------------------- 1 | project_name: PyTorchSpiking 2 | pkg_name: pytorch_spiking 3 | repo_name: nengo/pytorch-spiking 4 | 5 | description: Spiking neuron integration for PyTorch 6 | copyright_start: 2020 7 | license: abr-free 8 | main_branch: main 9 | 10 | license_rst: {} 11 | 12 | contributing_rst: {} 13 | 14 | contributors_rst: {} 15 | 16 | manifest_in: {} 17 | 18 | setup_cfg: 19 | pytest: 20 | xfail_strict: True 21 | pylint: 22 | disable: 23 | - not-callable # https://github.com/pytorch/pytorch/issues/24807 24 | known_third_party: 25 | - torch 26 | codespell: 27 | ignore_words: 28 | - hist 29 | 30 | docs_conf_py: 31 | nengo_logo: "" 32 | extensions: 33 | - nengo_sphinx_theme.ext.autoautosummary 34 | doctest_setup: 35 | - import numpy as np 36 | - import torch 37 | autoautosummary_change_modules: 38 | pytorch_spiking: 39 | - pytorch_spiking.modules.SpikingActivation 40 | - pytorch_spiking.modules.Lowpass 41 | - pytorch_spiking.modules.TemporalAvgPool 42 | 43 | ci_scripts: 44 | - template: static 45 | - template: test 46 | coverage: true 47 | pip_install: 48 | - $PYTORCH_VERSION 49 | - nengo[tests] 50 | - template: docs 51 | pip_install: 52 | - $PYTORCH_VERSION 53 | - template: examples 54 | pip_install: 55 | - $PYTORCH_VERSION 56 | - template: remote-script 57 | remote_script: docs 58 | output_name: remote-docs 59 | host: azure-docs 60 | azure_name: nengo-dl-docs 61 | azure_group: nengo-ci 62 | remote_vars: 63 | PYTORCH_VERSION: $PYTORCH_VERSION 64 | remote_setup: 65 | - micromamba install -y cudatoolkit=10.2 66 | - template: remote-script 67 | remote_script: examples 68 | output_name: remote-examples 69 | host: azure-examples 70 | azure_name: nengo-dl-examples 71 | azure_group: nengo-ci 72 | remote_vars: 73 | PYTORCH_VERSION: $PYTORCH_VERSION 74 | remote_setup: 75 | - micromamba install -y cudatoolkit=10.2 76 | - template: deploy 77 | wheel: true 78 | 79 | setup_py: 80 | include_package_data: True 81 | install_req: 82 | - numpy>=1.16.0 83 | - torch>=1.0.0 84 | docs_req: 85 | - jupyter>=1.0.0 86 | - matplotlib>=2.0.0 87 | - nbsphinx>=0.3.5 88 | - nengo-sphinx-theme>=1.2.1 89 | - numpydoc>=0.6.0 90 | - sphinx>=3.0.0 91 | - torchvision>=0.7.0 92 | tests_req: 93 | - pylint>=1.9.2 94 | - pytest>=3.6.0 95 | - pytest-allclose>=1.0.0 96 | - pytest-cov>=2.6.0 97 | - pytest-rng>=1.0.0 98 | - pytest-xdist>=1.16.0 99 | classifiers: 100 | - "Development Status :: 3 - Alpha" 101 | - "Intended Audience :: Science/Research" 102 | - "Operating System :: Microsoft :: Windows" 103 | - "Operating System :: POSIX :: Linux" 104 | - "Programming Language :: Python" 105 | - "Programming Language :: Python :: 3.6" 106 | - "Programming Language :: Python :: 3.7" 107 | - "Programming Language :: Python :: 3.8" 108 | - "Topic :: Scientific/Engineering" 109 | - "Topic :: Scientific/Engineering :: Artificial Intelligence" 110 | 111 | pyproject_toml: {} 112 | 113 | pre_commit_config_yaml: {} 114 | 115 | version_py: 116 | major: 0 117 | minor: 1 118 | patch: 1 119 | release: false 120 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # Automatically generated by nengo-bones, do not edit this file directly 2 | 3 | repos: 4 | - repo: https://github.com/psf/black 5 | rev: 21.12b0 6 | hooks: 7 | - id: black 8 | files: \.py$ 9 | - repo: https://github.com/pycqa/isort 10 | rev: 5.6.4 11 | hooks: 12 | - id: isort 13 | files: \.py$ 14 | -------------------------------------------------------------------------------- /CHANGES.rst: -------------------------------------------------------------------------------- 1 | Release history 2 | =============== 3 | 4 | .. Changelog entries should follow this format: 5 | 6 | version (release date) 7 | ---------------------- 8 | 9 | **section** 10 | 11 | - One-line description of change (link to GitHub issue/PR) 12 | 13 | .. Changes should be organized in one of several sections: 14 | 15 | - Added 16 | - Changed 17 | - Fixed 18 | - Deprecated 19 | - Removed 20 | 21 | 0.1.1 (unreleased) 22 | ------------------ 23 | 24 | *Compatible with PyTorch 1.0.0 - 1.7.0* 25 | 26 | 0.1.0 (September 9, 2020) 27 | ------------------------- 28 | 29 | *Compatible with PyTorch 1.0.0 - 1.7.0* 30 | 31 | Initial release 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. Automatically generated by nengo-bones, do not edit this file directly 2 | 3 | ****************************** 4 | Contributing to PyTorchSpiking 5 | ****************************** 6 | 7 | Issues and pull requests are always welcome! 8 | We appreciate help from the community to make PyTorchSpiking better. 9 | 10 | Filing issues 11 | ============= 12 | 13 | If you find a bug in PyTorchSpiking, 14 | or think that a certain feature is missing, 15 | please consider 16 | `filing an issue `_! 17 | Please search the currently open issues first 18 | to see if your bug or feature request already exists. 19 | If so, feel free to add a comment to the issue 20 | so that we know that multiple people are affected. 21 | 22 | Making pull requests 23 | ==================== 24 | 25 | If you want to fix a bug or add a feature to PyTorchSpiking, 26 | we welcome pull requests. 27 | Ensure that you fill out all sections of the pull request template, 28 | deleting the comments as you go. 29 | 30 | Contributor agreement 31 | ===================== 32 | 33 | We require that all contributions be covered under 34 | our contributor assignment agreement. Please see 35 | `the agreement `_ 36 | for instructions on how to sign. 37 | 38 | More details 39 | ============ 40 | 41 | For more details on how to contribute to Nengo, 42 | please see the `developer guide `_. 43 | -------------------------------------------------------------------------------- /CONTRIBUTORS.rst: -------------------------------------------------------------------------------- 1 | .. Automatically generated by nengo-bones, do not edit this file directly 2 | 3 | *************************** 4 | PyTorchSpiking contributors 5 | *************************** 6 | 7 | See https://github.com/nengo/pytorch-spiking/graphs/contributors 8 | for a list of the people who have committed to PyTorchSpiking. 9 | Thank you for your contributions! 10 | 11 | For the full list of the many contributors to the Nengo ecosystem, 12 | see https://www.nengo.ai/people/. 13 | -------------------------------------------------------------------------------- /LICENSE.rst: -------------------------------------------------------------------------------- 1 | .. Automatically generated by nengo-bones, do not edit this file directly 2 | 3 | ********************** 4 | PyTorchSpiking license 5 | ********************** 6 | 7 | Copyright (c) 2020-2023 Applied Brain Research 8 | 9 | **ABR License** 10 | 11 | PyTorchSpiking is made available under a proprietary license, the 12 | "ABR TECHNOLOGY LICENSE AND USE AGREEMENT" (the "ABR License"). 13 | The main ABR License file is available for download at 14 | ``_. 15 | The entire contents of this ``LICENSE.rst`` file, including any 16 | terms and conditions herein, form part of the ABR License. 17 | 18 | Commercial Use Licenses are available to purchase for a yearly fee. 19 | Academic and Personal Use Licenses for PyTorchSpiking are available at 20 | no cost. 21 | Both types of licences can be obtained from the 22 | ABR store at ``_. 23 | 24 | If you have any sales questions, 25 | please contact ``_. 26 | If you have any technical support questions, please post them on the ABR 27 | community forums at ``_ or contact 28 | ``_. 29 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Automatically generated by nengo-bones, do not edit this file directly 2 | 3 | global-include *.py 4 | global-include *.sh 5 | global-include *.template 6 | include *.rst 7 | 8 | # Include files for CI and recreating the source dist 9 | include *.yml 10 | include *.yaml 11 | include *.toml 12 | include MANIFEST.in 13 | include .gitlint 14 | include .pylintrc 15 | 16 | # Directories to include 17 | graft docs 18 | 19 | # Subdirectories to exclude, if they exist 20 | prune docs/_build 21 | prune dist 22 | prune .git 23 | prune .github 24 | prune .tox 25 | prune .eggs 26 | prune .ci 27 | prune bones-scripts 28 | 29 | # Exclude auto-generated files 30 | recursive-exclude docs *.py 31 | 32 | # Patterns to exclude from any directory 33 | global-exclude *.ipynb_checkpoints* 34 | global-exclude *-checkpoint.ipynb 35 | 36 | # Exclude all bytecode 37 | global-exclude *.pyc *.pyo *.pyd 38 | 39 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: https://img.shields.io/pypi/v/pytorch-spiking.svg 2 | :target: https://pypi.org/project/pytorch-spiking 3 | :alt: Latest PyPI version 4 | 5 | .. image:: https://img.shields.io/pypi/pyversions/pytorch-spiking.svg 6 | :target: https://pypi.org/project/pytorch-spiking 7 | :alt: Python versions 8 | 9 | ************** 10 | PyTorchSpiking 11 | ************** 12 | 13 | PyTorchSpiking provides tools for training and running spiking neural networks 14 | directly within the PyTorch framework. The main feature is 15 | ``pytorch_spiking.SpikingActivation``, which can be used to transform 16 | any activation function into a spiking equivalent. For example, we can translate a 17 | non-spiking model, such as 18 | 19 | .. code-block:: python 20 | 21 | torch.nn.Sequential( 22 | torch.nn.Linear(5, 10), 23 | torch.nn.ReLU(), 24 | ) 25 | 26 | into the spiking equivalent: 27 | 28 | .. code-block:: python 29 | 30 | torch.nn.Sequential( 31 | torch.nn.Linear(5, 10), 32 | pytorch_spiking.SpikingActivation(torch.nn.ReLU()), 33 | ) 34 | 35 | Models with SpikingActivation layers can be optimized and evaluated in the same way as 36 | any other PyTorch model. They will automatically take advantage of PyTorchSpiking's 37 | "spiking aware training": using the spiking activations on the forward pass and the 38 | non-spiking (differentiable) activation function on the backwards pass. 39 | 40 | PyTorchSpiking also includes various tools to assist in the training of spiking models, 41 | such as `filtering layers 42 | `_. 43 | 44 | If you are interested in building and optimizing spiking neuron models, you may also 45 | be interested in `NengoDL `_. See 46 | `this page `_ for a 47 | comparison of the different use cases supported by these two packages. 48 | 49 | **Documentation** 50 | 51 | Check out the `documentation `_ for 52 | 53 | - `Installation instructions 54 | `_ 55 | - `More detailed example introducing the features of PyTorchSpiking 56 | `_ 57 | - `API reference `_ 58 | -------------------------------------------------------------------------------- /docs/_static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengo/pytorch-spiking/e42d24c0955c1b1ef12ed9609d74941126f30e46/docs/_static/favicon.ico -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Automatically generated by nengo-bones, do not edit this file directly 4 | 5 | import pathlib 6 | 7 | import pytorch_spiking 8 | 9 | extensions = [ 10 | "sphinx.ext.autodoc", 11 | "sphinx.ext.autosummary", 12 | "sphinx.ext.doctest", 13 | "sphinx.ext.githubpages", 14 | "sphinx.ext.intersphinx", 15 | "sphinx.ext.mathjax", 16 | "sphinx.ext.todo", 17 | "nbsphinx", 18 | "nengo_sphinx_theme", 19 | "nengo_sphinx_theme.ext.backoff", 20 | "nengo_sphinx_theme.ext.sourcelinks", 21 | "notfound.extension", 22 | "numpydoc", 23 | "nengo_sphinx_theme.ext.autoautosummary", 24 | ] 25 | 26 | # -- sphinx.ext.autodoc 27 | autoclass_content = "both" # class and __init__ docstrings are concatenated 28 | autodoc_default_options = {"members": None} 29 | autodoc_member_order = "bysource" # default is alphabetical 30 | 31 | # -- sphinx.ext.doctest 32 | doctest_global_setup = """ 33 | import pytorch_spiking 34 | import numpy as np 35 | import torch 36 | """ 37 | 38 | # -- sphinx.ext.intersphinx 39 | intersphinx_mapping = { 40 | "nengo": ("https://www.nengo.ai/nengo/", None), 41 | "numpy": ("https://numpy.org/doc/stable", None), 42 | "python": ("https://docs.python.org/3", None), 43 | } 44 | 45 | # -- sphinx.ext.todo 46 | todo_include_todos = True 47 | 48 | # -- nbsphinx 49 | nbsphinx_timeout = -1 50 | 51 | # -- notfound.extension 52 | notfound_template = "404.html" 53 | notfound_urls_prefix = "/pytorch-spiking/" 54 | 55 | # -- numpydoc config 56 | numpydoc_show_class_members = False 57 | 58 | # -- nengo_sphinx_theme.ext.autoautosummary 59 | autoautosummary_change_modules = { 60 | "pytorch_spiking": [ 61 | "pytorch_spiking.modules.SpikingActivation", 62 | "pytorch_spiking.modules.Lowpass", 63 | "pytorch_spiking.modules.TemporalAvgPool", 64 | ], 65 | } 66 | 67 | # -- nengo_sphinx_theme.ext.sourcelinks 68 | sourcelinks_module = "pytorch_spiking" 69 | sourcelinks_url = "https://github.com/nengo/pytorch-spiking" 70 | 71 | # -- sphinx 72 | nitpicky = True 73 | exclude_patterns = [ 74 | "_build", 75 | "**/.ipynb_checkpoints", 76 | ] 77 | linkcheck_timeout = 30 78 | source_suffix = ".rst" 79 | source_encoding = "utf-8" 80 | master_doc = "index" 81 | linkcheck_ignore = [r"http://localhost:\d+"] 82 | linkcheck_anchors = True 83 | default_role = "py:obj" 84 | pygments_style = "sphinx" 85 | user_agent = "pytorch_spiking" 86 | 87 | project = "PyTorchSpiking" 88 | authors = "Applied Brain Research" 89 | copyright = "2020-2023 Applied Brain Research" 90 | version = ".".join(pytorch_spiking.__version__.split(".")[:2]) # Short X.Y version 91 | release = pytorch_spiking.__version__ # Full version, with tags 92 | 93 | # -- HTML output 94 | templates_path = ["_templates"] 95 | html_static_path = ["_static"] 96 | html_theme = "nengo_sphinx_theme" 97 | html_title = f"PyTorchSpiking {release} docs" 98 | htmlhelp_basename = "PyTorchSpiking" 99 | html_last_updated_fmt = "" # Default output format (suppressed) 100 | html_show_sphinx = False 101 | html_favicon = str(pathlib.Path("_static", "favicon.ico")) 102 | html_theme_options = { 103 | "nengo_logo": "", 104 | "nengo_logo_color": "#a8acaf", 105 | "analytics": """ 106 | 107 | 108 | 114 | 115 | 116 | 133 | 134 | """, 135 | } 136 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | These examples can be found in the ``/docs/examples`` directory 5 | (where ```` is the location of the PyTorchSpiking package). The 6 | examples are Jupyter notebooks; if you would like to run them yourself, refer to 7 | the `Jupyter documentation `_. 8 | 9 | .. toctree:: 10 | 11 | examples/spiking-fashion-mnist 12 | -------------------------------------------------------------------------------- /docs/examples/spiking-fashion-mnist.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Classifying Fashion MNIST with spiking activations\n", 8 | "\n", 9 | "In this example we assume that you are already familiar with building and training\n", 10 | "standard, non-spiking neural networks in PyTorch. We would recommend checking out the\n", 11 | "[PyTorch\n", 12 | "documentation](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html)\n", 13 | "if you would like a more basic introduction to how PyTorch works. In this example we\n", 14 | "will walk through how we can convert a non-spiking model into a spiking model using\n", 15 | "PyTorchSpiking, and various techniques that can be used to fine tune performance." 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "# pylint: disable=redefined-outer-name\n", 25 | "\n", 26 | "import matplotlib.pyplot as plt\n", 27 | "import numpy as np\n", 28 | "import torch\n", 29 | "import torchvision\n", 30 | "\n", 31 | "import pytorch_spiking\n", 32 | "\n", 33 | "torch.manual_seed(0)\n", 34 | "np.random.seed(0)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Loading data\n", 42 | "\n", 43 | "We'll begin by loading the Fashion MNIST data:" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "train_images, train_labels = zip(\n", 53 | " *torchvision.datasets.FashionMNIST(\".\", train=True, download=True)\n", 54 | ")\n", 55 | "train_images = np.asarray([np.array(img) for img in train_images], dtype=np.float32)\n", 56 | "train_labels = np.asarray(train_labels, dtype=np.int64)\n", 57 | "test_images, test_labels = zip(\n", 58 | " *torchvision.datasets.FashionMNIST(\".\", train=False, download=True)\n", 59 | ")\n", 60 | "test_images = np.asarray([np.array(img) for img in train_images], dtype=np.float32)\n", 61 | "test_labels = np.asarray(train_labels, dtype=np.int64)\n", 62 | "\n", 63 | "# normalize images so values are between 0 and 1\n", 64 | "train_images = train_images / 255.0\n", 65 | "test_images = test_images / 255.0\n", 66 | "\n", 67 | "class_names = [\n", 68 | " \"T-shirt/top\",\n", 69 | " \"Trouser\",\n", 70 | " \"Pullover\",\n", 71 | " \"Dress\",\n", 72 | " \"Coat\",\n", 73 | " \"Sandal\",\n", 74 | " \"Shirt\",\n", 75 | " \"Sneaker\",\n", 76 | " \"Bag\",\n", 77 | " \"Ankle boot\",\n", 78 | "]\n", 79 | "num_classes = len(class_names)\n", 80 | "\n", 81 | "plt.figure(figsize=(10, 10))\n", 82 | "for i in range(25):\n", 83 | " plt.subplot(5, 5, i + 1)\n", 84 | " plt.imshow(train_images[i], cmap=plt.cm.binary)\n", 85 | " plt.axis(\"off\")\n", 86 | " plt.title(class_names[train_labels[i]])" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "## Non-spiking model\n", 94 | "\n", 95 | "Next we'll build and train a simple non-spiking model to classify the Fashion MNIST\n", 96 | "images." 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "model = torch.nn.Sequential(\n", 106 | " torch.nn.Linear(784, 128),\n", 107 | " torch.nn.ReLU(),\n", 108 | " torch.nn.Linear(128, 10),\n", 109 | ")\n", 110 | "\n", 111 | "\n", 112 | "def train(input_model, train_x, test_x):\n", 113 | " minibatch_size = 32\n", 114 | " optimizer = torch.optim.Adam(input_model.parameters())\n", 115 | "\n", 116 | " input_model.train()\n", 117 | " for j in range(10):\n", 118 | " train_acc = 0\n", 119 | " for i in range(train_x.shape[0] // minibatch_size):\n", 120 | " input_model.zero_grad()\n", 121 | "\n", 122 | " batch_in = train_x[i * minibatch_size : (i + 1) * minibatch_size]\n", 123 | " # flatten images\n", 124 | " batch_in = batch_in.reshape((-1,) + train_x.shape[1:-2] + (784,))\n", 125 | " batch_label = train_labels[i * minibatch_size : (i + 1) * minibatch_size]\n", 126 | " output = input_model(torch.tensor(batch_in))\n", 127 | "\n", 128 | " # compute sparse categorical cross entropy loss\n", 129 | " logp = torch.nn.functional.log_softmax(output, dim=-1)\n", 130 | " logpy = torch.gather(logp, 1, torch.tensor(batch_label).view(-1, 1))\n", 131 | " loss = -logpy.mean()\n", 132 | "\n", 133 | " loss.backward()\n", 134 | " optimizer.step()\n", 135 | "\n", 136 | " train_acc += torch.mean(\n", 137 | " torch.eq(torch.argmax(output, dim=1), torch.tensor(batch_label)).float()\n", 138 | " )\n", 139 | "\n", 140 | " train_acc /= i + 1\n", 141 | " print(f\"Train accuracy ({j}): {train_acc.numpy()}\")\n", 142 | "\n", 143 | " # compute test accuracy\n", 144 | " input_model.eval()\n", 145 | " test_acc = 0\n", 146 | " for i in range(test_x.shape[0] // minibatch_size):\n", 147 | " batch_in = test_x[i * minibatch_size : (i + 1) * minibatch_size]\n", 148 | " batch_in = batch_in.reshape((-1,) + test_x.shape[1:-2] + (784,))\n", 149 | " batch_label = test_labels[i * minibatch_size : (i + 1) * minibatch_size]\n", 150 | " output = input_model(torch.tensor(batch_in))\n", 151 | "\n", 152 | " test_acc += torch.mean(\n", 153 | " torch.eq(torch.argmax(output, dim=1), torch.tensor(batch_label)).float()\n", 154 | " )\n", 155 | "\n", 156 | " test_acc /= i + 1\n", 157 | "\n", 158 | " print(f\"Test accuracy: {test_acc.numpy()}\")\n", 159 | "\n", 160 | "\n", 161 | "train(model, train_images, test_images)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "## Spiking model\n", 169 | "\n", 170 | "Next we will create an equivalent spiking model. There are three important changes here:\n", 171 | "\n", 172 | "1. Add a temporal dimension to the data/model.\n", 173 | "\n", 174 | "Spiking models always run over time (i.e., each forward pass through the model will run\n", 175 | "for some number of timesteps). This means that we need to add a temporal dimension to\n", 176 | "the data, so instead of having shape `(batch_size, ...)` it will have shape\n", 177 | "`(batch_size, n_steps, ...)`. For those familiar with working with RNNs, the principles\n", 178 | "are the same; a spiking neuron accepts temporal data and computes over time, just like\n", 179 | "an RNN.\n", 180 | "\n", 181 | "2. Replace any activation functions with `pytorch_spiking.SpikingActivation`.\n", 182 | "\n", 183 | "`pytorch_spiking.SpikingActivation` can encapsulate any activation function, and will\n", 184 | "produce an equivalent spiking implementation. Neurons will spike at a rate proportional\n", 185 | "to the output of the base activation function. For example, if the activation function\n", 186 | "is outputting a value of 10, then the wrapped `SpikingActivation` will output spikes at\n", 187 | "a rate of 10Hz (i.e., 10 spikes per 1 simulated second, where 1 simulated second is\n", 188 | "equivalent to some number of timesteps, determined by the `dt` parameter of\n", 189 | "`SpikingActivation`).\n", 190 | "\n", 191 | "3. Pool across time\n", 192 | "\n", 193 | "The output of our `pytorch_spiking.SpikingActivation` layer is also a timeseries. For\n", 194 | "classification, we need to aggregate that temporal information somehow to generate a\n", 195 | "final prediction. Averaging the output over time is usually a good approach (but not the\n", 196 | "only method; we could also, e.g., look at the output on the last timestep or the time to\n", 197 | "first spike). We add a `pytorch_spiking.TemporalAvgPool` layer to average across the\n", 198 | "temporal dimension of the data." 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "metadata": {}, 205 | "outputs": [], 206 | "source": [ 207 | "# repeat the images for n_steps\n", 208 | "n_steps = 10\n", 209 | "train_sequences = np.tile(train_images[:, None], (1, n_steps, 1, 1))\n", 210 | "test_sequences = np.tile(test_images[:, None], (1, n_steps, 1, 1))" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": null, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "spiking_model = torch.nn.Sequential(\n", 220 | " torch.nn.Linear(784, 128),\n", 221 | " # wrap ReLU in SpikingActivation\n", 222 | " pytorch_spiking.SpikingActivation(torch.nn.ReLU(), spiking_aware_training=False),\n", 223 | " # use average pooling layer to average spiking output over time\n", 224 | " pytorch_spiking.TemporalAvgPool(),\n", 225 | " torch.nn.Linear(128, 10),\n", 226 | ")\n", 227 | "\n", 228 | "# train the model, identically to the non-spiking version,\n", 229 | "# except using the time sequences as input\n", 230 | "train(spiking_model, train_sequences, test_sequences)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": {}, 236 | "source": [ 237 | "We can see that while the training accuracy is as good as we expect, the test accuracy\n", 238 | "is not. This is due to a unique feature of `SpikingActivation`; it will automatically\n", 239 | "swap the behaviour of the spiking neurons during training. Because spiking neurons are\n", 240 | "(in general) not differentiable, we cannot directly use the spiking activation function\n", 241 | "during training. Instead, SpikingActivation will use the base (non-spiking) activation\n", 242 | "during training, and the spiking version during inference. So during training above we\n", 243 | "are seeing the performance of the non-spiking model, but during evaluation we are seeing\n", 244 | "the performance of the spiking model.\n", 245 | "\n", 246 | "So the question is, why is the performance of the spiking model so much worse than the\n", 247 | "non-spiking equivalent, and what can we do to fix that?" 248 | ] 249 | }, 250 | { 251 | "cell_type": "markdown", 252 | "metadata": {}, 253 | "source": [ 254 | "## Simulation time\n", 255 | "\n", 256 | "Let's visualize the output of the spiking model, to get a better sense of what is going\n", 257 | "on." 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "def check_output(seq_model, modify_dt=None): # noqa: C901\n", 267 | " \"\"\"\n", 268 | " This code is only used for plotting purposes, and isn't necessary to\n", 269 | " understand the rest of this example; feel free to skip it\n", 270 | " if you just want to see the results.\n", 271 | " \"\"\"\n", 272 | "\n", 273 | " # rebuild the model in a form that will let us access the output of\n", 274 | " # intermediate layers\n", 275 | " class Model(torch.nn.Module):\n", 276 | " def __init__(self):\n", 277 | " super().__init__()\n", 278 | "\n", 279 | " self.has_temporal_pooling = False\n", 280 | " for i, module in enumerate(seq_model.modules()):\n", 281 | " if isinstance(module, pytorch_spiking.TemporalAvgPool):\n", 282 | " # remove the pooling so that we can see the model's output over time\n", 283 | " self.has_temporal_pooling = True\n", 284 | " continue\n", 285 | "\n", 286 | " if isinstance(\n", 287 | " module, (pytorch_spiking.SpikingActivation, pytorch_spiking.Lowpass)\n", 288 | " ):\n", 289 | " # update dt, if specified\n", 290 | " if modify_dt is not None:\n", 291 | " module.dt = modify_dt\n", 292 | " # always return the full time series so we can visualize it\n", 293 | " module.return_sequences = True\n", 294 | "\n", 295 | " if isinstance(module, pytorch_spiking.SpikingActivation):\n", 296 | " # save this layer so we can access it later\n", 297 | " self.spike_layer = module\n", 298 | "\n", 299 | " if i > 0:\n", 300 | " self.add_module(str(i), module)\n", 301 | "\n", 302 | " def forward(self, inputs):\n", 303 | " x = inputs\n", 304 | "\n", 305 | " for i, module in enumerate(self.modules()):\n", 306 | " if i > 0:\n", 307 | " x = module(x)\n", 308 | "\n", 309 | " if isinstance(module, pytorch_spiking.SpikingActivation):\n", 310 | " # save this layer so we can access it later\n", 311 | " spike_output = x\n", 312 | "\n", 313 | " return x, spike_output\n", 314 | "\n", 315 | " func_model = Model()\n", 316 | "\n", 317 | " # run model\n", 318 | " func_model.eval()\n", 319 | " with torch.no_grad():\n", 320 | " output, spikes = func_model(\n", 321 | " torch.tensor(\n", 322 | " test_sequences.reshape(\n", 323 | " test_sequences.shape[0], test_sequences.shape[1], -1\n", 324 | " )\n", 325 | " )\n", 326 | " )\n", 327 | " output = output.numpy()\n", 328 | " spikes = spikes.numpy()\n", 329 | "\n", 330 | " if func_model.has_temporal_pooling:\n", 331 | " # check test accuracy using average output over all timesteps\n", 332 | " predictions = np.argmax(output.mean(axis=1), axis=-1)\n", 333 | " else:\n", 334 | " # check test accuracy using output from last timestep\n", 335 | " predictions = np.argmax(output[:, -1], axis=-1)\n", 336 | " accuracy = np.equal(predictions, test_labels).mean()\n", 337 | " print(f\"Test accuracy: {100 * accuracy:.2f}\")\n", 338 | "\n", 339 | " time = test_sequences.shape[1] * func_model.spike_layer.dt\n", 340 | " n_spikes = spikes * func_model.spike_layer.dt\n", 341 | " rates = np.sum(n_spikes, axis=1) / time\n", 342 | "\n", 343 | " print(\n", 344 | " f\"Spike rate per neuron (Hz): \"\n", 345 | " f\"min={np.min(rates):.2f} mean={np.mean(rates):.2f} max={np.max(rates):.2f}\"\n", 346 | " )\n", 347 | "\n", 348 | " # plot output\n", 349 | " for ii in range(4):\n", 350 | " plt.figure(figsize=(12, 4))\n", 351 | "\n", 352 | " plt.subplot(1, 3, 1)\n", 353 | " plt.title(class_names[test_labels[ii]])\n", 354 | " plt.imshow(test_images[ii], cmap=\"gray\")\n", 355 | " plt.axis(\"off\")\n", 356 | "\n", 357 | " plt.subplot(1, 3, 2)\n", 358 | " plt.title(\"Spikes per neuron per timestep\")\n", 359 | " bin_edges = np.arange(int(np.max(n_spikes[ii])) + 2) - 0.5\n", 360 | " plt.hist(np.ravel(n_spikes[ii]), bins=bin_edges)\n", 361 | " x_ticks = plt.xticks()[0]\n", 362 | " plt.xticks(\n", 363 | " x_ticks[(np.abs(x_ticks - np.round(x_ticks)) < 1e-8) & (x_ticks > -1e-8)]\n", 364 | " )\n", 365 | " plt.xlabel(\"# of spikes\")\n", 366 | " plt.ylabel(\"Frequency\")\n", 367 | "\n", 368 | " plt.subplot(1, 3, 3)\n", 369 | " plt.title(\"Output predictions\")\n", 370 | " plt.plot(\n", 371 | " np.arange(test_sequences.shape[1]) * func_model.spike_layer.dt,\n", 372 | " torch.softmax(torch.tensor(output[ii]), dim=-1),\n", 373 | " )\n", 374 | " plt.legend(class_names, loc=\"upper left\")\n", 375 | " plt.xlabel(\"Time (s)\")\n", 376 | " plt.ylabel(\"Probability\")\n", 377 | " plt.ylim([-0.05, 1.05])\n", 378 | "\n", 379 | " plt.tight_layout()" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": null, 385 | "metadata": {}, 386 | "outputs": [], 387 | "source": [ 388 | "check_output(spiking_model)" 389 | ] 390 | }, 391 | { 392 | "cell_type": "markdown", 393 | "metadata": {}, 394 | "source": [ 395 | "We can see an immediate problem: the neurons are hardly spiking at all. The mean number\n", 396 | "of spikes we're getting out of each neuron in our SpikingActivation layer is very close\n", 397 | "to zero, and as a result the output is mostly flat.\n", 398 | "\n", 399 | "To help understand why, we need to think more about the temporal nature of spiking\n", 400 | "neurons. Recall that the layer is set up such that if the base activation function were\n", 401 | "to be outputting a value of 1, the spiking equivalent would be spiking at 1Hz (i.e.,\n", 402 | "emitting one spike per second). In the above example we are simulating for 10 timesteps,\n", 403 | "with the default `dt` of 0.001s, so we're simulating a total of 0.01s. If our neurons\n", 404 | "aren't spiking very rapidly, and we're only simulating for 0.01s, then it's not\n", 405 | "surprising that we aren't getting any spikes in that time window.\n", 406 | "\n", 407 | "We can increase the value of `dt`, effectively running the spiking neurons for longer,\n", 408 | "in order to get a more accurate measure of the neuron's output. Basically this allows us\n", 409 | "to collect more spikes from each neuron, giving us a better estimate of the neuron's\n", 410 | "actual spike rate. We can see how the number of spikes and accuracy change as we\n", 411 | "increase `dt`:" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "# dt=0.01 * 10 timesteps is equivalent to 0.1s of simulated time\n", 421 | "check_output(spiking_model, modify_dt=0.01)" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "check_output(spiking_model, modify_dt=0.1)" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "metadata": {}, 437 | "outputs": [], 438 | "source": [ 439 | "check_output(spiking_model, modify_dt=1)" 440 | ] 441 | }, 442 | { 443 | "cell_type": "markdown", 444 | "metadata": {}, 445 | "source": [ 446 | "We can see that as we increase `dt` the performance of the spiking model increasingly\n", 447 | "approaches the non-spiking performance. In addition, as `dt` increases, the number of\n", 448 | "spikes is increasing. To understand why this improves accuracy, keep in mind that\n", 449 | "although the simulated time is increasing, the actual number of timesteps is still 10 in\n", 450 | "all cases. We're effectively binning all the spikes that occur on each time step. So as\n", 451 | "our bin sizes get larger (increasing `dt`), the spike counts will more closely\n", 452 | "approximate the \"true\" output of the underlying non-spiking activation function.\n", 453 | "\n", 454 | "One might be tempted to simply increase `dt` to a very large value, and thereby always\n", 455 | "get great performance. But keep in mind that when we do that we have likely lost any of\n", 456 | "the advantages that were motivating us to investigate spiking models in the first place.\n", 457 | "For example, one prominent advantage of spiking models is temporal sparsity (we only\n", 458 | "need to communicate occasional spikes, rather than continuous values). However, with\n", 459 | "large `dt` the neurons are likely spiking every simulation time step (or multiple times\n", 460 | "per timestep), so the activity is no longer temporally sparse.\n", 461 | "\n", 462 | "Thus setting `dt` represents a trade-off between accuracy and temporal sparsity.\n", 463 | "Choosing the appropriate value will depend on the demands of your application." 464 | ] 465 | }, 466 | { 467 | "cell_type": "markdown", 468 | "metadata": {}, 469 | "source": [ 470 | "## Spiking aware training\n", 471 | "\n", 472 | "As mentioned above, by default SpikingActivation layers will use the non-spiking\n", 473 | "activation function during training and the spiking version during inference. However,\n", 474 | "similar to the idea of\n", 475 | "[quantization aware\n", 476 | "training](https://www.tensorflow.org/model_optimization/guide/quantization/training),\n", 477 | "often we can improve performance by partially incorporating spiking behaviour during\n", 478 | "training. Specifically, we will use the spiking activation on the forward pass, while\n", 479 | "still using the non-spiking version on the backwards pass. This allows the model to\n", 480 | "learn weights that account for the discrete, temporal nature of the spiking activities." 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "spikeaware_model = torch.nn.Sequential(\n", 490 | " torch.nn.Linear(784, 128),\n", 491 | " # set spiking_aware_training and a moderate dt\n", 492 | " pytorch_spiking.SpikingActivation(\n", 493 | " torch.nn.ReLU(), dt=0.01, spiking_aware_training=True\n", 494 | " ),\n", 495 | " pytorch_spiking.TemporalAvgPool(),\n", 496 | " torch.nn.Linear(128, 10),\n", 497 | ")\n", 498 | "\n", 499 | "train(spikeaware_model, train_sequences, test_sequences)" 500 | ] 501 | }, 502 | { 503 | "cell_type": "code", 504 | "execution_count": null, 505 | "metadata": {}, 506 | "outputs": [], 507 | "source": [ 508 | "check_output(spikeaware_model)" 509 | ] 510 | }, 511 | { 512 | "cell_type": "markdown", 513 | "metadata": {}, 514 | "source": [ 515 | "We can see that with `spiking_aware_training` we're getting better performance than we\n", 516 | "were with the equivalent `dt` value above. The model has learned weights that are less\n", 517 | "sensitive to the discrete, sparse output produced by the spiking neurons." 518 | ] 519 | }, 520 | { 521 | "cell_type": "markdown", 522 | "metadata": {}, 523 | "source": [ 524 | "## Spike rate regularization\n", 525 | "\n", 526 | "As we saw in the [Simulation time section](#Simulation-time), the spiking rate of the\n", 527 | "neurons is very important. If a neuron is spiking too slowly then we don't have enough\n", 528 | "information to determine its output value. Conversely, if a neuron is spiking too\n", 529 | "quickly then we may lose the spiking advantages we are looking for, such as temporal\n", 530 | "sparsity.\n", 531 | "\n", 532 | "Thus it can be helpful to more directly control the firing rates in the model by\n", 533 | "applying regularization penalties during training. For example, we could add an L2\n", 534 | "penalty to the output of the spiking activation layer." 535 | ] 536 | }, 537 | { 538 | "cell_type": "code", 539 | "execution_count": null, 540 | "metadata": {}, 541 | "outputs": [], 542 | "source": [ 543 | "# construct model using a generic Module so that we can\n", 544 | "# access the spiking activations in our loss function\n", 545 | "class Model(torch.nn.Module):\n", 546 | " def __init__(self):\n", 547 | " super().__init__()\n", 548 | "\n", 549 | " self.dense0 = torch.nn.Linear(784, 128)\n", 550 | " self.spiking_activation = pytorch_spiking.SpikingActivation(\n", 551 | " torch.nn.ReLU(), dt=0.01, spiking_aware_training=True\n", 552 | " )\n", 553 | " self.temporal_pooling = pytorch_spiking.TemporalAvgPool()\n", 554 | " self.dense1 = torch.nn.Linear(128, 10)\n", 555 | "\n", 556 | " def forward(self, inputs):\n", 557 | " x = self.dense0(inputs)\n", 558 | " spikes = self.spiking_activation(x)\n", 559 | " spike_rates = self.temporal_pooling(spikes)\n", 560 | " output = self.dense1(spike_rates)\n", 561 | "\n", 562 | " return output, spike_rates\n", 563 | "\n", 564 | "\n", 565 | "regularized_model = Model()\n", 566 | "\n", 567 | "minibatch_size = 32\n", 568 | "optimizer = torch.optim.Adam(regularized_model.parameters())\n", 569 | "\n", 570 | "regularized_model.train()\n", 571 | "for j in range(10):\n", 572 | " train_acc = 0\n", 573 | " for i in range(train_sequences.shape[0] // minibatch_size):\n", 574 | " regularized_model.zero_grad()\n", 575 | "\n", 576 | " batch_in = train_sequences[i * minibatch_size : (i + 1) * minibatch_size]\n", 577 | " batch_in = batch_in.reshape((-1,) + train_sequences.shape[1:-2] + (784,))\n", 578 | " batch_label = train_labels[i * minibatch_size : (i + 1) * minibatch_size]\n", 579 | " output, spike_rates = regularized_model(torch.tensor(batch_in))\n", 580 | "\n", 581 | " # compute sparse categorical cross entropy loss\n", 582 | " logp = torch.nn.functional.log_softmax(output, dim=-1)\n", 583 | " logpy = torch.gather(logp, 1, torch.tensor(batch_label).view(-1, 1))\n", 584 | " loss = -logpy.mean()\n", 585 | "\n", 586 | " # add activity regularization\n", 587 | " reg_weight = 1e-3 # weight on regularization penalty\n", 588 | " target_rate = 20 # target spike rate (in Hz)\n", 589 | " loss += reg_weight * torch.mean(\n", 590 | " torch.sum((spike_rates - target_rate) ** 2, dim=-1)\n", 591 | " )\n", 592 | "\n", 593 | " loss.backward()\n", 594 | " optimizer.step()\n", 595 | "\n", 596 | " train_acc += torch.mean(\n", 597 | " torch.eq(torch.argmax(output, dim=1), torch.tensor(batch_label)).float()\n", 598 | " )\n", 599 | "\n", 600 | " train_acc /= i + 1\n", 601 | " print(f\"Train accuracy ({j}): {train_acc.numpy()}\")" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": null, 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [ 610 | "check_output(regularized_model)" 611 | ] 612 | }, 613 | { 614 | "cell_type": "markdown", 615 | "metadata": {}, 616 | "source": [ 617 | "We can see that the spike rates have moved towards the 20 Hz target we specified.\n", 618 | "However, the test accuracy has dropped, since we're adding an additional optimization\n", 619 | "constraint. (The accuracy is still higher than the original result with `dt=0.01`, due\n", 620 | "to the higher spike rates.) We could lower the regularization weight to allow more\n", 621 | "freedom in the firing rates. Again, this is a tradeoff that is made between controlling\n", 622 | "the firing rates and optimizing accuracy, and the best value for that tradeoff will\n", 623 | "depend on the particular application (e.g., how important is it that spike rates fall\n", 624 | "within a particular range?)." 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "metadata": {}, 630 | "source": [ 631 | "## Lowpass filtering\n", 632 | "\n", 633 | "Another tool we can employ when working with SpikingActivation layers is filtering. As\n", 634 | "we've seen, the output of a spiking layer consists of discrete, temporally sparse spike\n", 635 | "events. This makes it difficult to determine the spike rate of a neuron when just\n", 636 | "looking at a single timestep. In the cases above we have worked around this by using a\n", 637 | "`TemporalAveragePooling` layer to average the output across all timesteps before\n", 638 | "classification.\n", 639 | "\n", 640 | "Another way to achieve this is to compute some kind of moving average of the spiking\n", 641 | "output across timesteps. This is effectively what filtering is doing. PyTorchSpiking\n", 642 | "contains a Lowpass layer, which implements a\n", 643 | "[lowpass filter](https://en.wikipedia.org/wiki/Low-pass_filter). This has a parameter\n", 644 | "`tau`, known as the filter time constant, which controls the degree of smoothing the\n", 645 | "layer will apply. Larger `tau` values will apply more smoothing, meaning that we're\n", 646 | "aggregating information across longer periods of time, but the output will also be\n", 647 | "slower to adapt to changes in the input.\n", 648 | "\n", 649 | "By default the `tau` values are trainable. We can use this in combination with spiking\n", 650 | "aware training to enable the model to learn time constants that best trade off spike\n", 651 | "noise versus response speed.\n", 652 | "\n", 653 | "Unlike `pytorch_spiking.TemporalAvgPool`, `pytorch_spiking.Lowpass` computes outputs for\n", 654 | "all timesteps by default. This makes it possible to apply filtering throughout the\n", 655 | "model—not only on the final layer—in the case that there are multiple spiking layers.\n", 656 | "For the final layer, we can pass `return_sequences=False` to have the layer only return\n", 657 | "the output of the final timestep, rather than the outputs of all timesteps." 658 | ] 659 | }, 660 | { 661 | "cell_type": "code", 662 | "execution_count": null, 663 | "metadata": {}, 664 | "outputs": [], 665 | "source": [ 666 | "dt = 0.01\n", 667 | "\n", 668 | "filtered_model = torch.nn.Sequential(\n", 669 | " torch.nn.Linear(784, 128),\n", 670 | " pytorch_spiking.SpikingActivation(\n", 671 | " torch.nn.ReLU(), spiking_aware_training=True, dt=dt\n", 672 | " ),\n", 673 | " # add a lowpass filter on output of spiking layer\n", 674 | " # note: the lowpass dt doesn't necessarily need to be the same as the\n", 675 | " # SpikingActivation dt, but it's probably a good idea to keep them in sync\n", 676 | " # so that if we change dt the relative effect of the lowpass filter is unchanged\n", 677 | " pytorch_spiking.Lowpass(units=128, tau=0.1, dt=dt, return_sequences=False),\n", 678 | " torch.nn.Linear(128, 10),\n", 679 | ")\n", 680 | "\n", 681 | "train(filtered_model, train_sequences, test_sequences)" 682 | ] 683 | }, 684 | { 685 | "cell_type": "code", 686 | "execution_count": null, 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [ 690 | "check_output(filtered_model)" 691 | ] 692 | }, 693 | { 694 | "cell_type": "markdown", 695 | "metadata": {}, 696 | "source": [ 697 | "We can see that the model performs similarly to the previous\n", 698 | "[spiking aware training](#Spiking-aware-training) example, which makes sense since, for\n", 699 | "a static input image, a moving average is very similar to a global average. We would\n", 700 | "need a more complicated model, with multiple spiking layers or inputs that are changing\n", 701 | "over time, to really see the benefits of a Lowpass layer." 702 | ] 703 | }, 704 | { 705 | "cell_type": "markdown", 706 | "metadata": {}, 707 | "source": [ 708 | "## Summary\n", 709 | "\n", 710 | "We can use `SpikingActivation` layers to convert any activation function to an\n", 711 | "equivalent spiking implementation. Models with SpikingActivations can be trained and\n", 712 | "evaluated in the same way as non-spiking models, thanks to the swappable\n", 713 | "training/inference behaviour.\n", 714 | "\n", 715 | "There are also a number of additional features that should be kept in mind in order to\n", 716 | "optimize the performance of a spiking model:\n", 717 | "\n", 718 | "- [Simulation time](#Simulation-time): by adjusting `dt` we can trade off temporal\n", 719 | " sparsity versus accuracy\n", 720 | "- [Spiking aware training](#Spiking-aware-training): incorporating spiking dynamics on\n", 721 | " the forward pass can allow the model to learn weights that are more robust to spiking\n", 722 | " activations\n", 723 | "- [Spike rate regularization](#Spike-rate-regularization): we can gain more control over\n", 724 | " spike rates by directly incorporating activity regularization into the optimization\n", 725 | " process\n", 726 | "- [Lowpass filtering](#Lowpass-filtering): we can achieve better accuracy with fewer\n", 727 | " spikes by aggregating spike data over time" 728 | ] 729 | } 730 | ], 731 | "metadata": { 732 | "language_info": { 733 | "name": "python", 734 | "pygments_lexer": "ipython3" 735 | } 736 | }, 737 | "nbformat": 4, 738 | "nbformat_minor": 4 739 | } 740 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | :end-before: **Documentation** 3 | 4 | .. toctree:: 5 | 6 | installation 7 | examples 8 | reference 9 | project 10 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Installing PyTorchSpiking 5 | ------------------------- 6 | We recommend using ``pip`` to install PyTorchSpiking: 7 | 8 | .. code-block:: bash 9 | 10 | pip install pytorch-spiking 11 | 12 | That's it! 13 | 14 | Requirements 15 | ------------ 16 | PyTorchSpiking works with Python 3.6 or later. ``pip`` will do its best to install 17 | all of PyTorchSpiking's requirements automatically. However, if anything 18 | goes wrong during this process you can install the requirements manually and 19 | then try to ``pip install pytorch-spiking`` again. 20 | 21 | Developer installation 22 | ---------------------- 23 | If you want to modify PyTorchSpiking, or get the very latest updates, you will need to 24 | perform a developer installation: 25 | 26 | .. code-block:: bash 27 | 28 | git clone https://github.com/nengo/pytorch-spiking.git 29 | pip install -e ./pytorch-spiking 30 | 31 | Installing PyTorch 32 | --------------------- 33 | The PyTorch documentation has a 34 | `useful tool `_ to determine 35 | the appropriate command to install pytorch on your system. We would recommend using 36 | ``conda``, as it will take care of installing the other GPU-related packages, like 37 | CUDA/cuDNN. 38 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../LICENSE.rst 2 | -------------------------------------------------------------------------------- /docs/nengo-dl-comparison.rst: -------------------------------------------------------------------------------- 1 | PyTorchSpiking versus NengoDL 2 | ============================= 3 | 4 | If you are interested in combining spiking neurons and deep learning methods, you may 5 | be familiar with `NengoDL `_ (and wondering what the 6 | difference is between PyTorchSpiking and NengoDL). 7 | 8 | The short answer is that PyTorchSpiking is designed to be a lightweight, minimal 9 | implementation of spiking behaviour that integrates very transparently into PyTorch. 10 | It is designed to get you up and running on building a spiking model with very little 11 | overhead. 12 | 13 | NengoDL provides much more robust, fully-featured tools for building spiking models. 14 | More neuron types, more synapse types, more complex network architectures, more of 15 | everything basically. However, all of those extra features require a more significant 16 | departure from the PyTorch API. There is more of a learning curve to 17 | getting started with NengoDL, and because NengoDL is based on TensorFlow/Keras, the 18 | API is designed to be more familiar to those with Keras experience. 19 | 20 | One particularly significant distinction is that PyTorchSpiking does not really 21 | integrate with the rest of the Nengo ecosystem (e.g., it cannot run models built with 22 | the Nengo API, and models built with PyTorchSpiking cannot run on other Nengo 23 | platforms). 24 | In contrast, NengoDL can run any Nengo model, and models optimized in NengoDL can 25 | be run on other Nengo platforms (such as custom neuromorphic hardware, like NengoLoihi). 26 | 27 | In summary, you should use PyTorchSpiking if you want to get up and running with minimal 28 | departures from the standard PyTorch API. If you find yourself wishing for more control 29 | or more features to build your model, or you would like to run your model on different 30 | hardware platforms, consider checking out NengoDL. 31 | -------------------------------------------------------------------------------- /docs/project.rst: -------------------------------------------------------------------------------- 1 | Project information 2 | =================== 3 | 4 | .. toctree:: 5 | 6 | release-history 7 | nengo-dl-comparison 8 | contributing 9 | license 10 | -------------------------------------------------------------------------------- /docs/reference.rst: -------------------------------------------------------------------------------- 1 | API reference 2 | ============= 3 | 4 | Modules 5 | ------- 6 | 7 | .. automodule:: pytorch_spiking.modules 8 | 9 | .. autoautosummary:: pytorch_spiking.modules 10 | :nosignatures: 11 | 12 | Functions 13 | --------- 14 | 15 | .. automodule:: pytorch_spiking.functional 16 | 17 | .. autoautosummary:: pytorch_spiking.functional 18 | :nosignatures: 19 | -------------------------------------------------------------------------------- /docs/release-history.rst: -------------------------------------------------------------------------------- 1 | *************** 2 | Release history 3 | *************** 4 | 5 | .. This extra heading is here to add an extra layer of depth to the TOC, so that 6 | all the date headings don't end up filling up the sidebar toc tree 7 | TODO: there must be a cleaner way to do this 8 | 9 | Changelog 10 | ######### 11 | 12 | .. include:: ../CHANGES.rst 13 | :start-line: 2 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # Automatically generated by nengo-bones, do not edit this file directly 2 | 3 | [build-system] 4 | requires = ["setuptools<64", "wheel"] 5 | 6 | [tool.black] 7 | target-version = ['py36'] 8 | 9 | [tool.isort] 10 | profile = "black" 11 | src_paths = ["pytorch_spiking"] 12 | 13 | [tool.docformatter] 14 | wrap-summaries = 88 15 | wrap-descriptions = 81 16 | pre-summary-newline = true 17 | -------------------------------------------------------------------------------- /pytorch_spiking/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=missing-docstring 2 | 3 | __copyright__ = "2020-2021, Applied Brain Research" 4 | __license__ = "Free for non-commercial use; see LICENSE.rst" 5 | 6 | from pytorch_spiking import functional, modules 7 | from pytorch_spiking.modules import Lowpass, SpikingActivation, TemporalAvgPool 8 | from pytorch_spiking.version import version as __version__ 9 | -------------------------------------------------------------------------------- /pytorch_spiking/functional.py: -------------------------------------------------------------------------------- 1 | """Functional implementation of spiking layers.""" 2 | 3 | import torch 4 | 5 | 6 | class SpikingActivation(torch.autograd.Function): 7 | """ 8 | Function for converting an arbitrary activation function to a spiking equivalent. 9 | 10 | Notes 11 | ----- 12 | We would not recommend calling this directly, use 13 | `pytorch_spiking.SpikingActivation` instead. 14 | """ 15 | 16 | @staticmethod 17 | def forward( 18 | ctx, 19 | inputs, 20 | activation, 21 | dt=0.001, 22 | initial_state=None, 23 | spiking_aware_training=True, 24 | return_sequences=False, 25 | training=False, 26 | ): 27 | """ 28 | Forward pass of SpikingActivation function. 29 | 30 | Parameters 31 | ---------- 32 | inputs : ``torch.Tensor`` 33 | Array of input values with shape ``(batch_size, n_steps, n_neurons)``. 34 | activation : callable 35 | Activation function to be converted to spiking equivalent. 36 | dt : float 37 | Length of time (in seconds) represented by one time step. 38 | initial_state : ``torch.Tensor`` 39 | Initial spiking voltage state (should be an array with shape 40 | ``(batch_size, n_neurons)``, with values between 0 and 1). Will use a 41 | uniform distribution if none is specified. 42 | spiking_aware_training : bool 43 | If True (default), use the spiking activation function 44 | for the forward pass and the base activation function for the backward pass. 45 | If False, use the base activation function for the forward and 46 | backward pass during training. 47 | return_sequences : bool 48 | Whether to return the last output in the output sequence (default), or the 49 | full sequence. 50 | training : bool 51 | Whether this function should be executed in training or evaluation mode 52 | (this only matters if ``spiking_aware_training=False``). 53 | """ 54 | 55 | ctx.activation = activation 56 | ctx.return_sequences = return_sequences 57 | ctx.save_for_backward(inputs) 58 | 59 | if training and not spiking_aware_training: 60 | output = activation(inputs if return_sequences else inputs[:, -1]) 61 | return output 62 | 63 | if initial_state is None: 64 | initial_state = torch.rand( 65 | inputs.shape[0], inputs.shape[2], dtype=inputs.dtype 66 | ) 67 | 68 | # match inputs to initial state dtype if one was passed in 69 | inputs = inputs.type(initial_state.dtype) 70 | 71 | voltage = initial_state 72 | all_spikes = [] 73 | rates = activation(inputs) * dt 74 | for i in range(inputs.shape[1]): 75 | voltage += rates[:, i] 76 | n_spikes = torch.floor(voltage) 77 | voltage -= n_spikes 78 | if return_sequences: 79 | all_spikes.append(n_spikes) 80 | 81 | if return_sequences: 82 | output = torch.stack(all_spikes, dim=1) 83 | else: 84 | output = n_spikes 85 | 86 | output /= dt 87 | 88 | return output 89 | 90 | @staticmethod 91 | def backward(ctx, grad_output): 92 | """Backward pass of SpikingActivation function.""" 93 | 94 | # TODO: is there a way to reuse the forward pass activations computed in 95 | # `forward`? the below results in an infinite loop 96 | # inputs, rates = ctx.saved_tensors 97 | # return torch.autograd.grad(rates, inputs, grad_outputs=grad_output) 98 | 99 | inputs = ctx.saved_tensors[0] 100 | with torch.enable_grad(): 101 | output = ctx.activation(inputs if ctx.return_sequences else inputs[:, -1]) 102 | return ( 103 | torch.autograd.grad(output, inputs, grad_outputs=grad_output) 104 | + (None,) * 7 105 | ) 106 | 107 | 108 | spiking_activation = SpikingActivation.apply 109 | -------------------------------------------------------------------------------- /pytorch_spiking/modules.py: -------------------------------------------------------------------------------- 1 | """Modules for adding spiking behaviour to PyTorch models.""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from pytorch_spiking.functional import spiking_activation 7 | 8 | 9 | class SpikingActivation(torch.nn.Module): # pylint: disable=abstract-method 10 | """ 11 | Module for converting an arbitrary activation function to a spiking equivalent. 12 | 13 | Neurons will spike at a rate proportional to the output of the base activation 14 | function. For example, if the activation function is outputting a value of 10, then 15 | the wrapped SpikingActivationCell will output spikes at a rate of 10Hz (i.e., 10 16 | spikes per 1 simulated second, where 1 simulated second is equivalent to ``1/dt`` 17 | time steps). Each spike will have height ``1/dt`` (so that the integral of the 18 | spiking output will be the same as the integral of the base activation output). 19 | Note that if the base activation is outputting a negative value then the spikes 20 | will have height ``-1/dt``. Multiple spikes per timestep are also possible, in 21 | which case the output will be ``n/dt`` (where ``n`` is the number of spikes). 22 | 23 | When applying this layer to an input, make sure that the input has a time axis. 24 | The spiking output will be computed along the time axis. 25 | The number of simulation timesteps will depend on the length of that time axis. 26 | The number of timesteps does not need to be the same during 27 | training/evaluation/inference. In particular, it may be more efficient 28 | to use one timestep during training and multiple timesteps during inference 29 | (often with ``spiking_aware_training=False``, and ``apply_during_training=False`` 30 | on any `.Lowpass` layers). 31 | 32 | Parameters 33 | ---------- 34 | activation : callable 35 | Activation function to be converted to spiking equivalent. 36 | dt : float 37 | Length of time (in seconds) represented by one time step. 38 | initial_state : ``torch.Tensor`` 39 | Initial spiking voltage state (should be an array with shape 40 | ``(batch_size, n_neurons)``, with values between 0 and 1). Will use a uniform 41 | distribution if none is specified. 42 | spiking_aware_training : bool 43 | If True (default), use the spiking activation function 44 | for the forward pass and the base activation function for the backward pass. 45 | If False, use the base activation function for the forward and 46 | backward pass during training. 47 | return_sequences : bool 48 | Whether to return the full sequence of output spikes (default), 49 | or just the spikes on the last timestep. 50 | """ 51 | 52 | def __init__( 53 | self, 54 | activation, 55 | dt=0.001, 56 | initial_state=None, 57 | spiking_aware_training=True, 58 | return_sequences=True, 59 | ): 60 | """""" # empty docstring removes useless parent docstring from docs 61 | super().__init__() 62 | 63 | self.activation = activation 64 | self.initial_state = initial_state 65 | self.dt = dt 66 | self.spiking_aware_training = spiking_aware_training 67 | self.return_sequences = return_sequences 68 | 69 | def forward(self, inputs): 70 | """ 71 | Compute output spikes given inputs. 72 | 73 | Parameters 74 | ---------- 75 | inputs : ``torch.Tensor`` 76 | Array of input values with shape ``(batch_size, n_steps, n_neurons)``. 77 | 78 | Returns 79 | ------- 80 | outputs : ``torch.Tensor`` 81 | Array of output spikes with shape ``(batch_size, n_neurons)`` if 82 | ``return_sequences=False`` else ``(batch_size, n_steps, n_neurons)``. Each 83 | element will have value ``n/dt``, where ``n`` is the number of spikes 84 | emitted by that neuron on that time step. 85 | """ 86 | return spiking_activation( 87 | inputs, 88 | self.activation, 89 | self.dt, 90 | self.initial_state, 91 | self.spiking_aware_training, 92 | self.return_sequences, 93 | self.training, 94 | ) 95 | 96 | 97 | class Lowpass(torch.nn.Module): # pylint: disable=abstract-method 98 | """ 99 | Module implementing a Lowpass filter. 100 | 101 | The initial filter state and filter time constants are both trainable 102 | parameters. However, if ``apply_during_training=False`` then the parameters are 103 | not part of the training loop, and so will never be updated. 104 | 105 | When applying this layer to an input, make sure that the input has a time axis. 106 | 107 | Parameters 108 | ---------- 109 | tau : float 110 | Time constant of filter (in seconds). 111 | dt : float 112 | Length of time (in seconds) represented by one time step. 113 | apply_during_training : bool 114 | If False, this layer will effectively be ignored during training (this 115 | often makes sense in concert with the swappable training behaviour in, e.g., 116 | `.SpikingActivation`, since if the activations are not spiking during training 117 | then we often don't need to filter them either). 118 | level_initializer : ``torch.Tensor`` 119 | Initializer for filter state. 120 | return_sequences : bool 121 | Whether to return the full sequence of filtered output (default), 122 | or just the output on the last timestep. 123 | """ 124 | 125 | def __init__( 126 | self, 127 | tau, 128 | units, 129 | dt=0.001, 130 | apply_during_training=True, 131 | initial_level=None, 132 | return_sequences=True, 133 | ): 134 | """""" # empty docstring removes useless parent docstring from docs 135 | super().__init__() 136 | 137 | if tau <= 0: 138 | raise ValueError("tau must be a positive number") 139 | 140 | self.tau = tau 141 | self.units = units 142 | self.dt = dt 143 | self.apply_during_training = apply_during_training 144 | self.initial_level = initial_level 145 | self.return_sequences = return_sequences 146 | 147 | # apply ZOH discretization 148 | smoothing_init = np.exp(-self.dt / self.tau) 149 | 150 | # compute inverse sigmoid of tau, so that when we apply the sigmoid 151 | # later we'll get the tau value specified 152 | self.smoothing_init = np.log(smoothing_init / (1 - smoothing_init)) 153 | 154 | self.level_var = torch.nn.Parameter( 155 | torch.zeros(1, units) if self.initial_level is None else self.initial_level 156 | ) 157 | 158 | self.smoothing_var = torch.nn.Parameter( 159 | torch.ones(1, units) * self.smoothing_init 160 | ) 161 | 162 | def forward(self, inputs): 163 | """ 164 | Apply filter to inputs. 165 | 166 | Parameters 167 | ---------- 168 | inputs : ``torch.Tensor`` 169 | Array of input values with shape ``(batch_size, n_steps, units)``. 170 | 171 | Returns 172 | ------- 173 | outputs : ``torch.Tensor`` 174 | Array of output spikes with shape ``(batch_size, units)`` if 175 | ``return_sequences=False`` else ``(batch_size, n_steps, units)``. 176 | """ 177 | 178 | if self.training and not self.apply_during_training: 179 | return inputs if self.return_sequences else inputs[:, -1] 180 | 181 | level = self.level_var 182 | smoothing = torch.sigmoid(self.smoothing_var) 183 | 184 | # cast inputs to module type 185 | inputs = inputs.type(self.smoothing_var.dtype) 186 | 187 | all_levels = [] 188 | for i in range(inputs.shape[1]): 189 | level = (1 - smoothing) * inputs[:, i] + smoothing * level 190 | if self.return_sequences: 191 | all_levels.append(level) 192 | 193 | if self.return_sequences: 194 | return torch.stack(all_levels, dim=1) 195 | else: 196 | return level 197 | 198 | 199 | class TemporalAvgPool(torch.nn.Module): 200 | """ 201 | Module for taking the average across one dimension of a tensor. 202 | 203 | Parameters 204 | ---------- 205 | dim : int, optional 206 | The dimension to average across. Defaults to the second dimension (``dim=1``), 207 | which is typically the time dimension (for tensors that have a time dimension). 208 | """ 209 | 210 | def __init__(self, dim=1): 211 | """""" # empty docstring removes useless parent docstring from docs 212 | super().__init__() 213 | self.dim = dim 214 | 215 | def forward(self, inputs): 216 | """ 217 | Apply average pooling to inputs. 218 | 219 | Parameters 220 | ---------- 221 | inputs : ``torch.Tensor`` 222 | Array of input values with shape ``(batch_size, n_steps, ...)``. 223 | 224 | Returns 225 | ------- 226 | outputs : ``torch.Tensor`` 227 | Array of output values with shape ``(batch_size, ...)``. 228 | The time dimension is fully averaged and removed. 229 | """ 230 | return torch.mean(inputs, dim=self.dim) 231 | -------------------------------------------------------------------------------- /pytorch_spiking/tests/test_modules.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | 7 | from pytorch_spiking import modules 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "activation", (torch.nn.ReLU(), torch.nn.functional.relu, torch.tanh) 12 | ) 13 | def test_activations(activation, rng, allclose): 14 | x = torch.from_numpy(rng.randn(32, 10, 2)) 15 | 16 | ground = activation(x) 17 | 18 | # behaviour equivalent to base activation during training 19 | y = modules.SpikingActivation( 20 | activation, return_sequences=True, spiking_aware_training=False 21 | ).train()(x) 22 | assert allclose(y, ground) 23 | 24 | # not equivalent during inference 25 | y = modules.SpikingActivation( 26 | activation, return_sequences=True, spiking_aware_training=False 27 | ).eval()(x) 28 | assert not allclose(y, ground, record_rmse=False, print_fail=0) 29 | 30 | # equivalent during inference, with large enough dt 31 | y = modules.SpikingActivation( 32 | activation, return_sequences=True, spiking_aware_training=False, dt=1e8 33 | ).eval()(x) 34 | assert allclose(y, ground) 35 | 36 | # not equivalent during training if using spiking_aware_training 37 | y = modules.SpikingActivation( 38 | activation, return_sequences=True, spiking_aware_training=True 39 | ).train()(x) 40 | assert not allclose(y, ground, record_rmse=False, print_fail=0) 41 | 42 | # equivalent with large enough dt 43 | y = modules.SpikingActivation( 44 | activation, 45 | return_sequences=True, 46 | spiking_aware_training=True, 47 | dt=1e8, 48 | ).train()(x) 49 | assert allclose(y, ground) 50 | 51 | 52 | def test_initial_state(seed, allclose): 53 | x = torch.from_numpy(np.ones((2, 100, 10)) * 100) 54 | 55 | init = torch.rand((2, 10), generator=torch.random.manual_seed(seed)) 56 | 57 | # layers with the same initial state produce the same output 58 | y0 = modules.SpikingActivation( 59 | torch.nn.ReLU(), return_sequences=True, initial_state=init 60 | )(x) 61 | y1 = modules.SpikingActivation( 62 | torch.nn.ReLU(), return_sequences=True, initial_state=init 63 | )(x) 64 | assert allclose(y0, y1) 65 | 66 | # layers with different initial state produce different output 67 | y2 = modules.SpikingActivation(torch.nn.ReLU(), return_sequences=True)(x) 68 | assert not allclose(y0, y2, record_rmse=False, print_fail=0) 69 | 70 | # the same layer called multiple times will produce the same output (if the initial 71 | # state is set) 72 | layer = modules.SpikingActivation( 73 | torch.nn.ReLU(), return_sequences=True, initial_state=init 74 | ) 75 | assert allclose(layer(x), layer(x)) 76 | 77 | # layer will produce different output each time if initial state not set 78 | layer = modules.SpikingActivation(torch.nn.ReLU(), return_sequences=True) 79 | assert not allclose(layer(x), layer(x), record_rmse=False, print_fail=0) 80 | 81 | 82 | def test_spiking_aware_training(rng, allclose): 83 | layer = modules.SpikingActivation( 84 | torch.nn.ReLU(), spiking_aware_training=False 85 | ).train() 86 | layer_sat = modules.SpikingActivation( 87 | torch.nn.ReLU(), spiking_aware_training=True 88 | ).train() 89 | x = torch.from_numpy(rng.uniform(-1, 1, size=(10, 20, 32))).requires_grad_(True) 90 | y = layer(x)[:, -1] 91 | y_sat = layer_sat(x)[:, -1] 92 | y_ground = torch.nn.ReLU()(x)[:, -1] 93 | 94 | # forward pass is different 95 | assert allclose(y.detach().numpy(), y_ground.detach().numpy()) 96 | assert not allclose( 97 | y_sat.detach().numpy(), 98 | y_ground.detach().numpy(), 99 | record_rmse=False, 100 | print_fail=0, 101 | ) 102 | 103 | # gradients are the same 104 | dy = torch.autograd.grad(y, x, grad_outputs=[torch.ones_like(y)])[0] 105 | dy_ground = torch.autograd.grad(y_ground, x, grad_outputs=[torch.ones_like(y)])[0] 106 | dy_sat = torch.autograd.grad(y_sat, x, grad_outputs=[torch.ones_like(y)])[0] 107 | assert allclose(dy, dy_ground) 108 | assert allclose(dy_sat, dy_ground) 109 | 110 | 111 | def test_spiking_swap_functional(allclose): 112 | class MyModel(torch.nn.Module): 113 | def __init__(self): 114 | super().__init__() 115 | 116 | self.dense0 = torch.nn.Linear(1, 10) 117 | self.dense1 = torch.nn.Linear(1, 10) 118 | 119 | def forward(self, inputs): 120 | x = inputs.view((-1, inputs.shape[-1])) 121 | 122 | x0 = self.dense0(x) 123 | x0 = x0.view((inputs.shape[0], inputs.shape[1], x0.shape[-1])) 124 | x0 = torch.nn.LeakyReLU(negative_slope=0.3)(x0) 125 | 126 | x1 = self.dense1(x) 127 | x1 = x1.view((inputs.shape[0], inputs.shape[1], x1.shape[-1])) 128 | x1 = modules.SpikingActivation( 129 | torch.nn.LeakyReLU(negative_slope=0.3), 130 | return_sequences=True, 131 | spiking_aware_training=False, 132 | )(x1) 133 | 134 | return x0, x1 135 | 136 | model = MyModel() 137 | loss_func = torch.nn.MSELoss(reduction="none") 138 | optimizer = torch.optim.SGD(model.parameters(), lr=0.1) 139 | 140 | for _ in range(200): 141 | model.zero_grad() 142 | 143 | outputs = model(torch.ones((32, 1, 1))) 144 | loss = sum( 145 | torch.mean(torch.sum(loss_func(o, t), dim=1)) 146 | for o, t in zip( 147 | outputs, 148 | [torch.ones((32, 1, 10)) * torch.arange(1.0, 100.0, 10)] * 2, 149 | ) 150 | ) 151 | loss.backward() 152 | optimizer.step() 153 | 154 | y0, y1 = model(torch.ones((1, 1000, 1))) 155 | assert allclose(y0.detach().numpy(), np.arange(1, 100, 10), atol=1) 156 | assert allclose( 157 | np.sum(y1.detach().numpy() * 0.001, axis=1, keepdims=True), 158 | np.arange(1, 100, 10), 159 | atol=1, 160 | ) 161 | 162 | 163 | @pytest.mark.parametrize("dt", (0.001, 1)) 164 | def test_lowpass_tau(dt, allclose, rng): 165 | nengo = pytest.importorskip("nengo") 166 | 167 | # verify that the pytorch-spiking lowpass implementation matches the nengo lowpass 168 | # implementation 169 | layer = modules.Lowpass(tau=0.1, units=32, dt=dt).double() 170 | 171 | with torch.no_grad(): 172 | x = torch.from_numpy(rng.randn(10, 100, 32)) 173 | y = layer(x) 174 | 175 | y_nengo = nengo.Lowpass(0.1).filt(x, axis=1, dt=dt) 176 | 177 | assert allclose(y, y_nengo) 178 | 179 | 180 | def test_lowpass_apply_during_training(allclose, rng): 181 | with torch.no_grad(): 182 | x = torch.from_numpy(rng.randn(10, 100, 32)) 183 | 184 | # apply_during_training=False: 185 | # confirm `output == input` for training=True, but not training=False 186 | layer = modules.Lowpass( 187 | tau=0.1, units=32, apply_during_training=False, return_sequences=True 188 | ) 189 | assert allclose(layer.train()(x), x) 190 | assert not allclose(layer.eval()(x), x, record_rmse=False, print_fail=0) 191 | 192 | # apply_during_training=True: 193 | # confirm `output != input` for both values of `training`, and 194 | # output is equal for both values of `training` 195 | layer = modules.Lowpass( 196 | tau=0.1, units=32, apply_during_training=True, return_sequences=True 197 | ) 198 | assert not allclose(layer.train()(x), x, record_rmse=False, print_fail=0) 199 | assert not allclose(layer.eval()(x), x, record_rmse=False, print_fail=0) 200 | assert allclose(layer.train()(x), layer.eval()(x)) 201 | 202 | 203 | def test_lowpass_trainable(allclose): 204 | class MyModel(torch.nn.Module): 205 | def __init__(self): 206 | super().__init__() 207 | 208 | self.trained = modules.Lowpass(0.01, 1, apply_during_training=True) 209 | self.skip = modules.Lowpass(0.01, 1, apply_during_training=False) 210 | self.untrained = modules.Lowpass(0.01, 1, apply_during_training=True) 211 | for param in self.untrained.parameters(): 212 | param.requires_grad = False 213 | 214 | def forward(self, inputs): 215 | return self.trained(inputs), self.skip(inputs), self.untrained(inputs) 216 | 217 | model = MyModel() 218 | 219 | loss_func = torch.nn.MSELoss(reduction="mean") 220 | optimizer = torch.optim.SGD(model.parameters(), lr=0.5) 221 | 222 | for _ in range(10): 223 | model.zero_grad() 224 | 225 | outputs = model(torch.zeros(1, 1, 1)) 226 | loss = sum( 227 | loss_func(o, t) 228 | for o, t in zip( 229 | outputs, 230 | [torch.ones(1, 1)] * 3, 231 | ) 232 | ) 233 | loss.backward() 234 | optimizer.step() 235 | 236 | # trainable layer should learn to output 1 237 | ys = model(torch.zeros((1, 1, 1))) 238 | assert allclose(ys[0].detach(), 1) 239 | assert not allclose(ys[1].detach(), 1, record_rmse=False, print_fail=0) 240 | assert not allclose(ys[2].detach(), 1, record_rmse=False, print_fail=0) 241 | 242 | # for trainable layer, smoothing * initial_level should go to 1 243 | assert allclose( 244 | (torch.sigmoid(model.trained.smoothing_var) * model.trained.level_var).detach(), 245 | 1, 246 | ) 247 | 248 | # other layers should stay at initial value 249 | assert allclose(model.skip.level_var.detach(), 0) 250 | assert allclose(model.untrained.level_var.detach(), 0) 251 | assert allclose(model.skip.smoothing_var.detach(), model.skip.smoothing_init) 252 | assert allclose( 253 | model.untrained.smoothing_var.detach(), model.untrained.smoothing_init 254 | ) 255 | 256 | 257 | def test_lowpass_validation(): 258 | with pytest.raises(ValueError, match="tau must be a positive number"): 259 | modules.Lowpass(tau=0, units=1) 260 | 261 | 262 | def test_temporalavgpool(rng, allclose): 263 | x = rng.randn(32, 10, 2, 5) 264 | tx = torch.from_numpy(x) 265 | for dim in range(x.ndim): 266 | model = torch.nn.Sequential(modules.TemporalAvgPool(dim=dim)) 267 | toutput = model(tx) 268 | assert allclose(toutput.numpy(), x.mean(axis=dim)) 269 | 270 | 271 | @pytest.mark.parametrize( 272 | "module", 273 | ( 274 | modules.SpikingActivation( 275 | torch.nn.ReLU(), initial_state=torch.zeros((32, 50)), dt=1 276 | ), 277 | modules.Lowpass(tau=0.01, units=50, dt=0.001), 278 | ), 279 | ) 280 | def test_return_sequences(module, rng, allclose): 281 | x = torch.tensor(rng.randn(32, 10, 50)) 282 | 283 | with torch.no_grad(): 284 | module_seq = copy.deepcopy(module) 285 | module_seq.return_sequences = True 286 | y_seq = module_seq(x) 287 | 288 | module_last = copy.deepcopy(module) 289 | module_last.return_sequences = False 290 | y_last = module_last(x) 291 | 292 | assert y_seq.shape == x.shape 293 | assert y_last.shape == (x.shape[0], x.shape[2]) 294 | assert allclose(y_seq[:, -1], y_last) 295 | -------------------------------------------------------------------------------- /pytorch_spiking/version.py: -------------------------------------------------------------------------------- 1 | # Automatically generated by nengo-bones, do not edit this file directly 2 | 3 | # pylint: disable=consider-using-f-string,bad-string-format-type 4 | 5 | """ 6 | PyTorchSpiking version information. 7 | 8 | We use semantic versioning (see http://semver.org/) and conform to PEP440 (see 9 | https://www.python.org/dev/peps/pep-0440/). '.dev0' will be added to the version 10 | unless the code base represents a release version. Release versions are git 11 | tagged with the version. 12 | """ 13 | 14 | version_info = (0, 1, 1) 15 | 16 | name = "pytorch-spiking" 17 | dev = 0 18 | 19 | # use old string formatting, so that this can still run in Python <= 3.5 20 | # (since this file is parsed in setup.py, before python_requires is applied) 21 | version = ".".join(str(v) for v in version_info) 22 | if dev is not None: 23 | version += ".dev%d" % dev # pragma: no cover 24 | 25 | copyright = "Copyright (c) 2020-2023 Applied Brain Research" 26 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # Automatically generated by nengo-bones, do not edit this file directly 2 | 3 | [build_sphinx] 4 | source-dir = docs 5 | build-dir = docs/_build 6 | all_files = 1 7 | 8 | [coverage:run] 9 | source = ./ 10 | relative_files = True 11 | 12 | [coverage:report] 13 | # Regexes for lines to exclude from consideration 14 | exclude_lines = 15 | # Have to re-enable the standard pragma 16 | # place ``# pragma: no cover`` at the end of a line to ignore it 17 | pragma: no cover 18 | 19 | # Don't complain if tests don't hit defensive assertion code: 20 | raise NotImplementedError 21 | 22 | # `pass` is just a placeholder, fine if it's not covered 23 | ^[ \t]*pass$ 24 | 25 | 26 | # Patterns for files to exclude from reporting 27 | omit = 28 | */tests/test* 29 | 30 | [flake8] 31 | exclude = 32 | __init__.py 33 | ignore = 34 | E123 35 | E133 36 | E203 37 | E226 38 | E241 39 | E242 40 | E501 41 | E731 42 | F401 43 | W503 44 | max-complexity = 10 45 | max-line-length = 88 46 | 47 | [tool:pytest] 48 | norecursedirs = 49 | .* 50 | *.egg 51 | build 52 | dist 53 | docs 54 | xfail_strict = True 55 | 56 | [pylint] 57 | 58 | [pylint.messages] 59 | disable = 60 | arguments-differ, 61 | assignment-from-no-return, 62 | attribute-defined-outside-init, 63 | blacklisted-name, 64 | comparison-with-callable, 65 | duplicate-code, 66 | fixme, 67 | import-error, 68 | invalid-name, 69 | invalid-sequence-index, 70 | len-as-condition, 71 | literal-comparison, 72 | no-else-raise, 73 | no-else-return, 74 | no-member, 75 | no-name-in-module, 76 | not-an-iterable, 77 | not-context-manager, 78 | protected-access, 79 | redefined-builtin, 80 | stop-iteration-return, 81 | too-few-public-methods, 82 | too-many-arguments, 83 | too-many-branches, 84 | too-many-instance-attributes, 85 | too-many-lines, 86 | too-many-locals, 87 | too-many-return-statements, 88 | too-many-statements, 89 | unexpected-keyword-arg, 90 | unidiomatic-typecheck, 91 | unsubscriptable-object, 92 | unsupported-assignment-operation, 93 | unused-argument, 94 | not-callable, 95 | 96 | [pylint.imports] 97 | known-third-party = 98 | matplotlib, 99 | nengo, 100 | numpy, 101 | pytest, 102 | torch, 103 | 104 | [pylint.format] 105 | max-line-length = 88 106 | 107 | [pylint.classes] 108 | valid-metaclass-classmethod-first-arg = metacls 109 | 110 | [pylint.reports] 111 | reports = no 112 | score = no 113 | 114 | [codespell] 115 | skip = ./build,*/_build,*-checkpoint.ipynb,./.eggs,./*.egg-info,./.git,*/_vendor,./.mypy_cache, 116 | ignore-words-list = hist 117 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Automatically generated by nengo-bones, do not edit this file directly 4 | 5 | import io 6 | import pathlib 7 | import runpy 8 | 9 | try: 10 | from setuptools import find_packages, setup 11 | except ImportError: 12 | raise ImportError( 13 | "'setuptools' is required but not installed. To install it, " 14 | "follow the instructions at " 15 | "https://pip.pypa.io/en/stable/installing/#installing-with-get-pip-py" 16 | ) 17 | 18 | 19 | def read(*filenames, **kwargs): 20 | encoding = kwargs.get("encoding", "utf-8") 21 | sep = kwargs.get("sep", "\n") 22 | buf = [] 23 | for filename in filenames: 24 | with io.open(filename, encoding=encoding) as f: 25 | buf.append(f.read()) 26 | return sep.join(buf) 27 | 28 | 29 | root = pathlib.Path(__file__).parent 30 | version = runpy.run_path(str(root / "pytorch_spiking" / "version.py"))["version"] 31 | 32 | install_req = [ 33 | "numpy>=1.16.0", 34 | "torch>=1.0.0", 35 | ] 36 | docs_req = [ 37 | "jupyter>=1.0.0", 38 | "matplotlib>=2.0.0", 39 | "nbsphinx>=0.3.5", 40 | "nengo-sphinx-theme>=1.2.1", 41 | "numpydoc>=0.6.0", 42 | "sphinx>=3.0.0", 43 | "torchvision>=0.7.0", 44 | ] 45 | optional_req = [] 46 | tests_req = [ 47 | "pylint>=1.9.2", 48 | "pytest>=3.6.0", 49 | "pytest-allclose>=1.0.0", 50 | "pytest-cov>=2.6.0", 51 | "pytest-rng>=1.0.0", 52 | "pytest-xdist>=1.16.0", 53 | ] 54 | 55 | setup( 56 | name="pytorch-spiking", 57 | version=version, 58 | author="Applied Brain Research", 59 | author_email="info@appliedbrainresearch.com", 60 | packages=find_packages(), 61 | url="https://www.nengo.ai/pytorch-spiking", 62 | include_package_data=True, 63 | license="Free for non-commercial use", 64 | description="Spiking neuron integration for PyTorch", 65 | long_description=read("README.rst", "CHANGES.rst"), 66 | zip_safe=False, 67 | install_requires=install_req, 68 | extras_require={ 69 | "all": docs_req + optional_req + tests_req, 70 | "docs": docs_req, 71 | "optional": optional_req, 72 | "tests": tests_req, 73 | }, 74 | python_requires=">=3.6", 75 | classifiers=[ 76 | "Development Status :: 3 - Alpha", 77 | "Intended Audience :: Science/Research", 78 | "License :: Free for non-commercial use", 79 | "Operating System :: Microsoft :: Windows", 80 | "Operating System :: POSIX :: Linux", 81 | "Programming Language :: Python", 82 | "Programming Language :: Python :: 3.6", 83 | "Programming Language :: Python :: 3.7", 84 | "Programming Language :: Python :: 3.8", 85 | "Topic :: Scientific/Engineering", 86 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 87 | ], 88 | ) 89 | --------------------------------------------------------------------------------