├── sample_images ├── img1.png ├── img2.png ├── diehl0.png ├── history.png ├── s_trace.gif ├── x_trace.gif ├── diehl10000.png ├── history_sample.png ├── pre_weight_maps.png ├── res_weight_maps.png ├── encode │ ├── lif_encode.png │ ├── repeat_encode.png │ ├── single_encode.png │ ├── bernoili_encode.png │ ├── poisson_encode.png │ ├── rankorder_encode.png │ └── fixed-frequency_encode.png └── result_weight_maps.png ├── wbn ├── __init__.py ├── additional_encoders.py ├── examples │ └── simple_mln.py └── snnlib.py ├── main.py └── README.md /sample_images/img1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/img1.png -------------------------------------------------------------------------------- /sample_images/img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/img2.png -------------------------------------------------------------------------------- /sample_images/diehl0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/diehl0.png -------------------------------------------------------------------------------- /sample_images/history.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/history.png -------------------------------------------------------------------------------- /sample_images/s_trace.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/s_trace.gif -------------------------------------------------------------------------------- /sample_images/x_trace.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/x_trace.gif -------------------------------------------------------------------------------- /sample_images/diehl10000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/diehl10000.png -------------------------------------------------------------------------------- /sample_images/history_sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/history_sample.png -------------------------------------------------------------------------------- /sample_images/pre_weight_maps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/pre_weight_maps.png -------------------------------------------------------------------------------- /sample_images/res_weight_maps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/res_weight_maps.png -------------------------------------------------------------------------------- /sample_images/encode/lif_encode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/encode/lif_encode.png -------------------------------------------------------------------------------- /sample_images/encode/repeat_encode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/encode/repeat_encode.png -------------------------------------------------------------------------------- /sample_images/encode/single_encode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/encode/single_encode.png -------------------------------------------------------------------------------- /sample_images/result_weight_maps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/result_weight_maps.png -------------------------------------------------------------------------------- /sample_images/encode/bernoili_encode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/encode/bernoili_encode.png -------------------------------------------------------------------------------- /sample_images/encode/poisson_encode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/encode/poisson_encode.png -------------------------------------------------------------------------------- /sample_images/encode/rankorder_encode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/encode/rankorder_encode.png -------------------------------------------------------------------------------- /sample_images/encode/fixed-frequency_encode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiroshiARAKI/snnlibpy/HEAD/sample_images/encode/fixed-frequency_encode.png -------------------------------------------------------------------------------- /wbn/__init__.py: -------------------------------------------------------------------------------- 1 | from .snnlib import ( 2 | Spiking, 3 | PoissonEncoder, 4 | SingleEncoder, 5 | RepeatEncoder, 6 | RankOrderEncoder, 7 | BernoulliEncoder, 8 | __version__ 9 | ) 10 | 11 | from .examples.simple_mln import ( 12 | DiehlCook_unsupervised_model, 13 | MultiLayerNetwork_unsupervised_model, 14 | ) 15 | 16 | from .additional_encoders import ( 17 | FixedFrequencyEncoder, 18 | LIFEncoder, 19 | LIFEncoder2, 20 | ) 21 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from wbn import Spiking 2 | 3 | 4 | if __name__ == '__main__': 5 | 6 | # Build SNNs and decide the number of input neurons and the simulation time. 7 | snn = Spiking(input_l=784, obs_time=300, dt=0.5) 8 | snn.IMAGE_DIR += 'diehl/' 9 | 10 | # Add a layer and give the num of neurons and the neuron model. 11 | snn.add_layer(n=100, 12 | node=snn.DIEHL_COOK, # or snn.DIEHL_COOK 13 | w=snn.W_SIMPLE_RAND, # initialize weights 14 | scale=0.4, # scale of random intensity 15 | rule=snn.SIMPLE_STDP, # learning rule 16 | nu=(1e-4, 1e-2), # learning rate 17 | # norm=150, # L1 weight normalization term 18 | ) 19 | 20 | # Add an inhibitory layer 21 | snn.add_inhibit_layer(inh_w=-128) 22 | 23 | # Load dataset 24 | snn.load_MNIST() 25 | 26 | # Check your network architecture 27 | snn.print_model() 28 | 29 | # If you use a small network, your network computation by GPU may be more slowly than CPU. 30 | # So you can change directly whether using GPU or not as below. 31 | # snn.gpu = False 32 | 33 | # Gpu is available?? If available, make it use. 34 | snn.to_gpu() 35 | 36 | # Plot weight maps before training 37 | snn.plot(plt_type='wmps', prefix='0', f_shape=(10, 10)) 38 | 39 | # Make my network run 40 | for i in range(3): 41 | snn.run( 42 | # tr_size=10000, # training data size 43 | # unsupervised=False, # do not unsupervised learning? 44 | # alpha=0.8, # assignment decay 45 | # debug=True, # Do you wanna watch neuron's assignments? 46 | # interval=500, # interval of assignment 47 | # ts_size=5000, # If you have little time for experiments, be able to reduce test size 48 | ) 49 | 50 | snn.plot(plt_type='wmps', prefix='{}'.format(i+1), f_shape=(10, 10)) # plot maps 51 | 52 | # Plot test accuracy transition 53 | snn.plot(plt_type='history', prefix='result') 54 | 55 | # Plot weight maps after training 56 | snn.plot(plt_type='wmps', prefix='result', f_shape=(10, 10)) 57 | 58 | # Plot output spike trains after training 59 | snn.plot(plt_type='sp', range=10) 60 | 61 | print(snn.history) 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WrappedBindsNET 2 | ![update](https://img.shields.io/badge/last%20update-2021.03.11-lightgray.svg?style=flat) 3 | 4 | これはBindsNETと呼ばれるPyTorchベースのSpiking Neural Networksフレームワークをさらに使いやすくしよう, 5 | というコンセプトのもと作成中. 6 | この小さなライブラリは,大体[snnlib.py](wbn/snnlib.py)に詰められているので,各種定数などはかなり弄りやすいかと思います. 7 | もちろん,main.pyから直接クラス変数は変更できます. 8 | 完全に個人利用ですが,使いたい人がいればご自由にどうぞ 9 | (結構頻繁に小さな(大したことない)アップデートをしています.) 10 | 11 | **作者の修士課程修了に伴い,大きなアップデートは今後おそらくありませんが,これを拡張して利用することは歓迎いたします.** 12 | 13 | I am making a tiny and user friendly library of Spiking Neural Networks with BindsNET. 14 | All functions are packed to only [snnlib.py](wbn/snnlib.py), so you can use easily. 15 | This library is used by private myself, but if you want to use it, feel free to use. 16 | 17 | **未完成につきバグがまだある可能性があります.(Maybe, there are bugs because this is incompletely.)** 18 | 19 | ## 実行保証環境 (Environment) 20 | 以下の環境において問題なく実行可能なことを確認しています. 21 | 22 | * OS.........MacOS 10.15 or Ubuntu 16.04 LTS 23 | * Python.....3.6.* or 3.7.* (, or later) 24 | * BindsNET...0.2.7 (not worked on < 0.2.7) 25 | * PyTorch....1.10 26 | (GPU: torch... 1.3.0+cu92, torchvision... 0.4.1+cu92) 27 | 28 | ## Example 29 | * Sample code 30 | ```python 31 | from wbn import Spiking 32 | 33 | 34 | if __name__ == '__main__': 35 | 36 | # Build SNNs and decide the number of input neurons and the simulation time. 37 | snn = Spiking(input_l=784, obs_time=300, dt=0.5) 38 | snn.IMAGE_DIR += 'diehl/' 39 | 40 | # Add a layer and give the num of neurons and the neuron model. 41 | snn.add_layer(n=100, 42 | node=snn.DIEHL_COOK, # or snn.DIEHL_COOK 43 | w=snn.W_SIMPLE_RAND, # initialize weights 44 | rule=snn.SIMPLE_STDP, # learning rule 45 | nu=(1e-4, 1e-2), # learning rate 46 | ) 47 | 48 | # Add an inhibitory layer 49 | snn.add_inhibit_layer(inh_w=-128) 50 | 51 | # Load dataset 52 | snn.load_MNIST() 53 | 54 | # Check your network architecture 55 | snn.print_model() 56 | 57 | # If you use a small network, your network computation by GPU may be more slowly than CPU. 58 | # So you can change directly whether using GPU or not as below. 59 | # snn.gpu = False 60 | 61 | # Gpu is available?? If available, make it use. 62 | snn.to_gpu() 63 | 64 | # Plot weight maps before training 65 | snn.plot(plt_type='wmps', prefix='0', f_shape=(10, 10)) 66 | 67 | # Make my network run 68 | for i in range(3): 69 | snn.run() 70 | 71 | snn.plot(plt_type='wmps', prefix='{}'.format(i+1), f_shape=(10, 10)) # plot maps 72 | 73 | # Plot test accuracy transition 74 | snn.plot(plt_type='history', prefix='result') 75 | 76 | # Plot weight maps after training 77 | snn.plot(plt_type='wmps', prefix='result', f_shape=(10, 10)) 78 | 79 | # Plot output spike trains after training 80 | snn.plot(plt_type='sp', range=10) 81 | 82 | print(snn.history) 83 | ``` 84 | 85 | or very simply, 86 | ```python 87 | from wbn import DiehlCook_unsupervised_model # packed sample simulation code 88 | 89 | DiehlCook_unsupervised_model() 90 | ``` 91 | is ok (actually this function is my backup data, so it's good for you to use this when you check whether it works properly). 92 | 93 | * Generated image samples 94 | * A weight map before training 95 | ![pre_training](sample_images/diehl0.png) 96 | 97 | * A weight map after STDP training with 1,0000 MNIST data 98 | ![pre_training](sample_images/diehl10000.png) 99 | 100 | 101 | ## BindsNET references 102 | 【docs】 103 | [Welcome to BindsNET’s documentation! — bindsnet 0.2.5 documentation](https://bindsnet-docs.readthedocs.io) 104 | 105 | 【Github】 106 | [Hananel-Hazan/bindsnet: Simulation of spiking neural networks (SNNs) using PyTorch.](https://github.com/Hananel-Hazan/bindsnet) 107 | 108 | 【Paper】 109 | [BindsNET: A Machine Learning-Oriented Spiking Neural Networks Library in Python](https://www.frontiersin.org/articles/10.3389/fninf.2018.00089/full) 110 | 111 | -------------------------------------------------------------------------------- /wbn/additional_encoders.py: -------------------------------------------------------------------------------- 1 | from bindsnet.encoding import Encoder 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def fixed_frequency(datum: torch.Tensor, time: int, dt: float = 1.0): 7 | """ 8 | Generates Fixed-frequency spike trains with an input image normalized as frequencies. 9 | :param datum: input image shape[batch, height, width]? 10 | :param time: 11 | :param dt: 12 | :return: 13 | """ 14 | time = int(time / dt) 15 | shape = list(datum.shape) 16 | datum = np.copy(datum) 17 | 18 | periods = 1. / (datum*0.001 + 0.000001) # transform frequencies to periods 19 | periods = periods.astype(int) + 1 # interval between two spike is periods, so needs plus one 20 | 21 | spike = np.arange(1, time + 1, 1) # initialize spike [1, time+1] that means indices of array 22 | spikes = torch.tensor([ 23 | spike for _ in range(784) # also initialize all spikes of img with that initialized spike 24 | ]).numpy().T.reshape((time, shape[1], shape[2])) # transpose and reshape 25 | 26 | # fixed-frequency spike trains generating process (this code is not smart??) 27 | spikes[spikes % periods != 0] = 2 # In all spikes, puts dummy number '2' into non-spike times decided each periods. 28 | spikes[spikes % periods == 0] = 1 # next, the number '1' into the spike-exsisting times 29 | spikes[spikes == 2] = 0 # and finally, changes dummy numbers to 0 30 | 31 | return torch.tensor(spikes) # to tensor 32 | 33 | 34 | def lif(datum: torch.Tensor, time: int, dt: float = 1.0, rest=-65, th=-40, ref=3, tc_decay=100): 35 | """ 36 | Very simple LIF neuron spike generator. 37 | Membrane formula is below, 38 | v = v + I 39 | v = decay * (v - rest) + rest 40 | :param datum: 41 | :param time: 42 | :param dt: 43 | :param rest: 44 | :param th: 45 | :param ref: 46 | :param tc_decay: 47 | :return: 48 | """ 49 | time = int(time / dt) 50 | shape = list(datum.shape) 51 | datum = np.copy(datum.squeeze()) 52 | 53 | decay = np.exp(-1.0 / tc_decay) 54 | 55 | spikes = np.zeros((time, shape[1], shape[2])) 56 | v = np.full_like(datum, rest) 57 | refrac = np.zeros_like(v) 58 | 59 | for t in range(time): 60 | # makes neurons' membrane potentials outside a refractory period integrate input currents 61 | v[refrac == 0] += datum[refrac == 0] 62 | # and make them decay (leak) 63 | v = decay * (v - rest) + rest 64 | 65 | # neurons whose potential is higher than threshold are fired 66 | spikes[t][v > th] = 1 67 | # sets them to a refractory period 68 | refrac[v > th] = ref 69 | # and make them be back to resting potential 70 | v[v > th] = rest 71 | 72 | # makes refractory period counters decrement 73 | refrac[refrac > 0] -= 1 74 | 75 | return torch.tensor(spikes) # to tensor 76 | 77 | 78 | def lif_v2(datum: torch.Tensor, time: int, dt: float = 1.0, rest=-65, th=-40, ref=3, tc_decay=10): 79 | """ 80 | This is also LIF equation. 81 | Membrane potential formula is below. 82 | dv/dt = (-v + v_rest + I)/time_const 83 | :param datum: 84 | :param time: 85 | :param dt: 86 | :param rest: 87 | :param th: 88 | :param ref: 89 | :param tc_decay: 90 | :return: 91 | """ 92 | time = int(time / dt) 93 | shape = list(datum.shape) 94 | current = np.copy(datum.squeeze()) 95 | 96 | tlast = np.full_like(current, -ref) # last firing time 97 | spikes = np.zeros((time, shape[1], shape[2])) 98 | v = np.full_like(current, rest) # membrane potentials 99 | # vpeak = 20 # peak of membrane potential which is not needed in case of computing only spikes 100 | 101 | for t in range(time): 102 | # computes amount of voltage increase 103 | dv = ((dt * t) > (tlast + ref)) * (-v + rest + current) / tc_decay 104 | 105 | # updates voltage 106 | v = v + dt * dv 107 | 108 | # if fires, updates last firing time 109 | tlast = tlast + (dt * t - tlast) * (v >= th) 110 | 111 | # v = v + (vpeak - v) * (v >= th) 112 | 113 | # also set 1 to spikes 114 | spikes[t][v >= th] = 1 115 | 116 | # and make voltage decrease to the resting voltage 117 | v = v + (rest - v) * (v >= th) 118 | 119 | return torch.tensor(spikes) # to tensor 120 | 121 | 122 | class FixedFrequencyEncoder(Encoder): 123 | def __init__(self, time: int, dt: float = 1.0, **kwargs): 124 | """ 125 | Generates vey simple spike train by a fixed frequency, which means the all interval between two spikes are even. 126 | Then, each pixels values is used as a fixed frequency, so you have to normalize input images in advance. 127 | :param time: 128 | :param dt: 129 | :param kwargs: 130 | """ 131 | super().__init__(time, dt=dt, **kwargs) 132 | self.enc = fixed_frequency 133 | 134 | 135 | class LIFEncoder(Encoder): 136 | def __init__(self, time: int, dt: float = 1.0, rest=-65, th=-40, ref=3, tc_decay=100, **kwargs): 137 | """ 138 | Generates vey simple spike train by LIF neuron, so the all interval between two spikes are also even. 139 | Then, each pixels values is used as a input current, so you have to normalize input images in advance. 140 | WARNING: This encoder can be used when batch size is one (non-batch). 141 | :param time: 142 | :param dt: 143 | :param kwargs: 144 | """ 145 | super().__init__(time, dt=dt, rest=rest, th=th, ref=ref, tc_decay=tc_decay, **kwargs) 146 | self.enc = lif 147 | 148 | 149 | class LIFEncoder2(Encoder): 150 | def __init__(self, time: int, dt: float = 1.0, rest=-65, th=-40, ref=3, tc_decay=10, **kwargs): 151 | """ 152 | This is also LIF encoder, but the formula of computing membrane potential is a bit different. 153 | Maybe, this model is more plausible. 'intensity' should be set to about around 100. 154 | WARNING: This encoder can be used when batch size is one (non-batch). 155 | :param time: 156 | :param dt: 157 | :param kwargs: 158 | """ 159 | super().__init__(time, dt=dt, rest=rest, th=th, ref=ref, tc_decay=tc_decay, **kwargs) 160 | self.enc = lif_v2 161 | -------------------------------------------------------------------------------- /wbn/examples/simple_mln.py: -------------------------------------------------------------------------------- 1 | from ..snnlib import Spiking, PoissonEncoder 2 | from typing import List 3 | 4 | def DiehlCook_unsupervised_model( 5 | obs_time=250, 6 | num_of_exc=100, 7 | init_weights_scale=0.3, 8 | learning_rate=(1e-4, 1e-2), 9 | weight_norm=78.4, 10 | inh_w=-128, 11 | tr_size=60000, 12 | ts_size=10000, 13 | encoder=PoissonEncoder, 14 | intensity=128, 15 | gpu=False, 16 | epochs=5, 17 | plt_wmp=True, 18 | plt_history=True, 19 | plt_result_spikes=True, 20 | debug=True, 21 | **kwargs 22 | ): 23 | """ 24 | Sample code: Diehl and Cook model using unsupervised STDP label assignment. 25 | (This is a backup model) 26 | :param obs_time: 27 | :param num_of_exc: 28 | :param init_weights_scale: 29 | :param learning_rate: 30 | :param weight_norm: 31 | :param inh_w: 32 | :param tr_size: 33 | :param ts_size: 34 | :param encoder: 35 | :param intensity: 36 | :param gpu: 37 | :param epochs: 38 | :param plt_wmp: 39 | :param plt_history: 40 | :param plt_result_spikes: 41 | :param debug: 42 | :return: 43 | """ 44 | # Build SNNs and decide the number of input neurons and the simulation time. 45 | snn = Spiking(input_l=784, obs_time=obs_time) 46 | 47 | # Add a layer and give the num of neurons and the neuron model. 48 | snn.add_layer(n=num_of_exc, 49 | node=snn.ADAPTIVE_LIF, # or snn.LIF 50 | w=snn.W_SIMPLE_RAND, # initialize weights 51 | scale=init_weights_scale, # scale of random intensity 52 | rule=snn.SIMPLE_STDP, # learning rule 53 | nu=learning_rate, # learning rate 54 | norm=weight_norm, # L1 weight normalization term 55 | wmax=kwargs.get('wmax', 1.0), 56 | wmin=kwargs.get('wmin', -1.0), 57 | ) 58 | 59 | # Add an inhibitory layer 60 | snn.add_inhibit_layer(inh_w=inh_w) 61 | 62 | # Load dataset 63 | snn.load_MNIST(encoder=encoder, intensity=intensity, min_lim=kwargs.get('min_lim', 0)) 64 | 65 | # Check your network architecture 66 | snn.print_model() 67 | 68 | # Gpu is available?? If available, make it use. 69 | if gpu: 70 | snn.to_gpu() 71 | else: 72 | snn.gpu = False 73 | 74 | # Plot weight maps before training 75 | if plt_wmp: 76 | snn.plot(plt_type='wmps', prefix='0', f_shape=kwargs.get('wmp_shape', (3, 3))) 77 | 78 | # Make my network run 79 | for i in range(epochs): 80 | snn.run(tr_size=tr_size, # training data size 81 | unsupervised=True, # do unsupervised learning? 82 | # alpha=0.8, # assignment decay 83 | debug=debug, # Do you wanna watch neuron's assignments? 84 | interval=250, # interval of assignment 85 | ts_size=ts_size, # If you have little time for experiments, be able to reduce test size 86 | ) 87 | if plt_wmp: 88 | snn.plot(plt_type='wmps', prefix='{}'.format(i+1), f_shape=kwargs.get('wmp_shape', (3, 3))) # plot maps 89 | 90 | # Plot test accuracy transition 91 | if plt_history: 92 | snn.plot(plt_type='history', prefix='result') 93 | 94 | # Plot weight maps after training 95 | if plt_wmp: 96 | snn.plot(plt_type='wmps', prefix='result', f_shape=kwargs.get('wmp_shape', (3, 3))) 97 | 98 | # Plot output spike trains after training 99 | if plt_result_spikes: 100 | snn.plot(plt_type='sp', range=10) 101 | 102 | 103 | def MultiLayerNetwork_unsupervised_model( 104 | obs_time=250, 105 | layers: List[int] = None, 106 | init_weights_scale=0.3, 107 | learning_rate=(1e-4, 1e-2), 108 | weight_norm=78.4, 109 | inh_w=-128, 110 | tr_size=60000, 111 | ts_size=10000, 112 | gpu=False, 113 | epochs=5, 114 | plt_wmp=True, 115 | plt_history=True, 116 | plt_result_spikes=True, 117 | debug=True, 118 | ): 119 | """ 120 | Simulate simple MultiLayer SNN with only full-connections. 121 | :param obs_time: 122 | :param layers: list[100, 100] is default 123 | :param init_weights_scale: 124 | :param learning_rate: 125 | :param weight_norm: 126 | :param inh_w: 127 | :param tr_size: 128 | :param ts_size: 129 | :param gpu: 130 | :param epochs: 131 | :param plt_wmp: 132 | :param plt_history: 133 | :param plt_result_spikes: 134 | :param debug: 135 | :return: 136 | """ 137 | # Build SNNs and decide the number of input neurons and the simulation time. 138 | snn = Spiking(input_l=784, obs_time=obs_time) 139 | 140 | for num in layers: 141 | # Add a layer and give the num of neurons and the neuron model. 142 | snn.add_layer(n=num, 143 | node=snn.ADAPTIVE_LIF, # or snn.LIF 144 | w=snn.W_SIMPLE_RAND, # initialize weights 145 | scale=init_weights_scale, # scale of random intensity 146 | rule=snn.SIMPLE_STDP, # learning rule 147 | nu=learning_rate, # learning rate 148 | norm=weight_norm, # L1 weight normalization term 149 | ) 150 | 151 | # Add an inhibitory layer 152 | snn.add_inhibit_layer(inh_w=inh_w) 153 | 154 | # Load dataset 155 | snn.load_MNIST() 156 | 157 | # Check your network architecture 158 | snn.print_model() 159 | 160 | # Gpu is available?? If available, make it use. 161 | if gpu: 162 | snn.to_gpu() 163 | else: 164 | snn.gpu = False 165 | 166 | # Plot weight maps before training 167 | if plt_wmp: 168 | for l, n in layers: 169 | snn.plot(plt_type='wmps', 170 | prefix='node{}-node{}_0'.format(l, l+1), 171 | layer=l+1) 172 | 173 | # Make my network run 174 | for i in range(epochs): 175 | snn.run(tr_size=tr_size, # training data size 176 | unsupervised=True, # do unsupervised learning? 177 | # alpha=0.8, # assignment decay 178 | debug=debug, # Do you wanna watch neuron's assignments? 179 | interval=250, # interval of assignment 180 | ts_size=ts_size, # If you have little time for experiments, be able to reduce test size 181 | ) 182 | if plt_wmp: 183 | for l, n in layers: 184 | snn.plot(plt_type='wmps', 185 | prefix='node{}-node{}_{}'.format(l, l + 1, i + 1), 186 | layer=l + 1) 187 | 188 | # Plot test accuracy transition 189 | if plt_history: 190 | snn.plot(plt_type='history', prefix='result') 191 | 192 | # Plot weight maps after training 193 | if plt_wmp: 194 | for l, n in layers: 195 | snn.plot(plt_type='wmps', 196 | prefix='node{}-node{}_result'.format(l, l + 1), 197 | layer=l + 1) 198 | 199 | # Plot output spike trains after training 200 | if plt_result_spikes: 201 | snn.plot(plt_type='sp', range=10) 202 | 203 | -------------------------------------------------------------------------------- /wbn/snnlib.py: -------------------------------------------------------------------------------- 1 | """ 2 | snnlib.py 3 | 4 | @description A tiny library to use BindsNET easily. 5 | @author Hiroshi ARAKI 6 | @source https://github.com/HiroshiARAKI/snnlibpy 7 | @contact araki@hirlab.net 8 | @Website https://hirlab.net 9 | """ 10 | 11 | from .additional_encoders import FixedFrequencyEncoder, LIFEncoder, LIFEncoder2 12 | 13 | import torch 14 | import torchvision.transforms as transforms 15 | from torch.utils.data import DataLoader 16 | 17 | from bindsnet.network import Network 18 | from bindsnet.network.nodes import (Nodes, Input, LIFNodes, IFNodes, IzhikevichNodes, 19 | SRM0Nodes, DiehlAndCookNodes, AdaptiveLIFNodes) 20 | from bindsnet.network.topology import Connection 21 | from bindsnet.network.monitors import Monitor 22 | from bindsnet.analysis.plotting import plot_spikes 23 | from bindsnet.learning import PostPre, NoOp, WeightDependentPostPre 24 | from bindsnet.encoding import PoissonEncoder, SingleEncoder, RepeatEncoder, RankOrderEncoder, BernoulliEncoder 25 | from bindsnet.datasets import MNIST 26 | from bindsnet.evaluation import assign_labels, all_activity, proportion_weighting 27 | 28 | from tqdm import tqdm 29 | import matplotlib.pyplot as plt 30 | import numpy as np 31 | import os 32 | import sys 33 | from time import time 34 | 35 | __version__ = '0.2.6' 36 | 37 | 38 | class Spiking: 39 | """ 40 | The Class to simulate Spiking Neural Networks. 41 | """ 42 | 43 | # ======= Constants ======= # 44 | 45 | LIF = LIFNodes 46 | IF = IFNodes 47 | IZHIKEVICH = IzhikevichNodes 48 | SRM = SRM0Nodes 49 | DIEHL_COOK = DiehlAndCookNodes 50 | ADAPTIVE_LIF = AdaptiveLIFNodes 51 | 52 | POISSON = PoissonEncoder 53 | SINGLE = SingleEncoder 54 | REPEAT = RepeatEncoder 55 | RANK_ORDER = RankOrderEncoder 56 | BERNOULI = BernoulliEncoder 57 | FIXED_FREQUENCY = FixedFrequencyEncoder 58 | LIF_ENCODER = LIFEncoder 59 | LIF_ENCODER_2 = LIFEncoder2 60 | 61 | NO_STDP: str = 'No_STDP' 62 | SIMPLE_STDP: str = 'Simple_STDP' 63 | WEIGHT_DEPENDENT_STDP: str = 'Weight_dependent_STDP' 64 | 65 | W_NORMAL_DIST: int = 0 # initialize with Normal Distribution 66 | W_RANDOM: int = 1 # initialize with Uniform Distribution [sw_min, sw_max] 67 | W_SIMPLE_RAND: int = 3 # initialize Uniform Distribution[0, scale] 68 | 69 | PROJECT_ROOT: str = os.getcwd() 70 | IMAGE_DIR: str = PROJECT_ROOT + '/images/' 71 | 72 | DPI: int = 150 # the dpi value of plt.savefig() 73 | 74 | rest_voltage = -65 # [mV] resting potential 75 | reset_voltage = -60 # [mV] reset potential 76 | threshold = -52 # [mV] firing threshold 77 | refractory_period = 5 # [ms] refractory period 78 | 79 | # if you encode images with poisson process, then [Hz] firing rate of input spikes 80 | # this intensity should be changed by the Encoder you use 81 | intensity: float = 128 82 | 83 | seed = 0 # a seed of random 84 | 85 | # ======================== # 86 | 87 | def __init__(self, input_l: int, obs_time: int = 300, dt: float = 1.0, 88 | input_shape=(1, 28, 28)): 89 | """ 90 | Constructor: Build SNN easily. Initialize many variables in backend. 91 | :param input_l: 92 | :param obs_time: 93 | """ 94 | print('You Called Spiking Neural Networks Library "WBN"!!') 95 | print('=> WrappedBindsNET (This Library) :version. %s' % __version__) 96 | 97 | self.network: Network = Network(dt=dt) # Core of SNN 98 | 99 | self.layer_index = 0 # index of last fc-layer 100 | self.input_l = input_l # num of input layer neurons 101 | self.monitors = {} # monitors to manege spikes and voltages activities 102 | self.T = obs_time # simulation time (duration) 103 | self.dt = dt # time step 104 | self.input_layer_name = 'in' 105 | 106 | self.train_data = None 107 | self.test_data = None 108 | self.train_loader = None 109 | self.test_loader = None 110 | self.train_data_num = None 111 | self.test_data_num = None 112 | 113 | self.batch = 1 114 | self.label_num = 0 115 | self.layer_names = [] 116 | self.history = { 117 | 'test_acc': [], 'train_acc': [], 118 | 'test_pro': [], 'train_pro': [], 119 | } 120 | 121 | self.gpu = torch.cuda.is_available() # Is GPU available? 122 | 123 | self.workers = self.gpu * 4 * torch.cuda.device_count() 124 | 125 | np.random.seed(self.seed) 126 | 127 | if self.gpu: 128 | torch.cuda.manual_seed_all(self.seed) 129 | else: 130 | torch.manual_seed(self.seed) 131 | 132 | self.assignments = None 133 | self.proportions = None 134 | self.rates = None 135 | 136 | input_layer = Input(n=input_l, traces=True, shape=input_shape) 137 | 138 | self.network.add_layer(layer=input_layer, name=self.input_layer_name) 139 | self.layer_names.append(self.input_layer_name) 140 | 141 | # information of the last added layer 142 | self.pre = { 143 | 'layer': input_layer, 144 | 'name': self.input_layer_name 145 | } 146 | 147 | monitor = Monitor( 148 | obj=input_layer, 149 | state_vars=('s',), 150 | time=self.T 151 | ) 152 | 153 | self.monitors[self.input_layer_name] = monitor 154 | 155 | self.network.add_monitor(monitor=monitor, name=self.input_layer_name) 156 | 157 | def add_layer(self, n: int, name='', node: Nodes = LIF, 158 | w=W_NORMAL_DIST, rule=SIMPLE_STDP, 159 | wmax: float = 1, wmin: float = -1, norm: float = 78.4, 160 | **kwargs): 161 | """ 162 | Add a full connection layer that consists LIF neuron. 163 | :param n: 164 | :param name: 165 | :param node: 166 | :param w: 167 | :param rule: 168 | :param wmax: 169 | :param wmin: 170 | :param norm: 171 | :param kwargs: nu (learning rate of STDP), mu, sigma, w_max and w_min are available 172 | :return: 173 | """ 174 | 175 | layer = node(n=n, 176 | traces=True, 177 | rest=self.rest_voltage, 178 | restet=self.reset_voltage, 179 | thresh=self.threshold, 180 | refrac=self.refractory_period, 181 | tc_decay=kwargs.get('tc_decay', 100.0), 182 | theta_plus=kwargs.get('theta_plus', 0.05), 183 | tc_theta_decay=kwargs.get('tc_theta_decay', 1e7), 184 | **kwargs 185 | ) 186 | 187 | if name == '' or name is None: 188 | name = 'fc-' + str(self.layer_index) 189 | self.layer_index += 1 190 | 191 | if type(w) is int: 192 | if w is self.W_NORMAL_DIST: 193 | mu = 0.3 if 'mu' not in kwargs else kwargs['mu'] 194 | sigma = 0.3 if 'sigma' not in kwargs else kwargs['sigma'] 195 | 196 | w = self.weight_norm(self.pre['layer'].n, layer.n, 197 | mu=mu, sigma=sigma) 198 | if w is self.W_RANDOM: 199 | w_max = 0.5 if 'sw_max' not in kwargs else kwargs['sw_max'] 200 | w_min = -0.5 if 'sw_min' not in kwargs else kwargs['sw_min'] 201 | 202 | w = self.weight_rand(self.pre['layer'].n, layer.n, 203 | w_max=w_max, w_min=w_min) 204 | if w is self.W_SIMPLE_RAND: 205 | if 'scale' not in kwargs: 206 | scale = 0.3 207 | else: 208 | scale = kwargs['scale'] 209 | w = self.weight_simple_rand(self.pre['layer'].n, layer.n, scale) 210 | 211 | self.network.add_layer(layer=layer, name=name) 212 | self.layer_names.append(name) 213 | 214 | if 'nu' not in kwargs: 215 | nu = (1e-4, 1e-2) 216 | else: 217 | nu = kwargs['nu'] 218 | 219 | if rule == self.SIMPLE_STDP: 220 | l_rule = PostPre 221 | elif rule == self.WEIGHT_DEPENDENT_STDP: 222 | l_rule = WeightDependentPostPre 223 | elif rule == self.NO_STDP: 224 | l_rule = NoOp 225 | else: 226 | l_rule = NoOp 227 | 228 | connection = Connection( 229 | source=self.pre['layer'], 230 | target=layer, 231 | w=w, 232 | wmax=wmax, 233 | wmin=wmin, 234 | update_rule=l_rule, 235 | nu=nu, 236 | norm=norm 237 | ) 238 | 239 | self.network.add_connection(connection, 240 | source=self.pre['name'], 241 | target=name,) 242 | monitor = Monitor( 243 | obj=layer, 244 | state_vars=('s', 'v'), 245 | time=self.T 246 | ) 247 | 248 | self.monitors[name] = monitor 249 | 250 | self.network.add_monitor(monitor=monitor, name=name) 251 | 252 | self.pre['layer'] = layer 253 | self.pre['name'] = name 254 | 255 | print('-- Added', name, 'with the Learning rule,', rule) 256 | 257 | def add_inhibit_layer(self, n: int = None, name='', node: Nodes = LIF, 258 | exc_w: float = 22.5, inh_w: float = -100): 259 | """ 260 | Add an inhibitory layer behind the last layer. 261 | If you added this layer, you can add layers more behind a last normal layer, not an inhibitory layer. 262 | :param n: 263 | :param name: 264 | :param node: 265 | :param exc_w: 266 | :param inh_w: 267 | :return: 268 | """ 269 | if n is None: 270 | n = self.pre['layer'].n 271 | 272 | layer = node( 273 | n=n, 274 | traces=False, 275 | rest=self.rest_voltage, 276 | restet=self.reset_voltage, 277 | thresh=self.threshold, 278 | refrac=self.refractory_period, 279 | ) 280 | 281 | if name == '' or name is None: 282 | name = 'inh[' + self.pre['name'] + ']' 283 | self.layer_index += 1 284 | 285 | self.network.add_layer(layer=layer, name=name) 286 | 287 | n_neurons = self.pre['layer'].n 288 | 289 | # 最終層 - 即抑制層の接続 290 | w = exc_w * torch.diag(torch.ones(n_neurons)) 291 | last_to_inh_conn = Connection( 292 | source=self.pre['layer'], 293 | target=layer, 294 | w=w, 295 | wmin=0, 296 | wmax=exc_w, 297 | ) 298 | 299 | # 即抑制層 - 最終層の接続 300 | w = inh_w * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons))) 301 | inh_to_last_conn = Connection( 302 | source=layer, 303 | target=self.pre['layer'], 304 | w=w, 305 | wmin=inh_w, 306 | wmax=0, 307 | ) 308 | 309 | self.network.add_connection(last_to_inh_conn, 310 | source=self.pre['name'], 311 | target=name) 312 | self.network.add_connection(inh_to_last_conn, 313 | source=name, 314 | target=self.pre['name']) 315 | 316 | print('-- Added', name, 'as an inhibitory layer') 317 | 318 | def add_recurrent_layer_exc_inh(self, exc_n: int, inh_n: int, 319 | name: str = ('rec_exc', 'rec_inh'), 320 | node: Nodes = LIF, 321 | init_exc_w: float = 0.3, init_inh_w: float = -0.5, 322 | exc_nu=(1e-4, 1e-2), inh_nu=(1e-2, 1e-4), 323 | rule: bool = SIMPLE_STDP, **kwargs): 324 | """ 325 | Create a Recurrent connection layer 326 | :param exc_n: 327 | :param inh_n: 328 | :param name: 329 | :param node: 330 | :param init_exc_w: 331 | :param init_inh_w: 332 | :param exc_nu: 333 | :param inh_nu: 334 | :param rule: 335 | :param kwargs: 336 | :return: 337 | """ 338 | # exc-neuron layer 339 | exc_l = node(n=exc_n, 340 | traces=True, 341 | rest=kwargs.get('exc_rest', self.rest_voltage), 342 | restet=kwargs.get('exc_reset', self.reset_voltage), 343 | thresh=kwargs.get('exc_th', self.threshold), 344 | refrac=kwargs.get('exc_ref', self.refractory_period), 345 | tc_decay=kwargs.get('exc_tc_decay', 100.0), 346 | theta_plus=kwargs.get('exc_theta_plus', 0.05), 347 | tc_theta_decay=kwargs.get('exc_tc_theta_decay', 1e7), 348 | **kwargs 349 | ) 350 | # inh-neuron layer 351 | inh_l = node(n=inh_n, 352 | traces=True, 353 | rest=kwargs.get('inh_rest', self.rest_voltage), 354 | restet=kwargs.get('inh_reset', self.reset_voltage), 355 | thresh=kwargs.get('inh_th', self.threshold), 356 | refrac=kwargs.get('inh_ref', self.refractory_period), 357 | tc_decay=kwargs.get('inh_tc_decay', 100.0), 358 | theta_plus=kwargs.get('inh_theta_plus', 0.05), 359 | tc_theta_decay=kwargs.get('inh_tc_theta_decay', 1e7), 360 | **kwargs 361 | ) 362 | 363 | self.network.add_layer(layer=exc_l, name=name[0]) 364 | self.network.add_layer(layer=inh_l, name=name[1]) 365 | 366 | if rule == self.SIMPLE_STDP: 367 | l_rule = PostPre 368 | elif rule == self.WEIGHT_DEPENDENT_STDP: 369 | l_rule = WeightDependentPostPre 370 | elif rule == self.NO_STDP: 371 | l_rule = NoOp 372 | else: 373 | l_rule = NoOp 374 | 375 | """ Normal Connection (input to exc & inh)""" 376 | input_exc_connection = Connection( 377 | source=self.pre['layer'], # input 378 | target=exc_l, 379 | w=self.weight_simple_rand(n=self.pre['layer'].n, m=exc_n, scale=kwargs.get('scale', 0.3)), 380 | wmax=kwargs.get('wmax', 1.0), 381 | wmin=kwargs.get('wmin', -1.0), 382 | update_rule=l_rule, 383 | nu=kwargs.get('nu', (1e-4, 1e-2)), 384 | norm=kwargs.get('norm', 78.4) 385 | ) 386 | input_inh_connection = Connection( 387 | source=self.pre['layer'], # input 388 | target=inh_l, 389 | w=self.weight_simple_rand(n=self.pre['layer'].n, m=inh_n, scale=kwargs.get('scale', 0.3)), 390 | wmax=kwargs.get('wmax', 1.0), 391 | wmin=kwargs.get('wmin', -1.0), 392 | update_rule=l_rule, 393 | nu=kwargs.get('nu', (1e-4, 1e-2)), 394 | norm=kwargs.get('norm', 78.4) 395 | ) 396 | 397 | self.network.add_connection(input_exc_connection, 398 | source=self.pre['name'], 399 | target=name[0], ) 400 | self.network.add_connection(input_inh_connection, 401 | source=self.pre['name'], 402 | target=name[1], ) 403 | 404 | """ Reccurent Connecion (exc and inh)""" 405 | exc_inh_connection = Connection( 406 | source=exc_l, 407 | target=inh_l, 408 | w=self.weight_simple_rand(exc_n, inh_n, init_exc_w), 409 | wmax=kwargs.get('exc_wmax', 1.0), 410 | wmin=kwargs.get('exc_wmin', 0), 411 | update_rule=l_rule, 412 | nu=exc_nu, 413 | norm=kwargs.get('exc_norm', exc_n / 10) 414 | ) 415 | inh_exc_connection = Connection( 416 | source=inh_l, 417 | target=exc_l, 418 | w=self.weight_simple_rand(inh_n, exc_n, init_inh_w), 419 | wmax=kwargs.get('inh_wmax', 0), 420 | wmin=kwargs.get('inh_wmin', -1.0), 421 | update_rule=l_rule, 422 | nu=inh_nu, 423 | norm=kwargs.get('inh_norm', inh_n / 10) 424 | ) 425 | 426 | self.network.add_connection( 427 | exc_inh_connection, 428 | source=name[0], 429 | target=name[1] 430 | ) 431 | self.network.add_connection( 432 | inh_exc_connection, 433 | source=name[1], 434 | target=name[0] 435 | ) 436 | 437 | """ Setting monitors """ 438 | exc_monitor = Monitor( 439 | obj=exc_l, 440 | state_vars=('s', 'v'), 441 | time=self.T 442 | ) 443 | inh_monitor = Monitor( 444 | obj=inh_l, 445 | state_vars=('s', 'v'), 446 | time=self.T 447 | ) 448 | 449 | self.monitors[name[0]] = exc_monitor 450 | self.monitors[name[1]] = inh_monitor 451 | self.network.add_monitor(monitor=exc_monitor, name=name[0]) 452 | self.network.add_monitor(monitor=inh_monitor, name=name[1]) 453 | 454 | self.pre['layer'] = exc_l 455 | self.pre['name'] = name[0] 456 | print(' -- Added Reccurent layer (exc: {}, inh: {})'.format(exc_n, inh_n)) 457 | 458 | def load_MNIST(self, batch: int = 1, 459 | encoder=PoissonEncoder, 460 | intensity=intensity, 461 | **kwargs): 462 | """ 463 | Load MNIST dataset from pyTorch. 464 | :param batch: 465 | :param encoder: 466 | :param intensity: 467 | :return: 468 | """ 469 | self.batch = batch 470 | self.train_data_num = 60000 471 | self.test_data_num = 10000 472 | self.label_num = 10 473 | self.train_data = MNIST(encoder(time=self.T, dt=self.dt), 474 | None, 475 | root=self.PROJECT_ROOT+'/data/mnist', 476 | train=True, 477 | download=True, 478 | transform=transforms.Compose( 479 | [transforms.ToTensor(), 480 | transforms.Lambda(lambda x: x*intensity + kwargs.get('min_lim', 0))] 481 | )) 482 | self.test_data = MNIST(encoder(time=self.T, dt=self.dt), 483 | None, 484 | root=self.PROJECT_ROOT+'/data/mnist', 485 | train=False, 486 | download=True, 487 | transform=transforms.Compose( 488 | [transforms.ToTensor(), 489 | transforms.Lambda(lambda x: x*intensity + kwargs.get('min_lim', 0))] 490 | )) 491 | 492 | self.train_loader = DataLoader(self.train_data, 493 | batch_size=batch, 494 | shuffle=True, 495 | pin_memory=self.gpu, 496 | num_workers=self.workers) 497 | self.test_loader = DataLoader(self.test_data, 498 | batch_size=batch, 499 | shuffle=False, 500 | pin_memory=self.gpu, 501 | num_workers=self.workers) 502 | 503 | def run(self, 504 | tr_size=None, 505 | unsupervised: bool = True, 506 | alpha: float = 1.0, 507 | interval: int = 250, 508 | debug: bool = False, 509 | **kwargs): 510 | """ 511 | Let the Network run simply. 512 | :param tr_size: 513 | :param unsupervised: 514 | :param alpha: 515 | :param interval: 516 | :param debug: 517 | :return: 518 | """ 519 | if unsupervised: 520 | print('') 521 | if tr_size is None: 522 | tr_size = int(self.train_data_num / self.batch) 523 | else: 524 | tr_size = int(tr_size / self.batch) 525 | 526 | n_out_neurons = self.pre['layer'].n 527 | 528 | assignments = -torch.ones(n_out_neurons) 529 | spikes = torch.zeros(interval, self.T, n_out_neurons) 530 | labels = [] 531 | proportions = torch.zeros(n_out_neurons, self.label_num) 532 | rates = None 533 | 534 | progress = tqdm(enumerate(self.train_loader)) 535 | start = time() 536 | for i, data in progress: 537 | progress.set_description_str('\rProgress: %d / %d (%.4f seconds)' % (i, tr_size, time() - start)) 538 | 539 | inputs_img = {'in': data['encoded_image'].view(int(self.T/self.dt), self.batch, 1, 28, 28)} 540 | 541 | # assign labels 542 | if unsupervised and i % interval == 0 and i > 0: 543 | t_labels = torch.tensor(labels) # to tensor 544 | # Get the assignments of output neurons 545 | assignments, proportions, rates = assign_labels( 546 | spikes=spikes, 547 | labels=t_labels, 548 | n_labels=self.label_num, 549 | rates=rates, 550 | alpha=alpha 551 | ) 552 | labels = [] 553 | 554 | if self.gpu: 555 | inputs_img = {key: img.cuda() for key, img in inputs_img.items()} 556 | 557 | # run! 558 | self.network.run(inputs_img, time=self.T, input_time_dim=1) 559 | 560 | if unsupervised: 561 | # labels used by assigning 562 | labels.append(data['label']) 563 | # output spike trains 564 | spikes[i % interval] = self.monitors[self.pre['name']].get('s').squeeze() 565 | 566 | self.network.reset_state_variables() 567 | 568 | if i >= tr_size: # if reach training size you specified, break for loop 569 | break 570 | 571 | print('Progress: %d / %d (%.4f seconds)' % (tr_size, tr_size, time() - start)) 572 | 573 | # compute train. and test accuracies 574 | if unsupervised: 575 | print('Computing accuracies...') 576 | 577 | if debug: 578 | print('\n[Neurons assignments]') 579 | print(assignments) 580 | 581 | self.stop_learning() 582 | 583 | train_acc, train_pro = self.calc_train_accuracy(assignments=assignments, 584 | proportions=proportions, 585 | tr_size=tr_size) 586 | 587 | test_acc, test_pro = self.calc_test_accuracy(assignments=assignments, 588 | proportions=proportions, 589 | ts_size=kwargs.get('ts_size', None)) 590 | 591 | self.history['train_acc'].append(train_acc) 592 | self.history['train_pro'].append(train_pro) 593 | self.history['test_acc'].append(test_acc) 594 | self.history['test_pro'].append(test_pro) 595 | 596 | print('\n*** Train accuracy is %4f ***' % self.history['train_acc'][-1]) 597 | print('*** Train accuracy with proportions is %4f ***' % self.history['train_pro'][-1]) 598 | print('*** Test accuracy is %4f ***' % self.history['test_acc'][-1]) 599 | print('*** Test accuracy with proportions is %4f ***\n' % self.history['test_pro'][-1]) 600 | 601 | self.start_learning() 602 | 603 | print('===< Have finished running the network >===\n') 604 | 605 | def calc_test_accuracy(self, 606 | assignments: torch.Tensor, 607 | proportions: torch.Tensor = None, 608 | ts_size: int = None) -> (float, float): 609 | """ 610 | Calculate test accuracy with the assignment. 611 | :param assignments: 612 | :param proportions: 613 | :param ts_size: 614 | :return: 615 | """ 616 | 617 | if ts_size is None: 618 | ts_size = self.test_data_num 619 | 620 | n_neurons = self.pre['layer'].n 621 | 622 | if proportions is None: 623 | proportions = torch.ones(n_neurons, self.label_num).float() 624 | 625 | interval = 250 626 | spike_record = torch.zeros(interval, self.T, n_neurons) 627 | labels = [] 628 | 629 | count_correct = {'acc': 0, 'pro': 0} 630 | 631 | progress = tqdm(enumerate(self.test_loader)) 632 | print('\n===< Calculate Test accuracy >===') 633 | for i, data in progress: 634 | progress.set_description_str('\rCalculate Test accuracy ... %d / %d ' % (i, ts_size)) 635 | inputs_img = {'in': data['encoded_image'].view(int(self.T/self.dt), self.batch, 1, 28, 28)} 636 | 637 | if self.gpu: 638 | inputs_img = {key: img.cuda() for key, img in inputs_img.items()} 639 | 640 | if i % interval == 0 and i > 0: 641 | # Convert the array of labels into a tensor 642 | label_tensor = torch.tensor(labels) 643 | 644 | # Get network predictions. 645 | all_activity_pred = all_activity( 646 | spikes=spike_record, assignments=assignments, n_labels=self.label_num 647 | ) 648 | proportion_pred = proportion_weighting( 649 | spikes=spike_record, 650 | assignments=assignments, 651 | proportions=proportions, 652 | n_labels=self.label_num, 653 | ) 654 | 655 | count_correct['acc'] += torch.sum(label_tensor.long() == all_activity_pred).item() 656 | count_correct['pro'] += torch.sum(label_tensor.long() == proportion_pred).item() 657 | 658 | labels = [] 659 | 660 | # get labels 661 | labels.append(data["label"]) 662 | 663 | # run! 664 | self.network.run(inputs=inputs_img, time=self.T) 665 | 666 | # get output spike trains 667 | spike_record[i % interval] = self.monitors[self.pre['name']].get("s").squeeze() 668 | 669 | self.network.reset_state_variables() 670 | 671 | if i >= ts_size: 672 | break 673 | 674 | print('\r ... done!') 675 | return (float(count_correct['acc']) / float(ts_size), # accuracy 676 | float(count_correct['pro']) / float(ts_size)) # accuracy with proportions 677 | 678 | def calc_train_accuracy(self, 679 | assignments: torch.Tensor, 680 | proportions: torch.Tensor = None, 681 | tr_size: int = None) -> (float, float): 682 | """ 683 | Calculate train accuracy with the assignment. 684 | :param assignments: 685 | :param proportions: 686 | :param tr_size: 687 | :return: 688 | """ 689 | if tr_size is None: 690 | tr_size = self.train_data_num 691 | 692 | n_neurons = self.pre['layer'].n 693 | 694 | if proportions is None: 695 | proportions = torch.ones(n_neurons, self.label_num).float() 696 | 697 | interval = 250 698 | spike_record = torch.zeros(interval, self.T, n_neurons) 699 | labels = [] 700 | 701 | count_correct = {'acc': 0, 'pro': 0} 702 | 703 | progress = tqdm(enumerate(self.train_loader)) 704 | print('\n===< Calculate Training accuracy >===') 705 | for i, data in progress: 706 | progress.set_description_str('\rCalculate Training accuracy ... %d / %d ' % (i, tr_size)) 707 | inputs_img = {'in': data['encoded_image'].view(int(self.T/self.dt), self.batch, 1, 28, 28)} 708 | 709 | if self.gpu: 710 | inputs_img = {key: img.cuda() for key, img in inputs_img.items()} 711 | 712 | if i % interval == 0 and i > 0: 713 | # Convert the array of labels into a tensor 714 | label_tensor = torch.tensor(labels) 715 | 716 | # Get network predictions. 717 | all_activity_pred = all_activity( 718 | spikes=spike_record, assignments=assignments, n_labels=self.label_num 719 | ) 720 | proportion_pred = proportion_weighting( 721 | spikes=spike_record, 722 | assignments=assignments, 723 | proportions=proportions, 724 | n_labels=self.label_num, 725 | ) 726 | 727 | count_correct['acc'] += torch.sum(label_tensor.long() == all_activity_pred).item() 728 | count_correct['pro'] += torch.sum(label_tensor.long() == proportion_pred).item() 729 | 730 | labels = [] 731 | 732 | # get labels 733 | labels.append(data["label"]) 734 | 735 | # run! 736 | self.network.run(inputs=inputs_img, time=self.T) 737 | 738 | # get output spike trains 739 | spike_record[i % interval] = self.monitors[self.pre['name']].get('s').squeeze() 740 | 741 | self.network.reset_state_variables() 742 | 743 | if i >= tr_size: 744 | break 745 | 746 | print('\r ... done!') 747 | return (float(count_correct['acc']) / float(tr_size), # accuracy 748 | float(count_correct['pro']) / float(tr_size)) # accuracy with proportions 749 | 750 | def test(self, data_num: int): 751 | """ 752 | Calculate test accuracy with the label assignment used training data. 753 | :param data_num: 754 | :return accuracy: 755 | """ 756 | # Stop learning 757 | for layer in self.network.layers: 758 | self.network.layers[layer].train(False) 759 | 760 | # the assignments of output neurons 761 | assignment = torch.zeros(self.label_num, self.pre['layer'].n) 762 | 763 | print('===< Calculate train spikes and assign labels >===') 764 | progress = tqdm(enumerate(self.train_loader)) 765 | for i, data in progress: 766 | progress.set_description_str('\rAssign labels... %d / %d ' % (i, data_num)) 767 | inputs_img = {'in': data['encoded_image'].view(int(self.T/self.dt), self.batch, 1, 28, 28)} 768 | 769 | if self.gpu: 770 | inputs_img = {key: img.cuda() for key, img in inputs_img.items()} 771 | # run! 772 | self.network.run(inputs=inputs_img, time=self.T) 773 | 774 | # output spike trains 775 | spikes: torch.Tensor = self.monitors[self.pre['name']].get('s') 776 | 777 | # sum of the number of spikes 778 | sum_spikes = spikes.sum(0) 779 | 780 | max_n_fire = sum_spikes.argmax(1) 781 | labels = data['label'] 782 | 783 | for j, l in enumerate(labels): 784 | assignment[l][max_n_fire[j]] += 1 785 | 786 | self.network.reset_state_variables() 787 | 788 | if i >= data_num: 789 | break 790 | 791 | # this result is assignment of output neurons 792 | assignment = assignment.argmax(0) 793 | 794 | # Calculate accuracy 795 | labels_rate = torch.zeros(self.label_num).float() # each firing rate of labels 796 | count_correct = 0 797 | progress = tqdm(enumerate(self.test_loader)) 798 | print('\n===< Calculate Test accuracy >===') 799 | for i, data in progress: 800 | progress.set_description_str('\rCalculate Test accuracy ... %d / %d ' % (i, self.test_data_num)) 801 | inputs_img = {'in': data['encoded_image'].view(int(self.T/self.dt), self.batch, 1, 28, 28)} 802 | 803 | if self.gpu: 804 | inputs_img = {key: img.cuda() for key, img in inputs_img.items()} 805 | # run! 806 | self.network.run(inputs=inputs_img, time=self.T) 807 | 808 | # output spike trains 809 | spikes: torch.Tensor = self.monitors[self.pre['name']].get('s') 810 | 811 | # sum of the number of spikes 812 | sum_spikes = spikes.sum(0) 813 | self.network.reset_state_variables() 814 | 815 | for b in range(self.batch): 816 | for l in range(self.label_num): 817 | if l in assignment: 818 | indices = torch.tensor([i for i, a in enumerate(assignment) if a == l]) 819 | count_assign = torch.sum(assignment == l) 820 | labels_rate[l] += torch.sum(sum_spikes[b][indices]).float() / count_assign.float() 821 | 822 | # if actual prediction equals desired label, increment the count. 823 | if labels_rate.argmax() == data['label']: 824 | count_correct += 1 825 | 826 | # initialize zeros 827 | labels_rate[:] = 0 828 | 829 | acc = float(count_correct) / float(self.test_data_num) 830 | self.history['test_acc'].append(acc) 831 | 832 | print('\n*** Test accuracy is %4f ***\n' % acc) 833 | 834 | # make learning rates be back 835 | for layer in self.network.layers: 836 | self.network.layers[layer].train(True) 837 | 838 | return acc 839 | 840 | def plot_out_voltage(self, index: int, save: bool = False, 841 | file_name: str = 'out_voltage.png', dpi: int = DPI): 842 | """ 843 | Plot a membrane potential of 'index'th neuron in the final layer. 844 | :param index: 845 | :param save: 846 | :param file_name: 847 | :param dpi: 848 | :return: 849 | """ 850 | os.makedirs(self.IMAGE_DIR, exist_ok=True) 851 | 852 | voltage = self.monitors[self.pre['name']].get('v').numpy().reshape(self.T, self.pre['layer'].n).T[index] 853 | 854 | plt.title('Membrane Voltage at neuron[{0}] in final layer'.format(index)) 855 | plt.plot(voltage) 856 | plt.xlabel('time [ms]') 857 | plt.ylabel('voltage [mV]') 858 | plt.ylim(self.reset_voltage-5, self.threshold+5) 859 | if not save: 860 | plt.show() 861 | else: 862 | plt.savefig(self.IMAGE_DIR+file_name, dpi=dpi) 863 | 864 | plt.close() 865 | 866 | @staticmethod 867 | def plot_spikes_scatter(spikes, save=False, dpi=DPI, file_name='spike.png'): 868 | plt.ioff() 869 | plot_spikes(spikes) 870 | if not save: 871 | plt.show() 872 | else: 873 | plt.savefig(file_name, dpi=dpi) 874 | 875 | def plot_spikes(self, save: bool = False, index: int = 0, 876 | file_name: str = 'spikes.png', dpi: int = DPI): 877 | """ 878 | Plot spike trains of all neurons as a scatter plot. 879 | :param save: 880 | :param index: 881 | :param file_name: 882 | :param dpi: 883 | :return: 884 | """ 885 | 886 | self.make_image_dir() 887 | self.stop_learning() 888 | 889 | data = self.train_data[index] 890 | label = data['label'] 891 | 892 | inputs_img = {'in': data['encoded_image'].view(int(self.T/self.dt), self.batch, 1, 28, 28)} 893 | if self.gpu: 894 | inputs_img = {key: img.cuda() for key, img in inputs_img.items()} 895 | 896 | self.network.run(inputs=inputs_img, time=self.T) 897 | 898 | spikes = {} 899 | for m_name in self.monitors: 900 | spikes[m_name] = self.monitors[m_name].get('s') 901 | 902 | plt.ioff() 903 | plot_spikes(spikes) 904 | if not save: 905 | plt.show() 906 | else: 907 | plt.savefig(self.IMAGE_DIR+'label_'+str(label)+file_name, dpi=dpi) 908 | plt.close() 909 | 910 | self.start_learning() 911 | 912 | def plot_poisson_img(self, image: torch.Tensor, save: bool = False, 913 | file_name='poisson_img.png', dpi: int = DPI): 914 | """ 915 | Plot a poisson image. 916 | :param image: 917 | :param save: 918 | :param file_name: 919 | :param dpi: 920 | :return: 921 | """ 922 | 923 | self.make_image_dir() 924 | 925 | result_img = np.zeros((28, 28)) 926 | for dt_spike_img in image: 927 | result_img += dt_spike_img.numpy().reshape((28, 28)) 928 | 929 | plt.imshow(result_img, cmap='winter') 930 | plt.colorbar().set_label('# of spikes') 931 | 932 | if not save: 933 | plt.show() 934 | else: 935 | plt.savefig(self.IMAGE_DIR + file_name, dpi=dpi) 936 | 937 | plt.close() 938 | 939 | def plot_output_weight_map(self, index: int, save: bool = False, 940 | file_name: str = 'weight_map.png', dpi: int = DPI, 941 | c_max: float = 1.0, c_min: float = -1.0): 942 | """ 943 | Plot an afferent weight map of the last layer's [index]th neuron. 944 | :param index: 945 | :param save: 946 | :param file_name: 947 | :param dpi: 948 | :param c_max: max of colormap 949 | :param c_min: min of colormap 950 | :return: 951 | """ 952 | self.make_image_dir() 953 | 954 | names = self.layer_names 955 | last = len(names) - 1 956 | 957 | # Get the connection information of the last layer 958 | weight: torch.Tensor = self.network.connections[(names[last-1], names[last])].w 959 | # to ndarray and trans. 960 | weight = weight.numpy().T if not self.gpu else weight.cpu().numpy().T 961 | 962 | # to shape as (n, n) 963 | weight = weight[index].reshape(28, 28) 964 | 965 | # Set the center of a colormap zero. 966 | wmax = weight.max() if weight.max() > c_max else c_max 967 | wmin = weight.min() if weight.min() < c_min else c_min 968 | abs_max = abs(wmax) if abs(wmax) > abs(wmin) else abs(wmin) 969 | 970 | plt.imshow(weight, cmap='coolwarm', vmax=abs_max, vmin=(-abs_max)) 971 | plt.colorbar() 972 | if not save: 973 | plt.show() 974 | else: 975 | plt.savefig(self.IMAGE_DIR + file_name, dpi=dpi) 976 | plt.close() 977 | 978 | def plot_weight_maps(self, 979 | f_shape: tuple = (3, 3), 980 | layer: int = None, 981 | file_name: str = 'weight_maps.png', 982 | dpi: int = DPI, 983 | c_max: float = 1.0, 984 | c_min: float = -1.0, 985 | save: bool = True, 986 | **kwargs): 987 | """ 988 | Plot weight maps of output connection with the shape of [f_shape]. 989 | :param f_shape: 990 | :param layer: post layer index counted from 0 (default is last layer) 991 | :param file_name: 992 | :param dpi: 993 | :param c_max: 994 | :param c_min: 995 | :param save: 996 | :return: 997 | """ 998 | 999 | self.make_image_dir() 1000 | 1001 | names = self.layer_names 1002 | if layer is None: 1003 | post = len(names) - 1 1004 | else: 1005 | post = layer 1006 | 1007 | # Get the connection information of the post layer 1008 | weight: torch.Tensor = self.network.connections[(names[post - 1], names[post])].w 1009 | # to ndarray and trans. 1010 | weight = weight.numpy().T if not self.gpu else weight.cpu().numpy().T 1011 | 1012 | # setting of figure 1013 | fig, axes = plt.subplots(ncols=f_shape[0], nrows=f_shape[1]) 1014 | 1015 | index = 0 1016 | im = None 1017 | for cols in axes: 1018 | for ax in cols: 1019 | sys.stdout.write('\rPlot weight map {}/{}'.format(index+1, f_shape[0]*f_shape[1])) 1020 | sys.stdout.flush() 1021 | 1022 | # to shape as (n, n) 1023 | n_neurons = self.network.layers[names[post - 1]].n 1024 | n = int(np.sqrt(n_neurons)) 1025 | tmp_weight = weight[index].reshape(n, n) 1026 | 1027 | # Set the center of a colormap zero. 1028 | wmax = tmp_weight.max() if tmp_weight.max() > c_max else c_max 1029 | wmin = tmp_weight.min() if tmp_weight.min() < c_min else c_min 1030 | abs_max = abs(wmax) if abs(wmax) > abs(wmin) else abs(wmin) 1031 | 1032 | im = ax.imshow(tmp_weight, cmap='coolwarm', vmax=abs_max, vmin=(-abs_max)) 1033 | # ax.set_title('map({})'.format(index)) 1034 | ax.tick_params(labelbottom=False, 1035 | labelleft=False, 1036 | labelright=False, 1037 | labeltop=False, 1038 | bottom=False, 1039 | left=False, 1040 | right=False, 1041 | top=False 1042 | ) 1043 | index += 1 1044 | print('\nPlotting... Done!') 1045 | fig.subplots_adjust(right=0.8) 1046 | cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7]) 1047 | fig.colorbar(im, cax=cbar_ax) 1048 | 1049 | fig.suptitle(kwargs.get('title', '')) 1050 | if not save: 1051 | plt.show() 1052 | else: 1053 | plt.savefig(self.IMAGE_DIR + file_name, dpi=dpi) 1054 | plt.close() 1055 | 1056 | def plot(self, plt_type: str, **kwargs): 1057 | """ 1058 | A Shortcut function about plotting 1059 | :param plt_type: 1060 | :param kwargs: 1061 | :return: 1062 | """ 1063 | plt.ioff() 1064 | if 'save' not in kwargs: 1065 | kwargs['save'] = True 1066 | if 'prefix' not in kwargs: 1067 | kwargs['prefix'] = '' 1068 | if 'range' not in kwargs: 1069 | kwargs['range'] = 1 1070 | 1071 | if plt_type == 'wmp': 1072 | for i in range(kwargs['range']): 1073 | self.plot_output_weight_map(index=i, 1074 | save=kwargs['save'], 1075 | file_name='{}_wmp_'.format(kwargs['prefix'])+str(i+1)+'.png') 1076 | elif plt_type == 'sp': 1077 | for i in range(kwargs['range']): 1078 | self.plot_spikes(save=kwargs['save'], 1079 | index=i) 1080 | elif plt_type == 'history': 1081 | epochs = len(self.history['train_acc']) 1082 | epochs = np.arange(1, epochs+1) 1083 | print(self.history) 1084 | plt.plot(epochs, self.history['train_acc'], label='train_acc', marker='.', c='b') 1085 | plt.plot(epochs, self.history['train_pro'], label='train_pro', marker='.', c='b', linestyle='dashed') 1086 | plt.plot(epochs, self.history['test_acc'], label='test_acc', marker='.', c='g') 1087 | plt.plot(epochs, self.history['test_pro'], label='test_pro', marker='.', c='g', linestyle='dashed') 1088 | plt.xlabel('epoch') 1089 | plt.ylabel('accuracy') 1090 | plt.legend() 1091 | if kwargs['save']: 1092 | plt.savefig(self.IMAGE_DIR + '{}_accuracies.png'.format(kwargs['prefix']), dpi=self.DPI) 1093 | else: 1094 | plt.show() 1095 | plt.close() 1096 | 1097 | elif plt_type == 'wmps': 1098 | self.plot_weight_maps(f_shape=kwargs.get('f_shape', (3, 3)), 1099 | layer=kwargs.get('layer', None), 1100 | file_name='{}_weight_maps.png'.format(kwargs['prefix']), 1101 | save=kwargs['save'], 1102 | title=kwargs.get('title', '')) 1103 | elif plt_type == 'p_img': 1104 | pass 1105 | elif plt_type == 'v': 1106 | pass 1107 | else: 1108 | print('Not Found the plt_type.') 1109 | 1110 | def get_train_batch(self, index) -> torch.Tensor: 1111 | return self.train_loader[index]['data'] 1112 | 1113 | def make_image_dir(self): 1114 | return os.makedirs(self.IMAGE_DIR, exist_ok=True) 1115 | 1116 | def print_model(self): 1117 | print('=============================') 1118 | print('Show your network information below.') 1119 | layers: dict[str: Nodes] = self.network.layers 1120 | print('Layers:') 1121 | for l in layers: 1122 | print(' '+l+'('+str(layers[l].n)+')', end='\n |\n') 1123 | print(' [END]') 1124 | print('=============================') 1125 | print('Simulation Time: {}, dt: {}'.format(self.T, self.dt)) 1126 | print('=============================') 1127 | 1128 | def to_gpu(self): 1129 | """ 1130 | Set gpu to the network if available. 1131 | :return: 1132 | """ 1133 | if self.gpu: 1134 | print('GPU computing is available.') 1135 | self.network.to('cuda') 1136 | return True 1137 | else: 1138 | print('You use Only CPU computing.') 1139 | return False 1140 | 1141 | def stop_learning(self): 1142 | """ 1143 | Stop learning 1144 | :return: 1145 | """ 1146 | self.network.train(False) 1147 | 1148 | def start_learning(self): 1149 | """ 1150 | (Re)start learning 1151 | :return: 1152 | """ 1153 | self.network.train(True) 1154 | 1155 | @staticmethod 1156 | def weight_norm(n: int, m: int, mu: float = 0.3, sigma: float = 0.3) -> torch.Tensor: 1157 | return mu + sigma * torch.randn(n, m) 1158 | 1159 | @staticmethod 1160 | def weight_rand(n: int, m: int, w_max: float = 0.5, w_min: float = -0.5) -> torch.Tensor: 1161 | x = torch.rand(n, m) 1162 | x_max = x.max() 1163 | x_min = x.min() 1164 | return ((x - x_min) / (x_max - x_min)) * (w_max - w_min) + w_min 1165 | 1166 | @staticmethod 1167 | def weight_simple_rand(n: int, m: int, scale: float = 0.3) -> torch.Tensor: 1168 | return scale * torch.rand(n, m) 1169 | 1170 | --------------------------------------------------------------------------------