├── .github └── workflows │ ├── black.yml │ ├── python-app.yml │ └── pythonpackage.yml ├── .gitignore ├── .pytest_cache ├── CACHEDIR.TAG ├── README.md └── v │ └── cache │ ├── lastfailed │ ├── nodeids │ └── stepwise ├── .readthedocs.yaml ├── .vscode └── launch.json ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── bindsnet ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── utils.cpython-310.pyc ├── analysis │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── pipeline_analysis.cpython-310.pyc │ │ ├── plotting.cpython-310.pyc │ │ └── visualization.cpython-310.pyc │ ├── dotTrace_plotter.py │ ├── pipeline_analysis.py │ ├── plotting.py │ └── visualization.py ├── conversion │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── conversion.cpython-310.pyc │ │ ├── nodes.cpython-310.pyc │ │ └── topology.cpython-310.pyc │ ├── conversion.py │ ├── nodes.py │ └── topology.py ├── datasets │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── alov300.cpython-310.pyc │ │ ├── collate.cpython-310.pyc │ │ ├── dataloader.cpython-310.pyc │ │ ├── davis.cpython-310.pyc │ │ ├── preprocess.cpython-310.pyc │ │ ├── spoken_mnist.cpython-310.pyc │ │ └── torchvision_wrapper.cpython-310.pyc │ ├── alov300.py │ ├── collate.py │ ├── dataloader.py │ ├── davis.py │ ├── preprocess.py │ ├── spoken_mnist.py │ └── torchvision_wrapper.py ├── encoding │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── encoders.cpython-310.pyc │ │ ├── encodings.cpython-310.pyc │ │ └── loaders.cpython-310.pyc │ ├── encoders.py │ ├── encodings.py │ └── loaders.py ├── environment │ ├── README.md │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── environment.cpython-310.pyc │ ├── cue_reward.py │ ├── dot_simulator.py │ └── environment.py ├── evaluation │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── evaluation.cpython-310.pyc │ └── evaluation.py ├── learning │ ├── MCC_learning.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── learning.cpython-310.pyc │ │ └── reward.cpython-310.pyc │ ├── learning.py │ └── reward.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── models.cpython-310.pyc │ └── models.py ├── network │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── monitors.cpython-310.pyc │ │ ├── network.cpython-310.pyc │ │ ├── nodes.cpython-310.pyc │ │ └── topology.cpython-310.pyc │ ├── monitors.py │ ├── network.py │ ├── nodes.py │ ├── topology.py │ └── topology_features.py ├── pipeline │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── action.cpython-310.pyc │ │ ├── base_pipeline.cpython-310.pyc │ │ ├── dataloader_pipeline.cpython-310.pyc │ │ └── environment_pipeline.cpython-310.pyc │ ├── action.py │ ├── base_pipeline.py │ ├── dataloader_pipeline.py │ └── environment_pipeline.py ├── preprocessing │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── preprocessing.cpython-310.pyc │ └── preprocessing.py └── utils.py ├── docs ├── BindsNET benchmark.png ├── DotTraceSample.png ├── Makefile.old ├── UML.png ├── directory_structure.png ├── logo.png ├── make.bat.old ├── pipeline.png ├── pyproject.toml ├── requirements.txt └── source │ ├── bindsnet.analysis.rst │ ├── bindsnet.conversion.rst │ ├── bindsnet.datasets.rst │ ├── bindsnet.encoding.rst │ ├── bindsnet.environment.rst │ ├── bindsnet.evaluation.rst │ ├── bindsnet.learning.rst │ ├── bindsnet.models.rst │ ├── bindsnet.network.rst │ ├── bindsnet.pipeline.rst │ ├── bindsnet.preprocessing.rst │ ├── bindsnet.rst │ ├── conf.py │ ├── guide.rst │ ├── guide │ ├── guide_part_i.rst │ ├── guide_part_ii.rst │ ├── spikes.png │ └── voltages.png │ ├── index.rst │ ├── installation.rst │ ├── modules.rst │ └── quickstart.rst ├── examples ├── README.md ├── benchmark │ ├── annarchy.py │ ├── benchmark.py │ ├── gpu_annarchy.py │ └── plot_benchmark.py ├── breakout │ ├── breakout.py │ ├── breakout_stdp.py │ ├── play_breakout_from_ANN.py │ ├── random_baseline.py │ ├── random_network_baseline.py │ └── trained_shallow_ANN.pt ├── dotTracing │ └── dot_tracing.py ├── mnist │ ├── MCC_reservoir.py │ ├── SOM_LM-SNNs.py │ ├── batch_eth_mnist.py │ ├── conv1d_MNIST.py │ ├── conv3d_MNIST.py │ ├── conv_mnist.py │ ├── eth_mnist.py │ ├── loc1d_mnist.py │ ├── loc2d_mnist.py │ ├── loc3d_mnist.py │ ├── reservoir.py │ └── supervised_mnist.py └── tensorboard │ └── tensorboard.py ├── logs ├── init │ ├── events.out.tfevents.1656543178.TempWin │ ├── events.out.tfevents.1656548905.TempWin │ ├── events.out.tfevents.1673646087.Spike │ ├── events.out.tfevents.1673648326.Spike │ ├── events.out.tfevents.1678117372.Spike │ ├── events.out.tfevents.1682712186.Spike │ ├── events.out.tfevents.1687464074.Spike │ ├── events.out.tfevents.1687464505.Spike │ ├── events.out.tfevents.1687736499.Spike │ ├── events.out.tfevents.1694374827.Spike │ ├── events.out.tfevents.1694374969.Spike │ ├── events.out.tfevents.1694375010.Spike │ ├── events.out.tfevents.1700165624.Spike │ ├── events.out.tfevents.1703700212.Spike │ ├── events.out.tfevents.1711672057.Spike │ ├── events.out.tfevents.1711672140.Spike │ ├── events.out.tfevents.1711673241.Spike │ ├── events.out.tfevents.1711673760.Spike │ ├── events.out.tfevents.1711674764.Spike │ ├── events.out.tfevents.1711675116.Spike │ ├── events.out.tfevents.1711675170.Spike │ ├── events.out.tfevents.1711675181.Spike │ ├── events.out.tfevents.1711675321.Spike │ ├── events.out.tfevents.1711675865.Spike │ ├── events.out.tfevents.1711721119.Spike │ ├── events.out.tfevents.1711723694.Spike │ ├── events.out.tfevents.1720719086.Spike │ └── events.out.tfevents.1720719342.Spike └── runs │ ├── events.out.tfevents.1656543178.TempWin │ ├── events.out.tfevents.1656548905.TempWin │ ├── events.out.tfevents.1673646087.Spike │ ├── events.out.tfevents.1673648326.Spike │ ├── events.out.tfevents.1678117372.Spike │ ├── events.out.tfevents.1682712186.Spike │ ├── events.out.tfevents.1687464074.Spike │ ├── events.out.tfevents.1687464505.Spike │ ├── events.out.tfevents.1687736499.Spike │ ├── events.out.tfevents.1694374827.Spike │ ├── events.out.tfevents.1694374969.Spike │ ├── events.out.tfevents.1694375010.Spike │ ├── events.out.tfevents.1700165624.Spike │ ├── events.out.tfevents.1703700212.Spike │ ├── events.out.tfevents.1711672057.Spike │ ├── events.out.tfevents.1711672140.Spike │ ├── events.out.tfevents.1711673241.Spike │ ├── events.out.tfevents.1711673760.Spike │ ├── events.out.tfevents.1711674764.Spike │ ├── events.out.tfevents.1711675116.Spike │ ├── events.out.tfevents.1711675170.Spike │ ├── events.out.tfevents.1711675181.Spike │ ├── events.out.tfevents.1711675321.Spike │ ├── events.out.tfevents.1711675865.Spike │ ├── events.out.tfevents.1711721119.Spike │ ├── events.out.tfevents.1711723694.Spike │ ├── events.out.tfevents.1720719086.Spike │ └── events.out.tfevents.1720719342.Spike ├── poetry.lock ├── pyproject.toml ├── setup.py ├── test ├── analysis │ ├── __pycache__ │ │ ├── test_analyzers.cpython-310-pytest-7.4.4.pyc │ │ ├── test_analyzers.cpython-310-pytest-8.1.1.pyc │ │ └── test_analyzers.cpython-310-pytest-8.2.2.pyc │ └── test_analyzers.py ├── conversion │ ├── __pycache__ │ │ ├── test_conversion.cpython-310-pytest-7.4.4.pyc │ │ ├── test_conversion.cpython-310-pytest-8.1.1.pyc │ │ └── test_conversion.cpython-310-pytest-8.2.2.pyc │ └── test_conversion.py ├── encoding │ ├── __pycache__ │ │ ├── test_encoding.cpython-310-pytest-7.4.4.pyc │ │ ├── test_encoding.cpython-310-pytest-8.1.1.pyc │ │ └── test_encoding.cpython-310-pytest-8.2.2.pyc │ └── test_encoding.py ├── import │ ├── __pycache__ │ │ ├── test_import.cpython-310-pytest-7.4.4.pyc │ │ ├── test_import.cpython-310-pytest-8.1.1.pyc │ │ └── test_import.cpython-310-pytest-8.2.2.pyc │ └── test_import.py ├── models │ ├── __pycache__ │ │ ├── test_models.cpython-310-pytest-7.4.4.pyc │ │ ├── test_models.cpython-310-pytest-8.1.1.pyc │ │ └── test_models.cpython-310-pytest-8.2.2.pyc │ └── test_models.py └── network │ ├── __pycache__ │ ├── test_connections.cpython-310-pytest-7.4.4.pyc │ ├── test_connections.cpython-310-pytest-8.1.1.pyc │ ├── test_connections.cpython-310-pytest-8.2.2.pyc │ ├── test_learning.cpython-310-pytest-7.4.4.pyc │ ├── test_learning.cpython-310-pytest-8.1.1.pyc │ ├── test_learning.cpython-310-pytest-8.2.2.pyc │ ├── test_monitors.cpython-310-pytest-7.4.4.pyc │ ├── test_monitors.cpython-310-pytest-8.1.1.pyc │ ├── test_monitors.cpython-310-pytest-8.2.2.pyc │ ├── test_network.cpython-310-pytest-7.4.4.pyc │ ├── test_network.cpython-310-pytest-8.1.1.pyc │ ├── test_network.cpython-310-pytest-8.2.2.pyc │ ├── test_nodes.cpython-310-pytest-7.4.4.pyc │ ├── test_nodes.cpython-310-pytest-8.1.1.pyc │ └── test_nodes.cpython-310-pytest-8.2.2.pyc │ ├── test_connections.py │ ├── test_learning.py │ ├── test_monitors.py │ ├── test_network.py │ └── test_nodes.py └── tox.ini /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Black Formater 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: actions/setup-python@v2 11 | - uses: psf/black@stable -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: BindsNET build status 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Set up Python 3.13 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: 3.13 23 | - name: Install Poetry 24 | env: 25 | POETRY_VERSION: 2.1.2 26 | run: | 27 | curl -sSL https://install.python-poetry.org | python - -y &&\ 28 | poetry config virtualenvs.create false 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install flake8 pytest 33 | # if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 34 | poetry install 35 | - name: Lint with flake8 36 | run: | 37 | # stop the build if there are Python syntax errors or undefined names 38 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 39 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 40 | # flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 41 | - name: Test with pytest 42 | run: | 43 | pytest 44 | -------------------------------------------------------------------------------- /.github/workflows/pythonpackage.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | max-parallel: 4 11 | matrix: 12 | python-version: ["3.10", "3.11", "3.12", "3.13"] 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install Poetry 21 | env: 22 | POETRY_VERSION: 2.1.2 23 | run: | 24 | curl -sSL https://install.python-poetry.org | python - -y &&\ 25 | poetry config virtualenvs.create false 26 | - name: Install dependencies 27 | run: | 28 | poetry install 29 | - name: Format with black 30 | run: | 31 | black . 32 | - name: Test with pytest 33 | run: | 34 | pytest 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bindsnet/__pycache__/* 2 | bindsnet/analysis/__pycache__/* 3 | bindsnet/conversion/__pycache__/* 4 | bindsnet/datasets/__pycache__/* 5 | bindsnet/environment/__pycache__/* 6 | bindsnet/evaluation/__pycache__/* 7 | bindsnet/learning/__pycache__/* 8 | bindsnet/encoding/__pycache__/* 9 | bindsnet/models/__pycache__/* 10 | bindsnet/network/__pycache__/* 11 | bindsnet/pipeline/__pycache__/* 12 | bindsnet/preprocessing/__pycache__/* 13 | test/analysis/__pycache__/* 14 | test/conversion/__pycache__/* 15 | test/encoding/__pycache__/* 16 | test/import/__pycache__/* 17 | test/models/__pycache__/* 18 | test/network/__pycache__/* 19 | test/analysis/__pycache__/* 20 | *.pyc 21 | dist/* 22 | logs/* 23 | .pytest_cache/* 24 | .vscode/* 25 | data/* -------------------------------------------------------------------------------- /.pytest_cache/CACHEDIR.TAG: -------------------------------------------------------------------------------- 1 | Signature: 8a477f597d28d172789f06886806bc55 2 | # This file is a cache directory tag created by pytest. 3 | # For information about cache directory tags, see: 4 | # https://bford.info/cachedir/spec.html 5 | -------------------------------------------------------------------------------- /.pytest_cache/README.md: -------------------------------------------------------------------------------- 1 | # pytest cache directory # 2 | 3 | This directory contains data from the pytest's cache plugin, 4 | which provides the `--lf` and `--ff` options, as well as the `cache` fixture. 5 | 6 | **Do not** commit this to version control. 7 | 8 | See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information. 9 | -------------------------------------------------------------------------------- /.pytest_cache/v/cache/lastfailed: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /.pytest_cache/v/cache/nodeids: -------------------------------------------------------------------------------- 1 | [ 2 | "test/analysis/test_analyzers.py::TestAnalyzer::test_init", 3 | "test/analysis/test_analyzers.py::TestAnalyzer::test_plot_runs", 4 | "test/conversion/test_conversion.py::test_conversion_1", 5 | "test/conversion/test_conversion.py::test_conversion_2", 6 | "test/encoding/test_encoding.py::TestEncodings::test_bernoulli", 7 | "test/encoding/test_encoding.py::TestEncodings::test_bernoulli_loader", 8 | "test/encoding/test_encoding.py::TestEncodings::test_multidim_bernoulli", 9 | "test/encoding/test_encoding.py::TestEncodings::test_poisson", 10 | "test/encoding/test_encoding.py::TestEncodings::test_poisson_loader", 11 | "test/models/test_models.py::TestDiehlAndCook2015::test_init", 12 | "test/models/test_models.py::TestTwoLayerNetwork::test_init", 13 | "test/network/test_learning.py::TestLearningRules::test_hebbian", 14 | "test/network/test_learning.py::TestLearningRules::test_mstdp", 15 | "test/network/test_learning.py::TestLearningRules::test_mstdpet", 16 | "test/network/test_learning.py::TestLearningRules::test_post_pre", 17 | "test/network/test_learning.py::TestLearningRules::test_rmax", 18 | "test/network/test_learning.py::TestLearningRules::test_weight_dependent_post_pre", 19 | "test/network/test_network.py::TestNetwork::test_add_objects", 20 | "test/network/test_network.py::TestNetwork::test_empty", 21 | "test/network/test_nodes.py::TestNodes::test_init", 22 | "test/network/test_nodes.py::TestNodes::test_transfer" 23 | ] -------------------------------------------------------------------------------- /.pytest_cache/v/cache/stepwise: -------------------------------------------------------------------------------- 1 | [] -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | builder: html 17 | configuration: docs/source/conf.py 18 | 19 | formats: 20 | - epub 21 | - pdf 22 | 23 | # We recommend specifying your dependencies to enable reproducible builds: 24 | # https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 25 | python: 26 | install: 27 | - requirements: docs/requirements.txt 28 | - method: pip 29 | path: docs/ 30 | # extra_requirements: 31 | # - docs 32 | 33 | # python: 34 | # version: 3.8 35 | # install: 36 | # - method: pip 37 | # path: . 38 | # - requirements: docs/requirements.txt 39 | # system_packages: False -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": false 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Protocol 2 | 3 | To clone this project locally, issue 4 | 5 | ```shell 6 | git clone https://github.com/Hananel-Hazan/bindsnet.git # clones bindsnet repository 7 | ``` 8 | 9 | in the directory of your choice. This will place the repository's code in a directory titled `bindsnet`. 10 | 11 | Install the project with [Poetry](https://python-poetry.org/) (current supported version - 1.1.8) 12 | 13 | ```shell 14 | poetry install 15 | poetry run pre-commit install 16 | ``` 17 | 18 | 19 | Now you can access the project environment with `poetry shell` or run commands with `poetry run `. For example, `poetry run python examples/mnist/conv_mnist.py`. 20 | 21 | Please make sure the `Poetry` environment is activated when you commit your files! The `git commit` command will invoke `pre-commit`, which is installed with Poetry too. IDEs like PyCharm have plugins for `Poetry` and will activate the environment automatically. 22 | 23 | Run the tests, they all should pass 24 | 25 | ```shell 26 | poetry run pytest 27 | ``` 28 | 29 | All development should take place on a branch separate from master. To create a branch, issue 30 | 31 | ```shell 32 | git branch [branch-name] # create new branch 33 | ``` 34 | 35 | replacing `[branch-name]` with a simple and memorable name of choice; e.g., `git branch dan`. Switch to the newly created branch using 36 | 37 | ```shell 38 | git checkout [branch-name] # switch to a different branch of the repository 39 | ``` 40 | 41 | __Note__: Issue `git branch` with no arguments to list all branches currently being tracked, with an asterisk next to the currently used branch; e.g., 42 | 43 | ```shell 44 | $ git branch # list all branches and indicate current branch 45 | * dan 46 | devel 47 | hananel 48 | master 49 | ``` 50 | 51 | If new branches have been created on the remote repository, you may start tracking them with ```git pull --all```, and check them out using ```git checkout [branch-name]```, as before. ```git branch -a``` will list all locally tracked branches, and well as list all remote branches (which can be checked out!). 52 | 53 | After making changes to the repository, issue a `git status` command to see which files have been modified. Then, use 54 | 55 | ```shell 56 | git add [file-name(s) | -A] # add modified or newly created files 57 | ``` 58 | 59 | to add one or more modified files (`file-name(s)`), or all modified files (`-A` or `--all`). These include newly created files. Issue 60 | 61 | ```shell 62 | pre-commit run -a 63 | ``` 64 | 65 | to run the `pre-commit` tool that will automatically format your code with `black`. Issue 66 | 67 | ```shell 68 | git commit -m "[commit-message]" # Useful messages help when reverting / searching through history 69 | ``` 70 | 71 | to "commit" your changes to your local repository, where `[commit-message]` is a _short yet descriptive_ note about what changes have been made. 72 | 73 | Before pushing your changes to the remote repository, you must make sure that you have an up-to-date version of the `master` code. That is, if master has been updated while you have been making your changes, your code will be out of date with respect to the master branch. Issue 74 | 75 | ```shell 76 | git pull # gets all changes from remote repository 77 | git merge master # merges changes made in master branch with those made in your branch 78 | ``` 79 | 80 | and fix any merge conflicts that may have resulted, and re-commit after the fix with 81 | 82 | ```shell 83 | git commit # no -m message needed; merge messages are auto-generated 84 | ``` 85 | 86 | Push your changes back to the repository onto the same branch you are developing on. Issue 87 | 88 | ```shell 89 | git push [origin] [branch-name] # verbose; depends on push.default behavior settings 90 | ``` 91 | 92 | or, 93 | 94 | ```shell 95 | git push # concise; again, depends on push.default behavior 96 | ``` 97 | 98 | where `[origin]` is the name of the remote repository, and `[branch-name]` is the name of the branch you have developed on. 99 | 100 | __Note__: See [push.default](https://git-scm.com/docs/git-config#git-config-pushdefault) for more information. 101 | 102 | To merge your changes into the `master` branch (the definitive version of the project's code), open a pull request on the [webpage](https://github.com/Hananel-Hazan/bindsnet) of the project. You can select the `base` branch (typically `master`, to merge changes _into_ the definitive version of the code) and the `compare` branch (say, `dan`, if I added a new feature locally and want to add it to the project code). You may add an optional extended description of your pull request changes. If there are merge conflicts at this stage, you may fix these using GitHub's pull request review interface. 103 | 104 | Assign reviewer(s) from the group of project contributors to perform a code review of your pull request. If the reviewer(s) are happy with your changes, you may then merge it in to the `master` branch. _Code review is crucial for the development of this project_, as the whole team should be held accountable for all changes. 105 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG DEPS=development 2 | ARG NVIDIA_30XX=false 3 | 4 | FROM nvidia/cuda:11.1-base AS base-default 5 | 6 | ARG DEBIAN_FRONTEND=noninteractive 7 | RUN apt-get update && apt-get install --no-install-recommends --no-install-suggests -y \ 8 | build-essential libgeos-dev liblzma-dev libssl-dev libbz2-dev curl vim python3.8-dev python-dev git libffi-dev \ 9 | libglib2.0-0 libsm6 libxext6 libblas-dev libatlas-base-dev ffmpeg \ 10 | && rm -rf /var/lib/apt/lists/* 11 | 12 | # install pyenv 13 | ENV PYENV_ROOT=$HOME/.pyenv 14 | ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH 15 | RUN curl -L https://github.com/pyenv/pyenv-installer/raw/master/bin/pyenv-installer | bash 16 | RUN echo 'eval "$(pyenv init -)"' >> $HOME/.bashrc 17 | 18 | # install python version specified in the .python-version file 19 | COPY .python-version . 20 | RUN PYTHON_VERSION=$(cat .python-version) pyenv install $PYTHON_VERSION && pyenv global $PYTHON_VERSION && pyenv rehash 21 | 22 | # install poetry and our package 23 | ENV POETRY_NO_INTERACTION=1\ 24 | # send python output directory to stdout 25 | PYTHONUNBUFFERED=1\ 26 | PIP_NO_CACHE_DIR=off\ 27 | PIP_DISABLE_PIP_VERSION_CHECK=on\ 28 | PIP_DEFAULT_TIMEOUT=100\ 29 | POETRY_HOME="/opt/poetry"\ 30 | VENV_PATH="/opt/pysetup/.venv"\ 31 | 32 | # install poetry and our package 33 | ENV POETRY_NO_INTERACTION=1 \ 34 | # send python output directory to stdout 35 | PYTHONUNBUFFERED=1 \ 36 | PIP_NO_CACHE_DIR=off \ 37 | PIP_DISABLE_PIP_VERSION_CHECK=on \ 38 | PIP_DEFAULT_TIMEOUT=100 39 | 40 | ENV PATH="$POETRY_HOME/bin:$VENV_PATH/bin:$PATH" POETRY_VERSION=1.1.8 41 | 42 | RUN mkdir $HOME/opt/ && \ 43 | curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python - &&\ 44 | poetry config virtualenvs.create false 45 | 46 | WORKDIR /bindsnet 47 | 48 | RUN mkdir bindsnet && touch bindsnet/__init__.py ## empty package for Poetry to add to path 49 | COPY pyproject.toml poetry.lock README.md ./ 50 | 51 | FROM base-default AS base-production 52 | RUN poetry install --no-dev # this will only install production dependencies 53 | 54 | FROM base-default AS base-development 55 | RUN poetry install 56 | 57 | FROM base-${DEPS} AS nvidia-30xx-false 58 | RUN rm -rf $HOME/.cache/pypoetry/artifacts # remove downloaded wheels 59 | 60 | # a fix for NVIDIA 30xx GPUs 61 | FROM installed AS nvidia-30xx-true 62 | 63 | RUN python -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 -f https://download.pytorch.org/whl/torch_stable.html 64 | 65 | FROM nvidia-30xx-${NVIDIA_30XX} AS final 66 | COPY . . 67 | -------------------------------------------------------------------------------- /bindsnet/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from bindsnet import ( 4 | analysis, 5 | conversion, 6 | datasets, 7 | encoding, 8 | environment, 9 | evaluation, 10 | learning, 11 | models, 12 | network, 13 | pipeline, 14 | preprocessing, 15 | utils, 16 | ) 17 | 18 | ROOT_DIR = Path(__file__).parents[0].parents[0] 19 | 20 | __all__ = [ 21 | "utils", 22 | "network", 23 | "models", 24 | "analysis", 25 | "preprocessing", 26 | "datasets", 27 | "encoding", 28 | "pipeline", 29 | "learning", 30 | "evaluation", 31 | "environment", 32 | "conversion", 33 | "ROOT_DIR", 34 | ] 35 | -------------------------------------------------------------------------------- /bindsnet/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.analysis import pipeline_analysis, plotting, visualization 2 | 3 | __all__ = ["plotting", "visualization", "pipeline_analysis"] 4 | -------------------------------------------------------------------------------- /bindsnet/analysis/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/analysis/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/__pycache__/pipeline_analysis.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/analysis/__pycache__/pipeline_analysis.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/__pycache__/plotting.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/analysis/__pycache__/plotting.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/__pycache__/visualization.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/analysis/__pycache__/visualization.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/analysis/dotTrace_plotter.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import sys 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | 7 | # Define grid dimensions globally 8 | ROWS = 28 9 | COLS = 28 10 | 11 | 12 | def plotGrids(gridData): 13 | if gridData.shape[0] % ROWS != 0 or gridData.shape[1] != COLS: 14 | raise ("Incompatible grid dimensionality: check data and assumed dimensions.") 15 | 16 | grids = gridData.shape[0] // ROWS 17 | 18 | print("Reshaping into", grids, "grids of shape (", ROWS, ",", COLS, ")") 19 | gridData = gridData.reshape((grids, ROWS, COLS)) 20 | 21 | plotAnotherRange = True 22 | 23 | while plotAnotherRange: 24 | start = -1 25 | end = 1 26 | print("Select the range of iterations to generate grid plots from.") 27 | print("0 means plot all iterations.") 28 | while (start < 0 or grids - 1 < start) or (end < 1 or grids < end): 29 | start = int(input("Start: ")) 30 | 31 | # If start is set to zero, plot everything. 32 | if start == 0: 33 | continue 34 | 35 | end = int(input("End: ")) 36 | 37 | if start == 0: 38 | print("\nPlotting whole shebang!") 39 | else: 40 | print("\nPlotting range from iteration", start, "to", end) 41 | 42 | # Plotting time! 43 | plt.figure() 44 | plt.ion() 45 | plt.imshow(gridData[start], cmap="hot", interpolation="nearest") 46 | plt.colorbar() 47 | plt.pause(0.001) # Pause so that that GUI can do its thing. 48 | for g in gridData[start + 1 : end]: 49 | plt.imshow(g, cmap="hot", interpolation="nearest") 50 | plt.pause(0.001) # Pause so that that GUI can do its thing. 51 | 52 | plotAnotherRange = str.lower(input("Plot another range? (y/n): ")) == "y" 53 | 54 | 55 | def plotRewards(rewData, fname): 56 | cumRewards = np.cumsum(rewData) 57 | tsteps = np.array(range(len(cumRewards))) 58 | 59 | # Plotting time! 60 | plt.figure() 61 | plt.plot(tsteps, cumRewards) 62 | plt.xlabel("Timesteps") 63 | plt.ylabel("Cumulative Reward") 64 | plt.title("Cumulative Reward by Iteration") 65 | plt.savefig(fname[0:-4] + ".png", dpi=200) 66 | plt.pause(0.001) # Pause so that that GUI can do its thing. 67 | 68 | 69 | def plotPerformance(perfData, fname): 70 | # Set bins to a tenth of the episodes, rounded up. 71 | binIdx = np.array(range(len(perfData))) // 10 72 | bins = np.bincount(binIdx, perfData).astype("uint32") 73 | 74 | # Plotting time! 75 | plt.figure() 76 | plt.bar(np.unique(binIdx), bins, color="seagreen") 77 | plt.xlabel("Episode Bins") 78 | plt.ylabel("Number of Intercepts") 79 | plt.title("Interception Performance Across Episodes") 80 | plt.savefig(fname[0:-4] + ".png", dpi=200) 81 | plt.pause(0.001) # Pause so that that GUI can do its thing. 82 | 83 | 84 | def main(): 85 | """ 86 | File types: 87 | 88 | 0) grid - the 2D matrix observation 89 | 1) reward - list of rewards per iteration 90 | 2) performance - list of performance values 91 | """ 92 | fileType = 0 # default to grid 93 | 94 | # By default, we'll search the examples directory, but tweak as needed. 95 | files = glob.glob("../../examples/*/out/*csv") 96 | 97 | if len(files) == 0: 98 | print("Could not find any csv files. Exiting...") 99 | sys.exit() 100 | 101 | plotAnotherFile = True 102 | 103 | while plotAnotherFile: 104 | print("Select the file to generate grid plots from.") 105 | for i, f in enumerate(files): 106 | print(str(i), "-", f) 107 | 108 | # Select the intended file. 109 | sel = -1 110 | while sel < 0 or len(files) < sel: 111 | sel = int(input("\nFile selection: ")) 112 | 113 | fileToPlot = files[sel] 114 | 115 | # Check file type 116 | if 0 < fileToPlot.find("grid"): 117 | print("\nFound 'grid' in name: assuming a grid file type.") 118 | fileType = 0 119 | elif 0 < fileToPlot.find("rew"): 120 | print("\nFound 'rew' in name: assuming a reward file type.") 121 | fileType = 1 122 | 123 | elif 0 < fileToPlot.find("perf"): 124 | print("\nFound 'perf' in name: assuming a performance file type.") 125 | fileType = 2 126 | else: 127 | print("\nUnknown file type. Which type are we plotting?") 128 | print("\n0) grid\n1) reward\n2) performance") 129 | fileType = -1 130 | while fileType < 0 or 2 < fileType: 131 | fileType = int(input("\nFile type: ")) 132 | 133 | print("\nPlotting: ", fileToPlot) 134 | data = np.genfromtxt(fileToPlot, delimiter=",") 135 | 136 | # Plot by file type 137 | if fileType == 0: 138 | plotGrids(data) 139 | elif fileType == 1: 140 | plotRewards(data, fileToPlot) 141 | elif fileType == 2: 142 | plotPerformance(data, fileToPlot) 143 | else: 144 | print("ERROR: Unknown file type") 145 | 146 | plotAnotherFile = str.lower(input("Plot another file? (y/n): ")) == "y" 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /bindsnet/conversion/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.conversion.conversion import ( 2 | ConstantPad2dConnection, 3 | FeatureExtractor, 4 | PassThroughNodes, 5 | Permute, 6 | PermuteConnection, 7 | SubtractiveResetIFNodes, 8 | ann_to_snn, 9 | data_based_normalization, 10 | ) 11 | 12 | __all__ = [ 13 | "Permute", 14 | "FeatureExtractor", 15 | "SubtractiveResetIFNodes", 16 | "PassThroughNodes", 17 | "PermuteConnection", 18 | "ConstantPad2dConnection", 19 | "data_based_normalization", 20 | "ann_to_snn", 21 | ] 22 | -------------------------------------------------------------------------------- /bindsnet/conversion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/conversion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/conversion/__pycache__/conversion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/conversion/__pycache__/conversion.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/conversion/__pycache__/nodes.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/conversion/__pycache__/nodes.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/conversion/__pycache__/topology.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/conversion/__pycache__/topology.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/conversion/nodes.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Optional, Union 2 | 3 | import torch 4 | 5 | from bindsnet.network import nodes 6 | 7 | 8 | class SubtractiveResetIFNodes(nodes.Nodes): 9 | # language=rst 10 | """ 11 | Layer of `integrate-and-fire (IF) neurons ` using 12 | reset by subtraction. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | n: Optional[int] = None, 18 | shape: Optional[Iterable[int]] = None, 19 | traces: bool = False, 20 | traces_additive: bool = False, 21 | tc_trace: Union[float, torch.Tensor] = 20.0, 22 | trace_scale: Union[float, torch.Tensor] = 1.0, 23 | sum_input: bool = False, 24 | thresh: Union[float, torch.Tensor] = -52.0, 25 | reset: Union[float, torch.Tensor] = -65.0, 26 | refrac: Union[int, torch.Tensor] = 5, 27 | lbound: float = None, 28 | **kwargs, 29 | ) -> None: 30 | # language=rst 31 | """ 32 | Instantiates a layer of IF neurons with the subtractive reset mechanism 33 | from `this paper `_. 34 | 35 | :param n: The number of neurons in the layer. 36 | :param shape: The dimensionality of the layer. 37 | :param traces: Whether to record spike traces. 38 | :param traces_additive: Whether to record spike traces additively. 39 | :param tc_trace: Time constant of spike trace decay. 40 | :param trace_scale: Scaling factor for spike trace. 41 | :param sum_input: Whether to sum all inputs. 42 | :param thresh: Spike threshold voltage. 43 | :param reset: Post-spike reset voltage. 44 | :param refrac: Refractory (non-firing) period of the neuron. 45 | :param lbound: Lower bound of the voltage. 46 | """ 47 | super().__init__( 48 | n=n, 49 | shape=shape, 50 | traces=traces, 51 | traces_additive=traces_additive, 52 | tc_trace=tc_trace, 53 | trace_scale=trace_scale, 54 | sum_input=sum_input, 55 | ) 56 | 57 | self.register_buffer( 58 | "reset", torch.tensor(reset, dtype=torch.float) 59 | ) # Post-spike reset voltage. 60 | self.register_buffer( 61 | "thresh", torch.tensor(thresh, dtype=torch.float) 62 | ) # Spike threshold voltage. 63 | self.register_buffer( 64 | "refrac", torch.tensor(refrac) 65 | ) # Post-spike refractory period. 66 | self.register_buffer("v", torch.FloatTensor()) # Neuron voltages. 67 | self.register_buffer( 68 | "refrac_count", torch.FloatTensor() 69 | ) # Refractory period counters. 70 | 71 | self.lbound = lbound # Lower bound of voltage. 72 | 73 | def forward(self, x: torch.Tensor) -> None: 74 | # language=rst 75 | """ 76 | Runs a single simulation step. 77 | 78 | :param x: Inputs to the layer. 79 | """ 80 | # Integrate input voltages. 81 | self.v += (self.refrac_count == 0).float() * x 82 | 83 | # Decrement refractory counters. 84 | self.refrac_count = (self.refrac_count > 0).float() * ( 85 | self.refrac_count - self.dt 86 | ) 87 | 88 | # Check for spiking neurons. 89 | self.s = self.v >= self.thresh 90 | 91 | # Refractoriness and voltage reset. 92 | self.refrac_count.masked_fill_(self.s, self.refrac) 93 | self.v[self.s] = self.v[self.s] - self.thresh 94 | 95 | # Voltage clipping to lower bound. 96 | if self.lbound is not None: 97 | self.v.masked_fill_(self.v < self.lbound, self.lbound) 98 | 99 | super().forward(x) 100 | 101 | def reset_state_variables(self) -> None: 102 | # language=rst 103 | """ 104 | Resets relevant state variables. 105 | """ 106 | super().reset_state_variables() 107 | self.v.fill_(self.reset) # Neuron voltages. 108 | self.refrac_count.zero_() # Refractory period counters. 109 | 110 | def set_batch_size(self, batch_size) -> None: 111 | # language=rst 112 | """ 113 | Sets mini-batch size. Called when layer is added to a network. 114 | 115 | :param batch_size: Mini-batch size. 116 | """ 117 | super().set_batch_size(batch_size=batch_size) 118 | self.v = self.reset * torch.ones(batch_size, *self.shape, device=self.v.device) 119 | self.refrac_count = torch.zeros_like(self.v, device=self.refrac_count.device) 120 | 121 | 122 | class PassThroughNodes(nodes.Nodes): 123 | # language=rst 124 | """ 125 | Layer of `integrate-and-fire (IF) neurons 126 | `_ with using reset by 127 | subtraction. 128 | """ 129 | 130 | def __init__( 131 | self, 132 | n: Optional[int] = None, 133 | shape: Optional[Iterable[int]] = None, 134 | traces: bool = False, 135 | traces_additive: bool = False, 136 | tc_trace: Union[float, torch.Tensor] = 20.0, 137 | trace_scale: Union[float, torch.Tensor] = 1.0, 138 | sum_input: bool = False, 139 | ) -> None: 140 | # language=rst 141 | """ 142 | Instantiates a layer of IF neurons. 143 | 144 | :param n: The number of neurons in the layer. 145 | :param shape: The dimensionality of the layer. 146 | :param traces: Whether to record spike traces. 147 | :param trace_tc: Time constant of spike trace decay. 148 | :param sum_input: Whether to sum all inputs. 149 | """ 150 | super().__init__( 151 | n=n, 152 | shape=shape, 153 | traces=traces, 154 | traces_additive=traces_additive, 155 | tc_trace=tc_trace, 156 | trace_scale=trace_scale, 157 | sum_input=sum_input, 158 | ) 159 | self.register_buffer("v", torch.zeros(self.shape)) 160 | 161 | def forward(self, x: torch.Tensor) -> None: 162 | # language=rst 163 | """ 164 | Runs a single simulation step. 165 | 166 | :param inputs: Inputs to the layer. 167 | :param dt: Simulation time step. 168 | """ 169 | self.s = x 170 | 171 | def reset_state_variables(self) -> None: 172 | # language=rst 173 | """ 174 | Resets relevant state variables. 175 | """ 176 | self.s.zero_() 177 | -------------------------------------------------------------------------------- /bindsnet/conversion/topology.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from bindsnet.network import nodes, topology 7 | 8 | 9 | class PermuteConnection(topology.AbstractConnection): 10 | # language=rst 11 | """ 12 | Special-purpose connection for emulating the custom ``Permute`` module in 13 | spiking neural networks. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | source: nodes.Nodes, 19 | target: nodes.Nodes, 20 | dims: Iterable, 21 | nu: Optional[Union[float, Iterable[float]]] = None, 22 | weight_decay: float = 0.0, 23 | **kwargs, 24 | ) -> None: 25 | # language=rst 26 | """ 27 | Constructor for ``PermuteConnection``. 28 | 29 | :param source: A layer of nodes from which the connection originates. 30 | :param target: A layer of nodes to which the connection connects. 31 | :param dims: Order of dimensions to permute. 32 | :param nu: Learning rate for both pre- and post-synaptic events. 33 | :param weight_decay: Constant multiple to decay weights by on each 34 | iteration. 35 | 36 | Keyword arguments: 37 | 38 | :param function update_rule: Modifies connection parameters according 39 | to some rule. 40 | :param float wmin: The minimum value on the connection weights. 41 | :param float wmax: The maximum value on the connection weights. 42 | :param float norm: Total weight per target neuron normalization. 43 | """ 44 | super().__init__(source, target, nu, weight_decay, **kwargs) 45 | 46 | self.dims = dims 47 | 48 | def compute(self, s: torch.Tensor) -> torch.Tensor: 49 | # language=rst 50 | """ 51 | Permute input. 52 | 53 | :param s: Input. 54 | :return: Permuted input. 55 | """ 56 | return s.permute(self.dims).float() 57 | 58 | 59 | class ConstantPad2dConnection(topology.AbstractConnection): 60 | # language=rst 61 | """ 62 | Special-purpose connection for emulating the ``ConstantPad2d`` PyTorch 63 | module in spiking neural networks. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | source: nodes.Nodes, 69 | target: nodes.Nodes, 70 | padding: Tuple, 71 | nu: Optional[Union[float, Iterable[float]]] = None, 72 | weight_decay: float = 0.0, 73 | **kwargs, 74 | ) -> None: 75 | # language=rst 76 | """ 77 | Constructor for ``ConstantPad2dConnection``. 78 | 79 | :param source: A layer of nodes from which the connection originates. 80 | :param target: A layer of nodes to which the connection connects. 81 | :param padding: Padding of input tensors; passed to 82 | ``torch.nn.functional.pad``. 83 | :param nu: Learning rate for both pre- and post-synaptic events. 84 | :param weight_decay: Constant multiple to decay weights by on each 85 | iteration. 86 | 87 | Keyword arguments: 88 | 89 | :param function update_rule: Modifies connection parameters according 90 | to some rule. 91 | :param float wmin: The minimum value on the connection weights. 92 | :param float wmax: The maximum value on the connection weights. 93 | :param float norm: Total weight per target neuron normalization. 94 | """ 95 | 96 | super().__init__(source, target, nu, weight_decay, **kwargs) 97 | 98 | self.padding = padding 99 | 100 | def compute(self, s: torch.Tensor): 101 | # language=rst 102 | """ 103 | Pad input. 104 | 105 | :param s: Input. 106 | :return: Padding input. 107 | """ 108 | return F.pad(s, self.padding).float() 109 | -------------------------------------------------------------------------------- /bindsnet/datasets/README.md: -------------------------------------------------------------------------------- 1 | BindsNET supplies datasets in several different formats that all base on 2 | the `torch.utils.data.Dataset` 3 | 4 | # torchvision datasets 5 | 6 | Wrappers around all `torchvision.datasets` are provided. This wrapper 7 | (found in `torchvision_wrapper.py`) adds two arguments for encoding the 8 | image and label. 9 | 10 | ## Tested 11 | 12 | - CIFAR10 13 | - CIFAR100 14 | - MNIST 15 | - EMNIST 16 | - KMNIST 17 | - FashionMNIST 18 | - STL10 19 | - SVHN 20 | 21 | ## Not tested 22 | 23 | - Cityscapes 24 | - CocoCaptions 25 | - CocoDetection 26 | - DatasetFolder 27 | - FakeData 28 | - Flickr30k 29 | - Flickr8k 30 | - ImageFolder 31 | - LSUN 32 | - LSUNClass 33 | - Omniglot 34 | - PhotoTour 35 | - SEMEION 36 | - SBU 37 | - VOCDetection 38 | - VOCSegmentation 39 | 40 | # SpokenMNIST 41 | 42 | File: `spoken_mnist.py` 43 | URL: https://github.com/Jakobovski/free-spoken-digit-dataset 44 | -------------------------------------------------------------------------------- /bindsnet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.datasets.alov300 import ALOV300 2 | from bindsnet.datasets.collate import time_aware_collate 3 | from bindsnet.datasets.dataloader import DataLoader 4 | from bindsnet.datasets.davis import Davis 5 | from bindsnet.datasets.spoken_mnist import SpokenMNIST 6 | from bindsnet.datasets.torchvision_wrapper import create_torchvision_dataset_wrapper 7 | 8 | CIFAR10 = create_torchvision_dataset_wrapper("CIFAR10") 9 | CIFAR100 = create_torchvision_dataset_wrapper("CIFAR100") 10 | Cityscapes = create_torchvision_dataset_wrapper("Cityscapes") 11 | CocoCaptions = create_torchvision_dataset_wrapper("CocoCaptions") 12 | CocoDetection = create_torchvision_dataset_wrapper("CocoDetection") 13 | DatasetFolder = create_torchvision_dataset_wrapper("DatasetFolder") 14 | EMNIST = create_torchvision_dataset_wrapper("EMNIST") 15 | FakeData = create_torchvision_dataset_wrapper("FakeData") 16 | FashionMNIST = create_torchvision_dataset_wrapper("FashionMNIST") 17 | Flickr30k = create_torchvision_dataset_wrapper("Flickr30k") 18 | Flickr8k = create_torchvision_dataset_wrapper("Flickr8k") 19 | ImageFolder = create_torchvision_dataset_wrapper("ImageFolder") 20 | KMNIST = create_torchvision_dataset_wrapper("KMNIST") 21 | LSUN = create_torchvision_dataset_wrapper("LSUN") 22 | LSUNClass = create_torchvision_dataset_wrapper("LSUNClass") 23 | MNIST = create_torchvision_dataset_wrapper("MNIST") 24 | Omniglot = create_torchvision_dataset_wrapper("Omniglot") 25 | PhotoTour = create_torchvision_dataset_wrapper("PhotoTour") 26 | SBU = create_torchvision_dataset_wrapper("SBU") 27 | SEMEION = create_torchvision_dataset_wrapper("SEMEION") 28 | STL10 = create_torchvision_dataset_wrapper("STL10") 29 | SVHN = create_torchvision_dataset_wrapper("SVHN") 30 | VOCDetection = create_torchvision_dataset_wrapper("VOCDetection") 31 | VOCSegmentation = create_torchvision_dataset_wrapper("VOCSegmentation") 32 | 33 | __all__ = [ 34 | "torchvision_wrapper", 35 | "create_torchvision_dataset_wrapper", 36 | "spoken_mnist", 37 | "SpokenMNIST", 38 | "davis", 39 | "Davis", 40 | "preprocess", 41 | "alov300", 42 | "ALOV300", 43 | "collate", 44 | "time_aware_collate", 45 | "dataloader", 46 | "DataLoader", 47 | "CIFAR10", 48 | "CIFAR100", 49 | "Cityscapes", 50 | "CocoCaptions", 51 | "CocoDetection", 52 | "DatasetFolder", 53 | "EMNIST", 54 | "FakeData", 55 | "FashionMNIST", 56 | "Flickr30k", 57 | "Flickr8k", 58 | "ImageFolder", 59 | "KMNIST", 60 | "LSUN", 61 | "LSUNClass", 62 | "MNIST", 63 | "Omniglot", 64 | "PhotoTour", 65 | "SBU", 66 | "SEMEION", 67 | "STL10", 68 | "SVHN", 69 | "VOCDetection", 70 | "VOCSegmentation", 71 | ] 72 | -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/datasets/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/alov300.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/datasets/__pycache__/alov300.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/collate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/datasets/__pycache__/collate.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/dataloader.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/datasets/__pycache__/dataloader.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/davis.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/datasets/__pycache__/davis.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/preprocess.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/datasets/__pycache__/preprocess.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/spoken_mnist.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/datasets/__pycache__/spoken_mnist.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/datasets/__pycache__/torchvision_wrapper.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/datasets/collate.py: -------------------------------------------------------------------------------- 1 | # language=rst 2 | """ 3 | This code is directly pulled from the pytorch version found at: 4 | 5 | https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py 6 | 7 | Modifications exist to have [time, batch, n_0, ... n_k] instead of batch in dimension 0. 8 | """ 9 | 10 | import collections.abc 11 | 12 | import torch 13 | from torch.utils.data._utils import collate as pytorch_collate 14 | 15 | 16 | def safe_worker_check(): 17 | # language=rst 18 | """ 19 | Method to check to use shared memory. 20 | """ 21 | try: 22 | return torch.utils.data.get_worker_info() is not None 23 | except: 24 | return pytorch_collate._use_shared_memory 25 | 26 | 27 | def time_aware_collate(batch): 28 | # language=rst 29 | """ 30 | Puts each data field into a tensor with dimensions ``[time, batch size, ...]`` 31 | 32 | Interpretation of dimensions being input: 33 | - 0 dim (,) - (1, batch_size, 1) 34 | - 1 dim (time,) - (time, batch_size, 1) 35 | - >2 dim (time, n_0, ...) - (time, batch_size, n_0, ...) 36 | """ 37 | elem = batch[0] 38 | elem_type = type(elem) 39 | if isinstance(elem, torch.Tensor): 40 | # catch 0 and 1 dimension cases and view as specified 41 | if elem.dim() == 0: 42 | batch = [x.view((1, 1)) for x in batch] 43 | elif elem.dim() == 1: 44 | batch = [x.view((x.shape[0], 1)) for x in batch] 45 | 46 | out = None 47 | if safe_worker_check(): 48 | # If we're in a background process, concatenate directly into a 49 | # shared memory tensor to avoid an extra copy 50 | numel = sum([x.numel() for x in batch]) 51 | storage = elem.storage()._new_shared(numel) 52 | out = elem.new(storage) 53 | return torch.stack(batch, 1, out=out) 54 | elif ( 55 | elem_type.__module__ == "numpy" 56 | and elem_type.__name__ != "str_" 57 | and elem_type.__name__ != "string_" 58 | ): 59 | elem = batch[0] 60 | if elem_type.__name__ == "ndarray": 61 | # array of string classes and object 62 | if ( 63 | pytorch_collate.np_str_obj_array_pattern.search(elem.dtype.str) 64 | is not None 65 | ): 66 | raise TypeError( 67 | pytorch_collate.default_collate_err_msg_format.format(elem.dtype) 68 | ) 69 | 70 | return time_aware_collate([torch.as_tensor(b) for b in batch]) 71 | elif elem.shape == (): # scalars 72 | return torch.as_tensor(batch) 73 | elif isinstance(elem, float): 74 | return torch.tensor(batch, dtype=torch.float64) 75 | elif isinstance(elem, int): 76 | return torch.tensor(batch) 77 | elif isinstance(elem, collections.abc.Mapping): 78 | return {key: time_aware_collate([d[key] for d in batch]) for key in elem} 79 | elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple 80 | return elem_type(*(time_aware_collate(samples) for samples in zip(*batch))) 81 | elif isinstance(elem, collections.abc.Sequence): 82 | transposed = zip(*batch) 83 | return [time_aware_collate(samples) for samples in transposed] 84 | 85 | raise TypeError(pytorch_collate.default_collate_err_msg_format.format(elem_type)) 86 | -------------------------------------------------------------------------------- /bindsnet/datasets/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bindsnet.datasets.collate import time_aware_collate 4 | 5 | 6 | class DataLoader(torch.utils.data.DataLoader): 7 | def __init__( 8 | self, 9 | dataset, 10 | batch_size=1, 11 | shuffle=False, 12 | sampler=None, 13 | batch_sampler=None, 14 | num_workers=0, 15 | collate_fn=time_aware_collate, 16 | pin_memory=False, 17 | drop_last=False, 18 | timeout=0, 19 | worker_init_fn=None, 20 | ): 21 | super().__init__( 22 | dataset, 23 | sampler=sampler, 24 | shuffle=shuffle, 25 | batch_size=batch_size, 26 | drop_last=drop_last, 27 | pin_memory=pin_memory, 28 | timeout=timeout, 29 | num_workers=num_workers, 30 | worker_init_fn=worker_init_fn, 31 | batch_sampler=batch_sampler, 32 | collate_fn=collate_fn, 33 | ) 34 | -------------------------------------------------------------------------------- /bindsnet/datasets/torchvision_wrapper.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | from torchvision import datasets as torchDB 5 | 6 | from bindsnet.encoding import Encoder, NullEncoder 7 | 8 | 9 | def create_torchvision_dataset_wrapper(ds_type): 10 | # language=rst 11 | """ 12 | Creates wrapper classes for datasets that output ``(image, label)`` from 13 | ``__getitem__``. This applies to all of the datasets inside of ``torchvision``. 14 | """ 15 | if type(ds_type) == str: 16 | ds_type = getattr(torchDB, ds_type) 17 | 18 | class TorchvisionDatasetWrapper(ds_type): 19 | __doc__ = ( 20 | """BindsNET torchvision dataset wrapper for: 21 | 22 | The core difference is the output of __getitem__ is no longer 23 | (image, label) rather a dictionary containing the image, label, 24 | and their encoded versions if encoders were provided. 25 | 26 | \n\n""" 27 | + str(ds_type) 28 | if ds_type.__doc__ is None 29 | else ds_type.__doc__ 30 | ) 31 | 32 | def __init__( 33 | self, 34 | image_encoder: Optional[Encoder] = None, 35 | label_encoder: Optional[Encoder] = None, 36 | *args, 37 | **kwargs, 38 | ): 39 | # language=rst 40 | """ 41 | Constructor for the BindsNET torchvision dataset wrapper. 42 | For details on the dataset you're interested in visit 43 | 44 | https://pytorch.org/docs/stable/torchvision/datasets.html 45 | 46 | :param image_encoder: Spike encoder for use on the image 47 | :param label_encoder: Spike encoder for use on the label 48 | :param *args: Arguments for the original dataset 49 | :param **kwargs: Keyword arguments for the original dataset 50 | """ 51 | super().__init__(*args, **kwargs) 52 | 53 | self.args = args 54 | self.kwargs = kwargs 55 | 56 | # Allow the passthrough of None, but change to NullEncoder 57 | if image_encoder is None: 58 | image_encoder = NullEncoder() 59 | 60 | if label_encoder is None: 61 | label_encoder = NullEncoder() 62 | 63 | self.image_encoder = image_encoder 64 | self.label_encoder = label_encoder 65 | 66 | def __getitem__(self, ind: int) -> Dict[str, torch.Tensor]: 67 | # language=rst 68 | """ 69 | Utilizes the ``torchvision.dataset`` parent class to grab the data, then 70 | encodes using the supplied encoders. 71 | 72 | :param int ind: Index to grab data at. 73 | :return: The relevant data and encoded data from the requested index. 74 | """ 75 | image, label = super().__getitem__(ind) 76 | 77 | output = { 78 | "image": image, 79 | "label": label, 80 | "encoded_image": self.image_encoder(image), 81 | "encoded_label": self.label_encoder(label), 82 | } 83 | 84 | return output 85 | 86 | def __len__(self): 87 | return super().__len__() 88 | 89 | return TorchvisionDatasetWrapper 90 | -------------------------------------------------------------------------------- /bindsnet/encoding/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.encoding.encodings import bernoulli, poisson, rank_order, repeat, single 2 | from bindsnet.encoding.loaders import ( 3 | bernoulli_loader, 4 | poisson_loader, 5 | rank_order_loader, 6 | ) 7 | 8 | from .encoders import ( 9 | BernoulliEncoder, 10 | Encoder, 11 | NullEncoder, 12 | PoissonEncoder, 13 | RankOrderEncoder, 14 | RepeatEncoder, 15 | SingleEncoder, 16 | ) 17 | 18 | __all__ = [ 19 | "encodings", 20 | "single", 21 | "repeat", 22 | "bernoulli", 23 | "poisson", 24 | "rank_order", 25 | "loaders", 26 | "bernoulli_loader", 27 | "poisson_loader", 28 | "rank_order_loader", 29 | "encoders", 30 | "Encoder", 31 | "NullEncoder", 32 | "SingleEncoder", 33 | "RepeatEncoder", 34 | "BernoulliEncoder", 35 | "PoissonEncoder", 36 | "RankOrderEncoder", 37 | ] 38 | -------------------------------------------------------------------------------- /bindsnet/encoding/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/encoding/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/encoding/__pycache__/encoders.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/encoding/__pycache__/encoders.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/encoding/__pycache__/encodings.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/encoding/__pycache__/encodings.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/encoding/__pycache__/loaders.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/encoding/__pycache__/loaders.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/encoding/encoders.py: -------------------------------------------------------------------------------- 1 | from bindsnet.encoding import encodings 2 | 3 | 4 | class Encoder: 5 | # language=rst 6 | """ 7 | Base class for spike encodings transforms. 8 | 9 | Calls ``self.enc`` from the subclass and passes whatever arguments were provided. 10 | ``self.enc`` must be callable with ``torch.Tensor``, ``*args``, ``**kwargs`` 11 | """ 12 | 13 | def __init__(self, *args, **kwargs) -> None: 14 | self.enc_args = args 15 | self.enc_kwargs = kwargs 16 | 17 | def __call__(self, img): 18 | return self.enc(img, *self.enc_args, **self.enc_kwargs) 19 | 20 | 21 | class NullEncoder(Encoder): 22 | # language=rst 23 | """ 24 | Pass through of the datum that was input. 25 | 26 | .. note:: 27 | This is not a real spike encoder. Be careful with the usage of this class. 28 | """ 29 | 30 | def __init__(self): 31 | super().__init__() 32 | 33 | def __call__(self, img): 34 | return img 35 | 36 | 37 | class SingleEncoder(Encoder): 38 | def __init__(self, time: int, dt: float = 1.0, sparsity: float = 0.5, **kwargs): 39 | # language=rst 40 | """ 41 | Creates a callable SingleEncoder which encodes as defined in 42 | ``bindsnet.encoding.single`` 43 | 44 | :param time: Length of single spike train per input variable. 45 | :param dt: Simulation time step. 46 | :param sparsity: Sparsity of the input representation. 0 for no spikes and 1 for 47 | all spikes. 48 | """ 49 | super().__init__(time, dt=dt, sparsity=sparsity, **kwargs) 50 | 51 | self.enc = encodings.single 52 | 53 | 54 | class RepeatEncoder(Encoder): 55 | def __init__(self, time: int, dt: float = 1.0, **kwargs): 56 | # language=rst 57 | """ 58 | Creates a callable ``RepeatEncoder`` which encodes as defined in 59 | ``bindsnet.encoding.repeat`` 60 | 61 | :param time: Length of repeat spike train per input variable. 62 | :param dt: Simulation time step. 63 | """ 64 | super().__init__(time, dt=dt, **kwargs) 65 | 66 | self.enc = encodings.repeat 67 | 68 | 69 | class BernoulliEncoder(Encoder): 70 | def __init__(self, time: int, dt: float = 1.0, **kwargs): 71 | # language=rst 72 | """ 73 | Creates a callable ``BernoulliEncoder`` which encodes as defined in 74 | :code:`bindsnet.encoding.bernoulli` 75 | 76 | :param time: Length of Bernoulli spike train per input variable. 77 | :param dt: Simulation time step. 78 | 79 | Keyword arguments: 80 | 81 | :param float max_prob: Maximum probability of spike per time step. 82 | """ 83 | super().__init__(time, dt=dt, **kwargs) 84 | 85 | self.enc = encodings.bernoulli 86 | 87 | 88 | class PoissonEncoder(Encoder): 89 | def __init__(self, time: int, dt: float = 1.0, approx: bool = False, **kwargs): 90 | # language=rst 91 | """ 92 | Creates a callable PoissonEncoder which encodes as defined in 93 | ``bindsnet.encoding.poisson` 94 | 95 | :param time: Length of Poisson spike train per input variable. 96 | :param dt: Simulation time step. 97 | :param approx: Bool: use alternate faster, less accurate computation. 98 | 99 | """ 100 | super().__init__(time, dt=dt, approx=approx, **kwargs) 101 | 102 | self.enc = encodings.poisson 103 | 104 | 105 | class RankOrderEncoder(Encoder): 106 | def __init__(self, time: int, dt: float = 1.0, **kwargs): 107 | # language=rst 108 | """ 109 | Creates a callable RankOrderEncoder which encodes as defined in 110 | :code:`bindsnet.encoding.rank_order` 111 | 112 | :param time: Length of RankOrder spike train per input variable. 113 | :param dt: Simulation time step. 114 | """ 115 | super().__init__(time, dt=dt, **kwargs) 116 | 117 | self.enc = encodings.rank_order 118 | -------------------------------------------------------------------------------- /bindsnet/encoding/loaders.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Iterator, Optional, Union 2 | 3 | import torch 4 | 5 | from bindsnet.encoding.encodings import bernoulli, poisson, rank_order 6 | 7 | 8 | def bernoulli_loader( 9 | data: Union[torch.Tensor, Iterable[torch.Tensor]], 10 | time: Optional[int] = None, 11 | dt: float = 1.0, 12 | **kwargs, 13 | ) -> Iterator[torch.Tensor]: 14 | # language=rst 15 | """ 16 | Lazily invokes ``bindsnet.encoding.bernoulli`` to iteratively encode a sequence of 17 | data. 18 | 19 | :param data: Tensor of shape ``[n_samples, n_1, ..., n_k]``. 20 | :param time: Length of Bernoulli spike train per input variable. 21 | :param dt: Simulation time step. 22 | :return: Tensors of shape ``[time, n_1, ..., n_k]`` of Bernoulli-distributed spikes. 23 | 24 | Keyword arguments: 25 | 26 | :param float max_prob: Maximum probability of spike per Bernoulli trial. 27 | """ 28 | # Setting kwargs. 29 | max_prob = kwargs.get("dt", 1.0) 30 | 31 | for i in range(len(data)): 32 | # Encode datum as Bernoulli spike trains. 33 | yield bernoulli(datum=data[i], time=time, dt=dt, max_prob=max_prob) 34 | 35 | 36 | def poisson_loader( 37 | data: Union[torch.Tensor, Iterable[torch.Tensor]], 38 | time: int, 39 | dt: float = 1.0, 40 | **kwargs, 41 | ) -> Iterator[torch.Tensor]: 42 | # language=rst 43 | """ 44 | Lazily invokes ``bindsnet.encoding.poisson`` to iteratively encode a sequence of 45 | data. 46 | 47 | :param data: Tensor of shape ``[n_samples, n_1, ..., n_k]``. 48 | :param time: Length of Poisson spike train per input variable. 49 | :param dt: Simulation time step. 50 | :return: Tensors of shape ``[time, n_1, ..., n_k]`` of Poisson-distributed spikes. 51 | """ 52 | for i in range(len(data)): 53 | # Encode datum as Poisson spike trains. 54 | yield poisson(datum=data[i], time=time, dt=dt) 55 | 56 | 57 | def rank_order_loader( 58 | data: Union[torch.Tensor, Iterable[torch.Tensor]], 59 | time: int, 60 | dt: float = 1.0, 61 | **kwargs, 62 | ) -> Iterator[torch.Tensor]: 63 | # language=rst 64 | """ 65 | Lazily invokes ``bindsnet.encoding.rank_order`` to iteratively encode a sequence of 66 | data. 67 | 68 | :param data: Tensor of shape ``[n_samples, n_1, ..., n_k]``. 69 | :param time: Length of rank order-encoded spike train per input variable. 70 | :param dt: Simulation time step. 71 | :return: Tensors of shape ``[time, n_1, ..., n_k]`` of rank order-encoded spikes. 72 | """ 73 | for i in range(len(data)): 74 | # Encode datum as rank order-encoded spike trains. 75 | yield rank_order(datum=data[i], time=time, dt=dt) 76 | -------------------------------------------------------------------------------- /bindsnet/environment/README.md: -------------------------------------------------------------------------------- 1 | ## Dot Simulator 2 | 3 | ### Overview 4 | 5 | This simulator lets us generate dots and make them move in a configurable 2D space, providing a visual to a neural network for training in experiments. 6 | 7 | Specifically, this generates a grid for each timestep, where a specified number of points have values of 1 with fading tails ("decay"), designating the current positions and movements of their corresponding dots. All other points are set to 0. From timestep to timestep, the dots either remain where they are or move one space. 8 | 9 | The 2D observation of the current state is provided every step, as well as the reward, completion flag, and sucessful interception flag. It may be helpful to scale the grid values when encoding them as spike trains. 10 | 11 | The intended objective is to train a network to use its "network dot" to trace or intercept a moving "target" dot. But this simulator is designed to easily adapt to multiple kinds of experiments. 12 | 13 | 14 | ### Dot Movement 15 | 16 | By default, there is a single "target" dot that moves in a random direction every timestep (or it can stay still, which can be disabled), and as it moves, it leaves a tunable "decay" in the form of a fading tail. The simulator supports four directions of movement by default (up/down/left/right) by default, as well as remaining still, but the diag parameters allows diagonal movement for more complexity. The rate of the target's randomized movement can also be modified (ie. random direction every timestep or only change direction so often). 17 | 18 | The simulator supports multiple bounds-handling schemes. By default, dots will simply not move past the edges. Alternatively, the bound_hand parameter can be set to 'bounce', for a geometric reflection off the edges, or 'trans' which will have a mirrored result: a geometric translation to the opposite side of the grid. 19 | 20 | To add further complexity, additional targets can be added as desired via the dots parameter, and the herrs parameter can be set to generate multiple "red herrings" as distraction dots. The speed of the dots' movements can also be set; it is 1 by default. 21 | 22 |

