├── README.md ├── models.py ├── utils.py ├── main.py ├── plot_compute_slowness.ipynb ├── main_pytorch.py ├── Tutorial_PEPITA_FullyConnectedNets_CIFAR-10.ipynb └── functions.py /README.md: -------------------------------------------------------------------------------- 1 | # PEPITA 2 | 3 | Code to run the simulations of the paper: 4 | ### Error-driven Input Modulation: Solving the Credit Assignment Problem without a Backward Pass 5 | 6 | Giorgia Dellaferrera, Gabriel Kreiman 7 | 8 | Presented at ICML 2022: https://proceedings.mlr.press/v162/dellaferrera22a.html 9 | 10 | 11 | # Requirements 12 | We run the experiments with the following: 13 | 14 | Numpy framework (fully connected models): Python 3.9.5, Numpy 1.19.5, Keras 2.5.0 15 | 16 | Pytorch framework (convolutional models): Python 3.7.10, Numpy 1.19.2, Pytorch 1.6.0 17 | 18 | 19 | # Experiments 20 | 21 | ## Fully connected models - Pytorch version 22 | 23 | The notebook `Tutorial_PEPITA_FullyConnectedNets_CIFAR-10.ipynb` provides a simple tutorial on how to implement and run the PEPITA training scheme for fully connected models. The entire framework is pytorch-based. The settings and results are the same as reported in the paper. 24 | 25 | The training for 100 epochs takes approximately 1.5 hours on CPU. 26 | 27 | ## Fully connected models - numpy version 28 | 29 | The experiments are run through `main.py`, which uses functions in `functions.py` and `utils.py`. 30 | The entire framework is numpy-based and relies on the keras library to load the datasets. 31 | 32 | For example, to run PEPITA with the standard settings on the MNIST dataset: 33 | ``` 34 | python main.py --exp_name Experiment1 \ 35 | --learn_type ERIN --n_runs 1 --train_epochs 100 \ 36 | --sample_passes 2 --n_samples all --eta 0.1 --dropout 0.9 \ 37 | --eta_decay --mnist --validation --batch_size 64 \ 38 | --update_type mom --w_init he_uniform \ 39 | --build auto --struct uniform --start_size 1024 --n_hlayers 1 --act_hidden relu --act_out softmax 40 | ``` 41 | 42 | Note that the training scheme for PEPITA is denoted as ERIN (ERror-INput). 43 | If you train with PEPITA (ERIN), make sure to use the setting `--sample_passes 2`, to have for each input two forward passes. 44 | 45 | Substitute `--learn_type ERIN` with `--learn_type BP` to train the network with backpropagation. Remember to set `--sample_passes 1`. 46 | 47 | ## Convolutional models - Pytorch version 48 | 49 | The experiments are run through `main_pytorch.py`, which uses functions in `models.py`. The entire framework is pytorch-based. 50 | 51 | For example, to run PEPITA with the standard settings on the MNIST dataset: 52 | ``` 53 | python main_pytorch.py --exp_name Experiment2 \ 54 | --learn_type ERIN --n_runs 1 --train_epochs 100 \ 55 | --eta 0.01 --dropout 0.9 --Bstd 0.05 \ 56 | --eta_decay --dataset mn --batch_size 50 \ 57 | --update_type mom --w_init he_uniform \ 58 | --model Net1conv1fcXL 59 | ``` 60 | 61 | The argument `Bstd` defines the standard deviation of the projection matrix. 62 | Here we use `B` instead of `F` (paper) to denote the projection matrix to avoid confusion with torch.nn.functional. 63 | 64 | ## Convergence rate 65 | 66 | The notebook `plot_compute_slowness.ipynb` contains the function to extract the convergence rate as "slowness" parameter. 67 | 68 | 69 | ## Citation 70 | ``` 71 | 72 | @InProceedings{pmlr-v162-dellaferrera22a, 73 | title = {Error-driven Input Modulation: Solving the Credit Assignment Problem without a Backward Pass}, 74 | author = {Dellaferrera, Giorgia and Kreiman, Gabriel}, 75 | booktitle = {Proceedings of the 39th International Conference on Machine Learning}, 76 | pages = {4937--4955}, 77 | year = {2022}, 78 | editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan}, 79 | volume = {162}, 80 | series = {Proceedings of Machine Learning Research}, 81 | month = {17--23 Jul}, 82 | publisher = {PMLR}, 83 | pdf = {https://proceedings.mlr.press/v162/dellaferrera22a/dellaferrera22a.pdf}, 84 | url = {https://proceedings.mlr.press/v162/dellaferrera22a.html}, 85 | } 86 | 87 | ``` 88 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Oct 15 11:24:59 2021 5 | 6 | @author: anonymous_ICML 7 | """ 8 | 9 | import torch 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | 17 | import copy 18 | 19 | from torch.autograd import Variable 20 | 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | 24 | # convolutional models for PEPITA 25 | # note that the softmax operation needs to be removed to use the models with BP 26 | 27 | class Net1conv1fcXL(nn.Module): 28 | def __init__(self,ch_input,nout): 29 | super().__init__() 30 | self.conv1 = nn.Conv2d(1, 32, 5,bias=False) 31 | self.fc1 = nn.Linear(4608, nout,bias=False) 32 | self.pool = nn.MaxPool2d(2, 2) 33 | 34 | def forward(self, x, do_masks): 35 | x = self.pool(F.relu(self.conv1(x))) 36 | x = torch.flatten(x, 1) # flatten all dimensions except batch 37 | x = F.softmax(self.fc1(x)) 38 | return x 39 | 40 | class Net1conv1fcXL_cif(nn.Module): 41 | def __init__(self,ch_input,nout): 42 | super().__init__() 43 | self.conv1 = nn.Conv2d(3, 32, 5,bias=False) 44 | self.fc1 = nn.Linear(6272, nout,bias=False) 45 | self.pool = nn.MaxPool2d(2, 2) 46 | 47 | def forward(self, x, do_masks): 48 | x = self.pool(F.relu(self.conv1(x))) 49 | x = torch.flatten(x, 1) # flatten all dimensions except batch 50 | x = F.softmax(self.fc1(x)) 51 | return x 52 | 53 | 54 | 55 | # for conv models 56 | def compute_delta_w_conv(inp,out_diff,w_shape,stride=1,sqrt=False, plot=False, plot2d=False): 57 | delta_w = torch.zeros(w_shape) 58 | ch_out = w_shape[0] # number of output channels 59 | size_out = out_diff.shape[-1] # size of output map 60 | ch_in = w_shape[1] # number of input channels 61 | size_in = inp.shape[-1] # size of input map 62 | ks = w_shape[2] # kernel height and width 63 | bs = out_diff.shape[0] 64 | cnt = 0 65 | #print(ch_out,size_out,ch_in,size_in,ks,bs) 66 | if plot: 67 | fig, axs = plt.subplots(1, size_out**2, figsize=(12, 3), sharey=False) 68 | fig2, axs2 = plt.subplots(1, size_out**2, figsize=(12, 3), sharey=False) 69 | if plot2d: 70 | figb, axsb = plt.subplots(1, size_out**2, figsize=(12, 3), sharey=False) 71 | fig2b, axs2b = plt.subplots(1, size_out**2, figsize=(12, 3), sharey=False) 72 | for r in range(0,size_out): # loop over all the output rows 73 | for c in range(0,size_out): # loop over all the output columns 74 | #print(ch_out,size_out,ch_in,size_in,ks) 75 | #print("r,c",r,c) 76 | inp_r_start = stride*r 77 | inp_r_end = stride*r+ks 78 | inp_c_start = stride*c 79 | inp_c_end = stride*c+ks 80 | this_out_diff = out_diff[:,:,r,c] 81 | this_inp = inp[:,:,inp_r_start:inp_r_end,inp_c_start:inp_c_end] 82 | partial = ev(this_out_diff, this_inp, bs, ch_in, ch_out, ks).reshape_as(delta_w) # gives the right answer 83 | delta_w += partial 84 | if plot: 85 | axs[cnt].imshow(partial.detach().numpy()[0,0]) 86 | axs2[cnt].imshow(delta_w.detach().numpy()[0,0]) 87 | if plot2d: 88 | axsb[cnt].imshow(partial.detach().numpy()[1,0]) 89 | axs2b[cnt].imshow(delta_w.detach().numpy()[1,0]) 90 | cnt += 1 91 | if sqrt == False: 92 | delta_w *= 1./cnt # dw = dw/n 93 | else: 94 | delta_w *= 1./np.sqrt(cnt) 95 | return delta_w 96 | 97 | 98 | def ev(this_out_diff, this_inp, bs, chin, chout, ks): 99 | prod_mul = torch.mul(this_out_diff.reshape(bs,chout,1,1,1), this_inp.reshape(bs,1,chin,ks,ks)) 100 | prod_mul = torch.mean(prod_mul,axis=0) # average across batchsize 101 | return prod_mul 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Feb 19 14:02:06 2020 4 | 5 | @author: GiorgiaDellaferrera 6 | """ 7 | 8 | import numpy as np 9 | import math 10 | import matplotlib.pyplot as plt 11 | from sklearn.model_selection import train_test_split 12 | from sklearn.utils import shuffle 13 | 14 | 15 | # activation functions 16 | def sigm(x): 17 | return 1/(1+np.exp(-x)) 18 | def d_sigm(x): 19 | return sigm(x) * (1-sigm(x)) 20 | def relu(x): 21 | return np.maximum(x,0) 22 | def step_f(x,bias=0): 23 | return np.heaviside(x,bias) 24 | def Lrelu(x,leakage=0.1): 25 | output = np.copy(x) 26 | output[output<0] *= leakage 27 | return output 28 | def d_Lrelu(x,leakage=0.1): 29 | return np.clip(x>0,leakage,1.0) 30 | def d_step_f(x): 31 | return 1-np.square(np.tanh(x)) 32 | def tanh(x): 33 | return np.tanh(x) 34 | def d_tanh(x): 35 | return 1-np.square(tanh(x)) 36 | def tanh_ciresan(x): 37 | A = 1.7159 38 | B = 0.6666 39 | return np.tanh(B*x)*A 40 | def d_tanh_ciresan(x): 41 | A = 1.7159 42 | B = 0.6666 43 | return A*B*(1-np.square(tanh(B*x))) 44 | def softmax(x): 45 | shiftx = x - np.max(x) 46 | exps = np.exp(shiftx) 47 | return exps / np.sum(exps) 48 | def onehotenc(idx,size): 49 | arr = np.zeros((size,1)) 50 | arr[idx] = 1 51 | return arr 52 | # to compute alignment of matrix 53 | def unit_vector(vector): 54 | """ Returns the unit vector of the vector. """ 55 | return vector / np.linalg.norm(vector) 56 | 57 | def angle_between(v1, v2): 58 | """ Returns the angle in radians between vectors 'v1' and 'v2':: 59 | """ 60 | v1 = 2.*(v1 - np.min(v1))/np.ptp(v1)-1 61 | v2 = 2.*(v2 - np.min(v2))/np.ptp(v2)-1 62 | v1_u = unit_vector(v1) 63 | v2_u = unit_vector(v2) 64 | return math.degrees(np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))) 65 | 66 | # prepare dataset 67 | def dataset_simple(n_input,n_output,n_samples,seed=None,VERBOSE=False,plots=False): 68 | x_list = [] 69 | target_list = [] 70 | for s in range(n_samples): 71 | # generate the sample and target 72 | if s == 0: 73 | x = np.zeros((n_input,1)) 74 | #np.random.seed(633) 75 | idx_l = np.random.choice(n_input,2,replace=False) 76 | for i_l in idx_l: 77 | x[i_l] = 1 78 | if VERBOSE: 79 | print('s',s) 80 | if s>0: 81 | flag = 1 82 | while flag>0: 83 | if VERBOSE: 84 | print('drawing ') 85 | x = np.zeros((n_input,1)) 86 | idx_l = np.random.choice(n_input,2,replace=False) 87 | for i_l in idx_l: 88 | x[i_l] = 1 89 | if VERBOSE: 90 | print('check') 91 | flag=0 92 | if VERBOSE: 93 | print('fl',flag) 94 | for i in range(len(x_list)): 95 | if VERBOSE: 96 | print(i) 97 | if np.array_equal(x,x_list[i]): 98 | flag+=1 99 | if VERBOSE: 100 | print('update: flag = ',flag) 101 | if VERBOSE: 102 | print('final ',flag) 103 | #plt.figure() 104 | #plt.imshow(x.reshape((int(np.sqrt(n_input)),int(np.sqrt(n_input))))) 105 | #plt.title('attempt') 106 | if plots: 107 | plt.figure() 108 | plt.imshow(x.reshape((int(np.sqrt(n_input)),int(np.sqrt(n_input))))) 109 | 110 | target = np.zeros((n_output,1)) 111 | #idx = np.random.choice(np.arange(n_output)) 112 | idx = s 113 | target[idx] = 1 114 | # add it to the dataset 115 | x_list.append(x) 116 | target_list.append(target) 117 | return x_list, target_list 118 | 119 | 120 | def dataset_mnist(n_samples,seed=None,plots=False): 121 | print('Loading mnist') 122 | x_max = 255 123 | x_min = 0 124 | # import mnist 125 | import keras 126 | from keras.datasets import mnist 127 | (x_train, y_train), (x_test, y_test) = mnist.load_data() 128 | 129 | if n_samples is not 'all': 130 | if seed is not None: 131 | np.random.seed(seed) 132 | i = np.random.choice(len(y_train)-n_samples-1) 133 | #print(i) 134 | #i=0 # to be removed 135 | x_l = x_train[i:n_samples+i,:,:] 136 | idx_list = y_train[i:n_samples+i] 137 | print("using digits:") 138 | print(idx_list) 139 | else: 140 | print("using the full mnist dataset") 141 | x_l = x_train 142 | idx_list = y_train 143 | x_l_test = x_test 144 | idx_list_test = y_test 145 | # one input neuron encodes one pixel 146 | n_input = np.size(x_l[0]) 147 | n_output = 10 148 | # flattening for samples and one-hot encoding for targets 149 | target_list = [] 150 | x_list = [] 151 | target_list_test = [] 152 | x_list_test = [] 153 | # train 154 | for idx,t in enumerate(idx_list): 155 | x = x_l[idx].reshape((np.size(x_l[idx]),1)) 156 | x = (x-x_min)/(x_max-x_min) 157 | x_list.append(x) 158 | target = np.zeros((n_output,1)) 159 | target[t] = 1 160 | target_list.append(target) 161 | if plots: 162 | plt.figure() 163 | plt.imshow(x.reshape((int(np.sqrt(n_input)),int(np.sqrt(n_input))))) 164 | # test 165 | if n_samples is not 'all': 166 | x_list_test = x_list 167 | target_list_test = target_list 168 | else: 169 | for idx,t in enumerate(idx_list_test): 170 | x = x_l_test[idx].reshape((np.size(x_l_test[idx]),1)) 171 | x = (x-x_min)/(x_max-x_min) 172 | x_list_test.append(x) 173 | target = np.zeros((n_output,1)) 174 | target[t] = 1 175 | target_list_test.append(target) 176 | 177 | return x_list, target_list, x_list_test, target_list_test 178 | 179 | def dataset_emnist(n_samples,seed=None,plots=False): 180 | print('Loading emnist') 181 | x_max = 255 182 | x_min = 0 183 | # import emnist balanced 184 | from extra_keras_datasets import emnist 185 | (x_train, y_train), (x_test, y_test) = emnist.load_data(type='balanced') 186 | 187 | if n_samples is not 'all': 188 | if seed is not None: 189 | np.random.seed(seed) 190 | i = np.random.choice(len(y_train)-n_samples-1) 191 | #print(i) 192 | x_l = x_train[i:n_samples+i,:,:] 193 | idx_list = y_train[i:n_samples+i] 194 | print("using letters:") 195 | print(idx_list) 196 | else: 197 | print("using the full emnist dataset") 198 | x_l = x_train 199 | idx_list = y_train 200 | x_l_test = x_test 201 | idx_list_test = y_test 202 | # one input neuron encodes one pixel 203 | n_input = np.size(x_l[0]) 204 | n_output = 47 205 | # flattening for samples and one-hot encoding for targets 206 | target_list = [] 207 | x_list = [] 208 | target_list_test = [] 209 | x_list_test = [] 210 | for idx,t in enumerate(idx_list): 211 | x = x_l[idx].reshape((np.size(x_l[idx]),1)) 212 | x = (x-x_min)/(x_max-x_min) 213 | x_list.append(x) 214 | target = np.zeros((n_output,1)) 215 | target[t] = 1 216 | target_list.append(target) 217 | if plots: 218 | plt.figure() 219 | plt.imshow(x.reshape((int(np.sqrt(n_input)),int(np.sqrt(n_input))))) 220 | if n_samples is not 'all': 221 | x_list_test = x_list 222 | target_list_test = target_list 223 | else: 224 | for idx,t in enumerate(idx_list_test): 225 | x = x_l_test[idx].reshape((np.size(x_l_test[idx]),1)) 226 | x = (x-x_min)/(x_max-x_min) 227 | x_list_test.append(x) 228 | target = np.zeros((n_output,1)) 229 | target[t] = 1 230 | target_list_test.append(target) 231 | 232 | return x_list, target_list, x_list_test, target_list_test 233 | 234 | def dataset_fmnist(n_samples,seed=None,plots=False): 235 | print('Loading fmnist') 236 | x_max = 255 237 | x_min = 0 238 | # import mnist 239 | from keras.datasets import fashion_mnist 240 | (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() 241 | 242 | if n_samples is not 'all': 243 | if seed is not None: 244 | np.random.seed(seed) 245 | i = np.random.choice(len(y_train)-n_samples-1) 246 | #print(i) 247 | x_l = x_train[i:n_samples+i,:,:] 248 | idx_list = y_train[i:n_samples+i] 249 | print("using clothes:") 250 | print(idx_list) 251 | else: 252 | print("using the full fmnist dataset") 253 | x_l = x_train 254 | idx_list = y_train 255 | x_l_test = x_test 256 | idx_list_test = y_test 257 | # one input neuron encodes one pixel 258 | n_input = np.size(x_l[0]) 259 | n_output = 10 260 | # flattening for samples and one-hot encoding for targets 261 | target_list = [] 262 | x_list = [] 263 | target_list_test = [] 264 | x_list_test = [] 265 | for idx,t in enumerate(idx_list): 266 | x = x_l[idx].reshape((np.size(x_l[idx]),1)) 267 | x = (x-x_min)/(x_max-x_min) 268 | x_list.append(x) 269 | target = np.zeros((n_output,1)) 270 | target[t] = 1 271 | target_list.append(target) 272 | if plots: 273 | plt.figure() 274 | plt.imshow(x.reshape((int(np.sqrt(n_input)),int(np.sqrt(n_input))))) 275 | if n_samples is not 'all': 276 | x_list_test = x_list 277 | target_list_test = target_list 278 | else: 279 | for idx,t in enumerate(idx_list_test): 280 | x = x_l_test[idx].reshape((np.size(x_l_test[idx]),1)) 281 | x = (x-x_min)/(x_max-x_min) 282 | x_list_test.append(x) 283 | target = np.zeros((n_output,1)) 284 | target[t] = 1 285 | target_list_test.append(target) 286 | 287 | return x_list, target_list, x_list_test, target_list_test 288 | 289 | def dataset_cifar(n_samples,seed=None,plots=False): 290 | print('Loading cifar10') 291 | x_max = 255 292 | x_min = 0 293 | # import cifar10 294 | import keras 295 | from keras.datasets import cifar10 296 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 297 | 298 | class_labels = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'] 299 | 300 | if n_samples is not 'all': 301 | if seed is not None: 302 | np.random.seed(seed) 303 | i = np.random.choice(len(y_train)-n_samples-1) 304 | #print(i) 305 | x_l = x_train[i:n_samples+i,:,:] 306 | idx_list = y_train[i:n_samples+i] 307 | print("using classes:") 308 | print(idx_list) 309 | print("corresponding to labels:") 310 | for id_ in idx_list: 311 | print(class_labels[id_[0]]) 312 | else: 313 | print("using the full cifar10 dataset") 314 | x_l = x_train 315 | idx_list = y_train 316 | x_l_test = x_test 317 | idx_list_test = y_test 318 | # one input neuron encodes one pixel 319 | n_input = np.size(x_l[0]) 320 | n_output = 10 321 | # flattening for samples and one-hot encoding for targets 322 | target_list = [] 323 | x_list = [] 324 | target_list_test = [] 325 | x_list_test = [] 326 | for idx,t in enumerate(idx_list): 327 | x = x_l[idx].reshape((np.size(x_l[idx]),1)) 328 | x = (x-x_min)/(x_max-x_min) 329 | x_list.append(x) 330 | target = np.zeros((n_output,1)) 331 | target[t] = 1 332 | target_list.append(target) 333 | if plots: 334 | plt.figure(figsize=(2,2)) 335 | plt.imshow(x.reshape((int(np.sqrt(n_input/3)),int(np.sqrt(n_input/3)),3))) 336 | if n_samples is not 'all': 337 | x_list_test = x_list 338 | target_list_test = target_list 339 | else: 340 | for idx,t in enumerate(idx_list_test): 341 | x = x_l_test[idx].reshape((np.size(x_l_test[idx]),1)) 342 | x = (x-x_min)/(x_max-x_min) 343 | x_list_test.append(x) 344 | target = np.zeros((n_output,1)) 345 | target[t] = 1 346 | target_list_test.append(target) 347 | 348 | return x_list, target_list, x_list_test, target_list_test 349 | 350 | def dataset_cifar100(n_samples,seed=None,plots=False): 351 | print('Loading cifar100') 352 | x_max = 255 353 | x_min = 0 354 | # import cifar10 355 | import keras 356 | from keras.datasets import cifar100 357 | (x_train, y_train), (x_test, y_test) = cifar100.load_data() 358 | 359 | #class_labels = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck'] 360 | 361 | if n_samples is not 'all': 362 | if seed is not None: 363 | np.random.seed(seed) 364 | i = np.random.choice(len(y_train)-n_samples-1) 365 | #print(i) 366 | x_l = x_train[i:n_samples+i,:,:] 367 | idx_list = y_train[i:n_samples+i] 368 | print("using classes:") 369 | print(idx_list) 370 | print("corresponding to labels:") 371 | #for id_ in idx_list: 372 | #print(class_labels[id_[0]]) 373 | else: 374 | print("using the full cifar100 dataset") 375 | x_l = x_train 376 | idx_list = y_train 377 | x_l_test = x_test 378 | idx_list_test = y_test 379 | # one input neuron encodes one pixel 380 | n_input = np.size(x_l[0]) 381 | n_output = 100 382 | # flattening for samples and one-hot encoding for targets 383 | target_list = [] 384 | x_list = [] 385 | target_list_test = [] 386 | x_list_test = [] 387 | for idx,t in enumerate(idx_list): 388 | x = x_l[idx].reshape((np.size(x_l[idx]),1)) 389 | x = (x-x_min)/(x_max-x_min) 390 | x_list.append(x) 391 | target = np.zeros((n_output,1)) 392 | target[t] = 1 393 | target_list.append(target) 394 | if plots: 395 | plt.figure(figsize=(2,2)) 396 | plt.imshow(x.reshape((int(np.sqrt(n_input/3)),int(np.sqrt(n_input/3)),3))) 397 | if n_samples is not 'all': 398 | x_list_test = x_list 399 | target_list_test = target_list 400 | else: 401 | for idx,t in enumerate(idx_list_test): 402 | x = x_l_test[idx].reshape((np.size(x_l_test[idx]),1)) 403 | x = (x-x_min)/(x_max-x_min) 404 | x_list_test.append(x) 405 | target = np.zeros((n_output,1)) 406 | target[t] = 1 407 | target_list_test.append(target) 408 | 409 | return x_list, target_list, x_list_test, target_list_test 410 | 411 | 412 | def dataset_debug(n_samples,seed=None,plots=False): 413 | print('Loading dataset for debugging') 414 | x_list =[np.array([[1.,0.,0.,0.]]).T] 415 | target_list = [np.array([[0,0.5]]).T] 416 | x_list_test = x_list 417 | target_list_test = target_list 418 | 419 | return x_list, target_list, x_list_test, target_list_test 420 | 421 | def distortion(x_train, y_train): 422 | #from keras.preprocessing.image import ImageDataGenerator 423 | from image import ImageDataGenerator 424 | datagen = ImageDataGenerator( 425 | featurewise_center=False, 426 | featurewise_std_normalization=False, 427 | rotation_range=10, 428 | width_shift_range=0.15, 429 | height_shift_range=0.15, 430 | shear_range=0.15, 431 | zoom_range=0.15, 432 | elastic_RGB=[34,6], 433 | fill_mode='constant', 434 | vertical_flip=False, 435 | horizontal_flip=False) 436 | #print('before',np.shape(x_train)) 437 | #print('before mean',np.mean(x_train)) 438 | x_train = np.reshape(x_train,(np.shape(x_train) + (1,))) 439 | x_train = np.reshape(x_train,(np.shape(x_train)[0],int(np.sqrt(np.shape(x_train)[1])),int(np.sqrt(np.shape(x_train)[1])),1)) 440 | #print('input shape',np.shape(x_train)) 441 | datagen.fit(x_train) 442 | batches = 0 443 | for x_deformed, y_deformed in datagen.flow(x_train, y_train, batch_size=60000): 444 | batches += 1 445 | if batches >= 1: 446 | break 447 | #x_plot = np.reshape(x_deformed,(np.shape(x_deformed)[0],np.shape(x_deformed)[1],np.shape(x_deformed)[2])) 448 | #plt.figure() 449 | #plt.imshow(x_plot[0,:,:]) 450 | #plt.colorbar() 451 | x_deformed = np.reshape(x_deformed,(np.shape(x_deformed)[0],np.shape(x_deformed)[1]*np.shape(x_deformed)[2],1)) 452 | #print('after',np.shape(x_deformed)) 453 | #print('after mean',np.mean(x_deformed)) 454 | x_max = np.max(x_deformed) 455 | x_min = np.min(x_deformed) 456 | x_deformed = (x_deformed-x_min)/(x_max-x_min) 457 | #print('after normaliz',np.mean(x_deformed)) 458 | return x_deformed, y_deformed 459 | 460 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Feb 19 14:02:22 2020 4 | 5 | @author: anonymous_ICML 6 | """ 7 | 8 | import os 9 | os.environ["MKL_NUM_THREADS"] = "1" 10 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 11 | os.environ["OMP_NUM_THREADS"] = "1" 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | import time 15 | from functions import * 16 | from utils import * 17 | 18 | import argparse 19 | from sklearn.model_selection import train_test_split 20 | 21 | plt.close('all') 22 | start_time = time.time() 23 | 24 | 25 | # ask for the arguments 26 | parser = argparse.ArgumentParser() 27 | 28 | parser.add_argument('-en', '--exp_name', 29 | type=str, default='expPEPITA', 30 | help="Experiment name") 31 | parser.add_argument('-lt', '--learn_type', 32 | type=str, default='ERIN', 33 | help="Learning rule: BP, ERIN, ERINsign, FA, DFA") 34 | parser.add_argument('-nf', '--nested_folder', 35 | type=str, default='ignore', 36 | help="nested folder where to save the saving folder") 37 | parser.add_argument('-r', '--n_runs', 38 | type=int, default=1, 39 | help="Number of simulations for each model") 40 | parser.add_argument('-trep', '--train_epochs', 41 | type=int, default= 3, 42 | help="Number of training epochs") 43 | parser.add_argument('-sp', '--sample_passes', 44 | type=int, default=2, 45 | help="Number of consecutive passes for each sample") 46 | parser.add_argument('-ns', '--n_samples', 47 | default='all', 48 | help="Size of training set. Choose between an integer or 'all' ") 49 | parser.add_argument('-eta', '--eta', 50 | type=float, default=0.01, 51 | help="Learning rate") 52 | parser.add_argument('-do', '--dropout', 53 | type=float, default=0.9, 54 | help="Dropout") 55 | parser.add_argument('-no_shuffling','--no_shuffling', 56 | default=False, action='store_true', 57 | help="Choose not to do any shuffling") 58 | parser.add_argument('-zm', '--zeromean', 59 | action='store_true', 60 | help="Rescale the dataset to the interval [-1,1]") 61 | parser.add_argument('-check_cos_norm', '--check_cos_norm', 62 | action='store_true', 63 | help="Compute antialignment angle and matrix norm during training") 64 | parser.add_argument('-eta_d', '--eta_decay', 65 | action='store_true', 66 | help="If True, eta is decreased by a factor 0.1 every 60 epochs") 67 | parser.add_argument('-def', '--deformation', 68 | action='store_true', 69 | help="If True, deformations are applied to the images at each epoch") 70 | parser.add_argument('-notav', '--no_test_as_val', 71 | action='store_true', 72 | help="If True, test set is used as validation set") 73 | parser.add_argument('-ct', '--continue_training', 74 | action='store_true', 75 | help="If True, training is continued from some saved weights") 76 | parser.add_argument('-ct_path', '--continue_training_path', 77 | default='res_exp_BP_v7_SGD_1_cir_un_do100_mnist_rep1_tr20_pass1',type=str, 78 | help="Path containing weights to continue training from") 79 | parser.add_argument('-seed', '--seed', 80 | default=None, 81 | help="Random seed. Set to None or to integer") 82 | parser.add_argument('-mn', '--mnist', action='store_true', 83 | help="use mnist as dataset") 84 | parser.add_argument('-emn', '--emnist', action='store_true', 85 | help="use emnist (balanced) as dataset") 86 | parser.add_argument('-fmn', '--fmnist', action='store_true', 87 | help="use fashion mnist as dataset") 88 | parser.add_argument('-cif', '--cifar10', action='store_true', 89 | help="use cifar10 as dataset") 90 | parser.add_argument('-cif100', '--cifar100', action='store_true', 91 | help="use cifar100 as dataset") 92 | parser.add_argument('-datadebug', '--datadebug', action='store_true', 93 | help="use debug dataset as dataset") 94 | parser.add_argument('-V', '--VERBOSE', action='store_true', 95 | help="print some extra variables and results") 96 | parser.add_argument('-pl', '--plots', action='store_true', 97 | help="plot extra figures") 98 | parser.add_argument('-val', '--validation', action='store_true', 99 | help="perform validation") 100 | parser.add_argument('-ut', '--update_type', 101 | type=str, default='mom', 102 | help="Update type: SGD, mom(entum), NAG, rmsprop, Adam ...") 103 | parser.add_argument('-bs', '--batch_size', 104 | default=64,type=int, 105 | help="Batch size during training. Choose an integer") 106 | parser.add_argument('-kv', '--keep_variants', 107 | type=str, default='un', 108 | help="Keep variants: ul (until learning), ue (until end) or un (until normalization)") 109 | parser.add_argument('-win', '--w_init', 110 | type=str, default='he_uniform', #'he_uniform', 111 | help="Weight initialization type. Options: rnd, zero, ones, xav, he, he_uniform, nok, cir") 112 | parser.add_argument('-build', '--build', 113 | type=str, default='auto', 114 | help="Building mode: auto or custom") 115 | parser.add_argument('-arch', '--architecture', 116 | default=[784,2000,1500,1000,500,10], 117 | help="Network layers size") 118 | parser.add_argument('-act', '--act_list', 119 | default=[tanh_ciresan,tanh_ciresan,tanh_ciresan,tanh_ciresan,softmax], 120 | help="Network layers activations") 121 | parser.add_argument('-struct', '--struct', 122 | type=str, default='uniform', 123 | help="Network structure: pyramidal or uniform") 124 | parser.add_argument('-ss', '--start_size', 125 | type=int, default=1024, 126 | help="Size of 1st hidden layer") 127 | parser.add_argument('-nh', '--n_hlayers', 128 | type=int, default=2, 129 | help="Number of hidden layers") 130 | parser.add_argument('-act_h', '--act_hidden', 131 | default='relu',type=str, 132 | help="Activation of hidden layers") 133 | parser.add_argument('-act_o', '--act_out', 134 | default='softmax',type=str, 135 | help="Activation of output layer") 136 | args = parser.parse_args() 137 | 138 | #mnist = True 139 | 140 | # save the arguments 141 | # simulation set-up 142 | exp_name = args.exp_name 143 | nested_folder = args.nested_folder 144 | n_runs = args.n_runs 145 | train_epochs = args.train_epochs 146 | sample_passes = args.sample_passes 147 | n_samples = args.n_samples 148 | eta = args.eta 149 | dropout = args.dropout 150 | no_shuffling = args.no_shuffling 151 | zeromean = args.zeromean 152 | check_cos_norm = args.check_cos_norm 153 | if no_shuffling == False: 154 | shuffling = True 155 | print("shuffling") 156 | else: 157 | shuffling = False 158 | print("no shuffling") 159 | dropout_perc = int(dropout*100) 160 | eta_decay = args.eta_decay 161 | eta_decay = True # to be removed 162 | deformation = args.deformation 163 | no_test_as_val = args.no_test_as_val 164 | if no_test_as_val: 165 | test_as_val = False 166 | else: 167 | test_as_val = True 168 | continue_training = args.continue_training 169 | continue_training_path = args.continue_training_path 170 | seed = args.seed 171 | mnist = args.mnist 172 | emnist = args.emnist 173 | fmnist = args.fmnist 174 | cifar10 = args.cifar10 175 | #cifar10 = True # to be removed 176 | cifar100 = args.cifar100 177 | #cifar100 = True # to be removed 178 | datadebug = args.datadebug 179 | # check that one dataset has been chosen 180 | if mnist == emnist == True or mnist == fmnist == True or emnist == fmnist == True or mnist == cifar10 == True or fmnist == cifar10 == True or emnist == cifar10 == True: 181 | print('Warning, two datasets have been chosen') 182 | print('Setting mnist as dataset') 183 | mnist = True 184 | fmnist = False 185 | emnist = False 186 | cifar10 = False 187 | cifar100 = False 188 | if mnist == emnist == fmnist == cifar10 == cifar100 == datadebug == False: 189 | print('Warning, one dataset should be chosen') 190 | print('Setting mnist as dataset') 191 | mnist = True 192 | if mnist or emnist or fmnist or cifar10 or cifar100 or datadebug: 193 | simple_data = False # if True use a simple dataset 194 | else: 195 | simple_data = True 196 | w_init = args.w_init 197 | VERBOSE = args.VERBOSE 198 | plots = args.plots 199 | validation = args.validation 200 | validation = True # to be removed 201 | # network set-up 202 | learn_type = args.learn_type 203 | update_type = args.update_type 204 | batch_size = args.batch_size 205 | keep_variants = args.keep_variants 206 | build = args.build 207 | if build == 'auto': 208 | struct = args.struct # pyramidal or uniform 209 | start_size = args.start_size # e.g. 256 210 | n_hlayers = args.n_hlayers # e.g. 2 211 | act_hidden = args.act_hidden 212 | act_hidden_str = args.act_hidden 213 | act_out = args.act_out 214 | if act_hidden == 'sigm': 215 | act_hidden = sigm 216 | elif act_hidden == 'relu': 217 | act_hidden = relu 218 | elif act_hidden == 'Lrelu': 219 | act_hidden = Lrelu 220 | elif act_hidden == 'tanh': 221 | act_hidden = tanh 222 | elif act_hidden == 'tanh_ciresan': 223 | act_hidden = tanh_ciresan 224 | elif act_hidden == 'step_f': 225 | act_hidden = step_f 226 | elif act_hidden == 'softmax': 227 | act_hidden = softmax 228 | if act_out == 'sigm': 229 | act_out = sigm 230 | elif act_out == 'relu': 231 | act_out = relu 232 | elif act_out == 'Lrelu': 233 | act_out = Lrelu 234 | elif act_out == 'tanh': 235 | act_out = tanh 236 | elif act_out == 'tanh_ciresan': 237 | act_out = tanh_ciresan 238 | elif act_out == 'step_f': 239 | act_out = step_f 240 | elif act_out == 'softmax': 241 | act_out = softmax 242 | if mnist or emnist or fmnist: 243 | layers_size = [784] 244 | elif cifar10 or cifar100: 245 | layers_size = [3072] 246 | elif datadebug: 247 | layers_size = [4] 248 | act_list = [] 249 | size_next = start_size 250 | for h in range(n_hlayers): 251 | layers_size.append(size_next) 252 | act_list.append(act_hidden) 253 | if struct == 'pyramidal': 254 | size_next = int(size_next/2) 255 | elif struct == 'uniform': 256 | pass 257 | if mnist or fmnist or cifar10 or simple_data: 258 | layers_size.append(10) 259 | elif emnist: 260 | layers_size.append(47) 261 | elif cifar100: 262 | layers_size.append(100) 263 | elif datadebug: 264 | layers_size.append(2) 265 | act_list.append(act_out) 266 | 267 | elif build == 'custom': 268 | layers_size = args.architecture 269 | act_list = args.act_list 270 | 271 | print(act_list) 272 | print(layers_size) 273 | print('Learning rate:',eta) 274 | #check size and create list of derivatives of activations 275 | try: 276 | a = len(layers_size)-1 277 | b = len(act_list) 278 | assert a == b 279 | except AssertionError: 280 | print ("Assertion Exception Raised.") 281 | else: 282 | print ("layer size and number of activations correctly set up!") 283 | d_act_list = [] 284 | for idx,a in enumerate(act_list): 285 | if a == sigm: 286 | d_act_list.append(d_sigm) 287 | elif a == relu: 288 | d_act_list.append(step_f) 289 | elif a == Lrelu: 290 | d_act_list.append(d_Lrelu) 291 | elif a == tanh: 292 | d_act_list.append(d_tanh) 293 | elif a == tanh_ciresan: 294 | d_act_list.append(d_tanh_ciresan) 295 | elif a == step_f: 296 | d_act_list.append(d_step_f) 297 | elif a == softmax: 298 | d_act_list.append(None) 299 | 300 | # create folder to save all results 301 | if mnist: 302 | arch_name = 'mnist' 303 | elif emnist: 304 | arch_name = 'emnist' 305 | elif fmnist: 306 | arch_name = 'fmnist' 307 | elif cifar10: 308 | arch_name = 'cifar10' 309 | elif cifar100: 310 | arch_name = 'cifar100' 311 | elif datadebug: 312 | arch_name = 'datadebug' 313 | elif simple_data: 314 | arch_name = 'simple' 315 | if eta_decay == False: 316 | savepath = "res_"+exp_name+"_"+learn_type+"_"+update_type+"_"+str(batch_size)+"_"+act_hidden_str+"_"+w_init+"_"+keep_variants+"_"+"do"+str(dropout_perc)+"_"+arch_name+"_rep"+str(n_runs)+"_tr"+str(train_epochs)+"_pass"+str(sample_passes) 317 | else: 318 | savepath = "res_"+exp_name+"_"+learn_type+"_"+update_type+"_"+str(batch_size)+"_"+act_hidden_str+"_"+w_init+"_"+keep_variants+"_"+"do"+str(dropout_perc)+"_etad_"+arch_name+"_rep"+str(n_runs)+"_tr"+str(train_epochs)+"_pass"+str(sample_passes) 319 | if nested_folder != "ignore": 320 | savepath = nested_folder + '/' + savepath 321 | if deformation: 322 | savepath = savepath + "_def" 323 | if test_as_val: 324 | savepath = savepath + "_tav" 325 | if continue_training: 326 | savepath = savepath + "_ct" 327 | try: 328 | os.mkdir(savepath) 329 | except OSError: 330 | print ("Creation of the directory %s failed" % savepath) 331 | else: 332 | print ("Successfully created the directory %s " % savepath) 333 | # prepare a file to write the results on 334 | filename = savepath+'/res_summary_'+exp_name+'.txt' 335 | file = open(filename,'w') 336 | file.write('Results for simulation with the following hyperparameters ') 337 | file.write('\n Number of repetitions = ') 338 | file.write(str(n_runs)) 339 | file.write('\n Training epochs = ') 340 | file.write(str(train_epochs)) 341 | file.write('\n Sample passes = ') 342 | file.write(str(sample_passes)) 343 | file.write('\n Learning rate = ') 344 | file.write(str(eta)) 345 | if deformation: 346 | file.write('\n Applying deformation') 347 | if test_as_val: 348 | file.write('\n Test set use for validation') 349 | if continue_training: 350 | file.write('\n Training continued from saved weights in folder '+continue_training_path) 351 | file.write('\n Dropout = ') 352 | file.write(str(dropout)) 353 | file.write('\n Shuffling = ') 354 | file.write(str(shuffling)) 355 | file.write('\n Eta decay = ') 356 | file.write(str(eta_decay)) 357 | file.write('\n Seed = ') 358 | file.write(str(seed)) 359 | file.write('\n Dataset type = ') 360 | file.write(arch_name) 361 | file.write('\n Learn type = ') 362 | file.write(learn_type) 363 | file.write('\n Batch size = ') 364 | file.write(str(batch_size)) 365 | file.write('\n Update type = ') 366 | file.write(update_type) 367 | file.write('\n Keep variants = ') 368 | file.write(keep_variants) 369 | file.write('\n Network architecture = ') 370 | file.write(str(layers_size)) 371 | file.write('\n Activation functions = ') 372 | file.write(str(act_list)) 373 | 374 | # create variables to store results 375 | train_acc_all = np.zeros((n_runs,train_epochs)) 376 | val_acc_all = np.zeros((n_runs,train_epochs)) 377 | test_acc_all = [] 378 | 379 | # loop over the number of simulations 380 | for r in range(n_runs): 381 | print('####### RUN {} #######'.format(r)) 382 | t0 = time.time() 383 | net = general_network(layers_size,act_list,d_act_list,learn_type,batch_size,update_type,keep_variants,w_init,sample_passes,VERBOSE) 384 | 385 | if continue_training: 386 | for i in range(len(layers_size)-1): 387 | weights = np.loadtxt(continue_training_path+'/weights_layer'+str(i)+'.txt') 388 | net.layers[i].w = weights 389 | loadw = False 390 | if loadw: 391 | for i in range(len(layers_size)-1): 392 | weights = np.loadtxt('hewin'+str(i)+'.txt') 393 | net.layers[i].w = weights 394 | print(weights[0][0]) 395 | weights = np.loadtxt('hewinF.txt') 396 | net.layers[-1].F = weights 397 | print(weights[0][0]) 398 | 399 | 400 | if mnist: 401 | x_list, target_list, x_list_test, target_list_test = dataset_mnist(n_samples,seed,plots=False) 402 | elif emnist: 403 | x_list, target_list, x_list_test, target_list_test = dataset_emnist(n_samples,seed,plots=False) 404 | elif fmnist: 405 | x_list, target_list, x_list_test, target_list_test = dataset_fmnist(n_samples,seed,plots=False) 406 | elif cifar10: 407 | x_list, target_list, x_list_test, target_list_test = dataset_cifar(n_samples,seed,plots=False) 408 | elif cifar100: 409 | x_list, target_list, x_list_test, target_list_test = dataset_cifar100(n_samples,seed,plots=False) 410 | elif datadebug: 411 | x_list, target_list, x_list_test, target_list_test = dataset_debug(n_samples,seed,plots=False) 412 | elif simple_data: 413 | x_list,target_list = dataset_simple(layers_size[0],layers_size[-1],n_samples,seed,plots) 414 | # normalize to the interval [-1,1] if zeromean is True 415 | if zeromean: 416 | #print("before: min = {} , max = {}".format(np.min(x_list),np.max(x_list))) 417 | print("normalizing to interval [-1,1]") 418 | for i in range(len(x_list)): 419 | x_list[i] = x_list[i]*2 - 1 420 | for i in range(len(x_list_test)): 421 | x_list_test[i] = x_list_test[i]*2 - 1 422 | #print("after: min = {} , max = {}".format(np.min(x_list),np.max(x_list))) 423 | 424 | # train the model 425 | E_curve, train_acc, val_acc = net.train(x_list,target_list,x_list_test,target_list_test,train_epochs,sample_passes,eta,dropout,shuffling,eta_decay,deformation,test_as_val,zeromean,plots,validation,savepath,r,check_cos_norm) 426 | t1 = time.time() 427 | print('Running time for train: {}'.format(np.round(t1-t0,2))) 428 | # test the model 429 | t0 = time.time() 430 | test_acc = net.test(x_list_test,target_list_test,plots) 431 | test_acc = np.array([test_acc]) 432 | print('Final accuracy = {}'.format(test_acc)) 433 | t1 = time.time() 434 | print('Running time for test: {}'.format(np.round(t1-t0,2))) 435 | # save the results for this network 436 | np.savetxt(savepath+'/train_acc_run'+str(r)+'.txt',train_acc) 437 | np.savetxt(savepath+'/test_acc_run'+str(r)+'.txt',test_acc) 438 | 439 | train_acc_all[r,:] = train_acc 440 | test_acc_all.append(test_acc) 441 | if validation: 442 | val_acc_all[r,:] = val_acc 443 | 444 | if plots: 445 | plt.figure() 446 | plt.plot(E_curve,label='error') 447 | plt.title('Error curves for network'+str(learn_type)) 448 | plt.legend() 449 | 450 | # save the results of the runs until the current one 451 | np.savetxt(savepath+'/train_acc_tot_rep'+str(r)+'.txt',train_acc_all[0:r+1,:]) 452 | np.savetxt(savepath+'/test_acc_tot_rep'+str(r)+'.txt',test_acc_all) 453 | if validation: 454 | np.savetxt(savepath+'/val_acc_tot_rep'+str(r)+'.txt',val_acc_all[0:r+1,:]) 455 | 456 | # remove the single accuracy files 457 | #for i in range(train_epochs): 458 | # os.remove(savepath+'/'+'train_acc_epoch'+str(i)+'.txt') 459 | # if validation: 460 | # os.remove(savepath+'/'+'val_acc_epoch'+str(i)+'.txt') 461 | 462 | # save the final train and test curves 463 | np.savetxt(savepath+'/train_acc_tot.txt',train_acc_all) 464 | np.savetxt(savepath+'/test_acc_tot.txt',test_acc_all) 465 | train_acc_mean = np.mean(train_acc_all,axis=0) 466 | train_acc_std = np.std(train_acc_all,axis=0) 467 | test_acc_mean = np.mean(test_acc_all) 468 | test_acc_std = np.std(test_acc_all) 469 | if validation: 470 | np.savetxt(savepath+'/val_acc_tot.txt',val_acc_all) 471 | val_acc_mean = np.mean(val_acc_all,axis=0) 472 | val_acc_std = np.std(val_acc_all,axis=0) 473 | 474 | file.write('\n Final train accuracy: ') 475 | file.write('mean = ') 476 | file.write(str(np.round(train_acc_mean[-1],4))) 477 | file.write(' std = ') 478 | file.write(str(np.round(train_acc_std[-1],4))) 479 | if validation: 480 | file.write('\n Final validation accuracy: ') 481 | file.write('mean = ') 482 | file.write(str(np.round(val_acc_mean[-1],4))) 483 | file.write(' std = ') 484 | file.write(str(np.round(val_acc_std[-1],4))) 485 | file.write('\n Final test accuracy: ') 486 | file.write('mean = ') 487 | file.write(str(np.round(test_acc_mean,4))) 488 | file.write(' std = ') 489 | file.write(str(np.round(test_acc_std,4))) 490 | 491 | # print final wrap up 492 | print("Mean train accuracy = {} std = {}".format(train_acc_mean,train_acc_std)) 493 | print("Mean test accuracy = {} std = {}".format(test_acc_mean,test_acc_std)) 494 | 495 | 496 | 497 | end_time = time.time() 498 | total_time = round(end_time - start_time,2) 499 | file.write('\n Total computational time : ') 500 | file.write(str(total_time)) 501 | file.write(' seconds ') 502 | 503 | file.close() 504 | 505 | 506 | -------------------------------------------------------------------------------- /plot_compute_slowness.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "pharmaceutical-working", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import matplotlib.pyplot as plt\n", 12 | "from scipy.optimize import curve_fit" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 7, 18 | "id": "ecf09661", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "# function to extract the slowness parameter that provide a quantitative evaluation of the convergence rate\n", 23 | "def extract_slowness(test_data, fit_epochs = 100, ymin=92, ymax=100, ylim=True, in100=False, nclasses=10):\n", 24 | " ### input parameters explained\n", 25 | " # test_data = test curve obtained during training of the model\n", 26 | " # fit_epochs = number of epochs on which we perform the fit --> advice is to use the epochs before the learning rate decay\n", 27 | " # ymin and ymax = limits for the plot\n", 28 | " # in100 = if True the test curve is in the range [0,100], if False it is in [0,1]\n", 29 | " # nclasses = number of classes e.g. 10 for MNIST and CIFAR10, 100 for CIFAR100\n", 30 | " \n", 31 | " # prepare the data\n", 32 | " if type(test_data) is not list:\n", 33 | " test_data = list(test_data)\n", 34 | " # add the chance level before the test accuracy --> needed for the fit\n", 35 | " if in100:\n", 36 | " test_data = [10]+test_data \n", 37 | " else:\n", 38 | " test_data = [0.1]+test_data\n", 39 | " if nclasses == 100:\n", 40 | " test_data[0] = test_data[0]/10 # the chance level is 0.01 or 1 if there are 100 classes (not 0.1 or 10)\n", 41 | " \n", 42 | " data_fit = test_data[0:fit_epochs] \n", 43 | " data = data_fit[0:fit_epochs]\n", 44 | " acc_max = np.max(data_fit)\n", 45 | " # here perform the fit to extract the slowness data\n", 46 | " [param,res] = curve_fit(lambda X,a: (acc_max * X)/(a + X), np.arange(0,np.shape(data)[0]), data_fit, p0=[0.5])\n", 47 | " slowness = param[0]\n", 48 | " \n", 49 | " # plot the fitted curve and the actual test curve\n", 50 | " x = np.arange(0,fit_epochs)\n", 51 | " y = (acc_max * x)/(slowness + x)\n", 52 | " if in100 == False:\n", 53 | " data_fit = [i * 100 for i in data_fit]\n", 54 | " y = y*100\n", 55 | " plt.figure()\n", 56 | " plt.plot(np.arange(1,len(data)+1),data_fit,label='test curve',ls='--',alpha = 1.0)\n", 57 | " plt.plot(np.arange(1,fit_epochs+1),y,label='fit: s={}'.format(np.round(slowness,3)),ls='-',alpha = 1.0)\n", 58 | " plt.xlabel('Training epochs', fontsize=14)\n", 59 | " plt.ylabel('Accuracy [%]', fontsize=14)\n", 60 | " plt.xticks(fontsize=14)\n", 61 | " plt.yticks(fontsize=14)\n", 62 | " if ylim is not None:\n", 63 | " plt.ylim([ymin,ymax])\n", 64 | "\n", 65 | " plt.legend()\n", 66 | " print('accmax=',acc_max,'slowness=',slowness)" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "id": "3147ab7c", 73 | "metadata": {}, 74 | "outputs": [], 75 | "source": [ 76 | "# load your test curve\n", 77 | "curve = np.load ..." 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 8, 83 | "id": "b1289864", 84 | "metadata": {}, 85 | "outputs": [ 86 | { 87 | "name": "stdout", 88 | "output_type": "stream", 89 | "text": [ 90 | "accmax= 0.98643 slowness= 0.026016557401307188\n" 91 | ] 92 | }, 93 | { 94 | "data": { 95 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAEUCAYAAADA7PqTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAA6e0lEQVR4nO3deXxU5dXA8d+ZmewJe9CwKQgoyiaCiqKoiNQN64pWLfZti1qt1lYtWmt932rV2lprrXWhClaL2E1rXYpYsVo3Fi0CsiMSCFsgkD2ZmfP+8VziMJkkE0hmksz5fj7zYebeO/eeJ8A9ee6ziapijDHGRPIlOwBjjDFtjyUHY4wx9VhyMMYYU48lB2OMMfVYcjDGGFOPJQdjjDH1WHIwxhhTT0KTg4jkichDIrJBRCpF5D0RGROx/yARmSkim0WkQkReF5FBiYzRGGNM4msOM4BJwFRgGDAXmCcivUVEgBeBQcBXgaOBDd7+nATHaYwxKU0SNUJaRLKAUuBCVX0pYvsi4DXgGWAlMFJV/+vt8wFbgNtVdUZCAjXGGJPQmkMA8ANVUdsrgXFAhve5br+qhoFqb78xxpgECSTqQqpaKiLvA3eIyFJcjeAyYCywBliBe4z0MxH5NlAG3AT0AQpinVNEpgHTAHJyco454ogjWr0cxhjTkSxatGiHquZHb0/YYyUAETkMeAo4GQgBi4FVwChVPVJEjgF+D4zw9s8DwgCqelZj5x49erQuXLiwFaM3xpiOR0QWqero6O0JbZBW1bWqOh7IBfqq6rFAGrDe279IVUcCXYACVf0K0H3vfmOMMYmRlHEOqlquqkUi0hXXe+mlqP27VXW71411dPR+Y4wxrSthbQ4AIjIJl5BWAAOBB3A9lJ729l8M7MC1PQwDfg28qKpzExmnMcakuoQmB6AzcC+ukXkn8BfgR6pa6+0vAB4EDgKKcN1bf5rgGI0xJuUlNDmo6gvAC43sfxh4OHERGWOMicXmVjLGGFNPoh8rGWMSLBxWNuysoLw6SHrAx8GdM+mUmUYwFKa8JoTfJ/gEfCL4RAj4BJ9P2FFWzSdflFBcXk1pVZCsdD856QFOGNidnnmZ7K6sZXtpNXmZAXIzAmSn+3Gz4DihsFJZG6K6NkRNKEwwpNSEwvTrlk2a30d5dRC/T8hM87dYWUNhxe9zMazeWsrGXRUUl9Wwo6yGHWXVZAR83PoVNx7q0flr2LSrki7ZaXTJSqdrTjr9umVzbP9ucV9PVQmFlWBYyQj49il/vEqratm6p5rqYIiDOmXSIzeDPVW1LNqwix45GfTIS6d7TgbpgcT+Lm/JwZg4lVbVsqu8lqx0P/l5GfX2qyrVwTDBsJKb4f5rbdxZQVVtiJAqwZAbU5Sd7mdAfi4Aq7aWsruyluKyaorLayguq6Fft2y+enRvAH7y0lIqakKIgCqU1wQ59tBuXHVif8JhZcKDb5Ofm0Hfbtn07ZZFv27ZDO/ThYH5OXyxo5Qf/vkTVhSVUFVTiw/Fh/LTyUP46ogClm0s5lszP6rb7iOMiHL/+UM58bBurFm7jfv+tmTf/SiHnX8UPXt14r+rtvLQGysRb79flOx04X/POZK+XTP417Iinv1gA0K47hyCcu/5R9EjJ433lhXx4seb6Jzlp0dOGj1y0shN9zN5RAEBHyxcX8yKLbsJhULUBkPUBENoOMx1pwwAVV79dDOfFu5CUPfD0TDpfuF7pw8CDbN0cSGrt5Yi3nW7+KFHXjpk9AUN039pEem7KqgJBqlVZTuKLy+dY4/uBaq89MkmyqprEUDUlb2gcybjB/UAlL99XEh5dS2iuGuIcki3bE4Y0A1Q/r1yG34f5KT7yE7zk53uo3NmgLzMAFW1IT4t3MWeyiAVNe4aoGQX5NGjezZSVUv1mh1sRilCESDND0MOzuOgvAz2VNawbPNuBKVX50z6XfQz6D2qRf+9J3QQXGuyQXDtUzislFTWsrO8hpKKGnZX1rK7spZxg3rQMy+TpZt285fFhYS9387CqtSGlBsnDKJvt2z+uWwLj761hupgmE6ZaeTnptMz18+1J/WlZ7aPbbtK2VNRSc9sH3lpioRrIVQDoSCEa/n4820s+6KY0spKAoQJECQ7oEwZVQChWl5bspHVRbupqKoiHKolQJju2T6+eUI/CAd59b8bKd5TARqCcBC/hsjPDTDx8O4QDvH2iiLKqqoJEMZHmAAhumUHGNErF8IhlmwspjYYxF+3P0znTB+9O6WDhti8qxwNh/ERwk8YvyhZfshOE9AQVTW1oGHQMD798hgTDwHxoSLuPYKK4BMfiFAbUsJQl9YQ8ImPjDQ/IFQGw6gKeN8PAwG/j8y0AIiP4vIagmEIKoTDoEBuZhrdcjIII2wuqSIt4Cc94CPgd7WOjICftICfMFBdqwRV3TnCEAwrnbPTyU4PUBVUtpVWA0JeVjpdL/gl9D12/34KDQyCs+RgDlhtKMxnRXv478YSdpbXUlET5KwjuzPioDTWb97GY28uwx+swh+qJBCqIhCq4OIR+QzuHmBl4Xb+9MFqMqkhQ2rJoJZMapg4uDMFOcKWnSWsKtxBugTJoJYMqSWdIH07+cmQINXVVdRUV5Gmtfg1SBq1TQfcUnwBgvgJ4SMsAVT8qM+PzxcgKyMdfAEqghDCh4offH5U/PgDAXIzM8Hnp6QqhIqfQCBAWloaaWlp+H3uWGTvnz7vvS9im7e9br+PID5Kq8MEAn7ysjLrtte9Io5135cY2324u+DebRJxbORx3g3V52v4u97Nt+46dX9G7ovYjnej3edzQ9/bewwxvh/xvt7xEddIoNKqWop2V5EZ8NOve3ZCr90USw4mPsFqqCyBqhKoLKGmbCcLVqwnM1RGZricjJB7dQtUk0MlleV7WL2xiGytJFuqyKGKLKpJl1CzL63iJxzIBH8GvrRMJM29J+C9/OkQyPzyvT8dAulfvvene/vS3HtfGsXVULi7ll3VsKtSKa5SJJDOpccNICcrk+qwn7T0DHyBAPjS3Hd9ae4GV/c+AP6A+3Pvq+4GaUz71lBysDaHji4cgrJtULYFyrZD2VYo3+beV+yAimKoKCZcvhMqivEFK/b5ejpwYsTnWvVTShbVGXnkdO1BZnou2Z3zycrrQucuXcnK6YQvPQfSsyEt4s+0LO+V/eX7QGbE+yzEH6Dlmiad7t6rIfVbDowxYMmh/autgpIvYNfnX772FMKeze5VusU9D48SSsvFn5sP2d1ZVJzB55WHslOHUaK5lPlyGdSvD1ecMgKyurArnE21P5cqfzbVmk5NSMlM89HtoDwEN9TdGNOxWHJoLypLYPsK99rm/bljFezZtO9xgSzo0hfyCqD/eOjUCzoVuM85PVm6J4Pb525l9c4Q79x4Kj1yM1j54ReUVtVycOdMjumWzVG9OpER+PJ3+K6JLakxpg2w5NAW1VbBlk9h00IoXAibFsGuiIlp07Kp7TaIJXIUvUdN4eBDh0DXQ90rJz/ms/DdlbU88M8VPPfhFxyUl8lDlw6jR657qPK14/olplzGmHbDkkNboApbl8LqN2DNm7DxQwh7vW7yekGfY2DUldDzKOh5BJvJ52szPuLz4gqyd/l5fMgxnNS33loddfZU1TLxwbfZUVbNN07oz/fPGFzXD98YY2KxO0SyhEMuESx/CdbMcw3GwNbswXyYdR5nn3Ue/r6j3WOhKCtXbGNPVZDHrzyGX72xiv+ZuYC/fedEhvbuvM9xVbUhMtP8dMpM45vj+nPiwB71jjHGmFisK2ui7S6Ej5+FxX9wDceZneGw01jd6Xhu/SSfj3dl8q1x/bnjnCNRVUoqaumakw58ebMHKKsOkpsRYHdlLc99uIFrTj4Mn+/Lx0mvLy3izpeW8fiVx3B0P2s1MMbE1iZWghORPBF5SEQ2iEiliLwnImMi9ueKyG9EpNDbv1JEbkpkjK1m7Vvw3CXw0DCYfx/kD4ZLnmH3dZ9xMzcx8a2+7PR144/fOo47zjkSgLnLt3Lyz99i5n/Ws2LLHk79xXxeX+pqGHsfC3XOSuM7pwzE5xM27qzg0flruPH5j7nm2cX07JRBjj0+Msbsh0TfOWYAw4GpQCFwBTBPRI5U1U24tRxOB67ELQ16MvCkiOxQ1T8kONaWsetzeP12WPkK5B4M477v2g+6HgpATWk181du49pTDuPGCYP2mYTs8IPyGNmvC3e9vBwR6JGbwWH5OQ1e6vkFX/Dbt9YS8Ak3nT6Y75x6GGl+m3jXGNN8CXusJCJZQClwoaq+FLF9EfCaqt4hIkuBv6jqTyL2vw18qqrXN3b+NvdYqbYS3n0I/vOQm3Zg/C1w/HfcCF6gaHclB3fKRETqHhHFoqq88qmboOz2s4bUTdgWSzis/HlxIUN7debIXp1aoVDGmI6mLYyQDgB+oCpqeyUwznv/LnCuiMxQ1Y0icgIwErecaPux6p/w6s1ucNrQC2HiT6Fz77rdm0oq+epv/8MFR/fmtrOGNNpzSEQ4Z3gvzhlev2E6ms8nXDK6b4sUwRiT2hKWHFS1VETeB/bWELYAlwFjgTXeYTcAjwFfiEjQ2/ZdVf1HouI8YItmwss3Qv4QmPoy9D95n91l1UG+OXMBVTUhLjymT3JiNMaYJiS6zeFK4Clce0MIWAzMBvZORP5d3FQ+k4ENuDaHX4jI56r6evTJRGQaMA2gX782MJBrwe/hle/DwIkw5VlIy9xndyis3DD7Y1ZvK+Ppq8Yw+KC8JAVqjDGNS0pXVhHJATqpapGIzAFygYuA3cDFUW0SM4BDVfX0xs6Z9DaHj550j5IGTYIpf6hrW4j0038s5/fvruee84dy+XGHJCFIY4zZV1toc6ijquVAuYh0BSYBtwJp3it6lrgQbX2t6w8eg9d/CIefBRfPjJkYAE49vCc56X5LDMaYNi+hyUFEJuFu9Ctwk3k+AKwEnlbVWq9n0n0iUoZ7rDQe+DouebRN7/8W/nk7HHEOXPS0W18gys7yGrrlpDNuUA/GDeqRhCCNMaZ5Ev0beWfgEVxyeAbXO+kMVd27fNelwALgOWA5MB34sfedtqdwIfzzdt4OjGXN+N/ETAyfFe1h/M/f4i+LCpMQoDHG7J+E1hxU9QXghUb2bwG+kbiI9t+fFm7kKx//L1mZ3bij9joqf7+Y2d8+jkERjczb9lTxzZkLyMkIcOJAqzEYY9qPtv0sv416b+0O/vS3P5G36d8ETrqJp68+FRG49IkPWLmlFICKmiDfemYhJZW1zJg6moM7ZzZxVmOMaTssOTTTxp0VXPfcYm7P/CvhnJ4w5lsM7JnL89OOJ+AXLnvyAz7fUc5Ncz7h0027efjSo20mVGNMu2OzsjVDeXWQbz+zkFHhTxnJp3DSfW6NZOCw/FyenzaW37+7jt5dsxg7oDvH9e/O6UcelOSojTGm+Sw5NMOs9z9n1dY9zOnzClQVwDH7No/075HD3V8dBsBVJ/ZPRojGGNMi7LFSM1x98mG8enaQztsXwUk/qDcC2hhjOgpLDs3gFzjis4ehc18Y9fVkh2OMMa3GHivF6YF/riB/81tctWkRnPvrBkdBG2NMR2A1hzi9u2o7p2yeAV0OgZGXJzscY4xpVVZziEM4rHTatoBD/Wvg5EfAn5bskIwxplVZzSEOhbsqGRRe5z4MnpTcYIwxJgEsOcRh1dZSBstGajO7QU5+ssMxxphWZ8khDn6/cHRmEdLzSBBJdjjGGNPqrM0hDqcO6gFSCAefkuxQjDEmIazmEAct2QC15dBzSLJDMcaYhEhochCRPBF5SEQ2iEiliLwnImMi9msDr98mMs5IwVCYG38z23046KhkhWGMMQmV6JrDDNyyoFOBYcBcYJ6I9Pb2F0S9zvW2N7gGRGv7vLiCPjWfuw/5RyQrDGOMSaiEJQcRyQIuBKar6nxVXaOqdwFrgGvBLfYT+QLOA1ap6tuJijPaqq2lHO7bSE1ub8jslKwwjDEmoRJZcwgAfqAqanslMC76YBHJwy0b+mTrh9awVVtLOVw24j/oyGSGYYwxCZWw5KCqpcD7wB0i0ltE/CJyBTAW9wgp2mVABjCroXOKyDQRWSgiC7dv394qca/ZsovDfEX4D7b2BmNM6kh0m8OVQBgoBKqBG4DZQCjGsd8GXlTVBu/6qvqEqo5W1dH5+a0zOO2Mg8pIIwg9reZgjEkdCU0OqrpWVccDuUBfVT0WSAPWRx4nIiOB0ST5kRLA5IIS98a6sRpjUkhSxjmoarmqFolIV1zvpZeiDpkGfA7MS3Rskcqrg1RtWgrihx6DkxmKMcYkVEJHSIvIJFxCWgEMBB4AVgJPRxyTDVwO/FxVNZHxRZv32VYy3/03p3TvT4at+maMSSGJrjl0Bh7BJYdngHeBM1S1NuKYKUAOEQkjWVZuKeUI30bSCqwx2hiTWhJac1DVF2hiQJuqPk0bSAwAnxdtp69sw2fdWI0xKcbmVmpE7ZYV+FDrqWSMSTmWHBpQUROkS9kq98GSgzEmxVhyaMQ3B1cR9mdAt/7JDsUYYxLK1nNoQHZ6gCN8hZB/OPj8yQ7HGGMSymoODVi2eTfBLctsmm5jTEqy5NCA376ygED5VhsZbYxJSZYcGqDblrs31hhtjElBlhxi2FNVS/eKde6DJQdjTAqy5BDD6q2lHCFfUJuWB516JTscY4xJuAZ7K4nIw/txvrtUdecBxNMmrNlWxmBfIaEeR5AmkuxwjDEm4Rrryno9bnGemjjPNQ54CGj3yeHkQT3IzyxCCsYmOxRjjEmKpsY5nK+q2+I5kYiUtkA8bUKBrwRq9sDBQ5MdijHGJEVjbQ7fAHY341xXA1sPLJy2YfOqRQDUdD88yZEYY0xyNJgcVHWWqlbHeyJV/aOqljd2jIjkichDIrJBRCpF5D0RGRN1zGAR+auIlIhIhYgsFpGEDjb4fPUyAKrybNoMY0xqavb0GSJyFHAK4AfeVdXFzfj6DGA4MBW3jvQVwDwROVJVN4lIf+A/uLUeTgNKgCOAsubGeSBEw+6NPz2RlzXGmDajWV1ZReRq4C1gPO7mPV9Ebo3zu1nAhcB0VZ2vqmtU9S5gDXCtd9g9wFxV/YGqLlbVdar6qqpubE6cB85LDtZRyRiTohpNDiKSH7XpBmC4ql6iql8FzgZujvNaAVxtoypqeyUwTkR8wLnAchF5XUS2i8gCEZkS5/lbzt7FSa0bqzEmRTVVc/hIRK6K+FwBRD7/PxLYE8+FVLUU1zX2DhHpLSJ+EbkCGAsUAD2BXOB2YC4wEZgNPCci58Q6p4hME5GFIrJw+/bt8YTRLGLJwRiToppKDuOA80TkTRE5DFdz+IOIbBWRYuD/gO8043pX4p7ZFALV3vlmA6GIWF5S1QdV9RNVfRC3rOh1sU6mqk+o6mhVHZ2fH13J2X/DencCIDvNZjQ3xqSmRu9+qroJOF9ELsT9Nv8EMBg4DHczX6mq0Y+JGjvfWmC8iOQAnVS1SETmAOuBHUAQWB71tc+AS+O9RkvIzXA/Fr/Pag7GmNQUV4O0qv4FOBrY25soU1X/25zEEHW+ci8xdAUm4WoLNcACIHpwwWBgw/5cZ39t2V0JQE1ImzjSGGM6piafm4jIWbh2hv+q6jUiMg54SkTeBH7U1NiGqHNNwiWkFcBA4AFgJfC0d8jPgRdE5B3gX8CpuFrDV+MuUQtYt6OMg4HakGKdWY0xqaip3kq/xN24xwCPi8iPVfVdYBRu9PTHXvKIV2fgEVxyeAZ4FzhDVWsBVPVFYBquB9SnwHeBr6vqK80p1AHzKgzWHm2MSVVN1RymApNUdZGIdAM+AH7q3cx/4rUXPA68Gs/FVPUFXANzY8fMBGbGc77Wszc7JDcKY4xJlqbaHCpw7QwAfYkao6Cqy1X1pNYIrC0QW+7CGJOimrr73QY8IyKbgbeBH7d+SG2AupqDPVYyxqSqprqyPicirwMDgNWqWpKQqJLs6H5dYB2kB/zJDsUYY5Kiyd5KqloMFCcgljYjK81VqHxWdTDGpKgGHyuJyKMikhvviUTkQRHp3jJhJdemEjfOIRi2cQ7GmNTUWJvD1UBWM871LVxX1XZvww43dCOklhyMMampscdKAqwTkXjvkDktEE8bsbdB2norGWNSU2PJ4Rv7cb4OsUxoxJzdSY3CGGOSpcHkoKqzEhlIm2IjpI0xKc6emzTC1nMwxqQqSw4xHHNIF8Cm7DbGpC5LDjFkBNyPxRqkjTGpyu5+MXyx03VlDds4B2NMioorOYjIV0XkgOeSEJE8EXlIRDaISKWIvCciYyL2zxQRjXp9cKDXba6NOyv2BpToSxtjTJsQb83hOWCTiNwvItErtTXHDNzKb1OBYbilR+eJSO+IY+YBBRGv5qwX0aKsQdoYk6riTQ4HAz8BxgPLReRdEfmGtxZ0XEQkC7gQmK6q81V1jareBawBro04tFpVt0S8dsZ7jRZjI6ONMSku3jWkS1X1cVU9Hvcb/4fAvUCRiDwpIsfHcZoA4CdqTQigEhgX8XmciGwTkVXeuXvGE2NrsAZpY0yqavbdT1WXA78CngDSgSnAOyLyoYgMb+R7pcD7wB0i0ltE/CJyBTAW9/gI4HXg68AE4AfAscC/RCQj1jlFZJqILBSRhdu3b29uURphNQdjTGqLOzmISJqIXOKt77AeOA24BjgIOARYBcxp4jRXAmGgEKgGbgBmAyEAVX1eVf+uqp+q6svAmcDhwNmxTqaqT6jqaFUdnZ+fH29RmnRc/27ujbU5GGNSVJPrOQCIyG+Ay3C/Uv8B+L5Xg9irUkR+BHze2HlUdS0w3mur6KSqRd461OsbOH6ziBQCg+KJs6UELCkYY1JcXMkBOBK4HvirqtY0cMxm4NR4Tqaq5UC5iHTF9V66NdZxItID6A0UxRlni1i/o8xbONuShDEmNcWVHFR1QhzHBHHrTDdIRCbhHmWtAAYCDwArgae9hYXuAv6CSwaH4hq9twF/iyfOlrK5pNIlB6tBGGNSVLyD4O4RkWtibL9GRH7ajOt1Bh7BJYdngHeBM1S1FtfuMAx4Cdd+MQuXOMZ6jdkJI9YgbYxJcfE+VroSuDjG9kXAbcCP4zmJqr4AvNDAvkrcI6akq0sNVnMwxqSoeHsr9QRi9RUtxvVW6lhsEJwxJsXFmxy+AE6Ksf1kXLdUY4wxHUi8j5UeB34lIunAv7xtE3ANxve3RmDJdOJh3WFTsqMwxpjkibe30i+9bqUP40ZFA9QAv1bVn7dWcMmjWDdWY0wqi7fmgKreJiJ348Y8CLBcVctaLbIkWrOtjAGILXZhjElZcScHqBu8tqCVYmkztu6ppL81ShtjUljcyUFETsVNodGPLx8tAaCqp7VwXEmn9ljJGJPC4h0EdxXwGpAHnILr1toVGAUsb/CL7ZXVGowxKS7ex+o3A9er6mVALXCbqh4NPAt0yHYHqzkYY1JZvMlhAG75TnBTbed67x8BrmrhmJLO7xMbHG2MSWnxJodi3CMlcCMAhnrvuwNZLR1Ush3fvysBn/VVMsakrngbpN8BzgA+xc2N9LCITMQNhHujlWJLLqs6GGNSWLzJ4Xog03t/LxAETsQlirtbIa6kWrVlDwPC2rx+vsYY04E0+exERALApXs/q2pYVe9X1cmqerOqlsR7MRHJE5GHRGSDiFSKyHsiMqaBY58QERWRm+M9f0vZXlZDKJzoqxpjTNvRZHLwFvF5AEhrgevNwE3LPRW3dsNcYJ6I9I48SEQuAsbgVpdLOFvPwRiT6uJtdf0AOOZALiQiWcCFwHRVna+qa1T1LmANcG3EcYcAvwa+hus2m3CqWHowxqS0eB+rPwn8QkT64Rb4KY/cqaqL47yWH6iK2l4JjIO6R1izgbtV9TNJWqOwpQZjTGqLNzn80fvzwRj7FHfTb5SqlorI+8AdIrIU2IKbjmMsrvYA8L9Asar+Lp6gRGQaMA2gX79+8XwlLhl+sd5KxpiUFm9y6N9C17sSeAq3QFAIWIyrKYwSkfG4AXUj4z2Zqj4BPAEwevToFvt1f/ShXWG79VUyxqSueNdz2NASF1PVtcB4EckBOqlqkYjMAdYDpwIFQFHE4yQ/cL+IfE9V+7REDHEGmrBLGWNMWxRXchCRCxrbr6p/bc5Fvam/y0WkK6730q3Ai8Cfow79J65m8WRzzn+glm/Zw2HBMBmJvKgxxrQh8T47ib5p77X3V+wm2xwARGQSrofUCmAgrovsSuBpVa0FtkUdXwtsUdWVccbZIkrKaghZ7cEYk8Li6sqqqr7IF249h+Nw02qc3IzrdcZN1rcCeAZ4FzjDSwxtis3KaoxJZfvV6uoNjFsgIrcDvwNGxPm9F3BTbsR7nUP3J74DZ7UGY0xqO9CpR0uAw1ogjjZGreZgjElp8TZIj4rehOtZ9EPg45YOKtky0/wkbwCeMcYkX7yPlRbinrVE3zE/AL7RohG1AaP6doFiG+dgjEld+zsILgxsV9XoqTA6EKs5GGNSV0IHwbUXSzeVMKA2SHayAzHGmCSJq0FaRO4RkWtibL9GRH7a8mEl156qWoK2noMxJoXF21vpSmI3PC8Cvt5y4RhjjGkL4k0OPYHtMbYXAwe1XDhthFpXVmNMaos3OXwBnBRj+8m4GVY7GLVhcMaYlBZvb6XHgV+JSDrwL2/bBOBe4P7WCCyZcjP8BHxWczDGpK54eyv9UkR6AA/j5lUCqAF+rao/b63gkmV47y6wqyWWzDbGmPYp7pFeqnqbiNwNHIkbBLBcVctaLbKksodKxpjUFu/0GQcDAVUtBBZEbO8D1Krq1laKLymWFO7m0OoQnZIdiDHGJEm8DdJ/AM6MsX2Sty8uIpInIg+JyAYRqRSR90RkTMT+n4rIChEpF5FdIvKmiJwQ7/lbSll1LaGw1R6MMakr3uQwBvh3jO3vAKObcb0ZuIQyFRgGzAXmiUhvb/9K4Dpv3zjc8qGvi0hCu8uKWm8lY0xqizc5BCDmqpmZDWyvR0SygAuB6ao6X1XXqOpdwBrgWgBVfVZV31TVdaq6DPg+kAeMjDPOFmS9lYwxqSve5PAh3g08ynVEtEE0IYBbTjR6sr5KXC1hH1632WnAHuCTOK/RQqzeYIxJbfH2VvoR8C8RGQG86W07DTgaOD2eE6hqqYi8D9whIkuBLcBlwFhc7QEAETkHeB7IBoqAiQ01eIvINFwCoV+/fnEWpWmdMtMI+A90HSRjjGm/4l1D+gPcTXwdcAHu8dB6YKyqvteM612Jm+67EKgGbgBmA6GIY97CPUY6AXgdeEFEChqI6wlVHa2qo/Pz85sRRuOO6pVHp0wb52CMSV3NGefwX+CK6O0icrqqzovzHGuB8SKSA3RS1SIRmYNLNHuPKcfVJNYAH4jIauBbQGJnf7WV4IwxKWy/np2ISG8RuUNE1gP/bO73VbXcSwxdcb2XXmoixrgavVvKx1/sYldFbSIvaYwxbUrcNQcR8QOTgW8DE4ElwO+APzXjHJNwN/sVwEDgAVz31adFpBNwK/Ayrq0hH9fg3Qd4Id5rtISq2iAhtUZpY0zqajI5iMjhuMc6XwfKgT/iksOVqrq8mdfrjJusrw+wE/gL8CNVrRWRNOAo4H+A7rjpwBcAJ6vqkmZe54DZlN3GmFTWaHIQkXeAocCfgUtU9W1v+w/352Kq+gIN1AJUtQI4f3/O2+Ks1mCMSXFN1RzGAr8FnlTVpQmIp82wmoMxJpU11SA9GpdA3hGRj0XkJm8Svg6tW3YaGTbOwRiTwhq9A6rqJ6p6HVAAPAicB2z0vne219uowzn84Dw6Z6c3faAxxnRQ8Q6Cq1LVP6jqKcAQXC+jm4AtIvJaK8aXHNbkYIxJcc1+duJNmDcd6AtcglsRrkNZtGEn28s6XLGMMSZucY9ziKaqIdzgtcYGsLVLtaEwtpyDMSaVWatrTJYZjDGpzZJDA6wrqzEmlVlyiMlqDsaY1GbJIYYeOelkpvmTHYYxxiTNfjdId2QD83Oh0tZzMMakLqs5xKTYGtLGmFRmySGGBZ/vpGhPdbLDMMaYpElochCRPBF5SEQ2iEiliLwnImO8fWkicr+ILBGRchEpEpE/ikjLLQ4dp3BYrUnaGJPSEl1zmIFb+W0qMAyYC8wTkd5ANjAKuMf78zzcKOzXRSQJbSP2WMkYk7oSdtMVkSzgQuBCVZ3vbb5LRM4FrlXVO3CLCEV+52pgGW4+p08TFat1ZTXGpLpE1hwCgB+oitpeCYxr4DudvD93tVZQDbH0YIxJZQlLDqpaCrwP3CEivUXELyJX4BYUKog+XkTSgV8CL6tqYaxzisg0EVkoIgu3b9/eYrH2zEsnJ916+RpjUlei2xyuBMJAIVAN3ADMBkKRB3ltDM8CXYBvNHQyVX1CVUer6uj8/PwWC7J/92y62HoOxpgUltDkoKprVXU8kAv0VdVjgTRg/d5jvMQwGxgOTFDV4kTGWEesQdoYk7qSMs5BVctVtchbSW4S3rTfIpIGzMElhlNVdUsy4vtofTGFuyqTcWljjGkTEvpgXUQm4RLSCmAgbkW5lcDTXo3hT8AY4FxAI9ar3q2qCb1b26ysxphUluiaQ2fgEVxyeAZ4FzhDVWuBPrixDb2ARUBRxGtKIoMU66tkjElxCa05qOoLwAsN7PucNjLyTC03GGNSnM2tFIOgbSRNGWNMclhyiOHgTpnkZtiU3caY1GUjvWLo2y0LQhnJDsMYY5LGag4xhEJhwtbwYIxJYVZziGHxxhJ6VldxSLIDMaYDqq2tpbCwkKqq6GnWTGvKzMykT58+pKXF98jckkMsVmswptUUFhaSl5fHoYceithMBAmhqhQXF1NYWEj//v3j+o49VmqApQdjWkdVVRXdu3e3xJBAIkL37t2bVVuz5BCDDYIzpnVZYki85v7MLTk0yP7xGtMRlZSU8Oijj+739x966CEqKipaMKK2yZJDDL06Z9LZpuw2pkNqq8khFAo1fVACWXKIoVfnTLpacjCmQ5o+fTpr165l5MiR3HLLLQA88MADjBkzhuHDh/OTn/wEgPLycs4++2xGjBjB0KFDmTNnDg8//DCbN2/m1FNP5dRTT6137gULFnDCCScwYsQIjj32WEpLS5k5cybXX3993THnnHMO8+fPByA3N5c777yT4447jp/97GdccskldcfNnz+fc889F4C5c+cyduxYRo0axcUXX0xZWVlr/XjqWG+lGGpCYXxhtR+OMQkw5fH36207Z3gBV449lMqaEFc9/VG9/Rcd04eLR/dlZ3kN1z67aJ99c64e2+j17rvvPpYuXconn3wCuBvv6tWr+eijj1BVJk+ezL///W+2b99Or169eOWVVwDYvXs3nTt35sEHH+Stt96iR48e+5y3pqaGKVOmMGfOHMaMGcOePXvIyspqNJby8nKGDh3K//3f/xEMBhkwYADl5eXk5OQwZ84cpkyZwo4dO7j77ruZN28eOTk53H///Tz44IPceeedjZ77QFnNIYZPC0vYaOs5GJMS5s6dy9y5czn66KMZNWoUK1asYPXq1QwbNox58+bxwx/+kHfeeYfOnTs3ep6VK1dSUFDAmDFjAOjUqROBQOO/Yvr9fi688EIAAoEAX/nKV3j55ZcJBoO88sornHfeeXzwwQcsX76cE088kZEjRzJr1iw2bNjQMoVvRKLXc8gDfgqcD/QEPgZuVNUF3v4LgKuBUUAP3II/8xMZo6PWX8mYBGnsN/2sdH+j+7vlpDdZU2iKqnLbbbdx9dVX19u3aNEiXn31VW677TbOOOOMRn9bV9WYPYICgQDhcLjuc2R30szMTPx+f93nKVOm8Nvf/pZu3boxZswY8vLyUFUmTpzI7Nmz97eI+yXRNYcZuJXfpgLDgLnAPBHp7e3PAd4Dvp/guOqxvkrGdEx5eXmUlpbWfZ40aRJPPfVU3XP8TZs2sW3bNjZv3kx2djZXXHEFN998M4sXL475/b2OOOIINm/ezIIFCwAoLS0lGAxy6KGH8sknnxAOh9m4cSMffVT/Mdlep5xyCosXL+bJJ59kyhS3jM3xxx/Pf/7zH9asWQNARUUFq1atapkfRiMSVnMQkSzgQuDCiNrAXSJyLnAtcIeq/sE7tkfssySGoLYSnDEdVPfu3TnxxBMZOnQoZ555Jg888ACfffYZY8e6Gkhubi7PPvssa9as4ZZbbsHn85GWlsbvfvc7AKZNm8aZZ55JQUEBb731Vt1509PTmTNnDt/97neprKwkKyuLefPmceKJJ9K/f3+GDRvG0KFDGTVqVIOx+f1+zjnnHGbOnMmsWbMAyM/PZ+bMmVx22WVUV1cDcPfddzN48ODW+hEBIJqgqSK8R0p7gK+o6j8jtr8LBFX1lIhtPYDtNOOx0ujRo3XhwoUtEuvi+86gS7CYAXcsavpgY0yzfPbZZwwZMiTZYaSkWD97EVmkqqOjj03YYyVVLQXeB+4Qkd4i4heRK4CxQEGi4ohH7y6ZdMu1rqzGmNSV6DaHK4EwUAhUAzcAs4H9Gv0hItNEZKGILNy+fXuLBXlQXiZdsiw5GGNSV0KTg6quVdXxQC7QV1WPBdKA9ft5vidUdbSqjs7Pz2+xOKtqg9SEwk0faIwxHVRSxjmoarmqFolIV1zvpZeSEUdDlhftsXEOxpiUluhxDpNwCWkFMBB4AFgJPO3t7wb0A7p4XxkoIiXAFlXdkshYjTEmlSW65tAZeASXHJ4B3gXOUNVab/9k3MC4vf3DnvQ+X5PIIK0rqzEm1SW6zeEFVT1MVTNUtUBVr1fV3RH7Z6qqxHjdlcg43fBoSw7GdFQPP/wwQ4YM4fLLL+fvf/879913HwAvvvgiy5cvb7Xr7ty5k4kTJzJo0CAmTpzIrl27Yh73+uuvc/jhhzNw4MC62ABuueUWjjjiCIYPH875559PSUlJ3b4lS5YwduxYjjrqKIYNG3bAy7Da3Eox2GI/xnRsjz76KK+++irPPfcckydPZvr06UDrJ4f77ruPCRMmsHr1aiZMmLDPjX+vUCjEddddx2uvvcby5cuZPXt2XUwTJ05k6dKlLFmyhMGDB3PvvfcCEAwGueKKK3jsscdYtmwZ8+fPj3ut6IZYcoihT7ds8vMykx2GMaYVXHPNNaxbt47Jkyfzq1/9qm5K7ffee4+///3v3HLLLYwcOZK1a9c2eI5ly5Zx7LHHMnLkSIYPH87q1avjuvZLL73E1KlTAZg6dSovvvhivWM++ugjBg4cyIABA0hPT+fSSy/lpZdcn50zzjijbjK/448/nsLCQsBNHjh8+HBGjBgBuFHgkXM27Q+blTqGHjlp4D+wrGuMicNr02HLpy17zoOHwZn1fyPf67HHHuP111+vm3Z75syZAJxwwglMnjyZc845h4suuqjuWHAJJfocN954I5dffjk1NTV1C/WcdNJJMedd+sUvfsHpp5/O1q1bKShwY34LCgrYtm1bvWM3bdpE37596z736dOHDz/8sN5xTz31VN38S6tWrUJEmDRpEtu3b+fSSy/l1ltvbfBnEA9LDjFU1ITwBcNY3cGY1BadFPYaO3Ys99xzD4WFhVxwwQUMGjQIgHfeeeeArxlrSqPo2V7vueceAoEAl19+OeAeK7377rssWLCA7OxsJkyYwDHHHMOECRP2Ow5LDjGs2rqHHKoYlOxAjOnoGvkNvy372te+xnHHHccrr7zCpEmTmDFjBqeddlqTNYeDDjqIoqIiCgoKKCoqomfPnvWO7dOnDxs3bqz7XFhYSK9eveo+z5o1i3/84x+8+eabdUmjT58+jB8/vm4BorPOOovFixdbcmhpoorGmJfdGNOxNTQdd7R169YxYMAAbrjhBtatW8eSJUs47bTTmqw5TJ48mVmzZjF9+nRmzZrFeeedV++YMWPGsHr1atavX0/v3r15/vnn+eMf/wi4Xkz3338/b7/9NtnZ2XXfmTRpEj//+c+pqKggPT2dt99+m5tuuqmZpd+XNUgbY4zn0ksv5YEHHuDoo49m7dq1PPbYY3XtDpHmzJnD0KFDGTlyJCtWrODrX/96XOefPn06b7zxBoMGDeKNN96o6yW1efNmzjrrLMAtDvTII48wadIkhgwZwiWXXMJRRx0FwPXXX09paSkTJ05k5MiRdY+9unbtyve//33GjBnDyJEjGTVqFGefffYB/SwSNmV3a2vJKbv/e894sqSGwbfXX9vWGHNgbMru5GmTU3a3NzZC2hiTyqzNIYZ+3bLwhw+sj7AxxrRnlhxi6Jqdtp8rTBhjTMdgj5ViKKsKUlETTHYYxnRYHaWtsz1p7s/ckkMM63eUsnHXgU1aZYyJLTMzk+LiYksQCaSqFBcXk5kZ/9Bee6wUg3ovY0zL69OnD4WFhbTk0r6maZmZmfTp0yfu4xO92E8e8FPgfKAnbq2GG1V1gbdfgJ8A04CuwIfAdaq6LKFxWmYwptWkpaXRv3//ZIdhmpDox0ozcMuCTgWGAXOBeSLS29t/K/AD4LvAGGAb8IaXVBLMurIaY1JXwpKDiGQBFwLTVXW+qq7xFvFZA1zr1Rq+B9ynqn9R1aW4JJIHfC1RcYKt52CMMYmsOQQAPxDd0lsJjAP6AwfjahMAqGol8G/ghATFCEC/7tn07pbd9IHGGNNBJazNQVVLReR94A4RWQpsAS4DxuJqDwd7h26N+upWoDcxiMg0XPsEQJmIrIwjlB7AjriCvqZdPFqKvzxtX0cqC3Ss8nSksoCVJ9IhsTYmurfSlcBTQCFumNliYDYwKuKY6Gc6EmObO1D1CeCJ5gQgIgtjzSPSXnWk8nSkskDHKk9HKgtYeeKR0AZpVV2rquOBXKCvqh4LpAHrcTUJ+LIGsVdP6tcmjDHGtKKkDIJT1XJVLRKRrrjeSy/xZYKYuPc4EckETgLeS0acxhiTqhI9zmESLiGtAAYCDwArgadVVUXkIeBHIrICWAXcAZQBf2zBMJr1GKod6Ejl6UhlgY5Vno5UFrDyNCmh6zmIyCXAvUAfYCfwF+BHqrrb2793ENzV7DsIbmnCgjTGGNNxFvsxxhjTcmziPWOMMfWkTHIQke+IyHoRqRKRRSJyUrJjioeInCwifxeRTSKiInJV1H4RkbtEZLOIVIrIfBE5KknhNkpEbhORBSKyR0S2i8jLIjI06pj2VJ7rRGSJV549IvK+iJwdsb/dlCWaiNzu/Xt7JGJbuymPF6dGvbZE7G83ZdlLRApEZJb3f6dKRJaLyPiI/S1appRIDiIyBfg18DPgaFzvp9dEpF9SA4tPLrAUuBE3mjxaG5qPqkmnAI/iRryfBgRxc2t1izimPZWnEPghbpzOaOBfwIsiMtzb357KUkdEjge+DSyJ2tXeyrMSKIh4DYvY167KIiJdgP/gxn2dDQzBxb4t4rCWLZOqdvgXrmH7yahtq4F7kx1bM8tRBlwV8VmAIlyj/t5tWUApcHWy442jPLm4wZDndoTyePHuxHWoaJdlAToDa3HJez7wSHv8uwHuApY2sK9dlcWL72fAfxrZ3+Jl6vA1BxFJB44hYs4mz1wSPGdTK2gz81Htpzxc7XWX97ndlkdE/CJyKS7hvUf7LcsTwJ9V9V9R29tjeQZ4j2PXi8jzIjLA294ey/JV4EMRmSMi20TkExG53uvhCa1Qpg6fHHBzjviJPWdT9Gjs9qax+ajaQ9l+DXwCvO99bnflEZFhIlIGVAOPAeer6qe0z7J8Gzf+6Mcxdre38nwIXAWciXtEdjDwnoh0p/2VBWAA8B1gHW7g8K+B+4DrvP0tXqZUWgku7jmb2qF2VzYReRA3G+84VQ1F7W5P5VkJjAS64KaknyUip0TsbxdlEZHDcY8uTlLVmkYObRflUdXXIj+LyAe4G+tU4IO9h0V9rU2WxeMDFqrqbd7nj0VkEC45PBJxXIuVKRVqDjtwz7U74pxN7XI+KhH5FW5G3tNUdV3ErnZXHlWtUbc2yd7/uJ8AN9H+yjIWV8teKiJBEQkC44HveO+LvePaS3n2oaplwDJgEO3v7wZce8LyqG2fAXs71bR4mTp8cvB+C1pExJxNnom0/zmb2t18VCLya9ziTaep6oqo3e2uPDH4gAzaX1lexPXmGRnxWgg8771fRfsqzz68WI/A3WTb298NuJ5Kh0dtGwxs8N63fJmS3QqfoJb+KUAN8C1cF7Bf43r+HJLs2OKIPZcv/7NWAHd67/t5+38I7AEuAIbi/jNvBvKSHXuMsvzWi/U03G84e1+5Ece0p/Lc5/3nOxR3Y70XCANntreyNFC++Xi9ldpbeYBf4Go+/YHjgH94sR/S3srixTsGqAV+hGsXuhjYjZteqFX+fpJe6AT+cL8DfI5rOFwEnJzsmOKM+xTcM8Po10xvv+C67RXhVtl7Gxia7LgbKEuscihwV8Qx7ak8M3G/uVXj+pTPAya1x7I0UL7o5NBuyhNxY6wBNuHmcTuyPZYlIuazgf968a4CbsCbAqk1ymRzKxljjKmnw7c5GGOMaT5LDsYYY+qx5GCMMaYeSw7GGGPqseRgjDGmHksOxhhj6rHkYDo0EZkpIv9o5nfmRy5y09F5C+FclOw4TNuSShPvmTZMRJoacDNLVa/aj1PfiBsc1BwX4EajGpOyLDmYtqIg4v05wJNR2/ZZBU9E0lS1yRu4qu5ubiCqurO53zGmo7HHSqZNUNUte19ASeQ2IBMoEZHLRORfIlIJXC0i3UVktogUemvmLhORb0SeN/qxkvfI6FER+ZmI7PAWTvmFiPiijolcO/lzEblDRB4Xt1Z0oYjcEnWdwSLytre270oROUtEyiRqze9oIvINby3gKhFZJSI3RcWi3qIur4hIhYhsEJEros4xTETmeT+DnV6ZO0cdM1VEPhWRahHZKiIzo0LpJiJ/EpFyEVkX4xp3eteuFpEtIvJMY+Uy7Z8lB9Oe3Itbg/pI3CyimcBiXE3jKNyEio+LyIQmznM5bv3qE4Drge/hJmdszE3Ap7j1ou8Hfi4iYwG8m/nfvHMej1tk5ie42Vkb5C2u8zPcZIpDcOv//hA3D1ik/wX+jptw8QngGREZ7Z0jG3gdN5HkscD5XrmeirjO1cDjwNPAcOAs3PTVke4EXgJGAHOAp0TkEO/7FwI3e3ENwv28P2qsbKYDSPZkUvayV/QLuMj906z7fChugr4fxPHd54EZEZ9nAv+I+DwfeD/qO29EfWc++0449zkwO+o7q4E7vPeTcImhd8T+E7yYr2ok1i+AK6O2fQ9YHvFZqb/++TzgWe/9t3Gzc+ZF7D/F+95A73MhcF8jcSgR66njHjdXAFd4n7+PW9QoLdn/NuyVuJfVHEx7sjDyg7du849EZImIFHvLdV7AlwugNGRJ1OfNuEVR9vc7RwCbVXVTxP4FuOm7YxKRfKAvrqZTtveFmwb8sKjD34/x+Ujv/RBgiaqWRux/z7v2kSLSE+gNvNlY4Ygon6oGge0R5fsTrpa2XkR+LyIXi0ijtSLT/lmDtGlPyqM+34x7FHMj7pFPGe4xTVM3+uiGbKXpR6yNfWd/lmLc+91rOLAFZhq7thJ/T60Gy6eqG71lRCcApwO/BH4iIsepavTfiekgrOZg2rNxwMuq+gdV/QRYi1sdK9E+A3qLSK+IbaNp5P+Xqm7FrTNwmLplRvd5RR1+fIzPn3nvlwMjRCQvYv8J3rU/i7hOU+0wjVLVKlV9RVVvwi08cxRw4oGc07RtVnMw7dkqYIqIjMOtFf5d3MpfHyc4jjdwz+RnicjNQBbwIK4dorEaxV3Ab0SkBHgVSMM1ePdW1XsjjrtARBbg2kIuwt3oj/P2PYdrsH5GRO4EuuIan/8akWTuAX4lIluBV4BsYIKq/jKewnk9rgLAh7ja2RRcTWN1PN837ZPVHEx7djeu18xrwL9xj52eS3QQqhrG9RLK8OKZhbshK25Froa+NwP4H+BK3Apf7wDTcOsBR7oLuBDXLnAt8A1VXeCdowLXIN7Ju/ZLuDaJ/4m4zu+A63CN10txvZuOakYRS4BvevEt9WK5QFWj4zQdiK0EZ0wrEJERwCfAaFVddADnUeBiVf1zS8VmTDzssZIxLUBEzsfVXFbjut4+iKsNLE5iWMbsN0sOxrSMPNzguL7ALlz7wE1qVXPTTtljJWOMMfVYg7Qxxph6LDkYY4ypx5KDMcaYeiw5GGOMqceSgzHGmHosORhjjKnn/wFtKu8AXMbc/wAAAABJRU5ErkJggg==\n", 96 | "text/plain": [ 97 | "
" 98 | ] 99 | }, 100 | "metadata": { 101 | "needs_background": "light" 102 | }, 103 | "output_type": "display_data" 104 | } 105 | ], 106 | "source": [ 107 | "extract_slowness(np.mean(curve, axis=0), fit_epochs=60, ymin=90,ymax=99)" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "id": "8c15a8ff", 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "e0bc2aca", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "e06965af", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "gentle-egyptian", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [] 141 | } 142 | ], 143 | "metadata": { 144 | "kernelspec": { 145 | "display_name": "Python 3 (ipykernel)", 146 | "language": "python", 147 | "name": "python3" 148 | }, 149 | "language_info": { 150 | "codemirror_mode": { 151 | "name": "ipython", 152 | "version": 3 153 | }, 154 | "file_extension": ".py", 155 | "mimetype": "text/x-python", 156 | "name": "python", 157 | "nbconvert_exporter": "python", 158 | "pygments_lexer": "ipython3", 159 | "version": "3.7.10" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 5 164 | } 165 | -------------------------------------------------------------------------------- /main_pytorch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri Oct 15 11:21:01 2021 5 | 6 | @author: anonymous_ICML 7 | """ 8 | 9 | import os 10 | os.environ["MKL_NUM_THREADS"] = "1" 11 | os.environ["NUMEXPR_NUM_THREADS"] = "1" 12 | os.environ["OMP_NUM_THREADS"] = "1" 13 | import torch 14 | import torchvision 15 | import torchvision.transforms as transforms 16 | 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | import torch.optim as optim 20 | from torch.optim.lr_scheduler import StepLR 21 | 22 | import copy 23 | 24 | from torch.autograd import Variable 25 | 26 | import matplotlib.pyplot as plt 27 | import numpy as np 28 | 29 | import argparse 30 | 31 | from models import * 32 | from models_deepFC import * 33 | 34 | 35 | plt.close('all') 36 | 37 | 38 | 39 | # ask for the arguments 40 | parser = argparse.ArgumentParser() 41 | 42 | parser.add_argument('-en', '--exp_name', 43 | type=str, default='exptorch', 44 | help="Experiment name") 45 | parser.add_argument('-lt', '--learn_type', 46 | type=str, default='ERIN', 47 | help="Learning rule: BP, ERIN, ERINsign, FA, DFA") 48 | parser.add_argument('-r', '--n_runs', 49 | type=int, default=1, 50 | help="Number of simulations for each model") 51 | parser.add_argument('-trep', '--train_epochs', 52 | type=int, default= 100, 53 | help="Number of training epochs") 54 | parser.add_argument('-eta', '--eta', 55 | type=float, default=0.01, 56 | help="Learning rate") 57 | parser.add_argument('-do', '--dropout', 58 | type=float, default=0.9, 59 | help="Dropout") 60 | parser.add_argument('-Bstd', '--Bstd', 61 | type=float, default=0.05, 62 | help="Std of fixed matrix") 63 | parser.add_argument('-Bmz', '--B_mean_zero', 64 | action='store_true', 65 | help="Choose if B matrix needs to have mean 0") 66 | parser.add_argument('-check_cos_norm', '--check_cos_norm', 67 | action='store_true', 68 | help="Compute antialignment angle and matrix norm during training") 69 | parser.add_argument('-freeze_conv', '--freeze_conv', 70 | action='store_true', 71 | help="Freeze convolutional layers for PEPITA") 72 | parser.add_argument('-sqrt_conv', '--sqrt_conv', 73 | action='store_true', 74 | help="Take the sqrt(n) for the update of convolutional layers with PEPITA") 75 | parser.add_argument('-freeze_bn', '--freeze_bn', 76 | action='store_true', 77 | help="Freeze the training of the batchnorm layers") 78 | parser.add_argument('-eta_d', '--eta_decay', 79 | action='store_true', 80 | help="If True, eta is decreased by a factor 0.1 every 60 epochs") 81 | parser.add_argument('-decs', '--decay_scheme', 82 | type=int, default=1, 83 | help="Code for the learning rate decay scheme") 84 | parser.add_argument('-is_pool', '--is_pool', 85 | action='store_true', 86 | help="Choose if there is pooling in the network") 87 | parser.add_argument('-is_fc', '--is_fc', 88 | action='store_true', 89 | help="Choose if there are only fc layers in the network") 90 | parser.add_argument('-seed', '--seed', 91 | default=None, 92 | help="Random seed. Set to None or to integer") 93 | parser.add_argument('-ds', '--dataset', 94 | type=str, default='cif', 95 | help="Dataset choice. Options: mn,cif,cif100,fmn,emn") 96 | parser.add_argument('-ut', '--update_type', 97 | type=str, default='mom', 98 | help="Update type: SGD, mom(entum), NAG, rmsprop, Adam ...") 99 | parser.add_argument('-bs', '--batch_size', 100 | default=50,type=int, 101 | help="Batch size during training. Choose an integer") 102 | parser.add_argument('-win', '--w_init', 103 | type=str, default='he_uniform', #'he_uniform', 104 | help="Weight initialization type. Options: rnd, zero, ones, xav, he, he_uniform, nok, cir") 105 | parser.add_argument('-mod', '--model', 106 | type=str, default='Net1conv1fcXL_cif', #Net1conv1fcL 107 | help="Network structure. Options NetFC1x1024DOcust,NetClark,NetGithub,NetGithub_cif,NetGithub_BP,NetGithub_cif_BP,NetConvHuge,NetConvHuge_BP,NetCroc_cif_BP,NetCroc_BP,NetCroc_cif_BP_bn,NetClark") 108 | args = parser.parse_args() 109 | 110 | #mnist = True 111 | 112 | # save the arguments 113 | # simulation set-up 114 | exp_name = args.exp_name 115 | n_runs = args.n_runs 116 | train_epochs = args.train_epochs 117 | eta = args.eta 118 | print('Learning rate:',eta) 119 | check_cos_norm = args.check_cos_norm 120 | dropout = args.dropout 121 | Bstd = args.Bstd 122 | B_mean_zero = args.B_mean_zero 123 | B_mean_zero = True 124 | is_pool = args.is_pool 125 | is_fc = args.is_fc 126 | freeze_conv = args.freeze_conv 127 | sqrt_conv = args.sqrt_conv 128 | freeze_bn = args.freeze_bn 129 | keep_rate = dropout 130 | eta_decay = args.eta_decay 131 | eta_decay = True # to be removed 132 | decay_scheme = args.decay_scheme 133 | seed = args.seed 134 | dataset = args.dataset 135 | w_init = args.w_init 136 | # network set-up 137 | learn_type = args.learn_type # current options are BP, ERIN 138 | update_type = args.update_type # current options are SGD, mom(entum) 139 | batch_size = args.batch_size 140 | model = args.model 141 | dataset = args.dataset 142 | 143 | criterion = nn.CrossEntropyLoss() 144 | 145 | 146 | # create folder to save all results 147 | savepath = "res_"+exp_name+"_"+dataset+"_"+model+learn_type+"_"+update_type+"_"+str(batch_size)+"_"+w_init+"_"+"_rep"+str(n_runs)+"_tr"+str(train_epochs) 148 | 149 | if eta_decay == True: 150 | savepath += "etad"+str(decay_scheme) 151 | 152 | try: 153 | os.mkdir(savepath) 154 | except OSError: 155 | print ("Creation of the directory %s failed" % savepath) 156 | else: 157 | print ("Successfully created the directory %s " % savepath) 158 | # prepare a file to write the results on 159 | filename = savepath+'/res_summary_'+exp_name+'.txt' 160 | file = open(filename,'w') 161 | file.write('Results for simulation with the following hyperparameters ') 162 | file.write('\n Number of repetitions = ') 163 | file.write(str(n_runs)) 164 | file.write('\n Training epochs = ') 165 | file.write(str(train_epochs)) 166 | file.write('\n Learning rate = ') 167 | file.write(str(eta)) 168 | file.write('\n Eta decay = ') 169 | file.write(str(eta_decay)) 170 | file.write('\n F std = ') 171 | file.write(str(Bstd)) 172 | file.write('\n Seed = ') 173 | file.write(str(seed)) 174 | file.write('\n Dataset = ') 175 | file.write(dataset) 176 | file.write('\n Model = ') 177 | file.write(model) 178 | file.write('\n Learn type = ') 179 | file.write(learn_type) 180 | file.write('\n Batch size = ') 181 | file.write(str(batch_size)) 182 | file.write('\n Update type = ') 183 | file.write(update_type) 184 | file.close() 185 | 186 | 187 | # create variables to store results 188 | train_acc_all = np.zeros((n_runs,train_epochs)) 189 | val_acc_all = np.zeros((n_runs,train_epochs)) 190 | test_acc_all = [] 191 | 192 | 193 | # load dataset 194 | transform = transforms.Compose( 195 | [transforms.ToTensor()]) # this normalizes to [0,1] 196 | if dataset == 'mn': 197 | ch_input = 1 198 | nout = 10 199 | trainset = torchvision.datasets.MNIST(root='./data', train=True, 200 | download=True, transform=transform) 201 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 202 | shuffle=True, num_workers=0) 203 | testset = torchvision.datasets.MNIST(root='./data', train=False, 204 | download=True, transform=transform) 205 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 206 | shuffle=False, num_workers=0) 207 | 208 | elif dataset == 'cif': 209 | #transform = transforms.Compose( 210 | # [transforms.ToTensor(), 211 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 212 | 213 | ch_input = 3 214 | nout = 10 215 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 216 | download=True, transform=transform) 217 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 218 | shuffle=True, num_workers=0) 219 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 220 | download=True, transform=transform) 221 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 222 | shuffle=False, num_workers=0) 223 | 224 | elif dataset == 'cif100': 225 | #transform = transforms.Compose( 226 | # [transforms.ToTensor(), 227 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 228 | 229 | ch_input = 3 230 | nout = 100 231 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, 232 | download=True, transform=transform) 233 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 234 | shuffle=True, num_workers=0) 235 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, 236 | download=True, transform=transform) 237 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 238 | shuffle=False, num_workers=0) 239 | 240 | 241 | 242 | 243 | # loop over the number of simulations 244 | for r in range(n_runs): 245 | print('####### RUN {} #######'.format(r)) 246 | if model == 'NetFC1x1024DOcust': 247 | net = NetFC1x1024DOcust(ch_input,nout) 248 | elif model == 'NetFC1x1024DOcust_cif': 249 | net = NetFC1x1024DOcust_cif(ch_input,nout) 250 | elif model == 'Net1conv1fcXL': 251 | net = Net1conv1fcXL(ch_input,nout) 252 | elif model == 'Net1conv1fcXL_cif': 253 | net = Net1conv1fcXL_cif(ch_input,nout) 254 | 255 | 256 | # set-up for BP 257 | if learn_type == 'BP': 258 | criterion = nn.CrossEntropyLoss() 259 | if update_type == 'SGD': 260 | optimizer = optim.SGD(net.parameters(), lr=eta) 261 | elif update_type == 'mom': 262 | optimizer = optim.SGD(net.parameters(), lr=eta, momentum=0.9) 263 | elif update_type == 'adam': 264 | optimizer = optim.Adam(net.parameters(), lr=eta) 265 | 266 | if eta_decay: 267 | scheduler = StepLR(optimizer, step_size=1, gamma=0.1) 268 | # set-up for ERIN 269 | elif learn_type == 'ERIN': 270 | activation = {} 271 | def get_activation(name): 272 | def hook(model, input, output): 273 | activation[name] = output.detach() 274 | return hook 275 | for name, layer in net.named_modules(): 276 | layer.register_forward_hook(get_activation(name)) 277 | 278 | # define B 279 | if dataset == 'mn': 280 | nin = 28*28*1 281 | nout = 10 282 | elif dataset == 'cif': 283 | nin = 32*32*3 284 | nout = 10 285 | elif dataset == 'cif100': 286 | nin = 32*32*3 287 | nout = 100 288 | sd = np.sqrt(6/nin) 289 | if B_mean_zero: 290 | B = (torch.rand(nin,nout)*2*sd-sd)*Bstd # mean zero 291 | else: 292 | B = (torch.rand(nin,nout)*sd)*Bstd # positive mean 293 | 294 | # save all weight shapes 295 | w_shapes = [] 296 | for l_idx,w in enumerate(net.parameters()): 297 | if len(w.shape)>1: 298 | with torch.no_grad(): 299 | w_shapes.append(w.shape) 300 | # do one forward pass to get the activation size needed for setting up the dropout masks 301 | dataiter = iter(trainloader) 302 | images, labels = dataiter.next() 303 | #if is_fc: 304 | # images = torch.flatten(images, 1) # flatten all dimensions except batch 305 | outputs = net(images,do_masks=None) 306 | layers_act = [] 307 | layers_key = [] 308 | flag_fc = 0 309 | for key in activation: 310 | if 'fc' in key and 'bn' not in key or 'conv' in key and 'bn' not in key: 311 | layers_act.append(F.relu(activation[key])) 312 | layers_key.append(key) 313 | if flag_fc == 0 and 'fc' in key: 314 | first_fc = len(layers_key) 315 | flag_fc = 1 316 | # set up for momentum 317 | if update_type == 'mom': 318 | gamma = 0.9 319 | v_w_all = [] 320 | for l_idx,w in enumerate(net.parameters()): 321 | if len(w.shape)>1: 322 | with torch.no_grad(): 323 | v_w_all.append(torch.zeros(w.shape)) 324 | 325 | # freeze the update of batchnorm layer if prescribed 326 | if freeze_bn: 327 | for name ,child in (net.named_children()): 328 | #if name.find('BatchNorm') != -1: 329 | if isinstance(child, nn.BatchNorm2d) or isinstance(child, nn.BatchNorm1d): 330 | for param in child.parameters(): 331 | param.requires_grad = False 332 | #print(name,'without grad') 333 | else: 334 | for param in child.parameters(): 335 | param.requires_grad = True 336 | #print(name,'with grad') 337 | 338 | # load pretrained weights for convolutional layers and freeze the conv layers 339 | load_pretrained = False 340 | if load_pretrained: 341 | first_fc = 3 342 | for l_idx,w in enumerate(net.parameters()): 343 | if len(w.shape)>1 and l_idx+1 < first_fc: # load only fc 344 | #if len(w.shape)>1: # load both conv and fc 345 | with torch.no_grad(): 346 | w_np = np.loadtxt('NetGithub_w'+str(l_idx)+'.txt') 347 | w_np = w_np.reshape(w.shape) 348 | w += -w + w_np 349 | for name ,child in (net.named_children()): 350 | #if name.find('BatchNorm') != -1: 351 | if isinstance(child, nn.Conv2d): 352 | for param in child.parameters(): 353 | param.requires_grad = False 354 | #print(name,'without grad') 355 | else: 356 | for param in child.parameters(): 357 | param.requires_grad = True 358 | 359 | # learning rate decay 360 | if eta_decay: 361 | decay_rate = 0.1 362 | if decay_scheme == 0: 363 | if dataset == 'mn': 364 | decay_epochs = [60] 365 | else: 366 | decay_epochs = [60,90] 367 | elif decay_scheme == 1: 368 | decay_epochs = [10,30,60] 369 | 370 | # train the model 371 | test_accs = [] 372 | losses = [] 373 | for epoch in range(train_epochs): # loop over the dataset multiple times 374 | 375 | # learning rate decay 376 | if eta_decay: 377 | if epoch in decay_epochs: 378 | if learn_type == 'BP': 379 | scheduler.step() 380 | elif learn_type == 'ERIN': 381 | eta = eta * decay_rate 382 | print('At epoch {} learning rate decreased to {}'.format(epoch,eta)) 383 | 384 | running_loss = 0.0 385 | for i, data in enumerate(trainloader, 0): 386 | # get the inputs; data is a list of [inputs, labels] 387 | inputs, target = data 388 | 389 | if learn_type == 'BP': 390 | # zero the parameter gradients 391 | optimizer.zero_grad() 392 | 393 | # forward + backward + optimize 394 | outputs = net(inputs,do_masks=None) 395 | loss = criterion(outputs, target) 396 | loss.backward() 397 | optimizer.step() 398 | 399 | elif learn_type == 'ERIN': 400 | target_onehot = F.one_hot(target,num_classes=nout) 401 | # create dropout mask for the two forward passes 402 | do_masks = [] 403 | for l_idx,l in enumerate(layers_act[:-1]): 404 | if model == 'NetConvHuge' and l_idx < first_fc-1: 405 | input1 = net.pool(l) 406 | else: 407 | input1 = l 408 | do_mask = Variable(torch.bernoulli(input1.data.new(input1.data.size()).fill_(keep_rate)))/keep_rate 409 | do_masks.append(do_mask) 410 | 411 | # forward pass 1 with original input --> keep track of activations 412 | outputs = net(inputs,do_masks) 413 | layers_act = [] 414 | for key in activation: 415 | if 'fc' in key and 'bn' not in key or 'conv' in key and 'bn' not in key: 416 | layers_act.append(F.relu(activation[key])) 417 | 418 | error = outputs - target_onehot 419 | 420 | # modify the input with the error 421 | error_input = error @ B.T 422 | error_input = error_input.reshape_as(inputs) 423 | mod_inputs = inputs + error_input 424 | 425 | # forward pass 2 with modified input 426 | mod_outputs = net(mod_inputs,do_masks) 427 | mod_layers_act = [] 428 | for key in activation: 429 | if 'fc' in key and 'bn' not in key or 'conv' in key and 'bn' not in key: 430 | mod_layers_act.append(F.relu(activation[key])) 431 | mod_error = mod_outputs - target_onehot 432 | 433 | # compute the delta_w for the batch 434 | delta_w_all = [] 435 | for l in range(len(layers_key)): 436 | if 'fc' in layers_key[l] and 'bn' not in layers_key[l]: 437 | #print('key for fc',layers_key[l],l) 438 | if l == first_fc-1 and first_fc == len(layers_act): # only fc layers: case with only one fc layer after conv layers 439 | 440 | if is_pool == False: 441 | delta_w = -error.T @ mod_layers_act[-2].flatten(1) 442 | else: 443 | delta_w = -error.T @ net.pool(mod_layers_act[-2]).flatten(1) 444 | 445 | 446 | elif l == len(layers_act)-1: # last layer 447 | #print('last fc') 448 | if len(layers_act)>1: 449 | delta_w = -mod_error.T @ mod_layers_act[-2] 450 | else: 451 | delta_w = -mod_error.T @ mod_inputs 452 | 453 | elif l == first_fc-1: # first layer to be modified 454 | #print('first fc --> apply pool, then reshape') 455 | if first_fc > 1: # convolutional model 456 | if is_pool == False: 457 | input_to_fc = mod_layers_act[l-1] 458 | else: 459 | input_to_fc = net.pool(mod_layers_act[l-1]) 460 | else: # fully connected model 461 | input_to_fc = mod_inputs 462 | 463 | delta_w = -(layers_act[l] - mod_layers_act[l]).T @ input_to_fc.view(batch_size,-1) 464 | 465 | elif l>first_fc-1 and l 1: # do not train the batchnorm layer 497 | with torch.no_grad(): 498 | #print('w',w.shape,'dw',delta_w_all[l_idx].shape) 499 | if update_type == 'SGD': 500 | w += eta * delta_w_all[l_idx]/batch_size # specify for which layer 501 | elif update_type == 'mom': 502 | v_w_all[l_idx] = gamma * v_w_all[l_idx] + eta * delta_w_all[l_idx]/batch_size 503 | w += v_w_all[l_idx] 504 | 505 | l_idx += 1 # needed to skip batchnorm 506 | 507 | 508 | 509 | # keep track of the loss 510 | loss = criterion(outputs, target) 511 | # print statistics 512 | running_loss += loss.item() 513 | 514 | curr_loss = running_loss / i 515 | print('[%d, %5d] loss: %.3f' % 516 | (epoch + 1, i + 1, running_loss / i)) 517 | running_loss = 0.0 518 | losses.append(curr_loss) 519 | 520 | 521 | print('Testing...') 522 | correct = 0 523 | total = 0 524 | # since we're not training, we don't need to calculate the gradients for our outputs 525 | with torch.no_grad(): 526 | for test_data in testloader: 527 | test_images, test_labels = test_data 528 | #test_images = torch.flatten(test_images, 1) # flatten all dimensions except batch 529 | # calculate outputs by running images through the network 530 | test_outputs = net(test_images,do_masks=None) 531 | # the class with the highest energy is what we choose as prediction 532 | _, predicted = torch.max(test_outputs.data, 1) 533 | total += test_labels.size(0) 534 | correct += (predicted == test_labels).sum().item() 535 | 536 | print('Test accuracy: {} %'.format(100 * correct / total)) 537 | test_accs.append(100 * correct / total) 538 | 539 | # save the results for this network 540 | np.savetxt(savepath+'/losses_run'+str(r)+'.txt',losses) 541 | np.savetxt(savepath+'/test_acc_run'+str(r)+'.txt',test_accs) 542 | 543 | print('Finished Training') 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | -------------------------------------------------------------------------------- /Tutorial_PEPITA_FullyConnectedNets_CIFAR-10.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0787a534", 6 | "metadata": {}, 7 | "source": [ 8 | "This notebook illustrates how to train Fully Connected models with PEPITA. We train and test the model on CIFAR-10." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "ae7a36dc", 14 | "metadata": {}, 15 | "source": [ 16 | "#### Import libraries" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "id": "a56a3e6a", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import torch\n", 27 | "import torchvision\n", 28 | "import torchvision.transforms as transforms\n", 29 | "\n", 30 | "import torch.nn as nn\n", 31 | "import torch.nn.functional as F\n", 32 | "import torch.optim as optim\n", 33 | "\n", 34 | "from torch.autograd import Variable\n", 35 | "\n", 36 | "import copy\n", 37 | "\n", 38 | "import matplotlib.pyplot as plt\n", 39 | "import numpy as np" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "id": "ab6701be", 45 | "metadata": {}, 46 | "source": [ 47 | "#### Define Network architecture" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": 8, 53 | "id": "9d192b26", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "# models with Dropout\n", 58 | "class NetFC1x1024DOcust(nn.Module):\n", 59 | " def __init__(self):\n", 60 | " super().__init__()\n", 61 | " self.fc1 = nn.Linear(32*32*3,1024,bias=False)\n", 62 | " self.fc2 = nn.Linear(1024, 10,bias=False)\n", 63 | " \n", 64 | " # initialize the layers using the He uniform initialization scheme\n", 65 | " fc1_nin = 32*32*3 # Note: if dataset is MNIST --> fc1_nin = 28*28*1\n", 66 | " fc1_limit = np.sqrt(6.0 / fc1_nin)\n", 67 | " torch.nn.init.uniform_(self.fc1.weight, a=-fc1_limit, b=fc1_limit)\n", 68 | " fc2_nin = 1024\n", 69 | " fc2_limit = np.sqrt(6.0 / fc2_nin)\n", 70 | " torch.nn.init.uniform_(self.fc2.weight, a=-fc2_limit, b=fc2_limit)\n", 71 | " \n", 72 | "\n", 73 | " def forward(self, x, do_masks):\n", 74 | " x = F.relu(self.fc1(x))\n", 75 | " # apply dropout --> we use a custom dropout implementation because we need to present the same dropout mask in the two forward passes\n", 76 | " if do_masks is not None:\n", 77 | " x = x * do_masks[0] \n", 78 | " x = F.softmax(self.fc2(x))\n", 79 | " return x\n", 80 | " \n", 81 | "\n" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "id": "0ad144c5", 87 | "metadata": {}, 88 | "source": [ 89 | "#### Set hyperparameters and train+test the model" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 39, 95 | "id": "e2a085cc", 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "Files already downloaded and verified\n", 103 | "Files already downloaded and verified\n", 104 | "norm of w at layer 0 is 45.241451263427734\n", 105 | "norm of w at layer 1 is 4.473217010498047\n", 106 | "torch.Size([3072, 10])\n" 107 | ] 108 | }, 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "c:\\users\\giorgiadellaferrera\\anaconda3\\envs\\env_pytorch\\lib\\site-packages\\ipykernel_launcher.py:21: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n" 114 | ] 115 | }, 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "[1, 500] loss: 2.210\n", 121 | "Testing...\n", 122 | "Test accuracy epoch 0: 35.12 %\n", 123 | "[2, 500] loss: 2.167\n", 124 | "Testing...\n", 125 | "Test accuracy epoch 1: 38.07 %\n", 126 | "[3, 500] loss: 2.150\n", 127 | "Testing...\n", 128 | "Test accuracy epoch 2: 40.9 %\n", 129 | "[4, 500] loss: 2.137\n", 130 | "Testing...\n", 131 | "Test accuracy epoch 3: 40.41 %\n", 132 | "[5, 500] loss: 2.126\n", 133 | "Testing...\n", 134 | "Test accuracy epoch 4: 41.4 %\n", 135 | "[6, 500] loss: 2.117\n", 136 | "Testing...\n", 137 | "Test accuracy epoch 5: 39.76 %\n", 138 | "[7, 500] loss: 2.109\n", 139 | "Testing...\n", 140 | "Test accuracy epoch 6: 43.73 %\n", 141 | "[8, 500] loss: 2.101\n", 142 | "Testing...\n", 143 | "Test accuracy epoch 7: 45.08 %\n", 144 | "[9, 500] loss: 2.096\n", 145 | "Testing...\n", 146 | "Test accuracy epoch 8: 45.28 %\n", 147 | "[10, 500] loss: 2.090\n", 148 | "Testing...\n", 149 | "Test accuracy epoch 9: 45.05 %\n", 150 | "[11, 500] loss: 2.085\n", 151 | "Testing...\n", 152 | "Test accuracy epoch 10: 44.66 %\n", 153 | "[12, 500] loss: 2.079\n", 154 | "Testing...\n", 155 | "Test accuracy epoch 11: 45.89 %\n", 156 | "[13, 500] loss: 2.074\n", 157 | "Testing...\n", 158 | "Test accuracy epoch 12: 47.39 %\n", 159 | "[14, 500] loss: 2.071\n", 160 | "Testing...\n", 161 | "Test accuracy epoch 13: 46.95 %\n", 162 | "[15, 500] loss: 2.065\n", 163 | "Testing...\n", 164 | "Test accuracy epoch 14: 47.41 %\n", 165 | "[16, 500] loss: 2.061\n", 166 | "Testing...\n", 167 | "Test accuracy epoch 15: 47.9 %\n", 168 | "[17, 500] loss: 2.056\n", 169 | "Testing...\n", 170 | "Test accuracy epoch 16: 48.52 %\n", 171 | "[18, 500] loss: 2.052\n", 172 | "Testing...\n", 173 | "Test accuracy epoch 17: 48.53 %\n", 174 | "[19, 500] loss: 2.048\n", 175 | "Testing...\n", 176 | "Test accuracy epoch 18: 46.39 %\n", 177 | "[20, 500] loss: 2.042\n", 178 | "Testing...\n", 179 | "Test accuracy epoch 19: 48.61 %\n", 180 | "[21, 500] loss: 2.040\n", 181 | "Testing...\n", 182 | "Test accuracy epoch 20: 47.27 %\n", 183 | "[22, 500] loss: 2.034\n", 184 | "Testing...\n", 185 | "Test accuracy epoch 21: 49.05 %\n", 186 | "[23, 500] loss: 2.034\n", 187 | "Testing...\n", 188 | "Test accuracy epoch 22: 49.2 %\n", 189 | "[24, 500] loss: 2.030\n", 190 | "Testing...\n", 191 | "Test accuracy epoch 23: 49.39 %\n", 192 | "[25, 500] loss: 2.027\n", 193 | "Testing...\n", 194 | "Test accuracy epoch 24: 48.68 %\n", 195 | "[26, 500] loss: 2.021\n", 196 | "Testing...\n", 197 | "Test accuracy epoch 25: 49.64 %\n", 198 | "[27, 500] loss: 2.019\n", 199 | "Testing...\n", 200 | "Test accuracy epoch 26: 49.09 %\n", 201 | "[28, 500] loss: 2.015\n", 202 | "Testing...\n", 203 | "Test accuracy epoch 27: 50.06 %\n", 204 | "[29, 500] loss: 2.010\n", 205 | "Testing...\n", 206 | "Test accuracy epoch 28: 49.31 %\n", 207 | "[30, 500] loss: 2.009\n", 208 | "Testing...\n", 209 | "Test accuracy epoch 29: 49.77 %\n", 210 | "[31, 500] loss: 2.005\n", 211 | "Testing...\n", 212 | "Test accuracy epoch 30: 49.89 %\n", 213 | "[32, 500] loss: 1.999\n", 214 | "Testing...\n", 215 | "Test accuracy epoch 31: 50.07 %\n", 216 | "[33, 500] loss: 1.999\n", 217 | "Testing...\n", 218 | "Test accuracy epoch 32: 49.84 %\n", 219 | "[34, 500] loss: 1.995\n", 220 | "Testing...\n", 221 | "Test accuracy epoch 33: 49.66 %\n", 222 | "[35, 500] loss: 1.994\n", 223 | "Testing...\n", 224 | "Test accuracy epoch 34: 50.19 %\n", 225 | "[36, 500] loss: 1.990\n", 226 | "Testing...\n", 227 | "Test accuracy epoch 35: 50.32 %\n", 228 | "[37, 500] loss: 1.985\n", 229 | "Testing...\n", 230 | "Test accuracy epoch 36: 50.61 %\n", 231 | "[38, 500] loss: 1.982\n", 232 | "Testing...\n", 233 | "Test accuracy epoch 37: 49.24 %\n", 234 | "[39, 500] loss: 1.982\n", 235 | "Testing...\n", 236 | "Test accuracy epoch 38: 50.55 %\n", 237 | "[40, 500] loss: 1.979\n", 238 | "Testing...\n", 239 | "Test accuracy epoch 39: 51.68 %\n", 240 | "[41, 500] loss: 1.972\n", 241 | "Testing...\n", 242 | "Test accuracy epoch 40: 50.78 %\n", 243 | "[42, 500] loss: 1.970\n", 244 | "Testing...\n", 245 | "Test accuracy epoch 41: 50.49 %\n", 246 | "[43, 500] loss: 1.968\n", 247 | "Testing...\n", 248 | "Test accuracy epoch 42: 51.47 %\n", 249 | "[44, 500] loss: 1.966\n", 250 | "Testing...\n", 251 | "Test accuracy epoch 43: 50.79 %\n", 252 | "[45, 500] loss: 1.962\n", 253 | "Testing...\n", 254 | "Test accuracy epoch 44: 50.73 %\n", 255 | "[46, 500] loss: 1.958\n", 256 | "Testing...\n", 257 | "Test accuracy epoch 45: 51.77 %\n", 258 | "[47, 500] loss: 1.958\n", 259 | "Testing...\n", 260 | "Test accuracy epoch 46: 51.32 %\n", 261 | "[48, 500] loss: 1.952\n", 262 | "Testing...\n", 263 | "Test accuracy epoch 47: 51.41 %\n", 264 | "[49, 500] loss: 1.951\n", 265 | "Testing...\n", 266 | "Test accuracy epoch 48: 51.31 %\n", 267 | "[50, 500] loss: 1.951\n", 268 | "Testing...\n", 269 | "Test accuracy epoch 49: 51.36 %\n", 270 | "[51, 500] loss: 1.945\n", 271 | "Testing...\n", 272 | "Test accuracy epoch 50: 51.31 %\n", 273 | "[52, 500] loss: 1.943\n", 274 | "Testing...\n", 275 | "Test accuracy epoch 51: 51.79 %\n", 276 | "[53, 500] loss: 1.940\n", 277 | "Testing...\n", 278 | "Test accuracy epoch 52: 51.32 %\n", 279 | "[54, 500] loss: 1.938\n", 280 | "Testing...\n", 281 | "Test accuracy epoch 53: 51.53 %\n", 282 | "[55, 500] loss: 1.934\n", 283 | "Testing...\n", 284 | "Test accuracy epoch 54: 51.2 %\n", 285 | "[56, 500] loss: 1.933\n", 286 | "Testing...\n", 287 | "Test accuracy epoch 55: 51.77 %\n", 288 | "[57, 500] loss: 1.928\n", 289 | "Testing...\n", 290 | "Test accuracy epoch 56: 50.74 %\n", 291 | "[58, 500] loss: 1.926\n", 292 | "Testing...\n", 293 | "Test accuracy epoch 57: 51.12 %\n", 294 | "[59, 500] loss: 1.926\n", 295 | "Testing...\n", 296 | "Test accuracy epoch 58: 51.42 %\n", 297 | "[60, 500] loss: 1.921\n", 298 | "Testing...\n", 299 | "Test accuracy epoch 59: 50.92 %\n", 300 | "eta decreased to 0.001\n", 301 | "[61, 500] loss: 1.909\n", 302 | "Testing...\n", 303 | "Test accuracy epoch 60: 52.45 %\n", 304 | "[62, 500] loss: 1.900\n", 305 | "Testing...\n", 306 | "Test accuracy epoch 61: 52.58 %\n", 307 | "[63, 500] loss: 1.900\n", 308 | "Testing...\n", 309 | "Test accuracy epoch 62: 52.84 %\n", 310 | "[64, 500] loss: 1.898\n", 311 | "Testing...\n", 312 | "Test accuracy epoch 63: 52.71 %\n", 313 | "[65, 500] loss: 1.898\n", 314 | "Testing...\n", 315 | "Test accuracy epoch 64: 52.51 %\n", 316 | "[66, 500] loss: 1.899\n", 317 | "Testing...\n", 318 | "Test accuracy epoch 65: 52.56 %\n", 319 | "[67, 500] loss: 1.898\n", 320 | "Testing...\n", 321 | "Test accuracy epoch 66: 52.66 %\n", 322 | "[68, 500] loss: 1.898\n", 323 | "Testing...\n", 324 | "Test accuracy epoch 67: 52.67 %\n", 325 | "[69, 500] loss: 1.896\n", 326 | "Testing...\n", 327 | "Test accuracy epoch 68: 52.56 %\n", 328 | "[70, 500] loss: 1.894\n", 329 | "Testing...\n", 330 | "Test accuracy epoch 69: 52.57 %\n", 331 | "[71, 500] loss: 1.898\n", 332 | "Testing...\n", 333 | "Test accuracy epoch 70: 52.69 %\n", 334 | "[72, 500] loss: 1.896\n", 335 | "Testing...\n", 336 | "Test accuracy epoch 71: 52.51 %\n", 337 | "[73, 500] loss: 1.896\n", 338 | "Testing...\n", 339 | "Test accuracy epoch 72: 52.42 %\n", 340 | "[74, 500] loss: 1.896\n", 341 | "Testing...\n", 342 | "Test accuracy epoch 73: 52.52 %\n", 343 | "[75, 500] loss: 1.895\n", 344 | "Testing...\n", 345 | "Test accuracy epoch 74: 52.9 %\n", 346 | "[76, 500] loss: 1.893\n", 347 | "Testing...\n", 348 | "Test accuracy epoch 75: 52.28 %\n", 349 | "[77, 500] loss: 1.897\n", 350 | "Testing...\n", 351 | "Test accuracy epoch 76: 52.47 %\n", 352 | "[78, 500] loss: 1.892\n", 353 | "Testing...\n", 354 | "Test accuracy epoch 77: 52.57 %\n", 355 | "[79, 500] loss: 1.893\n", 356 | "Testing...\n", 357 | "Test accuracy epoch 78: 52.29 %\n", 358 | "[80, 500] loss: 1.894\n", 359 | "Testing...\n", 360 | "Test accuracy epoch 79: 52.93 %\n", 361 | "[81, 500] loss: 1.892\n", 362 | "Testing...\n", 363 | "Test accuracy epoch 80: 52.64 %\n", 364 | "[82, 500] loss: 1.891\n", 365 | "Testing...\n", 366 | "Test accuracy epoch 81: 52.73 %\n", 367 | "[83, 500] loss: 1.892\n", 368 | "Testing...\n", 369 | "Test accuracy epoch 82: 52.37 %\n", 370 | "[84, 500] loss: 1.893\n", 371 | "Testing...\n", 372 | "Test accuracy epoch 83: 52.49 %\n", 373 | "[85, 500] loss: 1.892\n", 374 | "Testing...\n", 375 | "Test accuracy epoch 84: 52.46 %\n", 376 | "[86, 500] loss: 1.891\n", 377 | "Testing...\n", 378 | "Test accuracy epoch 85: 52.36 %\n", 379 | "[87, 500] loss: 1.892\n", 380 | "Testing...\n", 381 | "Test accuracy epoch 86: 52.43 %\n", 382 | "[88, 500] loss: 1.893\n", 383 | "Testing...\n", 384 | "Test accuracy epoch 87: 52.32 %\n", 385 | "[89, 500] loss: 1.891\n", 386 | "Testing...\n", 387 | "Test accuracy epoch 88: 52.21 %\n", 388 | "[90, 500] loss: 1.890\n", 389 | "Testing...\n", 390 | "Test accuracy epoch 89: 52.37 %\n", 391 | "eta decreased to 0.0001\n", 392 | "[91, 500] loss: 1.888\n", 393 | "Testing...\n", 394 | "Test accuracy epoch 90: 52.45 %\n", 395 | "[92, 500] loss: 1.889\n", 396 | "Testing...\n", 397 | "Test accuracy epoch 91: 52.43 %\n", 398 | "[93, 500] loss: 1.890\n", 399 | "Testing...\n", 400 | "Test accuracy epoch 92: 52.79 %\n", 401 | "[94, 500] loss: 1.889\n", 402 | "Testing...\n", 403 | "Test accuracy epoch 93: 52.58 %\n", 404 | "[95, 500] loss: 1.887\n", 405 | "Testing...\n", 406 | "Test accuracy epoch 94: 52.62 %\n", 407 | "[96, 500] loss: 1.888\n", 408 | "Testing...\n", 409 | "Test accuracy epoch 95: 52.59 %\n", 410 | "[97, 500] loss: 1.889\n", 411 | "Testing...\n", 412 | "Test accuracy epoch 96: 52.7 %\n", 413 | "[98, 500] loss: 1.886\n", 414 | "Testing...\n", 415 | "Test accuracy epoch 97: 52.57 %\n", 416 | "[99, 500] loss: 1.890\n", 417 | "Testing...\n", 418 | "Test accuracy epoch 98: 52.56 %\n", 419 | "[100, 500] loss: 1.888\n", 420 | "Testing...\n", 421 | "Test accuracy epoch 99: 52.58 %\n", 422 | "Finished Training\n" 423 | ] 424 | } 425 | ], 426 | "source": [ 427 | "# set hyperparameters\n", 428 | "## learning rate\n", 429 | "eta = 0.01 \n", 430 | "## dropout keep rate\n", 431 | "keep_rate = 0.9\n", 432 | "## loss --> used to monitor performance, but not for parameter updates (PEPITA does not backpropagate the loss)\n", 433 | "criterion = nn.CrossEntropyLoss()\n", 434 | "## optimizer (choose 'SGD' o 'mom')\n", 435 | "optim = 'mom' # --> default in the paper\n", 436 | "if optim == 'SGD':\n", 437 | " gamma = 0\n", 438 | "elif optim == 'mom':\n", 439 | " gamma = 0.9\n", 440 | "## batch size\n", 441 | "batch_size = 64 # --> default in the paper\n", 442 | "\n", 443 | "# initialize the network\n", 444 | "net = NetFC1x1024DOcust()\n", 445 | "\n", 446 | "# load the dataset\n", 447 | "transform = transforms.Compose(\n", 448 | " [transforms.ToTensor()]) # this normalizes to [0,1]\n", 449 | "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n", 450 | " download=True, transform=transform)\n", 451 | "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n", 452 | " shuffle=True, num_workers=2)\n", 453 | "testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n", 454 | " download=True, transform=transform)\n", 455 | "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n", 456 | " shuffle=False, num_workers=2)\n", 457 | "\n", 458 | "\n", 459 | "# define function to register the activations --> we need this to compare the activations in the two forward passes\n", 460 | "activation = {}\n", 461 | "def get_activation(name):\n", 462 | " def hook(model, input, output):\n", 463 | " activation[name] = output.detach()\n", 464 | " return hook\n", 465 | "for name, layer in net.named_modules():\n", 466 | " layer.register_forward_hook(get_activation(name))\n", 467 | "\n", 468 | "\n", 469 | "# define B --> this is the F projection matrix in the paper (here named B because F is torch.nn.functional)\n", 470 | "nin = 32*32*3\n", 471 | "sd = np.sqrt(6/nin)\n", 472 | "B = (torch.rand(nin,10)*2*sd-sd)*0.05 # B is initialized with the He uniform initialization (like the forward weights)\n", 473 | "\n", 474 | "\n", 475 | "# check cosine similarity before training AND matrix norm\n", 476 | "angles = []\n", 477 | "w_all = []\n", 478 | "norm_w0 = []\n", 479 | "for l_idx,w in enumerate(net.parameters()):\n", 480 | " with torch.no_grad():\n", 481 | " w_all.append(copy.deepcopy(w))\n", 482 | " if l_idx == 0:\n", 483 | " norm_w0.append(torch.norm(w))\n", 484 | " print('norm of w at layer {} is {}'.format(l_idx,torch.norm(w)))\n", 485 | "w_prod = w_all[0].T\n", 486 | "for idx in range(1,len(w_all)):\n", 487 | " w_prod = torch.matmul(w_prod,w_all[idx].T)\n", 488 | " print(w_prod.size())\n", 489 | "\n", 490 | "# do one forward pass to get the activation size needed for setting up the dropout masks\n", 491 | "dataiter = iter(trainloader)\n", 492 | "images, labels = dataiter.next()\n", 493 | "images = torch.flatten(images, 1) # flatten all dimensions except batch \n", 494 | "outputs = net(images,do_masks=None)\n", 495 | "layers_act = []\n", 496 | "for key in activation:\n", 497 | " if 'fc' in key or 'conv' in key:\n", 498 | " layers_act.append(F.relu(activation[key]))\n", 499 | " \n", 500 | "# set up for momentum\n", 501 | "if optim == 'mom':\n", 502 | " gamma = 0.9\n", 503 | " v_w_all = []\n", 504 | " for l_idx,w in enumerate(net.parameters()):\n", 505 | " if len(w.shape)>1:\n", 506 | " with torch.no_grad():\n", 507 | " v_w_all.append(torch.zeros(w.shape))\n", 508 | "\n", 509 | "# Train and test the model\n", 510 | "test_accs = []\n", 511 | "for epoch in range(100): # loop over the dataset multiple times\n", 512 | " \n", 513 | " # learning rate decay\n", 514 | " if epoch in [60,90]: \n", 515 | " eta = eta*0.1\n", 516 | " print('eta decreased to ',eta)\n", 517 | " \n", 518 | " # loop over batches\n", 519 | " running_loss = 0.0\n", 520 | " for i, data in enumerate(trainloader, 0):\n", 521 | " # get the inputs; data is a list of [inputs, labels]\n", 522 | " inputs, target = data\n", 523 | " inputs = torch.flatten(inputs, 1) # flatten all dimensions except batch\n", 524 | " target_onehot = F.one_hot(target,num_classes=10)\n", 525 | " \n", 526 | " # create dropout mask for the two forward passes --> we need to use the same mask for the two passes\n", 527 | " do_masks = []\n", 528 | " if keep_rate < 1:\n", 529 | " for l in layers_act[:-1]:\n", 530 | " input1 = l\n", 531 | " do_mask = Variable(torch.ones(inputs.shape[0],input1.data.new(input1.data.size()).shape[1]).bernoulli_(keep_rate))/keep_rate\n", 532 | " do_masks.append(do_mask)\n", 533 | " do_masks.append(1) # for the last layer we don't use dropout --> just set a scalar 1 (needed for when we register activation layer)\n", 534 | " \n", 535 | " # forward pass 1 with original input --> keep track of activations\n", 536 | " outputs = net(inputs,do_masks)\n", 537 | " layers_act = []\n", 538 | " cnt_act = 0\n", 539 | " for key in activation:\n", 540 | " if 'fc' in key or 'conv' in key:\n", 541 | " layers_act.append(F.relu(activation[key])* do_masks[cnt_act]) # Note: we need to register the activations taking into account non-linearity and dropout mask\n", 542 | " cnt_act += 1\n", 543 | " \n", 544 | " # compute the error\n", 545 | " error = outputs - target_onehot \n", 546 | " \n", 547 | " # modify the input with the error\n", 548 | " error_input = error @ B.T\n", 549 | " mod_inputs = inputs + error_input\n", 550 | " \n", 551 | " # forward pass 2 with modified input --> keep track of modulated activations\n", 552 | " mod_outputs = net(mod_inputs,do_masks)\n", 553 | " mod_layers_act = []\n", 554 | " cnt_act = 0\n", 555 | " for key in activation:\n", 556 | " if 'fc' in key or 'conv' in key:\n", 557 | " mod_layers_act.append(F.relu(activation[key])* do_masks[cnt_act]) # Note: we need to register the activations taking into account non-linearity and dropout mask\n", 558 | " cnt_act += 1\n", 559 | " mod_error = mod_outputs - target_onehot\n", 560 | " \n", 561 | " # compute the delta_w for the batch\n", 562 | " delta_w_all = []\n", 563 | " v_w = []\n", 564 | " for l_idx,w in enumerate(net.parameters()):\n", 565 | " v_w.append(torch.zeros(w.shape))\n", 566 | " \n", 567 | " for l in range(len(layers_act)):\n", 568 | " \n", 569 | " # update for the last layer\n", 570 | " if l == len(layers_act)-1:\n", 571 | " \n", 572 | " if len(layers_act)>1:\n", 573 | " delta_w = -mod_error.T @ mod_layers_act[-2]\n", 574 | " else:\n", 575 | " delta_w = -mod_error.T @ mod_inputs\n", 576 | " \n", 577 | " # update for the first layer\n", 578 | " elif l == 0:\n", 579 | " delta_w = -(layers_act[l] - mod_layers_act[l]).T @ mod_inputs\n", 580 | " \n", 581 | " # update for the hidden layers (not first, not last)\n", 582 | " elif l>0 and l9: 420 | if e%int(train_epochs/10)==0: 421 | print('Training epoch {}/{}'.format(e,train_epochs)) 422 | else: 423 | print('Training epoch {}/{}'.format(e,train_epochs)) 424 | if eta_decay: 425 | if e in [60,90]: 426 | eta = np.max((eta*(0.1),1e-6)) 427 | print("Learning rate at epoch {} decreased to {}".format(e,eta)) 428 | acc = [] 429 | val_accuracy = [] 430 | self.new_batch = True 431 | 432 | for s in range(dataset_size): 433 | #print("********** sample {} **********".format(s)) 434 | self.s = s 435 | x = x_list[s] 436 | target = target_list[s] 437 | self.target = target 438 | for p in range(sample_passes): 439 | #print("Pass number",p) 440 | targs.append(np.argmax(target)) 441 | 442 | if self.learn_type != 'ERIN' and self.learn_type != 'ERINsign': 443 | y,self.error = self.forward(x,target,dropout,training=True) 444 | self.learning(self.error,eta,dropout) 445 | 446 | elif self.learn_type == 'ERIN': 447 | if p == 0: # no learning, only compute output 448 | y,self.error = self.forward(x,target,dropout,training=True, 449 | error_input=None,compute_diff=False) 450 | if p > 0: # apply the weight change only after the first pass 451 | error_input = self.layers[-1].F @ self.error 452 | x = np.copy(x_list[s]) 453 | y,self.error = self.forward(x,target,dropout,training=True, 454 | error_input=error_input,compute_diff=True) 455 | self.learning(self.error,eta,dropout) 456 | 457 | 458 | elif self.learn_type == 'ERINsign': 459 | if p == 0: # no learning, only compute output 460 | y,self.error = self.forward(x,target,dropout,training=True, 461 | error_input=None,compute_diff=False) 462 | if p > 0: # apply the weight change only after the first pass 463 | error_input = self.layers[-1].F @ np.sign(self.error) 464 | x = np.copy(x_list[s]) 465 | y,self.error = self.forward(x,target,dropout,training=True, 466 | error_input=error_input,compute_diff=True) 467 | self.learning(self.error,eta,dropout) 468 | 469 | # save the error 470 | E_curve.append(np.sum(abs(self.error))) 471 | self.pred = onehotenc(np.argmax(y),np.size(y)) 472 | pred_all.append(np.argmax(self.pred)) 473 | if np.argmax(y) == np.argmax(target): 474 | acc.append(1) 475 | else: 476 | acc.append(0) 477 | 478 | #print("target={}, pred={} at {}%".format(np.argmax(target),np.argmax(y),np.round(np.max(y),2))) 479 | 480 | # check cosine similarity during training AND matrix norm 481 | if self.learn_type == 'ERIN' and check_cos_norm: 482 | w_all = [] 483 | for l_idx,l in enumerate(self.layers): 484 | w_all.append(np.copy(l.w)) 485 | if l_idx == 0: 486 | norm_w0.append(LA.norm(l.w)) 487 | #print('norm of w at layer {} is {}'.format(l_idx,norm_w0)) 488 | w_prod = w_all[0].T 489 | for idx in range(1,len(w_all)): 490 | w_prod = w_prod @ w_all[idx].T 491 | #print(np.shape(w_prod)) 492 | w_prod = w_prod.flatten() 493 | B_flat = np.array(self.layers[-1].F).flatten() 494 | cos = 1-spatial.distance.cosine(w_prod,B_flat) 495 | arccos = np.arccos(cos)*180/np.pi 496 | #print('Angle between Ws and B',arccos) 497 | angles.append(arccos) 498 | 499 | self.acc_all.append(np.mean(acc)) 500 | if train_epochs>9: 501 | if e%int(train_epochs/10)==0: 502 | print('Training accuracy = {}'.format(self.acc_all[-1])) 503 | else: 504 | print('Training accuracy = {}'.format(self.acc_all[-1])) 505 | 506 | if e%1 == 0: 507 | np.savetxt(savepath+'/train_acc_tot.txt',np.array([self.acc_all])) 508 | if check_cos_norm: 509 | np.savetxt(savepath+'/angles.txt',np.array([angles])) 510 | np.savetxt(savepath+'/Anorm.txt',np.array([norm_w0])) 511 | 512 | # save the weights 513 | for i in range(self.n_layers-1): 514 | #np.savetxt(savepath+'/weights_layer'+str(i)+'.txt',self.layers[i].w) 515 | pass 516 | # perform validation 517 | if validation: 518 | for s in range(val_size): 519 | x = x_val[s] 520 | target = target_val[s] 521 | self.target = target 522 | y,self.error = self.forward(x,target,dropout,training=False) 523 | # save the error 524 | self.pred = onehotenc(np.argmax(y),np.size(y)) 525 | #print('target {} pred {}'.format(np.argmax(target),np.argmax(self.pred))) 526 | val_pred_all.append(np.argmax(self.pred)) 527 | if np.argmax(y) == np.argmax(target): 528 | val_accuracy.append(1) 529 | else: 530 | val_accuracy.append(0) 531 | val_targs.append(np.argmax(target)) 532 | 533 | self.val_acc_all.append(np.mean(val_accuracy)) 534 | if train_epochs>9: 535 | if e%int(train_epochs/10)==0: 536 | print('Validation accuracy = {}'.format(self.val_acc_all[-1])) 537 | else: 538 | print('Validation accuracy = {}'.format(self.val_acc_all[-1])) 539 | if e%1 == 0: 540 | np.savetxt(savepath+'/val_acc_tot.txt',np.array([self.val_acc_all])) 541 | 542 | if plots: 543 | plt.figure() 544 | plt.plot(val_pred_all,'*',label='Prediction') 545 | plt.plot(val_targs,'.',label='Target') 546 | plt.title(str(self.learn_type)+' - Validation') 547 | plt.legend() 548 | 549 | 550 | if plots: 551 | plt.figure() 552 | plt.plot(pred_all,'*',label='Prediction') 553 | plt.plot(targs,'.',label='Target') 554 | plt.title(str(self.learn_type)) 555 | plt.legend() 556 | 557 | 558 | return E_curve, self.acc_all, self.val_acc_all 559 | 560 | 561 | def test(self,x_list,target_list,plots,plots_test=False): 562 | dataset_size = len(x_list) 563 | pred_all = [] 564 | targs = [] 565 | accuracy = [] 566 | 567 | for s in range(dataset_size): 568 | if dataset_size>9: 569 | if s%int(dataset_size/10)==0: 570 | #print('Testing sample {}/{}'.format(s,dataset_size)) 571 | pass 572 | else: 573 | #print('Testing sample {}/{}'.format(s,dataset_size)) 574 | pass 575 | 576 | x = x_list[s] 577 | target = target_list[s] 578 | self.target = target 579 | y,self.error = self.forward(x,target,self.dropout,training=False) 580 | # save the error 581 | self.pred = onehotenc(np.argmax(y),np.size(y)) 582 | #print('target {} pred {}'.format(np.argmax(target),np.argmax(self.pred))) 583 | pred_all.append(np.argmax(self.pred)) 584 | if np.argmax(y) == np.argmax(target): 585 | accuracy.append(1) 586 | else: 587 | accuracy.append(0) 588 | targs.append(np.argmax(target)) 589 | 590 | accuracy_mean = np.mean(accuracy) 591 | 592 | if plots or plots_test: 593 | plt.figure() 594 | plt.plot(pred_all,'*',label='Prediction') 595 | plt.plot(targs,'.',label='Target') 596 | plt.title(str(self.learn_type)) 597 | plt.legend() 598 | 599 | return accuracy_mean 600 | 601 | --------------------------------------------------------------------------------