├── .gitignore ├── LICENSE ├── README.md ├── data ├── input_700_250_25.pkl ├── small_test_dvs_gesture.pickle ├── small_train_dvs_gesture.pickle ├── smile100.pkl ├── smile30.pkl ├── smile50.pkl ├── smile70.pkl └── smile95.pkl ├── figures ├── .DS_Store ├── ._torch_wage_acc_cifar10_210810.png ├── ._torch_wage_acc_cifar10_21088.png ├── ._torch_wage_acc_cifar10_2888.png ├── ._torch_wage_acc_cifar10_310810.png ├── ._utorch_wage_acc_cifar10_21088.png ├── ICONS_PQ_distr.png ├── ICONS_QuantSNN.png ├── ICONS_curves.pdf ├── ICONS_data_set_gest.png ├── ICONS_data_set_poker.png ├── ICONS_sur.png ├── ICONS_unscatter.pdf ├── ISCAS_schem1.png └── ISCAS_smile_black.png ├── localQ.py ├── prepGesture.py ├── prepPoker.py ├── qsnn_decolle.py ├── qsnn_precise.py ├── qsnn_util.py └── quantization.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | results/* 3 | .DS_Store 4 | .README.md.icloud 5 | data/train_dvs_gesture.pickle 6 | data/test_dvs_gesture.pickle 7 | data/slow_poker_500_train.pickle 8 | data/slow_poker_500_test.pickle 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Siddharth Joshi, Clemens JS Schaefer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quantized Spiking Neural Networks 2 | 3 | This repository contains the models and training scripts used in the papers: ["Quantizing Spiking Neural Networks with Integers"](https://dl.acm.org/doi/abs/10.1145/3407197.3407203) (ICONS 2020) and ["Memory Organization for Energy-Efficient Learning and Inference in Digital Neuromorphic Accelerators"](https://ieeexplore.ieee.org/document/9180443) (ISCAS 2020). 4 | 5 | ## Requiremnts 6 | 7 | - Python 8 | - PyTorch 9 | - torchvision 10 | - NumPy 11 | - pickle 12 | - argparse 13 | 14 | 15 | ## Quantized SNNs for Spatio-Temporal Patterns 16 | 17 |

18 | 19 |

20 | 21 | All relevant code for the experiments from the ISCAS paper is contained in qsnn_precise.py, quantization.py and qsnn_util.py. To run the experiments execute: 22 | 23 | ``` 24 | python qsnn_precise.py 25 | ``` 26 |

27 | 28 |

29 | 30 | You can specify desired setting either as command-line arguments or within qsnn_precise.py. 31 | 32 | Optional arguments: 33 | 34 | | Argument | Description | 35 | |:-----------------------|:---------------------------------------------------------| 36 | | --input INPUT | Input pickle file (default: ./data/input_700_250_25.pkl) | 37 | | --target TARGET | Target pattern pickle (default: ./data/smile95.pkl) | 38 | | --global_wb GLOBAL_WB | Weight bitwidth (default: 2) | 39 | | --global_ab GLOBAL_AB | Membrane potential, synapse state bitwidth (default: 8) | 40 | | --global_gb GLOBAL_GB | Gradient bitwidth (default: 8) | 41 | | --global_eb GLOBAL_EB | Error bitwidth (default: 8) | 42 | | --global_rb GLOBAL_RB | Gradient RNG bitwidth (default: 16) | 43 | | --time_step TIME_STEP | Simulation time step size (default: 0.001) | 44 | | --nb_steps NB_STEPS | Simulation steps (default: 250) | 45 | | --nb_epochs NB_EPOCHS | Simulation steps (default: 10000) | 46 | | --tau_mem TAU_MEM | Time constant for membrane potential (default: 0.01) | 47 | | --tau_syn TAU_SYN | Time constant for synapse (default: 0.005) | 48 | | --tau_vr TAU_VR | Time constant for Van Rossum distance (default: 0.005) | 49 | | --alpha ALPHA | Time constant for synapse (default: 0.75) | 50 | | --beta BETA | Time constant for Van Rossum distance (default: 0.875) | 51 | | --nb_inputs NB_INPUTS | Spatial input dimensions (default: 700) | 52 | | --nb_hidden NB_HIDDEN | Spatial hidden dimensions (default: 400) | 53 | | --nb_outputs NB_OUTPUTS| Spatial output dimensions (default: 250) | 54 | 55 | 56 | ## Quantized SNNs for Gesture Detection with Local Learning 57 | 58 |

59 | 60 |

61 | 62 | Download and extract the [DVS Slow Poker](http://www2.imse-cnm.csic.es/caviar/SLOWPOKERDVS.html#:~:text=The%20SLOW%2DPOKER%2DDVS%20database,diamond%2C%20heart%20or%20spade) and [DVS Gesture](https://www.research.ibm.com/dvsgesture/) data set. 63 | 64 | To prepare the data run the following commands in the respective directories (e.g. in the directory of the DVS Poker data or the DVS Gesture data). 65 | 66 | ``` 67 | python prepPoker.py 68 | ``` 69 | 70 | ``` 71 | python prepGesture.py 72 | ``` 73 | 74 | 75 | All relevant code for the experiments from the ICONS paper is contained in qsnn_decolle.py, quantization.py and localQ.py. To run the experiments execute: 76 | 77 | ``` 78 | python qsnn_decolle.py 79 | ``` 80 | 81 |

82 | 83 | 84 |

85 | 86 | 87 | You can specify desired setting either as command-line arguments or within qsnn_decolle.py. 88 | 89 | Optional arguments: 90 | 91 | | Argument | Description | 92 | |:-----------------------|:---------------------------------------------------------| 93 | | --data-set DATA_SET | Input date set: Poker/Gesture (default: Gesture) | 94 | | --global_wb GLOBAL_WB | Weight bitwidth (default: 8) | 95 | | --global_qb GLOBAL_QB | Synapse bitwidth (default: 10) | 96 | | --global_pb GLOBAL_PB | Membrane trace bitwidth (default: 12) | 97 | | --global_rfb GLOBAL_RFB | Refractory bitwidth (default: 2) | 98 | | --global_sb GLOBAL_SB | Learning signal bitwidth (default: 6) | 99 | | --global_gb GLOBAL_GB | Gradient bitwidth (default: 10) | 100 | | --global_eb GLOBAL_EB | Error bitwidth (default: 6) | 101 | | --global_ub GLOBAL_UB | Membrane Potential bitwidth (default: 6) | 102 | | --global_ab GLOBAL_AB | Activation bitwidth (default: 6) | 103 | | --global_sig GLOBAL_SIG | Sigmoid bitwidth (default: 6) | 104 | | --global_rb GLOBAL_RB | Gradient RNG bitwidth (default: 16) | 105 | | --global_lr GLOBAL_LR | Learning rate for quantized gradients (default: 1) | 106 | | --global_lr_sgd GLOBAL_LR_SGD | Learning rate for SGD (default: 1e-09) | 107 | | --global_beta GLOBAL_BETA | Beta for weight init (default: 1.5) | 108 | | --delta_t DELTA_T | Time step in ms (default: 0.001) | 109 | | --input_mode INPUT_MODE | Spike processing method (default: 0) | 110 | | --ds DS | Downsampling (default: 4) | 111 | | --epochs EPOCHS | Epochs for training (default: 320) | 112 | | --lr_div LR_DIV | Learning rate divide interval (default: 80) | 113 | | --batch_size BATCH_SIZE | Batch size (default: 72) | 114 | | --PQ_cap PQ_CAP | Value cap for membrane and synpase trace (default: 1) | 115 | | --weight_mult WEIGHT_MULT | Weight multiplier (default: 4e-05) | 116 | | --dropout_p DROPOUT_P | Dropout probability (default: 0.5) | 117 | | --lc_ampl LC_AMPL | Magnitude amplifier for weight init (default: 0.5) | 118 | | --l1 L1 | Regularizer 1 (default: 0.001) | 119 | | --l2 L2 | Regularizer 2 (default: 0.001) | 120 | | --tau_mem_lower TAU_MEM_LOWER | Tau mem lower bound (default: 5) | 121 | | --tau_mem_upper TAU_MEM_UPPER | Tau mem upper bound (default: 35) | 122 | | --tau_syn_lower TAU_SYN_LOWER | Tau syn lower bound (default: 5) | 123 | | --tau_syn_upper TAU_SYN_UPPER | Tau syn upper bound (default: 10) | 124 | | --tau_ref TAU_REF | Tau ref (default: 2.857142857142857) | 125 | 126 | -------------------------------------------------------------------------------- /data/input_700_250_25.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/input_700_250_25.pkl -------------------------------------------------------------------------------- /data/small_test_dvs_gesture.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/small_test_dvs_gesture.pickle -------------------------------------------------------------------------------- /data/small_train_dvs_gesture.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/small_train_dvs_gesture.pickle -------------------------------------------------------------------------------- /data/smile100.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile100.pkl -------------------------------------------------------------------------------- /data/smile30.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile30.pkl -------------------------------------------------------------------------------- /data/smile50.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile50.pkl -------------------------------------------------------------------------------- /data/smile70.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile70.pkl -------------------------------------------------------------------------------- /data/smile95.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile95.pkl -------------------------------------------------------------------------------- /figures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/.DS_Store -------------------------------------------------------------------------------- /figures/._torch_wage_acc_cifar10_210810.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._torch_wage_acc_cifar10_210810.png -------------------------------------------------------------------------------- /figures/._torch_wage_acc_cifar10_21088.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._torch_wage_acc_cifar10_21088.png -------------------------------------------------------------------------------- /figures/._torch_wage_acc_cifar10_2888.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._torch_wage_acc_cifar10_2888.png -------------------------------------------------------------------------------- /figures/._torch_wage_acc_cifar10_310810.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._torch_wage_acc_cifar10_310810.png -------------------------------------------------------------------------------- /figures/._utorch_wage_acc_cifar10_21088.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._utorch_wage_acc_cifar10_21088.png -------------------------------------------------------------------------------- /figures/ICONS_PQ_distr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_PQ_distr.png -------------------------------------------------------------------------------- /figures/ICONS_QuantSNN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_QuantSNN.png -------------------------------------------------------------------------------- /figures/ICONS_curves.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_curves.pdf -------------------------------------------------------------------------------- /figures/ICONS_data_set_gest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_data_set_gest.png -------------------------------------------------------------------------------- /figures/ICONS_data_set_poker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_data_set_poker.png -------------------------------------------------------------------------------- /figures/ICONS_sur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_sur.png -------------------------------------------------------------------------------- /figures/ICONS_unscatter.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_unscatter.pdf -------------------------------------------------------------------------------- /figures/ISCAS_schem1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ISCAS_schem1.png -------------------------------------------------------------------------------- /figures/ISCAS_smile_black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ISCAS_smile_black.png -------------------------------------------------------------------------------- /localQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import pickle 6 | import time 7 | import math 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | import uuid 12 | 13 | import quantization 14 | 15 | 16 | global lc_ampl 17 | lc_ampl = .5 18 | 19 | global shift_prob 20 | shift_prob = .5 21 | 22 | 23 | def create_graph(plot_file_name, diff_layers_acc, ds_name, best_test): 24 | 25 | bit_string = str(quantization.global_wb) + str(quantization.global_ub) + str(quantization.global_pb) + str(quantization.global_qb) + str(quantization.global_rfb) + " " + str(quantization.global_sb) + str(quantization.global_ab) + str(quantization.global_sig) + str(quantization.global_eb) + str(quantization.global_gb) 26 | bit_string = bit_string.replace("None", "f") 27 | 28 | 29 | fig, ax1 = plt.subplots() 30 | fig.set_size_inches(8.4, 4.8) 31 | plt.title(ds_name + " "+ bit_string + " Test3: " + str(np.round( best_test.item(), 4)) + " " +str(shift_prob)) 32 | ax1.set_xlabel('Epochs') 33 | ax1.set_ylabel('Accuracy') 34 | t = np.arange(len(diff_layers_acc['loss'])) 35 | ax1.plot(t, diff_layers_acc['train1'], 'g--', label = 'Train 1') 36 | ax1.plot(t, diff_layers_acc['train2'], 'b--', label = 'Train 2') 37 | ax1.plot(t, diff_layers_acc['train3'], 'r--', label = 'Train 3') 38 | ax1.plot(t, diff_layers_acc['test1'], 'g-', label = 'Test 1') 39 | ax1.plot(t, diff_layers_acc['test2'], 'b-', label = 'Test 2') 40 | ax1.plot(t, diff_layers_acc['test3'], 'r-', label = 'Test 3') 41 | ax1.plot([], [], 'k-', label = 'Loss') 42 | ax1.legend(bbox_to_anchor=(1.20,1), loc="upper left") 43 | #ax1.text(1.20, 0.1, str(max(diff_layers_acc['test3']))) 44 | 45 | ax2 = ax1.twinx() 46 | ax2.set_ylabel('Loss') 47 | ax2.plot(t, diff_layers_acc['loss'], 'k-', label = 'Loss') 48 | 49 | fig.tight_layout() 50 | plt.savefig("figures/"+plot_file_name + ".png") 51 | plt.close() 52 | 53 | def create_graph2(plot_file_name, diff_layers_acc, ds_name): 54 | 55 | bit_string = str(quantization.global_wb) + str(quantization.global_ub) + str(quantization.global_pb) + str(quantization.global_qb) + str(quantization.global_rfb) + " " + str(quantization.global_sb) + str(quantization.global_ab) + str(quantization.global_sig) + str(quantization.global_eb) + str(quantization.global_gb) 56 | bit_string = bit_string.replace("None", "f") 57 | 58 | 59 | fig, ax1 = plt.subplots() 60 | fig.set_size_inches(8.4, 4.8) 61 | plt.title(ds_name + " Act "+ bit_string) 62 | ax1.set_xlabel('Epochs') 63 | ax1.set_ylabel('# Spikes/Updates') 64 | t = np.arange(len(diff_layers_acc['loss'])) 65 | 66 | ax1.plot(t, diff_layers_acc['act_train1'], 'g--', label = 'Train 1') 67 | ax1.plot(t, diff_layers_acc['act_train2'], 'b--', label = 'Train 2') 68 | ax1.plot(t, diff_layers_acc['act_train3'], 'r--', label = 'Train 3') 69 | ax1.plot(t, diff_layers_acc['act_test1'], 'g-', label = 'Test 1') 70 | ax1.plot(t, diff_layers_acc['act_test2'], 'b-', label = 'Test 2') 71 | ax1.plot(t, diff_layers_acc['act_test3'], 'r-', label = 'Test 3') 72 | ax1.plot(t, diff_layers_acc['w1update'], 'm-', label = 'W update 1') 73 | ax1.plot(t, diff_layers_acc['w2update'], 'k-', label = 'W update 2') 74 | ax1.plot(t, diff_layers_acc['w3update'], 'y-', label = 'W update 3') 75 | ax1.legend(bbox_to_anchor=(1.20,1), loc="upper left") 76 | #ax1.text(1.20, 0.1, str(max(diff_layers_acc['test3']))) 77 | 78 | #ax2 = ax1.twinx() 79 | #ax2.set_ylabel('Loss') 80 | #ax2.plot(t, diff_layers_acc['loss'], 'k-', label = 'Loss') 81 | 82 | fig.tight_layout() 83 | plt.savefig("figures/"+plot_file_name+ "_act.png") 84 | plt.close() 85 | 86 | 87 | 88 | def acc_comp(rread_hist_train, y_local, bools = False): 89 | rhts = torch.stack(rread_hist_train, dim = 0) 90 | if bools: 91 | return (rhts.mode(0)[0] == y_local).float() 92 | return (rhts.mode(0)[0] == y_local).float().mean() 93 | 94 | def clee_spikes(T, rates): 95 | spikes = np.ones((T, + np.prod(rates.shape))) 96 | spikes[np.random.binomial(1, (1000. - rates.flatten())/1000, size=(T, np.prod(rates.shape))).astype('bool')] = 0 97 | return spikes.T.reshape((rates.shape + (T,))) 98 | 99 | def prep_input(x_local, input_mode, channels = 2): 100 | #two channel trick / decolle 101 | if input_mode == 0: 102 | x_local[x_local > 0] = 1 103 | 104 | #down_spikes = torch.cat((x_local, x_local), dim = 1) 105 | #mask1 = (down_spikes > 0) # this might change 106 | #mask2 = (down_spikes < 0) 107 | #mask1[:,0,:,:] = False 108 | #mask2[:,1,:,:] = False 109 | #down_spikes = torch.zeros_like(down_spikes) 110 | #down_spikes[mask1] = 1 111 | #down_spikes[mask2] = 1 112 | return x_local 113 | #bi directional 114 | if input_mode == 2: 115 | x_local[:,0,:,:] *= -1 116 | new_spikes = x_local[:,0,:,:] + x_local[:,1,:,:] 117 | new_spikes = new_spikes.reshape([x_local.shape[0], 1, x_local.shape[2], x_local.shape[3]]) 118 | new_spikes[new_spikes > 0] = 1 119 | new_spikes[new_spikes < 0] = -1 120 | return new_spikes 121 | # same same but different 122 | if input_mode == 1: 123 | down_spikes = x_local 124 | down_spikes[down_spikes != 0] = 1 125 | return down_spikes 126 | #bi directional two channels 127 | if input_mode == 3: 128 | x_local[:,0,:,:] *= -1 129 | new_spikes = x_local[:,0,:,:] + x_local[:,1,:,:] 130 | new_spikes = new_spikes.reshape([x_local.shape[0], 1, x_local.shape[2], x_local.shape[3]]) 131 | new_spikes[new_spikes > 0] = 1 132 | new_spikes[new_spikes < 0] = -1 133 | 134 | new_spikes = torch.cat((new_spikes, new_spikes), dim = 1) 135 | return new_spikes 136 | else: 137 | return x_local 138 | 139 | 140 | def sparse_data_generator_DVSPoker(X, y, batch_size, nb_steps, shuffle, device, test = False): 141 | number_of_batches = int(np.ceil(len(y)/batch_size)) 142 | sample_index = np.arange(len(y)) 143 | nb_steps = nb_steps -1 144 | y = np.array(y) 145 | 146 | if shuffle: 147 | np.random.shuffle(sample_index) 148 | 149 | total_batch_count = 0 150 | counter = 0 151 | while counter 0] = 1 167 | 168 | sparse_matrix = sparse_matrix.reshape(torch.Size([sparse_matrix.shape[0], 1, sparse_matrix.shape[1], sparse_matrix.shape[2], sparse_matrix.shape[3]])) 169 | 170 | y_batch = torch.tensor(y[batch_index], dtype = int) 171 | try: 172 | torch.cuda.empty_cache() 173 | yield sparse_matrix.to(device=device), y_batch.to(device=device) 174 | counter += 1 175 | except StopIteration: 176 | return 177 | 178 | def sparse_data_generator_DVSGesture(X, y, batch_size, nb_steps, shuffle, device, ds = 4, test = False, x_size = 32, y_size = 32): 179 | number_of_batches = int(np.ceil(len(y)/batch_size)) 180 | sample_index = np.arange(len(y)) 181 | nb_steps = nb_steps -1 182 | y = np.array(y) 183 | 184 | if shuffle: 185 | np.random.shuffle(sample_index) 186 | 187 | total_batch_count = 0 188 | counter = 0 189 | while counter we sample 197 | if test: 198 | start_ts = 0 199 | else: 200 | start_ts = np.random.choice(np.arange(np.max(X[idx][:,0]) - nb_steps),1) 201 | temp = X[idx][X[idx][:,0] >= start_ts] 202 | temp = temp[temp[:,0] <= start_ts+nb_steps] 203 | temp = np.append(np.ones((temp.shape[0], 1))*bc, temp, axis=1) 204 | temp[:,1] = temp[:,1] - start_ts 205 | all_events = np.append(all_events, temp, axis = 0) 206 | 207 | # to matrix 208 | #all_events[:,4][all_events[:,4] == 0] = -1 209 | # spike_ind = (x_local == 1).nonzero() 210 | # spike_ind = spike_ind[torch.bernoulli((.5) * torch.ones(spike_ind.shape[0])).bool()] 211 | # spike_ind = spike_ind[torch.randperm(spike_ind.shape[0])] 212 | # split_point = int(spike_ind.shape[0]/2) 213 | # forward_spike = spike_ind[0:split_point] 214 | # backward_spike = spike_ind[split_point:] 215 | 216 | # x_local[torch.sparse.FloatTensor(forward_spike.t(), torch.ones(forward_spike.shape[0]).to(device)).to_dense().bool()] = 0 217 | # forward_spike[:,4] = forward_spike[:,4] + 1 218 | # forward_spike[forward_spike[:,4] == 500] = 499 219 | # x_local[torch.sparse.FloatTensor(forward_spike.t(), torch.ones(forward_spike.shape[0]).to(device)).to_dense().bool()] = 1 220 | 221 | # x_local[torch.sparse.FloatTensor(backward_spike.t(), torch.ones(backward_spike.shape[0]).to(device)).to_dense().bool()] = 0 222 | # backward_spike[:,4] = backward_spike[:,4] - 1 223 | # backward_spike[backward_spike[:,4] == -1] = 0 224 | # x_local[torch.sparse.FloatTensor(backward_spike.t(), torch.ones(backward_spike.shape[0]).to(device)).to_dense().bool()] = 1 225 | 226 | 227 | #change 228 | # by plus minus one process... 229 | # change_mask = torch.bernoulli((shift_prob) * torch.ones(all_events.shape[0])).bool() 230 | # forward_mask = change_mask * torch.bernoulli((.5) * torch.ones(all_events.shape[0])).bool() 231 | # backward_mask = (change_mask != forward_mask) 232 | # all_events[forward_mask, 1] = all_events[forward_mask, 1] + 1 #torch.randn(all_events[forward_mask, 1].shape[0]) 233 | # all_events[backward_mask, 1] = all_events[backward_mask, 1] - 1 234 | 235 | all_events[:, 1] = all_events[:, 1] + (shift_prob*np.random.randn(all_events[:, 1].shape[0])).astype(int) 236 | 237 | neg_ind = (all_events[:,1] < 0) 238 | pos_ind = (all_events[:,1] > nb_steps) 239 | all_events[neg_ind,1] = 0 240 | all_events[pos_ind,1] = int(nb_steps) 241 | 242 | 243 | all_events = all_events[:,[0,4,2,3,1]] 244 | all_events[:, 2] = all_events[:, 2]//ds 245 | all_events[:, 3] = all_events[:, 3]//ds 246 | sparse_matrix = torch.sparse.FloatTensor(torch.LongTensor(all_events[:,[True, True, True, True, True]].T), torch.ones_like(torch.tensor(all_events[:,0])), torch.Size([len(y_batch),2,x_size,y_size,int(nb_steps+1)])).to_dense().type(torch.int16) 247 | 248 | # quick trick... 249 | #sparse_matrix[sparse_matrix != 0] = 1 250 | #sparse_matrix[sparse_matrix > 0] = 1 251 | #sparse_matrix = sparse_matrix.reshape(torch.Size([sparse_matrix.shape[0], 1, sparse_matrix.shape[1], sparse_matrix.shape[2], sparse_matrix.shape[3]])) 252 | 253 | 254 | try: 255 | torch.cuda.empty_cache() 256 | yield sparse_matrix.to(device=device), y_batch.to(device=device) 257 | counter += 1 258 | except StopIteration: 259 | return 260 | 261 | def sparse_data_generator_Static(X, y, batch_size, nb_steps, samples, max_hertz, shuffle=True, device=torch.device("cpu")): 262 | sample_idx = torch.randperm(len(X))[:samples] 263 | number_of_batches = int(np.ceil(samples/batch_size)) 264 | nb_steps = int(nb_steps) 265 | 266 | counter = 0 267 | while counterac', input, weight) 305 | if bias is not None: 306 | output += bias.unsqueeze(0).expand_as(output) 307 | 308 | if quantization.global_sb is not None: 309 | output = output/scale 310 | # quant act 311 | if quantization.global_ab is not None: 312 | output, _ = quantization.quant_act(output) 313 | 314 | ctx.save_for_backward(input, weight, weight_fa, bias) 315 | 316 | # ify part here... shall we bring it between 0 and 1 for the targets 317 | return (output+1)/2 318 | 319 | @staticmethod 320 | def backward(ctx, grad_output): 321 | input, weight, weight_fa, bias = ctx.saved_tensors 322 | grad_input = None 323 | 324 | if quantization.global_eb is not None: 325 | quant_error = quantization.quant_err(grad_output) #* clip_info.float() 326 | else: 327 | quant_error = grad_output 328 | 329 | if ctx.needs_input_grad[0]: 330 | grad_input = torch.einsum('ab,bc->ac', quant_error, weight_fa) 331 | 332 | # quantizing here for sigmoid input 333 | if quantization.global_eb is not None: 334 | grad_input = quantization.quant_err(grad_input) 335 | else: 336 | grad_input = grad_input 337 | 338 | return grad_input, None, None, None, None 339 | 340 | class QLinearLayerSign(nn.Module): 341 | '''from https://github.com/L0SG/feedback-alignment-pytorch/''' 342 | def __init__(self, input_features, output_features, pass_through = False, bias = True, dtype = None, device = None): 343 | super(QLinearLayerSign, self).__init__() 344 | self.input_features = input_features 345 | self.output_features = output_features 346 | self.dtype = dtype 347 | self.device = device 348 | 349 | # weight and bias for forward pass 350 | self.weights = nn.Parameter(torch.empty((output_features, input_features), device=device, dtype=dtype, requires_grad=False)) 351 | self.weight_fa = nn.Parameter(torch.empty((output_features, input_features), device=device, dtype=dtype, requires_grad=False)) 352 | self.bias = nn.Parameter(torch.empty((output_features), device=device, dtype=dtype, requires_grad=False)) 353 | 354 | if quantization.global_sb is not None: 355 | self.L_min = quantization.global_beta/quantization.step_d(torch.tensor([float(quantization.global_sb)])) 356 | #self.L = np.sqrt(6/self.input_features) 357 | self.L = lc_ampl/np.sqrt(self.input_features) 358 | self.scale = 2 ** round(math.log(self.L_min / self.L, 2.0)) 359 | self.scale = self.scale if self.scale > 1 else 1.0 360 | self.L = np.max([self.L, self.L_min]) 361 | 362 | #since those weights are fixed lets just initialize them between -1 and 1 to make use of all given bits 363 | self.L = lc_ampl/np.sqrt(self.input_features) 364 | self.scale = 2 ** round(math.log((1-self.L_min)/self.L, 2.0)) 365 | self.scale = self.scale if self.scale > 1 else 1.0 366 | self.L = 1 367 | 368 | torch.nn.init.uniform_(self.weights, a = -self.L, b = self.L) 369 | torch.nn.init.uniform_(self.weight_fa, a = -self.L, b = self.L) 370 | if bias: 371 | torch.nn.init.uniform_(self.bias, a = -self.L, b = self.L) 372 | else: 373 | self.bias = None 374 | 375 | # quantize them 376 | with torch.no_grad(): 377 | self.weights.data = quantization.quant_s(self.weights.data) 378 | self.weight_fa.data = quantization.quant_s(self.weight_fa.data) 379 | if self.bias is not None: 380 | self.bias.data = quantization.quant_s(self.bias.data) 381 | else: 382 | self.scale = 1 383 | self.stdv = lc_ampl/np.sqrt(self.input_features) 384 | torch.nn.init.uniform_(self.weights, a = -self.stdv, b = self.stdv) 385 | torch.nn.init.uniform_(self.weight_fa, a = -self.stdv, b = self.stdv) 386 | if bias: 387 | torch.nn.init.uniform_(self.bias, a = -self.stdv, b = self.stdv) 388 | else: 389 | self.bias = None 390 | 391 | # sign concordant weights in fwd and bwd pass 392 | #self.weight_fa = self.weights 393 | nonzero_mask = (self.weights.data != 0) 394 | self.weight_fa.data[nonzero_mask] *= torch.sign((torch.sign(self.weights.data) == torch.sign(self.weight_fa.data)).type(dtype) -.5)[nonzero_mask] 395 | 396 | 397 | def forward(self, input): 398 | return QLinearFunctional.apply(input, self.weights, self.weight_fa, self.bias, self.scale) 399 | 400 | 401 | 402 | class QSConv2dFunctional(torch.autograd.Function): 403 | @staticmethod 404 | def forward(ctx, input, weights, bias, scale, padding = 0, weight_mult = 1): 405 | if quantization.global_wb is not None: 406 | w_quant = quantization.quant_w(weights/weight_mult, 1) *weight_mult 407 | bias_quant = quantization.quant_w(bias/weight_mult, 1) *weight_mult 408 | else: 409 | w_quant = weights 410 | bias_quant = bias 411 | ctx.padding = padding 412 | 413 | output = F.conv2d(input = input, weight = w_quant, bias = bias_quant, padding = ctx.padding) 414 | if quantization.global_wb is not None: 415 | output = output / scale 416 | 417 | ctx.save_for_backward(input, w_quant, bias_quant) 418 | 419 | 420 | return output 421 | 422 | @staticmethod 423 | def backward(ctx, grad_output): 424 | input, w_quant, bias_quant = ctx.saved_tensors 425 | grad_input = grad_weight = grad_bias = None 426 | 427 | if quantization.global_eb is not None: 428 | quant_error = quantization.quant_err(grad_output) 429 | else: 430 | quant_error = grad_output 431 | 432 | # compute quantized error 433 | if ctx.needs_input_grad[0]: 434 | grad_input = torch.nn.grad.conv2d_input(input.shape, w_quant, quant_error, padding = ctx.padding) 435 | # computed quantized gradient 436 | if ctx.needs_input_grad[1]: 437 | if quantization.global_gb is not None: 438 | grad_weight = quantization.quant_grad(torch.nn.grad.conv2d_weight(input, w_quant.shape, quant_error, padding = ctx.padding)).float() 439 | else: 440 | grad_weight = torch.nn.grad.conv2d_weight(input, w_quant.shape, quant_error, padding = ctx.padding) 441 | # computed quantized bias 442 | if bias_quant is not None and ctx.needs_input_grad[2]: 443 | if quantization.global_gb is not None: 444 | grad_bias = quantization.quant_grad(quant_error.sum((0,2,3)).squeeze(0)).float() 445 | else: 446 | grad_bias = quant_error.sum((0,2,3)).squeeze(0) 447 | 448 | if input.shape[2] == 13: 449 | quantization.global_w3update += grad_bias.nonzero().shape[0] + grad_weight.nonzero().shape[0] 450 | if input.shape[2] == 15: 451 | quantization.global_w2update += grad_bias.nonzero().shape[0] + grad_weight.nonzero().shape[0] 452 | if input.shape[2] == 32: 453 | quantization.global_w1update += grad_bias.nonzero().shape[0] + grad_weight.nonzero().shape[0] 454 | return grad_input, grad_weight, grad_bias, None, None, None, None 455 | 456 | 457 | class LIFConv2dLayer(nn.Module): 458 | def __init__(self, inp_shape, kernel_size, out_channels, tau_syn, tau_mem, tau_ref, delta_t, pooling = 1, padding = 0, bias = True, thr = 1, device=torch.device("cpu"), dtype = torch.float, dropout_p = .5, output_neurons = 10, loss_fn = None, l1 = 0, l2 = 0, PQ_cap = 1, weight_mult = 4e-5): 459 | super(LIFConv2dLayer, self).__init__() 460 | self.device = device 461 | self.dtype = dtype 462 | self.inp_shape = inp_shape 463 | self.kernel_size = kernel_size 464 | self.out_channels = out_channels 465 | self.output_neurons = output_neurons 466 | self.padding = padding 467 | self.pooling = pooling 468 | self.thr = thr 469 | self.PQ_cap = PQ_cap 470 | self.weight_mult = weight_mult 471 | self.fan_in = kernel_size * kernel_size * inp_shape[0] 472 | 473 | self.dropout_learning = nn.Dropout(p=dropout_p) 474 | self.dropout_p = dropout_p 475 | self.l1 = l1 476 | self.l2 = l2 477 | self.loss_fn = loss_fn 478 | 479 | self.weights = nn.Parameter(torch.empty((self.out_channels, inp_shape[0], self.kernel_size, self.kernel_size), device=device, dtype=dtype, requires_grad=True)) 480 | 481 | # decide which one you like 482 | self.stdv = 1 / np.sqrt(self.fan_in) #/ 250 * 1e-2 483 | #self.stdv = np.sqrt(6 / self.fan_in) #* self.weight_mult 484 | if quantization.global_wb is not None: 485 | self.L_min = quantization.global_beta/quantization.step_d(torch.tensor([float(quantization.global_wb)])) 486 | #self.stdv = np.sqrt(6/self.fan_in) 487 | self.scale = 2 ** round(math.log(self.L_min / self.stdv, 2.0)) 488 | self.scale = self.scale if self.scale > 1 else 1.0 489 | self.L = np.max([self.stdv, self.L_min]) 490 | torch.nn.init.uniform_(self.weights, a = -self.L * self.weight_mult, b = self.L* self.weight_mult) 491 | else: 492 | self.scale = 1 493 | torch.nn.init.uniform_(self.weights, a = -self.stdv * self.weight_mult, b = self.stdv* self.weight_mult) 494 | 495 | # bias has a different scale... just why? 496 | if bias: 497 | self.bias = nn.Parameter(torch.empty(self.out_channels, device=device, dtype=dtype, requires_grad=True)) 498 | if quantization.global_wb is not None: 499 | bias_L = np.max([self.stdv* 1e2, self.L_min]) 500 | torch.nn.init.uniform_(self.bias, a = -bias_L * self.weight_mult, b = bias_L* self.weight_mult) 501 | else: 502 | torch.nn.init.uniform_(self.bias, a = -self.stdv* self.weight_mult* 1e2, b = self.stdv* self.weight_mult * 1e2) 503 | else: 504 | self.register_parameter('bias', None) 505 | 506 | self.mpool = nn.MaxPool2d(kernel_size = self.pooling, stride = self.pooling, padding = (self.pooling-1)//2, return_indices=False) 507 | self.out_shape2 = self.mpool(QSConv2dFunctional.apply(torch.zeros((1,)+self.inp_shape, dtype = dtype).to(device), self.weights, self.bias, self.scale, self.padding)).shape[1:] #self.pooling, 508 | self.out_shape = QSConv2dFunctional.apply(torch.zeros((1,)+self.inp_shape, dtype = dtype).to(device), self.weights, self.bias, self.scale, self.padding).shape[1:] 509 | 510 | self.sign_random_readout = QLinearLayerSign(input_features = np.prod(self.out_shape2), output_features = output_neurons, pass_through = False, bias = False, dtype = self.dtype, device = device).to(device) 511 | 512 | # tau quantization, static hardware friendly values 513 | if tau_syn.shape[0] == 2: 514 | self.tau_syn = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_syn[0], tau_syn[1]).to(device) 515 | self.beta = 1. - 1e-3 / self.tau_syn 516 | self.tau_syn = 1. / (1. - self.beta) 517 | else: 518 | self.beta = torch.tensor([1 - delta_t / tau_syn], dtype = dtype).to(device) 519 | self.tau_syn = 1. / (1. - self.beta) 520 | 521 | 522 | if tau_mem.shape[0] == 2: 523 | self.tau_mem = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_mem[0], tau_mem[1]).to(device) 524 | self.alpha = 1. - 1e-3 / self.tau_mem 525 | self.tau_mem = 1. / (1. - self.alpha) 526 | else: 527 | self.alpha = torch.tensor([1 - delta_t / tau_mem], dtype = dtype).to(device) 528 | self.tau_mem = 1. / (1. - self.alpha) 529 | 530 | 531 | if tau_ref.shape[0] == 2: 532 | self.tau_ref = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_ref[0], tau_ref[1]).to(device) 533 | self.gamma = 1. - 1e-3 / self.tau_gamma 534 | self.tau_ref = 1. / (1. - self.gamma) 535 | else: 536 | self.gamma = torch.tensor([1 - delta_t / tau_ref], dtype = dtype).to(device) 537 | self.tau_ref = 1. / (1. - self.gamma) 538 | 539 | self.r_scale = 1/(1-self.gamma) # the one comes from decolle, best value ? 540 | #self.q_scale = self.tau_syn/(1-self.beta) 541 | #self.q_scale = self.q_scale.max() 542 | # p_scale should be max overall to differentiate input signals 543 | #self.p_scale = (self.tau_mem * self.q_scale*self.PQ_cap)/(1-self.alpha) 544 | #self.p_scale = self.p_scale.max() 545 | 546 | self.inp_mult_q = self.tau_syn##1/self.PQ_cap * (1-self.beta.max()) # 547 | self.inp_mult_p = self.tau_mem##1/self.PQ_cap * (1-self.alpha.max()) # 548 | #self.pmult = self.p_scale * self.PQ_cap * self.weight_mult 549 | 550 | # those might be clamped as in chop off values. 551 | self.Q_scale = (self.tau_syn/(1-self.beta)).max() 552 | self.P_scale = ((self.tau_mem * self.Q_scale)/(1-self.alpha)).max() 553 | self.Q_scale = (self.tau_syn/(1-self.beta)).max() 554 | self.R_scale = 1/(1-self.gamma) 555 | 556 | if quantization.global_wb is not None: 557 | with torch.no_grad(): 558 | self.weights.data = quantization.quant_w(self.weights.data) 559 | if self.bias is not None: 560 | self.bias.data = quantization.quant_w(self.bias.data) 561 | 562 | 563 | def state_init(self, batch_size): 564 | self.P = torch.zeros((batch_size,) + self.inp_shape, dtype = self.dtype).detach().to(self.device) 565 | self.Q = torch.zeros((batch_size,) + self.inp_shape, dtype = self.dtype).detach().to(self.device) 566 | self.R = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device) 567 | self.S = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device) 568 | self.U = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device) 569 | 570 | 571 | def forward(self, input_t, y_local, train_flag = False, test_flag = False): 572 | # probably dont need to quantize because gb steps are arleady in the right level... just clipping 573 | if quantization.global_gb is not None: 574 | with torch.no_grad(): 575 | self.weights.data = quantization.clip(self.weights.data/self.weight_mult, quantization.global_gb)*self.weight_mult 576 | if self.bias is not None: 577 | self.bias.data = quantization.clip(self.bias.data/self.weight_mult, quantization.global_gb)*self.weight_mult 578 | if quantization.global_rfb is not None: 579 | # R always using full scale? 580 | self.R = quantization.quant01(self.R/self.R_scale, quantization.global_rfb)*self.R_scale 581 | 582 | #self.P, self.R, self.Q = self.alpha * self.P + self.tau_mem * self.Q, self.gamma * self.R, self.beta * self.Q + self.tau_syn * input_t 583 | #dtype necessary 584 | self.P, self.R, self.Q = self.alpha * self.P + self.inp_mult_p * self.Q, self.gamma * self.R, self.beta * self.Q + self.inp_mult_q * input_t.type(self.dtype) 585 | 586 | if self.PQ_cap != 1: 587 | self.P = torch.clamp(self.P, 0, self.P_scale*self.PQ_cap) 588 | self.Q = torch.clamp(self.Q, 0, self.Q_scale*self.PQ_cap) 589 | 590 | if quantization.global_pb is not None: 591 | self.P = torch.clamp(self.P/(self.P_scale*self.PQ_cap), 0, 1) 592 | self.P = quantization.quant01(self.P, quantization.global_pb)*(self.P_scale*self.PQ_cap) 593 | if quantization.global_qb is not None: 594 | self.Q = torch.clamp(self.Q/(self.Q_scale*self.PQ_cap), 0, 1) 595 | self.Q = quantization.quant01(self.Q, quantization.global_qb)*(self.Q_scale*self.PQ_cap) 596 | 597 | #self.U = QSConv2dFunctional.apply(self.P * self.pmult, self.weights, self.bias, self.scale, self.padding) - self.R 598 | self.U = QSConv2dFunctional.apply(self.P, self.weights, self.bias, self.scale, self.padding, self.weight_mult) - self.R #* self.r_scale 599 | if quantization.global_ub is not None: 600 | self.U = quantU.apply(self.U) 601 | self.S = (self.U >= self.thr).type(self.dtype) #float() 602 | self.R += self.S * 1#(1-self.gamma) 603 | 604 | 605 | if test_flag or train_flag: 606 | self.U_aux = torch.sigmoid(self.U) # quantize this function.... at some point 607 | self.U_aux = self.mpool(self.U_aux) 608 | 609 | rreadout = self.dropout_learning(self.sign_random_readout(self.U_aux.reshape([input_t.shape[0], np.prod(self.out_shape2)]))) * self.dropout_p 610 | 611 | if train_flag: 612 | if quantization.global_eb is not None: 613 | part1 = quantization.SSE(rreadout, y_local) 614 | #part2 = self.l1 * 200e-1 * F.relu((self.U+.01)).mean() 615 | part2 = self.l1 * 200e-1 * F.relu((self.U_aux+.01)).mean() 616 | #part3 = self.l2 *1e-1* F.relu(.1-self.U_aux.mean()) 617 | part3 = self.l2 *1e-1* F.relu(.1-self.U.mean()) 618 | loss_gen = part1 + part2 + part3 619 | else: 620 | part1 = self.loss_fn(rreadout, y_local) 621 | #part2 = self.l1 * 200e-1 * F.relu((self.U+.01)).mean() 622 | part2 = self.l1 * 200e-1 * F.relu((self.U_aux+.01)).mean() 623 | #part3 = self.l2 *1e-1* F.relu(.1-self.U_aux.mean()) 624 | part3 = self.l2 *1e-1* F.relu(.1-self.U.mean()) 625 | loss_gen = part1 + part2 + part3 626 | #loss_gen = self.loss_fn(rreadout, y_local) + self.l1 * 200e-1 * F.relu((self.U+.01)).mean() + self.l2 *1e-1* F.relu(.1-self.U_aux.mean()) 627 | #loss_gen = self.loss_fn(rreadout, y_local) + self.l1 * 200e-1 * F.relu((self.U+.01)).mean() + self.l2 *1e-1* F.relu(.1-self.U_aux.mean()) 628 | else: 629 | part1 = None 630 | part2 = None 631 | part3 = None 632 | loss_gen = None 633 | else: 634 | part1 = None 635 | part2 = None 636 | part3 = None 637 | loss_gen = None 638 | rreadout = torch.tensor([[0]]) 639 | 640 | 641 | return self.mpool(self.S), loss_gen, rreadout.argmax(1), [part1, part2, part3] 642 | 643 | 644 | 645 | class DTNLIFConv2dLayer(nn.Module): 646 | def __init__(self, inp_shape, kernel_size, out_channels, tau_syn, tau_mem, tau_ref, delta_t, pooling = 1, padding = 0, bias = True, thr = 1, device=torch.device("cpu"), dtype = torch.float, dropout_p = .5, output_neurons = 10, loss_fn = None, l1 = 0, l2 = 0, PQ_cap = 1, weight_mult = 4e-5): 647 | super(DTNLIFConv2dLayer, self).__init__() 648 | self.device = device 649 | self.dtype = dtype 650 | self.inp_shape = inp_shape 651 | self.kernel_size = kernel_size 652 | self.out_channels = out_channels 653 | self.output_neurons = output_neurons 654 | self.padding = padding 655 | self.pooling = pooling 656 | self.thr = thr 657 | self.PQ_cap = PQ_cap 658 | self.weight_mult = weight_mult 659 | self.fan_in = kernel_size * kernel_size * inp_shape[0] 660 | 661 | self.dropout_learning = nn.Dropout(p=dropout_p) 662 | self.dropout_p = dropout_p 663 | self.l1 = l1 664 | self.l2 = l2 665 | self.loss_fn = loss_fn 666 | 667 | self.weights = nn.Parameter(torch.empty((self.out_channels, inp_shape[0], self.kernel_size, self.kernel_size), device=device, dtype=dtype, requires_grad=True)) 668 | 669 | # decide which one you like 670 | self.stdv = 1 / np.sqrt(self.fan_in) #/ 250 * 1e-2 671 | #self.stdv = np.sqrt(6 / self.fan_in) #* self.weight_mult 672 | if quantization.global_wb is not None: 673 | self.L_min = quantization.global_beta/quantization.step_d(torch.tensor([float(quantization.global_wb)])) 674 | #self.stdv = np.sqrt(6/self.fan_in) 675 | self.scale = 2 ** round(math.log(self.L_min / self.stdv, 2.0)) 676 | self.scale = self.scale if self.scale > 1 else 1.0 677 | self.L = np.max([self.stdv, self.L_min]) 678 | torch.nn.init.uniform_(self.weights, a = -self.L * self.weight_mult, b = self.L* self.weight_mult) 679 | else: 680 | self.scale = 1 681 | torch.nn.init.uniform_(self.weights, a = -self.stdv * self.weight_mult, b = self.stdv* self.weight_mult) 682 | 683 | # bias has a different scale... just why? 684 | if bias: 685 | self.bias = nn.Parameter(torch.empty(self.out_channels, device=device, dtype=dtype, requires_grad=True)) 686 | if quantization.global_wb is not None: 687 | bias_L = np.max([self.stdv* 1e2, self.L_min]) 688 | torch.nn.init.uniform_(self.bias, a = -bias_L * self.weight_mult, b = bias_L* self.weight_mult) 689 | else: 690 | torch.nn.init.uniform_(self.bias, a = -self.stdv* self.weight_mult* 1e2, b = self.stdv* self.weight_mult * 1e2) 691 | else: 692 | self.register_parameter('bias', None) 693 | 694 | self.mpool = nn.MaxPool2d(kernel_size = self.pooling, stride = self.pooling, padding = (self.pooling-1)//2, return_indices=False) 695 | self.out_shape2 = self.mpool(QSConv2dFunctional.apply(torch.zeros((1,)+self.inp_shape, dtype = dtype).to(device), self.weights, self.bias, self.scale, self.padding)).shape[1:] #self.pooling, 696 | self.out_shape = QSConv2dFunctional.apply(torch.zeros((1,)+self.inp_shape, dtype = dtype).to(device), self.weights, self.bias, self.scale, self.padding).shape[1:] 697 | 698 | self.sign_random_readout = QLinearLayerSign(input_features = np.prod(self.out_shape2), output_features = output_neurons, pass_through = False, bias = False, dtype = self.dtype, device = device).to(device) 699 | 700 | # tau quantization, static hardware friendly values 701 | if tau_syn.shape[0] == 2: 702 | self.tau_syn = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_syn[0], tau_syn[1]).to(device) 703 | self.beta = 1. - 1e-3 / self.tau_syn 704 | self.tau_syn = 1. / (1. - self.beta) 705 | else: 706 | self.beta = torch.tensor([1 - delta_t / tau_syn], dtype = dtype).to(device) 707 | self.tau_syn = 1. / (1. - self.beta) 708 | 709 | 710 | if tau_mem.shape[0] == 2: 711 | self.tau_mem = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_mem[0], tau_mem[1]).to(device) 712 | self.alpha = 1. - 1e-3 / self.tau_mem 713 | self.tau_mem = 1. / (1. - self.alpha) 714 | else: 715 | self.alpha = torch.tensor([1 - delta_t / tau_mem], dtype = dtype).to(device) 716 | self.tau_mem = 1. / (1. - self.alpha) 717 | 718 | 719 | if tau_ref.shape[0] == 2: 720 | self.tau_ref = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_ref[0], tau_ref[1]).to(device) 721 | self.gamma = 1. - 1e-3 / self.tau_gamma 722 | self.tau_ref = 1. / (1. - self.gamma) 723 | else: 724 | self.gamma = torch.tensor([1 - delta_t / tau_ref], dtype = dtype).to(device) 725 | self.tau_ref = 1. / (1. - self.gamma) 726 | 727 | self.r_scale = 1/(1-self.gamma) # the one comes from decolle, best value ? 728 | #self.q_scale = self.tau_syn/(1-self.beta) 729 | #self.q_scale = self.q_scale.max() 730 | # p_scale should be max overall to differentiate input signals 731 | #self.p_scale = (self.tau_mem * self.q_scale*self.PQ_cap)/(1-self.alpha) 732 | #self.p_scale = self.p_scale.max() 733 | 734 | self.inp_mult_q = self.tau_syn##1/self.PQ_cap * (1-self.beta.max()) # 735 | self.inp_mult_p = self.tau_mem##1/self.PQ_cap * (1-self.alpha.max()) # 736 | #self.pmult = self.p_scale * self.PQ_cap * self.weight_mult 737 | 738 | # those might be clamped as in chop off values. 739 | self.Q_scale = (self.tau_syn/(1-self.beta)).max() 740 | self.P_scale = ((self.tau_mem * self.Q_scale)/(1-self.alpha)).max() 741 | self.Q_scale = (self.tau_syn/(1-self.beta)).max() 742 | self.R_scale = 1/(1-self.gamma) 743 | 744 | if quantization.global_wb is not None: 745 | with torch.no_grad(): 746 | self.weights.data = quantization.quant_w(self.weights.data) 747 | if self.bias is not None: 748 | self.bias.data = quantization.quant_w(self.bias.data) 749 | 750 | 751 | def state_init(self, batch_size): 752 | self.P = torch.zeros((batch_size,) + self.inp_shape, dtype = self.dtype).detach().to(self.device) 753 | self.Q = torch.zeros((batch_size,) + self.inp_shape, dtype = self.dtype).detach().to(self.device) 754 | self.R = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device) 755 | self.S = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device) 756 | self.U = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device) 757 | 758 | 759 | def forward(self, input_t, y_local, train_flag = False, test_flag = False): 760 | # probably dont need to quantize because gb steps are arleady in the right level... just clipping 761 | if quantization.global_gb is not None: 762 | with torch.no_grad(): 763 | self.weights.data = quantization.clip(self.weights.data/self.weight_mult, quantization.global_gb)*self.weight_mult 764 | if self.bias is not None: 765 | self.bias.data = quantization.clip(self.bias.data/self.weight_mult, quantization.global_gb)*self.weight_mult 766 | if quantization.global_rfb is not None: 767 | # R always using full scale? 768 | self.R = quantization.quant01(self.R/self.R_scale, quantization.global_rfb)*self.R_scale 769 | 770 | #self.P, self.R, self.Q = self.alpha * self.P + self.tau_mem * self.Q, self.gamma * self.R, self.beta * self.Q + self.tau_syn * input_t 771 | #dtype necessary 772 | self.P, self.R, self.Q = self.alpha * self.P + self.inp_mult_p * self.Q, self.gamma * self.R, self.beta * self.Q + self.inp_mult_q * input_t.type(self.dtype) 773 | 774 | if self.PQ_cap != 1: 775 | self.P = torch.clamp(self.P, 0, self.P_scale*self.PQ_cap) 776 | self.Q = torch.clamp(self.Q, 0, self.Q_scale*self.PQ_cap) 777 | 778 | if quantization.global_pb is not None: 779 | self.P = torch.clamp(self.P/(self.P_scale*self.PQ_cap), 0, 1) 780 | self.P = quantization.quant01(self.P, quantization.global_pb)*(self.P_scale*self.PQ_cap) 781 | if quantization.global_qb is not None: 782 | self.Q = torch.clamp(self.Q/(self.Q_scale*self.PQ_cap), 0, 1) 783 | self.Q = quantization.quant01(self.Q, quantization.global_qb)*(self.Q_scale*self.PQ_cap) 784 | 785 | #self.U = QSConv2dFunctional.apply(self.P * self.pmult, self.weights, self.bias, self.scale, self.padding) - self.R 786 | self.U = QSConv2dFunctional.apply(self.P, self.weights, self.bias, self.scale, self.padding, self.weight_mult) - self.R #* self.r_scale 787 | if quantization.global_ub is not None: 788 | self.U = quantU.apply(self.U) 789 | self.S = (self.U >= self.thr).type(self.dtype) 790 | self.S += (self.U <= -self.thr).type(self.dtype)*-1 791 | self.R += self.S * self.thr#(1-self.gamma) 792 | 793 | 794 | if test_flag or train_flag: 795 | self.U_aux = torch.sigmoid(self.U) # quantize this function.... at some point 796 | self.U_aux = self.mpool(self.U_aux) 797 | 798 | rreadout = self.dropout_learning(self.sign_random_readout(self.U_aux.reshape([input_t.shape[0], np.prod(self.out_shape2)]))) * self.dropout_p 799 | 800 | if train_flag: 801 | if quantization.global_eb is not None: 802 | part1 = quantization.SSE(rreadout, y_local) 803 | part2 = self.l1 * 200e-1 * F.relu((self.U+.01)).mean() 804 | part3 = self.l2 *1e-1* F.relu(.1-self.U_aux.mean()) 805 | loss_gen = part1 + part2 + part3 806 | else: 807 | part1 = self.loss_fn(rreadout, y_local) 808 | part2 = self.l1 * 200e-1 * F.relu((self.U+.01)).mean() 809 | part3 = self.l2 *1e-1* F.relu(.1-self.U_aux.mean()) 810 | loss_gen = part1 + part2 + part3 811 | #loss_gen = self.loss_fn(rreadout, y_local) + self.l1 * 200e-1 * F.relu((self.U+.01)).mean() + self.l2 *1e-1* F.relu(.1-self.U_aux.mean()) 812 | else: 813 | loss_gen = None 814 | else: 815 | loss_gen = None 816 | rreadout = torch.tensor([[0]]) 817 | 818 | 819 | return self.mpool(self.S), loss_gen, rreadout.argmax(1), [part1, part2, part3] 820 | 821 | 822 | -------------------------------------------------------------------------------- /prepGesture.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import torch 5 | import pickle 6 | 7 | 8 | 9 | def read_aedat31(filename, labels_f, test_set = False): 10 | # https://inivation.com/support/software/fileformat/#aedat-31 11 | # http://research.ibm.com/dvsgesture/ 12 | gestures_full = [] 13 | labels_full = [] 14 | 15 | # Addresses will be interpreted as 32 bits 16 | print(filename) 17 | f = open(filename, 'r', encoding='latin_1') 18 | labels = np.genfromtxt(labels_f, delimiter=',')[1:] 19 | #Skip header lines 20 | bof = f.tell() 21 | line = f.readline() 22 | while (line[0]=='#'): 23 | print(line, end='') 24 | bof = f.tell() 25 | line = f.readline() 26 | 27 | # read data 28 | f.seek(bof,0) 29 | dataArray = np.fromfile(f, '> 17 ) & 0x00001FFF 46 | y = ( addr >> 2 ) & 0x00001FFF 47 | polarity = ( addr >> 1 ) & 0x00000001 48 | 49 | # how to access header info 50 | # dataArray[0] >> 16 # event type -> polarity event 51 | # dataArray[0] & 0xFFFF0000 # event source ID 52 | # dataArray[1] # eventSize 53 | # dataArray[2] # eventTSOffset 54 | # dataArray[3] # eventTSOverflow 55 | # dataArray[4] # eventCapacity (always equals eventNumber) 56 | # dataArray[5] # eventNumber (valid + invalid) 57 | # dataArray[6] # eventValid 58 | 59 | stim = np.array([allTs, x, y, polarity]).T#.astype(int) 60 | for i in labels: 61 | 62 | # chop things right 63 | single_gesture = stim[stim[:, 0] >= i[1]] 64 | single_gesture = single_gesture[single_gesture[:, 0] <= i[2]] 65 | 66 | # bin them 1ms 67 | single_gesture[:,0] = np.floor(single_gesture[:,0]/1000) 68 | single_gesture[:,0] = single_gesture[:,0] - np.min(single_gesture[:,0]) 69 | 70 | if test_set: 71 | single_gesture = single_gesture[single_gesture[:,0] <= 1800] 72 | 73 | #if i[0] in labels_full: 74 | # gestures_full[labels_full.index(i[0])] = np.vstack((gestures_full[labels_full.index(i[0])], single_gesture)) 75 | #else: 76 | gestures_full.append(single_gesture) 77 | # record label 78 | labels_full.append(i[0]) 79 | return gestures_full, labels_full 80 | 81 | 82 | # full set 83 | gestures_full = [] 84 | labels_full = [] 85 | with open('trials_to_train.txt') as fp: 86 | for cnt, line in enumerate(fp): 87 | try: 88 | gestures_temp, labels_temp = read_aedat31(line.split(".")[0] + ".aedat", line.split(".")[0] + "_labels.csv") 89 | gestures_full += gestures_temp 90 | labels_full += labels_temp 91 | except: 92 | continue 93 | 94 | with open('train_dvs_gesture.pickle', 'wb') as handle: 95 | pickle.dump((gestures_full, labels_full), handle) 96 | 97 | 98 | 99 | 100 | gestures_full = [] 101 | labels_full = [] 102 | with open('trials_to_test.txt') as fp: 103 | for cnt, line in enumerate(fp): 104 | try: 105 | gestures_temp, labels_temp = read_aedat31(line.split(".")[0] + ".aedat", line.split(".")[0] + "_labels.csv", test_set = True) 106 | gestures_full += gestures_temp 107 | labels_full += labels_temp 108 | except: 109 | continue 110 | 111 | with open('test_dvs_gesture.pickle', 'wb') as handle: 112 | pickle.dump((gestures_full, labels_full), handle) 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /prepPoker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import pickle 4 | import os 5 | 6 | 7 | def dat2mat(filename, retinaSizeX, only_pos=False): 8 | """dat2mat.py: This script converts a aedat file into a list of events. 9 | It only works for 32 unsinged values in the aedat file. 10 | 11 | filename: name of the dat file 12 | retinaSizeX: one dimension of the retina size 13 | only_pos: True to delete all the negative spikes from the dat file 14 | """ 15 | print('Addresses will be interpreted as 32 bits') 16 | maxEvents = 30e6 17 | numBytesPerEvent = 8 18 | 19 | f = open(filename, 'r', encoding='latin-1') 20 | bof = f.tell() 21 | #Skip header lines 22 | line = f.readline() 23 | while (line[0]=='#'): 24 | print(line) 25 | bof = f.tell() 26 | line = f.readline() 27 | 28 | #Calculate number of events 29 | f.seek(0,2) #EOF 30 | numEvents = (f.tell()-bof)/numBytesPerEvent 31 | if (numEvents>maxEvents): 32 | print("More events than the maximum events!!!") 33 | numEvents = maxEvents 34 | #Read data 35 | f.seek(bof,0) 36 | dataArray = np.fromfile(f, '>u4') 37 | allAddr = dataArray[::2] 38 | allTs = dataArray[1::2] 39 | f.close() 40 | #print allTs 41 | 42 | #Define event format 43 | xmask = 0xFE 44 | ymask = 0x7F00 45 | xshift = 1 46 | yshift = 8 47 | if (retinaSizeX == 32): 48 | xshift=3 #Subsampling of 4 49 | yshift=10 #Subsampling of 4 50 | polmask = 0x1 51 | addr = abs(allAddr) 52 | x = (addr & xmask)>>xshift 53 | y = (addr & ymask)>>yshift 54 | pol = 1 - (2*(addr & polmask)) #1 for ON, -1 for OFF 55 | pol = pol.astype(np.int32) 56 | ''' 57 | #invert x 58 | x = retinaSizeX - x 59 | ''' 60 | #Do relative time 61 | tpo = allTs; 62 | tpo[:] = tpo[:]-tpo[0] 63 | 64 | stim = np.array([tpo, np.zeros(x.size, dtype=np.int), \ 65 | -1*np.ones(x.size, dtype=np.int), x, y, pol]) 66 | stim = np.transpose(stim) 67 | 68 | if (only_pos == True): 69 | res_stim = stim[stim[:,5]==1, :] 70 | else: 71 | res_stim = stim 72 | 73 | # bin them 1ms 74 | res_stim[:,0] = np.floor(res_stim[:,0]/1000) 75 | #res_stim[:,0] = res_stim[:,0] - np.min(res_stim[:,0]) 76 | 77 | return res_stim 78 | 79 | 80 | 81 | chunk_size = 500 82 | chunk_size = 1300 83 | chunk_size = 2400 84 | file_list = ["RetinaTeresa2-club_long.aedat", "RetinaTeresa2-diamond_long.aedat", "RetinaTeresa2-heart_long.aedat", "RetinaTeresa2-spade_long.aedat"] 85 | start_ts = np.arange(0,121000/chunk_size)*chunk_size 86 | end_ts = np.arange(0,121000/chunk_size)*chunk_size + chunk_size #its not 3min... one recording is just 2min! 87 | cards_full = [] 88 | labels_full = [] 89 | 90 | for idx,cur_file in enumerate(file_list): 91 | stim_cur = dat2mat(cur_file, 128, False) 92 | for i in np.arange(len(start_ts)): 93 | temp_cur = stim_cur[stim_cur[:,0] >= start_ts[i]] 94 | temp_cur = temp_cur[temp_cur[:,0] < end_ts[i]] 95 | if(len(temp_cur) == 0): 96 | import pdb; pdb.set_trace() 97 | temp_cur[:,0] = temp_cur[:,0]-start_ts[i] 98 | cards_full.append(temp_cur) 99 | labels_full += [idx]*len(start_ts) 100 | 101 | #80/20 split train/test 102 | cards_full = np.array(cards_full) 103 | labels_full = np.array(labels_full) 104 | shuffle_idx = np.arange(len(labels_full)) 105 | np.random.shuffle(shuffle_idx) 106 | cards_full = cards_full[shuffle_idx] 107 | labels_full = labels_full[shuffle_idx] 108 | 109 | 110 | with open('slow_poker_'+str(chunk_size)+'_train.pickle', 'wb') as handle: 111 | pickle.dump((cards_full[:int(len(labels_full)*.8) ], labels_full[:int(len(labels_full)*.8) ]), handle) 112 | with open('slow_poker_'+str(chunk_size)+'_test.pickle', 'wb') as handle: 113 | pickle.dump((cards_full[int(len(labels_full)*.8):], labels_full[int(len(labels_full)*.8):]), handle) 114 | 115 | 116 | -------------------------------------------------------------------------------- /qsnn_decolle.py: -------------------------------------------------------------------------------- 1 | import pickle, argparse, time, math, datetime, uuid 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | import numpy as np 8 | 9 | import quantization 10 | import localQ 11 | from localQ import sparse_data_generator_Static, sparse_data_generator_DVSGesture, sparse_data_generator_DVSPoker, LIFConv2dLayer, prep_input, acc_comp, create_graph, DTNLIFConv2dLayer, create_graph2 12 | 13 | 14 | # Check whether a GPU is available 15 | if torch.cuda.is_available(): 16 | device = torch.device("cuda") 17 | else: 18 | device = torch.device("cpu") 19 | dtype = torch.float32 20 | ms = 1e-3 21 | 22 | 23 | 24 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 25 | parser.add_argument("--data-set", type=str, default="Gesture", help='Input date set: Poker/Gesture') 26 | 27 | parser.add_argument("--global_wb", type=int, default=8, help='Weight bitwidth') 28 | parser.add_argument("--global_qb", type=int, default=10, help='Synapse bitwidth') 29 | parser.add_argument("--global_pb", type=int, default=12, help='Membrane trace bitwidth') 30 | parser.add_argument("--global_rfb", type=int, default=2, help='Refractory bitwidth') 31 | 32 | parser.add_argument("--global_sb", type=int, default=6, help='Learning signal bitwidth') 33 | parser.add_argument("--global_gb", type=int, default=10, help='Gradient bitwidth') 34 | parser.add_argument("--global_eb", type=int, default=6, help='Error bitwidth') 35 | 36 | parser.add_argument("--global_ub", type=int, default=6, help='Membrane Potential bitwidth') 37 | parser.add_argument("--global_ab", type=int, default=6, help='Activation bitwidth') 38 | parser.add_argument("--global_sig", type=int, default=6, help='Sigmoid bitwidth') 39 | 40 | parser.add_argument("--global_rb", type=int, default=16, help='Gradient RNG bitwidth') 41 | parser.add_argument("--global_lr", type=int, default=1, help='Learning rate for quantized gradients') 42 | parser.add_argument("--global_lr_sgd", type=float, default=1.0e-9, help='Learning rate for SGD') 43 | parser.add_argument("--global_beta", type=float, default=1.5, help='Beta for weight init') 44 | 45 | parser.add_argument("--delta_t", type=float, default=1*ms, help='Time step in ms') 46 | parser.add_argument("--input_mode", type=int, default=0, help='Spike processing method') 47 | parser.add_argument("--ds", type=int, default=4, help='Downsampling') 48 | parser.add_argument("--epochs", type=int, default=320, help='Epochs for training') 49 | parser.add_argument("--lr_div", type=int, default=80, help='Learning rate divide interval') 50 | parser.add_argument("--batch_size", type=int, default=72, help='Batch size') 51 | 52 | parser.add_argument("--PQ_cap", type=float, default=1, help='Value cap for membrane and synpase trace') 53 | parser.add_argument("--weight_mult", type=float, default=4e-5, help='Weight multiplier') 54 | parser.add_argument("--dropout_p", type=float, default=.5, help='Dropout probability') 55 | parser.add_argument("--lc_ampl", type=float, default=.5, help='Magnitude amplifier for weight init') 56 | parser.add_argument("--l1", type=float, default=.001, help='Regularizer 1') 57 | parser.add_argument("--l2", type=float, default=.001, help='Regularizer 2') 58 | 59 | parser.add_argument("--tau_mem_lower", type=float, default=5, help='Tau mem lower bound') 60 | parser.add_argument("--tau_mem_upper", type=float, default=35, help='Tau mem upper bound') 61 | parser.add_argument("--tau_syn_lower", type=float, default=5, help='Tau syn lower bound') 62 | parser.add_argument("--tau_syn_upper", type=float, default=10, help='Tau syn upper bound') 63 | parser.add_argument("--tau_ref", type=float, default=1/.35, help='Tau ref') 64 | 65 | 66 | args = parser.parse_args() 67 | 68 | 69 | # set quant level 70 | quantization.global_wb = args.global_wb 71 | quantization.global_qb = args.global_qb 72 | quantization.global_pb = args.global_pb 73 | quantization.global_rfb = args.global_rfb 74 | 75 | quantization.global_sb = args.global_sb 76 | quantization.global_gb = args.global_gb 77 | quantization.global_eb = args.global_eb 78 | 79 | quantization.global_ub = args.global_ub 80 | quantization.global_ab = args.global_ab 81 | quantization.global_sig = args.global_sig 82 | 83 | quantization.global_rb = args.global_rb 84 | quantization.global_lr = args.global_lr 85 | quantization.global_lr_sgd = args.global_lr_sgd 86 | quantization.global_beta = args.global_beta 87 | quantization.weight_mult = args.weight_mult 88 | 89 | localQ.lc_ampl = args.lc_ampl 90 | 91 | tau_mem = torch.tensor([args.tau_mem_lower*ms, args.tau_mem_upper*ms], dtype = dtype).to(device) 92 | tau_ref = torch.tensor([args.tau_ref*ms], dtype = dtype).to(device) 93 | tau_syn = torch.tensor([args.tau_syn_lower*ms, args.tau_syn_upper*ms], dtype = dtype).to(device) 94 | 95 | 96 | if args.data_set == "Poker": 97 | ds_name = "DVS Poker" 98 | with open('data/slow_poker_500_train.pickle', 'rb') as f: 99 | data = pickle.load(f) 100 | x_train = data[0].tolist() 101 | for i in range(len(x_train)): 102 | x_train[i] = x_train[i][:,[0,3,4,5]] 103 | x_train[i][:,3][x_train[i][:,3] == -1] = 0 104 | x_train[i] = x_train[i].astype('uint32') 105 | y_train = data[1] 106 | 107 | x_train = np.array(x_train) 108 | y_train = np.array(y_train) 109 | 110 | idx_temp = np.arange(len(x_train)) 111 | np.random.shuffle(idx_temp) 112 | idx_train = idx_temp[0:int(len(y_train)*.8)] 113 | idx_val = idx_temp[int(len(y_train)*.8):] 114 | 115 | x_train, x_val = x_train[idx_train], x_train[idx_val] 116 | y_train, y_val = y_train[idx_train], y_train[idx_val] 117 | 118 | with open('data/slow_poker_500_test.pickle', 'rb') as f: 119 | data = pickle.load(f) 120 | x_test = data[0].tolist() 121 | for i in range(len(x_test)): 122 | x_test[i] = x_test[i][:,[0,3,4,5]] 123 | x_test[i][:,3][x_test[i][:,3] == -1] = 0 124 | x_test[i] = x_test[i].astype('uint32') 125 | y_test = data[1] 126 | 127 | output_neurons = 4 128 | T = 500*ms 129 | T_test = 500*ms 130 | burnin = 50*ms 131 | x_size = 32 132 | y_size = 32 133 | train_tflag = True 134 | 135 | 136 | 137 | elif args.data_set == "Gesture": 138 | ds_name = "DVS Gesture" 139 | with open('data/train_dvs_gesture88.pickle', 'rb') as f: 140 | data = pickle.load(f) 141 | x_train = np.array(data[0]) 142 | y_train = np.array(data[1], dtype = int) - 1 143 | 144 | idx_temp = np.arange(len(x_train)) 145 | np.random.shuffle(idx_temp) 146 | idx_train = idx_temp[0:int(len(y_train)*.8)] 147 | idx_val = idx_temp[int(len(y_train)*.8):] 148 | 149 | x_train, x_val = x_train[idx_train], x_train[idx_val] 150 | y_train, y_val = y_train[idx_train], y_train[idx_val] 151 | 152 | 153 | with open('data/test_dvs_gesture88.pickle', 'rb') as f: 154 | data = pickle.load(f) 155 | x_test = data[0] 156 | y_test = np.array(data[1], dtype = int) - 1 157 | 158 | output_neurons = 11 159 | T = 500*ms 160 | T_test = 1800*ms 161 | burnin = 50*ms 162 | x_size = 32 163 | y_size = 32 164 | train_tflag = False 165 | else: 166 | raise Exception("Data set unknown.") 167 | 168 | sl1_loss = torch.nn.MSELoss() 169 | 170 | thr = torch.tensor([.0], dtype = dtype).to(device) 171 | layer1 = LIFConv2dLayer(inp_shape = (2, x_size, y_size), kernel_size = 7, out_channels = 64, tau_mem = tau_mem, tau_syn = tau_syn, tau_ref = tau_ref, delta_t = args.delta_t, pooling = 2, padding = 2, thr = thr, device = device, dropout_p = args.dropout_p, output_neurons = output_neurons, loss_fn = sl1_loss, l1 = args.l1, l2 = args.l2, PQ_cap = args.PQ_cap, weight_mult = args.weight_mult, dtype = dtype).to(device) 172 | 173 | layer2 = LIFConv2dLayer(inp_shape = layer1.out_shape2, kernel_size = 7, out_channels = 128, tau_mem = tau_mem, tau_syn = tau_syn, tau_ref = tau_ref, delta_t = args.delta_t, pooling = 1, padding = 2, thr = thr, device = device, dropout_p = args.dropout_p, output_neurons = output_neurons, loss_fn = sl1_loss, l1 = args.l1, l2 = args.l2, PQ_cap = args.PQ_cap, weight_mult = args.weight_mult, dtype = dtype).to(device) 174 | 175 | layer3 = LIFConv2dLayer(inp_shape = layer2.out_shape2, kernel_size = 7, out_channels = 128, tau_mem = tau_mem, tau_syn = tau_syn, tau_ref = tau_ref, delta_t = args.delta_t, pooling = 2, padding = 2, thr = thr, device = device, dropout_p = args.dropout_p, output_neurons = output_neurons, loss_fn = sl1_loss, l1 = args.l1, l2 = args.l2, PQ_cap = args.PQ_cap, weight_mult = args.weight_mult, dtype = dtype).to(device) 176 | 177 | 178 | all_parameters = list(layer1.parameters()) + list(layer2.parameters()) + list(layer3.parameters()) 179 | 180 | # initlialize optimizier 181 | if quantization.global_gb is not None: 182 | opt = torch.optim.SGD(all_parameters, lr = 1) 183 | else: 184 | opt = torch.optim.SGD(all_parameters, lr = quantization.global_lr_sgd) 185 | 186 | def eval_test(): 187 | batch_corr = {'train1': [], 'test1': [],'train2': [], 'test2': [],'train3': [], 'test3': [], 'loss':[], 'act_train1':0, 'act_train2':0, 'act_train3':0, 'act_test1':0, 'act_test2':0, 'act_test3':0, 'w1u':0, 'w2u':0, 'w3u':0} 188 | # test accuracy 189 | for x_local, y_local in sparse_data_generator_DVSGesture(x_test, y_test, batch_size = args.batch_size, nb_steps = T_test / ms, shuffle = True, device = device, test = True, ds = args.ds, x_size = x_size, y_size = y_size): 190 | rread_hist1_test = [] 191 | rread_hist2_test = [] 192 | rread_hist3_test = [] 193 | 194 | y_onehot = torch.Tensor(len(y_local), output_neurons).to(device).type(dtype) 195 | y_onehot.zero_() 196 | y_onehot.scatter_(1, y_local.reshape([y_local.shape[0],1]), 1) 197 | 198 | 199 | layer1.state_init(x_local.shape[0]) 200 | layer2.state_init(x_local.shape[0]) 201 | layer3.state_init(x_local.shape[0]) 202 | 203 | for t in range(int(T_test/ms)): 204 | test_flag = (t > int(burnin/ms)) 205 | 206 | out_spikes1, temp_loss1, temp_corr1, _ = layer1.forward(prep_input(x_local[:,:,:,:,t], args.input_mode), y_onehot, test_flag = test_flag) 207 | out_spikes2, temp_loss2, temp_corr2, _ = layer2.forward(out_spikes1, y_onehot, test_flag = test_flag) 208 | out_spikes3, temp_loss3, temp_corr3, _ = layer3.forward(out_spikes2, y_onehot, test_flag = test_flag) 209 | 210 | if test_flag: 211 | rread_hist1_test.append(temp_corr1) 212 | rread_hist2_test.append(temp_corr2) 213 | rread_hist3_test.append(temp_corr3) 214 | 215 | 216 | batch_corr['test1'].append(acc_comp(rread_hist1_test, y_local, True)) 217 | batch_corr['test2'].append(acc_comp(rread_hist2_test, y_local, True)) 218 | batch_corr['test3'].append(acc_comp(rread_hist3_test, y_local, True)) 219 | 220 | return torch.cat(batch_corr['test3']).mean() 221 | 222 | 223 | w1, w2, w3, b1, b2, b3 = None, None, None, None, None, None 224 | 225 | diff_layers_acc = {'train1': [], 'test1': [],'train2': [], 'test2': [],'train3': [], 'test3': [], 'loss':[], 'act_train1':[], 'act_train2':[], 'act_train3':[], 'act_test1':[], 'act_test2':[], 'act_test3':[], 'w1update':[], 'w2update':[], 'w3update':[]} 226 | print("WUPQR SASigEG Quantization: {0}{1}{2}{3}{4} {5}{6}{7}{8}{9} l1 {10:.3f} l2 {11:.3f} Inp {12} LR {13} Drop {14} Cap {15} thr {16}".format(quantization.global_wb, quantization.global_ub, quantization.global_pb, quantization.global_qb, quantization.global_rfb, quantization.global_sb, quantization.global_ab, quantization.global_sig, quantization.global_eb, quantization.global_gb, args.l1, args.l2, args.input_mode, quantization.global_lr if quantization.global_lr != None else quantization.global_lr_sgd, args.dropout_p, args.PQ_cap, thr.item())) 227 | plot_file_name = "DVS_WPQUEG{0}{1}{2}{3}{4}{5}{6}_Inp{7}_LR{8}_Drop{9}_thr{10}".format(quantization.global_wb, quantization.global_pb, quantization.global_qb, quantization.global_ub, quantization.global_eb, quantization.global_gb, quantization.global_sb, args.input_mode, quantization.global_lr, args.dropout_p, thr.item())+datetime.datetime.now().strftime("_%Y%m%d_%H%M%S") 228 | print("Epoch Loss Train1 Train2 Train3 Test1 Test2 Test3 | TrainT TestT") 229 | 230 | best_vali = torch.tensor(0, device = device) 231 | 232 | for e in range(args.epochs): 233 | if ((e+1)%args.lr_div)==0: 234 | if quantization.global_gb is not None: 235 | quantization.global_lr /= 2 236 | else: 237 | opt.param_groups[-1]['lr'] /= 5 238 | 239 | 240 | batch_corr = {'train1': [], 'test1': [],'train2': [], 'test2': [],'train3': [], 'test3': [], 'loss':[], 'act_train1':0, 'act_train2':0, 'act_train3':0, 'act_test1':0, 'act_test2':0, 'act_test3':0, 'w1u':0, 'w2u':0, 'w3u':0} 241 | quantization.global_w1update = 0 242 | quantization.global_w2update = 0 243 | quantization.global_w3update = 0 244 | start_time = time.time() 245 | 246 | # training 247 | for x_local, y_local in sparse_data_generator_DVSGesture(x_train, y_train, batch_size = args.batch_size, nb_steps = T / ms, shuffle = True, test = train_tflag, device = device, ds = args.ds, x_size = x_size, y_size = y_size): 248 | 249 | y_onehot = torch.Tensor(len(y_local), output_neurons).to(device).type(dtype) 250 | y_onehot.zero_() 251 | y_onehot.scatter_(1, y_local.reshape([y_local.shape[0],1]), 1) 252 | 253 | rread_hist1_train = [] 254 | rread_hist2_train = [] 255 | rread_hist3_train = [] 256 | loss_hist = [] 257 | 258 | 259 | layer1.state_init(x_local.shape[0]) 260 | layer2.state_init(x_local.shape[0]) 261 | layer3.state_init(x_local.shape[0]) 262 | 263 | for t in range(int(T/ms)): 264 | train_flag = (t > int(burnin/ms)) 265 | 266 | out_spikes1, temp_loss1, temp_corr1, lparts1 = layer1.forward(prep_input(x_local[:,:,:,:,t], args.input_mode), y_onehot, train_flag = train_flag) 267 | out_spikes2, temp_loss2, temp_corr2, lparts2 = layer2.forward(out_spikes1, y_onehot, train_flag = train_flag) 268 | out_spikes3, temp_loss3, temp_corr3, lparts3 = layer3.forward(out_spikes2, y_onehot, train_flag = train_flag) 269 | 270 | 271 | 272 | if train_flag: 273 | loss_gen = temp_loss1 + temp_loss2 + temp_loss3 274 | 275 | loss_gen.backward() 276 | opt.step() 277 | opt.zero_grad() 278 | 279 | loss_hist.append(loss_gen.item()) 280 | rread_hist1_train.append(temp_corr1) 281 | rread_hist2_train.append(temp_corr2) 282 | rread_hist3_train.append(temp_corr3) 283 | 284 | 285 | batch_corr['act_train1'] += int(out_spikes1.sum()) 286 | batch_corr['act_train2'] += int(out_spikes2.sum()) 287 | batch_corr['act_train3'] += int(out_spikes3.sum()) 288 | 289 | 290 | batch_corr['train1'].append(acc_comp(rread_hist1_train, y_local, True)) 291 | batch_corr['train2'].append(acc_comp(rread_hist2_train, y_local, True)) 292 | batch_corr['train3'].append(acc_comp(rread_hist3_train, y_local, True)) 293 | del x_local, y_local, y_onehot 294 | 295 | 296 | train_time = time.time() 297 | 298 | diff_layers_acc['train1'].append(torch.cat(batch_corr['train1']).mean()) 299 | diff_layers_acc['train2'].append(torch.cat(batch_corr['train2']).mean()) 300 | diff_layers_acc['train3'].append(torch.cat(batch_corr['train3']).mean()) 301 | diff_layers_acc['act_train1'].append(batch_corr['act_train1']) 302 | diff_layers_acc['act_train2'].append(batch_corr['act_train2']) 303 | diff_layers_acc['act_train3'].append(batch_corr['act_train3']) 304 | diff_layers_acc['loss'].append(np.mean(loss_hist)/3) 305 | diff_layers_acc['w1update'].append(quantization.global_w1update) 306 | diff_layers_acc['w2update'].append(quantization.global_w2update) 307 | diff_layers_acc['w3update'].append(quantization.global_w3update) 308 | 309 | 310 | # test accuracy 311 | for x_local, y_local in sparse_data_generator_DVSGesture(x_val, y_val, batch_size = args.batch_size, nb_steps = T_test / ms, shuffle = True, device = device, test = True, ds = args.ds, x_size = x_size, y_size = y_size): 312 | rread_hist1_test = [] 313 | rread_hist2_test = [] 314 | rread_hist3_test = [] 315 | 316 | y_onehot = torch.Tensor(len(y_local), output_neurons).to(device).type(dtype) 317 | y_onehot.zero_() 318 | y_onehot.scatter_(1, y_local.reshape([y_local.shape[0],1]), 1) 319 | 320 | 321 | layer1.state_init(x_local.shape[0]) 322 | layer2.state_init(x_local.shape[0]) 323 | layer3.state_init(x_local.shape[0]) 324 | 325 | for t in range(int(T_test/ms)): 326 | test_flag = (t > int(burnin/ms)) 327 | 328 | out_spikes1, temp_loss1, temp_corr1, _ = layer1.forward(prep_input(x_local[:,:,:,:,t], args.input_mode), y_onehot, test_flag = test_flag) 329 | out_spikes2, temp_loss2, temp_corr2, _ = layer2.forward(out_spikes1, y_onehot, test_flag = test_flag) 330 | out_spikes3, temp_loss3, temp_corr3, _ = layer3.forward(out_spikes2, y_onehot, test_flag = test_flag) 331 | 332 | if test_flag: 333 | rread_hist1_test.append(temp_corr1) 334 | rread_hist2_test.append(temp_corr2) 335 | rread_hist3_test.append(temp_corr3) 336 | 337 | batch_corr['act_test1'] += int(out_spikes1.sum()) 338 | batch_corr['act_test2'] += int(out_spikes2.sum()) 339 | batch_corr['act_test3'] += int(out_spikes3.sum()) 340 | 341 | batch_corr['test1'].append(acc_comp(rread_hist1_test, y_local, True)) 342 | batch_corr['test2'].append(acc_comp(rread_hist2_test, y_local, True)) 343 | batch_corr['test3'].append(acc_comp(rread_hist3_test, y_local, True)) 344 | del x_local, y_local, y_onehot 345 | 346 | inf_time = time.time() 347 | 348 | if best_vali.item() < torch.cat(batch_corr['test3']).mean().item(): 349 | best_vali = torch.cat(batch_corr['test3']).mean() 350 | test_acc_best_vali = eval_test() 351 | w1 = layer1.weights.data.detach().cpu() 352 | w2 = layer2.weights.data.detach().cpu() 353 | w3 = layer3.weights.data.detach().cpu() 354 | b1 = layer1.bias.data.detach().cpu() 355 | b2 = layer2.bias.data.detach().cpu() 356 | b3 = layer3.bias.data.detach().cpu() 357 | 358 | diff_layers_acc['test1'].append(torch.cat(batch_corr['test1']).mean()) 359 | diff_layers_acc['test2'].append(torch.cat(batch_corr['test2']).mean()) 360 | diff_layers_acc['test3'].append(torch.cat(batch_corr['test3']).mean()) 361 | diff_layers_acc['act_test1'].append(batch_corr['act_test1']) 362 | diff_layers_acc['act_test2'].append(batch_corr['act_test2']) 363 | diff_layers_acc['act_test3'].append(batch_corr['act_test3']) 364 | 365 | print("{0:02d} {1:.3E} {2:.4f} {3:.4f} {4:.4f} {5:.4f} {6:.4f} {7:.4f} | {8:.4f} {9:.4f}".format(e+1, diff_layers_acc['loss'][-1], diff_layers_acc['train1'][-1], diff_layers_acc['train2'][-1], diff_layers_acc['train3'][-1], diff_layers_acc['test1'][-1], diff_layers_acc['test2'][-1], diff_layers_acc['test3'][-1], train_time - start_time, inf_time - train_time)) 366 | create_graph(plot_file_name, diff_layers_acc, ds_name, test_acc_best_vali) 367 | 368 | 369 | 370 | # saving results and weights 371 | results = { 372 | 'layer1':[layer1.weights.detach().cpu(), layer1.bias.detach().cpu(), w1, b1, layer1.sign_random_readout.weights.detach().cpu(), layer1.sign_random_readout.weight_fa.detach().cpu(), layer1.tau_mem.cpu(), layer1.tau_syn.cpu(), layer1.tau_ref.cpu()], 373 | 'layer2':[layer2.weights.detach().cpu(), layer2.bias.detach().cpu(), w2, b2, layer2.sign_random_readout.weights.detach().cpu(), layer2.sign_random_readout.weight_fa.detach().cpu(), layer2.tau_mem.cpu(), layer2.tau_syn.cpu(), layer2.tau_ref.cpu()], 374 | 'layer3':[layer3.weights.detach().cpu(), layer3.bias.detach().cpu(), w3, b3, layer3.sign_random_readout.weights.detach().cpu(), layer3.sign_random_readout.weight_fa.detach().cpu(), layer3.tau_mem.cpu(), layer3.tau_syn.cpu(), layer3.tau_ref.cpu()], 375 | 'acc': diff_layers_acc, 'fname':plot_file_name, 'args': args, 'evaled_test':test_acc_best_vali} 376 | with open('results/'+plot_file_name+'.pkl', 'wb') as f: 377 | pickle.dump(results, f) 378 | 379 | -------------------------------------------------------------------------------- /qsnn_precise.py: -------------------------------------------------------------------------------- 1 | import argparse, pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision 6 | import numpy as np 7 | 8 | import spytorch_util 9 | import quantization 10 | 11 | dtype = torch.float 12 | 13 | # Check whether a GPU is available 14 | if torch.cuda.is_available(): 15 | device = torch.device("cuda") 16 | else: 17 | device = torch.device("cpu") 18 | 19 | # Code is based on: https://github.com/fzenke/spytorch 20 | 21 | 22 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | 24 | parser.add_argument("--input", type=str, default="./data/input_700_250_25.pkl", help='Input pickle') 25 | parser.add_argument("--target", type=str, default="./data/smile95.pkl", help='Target pattern pickle') 26 | 27 | parser.add_argument("--global_wb", type=int, default=2, help='Weight bitwidth') 28 | parser.add_argument("--global_ab", type=int, default=8, help='Membrane potential, synapse state bitwidth') 29 | parser.add_argument("--global_gb", type=int, default=8, help='Gradient bitwidth') 30 | parser.add_argument("--global_eb", type=int, default=8, help='Error bitwidth') 31 | parser.add_argument("--global_rb", type=int, default=16, help='Gradient RNG bitwidth') 32 | 33 | parser.add_argument("--time_step", type=float, default=1e-3, help='Simulation time step size') 34 | parser.add_argument("--nb_steps", type=float, default=250, help='Simulation steps') 35 | parser.add_argument("--nb_epochs", type=float, default=10000, help='Simulation steps') 36 | 37 | parser.add_argument("--tau_mem", type=float, default=10e-3, help='Time constant for membrane potential') 38 | parser.add_argument("--tau_syn", type=float, default=5e-3, help='Time constant for synapse') 39 | parser.add_argument("--tau_vr", type=float, default=5e-3, help='Time constant for Van Rossum distance') 40 | parser.add_argument("--alpha", type=float, default=.75, help='Time constant for synapse') 41 | parser.add_argument("--beta", type=float, default=.875, help='Time constant for Van Rossum distance') 42 | 43 | parser.add_argument("--nb_inputs", type=int, default=700, help='Spatial input dimensions') 44 | parser.add_argument("--nb_hidden", type=int, default=400, help='Spatial hidden dimensions') 45 | parser.add_argument("--nb_outputs", type=int, default=250, help='Spatial output dimensions') 46 | 47 | args = parser.parse_args() 48 | 49 | 50 | quantization.global_wb = args.global_wb 51 | quantization.global_ab = args.global_ab 52 | quantization.global_gb = args.global_gb 53 | quantization.global_eb = args.global_eb 54 | quantization.global_rb = args.global_rb 55 | stop_quant_level = 33 56 | 57 | time_step = args.time_step 58 | nb_steps = args.nb_steps 59 | tau_mem = args.tau_mem 60 | tau_syn = args.tau_syn 61 | tau_vr = args.tau_vr 62 | 63 | alpha = args.alpha 64 | beta = args.beta 65 | 66 | nb_inputs = args.nb_inputs 67 | nb_hidden = args.nb_hidden 68 | nb_outputs = args.nb_outputs 69 | 70 | def conv_exp_kernel(inputs, time_step, tau, device): 71 | dtype = torch.float 72 | nb_hidden = inputs.shape[1] 73 | nb_steps = inputs.shape[0] 74 | 75 | u = torch.zeros((nb_hidden), device=device, dtype=dtype) 76 | rec_u = [] 77 | 78 | for t in range(nb_steps): 79 | u = alpha*u + inputs[t,:] 80 | rec_u.append(u) 81 | 82 | rec_u = torch.stack(rec_u,dim=0) 83 | return rec_u 84 | 85 | def van_rossum(x, y, time_step, tau, device): 86 | tild_x = conv_exp_kernel(x, time_step, tau, device) 87 | tild_y = conv_exp_kernel(y, time_step, tau, device) 88 | return torch.sqrt(1/tau*torch.sum((tild_x - tild_y)**2)) 89 | 90 | class SuperSpike(torch.autograd.Function): 91 | scale = 100.0 # controls steepness of surrogate gradient 92 | @staticmethod 93 | def forward(ctx, input): 94 | ctx.save_for_backward(input) 95 | out = torch.zeros_like(input) 96 | out[input > 0] = 1.0 97 | return out 98 | 99 | @staticmethod 100 | def backward(ctx, grad_output): 101 | input, = ctx.saved_tensors 102 | grad_input = grad_output.clone() 103 | grad = grad_input/(SuperSpike.scale*torch.abs(input)+1.0)**2 104 | return grad 105 | 106 | 107 | class einsum_linear(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, scale, bias=None): 110 | if quantization.global_wb < stop_quant_level: 111 | w_quant = quantization.quant_w(weight, scale) 112 | else: 113 | w_quant = weight 114 | 115 | h1 = torch.einsum("bc,cd->bd", (input, w_quant)) 116 | 117 | if bias is not None: 118 | output += bias.unsqueeze(0).expand_as(output) 119 | 120 | ctx.save_for_backward(input, w_quant, bias) 121 | 122 | return h1 123 | 124 | @staticmethod 125 | def backward(ctx, grad_output): 126 | input, w_quant, bias = ctx.saved_tensors 127 | grad_input = grad_weight = grad_bias = None 128 | if quantization.global_eb < stop_quant_level: 129 | quant_error = quantization.quant_err(grad_output) 130 | else: 131 | quant_error = grad_output 132 | 133 | if ctx.needs_input_grad[0]: 134 | # propagate quantized error 135 | grad_input = torch.einsum("bc,dc->bd", (quant_error, w_quant)) 136 | 137 | if ctx.needs_input_grad[1]: 138 | if quantization.global_gb < stop_quant_level: 139 | grad_weight = quantization.quant_grad(torch.einsum("bc,bd->dc", (quant_error, input))).float() 140 | else: 141 | grad_weight = torch.einsum("bc,bd->dc", (quant_error, input)) 142 | 143 | if bias is not None and ctx.needs_input_grad[2]: 144 | grad_bias = grad_output.sum(0).squeeze(0) 145 | 146 | return grad_input, grad_weight, grad_bias 147 | 148 | 149 | class custom_quant(torch.autograd.Function): 150 | @staticmethod 151 | def forward(ctx, input, b_level): 152 | if quantization.global_ab < stop_quant_level: 153 | output, clip_info = quantization.quant_act(input) 154 | else: 155 | output, clip_info = input, None 156 | ctx.save_for_backward(clip_info) 157 | return output 158 | 159 | @staticmethod 160 | def backward(ctx, grad_output): 161 | clip_info = ctx.saved_tensors 162 | if quantization.global_eb < stop_quant_level: 163 | quant_error = quantization.quant_err(grad_output) * clip_info[0].float() 164 | else: 165 | quant_error = grad_output 166 | return quant_error, None 167 | 168 | 169 | def run_snn(inputs): 170 | with torch.no_grad(): 171 | spytorch_util.w1.data = quantization.clip(spytorch_util.w1.data, quantization.global_wb) 172 | spytorch_util.w2.data = quantization.clip(spytorch_util.w2.data, quantization.global_wb) 173 | 174 | 175 | h1 = einsum_linear.apply(inputs, spytorch_util.w1, scale1) 176 | 177 | syn = torch.zeros((nb_hidden), device=device, dtype=dtype) 178 | mem = torch.zeros((nb_hidden), device=device, dtype=dtype) 179 | 180 | mem_rec = [] 181 | spk_rec = [] 182 | 183 | # Compute hidden layer activity 184 | for t in range(nb_steps): 185 | mthr = mem-.9 186 | mthr = custom_quant.apply(mthr, quantization.global_ab) 187 | out = spike_fn(mthr) 188 | 189 | rst = torch.zeros_like(mem) 190 | c = (mthr > 0) 191 | rst[c] = torch.ones_like(mem)[c] 192 | 193 | new_syn = alpha*syn +h1[t,:] 194 | new_syn = custom_quant.apply(new_syn, quantization.global_ab) 195 | new_mem = beta*mem +syn -rst 196 | new_mem = custom_quant.apply(new_mem, quantization.global_ab) 197 | 198 | syn = new_syn 199 | mem = new_mem 200 | 201 | mem_rec.append(mem) 202 | spk_rec.append(out) 203 | 204 | mem_rec1 = torch.stack(mem_rec,dim=0) 205 | spk_rec1 = torch.stack(spk_rec,dim=0) 206 | 207 | 208 | #Readout layer 209 | h2 = einsum_linear.apply(spk_rec1, spytorch_util.w2, scale2) 210 | 211 | syn = torch.zeros((nb_outputs), device=device, dtype=dtype) 212 | mem = torch.zeros((nb_outputs), device=device, dtype=dtype) 213 | 214 | mem_rec = [] 215 | spk_rec = [] 216 | 217 | for t in range(nb_steps): 218 | mthr = mem-.9 219 | mthr = custom_quant.apply(mthr, quantization.global_ab) 220 | out = spike_fn(mthr) 221 | 222 | rst = torch.zeros_like(mem) 223 | c = (mthr > 0) 224 | rst[c] = torch.ones_like(mem)[c] 225 | 226 | new_syn = alpha*syn +h2[t,:] 227 | new_syn = custom_quant.apply(new_syn, quantization.global_ab) 228 | new_mem = beta*mem +syn -rst 229 | new_mem = custom_quant.apply(new_mem, quantization.global_ab) 230 | 231 | mem = new_mem 232 | syn = new_syn 233 | 234 | mem_rec.append(mem) 235 | spk_rec.append(out) 236 | 237 | mem_rec2 = torch.stack(mem_rec,dim=0) 238 | spk_rec2 = torch.stack(spk_rec,dim=0) 239 | 240 | 241 | other_recs = [mem_rec1, spk_rec1, mem_rec2] 242 | return spk_rec2, other_recs 243 | 244 | 245 | def train(x_data, y_data, lr=1e-3, nb_epochs=10): 246 | params = [spytorch_util.w1,spytorch_util.w2] 247 | optimizer = torch.optim.Adamax(params, lr=lr, betas=(0.9,0.999)) 248 | 249 | loss_hist = [] 250 | acc_hist = [] 251 | for e in range(nb_epochs): 252 | output,recs = run_snn(x_data) 253 | loss_val = van_rossum(output, y_data, time_step, tau_syn, device) 254 | 255 | optimizer.zero_grad() 256 | loss_val.backward() 257 | optimizer.step() 258 | 259 | loss_hist.append(loss_val.item()) 260 | print("Epoch %i: loss=%.5f"%(e+1,loss_val.item())) 261 | 262 | return loss_hist, output 263 | 264 | spike_fn = SuperSpike.apply 265 | 266 | 267 | quantization.global_beta = quantization.step_d(quantization.global_wb)-.5 268 | with open(args.input, 'rb') as f: 269 | x_train = pickle.load(f).t().to(device) 270 | 271 | 272 | with open(args.target, 'rb') as f: 273 | y_train = torch.tensor(pickle.load(f)).to(device) 274 | y_train = y_train.type(dtype) 275 | 276 | 277 | bit_string = str(quantization.global_wb) + str(quantization.global_ab) + str(quantization.global_gb) + str(quantization.global_eb) 278 | 279 | print("Start Training") 280 | print(bit_string) 281 | 282 | spytorch_util.w1 = torch.empty((nb_inputs, nb_hidden), device=device, dtype=dtype, requires_grad=True) 283 | scale1 = quantization.init_layer_weights(spytorch_util.w1, 28*28).to(device) 284 | 285 | spytorch_util.w2 = torch.empty((nb_hidden, nb_outputs), device=device, dtype=dtype, requires_grad=True) 286 | scale2 = quantization.init_layer_weights(spytorch_util.w2, 28*28).to(device) 287 | 288 | 289 | quantization.global_lr = .1 290 | loss_hist, output = train(x_train, y_train, lr = 1, nb_epochs = args.nb_epochs) 291 | 292 | bit_string = str(quantization.global_wb) + str(quantization.global_ab) + str(quantization.global_gb) + str(quantization.global_eb) 293 | 294 | results = {'bit_string': bit_string ,'loss_hist': loss_hist, 'output': output.cpu()} 295 | 296 | with open('results/snn_smile_precise_'+bit_string+'.pkl', 'wb') as f: 297 | pickle.dump(results, f) 298 | 299 | 300 | 301 | -------------------------------------------------------------------------------- /qsnn_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.gridspec import GridSpec 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torchvision 10 | 11 | 12 | time_step = 1e-3 13 | 14 | # Check whether a GPU is available 15 | if torch.cuda.is_available(): 16 | device = torch.device("cuda") 17 | else: 18 | device = torch.device("cpu") 19 | 20 | w1 = None 21 | w2 = None 22 | 23 | 24 | def current2firing_time(x, tau=20, thr=0.2, tmax=1.0, epsilon=1e-7): 25 | """ Computes first firing time latency for a current input x assuming the charge time of a current based LIF neuron. 26 | 27 | Args: 28 | x -- The "current" values 29 | 30 | Keyword args: 31 | tau -- The membrane time constant of the LIF neuron to be charged 32 | thr -- The firing threshold value 33 | tmax -- The maximum time returned 34 | epsilon -- A generic (small) epsilon > 0 35 | 36 | Returns: 37 | Time to first spike for each "current" x 38 | """ 39 | idx = x 0] = 1.0 127 | return out 128 | 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | """ 132 | In the backward pass we receive a Tensor we need to compute the 133 | surrogate gradient of the loss with respect to the input. 134 | Here we use the normalized negative part of a fast sigmoid 135 | as this was done in Zenke & Ganguli (2018). 136 | """ 137 | input, = ctx.saved_tensors 138 | grad_input = grad_output.clone() 139 | grad = grad_input/(SuperSpike.scale*torch.abs(input)+1.0)**2 140 | return grad 141 | 142 | 143 | def sparse_data_generator_DVS(X, y, batch_size, nb_steps, nb_units, shuffle, time_step, device): 144 | """ This generator takes datasets in analog format and generates spiking network input as sparse tensors. 145 | 146 | Args: 147 | X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples 148 | y: The labels 149 | """ 150 | 151 | try: 152 | labels_ = np.array(y.cpu(),dtype=np.int) 153 | except: 154 | labels_ = np.array(y,dtype=np.int) 155 | number_of_batches = len(y)//batch_size 156 | sample_index = np.arange(len(y)) 157 | 158 | 159 | if shuffle: 160 | np.random.shuffle(sample_index) 161 | 162 | total_batch_count = 0 163 | counter = 0 164 | while counter 1).sum() != 0: 153 | # import pdb; pdb.set_trace() 154 | alpha = shift(torch.max(torch.abs(x))) 155 | return quant(clip(x / alpha, global_eb), global_eb) 156 | 157 | def init_layer_weights(weights_layer, shape, factor=1): 158 | fan_in = shape 159 | 160 | limit = torch.sqrt(torch.tensor([3*factor/fan_in])) 161 | Wm = global_beta/step_d(torch.tensor([float(global_wb)])) 162 | scale = 2 ** round(math.log(Wm / limit, 2.0)) 163 | scale = scale if scale > 1 else 1.0 164 | limit = Wm if Wm > limit else limit 165 | 166 | torch.nn.init.uniform_(weights_layer, a = -float(limit), b = float(limit)) 167 | weights_layer.data = quant_generic(weights_layer.data, global_gb)[0] 168 | return torch.tensor([float(scale)]) 169 | 170 | # sum of square errors 171 | def SSE(y_true, y_pred): 172 | return 0.5 * torch.sum((y_true - y_pred)**2) 173 | 174 | def to_cat(inp_tensor, num_class, device): 175 | out_tensor = torch.zeros([inp_tensor.shape[0], num_class], device=device) 176 | out_tensor[torch.arange(inp_tensor.shape[0]).to(device), torch.tensor(inp_tensor, dtype = int, device=device)] = 1 177 | return out_tensor 178 | 179 | # Inherit from Function 180 | class clee_LinearFunction(torch.autograd.Function): 181 | 182 | # Note that both forward and backward are @staticmethods 183 | @staticmethod 184 | # bias is an optional argument 185 | def forward(ctx, input, weight, scale, act, act_q, bias=None): 186 | # prep and save 187 | w_quant = quant_w(weight, scale) 188 | input = input.float() 189 | 190 | # compute output 191 | output = input.mm(w_quant.t()) 192 | 193 | relu_mask = torch.ones(output.shape).to(output.device) 194 | clip_info = torch.ones(output.shape).to(output.device) 195 | 196 | # add relu and quant optionally 197 | if act: 198 | output = F.relu(output) 199 | relu_mask = (output != 0) 200 | if act_q: 201 | output, clip_info = quant_act(output) 202 | if bias is not None: 203 | output += bias.unsqueeze(0).expand_as(output) 204 | 205 | gradient_mask = relu_mask * clip_info 206 | 207 | ctx.save_for_backward(input, w_quant, bias, gradient_mask) 208 | return output 209 | 210 | # This function has only a single output, so it gets only one gradient 211 | @staticmethod 212 | def backward(ctx, grad_output): 213 | input, w_quant, bias, gradient_mask = ctx.saved_tensors 214 | grad_input = grad_weight = grad_bias = None 215 | quant_error = quant_err(grad_output) * gradient_mask.float() 216 | 217 | if ctx.needs_input_grad[0]: 218 | # propagate quantized error 219 | grad_input = quant_error.mm(w_quant) 220 | if ctx.needs_input_grad[1]: 221 | grad_weight = quant_grad(quant_error.t().mm(input)).float() 222 | 223 | if bias is not None and ctx.needs_input_grad[2]: 224 | grad_bias = grad_output.sum(0).squeeze(0) 225 | 226 | return grad_input, grad_weight, grad_bias, None, None 227 | 228 | # Inherit from Function 229 | class clee_conv2d(torch.autograd.Function): 230 | # Note that both forward and backward are @staticmethods 231 | @staticmethod 232 | # bias is an optional argument 233 | def forward(ctx, input, weight, scale, act=False, act_q=False, pool=False, bias=None): 234 | mpool1 = nn.MaxPool2d(2, stride=2, return_indices=True) 235 | 236 | # prep and save 237 | w_quant = quant_w(weight, scale) 238 | input = input.float() 239 | 240 | # compute output 241 | output = F.conv2d(input, w_quant, bias=None, stride=1, padding=0, dilation=1, groups=1) 242 | relu_mask = torch.ones(output.shape).to(output.device) 243 | clip_info = torch.ones(output.shape).to(output.device) 244 | pool_indices = torch.ones(output.shape).to(output.device) 245 | size_pool = torch.tensor([0]) 246 | 247 | # add pool, relu, quant optionally 248 | if pool: 249 | size_pool = output.shape 250 | output, pool_indices = mpool1(output) 251 | if act: 252 | output = F.relu(output) 253 | relu_mask = (output != 0) 254 | if act_q: 255 | output, clip_info = quant_act(output) 256 | if bias is not None: 257 | output += bias.unsqueeze(0).expand_as(output) 258 | 259 | gradient_mask = relu_mask * clip_info 260 | 261 | ctx.save_for_backward(input, w_quant, bias, torch.tensor([pool]), gradient_mask, pool_indices, torch.tensor(size_pool)) 262 | return output 263 | 264 | # This function has only a single output, so it gets only one gradient 265 | @staticmethod 266 | def backward(ctx, grad_output): 267 | unpool1 = nn.MaxUnpool2d(2, stride=2, padding = 0) 268 | 269 | input, w_quant, bias, pool, gradient_mask, pool_indices, size_pool = ctx.saved_tensors 270 | grad_input = grad_weight = grad_bias = None 271 | 272 | grad_output = grad_output * gradient_mask.float() 273 | if pool: 274 | grad_output = unpool1(grad_output, pool_indices, output_size = torch.Size(size_pool)) 275 | 276 | quant_error = quant_err(grad_output) 277 | 278 | if ctx.needs_input_grad[0]: 279 | # propagate quantized error 280 | grad_input = torch.nn.grad.conv2d_input(input.shape, w_quant, quant_error) 281 | if ctx.needs_input_grad[1]: 282 | grad_weight = quant_grad(torch.nn.grad.conv2d_weight(input, w_quant.shape, quant_error)).float() 283 | 284 | if bias is not None and ctx.needs_input_grad[2]: 285 | grad_bias = grad_output.sum(0).squeeze(0) 286 | 287 | return grad_input, grad_weight, grad_bias, None, None, None --------------------------------------------------------------------------------