├── .gitignore ├── .vscode └── launch.json ├── README.md ├── assets ├── acquisition_setup.JPG ├── figure_encoding-reconstruct.JPG ├── pipeline1.gif ├── pipeline2.gif └── pipeline3.gif ├── notebooks ├── braille_encoder_rsnn.py ├── braille_reading_ffsnn.ipynb ├── braille_reading_ffsnn_Loihi.ipynb ├── braille_reading_rsnn.ipynb ├── braille_reading_rsnn.py ├── braille_reading_rsnn_2layer.ipynb ├── braille_reading_rsnn_3layer.ipynb ├── braille_reading_rsnn_Loihi.ipynb ├── spytorch_rsnn.ipynb ├── spytorch_rsnn.py └── spytorch_rsnn_encoder.py ├── parameters ├── parameters_th1.txt ├── parameters_th10.txt ├── parameters_th2.txt └── parameters_th5.txt ├── spytorch2loihi ├── SlayerSNN_src │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── learningStats.cpython-38.pyc │ │ ├── optimizer.cpython-38.pyc │ │ ├── quantizeParams.cpython-38.pyc │ │ ├── slayer.cpython-38.pyc │ │ ├── slayerLoihi.cpython-38.pyc │ │ ├── slayerParams.cpython-38.pyc │ │ ├── spikeClassifier.cpython-38.pyc │ │ ├── spikeFileIO.cpython-38.pyc │ │ ├── spikeLoss.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── auto │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── assistant.cpython-38.pyc │ │ │ ├── dataset.cpython-38.pyc │ │ │ └── loihi.cpython-38.pyc │ │ ├── assistant.py │ │ ├── dataset.py │ │ └── loihi.py │ ├── cuda │ │ ├── convKernels.h │ │ ├── shiftKernels.h │ │ ├── slayerKernels.cu │ │ ├── slayerLoihiKernels.cu │ │ ├── spikeKernels.h │ │ └── spikeLoihiKernels.h │ ├── learningStats.py │ ├── optimizer.py │ ├── quantizeParams.py │ ├── slayer.py │ ├── slayerLoihi.py │ ├── slayerParams.py │ ├── spikeClassifier.py │ ├── spikeFileIO.py │ ├── spikeLoss.py │ └── utils.py ├── netsLoihi │ └── netLoihi_rec_th1_6.net ├── spytorch2loihi_export_fsnn.ipynb ├── spytorch2loihi_export_rsnn.ipynb └── weights │ └── SpyTorch_trained_weights_rec_th1_6.pt └── utils └── event_transform.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/* -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Braille letter reading: A benchmark for spatio-temporal pattern recognition on neuromorphic hardware 2 | 3 | You can find the Frontiers publication "Braille letter reading: A benchmark for spatio-temporal pattern recognition on neuromorphic hardware" [here](https://www.frontiersin.org/articles/10.3389/fnins.2022.951164/full). 4 | 5 | # Braille letters 6 | 7 | Braille is a tactile writing system used by people who are visually impaired. These characters have rectangular blocks called *cells* that have tiny bumps called *raised dots*. The number and arrangement of these dots distinguish one character from another. For more details and background information see [here](https://en.wikipedia.org/wiki/Braille). 8 | 9 | ![brialle_system_english](https://user-images.githubusercontent.com/60852381/120632860-bb6c9e00-c469-11eb-8b33-47df012f76b0.jpg) 10 | 11 | # The Dataset 12 | The dataset is composed of different levels of complexity from single letters to words. The 27 letters (Space + A - Z) have been recorded using the iCub fingertip sliding over 3d printed stimuli. For that, the fingertip was mounted on a 3-axis robot (omega.3, [forcedimensions](https://www.forcedimension.com/products/omega)) and moved over single braille letters 50 times each with similar velocity (0.01 m/s) at a sampling frequency of 40Hz. The data is converted into spike trains afterward. 13 | Delta coding is used for the conversion. No additional noise is added because the analog recordings already contain sensor noise. Binary events ('ON'/'OFF') are created when a predefined threshold is reached followed by a refractory period. At the end of the refractory period, change is accumulated again, until the threshold is reached and a new event is elicit. Thresholds and refractory period are (0.5 for ON and OFF) and (0.0025 sec) respectively. The recordings of the single letters spike trains are combined to compose words. 14 | 15 | Experimental Setup | Encoding Scheme 16 | :------------:|:------------: 17 | ![experiantal_setup](https://github.com/event-driven-robotics/tactile_braille_reading/blob/main/assets/acquisition_setup.JPG) | ![encoding_scheme](https://github.com/event-driven-robotics/tactile_braille_reading/blob/main/assets/figure_encoding-reconstruct.JPG) 18 | 19 | Scanning | Sample-based | Event-based 20 | :------------:|:------------:|:------------: 21 | ![scanning](https://github.com/event-driven-robotics/tactile_braille_reading/blob/main/assets/pipeline1.gif) | ![sample_based](https://github.com/event-driven-robotics/tactile_braille_reading/blob/main/assets/pipeline2.gif) | ![event_based](https://github.com/event-driven-robotics/tactile_braille_reading/blob/main/assets/pipeline3.gif) 22 | 23 | # How-to 24 | 1. Install [Python](https://www.python.org/), [PyTorch](https://pytorch.org/), [NumPy](https://numpy.org/), [scikit-learn](https://scikit-learn.org/stable/), [pandas](https://pandas.pydata.org/) and [matplotlib](https://matplotlib.org/) for plotting 25 | 2. Download the [dataset](https://zenodo.org/record/7050094) from Zenodo 26 | 3. Extract the files and add them in the main folder of this repository 27 | 4. Run the jupiter notebook for the [feedforawrd SNN](https://github.com/event-driven-robotics/tactile_braille_reading/blob/main/notebooks/braille_reading_ffsnn.ipynb) and/or the [recurrent SNN](https://github.com/event-driven-robotics/tactile_braille_reading/blob/main/notebooks/braille_reading_rsnn.ipynb) 28 | 5. If you want to use encoding thresholds not already contained run [this file](https://github.com/event-driven-robotics/tactile_braille_reading/blob/main/utils/event_transform.py) with your personal parameters again and change the data loading in the notebooks accordingly 29 | -------------------------------------------------------------------------------- /assets/acquisition_setup.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/assets/acquisition_setup.JPG -------------------------------------------------------------------------------- /assets/figure_encoding-reconstruct.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/assets/figure_encoding-reconstruct.JPG -------------------------------------------------------------------------------- /assets/pipeline1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/assets/pipeline1.gif -------------------------------------------------------------------------------- /assets/pipeline2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/assets/pipeline2.gif -------------------------------------------------------------------------------- /assets/pipeline3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/assets/pipeline3.gif -------------------------------------------------------------------------------- /notebooks/braille_encoder_rsnn.py: -------------------------------------------------------------------------------- 1 | # from IPython.display import clear_output 2 | import os 3 | import pickle 4 | import gzip 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from matplotlib.gridspec import GridSpec 9 | import seaborn as sns 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.utils.data import TensorDataset, DataLoader 14 | 15 | dtype = torch.float 16 | 17 | # Check whether a GPU is available 18 | if torch.cuda.is_available(): 19 | device = torch.device("cuda") 20 | else: 21 | device = torch.device("cpu") 22 | 23 | # Seed random number generators to ensure reproducibility 24 | seed = 42 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | 28 | 29 | # data structure: [trial number] x ['key'] x [time] x [sensor_nr] 30 | 31 | file_name = './data/data_braille_letters_all.pkl' 32 | # file = open(file_name, 'rb') 33 | # data_dict = pickle.load(file_name) 34 | # file.close() 35 | 36 | with open(file_name, 'rb') as fp: 37 | data_dict = pickle.load(fp) 38 | 39 | letter_written = ['Space', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 40 | 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] 41 | 42 | # Extract data 43 | data = [] 44 | for i, letter in enumerate(letter_written): 45 | for repetition in np.arange(int(len(data_dict)/len(letter_written))): 46 | dat = 1.0-data_dict['taxel_data'][i*repetition+repetition]/255 47 | data.append(dat) 48 | # labels.append(i) 49 | # data = data_dict['taxel_data'].to_numpy() 50 | # data = data.dtype(np.float64)/255 # set to float type 51 | labels = data_dict['letter'].to_numpy() 52 | labels_str, labels = np.unique(labels, return_inverse=True) 53 | unique_lables = np.unique(labels) 54 | 55 | # Crop to same length 56 | # data_steps = l = np.min([len(d) for d in data]) 57 | # data = torch.tensor([d[:l] for d in data], dtype=dtype) 58 | data_steps = len(data[0]) # 350 59 | data = torch.as_tensor(data, dtype=dtype) 60 | # convert string to number 61 | 62 | labels = torch.tensor(labels, dtype=torch.long) 63 | 64 | # Select nonzero inputs 65 | nzid = [1, 2, 6, 10] 66 | data = data[:, :, nzid] 67 | 68 | # Standardize data 69 | rshp = data.reshape((-1, data.shape[2])) 70 | data = (data-rshp.mean(0))/(rshp.std(0)+1e-3) 71 | 72 | # Upsample 73 | 74 | 75 | def upsample(data, n=2): 76 | shp = data.shape 77 | tmp = data.reshape(shp+(1,)) 78 | tmp = data.tile((1, 1, 1, n)) 79 | return tmp.reshape((shp[0], n*shp[1], shp[2])) 80 | 81 | 82 | nb_upsample = 2 83 | data = upsample(data, n=nb_upsample) 84 | 85 | # Shuffle data 86 | idx = np.arange(len(data)) 87 | np.random.shuffle(idx) 88 | data = data[idx] 89 | labels = labels[idx] 90 | 91 | # Peform train/test split 92 | a = int(0.8*len(idx)) 93 | x_train, x_test = data[:a], data[a:] 94 | y_train, y_test = labels[:a], labels[a:] 95 | 96 | ds_train = TensorDataset(x_train, y_train) 97 | ds_test = TensorDataset(x_test, y_test) 98 | 99 | # data = torch.tensor([d[:l] for d in data], dtype=dtype) 100 | 101 | # Visualize single data point 102 | i = 123 103 | plt.plot(data[i]) 104 | plt.title("Letter %s" % letter_written[labels[i]]) 105 | sns.despine() 106 | 107 | 108 | nb_channels = len(nzid) 109 | enc_fan_out = 32 # Num of spiking neurons used to encode each channel 110 | 111 | # Network parameters 112 | nb_inputs = nb_channels*enc_fan_out 113 | nb_hidden = 450 114 | nb_outputs = len(unique_lables) 115 | # TODO needs to be updated to reflect the correct time scale 116 | time_step = 2e-3/nb_upsample 117 | # TODO We should change this and upsample the input data 118 | nb_steps = nb_upsample*data_steps 119 | 120 | batch_size = 128 121 | 122 | print("Number of training data %i" % len(ds_train)) 123 | print("Number of testing data %i" % len(ds_test)) 124 | print("Number of outputs %i" % nb_outputs) 125 | print("Number of timesteps %i" % nb_steps) 126 | 127 | tau_mem = 20e-3 128 | tau_syn = 10e-3 129 | 130 | alpha = float(np.exp(-time_step/tau_syn)) 131 | beta = float(np.exp(-time_step/tau_mem)) 132 | 133 | encoder_weight_scale = 1.0 134 | fwd_weight_scale = 3.0 135 | rec_weight_scale = 1e-2*fwd_weight_scale 136 | 137 | # Parameters 138 | 139 | # Encoder 140 | enc_gain = torch.empty((nb_inputs,), device=device, 141 | dtype=dtype, requires_grad=True) 142 | enc_bias = torch.empty((nb_inputs,), device=device, 143 | dtype=dtype, requires_grad=True) 144 | # TODO update this parameter 145 | torch.nn.init.normal_(enc_gain, mean=0.0, std=encoder_weight_scale) 146 | torch.nn.init.normal_(enc_bias, mean=0.0, std=1.0) 147 | 148 | # Spiking network 149 | w1 = torch.empty((nb_inputs, nb_hidden), device=device, 150 | dtype=dtype, requires_grad=True) 151 | torch.nn.init.normal_(w1, mean=0.0, std=fwd_weight_scale/np.sqrt(nb_inputs)) 152 | 153 | w2 = torch.empty((nb_hidden, nb_outputs), device=device, 154 | dtype=dtype, requires_grad=True) 155 | torch.nn.init.normal_(w2, mean=0.0, std=fwd_weight_scale/np.sqrt(nb_hidden)) 156 | 157 | v1 = torch.empty((nb_hidden, nb_hidden), device=device, 158 | dtype=dtype, requires_grad=True) 159 | torch.nn.init.normal_(v1, mean=0.0, std=rec_weight_scale/np.sqrt(nb_hidden)) 160 | 161 | print("init done") 162 | 163 | 164 | def plot_voltage_traces(mem, spk=None, dim=(3, 5), spike_height=5, **kwargs): 165 | gs = GridSpec(*dim) 166 | if spk is not None: 167 | dat = 1.0*mem 168 | dat[spk > 0.0] = spike_height 169 | dat = dat.detach().cpu().numpy() 170 | else: 171 | dat = mem.detach().cpu().numpy() 172 | for i in range(np.prod(dim)): 173 | if i == 0: 174 | a0 = ax = plt.subplot(gs[i]) 175 | else: 176 | ax = plt.subplot(gs[i], sharey=a0) 177 | ax.plot(dat[i], **kwargs) 178 | ax.axis("off") 179 | 180 | 181 | def live_plot(loss): 182 | if len(loss) == 1: 183 | return 184 | # clear_output(wait=True) 185 | ax = plt.figure(figsize=(3, 2), dpi=150).gca() 186 | ax.plot(range(1, len(loss) + 1), loss) 187 | ax.set_xlabel("Epoch") 188 | ax.set_ylabel("Loss") 189 | ax.xaxis.get_major_locator().set_params(integer=True) 190 | sns.despine() 191 | plt.show() 192 | 193 | 194 | class SurrGradSpike(torch.autograd.Function): 195 | """ 196 | Here we implement our spiking nonlinearity which also implements 197 | the surrogate gradient. By subclassing torch.autograd.Function, 198 | we will be able to use all of PyTorch's autograd functionality. 199 | Here we use the normalized negative part of a fast sigmoid 200 | as this was done in Zenke & Ganguli (2018). 201 | """ 202 | 203 | scale = 20.0 # controls steepness of surrogate gradient 204 | 205 | @staticmethod 206 | def forward(ctx, input): 207 | """ 208 | In the forward pass we compute a step function of the input Tensor 209 | and return it. ctx is a context object that we use to stash information which 210 | we need to later backpropagate our error signals. To achieve this we use the 211 | ctx.save_for_backward method. 212 | """ 213 | ctx.save_for_backward(input) 214 | out = torch.zeros_like(input) 215 | out[input > 0] = 1.0 216 | return out 217 | 218 | @staticmethod 219 | def backward(ctx, grad_output): 220 | """ 221 | In the backward pass we receive a Tensor we need to compute the 222 | surrogate gradient of the loss with respect to the input. 223 | Here we use the normalized negative part of a fast sigmoid 224 | as this was done in Zenke & Ganguli (2018). 225 | """ 226 | input, = ctx.saved_tensors 227 | grad_input = grad_output.clone() 228 | grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2 229 | return grad 230 | 231 | 232 | # here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient 233 | spike_fn = SurrGradSpike.apply 234 | 235 | 236 | def run_snn(inputs): 237 | bs = inputs.shape[0] 238 | enc = torch.zeros((bs, nb_inputs), device=device, dtype=dtype) 239 | input_spk = torch.zeros((bs, nb_inputs), device=device, dtype=dtype) 240 | syn = torch.zeros((bs, nb_hidden), device=device, dtype=dtype) 241 | mem = -1e-3*torch.ones((bs, nb_hidden), device=device, dtype=dtype) 242 | out = torch.zeros((bs, nb_hidden), device=device, dtype=dtype) 243 | 244 | enc_rec = [] 245 | mem_rec = [] 246 | spk_rec = [] 247 | 248 | # encoder_currents = torch.einsum("abc,c->ab", (inputs.tile((enc_fan_out,)), enc_gain))+enc_bias 249 | encoder_currents = enc_gain*(inputs.tile((enc_fan_out,))+enc_bias) 250 | for t in range(nb_steps): 251 | # Compute encoder activity 252 | new_enc = (beta*enc + (1.0-beta) * 253 | encoder_currents[:, t])*(1.0-input_spk.detach()) 254 | input_spk = spike_fn(enc-1.0) 255 | 256 | # Compute hidden layer activity 257 | h1 = input_spk.mm(w1) + torch.einsum("ab,bc->ac", (out, v1)) 258 | mthr = mem-1.0 259 | out = spike_fn(mthr) 260 | rst = out.detach() # We do not want to backprop through the reset 261 | 262 | new_syn = alpha*syn + h1 263 | new_mem = (beta*mem + (1.0-beta)*syn)*(1.0-rst) 264 | 265 | # Here we store some state variables so we can look at them later. 266 | mem_rec.append(mem.detach()) 267 | spk_rec.append(out.detach()) 268 | enc_rec.append(enc.detach()) 269 | 270 | enc = new_enc 271 | mem = new_mem 272 | syn = new_syn 273 | 274 | enc_rec = torch.stack(enc_rec, dim=1) 275 | mem_rec = torch.stack(mem_rec, dim=1) 276 | spk_rec = torch.stack(spk_rec, dim=1) 277 | 278 | # Readout layer 279 | h2 = torch.einsum("abc,cd->abd", (spk_rec, w2)) 280 | flt = torch.zeros((bs, nb_outputs), device=device, dtype=dtype) 281 | out = torch.zeros((bs, nb_outputs), device=device, dtype=dtype) 282 | out_rec = [out] 283 | for t in range(nb_steps): 284 | new_flt = alpha*flt + h2[:, t] 285 | new_out = beta*out + (1.0-beta)*flt 286 | 287 | flt = new_flt 288 | out = new_out 289 | 290 | out_rec.append(out) 291 | 292 | out_rec = torch.stack(out_rec, dim=1) 293 | other_recs = [enc_rec, mem_rec, spk_rec] 294 | return out_rec, other_recs 295 | 296 | 297 | def train(dataset, lr=1e-3, nb_epochs=10): 298 | 299 | params = [enc_gain, enc_bias, w1, w2, v1] 300 | optimizer = torch.optim.Adamax(params, lr=lr, betas=(0.9, 0.995)) 301 | 302 | log_softmax_fn = nn.LogSoftmax(dim=1) 303 | loss_fn = nn.NLLLoss() 304 | 305 | generator = DataLoader(dataset, batch_size=batch_size, 306 | shuffle=True, num_workers=2) 307 | 308 | loss_hist = [] 309 | for e in range(nb_epochs): 310 | local_loss = [] 311 | for x_local, y_local in generator: 312 | x_local, y_local = x_local.to(device), y_local.to(device) 313 | output, recs = run_snn(x_local) 314 | _, _, spks = recs 315 | m, _ = torch.max(output, 1) 316 | log_p_y = log_softmax_fn(m) 317 | 318 | # Here we can set up our regularizer loss 319 | # e.g., L1 loss on total number of spikes 320 | reg_loss = 1e-3*torch.mean(torch.sum(spks, 1)) 321 | # reg_loss = 0.0 322 | 323 | # Here we combine supervised loss and the regularizer 324 | loss_val = loss_fn(log_p_y, y_local) + reg_loss 325 | 326 | optimizer.zero_grad() 327 | loss_val.backward() 328 | optimizer.step() 329 | local_loss.append(loss_val.item()) 330 | 331 | mean_loss = np.mean(local_loss) 332 | loss_hist.append(mean_loss) 333 | # live_plot(loss_hist) 334 | print("Epoch %i: loss=%.5f" % (e+1, mean_loss)) 335 | 336 | return loss_hist 337 | 338 | 339 | def compute_classification_accuracy(dataset): 340 | """ Computes classification accuracy on supplied data in batches. """ 341 | generator = DataLoader(dataset, batch_size=batch_size, 342 | shuffle=False, num_workers=2) 343 | accs = [] 344 | for x_local, y_local in generator: 345 | x_local, y_local = x_local.to(device), y_local.to(device) 346 | output, _ = run_snn(x_local) 347 | m, _ = torch.max(output, 1) # max over time 348 | _, am = torch.max(m, 1) # argmax over output units 349 | # compare to labels 350 | tmp = np.mean((y_local == am).detach().cpu().numpy()) 351 | accs.append(tmp) 352 | return np.mean(accs) 353 | 354 | 355 | nb_epochs = 200 356 | 357 | loss_hist = train(ds_train, lr=1e-2, nb_epochs=nb_epochs) 358 | 359 | print("Training accuracy: %.3f" % (compute_classification_accuracy(ds_train))) 360 | print("Test accuracy: %.3f" % (compute_classification_accuracy(ds_test))) 361 | 362 | 363 | # Let's run the network on a single batch from the test set 364 | data_loader = DataLoader(ds_test, batch_size=batch_size, shuffle=False) 365 | x_batch, y_batch = next(iter(data_loader)) 366 | output, other_recordings = run_snn(x_batch.to(device)) 367 | enc_rec, mem_rec, spk_rec = other_recordings 368 | 369 | # This is how our spiking encoders convert the current based input into spike trains (we plot 20) 370 | 371 | fig = plt.figure(dpi=150, figsize=(7, 3)) 372 | plot_voltage_traces(enc_rec[:, :, :20], color="black", alpha=0.2) 373 | 374 | # Let's take a look at the readout layer activity 375 | 376 | fig = plt.figure(dpi=150, figsize=(7, 3)) 377 | plot_voltage_traces(output) 378 | 379 | # Let's plot the hiddden layer spiking activity for some input stimuli 380 | 381 | nb_plt = 4 382 | gs = GridSpec(1, nb_plt) 383 | fig = plt.figure(figsize=(7, 3), dpi=150) 384 | for i in range(nb_plt): 385 | plt.subplot(gs[i]) 386 | plt.imshow(spk_rec[i].detach().cpu().numpy().T, 387 | cmap=plt.cm.gray_r, origin="lower") 388 | if i == 0: 389 | plt.xlabel("Time") 390 | plt.ylabel("Units") 391 | 392 | sns.despine() 393 | -------------------------------------------------------------------------------- /parameters/parameters_th1.txt: -------------------------------------------------------------------------------- 1 | scale 5 2 | time_bin_size 5 3 | nb_input_copies 2 4 | tau_mem 0.06 5 | tau_ratio 10 6 | fwd_weight_scale 1 7 | weight_scale_factor 0.01 8 | reg_spikes 0.004 9 | reg_neurons 0.000001 10 | -------------------------------------------------------------------------------- /parameters/parameters_th10.txt: -------------------------------------------------------------------------------- 1 | scale 10 2 | time_bin_size 5 3 | nb_input_copies 2 4 | tau_mem 0.07 5 | tau_ratio 10 6 | fwd_weight_scale 4.0 7 | weight_scale_factor 0.015 8 | reg_spikes 0.0015 9 | reg_neurons 0.0 10 | -------------------------------------------------------------------------------- /parameters/parameters_th2.txt: -------------------------------------------------------------------------------- 1 | scale 15 2 | time_bin_size 3 3 | nb_input_copies 8 4 | tau_mem 0.05 5 | tau_ratio 10 6 | fwd_weight_scale 1 7 | weight_scale_factor 0.02 8 | reg_spikes 0.0015 9 | reg_neurons 0.0 10 | -------------------------------------------------------------------------------- /parameters/parameters_th5.txt: -------------------------------------------------------------------------------- 1 | scale 10 2 | time_bin_size 3 3 | nb_input_copies 4 4 | tau_mem 0.07 5 | tau_ratio 10 6 | fwd_weight_scale 1.5 7 | weight_scale_factor 0.035 8 | reg_spikes 0.001 9 | reg_neurons 0.0 10 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__init__.py: -------------------------------------------------------------------------------- 1 | from .slayer import spikeLayer as layer 2 | from .slayerLoihi import spikeLayer as loihi 3 | # from slayer import yamlParams as params 4 | from .slayerParams import yamlParams as params 5 | from .spikeLoss import spikeLoss as loss 6 | from .spikeClassifier import spikeClassifier as predict 7 | from . import spikeFileIO as io 8 | from . import utils 9 | # This will be removed later. Kept for compatibility only 10 | from .quantizeParams import quantizeWeights as quantize 11 | 12 | # from .slayer import spikeLayer as layer 13 | # from .slayerLoihi import spikeLayer as loihi 14 | # # from slayer import yamlParams as params 15 | # from .slayerParams import yamlParams as params 16 | # from .spikeLoss import spikeLoss as loss 17 | # from .spikeClassifier import spikeClassifier as predict 18 | # from . import spikeFileIO as io 19 | # from .quantizeParams import quantizeWeights as quantize 20 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/learningStats.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/learningStats.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/optimizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/optimizer.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/quantizeParams.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/quantizeParams.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/slayer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/slayer.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/slayerLoihi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/slayerLoihi.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/slayerParams.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/slayerParams.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/spikeClassifier.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/spikeClassifier.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/spikeFileIO.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/spikeFileIO.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/spikeLoss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/spikeLoss.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/auto/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import SlayerDataset as dataset 2 | from .assistant import Assistant as assistant 3 | from . import loihi -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/auto/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/auto/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/auto/__pycache__/assistant.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/auto/__pycache__/assistant.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/auto/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/auto/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/auto/__pycache__/loihi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/SlayerSNN_src/auto/__pycache__/loihi.cpython-38.pyc -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/auto/assistant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from ..spikeClassifier import spikeClassifier as predict 4 | # import slayerCuda # LK: debug 5 | from datetime import datetime 6 | 7 | class Assistant: 8 | ''' 9 | This class provides standard assistant functionalities for traiing and testing workflow. 10 | If you want a different workflow than what is available, you should inherit this module and 11 | overload the particular module to your need. 12 | 13 | Arguments: 14 | * ``net``: the SLAYER network to be run. 15 | * ``trainLoader``: training dataloader. 16 | * ``testLoader``: testing dataloader. 17 | * ``error``: a function object or a lamda function that takes (output, target, label) as its input and returns 18 | a scalar error value. 19 | * ``optimizer``: the learning optimizer. 20 | * ``scheduler``: the learning scheduler. Default: ``None`` meaning no scheduler will be used. 21 | * ``stats``: the SLAYER learning stats logger: ``slayerSNN.stats``. Default: ``None`` meaning no stats will be used. 22 | * ``dataParallel``: flag if dataParallel execution needs to be handled. Default: ``False``. 23 | * ``showTimeSteps``: flag to print timesteps of the sample or not. Default: ``False``. 24 | * ``lossScale``: a scale factor to be used while printing the loss. Default: ``None`` meaning no scaling is done. 25 | * ``printInterval``: number of epochs to print the lerning output once. Default: 1. 26 | 27 | Usage: 28 | 29 | .. code-block:: python 30 | 31 | assist = assistant(net, trainLoader, testLoader, lambda o, t, l: error.numSpikes(o, t), optimizer, stats) 32 | 33 | for epoch in range(maxEpoch): 34 | assist.train(epoch) 35 | assist.test(epoch) 36 | ''' 37 | def __init__(self, net, trainLoader, testLoader, error, optimizer, scheduler=None, stats=None, 38 | dataParallel=False, showTimeSteps=False, lossScale=None, printInterval=1): 39 | self.net = net 40 | self.module = net.module if dataParallel is True else net 41 | self.error = error 42 | self.device = self.module.slayer.srmKernel.device 43 | self.optimizer = optimizer 44 | self.scheduler = scheduler 45 | self.stats = stats 46 | self.showTimeSteps = showTimeSteps 47 | self.lossScale = lossScale 48 | self.printInterval = printInterval 49 | 50 | self.trainLoader = trainLoader 51 | self.testLoader = testLoader 52 | 53 | def train(self, epoch=0, breakIter = None, printLog=True): 54 | ''' 55 | Training assistant fucntion. 56 | 57 | Arguments: 58 | * ``epoch``: training epoch number. 59 | * ``breakIter``: number of samples to wait before breaking out of the training loop. 60 | ``None`` means go over the complete training samples. Default: ``None``. 61 | ''' 62 | tSt = datetime.now() 63 | for i, (input, target, label) in enumerate(self.trainLoader, 0): 64 | self.net.train() 65 | 66 | input = input.to(self.device) 67 | target = target.to(self.device) 68 | 69 | count = 0 70 | if self.module.countLog is True: 71 | output, count = self.net.forward(input) 72 | else: 73 | output = self.net.forward(input) 74 | 75 | if self.stats is not None: 76 | self.stats.training.correctSamples += torch.sum( predict.getClass(output) == label ).data.item() 77 | self.stats.training.numSamples += len(label) 78 | 79 | loss = self.error(output, target, label) 80 | if self.stats is not None: 81 | self.stats.training.lossSum += loss.cpu().data.item() * (1 if self.lossScale is None else self.lossScale) 82 | 83 | self.optimizer.zero_grad() 84 | loss.backward() 85 | self.optimizer.step() 86 | self.module.clamp() 87 | 88 | if self.stats is not None and i%self.printInterval == 0 and printLog is True: 89 | headerList = ['[{}/{} ({:.0f}%)]'.format(i*self.trainLoader.batch_size, len(self.trainLoader.dataset), 100.0*i/len(self.trainLoader))] 90 | if self.module.countLog is True: 91 | headerList.append('Spike count: ' + ', '.join(['{}'.format(int(c)) for c in torch.sum(count, dim=0).tolist()])) 92 | if self.showTimeSteps is True: 93 | headerList.append('nTimeBins: {}'.format(input.shape[-1])) 94 | 95 | self.stats.print( 96 | epoch, i, 97 | (datetime.now() - tSt).total_seconds() / (i+1) / self.trainLoader.batch_size, 98 | header= headerList, 99 | ) 100 | 101 | if breakIter is not None and i >= breakIter: 102 | break 103 | 104 | if self.scheduler is not None: 105 | self.scheduler.step() 106 | 107 | def test(self, epoch=0, evalLoss=True, slidingWindow=None, breakIter = None, printLog=True): 108 | ''' 109 | Testing assistant fucntion. 110 | 111 | Arguments: 112 | * ``epoch``: training epoch number. 113 | * ``evalLoss``: a flag to enable or disable loss evalutaion. Default: ``True``. 114 | * ``slidingWindow``: the length of sliding window to use for continuous output prediction over time. 115 | ``None`` means total spike count is used to produce one output per sample. If it is not 116 | ``None``, ``evalLoss`` is overwritten to ``False``. Default: ``None``. 117 | * ``breakIter``: number of samples to wait before breaking out of the testing loop. 118 | ``None`` means go over the complete training samples. Default: ``None``. 119 | ''' 120 | if slidingWindow is not None: 121 | filter = torch.ones((slidingWindow)).to(self.device) 122 | evalLoss = False 123 | 124 | tSt = datetime.now() 125 | for i, (input, target, label) in enumerate(self.testLoader, 0): 126 | self.net.eval() 127 | 128 | with torch.no_grad(): 129 | input = input.to(self.device) 130 | target = target.to(self.device) 131 | 132 | count = 0 133 | if self.module.countLog is True: 134 | output, count = self.net.forward(input) 135 | else: 136 | output = self.net.forward(input) 137 | 138 | if slidingWindow is None: 139 | if self.stats is not None: 140 | self.stats.testing.correctSamples += torch.sum( predict.getClass(output) == label ).data.item() 141 | self.stats.testing.numSamples += len(label) 142 | else: 143 | filteredOutput = slayerCuda.conv(output.contiguous(), filter, 1)[..., slidingWindow:] 144 | predictions = torch.argmax(filteredOutput.reshape(-1, filteredOutput.shape[-1]), dim=0) 145 | 146 | # print(output.shape, predictions.shape) 147 | # print(predictions[:100]) 148 | # print(label) 149 | # print(torch.sum(predictions == label).item()) 150 | # print(torch.sum(predictions == label).item() / predictions.shape[0]) 151 | 152 | # assert False, 'Just braking' 153 | 154 | if self.stats is not None: 155 | self.stats.testing.correctSamples += torch.sum(predictions == label.to(self.device)).item() 156 | self.stats.testing.numSamples += predictions.shape[0] 157 | 158 | if evalLoss is True: 159 | loss = self.error(output, target, label) 160 | if self.stats is not None: 161 | self.stats.testing.lossSum += loss.cpu().data.item() * (1 if self.lossScale is None else self.lossScale) 162 | else: 163 | if self.stats is not None: 164 | if slidingWindow is None: 165 | self.stats.testing.lossSum += (1 if self.lossScale is None else self.lossScale) 166 | else: 167 | self.stats.testing.lossSum += predictions.shape[0] * (1 if self.lossScale is None else self.lossScale) 168 | 169 | if self.stats is not None and i%self.printInterval == 0 and printLog is True: 170 | headerList = ['[{}/{} ({:.0f}%)]'.format(i*self.testLoader.batch_size, len(self.testLoader.dataset), 100.0*i/len(self.testLoader))] 171 | if self.module.countLog is True: 172 | headerList.append('Spike count: ' + ', '.join(['{}'.format(int(c)) for c in torch.sum(count, dim=0).tolist()])) 173 | if self.showTimeSteps is True: 174 | headerList.append('nTimeBins: {}'.format(input.shape[-1])) 175 | 176 | self.stats.print( 177 | epoch, i, 178 | (datetime.now() - tSt).total_seconds() / (i+1) / self.testLoader.batch_size, 179 | header= headerList, 180 | ) 181 | 182 | if breakIter is not None and i >= breakIter: 183 | break 184 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/auto/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset 4 | from .. import spikeFileIO as sio 5 | 6 | class SlayerDataset(Dataset): 7 | ''' 8 | This class wraps a basic dataset class to be used in SLAYER training. This allows the use 9 | of the same basic dataset definition on some other platform other than SLAYER, for e.g. for 10 | implementation in a neuromorphic hardware with its SDK. 11 | 12 | The basic dataset must return a numpy array of events where each row consists of an AER event 13 | represented by x, y, polarity and time (in ms). 14 | 15 | Arguments: 16 | * ``dataset``: basic dataset to be wrapped. 17 | * ``network``: an ``auto`` module network with which the dataset is intended to be used with. 18 | The shape of the tensor is determined from the netowrk definition. 19 | * ``randomShift``: a flag to indicate if the sample must be randomly shifted in time over the 20 | entire sample length. Default: False 21 | * ``binningMode``: the way the overlapping events are binned. Supports ``SUM`` and ``OR`` binning. 22 | Default: ``OR`` 23 | * ``fullDataset``: a flag that indicates weather the full dataset is to be processed or not. 24 | If ``True``, full length of the events is loaded into tensor. This will cause problems with 25 | default batching, as the number of time bins will not match for all the samples in a minibatch. 26 | In this case, the dataloader's ``collate_fn`` must be custom defined or a batch size of 1 should 27 | be used. Default: ``False`` 28 | 29 | Usage: 30 | 31 | .. code-block:: python 32 | 33 | dataset = SlayerDataset(dataset, net) 34 | ''' 35 | # this expects np event and label from dataset 36 | # np event should have events ordered in x, y, p, t(ms) 37 | def __init__(self, dataset, network, randomShift=False, binningMode='OR', fullDataset=False): 38 | # fullDataset = True superseds randomShift, nTimeBins and tensorShape. It is expected to be run with batch size of 1 only 39 | super(SlayerDataset, self).__init__() 40 | self.dataset = dataset 41 | self.samplingTime = network.netParams['simulation']['Ts'] 42 | self.sampleLength = network.netParams['simulation']['tSample'] 43 | self.nTimeBins = int(self.sampleLength/self.samplingTime) 44 | self.inputShape = network.inputShape 45 | self.nOutput = network.nOutput 46 | self.tensorShape = (self.inputShape[2], self.inputShape[1], self.inputShape[0], self.nTimeBins) 47 | self.randomShift = randomShift 48 | self.binningMode = binningMode 49 | self.fullDataset = fullDataset 50 | 51 | def __getitem__(self, index): 52 | event, label = self.dataset[index] 53 | 54 | if self.fullDataset is False: 55 | inputSpikes = sio.event( 56 | event[:, 0], event[:, 1], event[:, 2], event[:, 3] 57 | ).toSpikeTensor( 58 | torch.zeros(self.tensorShape), 59 | samplingTime=self.samplingTime, 60 | randomShift=self.randomShift, 61 | binningMode=self.binningMode, 62 | ) 63 | else: 64 | nTimeBins = int(np.ceil(event[:, 3].max())) 65 | tensorShape = (self.inputShape[2], self.inputShape[1], self.inputShape[0], nTimeBins) 66 | inputSpikes = sio.event( 67 | event[:, 0], event[:, 1], event[:, 2], event[:, 3] 68 | ).toSpikeTensor( 69 | torch.zeros(tensorShape), 70 | samplingTime=self.samplingTime, 71 | randomShift=self.randomShift, 72 | binningMode=self.binningMode, 73 | ) 74 | 75 | desiredClass = torch.zeros((self.nOutput, 1, 1, 1)) 76 | desiredClass[label, ...] = 1 77 | 78 | return inputSpikes, desiredClass, label 79 | 80 | def __len__(self): 81 | return len(self.dataset) -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/auto/loihi.py: -------------------------------------------------------------------------------- 1 | # Sumit Bam Shrestha 09/28/2020 5pm 2 | # ================================= 3 | # This is a wrapper code that generates feedforward slayerSNN network from a 4 | # network config file (*.yaml). This will also include modules to output a 5 | # network description file (*.hdf5) OR SOME OTHER FORMAT (TO BE DECIDED) which 6 | # will be directly loadable in nxsdk (PERHAPS THIS NEEDS REMOVING LATER) 7 | # module to load the trained network in Loihi hardware. 8 | # 9 | # This module should be merged with slayerSNN.loihi later and served from 10 | # SLAYER-PyTorch module 11 | # It shall be accessible as slayerSNN.auto.loihi 12 | 13 | from .. import utils 14 | from ..slayerLoihi import spikeLayer as loihi 15 | 16 | from collections import _count_elements 17 | import torch 18 | from torch.utils.data import Dataset 19 | import matplotlib.pyplot as plt 20 | import numpy as np 21 | import re 22 | import h5py 23 | 24 | class denseBlock(torch.nn.Module): 25 | ''' 26 | This class creates a dense layer block with Loihi neuron. It groups the 27 | synaptic interaction, Loihi neuron response and the associated delays. 28 | 29 | Arguments: 30 | * ``slayer`` (``slayerLoihi.slayer``): pre-initialized slayer loihi module. 31 | * ``inFeatures``: number of input features. 32 | * ``outFeatures``: number of output features. 33 | * ``weightScale``: scale factor of the defaule initialized weights. Default: 100 34 | * ``preHoodFx``: a function that operates on weight before applying it. Could be used for quantization etc. 35 | * ``weightNorm``: a flag to indicate if weight normalization should be applied or not. Default: False 36 | * ``delay``: a flag to inidicate if axonal delays should be applied or not. Default: False 37 | * ``maxDelay``: maximum allowable delay. Default: 62 38 | * ``countLog``: a flag to indicate if a log of spike count should be maintained and passed around or not. 39 | Default: False 40 | 41 | Usage: 42 | 43 | .. code-block:: python 44 | 45 | blk = denseBlock(self.slayer, 512, 10) 46 | ''' 47 | def __init__(self, slayer, inFeatures, outFeatures, weightScale=100, 48 | preHookFx = lambda x: utils.quantize(x, step=2), weightNorm=False, 49 | delay=False, maxDelay=62, countLog=False): 50 | super(denseBlock, self).__init__() 51 | self.slayer = slayer 52 | self.weightNorm = weightNorm 53 | if weightNorm is True: 54 | self.weightOp = torch.nn.utils.weight_norm(slayer.dense(inFeatures, outFeatures, weightScale, preHookFx), name='weight') 55 | else: 56 | self.weightOp = slayer.dense(inFeatures, outFeatures, weightScale, preHookFx) 57 | self.delayOp = slayer.delay(outFeatures) if delay is True else None 58 | self.countLog = countLog 59 | self.gradLog = True 60 | self.maxDelay = maxDelay 61 | 62 | self.paramsDict = { 63 | 'inFeatures' : inFeatures, 64 | 'outFeatures' : outFeatures, 65 | } 66 | 67 | def forward(self, spike): 68 | spike = self.slayer.spikeLoihi(self.weightOp(spike)) 69 | spike = self.slayer.delayShift(spike, 1) 70 | if self.delayOp is not None: 71 | spike = self.delayOp(spike) 72 | 73 | if self.countLog is True: 74 | return spike, torch.sum(spike) 75 | else: 76 | return spike 77 | 78 | class convBlock(torch.nn.Module): 79 | ''' 80 | This class creates a conv layer block with Loihi neuron. It groups the 81 | synaptic interaction, Loihi neuron response and the associated delays. 82 | 83 | Arguments: 84 | * ``slayer`` (``slayerLoihi.slayer``): pre-initialized slayer loihi module. 85 | * ``inChannels``: number of input channels. 86 | * ``outChannels``: number of output channels. 87 | * ``kernelSize``: size of convolution kernel. 88 | * ``stride``: size of convolution stride. Default: 1 89 | * ``padding``: size of padding. Default: 0 90 | * ``dialtion``: size of convolution dilation. Default: 1 91 | * ``groups``: number of convolution groups. Default: 1 92 | * ``weightScale``: scale factor of the defaule initialized weights. Default: 100 93 | * ``preHoodFx``: a function that operates on weight before applying it. Could be used for quantization etc. 94 | Default: quantization in step of 2 (Mixed weight mode in Loihi) 95 | * ``weightNorm``: a flag to indicate if weight normalization should be applied or not. Default: False 96 | * ``delay``: a flag to inidicate if axonal delays should be applied or not. Default: False 97 | * ``maxDelay``: maximum allowable delay. Default: 62 98 | * ``countLog``: a flag to indicate if a log of spike count should be maintained and passed around or not. 99 | Default: False 100 | 101 | Usage: 102 | 103 | .. code-block:: python 104 | 105 | blk = convBlock(self.slayer, 16, 31, 3, padding=1) 106 | spike = blk(spike) 107 | ''' 108 | def __init__(self, slayer, inChannels, outChannels, kernelSize, stride=1, padding=0, dilation=1, groups=1, weightScale=100, 109 | preHookFx = lambda x: utils.quantize(x, step=2), weightNorm=False, 110 | delay=False, maxDelay=62, countLog=False): 111 | super(convBlock, self).__init__() 112 | self.slayer = slayer 113 | self.weightNorm = weightNorm 114 | if weightNorm is True: 115 | self.weightOp = torch.nn.utils.weight_norm( 116 | slayer.conv(inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weightScale, preHookFx), 117 | name='weight', 118 | ) 119 | else: 120 | self.weightOp = slayer.conv(inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weightScale, preHookFx) 121 | # only channel wise delay is supported for conv layer 122 | # for neuron wise delay, one will need to write a custom block as it would require the spatial dimension as well 123 | self.delayOp = slayer.delay(outChannels) if delay is True else None 124 | self.countLog = countLog 125 | self.gradLog = True 126 | self.maxDelay = maxDelay 127 | 128 | self.paramsDict = { 129 | 'inChannels' : inChannels, 130 | 'outChannels' : outChannels, 131 | 'kernelSize' : kernelSize, 132 | 'stride' : stride, 133 | 'padding' : padding, 134 | 'dilation' : dilation, 135 | 'groups' : groups, 136 | } 137 | 138 | def forward(self, spike): 139 | spike = self.slayer.spikeLoihi(self.weightOp(spike)) 140 | spike = self.slayer.delayShift(spike, 1) 141 | if self.delayOp is not None: 142 | spike = self.delayOp(spike) 143 | 144 | if self.countLog is True: 145 | return spike, torch.sum(spike) 146 | else: 147 | return spike 148 | 149 | class poolBlock(torch.nn.Module): 150 | ''' 151 | This class creates a pool layer block with Loihi neuron. It groups the 152 | synaptic interaction, Loihi neuron response and the associated delays. 153 | 154 | Arguments: 155 | * ``slayer`` (``slayerLoihi.slayer``): pre-initialized slayer loihi module. 156 | * ``kernelSize``: size of pooling kernel. 157 | * ``stride``: size of pooling stride. Default: None(same as ``kernelSize``) 158 | * ``padding``: size of padding. Default: 0 159 | * ``dialtion``: size of convolution dilation. Default: 1 160 | * ``countLog``: a flag to indicate if a log of spike count should be maintained and passed around or not. 161 | Default: False 162 | 163 | Usage: 164 | 165 | .. code-block:: python 166 | 167 | blk = poolBlock(self.slayer, 2) 168 | spike = blk(spike) 169 | ''' 170 | def __init__(self, slayer, kernelSize, stride=None, padding=0, dilation=1, countLog=False): 171 | super(poolBlock, self).__init__() 172 | self.slayer = slayer 173 | self.weightOp = slayer.pool(kernelSize, stride, padding, dilation) 174 | self.countLog = countLog 175 | self.delayOp = None # it does not make sense to have axonal delays after pool block 176 | self.gradLog = False # no need to monitor gradients 177 | self.paramsDict = { 178 | 'kernelSize' : kernelSize, 179 | 'stride' : kernelSize if stride is None else stride, 180 | 'padding' : padding, 181 | 'dilation' : dilation, 182 | } 183 | 184 | def forward(self, spike): 185 | spike = self.slayer.spikeLoihi(self.weightOp(spike)) 186 | spike = self.slayer.delayShift(spike, 1) 187 | 188 | if self.countLog is True: 189 | return spike, None # return None for count. It does not make sense to count for pool layer 190 | else: 191 | return spike 192 | 193 | class flattenBlock(torch.nn.Module): 194 | ''' 195 | This class flattens the spatial dimension. The resulting tensor is compatible with dense layer. 196 | 197 | Arguments: 198 | * ``countLog``: a flag to indicate if a log of spike count should be maintained and passed around or not. 199 | Default: False 200 | 201 | Usage: 202 | 203 | .. code-block:: python 204 | 205 | blk = flattenBlock(self.slayer, True) 206 | spike = blk(spike) 207 | ''' 208 | def __init__(self, countLog=False): 209 | super(flattenBlock, self).__init__() 210 | self.delayOp = None 211 | self.weightOp = None 212 | self.gradLog = False 213 | self.countLog = countLog 214 | self.paramsDict = {} 215 | 216 | def forward(self, spike): 217 | if self.countLog is True: 218 | return spike.reshape((spike.shape[0], -1, 1, 1, spike.shape[-1])), None 219 | else: 220 | return spike.reshape((spike.shape[0], -1, 1, 1, spike.shape[-1])) 221 | 222 | class averageBlock(torch.nn.Module): 223 | ''' 224 | This class averages the spikes among n different output groups for population voting. 225 | 226 | Arguments: 227 | * ``nOutputs``: number of output groups (Equal to the number of ouptut classes). 228 | * ``countLog``: a flag to indicate if a log of spike count should be maintained and passed around or not. 229 | Default: False 230 | 231 | Usage: 232 | 233 | .. code-block:: python 234 | 235 | blk = averageBlock(self.slayer, nOutputs=10) 236 | spike = blk(spike) 237 | ''' 238 | def __init__(self, nOutputs, countLog=False): 239 | super(averageBlock, self).__init__() 240 | self.nOutputs = nOutputs 241 | self.delayOp = None 242 | self.weightOp = None 243 | self.gradLog = False 244 | self.countLog = countLog 245 | self.paramsDict = {} 246 | 247 | def forward(self, spike): 248 | N, _, _, _, T = spike.shape 249 | if self.countLog is True: 250 | return torch.mean(spike.reshape((N, self.nOutputs, -1, 1, T)), dim=2, keepdim=True), None 251 | else: 252 | return torch.mean(spike.reshape((N, self.nOutputs, -1, 1, T)), dim=2, keepdim=True) 253 | 254 | 255 | 256 | 257 | class Network(torch.nn.Module): 258 | ''' 259 | This class encapsulates the network creation from the networks described in netParams 260 | configuration. A netParams configuration is ``slayerSNN.slayerParams.yamlParams`` which 261 | can be initialized from a yaml config file or a dictionary. 262 | 263 | In addition to the standard network ``forward`` function, it also includes ``clamp`` function 264 | for clamping delays, ``gradFlow`` function for monitioring the gradient flow, and ``genModel`` 265 | function for exporting a hdf5 file which is a packs network specification and trained 266 | parameter into a single file that can be possibly used to generate the inference network 267 | specific to a hardware, with some support. 268 | 269 | Arguments: 270 | * ``nOutputs``: number of output groups (Equal to the number of ouptut classes). 271 | * ``countLog``: a flag to indicate if a log of spike count should be maintained and passed around or not. 272 | Default: False 273 | 274 | Usage: 275 | 276 | .. code-block:: python 277 | 278 | blk = averageBlock(self.slayer, nOutputs=10) 279 | spike = blk(spike) 280 | ''' 281 | def __init__(self, netParams, preHookFx=lambda x: utils.quantize(x, step=2), weightNorm=False, countLog=False): 282 | super(Network, self).__init__() 283 | 284 | self.netParams = netParams 285 | self.netParams.print('simulation') 286 | print('') 287 | self.netParams.print('neuron') 288 | print('') 289 | # TODO print netParams 290 | 291 | # initialize slayer 292 | slayer = loihi(netParams['neuron'], netParams['simulation']) 293 | self.slayer = slayer 294 | 295 | self.inputShape = None 296 | self.nOutput = None 297 | self.weightNorm = weightNorm 298 | self.preHookFx =preHookFx 299 | self.countLog = countLog 300 | self.layerDims = [] 301 | 302 | # parse the layer information 303 | self.blocks = self._parseLayers() 304 | 305 | # TODO pass through core usage estimator 306 | print('TODO core usage estimator') 307 | 308 | 309 | def _layerType(self, dim): 310 | if type(dim) is int: 311 | return 'dense' 312 | elif dim.find('c') != -1: 313 | return 'conv' 314 | elif dim.find('avg') != -1: 315 | return 'average' 316 | elif dim.find('a') != -1: 317 | return 'pool' 318 | elif dim.find('x') != -1: 319 | return 'input' 320 | else: 321 | raise Exception('Could not parse the layer description. Found {}'.format(dim)) 322 | # return [int(i) for i in re.findall(r'\d+', dim)] 323 | 324 | def _tableStr(self, typeStr='', width=None, height=None, channel=None, kernel=None, stride=None, 325 | padding=None, delay=False, numParams=None, header=False, footer=False): 326 | if header is True: 327 | return '|{:10s}|{:5s}|{:5s}|{:5s}|{:5s}|{:5s}|{:5s}|{:5s}|{:10s}|'.format( 328 | ' Type ', ' W ', ' H ', ' C ', ' ker ', ' str ', ' pad ', 'delay', ' params ') 329 | elif footer is True and numParams is not None: 330 | return '|{:10s} {:5s} {:5s} {:5s} {:5s} {:5s} {:5s} {:5s}|{:-10d}|'.format( 331 | 'Total', '', '', '', '', '', '', '', numParams) 332 | else: 333 | entry = '|' 334 | entry += '{:10s}|'.format(typeStr) 335 | entry += '{:-5d}|'.format(width) 336 | entry += '{:-5d}|'.format(height) 337 | entry += '{:-5d}|'.format(channel) 338 | entry += '{:-5d}|'.format(kernel) if kernel is not None else '{:5s}|'.format('') 339 | entry += '{:-5d}|'.format(stride) if stride is not None else '{:5s}|'.format('') 340 | entry += '{:-5d}|'.format(padding) if padding is not None else '{:5s}|'.format('') 341 | entry += '{:5s}|'.format(str(delay)) 342 | entry += '{:-10d}|'.format(numParams) if numParams is not None else '{:10s}|'.format('') 343 | 344 | return entry 345 | 346 | def _parseLayers(self): 347 | i = 0 348 | blocks = torch.nn.ModuleList() 349 | layerDim = [] # CHW 350 | is1Dconv = False 351 | 352 | print('\nNetwork Architecture:') 353 | # print('=====================') 354 | print(self._tableStr(header=True)) 355 | 356 | for layer in self.netParams['layer']: 357 | layerType = self._layerType(layer['dim']) 358 | # print(i, layerType) 359 | 360 | # if layer has neuron feild, then use the slayer initialized with it and self.netParams['simulation'] 361 | if 'neuron' in layer.keys(): 362 | print(layerType, 'using individual slayer') 363 | slayer = loihi(layer['neuron'], self.netParams['simulation']) 364 | else: 365 | slayer = self.slayer 366 | 367 | if i==0 and self.inputShape is None: 368 | if layerType == 'input': 369 | self.inputShape = tuple([int(numStr) for numStr in re.findall(r'\d+', layer['dim'])]) 370 | if len(self.inputShape) == 3: 371 | layerDim = list(self.inputShape)[::-1] 372 | elif len(self.inputShape) == 2: 373 | layerDim = [1, self.inputShape[1], self.inputShape[0]] 374 | else: 375 | raise Exception('Could not parse the input dimension. Got {}'.format(self.inputShape)) 376 | elif layerType == 'dense': 377 | self.inputShape = tuple([layer['dim']]) 378 | layerDim = [layer['dim'], 1, 1] 379 | else: 380 | raise Exception('Input dimension could not be determined! It should be the first entry in the' 381 | + "'layer' feild.") 382 | # print(self.inputShape) 383 | print(self._tableStr('Input', layerDim[2], layerDim[1], layerDim[0])) 384 | if layerDim[1] == 1: 385 | is1Dconv = True 386 | else: 387 | # print(i, layer['dim'], self._layerType(layer['dim'])) 388 | if layerType == 'conv': 389 | params = [int(i) for i in re.findall(r'\d+', layer['dim'])] 390 | inChannels = layerDim[0] 391 | outChannels = params[0] 392 | kernelSize = params[1] 393 | stride = layer['stride'] if 'stride' in layer.keys() else 1 394 | padding = layer['padding'] if 'padding' in layer.keys() else kernelSize//2 395 | dilation = layer['dilation'] if 'dilation' in layer.keys() else 1 396 | groups = layer['groups'] if 'groups' in layer.keys() else 1 397 | weightScale = layer['wScale'] if 'wScale' in layer.keys() else 100 398 | delay = layer['delay'] if 'delay' in layer.keys() else False 399 | maxDelay = layer['maxDelay'] if 'maxDelay' in layer.keys() else 62 400 | # print(i, inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weightScale) 401 | 402 | if is1Dconv is False: 403 | blocks.append(convBlock(slayer, inChannels, outChannels, kernelSize, stride, padding, 404 | dilation, groups, weightScale, self.preHookFx, self.weightNorm, 405 | delay, maxDelay, self.countLog)) 406 | layerDim[0] = outChannels 407 | layerDim[1] = int(np.floor((layerDim[1] + 2*padding - dilation * (kernelSize - 1) - 1)/stride + 1)) 408 | layerDim[2] = int(np.floor((layerDim[2] + 2*padding - dilation * (kernelSize - 1) - 1)/stride + 1)) 409 | else: 410 | blocks.append(convBlock(slayer, inChannels, outChannels, [1, kernelSize], [1, stride], [0, padding], 411 | [1, dilation], groups, weightScale, self.preHookFx, self.weightNorm, 412 | delay, maxDelay, self.countLog)) 413 | layerDim[0] = outChannels 414 | layerDim[1] = 1 415 | layerDim[2] = int(np.floor((layerDim[2] + 2*padding - dilation * (kernelSize - 1) - 1)/stride + 1)) 416 | self.layerDims.append(layerDim.copy()) 417 | 418 | print(self._tableStr('Conv', layerDim[2], layerDim[1], layerDim[0], kernelSize, stride, padding, 419 | delay, sum(p.numel() for p in blocks[-1].parameters() if p.requires_grad))) 420 | elif layerType == 'pool': 421 | params = [int(i) for i in re.findall(r'\d+', layer['dim'])] 422 | # print(params[0]) 423 | 424 | blocks.append(poolBlock(slayer, params[0], countLog=self.countLog)) 425 | layerDim[1] = int(np.ceil(layerDim[1] / params[0])) 426 | layerDim[2] = int(np.ceil(layerDim[2] / params[0])) 427 | self.layerDims.append(layerDim.copy()) 428 | 429 | print(self._tableStr('Pool', layerDim[2], layerDim[1], layerDim[0], params[0])) 430 | elif layerType == 'dense': 431 | params = layer['dim'] 432 | # print(params) 433 | if layerDim[1] != 1 or layerDim[2] != 1: # needs flattening of layers 434 | blocks.append(flattenBlock(self.countLog )) 435 | layerDim[0] = layerDim[0] * layerDim[1] * layerDim[2] 436 | layerDim[1] = layerDim[2] = 1 437 | self.layerDims.append(layerDim.copy()) 438 | weightScale = layer['wScale'] if 'wScale' in layer.keys() else 100 439 | delay = layer['delay'] if 'delay' in layer.keys() else False 440 | maxDelay = layer['maxDelay'] if 'maxDelay' in layer.keys() else 62 441 | 442 | blocks.append(denseBlock(slayer, layerDim[0], params, weightScale, self.preHookFx, 443 | self.weightNorm, delay, maxDelay, self.countLog)) 444 | layerDim[0] = params 445 | layerDim[1] = layerDim[2] = 1 446 | self.layerDims.append(layerDim.copy()) 447 | 448 | print(self._tableStr('Dense', layerDim[2], layerDim[1], layerDim[0], delay=delay, 449 | numParams=sum(p.numel() for p in blocks[-1].parameters() if p.requires_grad))) 450 | elif layerType == 'average': 451 | params = [int(i) for i in re.findall(r'\d+', layer['dim'])] 452 | layerDim[0] = params[0] 453 | layerDim[1] = layerDim[2] = 1 454 | self.layerDims.append(layerDim.copy()) 455 | 456 | blocks.append(averageBlock(nOutputs=layerDim[0], countLog=self.countLog)) 457 | print(self._tableStr('Average', 1, 1, params[0])) 458 | 459 | i += 1 460 | self.nOutput = layerDim[0] * layerDim[1] * layerDim[2] 461 | print(self._tableStr(numParams=sum(p.numel() for p in blocks.parameters() if p.requires_grad), footer=True)) 462 | return blocks 463 | 464 | def forward(self, spike): 465 | ''' 466 | Forward operation of the network. 467 | 468 | Arguments: 469 | * ``spike``: Input spke tensor. 470 | 471 | Usage: 472 | 473 | .. code-block:: python 474 | 475 | net = Network(netParams) 476 | spikeOut = net.forward(spike) 477 | ''' 478 | count = [] 479 | 480 | for b in self.blocks: 481 | # print(b) 482 | # print(b.countLog) 483 | if self.countLog is True: 484 | spike, cnt = b(spike) 485 | if cnt is not None: 486 | count.append(cnt.item()) 487 | else: 488 | spike = b(spike) 489 | # print(spike.shape) 490 | 491 | if self.countLog is True: 492 | return spike, torch.tensor(count).reshape((1, -1)).to(spike.device) 493 | else: 494 | return spike 495 | 496 | def clamp(self): 497 | ''' 498 | Clamp routine for delay parameters after gradient step to ensure positive value and limit 499 | the maximum value. 500 | 501 | Usage: 502 | 503 | .. code-block:: python 504 | 505 | net = Network(netParams) 506 | net.clamp() 507 | ''' 508 | for d in self.blocks: 509 | if d.delayOp is not None: 510 | # d.delayOp.delay.data.clamp_(0, 62) 511 | d.delayOp.delay.data.clamp_(0, d.maxDelay) 512 | # print(d.maxDelay) 513 | # print() 514 | # print(d.delayOp.delay.shape) 515 | 516 | def gradFlow(self, path): 517 | ''' 518 | A method to monitor the flow of gradient across the layers. Use it to monitor exploding and 519 | vanishing gradients. ``scaleRho`` must be tweaked to ensure proper gradient flow. Usually 520 | monitoring it for first few epochs is good enough. 521 | 522 | Usage: 523 | 524 | .. code-block:: python 525 | 526 | net = Network(netParams) 527 | net.gradFlow(path_to_save) 528 | ''' 529 | gradNorm = lambda x: torch.norm(x).item()/torch.numel(x) 530 | grad = [] 531 | 532 | for l in self.blocks: 533 | # print(l) 534 | if l.gradLog is True: 535 | if l.weightNorm is True: 536 | grad.append(gradNorm(l.weightOp.weight_g.grad)) 537 | else: 538 | grad.append(gradNorm(l.weightOp.weight.grad)) 539 | 540 | plt.figure() 541 | plt.semilogy(grad) 542 | plt.savefig(path + 'gradFlow.png') 543 | plt.close() 544 | 545 | def genModel(self, fname): 546 | ''' 547 | This function exports a hdf5 encapsulated neuron parameter, network structure, the weight 548 | and delay parameters of the trained network. This is intended to be platform indepenent 549 | representation of the network. The basic protocol of the file is as follows: 550 | 551 | .. code-block:: 552 | 553 | |->simulation # simulation description 554 | | |->Ts # sampling time. Usually 1 555 | | |->tSample # length of the sample to run 556 | |->layer # description of network layer blocks such as input, dense, conv, pool, flatten, average 557 | |->0 558 | | |->{shape, type, ...} # each layer description has ateast shape and type attribute 559 | |->1 560 | | |->{shape, type, ...} 561 | : 562 | |->n 563 | |->{shape, type, ...} 564 | 565 | input : {shape, type} 566 | flatten: {shape, type} 567 | average: {shape, type} 568 | dense : {shape, type, neuron, inFeatures, outFeatures, weight, delay(if available)} 569 | pool : {shape, type, neuron, kernelSize, stride, padding, dilation, weight} 570 | conv : {shape, type, neuron, inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weight, delay(if available)} 571 | |-> this is the description of the compartment parameters 572 | |-> {iDecay, vDecay, vThMant, refDelay, ... (other additional parameters can exist)} 573 | 574 | Usage: 575 | 576 | .. code-block:: python 577 | 578 | net = Network(netParams) 579 | net.genModel(path_to_save) 580 | ''' 581 | qWeights = lambda x: self.preHookFx(x).cpu().data.numpy().squeeze() 582 | qDelays = lambda d: torch.floor(d).flatten().cpu().data.numpy().squeeze() 583 | 584 | h = h5py.File(fname, 'w') 585 | 586 | simulation = h.create_group('simulation') 587 | 588 | for key, value in self.netParams['simulation'].items(): 589 | # print(key, value) 590 | simulation[key] = value 591 | 592 | layer = h.create_group('layer') 593 | layer.create_dataset('0/type', (1, ), 'S10', [b'input']) 594 | layer.create_dataset('0/shape', data=np.array([self.inputShape[2], self.inputShape[1], self.inputShape[0]])) 595 | for i, block in enumerate(self.blocks): 596 | # print(block.__class__.__name__, self.layerDims[i]) 597 | # find the layerType from the block name. Exclude last 5 characters: Block 598 | layerType = block.__class__.__name__[:-5] 599 | # print(layerType.encode('ascii', 'ignore')) 600 | layer.create_dataset('{}/type'.format(i+1), (1, ), 'S10', [layerType.encode('ascii', 'ignore')]) 601 | # print(i, self.layerDims[i]) 602 | layer.create_dataset('{}/shape'.format(i+1), data=np.array(self.layerDims[i])) 603 | 604 | if block.weightOp is not None: 605 | if self.weightNorm is True and layerType != 'pool': 606 | torch.nn.utils.remove_weight_norm(block.weightOp, name='weight') 607 | layer.create_dataset('{}/weight'.format(i+1), data=qWeights(block.weightOp.weight)) 608 | 609 | if block.delayOp is not None: 610 | layer.create_dataset('{}/delay'.format(i+1), data=qDelays(block.delayOp.delay)) 611 | 612 | for key, param in block.paramsDict.items(): 613 | layer.create_dataset('{}/{}'.format(i+1, key), data=param) 614 | if layerType != 'flatten' and layerType != 'average': 615 | for key, value in block.slayer.neuron.items(): 616 | # print(i, key, value) 617 | layer.create_dataset('{}/neuron/{}'.format(i+1, key), data=value) 618 | 619 | h.close() 620 | 621 | 622 | def loadModel(self, fname): 623 | ''' 624 | This function loads the network from a perviously saved hdf5 file using ``genModel``. 625 | 626 | Usage: 627 | 628 | .. code-block:: python 629 | 630 | net = Network(netParams) 631 | net.loadModel(path_of_model) 632 | ''' 633 | # only the layer weights and delays shall be loaded 634 | h = h5py.File(fname, 'r') 635 | 636 | # one more layer for input layer in the hdf5 file 637 | assert len(h['layer']) == len(self.blocks) + 1, 'The number of layers in the network does not match with the number of layers in the file {}. Expected {}, found {}'.format(fname, len(self.blocks) + 1, len(h['layer'])) 638 | 639 | for i, block in enumerate(self.blocks): 640 | idxKey = '{}'.format(i+1) 641 | blockTypeStr = block.__class__.__name__[:-5] 642 | layerTypeStr = h['layer'][idxKey]['type'][()][0].decode('utf-8') 643 | assert layerTypeStr == blockTypeStr, 'The layer typestring do not match. Found {} in network and {} in file.'.format(blockTypeStr, layerTypeStr) 644 | 645 | if block.weightOp is not None: 646 | if self.weightNorm is True and layerTypeStr != 'pool': 647 | torch.nn.utils.remove_weight_norm(block.weightOp, name='weight') 648 | block.weightOp.weight.data = torch.FloatTensor(h['layer'][idxKey]['weight'][()]).reshape(block.weightOp.weight.shape).to(block.weightOp.weight.device) 649 | if self.weightNorm is True and layerTypeStr != 'pool': 650 | block.weightOp = torch.nn.utils.weight_norm(block.weightOp, name='weight') 651 | 652 | if block.delayOp is not None: 653 | block.delayOp.delay.data = torch.FloatTensor(h['layer'][idxKey]['delay'][()]).reshape(block.delayOp.delay.shape).to(block.delayOp.delay.device) 654 | 655 | 656 | 657 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/cuda/convKernels.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Author: Sumit Bam Shrestha 3 | * 09/05/2019 6:00 PM 4 | * This header contains routines to perform time based convolution and correlation of signal 5 | * These operations are key in forward propagation and backpropagation routines in SLAYER 6 | */ 7 | #ifndef CONVKERNELS_H_INCLUDED 8 | #define CONVKERNELS_H_INCLUDED 9 | 10 | template 11 | __global__ void convKernel( T* output, 12 | const T* input, const T* filter, 13 | unsigned signalSize, unsigned filterSize, unsigned nNeurons, 14 | float Ts) 15 | { 16 | // calcualte the threadID 17 | // this is the index of the signal along time axis 18 | unsigned tID = blockIdx.x * blockDim.x + threadIdx.x; 19 | unsigned nID = blockIdx.y * blockDim.y + threadIdx.y; 20 | 21 | if(tID >= signalSize) return; 22 | if(nID >= nNeurons) return; 23 | 24 | // declare local variables 25 | float result = 0.0f; 26 | 27 | // calculate convolution sum 28 | for(unsigned i=0; i= 0) result += input[id + nID * signalSize] * filter[i]; 32 | } 33 | output[tID + nID * signalSize] = result * Ts; 34 | return; 35 | } 36 | 37 | template 38 | __global__ void corrKernel( T* output, 39 | const T* input, const T* filter, 40 | unsigned signalSize, unsigned filterSize, unsigned nNeurons, 41 | float Ts) 42 | { 43 | // calcualte the threadID 44 | // this is the index of the signal along time axis 45 | unsigned tID = blockIdx.x * blockDim.x + threadIdx.x; 46 | unsigned nID = blockIdx.y * blockDim.y + threadIdx.y; 47 | 48 | if(tID >= signalSize) return; 49 | if(nID >= nNeurons) return; 50 | 51 | // declare local variables 52 | float result = 0.0f; 53 | 54 | // calculate convolution sum 55 | for(unsigned i=0; i 65 | void conv( T* output, 66 | const T* input, const T* filter, 67 | unsigned signalSize, unsigned filterSize, unsigned nNeurons, 68 | float Ts) 69 | { 70 | dim3 thread(128, 8, 1); 71 | 72 | // int nGrid = 128; 73 | int nGrid = ceil( 1.0f * nNeurons / thread.y / 65535); 74 | int neuronsPerGrid = ceil(1.0f * nNeurons / nGrid); 75 | 76 | for(auto i=0; i= 65535) AT_ERROR("maximum blockDim.y exceeded."); 89 | if(block.z >= 65535) AT_ERROR("maximum blockDim.z exceeded."); 90 | 91 | convKernel<<< block, thread >>>( output + startOffset * signalSize, 92 | input + startOffset * signalSize, 93 | filter, signalSize, filterSize, 94 | neuronsInGrid, Ts); 95 | } 96 | } 97 | 98 | template 99 | void corr( T* output, 100 | const T* input, const T* filter, 101 | unsigned signalSize, unsigned filterSize, unsigned nNeurons, 102 | float Ts) 103 | { 104 | dim3 thread(128, 8, 1); 105 | 106 | // int nGrid = 128; 107 | int nGrid = ceil( 1.0f * nNeurons / thread.y / 65535 ); 108 | int neuronsPerGrid = ceil(1.0f * nNeurons / nGrid); 109 | for(auto i=0; i= 65535) AT_ERROR("maximum blockDim.y exceeded."); 122 | if(block.z >= 65535) AT_ERROR("maximum blockDim.z exceeded."); 123 | 124 | corrKernel<<< block, thread >>>( output + startOffset * signalSize, 125 | input + startOffset * signalSize, 126 | filter, signalSize, filterSize, 127 | neuronsInGrid, Ts); 128 | } 129 | } 130 | 131 | #endif // CONVKERNELS_H_INCLUDED 132 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/cuda/shiftKernels.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Author: Sumit Bam Shrestha 3 | * 10/06/2019 6:30 PM 4 | * This header contains routines to perform tensor shifts as defined by the shift parameter 5 | */ 6 | #ifndef SHIFTKERNELS_H_INCLUDED 7 | #define SHIFTKERNELS_H_INCLUDED 8 | 9 | template 10 | __global__ void shiftKernel(T* output, 11 | const T* input, 12 | const T shiftValue, 13 | unsigned signalSize, unsigned nNeurons, float Ts) 14 | { 15 | // calcualte the threadID 16 | // this is the index of the signal along time axis 17 | unsigned tID = blockIdx.x * blockDim.x + threadIdx.x; 18 | unsigned nID = blockIdx.y * blockDim.y + threadIdx.y; 19 | 20 | if(tID >= signalSize) return; 21 | if(nID >= nNeurons) return; 22 | 23 | // floor the shift to integer value 24 | int shiftBlocks = static_cast(shiftValue/Ts); 25 | 26 | float temp = 0; 27 | auto neuronOffset = signalSize * nID; 28 | // shift the elements 29 | int id = tID - shiftBlocks; 30 | if(id >= 0 && id 37 | __global__ void shiftKernel(T* output, 38 | const T* input, 39 | const T* shiftLUT, 40 | unsigned signalSize, unsigned nNeurons, float Ts) 41 | { 42 | // calcualte the threadID 43 | // this is the index of the signal along time axis 44 | unsigned tID = blockIdx.x * blockDim.x + threadIdx.x; 45 | unsigned nID = blockIdx.y * blockDim.y + threadIdx.y; 46 | 47 | if(tID >= signalSize) return; 48 | if(nID >= nNeurons) return; 49 | 50 | // floor the shift to integer value 51 | int shiftBlocks = static_cast(shiftLUT[nID]/Ts); 52 | 53 | float temp = 0; 54 | auto neuronOffset = signalSize * nID; 55 | // shift the elements 56 | int id = tID - shiftBlocks; 57 | if(id >= 0 && id 64 | void shift( T* output, 65 | const T* input, 66 | const T shiftValue, 67 | unsigned signalSize, unsigned nNeurons, float Ts) 68 | { 69 | dim3 thread(128, 8, 1); 70 | int nGrid = ceil( 1.0f * nNeurons / thread.y / 65535 ); 71 | int neuronsPerGrid = ceil(1.0f * nNeurons / nGrid); 72 | 73 | for(auto i=0; i= 65535) AT_ERROR("maximum blockDim.y exceeded."); 86 | if(block.z >= 65535) AT_ERROR("maximum blockDim.z exceeded."); 87 | 88 | // std::cout << "Thread: (" << thread.x << ", " << thread.y << ", " << thread.z << ")" << std::endl; 89 | // std::cout << "Block : (" << block.x << ", " << block.y << ", " << block.z << ")" << std::endl; 90 | 91 | shiftKernel<<< block, thread >>>(output + startOffset * signalSize, 92 | input + startOffset * signalSize, 93 | shiftValue, 94 | signalSize, neuronsInGrid, Ts); 95 | } 96 | 97 | // cudaDeviceSynchronize(); 98 | } 99 | 100 | template 101 | void shift( T* output, 102 | const T* input, 103 | const T* shiftLUT, 104 | unsigned signalSize, unsigned nNeurons, unsigned nBatch, float Ts) 105 | { 106 | dim3 thread(128, 8, 1); 107 | 108 | int nGrid = ceil( 1.0f * nNeurons / thread.y / 65535); 109 | int neuronsPerGrid = ceil(1.0f * nNeurons / nGrid); 110 | 111 | for(unsigned i=0; i= 65535) AT_ERROR("maximum blockDim.y exceeded."); 126 | if(block.z >= 65535) AT_ERROR("maximum blockDim.z exceeded."); 127 | 128 | shiftKernel<<< block, thread >>>(output + (i * nNeurons + startOffset) * signalSize, 129 | input + (i * nNeurons + startOffset) * signalSize, 130 | shiftLUT + startOffset, 131 | signalSize, neuronsInGrid, Ts); 132 | } 133 | } 134 | 135 | // cudaDeviceSynchronize(); 136 | } 137 | 138 | 139 | #endif // SHIFTKERNELS_H_INCLUDED -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/cuda/slayerKernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "spikeKernels.h" 4 | #include "convKernels.h" 5 | #include "shiftKernels.h" 6 | 7 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | #define CHECK_DEVICE(x, y) AT_ASSERTM(x.device().index() == y.device().index(), #x " and " #y " must be in same CUDA device") 11 | 12 | // C++ Python interface 13 | 14 | torch::Tensor getSpikesCuda( 15 | torch::Tensor d_u, 16 | const torch::Tensor& d_nu, 17 | const float theta, 18 | const float Ts) 19 | { 20 | CHECK_INPUT(d_u); 21 | CHECK_INPUT(d_nu); 22 | 23 | // check if tensor are in same device 24 | CHECK_DEVICE(d_u, d_nu); 25 | 26 | auto d_s = torch::empty_like(d_u); 27 | 28 | // TODO implement for different data types 29 | 30 | // set the current cuda device to wherever the tensor d_u resides 31 | cudaSetDevice(d_u.device().index()); 32 | 33 | unsigned nuSize = d_nu.size(-1); 34 | unsigned Ns = d_u.size(-1); 35 | unsigned nNeurons = d_u.size(0) * d_u.size(1) * d_u.size(2) * d_u.size(3); 36 | getSpikes(d_s.data(), d_u.data(), d_nu.data(), nNeurons, nuSize, Ns, theta, Ts); 37 | 38 | return d_s; 39 | } 40 | 41 | torch::Tensor convCuda(torch::Tensor input, torch::Tensor filter, float Ts) 42 | { 43 | CHECK_INPUT(input); 44 | CHECK_INPUT(filter); 45 | CHECK_DEVICE(input, filter); 46 | 47 | cudaSetDevice(input.device().index()); 48 | 49 | auto output = torch::empty_like(input); 50 | 51 | unsigned signalSize = input.size(-1); 52 | unsigned filterSize = filter.numel(); 53 | unsigned nNeurons = input.numel()/input.size(-1); 54 | conv(output.data(), input.data(), filter.data(), signalSize, filterSize, nNeurons, Ts); 55 | 56 | return output; 57 | } 58 | 59 | torch::Tensor corrCuda(torch::Tensor input, torch::Tensor filter, float Ts) 60 | { 61 | CHECK_INPUT(input); 62 | CHECK_INPUT(filter); 63 | CHECK_DEVICE(input, filter); 64 | 65 | cudaSetDevice(input.device().index()); 66 | 67 | auto output = torch::empty_like(input); 68 | 69 | unsigned signalSize = input.size(-1); 70 | unsigned filterSize = filter.numel(); 71 | unsigned nNeurons = input.numel()/input.size(-1); 72 | corr(output.data(), input.data(), filter.data(), signalSize, filterSize, nNeurons, Ts); 73 | 74 | return output; 75 | } 76 | 77 | torch::Tensor shiftCuda(torch::Tensor input, torch::Tensor shiftLUT, float Ts) 78 | { 79 | CHECK_INPUT(input); 80 | CHECK_INPUT(shiftLUT); 81 | CHECK_DEVICE(input, shiftLUT); 82 | 83 | cudaSetDevice(input.device().index()); 84 | 85 | auto output = torch::empty_like(input); 86 | 87 | if(shiftLUT.numel() == 1) 88 | { 89 | unsigned signalSize = input.size(-1); 90 | unsigned nNeurons = input.numel()/signalSize; 91 | 92 | float shiftValue = shiftLUT.item(); 93 | 94 | shift(output.data(), input.data(), shiftValue, signalSize, nNeurons, Ts); 95 | } 96 | else 97 | { 98 | unsigned signalSize = input.size(-1); 99 | unsigned nBatch = input.size(0); 100 | unsigned nNeurons = input.numel()/signalSize/nBatch; 101 | 102 | AT_ASSERTM(shiftLUT.numel() == nNeurons, "shift and number of neurons must be same"); 103 | 104 | shift(output.data(), input.data(), shiftLUT.data(), signalSize, nNeurons, nBatch, Ts); 105 | } 106 | 107 | return output; 108 | } 109 | 110 | torch::Tensor shift1Cuda(torch::Tensor input, torch::Tensor shiftLUT) 111 | { 112 | return shiftCuda(input, shiftLUT, 1.0f); 113 | } 114 | 115 | torch::Tensor shiftFlCuda(torch::Tensor input, float shiftLUT, float Ts) 116 | { 117 | CHECK_INPUT(input); 118 | 119 | cudaSetDevice(input.device().index()); 120 | 121 | auto output = torch::empty_like(input); 122 | 123 | unsigned signalSize = input.size(-1); 124 | unsigned nNeurons = input.numel()/signalSize; 125 | 126 | shift(output.data(), input.data(), shiftLUT, signalSize, nNeurons, Ts); 127 | 128 | return output; 129 | } 130 | 131 | torch::Tensor shiftFl1Cuda(torch::Tensor input, float shiftLUT) 132 | { 133 | return shiftFlCuda(input, shiftLUT, 1.0f); 134 | } 135 | 136 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 137 | { 138 | m.def("getSpikes", &getSpikesCuda, "Get spikes (CUDA)"); 139 | m.def("conv" , &convCuda , "Convolution in time (CUDA)"); 140 | m.def("corr" , &corrCuda , "Correlation in time (CUDA)"); 141 | m.def("shift" , &shiftCuda , "Element shift in time (CUDA)"); 142 | m.def("shift" , &shift1Cuda , "Element shift in time (CUDA)"); 143 | m.def("shift" , &shiftFlCuda , "Element shift in time (CUDA)"); 144 | m.def("shift" , &shiftFl1Cuda , "Element shift in time (CUDA)"); 145 | } 146 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/cuda/slayerLoihiKernels.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "spikeLoihiKernels.h" 4 | 5 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 6 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 7 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 8 | #define CHECK_DEVICE(x, y) AT_ASSERTM(x.device().index() == y.device().index(), #x " and " #y " must be in same CUDA device") 9 | 10 | // C++ Python interface 11 | 12 | std::vector getSpikesCuda( 13 | torch::Tensor weightedSpikes, 14 | // const unsigned weightScale, 15 | const unsigned wgtExp, 16 | const unsigned theta, 17 | const unsigned iDecay, 18 | const unsigned vDecay, 19 | const unsigned refDelay) 20 | { 21 | CHECK_INPUT(weightedSpikes); 22 | 23 | auto current = torch::empty_like(weightedSpikes); 24 | auto voltage = torch::empty_like(weightedSpikes); 25 | auto spike = torch::empty_like(weightedSpikes); 26 | 27 | // set the current cuda device to wherever the tensor d_u resides 28 | cudaSetDevice(weightedSpikes.device().index()); 29 | 30 | unsigned Ns = weightedSpikes.size(-1); 31 | unsigned nNeurons = weightedSpikes.numel()/weightedSpikes.size(-1); 32 | 33 | // std::cout << "Ns = " << Ns << std::endl 34 | // << "nNeurons = " << nNeurons << std::endl; 35 | // std::cout << "refDelay = " << refDelay << std::endl; 36 | 37 | getSpikes(spike.data(), 38 | voltage.data(), 39 | current.data(), 40 | weightedSpikes.data(), 41 | // weightScale, nNeurons, Ns, iDecay, vDecay, theta); 42 | wgtExp, nNeurons, Ns, iDecay, vDecay, refDelay, theta); 43 | 44 | // return {weightedSpikes, weightedSpikes, weightedSpikes}; 45 | return {spike, voltage, current}; 46 | } 47 | 48 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 49 | { 50 | m.def("getSpikes", &getSpikesCuda, "Get spikes for Loihi neuron (CUDA)"); 51 | } 52 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/cuda/spikeKernels.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Author: Sumit Bam Shrestha 3 | * 09/05/2019 4:00 PM 4 | * Contains routines that converts membrane potential of neuron into spikes 5 | */ 6 | #ifndef SPIKEKERNELS_H_INCLUDED 7 | #define SPIKEKERNELS_H_INCLUDED 8 | 9 | template 10 | __global__ void getSpikesKernel( 11 | T* __restrict__ d_s, 12 | T* __restrict__ d_u, 13 | const T* __restrict__ d_nu, 14 | unsigned nNeurons, unsigned nuSize, unsigned Ns, 15 | float theta, float Ts) 16 | { 17 | unsigned neuronID = blockIdx.x * blockDim.x + threadIdx.x; 18 | const T spike = 1.0f/Ts; 19 | 20 | if(neuronID >= nNeurons) return; 21 | 22 | for(unsigned i=0; i= theta) 26 | { 27 | d_s[linearID] = spike; 28 | // dynamic parallelism seems to be slower because of race condition!!! 29 | // ahpKernel<<< block, thread >>>(d_u + linearID, d_nu, nuSize); 30 | // cudaDeviceSynchronize(); 31 | for(unsigned j=0; j 42 | __global__ void evalRhoKernel(T* d_rho, const T* d_u, float theta, float tau, unsigned nNeurons, unsigned Ns, float scale) 43 | { 44 | unsigned timeID = blockIdx.x * blockDim.x + threadIdx.x; 45 | unsigned nID = blockIdx.y * blockDim.y + threadIdx.y; 46 | 47 | if(timeID >= Ns || nID >= nNeurons) return; 48 | 49 | unsigned linearID = timeID + nID * Ns; 50 | 51 | d_rho[linearID] = scale/tau * exp(-fabs(theta - d_u[linearID])/tau); 52 | } 53 | 54 | template 55 | void getSpikes(T* d_s, T* d_u, const T* d_nu, unsigned nNeurons, unsigned nuSize, unsigned Ns, float theta, float Ts) 56 | { 57 | unsigned thread = 256; 58 | unsigned block = ceil(1.0f * nNeurons / thread); 59 | getSpikesKernel<<< block, thread >>>(d_s, d_u, d_nu, nNeurons, nuSize, Ns, theta, Ts); 60 | } 61 | 62 | template 63 | void evalRho(T* d_rho, const T* d_u, float theta, float tauRho, float scaleRho, unsigned nNeurons, unsigned Ns) 64 | { 65 | dim3 thread, block; 66 | thread.x = 128; 67 | thread.y = 8; 68 | block.x = ceil(1.0f * Ns/thread.x); 69 | block.y = ceil(1.0f * nNeurons/thread.y); 70 | if(block.y >= 65535) AT_ERROR("maximum blockDim.y exceeded"); 71 | if(block.z >= 65535) AT_ERROR("maximum blockDim.z exceeded"); 72 | 73 | // slayerio::cout << "scaleRho = " << scaleRho << ", tauRho = " << tauRho << std::endl; 74 | 75 | // evalRhoKernel<<< block, thread >>>(rho, u, theta, tau, info.nNeurons, Ns); 76 | // evalRhoKernel<<< block, thread >>>(rho, u, theta, tau, info.nNeurons, Ns, 1.0/10); 77 | evalRhoKernel<<< block, thread >>>(d_rho, d_u, theta, tauRho * theta, nNeurons, Ns, scaleRho); 78 | } 79 | 80 | #endif // SPIKEKERNELS_H_INCLUDED -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/cuda/spikeLoihiKernels.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Author: Sumit Bam Shrestha 3 | * 10/05/2019 11:00 AM 4 | * Contains routines that converts membrane potential of neuron into spikes 5 | */ 6 | #ifndef SPIKELOIHIKERNELS_H_INCLUDED 7 | #define SPIKELOIHIKERNELS_H_INCLUDED 8 | 9 | template 10 | __global__ void getSpikesKernel( 11 | T* __restrict__ s, 12 | T* __restrict__ v, 13 | T* __restrict__ u, 14 | const T* __restrict__ weightedSpikes, 15 | const unsigned weightScale, 16 | const unsigned nNeurons, 17 | const unsigned Ns, 18 | const unsigned iDecay, 19 | const unsigned vDecay, 20 | const unsigned refDelay, 21 | const int theta) // int because using unsigned value is giving errors when comparing the result with signed int 22 | { 23 | unsigned neuronID = blockIdx.x * blockDim.x + threadIdx.x; 24 | 25 | if(neuronID >= nNeurons) return; 26 | 27 | int uOld = 0; 28 | int vOld = 0; 29 | unsigned refState = 0; 30 | unsigned spike = 0; 31 | 32 | for(unsigned i=0; i= 0) ? 1 : -1 ; 45 | int vSign = (vOld >= 0) ? 1 : -1 ; 46 | 47 | int uTemp = uSign * ( ( uSign * uOld * ( (1<<12) - iDecay ) ) >> 12 ) + weightScale * int(weightedSpikes[linearID]); 48 | int vTemp = vSign * ( ( vSign * vOld * ( (1<<12) - vDecay ) ) >> 12 ) + uTemp; 49 | 50 | // s[linearID] = 0; 51 | // u[linearID] = uOld = uTemp; 52 | 53 | // if( vTemp > theta ) 54 | // { 55 | // s[linearID] = 1; 56 | // v[linearID] = vDecay; 57 | // vOld = 0; 58 | // } 59 | // else 60 | // v[linearID] = vOld = vTemp; 61 | 62 | if(i>=refDelay) refState -= unsigned(s[linearID-refDelay]); 63 | spike = (vTemp > theta) * (refState == 0); 64 | vOld = vTemp * (1 - spike) * (refState == 0); 65 | refState += spike; 66 | 67 | s[linearID] = spike; 68 | u[linearID] = uOld = uTemp; 69 | v[linearID] = vOld; 70 | v[linearID] = spike>0 ? int(vDecay) : vOld; 71 | } 72 | } 73 | 74 | 75 | template 76 | void getSpikes( 77 | T* __restrict__ s, 78 | T* __restrict__ v, 79 | T* __restrict__ u, 80 | const T* __restrict__ weightedSpikes, 81 | // const unsigned weightScale, 82 | const unsigned wgtExp, 83 | const unsigned nNeurons, 84 | const unsigned Ns, 85 | const unsigned iDecay, 86 | const unsigned vDecay, 87 | const unsigned refDelay, 88 | const unsigned theta) 89 | { 90 | unsigned thread = 256; 91 | unsigned block = ceil(1.0f * nNeurons / thread); 92 | // std::cout << "Ns : " << Ns << std::endl; 93 | // std::cout << "iDecay : " << iDecay << std::endl; 94 | // std::cout << "vDecay : " << vDecay << std::endl; 95 | getSpikesKernel<<< block, thread >>>(s, v, u, weightedSpikes, 1 << (6 + wgtExp), nNeurons, Ns, iDecay, vDecay, refDelay, theta); 96 | } 97 | 98 | #endif // SPIKELOIHIKERNELS_H_INCLUDED -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/learningStats.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | class learningStat(): 5 | ''' 6 | This class collect the learning statistics over the epoch. 7 | 8 | Usage: 9 | 10 | This class is designed to be used with learningStats instance although it can be used separately. 11 | 12 | >>> trainingStat = learningStat() 13 | ''' 14 | def __init__(self): 15 | self.lossSum = 0 16 | self.correctSamples = 0 17 | self.numSamples = 0 18 | self.minloss = None 19 | self.maxAccuracy = None 20 | self.lossLog = [] 21 | self.accuracyLog = [] 22 | self.bestLoss = False 23 | self.bestAccuracy = False 24 | 25 | def reset(self): 26 | ''' 27 | Reset the learning staistics. 28 | This should usually be done before the start of an epoch so that new statistics counts can be accumulated. 29 | 30 | Usage: 31 | 32 | >>> trainingStat.reset() 33 | ''' 34 | self.lossSum = 0 35 | self.correctSamples = 0 36 | self.numSamples = 0 37 | 38 | def loss(self): 39 | ''' 40 | Returns the average loss calculated from the point the stats was reset. 41 | 42 | Usage: 43 | 44 | >>> loss = trainingStat.loss() 45 | ''' 46 | if self.numSamples > 0: 47 | return self.lossSum/self.numSamples 48 | else: 49 | return None 50 | 51 | def accuracy(self): 52 | ''' 53 | Returns the average accuracy calculated from the point the stats was reset. 54 | 55 | Usage: 56 | 57 | >>> accuracy = trainingStat.accuracy() 58 | ''' 59 | if self.numSamples > 0 and self.correctSamples > 0: 60 | return self.correctSamples/self.numSamples 61 | else: 62 | return None 63 | 64 | def update(self): 65 | ''' 66 | Updates the stats of the current session and resets the measures for next session. 67 | 68 | Usage: 69 | 70 | >>> trainingStat.update() 71 | ''' 72 | currentLoss = self.loss() 73 | self.lossLog.append(currentLoss) 74 | if self.minloss is None: 75 | self.minloss = currentLoss 76 | else: 77 | if currentLoss < self.minloss: 78 | self.minloss = currentLoss 79 | self.bestLoss = True 80 | else: 81 | self.bestLoss = False 82 | # self.minloss = self.minloss if self.minloss < currentLoss else currentLoss 83 | 84 | currentAccuracy = self.accuracy() 85 | self.accuracyLog.append(currentAccuracy) 86 | if self.maxAccuracy is None: 87 | self.maxAccuracy = currentAccuracy 88 | else: 89 | if currentAccuracy > self.maxAccuracy: 90 | self.maxAccuracy = currentAccuracy 91 | self.bestAccuracy = True 92 | else: 93 | self.bestAccuracy = False 94 | # self.maxAccuracy = self.maxAccuracy if self.maxAccuracy > currentAccuracy else currentAccuracy 95 | 96 | def displayString(self): 97 | loss = self.loss() 98 | accuracy = self.accuracy() 99 | minloss = self.minloss 100 | maxAccuracy = self.maxAccuracy 101 | 102 | if loss is None: # no stats available 103 | return None 104 | elif accuracy is None: 105 | if minloss is None: # accuracy and minloss stats is not available 106 | return 'loss = %-11.5g'%(loss) 107 | else: # accuracy is not available but minloss is available 108 | return 'loss = %-11.5g (min = %-11.5g)'%(loss, minloss) 109 | else: 110 | if minloss is None and maxAccuracy is None: # minloss and maxAccuracy is available 111 | return 'loss = %-11.5g %-11s accuracy = %-8.5g %-8s '%(loss, ' ', accuracy, ' ') 112 | else: # all stats are available 113 | return 'loss = %-11.5g (min = %-11.5g) accuracy = %-8.5g (max = %-8.5g)'%(loss, minloss, accuracy, maxAccuracy) 114 | 115 | class learningStats(): 116 | ''' 117 | This class provides mechanism to collect learning stats for training and testing, and displaying them efficiently. 118 | 119 | Usage: 120 | 121 | .. code-block:: python 122 | 123 | stats = learningStats() 124 | 125 | for epoch in range(100): 126 | tSt = datetime.now() 127 | 128 | stats.training.reset() 129 | for i in trainingLoop: 130 | # other main stuffs 131 | stats.training.correctSamples += numberOfCorrectClassification 132 | stats.training.numSamples += numberOfSamplesProcessed 133 | stats.training.lossSum += currentLoss 134 | stats.print(epoch, i, (datetime.now() - tSt).total_seconds()) 135 | stats.training.update() 136 | 137 | stats.testing.reset() 138 | for i in testingLoop 139 | # other main stuffs 140 | stats.testing.correctSamples += numberOfCorrectClassification 141 | stats.testing.numSamples += numberOfSamplesProcessed 142 | stats.testing.lossSum += currentLoss 143 | stats.print(epoch, i) 144 | stats.training.update() 145 | 146 | ''' 147 | def __init__(self): 148 | self.linesPrinted = 0 149 | self.training = learningStat() 150 | self.testing = learningStat() 151 | 152 | def update(self): 153 | ''' 154 | Updates the stats for training and testing and resets the measures for next session. 155 | 156 | Usage: 157 | 158 | >>> stats.update() 159 | ''' 160 | self.training.update() 161 | self.training.reset() 162 | self.testing.update() 163 | self.testing.reset() 164 | 165 | def print(self, epoch, iter=None, timeElapsed=None, header=None, footer=None): 166 | ''' 167 | Prints the available learning statistics from the current session on the console. 168 | For Linux systems, prints the data on same terminal space (might not work properly on other systems). 169 | 170 | Arguments: 171 | * ``epoch``: epoch counter to display (required). 172 | * ``iter``: iteration counter to display (not required). 173 | * ``timeElapsed``: runtime information (not required). 174 | * ``header``: things to be printed before printing learning statistics. Default: ``None``. 175 | * ``footer``: things to be printed after printing learning statistics. Default: ``None``. 176 | 177 | Usage: 178 | 179 | .. code-block:: python 180 | 181 | # prints stats with epoch index provided 182 | stats.print(epoch) 183 | 184 | # prints stats with epoch index and iteration index provided 185 | stats.print(epoch, iter=i) 186 | 187 | # prints stats with epoch index, iteration index and time elapsed information provided 188 | stats.print(epoch, iter=i, timeElapsed=time) 189 | ''' 190 | print('\033[%dA'%(self.linesPrinted)) 191 | 192 | self.linesPrinted = 1 193 | 194 | epochStr = 'Epoch : %10d'%(epoch) 195 | iterStr = '' if iter is None else '(i = %7d)'%(iter) 196 | profileStr = '' if timeElapsed is None else ', %12.4f ms elapsed'%(timeElapsed * 1000) 197 | 198 | if header is not None: 199 | for h in header: 200 | print('\033[2K'+str(h)) 201 | self.linesPrinted +=1 202 | 203 | print(epochStr + iterStr + profileStr) 204 | print(self.training.displayString()) 205 | self.linesPrinted += 2 206 | if self.testing.displayString() is not None: 207 | print(self.testing.displayString()) 208 | self.linesPrinted += 1 209 | 210 | if footer is not None: 211 | for f in footer: 212 | print('\033[2K'+str(f)) 213 | self.linesPrinted +=1 214 | 215 | 216 | def plot(self, figures=(1, 2), saveFig=False, path=''): 217 | ''' 218 | Plots the available learning statistics. 219 | 220 | Arguments: 221 | * ``figures``: Index of figure ID to plot on. Default is figure(1) for loss plot and figure(2) for accuracy plot. 222 | * ``saveFig``(``bool``): flag to save figure into a file. 223 | * ``path``: path to save the file. Defaule is ``''``. 224 | 225 | Usage: 226 | 227 | .. code-block:: python 228 | 229 | # plot stats 230 | stats.plot() 231 | 232 | # plot stats figures specified 233 | stats.print(figures=(10, 11)) 234 | ''' 235 | plt.figure(figures[0]) 236 | plt.cla() 237 | if len(self.training.lossLog) > 0: 238 | plt.semilogy(self.training.lossLog, label='Training') 239 | if len(self.testing.lossLog) > 0: 240 | plt.semilogy(self.testing .lossLog, label='Testing') 241 | plt.xlabel('Epoch') 242 | plt.ylabel('Loss') 243 | plt.legend() 244 | if saveFig is True: 245 | plt.savefig(path + 'loss.png') 246 | # plt.close() 247 | 248 | plt.figure(figures[1]) 249 | plt.cla() 250 | if len(self.training.accuracyLog) > 0: 251 | plt.plot(self.training.accuracyLog, label='Training') 252 | if len(self.testing.accuracyLog) > 0: 253 | plt.plot(self.testing .accuracyLog, label='Testing') 254 | plt.xlabel('Epoch') 255 | plt.ylabel('Accuracy') 256 | plt.legend() 257 | if saveFig is True: 258 | plt.savefig(path + 'accuracy.png') 259 | # plt.close() 260 | 261 | def save(self, filename=''): 262 | ''' 263 | Saves the learning satatistics logs. 264 | 265 | Arguments: 266 | * ``filename``: filename to save the logs. ``accuracy.txt`` and ``loss.txt`` will be appended. 267 | 268 | Usage: 269 | 270 | .. code-block:: python 271 | 272 | # save stats 273 | stats.save() 274 | 275 | # save stats filename specified 276 | stats.save(filename='Run101-0.001-') # Run101-0.001-accuracy.txt and Run101-0.001-loss.txt 277 | ''' 278 | 279 | with open(filename + 'loss.txt', 'wt') as loss: 280 | loss.write('#%11s %11s\r\n'%('Train', 'Test')) 281 | for i in range(len(self.training.lossLog)): 282 | loss.write('%12.6g %12.6g \r\n'%(self.training.lossLog[i], self.testing.lossLog[i])) 283 | 284 | with open(filename + 'accuracy.txt', 'wt') as accuracy: 285 | accuracy.write('#%11s %11s\r\n'%('Train', 'Test')) 286 | if self.training.accuracyLog != [None]*len(self.training.accuracyLog): 287 | for i in range(len(self.training.accuracyLog)): 288 | accuracy.write('%12.6g %12.6g \r\n'%( 289 | self.training.accuracyLog[i], 290 | self.testing.accuracyLog[i] if self.testing.accuracyLog[i] is not None else 0, 291 | )) 292 | 293 | def load(self, filename='', numEpoch=None, modulo=1): 294 | ''' 295 | Loads the learning statistics logs from saved files. 296 | 297 | Arguments: 298 | * ``filename``: filename to save the logs. ``accuracy.txt`` and ``loss.txt`` will be appended. 299 | * ``numEpoch``: number of epochs of logs to load. Default: None. ``numEpoch`` will be automatically determined from saved files. 300 | * ``modulo``: the gap in number of epoch before model was saved. 301 | 302 | Usage: 303 | 304 | .. code-block:: python 305 | 306 | # save stats 307 | stats.load(epoch=10) 308 | 309 | # save stats filename specified 310 | stats.save(filename='Run101-0.001-', epoch=50) # Run101-0.001-accuracy.txt and Run101-0.001-loss.txt 311 | ''' 312 | saved = {} 313 | saved['accuracy'] = np.loadtxt(filename + 'accuracy.txt') 314 | saved['loss'] = np.loadtxt(filename + 'loss.txt') 315 | if numEpoch is None: 316 | saved['epoch'] = saved['loss'].shape[0] // modulo * modulo + 1 317 | else: 318 | saved['epoch'] = numEpoch 319 | 320 | self.training.lossLog = saved['loss'][:saved['epoch'], 0].tolist() 321 | self.testing .lossLog = saved['loss'][:saved['epoch'], 1].tolist() 322 | self.training.minloss = saved['loss'][:saved['epoch'], 0].min() 323 | self.testing .minloss = saved['loss'][:saved['epoch'], 1].min() 324 | self.training.accuracyLog = saved['accuracy'][:saved['epoch'], 0].tolist() 325 | self.testing .accuracyLog = saved['accuracy'][:saved['epoch'], 1].tolist() 326 | self.training.maxAccuracy = saved['accuracy'][:saved['epoch'], 0].max() 327 | self.testing .maxAccuracy = saved['accuracy'][:saved['epoch'], 1].max() 328 | 329 | return saved['epoch'] 330 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | # from .optimizer import Optimizer 4 | 5 | 6 | class Nadam(torch.optim.Optimizer): 7 | ''' 8 | Implements Nadam algorithm. (Modified Adam from PyTorch_) 9 | 10 | It has been proposed in `Incorporating Nesterov Momentum into Adam`_. 11 | 12 | Arguments: 13 | * ``params`` (iterable): iterable of parameters to optimize or dicts defining parameter groups. 14 | * ``lr`` (``float``, optional): learning rate (default: 1e-3). 15 | * ``betas`` (Tuple[``float``, ``float``], optional): coefficients used for computing 16 | running averages of gradient and its square (default: (0.9, 0.999)). 17 | * ``eps`` (``float``, optional): term added to the denominator to improve 18 | numerical stability (default: 1e-8). 19 | * ``weight_decay`` (``float``, optional): weight decay (L2 penalty) (default: 0). 20 | * ``amsgrad`` (``boolean``, optional): whether to use the AMSGrad variant of this 21 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 22 | (default: False). 23 | 24 | .. _PyTorch: 25 | https://pytorch.org/docs/stable/_modules/torch/optim/adam.html#Adam 26 | .. _Incorporating Nesterov Momentum into Adam: 27 | https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ 28 | .. _On the Convergence of Adam and Beyond: 29 | https://openreview.net/forum?id=ryQu7f-RZ 30 | ''' 31 | 32 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 33 | weight_decay=0, amsgrad=False): 34 | if not 0.0 <= lr: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if not 0.0 <= eps: 37 | raise ValueError("Invalid epsilon value: {}".format(eps)) 38 | if not 0.0 <= betas[0] < 1.0: 39 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 40 | if not 0.0 <= betas[1] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 42 | defaults = dict(lr=lr, betas=betas, eps=eps, 43 | weight_decay=weight_decay, amsgrad=amsgrad) 44 | super(Nadam, self).__init__(params, defaults) 45 | 46 | def __setstate__(self, state): 47 | super(Nadam, self).__setstate__(state) 48 | for group in self.param_groups: 49 | group.setdefault('amsgrad', False) 50 | 51 | def step(self, closure=None): 52 | ''' 53 | Performs a single optimization step. 54 | 55 | Arguments: 56 | * ``closure`` (callable, optional): A closure that reevaluates the model 57 | and returns the loss. 58 | ''' 59 | loss = None 60 | if closure is not None: 61 | loss = closure() 62 | 63 | for group in self.param_groups: 64 | for p in group['params']: 65 | if p.grad is None: 66 | continue 67 | grad = p.grad.data 68 | amsgrad = group['amsgrad'] 69 | 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state['step'] = 0 75 | # Exponential moving average of gradient values 76 | state['exp_avg'] = torch.zeros_like(p.data) 77 | # Exponential moving average of squared gradient values 78 | state['exp_avg_sq'] = torch.zeros_like(p.data) 79 | if amsgrad: 80 | # Maintains max of all exp. moving avg. of sq. grad. values 81 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 82 | 83 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 84 | if amsgrad: 85 | max_exp_avg_sq = state['max_exp_avg_sq'] 86 | beta1, beta2 = group['betas'] 87 | 88 | state['step'] += 1 89 | 90 | if group['weight_decay'] != 0: 91 | grad.add_(group['weight_decay'], p.data) 92 | 93 | # Decay the first and second moment running average coefficient 94 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 95 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 96 | if amsgrad: 97 | # Maintains the maximum of all 2nd moment running avg. till now 98 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 99 | # Use the max. for normalizing running avg. of gradient 100 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 101 | else: 102 | denom = exp_avg_sq.sqrt().add_(group['eps']) 103 | 104 | bias_correction1 = 1 - beta1 ** state['step'] 105 | bias_correction2 = 1 - beta2 ** state['step'] 106 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 107 | 108 | 109 | # only change is here 110 | # p.data.addcdiv_(-step_size, exp_avg, denom) 111 | 112 | p.data.addcdiv_(-step_size, beta1 * exp_avg + (1-beta1) * grad, denom) 113 | 114 | return loss -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/quantizeParams.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class quantizeWeights(torch.autograd.Function): 4 | ''' 5 | This class provides routine to quantize the weights during forward propagation pipeline. 6 | The backward propagation pipeline passes the gradient as it it, without any modification. 7 | 8 | Arguments; 9 | * ``weights``: full precision weight tensor. 10 | * ``step``: quantization step size. Default: 1 11 | 12 | Usage: 13 | 14 | >>> # Quantize weights in step of 0.5 15 | >>> stepWeights = quantizeWeights.apply(fullWeights, 0.5) 16 | ''' 17 | @staticmethod 18 | def forward(ctx, weights, step=1): 19 | ''' 20 | ''' 21 | # return weights 22 | # print('Weights qunatized with step', step) 23 | return torch.round(weights / step) * step 24 | 25 | @staticmethod 26 | def backward(ctx, gradOutput): 27 | ''' 28 | ''' 29 | return gradOutput, None 30 | 31 | def quantize(weights, step=1): 32 | ''' 33 | This function provides a wrapper around quantizeWeights. 34 | 35 | Arguments; 36 | * ``weights``: full precision weight tensor. 37 | * ``step``: quantization step size. Default: 1 38 | 39 | Usage: 40 | 41 | >>> # Quantize weights in step of 0.5 42 | >>> stepWeights = quantize(fullWeights, step=0.5) 43 | ''' 44 | return quantizeWeights.apply(weights, step) 45 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/slayerLoihi.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | 3 | CURRENT_SRC_DIR = os.path.dirname(os.path.realpath(__file__)) 4 | sys.path.append(CURRENT_SRC_DIR + "/../../slayerPyTorch/src") 5 | 6 | import math 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from . import slayer 12 | # import slayerCuda # LK: debug 13 | # import slayerLoihiCuda # LK: debug 14 | from .quantizeParams import quantizeWeights, quantize 15 | 16 | class spikeLayer(slayer.spikeLayer): 17 | ''' 18 | This class defines the main engine of SLAYER Loihi module. 19 | It is derived from ``slayer.spikeLayer`` with Loihi specific implementation for 20 | neuron model, weight quantization. 21 | All of the routines available for ``slayer.spikeLayer`` are applicable. 22 | 23 | Arguments: 24 | * ``neuronDesc`` (``slayerParams.yamlParams``): spiking neuron descriptor. 25 | .. code-block:: python 26 | 27 | neuron: 28 | type: LOIHI # neuron type 29 | vThMant: 80 # neuron threshold mantessa 30 | vDecay: 128 # compartment voltage decay 31 | iDecay: 1024 # compartment current decay 32 | refDelay: 1 # refractory delay 33 | wgtExp: 0 # weight exponent 34 | tauRho: 1 # spike function derivative time constant (relative to theta) 35 | scaleRho: 1 # spike function derivative scale factor 36 | * ``simulationDesc`` (``slayerParams.yamlParams``): simulation descriptor 37 | .. code-block:: python 38 | 39 | simulation: 40 | Ts: 1.0 # sampling time (ms) 41 | tSample: 300 # time length of sample (ms) 42 | 43 | Usage: 44 | 45 | >>> snnLayer = slayerLoihi.spikeLayer(neuronDesc, simulationDesc) 46 | ''' 47 | def __init__(self, neuronDesc, simulationDesc): 48 | if neuronDesc['type'] == 'LOIHI': 49 | neuronDesc['theta'] = neuronDesc['vThMant'] * 2**6 50 | 51 | super(spikeLayer, self).__init__(neuronDesc, simulationDesc) 52 | 53 | self.maxPspKernel = torch.max(self.srmKernel).cpu().data.item() 54 | print('Max PSP kernel:', self.maxPspKernel) 55 | print('Scaling neuron[scaleRho] by Max PSP Kernel @slayerLoihi') 56 | neuronDesc['scaleRho'] /= self.maxPspKernel 57 | 58 | def calculateSrmKernel(self): 59 | srmKernel = self._calculateLoihiPSP() 60 | return torch.tensor(srmKernel) 61 | 62 | def calculateRefKernel(self, SCALE=1000): 63 | refKernel = self._calculateLoihiRefKernel(SCALE) 64 | return torch.tensor(refKernel) 65 | 66 | def _calculateLoihiPSP(self): 67 | # u = [0] 68 | # v = [0] 69 | u = [] 70 | v = [] 71 | u.append( 1 << (6 + self.neuron['wgtExp'] + 1) ) # +1 to compensate for weight resolution of 2 for mixed synapse mode 72 | v.append( u[-1] ) # we do not consider bias in slayer 73 | while v[-1] > 0: 74 | uNext = ( ( u[-1] * ( (1<<12) - self.neuron['iDecay']) ) >> 12 ) 75 | vNext = ( ( v[-1] * ( (1<<12) - self.neuron['vDecay']) ) >> 12 ) + uNext # again, we do not consider bias in slayer 76 | u.append(uNext) 77 | v.append(vNext) 78 | 79 | return [float(x)/2 for x in v] # scale by half to compensate for 1 in the initial weight 80 | 81 | def _calculateLoihiRefKernel(self, SCALE=1000): 82 | absoluteRefKernel = np.ones(self.neuron['refDelay']) * (-SCALE * self.neuron['theta']) 83 | absoluteRefKernel[0] = 0 84 | relativeRefKernel = [ self.neuron['theta'] ] 85 | while relativeRefKernel[-1] > 0: 86 | nextRefKernel = ( relativeRefKernel[-1] * ( (1<<12) - self.neuron['vDecay']) ) >> 12 87 | relativeRefKernel.append(nextRefKernel) 88 | refKernel = np.concatenate( (absoluteRefKernel, -2 * np.array(relativeRefKernel) ) ).astype('float32') 89 | return refKernel 90 | 91 | def spikeLoihi(self, weightedSpikes): 92 | ''' 93 | Applies Loihi neuron dynamics to weighted spike inputs and returns output spike tensor. 94 | The output tensor dimension is same as input. 95 | 96 | NOTE: This function is different than the default ``spike`` function which takes membrane potential (weighted spikes with psp filter applied). 97 | Since the dynamics is modeled internally, it just takes in weightedSpikes (NOT FILTERED WITH PSP) for accurate Loihi neuron simulation. 98 | 99 | Arguments: 100 | * ``weightedSpikes``: input spikes weighted by their corresponding synaptic weights. 101 | 102 | Usage: 103 | 104 | >>> outSpike = snnLayer.spikeLoihi(weightedSpikes) 105 | ''' 106 | return _spike.apply(weightedSpikes, self.srmKernel, self.neuron, self.simulation['Ts']) 107 | 108 | def spikeLoihiFull(self, weightedSpikes): 109 | ''' 110 | Applies Loihi neuron dynamics to weighted spike inputs and returns output spike, voltage and current. 111 | The output tensor dimension is same as input. 112 | 113 | NOTE: This function does not have autograd routine in the computational graph. 114 | 115 | Arguments: 116 | * ``weightedSpikes``: input spikes weighted by their corresponding synaptic weights. 117 | 118 | Usage: 119 | 120 | >>> outSpike, outVoltage, outCurrent = snnLayer.spikeLoihiFull(weightedSpikes) 121 | ''' 122 | return _spike.loihi(weightedSpikes, self.neuron, self.simulation['Ts']) 123 | 124 | def dense(self, inFeatures, outFeatures, weightScale=100, preHookFx = lambda x: quantize(x, step=2)): 125 | ''' 126 | This function behaves similar to :meth:`slayer.spikeLayer.dense`. 127 | The only difference is that the weights are qunatized with step of 2 (as is the case for signed weights in Loihi). 128 | One can, however, skip the quantization step altogether as well. 129 | 130 | Arguments: 131 | The arguments that are different from :meth:`slayer.spikeLayer.dense` are listed. 132 | 133 | * ``weightScale``: sale factor of default initialized weights. Default: 100 134 | * ``preHookFx``: a function that operates on weight before applying it. Could be used for quantization etc. Default: quantizes in step of 2. 135 | Usage: 136 | Same as :meth:`slayer.spikeLayer.dense` 137 | ''' 138 | # return _denseLayer(inFeatures, outFeatures, weightScale, quantize) 139 | return super(spikeLayer, self).dense(inFeatures, outFeatures, weightScale, preHookFx) 140 | 141 | def conv(self, inChannels, outChannels, kernelSize, stride=1, padding=0, dilation=1, groups=1, weightScale=100, preHookFx = lambda x: quantize(x, step=2)): 142 | ''' 143 | This function behaves similar to :meth:`slayer.spikeLayer.conv`. 144 | The only difference is that the weights are qunatized with step of 2 (as is the case for signed weights in Loihi). 145 | One can, however, skip the quantization step altogether as well. 146 | 147 | Arguments: 148 | The arguments that are different from :meth:`slayer.spikeLayer.conv` are listed. 149 | 150 | * ``weightScale``: sale factor of default initialized weights. Default: 100 151 | * ``preHookFx``: a function that operates on weight before applying it. Could be used for quantization etc. Default: quantizes in step of 2. 152 | Usage: 153 | Same as :meth:`slayer.spikeLayer.conv` 154 | ''' 155 | # return _convLayer(inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weightScale, quantize) 156 | return super(spikeLayer, self).conv(inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weightScale, preHookFx) 157 | 158 | def pool(self, kernelSize, stride=None, padding=0, dilation=1, preHookFx=None): 159 | ''' 160 | This function behaves similar to :meth:`slayer.spikeLayer.pool`. 161 | The only difference is that the weights are qunatized with step of 2 (as is the case for signed weights in Loihi). 162 | One can, however, skip the quantization step altogether as well. 163 | 164 | Arguments: 165 | The arguments set is same as :meth:`slayer.spikeLayer.pool`. 166 | 167 | Usage: 168 | Same as :meth:`slayer.spikeLayer.pool` 169 | ''' 170 | requiredWeight = quantizeWeights.apply(torch.tensor(1.1 * self.neuron['theta'] / self.maxPspKernel), 2).cpu().data.item() 171 | # print('Required pool layer weight =', requiredWeight) 172 | return slayer._poolLayer(requiredWeight/ 1.1, # to compensate for maxPsp 173 | kernelSize, stride, padding, dilation, preHookFx) 174 | 175 | def convTranspose(self, inChannels, outChannels, kernelSize, stride=1, padding=0, dilation=1, groups=1, weightScale=100, preHookFx=lambda x: quantize(x, step=2)): 176 | ''' 177 | This function behaves similar to :meth:`slayer.spikeLayer.convTranspose`. 178 | The only difference is that the weights are qunatized with step of 2 (as is the case for signed weights in Loihi). 179 | One can, however, skip the quantization step altogether as well. 180 | 181 | Arguments: 182 | The arguments that are different from :meth:`slayer.spikeLayer.conv` are listed. 183 | 184 | * ``weightScale``: sale factor of default initialized weights. Default: 100 185 | * ``preHookFx``: a function that operates on weight before applying it. Could be used for quantization etc. Default: quantizes in step of 2. 186 | Usage: 187 | Same as :meth:`slayer.spikeLayer.convTranspose` 188 | ''' 189 | return super(spikeLayer, self).convTranspose(inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weightScale, preHookFx) 190 | 191 | def unpool(self, kernelSize, stride=None, padding=0, dilation=1, preHookFx=None): 192 | ''' 193 | This function behaves similar to :meth:`slayer.spikeLayer.unpool`. 194 | The only difference is that the weights are qunatized with step of 2 (as is the case for signed weights in Loihi). 195 | One can, however, skip the quantization step altogether as well. 196 | 197 | Arguments: 198 | The arguments set is same as :meth:`slayer.spikeLayer.unpool`. 199 | 200 | Usage: 201 | Same as :meth:`slayer.spikeLayer.pool` 202 | ''' 203 | requiredWeight = quantizeWeights.apply(torch.tensor(1.1 * self.neuron['theta'] / self.maxPspKernel), 2).cpu().data.item() 204 | return slayer._unpoolLayer(requiredWeight/ 1.1, # to compensate for maxPsp 205 | kernelSize, stride, padding, dilation, preHookFx) 206 | 207 | def getVoltage(self, membranePotential): 208 | Ns = int(self.simulation['tSample'] / self.simulation['Ts']) 209 | voltage = membranePotential.reshape((-1, Ns)).cpu().data.numpy() 210 | return np.where(voltage <= -500*self.neuron['theta'], self.neuron['theta'] + 1, voltage) 211 | 212 | # class _denseLayer(slayer._denseLayer): 213 | # def __init__(self, inFeatures, outFeatures, weightScale=1, quantize=True): 214 | # self.quantize = quantize 215 | # super(_denseLayer, self).__init__(inFeatures, outFeatures, weightScale) 216 | 217 | # def forward(self, input): 218 | # if self.quantize is True: 219 | # return F.conv3d(input, 220 | # quantizeWeights.apply(self.weight, 2), self.bias, 221 | # self.stride, self.padding, self.dilation, self.groups) 222 | # else: 223 | # return F.conv3d(input, 224 | # self.weight, self.bias, 225 | # self.stride, self.padding, self.dilation, self.groups) 226 | 227 | # class _convLayer(slayer._convLayer): 228 | # def __init__(self, inFeatures, outFeatures, kernelSize, stride=1, padding=0, dilation=1, groups=1, weightScale=1, quantize=True): 229 | # self.quantize = quantize 230 | # super(_convLayer, self).__init__(inFeatures, outFeatures, kernelSize, stride, padding, dilation, groups, weightScale) 231 | 232 | # def forward(self, input): 233 | # if self.quantize is True: 234 | # return F.conv3d(input, 235 | # quantizeWeights.apply(self.weight, 2), self.bias, 236 | # self.stride, self.padding, self.dilation, self.groups) 237 | # else: 238 | # return F.conv3d(input, 239 | # self.weight, self.bias, 240 | # self.stride, self.padding, self.dilation, self.groups) 241 | 242 | 243 | class _spike(torch.autograd.Function): 244 | ''' 245 | ''' 246 | @staticmethod 247 | def loihi(weightedSpikes, neuron, Ts): 248 | iDecay = neuron['iDecay'] 249 | vDecay = neuron['vDecay'] 250 | theta = neuron['theta'] 251 | # wScale = 1 << (6 + neuron['wgtExp']) 252 | wgtExp = neuron['wgtExp'] 253 | refDelay = neuron['refDelay'] 254 | 255 | if weightedSpikes.dtype == torch.int32: 256 | Ts = 1 257 | 258 | spike, voltage, current = slayerLoihiCuda.getSpikes((weightedSpikes * Ts).contiguous(), wgtExp, theta, iDecay, vDecay, refDelay) 259 | 260 | return spike/Ts, voltage, current 261 | 262 | @staticmethod 263 | def forward(ctx, weightedSpikes, srmKernel, neuron, Ts): 264 | device = weightedSpikes.device 265 | dtype = weightedSpikes.dtype 266 | pdfScale = torch.autograd.Variable(torch.tensor(neuron['scaleRho'] , device=device, dtype=dtype), requires_grad=False) 267 | pdfTimeConstant = torch.autograd.Variable(torch.tensor(neuron['tauRho'] * neuron['theta'] , device=device, dtype=dtype), requires_grad=False) # needs to be scaled by theta 268 | threshold = torch.autograd.Variable(torch.tensor(neuron['theta'] , device=device, dtype=dtype), requires_grad=False) 269 | Ts = torch.autograd.Variable(torch.tensor(Ts, device=device, dtype=dtype), requires_grad=False) 270 | srmKernel = torch.autograd.Variable(srmKernel.clone().detach(), requires_grad=False) 271 | 272 | 273 | spike, voltage, current = _spike.loihi(weightedSpikes, neuron, Ts) 274 | 275 | ctx.save_for_backward(voltage, threshold, pdfTimeConstant, pdfScale, srmKernel, Ts) 276 | return spike 277 | 278 | @staticmethod 279 | def backward(ctx, gradOutput): 280 | (membranePotential, threshold, pdfTimeConstant, pdfScale, srmKernel, Ts) = ctx.saved_tensors 281 | spikePdf = pdfScale / pdfTimeConstant * torch.exp( -torch.abs(membranePotential - threshold) / pdfTimeConstant) 282 | 283 | return slayerCuda.corr(gradOutput * spikePdf, srmKernel, Ts), None, None, None 284 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/slayerParams.py: -------------------------------------------------------------------------------- 1 | from numpy.lib.arraysetops import isin 2 | import yaml 3 | 4 | # Consider dictionary for easier iteration and better scalability 5 | class yamlParams(object): 6 | ''' 7 | This class reads yaml parameter file and allows dictionary like access to the members. 8 | 9 | Usage: 10 | 11 | .. code-block:: python 12 | 13 | import slayerSNN as snn 14 | netParams = snn.params('path_to_yaml_file') # OR 15 | netParams = yamlParams('path_to_yaml_file') 16 | 17 | netParams['training']['learning']['etaW'] = 0.01 18 | print('Simulation step size ', netParams['simulation']['Ts']) 19 | print('Spiking neuron time constant', netParams['neuron']['tauSr']) 20 | print('Spiking neuron threshold ', netParams['neuron']['theta']) 21 | 22 | netParams.save('filename.yaml') 23 | ''' 24 | def __init__(self, parameter_file_path=None, dict=None): 25 | if dict is None: 26 | with open(parameter_file_path, 'r') as param_file: 27 | self.parameters = yaml.safe_load(param_file) 28 | else: 29 | self.parameters = dict 30 | 31 | # Allow dictionary like access 32 | def __getitem__(self, key): 33 | return self.parameters[key] 34 | 35 | def __setitem__(self, key, value): 36 | self.parameters[key] = value 37 | 38 | def save(self, filename): 39 | with open(filename, 'w') as f: 40 | yaml.dump(self.parameters, f) 41 | 42 | def print(self, key=None): 43 | if key is None: 44 | printConfig(self.parameters) 45 | else: 46 | print(key + ':') 47 | printConfig(self.parameters[key], pre=' ') 48 | 49 | def printConfig(obj, pre=''): 50 | if isinstance(obj, dict): 51 | for key, value in obj.items(): 52 | if isinstance(value, dict) or isinstance(value, list): 53 | print(pre + key + ' :') 54 | printConfig(value, pre=pre+' ') 55 | else: 56 | print(pre + '{:10s} : {}'.format(str(key), value)) 57 | elif isinstance(obj, list): 58 | for l in obj: 59 | printConfig(pre + '- {}'.format(l)) 60 | else: 61 | print(pre + '{}'.format(obj)) 62 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/spikeClassifier.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class spikeClassifier: 5 | ''' 6 | It provides classification modules for SNNs. 7 | All the functions it supplies are static and can be called without making an instance of the class. 8 | ''' 9 | @staticmethod 10 | def getClass(spike): 11 | ''' 12 | Returns the predicted class label. 13 | It assignes single class for the SNN output for the whole simulation runtime. 14 | 15 | Usage: 16 | 17 | >>> predictedClass = spikeClassifier.getClass(spikeOut) 18 | ''' 19 | numSpikes = torch.sum(spike, 4, keepdim=True).cpu() 20 | return torch.max(numSpikes.reshape((numSpikes.shape[0], -1)), 1)[1] 21 | # numSpikes = torch.sum(spike, 4, keepdim=True).cpu().data.numpy() 22 | # return np.argmax(numSpikes.reshape((numSpikes.shape[0], -1)), 1) -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/spikeFileIO.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | from matplotlib import animation 4 | from matplotlib import cm 5 | 6 | class event(): 7 | ''' 8 | This class provides a way to store, read, write and visualize spike event. 9 | 10 | Members: 11 | * ``x`` (numpy ``int`` array): `x` index of spike event. 12 | * ``y`` (numpy ``int`` array): `y` index of spike event (not used if the spatial dimension is 1). 13 | * ``p`` (numpy ``int`` array): `polarity` or `channel` index of spike event. 14 | * ``t`` (numpy ``double`` array): `timestamp` of spike event. Time is assumend to be in ms. 15 | 16 | Usage: 17 | 18 | >>> TD = spikeFileIO.event(xEvent, yEvent, pEvent, tEvent) 19 | ''' 20 | def __init__(self, xEvent, yEvent, pEvent, tEvent): 21 | if yEvent is None: 22 | self.dim = 1 23 | else: 24 | self.dim = 2 25 | 26 | self.x = xEvent if type(xEvent) is np.ndarray else np.asarray(xEvent) # x spatial dimension 27 | self.y = yEvent if type(yEvent) is np.ndarray else np.asarray(yEvent) # y spatial dimension 28 | self.p = pEvent if type(pEvent) is np.ndarray else np.asarray(pEvent) # spike polarity 29 | self.t = tEvent if type(tEvent) is np.ndarray else np.asarray(tEvent) # time stamp in ms 30 | 31 | if not issubclass(self.x.dtype.type, np.integer): self.x = self.x.astype('int') 32 | if not issubclass(self.p.dtype.type, np.integer): self.p = self.p.astype('int') 33 | 34 | if self.dim == 2: 35 | if not issubclass(self.y.dtype.type, np.integer): self.y = self.y.astype('int') 36 | 37 | self.p -= self.p.min() 38 | 39 | def toSpikeArray(self, samplingTime=1, dim=None): # Sampling time in ms 40 | ''' 41 | Returns a numpy tensor that contains the spike events sampled in bins of `samplingTime`. 42 | The array is of dimension (channels, height, time) or``CHT`` for 1D data. 43 | The array is of dimension (channels, height, width, time) or``CHWT`` for 2D data. 44 | 45 | Arguments: 46 | * ``samplingTime``: the width of time bin to use. 47 | * ``dim``: the dimension of the desired tensor. Assignes dimension itself if not provided. 48 | 49 | Usage: 50 | 51 | >>> spike = TD.toSpikeArray() 52 | ''' 53 | if self.dim == 1: 54 | if dim is None: dim = ( np.round(max(self.p)+1).astype(int), 55 | np.round(max(self.x)+1).astype(int), 56 | np.round(max(self.t)/samplingTime+1).astype(int) ) 57 | frame = np.zeros((dim[0], 1, dim[1], dim[2])) 58 | elif self.dim == 2: 59 | if dim is None: dim = ( np.round(max(self.p)+1).astype(int), 60 | np.round(max(self.y)+1).astype(int), 61 | np.round(max(self.x)+1).astype(int), 62 | np.round(max(self.t)/samplingTime+1).astype(int) ) 63 | frame = np.zeros((dim[0], dim[1], dim[2], dim[3])) 64 | return self.toSpikeTensor(frame, samplingTime).reshape(dim) 65 | 66 | def toSpikeTensor(self, emptyTensor, samplingTime=1, randomShift=False, binningMode='OR'): # Sampling time in ms 67 | ''' 68 | Returns a numpy tensor that contains the spike events sampled in bins of `samplingTime`. 69 | The tensor is of dimension (channels, height, width, time) or``CHWT``. 70 | 71 | Arguments: 72 | * ``emptyTensor`` (``numpy or torch tensor``): an empty tensor to hold spike data 73 | * ``samplingTime``: the width of time bin to use. 74 | * ``randomShift``: flag to shift the sample in time or not. Default: False. 75 | * ``binningMode``: the way spikes are binned. 'SUM' or 'OR' are supported. Default: 'OR' 76 | 77 | Usage: 78 | 79 | >>> spike = TD.toSpikeTensor( torch.zeros((2, 240, 180, 5000)) ) 80 | ''' 81 | 82 | if randomShift is True: 83 | tSt = np.random.randint( 84 | max( 85 | int(self.t.min() / samplingTime), 86 | int(self.t.max() / samplingTime) - emptyTensor.shape[3], 87 | emptyTensor.shape[3] - int(self.t.max() / samplingTime), 88 | 1, 89 | ) 90 | ) 91 | else: 92 | tSt = 0 93 | 94 | xEvent = np.round(self.x).astype(int) 95 | pEvent = np.round(self.p).astype(int) 96 | tEvent = np.round(self.t/samplingTime).astype(int) - tSt 97 | 98 | # print('shifted sequence by', tSt) 99 | 100 | if self.dim == 1: 101 | validInd = np.argwhere((xEvent < emptyTensor.shape[2]) & 102 | (pEvent < emptyTensor.shape[0]) & 103 | (tEvent < emptyTensor.shape[3]) & 104 | (xEvent >= 0) & 105 | (pEvent >= 0) & 106 | (tEvent >= 0)) 107 | if binningMode.upper() == 'OR': 108 | emptyTensor[pEvent[validInd], 109 | 0, 110 | xEvent[validInd], 111 | tEvent[validInd]] = 1/samplingTime 112 | elif binningMode.upper() == 'SUM': 113 | emptyTensor[pEvent[validInd], 114 | 0, 115 | xEvent[validInd], 116 | tEvent[validInd]] += 1/samplingTime 117 | else: 118 | raise Exception('Unsupported binningMode. It was {}'.format(binningMode)) 119 | 120 | elif self.dim == 2: 121 | yEvent = np.round(self.y).astype(int) 122 | validInd = np.argwhere((xEvent < emptyTensor.shape[2]) & 123 | (yEvent < emptyTensor.shape[1]) & 124 | (pEvent < emptyTensor.shape[0]) & 125 | (tEvent < emptyTensor.shape[3]) & 126 | (xEvent >= 0) & 127 | (yEvent >= 0) & 128 | (pEvent >= 0) & 129 | (tEvent >= 0)) 130 | 131 | if binningMode.upper() == 'OR': 132 | emptyTensor[pEvent[validInd], 133 | yEvent[validInd], 134 | xEvent[validInd], 135 | tEvent[validInd]] = 1/samplingTime 136 | elif binningMode.upper() == 'SUM': 137 | emptyTensor[pEvent[validInd], 138 | yEvent[validInd], 139 | xEvent[validInd], 140 | tEvent[validInd]] += 1/samplingTime 141 | else: 142 | raise Exception('Unsupported binningMode. It was {}'.format(binningMode)) 143 | 144 | return emptyTensor 145 | 146 | def spikeArrayToEvent(spikeMat, samplingTime=1): 147 | ''' 148 | Returns TD event from a numpy array (of dimension 3 or 4). 149 | The numpy array must be of dimension (channels, height, time) or``CHT`` for 1D data. 150 | The numpy array must be of dimension (channels, height, width, time) or``CHWT`` for 2D data. 151 | 152 | Arguments: 153 | * ``spikeMat``: numpy array with spike information. 154 | * ``samplingTime``: time width of each time bin. 155 | 156 | Usage: 157 | 158 | >>> TD = spikeFileIO.spikeArrayToEvent(spike) 159 | ''' 160 | if spikeMat.ndim == 3: 161 | spikeEvent = np.argwhere(spikeMat > 0) 162 | xEvent = spikeEvent[:,1] 163 | yEvent = None 164 | pEvent = spikeEvent[:,0] 165 | tEvent = spikeEvent[:,2] 166 | elif spikeMat.ndim == 4: 167 | spikeEvent = np.argwhere(spikeMat > 0) 168 | xEvent = spikeEvent[:,2] 169 | yEvent = spikeEvent[:,1] 170 | pEvent = spikeEvent[:,0] 171 | tEvent = spikeEvent[:,3] 172 | else: 173 | raise Exception('Expected numpy array of 3 or 4 dimension. It was {}'.format(spikeMat.ndim)) 174 | 175 | return event(xEvent, yEvent, pEvent, tEvent * samplingTime) 176 | 177 | def read1Dspikes(filename): 178 | ''' 179 | Reads one dimensional binary spike file and returns a TD event. 180 | 181 | The binary file is encoded as follows: 182 | * Each spike event is represented by a 40 bit number. 183 | * First 16 bits (bits 39-24) represent the neuronID. 184 | * Bit 23 represents the sign of spike event: 0=>OFF event, 1=>ON event. 185 | * the last 23 bits (bits 22-0) represent the spike event timestamp in microseconds. 186 | 187 | Arguments: 188 | * ``filename`` (``string``): path to the binary file. 189 | 190 | Usage: 191 | 192 | >>> TD = spikeFileIO.read1Dspikes(file_path) 193 | ''' 194 | with open(filename, 'rb') as inputFile: 195 | inputByteArray = inputFile.read() 196 | inputAsInt = np.asarray([x for x in inputByteArray]) 197 | xEvent = (inputAsInt[0::5] << 8) | inputAsInt[1::5] 198 | pEvent = inputAsInt[2::5] >> 7 199 | tEvent =( (inputAsInt[2::5] << 16) | (inputAsInt[3::5] << 8) | (inputAsInt[4::5]) ) & 0x7FFFFF 200 | return event(xEvent, None, pEvent, tEvent/1000) # convert spike times to ms 201 | 202 | def encode1Dspikes(filename, TD): 203 | ''' 204 | Writes one dimensional binary spike file from a TD event. 205 | 206 | The binary file is encoded as follows: 207 | * Each spike event is represented by a 40 bit number. 208 | * First 16 bits (bits 39-24) represent the neuronID. 209 | * Bit 23 represents the sign of spike event: 0=>OFF event, 1=>ON event. 210 | * the last 23 bits (bits 22-0) represent the spike event timestamp in microseconds. 211 | 212 | Arguments: 213 | * ``filename`` (``string``): path to the binary file. 214 | * ``TD`` (an ``spikeFileIO.event``): TD event. 215 | 216 | Usage: 217 | 218 | >>> spikeFileIO.write1Dspikes(file_path, TD) 219 | ''' 220 | if TD.dim != 1: raise Exception('Expected Td dimension to be 1. It was: {}'.format(TD.dim)) 221 | xEvent = np.round(TD.x).astype(int) 222 | pEvent = np.round(TD.p).astype(int) 223 | tEvent = np.round(TD.t * 1000).astype(int) # encode spike time in us 224 | outputByteArray = bytearray(len(tEvent) * 5) 225 | outputByteArray[0::5] = np.uint8( (xEvent >> 8) & 0xFF00 ).tobytes() 226 | outputByteArray[1::5] = np.uint8( (xEvent & 0xFF) ).tobytes() 227 | outputByteArray[2::5] = np.uint8(((tEvent >> 16) & 0x7F) | (pEvent.astype(int) << 7) ).tobytes() 228 | outputByteArray[3::5] = np.uint8( (tEvent >> 8 ) & 0xFF ).tobytes() 229 | outputByteArray[4::5] = np.uint8( tEvent & 0xFF ).tobytes() 230 | with open(filename, 'wb') as outputFile: 231 | outputFile.write(outputByteArray) 232 | 233 | def read2Dspikes(filename): 234 | ''' 235 | Reads two dimensional binary spike file and returns a TD event. 236 | It is the same format used in neuromorphic datasets NMNIST & NCALTECH101. 237 | 238 | The binary file is encoded as follows: 239 | * Each spike event is represented by a 40 bit number. 240 | * First 8 bits (bits 39-32) represent the xID of the neuron. 241 | * Next 8 bits (bits 31-24) represent the yID of the neuron. 242 | * Bit 23 represents the sign of spike event: 0=>OFF event, 1=>ON event. 243 | * The last 23 bits (bits 22-0) represent the spike event timestamp in microseconds. 244 | 245 | Arguments: 246 | * ``filename`` (``string``): path to the binary file. 247 | 248 | Usage: 249 | 250 | >>> TD = spikeFileIO.read2Dspikes(file_path) 251 | ''' 252 | with open(filename, 'rb') as inputFile: 253 | inputByteArray = inputFile.read() 254 | inputAsInt = np.asarray([x for x in inputByteArray]) 255 | xEvent = inputAsInt[0::5] 256 | yEvent = inputAsInt[1::5] 257 | pEvent = inputAsInt[2::5] >> 7 258 | tEvent =( (inputAsInt[2::5] << 16) | (inputAsInt[3::5] << 8) | (inputAsInt[4::5]) ) & 0x7FFFFF 259 | return event(xEvent, yEvent, pEvent, tEvent/1000) # convert spike times to ms 260 | 261 | def encode2Dspikes(filename, TD): 262 | ''' 263 | Writes two dimensional binary spike file from a TD event. 264 | It is the same format used in neuromorphic datasets NMNIST & NCALTECH101. 265 | 266 | The binary file is encoded as follows: 267 | * Each spike event is represented by a 40 bit number. 268 | * First 8 bits (bits 39-32) represent the xID of the neuron. 269 | * Next 8 bits (bits 31-24) represent the yID of the neuron. 270 | * Bit 23 represents the sign of spike event: 0=>OFF event, 1=>ON event. 271 | * The last 23 bits (bits 22-0) represent the spike event timestamp in microseconds. 272 | 273 | Arguments: 274 | * ``filename`` (``string``): path to the binary file. 275 | * ``TD`` (an ``spikeFileIO.event``): TD event. 276 | 277 | Usage: 278 | 279 | >>> spikeFileIO.write2Dspikes(file_path, TD) 280 | ''' 281 | if TD.dim != 2: raise Exception('Expected Td dimension to be 2. It was: {}'.format(TD.dim)) 282 | xEvent = np.round(TD.x).astype(int) 283 | yEvent = np.round(TD.y).astype(int) 284 | pEvent = np.round(TD.p).astype(int) 285 | tEvent = np.round(TD.t * 1000).astype(int) # encode spike time in us 286 | outputByteArray = bytearray(len(tEvent) * 5) 287 | outputByteArray[0::5] = np.uint8(xEvent).tobytes() 288 | outputByteArray[1::5] = np.uint8(yEvent).tobytes() 289 | outputByteArray[2::5] = np.uint8(((tEvent >> 16) & 0x7F) | (pEvent.astype(int) << 7) ).tobytes() 290 | outputByteArray[3::5] = np.uint8( (tEvent >> 8 ) & 0xFF ).tobytes() 291 | outputByteArray[4::5] = np.uint8( tEvent & 0xFF ).tobytes() 292 | with open(filename, 'wb') as outputFile: 293 | outputFile.write(outputByteArray) 294 | 295 | def read3Dspikes(filename): 296 | ''' 297 | Reads binary spike file for spike event in height, width and channel dimension and returns a TD event. 298 | 299 | The binary file is encoded as follows: 300 | * Each spike event is represented by a 56 bit number. 301 | * First 12 bits (bits 56-44) represent the xID of the neuron. 302 | * Next 12 bits (bits 43-32) represent the yID of the neuron. 303 | * Next 8 bits (bits 31-24) represents the channel ID of the neuron. 304 | * The last 24 bits (bits 23-0) represent the spike event timestamp in microseconds. 305 | 306 | Arguments: 307 | * ``filename`` (``string``): path to the binary file. 308 | 309 | Usage: 310 | 311 | >>> TD = spikeFileIO.read3Dspikes(file_path) 312 | ''' 313 | with open(filename, 'rb') as inputFile: 314 | inputByteArray = inputFile.read() 315 | inputAsInt = np.asarray([x for x in inputByteArray]) 316 | xEvent = (inputAsInt[0::7] << 4 ) | (inputAsInt[1::7] >> 4 ) 317 | yEvent = (inputAsInt[2::7] ) | ( (inputAsInt[1::7] & 0x0F) << 8 ) 318 | pEvent = inputAsInt[3::7] 319 | tEvent =( (inputAsInt[4::7] << 16) | (inputAsInt[5::7] << 8) | (inputAsInt[6::7]) ) 320 | return event(xEvent, yEvent, pEvent, tEvent/1000) # convert spike times to ms 321 | 322 | def encode3Dspikes(filename, TD): 323 | ''' 324 | Writes binary spike file for TD event in height, width and channel dimension. 325 | 326 | The binary file is encoded as follows: 327 | * Each spike event is represented by a 56 bit number. 328 | * First 12 bits (bits 56-44) represent the xID of the neuron. 329 | * Next 12 bits (bits 43-32) represent the yID of the neuron. 330 | * Next 8 bits (bits 31-24) represents the channel ID of the neuron. 331 | * The last 24 bits (bits 23-0) represent the spike event timestamp in microseconds. 332 | 333 | Arguments: 334 | * ``filename`` (``string``): path to the binary file. 335 | * ``TD`` (an ``spikeFileIO.event``): TD event. 336 | 337 | Usage: 338 | 339 | >>> spikeFileIO.write3Dspikes(file_path, TD) 340 | ''' 341 | if TD.dim != 2: raise Exception('Expected Td dimension to be 2. It was: {}'.format(TD.dim)) 342 | xEvent = np.round(TD.x).astype(int) 343 | yEvent = np.round(TD.y).astype(int) 344 | pEvent = np.round(TD.p).astype(int) 345 | tEvent = np.round(TD.t * 1000).astype(int) # encode spike time in us 346 | outputByteArray = bytearray(len(tEvent) * 7) 347 | outputByteArray[0::7] = np.uint8(xEvent >> 4).tobytes() 348 | outputByteArray[1::7] = np.uint8( ((xEvent << 4) & 0xFF) | (yEvent >> 8) & 0xFF00 ).tobytes() 349 | outputByteArray[2::7] = np.uint8( yEvent & 0xFF ).tobytes() 350 | outputByteArray[3::7] = np.uint8( pEvent ).tobytes() 351 | outputByteArray[4::7] = np.uint8( (tEvent >> 16 ) & 0xFF ).tobytes() 352 | outputByteArray[5::7] = np.uint8( (tEvent >> 8 ) & 0xFF ).tobytes() 353 | outputByteArray[6::7] = np.uint8( tEvent & 0xFF ).tobytes() 354 | with open(filename, 'wb') as outputFile: 355 | outputFile.write(outputByteArray) 356 | 357 | def read1DnumSpikes(filename): 358 | ''' 359 | Reads a tuple specifying neuron, start of spike region, end of spike region and number of spikes from binary spike file. 360 | 361 | The binary file is encoded as follows: 362 | * Number of spikes data is represented by an 80 bit number. 363 | * First 16 bits (bits 79-64) represent the neuronID. 364 | * Next 24 bits (bits 63-40) represents the start time in microseconds. 365 | * Next 24 bits (bits 39-16) represents the end time in microseconds. 366 | * Last 16 bits (bits 15-0) represents the number of spikes. 367 | 368 | Arguments: 369 | * ``filename`` (``string``): path to the binary file 370 | 371 | Usage: 372 | 373 | >>> nID, tSt, tEn, nSp = spikeFileIO.read1DnumSpikes(file_path) 374 | ``tSt`` and ``tEn`` are returned in milliseconds 375 | ''' 376 | with open(filename, 'rb') as inputFile: 377 | inputByteArray = inputFile.read() 378 | inputAsInt = np.asarray([x for x in inputByteArray]) 379 | neuronID = (inputAsInt[0::10] << 8) | inputAsInt[1::10] 380 | tStart = (inputAsInt[2::10] << 16) | (inputAsInt[3::10] << 8) | (inputAsInt[4::10]) 381 | tEnd = (inputAsInt[5::10] << 16) | (inputAsInt[6::10] << 8) | (inputAsInt[7::10]) 382 | nSpikes = (inputAsInt[8::10] << 8) | inputAsInt[9::10] 383 | return neuronID, tStart/1000, tEnd/1000, nSpikes # convert spike times to ms 384 | 385 | def encode1DnumSpikes(filename, nID, tSt, tEn, nSp): 386 | ''' 387 | Writes binary spike file given a tuple specifying neuron, start of spike region, end of spike region and number of spikes. 388 | 389 | The binary file is encoded as follows: 390 | * Number of spikes data is represented by an 80 bit number 391 | * First 16 bits (bits 79-64) represent the neuronID 392 | * Next 24 bits (bits 63-40) represents the start time in microseconds 393 | * Next 24 bits (bits 39-16) represents the end time in microseconds 394 | * Last 16 bits (bits 15-0) represents the number of spikes 395 | 396 | Arguments: 397 | * ``filename`` (``string``): path to the binary file 398 | * ``nID`` (``numpy array``): neuron ID 399 | * ``tSt`` (``numpy array``): region start time (in milliseconds) 400 | * ``tEn`` (``numpy array``): region end time (in milliseconds) 401 | * ``nSp`` (``numpy array``): number of spikes in the region 402 | 403 | Usage: 404 | 405 | >>> spikeFileIO.encode1DnumSpikes(file_path, nID, tSt, tEn, nSp) 406 | ''' 407 | neuronID = np.round(nID).astype(int) 408 | tStart = np.round(tSt * 1000).astype(int) # encode spike time in us 409 | tEnd = np.round(tEn * 1000).astype(int) # encode spike time in us 410 | nSpikes = np.round(nSp).astype(int) 411 | outputByteArray = bytearray(len(neuronID) * 10) 412 | outputByteArray[0::10] = np.uint8( neuronID >> 8 ).tobytes() 413 | outputByteArray[1::10] = np.uint8( neuronID ).tobytes() 414 | outputByteArray[2::10] = np.uint8( tStart >> 16 ).tobytes() 415 | outputByteArray[3::10] = np.uint8( tStart >> 8 ).tobytes() 416 | outputByteArray[4::10] = np.uint8( tStart ).tobytes() 417 | outputByteArray[5::10] = np.uint8( tEnd >> 16 ).tobytes() 418 | outputByteArray[6::10] = np.uint8( tEnd >> 8 ).tobytes() 419 | outputByteArray[7::10] = np.uint8( tEnd ).tobytes() 420 | outputByteArray[8::10] = np.uint8( nSpikes >> 8 ).tobytes() 421 | outputByteArray[9::10] = np.uint8( nSpikes ).tobytes() 422 | with open(filename, 'wb') as outputFile: 423 | outputFile.write(outputByteArray) 424 | 425 | def readNpSpikes(filename, fmt='xypt', timeUnit=1e-3): 426 | ''' 427 | Reads numpy spike event and returns a TD event. 428 | The numpy array is assumed to be of nEvent x event diension. 429 | 430 | Arguments: 431 | * ``filename`` (``string``): path to the file. 432 | * ``fmt`` (``string``): format of event. For e.g.'xypt' means the event data is arrange in x data, y data, p data and time data. 433 | * ``timeUnit`` (``double``): factor to scale the time data to convert it into seconds. Default: 1e-3 (ms). 434 | 435 | Usage: 436 | 437 | >>> TD = spikeFileIO.readNpSpikes(file_path) 438 | >>> TD = spikeFileIO.readNpSpikes(file_path, fmt='xypt') 439 | >>> TD = spikeFileIO.readNpSpikes(file_path, timeUnit=1e-6) 440 | ''' 441 | npEvent = np.load(filename) 442 | if fmt=='xypt': 443 | if npEvent.shape[1] == 3: 444 | return event(npEvent[:, 0].astype('int'), None, npEvent[:, 1], npEvent[:, 2] * timeUnit * 1e3) 445 | elif npEvent.shape[1] == 4: 446 | return event(npEvent[:, 0], npEvent[:, 1], npEvent[:, 2], npEvent[:, 3] * timeUnit * 1e3) 447 | else: 448 | raise Exception('Numpy array format did not match. Ecpected it to be nEvents x eventDim.') 449 | else: 450 | raise Exception("fmt='%s' not implemented."%(fmt)) 451 | 452 | 453 | def encodeNpSpikes(filename, TD, fmt='xypt', timeUnit=1e-3): 454 | ''' 455 | Writes TD event into numpy file. 456 | 457 | Arguments: 458 | * ``filename`` (``string``): path to the binary file. 459 | * ``TD`` (an ``spikeFileIO.event``): TD event. 460 | 461 | Usage: 462 | 463 | >>> spikeFileIO.write1Dspikes(file_path, TD) 464 | >>> spikeFileIO.write1Dspikes(file_path, TD, fmt='xypt') 465 | ''' 466 | if fmt=='xypt': 467 | if TD.dim ==1: 468 | npEvent = np.zeros((len(TD.x), 3)) 469 | npEvent[:, 0] = TD.x 470 | npEvent[:, 1] = TD.p 471 | npEvent[:, 2] = TD.t 472 | elif TD.dim==2: 473 | npEvent = np.zeros((len(TD.x), 4)) 474 | npEvent[:, 0] = TD.x 475 | npEvent[:, 1] = TD.y 476 | npEvent[:, 2] = TD.p 477 | npEvent[:, 3] = TD.t 478 | else: 479 | raise Exception('Numpy array format did not match. Ecpected it to be nEvents x eventDim.') 480 | else: 481 | raise Exception("fmt='%s' not implemented."%(fmt)) 482 | np.save(filename, npEvent) 483 | 484 | def _showTD1D(TD, fig=None, frameRate=24, preComputeFrames=True, repeat=False, plot=True): 485 | if TD.dim !=1: raise Exception('Expected Td dimension to be 1. It was: {}'.format(TD.dim)) 486 | if fig is None: fig = plt.figure() 487 | interval = 1e3 / frameRate # in ms 488 | xDim = TD.x.max()+1 489 | tMax = TD.t.max() 490 | tMin = TD.t.min() 491 | pMax = TD.p.max()+1 492 | minFrame = int(np.floor(tMin / interval)) 493 | maxFrame = int(np.ceil(tMax / interval )) + 1 494 | 495 | # ignore preComputeFrames 496 | 497 | raster, = plt.plot([], [], '.') 498 | scanLine, = plt.plot([], []) 499 | plt.axis((tMin -0.1*tMax, 1.1*tMax, -0.1*xDim, 1.1*xDim)) 500 | 501 | def animate(i): 502 | tEnd = (i + minFrame + 1) * interval 503 | ind = (TD.t < tEnd) 504 | # update raster 505 | raster.set_data(TD.t[ind], TD.x[ind]) 506 | # update raster scan line 507 | scanLine.set_data([tEnd + interval, tEnd + interval], [0, xDim]) 508 | 509 | 510 | anim = animation.FuncAnimation(fig, animate, frames=maxFrame, interval=interval, repeat=repeat) 511 | 512 | if plot is True: plt.show() 513 | return anim 514 | 515 | def _showTD2D(TD, fig=None, frameRate=24, preComputeFrames=True, repeat=False, plot=True): 516 | if TD.dim != 2: raise Exception('Expected Td dimension to be 2. It was: {}'.format(TD.dim)) 517 | if fig is None: fig = plt.figure() 518 | interval = 1e3 / frameRate # in ms 519 | xDim = TD.x.max()+1 520 | yDim = TD.y.max()+1 521 | 522 | if preComputeFrames is True: 523 | minFrame = int(np.floor(TD.t.min() / interval)) 524 | maxFrame = int(np.ceil(TD.t.max() / interval )) 525 | image = plt.imshow(np.zeros((yDim, xDim, 3))) 526 | frames = np.zeros( (maxFrame-minFrame, yDim, xDim, 3)) 527 | 528 | # precompute frames 529 | for i in range(len(frames)): 530 | tStart = (i + minFrame) * interval 531 | tEnd = (i + minFrame + 1) * interval 532 | timeMask = (TD.t >= tStart) & (TD.t < tEnd) 533 | rInd = (timeMask & (TD.p == 1)) 534 | gInd = (timeMask & (TD.p == 2)) 535 | bInd = (timeMask & (TD.p == 0)) 536 | frames[i, TD.y[rInd], TD.x[rInd], 0] = 1 537 | frames[i, TD.y[gInd], TD.x[gInd], 1] = 1 538 | frames[i, TD.y[bInd], TD.x[bInd], 2] = 1 539 | 540 | def animate(frame): 541 | image.set_data(frame) 542 | return image 543 | 544 | anim = animation.FuncAnimation(fig, animate, frames=frames, interval=interval, repeat=repeat) 545 | 546 | else: 547 | minFrame = int(np.floor(TD.t.min() / interval)) 548 | maxFrame = int(np.ceil(TD.t.max() / interval )) 549 | image = plt.imshow(np.zeros((yDim, xDim, 3))) 550 | def animate(i): 551 | tStart = (i + minFrame) * interval 552 | tEnd = (i + minFrame + 1) * interval 553 | frame = np.zeros((yDim, xDim, 3)) 554 | timeMask = (TD.t >= tStart) & (TD.t < tEnd) 555 | rInd = (timeMask & (TD.p == 1)) 556 | gInd = (timeMask & (TD.p == 2)) 557 | bInd = (timeMask & (TD.p == 0)) 558 | frame[TD.y[rInd], TD.x[rInd], 0] = 1 559 | frame[TD.y[gInd], TD.x[gInd], 1] = 1 560 | frame[TD.y[bInd], TD.x[bInd], 2] = 1 561 | image.set_data(frame) 562 | return image 563 | 564 | anim = animation.FuncAnimation(fig, animate, frames=maxFrame-minFrame, interval=interval, repeat=repeat) 565 | 566 | # # save the animation as an mp4. This requires ffmpeg or mencoder to be 567 | # # installed. The extra_args ensure that the x264 codec is used, so that 568 | # # the video can be embedded in html5. You may need to adjust this for 569 | # # your system: for more information, see 570 | # # http://matplotlib.sourceforge.net/api/animation_api.html 571 | # if saveAnimation: anim.save('showTD_animation.mp4', fps=30) 572 | 573 | if plot is True: plt.show() 574 | return anim 575 | 576 | def showTD(TD, fig=None, frameRate=24, preComputeFrames=True, repeat=False): 577 | ''' 578 | Visualizes TD event. 579 | 580 | Arguments: 581 | * ``TD``: spike event to visualize. 582 | * ``fig``: figure to plot animation. Default is ``None``, in which case a figure is created. 583 | * ``frameRate``: framerate of visualization. 584 | * ``preComputeFrames``: flag to enable precomputation of frames for faster visualization. Default is ``True``. 585 | * ``repeat``: flag to enable repeat of animation. Default is ``False``. 586 | 587 | Usage: 588 | 589 | >>> showTD(TD) 590 | ''' 591 | 592 | if fig is None: fig = plt.figure() 593 | if TD.dim == 1: 594 | _showTD1D(TD, fig, frameRate=frameRate, preComputeFrames=preComputeFrames, repeat=repeat) 595 | else: 596 | _showTD2D(TD, fig, frameRate=frameRate, preComputeFrames=preComputeFrames, repeat=repeat) 597 | 598 | def animTD(TD, fig=None, frameRate=24, preComputeFrames=True, repeat=True): 599 | ''' 600 | Reutrn animation object for TD event. 601 | 602 | Arguments: 603 | * ``TD``: spike event to visualize. 604 | * ``fig``: figure to plot animation. Default is ``None``, in which case a figure is created. 605 | * ``frameRate``: framerate of visualization. 606 | * ``preComputeFrames``: flag to enable precomputation of frames for faster visualization. Default is ``True``. 607 | * ``repeat``: flag to enable repeat of animation. Default is ``True``. 608 | 609 | Usage: 610 | 611 | >>> anim = animTD(TD) 612 | ''' 613 | if fig is None: fig = plt.figure() 614 | if TD.dim == 1: 615 | anim = _showTD1D(TD, fig, frameRate=frameRate, preComputeFrames=preComputeFrames, repeat=repeat, plot=False) 616 | else: 617 | anim = _showTD2D(TD, fig, frameRate=frameRate, preComputeFrames=preComputeFrames, repeat=repeat, plot=False) 618 | 619 | plt.close(anim._fig) 620 | return anim 621 | 622 | 623 | # def spikeMat2TD(spikeMat, samplingTime=1): # Sampling time in ms 624 | # addressEvent = np.argwhere(spikeMat > 0) 625 | # # print(addressEvent.shape) 626 | # return event(addressEvent[:,2], addressEvent[:,1], addressEvent[:,0], addressEvent[:,3] * samplingTime) 627 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/spikeLoss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from .slayer import spikeLayer 6 | 7 | class spikeLoss(torch.nn.Module): 8 | ''' 9 | This class defines different spike based loss modules that can be used to optimize the SNN. 10 | 11 | NOTE: By default, this class uses the spike kernels from ``slayer.spikeLayer`` (``snn.layer``). 12 | In some cases, you may want to explicitly use different spike kernels, for e.g. ``slayerLoihi.spikeLayer`` (``snn.loihi``). 13 | In that scenario, you can explicitly pass the class name: ``slayerClass=snn.loihi`` 14 | 15 | Usage: 16 | 17 | >>> error = spikeLoss.spikeLoss(networkDescriptor) 18 | >>> error = spikeLoss.spikeLoss(errorDescriptor, neuronDesc, simulationDesc) 19 | >>> error = spikeLoss.spikeLoss(netParams, slayerClass=slayerLoihi.spikeLayer) 20 | ''' 21 | def __init__(self, errorDescriptor, neuronDesc, simulationDesc, slayerClass=spikeLayer): 22 | super(spikeLoss, self).__init__() 23 | self.neuron = neuronDesc 24 | self.simulation = simulationDesc 25 | self.errorDescriptor = errorDescriptor 26 | # self.slayer = spikeLayer(neuronDesc, simulationDesc) 27 | self.slayer = slayerClass(self.neuron, self.simulation) 28 | 29 | def __init__(self, networkDescriptor, slayerClass=spikeLayer): 30 | super(spikeLoss, self).__init__() 31 | self.neuron = networkDescriptor['neuron'] 32 | self.simulation = networkDescriptor['simulation'] 33 | self.errorDescriptor = networkDescriptor['training']['error'] 34 | # self.slayer = spikeLayer(self.neuron, self.simulation) 35 | self.slayer = slayerClass(self.neuron, self.simulation) 36 | 37 | def spikeTime(self, spikeOut, spikeDesired): 38 | ''' 39 | Calculates spike loss based on spike time. 40 | The loss is similar to van Rossum distance between output and desired spike train. 41 | 42 | .. math:: 43 | 44 | E = \int_0^T \\left( \\varepsilon * (output -desired) \\right)(t)^2\\ \\text{d}t 45 | 46 | Arguments: 47 | * ``spikeOut`` (``torch.tensor``): spike tensor 48 | * ``spikeDesired`` (``torch.tensor``): desired spike tensor 49 | 50 | Usage: 51 | 52 | >>> loss = error.spikeTime(spikeOut, spikeDes) 53 | ''' 54 | # Tested with autograd, it works 55 | assert self.errorDescriptor['type'] == 'SpikeTime', "Error type is not SpikeTime" 56 | # error = self.psp(spikeOut - spikeDesired) 57 | error = self.slayer.psp(spikeOut - spikeDesired) 58 | return 1/2 * torch.sum(error**2) * self.simulation['Ts'] 59 | 60 | def numSpikes(self, spikeOut, desiredClass, numSpikesScale=1): 61 | ''' 62 | Calculates spike loss based on number of spikes within a `target region`. 63 | The `target region` and `desired spike count` is specified in ``error.errorDescriptor['tgtSpikeRegion']`` 64 | Any spikes outside the target region are penalized with ``error.spikeTime`` loss.. 65 | 66 | .. math:: 67 | e(t) &= 68 | \\begin{cases} 69 | \\frac{acutalSpikeCount - desiredSpikeCount}{targetRegionLength} & \\text{for }t \in targetRegion\\\\ 70 | \\left(\\varepsilon * (output - desired)\\right)(t) & \\text{otherwise} 71 | \\end{cases} 72 | 73 | E &= \\int_0^T e(t)^2 \\text{d}t 74 | 75 | Arguments: 76 | * ``spikeOut`` (``torch.tensor``): spike tensor 77 | * ``desiredClass`` (``torch.tensor``): one-hot encoded desired class tensor. Time dimension should be 1 and rest of the tensor dimensions should be same as ``spikeOut``. 78 | 79 | Usage: 80 | 81 | >>> loss = error.numSpikes(spikeOut, target) 82 | ''' 83 | # Tested with autograd, it works 84 | assert self.errorDescriptor['type'] == 'NumSpikes', "Error type is not NumSpikes" 85 | # desiredClass should be one-hot tensor with 5th dimension 1 86 | tgtSpikeRegion = self.errorDescriptor['tgtSpikeRegion'] 87 | tgtSpikeCount = self.errorDescriptor['tgtSpikeCount'] 88 | startID = np.rint( tgtSpikeRegion['start'] / self.simulation['Ts'] ).astype(int) 89 | stopID = np.rint( tgtSpikeRegion['stop' ] / self.simulation['Ts'] ).astype(int) 90 | 91 | actualSpikes = torch.sum(spikeOut[...,startID:stopID], 4, keepdim=True).cpu().detach().numpy() * self.simulation['Ts'] 92 | desiredSpikes = np.where(desiredClass.cpu() == True, tgtSpikeCount[True], tgtSpikeCount[False]) 93 | # print('actualSpikes :', actualSpikes.flatten()) 94 | # print('desiredSpikes:', desiredSpikes.flatten()) 95 | errorSpikeCount = (actualSpikes - desiredSpikes) / (stopID - startID) * numSpikesScale 96 | targetRegion = np.zeros(spikeOut.shape) 97 | targetRegion[:,:,:,:,startID:stopID] = 1; 98 | spikeDesired = torch.FloatTensor(targetRegion * spikeOut.cpu().data.numpy()).to(spikeOut.device) 99 | 100 | # error = self.psp(spikeOut - spikeDesired) 101 | error = self.slayer.psp(spikeOut - spikeDesired) 102 | error += torch.FloatTensor(errorSpikeCount * targetRegion).to(spikeOut.device) 103 | 104 | return 1/2 * torch.sum(error**2) * self.simulation['Ts'] 105 | 106 | def probSpikes(spikeOut, spikeDesired, probSlidingWindow = 20): 107 | assert self.errorDescriptor['type'] == 'ProbSpikes', "Error type is not ProbSpikes" 108 | pass 109 | 110 | # def numSpikesII(self, membranePotential, desiredClass, numSpikeScale=1): 111 | # assert self.errorDescriptor['type'] == 'NumSpikes', "Error type is not NumSpikes" 112 | # # desiredClass should be one-hot tensor with 5th dimension 1 113 | # tgtSpikeRegion = self.errorDescriptor['tgtSpikeRegion'] 114 | # tgtSpikeCount = self.errorDescriptor['tgtSpikeCount'] 115 | # startID = np.rint( tgtSpikeRegion['start'] / self.simulation['Ts'] ).astype(int) 116 | # stopID = np.rint( tgtSpikeRegion['stop' ] / self.simulation['Ts'] ).astype(int) 117 | 118 | # spikeOut = self.slayer.spike(membranePotential) 119 | # spikeDes = torch.zeros(spikeOut.shape, dtype=spikeOut.dtype).to(spikeOut.device) 120 | 121 | # actualSpikes = torch.sum(spikeOut[...,startID:stopID], 4, keepdim=True).cpu().detach().numpy() * self.simulation['Ts'] 122 | # desiredSpikes = np.where(desiredClass.cpu() == True, tgtSpikeCount[True], tgtSpikeCount[False]) 123 | 124 | # spikesAER = spikeOut.nonzero().tolist() 125 | 126 | # for n in range(spikeOut.shape[0]): 127 | # for c in range(spikeOut.shape[1]): 128 | # for h in range(spikeOut.shape[2]): 129 | # for w in range(spikeOut.shape[3]): 130 | # diff = desiredSpikes[n,c,h,w] - acutalSpikes[n,c,h,w] 131 | # if diff < 0: 132 | # spikesAER[n,c,h,w] = spikesAER[n,c,h,w,:diff] 133 | # elif diff > 0: 134 | # spikeDes[n,c,h,w,(actualInd[:diff] + startID)] = 1 / self.simulation['Ts'] 135 | # probableInds = np.random.randint(low=startID, high=stopID, size = diff) 136 | 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /spytorch2loihi/SlayerSNN_src/utils.py: -------------------------------------------------------------------------------- 1 | from .quantizeParams import quantize 2 | from .learningStats import learningStats as stats 3 | from . import optimizer as optim 4 | from .slayerParams import printConfig -------------------------------------------------------------------------------- /spytorch2loihi/netsLoihi/netLoihi_rec_th1_6.net: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/netsLoihi/netLoihi_rec_th1_6.net -------------------------------------------------------------------------------- /spytorch2loihi/spytorch2loihi_export_fsnn.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Imports\n", 10 | "import sys, os\n", 11 | "import zipfile\n", 12 | "import numpy as np\n", 13 | "import h5py\n", 14 | "import torch\n", 15 | "import re\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import math\n", 18 | "\n", 19 | "from SlayerSNN_src.auto.loihi import denseBlock, convBlock, flattenBlock, poolBlock, Network\n", 20 | "from SlayerSNN_src.slayerLoihi import spikeLayer as loihi\n", 21 | "from SlayerSNN_src import utils\n", 22 | "from torch.utils.data import Dataset, DataLoader\n", 23 | "from IPython.display import HTML\n", 24 | "\n", 25 | "from SlayerSNN_src.slayer import spikeLayer as layer\n", 26 | "from SlayerSNN_src.slayerLoihi import spikeLayer as loihi\n", 27 | "from SlayerSNN_src.slayerParams import yamlParams as params\n", 28 | "from SlayerSNN_src.spikeLoss import spikeLoss as loss\n", 29 | "from SlayerSNN_src.spikeClassifier import spikeClassifier as predict\n", 30 | "from SlayerSNN_src import spikeFileIO as io\n", 31 | "from SlayerSNN_src import utils\n", 32 | "# This will be removed later. Kept for compatibility only\n", 33 | "from SlayerSNN_src.quantizeParams import quantizeWeights as quantize\n", 34 | "\n", 35 | "# Added for debug\n", 36 | "%matplotlib inline" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "# Threshold\n", 46 | "threshold = 10 # 1, 2, 5, 10\n", 47 | "run = \"_3\"\n", 48 | "\n", 49 | "if threshold == 1:\n", 50 | " time_bin_size = 5\n", 51 | " nb_input_copies = 2\n", 52 | " tau_mem = 0.06\n", 53 | " tau_ratio = 10\n", 54 | "elif threshold == 2:\n", 55 | " time_bin_size = 3\n", 56 | " nb_input_copies = 8\n", 57 | " tau_mem = 0.05\n", 58 | " tau_ratio = 10\n", 59 | "elif threshold == 5:\n", 60 | " time_bin_size = 3\n", 61 | " nb_input_copies = 4\n", 62 | " tau_mem = 0.07\n", 63 | " tau_ratio = 10\n", 64 | "elif threshold == 10:\n", 65 | " time_bin_size = 5\n", 66 | " nb_input_copies = 2\n", 67 | " tau_mem = 0.07\n", 68 | " tau_ratio = 10\n", 69 | "\n", 70 | "# SpyTorch weights\n", 71 | "weights_path = \"../weights/SpyTorch_trained_weights_fwd_th\" + str(threshold) + run + \".pt\"" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 3, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "# Import weights\n", 81 | "path = weights_path\n", 82 | "SpyTorch_weights = torch.load(path, map_location=torch.device('cpu'))\n", 83 | "\n", 84 | "wgt1_in2hid = SpyTorch_weights[0].detach().numpy()\n", 85 | "wgt2_hid2out = SpyTorch_weights[1].detach().numpy()" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "(450, 28)\n", 98 | "[[ 0.09514701 -0.2498929 -0.21983077 ... -0.42935914 0.093075\n", 99 | " -0.22871235]\n", 100 | " [-0.37213632 -0.02317753 -0.06820287 ... -0.22371764 -0.07909746\n", 101 | " -0.14899771]\n", 102 | " [ 0.01635543 -0.2587211 -0.03930256 ... 0.2052648 0.14091352\n", 103 | " -0.5218253 ]\n", 104 | " ...\n", 105 | " [ 0.21863535 0.16354123 0.47647065 ... 0.12791806 0.12644301\n", 106 | " -0.8379483 ]\n", 107 | " [-0.02220986 -0.42480463 -0.06455369 ... -0.16587071 0.43334395\n", 108 | " 0.15958126]\n", 109 | " [ 0.21037246 0.02932504 -0.16226019 ... -0.10203745 0.04537127\n", 110 | " -0.2431798 ]]\n" 111 | ] 112 | } 113 | ], 114 | "source": [ 115 | "print(np.shape(SpyTorch_weights[1].detach().numpy()))\n", 116 | "print(SpyTorch_weights[1].detach().numpy())" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 5, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "(48, 450)\n" 129 | ] 130 | } 131 | ], 132 | "source": [ 133 | "print(np.shape(SpyTorch_weights[0].detach().numpy()))" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 6, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "name": "stdout", 143 | "output_type": "stream", 144 | "text": [ 145 | "wgt_scale_calc = 99\n", 146 | "vDecay_calc = 282\n", 147 | "iDecay_calc = 2090\n", 148 | "time_bins = 270\n" 149 | ] 150 | } 151 | ], 152 | "source": [ 153 | "# Loihi inference parameters approximation\n", 154 | "wgt_max = np.amax([np.amax(np.abs(wgt1_in2hid)), \n", 155 | " np.amax(np.abs(wgt2_hid2out))])\n", 156 | "wgt_scale_calc = math.floor(256/wgt_max) # round down\n", 157 | "tau_syn = tau_mem/tau_ratio\n", 158 | "alpha = float(np.exp(-(time_bin_size/1000)/tau_syn))\n", 159 | "beta = float(np.exp(-(time_bin_size/1000)/tau_mem))\n", 160 | "vDecay_calc = int(4096-4096*beta)\n", 161 | "iDecay_calc = int(4096-4096*alpha)\n", 162 | "time_bins = math.ceil(1350/time_bin_size) # round up\n", 163 | "\n", 164 | "print(\"wgt_scale_calc = \", wgt_scale_calc)\n", 165 | "print(\"vDecay_calc = \", vDecay_calc)\n", 166 | "print(\"iDecay_calc = \", iDecay_calc)\n", 167 | "print(\"time_bins = \", time_bins)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 7, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "# Loihi inference parameters\n", 177 | "wgt_scale = wgt_scale_calc # scale for all weight\n", 178 | "vThMant = wgt_scale_calc # vth = vthMant ∗ 64\n", 179 | "vDecay = vDecay_calc # tau_mem\n", 180 | "iDecay = iDecay_calc # tau_syn\n", 181 | "\n", 182 | "qtz_step = 2 # weights quantization step\n", 183 | "rec_scale = 1 # extra scale for recurrent weights\n", 184 | "refDelay = 1 # refractory delay\n", 185 | "wgtExp = 0 # 2**(6+wgtExp) * W * spike_input" 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 8, 191 | "metadata": {}, 192 | "outputs": [ 193 | { 194 | "data": { 195 | "text/plain": [ 196 | "(array([ 7., 68., 473., 2047., 4895., 6715., 5019., 1914., 408.,\n", 197 | " 54.]),\n", 198 | " array([-255.82756 , -209.13321 , -162.43886 , -115.7445 , -69.05015 ,\n", 199 | " -22.355797, 24.338556, 71.032906, 117.727264, 164.42162 ,\n", 200 | " 211.11597 ], dtype=float32),\n", 201 | " )" 202 | ] 203 | }, 204 | "execution_count": 8, 205 | "metadata": {}, 206 | "output_type": "execute_result" 207 | }, 208 | { 209 | "data": { 210 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD7CAYAAACG50QgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAASrElEQVR4nO3dYYxd5X3n8e+vOKRV2o3tMOu1bGdNVStd+iKJdwSuWlXdsDXGVDGVGkS0WmZZS94XbJVIlRpn8wItNBLsSs0GacuuVbxrqmwomzay1bClUydRtS8gDAl1Ag7rgYBsy+BpxiHtotIl/e+LeYbeODPMHZi5A36+H+nqPud/nnvueR7Zv3t87rnHqSokSX34sbXeAUnS6Bj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdWTL0k7wvyRMDj+8n+XiSjUkmk5xqzxta/yS5J8l0khNJdg5sa6L1P5VkYjUHJkn6UVnOdfpJLgPOAtcAtwGzVXVXkoPAhqr6RJK9wG8Ae1u/z1bVNUk2AlPAOFDA48A/raoLKzoiSdKi1i2z/7XAM1X1fJJ9wC+3+hHgq8AngH3A/TX3afJIkvVJNre+k1U1C5BkEtgDfH6xN7viiitq+/bty9xFSerb448//pdVNbbQuuWG/s38fUhvqqpzrf0CsKm1twCnB15zptUWqy9q+/btTE1NLXMXJalvSZ5fbN3QX+QmuRz4MPA/L17XjupX5H4OSQ4kmUoyNTMzsxKblCQ1y7l653rg61X1Ylt+sZ22oT2fb/WzwLaB121ttcXqP6SqDlXVeFWNj40t+K8TSdIbtJzQ/yg/fP79GDB/Bc4EcHSgfku7imcX8FI7DfQwsDvJhnalz+5WkySNyFDn9JO8C/gV4N8MlO8CHkyyH3geuKnVH2Luyp1p4GXgVoCqmk1yJ/BY63fH/Je6kqTRWNYlm6M2Pj5efpErScuT5PGqGl9onb/IlaSOGPqS1BFDX5I6YuhLUkeW+4tcSc32g19ak/d97q4b1uR9dWnwSF+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6MlToJ1mf5AtJvp3kZJKfT7IxyWSSU+15Q+ubJPckmU5yIsnOge1MtP6nkkys1qAkSQsb9kj/s8CfVNXPAu8HTgIHgeNVtQM43pYBrgd2tMcB4F6AJBuB24FrgKuB2+c/KCRJo7Fk6Cd5N/BLwH0AVfW3VfU9YB9wpHU7AtzY2vuA+2vOI8D6JJuB64DJqpqtqgvAJLBnBcciSVrCMEf6VwIzwH9L8o0kv5fkXcCmqjrX+rwAbGrtLcDpgdefabXF6pKkERkm9NcBO4F7q+qDwP/l70/lAFBVBdRK7FCSA0mmkkzNzMysxCYlSc0woX8GOFNVj7blLzD3IfBiO21Dez7f1p8Ftg28fmurLVb/IVV1qKrGq2p8bGxsOWORJC1hydCvqheA00ne10rXAk8Bx4D5K3AmgKOtfQy4pV3Fswt4qZ0GehjYnWRD+wJ3d6tJkkZk3ZD9fgP4XJLLgWeBW5n7wHgwyX7geeCm1vchYC8wDbzc+lJVs0nuBB5r/e6oqtkVGYUkaShDhX5VPQGML7Dq2gX6FnDbIts5DBxexv5JklaQv8iVpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdGSr0kzyX5JtJnkgy1Wobk0wmOdWeN7R6ktyTZDrJiSQ7B7Yz0fqfSjKxOkOSJC1mOUf6/6yqPlBV4235IHC8qnYAx9sywPXAjvY4ANwLcx8SwO3ANcDVwO3zHxSSpNF4M6d39gFHWvsIcONA/f6a8wiwPslm4Dpgsqpmq+oCMAnseRPvL0lapnVD9ivgT5MU8F+r6hCwqarOtfUvAJtaewtweuC1Z1ptsbqkZdh+8Etr9t7P3XXDmr23Vsawof+LVXU2yT8EJpN8e3BlVVX7QHjTkhxg7rQQ733ve1dik5KkZqjTO1V1tj2fB77I3Dn5F9tpG9rz+db9LLBt4OVbW22x+sXvdaiqxqtqfGxsbHmjkSS9riVDP8m7kvzUfBvYDXwLOAbMX4EzARxt7WPALe0qnl3AS+000MPA7iQb2he4u1tNkjQiw5ze2QR8Mcl8//9RVX+S5DHgwST7geeBm1r/h4C9wDTwMnArQFXNJrkTeKz1u6OqZldsJOrSWp7flt6Olgz9qnoWeP8C9e8C1y5QL+C2RbZ1GDi8/N2UJK0Ef5ErSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6MnToJ7ksyTeS/HFbvjLJo0mmk/xBkstb/Z1tebqt3z6wjU+2+tNJrlvx0UiSXtdyjvQ/BpwcWL4b+ExV/QxwAdjf6vuBC63+mdaPJFcBNwM/B+wBfjfJZW9u9yVJyzFU6CfZCtwA/F5bDvAh4AutyxHgxtbe15Zp669t/fcBD1TVK1X1HWAauHoFxiBJGtKwR/r/Cfgt4O/a8nuA71XVq235DLCltbcApwHa+pda/9fqC7zmNUkOJJlKMjUzMzP8SCRJS1oy9JP8KnC+qh4fwf5QVYeqaryqxsfGxkbxlpLUjXVD9PkF4MNJ9gI/DvwD4LPA+iTr2tH8VuBs638W2AacSbIOeDfw3YH6vMHXSJJGYMkj/ar6ZFVtrartzH0R++Wq+hfAV4Bfb90mgKOtfawt09Z/uaqq1W9uV/dcCewAvrZiI5EkLWmYI/3FfAJ4IMlvA98A7mv1+4DfTzINzDL3QUFVPZnkQeAp4FXgtqr6wZt4f0nSMi0r9Kvqq8BXW/tZFrj6pqr+BvjIIq//NPDp5e6kJGll+ItcSeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUkSVDP8mPJ/lakr9I8mSSf9/qVyZ5NMl0kj9Icnmrv7MtT7f12we29clWfzrJdas2KknSgoY50n8F+FBVvR/4ALAnyS7gbuAzVfUzwAVgf+u/H7jQ6p9p/UhyFXAz8HPAHuB3k1y2gmORJC1hydCvOX/dFt/RHgV8CPhCqx8BbmztfW2Ztv7aJGn1B6rqlar6DjANXL0Sg5AkDWeoc/pJLkvyBHAemASeAb5XVa+2LmeALa29BTgN0Na/BLxnsL7Aawbf60CSqSRTMzMzyx6QJGlxQ4V+Vf2gqj4AbGXu6PxnV2uHqupQVY1X1fjY2NhqvY0kdWlZV+9U1feArwA/D6xPsq6t2gqcbe2zwDaAtv7dwHcH6wu8RpI0AsNcvTOWZH1r/wTwK8BJ5sL/11u3CeBoax9ry7T1X66qavWb29U9VwI7gK+t0DgkSUNYt3QXNgNH2pU2PwY8WFV/nOQp4IEkvw18A7iv9b8P+P0k08Asc1fsUFVPJnkQeAp4Fbitqn6wssORJL2eJUO/qk4AH1yg/iwLXH1TVX8DfGSRbX0a+PTyd1OStBL8Ra4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdMfQlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSR5YM/STbknwlyVNJnkzysVbfmGQyyan2vKHVk+SeJNNJTiTZObCtidb/VJKJ1RuWJGkhwxzpvwr8ZlVdBewCbktyFXAQOF5VO4DjbRngemBHexwA7oW5DwngduAa4Grg9vkPCknSaCwZ+lV1rqq+3tp/BZwEtgD7gCOt2xHgxtbeB9xfcx4B1ifZDFwHTFbVbFVdACaBPSs5GEnS61u3nM5JtgMfBB4FNlXVubbqBWBTa28BTg+87EyrLVbXJWD7wS+t9S5IGsLQX+Qm+UngD4GPV9X3B9dVVQG1EjuU5ECSqSRTMzMzK7FJSVIzVOgneQdzgf+5qvqjVn6xnbahPZ9v9bPAtoGXb221xeo/pKoOVdV4VY2PjY0tZyySpCUMc/VOgPuAk1X1OwOrjgHzV+BMAEcH6re0q3h2AS+100APA7uTbGhf4O5uNUnSiAxzTv8XgH8JfDPJE63274C7gAeT7AeeB25q6x4C9gLTwMvArQBVNZvkTuCx1u+OqppdiUFIGo21+u7mubtuWJP3vRQtGfpV9b+BLLL62gX6F3DbIts6DBxezg5KklaOv8iVpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOGPqS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktQRQ1+SOmLoS1JHDH1J6oihL0kdWTL0kxxOcj7JtwZqG5NMJjnVnje0epLck2Q6yYkkOwdeM9H6n0oysTrDkSS9nmGO9P87sOei2kHgeFXtAI63ZYDrgR3tcQC4F+Y+JIDbgWuAq4Hb5z8oJEmjs2ToV9WfA7MXlfcBR1r7CHDjQP3+mvMIsD7JZuA6YLKqZqvqAjDJj36QSJJW2Rs9p7+pqs619gvAptbeApwe6Hem1RarS5JG6E1/kVtVBdQK7AsASQ4kmUoyNTMzs1KblSTxxkP/xXbahvZ8vtXPAtsG+m1ttcXqP6KqDlXVeFWNj42NvcHdkyQt5I2G/jFg/gqcCeDoQP2WdhXPLuCldhroYWB3kg3tC9zdrSZJGqF1S3VI8nngl4Erkpxh7iqcu4AHk+wHngduat0fAvYC08DLwK0AVTWb5E7gsdbvjqq6+MthSdIqWzL0q+qji6y6doG+Bdy2yHYOA4eXtXeSpBXlL3IlqSOGviR1xNCXpI4Y+pLUEUNfkjpi6EtSRwx9SeqIoS9JHTH0Jakjhr4kdcTQl6SOLHnvHb19bD/4pbXeBUlvcR7pS1JHPNKX9Ja3Vv+Kfe6uG9bkfVeTR/qS1BFDX5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjow89JPsSfJ0kukkB0f9/pLUs5GGfpLLgP8MXA9cBXw0yVWj3AdJ6tmob8NwNTBdVc8CJHkA2Ac8NeL9WFXe+Ey6NKzl3+XVugXEqE/vbAFODyyfaTVJ0gi85W64luQAcKAt/nWSp9dyf1bYFcBfrvVOrDHnwDkA52DeovOQu9/Udv/xYitGHfpngW0Dy1tb7TVVdQg4NMqdGpUkU1U1vtb7sZacA+cAnIN5azEPoz698xiwI8mVSS4HbgaOjXgfJKlbIz3Sr6pXk/xb4GHgMuBwVT05yn2QpJ6N/Jx+VT0EPDTq932LuCRPWy2Tc+AcgHMwb+TzkKoa9XtKktaIt2GQpI4Y+qsgyX9M8u0kJ5J8Mcn6gXWfbLegeDrJdQP1S+72FEk+kuTJJH+XZPyidd3Mw6BLfXzzkhxOcj7JtwZqG5NMJjnVnje0epLc0+bkRJKda7fnKyfJtiRfSfJU+3vwsVZf23moKh8r/AB2A+ta+27g7ta+CvgL4J3AlcAzzH2hfVlr/zRweetz1VqPYwXm4Z8A7wO+CowP1Luah4FxX9Lju2isvwTsBL41UPsPwMHWPjjw92Iv8L+AALuAR9d6/1doDjYDO1v7p4D/0/7sr+k8eKS/CqrqT6vq1bb4CHO/R4C5W048UFWvVNV3gGnmbk3x2u0pqupvgfnbU7ytVdXJqlrox3VdzcOAS318r6mqPwdmLyrvA4609hHgxoH6/TXnEWB9ks0j2dFVVFXnqurrrf1XwEnm7kCwpvNg6K++f83cpzcsfhuK3m5P0es8XOrjW8qmqjrX2i8Am1r7kp+XJNuBDwKPssbz8Ja7DcPbRZI/A/7RAqs+VVVHW59PAa8Cnxvlvo3SMPMgXayqKkkXlw4m+UngD4GPV9X3k7y2bi3mwdB/g6rqn7/e+iT/CvhV4NpqJ+x4/dtQvO7tKd6qlpqHRVxy8zCkJW9Dcol7McnmqjrXTlucb/VLdl6SvIO5wP9cVf1RK6/pPHh6ZxUk2QP8FvDhqnp5YNUx4OYk70xyJbAD+Br93Z6i13m41Me3lGPARGtPAEcH6re0q1d2AS8NnP5428rcIf19wMmq+p2BVWs7D2v9Dfel+GDui8nTwBPt8V8G1n2KuSs4ngauH6jvZe7b/WeYOzWy5uNYgXn4NebOS74CvAg83OM8XDQnl/T4Bsb5eeAc8P/an4H9wHuA48Ap4M+Aja1vmPvPlZ4BvsnAlV5v5wfwi0ABJwayYO9az4O/yJWkjnh6R5I6YuhLUkcMfUnqiKEvSR0x9CWpI4a+JHXE0Jekjhj6ktSR/w9X/xMeWORJyQAAAABJRU5ErkJggg==", 211 | "text/plain": [ 212 | "
" 213 | ] 214 | }, 215 | "metadata": { 216 | "needs_background": "light" 217 | }, 218 | "output_type": "display_data" 219 | } 220 | ], 221 | "source": [ 222 | "plt.hist(SpyTorch_weights[0].detach().numpy().flatten() * wgt_scale)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 9, 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "# Reshape weights\n", 232 | "spy_weights = []\n", 233 | "\n", 234 | "spy_weights.append([]) # flatten layer\n", 235 | "\n", 236 | "spy_weights.append(np.reshape(\n", 237 | " np.transpose(\n", 238 | " SpyTorch_weights[0].detach().numpy()\n", 239 | " ), \n", 240 | " (450, 24*nb_input_copies, 1, 1, 1)) * wgt_scale)\n", 241 | "\n", 242 | "spy_weights.append(np.reshape(\n", 243 | " np.transpose(\n", 244 | " SpyTorch_weights[1].detach().numpy()\n", 245 | " ),\n", 246 | " (28, 450, 1, 1, 1)) * wgt_scale)" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 10, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "# Describe the network\n", 256 | "netDesc = {\n", 257 | " 'simulation' : {\n", 258 | " 'Ts': 1,\n", 259 | " 'tSample': time_bins,\n", 260 | " },\n", 261 | " 'neuron' : {\n", 262 | " 'type' : 'LOIHI',\n", 263 | " 'vThMant' : vThMant,\n", 264 | " 'vDecay' : vDecay,\n", 265 | " 'iDecay' : iDecay,\n", 266 | " 'refDelay' : refDelay,\n", 267 | " 'wgtExp' : wgtExp,\n", 268 | " 'tauRho' : 1, # useless in inference\n", 269 | " 'scaleRho' : 1, # useless in inference\n", 270 | " },\n", 271 | " 'layer' : [\n", 272 | " {'dim' : \"'\" + str(24*nb_input_copies) + \"x1x1\"}, # Width x Height x Channels\n", 273 | " {'dim' : 450, 'delay' : False},\n", 274 | " {'dim' : 28, 'delay' : False}\n", 275 | " ]\n", 276 | "}\n", 277 | "\n", 278 | "netParams = params(dict=netDesc)" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 11, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "class recurrentBlock(torch.nn.Module):\n", 288 | " def __init__(self, slayer, inFeatures, outFeatures, weightScale, \n", 289 | " preHookFx = lambda x: utils.quantize(x, step=qtz_step), weightNorm=False, \n", 290 | " delay=False, maxDelay=62, countLog=False):\n", 291 | " super(recurrentBlock, self).__init__()\n", 292 | " self.slayer = slayer\n", 293 | " self.weightNorm = weightNorm\n", 294 | " if weightNorm is True:\n", 295 | " self.weightOp = torch.nn.utils.weight_norm(slayer.dense(\n", 296 | " inFeatures, outFeatures, weightScale, preHookFx), name='weight')\n", 297 | " self.recWeightOp = torch.nn.utils.weight_norm(slayer.dense(\n", 298 | " outFeatures, outFeatures, weightScale, preHookFx), name='recWeight')\n", 299 | " else:\n", 300 | " self.weightOp = slayer.dense(inFeatures, outFeatures, weightScale, preHookFx)\n", 301 | " self.recWeightOp = slayer.dense(outFeatures, outFeatures, weightScale, preHookFx)\n", 302 | " self.delayOp = slayer.delay(outFeatures) if delay is True else None\n", 303 | " self.countLog = countLog\n", 304 | " self.gradLog = True\n", 305 | " self.maxDelay = maxDelay\n", 306 | " \n", 307 | " self.paramsDict = {\n", 308 | " 'inFeatures' : inFeatures,\n", 309 | " 'outFeatures' : outFeatures,\n", 310 | " }\n", 311 | " \n", 312 | " def forward(self, spike):\n", 313 | " spike = self.slayer.spikeLoihi(self.weightOp(spike) + self.recWeightOp(spike))\n", 314 | " spike = self.slayer.delayShift(spike, 1)\n", 315 | " if self.delayOp is not None:\n", 316 | " spike = self.delayOp(spike)\n", 317 | " if self.countLog is True:\n", 318 | " return spike, torch.sum(spike)\n", 319 | " else:\n", 320 | " return spike" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 12, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "class SpyTorch2Loihi(Network):\n", 330 | " \n", 331 | " def __init__(self, netParams, weights):\n", 332 | " super(SpyTorch2Loihi, self).__init__(netParams)\n", 333 | " self.preHookFx = lambda x: utils.quantize(x, step=qtz_step)\n", 334 | " \n", 335 | " # Load the weights trained in spytorch\n", 336 | " self.weights = weights\n", 337 | " self.recWeights = None\n", 338 | " \n", 339 | " \n", 340 | " def _layerType(self, dim):\n", 341 | " if type(dim) is int:\n", 342 | " return 'dense'\n", 343 | " elif dim.find('c') != -1:\n", 344 | " return 'conv'\n", 345 | " elif dim.find('avg') != -1:\n", 346 | " return 'average'\n", 347 | " elif dim.find('a') != -1:\n", 348 | " return 'pool'\n", 349 | " elif dim.find('x') != -1:\n", 350 | " return 'input'\n", 351 | " elif dim.find('r') != -1:\n", 352 | " return 'recurrent'\n", 353 | " else:\n", 354 | " raise Exception('Could not parse the layer description. Found {}'.format(dim))\n", 355 | " # return [int(i) for i in re.findall(r'\\d+', dim)]\n", 356 | "\n", 357 | " \n", 358 | " def _parseLayers(self):\n", 359 | " i = 0\n", 360 | " blocks = torch.nn.ModuleList()\n", 361 | " layerDim = [] # CHW\n", 362 | " is1Dconv = False\n", 363 | "\n", 364 | " print('\\nNetwork Architecture:')\n", 365 | " # print('=====================')\n", 366 | " print(self._tableStr(header=True))\n", 367 | "\n", 368 | " for layer in self.netParams['layer']:\n", 369 | " layerType = self._layerType(layer['dim'])\n", 370 | " # print(i, layerType)\n", 371 | "\n", 372 | " # If layer has neuron feild, then use the slayer initialized with it and self.netParams['simulation']\n", 373 | " if 'neuron' in layer.keys():\n", 374 | " print(layerType, 'using individual slayer')\n", 375 | " slayer = loihi(layer['neuron'], self.netParams['simulation'])\n", 376 | " else:\n", 377 | " slayer = self.slayer\n", 378 | "\n", 379 | " if i==0 and self.inputShape is None: \n", 380 | " if layerType == 'input':\n", 381 | " self.inputShape = tuple([int(numStr) for numStr in re.findall(r'\\d+', layer['dim'])])\n", 382 | " if len(self.inputShape) == 3:\n", 383 | " layerDim = list(self.inputShape)[::-1]\n", 384 | " elif len(self.inputShape) == 2:\n", 385 | " layerDim = [1, self.inputShape[1], self.inputShape[0]]\n", 386 | " else:\n", 387 | " raise Exception('Could not parse the input dimension. Got {}'.format(self.inputShape))\n", 388 | " elif layerType == 'dense':\n", 389 | " self.inputShape = tuple([layer['dim']])\n", 390 | " layerDim = [layer['dim'], 1, 1]\n", 391 | " else:\n", 392 | " raise Exception('Input dimension could not be determined! It should be the first entry in the' \n", 393 | " + \"'layer' feild.\")\n", 394 | " # print(self.inputShape)\n", 395 | " print(self._tableStr('Input', layerDim[2], layerDim[1], layerDim[0]))\n", 396 | " if layerDim[1] == 1:\n", 397 | " is1Dconv = True\n", 398 | " else:\n", 399 | " # print(i, layer['dim'], self._layerType(layer['dim']))\n", 400 | " if layerType == 'conv':\n", 401 | " params = [int(i) for i in re.findall(r'\\d+', layer['dim'])]\n", 402 | " inChannels = layerDim[0]\n", 403 | " outChannels = params[0]\n", 404 | " kernelSize = params[1]\n", 405 | " stride = layer['stride'] if 'stride' in layer.keys() else 1\n", 406 | " padding = layer['padding'] if 'padding' in layer.keys() else kernelSize//2\n", 407 | " dilation = layer['dilation'] if 'dilation' in layer.keys() else 1\n", 408 | " groups = layer['groups'] if 'groups' in layer.keys() else 1\n", 409 | " weightScale = layer['wScale'] if 'wScale' in layer.keys() else 100\n", 410 | " delay = layer['delay'] if 'delay' in layer.keys() else False\n", 411 | " maxDelay = layer['maxDelay'] if 'maxDelay' in layer.keys() else 62\n", 412 | " # print(i, inChannels, outChannels, kernelSize, stride, padding, dilation, groups, weightScale)\n", 413 | " \n", 414 | " if is1Dconv is False:\n", 415 | " blocks.append(convBlock(slayer, inChannels, outChannels, kernelSize, stride, padding, \n", 416 | " dilation, groups, weightScale, self.preHookFx, self.weightNorm, \n", 417 | " delay, maxDelay, self.countLog))\n", 418 | " layerDim[0] = outChannels\n", 419 | " layerDim[1] = int(np.floor((layerDim[1] + 2*padding - dilation * (kernelSize - 1) - 1)/stride + 1))\n", 420 | " layerDim[2] = int(np.floor((layerDim[2] + 2*padding - dilation * (kernelSize - 1) - 1)/stride + 1))\n", 421 | " else:\n", 422 | " blocks.append(convBlock(slayer, inChannels, outChannels, [1, kernelSize], [1, stride], [0, padding], \n", 423 | " [1, dilation], groups, weightScale, self.preHookFx, self.weightNorm, \n", 424 | " delay, maxDelay, self.countLog))\n", 425 | " layerDim[0] = outChannels\n", 426 | " layerDim[1] = 1\n", 427 | " layerDim[2] = int(np.floor((layerDim[2] + 2*padding - dilation * (kernelSize - 1) - 1)/stride + 1))\n", 428 | " self.layerDims.append(layerDim.copy())\n", 429 | "\n", 430 | " print(self._tableStr('Conv', layerDim[2], layerDim[1], layerDim[0], kernelSize, stride, padding, \n", 431 | " delay, sum(p.numel() for p in blocks[-1].parameters() if p.requires_grad)))\n", 432 | " elif layerType == 'pool':\n", 433 | " params = [int(i) for i in re.findall(r'\\d+', layer['dim'])]\n", 434 | " # print(params[0])\n", 435 | " \n", 436 | " blocks.append(poolBlock(slayer, params[0], countLog=self.countLog))\n", 437 | " layerDim[1] = int(np.ceil(layerDim[1] / params[0]))\n", 438 | " layerDim[2] = int(np.ceil(layerDim[2] / params[0]))\n", 439 | " self.layerDims.append(layerDim.copy())\n", 440 | "\n", 441 | " print(self._tableStr('Pool', layerDim[2], layerDim[1], layerDim[0], params[0]))\n", 442 | " elif layerType == 'dense':\n", 443 | " params = layer['dim']\n", 444 | " # print(params)\n", 445 | " if layerDim[1] != 1 or layerDim[2] != 1: # needs flattening of layers\n", 446 | " blocks.append(flattenBlock(self.countLog ))\n", 447 | " layerDim[0] = layerDim[0] * layerDim[1] * layerDim[2]\n", 448 | " layerDim[1] = layerDim[2] = 1\n", 449 | " self.layerDims.append(layerDim.copy())\n", 450 | " weightScale = layer['wScale'] if 'wScale' in layer.keys() else 100\n", 451 | " delay = layer['delay'] if 'delay' in layer.keys() else False\n", 452 | " maxDelay = layer['maxDelay'] if 'maxDelay' in layer.keys() else 62\n", 453 | " \n", 454 | " blocks.append(denseBlock(slayer, layerDim[0], params, weightScale, self.preHookFx, \n", 455 | " self.weightNorm, delay, maxDelay, self.countLog))\n", 456 | " layerDim[0] = params\n", 457 | " layerDim[1] = layerDim[2] = 1\n", 458 | " self.layerDims.append(layerDim.copy())\n", 459 | "\n", 460 | " print(self._tableStr('Dense', layerDim[2], layerDim[1], layerDim[0], delay=delay, \n", 461 | " numParams=sum(p.numel() for p in blocks[-1].parameters() if p.requires_grad)))\n", 462 | " elif layerType == 'recurrent':\n", 463 | " #params = layer['dim']\n", 464 | " params = [int(i) for i in re.findall(r'\\d+', layer['dim'])]\n", 465 | " # print(params)\n", 466 | " if layerDim[1] != 1 or layerDim[2] != 1: # needs flattening of layers\n", 467 | " blocks.append(flattenBlock(self.countLog ))\n", 468 | " layerDim[0] = layerDim[0] * layerDim[1] * layerDim[2]\n", 469 | " layerDim[1] = layerDim[2] = 1\n", 470 | " self.layerDims.append(layerDim.copy())\n", 471 | " weightScale = layer['wScale'] if 'wScale' in layer.keys() else 100\n", 472 | " delay = layer['delay'] if 'delay' in layer.keys() else False\n", 473 | " maxDelay = layer['maxDelay'] if 'maxDelay' in layer.keys() else 62\n", 474 | " \n", 475 | " blocks.append(recurrentBlock(slayer, layerDim[0], params[0], weightScale, self.preHookFx, \n", 476 | " self.weightNorm, delay, maxDelay, self.countLog))\n", 477 | " layerDim[0] = params[0]\n", 478 | " layerDim[1] = layerDim[2] = 1\n", 479 | " self.layerDims.append(layerDim.copy())\n", 480 | "\n", 481 | " print(self._tableStr('Recurrent', layerDim[2], layerDim[1], layerDim[0], delay=delay, \n", 482 | " numParams=sum(p.numel() for p in blocks[-1].parameters() if p.requires_grad)))\n", 483 | " elif layerType == 'average':\n", 484 | " params = [int(i) for i in re.findall(r'\\d+', layer['dim'])]\n", 485 | " layerDim[0] = params[0]\n", 486 | " layerDim[1] = layerDim[2] = 1\n", 487 | " self.layerDims.append(layerDim.copy())\n", 488 | "\n", 489 | " blocks.append(averageBlock(nOutputs=layerDim[0], countLog=self.countLog))\n", 490 | " print(self._tableStr('Average', 1, 1, params[0]))\n", 491 | "\n", 492 | " i += 1\n", 493 | " self.nOutput = layerDim[0] * layerDim[1] * layerDim[2]\n", 494 | " print(self._tableStr(numParams=sum(p.numel() for p in blocks.parameters() if p.requires_grad), footer=True))\n", 495 | " return blocks\n", 496 | " \n", 497 | " \n", 498 | " def genSpyModel(self, fname):\n", 499 | " qWeights = lambda x: self.preHookFx(x).cpu().data.numpy().squeeze()\n", 500 | "\n", 501 | " h = h5py.File(fname, 'w')\n", 502 | "\n", 503 | " simulation = h.create_group('simulation')\n", 504 | "\n", 505 | " for key, value in self.netParams['simulation'].items():\n", 506 | " # print(key, value)\n", 507 | " simulation[key] = value\n", 508 | "\n", 509 | " layer = h.create_group('layer')\n", 510 | " layer.create_dataset('0/type', (1, ), 'S10', [b'input'])\n", 511 | " layer.create_dataset('0/shape', data=np.array([self.inputShape[2], self.inputShape[1], self.inputShape[0]]))\n", 512 | " \n", 513 | " for i, block in enumerate(self.blocks):\n", 514 | " print(\"\\nblock %d / %d\" % (i, len(self.blocks)))\n", 515 | " layerType = block.__class__.__name__[:-5]\n", 516 | " \n", 517 | " print(layerType.encode('ascii', 'ignore'))\n", 518 | " layer.create_dataset('{}/type'.format(i+1), (1, ), 'S10', [layerType.encode('ascii', 'ignore')])\n", 519 | " \n", 520 | " print(i, self.layerDims[i])\n", 521 | " layer.create_dataset('{}/shape'.format(i+1), data=np.array(self.layerDims[i]))\n", 522 | " \n", 523 | " if layerType != 'flatten':\n", 524 | " layer.create_dataset('{}/weight'.format(i+1), data=qWeights(torch.Tensor(self.weights[i])))\n", 525 | " if layerType == 'recurrent':\n", 526 | " layer.create_dataset('{}/recWeight'.format(i+1), data=qWeights(torch.Tensor(self.recWeights)))\n", 527 | " \n", 528 | " for key, param in block.paramsDict.items():\n", 529 | " layer.create_dataset('{}/{}'.format(i+1, key), data=param)\n", 530 | " \n", 531 | " if layerType != 'flatten' and layerType != 'average':\n", 532 | " for key, value in block.slayer.neuron.items():\n", 533 | " # print(i, key, value)\n", 534 | " layer.create_dataset('{}/neuron/{}'.format(i+1, key), data=value)\n", 535 | " h.close()" 536 | ] 537 | }, 538 | { 539 | "cell_type": "code", 540 | "execution_count": 13, 541 | "metadata": {}, 542 | "outputs": [ 543 | { 544 | "name": "stdout", 545 | "output_type": "stream", 546 | "text": [ 547 | "simulation:\n", 548 | " Ts : 1\n", 549 | " tSample : 270\n", 550 | "\n", 551 | "neuron:\n", 552 | " type : LOIHI\n", 553 | " vThMant : 99\n", 554 | " vDecay : 282\n", 555 | " iDecay : 2090\n", 556 | " refDelay : 1\n", 557 | " wgtExp : 0\n", 558 | " tauRho : 1\n", 559 | " scaleRho : 1\n", 560 | "\n", 561 | "Max PSP kernel: 99.0\n", 562 | "Scaling neuron[scaleRho] by Max PSP Kernel @slayerLoihi\n", 563 | "\n", 564 | "Network Architecture:\n", 565 | "| Type | W | H | C | ker | str | pad |delay| params |\n", 566 | "|Input | 48| 1| 1| | | |False| |\n", 567 | "|Dense | 1| 1| 450| | | |False| 21600|\n", 568 | "|Dense | 1| 1| 28| | | |False| 12600|\n", 569 | "|Total | 34200|\n", 570 | "TODO core usage estimator\n" 571 | ] 572 | } 573 | ], 574 | "source": [ 575 | "# Create the network\n", 576 | "netLoihi = SpyTorch2Loihi(netParams, spy_weights)" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 14, 582 | "metadata": {}, 583 | "outputs": [ 584 | { 585 | "name": "stdout", 586 | "output_type": "stream", 587 | "text": [ 588 | "\n", 589 | "block 0 / 3\n", 590 | "b'flatten'\n", 591 | "0 [48, 1, 1]\n", 592 | "\n", 593 | "block 1 / 3\n", 594 | "b'dense'\n", 595 | "1 [450, 1, 1]\n", 596 | "\n", 597 | "block 2 / 3\n", 598 | "b'dense'\n", 599 | "2 [28, 1, 1]\n" 600 | ] 601 | } 602 | ], 603 | "source": [ 604 | "# Export the model\n", 605 | "netLoihi.genSpyModel('../netsLoihi/netLoihi_fwd_th' + str(threshold) + run + '.net')" 606 | ] 607 | }, 608 | { 609 | "cell_type": "code", 610 | "execution_count": null, 611 | "metadata": {}, 612 | "outputs": [], 613 | "source": [] 614 | } 615 | ], 616 | "metadata": { 617 | "interpreter": { 618 | "hash": "9bf43d2b4a4b64acce80ec436e45972ca3a2814376cdde651dbbddec46144e4b" 619 | }, 620 | "kernelspec": { 621 | "display_name": "Python 3.8.10 64-bit ('pyenv_pytorch')", 622 | "language": "python", 623 | "name": "python3" 624 | }, 625 | "language_info": { 626 | "codemirror_mode": { 627 | "name": "ipython", 628 | "version": 3 629 | }, 630 | "file_extension": ".py", 631 | "mimetype": "text/x-python", 632 | "name": "python", 633 | "nbconvert_exporter": "python", 634 | "pygments_lexer": "ipython3", 635 | "version": "3.8.10" 636 | } 637 | }, 638 | "nbformat": 4, 639 | "nbformat_minor": 2 640 | } 641 | -------------------------------------------------------------------------------- /spytorch2loihi/weights/SpyTorch_trained_weights_rec_th1_6.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/event-driven-robotics/tactile_braille_reading/e1e2d65ad6761ea814fc6883b0315ed455f7c6c0/spytorch2loihi/weights/SpyTorch_trained_weights_rec_th1_6.pt -------------------------------------------------------------------------------- /utils/event_transform.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2021 3 | Authors: Alejandro Pequeno-Zurro 4 | 5 | This program is free software: you can redistribute it and/or modify it under 6 | the terms of the GNU General Public License as published by the Free Software 7 | Foundation, either version 3 of the License, or (at your option) any later version. 8 | This program is distributed in the hope that it will be useful, but WITHOUT ANY 9 | WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A 10 | PARTICULAR PURPOSE. See the GNU General Public License for more details. 11 | You should have received a copy of the GNU General Public License along with 12 | this program. If not, see . 13 | ''' 14 | import os 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import pickle 18 | from scipy.signal import argrelextrema 19 | from scipy.interpolate import interp1d 20 | 21 | DEBUG = False 22 | DirOut = "../plots/" 23 | save_fig = False 24 | 25 | def sample_to_changes(sample, f, threshold, save): 26 | ''' Convert one sample time-based to event-based 27 | sample: time-based sample 28 | f: frequency of the time-based sequence 29 | threshold: create an event at certain threshold value 30 | Find the local max and min values of the sequence and applied interpolation in time based on threshold to find 31 | the correspondent event time. 32 | ''' 33 | Precision = 4 # Fix numerical errors due to float values in arange method 29.000000000000014 34 | n = sample.shape[0] 35 | dt = 1/f 36 | taxel_samples = np.transpose(sample, (1, 0)).tolist() 37 | sample_list = list() 38 | for nt, taxel in enumerate(taxel_samples): 39 | # Find indexes in the sequence with local maximum and minimum to apply interpolation. 40 | txl = np.array(taxel, dtype=int) 41 | # max 42 | ind_max = np.squeeze(np.array(argrelextrema(txl, np.greater_equal))) 43 | d_ixtr = np.insert(np.diff(ind_max), 0, -1) # Match dimensions of the index 44 | max_p = ind_max[d_ixtr != 1] 45 | # min 46 | ind_min = np.squeeze(np.array(argrelextrema(txl, np.less_equal))) 47 | d_ixtr = np.insert(np.diff(ind_min), 0, -1) # Match dimensions of the index 48 | min_p = ind_min[d_ixtr != 1] 49 | # add index with same values 50 | all_indx = np.append(max_p, min_p) 51 | i = 0 52 | while i < len(all_indx): 53 | try: 54 | ival = all_indx[i] 55 | if txl[ival + 1] - txl[ival] == 0: 56 | all_indx = np.append(all_indx, np.array(ival + 1)) 57 | except IndexError: 58 | None 59 | i += 1 60 | # Corresponding values in the sequence 61 | all_t = np.unique(np.sort(all_indx)) 62 | all_values = txl[all_t] 63 | # Find the events [ON, OFF] 64 | taxel_list = list() 65 | on_events = np.array([]); off_events = np.array([]) 66 | # Compare each pair of points and generate event times based on threshold 67 | last_value = all_values[0] # Last value storage controls when threshold is not reached 68 | for i in range(len(all_values) - 1): 69 | d_pair = all_values[i+1] - last_value 70 | if d_pair > 0: 71 | start = last_value + threshold 72 | stop = all_values[i+1] + 0.0001 73 | spk_values = np.round(np.arange(start, stop, threshold), Precision) 74 | # Interpolation with all the values of the pair 75 | pts = all_t[i+1] - all_t[i] + 1 76 | t_interp = np.linspace(all_t[i], all_t[i+1], pts, dtype=int) 77 | vals_interp = txl[t_interp] 78 | f = interp1d(vals_interp, t_interp.astype(float), 'linear') 79 | on_events = np.append(on_events, np.apply_along_axis(f, 0, spk_values)) 80 | last_value = spk_values[-1] if spk_values.size > 0 else last_value # Change value of sensor when spike 81 | elif d_pair < 0: 82 | start = last_value - threshold 83 | stop = all_values[i+1] - 0.0001 # No Threshold 84 | spk_values = np.round(np.arange(start, stop, -1*threshold), Precision) 85 | # Interpolation with all the values of the pair 86 | pts = all_t[i+1] - all_t[i] + 1 87 | t_interp = np.linspace(all_t[i], all_t[i+1], pts, dtype=int) 88 | vals_interp = txl[t_interp] 89 | f = interp1d(vals_interp, t_interp, 'linear') 90 | off_events = np.append(off_events, np.apply_along_axis(f, 0, spk_values)) 91 | last_value = spk_values[-1] if spk_values.size > 0 else last_value # Change value of sensor when spike 92 | # Assign events 93 | taxel_list.append((on_events * dt).tolist()) 94 | taxel_list.append((off_events * dt).tolist()) 95 | sample_list.append(taxel_list) 96 | # Plot conversions. Run in debug mode 97 | if DEBUG: 98 | plt.rcParams['text.usetex'] = True 99 | f1 = plt.figure() 100 | axes = plt.axes() 101 | n = len(txl) 102 | scale = 1/5 103 | axes.set_xlim([0, ((scale * n) - 0.5) * dt]) 104 | axes.set_ylim([-0.5, 0.5]) 105 | if taxel_list[0]: 106 | plt.eventplot(taxel_list[0], lineoffsets=0.15, 107 | colors='green', linelength=0.25) 108 | if taxel_list[1]: 109 | plt.eventplot(taxel_list[1], lineoffsets=-0.15, 110 | colors='red', linelength=0.25) 111 | 112 | axes.set_ylabel(r'$\vartheta = ${}'.format(str(threshold))) 113 | if save: 114 | plt.savefig('{}encoding_TH{}_taxel_{}_events.png'.format(DirOut, str(threshold), str(nt)), dpi=200) 115 | f2 = plt.figure() 116 | axes = plt.axes() 117 | axes.set_xlim([0, ((scale * n) - 0.5) * dt]) 118 | plt.plot(np.arange(start=0, stop=(n - 0.5) * dt, step=dt), txl - txl[0], '-o') 119 | axes.set_ylabel("Sensor value") 120 | axes.set_xlabel('t(s)') 121 | if save: 122 | plt.savefig('{}encoding_TH{}_taxel_{}_sample.png'.format(DirOut, str(threshold), str(nt)), dpi=200) 123 | 124 | return sample_list 125 | 126 | def extract_data_icub_raw_integers(file_name): 127 | ''' Read the files and convert taxel data and labels 128 | file_name: filename of the dataset in format dict{'taxel_data':, 'letter':} 129 | ''' 130 | data = [] 131 | labels = [] 132 | print("file name {}".format(file_name)) 133 | with open(file_name, 'rb') as infile: 134 | data_dict = pickle.load(infile) 135 | for item in data_dict: 136 | dat = np.abs(255 - item['taxel_data'][:]) 137 | data.append(dat) 138 | labels.append(item['letter']) 139 | return data, labels 140 | 141 | def main(): 142 | ''' Convert time-based data into event-based data ''' 143 | Spk_threshold = 1 # 1, 2, 5, 10 (default) 144 | Events_filename_out = '../data/data_braille_letters_th{}'.format(str(Spk_threshold)) 145 | f = 40 # Hz 146 | data_raw, labels_raw = extract_data_icub_raw_integers('../data/data_braille_letters_raw') 147 | samples = list() 148 | if save_fig: 149 | isExist = os.path.exists(DirOut) 150 | if not isExist: 151 | os.makedirs(DirOut) 152 | # Each sequence sample is parsed to events 153 | for sample_raw, label in zip(data_raw, labels_raw): 154 | data_dict_events = {} 155 | events_per_samples = sample_to_changes(sample_raw, f, Spk_threshold, save=save_fig) 156 | # Dict of the sample 157 | data_dict_events['letter'] = label 158 | data_dict_events['events'] = events_per_samples 159 | samples.append(data_dict_events) 160 | print('Finished conversion') 161 | with open(Events_filename_out, 'wb') as outf: 162 | pickle.dump(samples, outf) 163 | 164 | if __name__ == "__main__": 165 | main() --------------------------------------------------------------------------------