23 | DotTraceSample 24 |

25 | >The grid visuals provided by the render function will double the value of the network dot; this is a visual aid only, invisible to the network. 26 | 27 | 28 | ### Reward Functions 29 | 30 | This simulator supports multiple reward functions (aka. fitness functions): 31 | - Euclidean (fit_func='euc'): the default option, this function computes the Euclidean (aka. Pythagorean) distance between the network dot and the target dot. 32 | - Displacement (fit_func='disp'): this option computes the x,y displacement of the network dot with respect to the target dot, returning an x,y tuple. Currently, BindsNET only supports single reward values. To use this one, either be creative or update the network code... 33 | - Range Rings (fit_func='rng'): this option uses the Euclidean distance and groups it into range rings. The radial distance of the range rings can be set by the ring_size parameter. 34 | - Directional (fit_func='dir'): the directional option checks to see if the network's decision moved its dot closer, laterally, or further away from the target dot's prior position (ie. before applying movement this timestep) and returns a +1, 0, or -1 accordingly. 35 | 36 | Additionally, upon a successful intercept, the network will receive +10 if the bullseye parameter is active, and its dot will be teleported to another random location if the teleport parameter is active. 37 | 38 | >In the event multiple target dots are generated, the fitness functions only compute rewards with respect to the first target dot. 39 | 40 | 41 | ### Additional Features 42 | 43 | The environment can take a seed for random number generation in python, numpy, and Pytorch; otherwise, it will generate and save a new seed based on the current system time. 44 | 45 | As this simulator was developed in Anaconda Spyder on Windows, it can be run from Windows or Linux. Since environments handle plotting differently, and experiments can sometimes be terminated prematurely, this environment supports the recording of grid observations in text files and post-op plotting. Live rendering can also be disabled via the mute parameter, and a text-based alternative using pandas dataframe formatting can be enabled via the pandas parameter. 46 | 47 | Filenames and file paths can be specified for recording grid observations. By default, the filenames will be "grid" followed by "s#_$.csv" where # is the random seed used and $ is the current file number. addFileSuffix(suffix) adds the provided suffix (typically used for "train" or "test") to the filename, and changeFileSuffix(sFrom, sTo) will find sFrom in the filename and replace it with sTo. 48 | 49 | To ensure that files do not become too large to either be saved or be practically useful, cycleOutFiles(newInt) can be used to cycle the current save file, incrementing the file number suffix, or resetting it if newInt is set to a positive number. 50 | 51 | Post-op plotting is supported by dotTrace_plotter.py in the analysis directory. By default, this tool searches the examples directory for csvs in "out" directories, but that path can be easily changed. It supports plotting ranges of grid observations, reward plots, and performance plots. See below for an example of recording reward and performance data for plotting purposes. 52 | 53 | 54 | ### Example 55 | See dot_tracing.py for an example in using the Dot Simulator for training an SNN in BindsNET. 56 | 57 | dot_tracing trains a basic RNN network on the dot simulator and demonstrates how to record reward and performance data (if desired) and plot spiking activity via monitors. 58 | 59 | 60 | -------------------------------------------------------------------------------- /bindsnet/environment/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.environment.environment import Environment, GymEnvironment 2 | 3 | __all__ = ["Environment", "GymEnvironment"] 4 | -------------------------------------------------------------------------------- /bindsnet/environment/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/environment/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/environment/__pycache__/environment.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/environment/__pycache__/environment.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.evaluation.evaluation import ( 2 | all_activity, 3 | assign_labels, 4 | logreg_fit, 5 | logreg_predict, 6 | ngram, 7 | proportion_weighting, 8 | update_ngram_scores, 9 | ) 10 | 11 | __all__ = [ 12 | "assign_labels", 13 | "logreg_fit", 14 | "logreg_predict", 15 | "all_activity", 16 | "proportion_weighting", 17 | "ngram", 18 | "update_ngram_scores", 19 | ] 20 | -------------------------------------------------------------------------------- /bindsnet/evaluation/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/evaluation/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/evaluation/__pycache__/evaluation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/evaluation/__pycache__/evaluation.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/learning/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.learning.learning import ( 2 | MSTDP, 3 | MSTDPET, 4 | Hebbian, 5 | LearningRule, 6 | NoOp, 7 | PostPre, 8 | Rmax, 9 | WeightDependentPostPre, 10 | ) 11 | 12 | __all__ = [ 13 | "LearningRule", 14 | "NoOp", 15 | "PostPre", 16 | "WeightDependentPostPre", 17 | "Hebbian", 18 | "MSTDP", 19 | "MSTDPET", 20 | "Rmax", 21 | ] 22 | -------------------------------------------------------------------------------- /bindsnet/learning/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/learning/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/learning/__pycache__/learning.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/learning/__pycache__/learning.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/learning/__pycache__/reward.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/learning/__pycache__/reward.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/learning/reward.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | 6 | class AbstractReward(ABC): 7 | # language=rst 8 | """ 9 | Abstract base class for reward computation. 10 | """ 11 | 12 | @abstractmethod 13 | def compute(self, **kwargs) -> None: 14 | # language=rst 15 | """ 16 | Computes/modifies reward. 17 | """ 18 | 19 | @abstractmethod 20 | def update(self, **kwargs) -> None: 21 | # language=rst 22 | """ 23 | Updates internal variables needed to modify reward. Usually called once per 24 | episode. 25 | """ 26 | 27 | 28 | class MovingAvgRPE(AbstractReward): 29 | # language=rst 30 | """ 31 | Computes reward prediction error (RPE) based on an exponential moving average (EMA) 32 | of past rewards. 33 | """ 34 | 35 | def __init__(self, **kwargs) -> None: 36 | # language=rst 37 | """ 38 | Constructor for EMA reward prediction error. 39 | """ 40 | self.reward_predict = torch.tensor(0.0) # Predicted reward (per step). 41 | self.reward_predict_episode = torch.tensor(0.0) # Predicted reward per episode. 42 | self.rewards_predict_episode = ( 43 | [] 44 | ) # List of predicted rewards per episode (used for plotting). 45 | 46 | def compute(self, **kwargs) -> torch.Tensor: 47 | # language=rst 48 | """ 49 | Computes the reward prediction error using EMA. 50 | 51 | Keyword arguments: 52 | 53 | :param Union[float, torch.Tensor] reward: Current reward. 54 | :return: Reward prediction error. 55 | """ 56 | # Get keyword arguments. 57 | reward = kwargs["reward"] 58 | 59 | return reward - self.reward_predict 60 | 61 | def update(self, **kwargs) -> None: 62 | # language=rst 63 | """ 64 | Updates the EMAs. Called once per episode. 65 | 66 | Keyword arguments: 67 | 68 | :param Union[float, torch.Tensor] accumulated_reward: Reward accumulated over 69 | one episode. 70 | :param int steps: Steps in that episode. 71 | :param float ema_window: Width of the averaging window. 72 | """ 73 | # Get keyword arguments. 74 | accumulated_reward = kwargs["accumulated_reward"] 75 | steps = torch.tensor(kwargs["steps"]).float() 76 | ema_window = torch.tensor(kwargs.get("ema_window", 10.0)) 77 | 78 | # Compute average reward per step. 79 | reward = accumulated_reward / steps 80 | 81 | # Update EMAs. 82 | self.reward_predict = ( 83 | 1 - 1 / ema_window 84 | ) * self.reward_predict + 1 / ema_window * reward 85 | self.reward_predict_episode = ( 86 | 1 - 1 / ema_window 87 | ) * self.reward_predict_episode + 1 / ema_window * accumulated_reward 88 | self.rewards_predict_episode.append(self.reward_predict_episode.item()) 89 | -------------------------------------------------------------------------------- /bindsnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.models.models import ( 2 | DiehlAndCook2015, 3 | DiehlAndCook2015v2, 4 | IncreasingInhibitionNetwork, 5 | LocallyConnectedNetwork, 6 | TwoLayerNetwork, 7 | ) 8 | 9 | __all__ = [ 10 | "TwoLayerNetwork", 11 | "DiehlAndCook2015v2", 12 | "DiehlAndCook2015", 13 | "IncreasingInhibitionNetwork", 14 | "LocallyConnectedNetwork", 15 | ] 16 | -------------------------------------------------------------------------------- /bindsnet/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/models/__pycache__/models.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/models/__pycache__/models.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/network/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.network import monitors, nodes, topology 2 | from bindsnet.network.network import Network, load 3 | 4 | __all__ = ["Network", "load", "nodes", "topology", "monitors"] 5 | -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/network/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/monitors.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/network/__pycache__/monitors.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/network.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/network/__pycache__/network.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/nodes.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/network/__pycache__/nodes.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/network/__pycache__/topology.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/network/__pycache__/topology.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from bindsnet.pipeline import action 2 | from bindsnet.pipeline.base_pipeline import BasePipeline 3 | from bindsnet.pipeline.dataloader_pipeline import ( 4 | DataLoaderPipeline, 5 | TorchVisionDatasetPipeline, 6 | ) 7 | from bindsnet.pipeline.environment_pipeline import EnvironmentPipeline 8 | 9 | __all__ = [ 10 | "EnvironmentPipeline", 11 | "BasePipeline", 12 | "DataLoaderPipeline", 13 | "TorchVisionDatasetPipeline", 14 | "action", 15 | ] 16 | -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/pipeline/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/action.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/pipeline/__pycache__/action.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/base_pipeline.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/pipeline/__pycache__/base_pipeline.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/pipeline/__pycache__/dataloader_pipeline.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/__pycache__/environment_pipeline.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/pipeline/__pycache__/environment_pipeline.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/pipeline/dataloader_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | from tqdm import tqdm 6 | 7 | from bindsnet.analysis.pipeline_analysis import PipelineAnalyzer 8 | from bindsnet.datasets import DataLoader 9 | from bindsnet.network import Network 10 | from bindsnet.pipeline.base_pipeline import BasePipeline 11 | 12 | 13 | class DataLoaderPipeline(BasePipeline): 14 | # language=rst 15 | """ 16 | A generic ``DataLoader`` pipeline that leverages the ``torch.utils.data`` setup. 17 | This still needs to be subclassed for specific implementations for functions given 18 | the dataset that will be used. An example can be seen in 19 | ``TorchVisionDatasetPipeline``. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | network: Network, 25 | train_ds: Dataset, 26 | test_ds: Optional[Dataset] = None, 27 | **kwargs, 28 | ) -> None: 29 | # language=rst 30 | """ 31 | Initializes the pipeline. 32 | 33 | :param network: Arbitrary ``network`` object. 34 | :param train_ds: Arbitrary ``torch.utils.data.Dataset`` object. 35 | :param test_ds: Arbitrary ``torch.utils.data.Dataset`` object. 36 | """ 37 | super().__init__(network, **kwargs) 38 | 39 | self.train_ds = train_ds 40 | self.test_ds = test_ds 41 | 42 | self.num_epochs = kwargs.get("num_epochs", 10) 43 | self.batch_size = kwargs.get("batch_size", 1) 44 | self.num_workers = kwargs.get("num_workers", 0) 45 | self.pin_memory = kwargs.get("pin_memory", True) 46 | self.shuffle = kwargs.get("shuffle", True) 47 | 48 | def train(self) -> None: 49 | # language=rst 50 | """ 51 | Training loop that runs for the set number of epochs and creates a new 52 | ``DataLoader`` at each epoch. 53 | """ 54 | for epoch in range(self.num_epochs): 55 | train_dataloader = DataLoader( 56 | self.train_ds, 57 | batch_size=self.batch_size, 58 | num_workers=self.num_workers, 59 | pin_memory=self.pin_memory, 60 | shuffle=self.shuffle, 61 | ) 62 | 63 | for step, batch in enumerate( 64 | tqdm( 65 | train_dataloader, 66 | desc="Epoch %d/%d" % (epoch + 1, self.num_epochs), 67 | total=len(self.train_ds) // self.batch_size, 68 | ) 69 | ): 70 | self.step(batch) 71 | 72 | def test(self) -> None: 73 | raise NotImplementedError("You need to provide a test function.") 74 | 75 | 76 | class TorchVisionDatasetPipeline(DataLoaderPipeline): 77 | # language=rst 78 | """ 79 | An example implementation of ``DataLoaderPipeline`` that runs all of the datasets 80 | inside of ``bindsnet.datasets`` that inherit from an instance of a 81 | ``torchvision.datasets``. These are documented in ``bindsnet/datasets/README.md``. 82 | This specific class just runs an unsupervised network. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | network: Network, 88 | train_ds: Dataset, 89 | pipeline_analyzer: Optional[PipelineAnalyzer] = None, 90 | **kwargs, 91 | ) -> None: 92 | # language=rst 93 | """ 94 | Initializes the pipeline. 95 | 96 | :param network: Arbitrary ``network`` object. 97 | :param train_ds: A ``torchvision.datasets`` wrapper dataset from 98 | ``bindsnet.datasets``. 99 | 100 | Keyword arguments: 101 | 102 | :param str input_layer: Layer of the network that receives input. 103 | """ 104 | super().__init__(network, train_ds, None, **kwargs) 105 | 106 | self.input_layer = kwargs.get("input_layer", "X") 107 | self.pipeline_analyzer = pipeline_analyzer 108 | 109 | def step_(self, batch: Dict[str, torch.Tensor], **kwargs) -> None: 110 | # language=rst 111 | """ 112 | Perform a pass of the network given the input batch. Unsupervised training 113 | (implying everything is stored inside of the ``network`` object, therefore 114 | returns ``None``. 115 | 116 | :param batch: A dictionary of the current batch. Includes image, label and 117 | encoded versions. 118 | """ 119 | self.network.reset_state_variables() 120 | inputs = {self.input_layer: batch["encoded_image"]} 121 | self.network.run(inputs, time=batch["encoded_image"].shape[0]) 122 | 123 | def init_fn(self) -> None: 124 | pass 125 | 126 | def plots(self, batch: Dict[str, torch.Tensor], *args) -> None: 127 | # language=rst 128 | """ 129 | Create any plots and logs for a step given the input batch. 130 | 131 | :param batch: A dictionary of the current batch. Includes image, label and 132 | encoded versions. 133 | """ 134 | if self.pipeline_analyzer is not None: 135 | self.pipeline_analyzer.plot_obs( 136 | batch["encoded_image"][0, ...].sum(0), step=self.step_count 137 | ) 138 | 139 | self.pipeline_analyzer.plot_spikes( 140 | self.get_spike_data(), step=self.step_count 141 | ) 142 | 143 | vr, tv = self.get_voltage_data() 144 | self.pipeline_analyzer.plot_voltages(vr, tv, step=self.step_count) 145 | 146 | self.pipeline_analyzer.finalize_step() 147 | 148 | def test_step(self): 149 | pass 150 | -------------------------------------------------------------------------------- /bindsnet/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .preprocessing import AbstractPreprocessor 2 | 3 | __all__ = ["AbstractPreprocessor"] 4 | -------------------------------------------------------------------------------- /bindsnet/preprocessing/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/preprocessing/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/preprocessing/__pycache__/preprocessing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/bindsnet/preprocessing/__pycache__/preprocessing.cpython-310.pyc -------------------------------------------------------------------------------- /bindsnet/preprocessing/preprocessing.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import pickle 4 | from abc import ABC, abstractmethod 5 | 6 | import torch 7 | 8 | 9 | class AbstractPreprocessor(ABC): 10 | # language=rst 11 | """ 12 | Abstract base class for Preprocessor. 13 | """ 14 | 15 | def process( 16 | self, 17 | csvfile: str, 18 | use_cache: bool = True, 19 | cachedfile: str = "./processed/data.pt", 20 | ) -> torch.tensor: 21 | # cache dictionary for storing encodings if previously encoded 22 | cache = {"verify": "", "data": None} 23 | 24 | # if the file exists 25 | if use_cache: 26 | # generate a hash 27 | cache["verify"] = self.__gen_hash(csvfile) 28 | 29 | # compare hash, if valid return cached value 30 | if self.__check_file(cachedfile, cache): 31 | return cache["data"] 32 | 33 | # otherwise process the data 34 | self._process(csvfile, cache) 35 | 36 | # save if use_cache 37 | if use_cache: 38 | self.__save(cachedfile, cache) 39 | 40 | # return data 41 | return cache["data"] 42 | 43 | @abstractmethod 44 | def _process(self, filename: str, cache: dict): 45 | # language=rst 46 | """ 47 | Method for defining how to preprocess the data. 48 | 49 | :param filename: File to load raw data from. 50 | :param cache: Dictionary for caching 'data' needs to be updated for caching to 51 | work. 52 | """ 53 | 54 | def __gen_hash(self, filename: str) -> str: 55 | # language=rst 56 | """ 57 | Generates an hash for a csv file and the preprocessor name. 58 | 59 | :param filename: File to generate hash for. 60 | :return: Hash for the csv file. 61 | """ 62 | # read all the lines 63 | with open(filename, "r") as f: 64 | lines = f.readlines() 65 | 66 | # generate md5 hash after concatenating all of the lines 67 | pre = "".join(lines) + str(self.__class__.__name__) 68 | m = hashlib.md5(pre.encode("utf-8")) 69 | return m.hexdigest() 70 | 71 | @staticmethod 72 | def __check_file(cachedfile: str, cache: dict) -> bool: 73 | # language=rst 74 | """ 75 | Compares the csv file and the saved file to see if a new encoding needs to be 76 | generated. 77 | 78 | :param cachedfile: The filename of the cached data. 79 | :param cache: Dictionary containing the current csv file hash. This is updated 80 | if the cache file has valid data. 81 | :return: Whether the cache is valid. 82 | """ 83 | # try opening the cached file 84 | try: 85 | with open(cachedfile, "rb") as f: 86 | temp = pickle.load(f) 87 | except FileNotFoundError: 88 | temp = {"verify": "", "data": None} 89 | 90 | # if the hash matches up, keep the data from the cache 91 | if cache["verify"] == temp["verify"]: 92 | cache["data"] = temp["data"] 93 | return True 94 | 95 | # otherwise don't do anything 96 | return False 97 | 98 | @staticmethod 99 | def __save(filename: str, data: dict) -> None: 100 | # language=rst 101 | """ 102 | Creates or overwrites existing encoding file. 103 | 104 | :param filename: Filename to save to. 105 | """ 106 | # if the directories in path don't exist create them 107 | if not os.path.exists(os.path.dirname(filename)): 108 | os.makedirs(os.path.dirname(filename), exist_ok=True) 109 | 110 | # save file 111 | with open(filename, "wb") as f: 112 | pickle.dump(data, f) 113 | -------------------------------------------------------------------------------- /docs/BindsNET benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/docs/BindsNET benchmark.png -------------------------------------------------------------------------------- /docs/DotTraceSample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/docs/DotTraceSample.png -------------------------------------------------------------------------------- /docs/Makefile.old: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = bindsnet 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/UML.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/docs/UML.png -------------------------------------------------------------------------------- /docs/directory_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/docs/directory_structure.png -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/docs/logo.png -------------------------------------------------------------------------------- /docs/make.bat.old: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=bindsnet 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/docs/pipeline.png -------------------------------------------------------------------------------- /docs/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "bindsnet_docs" 3 | dynamic = ["version"] 4 | dependencies = [ 5 | "sphinx==7.2.6", 6 | "sphinx_rtd_theme==1.3.0", 7 | "readthedocs-sphinx-search==0.3.2", 8 | "imagecodecs == 2023.9.18", 9 | "Jinja2 == 3.1.6", 10 | "wheel == 0.41.3", 11 | ] 12 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | # Defining the exact version will make sure things don't break 2 | #sphinx==6.2.1 3 | #sphinx_rtd_theme==1.2.2 4 | #readthedocs-sphinx-search==0.1.1 5 | #imagecodecs == 2023.9.18 6 | #Jinja2 == 3.1.6 7 | 8 | sphinx==7.2.6 9 | sphinx_rtd_theme==1.3.0 10 | readthedocs-sphinx-search==0.3.2 11 | imagecodecs == 2023.9.18 12 | Jinja2 == 3.1.6 13 | wheel == 0.41.3 14 | -------------------------------------------------------------------------------- /docs/source/bindsnet.analysis.rst: -------------------------------------------------------------------------------- 1 | bindsnet.analysis package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.analysis.pipeline\_analysis module 8 | ------------------------------------------- 9 | 10 | .. automodule:: bindsnet.analysis.pipeline_analysis 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bindsnet.analysis.plotting module 16 | --------------------------------- 17 | 18 | .. automodule:: bindsnet.analysis.plotting 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bindsnet.analysis.visualization module 24 | -------------------------------------- 25 | 26 | .. automodule:: bindsnet.analysis.visualization 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: bindsnet.analysis 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /docs/source/bindsnet.conversion.rst: -------------------------------------------------------------------------------- 1 | bindsnet.conversion package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.conversion.conversion module 8 | ------------------------------------- 9 | 10 | .. automodule:: bindsnet.conversion.conversion 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bindsnet.conversion.nodes module 16 | -------------------------------- 17 | 18 | .. automodule:: bindsnet.conversion.nodes 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bindsnet.conversion.topology module 24 | ----------------------------------- 25 | 26 | .. automodule:: bindsnet.conversion.topology 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: bindsnet.conversion 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /docs/source/bindsnet.datasets.rst: -------------------------------------------------------------------------------- 1 | bindsnet.datasets package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.datasets.alov300 module 8 | -------------------------------- 9 | 10 | .. automodule:: bindsnet.datasets.alov300 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bindsnet.datasets.collate module 16 | -------------------------------- 17 | 18 | .. automodule:: bindsnet.datasets.collate 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bindsnet.datasets.dataloader module 24 | ----------------------------------- 25 | 26 | .. automodule:: bindsnet.datasets.dataloader 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | bindsnet.datasets.davis module 32 | ------------------------------ 33 | 34 | .. automodule:: bindsnet.datasets.davis 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | bindsnet.datasets.preprocess module 40 | ----------------------------------- 41 | 42 | .. automodule:: bindsnet.datasets.preprocess 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | bindsnet.datasets.spoken\_mnist module 48 | -------------------------------------- 49 | 50 | .. automodule:: bindsnet.datasets.spoken_mnist 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | bindsnet.datasets.torchvision\_wrapper module 56 | --------------------------------------------- 57 | 58 | .. automodule:: bindsnet.datasets.torchvision_wrapper 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | 64 | Module contents 65 | --------------- 66 | 67 | .. automodule:: bindsnet.datasets 68 | :members: 69 | :undoc-members: 70 | :show-inheritance: 71 | -------------------------------------------------------------------------------- /docs/source/bindsnet.encoding.rst: -------------------------------------------------------------------------------- 1 | bindsnet.encoding package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.encoding.encoders module 8 | --------------------------------- 9 | 10 | .. automodule:: bindsnet.encoding.encoders 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bindsnet.encoding.encodings module 16 | ---------------------------------- 17 | 18 | .. automodule:: bindsnet.encoding.encodings 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bindsnet.encoding.loaders module 24 | -------------------------------- 25 | 26 | .. automodule:: bindsnet.encoding.loaders 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: bindsnet.encoding 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /docs/source/bindsnet.environment.rst: -------------------------------------------------------------------------------- 1 | bindsnet.environment package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.environment.environment module 8 | --------------------------------------- 9 | 10 | .. automodule:: bindsnet.environment.environment 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: bindsnet.environment 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/bindsnet.evaluation.rst: -------------------------------------------------------------------------------- 1 | bindsnet.evaluation package 2 | =========================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.evaluation.evaluation module 8 | ------------------------------------- 9 | 10 | .. automodule:: bindsnet.evaluation.evaluation 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: bindsnet.evaluation 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/bindsnet.learning.rst: -------------------------------------------------------------------------------- 1 | bindsnet.learning package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.learning.learning module 8 | --------------------------------- 9 | 10 | .. automodule:: bindsnet.learning.learning 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bindsnet.learning.reward module 16 | ------------------------------- 17 | 18 | .. automodule:: bindsnet.learning.reward 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: bindsnet.learning 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/source/bindsnet.models.rst: -------------------------------------------------------------------------------- 1 | bindsnet.models package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.models.models module 8 | ----------------------------- 9 | 10 | .. automodule:: bindsnet.models.models 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: bindsnet.models 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/bindsnet.network.rst: -------------------------------------------------------------------------------- 1 | bindsnet.network package 2 | ======================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.network.monitors module 8 | -------------------------------- 9 | 10 | .. automodule:: bindsnet.network.monitors 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bindsnet.network.network module 16 | ------------------------------- 17 | 18 | .. automodule:: bindsnet.network.network 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bindsnet.network.nodes module 24 | ----------------------------- 25 | 26 | .. automodule:: bindsnet.network.nodes 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | bindsnet.network.topology module 32 | -------------------------------- 33 | 34 | .. automodule:: bindsnet.network.topology 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: bindsnet.network 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/source/bindsnet.pipeline.rst: -------------------------------------------------------------------------------- 1 | bindsnet.pipeline package 2 | ========================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.pipeline.action module 8 | ------------------------------- 9 | 10 | .. automodule:: bindsnet.pipeline.action 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bindsnet.pipeline.base\_pipeline module 16 | --------------------------------------- 17 | 18 | .. automodule:: bindsnet.pipeline.base_pipeline 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bindsnet.pipeline.dataloader\_pipeline module 24 | --------------------------------------------- 25 | 26 | .. automodule:: bindsnet.pipeline.dataloader_pipeline 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | bindsnet.pipeline.environment\_pipeline module 32 | ---------------------------------------------- 33 | 34 | .. automodule:: bindsnet.pipeline.environment_pipeline 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | 40 | Module contents 41 | --------------- 42 | 43 | .. automodule:: bindsnet.pipeline 44 | :members: 45 | :undoc-members: 46 | :show-inheritance: 47 | -------------------------------------------------------------------------------- /docs/source/bindsnet.preprocessing.rst: -------------------------------------------------------------------------------- 1 | bindsnet.preprocessing package 2 | ============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | bindsnet.preprocessing.preprocessing module 8 | ------------------------------------------- 9 | 10 | .. automodule:: bindsnet.preprocessing.preprocessing 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | 16 | Module contents 17 | --------------- 18 | 19 | .. automodule:: bindsnet.preprocessing 20 | :members: 21 | :undoc-members: 22 | :show-inheritance: 23 | -------------------------------------------------------------------------------- /docs/source/bindsnet.rst: -------------------------------------------------------------------------------- 1 | bindsnet package 2 | ================ 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | bindsnet.analysis 10 | bindsnet.conversion 11 | bindsnet.datasets 12 | bindsnet.encoding 13 | bindsnet.environment 14 | bindsnet.evaluation 15 | bindsnet.learning 16 | bindsnet.models 17 | bindsnet.network 18 | bindsnet.pipeline 19 | bindsnet.preprocessing 20 | 21 | Submodules 22 | ---------- 23 | 24 | bindsnet.utils module 25 | --------------------- 26 | 27 | .. automodule:: bindsnet.utils 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | 32 | 33 | Module contents 34 | --------------- 35 | 36 | .. automodule:: bindsnet 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # bindsnet documentation build configuration file, created by 5 | # sphinx-quickstart on Tue May 1 21:54:09 2018. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | 23 | sys.path.insert(0, os.path.abspath("../..")) 24 | 25 | # -- General configuration ------------------------------------------------ 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | # 29 | # needs_sphinx = '1.0' 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = [ 35 | "sphinx.ext.autodoc", 36 | "sphinx.ext.doctest", 37 | "sphinx.ext.coverage", 38 | "sphinx.ext.mathjax", 39 | "sphinx.ext.viewcode", 40 | "sphinx.ext.githubpages", 41 | ] 42 | 43 | # Add any paths that contain templates here, relative to this directory. 44 | templates_path = ["_templates"] 45 | 46 | # autodoc of module special functions 47 | autoclass_content = "both" 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = ".rst" 54 | 55 | # The master toctree document. 56 | master_doc = "index" 57 | 58 | # General information about the project. 59 | project = "bindsnet" 60 | copyright = "2019, Daniel Saunders, Hananel Hazan" 61 | author = "Daniel Saunders, Hananel Hazan" 62 | 63 | # The version info for the project you're documenting, acts as replacement for 64 | # |version| and |release|, also used in various other places throughout the 65 | # built documents. 66 | # 67 | # The short X.Y version. 68 | # version = "0.2.5" 69 | # The full version, including alpha/beta/rc tags. 70 | # release = "0.2.5" 71 | 72 | # The language for content autogenerated by Sphinx. Refer to documentation 73 | # for a list of supported languages. 74 | # 75 | # This is also used if you do content translation via gettext catalogs. 76 | # Usually you set "language" from the command line for these cases. 77 | language = None 78 | 79 | # List of patterns, relative to source directory, that match files and 80 | # directories to ignore when looking for source files. 81 | # This patterns also effect to html_static_path and html_extra_path 82 | exclude_patterns = [] 83 | 84 | # The name of the Pygments (syntax highlighting) style to use. 85 | pygments_style = "sphinx" 86 | 87 | # If true, `todo` and `todoList` produce output, else they produce nothing. 88 | todo_include_todos = False 89 | 90 | # -- Options for HTML output ---------------------------------------------- 91 | 92 | # The theme to use for HTML and HTML Help pages. See the documentation for 93 | # a list of builtin themes. 94 | # 95 | html_theme = "sphinx_rtd_theme" 96 | 97 | # Theme options are theme-specific and customize the look and feel of a theme 98 | # further. For a list of options available for each theme, see the 99 | # documentation. 100 | # 101 | # html_theme_options = {} 102 | 103 | # Add any paths that contain custom static files (such as style sheets) here, 104 | # relative to this directory. They are copied after the builtin static files, 105 | # so a file named "default.css" will overwrite the builtin "default.css". 106 | html_static_path = ["_static"] 107 | 108 | # Custom sidebar templates, must be a dictionary that maps document names 109 | # to template names. 110 | # 111 | # This is required for the alabaster theme 112 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 113 | html_sidebars = { 114 | "**": [ 115 | "about.html", 116 | "navigation.html", 117 | "relations.html", # needs 'show_related': True theme option to display 118 | "searchbox.html", 119 | "donate.html", 120 | ] 121 | } 122 | 123 | # -- Options for HTMLHelp output ------------------------------------------ 124 | 125 | # Output file base name for HTML help builder. 126 | htmlhelp_basename = "bindsnetdoc" 127 | 128 | # -- Options for LaTeX output --------------------------------------------- 129 | 130 | latex_elements = { 131 | # The paper size ('letterpaper' or 'a4paper'). 132 | # 133 | # 'papersize': 'letterpaper', 134 | # The font size ('10pt', '11pt' or '12pt'). 135 | # 136 | # 'pointsize': '10pt', 137 | # Additional stuff for the LaTeX preamble. 138 | # 139 | # 'preamble': '', 140 | # Latex figure (float) alignment 141 | # 142 | # 'figure_align': 'htbp', 143 | } 144 | 145 | # Grouping the document tree into LaTeX files. List of tuples 146 | # (source start file, target name, title, 147 | # author, documentclass [howto, manual, or own class]). 148 | latex_documents = [ 149 | ( 150 | master_doc, 151 | "bindsnet.tex", 152 | "bindsnet Documentation", 153 | "Daniel Saunders, Hananel Hazan", 154 | "manual", 155 | ) 156 | ] 157 | 158 | # -- Options for manual page output --------------------------------------- 159 | 160 | # One entry per manual page. List of tuples 161 | # (source start file, name, description, authors, manual section). 162 | man_pages = [(master_doc, "bindsnet", "bindsnet Documentation", [author], 1)] 163 | 164 | # -- Options for Texinfo output ------------------------------------------- 165 | 166 | # Grouping the document tree into Texinfo files. List of tuples 167 | # (source start file, target name, title, author, 168 | # dir menu entry, description, category) 169 | texinfo_documents = [ 170 | ( 171 | master_doc, 172 | "bindsnet", 173 | "bindsnet Documentation", 174 | author, 175 | "bindsnet", 176 | "One line description of project.", 177 | "Miscellaneous", 178 | ) 179 | ] 180 | -------------------------------------------------------------------------------- /docs/source/guide.rst: -------------------------------------------------------------------------------- 1 | BindsNET User Manual 2 | ==================== 3 | 4 | 5 | Welcome to BindsNET's user manual! To get started, click on one of the links below. 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | :caption: Table of Contents: 10 | 11 | guide/guide_part_i 12 | guide/guide_part_ii 13 | -------------------------------------------------------------------------------- /docs/source/guide/spikes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/docs/source/guide/spikes.png -------------------------------------------------------------------------------- /docs/source/guide/voltages.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/docs/source/guide/voltages.png -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. bindsnet documentation master file, created by 2 | sphinx-quickstart on Wed Apr 11 13:44:33 2018. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to BindsNET's documentation! 7 | ==================================== 8 | 9 | BindsNET is built on top of the `PyTorch `_ deep learning platform. It is used for the simulation 10 | of spiking neural networks (SNNs) and is geared towards machine learning and reinforcement learning. 11 | 12 | BindsNET takes advantage of the :code:`torch.Tensor` object to build spiking neurons and connections between them, and 13 | simulate them on CPUs or GPUs (for strong acceleration / parallelization) without any extra work. Recently, 14 | :code:`torchvision.datasets` has been integrated into the library to allow the use of popular vision datasets in 15 | training SNNs for computer vision tasks. Neural network functionality contained in :code:`torch.nn.functional` module is 16 | used to implement more complex connections between populations of spiking neurons. 17 | 18 | Spiking neural networks are sometimes referred to as the `third generation of neural networks 19 | `_. Rather than the simple linear layers and nonlinear activation functions of deep learning neural networks, SNNs are composed of neural units which more accurately capture properties of their biological counterparts. An important difference between spiking neurons and the artificial neurons of deep learning are the former's integration of input *in time*; they are naturally short-term memory devices by their maintenance of a (possibly decaying) membrane voltage. As a result, some have argued that SNNs are particularly well-suited to model time-varying data. 20 | 21 | Neurons are connected together with directed edges (*synapses*) which are (in general) plastic. Synapses may have their own dynamics as well, which may or may not `depend on pre- and post-synaptic neural activity https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3395004/` or `other biological signals https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4717313/`. The modification of synaptic strengths is thought to be an important mechanism by which organisms learn. Accordingly, BindsNET provides a module (**bindsnet.learning**) which contains functions used for the updating of synapse weights. 22 | 23 | At its core, BindsNET provides software objects and methods which support the simulation of groups of different types of neurons (**bindsnet.network.nodes**), as well as different types of connections between them (**bindsnet.network.topology**). These may be arbitrarily combined together under a single **bindsnet.network.Network** object, which is responsible for the coordination of the simulation logic of all underlying components. On creation of a network, the user can specify a simulation timestep constant, :math:`dt`, which determines the granularity of the simulation. Choosing this parameter induces a trade-off between simulation speed and numerical precision: large values result in fast simulation, but poor simulation accuracy, and vice versa. Monitors (**bindsnet.network.monitors**) are available for recording state variables from arbitrary network components (e.g., the voltage :math:`v` of a group of neurons). 24 | 25 | The development of BindsNET is supported by the Defense Advanced Research Project Agency Grant DARPA/MTO HR0011-16-l-0006. 26 | 27 | .. toctree:: 28 | :maxdepth: 2 29 | :caption: Contents: 30 | 31 | installation 32 | quickstart 33 | guide 34 | 35 | .. toctree:: 36 | :maxdepth: 2 37 | :caption: Package reference 38 | 39 | bindsnet 40 | 41 | Indices and tables 42 | ================== 43 | 44 | * :ref:`genindex` 45 | * :ref:`search` 46 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | 5 | Pip install 6 | ----------- 7 | 8 | Issue: 9 | 10 | .. code-block:: bash 11 | 12 | pip install git+https://github.com/BindsNET/bindsnet.git 13 | 14 | 15 | Installing from source 16 | ---------------------- 17 | 18 | On \*nix systems, issue one of the following in a shell: 19 | 20 | .. code-block:: bash 21 | 22 | git clone https://github.com/Hananel-Hazan/bindsnet.git # HTTPS 23 | git clone git@github.com:Hananel-Hazan/bindsnet.git # SSH 24 | 25 | Change directory into :code:`bindsnet` and issue one of the following: 26 | 27 | .. code-block:: bash 28 | 29 | pip install . # Typical install 30 | pip install -e . # Editable mode (package code can be edited without reinstall) 31 | 32 | This will install :code:`bindsnet` and all its dependencies. 33 | 34 | 35 | Running the tests 36 | ----------------- 37 | 38 | If BindsNET is installed from source, install :code:`pytest` and issue the following from BindsNET's installation directory: 39 | 40 | .. code-block:: bash 41 | 42 | python -m pytest test 43 | 44 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | bindsnet 2 | ======== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | bindsnet 8 | -------------------------------------------------------------------------------- /docs/source/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ========== 3 | 4 | Check out some example use cases for BindsNET in the :code:`examples/` folder 5 | (`link `_). For example, changing directory to 6 | `[bindsnet-root]/examples/mnist` and running the following will result in a near-replication of the architecture of 7 | `Diehl & Cook 2015 `_: 8 | 9 | .. code-block:: bash 10 | 11 | python eth_mnist.py [options] 12 | 13 | 14 | The token :code:`[options]` should be replaced with any command-line arguments you'd like to use to modify the behavior 15 | of the program. 16 | -------------------------------------------------------------------------------- /examples/benchmark/annarchy.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | from time import time as t 6 | 7 | import ANNarchy 8 | import numpy as np 9 | import pandas as pd 10 | 11 | plots_path = os.path.join("..", "..", "figures") 12 | benchmark_path = os.path.join("..", "..", "benchmark") 13 | if not os.path.isdir(benchmark_path): 14 | os.makedirs(benchmark_path) 15 | 16 | 17 | def ANNarchy_cpu(n_neurons, time): 18 | ANNarchy.setup(paradigm="openmp", dt=1.0) 19 | ANNarchy.clear() 20 | 21 | t1 = t() 22 | 23 | IF = ANNarchy.Neuron( 24 | parameters=""" 25 | tau_m = 10.0 26 | tau_e = 5.0 27 | vt = -54.0 28 | vr = -60.0 29 | El = -74.0 30 | Ee = 0.0 31 | """, 32 | equations=""" 33 | tau_m * dv/dt = El - v + g_exc * (Ee - vr) : init = -60.0 34 | tau_e * dg_exc/dt = - g_exc 35 | """, 36 | spike=""" 37 | v > vt 38 | """, 39 | reset=""" 40 | v = vr 41 | """, 42 | ) 43 | 44 | Input = ANNarchy.PoissonPopulation(name="Input", geometry=n_neurons, rates=50.0) 45 | Output = ANNarchy.Population(name="Output", geometry=n_neurons, neuron=IF) 46 | proj = ANNarchy.Projection(pre=Input, post=Output, target="exc", synapse=None) 47 | proj.connect_all_to_all(weights=ANNarchy.Uniform(0.0, 1.0)) 48 | 49 | ANNarchy.compile() 50 | ANNarchy.simulate(duration=time) 51 | 52 | return t() - t1 53 | 54 | 55 | def ANNarchy_gpu(n_neurons, time): 56 | ANNarchy.setup(paradigm="cuda", dt=1.0) 57 | ANNarchy.clear() 58 | 59 | t1 = t() 60 | 61 | IF = ANNarchy.Neuron( 62 | parameters=""" 63 | tau_m = 10.0 64 | tau_e = 5.0 65 | vt = -54.0 66 | vr = -60.0 67 | El = -74.0 68 | Ee = 0.0 69 | """, 70 | equations=""" 71 | tau_m * dv/dt = El - v + g_exc * (Ee - vr) : init = -60.0 72 | tau_e * dg_exc/dt = - g_exc 73 | """, 74 | spike=""" 75 | v > vt 76 | """, 77 | reset=""" 78 | v = vr 79 | """, 80 | ) 81 | 82 | Input = ANNarchy.PoissonPopulation(name="Input", geometry=n_neurons, rates=50.0) 83 | Output = ANNarchy.Population(name="Output", geometry=n_neurons, neuron=IF) 84 | proj = ANNarchy.Projection(pre=Input, post=Output, target="exc", synapse=None) 85 | proj.connect_all_to_all(weights=ANNarchy.Uniform(0.0, 1.0)) 86 | 87 | ANNarchy.compile() 88 | ANNarchy.simulate(duration=time) 89 | 90 | return t() - t1 91 | 92 | 93 | def main(start=100, stop=1000, step=100, time=1000, interval=100, plot=False): 94 | times = {"ANNarchy_cpu": []} 95 | 96 | f = os.path.join( 97 | benchmark_path, "benchmark_{start}_{stop}_{step}_{time}.csv".format(**locals()) 98 | ) 99 | if not os.path.isfile(f): 100 | raise Exception("{0} not found.".format(f)) 101 | 102 | for n_neurons in range(start, stop + step, step): 103 | print("\nRunning benchmark with {0} neurons.".format(n_neurons)) 104 | for framework in times.keys(): 105 | if framework == "ANNarchy_cpu" and n_neurons > 5000: 106 | times[framework].append(np.nan) 107 | continue 108 | 109 | print("- {0}:".format(framework), end=" ") 110 | 111 | fn = globals()[framework] 112 | elapsed = fn(n_neurons=n_neurons, time=time) 113 | times[framework].append(elapsed) 114 | 115 | print("(elapsed: {0:.4f})".format(elapsed)) 116 | 117 | df = pd.read_csv(f, index_col=0) 118 | 119 | for framework in times.keys(): 120 | print(pd.Series(times[framework])) 121 | df[framework] = times[framework] 122 | 123 | print() 124 | print(df) 125 | print() 126 | 127 | df.to_csv(f) 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument("--start", type=int, default=100) 133 | parser.add_argument("--stop", type=int, default=1000) 134 | parser.add_argument("--step", type=int, default=100) 135 | parser.add_argument("--time", type=int, default=1000) 136 | parser.add_argument("--interval", type=int, default=100) 137 | parser.add_argument("--plot", dest="plot", action="store_true") 138 | parser.set_defaults(plot=False) 139 | args = parser.parse_args() 140 | 141 | print(args) 142 | 143 | main( 144 | start=args.start, 145 | stop=args.stop, 146 | step=args.step, 147 | time=args.time, 148 | interval=args.interval, 149 | plot=args.plot, 150 | ) 151 | -------------------------------------------------------------------------------- /examples/benchmark/gpu_annarchy.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | from time import time as t 6 | 7 | import ANNarchy 8 | import numpy as np 9 | import pandas as pd 10 | 11 | plots_path = os.path.join("..", "..", "figures") 12 | benchmark_path = os.path.join("..", "..", "benchmark") 13 | if not os.path.isdir(benchmark_path): 14 | os.makedirs(benchmark_path) 15 | 16 | 17 | def ANNarchy_gpu(n_neurons, time): 18 | t0 = t() 19 | 20 | ANNarchy.setup(paradigm="cuda", dt=1.0) 21 | ANNarchy.clear() 22 | 23 | IF = ANNarchy.Neuron( 24 | parameters=""" 25 | tau_m = 10.0 26 | tau_e = 5.0 27 | vt = -54.0 28 | vr = -60.0 29 | El = -74.0 30 | Ee = 0.0 31 | """, 32 | equations=""" 33 | tau_m * dv/dt = El - v + g_exc * (Ee - vr) : init = -60.0 34 | tau_e * dg_exc/dt = - g_exc 35 | """, 36 | spike=""" 37 | v > vt 38 | """, 39 | reset=""" 40 | v = vr 41 | """, 42 | ) 43 | 44 | Input = ANNarchy.PoissonPopulation(name="Input", geometry=n_neurons, rates=50.0) 45 | Output = ANNarchy.Population(name="Output", geometry=n_neurons, neuron=IF) 46 | proj = ANNarchy.Projection(pre=Input, post=Output, target="exc", synapse=None) 47 | proj.connect_all_to_all(weights=ANNarchy.Uniform(0.0, 1.0)) 48 | 49 | ANNarchy.compile() 50 | 51 | t1 = t() 52 | 53 | ANNarchy.simulate(duration=time) 54 | 55 | return t() - t0, t() - t1 56 | 57 | 58 | def main(start=100, stop=1000, step=100, time=1000, interval=100, plot=False): 59 | times = {"ANNarchy_gpu": [], "ANNarchy_gpu (w/ comp.)": []} 60 | 61 | f = os.path.join( 62 | benchmark_path, "benchmark_{start}_{stop}_{step}_{time}.csv".format(**locals()) 63 | ) 64 | if not os.path.isfile(f): 65 | raise Exception("{0} not found.".format(f)) 66 | 67 | for n_neurons in range(start, stop + step, step): 68 | print("\nRunning benchmark with {0} neurons.".format(n_neurons)) 69 | for framework in times.keys(): 70 | if "comp" in framework: 71 | continue 72 | 73 | print("- {0}:".format(framework), end=" ") 74 | 75 | fn = globals()[framework] 76 | total, sim = fn(n_neurons=n_neurons, time=time) 77 | times[framework].append(sim) 78 | times[framework + " (w/ comp.)"].append(total) 79 | 80 | print("(total, sim: {0:.4f}, {1:.4f})".format(total, sim)) 81 | 82 | df = pd.read_csv(f, index_col=0) 83 | 84 | for framework in times.keys(): 85 | print(pd.Series(times[framework])) 86 | df[framework] = times[framework] 87 | 88 | print() 89 | print(df) 90 | print() 91 | 92 | df.to_csv(f) 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--start", type=int, default=100) 98 | parser.add_argument("--stop", type=int, default=1000) 99 | parser.add_argument("--step", type=int, default=100) 100 | parser.add_argument("--time", type=int, default=1000) 101 | parser.add_argument("--interval", type=int, default=100) 102 | parser.add_argument("--plot", dest="plot", action="store_true") 103 | parser.set_defaults(plot=False) 104 | args = parser.parse_args() 105 | 106 | print(args) 107 | 108 | main( 109 | start=args.start, 110 | stop=args.stop, 111 | step=args.step, 112 | time=args.time, 113 | interval=args.interval, 114 | plot=args.plot, 115 | ) 116 | -------------------------------------------------------------------------------- /examples/benchmark/plot_benchmark.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | import pandas as pd 6 | from experiments import ROOT_DIR 7 | 8 | benchmark_path = os.path.join(ROOT_DIR, "benchmark") 9 | figure_path = os.path.join(ROOT_DIR, "figures") 10 | 11 | if not os.path.isdir(benchmark_path): 12 | os.makedirs(benchmark_path) 13 | 14 | 15 | def main(start=100, stop=1000, step=100, time=1000, interval=100, plot=False): 16 | name = f"benchmark_{start}_{stop}_{step}_{time}" 17 | f = os.path.join(benchmark_path, name + ".csv") 18 | df = pd.read_csv(f, index_col=0) 19 | 20 | plt.plot(df["BindsNET_cpu"], label="BindsNET (CPU)", linestyle="-", color="b") 21 | plt.plot(df["BindsNET_gpu"], label="BindsNET (GPU)", linestyle="-", color="g") 22 | plt.plot(df["BRIAN2"], label="BRIAN2", linestyle="--", color="r") 23 | plt.plot(df["BRIAN2GENN"], label="brian2genn", linestyle="--", color="c") 24 | plt.plot(df["BRIAN2GENN comp."], label="brian2genn comp.", linestyle=":", color="c") 25 | plt.plot(df["PyNEST"], label="PyNEST", linestyle="--", color="y") 26 | plt.plot(df["ANNarchy_cpu"], label="ANNarchy (CPU)", linestyle="--", color="m") 27 | plt.plot(df["ANNarchy_gpu"], label="ANNarchy (GPU)", linestyle="--", color="k") 28 | plt.plot( 29 | df["ANNarchy_gpu comp."], label="ANNarchy (GPU) comp.", linestyle=":", color="k" 30 | ) 31 | 32 | # for c in df.columns: 33 | # if 'BindsNET' in c: 34 | # plt.plot(df[c], label=c, linestyle='-') 35 | # else: 36 | # plt.plot(df[c], label=c, linestyle='--') 37 | 38 | plt.title("Benchmark comparison of SNN simulation libraries") 39 | plt.xticks(range(0, stop + interval, interval)) 40 | plt.xlabel("Number of input / output neurons") 41 | plt.ylabel("Simulation time (seconds)") 42 | plt.legend(loc=1, prop={"size": 5}) 43 | plt.yscale("log") 44 | 45 | plt.savefig(os.path.join(figure_path, name + ".png")) 46 | 47 | if plot: 48 | plt.show() 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--start", type=int, default=100) 54 | parser.add_argument("--stop", type=int, default=1000) 55 | parser.add_argument("--step", type=int, default=100) 56 | parser.add_argument("--time", type=int, default=1000) 57 | parser.add_argument("--interval", type=int, default=1000) 58 | parser.add_argument("--plot", dest="plot", action="store_true") 59 | parser.set_defaults(plot=False) 60 | args = parser.parse_args() 61 | 62 | main( 63 | start=args.start, 64 | stop=args.stop, 65 | step=args.step, 66 | time=args.time, 67 | interval=args.interval, 68 | plot=args.plot, 69 | ) 70 | -------------------------------------------------------------------------------- /examples/breakout/breakout.py: -------------------------------------------------------------------------------- 1 | from bindsnet.encoding import bernoulli 2 | from bindsnet.environment import GymEnvironment 3 | from bindsnet.network import Network 4 | from bindsnet.network.nodes import Input, IzhikevichNodes 5 | from bindsnet.network.topology import Connection 6 | from bindsnet.pipeline import EnvironmentPipeline 7 | from bindsnet.pipeline.action import select_softmax 8 | 9 | # Build network. 10 | network = Network(dt=1.0) 11 | 12 | # Layers of neurons. 13 | inpt = Input(n=80 * 80, shape=[1, 1, 1, 80, 80], traces=True) 14 | middle = IzhikevichNodes(n=100, traces=True) 15 | out = IzhikevichNodes(n=4, refrac=0, traces=True) 16 | 17 | # Connections between layers. 18 | inpt_middle = Connection(source=inpt, target=middle, wmin=0, wmax=1) 19 | middle_out = Connection(source=middle, target=out, wmin=0, wmax=1) 20 | 21 | # Add all layers and connections to the network. 22 | network.add_layer(inpt, name="Input Layer") 23 | network.add_layer(middle, name="Hidden Layer") 24 | network.add_layer(out, name="Output Layer") 25 | network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer") 26 | network.add_connection(middle_out, source="Hidden Layer", target="Output Layer") 27 | 28 | # Load the Breakout environment. 29 | environment = GymEnvironment("BreakoutDeterministic-v4") 30 | environment.reset() 31 | 32 | # Build pipeline from specified components. 33 | pipeline = EnvironmentPipeline( 34 | network, 35 | environment, 36 | encoding=bernoulli, 37 | action_function=select_softmax, 38 | output="Output Layer", 39 | time=100, 40 | history_length=1, 41 | delta=1, 42 | plot_interval=1, 43 | render_interval=1, 44 | ) 45 | 46 | # Run environment simulation for 100 episodes. 47 | for i in range(100): 48 | total_reward = 0 49 | pipeline.reset_state_variables() 50 | is_done = False 51 | while not is_done: 52 | result = pipeline.env_step() 53 | pipeline.step(result) 54 | 55 | reward = result[1] 56 | total_reward += reward 57 | 58 | is_done = result[2] 59 | print(f"Episode {i} total reward:{total_reward}") 60 | -------------------------------------------------------------------------------- /examples/breakout/breakout_stdp.py: -------------------------------------------------------------------------------- 1 | from bindsnet.encoding import bernoulli 2 | from bindsnet.environment import GymEnvironment 3 | from bindsnet.learning import MSTDP 4 | from bindsnet.network import Network 5 | from bindsnet.network.nodes import Input, LIFNodes 6 | from bindsnet.network.topology import Connection 7 | from bindsnet.pipeline import EnvironmentPipeline 8 | from bindsnet.pipeline.action import select_softmax 9 | 10 | # Build network. 11 | network = Network(dt=1.0) 12 | 13 | # Layers of neurons. 14 | inpt = Input(n=80 * 80, shape=[1, 1, 1, 80, 80], traces=True) 15 | middle = LIFNodes(n=100, traces=True) 16 | out = LIFNodes(n=4, refrac=0, traces=True) 17 | 18 | # Connections between layers. 19 | inpt_middle = Connection(source=inpt, target=middle, wmin=0, wmax=1e-1) 20 | middle_out = Connection( 21 | source=middle, 22 | target=out, 23 | wmin=0, 24 | wmax=1, 25 | update_rule=MSTDP, 26 | nu=1e-1, 27 | norm=0.5 * middle.n, 28 | ) 29 | 30 | # Add all layers and connections to the network. 31 | network.add_layer(inpt, name="Input Layer") 32 | network.add_layer(middle, name="Hidden Layer") 33 | network.add_layer(out, name="Output Layer") 34 | network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer") 35 | network.add_connection(middle_out, source="Hidden Layer", target="Output Layer") 36 | 37 | # Load the Breakout environment. 38 | environment = GymEnvironment("BreakoutDeterministic-v4", render_mode="human") 39 | environment.reset() 40 | 41 | # Build pipeline from specified components. 42 | environment_pipeline = EnvironmentPipeline( 43 | network, 44 | environment, 45 | encoding=bernoulli, 46 | action_function=select_softmax, 47 | output="Output Layer", 48 | time=100, 49 | history_length=1, 50 | delta=1, 51 | plot_interval=1, 52 | render_interval=1, 53 | ) 54 | 55 | 56 | def run_pipeline(pipeline, episode_count): 57 | for i in range(episode_count): 58 | total_reward = 0 59 | pipeline.reset_state_variables() 60 | is_done = False 61 | while not is_done: 62 | result = pipeline.env_step() 63 | pipeline.step(result) 64 | 65 | reward = result[1] 66 | total_reward += reward 67 | 68 | is_done = result[2] 69 | print(f"Episode {i} total reward:{total_reward}") 70 | 71 | 72 | # enable MSTDP 73 | environment_pipeline.network.learning = True 74 | 75 | print("Training: ") 76 | run_pipeline(environment_pipeline, episode_count=100) 77 | 78 | # stop MSTDP 79 | environment_pipeline.network.learning = False 80 | 81 | print("Testing: ") 82 | run_pipeline(environment_pipeline, episode_count=100) 83 | -------------------------------------------------------------------------------- /examples/breakout/play_breakout_from_ANN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Iterable, Optional, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from tqdm import tqdm 8 | 9 | from bindsnet.encoding import bernoulli, poisson 10 | from bindsnet.environment import GymEnvironment 11 | from bindsnet.network import Network 12 | from bindsnet.network.nodes import ( 13 | AbstractInput, 14 | IFNodes, 15 | Input, 16 | IzhikevichNodes, 17 | LIFNodes, 18 | Nodes, 19 | ) 20 | from bindsnet.network.topology import Connection 21 | from bindsnet.pipeline import EnvironmentPipeline 22 | from bindsnet.pipeline.action import * 23 | 24 | parser = argparse.ArgumentParser(prefix_chars="@") 25 | parser.add_argument("@@seed", type=int, default=42) 26 | parser.add_argument("@@dt", type=float, default=1.0) 27 | parser.add_argument("@@gpu", dest="gpu", action="store_true") 28 | parser.add_argument("@@layer1scale", dest="layer1scale", type=float, default=57.68) 29 | parser.add_argument("@@layer2scale", dest="layer2scale", type=float, default=77.48) 30 | parser.add_argument("@@num_episodes", type=int, default=10) 31 | parser.add_argument("@@plot_interval", type=int, default=1) 32 | parser.add_argument("@@rander_interval", type=int, default=1) 33 | parser.set_defaults(plot=False, render=False, gpu=True, probabilistic=False) 34 | locals().update(vars(parser.parse_args())) 35 | 36 | # Setup PyTorch computing device 37 | device = torch.device("cuda" if torch.cuda.is_available() and gpu else "cpu") 38 | torch.random.manual_seed(seed) 39 | 40 | 41 | # Build ANN 42 | class Net(nn.Module): 43 | def __init__(self): 44 | super(Net, self).__init__() 45 | self.fc1 = nn.Linear(6400, 1000) 46 | self.fc2 = nn.Linear(1000, 4) 47 | 48 | def forward(self, x): 49 | x = F.relu(self.fc1(x)) 50 | x = self.fc2(x) 51 | return x 52 | 53 | 54 | # load ANN 55 | dqn_network = torch.load("trained_shallow_ANN.pt", map_location=device) 56 | 57 | # Build Spiking network. 58 | network = Network(dt=dt).to(device) 59 | 60 | # Layers of neurons. 61 | inpt = Input(n=6400, traces=False) # Input layer 62 | middle = LIFNodes( 63 | n=1000, refrac=0, traces=True, thresh=-52.0, rest=-65.0 64 | ) # Hidden layer 65 | readout = LIFNodes( 66 | n=4, refrac=0, traces=True, thresh=-52.0, rest=-65.0 67 | ) # Readout layer 68 | layers = {"X": inpt, "M": middle, "R": readout} 69 | 70 | # Set the connections between layers with the values set by the ANN 71 | # Input -> hidden. 72 | inpt_middle = Connection( 73 | source=layers["X"], 74 | target=layers["M"], 75 | w=torch.transpose(dqn_network.fc1.weight, 0, 1) * layer1scale, 76 | ) 77 | # hidden -> readout. 78 | middle_out = Connection( 79 | source=layers["M"], 80 | target=layers["R"], 81 | w=torch.transpose(dqn_network.fc2.weight, 0, 1) * layer2scale, 82 | ) 83 | 84 | # Add all layers and connections to the network. 85 | network.add_layer(inpt, name="Input Layer") 86 | network.add_layer(middle, name="Hidden Layer") 87 | network.add_layer(readout, name="Output Layer") 88 | network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer") 89 | network.add_connection(middle_out, source="Hidden Layer", target="Output Layer") 90 | 91 | # Load the Breakout environment. 92 | environment = GymEnvironment("BreakoutDeterministic-v4") 93 | environment.reset() 94 | 95 | # Build pipeline from specified components. 96 | pipeline = EnvironmentPipeline( 97 | network, 98 | environment, 99 | encoding=poisson, 100 | encode_factor=50, 101 | action_function=select_highest, 102 | percent_of_random_action=0.05, 103 | random_action_after=5, 104 | output="Output Layer", 105 | reset_output_spikes=True, 106 | time=500, 107 | overlay_input=4, 108 | history_length=1, 109 | plot_interval=plot_interval if plot else None, 110 | render_interval=render_interval if render else None, 111 | device=device, 112 | ) 113 | 114 | # Run environment simulation for number of episodes. 115 | for i in tqdm(range(num_episodes)): 116 | total_reward = 0 117 | pipeline.reset_state_variables() 118 | is_done = False 119 | pipeline.env.step(1) # start with fire the ball 120 | pipeline.env.step(1) # start with fire the ball 121 | while not is_done: 122 | result = pipeline.env_step() 123 | pipeline.step(result) 124 | 125 | reward = result[1] 126 | total_reward += reward 127 | 128 | is_done = result[2] 129 | tqdm.write(f"Episode {i} total reward:{total_reward}") 130 | with open("play-breakout_results.csv", "a") as myfile: 131 | myfile.write(f"{i},{layer1scale},{layer2scale},{total_reward}\n") 132 | -------------------------------------------------------------------------------- /examples/breakout/random_baseline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | 6 | from bindsnet.environment import GymEnvironment 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("-n", type=int, default=1000000) 10 | parser.add_argument("--render", dest="render", action="store_true") 11 | parser.set_defaults(render=False) 12 | 13 | args = parser.parse_args() 14 | 15 | n = args.n 16 | render = args.render 17 | 18 | # Load Breakout environment. 19 | env = GymEnvironment("BreakoutDeterministic-v4") 20 | env.reset() 21 | 22 | total = 0 23 | rewards = [] 24 | avg_rewards = [] 25 | lengths = [] 26 | avg_lengths = [] 27 | 28 | i, j, k = 0, 0, 0 29 | while i < n: 30 | if render: 31 | env.render() 32 | 33 | # Select random action. 34 | a = np.random.choice(4) 35 | 36 | # Step environment with random action. 37 | obs, reward, done, info = env.step(a) 38 | 39 | total += reward 40 | 41 | rewards.append(reward) 42 | if i == 0: 43 | avg_rewards.append(reward) 44 | else: 45 | avg = (avg_rewards[-1] * (i - 1)) / i + reward / i 46 | avg_rewards.append(avg) 47 | 48 | if i % 100 == 0: 49 | print( 50 | "Iteration %d: last reward: %.2f, average reward: %.2f" 51 | % (i, reward, avg_rewards[-1]) 52 | ) 53 | 54 | if done: 55 | # Restart game if out of lives. 56 | env.reset() 57 | 58 | length = i - j 59 | lengths.append(length) 60 | if j == 0: 61 | avg_lengths.append(length) 62 | else: 63 | avg = (avg_lengths[-1] * (k - 1)) / k + length / k 64 | avg_lengths.append(avg) 65 | 66 | print( 67 | "Episode %d: last length: %.2f, average length: %.2f" 68 | % (k, length, avg_lengths[-1]) 69 | ) 70 | 71 | j += length 72 | k += 1 73 | 74 | i += 1 75 | -------------------------------------------------------------------------------- /examples/breakout/random_network_baseline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | from bindsnet.encoding import bernoulli 6 | from bindsnet.environment import GymEnvironment 7 | from bindsnet.learning import Hebbian 8 | from bindsnet.network import Network 9 | from bindsnet.network.monitors import Monitor 10 | from bindsnet.network.nodes import Input, LIFNodes 11 | from bindsnet.network.topology import Connection 12 | from bindsnet.pipeline import EnvironmentPipeline 13 | from bindsnet.pipeline.action import select_multinomial 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("-n", type=int, default=1000000) 17 | parser.add_argument("--seed", type=int, default=0) 18 | parser.add_argument("--n_neurons", type=int, default=100) 19 | parser.add_argument("--dt", type=float, default=1.0) 20 | parser.add_argument("--plot_interval", type=int, default=10) 21 | parser.add_argument("--render_interval", type=int, default=10) 22 | parser.add_argument("--print_interval", type=int, default=100) 23 | parser.add_argument("--gpu", dest="gpu", action="store_true") 24 | parser.set_defaults(plot=False, render=False, gpu=False) 25 | 26 | args = parser.parse_args() 27 | 28 | n = args.n 29 | seed = args.seed 30 | n_neurons = args.n_neurons 31 | dt = args.dt 32 | plot_interval = args.plot_interval 33 | render_interval = args.render_interval 34 | print_interval = args.print_interval 35 | gpu = args.gpu 36 | 37 | if gpu: 38 | torch.set_default_tensor_type("torch.cuda.FloatTensor") 39 | torch.cuda.manual_seed_all(seed) 40 | else: 41 | torch.manual_seed(seed) 42 | 43 | # Build network. 44 | network = Network(dt=dt) 45 | 46 | # Layers of neurons. 47 | inpt = Input(shape=(1, 1, 1, 80, 80), traces=True) # Input layer 48 | exc = LIFNodes(n=n_neurons, refrac=0, traces=True) # Excitatory layer 49 | readout = LIFNodes(n=4, refrac=0, traces=True) # Readout layer 50 | layers = {"X": inpt, "E": exc, "R": readout} 51 | 52 | # Connections between layers. 53 | # Input -> excitatory. 54 | w = 0.01 * torch.rand(layers["X"].n, layers["E"].n) 55 | input_exc_conn = Connection( 56 | source=layers["X"], 57 | target=layers["E"], 58 | w=0.1 * torch.rand(layers["X"].n, layers["E"].n), 59 | wmax=0.02, 60 | norm=0.01 * layers["X"].n, 61 | ) 62 | 63 | # Excitatory -> readout. 64 | exc_readout_conn = Connection( 65 | source=layers["E"], 66 | target=layers["R"], 67 | w=0.1 * torch.rand(layers["E"].n, layers["R"].n), 68 | update_rule=Hebbian, 69 | nu=[1e-2, 1e-2], 70 | norm=0.5 * layers["E"].n, 71 | ) 72 | 73 | # Spike recordings for all layers. 74 | spikes = {} 75 | for layer in layers: 76 | spikes[layer] = Monitor(layers[layer], ["s"], time=plot_interval) 77 | 78 | # Voltage recordings for excitatory and readout layers. 79 | voltages = {} 80 | for layer in set(layers.keys()) - {"X"}: 81 | voltages[layer] = Monitor(layers[layer], ["v"], time=plot_interval) 82 | 83 | # Add all layers and connections to the network. 84 | for layer in layers: 85 | network.add_layer(layers[layer], name=layer) 86 | 87 | network.add_connection(input_exc_conn, source="X", target="E") 88 | network.add_connection(exc_readout_conn, source="E", target="R") 89 | 90 | # Add all monitors to the network. 91 | for layer in layers: 92 | network.add_monitor(spikes[layer], name="%s_spikes" % layer) 93 | 94 | if layer in voltages: 95 | network.add_monitor(voltages[layer], name="%s_voltages" % layer) 96 | 97 | # Load the Breakout environment. 98 | environment = GymEnvironment("BreakoutDeterministic-v4", render_mode="human") 99 | environment.reset() 100 | 101 | pipeline = EnvironmentPipeline( 102 | network, 103 | environment, 104 | encoding=bernoulli, 105 | history_length=1, 106 | delta=1, 107 | time=100, 108 | plot_interval=plot_interval, 109 | print_interval=print_interval, 110 | render_interval=render_interval, 111 | action_function=select_multinomial, 112 | output="R", 113 | ) 114 | 115 | total = 0 116 | rewards = [] 117 | avg_rewards = [] 118 | lengths = [] 119 | avg_lengths = [] 120 | 121 | i = 0 122 | # pipeline.reset_state_variables() 123 | try: 124 | while i < n: 125 | result = pipeline.env_step() 126 | pipeline.step(result) 127 | 128 | is_done = result[2] 129 | if is_done: 130 | pipeline.reset_state_variables() 131 | 132 | i += 1 133 | 134 | except KeyboardInterrupt: 135 | environment.close() 136 | -------------------------------------------------------------------------------- /examples/breakout/trained_shallow_ANN.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/examples/breakout/trained_shallow_ANN.pt -------------------------------------------------------------------------------- /examples/mnist/conv1d_MNIST.py: -------------------------------------------------------------------------------- 1 | ### Toy example to test Conv1dConnection (the dataset used is MNIST but each image is raveled (each sample has shape (784,)). 2 | 3 | import argparse 4 | import os 5 | from time import time as t 6 | 7 | import torch 8 | from torchvision import transforms 9 | from tqdm import tqdm 10 | 11 | from bindsnet.datasets import MNIST 12 | from bindsnet.encoding import PoissonEncoder 13 | from bindsnet.learning import PostPre 14 | from bindsnet.network import Network 15 | from bindsnet.network.monitors import Monitor 16 | from bindsnet.network.nodes import DiehlAndCookNodes, Input 17 | from bindsnet.network.topology import Connection, Conv1dConnection 18 | 19 | print() 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--seed", type=int, default=0) 23 | parser.add_argument("--n_epochs", type=int, default=1) 24 | parser.add_argument("--n_test", type=int, default=10000) 25 | parser.add_argument("--n_train", type=int, default=60000) 26 | parser.add_argument("--batch_size", type=int, default=1) 27 | parser.add_argument("--kernel_size", type=int, default=28 * 2) 28 | parser.add_argument("--stride", type=int, default=28) 29 | parser.add_argument("--n_filters", type=int, default=25) 30 | parser.add_argument("--padding", type=int, default=0) 31 | parser.add_argument("--time", type=int, default=50) 32 | parser.add_argument("--dt", type=int, default=1.0) 33 | parser.add_argument("--intensity", type=float, default=128.0) 34 | parser.add_argument("--progress_interval", type=int, default=10) 35 | parser.add_argument("--update_interval", type=int, default=250) 36 | parser.add_argument("--train", dest="train", action="store_true") 37 | parser.add_argument("--test", dest="train", action="store_false") 38 | parser.add_argument("--plot", dest="plot", action="store_true") 39 | parser.add_argument("--gpu", dest="gpu", action="store_true") 40 | parser.set_defaults(plot=True, gpu=True, train=True) 41 | 42 | args = parser.parse_args() 43 | 44 | seed = args.seed 45 | n_epochs = args.n_epochs 46 | n_test = args.n_test 47 | n_train = args.n_train 48 | batch_size = args.batch_size 49 | kernel_size = args.kernel_size 50 | stride = args.stride 51 | n_filters = args.n_filters 52 | padding = args.padding 53 | time = args.time 54 | dt = args.dt 55 | intensity = args.intensity 56 | progress_interval = args.progress_interval 57 | update_interval = args.update_interval 58 | train = args.train 59 | plot = args.plot 60 | gpu = args.gpu 61 | 62 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 63 | if gpu and torch.cuda.is_available(): 64 | torch.cuda.manual_seed_all(seed) 65 | else: 66 | torch.manual_seed(seed) 67 | device = "cpu" 68 | if gpu: 69 | gpu = False 70 | 71 | torch.set_num_threads(os.cpu_count() - 1) 72 | print("Running on Device = ", device) 73 | 74 | if not train: 75 | update_interval = n_test 76 | 77 | conv_size = int((28 * 28 - kernel_size + 2 * padding) / stride) + 1 78 | per_class = int((n_filters * conv_size) / 10) 79 | 80 | # Build network. 81 | network = Network() 82 | input_layer = Input(n=28 * 28, shape=(1, 28 * 28), traces=True) 83 | 84 | conv_layer = DiehlAndCookNodes( 85 | n=n_filters * conv_size, shape=(n_filters, conv_size), traces=True 86 | ) 87 | 88 | conv_conn = Conv1dConnection( 89 | input_layer, 90 | conv_layer, 91 | kernel_size=kernel_size, 92 | stride=stride, 93 | update_rule=PostPre, 94 | norm=0.4 * kernel_size, 95 | nu=[1e-4, 1e-2], 96 | wmax=1.0, 97 | ) 98 | 99 | w = torch.zeros(n_filters, conv_size, n_filters, conv_size) 100 | for fltr1 in range(n_filters): 101 | for fltr2 in range(n_filters): 102 | if fltr1 != fltr2: 103 | for i in range(conv_size): 104 | w[fltr1, i, fltr2, i] = -100.0 105 | 106 | w = w.view(n_filters * conv_size, n_filters * conv_size) 107 | recurrent_conn = Connection(conv_layer, conv_layer, w=w) 108 | 109 | network.add_layer(input_layer, name="X") 110 | network.add_layer(conv_layer, name="Y") 111 | network.add_connection(conv_conn, source="X", target="Y") 112 | network.add_connection(recurrent_conn, source="Y", target="Y") 113 | 114 | # Voltage recording for excitatory and inhibitory layers. 115 | voltage_monitor = Monitor(network.layers["Y"], ["v"], time=time) 116 | network.add_monitor(voltage_monitor, name="output_voltage") 117 | 118 | if gpu: 119 | network.to("cuda") 120 | 121 | # Load MNIST data. 122 | train_dataset = MNIST( 123 | PoissonEncoder(time=time, dt=dt), 124 | None, 125 | "../../data/MNIST", 126 | download=True, 127 | train=True, 128 | transform=transforms.Compose( 129 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] 130 | ), 131 | ) 132 | 133 | spikes = {} 134 | for layer in set(network.layers): 135 | spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) 136 | network.add_monitor(spikes[layer], name="%s_spikes" % layer) 137 | 138 | voltages = {} 139 | for layer in set(network.layers) - {"X"}: 140 | voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) 141 | network.add_monitor(voltages[layer], name="%s_voltages" % layer) 142 | 143 | # Train the network. 144 | print("Begin training.\n") 145 | start = t() 146 | 147 | inpt_axes = None 148 | inpt_ims = None 149 | spike_ims = None 150 | spike_axes = None 151 | voltage_ims = None 152 | voltage_axes = None 153 | 154 | for epoch in range(n_epochs): 155 | if epoch % progress_interval == 0: 156 | print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) 157 | start = t() 158 | 159 | train_dataloader = torch.utils.data.DataLoader( 160 | train_dataset, 161 | batch_size=batch_size, 162 | shuffle=True, 163 | num_workers=0, 164 | pin_memory=gpu, 165 | ) 166 | 167 | for step, batch in enumerate(tqdm(train_dataloader)): 168 | # Get next input sample (raveled to have shape (time, batch_size, 1, 28*28)) 169 | if step > n_train: 170 | break 171 | inputs = {"X": batch["encoded_image"].view(time, batch_size, 1, 28 * 28)} 172 | if gpu: 173 | inputs = {k: v.cuda() for k, v in inputs.items()} 174 | label = batch["label"] 175 | 176 | # Run the network on the input. 177 | network.run(inputs=inputs, time=time) 178 | 179 | network.reset_state_variables() # Reset state variables. 180 | 181 | print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) 182 | print("Training complete.\n") 183 | -------------------------------------------------------------------------------- /examples/mnist/loc1d_mnist.py: -------------------------------------------------------------------------------- 1 | ### Toy example to test LocanConnection1D (the dataset used is MNIST but each image is raveled (each sample has shape (784,)). 2 | 3 | import torch 4 | from torch.nn.modules.utils import _pair 5 | 6 | from tqdm import tqdm 7 | import os 8 | from bindsnet.network.monitors import Monitor 9 | 10 | import torch 11 | from torchvision import transforms 12 | from tqdm import tqdm 13 | 14 | from time import time as t 15 | from torchvision import transforms 16 | from bindsnet.learning import PostPre 17 | 18 | from bindsnet.network.nodes import AdaptiveLIFNodes 19 | from bindsnet.network.nodes import Input 20 | from bindsnet.network.network import Network 21 | from bindsnet.network.topology import Connection, LocalConnection1D 22 | from bindsnet.encoding import PoissonEncoder 23 | from bindsnet.datasets import MNIST 24 | 25 | # Hyperparameters 26 | in_channels = 1 27 | n_filters = 25 28 | input_shape = 784 29 | kernel_size = 28 * 2 30 | stride = 28 31 | tc_theta_decay = 1e6 32 | theta_plus = 0.05 33 | norm = 0.2 * kernel_size 34 | wmin = 0.0 35 | wmax = 1.0 36 | nu = (1e-4, 1e-2) 37 | inh = 25.0 38 | dt = 1.0 39 | time = 250 40 | intensity = 128 41 | n_epochs = 1 42 | n_train = 500 43 | progress_interval = 10 44 | batch_size = 1 45 | 46 | # Build network 47 | network = Network() 48 | 49 | input_layer = Input(shape=[in_channels, input_shape], traces=True, tc_trace=20) 50 | 51 | compute_conv_size = lambda inp_size, k, s: int((inp_size - k) / s) + 1 52 | conv_size = compute_conv_size(input_shape, kernel_size, stride) 53 | 54 | output_layer = AdaptiveLIFNodes( 55 | shape=[n_filters, conv_size], 56 | traces=True, 57 | rest=-65.0, 58 | reset=-60.0, 59 | thresh=-52.0, 60 | refrac=5, 61 | tc_decay=100.0, 62 | tc_trace=20.0, 63 | theta_plus=theta_plus, 64 | tc_theta_decay=tc_theta_decay, 65 | ) 66 | 67 | input_output_conn = LocalConnection1D( 68 | input_layer, 69 | output_layer, 70 | kernel_size=kernel_size, 71 | stride=stride, 72 | n_filters=n_filters, 73 | nu=nu, 74 | update_rule=PostPre, 75 | wmin=wmin, 76 | wmax=wmax, 77 | norm=norm, 78 | ) 79 | 80 | w_inh_LC = torch.zeros(n_filters, conv_size, n_filters, conv_size) 81 | for c in range(n_filters): 82 | for w1 in range(conv_size): 83 | w_inh_LC[c, w1, :, w1] = -inh 84 | w_inh_LC[c, w1, c, w1] = 0 85 | 86 | w_inh_LC = w_inh_LC.reshape(output_layer.n, output_layer.n) 87 | recurrent_conn = Connection(output_layer, output_layer, w=w_inh_LC) 88 | 89 | network.add_layer(input_layer, name="X") 90 | network.add_layer(output_layer, name="Y") 91 | network.add_connection(input_output_conn, source="X", target="Y") 92 | network.add_connection(recurrent_conn, source="Y", target="Y") 93 | 94 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 95 | gpu = True 96 | seed = 0 97 | 98 | if gpu and torch.cuda.is_available(): 99 | torch.cuda.manual_seed_all(seed) 100 | else: 101 | torch.manual_seed(seed) 102 | device = "cpu" 103 | if gpu: 104 | gpu = False 105 | 106 | torch.set_num_threads(os.cpu_count() - 1) 107 | print("Running on Device = ", device) 108 | 109 | if gpu: 110 | network.to("cuda") 111 | 112 | # Load MNIST data. 113 | train_dataset = MNIST( 114 | PoissonEncoder(time=time, dt=dt), 115 | None, 116 | "../../data/MNIST", 117 | download=True, 118 | train=True, 119 | transform=transforms.Compose( 120 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] 121 | ), 122 | ) 123 | 124 | spikes = {} 125 | for layer in set(network.layers): 126 | spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) 127 | network.add_monitor(spikes[layer], name="%s_spikes" % layer) 128 | 129 | voltages = {} 130 | for layer in set(network.layers) - {"X"}: 131 | voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) 132 | network.add_monitor(voltages[layer], name="%s_voltages" % layer) 133 | 134 | # Train the network. 135 | print("Begin training.\n") 136 | start = t() 137 | 138 | for epoch in range(n_epochs): 139 | if epoch % progress_interval == 0: 140 | print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) 141 | start = t() 142 | 143 | train_dataloader = torch.utils.data.DataLoader( 144 | train_dataset, 145 | batch_size=batch_size, 146 | shuffle=True, 147 | num_workers=0, 148 | pin_memory=gpu, 149 | ) 150 | 151 | for step, batch in enumerate(tqdm(train_dataloader)): 152 | # Get next input sample. 153 | if step > n_train: 154 | break 155 | inputs = {"X": batch["encoded_image"].view(time, batch_size, 1, 28 * 28)} 156 | if gpu: 157 | inputs = {k: v.cuda() for k, v in inputs.items()} 158 | label = batch["label"] 159 | 160 | # Run the network on the input. 161 | network.run(inputs=inputs, time=time) 162 | 163 | network.reset_state_variables() # Reset state variables. 164 | 165 | print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) 166 | print("Training complete.\n") 167 | -------------------------------------------------------------------------------- /examples/mnist/loc2d_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.utils import _pair 3 | 4 | from tqdm import tqdm 5 | import os 6 | from bindsnet.network.monitors import Monitor 7 | import matplotlib.pyplot as plt 8 | import torch 9 | from torchvision import transforms 10 | from tqdm import tqdm 11 | 12 | from bindsnet.analysis.plotting import plot_local_connection_2d_weights 13 | 14 | from time import time as t 15 | from torchvision import transforms 16 | from bindsnet.learning import PostPre 17 | 18 | from bindsnet.network.nodes import AdaptiveLIFNodes 19 | from bindsnet.network.nodes import Input 20 | from bindsnet.network.network import Network 21 | from bindsnet.network.topology import Connection, LocalConnection2D 22 | from bindsnet.encoding import PoissonEncoder 23 | from bindsnet.datasets import MNIST 24 | 25 | # Hyperparameters 26 | in_channels = 1 27 | n_filters = 50 28 | input_shape = [20, 20] 29 | kernel_size = _pair(12) 30 | stride = _pair(4) 31 | tc_theta_decay = 1e6 32 | theta_plus = 0.05 33 | norm = 0.2 * kernel_size[0] * kernel_size[1] 34 | wmin = 0.0 35 | wmax = 1.0 36 | nu = (0.0001, 0.01) 37 | inh = 25.0 38 | dt = 1.0 39 | time = 250 40 | intensity = 128 41 | n_epochs = 1 42 | n_train = 2500 43 | progress_interval = 10 44 | batch_size = 1 45 | 46 | plot = True 47 | 48 | # Build network 49 | network = Network() 50 | 51 | input_layer = Input( 52 | shape=[in_channels, input_shape[0], input_shape[1]], traces=True, tc_trace=20 53 | ) 54 | 55 | compute_conv_size = lambda inp_size, k, s: int((inp_size - k) / s) + 1 56 | conv_size = _pair(compute_conv_size(input_shape[0], kernel_size[0], stride[0])) 57 | 58 | output_layer = AdaptiveLIFNodes( 59 | shape=[n_filters, conv_size[0], conv_size[1]], 60 | traces=True, 61 | rest=-65.0, 62 | reset=-60.0, 63 | thresh=-52.0, 64 | refrac=5, 65 | tc_trace=20.0, 66 | theta_plus=theta_plus, 67 | tc_theta_decay=tc_theta_decay, 68 | ) 69 | 70 | input_output_conn = LocalConnection2D( 71 | input_layer, 72 | output_layer, 73 | kernel_size=kernel_size, 74 | stride=stride, 75 | n_filters=n_filters, 76 | nu=nu, 77 | update_rule=PostPre, 78 | wmin=wmin, 79 | wmax=wmax, 80 | norm=norm, 81 | ) 82 | 83 | w_inh_LC = torch.zeros( 84 | n_filters, conv_size[0], conv_size[1], n_filters, conv_size[0], conv_size[1] 85 | ) 86 | for c in range(n_filters): 87 | for w1 in range(conv_size[0]): 88 | for w2 in range(conv_size[0]): 89 | w_inh_LC[c, w1, w2, :, w1, w2] = -inh 90 | w_inh_LC[c, w1, w2, c, w1, w2] = 0 91 | 92 | w_inh_LC = w_inh_LC.reshape(output_layer.n, output_layer.n) 93 | recurrent_conn = Connection(output_layer, output_layer, w=w_inh_LC) 94 | 95 | network.add_layer(input_layer, name="X") 96 | network.add_layer(output_layer, name="Y") 97 | network.add_connection(input_output_conn, source="X", target="Y") 98 | network.add_connection(recurrent_conn, source="Y", target="Y") 99 | 100 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 101 | gpu = True 102 | seed = 0 103 | if gpu and torch.cuda.is_available(): 104 | torch.cuda.manual_seed_all(seed) 105 | else: 106 | torch.manual_seed(seed) 107 | device = "cpu" 108 | if gpu: 109 | gpu = False 110 | 111 | torch.set_num_threads(os.cpu_count() - 1) 112 | print("Running on Device = ", device) 113 | 114 | if gpu: 115 | network.to("cuda") 116 | 117 | # Load MNIST data. 118 | train_dataset = MNIST( 119 | PoissonEncoder(time=time, dt=dt), 120 | None, 121 | "../../data/MNIST", 122 | download=True, 123 | train=True, 124 | transform=transforms.Compose( 125 | [ 126 | transforms.ToTensor(), 127 | transforms.CenterCrop((input_shape[0], input_shape[1])), 128 | transforms.Lambda(lambda x: x * intensity), 129 | ] 130 | ), 131 | ) 132 | 133 | spikes = {} 134 | for layer in set(network.layers): 135 | spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) 136 | network.add_monitor(spikes[layer], name="%s_spikes" % layer) 137 | 138 | voltages = {} 139 | for layer in set(network.layers) - {"X"}: 140 | voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) 141 | network.add_monitor(voltages[layer], name="%s_voltages" % layer) 142 | 143 | # Train the network. 144 | print("Begin training.\n") 145 | start = t() 146 | 147 | weights1_im = None 148 | 149 | for epoch in range(n_epochs): 150 | if epoch % progress_interval == 0: 151 | print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) 152 | start = t() 153 | 154 | train_dataloader = torch.utils.data.DataLoader( 155 | train_dataset, 156 | batch_size=batch_size, 157 | shuffle=True, 158 | num_workers=0, 159 | pin_memory=gpu, 160 | ) 161 | 162 | for step, batch in enumerate(tqdm(train_dataloader)): 163 | # Get next input sample. 164 | if step > n_train: 165 | break 166 | inputs = { 167 | "X": batch["encoded_image"].view( 168 | time, batch_size, 1, input_shape[0], input_shape[1] 169 | ) 170 | } 171 | if gpu: 172 | inputs = {k: v.cuda() for k, v in inputs.items()} 173 | label = batch["label"] 174 | 175 | # Run the network on the input. 176 | network.run(inputs=inputs, time=time) 177 | 178 | # Optionally plot various simulation information. 179 | if plot: 180 | weights1_im = plot_local_connection_2d_weights( 181 | network.connections[("X", "Y")], im=weights1_im 182 | ) 183 | plt.pause(1) 184 | 185 | network.reset_state_variables() # Reset state variables. 186 | 187 | print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) 188 | print("Training complete.\n") 189 | 190 | weights1_im = plot_local_connection_2d_weights(network.connections[("X", "Y")]) 191 | plt.savefig("test.png") 192 | plt.pause(100) 193 | -------------------------------------------------------------------------------- /examples/mnist/loc3d_mnist.py: -------------------------------------------------------------------------------- 1 | ### Toy example to test LocalConnection3D (the dataset used is MNIST but with a dimension replicated 2 | ### for each image (each sample has size (28, 28, 28)) 3 | 4 | import torch 5 | from torch.nn.modules.utils import _triple 6 | 7 | from tqdm import tqdm 8 | import os 9 | from bindsnet.network.monitors import Monitor 10 | 11 | import torch 12 | from torchvision import transforms 13 | from tqdm import tqdm 14 | 15 | from time import time as t 16 | from torchvision import transforms 17 | from bindsnet.learning import PostPre 18 | 19 | from bindsnet.network.nodes import AdaptiveLIFNodes 20 | from bindsnet.network.nodes import Input 21 | from bindsnet.network.network import Network 22 | from bindsnet.network.topology import Connection, LocalConnection3D 23 | from bindsnet.encoding import PoissonEncoder 24 | from bindsnet.datasets import MNIST 25 | 26 | # Hyperparameters 27 | in_channels = 1 28 | n_filters = 25 29 | input_shape = [20, 20, 20] 30 | kernel_size = _triple(16) 31 | stride = _triple(2) 32 | tc_theta_decay = 1e6 33 | theta_plus = 0.05 34 | norm = 0.2 * kernel_size[0] * kernel_size[1] * kernel_size[2] 35 | wmin = 0.0 36 | wmax = 1.0 37 | nu = (0.0001, 0.01) 38 | inh = 25.0 39 | dt = 1.0 40 | time = 250 41 | intensity = 128 42 | n_epochs = 1 43 | n_train = 2500 44 | progress_interval = 10 45 | batch_size = 1 46 | 47 | # Build network 48 | network = Network() 49 | 50 | input_layer = Input( 51 | n=input_shape[0] * input_shape[1] * input_shape[2], 52 | shape=(in_channels, input_shape[0], input_shape[1], input_shape[2]), 53 | traces=True, 54 | ) 55 | 56 | compute_conv_size = lambda inp_size, k, s: int((inp_size - k) / s) + 1 57 | conv_size = _triple(compute_conv_size(input_shape[0], kernel_size[0], stride[0])) 58 | 59 | output_layer = AdaptiveLIFNodes( 60 | shape=[n_filters, conv_size[0], conv_size[1], conv_size[2]], 61 | traces=True, 62 | rest=-65.0, 63 | reset=-60.0, 64 | thresh=-52.0, 65 | refrac=5, 66 | tc_trace=20.0, 67 | theta_plus=theta_plus, 68 | tc_theta_decay=tc_theta_decay, 69 | ) 70 | 71 | input_output_conn = LocalConnection3D( 72 | input_layer, 73 | output_layer, 74 | kernel_size=kernel_size, 75 | stride=stride, 76 | n_filters=n_filters, 77 | nu=nu, 78 | update_rule=PostPre, 79 | wmin=wmin, 80 | wmax=wmax, 81 | norm=norm, 82 | ) 83 | 84 | w_inh_LC = torch.zeros( 85 | n_filters, 86 | conv_size[0], 87 | conv_size[1], 88 | conv_size[2], 89 | n_filters, 90 | conv_size[0], 91 | conv_size[1], 92 | conv_size[2], 93 | ) 94 | 95 | for c in range(n_filters): 96 | for w1 in range(conv_size[0]): 97 | for w2 in range(conv_size[1]): 98 | for w3 in range(conv_size[2]): 99 | w_inh_LC[c, w1, w2, w3, :, w1, w2, w3] = -inh 100 | w_inh_LC[c, w1, w2, w3, c, w1, w2, w3] = 0 101 | 102 | w_inh_LC = w_inh_LC.reshape(output_layer.n, output_layer.n) 103 | recurrent_conn = Connection(output_layer, output_layer, w=w_inh_LC) 104 | 105 | network.add_layer(input_layer, name="X") 106 | network.add_layer(output_layer, name="Y") 107 | network.add_connection(input_output_conn, source="X", target="Y") 108 | network.add_connection(recurrent_conn, source="Y", target="Y") 109 | 110 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 111 | gpu = True 112 | seed = 0 113 | if gpu and torch.cuda.is_available(): 114 | torch.cuda.manual_seed_all(seed) 115 | else: 116 | torch.manual_seed(seed) 117 | device = "cpu" 118 | if gpu: 119 | gpu = False 120 | 121 | torch.set_num_threads(os.cpu_count() - 1) 122 | print("Running on Device = ", device) 123 | 124 | if gpu: 125 | network.to("cuda") 126 | 127 | # Load MNIST data. 128 | train_dataset = MNIST( 129 | PoissonEncoder(time=time, dt=dt), 130 | None, 131 | "../../data/MNIST", 132 | download=True, 133 | train=True, 134 | transform=transforms.Compose( 135 | [ 136 | transforms.ToTensor(), 137 | transforms.CenterCrop((input_shape[0], input_shape[1])), 138 | transforms.Lambda(lambda x: x * intensity), 139 | ] 140 | ), 141 | ) 142 | 143 | spikes = {} 144 | for layer in set(network.layers): 145 | spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) 146 | network.add_monitor(spikes[layer], name="%s_spikes" % layer) 147 | 148 | voltages = {} 149 | for layer in set(network.layers) - {"X"}: 150 | voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=time) 151 | network.add_monitor(voltages[layer], name="%s_voltages" % layer) 152 | 153 | # Train the network. 154 | print("Begin training.\n") 155 | start = t() 156 | 157 | for epoch in range(n_epochs): 158 | if epoch % progress_interval == 0: 159 | print("Progress: %d / %d (%.4f seconds)" % (epoch, n_epochs, t() - start)) 160 | start = t() 161 | 162 | train_dataloader = torch.utils.data.DataLoader( 163 | train_dataset, 164 | batch_size=batch_size, 165 | shuffle=True, 166 | num_workers=0, 167 | pin_memory=gpu, 168 | ) 169 | 170 | for step, batch in enumerate(tqdm(train_dataloader)): 171 | # Get next input sample. 172 | if step > n_train: 173 | break 174 | inputs = { 175 | "X": batch["encoded_image"] 176 | .view(time, batch_size, 1, input_shape[0], input_shape[1]) 177 | .unsqueeze(3) 178 | .repeat(1, 1, 1, input_shape[2], 1, 1) 179 | .float() 180 | } 181 | if gpu: 182 | inputs = {k: v.cuda() for k, v in inputs.items()} 183 | label = batch["label"] 184 | 185 | # Run the network on the input. 186 | network.run(inputs=inputs, time=time) 187 | 188 | print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) 189 | print("Training complete.\n") 190 | -------------------------------------------------------------------------------- /examples/tensorboard/tensorboard.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from time import time as t 4 | 5 | import torch 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | 9 | import bindsnet.datasets 10 | from bindsnet.analysis.pipeline_analysis import MatplotlibAnalyzer, TensorboardAnalyzer 11 | from bindsnet.encoding import NullEncoder, PoissonEncoder 12 | from bindsnet.learning import PostPre 13 | from bindsnet.network import Network 14 | from bindsnet.network.nodes import Input, LIFNodes 15 | from bindsnet.network.topology import Connection, Conv2dConnection 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--dataset", 20 | type=str, 21 | default="MNIST", 22 | choices=["MNIST", "KMNIST", "FashionMNIST", "CIFAR10", "CIFAR100"], 23 | ) 24 | parser.add_argument("--seed", type=int, default=0) 25 | parser.add_argument("--time", type=int, default=50) 26 | parser.add_argument("--dt", type=int, default=1.0) 27 | parser.add_argument("--tensorboard", dest="tensorboard", action="store_true") 28 | parser.set_defaults(plot=False, gpu=False, train=True) 29 | 30 | args = parser.parse_args() 31 | 32 | seed = args.seed 33 | torch.manual_seed(seed) 34 | 35 | # Encoding parameters 36 | time = args.time 37 | dt = args.dt 38 | 39 | # Convolution parameters 40 | kernel_size = 5 41 | stride = 2 42 | n_filters = 5 43 | padding = 0 44 | 45 | # Create the datasets and loaders 46 | # This is dynamic so you can test each dataset easily 47 | dataset_type = getattr(bindsnet.datasets, args.dataset) 48 | dataset_path = os.path.join("..", "..", "data", args.dataset) 49 | train_dataset = dataset_type( 50 | PoissonEncoder(time=time, dt=dt), 51 | NullEncoder(), 52 | dataset_path, 53 | download=True, 54 | train=True, 55 | transform=transforms.Compose( 56 | [transforms.ToTensor(), transforms.Lambda(lambda x: x * 128.0)] 57 | ), 58 | ) 59 | 60 | train_dataloader = torch.utils.data.DataLoader( 61 | train_dataset, batch_size=1, shuffle=True, num_workers=0 62 | ) 63 | 64 | # Grab the shape of a single sample (not including batch) 65 | # So, TxCxHxW 66 | sample_shape = train_dataset[0]["encoded_image"].shape 67 | print(args.dataset, " has shape ", sample_shape) 68 | 69 | conv_size = int((sample_shape[-1] - kernel_size + 2 * padding) / stride) + 1 70 | per_class = int((n_filters * conv_size * conv_size) / 10) 71 | 72 | # Build a small convolutional network 73 | network = Network() 74 | 75 | # Make sure to include the batch dimension but not time 76 | input_layer = Input(shape=(sample_shape[1:]), traces=True) 77 | 78 | conv_layer = LIFNodes( 79 | n=n_filters * conv_size * conv_size, 80 | shape=(n_filters, conv_size, conv_size), 81 | traces=True, 82 | ) 83 | 84 | conv_conn = Conv2dConnection( 85 | input_layer, 86 | conv_layer, 87 | kernel_size=kernel_size, 88 | stride=stride, 89 | update_rule=PostPre, 90 | norm=0.4 * kernel_size**2, 91 | nu=[1e-4, 1e-2], 92 | wmax=1.0, 93 | ) 94 | 95 | network.add_layer(input_layer, name="X") 96 | network.add_layer(conv_layer, name="Y") 97 | network.add_connection(conv_conn, source="X", target="Y") 98 | 99 | # Train the network. 100 | print("Begin training.\n") 101 | 102 | if args.tensorboard: 103 | analyzer = TensorboardAnalyzer("logs/conv") 104 | else: 105 | analyzer = MatplotlibAnalyzer() 106 | 107 | for step, batch in enumerate(tqdm(train_dataloader)): 108 | # batch contains image, label, encoded_image since an image_encoder 109 | # was provided 110 | 111 | # batch["encoded_image"] is in BxTxCxHxW format 112 | inputs = {"X": batch["encoded_image"].view(time, 1, 1, 28, 28)} 113 | 114 | # Run the network on the input. 115 | # Specify the location of the time dimension 116 | network.run(inputs=inputs, time=time) 117 | 118 | network.reset_state_variables() # Reset state variables. 119 | 120 | analyzer.plot_conv2d_weights(conv_conn.w, step=step) 121 | 122 | analyzer.finalize_step() 123 | -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1656543178.TempWin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1656543178.TempWin -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1656548905.TempWin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1656548905.TempWin -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1673646087.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1673646087.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1673648326.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1673648326.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1678117372.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1678117372.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1682712186.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1682712186.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1687464074.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1687464074.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1687464505.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1687464505.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1687736499.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1687736499.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1694374827.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1694374827.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1694374969.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1694374969.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1694375010.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1694375010.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1700165624.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1700165624.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1703700212.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1703700212.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711672057.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711672057.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711672140.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711672140.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711673241.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711673241.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711673760.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711673760.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711674764.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711674764.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711675116.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711675116.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711675170.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711675170.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711675181.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711675181.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711675321.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711675321.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711675865.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711675865.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711721119.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711721119.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1711723694.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1711723694.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1720719086.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1720719086.Spike -------------------------------------------------------------------------------- /logs/init/events.out.tfevents.1720719342.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/init/events.out.tfevents.1720719342.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1656543178.TempWin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1656543178.TempWin -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1656548905.TempWin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1656548905.TempWin -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1673646087.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1673646087.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1673648326.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1673648326.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1678117372.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1678117372.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1682712186.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1682712186.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1687464074.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1687464074.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1687464505.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1687464505.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1687736499.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1687736499.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1694374827.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1694374827.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1694374969.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1694374969.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1694375010.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1694375010.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1700165624.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1700165624.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1703700212.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1703700212.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711672057.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711672057.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711672140.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711672140.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711673241.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711673241.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711673760.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711673760.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711674764.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711674764.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711675116.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711675116.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711675170.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711675170.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711675181.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711675181.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711675321.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711675321.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711675865.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711675865.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711721119.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711721119.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1711723694.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1711723694.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1720719086.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1720719086.Spike -------------------------------------------------------------------------------- /logs/runs/events.out.tfevents.1720719342.Spike: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/logs/runs/events.out.tfevents.1720719342.Spike -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "bindsnet" 3 | version = "0.3.3" 4 | description = "Spiking neural networks for ML in Python" 5 | authors = [ "Hananel Hazan ", "Daniel Saunders", "Darpan Sanghavi", "Hassaan Khan" ] 6 | license = "AGPL-3.0-only" 7 | readme = "README.md" 8 | repository = "https://github.com/BindsNET/bindsnet" 9 | documentation = "https://bindsnet-docs.readthedocs.io/" 10 | keywords = ["spiking", "neural", "networks", "pytorch"] 11 | 12 | [tool.poetry.dependencies] 13 | python = ">=3.10" 14 | numpy = "^2" 15 | scipy = "^1" 16 | Cython = "^3" 17 | torch = [ 18 | {version = "2.6.0", markers = "sys_platform != 'darwin'", source = "torch+cu124"}, 19 | {version = "2.6", markers = "sys_platform == 'darwin'" }, 20 | ] 21 | torchvision = [ 22 | {version = "0.21.0", markers = "sys_platform != 'darwin'", source = "torch+cu124"}, 23 | {version = "0.21.0", markers = "sys_platform == 'darwin'" }, 24 | ] 25 | torchaudio = [ 26 | {version = "2.6.0", markers = "sys_platform != 'darwin'", source = "torch+cu124"}, 27 | {version = "2.6.0", markers = "sys_platform == 'darwin'" }, 28 | ] 29 | 30 | tensorboardX = "^2.6.2" 31 | tqdm = "^4" 32 | matplotlib = "^3" 33 | ale-py = "^0.10.2" 34 | gymnasium = {extras = ["atari"], version = "^1"} 35 | scikit-build = "^0.18" 36 | scikit-image = "^0.25.2" 37 | scikit-learn = "^1.5" 38 | opencv-python = "^4" 39 | pandas = "^2" 40 | foolbox = "^3" 41 | 42 | [[tool.poetry.source]] 43 | name = "torch+cu118" 44 | url = "https://download.pytorch.org/whl/cu118" 45 | priority = "explicit" 46 | 47 | [[tool.poetry.source]] 48 | name = "torch+cu121" 49 | url = "https://download.pytorch.org/whl/cu121" 50 | priority = "explicit" 51 | 52 | [[tool.poetry.source]] 53 | name = "torch+cu124" 54 | url = "https://download.pytorch.org/whl/cu124" 55 | priority = "explicit" 56 | 57 | [[tool.poetry.source]] 58 | name = "torch+cu126" 59 | url = "https://download.pytorch.org/whl/cu126" 60 | priority = "explicit" 61 | 62 | [tool.poetry.dev-dependencies] 63 | pytest = "^8" 64 | pre-commit = "^3" 65 | notebook = "^7" 66 | jupyterlab = "^4" 67 | isort = "^5.9.3" 68 | black = "^24" 69 | autoflake = "^2" 70 | 71 | [build-system] 72 | requires = ["setuptools", "poetry-core>=1.0.0"] 73 | build-backend = "poetry.core.masonry.api" 74 | 75 | [tool.isort] 76 | profile = "black" 77 | line_length = 88 78 | src_paths = ["bindsnet", "test"] 79 | 80 | [tool.black] 81 | target-version = ['py38'] 82 | include = '\.pyi?$' 83 | exclude = ''' 84 | /( 85 | \.eggs 86 | | \.git 87 | | \.hg 88 | | \.mypy_cache 89 | | \.pytest_cache 90 | | \.venv 91 | | \.github 92 | | build 93 | | dist 94 | | BindsNET.egg-info 95 | | notebooks 96 | | data 97 | | logs 98 | )/ 99 | ''' 100 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import setuptools 4 | 5 | if __name__ == "__main__": 6 | setuptools.setup() 7 | -------------------------------------------------------------------------------- /test/analysis/__pycache__/test_analyzers.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/analysis/__pycache__/test_analyzers.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/analysis/test_analyzers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import torch 5 | 6 | from bindsnet.analysis.pipeline_analysis import MatplotlibAnalyzer, TensorboardAnalyzer 7 | 8 | 9 | class TestAnalyzer: 10 | """ 11 | Sanity checks all plotting functions for analyzers 12 | """ 13 | 14 | def test_init(self): 15 | ma = MatplotlibAnalyzer() 16 | assert plt.isinteractive() 17 | 18 | ta = TensorboardAnalyzer("./logs/init") 19 | 20 | # check to ensure path was written 21 | assert os.path.isdir("./logs/init") 22 | 23 | # check to ensure we can write data 24 | ta.writer.add_scalar("init_scalar", 100.0, 0) 25 | ta.writer.close() 26 | 27 | def test_plot_runs(self): 28 | ma = MatplotlibAnalyzer() 29 | ta = TensorboardAnalyzer("./logs/runs") 30 | 31 | for analyzer in [ma, ta]: 32 | obs = torch.rand(1, 28, 28) 33 | analyzer.plot_obs(obs) 34 | 35 | # 4 channels out, 1 channel in, 8x8 kernels 36 | conv_weights = torch.rand(4, 1, 8, 8) 37 | analyzer.plot_conv2d_weights(conv_weights) 38 | 39 | rewards = [0, 0, 0, 0, 0] 40 | analyzer.plot_reward(rewards) 41 | 42 | # Monitors have time as last dimension 43 | v = torch.rand(50, 1, 1, 28, 28) 44 | voltage_dict = {"X": v} 45 | threshold_dict = {"X": torch.tensor(0.75)} 46 | analyzer.plot_voltages(voltage_dict, threshold_dict) 47 | 48 | # The monitors have time as last dimension 49 | spikes = torch.rand(50, 1, 1, 28, 28) > 0.5 50 | spike_dict = {"X": spikes} 51 | analyzer.plot_spikes(spike_dict) 52 | 53 | analyzer.finalize_step() 54 | 55 | ta.writer.close() 56 | 57 | 58 | if __name__ == "__main__": 59 | tester = TestAnalyzer() 60 | 61 | tester.test_init() 62 | tester.test_plot_runs() 63 | -------------------------------------------------------------------------------- /test/conversion/__pycache__/test_conversion.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/conversion/__pycache__/test_conversion.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/conversion/__pycache__/test_conversion.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/conversion/test_conversion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from bindsnet.conversion import ann_to_snn 6 | 7 | 8 | class FullyConnectedNetwork(nn.Module): 9 | # language=rst 10 | """ 11 | Simply fully-connected network implemented in PyTorch. 12 | """ 13 | 14 | def __init__(self): 15 | super(FullyConnectedNetwork, self).__init__() 16 | 17 | self.fc1 = nn.Linear(784, 256) 18 | self.fc2 = nn.Linear(256, 128) 19 | self.fc3 = nn.Linear(128, 10) 20 | 21 | def forward(self, x): 22 | x = F.relu(self.fc1(x)) 23 | x = F.relu(self.fc2(x)) 24 | x = self.fc3(x) 25 | return x 26 | 27 | 28 | def test_conversion_1(): 29 | ann = FullyConnectedNetwork() 30 | snn = ann_to_snn(ann, input_shape=(784,)) 31 | 32 | 33 | def test_conversion_2(): 34 | data = torch.rand(784, 20) 35 | ann = FullyConnectedNetwork() 36 | snn = ann_to_snn(ann, data=data, input_shape=(784,)) 37 | 38 | 39 | def main(): 40 | test_conversion_1() 41 | test_conversion_2() 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /test/encoding/__pycache__/test_encoding.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/encoding/__pycache__/test_encoding.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/encoding/__pycache__/test_encoding.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/encoding/test_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bindsnet.encoding import * 4 | 5 | 6 | class TestEncodings: 7 | """ 8 | Tests all stable encoding functions and generators. 9 | """ 10 | 11 | def test_bernoulli(self): 12 | for n in [1, 100]: # number of nodes in layer 13 | for t in [1, 100]: # number of timesteps 14 | for m in [0.1, 1.0]: # maximum spiking probability 15 | datum = torch.empty(n).uniform_(0, m) 16 | spikes = bernoulli(datum, time=t, max_prob=m) 17 | 18 | assert spikes.size() == torch.Size((t, n)) 19 | 20 | def test_multidim_bernoulli(self): 21 | for shape in [[5, 5], [10, 10], [25, 25]]: # shape of nodes in layer 22 | for t in [1, 100]: # number of timesteps 23 | for m in [0.1, 1.0]: # maximum spiking probability 24 | datum = torch.empty(shape).uniform_(0, m) 25 | spikes = bernoulli(datum, time=t, max_prob=m) 26 | 27 | assert spikes.size() == torch.Size((t, *shape)) 28 | 29 | def test_bernoulli_loader(self): 30 | for s in [1, 100]: # number of data samples 31 | for n in [1, 100]: # number of nodes in layer 32 | for m in [0.1, 1.0]: # maximum spiking probability 33 | for t in [1, 100]: # number of timesteps 34 | data = torch.empty(s, n).uniform_(0, 1) 35 | spike_loader = bernoulli_loader(data, time=t, max_prob=m) 36 | 37 | for i, spikes in enumerate(spike_loader): 38 | assert spikes.size() == torch.Size((t, n)) 39 | 40 | def test_poisson(self): 41 | for n in [1, 100]: # number of nodes in layer 42 | for t in [1000]: # number of timesteps 43 | datum = torch.empty(n).uniform_(20, 100) # Generate firing rates. 44 | spikes = poisson(datum, time=t) # Encode as spikes. 45 | 46 | assert spikes.size() == torch.Size((t, n)) 47 | 48 | def test_poisson_loader(self): 49 | for s in [1, 10]: # number of data samples 50 | for n in [1, 100]: # number of nodes in layer 51 | for t in [1000]: # number of timesteps 52 | data = torch.empty(s, n).uniform_(20, 100) # Generate firing rates. 53 | spike_loader = poisson_loader(data, time=t) # Encode as spikes. 54 | 55 | for i, spikes in enumerate(spike_loader): 56 | assert spikes.size() == torch.Size((t, n)) 57 | -------------------------------------------------------------------------------- /test/import/__pycache__/test_import.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/import/__pycache__/test_import.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/import/__pycache__/test_import.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/import/__pycache__/test_import.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/import/__pycache__/test_import.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/import/__pycache__/test_import.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/import/test_import.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/import/test_import.py -------------------------------------------------------------------------------- /test/models/__pycache__/test_models.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/models/__pycache__/test_models.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/models/__pycache__/test_models.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/models/__pycache__/test_models.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/models/__pycache__/test_models.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/models/__pycache__/test_models.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/models/test_models.py: -------------------------------------------------------------------------------- 1 | from bindsnet.models import DiehlAndCook2015, TwoLayerNetwork 2 | from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes 3 | from bindsnet.network.topology import Connection 4 | 5 | 6 | class TestTwoLayerNetwork: 7 | def test_init(self): 8 | for n_inpt in [50, 100, 200]: 9 | for n_neurons in [50, 100, 200]: 10 | for dt in [1.0, 2.0]: 11 | network = TwoLayerNetwork(n_inpt, n_neurons=n_neurons, dt=dt) 12 | 13 | assert network.n_inpt == n_inpt 14 | assert network.n_neurons == n_neurons 15 | assert network.dt == dt 16 | 17 | assert ( 18 | isinstance(network.layers["X"], Input) 19 | and network.layers["X"].n == n_inpt 20 | ) 21 | assert ( 22 | isinstance(network.layers["Y"], LIFNodes) 23 | and network.layers["Y"].n == n_neurons 24 | ) 25 | assert isinstance(network.connections[("X", "Y")], Connection) 26 | assert ( 27 | network.connections[("X", "Y")].source.n == n_inpt 28 | and network.connections[("X", "Y")].target.n == n_neurons 29 | ) 30 | 31 | 32 | class TestDiehlAndCook2015: 33 | def test_init(self): 34 | for n_inpt in [50, 100, 200]: 35 | for n_neurons in [50, 100, 200]: 36 | for dt in [1.0, 2.0]: 37 | for exc in [13.3, 14.53]: 38 | for inh in [10.5, 12.2]: 39 | network = DiehlAndCook2015( 40 | n_inpt=n_inpt, 41 | n_neurons=n_neurons, 42 | exc=exc, 43 | inh=inh, 44 | dt=dt, 45 | ) 46 | 47 | assert network.n_inpt == n_inpt 48 | assert network.n_neurons == n_neurons 49 | assert network.dt == dt 50 | assert network.exc == exc 51 | assert network.inh == inh 52 | 53 | assert ( 54 | isinstance(network.layers["X"], Input) 55 | and network.layers["X"].n == n_inpt 56 | ) 57 | assert ( 58 | isinstance(network.layers["Ae"], DiehlAndCookNodes) 59 | and network.layers["Ae"].n == n_neurons 60 | ) 61 | assert ( 62 | isinstance(network.layers["Ai"], LIFNodes) 63 | and network.layers["Ae"].n == n_neurons 64 | ) 65 | 66 | for conn in [("X", "Ae"), ("Ae", "Ai"), ("Ai", "Ae")]: 67 | assert conn in network.connections 68 | -------------------------------------------------------------------------------- /test/network/__pycache__/test_connections.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_connections.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_connections.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_connections.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_connections.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_connections.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_learning.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_learning.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_learning.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_learning.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_learning.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_learning.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_monitors.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_monitors.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_monitors.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_monitors.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_monitors.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_monitors.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_network.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_network.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_network.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_network.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_network.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_network.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_nodes.cpython-310-pytest-7.4.4.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_nodes.cpython-310-pytest-7.4.4.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_nodes.cpython-310-pytest-8.1.1.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_nodes.cpython-310-pytest-8.1.1.pyc -------------------------------------------------------------------------------- /test/network/__pycache__/test_nodes.cpython-310-pytest-8.2.2.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BindsNET/bindsnet/666dd447c0442a9cddfc6165f79f94fc522446db/test/network/__pycache__/test_nodes.cpython-310-pytest-8.2.2.pyc -------------------------------------------------------------------------------- /test/network/test_monitors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bindsnet.network import Network 4 | from bindsnet.network.monitors import Monitor, NetworkMonitor 5 | from bindsnet.network.nodes import IFNodes, Input 6 | from bindsnet.network.topology import Connection 7 | 8 | 9 | class TestMonitor: 10 | """ 11 | Testing Monitor object. 12 | """ 13 | 14 | network = Network() 15 | 16 | inpt = Input(75) 17 | network.add_layer(inpt, name="X") 18 | _if = IFNodes(25) 19 | network.add_layer(_if, name="Y") 20 | conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n)) 21 | network.add_connection(conn, source="X", target="Y") 22 | 23 | inpt_mon = Monitor(inpt, state_vars=["s"]) 24 | network.add_monitor(inpt_mon, name="X") 25 | _if_mon = Monitor(_if, state_vars=["s", "v"]) 26 | network.add_monitor(_if_mon, name="Y") 27 | 28 | network.run(inputs={"X": torch.bernoulli(torch.rand(100, inpt.n))}, time=100) 29 | 30 | assert inpt_mon.get("s").size() == torch.Size([100, 1, inpt.n]) 31 | assert _if_mon.get("s").size() == torch.Size([100, 1, _if.n]) 32 | assert _if_mon.get("v").size() == torch.Size([100, 1, _if.n]) 33 | 34 | del network.monitors["X"], network.monitors["Y"] 35 | 36 | inpt_mon = Monitor(inpt, state_vars=["s"], time=500) 37 | network.add_monitor(inpt_mon, name="X") 38 | _if_mon = Monitor(_if, state_vars=["s", "v"], time=500) 39 | network.add_monitor(_if_mon, name="Y") 40 | 41 | network.run(inputs={"X": torch.bernoulli(torch.rand(500, inpt.n))}, time=500) 42 | 43 | assert inpt_mon.get("s").size() == torch.Size([500, 1, inpt.n]) 44 | assert _if_mon.get("s").size() == torch.Size([500, 1, _if.n]) 45 | assert _if_mon.get("v").size() == torch.Size([500, 1, _if.n]) 46 | 47 | 48 | class TestNetworkMonitor: 49 | """ 50 | Testing NetworkMonitor object. 51 | """ 52 | 53 | network = Network() 54 | 55 | inpt = Input(25) 56 | network.add_layer(inpt, name="X") 57 | _if = IFNodes(75) 58 | network.add_layer(_if, name="Y") 59 | conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n)) 60 | network.add_connection(conn, source="X", target="Y") 61 | 62 | mon = NetworkMonitor(network, state_vars=["s", "v", "w"]) 63 | network.add_monitor(mon, name="monitor") 64 | 65 | network.run(inputs={"X": torch.bernoulli(torch.rand(50, inpt.n))}, time=50) 66 | 67 | recording = mon.get() 68 | 69 | assert recording["X"]["s"].size() == torch.Size([50, 1, inpt.n]) 70 | assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n]) 71 | assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n]) 72 | 73 | del network.monitors["monitor"] 74 | 75 | mon = NetworkMonitor(network, state_vars=["s", "v", "w"], time=50) 76 | network.add_monitor(mon, name="monitor") 77 | 78 | network.run(inputs={"X": torch.bernoulli(torch.rand(50, inpt.n))}, time=50) 79 | 80 | recording = mon.get() 81 | 82 | assert recording["X"]["s"].size() == torch.Size([50, 1, inpt.n]) 83 | assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n]) 84 | assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n]) 85 | 86 | 87 | if __name__ == "__main__": 88 | tm = TestMonitor() 89 | tnm = TestNetworkMonitor() 90 | -------------------------------------------------------------------------------- /test/network/test_network.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | from bindsnet.network import Network, load 5 | from bindsnet.network.monitors import Monitor 6 | from bindsnet.network.nodes import Input, LIFNodes 7 | from bindsnet.network.topology import Connection 8 | 9 | 10 | class TestNetwork: 11 | """ 12 | Tests basic network functionality. 13 | """ 14 | 15 | def test_empty(self, tmp_path): 16 | for dt in [0.1, 1.0, 5.0]: 17 | network = Network(dt=dt) 18 | assert network.dt == dt 19 | 20 | network.run(inputs={}, time=1000) 21 | 22 | file_path = str(tmp_path / "net.pt") 23 | network.save(file_path) 24 | _network = load(file_path) 25 | assert _network.dt == dt 26 | assert _network.learning 27 | del _network 28 | 29 | _network = load(file_path, learning=True) 30 | assert _network.dt == dt 31 | assert _network.learning 32 | del _network 33 | 34 | _network = load(file_path, learning=False) 35 | assert _network.dt == dt 36 | assert not _network.learning 37 | del _network 38 | 39 | def test_add_objects(self, tmp_path): 40 | network = Network(dt=1.0, learning=False) 41 | 42 | inpt = Input(100) 43 | network.add_layer(inpt, name="X") 44 | lif = LIFNodes(50) 45 | network.add_layer(lif, name="Y") 46 | 47 | assert inpt == network.layers["X"] 48 | assert lif == network.layers["Y"] 49 | 50 | conn = Connection(inpt, lif) 51 | network.add_connection(conn, source="X", target="Y") 52 | 53 | assert conn == network.connections[("X", "Y")] 54 | 55 | monitor = Monitor(lif, state_vars=["s", "v"]) 56 | network.add_monitor(monitor, "Y") 57 | 58 | assert monitor == network.monitors["Y"] 59 | 60 | file_path = str(tmp_path / "net.pt") 61 | network.save(file_path) 62 | _network = load(file_path, learning=True) 63 | assert _network.learning 64 | assert "X" in _network.layers 65 | assert "Y" in _network.layers 66 | assert ("X", "Y") in _network.connections 67 | assert "Y" in _network.monitors 68 | del _network 69 | -------------------------------------------------------------------------------- /test/network/test_nodes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from bindsnet.network import Network 4 | from bindsnet.network.nodes import ( 5 | AdaptiveLIFNodes, 6 | IFNodes, 7 | Input, 8 | LIFNodes, 9 | McCullochPitts, 10 | Nodes, 11 | SRM0Nodes, 12 | ) 13 | 14 | 15 | class TestNodes: 16 | """ 17 | Tests all stable groups of neurons / nodes. 18 | """ 19 | 20 | def test_init(self): 21 | network = Network() 22 | for i, nodes in enumerate( 23 | [Input, McCullochPitts, IFNodes, LIFNodes, AdaptiveLIFNodes, SRM0Nodes] 24 | ): 25 | for n in [1, 100, 10000]: 26 | layer = nodes(n) 27 | network.add_layer(layer=layer, name=f"{i}_{n}") 28 | 29 | assert layer.n == n 30 | assert (layer.s.float() == torch.zeros(n)).all() 31 | 32 | if nodes in [LIFNodes, AdaptiveLIFNodes]: 33 | assert (layer.v == layer.rest * torch.ones(n)).all() 34 | 35 | layer = nodes(n, traces=True, tc_trace=1e5) 36 | network.add_layer(layer=layer, name=f"{i}_traces_{n}") 37 | 38 | assert layer.n == n 39 | assert layer.tc_trace == 1e5 40 | assert (layer.s.float() == torch.zeros(n)).all() 41 | assert (layer.x == torch.zeros(n)).all() 42 | assert (layer.x == torch.zeros(n)).all() 43 | 44 | if nodes in [LIFNodes, AdaptiveLIFNodes, SRM0Nodes]: 45 | assert (layer.v == layer.rest * torch.ones(n)).all() 46 | 47 | for nodes in [LIFNodes, AdaptiveLIFNodes]: 48 | for n in [1, 100, 10000]: 49 | layer = nodes( 50 | n, rest=0.0, reset=-10.0, thresh=10.0, refrac=3, tc_decay=1.5e3 51 | ) 52 | network.add_layer(layer=layer, name=f"{i}_params_{n}") 53 | 54 | assert layer.rest == 0.0 55 | assert layer.reset == -10.0 56 | assert layer.thresh == 10.0 57 | assert layer.refrac == 3 58 | assert layer.tc_decay == 1.5e3 59 | assert (layer.s.float() == torch.zeros(n)).all() 60 | assert (layer.v == layer.rest * torch.ones(n)).all() 61 | 62 | def test_transfer(self): 63 | if not torch.cuda.is_available(): 64 | return 65 | 66 | for nodes in Nodes.__subclasses__(): 67 | layer = nodes(10) 68 | 69 | layer.to(torch.device("cuda:0")) 70 | 71 | layer_tensors = [ 72 | k for k, v in layer.state_dict().items() if isinstance(v, torch.Tensor) 73 | ] 74 | 75 | tensor_devs = [getattr(layer, k).device for k in layer_tensors] 76 | 77 | print("State dict in {} : {}".format(nodes, layer.state_dict().keys())) 78 | print("__dict__ in {} : {}".format(nodes, layer.__dict__.keys())) 79 | print("Tensors in {} : {}".format(nodes, layer_tensors)) 80 | print("Tensor devices {}".format(list(zip(layer_tensors, tensor_devs)))) 81 | 82 | for d in tensor_devs: 83 | print(d, d == torch.device("cuda:0")) 84 | assert d == torch.device("cuda:0") 85 | 86 | print("Reset layer") 87 | layer.reset_state_variables() 88 | layer_tensors = [ 89 | k for k, v in layer.state_dict().items() if isinstance(v, torch.Tensor) 90 | ] 91 | 92 | tensor_devs = [getattr(layer, k).device for k in layer_tensors] 93 | 94 | for d in tensor_devs: 95 | print(d, d == torch.device("cuda:0")) 96 | assert d == torch.device("cuda:0") 97 | 98 | 99 | if __name__ == "__main__": 100 | tester = TestNodes() 101 | 102 | tester.test_init() 103 | tester.test_transfer() 104 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203, E501 4 | exclude = .eggs,.git,.hg,.mypy_cache,.pytest_cache,.tox,.venv,build,dist 5 | --------------------------------------------------------------------------------