├── .gitignore ├── neurons ├── __init__.py ├── base.py ├── lif.py └── spsn.py ├── outputs ├── basic │ ├── model_LIF_1675938986.pt │ ├── model_SPSN-GS_1675890063.pt │ ├── model_SPSN-SB_1675889673.pt │ ├── model_Non-Spiking_1675888688.pt │ ├── results_LIF_1675938986.json │ ├── results_SPSN-GS_1675890063.json │ ├── results_SPSN-SB_1675889673.json │ └── results_Non-Spiking_1675888688.json ├── data_aug │ ├── model_LIF_1675940888.pt │ ├── model_SPSN-GS_1675891161.pt │ ├── model_SPSN-SB_1675890736.pt │ ├── model_Non-Spiking_1675889858.pt │ ├── results_Non-Spiking_1675889858.json │ ├── results_LIF_1675940888.json │ ├── results_SPSN-SB_1675890736.json │ └── results_SPSN-GS_1675891161.json └── regul │ ├── model_SPSN-GS_1675891157.pt │ ├── model_SPSN-SB_1675890850.pt │ ├── results_SPSN-GS_1675891157.json │ └── results_SPSN-SB_1675890850.json ├── README.md ├── datasets.py ├── run.py └── network.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets -------------------------------------------------------------------------------- /neurons/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /outputs/basic/model_LIF_1675938986.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/basic/model_LIF_1675938986.pt -------------------------------------------------------------------------------- /outputs/basic/model_SPSN-GS_1675890063.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/basic/model_SPSN-GS_1675890063.pt -------------------------------------------------------------------------------- /outputs/basic/model_SPSN-SB_1675889673.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/basic/model_SPSN-SB_1675889673.pt -------------------------------------------------------------------------------- /outputs/data_aug/model_LIF_1675940888.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/data_aug/model_LIF_1675940888.pt -------------------------------------------------------------------------------- /outputs/regul/model_SPSN-GS_1675891157.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/regul/model_SPSN-GS_1675891157.pt -------------------------------------------------------------------------------- /outputs/regul/model_SPSN-SB_1675890850.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/regul/model_SPSN-SB_1675890850.pt -------------------------------------------------------------------------------- /outputs/data_aug/model_SPSN-GS_1675891161.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/data_aug/model_SPSN-GS_1675891161.pt -------------------------------------------------------------------------------- /outputs/data_aug/model_SPSN-SB_1675890736.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/data_aug/model_SPSN-SB_1675890736.pt -------------------------------------------------------------------------------- /outputs/basic/model_Non-Spiking_1675888688.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/basic/model_Non-Spiking_1675888688.pt -------------------------------------------------------------------------------- /outputs/data_aug/model_Non-Spiking_1675889858.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NECOTIS/Stochastic-Parallelizable-Spiking-Neuron-SPSN/HEAD/outputs/data_aug/model_Non-Spiking_1675889858.pt -------------------------------------------------------------------------------- /neurons/base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on February 2023 4 | 5 | @author: Anonymous 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | 13 | class Base(torch.nn.Module): 14 | """ 15 | Base class for creating a spiking neural network using PyTorch. 16 | 17 | Parameters: 18 | - input_size (int): size of input tensor 19 | - hidden_size (int): size of hidden layer 20 | - device (torch.device): device to use for tensor computations, such as 'cpu' or 'cuda' 21 | - fire (bool, optional): flag to determine if the neurons should fire spikes or not (default: True) 22 | - tau_mem (float, optional): time constant for the membrane potential (default: 1e-3) 23 | - tau_syn (float, optional): time constant for the synaptic potential (default: 1e-3) 24 | - time_step (float, optional): step size for updating the LIF model (default: 1e-3) 25 | - debug (bool, optional): flag to turn on/off debugging mode (default: False) 26 | """ 27 | def __init__(self, input_size, hidden_size, device, 28 | fire, tau_mem, tau_syn, time_step, debug): 29 | super(Base, self).__init__() 30 | self.input_size = input_size 31 | self.hidden_size = hidden_size 32 | self.device = device 33 | self.v_th = torch.tensor(1.0) 34 | self.fire = fire 35 | self.debug = debug 36 | self.nb_spike_per_neuron = torch.zeros(self.hidden_size, device=self.device) 37 | 38 | # Neuron time constants 39 | self.alpha = float(np.exp(-time_step/tau_syn)) 40 | self.beta = float(np.exp(-time_step/tau_mem)) 41 | self.beta_1 = 1-self.beta 42 | 43 | # Fully connected layer for synapses 44 | self.fc = torch.nn.Linear(self.input_size, self.hidden_size, device=self.device) 45 | 46 | # Initializing weights 47 | torch.nn.init.kaiming_uniform_(self.fc.weight, a=0, mode='fan_in', nonlinearity='linear') 48 | torch.nn.init.zeros_(self.fc.bias) 49 | if self.debug: 50 | torch.nn.init.ones_(self.fc.weight) 51 | 52 | 53 | -------------------------------------------------------------------------------- /neurons/lif.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on February 2023 4 | 5 | @author: Anonymous 6 | """ 7 | 8 | import torch 9 | from neurons.base import Base 10 | 11 | 12 | class LIF(Base): 13 | """ 14 | Class for implementing a Leaky Integrate and Fire (LIF) neuron model 15 | 16 | Parameters: 17 | - input_size (int): size of input tensor 18 | - hidden_size (int): size of hidden layer 19 | - device (torch.device): device to use for tensor computations, such as 'cpu' or 'cuda' 20 | - fire (bool, optional): flag to determine if the neurons should fire spikes or not (default: True) 21 | - tau_mem (float, optional): time constant for the membrane potential (default: 1e-3) 22 | - tau_syn (float, optional): time constant for the synaptic potential (default: 1e-3) 23 | - time_step (float, optional): step size for updating the LIF model (default: 1e-3) 24 | - debug (bool, optional): flag to turn on/off debugging mode (default: False) 25 | """ 26 | def __init__(self, input_size, hidden_size, device, 27 | fire=True, tau_mem=1e-3, tau_syn=1e-3, time_step=1e-3, 28 | debug=False): 29 | 30 | super(LIF, self).__init__(input_size, hidden_size, device, 31 | fire, tau_mem, tau_syn, time_step, 32 | debug) 33 | # Set the spiking function 34 | self.spike_fn = SurrGradSpike.apply 35 | 36 | 37 | def forward(self, inputs): 38 | """ 39 | Perform forward pass of the network 40 | 41 | Parameters: 42 | - inputs (tensor): Input tensor with shape (batch_size, nb_steps, input_size) 43 | 44 | Returns: 45 | - Return membrane potential tensor with shape (batch_size, nb_steps, hidden_size) if 'fire' is False 46 | - Return spiking tensor with shape (batch_size, nb_steps, hidden_size) if 'fire' is True 47 | - Return the tuple (spiking tensor, membrane potential tensor) if 'debug' is True 48 | """ 49 | X = self.fc(inputs) 50 | batch_size,nb_steps,_ = X.shape 51 | syn = torch.zeros((batch_size,self.hidden_size), device=self.device) 52 | mem = torch.zeros((batch_size,self.hidden_size), device=self.device) 53 | mem_rec = [] 54 | spk_rec = [] 55 | 56 | # Iterate over each time step 57 | for t in range(nb_steps): 58 | # Integrating input to synaptic current - Equation (5) 59 | syn = self.alpha*syn + X[:,t] 60 | # Integrating synaptic current to membrane potential - Equation (6) 61 | mem = self.beta*mem + self.beta_1*syn 62 | if self.fire: 63 | # Spikes generation - Equation (3) 64 | spk = self.spike_fn(mem-self.v_th) 65 | spk_rec.append(spk) 66 | # Membrane potential reseting - Equation (6) 67 | mem = mem * (1-spk.detach()) 68 | mem_rec.append(mem) 69 | 70 | mem_rec = torch.stack(mem_rec,dim=1) 71 | if self.fire: 72 | spk_rec = torch.stack(spk_rec,dim=1) 73 | self.nb_spike_per_neuron = torch.mean(torch.mean(spk_rec,dim=0),dim=0) 74 | return (spk_rec, mem_rec) if self.debug else spk_rec 75 | return mem_rec 76 | 77 | 78 | 79 | # Surrogate gradient implementation from https://github.com/fzenke/spytorch/blob/main/notebooks/SpyTorchTutorial1.ipynb 80 | class SurrGradSpike(torch.autograd.Function): 81 | scale = 100.0 82 | @staticmethod 83 | def forward(ctx, input): 84 | ctx.save_for_backward(input) 85 | out = torch.zeros_like(input) 86 | out[input > 0] = 1.0 87 | return out 88 | @staticmethod 89 | def backward(ctx, grad_output): 90 | input, = ctx.saved_tensors 91 | grad_input = grad_output.clone() 92 | grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2 93 | return grad -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Accelerating SNN Training with Stochastic Parallelizable Spiking Neurons (SPSN) 3 | 4 | > [!WARNING] 5 | > SPSN has evolved into the ParaLIF (Parallelizable Leaky-Integrate-and-Fire) neuron. ParaLIF allows more stochastic and deterministic spiking functions. A recurrent version is also available. Visit https://github.com/NECOTIS/Parallelizable-Leaky-Integrate-and-Fire-Neuron 6 | 7 | This repository contains code for simulating the proposed SPSN to accelerate training of spiking neural networks (SNN). The SPSN is compared to Leaky Integrate and Fire (LIF) neuron on the Spiking Heidelberg Digits (SHD) dataset. This repository consists of a few key components: 8 | 9 | - `datasets.py`: This module provides a simple interface for loading and accessing training and test datasets. 10 | 11 | - `network.py`: This module contains the implementation of the neural network itself, including code for training and evaluating the network. 12 | 13 | - `run.py`: This is the main entry point for running the simulation. It provides a simple command-line interface for specifying various options. 14 | 15 | - `datasets` directory: This directory contains training and test datasets as hdf5 files. The SHD dataset needs to be downloaded to this directory from https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/ 16 | 17 | - `neurons` directory: This directory contains implementations for the two neurons types, extending the base class in `base.py`. The available models are: 18 | 19 | - `lif.py`: The Leaky Integrate-and-Fire model 20 | - `spsn.py`: The Stochastic Parallelizable Spiking Neuron model. It can be simulated with the Sigmoid-Bernouilli firing mode (SPSP-SB) or with the Gumbel Softmax firing mode (SPSN-GS). 21 | 22 | - `outputs` directory: This directory contains outputs generated by the simulation. 23 | 24 | 25 | ## Usage 26 | The `run.py` script can be run using various arguments. The following are available: 27 | 28 | - `--seed`: Random seed for reproducibility. 29 | - `--dataset`: The dataset to use for training, currently only `heidelberg` is supported. 30 | - `--neuron`: The neuron model to use for training, options include `LIF`, `SPSN-SB`, `SPSN-GS`, and `Non-Spiking`. The `Non-Spiking` neuron is a traditional neuron followed by a ReLu activation. 31 | - `--nb_epochs`: The number of training epochs. 32 | - `--tau_mem`: The neuron membrane time constant. 33 | - `--tau_syn`: The neuron synaptic current time constant. 34 | - `--batch_size`: The batch size for training. 35 | - `--hidden_size`: The number of neurons in the hidden layer. 36 | - `--nb_layers`: The number of hidden layers. 37 | - `--reg_thr`: The spiking frequency regularization threshold. 38 | - `--loss_mode`: The mode for computing the loss, options include `last`, `max`, and `mean`. 39 | - `--data_augmentation`: Whether to use data augmentation during training, options include `True` and `False`. 40 | - `--h_shift`: The random shift factor for data augmentation. 41 | - `--scale`: The random scale factor for data augmentation. 42 | - `--dir`: The directory to save the results. 43 | - `--save_model`: Whether to save the trained model, options include `True` and `False`. 44 | 45 | ### Examples - Basic 46 | To run the code in the basic mode, the following commands can be used. 47 | ```console 48 | python run.py --seed 0 --neuron 'LIF' 49 | python run.py --seed 0 --neuron 'SPSN-SB' 50 | python run.py --seed 0 --neuron 'SPSN-GS' 51 | python run.py --seed 0 --neuron 'Non-Spiking' 52 | ``` 53 | 54 | ### Examples - Data augmentation 55 | To add data augmentation when training, the following commands can be used. 56 | ```console 57 | python run.py --seed 0 --neuron 'LIF' --data_augmentation True 58 | python run.py --seed 0 --neuron 'SPSN-SB' --data_augmentation True 59 | python run.py --seed 0 --neuron 'SPSN-GS' --data_augmentation True 60 | python run.py --seed 0 --neuron 'Non-Spiking' --data_augmentation True 61 | ``` 62 | 63 | ### Examples - Regularization 64 | To reduce spiking frequency for SPSN the regularization can be used by the following commands: 65 | ```console 66 | python run.py --seed 0 --neuron 'SPSN-SB' --data_augmentation True --reg_thr 0.4 67 | python run.py --seed 0 --neuron 'SPSN-GS' --data_augmentation True --reg_thr 0.1 68 | ``` 69 | 70 | ## Results 71 | 72 | The results achieved for the commands listed above are summurized in the following tables : 73 | 74 | - Classification accuracy for the test set : 75 | 76 | | Neuron| Basic | Data augmentation | Data augmentation + Regularization | 77 | |----------|----------|----------|----------| 78 | | **LIF** | 71.37% | 83.03% | - | 79 | | **SPSN-SB** | 77.16 % | 86.08 % | 89.70 % | 80 | | **SPSN-GS** | 75.66 % | 86.08 % | 89.39 % | 81 | | **Non-Spiking** | 71.82 % | 66.07 % | - | 82 | 83 | - 1 epoch training duration : 84 | 85 | | Neuron| Basic | Data augmentation | Data augmentation + Regularization | 86 | |----------|----------|----------|----------| 87 | | **LIF** | 252.2 s | 261.7 s | - | 88 | | **SPSN-SB** | 5.5 s | 10.7 s | 10.7 s | 89 | | **SPSN-GS** | 6.8 s | 12.2 s | 12.2 s | 90 | | **Non-Spiking** | 1.5 s | 6.7 s | - | 91 | 92 | 93 | 94 | 95 | ## Requirements 96 | The required librairies to run the code are : 97 | - h5py 98 | - numpy 99 | - torch 100 | - torchvision 101 | - tqdm -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on February 2023 4 | 5 | @author: Anonymous 6 | """ 7 | import os 8 | import h5py 9 | import torch 10 | import numpy as np 11 | import torch.nn.functional as F 12 | import torchvision.transforms as T 13 | 14 | 15 | # The Spiking Heidelberg Digits data set can be found at https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/ 16 | def heidelberg_dataset(window_size=1, device=None, augment=False, h_shift=0., scale=0.): 17 | """ 18 | Loads the Heidelberg dataset and sets up the training and testing datasets, feature and output sizes, and collate functions for data processing. 19 | 20 | Parameters: 21 | - window_size: size of the time window to be used as input to the model 22 | - device: device to use for PyTorch tensor computations (e.g. CPU or GPU) 23 | - augment: boolean flag indicating whether to perform data augmentation on the training set 24 | - h_shift: horizontal shift applied to the data during augmentation 25 | - scale: scale applied to the data during augmentation 26 | 27 | Returns: 28 | - train_set: PyTorch dataset for training 29 | - test_set: PyTorch dataset for testing 30 | - nb_features: number of features in the dataset 31 | - nb_class: number of outputs in the dataset 32 | - collate_fn_train: collate function for processing the training data 33 | - collate_fn_test: collate function for processing the testing data 34 | """ 35 | 36 | nb_features = 700 37 | dt = 1e-3 38 | nb_class = 20 39 | train_set = Dataset_shd('datasets/shd_train.h5', dt, nb_features, window_size=window_size, device=device) 40 | test_set = Dataset_shd('datasets/shd_test.h5', dt, nb_features, window_size=window_size, device=device) 41 | max_duration = max(train_set.max_bin,test_set.max_bin) 42 | collate_fn_train = spikeTimeToMatrix_shd(device=device, augment=augment, h_shift=h_shift, scale=scale) 43 | collate_fn_test = spikeTimeToMatrix_shd(device=device, augment=False) 44 | 45 | return train_set, test_set, nb_features, nb_class, collate_fn_train, collate_fn_test, max_duration 46 | 47 | 48 | 49 | class Dataset_shd(torch.utils.data.Dataset): 50 | """ 51 | Custom torch Dataset for loading spike data from hdf5 file. 52 | 53 | Parameters: 54 | - path: path to the hdf5 file 55 | - dt: time step for spike data 56 | - nb_features: number of features for spike data 57 | - window_size: size of window for spike data (default = 1) 58 | - device: device to use for PyTorch tensor computations (e.g. CPU or GPU) 59 | """ 60 | def __init__(self, path, dt, nb_features, window_size=1, device=None): 61 | super(Dataset_shd, self).__init__() 62 | assert os.path.exists(path), f"shd dataset not found at '{path}'. It is available at https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/" 63 | 64 | self.nb_features = nb_features 65 | self.device = device 66 | 67 | f = h5py.File(path, 'r') 68 | spikes_times_ = [k for k in f['spikes']['times']] 69 | spikes_units = [k.astype(np.int32) for k in f['spikes']['units']] 70 | self.labels = [k for k in f['labels']] 71 | f.close() 72 | 73 | # Get the maximum duration of the spikes data 74 | self.max_duration = int(max([t.max() for t in spikes_times_])/dt) 75 | self.max_bin = int(self.max_duration/window_size)+1 76 | bins = np.linspace(0, self.max_duration, num=self.max_bin) 77 | spikes_times_digitized = [np.digitize(t/dt,bins,right=True) for t in spikes_times_] 78 | # Convert the digitized spike times and units to sparse tensors 79 | self.inputs_data = [self.to_sparse_tensor(spikes_t, spikes_u) for (spikes_t, spikes_u) in zip(spikes_times_digitized,spikes_units)] 80 | 81 | def to_sparse_tensor(self, spikes_times, spikes_units): 82 | """ 83 | Convert digitized spike times and units to a sparse tensor 84 | spikes_times: digitized spike times 85 | spikes_units: units of the spikes 86 | """ 87 | v = torch.ones(len(spikes_times)) 88 | shape = [spikes_times.max()+1, self.nb_features] 89 | t = torch.sparse_coo_tensor(torch.tensor([spikes_times.tolist(), spikes_units.tolist()]), v, shape, dtype=torch.float32, device=self.device) 90 | return t 91 | 92 | def __getitem__(self, index): 93 | return self.inputs_data[index], self.labels[index] 94 | 95 | def __len__(self): 96 | return len(self.labels) 97 | 98 | 99 | 100 | class spikeTimeToMatrix_shd(torch.nn.Module): 101 | """ 102 | collate function for processing data. 103 | Apply data augmentation and pad spike trains if necessary 104 | """ 105 | 106 | def __init__(self, device=None, augment=False, h_shift=0.1, scale=0.3): 107 | super().__init__() 108 | self.device = device 109 | self.augment = augment 110 | self.h_shift = h_shift # channels axis 111 | self.scale = scale 112 | self.affine_transfomer = T.RandomAffine(degrees=0, translate=(self.h_shift, 0.), scale=(1-self.scale,1+self.scale)) 113 | 114 | def forward(self, samples): 115 | max_d = max([st[0].shape[0] for st in samples]) 116 | spike_train_batch = [] 117 | labels_batch = [] 118 | 119 | for (spike_train, label) in samples: 120 | # apply data augmentation if augment is True 121 | if self.augment: 122 | spike_train = self.affine_transfomer(spike_train.to_dense().unsqueeze(0)).squeeze(0) 123 | else: 124 | spike_train = spike_train.to_dense() 125 | 126 | # pad spike trains if needed 127 | pad = (0, 0, max_d-spike_train.shape[0], 0) 128 | spike_train_batch.append(F.pad(spike_train, pad, "constant", 0)) 129 | labels_batch.append(label) 130 | 131 | return torch.stack(spike_train_batch), torch.tensor(labels_batch, device=self.device, dtype=torch.long) 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /neurons/spsn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on February 2023 4 | 5 | @author: Anonymous 6 | """ 7 | 8 | import torch 9 | from neurons.base import Base 10 | 11 | 12 | class SPSN(Base): 13 | """ 14 | Class for implementing a Stochastic and Parallelizable Spiking Neuron (SPSN) model 15 | 16 | Parameters: 17 | - input_size (int): size of input tensor 18 | - hidden_size (int): size of hidden layer 19 | - device (torch.device): device to use for tensor computations, such as 'cpu' or 'cuda' 20 | - spike_mode (str): "SB" for Sigmoid-Bernoulli or "GS" for Gumbel-Softmax 21 | - nb_steps (int): number of timesteps of inputs 22 | - fire (bool, optional): flag to determine if the neurons should fire spikes or not (default: True) 23 | - tau_mem (float, optional): time constant for the membrane potential (default: 1e-3) 24 | - tau_syn (float, optional): time constant for the synaptic potential (default: 1e-3) 25 | - time_step (float, optional): step size for updating the LIF model (default: 1e-3) 26 | - debug (bool, optional): flag to turn on/off debugging mode (default: False) 27 | """ 28 | 29 | def __init__(self, input_size, hidden_size, device, spike_mode, nb_steps=None, 30 | fire=True, tau_mem=1e-3, tau_syn=1e-3, time_step=1e-3, debug=False): 31 | 32 | super(SPSN, self).__init__(input_size, hidden_size, device, 33 | fire, tau_mem, tau_syn, time_step, 34 | debug) 35 | # Set the spiking function 36 | if spike_mode=="SB": self.spike_fn = SigmoidBernoulli(self.device) 37 | elif spike_mode=="GS": self.spike_fn = GumbelSoftmax(self.device) 38 | 39 | self.nb_steps = nb_steps 40 | # Parameters can be computed upstream if the number of timesteps "nb_steps" is known 41 | self.fft_l_k = self.compute_params_fft(self.nb_steps) 42 | 43 | 44 | def compute_params_fft(self, nb_steps): 45 | """ 46 | Compute the FFT of the parameters for parallel Leaky Integration 47 | 48 | Returns: 49 | fft_l_k: Product of FFT of parameters l and k 50 | """ 51 | if nb_steps is None: return None 52 | 53 | l = torch.pow(self.alpha,torch.arange(nb_steps,device=self.device)) 54 | k = torch.pow(self.beta,torch.arange(nb_steps,device=self.device))*self.beta_1 55 | fft_l = torch.fft.rfft(l, n=2*nb_steps).unsqueeze(1) 56 | fft_k = torch.fft.rfft(k, n=2*nb_steps).unsqueeze(1) 57 | 58 | return fft_l*fft_k 59 | 60 | 61 | def forward(self, inputs): 62 | """ 63 | Perform forward pass of the network 64 | 65 | Parameters: 66 | - inputs (tensor): Input tensor with shape (batch_size, nb_steps, input_size) 67 | 68 | Returns: 69 | - Return membrane potential tensor with shape (batch_size, nb_steps, hidden_size) if 'fire' is False 70 | - Return spiking tensor with shape (batch_size, nb_steps, hidden_size) if 'fire' is True 71 | - Return the tuple (spiking tensor, membrane potential tensor) if 'debug' is True 72 | """ 73 | 74 | X = self.fc(inputs) 75 | batch_size,nb_steps,_ = X.shape 76 | 77 | # Recompute FFT params if nb_steps has changed 78 | if self.nb_steps!=nb_steps: 79 | self.fft_l_k = self.compute_params_fft(nb_steps) 80 | self.nb_steps = nb_steps 81 | 82 | # Perform parallel leaky integration 83 | fft_X = torch.fft.rfft(X, n=2*nb_steps, dim=1) 84 | mem_rec = torch.fft.irfft(fft_X*self.fft_l_k, n=2*nb_steps, dim=1)[:,:nb_steps:,] # Equation (15) 85 | 86 | if self.fire: 87 | # Perform stochastic firing 88 | spk_rec = self.spike_fn(mem_rec) 89 | self.nb_spike_per_neuron = torch.mean(torch.mean(spk_rec,dim=0),dim=0) 90 | return (spk_rec, mem_rec) if self.debug else spk_rec 91 | return mem_rec 92 | 93 | 94 | 95 | 96 | class StochasticStraightThrough(torch.autograd.Function): 97 | @staticmethod 98 | def forward(ctx, input): 99 | ctx.save_for_backward(input) 100 | out = torch.bernoulli(input) # Equation (18) 101 | return out 102 | @staticmethod 103 | def backward(ctx, grad_output): 104 | input, = ctx.saved_tensors 105 | grad_input = grad_output.clone() 106 | return grad_input*input # Equation (19) 107 | 108 | 109 | 110 | class SigmoidBernoulli(torch.nn.Module): 111 | def __init__(self, device): 112 | super().__init__() 113 | self.spike_fn = StochasticStraightThrough.apply 114 | 115 | def forward(self, inputs): 116 | spk_prob = torch.sigmoid(inputs) # Equation (17) 117 | spk = self.spike_fn(spk_prob) 118 | return spk 119 | 120 | 121 | 122 | 123 | 124 | 125 | class GumbelSoftmax(torch.nn.Module): 126 | def __init__(self, device, hard=True, tau=1.0): 127 | super().__init__() 128 | 129 | self.hard = hard 130 | self.tau = tau 131 | self.uniform = torch.distributions.Uniform(torch.tensor(0.0).to(device), 132 | torch.tensor(1.0).to(device)) 133 | self.softmax = torch.nn.Softmax(dim=0) 134 | 135 | 136 | def forward(self, logits): 137 | # Sample uniform noise 138 | unif = self.uniform.sample(logits.shape + (2,)) 139 | # Compute Gumbel noise from the uniform noise 140 | gumbels = -torch.log(-torch.log(unif)) 141 | # Apply softmax function to the logits and Gumbel noise 142 | y_soft = self.softmax(torch.stack([(logits + gumbels[...,0]) / self.tau, 143 | (-logits + gumbels[...,1]) / self.tau]))[0] 144 | if self.hard: 145 | # Use straight-through estimator 146 | y_hard = torch.where(y_soft > 0.5, 1.0, 0.0) 147 | ret = y_hard - y_soft.detach() + y_soft 148 | else: 149 | # Use reparameterization trick 150 | ret = y_soft 151 | 152 | return ret -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on February 2023 4 | 5 | @author: Anonymous 6 | """ 7 | 8 | import os 9 | import json 10 | import torch 11 | import random 12 | import argparse 13 | import numpy as np 14 | from datetime import datetime 15 | from datasets import heidelberg_dataset 16 | from network import create_network, train, test 17 | 18 | 19 | parser = argparse.ArgumentParser(description="SNN training") 20 | parser.add_argument('--seed', type=int) 21 | parser.add_argument('--dataset', type=str, default='heidelberg', choices=["heidelberg"]) 22 | parser.add_argument('--neuron', type=str, default='LIF', choices=["LIF", "SPSN-SB", "SPSN-GS", "Non-Spiking"]) 23 | parser.add_argument('--nb_epochs', type=int, default=200) 24 | parser.add_argument('--tau_mem', type=float, default=2e-2, help='neuron membrane time constant') 25 | parser.add_argument('--tau_syn', type=float, default=2e-2, help='neuron synaptic current time constant') 26 | parser.add_argument('--batch_size', type=int, default=64) 27 | parser.add_argument('--hidden_size', type=int, default=128, help='nb of neurons in the hidden layer') 28 | parser.add_argument('--nb_layers', type=int, default=3, help='nb of hidden layers') 29 | parser.add_argument('--reg_thr', type=float, default=0., help='spiking frequency regularization threshold') 30 | parser.add_argument('--loss_mode', type=str, default='mean', choices=["last", "max", "mean"]) 31 | parser.add_argument('--data_augmentation', type=str, default='False', choices=["True", "False"]) 32 | parser.add_argument('--h_shift', type=float, default=0.1, help='data augmentation random shift factor') 33 | parser.add_argument('--scale', type=float, default=0.3, help='data augmentation random scale factor') 34 | parser.add_argument('--dir', type=str, default='') 35 | parser.add_argument('--save_model', type=str, default='False', choices=["True", "False"]) 36 | 37 | args = parser.parse_args() 38 | PARAMS = { 39 | "seed" : args.seed, 40 | "dataset" : args.dataset, 41 | "neuron" : args.neuron, 42 | "nb_epochs" : args.nb_epochs, 43 | "tau_mem" : args.tau_mem, 44 | "tau_syn" : args.tau_syn, 45 | "batch_size" : args.batch_size, 46 | "hidden_size" : args.hidden_size, 47 | "nb_layers" : args.nb_layers, 48 | "reg_thr" : args.reg_thr, 49 | "loss_mode" : args.loss_mode, 50 | "data_augmentation" : args.data_augmentation=='True', 51 | "h_shift" : args.h_shift, 52 | "scale" : args.scale, 53 | "dir" : args.dir, 54 | "save_model" : args.save_model=='True', 55 | } 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | def save_results(train_results, test_results, PARAMS, model): 64 | """ 65 | This function creates a dictionary of results from the training and testing and save it 66 | as a json file. 67 | If the 'save_model' parameter is set to True, the trained model is also saved. 68 | """ 69 | outputs = { 70 | 'loss_hist':train_results['loss'], 71 | 'train_accuracies':train_results['acc'], 72 | 'train_duration':train_results['dur'], 73 | 'test_accuracies': test_results['acc'], 74 | 'nb_spikes':test_results['spk'], 75 | 'test_duration':test_results['dur'], 76 | 'PARAMS': PARAMS 77 | } 78 | 79 | output_dir = f"outputs/{PARAMS['dir']}" 80 | timestamp = int(datetime.timestamp(datetime.now())) 81 | filename = output_dir+'results_{}_{}.json'.format(PARAMS['neuron'], str(timestamp)) 82 | os.makedirs(os.path.dirname(filename), exist_ok=True) 83 | 84 | with open(filename, 'w') as f: 85 | json.dump(outputs, f) 86 | 87 | if PARAMS['save_model']: 88 | modelname = output_dir+'model_{}_{}.pt'.format(PARAMS['neuron'], str(timestamp)) 89 | torch.save(model.state_dict(), modelname) 90 | 91 | 92 | 93 | def main(): 94 | """ 95 | This function : 96 | - Enable or not the reproductibility by setting a seed 97 | - Loads the train and test sets 98 | - Create the network 99 | - Train and test the network 100 | - Save the results 101 | """ 102 | print("\n-- Start --\n") 103 | #To enable reproductibility 104 | if PARAMS["seed"] is not None: 105 | seed=PARAMS["seed"] 106 | random.seed(seed) 107 | np.random.seed(seed) 108 | os.environ["PYTHONHASHSEED"] = str(seed) 109 | torch.manual_seed(seed) 110 | if torch.cuda.is_available(): 111 | torch.backends.cudnn.deterministic = True 112 | torch.backends.cudnn.benchmark = False 113 | 114 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 115 | 116 | # Loads the train and test sets 117 | if PARAMS["dataset"]=="heidelberg": 118 | (train_set, test_set, input_size, nb_class, collate_fn_train, 119 | collate_fn_test, max_duration) = heidelberg_dataset(device=device, augment=PARAMS["data_augmentation"], 120 | h_shift=PARAMS["h_shift"], scale=PARAMS["scale"]) 121 | PARAMS["input_size"]=input_size 122 | PARAMS["nb_class"]=nb_class 123 | 124 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=PARAMS['batch_size'], shuffle=True, collate_fn=collate_fn_train) 125 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=PARAMS['batch_size'], shuffle=False, collate_fn=collate_fn_test) 126 | 127 | # Create the network 128 | model = create_network(PARAMS, device, max_duration) 129 | 130 | # Train and test the network 131 | print("\n-- Training --\n") 132 | train_results = train(model, train_loader, nb_epochs=PARAMS['nb_epochs'], loss_mode=PARAMS['loss_mode'], reg_thr=PARAMS['reg_thr']) 133 | print("\n-- Testing --\n") 134 | test_results = test(model, test_loader, loss_mode=PARAMS['loss_mode']) 135 | 136 | # Save train and test results 137 | save_results(train_results, test_results, PARAMS, model) 138 | print("\n-- End --\n") 139 | 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on February 2023 4 | 5 | @author: Anonymous 6 | """ 7 | 8 | import time 9 | import torch 10 | import numpy as np 11 | from tqdm import tqdm 12 | from neurons.lif import LIF 13 | from neurons.spsn import SPSN 14 | import torch.nn.functional as F 15 | 16 | 17 | def create_network(params, device, max_duration): 18 | """ 19 | This function creates a neural network based on the given parameters 20 | """ 21 | neuron = params["neuron"] 22 | nb_layers = params["nb_layers"] 23 | input_size = params["input_size"] 24 | hidden_size = params["hidden_size"] 25 | nb_class = params["nb_class"] 26 | tau_mem = params["tau_mem"] 27 | tau_syn = params["tau_syn"] 28 | 29 | modules = [] 30 | if neuron=="LIF": 31 | modules.append(LIF(input_size, hidden_size, device, tau_mem=tau_mem, tau_syn=tau_syn)) 32 | for i in range(nb_layers-1): 33 | modules.append(LIF(hidden_size, hidden_size, device, tau_mem=tau_mem, tau_syn=tau_syn)) 34 | modules.append(LIF(hidden_size, nb_class, device, tau_mem=tau_mem, tau_syn=tau_syn, fire=False)) 35 | 36 | elif neuron in ["SPSN-SB", "SPSN-GS"]: 37 | spike_mode = neuron.split('-')[-1] 38 | modules.append(SPSN(input_size, hidden_size, device, spike_mode, max_duration, tau_mem=tau_mem, tau_syn=tau_syn)) 39 | for i in range(nb_layers-1): 40 | modules.append(SPSN(hidden_size, hidden_size, device, spike_mode, max_duration, tau_mem=tau_mem, tau_syn=tau_syn)) 41 | modules.append(SPSN(hidden_size, nb_class, device, spike_mode, max_duration, tau_mem=tau_mem, tau_syn=tau_syn, fire=False)) 42 | 43 | elif neuron=="Non-Spiking": 44 | modules.append(torch.nn.Linear(input_size, hidden_size, device=device)) 45 | modules.append(torch.nn.ReLU()) 46 | for i in range(nb_layers-1): 47 | modules.append(torch.nn.Linear(hidden_size, hidden_size, device=device)) 48 | modules.append(torch.nn.ReLU()) 49 | modules.append(torch.nn.Linear(hidden_size, nb_class, device=device)) 50 | model = torch.nn.Sequential(*modules) 51 | 52 | return model 53 | 54 | 55 | def train(model, data_loader, nb_epochs=100, loss_mode='mean', reg_thr=0.): 56 | """ 57 | This function Train the given model on the train data. 58 | """ 59 | optimizer = torch.optim.Adamax(model.parameters(), lr=1e-3) 60 | loss_fn = torch.nn.CrossEntropyLoss() 61 | 62 | # If a regularization threshold is set we compute the theta_reg*N parameter of Equation (21) 63 | if reg_thr>0: 64 | reg_thr_sum = reg_thr * np.sum([layer.hidden_size for layer in model if (layer.__class__.__name__ in ['LIF', 'SPSN'] and layer.fire)]) 65 | 66 | loss_hist = [] 67 | acc_hist = [] 68 | progress_bar = tqdm(range(nb_epochs), desc=f"Train {nb_epochs} epochs") 69 | start_time = time.time() 70 | # Loop over the number of epochs 71 | for i_epoch in progress_bar: 72 | local_loss = 0 73 | local_acc = 0 74 | total = 0 75 | nb_batch = len(data_loader) 76 | # Loop over the batches 77 | for i_batch,(x,y) in enumerate(data_loader): 78 | total += len(y) 79 | output = model(x) 80 | # Select the relevant function to process the output based on loss mode 81 | if loss_mode=='last' : output = output[:,-1,:] 82 | elif loss_mode=='max': output = torch.max(output,1)[0] 83 | else: output = torch.mean(output,1) 84 | 85 | # Here we set up our regularizer loss as in Equation (21) 86 | reg_loss_val = 0 87 | if reg_thr>0: 88 | spks = torch.stack([layer.nb_spike_per_neuron.sum() for layer in model if (layer.__class__.__name__ in ['LIF', 'SPSN'] and layer.fire)]) 89 | reg_loss_val = F.relu(spks.sum()-reg_thr_sum)**2 90 | 91 | # Here we combine supervised loss and the regularizer 92 | loss_val = loss_fn(output, y) + reg_loss_val 93 | 94 | # Backpropagation and weights update 95 | optimizer.zero_grad() 96 | loss_val.backward() 97 | optimizer.step() 98 | 99 | local_loss += loss_val.detach().cpu().item() 100 | _,y_pred = torch.max(output,1) 101 | local_acc += torch.sum((y==y_pred)).detach().cpu().numpy() 102 | progress_bar.set_postfix(loss=local_loss/total, accuracy=local_acc/total, _batch=f"{i_batch+1}/{nb_batch}") 103 | 104 | loss_hist.append(local_loss/total) 105 | acc_hist.append(local_acc/total) 106 | 107 | train_duration = (time.time()-start_time)/nb_epochs 108 | 109 | return {'loss':loss_hist, 'acc':acc_hist, 'dur':train_duration} 110 | 111 | 112 | def test(model, data_loader, loss_mode='mean'): 113 | """ 114 | This function Computes classification accuracy for the given model on the test data. 115 | """ 116 | acc = 0 117 | total = 0 118 | spk_per_layer = [] 119 | progress_bar = tqdm(data_loader, desc="Test") 120 | start_time = time.time() 121 | # loop through the test data 122 | for x,y in progress_bar: 123 | total += len(y) 124 | with torch.no_grad(): 125 | output = model(x) 126 | # Select the relevant function to process the output based on loss mode 127 | if loss_mode=='last' : output = output[:,-1,:] 128 | elif loss_mode=='max': output = torch.max(output,1)[0] 129 | else: output = torch.mean(output,1) 130 | # get the predicted label 131 | _,y_pred = torch.max(output,1) 132 | acc += torch.sum((y==y_pred)).cpu().numpy() 133 | # get the number of spikes per layer for LIF and SPSN layers 134 | spk_per_layer.append([layer.nb_spike_per_neuron.sum().cpu().item() for layer in model if (layer.__class__.__name__ in ['LIF', 'SPSN'] and layer.fire)]) 135 | progress_bar.set_postfix(accuracy=acc/total) 136 | test_duration = (time.time()-start_time) 137 | 138 | return {'acc':acc/total, 'spk':np.mean(spk_per_layer,axis=0).tolist(), 'dur':test_duration} 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | -------------------------------------------------------------------------------- /outputs/basic/results_LIF_1675938986.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [0.04662377902138987, 0.04120861525159072, 0.03331401735149569, 0.02657397970549905, 0.02184807600900202, 0.01907811719100226, 0.016589508797971567, 0.014939766967337526, 0.013392521688726499, 0.012080313763003421, 0.010986351770270741, 0.010108979514792248, 0.009276398048791422, 0.008547651085887486, 0.007953829380038207, 0.00731691990327227, 0.006885512647469761, 0.00644165251902823, 0.0061049807566442345, 0.005687105580397012, 0.0054866599943077755, 0.005015640957587019, 0.004786640592010773, 0.004554538685662189, 0.004425133282045572, 0.004129272626640751, 0.003974166377655369, 0.0036768371868712343, 0.003538640985167629, 0.0034049382792304345, 0.0031697164568730346, 0.0030593676798597396, 0.0030020431548269315, 0.0026757809552288687, 0.0026249864913775326, 0.0024819139193865878, 0.0023969666545460657, 0.002280415987660775, 0.0021149823981071875, 0.002085057232947312, 0.001956155041781724, 0.0018574703514919438, 0.0017183726057848725, 0.0016950400168913846, 0.0016045707943445038, 0.0015899912765237735, 0.0014300342697637406, 0.0013364932055403755, 0.0013812326576310434, 0.0013204122632705388, 0.001211394010501758, 0.0011492869816720486, 0.001006133497663869, 0.0010484900661708093, 0.0009623238001093671, 0.0009507965794369366, 0.000978994824754119, 0.0008205254228768277, 0.000849667497622347, 0.0007783299651820908, 0.0007162864707993154, 0.0006915775679011681, 0.0006767136475063089, 0.0006262066693224703, 0.0005884733929857429, 0.0005820751345025347, 0.0005427834124757912, 0.00047562573910811887, 0.0004546242124798866, 0.00041689462272350754, 0.0004650995728174748, 0.00036393050910551474, 0.0004283122409111352, 0.0003723436550891768, 0.00033002290858733207, 0.0003167749441362636, 0.00028210855483078834, 0.00028791561291807493, 0.000288672493213478, 0.0002828664467804766, 0.0002554683971080368, 0.0002530137462474775, 0.00023579188803409085, 0.00019643819348977046, 0.00019059624791724488, 0.00017151054694413656, 0.00018542048836728926, 0.00017529178017160256, 0.00014965517667696326, 0.00015195254929584797, 0.0001476430394670772, 0.00015712743878543611, 0.00012730860466282775, 0.00011249287687118708, 0.00011953461193924934, 0.00011309016534881396, 0.00011474141386205577, 0.00010640211727094708, 9.81321571659262e-05, 8.690408765851754e-05, 8.59324569923985e-05, 8.113610696608439e-05, 7.538916544553464e-05, 9.73684744103919e-05, 6.673408171724731e-05, 5.6738231965199535e-05, 6.115910443961493e-05, 7.455823131708602e-05, 5.280504342821272e-05, 5.404900114806556e-05, 5.223056559512358e-05, 4.490670282622071e-05, 4.3449219080558765e-05, 5.443029604753312e-05, 6.360234811163474e-05, 3.7874017522157394e-05, 3.762441621370506e-05, 3.2561530324889874e-05, 7.577074352570075e-05, 3.069251969405516e-05, 3.0908016637257696e-05, 2.8095548868747067e-05, 2.98474407186179e-05, 2.7890171836961156e-05, 2.7610611004043278e-05, 2.9208206370324006e-05, 2.9939339053600136e-05, 2.550046565308034e-05, 3.3510878033602134e-05, 2.5780260786820656e-05, 2.0890578672919256e-05, 2.1099841701245468e-05, 2.8922284180867107e-05, 2.3642123224706556e-05, 2.0701093131883225e-05, 1.8434174759761082e-05, 1.812147846650524e-05, 2.3659928524254648e-05, 2.1652808791901615e-05, 2.8038565492077743e-05, 1.655187918665276e-05, 1.652653311888992e-05, 2.6138885251130832e-05, 1.347302628286599e-05, 1.4811105064247507e-05, 1.2690382199800976e-05, 1.2282037864520949e-05, 1.2677716361588281e-05, 1.6411174388147676e-05, 1.588036295240302e-05, 1.1218703813765123e-05, 1.0437030312757645e-05, 1.0726027623788523e-05, 1.1032278865582414e-05, 1.3844035146067119e-05, 1.0120940795130598e-05, 9.281792060578038e-06, 1.0124197549799498e-05, 1.8815734566175873e-05, 9.011531471584929e-06, 7.425805923779385e-06, 7.76385733246385e-06, 6.404701894992581e-06, 6.537613097353107e-06, 6.177083258146544e-06, 6.129355515110616e-06, 5.922774781583515e-06, 1.4666104652236199e-05, 7.892146269538357e-06, 6.161244658148876e-06, 6.0589484658870795e-06, 5.1678281789682665e-06, 5.264876458338608e-06, 5.1449423189064405e-06, 6.441715870730514e-06, 5.4298259403502114e-06, 5.713718896808811e-06, 5.818609247186981e-06, 4.8165559993514195e-06, 5.583470132020749e-06, 5.100854080166511e-06, 5.144271657780982e-06, 4.352314255299442e-06, 4.529258082070125e-06, 1.673424934160953e-05, 2.155946813659768e-05, 4.749445007100442e-06, 3.9918792069403845e-06, 3.823246192881909e-06, 3.559997330766624e-06, 3.317354295079526e-06, 3.916073795043178e-06, 3.3918457453952663e-06, 3.2703796042896287e-06, 3.560200487528704e-06, 3.3531360277140882e-06, 3.220441991634815e-06, 3.2722875720711074e-06, 3.169811575070833e-06, 3.113441601951249e-06], "train_accuracies": [0.0734428641490927, 0.1818293281020108, 0.40129965669445805, 0.533962726826876, 0.6346248160863168, 0.6788867091711623, 0.7338155958803335, 0.7615252574791564, 0.7811427170181462, 0.8055419323197646, 0.8212358999509564, 0.8334968121628249, 0.8499264345267288, 0.8607160372731731, 0.8696665031878372, 0.883766552231486, 0.88805787150564, 0.8987248651299656, 0.9031387935262384, 0.9125796959293772, 0.9127023050514959, 0.9239823442864149, 0.9264345267287887, 0.9303580186365865, 0.9339136831780285, 0.9412702305051496, 0.941883276115743, 0.9499754781755763, 0.9502206964198137, 0.9518146150073565, 0.9584355076017655, 0.9600294261893085, 0.9592937714565963, 0.967999019127023, 0.9659146640510053, 0.969960765080922, 0.9710642471799902, 0.9714320745463463, 0.976213830308975, 0.9759686120647376, 0.9806277587052477, 0.9806277587052477, 0.9843060323688082, 0.9833251593918587, 0.9841834232466895, 0.9840608141245709, 0.9868808239333007, 0.9897008337420304, 0.9883521333987249, 0.9884747425208436, 0.9895782246199117, 0.9916625796959294, 0.99534085335949, 0.991417361451692, 0.9944825895046592, 0.9937469347719471, 0.9936243256498284, 0.9959538989700834, 0.9955860716037274, 0.9959538989700834, 0.9963217263364395, 0.9965669445806769, 0.9968121628249141, 0.9975478175576263, 0.9975478175576263, 0.9975478175576263, 0.9979156449239823, 0.9986512996566944, 0.9985286905345758, 0.9991417361451692, 0.998038254046101, 0.9996321726336439, 0.9985286905345758, 0.9992643452672879, 0.9998773908778813, 0.9993869543894066, 1.0, 0.9997547817557626, 0.9997547817557626, 0.9993869543894066, 1.0, 0.9996321726336439, 0.9995095635115253, 1.0, 0.9998773908778813, 0.9998773908778813, 0.9995095635115253, 0.9998773908778813, 1.0, 1.0, 1.0, 0.9998773908778813, 0.9998773908778813, 1.0, 1.0, 1.0, 1.0, 0.9998773908778813, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9998773908778813, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9996321726336439, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9998773908778813, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9998773908778813, 0.9998773908778813, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "train_duration": 252.28891336917877, "test_accuracies": 0.7137809187279152, "nb_spikes": [3.1986988451745777, 5.003289355172051, 13.823506010903252], "test_duration": 17.19488525390625, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "LIF", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.0, "loss_mode": "mean", "data_augmentation": false, "h_shift": 0.1, "scale": 0.3, "dir": "basic/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} -------------------------------------------------------------------------------- /outputs/basic/results_SPSN-GS_1675890063.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [0.06347647614780742, 0.03538964351875751, 0.030311497331191768, 0.0253143578022353, 0.021958204242983188, 0.01973618987204106, 0.01745963907814774, 0.015882885539397238, 0.014420703294640841, 0.013303151780802458, 0.011806554642888248, 0.01114754442741264, 0.010258743920385867, 0.009212436513493844, 0.00867101992971468, 0.008111945674546856, 0.007446622634018795, 0.006685662523182424, 0.006558272846923733, 0.005869609486931033, 0.005557933543653918, 0.005615185615535809, 0.0046373396077800345, 0.004540051057269726, 0.004264475257958543, 0.003992874012705507, 0.0038511329734731738, 0.003412355257306922, 0.0033487996745118326, 0.0030590992662078206, 0.002870650577407774, 0.002643045384661569, 0.00253352241032465, 0.0023976087532736555, 0.0024503490345414622, 0.0019135305725582503, 0.0019553163768965686, 0.0019328423993604448, 0.0018471125869583768, 0.0020473133334127226, 0.001393695838764025, 0.0014105589609079761, 0.0012724233107013758, 0.00147494222592451, 0.0012287131207846672, 0.0011951877996362033, 0.0013066865862340334, 0.0009104577043879816, 0.0010391580255619008, 0.0009297774483455695, 0.0011311480978731287, 0.0007869397137494956, 0.0006636087627384434, 0.00069076604154573, 0.0007087971690618316, 0.0009751957041897014, 0.0007428107076370655, 0.000500462358806496, 0.0006112681291613759, 0.0005530801839448691, 0.00045736984449359734, 0.0005086895572058144, 0.0003442757659285252, 0.0005537770626186783, 0.00034133867604330055, 0.0002821664431631609, 0.0003971115505238771, 0.00030539913463592164, 0.00041941416510593197, 0.00029981042241998285, 0.0003142908473075251, 0.00037111867078300887, 0.0004250435475483829, 0.00021061054058647027, 0.00021682487074020827, 0.00029819503345135314, 0.000279983362128621, 0.0002747498942361418, 0.00027046548064465597, 0.0002964730563782733, 0.00020701237875980094, 0.0001296574396196371, 0.00017350408273839033, 0.0003349532998946202, 0.00019448164024261477, 0.00011378097275519573, 0.0001820499899483351, 0.00045267757918555956, 0.00015463376291406346, 8.686182301597923e-05, 8.679748670784497e-05, 8.632953034299588e-05, 8.875531290144882e-05, 0.00031675303334819245, 9.742075085624881e-05, 8.550836224108478e-05, 8.700372995326845e-05, 7.52549593332541e-05, 7.881925958827642e-05, 8.347953163642781e-05, 0.00030062239325471504, 0.00012870569931628214, 6.293760348517218e-05, 7.621822499612081e-05, 0.00012944100563729098, 6.053746072335654e-05, 7.000178003302524e-05, 7.863404462399447e-05, 7.489872906222764e-05, 0.0002148679557862514, 0.00010241403509867288, 5.81422363518787e-05, 5.563331453384062e-05, 0.00022650193947158635, 9.456016586329779e-05, 4.637446394943837e-05, 3.835847089872132e-05, 3.814913859357258e-05, 4.240467165049363e-05, 0.00030236777494597106, 0.00032498339027675193, 7.875908418584116e-05, 3.9424943433258755e-05, 3.3944869678775364e-05, 3.2298092642931935e-05, 3.633537850176661e-05, 3.5590681956182066e-05, 3.1148609981891326e-05, 3.26272630816268e-05, 3.487672187459197e-05, 0.00011915508673537537, 3.665133710746672e-05, 3.117768573978379e-05, 3.3358745435507424e-05, 3.8446162456285496e-05, 3.091497811221149e-05, 2.824453493097661e-05, 3.6445763300101274e-05, 3.775677267693133e-05, 0.0003139956798559752, 6.963985804334792e-05, 2.7537116930863898e-05, 2.4237779150868832e-05, 2.3446729177880398e-05, 2.538622363374348e-05, 2.304127367075543e-05, 0.00010716855243310747, 0.00017168924828846047, 3.0719646266080366e-05, 2.3641799085969018e-05, 2.1437975470273098e-05, 2.008220563864949e-05, 2.2458928539889283e-05, 2.6314275974077557e-05, 2.1229168272305686e-05, 1.996945468926402e-05, 2.3871009482171063e-05, 2.1723380071157156e-05, 2.1007948072976375e-05, 0.0001861119981872682, 0.00014263312446948984, 2.3731636687053086e-05, 2.0821693071394623e-05, 2.0058710142371052e-05, 1.7837015026261094e-05, 1.7606657796046413e-05, 1.815745031386383e-05, 2.1115808577910286e-05, 1.5772607160239197e-05, 1.583452359870279e-05, 1.7951743553449455e-05, 1.5061963707419425e-05, 1.572163370136462e-05, 5.44076082271548e-05, 0.0002371666140743824, 2.1147228707154038e-05, 1.8197256330392254e-05, 1.777343188857595e-05, 1.537434227434373e-05, 1.466605500961563e-05, 1.3625440082984166e-05, 1.5140177664685448e-05, 1.551838898905103e-05, 1.3201638976687633e-05, 1.3668688360769293e-05, 1.3651335276795763e-05, 1.2087015291193363e-05, 1.8759018419478993e-05, 0.0001771537836912785, 0.00011526671511054522, 1.7841351247211493e-05, 1.4498830433967357e-05, 1.3714544166555142e-05, 1.3449506487601101e-05, 1.236511648636811e-05, 1.1291809092456101e-05, 1.1337050638277386e-05, 1.1858895214546522e-05, 1.0718732632354861e-05, 1.2052211417912098e-05], "train_accuracies": [0.14737616478666013, 0.297204512015694, 0.39811181951937225, 0.50465914664051, 0.5811672388425699, 0.624816086316822, 0.6748406081412457, 0.7103972535556645, 0.7486512996566944, 0.7650809220205983, 0.8046836684649338, 0.8126532614026484, 0.8258950465914664, 0.8456351152525748, 0.8589995095635116, 0.8659882295242766, 0.8756743501716527, 0.896640510053948, 0.88805787150564, 0.9070622854340363, 0.9131927415399705, 0.90951446787641, 0.9276606179499755, 0.929377145659637, 0.9363658656204021, 0.9383276115743011, 0.941515448749387, 0.952795487984306, 0.9510789602746444, 0.9563511525257479, 0.9562285434036293, 0.9622363903874448, 0.9639529180971064, 0.9661598822952427, 0.962972045120157, 0.9748651299656694, 0.9740068661108386, 0.974619911721432, 0.973516429622364, 0.9689798921039725, 0.9834477685139774, 0.985164296223639, 0.987126042177538, 0.9802599313388917, 0.9863903874448259, 0.9857773418342325, 0.9834477685139774, 0.9906817067189799, 0.9890877881314369, 0.9889651790093182, 0.9858999509563512, 0.9942373712604218, 0.995708680725846, 0.993379107405591, 0.9946051986267779, 0.9884747425208436, 0.9921530161844041, 0.9969347719470328, 0.9944825895046592, 0.9959538989700834, 0.9974252084355076, 0.9958312898479647, 0.9986512996566944, 0.9946051986267779, 0.9986512996566944, 0.9996321726336439, 0.9964443354585582, 0.9992643452672879, 0.9970573810691515, 0.9987739087788131, 0.9982834722903384, 0.998038254046101, 0.9954634624816087, 0.9992643452672879, 0.9996321726336439, 0.9975478175576263, 0.9986512996566944, 0.998038254046101, 0.9982834722903384, 0.9979156449239823, 0.9991417361451692, 1.0, 0.9995095635115253, 0.9961991172143208, 0.9992643452672879, 0.9998773908778813, 0.9986512996566944, 0.9939921530161844, 0.9992643452672879, 1.0, 0.9998773908778813, 0.9998773908778813, 1.0, 0.9963217263364395, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9959538989700834, 0.9992643452672879, 1.0, 0.9998773908778813, 0.9990191270230505, 1.0, 1.0, 0.9998773908778813, 0.9997547817557626, 0.9987739087788131, 0.9997547817557626, 1.0, 1.0, 0.998038254046101, 0.9996321726336439, 1.0, 1.0, 1.0, 1.0, 0.9958312898479647, 0.9942373712604218, 0.9997547817557626, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9991417361451692, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9998773908778813, 0.9947278077488965, 0.9996321726336439, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9981608631682197, 0.9973025993133889, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9973025993133889, 0.9979156449239823, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9996321726336439, 0.9966895537027954, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9975478175576263, 0.9987739087788131, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], "train_duration": 6.807290356159211, "test_accuracies": 0.7566254416961131, "nb_spikes": [59.09761386447482, 66.4406615363227, 54.353370878431534], "test_duration": 1.2759411334991455, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "SPSN-GS", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.0, "loss_mode": "mean", "data_augmentation": false, "h_shift": 0.1, "scale": 0.3, "dir": "basic/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} -------------------------------------------------------------------------------- /outputs/basic/results_SPSN-SB_1675889673.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [0.060015510612401735, 0.04032054810970656, 0.03670479466132875, 0.03325392888304414, 0.029637179738929192, 0.02673836737184211, 0.024501434504956576, 0.022266544244522086, 0.019896487326759985, 0.01789839432747005, 0.016639869219827675, 0.01534333047532404, 0.013915190258524241, 0.012745557783163313, 0.011711873567209576, 0.010901919456017024, 0.009996820345352887, 0.009051514276983459, 0.008902051417104984, 0.007937458711491313, 0.007557333944444858, 0.007187238634509628, 0.006593864350824043, 0.0063239673427532206, 0.005766394600644891, 0.005445589885561416, 0.005134749700840107, 0.0049261575150922445, 0.004538320207739063, 0.00438428885207976, 0.003918534505595291, 0.0037565621327702, 0.0037325637980513623, 0.0032776619767648676, 0.0033534101990214627, 0.003197238294114489, 0.0028596167590116506, 0.002872925459088912, 0.002572002483128975, 0.002500766430592759, 0.002535589113319955, 0.0024679493850582133, 0.002060097045721436, 0.0019571601201713, 0.002023401427621231, 0.0018286483778877257, 0.0017832699101628365, 0.0015747408492672438, 0.0017725869244975807, 0.0014154334397746278, 0.0014905250481000538, 0.0014088381179565013, 0.0013216709381936812, 0.0012099055126638257, 0.0013566875372167865, 0.0010384648235340279, 0.001160631716331762, 0.0011161995840191753, 0.0009379551605270767, 0.0010758863502584467, 0.0009752537822465048, 0.0008920470584693669, 0.0008460407391055329, 0.0008644254151010057, 0.0007539107729171555, 0.0007548035120627148, 0.0007169800747712126, 0.0007246595719676645, 0.0006877276450377314, 0.0006803295604428783, 0.0007278969989383393, 0.000601286237215319, 0.0005388522419946971, 0.000625594383194246, 0.0005219712784461761, 0.00048460573281831546, 0.0005816841374579388, 0.00041559779423445666, 0.0004992715812499702, 0.00047989254266977337, 0.00038153018385387477, 0.0003920662452755912, 0.0005388083189904923, 0.00033291740542348033, 0.00032706541593208025, 0.00035435381447535523, 0.0003650496789638379, 0.0004515757152441635, 0.00028906642741835645, 0.0003901189265068385, 0.00026088815500805373, 0.0004192023030370328, 0.00034165457190234123, 0.0001921913246388943, 0.00019869070507329305, 0.00025390598239206317, 0.00023119082839272517, 0.0002383264327773942, 0.00021164174525198755, 0.00017479226326473082, 0.00026256521597829827, 0.00015034487765248593, 0.0001393403078494717, 0.00024441983390351927, 0.0002758967940071657, 0.00011991157699805694, 0.00021671371292044254, 0.000128188443046667, 0.00015132037638752828, 0.00017459938451538137, 0.00014021089277104596, 0.0002206188412540779, 0.00017613553039089226, 0.0002020823945921205, 0.00017964861394203523, 9.989771394932834e-05, 0.00010454599466266192, 0.0001992641471654301, 9.797335579048599e-05, 0.00013487910120572134, 0.00011804698725905161, 0.0002143627988310899, 0.0001981929574480811, 7.515235386462037e-05, 9.714013387115884e-05, 7.77163038464412e-05, 8.668550563078189e-05, 7.886112448492175e-05, 8.834650611108449e-05, 0.00015524235100025432, 7.012602961642732e-05, 0.00017268969231767488, 7.28635410048559e-05, 5.692733535253045e-05, 5.5697542890886236e-05, 0.0001601717403592919, 0.00010088995859164392, 7.359252225990394e-05, 5.8020180539281004e-05, 6.395039581717688e-05, 0.00010516502258116394, 9.335720279416623e-05, 0.00017169130566806724, 6.856176153313169e-05, 0.0001543512870591863, 4.59533506421904e-05, 6.508423746025666e-05, 4.9173973083997154e-05, 6.183462961588982e-05, 0.00015239075185420918, 7.870647483726834e-05, 5.270624785652533e-05, 3.7706790559252794e-05, 6.667060496412535e-05, 0.00017860654247145918, 0.00010598950106893812, 3.835676718275316e-05, 3.618710059468196e-05, 3.2846100577904624e-05, 3.404222532096577e-05, 3.942438824093833e-05, 3.995858034922441e-05, 8.135124407946157e-05, 4.0810445388017405e-05, 0.00013700538984766616, 9.050008495381196e-05, 3.0553712375372305e-05, 2.865953420715013e-05, 2.9355156820156903e-05, 3.090114058254587e-05, 4.131296334728597e-05, 8.39595999978067e-05, 4.7545952853119134e-05, 2.7425884192642974e-05, 2.638767354714818e-05, 2.393112521747181e-05, 2.9178019196524498e-05, 0.00023751311926977354, 5.297298556886378e-05, 2.505005444463958e-05, 2.8575752340269727e-05, 2.3318816479896123e-05, 2.919980787411524e-05, 2.1805660532229745e-05, 2.3059995134335738e-05, 2.750488669385877e-05, 3.781102262013133e-05, 6.466715794176894e-05, 2.8219024620692385e-05, 2.4618653229501944e-05, 2.1747442737074186e-05, 9.378590963905574e-05, 2.3282992408699718e-05, 2.5044556118729706e-05, 2.4179419708269675e-05, 2.183492562202956e-05, 2.188046629392291e-05, 0.00012812844944325706, 6.018871978926621e-05, 2.369066617830886e-05], "train_accuracies": [0.10274644433545856, 0.19237371260421776, 0.26471309465424225, 0.34526728788621874, 0.43146150073565476, 0.48822952427660615, 0.5391123099558607, 0.5901177047572339, 0.646640510053948, 0.6948258950465914, 0.7138303089749878, 0.7448504168710152, 0.7674104953408534, 0.7867827366356057, 0.8082393330063756, 0.822584600294262, 0.8431829328102011, 0.8593673369298676, 0.8549534085335949, 0.8757969592937714, 0.8822952427660617, 0.8933300637567435, 0.8988474742520843, 0.9003187837175086, 0.9152770966159882, 0.9135605689063266, 0.9252084355076018, 0.9258214811181952, 0.9341589014222658, 0.9362432564982834, 0.9459293771456596, 0.9464198136341344, 0.9447032859244728, 0.9551250613045611, 0.9493624325649829, 0.9515693967631191, 0.9608876900441393, 0.9616233447768514, 0.9659146640510053, 0.9675085826385483, 0.9626042177538009, 0.9639529180971064, 0.974619911721432, 0.9749877390877881, 0.9738842569887199, 0.9757233938205002, 0.9756007846983815, 0.9830799411476214, 0.9767042667974497, 0.9838155958803335, 0.9820990681706719, 0.9822216772927905, 0.9845512506130456, 0.987126042177538, 0.9823442864149092, 0.9897008337420304, 0.9874938695438941, 0.9865129965669446, 0.9915399705738107, 0.9874938695438941, 0.9903138793526238, 0.9908043158410986, 0.9926434526728789, 0.9921530161844041, 0.9936243256498284, 0.9927660617949976, 0.9937469347719471, 0.9939921530161844, 0.9943599803825405, 0.9937469347719471, 0.9931338891613536, 0.9944825895046592, 0.9964443354585582, 0.9935017165277097, 0.9959538989700834, 0.9963217263364395, 0.9950956351152526, 0.9973025993133889, 0.9949730259931339, 0.9968121628249141, 0.9976704266797449, 0.9976704266797449, 0.9949730259931339, 0.9986512996566944, 0.9985286905345758, 0.9984060814124571, 0.9970573810691515, 0.9954634624816087, 0.9991417361451692, 0.9964443354585582, 0.9988965179009318, 0.9966895537027954, 0.9976704266797449, 0.9998773908778813, 0.9996321726336439, 0.9982834722903384, 0.9985286905345758, 0.9997547817557626, 0.9992643452672879, 0.9996321726336439, 0.9979156449239823, 0.9998773908778813, 1.0, 0.9985286905345758, 0.9974252084355076, 1.0, 0.9985286905345758, 0.9997547817557626, 0.9995095635115253, 0.9990191270230505, 0.9998773908778813, 0.9986512996566944, 0.9988965179009318, 0.9981608631682197, 0.9987739087788131, 1.0, 0.9998773908778813, 0.9984060814124571, 0.9998773908778813, 0.9997547817557626, 0.9997547817557626, 0.9976704266797449, 0.9979156449239823, 1.0, 0.9996321726336439, 1.0, 1.0, 0.9998773908778813, 0.9998773908778813, 0.9992643452672879, 1.0, 0.9981608631682197, 0.9998773908778813, 1.0, 1.0, 0.9985286905345758, 0.9997547817557626, 0.9997547817557626, 1.0, 0.9998773908778813, 0.9992643452672879, 0.9996321726336439, 0.9981608631682197, 0.9998773908778813, 0.9981608631682197, 1.0, 0.9998773908778813, 0.9998773908778813, 1.0, 0.9984060814124571, 0.9997547817557626, 0.9998773908778813, 1.0, 0.9997547817557626, 0.9977930358018636, 0.9991417361451692, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9997547817557626, 1.0, 0.9981608631682197, 0.9995095635115253, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9990191270230505, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9958312898479647, 0.9997547817557626, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9998773908778813, 0.9996321726336439, 1.0, 1.0, 1.0, 0.9995095635115253, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9979156449239823, 0.9996321726336439, 1.0], "train_duration": 5.51313775062561, "test_accuracies": 0.7716431095406361, "nb_spikes": [63.36234368218316, 52.71093591054281, 58.357414139641655], "test_duration": 1.2544794082641602, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "SPSN-SB", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.0, "loss_mode": "mean", "data_augmentation": false, "h_shift": 0.1, "scale": 0.3, "dir": "basic/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} -------------------------------------------------------------------------------- /outputs/data_aug/results_Non-Spiking_1675889858.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [0.046870894558351375, 0.04513777962731871, 0.04277407123271948, 0.04143538803607591, 0.04047747907128739, 0.03990462876933062, 0.03940175306335157, 0.03879221307699326, 0.03838457296501485, 0.0378855899480583, 0.03748356091619532, 0.0367999798588568, 0.03647782584508884, 0.03596314441462017, 0.03565480245451766, 0.035357435904863475, 0.03480003624350607, 0.0343530820742023, 0.03384571551341178, 0.033483901521983016, 0.03303338074988159, 0.032726279963694, 0.03251601455257242, 0.03211271900239676, 0.03166644959954902, 0.03118491605428459, 0.031065986371613296, 0.030813886114871635, 0.03045448608232398, 0.030111483758193947, 0.030010466045933416, 0.029464231614916857, 0.029660283686075217, 0.029127750036002025, 0.02890721903296298, 0.02895098100286187, 0.028828214260221054, 0.028408601785068596, 0.027893587121079515, 0.02768682885427227, 0.02775059340045755, 0.027461732994872836, 0.027258394131653447, 0.026790603637227597, 0.026755619306316442, 0.026722863703863828, 0.026161102730599965, 0.026326902462840956, 0.025934399548203646, 0.025925550787047814, 0.025788594806469087, 0.025706844434485593, 0.025327526336209688, 0.025328943941627078, 0.025134369606712156, 0.025150318073255868, 0.02474839492075697, 0.024447581781482743, 0.024367365885622775, 0.02435846801421524, 0.024123687203927855, 0.02422247930096902, 0.023744995948320045, 0.023908579828927422, 0.023594058512939782, 0.023678949836307205, 0.02379023090421481, 0.023728334813096934, 0.023464069376978702, 0.023276134631164874, 0.023161954903731224, 0.023152985281661306, 0.02276062440673413, 0.02274828298586499, 0.022594730020329436, 0.0228018009265069, 0.02250315894269078, 0.02244190908986719, 0.022646123387171625, 0.02249580099922936, 0.022155848728669162, 0.02218919320104168, 0.022087516198268178, 0.021896100602587277, 0.022181985160777122, 0.02184658055798922, 0.02176839390942254, 0.021400935783463405, 0.021572677629610206, 0.021533202685053994, 0.021662698694161774, 0.021275477228999546, 0.02131800164247974, 0.021271777962749643, 0.021137389322891723, 0.021367493594844533, 0.02077683994236853, 0.02092037592201972, 0.021149338314847773, 0.020893463306768435, 0.020871832155608384, 0.020610875069598113, 0.020526694975219677, 0.020379713407154467, 0.020624378897969545, 0.02013100270052176, 0.020334782605956968, 0.020258935790485235, 0.020112090890696377, 0.020201439247112637, 0.020159232179401784, 0.01997367484479514, 0.01992778132188665, 0.019803477983781084, 0.019941411272984853, 0.019917979281050366, 0.0196989065820882, 0.01964344750759121, 0.019478772918987416, 0.01944057706262973, 0.019377224297135294, 0.019447768228553333, 0.019417951830816713, 0.01922030702357552, 0.01895315004599451, 0.019582334140342144, 0.019181470833846668, 0.01930227011895285, 0.019163572237209808, 0.019044181488202215, 0.01899673731640427, 0.019143497632554258, 0.018923351611501743, 0.018987722756992433, 0.019038805289558477, 0.01862997579597971, 0.018689779644213568, 0.018714923333045957, 0.01869161073310285, 0.018601602406522817, 0.01845439762797877, 0.018516962641650054, 0.018510118704206983, 0.018277289938727917, 0.018818356904286623, 0.018469837328392597, 0.018366343611709742, 0.018540962397789126, 0.018398434413362214, 0.0186257594616288, 0.018362233573283093, 0.01829420634494686, 0.01804081635945214, 0.01815244708212057, 0.018049194641473424, 0.017866792253210827, 0.018224402940583147, 0.018077241498913, 0.01793812435301219, 0.017831649651359027, 0.017956570382794076, 0.017896622090494006, 0.01793647006102904, 0.017957751089711583, 0.017758886864689297, 0.01798399717214939, 0.017754978901793877, 0.01755369820391331, 0.017673173310009505, 0.01740346083078623, 0.01763049602625474, 0.017477871967975622, 0.017521049365101637, 0.017524048057533236, 0.017374398955064055, 0.01747622481597996, 0.01737373777906821, 0.01733680701594893, 0.01730477627934984, 0.017381950036283216, 0.01708248760432458, 0.017289493205560193, 0.01743434349832259, 0.01696512205393774, 0.016993854554802134, 0.017104557766056107, 0.017301981986416737, 0.01717582497016772, 0.017227470801705645, 0.016959486447928756, 0.017115110060992107, 0.016912552693885235, 0.01708119171486584, 0.016826652527557044, 0.017165037223730327, 0.016862838956115875, 0.01685853331711083, 0.01676321930748508, 0.01699561610006713, 0.016723544465203445], "train_accuracies": [0.05750367827366356, 0.10250122609122118, 0.13462481608631682, 0.1602501226091221, 0.1679744973025993, 0.17520843550760176, 0.18477194703285924, 0.19801373222167729, 0.19997547817557626, 0.2132172633643943, 0.21910250122609123, 0.2362677783227072, 0.24350171652770966, 0.24938695438940658, 0.2534330554193232, 0.2594409024031388, 0.27035311427170183, 0.27476704266797447, 0.2816331535066209, 0.28898970083374204, 0.30002452182442374, 0.3030897498773909, 0.30689063266307015, 0.31093673369298674, 0.3249141736145169, 0.3304315841098578, 0.3214811181951937, 0.3412211868563021, 0.34502206964198134, 0.34955860716037274, 0.3526238352133399, 0.36611083864639526, 0.37273173124080433, 0.37616478666012754, 0.3841343795978421, 0.38511525257479157, 0.38352133398724864, 0.3908778813143698, 0.4010544384502207, 0.41012751348700344, 0.4079205492888671, 0.40890142226581655, 0.42447278077488965, 0.4261893084845513, 0.4306032368808239, 0.42533104462972043, 0.43795978420794507, 0.43881804806277586, 0.44249632172633646, 0.4368563021088769, 0.44347719470328595, 0.4439676311917607, 0.4538989700833742, 0.459048553212359, 0.4536537518391368, 0.45782246199117216, 0.47314860225600786, 0.4639529180971064, 0.47081902893575284, 0.47314860225600786, 0.4778077488965179, 0.47756253065228055, 0.48418342324668956, 0.49092692496321727, 0.4901912702305051, 0.4925208435507602, 0.4932564982834723, 0.49386954389406573, 0.4954634624816086, 0.5007356547327121, 0.5004904364884747, 0.5026974006866111, 0.5115252574791564, 0.5088278567925454, 0.517165277096616, 0.5154487493869544, 0.5180235409514468, 0.5165522314860226, 0.5111574301128003, 0.5153261402648357, 0.5258705247670427, 0.532000980872977, 0.5267287886218734, 0.5342079450711132, 0.5242766061794998, 0.5344531633153506, 0.5342079450711132, 0.5541932319764591, 0.5376410004904365, 0.5411966650318784, 0.5359244727807749, 0.5418097106424719, 0.5459784207945071, 0.5492888670917117, 0.5521088769004414, 0.5483079941147622, 0.5592202059833251, 0.5540706228543404, 0.5593428151054438, 0.5532123589995096, 0.5589749877390878, 0.562408043158411, 0.5655958803334968, 0.5689063266307013, 0.5701324178518882, 0.579941147621383, 0.568661108386464, 0.5723393820500245, 0.57368808239333, 0.5749141736145169, 0.575282000980873, 0.5766307013241785, 0.5730750367827366, 0.5758950465914664, 0.5728298185384992, 0.582270720941638, 0.5809220205983325, 0.583864639529181, 0.5893820500245218, 0.5902403138793526, 0.5850907307503679, 0.5857037763609613, 0.5866846493379108, 0.5881559588033349, 0.5940411966650319, 0.5805541932319764, 0.5933055419323198, 0.5937959784207945, 0.5961255517410495, 0.5977194703285924, 0.5914664051005395, 0.5937959784207945, 0.5940411966650319, 0.5952672878862187, 0.5947768513977439, 0.6075282000980873, 0.5995586071603727, 0.6045855811672388, 0.6065473271211378, 0.6102256007846983, 0.6075282000980873, 0.6120647376164786, 0.613903874448259, 0.6113290828837665, 0.5985777341834232, 0.6191760666993624, 0.6194212849435998, 0.6051986267778323, 0.6083864639529181, 0.6112064737616478, 0.6130456105934281, 0.6083864639529181, 0.6202795487984306, 0.6177047572339383, 0.6166012751348701, 0.6222412947523296, 0.6121873467385973, 0.616969102501226, 0.6232221677292791, 0.6233447768513978, 0.6190534575772437, 0.6266552231486022, 0.6207699852869053, 0.6202795487984306, 0.6232221677292791, 0.6206473761647867, 0.6253065228052967, 0.6311917606669937, 0.6300882785679255, 0.6297204512015694, 0.6250613045610593, 0.6318048062775871, 0.6253065228052967, 0.6348700343305542, 0.6322952427660617, 0.631436978911231, 0.6273908778813144, 0.63805787150564, 0.6369543894065719, 0.6264100049043649, 0.6397743992153017, 0.6385483079941148, 0.6284943599803825, 0.6408778813143697, 0.6367091711623345, 0.6418587542913192, 0.631069151544875, 0.6293526238352133, 0.6390387444825895, 0.6371996076508092, 0.640019617459539, 0.6416135360470819, 0.6417361451692006, 0.6412457086807258, 0.6292300147130947, 0.6387935262383522, 0.6401422265816576, 0.6452918097106425, 0.6345022069641981, 0.6500735654732712], "train_duration": 6.733256921768189, "test_accuracies": 0.6607773851590106, "nb_spikes": [], "test_duration": 0.30144286155700684, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "Non-Spiking", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.0, "loss_mode": "mean", "data_augmentation": true, "h_shift": 0.1, "scale": 0.3, "dir": "data_aug/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} -------------------------------------------------------------------------------- /outputs/basic/results_Non-Spiking_1675888688.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [0.04679242850753126, 0.04342181253456613, 0.038229585629342054, 0.035060111011226365, 0.03297411870757641, 0.03153584712623907, 0.030592488228193156, 0.02984992676707329, 0.029192007409409602, 0.028586722007857168, 0.028038259593805535, 0.02752940841256666, 0.026988956143658437, 0.026500849863546276, 0.025882917008252633, 0.025345538003354637, 0.02477509242287449, 0.024265399424687153, 0.02379699642603277, 0.02332649156707008, 0.022843754806022774, 0.022259120586866722, 0.02193885519844811, 0.02135551788457297, 0.020963203679001056, 0.020548146808188356, 0.020023198864985003, 0.01960364022840759, 0.019200026594347202, 0.01874336178627828, 0.01856117516887603, 0.01812053756453115, 0.017777383473645596, 0.017439160168960202, 0.01715031331985122, 0.016786185095507822, 0.01648298790929597, 0.016222240367667708, 0.015886560411298902, 0.015645820927011903, 0.015524400324140944, 0.015223374498013248, 0.014947686105045752, 0.014692819596447741, 0.01451741685143518, 0.014363102165433109, 0.014247225145313774, 0.014018132120677804, 0.0137658773785774, 0.013645446011106893, 0.013386322250256297, 0.013334789598847558, 0.013152607103198577, 0.012939757338161852, 0.012764839282832115, 0.012692231460199922, 0.012484374687682185, 0.012149205514471456, 0.012116605040488025, 0.012098390782592693, 0.01205162152231412, 0.011784658778701827, 0.011627739758512563, 0.011539365239386585, 0.011505858944349865, 0.01133237628594165, 0.011075825186089128, 0.011075835161609005, 0.010853469265828593, 0.010883084125177366, 0.010708108379918023, 0.010572067969362194, 0.010452889972980727, 0.010234549937586155, 0.010209539876871685, 0.010030991963716977, 0.010085675194896044, 0.010021313630289281, 0.009852644401388229, 0.00975276578022254, 0.009721661400362345, 0.009626293705631086, 0.009499177033407774, 0.009246158283413948, 0.00952462164311505, 0.009275075484865608, 0.009197202637652781, 0.009037432113279602, 0.008947227985149157, 0.008731784131843825, 0.008699298029205505, 0.008668025859639363, 0.00867734136185031, 0.00858519254448719, 0.008334501116167042, 0.008303825738472118, 0.008210603284771504, 0.008170341593252204, 0.008102049246291545, 0.007965019750150761, 0.00794928055028228, 0.007769216040205055, 0.007823314044538697, 0.007654688404640069, 0.007623682533952171, 0.007478380368409524, 0.007543085748626649, 0.007477419521219301, 0.007319938970611866, 0.007281119368382094, 0.007241717911982665, 0.00716294184173656, 0.007083639915431463, 0.006943946043801343, 0.007021049228499133, 0.006994363492379415, 0.00696091121679255, 0.0067108754085056215, 0.006807056204938842, 0.006551272465558307, 0.006502678573073096, 0.006653256725306134, 0.006557592158372523, 0.006431950892944322, 0.006441267840326738, 0.006366615147567836, 0.006242265247901086, 0.006166294810890626, 0.006137830744688858, 0.006083055755170085, 0.006105450985477274, 0.0059992275810551565, 0.005893285298110691, 0.0059325131713291665, 0.0058189029693457356, 0.005793291423150054, 0.005766271987620261, 0.005715268809766732, 0.005742898387190406, 0.005517045102644458, 0.00552748458468195, 0.005482798461428222, 0.0054058834523064885, 0.005389189871065519, 0.005422921066572994, 0.0051867793706511095, 0.0052367514177477205, 0.005413765276820125, 0.005182241756960479, 0.005130333598482076, 0.005121094993644043, 0.005166396836678499, 0.0050143278302106865, 0.004945952545754637, 0.004911731369519362, 0.0049296552776676465, 0.004776617671970996, 0.004738059497788702, 0.00475755857889298, 0.004691101745859906, 0.004705627531529409, 0.004699170926531896, 0.004670409831917713, 0.0046366539835389076, 0.004597785286587088, 0.004436164660458766, 0.004396550448692681, 0.004330395724096387, 0.004517302862558077, 0.004503456914962771, 0.004381581033749461, 0.004151889382002399, 0.004215031632054606, 0.004133530185832411, 0.0042567656481891595, 0.004123278602577706, 0.004221697166642469, 0.004070894937010417, 0.004010686936800593, 0.003926593576461559, 0.003980162869106618, 0.003929157221124358, 0.003886193379599011, 0.003840721962775244, 0.003805065741319467, 0.0038612796529915593, 0.0036305637866653256, 0.0036526688819001154, 0.003656516711228966, 0.0037440309943215645, 0.003699304960766746, 0.0036604051844067523, 0.003516997409369678, 0.00344013191427069, 0.003426064662690137, 0.0034675228757972865, 0.00341798295963477, 0.0033631526123752424, 0.0033950177699605174, 0.003290767263020003], "train_accuracies": [0.0614271701814615, 0.1272682687591957, 0.2140755272192251, 0.26152525747915645, 0.2842079450711133, 0.31436978911230995, 0.3284698381559588, 0.33766552231486024, 0.34857773418342325, 0.36316821971554686, 0.37408043158410986, 0.38621873467385975, 0.39602746444335457, 0.402893575282001, 0.41932319764590487, 0.42778322707209415, 0.44445806768023544, 0.45855811672388425, 0.47045120156939674, 0.4803825404610103, 0.49190779794016676, 0.5014713094654242, 0.5082148111819519, 0.5275870524767042, 0.532368808239333, 0.5470819028935753, 0.5646150073565473, 0.5681706718979892, 0.5817802844531633, 0.587788131436979, 0.5881559588033349, 0.6055664541441883, 0.6098577734183424, 0.6146395291809711, 0.6217508582638548, 0.6322952427660617, 0.6368317802844532, 0.6370769985286905, 0.650931829328102, 0.6495831289847964, 0.6533840117704757, 0.6607405590975969, 0.6625796959293772, 0.6728788621873467, 0.6768023540951447, 0.6799901912702305, 0.6841589014222658, 0.6799901912702305, 0.691883276115743, 0.6940902403138793, 0.6976459048553212, 0.6959293771456596, 0.7059833251593919, 0.7182442373712604, 0.7091711623344776, 0.7203285924472781, 0.7233938205002453, 0.7311181951937225, 0.7320990681706719, 0.7303825404610103, 0.7274399215301618, 0.7401912702305051, 0.7405590975968612, 0.7446051986267779, 0.7420304070622854, 0.7442373712604218, 0.751961745953899, 0.7555174104953408, 0.7614026483570377, 0.7589504659146641, 0.760544384502207, 0.7656939676311918, 0.7692496321726336, 0.7741539970573811, 0.7742766061794998, 0.7867827366356057, 0.7763609612555175, 0.776483570377636, 0.787027954879843, 0.7871505640019617, 0.7867827366356057, 0.7914418832761158, 0.7935262383521334, 0.8010053948013732, 0.7904610102991663, 0.7976949485041687, 0.7996566944580676, 0.8034575772437469, 0.8055419323197646, 0.81436978911231, 0.8135115252574792, 0.8086071603727317, 0.8111819519372241, 0.818661108386464, 0.8189063266307013, 0.8241785188818048, 0.8212358999509564, 0.821726336439431, 0.8255272192251103, 0.8320255026974007, 0.8288376655223149, 0.8348455125061305, 0.8354585581167239, 0.8369298675821482, 0.8393820500245218, 0.8395046591466405, 0.8371750858263854, 0.838523786169691, 0.8450220696419813, 0.8451446787641, 0.8441638057871506, 0.851765571358509, 0.8507846983815596, 0.8544629720451201, 0.8522560078469839, 0.8515203531142717, 0.8527464443354585, 0.8581412457086808, 0.8589995095635116, 0.8614516920058852, 0.8601029916625796, 0.8593673369298676, 0.8654977930358019, 0.8624325649828347, 0.8645169200588524, 0.8697891123099558, 0.8734673859735165, 0.8738352133398725, 0.8721186856302109, 0.875183913683178, 0.872486512996567, 0.8784943599803825, 0.8804561059342815, 0.8803334968121628, 0.8788621873467386, 0.8807013241785189, 0.8816821971554684, 0.8798430603236881, 0.8787395782246199, 0.889651790093183, 0.8864639529180971, 0.8890387444825895, 0.8886709171162335, 0.8886709171162335, 0.8898970083374204, 0.8934526728788622, 0.8913683178028445, 0.8914909269249632, 0.8955370279548799, 0.8986022560078469, 0.8951692005885238, 0.8911230995586071, 0.898234428641491, 0.9004413928396273, 0.8999509563511525, 0.9020353114271702, 0.9024031387935263, 0.9022805296714076, 0.9024031387935263, 0.9058361942128494, 0.9066944580676802, 0.9047327121137813, 0.9060814124570868, 0.9092692496321726, 0.9068170671897989, 0.9120892594409024, 0.9136831780284453, 0.9140510053948013, 0.9066944580676802, 0.907184894556155, 0.912211868563021, 0.9158901422265816, 0.9142962236390387, 0.9182197155468367, 0.916135360470819, 0.918097106424718, 0.9144188327611574, 0.9192005885237862, 0.9209171162334477, 0.9234919077979402, 0.9183423246689554, 0.9236145169200588, 0.9228788621873467, 0.9249632172633644, 0.9275380088278568, 0.922756253065228, 0.9291319274153997, 0.931338891613536, 0.9301128003923492, 0.9268023540951447, 0.9268023540951447, 0.924717999019127, 0.9331780284453163, 0.9318293281020108, 0.9350171652770967, 0.9325649828347229, 0.9326875919568416, 0.9361206473761647, 0.9328102010789603, 0.9383276115743011], "train_duration": 1.506732382774353, "test_accuracies": 0.7181978798586572, "nb_spikes": [], "test_duration": 0.30219149589538574, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "Non-Spiking", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.0, "loss_mode": "mean", "data_augmentation": false, "h_shift": 0.1, "scale": 0.3, "dir": "basic/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} -------------------------------------------------------------------------------- /outputs/data_aug/results_LIF_1675940888.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [0.04671013334792522, 0.04414309594138923, 0.04129482034727193, 0.0389302469106677, 0.03652428464014896, 0.034307186936209634, 0.03164662690417564, 0.029509880748550467, 0.027822951202944483, 0.02655038526971414, 0.02486587685951712, 0.024073919006397233, 0.023088758046747482, 0.022235268029937444, 0.021531784670384507, 0.02083961786213548, 0.020032174055806675, 0.019706980424397363, 0.01914468517982825, 0.01894381033579119, 0.018364270925229527, 0.01811775049197434, 0.017477595678961354, 0.017226358505661083, 0.017361732544181042, 0.016910068935597277, 0.016613586655488607, 0.016492775852859107, 0.015944066425934326, 0.015941447227126305, 0.015583227258385719, 0.015277117454228066, 0.015391513668713235, 0.014832939417295564, 0.014742166987817155, 0.014506776742691968, 0.01452737785747906, 0.014261590157027569, 0.014423308980821569, 0.01405399747634998, 0.01397507923850246, 0.013964095650320254, 0.013664605615769257, 0.013418920590282364, 0.013341909759532242, 0.013134861026048777, 0.013352568686563867, 0.013169062764636438, 0.012825654915526195, 0.012842564415615992, 0.012970404225333524, 0.01288703629174263, 0.01265683918533634, 0.012498370736705842, 0.01234030005918992, 0.012010586687809934, 0.012061610843487497, 0.012075432160135693, 0.0120207257148622, 0.01163112027584308, 0.011872532307487776, 0.011534479120129636, 0.011387844776890919, 0.01151527736250123, 0.011010437937066974, 0.011582353441460335, 0.01117097923660933, 0.01104524288036006, 0.011149368690561582, 0.01105303508779943, 0.010686287149428854, 0.010847738059260438, 0.01067177159506179, 0.010608731161998498, 0.010791894426645163, 0.01037598472561773, 0.010644491756441781, 0.010514788117667167, 0.010329057621920794, 0.010167079186135498, 0.010372717176482278, 0.010254859372556883, 0.01017465929502129, 0.009967945980172113, 0.009784045330942584, 0.010120602829775546, 0.009700633011769289, 0.009683498561879243, 0.009696601648404354, 0.009550692927404734, 0.009615611951832914, 0.009540707735292697, 0.009490964092887928, 0.009411705583845534, 0.009335720901573447, 0.009194681144129009, 0.009267472512672905, 0.00913403164951634, 0.008876471425342232, 0.009234617600813982, 0.009079572041722241, 0.008972607381265499, 0.009078433948179388, 0.008997224866905419, 0.008881543980853822, 0.008840027211810398, 0.00886729506153637, 0.00889474249619839, 0.008734440934342341, 0.008469908980103672, 0.008601969856149487, 0.008818317820418985, 0.008534058066137051, 0.008465640805935731, 0.008471779214686889, 0.008677597023828056, 0.008236926175058326, 0.008301922939114853, 0.008159455333988475, 0.008258876586702888, 0.008214813456589712, 0.00821461226350832, 0.008285276515043034, 0.008327822845434664, 0.007925574074427599, 0.008127516763154171, 0.008225410288750628, 0.007678590185096183, 0.00796650882486972, 0.007950321526840776, 0.007827474879554933, 0.007958977965089023, 0.007762526131366152, 0.007839537247131127, 0.0076253024377970205, 0.0077216811313706325, 0.007592621896563938, 0.00765859623964421, 0.0075856077065942565, 0.00754138195540988, 0.0074641563440433255, 0.007625041947009927, 0.007465285413942833, 0.0071615799658581464, 0.007224652304141881, 0.0073317631479862455, 0.007524734139150121, 0.007192562521965613, 0.00730243732293135, 0.00725804163712155, 0.007348214490352863, 0.007222819855950988, 0.0071137830041359896, 0.00687416266903917, 0.0072279403926169775, 0.007112231779735774, 0.0071122475304604745, 0.006842966353618496, 0.007128359133294542, 0.006937278793643303, 0.007020234316222044, 0.00677836219752285, 0.0070728521087024656, 0.006718461467989992, 0.00693582613687997, 0.0069733682111848385, 0.006955515936769768, 0.006827432205494038, 0.006735757367832512, 0.006880501434097341, 0.006608779216482103, 0.006695945612454075, 0.006731071751044041, 0.006654428213089339, 0.006650703021449865, 0.006532045253911049, 0.006782952299622007, 0.0065653719676657596, 0.006814785838828712, 0.006624915275783034, 0.0065852784738492706, 0.006780824950581499, 0.006531123229032524, 0.00660568041989124, 0.006418870035750776, 0.0064027811372730295, 0.006128602508179635, 0.006373567715864547, 0.006448727117530735, 0.0062114052426996275, 0.006329829641537902, 0.006150927668681269, 0.0060961974648619, 0.006157759520398337, 0.006086896953336043, 0.006299036362698524, 0.00615026111394357, 0.006255323280155103, 0.006262874905825714, 0.005905382324223895], "train_accuracies": [0.06204021579205493, 0.12653261402648358, 0.18894065718489456, 0.2584600294261893, 0.2942618930848455, 0.33398724865129964, 0.38597351642962235, 0.4283962726826876, 0.4659146640510054, 0.48320255026974007, 0.5143452672878862, 0.5279548798430603, 0.5483079941147622, 0.5507601765571358, 0.579573320255027, 0.5850907307503679, 0.6015203531142717, 0.6104708190289357, 0.6157430112800393, 0.6239578224619912, 0.633030897498774, 0.6389161353604709, 0.6596370769985287, 0.6598822952427661, 0.6506866110838646, 0.6593918587542913, 0.6656449239823443, 0.6628249141736146, 0.681338891613536, 0.6824423737126042, 0.6904119666503188, 0.6884502206964198, 0.6907797940166749, 0.7096615988229524, 0.7010789602746444, 0.7122363903874448, 0.7154242275625307, 0.7115007356547327, 0.7065963707699853, 0.7192251103482099, 0.7192251103482099, 0.7191025012260912, 0.7248651299656694, 0.7286660127513487, 0.7316086316821971, 0.7355321235899951, 0.737126042177538, 0.7340608141245709, 0.7388425698871997, 0.7416625796959294, 0.7363903874448259, 0.7422756253065228, 0.7497547817557626, 0.7488965179009318, 0.7544139283962726, 0.7556400196174595, 0.7614026483570377, 0.7584600294261893, 0.7596861206473762, 0.769127023050515, 0.7669200588523786, 0.7692496321726336, 0.775380088278568, 0.7710887690044139, 0.7834722903384012, 0.7658165767533105, 0.7785679254536537, 0.7777096615988229, 0.7789357528200098, 0.7817557626287396, 0.7908288376655224, 0.7785679254536537, 0.7896027464443355, 0.7898479646885729, 0.7872731731240804, 0.7957332025502697, 0.7903384011770476, 0.7949975478175576, 0.7936488474742521, 0.799166257969593, 0.7929131927415399, 0.7965914664051006, 0.8006375674350171, 0.7986758214811182, 0.8073810691515448, 0.8041932319764591, 0.8095880333496812, 0.8088523786169691, 0.8097106424717999, 0.8119176066699363, 0.8094654242275625, 0.8097106424717999, 0.8127758705247671, 0.8115497793035802, 0.8187837175085826, 0.8196419813634135, 0.8198871996076508, 0.8178028445316331, 0.8273663560568907, 0.8193967631191761, 0.8201324178518882, 0.8231976459048553, 0.821726336439431, 0.8247915644923982, 0.8285924472780775, 0.8283472290338402, 0.830308974987739, 0.8238106915154487, 0.8301863658656204, 0.8343550760176557, 0.8288376655223149, 0.8269985286905346, 0.8320255026974007, 0.8354585581167239, 0.8344776851397744, 0.829205492888671, 0.8354585581167239, 0.8359489946051987, 0.8397498773908779, 0.8375429131927415, 0.8370524767042667, 0.8386463952918097, 0.8374203040706228, 0.8310446297204512, 0.8420794507111329, 0.8399950956351152, 0.8409759686120647, 0.8496812162824914, 0.8419568415890142, 0.8397498773908779, 0.8487003433055419, 0.8382785679254536, 0.8469838155958803, 0.8500490436488475, 0.8495586071603727, 0.8473516429622364, 0.8534820990681706, 0.8494359980382541, 0.8550760176557136, 0.8554438450220696, 0.8566699362432565, 0.8516429622363904, 0.8482099068170672, 0.8580186365865621, 0.8577734183423247, 0.8574055909759686, 0.8515203531142717, 0.8548307994114762, 0.8523786169691026, 0.8569151544874939, 0.8561794997547818, 0.8543403629230014, 0.8610838646395291, 0.8616969102501226, 0.8559342815105444, 0.859612555174105, 0.8564247179990191, 0.863903874448259, 0.8540951446787641, 0.8650073565473271, 0.8608386463952918, 0.870892594409024, 0.8610838646395291, 0.8679499754781755, 0.8625551741049534, 0.8636586562040216, 0.8625551741049534, 0.8631682197155468, 0.866233447768514, 0.863903874448259, 0.8739578224619912, 0.8699117214320745, 0.8648847474252085, 0.8707699852869053, 0.8677047572339383, 0.8677047572339383, 0.8646395291809711, 0.870892594409024, 0.8666012751348701, 0.8683178028445316, 0.8684404119666503, 0.8632908288376655, 0.8692986758214811, 0.8679499754781755, 0.8688082393330063, 0.8726091221186856, 0.8788621873467386, 0.8727317312408043, 0.8733447768513978, 0.8787395782246199, 0.8711378126532614, 0.8803334968121628, 0.8776360961255517, 0.8782491417361452, 0.8762873957822462, 0.8735899950956351, 0.879107405590976, 0.8711378126532614, 0.874448258950466, 0.8821726336439432], "train_duration": 261.7646133852005, "test_accuracies": 0.8303886925795053, "nb_spikes": [5.703311761220296, 6.492319491174486, 11.45410230424669], "test_duration": 17.770599603652954, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "LIF", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.0, "loss_mode": "mean", "data_augmentation": true, "h_shift": 0.1, "scale": 0.3, "dir": "data_aug/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} -------------------------------------------------------------------------------- /outputs/regul/results_SPSN-GS_1675891157.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [24.987427431324523, 0.8481050060332085, 0.2838696423203645, 0.12181497897629413, 0.07443603029901225, 0.053360947970960464, 0.050109635766316066, 0.04924150200788153, 0.0484842824666976, 0.04824879810820613, 0.04789857165143442, 0.04745814875097121, 0.04728725682642601, 0.04718887697779238, 0.04689448224147667, 0.046418801951256375, 0.04606720964484147, 0.04545619944955509, 0.04493889158528596, 0.04433138172318972, 0.04349870521569731, 0.04258746306296363, 0.04170863671651244, 0.04056676764766512, 0.039418324265426255, 0.03816691004452371, 0.036741508360058725, 0.03553214182509927, 0.03453562882614229, 0.03351412633869682, 0.03269854933388856, 0.03171607843212429, 0.030850294029204035, 0.02992378123371317, 0.02895654105801042, 0.028233018165245073, 0.027153314563092674, 0.026489310313697653, 0.025836095952005467, 0.024784620615592945, 0.02397546541930531, 0.023502415436164253, 0.022592550577516354, 0.022020393197825283, 0.021782851930983574, 0.021276619868163912, 0.02075884444798951, 0.020257375787262125, 0.019865512767511587, 0.019316747974624467, 0.018704862842202012, 0.018416914272214806, 0.01800554110876423, 0.017571677401289162, 0.017190094012264647, 0.017254337098389177, 0.016864171751344127, 0.016429609794018023, 0.01630386548658514, 0.015624406599074733, 0.015155720631827204, 0.014993306152548142, 0.0146515097628042, 0.014577647790773175, 0.014226785348933839, 0.014041909422723611, 0.013868952247662658, 0.013230564490201838, 0.012908632386075685, 0.012735287411765996, 0.012767733294423624, 0.012201513219926215, 0.012167377648633038, 0.01192014822631329, 0.011636634900354766, 0.011303719957123906, 0.01142584472938517, 0.010884673294920499, 0.010897370860119436, 0.010719098413820767, 0.010546779357784808, 0.010243206660849257, 0.010196960742652503, 0.010097039712738207, 0.009621575807256404, 0.009686420603585395, 0.009502325106509004, 0.00922288314830795, 0.00929298370988887, 0.009117716920615528, 0.008945527243052937, 0.008806370721253886, 0.008979266881211942, 0.008215613462224552, 0.008285407976317254, 0.008265814471253581, 0.008228882069298893, 0.008403390379236163, 0.00798793785299607, 0.007981681619835227, 0.007722362163401281, 0.007818248292843229, 0.007534759189493227, 0.007815887032477747, 0.007581538275476178, 0.007357725677185049, 0.007451790477669432, 0.007057711021639425, 0.007378102949344978, 0.007062020943175349, 0.006986111735028926, 0.00694162450464056, 0.006818871967066965, 0.006955136408432843, 0.00686784608398307, 0.006856321060756891, 0.006444287294994212, 0.0067532050106793185, 0.006564226189683968, 0.006374492771766419, 0.006470233265926231, 0.006514745599545937, 0.006270964825895851, 0.006307922274444312, 0.006049615040954975, 0.006102416642437734, 0.006051098172383427, 0.006063963894518992, 0.006043258349193318, 0.005859431149753184, 0.005860622120859461, 0.006012114777963731, 0.0058481089493087135, 0.005684592001516484, 0.005764326211311702, 0.005751655262115248, 0.005765539134042762, 0.0056107840324921485, 0.005445706359804583, 0.00547471122159392, 0.005640772738150957, 0.005532557178560924, 0.005486327944695219, 0.005376881247001749, 0.005265841535518659, 0.005364825385437695, 0.005229465213986575, 0.005361995865764894, 0.005239069134786236, 0.0052756017951871206, 0.005113837081261158, 0.004952728636873418, 0.005082438352310874, 0.005295906765838649, 0.005036732215706539, 0.005159236578423349, 0.004893022890386773, 0.0047838373284178545, 0.004807073675925502, 0.004907728900448723, 0.004704318032611405, 0.005077485115709932, 0.004721358727230752, 0.00492280880797722, 0.004732495814181824, 0.004690266106060289, 0.004762595189646683, 0.004713046020550024, 0.004710033338639361, 0.004772861718824531, 0.004563920696820039, 0.004696277517856424, 0.00469032109382984, 0.004671666570201465, 0.004544763997116763, 0.004555786864365241, 0.004551220040306149, 0.004360990011411417, 0.004532561604118102, 0.004470718710748602, 0.004504382092361359, 0.004312385179654181, 0.004591310267985715, 0.004571447345323572, 0.004329848898434943, 0.004246684566849058, 0.00443764402366199, 0.004293585962571602, 0.004208716368846241, 0.004247035093554362, 0.00427901767680258, 0.004301819265210318, 0.004295944034790162, 0.004167693711418974, 0.004306059650430029, 0.004127498256620851, 0.004204530923893605, 0.004176815622771645, 0.004142609545015541, 0.0041336482185221814], "train_accuracies": [0.049901912702305054, 0.04659146640510054, 0.05063756743501716, 0.04536537518391368, 0.04548798430603237, 0.05333496812162825, 0.05161844041196665, 0.053948013732221675, 0.05529671407552722, 0.05468366846493379, 0.05321235899950957, 0.05872976949485042, 0.06817067189798921, 0.07797940166748406, 0.07724374693477194, 0.08631682197155469, 0.09686120647376165, 0.11316821971554683, 0.11194212849435999, 0.12420304070622855, 0.13879352623835214, 0.1503187837175086, 0.15890142226581658, 0.18109367336929869, 0.21174595389897008, 0.23406081412457086, 0.24963217263364396, 0.27378616969102504, 0.28469838155958804, 0.3087297694948504, 0.3247915644923982, 0.34048553212359, 0.3686856302108877, 0.38339872486513, 0.39492398234428644, 0.4146640510053948, 0.4318293281020108, 0.4507111329082884, 0.45696419813634137, 0.4795242766061795, 0.49877390877881317, 0.5144678764100049, 0.5321235899950957, 0.5361696910250122, 0.5486758214811182, 0.5579941147621383, 0.5675576262873958, 0.5779794016674841, 0.5908533594899461, 0.5968612064737616, 0.6121873467385973, 0.6183178028445316, 0.6270230505149583, 0.6358509073075037, 0.6468857282981854, 0.644310936733693, 0.6602501226091221, 0.6624570868072585, 0.6699362432564983, 0.674717999019127, 0.6911476213830309, 0.6905345757724375, 0.706718979892104, 0.7062285434036293, 0.7074546346248161, 0.7241294752329573, 0.7192251103482099, 0.728175576262874, 0.7408043158410986, 0.7415399705738107, 0.7437469347719471, 0.7540461010299166, 0.7590730750367828, 0.7627513487003433, 0.7665522314860226, 0.775380088278568, 0.7720696419813634, 0.7849435998038254, 0.7810201078960275, 0.7864149092692496, 0.7898479646885729, 0.793281020107896, 0.7965914664051006, 0.799166257969593, 0.8086071603727317, 0.8050514958312899, 0.8109367336929868, 0.812408043158411, 0.810446297204512, 0.8164541441883276, 0.8244237371260422, 0.8223393820500245, 0.8195193722412948, 0.8392594409024031, 0.8347229033840118, 0.8359489946051987, 0.8390142226581657, 0.8311672388425699, 0.8408533594899461, 0.8391368317802844, 0.8457577243746934, 0.8413437959784208, 0.8507846983815596, 0.8420794507111329, 0.8493133889161354, 0.8507846983815596, 0.8488229524276606, 0.860348209906817, 0.8549534085335949, 0.8553212358999509, 0.8592447278077489, 0.8645169200588524, 0.8632908288376655, 0.8604708190289357, 0.8613290828837665, 0.8646395291809711, 0.8712604217753801, 0.8621873467385973, 0.8661108386463953, 0.8712604217753801, 0.8652525747915645, 0.8699117214320745, 0.8723639038744483, 0.8721186856302109, 0.8787395782246199, 0.8742030407062286, 0.8793526238352133, 0.8762873957822462, 0.8807013241785189, 0.8829082883766552, 0.8838891613536047, 0.8792300147130947, 0.8832761157430112, 0.8868317802844532, 0.885728298185385, 0.8818048062775871, 0.8821726336439432, 0.8891613536047082, 0.8907552721922511, 0.8875674350171653, 0.8821726336439432, 0.8878126532614027, 0.8924717999019127, 0.8847474252084355, 0.896272682687592, 0.890019617459539, 0.896640510053948, 0.889651790093183, 0.8950465914664051, 0.8961500735654733, 0.8971309465424228, 0.9008092202059833, 0.8988474742520843, 0.8939431093673369, 0.8921039725355566, 0.8941883276115743, 0.9030161844041197, 0.9030161844041197, 0.8999509563511525, 0.8959048553212359, 0.9052231486022561, 0.8957822461991172, 0.907184894556155, 0.9004413928396273, 0.9039970573810692, 0.9030161844041197, 0.9036292300147131, 0.9058361942128494, 0.9046101029916626, 0.9025257479156449, 0.9054683668464933, 0.9041196665031879, 0.9068170671897989, 0.9025257479156449, 0.9091466405100539, 0.907552721922511, 0.9065718489455615, 0.9113536047081903, 0.9033840117704757, 0.9103727317312408, 0.9097596861206474, 0.9091466405100539, 0.9032614026483571, 0.9062040215792055, 0.9084109857773418, 0.9117214320745464, 0.9119666503187838, 0.9104953408533595, 0.9149092692496321, 0.9087788131436979, 0.91417361451692, 0.9127023050514959, 0.9079205492888671, 0.9147866601275135, 0.9104953408533595, 0.9171162334477685, 0.9135605689063266, 0.9168710152035311, 0.9140510053948013, 0.9155223148602256], "train_duration": 12.23677794456482, "test_accuracies": 0.8939929328621908, "nb_spikes": [16.030665556589764, 11.549198653962877, 9.689412222968208], "test_duration": 1.2312428951263428, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "SPSN-GS", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.1, "loss_mode": "mean", "data_augmentation": true, "h_shift": 0.1, "scale": 0.3, "dir": "regul/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} -------------------------------------------------------------------------------- /outputs/data_aug/results_SPSN-SB_1675890736.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [0.08190464187101033, 0.04388385458633325, 0.04179765565422249, 0.039772562415883964, 0.038323244667333854, 0.03692093636254109, 0.03524230890498552, 0.034288910044473195, 0.03325304253069778, 0.032105579848029905, 0.03103848737849506, 0.030233363505377964, 0.029547779831364786, 0.028459994479568054, 0.027640075119995146, 0.026994696328194484, 0.02628616907539293, 0.02533991946900458, 0.025116044949409482, 0.024245684274918546, 0.023605168786570395, 0.02294090878384671, 0.022755036887846605, 0.022257824690099907, 0.021470459054766362, 0.020940830173768386, 0.02049733715907448, 0.020020400442990558, 0.019347960667437113, 0.019330490535822142, 0.01905394863298911, 0.01834795609766973, 0.018008286422990946, 0.01771635985684313, 0.017490305368253682, 0.01732512715109544, 0.017027035684431226, 0.016640546583205594, 0.01636856541497968, 0.016159073687640402, 0.016046472694021395, 0.01561426175845026, 0.015308702127965073, 0.015039267932860041, 0.015072111867993062, 0.01457358123538189, 0.014531959135079863, 0.01470691040194812, 0.014336783857308164, 0.014139533642169714, 0.013764268067784611, 0.014019585542961802, 0.013563887030251649, 0.013549670606281547, 0.013620297886274094, 0.013179289821084076, 0.013029244847599345, 0.012646603556518872, 0.012815916323381174, 0.012633513788957114, 0.01260579931730146, 0.012179018842817347, 0.012522518802881825, 0.012414057277300125, 0.011952149244632834, 0.011954257532391495, 0.012083148908065545, 0.01152126319695478, 0.011444514758917793, 0.011539552497799458, 0.011727067910028826, 0.011279614895304492, 0.011210225175384683, 0.011119733785168104, 0.010732565193915262, 0.010676105355532395, 0.011157013914379317, 0.010453593784275447, 0.01047192239524571, 0.010592150492315025, 0.010603160553992292, 0.010283252289185283, 0.01023963798255883, 0.010161121798356763, 0.010228600910243128, 0.010024004722267112, 0.00992642595538209, 0.010031694850540443, 0.009560750242890886, 0.009611810065619557, 0.009479510825805186, 0.00941767764988923, 0.009234982620798648, 0.009373054924170255, 0.009301140352667051, 0.008965595559111526, 0.009251611348371287, 0.00907184851920856, 0.00918836429050122, 0.008752718290141424, 0.008718284885634272, 0.008859405214585411, 0.008616635850856794, 0.00856042225069132, 0.008496999338656561, 0.008431381973026428, 0.00841412433460099, 0.008261817238417369, 0.008394558923597603, 0.008086586214854122, 0.008053583168352042, 0.008112539886668244, 0.008014052667823594, 0.007762117255634394, 0.007842419174823648, 0.007778427782393133, 0.008054995742161055, 0.007791932093097393, 0.007795718489849667, 0.007487037259758308, 0.007373718902022889, 0.007435039531708114, 0.007463198094873115, 0.007504839651092938, 0.007624892897030605, 0.007216393556619404, 0.007737629733113695, 0.007234639617422154, 0.007461294842415314, 0.0071938515674138314, 0.0072792018962409225, 0.007221030264128184, 0.007105474357089908, 0.0071524171128525575, 0.007155909122147357, 0.0071722391450182955, 0.007086884509909089, 0.00687597863021699, 0.007050289940079852, 0.006736478289153542, 0.006915419599541032, 0.007000306782885005, 0.006771062144963048, 0.006865012303703611, 0.006679349502639013, 0.0066999723114413105, 0.0066108269349303295, 0.006810415932628207, 0.006456899562190127, 0.006703694462922346, 0.006744362523504307, 0.006488448970820051, 0.006383724720776812, 0.0065502559345045645, 0.006410643268356491, 0.00620496510005044, 0.006336645373472576, 0.006541824913773249, 0.006365362240541443, 0.005996008469828676, 0.0062490548242418835, 0.006339189640779364, 0.006048406778947578, 0.0059465348168791015, 0.006388198239011704, 0.006039560336277963, 0.0061728144306625555, 0.006274406550166185, 0.006131060394061553, 0.006150102932526821, 0.006155324044722795, 0.0061455323155777545, 0.006177412160563083, 0.006106591132588688, 0.005839688662309867, 0.0058501789756502925, 0.00585248905027129, 0.005878930959837761, 0.006020229207484263, 0.006140267526682479, 0.0056721595001314245, 0.006023927973131154, 0.0055951409155156695, 0.005515698191872878, 0.00570797791604963, 0.005641692175971579, 0.005615213536029355, 0.005706943805425877, 0.006005713614749815, 0.005583379123507907, 0.005564098919195074, 0.005638347692922847, 0.005798768763699445, 0.005492580046054193, 0.005490732122104854, 0.0054612339909547915, 0.005461738702931084, 0.005326137379681497, 0.0055867239262848405, 0.005594525727397265], "train_accuracies": [0.07540461010299167, 0.12530652280529672, 0.16601275134870033, 0.19801373222167729, 0.23406081412457086, 0.2509808729769495, 0.2708435507601766, 0.29916625796959295, 0.31535066208925944, 0.3429377145659637, 0.3702795487984306, 0.372854340362923, 0.38254046101029915, 0.41025012260912214, 0.4252084355076018, 0.4297449730259931, 0.45340853359489947, 0.4682442373712604, 0.4733938205002452, 0.4958312898479647, 0.5056400196174595, 0.519862677783227, 0.532000980872977, 0.5359244727807749, 0.5597106424717999, 0.5592202059833251, 0.5785924472780775, 0.5850907307503679, 0.6025012260912211, 0.5984551250613046, 0.6089995095635116, 0.629107405590976, 0.6234673859735165, 0.6375674350171653, 0.6414909269249632, 0.6459048553212359, 0.6452918097106425, 0.6506866110838646, 0.661476213830309, 0.6658901422265816, 0.6688327611574301, 0.679377145659637, 0.6853849926434527, 0.6926189308484552, 0.6869789112309956, 0.6959293771456596, 0.700833742030407, 0.6950711132908288, 0.7091711623344776, 0.7074546346248161, 0.719592937714566, 0.7106424717999019, 0.7244973025993134, 0.7225355566454145, 0.717631191760667, 0.7274399215301618, 0.726581657675331, 0.7411721432074546, 0.7346738597351643, 0.7403138793526238, 0.7436243256498284, 0.7533104462972046, 0.7477930358018636, 0.743379107405591, 0.7551495831289848, 0.7582148111819519, 0.7486512996566944, 0.7685139774399216, 0.7649583128984796, 0.7629965669445806, 0.7625061304561059, 0.767165277096616, 0.7746444335458558, 0.7715792054928887, 0.7806522805296714, 0.7860470819028936, 0.764835703776361, 0.7837175085826386, 0.7837175085826386, 0.7858018636586562, 0.7869053457577244, 0.7844531633153506, 0.7848209906817067, 0.7914418832761158, 0.7894801373222168, 0.7959784207945071, 0.7978175576262874, 0.7892349190779794, 0.799166257969593, 0.7984306032368809, 0.8051741049534086, 0.8041932319764591, 0.810078469838156, 0.8111819519372241, 0.8062775870524767, 0.8155958803334968, 0.8082393330063756, 0.814737616478666, 0.810446297204512, 0.8195193722412948, 0.8239333006375674, 0.8200098087297695, 0.8281020107896028, 0.8231976459048553, 0.8241785188818048, 0.8231976459048553, 0.8284698381559588, 0.8346002942618931, 0.8238106915154487, 0.8310446297204512, 0.8366846493379108, 0.83582638548308, 0.8372976949485041, 0.8445316331535067, 0.8392594409024031, 0.8442864149092693, 0.8344776851397744, 0.8425698871996077, 0.842815105443845, 0.8491907797940167, 0.8464933791074056, 0.8473516429622364, 0.8461255517410495, 0.8483325159391859, 0.8455125061304561, 0.8494359980382541, 0.8433055419323198, 0.8532368808239333, 0.8455125061304561, 0.8460029426189308, 0.8468612064737616, 0.8511525257479157, 0.8487003433055419, 0.8549534085335949, 0.8528690534575772, 0.8520107896027465, 0.8525012260912211, 0.8650073565473271, 0.8496812162824914, 0.8608386463952918, 0.8567925453653752, 0.8555664541441883, 0.8554438450220696, 0.8605934281510544, 0.8646395291809711, 0.8653751839136832, 0.8636586562040216, 0.8601029916625796, 0.8694212849435998, 0.8592447278077489, 0.8607160372731731, 0.8659882295242766, 0.868563021088769, 0.8634134379597842, 0.8722412947523296, 0.8727317312408043, 0.8648847474252085, 0.8659882295242766, 0.8679499754781755, 0.8803334968121628, 0.8690534575772437, 0.8700343305541932, 0.8778813143697891, 0.8776360961255517, 0.8677047572339383, 0.8784943599803825, 0.8760421775380088, 0.8739578224619912, 0.8717508582638548, 0.8711378126532614, 0.8735899950956351, 0.877513487003433, 0.8756743501716527, 0.870892594409024, 0.8797204512015694, 0.8836439431093673, 0.8804561059342815, 0.8813143697891123, 0.8760421775380088, 0.8718734673859735, 0.8800882785679255, 0.875183913683178, 0.883030897498774, 0.8910004904364884, 0.8807013241785189, 0.8874448258950466, 0.8847474252084355, 0.8848700343305542, 0.8750613045610593, 0.8870769985286905, 0.8854830799411476, 0.883766552231486, 0.877513487003433, 0.8881804806277587, 0.8870769985286905, 0.886096125551741, 0.8871996076508092, 0.887690044139284, 0.8854830799411476, 0.8865865620402158], "train_duration": 10.720955848693848, "test_accuracies": 0.8608657243816255, "nb_spikes": [55.048276053534615, 54.20732169681125, 46.59869352976481], "test_duration": 1.0690453052520752, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "SPSN-SB", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.0, "loss_mode": "mean", "data_augmentation": true, "h_shift": 0.1, "scale": 0.3, "dir": "data_aug/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} -------------------------------------------------------------------------------- /outputs/regul/results_SPSN-SB_1675890850.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [0.5159534273807064, 0.04842920572866465, 0.045519507504257865, 0.044263104412589606, 0.04343963104933719, 0.0427717064571708, 0.04217807193665366, 0.041237972307696304, 0.04053965150988412, 0.03996404422270778, 0.03951364170984639, 0.03853947570243263, 0.038141224125894393, 0.03762112244371692, 0.037163833153487535, 0.03676580657790606, 0.036005707012782676, 0.035819492602009236, 0.03493066118183295, 0.03462803716049176, 0.03384804219928309, 0.03339769335691107, 0.032717242303462866, 0.03193906866024966, 0.03139914634357544, 0.030823972644380287, 0.03039017962789699, 0.02965229683906907, 0.029232633397414325, 0.028491393279772053, 0.02782389701377416, 0.027102499454848144, 0.026675524695238803, 0.026249423791170706, 0.025622617926534454, 0.025256946880247967, 0.024149288568150594, 0.02368446111737542, 0.023624268158093688, 0.022975669312325104, 0.022511211576971602, 0.021892447077567347, 0.021864263252220417, 0.021233561029908938, 0.020771543131323643, 0.02050328994276712, 0.019950582700187288, 0.019429930471625615, 0.018938715914524673, 0.01916568982449158, 0.018521471181998598, 0.017999455975866482, 0.01786896817353264, 0.01750563748652237, 0.01686397905207078, 0.016986869198835148, 0.016539110140276633, 0.016208751930098374, 0.01594330873221612, 0.015661332860420943, 0.015618548652247625, 0.01495131611268622, 0.014937078904556023, 0.014485722110808159, 0.014222015771519735, 0.01435847606011148, 0.01369486540161318, 0.01372206257130949, 0.013424540425469443, 0.013288473435451495, 0.01304655017120222, 0.012977409013467072, 0.012614870536204586, 0.012190780503279793, 0.01198792845273848, 0.011840000814406997, 0.01190396902197538, 0.011510846649477914, 0.011445081920206343, 0.011392663106461141, 0.011340330703373105, 0.011236294526689014, 0.01111393648073216, 0.010634349696539149, 0.010567752150211222, 0.010239109616176705, 0.010227865783854873, 0.010128370796336212, 0.01009651629026758, 0.010091770113303767, 0.009780308862946207, 0.0096867154733757, 0.009673586289721297, 0.009525120017164884, 0.009499547965623632, 0.009366154656099185, 0.009307813310430235, 0.009144719015914708, 0.008700741399234858, 0.008962611153962977, 0.008668161607098497, 0.008996609406561287, 0.008688561468512594, 0.008406438382891476, 0.008753282488006303, 0.00841961811283396, 0.008216234389007412, 0.008102629726543445, 0.008267461932015337, 0.008094051448137752, 0.00805423379149023, 0.007774077112868231, 0.007732533495790997, 0.007641511385992732, 0.007996894291301987, 0.007895934711109488, 0.007523336344092075, 0.007597090995241579, 0.0075539080311908275, 0.007434916240859698, 0.007532911544164178, 0.0073344215144124205, 0.0073493176695761464, 0.007379530102759965, 0.007177428989623211, 0.007350772845797705, 0.0072435556621806415, 0.007098554106919773, 0.007015457653631002, 0.006855681851001976, 0.006889452504623515, 0.006927589795909833, 0.007037298332729897, 0.0068006676298809614, 0.0067738091363925565, 0.006627905638838236, 0.006757627659243892, 0.006514030707385272, 0.0064960785425720756, 0.006797782428171659, 0.0065851836826592, 0.0061757470215723295, 0.006603651709741328, 0.006381666348984921, 0.0063698392100138895, 0.006440011063675602, 0.006154712073983604, 0.006256619412607279, 0.0061573320785066905, 0.0061489589395419, 0.006182896714180693, 0.00592475231186966, 0.0058675307749444915, 0.0059707382983469155, 0.00603631968815985, 0.006058634796406838, 0.005910338385395117, 0.005657307240909195, 0.005936301197781757, 0.006073385568416955, 0.005757339786539012, 0.005684256814913518, 0.005857408737335327, 0.005693413298933339, 0.005549485904364214, 0.005775346822192588, 0.005793732852693747, 0.005869124104808276, 0.00577480806921147, 0.005468958978737143, 0.005675164196145131, 0.005553970152677538, 0.005516028527742276, 0.005690292999250682, 0.005668998782236591, 0.005560549461137085, 0.0054505150445452386, 0.005579785345640178, 0.0053558701712340335, 0.00529819436630763, 0.005352269384924134, 0.005341726561250963, 0.005184532742894357, 0.005348444988931903, 0.005007823307090907, 0.0051588518101979616, 0.0049155281227248274, 0.005274895147446501, 0.0051128689899434055, 0.005143033990078314, 0.005066794475294785, 0.0050364288228759, 0.0050063328401840105, 0.004999767658229901, 0.004851530250773937, 0.005020611463164862, 0.004959431819000216, 0.004989921918770096, 0.004801302379094952, 0.004831725979053021], "train_accuracies": [0.04941147621383031, 0.0734428641490927, 0.1040951446787641, 0.11831780284453164, 0.12358999509563512, 0.13842569887199607, 0.1425944090240314, 0.15988229524276606, 0.17447278077488965, 0.18746934771947033, 0.18575282000980872, 0.2021824423737126, 0.20279548798430602, 0.21358509073075035, 0.22829818538499264, 0.2313634134379598, 0.24865129965669447, 0.2507356547327121, 0.26667484060814123, 0.2731731240804316, 0.29462972045120156, 0.30014713094654244, 0.31461500735654735, 0.32233938205002455, 0.342815105443845, 0.349068170671898, 0.3679499754781756, 0.3805787150564002, 0.38952918097106426, 0.39823442864149095, 0.41589014222658166, 0.42569887199607653, 0.43661108386463954, 0.4538989700833742, 0.4650564001961746, 0.4659146640510054, 0.495708680725846, 0.5001226091221187, 0.5052721922511035, 0.5245218244237371, 0.5251348700343306, 0.5487984306032369, 0.5388670917116234, 0.5615497793035802, 0.5659637076998528, 0.5727072094163805, 0.582270720941638, 0.5966159882295243, 0.6059342815105444, 0.6022560078469839, 0.6204021579205493, 0.6283717508582638, 0.6261647866601275, 0.6398970083374204, 0.6506866110838646, 0.6514222658165768, 0.6554683668464933, 0.6667484060814125, 0.670794507111329, 0.6715301618440412, 0.6736145169200588, 0.6871015203531142, 0.689921530161844, 0.7045120156939676, 0.7092937714565963, 0.7057381069151545, 0.7187346738597352, 0.7236390387444825, 0.7264590485532123, 0.7298921039725356, 0.7282981853849927, 0.737126042177538, 0.7426434526728789, 0.7515939185875429, 0.7595635115252575, 0.761280039234919, 0.7572339382050024, 0.7625061304561059, 0.7703531142717018, 0.7641000490436488, 0.7746444335458558, 0.7775870524767042, 0.7763609612555175, 0.7894801373222168, 0.7855566454144188, 0.7904610102991663, 0.7968366846493379, 0.7990436488474743, 0.797204512015694, 0.795242766061795, 0.8060323688082394, 0.8064001961745954, 0.8077488965179009, 0.8110593428151054, 0.8092202059833251, 0.8122854340362923, 0.816699362432565, 0.8137567435017166, 0.8290828837665523, 0.8201324178518882, 0.8223393820500245, 0.8153506620892594, 0.8315350662089259, 0.8339872486512997, 0.8228298185384992, 0.8371750858263854, 0.8327611574301128, 0.8336194212849436, 0.8327611574301128, 0.836194212849436, 0.8422020598332516, 0.8446542422756254, 0.8442864149092693, 0.8453898970083374, 0.8391368317802844, 0.8354585581167239, 0.8473516429622364, 0.8500490436488475, 0.8462481608631682, 0.8494359980382541, 0.8458803334968121, 0.8509073075036783, 0.8483325159391859, 0.8551986267778323, 0.8574055909759686, 0.8544629720451201, 0.8571603727317313, 0.8565473271211378, 0.8574055909759686, 0.8647621383030898, 0.8608386463952918, 0.860348209906817, 0.8544629720451201, 0.8632908288376655, 0.8587542913192742, 0.8647621383030898, 0.8632908288376655, 0.8653751839136832, 0.8679499754781755, 0.8648847474252085, 0.8707699852869053, 0.8729769494850417, 0.8659882295242766, 0.8701569396763119, 0.8678273663560568, 0.870524767042668, 0.8783717508582638, 0.8704021579205493, 0.8753065228052967, 0.8738352133398725, 0.8727317312408043, 0.883030897498774, 0.883766552231486, 0.8773908778813144, 0.8739578224619912, 0.8804561059342815, 0.8777587052476704, 0.890387444825895, 0.8782491417361452, 0.875183913683178, 0.8847474252084355, 0.8868317802844532, 0.8853604708190289, 0.8838891613536047, 0.8897743992153017, 0.8809465424227563, 0.8836439431093673, 0.8869543894065719, 0.8826630701324178, 0.8897743992153017, 0.8829082883766552, 0.8919813634134379, 0.8892839627268269, 0.881436978911231, 0.8875674350171653, 0.8890387444825895, 0.8881804806277587, 0.8867091711623345, 0.8887935262383522, 0.8952918097106425, 0.890387444825895, 0.8889161353604709, 0.8970083374203041, 0.8933300637567435, 0.9017900931829328, 0.8948013732221677, 0.8976213830308974, 0.8918587542913192, 0.8946787641000491, 0.8929622363903874, 0.8978666012751348, 0.9000735654732712, 0.8984796468857283, 0.8983570377636096, 0.9014222658165768, 0.8981118195193722, 0.9008092202059833, 0.8978666012751348, 0.9021579205492889, 0.9022805296714076], "train_duration": 10.744028666019439, "test_accuracies": 0.8970848056537103, "nb_spikes": [49.20482667287191, 29.463692294226753, 29.59486675262451], "test_duration": 0.9459738731384277, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "SPSN-SB", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.4, "loss_mode": "mean", "data_augmentation": true, "h_shift": 0.1, "scale": 0.3, "dir": "regul/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} -------------------------------------------------------------------------------- /outputs/data_aug/results_SPSN-GS_1675891161.json: -------------------------------------------------------------------------------- 1 | {"loss_hist": [0.06516748070424536, 0.04161779587616109, 0.039664721828047235, 0.03777516019759428, 0.03652730921777573, 0.03498056494412088, 0.033708735927756715, 0.03245380322041028, 0.031009843400320976, 0.030089837479322434, 0.02936120904971128, 0.027868958666021768, 0.0266487805822306, 0.026153831650896246, 0.025373106788571645, 0.024628126384464766, 0.02392098763633324, 0.023452502611865596, 0.0232314166922146, 0.022446201983592626, 0.021751082458000313, 0.02184811167239909, 0.020870560536260636, 0.02099838905500512, 0.020582775002077783, 0.020351333381967138, 0.019736827852212663, 0.01921583867032134, 0.019333220130383473, 0.018676167575970418, 0.018740560316817326, 0.018118636522765484, 0.018274956017864167, 0.017381853284702518, 0.017387296951477758, 0.017280939873262267, 0.016867098342325828, 0.016594823192240503, 0.01606033648153682, 0.016417312426213953, 0.015770142286825904, 0.015706048884656043, 0.015604440088129442, 0.015510082581044412, 0.01472137304951947, 0.014930233015857414, 0.014568331663721972, 0.01457096572715082, 0.014088171371178864, 0.01415152063610625, 0.013909639645406228, 0.013940124761070207, 0.013833195570346594, 0.013530658061975589, 0.01322472503279985, 0.013481984822085935, 0.013017545382493615, 0.012698825147293431, 0.012597394083106325, 0.012346467137512125, 0.012798354555836725, 0.012361274280344638, 0.011987920965617517, 0.011959029762636159, 0.011724360517393789, 0.011946565178809885, 0.011770152454431177, 0.011891741855522182, 0.011167807306303236, 0.011332028100103534, 0.010990018961387502, 0.011253717115633975, 0.010953304762054507, 0.010834520013090324, 0.010892361705796633, 0.01057573012521654, 0.010979718966715589, 0.010457070687783005, 0.010599433239357562, 0.010131887452107776, 0.010459178953617445, 0.01010159920497289, 0.010247192136822757, 0.00974111098571406, 0.009803392256136441, 0.009996897741501795, 0.009703323135427426, 0.009709360159279494, 0.009681360318328657, 0.009639791267329773, 0.009316794735038322, 0.009490846283094389, 0.009376089703867165, 0.008860819330397864, 0.009276783235403999, 0.009089805858137562, 0.00931332193501385, 0.008721200174680815, 0.008750272380364415, 0.008777143781737056, 0.008617839720712216, 0.009074487007675713, 0.008807689430171102, 0.008697599596460407, 0.008767844872506241, 0.008684515989604331, 0.00839402509808482, 0.00842617627538379, 0.008288731110657237, 0.00835034954067011, 0.008501168373330894, 0.008328001396278355, 0.007969951391395487, 0.008401324080011902, 0.008077291416414466, 0.008117160240536289, 0.007927337174485162, 0.007991363731536753, 0.007887985321632374, 0.007940174153968948, 0.007866441445309137, 0.00800742705541244, 0.00786120552695692, 0.0077368935159955205, 0.007774415951680763, 0.007705279279275523, 0.007692559482918001, 0.007451723868236523, 0.00750972641894138, 0.007691570707926159, 0.00758357467920012, 0.007510209307184168, 0.0074198329825796525, 0.0074170280856323805, 0.007474578657589923, 0.007598637277557313, 0.007435116250033236, 0.00691588302455795, 0.0072753188319507915, 0.007377610527720739, 0.00719940375907565, 0.007387665321921647, 0.007146230735736129, 0.006946857103300539, 0.0071483324480559555, 0.0071264316856334935, 0.007113665835625871, 0.0067655783312878465, 0.00697869350361789, 0.007204308692966974, 0.006834709846020447, 0.0068447418079737535, 0.00731953144139259, 0.006801939152396737, 0.006671982016661052, 0.006937991122497303, 0.006648392382280215, 0.006762697227580574, 0.006825528222228921, 0.006755868637091158, 0.006770059873787078, 0.006856870178412783, 0.0065203448661610795, 0.006760745138923815, 0.00659363248210552, 0.006415540863115335, 0.006570020897926917, 0.0067586521433082615, 0.0064780157554553035, 0.006675922124116312, 0.006405549175953268, 0.006253140017046772, 0.006261809761469244, 0.006193745223592461, 0.00643770308099055, 0.006442957061245607, 0.006339022903435952, 0.006319939728429822, 0.006822188352401377, 0.0063347810419854605, 0.0063551637422167365, 0.006306218987311611, 0.005966967619433013, 0.006313046273309966, 0.0064392529406080996, 0.006125144384593108, 0.005761940236869777, 0.0063448215361141935, 0.006103819876529236, 0.005985583519894683, 0.005887138549021028, 0.006123461240337081, 0.005724377235897386, 0.005904864497910299, 0.00610388238247906, 0.006002219029359224, 0.006247508609525826, 0.0060731959787287395, 0.005968874701321155, 0.006067753578077174], "train_accuracies": [0.10201078960274644, 0.17373712604217753, 0.20953898970083373, 0.2372486512996567, 0.262873957822462, 0.2919323197645905, 0.31522805296714074, 0.3448994605198627, 0.36709171162334475, 0.3940657184894556, 0.3950465914664051, 0.42361451692005886, 0.4514467876410005, 0.46088769004413926, 0.4733938205002452, 0.49632172633643945, 0.5049043648847474, 0.513241785188818, 0.5181461500735655, 0.5358018636586562, 0.5441392839627268, 0.5430358018636586, 0.5657184894556155, 0.5560323688082394, 0.5687837175085826, 0.5782246199117215, 0.5831289847964689, 0.6020107896027465, 0.5961255517410495, 0.6078960274644434, 0.6107160372731731, 0.620524767042668, 0.6201569396763119, 0.6356056890632663, 0.6371996076508092, 0.6456596370769986, 0.6549779303580187, 0.6517900931829328, 0.6688327611574301, 0.6576753310446297, 0.6676066699362433, 0.674717999019127, 0.6791319274153997, 0.6758214811181952, 0.6983815595880334, 0.6927415399705739, 0.7003433055419324, 0.6910250122609122, 0.7020598332515939, 0.6999754781755763, 0.7175085826385483, 0.7106424717999019, 0.716037273173124, 0.7161598822952427, 0.7273173124080432, 0.7188572829818538, 0.7338155958803335, 0.7399460519862677, 0.7448504168710152, 0.7501226091221187, 0.7356547327121138, 0.7490191270230505, 0.7502452182442374, 0.7540461010299166, 0.7591956841589014, 0.7557626287395782, 0.7561304561059343, 0.7556400196174595, 0.7682687591956842, 0.7703531142717018, 0.7778322707209416, 0.7730505149583129, 0.7769740068661108, 0.7772192251103482, 0.7758705247670427, 0.7827366356056891, 0.7735409514467877, 0.7854340362923001, 0.7853114271701814, 0.7947523295733202, 0.7920549288867091, 0.8021088769004414, 0.7914418832761158, 0.8037027954879843, 0.8010053948013732, 0.7902157920549289, 0.8000245218244237, 0.7995340853359489, 0.8041932319764591, 0.7995340853359489, 0.8072584600294261, 0.8059097596861207, 0.8090975968612065, 0.8251593918587543, 0.8154732712113781, 0.8175576262873958, 0.805787150564002, 0.8196419813634135, 0.8178028445316331, 0.8193967631191761, 0.8240559097596861, 0.8140019617459539, 0.820990681706719, 0.82368808239333, 0.8230750367827366, 0.823320255026974, 0.8267533104462972, 0.8283472290338402, 0.8314124570868072, 0.8346002942618931, 0.8258950465914664, 0.8296959293771456, 0.8386463952918097, 0.8224619911721432, 0.8353359489946052, 0.8326385483079941, 0.8401177047572339, 0.8371750858263854, 0.8408533594899461, 0.837788131436979, 0.8370524767042667, 0.8346002942618931, 0.834232466895537, 0.8401177047572339, 0.8363168219715547, 0.8440411966650319, 0.8399950956351152, 0.8479646885728298, 0.8448994605198626, 0.8420794507111329, 0.8455125061304561, 0.8456351152525748, 0.8493133889161354, 0.8515203531142717, 0.8455125061304561, 0.8401177047572339, 0.8460029426189308, 0.8539725355566454, 0.8532368808239333, 0.8534820990681706, 0.8560568906326631, 0.8448994605198626, 0.8566699362432565, 0.8585090730750368, 0.8571603727317313, 0.8495586071603727, 0.8512751348700344, 0.8632908288376655, 0.8580186365865621, 0.8554438450220696, 0.8592447278077489, 0.8583864639529181, 0.8520107896027465, 0.8616969102501226, 0.8659882295242766, 0.8523786169691026, 0.8578960274644434, 0.8679499754781755, 0.8582638548307994, 0.8629230014713095, 0.8580186365865621, 0.8588769004413929, 0.8658656204021579, 0.8626777832270721, 0.862309955860716, 0.8670917116233448, 0.8654977930358019, 0.863903874448259, 0.8664786660127514, 0.8626777832270721, 0.8695438940657185, 0.8721186856302109, 0.8699117214320745, 0.8721186856302109, 0.8677047572339383, 0.8684404119666503, 0.8701569396763119, 0.8695438940657185, 0.8551986267778323, 0.8717508582638548, 0.8710152035311427, 0.8656204021579206, 0.8788621873467386, 0.8713830308974988, 0.8658656204021579, 0.8734673859735165, 0.8777587052476704, 0.8667238842569888, 0.8740804315841099, 0.8766552231486022, 0.8832761157430112, 0.8753065228052967, 0.8786169691025012, 0.874816086316822, 0.8715056400196175, 0.8735899950956351, 0.8688082393330063, 0.8723639038744483, 0.8742030407062286, 0.8762873957822462], "train_duration": 12.287769885063172, "test_accuracies": 0.8608657243816255, "nb_spikes": [57.94560729132758, 62.49844593471951, 57.93204837375217], "test_duration": 1.251098871231079, "PARAMS": {"seed": 0, "dataset": "heidelberg", "neuron": "SPSN-GS", "nb_epochs": 200, "tau_mem": 0.02, "tau_syn": 0.02, "batch_size": 64, "hidden_size": 128, "nb_layers": 3, "reg_thr": 0.0, "loss_mode": "mean", "data_augmentation": true, "h_shift": 0.1, "scale": 0.3, "dir": "data_aug/", "save_model": true, "debug": false, "input_size": 700, "nb_class": 20}} --------------------------------------------------------------------------------