├── .gitignore
├── .readthedocs.yml
├── LICENSE
├── Makefile
├── README.md
├── attractors
├── Mi_2014_CANN_1D_oscillatory_tracking.ipynb
├── Wu_2008_CANN.ipynb
├── Wu_2008_CANN_2D.ipynb
├── data
│ └── data_to_train_on.npy
├── discrete_hopfield_demo_for_image_reconstruction.ipynb
└── discrete_hopfield_network.ipynb
├── brain_inspired_computing
├── OTTT-SNN.py
├── SurrogateGrad_lif.py
├── SurrogateGrad_lif_fashion_mnist.py
├── _datasets.py
├── fashion_mnist_conv_lif.py
├── liquid_time_constant_network.py
├── mnist_lif_readout.py
└── spiking_FPTT
│ ├── README.md
│ ├── add_task.py
│ └── mnist_classification.py
├── classical_dynamical_systems
├── Multiscroll_attractor.ipynb
├── Rabinovich_Fabrikant_eq.ipynb
├── fractional_order_chaos.ipynb
├── henon_map.ipynb
├── logistic_map.ipynb
├── lorenz_system.ipynb
└── mackey_glass_eq.ipynb
├── conf.py
├── decision_making
├── Wang_2002_decision_making_spiking.ipynb
├── Wang_2002_decision_making_spiking.py
└── Wang_2006_decision_making_rate.ipynb
├── dynamics_analysis
├── 1d_simple_systems.ipynb
├── 2d_NaK_model.ipynb
├── 2d_decision_making_model.ipynb
├── 2d_decision_making_with_lowdim_analyzer.ipynb
├── 2d_wilson_cowan_model.ipynb
├── 3d_hindmarsh_rose_model.ipynb
├── highdim_CANN.ipynb
└── highdim_gj_coupled_fhn.ipynb
├── ei_nets
├── Brette_2007_COBA.ipynb
├── Brette_2007_COBAHH.ipynb
├── Brette_2007_CUBA.ipynb
├── Tian_2020_EI_net_for_fast_response.ipynb
└── Vreeswijk_1996_EI_net.ipynb
├── gj_nets
├── Fazli_2022_gj_coupled_bursting_pituitary_cells.ipynb
└── Sherman_1992_gj_antisynchrony.ipynb
├── images
├── cann-decoding.gif
├── cann-encoding.gif
├── cann-tracking.gif
├── cann_1d_oscillatory_tracking.gif
├── cann_2d_encoding.gif
├── cann_2d_tracking.gif
├── decision_model.png
└── izhikevich_patterns.jfif
├── index.rst
├── large_scale_modeling
├── 2014_CorticalModel.py
├── EI_net_with_1m_neurons.ipynb
├── Joglekar_2018_InterAreal_Balanced_Amplification_figure1.ipynb
├── Joglekar_2018_InterAreal_Balanced_Amplification_figure2.ipynb
├── Joglekar_2018_InterAreal_Balanced_Amplification_figure5.ipynb
├── Joglekar_2018_InterAreal_Balanced_Amplification_taichi_customized_op.ipynb
└── Joglekar_2018_data
│ ├── efelenMatpython.mat
│ ├── hierValspython.mat
│ ├── subgraphData.mat
│ └── subgraphWiring29.mat
├── make.bat
├── neurons
├── 2018_Fractional_Izhikevich_model.ipynb
├── 2019_Fractional_order_FHR_model.ipynb
├── Gerstner_2005_AdExIF_model.ipynb
├── Izhikevich_2003_Izhikevich_model.ipynb
├── JR_1995_jansen_rit_model.ipynb
├── Niebur_2009_GIF.ipynb
├── Romain_2004_LIF_phase_locking.ipynb
└── Susin_2021_gamma_oscillation_nets.py
├── oscillation_synchronization
├── Brunel_Hakim_1999_fast_oscillation.ipynb
├── Diesmann_1999_synfire_chains.ipynb
├── Li_2017_unified_thalamus_oscillation_model.ipynb
├── Susin_Destexhe_2021_gamma_oscillation_AI.ipynb
├── Susin_Destexhe_2021_gamma_oscillation_CHING.ipynb
├── Susin_Destexhe_2021_gamma_oscillation_ING.ipynb
├── Susin_Destexhe_2021_gamma_oscillation_PING.ipynb
└── Wang_1996_gamma_oscillation.ipynb
├── others
└── Brette_Guigon_2003_spike_timing_reliability.ipynb
├── recurrent_networks
├── Bellec_2020_eprop_evidence_accumulation.ipynb
├── Laje_Buonomano_2013_robust_timing_rnn.ipynb
├── Laje_Buonomano_2013_robust_timing_rnn.py
├── Laje_Buonomano_2013_simulation.ipynb
├── Laje_Buonomano_2013_simulation.py
├── Masse_2019_STP_RNN.ipynb
├── Masse_2019_STP_RNN_tasks.py
├── ParametricWorkingMemory.ipynb
├── Song_2016_EI_RNN.ipynb
├── Sussillo_Abbott_2009_FORCE_Learning.ipynb
├── Yang_2020_RNN_Analysis.ipynb
├── data
│ └── DAC_handwriting_output_targets.mat
├── fixed_points_finder.ipynb
├── fixed_points_finder.py
├── fixed_points_finder2.ipynb
├── fixed_points_finder2.py
└── integrator_rnn.ipynb
├── requirements.txt
├── reservoir_computing
├── Gauthier_2021_ngrc.ipynb
└── predicting_Mackey_Glass_timeseries.ipynb
└── working_memory
├── Bouchacourt_2019_Flexible_working_memory.ipynb
└── Mi_2017_working_memory_capacity.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 | _build
9 | recurrent_networks/neurogym
10 | recurrent_networks/Untitled.ipynb
11 | recurrent_networks/Untitled1.ipynb
12 | .idea
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | pip-wheel-metadata/
29 | share/python-wheels/
30 | *.egg-info/
31 | .installed.cfg
32 | *.egg
33 | MANIFEST
34 |
35 | # PyInstaller
36 | # Usually these files are written by a python script from a template
37 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
38 | *.manifest
39 | *.spec
40 |
41 | # Installer logs
42 | pip-log.txt
43 | pip-delete-this-directory.txt
44 |
45 | # Unit test / coverage reports
46 | htmlcov/
47 | .tox/
48 | .nox/
49 | .coverage
50 | .coverage.*
51 | .cache
52 | nosetests.xml
53 | coverage.xml
54 | *.cover
55 | *.py,cover
56 | .hypothesis/
57 | .pytest_cache/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # IPython
86 | profile_default/
87 | ipython_config.py
88 |
89 | # pyenv
90 | .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 | /brain_inspired_computing/logs/
136 | /brain_inspired_computing/data/
137 | /brain_inspired_computing/results/
138 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | # Required
6 | version: 2
7 |
8 | build:
9 | os: "ubuntu-20.04"
10 | tools:
11 | python: "3.10"
12 |
13 | # Build documentation in the docs/ directory with Sphinx
14 | sphinx:
15 | configuration: conf.py
16 |
17 | # Optionally set the version of Python and requirements required to build your docs
18 | python:
19 | install:
20 | - requirements: requirements.txt
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Examples for [BrainPy](https://github.com/brainpy/BrainPy) computation
2 |
3 | Examples needs ``brainpy>=2.3.0``.
4 |
5 |
6 |
7 |
8 | ### Recent activities
9 |
10 | - 2023/10/06, [Discrete Hopfield Network](./attractors/discrete_hopfiled_network.ipynb)
11 | - 2023/10/06, [Discrete Hopfield Network Demo for Image Reconstruction](./attractors/discrete_hopfield_demo_for_image_reconstruction.ipynb)
12 | - 2023/04/08, large-scale simulation example for simulating [1 million neuron EI network using 1GB GPU memory](./large_scale_modeling/EI_net_with_1m_neurons.ipynb)
13 | - 2023/01/29, implementing [liquid time-constant network](./brain_inspired_computing/liquid_time_constant_network.py)
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/attractors/Mi_2014_CANN_1D_oscillatory_tracking.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "647060a3",
6 | "metadata": {},
7 | "source": [
8 | "# CANN 1D Oscillatory Tracking\n",
9 | "\n",
10 | "[](https://colab.research.google.com/github/brainpy/examples/blob/main/attractors/Mi_2014_CANN_1D_oscillatory_tracking.ipynb)\n",
11 | "[](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/attractors/Mi_2014_CANN_1D_oscillatory_tracking.ipynb)"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "6d135f04",
17 | "metadata": {},
18 | "source": [
19 | "Implementation of the paper:\n",
20 | "\n",
21 | "- Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. \"Dynamics and computation of continuous attractors.\" Neural computation 20.4 (2008): 994-1025.\n",
22 | "- Mi, Y., Fung, C. C., Wong, M. K. Y., & Wu, S. (2014). Spike frequency adaptation implements anticipative tracking in continuous attractor neural networks. Advances in neural information processing systems, 1(January), 505."
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 1,
28 | "id": "51852ed7",
29 | "metadata": {
30 | "ExecuteTime": {
31 | "end_time": "2023-07-22T04:07:35.934007100Z",
32 | "start_time": "2023-07-22T04:07:35.243879300Z"
33 | },
34 | "pycharm": {
35 | "is_executing": true
36 | }
37 | },
38 | "outputs": [],
39 | "source": [
40 | "import brainpy as bp\n",
41 | "import brainpy.math as bm\n",
42 | "\n",
43 | "bm.set_platform('cpu')"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 2,
49 | "metadata": {
50 | "ExecuteTime": {
51 | "end_time": "2023-07-22T04:07:35.949695100Z",
52 | "start_time": "2023-07-22T04:07:35.934007100Z"
53 | },
54 | "collapsed": false
55 | },
56 | "outputs": [
57 | {
58 | "data": {
59 | "text/plain": [
60 | "'2.4.3'"
61 | ]
62 | },
63 | "execution_count": 2,
64 | "metadata": {},
65 | "output_type": "execute_result"
66 | }
67 | ],
68 | "source": [
69 | "bp.__version__"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 2,
75 | "id": "433fe4d4",
76 | "metadata": {},
77 | "outputs": [],
78 | "source": [
79 | "class CANN1D(bp.dyn.NeuDyn):\n",
80 | " def __init__(self, num, tau=1., tau_v=50., k=1., a=0.3, A=0.2, J0=1.,\n",
81 | " z_min=-bm.pi, z_max=bm.pi, m=0.3):\n",
82 | " super(CANN1D, self).__init__(size=num)\n",
83 | "\n",
84 | " # parameters\n",
85 | " self.tau = tau # The synaptic time constant\n",
86 | " self.tau_v = tau_v\n",
87 | " self.k = k # Degree of the rescaled inhibition\n",
88 | " self.a = a # Half-width of the range of excitatory connections\n",
89 | " self.A = A # Magnitude of the external input\n",
90 | " self.J0 = J0 # maximum connection value\n",
91 | " self.m = m\n",
92 | "\n",
93 | " # feature space\n",
94 | " self.z_min = z_min\n",
95 | " self.z_max = z_max\n",
96 | " self.z_range = z_max - z_min\n",
97 | " self.x = bm.linspace(z_min, z_max, num) # The encoded feature values\n",
98 | " self.rho = num / self.z_range # The neural density\n",
99 | " self.dx = self.z_range / num # The stimulus density\n",
100 | "\n",
101 | " # The connection matrix\n",
102 | " self.conn_mat = self.make_conn()\n",
103 | "\n",
104 | " # variables\n",
105 | " self.r = bm.Variable(bm.zeros(num))\n",
106 | " self.u = bm.Variable(bm.zeros(num))\n",
107 | " self.v = bm.Variable(bm.zeros(num))\n",
108 | " self.input = bm.Variable(bm.zeros(num))\n",
109 | "\n",
110 | " def dist(self, d):\n",
111 | " d = bm.remainder(d, self.z_range)\n",
112 | " d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)\n",
113 | " return d\n",
114 | "\n",
115 | " def make_conn(self):\n",
116 | " x_left = bm.reshape(self.x, (-1, 1))\n",
117 | " x_right = bm.repeat(self.x.reshape((1, -1)), len(self.x), axis=0)\n",
118 | " d = self.dist(x_left - x_right)\n",
119 | " conn = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)\n",
120 | " return conn\n",
121 | "\n",
122 | " def get_stimulus_by_pos(self, pos):\n",
123 | " return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))\n",
124 | "\n",
125 | " def update(self):\n",
126 | " r1 = bm.square(self.u)\n",
127 | " r2 = 1.0 + self.k * bm.sum(r1)\n",
128 | " self.r.value = r1 / r2\n",
129 | " Irec = bm.dot(self.conn_mat, self.r)\n",
130 | " self.u.value = self.u + (-self.u + Irec + self.input - self.v) / self.tau * bp.share['dt']\n",
131 | " self.v.value = self.v + (-self.v + self.m * self.u) / self.tau_v * bp.share['dt']\n",
132 | " self.input[:] = 0."
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": 3,
138 | "id": "1c04226c",
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "cann = CANN1D(num=512)"
143 | ]
144 | },
145 | {
146 | "cell_type": "code",
147 | "execution_count": 4,
148 | "id": "6bc3315e",
149 | "metadata": {
150 | "scrolled": false
151 | },
152 | "outputs": [
153 | {
154 | "data": {
155 | "application/vnd.jupyter.widget-view+json": {
156 | "model_id": "76f9d82a2cd447cc9541cd771fee2086",
157 | "version_major": 2,
158 | "version_minor": 0
159 | },
160 | "text/plain": [
161 | " 0%| | 0/26000 [00:00, ?it/s]"
162 | ]
163 | },
164 | "metadata": {},
165 | "output_type": "display_data"
166 | }
167 | ],
168 | "source": [
169 | "dur1, dur2, dur3 = 100., 2000., 500.\n",
170 | "num1 = int(dur1 / bm.get_dt())\n",
171 | "num2 = int(dur2 / bm.get_dt())\n",
172 | "num3 = int(dur3 / bm.get_dt())\n",
173 | "position = bm.zeros(num1 + num2 + num3)\n",
174 | "final_pos = cann.a / cann.tau_v * 0.6 * dur2\n",
175 | "position[num1: num1 + num2] = bm.linspace(0., final_pos, num2)\n",
176 | "position[num1 + num2:] = final_pos\n",
177 | "position = position.reshape((-1, 1))\n",
178 | "Iext = cann.get_stimulus_by_pos(position)\n",
179 | "\n",
180 | "runner = bp.DSRunner(cann,\n",
181 | " inputs=('input', Iext, 'iter'),\n",
182 | " monitors=['u', 'v'])\n",
183 | "runner.run(dur1 + dur2 + dur3)\n",
184 | "_ = bp.visualize.animate_1D(\n",
185 | " dynamical_vars=[\n",
186 | " {'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},\n",
187 | " {'ys': runner.mon.v, 'xs': cann.x, 'legend': 'v'},\n",
188 | " {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}\n",
189 | " ],\n",
190 | " frame_step=30,\n",
191 | " frame_delay=5,\n",
192 | " show=True\n",
193 | ")"
194 | ]
195 | },
196 | {
197 | "cell_type": "markdown",
198 | "id": "4b9b8501",
199 | "metadata": {},
200 | "source": [
201 | ""
202 | ]
203 | }
204 | ],
205 | "metadata": {
206 | "jupytext": {
207 | "cell_metadata_filter": "-all",
208 | "formats": "ipynb,auto:percent",
209 | "notebook_metadata_filter": "-all"
210 | },
211 | "kernelspec": {
212 | "display_name": "brainpy",
213 | "language": "python",
214 | "name": "brainpy"
215 | },
216 | "language_info": {
217 | "codemirror_mode": {
218 | "name": "ipython",
219 | "version": 3
220 | },
221 | "file_extension": ".py",
222 | "mimetype": "text/x-python",
223 | "name": "python",
224 | "nbconvert_exporter": "python",
225 | "pygments_lexer": "ipython3",
226 | "version": "3.9.12"
227 | },
228 | "latex_envs": {
229 | "LaTeX_envs_menu_present": true,
230 | "autoclose": false,
231 | "autocomplete": true,
232 | "bibliofile": "biblio.bib",
233 | "cite_by": "apalike",
234 | "current_citInitial": 1,
235 | "eqLabelWithNumbers": true,
236 | "eqNumInitial": 1,
237 | "hotkeys": {
238 | "equation": "Ctrl-E",
239 | "itemize": "Ctrl-I"
240 | },
241 | "labels_anchors": false,
242 | "latex_user_defs": false,
243 | "report_style_numbering": false,
244 | "user_envs_cfg": false
245 | },
246 | "toc": {
247 | "base_numbering": 1,
248 | "nav_menu": {},
249 | "number_sections": true,
250 | "sideBar": true,
251 | "skip_h1_title": false,
252 | "title_cell": "Table of Contents",
253 | "title_sidebar": "Contents",
254 | "toc_cell": false,
255 | "toc_position": {},
256 | "toc_section_display": true,
257 | "toc_window_display": false
258 | }
259 | },
260 | "nbformat": 4,
261 | "nbformat_minor": 5
262 | }
263 |
--------------------------------------------------------------------------------
/attractors/Wu_2008_CANN.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "cfec5599",
6 | "metadata": {},
7 | "source": [
8 | "# _(Si Wu, 2008)_: Continuous-attractor Neural Network 1D\n",
9 | "\n",
10 | "[](https://colab.research.google.com/github/brainpy/examples/blob/main/attractors/Wu_2008_CANN.ipynb)\n",
11 | "[](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/attractors/Wu_2008_CANN.ipynb)"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "9853e1ab",
17 | "metadata": {},
18 | "source": [
19 | "Here we show the implementation of the paper:\n",
20 | "\n",
21 | "- Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. \"Dynamics and computation\n",
22 | " of continuous attractors.\" Neural computation 20.4 (2008): 994-1025.\n",
23 | "\n",
24 | "Author:\n",
25 | "\n",
26 | "- Chaoming Wang (chao.brain@qq.com)"
27 | ]
28 | },
29 | {
30 | "cell_type": "markdown",
31 | "id": "aa926f46",
32 | "metadata": {},
33 | "source": [
34 | "The mathematical equation of the Continuous-Attractor Neural Network (CANN) is given by:\n",
35 | "\n",
36 | "$$\\tau \\frac{du(x,t)}{dt} = -u(x,t) + \\rho \\int dx' J(x,x') r(x',t)+I_{ext}$$\n",
37 | "\n",
38 | "$$r(x,t) = \\frac{u(x,t)^2}{1 + k \\rho \\int dx' u(x',t)^2}$$\n",
39 | "\n",
40 | "$$J(x,x') = \\frac{1}{\\sqrt{2\\pi}a}\\exp(-\\frac{|x-x'|^2}{2a^2})$$\n",
41 | "\n",
42 | "$$I_{ext} = A\\exp\\left[-\\frac{|x-z(t)|^2}{4a^2}\\right]$$"
43 | ]
44 | },
45 | {
46 | "cell_type": "code",
47 | "execution_count": 1,
48 | "id": "655048d7",
49 | "metadata": {
50 | "ExecuteTime": {
51 | "end_time": "2023-07-22T04:07:53.354158Z",
52 | "start_time": "2023-07-22T04:07:52.616728700Z"
53 | }
54 | },
55 | "outputs": [],
56 | "source": [
57 | "import brainpy as bp\n",
58 | "import brainpy.math as bm"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 2,
64 | "metadata": {
65 | "ExecuteTime": {
66 | "end_time": "2023-07-22T04:07:53.369742200Z",
67 | "start_time": "2023-07-22T04:07:53.354158Z"
68 | },
69 | "collapsed": false
70 | },
71 | "outputs": [
72 | {
73 | "data": {
74 | "text/plain": [
75 | "'2.4.3'"
76 | ]
77 | },
78 | "execution_count": 2,
79 | "metadata": {},
80 | "output_type": "execute_result"
81 | }
82 | ],
83 | "source": [
84 | "bp.__version__"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 3,
90 | "id": "2dafb4c1",
91 | "metadata": {
92 | "ExecuteTime": {
93 | "end_time": "2023-07-22T04:07:53.416674700Z",
94 | "start_time": "2023-07-22T04:07:53.369742200Z"
95 | },
96 | "lines_to_next_cell": 1
97 | },
98 | "outputs": [],
99 | "source": [
100 | "class CANN1D(bp.dyn.NeuDyn):\n",
101 | " def __init__(self, num, tau=1., k=8.1, a=0.5, A=10., J0=4.,\n",
102 | " z_min=-bm.pi, z_max=bm.pi, **kwargs):\n",
103 | " super(CANN1D, self).__init__(size=num, **kwargs)\n",
104 | "\n",
105 | " # parameters\n",
106 | " self.tau = tau # The synaptic time constant\n",
107 | " self.k = k # Degree of the rescaled inhibition\n",
108 | " self.a = a # Half-width of the range of excitatory connections\n",
109 | " self.A = A # Magnitude of the external input\n",
110 | " self.J0 = J0 # maximum connection value\n",
111 | "\n",
112 | " # feature space\n",
113 | " self.z_min = z_min\n",
114 | " self.z_max = z_max\n",
115 | " self.z_range = z_max - z_min\n",
116 | " self.x = bm.linspace(z_min, z_max, num) # The encoded feature values\n",
117 | " self.rho = num / self.z_range # The neural density\n",
118 | " self.dx = self.z_range / num # The stimulus density\n",
119 | "\n",
120 | " # variables\n",
121 | " self.u = bm.Variable(bm.zeros(num))\n",
122 | " self.input = bm.Variable(bm.zeros(num))\n",
123 | "\n",
124 | " # The connection matrix\n",
125 | " self.conn_mat = self.make_conn(self.x)\n",
126 | " \n",
127 | " # function\n",
128 | " self.integral = bp.odeint(self.derivative)\n",
129 | "\n",
130 | " def derivative(self, u, t, Iext):\n",
131 | " r1 = bm.square(u)\n",
132 | " r2 = 1.0 + self.k * bm.sum(r1)\n",
133 | " r = r1 / r2\n",
134 | " Irec = bm.dot(self.conn_mat, r)\n",
135 | " du = (-u + Irec + Iext) / self.tau\n",
136 | " return du\n",
137 | "\n",
138 | " def dist(self, d):\n",
139 | " d = bm.remainder(d, self.z_range)\n",
140 | " d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)\n",
141 | " return d\n",
142 | "\n",
143 | " def make_conn(self, x):\n",
144 | " assert bm.ndim(x) == 1\n",
145 | " x_left = bm.reshape(x, (-1, 1))\n",
146 | " x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)\n",
147 | " d = self.dist(x_left - x_right)\n",
148 | " Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / \\\n",
149 | " (bm.sqrt(2 * bm.pi) * self.a)\n",
150 | " return Jxx\n",
151 | "\n",
152 | " def get_stimulus_by_pos(self, pos):\n",
153 | " return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))\n",
154 | "\n",
155 | " def update(self):\n",
156 | " self.u.value = self.integral(self.u, bp.share['t'], self.input, bp.share['dt'])\n",
157 | " self.input[:] = 0."
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 9,
163 | "id": "64473237",
164 | "metadata": {
165 | "lines_to_next_cell": 2
166 | },
167 | "outputs": [],
168 | "source": [
169 | "cann = CANN1D(num=512, k=0.1)"
170 | ]
171 | },
172 | {
173 | "cell_type": "markdown",
174 | "id": "d83942e7",
175 | "metadata": {},
176 | "source": [
177 | "## Population coding"
178 | ]
179 | },
180 | {
181 | "cell_type": "code",
182 | "execution_count": 10,
183 | "metadata": {
184 | "collapsed": false
185 | },
186 | "outputs": [
187 | {
188 | "data": {
189 | "application/vnd.jupyter.widget-view+json": {
190 | "model_id": "a34c8dd621604a29b1c02176672f28d2",
191 | "version_major": 2,
192 | "version_minor": 0
193 | },
194 | "text/plain": [
195 | " 0%| | 0/170 [00:00, ?it/s]"
196 | ]
197 | },
198 | "metadata": {},
199 | "output_type": "display_data"
200 | }
201 | ],
202 | "source": [
203 | "I1 = cann.get_stimulus_by_pos(0.)\n",
204 | "Iext, duration = bp.inputs.section_input(values=[0., I1, 0.],\n",
205 | " durations=[1., 8., 8.],\n",
206 | " return_length=True)\n",
207 | "runner = bp.DSRunner(cann,\n",
208 | " inputs=['input', Iext, 'iter'],\n",
209 | " monitors=['u'])\n",
210 | "runner.run(duration)\n",
211 | "bp.visualize.animate_1D(\n",
212 | " dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},\n",
213 | " {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}],\n",
214 | " frame_step=1,\n",
215 | " frame_delay=100,\n",
216 | " show=True,\n",
217 | " # save_path='../../images/cann-encoding.gif'\n",
218 | ")"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "id": "aa8cb3c4",
224 | "metadata": {},
225 | "source": [
226 | ""
227 | ]
228 | },
229 | {
230 | "cell_type": "markdown",
231 | "id": "efb40f33",
232 | "metadata": {},
233 | "source": [
234 | "## Template matching"
235 | ]
236 | },
237 | {
238 | "cell_type": "markdown",
239 | "id": "4e7550da",
240 | "metadata": {},
241 | "source": [
242 | "The cann can perform efficient population decoding by achieving template-matching."
243 | ]
244 | },
245 | {
246 | "cell_type": "code",
247 | "execution_count": 11,
248 | "id": "67aaa53f",
249 | "metadata": {},
250 | "outputs": [
251 | {
252 | "data": {
253 | "application/vnd.jupyter.widget-view+json": {
254 | "model_id": "9b850fbc278b470f8ecded402791df28",
255 | "version_major": 2,
256 | "version_minor": 0
257 | },
258 | "text/plain": [
259 | " 0%| | 0/400 [00:00, ?it/s]"
260 | ]
261 | },
262 | "metadata": {},
263 | "output_type": "display_data"
264 | }
265 | ],
266 | "source": [
267 | "cann.k = 8.1\n",
268 | "\n",
269 | "dur1, dur2, dur3 = 10., 30., 0.\n",
270 | "num1 = int(dur1 / bm.get_dt())\n",
271 | "num2 = int(dur2 / bm.get_dt())\n",
272 | "num3 = int(dur3 / bm.get_dt())\n",
273 | "Iext = bm.zeros((num1 + num2 + num3,) + cann.size)\n",
274 | "Iext[:num1] = cann.get_stimulus_by_pos(0.5)\n",
275 | "Iext[num1:num1 + num2] = cann.get_stimulus_by_pos(0.)\n",
276 | "Iext[num1:num1 + num2] += 0.1 * cann.A * bm.random.randn(num2, *cann.size)\n",
277 | "\n",
278 | "runner = bp.DSRunner(cann,\n",
279 | " inputs=('input', Iext, 'iter'),\n",
280 | " monitors=['u'])\n",
281 | "runner.run(dur1 + dur2 + dur3)\n",
282 | "bp.visualize.animate_1D(\n",
283 | " dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},\n",
284 | " {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}],\n",
285 | " frame_step=5,\n",
286 | " frame_delay=50,\n",
287 | " show=True,\n",
288 | " # save_path='../../images/cann-decoding.gif'\n",
289 | ")"
290 | ]
291 | },
292 | {
293 | "cell_type": "markdown",
294 | "id": "deb96b0b",
295 | "metadata": {},
296 | "source": [
297 | ""
298 | ]
299 | },
300 | {
301 | "cell_type": "markdown",
302 | "id": "a373fef3",
303 | "metadata": {},
304 | "source": [
305 | "## Smooth tracking\n",
306 | "\n",
307 | "The cann can track moving stimulus."
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": 12,
313 | "id": "42929b22",
314 | "metadata": {},
315 | "outputs": [
316 | {
317 | "data": {
318 | "application/vnd.jupyter.widget-view+json": {
319 | "model_id": "043f17ebeb4b41a69839a8fdcfb6c533",
320 | "version_major": 2,
321 | "version_minor": 0
322 | },
323 | "text/plain": [
324 | " 0%| | 0/600 [00:00, ?it/s]"
325 | ]
326 | },
327 | "metadata": {},
328 | "output_type": "display_data"
329 | }
330 | ],
331 | "source": [
332 | "dur1, dur2, dur3 = 20., 20., 20.\n",
333 | "num1 = int(dur1 / bm.get_dt())\n",
334 | "num2 = int(dur2 / bm.get_dt())\n",
335 | "num3 = int(dur3 / bm.get_dt())\n",
336 | "position = bm.zeros(num1 + num2 + num3)\n",
337 | "position[num1: num1 + num2] = bm.linspace(0., 12., num2)\n",
338 | "position[num1 + num2:] = 12.\n",
339 | "position = position.reshape((-1, 1))\n",
340 | "Iext = cann.get_stimulus_by_pos(position)\n",
341 | "runner = bp.DSRunner(cann,\n",
342 | " inputs=('input', Iext, 'iter'),\n",
343 | " monitors=['u'])\n",
344 | "runner.run(dur1 + dur2 + dur3)\n",
345 | "bp.visualize.animate_1D(\n",
346 | " dynamical_vars=[{'ys': runner.mon.u, 'xs': cann.x, 'legend': 'u'},\n",
347 | " {'ys': Iext, 'xs': cann.x, 'legend': 'Iext'}],\n",
348 | " frame_step=5,\n",
349 | " frame_delay=50,\n",
350 | " show=True,\n",
351 | " # save_path='../../images/cann-tracking.gif'\n",
352 | ")"
353 | ]
354 | },
355 | {
356 | "cell_type": "markdown",
357 | "id": "dc4c39bf",
358 | "metadata": {},
359 | "source": [
360 | ""
361 | ]
362 | }
363 | ],
364 | "metadata": {
365 | "jupytext": {
366 | "formats": "ipynb,py:percent"
367 | },
368 | "kernelspec": {
369 | "display_name": "brainpy",
370 | "language": "python",
371 | "name": "brainpy"
372 | },
373 | "language_info": {
374 | "codemirror_mode": {
375 | "name": "ipython",
376 | "version": 3
377 | },
378 | "file_extension": ".py",
379 | "mimetype": "text/x-python",
380 | "name": "python",
381 | "nbconvert_exporter": "python",
382 | "pygments_lexer": "ipython3",
383 | "version": "3.8.11"
384 | },
385 | "latex_envs": {
386 | "LaTeX_envs_menu_present": true,
387 | "autoclose": false,
388 | "autocomplete": true,
389 | "bibliofile": "biblio.bib",
390 | "cite_by": "apalike",
391 | "current_citInitial": 1,
392 | "eqLabelWithNumbers": true,
393 | "eqNumInitial": 1,
394 | "hotkeys": {
395 | "equation": "Ctrl-E",
396 | "itemize": "Ctrl-I"
397 | },
398 | "labels_anchors": false,
399 | "latex_user_defs": false,
400 | "report_style_numbering": false,
401 | "user_envs_cfg": false
402 | },
403 | "toc": {
404 | "base_numbering": 1,
405 | "nav_menu": {},
406 | "number_sections": false,
407 | "sideBar": true,
408 | "skip_h1_title": false,
409 | "title_cell": "Table of Contents",
410 | "title_sidebar": "Contents",
411 | "toc_cell": false,
412 | "toc_position": {
413 | "height": "calc(100% - 180px)",
414 | "left": "10px",
415 | "top": "150px",
416 | "width": "245.76px"
417 | },
418 | "toc_section_display": true,
419 | "toc_window_display": true
420 | }
421 | },
422 | "nbformat": 4,
423 | "nbformat_minor": 5
424 | }
425 |
--------------------------------------------------------------------------------
/attractors/data/data_to_train_on.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/attractors/data/data_to_train_on.npy
--------------------------------------------------------------------------------
/brain_inspired_computing/OTTT-SNN.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import argparse
5 | import os
6 | import sys
7 | import time
8 |
9 | import brainpy as bp
10 | import brainpy.math as bm
11 | import jax
12 | import jax.numpy as jnp
13 | import numpy as np
14 | import torch.utils.data as data
15 | import torchvision.datasets as datasets
16 | import torchvision.transforms as transforms
17 | import tqdm
18 | from torchtoolbox.transform import Cutout
19 |
20 | bm.set_environment(bm.TrainingMode())
21 | conv_init = bp.init.KaimingNormal(mode='fan_out', scale=jnp.sqrt(2))
22 | dense_init = bp.init.Normal(0, 0.01)
23 |
24 |
25 | @jax.custom_gradient
26 | def replace(spike, rate):
27 | def grad(dz):
28 | return dz, dz
29 |
30 | return rate, grad
31 |
32 |
33 | class ScaledWSConv2d(bp.layers.Conv2d):
34 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
35 | groups=1, b_initializer=bp.init.ZeroInit(), gain=True, eps=1e-4):
36 | super(ScaledWSConv2d, self).__init__(in_channels=in_channels,
37 | out_channels=out_channels,
38 | kernel_size=kernel_size,
39 | stride=stride,
40 | padding=padding,
41 | groups=groups,
42 | w_initializer=conv_init,
43 | b_initializer=b_initializer)
44 | assert self.mode.is_parent_of(bm.TrainingMode)
45 | if gain:
46 | self.gain = bm.TrainVar(jnp.ones([1, 1, 1, self.out_channels]))
47 | else:
48 | self.gain = None
49 | self.eps = eps
50 |
51 | def update(self, x):
52 | assert self.mask is None
53 | self._check_input_dim(x)
54 | w = self.w.value
55 | fan_in = np.prod(w.shape[:-1])
56 | mean = jnp.mean(w, axis=[0, 1, 2], keepdims=True)
57 | var = jnp.var(w, axis=[0, 1, 2], keepdims=True)
58 | w = (w - mean) / ((var * fan_in + self.eps) ** 0.5)
59 | if self.gain is not None:
60 | w = w * self.gain
61 | y = jax.lax.conv_general_dilated(lhs=bm.as_jax(x),
62 | rhs=bm.as_jax(w),
63 | window_strides=self.stride,
64 | padding=self.padding,
65 | lhs_dilation=self.lhs_dilation,
66 | rhs_dilation=self.rhs_dilation,
67 | feature_group_count=self.groups,
68 | dimension_numbers=self.dimension_numbers)
69 | return y if self.b is None else (y + self.b.value)
70 |
71 |
72 | class ScaledWSLinear(bp.layers.Dense):
73 | def __init__(self, in_features, out_features, b_initializer=bp.init.ZeroInit(), gain=True, eps=1e-4):
74 | super(ScaledWSLinear, self).__init__(num_in=in_features,
75 | num_out=out_features,
76 | W_initializer=dense_init,
77 | b_initializer=b_initializer)
78 | bp.check.is_subclass(self.mode, bm.TrainingMode)
79 | if gain:
80 | self.gain = bm.TrainVar(jnp.ones(1, self.num_out))
81 | else:
82 | self.gain = None
83 | self.eps = eps
84 |
85 | def update(self, x):
86 | fan_in = self.W.shape[0]
87 | mean = jnp.mean(self.W.value, axis=0, keepdims=True)
88 | var = jnp.var(self.W.value, axis=0, keepdims=True)
89 | weight = (self.W.value - mean) / ((var * fan_in + self.eps) ** 0.5)
90 | if self.gain is not None:
91 | weight = weight * self.gain
92 | if self.b is not None:
93 | return x @ weight + self.b
94 | else:
95 | return x @ weight
96 |
97 |
98 | class Scale(bp.layers.Layer):
99 | def __init__(self, scale: float):
100 | super(Scale, self).__init__()
101 | self.scale = scale
102 |
103 | def update(self, x):
104 | return x * self.scale
105 |
106 |
107 | class WrappedSNNOp(bp.layers.Layer):
108 | def __init__(self, op, grad_with_rate):
109 | super(WrappedSNNOp, self).__init__()
110 | self.op = op
111 | self.grad_with_rate = grad_with_rate
112 |
113 | def update(self, x):
114 | if bp.share.load('fit') and self.grad_with_rate:
115 | spike, rate = jnp.split(x, 2, axis=0)
116 | out_for_grad = self.op(replace(spike, rate))
117 | out = jax.lax.stop_gradient(self.op(spike))
118 | return replace(out_for_grad, out)
119 | else:
120 | return self.op(x)
121 |
122 | def __repr__(self):
123 | return f'{self.__class__.__name__}(op={self.op}, grad_with_rate={self.grad_with_rate})'
124 |
125 |
126 | class OnlineSpikingVGG(bp.DynamicalSystemNS):
127 | cfg = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512]
128 |
129 | def __init__(
130 | self,
131 | neuron_model,
132 | weight_standardization=True,
133 | num_classes=1000,
134 | neuron_pars: dict = None,
135 | light_classifier=True,
136 | batch_norm=False,
137 | grad_with_rate: bool = False,
138 | fc_hw: int = 3,
139 | c_in: int = 3
140 | ):
141 | super(OnlineSpikingVGG, self).__init__()
142 |
143 | if neuron_pars is None:
144 | neuron_pars = dict()
145 | self.neuron_pars = neuron_pars
146 | self.neuron_model = neuron_model
147 | self.grad_with_rate = grad_with_rate
148 | self.fc_hw = fc_hw
149 |
150 | neuron_sizes = [(32, 32, 64),
151 | (32, 32, 128),
152 | (16, 16, 256),
153 | (16, 16, 256),
154 | (8, 8, 512),
155 | (8, 8, 512),
156 | (4, 4, 512),
157 | (4, 4, 512), ]
158 | neuron_i = 0
159 | layers = []
160 | first_conv = True
161 | in_channels = c_in
162 |
163 | for v in self.cfg:
164 | if v == 'M':
165 | layers.append(bp.layers.AvgPool2d(kernel_size=2, stride=2))
166 | else:
167 | if weight_standardization:
168 | conv2d = ScaledWSConv2d(in_channels, v, kernel_size=3, padding=1, stride=1)
169 | if first_conv:
170 | first_conv = False
171 | else:
172 | conv2d = WrappedSNNOp(conv2d, self.grad_with_rate)
173 | layers += [conv2d,
174 | self.neuron_model(neuron_sizes[neuron_i], **self.neuron_pars),
175 | Scale(2.74)]
176 | else:
177 | conv2d = bp.layers.Conv2d(in_channels, v, kernel_size=3, padding=1, stride=1, w_initializer=conv_init)
178 | if first_conv:
179 | first_conv = False
180 | else:
181 | conv2d = WrappedSNNOp(conv2d, self.grad_with_rate)
182 | if batch_norm:
183 | layers += [conv2d,
184 | bp.layers.BatchNorm2d(v, momentum=0.9),
185 | self.neuron_model(neuron_sizes[neuron_i], **self.neuron_pars)]
186 | else:
187 | layers += [conv2d,
188 | self.neuron_model(neuron_sizes[neuron_i], **self.neuron_pars),
189 | Scale(2.74)]
190 | neuron_i += 1
191 | in_channels = v
192 | self.features = bp.Sequential(*layers)
193 |
194 | if light_classifier:
195 | self.avgpool = bp.layers.AdaptiveAvgPool2d((self.fc_hw, self.fc_hw))
196 | self.classifier = WrappedSNNOp(bp.layers.Dense(512 * self.fc_hw * self.fc_hw,
197 | num_classes,
198 | W_initializer=dense_init),
199 | self.grad_with_rate)
200 | else:
201 | self.avgpool = bp.layers.AdaptiveAvgPool2d((7, 7))
202 | if self.grad_with_rate:
203 | self.classifier = bp.Sequential(
204 | WrappedSNNOp(ScaledWSLinear(512 * 7 * 7, 4096), self.grad_with_rate),
205 | neuron_model((4096,), **self.neuron_pars, neuron_dropout=0.0),
206 | Scale(2.74),
207 | bp.layers.Dropout(0.5),
208 | WrappedSNNOp(ScaledWSLinear(4096, 4096), self.grad_with_rate),
209 | neuron_model((4096,), **self.neuron_pars, neuron_dropout=0.0),
210 | Scale(2.74),
211 | bp.layers.Dropout(0.5),
212 | WrappedSNNOp(bp.layers.Dense(4096, num_classes, W_initializer=dense_init), self.grad_with_rate),
213 | )
214 | else:
215 | self.classifier = bp.Sequential(
216 | ScaledWSLinear(512 * 7 * 7, 4096),
217 | neuron_model((4096,), **self.neuron_pars, neuron_dropout=0.0),
218 | Scale(2.74),
219 | bp.layers.Dropout(0.5),
220 | ScaledWSLinear(4096, 4096),
221 | neuron_model((4096,), **self.neuron_pars, neuron_dropout=0.0),
222 | Scale(2.74),
223 | bp.layers.Dropout(0.5),
224 | bp.layers.Dense(4096, num_classes, W_initializer=dense_init),
225 | )
226 |
227 | def update(self, x):
228 | if self.grad_with_rate and bp.share.load('fit'):
229 | bp.share.save('output_type', 'spike_rate')
230 | x = self.features(x)
231 | x = self.avgpool(x)
232 | x = bm.flatten(x, 1)
233 | x = self.classifier(x)
234 | else:
235 | bp.share.save('output_type', 'spike')
236 | x = self.features(x)
237 | x = self.avgpool(x)
238 | x = bm.flatten(x, 1)
239 | x = self.classifier(x)
240 | return x
241 |
242 |
243 | class OnlineIFNode(bp.DynamicalSystemNS):
244 | def __init__(
245 | self,
246 | size,
247 | v_threshold: float = 1.,
248 | v_reset: float = None,
249 | f_surrogate=bm.surrogate.sigmoid,
250 | detach_reset: bool = True,
251 | track_rate: bool = True,
252 | neuron_dropout: float = 0.0,
253 | name: str = None,
254 | mode: bm.Mode = None
255 | ):
256 | super().__init__(name=name, mode=mode)
257 | bp.check.is_subclass(self.mode, bm.TrainingMode)
258 |
259 | self.size = bp.check.is_sequence(size, elem_type=int)
260 | self.f_surrogate = bp.check.is_callable(f_surrogate)
261 | self.detach_reset = detach_reset
262 | self.v_reset = v_reset
263 | self.v_threshold = v_threshold
264 | self.track_rate = track_rate
265 | self.dropout = neuron_dropout
266 | if self.dropout > 0.0:
267 | self.rng = bm.random.default_rng()
268 | self.reset_state(1)
269 |
270 | def reset_state(self, batch_size=1):
271 | self.v = bp.init.variable_(bm.zeros, self.size, batch_size)
272 | self.spike = bp.init.variable_(bm.zeros, self.size, batch_size)
273 | if self.track_rate:
274 | self.rate_tracking = bp.init.variable_(bm.zeros, self.size, batch_size)
275 |
276 | def update(self, x):
277 | # neuron charge
278 | self.v.value = jax.lax.stop_gradient(self.v.value) + x
279 | # neuron fire
280 | spike = self.f_surrogate(self.v.value - self.v_threshold)
281 | # spike reset
282 | spike_d = jax.lax.stop_gradient(spike) if self.detach_reset else spike
283 | if self.v_reset is None:
284 | self.v -= spike_d * self.v_threshold
285 | else:
286 | self.v.value = (1. - spike_d) * self.v + spike_d * self.v_reset
287 | # dropout
288 | if self.dropout > 0.0 and bp.share.load('fit'):
289 | mask = self.rng.bernoulli(1 - self.dropout, self.v.shape) / (1 - self.dropout)
290 | spike = mask * spike
291 | self.spike.value = spike
292 | # spike track
293 | if self.track_rate:
294 | self.rate_tracking += jax.lax.stop_gradient(spike)
295 | # output
296 | if bm.save.load('output_type') == 'spike_rate':
297 | assert self.track_rate
298 | return jnp.concatenate([spike, self.rate_tracking.value])
299 | else:
300 | return spike
301 |
302 |
303 | class OnlineLIFNode(bp.DynamicalSystemNS):
304 | def __init__(
305 | self,
306 | size,
307 | tau: float = 2.,
308 | decay_input: bool = False,
309 | v_threshold: float = 1.,
310 | v_reset: float = None,
311 | f_surrogate=bm.surrogate.sigmoid,
312 | detach_reset: bool = True,
313 | track_rate: bool = True,
314 | neuron_dropout: float = 0.0,
315 | name: str = None,
316 | mode: bm.Mode = None
317 | ):
318 | super().__init__(name=name, mode=mode)
319 | bp.check.is_subclass(self.mode, bm.TrainingMode)
320 |
321 | self.size = bp.check.is_sequence(size, elem_type=int)
322 | self.tau = tau
323 | self.decay_input = decay_input
324 | self.v_threshold = v_threshold
325 | self.v_reset = v_reset
326 | self.f_surrogate = f_surrogate
327 | self.detach_reset = detach_reset
328 | self.track_rate = track_rate
329 | self.dropout = neuron_dropout
330 |
331 | if self.dropout > 0.0:
332 | self.rng = bm.random.default_rng()
333 | self.reset_state(1)
334 |
335 | def reset_state(self, batch_size=1):
336 | self.v = bp.init.variable_(bm.zeros, self.size, batch_size)
337 | self.spike = bp.init.variable_(bm.zeros, self.size, batch_size)
338 | if self.track_rate:
339 | self.rate_tracking = bp.init.variable_(bm.zeros, self.size, batch_size)
340 |
341 | def update(self, x):
342 | # neuron charge
343 | if self.decay_input:
344 | x = x / self.tau
345 | if self.v_reset is None or self.v_reset == 0:
346 | self.v = jax.lax.stop_gradient(self.v.value) * (1 - 1. / self.tau) + x
347 | else:
348 | self.v = jax.lax.stop_gradient(self.v.value) * (1 - 1. / self.tau) + self.v_reset / self.tau + x
349 | # neuron fire
350 | spike = self.f_surrogate(self.v - self.v_threshold)
351 | # neuron reset
352 | spike_d = jax.lax.stop_gradient(spike) if self.detach_reset else spike
353 | if self.v_reset is None:
354 | self.v -= spike_d * self.v_threshold
355 | else:
356 | self.v = (1. - spike_d) * self.v + spike_d * self.v_reset
357 | # dropout
358 | if self.dropout > 0.0 and bp.share.load('fit'):
359 | mask = self.rng.bernoulli(1 - self.dropout, spike.shape) / (1 - self.dropout)
360 | spike = mask * spike
361 | self.spike.value = spike
362 | # spike
363 | if self.track_rate:
364 | self.rate_tracking.value = jax.lax.stop_gradient(self.rate_tracking * (1 - 1. / self.tau) + spike)
365 | if bp.share.load('output_type') == 'spike_rate':
366 | assert self.track_rate
367 | return jnp.concatenate((spike, self.rate_tracking.value))
368 | else:
369 | return spike
370 |
371 |
372 | class AverageMeter(object):
373 | def __init__(self):
374 | self.reset()
375 |
376 | def reset(self):
377 | self.val = 0
378 | self.avg = 0
379 | self.sum = 0
380 | self.count = 0
381 |
382 | def update(self, val, n=1):
383 | self.val = val
384 | self.sum += val * n
385 | self.count += n
386 | self.avg = self.sum / self.count
387 |
388 |
389 | @bm.jit(static_argnums=2)
390 | def accuracy(output, target, topk=(1,)):
391 | """Computes the precision@k for the specified values of k"""
392 | maxk = max(topk)
393 | _, pred = jax.vmap(jax.lax.top_k, in_axes=(0, None))(output, maxk)
394 | pred = pred.T
395 | correct = (pred == target.reshape(1, -1)).astype(bm.float_)
396 | res = []
397 | for k in topk:
398 | correct_k = correct[:k].reshape(-1).sum(0)
399 | res.append(correct_k * 100.0 / target.size)
400 | return res
401 |
402 |
403 | # print(accuracy(jnp.ones(10, ), jnp.ones(20), (1, 3, 5)))
404 | # import sys
405 | # sys.exit()
406 |
407 |
408 | def classify_cifar():
409 | parser = argparse.ArgumentParser(description='Classify CIFAR')
410 | parser.add_argument('-T', default=6, type=int, help='simulating time-steps')
411 | parser.add_argument('-tau', default=2., type=float)
412 | parser.add_argument('-b', default=128, type=int, help='batch size')
413 | parser.add_argument('-epochs', default=300, type=int, help='number of total epochs to run')
414 | parser.add_argument('-j', default=4, type=int, help='number of data loading workers (default: 4)')
415 | parser.add_argument('-data_dir', type=str, default=r'/mnt/d/data')
416 | # parser.add_argument('-data_dir', type=str, default=r'D:/data')
417 | parser.add_argument('-dataset', default='cifar10', type=str)
418 | parser.add_argument('-out_dir', default='./logs', type=str, help='root dir for saving logs and checkpoint')
419 | parser.add_argument('-resume', type=str, help='resume from the checkpoint path')
420 | parser.add_argument('-opt', type=str, help='use which optimizer. SGD or Adam', default='SGD')
421 | parser.add_argument('-lr', default=0.1, type=float, help='learning rate')
422 | parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
423 | parser.add_argument('-lr_scheduler', default='CosALR', type=str, help='use which schedule. StepLR or CosALR')
424 | parser.add_argument('-step_size', default=100, type=float, help='step_size for StepLR')
425 | parser.add_argument('-gamma', default=0.1, type=float, help='gamma for StepLR')
426 | parser.add_argument('-T_max', default=300, type=int, help='T_max for CosineAnnealingLR')
427 | parser.add_argument('-drop_rate', type=float, default=0.0)
428 | parser.add_argument('-weight_decay', type=float, default=0.0)
429 | parser.add_argument('-loss_lambda', type=float, default=0.05)
430 | parser.add_argument('-online_update', action='store_true')
431 | parser.add_argument('-gpu-id', default='0', type=str, help='gpu id')
432 | args = parser.parse_args()
433 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
434 |
435 | # datasets
436 | transform_train = transforms.Compose([
437 | transforms.RandomCrop(32, padding=4),
438 | Cutout(),
439 | transforms.RandomHorizontalFlip(),
440 | transforms.ToTensor(),
441 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
442 | ])
443 | transform_test = transforms.Compose([
444 | transforms.ToTensor(),
445 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
446 | ])
447 | if args.dataset == 'cifar10':
448 | dataloader = datasets.CIFAR10
449 | num_classes = 10
450 | else:
451 | dataloader = datasets.CIFAR100
452 | num_classes = 100
453 |
454 | # network
455 | net = OnlineSpikingVGG(neuron_model=OnlineLIFNode,
456 | neuron_pars=dict(tau=args.tau,
457 | neuron_dropout=args.drop_rate,
458 | f_surrogate=bm.surrogate.sigmoid,
459 | track_rate=True,
460 | v_reset=None),
461 | weight_standardization=True,
462 | num_classes=num_classes,
463 | grad_with_rate=True,
464 | fc_hw=1,
465 | c_in=3)
466 | print('Total Parameters: %.2fM' % (
467 | sum(p.size for p in net.vars().subset(bm.TrainVar).unique().values()) / 1000000.0))
468 | print(net)
469 |
470 | trainset = dataloader(root=args.data_dir, train=True, download=True, transform=transform_train)
471 | train_data_loader = data.DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j)
472 | testset = dataloader(root=args.data_dir, train=False, download=False, transform=transform_test)
473 | test_data_loader = data.DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j)
474 |
475 |
476 | # path
477 | out_dir = os.path.join(args.out_dir, f'{args.dataset}_T_{args.T}_{args.opt}_lr_{args.lr}_')
478 | if args.lr_scheduler == 'CosALR':
479 | out_dir += f'CosALR_{args.T_max}'
480 | elif args.lr_scheduler == 'StepLR':
481 | out_dir += f'StepLR_{args.step_size}_{args.gamma}'
482 | else:
483 | raise NotImplementedError(args.lr_scheduler)
484 | if args.online_update:
485 | out_dir += '_online'
486 | os.makedirs(out_dir, exist_ok=True)
487 | with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
488 | args_txt.write(str(args))
489 |
490 | t_step = args.T
491 |
492 | def single_step(x, y, fit=True):
493 | bp.share.save('fit', fit)
494 | out = net(x)
495 | if args.loss_lambda > 0.0:
496 | y = bm.one_hot(y, 10, dtype=bm.float_)
497 | l = bp.losses.mean_squared_error(out, y) * args.loss_lambda
498 | l += (1 - args.loss_lambda) * bp.losses.cross_entropy_loss(out, y)
499 | l /= t_step
500 | else:
501 | l = bp.losses.cross_entropy_loss(out, y) / t_step
502 | return l, out
503 |
504 | @bm.jit
505 | def inference_fun(x, y):
506 | l, out = bm.for_loop(lambda _: single_step(x, y, False), jnp.arange(t_step))
507 | out = out.sum(0)
508 | n = jnp.sum(jnp.argmax(out, axis=1) == y)
509 | return l.sum(), n, out
510 |
511 | grad_fun = bm.grad(single_step, grad_vars=net.train_vars().unique(), return_value=True, has_aux=True)
512 |
513 | if args.lr_scheduler == 'StepLR':
514 | lr = bp.optim.StepLR(args.lr, step_size=args.step_size, gamma=args.gamma)
515 | elif args.lr_scheduler == 'CosALR':
516 | lr = bp.optim.CosineAnnealingLR(args.lr, T_max=args.T_max)
517 | else:
518 | raise NotImplementedError(args.lr_scheduler)
519 |
520 | if args.opt == 'SGD':
521 | optimizer = bp.optim.Momentum(lr, net.train_vars().unique(), momentum=args.momentum,
522 | weight_decay=args.weight_decay)
523 | elif args.opt == 'Adam':
524 | optimizer = bp.optim.AdamW(lr, net.train_vars().unique(), weight_decay=args.weight_decay)
525 | else:
526 | raise NotImplementedError(args.opt)
527 |
528 | @bm.jit
529 | def train_fun(x, y):
530 | if args.online_update:
531 | final_loss, final_out = 0., 0.
532 | for _ in range(t_step):
533 | grads, l, out = grad_fun(x, y)
534 | optimizer.update(grads)
535 | final_loss += l
536 | final_out += out
537 | else:
538 | final_grads, final_loss, final_out = grad_fun(x, y)
539 | for _ in range(t_step - 1):
540 | grads, l, out = grad_fun(x, y)
541 | final_grads = jax.tree_util.tree_map(lambda a, b: a + b, final_grads, grads)
542 | final_loss += l
543 | final_out += out
544 | optimizer.update(final_grads)
545 | n = jnp.sum(jnp.argmax(final_out, axis=1) == y)
546 | return final_loss, n, final_out
547 |
548 | start_epoch = 0
549 | max_test_acc = 0
550 | if args.resume:
551 | checkpoint = bp.checkpoints.load_pytree(args.resume)
552 | net.load_state_dict(checkpoint['net'])
553 | optimizer.load_state_dict(checkpoint['optimizer'])
554 | start_epoch = checkpoint['epoch'] + 1
555 | max_test_acc = checkpoint['max_test_acc']
556 |
557 | train_samples = len(train_data_loader)
558 | test_samples = len(test_data_loader)
559 | for epoch in range(start_epoch, args.epochs):
560 | start_time = time.time()
561 |
562 | batch_time = AverageMeter()
563 | losses = AverageMeter()
564 | top1 = AverageMeter()
565 | top5 = AverageMeter()
566 | end = time.time()
567 |
568 | train_loss = 0
569 | train_acc = 0
570 | pbar = tqdm.tqdm(total=train_samples)
571 | for frame, label in train_data_loader:
572 | frame = jnp.asarray(frame).transpose(0, 2, 3, 1)
573 | label = jnp.asarray(label)
574 | net.reset_state(frame.shape[0])
575 | batch_loss, n, total_fr = train_fun(frame, label)
576 | prec1, prec5 = accuracy(total_fr, label, (1, 5))
577 | train_loss += batch_loss * label.size
578 | train_acc += n
579 | losses.update(batch_loss, frame.shape[0])
580 | top1.update(prec1.item(), frame.shape[0])
581 | top5.update(prec5.item(), frame.shape[0])
582 |
583 | # measure elapsed time
584 | batch_time.update(time.time() - end)
585 | end = time.time()
586 |
587 | # plot progress
588 | pbar.update(1)
589 | pbar.set_description(
590 | 'Batch: {bt:.3f}s | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
591 | bt=batch_time.avg, loss=losses.avg, top1=top1.avg, top5=top5.avg,
592 | )
593 | )
594 | pbar.close()
595 |
596 | train_loss /= train_samples
597 | train_acc /= train_samples
598 | optimizer.lr.step_epoch()
599 |
600 | batch_time = AverageMeter()
601 | losses = AverageMeter()
602 | top1 = AverageMeter()
603 | top5 = AverageMeter()
604 | end = time.time()
605 |
606 | test_loss = 0
607 | test_acc = 0
608 | pbar = tqdm.tqdm(total=test_samples)
609 | for frame, label in test_data_loader:
610 | frame = jnp.asarray(frame).transpose(0, 2, 3, 1)
611 | label = jnp.asarray(label)
612 | net.reset_state(frame.shape[0])
613 | total_loss, n, out = inference_fun(frame, label)
614 | test_loss += total_loss * label.size
615 | test_acc += n
616 | prec1, prec5 = accuracy(out, label, (1, 5))
617 | losses.update(total_loss, frame.shape[0])
618 | top1.update(prec1.item(), frame.shape[0])
619 | top5.update(prec5.item(), frame.shape[0])
620 | batch_time.update(time.time() - end)
621 | end = time.time()
622 |
623 | # plot progress
624 | pbar.update(1)
625 | pbar.set_description(
626 | 'Batch: {bt:.3f}s | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
627 | bt=batch_time.avg, loss=losses.avg, top1=top1.avg, top5=top5.avg,
628 | )
629 | )
630 | pbar.close()
631 |
632 | test_loss /= test_samples
633 | test_acc /= test_samples
634 |
635 | if test_acc > max_test_acc:
636 | max_test_acc = test_acc
637 | checkpoint = {
638 | 'net': net.state_dict(),
639 | 'optimizer': optimizer.state_dict(),
640 | 'epoch': epoch,
641 | 'max_test_acc': max_test_acc
642 | }
643 | bp.checkpoints.save_pytree(out_dir + '/checkpoint.bp', checkpoint, overwrite=True)
644 |
645 | total_time = time.time() - start_time
646 | print(f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, '
647 | f'test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, '
648 | f'total_time={total_time}')
649 |
650 |
651 | if __name__ == '__main__':
652 | classify_cifar()
653 |
--------------------------------------------------------------------------------
/brain_inspired_computing/SurrogateGrad_lif.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | Reproduce the results of the``spytorch`` tutorial 1:
6 |
7 | - https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial1.ipynb
8 |
9 | """
10 |
11 | import time
12 |
13 | import jax.numpy as jnp
14 | import numpy as np
15 | import matplotlib.pyplot as plt
16 | from matplotlib.gridspec import GridSpec
17 |
18 | import brainpy as bp
19 | import brainpy.math as bm
20 |
21 |
22 | class SNN(bp.DynSysGroup):
23 | def __init__(self, num_in, num_rec, num_out):
24 | super(SNN, self).__init__()
25 |
26 | # parameters
27 | self.num_in = num_in
28 | self.num_rec = num_rec
29 | self.num_out = num_out
30 |
31 | # synapse: i->r
32 | self.i2r = bp.Sequential(
33 | bp.dnn.Linear(num_in, num_rec, W_initializer=bp.init.KaimingNormal(scale=20.)),
34 | bp.dyn.Expon(num_rec, tau=10.)
35 | )
36 | # recurrent: r
37 | self.r = bp.dyn.Lif(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)
38 | # synapse: r->o
39 | self.r2o = bp.Sequential(
40 | bp.dnn.Linear(num_rec, num_out, W_initializer=bp.init.KaimingNormal(scale=20.)),
41 | bp.dyn.Expon(num_out, tau=10.)
42 | )
43 | # output: o
44 | self.o = bp.dyn.Leaky(num_out, tau=5)
45 |
46 | def update(self, spike):
47 | return spike >> self.i2r >> self.r >> self.r2o >> self.o
48 |
49 |
50 | def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5):
51 | gs = GridSpec(*dim)
52 | mem = 1. * mem
53 | if spk is not None:
54 | mem[spk > 0.0] = spike_height
55 | mem = bm.as_numpy(mem)
56 | for i in range(np.prod(dim)):
57 | if i == 0:
58 | a0 = ax = plt.subplot(gs[i])
59 | else:
60 | ax = plt.subplot(gs[i], sharey=a0)
61 | ax.plot(mem[i])
62 | plt.tight_layout()
63 | plt.show()
64 |
65 |
66 | def print_classification_accuracy(output, target):
67 | """ Dirty little helper function to compute classification accuracy. """
68 | m = bm.max(output, axis=1) # max over time
69 | am = bm.argmax(m, axis=1) # argmax over output units
70 | acc = bm.mean(target == am) # compare to labels
71 | print("Accuracy %.3f" % acc)
72 |
73 |
74 | with bm.environment(mode=bm.training_mode):
75 | net = SNN(100, 4, 2)
76 |
77 | num_step = 2000
78 | num_sample = 256
79 | freq = 5 # Hz
80 | mask = bm.random.rand(num_sample, num_step, net.num_in)
81 | x_data = bm.zeros((num_sample, num_step, net.num_in))
82 | x_data[mask < freq * bm.get_dt() / 1000.] = 1.0
83 | y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_)
84 | rng = bm.random.RandomState(123)
85 |
86 |
87 | # Before training
88 | runner = bp.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V})
89 | out = runner.run(inputs=x_data.value, reset_state=True)
90 | plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike'))
91 | plot_voltage_traces(out)
92 | print_classification_accuracy(out, y_data)
93 |
94 |
95 | def loss():
96 | key = rng.split_key()
97 | X = bm.random.permutation(x_data, key=key)
98 | Y = bm.random.permutation(y_data, key=key)
99 | looper = bp.DSRunner(net, numpy_mon_after_run=False, progress_bar=False)
100 | predictions = looper.run(inputs=X, reset_state=True)
101 | predictions = bm.max(predictions, axis=1)
102 | return bp.losses.cross_entropy_loss(predictions, Y)
103 |
104 |
105 | grad = bm.grad(loss, grad_vars=net.train_vars().unique(), return_value=True)
106 | optimizer = bp.optim.Adam(lr=2e-3, train_vars=net.train_vars().unique())
107 |
108 |
109 | def train(_):
110 | grads, l = grad()
111 | optimizer.update(grads)
112 | return l
113 |
114 |
115 | # train the network
116 | net.reset_state(num_sample)
117 | train_losses = []
118 | b = 100
119 | for i in range(0, 3000, b):
120 | t0 = time.time()
121 | ls = bm.for_loop(train, operands=bm.arange(i, i + b, 1))
122 | print(f'Train {i + b} epoch, loss = {jnp.mean(ls):.4f}, used time {time.time() - t0:.4f} s')
123 | train_losses.append(ls)
124 |
125 | # visualize the training losses
126 | plt.plot(bm.as_numpy(jnp.concatenate(train_losses)))
127 | plt.xlabel("Epoch")
128 | plt.ylabel("Training Loss")
129 | plt.show()
130 |
131 | # predict the output according to the input data
132 | runner = bp.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V})
133 | out = runner.run(inputs=x_data, reset_state=True)
134 | plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike'))
135 | plot_voltage_traces(out)
136 | print_classification_accuracy(out, y_data)
137 |
--------------------------------------------------------------------------------
/brain_inspired_computing/SurrogateGrad_lif_fashion_mnist.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | """
4 | Reproduce the results of the``spytorch`` tutorial 2 & 3:
5 |
6 | - https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial2.ipynb
7 | - https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial3.ipynb
8 |
9 | """
10 |
11 | import brainpy_datasets as bd
12 | import matplotlib.pyplot as plt
13 | import numpy as np
14 | from matplotlib.gridspec import GridSpec
15 |
16 | import brainpy as bp
17 | import brainpy.math as bm
18 |
19 | bm.set_environment(bm.training_mode)
20 |
21 |
22 | class SNN(bp.DynamicalSystem):
23 | """
24 | This class implements a spiking neural network model with three layers:
25 |
26 | i >> r >> o
27 |
28 | Each two layers are connected through the exponential synapse model.
29 | """
30 |
31 | def __init__(self, num_in, num_rec, num_out):
32 | super(SNN, self).__init__()
33 |
34 | # parameters
35 | self.num_in = num_in
36 | self.num_rec = num_rec
37 | self.num_out = num_out
38 |
39 | # synapse: i->r
40 | self.i2r = bp.Sequential(
41 | bp.dnn.Linear(num_in, num_rec, W_initializer=bp.init.KaimingNormal(scale=2.)),
42 | bp.dyn.Expon(num_rec, tau=10.)
43 | )
44 | # recurrent: r
45 | self.r = bp.dyn.Lif(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.)
46 | # synapse: r->o
47 | self.r2o = bp.Sequential(
48 | bp.dnn.Linear(num_rec, num_out, W_initializer=bp.init.KaimingNormal(scale=2.)),
49 | bp.dyn.Expon(num_out, tau=10.)
50 | )
51 | # output: o
52 | self.o = bp.dyn.Leaky(num_out, tau=5)
53 |
54 | def update(self, spike):
55 | return spike >> self.i2r >> self.r >> self.r2o >> self.o
56 |
57 |
58 | def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5):
59 | gs = GridSpec(*dim)
60 | mem = 1. * mem
61 | if spk is not None:
62 | mem[spk > 0.0] = spike_height
63 | mem = bm.as_numpy(mem)
64 | for i in range(np.prod(dim)):
65 | if i == 0:
66 | a0 = ax = plt.subplot(gs[i])
67 | else:
68 | ax = plt.subplot(gs[i], sharey=a0)
69 | ax.plot(mem[i])
70 | ax.axis("off")
71 | plt.tight_layout()
72 | plt.show()
73 |
74 |
75 | def print_classification_accuracy(output, target):
76 | """ Dirty little helper function to compute classification accuracy. """
77 | m = bm.max(output, axis=1) # max over time
78 | am = bm.argmax(m, axis=1) # argmax over output units
79 | acc = bm.mean(target == am) # compare to labels
80 | print("Accuracy %.3f" % acc)
81 |
82 |
83 | def current2firing_time(x, tau=20., thr=0.2, epsilon=1e-7):
84 | """Computes first firing time latency for a current input x
85 | assuming the charge time of a current based LIF neuron.
86 |
87 | Args:
88 | x -- The "current" values
89 |
90 | Keyword args:
91 | tau -- The membrane time constant of the LIF neuron to be charged
92 | thr -- The firing threshold value
93 | tmax -- The maximum time returned
94 | epsilon -- A generic (small) epsilon > 0
95 |
96 | Returns:
97 | Time to first spike for each "current" x
98 | """
99 | x = np.clip(x, thr + epsilon, 1e9)
100 | T = tau * np.log(x / (x - thr))
101 | return T
102 |
103 |
104 | def sparse_data_generator(X, y, batch_size, nb_steps, nb_units, shuffle=True):
105 | """ This generator takes datasets in analog format and
106 | generates spiking network input as sparse tensors.
107 |
108 | Args:
109 | X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
110 | y: The labels
111 | """
112 |
113 | labels_ = np.array(y, dtype=bm.int_)
114 | sample_index = np.arange(len(X))
115 |
116 | # compute discrete firing times
117 | tau_eff = 2. / bm.get_dt()
118 | unit_numbers = np.arange(nb_units)
119 | firing_times = np.array(current2firing_time(X, tau=tau_eff), dtype=bm.int_)
120 |
121 | if shuffle:
122 | np.random.shuffle(sample_index)
123 |
124 | counter = 0
125 | number_of_batches = len(X) // batch_size
126 | while counter < number_of_batches:
127 | batch_index = sample_index[batch_size * counter:batch_size * (counter + 1)]
128 | all_batch, all_times, all_units = [], [], []
129 | for bc, idx in enumerate(batch_index):
130 | c = firing_times[idx] < nb_steps
131 | times, units = firing_times[idx][c], unit_numbers[c]
132 | batch = bc * np.ones(len(times), dtype=bm.int_)
133 | all_batch.append(batch)
134 | all_times.append(times)
135 | all_units.append(units)
136 | all_batch = np.concatenate(all_batch).flatten()
137 | all_times = np.concatenate(all_times).flatten()
138 | all_units = np.concatenate(all_units).flatten()
139 | x_batch = bm.zeros((batch_size, nb_steps, nb_units))
140 | x_batch[all_batch, all_times, all_units] = 1.
141 | y_batch = bm.asarray(labels_[batch_index])
142 | yield x_batch, y_batch
143 | counter += 1
144 |
145 |
146 | def train(model, x_data, y_data, lr=1e-3, nb_epochs=10, batch_size=128, nb_steps=128, nb_inputs=28 * 28):
147 | def loss_fun(predicts, targets):
148 | predicts, mon = predicts
149 | # Here we set up our regularizer loss
150 | # The strength paramters here are merely a guess and
151 | # there should be ample room for improvement by
152 | # tuning these paramters.
153 | l1_loss = 1e-5 * bm.sum(mon['r.spike']) # L1 loss on total number of spikes
154 | l2_loss = 1e-5 * bm.mean(bm.sum(bm.sum(mon['r.spike'], axis=0), axis=0) ** 2) # L2 loss on spikes per neuron
155 | # predictions
156 | predicts = bm.max(predicts, axis=1)
157 | loss = bp.losses.cross_entropy_loss(predicts, targets)
158 | return loss + l2_loss + l1_loss
159 |
160 | trainer = bp.BPTT(
161 | model,
162 | loss_fun,
163 | optimizer=bp.optim.Adam(lr=lr),
164 | monitors={'r.spike': net.r.spike},
165 | )
166 | trainer.fit(lambda: sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs),
167 | num_epoch=nb_epochs)
168 | return trainer.get_hist_metric('fit')
169 |
170 |
171 | def compute_classification_accuracy(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28):
172 | """ Computes classification accuracy on supplied data in batches. """
173 | accs = []
174 | runner = bp.DSRunner(model, progress_bar=False)
175 | for x_local, y_local in sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False):
176 | output = runner.predict(inputs=x_local, reset_state=True)
177 | m = bm.max(output, 1) # max over time
178 | am = bm.argmax(m, 1) # argmax over output units
179 | tmp = bm.mean(y_local == am) # compare to labels
180 | accs.append(tmp)
181 | return bm.mean(bm.asarray(accs))
182 |
183 |
184 | def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28):
185 | runner = bp.DSRunner(model,
186 | monitors={'r.spike': model.r.spike},
187 | progress_bar=False)
188 | data = sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False)
189 | x_local, y_local = next(data)
190 | output = runner.predict(inputs=x_local, reset_state=True)
191 | return output, runner.mon.get('r.spike')
192 |
193 |
194 | num_input = 28 * 28
195 | net = SNN(num_in=num_input, num_rec=100, num_out=10)
196 |
197 | # load the dataset
198 | root = r"D:\data"
199 | train_dataset = bd.vision.FashionMNIST(root, split='train', download=True)
200 | test_dataset = bd.vision.FashionMNIST(root, split='test', download=True)
201 |
202 | # Standardize data
203 | x_train = np.array(train_dataset.data, dtype=bm.float_)
204 | x_train = x_train.reshape(x_train.shape[0], -1) / 255
205 | y_train = np.array(train_dataset.targets, dtype=bm.int_)
206 | x_test = np.array(test_dataset.data, dtype=bm.float_)
207 | x_test = x_test.reshape(x_test.shape[0], -1) / 255
208 | y_test = np.array(test_dataset.targets, dtype=bm.int_)
209 |
210 | # training
211 | train_losses = train(net, x_train, y_train, lr=1e-3, nb_epochs=30, batch_size=256, nb_steps=100, nb_inputs=28 * 28)
212 |
213 | plt.figure(figsize=(3.3, 2), dpi=150)
214 | plt.plot(train_losses)
215 | plt.xlabel("Epoch")
216 | plt.ylabel("Loss")
217 | plt.show()
218 |
219 | print("Training accuracy: %.3f" % (compute_classification_accuracy(net, x_train, y_train, batch_size=512)))
220 | print("Test accuracy: %.3f" % (compute_classification_accuracy(net, x_test, y_test, batch_size=512)))
221 |
222 | outs, spikes = get_mini_batch_results(net, x_train, y_train)
223 | # Let's plot the hidden layer spiking activity for some input stimuli
224 | fig = plt.figure(dpi=100)
225 | plot_voltage_traces(outs)
226 | plt.show()
227 |
228 | nb_plt = 4
229 | gs = GridSpec(1, nb_plt)
230 | plt.figure(figsize=(7, 3), dpi=150)
231 | for i in range(nb_plt):
232 | plt.subplot(gs[i])
233 | plt.imshow(bm.as_numpy(spikes[i]).T, cmap=plt.cm.gray_r, origin="lower")
234 | if i == 0:
235 | plt.xlabel("Time")
236 | plt.ylabel("Units")
237 | plt.tight_layout()
238 | plt.show()
239 |
--------------------------------------------------------------------------------
/brain_inspired_computing/fashion_mnist_conv_lif.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import argparse
4 | import os
5 | import sys
6 | import time
7 |
8 | import brainpy as bp
9 | import brainpy.math as bm
10 | import brainpy_datasets as bd
11 | import jax
12 | import jax.numpy as jnp
13 | from jax import lax
14 |
15 | bm.set_environment(mode=bm.training_mode, dt=1.)
16 |
17 |
18 | class ConvLIF(bp.DynamicalSystem):
19 | def __init__(self, n_time: int, n_channel: int, tau: float = 5.):
20 | super().__init__()
21 | self.n_time = n_time
22 |
23 | lif_par = dict(keep_size=True, V_rest=0., V_reset=0., V_th=1.,
24 | tau=tau, spk_fun=bm.surrogate.arctan)
25 |
26 | self.block1 = bp.Sequential(
27 | bp.layers.Conv2d(1, n_channel, kernel_size=3, padding=(1, 1), b_initializer=None),
28 | bp.layers.BatchNorm2d(n_channel, momentum=0.9),
29 | bp.dyn.Lif((28, 28, n_channel), **lif_par)
30 | )
31 | self.block2 = bp.Sequential(
32 | bp.layers.MaxPool2d(2, 2), # 14 * 14
33 | bp.layers.Conv2d(n_channel, n_channel, kernel_size=3, padding=(1, 1), b_initializer=None),
34 | bp.layers.BatchNorm2d(n_channel, momentum=0.9),
35 | bp.dyn.Lif((14, 14, n_channel), **lif_par),
36 | )
37 | self.block3 = bp.Sequential(
38 | bp.layers.MaxPool2d(2, 2), # 7 * 7
39 | bp.layers.Flatten(),
40 | bp.layers.Dense(n_channel * 7 * 7, n_channel * 4 * 4, b_initializer=None),
41 | bp.dyn.Lif(4 * 4 * n_channel, **lif_par),
42 | )
43 | self.block4 = bp.Sequential(
44 | bp.layers.Dense(n_channel * 4 * 4, 10, b_initializer=None),
45 | bp.dyn.Lif(10, **lif_par),
46 | )
47 |
48 | def update(self, x):
49 | # x.shape = [B, H, W, C]
50 | x = self.block1(x)
51 | x = self.block2(x)
52 | x = self.block3(x)
53 | return self.block4(x)
54 |
55 |
56 | class IFNode(bp.DynamicalSystem):
57 | """The Integrate-and-Fire neuron. The voltage of the IF neuron will
58 | not decay as that of the LIF neuron. The sub-threshold neural dynamics
59 | of it is as followed:
60 |
61 | .. math::
62 | V[t] = V[t-1] + X[t]
63 | """
64 |
65 | def __init__(self, size: tuple, v_threshold: float = 1., v_reset: float = 0.,
66 | spike_fun=bm.surrogate.arctan, mode=None, reset_mode='soft'):
67 | super().__init__(mode=mode)
68 | bp.check.is_subclass(self.mode, bm.TrainingMode)
69 |
70 | self.size = bp.check.is_sequence(size, elem_type=int, allow_none=False)
71 | self.reset_mode = bp.check.is_string(reset_mode, candidates=['hard', 'soft'])
72 | self.v_threshold = bp.check.is_float(v_threshold)
73 | self.v_reset = bp.check.is_float(v_reset)
74 | self.spike_fun = bp.check.is_callable(spike_fun)
75 |
76 | # variables
77 | self.V = bm.Variable(jnp.zeros((1,) + size, dtype=bm.float_), batch_axis=0)
78 |
79 | def reset_state(self, batch_size):
80 | self.V.value = jnp.zeros((batch_size,) + self.size, dtype=bm.float_)
81 |
82 | def update(self, x):
83 | self.V.value += x
84 | spike = self.spike_fun(self.V - self.v_threshold)
85 | if self.reset_mode == 'hard':
86 | one = lax.convert_element_type(1., bm.float_)
87 | self.V.value = self.v_reset * spike + (one - spike) * self.V
88 | else:
89 | self.V -= spike * self.v_threshold
90 | return spike
91 |
92 |
93 | class ConvIF(bp.DynamicalSystem):
94 | def __init__(self, n_time: int, n_channel: int):
95 | super().__init__()
96 | self.n_time = n_time
97 |
98 | self.block1 = bp.Sequential(
99 | bp.layers.Conv2d(1, n_channel, kernel_size=3, padding=(1, 1), ),
100 | bp.layers.BatchNorm2d(n_channel, momentum=0.9),
101 | IFNode((28, 28, n_channel), spike_fun=bm.surrogate.arctan)
102 | )
103 | self.block2 = bp.Sequential(
104 | bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 14 * 14
105 | bp.layers.Conv2d(n_channel, n_channel, kernel_size=3, padding=(1, 1), ),
106 | bp.layers.BatchNorm2d(n_channel, momentum=0.9),
107 | IFNode((14, 14, n_channel), spike_fun=bm.surrogate.arctan),
108 | )
109 | self.block3 = bp.Sequential(
110 | bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 7 * 7
111 | bp.layers.Flatten(),
112 | bp.layers.Dense(n_channel * 7 * 7, n_channel * 4 * 4, ),
113 | IFNode((4 * 4 * n_channel,), spike_fun=bm.surrogate.arctan),
114 | )
115 | self.block4 = bp.Sequential(
116 | bp.layers.Dense(n_channel * 4 * 4, 10, ),
117 | IFNode((10,), spike_fun=bm.surrogate.arctan),
118 | )
119 |
120 | def update(self, x):
121 | x = self.block1(x) # x.shape = [B, H, W, C]
122 | x = self.block2(x)
123 | x = self.block3(x)
124 | x = self.block4(x)
125 | return x
126 |
127 |
128 | class TrainMNIST:
129 | def __init__(self, net, n_time):
130 | self.net = net
131 | self.n_time = n_time
132 | self.f_opt = bp.optim.Adam(bp.optim.ExponentialDecay(0.2, 1, 0.9999),
133 | train_vars=net.train_vars().unique())
134 | self.f_grad = bm.grad(self.loss, grad_vars=self.f_opt.vars_to_train,
135 | has_aux=True, return_value=True)
136 |
137 | def inference(self, X, fit=False):
138 | def run_net(t):
139 | bp.share.save(t=t, fit=fit)
140 | return self.net(X)
141 |
142 | self.net.reset_state(X.shape[0])
143 | return bm.for_loop(run_net, jnp.arange(self.n_time, dtype=bm.float_), jit=False)
144 |
145 | def loss(self, X, Y, fit=False):
146 | fr = bm.max(self.inference(X, fit), axis=0)
147 | ys_onehot = bm.one_hot(Y, 10, dtype=bm.float_)
148 | l = bp.losses.mean_squared_error(fr, ys_onehot)
149 | n = bm.sum(fr.argmax(1) == Y)
150 | return l, n
151 |
152 | @bm.cls_jit
153 | def f_predict(self, X, Y):
154 | return self.loss(X, Y, fit=False)
155 |
156 | @bm.cls_jit
157 | def f_train(self, X, Y):
158 | bp.share.save(fit=True)
159 | grads, l, n = self.f_grad(X, Y)
160 | self.f_opt.update(grads)
161 | return l, n
162 |
163 |
164 | def main():
165 | parser = argparse.ArgumentParser(description='Classify Fashion-MNIST')
166 | parser.add_argument('-platform', default='cpu', help='platform')
167 | parser.add_argument('-model', default='lif', help='Neuron model to use')
168 | parser.add_argument('-n_time', default=4, type=int, help='simulating time-steps')
169 | parser.add_argument('-tau', default=5., type=float, help='LIF time constant')
170 | parser.add_argument('-batch', default=128, type=int, help='batch size')
171 | parser.add_argument('-n_channel', default=128, type=int, help='channels of ConvLIF')
172 | parser.add_argument('-n_epoch', default=64, type=int, metavar='N', help='number of total epochs to run')
173 | parser.add_argument('-data-dir', default='d:/data', type=str, help='root dir of Fashion-MNIST dataset')
174 | parser.add_argument('-out-dir', default='./logs', type=str, help='root dir for saving logs and checkpoint')
175 | parser.add_argument('-lr', default=0.1, type=float, help='learning rate')
176 | args = parser.parse_args()
177 | print(args)
178 |
179 | bm.set_platform(args.platform)
180 |
181 | # net
182 | if args.model == 'if':
183 | net = ConvIF(n_time=args.n_time, n_channel=args.n_channel)
184 | out_dir = os.path.join(args.out_dir,
185 | f'{args.model}_T{args.n_time}_b{args.batch}'
186 | f'_lr{args.lr}_c{args.n_channel}')
187 | elif args.model == 'lif':
188 | net = ConvLIF(n_time=args.n_time, n_channel=args.n_channel, tau=args.tau)
189 | out_dir = os.path.join(args.out_dir,
190 | f'{args.model}_T{args.n_time}_b{args.batch}'
191 | f'_lr{args.lr}_c{args.n_channel}_tau{args.tau}')
192 | else:
193 | raise ValueError
194 |
195 | trainer = TrainMNIST(net, args.n_time)
196 |
197 | # dataset
198 | train_set = bd.vision.FashionMNIST(root=args.data_dir, split='train', download=True)
199 | test_set = bd.vision.FashionMNIST(root=args.data_dir, split='test', download=True)
200 | x_train = jnp.asarray(train_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1))
201 | y_train = jnp.asarray(train_set.targets, dtype=bm.int_)
202 | x_test = jnp.asarray(test_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1))
203 | y_test = jnp.asarray(test_set.targets, dtype=bm.int_)
204 |
205 | os.makedirs(out_dir, exist_ok=True)
206 | with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
207 | args_txt.write(str(args))
208 | args_txt.write('\n')
209 | args_txt.write(' '.join(sys.argv))
210 |
211 | max_test_acc = -1
212 | for epoch_i in range(0, args.n_epoch):
213 | start_time = time.time()
214 | loss, train_acc = [], 0.
215 | for i in range(0, x_train.shape[0], args.batch):
216 | xs = x_train[i: i + args.batch]
217 | ys = y_train[i: i + args.batch]
218 | with jax.disable_jit():
219 | l, n = trainer.f_train(xs, ys)
220 | loss.append(l)
221 | train_acc += n
222 | train_acc /= x_train.shape[0]
223 | train_loss = jnp.mean(jnp.asarray(loss))
224 | trainer.f_opt.lr.step_epoch()
225 |
226 | loss, test_acc = [], 0.
227 | for i in range(0, x_test.shape[0], args.batch):
228 | xs = x_test[i: i + args.batch]
229 | ys = y_test[i: i + args.batch]
230 | l, n = trainer.f_predict(xs, ys)
231 | loss.append(l)
232 | test_acc += n
233 | test_acc /= x_test.shape[0]
234 | test_loss = jnp.mean(jnp.asarray(loss))
235 |
236 | t = (time.time() - start_time) / 60
237 | print(f'epoch {epoch_i}, used {t:.3f} min, '
238 | f'train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, '
239 | f'test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}')
240 |
241 | if max_test_acc < test_acc:
242 | max_test_acc = test_acc
243 | states = {
244 | 'net': net.state_dict(),
245 | 'optimizer': trainer.f_opt.state_dict(),
246 | 'epoch_i': epoch_i,
247 | 'train_acc': train_acc,
248 | 'test_acc': test_acc,
249 | }
250 | bp.checkpoints.save_pytree(os.path.join(out_dir, 'fmnist-conv-lif.bp'), states)
251 |
252 | # inference
253 | state_dict = bp.checkpoints.load_pytree(os.path.join(out_dir, 'fmnist-conv-lif.bp'))
254 | net.load_state_dict(state_dict['net'])
255 | correct_num = 0
256 | for i in range(0, x_test.shape[0], 512):
257 | xs = x_test[i: i + 512]
258 | ys = y_test[i: i + 512]
259 | correct_num += trainer.f_predict(xs, ys)[1]
260 | print('Max test accuracy: ', correct_num / x_test.shape[0])
261 |
262 |
263 | if __name__ == '__main__':
264 | main()
265 |
--------------------------------------------------------------------------------
/brain_inspired_computing/mnist_lif_readout.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import time
4 | import argparse
5 | import os.path
6 | import sys
7 |
8 | import brainpy_datasets as bd
9 |
10 | import jax.numpy as jnp
11 |
12 | import brainpy as bp
13 | import brainpy.math as bm
14 | import numpy as np
15 |
16 | parser = argparse.ArgumentParser(description='LIF MNIST Training')
17 | parser.add_argument('-T', default=100, type=int, help='simulating time-steps')
18 | parser.add_argument('-platform', default='cpu', help='device')
19 | parser.add_argument('-batch', default=64, type=int, help='batch size')
20 | parser.add_argument('-epochs', default=15, type=int, metavar='N', help='number of total epochs to run')
21 | parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint')
22 | parser.add_argument('-lr', default=1e-3, type=float, help='learning rate')
23 | parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron')
24 | args = parser.parse_args()
25 | print(args)
26 |
27 | out_dir = os.path.join(args.out_dir, f'T{args.T}_b{args.batch}_lr{args.lr}')
28 | if not os.path.exists(out_dir):
29 | os.makedirs(out_dir)
30 | with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
31 | args_txt.write(str(args))
32 | args_txt.write('\n')
33 | args_txt.write(' '.join(sys.argv))
34 |
35 | bm.set_platform(args.platform)
36 | bm.set_environment(mode=bm.training_mode, dt=1.)
37 |
38 |
39 | class SNN(bp.DynamicalSystem):
40 | def __init__(self, tau):
41 | super().__init__()
42 | self.l1 = bp.dnn.Dense(28 * 28, 10, b_initializer=None)
43 | self.l2 = bp.dyn.Lif(10, V_rest=0., V_reset=0., V_th=1., tau=tau, spk_fun=bm.surrogate.arctan)
44 |
45 | def update(self, x):
46 | return x >> self.l1 >> self.l2
47 |
48 |
49 | net = SNN(args.tau)
50 |
51 | # data
52 | train_data = bd.vision.MNIST(r'D:/data', split='train', download=True)
53 | test_data = bd.vision.MNIST(r'D:/data', split='test', download=True)
54 | x_train = bm.asarray(train_data.data / 255, dtype=bm.float_).reshape(-1, 28 * 28)
55 | y_train = bm.asarray(train_data.targets, dtype=bm.int_)
56 | x_test = bm.asarray(test_data.data / 255, dtype=bm.float_).reshape(-1, 28 * 28)
57 | y_test = bm.asarray(test_data.targets, dtype=bm.int_)
58 |
59 | # loss
60 | encoder = bp.encoding.PoissonEncoder(min_val=0., max_val=1.)
61 |
62 |
63 | def loss_fun(xs, ys):
64 | net.reset_state(batch_size=xs.shape[0])
65 | xs = encoder.multi_steps(xs, n_time=args.T * bm.get_dt())
66 | # shared arguments for looping over time
67 | indices = np.arange(args.T)
68 | outs = bm.for_loop(net.step_run, (indices, xs))
69 | out_fr = bm.mean(outs, axis=0)
70 | ys_onehot = bm.one_hot(ys, 10, dtype=bm.float_)
71 | l = bp.losses.mean_squared_error(out_fr, ys_onehot)
72 | n = bm.sum(out_fr.argmax(1) == ys)
73 | return l, n
74 |
75 |
76 | # gradient
77 | grad_fun = bm.grad(loss_fun, grad_vars=net.train_vars().unique(), has_aux=True, return_value=True)
78 |
79 | # optimizer
80 | optimizer = bp.optim.Adam(lr=args.lr, train_vars=net.train_vars().unique())
81 |
82 |
83 | # train
84 | @bm.jit
85 | def train(xs, ys):
86 | grads, l, n = grad_fun(xs, ys)
87 | optimizer.update(grads)
88 | return l, n
89 |
90 |
91 | max_test_acc = 0.
92 |
93 | # computing
94 | for epoch_i in range(args.epochs):
95 | bm.random.shuffle(x_train, key=123)
96 | bm.random.shuffle(y_train, key=123)
97 |
98 | t0 = time.time()
99 | loss, train_acc = [], 0.
100 | for i in range(0, x_train.shape[0], args.batch):
101 | X = x_train[i: i + args.batch]
102 | Y = y_train[i: i + args.batch]
103 | l, correct_num = train(X, Y)
104 | loss.append(l)
105 | train_acc += correct_num
106 | train_acc /= x_train.shape[0]
107 | train_loss = jnp.mean(jnp.asarray(loss))
108 | optimizer.lr.step_epoch()
109 |
110 | loss, test_acc = [], 0.
111 | for i in range(0, x_test.shape[0], args.batch):
112 | X = x_test[i: i + args.batch]
113 | Y = y_test[i: i + args.batch]
114 | l, correct_num = loss_fun(X, Y)
115 | loss.append(l)
116 | test_acc += correct_num
117 | test_acc /= x_test.shape[0]
118 | test_loss = jnp.mean(jnp.asarray(loss))
119 |
120 | t = (time.time() - t0) / 60
121 | print(f'epoch {epoch_i}, used {t:.3f} min, '
122 | f'train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, '
123 | f'test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}')
124 |
125 | if max_test_acc < test_acc:
126 | max_test_acc = test_acc
127 | states = {
128 | 'net': net.state_dict(),
129 | 'optimizer': optimizer.state_dict(),
130 | 'epoch_i': epoch_i,
131 | 'train_acc': train_acc,
132 | 'test_acc': test_acc,
133 | }
134 | bp.checkpoints.save_pytree(os.path.join(out_dir, 'mnist-lif.bp'), states)
135 |
136 | # inference
137 | state_dict = bp.checkpoints.load_pytree(os.path.join(out_dir, 'mnist-lif.bp'))
138 | net.load_state_dict(state_dict['net'])
139 |
140 | runner = bp.DSRunner(net, data_first_axis='T')
141 | correct_num = 0
142 | for i in range(0, x_test.shape[0], 512):
143 | X = encoder.multi_steps(x_test[i: i + 512], n_time=args.T * bm.get_dt())
144 | Y = y_test[i: i + 512]
145 | out_fr = bm.mean(runner.predict(inputs=X, reset_state=True), axis=0)
146 | correct_num += bm.sum(out_fr.argmax(1) == Y)
147 |
148 | print('Max test accuracy: ', correct_num / x_test.shape[0])
149 |
--------------------------------------------------------------------------------
/brain_inspired_computing/spiking_FPTT/README.md:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/brain_inspired_computing/spiking_FPTT/README.md
--------------------------------------------------------------------------------
/brain_inspired_computing/spiking_FPTT/add_task.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import time
4 |
5 | import brainpy as bp
6 | import brainpy.math as bm
7 | import numba
8 | import numpy as np
9 |
10 |
11 | @numba.njit(fastmath=True, nogil=True)
12 | def _adding_problem_generator(X_num, X_mask, Y, N, seq_len=6, number_of_ones=2):
13 | for i in numba.prange(N):
14 | positions1 = np.random.choice(np.arange(math.floor(seq_len / 2)),
15 | size=math.floor(number_of_ones / 2),
16 | replace=False)
17 | positions2 = np.random.choice(np.arange(math.ceil(seq_len / 2), seq_len),
18 | size=math.ceil(number_of_ones / 2),
19 | replace=False)
20 | for p in positions1:
21 | X_mask[p, i] = 1
22 | Y[i, 0] += X_num[p, i, 0]
23 | for p in positions2:
24 | X_mask[p, i] = 1
25 | Y[i, 0] += X_num[p, i, 0]
26 |
27 |
28 | class AddTask:
29 | def __init__(self, seq_len=6, high=1, number_of_ones=2):
30 | self.seq_len = seq_len
31 | self.high = high
32 | self.number_of_ones = number_of_ones
33 |
34 | def __call__(self, batch_size):
35 | x_num = np.random.uniform(low=0, high=self.high, size=(self.seq_len, batch_size, 1))
36 | x_mask = np.zeros((self.seq_len, batch_size, 1))
37 | ys = np.ones((batch_size, 1))
38 | _adding_problem_generator(x_num, x_mask, ys, batch_size, self.seq_len, self.number_of_ones)
39 | xs = np.append(x_num, x_mask, axis=2)
40 | return xs, ys
41 |
42 |
43 | class IF(bp.DynamicalSystemNS):
44 | def __init__(
45 | self, size, V_th=0.5,
46 | spike_fun: bm.surrogate.Surrogate = bm.surrogate.MultiGaussianGrad()
47 | ):
48 | super().__init__()
49 |
50 | self.size = size
51 | self.V_th = V_th
52 | self.spike_fun = spike_fun
53 | self.reset_state(self.mode)
54 |
55 | def reset_state(self, batch_size=1):
56 | self.V = bp.init.variable_(bm.zeros, self.size, batch_size)
57 | self.spike = bp.init.variable_(bm.zeros, self.size, batch_size)
58 |
59 | def update(self, x, tau_m):
60 | mem = self.V + (-self.V + x) * tau_m
61 | self.spike.value = self.spike_fun(mem - self.V_th)
62 | self.V.value = (1 - self.spike) * mem
63 |
64 |
65 | class SNN(bp.DynamicalSystemNS):
66 | def __init__(self, input_size, hidden_size):
67 | super().__init__()
68 |
69 | self.lin_inp = bp.layers.Linear(input_size + hidden_size, hidden_size)
70 | self.lin_tau = bp.layers.Linear(hidden_size + hidden_size, hidden_size)
71 | self.rnn = IF(hidden_size)
72 | self.out = bp.layers.Linear(hidden_size, 1, W_initializer=bp.init.XavierNormal())
73 | self.act = bp.layers.Sigmoid()
74 |
75 | self.loss_func = bp.losses.MSELoss()
76 |
77 | def _step(self, x):
78 | inp = self.lin_inp(bm.cat((x, self.rnn.spike), dim=-1))
79 | tau = self.act(self.lin_tau(bm.cat((inp, self.rnn.V), dim=-1)))
80 | self.rnn(inp, tau)
81 |
82 | def update(self, xs, y):
83 | # xs: (num_time, num_batch, num_hidden)
84 | bm.for_loop(self._step, xs)
85 | out = self.out(self.rnn.V)
86 | out = out.squeeze()
87 | y = y.squeeze()
88 | loss = self.loss_func(out, y)
89 | return loss, out
90 |
91 |
92 | class FPTT_Trainer:
93 | def __init__(
94 | self,
95 | net: bp.DynamicalSystemNS,
96 | opt: bp.optim.Optimizer,
97 | clip: float,
98 | alpha: float = 0.1,
99 | beta: float = 0.5,
100 | rho: float = 0.0,
101 | ):
102 | super().__init__()
103 | self.alpha = alpha
104 | self.beta = beta
105 | self.rho = rho
106 | self.clip = clip
107 |
108 | # objects
109 | self.net = net
110 | self.opt = opt
111 | opt.register_train_vars(net.train_vars().unique())
112 |
113 | # parameters
114 | self.named_params = {}
115 | for name, param in self.opt.vars_to_train.items():
116 | sm = bm.Variable(param.clone())
117 | lm = bm.Variable(bm.zeros_like(param))
118 | self.named_params[name] = (sm, lm)
119 |
120 | def reset_params(self):
121 | for name, param in self.opt.vars_to_train.items():
122 | param.value = self.named_params[name][0].value
123 |
124 | def update_params(self):
125 | for name, param in self.opt.vars_to_train.items():
126 | sm, lm = self.named_params[name]
127 | lm += (-self.alpha * (param - sm))
128 | sm *= (1.0 - self.beta)
129 | sm += (self.beta * param - (self.beta / self.alpha) * lm)
130 |
131 | def dyn_loss(self, lambd=1.):
132 | regularization = 0.
133 | for name, param in self.opt.vars_to_train.items():
134 | sm, lm = self.named_params[name]
135 | regularization += (self.rho - 1.) * bm.sum(param * lm)
136 | regularization += lambd * 0.5 * self.alpha * bm.sum(bm.square(param - sm))
137 | return regularization
138 |
139 | def f_loss(self, xs, ys, progress):
140 | l, _ = self.net(xs, ys)
141 | reg = self.dyn_loss()
142 | return l * progress + reg, (l, reg)
143 |
144 | @bm.cls_jit
145 | def predict(self, xs, ys):
146 | return self.net(xs, ys)[0]
147 |
148 | @bm.cls_jit
149 | def fit(self, xs, ys, progress):
150 | grads, (loss, reg) = bm.grad(self.f_loss, grad_vars=self.opt.vars_to_train, has_aux=True)(xs, ys, progress)
151 | grads = bm.clip_by_norm(grads, self.clip)
152 | self.opt.update(grads)
153 | self.update_params()
154 | return loss, reg
155 |
156 |
157 | class BPTT_Trainer:
158 | def __init__(
159 | self,
160 | net: bp.DynamicalSystemNS,
161 | opt: bp.optim.Optimizer,
162 | clip: float,
163 | ):
164 | super().__init__()
165 | self.clip = clip
166 |
167 | # objects
168 | self.net = net
169 | self.opt = opt
170 | opt.register_train_vars(net.train_vars().unique())
171 |
172 | def f_loss(self, xs, ys):
173 | l, _ = self.net(xs, ys)
174 | return l
175 |
176 | @bm.cls_jit
177 | def predict(self, xs, ys):
178 | return self.net(xs, ys)[0]
179 |
180 | @bm.cls_jit
181 | def fit(self, xs, ys):
182 | grads, loss = bm.grad(self.f_loss, grad_vars=self.opt.vars_to_train, return_value=True)(xs, ys)
183 | grads = bm.clip_by_norm(grads, self.clip)
184 | self.opt.update(grads)
185 | return loss
186 |
187 |
188 | def fptt_training():
189 | parser = argparse.ArgumentParser()
190 | parser.add_argument('--alpha', type=float, default=.1, help='Alpha')
191 | parser.add_argument('--beta', type=float, default=0.5, help='Beta')
192 | parser.add_argument('--rho', type=float, default=0.0, help='Rho')
193 | parser.add_argument('--bptt', type=int, default=300, help='sequence length')
194 | parser.add_argument('--nhid', type=int, default=128, help='number of hidden units per layer')
195 | parser.add_argument('--lr', type=float, default=3e-3, help='initial learning rate (default: 4e-3)')
196 | parser.add_argument('--clip', type=float, default=1., help='gradient clipping')
197 | parser.add_argument('--epochs', type=int, default=1000, help='upper epoch limit (default: 200)')
198 | parser.add_argument('--parts', type=int, default=10, help='Parts to split the sequential input into (default: 10)')
199 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='batch size')
200 | parser.add_argument('--save', type=str, default='', help='path of model to save')
201 | parser.add_argument('--load', type=str, default='', help='path of model to load')
202 | parser.add_argument('--wdecay', type=float, default=0.0, help='weight decay')
203 | parser.add_argument('--seed', type=int, default=1111, help='random seed')
204 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use')
205 | args = parser.parse_args()
206 |
207 | bm.random.seed(args.seed)
208 |
209 | # model
210 | with bm.environment(mode=bm.TrainingMode(batch_size=args.batch_size)):
211 | model = SNN(2, args.nhid)
212 |
213 | # dataset
214 | dataset = AddTask(args.bptt, number_of_ones=2)
215 |
216 | # optimizer
217 | if args.optim == 'adam':
218 | optimizer = bp.optim.Adam(lr=args.lr, weight_decay=args.wdecay)
219 | elif args.optim == 'sgd':
220 | optimizer = bp.optim.SGD(lr=args.lr, weight_decay=args.wdecay)
221 | else:
222 | raise ValueError
223 |
224 | # trainer
225 | trainer = FPTT_Trainer(model, optimizer, args.clip, alpha=args.alpha, beta=args.beta, rho=args.rho)
226 |
227 | # loading
228 | if args.load:
229 | states = bp.checkpoints.load_pytree(args.load)
230 | model.load_state_dict(states['model'])
231 | optimizer.load_state_dict(states['opt'])
232 |
233 | # training
234 | step = args.bptt // args.parts
235 | for epoch in range(1, args.epochs + 1):
236 | model.reset_state(args.batch_size)
237 |
238 | # fitting
239 | s_t = time.time()
240 | x, y = dataset(args.batch_size)
241 | losses, regs = [], []
242 | for p in range(0, args.parts):
243 | start = p * step
244 | l, r = trainer.fit(x[start: start + step], y, (p + 1) / args.parts)
245 | losses.append(l.item())
246 | regs.append(r.item())
247 |
248 | # prediction
249 | x, y = dataset(args.batch_size)
250 | loss_act = trainer.predict(x, y).item()
251 | print(f'Epoch {epoch}, '
252 | f'time {time.time() - s_t:.4f} s, '
253 | f'train loss {np.mean(np.array(losses)):.4f}, '
254 | f'train reg {np.mean(np.array(regs)):.4f}, '
255 | f'test loss {loss_act:.4f}. ')
256 | trainer.reset_params()
257 | trainer.opt.lr.step_epoch()
258 |
259 |
260 | def bptt_training():
261 | parser = argparse.ArgumentParser()
262 | parser.add_argument('--bptt', type=int, default=300, help='sequence length')
263 | parser.add_argument('--nhid', type=int, default=128, help='number of hidden units per layer')
264 | parser.add_argument('--lr', type=float, default=3e-3, help='initial learning rate (default: 4e-3)')
265 | parser.add_argument('--clip', type=float, default=1., help='gradient clipping')
266 | parser.add_argument('--epochs', type=int, default=1000, help='upper epoch limit (default: 200)')
267 | parser.add_argument('--parts', type=int, default=10, help='Parts to split the sequential input into (default: 10)')
268 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='batch size')
269 | parser.add_argument('--save', type=str, default='', help='path of model to save')
270 | parser.add_argument('--load', type=str, default='', help='path of model to load')
271 | parser.add_argument('--wdecay', type=float, default=0.0, help='weight decay')
272 | parser.add_argument('--seed', type=int, default=1111, help='random seed')
273 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use')
274 | args = parser.parse_args()
275 |
276 | bm.random.seed(args.seed)
277 |
278 | # model
279 | with bm.environment(mode=bm.TrainingMode(batch_size=args.batch_size)):
280 | model = SNN(2, args.nhid)
281 |
282 | # dataset
283 | dataset = AddTask(args.bptt, number_of_ones=2)
284 |
285 | # optimizer
286 | if args.optim == 'adam':
287 | optimizer = bp.optim.Adam(lr=args.lr, weight_decay=args.wdecay)
288 | elif args.optim == 'sgd':
289 | optimizer = bp.optim.SGD(lr=args.lr, weight_decay=args.wdecay)
290 | else:
291 | raise ValueError
292 |
293 | # trainer
294 | trainer = BPTT_Trainer(model, optimizer, args.clip)
295 |
296 | # loading
297 | if args.load:
298 | states = bp.checkpoints.load_pytree(args.load)
299 | model.load_state_dict(states['model'])
300 | optimizer.load_state_dict(states['opt'])
301 |
302 | # training
303 | for epoch in range(1, args.epochs + 1):
304 | model.reset_state(args.batch_size)
305 |
306 | # fitting
307 | s_t = time.time()
308 | x, y = dataset(args.batch_size)
309 | loss = trainer.fit(x, y).item()
310 |
311 | # prediction
312 | x, y = dataset(args.batch_size)
313 | loss_act = trainer.predict(x, y).item()
314 |
315 | print(f'Epoch {epoch}, '
316 | f'time {time.time() - s_t:.4f} s, '
317 | f'train loss {loss:.4f}, '
318 | f'test loss {loss_act:.4f}. ')
319 | trainer.opt.lr.step_epoch()
320 |
321 |
322 | if __name__ == '__main__':
323 | # fptt_training()
324 | bptt_training()
325 |
326 |
--------------------------------------------------------------------------------
/brain_inspired_computing/spiking_FPTT/mnist_classification.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import numpy as np
4 |
5 | import brainpy as bp
6 | import brainpy.math as bm
7 | import brainpy_datasets as bdata
8 | from functools import partial
9 |
10 |
11 | class SigmoidBeta(bp.DynamicalSystemNS):
12 | def __init__(self, alpha=1., is_train=False):
13 | super(SigmoidBeta, self).__init__()
14 | if alpha is None:
15 | self.alpha = bm.asarray(1.) # create a tensor out of alpha
16 | else:
17 | self.alpha = bm.asarray(alpha) # create a tensor out of alpha
18 | if is_train:
19 | self.alpha = bm.TrainVar(self.alpha)
20 |
21 | def update(self, x):
22 | return bm.sigmoid(self.alpha * x)
23 |
24 |
25 | class LTC_LIF(bp.DynamicalSystemNS):
26 | def __init__(
27 | self, input_size, hidden_size, beta=1.8, b0=0.1,
28 | spike_fun: bm.surrogate.Surrogate = bm.surrogate.MultiGaussianGrad()
29 | ):
30 | super().__init__()
31 |
32 | self.hidden_size = hidden_size
33 | self.beta = beta
34 | self.b0 = b0
35 | self.spike_fun = spike_fun
36 |
37 | self.lin = bp.layers.Linear(input_size, hidden_size, W_initializer=bp.init.XavierNormal())
38 | self.lr = bp.layers.Linear(hidden_size, hidden_size, W_initializer=bp.init.Orthogonal())
39 | self.tauM = bp.layers.Linear(hidden_size + hidden_size, hidden_size,
40 | W_initializer=bp.init.XavierNormal())
41 | self.tauAdp = bp.layers.Linear(hidden_size + hidden_size, hidden_size,
42 | W_initializer=bp.init.XavierNormal())
43 | self.actM = SigmoidBeta(is_train=True)
44 | self.actAdp = SigmoidBeta(is_train=True)
45 |
46 | self.reset_state(self.mode)
47 |
48 | def reset_state(self, batch_size=1):
49 | self.V = bp.init.variable_(bm.zeros, self.hidden_size, batch_size)
50 | self.b = bp.init.variable_(bm.zeros, self.hidden_size, batch_size)
51 | self.spike = bp.init.variable_(bm.zeros, self.hidden_size, batch_size)
52 |
53 | def update(self, x):
54 | encoding = self.lin(x) + self.lr(self.spike)
55 | tauM = self.actM(self.tauM(bm.cat((encoding, self.V), dim=-1)))
56 | tauAdp = self.actAdp(self.tauAdp(bm.cat((encoding, self.b), dim=-1)))
57 | b = tauAdp * self.b + (1 - tauAdp) * self.spike
58 | self.b.value = b
59 | B = self.b0 + self.beta * b
60 | d_mem = -self.V + encoding
61 | mem = self.V + d_mem * tauM
62 | spike = self.spike_fun(mem - B)
63 | self.V.value = (1 - spike) * mem
64 | self.spike.value = spike
65 | return spike
66 |
67 |
68 | class ReadoutLIF(bp.DynamicalSystemNS):
69 | def __init__(
70 | self, input_size, hidden_size, beta=1.8, b0=0.1,
71 | ):
72 | super().__init__()
73 |
74 | self.size = hidden_size
75 | self.beta = beta
76 | self.b0 = b0
77 |
78 | self.lin = bp.layers.Linear(input_size, hidden_size,
79 | W_initializer=bp.init.XavierNormal())
80 | self.tauM = bp.layers.Linear(hidden_size + hidden_size, hidden_size,
81 | W_initializer=bp.init.XavierNormal())
82 | self.actM = SigmoidBeta(is_train=True)
83 |
84 | self.reset_state(self.mode)
85 |
86 | def reset_state(self, batch_size=1):
87 | self.V = bp.init.variable_(bm.zeros, self.size, batch_size)
88 |
89 | def update(self, x):
90 | encoding = self.lin(x)
91 | tauM = self.actM(self.tauM(bm.cat((encoding, self.V), dim=-1)))
92 | mem = (1 - tauM) * self.V + tauM * encoding
93 | self.V.value = mem
94 | return mem
95 |
96 |
97 | class SNN(bp.DynamicalSystemNS):
98 | def __init__(self, input_size, hidden_size, output_size):
99 | super(SNN, self).__init__()
100 |
101 | self.input_size = input_size
102 | self.hidden_size = hidden_size
103 | self.output_size = output_size
104 |
105 | self.layer1 = LTC_LIF(input_size, hidden_size)
106 | self.layer2 = LTC_LIF(hidden_size, hidden_size)
107 | self.layer3 = ReadoutLIF(hidden_size, output_size)
108 |
109 | self.fr = bm.Variable(bm.asarray(0.))
110 |
111 | def update(self, x):
112 | x = bm.expand_dims(x, axis=1)
113 | spk_1 = self.layer1(x)
114 | spk_2 = self.layer2(spk_1)
115 | out = self.layer3(spk_2)
116 | self.fr += (spk_1.mean() + spk_2.mean())
117 | return out
118 |
119 |
120 | class FPTT_Trainer:
121 | def __init__(
122 | self,
123 | net: bp.DynamicalSystem,
124 | optimizer: bp.optim.Optimizer,
125 |
126 | debias: bool = False,
127 | clip: float = 0.,
128 | alpha: float = 0.1,
129 | beta: float = 0.5,
130 | rho: float = 0.,
131 | ):
132 | self.clip = clip
133 | self.alpha = alpha
134 | self.beta = beta
135 | self.rho = rho
136 | self.debias = debias
137 |
138 | self.optimizer = optimizer
139 | self.net = net
140 | optimizer.register_train_vars(net.train_vars().unique())
141 |
142 | # parameters
143 | self.named_params = {}
144 | for name, param in optimizer.vars_to_train.items():
145 | sm = bm.Variable(param.clone())
146 | lm = bm.Variable(bm.zeros_like(param))
147 | if debias:
148 | dm = bm.Variable(bm.zeros_like(param))
149 | self.named_params[name] = (sm, lm, dm)
150 | else:
151 | self.named_params[name] = (sm, lm)
152 |
153 | def reset_params(self):
154 | if not self.debias:
155 | for name, param in self.optimizer.vars_to_train.items():
156 | param.value = self.named_params[name][0].value
157 |
158 | def update_params(self, epoch):
159 | for name, param in self.optimizer.vars_to_train.items():
160 | if self.debias:
161 | sm, lm, dm = self.named_params[name]
162 | beta = (1. / (1. + epoch))
163 | sm *= (1.0 - beta)
164 | sm += (beta * param)
165 | dm *= (1. - beta)
166 | dm += (beta * lm)
167 | else:
168 | sm, lm = self.named_params[name]
169 | lm += (-self.alpha * (param - sm))
170 | sm *= (1.0 - self.beta)
171 | sm += (self.beta * param - (self.beta / self.alpha) * lm)
172 |
173 | def dyn_loss(self, lambd=1.):
174 | reg = 0.
175 | for name, param in self.optimizer.vars_to_train.items():
176 | if self.debias:
177 | sm, lm, dm = self.named_params[name]
178 | reg += (self.rho - 1.) * bm.sum(param * lm)
179 | reg += (1. - self.rho) * bm.sum(param * dm)
180 | else:
181 | sm, lm = self.named_params[name]
182 | reg += (self.rho - 1.) * bm.sum(param * lm)
183 | reg += lambd * 0.5 * self.alpha * bm.sum(bm.square(param - sm))
184 | return reg
185 |
186 | def _loss(self, x, y, progress):
187 | out = self.net(x)
188 | loss = progress * bp.losses.cross_entropy_loss(out, y, reduction='mean')
189 | reg = self.dyn_loss()
190 | return loss + reg, (loss, reg)
191 |
192 | def _train(self, x, progress, y, epoch):
193 | grads, (loss, reg) = bm.grad(self._loss, grad_vars=self.optimizer.vars_to_train, has_aux=True)(
194 | x, y, progress
195 | )
196 | if self.clip > 0.:
197 | grads = bm.clip_by_norm(grads, self.clip)
198 | self.optimizer.update(grads)
199 | self.update_params(epoch)
200 | return loss, reg
201 |
202 | @bm.cls_jit
203 | def fit(self, xs, ys, epoch): # xs: (num_time, num_batch)
204 | progresses = bm.linspace(0., 1., xs.shape[0])
205 | loss, reg = bm.for_loop(partial(self._train, epoch=epoch, y=ys), (xs, progresses), remat=True)
206 | return loss, reg
207 |
208 |
209 | parser = argparse.ArgumentParser()
210 | parser.add_argument('--alpha', type=float, default=0.1, help='Alpha')
211 | parser.add_argument('--beta', type=float, default=0.5, help='Beta')
212 | parser.add_argument('--rho', type=float, default=0.0, help='Rho')
213 | parser.add_argument('--debias', action='store_true', help='FedDyn debias algorithm')
214 |
215 | parser.add_argument('--bptt', type=int, default=300, help='sequence length')
216 | parser.add_argument('--nhid', type=int, default=256, help='number of hidden units per layer')
217 | parser.add_argument('--lr', type=float, default=5e-3, help='initial learning rate (default: 4e-3)')
218 | parser.add_argument('--clip', type=float, default=1., help='gradient clipping')
219 |
220 | parser.add_argument('--epochs', type=int, default=250, help='upper epoch limit (default: 200)')
221 | parser.add_argument('--batch_size', type=int, default=512, metavar='N', help='batch size')
222 |
223 | parser.add_argument('--wdecay', type=float, default=0., help='weight decay')
224 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use')
225 | parser.add_argument('--when', nargs='+', type=int, default=[10, 30, 50, 75, 90], help='When to decay the learning rate')
226 | parser.add_argument('--load', type=str, default='', help='path to load the model')
227 | parser.add_argument('--save', type=str, default='./models/', help='path to load the model')
228 | parser.add_argument('--permute', action='store_true', help='use permuted dataset (default: False)')
229 | args = parser.parse_args()
230 |
231 | bm.set(mode=bm.TrainingMode(args.batch_size)) # important
232 |
233 | # datasets
234 | n_classes = 10
235 | train_data = bdata.vision.MNIST(r'/mnt/d/data/', download=True, split='train')
236 | x_train = (train_data.data / 255.).reshape(-1, 28 * 28)
237 | y_train = train_data.targets
238 |
239 |
240 | def train_data():
241 | indices = np.random.permutation(len(x_train))
242 | for i in range(0, len(x_train), args.batch_size):
243 | idx = indices[i: i + args.batch_size]
244 | yield x_train[idx].T, y_train[idx]
245 |
246 |
247 | # model
248 | model = SNN(1, args.nhid, n_classes)
249 |
250 | # optimizer
251 | lr = bp.optim.MultiStepLR(args.lr, args.when, gamma=0.1)
252 | if args.optim == 'adam':
253 | optimizer = bp.optim.Adam(lr=lr, weight_decay=args.wdecay)
254 | elif args.optim == 'sgd':
255 | optimizer = bp.optim.SGD(lr=lr, weight_decay=args.wdecay)
256 | else:
257 | raise ValueError
258 |
259 | # trainer
260 | trainer = FPTT_Trainer(model,
261 | optimizer,
262 | debias=args.debias,
263 | clip=args.clip,
264 | alpha=args.alpha,
265 | beta=args.beta,
266 | rho=args.rho)
267 |
268 | # training
269 | for epoch in range(1, args.epochs + 1):
270 | num_data = 0
271 | for data, target in train_data():
272 | t0 = time.time()
273 | model.reset_state(data.shape[0])
274 | losses, regs = trainer.fit(data, target, epoch)
275 | total_reg_loss = regs.sum().item()
276 | total_clf_loss = losses.sum().item()
277 |
278 | num_data += data.shape[0]
279 | print(
280 | f'Epoch {epoch} [{num_data}/{len(x_train)}]\t'
281 | f'time: {time.time() - t0:.4f}s\tLoss: {total_clf_loss:.6f}\t'
282 | f'Reg: {total_reg_loss:.6f}'
283 | )
284 |
--------------------------------------------------------------------------------
/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 |
14 | # -- Project information -----------------------------------------------------
15 |
16 | project = 'BrainPy Examples'
17 | copyright = '2023, BrainPy team'
18 | author = 'BrainPy team'
19 |
20 | # -- General configuration ---------------------------------------------------
21 |
22 | # Add any Sphinx extension module names here, as strings. They can be
23 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
24 | # ones.
25 | extensions = [
26 | 'sphinx.ext.autodoc',
27 | 'sphinx.ext.autosummary',
28 | 'sphinx.ext.intersphinx',
29 | 'sphinx.ext.mathjax',
30 | 'sphinx.ext.napoleon',
31 | 'sphinx.ext.viewcode',
32 | 'nbsphinx',
33 | ]
34 | # Add any paths that contain templates here, relative to this directory.
35 | templates_path = ['_templates']
36 |
37 |
38 | # source_suffix = '.rst'
39 | autosummary_generate = True
40 | napolean_use_rtype = False
41 |
42 |
43 | # Execute notebooks before conversion: 'always', 'never', 'auto' (default)
44 | # We execute all notebooks, exclude the slow ones using 'exclude_patterns'
45 | nbsphinx_execute = 'never'
46 |
47 |
48 | # -- Options for HTML output -------------------------------------------------
49 | # The suffix(es) of source filenames.
50 | # You can specify multiple suffix as a list of string:
51 | #
52 | source_suffix = '.rst'
53 |
54 | # The master toctree document.
55 | main_doc = 'index'
56 | master_doc = 'index'
57 |
58 | # List of patterns, relative to source directory, that match files and
59 | # directories to ignore when looking for source files.
60 | # This pattern also affects html_static_path and html_extra_path.
61 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
62 |
63 |
64 | # The theme to use for HTML and HTML Help pages. See the documentation for
65 | # a list of builtin themes.
66 | #
67 | html_theme = 'sphinx_rtd_theme'
68 |
69 |
70 |
--------------------------------------------------------------------------------
/decision_making/Wang_2002_decision_making_spiking.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "afe4708f",
6 | "metadata": {},
7 | "source": [
8 | "# *(Wang, 2002)* Decision making spiking model\n",
9 | "\n",
10 | "[](https://colab.research.google.com/github/brainpy/examples/blob/main/decision_making/Wang_2002_decision_making_spiking.ipynb)\n",
11 | "[](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/decision_making/Wang_2002_decision_making_spiking.ipynb)"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "id": "83acef7d",
17 | "metadata": {},
18 | "source": [
19 | "Implementation of the paper: *Wang, Xiao-Jing. \"Probabilistic decision making by slow reverberation in cortical circuits.\" Neuron 36.5 (2002): 955-968.*\n",
20 | "\n",
21 | "- Author : Chaoming Wang (chao.brain@qq.com)"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "id": "88e26125",
27 | "metadata": {},
28 | "source": [
29 | "Please refer to github example folder:\n",
30 | "\n",
31 | "- https://github.com/brainpy/examples/blob/master/decision_making/Wang_2002_decision_making_spiking.py"
32 | ]
33 | }
34 | ],
35 | "metadata": {
36 | "jupytext": {
37 | "encoding": "# -*- coding: utf-8 -*-",
38 | "formats": "ipynb,py:percent"
39 | },
40 | "kernelspec": {
41 | "display_name": "brainpy",
42 | "language": "python",
43 | "name": "brainpy"
44 | },
45 | "language_info": {
46 | "codemirror_mode": {
47 | "name": "ipython",
48 | "version": 3
49 | },
50 | "file_extension": ".py",
51 | "mimetype": "text/x-python",
52 | "name": "python",
53 | "nbconvert_exporter": "python",
54 | "pygments_lexer": "ipython3",
55 | "version": "3.9.12"
56 | },
57 | "latex_envs": {
58 | "LaTeX_envs_menu_present": true,
59 | "autoclose": false,
60 | "autocomplete": true,
61 | "bibliofile": "biblio.bib",
62 | "cite_by": "apalike",
63 | "current_citInitial": 1,
64 | "eqLabelWithNumbers": true,
65 | "eqNumInitial": 1,
66 | "hotkeys": {
67 | "equation": "Ctrl-E",
68 | "itemize": "Ctrl-I"
69 | },
70 | "labels_anchors": false,
71 | "latex_user_defs": false,
72 | "report_style_numbering": false,
73 | "user_envs_cfg": false
74 | },
75 | "toc": {
76 | "base_numbering": 1,
77 | "nav_menu": {},
78 | "number_sections": true,
79 | "sideBar": true,
80 | "skip_h1_title": false,
81 | "title_cell": "Table of Contents",
82 | "title_sidebar": "Contents",
83 | "toc_cell": false,
84 | "toc_position": {},
85 | "toc_section_display": true,
86 | "toc_window_display": true
87 | }
88 | },
89 | "nbformat": 4,
90 | "nbformat_minor": 5
91 | }
92 |
--------------------------------------------------------------------------------
/decision_making/Wang_2002_decision_making_spiking.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | import brainpy as bp
5 | import brainpy.math as bm
6 | import matplotlib.pyplot as plt
7 |
8 | print(bp.__version__)
9 |
10 | # bm.set_platform('cpu')
11 |
12 |
13 | class PoissonStim(bp.NeuGroup):
14 | def __init__(self, size, freq_mean, freq_var, t_interval):
15 | super(PoissonStim, self).__init__(size=size)
16 |
17 | # parameters
18 | self.freq_mean = freq_mean
19 | self.freq_var = freq_var
20 | self.t_interval = t_interval
21 |
22 | # variables
23 | self.freq = bp.init.variable_(bm.zeros, 1, self.mode)
24 | self.freq_t_last_change = bp.init.variable_(lambda s: bm.ones(s) * -1e7, 1, self.mode)
25 | self.spike = bp.init.variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode)
26 | self.rng = bm.random.RandomState()
27 |
28 | def reset_state(self, batch_size=None):
29 | self.freq.value = bp.init.variable_(bm.zeros, 1, batch_size)
30 | self.freq_t_last_change.value = bp.init.variable_(lambda s: bm.ones(s) * -1e7, 1, batch_size)
31 | self.spike.value = bp.init.variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size)
32 |
33 | def update(self):
34 | t = bp.share['t']
35 | dt = bp.share['dt']
36 | in_interval = bm.logical_and(pre_stimulus_period < t, t < pre_stimulus_period + stimulus_period)
37 | in_interval = bm.ones_like(self.freq, dtype=bool) * in_interval
38 | prev_freq = bm.where(in_interval, self.freq, 0.)
39 | in_interval = bm.logical_and(in_interval, (t - self.freq_t_last_change) >= self.t_interval)
40 | self.freq.value = bm.where(in_interval, self.rng.normal(self.freq_mean, self.freq_var, self.freq.shape), prev_freq)
41 | self.freq_t_last_change.value = bm.where(in_interval, t, self.freq_t_last_change)
42 | shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, bm.BatchingMode) else self.varshape
43 | self.spike.value = self.rng.random(shape) < self.freq * dt / 1000.
44 |
45 |
46 | class DecisionMaking(bp.Network):
47 | def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15):
48 | super(DecisionMaking, self).__init__()
49 |
50 | num_exc = int(1600 * scale)
51 | num_inh = int(400 * scale)
52 | num_A = int(f * num_exc)
53 | num_B = int(f * num_exc)
54 | num_N = num_exc - num_A - num_B
55 | print(f'Total network size: {num_exc + num_inh}')
56 |
57 | poisson_freq = 2400. # Hz
58 | w_pos = 1.7
59 | w_neg = 1. - f * (w_pos - 1.) / (1. - f)
60 | g_ext2E_AMPA = 2.1 # nS
61 | g_ext2I_AMPA = 1.62 # nS
62 | g_E2E_AMPA = 0.05 / scale # nS
63 | g_E2I_AMPA = 0.04 / scale # nS
64 | g_E2E_NMDA = 0.165 / scale # nS
65 | g_E2I_NMDA = 0.13 / scale # nS
66 | g_I2E_GABAa = 1.3 / scale # nS
67 | g_I2I_GABAa = 1.0 / scale # nS
68 |
69 | ampa_par = dict(delay_step=int(0.5 / bm.get_dt()), tau=2.0)
70 | gaba_par = dict(delay_step=int(0.5 / bm.get_dt()), tau=5.0)
71 | nmda_par = dict(delay_step=int(0.5 / bm.get_dt()), tau_decay=100, tau_rise=2., a=0.5)
72 |
73 | # E neurons/pyramid neurons
74 | A = bp.neurons.LIF(num_A, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04,
75 | tau_ref=2., V_initializer=bp.init.OneInit(-70.))
76 | B = bp.neurons.LIF(num_B, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04,
77 | tau_ref=2., V_initializer=bp.init.OneInit(-70.))
78 | N = bp.neurons.LIF(num_N, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04,
79 | tau_ref=2., V_initializer=bp.init.OneInit(-70.))
80 | # I neurons/interneurons
81 | I = bp.neurons.LIF(num_inh, V_rest=-70., V_reset=-55., V_th=-50., tau=10., R=0.05,
82 | tau_ref=1., V_initializer=bp.init.OneInit(-70.))
83 |
84 | # poisson stimulus
85 | IA = PoissonStim(num_A, freq_var=10., t_interval=50., freq_mean=mu0 + mu0 / 100. * coherence)
86 | IB = PoissonStim(num_B, freq_var=10., t_interval=50., freq_mean=mu0 - mu0 / 100. * coherence)
87 |
88 | # noise neurons
89 | self.noise_B = bp.neurons.PoissonGroup(num_B, freqs=poisson_freq)
90 | self.noise_A = bp.neurons.PoissonGroup(num_A, freqs=poisson_freq)
91 | self.noise_N = bp.neurons.PoissonGroup(num_N, freqs=poisson_freq)
92 | self.noise_I = bp.neurons.PoissonGroup(num_inh, freqs=poisson_freq)
93 |
94 | # define external inputs
95 | self.IA2A = bp.synapses.Exponential(IA, A, bp.conn.One2One(), g_max=g_ext2E_AMPA,
96 | output=bp.synouts.COBA(E=0.), **ampa_par)
97 | self.IB2B = bp.synapses.Exponential(IB, B, bp.conn.One2One(), g_max=g_ext2E_AMPA,
98 | output=bp.synouts.COBA(E=0.), **ampa_par)
99 |
100 | # define E->E/I conn
101 |
102 | self.N2B_AMPA = bp.synapses.Exponential(N, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg,
103 | output=bp.synouts.COBA(E=0.), **ampa_par)
104 | self.N2A_AMPA = bp.synapses.Exponential(N, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg,
105 | output=bp.synouts.COBA(E=0.), **ampa_par)
106 | self.N2N_AMPA = bp.synapses.Exponential(N, N, bp.conn.All2All(), g_max=g_E2E_AMPA,
107 | output=bp.synouts.COBA(E=0.), **ampa_par)
108 | self.N2I_AMPA = bp.synapses.Exponential(N, I, bp.conn.All2All(), g_max=g_E2I_AMPA,
109 | output=bp.synouts.COBA(E=0.), **ampa_par)
110 | self.N2B_NMDA = bp.synapses.NMDA(N, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg,
111 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
112 | self.N2A_NMDA = bp.synapses.NMDA(N, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg,
113 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
114 | self.N2N_NMDA = bp.synapses.NMDA(N, N, bp.conn.All2All(), g_max=g_E2E_NMDA,
115 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
116 | self.N2I_NMDA = bp.synapses.NMDA(N, I, bp.conn.All2All(), g_max=g_E2I_NMDA,
117 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
118 |
119 | self.B2B_AMPA = bp.synapses.Exponential(B, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos,
120 | output=bp.synouts.COBA(E=0.), **ampa_par)
121 | self.B2A_AMPA = bp.synapses.Exponential(B, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg,
122 | output=bp.synouts.COBA(E=0.), **ampa_par)
123 | self.B2N_AMPA = bp.synapses.Exponential(B, N, bp.conn.All2All(), g_max=g_E2E_AMPA,
124 | output=bp.synouts.COBA(E=0.), **ampa_par)
125 | self.B2I_AMPA = bp.synapses.Exponential(B, I, bp.conn.All2All(), g_max=g_E2I_AMPA,
126 | output=bp.synouts.COBA(E=0.), **ampa_par)
127 | self.B2B_NMDA = bp.synapses.NMDA(B, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos,
128 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
129 | self.B2A_NMDA = bp.synapses.NMDA(B, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg,
130 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
131 | self.B2N_NMDA = bp.synapses.NMDA(B, N, bp.conn.All2All(), g_max=g_E2E_NMDA,
132 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
133 | self.B2I_NMDA = bp.synapses.NMDA(B, I, bp.conn.All2All(), g_max=g_E2I_NMDA,
134 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
135 |
136 | self.A2B_AMPA = bp.synapses.Exponential(A, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg,
137 | output=bp.synouts.COBA(E=0.), **ampa_par)
138 | self.A2A_AMPA = bp.synapses.Exponential(A, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos,
139 | output=bp.synouts.COBA(E=0.), **ampa_par)
140 | self.A2N_AMPA = bp.synapses.Exponential(A, N, bp.conn.All2All(), g_max=g_E2E_AMPA,
141 | output=bp.synouts.COBA(E=0.), **ampa_par)
142 | self.A2I_AMPA = bp.synapses.Exponential(A, I, bp.conn.All2All(), g_max=g_E2I_AMPA,
143 | output=bp.synouts.COBA(E=0.), **ampa_par)
144 | self.A2B_NMDA = bp.synapses.NMDA(A, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg,
145 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
146 | self.A2A_NMDA = bp.synapses.NMDA(A, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos,
147 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
148 | self.A2N_NMDA = bp.synapses.NMDA(A, N, bp.conn.All2All(), g_max=g_E2E_NMDA,
149 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
150 | self.A2I_NMDA = bp.synapses.NMDA(A, I, bp.conn.All2All(), g_max=g_E2I_NMDA,
151 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par)
152 |
153 | # define I->E/I conn
154 | self.I2B = bp.synapses.Exponential(I, B, bp.conn.All2All(), g_max=g_I2E_GABAa,
155 | output=bp.synouts.COBA(E=-70.), **gaba_par)
156 | self.I2A = bp.synapses.Exponential(I, A, bp.conn.All2All(), g_max=g_I2E_GABAa,
157 | output=bp.synouts.COBA(E=-70.), **gaba_par)
158 | self.I2N = bp.synapses.Exponential(I, N, bp.conn.All2All(), g_max=g_I2E_GABAa,
159 | output=bp.synouts.COBA(E=-70.), **gaba_par)
160 | self.I2I = bp.synapses.Exponential(I, I, bp.conn.All2All(), g_max=g_I2I_GABAa,
161 | output=bp.synouts.COBA(E=-70.), **gaba_par)
162 |
163 | # define external projections
164 | self.noise2B = bp.synapses.Exponential(self.noise_B, B, bp.conn.One2One(), g_max=g_ext2E_AMPA,
165 | output=bp.synouts.COBA(E=0.), **ampa_par)
166 | self.noise2A = bp.synapses.Exponential(self.noise_A, A, bp.conn.One2One(), g_max=g_ext2E_AMPA,
167 | output=bp.synouts.COBA(E=0.), **ampa_par)
168 | self.noise2N = bp.synapses.Exponential(self.noise_N, N, bp.conn.One2One(), g_max=g_ext2E_AMPA,
169 | output=bp.synouts.COBA(E=0.), **ampa_par)
170 | self.noise2I = bp.synapses.Exponential(self.noise_I, I, bp.conn.One2One(), g_max=g_ext2I_AMPA,
171 | output=bp.synouts.COBA(E=0.), **ampa_par)
172 |
173 | # nodes
174 | self.B = B
175 | self.A = A
176 | self.N = N
177 | self.I = I
178 | self.IA = IA
179 | self.IB = IB
180 |
181 |
182 | def visualize_raster(ax, mon, t_start=0., title=None):
183 | bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax, color='', label="Group A")
184 | bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax, color='', label="Group B")
185 | if title:
186 | ax.set_title(title)
187 | ax.set_ylabel("Neuron Index")
188 | ax.set_xlim(t_start, total_period + 1)
189 | ax.axvline(pre_stimulus_period, linestyle='dashed')
190 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed')
191 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed')
192 | ax.legend()
193 |
194 |
195 | def visualize_results(axes, mon, t_start=0., title=None):
196 | ax = axes[0]
197 | bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax)
198 | if title:
199 | ax.set_title(title)
200 | ax.set_ylabel("Group A")
201 | ax.set_xlim(t_start, total_period + 1)
202 | ax.axvline(pre_stimulus_period, linestyle='dashed')
203 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed')
204 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed')
205 |
206 | ax = axes[1]
207 | bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax)
208 | ax.set_ylabel("Group B")
209 | ax.set_xlim(t_start, total_period + 1)
210 | ax.axvline(pre_stimulus_period, linestyle='dashed')
211 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed')
212 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed')
213 |
214 | ax = axes[2]
215 | rateA = bp.measure.firing_rate(mon['A.spike'], width=10.)
216 | rateB = bp.measure.firing_rate(mon['B.spike'], width=10.)
217 | ax.plot(mon['ts'], rateA, label="Group A")
218 | ax.plot(mon['ts'], rateB, label="Group B")
219 | ax.set_ylabel('Population activity [Hz]')
220 | ax.set_xlim(t_start, total_period + 1)
221 | ax.axvline(pre_stimulus_period, linestyle='dashed')
222 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed')
223 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed')
224 | ax.legend()
225 |
226 | ax = axes[3]
227 | ax.plot(mon['ts'], mon['IA.freq'], label="group A")
228 | ax.plot(mon['ts'], mon['IB.freq'], label="group B")
229 | ax.set_ylabel("Input activity [Hz]")
230 | ax.set_xlim(t_start, total_period + 1)
231 | ax.axvline(pre_stimulus_period, linestyle='dashed')
232 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed')
233 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed')
234 | ax.legend()
235 | ax.set_xlabel("Time [ms]")
236 |
237 |
238 | pre_stimulus_period = 100.
239 | stimulus_period = 1000.
240 | delay_period = 500.
241 | total_period = pre_stimulus_period + stimulus_period + delay_period
242 |
243 |
244 | def single_run():
245 | net = DecisionMaking(scale=1., coherence=-80., mu0=50.)
246 | runner = bp.DSRunner(
247 | net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq']
248 | )
249 | runner.run(total_period)
250 |
251 | fig, gs = bp.visualize.get_figure(4, 1, 3, 10)
252 | axes = [fig.add_subplot(gs[i, 0]) for i in range(4)]
253 | visualize_results(axes, mon=runner.mon)
254 | plt.show()
255 |
256 |
257 | def batching_run():
258 | num_row, num_col = 3, 4
259 | num_batch = 12
260 | coherence = bm.expand_dims(bm.linspace(-100, 100., num_batch), 1)
261 |
262 | with bm.environment(mode=bm.BatchingMode(batch_size=num_batch)):
263 | net = DecisionMaking(scale=1., coherence=coherence, mu0=20.)
264 | runner = bp.DSRunner(
265 | net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'], data_first_axis='B'
266 | )
267 | runner.run(total_period)
268 |
269 | coherence = bm.as_numpy(coherence)
270 | fig, gs = bp.visualize.get_figure(num_row, num_col, 3, 4)
271 | for i in range(num_row):
272 | for j in range(num_col):
273 | idx = i * num_col + j
274 | if idx < num_batch:
275 | mon = {'A.spike': runner.mon['A.spike'][idx],
276 | 'B.spike': runner.mon['B.spike'][idx],
277 | 'IA.freq': runner.mon['IA.freq'][idx],
278 | 'IB.freq': runner.mon['IB.freq'][idx],
279 | 'ts': runner.mon['ts']}
280 | ax = fig.add_subplot(gs[i, j])
281 | visualize_raster(ax, mon=mon, title=f'coherence={coherence[idx, 0]}%')
282 | plt.show()
283 |
284 |
285 | if __name__ == '__main__':
286 | single_run()
287 | batching_run()
288 |
--------------------------------------------------------------------------------
/decision_making/Wang_2006_decision_making_rate.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "34df5f48",
6 | "metadata": {
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | }
10 | },
11 | "source": [
12 | "# *(Wong & Wang, 2006)* Decision making rate model\n",
13 | "\n",
14 | "[](https://colab.research.google.com/github/brainpy/examples/blob/main/decision_making/Wang_2006_decision_making_rate.ipynb)\n",
15 | "[](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/decision_making/Wang_2006_decision_making_rate.ipynb)"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "id": "ad00c2ea",
21 | "metadata": {
22 | "pycharm": {
23 | "name": "#%% md\n"
24 | }
25 | },
26 | "source": [
27 | "Please refer to [Anlysis of A Decision Making Model](https://brainpy.readthedocs.io/en/latest/tutorial_analysis/decision_making_model.html)."
28 | ]
29 | }
30 | ],
31 | "metadata": {
32 | "kernelspec": {
33 | "display_name": "brainpy",
34 | "language": "python",
35 | "name": "brainpy"
36 | },
37 | "language_info": {
38 | "codemirror_mode": {
39 | "name": "ipython",
40 | "version": 3
41 | },
42 | "file_extension": ".py",
43 | "mimetype": "text/x-python",
44 | "name": "python",
45 | "nbconvert_exporter": "python",
46 | "pygments_lexer": "ipython3",
47 | "version": "3.8.11"
48 | },
49 | "latex_envs": {
50 | "LaTeX_envs_menu_present": true,
51 | "autoclose": false,
52 | "autocomplete": true,
53 | "bibliofile": "biblio.bib",
54 | "cite_by": "apalike",
55 | "current_citInitial": 1,
56 | "eqLabelWithNumbers": true,
57 | "eqNumInitial": 1,
58 | "hotkeys": {
59 | "equation": "Ctrl-E",
60 | "itemize": "Ctrl-I"
61 | },
62 | "labels_anchors": false,
63 | "latex_user_defs": false,
64 | "report_style_numbering": false,
65 | "user_envs_cfg": false
66 | },
67 | "toc": {
68 | "base_numbering": 1,
69 | "nav_menu": {},
70 | "number_sections": true,
71 | "sideBar": true,
72 | "skip_h1_title": false,
73 | "title_cell": "Table of Contents",
74 | "title_sidebar": "Contents",
75 | "toc_cell": false,
76 | "toc_position": {},
77 | "toc_section_display": true,
78 | "toc_window_display": true
79 | }
80 | },
81 | "nbformat": 4,
82 | "nbformat_minor": 5
83 | }
84 |
--------------------------------------------------------------------------------
/dynamics_analysis/2d_decision_making_with_lowdim_analyzer.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "34df5f48",
6 | "metadata": {
7 | "pycharm": {
8 | "name": "#%% md\n"
9 | }
10 | },
11 | "source": [
12 | "# [2D] Decision Making Model with Low-dimensional Analyzer\n",
13 | "\n",
14 | "[](https://colab.research.google.com/github/brainpy/examples/blob/main/dynamics_analysis/2d_decision_making_with_lowdim_analyzer.ipynb)\n",
15 | "[](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/dynamics_analysis/2d_decision_making_with_lowdim_analyzer.ipynb)"
16 | ]
17 | },
18 | {
19 | "cell_type": "markdown",
20 | "id": "ad00c2ea",
21 | "metadata": {
22 | "pycharm": {
23 | "name": "#%% md\n"
24 | }
25 | },
26 | "source": [
27 | "Please refer to [Anlysis of A Decision Making Model](https://brainpy.readthedocs.io/en/brainpy-2.x/tutorial_analysis/decision_making_model.html)."
28 | ]
29 | }
30 | ],
31 | "metadata": {
32 | "kernelspec": {
33 | "display_name": "brainpy",
34 | "language": "python",
35 | "name": "brainpy"
36 | },
37 | "language_info": {
38 | "codemirror_mode": {
39 | "name": "ipython",
40 | "version": 3
41 | },
42 | "file_extension": ".py",
43 | "mimetype": "text/x-python",
44 | "name": "python",
45 | "nbconvert_exporter": "python",
46 | "pygments_lexer": "ipython3",
47 | "version": "3.8.11"
48 | },
49 | "latex_envs": {
50 | "LaTeX_envs_menu_present": true,
51 | "autoclose": false,
52 | "autocomplete": true,
53 | "bibliofile": "biblio.bib",
54 | "cite_by": "apalike",
55 | "current_citInitial": 1,
56 | "eqLabelWithNumbers": true,
57 | "eqNumInitial": 1,
58 | "hotkeys": {
59 | "equation": "Ctrl-E",
60 | "itemize": "Ctrl-I"
61 | },
62 | "labels_anchors": false,
63 | "latex_user_defs": false,
64 | "report_style_numbering": false,
65 | "user_envs_cfg": false
66 | },
67 | "toc": {
68 | "base_numbering": 1,
69 | "nav_menu": {},
70 | "number_sections": true,
71 | "sideBar": true,
72 | "skip_h1_title": false,
73 | "title_cell": "Table of Contents",
74 | "title_sidebar": "Contents",
75 | "toc_cell": false,
76 | "toc_position": {},
77 | "toc_section_display": true,
78 | "toc_window_display": true
79 | }
80 | },
81 | "nbformat": 4,
82 | "nbformat_minor": 5
83 | }
84 |
--------------------------------------------------------------------------------
/images/cann-decoding.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann-decoding.gif
--------------------------------------------------------------------------------
/images/cann-encoding.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann-encoding.gif
--------------------------------------------------------------------------------
/images/cann-tracking.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann-tracking.gif
--------------------------------------------------------------------------------
/images/cann_1d_oscillatory_tracking.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann_1d_oscillatory_tracking.gif
--------------------------------------------------------------------------------
/images/cann_2d_encoding.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann_2d_encoding.gif
--------------------------------------------------------------------------------
/images/cann_2d_tracking.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann_2d_tracking.gif
--------------------------------------------------------------------------------
/images/decision_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/decision_model.png
--------------------------------------------------------------------------------
/images/izhikevich_patterns.jfif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/izhikevich_patterns.jfif
--------------------------------------------------------------------------------
/index.rst:
--------------------------------------------------------------------------------
1 | BrainPy Examples
2 | ================
3 |
4 | This repository contains examples of using `BrainPy `_
5 | to implement various models about neurons, synapse, networks, etc. We welcome your implementation,
6 | which can be post through our `github `_ page.
7 |
8 | If you run some codes failed, please tell us through github issue https://github.com/brainpy/examples/issues .
9 |
10 | If you found these examples are useful for your research, please kindly `cite us `_.
11 |
12 | If you want to add more examples, please fork our github https://github.com/brainpy/examples .
13 |
14 |
15 |
16 | Example categories:
17 |
18 | .. contents::
19 | :local:
20 | :depth: 2
21 |
22 |
23 |
24 |
25 | Neuron Models
26 | -------------
27 |
28 | - `(Izhikevich, 2003): Izhikevich Model `_
29 | - `(Brette, Romain. 2004): LIF phase locking `_
30 | - `(Gerstner, 2005): Adaptive Exponential Integrate-and-Fire model `_
31 | - `(Niebur, et. al, 2009): Generalized integrate-and-fire model `_
32 | - `(Jansen & Rit, 1995): Jansen-Rit Model `_
33 | - `(Teka, et. al, 2018): Fractional-order Izhikevich neuron model `_
34 | - `(Mondal, et. al, 2019): Fractional-order FitzHugh-Rinzel bursting neuron model `_
35 |
36 |
37 |
38 | Attractor Networks
39 | ------------------
40 |
41 | - `CANN 1D Oscillatory Tracking `_
42 | - `(Si Wu, 2008): Continuous-attractor Neural Network 1D `_
43 | - `(Si Wu, 2008): Continuous-attractor Neural Network 2D `_
44 | - `Discrete Hopfield Network `_
45 | - `Discrete Hopfield Network Demo for Image Reconstruction `_
46 |
47 |
48 |
49 | Decision Making Model
50 | ---------------------
51 |
52 | - `(Wang, 2002): Decision making spiking model `_
53 | - `(Wong & Wang, 2006): Decision making rate model `_
54 |
55 |
56 |
57 |
58 | E/I Balanced Network
59 | --------------------
60 |
61 |
62 | - `(Vreeswijk & Sompolinsky, 1996): E/I balanced network `_
63 | - `(Brette, et, al., 2007): COBA `_
64 | - `(Brette, et, al., 2007): CUBA `_
65 | - `(Brette, et, al., 2007): COBA-HH `_
66 | - `(Tian, et al., 2020): E/I Net for fast response `_
67 |
68 |
69 |
70 | Brain-inspired Computing
71 | ------------------------
72 |
73 |
74 | - `Classify MNIST dataset by a fully connected LIF layer `_
75 | - `Convolutional SNN to Classify Fashion-MNIST `_
76 | - `(2022, NeurIPS): Online Training Through Time for Spiking Neural Networks `_
77 | - `(2019, Zenke, F.): SNN Surrogate Gradient Learning `_
78 | - `(2019, Zenke, F.): SNN Surrogate Gradient Learning to Classify Fashion-MNIST `_
79 | - `(2021, Raminmh): Liquid time-constant Networks `_
80 |
81 |
82 |
83 | Reservoir Computing
84 | -------------------
85 |
86 |
87 | - `Predicting Mackey-Glass timeseries `_
88 | - `(Sussillo & Abbott, 2009): FORCE Learning `_
89 | - `(Gauthier, et. al, 2021): Next generation reservoir computing `_
90 |
91 |
92 |
93 | Gap Junction Network
94 | --------------------
95 |
96 | - `(Fazli and Richard, 2022): Electrically Coupled Bursting Pituitary Cells `_
97 | - `(Sherman & Rinzel, 1992): Gap junction leads to anti-synchronization `_
98 |
99 |
100 |
101 | Oscillation and Synchronization
102 | -------------------------------
103 |
104 | - `(Wang & Buzsáki, 1996): Gamma Oscillation `_
105 | - `(Brunel & Hakim, 1999): Fast Global Oscillation `_
106 | - `(Diesmann, et, al., 1999): Synfire Chains `_
107 | - `(Li, et. al, 2017): Unified Thalamus Oscillation Model `_
108 | - `(Susin & Destexhe, 2021): Asynchronous Network `_
109 | - `(Susin & Destexhe, 2021): CHING Network for Generating Gamma Oscillation `_
110 | - `(Susin & Destexhe, 2021): ING Network for Generating Gamma Oscillation `_
111 | - `(Susin & Destexhe, 2021): PING Network for Generating Gamma Oscillation `_
112 |
113 |
114 |
115 | Large-Scale Modeling
116 | --------------------
117 |
118 | - `(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 1 `_
119 | - `(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 2 `_
120 | - `(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 5 `_
121 | - `(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Taichi customized operators `_
122 | - `Simulating 1-million-neuron networks with 1GB GPU memory `_
123 |
124 |
125 |
126 | Recurrent Neural Network
127 | ------------------------
128 |
129 |
130 | - `(Sussillo & Abbott, 2009): FORCE Learning `_
131 | - `Integrator RNN Model `_
132 | - `Train RNN to Solve Parametric Working Memory `_
133 | - `(Song, et al., 2016): Training excitatory-inhibitory recurrent network `_
134 | - `(Masse, et al., 2019): RNN with STP for Working Memory `_
135 | - `(Yang, 2020): Dynamical system analysis for RNN `_
136 | - `(Bellec, et. al, 2020): eprop for Evidence Accumulation Task `_
137 |
138 |
139 |
140 | Working Memory Model
141 | --------------------
142 |
143 | - `(Bouchacourt & Buschman, 2019): Flexible Working Memory Model `_
144 | - `(Mi, et. al., 2017): STP for Working Memory Capacity `_
145 | - `(Masse, et al., 2019): RNN with STP for Working Memory `_
146 |
147 |
148 |
149 | Dynamics Analysis
150 | -----------------
151 |
152 | - `[1D] Simple systems `_
153 | - `[2D] NaK model analysis `_
154 | - `[2D] Wilson-Cowan model `_
155 | - `[2D] Decision Making Model with SlowPointFinder `_
156 | - `[2D] Decision Making Model with Low-dimensional Analyzer `_
157 | - `[3D] Hindmarsh Rose Model `_
158 | - `Continuous-attractor Neural Network `_
159 | - `Gap junction-coupled FitzHugh-Nagumo Model `_
160 | - `(Yang, 2020): Dynamical system analysis for RNN `_
161 |
162 |
163 |
164 |
165 | Classical Dynamical Systems
166 | ---------------------------
167 |
168 | - `Hénon map `_
169 | - `Logistic map `_
170 | - `Lorenz system `_
171 | - `Mackey-Glass equation `_
172 | - `Multiscroll chaotic attractor (多卷波混沌吸引子) `_
173 | - `Rabinovich-Fabrikant equations `_
174 | - `Fractional-order Chaos Gallery `_
175 |
176 |
177 |
178 |
179 |
180 | Unclassified Models
181 | -------------------
182 |
183 | - `(Brette & Guigon, 2003): Reliability of spike timing `_
184 |
185 |
186 |
187 |
188 |
189 | Indices and tables
190 | ==================
191 |
192 | * :ref:`genindex`
193 | * :ref:`modindex`
194 | * :ref:`search`
195 |
--------------------------------------------------------------------------------
/large_scale_modeling/2014_CorticalModel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import brainpy as bp
4 | import brainpy.math as bm
5 | import numpy as np
6 |
7 | bp.math.set_platform('cpu')
8 |
9 |
10 | class LIF(bp.NeuGroup):
11 | def __init__(self, size, tau_neu=10., tau_syn=0.5, tau_ref=2.,
12 | V_reset=-65., V_th=-50., Cm=0.25, ):
13 | super(LIF, self).__init__(size=size)
14 |
15 | # parameters
16 | self.tau_neu = tau_neu # membrane time constant [ms]
17 | self.tau_syn = tau_syn # Post-synaptic current time constant [ms]
18 | self.tau_ref = tau_ref # absolute refractory period [ms]
19 | self.Cm = Cm # membrane capacity [nF]
20 | self.V_reset = V_reset # reset potential [mV]
21 | self.V_th = V_th # fixed firing threshold [mV]
22 | self.Iext = 0. # constant external current [nA]
23 |
24 | # variables
25 | self.V = bm.Variable(-65. + 5.0 * bm.random.randn(self.num)) # [mV]
26 | self.I = bm.Variable(bm.zeros(self.num)) # synaptic currents [nA]
27 | self.spike = bm.Variable(bm.zeros(self.num, dtype=bool))
28 | self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7)
29 |
30 | # function
31 | self.integral = bp.odeint(bp.JointEq([self.dV, self.dI]), method='exp_auto')
32 |
33 | def dV(self, V, t, I):
34 | return (-V + self.V_reset) / self.tau_neu + (I + self.Iext) / self.Cm
35 |
36 | def dI(self, I, t):
37 | return -I / self.tau_syn
38 |
39 | def update(self, _t, _dt):
40 | ref = (_t - self.t_last_spike) <= self.tau_ref
41 | V, I = self.integral(self.V, self.I, _t, _dt)
42 | V = bm.where(ref, self.V, V)
43 | spike = (V >= self.V_th)
44 | self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
45 | self.V.value = bm.where(spike, self.V_reset, V)
46 | self.spike.value = spike
47 | self.I.value = I
48 |
49 |
50 | class ExpSyn(bp.dyn.TwoEndConn):
51 | # Synapses parameters
52 | exc_delay = (1.5, 0.75) # Excitatory/Std. delay [ms]
53 | inh_delay = (0.80, 0.4) # Inhibitory/Std. delay [ms]
54 | exc_weight = (0.0878, 0.0088) # excitatory/Std. synaptic weight [nA]
55 | inh_weight_scale = -4. # Relative inhibitory synaptic strength
56 |
57 | def __init__(self, pre, post, prob, syn_type='e', conn_type=0):
58 | super(ExpSyn, self).__init__(pre=pre, post=post, conn=None)
59 | self.check_pre_attrs('spike')
60 | self.check_post_attrs('I')
61 | assert syn_type in ['e', 'i']
62 | # assert conn_type in [0, 1, 2, 3]
63 | assert 0. < prob < 1.
64 |
65 | # parameters
66 | self.syn_type = syn_type
67 | self.conn_type = conn_type
68 |
69 | # connection
70 | if conn_type == 0:
71 | # number of synapses calculated with equation 3 from the article
72 | num = int(np.log(1.0 - prob) / np.log(1.0 - (1.0 / float(pre.num * post.num))))
73 | self.pre2post = bp.conn.ij2csr(pre_ids=np.random.randint(0, pre.num, num),
74 | post_ids=np.random.randint(0, post.num, num),
75 | num_pre=pre.num)
76 | self.num = self.pre2post[0].size
77 | elif conn_type == 1:
78 | # number of synapses calculated with equation 5 from the article
79 | self.pre2post = bp.conn.FixedProb(prob)(pre.size, post.size).require('pre2post')
80 | self.num = self.pre2post[0].size
81 | elif conn_type == 2:
82 | self.num = int(prob * pre.num * post.num)
83 | self.pre_ids = bm.random.randint(0, pre.num, size=self.num, dtype=bm.uint32)
84 | self.post_ids = bm.random.randint(0, post.num, size=self.num, dtype=bm.uint32)
85 | elif conn_type in [3, 4]:
86 | self.pre2post = bp.conn.FixedProb(prob)(pre.size, post.size).require('pre2post')
87 | self.num = self.pre2post[0].size
88 | self.max_post_conn = bm.diff(self.pre2post[1]).max()
89 | else:
90 | raise ValueError
91 |
92 | # delay
93 | if syn_type == 'e':
94 | self.delay = bm.random.normal(*self.exc_delay, size=pre.num)
95 | elif syn_type == 'i':
96 | self.delay = bm.random.normal(*self.inh_delay, size=pre.num)
97 | else:
98 | raise ValueError
99 | self.delay = bm.where(self.delay < bm.get_dt(), bm.get_dt(), self.delay)
100 |
101 | # weights
102 | self.weights = bm.random.normal(*self.exc_weight, size=self.num)
103 | self.weights = bm.where(self.weights < 0, 0., self.weights)
104 | if syn_type == 'i':
105 | self.weights *= self.inh_weight_scale
106 |
107 | # variables
108 | self.pre_sps = bp.ConstantDelay(pre.num, self.delay, bool)
109 |
110 | def update(self, _t, _dt):
111 | self.pre_sps.push(self.pre.spike)
112 | delayed_sps = self.pre_sps.pull()
113 | if self.conn_type in [0, 1]:
114 | post_vs = bm.pre2post_event_sum(delayed_sps, self.pre2post, self.post.num, self.weights)
115 | elif self.conn_type == 2:
116 | post_vs = bm.pre2post_event_sum2(delayed_sps, self.pre_ids, self.post_ids, self.post.num, self.weights)
117 | # post_vs = bm.zeros(self.post.num)
118 | # post_vs = post_vs.value.at[self.post_ids.value].add(delayed_sps[self.pre_ids.value])
119 | elif self.conn_type == 3:
120 | post_vs = bm.pre2post_event_sum3(delayed_sps, self.pre2post, self.post.num, self.weights,
121 | self.max_post_conn)
122 | elif self.conn_type == 4:
123 | post_vs = bm.pre2post_event_sum4(delayed_sps, self.pre2post, self.post.num, self.weights,
124 | self.max_post_conn)
125 | else:
126 | raise ValueError
127 | self.post.I += post_vs
128 |
129 |
130 | # class PoissonInput(bp.NeuGroup):
131 | # def __init__(self, post, freq=8.):
132 | # base = 20
133 | # super(PoissonInput, self).__init__(size=(post.num, base))
134 | #
135 | # # parameters
136 | # freq = post.num * freq / base
137 | # self.prob = freq * bm.get_dt() / 1000.
138 | # self.weight = ExpSyn.exc_weight[0]
139 | # self.post = post
140 | # assert hasattr(post, 'I')
141 | #
142 | # # variables
143 | # self.rng = bm.random.RandomState()
144 | #
145 | # def update(self, _t, _dt):
146 | # self.post.I += self.weight * self.rng.random(self.size).sum(axis=1)
147 |
148 |
149 | class PoissonInput(bp.NeuGroup):
150 | def __init__(self, post, freq=8.):
151 | super(PoissonInput, self).__init__(size=(post.num,))
152 |
153 | # parameters
154 | self.prob = freq * bm.get_dt() / 1000.
155 | self.loc = post.num * self.prob
156 | self.scale = np.sqrt(post.num * self.prob * (1 - self.prob))
157 | self.weight = ExpSyn.exc_weight[0]
158 | self.post = post
159 | assert hasattr(post, 'I')
160 |
161 | # variables
162 | self.rng = bm.random.RandomState()
163 |
164 | def update(self, _t, _dt):
165 | self.post.I += self.weight * self.rng.normal(self.loc, self.scale, self.num)
166 |
167 |
168 | class PoissonInput2(bp.NeuGroup):
169 | def __init__(self, pops, freq=8.):
170 | super(PoissonInput2, self).__init__(size=sum([p.num for p in pops]))
171 |
172 | # parameters
173 | self.pops = pops
174 | prob = freq * bm.get_dt() / 1000.
175 | assert (prob * self.num > 5.) and (self.num * (1 - prob) > 5)
176 | self.loc = self.num * prob
177 | self.scale = np.sqrt(self.num * prob * (1 - prob))
178 | self.weight = ExpSyn.exc_weight[0]
179 |
180 | # variables
181 | self.rng = bm.random.RandomState()
182 |
183 | def update(self, _t, _dt):
184 | sample_weights = self.rng.normal(self.loc, self.scale, self.num) * self.weight
185 | size = 0
186 | for p in self.pops:
187 | p.I += sample_weights[size: size + p.num]
188 | size += p.num
189 |
190 |
191 | class ThalamusInput(bp.TwoEndConn):
192 | def __init__(self, pre, post, conn_prob=0.1):
193 | super(ThalamusInput, self).__init__(pre=pre, post=post, conn=bp.conn.FixedProb(conn_prob))
194 | self.check_pre_attrs('spike')
195 | self.check_post_attrs('I')
196 |
197 | # connection and weights
198 | self.pre2post = self.conn.require('pre2post')
199 | self.syn_num = self.pre2post[0].size
200 | self.weights = bm.random.normal(*ExpSyn.exc_weight, size=self.syn_num)
201 | self.weights = bm.where(self.weights < 0., 0., self.weights)
202 |
203 | # variables
204 | self.turn_on = bm.Variable(bm.asarray([False]))
205 |
206 | def update(self, _t, _dt):
207 | def true_fn(x):
208 | post_vs = bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.weights)
209 | self.post.I += post_vs
210 |
211 | bm.make_cond(true_fn, lambda _: None, dyn_vars=(self.post.I, self.pre.spike))(self.turn_on[0])
212 |
213 |
214 | class CorticalMicrocircuit(bp.dyn.Network):
215 | # Names for each layer:
216 | layer_name = ['L23e', 'L23i', 'L4e', 'L4i', 'L5e', 'L5i', 'L6e', 'L6i', 'Th']
217 |
218 | # Population size per layer:
219 | # 2/3e 2/3i 4e 4i 5e 5i 6e 6i Th
220 | layer_num = [20683, 5834, 21915, 5479, 4850, 1065, 14395, 2948, 902]
221 |
222 | # Layer-specific background input [nA]:
223 | # 2/3e 2/3i 4e 4i 5e 5i 6e 6i
224 | layer_specific_bg = np.array([1600, 1500, 2100, 1900, 2000, 1900, 2900, 2100]) / 1000
225 |
226 | # Layer-independent background input [nA]:
227 | # 2/3e 2/3i 4e 4i 5e 5i 6e 6i
228 | layer_independent_bg = np.array([2000, 1850, 2000, 1850, 2000, 1850, 2000, 1850]) / 1000
229 |
230 | # Prob. connection table
231 | conn_table = np.array([[0.101, 0.169, 0.044, 0.082, 0.032, 0.0000, 0.008, 0.000, 0.0000],
232 | [0.135, 0.137, 0.032, 0.052, 0.075, 0.0000, 0.004, 0.000, 0.0000],
233 | [0.008, 0.006, 0.050, 0.135, 0.007, 0.0003, 0.045, 0.000, 0.0983],
234 | [0.069, 0.003, 0.079, 0.160, 0.003, 0.0000, 0.106, 0.000, 0.0619],
235 | [0.100, 0.062, 0.051, 0.006, 0.083, 0.3730, 0.020, 0.000, 0.0000],
236 | [0.055, 0.027, 0.026, 0.002, 0.060, 0.3160, 0.009, 0.000, 0.0000],
237 | [0.016, 0.007, 0.021, 0.017, 0.057, 0.0200, 0.040, 0.225, 0.0512],
238 | [0.036, 0.001, 0.003, 0.001, 0.028, 0.0080, 0.066, 0.144, 0.0196]])
239 |
240 | def __init__(self, bg_type=0, stim_type=0, conn_type=0, poisson_freq=8., has_thalamus=False):
241 | super(CorticalMicrocircuit, self).__init__()
242 |
243 | # parameters
244 | self.bg_type = bg_type
245 | self.stim_type = stim_type
246 | self.conn_type = conn_type
247 | self.poisson_freq = poisson_freq
248 | self.has_thalamus = has_thalamus
249 |
250 | # NEURON: populations
251 | self.populations = bp.Collector()
252 | for i in range(8):
253 | l_name = self.layer_name[i]
254 | print(f'Creating {l_name} ...')
255 | self.populations[l_name] = LIF(self.layer_num[i])
256 |
257 | # SYNAPSE: synapses
258 | self.synapses = bp.Collector()
259 | for c in range(8): # from
260 | for r in range(8): # to
261 | if self.conn_table[r, c] > 0.:
262 | print(f'Creating Synapses from {self.layer_name[c]} to {self.layer_name[r]} ...')
263 | syn = ExpSyn(pre=self.populations[self.layer_name[c]],
264 | post=self.populations[self.layer_name[r]],
265 | prob=self.conn_table[r, c],
266 | syn_type=self.layer_name[c][-1],
267 | conn_type=conn_type)
268 | self.synapses[f'{self.layer_name[c]}_to_{self.layer_name[r]}'] = syn
269 | # Synaptic weight from L4e to L2/3e is doubled
270 | self.synapses['L4e_to_L23e'].weights *= 2.
271 |
272 | # NEURON & SYNAPSE: poisson inputs
273 | if stim_type == 0:
274 | # print(f'Creating Poisson noise group ...')
275 | # self.populations['Poisson'] = PoissonInput2(
276 | # freq=poisson_freq, pops=[self.populations[k] for k in self.layer_name[:-1]])
277 | for r in range(0, 8):
278 | l_name = self.layer_name[r]
279 | print(f'Creating Poisson group of {l_name} ...')
280 | N = PoissonInput(freq=poisson_freq, post=self.populations[l_name])
281 | self.populations[f'Poisson_to_{l_name}'] = N
282 | elif stim_type == 1:
283 | bg_inputs = self._get_bg_inputs(bg_type)
284 | assert bg_inputs is not None
285 | for i, current in enumerate(bg_inputs):
286 | self.populations[self.layer_name[i]].Iext = 0.3512 * current
287 |
288 | # NEURON & SYNAPSE: thalamus inputs
289 | if has_thalamus:
290 | thalamus = bp.dyn.PoissonInput(self.layer_num[-1], freqs=15.)
291 | self.populations[self.layer_name[-1]] = thalamus
292 | for r in range(0, 8):
293 | l_name = self.layer_name[r]
294 | print(f'Creating Thalamus projection of {l_name} ...')
295 | S = ThalamusInput(pre=thalamus,
296 | post=self.populations[l_name],
297 | conn_prob=self.conn_table[r, 8])
298 | self.synapses[f'{self.layer_name[-1]}_to_{l_name}'] = S
299 |
300 | # finally, compose them as a network
301 | self.register_implicit_nodes(self.populations)
302 | self.register_implicit_nodes(self.synapses)
303 |
304 | def _get_bg_inputs(self, bg_type):
305 | if bg_type == 0: # layer-specific
306 | bg_layer = self.layer_specific_bg
307 | elif bg_type == 1: # layer-independent
308 | bg_layer = self.layer_independent_bg
309 | elif bg_type == 2: # layer-independent-random
310 | bg_layer = np.zeros(8)
311 | for i in range(0, 8, 2):
312 | # randomly choosing a number for the external input to an excitatory population:
313 | exc_bound = [self.layer_specific_bg[i], self.layer_independent_bg[i]]
314 | exc_input = np.random.uniform(min(exc_bound), max(exc_bound))
315 | # randomly choosing a number for the external input to an inhibitory population:
316 | T = 0.1 if i != 6 else 0.2
317 | inh_bound = ((1 - T) / (1 + T)) * exc_input # eq. 4 from the article
318 | inh_input = np.random.uniform(inh_bound, exc_input)
319 | # array created to save the values:
320 | bg_layer[i] = int(exc_input)
321 | bg_layer[i + 1] = int(inh_input)
322 | else:
323 | bg_layer = None
324 | return bg_layer
325 |
326 |
327 | bm.random.seed()
328 | net = CorticalMicrocircuit(conn_type=1, poisson_freq=8., stim_type=1, bg_type=0)
329 | sps_monitors = [f'{n}.spike' for n in net.layer_name[:-1]]
330 | runner = bp.StructRunner(net, monitors=sps_monitors)
331 | runner.run(1000.)
332 |
333 | spikes = np.hstack([runner.mon[name] for name in sps_monitors])
334 | bp.visualize.raster_plot(runner.mon.ts, spikes, show=True)
335 |
336 |
337 | # bp.visualize.line_plot(runner.mon.ts, runner.mon['L4e.V'], plot_ids=[0, 1, 2], show=True)
338 |
339 | # def run1():
340 | # fig, gs = bp.visualize.get_figure(8, 1, col_len=8, row_len=1)
341 | # for i in range(8):
342 | # fig.add_subplot(gs[i, 0])
343 | # name = net.layer_name[i]
344 | # bp.visualize.raster_plot(runner.mon.ts, runner.mon[f'{name}.spike'],
345 | # xlabel='Time [ms]' if i == 7 else None,
346 | # ylabel=name, show=i == 7)
347 |
348 |
349 | # if __name__ == '__main__':
350 | # run1()
351 |
--------------------------------------------------------------------------------
/large_scale_modeling/Joglekar_2018_data/efelenMatpython.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/large_scale_modeling/Joglekar_2018_data/efelenMatpython.mat
--------------------------------------------------------------------------------
/large_scale_modeling/Joglekar_2018_data/hierValspython.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/large_scale_modeling/Joglekar_2018_data/hierValspython.mat
--------------------------------------------------------------------------------
/large_scale_modeling/Joglekar_2018_data/subgraphData.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/large_scale_modeling/Joglekar_2018_data/subgraphData.mat
--------------------------------------------------------------------------------
/large_scale_modeling/Joglekar_2018_data/subgraphWiring29.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/large_scale_modeling/Joglekar_2018_data/subgraphWiring29.mat
--------------------------------------------------------------------------------
/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/recurrent_networks/Laje_Buonomano_2013_robust_timing_rnn.py:
--------------------------------------------------------------------------------
1 | # ---
2 | # jupyter:
3 | # jupytext:
4 | # formats: ipynb,py:percent
5 | # text_representation:
6 | # extension: .py
7 | # format_name: percent
8 | # format_version: '1.3'
9 | # jupytext_version: 1.11.5
10 | # kernelspec:
11 | # display_name: brainpy
12 | # language: python
13 | # name: brainpy
14 | # ---
15 |
16 | # %% [markdown]
17 | # # *(Laje & Buonomano, 2013)* Robust Timing in RNN
18 |
19 | # %% [markdown]
20 | # Implementation of the paper:
21 | #
22 | # - Laje, Rodrigo, and Dean V. Buonomano. "Robust timing and motor patterns by taming chaos in recurrent neural networks." Nature neuroscience 16, no. 7 (2013): 925-933. http://www.ncbi.nlm.nih.gov/pmc/articles/PMC3753043
23 | #
24 | # Thanks to the original implementation codes: https://github.com/ReScience-Archives/Vitay-2016
25 |
26 | # %%
27 | import brainpy as bp
28 | import brainpy.math as bm
29 |
30 |
31 | # %% [markdown]
32 | # ## Model Descriptions
33 |
34 | # %% [markdown]
35 | # ### Recurrent network
36 | #
37 | # The recurrent network is composed of $N=800$ neurons, receiving inputs from a variable number of input neurons $N_i$ and sparsely connected with each other. Each neuron's firing rate $r_i(t)$ applies the `tanh` transfer function on an internal variable $x_i(t)$ which follows a first-order linear ordinary differential equation (ODE):
38 | #
39 | # $$
40 | # \tau \cdot \frac{d x_i(t)}{dt} = - x_i(t) + \sum_{j=1}^{N_i} W_{ij}^{in} \cdot y_j(t) + \sum_{j=1}^{N} W_{ij}^{rec} \cdot r_j(t) + I^{noise}_i(t)
41 | # $$
42 | #
43 | # $$
44 | # r_i(t) = \tanh(x_i(t))
45 | # $$
46 | #
47 | # The weights of the input matrix $W^{in}$ are taken from the normal distribution, with mean 0 and variance 1, and multiply the rates of the input neurons $y_j(t)$. The variance of these weights does not depend on the number of inputs, as only one input neuron is activated at the same time. The recurrent connectivity matrix $W^{rec}$ is sparse with a connection probability $pc = 0.1$ (i.e. 64000 non-zero elements) and existing weights are taken randomly from a normal distribution with mean 0 and variance $g/\sqrt{pc \cdot N}$, where $g$ is a scaling factor. It is a well known result for sparse recurrent networks that for high values of $g$, the network dynamics become chaotic. $I^{noise}_i(t)$ is an additive noise, taken randomly at each time step and for each neuron from a normal distribution with mean 0 and variance $I_0$. $I_0$ is chosen very small in the experiments reproduced here ($I_0 = 0.001$) but is enough to highlight the chaotic behavior of the recurrent neurons (non-reproducibility between two trials).
48 |
49 | # %% [markdown]
50 | # ### Read-out neurons
51 | #
52 | # The read-out neurons simply sum the activity of the recurrent neurons using a matrix $W^{out}$:
53 | #
54 | # $$
55 | # z_i(t) = \sum_{j=1}^{N} W_{ij}^{out} \cdot r_j(t)
56 | # $$
57 | #
58 | # The read-out matrix is initialized randomly from the normal distribution with mean 0 and variance $1/\sqrt{N}$.
59 |
60 | # %% [markdown]
61 | # ### Learning rule
62 | #
63 | # The particularity of the reservoir network proposed by *(Laje & Buonomano, Nature Neuroscience, 2013)* is that both the recurrent weights $W^{rec}$ and the read-out weights are trained in a supervised manner. More precisely, in this implementation, only 60% of the recurrent neurons have plastic weights, the 40% others keep the same weights throughout the simulation.
64 | #
65 | # Learning is done using the recursive least squares (RLS) algorithm *(Haykin, 2002)*. It is a supervised error-driven learning rule, i.e. the weight changes depend on the error made by each neuron: the difference between the firing rate of a neuron $r_i(t)$ and a desired value $R_i(t)$.
66 | #
67 | # $$
68 | # e_i(t) = r_i(t) - R_i(t)
69 | # $$
70 | #
71 | # For the recurrent neurons, the desired value is the recorded rate of that neuron during an initial trial (to enforce the reproducibility of the trajectory). For the read-out neurons, it is a function which we want the network to reproduce (e.g. handwriting as in Fig. 2).
72 | #
73 | # Contrary to the delta learning rule which modifies weights proportionally to the error and to the direct input to a synapse ($\Delta w_{ij} = - \eta \cdot e_i \cdot r_j$), the RLS learning uses a running estimate of the inverse correlation matrix of the inputs to each neuron:
74 | #
75 | # $$
76 | # \Delta w_{ij} = - e_i \sum_{k \in \mathcal{B}(i)} P^i_{jk} \cdot r_k
77 | # $$
78 | #
79 | # Each neuron $i$ therefore stores a square matrix $P^i$, whose size depends of the number of weights arriving to the neuron. Read-out neurons receive synapses from all $N$ recurrent neurons, so the $P$ matrix is $N*N$. Recurrent units have a sparse random connectivity ($pc = 0.1$), so each recurrent neuron stores only a $80*80$ matrix on average. In the previous equation, $\mathcal{B}(i)$ represents those existing weights.
80 | #
81 | # The inverse correlation matrix $P$ is updated at each time step with the following rule:
82 | #
83 | # $$
84 | # \Delta P^i_{jk} = - \frac{\sum_{m \in \mathcal{B}(i)} \sum_{n \in \mathcal{B}(i)} P^i_{jm} \cdot r_m \cdot r_n \cdot P^i_{nk} }{ 1 + \sum_{m \in \mathcal{B}(i)} \sum_{n \in \mathcal{B}(i)} r_m \cdot P^i_{mn} \cdot r_n}
85 | # $$
86 | #
87 | # Each matrix $P^i$ is initialized to the diagonal matrix and scaled by a factor $1/\delta$, where $\delta$ is 1 in the current implementation and can be used to modify implicitly the learning rate *(Sussilo & Larry, Neuron, 2009)*.
88 |
89 | # %% [markdown]
90 | # **Matrix/Vector mode**
91 | #
92 | # The above algorithms can be written into the matrix/vector forms. For the recurrent units, each matrix $P^i$ has a different size ($80*80$ on average), so we will still need to iterate over all post-synaptic neurons. If we note $\mathbf{W}$ the vector of weights coming to a neuron (80 on average), $\mathbf{r}$ the corresponding vector of firing rates (also 80), $e$ the error of that neuron (a scalar) and $\mathbf{P}$ the inverse correlation matrix (80*80), the update rules become:
93 | # $$
94 | # \Delta \mathbf{W} = - e \cdot \mathbf{P} \cdot \mathbf{r}
95 | # $$
96 | #
97 | # $$
98 | # \Delta \mathbf{P} = - \frac{(\mathbf{P} \cdot \mathbf{r}) \cdot (\mathbf{P} \cdot \mathbf{r})^T}{1 + \mathbf{r}^T \cdot \mathbf{P} \cdot \mathbf{r}}
99 | # $$
100 | #
101 | # In the original Matlab code, one notices that the weight update $\Delta \mathbf{W}$ is also normalized by the denominator of the update rule for $\mathbf{P}$:
102 | #
103 | # $$
104 | # \Delta \mathbf{W} = - e \cdot \frac{\mathbf{P} \cdot \mathbf{r}}{1 + \mathbf{r}^T \cdot \mathbf{P} \cdot \mathbf{r}}
105 | # $$
106 | #
107 | # Removing this normalization from the learning rule impairs learning completely, so we kept this variant of the RLS rule in our implementation.
108 |
109 | # %% [markdown]
110 | # ### Training procedure
111 | #
112 | # The training procedure is split into different trials, which differ from one experiment to another (Fig. 1, Fig. 2 and Fig. 3). Each trial begins with a relaxation period of `t_offset` = 200 ms, followed by a brief input impulse of duration 50 ms and variable amplitude. This impulse has the effect of bringing all recurrent neurons into a deterministic state (due to the `tanh` transfer function, the rates are saturated at either +1 or -1). This impulse is followed by a training window of variable length (in the seconds range) and finally another relaxation period. In Fig. 1 and Fig. 2, an additional impulse (duration 10 ms, smaller amplitude) can be given a certain delay after the initial impulse to test the ability of the network to recover its acquired trajectory after learning.
113 | #
114 | # In all experiments, the first trial is used to acquire an innate trajectory for the recurrent neurons in the absence of noise ($I_0$ is set to 0). The firing rate of all recurrent neurons over the training window is simply recorded and stored in an array without applying the learning rules. This innate trajectory for each recurrent neuron is used in the following trials as the target for the RLS learning rule, this time in the presence of noise ($I_0 = 0.001$). The RLS learning rule itself is only applied to the recurrent neurons during the training window, not during the impulse or the relaxation periods. Such a learning trial is repeated 20 or 30 times depending on the experiments. Once the recurrent weights have converged and the recurrent neurons are able to reproduce the innate trajectory, the read-out weights are trained using a custom desired function as target (10 trials).
115 |
116 | # %% [markdown]
117 | # ### References
118 | #
119 | # - Laje, Rodrigo, and Dean V. Buonomano. "Robust timing and motor patterns by taming chaos in recurrent neural networks." Nature neuroscience 16, no. 7 (2013): 925-933.
120 | # - Haykin, Simon S. Adaptive filter theory. Pearson Education India, 2008.
121 | # - Sussillo, David, and Larry F. Abbott. "Generating coherent patterns of activity from chaotic neural networks." Neuron 63, no. 4 (2009): 544-557.
122 |
123 | # %% [markdown]
124 | # ## Implementation
125 |
126 | # %%
127 | class RNN(bp.Base):
128 | target_backend = 'numpy'
129 |
130 | def __init__(self, num_input=2, num_rec=800, num_output=1, tau=10.0,
131 | g=1.5, pc=0.1, noise=0.001, delta=1.0, plastic_prob=0.6, dt=1.):
132 | super(RNN, self).__init__()
133 |
134 | # Copy the parameters
135 | self.num_input = num_input # number of input neurons
136 | self.num_rec = num_rec # number of recurrent neurons
137 | self.num_output = num_output # number of read-out neurons
138 | self.tau = tau # time constant of the neurons
139 | self.g = g # synaptic strength scaling
140 | self.pc = pc # connection probability
141 | self.noise = noise # noise variance
142 | self.delta = delta # initial value of the P matrix
143 | self.plastic_prob = plastic_prob # percentage of neurons receiving plastic synapses
144 | self.dt = dt # numerical precision
145 |
146 | # Initializes the network including the weight matrices.
147 | # ---
148 |
149 | # Recurrent population
150 | self.x = bm.Variable(bm.random.uniform(-1.0, 1.0, (self.num_rec,)))
151 | self.r = bm.Variable(bm.tanh(self.x))
152 |
153 | # Read-out population
154 | self.z = bm.Variable(bm.zeros(self.num_output, dtype=bm.float_))
155 |
156 | # Weights between the input and recurrent units
157 | self.W_in = bm.Variable(bm.random.randn(self.num_input, self.num_rec))
158 |
159 | # Weights between the recurrent units
160 | self.W_rec = bm.Variable(bm.random.randn(self.num_rec, self.num_rec) * self.g / bm.sqrt(self.pc * self.num_rec))
161 |
162 | # The connection pattern is sparse with p=0.1
163 | self.conn_mask = bm.random.random((self.num_rec, self.num_rec)) < self.pc
164 | diag = bm.arange(self.num_rec)
165 | self.conn_mask[diag, diag] = False
166 | self.W_rec *= bm.asarray(self.conn_mask, dtype=bm.float_)
167 |
168 | # Store the pre-synaptic neurons to each plastic neuron
169 | train_mask = bm.random.random((self.num_rec, self.num_rec)) < plastic_prob
170 | self.train_mask = bm.logical_and(train_mask, self.conn_mask)
171 |
172 | # Store the pre-synaptic neurons to each plastic neuron
173 | self.W_plastic = bm.Variable([bm.array(bm.nonzero(self.train_mask[i])[0]) for i in range(self.num_rec)])
174 |
175 | # Inverse correlation matrix of inputs for learning recurrent weights
176 | self.P_rec = bm.Variable([bm.identity(len(self.W_plastic[i])) / self.delta for i in range(self.num_rec)])
177 |
178 | # Output weights
179 | self.W_out = bm.Variable(bm.random.randn(self.num_output, self.num_rec) / bm.sqrt(self.num_rec))
180 |
181 | # Inverse correlation matrix of inputs for learning readout weights
182 | P_out = bm.expand_dims(bm.identity(self.num_rec) / self.delta, axis=0)
183 | self.P_out = bm.Variable(bm.repeat(P_out, self.num_output, axis=0))
184 |
185 | # loss
186 | self.loss = bm.Variable(bm.array([0.]))
187 |
188 | def simulate(self, stimulus, noise=True, target_traj=None,
189 | learn_start=-1, learn_stop=-1, learn_readout=False):
190 | """Simulates the recurrent network for the given duration, with or without plasticity.
191 |
192 | Parameters
193 | ----------
194 |
195 | stimulus : bm.ndarray
196 | The external inputs.
197 | noise : bool
198 | if noise should be added to the recurrent units (default: True)
199 | target_traj : bm,ndarray
200 | During learning, defines which output target_traj function should be learned (default: no learning)
201 | learn_start : int, float
202 | Time when learning should start.
203 | learn_stop : int, float
204 | Time when learning should stop.
205 | learn_readout : bool
206 | Defines whether the recurrent (False) or readout (True) weights should be learned.
207 | """
208 |
209 | length, _ = stimulus.shape
210 |
211 | # Arrays for recording
212 | record_r = bm.zeros((length, self.num_rec), dtype=bm.float_)
213 | record_z = bm.zeros((length, self.num_output), dtype=bm.float_)
214 |
215 | # Reset the recurrent population
216 | self.x[:] = bm.random.uniform(-1.0, 1.0, (self.num_rec,))
217 | self.r[:] = bm.tanh(self.x)
218 |
219 | # Reset loss term
220 | self.loss[:] = 0.0
221 |
222 | # Simulate for the desired duration
223 | for i in range(length):
224 | # Update the neurons' firing rates
225 | self.update(stimulus[i, :], noise)
226 |
227 | # Recording
228 | record_r[i] = self.r
229 | record_z[i] = self.z
230 |
231 | # Learning
232 | if target_traj is not None and learn_stop > i * self.dt >= learn_start and i % 2 == 0:
233 | if not learn_readout:
234 | self.train_recurrent(target_traj[i])
235 | else:
236 | self.train_readout(target_traj[i])
237 |
238 | return record_r, record_z, self.loss
239 |
240 | def update(self, stimulus, noise=True):
241 | """Updates neural variables for a single simulation step."""
242 | dx = -self.x + bm.dot(stimulus, self.W_in) + bm.dot(self.W_rec, self.r)
243 | if noise:
244 | dx += self.noise * bm.random.randn(self.num_rec)
245 | self.x += dx / self.tau * self.dt
246 | self.r[:] = bm.tanh(self.x)
247 | self.z[:] = bm.dot(self.W_out, self.r)
248 |
249 | def train_recurrent(self, target):
250 | """Apply the RLS learning rule to the recurrent weights."""
251 |
252 | # Compute the error of the recurrent neurons
253 | error = self.r - target # output_traj : (num_rec, )
254 | self.loss += bm.mean(error ** 2)
255 |
256 | # Apply the FORCE learning rule to the recurrent weights
257 | for i in range(self.num_rec): # for each plastic post neuron
258 | # Get the rates from the plastic synapses only
259 | r_plastic = bm.expand_dims(self.r[self.W_plastic[i]], axis=1)
260 | # Multiply with the inverse correlation matrix P*R
261 | PxR = bm.dot(self.P_rec[i], r_plastic)
262 | # Normalization term 1 + R'*P*R
263 | RxPxR = (1. + bm.dot(r_plastic.T, PxR))
264 | # Update the inverse correlation matrix P <- P - ((P*R)*(P*R)')/(1+R'*P*R)
265 | self.P_rec[i][:] -= bm.dot(PxR, PxR.T) / RxPxR
266 | # Learning rule W <- W - e * (P*R)/(1+R'*P*R)
267 | for j, idx in enumerate(self.W_plastic[i]):
268 | self.W_rec[i, idx] -= error[i] * (PxR[j, 0] / RxPxR[0, 0])
269 |
270 | def train_readout(self, target):
271 | """Apply the RLS learning rule to the readout weights."""
272 |
273 | # Compute the error of the output neurons
274 | error = self.z - target # output_traj : (O, )
275 | # loss
276 | self.loss += bm.mean(error ** 2)
277 | # Apply the FORCE learning rule to the readout weights
278 | for i in range(self.num_output): # for each readout neuron
279 | # Multiply the rates with the inverse correlation matrix P*R
280 | r = bm.expand_dims(self.r, axis=1)
281 | PxR = bm.dot(self.P_out[i], r)
282 | # Normalization term 1 + R'*P*R
283 | RxPxR = (1. + bm.dot(r.T, PxR))
284 | # Update the inverse correlation matrix P <- P - ((P*R)*(P*R)')/(1+R'*P*R)
285 | self.P_out[i] -= bm.dot(PxR, PxR.T) / RxPxR
286 | # Learning rule W <- W - e * (P*R)/(1+R'*P*R)
287 | self.W_out[i] -= error[i] * (PxR / RxPxR)[:, 0]
288 |
289 | def init(self):
290 | # Recurrent population
291 | self.x[:] = bm.random.uniform(-1.0, 1.0, (self.num_rec,))
292 | self.r[:] = bm.tanh(self.x)
293 |
294 | # Read-out population
295 | self.z[:] = bm.zeros((self.num_output,))
296 |
297 | # Weights between the input and recurrent units
298 | self.W_in[:] = bm.random.randn(self.num_input, self.num_rec)
299 |
300 | # Weights between the recurrent units
301 | self.W_rec[:] = bm.random.randn(self.num_rec, self.num_rec) * self.g / bm.sqrt(self.pc * self.num_rec)
302 | self.W_rec *= self.conn_mask
303 |
304 | # Inverse correlation matrix of inputs for learning recurrent weights
305 | for i in range(self.num_rec):
306 | self.P_rec[i][:] = bm.identity(len(self.W_plastic[i])) / self.delta
307 |
308 | # Output weights
309 | self.W_out[:] = bm.random.randn(self.num_output, self.num_rec) / bm.sqrt(self.num_rec)
310 |
311 | # Inverse correlation matrix of inputs for learning readout weights
312 | P_out = bm.expand_dims(bm.identity(self.num_rec) / self.delta, axis=0)
313 | self.P_out[:] = bm.repeat(P_out, self.num_output, axis=0)
314 |
315 | # loss
316 | self.loss[:] = 0.
317 |
318 |
319 | # %% [markdown]
320 | # ## Results
321 |
322 | # %%
323 | from Laje_Buonomano_2013_simulation import *
324 |
325 | # %% [markdown]
326 | # The ``simulation.py`` can be downloaded from [here](./Laje_Buonomano_2013_simulation.ipynb).
327 |
328 | # %% [markdown]
329 | # ### Fig. 1: an initially chaotic network can become deterministic after training
330 |
331 | # %%
332 | net = RNN(
333 | num_input=2, # Number of inputs
334 | num_rec=800, # Number of recurrent neurons
335 | num_output=1, # Number of read-out neurons
336 | tau=10.0, # Time constant of the neurons
337 | g=1.8, # Synaptic strength scaling
338 | pc=0.1, # Connection probability
339 | noise=0.001, # Noise variance
340 | delta=1.0, # Initial diagonal value of the P matrix
341 | plastic_prob=0.6, # Percentage of neurons receiving plastic synapses
342 | dt=1.,
343 | )
344 | net.update = bp.math.jit(net.update)
345 | net.train_recurrent = bp.math.jit(net.train_recurrent)
346 |
347 | fig1(net)
348 |
349 | # %% [markdown]
350 | # ### Fig. 2: read-out neurons robustly learn trajectories in the 2D space
351 |
352 | # %%
353 | net = RNN(
354 | num_input=4, # Number of inputs
355 | num_rec=800, # Number of recurrent neurons
356 | num_output=2, # Number of read-out neurons
357 | tau=10.0, # Time constant of the neurons
358 | g=1.5, # Synaptic strength scaling
359 | pc=0.1, # Connection probability
360 | noise=0.001, # Noise variance
361 | delta=1.0, # Initial value of the P matrix
362 | plastic_prob=0.6, # Percentage of neurons receiving plastic synapses
363 | dt=1.
364 | )
365 | net.update = bp.math.jit(net.update)
366 | net.train_recurrent = bp.math.jit(net.train_recurrent)
367 | net.train_readout = bp.math.jit(net.train_readout)
368 |
369 | fig2(net)
370 |
371 | # %% [markdown]
372 | # ### Fig. 3: the timing capacity of the recurrent network
373 |
374 | # %%
375 | nets = []
376 | for _ in range(10):
377 | net = RNN(
378 | num_input=2, # Number of inputs
379 | num_rec=800, # Number of recurrent neurons
380 | num_output=1, # Number of read-out neurons
381 | tau=10.0, # Time constant of the neurons
382 | g=1.5, # Synaptic strength scaling
383 | pc=0.1, # Connection probability
384 | noise=0.001, # Noise variance
385 | delta=1.0, # Initial diagonal value of the P matrix
386 | plastic_prob=0.6, # Percentage of neurons receiving plastic synapses
387 | dt=1.
388 | )
389 | net.update = bp.math.jit(net.update)
390 | net.train_recurrent = bp.math.jit(net.train_recurrent)
391 | nets.append(net)
392 |
393 | fig3(nets, verbose=False)
394 |
--------------------------------------------------------------------------------
/recurrent_networks/Laje_Buonomano_2013_simulation.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import time
3 |
4 | # %%
5 | import brainpy.math as bm
6 | import matplotlib.patches as patches
7 | import matplotlib.pyplot as plt
8 | import scipy.stats
9 | from mpl_toolkits.axes_grid1 import make_axes_locatable
10 | from scipy.io import loadmat
11 |
12 |
13 | # %%
14 | __all__ = [
15 | 'fig1',
16 | 'fig2',
17 | 'fig3',
18 | ]
19 |
20 |
21 | # %%
22 | dt = 1.
23 | time2len = lambda t: int(t / dt)
24 |
25 |
26 | # %%
27 | def fig1(net, verbose=True):
28 | # Parameters
29 | # ----------
30 | num_rec_train = 30 # Number of learning trials for the recurrent weights
31 | num_readout_train = 10 # Number of learning trials for the readout weights
32 | num_perturbation = 5 # Number of perturbation trials
33 |
34 | stimulus_amplitude = 5.0 # Amplitude of the input pulse
35 | t_offset = time2len(200) # Time to wait before the stimulation
36 | d_stim = time2len(50) # Duration of the stimulation
37 | d_trajectory = time2len(2000 + 150) # Duration of the desired target_traj
38 | t_relax = time2len(550) # Duration to relax after the target_traj
39 | trial_duration = t_offset + d_stim + d_trajectory + t_relax # Total duration of a trial
40 | times = bm.arange(0, trial_duration, dt)
41 |
42 | perturbation_amplitude = 0.5 # Amplitude of the perturbation pulse
43 | t_perturb = time2len(500) # Offset for the perturbation
44 | d_perturb = time2len(10) # Duration of the perturbation
45 |
46 | target_baseline = 0.2 # Baseline of the output_traj function
47 | target_amplitude = 1. # Maximal value of the output_traj function
48 | target_width = 30. # Width of the Gaussian
49 | target_time = time2len(d_trajectory - 150) # Peak time within the learning interval
50 |
51 | # Input definitions
52 | # -----------------
53 |
54 | # Impulse after 200 ms
55 | impulse = bm.zeros((trial_duration, net.num_input))
56 | impulse[t_offset:t_offset + d_stim, 0] = stimulus_amplitude
57 |
58 | # Perturbation during the trial
59 | perturbation = bm.zeros((trial_duration, net.num_input))
60 | perturbation[t_offset: t_offset + d_stim, 0] = stimulus_amplitude
61 | perturbation[t_offset + t_perturb: t_offset + t_perturb + d_perturb, 1] = perturbation_amplitude
62 |
63 | # Target output for learning the readout weights
64 | output_traj = bm.zeros((trial_duration, net.num_output))
65 | output_traj[:, 0] = target_baseline + (target_amplitude - target_baseline) * \
66 | bm.exp(-(t_offset + d_stim + target_time - times) ** 2 / target_width ** 2)
67 |
68 | # Main procedure
69 | # --------------
70 | tstart = time.time()
71 |
72 | # Initial trial to determine the innate target_traj
73 | if verbose: print('Initial trial to determine a target_traj (without noise)')
74 | initial_traj, initial_output, _ = net.simulate(stimulus=impulse, noise=False)
75 |
76 | # Pre-training test trial
77 | if verbose: print('Pre-training test trial')
78 | pretrain_traj, pretrain_output, _ = net.simulate(stimulus=impulse, noise=True)
79 |
80 | # Perturbation trial
81 | if verbose: print(num_perturbation, 'perturbation trials')
82 | perturbation_initial = []
83 | for i in range(num_perturbation):
84 | _, perturbation_output, _ = net.simulate(stimulus=perturbation, noise=True)
85 | perturbation_initial.append(perturbation_output)
86 |
87 | # 20 trials of learning for the recurrent weights
88 | for i in range(num_rec_train):
89 | t0 = time.time()
90 | if verbose: print(f'Learning trial recurrent {i + 1} loss: ', end='')
91 | _, _, loss = net.simulate(stimulus=impulse,
92 | target_traj=initial_traj,
93 | learn_start=t_offset + d_stim,
94 | learn_stop=t_offset + d_stim + d_trajectory)
95 | if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s')
96 |
97 | # 10 trials of learning for the readout weights
98 | for i in range(num_readout_train):
99 | t0 = time.time()
100 | if verbose: print(f'Learning trial readout {i + 1} loss: ', end='')
101 | _, _, loss = net.simulate(stimulus=impulse,
102 | target_traj=output_traj,
103 | learn_start=t_offset + d_stim,
104 | learn_stop=t_offset + d_stim + d_trajectory,
105 | learn_readout=True)
106 | if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s')
107 |
108 | # Test trial
109 | if verbose: print('2 test trials')
110 | reproductions = []
111 | final_outputs = []
112 | for i in range(2):
113 | reproduction, final_output, _ = net.simulate(stimulus=impulse, noise=True)
114 | reproductions.append(reproduction)
115 | final_outputs.append(final_output)
116 |
117 | # Perturbation trial
118 | if verbose: print(num_perturbation, 'perturbation trials')
119 | perturbation_final = []
120 | for i in range(num_perturbation):
121 | _, perturbation_output, _ = net.simulate(stimulus=perturbation)
122 | perturbation_final.append(perturbation_output)
123 |
124 | if verbose: print('Simulation done in', time.time() - tstart, 'seconds.')
125 |
126 | # Visualization
127 | # -------------
128 |
129 | plt.figure(figsize=(8, 12))
130 |
131 | # innate trajectory
132 | ax = plt.subplot2grid((4, 2), (0, 0), colspan=2)
133 | im = ax.imshow(initial_traj[:, :100].T, aspect='auto', origin='lower')
134 | cax = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.05)
135 | plt.colorbar(im, cax=cax)
136 | ymin, ymax = ax.get_ylim()
137 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.2))
138 | ax.set_title('Innate Trajectory')
139 | ax.set_xlabel('Time (ms)')
140 | ax.set_ylabel('Recurrent units')
141 |
142 | # pre-training results
143 | ax = plt.subplot2grid((4, 2), (1, 0))
144 | for i in range(3):
145 | ax.plot(times, initial_traj[:, i] + i * 2 + 1, 'b')
146 | ax.plot(times, pretrain_traj[:, i] + i * 2 + 1, 'r')
147 | ax.set_yticks([i * 2 + 1 for i in range(3)])
148 | ax.set_yticklabels([0, 0, 0])
149 | ymin, ymax = ax.get_ylim()
150 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
151 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
152 | ax.set_title('Pre-training')
153 | ax.set_ylabel('Firing rate $r$ [Hz]')
154 | ax.set_xlim(times[0], times[-1])
155 |
156 | ax = plt.subplot2grid((4, 2), (2, 0))
157 | ax.plot(times, initial_output[:, 0], 'b')
158 | ax.plot(times, pretrain_output[:, 0], 'r')
159 | ax.axhline(output_traj[0, 0], c='k')
160 | ax.set_yticks([-2, -1, 0, 1, 2])
161 | ax.set_ylim((-2, 2))
162 | ymin, ymax = ax.get_ylim()
163 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
164 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
165 | ax.set_ylabel('Output (test)')
166 | ax.set_xlim(times[0], times[-1])
167 |
168 | ax = plt.subplot2grid((4, 2), (3, 0))
169 | for i in range(num_perturbation):
170 | ax.plot(times, perturbation_initial[i][:, 0])
171 | ax.axhline(output_traj[0, 0], c='k')
172 | ax.set_yticks([-2, -1, 0, 1, 2])
173 | ax.set_ylim((-2, 2))
174 | ymin, ymax = ax.get_ylim()
175 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
176 | ax.add_patch(patches.Rectangle((t_offset + t_perturb, ymin), d_perturb, ymax - ymin, color='gray', alpha=0.7))
177 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
178 | ax.set_xlabel('Time (ms)')
179 | ax.set_ylabel('Output (perturbed)')
180 | ax.set_xlim(times[0], times[-1])
181 |
182 | # post-training results
183 |
184 | ax = plt.subplot2grid((4, 2), (1, 1))
185 | for i in range(3):
186 | ax.plot(times, reproductions[0][:, i] + i * 2 + 1, 'b')
187 | ax.plot(times, reproductions[1][:, i] + i * 2 + 1, 'r')
188 | ax.set_yticks([i * 2 + 1 for i in range(3)])
189 | ax.set_yticklabels([0, 0, 0])
190 | ymin, ymax = ax.get_ylim()
191 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
192 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
193 | ax.set_title('Post-training')
194 | ax.set_xlim(times[0], times[-1])
195 |
196 | ax = plt.subplot2grid((4, 2), (2, 1))
197 | ax.plot(times, final_outputs[0][:, 0], 'b')
198 | ax.plot(times, final_outputs[1][:, 0], 'r')
199 | ax.axhline(output_traj[0, 0], c='k')
200 | ax.set_yticks([-2, -1, 0, 1, 2])
201 | ax.set_ylim((-2, 2))
202 | ymin, ymax = ax.get_ylim()
203 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
204 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
205 | ax.set_xlim(times[0], times[-1])
206 |
207 | ax = plt.subplot2grid((4, 2), (3, 1))
208 | for i in range(num_perturbation):
209 | ax.plot(times, perturbation_final[i][:, 0])
210 | ax.axhline(output_traj[0, 0], c='k')
211 | ax.set_yticks([-2, -1, 0, 1, 2])
212 | ax.set_ylim((-2, 2))
213 | ymin, ymax = ax.get_ylim()
214 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7))
215 | ax.add_patch(patches.Rectangle((t_offset + t_perturb, ymin), d_perturb, ymax - ymin, color='gray', alpha=0.7))
216 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1))
217 | ax.set_xlabel('Time (ms)')
218 | ax.set_xlim(times[0], times[-1])
219 |
220 | plt.tight_layout()
221 | plt.show()
222 |
223 |
224 | # %%
225 | def fig2(net, verbose=True):
226 | # Parameters
227 | # ----------
228 |
229 | num_rec_train = 30 # Number of learning trials for the recurrent weights
230 | num_readout_train = 10 # Number of learning trials for the readout weights
231 | num_test = 5 # Number of test trials
232 | num_perturb = 5 # Number of perturbation trials
233 |
234 | stimulus_amplitude = 2.0 # Amplitude of the input pulse
235 | t_offset = time2len(200) # Time to wait before the stimulation
236 | d_stim = time2len(50) # Duration of the stimulation
237 | t_relax = time2len(150) # Duration to relax after the target_traj
238 |
239 | perturbation_amplitude = 0.2 # Amplitude of the perturbation pulse
240 | t_perturb = time2len(300) # Offset for the perturbation
241 | d_perturb = time2len(10) # Duration of the perturbation
242 |
243 | # Input definitions
244 | # -----------------
245 |
246 | # Retrieve the targets and reformat them
247 | targets = loadmat('data/DAC_handwriting_output_targets.mat')
248 | chaos = targets['chaos']
249 | neuron = targets['neuron']
250 |
251 | # Durations
252 | _, d_chaos = chaos.shape
253 | _, d_neuron = neuron.shape
254 |
255 | # Impulses
256 | impulse_chaos = bm.zeros((t_offset + d_stim + d_chaos + t_relax, net.num_input))
257 | impulse_chaos[t_offset:t_offset + d_stim, 0] = stimulus_amplitude
258 | impulse_neuron = bm.zeros((t_offset + d_stim + d_neuron + t_relax, net.num_input))
259 | impulse_neuron[t_offset:t_offset + d_stim, 2] = stimulus_amplitude
260 |
261 | # Perturbation
262 | perturbation_chaos = bm.zeros((t_offset + d_stim + d_chaos + t_relax, net.num_input))
263 | perturbation_chaos[t_offset:t_offset + d_stim, 0] = stimulus_amplitude
264 | perturbation_chaos[t_offset + t_perturb: t_offset + t_perturb + d_perturb, 1] = perturbation_amplitude
265 | perturbation_neuron = bm.zeros((t_offset + d_stim + d_neuron + t_relax, net.num_input))
266 | perturbation_neuron[t_offset:t_offset + d_stim, 2] = stimulus_amplitude
267 | perturbation_neuron[t_offset + t_perturb: t_offset + t_perturb + d_perturb, 3] = perturbation_amplitude
268 |
269 | # Targets
270 | target_chaos = bm.zeros((t_offset + d_stim + d_chaos + t_relax, net.num_output))
271 | target_chaos[t_offset + d_stim: t_offset + d_stim + d_chaos, :] = chaos.T
272 | target_neuron = bm.zeros((t_offset + d_stim + d_neuron + t_relax, net.num_output))
273 | target_neuron[t_offset + d_stim: t_offset + d_stim + d_neuron, :] = neuron.T
274 |
275 | # Main procedure
276 | # --------------
277 | tstart = time.time()
278 |
279 | # Initial trial to determine the innate target_traj
280 | if verbose: print('Initial chaos trial')
281 | trajectory_chaos, initial_chaos_output, _ = net.simulate(stimulus=impulse_chaos, noise=False)
282 | if verbose: print('Initial neuron trial')
283 | trajectory_neuron, initial_neuron_output, _ = net.simulate(stimulus=impulse_neuron, noise=False)
284 |
285 | # learning for the recurrent weights
286 | for i in range(num_rec_train):
287 | t0 = time.time()
288 | if verbose: print(f'Learning recurrent {i + 1} "chaos" loss: ', end='')
289 | _, _, loss = net.simulate(stimulus=impulse_chaos,
290 | target_traj=trajectory_chaos,
291 | learn_start=t_offset + d_stim,
292 | learn_stop=t_offset + d_stim + d_chaos)
293 | if verbose: print(f'{(2 * loss[0] / d_chaos):.5f} used {(time.time() - t0):5f} s, ', end='')
294 |
295 | t0 = time.time()
296 | _, _, loss = net.simulate(stimulus=impulse_neuron,
297 | target_traj=trajectory_neuron,
298 | learn_start=t_offset + d_stim,
299 | learn_stop=t_offset + d_stim + d_neuron)
300 | if verbose: print(f'"neuron" loss: {(2 * loss[0] / d_chaos):5f} used {(time.time() - t0):5f} s')
301 |
302 | # learning for the readout weights
303 | for i in range(num_readout_train):
304 | t0 = time.time()
305 | if verbose: print(f'Learning readout {i + 1} "chaos" loss: ', end='')
306 | _, _, loss = net.simulate(stimulus=impulse_chaos,
307 | target_traj=target_chaos,
308 | learn_start=t_offset + d_stim,
309 | learn_stop=t_offset + d_stim + d_chaos,
310 | learn_readout=True)
311 | if verbose: print(f'{(2 * loss[0] / d_chaos):.5f} used {(time.time() - t0):5f} s, ', end='')
312 |
313 | t0 = time.time()
314 | _, _, loss = net.simulate(stimulus=impulse_neuron,
315 | target_traj=target_neuron,
316 | learn_start=t_offset + d_stim,
317 | learn_stop=t_offset + d_stim + d_neuron,
318 | learn_readout=True)
319 | if verbose: print(f'"neuron" loss: {(2 * loss[0] / d_chaos):5f} used {(time.time() - t0):5f} s')
320 |
321 | # Test trials
322 | final_output_chaos = []
323 | final_output_neuron = []
324 | for _ in range(num_test):
325 | if verbose: print('Test chaos trial')
326 | _, o, _ = net.simulate(stimulus=impulse_chaos)
327 | final_output_chaos.append(o)
328 | if verbose: print('Test neuron trial')
329 | _, o, _ = net.simulate(stimulus=impulse_neuron)
330 | final_output_neuron.append(o)
331 |
332 | # Perturbation trials
333 | perturbation_output_chaos = []
334 | perturbation_output_neuron = []
335 | for _ in range(num_perturb):
336 | if verbose: print('Perturbation chaos trial')
337 | _, o, _ = net.simulate(stimulus=perturbation_chaos)
338 | perturbation_output_chaos.append(o)
339 | if verbose: print('Perturbation neuron trial')
340 | _, o, _ = net.simulate(stimulus=perturbation_neuron)
341 | perturbation_output_neuron.append(o)
342 |
343 | if verbose: print('Simulation done in', time.time() - tstart, 'seconds.')
344 |
345 | # Visualization
346 | # -------------
347 | subsampling_chaos = (t_offset + d_stim + bm.linspace(0, d_chaos, 20)).astype(bm.int32)
348 | subsampling_chaos = bm.unique(subsampling_chaos)
349 | subsampling_neuron = (t_offset + d_stim + bm.linspace(0, d_neuron, 20)).astype(bm.int32)
350 | subsampling_neuron = bm.unique(subsampling_neuron)
351 |
352 | plt.figure(figsize=(12, 8))
353 | ax = plt.subplot2grid((2, 2), (0, 0))
354 | ax.plot(chaos[0, :], chaos[1, :], linewidth=2.)
355 | for i in range(num_perturb):
356 | ax.plot(final_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 0],
357 | final_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 1])
358 | ax.plot(final_output_chaos[i][subsampling_chaos, 0],
359 | final_output_chaos[i][subsampling_chaos, 1], 'bo')
360 | ax.set_xlabel('x')
361 | ax.set_ylabel('Without perturbation\ny')
362 |
363 | ax = plt.subplot2grid((2, 2), (0, 1))
364 | ax.plot(neuron[0, :], neuron[1, :], linewidth=2.)
365 | for i in range(num_perturb):
366 | ax.plot(final_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 0],
367 | final_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 1])
368 | ax.plot(final_output_neuron[i][subsampling_neuron, 0],
369 | final_output_neuron[i][subsampling_neuron, 1], 'bo')
370 | ax.set_xlabel('x')
371 | ax.set_ylabel('y')
372 |
373 | ax = plt.subplot2grid((2, 2), (1, 0))
374 | ax.plot(chaos[0, :], chaos[1, :], linewidth=2.)
375 | for i in range(num_perturb):
376 | ax.plot(perturbation_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 0],
377 | perturbation_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 1])
378 | ax.plot(perturbation_output_chaos[i][subsampling_chaos, 0],
379 | perturbation_output_chaos[i][subsampling_chaos, 1], 'bo')
380 | ax.set_xlabel('x')
381 | ax.set_ylabel('y')
382 | ax.set_ylabel('With perturbation\ny')
383 |
384 | ax = plt.subplot2grid((2, 2), (1, 1))
385 | ax.plot(neuron[0, :], neuron[1, :], linewidth=2.)
386 | for i in range(num_perturb):
387 | ax.plot(perturbation_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 0],
388 | perturbation_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 1])
389 | ax.plot(perturbation_output_neuron[i][subsampling_neuron, 0],
390 | perturbation_output_neuron[i][subsampling_neuron, 1], 'bo')
391 | ax.set_xlabel('x')
392 | ax.set_ylabel('y')
393 |
394 | plt.show()
395 |
396 |
397 | # %%
398 | def fig3(nets, verbose=True):
399 | # Parameters
400 | # ----------
401 |
402 | num_rec_train = 20 # Number of learning trials for the recurrent weights
403 | num_readout_train = 10 # Number of learning trials for the readout weights
404 |
405 | stimulus_amplitude = 5.0 # Amplitude of the input pulse
406 | t_offset = time2len(200) # Time to wait before the stimulation
407 | d_stim = time2len(50) # Duration of the stimulation
408 | t_relax = time2len(150) # Duration to relax after the target_traj
409 |
410 | target_baseline = 0.2 # Baseline of the output_traj function
411 | target_amplitude = 1. # Maximal value of the output_traj function
412 | target_width = 30. # Width of the Gaussian
413 |
414 | # Main procedure
415 | # --------------
416 |
417 | # Vary the timing interval
418 | delays = [250, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000]
419 |
420 | # Store the Pearson correlation coefficients
421 | pearsons = []
422 |
423 | # Iterate over the delays
424 | for target_time in delays:
425 | if verbose:
426 | print('*' * 60)
427 | print('Learning a delay of', target_time)
428 | print('*' * 60)
429 | print()
430 | d_trajectory = time2len(target_time + 150) # Duration of the desired target_traj
431 | trial_duration = t_offset + d_stim + d_trajectory + t_relax # Total duration of a trial
432 | pearsons.append([])
433 |
434 | for n, net in enumerate(nets): # 10 networks per delay
435 | if verbose: print(f'## Network {n + 1} ##', )
436 | net.init()
437 |
438 | # Impulse input after 200 ms
439 | impulse = bm.zeros((trial_duration, net.num_input))
440 | impulse[t_offset:t_offset + d_stim, 0] = stimulus_amplitude
441 |
442 | # Target output for learning the readout weights
443 | target = bm.zeros((trial_duration, net.num_output))
444 | time_axis = bm.linspace(0, trial_duration, trial_duration)
445 | target[:, 0] = target_baseline + (target_amplitude - target_baseline) * \
446 | bm.exp(-(t_offset + d_stim + target_time - time_axis) ** 2 / target_width ** 2)
447 |
448 | # Initial trial to determine the innate target_traj
449 | if verbose: print('Initial trial to determine a target_traj (without noise)')
450 | trajectory, initial_output, _ = net.simulate(stimulus=impulse, noise=False)
451 |
452 | # 20 trials of learning for the recurrent weights
453 | for i in range(num_rec_train):
454 | t0 = time.time()
455 | if verbose: print(f'Learning trial recurrent {i + 1} loss: ', end='')
456 | _, _, loss = net.simulate(stimulus=impulse,
457 | target_traj=trajectory,
458 | learn_start=t_offset + d_stim,
459 | learn_stop=t_offset + d_stim + d_trajectory)
460 | if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s')
461 |
462 | # 10 trials of learning for the readout weights
463 | for i in range(num_readout_train):
464 | t0 = time.time()
465 | if verbose: print(f'Learning trial readout {i + 1} loss: ', end='')
466 | _, _, loss = net.simulate(stimulus=impulse,
467 | target_traj=target,
468 | learn_start=t_offset + d_stim,
469 | learn_stop=t_offset + d_stim + d_trajectory,
470 | learn_readout=True)
471 | if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s')
472 |
473 | # Test trial
474 | if verbose: print('Test trial')
475 | reproduction, final_output, _ = net.simulate(stimulus=impulse)
476 |
477 | # Pearson correlation coefficient
478 | pred = final_output[t_offset + d_stim:t_offset + d_stim + d_trajectory, 0]
479 | desired = target[t_offset + d_stim:t_offset + d_stim + d_trajectory, 0]
480 | r, p = scipy.stats.pearsonr(desired, pred)
481 | pearsons[-1].append(r)
482 |
483 | # Save the results
484 | pearsons = bm.asarray(pearsons)
485 |
486 | # Visualization
487 | # -------------
488 |
489 | plt.figure(figsize=(8, 6))
490 | correlation_mean = bm.mean(pearsons ** 2, axis=1)
491 | correlation_std = bm.std(pearsons ** 2, axis=1)
492 | plt.errorbar(bm.array(delays) / 1000., correlation_mean, correlation_std / bm.sqrt(10), linestyle='-', marker='^')
493 | plt.xlim((0., 8.5))
494 | plt.ylim((-0.1, 1.1))
495 | plt.xlabel('Interval (s)')
496 | plt.ylabel('Performance ($R^2$)')
497 | plt.show()
498 |
--------------------------------------------------------------------------------
/recurrent_networks/data/DAC_handwriting_output_targets.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/recurrent_networks/data/DAC_handwriting_output_targets.mat
--------------------------------------------------------------------------------
/recurrent_networks/fixed_points_finder.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # %% [markdown]
3 | # # Find Fixed Points
4 |
5 | # %% [markdown]
6 | # The goal of this tutorial is to learn about fixed point finding by running the algorithm on a simple data generator, a Gated Recurrent Unit (GRU) that is trained to make a binary decision, namely whether the integral of the white noise input is in total positive or negative, outputing either a +1 or a -1.
7 |
8 | # %% [markdown]
9 | # In this tutorial we do a few things:
10 | #
11 | # - Train the decision making GRU
12 | # - Find the fixed points of the GRU
13 |
14 | # %%
15 | import brainpy as bp
16 | import brainpy.math as bm
17 | bm.set_platform('cpu')
18 |
19 | # %%
20 | import time
21 | from functools import partial
22 |
23 | import numpy as np
24 | import matplotlib.pyplot as plt
25 | import matplotlib.lines as mlines
26 |
27 | # %% [markdown]
28 | # ## Parameters
29 |
30 | # %%
31 | # Integration parameters
32 |
33 | T = 1.0 # Arbitrary amount time, roughly physiological.
34 | dt = 0.04
35 | num_step = int(T / dt) # Divide T into this many bins
36 | bval = 0.01 # bias value limit
37 | sval = 0.025 # standard deviation (before dividing by sqrt(dt))
38 |
39 | # %%
40 | # Optimization hyperparameters
41 |
42 | l2reg = 0.00002 # amount of L2 regularization on the weights
43 | num_train = int(2e4) # Total number of batches to train on.
44 | num_batch = 150 # How many examples in each batch
45 | # Gradient clipping is HUGELY important for training RNNs
46 | # max gradient norm before clipping, clip to this value.
47 | max_grad_norm = 10.0
48 |
49 |
50 | # %% [markdown]
51 | # ## Helpers
52 |
53 | # %%
54 | def plot_examples(num_time, inputs, hiddens, outputs, targets, num_example=1,
55 | num_plot=10, start_id=0):
56 | """Plot some input/hidden/output triplets.
57 |
58 | Parameters
59 | ----------
60 | inputs: ndarray of (num_time, num_batch)
61 | hiddens: ndarray of (num_time, num_batch, num_hidden)
62 | outputs: ndarray of (num_time, num_batch, num_output)
63 | targets: ndarray of (num_time, num_batch, num_output)
64 | """
65 | plt.figure(figsize=(num_example * 5, 14))
66 | selected_ids = list(range(start_id, start_id + num_example))
67 |
68 | for i, bidx in enumerate(selected_ids):
69 | plt.subplot(3, num_example, i + 1)
70 | plt.plot(inputs[:, bidx], 'k')
71 | plt.xlim([0, num_time])
72 | plt.title('Example %d' % bidx)
73 | if bidx == 0: plt.ylabel('Input Units')
74 |
75 | closeness = 0.25
76 | for i, bidx in enumerate(selected_ids):
77 | plt.subplot(3, num_example, num_example + i + 1)
78 | plt.plot(hiddens[:, bidx, 0:num_plot] + closeness * np.arange(num_plot), 'b')
79 | plt.xlim([0, num_time])
80 | if bidx == 0: plt.ylabel('Hidden Units')
81 |
82 | for i, bidx in enumerate(selected_ids):
83 | plt.subplot(3, num_example, 2 * num_example + i + 1)
84 | plt.plot(outputs[:, bidx, :], 'r', label='predict')
85 | plt.plot(targets[:, bidx, :], 'k', label='target')
86 | plt.xlim([0, num_time])
87 | plt.xlabel('Time steps')
88 | plt.legend()
89 | if bidx == 0: plt.ylabel('Output Units')
90 |
91 | plt.show()
92 |
93 |
94 | # %%
95 | def plot_params(rnn, show_top_eig=0):
96 | """Plot the parameters. """
97 | assert isinstance(rnn, GRU)
98 |
99 | plt.figure(figsize=(16, 8))
100 | plt.subplot(231)
101 | plt.stem(rnn.w_ro.numpy()[:, 0])
102 | plt.title('W_ro - output weights')
103 |
104 | plt.subplot(232)
105 | plt.stem(rnn.h0)
106 | plt.title('h0 - initial hidden state')
107 |
108 | a = bm.concatenate([rnn.w_ir, rnn.w_iz], axis=0)
109 | b = bm.concatenate([rnn.w_hr, rnn.w_hz], axis=0)
110 | c = bm.concatenate([a, b], axis=0).numpy()
111 |
112 | plt.subplot(233)
113 | plt.imshow(c, interpolation=None)
114 | plt.colorbar()
115 | plt.title('[W_ir, W_iz, W_hr, W_hz]')
116 |
117 | plt.subplot(234)
118 | a = bm.concatenate([rnn.w_ia, rnn.w_ha], axis=0).numpy()
119 | plt.imshow(a, interpolation=None)
120 | plt.colorbar()
121 | plt.title('[W_ia, W_ha]')
122 |
123 | plt.subplot(235)
124 | plt.stem(bm.concatenate([rnn.bz, rnn.br]).numpy())
125 | plt.title('[bz, br] - recurrent biases')
126 |
127 | plt.subplot(236)
128 | dFdh = bm.jacobian(rnn.cell)(rnn.h0.value, bm.zeros(rnn.num_input))
129 | evals, _ = np.linalg.eig(bm.as_numpy(dFdh))
130 | x = np.linspace(-1, 1, 1000)
131 | plt.plot(x, np.sqrt(1 - x ** 2), 'k')
132 | plt.plot(x, -np.sqrt(1 - x ** 2), 'k')
133 | plt.plot(np.real(evals), np.imag(evals), '.')
134 |
135 | if show_top_eig > 0:
136 | print(np.sort(np.real(evals))[-show_top_eig:])
137 |
138 | plt.axis('equal')
139 | plt.xlabel('Real($\lambda$)')
140 | plt.ylabel('Imaginary($\lambda$)')
141 | plt.title('Eigenvalues of $dF/dh(h_0)$')
142 |
143 | plt.show()
144 |
145 |
146 | # %%
147 | def plot_data(num_time, inputs, targets=None, outputs=None, errors=None, num_plot=10):
148 | """Plot some white noise / integrated white noise examples.
149 |
150 | Parameters
151 | ----------
152 | num_time : int
153 | num_plot : int
154 | inputs: ndarray
155 | with the shape of (num_batch, num_time, num_input)
156 | targets: ndarray
157 | with the shape of (num_batch, num_time, num_output)
158 | outputs: ndarray
159 | with the shape of (num_batch, num_time, num_output)
160 | errors: ndarray
161 | with the shape of (num_batch, num_time, num_output)
162 | """
163 | num = 1
164 | if errors is not None: num += 1
165 | if (targets is not None) or (outputs is not None): num += 1
166 | plt.figure(figsize=(14, 4 * num))
167 |
168 | # inputs
169 | plt.subplot(num, 1, 1)
170 | plt.plot(inputs[:, 0:num_plot, 0])
171 | plt.xlim([0, num_time])
172 | plt.ylabel('Noise')
173 |
174 | legends = []
175 | if outputs is not None:
176 | plt.subplot(num, 1, 2)
177 | plt.plot(outputs[:, 0:num_plot, 0])
178 | plt.xlim([0, num_time])
179 | legends.append(mlines.Line2D([], [], color='k', linestyle='-', label='predict'))
180 | if targets is not None:
181 | plt.subplot(num, 1, 2)
182 | plt.plot(targets[:, 0:num_plot, 0], '--')
183 | plt.xlim([0, num_time])
184 | plt.ylabel("Integration")
185 | legends.append(mlines.Line2D([], [], color='k', linestyle='--', label='target'))
186 | if len(legends): plt.legend(handles=legends)
187 |
188 | if errors is not None:
189 | plt.subplot(num, 1, 3)
190 | plt.plot(errors[:, 0:num_plot, 0], '--')
191 | plt.xlim([0, num_time])
192 | plt.ylabel("|Errors|")
193 |
194 | plt.xlabel('Time steps')
195 | plt.show()
196 |
197 |
198 | # %%
199 | @partial(bm.jit, dyn_vars={'a': bm.random.DEFAULT}, static_argnames=('num_batch', 'num_step', 'dt'))
200 | def build_inputs_and_targets(mean, scale, num_batch, num_step, dt):
201 | """Build white noise input and integration targets."""
202 |
203 | # Create the white noise input.
204 | sample = bm.random.normal(size=(num_batch,))
205 | bias = mean * 2.0 * (sample - 0.5)
206 | samples = bm.random.normal(size=(num_step, num_batch))
207 | noise_t = scale / dt ** 0.5 * samples
208 | white_noise_t = bias + noise_t
209 | inputs_txbx1 = bm.expand_dims(white_noise_t, axis=2)
210 |
211 | # * dt, intentionally left off to get output scaling in O(1).
212 | integration_txbx1 = bm.expand_dims(bm.cumsum(white_noise_t, axis=0), axis=2)
213 | targets_txbx1 = bm.zeros_like(integration_txbx1)
214 | targets_txbx1[-1] = 2.0 * ((integration_txbx1[-1] > 0.0) - 0.5)
215 | # targets_mask = bm.ones((num_batch, 1)) * (num_step - 1)
216 | return inputs_txbx1, targets_txbx1
217 |
218 |
219 | # %%
220 | # # Plot the example inputs and targets for the RNN.
221 | # _ints, _outs = build_inputs_and_targets(bval, sval, num_batch=num_batch, num_step=num_step, dt=dt)
222 |
223 | # plot_data(num_step, inputs=_ints, targets=_outs)
224 |
225 | # %% [markdown]
226 | # ## Model
227 |
228 | # %%
229 | class GRU(bp.DynamicalSystem):
230 | def __init__(self, num_hidden, num_input, num_output, num_batch,
231 | g=1.0, l2_reg=0., forget_bias=0.5, **kwargs):
232 | super(GRU, self).__init__(**kwargs)
233 |
234 | # parameters
235 | self.l2_reg = l2_reg
236 | self.num_input = num_input
237 | self.num_batch = num_batch
238 | self.num_hidden = num_hidden
239 | self.num_output = num_output
240 | self.rng = bm.random.RandomState()
241 | self.forget_bias = forget_bias
242 |
243 | # recurrent weights
244 | self.w_iz = bm.TrainVar(self.rng.normal(scale=1 / num_input ** 0.5, size=(num_input, num_hidden)))
245 | self.w_ir = bm.TrainVar(self.rng.normal(scale=1 / num_input ** 0.5, size=(num_input, num_hidden)))
246 | self.w_ia = bm.TrainVar(self.rng.normal(scale=1 / num_input ** 0.5, size=(num_input, num_hidden)))
247 | self.w_hz = bm.TrainVar(self.rng.normal(scale=g / num_hidden ** 0.5, size=(num_hidden, num_hidden)))
248 | self.w_hr = bm.TrainVar(self.rng.normal(scale=g / num_hidden ** 0.5, size=(num_hidden, num_hidden)))
249 | self.w_ha = bm.TrainVar(self.rng.normal(scale=g / num_hidden ** 0.5, size=(num_hidden, num_hidden)))
250 | self.bz = bm.TrainVar(bm.zeros((num_hidden,)))
251 | self.br = bm.TrainVar(bm.zeros((num_hidden,)))
252 | self.ba = bm.TrainVar(bm.zeros((num_hidden,)))
253 | self.h0 = bm.TrainVar(self.rng.normal(scale=0.1, size=(num_hidden,)))
254 |
255 | # output weights
256 | self.w_ro = bm.TrainVar(self.rng.normal(scale=1 / num_hidden ** 0.5, size=(num_hidden, num_output)))
257 | self.b_ro = bm.TrainVar(bm.zeros((num_output,)))
258 |
259 | # variables
260 | self.h = bm.Variable(self.rng.normal(scale=0.1, size=(num_batch, self.num_hidden)))
261 | self.o = self.h @ self.w_ro + self.b_ro
262 |
263 | # loss
264 | self.total_loss = bm.Variable(bm.zeros(1))
265 | self.l2_loss = bm.Variable(bm.zeros(1))
266 | self.mse_loss = bm.Variable(bm.zeros(1))
267 |
268 | def cell(self, h, x):
269 | r = bm.sigmoid(x @ self.w_ir + h @ self.w_hr + self.br)
270 | z = bm.sigmoid(x @ self.w_iz + h @ self.w_hz + self.bz)
271 | a = bm.tanh(x @ self.w_ia + (r * h) @ self.w_ha + self.ba)
272 | return (1. - z) * h + z * a
273 |
274 | def readout(self, h):
275 | return h @ self.w_ro + self.b_ro
276 |
277 | def make_update(self, h: bm.JaxArray, o: bm.JaxArray):
278 | def update(x):
279 | h.value = self.cell(h, x)
280 | o.value = self.readout(h)
281 |
282 | return update
283 |
284 | def predict(self, xs):
285 | self.h[:] = self.h0
286 | f = bm.make_loop(self.make_update(self.h, self.o),
287 | dyn_vars=self.vars().unique(),
288 | out_vars=[self.h, self.o])
289 | return f(xs)
290 |
291 | def loss(self, xs, ys):
292 | hs, os = self.predict(xs)
293 | l2 = self.l2_reg * bm.losses.l2_norm(self.train_vars().dict()) ** 2
294 | mse = bm.losses.mean_squared_error(os[-1], ys[-1])
295 | total = l2 + mse
296 | self.total_loss[0] = total
297 | self.l2_loss[0] = l2
298 | self.mse_loss[0] = mse
299 | return total
300 |
301 |
302 | # %%
303 | net = GRU(num_input=1, num_hidden=100, num_output=1, num_batch=num_batch, l2_reg=l2reg)
304 |
305 | # plot_params(net)
306 |
307 | # %%
308 | lr = bp.optim.ExponentialDecay(lr=0.04, decay_steps=1, decay_rate=0.9999)
309 | optimizer = bp.optim.Adam(lr=lr, train_vars=net.train_vars(), eps=1e-1)
310 |
311 |
312 | @bm.jit
313 | @bm.function(nodes=(net, optimizer))
314 | def train(inputs, targets):
315 | grad_f = bm.grad(net.loss, dyn_vars=net.vars(), grad_vars=net.train_vars(), return_value=True)
316 | grads, loss = grad_f(inputs, targets)
317 | clipped_grads = bm.clip_by_norm(grads, max_grad_norm)
318 | optimizer.update(clipped_grads)
319 | return loss
320 |
321 |
322 | # %% [markdown]
323 | # ## Training
324 |
325 | # %%
326 | t0 = time.time()
327 | train_losses = {'total': [], 'l2': [], 'mse': []}
328 | for i in range(num_train):
329 | _ins, _outs = build_inputs_and_targets(bval, sval, num_batch=num_batch, num_step=num_step, dt=dt)
330 | loss = train(inputs=_ins, targets=_outs)
331 | if (i + 1) % 400 == 0:
332 | print(f"Run batch {i + 1} in {time.time() - t0:0.3f} s, learning rate: {lr():.5f}, training loss {loss:0.4f}")
333 | train_losses['total'].append(net.total_loss[0])
334 | train_losses['l2'].append(net.l2_loss[0])
335 | train_losses['mse'].append(net.mse_loss[0])
336 |
337 | # %%
338 | # net.save_states('./data/fixed_points-80.h5')
339 | # net.save_states('./data/fixed_points-1200.h5')
340 | # net.load_states('./data/fixed_points-80.h5')
341 |
342 | # %%
343 | # Show the loss through training.
344 | plt.figure(figsize=(12, 4))
345 | plt.subplot(131)
346 | plt.plot(train_losses['total'], 'k')
347 | plt.title('Total loss')
348 | plt.xlabel('Trail')
349 |
350 | plt.subplot(132)
351 | plt.plot(train_losses['mse'], 'r')
352 | plt.title('Least mean square loss')
353 | plt.xlabel('Trail')
354 |
355 | plt.subplot(133)
356 | plt.plot(train_losses['l2'], 'g')
357 | plt.title('L2 loss')
358 | plt.xlabel('Trail')
359 | plt.show()
360 |
361 | # %%
362 | # show the trained weights
363 | plot_params(net, show_top_eig=2)
364 |
365 | # %% [markdown]
366 | # ## Testing
367 |
368 | # %%
369 | # net.load_states('./data/fixed_points.h5')
370 |
371 | # %%
372 | inputs, targets = build_inputs_and_targets(bval, sval, num_batch=num_batch, num_step=num_step, dt=dt)
373 | hiddens, outputs = net.predict(inputs)
374 |
375 | # plot_data(num_step, inputs=inputs, targets=targets, outputs=outputs, errors=np.abs(targets - outputs), num_plot=16)
376 |
377 | # %%
378 | plot_examples(num_step, inputs=inputs, targets=targets, outputs=outputs, hiddens=hiddens, num_example=4)
379 |
380 | # %%
381 | plot_examples(num_step, inputs=inputs, targets=targets, outputs=outputs, hiddens=hiddens, num_example=4, start_id=10)
382 |
383 | # %% [markdown]
384 | # ## Fixed point analysis
385 |
386 | # %% [markdown]
387 | # Now that we've trained up this GRU to decide whether or not the perfect integral of the input is positive or negative, we can analyze the system via fixed point analysis.
388 |
389 | # %% [markdown]
390 | # The update rule of the GRU cell from the current state to the next state can be computed by:
391 |
392 | # %%
393 | f_cell = lambda h: net.cell(h, bm.zeros(net.num_input))
394 |
395 | # %% [markdown]
396 | # The function to determine the fixed point loss is given by the squared error of a point $(h - F(h))^2$:
397 |
398 | # %% [markdown]
399 | # Let's try to find the fixed points given the initial states.
400 |
401 | # %%
402 | fp_candidates = hiddens.reshape((-1, net.num_hidden))
403 |
404 | # %%
405 | finder = bp.analysis.SlowPointFinder(
406 | f_cell=f_cell,
407 | f_type='discrete',
408 | )
409 | finder.find_fps_with_gd_method(
410 | candidates=fp_candidates,
411 | tolerance=5e-7, num_batch=400,
412 | optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.2, 1, 0.9999), eps=1e-8),
413 | )
414 | finder.filter_loss(tolerance=1e-7)
415 | finder.keep_unique(tolerance=0.0005)
416 | finder.exclude_outliers(0.1)
417 | fps = finder.fixed_points
418 |
419 | #
420 | # fps, fp_losses, keep_ids, opt_losses = finder.find_fixed_points(candidates=fp_candidates)
421 |
422 | # %% [markdown]
423 | # ### Verify fixed points
424 |
425 | # %% [markdown]
426 | # Plotting the quality of the fixed points.
427 |
428 | # %%
429 | # # %matplotlib inline
430 |
431 | # %%
432 | fig, gs = bp.visualize.get_figure(1, 2, 4, 6)
433 |
434 | fig.add_subplot(gs[0, 0])
435 | plt.semilogy(finder.losses)
436 | plt.xlabel('Fixed point #')
437 | plt.ylabel('Fixed point loss')
438 |
439 | fig.add_subplot(gs[0, 1])
440 | plt.hist(np.log10(finder.f_loss_batch(fps)), 50)
441 | plt.xlabel('log10(FP loss)')
442 | plt.show()
443 |
444 | # %% [markdown]
445 | # Let's run the system starting at these fixed points, without input, and make sure the system is at equilibrium there. Note one can have fixed points that are very unstable, but that does not show up in this example.
446 |
447 | # %%
448 | # num_example = len(fps)
449 | # idxs = np.random.randint(0, len(fps), num_example)
450 | # check_h = bm.Variable(fps[idxs])
451 |
452 | num_example = len(fps)
453 | idxs = np.random.randint(0, len(fps), num_example)
454 | check_h = bm.Variable(fps)
455 | check_o = bm.Variable(bm.zeros((num_example, net.num_output)))
456 |
457 | f_check_update = bm.make_loop(net.make_update(check_h, check_o),
458 | dyn_vars=list(net.vars().values()) + [check_h, check_o],
459 | out_vars=[check_h, check_o])
460 |
461 | _ins = bm.zeros((num_step, num_example, net.num_input))
462 | _outs = bm.zeros((num_step, num_example, net.num_output))
463 | slow_hiddens, slow_outputs = f_check_update(_ins)
464 |
465 | # %%
466 | plot_examples(num_step, inputs=_ins, targets=_outs, outputs=slow_outputs,
467 | hiddens=slow_hiddens, num_example=4)
468 |
469 | # %%
470 | plot_examples(num_step, inputs=_ins, targets=_outs, outputs=slow_outputs,
471 | hiddens=slow_hiddens, num_example=4, start_id=4)
472 |
473 | # %% [markdown]
474 | # Try to get a nice representation of the line using the fixed points.
475 |
476 | # %%
477 | # Sort the best fixed points by projection onto the readout.
478 | fp_readouts = np.squeeze(net.readout(fps))
479 | fp_ro_sidxs = np.argsort(fp_readouts)
480 | sorted_fp_readouts = fp_readouts[fp_ro_sidxs]
481 | sorted_fps = fps[fp_ro_sidxs]
482 |
483 | downsample_fps = 1 # Use this if too many fps
484 | sorted_fp_readouts = sorted_fp_readouts[0:-1:downsample_fps]
485 | sorted_fps = sorted_fps[0:-1:downsample_fps]
486 |
487 | # %% [markdown]
488 | # ### Visualize fixed points
489 |
490 | # %% [markdown]
491 | # Now, through a series of plots and dot products, we will see how the GRU solved the binary decision task. Now, let's plot the fixed points and the fixed point candidates that the fixed point optimization was seeded with. Black shows the original candidate point, the colored stars show the fixed point, where the color of the fixed point is the projection onto the readout vector and the size is commensurate with how slow it is (slower is larger).
492 |
493 | # %%
494 | # # %matplotlib qt
495 |
496 | # %%
497 | from sklearn.decomposition import PCA
498 |
499 | # fit candidates
500 | pca = PCA(n_components=3).fit(fp_candidates)
501 | # pca = PCA(n_components=3).fit(fps)
502 |
503 | # %%
504 | fig = plt.figure(figsize=(16,16))
505 | ax = fig.add_subplot(111, projection='3d')
506 |
507 | emax = fps.shape[0]
508 |
509 | # # plot candidates
510 | # h_pca = pca.transform(fp_candidates[keep_ids])
511 | # emax = h_pca.shape[0] if h_pca.shape[0] < 1000 else 1000
512 | # ax.scatter(h_pca[0:emax, 0], h_pca[0:emax, 1], h_pca[0:emax, 2], color=[0, 0, 0, 0.1], s=10)
513 |
514 | # plot fixed points
515 | hstar_pca = pca.transform(fps)
516 | color = np.squeeze(net.readout(fps))
517 | color = np.where(color > 1.0, 1.0, color)
518 | color = np.where(color < -1.0, -1.0, color)
519 | color = (color + 1.0) / 2.0
520 | marker_style = dict(marker='*', s=100, edgecolor='gray')
521 | ax.scatter(hstar_pca[0:emax, 0], hstar_pca[0:emax, 1], hstar_pca[0:emax, 2], c=color[0:emax], **marker_style);
522 |
523 | # alpha = 0.02
524 | # for eidx in range(emax):
525 | # ax.plot3D([h_pca[eidx,0], hstar_pca[eidx,0]],
526 | # [h_pca[eidx,1], hstar_pca[eidx,1]],
527 | # [h_pca[eidx,2], hstar_pca[eidx,2]],
528 | # c=[0, 0, 1, alpha])
529 |
530 | plt.title('Fixed point structure and fixed point candidate starting points.')
531 | ax.set_xlabel('PC 1')
532 | ax.set_ylabel('PC 2')
533 | ax.set_zlabel('PC 3')
534 | plt.show()
535 |
536 | # %% [markdown]
537 | # So in this example, we see that the fixed point structure implements an approximate line attractor, which is the one-dimensional manifold likely used to integrate the white noise and ultimately lead to the decision.
538 |
539 | # %% [markdown]
540 | # Note also the shape of the manifold relative to the color. The color is the based on the readout value of the fixed point, so it appears that there may be three parts to the line attractor. The middle and two sides. The two sides may be integrating, even though the the readout would be +1 or -1.
541 |
542 | # %% [markdown]
543 | # It's worth taking a look at the fixed points, and the trajectories started at the fixed points, without any input, all plotted in the 3D PCA space.
544 |
545 | # %%
546 | fig = plt.figure(figsize=(16,16))
547 | ax = fig.add_subplot(111, projection='3d')
548 |
549 | alpha = 0.05
550 | emax = len(sorted_fps)
551 | for eidx in range(emax):
552 | h_pca = pca.transform(slow_hiddens[:, eidx, :])
553 | ax.plot3D(h_pca[:,0], h_pca[:,1], h_pca[:,2], c=[0, 0, 1, alpha])
554 |
555 | size = 100
556 | hstar_pca = pca.transform(sorted_fps)
557 | color = np.squeeze(net.readout(sorted_fps))
558 | color = np.where(color > 1.0, 1.0, color)
559 | color = np.where(color < -1.0, -1.0, color)
560 | color = (color + 1.0) / 2.0
561 | marker_style = dict(marker='*', s=size, edgecolor='gray')
562 |
563 | ax.scatter(hstar_pca[0:emax, 0], hstar_pca[0:emax, 1], hstar_pca[0:emax, 2],
564 | c=color[0:emax], **marker_style)
565 |
566 | plt.title('High quality fixed points and the network dynamics initialized from them.')
567 | ax.set_xlabel('PC 1')
568 | ax.set_ylabel('PC 2')
569 | ax.set_zlabel('PC 3')
570 | plt.show()
571 |
572 | # %%
573 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | brainpy
2 |
3 | # docs
4 | pandoc
5 | docutils
6 | Jinja2
7 | sphinx
8 | nbsphinx
9 | sphinx_rtd_theme>=1.0.0
--------------------------------------------------------------------------------