├── .gitignore
├── LICENSE
├── README.md
├── data
├── input_700_250_25.pkl
├── small_test_dvs_gesture.pickle
├── small_train_dvs_gesture.pickle
├── smile100.pkl
├── smile30.pkl
├── smile50.pkl
├── smile70.pkl
└── smile95.pkl
├── figures
├── .DS_Store
├── ._torch_wage_acc_cifar10_210810.png
├── ._torch_wage_acc_cifar10_21088.png
├── ._torch_wage_acc_cifar10_2888.png
├── ._torch_wage_acc_cifar10_310810.png
├── ._utorch_wage_acc_cifar10_21088.png
├── ICONS_PQ_distr.png
├── ICONS_QuantSNN.png
├── ICONS_curves.pdf
├── ICONS_data_set_gest.png
├── ICONS_data_set_poker.png
├── ICONS_sur.png
├── ICONS_unscatter.pdf
├── ISCAS_schem1.png
└── ISCAS_smile_black.png
├── localQ.py
├── prepGesture.py
├── prepPoker.py
├── qsnn_decolle.py
├── qsnn_precise.py
├── qsnn_util.py
└── quantization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/*
2 | results/*
3 | .DS_Store
4 | .README.md.icloud
5 | data/train_dvs_gesture.pickle
6 | data/test_dvs_gesture.pickle
7 | data/slow_poker_500_train.pickle
8 | data/slow_poker_500_test.pickle
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Siddharth Joshi, Clemens JS Schaefer
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Quantized Spiking Neural Networks
2 |
3 | This repository contains the models and training scripts used in the papers: ["Quantizing Spiking Neural Networks with Integers"](https://dl.acm.org/doi/abs/10.1145/3407197.3407203) (ICONS 2020) and ["Memory Organization for Energy-Efficient Learning and Inference in Digital Neuromorphic Accelerators"](https://ieeexplore.ieee.org/document/9180443) (ISCAS 2020).
4 |
5 | ## Requiremnts
6 |
7 | - Python
8 | - PyTorch
9 | - torchvision
10 | - NumPy
11 | - pickle
12 | - argparse
13 |
14 |
15 | ## Quantized SNNs for Spatio-Temporal Patterns
16 |
17 |
18 |
19 |
20 |
21 | All relevant code for the experiments from the ISCAS paper is contained in qsnn_precise.py, quantization.py and qsnn_util.py. To run the experiments execute:
22 |
23 | ```
24 | python qsnn_precise.py
25 | ```
26 |
27 |
28 |
29 |
30 | You can specify desired setting either as command-line arguments or within qsnn_precise.py.
31 |
32 | Optional arguments:
33 |
34 | | Argument | Description |
35 | |:-----------------------|:---------------------------------------------------------|
36 | | --input INPUT | Input pickle file (default: ./data/input_700_250_25.pkl) |
37 | | --target TARGET | Target pattern pickle (default: ./data/smile95.pkl) |
38 | | --global_wb GLOBAL_WB | Weight bitwidth (default: 2) |
39 | | --global_ab GLOBAL_AB | Membrane potential, synapse state bitwidth (default: 8) |
40 | | --global_gb GLOBAL_GB | Gradient bitwidth (default: 8) |
41 | | --global_eb GLOBAL_EB | Error bitwidth (default: 8) |
42 | | --global_rb GLOBAL_RB | Gradient RNG bitwidth (default: 16) |
43 | | --time_step TIME_STEP | Simulation time step size (default: 0.001) |
44 | | --nb_steps NB_STEPS | Simulation steps (default: 250) |
45 | | --nb_epochs NB_EPOCHS | Simulation steps (default: 10000) |
46 | | --tau_mem TAU_MEM | Time constant for membrane potential (default: 0.01) |
47 | | --tau_syn TAU_SYN | Time constant for synapse (default: 0.005) |
48 | | --tau_vr TAU_VR | Time constant for Van Rossum distance (default: 0.005) |
49 | | --alpha ALPHA | Time constant for synapse (default: 0.75) |
50 | | --beta BETA | Time constant for Van Rossum distance (default: 0.875) |
51 | | --nb_inputs NB_INPUTS | Spatial input dimensions (default: 700) |
52 | | --nb_hidden NB_HIDDEN | Spatial hidden dimensions (default: 400) |
53 | | --nb_outputs NB_OUTPUTS| Spatial output dimensions (default: 250) |
54 |
55 |
56 | ## Quantized SNNs for Gesture Detection with Local Learning
57 |
58 |
59 |
60 |
61 |
62 | Download and extract the [DVS Slow Poker](http://www2.imse-cnm.csic.es/caviar/SLOWPOKERDVS.html#:~:text=The%20SLOW%2DPOKER%2DDVS%20database,diamond%2C%20heart%20or%20spade) and [DVS Gesture](https://www.research.ibm.com/dvsgesture/) data set.
63 |
64 | To prepare the data run the following commands in the respective directories (e.g. in the directory of the DVS Poker data or the DVS Gesture data).
65 |
66 | ```
67 | python prepPoker.py
68 | ```
69 |
70 | ```
71 | python prepGesture.py
72 | ```
73 |
74 |
75 | All relevant code for the experiments from the ICONS paper is contained in qsnn_decolle.py, quantization.py and localQ.py. To run the experiments execute:
76 |
77 | ```
78 | python qsnn_decolle.py
79 | ```
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 | You can specify desired setting either as command-line arguments or within qsnn_decolle.py.
88 |
89 | Optional arguments:
90 |
91 | | Argument | Description |
92 | |:-----------------------|:---------------------------------------------------------|
93 | | --data-set DATA_SET | Input date set: Poker/Gesture (default: Gesture) |
94 | | --global_wb GLOBAL_WB | Weight bitwidth (default: 8) |
95 | | --global_qb GLOBAL_QB | Synapse bitwidth (default: 10) |
96 | | --global_pb GLOBAL_PB | Membrane trace bitwidth (default: 12) |
97 | | --global_rfb GLOBAL_RFB | Refractory bitwidth (default: 2) |
98 | | --global_sb GLOBAL_SB | Learning signal bitwidth (default: 6) |
99 | | --global_gb GLOBAL_GB | Gradient bitwidth (default: 10) |
100 | | --global_eb GLOBAL_EB | Error bitwidth (default: 6) |
101 | | --global_ub GLOBAL_UB | Membrane Potential bitwidth (default: 6) |
102 | | --global_ab GLOBAL_AB | Activation bitwidth (default: 6) |
103 | | --global_sig GLOBAL_SIG | Sigmoid bitwidth (default: 6) |
104 | | --global_rb GLOBAL_RB | Gradient RNG bitwidth (default: 16) |
105 | | --global_lr GLOBAL_LR | Learning rate for quantized gradients (default: 1) |
106 | | --global_lr_sgd GLOBAL_LR_SGD | Learning rate for SGD (default: 1e-09) |
107 | | --global_beta GLOBAL_BETA | Beta for weight init (default: 1.5) |
108 | | --delta_t DELTA_T | Time step in ms (default: 0.001) |
109 | | --input_mode INPUT_MODE | Spike processing method (default: 0) |
110 | | --ds DS | Downsampling (default: 4) |
111 | | --epochs EPOCHS | Epochs for training (default: 320) |
112 | | --lr_div LR_DIV | Learning rate divide interval (default: 80) |
113 | | --batch_size BATCH_SIZE | Batch size (default: 72) |
114 | | --PQ_cap PQ_CAP | Value cap for membrane and synpase trace (default: 1) |
115 | | --weight_mult WEIGHT_MULT | Weight multiplier (default: 4e-05) |
116 | | --dropout_p DROPOUT_P | Dropout probability (default: 0.5) |
117 | | --lc_ampl LC_AMPL | Magnitude amplifier for weight init (default: 0.5) |
118 | | --l1 L1 | Regularizer 1 (default: 0.001) |
119 | | --l2 L2 | Regularizer 2 (default: 0.001) |
120 | | --tau_mem_lower TAU_MEM_LOWER | Tau mem lower bound (default: 5) |
121 | | --tau_mem_upper TAU_MEM_UPPER | Tau mem upper bound (default: 35) |
122 | | --tau_syn_lower TAU_SYN_LOWER | Tau syn lower bound (default: 5) |
123 | | --tau_syn_upper TAU_SYN_UPPER | Tau syn upper bound (default: 10) |
124 | | --tau_ref TAU_REF | Tau ref (default: 2.857142857142857) |
125 |
126 |
--------------------------------------------------------------------------------
/data/input_700_250_25.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/input_700_250_25.pkl
--------------------------------------------------------------------------------
/data/small_test_dvs_gesture.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/small_test_dvs_gesture.pickle
--------------------------------------------------------------------------------
/data/small_train_dvs_gesture.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/small_train_dvs_gesture.pickle
--------------------------------------------------------------------------------
/data/smile100.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile100.pkl
--------------------------------------------------------------------------------
/data/smile30.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile30.pkl
--------------------------------------------------------------------------------
/data/smile50.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile50.pkl
--------------------------------------------------------------------------------
/data/smile70.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile70.pkl
--------------------------------------------------------------------------------
/data/smile95.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/data/smile95.pkl
--------------------------------------------------------------------------------
/figures/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/.DS_Store
--------------------------------------------------------------------------------
/figures/._torch_wage_acc_cifar10_210810.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._torch_wage_acc_cifar10_210810.png
--------------------------------------------------------------------------------
/figures/._torch_wage_acc_cifar10_21088.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._torch_wage_acc_cifar10_21088.png
--------------------------------------------------------------------------------
/figures/._torch_wage_acc_cifar10_2888.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._torch_wage_acc_cifar10_2888.png
--------------------------------------------------------------------------------
/figures/._torch_wage_acc_cifar10_310810.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._torch_wage_acc_cifar10_310810.png
--------------------------------------------------------------------------------
/figures/._utorch_wage_acc_cifar10_21088.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/._utorch_wage_acc_cifar10_21088.png
--------------------------------------------------------------------------------
/figures/ICONS_PQ_distr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_PQ_distr.png
--------------------------------------------------------------------------------
/figures/ICONS_QuantSNN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_QuantSNN.png
--------------------------------------------------------------------------------
/figures/ICONS_curves.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_curves.pdf
--------------------------------------------------------------------------------
/figures/ICONS_data_set_gest.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_data_set_gest.png
--------------------------------------------------------------------------------
/figures/ICONS_data_set_poker.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_data_set_poker.png
--------------------------------------------------------------------------------
/figures/ICONS_sur.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_sur.png
--------------------------------------------------------------------------------
/figures/ICONS_unscatter.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ICONS_unscatter.pdf
--------------------------------------------------------------------------------
/figures/ISCAS_schem1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ISCAS_schem1.png
--------------------------------------------------------------------------------
/figures/ISCAS_smile_black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Intelligent-Microsystems-Lab/QuantizedSNNs/bb7a4998a5c932ff8d0e1ae961ee19e3c419de54/figures/ISCAS_smile_black.png
--------------------------------------------------------------------------------
/localQ.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision
5 | import pickle
6 | import time
7 | import math
8 | import numpy as np
9 | import matplotlib.pyplot as plt
10 | import seaborn as sns
11 | import uuid
12 |
13 | import quantization
14 |
15 |
16 | global lc_ampl
17 | lc_ampl = .5
18 |
19 | global shift_prob
20 | shift_prob = .5
21 |
22 |
23 | def create_graph(plot_file_name, diff_layers_acc, ds_name, best_test):
24 |
25 | bit_string = str(quantization.global_wb) + str(quantization.global_ub) + str(quantization.global_pb) + str(quantization.global_qb) + str(quantization.global_rfb) + " " + str(quantization.global_sb) + str(quantization.global_ab) + str(quantization.global_sig) + str(quantization.global_eb) + str(quantization.global_gb)
26 | bit_string = bit_string.replace("None", "f")
27 |
28 |
29 | fig, ax1 = plt.subplots()
30 | fig.set_size_inches(8.4, 4.8)
31 | plt.title(ds_name + " "+ bit_string + " Test3: " + str(np.round( best_test.item(), 4)) + " " +str(shift_prob))
32 | ax1.set_xlabel('Epochs')
33 | ax1.set_ylabel('Accuracy')
34 | t = np.arange(len(diff_layers_acc['loss']))
35 | ax1.plot(t, diff_layers_acc['train1'], 'g--', label = 'Train 1')
36 | ax1.plot(t, diff_layers_acc['train2'], 'b--', label = 'Train 2')
37 | ax1.plot(t, diff_layers_acc['train3'], 'r--', label = 'Train 3')
38 | ax1.plot(t, diff_layers_acc['test1'], 'g-', label = 'Test 1')
39 | ax1.plot(t, diff_layers_acc['test2'], 'b-', label = 'Test 2')
40 | ax1.plot(t, diff_layers_acc['test3'], 'r-', label = 'Test 3')
41 | ax1.plot([], [], 'k-', label = 'Loss')
42 | ax1.legend(bbox_to_anchor=(1.20,1), loc="upper left")
43 | #ax1.text(1.20, 0.1, str(max(diff_layers_acc['test3'])))
44 |
45 | ax2 = ax1.twinx()
46 | ax2.set_ylabel('Loss')
47 | ax2.plot(t, diff_layers_acc['loss'], 'k-', label = 'Loss')
48 |
49 | fig.tight_layout()
50 | plt.savefig("figures/"+plot_file_name + ".png")
51 | plt.close()
52 |
53 | def create_graph2(plot_file_name, diff_layers_acc, ds_name):
54 |
55 | bit_string = str(quantization.global_wb) + str(quantization.global_ub) + str(quantization.global_pb) + str(quantization.global_qb) + str(quantization.global_rfb) + " " + str(quantization.global_sb) + str(quantization.global_ab) + str(quantization.global_sig) + str(quantization.global_eb) + str(quantization.global_gb)
56 | bit_string = bit_string.replace("None", "f")
57 |
58 |
59 | fig, ax1 = plt.subplots()
60 | fig.set_size_inches(8.4, 4.8)
61 | plt.title(ds_name + " Act "+ bit_string)
62 | ax1.set_xlabel('Epochs')
63 | ax1.set_ylabel('# Spikes/Updates')
64 | t = np.arange(len(diff_layers_acc['loss']))
65 |
66 | ax1.plot(t, diff_layers_acc['act_train1'], 'g--', label = 'Train 1')
67 | ax1.plot(t, diff_layers_acc['act_train2'], 'b--', label = 'Train 2')
68 | ax1.plot(t, diff_layers_acc['act_train3'], 'r--', label = 'Train 3')
69 | ax1.plot(t, diff_layers_acc['act_test1'], 'g-', label = 'Test 1')
70 | ax1.plot(t, diff_layers_acc['act_test2'], 'b-', label = 'Test 2')
71 | ax1.plot(t, diff_layers_acc['act_test3'], 'r-', label = 'Test 3')
72 | ax1.plot(t, diff_layers_acc['w1update'], 'm-', label = 'W update 1')
73 | ax1.plot(t, diff_layers_acc['w2update'], 'k-', label = 'W update 2')
74 | ax1.plot(t, diff_layers_acc['w3update'], 'y-', label = 'W update 3')
75 | ax1.legend(bbox_to_anchor=(1.20,1), loc="upper left")
76 | #ax1.text(1.20, 0.1, str(max(diff_layers_acc['test3'])))
77 |
78 | #ax2 = ax1.twinx()
79 | #ax2.set_ylabel('Loss')
80 | #ax2.plot(t, diff_layers_acc['loss'], 'k-', label = 'Loss')
81 |
82 | fig.tight_layout()
83 | plt.savefig("figures/"+plot_file_name+ "_act.png")
84 | plt.close()
85 |
86 |
87 |
88 | def acc_comp(rread_hist_train, y_local, bools = False):
89 | rhts = torch.stack(rread_hist_train, dim = 0)
90 | if bools:
91 | return (rhts.mode(0)[0] == y_local).float()
92 | return (rhts.mode(0)[0] == y_local).float().mean()
93 |
94 | def clee_spikes(T, rates):
95 | spikes = np.ones((T, + np.prod(rates.shape)))
96 | spikes[np.random.binomial(1, (1000. - rates.flatten())/1000, size=(T, np.prod(rates.shape))).astype('bool')] = 0
97 | return spikes.T.reshape((rates.shape + (T,)))
98 |
99 | def prep_input(x_local, input_mode, channels = 2):
100 | #two channel trick / decolle
101 | if input_mode == 0:
102 | x_local[x_local > 0] = 1
103 |
104 | #down_spikes = torch.cat((x_local, x_local), dim = 1)
105 | #mask1 = (down_spikes > 0) # this might change
106 | #mask2 = (down_spikes < 0)
107 | #mask1[:,0,:,:] = False
108 | #mask2[:,1,:,:] = False
109 | #down_spikes = torch.zeros_like(down_spikes)
110 | #down_spikes[mask1] = 1
111 | #down_spikes[mask2] = 1
112 | return x_local
113 | #bi directional
114 | if input_mode == 2:
115 | x_local[:,0,:,:] *= -1
116 | new_spikes = x_local[:,0,:,:] + x_local[:,1,:,:]
117 | new_spikes = new_spikes.reshape([x_local.shape[0], 1, x_local.shape[2], x_local.shape[3]])
118 | new_spikes[new_spikes > 0] = 1
119 | new_spikes[new_spikes < 0] = -1
120 | return new_spikes
121 | # same same but different
122 | if input_mode == 1:
123 | down_spikes = x_local
124 | down_spikes[down_spikes != 0] = 1
125 | return down_spikes
126 | #bi directional two channels
127 | if input_mode == 3:
128 | x_local[:,0,:,:] *= -1
129 | new_spikes = x_local[:,0,:,:] + x_local[:,1,:,:]
130 | new_spikes = new_spikes.reshape([x_local.shape[0], 1, x_local.shape[2], x_local.shape[3]])
131 | new_spikes[new_spikes > 0] = 1
132 | new_spikes[new_spikes < 0] = -1
133 |
134 | new_spikes = torch.cat((new_spikes, new_spikes), dim = 1)
135 | return new_spikes
136 | else:
137 | return x_local
138 |
139 |
140 | def sparse_data_generator_DVSPoker(X, y, batch_size, nb_steps, shuffle, device, test = False):
141 | number_of_batches = int(np.ceil(len(y)/batch_size))
142 | sample_index = np.arange(len(y))
143 | nb_steps = nb_steps -1
144 | y = np.array(y)
145 |
146 | if shuffle:
147 | np.random.shuffle(sample_index)
148 |
149 | total_batch_count = 0
150 | counter = 0
151 | while counter 0] = 1
167 |
168 | sparse_matrix = sparse_matrix.reshape(torch.Size([sparse_matrix.shape[0], 1, sparse_matrix.shape[1], sparse_matrix.shape[2], sparse_matrix.shape[3]]))
169 |
170 | y_batch = torch.tensor(y[batch_index], dtype = int)
171 | try:
172 | torch.cuda.empty_cache()
173 | yield sparse_matrix.to(device=device), y_batch.to(device=device)
174 | counter += 1
175 | except StopIteration:
176 | return
177 |
178 | def sparse_data_generator_DVSGesture(X, y, batch_size, nb_steps, shuffle, device, ds = 4, test = False, x_size = 32, y_size = 32):
179 | number_of_batches = int(np.ceil(len(y)/batch_size))
180 | sample_index = np.arange(len(y))
181 | nb_steps = nb_steps -1
182 | y = np.array(y)
183 |
184 | if shuffle:
185 | np.random.shuffle(sample_index)
186 |
187 | total_batch_count = 0
188 | counter = 0
189 | while counter we sample
197 | if test:
198 | start_ts = 0
199 | else:
200 | start_ts = np.random.choice(np.arange(np.max(X[idx][:,0]) - nb_steps),1)
201 | temp = X[idx][X[idx][:,0] >= start_ts]
202 | temp = temp[temp[:,0] <= start_ts+nb_steps]
203 | temp = np.append(np.ones((temp.shape[0], 1))*bc, temp, axis=1)
204 | temp[:,1] = temp[:,1] - start_ts
205 | all_events = np.append(all_events, temp, axis = 0)
206 |
207 | # to matrix
208 | #all_events[:,4][all_events[:,4] == 0] = -1
209 | # spike_ind = (x_local == 1).nonzero()
210 | # spike_ind = spike_ind[torch.bernoulli((.5) * torch.ones(spike_ind.shape[0])).bool()]
211 | # spike_ind = spike_ind[torch.randperm(spike_ind.shape[0])]
212 | # split_point = int(spike_ind.shape[0]/2)
213 | # forward_spike = spike_ind[0:split_point]
214 | # backward_spike = spike_ind[split_point:]
215 |
216 | # x_local[torch.sparse.FloatTensor(forward_spike.t(), torch.ones(forward_spike.shape[0]).to(device)).to_dense().bool()] = 0
217 | # forward_spike[:,4] = forward_spike[:,4] + 1
218 | # forward_spike[forward_spike[:,4] == 500] = 499
219 | # x_local[torch.sparse.FloatTensor(forward_spike.t(), torch.ones(forward_spike.shape[0]).to(device)).to_dense().bool()] = 1
220 |
221 | # x_local[torch.sparse.FloatTensor(backward_spike.t(), torch.ones(backward_spike.shape[0]).to(device)).to_dense().bool()] = 0
222 | # backward_spike[:,4] = backward_spike[:,4] - 1
223 | # backward_spike[backward_spike[:,4] == -1] = 0
224 | # x_local[torch.sparse.FloatTensor(backward_spike.t(), torch.ones(backward_spike.shape[0]).to(device)).to_dense().bool()] = 1
225 |
226 |
227 | #change
228 | # by plus minus one process...
229 | # change_mask = torch.bernoulli((shift_prob) * torch.ones(all_events.shape[0])).bool()
230 | # forward_mask = change_mask * torch.bernoulli((.5) * torch.ones(all_events.shape[0])).bool()
231 | # backward_mask = (change_mask != forward_mask)
232 | # all_events[forward_mask, 1] = all_events[forward_mask, 1] + 1 #torch.randn(all_events[forward_mask, 1].shape[0])
233 | # all_events[backward_mask, 1] = all_events[backward_mask, 1] - 1
234 |
235 | all_events[:, 1] = all_events[:, 1] + (shift_prob*np.random.randn(all_events[:, 1].shape[0])).astype(int)
236 |
237 | neg_ind = (all_events[:,1] < 0)
238 | pos_ind = (all_events[:,1] > nb_steps)
239 | all_events[neg_ind,1] = 0
240 | all_events[pos_ind,1] = int(nb_steps)
241 |
242 |
243 | all_events = all_events[:,[0,4,2,3,1]]
244 | all_events[:, 2] = all_events[:, 2]//ds
245 | all_events[:, 3] = all_events[:, 3]//ds
246 | sparse_matrix = torch.sparse.FloatTensor(torch.LongTensor(all_events[:,[True, True, True, True, True]].T), torch.ones_like(torch.tensor(all_events[:,0])), torch.Size([len(y_batch),2,x_size,y_size,int(nb_steps+1)])).to_dense().type(torch.int16)
247 |
248 | # quick trick...
249 | #sparse_matrix[sparse_matrix != 0] = 1
250 | #sparse_matrix[sparse_matrix > 0] = 1
251 | #sparse_matrix = sparse_matrix.reshape(torch.Size([sparse_matrix.shape[0], 1, sparse_matrix.shape[1], sparse_matrix.shape[2], sparse_matrix.shape[3]]))
252 |
253 |
254 | try:
255 | torch.cuda.empty_cache()
256 | yield sparse_matrix.to(device=device), y_batch.to(device=device)
257 | counter += 1
258 | except StopIteration:
259 | return
260 |
261 | def sparse_data_generator_Static(X, y, batch_size, nb_steps, samples, max_hertz, shuffle=True, device=torch.device("cpu")):
262 | sample_idx = torch.randperm(len(X))[:samples]
263 | number_of_batches = int(np.ceil(samples/batch_size))
264 | nb_steps = int(nb_steps)
265 |
266 | counter = 0
267 | while counterac', input, weight)
305 | if bias is not None:
306 | output += bias.unsqueeze(0).expand_as(output)
307 |
308 | if quantization.global_sb is not None:
309 | output = output/scale
310 | # quant act
311 | if quantization.global_ab is not None:
312 | output, _ = quantization.quant_act(output)
313 |
314 | ctx.save_for_backward(input, weight, weight_fa, bias)
315 |
316 | # ify part here... shall we bring it between 0 and 1 for the targets
317 | return (output+1)/2
318 |
319 | @staticmethod
320 | def backward(ctx, grad_output):
321 | input, weight, weight_fa, bias = ctx.saved_tensors
322 | grad_input = None
323 |
324 | if quantization.global_eb is not None:
325 | quant_error = quantization.quant_err(grad_output) #* clip_info.float()
326 | else:
327 | quant_error = grad_output
328 |
329 | if ctx.needs_input_grad[0]:
330 | grad_input = torch.einsum('ab,bc->ac', quant_error, weight_fa)
331 |
332 | # quantizing here for sigmoid input
333 | if quantization.global_eb is not None:
334 | grad_input = quantization.quant_err(grad_input)
335 | else:
336 | grad_input = grad_input
337 |
338 | return grad_input, None, None, None, None
339 |
340 | class QLinearLayerSign(nn.Module):
341 | '''from https://github.com/L0SG/feedback-alignment-pytorch/'''
342 | def __init__(self, input_features, output_features, pass_through = False, bias = True, dtype = None, device = None):
343 | super(QLinearLayerSign, self).__init__()
344 | self.input_features = input_features
345 | self.output_features = output_features
346 | self.dtype = dtype
347 | self.device = device
348 |
349 | # weight and bias for forward pass
350 | self.weights = nn.Parameter(torch.empty((output_features, input_features), device=device, dtype=dtype, requires_grad=False))
351 | self.weight_fa = nn.Parameter(torch.empty((output_features, input_features), device=device, dtype=dtype, requires_grad=False))
352 | self.bias = nn.Parameter(torch.empty((output_features), device=device, dtype=dtype, requires_grad=False))
353 |
354 | if quantization.global_sb is not None:
355 | self.L_min = quantization.global_beta/quantization.step_d(torch.tensor([float(quantization.global_sb)]))
356 | #self.L = np.sqrt(6/self.input_features)
357 | self.L = lc_ampl/np.sqrt(self.input_features)
358 | self.scale = 2 ** round(math.log(self.L_min / self.L, 2.0))
359 | self.scale = self.scale if self.scale > 1 else 1.0
360 | self.L = np.max([self.L, self.L_min])
361 |
362 | #since those weights are fixed lets just initialize them between -1 and 1 to make use of all given bits
363 | self.L = lc_ampl/np.sqrt(self.input_features)
364 | self.scale = 2 ** round(math.log((1-self.L_min)/self.L, 2.0))
365 | self.scale = self.scale if self.scale > 1 else 1.0
366 | self.L = 1
367 |
368 | torch.nn.init.uniform_(self.weights, a = -self.L, b = self.L)
369 | torch.nn.init.uniform_(self.weight_fa, a = -self.L, b = self.L)
370 | if bias:
371 | torch.nn.init.uniform_(self.bias, a = -self.L, b = self.L)
372 | else:
373 | self.bias = None
374 |
375 | # quantize them
376 | with torch.no_grad():
377 | self.weights.data = quantization.quant_s(self.weights.data)
378 | self.weight_fa.data = quantization.quant_s(self.weight_fa.data)
379 | if self.bias is not None:
380 | self.bias.data = quantization.quant_s(self.bias.data)
381 | else:
382 | self.scale = 1
383 | self.stdv = lc_ampl/np.sqrt(self.input_features)
384 | torch.nn.init.uniform_(self.weights, a = -self.stdv, b = self.stdv)
385 | torch.nn.init.uniform_(self.weight_fa, a = -self.stdv, b = self.stdv)
386 | if bias:
387 | torch.nn.init.uniform_(self.bias, a = -self.stdv, b = self.stdv)
388 | else:
389 | self.bias = None
390 |
391 | # sign concordant weights in fwd and bwd pass
392 | #self.weight_fa = self.weights
393 | nonzero_mask = (self.weights.data != 0)
394 | self.weight_fa.data[nonzero_mask] *= torch.sign((torch.sign(self.weights.data) == torch.sign(self.weight_fa.data)).type(dtype) -.5)[nonzero_mask]
395 |
396 |
397 | def forward(self, input):
398 | return QLinearFunctional.apply(input, self.weights, self.weight_fa, self.bias, self.scale)
399 |
400 |
401 |
402 | class QSConv2dFunctional(torch.autograd.Function):
403 | @staticmethod
404 | def forward(ctx, input, weights, bias, scale, padding = 0, weight_mult = 1):
405 | if quantization.global_wb is not None:
406 | w_quant = quantization.quant_w(weights/weight_mult, 1) *weight_mult
407 | bias_quant = quantization.quant_w(bias/weight_mult, 1) *weight_mult
408 | else:
409 | w_quant = weights
410 | bias_quant = bias
411 | ctx.padding = padding
412 |
413 | output = F.conv2d(input = input, weight = w_quant, bias = bias_quant, padding = ctx.padding)
414 | if quantization.global_wb is not None:
415 | output = output / scale
416 |
417 | ctx.save_for_backward(input, w_quant, bias_quant)
418 |
419 |
420 | return output
421 |
422 | @staticmethod
423 | def backward(ctx, grad_output):
424 | input, w_quant, bias_quant = ctx.saved_tensors
425 | grad_input = grad_weight = grad_bias = None
426 |
427 | if quantization.global_eb is not None:
428 | quant_error = quantization.quant_err(grad_output)
429 | else:
430 | quant_error = grad_output
431 |
432 | # compute quantized error
433 | if ctx.needs_input_grad[0]:
434 | grad_input = torch.nn.grad.conv2d_input(input.shape, w_quant, quant_error, padding = ctx.padding)
435 | # computed quantized gradient
436 | if ctx.needs_input_grad[1]:
437 | if quantization.global_gb is not None:
438 | grad_weight = quantization.quant_grad(torch.nn.grad.conv2d_weight(input, w_quant.shape, quant_error, padding = ctx.padding)).float()
439 | else:
440 | grad_weight = torch.nn.grad.conv2d_weight(input, w_quant.shape, quant_error, padding = ctx.padding)
441 | # computed quantized bias
442 | if bias_quant is not None and ctx.needs_input_grad[2]:
443 | if quantization.global_gb is not None:
444 | grad_bias = quantization.quant_grad(quant_error.sum((0,2,3)).squeeze(0)).float()
445 | else:
446 | grad_bias = quant_error.sum((0,2,3)).squeeze(0)
447 |
448 | if input.shape[2] == 13:
449 | quantization.global_w3update += grad_bias.nonzero().shape[0] + grad_weight.nonzero().shape[0]
450 | if input.shape[2] == 15:
451 | quantization.global_w2update += grad_bias.nonzero().shape[0] + grad_weight.nonzero().shape[0]
452 | if input.shape[2] == 32:
453 | quantization.global_w1update += grad_bias.nonzero().shape[0] + grad_weight.nonzero().shape[0]
454 | return grad_input, grad_weight, grad_bias, None, None, None, None
455 |
456 |
457 | class LIFConv2dLayer(nn.Module):
458 | def __init__(self, inp_shape, kernel_size, out_channels, tau_syn, tau_mem, tau_ref, delta_t, pooling = 1, padding = 0, bias = True, thr = 1, device=torch.device("cpu"), dtype = torch.float, dropout_p = .5, output_neurons = 10, loss_fn = None, l1 = 0, l2 = 0, PQ_cap = 1, weight_mult = 4e-5):
459 | super(LIFConv2dLayer, self).__init__()
460 | self.device = device
461 | self.dtype = dtype
462 | self.inp_shape = inp_shape
463 | self.kernel_size = kernel_size
464 | self.out_channels = out_channels
465 | self.output_neurons = output_neurons
466 | self.padding = padding
467 | self.pooling = pooling
468 | self.thr = thr
469 | self.PQ_cap = PQ_cap
470 | self.weight_mult = weight_mult
471 | self.fan_in = kernel_size * kernel_size * inp_shape[0]
472 |
473 | self.dropout_learning = nn.Dropout(p=dropout_p)
474 | self.dropout_p = dropout_p
475 | self.l1 = l1
476 | self.l2 = l2
477 | self.loss_fn = loss_fn
478 |
479 | self.weights = nn.Parameter(torch.empty((self.out_channels, inp_shape[0], self.kernel_size, self.kernel_size), device=device, dtype=dtype, requires_grad=True))
480 |
481 | # decide which one you like
482 | self.stdv = 1 / np.sqrt(self.fan_in) #/ 250 * 1e-2
483 | #self.stdv = np.sqrt(6 / self.fan_in) #* self.weight_mult
484 | if quantization.global_wb is not None:
485 | self.L_min = quantization.global_beta/quantization.step_d(torch.tensor([float(quantization.global_wb)]))
486 | #self.stdv = np.sqrt(6/self.fan_in)
487 | self.scale = 2 ** round(math.log(self.L_min / self.stdv, 2.0))
488 | self.scale = self.scale if self.scale > 1 else 1.0
489 | self.L = np.max([self.stdv, self.L_min])
490 | torch.nn.init.uniform_(self.weights, a = -self.L * self.weight_mult, b = self.L* self.weight_mult)
491 | else:
492 | self.scale = 1
493 | torch.nn.init.uniform_(self.weights, a = -self.stdv * self.weight_mult, b = self.stdv* self.weight_mult)
494 |
495 | # bias has a different scale... just why?
496 | if bias:
497 | self.bias = nn.Parameter(torch.empty(self.out_channels, device=device, dtype=dtype, requires_grad=True))
498 | if quantization.global_wb is not None:
499 | bias_L = np.max([self.stdv* 1e2, self.L_min])
500 | torch.nn.init.uniform_(self.bias, a = -bias_L * self.weight_mult, b = bias_L* self.weight_mult)
501 | else:
502 | torch.nn.init.uniform_(self.bias, a = -self.stdv* self.weight_mult* 1e2, b = self.stdv* self.weight_mult * 1e2)
503 | else:
504 | self.register_parameter('bias', None)
505 |
506 | self.mpool = nn.MaxPool2d(kernel_size = self.pooling, stride = self.pooling, padding = (self.pooling-1)//2, return_indices=False)
507 | self.out_shape2 = self.mpool(QSConv2dFunctional.apply(torch.zeros((1,)+self.inp_shape, dtype = dtype).to(device), self.weights, self.bias, self.scale, self.padding)).shape[1:] #self.pooling,
508 | self.out_shape = QSConv2dFunctional.apply(torch.zeros((1,)+self.inp_shape, dtype = dtype).to(device), self.weights, self.bias, self.scale, self.padding).shape[1:]
509 |
510 | self.sign_random_readout = QLinearLayerSign(input_features = np.prod(self.out_shape2), output_features = output_neurons, pass_through = False, bias = False, dtype = self.dtype, device = device).to(device)
511 |
512 | # tau quantization, static hardware friendly values
513 | if tau_syn.shape[0] == 2:
514 | self.tau_syn = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_syn[0], tau_syn[1]).to(device)
515 | self.beta = 1. - 1e-3 / self.tau_syn
516 | self.tau_syn = 1. / (1. - self.beta)
517 | else:
518 | self.beta = torch.tensor([1 - delta_t / tau_syn], dtype = dtype).to(device)
519 | self.tau_syn = 1. / (1. - self.beta)
520 |
521 |
522 | if tau_mem.shape[0] == 2:
523 | self.tau_mem = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_mem[0], tau_mem[1]).to(device)
524 | self.alpha = 1. - 1e-3 / self.tau_mem
525 | self.tau_mem = 1. / (1. - self.alpha)
526 | else:
527 | self.alpha = torch.tensor([1 - delta_t / tau_mem], dtype = dtype).to(device)
528 | self.tau_mem = 1. / (1. - self.alpha)
529 |
530 |
531 | if tau_ref.shape[0] == 2:
532 | self.tau_ref = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_ref[0], tau_ref[1]).to(device)
533 | self.gamma = 1. - 1e-3 / self.tau_gamma
534 | self.tau_ref = 1. / (1. - self.gamma)
535 | else:
536 | self.gamma = torch.tensor([1 - delta_t / tau_ref], dtype = dtype).to(device)
537 | self.tau_ref = 1. / (1. - self.gamma)
538 |
539 | self.r_scale = 1/(1-self.gamma) # the one comes from decolle, best value ?
540 | #self.q_scale = self.tau_syn/(1-self.beta)
541 | #self.q_scale = self.q_scale.max()
542 | # p_scale should be max overall to differentiate input signals
543 | #self.p_scale = (self.tau_mem * self.q_scale*self.PQ_cap)/(1-self.alpha)
544 | #self.p_scale = self.p_scale.max()
545 |
546 | self.inp_mult_q = self.tau_syn##1/self.PQ_cap * (1-self.beta.max()) #
547 | self.inp_mult_p = self.tau_mem##1/self.PQ_cap * (1-self.alpha.max()) #
548 | #self.pmult = self.p_scale * self.PQ_cap * self.weight_mult
549 |
550 | # those might be clamped as in chop off values.
551 | self.Q_scale = (self.tau_syn/(1-self.beta)).max()
552 | self.P_scale = ((self.tau_mem * self.Q_scale)/(1-self.alpha)).max()
553 | self.Q_scale = (self.tau_syn/(1-self.beta)).max()
554 | self.R_scale = 1/(1-self.gamma)
555 |
556 | if quantization.global_wb is not None:
557 | with torch.no_grad():
558 | self.weights.data = quantization.quant_w(self.weights.data)
559 | if self.bias is not None:
560 | self.bias.data = quantization.quant_w(self.bias.data)
561 |
562 |
563 | def state_init(self, batch_size):
564 | self.P = torch.zeros((batch_size,) + self.inp_shape, dtype = self.dtype).detach().to(self.device)
565 | self.Q = torch.zeros((batch_size,) + self.inp_shape, dtype = self.dtype).detach().to(self.device)
566 | self.R = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device)
567 | self.S = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device)
568 | self.U = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device)
569 |
570 |
571 | def forward(self, input_t, y_local, train_flag = False, test_flag = False):
572 | # probably dont need to quantize because gb steps are arleady in the right level... just clipping
573 | if quantization.global_gb is not None:
574 | with torch.no_grad():
575 | self.weights.data = quantization.clip(self.weights.data/self.weight_mult, quantization.global_gb)*self.weight_mult
576 | if self.bias is not None:
577 | self.bias.data = quantization.clip(self.bias.data/self.weight_mult, quantization.global_gb)*self.weight_mult
578 | if quantization.global_rfb is not None:
579 | # R always using full scale?
580 | self.R = quantization.quant01(self.R/self.R_scale, quantization.global_rfb)*self.R_scale
581 |
582 | #self.P, self.R, self.Q = self.alpha * self.P + self.tau_mem * self.Q, self.gamma * self.R, self.beta * self.Q + self.tau_syn * input_t
583 | #dtype necessary
584 | self.P, self.R, self.Q = self.alpha * self.P + self.inp_mult_p * self.Q, self.gamma * self.R, self.beta * self.Q + self.inp_mult_q * input_t.type(self.dtype)
585 |
586 | if self.PQ_cap != 1:
587 | self.P = torch.clamp(self.P, 0, self.P_scale*self.PQ_cap)
588 | self.Q = torch.clamp(self.Q, 0, self.Q_scale*self.PQ_cap)
589 |
590 | if quantization.global_pb is not None:
591 | self.P = torch.clamp(self.P/(self.P_scale*self.PQ_cap), 0, 1)
592 | self.P = quantization.quant01(self.P, quantization.global_pb)*(self.P_scale*self.PQ_cap)
593 | if quantization.global_qb is not None:
594 | self.Q = torch.clamp(self.Q/(self.Q_scale*self.PQ_cap), 0, 1)
595 | self.Q = quantization.quant01(self.Q, quantization.global_qb)*(self.Q_scale*self.PQ_cap)
596 |
597 | #self.U = QSConv2dFunctional.apply(self.P * self.pmult, self.weights, self.bias, self.scale, self.padding) - self.R
598 | self.U = QSConv2dFunctional.apply(self.P, self.weights, self.bias, self.scale, self.padding, self.weight_mult) - self.R #* self.r_scale
599 | if quantization.global_ub is not None:
600 | self.U = quantU.apply(self.U)
601 | self.S = (self.U >= self.thr).type(self.dtype) #float()
602 | self.R += self.S * 1#(1-self.gamma)
603 |
604 |
605 | if test_flag or train_flag:
606 | self.U_aux = torch.sigmoid(self.U) # quantize this function.... at some point
607 | self.U_aux = self.mpool(self.U_aux)
608 |
609 | rreadout = self.dropout_learning(self.sign_random_readout(self.U_aux.reshape([input_t.shape[0], np.prod(self.out_shape2)]))) * self.dropout_p
610 |
611 | if train_flag:
612 | if quantization.global_eb is not None:
613 | part1 = quantization.SSE(rreadout, y_local)
614 | #part2 = self.l1 * 200e-1 * F.relu((self.U+.01)).mean()
615 | part2 = self.l1 * 200e-1 * F.relu((self.U_aux+.01)).mean()
616 | #part3 = self.l2 *1e-1* F.relu(.1-self.U_aux.mean())
617 | part3 = self.l2 *1e-1* F.relu(.1-self.U.mean())
618 | loss_gen = part1 + part2 + part3
619 | else:
620 | part1 = self.loss_fn(rreadout, y_local)
621 | #part2 = self.l1 * 200e-1 * F.relu((self.U+.01)).mean()
622 | part2 = self.l1 * 200e-1 * F.relu((self.U_aux+.01)).mean()
623 | #part3 = self.l2 *1e-1* F.relu(.1-self.U_aux.mean())
624 | part3 = self.l2 *1e-1* F.relu(.1-self.U.mean())
625 | loss_gen = part1 + part2 + part3
626 | #loss_gen = self.loss_fn(rreadout, y_local) + self.l1 * 200e-1 * F.relu((self.U+.01)).mean() + self.l2 *1e-1* F.relu(.1-self.U_aux.mean())
627 | #loss_gen = self.loss_fn(rreadout, y_local) + self.l1 * 200e-1 * F.relu((self.U+.01)).mean() + self.l2 *1e-1* F.relu(.1-self.U_aux.mean())
628 | else:
629 | part1 = None
630 | part2 = None
631 | part3 = None
632 | loss_gen = None
633 | else:
634 | part1 = None
635 | part2 = None
636 | part3 = None
637 | loss_gen = None
638 | rreadout = torch.tensor([[0]])
639 |
640 |
641 | return self.mpool(self.S), loss_gen, rreadout.argmax(1), [part1, part2, part3]
642 |
643 |
644 |
645 | class DTNLIFConv2dLayer(nn.Module):
646 | def __init__(self, inp_shape, kernel_size, out_channels, tau_syn, tau_mem, tau_ref, delta_t, pooling = 1, padding = 0, bias = True, thr = 1, device=torch.device("cpu"), dtype = torch.float, dropout_p = .5, output_neurons = 10, loss_fn = None, l1 = 0, l2 = 0, PQ_cap = 1, weight_mult = 4e-5):
647 | super(DTNLIFConv2dLayer, self).__init__()
648 | self.device = device
649 | self.dtype = dtype
650 | self.inp_shape = inp_shape
651 | self.kernel_size = kernel_size
652 | self.out_channels = out_channels
653 | self.output_neurons = output_neurons
654 | self.padding = padding
655 | self.pooling = pooling
656 | self.thr = thr
657 | self.PQ_cap = PQ_cap
658 | self.weight_mult = weight_mult
659 | self.fan_in = kernel_size * kernel_size * inp_shape[0]
660 |
661 | self.dropout_learning = nn.Dropout(p=dropout_p)
662 | self.dropout_p = dropout_p
663 | self.l1 = l1
664 | self.l2 = l2
665 | self.loss_fn = loss_fn
666 |
667 | self.weights = nn.Parameter(torch.empty((self.out_channels, inp_shape[0], self.kernel_size, self.kernel_size), device=device, dtype=dtype, requires_grad=True))
668 |
669 | # decide which one you like
670 | self.stdv = 1 / np.sqrt(self.fan_in) #/ 250 * 1e-2
671 | #self.stdv = np.sqrt(6 / self.fan_in) #* self.weight_mult
672 | if quantization.global_wb is not None:
673 | self.L_min = quantization.global_beta/quantization.step_d(torch.tensor([float(quantization.global_wb)]))
674 | #self.stdv = np.sqrt(6/self.fan_in)
675 | self.scale = 2 ** round(math.log(self.L_min / self.stdv, 2.0))
676 | self.scale = self.scale if self.scale > 1 else 1.0
677 | self.L = np.max([self.stdv, self.L_min])
678 | torch.nn.init.uniform_(self.weights, a = -self.L * self.weight_mult, b = self.L* self.weight_mult)
679 | else:
680 | self.scale = 1
681 | torch.nn.init.uniform_(self.weights, a = -self.stdv * self.weight_mult, b = self.stdv* self.weight_mult)
682 |
683 | # bias has a different scale... just why?
684 | if bias:
685 | self.bias = nn.Parameter(torch.empty(self.out_channels, device=device, dtype=dtype, requires_grad=True))
686 | if quantization.global_wb is not None:
687 | bias_L = np.max([self.stdv* 1e2, self.L_min])
688 | torch.nn.init.uniform_(self.bias, a = -bias_L * self.weight_mult, b = bias_L* self.weight_mult)
689 | else:
690 | torch.nn.init.uniform_(self.bias, a = -self.stdv* self.weight_mult* 1e2, b = self.stdv* self.weight_mult * 1e2)
691 | else:
692 | self.register_parameter('bias', None)
693 |
694 | self.mpool = nn.MaxPool2d(kernel_size = self.pooling, stride = self.pooling, padding = (self.pooling-1)//2, return_indices=False)
695 | self.out_shape2 = self.mpool(QSConv2dFunctional.apply(torch.zeros((1,)+self.inp_shape, dtype = dtype).to(device), self.weights, self.bias, self.scale, self.padding)).shape[1:] #self.pooling,
696 | self.out_shape = QSConv2dFunctional.apply(torch.zeros((1,)+self.inp_shape, dtype = dtype).to(device), self.weights, self.bias, self.scale, self.padding).shape[1:]
697 |
698 | self.sign_random_readout = QLinearLayerSign(input_features = np.prod(self.out_shape2), output_features = output_neurons, pass_through = False, bias = False, dtype = self.dtype, device = device).to(device)
699 |
700 | # tau quantization, static hardware friendly values
701 | if tau_syn.shape[0] == 2:
702 | self.tau_syn = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_syn[0], tau_syn[1]).to(device)
703 | self.beta = 1. - 1e-3 / self.tau_syn
704 | self.tau_syn = 1. / (1. - self.beta)
705 | else:
706 | self.beta = torch.tensor([1 - delta_t / tau_syn], dtype = dtype).to(device)
707 | self.tau_syn = 1. / (1. - self.beta)
708 |
709 |
710 | if tau_mem.shape[0] == 2:
711 | self.tau_mem = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_mem[0], tau_mem[1]).to(device)
712 | self.alpha = 1. - 1e-3 / self.tau_mem
713 | self.tau_mem = 1. / (1. - self.alpha)
714 | else:
715 | self.alpha = torch.tensor([1 - delta_t / tau_mem], dtype = dtype).to(device)
716 | self.tau_mem = 1. / (1. - self.alpha)
717 |
718 |
719 | if tau_ref.shape[0] == 2:
720 | self.tau_ref = torch.empty(torch.Size(self.inp_shape), dtype = dtype).uniform_(tau_ref[0], tau_ref[1]).to(device)
721 | self.gamma = 1. - 1e-3 / self.tau_gamma
722 | self.tau_ref = 1. / (1. - self.gamma)
723 | else:
724 | self.gamma = torch.tensor([1 - delta_t / tau_ref], dtype = dtype).to(device)
725 | self.tau_ref = 1. / (1. - self.gamma)
726 |
727 | self.r_scale = 1/(1-self.gamma) # the one comes from decolle, best value ?
728 | #self.q_scale = self.tau_syn/(1-self.beta)
729 | #self.q_scale = self.q_scale.max()
730 | # p_scale should be max overall to differentiate input signals
731 | #self.p_scale = (self.tau_mem * self.q_scale*self.PQ_cap)/(1-self.alpha)
732 | #self.p_scale = self.p_scale.max()
733 |
734 | self.inp_mult_q = self.tau_syn##1/self.PQ_cap * (1-self.beta.max()) #
735 | self.inp_mult_p = self.tau_mem##1/self.PQ_cap * (1-self.alpha.max()) #
736 | #self.pmult = self.p_scale * self.PQ_cap * self.weight_mult
737 |
738 | # those might be clamped as in chop off values.
739 | self.Q_scale = (self.tau_syn/(1-self.beta)).max()
740 | self.P_scale = ((self.tau_mem * self.Q_scale)/(1-self.alpha)).max()
741 | self.Q_scale = (self.tau_syn/(1-self.beta)).max()
742 | self.R_scale = 1/(1-self.gamma)
743 |
744 | if quantization.global_wb is not None:
745 | with torch.no_grad():
746 | self.weights.data = quantization.quant_w(self.weights.data)
747 | if self.bias is not None:
748 | self.bias.data = quantization.quant_w(self.bias.data)
749 |
750 |
751 | def state_init(self, batch_size):
752 | self.P = torch.zeros((batch_size,) + self.inp_shape, dtype = self.dtype).detach().to(self.device)
753 | self.Q = torch.zeros((batch_size,) + self.inp_shape, dtype = self.dtype).detach().to(self.device)
754 | self.R = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device)
755 | self.S = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device)
756 | self.U = torch.zeros((batch_size,) + self.out_shape, dtype = self.dtype).detach().to(self.device)
757 |
758 |
759 | def forward(self, input_t, y_local, train_flag = False, test_flag = False):
760 | # probably dont need to quantize because gb steps are arleady in the right level... just clipping
761 | if quantization.global_gb is not None:
762 | with torch.no_grad():
763 | self.weights.data = quantization.clip(self.weights.data/self.weight_mult, quantization.global_gb)*self.weight_mult
764 | if self.bias is not None:
765 | self.bias.data = quantization.clip(self.bias.data/self.weight_mult, quantization.global_gb)*self.weight_mult
766 | if quantization.global_rfb is not None:
767 | # R always using full scale?
768 | self.R = quantization.quant01(self.R/self.R_scale, quantization.global_rfb)*self.R_scale
769 |
770 | #self.P, self.R, self.Q = self.alpha * self.P + self.tau_mem * self.Q, self.gamma * self.R, self.beta * self.Q + self.tau_syn * input_t
771 | #dtype necessary
772 | self.P, self.R, self.Q = self.alpha * self.P + self.inp_mult_p * self.Q, self.gamma * self.R, self.beta * self.Q + self.inp_mult_q * input_t.type(self.dtype)
773 |
774 | if self.PQ_cap != 1:
775 | self.P = torch.clamp(self.P, 0, self.P_scale*self.PQ_cap)
776 | self.Q = torch.clamp(self.Q, 0, self.Q_scale*self.PQ_cap)
777 |
778 | if quantization.global_pb is not None:
779 | self.P = torch.clamp(self.P/(self.P_scale*self.PQ_cap), 0, 1)
780 | self.P = quantization.quant01(self.P, quantization.global_pb)*(self.P_scale*self.PQ_cap)
781 | if quantization.global_qb is not None:
782 | self.Q = torch.clamp(self.Q/(self.Q_scale*self.PQ_cap), 0, 1)
783 | self.Q = quantization.quant01(self.Q, quantization.global_qb)*(self.Q_scale*self.PQ_cap)
784 |
785 | #self.U = QSConv2dFunctional.apply(self.P * self.pmult, self.weights, self.bias, self.scale, self.padding) - self.R
786 | self.U = QSConv2dFunctional.apply(self.P, self.weights, self.bias, self.scale, self.padding, self.weight_mult) - self.R #* self.r_scale
787 | if quantization.global_ub is not None:
788 | self.U = quantU.apply(self.U)
789 | self.S = (self.U >= self.thr).type(self.dtype)
790 | self.S += (self.U <= -self.thr).type(self.dtype)*-1
791 | self.R += self.S * self.thr#(1-self.gamma)
792 |
793 |
794 | if test_flag or train_flag:
795 | self.U_aux = torch.sigmoid(self.U) # quantize this function.... at some point
796 | self.U_aux = self.mpool(self.U_aux)
797 |
798 | rreadout = self.dropout_learning(self.sign_random_readout(self.U_aux.reshape([input_t.shape[0], np.prod(self.out_shape2)]))) * self.dropout_p
799 |
800 | if train_flag:
801 | if quantization.global_eb is not None:
802 | part1 = quantization.SSE(rreadout, y_local)
803 | part2 = self.l1 * 200e-1 * F.relu((self.U+.01)).mean()
804 | part3 = self.l2 *1e-1* F.relu(.1-self.U_aux.mean())
805 | loss_gen = part1 + part2 + part3
806 | else:
807 | part1 = self.loss_fn(rreadout, y_local)
808 | part2 = self.l1 * 200e-1 * F.relu((self.U+.01)).mean()
809 | part3 = self.l2 *1e-1* F.relu(.1-self.U_aux.mean())
810 | loss_gen = part1 + part2 + part3
811 | #loss_gen = self.loss_fn(rreadout, y_local) + self.l1 * 200e-1 * F.relu((self.U+.01)).mean() + self.l2 *1e-1* F.relu(.1-self.U_aux.mean())
812 | else:
813 | loss_gen = None
814 | else:
815 | loss_gen = None
816 | rreadout = torch.tensor([[0]])
817 |
818 |
819 | return self.mpool(self.S), loss_gen, rreadout.argmax(1), [part1, part2, part3]
820 |
821 |
822 |
--------------------------------------------------------------------------------
/prepGesture.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pickle
3 | import os
4 | import torch
5 | import pickle
6 |
7 |
8 |
9 | def read_aedat31(filename, labels_f, test_set = False):
10 | # https://inivation.com/support/software/fileformat/#aedat-31
11 | # http://research.ibm.com/dvsgesture/
12 | gestures_full = []
13 | labels_full = []
14 |
15 | # Addresses will be interpreted as 32 bits
16 | print(filename)
17 | f = open(filename, 'r', encoding='latin_1')
18 | labels = np.genfromtxt(labels_f, delimiter=',')[1:]
19 | #Skip header lines
20 | bof = f.tell()
21 | line = f.readline()
22 | while (line[0]=='#'):
23 | print(line, end='')
24 | bof = f.tell()
25 | line = f.readline()
26 |
27 | # read data
28 | f.seek(bof,0)
29 | dataArray = np.fromfile(f, '> 17 ) & 0x00001FFF
46 | y = ( addr >> 2 ) & 0x00001FFF
47 | polarity = ( addr >> 1 ) & 0x00000001
48 |
49 | # how to access header info
50 | # dataArray[0] >> 16 # event type -> polarity event
51 | # dataArray[0] & 0xFFFF0000 # event source ID
52 | # dataArray[1] # eventSize
53 | # dataArray[2] # eventTSOffset
54 | # dataArray[3] # eventTSOverflow
55 | # dataArray[4] # eventCapacity (always equals eventNumber)
56 | # dataArray[5] # eventNumber (valid + invalid)
57 | # dataArray[6] # eventValid
58 |
59 | stim = np.array([allTs, x, y, polarity]).T#.astype(int)
60 | for i in labels:
61 |
62 | # chop things right
63 | single_gesture = stim[stim[:, 0] >= i[1]]
64 | single_gesture = single_gesture[single_gesture[:, 0] <= i[2]]
65 |
66 | # bin them 1ms
67 | single_gesture[:,0] = np.floor(single_gesture[:,0]/1000)
68 | single_gesture[:,0] = single_gesture[:,0] - np.min(single_gesture[:,0])
69 |
70 | if test_set:
71 | single_gesture = single_gesture[single_gesture[:,0] <= 1800]
72 |
73 | #if i[0] in labels_full:
74 | # gestures_full[labels_full.index(i[0])] = np.vstack((gestures_full[labels_full.index(i[0])], single_gesture))
75 | #else:
76 | gestures_full.append(single_gesture)
77 | # record label
78 | labels_full.append(i[0])
79 | return gestures_full, labels_full
80 |
81 |
82 | # full set
83 | gestures_full = []
84 | labels_full = []
85 | with open('trials_to_train.txt') as fp:
86 | for cnt, line in enumerate(fp):
87 | try:
88 | gestures_temp, labels_temp = read_aedat31(line.split(".")[0] + ".aedat", line.split(".")[0] + "_labels.csv")
89 | gestures_full += gestures_temp
90 | labels_full += labels_temp
91 | except:
92 | continue
93 |
94 | with open('train_dvs_gesture.pickle', 'wb') as handle:
95 | pickle.dump((gestures_full, labels_full), handle)
96 |
97 |
98 |
99 |
100 | gestures_full = []
101 | labels_full = []
102 | with open('trials_to_test.txt') as fp:
103 | for cnt, line in enumerate(fp):
104 | try:
105 | gestures_temp, labels_temp = read_aedat31(line.split(".")[0] + ".aedat", line.split(".")[0] + "_labels.csv", test_set = True)
106 | gestures_full += gestures_temp
107 | labels_full += labels_temp
108 | except:
109 | continue
110 |
111 | with open('test_dvs_gesture.pickle', 'wb') as handle:
112 | pickle.dump((gestures_full, labels_full), handle)
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/prepPoker.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import pickle
4 | import os
5 |
6 |
7 | def dat2mat(filename, retinaSizeX, only_pos=False):
8 | """dat2mat.py: This script converts a aedat file into a list of events.
9 | It only works for 32 unsinged values in the aedat file.
10 |
11 | filename: name of the dat file
12 | retinaSizeX: one dimension of the retina size
13 | only_pos: True to delete all the negative spikes from the dat file
14 | """
15 | print('Addresses will be interpreted as 32 bits')
16 | maxEvents = 30e6
17 | numBytesPerEvent = 8
18 |
19 | f = open(filename, 'r', encoding='latin-1')
20 | bof = f.tell()
21 | #Skip header lines
22 | line = f.readline()
23 | while (line[0]=='#'):
24 | print(line)
25 | bof = f.tell()
26 | line = f.readline()
27 |
28 | #Calculate number of events
29 | f.seek(0,2) #EOF
30 | numEvents = (f.tell()-bof)/numBytesPerEvent
31 | if (numEvents>maxEvents):
32 | print("More events than the maximum events!!!")
33 | numEvents = maxEvents
34 | #Read data
35 | f.seek(bof,0)
36 | dataArray = np.fromfile(f, '>u4')
37 | allAddr = dataArray[::2]
38 | allTs = dataArray[1::2]
39 | f.close()
40 | #print allTs
41 |
42 | #Define event format
43 | xmask = 0xFE
44 | ymask = 0x7F00
45 | xshift = 1
46 | yshift = 8
47 | if (retinaSizeX == 32):
48 | xshift=3 #Subsampling of 4
49 | yshift=10 #Subsampling of 4
50 | polmask = 0x1
51 | addr = abs(allAddr)
52 | x = (addr & xmask)>>xshift
53 | y = (addr & ymask)>>yshift
54 | pol = 1 - (2*(addr & polmask)) #1 for ON, -1 for OFF
55 | pol = pol.astype(np.int32)
56 | '''
57 | #invert x
58 | x = retinaSizeX - x
59 | '''
60 | #Do relative time
61 | tpo = allTs;
62 | tpo[:] = tpo[:]-tpo[0]
63 |
64 | stim = np.array([tpo, np.zeros(x.size, dtype=np.int), \
65 | -1*np.ones(x.size, dtype=np.int), x, y, pol])
66 | stim = np.transpose(stim)
67 |
68 | if (only_pos == True):
69 | res_stim = stim[stim[:,5]==1, :]
70 | else:
71 | res_stim = stim
72 |
73 | # bin them 1ms
74 | res_stim[:,0] = np.floor(res_stim[:,0]/1000)
75 | #res_stim[:,0] = res_stim[:,0] - np.min(res_stim[:,0])
76 |
77 | return res_stim
78 |
79 |
80 |
81 | chunk_size = 500
82 | chunk_size = 1300
83 | chunk_size = 2400
84 | file_list = ["RetinaTeresa2-club_long.aedat", "RetinaTeresa2-diamond_long.aedat", "RetinaTeresa2-heart_long.aedat", "RetinaTeresa2-spade_long.aedat"]
85 | start_ts = np.arange(0,121000/chunk_size)*chunk_size
86 | end_ts = np.arange(0,121000/chunk_size)*chunk_size + chunk_size #its not 3min... one recording is just 2min!
87 | cards_full = []
88 | labels_full = []
89 |
90 | for idx,cur_file in enumerate(file_list):
91 | stim_cur = dat2mat(cur_file, 128, False)
92 | for i in np.arange(len(start_ts)):
93 | temp_cur = stim_cur[stim_cur[:,0] >= start_ts[i]]
94 | temp_cur = temp_cur[temp_cur[:,0] < end_ts[i]]
95 | if(len(temp_cur) == 0):
96 | import pdb; pdb.set_trace()
97 | temp_cur[:,0] = temp_cur[:,0]-start_ts[i]
98 | cards_full.append(temp_cur)
99 | labels_full += [idx]*len(start_ts)
100 |
101 | #80/20 split train/test
102 | cards_full = np.array(cards_full)
103 | labels_full = np.array(labels_full)
104 | shuffle_idx = np.arange(len(labels_full))
105 | np.random.shuffle(shuffle_idx)
106 | cards_full = cards_full[shuffle_idx]
107 | labels_full = labels_full[shuffle_idx]
108 |
109 |
110 | with open('slow_poker_'+str(chunk_size)+'_train.pickle', 'wb') as handle:
111 | pickle.dump((cards_full[:int(len(labels_full)*.8) ], labels_full[:int(len(labels_full)*.8) ]), handle)
112 | with open('slow_poker_'+str(chunk_size)+'_test.pickle', 'wb') as handle:
113 | pickle.dump((cards_full[int(len(labels_full)*.8):], labels_full[int(len(labels_full)*.8):]), handle)
114 |
115 |
116 |
--------------------------------------------------------------------------------
/qsnn_decolle.py:
--------------------------------------------------------------------------------
1 | import pickle, argparse, time, math, datetime, uuid
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torchvision
7 | import numpy as np
8 |
9 | import quantization
10 | import localQ
11 | from localQ import sparse_data_generator_Static, sparse_data_generator_DVSGesture, sparse_data_generator_DVSPoker, LIFConv2dLayer, prep_input, acc_comp, create_graph, DTNLIFConv2dLayer, create_graph2
12 |
13 |
14 | # Check whether a GPU is available
15 | if torch.cuda.is_available():
16 | device = torch.device("cuda")
17 | else:
18 | device = torch.device("cpu")
19 | dtype = torch.float32
20 | ms = 1e-3
21 |
22 |
23 |
24 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
25 | parser.add_argument("--data-set", type=str, default="Gesture", help='Input date set: Poker/Gesture')
26 |
27 | parser.add_argument("--global_wb", type=int, default=8, help='Weight bitwidth')
28 | parser.add_argument("--global_qb", type=int, default=10, help='Synapse bitwidth')
29 | parser.add_argument("--global_pb", type=int, default=12, help='Membrane trace bitwidth')
30 | parser.add_argument("--global_rfb", type=int, default=2, help='Refractory bitwidth')
31 |
32 | parser.add_argument("--global_sb", type=int, default=6, help='Learning signal bitwidth')
33 | parser.add_argument("--global_gb", type=int, default=10, help='Gradient bitwidth')
34 | parser.add_argument("--global_eb", type=int, default=6, help='Error bitwidth')
35 |
36 | parser.add_argument("--global_ub", type=int, default=6, help='Membrane Potential bitwidth')
37 | parser.add_argument("--global_ab", type=int, default=6, help='Activation bitwidth')
38 | parser.add_argument("--global_sig", type=int, default=6, help='Sigmoid bitwidth')
39 |
40 | parser.add_argument("--global_rb", type=int, default=16, help='Gradient RNG bitwidth')
41 | parser.add_argument("--global_lr", type=int, default=1, help='Learning rate for quantized gradients')
42 | parser.add_argument("--global_lr_sgd", type=float, default=1.0e-9, help='Learning rate for SGD')
43 | parser.add_argument("--global_beta", type=float, default=1.5, help='Beta for weight init')
44 |
45 | parser.add_argument("--delta_t", type=float, default=1*ms, help='Time step in ms')
46 | parser.add_argument("--input_mode", type=int, default=0, help='Spike processing method')
47 | parser.add_argument("--ds", type=int, default=4, help='Downsampling')
48 | parser.add_argument("--epochs", type=int, default=320, help='Epochs for training')
49 | parser.add_argument("--lr_div", type=int, default=80, help='Learning rate divide interval')
50 | parser.add_argument("--batch_size", type=int, default=72, help='Batch size')
51 |
52 | parser.add_argument("--PQ_cap", type=float, default=1, help='Value cap for membrane and synpase trace')
53 | parser.add_argument("--weight_mult", type=float, default=4e-5, help='Weight multiplier')
54 | parser.add_argument("--dropout_p", type=float, default=.5, help='Dropout probability')
55 | parser.add_argument("--lc_ampl", type=float, default=.5, help='Magnitude amplifier for weight init')
56 | parser.add_argument("--l1", type=float, default=.001, help='Regularizer 1')
57 | parser.add_argument("--l2", type=float, default=.001, help='Regularizer 2')
58 |
59 | parser.add_argument("--tau_mem_lower", type=float, default=5, help='Tau mem lower bound')
60 | parser.add_argument("--tau_mem_upper", type=float, default=35, help='Tau mem upper bound')
61 | parser.add_argument("--tau_syn_lower", type=float, default=5, help='Tau syn lower bound')
62 | parser.add_argument("--tau_syn_upper", type=float, default=10, help='Tau syn upper bound')
63 | parser.add_argument("--tau_ref", type=float, default=1/.35, help='Tau ref')
64 |
65 |
66 | args = parser.parse_args()
67 |
68 |
69 | # set quant level
70 | quantization.global_wb = args.global_wb
71 | quantization.global_qb = args.global_qb
72 | quantization.global_pb = args.global_pb
73 | quantization.global_rfb = args.global_rfb
74 |
75 | quantization.global_sb = args.global_sb
76 | quantization.global_gb = args.global_gb
77 | quantization.global_eb = args.global_eb
78 |
79 | quantization.global_ub = args.global_ub
80 | quantization.global_ab = args.global_ab
81 | quantization.global_sig = args.global_sig
82 |
83 | quantization.global_rb = args.global_rb
84 | quantization.global_lr = args.global_lr
85 | quantization.global_lr_sgd = args.global_lr_sgd
86 | quantization.global_beta = args.global_beta
87 | quantization.weight_mult = args.weight_mult
88 |
89 | localQ.lc_ampl = args.lc_ampl
90 |
91 | tau_mem = torch.tensor([args.tau_mem_lower*ms, args.tau_mem_upper*ms], dtype = dtype).to(device)
92 | tau_ref = torch.tensor([args.tau_ref*ms], dtype = dtype).to(device)
93 | tau_syn = torch.tensor([args.tau_syn_lower*ms, args.tau_syn_upper*ms], dtype = dtype).to(device)
94 |
95 |
96 | if args.data_set == "Poker":
97 | ds_name = "DVS Poker"
98 | with open('data/slow_poker_500_train.pickle', 'rb') as f:
99 | data = pickle.load(f)
100 | x_train = data[0].tolist()
101 | for i in range(len(x_train)):
102 | x_train[i] = x_train[i][:,[0,3,4,5]]
103 | x_train[i][:,3][x_train[i][:,3] == -1] = 0
104 | x_train[i] = x_train[i].astype('uint32')
105 | y_train = data[1]
106 |
107 | x_train = np.array(x_train)
108 | y_train = np.array(y_train)
109 |
110 | idx_temp = np.arange(len(x_train))
111 | np.random.shuffle(idx_temp)
112 | idx_train = idx_temp[0:int(len(y_train)*.8)]
113 | idx_val = idx_temp[int(len(y_train)*.8):]
114 |
115 | x_train, x_val = x_train[idx_train], x_train[idx_val]
116 | y_train, y_val = y_train[idx_train], y_train[idx_val]
117 |
118 | with open('data/slow_poker_500_test.pickle', 'rb') as f:
119 | data = pickle.load(f)
120 | x_test = data[0].tolist()
121 | for i in range(len(x_test)):
122 | x_test[i] = x_test[i][:,[0,3,4,5]]
123 | x_test[i][:,3][x_test[i][:,3] == -1] = 0
124 | x_test[i] = x_test[i].astype('uint32')
125 | y_test = data[1]
126 |
127 | output_neurons = 4
128 | T = 500*ms
129 | T_test = 500*ms
130 | burnin = 50*ms
131 | x_size = 32
132 | y_size = 32
133 | train_tflag = True
134 |
135 |
136 |
137 | elif args.data_set == "Gesture":
138 | ds_name = "DVS Gesture"
139 | with open('data/train_dvs_gesture88.pickle', 'rb') as f:
140 | data = pickle.load(f)
141 | x_train = np.array(data[0])
142 | y_train = np.array(data[1], dtype = int) - 1
143 |
144 | idx_temp = np.arange(len(x_train))
145 | np.random.shuffle(idx_temp)
146 | idx_train = idx_temp[0:int(len(y_train)*.8)]
147 | idx_val = idx_temp[int(len(y_train)*.8):]
148 |
149 | x_train, x_val = x_train[idx_train], x_train[idx_val]
150 | y_train, y_val = y_train[idx_train], y_train[idx_val]
151 |
152 |
153 | with open('data/test_dvs_gesture88.pickle', 'rb') as f:
154 | data = pickle.load(f)
155 | x_test = data[0]
156 | y_test = np.array(data[1], dtype = int) - 1
157 |
158 | output_neurons = 11
159 | T = 500*ms
160 | T_test = 1800*ms
161 | burnin = 50*ms
162 | x_size = 32
163 | y_size = 32
164 | train_tflag = False
165 | else:
166 | raise Exception("Data set unknown.")
167 |
168 | sl1_loss = torch.nn.MSELoss()
169 |
170 | thr = torch.tensor([.0], dtype = dtype).to(device)
171 | layer1 = LIFConv2dLayer(inp_shape = (2, x_size, y_size), kernel_size = 7, out_channels = 64, tau_mem = tau_mem, tau_syn = tau_syn, tau_ref = tau_ref, delta_t = args.delta_t, pooling = 2, padding = 2, thr = thr, device = device, dropout_p = args.dropout_p, output_neurons = output_neurons, loss_fn = sl1_loss, l1 = args.l1, l2 = args.l2, PQ_cap = args.PQ_cap, weight_mult = args.weight_mult, dtype = dtype).to(device)
172 |
173 | layer2 = LIFConv2dLayer(inp_shape = layer1.out_shape2, kernel_size = 7, out_channels = 128, tau_mem = tau_mem, tau_syn = tau_syn, tau_ref = tau_ref, delta_t = args.delta_t, pooling = 1, padding = 2, thr = thr, device = device, dropout_p = args.dropout_p, output_neurons = output_neurons, loss_fn = sl1_loss, l1 = args.l1, l2 = args.l2, PQ_cap = args.PQ_cap, weight_mult = args.weight_mult, dtype = dtype).to(device)
174 |
175 | layer3 = LIFConv2dLayer(inp_shape = layer2.out_shape2, kernel_size = 7, out_channels = 128, tau_mem = tau_mem, tau_syn = tau_syn, tau_ref = tau_ref, delta_t = args.delta_t, pooling = 2, padding = 2, thr = thr, device = device, dropout_p = args.dropout_p, output_neurons = output_neurons, loss_fn = sl1_loss, l1 = args.l1, l2 = args.l2, PQ_cap = args.PQ_cap, weight_mult = args.weight_mult, dtype = dtype).to(device)
176 |
177 |
178 | all_parameters = list(layer1.parameters()) + list(layer2.parameters()) + list(layer3.parameters())
179 |
180 | # initlialize optimizier
181 | if quantization.global_gb is not None:
182 | opt = torch.optim.SGD(all_parameters, lr = 1)
183 | else:
184 | opt = torch.optim.SGD(all_parameters, lr = quantization.global_lr_sgd)
185 |
186 | def eval_test():
187 | batch_corr = {'train1': [], 'test1': [],'train2': [], 'test2': [],'train3': [], 'test3': [], 'loss':[], 'act_train1':0, 'act_train2':0, 'act_train3':0, 'act_test1':0, 'act_test2':0, 'act_test3':0, 'w1u':0, 'w2u':0, 'w3u':0}
188 | # test accuracy
189 | for x_local, y_local in sparse_data_generator_DVSGesture(x_test, y_test, batch_size = args.batch_size, nb_steps = T_test / ms, shuffle = True, device = device, test = True, ds = args.ds, x_size = x_size, y_size = y_size):
190 | rread_hist1_test = []
191 | rread_hist2_test = []
192 | rread_hist3_test = []
193 |
194 | y_onehot = torch.Tensor(len(y_local), output_neurons).to(device).type(dtype)
195 | y_onehot.zero_()
196 | y_onehot.scatter_(1, y_local.reshape([y_local.shape[0],1]), 1)
197 |
198 |
199 | layer1.state_init(x_local.shape[0])
200 | layer2.state_init(x_local.shape[0])
201 | layer3.state_init(x_local.shape[0])
202 |
203 | for t in range(int(T_test/ms)):
204 | test_flag = (t > int(burnin/ms))
205 |
206 | out_spikes1, temp_loss1, temp_corr1, _ = layer1.forward(prep_input(x_local[:,:,:,:,t], args.input_mode), y_onehot, test_flag = test_flag)
207 | out_spikes2, temp_loss2, temp_corr2, _ = layer2.forward(out_spikes1, y_onehot, test_flag = test_flag)
208 | out_spikes3, temp_loss3, temp_corr3, _ = layer3.forward(out_spikes2, y_onehot, test_flag = test_flag)
209 |
210 | if test_flag:
211 | rread_hist1_test.append(temp_corr1)
212 | rread_hist2_test.append(temp_corr2)
213 | rread_hist3_test.append(temp_corr3)
214 |
215 |
216 | batch_corr['test1'].append(acc_comp(rread_hist1_test, y_local, True))
217 | batch_corr['test2'].append(acc_comp(rread_hist2_test, y_local, True))
218 | batch_corr['test3'].append(acc_comp(rread_hist3_test, y_local, True))
219 |
220 | return torch.cat(batch_corr['test3']).mean()
221 |
222 |
223 | w1, w2, w3, b1, b2, b3 = None, None, None, None, None, None
224 |
225 | diff_layers_acc = {'train1': [], 'test1': [],'train2': [], 'test2': [],'train3': [], 'test3': [], 'loss':[], 'act_train1':[], 'act_train2':[], 'act_train3':[], 'act_test1':[], 'act_test2':[], 'act_test3':[], 'w1update':[], 'w2update':[], 'w3update':[]}
226 | print("WUPQR SASigEG Quantization: {0}{1}{2}{3}{4} {5}{6}{7}{8}{9} l1 {10:.3f} l2 {11:.3f} Inp {12} LR {13} Drop {14} Cap {15} thr {16}".format(quantization.global_wb, quantization.global_ub, quantization.global_pb, quantization.global_qb, quantization.global_rfb, quantization.global_sb, quantization.global_ab, quantization.global_sig, quantization.global_eb, quantization.global_gb, args.l1, args.l2, args.input_mode, quantization.global_lr if quantization.global_lr != None else quantization.global_lr_sgd, args.dropout_p, args.PQ_cap, thr.item()))
227 | plot_file_name = "DVS_WPQUEG{0}{1}{2}{3}{4}{5}{6}_Inp{7}_LR{8}_Drop{9}_thr{10}".format(quantization.global_wb, quantization.global_pb, quantization.global_qb, quantization.global_ub, quantization.global_eb, quantization.global_gb, quantization.global_sb, args.input_mode, quantization.global_lr, args.dropout_p, thr.item())+datetime.datetime.now().strftime("_%Y%m%d_%H%M%S")
228 | print("Epoch Loss Train1 Train2 Train3 Test1 Test2 Test3 | TrainT TestT")
229 |
230 | best_vali = torch.tensor(0, device = device)
231 |
232 | for e in range(args.epochs):
233 | if ((e+1)%args.lr_div)==0:
234 | if quantization.global_gb is not None:
235 | quantization.global_lr /= 2
236 | else:
237 | opt.param_groups[-1]['lr'] /= 5
238 |
239 |
240 | batch_corr = {'train1': [], 'test1': [],'train2': [], 'test2': [],'train3': [], 'test3': [], 'loss':[], 'act_train1':0, 'act_train2':0, 'act_train3':0, 'act_test1':0, 'act_test2':0, 'act_test3':0, 'w1u':0, 'w2u':0, 'w3u':0}
241 | quantization.global_w1update = 0
242 | quantization.global_w2update = 0
243 | quantization.global_w3update = 0
244 | start_time = time.time()
245 |
246 | # training
247 | for x_local, y_local in sparse_data_generator_DVSGesture(x_train, y_train, batch_size = args.batch_size, nb_steps = T / ms, shuffle = True, test = train_tflag, device = device, ds = args.ds, x_size = x_size, y_size = y_size):
248 |
249 | y_onehot = torch.Tensor(len(y_local), output_neurons).to(device).type(dtype)
250 | y_onehot.zero_()
251 | y_onehot.scatter_(1, y_local.reshape([y_local.shape[0],1]), 1)
252 |
253 | rread_hist1_train = []
254 | rread_hist2_train = []
255 | rread_hist3_train = []
256 | loss_hist = []
257 |
258 |
259 | layer1.state_init(x_local.shape[0])
260 | layer2.state_init(x_local.shape[0])
261 | layer3.state_init(x_local.shape[0])
262 |
263 | for t in range(int(T/ms)):
264 | train_flag = (t > int(burnin/ms))
265 |
266 | out_spikes1, temp_loss1, temp_corr1, lparts1 = layer1.forward(prep_input(x_local[:,:,:,:,t], args.input_mode), y_onehot, train_flag = train_flag)
267 | out_spikes2, temp_loss2, temp_corr2, lparts2 = layer2.forward(out_spikes1, y_onehot, train_flag = train_flag)
268 | out_spikes3, temp_loss3, temp_corr3, lparts3 = layer3.forward(out_spikes2, y_onehot, train_flag = train_flag)
269 |
270 |
271 |
272 | if train_flag:
273 | loss_gen = temp_loss1 + temp_loss2 + temp_loss3
274 |
275 | loss_gen.backward()
276 | opt.step()
277 | opt.zero_grad()
278 |
279 | loss_hist.append(loss_gen.item())
280 | rread_hist1_train.append(temp_corr1)
281 | rread_hist2_train.append(temp_corr2)
282 | rread_hist3_train.append(temp_corr3)
283 |
284 |
285 | batch_corr['act_train1'] += int(out_spikes1.sum())
286 | batch_corr['act_train2'] += int(out_spikes2.sum())
287 | batch_corr['act_train3'] += int(out_spikes3.sum())
288 |
289 |
290 | batch_corr['train1'].append(acc_comp(rread_hist1_train, y_local, True))
291 | batch_corr['train2'].append(acc_comp(rread_hist2_train, y_local, True))
292 | batch_corr['train3'].append(acc_comp(rread_hist3_train, y_local, True))
293 | del x_local, y_local, y_onehot
294 |
295 |
296 | train_time = time.time()
297 |
298 | diff_layers_acc['train1'].append(torch.cat(batch_corr['train1']).mean())
299 | diff_layers_acc['train2'].append(torch.cat(batch_corr['train2']).mean())
300 | diff_layers_acc['train3'].append(torch.cat(batch_corr['train3']).mean())
301 | diff_layers_acc['act_train1'].append(batch_corr['act_train1'])
302 | diff_layers_acc['act_train2'].append(batch_corr['act_train2'])
303 | diff_layers_acc['act_train3'].append(batch_corr['act_train3'])
304 | diff_layers_acc['loss'].append(np.mean(loss_hist)/3)
305 | diff_layers_acc['w1update'].append(quantization.global_w1update)
306 | diff_layers_acc['w2update'].append(quantization.global_w2update)
307 | diff_layers_acc['w3update'].append(quantization.global_w3update)
308 |
309 |
310 | # test accuracy
311 | for x_local, y_local in sparse_data_generator_DVSGesture(x_val, y_val, batch_size = args.batch_size, nb_steps = T_test / ms, shuffle = True, device = device, test = True, ds = args.ds, x_size = x_size, y_size = y_size):
312 | rread_hist1_test = []
313 | rread_hist2_test = []
314 | rread_hist3_test = []
315 |
316 | y_onehot = torch.Tensor(len(y_local), output_neurons).to(device).type(dtype)
317 | y_onehot.zero_()
318 | y_onehot.scatter_(1, y_local.reshape([y_local.shape[0],1]), 1)
319 |
320 |
321 | layer1.state_init(x_local.shape[0])
322 | layer2.state_init(x_local.shape[0])
323 | layer3.state_init(x_local.shape[0])
324 |
325 | for t in range(int(T_test/ms)):
326 | test_flag = (t > int(burnin/ms))
327 |
328 | out_spikes1, temp_loss1, temp_corr1, _ = layer1.forward(prep_input(x_local[:,:,:,:,t], args.input_mode), y_onehot, test_flag = test_flag)
329 | out_spikes2, temp_loss2, temp_corr2, _ = layer2.forward(out_spikes1, y_onehot, test_flag = test_flag)
330 | out_spikes3, temp_loss3, temp_corr3, _ = layer3.forward(out_spikes2, y_onehot, test_flag = test_flag)
331 |
332 | if test_flag:
333 | rread_hist1_test.append(temp_corr1)
334 | rread_hist2_test.append(temp_corr2)
335 | rread_hist3_test.append(temp_corr3)
336 |
337 | batch_corr['act_test1'] += int(out_spikes1.sum())
338 | batch_corr['act_test2'] += int(out_spikes2.sum())
339 | batch_corr['act_test3'] += int(out_spikes3.sum())
340 |
341 | batch_corr['test1'].append(acc_comp(rread_hist1_test, y_local, True))
342 | batch_corr['test2'].append(acc_comp(rread_hist2_test, y_local, True))
343 | batch_corr['test3'].append(acc_comp(rread_hist3_test, y_local, True))
344 | del x_local, y_local, y_onehot
345 |
346 | inf_time = time.time()
347 |
348 | if best_vali.item() < torch.cat(batch_corr['test3']).mean().item():
349 | best_vali = torch.cat(batch_corr['test3']).mean()
350 | test_acc_best_vali = eval_test()
351 | w1 = layer1.weights.data.detach().cpu()
352 | w2 = layer2.weights.data.detach().cpu()
353 | w3 = layer3.weights.data.detach().cpu()
354 | b1 = layer1.bias.data.detach().cpu()
355 | b2 = layer2.bias.data.detach().cpu()
356 | b3 = layer3.bias.data.detach().cpu()
357 |
358 | diff_layers_acc['test1'].append(torch.cat(batch_corr['test1']).mean())
359 | diff_layers_acc['test2'].append(torch.cat(batch_corr['test2']).mean())
360 | diff_layers_acc['test3'].append(torch.cat(batch_corr['test3']).mean())
361 | diff_layers_acc['act_test1'].append(batch_corr['act_test1'])
362 | diff_layers_acc['act_test2'].append(batch_corr['act_test2'])
363 | diff_layers_acc['act_test3'].append(batch_corr['act_test3'])
364 |
365 | print("{0:02d} {1:.3E} {2:.4f} {3:.4f} {4:.4f} {5:.4f} {6:.4f} {7:.4f} | {8:.4f} {9:.4f}".format(e+1, diff_layers_acc['loss'][-1], diff_layers_acc['train1'][-1], diff_layers_acc['train2'][-1], diff_layers_acc['train3'][-1], diff_layers_acc['test1'][-1], diff_layers_acc['test2'][-1], diff_layers_acc['test3'][-1], train_time - start_time, inf_time - train_time))
366 | create_graph(plot_file_name, diff_layers_acc, ds_name, test_acc_best_vali)
367 |
368 |
369 |
370 | # saving results and weights
371 | results = {
372 | 'layer1':[layer1.weights.detach().cpu(), layer1.bias.detach().cpu(), w1, b1, layer1.sign_random_readout.weights.detach().cpu(), layer1.sign_random_readout.weight_fa.detach().cpu(), layer1.tau_mem.cpu(), layer1.tau_syn.cpu(), layer1.tau_ref.cpu()],
373 | 'layer2':[layer2.weights.detach().cpu(), layer2.bias.detach().cpu(), w2, b2, layer2.sign_random_readout.weights.detach().cpu(), layer2.sign_random_readout.weight_fa.detach().cpu(), layer2.tau_mem.cpu(), layer2.tau_syn.cpu(), layer2.tau_ref.cpu()],
374 | 'layer3':[layer3.weights.detach().cpu(), layer3.bias.detach().cpu(), w3, b3, layer3.sign_random_readout.weights.detach().cpu(), layer3.sign_random_readout.weight_fa.detach().cpu(), layer3.tau_mem.cpu(), layer3.tau_syn.cpu(), layer3.tau_ref.cpu()],
375 | 'acc': diff_layers_acc, 'fname':plot_file_name, 'args': args, 'evaled_test':test_acc_best_vali}
376 | with open('results/'+plot_file_name+'.pkl', 'wb') as f:
377 | pickle.dump(results, f)
378 |
379 |
--------------------------------------------------------------------------------
/qsnn_precise.py:
--------------------------------------------------------------------------------
1 | import argparse, pickle
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision
6 | import numpy as np
7 |
8 | import spytorch_util
9 | import quantization
10 |
11 | dtype = torch.float
12 |
13 | # Check whether a GPU is available
14 | if torch.cuda.is_available():
15 | device = torch.device("cuda")
16 | else:
17 | device = torch.device("cpu")
18 |
19 | # Code is based on: https://github.com/fzenke/spytorch
20 |
21 |
22 | parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
23 |
24 | parser.add_argument("--input", type=str, default="./data/input_700_250_25.pkl", help='Input pickle')
25 | parser.add_argument("--target", type=str, default="./data/smile95.pkl", help='Target pattern pickle')
26 |
27 | parser.add_argument("--global_wb", type=int, default=2, help='Weight bitwidth')
28 | parser.add_argument("--global_ab", type=int, default=8, help='Membrane potential, synapse state bitwidth')
29 | parser.add_argument("--global_gb", type=int, default=8, help='Gradient bitwidth')
30 | parser.add_argument("--global_eb", type=int, default=8, help='Error bitwidth')
31 | parser.add_argument("--global_rb", type=int, default=16, help='Gradient RNG bitwidth')
32 |
33 | parser.add_argument("--time_step", type=float, default=1e-3, help='Simulation time step size')
34 | parser.add_argument("--nb_steps", type=float, default=250, help='Simulation steps')
35 | parser.add_argument("--nb_epochs", type=float, default=10000, help='Simulation steps')
36 |
37 | parser.add_argument("--tau_mem", type=float, default=10e-3, help='Time constant for membrane potential')
38 | parser.add_argument("--tau_syn", type=float, default=5e-3, help='Time constant for synapse')
39 | parser.add_argument("--tau_vr", type=float, default=5e-3, help='Time constant for Van Rossum distance')
40 | parser.add_argument("--alpha", type=float, default=.75, help='Time constant for synapse')
41 | parser.add_argument("--beta", type=float, default=.875, help='Time constant for Van Rossum distance')
42 |
43 | parser.add_argument("--nb_inputs", type=int, default=700, help='Spatial input dimensions')
44 | parser.add_argument("--nb_hidden", type=int, default=400, help='Spatial hidden dimensions')
45 | parser.add_argument("--nb_outputs", type=int, default=250, help='Spatial output dimensions')
46 |
47 | args = parser.parse_args()
48 |
49 |
50 | quantization.global_wb = args.global_wb
51 | quantization.global_ab = args.global_ab
52 | quantization.global_gb = args.global_gb
53 | quantization.global_eb = args.global_eb
54 | quantization.global_rb = args.global_rb
55 | stop_quant_level = 33
56 |
57 | time_step = args.time_step
58 | nb_steps = args.nb_steps
59 | tau_mem = args.tau_mem
60 | tau_syn = args.tau_syn
61 | tau_vr = args.tau_vr
62 |
63 | alpha = args.alpha
64 | beta = args.beta
65 |
66 | nb_inputs = args.nb_inputs
67 | nb_hidden = args.nb_hidden
68 | nb_outputs = args.nb_outputs
69 |
70 | def conv_exp_kernel(inputs, time_step, tau, device):
71 | dtype = torch.float
72 | nb_hidden = inputs.shape[1]
73 | nb_steps = inputs.shape[0]
74 |
75 | u = torch.zeros((nb_hidden), device=device, dtype=dtype)
76 | rec_u = []
77 |
78 | for t in range(nb_steps):
79 | u = alpha*u + inputs[t,:]
80 | rec_u.append(u)
81 |
82 | rec_u = torch.stack(rec_u,dim=0)
83 | return rec_u
84 |
85 | def van_rossum(x, y, time_step, tau, device):
86 | tild_x = conv_exp_kernel(x, time_step, tau, device)
87 | tild_y = conv_exp_kernel(y, time_step, tau, device)
88 | return torch.sqrt(1/tau*torch.sum((tild_x - tild_y)**2))
89 |
90 | class SuperSpike(torch.autograd.Function):
91 | scale = 100.0 # controls steepness of surrogate gradient
92 | @staticmethod
93 | def forward(ctx, input):
94 | ctx.save_for_backward(input)
95 | out = torch.zeros_like(input)
96 | out[input > 0] = 1.0
97 | return out
98 |
99 | @staticmethod
100 | def backward(ctx, grad_output):
101 | input, = ctx.saved_tensors
102 | grad_input = grad_output.clone()
103 | grad = grad_input/(SuperSpike.scale*torch.abs(input)+1.0)**2
104 | return grad
105 |
106 |
107 | class einsum_linear(torch.autograd.Function):
108 | @staticmethod
109 | def forward(ctx, input, weight, scale, bias=None):
110 | if quantization.global_wb < stop_quant_level:
111 | w_quant = quantization.quant_w(weight, scale)
112 | else:
113 | w_quant = weight
114 |
115 | h1 = torch.einsum("bc,cd->bd", (input, w_quant))
116 |
117 | if bias is not None:
118 | output += bias.unsqueeze(0).expand_as(output)
119 |
120 | ctx.save_for_backward(input, w_quant, bias)
121 |
122 | return h1
123 |
124 | @staticmethod
125 | def backward(ctx, grad_output):
126 | input, w_quant, bias = ctx.saved_tensors
127 | grad_input = grad_weight = grad_bias = None
128 | if quantization.global_eb < stop_quant_level:
129 | quant_error = quantization.quant_err(grad_output)
130 | else:
131 | quant_error = grad_output
132 |
133 | if ctx.needs_input_grad[0]:
134 | # propagate quantized error
135 | grad_input = torch.einsum("bc,dc->bd", (quant_error, w_quant))
136 |
137 | if ctx.needs_input_grad[1]:
138 | if quantization.global_gb < stop_quant_level:
139 | grad_weight = quantization.quant_grad(torch.einsum("bc,bd->dc", (quant_error, input))).float()
140 | else:
141 | grad_weight = torch.einsum("bc,bd->dc", (quant_error, input))
142 |
143 | if bias is not None and ctx.needs_input_grad[2]:
144 | grad_bias = grad_output.sum(0).squeeze(0)
145 |
146 | return grad_input, grad_weight, grad_bias
147 |
148 |
149 | class custom_quant(torch.autograd.Function):
150 | @staticmethod
151 | def forward(ctx, input, b_level):
152 | if quantization.global_ab < stop_quant_level:
153 | output, clip_info = quantization.quant_act(input)
154 | else:
155 | output, clip_info = input, None
156 | ctx.save_for_backward(clip_info)
157 | return output
158 |
159 | @staticmethod
160 | def backward(ctx, grad_output):
161 | clip_info = ctx.saved_tensors
162 | if quantization.global_eb < stop_quant_level:
163 | quant_error = quantization.quant_err(grad_output) * clip_info[0].float()
164 | else:
165 | quant_error = grad_output
166 | return quant_error, None
167 |
168 |
169 | def run_snn(inputs):
170 | with torch.no_grad():
171 | spytorch_util.w1.data = quantization.clip(spytorch_util.w1.data, quantization.global_wb)
172 | spytorch_util.w2.data = quantization.clip(spytorch_util.w2.data, quantization.global_wb)
173 |
174 |
175 | h1 = einsum_linear.apply(inputs, spytorch_util.w1, scale1)
176 |
177 | syn = torch.zeros((nb_hidden), device=device, dtype=dtype)
178 | mem = torch.zeros((nb_hidden), device=device, dtype=dtype)
179 |
180 | mem_rec = []
181 | spk_rec = []
182 |
183 | # Compute hidden layer activity
184 | for t in range(nb_steps):
185 | mthr = mem-.9
186 | mthr = custom_quant.apply(mthr, quantization.global_ab)
187 | out = spike_fn(mthr)
188 |
189 | rst = torch.zeros_like(mem)
190 | c = (mthr > 0)
191 | rst[c] = torch.ones_like(mem)[c]
192 |
193 | new_syn = alpha*syn +h1[t,:]
194 | new_syn = custom_quant.apply(new_syn, quantization.global_ab)
195 | new_mem = beta*mem +syn -rst
196 | new_mem = custom_quant.apply(new_mem, quantization.global_ab)
197 |
198 | syn = new_syn
199 | mem = new_mem
200 |
201 | mem_rec.append(mem)
202 | spk_rec.append(out)
203 |
204 | mem_rec1 = torch.stack(mem_rec,dim=0)
205 | spk_rec1 = torch.stack(spk_rec,dim=0)
206 |
207 |
208 | #Readout layer
209 | h2 = einsum_linear.apply(spk_rec1, spytorch_util.w2, scale2)
210 |
211 | syn = torch.zeros((nb_outputs), device=device, dtype=dtype)
212 | mem = torch.zeros((nb_outputs), device=device, dtype=dtype)
213 |
214 | mem_rec = []
215 | spk_rec = []
216 |
217 | for t in range(nb_steps):
218 | mthr = mem-.9
219 | mthr = custom_quant.apply(mthr, quantization.global_ab)
220 | out = spike_fn(mthr)
221 |
222 | rst = torch.zeros_like(mem)
223 | c = (mthr > 0)
224 | rst[c] = torch.ones_like(mem)[c]
225 |
226 | new_syn = alpha*syn +h2[t,:]
227 | new_syn = custom_quant.apply(new_syn, quantization.global_ab)
228 | new_mem = beta*mem +syn -rst
229 | new_mem = custom_quant.apply(new_mem, quantization.global_ab)
230 |
231 | mem = new_mem
232 | syn = new_syn
233 |
234 | mem_rec.append(mem)
235 | spk_rec.append(out)
236 |
237 | mem_rec2 = torch.stack(mem_rec,dim=0)
238 | spk_rec2 = torch.stack(spk_rec,dim=0)
239 |
240 |
241 | other_recs = [mem_rec1, spk_rec1, mem_rec2]
242 | return spk_rec2, other_recs
243 |
244 |
245 | def train(x_data, y_data, lr=1e-3, nb_epochs=10):
246 | params = [spytorch_util.w1,spytorch_util.w2]
247 | optimizer = torch.optim.Adamax(params, lr=lr, betas=(0.9,0.999))
248 |
249 | loss_hist = []
250 | acc_hist = []
251 | for e in range(nb_epochs):
252 | output,recs = run_snn(x_data)
253 | loss_val = van_rossum(output, y_data, time_step, tau_syn, device)
254 |
255 | optimizer.zero_grad()
256 | loss_val.backward()
257 | optimizer.step()
258 |
259 | loss_hist.append(loss_val.item())
260 | print("Epoch %i: loss=%.5f"%(e+1,loss_val.item()))
261 |
262 | return loss_hist, output
263 |
264 | spike_fn = SuperSpike.apply
265 |
266 |
267 | quantization.global_beta = quantization.step_d(quantization.global_wb)-.5
268 | with open(args.input, 'rb') as f:
269 | x_train = pickle.load(f).t().to(device)
270 |
271 |
272 | with open(args.target, 'rb') as f:
273 | y_train = torch.tensor(pickle.load(f)).to(device)
274 | y_train = y_train.type(dtype)
275 |
276 |
277 | bit_string = str(quantization.global_wb) + str(quantization.global_ab) + str(quantization.global_gb) + str(quantization.global_eb)
278 |
279 | print("Start Training")
280 | print(bit_string)
281 |
282 | spytorch_util.w1 = torch.empty((nb_inputs, nb_hidden), device=device, dtype=dtype, requires_grad=True)
283 | scale1 = quantization.init_layer_weights(spytorch_util.w1, 28*28).to(device)
284 |
285 | spytorch_util.w2 = torch.empty((nb_hidden, nb_outputs), device=device, dtype=dtype, requires_grad=True)
286 | scale2 = quantization.init_layer_weights(spytorch_util.w2, 28*28).to(device)
287 |
288 |
289 | quantization.global_lr = .1
290 | loss_hist, output = train(x_train, y_train, lr = 1, nb_epochs = args.nb_epochs)
291 |
292 | bit_string = str(quantization.global_wb) + str(quantization.global_ab) + str(quantization.global_gb) + str(quantization.global_eb)
293 |
294 | results = {'bit_string': bit_string ,'loss_hist': loss_hist, 'output': output.cpu()}
295 |
296 | with open('results/snn_smile_precise_'+bit_string+'.pkl', 'wb') as f:
297 | pickle.dump(results, f)
298 |
299 |
300 |
301 |
--------------------------------------------------------------------------------
/qsnn_util.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | from matplotlib.gridspec import GridSpec
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torchvision
10 |
11 |
12 | time_step = 1e-3
13 |
14 | # Check whether a GPU is available
15 | if torch.cuda.is_available():
16 | device = torch.device("cuda")
17 | else:
18 | device = torch.device("cpu")
19 |
20 | w1 = None
21 | w2 = None
22 |
23 |
24 | def current2firing_time(x, tau=20, thr=0.2, tmax=1.0, epsilon=1e-7):
25 | """ Computes first firing time latency for a current input x assuming the charge time of a current based LIF neuron.
26 |
27 | Args:
28 | x -- The "current" values
29 |
30 | Keyword args:
31 | tau -- The membrane time constant of the LIF neuron to be charged
32 | thr -- The firing threshold value
33 | tmax -- The maximum time returned
34 | epsilon -- A generic (small) epsilon > 0
35 |
36 | Returns:
37 | Time to first spike for each "current" x
38 | """
39 | idx = x 0] = 1.0
127 | return out
128 |
129 | @staticmethod
130 | def backward(ctx, grad_output):
131 | """
132 | In the backward pass we receive a Tensor we need to compute the
133 | surrogate gradient of the loss with respect to the input.
134 | Here we use the normalized negative part of a fast sigmoid
135 | as this was done in Zenke & Ganguli (2018).
136 | """
137 | input, = ctx.saved_tensors
138 | grad_input = grad_output.clone()
139 | grad = grad_input/(SuperSpike.scale*torch.abs(input)+1.0)**2
140 | return grad
141 |
142 |
143 | def sparse_data_generator_DVS(X, y, batch_size, nb_steps, nb_units, shuffle, time_step, device):
144 | """ This generator takes datasets in analog format and generates spiking network input as sparse tensors.
145 |
146 | Args:
147 | X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
148 | y: The labels
149 | """
150 |
151 | try:
152 | labels_ = np.array(y.cpu(),dtype=np.int)
153 | except:
154 | labels_ = np.array(y,dtype=np.int)
155 | number_of_batches = len(y)//batch_size
156 | sample_index = np.arange(len(y))
157 |
158 |
159 | if shuffle:
160 | np.random.shuffle(sample_index)
161 |
162 | total_batch_count = 0
163 | counter = 0
164 | while counter 1).sum() != 0:
153 | # import pdb; pdb.set_trace()
154 | alpha = shift(torch.max(torch.abs(x)))
155 | return quant(clip(x / alpha, global_eb), global_eb)
156 |
157 | def init_layer_weights(weights_layer, shape, factor=1):
158 | fan_in = shape
159 |
160 | limit = torch.sqrt(torch.tensor([3*factor/fan_in]))
161 | Wm = global_beta/step_d(torch.tensor([float(global_wb)]))
162 | scale = 2 ** round(math.log(Wm / limit, 2.0))
163 | scale = scale if scale > 1 else 1.0
164 | limit = Wm if Wm > limit else limit
165 |
166 | torch.nn.init.uniform_(weights_layer, a = -float(limit), b = float(limit))
167 | weights_layer.data = quant_generic(weights_layer.data, global_gb)[0]
168 | return torch.tensor([float(scale)])
169 |
170 | # sum of square errors
171 | def SSE(y_true, y_pred):
172 | return 0.5 * torch.sum((y_true - y_pred)**2)
173 |
174 | def to_cat(inp_tensor, num_class, device):
175 | out_tensor = torch.zeros([inp_tensor.shape[0], num_class], device=device)
176 | out_tensor[torch.arange(inp_tensor.shape[0]).to(device), torch.tensor(inp_tensor, dtype = int, device=device)] = 1
177 | return out_tensor
178 |
179 | # Inherit from Function
180 | class clee_LinearFunction(torch.autograd.Function):
181 |
182 | # Note that both forward and backward are @staticmethods
183 | @staticmethod
184 | # bias is an optional argument
185 | def forward(ctx, input, weight, scale, act, act_q, bias=None):
186 | # prep and save
187 | w_quant = quant_w(weight, scale)
188 | input = input.float()
189 |
190 | # compute output
191 | output = input.mm(w_quant.t())
192 |
193 | relu_mask = torch.ones(output.shape).to(output.device)
194 | clip_info = torch.ones(output.shape).to(output.device)
195 |
196 | # add relu and quant optionally
197 | if act:
198 | output = F.relu(output)
199 | relu_mask = (output != 0)
200 | if act_q:
201 | output, clip_info = quant_act(output)
202 | if bias is not None:
203 | output += bias.unsqueeze(0).expand_as(output)
204 |
205 | gradient_mask = relu_mask * clip_info
206 |
207 | ctx.save_for_backward(input, w_quant, bias, gradient_mask)
208 | return output
209 |
210 | # This function has only a single output, so it gets only one gradient
211 | @staticmethod
212 | def backward(ctx, grad_output):
213 | input, w_quant, bias, gradient_mask = ctx.saved_tensors
214 | grad_input = grad_weight = grad_bias = None
215 | quant_error = quant_err(grad_output) * gradient_mask.float()
216 |
217 | if ctx.needs_input_grad[0]:
218 | # propagate quantized error
219 | grad_input = quant_error.mm(w_quant)
220 | if ctx.needs_input_grad[1]:
221 | grad_weight = quant_grad(quant_error.t().mm(input)).float()
222 |
223 | if bias is not None and ctx.needs_input_grad[2]:
224 | grad_bias = grad_output.sum(0).squeeze(0)
225 |
226 | return grad_input, grad_weight, grad_bias, None, None
227 |
228 | # Inherit from Function
229 | class clee_conv2d(torch.autograd.Function):
230 | # Note that both forward and backward are @staticmethods
231 | @staticmethod
232 | # bias is an optional argument
233 | def forward(ctx, input, weight, scale, act=False, act_q=False, pool=False, bias=None):
234 | mpool1 = nn.MaxPool2d(2, stride=2, return_indices=True)
235 |
236 | # prep and save
237 | w_quant = quant_w(weight, scale)
238 | input = input.float()
239 |
240 | # compute output
241 | output = F.conv2d(input, w_quant, bias=None, stride=1, padding=0, dilation=1, groups=1)
242 | relu_mask = torch.ones(output.shape).to(output.device)
243 | clip_info = torch.ones(output.shape).to(output.device)
244 | pool_indices = torch.ones(output.shape).to(output.device)
245 | size_pool = torch.tensor([0])
246 |
247 | # add pool, relu, quant optionally
248 | if pool:
249 | size_pool = output.shape
250 | output, pool_indices = mpool1(output)
251 | if act:
252 | output = F.relu(output)
253 | relu_mask = (output != 0)
254 | if act_q:
255 | output, clip_info = quant_act(output)
256 | if bias is not None:
257 | output += bias.unsqueeze(0).expand_as(output)
258 |
259 | gradient_mask = relu_mask * clip_info
260 |
261 | ctx.save_for_backward(input, w_quant, bias, torch.tensor([pool]), gradient_mask, pool_indices, torch.tensor(size_pool))
262 | return output
263 |
264 | # This function has only a single output, so it gets only one gradient
265 | @staticmethod
266 | def backward(ctx, grad_output):
267 | unpool1 = nn.MaxUnpool2d(2, stride=2, padding = 0)
268 |
269 | input, w_quant, bias, pool, gradient_mask, pool_indices, size_pool = ctx.saved_tensors
270 | grad_input = grad_weight = grad_bias = None
271 |
272 | grad_output = grad_output * gradient_mask.float()
273 | if pool:
274 | grad_output = unpool1(grad_output, pool_indices, output_size = torch.Size(size_pool))
275 |
276 | quant_error = quant_err(grad_output)
277 |
278 | if ctx.needs_input_grad[0]:
279 | # propagate quantized error
280 | grad_input = torch.nn.grad.conv2d_input(input.shape, w_quant, quant_error)
281 | if ctx.needs_input_grad[1]:
282 | grad_weight = quant_grad(torch.nn.grad.conv2d_weight(input, w_quant.shape, quant_error)).float()
283 |
284 | if bias is not None and ctx.needs_input_grad[2]:
285 | grad_bias = grad_output.sum(0).squeeze(0)
286 |
287 | return grad_input, grad_weight, grad_bias, None, None, None
--------------------------------------------------------------------------------