├── .gitmodules ├── README.md ├── snn_utils.py └── tutorial02_dcll_classification.ipynb /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "decolle"] 2 | path = thirdpartylibs 3 | url = git@github.com:nmi-lab/decolle-public.git 4 | [submodule "decolle-public"] 5 | path = decolle-public 6 | url = git@github.com:nmi-lab/decolle-public.git 7 | branch = master 8 | [submodule "decolle_public"] 9 | path = decolle_public 10 | url = git@github.com:nmi-lab/decolle-public.git 11 | branch = master 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Synaptic Plasticity Dynamics for Deep Continuous Local Learning (DECOLLE) 2 | 3 | This repo contains a tutorial for the [PyTorch](https://pytorch.org/) implementation of the DECOLLE learning rule presented in [this paper](https://arxiv.org/abs/1811.10766). 4 | 5 | If you use this code in a scientific publication, please include the following reference in your bibliography: 6 | 7 | ``` 8 | @article{kaiser2018synaptic, 9 | title={Synaptic Plasticity Dynamics for Deep Continuous Local Learning (DECOLLE)}, 10 | author={Kaiser, Jacques and Mostafa, Hesham and Neftci, Emre}, 11 | journal={arXiv preprint arXiv:1811.10766}, 12 | year={2018} 13 | } 14 | ``` 15 | ## Tutorials 16 | 17 | The first notebook under the tutorials is standalone except for snn_utils.py. 18 | Step-by-step instructions for setting up spiking neural networks in PyTorch and setting up DECOLLE are provided. 19 | See for example [tutorial1.ipynb](tutorial1.ipynb). 20 | 21 | ## Google Colab 22 | You can open the notebook in colab, but you will have to upload or clone snn_utils.py by yourself 23 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/surrogate-gradient-learning/pytorch-lif-autograd/blob/master/) 24 | -------------------------------------------------------------------------------- /snn_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from torchvision import datasets, models, transforms, utils 4 | import torch 5 | from torch.utils.data.dataloader import DataLoader 6 | #import tqdm 7 | import copy 8 | from collections import namedtuple 9 | 10 | def __gen_ST(N, T, rate, mode = 'regular'): 11 | if mode == 'regular': 12 | spikes = np.zeros([T, N]) 13 | spikes[::(1000//rate)] = 1 14 | return spikes 15 | elif mode == 'poisson': 16 | spikes = np.ones([T, N]) 17 | spikes[np.random.binomial(1,float(1000. - rate)/1000, size=(T,N)).astype('bool')] = 0 18 | return spikes 19 | else: 20 | raise Exception('mode must be regular or Poisson') 21 | 22 | def spiketrains(N, T, rates, mode = 'poisson'): 23 | ''' 24 | *N*: number of neurons 25 | *T*: number of time steps 26 | *rates*: vector or firing rates, one per neuron 27 | *mode*: 'regular' or 'poisson' 28 | ''' 29 | if not hasattr(rates, '__iter__'): 30 | return __gen_ST(N, T, rates, mode) 31 | rates = np.array(rates) 32 | M = rates.shape[0] 33 | spikes = np.zeros([T, N]) 34 | for i in range(M): 35 | if int(rates[i])>0: 36 | spikes[:,i] = __gen_ST(1, T, int(rates[i]), mode = mode).flatten() 37 | return spikes 38 | 39 | def spikes_to_evlist(spikes): 40 | t = np.tile(np.arange(spikes.shape[0]), [spikes.shape[1],1]) 41 | n = np.tile(np.arange(spikes.shape[1]), [spikes.shape[0],1]).T 42 | return t[spikes.astype('bool').T], n[spikes.astype('bool').T] 43 | 44 | def plotLIF(U, S, Vplot = 'all', staggering= 1, ax1=None, ax2=None, **kwargs): 45 | ''' 46 | This function plots the output of the function LIF. 47 | 48 | Inputs: 49 | *S*: an TxNnp.array, where T are time steps and N are the number of neurons 50 | *S*: an TxNnp.array of zeros and ones indicating spikes. This is the second 51 | output return by function LIF 52 | *Vplot*: A list indicating which neurons' membrane potentials should be 53 | plotted. If scalar, the list range(Vplot) are plotted. Default: 'all' 54 | *staggering*: the amount by which each V trace should be shifted. None 55 | 56 | Outputs the figure returned by figure(). 57 | ''' 58 | V = U 59 | spikes = S 60 | #Plot 61 | t, n = spikes_to_evlist(spikes) 62 | #f = plt.figure() 63 | if V is not None and ax1 is None: 64 | ax1 = plt.subplot(211) 65 | elif V is None: 66 | ax1 = plt.axes() 67 | ax2 = None 68 | ax1.plot(t, n, 'k|', **kwargs) 69 | ax1.set_ylim([-1, spikes.shape[1] + 1]) 70 | ax1.set_xlim([0, spikes.shape[0]]) 71 | 72 | if V is not None: 73 | if Vplot == 'all': 74 | Vplot = range(V.shape[1]) 75 | elif not hasattr(Vplot, '__iter__'): 76 | Vplot = range(np.minimum(Vplot, V.shape[1])) 77 | 78 | if ax2 is None: 79 | ax2 = plt.subplot(212) 80 | 81 | if V.shape[1]>1: 82 | for i, idx in enumerate(Vplot): 83 | ax2.plot(V[:,idx]+i*staggering,'-', **kwargs) 84 | else: 85 | ax2.plot(V[:,0], '-', **kwargs) 86 | 87 | if staggering!=0: 88 | plt.yticks([]) 89 | plt.xlabel('time [ms]') 90 | plt.ylabel('u [au]') 91 | 92 | ax1.set_ylabel('Neuron ') 93 | 94 | plt.xlim([0, spikes.shape[0]]) 95 | plt.ion() 96 | plt.show() 97 | return ax1,ax2 98 | 99 | 100 | 101 | 102 | input_shape = [28,28,1] 103 | 104 | 105 | def to_one_hot(t, width): 106 | t_onehot = torch.zeros(*t.shape+(width,)) 107 | return t_onehot.scatter_(1, t.unsqueeze(-1), 1) 108 | 109 | 110 | 111 | def image2spiketrain(x,y,gain=50,min_duration=None, max_duration=500): 112 | y = to_one_hot(y, 10) 113 | if min_duration is None: 114 | min_duration = max_duration-1 115 | batch_size = x.shape[0] 116 | T = np.random.randint(min_duration,max_duration,batch_size) 117 | Nin = np.prod(input_shape) 118 | allinputs = np.zeros([batch_size,max_duration, Nin]) 119 | for i in range(batch_size): 120 | st = spiketrains(T = T[i], N = Nin, rates=gain*x[i].reshape(-1)).astype(np.float32) 121 | allinputs[i] = np.pad(st,((0,max_duration-T[i]),(0,0)),'constant') 122 | allinputs = np.transpose(allinputs, (1,0,2)) 123 | allinputs = allinputs.reshape(allinputs.shape[0],allinputs.shape[1],1, 28,28) 124 | 125 | alltgt = np.zeros([max_duration, batch_size, 10], dtype=np.float32) 126 | for i in range(batch_size): 127 | alltgt[:,i,:] = y[i] 128 | 129 | return allinputs, alltgt 130 | 131 | def target_convolve(tgt,alpha=8,alphas=5): 132 | max_duration = tgt.shape[0] 133 | kernel_alpha = np.exp(-np.linspace(0,10*alpha,dtype='float')/alpha) 134 | kernel_alpha /= kernel_alpha.sum() 135 | kernel_alphas = np.exp(-np.linspace(0,10*alphas,dtype='float')/alphas) 136 | kernel_alphas /= kernel_alphas.sum() 137 | tgt = tgt.copy() 138 | for i in range(tgt.shape[1]): 139 | for j in range(tgt.shape[2]): 140 | tmp=np.convolve(np.convolve(tgt[:,i,j],kernel_alpha),kernel_alphas)[:max_duration] 141 | tgt[:,i,j] = tmp 142 | return tgt/tgt.max() 143 | 144 | 145 | datasetConfig = namedtuple('config',['image_size','batch_size','data_path']) 146 | 147 | #class DataLoaderIterPreProcessed(_DataLoaderIter): 148 | # def __next__(self): 149 | # indices = next(self.sample_iter) # may raise StopIteration 150 | # td, tl = self.dataset.data, self.dataset.targets 151 | # batch = self.collate_fn([(td[i], tl[i]) for i in indices]) 152 | # return batch 153 | # 154 | #class DataLoaderPreProcessed(DataLoader): 155 | # def __iter__(self): 156 | # return DataLoaderIterPreProcessed(self) 157 | 158 | def sequester(tensor): 159 | dtype = tensor.dtype 160 | return torch.tensor(tensor.detach().cpu().numpy(), dtype=dtype) 161 | 162 | 163 | #def preprocess_dataset(dataset): 164 | # x, y = dataset.data[0], dataset.targets[0] 165 | # td = torch.empty(torch.Size([len(dataset)])+x.shape, dtype = torch.float32) 166 | # if not hasattr(y, 'shape'): 167 | # tl = torch.empty(torch.Size([len(dataset)]), dtype = torch.int) 168 | # else: 169 | # tl = torch.empty(torch.Size([len(dataset)])+y.shape, dtype = y.dtype) 170 | # for idx in tqdm.tqdm(range(len(dataset)), desc = "Pre-processing dataset"): 171 | # td[idx], tl[idx] = dataset[idx] 172 | # 173 | # dataset.data, dataset.targets = td, tl 174 | # 175 | # 176 | # dataset.transform = None 177 | # return dataset 178 | 179 | 180 | def pixel_permutation(d_size, r_pix=1.0, seed=0): 181 | import copy 182 | n_pix = int(r_pix * d_size) 183 | np.random.seed(seed*1313) 184 | pix_sel = np.random.choice(d_size, n_pix, replace=False).astype(np.int32) 185 | pix_prm = np.copy(pix_sel) 186 | np.random.shuffle(pix_prm) 187 | perm_inds = np.arange(d_size) 188 | perm_inds[pix_sel] = perm_inds[pix_prm] 189 | return perm_inds 190 | 191 | def permute_dataset(dataset, r_pix, seed): 192 | shape = dataset.data.shape[1:] 193 | datap = dataset.data.view(-1, np.prod(shape)).detach().numpy() 194 | perm = pixel_permutation(np.prod(shape), r_pix, seed=seed) 195 | return torch.FloatTensor(datap[:,perm].reshape(-1,*shape)) 196 | 197 | def partition_dataset(dataset, Nparts=60, part=0): 198 | N = len(dataset.data) 199 | 200 | idx = np.arange(N, dtype='int') 201 | 202 | step = (N//Nparts) 203 | idx = idx[step*part:step*(part+1)] 204 | 205 | td = dataset.data[idx] 206 | tl = dataset.targets[idx] 207 | return td, tl 208 | 209 | def dynaload(dataset, 210 | batch_size, 211 | name, 212 | DL, 213 | perm=0., 214 | Nparts=1, 215 | part=0, 216 | seed=0, 217 | taskid=0, 218 | base_perm=.0, 219 | base_seed=0, 220 | train = True, 221 | **loader_kwargs): 222 | if base_perm>0: 223 | data = permute_dataset(dataset, base_perm, seed=base_seed) 224 | dataset.data = data 225 | if perm>0: 226 | data = permute_dataset(dataset, perm, seed=seed) 227 | dataset.data = data 228 | 229 | loader = DL(dataset=dataset, 230 | batch_size=batch_size, 231 | shuffle=dataset.train, 232 | **loader_kwargs) 233 | 234 | loader.taskid = taskid 235 | loader.name = name +'_{}'.format(part) 236 | loader.short_name = name 237 | return loader 238 | 239 | 240 | def mnist_loader_dynamic( 241 | config, 242 | train, 243 | pre_processed = True, 244 | Nparts=1, 245 | part=1): 246 | """Builds and returns Dataloader for MNIST and SVHN dataset.""" 247 | 248 | transform = transforms.Compose([ 249 | transforms.Resize(config.image_size), 250 | transforms.Grayscale(), 251 | transforms.ToTensor(), 252 | transforms.Normalize((0.0,), (1.0,))]) 253 | 254 | 255 | dataset = datasets.MNIST(root=config.data_path, download=True, transform=transform, train = train) 256 | if Nparts>1: 257 | data, targets = partition_dataset(dataset, Nparts, part) 258 | dataset.data = data 259 | dataset.targets = targets 260 | 261 | if pre_processed: 262 | dataset_ = preprocess_dataset(dataset) 263 | DL = DataLoaderPreProcessed 264 | else: 265 | dataset_ = dataset 266 | DL = DataLoader 267 | batch_size = config.batch_size 268 | name = 'MNIST' 269 | 270 | return dataset_, name, DL 271 | 272 | def get_mnist_loader( 273 | batch_size, 274 | train, 275 | perm=0., 276 | Nparts=1, 277 | part=0, 278 | seed=0, 279 | taskid=0, 280 | pre_processed=False, 281 | base_perm=.0, 282 | base_seed=0, 283 | **loader_kwargs): 284 | 285 | config = datasetConfig(image_size = [28,28], batch_size = batch_size, data_path = './data/mnist') 286 | 287 | d,name,dl = mnist_loader_dynamic(config, train, pre_processed, Nparts, part) 288 | 289 | return dynaload( 290 | d, 291 | config.batch_size, 292 | name, 293 | dl, 294 | perm=perm, 295 | Nparts=Nparts, 296 | part=part, 297 | seed=seed, 298 | taskid=taskid, 299 | base_perm=base_perm, 300 | base_seed=base_seed, 301 | train = train, 302 | **loader_kwargs) 303 | 304 | def usps_loader_dynamic(config, train, pre_processed=False, Nparts=1, part=1): 305 | """Builds and returns Dataloader for MNIST and SVHN dataset.""" 306 | from usps_loader import USPS 307 | 308 | transform = transforms.Compose([ 309 | # transforms.ToPILImage(), 310 | transforms.Resize(config.image_size), 311 | transforms.Grayscale(), 312 | transforms.ToTensor(), 313 | transforms.Normalize((0.0,), (1.0,))]) 314 | 315 | dataset = USPS(root=config.data_path, download=True, transform=transform, train = train) 316 | name = 'USPS' 317 | 318 | if Nparts>1: 319 | partition_dataset(dataset, Nparts, part) 320 | 321 | if pre_processed: 322 | dataset = preprocess_dataset(dataset) 323 | DL = DataLoaderPreProcessed 324 | else: 325 | DL = DataLoader 326 | 327 | return dataset,name, DL 328 | 329 | 330 | def get_usps_loader(config, train, perm=0., Nparts=1, part=0, seed=0, taskid=0, pre_processed=False, **loader_kwargs): 331 | """Builds and returns Dataloader for MNIST and SVHN dataset.""" 332 | dataset,name,DL = usps_loader_dynamic(config, train, pre_processed) 333 | 334 | return dynaload(dataset, 335 | config.batch_size, 336 | name, 337 | DL, 338 | perm=perm, 339 | Nparts=Nparts, 340 | part=part, 341 | seed=seed, 342 | taskid=taskid, 343 | base_perm=base_perm, 344 | base_seed=base_seed, 345 | train = train, 346 | **loader_kwargs) 347 | 348 | def svhn_loader_dynamic(config, train, pre_processed=False, Nparts=1, part=1): 349 | """Builds and returns Dataloader for MNIST and SVHN dataset.""" 350 | 351 | transform = transforms.Compose([ 352 | transforms.Resize(config.image_size), 353 | transforms.Grayscale(), 354 | transforms.ToTensor(), 355 | transforms.Normalize((0.35,), (0.65,)), 356 | transforms.Lambda(lambda x: x.view(np.prod(config.image_size))), 357 | transforms.Lambda(lambda x: x*2-1)]) 358 | 359 | name = 'SVHN' 360 | 361 | dataset = datasets.SVHN(root=config.data_path, download=True, transform=transform, split = 'train' if train else 'test') 362 | dataset.train = train 363 | 364 | if Nparts>1: 365 | partition_dataset(dataset, Nparts, part) 366 | 367 | if pre_processed: 368 | dataset = preprocess_dataset(dataset) 369 | DL = DataLoaderPreProcessed 370 | else: 371 | DL = DataLoader 372 | 373 | return dataset,name, DL 374 | 375 | def get_svhn_loader_dynamic(config, train, perm=0, taskid=0, seed=0, Nparts=1, part=0, pre_processed=False, **loader_kwargs): 376 | dataset,name,DL = svhn_loader_dynamic(config, train, pre_processed) 377 | 378 | return dynaload(dataset, 379 | config.batch_size, 380 | name, 381 | DL, 382 | perm=perm, 383 | Nparts=Nparts, 384 | part=part, 385 | seed=seed, 386 | taskid=taskid, 387 | base_perm=base_perm, 388 | base_seed=base_seed, 389 | train = train, 390 | **loader_kwargs) 391 | -------------------------------------------------------------------------------- /tutorial02_dcll_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import snn_utils\n", 11 | "import pylab as plt\n", 12 | "device = 'cuda'" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "This is a tutorial for apply a DECOLLE network to a traditional MNIST problem usingwith local errors. The reader is assumed to be familiar with Python and Pytorch. If you need to cite this code, please use (Kaiser, Mostafa, Neftci, 2019), bibtex:" 20 | ] 21 | }, 22 | { 23 | "cell_type": "raw", 24 | "metadata": {}, 25 | "source": [ 26 | "@Article{Kaiser_etal18,\n", 27 | "author\t\t= {Kaiser, J. and Mostafa, H. and Neftci, E.},\n", 28 | "booktitle\t= {arXiv preprint},\n", 29 | "journal\t\t= {arXiv preprint arXiv:1812.10766},\n", 30 | "link\t\t= {http://arxiv.org/pdf/1812.10766},\n", 31 | "title\t\t= {Synaptic Plasticity for Deep Continuous Local Learning},\n", 32 | "year\t\t= {2018},\n", 33 | "contrib\t\t= {80\\%}\n", 34 | "}" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "### Loading MNIST data as spiketrains" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "The following function will load the MNIST dataset using torchvision modules. It will download and pre-pre-process the data for faster usage." 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 2, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "gen_train = snn_utils.get_mnist_loader(100, Nparts=100, train=True)\n", 58 | "gen_test = snn_utils.get_mnist_loader(100, Nparts=100, train=False)" 59 | ] 60 | }, 61 | { 62 | "cell_type": "markdown", 63 | "metadata": {}, 64 | "source": [ 65 | "Because MNIST is an image, we need to transform it into a spiketrain. The function __image2spiketrain__ in snn_utils takes case of this." 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 3, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "datait = iter(gen_train)\n", 75 | "raw_input, raw_labels = next(datait)\n", 76 | "data, labels1h = snn_utils.image2spiketrain(raw_input, raw_labels, max_duration=1000, gain=20)\n", 77 | "data_t = torch.FloatTensor(data)\n", 78 | "labels_t = torch.Tensor(labels1h)" 79 | ] 80 | }, 81 | { 82 | "cell_type": "markdown", 83 | "metadata": {}, 84 | "source": [ 85 | "Let's examine the shape of data:" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 4, 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "data": { 95 | "text/plain": [ 96 | "(1000, 100, 1, 28, 28)" 97 | ] 98 | }, 99 | "execution_count": 4, 100 | "metadata": {}, 101 | "output_type": "execute_result" 102 | } 103 | ], 104 | "source": [ 105 | "data.shape" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "1000 here corresponds to the number of time setps, 100 is the batchsize, 1 is the number of channels and 28,28 are the height and width, respectively. The last three dimensions will be important when we'll use convolutional or locally connected layers, but for the moment, our network has no structure." 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "Here is what one sample looks like" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 5, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "image/png": "\n", 130 | "text/plain": [ 131 | "
" 132 | ] 133 | }, 134 | "metadata": { 135 | "needs_background": "light" 136 | }, 137 | "output_type": "display_data" 138 | }, 139 | { 140 | "data": { 141 | "text/plain": [ 142 | "(, None)" 143 | ] 144 | }, 145 | "execution_count": 5, 146 | "metadata": {}, 147 | "output_type": "execute_result" 148 | } 149 | ], 150 | "source": [ 151 | "snn_utils.plotLIF(U=None,S=data_t[:,0].view(1000,-1).data.numpy())" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "The average reate here is more revealing. Our MNIST input spike trains are simply flattened spiketrains representing the digit image in the firing rates." 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": 6, 164 | "metadata": {}, 165 | "outputs": [ 166 | { 167 | "data": { 168 | "text/plain": [ 169 | "" 170 | ] 171 | }, 172 | "execution_count": 6, 173 | "metadata": {}, 174 | "output_type": "execute_result" 175 | }, 176 | { 177 | "data": { 178 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+j8jraAAANUUlEQVR4nO3dbYxU53nG8esywSwmYExsMAVaiEsao1Ql6QqnxWpdWbGIpQg7Uprg1qWRJfIhbmM1qmqlamOpUoXaJFY/pJFwTU3cxFFaB5kPlmNEIiEnhHhtAQbTGtshMS9m6+KIdRNe9+6HPa4WvPPMMnPmBe7/TxrNzLnnzLk17MU5M8+ceRwRAnD5u6LXDQDoDsIOJEHYgSQIO5AEYQeSeFc3N3alp8WAZnRzk0AqJ/W/Oh2nPFGtrbDbXiXpnyRNkfQvEbG+9PgBzdBNvrWdTQIo2BnbGtZaPoy3PUXSVyV9VNIySWtsL2v1+QB0Vjvv2VdIejkiXo2I05K+JWl1PW0BqFs7YV8g6bVx9w9Vy85je53tIdtDZ3Sqjc0BaEc7YZ/oQ4B3fPc2IjZExGBEDE7VtDY2B6Ad7YT9kKRF4+4vlHSkvXYAdEo7YX9W0lLbS2xfKelTkrbU0xaAurU89BYRZ23fK+m7Ght62xgR+2rrDECt2hpnj4gnJT1ZUy8AOoivywJJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRFenbAbGO/DIbxfr7xo4U6zfsP5ssT6668WL7ulyxp4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnB0ddeKuDzesLVv80+K6N856vVjfsX5JsT5jVbGcTltht31Q0oikc5LORsRgHU0BqF8de/Y/iIg3angeAB3Ee3YgiXbDHpKetv2c7XUTPcD2OttDtofO6FSbmwPQqnYP41dGxBHbcyVttf2fEbF9/AMiYoOkDZI0y3Oize0BaFFbe/aIOFJdD0vaLGlFHU0BqF/LYbc9w/bMt29Luk3S3roaA1Cvdg7j50nabPvt5/lmRDxVS1e4ZExZ+t5y/U+GG9Ze+/ns4rp/v3hzsf7Ml24q1qVXm9RzaTnsEfGqpN+qsRcAHcTQG5AEYQeSIOxAEoQdSIKwA0lwiivaEjMGivXZA2+2/Nx37/p0sX5Vy8+cE3t2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCcXa05ZX7m/wJvX5dw9L0JlMy//Gv/7hY/96L5VNcR4vVfNizA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLOjKFYuL9bnzj5RrM8e+GXD2ks/XFxc96GTK4v1hXOvLNanFqv5sGcHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQYZ0fRL/6mPI5++OC1xfrw7JMNa/9x14PFde/cfF+xPm3HvmKd89nP13TPbnuj7WHbe8ctm2N7q+0D1fU1nW0TQLsmcxj/iKRVFyy7X9K2iFgqaVt1H0Afaxr2iNgu6fgFi1dL2lTd3iTpjpr7AlCzVj+gmxcRRyWpup7b6IG219kesj10Rqda3ByAdnX80/iI2BARgxExOFXTOr05AA20GvZjtudLUnU9XF9LADqh1bBvkbS2ur1W0hP1tAOgU5qOs9t+TNItkq61fUjSFyWtl/Rt2/dI+pmkT3SySfSvWdePFOuPLv/XhrVdJxcW1x2dXh4pHx0pbxvnaxr2iFjToHRrzb0A6CC+LgskQdiBJAg7kARhB5Ig7EASnOKa3JnbBps84sLTIs73yxdnF+uP/urvNKx9ZNbehjVJWvro6WIdF4c9O5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwTj7Za7ZlMsH74pi/ear3yjWZ/9u4ymZJWn/iesb1p546TeL6y75wa5iHReHPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4+2Xu+PunF+t+szzOvvvYgmJ95kB5Sq+Rk41nAZq1dUZxXdSLPTuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJME4+2XuzQ+Ux9Hn3zhcrJ/893nF+sjHfl6sL5rduD66vTxGf65YxcVqume3vdH2sO2945Y9YPuw7V3V5fbOtgmgXZM5jH9E0qoJlj8YEcury5P1tgWgbk3DHhHb1WwOIAB9r50P6O61vac6zL+m0YNsr7M9ZHvojMrv0QB0Tqth/5qkGyQtl3RU0pcbPTAiNkTEYEQMTlXjkyIAdFZLYY+IYxFxLiJGJT0kaUW9bQGoW0thtz1/3N07JZXn3gXQc03H2W0/JukWSdfaPiTpi5Jusb1cUkg6KOkzHewRTVwxc2bD2sd/f2dx3dLvukvS4eku1t/6ydXF+ut7G8/f/p4DO4rrol5Nwx4RayZY/HAHegHQQXxdFkiCsANJEHYgCcIOJEHYgSQ4xfUS0Gza5Vf+bLRx7anyzzWfm14+BXbKzW+V68WqdN3f/qRhrXHX6AT27EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBOPsl4Bm0y5PH2j8c823rfpRcd0dw0uK9X9+/zeL9dVP/XmxPjoyUqyje9izA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EASjLNfAn7xK+Wfc54z0HharWOnGv/MtCStXri7WP+jr/5Fsf6+f/xhsY7+wZ4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0PXLF8WbH+6U9+t1j/yzmvNKwt2bKuuO7u6xcU69ftPl2s49LRdM9ue5Ht79veb3uf7c9Vy+fY3mr7QHV9TefbBdCqyRzGn5X0+Yi4UdKHJX3W9jJJ90vaFhFLJW2r7gPoU03DHhFHI+L56vaIpP2SFkhaLWlT9bBNku7oVJMA2ndRH9DZXizpg5J2SpoXEUelsf8QJM1tsM4620O2h86o8Xe4AXTWpMNu+92SHpd0X0ScmOx6EbEhIgYjYnCqprXSI4AaTCrstqdqLOjfiIjvVIuP2Z5f1edLGu5MiwDq0HTozbYlPSxpf0R8ZVxpi6S1ktZX1090pMMEjjxQnjb5ob0ri/V/G1jRsDb9cPmf+KrtVxfrU58u/xQ1Lh2TGWdfKeluSS/Y3lUt+4LGQv5t2/dI+pmkT3SmRQB1aBr2iHhGUqNfT7i13nYAdApflwWSIOxAEoQdSIKwA0kQdiAJTnHtA6M/KJ8wuHrNjmL98T0faly8ofwV5UV/91yxjssHe3YgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9j4w7Xj5fPa755TH2Y/9RuNpmf/n41cV1z1brOJywp4dSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JwRHmMt06zPCduMj9IC3TKztimE3F8wl+DZs8OJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0k0DbvtRba/b3u/7X22P1ctf8D2Ydu7qsvtnW8XQKsm8+MVZyV9PiKetz1T0nO2t1a1ByPiS51rD0BdJjM/+1FJR6vbI7b3S1rQ6cYA1Oui3rPbXizpg5J2Vovutb3H9kbbE85hZHud7SHbQ2dUnooIQOdMOuy23y3pcUn3RcQJSV+TdIOk5Rrb8395ovUiYkNEDEbE4FRNq6FlAK2YVNhtT9VY0L8REd+RpIg4FhHnImJU0kOSVnSuTQDtmsyn8Zb0sKT9EfGVccvnj3vYnZL21t8egLpM5tP4lZLulvSC7V3Vsi9IWmN7uaSQdFDSZzrSIYBaTObT+GckTXR+7JP1twOgU/gGHZAEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IImuTtls+78l/XTcomslvdG1Bi5Ov/bWr31J9NaqOnv7tYi4bqJCV8P+jo3bQxEx2LMGCvq1t37tS6K3VnWrNw7jgSQIO5BEr8O+ocfbL+nX3vq1L4neWtWV3nr6nh1A9/R6zw6gSwg7kERPwm57le3/sv2y7ft70UMjtg/afqGahnqox71stD1se++4ZXNsb7V9oLqecI69HvXWF9N4F6YZ7+lr1+vpz7v+nt32FEkvSfqIpEOSnpW0JiJe7GojDdg+KGkwInr+BQzbvyfpLUlfj4gPVMv+QdLxiFhf/Ud5TUT8VZ/09oCkt3o9jXc1W9H88dOMS7pD0p+qh69doa8/VBdet17s2VdIejkiXo2I05K+JWl1D/roexGxXdLxCxavlrSpur1JY38sXdegt74QEUcj4vnq9oikt6cZ7+lrV+irK3oR9gWSXht3/5D6a773kPS07edsr+t1MxOYFxFHpbE/Hklze9zPhZpO491NF0wz3jevXSvTn7erF2GfaCqpfhr/WxkRH5L0UUmfrQ5XMTmTmsa7WyaYZrwvtDr9ebt6EfZDkhaNu79Q0pEe9DGhiDhSXQ9L2qz+m4r62Nsz6FbXwz3u5//10zTeE00zrj547Xo5/Xkvwv6spKW2l9i+UtKnJG3pQR/vYHtG9cGJbM+QdJv6byrqLZLWVrfXSnqih72cp1+m8W40zbh6/Nr1fPrziOj6RdLtGvtE/hVJf92LHhr09V5Ju6vLvl73JukxjR3WndHYEdE9kt4jaZukA9X1nD7q7VFJL0jao7Fgze9Rbzdr7K3hHkm7qsvtvX7tCn115XXj67JAEnyDDkiCsANJEHYgCcIOJEHYgSQIO5AEYQeS+D+d4/ALFjNYogAAAABJRU5ErkJggg==\n", 179 | "text/plain": [ 180 | "
" 181 | ] 182 | }, 183 | "metadata": { 184 | "needs_background": "light" 185 | }, 186 | "output_type": "display_data" 187 | } 188 | ], 189 | "source": [ 190 | "plt.imshow(data_t[:,0].data.numpy().mean(axis=0)[0])" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 7, 196 | "metadata": {}, 197 | "outputs": [ 198 | { 199 | "name": "stdout", 200 | "output_type": "stream", 201 | "text": [ 202 | "tensor(1)\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "print(labels_t[:,0].argmax(1)[-1])" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "Let's create an iterator function that does all these steps:" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 8, 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [ 223 | "def iter_mnist(gen_train, batchsize=100, T=1000, max_rate = 20):\n", 224 | " datait = iter(gen_train)\n", 225 | " for raw_input, raw_labels in datait:\n", 226 | " data, labels1h = snn_utils.image2spiketrain(raw_input, raw_labels, max_duration=T, gain=max_rate)\n", 227 | " data_t = torch.FloatTensor(data)\n", 228 | " labels_t = torch.Tensor(labels1h)\n", 229 | " yield data_t, labels_t " 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "metadata": {}, 235 | "source": [ 236 | "### Creating the MNIST Nwtwork" 237 | ] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": {}, 242 | "source": [ 243 | "We make use of decolle, a python module contains a cleanly written DECOLLE module based on the principles described in tutorial 1. decolle should have been cloned under lib when you clones this repository. (In git terms, it is a submodule). If the decolle module isn't there run commented cell below" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 9, 249 | "metadata": {}, 250 | "outputs": [ 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "Submodule path 'decolle_public': checked out '62179e23bfea4cbec29d372f665cf2ed8251b935'\n", 256 | "Entering 'decolle_public'\n", 257 | "From github.com:nmi-lab/decolle-public\n", 258 | " * branch master -> FETCH_HEAD\n", 259 | "Updating 62179e2..8d56959\n", 260 | "Fast-forward\n", 261 | " .gitignore | 1 \u001b[32m+\u001b[m\n", 262 | " README.md | 31 \u001b[32m++\u001b[m\n", 263 | " decolle/base_model.py | 94 \u001b[32m+++\u001b[m\u001b[31m--\u001b[m\n", 264 | " decolle/lenet_decolle_model.py | 23 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n", 265 | " decolle/lenet_decolle_model_errortriggered.py | 112 \u001b[32m++++++\u001b[m\n", 266 | " decolle/lenet_decolle_model_fa.py | 139 \u001b[32m++++++++\u001b[m\n", 267 | " decolle/lenet_delle.py | 29 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n", 268 | " decolle/snn_utils.py | 390 \u001b[32m+++++++++++++++++++++\u001b[m\n", 269 | " decolle/utils.py | 220 \u001b[32m+++++++++\u001b[m\u001b[31m---\u001b[m\n", 270 | " pull_files_only.sh | 2 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n", 271 | " scripts/parameters/experiment_params.yml | 34 \u001b[31m--\u001b[m\n", 272 | " scripts/parameters/params.yml | 45 \u001b[32m++\u001b[m\u001b[31m-\u001b[m\n", 273 | " scripts/parameters/params_delle.yml | 37 \u001b[31m--\u001b[m\n", 274 | " .../{params_unittest.yml => params_dvsgesture.yml} | 16 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n", 275 | " .../params_dvsgestures_torchneuromorphic.yml | 52 \u001b[32m+++\u001b[m\n", 276 | " scripts/parameters/params_nmnist.yml | 37 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n", 277 | " .../parameters/params_ntidigits_errtrig_mlp.yml | 52 \u001b[31m---\u001b[m\n", 278 | " scripts/train_lenet_decolle.py | 39 \u001b[32m++\u001b[m\u001b[31m-\u001b[m\n", 279 | " scripts/train_lenet_decolle_error_triggered.py | 149 \u001b[32m++++++++\u001b[m\n", 280 | " scripts/train_lenet_decolle_fa.py | 134 \u001b[32m+++++++\u001b[m\n", 281 | " setup.py | 7 \u001b[32m+\u001b[m\u001b[31m-\u001b[m\n", 282 | " 21 files changed, 1374 insertions(+), 269 deletions(-)\n", 283 | " create mode 100644 decolle/lenet_decolle_model_errortriggered.py\n", 284 | " create mode 100644 decolle/lenet_decolle_model_fa.py\n", 285 | " create mode 100644 decolle/snn_utils.py\n", 286 | " delete mode 100644 scripts/parameters/experiment_params.yml\n", 287 | " delete mode 100644 scripts/parameters/params_delle.yml\n", 288 | " rename scripts/parameters/{params_unittest.yml => params_dvsgesture.yml} (74%)\n", 289 | " create mode 100644 scripts/parameters/params_dvsgestures_torchneuromorphic.yml\n", 290 | " delete mode 100644 scripts/parameters/params_ntidigits_errtrig_mlp.yml\n", 291 | " create mode 100644 scripts/train_lenet_decolle_error_triggered.py\n", 292 | " create mode 100644 scripts/train_lenet_decolle_fa.py\n" 293 | ] 294 | } 295 | ], 296 | "source": [ 297 | "!git submodule update --init decolle_public \n", 298 | "!git submodule foreach git pull origin master" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 10, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [ 307 | "import decolle_public.decolle as decolle\n", 308 | "from decolle_public.decolle import lenet_decolle_model " 309 | ] 310 | }, 311 | { 312 | "cell_type": "markdown", 313 | "metadata": {}, 314 | "source": [ 315 | "The following creates a three layer DECOLLE network" 316 | ] 317 | }, 318 | { 319 | "cell_type": "code", 320 | "execution_count": 11, 321 | "metadata": {}, 322 | "outputs": [ 323 | { 324 | "name": "stdout", 325 | "output_type": "stream", 326 | "text": [ 327 | "torch.Size([300, 100, 1, 28, 28])\n" 328 | ] 329 | } 330 | ], 331 | "source": [ 332 | "T = 300 #duration of sequence\n", 333 | "data, target = next(iter_mnist(gen_train, T=T))\n", 334 | "data = data.to(device)\n", 335 | "target = target.to(device)\n", 336 | "print(data.shape)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 12, 342 | "metadata": {}, 343 | "outputs": [], 344 | "source": [ 345 | "net = lenet_decolle_model.LenetDECOLLE(input_shape = data.shape[2:], Mhid = [150,120], num_conv_layers=0, num_mlp_layers=2, alpha=[.95],beta=[.92], lc_ampl=.5, out_channels=10).to(device)" 346 | ] 347 | }, 348 | { 349 | "cell_type": "markdown", 350 | "metadata": {}, 351 | "source": [ 352 | "Let's examine the created network" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "execution_count": 13, 358 | "metadata": {}, 359 | "outputs": [ 360 | { 361 | "data": { 362 | "text/plain": [ 363 | "LenetDECOLLE(\n", 364 | " (LIF_layers): ModuleList(\n", 365 | " (0): LIFLayer(\n", 366 | " (base_layer): Linear(in_features=784, out_features=150, bias=True)\n", 367 | " )\n", 368 | " (1): LIFLayer(\n", 369 | " (base_layer): Linear(in_features=150, out_features=120, bias=True)\n", 370 | " )\n", 371 | " )\n", 372 | " (readout_layers): ModuleList(\n", 373 | " (0): Linear(in_features=150, out_features=10, bias=True)\n", 374 | " (1): Linear(in_features=120, out_features=10, bias=True)\n", 375 | " )\n", 376 | " (pool_layers): ModuleList(\n", 377 | " (0): Sequential()\n", 378 | " (1): Sequential()\n", 379 | " )\n", 380 | " (dropout_layers): ModuleList(\n", 381 | " (0): Dropout(p=0.5, inplace=False)\n", 382 | " (1): Dropout(p=0.5, inplace=False)\n", 383 | " )\n", 384 | ")" 385 | ] 386 | }, 387 | "execution_count": 13, 388 | "metadata": {}, 389 | "output_type": "execute_result" 390 | } 391 | ], 392 | "source": [ 393 | "net" 394 | ] 395 | }, 396 | { 397 | "cell_type": "markdown", 398 | "metadata": {}, 399 | "source": [ 400 | "The network consists in 2 LIF layers, 150 and 120 neurons, with a readout layer and dropout layer associated to each. There are no pool_layers here, therefore they are represented as pass through layers (Sequential()). Dropout modules are used for the readout. Next we set up loss and optimization" 401 | ] 402 | }, 403 | { 404 | "cell_type": "code", 405 | "execution_count": 14, 406 | "metadata": {}, 407 | "outputs": [], 408 | "source": [ 409 | "loss = torch.nn.SmoothL1Loss()\n", 410 | "opt = torch.optim.Adamax(net.get_trainable_parameters(), lr=1e-8, betas=[0., .95])\n", 411 | "\n", 412 | "def decolle_loss(r, s, tgt):\n", 413 | " loss_tv = 0\n", 414 | " for i in range(len(r)):\n", 415 | " loss_tv += loss(r[i],tgt) \n", 416 | " return loss_tv" 417 | ] 418 | }, 419 | { 420 | "cell_type": "markdown", 421 | "metadata": {}, 422 | "source": [ 423 | "Initialize the DECOLLE network with the correct batch size. To avoid problems with initialization there is a burnin period where the state variables are allowed to settle. There is a little quirk here in that be need to swap the batch dimension and the timing dimension." 424 | ] 425 | }, 426 | { 427 | "cell_type": "code", 428 | "execution_count": 15, 429 | "metadata": {}, 430 | "outputs": [], 431 | "source": [ 432 | "net.init(data.transpose(0,1), len(net))" 433 | ] 434 | }, 435 | { 436 | "cell_type": "markdown", 437 | "metadata": {}, 438 | "source": [ 439 | "Initialize the parameters" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 16, 445 | "metadata": {}, 446 | "outputs": [], 447 | "source": [ 448 | "net.init_parameters(data.transpose(0,1))" 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "metadata": {}, 454 | "source": [ 455 | "Train the network. The error is computed at every epoch. The readout is based on total output across the entire sequence. The readout can be improved (see DECOLLE paper)." 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": 17, 461 | "metadata": {}, 462 | "outputs": [ 463 | { 464 | "name": "stdout", 465 | "output_type": "stream", 466 | "text": [ 467 | "Training Error tensor(0.2983)\n", 468 | "Epoch 0 Loss tensor(24.0427, device='cuda:0')\n", 469 | "Training Error tensor(0.1617)\n", 470 | "Epoch 1 Loss tensor(20.5774, device='cuda:0')\n", 471 | "Training Error tensor(0.1333)\n", 472 | "Epoch 2 Loss tensor(18.5669, device='cuda:0')\n", 473 | "Training Error tensor(0.1017)\n", 474 | "Epoch 3 Loss tensor(18.5628, device='cuda:0')\n", 475 | "Training Error tensor(0.1150)\n", 476 | "Epoch 4 Loss tensor(17.9603, device='cuda:0')\n", 477 | "Training Error tensor(0.1100)\n", 478 | "Epoch 5 Loss tensor(16.6971, device='cuda:0')\n", 479 | "Training Error tensor(0.0883)\n", 480 | "Epoch 6 Loss tensor(15.3596, device='cuda:0')\n", 481 | "Training Error tensor(0.0900)\n", 482 | "Epoch 7 Loss tensor(16.1182, device='cuda:0')\n", 483 | "Training Error tensor(0.0767)\n", 484 | "Epoch 8 Loss tensor(15.9183, device='cuda:0')\n", 485 | "Training Error tensor(0.0833)\n", 486 | "Epoch 9 Loss tensor(15.6708, device='cuda:0')\n" 487 | ] 488 | } 489 | ], 490 | "source": [ 491 | "for e in range(10): \n", 492 | " error = []\n", 493 | " for data, label in iter_mnist(gen_train, T=T):\n", 494 | " net.train()\n", 495 | " loss_hist = 0\n", 496 | " data_d = data.to(device)\n", 497 | " label_d = label.to(device)\n", 498 | " net.init(data_d.transpose(0,1), burnin=100)\n", 499 | " readout = 0\n", 500 | " for n in range(T):\n", 501 | " st, rt, ut = net.forward(data_d[n]) \n", 502 | " loss_tv = decolle_loss(rt, st, label_d[n])\n", 503 | " loss_tv.backward()\n", 504 | " opt.step()\n", 505 | " opt.zero_grad()\n", 506 | " loss_hist += loss_tv\n", 507 | " readout += rt[-1]\n", 508 | " error += (readout.argmax(axis=1)!=label_d[-1].argmax(axis=1)).float()\n", 509 | " print('Training Error', torch.mean(torch.Tensor(error)).data)\n", 510 | " \n", 511 | " print('Epoch', e, 'Loss', loss_hist.data)" 512 | ] 513 | }, 514 | { 515 | "cell_type": "markdown", 516 | "metadata": {}, 517 | "source": [ 518 | "## Convolutional DECOLLE" 519 | ] 520 | }, 521 | { 522 | "cell_type": "markdown", 523 | "metadata": {}, 524 | "source": [ 525 | "A convolutional DECOLLE network can be obtained by replacing the network generation as follows" 526 | ] 527 | }, 528 | { 529 | "cell_type": "code", 530 | "execution_count": 18, 531 | "metadata": {}, 532 | "outputs": [], 533 | "source": [ 534 | "convnet = lenet_decolle_model.LenetDECOLLE( out_channels=10,\n", 535 | " Nhid=[16,32], #Number of convolution channels\n", 536 | " Mhid=[64],\n", 537 | " kernel_size=[7],\n", 538 | " pool_size=[2,2],\n", 539 | " input_shape=data.shape[2:],\n", 540 | " alpha=[.95],\n", 541 | " alpharp=[.65],\n", 542 | " beta=[.92],\n", 543 | " num_conv_layers=2,\n", 544 | " num_mlp_layers=1,\n", 545 | " lc_ampl=.5).to(device)" 546 | ] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "execution_count": 19, 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [ 554 | "data, target = next(iter_mnist(gen_train, T=T))\n", 555 | "data_d = data.to(device)\n", 556 | "target_d = target.to(device)\n", 557 | "convnet.init_parameters(data_d.transpose(0,1))" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": 20, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [ 566 | "opt_conv = torch.optim.Adamax(convnet.get_trainable_parameters(), lr=1e-9, betas=[0., .95])" 567 | ] 568 | }, 569 | { 570 | "cell_type": "markdown", 571 | "metadata": {}, 572 | "source": [ 573 | "Train the network" 574 | ] 575 | }, 576 | { 577 | "cell_type": "code", 578 | "execution_count": null, 579 | "metadata": {}, 580 | "outputs": [ 581 | { 582 | "name": "stdout", 583 | "output_type": "stream", 584 | "text": [ 585 | "Training Error tensor(0.3883)\n", 586 | "Epoch 0 Loss tensor(42.2926, device='cuda:0')\n", 587 | "Training Error tensor(0.1317)\n", 588 | "Epoch 1 Loss tensor(39.1293, device='cuda:0')\n", 589 | "Training Error tensor(0.0933)\n", 590 | "Epoch 2 Loss tensor(36.5215, device='cuda:0')\n", 591 | "Training Error tensor(0.0817)\n", 592 | "Epoch 3 Loss tensor(35.7263, device='cuda:0')\n", 593 | "Training Error tensor(0.0617)\n", 594 | "Epoch 4 Loss tensor(34.1334, device='cuda:0')\n", 595 | "Training Error tensor(0.0433)\n", 596 | "Epoch 5 Loss tensor(34.7060, device='cuda:0')\n", 597 | "Training Error tensor(0.0383)\n", 598 | "Epoch 6 Loss tensor(33.5180, device='cuda:0')\n", 599 | "Training Error tensor(0.0250)\n", 600 | "Epoch 7 Loss tensor(33.7322, device='cuda:0')\n", 601 | "Training Error tensor(0.0167)\n", 602 | "Epoch 8 Loss tensor(31.7986, device='cuda:0')\n" 603 | ] 604 | } 605 | ], 606 | "source": [ 607 | "for e in range(10): \n", 608 | " error = []\n", 609 | " for data, label in iter_mnist(gen_train, T=T):\n", 610 | " convnet.train()\n", 611 | " loss_hist = 0\n", 612 | " data_d = data.to(device)\n", 613 | " label_d = label.to(device)\n", 614 | " convnet.init(data_d.transpose(0,1), burnin=100)\n", 615 | " readout = 0\n", 616 | " for n in range(T):\n", 617 | " st, rt, ut = convnet.forward(data_d[n]) \n", 618 | " loss_tv = decolle_loss(rt, st, label_d[n])\n", 619 | " loss_tv.backward()\n", 620 | " opt_conv.step()\n", 621 | " opt_conv.zero_grad()\n", 622 | " loss_hist += loss_tv\n", 623 | " readout += rt[-1]\n", 624 | " error += (readout.argmax(axis=1)!=label_d[-1].argmax(axis=1)).float()\n", 625 | " print('Training Error', torch.mean(torch.Tensor(error)).data)\n", 626 | " print('Epoch', e, 'Loss', loss_hist.data)" 627 | ] 628 | } 629 | ], 630 | "metadata": { 631 | "kernelspec": { 632 | "display_name": "Python 3", 633 | "language": "python", 634 | "name": "python3" 635 | }, 636 | "language_info": { 637 | "codemirror_mode": { 638 | "name": "ipython", 639 | "version": 3 640 | }, 641 | "file_extension": ".py", 642 | "mimetype": "text/x-python", 643 | "name": "python", 644 | "nbconvert_exporter": "python", 645 | "pygments_lexer": "ipython3", 646 | "version": "3.8.2" 647 | } 648 | }, 649 | "nbformat": 4, 650 | "nbformat_minor": 4 651 | } 652 | --------------------------------------------------------------------------------