├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── Makefile ├── README.md ├── attractors ├── Mi_2014_CANN_1D_oscillatory_tracking.ipynb ├── Wu_2008_CANN.ipynb ├── Wu_2008_CANN_2D.ipynb ├── data │ └── data_to_train_on.npy ├── discrete_hopfield_demo_for_image_reconstruction.ipynb └── discrete_hopfield_network.ipynb ├── brain_inspired_computing ├── OTTT-SNN.py ├── SurrogateGrad_lif.py ├── SurrogateGrad_lif_fashion_mnist.py ├── _datasets.py ├── fashion_mnist_conv_lif.py ├── liquid_time_constant_network.py ├── mnist_lif_readout.py └── spiking_FPTT │ ├── README.md │ ├── add_task.py │ └── mnist_classification.py ├── classical_dynamical_systems ├── Multiscroll_attractor.ipynb ├── Rabinovich_Fabrikant_eq.ipynb ├── fractional_order_chaos.ipynb ├── henon_map.ipynb ├── logistic_map.ipynb ├── lorenz_system.ipynb └── mackey_glass_eq.ipynb ├── conf.py ├── decision_making ├── Wang_2002_decision_making_spiking.ipynb ├── Wang_2002_decision_making_spiking.py └── Wang_2006_decision_making_rate.ipynb ├── dynamics_analysis ├── 1d_simple_systems.ipynb ├── 2d_NaK_model.ipynb ├── 2d_decision_making_model.ipynb ├── 2d_decision_making_with_lowdim_analyzer.ipynb ├── 2d_wilson_cowan_model.ipynb ├── 3d_hindmarsh_rose_model.ipynb ├── highdim_CANN.ipynb └── highdim_gj_coupled_fhn.ipynb ├── ei_nets ├── Brette_2007_COBA.ipynb ├── Brette_2007_COBAHH.ipynb ├── Brette_2007_CUBA.ipynb ├── Tian_2020_EI_net_for_fast_response.ipynb └── Vreeswijk_1996_EI_net.ipynb ├── gj_nets ├── Fazli_2022_gj_coupled_bursting_pituitary_cells.ipynb └── Sherman_1992_gj_antisynchrony.ipynb ├── images ├── cann-decoding.gif ├── cann-encoding.gif ├── cann-tracking.gif ├── cann_1d_oscillatory_tracking.gif ├── cann_2d_encoding.gif ├── cann_2d_tracking.gif ├── decision_model.png └── izhikevich_patterns.jfif ├── index.rst ├── large_scale_modeling ├── 2014_CorticalModel.py ├── EI_net_with_1m_neurons.ipynb ├── Joglekar_2018_InterAreal_Balanced_Amplification_figure1.ipynb ├── Joglekar_2018_InterAreal_Balanced_Amplification_figure2.ipynb ├── Joglekar_2018_InterAreal_Balanced_Amplification_figure5.ipynb ├── Joglekar_2018_InterAreal_Balanced_Amplification_taichi_customized_op.ipynb └── Joglekar_2018_data │ ├── efelenMatpython.mat │ ├── hierValspython.mat │ ├── subgraphData.mat │ └── subgraphWiring29.mat ├── make.bat ├── neurons ├── 2018_Fractional_Izhikevich_model.ipynb ├── 2019_Fractional_order_FHR_model.ipynb ├── Gerstner_2005_AdExIF_model.ipynb ├── Izhikevich_2003_Izhikevich_model.ipynb ├── JR_1995_jansen_rit_model.ipynb ├── Niebur_2009_GIF.ipynb ├── Romain_2004_LIF_phase_locking.ipynb └── Susin_2021_gamma_oscillation_nets.py ├── oscillation_synchronization ├── Brunel_Hakim_1999_fast_oscillation.ipynb ├── Diesmann_1999_synfire_chains.ipynb ├── Li_2017_unified_thalamus_oscillation_model.ipynb ├── Susin_Destexhe_2021_gamma_oscillation_AI.ipynb ├── Susin_Destexhe_2021_gamma_oscillation_CHING.ipynb ├── Susin_Destexhe_2021_gamma_oscillation_ING.ipynb ├── Susin_Destexhe_2021_gamma_oscillation_PING.ipynb └── Wang_1996_gamma_oscillation.ipynb ├── others └── Brette_Guigon_2003_spike_timing_reliability.ipynb ├── recurrent_networks ├── Bellec_2020_eprop_evidence_accumulation.ipynb ├── Laje_Buonomano_2013_robust_timing_rnn.ipynb ├── Laje_Buonomano_2013_robust_timing_rnn.py ├── Laje_Buonomano_2013_simulation.ipynb ├── Laje_Buonomano_2013_simulation.py ├── Masse_2019_STP_RNN.ipynb ├── Masse_2019_STP_RNN_tasks.py ├── ParametricWorkingMemory.ipynb ├── Song_2016_EI_RNN.ipynb ├── Sussillo_Abbott_2009_FORCE_Learning.ipynb ├── Yang_2020_RNN_Analysis.ipynb ├── data │ └── DAC_handwriting_output_targets.mat ├── fixed_points_finder.ipynb ├── fixed_points_finder.py ├── fixed_points_finder2.ipynb ├── fixed_points_finder2.py └── integrator_rnn.ipynb ├── requirements.txt ├── reservoir_computing ├── Gauthier_2021_ngrc.ipynb └── predicting_Mackey_Glass_timeseries.ipynb └── working_memory ├── Bouchacourt_2019_Flexible_working_memory.ipynb └── Mi_2017_working_memory_capacity.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | _build 9 | recurrent_networks/neurogym 10 | recurrent_networks/Untitled.ipynb 11 | recurrent_networks/Untitled1.ipynb 12 | .idea 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | /brain_inspired_computing/logs/ 136 | /brain_inspired_computing/data/ 137 | /brain_inspired_computing/results/ 138 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | build: 9 | os: "ubuntu-20.04" 10 | tools: 11 | python: "3.10" 12 | 13 | # Build documentation in the docs/ directory with Sphinx 14 | sphinx: 15 | configuration: conf.py 16 | 17 | # Optionally set the version of Python and requirements required to build your docs 18 | python: 19 | install: 20 | - requirements: requirements.txt -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Examples for [BrainPy](https://github.com/brainpy/BrainPy) computation 2 | 3 | Examples needs ``brainpy>=2.3.0``. 4 | 5 | 6 | 7 | 8 | ### Recent activities 9 | 10 | - 2023/10/06, [Discrete Hopfield Network](./attractors/discrete_hopfiled_network.ipynb) 11 | - 2023/10/06, [Discrete Hopfield Network Demo for Image Reconstruction](./attractors/discrete_hopfield_demo_for_image_reconstruction.ipynb) 12 | - 2023/04/08, large-scale simulation example for simulating [1 million neuron EI network using 1GB GPU memory](./large_scale_modeling/EI_net_with_1m_neurons.ipynb) 13 | - 2023/01/29, implementing [liquid time-constant network](./brain_inspired_computing/liquid_time_constant_network.py) 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /attractors/Mi_2014_CANN_1D_oscillatory_tracking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "647060a3", 6 | "metadata": {}, 7 | "source": [ 8 | "# CANN 1D Oscillatory Tracking\n", 9 | "\n", 10 | "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/attractors/Mi_2014_CANN_1D_oscillatory_tracking.ipynb)\n", 11 | "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/attractors/Mi_2014_CANN_1D_oscillatory_tracking.ipynb)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "6d135f04", 17 | "metadata": {}, 18 | "source": [ 19 | "Implementation of the paper:\n", 20 | "\n", 21 | "- Si Wu, Kosuke Hamaguchi, and Shun-ichi Amari. \"Dynamics and computation of continuous attractors.\" Neural computation 20.4 (2008): 994-1025.\n", 22 | "- Mi, Y., Fung, C. C., Wong, M. K. Y., & Wu, S. (2014). Spike frequency adaptation implements anticipative tracking in continuous attractor neural networks. Advances in neural information processing systems, 1(January), 505." 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 1, 28 | "id": "51852ed7", 29 | "metadata": { 30 | "ExecuteTime": { 31 | "end_time": "2023-07-22T04:07:35.934007100Z", 32 | "start_time": "2023-07-22T04:07:35.243879300Z" 33 | }, 34 | "pycharm": { 35 | "is_executing": true 36 | } 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "import brainpy as bp\n", 41 | "import brainpy.math as bm\n", 42 | "\n", 43 | "bm.set_platform('cpu')" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "metadata": { 50 | "ExecuteTime": { 51 | "end_time": "2023-07-22T04:07:35.949695100Z", 52 | "start_time": "2023-07-22T04:07:35.934007100Z" 53 | }, 54 | "collapsed": false 55 | }, 56 | "outputs": [ 57 | { 58 | "data": { 59 | "text/plain": [ 60 | "'2.4.3'" 61 | ] 62 | }, 63 | "execution_count": 2, 64 | "metadata": {}, 65 | "output_type": "execute_result" 66 | } 67 | ], 68 | "source": [ 69 | "bp.__version__" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 2, 75 | "id": "433fe4d4", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "class CANN1D(bp.dyn.NeuDyn):\n", 80 | " def __init__(self, num, tau=1., tau_v=50., k=1., a=0.3, A=0.2, J0=1.,\n", 81 | " z_min=-bm.pi, z_max=bm.pi, m=0.3):\n", 82 | " super(CANN1D, self).__init__(size=num)\n", 83 | "\n", 84 | " # parameters\n", 85 | " self.tau = tau # The synaptic time constant\n", 86 | " self.tau_v = tau_v\n", 87 | " self.k = k # Degree of the rescaled inhibition\n", 88 | " self.a = a # Half-width of the range of excitatory connections\n", 89 | " self.A = A # Magnitude of the external input\n", 90 | " self.J0 = J0 # maximum connection value\n", 91 | " self.m = m\n", 92 | "\n", 93 | " # feature space\n", 94 | " self.z_min = z_min\n", 95 | " self.z_max = z_max\n", 96 | " self.z_range = z_max - z_min\n", 97 | " self.x = bm.linspace(z_min, z_max, num) # The encoded feature values\n", 98 | " self.rho = num / self.z_range # The neural density\n", 99 | " self.dx = self.z_range / num # The stimulus density\n", 100 | "\n", 101 | " # The connection matrix\n", 102 | " self.conn_mat = self.make_conn()\n", 103 | "\n", 104 | " # variables\n", 105 | " self.r = bm.Variable(bm.zeros(num))\n", 106 | " self.u = bm.Variable(bm.zeros(num))\n", 107 | " self.v = bm.Variable(bm.zeros(num))\n", 108 | " self.input = bm.Variable(bm.zeros(num))\n", 109 | "\n", 110 | " def dist(self, d):\n", 111 | " d = bm.remainder(d, self.z_range)\n", 112 | " d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)\n", 113 | " return d\n", 114 | "\n", 115 | " def make_conn(self):\n", 116 | " x_left = bm.reshape(self.x, (-1, 1))\n", 117 | " x_right = bm.repeat(self.x.reshape((1, -1)), len(self.x), axis=0)\n", 118 | " d = self.dist(x_left - x_right)\n", 119 | " conn = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)\n", 120 | " return conn\n", 121 | "\n", 122 | " def get_stimulus_by_pos(self, pos):\n", 123 | " return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))\n", 124 | "\n", 125 | " def update(self):\n", 126 | " r1 = bm.square(self.u)\n", 127 | " r2 = 1.0 + self.k * bm.sum(r1)\n", 128 | " self.r.value = r1 / r2\n", 129 | " Irec = bm.dot(self.conn_mat, self.r)\n", 130 | " self.u.value = self.u + (-self.u + Irec + self.input - self.v) / self.tau * bp.share['dt']\n", 131 | " self.v.value = self.v + (-self.v + self.m * self.u) / self.tau_v * bp.share['dt']\n", 132 | " self.input[:] = 0." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 3, 138 | "id": "1c04226c", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "cann = CANN1D(num=512)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 4, 148 | "id": "6bc3315e", 149 | "metadata": { 150 | "scrolled": false 151 | }, 152 | "outputs": [ 153 | { 154 | "data": { 155 | "application/vnd.jupyter.widget-view+json": { 156 | "model_id": "76f9d82a2cd447cc9541cd771fee2086", 157 | "version_major": 2, 158 | "version_minor": 0 159 | }, 160 | "text/plain": [ 161 | " 0%| | 0/26000 [00:00 0.5 * self.z_range, d - self.z_range, d)\n", 141 | " return d\n", 142 | "\n", 143 | " def make_conn(self, x):\n", 144 | " assert bm.ndim(x) == 1\n", 145 | " x_left = bm.reshape(x, (-1, 1))\n", 146 | " x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)\n", 147 | " d = self.dist(x_left - x_right)\n", 148 | " Jxx = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / \\\n", 149 | " (bm.sqrt(2 * bm.pi) * self.a)\n", 150 | " return Jxx\n", 151 | "\n", 152 | " def get_stimulus_by_pos(self, pos):\n", 153 | " return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))\n", 154 | "\n", 155 | " def update(self):\n", 156 | " self.u.value = self.integral(self.u, bp.share['t'], self.input, bp.share['dt'])\n", 157 | " self.input[:] = 0." 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 9, 163 | "id": "64473237", 164 | "metadata": { 165 | "lines_to_next_cell": 2 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "cann = CANN1D(num=512, k=0.1)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "markdown", 174 | "id": "d83942e7", 175 | "metadata": {}, 176 | "source": [ 177 | "## Population coding" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 10, 183 | "metadata": { 184 | "collapsed": false 185 | }, 186 | "outputs": [ 187 | { 188 | "data": { 189 | "application/vnd.jupyter.widget-view+json": { 190 | "model_id": "a34c8dd621604a29b1c02176672f28d2", 191 | "version_major": 2, 192 | "version_minor": 0 193 | }, 194 | "text/plain": [ 195 | " 0%| | 0/170 [00:00 0.0: 267 | self.rng = bm.random.default_rng() 268 | self.reset_state(1) 269 | 270 | def reset_state(self, batch_size=1): 271 | self.v = bp.init.variable_(bm.zeros, self.size, batch_size) 272 | self.spike = bp.init.variable_(bm.zeros, self.size, batch_size) 273 | if self.track_rate: 274 | self.rate_tracking = bp.init.variable_(bm.zeros, self.size, batch_size) 275 | 276 | def update(self, x): 277 | # neuron charge 278 | self.v.value = jax.lax.stop_gradient(self.v.value) + x 279 | # neuron fire 280 | spike = self.f_surrogate(self.v.value - self.v_threshold) 281 | # spike reset 282 | spike_d = jax.lax.stop_gradient(spike) if self.detach_reset else spike 283 | if self.v_reset is None: 284 | self.v -= spike_d * self.v_threshold 285 | else: 286 | self.v.value = (1. - spike_d) * self.v + spike_d * self.v_reset 287 | # dropout 288 | if self.dropout > 0.0 and bp.share.load('fit'): 289 | mask = self.rng.bernoulli(1 - self.dropout, self.v.shape) / (1 - self.dropout) 290 | spike = mask * spike 291 | self.spike.value = spike 292 | # spike track 293 | if self.track_rate: 294 | self.rate_tracking += jax.lax.stop_gradient(spike) 295 | # output 296 | if bm.save.load('output_type') == 'spike_rate': 297 | assert self.track_rate 298 | return jnp.concatenate([spike, self.rate_tracking.value]) 299 | else: 300 | return spike 301 | 302 | 303 | class OnlineLIFNode(bp.DynamicalSystemNS): 304 | def __init__( 305 | self, 306 | size, 307 | tau: float = 2., 308 | decay_input: bool = False, 309 | v_threshold: float = 1., 310 | v_reset: float = None, 311 | f_surrogate=bm.surrogate.sigmoid, 312 | detach_reset: bool = True, 313 | track_rate: bool = True, 314 | neuron_dropout: float = 0.0, 315 | name: str = None, 316 | mode: bm.Mode = None 317 | ): 318 | super().__init__(name=name, mode=mode) 319 | bp.check.is_subclass(self.mode, bm.TrainingMode) 320 | 321 | self.size = bp.check.is_sequence(size, elem_type=int) 322 | self.tau = tau 323 | self.decay_input = decay_input 324 | self.v_threshold = v_threshold 325 | self.v_reset = v_reset 326 | self.f_surrogate = f_surrogate 327 | self.detach_reset = detach_reset 328 | self.track_rate = track_rate 329 | self.dropout = neuron_dropout 330 | 331 | if self.dropout > 0.0: 332 | self.rng = bm.random.default_rng() 333 | self.reset_state(1) 334 | 335 | def reset_state(self, batch_size=1): 336 | self.v = bp.init.variable_(bm.zeros, self.size, batch_size) 337 | self.spike = bp.init.variable_(bm.zeros, self.size, batch_size) 338 | if self.track_rate: 339 | self.rate_tracking = bp.init.variable_(bm.zeros, self.size, batch_size) 340 | 341 | def update(self, x): 342 | # neuron charge 343 | if self.decay_input: 344 | x = x / self.tau 345 | if self.v_reset is None or self.v_reset == 0: 346 | self.v = jax.lax.stop_gradient(self.v.value) * (1 - 1. / self.tau) + x 347 | else: 348 | self.v = jax.lax.stop_gradient(self.v.value) * (1 - 1. / self.tau) + self.v_reset / self.tau + x 349 | # neuron fire 350 | spike = self.f_surrogate(self.v - self.v_threshold) 351 | # neuron reset 352 | spike_d = jax.lax.stop_gradient(spike) if self.detach_reset else spike 353 | if self.v_reset is None: 354 | self.v -= spike_d * self.v_threshold 355 | else: 356 | self.v = (1. - spike_d) * self.v + spike_d * self.v_reset 357 | # dropout 358 | if self.dropout > 0.0 and bp.share.load('fit'): 359 | mask = self.rng.bernoulli(1 - self.dropout, spike.shape) / (1 - self.dropout) 360 | spike = mask * spike 361 | self.spike.value = spike 362 | # spike 363 | if self.track_rate: 364 | self.rate_tracking.value = jax.lax.stop_gradient(self.rate_tracking * (1 - 1. / self.tau) + spike) 365 | if bp.share.load('output_type') == 'spike_rate': 366 | assert self.track_rate 367 | return jnp.concatenate((spike, self.rate_tracking.value)) 368 | else: 369 | return spike 370 | 371 | 372 | class AverageMeter(object): 373 | def __init__(self): 374 | self.reset() 375 | 376 | def reset(self): 377 | self.val = 0 378 | self.avg = 0 379 | self.sum = 0 380 | self.count = 0 381 | 382 | def update(self, val, n=1): 383 | self.val = val 384 | self.sum += val * n 385 | self.count += n 386 | self.avg = self.sum / self.count 387 | 388 | 389 | @bm.jit(static_argnums=2) 390 | def accuracy(output, target, topk=(1,)): 391 | """Computes the precision@k for the specified values of k""" 392 | maxk = max(topk) 393 | _, pred = jax.vmap(jax.lax.top_k, in_axes=(0, None))(output, maxk) 394 | pred = pred.T 395 | correct = (pred == target.reshape(1, -1)).astype(bm.float_) 396 | res = [] 397 | for k in topk: 398 | correct_k = correct[:k].reshape(-1).sum(0) 399 | res.append(correct_k * 100.0 / target.size) 400 | return res 401 | 402 | 403 | # print(accuracy(jnp.ones(10, ), jnp.ones(20), (1, 3, 5))) 404 | # import sys 405 | # sys.exit() 406 | 407 | 408 | def classify_cifar(): 409 | parser = argparse.ArgumentParser(description='Classify CIFAR') 410 | parser.add_argument('-T', default=6, type=int, help='simulating time-steps') 411 | parser.add_argument('-tau', default=2., type=float) 412 | parser.add_argument('-b', default=128, type=int, help='batch size') 413 | parser.add_argument('-epochs', default=300, type=int, help='number of total epochs to run') 414 | parser.add_argument('-j', default=4, type=int, help='number of data loading workers (default: 4)') 415 | parser.add_argument('-data_dir', type=str, default=r'/mnt/d/data') 416 | # parser.add_argument('-data_dir', type=str, default=r'D:/data') 417 | parser.add_argument('-dataset', default='cifar10', type=str) 418 | parser.add_argument('-out_dir', default='./logs', type=str, help='root dir for saving logs and checkpoint') 419 | parser.add_argument('-resume', type=str, help='resume from the checkpoint path') 420 | parser.add_argument('-opt', type=str, help='use which optimizer. SGD or Adam', default='SGD') 421 | parser.add_argument('-lr', default=0.1, type=float, help='learning rate') 422 | parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD') 423 | parser.add_argument('-lr_scheduler', default='CosALR', type=str, help='use which schedule. StepLR or CosALR') 424 | parser.add_argument('-step_size', default=100, type=float, help='step_size for StepLR') 425 | parser.add_argument('-gamma', default=0.1, type=float, help='gamma for StepLR') 426 | parser.add_argument('-T_max', default=300, type=int, help='T_max for CosineAnnealingLR') 427 | parser.add_argument('-drop_rate', type=float, default=0.0) 428 | parser.add_argument('-weight_decay', type=float, default=0.0) 429 | parser.add_argument('-loss_lambda', type=float, default=0.05) 430 | parser.add_argument('-online_update', action='store_true') 431 | parser.add_argument('-gpu-id', default='0', type=str, help='gpu id') 432 | args = parser.parse_args() 433 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 434 | 435 | # datasets 436 | transform_train = transforms.Compose([ 437 | transforms.RandomCrop(32, padding=4), 438 | Cutout(), 439 | transforms.RandomHorizontalFlip(), 440 | transforms.ToTensor(), 441 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 442 | ]) 443 | transform_test = transforms.Compose([ 444 | transforms.ToTensor(), 445 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 446 | ]) 447 | if args.dataset == 'cifar10': 448 | dataloader = datasets.CIFAR10 449 | num_classes = 10 450 | else: 451 | dataloader = datasets.CIFAR100 452 | num_classes = 100 453 | 454 | # network 455 | net = OnlineSpikingVGG(neuron_model=OnlineLIFNode, 456 | neuron_pars=dict(tau=args.tau, 457 | neuron_dropout=args.drop_rate, 458 | f_surrogate=bm.surrogate.sigmoid, 459 | track_rate=True, 460 | v_reset=None), 461 | weight_standardization=True, 462 | num_classes=num_classes, 463 | grad_with_rate=True, 464 | fc_hw=1, 465 | c_in=3) 466 | print('Total Parameters: %.2fM' % ( 467 | sum(p.size for p in net.vars().subset(bm.TrainVar).unique().values()) / 1000000.0)) 468 | print(net) 469 | 470 | trainset = dataloader(root=args.data_dir, train=True, download=True, transform=transform_train) 471 | train_data_loader = data.DataLoader(trainset, batch_size=args.b, shuffle=True, num_workers=args.j) 472 | testset = dataloader(root=args.data_dir, train=False, download=False, transform=transform_test) 473 | test_data_loader = data.DataLoader(testset, batch_size=args.b, shuffle=False, num_workers=args.j) 474 | 475 | 476 | # path 477 | out_dir = os.path.join(args.out_dir, f'{args.dataset}_T_{args.T}_{args.opt}_lr_{args.lr}_') 478 | if args.lr_scheduler == 'CosALR': 479 | out_dir += f'CosALR_{args.T_max}' 480 | elif args.lr_scheduler == 'StepLR': 481 | out_dir += f'StepLR_{args.step_size}_{args.gamma}' 482 | else: 483 | raise NotImplementedError(args.lr_scheduler) 484 | if args.online_update: 485 | out_dir += '_online' 486 | os.makedirs(out_dir, exist_ok=True) 487 | with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: 488 | args_txt.write(str(args)) 489 | 490 | t_step = args.T 491 | 492 | def single_step(x, y, fit=True): 493 | bp.share.save('fit', fit) 494 | out = net(x) 495 | if args.loss_lambda > 0.0: 496 | y = bm.one_hot(y, 10, dtype=bm.float_) 497 | l = bp.losses.mean_squared_error(out, y) * args.loss_lambda 498 | l += (1 - args.loss_lambda) * bp.losses.cross_entropy_loss(out, y) 499 | l /= t_step 500 | else: 501 | l = bp.losses.cross_entropy_loss(out, y) / t_step 502 | return l, out 503 | 504 | @bm.jit 505 | def inference_fun(x, y): 506 | l, out = bm.for_loop(lambda _: single_step(x, y, False), jnp.arange(t_step)) 507 | out = out.sum(0) 508 | n = jnp.sum(jnp.argmax(out, axis=1) == y) 509 | return l.sum(), n, out 510 | 511 | grad_fun = bm.grad(single_step, grad_vars=net.train_vars().unique(), return_value=True, has_aux=True) 512 | 513 | if args.lr_scheduler == 'StepLR': 514 | lr = bp.optim.StepLR(args.lr, step_size=args.step_size, gamma=args.gamma) 515 | elif args.lr_scheduler == 'CosALR': 516 | lr = bp.optim.CosineAnnealingLR(args.lr, T_max=args.T_max) 517 | else: 518 | raise NotImplementedError(args.lr_scheduler) 519 | 520 | if args.opt == 'SGD': 521 | optimizer = bp.optim.Momentum(lr, net.train_vars().unique(), momentum=args.momentum, 522 | weight_decay=args.weight_decay) 523 | elif args.opt == 'Adam': 524 | optimizer = bp.optim.AdamW(lr, net.train_vars().unique(), weight_decay=args.weight_decay) 525 | else: 526 | raise NotImplementedError(args.opt) 527 | 528 | @bm.jit 529 | def train_fun(x, y): 530 | if args.online_update: 531 | final_loss, final_out = 0., 0. 532 | for _ in range(t_step): 533 | grads, l, out = grad_fun(x, y) 534 | optimizer.update(grads) 535 | final_loss += l 536 | final_out += out 537 | else: 538 | final_grads, final_loss, final_out = grad_fun(x, y) 539 | for _ in range(t_step - 1): 540 | grads, l, out = grad_fun(x, y) 541 | final_grads = jax.tree_util.tree_map(lambda a, b: a + b, final_grads, grads) 542 | final_loss += l 543 | final_out += out 544 | optimizer.update(final_grads) 545 | n = jnp.sum(jnp.argmax(final_out, axis=1) == y) 546 | return final_loss, n, final_out 547 | 548 | start_epoch = 0 549 | max_test_acc = 0 550 | if args.resume: 551 | checkpoint = bp.checkpoints.load_pytree(args.resume) 552 | net.load_state_dict(checkpoint['net']) 553 | optimizer.load_state_dict(checkpoint['optimizer']) 554 | start_epoch = checkpoint['epoch'] + 1 555 | max_test_acc = checkpoint['max_test_acc'] 556 | 557 | train_samples = len(train_data_loader) 558 | test_samples = len(test_data_loader) 559 | for epoch in range(start_epoch, args.epochs): 560 | start_time = time.time() 561 | 562 | batch_time = AverageMeter() 563 | losses = AverageMeter() 564 | top1 = AverageMeter() 565 | top5 = AverageMeter() 566 | end = time.time() 567 | 568 | train_loss = 0 569 | train_acc = 0 570 | pbar = tqdm.tqdm(total=train_samples) 571 | for frame, label in train_data_loader: 572 | frame = jnp.asarray(frame).transpose(0, 2, 3, 1) 573 | label = jnp.asarray(label) 574 | net.reset_state(frame.shape[0]) 575 | batch_loss, n, total_fr = train_fun(frame, label) 576 | prec1, prec5 = accuracy(total_fr, label, (1, 5)) 577 | train_loss += batch_loss * label.size 578 | train_acc += n 579 | losses.update(batch_loss, frame.shape[0]) 580 | top1.update(prec1.item(), frame.shape[0]) 581 | top5.update(prec5.item(), frame.shape[0]) 582 | 583 | # measure elapsed time 584 | batch_time.update(time.time() - end) 585 | end = time.time() 586 | 587 | # plot progress 588 | pbar.update(1) 589 | pbar.set_description( 590 | 'Batch: {bt:.3f}s | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 591 | bt=batch_time.avg, loss=losses.avg, top1=top1.avg, top5=top5.avg, 592 | ) 593 | ) 594 | pbar.close() 595 | 596 | train_loss /= train_samples 597 | train_acc /= train_samples 598 | optimizer.lr.step_epoch() 599 | 600 | batch_time = AverageMeter() 601 | losses = AverageMeter() 602 | top1 = AverageMeter() 603 | top5 = AverageMeter() 604 | end = time.time() 605 | 606 | test_loss = 0 607 | test_acc = 0 608 | pbar = tqdm.tqdm(total=test_samples) 609 | for frame, label in test_data_loader: 610 | frame = jnp.asarray(frame).transpose(0, 2, 3, 1) 611 | label = jnp.asarray(label) 612 | net.reset_state(frame.shape[0]) 613 | total_loss, n, out = inference_fun(frame, label) 614 | test_loss += total_loss * label.size 615 | test_acc += n 616 | prec1, prec5 = accuracy(out, label, (1, 5)) 617 | losses.update(total_loss, frame.shape[0]) 618 | top1.update(prec1.item(), frame.shape[0]) 619 | top5.update(prec5.item(), frame.shape[0]) 620 | batch_time.update(time.time() - end) 621 | end = time.time() 622 | 623 | # plot progress 624 | pbar.update(1) 625 | pbar.set_description( 626 | 'Batch: {bt:.3f}s | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format( 627 | bt=batch_time.avg, loss=losses.avg, top1=top1.avg, top5=top5.avg, 628 | ) 629 | ) 630 | pbar.close() 631 | 632 | test_loss /= test_samples 633 | test_acc /= test_samples 634 | 635 | if test_acc > max_test_acc: 636 | max_test_acc = test_acc 637 | checkpoint = { 638 | 'net': net.state_dict(), 639 | 'optimizer': optimizer.state_dict(), 640 | 'epoch': epoch, 641 | 'max_test_acc': max_test_acc 642 | } 643 | bp.checkpoints.save_pytree(out_dir + '/checkpoint.bp', checkpoint, overwrite=True) 644 | 645 | total_time = time.time() - start_time 646 | print(f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, ' 647 | f'test_loss={test_loss}, test_acc={test_acc}, max_test_acc={max_test_acc}, ' 648 | f'total_time={total_time}') 649 | 650 | 651 | if __name__ == '__main__': 652 | classify_cifar() 653 | -------------------------------------------------------------------------------- /brain_inspired_computing/SurrogateGrad_lif.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | """ 5 | Reproduce the results of the``spytorch`` tutorial 1: 6 | 7 | - https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial1.ipynb 8 | 9 | """ 10 | 11 | import time 12 | 13 | import jax.numpy as jnp 14 | import numpy as np 15 | import matplotlib.pyplot as plt 16 | from matplotlib.gridspec import GridSpec 17 | 18 | import brainpy as bp 19 | import brainpy.math as bm 20 | 21 | 22 | class SNN(bp.DynSysGroup): 23 | def __init__(self, num_in, num_rec, num_out): 24 | super(SNN, self).__init__() 25 | 26 | # parameters 27 | self.num_in = num_in 28 | self.num_rec = num_rec 29 | self.num_out = num_out 30 | 31 | # synapse: i->r 32 | self.i2r = bp.Sequential( 33 | bp.dnn.Linear(num_in, num_rec, W_initializer=bp.init.KaimingNormal(scale=20.)), 34 | bp.dyn.Expon(num_rec, tau=10.) 35 | ) 36 | # recurrent: r 37 | self.r = bp.dyn.Lif(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.) 38 | # synapse: r->o 39 | self.r2o = bp.Sequential( 40 | bp.dnn.Linear(num_rec, num_out, W_initializer=bp.init.KaimingNormal(scale=20.)), 41 | bp.dyn.Expon(num_out, tau=10.) 42 | ) 43 | # output: o 44 | self.o = bp.dyn.Leaky(num_out, tau=5) 45 | 46 | def update(self, spike): 47 | return spike >> self.i2r >> self.r >> self.r2o >> self.o 48 | 49 | 50 | def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): 51 | gs = GridSpec(*dim) 52 | mem = 1. * mem 53 | if spk is not None: 54 | mem[spk > 0.0] = spike_height 55 | mem = bm.as_numpy(mem) 56 | for i in range(np.prod(dim)): 57 | if i == 0: 58 | a0 = ax = plt.subplot(gs[i]) 59 | else: 60 | ax = plt.subplot(gs[i], sharey=a0) 61 | ax.plot(mem[i]) 62 | plt.tight_layout() 63 | plt.show() 64 | 65 | 66 | def print_classification_accuracy(output, target): 67 | """ Dirty little helper function to compute classification accuracy. """ 68 | m = bm.max(output, axis=1) # max over time 69 | am = bm.argmax(m, axis=1) # argmax over output units 70 | acc = bm.mean(target == am) # compare to labels 71 | print("Accuracy %.3f" % acc) 72 | 73 | 74 | with bm.environment(mode=bm.training_mode): 75 | net = SNN(100, 4, 2) 76 | 77 | num_step = 2000 78 | num_sample = 256 79 | freq = 5 # Hz 80 | mask = bm.random.rand(num_sample, num_step, net.num_in) 81 | x_data = bm.zeros((num_sample, num_step, net.num_in)) 82 | x_data[mask < freq * bm.get_dt() / 1000.] = 1.0 83 | y_data = bm.asarray(bm.random.rand(num_sample) < 0.5, dtype=bm.float_) 84 | rng = bm.random.RandomState(123) 85 | 86 | 87 | # Before training 88 | runner = bp.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) 89 | out = runner.run(inputs=x_data.value, reset_state=True) 90 | plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) 91 | plot_voltage_traces(out) 92 | print_classification_accuracy(out, y_data) 93 | 94 | 95 | def loss(): 96 | key = rng.split_key() 97 | X = bm.random.permutation(x_data, key=key) 98 | Y = bm.random.permutation(y_data, key=key) 99 | looper = bp.DSRunner(net, numpy_mon_after_run=False, progress_bar=False) 100 | predictions = looper.run(inputs=X, reset_state=True) 101 | predictions = bm.max(predictions, axis=1) 102 | return bp.losses.cross_entropy_loss(predictions, Y) 103 | 104 | 105 | grad = bm.grad(loss, grad_vars=net.train_vars().unique(), return_value=True) 106 | optimizer = bp.optim.Adam(lr=2e-3, train_vars=net.train_vars().unique()) 107 | 108 | 109 | def train(_): 110 | grads, l = grad() 111 | optimizer.update(grads) 112 | return l 113 | 114 | 115 | # train the network 116 | net.reset_state(num_sample) 117 | train_losses = [] 118 | b = 100 119 | for i in range(0, 3000, b): 120 | t0 = time.time() 121 | ls = bm.for_loop(train, operands=bm.arange(i, i + b, 1)) 122 | print(f'Train {i + b} epoch, loss = {jnp.mean(ls):.4f}, used time {time.time() - t0:.4f} s') 123 | train_losses.append(ls) 124 | 125 | # visualize the training losses 126 | plt.plot(bm.as_numpy(jnp.concatenate(train_losses))) 127 | plt.xlabel("Epoch") 128 | plt.ylabel("Training Loss") 129 | plt.show() 130 | 131 | # predict the output according to the input data 132 | runner = bp.DSRunner(net, monitors={'r.spike': net.r.spike, 'r.membrane': net.r.V}) 133 | out = runner.run(inputs=x_data, reset_state=True) 134 | plot_voltage_traces(runner.mon.get('r.membrane'), runner.mon.get('r.spike')) 135 | plot_voltage_traces(out) 136 | print_classification_accuracy(out, y_data) 137 | -------------------------------------------------------------------------------- /brain_inspired_computing/SurrogateGrad_lif_fashion_mnist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | Reproduce the results of the``spytorch`` tutorial 2 & 3: 5 | 6 | - https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial2.ipynb 7 | - https://github.com/surrogate-gradient-learning/spytorch/blob/master/notebooks/SpyTorchTutorial3.ipynb 8 | 9 | """ 10 | 11 | import brainpy_datasets as bd 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | from matplotlib.gridspec import GridSpec 15 | 16 | import brainpy as bp 17 | import brainpy.math as bm 18 | 19 | bm.set_environment(bm.training_mode) 20 | 21 | 22 | class SNN(bp.DynamicalSystem): 23 | """ 24 | This class implements a spiking neural network model with three layers: 25 | 26 | i >> r >> o 27 | 28 | Each two layers are connected through the exponential synapse model. 29 | """ 30 | 31 | def __init__(self, num_in, num_rec, num_out): 32 | super(SNN, self).__init__() 33 | 34 | # parameters 35 | self.num_in = num_in 36 | self.num_rec = num_rec 37 | self.num_out = num_out 38 | 39 | # synapse: i->r 40 | self.i2r = bp.Sequential( 41 | bp.dnn.Linear(num_in, num_rec, W_initializer=bp.init.KaimingNormal(scale=2.)), 42 | bp.dyn.Expon(num_rec, tau=10.) 43 | ) 44 | # recurrent: r 45 | self.r = bp.dyn.Lif(num_rec, tau=10, V_reset=0, V_rest=0, V_th=1.) 46 | # synapse: r->o 47 | self.r2o = bp.Sequential( 48 | bp.dnn.Linear(num_rec, num_out, W_initializer=bp.init.KaimingNormal(scale=2.)), 49 | bp.dyn.Expon(num_out, tau=10.) 50 | ) 51 | # output: o 52 | self.o = bp.dyn.Leaky(num_out, tau=5) 53 | 54 | def update(self, spike): 55 | return spike >> self.i2r >> self.r >> self.r2o >> self.o 56 | 57 | 58 | def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5): 59 | gs = GridSpec(*dim) 60 | mem = 1. * mem 61 | if spk is not None: 62 | mem[spk > 0.0] = spike_height 63 | mem = bm.as_numpy(mem) 64 | for i in range(np.prod(dim)): 65 | if i == 0: 66 | a0 = ax = plt.subplot(gs[i]) 67 | else: 68 | ax = plt.subplot(gs[i], sharey=a0) 69 | ax.plot(mem[i]) 70 | ax.axis("off") 71 | plt.tight_layout() 72 | plt.show() 73 | 74 | 75 | def print_classification_accuracy(output, target): 76 | """ Dirty little helper function to compute classification accuracy. """ 77 | m = bm.max(output, axis=1) # max over time 78 | am = bm.argmax(m, axis=1) # argmax over output units 79 | acc = bm.mean(target == am) # compare to labels 80 | print("Accuracy %.3f" % acc) 81 | 82 | 83 | def current2firing_time(x, tau=20., thr=0.2, epsilon=1e-7): 84 | """Computes first firing time latency for a current input x 85 | assuming the charge time of a current based LIF neuron. 86 | 87 | Args: 88 | x -- The "current" values 89 | 90 | Keyword args: 91 | tau -- The membrane time constant of the LIF neuron to be charged 92 | thr -- The firing threshold value 93 | tmax -- The maximum time returned 94 | epsilon -- A generic (small) epsilon > 0 95 | 96 | Returns: 97 | Time to first spike for each "current" x 98 | """ 99 | x = np.clip(x, thr + epsilon, 1e9) 100 | T = tau * np.log(x / (x - thr)) 101 | return T 102 | 103 | 104 | def sparse_data_generator(X, y, batch_size, nb_steps, nb_units, shuffle=True): 105 | """ This generator takes datasets in analog format and 106 | generates spiking network input as sparse tensors. 107 | 108 | Args: 109 | X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples 110 | y: The labels 111 | """ 112 | 113 | labels_ = np.array(y, dtype=bm.int_) 114 | sample_index = np.arange(len(X)) 115 | 116 | # compute discrete firing times 117 | tau_eff = 2. / bm.get_dt() 118 | unit_numbers = np.arange(nb_units) 119 | firing_times = np.array(current2firing_time(X, tau=tau_eff), dtype=bm.int_) 120 | 121 | if shuffle: 122 | np.random.shuffle(sample_index) 123 | 124 | counter = 0 125 | number_of_batches = len(X) // batch_size 126 | while counter < number_of_batches: 127 | batch_index = sample_index[batch_size * counter:batch_size * (counter + 1)] 128 | all_batch, all_times, all_units = [], [], [] 129 | for bc, idx in enumerate(batch_index): 130 | c = firing_times[idx] < nb_steps 131 | times, units = firing_times[idx][c], unit_numbers[c] 132 | batch = bc * np.ones(len(times), dtype=bm.int_) 133 | all_batch.append(batch) 134 | all_times.append(times) 135 | all_units.append(units) 136 | all_batch = np.concatenate(all_batch).flatten() 137 | all_times = np.concatenate(all_times).flatten() 138 | all_units = np.concatenate(all_units).flatten() 139 | x_batch = bm.zeros((batch_size, nb_steps, nb_units)) 140 | x_batch[all_batch, all_times, all_units] = 1. 141 | y_batch = bm.asarray(labels_[batch_index]) 142 | yield x_batch, y_batch 143 | counter += 1 144 | 145 | 146 | def train(model, x_data, y_data, lr=1e-3, nb_epochs=10, batch_size=128, nb_steps=128, nb_inputs=28 * 28): 147 | def loss_fun(predicts, targets): 148 | predicts, mon = predicts 149 | # Here we set up our regularizer loss 150 | # The strength paramters here are merely a guess and 151 | # there should be ample room for improvement by 152 | # tuning these paramters. 153 | l1_loss = 1e-5 * bm.sum(mon['r.spike']) # L1 loss on total number of spikes 154 | l2_loss = 1e-5 * bm.mean(bm.sum(bm.sum(mon['r.spike'], axis=0), axis=0) ** 2) # L2 loss on spikes per neuron 155 | # predictions 156 | predicts = bm.max(predicts, axis=1) 157 | loss = bp.losses.cross_entropy_loss(predicts, targets) 158 | return loss + l2_loss + l1_loss 159 | 160 | trainer = bp.BPTT( 161 | model, 162 | loss_fun, 163 | optimizer=bp.optim.Adam(lr=lr), 164 | monitors={'r.spike': net.r.spike}, 165 | ) 166 | trainer.fit(lambda: sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs), 167 | num_epoch=nb_epochs) 168 | return trainer.get_hist_metric('fit') 169 | 170 | 171 | def compute_classification_accuracy(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28): 172 | """ Computes classification accuracy on supplied data in batches. """ 173 | accs = [] 174 | runner = bp.DSRunner(model, progress_bar=False) 175 | for x_local, y_local in sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False): 176 | output = runner.predict(inputs=x_local, reset_state=True) 177 | m = bm.max(output, 1) # max over time 178 | am = bm.argmax(m, 1) # argmax over output units 179 | tmp = bm.mean(y_local == am) # compare to labels 180 | accs.append(tmp) 181 | return bm.mean(bm.asarray(accs)) 182 | 183 | 184 | def get_mini_batch_results(model, x_data, y_data, batch_size=128, nb_steps=100, nb_inputs=28 * 28): 185 | runner = bp.DSRunner(model, 186 | monitors={'r.spike': model.r.spike}, 187 | progress_bar=False) 188 | data = sparse_data_generator(x_data, y_data, batch_size, nb_steps, nb_inputs, shuffle=False) 189 | x_local, y_local = next(data) 190 | output = runner.predict(inputs=x_local, reset_state=True) 191 | return output, runner.mon.get('r.spike') 192 | 193 | 194 | num_input = 28 * 28 195 | net = SNN(num_in=num_input, num_rec=100, num_out=10) 196 | 197 | # load the dataset 198 | root = r"D:\data" 199 | train_dataset = bd.vision.FashionMNIST(root, split='train', download=True) 200 | test_dataset = bd.vision.FashionMNIST(root, split='test', download=True) 201 | 202 | # Standardize data 203 | x_train = np.array(train_dataset.data, dtype=bm.float_) 204 | x_train = x_train.reshape(x_train.shape[0], -1) / 255 205 | y_train = np.array(train_dataset.targets, dtype=bm.int_) 206 | x_test = np.array(test_dataset.data, dtype=bm.float_) 207 | x_test = x_test.reshape(x_test.shape[0], -1) / 255 208 | y_test = np.array(test_dataset.targets, dtype=bm.int_) 209 | 210 | # training 211 | train_losses = train(net, x_train, y_train, lr=1e-3, nb_epochs=30, batch_size=256, nb_steps=100, nb_inputs=28 * 28) 212 | 213 | plt.figure(figsize=(3.3, 2), dpi=150) 214 | plt.plot(train_losses) 215 | plt.xlabel("Epoch") 216 | plt.ylabel("Loss") 217 | plt.show() 218 | 219 | print("Training accuracy: %.3f" % (compute_classification_accuracy(net, x_train, y_train, batch_size=512))) 220 | print("Test accuracy: %.3f" % (compute_classification_accuracy(net, x_test, y_test, batch_size=512))) 221 | 222 | outs, spikes = get_mini_batch_results(net, x_train, y_train) 223 | # Let's plot the hidden layer spiking activity for some input stimuli 224 | fig = plt.figure(dpi=100) 225 | plot_voltage_traces(outs) 226 | plt.show() 227 | 228 | nb_plt = 4 229 | gs = GridSpec(1, nb_plt) 230 | plt.figure(figsize=(7, 3), dpi=150) 231 | for i in range(nb_plt): 232 | plt.subplot(gs[i]) 233 | plt.imshow(bm.as_numpy(spikes[i]).T, cmap=plt.cm.gray_r, origin="lower") 234 | if i == 0: 235 | plt.xlabel("Time") 236 | plt.ylabel("Units") 237 | plt.tight_layout() 238 | plt.show() 239 | -------------------------------------------------------------------------------- /brain_inspired_computing/fashion_mnist_conv_lif.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import time 7 | 8 | import brainpy as bp 9 | import brainpy.math as bm 10 | import brainpy_datasets as bd 11 | import jax 12 | import jax.numpy as jnp 13 | from jax import lax 14 | 15 | bm.set_environment(mode=bm.training_mode, dt=1.) 16 | 17 | 18 | class ConvLIF(bp.DynamicalSystem): 19 | def __init__(self, n_time: int, n_channel: int, tau: float = 5.): 20 | super().__init__() 21 | self.n_time = n_time 22 | 23 | lif_par = dict(keep_size=True, V_rest=0., V_reset=0., V_th=1., 24 | tau=tau, spk_fun=bm.surrogate.arctan) 25 | 26 | self.block1 = bp.Sequential( 27 | bp.layers.Conv2d(1, n_channel, kernel_size=3, padding=(1, 1), b_initializer=None), 28 | bp.layers.BatchNorm2d(n_channel, momentum=0.9), 29 | bp.dyn.Lif((28, 28, n_channel), **lif_par) 30 | ) 31 | self.block2 = bp.Sequential( 32 | bp.layers.MaxPool2d(2, 2), # 14 * 14 33 | bp.layers.Conv2d(n_channel, n_channel, kernel_size=3, padding=(1, 1), b_initializer=None), 34 | bp.layers.BatchNorm2d(n_channel, momentum=0.9), 35 | bp.dyn.Lif((14, 14, n_channel), **lif_par), 36 | ) 37 | self.block3 = bp.Sequential( 38 | bp.layers.MaxPool2d(2, 2), # 7 * 7 39 | bp.layers.Flatten(), 40 | bp.layers.Dense(n_channel * 7 * 7, n_channel * 4 * 4, b_initializer=None), 41 | bp.dyn.Lif(4 * 4 * n_channel, **lif_par), 42 | ) 43 | self.block4 = bp.Sequential( 44 | bp.layers.Dense(n_channel * 4 * 4, 10, b_initializer=None), 45 | bp.dyn.Lif(10, **lif_par), 46 | ) 47 | 48 | def update(self, x): 49 | # x.shape = [B, H, W, C] 50 | x = self.block1(x) 51 | x = self.block2(x) 52 | x = self.block3(x) 53 | return self.block4(x) 54 | 55 | 56 | class IFNode(bp.DynamicalSystem): 57 | """The Integrate-and-Fire neuron. The voltage of the IF neuron will 58 | not decay as that of the LIF neuron. The sub-threshold neural dynamics 59 | of it is as followed: 60 | 61 | .. math:: 62 | V[t] = V[t-1] + X[t] 63 | """ 64 | 65 | def __init__(self, size: tuple, v_threshold: float = 1., v_reset: float = 0., 66 | spike_fun=bm.surrogate.arctan, mode=None, reset_mode='soft'): 67 | super().__init__(mode=mode) 68 | bp.check.is_subclass(self.mode, bm.TrainingMode) 69 | 70 | self.size = bp.check.is_sequence(size, elem_type=int, allow_none=False) 71 | self.reset_mode = bp.check.is_string(reset_mode, candidates=['hard', 'soft']) 72 | self.v_threshold = bp.check.is_float(v_threshold) 73 | self.v_reset = bp.check.is_float(v_reset) 74 | self.spike_fun = bp.check.is_callable(spike_fun) 75 | 76 | # variables 77 | self.V = bm.Variable(jnp.zeros((1,) + size, dtype=bm.float_), batch_axis=0) 78 | 79 | def reset_state(self, batch_size): 80 | self.V.value = jnp.zeros((batch_size,) + self.size, dtype=bm.float_) 81 | 82 | def update(self, x): 83 | self.V.value += x 84 | spike = self.spike_fun(self.V - self.v_threshold) 85 | if self.reset_mode == 'hard': 86 | one = lax.convert_element_type(1., bm.float_) 87 | self.V.value = self.v_reset * spike + (one - spike) * self.V 88 | else: 89 | self.V -= spike * self.v_threshold 90 | return spike 91 | 92 | 93 | class ConvIF(bp.DynamicalSystem): 94 | def __init__(self, n_time: int, n_channel: int): 95 | super().__init__() 96 | self.n_time = n_time 97 | 98 | self.block1 = bp.Sequential( 99 | bp.layers.Conv2d(1, n_channel, kernel_size=3, padding=(1, 1), ), 100 | bp.layers.BatchNorm2d(n_channel, momentum=0.9), 101 | IFNode((28, 28, n_channel), spike_fun=bm.surrogate.arctan) 102 | ) 103 | self.block2 = bp.Sequential( 104 | bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 14 * 14 105 | bp.layers.Conv2d(n_channel, n_channel, kernel_size=3, padding=(1, 1), ), 106 | bp.layers.BatchNorm2d(n_channel, momentum=0.9), 107 | IFNode((14, 14, n_channel), spike_fun=bm.surrogate.arctan), 108 | ) 109 | self.block3 = bp.Sequential( 110 | bp.layers.MaxPool([2, 2], 2, channel_axis=-1), # 7 * 7 111 | bp.layers.Flatten(), 112 | bp.layers.Dense(n_channel * 7 * 7, n_channel * 4 * 4, ), 113 | IFNode((4 * 4 * n_channel,), spike_fun=bm.surrogate.arctan), 114 | ) 115 | self.block4 = bp.Sequential( 116 | bp.layers.Dense(n_channel * 4 * 4, 10, ), 117 | IFNode((10,), spike_fun=bm.surrogate.arctan), 118 | ) 119 | 120 | def update(self, x): 121 | x = self.block1(x) # x.shape = [B, H, W, C] 122 | x = self.block2(x) 123 | x = self.block3(x) 124 | x = self.block4(x) 125 | return x 126 | 127 | 128 | class TrainMNIST: 129 | def __init__(self, net, n_time): 130 | self.net = net 131 | self.n_time = n_time 132 | self.f_opt = bp.optim.Adam(bp.optim.ExponentialDecay(0.2, 1, 0.9999), 133 | train_vars=net.train_vars().unique()) 134 | self.f_grad = bm.grad(self.loss, grad_vars=self.f_opt.vars_to_train, 135 | has_aux=True, return_value=True) 136 | 137 | def inference(self, X, fit=False): 138 | def run_net(t): 139 | bp.share.save(t=t, fit=fit) 140 | return self.net(X) 141 | 142 | self.net.reset_state(X.shape[0]) 143 | return bm.for_loop(run_net, jnp.arange(self.n_time, dtype=bm.float_), jit=False) 144 | 145 | def loss(self, X, Y, fit=False): 146 | fr = bm.max(self.inference(X, fit), axis=0) 147 | ys_onehot = bm.one_hot(Y, 10, dtype=bm.float_) 148 | l = bp.losses.mean_squared_error(fr, ys_onehot) 149 | n = bm.sum(fr.argmax(1) == Y) 150 | return l, n 151 | 152 | @bm.cls_jit 153 | def f_predict(self, X, Y): 154 | return self.loss(X, Y, fit=False) 155 | 156 | @bm.cls_jit 157 | def f_train(self, X, Y): 158 | bp.share.save(fit=True) 159 | grads, l, n = self.f_grad(X, Y) 160 | self.f_opt.update(grads) 161 | return l, n 162 | 163 | 164 | def main(): 165 | parser = argparse.ArgumentParser(description='Classify Fashion-MNIST') 166 | parser.add_argument('-platform', default='cpu', help='platform') 167 | parser.add_argument('-model', default='lif', help='Neuron model to use') 168 | parser.add_argument('-n_time', default=4, type=int, help='simulating time-steps') 169 | parser.add_argument('-tau', default=5., type=float, help='LIF time constant') 170 | parser.add_argument('-batch', default=128, type=int, help='batch size') 171 | parser.add_argument('-n_channel', default=128, type=int, help='channels of ConvLIF') 172 | parser.add_argument('-n_epoch', default=64, type=int, metavar='N', help='number of total epochs to run') 173 | parser.add_argument('-data-dir', default='d:/data', type=str, help='root dir of Fashion-MNIST dataset') 174 | parser.add_argument('-out-dir', default='./logs', type=str, help='root dir for saving logs and checkpoint') 175 | parser.add_argument('-lr', default=0.1, type=float, help='learning rate') 176 | args = parser.parse_args() 177 | print(args) 178 | 179 | bm.set_platform(args.platform) 180 | 181 | # net 182 | if args.model == 'if': 183 | net = ConvIF(n_time=args.n_time, n_channel=args.n_channel) 184 | out_dir = os.path.join(args.out_dir, 185 | f'{args.model}_T{args.n_time}_b{args.batch}' 186 | f'_lr{args.lr}_c{args.n_channel}') 187 | elif args.model == 'lif': 188 | net = ConvLIF(n_time=args.n_time, n_channel=args.n_channel, tau=args.tau) 189 | out_dir = os.path.join(args.out_dir, 190 | f'{args.model}_T{args.n_time}_b{args.batch}' 191 | f'_lr{args.lr}_c{args.n_channel}_tau{args.tau}') 192 | else: 193 | raise ValueError 194 | 195 | trainer = TrainMNIST(net, args.n_time) 196 | 197 | # dataset 198 | train_set = bd.vision.FashionMNIST(root=args.data_dir, split='train', download=True) 199 | test_set = bd.vision.FashionMNIST(root=args.data_dir, split='test', download=True) 200 | x_train = jnp.asarray(train_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1)) 201 | y_train = jnp.asarray(train_set.targets, dtype=bm.int_) 202 | x_test = jnp.asarray(test_set.data / 255, dtype=bm.float_).reshape((-1, 28, 28, 1)) 203 | y_test = jnp.asarray(test_set.targets, dtype=bm.int_) 204 | 205 | os.makedirs(out_dir, exist_ok=True) 206 | with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: 207 | args_txt.write(str(args)) 208 | args_txt.write('\n') 209 | args_txt.write(' '.join(sys.argv)) 210 | 211 | max_test_acc = -1 212 | for epoch_i in range(0, args.n_epoch): 213 | start_time = time.time() 214 | loss, train_acc = [], 0. 215 | for i in range(0, x_train.shape[0], args.batch): 216 | xs = x_train[i: i + args.batch] 217 | ys = y_train[i: i + args.batch] 218 | with jax.disable_jit(): 219 | l, n = trainer.f_train(xs, ys) 220 | loss.append(l) 221 | train_acc += n 222 | train_acc /= x_train.shape[0] 223 | train_loss = jnp.mean(jnp.asarray(loss)) 224 | trainer.f_opt.lr.step_epoch() 225 | 226 | loss, test_acc = [], 0. 227 | for i in range(0, x_test.shape[0], args.batch): 228 | xs = x_test[i: i + args.batch] 229 | ys = y_test[i: i + args.batch] 230 | l, n = trainer.f_predict(xs, ys) 231 | loss.append(l) 232 | test_acc += n 233 | test_acc /= x_test.shape[0] 234 | test_loss = jnp.mean(jnp.asarray(loss)) 235 | 236 | t = (time.time() - start_time) / 60 237 | print(f'epoch {epoch_i}, used {t:.3f} min, ' 238 | f'train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, ' 239 | f'test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}') 240 | 241 | if max_test_acc < test_acc: 242 | max_test_acc = test_acc 243 | states = { 244 | 'net': net.state_dict(), 245 | 'optimizer': trainer.f_opt.state_dict(), 246 | 'epoch_i': epoch_i, 247 | 'train_acc': train_acc, 248 | 'test_acc': test_acc, 249 | } 250 | bp.checkpoints.save_pytree(os.path.join(out_dir, 'fmnist-conv-lif.bp'), states) 251 | 252 | # inference 253 | state_dict = bp.checkpoints.load_pytree(os.path.join(out_dir, 'fmnist-conv-lif.bp')) 254 | net.load_state_dict(state_dict['net']) 255 | correct_num = 0 256 | for i in range(0, x_test.shape[0], 512): 257 | xs = x_test[i: i + 512] 258 | ys = y_test[i: i + 512] 259 | correct_num += trainer.f_predict(xs, ys)[1] 260 | print('Max test accuracy: ', correct_num / x_test.shape[0]) 261 | 262 | 263 | if __name__ == '__main__': 264 | main() 265 | -------------------------------------------------------------------------------- /brain_inspired_computing/mnist_lif_readout.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import time 4 | import argparse 5 | import os.path 6 | import sys 7 | 8 | import brainpy_datasets as bd 9 | 10 | import jax.numpy as jnp 11 | 12 | import brainpy as bp 13 | import brainpy.math as bm 14 | import numpy as np 15 | 16 | parser = argparse.ArgumentParser(description='LIF MNIST Training') 17 | parser.add_argument('-T', default=100, type=int, help='simulating time-steps') 18 | parser.add_argument('-platform', default='cpu', help='device') 19 | parser.add_argument('-batch', default=64, type=int, help='batch size') 20 | parser.add_argument('-epochs', default=15, type=int, metavar='N', help='number of total epochs to run') 21 | parser.add_argument('-out-dir', type=str, default='./logs', help='root dir for saving logs and checkpoint') 22 | parser.add_argument('-lr', default=1e-3, type=float, help='learning rate') 23 | parser.add_argument('-tau', default=2.0, type=float, help='parameter tau of LIF neuron') 24 | args = parser.parse_args() 25 | print(args) 26 | 27 | out_dir = os.path.join(args.out_dir, f'T{args.T}_b{args.batch}_lr{args.lr}') 28 | if not os.path.exists(out_dir): 29 | os.makedirs(out_dir) 30 | with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt: 31 | args_txt.write(str(args)) 32 | args_txt.write('\n') 33 | args_txt.write(' '.join(sys.argv)) 34 | 35 | bm.set_platform(args.platform) 36 | bm.set_environment(mode=bm.training_mode, dt=1.) 37 | 38 | 39 | class SNN(bp.DynamicalSystem): 40 | def __init__(self, tau): 41 | super().__init__() 42 | self.l1 = bp.dnn.Dense(28 * 28, 10, b_initializer=None) 43 | self.l2 = bp.dyn.Lif(10, V_rest=0., V_reset=0., V_th=1., tau=tau, spk_fun=bm.surrogate.arctan) 44 | 45 | def update(self, x): 46 | return x >> self.l1 >> self.l2 47 | 48 | 49 | net = SNN(args.tau) 50 | 51 | # data 52 | train_data = bd.vision.MNIST(r'D:/data', split='train', download=True) 53 | test_data = bd.vision.MNIST(r'D:/data', split='test', download=True) 54 | x_train = bm.asarray(train_data.data / 255, dtype=bm.float_).reshape(-1, 28 * 28) 55 | y_train = bm.asarray(train_data.targets, dtype=bm.int_) 56 | x_test = bm.asarray(test_data.data / 255, dtype=bm.float_).reshape(-1, 28 * 28) 57 | y_test = bm.asarray(test_data.targets, dtype=bm.int_) 58 | 59 | # loss 60 | encoder = bp.encoding.PoissonEncoder(min_val=0., max_val=1.) 61 | 62 | 63 | def loss_fun(xs, ys): 64 | net.reset_state(batch_size=xs.shape[0]) 65 | xs = encoder.multi_steps(xs, n_time=args.T * bm.get_dt()) 66 | # shared arguments for looping over time 67 | indices = np.arange(args.T) 68 | outs = bm.for_loop(net.step_run, (indices, xs)) 69 | out_fr = bm.mean(outs, axis=0) 70 | ys_onehot = bm.one_hot(ys, 10, dtype=bm.float_) 71 | l = bp.losses.mean_squared_error(out_fr, ys_onehot) 72 | n = bm.sum(out_fr.argmax(1) == ys) 73 | return l, n 74 | 75 | 76 | # gradient 77 | grad_fun = bm.grad(loss_fun, grad_vars=net.train_vars().unique(), has_aux=True, return_value=True) 78 | 79 | # optimizer 80 | optimizer = bp.optim.Adam(lr=args.lr, train_vars=net.train_vars().unique()) 81 | 82 | 83 | # train 84 | @bm.jit 85 | def train(xs, ys): 86 | grads, l, n = grad_fun(xs, ys) 87 | optimizer.update(grads) 88 | return l, n 89 | 90 | 91 | max_test_acc = 0. 92 | 93 | # computing 94 | for epoch_i in range(args.epochs): 95 | bm.random.shuffle(x_train, key=123) 96 | bm.random.shuffle(y_train, key=123) 97 | 98 | t0 = time.time() 99 | loss, train_acc = [], 0. 100 | for i in range(0, x_train.shape[0], args.batch): 101 | X = x_train[i: i + args.batch] 102 | Y = y_train[i: i + args.batch] 103 | l, correct_num = train(X, Y) 104 | loss.append(l) 105 | train_acc += correct_num 106 | train_acc /= x_train.shape[0] 107 | train_loss = jnp.mean(jnp.asarray(loss)) 108 | optimizer.lr.step_epoch() 109 | 110 | loss, test_acc = [], 0. 111 | for i in range(0, x_test.shape[0], args.batch): 112 | X = x_test[i: i + args.batch] 113 | Y = y_test[i: i + args.batch] 114 | l, correct_num = loss_fun(X, Y) 115 | loss.append(l) 116 | test_acc += correct_num 117 | test_acc /= x_test.shape[0] 118 | test_loss = jnp.mean(jnp.asarray(loss)) 119 | 120 | t = (time.time() - t0) / 60 121 | print(f'epoch {epoch_i}, used {t:.3f} min, ' 122 | f'train_loss = {train_loss:.4f}, train_acc = {train_acc:.4f}, ' 123 | f'test_loss = {test_loss:.4f}, test_acc = {test_acc:.4f}') 124 | 125 | if max_test_acc < test_acc: 126 | max_test_acc = test_acc 127 | states = { 128 | 'net': net.state_dict(), 129 | 'optimizer': optimizer.state_dict(), 130 | 'epoch_i': epoch_i, 131 | 'train_acc': train_acc, 132 | 'test_acc': test_acc, 133 | } 134 | bp.checkpoints.save_pytree(os.path.join(out_dir, 'mnist-lif.bp'), states) 135 | 136 | # inference 137 | state_dict = bp.checkpoints.load_pytree(os.path.join(out_dir, 'mnist-lif.bp')) 138 | net.load_state_dict(state_dict['net']) 139 | 140 | runner = bp.DSRunner(net, data_first_axis='T') 141 | correct_num = 0 142 | for i in range(0, x_test.shape[0], 512): 143 | X = encoder.multi_steps(x_test[i: i + 512], n_time=args.T * bm.get_dt()) 144 | Y = y_test[i: i + 512] 145 | out_fr = bm.mean(runner.predict(inputs=X, reset_state=True), axis=0) 146 | correct_num += bm.sum(out_fr.argmax(1) == Y) 147 | 148 | print('Max test accuracy: ', correct_num / x_test.shape[0]) 149 | -------------------------------------------------------------------------------- /brain_inspired_computing/spiking_FPTT/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/brain_inspired_computing/spiking_FPTT/README.md -------------------------------------------------------------------------------- /brain_inspired_computing/spiking_FPTT/add_task.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import time 4 | 5 | import brainpy as bp 6 | import brainpy.math as bm 7 | import numba 8 | import numpy as np 9 | 10 | 11 | @numba.njit(fastmath=True, nogil=True) 12 | def _adding_problem_generator(X_num, X_mask, Y, N, seq_len=6, number_of_ones=2): 13 | for i in numba.prange(N): 14 | positions1 = np.random.choice(np.arange(math.floor(seq_len / 2)), 15 | size=math.floor(number_of_ones / 2), 16 | replace=False) 17 | positions2 = np.random.choice(np.arange(math.ceil(seq_len / 2), seq_len), 18 | size=math.ceil(number_of_ones / 2), 19 | replace=False) 20 | for p in positions1: 21 | X_mask[p, i] = 1 22 | Y[i, 0] += X_num[p, i, 0] 23 | for p in positions2: 24 | X_mask[p, i] = 1 25 | Y[i, 0] += X_num[p, i, 0] 26 | 27 | 28 | class AddTask: 29 | def __init__(self, seq_len=6, high=1, number_of_ones=2): 30 | self.seq_len = seq_len 31 | self.high = high 32 | self.number_of_ones = number_of_ones 33 | 34 | def __call__(self, batch_size): 35 | x_num = np.random.uniform(low=0, high=self.high, size=(self.seq_len, batch_size, 1)) 36 | x_mask = np.zeros((self.seq_len, batch_size, 1)) 37 | ys = np.ones((batch_size, 1)) 38 | _adding_problem_generator(x_num, x_mask, ys, batch_size, self.seq_len, self.number_of_ones) 39 | xs = np.append(x_num, x_mask, axis=2) 40 | return xs, ys 41 | 42 | 43 | class IF(bp.DynamicalSystemNS): 44 | def __init__( 45 | self, size, V_th=0.5, 46 | spike_fun: bm.surrogate.Surrogate = bm.surrogate.MultiGaussianGrad() 47 | ): 48 | super().__init__() 49 | 50 | self.size = size 51 | self.V_th = V_th 52 | self.spike_fun = spike_fun 53 | self.reset_state(self.mode) 54 | 55 | def reset_state(self, batch_size=1): 56 | self.V = bp.init.variable_(bm.zeros, self.size, batch_size) 57 | self.spike = bp.init.variable_(bm.zeros, self.size, batch_size) 58 | 59 | def update(self, x, tau_m): 60 | mem = self.V + (-self.V + x) * tau_m 61 | self.spike.value = self.spike_fun(mem - self.V_th) 62 | self.V.value = (1 - self.spike) * mem 63 | 64 | 65 | class SNN(bp.DynamicalSystemNS): 66 | def __init__(self, input_size, hidden_size): 67 | super().__init__() 68 | 69 | self.lin_inp = bp.layers.Linear(input_size + hidden_size, hidden_size) 70 | self.lin_tau = bp.layers.Linear(hidden_size + hidden_size, hidden_size) 71 | self.rnn = IF(hidden_size) 72 | self.out = bp.layers.Linear(hidden_size, 1, W_initializer=bp.init.XavierNormal()) 73 | self.act = bp.layers.Sigmoid() 74 | 75 | self.loss_func = bp.losses.MSELoss() 76 | 77 | def _step(self, x): 78 | inp = self.lin_inp(bm.cat((x, self.rnn.spike), dim=-1)) 79 | tau = self.act(self.lin_tau(bm.cat((inp, self.rnn.V), dim=-1))) 80 | self.rnn(inp, tau) 81 | 82 | def update(self, xs, y): 83 | # xs: (num_time, num_batch, num_hidden) 84 | bm.for_loop(self._step, xs) 85 | out = self.out(self.rnn.V) 86 | out = out.squeeze() 87 | y = y.squeeze() 88 | loss = self.loss_func(out, y) 89 | return loss, out 90 | 91 | 92 | class FPTT_Trainer: 93 | def __init__( 94 | self, 95 | net: bp.DynamicalSystemNS, 96 | opt: bp.optim.Optimizer, 97 | clip: float, 98 | alpha: float = 0.1, 99 | beta: float = 0.5, 100 | rho: float = 0.0, 101 | ): 102 | super().__init__() 103 | self.alpha = alpha 104 | self.beta = beta 105 | self.rho = rho 106 | self.clip = clip 107 | 108 | # objects 109 | self.net = net 110 | self.opt = opt 111 | opt.register_train_vars(net.train_vars().unique()) 112 | 113 | # parameters 114 | self.named_params = {} 115 | for name, param in self.opt.vars_to_train.items(): 116 | sm = bm.Variable(param.clone()) 117 | lm = bm.Variable(bm.zeros_like(param)) 118 | self.named_params[name] = (sm, lm) 119 | 120 | def reset_params(self): 121 | for name, param in self.opt.vars_to_train.items(): 122 | param.value = self.named_params[name][0].value 123 | 124 | def update_params(self): 125 | for name, param in self.opt.vars_to_train.items(): 126 | sm, lm = self.named_params[name] 127 | lm += (-self.alpha * (param - sm)) 128 | sm *= (1.0 - self.beta) 129 | sm += (self.beta * param - (self.beta / self.alpha) * lm) 130 | 131 | def dyn_loss(self, lambd=1.): 132 | regularization = 0. 133 | for name, param in self.opt.vars_to_train.items(): 134 | sm, lm = self.named_params[name] 135 | regularization += (self.rho - 1.) * bm.sum(param * lm) 136 | regularization += lambd * 0.5 * self.alpha * bm.sum(bm.square(param - sm)) 137 | return regularization 138 | 139 | def f_loss(self, xs, ys, progress): 140 | l, _ = self.net(xs, ys) 141 | reg = self.dyn_loss() 142 | return l * progress + reg, (l, reg) 143 | 144 | @bm.cls_jit 145 | def predict(self, xs, ys): 146 | return self.net(xs, ys)[0] 147 | 148 | @bm.cls_jit 149 | def fit(self, xs, ys, progress): 150 | grads, (loss, reg) = bm.grad(self.f_loss, grad_vars=self.opt.vars_to_train, has_aux=True)(xs, ys, progress) 151 | grads = bm.clip_by_norm(grads, self.clip) 152 | self.opt.update(grads) 153 | self.update_params() 154 | return loss, reg 155 | 156 | 157 | class BPTT_Trainer: 158 | def __init__( 159 | self, 160 | net: bp.DynamicalSystemNS, 161 | opt: bp.optim.Optimizer, 162 | clip: float, 163 | ): 164 | super().__init__() 165 | self.clip = clip 166 | 167 | # objects 168 | self.net = net 169 | self.opt = opt 170 | opt.register_train_vars(net.train_vars().unique()) 171 | 172 | def f_loss(self, xs, ys): 173 | l, _ = self.net(xs, ys) 174 | return l 175 | 176 | @bm.cls_jit 177 | def predict(self, xs, ys): 178 | return self.net(xs, ys)[0] 179 | 180 | @bm.cls_jit 181 | def fit(self, xs, ys): 182 | grads, loss = bm.grad(self.f_loss, grad_vars=self.opt.vars_to_train, return_value=True)(xs, ys) 183 | grads = bm.clip_by_norm(grads, self.clip) 184 | self.opt.update(grads) 185 | return loss 186 | 187 | 188 | def fptt_training(): 189 | parser = argparse.ArgumentParser() 190 | parser.add_argument('--alpha', type=float, default=.1, help='Alpha') 191 | parser.add_argument('--beta', type=float, default=0.5, help='Beta') 192 | parser.add_argument('--rho', type=float, default=0.0, help='Rho') 193 | parser.add_argument('--bptt', type=int, default=300, help='sequence length') 194 | parser.add_argument('--nhid', type=int, default=128, help='number of hidden units per layer') 195 | parser.add_argument('--lr', type=float, default=3e-3, help='initial learning rate (default: 4e-3)') 196 | parser.add_argument('--clip', type=float, default=1., help='gradient clipping') 197 | parser.add_argument('--epochs', type=int, default=1000, help='upper epoch limit (default: 200)') 198 | parser.add_argument('--parts', type=int, default=10, help='Parts to split the sequential input into (default: 10)') 199 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='batch size') 200 | parser.add_argument('--save', type=str, default='', help='path of model to save') 201 | parser.add_argument('--load', type=str, default='', help='path of model to load') 202 | parser.add_argument('--wdecay', type=float, default=0.0, help='weight decay') 203 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 204 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use') 205 | args = parser.parse_args() 206 | 207 | bm.random.seed(args.seed) 208 | 209 | # model 210 | with bm.environment(mode=bm.TrainingMode(batch_size=args.batch_size)): 211 | model = SNN(2, args.nhid) 212 | 213 | # dataset 214 | dataset = AddTask(args.bptt, number_of_ones=2) 215 | 216 | # optimizer 217 | if args.optim == 'adam': 218 | optimizer = bp.optim.Adam(lr=args.lr, weight_decay=args.wdecay) 219 | elif args.optim == 'sgd': 220 | optimizer = bp.optim.SGD(lr=args.lr, weight_decay=args.wdecay) 221 | else: 222 | raise ValueError 223 | 224 | # trainer 225 | trainer = FPTT_Trainer(model, optimizer, args.clip, alpha=args.alpha, beta=args.beta, rho=args.rho) 226 | 227 | # loading 228 | if args.load: 229 | states = bp.checkpoints.load_pytree(args.load) 230 | model.load_state_dict(states['model']) 231 | optimizer.load_state_dict(states['opt']) 232 | 233 | # training 234 | step = args.bptt // args.parts 235 | for epoch in range(1, args.epochs + 1): 236 | model.reset_state(args.batch_size) 237 | 238 | # fitting 239 | s_t = time.time() 240 | x, y = dataset(args.batch_size) 241 | losses, regs = [], [] 242 | for p in range(0, args.parts): 243 | start = p * step 244 | l, r = trainer.fit(x[start: start + step], y, (p + 1) / args.parts) 245 | losses.append(l.item()) 246 | regs.append(r.item()) 247 | 248 | # prediction 249 | x, y = dataset(args.batch_size) 250 | loss_act = trainer.predict(x, y).item() 251 | print(f'Epoch {epoch}, ' 252 | f'time {time.time() - s_t:.4f} s, ' 253 | f'train loss {np.mean(np.array(losses)):.4f}, ' 254 | f'train reg {np.mean(np.array(regs)):.4f}, ' 255 | f'test loss {loss_act:.4f}. ') 256 | trainer.reset_params() 257 | trainer.opt.lr.step_epoch() 258 | 259 | 260 | def bptt_training(): 261 | parser = argparse.ArgumentParser() 262 | parser.add_argument('--bptt', type=int, default=300, help='sequence length') 263 | parser.add_argument('--nhid', type=int, default=128, help='number of hidden units per layer') 264 | parser.add_argument('--lr', type=float, default=3e-3, help='initial learning rate (default: 4e-3)') 265 | parser.add_argument('--clip', type=float, default=1., help='gradient clipping') 266 | parser.add_argument('--epochs', type=int, default=1000, help='upper epoch limit (default: 200)') 267 | parser.add_argument('--parts', type=int, default=10, help='Parts to split the sequential input into (default: 10)') 268 | parser.add_argument('--batch_size', type=int, default=128, metavar='N', help='batch size') 269 | parser.add_argument('--save', type=str, default='', help='path of model to save') 270 | parser.add_argument('--load', type=str, default='', help='path of model to load') 271 | parser.add_argument('--wdecay', type=float, default=0.0, help='weight decay') 272 | parser.add_argument('--seed', type=int, default=1111, help='random seed') 273 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use') 274 | args = parser.parse_args() 275 | 276 | bm.random.seed(args.seed) 277 | 278 | # model 279 | with bm.environment(mode=bm.TrainingMode(batch_size=args.batch_size)): 280 | model = SNN(2, args.nhid) 281 | 282 | # dataset 283 | dataset = AddTask(args.bptt, number_of_ones=2) 284 | 285 | # optimizer 286 | if args.optim == 'adam': 287 | optimizer = bp.optim.Adam(lr=args.lr, weight_decay=args.wdecay) 288 | elif args.optim == 'sgd': 289 | optimizer = bp.optim.SGD(lr=args.lr, weight_decay=args.wdecay) 290 | else: 291 | raise ValueError 292 | 293 | # trainer 294 | trainer = BPTT_Trainer(model, optimizer, args.clip) 295 | 296 | # loading 297 | if args.load: 298 | states = bp.checkpoints.load_pytree(args.load) 299 | model.load_state_dict(states['model']) 300 | optimizer.load_state_dict(states['opt']) 301 | 302 | # training 303 | for epoch in range(1, args.epochs + 1): 304 | model.reset_state(args.batch_size) 305 | 306 | # fitting 307 | s_t = time.time() 308 | x, y = dataset(args.batch_size) 309 | loss = trainer.fit(x, y).item() 310 | 311 | # prediction 312 | x, y = dataset(args.batch_size) 313 | loss_act = trainer.predict(x, y).item() 314 | 315 | print(f'Epoch {epoch}, ' 316 | f'time {time.time() - s_t:.4f} s, ' 317 | f'train loss {loss:.4f}, ' 318 | f'test loss {loss_act:.4f}. ') 319 | trainer.opt.lr.step_epoch() 320 | 321 | 322 | if __name__ == '__main__': 323 | # fptt_training() 324 | bptt_training() 325 | 326 | -------------------------------------------------------------------------------- /brain_inspired_computing/spiking_FPTT/mnist_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import numpy as np 4 | 5 | import brainpy as bp 6 | import brainpy.math as bm 7 | import brainpy_datasets as bdata 8 | from functools import partial 9 | 10 | 11 | class SigmoidBeta(bp.DynamicalSystemNS): 12 | def __init__(self, alpha=1., is_train=False): 13 | super(SigmoidBeta, self).__init__() 14 | if alpha is None: 15 | self.alpha = bm.asarray(1.) # create a tensor out of alpha 16 | else: 17 | self.alpha = bm.asarray(alpha) # create a tensor out of alpha 18 | if is_train: 19 | self.alpha = bm.TrainVar(self.alpha) 20 | 21 | def update(self, x): 22 | return bm.sigmoid(self.alpha * x) 23 | 24 | 25 | class LTC_LIF(bp.DynamicalSystemNS): 26 | def __init__( 27 | self, input_size, hidden_size, beta=1.8, b0=0.1, 28 | spike_fun: bm.surrogate.Surrogate = bm.surrogate.MultiGaussianGrad() 29 | ): 30 | super().__init__() 31 | 32 | self.hidden_size = hidden_size 33 | self.beta = beta 34 | self.b0 = b0 35 | self.spike_fun = spike_fun 36 | 37 | self.lin = bp.layers.Linear(input_size, hidden_size, W_initializer=bp.init.XavierNormal()) 38 | self.lr = bp.layers.Linear(hidden_size, hidden_size, W_initializer=bp.init.Orthogonal()) 39 | self.tauM = bp.layers.Linear(hidden_size + hidden_size, hidden_size, 40 | W_initializer=bp.init.XavierNormal()) 41 | self.tauAdp = bp.layers.Linear(hidden_size + hidden_size, hidden_size, 42 | W_initializer=bp.init.XavierNormal()) 43 | self.actM = SigmoidBeta(is_train=True) 44 | self.actAdp = SigmoidBeta(is_train=True) 45 | 46 | self.reset_state(self.mode) 47 | 48 | def reset_state(self, batch_size=1): 49 | self.V = bp.init.variable_(bm.zeros, self.hidden_size, batch_size) 50 | self.b = bp.init.variable_(bm.zeros, self.hidden_size, batch_size) 51 | self.spike = bp.init.variable_(bm.zeros, self.hidden_size, batch_size) 52 | 53 | def update(self, x): 54 | encoding = self.lin(x) + self.lr(self.spike) 55 | tauM = self.actM(self.tauM(bm.cat((encoding, self.V), dim=-1))) 56 | tauAdp = self.actAdp(self.tauAdp(bm.cat((encoding, self.b), dim=-1))) 57 | b = tauAdp * self.b + (1 - tauAdp) * self.spike 58 | self.b.value = b 59 | B = self.b0 + self.beta * b 60 | d_mem = -self.V + encoding 61 | mem = self.V + d_mem * tauM 62 | spike = self.spike_fun(mem - B) 63 | self.V.value = (1 - spike) * mem 64 | self.spike.value = spike 65 | return spike 66 | 67 | 68 | class ReadoutLIF(bp.DynamicalSystemNS): 69 | def __init__( 70 | self, input_size, hidden_size, beta=1.8, b0=0.1, 71 | ): 72 | super().__init__() 73 | 74 | self.size = hidden_size 75 | self.beta = beta 76 | self.b0 = b0 77 | 78 | self.lin = bp.layers.Linear(input_size, hidden_size, 79 | W_initializer=bp.init.XavierNormal()) 80 | self.tauM = bp.layers.Linear(hidden_size + hidden_size, hidden_size, 81 | W_initializer=bp.init.XavierNormal()) 82 | self.actM = SigmoidBeta(is_train=True) 83 | 84 | self.reset_state(self.mode) 85 | 86 | def reset_state(self, batch_size=1): 87 | self.V = bp.init.variable_(bm.zeros, self.size, batch_size) 88 | 89 | def update(self, x): 90 | encoding = self.lin(x) 91 | tauM = self.actM(self.tauM(bm.cat((encoding, self.V), dim=-1))) 92 | mem = (1 - tauM) * self.V + tauM * encoding 93 | self.V.value = mem 94 | return mem 95 | 96 | 97 | class SNN(bp.DynamicalSystemNS): 98 | def __init__(self, input_size, hidden_size, output_size): 99 | super(SNN, self).__init__() 100 | 101 | self.input_size = input_size 102 | self.hidden_size = hidden_size 103 | self.output_size = output_size 104 | 105 | self.layer1 = LTC_LIF(input_size, hidden_size) 106 | self.layer2 = LTC_LIF(hidden_size, hidden_size) 107 | self.layer3 = ReadoutLIF(hidden_size, output_size) 108 | 109 | self.fr = bm.Variable(bm.asarray(0.)) 110 | 111 | def update(self, x): 112 | x = bm.expand_dims(x, axis=1) 113 | spk_1 = self.layer1(x) 114 | spk_2 = self.layer2(spk_1) 115 | out = self.layer3(spk_2) 116 | self.fr += (spk_1.mean() + spk_2.mean()) 117 | return out 118 | 119 | 120 | class FPTT_Trainer: 121 | def __init__( 122 | self, 123 | net: bp.DynamicalSystem, 124 | optimizer: bp.optim.Optimizer, 125 | 126 | debias: bool = False, 127 | clip: float = 0., 128 | alpha: float = 0.1, 129 | beta: float = 0.5, 130 | rho: float = 0., 131 | ): 132 | self.clip = clip 133 | self.alpha = alpha 134 | self.beta = beta 135 | self.rho = rho 136 | self.debias = debias 137 | 138 | self.optimizer = optimizer 139 | self.net = net 140 | optimizer.register_train_vars(net.train_vars().unique()) 141 | 142 | # parameters 143 | self.named_params = {} 144 | for name, param in optimizer.vars_to_train.items(): 145 | sm = bm.Variable(param.clone()) 146 | lm = bm.Variable(bm.zeros_like(param)) 147 | if debias: 148 | dm = bm.Variable(bm.zeros_like(param)) 149 | self.named_params[name] = (sm, lm, dm) 150 | else: 151 | self.named_params[name] = (sm, lm) 152 | 153 | def reset_params(self): 154 | if not self.debias: 155 | for name, param in self.optimizer.vars_to_train.items(): 156 | param.value = self.named_params[name][0].value 157 | 158 | def update_params(self, epoch): 159 | for name, param in self.optimizer.vars_to_train.items(): 160 | if self.debias: 161 | sm, lm, dm = self.named_params[name] 162 | beta = (1. / (1. + epoch)) 163 | sm *= (1.0 - beta) 164 | sm += (beta * param) 165 | dm *= (1. - beta) 166 | dm += (beta * lm) 167 | else: 168 | sm, lm = self.named_params[name] 169 | lm += (-self.alpha * (param - sm)) 170 | sm *= (1.0 - self.beta) 171 | sm += (self.beta * param - (self.beta / self.alpha) * lm) 172 | 173 | def dyn_loss(self, lambd=1.): 174 | reg = 0. 175 | for name, param in self.optimizer.vars_to_train.items(): 176 | if self.debias: 177 | sm, lm, dm = self.named_params[name] 178 | reg += (self.rho - 1.) * bm.sum(param * lm) 179 | reg += (1. - self.rho) * bm.sum(param * dm) 180 | else: 181 | sm, lm = self.named_params[name] 182 | reg += (self.rho - 1.) * bm.sum(param * lm) 183 | reg += lambd * 0.5 * self.alpha * bm.sum(bm.square(param - sm)) 184 | return reg 185 | 186 | def _loss(self, x, y, progress): 187 | out = self.net(x) 188 | loss = progress * bp.losses.cross_entropy_loss(out, y, reduction='mean') 189 | reg = self.dyn_loss() 190 | return loss + reg, (loss, reg) 191 | 192 | def _train(self, x, progress, y, epoch): 193 | grads, (loss, reg) = bm.grad(self._loss, grad_vars=self.optimizer.vars_to_train, has_aux=True)( 194 | x, y, progress 195 | ) 196 | if self.clip > 0.: 197 | grads = bm.clip_by_norm(grads, self.clip) 198 | self.optimizer.update(grads) 199 | self.update_params(epoch) 200 | return loss, reg 201 | 202 | @bm.cls_jit 203 | def fit(self, xs, ys, epoch): # xs: (num_time, num_batch) 204 | progresses = bm.linspace(0., 1., xs.shape[0]) 205 | loss, reg = bm.for_loop(partial(self._train, epoch=epoch, y=ys), (xs, progresses), remat=True) 206 | return loss, reg 207 | 208 | 209 | parser = argparse.ArgumentParser() 210 | parser.add_argument('--alpha', type=float, default=0.1, help='Alpha') 211 | parser.add_argument('--beta', type=float, default=0.5, help='Beta') 212 | parser.add_argument('--rho', type=float, default=0.0, help='Rho') 213 | parser.add_argument('--debias', action='store_true', help='FedDyn debias algorithm') 214 | 215 | parser.add_argument('--bptt', type=int, default=300, help='sequence length') 216 | parser.add_argument('--nhid', type=int, default=256, help='number of hidden units per layer') 217 | parser.add_argument('--lr', type=float, default=5e-3, help='initial learning rate (default: 4e-3)') 218 | parser.add_argument('--clip', type=float, default=1., help='gradient clipping') 219 | 220 | parser.add_argument('--epochs', type=int, default=250, help='upper epoch limit (default: 200)') 221 | parser.add_argument('--batch_size', type=int, default=512, metavar='N', help='batch size') 222 | 223 | parser.add_argument('--wdecay', type=float, default=0., help='weight decay') 224 | parser.add_argument('--optim', type=str, default='adam', help='optimizer to use') 225 | parser.add_argument('--when', nargs='+', type=int, default=[10, 30, 50, 75, 90], help='When to decay the learning rate') 226 | parser.add_argument('--load', type=str, default='', help='path to load the model') 227 | parser.add_argument('--save', type=str, default='./models/', help='path to load the model') 228 | parser.add_argument('--permute', action='store_true', help='use permuted dataset (default: False)') 229 | args = parser.parse_args() 230 | 231 | bm.set(mode=bm.TrainingMode(args.batch_size)) # important 232 | 233 | # datasets 234 | n_classes = 10 235 | train_data = bdata.vision.MNIST(r'/mnt/d/data/', download=True, split='train') 236 | x_train = (train_data.data / 255.).reshape(-1, 28 * 28) 237 | y_train = train_data.targets 238 | 239 | 240 | def train_data(): 241 | indices = np.random.permutation(len(x_train)) 242 | for i in range(0, len(x_train), args.batch_size): 243 | idx = indices[i: i + args.batch_size] 244 | yield x_train[idx].T, y_train[idx] 245 | 246 | 247 | # model 248 | model = SNN(1, args.nhid, n_classes) 249 | 250 | # optimizer 251 | lr = bp.optim.MultiStepLR(args.lr, args.when, gamma=0.1) 252 | if args.optim == 'adam': 253 | optimizer = bp.optim.Adam(lr=lr, weight_decay=args.wdecay) 254 | elif args.optim == 'sgd': 255 | optimizer = bp.optim.SGD(lr=lr, weight_decay=args.wdecay) 256 | else: 257 | raise ValueError 258 | 259 | # trainer 260 | trainer = FPTT_Trainer(model, 261 | optimizer, 262 | debias=args.debias, 263 | clip=args.clip, 264 | alpha=args.alpha, 265 | beta=args.beta, 266 | rho=args.rho) 267 | 268 | # training 269 | for epoch in range(1, args.epochs + 1): 270 | num_data = 0 271 | for data, target in train_data(): 272 | t0 = time.time() 273 | model.reset_state(data.shape[0]) 274 | losses, regs = trainer.fit(data, target, epoch) 275 | total_reg_loss = regs.sum().item() 276 | total_clf_loss = losses.sum().item() 277 | 278 | num_data += data.shape[0] 279 | print( 280 | f'Epoch {epoch} [{num_data}/{len(x_train)}]\t' 281 | f'time: {time.time() - t0:.4f}s\tLoss: {total_clf_loss:.6f}\t' 282 | f'Reg: {total_reg_loss:.6f}' 283 | ) 284 | -------------------------------------------------------------------------------- /conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | 14 | # -- Project information ----------------------------------------------------- 15 | 16 | project = 'BrainPy Examples' 17 | copyright = '2023, BrainPy team' 18 | author = 'BrainPy team' 19 | 20 | # -- General configuration --------------------------------------------------- 21 | 22 | # Add any Sphinx extension module names here, as strings. They can be 23 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 24 | # ones. 25 | extensions = [ 26 | 'sphinx.ext.autodoc', 27 | 'sphinx.ext.autosummary', 28 | 'sphinx.ext.intersphinx', 29 | 'sphinx.ext.mathjax', 30 | 'sphinx.ext.napoleon', 31 | 'sphinx.ext.viewcode', 32 | 'nbsphinx', 33 | ] 34 | # Add any paths that contain templates here, relative to this directory. 35 | templates_path = ['_templates'] 36 | 37 | 38 | # source_suffix = '.rst' 39 | autosummary_generate = True 40 | napolean_use_rtype = False 41 | 42 | 43 | # Execute notebooks before conversion: 'always', 'never', 'auto' (default) 44 | # We execute all notebooks, exclude the slow ones using 'exclude_patterns' 45 | nbsphinx_execute = 'never' 46 | 47 | 48 | # -- Options for HTML output ------------------------------------------------- 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | source_suffix = '.rst' 53 | 54 | # The master toctree document. 55 | main_doc = 'index' 56 | master_doc = 'index' 57 | 58 | # List of patterns, relative to source directory, that match files and 59 | # directories to ignore when looking for source files. 60 | # This pattern also affects html_static_path and html_extra_path. 61 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 62 | 63 | 64 | # The theme to use for HTML and HTML Help pages. See the documentation for 65 | # a list of builtin themes. 66 | # 67 | html_theme = 'sphinx_rtd_theme' 68 | 69 | 70 | -------------------------------------------------------------------------------- /decision_making/Wang_2002_decision_making_spiking.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "afe4708f", 6 | "metadata": {}, 7 | "source": [ 8 | "# *(Wang, 2002)* Decision making spiking model\n", 9 | "\n", 10 | "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/decision_making/Wang_2002_decision_making_spiking.ipynb)\n", 11 | "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/decision_making/Wang_2002_decision_making_spiking.ipynb)" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "83acef7d", 17 | "metadata": {}, 18 | "source": [ 19 | "Implementation of the paper: *Wang, Xiao-Jing. \"Probabilistic decision making by slow reverberation in cortical circuits.\" Neuron 36.5 (2002): 955-968.*\n", 20 | "\n", 21 | "- Author : Chaoming Wang (chao.brain@qq.com)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "88e26125", 27 | "metadata": {}, 28 | "source": [ 29 | "Please refer to github example folder:\n", 30 | "\n", 31 | "- https://github.com/brainpy/examples/blob/master/decision_making/Wang_2002_decision_making_spiking.py" 32 | ] 33 | } 34 | ], 35 | "metadata": { 36 | "jupytext": { 37 | "encoding": "# -*- coding: utf-8 -*-", 38 | "formats": "ipynb,py:percent" 39 | }, 40 | "kernelspec": { 41 | "display_name": "brainpy", 42 | "language": "python", 43 | "name": "brainpy" 44 | }, 45 | "language_info": { 46 | "codemirror_mode": { 47 | "name": "ipython", 48 | "version": 3 49 | }, 50 | "file_extension": ".py", 51 | "mimetype": "text/x-python", 52 | "name": "python", 53 | "nbconvert_exporter": "python", 54 | "pygments_lexer": "ipython3", 55 | "version": "3.9.12" 56 | }, 57 | "latex_envs": { 58 | "LaTeX_envs_menu_present": true, 59 | "autoclose": false, 60 | "autocomplete": true, 61 | "bibliofile": "biblio.bib", 62 | "cite_by": "apalike", 63 | "current_citInitial": 1, 64 | "eqLabelWithNumbers": true, 65 | "eqNumInitial": 1, 66 | "hotkeys": { 67 | "equation": "Ctrl-E", 68 | "itemize": "Ctrl-I" 69 | }, 70 | "labels_anchors": false, 71 | "latex_user_defs": false, 72 | "report_style_numbering": false, 73 | "user_envs_cfg": false 74 | }, 75 | "toc": { 76 | "base_numbering": 1, 77 | "nav_menu": {}, 78 | "number_sections": true, 79 | "sideBar": true, 80 | "skip_h1_title": false, 81 | "title_cell": "Table of Contents", 82 | "title_sidebar": "Contents", 83 | "toc_cell": false, 84 | "toc_position": {}, 85 | "toc_section_display": true, 86 | "toc_window_display": true 87 | } 88 | }, 89 | "nbformat": 4, 90 | "nbformat_minor": 5 91 | } 92 | -------------------------------------------------------------------------------- /decision_making/Wang_2002_decision_making_spiking.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import brainpy as bp 5 | import brainpy.math as bm 6 | import matplotlib.pyplot as plt 7 | 8 | print(bp.__version__) 9 | 10 | # bm.set_platform('cpu') 11 | 12 | 13 | class PoissonStim(bp.NeuGroup): 14 | def __init__(self, size, freq_mean, freq_var, t_interval): 15 | super(PoissonStim, self).__init__(size=size) 16 | 17 | # parameters 18 | self.freq_mean = freq_mean 19 | self.freq_var = freq_var 20 | self.t_interval = t_interval 21 | 22 | # variables 23 | self.freq = bp.init.variable_(bm.zeros, 1, self.mode) 24 | self.freq_t_last_change = bp.init.variable_(lambda s: bm.ones(s) * -1e7, 1, self.mode) 25 | self.spike = bp.init.variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode) 26 | self.rng = bm.random.RandomState() 27 | 28 | def reset_state(self, batch_size=None): 29 | self.freq.value = bp.init.variable_(bm.zeros, 1, batch_size) 30 | self.freq_t_last_change.value = bp.init.variable_(lambda s: bm.ones(s) * -1e7, 1, batch_size) 31 | self.spike.value = bp.init.variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) 32 | 33 | def update(self): 34 | t = bp.share['t'] 35 | dt = bp.share['dt'] 36 | in_interval = bm.logical_and(pre_stimulus_period < t, t < pre_stimulus_period + stimulus_period) 37 | in_interval = bm.ones_like(self.freq, dtype=bool) * in_interval 38 | prev_freq = bm.where(in_interval, self.freq, 0.) 39 | in_interval = bm.logical_and(in_interval, (t - self.freq_t_last_change) >= self.t_interval) 40 | self.freq.value = bm.where(in_interval, self.rng.normal(self.freq_mean, self.freq_var, self.freq.shape), prev_freq) 41 | self.freq_t_last_change.value = bm.where(in_interval, t, self.freq_t_last_change) 42 | shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, bm.BatchingMode) else self.varshape 43 | self.spike.value = self.rng.random(shape) < self.freq * dt / 1000. 44 | 45 | 46 | class DecisionMaking(bp.Network): 47 | def __init__(self, scale=1., mu0=40., coherence=25.6, f=0.15): 48 | super(DecisionMaking, self).__init__() 49 | 50 | num_exc = int(1600 * scale) 51 | num_inh = int(400 * scale) 52 | num_A = int(f * num_exc) 53 | num_B = int(f * num_exc) 54 | num_N = num_exc - num_A - num_B 55 | print(f'Total network size: {num_exc + num_inh}') 56 | 57 | poisson_freq = 2400. # Hz 58 | w_pos = 1.7 59 | w_neg = 1. - f * (w_pos - 1.) / (1. - f) 60 | g_ext2E_AMPA = 2.1 # nS 61 | g_ext2I_AMPA = 1.62 # nS 62 | g_E2E_AMPA = 0.05 / scale # nS 63 | g_E2I_AMPA = 0.04 / scale # nS 64 | g_E2E_NMDA = 0.165 / scale # nS 65 | g_E2I_NMDA = 0.13 / scale # nS 66 | g_I2E_GABAa = 1.3 / scale # nS 67 | g_I2I_GABAa = 1.0 / scale # nS 68 | 69 | ampa_par = dict(delay_step=int(0.5 / bm.get_dt()), tau=2.0) 70 | gaba_par = dict(delay_step=int(0.5 / bm.get_dt()), tau=5.0) 71 | nmda_par = dict(delay_step=int(0.5 / bm.get_dt()), tau_decay=100, tau_rise=2., a=0.5) 72 | 73 | # E neurons/pyramid neurons 74 | A = bp.neurons.LIF(num_A, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, 75 | tau_ref=2., V_initializer=bp.init.OneInit(-70.)) 76 | B = bp.neurons.LIF(num_B, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, 77 | tau_ref=2., V_initializer=bp.init.OneInit(-70.)) 78 | N = bp.neurons.LIF(num_N, V_rest=-70., V_reset=-55., V_th=-50., tau=20., R=0.04, 79 | tau_ref=2., V_initializer=bp.init.OneInit(-70.)) 80 | # I neurons/interneurons 81 | I = bp.neurons.LIF(num_inh, V_rest=-70., V_reset=-55., V_th=-50., tau=10., R=0.05, 82 | tau_ref=1., V_initializer=bp.init.OneInit(-70.)) 83 | 84 | # poisson stimulus 85 | IA = PoissonStim(num_A, freq_var=10., t_interval=50., freq_mean=mu0 + mu0 / 100. * coherence) 86 | IB = PoissonStim(num_B, freq_var=10., t_interval=50., freq_mean=mu0 - mu0 / 100. * coherence) 87 | 88 | # noise neurons 89 | self.noise_B = bp.neurons.PoissonGroup(num_B, freqs=poisson_freq) 90 | self.noise_A = bp.neurons.PoissonGroup(num_A, freqs=poisson_freq) 91 | self.noise_N = bp.neurons.PoissonGroup(num_N, freqs=poisson_freq) 92 | self.noise_I = bp.neurons.PoissonGroup(num_inh, freqs=poisson_freq) 93 | 94 | # define external inputs 95 | self.IA2A = bp.synapses.Exponential(IA, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, 96 | output=bp.synouts.COBA(E=0.), **ampa_par) 97 | self.IB2B = bp.synapses.Exponential(IB, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, 98 | output=bp.synouts.COBA(E=0.), **ampa_par) 99 | 100 | # define E->E/I conn 101 | 102 | self.N2B_AMPA = bp.synapses.Exponential(N, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, 103 | output=bp.synouts.COBA(E=0.), **ampa_par) 104 | self.N2A_AMPA = bp.synapses.Exponential(N, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, 105 | output=bp.synouts.COBA(E=0.), **ampa_par) 106 | self.N2N_AMPA = bp.synapses.Exponential(N, N, bp.conn.All2All(), g_max=g_E2E_AMPA, 107 | output=bp.synouts.COBA(E=0.), **ampa_par) 108 | self.N2I_AMPA = bp.synapses.Exponential(N, I, bp.conn.All2All(), g_max=g_E2I_AMPA, 109 | output=bp.synouts.COBA(E=0.), **ampa_par) 110 | self.N2B_NMDA = bp.synapses.NMDA(N, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, 111 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 112 | self.N2A_NMDA = bp.synapses.NMDA(N, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, 113 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 114 | self.N2N_NMDA = bp.synapses.NMDA(N, N, bp.conn.All2All(), g_max=g_E2E_NMDA, 115 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 116 | self.N2I_NMDA = bp.synapses.NMDA(N, I, bp.conn.All2All(), g_max=g_E2I_NMDA, 117 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 118 | 119 | self.B2B_AMPA = bp.synapses.Exponential(B, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, 120 | output=bp.synouts.COBA(E=0.), **ampa_par) 121 | self.B2A_AMPA = bp.synapses.Exponential(B, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, 122 | output=bp.synouts.COBA(E=0.), **ampa_par) 123 | self.B2N_AMPA = bp.synapses.Exponential(B, N, bp.conn.All2All(), g_max=g_E2E_AMPA, 124 | output=bp.synouts.COBA(E=0.), **ampa_par) 125 | self.B2I_AMPA = bp.synapses.Exponential(B, I, bp.conn.All2All(), g_max=g_E2I_AMPA, 126 | output=bp.synouts.COBA(E=0.), **ampa_par) 127 | self.B2B_NMDA = bp.synapses.NMDA(B, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, 128 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 129 | self.B2A_NMDA = bp.synapses.NMDA(B, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, 130 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 131 | self.B2N_NMDA = bp.synapses.NMDA(B, N, bp.conn.All2All(), g_max=g_E2E_NMDA, 132 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 133 | self.B2I_NMDA = bp.synapses.NMDA(B, I, bp.conn.All2All(), g_max=g_E2I_NMDA, 134 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 135 | 136 | self.A2B_AMPA = bp.synapses.Exponential(A, B, bp.conn.All2All(), g_max=g_E2E_AMPA * w_neg, 137 | output=bp.synouts.COBA(E=0.), **ampa_par) 138 | self.A2A_AMPA = bp.synapses.Exponential(A, A, bp.conn.All2All(), g_max=g_E2E_AMPA * w_pos, 139 | output=bp.synouts.COBA(E=0.), **ampa_par) 140 | self.A2N_AMPA = bp.synapses.Exponential(A, N, bp.conn.All2All(), g_max=g_E2E_AMPA, 141 | output=bp.synouts.COBA(E=0.), **ampa_par) 142 | self.A2I_AMPA = bp.synapses.Exponential(A, I, bp.conn.All2All(), g_max=g_E2I_AMPA, 143 | output=bp.synouts.COBA(E=0.), **ampa_par) 144 | self.A2B_NMDA = bp.synapses.NMDA(A, B, bp.conn.All2All(), g_max=g_E2E_NMDA * w_neg, 145 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 146 | self.A2A_NMDA = bp.synapses.NMDA(A, A, bp.conn.All2All(), g_max=g_E2E_NMDA * w_pos, 147 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 148 | self.A2N_NMDA = bp.synapses.NMDA(A, N, bp.conn.All2All(), g_max=g_E2E_NMDA, 149 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 150 | self.A2I_NMDA = bp.synapses.NMDA(A, I, bp.conn.All2All(), g_max=g_E2I_NMDA, 151 | output=bp.synouts.MgBlock(E=0., cc_Mg=1.), **nmda_par) 152 | 153 | # define I->E/I conn 154 | self.I2B = bp.synapses.Exponential(I, B, bp.conn.All2All(), g_max=g_I2E_GABAa, 155 | output=bp.synouts.COBA(E=-70.), **gaba_par) 156 | self.I2A = bp.synapses.Exponential(I, A, bp.conn.All2All(), g_max=g_I2E_GABAa, 157 | output=bp.synouts.COBA(E=-70.), **gaba_par) 158 | self.I2N = bp.synapses.Exponential(I, N, bp.conn.All2All(), g_max=g_I2E_GABAa, 159 | output=bp.synouts.COBA(E=-70.), **gaba_par) 160 | self.I2I = bp.synapses.Exponential(I, I, bp.conn.All2All(), g_max=g_I2I_GABAa, 161 | output=bp.synouts.COBA(E=-70.), **gaba_par) 162 | 163 | # define external projections 164 | self.noise2B = bp.synapses.Exponential(self.noise_B, B, bp.conn.One2One(), g_max=g_ext2E_AMPA, 165 | output=bp.synouts.COBA(E=0.), **ampa_par) 166 | self.noise2A = bp.synapses.Exponential(self.noise_A, A, bp.conn.One2One(), g_max=g_ext2E_AMPA, 167 | output=bp.synouts.COBA(E=0.), **ampa_par) 168 | self.noise2N = bp.synapses.Exponential(self.noise_N, N, bp.conn.One2One(), g_max=g_ext2E_AMPA, 169 | output=bp.synouts.COBA(E=0.), **ampa_par) 170 | self.noise2I = bp.synapses.Exponential(self.noise_I, I, bp.conn.One2One(), g_max=g_ext2I_AMPA, 171 | output=bp.synouts.COBA(E=0.), **ampa_par) 172 | 173 | # nodes 174 | self.B = B 175 | self.A = A 176 | self.N = N 177 | self.I = I 178 | self.IA = IA 179 | self.IB = IB 180 | 181 | 182 | def visualize_raster(ax, mon, t_start=0., title=None): 183 | bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax, color='', label="Group A") 184 | bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax, color='', label="Group B") 185 | if title: 186 | ax.set_title(title) 187 | ax.set_ylabel("Neuron Index") 188 | ax.set_xlim(t_start, total_period + 1) 189 | ax.axvline(pre_stimulus_period, linestyle='dashed') 190 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') 191 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') 192 | ax.legend() 193 | 194 | 195 | def visualize_results(axes, mon, t_start=0., title=None): 196 | ax = axes[0] 197 | bp.visualize.raster_plot(mon['ts'], mon['A.spike'], markersize=1, ax=ax) 198 | if title: 199 | ax.set_title(title) 200 | ax.set_ylabel("Group A") 201 | ax.set_xlim(t_start, total_period + 1) 202 | ax.axvline(pre_stimulus_period, linestyle='dashed') 203 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') 204 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') 205 | 206 | ax = axes[1] 207 | bp.visualize.raster_plot(mon['ts'], mon['B.spike'], markersize=1, ax=ax) 208 | ax.set_ylabel("Group B") 209 | ax.set_xlim(t_start, total_period + 1) 210 | ax.axvline(pre_stimulus_period, linestyle='dashed') 211 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') 212 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') 213 | 214 | ax = axes[2] 215 | rateA = bp.measure.firing_rate(mon['A.spike'], width=10.) 216 | rateB = bp.measure.firing_rate(mon['B.spike'], width=10.) 217 | ax.plot(mon['ts'], rateA, label="Group A") 218 | ax.plot(mon['ts'], rateB, label="Group B") 219 | ax.set_ylabel('Population activity [Hz]') 220 | ax.set_xlim(t_start, total_period + 1) 221 | ax.axvline(pre_stimulus_period, linestyle='dashed') 222 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') 223 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') 224 | ax.legend() 225 | 226 | ax = axes[3] 227 | ax.plot(mon['ts'], mon['IA.freq'], label="group A") 228 | ax.plot(mon['ts'], mon['IB.freq'], label="group B") 229 | ax.set_ylabel("Input activity [Hz]") 230 | ax.set_xlim(t_start, total_period + 1) 231 | ax.axvline(pre_stimulus_period, linestyle='dashed') 232 | ax.axvline(pre_stimulus_period + stimulus_period, linestyle='dashed') 233 | ax.axvline(pre_stimulus_period + stimulus_period + delay_period, linestyle='dashed') 234 | ax.legend() 235 | ax.set_xlabel("Time [ms]") 236 | 237 | 238 | pre_stimulus_period = 100. 239 | stimulus_period = 1000. 240 | delay_period = 500. 241 | total_period = pre_stimulus_period + stimulus_period + delay_period 242 | 243 | 244 | def single_run(): 245 | net = DecisionMaking(scale=1., coherence=-80., mu0=50.) 246 | runner = bp.DSRunner( 247 | net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'] 248 | ) 249 | runner.run(total_period) 250 | 251 | fig, gs = bp.visualize.get_figure(4, 1, 3, 10) 252 | axes = [fig.add_subplot(gs[i, 0]) for i in range(4)] 253 | visualize_results(axes, mon=runner.mon) 254 | plt.show() 255 | 256 | 257 | def batching_run(): 258 | num_row, num_col = 3, 4 259 | num_batch = 12 260 | coherence = bm.expand_dims(bm.linspace(-100, 100., num_batch), 1) 261 | 262 | with bm.environment(mode=bm.BatchingMode(batch_size=num_batch)): 263 | net = DecisionMaking(scale=1., coherence=coherence, mu0=20.) 264 | runner = bp.DSRunner( 265 | net, monitors=['A.spike', 'B.spike', 'IA.freq', 'IB.freq'], data_first_axis='B' 266 | ) 267 | runner.run(total_period) 268 | 269 | coherence = bm.as_numpy(coherence) 270 | fig, gs = bp.visualize.get_figure(num_row, num_col, 3, 4) 271 | for i in range(num_row): 272 | for j in range(num_col): 273 | idx = i * num_col + j 274 | if idx < num_batch: 275 | mon = {'A.spike': runner.mon['A.spike'][idx], 276 | 'B.spike': runner.mon['B.spike'][idx], 277 | 'IA.freq': runner.mon['IA.freq'][idx], 278 | 'IB.freq': runner.mon['IB.freq'][idx], 279 | 'ts': runner.mon['ts']} 280 | ax = fig.add_subplot(gs[i, j]) 281 | visualize_raster(ax, mon=mon, title=f'coherence={coherence[idx, 0]}%') 282 | plt.show() 283 | 284 | 285 | if __name__ == '__main__': 286 | single_run() 287 | batching_run() 288 | -------------------------------------------------------------------------------- /decision_making/Wang_2006_decision_making_rate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "34df5f48", 6 | "metadata": { 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | } 10 | }, 11 | "source": [ 12 | "# *(Wong & Wang, 2006)* Decision making rate model\n", 13 | "\n", 14 | "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/decision_making/Wang_2006_decision_making_rate.ipynb)\n", 15 | "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/decision_making/Wang_2006_decision_making_rate.ipynb)" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "ad00c2ea", 21 | "metadata": { 22 | "pycharm": { 23 | "name": "#%% md\n" 24 | } 25 | }, 26 | "source": [ 27 | "Please refer to [Anlysis of A Decision Making Model](https://brainpy.readthedocs.io/en/latest/tutorial_analysis/decision_making_model.html)." 28 | ] 29 | } 30 | ], 31 | "metadata": { 32 | "kernelspec": { 33 | "display_name": "brainpy", 34 | "language": "python", 35 | "name": "brainpy" 36 | }, 37 | "language_info": { 38 | "codemirror_mode": { 39 | "name": "ipython", 40 | "version": 3 41 | }, 42 | "file_extension": ".py", 43 | "mimetype": "text/x-python", 44 | "name": "python", 45 | "nbconvert_exporter": "python", 46 | "pygments_lexer": "ipython3", 47 | "version": "3.8.11" 48 | }, 49 | "latex_envs": { 50 | "LaTeX_envs_menu_present": true, 51 | "autoclose": false, 52 | "autocomplete": true, 53 | "bibliofile": "biblio.bib", 54 | "cite_by": "apalike", 55 | "current_citInitial": 1, 56 | "eqLabelWithNumbers": true, 57 | "eqNumInitial": 1, 58 | "hotkeys": { 59 | "equation": "Ctrl-E", 60 | "itemize": "Ctrl-I" 61 | }, 62 | "labels_anchors": false, 63 | "latex_user_defs": false, 64 | "report_style_numbering": false, 65 | "user_envs_cfg": false 66 | }, 67 | "toc": { 68 | "base_numbering": 1, 69 | "nav_menu": {}, 70 | "number_sections": true, 71 | "sideBar": true, 72 | "skip_h1_title": false, 73 | "title_cell": "Table of Contents", 74 | "title_sidebar": "Contents", 75 | "toc_cell": false, 76 | "toc_position": {}, 77 | "toc_section_display": true, 78 | "toc_window_display": true 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 5 83 | } 84 | -------------------------------------------------------------------------------- /dynamics_analysis/2d_decision_making_with_lowdim_analyzer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "34df5f48", 6 | "metadata": { 7 | "pycharm": { 8 | "name": "#%% md\n" 9 | } 10 | }, 11 | "source": [ 12 | "# [2D] Decision Making Model with Low-dimensional Analyzer\n", 13 | "\n", 14 | "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/examples/blob/main/dynamics_analysis/2d_decision_making_with_lowdim_analyzer.ipynb)\n", 15 | "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/examples/blob/main/dynamics_analysis/2d_decision_making_with_lowdim_analyzer.ipynb)" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "id": "ad00c2ea", 21 | "metadata": { 22 | "pycharm": { 23 | "name": "#%% md\n" 24 | } 25 | }, 26 | "source": [ 27 | "Please refer to [Anlysis of A Decision Making Model](https://brainpy.readthedocs.io/en/brainpy-2.x/tutorial_analysis/decision_making_model.html)." 28 | ] 29 | } 30 | ], 31 | "metadata": { 32 | "kernelspec": { 33 | "display_name": "brainpy", 34 | "language": "python", 35 | "name": "brainpy" 36 | }, 37 | "language_info": { 38 | "codemirror_mode": { 39 | "name": "ipython", 40 | "version": 3 41 | }, 42 | "file_extension": ".py", 43 | "mimetype": "text/x-python", 44 | "name": "python", 45 | "nbconvert_exporter": "python", 46 | "pygments_lexer": "ipython3", 47 | "version": "3.8.11" 48 | }, 49 | "latex_envs": { 50 | "LaTeX_envs_menu_present": true, 51 | "autoclose": false, 52 | "autocomplete": true, 53 | "bibliofile": "biblio.bib", 54 | "cite_by": "apalike", 55 | "current_citInitial": 1, 56 | "eqLabelWithNumbers": true, 57 | "eqNumInitial": 1, 58 | "hotkeys": { 59 | "equation": "Ctrl-E", 60 | "itemize": "Ctrl-I" 61 | }, 62 | "labels_anchors": false, 63 | "latex_user_defs": false, 64 | "report_style_numbering": false, 65 | "user_envs_cfg": false 66 | }, 67 | "toc": { 68 | "base_numbering": 1, 69 | "nav_menu": {}, 70 | "number_sections": true, 71 | "sideBar": true, 72 | "skip_h1_title": false, 73 | "title_cell": "Table of Contents", 74 | "title_sidebar": "Contents", 75 | "toc_cell": false, 76 | "toc_position": {}, 77 | "toc_section_display": true, 78 | "toc_window_display": true 79 | } 80 | }, 81 | "nbformat": 4, 82 | "nbformat_minor": 5 83 | } 84 | -------------------------------------------------------------------------------- /images/cann-decoding.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann-decoding.gif -------------------------------------------------------------------------------- /images/cann-encoding.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann-encoding.gif -------------------------------------------------------------------------------- /images/cann-tracking.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann-tracking.gif -------------------------------------------------------------------------------- /images/cann_1d_oscillatory_tracking.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann_1d_oscillatory_tracking.gif -------------------------------------------------------------------------------- /images/cann_2d_encoding.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann_2d_encoding.gif -------------------------------------------------------------------------------- /images/cann_2d_tracking.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/cann_2d_tracking.gif -------------------------------------------------------------------------------- /images/decision_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/decision_model.png -------------------------------------------------------------------------------- /images/izhikevich_patterns.jfif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/images/izhikevich_patterns.jfif -------------------------------------------------------------------------------- /index.rst: -------------------------------------------------------------------------------- 1 | BrainPy Examples 2 | ================ 3 | 4 | This repository contains examples of using `BrainPy `_ 5 | to implement various models about neurons, synapse, networks, etc. We welcome your implementation, 6 | which can be post through our `github `_ page. 7 | 8 | If you run some codes failed, please tell us through github issue https://github.com/brainpy/examples/issues . 9 | 10 | If you found these examples are useful for your research, please kindly `cite us `_. 11 | 12 | If you want to add more examples, please fork our github https://github.com/brainpy/examples . 13 | 14 | 15 | 16 | Example categories: 17 | 18 | .. contents:: 19 | :local: 20 | :depth: 2 21 | 22 | 23 | 24 | 25 | Neuron Models 26 | ------------- 27 | 28 | - `(Izhikevich, 2003): Izhikevich Model `_ 29 | - `(Brette, Romain. 2004): LIF phase locking `_ 30 | - `(Gerstner, 2005): Adaptive Exponential Integrate-and-Fire model `_ 31 | - `(Niebur, et. al, 2009): Generalized integrate-and-fire model `_ 32 | - `(Jansen & Rit, 1995): Jansen-Rit Model `_ 33 | - `(Teka, et. al, 2018): Fractional-order Izhikevich neuron model `_ 34 | - `(Mondal, et. al, 2019): Fractional-order FitzHugh-Rinzel bursting neuron model `_ 35 | 36 | 37 | 38 | Attractor Networks 39 | ------------------ 40 | 41 | - `CANN 1D Oscillatory Tracking `_ 42 | - `(Si Wu, 2008): Continuous-attractor Neural Network 1D `_ 43 | - `(Si Wu, 2008): Continuous-attractor Neural Network 2D `_ 44 | - `Discrete Hopfield Network `_ 45 | - `Discrete Hopfield Network Demo for Image Reconstruction `_ 46 | 47 | 48 | 49 | Decision Making Model 50 | --------------------- 51 | 52 | - `(Wang, 2002): Decision making spiking model `_ 53 | - `(Wong & Wang, 2006): Decision making rate model `_ 54 | 55 | 56 | 57 | 58 | E/I Balanced Network 59 | -------------------- 60 | 61 | 62 | - `(Vreeswijk & Sompolinsky, 1996): E/I balanced network `_ 63 | - `(Brette, et, al., 2007): COBA `_ 64 | - `(Brette, et, al., 2007): CUBA `_ 65 | - `(Brette, et, al., 2007): COBA-HH `_ 66 | - `(Tian, et al., 2020): E/I Net for fast response `_ 67 | 68 | 69 | 70 | Brain-inspired Computing 71 | ------------------------ 72 | 73 | 74 | - `Classify MNIST dataset by a fully connected LIF layer `_ 75 | - `Convolutional SNN to Classify Fashion-MNIST `_ 76 | - `(2022, NeurIPS): Online Training Through Time for Spiking Neural Networks `_ 77 | - `(2019, Zenke, F.): SNN Surrogate Gradient Learning `_ 78 | - `(2019, Zenke, F.): SNN Surrogate Gradient Learning to Classify Fashion-MNIST `_ 79 | - `(2021, Raminmh): Liquid time-constant Networks `_ 80 | 81 | 82 | 83 | Reservoir Computing 84 | ------------------- 85 | 86 | 87 | - `Predicting Mackey-Glass timeseries `_ 88 | - `(Sussillo & Abbott, 2009): FORCE Learning `_ 89 | - `(Gauthier, et. al, 2021): Next generation reservoir computing `_ 90 | 91 | 92 | 93 | Gap Junction Network 94 | -------------------- 95 | 96 | - `(Fazli and Richard, 2022): Electrically Coupled Bursting Pituitary Cells `_ 97 | - `(Sherman & Rinzel, 1992): Gap junction leads to anti-synchronization `_ 98 | 99 | 100 | 101 | Oscillation and Synchronization 102 | ------------------------------- 103 | 104 | - `(Wang & Buzsáki, 1996): Gamma Oscillation `_ 105 | - `(Brunel & Hakim, 1999): Fast Global Oscillation `_ 106 | - `(Diesmann, et, al., 1999): Synfire Chains `_ 107 | - `(Li, et. al, 2017): Unified Thalamus Oscillation Model `_ 108 | - `(Susin & Destexhe, 2021): Asynchronous Network `_ 109 | - `(Susin & Destexhe, 2021): CHING Network for Generating Gamma Oscillation `_ 110 | - `(Susin & Destexhe, 2021): ING Network for Generating Gamma Oscillation `_ 111 | - `(Susin & Destexhe, 2021): PING Network for Generating Gamma Oscillation `_ 112 | 113 | 114 | 115 | Large-Scale Modeling 116 | -------------------- 117 | 118 | - `(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 1 `_ 119 | - `(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 2 `_ 120 | - `(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Figure 5 `_ 121 | - `(Joglekar, et. al, 2018): Inter-areal Balanced Amplification Taichi customized operators `_ 122 | - `Simulating 1-million-neuron networks with 1GB GPU memory `_ 123 | 124 | 125 | 126 | Recurrent Neural Network 127 | ------------------------ 128 | 129 | 130 | - `(Sussillo & Abbott, 2009): FORCE Learning `_ 131 | - `Integrator RNN Model `_ 132 | - `Train RNN to Solve Parametric Working Memory `_ 133 | - `(Song, et al., 2016): Training excitatory-inhibitory recurrent network `_ 134 | - `(Masse, et al., 2019): RNN with STP for Working Memory `_ 135 | - `(Yang, 2020): Dynamical system analysis for RNN `_ 136 | - `(Bellec, et. al, 2020): eprop for Evidence Accumulation Task `_ 137 | 138 | 139 | 140 | Working Memory Model 141 | -------------------- 142 | 143 | - `(Bouchacourt & Buschman, 2019): Flexible Working Memory Model `_ 144 | - `(Mi, et. al., 2017): STP for Working Memory Capacity `_ 145 | - `(Masse, et al., 2019): RNN with STP for Working Memory `_ 146 | 147 | 148 | 149 | Dynamics Analysis 150 | ----------------- 151 | 152 | - `[1D] Simple systems `_ 153 | - `[2D] NaK model analysis `_ 154 | - `[2D] Wilson-Cowan model `_ 155 | - `[2D] Decision Making Model with SlowPointFinder `_ 156 | - `[2D] Decision Making Model with Low-dimensional Analyzer `_ 157 | - `[3D] Hindmarsh Rose Model `_ 158 | - `Continuous-attractor Neural Network `_ 159 | - `Gap junction-coupled FitzHugh-Nagumo Model `_ 160 | - `(Yang, 2020): Dynamical system analysis for RNN `_ 161 | 162 | 163 | 164 | 165 | Classical Dynamical Systems 166 | --------------------------- 167 | 168 | - `Hénon map `_ 169 | - `Logistic map `_ 170 | - `Lorenz system `_ 171 | - `Mackey-Glass equation `_ 172 | - `Multiscroll chaotic attractor (多卷波混沌吸引子) `_ 173 | - `Rabinovich-Fabrikant equations `_ 174 | - `Fractional-order Chaos Gallery `_ 175 | 176 | 177 | 178 | 179 | 180 | Unclassified Models 181 | ------------------- 182 | 183 | - `(Brette & Guigon, 2003): Reliability of spike timing `_ 184 | 185 | 186 | 187 | 188 | 189 | Indices and tables 190 | ================== 191 | 192 | * :ref:`genindex` 193 | * :ref:`modindex` 194 | * :ref:`search` 195 | -------------------------------------------------------------------------------- /large_scale_modeling/2014_CorticalModel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import brainpy as bp 4 | import brainpy.math as bm 5 | import numpy as np 6 | 7 | bp.math.set_platform('cpu') 8 | 9 | 10 | class LIF(bp.NeuGroup): 11 | def __init__(self, size, tau_neu=10., tau_syn=0.5, tau_ref=2., 12 | V_reset=-65., V_th=-50., Cm=0.25, ): 13 | super(LIF, self).__init__(size=size) 14 | 15 | # parameters 16 | self.tau_neu = tau_neu # membrane time constant [ms] 17 | self.tau_syn = tau_syn # Post-synaptic current time constant [ms] 18 | self.tau_ref = tau_ref # absolute refractory period [ms] 19 | self.Cm = Cm # membrane capacity [nF] 20 | self.V_reset = V_reset # reset potential [mV] 21 | self.V_th = V_th # fixed firing threshold [mV] 22 | self.Iext = 0. # constant external current [nA] 23 | 24 | # variables 25 | self.V = bm.Variable(-65. + 5.0 * bm.random.randn(self.num)) # [mV] 26 | self.I = bm.Variable(bm.zeros(self.num)) # synaptic currents [nA] 27 | self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) 28 | self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) 29 | 30 | # function 31 | self.integral = bp.odeint(bp.JointEq([self.dV, self.dI]), method='exp_auto') 32 | 33 | def dV(self, V, t, I): 34 | return (-V + self.V_reset) / self.tau_neu + (I + self.Iext) / self.Cm 35 | 36 | def dI(self, I, t): 37 | return -I / self.tau_syn 38 | 39 | def update(self, _t, _dt): 40 | ref = (_t - self.t_last_spike) <= self.tau_ref 41 | V, I = self.integral(self.V, self.I, _t, _dt) 42 | V = bm.where(ref, self.V, V) 43 | spike = (V >= self.V_th) 44 | self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike) 45 | self.V.value = bm.where(spike, self.V_reset, V) 46 | self.spike.value = spike 47 | self.I.value = I 48 | 49 | 50 | class ExpSyn(bp.dyn.TwoEndConn): 51 | # Synapses parameters 52 | exc_delay = (1.5, 0.75) # Excitatory/Std. delay [ms] 53 | inh_delay = (0.80, 0.4) # Inhibitory/Std. delay [ms] 54 | exc_weight = (0.0878, 0.0088) # excitatory/Std. synaptic weight [nA] 55 | inh_weight_scale = -4. # Relative inhibitory synaptic strength 56 | 57 | def __init__(self, pre, post, prob, syn_type='e', conn_type=0): 58 | super(ExpSyn, self).__init__(pre=pre, post=post, conn=None) 59 | self.check_pre_attrs('spike') 60 | self.check_post_attrs('I') 61 | assert syn_type in ['e', 'i'] 62 | # assert conn_type in [0, 1, 2, 3] 63 | assert 0. < prob < 1. 64 | 65 | # parameters 66 | self.syn_type = syn_type 67 | self.conn_type = conn_type 68 | 69 | # connection 70 | if conn_type == 0: 71 | # number of synapses calculated with equation 3 from the article 72 | num = int(np.log(1.0 - prob) / np.log(1.0 - (1.0 / float(pre.num * post.num)))) 73 | self.pre2post = bp.conn.ij2csr(pre_ids=np.random.randint(0, pre.num, num), 74 | post_ids=np.random.randint(0, post.num, num), 75 | num_pre=pre.num) 76 | self.num = self.pre2post[0].size 77 | elif conn_type == 1: 78 | # number of synapses calculated with equation 5 from the article 79 | self.pre2post = bp.conn.FixedProb(prob)(pre.size, post.size).require('pre2post') 80 | self.num = self.pre2post[0].size 81 | elif conn_type == 2: 82 | self.num = int(prob * pre.num * post.num) 83 | self.pre_ids = bm.random.randint(0, pre.num, size=self.num, dtype=bm.uint32) 84 | self.post_ids = bm.random.randint(0, post.num, size=self.num, dtype=bm.uint32) 85 | elif conn_type in [3, 4]: 86 | self.pre2post = bp.conn.FixedProb(prob)(pre.size, post.size).require('pre2post') 87 | self.num = self.pre2post[0].size 88 | self.max_post_conn = bm.diff(self.pre2post[1]).max() 89 | else: 90 | raise ValueError 91 | 92 | # delay 93 | if syn_type == 'e': 94 | self.delay = bm.random.normal(*self.exc_delay, size=pre.num) 95 | elif syn_type == 'i': 96 | self.delay = bm.random.normal(*self.inh_delay, size=pre.num) 97 | else: 98 | raise ValueError 99 | self.delay = bm.where(self.delay < bm.get_dt(), bm.get_dt(), self.delay) 100 | 101 | # weights 102 | self.weights = bm.random.normal(*self.exc_weight, size=self.num) 103 | self.weights = bm.where(self.weights < 0, 0., self.weights) 104 | if syn_type == 'i': 105 | self.weights *= self.inh_weight_scale 106 | 107 | # variables 108 | self.pre_sps = bp.ConstantDelay(pre.num, self.delay, bool) 109 | 110 | def update(self, _t, _dt): 111 | self.pre_sps.push(self.pre.spike) 112 | delayed_sps = self.pre_sps.pull() 113 | if self.conn_type in [0, 1]: 114 | post_vs = bm.pre2post_event_sum(delayed_sps, self.pre2post, self.post.num, self.weights) 115 | elif self.conn_type == 2: 116 | post_vs = bm.pre2post_event_sum2(delayed_sps, self.pre_ids, self.post_ids, self.post.num, self.weights) 117 | # post_vs = bm.zeros(self.post.num) 118 | # post_vs = post_vs.value.at[self.post_ids.value].add(delayed_sps[self.pre_ids.value]) 119 | elif self.conn_type == 3: 120 | post_vs = bm.pre2post_event_sum3(delayed_sps, self.pre2post, self.post.num, self.weights, 121 | self.max_post_conn) 122 | elif self.conn_type == 4: 123 | post_vs = bm.pre2post_event_sum4(delayed_sps, self.pre2post, self.post.num, self.weights, 124 | self.max_post_conn) 125 | else: 126 | raise ValueError 127 | self.post.I += post_vs 128 | 129 | 130 | # class PoissonInput(bp.NeuGroup): 131 | # def __init__(self, post, freq=8.): 132 | # base = 20 133 | # super(PoissonInput, self).__init__(size=(post.num, base)) 134 | # 135 | # # parameters 136 | # freq = post.num * freq / base 137 | # self.prob = freq * bm.get_dt() / 1000. 138 | # self.weight = ExpSyn.exc_weight[0] 139 | # self.post = post 140 | # assert hasattr(post, 'I') 141 | # 142 | # # variables 143 | # self.rng = bm.random.RandomState() 144 | # 145 | # def update(self, _t, _dt): 146 | # self.post.I += self.weight * self.rng.random(self.size).sum(axis=1) 147 | 148 | 149 | class PoissonInput(bp.NeuGroup): 150 | def __init__(self, post, freq=8.): 151 | super(PoissonInput, self).__init__(size=(post.num,)) 152 | 153 | # parameters 154 | self.prob = freq * bm.get_dt() / 1000. 155 | self.loc = post.num * self.prob 156 | self.scale = np.sqrt(post.num * self.prob * (1 - self.prob)) 157 | self.weight = ExpSyn.exc_weight[0] 158 | self.post = post 159 | assert hasattr(post, 'I') 160 | 161 | # variables 162 | self.rng = bm.random.RandomState() 163 | 164 | def update(self, _t, _dt): 165 | self.post.I += self.weight * self.rng.normal(self.loc, self.scale, self.num) 166 | 167 | 168 | class PoissonInput2(bp.NeuGroup): 169 | def __init__(self, pops, freq=8.): 170 | super(PoissonInput2, self).__init__(size=sum([p.num for p in pops])) 171 | 172 | # parameters 173 | self.pops = pops 174 | prob = freq * bm.get_dt() / 1000. 175 | assert (prob * self.num > 5.) and (self.num * (1 - prob) > 5) 176 | self.loc = self.num * prob 177 | self.scale = np.sqrt(self.num * prob * (1 - prob)) 178 | self.weight = ExpSyn.exc_weight[0] 179 | 180 | # variables 181 | self.rng = bm.random.RandomState() 182 | 183 | def update(self, _t, _dt): 184 | sample_weights = self.rng.normal(self.loc, self.scale, self.num) * self.weight 185 | size = 0 186 | for p in self.pops: 187 | p.I += sample_weights[size: size + p.num] 188 | size += p.num 189 | 190 | 191 | class ThalamusInput(bp.TwoEndConn): 192 | def __init__(self, pre, post, conn_prob=0.1): 193 | super(ThalamusInput, self).__init__(pre=pre, post=post, conn=bp.conn.FixedProb(conn_prob)) 194 | self.check_pre_attrs('spike') 195 | self.check_post_attrs('I') 196 | 197 | # connection and weights 198 | self.pre2post = self.conn.require('pre2post') 199 | self.syn_num = self.pre2post[0].size 200 | self.weights = bm.random.normal(*ExpSyn.exc_weight, size=self.syn_num) 201 | self.weights = bm.where(self.weights < 0., 0., self.weights) 202 | 203 | # variables 204 | self.turn_on = bm.Variable(bm.asarray([False])) 205 | 206 | def update(self, _t, _dt): 207 | def true_fn(x): 208 | post_vs = bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.weights) 209 | self.post.I += post_vs 210 | 211 | bm.make_cond(true_fn, lambda _: None, dyn_vars=(self.post.I, self.pre.spike))(self.turn_on[0]) 212 | 213 | 214 | class CorticalMicrocircuit(bp.dyn.Network): 215 | # Names for each layer: 216 | layer_name = ['L23e', 'L23i', 'L4e', 'L4i', 'L5e', 'L5i', 'L6e', 'L6i', 'Th'] 217 | 218 | # Population size per layer: 219 | # 2/3e 2/3i 4e 4i 5e 5i 6e 6i Th 220 | layer_num = [20683, 5834, 21915, 5479, 4850, 1065, 14395, 2948, 902] 221 | 222 | # Layer-specific background input [nA]: 223 | # 2/3e 2/3i 4e 4i 5e 5i 6e 6i 224 | layer_specific_bg = np.array([1600, 1500, 2100, 1900, 2000, 1900, 2900, 2100]) / 1000 225 | 226 | # Layer-independent background input [nA]: 227 | # 2/3e 2/3i 4e 4i 5e 5i 6e 6i 228 | layer_independent_bg = np.array([2000, 1850, 2000, 1850, 2000, 1850, 2000, 1850]) / 1000 229 | 230 | # Prob. connection table 231 | conn_table = np.array([[0.101, 0.169, 0.044, 0.082, 0.032, 0.0000, 0.008, 0.000, 0.0000], 232 | [0.135, 0.137, 0.032, 0.052, 0.075, 0.0000, 0.004, 0.000, 0.0000], 233 | [0.008, 0.006, 0.050, 0.135, 0.007, 0.0003, 0.045, 0.000, 0.0983], 234 | [0.069, 0.003, 0.079, 0.160, 0.003, 0.0000, 0.106, 0.000, 0.0619], 235 | [0.100, 0.062, 0.051, 0.006, 0.083, 0.3730, 0.020, 0.000, 0.0000], 236 | [0.055, 0.027, 0.026, 0.002, 0.060, 0.3160, 0.009, 0.000, 0.0000], 237 | [0.016, 0.007, 0.021, 0.017, 0.057, 0.0200, 0.040, 0.225, 0.0512], 238 | [0.036, 0.001, 0.003, 0.001, 0.028, 0.0080, 0.066, 0.144, 0.0196]]) 239 | 240 | def __init__(self, bg_type=0, stim_type=0, conn_type=0, poisson_freq=8., has_thalamus=False): 241 | super(CorticalMicrocircuit, self).__init__() 242 | 243 | # parameters 244 | self.bg_type = bg_type 245 | self.stim_type = stim_type 246 | self.conn_type = conn_type 247 | self.poisson_freq = poisson_freq 248 | self.has_thalamus = has_thalamus 249 | 250 | # NEURON: populations 251 | self.populations = bp.Collector() 252 | for i in range(8): 253 | l_name = self.layer_name[i] 254 | print(f'Creating {l_name} ...') 255 | self.populations[l_name] = LIF(self.layer_num[i]) 256 | 257 | # SYNAPSE: synapses 258 | self.synapses = bp.Collector() 259 | for c in range(8): # from 260 | for r in range(8): # to 261 | if self.conn_table[r, c] > 0.: 262 | print(f'Creating Synapses from {self.layer_name[c]} to {self.layer_name[r]} ...') 263 | syn = ExpSyn(pre=self.populations[self.layer_name[c]], 264 | post=self.populations[self.layer_name[r]], 265 | prob=self.conn_table[r, c], 266 | syn_type=self.layer_name[c][-1], 267 | conn_type=conn_type) 268 | self.synapses[f'{self.layer_name[c]}_to_{self.layer_name[r]}'] = syn 269 | # Synaptic weight from L4e to L2/3e is doubled 270 | self.synapses['L4e_to_L23e'].weights *= 2. 271 | 272 | # NEURON & SYNAPSE: poisson inputs 273 | if stim_type == 0: 274 | # print(f'Creating Poisson noise group ...') 275 | # self.populations['Poisson'] = PoissonInput2( 276 | # freq=poisson_freq, pops=[self.populations[k] for k in self.layer_name[:-1]]) 277 | for r in range(0, 8): 278 | l_name = self.layer_name[r] 279 | print(f'Creating Poisson group of {l_name} ...') 280 | N = PoissonInput(freq=poisson_freq, post=self.populations[l_name]) 281 | self.populations[f'Poisson_to_{l_name}'] = N 282 | elif stim_type == 1: 283 | bg_inputs = self._get_bg_inputs(bg_type) 284 | assert bg_inputs is not None 285 | for i, current in enumerate(bg_inputs): 286 | self.populations[self.layer_name[i]].Iext = 0.3512 * current 287 | 288 | # NEURON & SYNAPSE: thalamus inputs 289 | if has_thalamus: 290 | thalamus = bp.dyn.PoissonInput(self.layer_num[-1], freqs=15.) 291 | self.populations[self.layer_name[-1]] = thalamus 292 | for r in range(0, 8): 293 | l_name = self.layer_name[r] 294 | print(f'Creating Thalamus projection of {l_name} ...') 295 | S = ThalamusInput(pre=thalamus, 296 | post=self.populations[l_name], 297 | conn_prob=self.conn_table[r, 8]) 298 | self.synapses[f'{self.layer_name[-1]}_to_{l_name}'] = S 299 | 300 | # finally, compose them as a network 301 | self.register_implicit_nodes(self.populations) 302 | self.register_implicit_nodes(self.synapses) 303 | 304 | def _get_bg_inputs(self, bg_type): 305 | if bg_type == 0: # layer-specific 306 | bg_layer = self.layer_specific_bg 307 | elif bg_type == 1: # layer-independent 308 | bg_layer = self.layer_independent_bg 309 | elif bg_type == 2: # layer-independent-random 310 | bg_layer = np.zeros(8) 311 | for i in range(0, 8, 2): 312 | # randomly choosing a number for the external input to an excitatory population: 313 | exc_bound = [self.layer_specific_bg[i], self.layer_independent_bg[i]] 314 | exc_input = np.random.uniform(min(exc_bound), max(exc_bound)) 315 | # randomly choosing a number for the external input to an inhibitory population: 316 | T = 0.1 if i != 6 else 0.2 317 | inh_bound = ((1 - T) / (1 + T)) * exc_input # eq. 4 from the article 318 | inh_input = np.random.uniform(inh_bound, exc_input) 319 | # array created to save the values: 320 | bg_layer[i] = int(exc_input) 321 | bg_layer[i + 1] = int(inh_input) 322 | else: 323 | bg_layer = None 324 | return bg_layer 325 | 326 | 327 | bm.random.seed() 328 | net = CorticalMicrocircuit(conn_type=1, poisson_freq=8., stim_type=1, bg_type=0) 329 | sps_monitors = [f'{n}.spike' for n in net.layer_name[:-1]] 330 | runner = bp.StructRunner(net, monitors=sps_monitors) 331 | runner.run(1000.) 332 | 333 | spikes = np.hstack([runner.mon[name] for name in sps_monitors]) 334 | bp.visualize.raster_plot(runner.mon.ts, spikes, show=True) 335 | 336 | 337 | # bp.visualize.line_plot(runner.mon.ts, runner.mon['L4e.V'], plot_ids=[0, 1, 2], show=True) 338 | 339 | # def run1(): 340 | # fig, gs = bp.visualize.get_figure(8, 1, col_len=8, row_len=1) 341 | # for i in range(8): 342 | # fig.add_subplot(gs[i, 0]) 343 | # name = net.layer_name[i] 344 | # bp.visualize.raster_plot(runner.mon.ts, runner.mon[f'{name}.spike'], 345 | # xlabel='Time [ms]' if i == 7 else None, 346 | # ylabel=name, show=i == 7) 347 | 348 | 349 | # if __name__ == '__main__': 350 | # run1() 351 | -------------------------------------------------------------------------------- /large_scale_modeling/Joglekar_2018_data/efelenMatpython.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/large_scale_modeling/Joglekar_2018_data/efelenMatpython.mat -------------------------------------------------------------------------------- /large_scale_modeling/Joglekar_2018_data/hierValspython.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/large_scale_modeling/Joglekar_2018_data/hierValspython.mat -------------------------------------------------------------------------------- /large_scale_modeling/Joglekar_2018_data/subgraphData.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/large_scale_modeling/Joglekar_2018_data/subgraphData.mat -------------------------------------------------------------------------------- /large_scale_modeling/Joglekar_2018_data/subgraphWiring29.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/large_scale_modeling/Joglekar_2018_data/subgraphWiring29.mat -------------------------------------------------------------------------------- /make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /recurrent_networks/Laje_Buonomano_2013_robust_timing_rnn.py: -------------------------------------------------------------------------------- 1 | # --- 2 | # jupyter: 3 | # jupytext: 4 | # formats: ipynb,py:percent 5 | # text_representation: 6 | # extension: .py 7 | # format_name: percent 8 | # format_version: '1.3' 9 | # jupytext_version: 1.11.5 10 | # kernelspec: 11 | # display_name: brainpy 12 | # language: python 13 | # name: brainpy 14 | # --- 15 | 16 | # %% [markdown] 17 | # # *(Laje & Buonomano, 2013)* Robust Timing in RNN 18 | 19 | # %% [markdown] 20 | # Implementation of the paper: 21 | # 22 | # - Laje, Rodrigo, and Dean V. Buonomano. "Robust timing and motor patterns by taming chaos in recurrent neural networks." Nature neuroscience 16, no. 7 (2013): 925-933. http://www.ncbi.nlm.nih.gov/pmc/articles/PMC3753043 23 | # 24 | # Thanks to the original implementation codes: https://github.com/ReScience-Archives/Vitay-2016 25 | 26 | # %% 27 | import brainpy as bp 28 | import brainpy.math as bm 29 | 30 | 31 | # %% [markdown] 32 | # ## Model Descriptions 33 | 34 | # %% [markdown] 35 | # ### Recurrent network 36 | # 37 | # The recurrent network is composed of $N=800$ neurons, receiving inputs from a variable number of input neurons $N_i$ and sparsely connected with each other. Each neuron's firing rate $r_i(t)$ applies the `tanh` transfer function on an internal variable $x_i(t)$ which follows a first-order linear ordinary differential equation (ODE): 38 | # 39 | # $$ 40 | # \tau \cdot \frac{d x_i(t)}{dt} = - x_i(t) + \sum_{j=1}^{N_i} W_{ij}^{in} \cdot y_j(t) + \sum_{j=1}^{N} W_{ij}^{rec} \cdot r_j(t) + I^{noise}_i(t) 41 | # $$ 42 | # 43 | # $$ 44 | # r_i(t) = \tanh(x_i(t)) 45 | # $$ 46 | # 47 | # The weights of the input matrix $W^{in}$ are taken from the normal distribution, with mean 0 and variance 1, and multiply the rates of the input neurons $y_j(t)$. The variance of these weights does not depend on the number of inputs, as only one input neuron is activated at the same time. The recurrent connectivity matrix $W^{rec}$ is sparse with a connection probability $pc = 0.1$ (i.e. 64000 non-zero elements) and existing weights are taken randomly from a normal distribution with mean 0 and variance $g/\sqrt{pc \cdot N}$, where $g$ is a scaling factor. It is a well known result for sparse recurrent networks that for high values of $g$, the network dynamics become chaotic. $I^{noise}_i(t)$ is an additive noise, taken randomly at each time step and for each neuron from a normal distribution with mean 0 and variance $I_0$. $I_0$ is chosen very small in the experiments reproduced here ($I_0 = 0.001$) but is enough to highlight the chaotic behavior of the recurrent neurons (non-reproducibility between two trials). 48 | 49 | # %% [markdown] 50 | # ### Read-out neurons 51 | # 52 | # The read-out neurons simply sum the activity of the recurrent neurons using a matrix $W^{out}$: 53 | # 54 | # $$ 55 | # z_i(t) = \sum_{j=1}^{N} W_{ij}^{out} \cdot r_j(t) 56 | # $$ 57 | # 58 | # The read-out matrix is initialized randomly from the normal distribution with mean 0 and variance $1/\sqrt{N}$. 59 | 60 | # %% [markdown] 61 | # ### Learning rule 62 | # 63 | # The particularity of the reservoir network proposed by *(Laje & Buonomano, Nature Neuroscience, 2013)* is that both the recurrent weights $W^{rec}$ and the read-out weights are trained in a supervised manner. More precisely, in this implementation, only 60% of the recurrent neurons have plastic weights, the 40% others keep the same weights throughout the simulation. 64 | # 65 | # Learning is done using the recursive least squares (RLS) algorithm *(Haykin, 2002)*. It is a supervised error-driven learning rule, i.e. the weight changes depend on the error made by each neuron: the difference between the firing rate of a neuron $r_i(t)$ and a desired value $R_i(t)$. 66 | # 67 | # $$ 68 | # e_i(t) = r_i(t) - R_i(t) 69 | # $$ 70 | # 71 | # For the recurrent neurons, the desired value is the recorded rate of that neuron during an initial trial (to enforce the reproducibility of the trajectory). For the read-out neurons, it is a function which we want the network to reproduce (e.g. handwriting as in Fig. 2). 72 | # 73 | # Contrary to the delta learning rule which modifies weights proportionally to the error and to the direct input to a synapse ($\Delta w_{ij} = - \eta \cdot e_i \cdot r_j$), the RLS learning uses a running estimate of the inverse correlation matrix of the inputs to each neuron: 74 | # 75 | # $$ 76 | # \Delta w_{ij} = - e_i \sum_{k \in \mathcal{B}(i)} P^i_{jk} \cdot r_k 77 | # $$ 78 | # 79 | # Each neuron $i$ therefore stores a square matrix $P^i$, whose size depends of the number of weights arriving to the neuron. Read-out neurons receive synapses from all $N$ recurrent neurons, so the $P$ matrix is $N*N$. Recurrent units have a sparse random connectivity ($pc = 0.1$), so each recurrent neuron stores only a $80*80$ matrix on average. In the previous equation, $\mathcal{B}(i)$ represents those existing weights. 80 | # 81 | # The inverse correlation matrix $P$ is updated at each time step with the following rule: 82 | # 83 | # $$ 84 | # \Delta P^i_{jk} = - \frac{\sum_{m \in \mathcal{B}(i)} \sum_{n \in \mathcal{B}(i)} P^i_{jm} \cdot r_m \cdot r_n \cdot P^i_{nk} }{ 1 + \sum_{m \in \mathcal{B}(i)} \sum_{n \in \mathcal{B}(i)} r_m \cdot P^i_{mn} \cdot r_n} 85 | # $$ 86 | # 87 | # Each matrix $P^i$ is initialized to the diagonal matrix and scaled by a factor $1/\delta$, where $\delta$ is 1 in the current implementation and can be used to modify implicitly the learning rate *(Sussilo & Larry, Neuron, 2009)*. 88 | 89 | # %% [markdown] 90 | # **Matrix/Vector mode** 91 | # 92 | # The above algorithms can be written into the matrix/vector forms. For the recurrent units, each matrix $P^i$ has a different size ($80*80$ on average), so we will still need to iterate over all post-synaptic neurons. If we note $\mathbf{W}$ the vector of weights coming to a neuron (80 on average), $\mathbf{r}$ the corresponding vector of firing rates (also 80), $e$ the error of that neuron (a scalar) and $\mathbf{P}$ the inverse correlation matrix (80*80), the update rules become: 93 | # $$ 94 | # \Delta \mathbf{W} = - e \cdot \mathbf{P} \cdot \mathbf{r} 95 | # $$ 96 | # 97 | # $$ 98 | # \Delta \mathbf{P} = - \frac{(\mathbf{P} \cdot \mathbf{r}) \cdot (\mathbf{P} \cdot \mathbf{r})^T}{1 + \mathbf{r}^T \cdot \mathbf{P} \cdot \mathbf{r}} 99 | # $$ 100 | # 101 | # In the original Matlab code, one notices that the weight update $\Delta \mathbf{W}$ is also normalized by the denominator of the update rule for $\mathbf{P}$: 102 | # 103 | # $$ 104 | # \Delta \mathbf{W} = - e \cdot \frac{\mathbf{P} \cdot \mathbf{r}}{1 + \mathbf{r}^T \cdot \mathbf{P} \cdot \mathbf{r}} 105 | # $$ 106 | # 107 | # Removing this normalization from the learning rule impairs learning completely, so we kept this variant of the RLS rule in our implementation. 108 | 109 | # %% [markdown] 110 | # ### Training procedure 111 | # 112 | # The training procedure is split into different trials, which differ from one experiment to another (Fig. 1, Fig. 2 and Fig. 3). Each trial begins with a relaxation period of `t_offset` = 200 ms, followed by a brief input impulse of duration 50 ms and variable amplitude. This impulse has the effect of bringing all recurrent neurons into a deterministic state (due to the `tanh` transfer function, the rates are saturated at either +1 or -1). This impulse is followed by a training window of variable length (in the seconds range) and finally another relaxation period. In Fig. 1 and Fig. 2, an additional impulse (duration 10 ms, smaller amplitude) can be given a certain delay after the initial impulse to test the ability of the network to recover its acquired trajectory after learning. 113 | # 114 | # In all experiments, the first trial is used to acquire an innate trajectory for the recurrent neurons in the absence of noise ($I_0$ is set to 0). The firing rate of all recurrent neurons over the training window is simply recorded and stored in an array without applying the learning rules. This innate trajectory for each recurrent neuron is used in the following trials as the target for the RLS learning rule, this time in the presence of noise ($I_0 = 0.001$). The RLS learning rule itself is only applied to the recurrent neurons during the training window, not during the impulse or the relaxation periods. Such a learning trial is repeated 20 or 30 times depending on the experiments. Once the recurrent weights have converged and the recurrent neurons are able to reproduce the innate trajectory, the read-out weights are trained using a custom desired function as target (10 trials). 115 | 116 | # %% [markdown] 117 | # ### References 118 | # 119 | # - Laje, Rodrigo, and Dean V. Buonomano. "Robust timing and motor patterns by taming chaos in recurrent neural networks." Nature neuroscience 16, no. 7 (2013): 925-933. 120 | # - Haykin, Simon S. Adaptive filter theory. Pearson Education India, 2008. 121 | # - Sussillo, David, and Larry F. Abbott. "Generating coherent patterns of activity from chaotic neural networks." Neuron 63, no. 4 (2009): 544-557. 122 | 123 | # %% [markdown] 124 | # ## Implementation 125 | 126 | # %% 127 | class RNN(bp.Base): 128 | target_backend = 'numpy' 129 | 130 | def __init__(self, num_input=2, num_rec=800, num_output=1, tau=10.0, 131 | g=1.5, pc=0.1, noise=0.001, delta=1.0, plastic_prob=0.6, dt=1.): 132 | super(RNN, self).__init__() 133 | 134 | # Copy the parameters 135 | self.num_input = num_input # number of input neurons 136 | self.num_rec = num_rec # number of recurrent neurons 137 | self.num_output = num_output # number of read-out neurons 138 | self.tau = tau # time constant of the neurons 139 | self.g = g # synaptic strength scaling 140 | self.pc = pc # connection probability 141 | self.noise = noise # noise variance 142 | self.delta = delta # initial value of the P matrix 143 | self.plastic_prob = plastic_prob # percentage of neurons receiving plastic synapses 144 | self.dt = dt # numerical precision 145 | 146 | # Initializes the network including the weight matrices. 147 | # --- 148 | 149 | # Recurrent population 150 | self.x = bm.Variable(bm.random.uniform(-1.0, 1.0, (self.num_rec,))) 151 | self.r = bm.Variable(bm.tanh(self.x)) 152 | 153 | # Read-out population 154 | self.z = bm.Variable(bm.zeros(self.num_output, dtype=bm.float_)) 155 | 156 | # Weights between the input and recurrent units 157 | self.W_in = bm.Variable(bm.random.randn(self.num_input, self.num_rec)) 158 | 159 | # Weights between the recurrent units 160 | self.W_rec = bm.Variable(bm.random.randn(self.num_rec, self.num_rec) * self.g / bm.sqrt(self.pc * self.num_rec)) 161 | 162 | # The connection pattern is sparse with p=0.1 163 | self.conn_mask = bm.random.random((self.num_rec, self.num_rec)) < self.pc 164 | diag = bm.arange(self.num_rec) 165 | self.conn_mask[diag, diag] = False 166 | self.W_rec *= bm.asarray(self.conn_mask, dtype=bm.float_) 167 | 168 | # Store the pre-synaptic neurons to each plastic neuron 169 | train_mask = bm.random.random((self.num_rec, self.num_rec)) < plastic_prob 170 | self.train_mask = bm.logical_and(train_mask, self.conn_mask) 171 | 172 | # Store the pre-synaptic neurons to each plastic neuron 173 | self.W_plastic = bm.Variable([bm.array(bm.nonzero(self.train_mask[i])[0]) for i in range(self.num_rec)]) 174 | 175 | # Inverse correlation matrix of inputs for learning recurrent weights 176 | self.P_rec = bm.Variable([bm.identity(len(self.W_plastic[i])) / self.delta for i in range(self.num_rec)]) 177 | 178 | # Output weights 179 | self.W_out = bm.Variable(bm.random.randn(self.num_output, self.num_rec) / bm.sqrt(self.num_rec)) 180 | 181 | # Inverse correlation matrix of inputs for learning readout weights 182 | P_out = bm.expand_dims(bm.identity(self.num_rec) / self.delta, axis=0) 183 | self.P_out = bm.Variable(bm.repeat(P_out, self.num_output, axis=0)) 184 | 185 | # loss 186 | self.loss = bm.Variable(bm.array([0.])) 187 | 188 | def simulate(self, stimulus, noise=True, target_traj=None, 189 | learn_start=-1, learn_stop=-1, learn_readout=False): 190 | """Simulates the recurrent network for the given duration, with or without plasticity. 191 | 192 | Parameters 193 | ---------- 194 | 195 | stimulus : bm.ndarray 196 | The external inputs. 197 | noise : bool 198 | if noise should be added to the recurrent units (default: True) 199 | target_traj : bm,ndarray 200 | During learning, defines which output target_traj function should be learned (default: no learning) 201 | learn_start : int, float 202 | Time when learning should start. 203 | learn_stop : int, float 204 | Time when learning should stop. 205 | learn_readout : bool 206 | Defines whether the recurrent (False) or readout (True) weights should be learned. 207 | """ 208 | 209 | length, _ = stimulus.shape 210 | 211 | # Arrays for recording 212 | record_r = bm.zeros((length, self.num_rec), dtype=bm.float_) 213 | record_z = bm.zeros((length, self.num_output), dtype=bm.float_) 214 | 215 | # Reset the recurrent population 216 | self.x[:] = bm.random.uniform(-1.0, 1.0, (self.num_rec,)) 217 | self.r[:] = bm.tanh(self.x) 218 | 219 | # Reset loss term 220 | self.loss[:] = 0.0 221 | 222 | # Simulate for the desired duration 223 | for i in range(length): 224 | # Update the neurons' firing rates 225 | self.update(stimulus[i, :], noise) 226 | 227 | # Recording 228 | record_r[i] = self.r 229 | record_z[i] = self.z 230 | 231 | # Learning 232 | if target_traj is not None and learn_stop > i * self.dt >= learn_start and i % 2 == 0: 233 | if not learn_readout: 234 | self.train_recurrent(target_traj[i]) 235 | else: 236 | self.train_readout(target_traj[i]) 237 | 238 | return record_r, record_z, self.loss 239 | 240 | def update(self, stimulus, noise=True): 241 | """Updates neural variables for a single simulation step.""" 242 | dx = -self.x + bm.dot(stimulus, self.W_in) + bm.dot(self.W_rec, self.r) 243 | if noise: 244 | dx += self.noise * bm.random.randn(self.num_rec) 245 | self.x += dx / self.tau * self.dt 246 | self.r[:] = bm.tanh(self.x) 247 | self.z[:] = bm.dot(self.W_out, self.r) 248 | 249 | def train_recurrent(self, target): 250 | """Apply the RLS learning rule to the recurrent weights.""" 251 | 252 | # Compute the error of the recurrent neurons 253 | error = self.r - target # output_traj : (num_rec, ) 254 | self.loss += bm.mean(error ** 2) 255 | 256 | # Apply the FORCE learning rule to the recurrent weights 257 | for i in range(self.num_rec): # for each plastic post neuron 258 | # Get the rates from the plastic synapses only 259 | r_plastic = bm.expand_dims(self.r[self.W_plastic[i]], axis=1) 260 | # Multiply with the inverse correlation matrix P*R 261 | PxR = bm.dot(self.P_rec[i], r_plastic) 262 | # Normalization term 1 + R'*P*R 263 | RxPxR = (1. + bm.dot(r_plastic.T, PxR)) 264 | # Update the inverse correlation matrix P <- P - ((P*R)*(P*R)')/(1+R'*P*R) 265 | self.P_rec[i][:] -= bm.dot(PxR, PxR.T) / RxPxR 266 | # Learning rule W <- W - e * (P*R)/(1+R'*P*R) 267 | for j, idx in enumerate(self.W_plastic[i]): 268 | self.W_rec[i, idx] -= error[i] * (PxR[j, 0] / RxPxR[0, 0]) 269 | 270 | def train_readout(self, target): 271 | """Apply the RLS learning rule to the readout weights.""" 272 | 273 | # Compute the error of the output neurons 274 | error = self.z - target # output_traj : (O, ) 275 | # loss 276 | self.loss += bm.mean(error ** 2) 277 | # Apply the FORCE learning rule to the readout weights 278 | for i in range(self.num_output): # for each readout neuron 279 | # Multiply the rates with the inverse correlation matrix P*R 280 | r = bm.expand_dims(self.r, axis=1) 281 | PxR = bm.dot(self.P_out[i], r) 282 | # Normalization term 1 + R'*P*R 283 | RxPxR = (1. + bm.dot(r.T, PxR)) 284 | # Update the inverse correlation matrix P <- P - ((P*R)*(P*R)')/(1+R'*P*R) 285 | self.P_out[i] -= bm.dot(PxR, PxR.T) / RxPxR 286 | # Learning rule W <- W - e * (P*R)/(1+R'*P*R) 287 | self.W_out[i] -= error[i] * (PxR / RxPxR)[:, 0] 288 | 289 | def init(self): 290 | # Recurrent population 291 | self.x[:] = bm.random.uniform(-1.0, 1.0, (self.num_rec,)) 292 | self.r[:] = bm.tanh(self.x) 293 | 294 | # Read-out population 295 | self.z[:] = bm.zeros((self.num_output,)) 296 | 297 | # Weights between the input and recurrent units 298 | self.W_in[:] = bm.random.randn(self.num_input, self.num_rec) 299 | 300 | # Weights between the recurrent units 301 | self.W_rec[:] = bm.random.randn(self.num_rec, self.num_rec) * self.g / bm.sqrt(self.pc * self.num_rec) 302 | self.W_rec *= self.conn_mask 303 | 304 | # Inverse correlation matrix of inputs for learning recurrent weights 305 | for i in range(self.num_rec): 306 | self.P_rec[i][:] = bm.identity(len(self.W_plastic[i])) / self.delta 307 | 308 | # Output weights 309 | self.W_out[:] = bm.random.randn(self.num_output, self.num_rec) / bm.sqrt(self.num_rec) 310 | 311 | # Inverse correlation matrix of inputs for learning readout weights 312 | P_out = bm.expand_dims(bm.identity(self.num_rec) / self.delta, axis=0) 313 | self.P_out[:] = bm.repeat(P_out, self.num_output, axis=0) 314 | 315 | # loss 316 | self.loss[:] = 0. 317 | 318 | 319 | # %% [markdown] 320 | # ## Results 321 | 322 | # %% 323 | from Laje_Buonomano_2013_simulation import * 324 | 325 | # %% [markdown] 326 | # The ``simulation.py`` can be downloaded from [here](./Laje_Buonomano_2013_simulation.ipynb). 327 | 328 | # %% [markdown] 329 | # ### Fig. 1: an initially chaotic network can become deterministic after training 330 | 331 | # %% 332 | net = RNN( 333 | num_input=2, # Number of inputs 334 | num_rec=800, # Number of recurrent neurons 335 | num_output=1, # Number of read-out neurons 336 | tau=10.0, # Time constant of the neurons 337 | g=1.8, # Synaptic strength scaling 338 | pc=0.1, # Connection probability 339 | noise=0.001, # Noise variance 340 | delta=1.0, # Initial diagonal value of the P matrix 341 | plastic_prob=0.6, # Percentage of neurons receiving plastic synapses 342 | dt=1., 343 | ) 344 | net.update = bp.math.jit(net.update) 345 | net.train_recurrent = bp.math.jit(net.train_recurrent) 346 | 347 | fig1(net) 348 | 349 | # %% [markdown] 350 | # ### Fig. 2: read-out neurons robustly learn trajectories in the 2D space 351 | 352 | # %% 353 | net = RNN( 354 | num_input=4, # Number of inputs 355 | num_rec=800, # Number of recurrent neurons 356 | num_output=2, # Number of read-out neurons 357 | tau=10.0, # Time constant of the neurons 358 | g=1.5, # Synaptic strength scaling 359 | pc=0.1, # Connection probability 360 | noise=0.001, # Noise variance 361 | delta=1.0, # Initial value of the P matrix 362 | plastic_prob=0.6, # Percentage of neurons receiving plastic synapses 363 | dt=1. 364 | ) 365 | net.update = bp.math.jit(net.update) 366 | net.train_recurrent = bp.math.jit(net.train_recurrent) 367 | net.train_readout = bp.math.jit(net.train_readout) 368 | 369 | fig2(net) 370 | 371 | # %% [markdown] 372 | # ### Fig. 3: the timing capacity of the recurrent network 373 | 374 | # %% 375 | nets = [] 376 | for _ in range(10): 377 | net = RNN( 378 | num_input=2, # Number of inputs 379 | num_rec=800, # Number of recurrent neurons 380 | num_output=1, # Number of read-out neurons 381 | tau=10.0, # Time constant of the neurons 382 | g=1.5, # Synaptic strength scaling 383 | pc=0.1, # Connection probability 384 | noise=0.001, # Noise variance 385 | delta=1.0, # Initial diagonal value of the P matrix 386 | plastic_prob=0.6, # Percentage of neurons receiving plastic synapses 387 | dt=1. 388 | ) 389 | net.update = bp.math.jit(net.update) 390 | net.train_recurrent = bp.math.jit(net.train_recurrent) 391 | nets.append(net) 392 | 393 | fig3(nets, verbose=False) 394 | -------------------------------------------------------------------------------- /recurrent_networks/Laje_Buonomano_2013_simulation.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import time 3 | 4 | # %% 5 | import brainpy.math as bm 6 | import matplotlib.patches as patches 7 | import matplotlib.pyplot as plt 8 | import scipy.stats 9 | from mpl_toolkits.axes_grid1 import make_axes_locatable 10 | from scipy.io import loadmat 11 | 12 | 13 | # %% 14 | __all__ = [ 15 | 'fig1', 16 | 'fig2', 17 | 'fig3', 18 | ] 19 | 20 | 21 | # %% 22 | dt = 1. 23 | time2len = lambda t: int(t / dt) 24 | 25 | 26 | # %% 27 | def fig1(net, verbose=True): 28 | # Parameters 29 | # ---------- 30 | num_rec_train = 30 # Number of learning trials for the recurrent weights 31 | num_readout_train = 10 # Number of learning trials for the readout weights 32 | num_perturbation = 5 # Number of perturbation trials 33 | 34 | stimulus_amplitude = 5.0 # Amplitude of the input pulse 35 | t_offset = time2len(200) # Time to wait before the stimulation 36 | d_stim = time2len(50) # Duration of the stimulation 37 | d_trajectory = time2len(2000 + 150) # Duration of the desired target_traj 38 | t_relax = time2len(550) # Duration to relax after the target_traj 39 | trial_duration = t_offset + d_stim + d_trajectory + t_relax # Total duration of a trial 40 | times = bm.arange(0, trial_duration, dt) 41 | 42 | perturbation_amplitude = 0.5 # Amplitude of the perturbation pulse 43 | t_perturb = time2len(500) # Offset for the perturbation 44 | d_perturb = time2len(10) # Duration of the perturbation 45 | 46 | target_baseline = 0.2 # Baseline of the output_traj function 47 | target_amplitude = 1. # Maximal value of the output_traj function 48 | target_width = 30. # Width of the Gaussian 49 | target_time = time2len(d_trajectory - 150) # Peak time within the learning interval 50 | 51 | # Input definitions 52 | # ----------------- 53 | 54 | # Impulse after 200 ms 55 | impulse = bm.zeros((trial_duration, net.num_input)) 56 | impulse[t_offset:t_offset + d_stim, 0] = stimulus_amplitude 57 | 58 | # Perturbation during the trial 59 | perturbation = bm.zeros((trial_duration, net.num_input)) 60 | perturbation[t_offset: t_offset + d_stim, 0] = stimulus_amplitude 61 | perturbation[t_offset + t_perturb: t_offset + t_perturb + d_perturb, 1] = perturbation_amplitude 62 | 63 | # Target output for learning the readout weights 64 | output_traj = bm.zeros((trial_duration, net.num_output)) 65 | output_traj[:, 0] = target_baseline + (target_amplitude - target_baseline) * \ 66 | bm.exp(-(t_offset + d_stim + target_time - times) ** 2 / target_width ** 2) 67 | 68 | # Main procedure 69 | # -------------- 70 | tstart = time.time() 71 | 72 | # Initial trial to determine the innate target_traj 73 | if verbose: print('Initial trial to determine a target_traj (without noise)') 74 | initial_traj, initial_output, _ = net.simulate(stimulus=impulse, noise=False) 75 | 76 | # Pre-training test trial 77 | if verbose: print('Pre-training test trial') 78 | pretrain_traj, pretrain_output, _ = net.simulate(stimulus=impulse, noise=True) 79 | 80 | # Perturbation trial 81 | if verbose: print(num_perturbation, 'perturbation trials') 82 | perturbation_initial = [] 83 | for i in range(num_perturbation): 84 | _, perturbation_output, _ = net.simulate(stimulus=perturbation, noise=True) 85 | perturbation_initial.append(perturbation_output) 86 | 87 | # 20 trials of learning for the recurrent weights 88 | for i in range(num_rec_train): 89 | t0 = time.time() 90 | if verbose: print(f'Learning trial recurrent {i + 1} loss: ', end='') 91 | _, _, loss = net.simulate(stimulus=impulse, 92 | target_traj=initial_traj, 93 | learn_start=t_offset + d_stim, 94 | learn_stop=t_offset + d_stim + d_trajectory) 95 | if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s') 96 | 97 | # 10 trials of learning for the readout weights 98 | for i in range(num_readout_train): 99 | t0 = time.time() 100 | if verbose: print(f'Learning trial readout {i + 1} loss: ', end='') 101 | _, _, loss = net.simulate(stimulus=impulse, 102 | target_traj=output_traj, 103 | learn_start=t_offset + d_stim, 104 | learn_stop=t_offset + d_stim + d_trajectory, 105 | learn_readout=True) 106 | if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s') 107 | 108 | # Test trial 109 | if verbose: print('2 test trials') 110 | reproductions = [] 111 | final_outputs = [] 112 | for i in range(2): 113 | reproduction, final_output, _ = net.simulate(stimulus=impulse, noise=True) 114 | reproductions.append(reproduction) 115 | final_outputs.append(final_output) 116 | 117 | # Perturbation trial 118 | if verbose: print(num_perturbation, 'perturbation trials') 119 | perturbation_final = [] 120 | for i in range(num_perturbation): 121 | _, perturbation_output, _ = net.simulate(stimulus=perturbation) 122 | perturbation_final.append(perturbation_output) 123 | 124 | if verbose: print('Simulation done in', time.time() - tstart, 'seconds.') 125 | 126 | # Visualization 127 | # ------------- 128 | 129 | plt.figure(figsize=(8, 12)) 130 | 131 | # innate trajectory 132 | ax = plt.subplot2grid((4, 2), (0, 0), colspan=2) 133 | im = ax.imshow(initial_traj[:, :100].T, aspect='auto', origin='lower') 134 | cax = make_axes_locatable(ax).append_axes("right", size="5%", pad=0.05) 135 | plt.colorbar(im, cax=cax) 136 | ymin, ymax = ax.get_ylim() 137 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.2)) 138 | ax.set_title('Innate Trajectory') 139 | ax.set_xlabel('Time (ms)') 140 | ax.set_ylabel('Recurrent units') 141 | 142 | # pre-training results 143 | ax = plt.subplot2grid((4, 2), (1, 0)) 144 | for i in range(3): 145 | ax.plot(times, initial_traj[:, i] + i * 2 + 1, 'b') 146 | ax.plot(times, pretrain_traj[:, i] + i * 2 + 1, 'r') 147 | ax.set_yticks([i * 2 + 1 for i in range(3)]) 148 | ax.set_yticklabels([0, 0, 0]) 149 | ymin, ymax = ax.get_ylim() 150 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7)) 151 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1)) 152 | ax.set_title('Pre-training') 153 | ax.set_ylabel('Firing rate $r$ [Hz]') 154 | ax.set_xlim(times[0], times[-1]) 155 | 156 | ax = plt.subplot2grid((4, 2), (2, 0)) 157 | ax.plot(times, initial_output[:, 0], 'b') 158 | ax.plot(times, pretrain_output[:, 0], 'r') 159 | ax.axhline(output_traj[0, 0], c='k') 160 | ax.set_yticks([-2, -1, 0, 1, 2]) 161 | ax.set_ylim((-2, 2)) 162 | ymin, ymax = ax.get_ylim() 163 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7)) 164 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1)) 165 | ax.set_ylabel('Output (test)') 166 | ax.set_xlim(times[0], times[-1]) 167 | 168 | ax = plt.subplot2grid((4, 2), (3, 0)) 169 | for i in range(num_perturbation): 170 | ax.plot(times, perturbation_initial[i][:, 0]) 171 | ax.axhline(output_traj[0, 0], c='k') 172 | ax.set_yticks([-2, -1, 0, 1, 2]) 173 | ax.set_ylim((-2, 2)) 174 | ymin, ymax = ax.get_ylim() 175 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7)) 176 | ax.add_patch(patches.Rectangle((t_offset + t_perturb, ymin), d_perturb, ymax - ymin, color='gray', alpha=0.7)) 177 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1)) 178 | ax.set_xlabel('Time (ms)') 179 | ax.set_ylabel('Output (perturbed)') 180 | ax.set_xlim(times[0], times[-1]) 181 | 182 | # post-training results 183 | 184 | ax = plt.subplot2grid((4, 2), (1, 1)) 185 | for i in range(3): 186 | ax.plot(times, reproductions[0][:, i] + i * 2 + 1, 'b') 187 | ax.plot(times, reproductions[1][:, i] + i * 2 + 1, 'r') 188 | ax.set_yticks([i * 2 + 1 for i in range(3)]) 189 | ax.set_yticklabels([0, 0, 0]) 190 | ymin, ymax = ax.get_ylim() 191 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7)) 192 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1)) 193 | ax.set_title('Post-training') 194 | ax.set_xlim(times[0], times[-1]) 195 | 196 | ax = plt.subplot2grid((4, 2), (2, 1)) 197 | ax.plot(times, final_outputs[0][:, 0], 'b') 198 | ax.plot(times, final_outputs[1][:, 0], 'r') 199 | ax.axhline(output_traj[0, 0], c='k') 200 | ax.set_yticks([-2, -1, 0, 1, 2]) 201 | ax.set_ylim((-2, 2)) 202 | ymin, ymax = ax.get_ylim() 203 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7)) 204 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1)) 205 | ax.set_xlim(times[0], times[-1]) 206 | 207 | ax = plt.subplot2grid((4, 2), (3, 1)) 208 | for i in range(num_perturbation): 209 | ax.plot(times, perturbation_final[i][:, 0]) 210 | ax.axhline(output_traj[0, 0], c='k') 211 | ax.set_yticks([-2, -1, 0, 1, 2]) 212 | ax.set_ylim((-2, 2)) 213 | ymin, ymax = ax.get_ylim() 214 | ax.add_patch(patches.Rectangle((t_offset, ymin), d_stim, ymax - ymin, color='gray', alpha=0.7)) 215 | ax.add_patch(patches.Rectangle((t_offset + t_perturb, ymin), d_perturb, ymax - ymin, color='gray', alpha=0.7)) 216 | ax.add_patch(patches.Rectangle((t_offset + d_stim, ymin), d_trajectory, ymax - ymin, color='gray', alpha=0.1)) 217 | ax.set_xlabel('Time (ms)') 218 | ax.set_xlim(times[0], times[-1]) 219 | 220 | plt.tight_layout() 221 | plt.show() 222 | 223 | 224 | # %% 225 | def fig2(net, verbose=True): 226 | # Parameters 227 | # ---------- 228 | 229 | num_rec_train = 30 # Number of learning trials for the recurrent weights 230 | num_readout_train = 10 # Number of learning trials for the readout weights 231 | num_test = 5 # Number of test trials 232 | num_perturb = 5 # Number of perturbation trials 233 | 234 | stimulus_amplitude = 2.0 # Amplitude of the input pulse 235 | t_offset = time2len(200) # Time to wait before the stimulation 236 | d_stim = time2len(50) # Duration of the stimulation 237 | t_relax = time2len(150) # Duration to relax after the target_traj 238 | 239 | perturbation_amplitude = 0.2 # Amplitude of the perturbation pulse 240 | t_perturb = time2len(300) # Offset for the perturbation 241 | d_perturb = time2len(10) # Duration of the perturbation 242 | 243 | # Input definitions 244 | # ----------------- 245 | 246 | # Retrieve the targets and reformat them 247 | targets = loadmat('data/DAC_handwriting_output_targets.mat') 248 | chaos = targets['chaos'] 249 | neuron = targets['neuron'] 250 | 251 | # Durations 252 | _, d_chaos = chaos.shape 253 | _, d_neuron = neuron.shape 254 | 255 | # Impulses 256 | impulse_chaos = bm.zeros((t_offset + d_stim + d_chaos + t_relax, net.num_input)) 257 | impulse_chaos[t_offset:t_offset + d_stim, 0] = stimulus_amplitude 258 | impulse_neuron = bm.zeros((t_offset + d_stim + d_neuron + t_relax, net.num_input)) 259 | impulse_neuron[t_offset:t_offset + d_stim, 2] = stimulus_amplitude 260 | 261 | # Perturbation 262 | perturbation_chaos = bm.zeros((t_offset + d_stim + d_chaos + t_relax, net.num_input)) 263 | perturbation_chaos[t_offset:t_offset + d_stim, 0] = stimulus_amplitude 264 | perturbation_chaos[t_offset + t_perturb: t_offset + t_perturb + d_perturb, 1] = perturbation_amplitude 265 | perturbation_neuron = bm.zeros((t_offset + d_stim + d_neuron + t_relax, net.num_input)) 266 | perturbation_neuron[t_offset:t_offset + d_stim, 2] = stimulus_amplitude 267 | perturbation_neuron[t_offset + t_perturb: t_offset + t_perturb + d_perturb, 3] = perturbation_amplitude 268 | 269 | # Targets 270 | target_chaos = bm.zeros((t_offset + d_stim + d_chaos + t_relax, net.num_output)) 271 | target_chaos[t_offset + d_stim: t_offset + d_stim + d_chaos, :] = chaos.T 272 | target_neuron = bm.zeros((t_offset + d_stim + d_neuron + t_relax, net.num_output)) 273 | target_neuron[t_offset + d_stim: t_offset + d_stim + d_neuron, :] = neuron.T 274 | 275 | # Main procedure 276 | # -------------- 277 | tstart = time.time() 278 | 279 | # Initial trial to determine the innate target_traj 280 | if verbose: print('Initial chaos trial') 281 | trajectory_chaos, initial_chaos_output, _ = net.simulate(stimulus=impulse_chaos, noise=False) 282 | if verbose: print('Initial neuron trial') 283 | trajectory_neuron, initial_neuron_output, _ = net.simulate(stimulus=impulse_neuron, noise=False) 284 | 285 | # learning for the recurrent weights 286 | for i in range(num_rec_train): 287 | t0 = time.time() 288 | if verbose: print(f'Learning recurrent {i + 1} "chaos" loss: ', end='') 289 | _, _, loss = net.simulate(stimulus=impulse_chaos, 290 | target_traj=trajectory_chaos, 291 | learn_start=t_offset + d_stim, 292 | learn_stop=t_offset + d_stim + d_chaos) 293 | if verbose: print(f'{(2 * loss[0] / d_chaos):.5f} used {(time.time() - t0):5f} s, ', end='') 294 | 295 | t0 = time.time() 296 | _, _, loss = net.simulate(stimulus=impulse_neuron, 297 | target_traj=trajectory_neuron, 298 | learn_start=t_offset + d_stim, 299 | learn_stop=t_offset + d_stim + d_neuron) 300 | if verbose: print(f'"neuron" loss: {(2 * loss[0] / d_chaos):5f} used {(time.time() - t0):5f} s') 301 | 302 | # learning for the readout weights 303 | for i in range(num_readout_train): 304 | t0 = time.time() 305 | if verbose: print(f'Learning readout {i + 1} "chaos" loss: ', end='') 306 | _, _, loss = net.simulate(stimulus=impulse_chaos, 307 | target_traj=target_chaos, 308 | learn_start=t_offset + d_stim, 309 | learn_stop=t_offset + d_stim + d_chaos, 310 | learn_readout=True) 311 | if verbose: print(f'{(2 * loss[0] / d_chaos):.5f} used {(time.time() - t0):5f} s, ', end='') 312 | 313 | t0 = time.time() 314 | _, _, loss = net.simulate(stimulus=impulse_neuron, 315 | target_traj=target_neuron, 316 | learn_start=t_offset + d_stim, 317 | learn_stop=t_offset + d_stim + d_neuron, 318 | learn_readout=True) 319 | if verbose: print(f'"neuron" loss: {(2 * loss[0] / d_chaos):5f} used {(time.time() - t0):5f} s') 320 | 321 | # Test trials 322 | final_output_chaos = [] 323 | final_output_neuron = [] 324 | for _ in range(num_test): 325 | if verbose: print('Test chaos trial') 326 | _, o, _ = net.simulate(stimulus=impulse_chaos) 327 | final_output_chaos.append(o) 328 | if verbose: print('Test neuron trial') 329 | _, o, _ = net.simulate(stimulus=impulse_neuron) 330 | final_output_neuron.append(o) 331 | 332 | # Perturbation trials 333 | perturbation_output_chaos = [] 334 | perturbation_output_neuron = [] 335 | for _ in range(num_perturb): 336 | if verbose: print('Perturbation chaos trial') 337 | _, o, _ = net.simulate(stimulus=perturbation_chaos) 338 | perturbation_output_chaos.append(o) 339 | if verbose: print('Perturbation neuron trial') 340 | _, o, _ = net.simulate(stimulus=perturbation_neuron) 341 | perturbation_output_neuron.append(o) 342 | 343 | if verbose: print('Simulation done in', time.time() - tstart, 'seconds.') 344 | 345 | # Visualization 346 | # ------------- 347 | subsampling_chaos = (t_offset + d_stim + bm.linspace(0, d_chaos, 20)).astype(bm.int32) 348 | subsampling_chaos = bm.unique(subsampling_chaos) 349 | subsampling_neuron = (t_offset + d_stim + bm.linspace(0, d_neuron, 20)).astype(bm.int32) 350 | subsampling_neuron = bm.unique(subsampling_neuron) 351 | 352 | plt.figure(figsize=(12, 8)) 353 | ax = plt.subplot2grid((2, 2), (0, 0)) 354 | ax.plot(chaos[0, :], chaos[1, :], linewidth=2.) 355 | for i in range(num_perturb): 356 | ax.plot(final_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 0], 357 | final_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 1]) 358 | ax.plot(final_output_chaos[i][subsampling_chaos, 0], 359 | final_output_chaos[i][subsampling_chaos, 1], 'bo') 360 | ax.set_xlabel('x') 361 | ax.set_ylabel('Without perturbation\ny') 362 | 363 | ax = plt.subplot2grid((2, 2), (0, 1)) 364 | ax.plot(neuron[0, :], neuron[1, :], linewidth=2.) 365 | for i in range(num_perturb): 366 | ax.plot(final_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 0], 367 | final_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 1]) 368 | ax.plot(final_output_neuron[i][subsampling_neuron, 0], 369 | final_output_neuron[i][subsampling_neuron, 1], 'bo') 370 | ax.set_xlabel('x') 371 | ax.set_ylabel('y') 372 | 373 | ax = plt.subplot2grid((2, 2), (1, 0)) 374 | ax.plot(chaos[0, :], chaos[1, :], linewidth=2.) 375 | for i in range(num_perturb): 376 | ax.plot(perturbation_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 0], 377 | perturbation_output_chaos[i][t_offset + d_stim: t_offset + d_stim + d_chaos, 1]) 378 | ax.plot(perturbation_output_chaos[i][subsampling_chaos, 0], 379 | perturbation_output_chaos[i][subsampling_chaos, 1], 'bo') 380 | ax.set_xlabel('x') 381 | ax.set_ylabel('y') 382 | ax.set_ylabel('With perturbation\ny') 383 | 384 | ax = plt.subplot2grid((2, 2), (1, 1)) 385 | ax.plot(neuron[0, :], neuron[1, :], linewidth=2.) 386 | for i in range(num_perturb): 387 | ax.plot(perturbation_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 0], 388 | perturbation_output_neuron[i][t_offset + d_stim: t_offset + d_stim + d_neuron, 1]) 389 | ax.plot(perturbation_output_neuron[i][subsampling_neuron, 0], 390 | perturbation_output_neuron[i][subsampling_neuron, 1], 'bo') 391 | ax.set_xlabel('x') 392 | ax.set_ylabel('y') 393 | 394 | plt.show() 395 | 396 | 397 | # %% 398 | def fig3(nets, verbose=True): 399 | # Parameters 400 | # ---------- 401 | 402 | num_rec_train = 20 # Number of learning trials for the recurrent weights 403 | num_readout_train = 10 # Number of learning trials for the readout weights 404 | 405 | stimulus_amplitude = 5.0 # Amplitude of the input pulse 406 | t_offset = time2len(200) # Time to wait before the stimulation 407 | d_stim = time2len(50) # Duration of the stimulation 408 | t_relax = time2len(150) # Duration to relax after the target_traj 409 | 410 | target_baseline = 0.2 # Baseline of the output_traj function 411 | target_amplitude = 1. # Maximal value of the output_traj function 412 | target_width = 30. # Width of the Gaussian 413 | 414 | # Main procedure 415 | # -------------- 416 | 417 | # Vary the timing interval 418 | delays = [250, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000] 419 | 420 | # Store the Pearson correlation coefficients 421 | pearsons = [] 422 | 423 | # Iterate over the delays 424 | for target_time in delays: 425 | if verbose: 426 | print('*' * 60) 427 | print('Learning a delay of', target_time) 428 | print('*' * 60) 429 | print() 430 | d_trajectory = time2len(target_time + 150) # Duration of the desired target_traj 431 | trial_duration = t_offset + d_stim + d_trajectory + t_relax # Total duration of a trial 432 | pearsons.append([]) 433 | 434 | for n, net in enumerate(nets): # 10 networks per delay 435 | if verbose: print(f'## Network {n + 1} ##', ) 436 | net.init() 437 | 438 | # Impulse input after 200 ms 439 | impulse = bm.zeros((trial_duration, net.num_input)) 440 | impulse[t_offset:t_offset + d_stim, 0] = stimulus_amplitude 441 | 442 | # Target output for learning the readout weights 443 | target = bm.zeros((trial_duration, net.num_output)) 444 | time_axis = bm.linspace(0, trial_duration, trial_duration) 445 | target[:, 0] = target_baseline + (target_amplitude - target_baseline) * \ 446 | bm.exp(-(t_offset + d_stim + target_time - time_axis) ** 2 / target_width ** 2) 447 | 448 | # Initial trial to determine the innate target_traj 449 | if verbose: print('Initial trial to determine a target_traj (without noise)') 450 | trajectory, initial_output, _ = net.simulate(stimulus=impulse, noise=False) 451 | 452 | # 20 trials of learning for the recurrent weights 453 | for i in range(num_rec_train): 454 | t0 = time.time() 455 | if verbose: print(f'Learning trial recurrent {i + 1} loss: ', end='') 456 | _, _, loss = net.simulate(stimulus=impulse, 457 | target_traj=trajectory, 458 | learn_start=t_offset + d_stim, 459 | learn_stop=t_offset + d_stim + d_trajectory) 460 | if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s') 461 | 462 | # 10 trials of learning for the readout weights 463 | for i in range(num_readout_train): 464 | t0 = time.time() 465 | if verbose: print(f'Learning trial readout {i + 1} loss: ', end='') 466 | _, _, loss = net.simulate(stimulus=impulse, 467 | target_traj=target, 468 | learn_start=t_offset + d_stim, 469 | learn_stop=t_offset + d_stim + d_trajectory, 470 | learn_readout=True) 471 | if verbose: print(f'{(2 * loss[0] / d_trajectory):5f}, time: {time.time() - t0} s') 472 | 473 | # Test trial 474 | if verbose: print('Test trial') 475 | reproduction, final_output, _ = net.simulate(stimulus=impulse) 476 | 477 | # Pearson correlation coefficient 478 | pred = final_output[t_offset + d_stim:t_offset + d_stim + d_trajectory, 0] 479 | desired = target[t_offset + d_stim:t_offset + d_stim + d_trajectory, 0] 480 | r, p = scipy.stats.pearsonr(desired, pred) 481 | pearsons[-1].append(r) 482 | 483 | # Save the results 484 | pearsons = bm.asarray(pearsons) 485 | 486 | # Visualization 487 | # ------------- 488 | 489 | plt.figure(figsize=(8, 6)) 490 | correlation_mean = bm.mean(pearsons ** 2, axis=1) 491 | correlation_std = bm.std(pearsons ** 2, axis=1) 492 | plt.errorbar(bm.array(delays) / 1000., correlation_mean, correlation_std / bm.sqrt(10), linestyle='-', marker='^') 493 | plt.xlim((0., 8.5)) 494 | plt.ylim((-0.1, 1.1)) 495 | plt.xlabel('Interval (s)') 496 | plt.ylabel('Performance ($R^2$)') 497 | plt.show() 498 | -------------------------------------------------------------------------------- /recurrent_networks/data/DAC_handwriting_output_targets.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainpy/examples/f0515f37b1c5b2a465fcd547766c06c6c3d3fe2e/recurrent_networks/data/DAC_handwriting_output_targets.mat -------------------------------------------------------------------------------- /recurrent_networks/fixed_points_finder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # %% [markdown] 3 | # # Find Fixed Points 4 | 5 | # %% [markdown] 6 | # The goal of this tutorial is to learn about fixed point finding by running the algorithm on a simple data generator, a Gated Recurrent Unit (GRU) that is trained to make a binary decision, namely whether the integral of the white noise input is in total positive or negative, outputing either a +1 or a -1. 7 | 8 | # %% [markdown] 9 | # In this tutorial we do a few things: 10 | # 11 | # - Train the decision making GRU 12 | # - Find the fixed points of the GRU 13 | 14 | # %% 15 | import brainpy as bp 16 | import brainpy.math as bm 17 | bm.set_platform('cpu') 18 | 19 | # %% 20 | import time 21 | from functools import partial 22 | 23 | import numpy as np 24 | import matplotlib.pyplot as plt 25 | import matplotlib.lines as mlines 26 | 27 | # %% [markdown] 28 | # ## Parameters 29 | 30 | # %% 31 | # Integration parameters 32 | 33 | T = 1.0 # Arbitrary amount time, roughly physiological. 34 | dt = 0.04 35 | num_step = int(T / dt) # Divide T into this many bins 36 | bval = 0.01 # bias value limit 37 | sval = 0.025 # standard deviation (before dividing by sqrt(dt)) 38 | 39 | # %% 40 | # Optimization hyperparameters 41 | 42 | l2reg = 0.00002 # amount of L2 regularization on the weights 43 | num_train = int(2e4) # Total number of batches to train on. 44 | num_batch = 150 # How many examples in each batch 45 | # Gradient clipping is HUGELY important for training RNNs 46 | # max gradient norm before clipping, clip to this value. 47 | max_grad_norm = 10.0 48 | 49 | 50 | # %% [markdown] 51 | # ## Helpers 52 | 53 | # %% 54 | def plot_examples(num_time, inputs, hiddens, outputs, targets, num_example=1, 55 | num_plot=10, start_id=0): 56 | """Plot some input/hidden/output triplets. 57 | 58 | Parameters 59 | ---------- 60 | inputs: ndarray of (num_time, num_batch) 61 | hiddens: ndarray of (num_time, num_batch, num_hidden) 62 | outputs: ndarray of (num_time, num_batch, num_output) 63 | targets: ndarray of (num_time, num_batch, num_output) 64 | """ 65 | plt.figure(figsize=(num_example * 5, 14)) 66 | selected_ids = list(range(start_id, start_id + num_example)) 67 | 68 | for i, bidx in enumerate(selected_ids): 69 | plt.subplot(3, num_example, i + 1) 70 | plt.plot(inputs[:, bidx], 'k') 71 | plt.xlim([0, num_time]) 72 | plt.title('Example %d' % bidx) 73 | if bidx == 0: plt.ylabel('Input Units') 74 | 75 | closeness = 0.25 76 | for i, bidx in enumerate(selected_ids): 77 | plt.subplot(3, num_example, num_example + i + 1) 78 | plt.plot(hiddens[:, bidx, 0:num_plot] + closeness * np.arange(num_plot), 'b') 79 | plt.xlim([0, num_time]) 80 | if bidx == 0: plt.ylabel('Hidden Units') 81 | 82 | for i, bidx in enumerate(selected_ids): 83 | plt.subplot(3, num_example, 2 * num_example + i + 1) 84 | plt.plot(outputs[:, bidx, :], 'r', label='predict') 85 | plt.plot(targets[:, bidx, :], 'k', label='target') 86 | plt.xlim([0, num_time]) 87 | plt.xlabel('Time steps') 88 | plt.legend() 89 | if bidx == 0: plt.ylabel('Output Units') 90 | 91 | plt.show() 92 | 93 | 94 | # %% 95 | def plot_params(rnn, show_top_eig=0): 96 | """Plot the parameters. """ 97 | assert isinstance(rnn, GRU) 98 | 99 | plt.figure(figsize=(16, 8)) 100 | plt.subplot(231) 101 | plt.stem(rnn.w_ro.numpy()[:, 0]) 102 | plt.title('W_ro - output weights') 103 | 104 | plt.subplot(232) 105 | plt.stem(rnn.h0) 106 | plt.title('h0 - initial hidden state') 107 | 108 | a = bm.concatenate([rnn.w_ir, rnn.w_iz], axis=0) 109 | b = bm.concatenate([rnn.w_hr, rnn.w_hz], axis=0) 110 | c = bm.concatenate([a, b], axis=0).numpy() 111 | 112 | plt.subplot(233) 113 | plt.imshow(c, interpolation=None) 114 | plt.colorbar() 115 | plt.title('[W_ir, W_iz, W_hr, W_hz]') 116 | 117 | plt.subplot(234) 118 | a = bm.concatenate([rnn.w_ia, rnn.w_ha], axis=0).numpy() 119 | plt.imshow(a, interpolation=None) 120 | plt.colorbar() 121 | plt.title('[W_ia, W_ha]') 122 | 123 | plt.subplot(235) 124 | plt.stem(bm.concatenate([rnn.bz, rnn.br]).numpy()) 125 | plt.title('[bz, br] - recurrent biases') 126 | 127 | plt.subplot(236) 128 | dFdh = bm.jacobian(rnn.cell)(rnn.h0.value, bm.zeros(rnn.num_input)) 129 | evals, _ = np.linalg.eig(bm.as_numpy(dFdh)) 130 | x = np.linspace(-1, 1, 1000) 131 | plt.plot(x, np.sqrt(1 - x ** 2), 'k') 132 | plt.plot(x, -np.sqrt(1 - x ** 2), 'k') 133 | plt.plot(np.real(evals), np.imag(evals), '.') 134 | 135 | if show_top_eig > 0: 136 | print(np.sort(np.real(evals))[-show_top_eig:]) 137 | 138 | plt.axis('equal') 139 | plt.xlabel('Real($\lambda$)') 140 | plt.ylabel('Imaginary($\lambda$)') 141 | plt.title('Eigenvalues of $dF/dh(h_0)$') 142 | 143 | plt.show() 144 | 145 | 146 | # %% 147 | def plot_data(num_time, inputs, targets=None, outputs=None, errors=None, num_plot=10): 148 | """Plot some white noise / integrated white noise examples. 149 | 150 | Parameters 151 | ---------- 152 | num_time : int 153 | num_plot : int 154 | inputs: ndarray 155 | with the shape of (num_batch, num_time, num_input) 156 | targets: ndarray 157 | with the shape of (num_batch, num_time, num_output) 158 | outputs: ndarray 159 | with the shape of (num_batch, num_time, num_output) 160 | errors: ndarray 161 | with the shape of (num_batch, num_time, num_output) 162 | """ 163 | num = 1 164 | if errors is not None: num += 1 165 | if (targets is not None) or (outputs is not None): num += 1 166 | plt.figure(figsize=(14, 4 * num)) 167 | 168 | # inputs 169 | plt.subplot(num, 1, 1) 170 | plt.plot(inputs[:, 0:num_plot, 0]) 171 | plt.xlim([0, num_time]) 172 | plt.ylabel('Noise') 173 | 174 | legends = [] 175 | if outputs is not None: 176 | plt.subplot(num, 1, 2) 177 | plt.plot(outputs[:, 0:num_plot, 0]) 178 | plt.xlim([0, num_time]) 179 | legends.append(mlines.Line2D([], [], color='k', linestyle='-', label='predict')) 180 | if targets is not None: 181 | plt.subplot(num, 1, 2) 182 | plt.plot(targets[:, 0:num_plot, 0], '--') 183 | plt.xlim([0, num_time]) 184 | plt.ylabel("Integration") 185 | legends.append(mlines.Line2D([], [], color='k', linestyle='--', label='target')) 186 | if len(legends): plt.legend(handles=legends) 187 | 188 | if errors is not None: 189 | plt.subplot(num, 1, 3) 190 | plt.plot(errors[:, 0:num_plot, 0], '--') 191 | plt.xlim([0, num_time]) 192 | plt.ylabel("|Errors|") 193 | 194 | plt.xlabel('Time steps') 195 | plt.show() 196 | 197 | 198 | # %% 199 | @partial(bm.jit, dyn_vars={'a': bm.random.DEFAULT}, static_argnames=('num_batch', 'num_step', 'dt')) 200 | def build_inputs_and_targets(mean, scale, num_batch, num_step, dt): 201 | """Build white noise input and integration targets.""" 202 | 203 | # Create the white noise input. 204 | sample = bm.random.normal(size=(num_batch,)) 205 | bias = mean * 2.0 * (sample - 0.5) 206 | samples = bm.random.normal(size=(num_step, num_batch)) 207 | noise_t = scale / dt ** 0.5 * samples 208 | white_noise_t = bias + noise_t 209 | inputs_txbx1 = bm.expand_dims(white_noise_t, axis=2) 210 | 211 | # * dt, intentionally left off to get output scaling in O(1). 212 | integration_txbx1 = bm.expand_dims(bm.cumsum(white_noise_t, axis=0), axis=2) 213 | targets_txbx1 = bm.zeros_like(integration_txbx1) 214 | targets_txbx1[-1] = 2.0 * ((integration_txbx1[-1] > 0.0) - 0.5) 215 | # targets_mask = bm.ones((num_batch, 1)) * (num_step - 1) 216 | return inputs_txbx1, targets_txbx1 217 | 218 | 219 | # %% 220 | # # Plot the example inputs and targets for the RNN. 221 | # _ints, _outs = build_inputs_and_targets(bval, sval, num_batch=num_batch, num_step=num_step, dt=dt) 222 | 223 | # plot_data(num_step, inputs=_ints, targets=_outs) 224 | 225 | # %% [markdown] 226 | # ## Model 227 | 228 | # %% 229 | class GRU(bp.DynamicalSystem): 230 | def __init__(self, num_hidden, num_input, num_output, num_batch, 231 | g=1.0, l2_reg=0., forget_bias=0.5, **kwargs): 232 | super(GRU, self).__init__(**kwargs) 233 | 234 | # parameters 235 | self.l2_reg = l2_reg 236 | self.num_input = num_input 237 | self.num_batch = num_batch 238 | self.num_hidden = num_hidden 239 | self.num_output = num_output 240 | self.rng = bm.random.RandomState() 241 | self.forget_bias = forget_bias 242 | 243 | # recurrent weights 244 | self.w_iz = bm.TrainVar(self.rng.normal(scale=1 / num_input ** 0.5, size=(num_input, num_hidden))) 245 | self.w_ir = bm.TrainVar(self.rng.normal(scale=1 / num_input ** 0.5, size=(num_input, num_hidden))) 246 | self.w_ia = bm.TrainVar(self.rng.normal(scale=1 / num_input ** 0.5, size=(num_input, num_hidden))) 247 | self.w_hz = bm.TrainVar(self.rng.normal(scale=g / num_hidden ** 0.5, size=(num_hidden, num_hidden))) 248 | self.w_hr = bm.TrainVar(self.rng.normal(scale=g / num_hidden ** 0.5, size=(num_hidden, num_hidden))) 249 | self.w_ha = bm.TrainVar(self.rng.normal(scale=g / num_hidden ** 0.5, size=(num_hidden, num_hidden))) 250 | self.bz = bm.TrainVar(bm.zeros((num_hidden,))) 251 | self.br = bm.TrainVar(bm.zeros((num_hidden,))) 252 | self.ba = bm.TrainVar(bm.zeros((num_hidden,))) 253 | self.h0 = bm.TrainVar(self.rng.normal(scale=0.1, size=(num_hidden,))) 254 | 255 | # output weights 256 | self.w_ro = bm.TrainVar(self.rng.normal(scale=1 / num_hidden ** 0.5, size=(num_hidden, num_output))) 257 | self.b_ro = bm.TrainVar(bm.zeros((num_output,))) 258 | 259 | # variables 260 | self.h = bm.Variable(self.rng.normal(scale=0.1, size=(num_batch, self.num_hidden))) 261 | self.o = self.h @ self.w_ro + self.b_ro 262 | 263 | # loss 264 | self.total_loss = bm.Variable(bm.zeros(1)) 265 | self.l2_loss = bm.Variable(bm.zeros(1)) 266 | self.mse_loss = bm.Variable(bm.zeros(1)) 267 | 268 | def cell(self, h, x): 269 | r = bm.sigmoid(x @ self.w_ir + h @ self.w_hr + self.br) 270 | z = bm.sigmoid(x @ self.w_iz + h @ self.w_hz + self.bz) 271 | a = bm.tanh(x @ self.w_ia + (r * h) @ self.w_ha + self.ba) 272 | return (1. - z) * h + z * a 273 | 274 | def readout(self, h): 275 | return h @ self.w_ro + self.b_ro 276 | 277 | def make_update(self, h: bm.JaxArray, o: bm.JaxArray): 278 | def update(x): 279 | h.value = self.cell(h, x) 280 | o.value = self.readout(h) 281 | 282 | return update 283 | 284 | def predict(self, xs): 285 | self.h[:] = self.h0 286 | f = bm.make_loop(self.make_update(self.h, self.o), 287 | dyn_vars=self.vars().unique(), 288 | out_vars=[self.h, self.o]) 289 | return f(xs) 290 | 291 | def loss(self, xs, ys): 292 | hs, os = self.predict(xs) 293 | l2 = self.l2_reg * bm.losses.l2_norm(self.train_vars().dict()) ** 2 294 | mse = bm.losses.mean_squared_error(os[-1], ys[-1]) 295 | total = l2 + mse 296 | self.total_loss[0] = total 297 | self.l2_loss[0] = l2 298 | self.mse_loss[0] = mse 299 | return total 300 | 301 | 302 | # %% 303 | net = GRU(num_input=1, num_hidden=100, num_output=1, num_batch=num_batch, l2_reg=l2reg) 304 | 305 | # plot_params(net) 306 | 307 | # %% 308 | lr = bp.optim.ExponentialDecay(lr=0.04, decay_steps=1, decay_rate=0.9999) 309 | optimizer = bp.optim.Adam(lr=lr, train_vars=net.train_vars(), eps=1e-1) 310 | 311 | 312 | @bm.jit 313 | @bm.function(nodes=(net, optimizer)) 314 | def train(inputs, targets): 315 | grad_f = bm.grad(net.loss, dyn_vars=net.vars(), grad_vars=net.train_vars(), return_value=True) 316 | grads, loss = grad_f(inputs, targets) 317 | clipped_grads = bm.clip_by_norm(grads, max_grad_norm) 318 | optimizer.update(clipped_grads) 319 | return loss 320 | 321 | 322 | # %% [markdown] 323 | # ## Training 324 | 325 | # %% 326 | t0 = time.time() 327 | train_losses = {'total': [], 'l2': [], 'mse': []} 328 | for i in range(num_train): 329 | _ins, _outs = build_inputs_and_targets(bval, sval, num_batch=num_batch, num_step=num_step, dt=dt) 330 | loss = train(inputs=_ins, targets=_outs) 331 | if (i + 1) % 400 == 0: 332 | print(f"Run batch {i + 1} in {time.time() - t0:0.3f} s, learning rate: {lr():.5f}, training loss {loss:0.4f}") 333 | train_losses['total'].append(net.total_loss[0]) 334 | train_losses['l2'].append(net.l2_loss[0]) 335 | train_losses['mse'].append(net.mse_loss[0]) 336 | 337 | # %% 338 | # net.save_states('./data/fixed_points-80.h5') 339 | # net.save_states('./data/fixed_points-1200.h5') 340 | # net.load_states('./data/fixed_points-80.h5') 341 | 342 | # %% 343 | # Show the loss through training. 344 | plt.figure(figsize=(12, 4)) 345 | plt.subplot(131) 346 | plt.plot(train_losses['total'], 'k') 347 | plt.title('Total loss') 348 | plt.xlabel('Trail') 349 | 350 | plt.subplot(132) 351 | plt.plot(train_losses['mse'], 'r') 352 | plt.title('Least mean square loss') 353 | plt.xlabel('Trail') 354 | 355 | plt.subplot(133) 356 | plt.plot(train_losses['l2'], 'g') 357 | plt.title('L2 loss') 358 | plt.xlabel('Trail') 359 | plt.show() 360 | 361 | # %% 362 | # show the trained weights 363 | plot_params(net, show_top_eig=2) 364 | 365 | # %% [markdown] 366 | # ## Testing 367 | 368 | # %% 369 | # net.load_states('./data/fixed_points.h5') 370 | 371 | # %% 372 | inputs, targets = build_inputs_and_targets(bval, sval, num_batch=num_batch, num_step=num_step, dt=dt) 373 | hiddens, outputs = net.predict(inputs) 374 | 375 | # plot_data(num_step, inputs=inputs, targets=targets, outputs=outputs, errors=np.abs(targets - outputs), num_plot=16) 376 | 377 | # %% 378 | plot_examples(num_step, inputs=inputs, targets=targets, outputs=outputs, hiddens=hiddens, num_example=4) 379 | 380 | # %% 381 | plot_examples(num_step, inputs=inputs, targets=targets, outputs=outputs, hiddens=hiddens, num_example=4, start_id=10) 382 | 383 | # %% [markdown] 384 | # ## Fixed point analysis 385 | 386 | # %% [markdown] 387 | # Now that we've trained up this GRU to decide whether or not the perfect integral of the input is positive or negative, we can analyze the system via fixed point analysis. 388 | 389 | # %% [markdown] 390 | # The update rule of the GRU cell from the current state to the next state can be computed by: 391 | 392 | # %% 393 | f_cell = lambda h: net.cell(h, bm.zeros(net.num_input)) 394 | 395 | # %% [markdown] 396 | # The function to determine the fixed point loss is given by the squared error of a point $(h - F(h))^2$: 397 | 398 | # %% [markdown] 399 | # Let's try to find the fixed points given the initial states. 400 | 401 | # %% 402 | fp_candidates = hiddens.reshape((-1, net.num_hidden)) 403 | 404 | # %% 405 | finder = bp.analysis.SlowPointFinder( 406 | f_cell=f_cell, 407 | f_type='discrete', 408 | ) 409 | finder.find_fps_with_gd_method( 410 | candidates=fp_candidates, 411 | tolerance=5e-7, num_batch=400, 412 | optimizer=bp.optim.Adam(lr=bp.optim.ExponentialDecay(0.2, 1, 0.9999), eps=1e-8), 413 | ) 414 | finder.filter_loss(tolerance=1e-7) 415 | finder.keep_unique(tolerance=0.0005) 416 | finder.exclude_outliers(0.1) 417 | fps = finder.fixed_points 418 | 419 | # 420 | # fps, fp_losses, keep_ids, opt_losses = finder.find_fixed_points(candidates=fp_candidates) 421 | 422 | # %% [markdown] 423 | # ### Verify fixed points 424 | 425 | # %% [markdown] 426 | # Plotting the quality of the fixed points. 427 | 428 | # %% 429 | # # %matplotlib inline 430 | 431 | # %% 432 | fig, gs = bp.visualize.get_figure(1, 2, 4, 6) 433 | 434 | fig.add_subplot(gs[0, 0]) 435 | plt.semilogy(finder.losses) 436 | plt.xlabel('Fixed point #') 437 | plt.ylabel('Fixed point loss') 438 | 439 | fig.add_subplot(gs[0, 1]) 440 | plt.hist(np.log10(finder.f_loss_batch(fps)), 50) 441 | plt.xlabel('log10(FP loss)') 442 | plt.show() 443 | 444 | # %% [markdown] 445 | # Let's run the system starting at these fixed points, without input, and make sure the system is at equilibrium there. Note one can have fixed points that are very unstable, but that does not show up in this example. 446 | 447 | # %% 448 | # num_example = len(fps) 449 | # idxs = np.random.randint(0, len(fps), num_example) 450 | # check_h = bm.Variable(fps[idxs]) 451 | 452 | num_example = len(fps) 453 | idxs = np.random.randint(0, len(fps), num_example) 454 | check_h = bm.Variable(fps) 455 | check_o = bm.Variable(bm.zeros((num_example, net.num_output))) 456 | 457 | f_check_update = bm.make_loop(net.make_update(check_h, check_o), 458 | dyn_vars=list(net.vars().values()) + [check_h, check_o], 459 | out_vars=[check_h, check_o]) 460 | 461 | _ins = bm.zeros((num_step, num_example, net.num_input)) 462 | _outs = bm.zeros((num_step, num_example, net.num_output)) 463 | slow_hiddens, slow_outputs = f_check_update(_ins) 464 | 465 | # %% 466 | plot_examples(num_step, inputs=_ins, targets=_outs, outputs=slow_outputs, 467 | hiddens=slow_hiddens, num_example=4) 468 | 469 | # %% 470 | plot_examples(num_step, inputs=_ins, targets=_outs, outputs=slow_outputs, 471 | hiddens=slow_hiddens, num_example=4, start_id=4) 472 | 473 | # %% [markdown] 474 | # Try to get a nice representation of the line using the fixed points. 475 | 476 | # %% 477 | # Sort the best fixed points by projection onto the readout. 478 | fp_readouts = np.squeeze(net.readout(fps)) 479 | fp_ro_sidxs = np.argsort(fp_readouts) 480 | sorted_fp_readouts = fp_readouts[fp_ro_sidxs] 481 | sorted_fps = fps[fp_ro_sidxs] 482 | 483 | downsample_fps = 1 # Use this if too many fps 484 | sorted_fp_readouts = sorted_fp_readouts[0:-1:downsample_fps] 485 | sorted_fps = sorted_fps[0:-1:downsample_fps] 486 | 487 | # %% [markdown] 488 | # ### Visualize fixed points 489 | 490 | # %% [markdown] 491 | # Now, through a series of plots and dot products, we will see how the GRU solved the binary decision task. Now, let's plot the fixed points and the fixed point candidates that the fixed point optimization was seeded with. Black shows the original candidate point, the colored stars show the fixed point, where the color of the fixed point is the projection onto the readout vector and the size is commensurate with how slow it is (slower is larger). 492 | 493 | # %% 494 | # # %matplotlib qt 495 | 496 | # %% 497 | from sklearn.decomposition import PCA 498 | 499 | # fit candidates 500 | pca = PCA(n_components=3).fit(fp_candidates) 501 | # pca = PCA(n_components=3).fit(fps) 502 | 503 | # %% 504 | fig = plt.figure(figsize=(16,16)) 505 | ax = fig.add_subplot(111, projection='3d') 506 | 507 | emax = fps.shape[0] 508 | 509 | # # plot candidates 510 | # h_pca = pca.transform(fp_candidates[keep_ids]) 511 | # emax = h_pca.shape[0] if h_pca.shape[0] < 1000 else 1000 512 | # ax.scatter(h_pca[0:emax, 0], h_pca[0:emax, 1], h_pca[0:emax, 2], color=[0, 0, 0, 0.1], s=10) 513 | 514 | # plot fixed points 515 | hstar_pca = pca.transform(fps) 516 | color = np.squeeze(net.readout(fps)) 517 | color = np.where(color > 1.0, 1.0, color) 518 | color = np.where(color < -1.0, -1.0, color) 519 | color = (color + 1.0) / 2.0 520 | marker_style = dict(marker='*', s=100, edgecolor='gray') 521 | ax.scatter(hstar_pca[0:emax, 0], hstar_pca[0:emax, 1], hstar_pca[0:emax, 2], c=color[0:emax], **marker_style); 522 | 523 | # alpha = 0.02 524 | # for eidx in range(emax): 525 | # ax.plot3D([h_pca[eidx,0], hstar_pca[eidx,0]], 526 | # [h_pca[eidx,1], hstar_pca[eidx,1]], 527 | # [h_pca[eidx,2], hstar_pca[eidx,2]], 528 | # c=[0, 0, 1, alpha]) 529 | 530 | plt.title('Fixed point structure and fixed point candidate starting points.') 531 | ax.set_xlabel('PC 1') 532 | ax.set_ylabel('PC 2') 533 | ax.set_zlabel('PC 3') 534 | plt.show() 535 | 536 | # %% [markdown] 537 | # So in this example, we see that the fixed point structure implements an approximate line attractor, which is the one-dimensional manifold likely used to integrate the white noise and ultimately lead to the decision. 538 | 539 | # %% [markdown] 540 | # Note also the shape of the manifold relative to the color. The color is the based on the readout value of the fixed point, so it appears that there may be three parts to the line attractor. The middle and two sides. The two sides may be integrating, even though the the readout would be +1 or -1. 541 | 542 | # %% [markdown] 543 | # It's worth taking a look at the fixed points, and the trajectories started at the fixed points, without any input, all plotted in the 3D PCA space. 544 | 545 | # %% 546 | fig = plt.figure(figsize=(16,16)) 547 | ax = fig.add_subplot(111, projection='3d') 548 | 549 | alpha = 0.05 550 | emax = len(sorted_fps) 551 | for eidx in range(emax): 552 | h_pca = pca.transform(slow_hiddens[:, eidx, :]) 553 | ax.plot3D(h_pca[:,0], h_pca[:,1], h_pca[:,2], c=[0, 0, 1, alpha]) 554 | 555 | size = 100 556 | hstar_pca = pca.transform(sorted_fps) 557 | color = np.squeeze(net.readout(sorted_fps)) 558 | color = np.where(color > 1.0, 1.0, color) 559 | color = np.where(color < -1.0, -1.0, color) 560 | color = (color + 1.0) / 2.0 561 | marker_style = dict(marker='*', s=size, edgecolor='gray') 562 | 563 | ax.scatter(hstar_pca[0:emax, 0], hstar_pca[0:emax, 1], hstar_pca[0:emax, 2], 564 | c=color[0:emax], **marker_style) 565 | 566 | plt.title('High quality fixed points and the network dynamics initialized from them.') 567 | ax.set_xlabel('PC 1') 568 | ax.set_ylabel('PC 2') 569 | ax.set_zlabel('PC 3') 570 | plt.show() 571 | 572 | # %% 573 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | brainpy 2 | 3 | # docs 4 | pandoc 5 | docutils 6 | Jinja2 7 | sphinx 8 | nbsphinx 9 | sphinx_rtd_theme>=1.0.0 --------------------------------------------------------------------------------