├── .gitignore ├── README.md ├── experiments ├── cifar10 │ ├── bp-V0 │ │ ├── experiment_config.txt │ │ ├── log.txt │ │ └── results.json │ ├── fa-V0 │ │ ├── experiment_config.txt │ │ ├── log.txt │ │ └── results.json │ ├── kp-V0 │ │ ├── experiment_config.txt │ │ ├── log.txt │ │ └── results.json │ └── wm-V0 │ │ ├── experiment_config.txt │ │ ├── log.txt │ │ └── results.json └── mnist │ ├── bp-V0 │ ├── experiment_config.txt │ ├── log.txt │ └── results.json │ ├── fa-V0 │ ├── experiment_config.txt │ ├── log.txt │ └── results.json │ ├── kp-V0 │ ├── experiment_config.txt │ ├── log.txt │ └── results.json │ └── wm-V0 │ ├── experiment_config.txt │ ├── log.txt │ └── results.json ├── fcnn ├── FCNN_BP.py ├── FCNN_FA.py ├── FCNN_KP.py ├── FCNN_WM.py └── __init__.py ├── figures ├── cifar10 │ ├── delta_angles.png │ ├── loss.png │ ├── test_accuracies.png │ └── weight_angles.png └── mnist │ ├── delta_angles.png │ ├── loss.png │ ├── test_accuracies.png │ └── weight_angles.png ├── main.py ├── plot_figures.py ├── requirements.txt ├── run.sh └── script ├── autolint ├── env └── up /.gitignore: -------------------------------------------------------------------------------- 1 | .cache 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | .pytest_cache/ 31 | 32 | # Environments 33 | .env 34 | .venv 35 | env/ 36 | venv/ 37 | ENV/ 38 | VENV/ 39 | 40 | # Jupyter Notebook checkpoints 41 | .ipynb_checkpoints 42 | 43 | # IntelliJ’s project specific settings files 44 | .idea 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Learning-without-Weight-Transport 2 | ##### Mohamed Akrout, Collin Wilson, Peter C. Humphreys, Timothy Lillicrap, Douglas Tweed 3 | 4 | Current algorithms for deep learning probably cannot run in the brain because they rely on weight transport, where forward-path neurons transmit their synaptic weights to a feedback path, in a way that is likely impossible biologically. In this work, we present two new mechanisms which let the feedback path learn appropriate synaptic weights quickly and accurately even in large networks, without weight transport or complex wiring. One mechanism is a neural circuit called a weight mirror, which learns without sensory input, and so could tune feedback paths in the pauses between physical trials of a task, or even in sleep or in utero. The other mechanism is based on a 1994 algorithm of Kolen and Pollack. Their method worked by transporting weight changes, which is no more biological than transporting the weights themselves, but we have shown that a simple circuit lets forward and feedback synapses compute their changes separately, based on local information, and still evolve as in the Kolen-Pollack algorithm. Tested on the ImageNet visual-recognition task, both the weight mirror and the Kolen-Pollack circuit outperform other recent proposals for biologically feasible learning — feedback alignment and the sign-symmetry method — and nearly match backprop, the standard algorithm of deep learning, which uses weight transport. 5 | 6 | [Preprint here](https://arxiv.org/pdf/1904.05391.pdf), feedback welcome! Contact Mohamed and Douglas: makrout@cs.toronto.edu, douglas.tweed@utoronto.ca 7 | 8 | # Getting started 9 | 10 | 11 | ## First-time setup 12 | 13 | If this is your first time running the code, follow these steps: 14 | 15 | 1. Run `script/up` to create a virtual environment `.venv` with the required packages 16 | 2. Activate the virtual environment by running `source .venv/bin/activate` 17 | 18 | ## Running experiments 19 | ### Arguments 20 | 21 | | Argument         | Description | Values | 22 | | :--- | :--- | :--- | 23 | | --dataset | Dataset's name | Choose from {mnist, cifar10} | 24 | | --algo | Learning algorithm's name | Choose from {bp, fa, wm, kp} | 25 | | --n_epochs | Number of epochs to run | 400 (default) | 26 | | --batch_size | Batch size | 128 (default) | 27 | | --learning_rate | Learning rate | 0.2 (default) | 28 | | --test_frequency | # of epochs before evaluation | 1 (default) | 29 | 30 | ### Example On MNIST: 31 | ```bash 32 | # Run backpropagation (BP) 33 | python load_and_run.py --dataset=mnist --algo=bp --n_epochs=400 --size_hidden_layers 500 --batch_size=128 --learning_rate=0.2 --test_frequency=1 34 | 35 | # Run feedback alignment (FA) 36 | python load_and_run.py --dataset=mnist --algo=fa --n_epochs=400 --size_hidden_layers 500 --batch_size=128 --learning_rate=0.2 --test_frequency=1 37 | 38 | # Run weight mirrors (WM) 39 | python load_and_run.py --dataset=mnist --algo=wm --n_epochs=400 --size_hidden_layers 500 --batch_size=128 --learning_rate=0.05 --test_frequency=1 40 | 41 | # Run the Kolen-Pollack (KP) algorithm 42 | python load_and_run.py --dataset=mnist --algo=kp --n_epochs=400 --size_hidden_layers 500 --batch_size=128 --learning_rate=0.3 --test_frequency=1 43 | ``` 44 | 45 | ## Alignment Results 46 | 47 | All the figures of these algorithms can be generated by running the script `plot_figures.py`. 48 | 49 | 1. delta angles between Backprop and the algorithms: feedback alignment, weight mirrors and Kolen-Pollack. 50 |

51 | delta angles on MNIST 52 | delta angles on CIFAR10 53 |

54 | 55 | 2. Weight angles between Backprop and the algorithms: feedback alignment, weight mirrors and Kolen-Pollack. 56 |

57 | weight angles on MNIST 58 | weight angles on CIFAR10 59 |

60 | 61 | ### The two new proposed algorithms 62 | 63 | - **Weight Mirrors (WM)**: it represents the the second learning mode alternating with the engaged mode during the training. This algorithm suggests that neurons can discharge noisily their signals and adjust the feedback weights so they mimic the forward ones. Here is a pseudo-code of this method: 64 | 65 | ```python 66 | for every layer: 67 | # generate the noise of the forward neurons 68 | noise_x = noise_amplitude * (np.random.rand(forward_weight_size, batch_size) - 0.5) 69 | # send the noise through the forward weight matrix to the next layer 70 | noise_y = self.sigmoid(np.matmul(forward_weight, noise_x) + bias) 71 | # update the backward weight matrices using the equation 7 of the paper manuscript 72 | # i.e. the delta signal becomes equal the neuroons' noise 73 | backward_weight += mirror_learning_rate * np.matmul(noise_x, noise_y.T) 74 | ``` 75 | - **Kolen-Pollack algorithm (KP)**: it solves the weight transport problem by transporting the changes in weights. At every time step, the forward and backward weights undergo identical adjustments and apply identical weight-decay factors as described in the equations 16 and 17 of the paper manuscript. 76 | ```python 77 | new_forward_weights = weight_decay * current_forward_weights - learning_rate * delta_forward_weights 78 | new_backward_weights = weight_decay * current_backward_weights - learning_rate * delta_backward_weights 79 | ``` 80 | 81 | ## Note 82 | This repository provides a Python version of the proprietary TensorFlow/TPU code for the weight mirror and the KP reciprocal network that we used in our tests. 83 | 84 | ## Credit 85 | The backpropagation code uses the same function structure of the backpagation code of Michael Nielsen's [repository](https://github.com/mnielsen/neural-networks-and-deep-learning). However, we added different code refactoring, batch learning and the two new algorithms we proposed in the paper. 86 | 87 | 88 | ## Citing the paper (bib) 89 | ``` 90 | @article{akrout2019deep, 91 | title={Deep Learning without Weight Transport.}, 92 | author={Akrout, Mohamed and Wilson, Collin and Humphreys, Peter C and Lillicrap, Timothy P and Tweed, Douglas B}, 93 | journal={CoRR, abs/1904.05391}, 94 | year={2019} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /experiments/cifar10/bp-V0/experiment_config.txt: -------------------------------------------------------------------------------- 1 | dataset cifar10 2 | algo bp 3 | n_epochs 400 4 | size_hidden_layers [1000] 5 | batch_size 128 6 | learning_rate 0.2 7 | test_frequency 1 8 | save_dir ./experiments 9 | seed 1111 10 | -------------------------------------------------------------------------------- /experiments/cifar10/bp-V0/log.txt: -------------------------------------------------------------------------------- 1 | 2 | ==================== 3 | Running the code with the dataset cifar10: 4 | learning algorithm: bp 5 | batch_size: 128 6 | learning rate: 0.2 7 | n_epochs: 400 8 | test frequency: 1 9 | size_layers: [3072, 1000, 10] 10 | ==================== 11 | 12 | 13 | Epoch 0 completed in 28.742406129837036 Loss: 3.053297821612102 14 | Test accuracy: 29.46 15 | 16 | Epoch 1 completed in 24.236584424972534 Loss: 2.7544624696857425 17 | Test accuracy: 34.1 18 | 19 | Epoch 2 completed in 29.321950674057007 Loss: 2.657579030362338 20 | Test accuracy: 37.1 21 | 22 | Epoch 3 completed in 28.695401430130005 Loss: 2.5885555230662174 23 | Test accuracy: 39.06 24 | 25 | Epoch 4 completed in 29.68910837173462 Loss: 2.5341922342093666 26 | Test accuracy: 40.57 27 | 28 | Epoch 5 completed in 29.78023099899292 Loss: 2.4886016642975153 29 | Test accuracy: 41.76 30 | 31 | Epoch 6 completed in 31.08112406730652 Loss: 2.4497514553688973 32 | Test accuracy: 42.47 33 | 34 | Epoch 7 completed in 25.864398956298828 Loss: 2.416203075461768 35 | Test accuracy: 43.12 36 | 37 | Epoch 8 completed in 48.85268235206604 Loss: 2.386312792858469 38 | Test accuracy: 43.97 39 | 40 | Epoch 9 completed in 63.38818120956421 Loss: 2.3589836654963596 41 | Test accuracy: 44.64 42 | 43 | Epoch 10 completed in 72.1238944530487 Loss: 2.333677006296486 44 | Test accuracy: 45.24 45 | 46 | Epoch 11 completed in 72.51838231086731 Loss: 2.310066980723325 47 | Test accuracy: 45.84 48 | 49 | Epoch 12 completed in 59.830355167388916 Loss: 2.2879233478364225 50 | Test accuracy: 46.23 51 | 52 | Epoch 13 completed in 58.21703815460205 Loss: 2.2670541495969303 53 | Test accuracy: 46.64 54 | 55 | Epoch 14 completed in 62.4963002204895 Loss: 2.24729992419288 56 | Test accuracy: 47.08 57 | 58 | Epoch 15 completed in 60.24095010757446 Loss: 2.2285314843657877 59 | Test accuracy: 47.34 60 | 61 | Epoch 16 completed in 57.912109375 Loss: 2.2106469899407992 62 | Test accuracy: 47.6 63 | 64 | Epoch 17 completed in 58.23998737335205 Loss: 2.1935450179010476 65 | Test accuracy: 47.87 66 | 67 | Epoch 18 completed in 60.085673332214355 Loss: 2.1771187933720126 68 | Test accuracy: 48.0 69 | 70 | Epoch 19 completed in 59.92277240753174 Loss: 2.161273952346151 71 | Test accuracy: 48.19 72 | 73 | Epoch 20 completed in 62.62262964248657 Loss: 2.1459346099283043 74 | Test accuracy: 48.4 75 | 76 | Epoch 21 completed in 72.37492370605469 Loss: 2.131042156391768 77 | Test accuracy: 48.47 78 | 79 | Epoch 22 completed in 73.89888501167297 Loss: 2.1165507195715287 80 | Test accuracy: 48.66 81 | 82 | Epoch 23 completed in 71.79264974594116 Loss: 2.1024204165816274 83 | Test accuracy: 48.65 84 | 85 | Epoch 24 completed in 71.92475771903992 Loss: 2.0886137621787815 86 | Test accuracy: 48.81 87 | 88 | Epoch 25 completed in 72.41856813430786 Loss: 2.07509543172961 89 | Test accuracy: 48.95 90 | 91 | Epoch 26 completed in 74.60004115104675 Loss: 2.0618350155436205 92 | Test accuracy: 49.21 93 | 94 | Epoch 27 completed in 71.045494556427 Loss: 2.0488099756167095 95 | Test accuracy: 49.39 96 | 97 | Epoch 28 completed in 73.01262354850769 Loss: 2.036004031853832 98 | Test accuracy: 49.54 99 | 100 | Epoch 29 completed in 73.08386945724487 Loss: 2.0234030278233885 101 | Test accuracy: 49.67 102 | 103 | Epoch 30 completed in 76.26278519630432 Loss: 2.0109929126773554 104 | Test accuracy: 49.73 105 | 106 | Epoch 31 completed in 72.24049806594849 Loss: 1.9987595909337506 107 | Test accuracy: 49.94 108 | 109 | Epoch 32 completed in 74.89494299888611 Loss: 1.9866888320945661 110 | Test accuracy: 50.06 111 | 112 | Epoch 33 completed in 73.97007942199707 Loss: 1.9747664235481248 113 | Test accuracy: 50.26 114 | 115 | Epoch 34 completed in 75.96859407424927 Loss: 1.9629787613128167 116 | Test accuracy: 50.33 117 | 118 | Epoch 35 completed in 70.94651627540588 Loss: 1.951313162050308 119 | Test accuracy: 50.4 120 | 121 | Epoch 36 completed in 69.79587531089783 Loss: 1.9397579336668678 122 | Test accuracy: 50.43 123 | 124 | Epoch 37 completed in 69.60495567321777 Loss: 1.9283027563723654 125 | Test accuracy: 50.52 126 | 127 | Epoch 38 completed in 70.36987042427063 Loss: 1.9169390230863919 128 | Test accuracy: 50.71 129 | 130 | Epoch 39 completed in 71.01841688156128 Loss: 1.9056596072487555 131 | Test accuracy: 50.79 132 | 133 | Epoch 40 completed in 68.71987199783325 Loss: 1.8944583908445587 134 | Test accuracy: 50.84 135 | 136 | Epoch 41 completed in 71.09561371803284 Loss: 1.88332997300468 137 | Test accuracy: 50.77 138 | 139 | Epoch 42 completed in 71.72982549667358 Loss: 1.8722696718810348 140 | Test accuracy: 50.95 141 | 142 | Epoch 43 completed in 73.74319767951965 Loss: 1.861273592122536 143 | Test accuracy: 50.95 144 | 145 | Epoch 44 completed in 71.06633019447327 Loss: 1.8503384064400727 146 | Test accuracy: 50.97 147 | 148 | Epoch 45 completed in 71.63289976119995 Loss: 1.83946092841614 149 | Test accuracy: 51.04 150 | 151 | Epoch 46 completed in 69.03900051116943 Loss: 1.8286378638049863 152 | Test accuracy: 51.14 153 | 154 | Epoch 47 completed in 71.88930869102478 Loss: 1.8178659230459533 155 | Test accuracy: 51.29 156 | 157 | Epoch 48 completed in 68.9676022529602 Loss: 1.8071421374204262 158 | Test accuracy: 51.33 159 | 160 | Epoch 49 completed in 72.63107872009277 Loss: 1.7964642267969984 161 | Test accuracy: 51.41 162 | 163 | Epoch 50 completed in 70.55856418609619 Loss: 1.785830765124631 164 | Test accuracy: 51.5 165 | 166 | Epoch 51 completed in 74.32642340660095 Loss: 1.7752407527480503 167 | Test accuracy: 51.6 168 | 169 | Epoch 52 completed in 73.71036291122437 Loss: 1.7646926317366858 170 | Test accuracy: 51.65 171 | 172 | Epoch 53 completed in 68.56588673591614 Loss: 1.754183474453197 173 | Test accuracy: 51.76 174 | 175 | Epoch 54 completed in 72.0324296951294 Loss: 1.7437090597635378 176 | Test accuracy: 51.9 177 | 178 | Epoch 55 completed in 71.04318356513977 Loss: 1.7332649420652027 179 | Test accuracy: 51.9 180 | 181 | Epoch 56 completed in 73.22227501869202 Loss: 1.7228475509935774 182 | Test accuracy: 51.9 183 | 184 | Epoch 57 completed in 73.35224771499634 Loss: 1.7124540179898529 185 | Test accuracy: 51.97 186 | 187 | Epoch 58 completed in 71.65839719772339 Loss: 1.702081366005622 188 | Test accuracy: 51.89 189 | 190 | Epoch 59 completed in 70.8826699256897 Loss: 1.6917264732544823 191 | Test accuracy: 51.93 192 | 193 | Epoch 60 completed in 72.53898239135742 Loss: 1.6813867041029704 194 | Test accuracy: 51.86 195 | 196 | Epoch 61 completed in 72.06977796554565 Loss: 1.6710603013723404 197 | Test accuracy: 51.92 198 | 199 | Epoch 62 completed in 72.96989274024963 Loss: 1.660746224564675 200 | Test accuracy: 51.89 201 | 202 | Epoch 63 completed in 70.69032263755798 Loss: 1.6504439504532697 203 | Test accuracy: 51.81 204 | 205 | Epoch 64 completed in 72.46765995025635 Loss: 1.6401536076002945 206 | Test accuracy: 51.85 207 | 208 | Epoch 65 completed in 70.896484375 Loss: 1.62987651980273 209 | Test accuracy: 51.69 210 | 211 | Epoch 66 completed in 71.28442287445068 Loss: 1.6196162514052428 212 | Test accuracy: 51.74 213 | 214 | Epoch 67 completed in 72.09442639350891 Loss: 1.6093799751616602 215 | Test accuracy: 51.7 216 | 217 | Epoch 68 completed in 73.37898445129395 Loss: 1.5991789831162144 218 | Test accuracy: 51.55 219 | 220 | Epoch 69 completed in 72.7078013420105 Loss: 1.5890257316452252 221 | Test accuracy: 51.42 222 | 223 | Epoch 70 completed in 72.49482345581055 Loss: 1.578925961561165 224 | Test accuracy: 51.29 225 | 226 | Epoch 71 completed in 71.98290657997131 Loss: 1.5688736785761248 227 | Test accuracy: 51.31 228 | 229 | Epoch 72 completed in 73.41173124313354 Loss: 1.5588670329932717 230 | Test accuracy: 51.24 231 | 232 | Epoch 73 completed in 71.9002320766449 Loss: 1.548944845841263 233 | Test accuracy: 51.19 234 | 235 | Epoch 74 completed in 71.8023009300232 Loss: 1.5393038092986322 236 | Test accuracy: 51.08 237 | 238 | Epoch 75 completed in 74.26058197021484 Loss: 1.5300024984051075 239 | Test accuracy: 51.09 240 | 241 | Epoch 76 completed in 73.89085173606873 Loss: 1.527897266640774 242 | Test accuracy: 51.26 243 | 244 | Epoch 77 completed in 74.44192171096802 Loss: 1.5135208774070774 245 | Test accuracy: 51.14 246 | 247 | Epoch 78 completed in 72.57588028907776 Loss: 1.5068405139067385 248 | Test accuracy: 50.99 249 | 250 | Epoch 79 completed in 73.18593335151672 Loss: 1.4945305072492228 251 | Test accuracy: 51.27 252 | 253 | Epoch 80 completed in 72.11677837371826 Loss: 1.4872993788459716 254 | Test accuracy: 50.83 255 | 256 | Epoch 81 completed in 70.91711688041687 Loss: 1.4769146749996116 257 | Test accuracy: 51.11 258 | 259 | Epoch 82 completed in 71.5919120311737 Loss: 1.4665974074120334 260 | Test accuracy: 51.23 261 | 262 | Epoch 83 completed in 67.17302680015564 Loss: 1.4601976003214532 263 | Test accuracy: 51.1 264 | 265 | Epoch 84 completed in 70.25320219993591 Loss: 1.4425586407313118 266 | Test accuracy: 51.13 267 | 268 | Epoch 85 completed in 72.9181215763092 Loss: 1.4395060560085473 269 | Test accuracy: 51.13 270 | 271 | Epoch 86 completed in 73.65581107139587 Loss: 1.4298239779663735 272 | Test accuracy: 50.8 273 | 274 | Epoch 87 completed in 72.60157823562622 Loss: 1.4166283433304228 275 | Test accuracy: 50.95 276 | 277 | Epoch 88 completed in 72.17531037330627 Loss: 1.4138485485077925 278 | Test accuracy: 50.95 279 | 280 | Epoch 89 completed in 72.31492400169373 Loss: 1.4025990300095164 281 | Test accuracy: 50.94 282 | 283 | Epoch 90 completed in 73.80630135536194 Loss: 1.3959181884002312 284 | Test accuracy: 51.0 285 | 286 | Epoch 91 completed in 72.89004039764404 Loss: 1.3805346197981345 287 | Test accuracy: 50.52 288 | 289 | Epoch 92 completed in 71.49877595901489 Loss: 1.376455387795295 290 | Test accuracy: 51.1 291 | 292 | Epoch 93 completed in 72.46301412582397 Loss: 1.3682275220173614 293 | Test accuracy: 51.1 294 | 295 | Epoch 94 completed in 74.05142736434937 Loss: 1.356601640376385 296 | Test accuracy: 50.48 297 | 298 | Epoch 95 completed in 72.37192821502686 Loss: 1.3503017749006125 299 | Test accuracy: 51.17 300 | 301 | Epoch 96 completed in 72.29743456840515 Loss: 1.3336550699045469 302 | Test accuracy: 50.74 303 | 304 | Epoch 97 completed in 72.54759478569031 Loss: 1.3324539219622293 305 | Test accuracy: 51.11 306 | 307 | Epoch 98 completed in 72.35094499588013 Loss: 1.3205711032967522 308 | Test accuracy: 50.83 309 | 310 | Epoch 99 completed in 74.0632860660553 Loss: 1.3120985488589383 311 | Test accuracy: 51.07 312 | 313 | Epoch 100 completed in 71.27108716964722 Loss: 1.2985679847864886 314 | Test accuracy: 50.97 315 | 316 | Epoch 101 completed in 75.47282600402832 Loss: 1.287547636267768 317 | Test accuracy: 50.57 318 | 319 | Epoch 102 completed in 74.68850922584534 Loss: 1.2882279363402853 320 | Test accuracy: 51.06 321 | 322 | Epoch 103 completed in 71.36130809783936 Loss: 1.2710199947931484 323 | Test accuracy: 51.12 324 | 325 | Epoch 104 completed in 72.11685919761658 Loss: 1.2734094775185405 326 | Test accuracy: 50.78 327 | 328 | Epoch 105 completed in 71.64312267303467 Loss: 1.257790563521754 329 | Test accuracy: 50.67 330 | 331 | Epoch 106 completed in 73.18255352973938 Loss: 1.244044202099717 332 | Test accuracy: 50.7 333 | 334 | Epoch 107 completed in 72.64436364173889 Loss: 1.2459859394029764 335 | Test accuracy: 50.85 336 | 337 | Epoch 108 completed in 69.59485626220703 Loss: 1.2287375325602887 338 | Test accuracy: 50.61 339 | 340 | Epoch 109 completed in 71.88749074935913 Loss: 1.2318849596941224 341 | Test accuracy: 50.62 342 | 343 | Epoch 110 completed in 72.21685695648193 Loss: 1.2179448237873447 344 | Test accuracy: 50.81 345 | 346 | Epoch 111 completed in 73.20576477050781 Loss: 1.2149614370320683 347 | Test accuracy: 50.44 348 | 349 | Epoch 112 completed in 72.87097263336182 Loss: 1.1995577792891874 350 | Test accuracy: 50.67 351 | 352 | Epoch 113 completed in 71.98776531219482 Loss: 1.197275143731275 353 | Test accuracy: 50.28 354 | 355 | Epoch 114 completed in 71.5116274356842 Loss: 1.1862311447232923 356 | Test accuracy: 50.68 357 | 358 | Epoch 115 completed in 75.22151327133179 Loss: 1.174250299314035 359 | Test accuracy: 50.64 360 | 361 | Epoch 116 completed in 75.222088098526 Loss: 1.1698691226246885 362 | Test accuracy: 50.72 363 | 364 | Epoch 117 completed in 72.45111036300659 Loss: 1.1559458142173933 365 | Test accuracy: 50.74 366 | 367 | Epoch 118 completed in 73.09412670135498 Loss: 1.1537784514656675 368 | Test accuracy: 50.29 369 | 370 | Epoch 119 completed in 74.77255964279175 Loss: 1.1484372212011114 371 | Test accuracy: 50.54 372 | 373 | Epoch 120 completed in 72.17480516433716 Loss: 1.1332561777055334 374 | Test accuracy: 50.69 375 | 376 | Epoch 121 completed in 75.22390985488892 Loss: 1.132647291981585 377 | Test accuracy: 50.46 378 | 379 | Epoch 122 completed in 72.53468465805054 Loss: 1.1158431677937604 380 | Test accuracy: 50.95 381 | 382 | Epoch 123 completed in 70.2790732383728 Loss: 1.1204003140444763 383 | Test accuracy: 50.92 384 | 385 | Epoch 124 completed in 75.35371661186218 Loss: 1.1114335921687921 386 | Test accuracy: 50.51 387 | 388 | Epoch 125 completed in 72.98771119117737 Loss: 1.10106841948988 389 | Test accuracy: 50.17 390 | 391 | Epoch 126 completed in 73.59141564369202 Loss: 1.0903906857938068 392 | Test accuracy: 50.41 393 | 394 | Epoch 127 completed in 73.36245346069336 Loss: 1.0796441455507035 395 | Test accuracy: 50.66 396 | 397 | Epoch 128 completed in 73.15183234214783 Loss: 1.0791847006945505 398 | Test accuracy: 50.62 399 | 400 | Epoch 129 completed in 72.88839197158813 Loss: 1.0716726259243983 401 | Test accuracy: 50.55 402 | 403 | Epoch 130 completed in 70.1832857131958 Loss: 1.0570679739533766 404 | Test accuracy: 50.34 405 | 406 | Epoch 131 completed in 73.11061096191406 Loss: 1.0554593191201072 407 | Test accuracy: 51.17 408 | 409 | Epoch 132 completed in 72.33777236938477 Loss: 1.048038886549545 410 | Test accuracy: 50.3 411 | 412 | Epoch 133 completed in 74.1661274433136 Loss: 1.0518754234551375 413 | Test accuracy: 51.08 414 | 415 | Epoch 134 completed in 72.51139521598816 Loss: 1.0352695514414578 416 | Test accuracy: 50.48 417 | 418 | Epoch 135 completed in 71.86236071586609 Loss: 1.0157572713289291 419 | Test accuracy: 50.89 420 | 421 | Epoch 136 completed in 74.67593741416931 Loss: 1.0168396434306919 422 | Test accuracy: 50.55 423 | 424 | Epoch 137 completed in 71.73832869529724 Loss: 1.0264149733562156 425 | Test accuracy: 50.75 426 | 427 | Epoch 138 completed in 70.36094856262207 Loss: 0.9992703363510664 428 | Test accuracy: 50.95 429 | 430 | Epoch 139 completed in 74.97921085357666 Loss: 0.9834487744836374 431 | Test accuracy: 50.76 432 | 433 | Epoch 140 completed in 75.29244804382324 Loss: 0.9881911178565947 434 | Test accuracy: 50.53 435 | 436 | Epoch 141 completed in 71.16616797447205 Loss: 0.9793821117827768 437 | Test accuracy: 50.74 438 | 439 | Epoch 142 completed in 74.96116375923157 Loss: 0.965372885375053 440 | Test accuracy: 50.49 441 | 442 | Epoch 143 completed in 72.65882730484009 Loss: 0.9841547396333404 443 | Test accuracy: 50.41 444 | 445 | Epoch 144 completed in 71.28412508964539 Loss: 0.9623321892185237 446 | Test accuracy: 50.08 447 | 448 | Epoch 145 completed in 73.62268781661987 Loss: 0.9595198568085187 449 | Test accuracy: 51.03 450 | 451 | Epoch 146 completed in 72.51202058792114 Loss: 0.9419436908063759 452 | Test accuracy: 50.61 453 | 454 | Epoch 147 completed in 76.03339910507202 Loss: 0.9526704004510792 455 | Test accuracy: 50.74 456 | 457 | Epoch 148 completed in 74.17329454421997 Loss: 0.9253438011948603 458 | Test accuracy: 50.18 459 | 460 | Epoch 149 completed in 73.13878583908081 Loss: 0.9317279917406975 461 | Test accuracy: 50.26 462 | 463 | Epoch 150 completed in 70.79054546356201 Loss: 0.930726158964817 464 | Test accuracy: 50.55 465 | 466 | Epoch 151 completed in 72.9537193775177 Loss: 0.9108862741914514 467 | Test accuracy: 50.28 468 | 469 | Epoch 152 completed in 74.4699056148529 Loss: 0.9065845260127082 470 | Test accuracy: 50.22 471 | 472 | Epoch 153 completed in 72.55058073997498 Loss: 0.9051368851866703 473 | Test accuracy: 49.04 474 | 475 | Epoch 154 completed in 72.93606925010681 Loss: 0.8951030757623596 476 | Test accuracy: 50.14 477 | 478 | Epoch 155 completed in 73.51294112205505 Loss: 0.8878596731146343 479 | Test accuracy: 50.06 480 | 481 | Epoch 156 completed in 71.62465834617615 Loss: 0.8849401036231553 482 | Test accuracy: 50.23 483 | 484 | Epoch 157 completed in 71.2976815700531 Loss: 0.8708838068311624 485 | Test accuracy: 50.73 486 | 487 | Epoch 158 completed in 70.75206160545349 Loss: 0.8702117041299111 488 | Test accuracy: 50.13 489 | 490 | Epoch 159 completed in 77.06475877761841 Loss: 0.8642822868252694 491 | Test accuracy: 50.34 492 | 493 | Epoch 160 completed in 74.7978355884552 Loss: 0.8594092126498233 494 | Test accuracy: 50.28 495 | 496 | Epoch 161 completed in 71.55985140800476 Loss: 0.854594938450758 497 | Test accuracy: 50.09 498 | 499 | Epoch 162 completed in 72.82956099510193 Loss: 0.8403203827486713 500 | Test accuracy: 49.77 501 | 502 | Epoch 163 completed in 71.7395761013031 Loss: 0.8333498032819532 503 | Test accuracy: 49.74 504 | 505 | Epoch 164 completed in 72.50844311714172 Loss: 0.8353561164541388 506 | Test accuracy: 50.03 507 | 508 | Epoch 165 completed in 73.7175874710083 Loss: 0.8431064658981621 509 | Test accuracy: 50.68 510 | 511 | Epoch 166 completed in 74.85490870475769 Loss: 0.8103193470248407 512 | Test accuracy: 50.52 513 | 514 | Epoch 167 completed in 71.54318499565125 Loss: 0.8026834802661839 515 | Test accuracy: 50.44 516 | 517 | Epoch 168 completed in 73.01749396324158 Loss: 0.8054162889587881 518 | Test accuracy: 50.01 519 | 520 | Epoch 169 completed in 71.13349843025208 Loss: 0.8122239779617819 521 | Test accuracy: 49.96 522 | 523 | Epoch 170 completed in 68.54035353660583 Loss: 0.8003557425692371 524 | Test accuracy: 49.81 525 | 526 | Epoch 171 completed in 72.59223699569702 Loss: 0.783970322113187 527 | Test accuracy: 50.44 528 | 529 | Epoch 172 completed in 68.53138136863708 Loss: 0.7866407617735302 530 | Test accuracy: 50.0 531 | 532 | Epoch 173 completed in 69.55573177337646 Loss: 0.771387713402842 533 | Test accuracy: 50.45 534 | 535 | Epoch 174 completed in 69.17532277107239 Loss: 0.7669279982381874 536 | Test accuracy: 49.5 537 | 538 | Epoch 175 completed in 70.86191487312317 Loss: 0.7591560928932742 539 | Test accuracy: 49.76 540 | 541 | Epoch 176 completed in 69.14711046218872 Loss: 0.7737895389083942 542 | Test accuracy: 50.15 543 | 544 | Epoch 177 completed in 71.44998574256897 Loss: 0.7768072272459314 545 | Test accuracy: 49.75 546 | 547 | Epoch 178 completed in 69.05590152740479 Loss: 0.7381474313711627 548 | Test accuracy: 49.22 549 | 550 | Epoch 179 completed in 68.71237206459045 Loss: 0.7492426251514858 551 | Test accuracy: 49.12 552 | 553 | Epoch 180 completed in 65.89209604263306 Loss: 0.7378499842463854 554 | Test accuracy: 49.9 555 | 556 | Epoch 181 completed in 70.64358234405518 Loss: 0.7259899866392359 557 | Test accuracy: 50.15 558 | 559 | Epoch 182 completed in 69.14028668403625 Loss: 0.7399774146565002 560 | Test accuracy: 50.07 561 | 562 | Epoch 183 completed in 69.65728902816772 Loss: 0.7150899542402003 563 | Test accuracy: 49.99 564 | 565 | Epoch 184 completed in 70.2642011642456 Loss: 0.7006130522396126 566 | Test accuracy: 50.15 567 | 568 | Epoch 185 completed in 71.31431293487549 Loss: 0.7236990432398175 569 | Test accuracy: 49.43 570 | 571 | Epoch 186 completed in 69.13694024085999 Loss: 0.7340686457366491 572 | Test accuracy: 50.33 573 | 574 | Epoch 187 completed in 71.43912625312805 Loss: 0.7006170731988347 575 | Test accuracy: 49.57 576 | 577 | Epoch 188 completed in 69.6170585155487 Loss: 0.679801849657919 578 | Test accuracy: 49.21 579 | 580 | Epoch 189 completed in 66.7943525314331 Loss: 0.6855390549646554 581 | Test accuracy: 49.95 582 | 583 | Epoch 190 completed in 70.0716233253479 Loss: 0.6844957052488151 584 | Test accuracy: 48.22 585 | 586 | Epoch 191 completed in 70.2063558101654 Loss: 0.6668315545458783 587 | Test accuracy: 49.77 588 | 589 | Epoch 192 completed in 68.67689681053162 Loss: 0.6555646640639419 590 | Test accuracy: 50.01 591 | 592 | Epoch 193 completed in 69.76357769966125 Loss: 0.6909166839173838 593 | Test accuracy: 49.41 594 | 595 | Epoch 194 completed in 72.0891764163971 Loss: 0.6704373824899802 596 | Test accuracy: 49.33 597 | 598 | Epoch 195 completed in 70.53363108634949 Loss: 0.6707354564546413 599 | Test accuracy: 49.74 600 | 601 | Epoch 196 completed in 70.15963625907898 Loss: 0.6550735786869909 602 | Test accuracy: 49.59 603 | 604 | Epoch 197 completed in 70.01903963088989 Loss: 0.630238333756259 605 | Test accuracy: 49.72 606 | 607 | Epoch 198 completed in 70.27178287506104 Loss: 0.6343156487706713 608 | Test accuracy: 48.83 609 | 610 | Epoch 199 completed in 72.46687531471252 Loss: 0.6136767549911878 611 | Test accuracy: 49.19 612 | 613 | Epoch 200 completed in 70.26026916503906 Loss: 0.6399702954634329 614 | Test accuracy: 49.1 615 | 616 | Epoch 201 completed in 68.39295721054077 Loss: 0.6339319455793416 617 | Test accuracy: 49.9 618 | 619 | Epoch 202 completed in 69.9641683101654 Loss: 0.6329453620798756 620 | Test accuracy: 50.05 621 | 622 | Epoch 203 completed in 73.41690587997437 Loss: 0.6488046123318139 623 | Test accuracy: 50.18 624 | 625 | Epoch 204 completed in 69.80956792831421 Loss: 0.6005514627165517 626 | Test accuracy: 49.11 627 | 628 | Epoch 205 completed in 69.99297738075256 Loss: 0.6174472354522805 629 | Test accuracy: 49.51 630 | 631 | Epoch 206 completed in 70.86913466453552 Loss: 0.6144585067992689 632 | Test accuracy: 49.92 633 | 634 | Epoch 207 completed in 68.46202516555786 Loss: 0.603719911268876 635 | Test accuracy: 49.93 636 | 637 | Epoch 208 completed in 70.28343176841736 Loss: 0.5905364785059445 638 | Test accuracy: 49.75 639 | 640 | Epoch 209 completed in 72.27253818511963 Loss: 0.5822582430118708 641 | Test accuracy: 47.68 642 | 643 | Epoch 210 completed in 71.5530149936676 Loss: 0.5687415711088983 644 | Test accuracy: 49.91 645 | 646 | Epoch 211 completed in 71.4791259765625 Loss: 0.5685769552059229 647 | Test accuracy: 49.69 648 | 649 | Epoch 212 completed in 69.50167727470398 Loss: 0.5764356261619419 650 | Test accuracy: 49.3 651 | 652 | Epoch 213 completed in 69.88989353179932 Loss: 0.572043352180387 653 | Test accuracy: 49.33 654 | 655 | Epoch 214 completed in 68.9281837940216 Loss: 0.5852667439129502 656 | Test accuracy: 49.84 657 | 658 | Epoch 215 completed in 69.86296582221985 Loss: 0.5567754724142221 659 | Test accuracy: 49.47 660 | 661 | Epoch 216 completed in 67.4602963924408 Loss: 0.5604705977008992 662 | Test accuracy: 49.76 663 | 664 | Epoch 217 completed in 70.04499769210815 Loss: 0.5473058124971276 665 | Test accuracy: 49.91 666 | 667 | Epoch 218 completed in 67.87855958938599 Loss: 0.5565082598972533 668 | Test accuracy: 49.72 669 | 670 | Epoch 219 completed in 67.32552003860474 Loss: 0.5147932201095579 671 | Test accuracy: 49.38 672 | 673 | Epoch 220 completed in 68.1470410823822 Loss: 0.5329322109265836 674 | Test accuracy: 49.59 675 | 676 | Epoch 221 completed in 70.17026472091675 Loss: 0.5618347461930323 677 | Test accuracy: 49.75 678 | 679 | Epoch 222 completed in 67.18737959861755 Loss: 0.5635242147217016 680 | Test accuracy: 49.76 681 | 682 | Epoch 223 completed in 71.29705214500427 Loss: 0.549970324584659 683 | Test accuracy: 49.5 684 | 685 | Epoch 224 completed in 68.83170938491821 Loss: 0.49411860565857085 686 | Test accuracy: 47.1 687 | 688 | Epoch 225 completed in 68.67385745048523 Loss: 0.5170639677280536 689 | Test accuracy: 49.68 690 | 691 | Epoch 226 completed in 70.11818337440491 Loss: 0.5554748183439132 692 | Test accuracy: 48.81 693 | 694 | Epoch 227 completed in 69.15998148918152 Loss: 0.5025571091472281 695 | Test accuracy: 49.79 696 | 697 | Epoch 228 completed in 69.01920437812805 Loss: 0.5317351941467594 698 | Test accuracy: 49.97 699 | 700 | Epoch 229 completed in 71.58016419410706 Loss: 0.4980491561705108 701 | Test accuracy: 49.72 702 | 703 | Epoch 230 completed in 69.42783665657043 Loss: 0.5074904811280809 704 | Test accuracy: 49.55 705 | 706 | Epoch 231 completed in 71.02840685844421 Loss: 0.502723312367161 707 | Test accuracy: 49.47 708 | 709 | Epoch 232 completed in 70.12740731239319 Loss: 0.5061692830962051 710 | Test accuracy: 49.53 711 | 712 | Epoch 233 completed in 69.68924474716187 Loss: 0.47970488459987526 713 | Test accuracy: 49.67 714 | 715 | Epoch 234 completed in 70.85739731788635 Loss: 0.46156447548316815 716 | Test accuracy: 49.79 717 | 718 | Epoch 235 completed in 70.07641577720642 Loss: 0.4594039121903712 719 | Test accuracy: 49.79 720 | 721 | Epoch 236 completed in 68.86794424057007 Loss: 0.4734331089114226 722 | Test accuracy: 48.7 723 | 724 | Epoch 237 completed in 71.36325645446777 Loss: 0.49071346182209286 725 | Test accuracy: 50.11 726 | 727 | Epoch 238 completed in 65.63136219978333 Loss: 0.482453543982539 728 | Test accuracy: 50.04 729 | 730 | Epoch 239 completed in 68.79189467430115 Loss: 0.47439764642277604 731 | Test accuracy: 50.0 732 | 733 | Epoch 240 completed in 67.71746063232422 Loss: 0.4466710151378898 734 | Test accuracy: 49.57 735 | 736 | Epoch 241 completed in 65.6564781665802 Loss: 0.44755575374227763 737 | Test accuracy: 49.76 738 | 739 | Epoch 242 completed in 67.24408841133118 Loss: 0.4741832462971999 740 | Test accuracy: 49.93 741 | 742 | Epoch 243 completed in 71.28531169891357 Loss: 0.4719313809739829 743 | Test accuracy: 49.77 744 | 745 | Epoch 244 completed in 69.59801292419434 Loss: 0.665724978645717 746 | Test accuracy: 50.14 747 | 748 | Epoch 245 completed in 71.20758199691772 Loss: 0.44958445047360845 749 | Test accuracy: 50.16 750 | 751 | Epoch 246 completed in 69.59970593452454 Loss: 0.4060911753973073 752 | Test accuracy: 49.85 753 | 754 | Epoch 247 completed in 65.49669647216797 Loss: 0.44626861092479675 755 | Test accuracy: 49.78 756 | 757 | Epoch 248 completed in 69.45210647583008 Loss: 0.4137142668500845 758 | Test accuracy: 49.66 759 | 760 | Epoch 249 completed in 68.0122857093811 Loss: 0.42818719397999105 761 | Test accuracy: 50.07 762 | 763 | Epoch 250 completed in 65.71237063407898 Loss: 0.45034449138487426 764 | Test accuracy: 49.91 765 | 766 | Epoch 251 completed in 69.15127205848694 Loss: 0.4196470639603519 767 | Test accuracy: 50.01 768 | 769 | Epoch 252 completed in 68.70920658111572 Loss: 0.4018933982053706 770 | Test accuracy: 49.96 771 | 772 | Epoch 253 completed in 69.21121907234192 Loss: 0.38113126069504866 773 | Test accuracy: 49.78 774 | 775 | Epoch 254 completed in 72.45578050613403 Loss: 0.42149596447371673 776 | Test accuracy: 49.66 777 | 778 | Epoch 255 completed in 65.89324045181274 Loss: 0.3816402812382329 779 | Test accuracy: 49.91 780 | 781 | Epoch 256 completed in 68.71777653694153 Loss: 0.4150739622195174 782 | Test accuracy: 49.79 783 | 784 | Epoch 257 completed in 71.51949262619019 Loss: 0.42498641911468843 785 | Test accuracy: 49.82 786 | 787 | Epoch 258 completed in 69.12933087348938 Loss: 0.39847836884751525 788 | Test accuracy: 49.72 789 | 790 | Epoch 259 completed in 73.34567737579346 Loss: 0.43361971990549086 791 | Test accuracy: 49.54 792 | 793 | Epoch 260 completed in 69.92176103591919 Loss: 0.3891695158849426 794 | Test accuracy: 49.94 795 | 796 | Epoch 261 completed in 71.96049952507019 Loss: 0.7068043450432698 797 | Test accuracy: 49.87 798 | 799 | Epoch 262 completed in 68.64918637275696 Loss: 0.4301938394349378 800 | Test accuracy: 50.01 801 | 802 | Epoch 263 completed in 70.50717902183533 Loss: 0.3886397264029654 803 | Test accuracy: 50.1 804 | 805 | Epoch 264 completed in 67.64396905899048 Loss: 0.3667340598561384 806 | Test accuracy: 49.79 807 | 808 | Epoch 265 completed in 68.68903350830078 Loss: 0.35780089120547764 809 | Test accuracy: 50.09 810 | 811 | Epoch 266 completed in 67.54397702217102 Loss: 0.3447011418913576 812 | Test accuracy: 49.97 813 | 814 | Epoch 267 completed in 70.54807615280151 Loss: 0.34365214400014454 815 | Test accuracy: 49.8 816 | 817 | Epoch 268 completed in 69.675039768219 Loss: 0.3775327487984379 818 | Test accuracy: 50.08 819 | 820 | Epoch 269 completed in 71.00943565368652 Loss: 0.3663604293898873 821 | Test accuracy: 49.99 822 | 823 | Epoch 270 completed in 66.47083401679993 Loss: 0.46567765292373026 824 | Test accuracy: 49.9 825 | 826 | Epoch 271 completed in 68.65406036376953 Loss: 0.6593278161154493 827 | Test accuracy: 42.83 828 | 829 | Epoch 272 completed in 68.9050760269165 Loss: 0.41080183540994514 830 | Test accuracy: 49.93 831 | 832 | Epoch 273 completed in 70.11033701896667 Loss: 0.41924891383366175 833 | Test accuracy: 50.22 834 | 835 | Epoch 274 completed in 67.26713156700134 Loss: 0.34389010725864055 836 | Test accuracy: 50.03 837 | 838 | Epoch 275 completed in 68.53717303276062 Loss: 0.3395181958236658 839 | Test accuracy: 50.06 840 | 841 | Epoch 276 completed in 69.95885968208313 Loss: 0.3325470660432857 842 | Test accuracy: 50.09 843 | 844 | Epoch 277 completed in 69.10315895080566 Loss: 0.4407182246867028 845 | Test accuracy: 49.23 846 | 847 | Epoch 278 completed in 68.74274325370789 Loss: 0.5280027963032736 848 | Test accuracy: 49.98 849 | 850 | Epoch 279 completed in 69.93453621864319 Loss: 0.34592754767675393 851 | Test accuracy: 49.9 852 | 853 | Epoch 280 completed in 68.98926162719727 Loss: 0.31600678640282576 854 | Test accuracy: 49.87 855 | 856 | Epoch 281 completed in 67.94819164276123 Loss: 0.3120363241524801 857 | Test accuracy: 49.79 858 | 859 | Epoch 282 completed in 66.07735347747803 Loss: 0.37833726395304845 860 | Test accuracy: 49.99 861 | 862 | Epoch 283 completed in 71.49466800689697 Loss: 0.32649194111068536 863 | Test accuracy: 49.68 864 | 865 | Epoch 284 completed in 69.69392657279968 Loss: 0.6624806551582927 866 | Test accuracy: 48.2 867 | 868 | Epoch 285 completed in 72.48514938354492 Loss: 0.457794253056153 869 | Test accuracy: 49.79 870 | 871 | Epoch 286 completed in 69.76488828659058 Loss: 0.3547179481467169 872 | Test accuracy: 49.89 873 | 874 | Epoch 287 completed in 70.7010931968689 Loss: 0.31631557236398583 875 | Test accuracy: 49.75 876 | 877 | Epoch 288 completed in 67.61465120315552 Loss: 0.3069762845149854 878 | Test accuracy: 49.7 879 | 880 | Epoch 289 completed in 70.98411154747009 Loss: 0.29835245643477404 881 | Test accuracy: 49.76 882 | 883 | Epoch 290 completed in 68.26999950408936 Loss: 0.2948006036874856 884 | Test accuracy: 49.92 885 | 886 | Epoch 291 completed in 69.28280186653137 Loss: 0.29078594432016985 887 | Test accuracy: 49.71 888 | 889 | Epoch 292 completed in 69.53418803215027 Loss: 0.8893957775266087 890 | Test accuracy: 49.29 891 | 892 | Epoch 293 completed in 69.1816291809082 Loss: 0.5914874555431668 893 | Test accuracy: 49.56 894 | 895 | Epoch 294 completed in 67.16020655632019 Loss: 0.4366035156249039 896 | Test accuracy: 49.65 897 | 898 | Epoch 295 completed in 70.81759667396545 Loss: 0.3725255944353288 899 | Test accuracy: 49.95 900 | 901 | Epoch 296 completed in 69.98427629470825 Loss: 0.34837406038693025 902 | Test accuracy: 49.96 903 | 904 | Epoch 297 completed in 70.06458139419556 Loss: 0.32011770335444073 905 | Test accuracy: 49.98 906 | 907 | Epoch 298 completed in 68.23412680625916 Loss: 0.31683775103405365 908 | Test accuracy: 50.06 909 | 910 | Epoch 299 completed in 70.62664890289307 Loss: 0.3289335351736815 911 | Test accuracy: 50.13 912 | 913 | Epoch 300 completed in 70.19696402549744 Loss: 0.38674510540756796 914 | Test accuracy: 43.32 915 | 916 | Epoch 301 completed in 69.92190408706665 Loss: 0.41690284743406597 917 | Test accuracy: 46.12 918 | 919 | Epoch 302 completed in 70.16865396499634 Loss: 0.3236276578581856 920 | Test accuracy: 50.16 921 | 922 | Epoch 303 completed in 70.18615627288818 Loss: 0.2857397908637491 923 | Test accuracy: 50.06 924 | 925 | Epoch 304 completed in 68.5168514251709 Loss: 0.3300496259288349 926 | Test accuracy: 49.8 927 | 928 | Epoch 305 completed in 69.42228412628174 Loss: 0.4809992387658133 929 | Test accuracy: 50.16 930 | 931 | Epoch 306 completed in 67.29459547996521 Loss: 0.2926683754422488 932 | Test accuracy: 50.34 933 | 934 | Epoch 307 completed in 69.40823984146118 Loss: 0.2723056323698864 935 | Test accuracy: 50.25 936 | 937 | Epoch 308 completed in 67.69834089279175 Loss: 0.2674368340967083 938 | Test accuracy: 50.26 939 | 940 | Epoch 309 completed in 68.9913239479065 Loss: 0.26353320488779786 941 | Test accuracy: 50.22 942 | 943 | Epoch 310 completed in 69.02312660217285 Loss: 0.260198056949888 944 | Test accuracy: 50.23 945 | 946 | Epoch 311 completed in 68.21346306800842 Loss: 0.2571857890420153 947 | Test accuracy: 50.19 948 | 949 | Epoch 312 completed in 68.37097597122192 Loss: 0.2544618591427307 950 | Test accuracy: 50.19 951 | 952 | Epoch 313 completed in 72.02266049385071 Loss: 0.2518323923933629 953 | Test accuracy: 50.15 954 | 955 | Epoch 314 completed in 69.50124192237854 Loss: 0.24922662668148085 956 | Test accuracy: 50.16 957 | 958 | Epoch 315 completed in 68.65312647819519 Loss: 0.24660590566980653 959 | Test accuracy: 50.18 960 | 961 | Epoch 316 completed in 69.92671275138855 Loss: 0.2440756251234993 962 | Test accuracy: 50.17 963 | 964 | Epoch 317 completed in 69.47999119758606 Loss: 0.4265097259972579 965 | Test accuracy: 41.21 966 | 967 | Epoch 318 completed in 71.84600830078125 Loss: 0.6910914799452577 968 | Test accuracy: 48.75 969 | 970 | Epoch 319 completed in 69.4047429561615 Loss: 0.68381896723231 971 | Test accuracy: 49.59 972 | 973 | Epoch 320 completed in 69.1953113079071 Loss: 0.5323811328545356 974 | Test accuracy: 49.81 975 | 976 | Epoch 321 completed in 72.18813872337341 Loss: 0.31989570185723304 977 | Test accuracy: 49.85 978 | 979 | Epoch 322 completed in 70.63432216644287 Loss: 0.27742181689763024 980 | Test accuracy: 50.01 981 | 982 | Epoch 323 completed in 69.29381608963013 Loss: 0.29063052263593475 983 | Test accuracy: 50.12 984 | 985 | Epoch 324 completed in 68.69424891471863 Loss: 0.2585108797025106 986 | Test accuracy: 49.92 987 | 988 | Epoch 325 completed in 70.25533771514893 Loss: 0.25457028774499835 989 | Test accuracy: 49.87 990 | 991 | Epoch 326 completed in 70.04188323020935 Loss: 0.24665438534475828 992 | Test accuracy: 50.06 993 | 994 | Epoch 327 completed in 68.37180399894714 Loss: 0.241582837445594 995 | Test accuracy: 50.05 996 | 997 | Epoch 328 completed in 69.09161376953125 Loss: 0.237506310787252 998 | Test accuracy: 50.06 999 | 1000 | Epoch 329 completed in 71.40907883644104 Loss: 0.23456239119018152 1001 | Test accuracy: 50.13 1002 | 1003 | Epoch 330 completed in 69.0515308380127 Loss: 0.23231302738803383 1004 | Test accuracy: 50.2 1005 | 1006 | Epoch 331 completed in 66.82807731628418 Loss: 0.2306415292521038 1007 | Test accuracy: 50.1 1008 | 1009 | Epoch 332 completed in 68.735431432724 Loss: 0.27146723671495765 1010 | Test accuracy: 50.14 1011 | 1012 | Epoch 333 completed in 66.981112241745 Loss: 0.3975971973799326 1013 | Test accuracy: 50.06 1014 | 1015 | Epoch 334 completed in 68.49900484085083 Loss: 0.8275217798498532 1016 | Test accuracy: 49.86 1017 | 1018 | Epoch 335 completed in 69.07605481147766 Loss: 0.7110047912249848 1019 | Test accuracy: 49.63 1020 | 1021 | Epoch 336 completed in 71.52160620689392 Loss: 0.34088282247217316 1022 | Test accuracy: 49.91 1023 | 1024 | Epoch 337 completed in 69.38916730880737 Loss: 0.28039807936555017 1025 | Test accuracy: 49.88 1026 | 1027 | Epoch 338 completed in 68.46503400802612 Loss: 0.23889312626997986 1028 | Test accuracy: 50.06 1029 | 1030 | Epoch 339 completed in 69.63969087600708 Loss: 0.23205829052197532 1031 | Test accuracy: 50.03 1032 | 1033 | Epoch 340 completed in 67.39283084869385 Loss: 0.2275830460509578 1034 | Test accuracy: 49.99 1035 | 1036 | Epoch 341 completed in 67.9783730506897 Loss: 0.2239381096009777 1037 | Test accuracy: 49.99 1038 | 1039 | Epoch 342 completed in 67.00987339019775 Loss: 0.2207864161805976 1040 | Test accuracy: 49.97 1041 | 1042 | Epoch 343 completed in 66.17969465255737 Loss: 0.21795672509018701 1043 | Test accuracy: 49.92 1044 | 1045 | Epoch 344 completed in 67.67376208305359 Loss: 0.21535154050485095 1046 | Test accuracy: 49.92 1047 | 1048 | Epoch 345 completed in 64.47255325317383 Loss: 0.21291221644736177 1049 | Test accuracy: 49.93 1050 | 1051 | Epoch 346 completed in 67.77541542053223 Loss: 0.21060128414070856 1052 | Test accuracy: 49.9 1053 | 1054 | Epoch 347 completed in 64.41680884361267 Loss: 0.20839130133266148 1055 | Test accuracy: 49.92 1056 | 1057 | Epoch 348 completed in 71.62799906730652 Loss: 0.2062790504356335 1058 | Test accuracy: 49.97 1059 | 1060 | Epoch 349 completed in 106.8981077671051 Loss: 0.20421933348194932 1061 | Test accuracy: 49.97 1062 | 1063 | Epoch 350 completed in 120.13204526901245 Loss: 0.20212238406321595 1064 | Test accuracy: 49.93 1065 | 1066 | Epoch 351 completed in 116.61565923690796 Loss: 0.20016102347767276 1067 | Test accuracy: 49.93 1068 | 1069 | Epoch 352 completed in 119.27539324760437 Loss: 0.19829708117543265 1070 | Test accuracy: 49.87 1071 | 1072 | Epoch 353 completed in 120.74486064910889 Loss: 0.19643667423939945 1073 | Test accuracy: 49.84 1074 | 1075 | Epoch 354 completed in 118.88974976539612 Loss: 0.19459668624846319 1076 | Test accuracy: 49.83 1077 | 1078 | Epoch 355 completed in 112.7045841217041 Loss: 0.19281817776198104 1079 | Test accuracy: 49.8 1080 | 1081 | Epoch 356 completed in 122.60651588439941 Loss: 0.19110144045607452 1082 | Test accuracy: 49.84 1083 | 1084 | Epoch 357 completed in 117.22458338737488 Loss: 0.1894309680281559 1085 | Test accuracy: 49.87 1086 | 1087 | Epoch 358 completed in 113.88178706169128 Loss: 0.18779145595679406 1088 | Test accuracy: 49.84 1089 | 1090 | Epoch 359 completed in 116.42365384101868 Loss: 0.18618054105389875 1091 | Test accuracy: 49.87 1092 | 1093 | Epoch 360 completed in 114.9848165512085 Loss: 0.1846005896509607 1094 | Test accuracy: 49.92 1095 | 1096 | Epoch 361 completed in 116.95010328292847 Loss: 0.18305238014289174 1097 | Test accuracy: 49.94 1098 | 1099 | Epoch 362 completed in 117.7083899974823 Loss: 0.18153342784320273 1100 | Test accuracy: 49.92 1101 | 1102 | Epoch 363 completed in 116.17603182792664 Loss: 0.18003903055909878 1103 | Test accuracy: 49.91 1104 | 1105 | Epoch 364 completed in 121.79050374031067 Loss: 0.17856653629318722 1106 | Test accuracy: 49.9 1107 | 1108 | Epoch 365 completed in 116.13130068778992 Loss: 0.17711621967599236 1109 | Test accuracy: 49.92 1110 | 1111 | Epoch 366 completed in 121.2351279258728 Loss: 0.17568936151609893 1112 | Test accuracy: 49.92 1113 | 1114 | Epoch 367 completed in 112.89200806617737 Loss: 0.17428516757706033 1115 | Test accuracy: 49.96 1116 | 1117 | Epoch 368 completed in 117.47594261169434 Loss: 0.17290132249258724 1118 | Test accuracy: 49.97 1119 | 1120 | Epoch 369 completed in 112.31957006454468 Loss: 0.17153497228574285 1121 | Test accuracy: 49.98 1122 | 1123 | Epoch 370 completed in 120.60543370246887 Loss: 0.17018646486935554 1124 | Test accuracy: 49.93 1125 | 1126 | Epoch 371 completed in 119.7005307674408 Loss: 0.16885543286917187 1127 | Test accuracy: 49.87 1128 | 1129 | Epoch 372 completed in 115.47239947319031 Loss: 0.16754527453202483 1130 | Test accuracy: 49.83 1131 | 1132 | Epoch 373 completed in 121.86983442306519 Loss: 0.16624869086308253 1133 | Test accuracy: 49.81 1134 | 1135 | Epoch 374 completed in 119.32853722572327 Loss: 0.1649767830990035 1136 | Test accuracy: 49.79 1137 | 1138 | Epoch 375 completed in 115.69028520584106 Loss: 0.163699113044429 1139 | Test accuracy: 49.74 1140 | 1141 | Epoch 376 completed in 117.2735710144043 Loss: 0.16248915141130688 1142 | Test accuracy: 49.67 1143 | 1144 | Epoch 377 completed in 115.49784564971924 Loss: 0.1612152661387229 1145 | Test accuracy: 49.75 1146 | 1147 | Epoch 378 completed in 119.48982119560242 Loss: 0.16034255224912 1148 | Test accuracy: 49.61 1149 | 1150 | Epoch 379 completed in 116.99925518035889 Loss: 0.8756902114525551 1151 | Test accuracy: 41.82 1152 | 1153 | Epoch 380 completed in 116.51480007171631 Loss: 2.164640277770898 1154 | Test accuracy: 37.12 1155 | 1156 | Epoch 381 completed in 120.60049891471863 Loss: 1.5046877300375388 1157 | Test accuracy: 43.75 1158 | 1159 | Epoch 382 completed in 115.93498468399048 Loss: 1.2301079724724642 1160 | Test accuracy: 46.21 1161 | 1162 | Epoch 383 completed in 98.33029890060425 Loss: 1.002147779038433 1163 | Test accuracy: 46.41 1164 | 1165 | Epoch 384 completed in 98.90586137771606 Loss: 0.9654435192738415 1166 | Test accuracy: 47.18 1167 | 1168 | Epoch 385 completed in 101.10772824287415 Loss: 0.7982344462000202 1169 | Test accuracy: 47.13 1170 | 1171 | Epoch 386 completed in 100.83912396430969 Loss: 0.7560235011437018 1172 | Test accuracy: 47.75 1173 | 1174 | Epoch 387 completed in 99.05274844169617 Loss: 0.6602670991911797 1175 | Test accuracy: 45.83 1176 | 1177 | Epoch 388 completed in 103.05192303657532 Loss: 0.6529552892801775 1178 | Test accuracy: 48.55 1179 | 1180 | Epoch 389 completed in 101.31489038467407 Loss: 0.6308781612409652 1181 | Test accuracy: 48.1 1182 | 1183 | Epoch 390 completed in 98.68316149711609 Loss: 0.5885506028219393 1184 | Test accuracy: 48.22 1185 | 1186 | Epoch 391 completed in 103.74336910247803 Loss: 0.5735586677955061 1187 | Test accuracy: 48.15 1188 | 1189 | Epoch 392 completed in 104.81183195114136 Loss: 0.5180062969072419 1190 | Test accuracy: 47.01 1191 | 1192 | Epoch 393 completed in 100.68421292304993 Loss: 0.5459557001973325 1193 | Test accuracy: 48.37 1194 | 1195 | Epoch 394 completed in 100.58251738548279 Loss: 0.563905098777564 1196 | Test accuracy: 48.42 1197 | 1198 | Epoch 395 completed in 103.85392951965332 Loss: 0.47720673806729996 1199 | Test accuracy: 48.04 1200 | 1201 | Epoch 396 completed in 101.49203610420227 Loss: 0.4901077750498401 1202 | Test accuracy: 48.0 1203 | 1204 | Epoch 397 completed in 98.23243069648743 Loss: 0.4330924555180104 1205 | Test accuracy: 47.76 1206 | 1207 | Epoch 398 completed in 104.59896755218506 Loss: 0.46083591758897213 1208 | Test accuracy: 48.96 1209 | 1210 | Epoch 399 completed in 106.47909927368164 Loss: 0.4513985798558965 1211 | Test accuracy: 48.18 1212 | -------------------------------------------------------------------------------- /experiments/cifar10/bp-V0/results.json: -------------------------------------------------------------------------------- 1 | {"epoch0": {"loss": 3.053297821612102, "test_accuracy": 29.46}, "epoch1": {"loss": 2.7544624696857425, "test_accuracy": 34.1}, "epoch2": {"loss": 2.657579030362338, "test_accuracy": 37.1}, "epoch3": {"loss": 2.5885555230662174, "test_accuracy": 39.06}, "epoch4": {"loss": 2.5341922342093666, "test_accuracy": 40.57}, "epoch5": {"loss": 2.4886016642975153, "test_accuracy": 41.76}, "epoch6": {"loss": 2.4497514553688973, "test_accuracy": 42.47}, "epoch7": {"loss": 2.416203075461768, "test_accuracy": 43.12}, "epoch8": {"loss": 2.386312792858469, "test_accuracy": 43.97}, "epoch9": {"loss": 2.3589836654963596, "test_accuracy": 44.64}, "epoch10": {"loss": 2.333677006296486, "test_accuracy": 45.24}, "epoch11": {"loss": 2.310066980723325, "test_accuracy": 45.84}, "epoch12": {"loss": 2.2879233478364225, "test_accuracy": 46.23}, "epoch13": {"loss": 2.2670541495969303, "test_accuracy": 46.64}, "epoch14": {"loss": 2.24729992419288, "test_accuracy": 47.08}, "epoch15": {"loss": 2.2285314843657877, "test_accuracy": 47.34}, "epoch16": {"loss": 2.2106469899407992, "test_accuracy": 47.6}, "epoch17": {"loss": 2.1935450179010476, "test_accuracy": 47.87}, "epoch18": {"loss": 2.1771187933720126, "test_accuracy": 48.0}, "epoch19": {"loss": 2.161273952346151, "test_accuracy": 48.19}, "epoch20": {"loss": 2.1459346099283043, "test_accuracy": 48.4}, "epoch21": {"loss": 2.131042156391768, "test_accuracy": 48.47}, "epoch22": {"loss": 2.1165507195715287, "test_accuracy": 48.66}, "epoch23": {"loss": 2.1024204165816274, "test_accuracy": 48.65}, "epoch24": {"loss": 2.0886137621787815, "test_accuracy": 48.81}, "epoch25": {"loss": 2.07509543172961, "test_accuracy": 48.95}, "epoch26": {"loss": 2.0618350155436205, "test_accuracy": 49.21}, "epoch27": {"loss": 2.0488099756167095, "test_accuracy": 49.39}, "epoch28": {"loss": 2.036004031853832, "test_accuracy": 49.54}, "epoch29": {"loss": 2.0234030278233885, "test_accuracy": 49.67}, "epoch30": {"loss": 2.0109929126773554, "test_accuracy": 49.73}, "epoch31": {"loss": 1.9987595909337506, "test_accuracy": 49.94}, "epoch32": {"loss": 1.9866888320945661, "test_accuracy": 50.06}, "epoch33": {"loss": 1.9747664235481248, "test_accuracy": 50.26}, "epoch34": {"loss": 1.9629787613128167, "test_accuracy": 50.33}, "epoch35": {"loss": 1.951313162050308, "test_accuracy": 50.4}, "epoch36": {"loss": 1.9397579336668678, "test_accuracy": 50.43}, "epoch37": {"loss": 1.9283027563723654, "test_accuracy": 50.52}, "epoch38": {"loss": 1.9169390230863919, "test_accuracy": 50.71}, "epoch39": {"loss": 1.9056596072487555, "test_accuracy": 50.79}, "epoch40": {"loss": 1.8944583908445587, "test_accuracy": 50.84}, "epoch41": {"loss": 1.88332997300468, "test_accuracy": 50.77}, "epoch42": {"loss": 1.8722696718810348, "test_accuracy": 50.95}, "epoch43": {"loss": 1.861273592122536, "test_accuracy": 50.95}, "epoch44": {"loss": 1.8503384064400727, "test_accuracy": 50.97}, "epoch45": {"loss": 1.83946092841614, "test_accuracy": 51.04}, "epoch46": {"loss": 1.8286378638049863, "test_accuracy": 51.14}, "epoch47": {"loss": 1.8178659230459533, "test_accuracy": 51.29}, "epoch48": {"loss": 1.8071421374204262, "test_accuracy": 51.33}, "epoch49": {"loss": 1.7964642267969984, "test_accuracy": 51.41}, "epoch50": {"loss": 1.785830765124631, "test_accuracy": 51.5}, "epoch51": {"loss": 1.7752407527480503, "test_accuracy": 51.6}, "epoch52": {"loss": 1.7646926317366858, "test_accuracy": 51.65}, "epoch53": {"loss": 1.754183474453197, "test_accuracy": 51.76}, "epoch54": {"loss": 1.7437090597635378, "test_accuracy": 51.9}, "epoch55": {"loss": 1.7332649420652027, "test_accuracy": 51.9}, "epoch56": {"loss": 1.7228475509935774, "test_accuracy": 51.9}, "epoch57": {"loss": 1.7124540179898529, "test_accuracy": 51.97}, "epoch58": {"loss": 1.702081366005622, "test_accuracy": 51.89}, "epoch59": {"loss": 1.6917264732544823, "test_accuracy": 51.93}, "epoch60": {"loss": 1.6813867041029704, "test_accuracy": 51.86}, "epoch61": {"loss": 1.6710603013723404, "test_accuracy": 51.92}, "epoch62": {"loss": 1.660746224564675, "test_accuracy": 51.89}, "epoch63": {"loss": 1.6504439504532697, "test_accuracy": 51.81}, "epoch64": {"loss": 1.6401536076002945, "test_accuracy": 51.85}, "epoch65": {"loss": 1.62987651980273, "test_accuracy": 51.69}, "epoch66": {"loss": 1.6196162514052428, "test_accuracy": 51.74}, "epoch67": {"loss": 1.6093799751616602, "test_accuracy": 51.7}, "epoch68": {"loss": 1.5991789831162144, "test_accuracy": 51.55}, "epoch69": {"loss": 1.5890257316452252, "test_accuracy": 51.42}, "epoch70": {"loss": 1.578925961561165, "test_accuracy": 51.29}, "epoch71": {"loss": 1.5688736785761248, "test_accuracy": 51.31}, "epoch72": {"loss": 1.5588670329932717, "test_accuracy": 51.24}, "epoch73": {"loss": 1.548944845841263, "test_accuracy": 51.19}, "epoch74": {"loss": 1.5393038092986322, "test_accuracy": 51.08}, "epoch75": {"loss": 1.5300024984051075, "test_accuracy": 51.09}, "epoch76": {"loss": 1.527897266640774, "test_accuracy": 51.26}, "epoch77": {"loss": 1.5135208774070774, "test_accuracy": 51.14}, "epoch78": {"loss": 1.5068405139067385, "test_accuracy": 50.99}, "epoch79": {"loss": 1.4945305072492228, "test_accuracy": 51.27}, "epoch80": {"loss": 1.4872993788459716, "test_accuracy": 50.83}, "epoch81": {"loss": 1.4769146749996116, "test_accuracy": 51.11}, "epoch82": {"loss": 1.4665974074120334, "test_accuracy": 51.23}, "epoch83": {"loss": 1.4601976003214532, "test_accuracy": 51.1}, "epoch84": {"loss": 1.4425586407313118, "test_accuracy": 51.13}, "epoch85": {"loss": 1.4395060560085473, "test_accuracy": 51.13}, "epoch86": {"loss": 1.4298239779663735, "test_accuracy": 50.8}, "epoch87": {"loss": 1.4166283433304228, "test_accuracy": 50.95}, "epoch88": {"loss": 1.4138485485077925, "test_accuracy": 50.95}, "epoch89": {"loss": 1.4025990300095164, "test_accuracy": 50.94}, "epoch90": {"loss": 1.3959181884002312, "test_accuracy": 51.0}, "epoch91": {"loss": 1.3805346197981345, "test_accuracy": 50.52}, "epoch92": {"loss": 1.376455387795295, "test_accuracy": 51.1}, "epoch93": {"loss": 1.3682275220173614, "test_accuracy": 51.1}, "epoch94": {"loss": 1.356601640376385, "test_accuracy": 50.48}, "epoch95": {"loss": 1.3503017749006125, "test_accuracy": 51.17}, "epoch96": {"loss": 1.3336550699045469, "test_accuracy": 50.74}, "epoch97": {"loss": 1.3324539219622293, "test_accuracy": 51.11}, "epoch98": {"loss": 1.3205711032967522, "test_accuracy": 50.83}, "epoch99": {"loss": 1.3120985488589383, "test_accuracy": 51.07}, "epoch100": {"loss": 1.2985679847864886, "test_accuracy": 50.97}, "epoch101": {"loss": 1.287547636267768, "test_accuracy": 50.57}, "epoch102": {"loss": 1.2882279363402853, "test_accuracy": 51.06}, "epoch103": {"loss": 1.2710199947931484, "test_accuracy": 51.12}, "epoch104": {"loss": 1.2734094775185405, "test_accuracy": 50.78}, "epoch105": {"loss": 1.257790563521754, "test_accuracy": 50.67}, "epoch106": {"loss": 1.244044202099717, "test_accuracy": 50.7}, "epoch107": {"loss": 1.2459859394029764, "test_accuracy": 50.85}, "epoch108": {"loss": 1.2287375325602887, "test_accuracy": 50.61}, "epoch109": {"loss": 1.2318849596941224, "test_accuracy": 50.62}, "epoch110": {"loss": 1.2179448237873447, "test_accuracy": 50.81}, "epoch111": {"loss": 1.2149614370320683, "test_accuracy": 50.44}, "epoch112": {"loss": 1.1995577792891874, "test_accuracy": 50.67}, "epoch113": {"loss": 1.197275143731275, "test_accuracy": 50.28}, "epoch114": {"loss": 1.1862311447232923, "test_accuracy": 50.68}, "epoch115": {"loss": 1.174250299314035, "test_accuracy": 50.64}, "epoch116": {"loss": 1.1698691226246885, "test_accuracy": 50.72}, "epoch117": {"loss": 1.1559458142173933, "test_accuracy": 50.74}, "epoch118": {"loss": 1.1537784514656675, "test_accuracy": 50.29}, "epoch119": {"loss": 1.1484372212011114, "test_accuracy": 50.54}, "epoch120": {"loss": 1.1332561777055334, "test_accuracy": 50.69}, "epoch121": {"loss": 1.132647291981585, "test_accuracy": 50.46}, "epoch122": {"loss": 1.1158431677937604, "test_accuracy": 50.95}, "epoch123": {"loss": 1.1204003140444763, "test_accuracy": 50.92}, "epoch124": {"loss": 1.1114335921687921, "test_accuracy": 50.51}, "epoch125": {"loss": 1.10106841948988, "test_accuracy": 50.17}, "epoch126": {"loss": 1.0903906857938068, "test_accuracy": 50.41}, "epoch127": {"loss": 1.0796441455507035, "test_accuracy": 50.66}, "epoch128": {"loss": 1.0791847006945505, "test_accuracy": 50.62}, "epoch129": {"loss": 1.0716726259243983, "test_accuracy": 50.55}, "epoch130": {"loss": 1.0570679739533766, "test_accuracy": 50.34}, "epoch131": {"loss": 1.0554593191201072, "test_accuracy": 51.17}, "epoch132": {"loss": 1.048038886549545, "test_accuracy": 50.3}, "epoch133": {"loss": 1.0518754234551375, "test_accuracy": 51.08}, "epoch134": {"loss": 1.0352695514414578, "test_accuracy": 50.48}, "epoch135": {"loss": 1.0157572713289291, "test_accuracy": 50.89}, "epoch136": {"loss": 1.0168396434306919, "test_accuracy": 50.55}, "epoch137": {"loss": 1.0264149733562156, "test_accuracy": 50.75}, "epoch138": {"loss": 0.9992703363510664, "test_accuracy": 50.95}, "epoch139": {"loss": 0.9834487744836374, "test_accuracy": 50.76}, "epoch140": {"loss": 0.9881911178565947, "test_accuracy": 50.53}, "epoch141": {"loss": 0.9793821117827768, "test_accuracy": 50.74}, "epoch142": {"loss": 0.965372885375053, "test_accuracy": 50.49}, "epoch143": {"loss": 0.9841547396333404, "test_accuracy": 50.41}, "epoch144": {"loss": 0.9623321892185237, "test_accuracy": 50.08}, "epoch145": {"loss": 0.9595198568085187, "test_accuracy": 51.03}, "epoch146": {"loss": 0.9419436908063759, "test_accuracy": 50.61}, "epoch147": {"loss": 0.9526704004510792, "test_accuracy": 50.74}, "epoch148": {"loss": 0.9253438011948603, "test_accuracy": 50.18}, "epoch149": {"loss": 0.9317279917406975, "test_accuracy": 50.26}, "epoch150": {"loss": 0.930726158964817, "test_accuracy": 50.55}, "epoch151": {"loss": 0.9108862741914514, "test_accuracy": 50.28}, "epoch152": {"loss": 0.9065845260127082, "test_accuracy": 50.22}, "epoch153": {"loss": 0.9051368851866703, "test_accuracy": 49.04}, "epoch154": {"loss": 0.8951030757623596, "test_accuracy": 50.14}, "epoch155": {"loss": 0.8878596731146343, "test_accuracy": 50.06}, "epoch156": {"loss": 0.8849401036231553, "test_accuracy": 50.23}, "epoch157": {"loss": 0.8708838068311624, "test_accuracy": 50.73}, "epoch158": {"loss": 0.8702117041299111, "test_accuracy": 50.13}, "epoch159": {"loss": 0.8642822868252694, "test_accuracy": 50.34}, "epoch160": {"loss": 0.8594092126498233, "test_accuracy": 50.28}, "epoch161": {"loss": 0.854594938450758, "test_accuracy": 50.09}, "epoch162": {"loss": 0.8403203827486713, "test_accuracy": 49.77}, "epoch163": {"loss": 0.8333498032819532, "test_accuracy": 49.74}, "epoch164": {"loss": 0.8353561164541388, "test_accuracy": 50.03}, "epoch165": {"loss": 0.8431064658981621, "test_accuracy": 50.68}, "epoch166": {"loss": 0.8103193470248407, "test_accuracy": 50.52}, "epoch167": {"loss": 0.8026834802661839, "test_accuracy": 50.44}, "epoch168": {"loss": 0.8054162889587881, "test_accuracy": 50.01}, "epoch169": {"loss": 0.8122239779617819, "test_accuracy": 49.96}, "epoch170": {"loss": 0.8003557425692371, "test_accuracy": 49.81}, "epoch171": {"loss": 0.783970322113187, "test_accuracy": 50.44}, "epoch172": {"loss": 0.7866407617735302, "test_accuracy": 50.0}, "epoch173": {"loss": 0.771387713402842, "test_accuracy": 50.45}, "epoch174": {"loss": 0.7669279982381874, "test_accuracy": 49.5}, "epoch175": {"loss": 0.7591560928932742, "test_accuracy": 49.76}, "epoch176": {"loss": 0.7737895389083942, "test_accuracy": 50.15}, "epoch177": {"loss": 0.7768072272459314, "test_accuracy": 49.75}, "epoch178": {"loss": 0.7381474313711627, "test_accuracy": 49.22}, "epoch179": {"loss": 0.7492426251514858, "test_accuracy": 49.12}, "epoch180": {"loss": 0.7378499842463854, "test_accuracy": 49.9}, "epoch181": {"loss": 0.7259899866392359, "test_accuracy": 50.15}, "epoch182": {"loss": 0.7399774146565002, "test_accuracy": 50.07}, "epoch183": {"loss": 0.7150899542402003, "test_accuracy": 49.99}, "epoch184": {"loss": 0.7006130522396126, "test_accuracy": 50.15}, "epoch185": {"loss": 0.7236990432398175, "test_accuracy": 49.43}, "epoch186": {"loss": 0.7340686457366491, "test_accuracy": 50.33}, "epoch187": {"loss": 0.7006170731988347, "test_accuracy": 49.57}, "epoch188": {"loss": 0.679801849657919, "test_accuracy": 49.21}, "epoch189": {"loss": 0.6855390549646554, "test_accuracy": 49.95}, "epoch190": {"loss": 0.6844957052488151, "test_accuracy": 48.22}, "epoch191": {"loss": 0.6668315545458783, "test_accuracy": 49.77}, "epoch192": {"loss": 0.6555646640639419, "test_accuracy": 50.01}, "epoch193": {"loss": 0.6909166839173838, "test_accuracy": 49.41}, "epoch194": {"loss": 0.6704373824899802, "test_accuracy": 49.33}, "epoch195": {"loss": 0.6707354564546413, "test_accuracy": 49.74}, "epoch196": {"loss": 0.6550735786869909, "test_accuracy": 49.59}, "epoch197": {"loss": 0.630238333756259, "test_accuracy": 49.72}, "epoch198": {"loss": 0.6343156487706713, "test_accuracy": 48.83}, "epoch199": {"loss": 0.6136767549911878, "test_accuracy": 49.19}, "epoch200": {"loss": 0.6399702954634329, "test_accuracy": 49.1}, "epoch201": {"loss": 0.6339319455793416, "test_accuracy": 49.9}, "epoch202": {"loss": 0.6329453620798756, "test_accuracy": 50.05}, "epoch203": {"loss": 0.6488046123318139, "test_accuracy": 50.18}, "epoch204": {"loss": 0.6005514627165517, "test_accuracy": 49.11}, "epoch205": {"loss": 0.6174472354522805, "test_accuracy": 49.51}, "epoch206": {"loss": 0.6144585067992689, "test_accuracy": 49.92}, "epoch207": {"loss": 0.603719911268876, "test_accuracy": 49.93}, "epoch208": {"loss": 0.5905364785059445, "test_accuracy": 49.75}, "epoch209": {"loss": 0.5822582430118708, "test_accuracy": 47.68}, "epoch210": {"loss": 0.5687415711088983, "test_accuracy": 49.91}, "epoch211": {"loss": 0.5685769552059229, "test_accuracy": 49.69}, "epoch212": {"loss": 0.5764356261619419, "test_accuracy": 49.3}, "epoch213": {"loss": 0.572043352180387, "test_accuracy": 49.33}, "epoch214": {"loss": 0.5852667439129502, "test_accuracy": 49.84}, "epoch215": {"loss": 0.5567754724142221, "test_accuracy": 49.47}, "epoch216": {"loss": 0.5604705977008992, "test_accuracy": 49.76}, "epoch217": {"loss": 0.5473058124971276, "test_accuracy": 49.91}, "epoch218": {"loss": 0.5565082598972533, "test_accuracy": 49.72}, "epoch219": {"loss": 0.5147932201095579, "test_accuracy": 49.38}, "epoch220": {"loss": 0.5329322109265836, "test_accuracy": 49.59}, "epoch221": {"loss": 0.5618347461930323, "test_accuracy": 49.75}, "epoch222": {"loss": 0.5635242147217016, "test_accuracy": 49.76}, "epoch223": {"loss": 0.549970324584659, "test_accuracy": 49.5}, "epoch224": {"loss": 0.49411860565857085, "test_accuracy": 47.1}, "epoch225": {"loss": 0.5170639677280536, "test_accuracy": 49.68}, "epoch226": {"loss": 0.5554748183439132, "test_accuracy": 48.81}, "epoch227": {"loss": 0.5025571091472281, "test_accuracy": 49.79}, "epoch228": {"loss": 0.5317351941467594, "test_accuracy": 49.97}, "epoch229": {"loss": 0.4980491561705108, "test_accuracy": 49.72}, "epoch230": {"loss": 0.5074904811280809, "test_accuracy": 49.55}, "epoch231": {"loss": 0.502723312367161, "test_accuracy": 49.47}, "epoch232": {"loss": 0.5061692830962051, "test_accuracy": 49.53}, "epoch233": {"loss": 0.47970488459987526, "test_accuracy": 49.67}, "epoch234": {"loss": 0.46156447548316815, "test_accuracy": 49.79}, "epoch235": {"loss": 0.4594039121903712, "test_accuracy": 49.79}, "epoch236": {"loss": 0.4734331089114226, "test_accuracy": 48.7}, "epoch237": {"loss": 0.49071346182209286, "test_accuracy": 50.11}, "epoch238": {"loss": 0.482453543982539, "test_accuracy": 50.04}, "epoch239": {"loss": 0.47439764642277604, "test_accuracy": 50.0}, "epoch240": {"loss": 0.4466710151378898, "test_accuracy": 49.57}, "epoch241": {"loss": 0.44755575374227763, "test_accuracy": 49.76}, "epoch242": {"loss": 0.4741832462971999, "test_accuracy": 49.93}, "epoch243": {"loss": 0.4719313809739829, "test_accuracy": 49.77}, "epoch244": {"loss": 0.665724978645717, "test_accuracy": 50.14}, "epoch245": {"loss": 0.44958445047360845, "test_accuracy": 50.16}, "epoch246": {"loss": 0.4060911753973073, "test_accuracy": 49.85}, "epoch247": {"loss": 0.44626861092479675, "test_accuracy": 49.78}, "epoch248": {"loss": 0.4137142668500845, "test_accuracy": 49.66}, "epoch249": {"loss": 0.42818719397999105, "test_accuracy": 50.07}, "epoch250": {"loss": 0.45034449138487426, "test_accuracy": 49.91}, "epoch251": {"loss": 0.4196470639603519, "test_accuracy": 50.01}, "epoch252": {"loss": 0.4018933982053706, "test_accuracy": 49.96}, "epoch253": {"loss": 0.38113126069504866, "test_accuracy": 49.78}, "epoch254": {"loss": 0.42149596447371673, "test_accuracy": 49.66}, "epoch255": {"loss": 0.3816402812382329, "test_accuracy": 49.91}, "epoch256": {"loss": 0.4150739622195174, "test_accuracy": 49.79}, "epoch257": {"loss": 0.42498641911468843, "test_accuracy": 49.82}, "epoch258": {"loss": 0.39847836884751525, "test_accuracy": 49.72}, "epoch259": {"loss": 0.43361971990549086, "test_accuracy": 49.54}, "epoch260": {"loss": 0.3891695158849426, "test_accuracy": 49.94}, "epoch261": {"loss": 0.7068043450432698, "test_accuracy": 49.87}, "epoch262": {"loss": 0.4301938394349378, "test_accuracy": 50.01}, "epoch263": {"loss": 0.3886397264029654, "test_accuracy": 50.1}, "epoch264": {"loss": 0.3667340598561384, "test_accuracy": 49.79}, "epoch265": {"loss": 0.35780089120547764, "test_accuracy": 50.09}, "epoch266": {"loss": 0.3447011418913576, "test_accuracy": 49.97}, "epoch267": {"loss": 0.34365214400014454, "test_accuracy": 49.8}, "epoch268": {"loss": 0.3775327487984379, "test_accuracy": 50.08}, "epoch269": {"loss": 0.3663604293898873, "test_accuracy": 49.99}, "epoch270": {"loss": 0.46567765292373026, "test_accuracy": 49.9}, "epoch271": {"loss": 0.6593278161154493, "test_accuracy": 42.83}, "epoch272": {"loss": 0.41080183540994514, "test_accuracy": 49.93}, "epoch273": {"loss": 0.41924891383366175, "test_accuracy": 50.22}, "epoch274": {"loss": 0.34389010725864055, "test_accuracy": 50.03}, "epoch275": {"loss": 0.3395181958236658, "test_accuracy": 50.06}, "epoch276": {"loss": 0.3325470660432857, "test_accuracy": 50.09}, "epoch277": {"loss": 0.4407182246867028, "test_accuracy": 49.23}, "epoch278": {"loss": 0.5280027963032736, "test_accuracy": 49.98}, "epoch279": {"loss": 0.34592754767675393, "test_accuracy": 49.9}, "epoch280": {"loss": 0.31600678640282576, "test_accuracy": 49.87}, "epoch281": {"loss": 0.3120363241524801, "test_accuracy": 49.79}, "epoch282": {"loss": 0.37833726395304845, "test_accuracy": 49.99}, "epoch283": {"loss": 0.32649194111068536, "test_accuracy": 49.68}, "epoch284": {"loss": 0.6624806551582927, "test_accuracy": 48.2}, "epoch285": {"loss": 0.457794253056153, "test_accuracy": 49.79}, "epoch286": {"loss": 0.3547179481467169, "test_accuracy": 49.89}, "epoch287": {"loss": 0.31631557236398583, "test_accuracy": 49.75}, "epoch288": {"loss": 0.3069762845149854, "test_accuracy": 49.7}, "epoch289": {"loss": 0.29835245643477404, "test_accuracy": 49.76}, "epoch290": {"loss": 0.2948006036874856, "test_accuracy": 49.92}, "epoch291": {"loss": 0.29078594432016985, "test_accuracy": 49.71}, "epoch292": {"loss": 0.8893957775266087, "test_accuracy": 49.29}, "epoch293": {"loss": 0.5914874555431668, "test_accuracy": 49.56}, "epoch294": {"loss": 0.4366035156249039, "test_accuracy": 49.65}, "epoch295": {"loss": 0.3725255944353288, "test_accuracy": 49.95}, "epoch296": {"loss": 0.34837406038693025, "test_accuracy": 49.96}, "epoch297": {"loss": 0.32011770335444073, "test_accuracy": 49.98}, "epoch298": {"loss": 0.31683775103405365, "test_accuracy": 50.06}, "epoch299": {"loss": 0.3289335351736815, "test_accuracy": 50.13}, "epoch300": {"loss": 0.38674510540756796, "test_accuracy": 43.32}, "epoch301": {"loss": 0.41690284743406597, "test_accuracy": 46.12}, "epoch302": {"loss": 0.3236276578581856, "test_accuracy": 50.16}, "epoch303": {"loss": 0.2857397908637491, "test_accuracy": 50.06}, "epoch304": {"loss": 0.3300496259288349, "test_accuracy": 49.8}, "epoch305": {"loss": 0.4809992387658133, "test_accuracy": 50.16}, "epoch306": {"loss": 0.2926683754422488, "test_accuracy": 50.34}, "epoch307": {"loss": 0.2723056323698864, "test_accuracy": 50.25}, "epoch308": {"loss": 0.2674368340967083, "test_accuracy": 50.26}, "epoch309": {"loss": 0.26353320488779786, "test_accuracy": 50.22}, "epoch310": {"loss": 0.260198056949888, "test_accuracy": 50.23}, "epoch311": {"loss": 0.2571857890420153, "test_accuracy": 50.19}, "epoch312": {"loss": 0.2544618591427307, "test_accuracy": 50.19}, "epoch313": {"loss": 0.2518323923933629, "test_accuracy": 50.15}, "epoch314": {"loss": 0.24922662668148085, "test_accuracy": 50.16}, "epoch315": {"loss": 0.24660590566980653, "test_accuracy": 50.18}, "epoch316": {"loss": 0.2440756251234993, "test_accuracy": 50.17}, "epoch317": {"loss": 0.4265097259972579, "test_accuracy": 41.21}, "epoch318": {"loss": 0.6910914799452577, "test_accuracy": 48.75}, "epoch319": {"loss": 0.68381896723231, "test_accuracy": 49.59}, "epoch320": {"loss": 0.5323811328545356, "test_accuracy": 49.81}, "epoch321": {"loss": 0.31989570185723304, "test_accuracy": 49.85}, "epoch322": {"loss": 0.27742181689763024, "test_accuracy": 50.01}, "epoch323": {"loss": 0.29063052263593475, "test_accuracy": 50.12}, "epoch324": {"loss": 0.2585108797025106, "test_accuracy": 49.92}, "epoch325": {"loss": 0.25457028774499835, "test_accuracy": 49.87}, "epoch326": {"loss": 0.24665438534475828, "test_accuracy": 50.06}, "epoch327": {"loss": 0.241582837445594, "test_accuracy": 50.05}, "epoch328": {"loss": 0.237506310787252, "test_accuracy": 50.06}, "epoch329": {"loss": 0.23456239119018152, "test_accuracy": 50.13}, "epoch330": {"loss": 0.23231302738803383, "test_accuracy": 50.2}, "epoch331": {"loss": 0.2306415292521038, "test_accuracy": 50.1}, "epoch332": {"loss": 0.27146723671495765, "test_accuracy": 50.14}, "epoch333": {"loss": 0.3975971973799326, "test_accuracy": 50.06}, "epoch334": {"loss": 0.8275217798498532, "test_accuracy": 49.86}, "epoch335": {"loss": 0.7110047912249848, "test_accuracy": 49.63}, "epoch336": {"loss": 0.34088282247217316, "test_accuracy": 49.91}, "epoch337": {"loss": 0.28039807936555017, "test_accuracy": 49.88}, "epoch338": {"loss": 0.23889312626997986, "test_accuracy": 50.06}, "epoch339": {"loss": 0.23205829052197532, "test_accuracy": 50.03}, "epoch340": {"loss": 0.2275830460509578, "test_accuracy": 49.99}, "epoch341": {"loss": 0.2239381096009777, "test_accuracy": 49.99}, "epoch342": {"loss": 0.2207864161805976, "test_accuracy": 49.97}, "epoch343": {"loss": 0.21795672509018701, "test_accuracy": 49.92}, "epoch344": {"loss": 0.21535154050485095, "test_accuracy": 49.92}, "epoch345": {"loss": 0.21291221644736177, "test_accuracy": 49.93}, "epoch346": {"loss": 0.21060128414070856, "test_accuracy": 49.9}, "epoch347": {"loss": 0.20839130133266148, "test_accuracy": 49.92}, "epoch348": {"loss": 0.2062790504356335, "test_accuracy": 49.97}, "epoch349": {"loss": 0.20421933348194932, "test_accuracy": 49.97}, "epoch350": {"loss": 0.20212238406321595, "test_accuracy": 49.93}, "epoch351": {"loss": 0.20016102347767276, "test_accuracy": 49.93}, "epoch352": {"loss": 0.19829708117543265, "test_accuracy": 49.87}, "epoch353": {"loss": 0.19643667423939945, "test_accuracy": 49.84}, "epoch354": {"loss": 0.19459668624846319, "test_accuracy": 49.83}, "epoch355": {"loss": 0.19281817776198104, "test_accuracy": 49.8}, "epoch356": {"loss": 0.19110144045607452, "test_accuracy": 49.84}, "epoch357": {"loss": 0.1894309680281559, "test_accuracy": 49.87}, "epoch358": {"loss": 0.18779145595679406, "test_accuracy": 49.84}, "epoch359": {"loss": 0.18618054105389875, "test_accuracy": 49.87}, "epoch360": {"loss": 0.1846005896509607, "test_accuracy": 49.92}, "epoch361": {"loss": 0.18305238014289174, "test_accuracy": 49.94}, "epoch362": {"loss": 0.18153342784320273, "test_accuracy": 49.92}, "epoch363": {"loss": 0.18003903055909878, "test_accuracy": 49.91}, "epoch364": {"loss": 0.17856653629318722, "test_accuracy": 49.9}, "epoch365": {"loss": 0.17711621967599236, "test_accuracy": 49.92}, "epoch366": {"loss": 0.17568936151609893, "test_accuracy": 49.92}, "epoch367": {"loss": 0.17428516757706033, "test_accuracy": 49.96}, "epoch368": {"loss": 0.17290132249258724, "test_accuracy": 49.97}, "epoch369": {"loss": 0.17153497228574285, "test_accuracy": 49.98}, "epoch370": {"loss": 0.17018646486935554, "test_accuracy": 49.93}, "epoch371": {"loss": 0.16885543286917187, "test_accuracy": 49.87}, "epoch372": {"loss": 0.16754527453202483, "test_accuracy": 49.83}, "epoch373": {"loss": 0.16624869086308253, "test_accuracy": 49.81}, "epoch374": {"loss": 0.1649767830990035, "test_accuracy": 49.79}, "epoch375": {"loss": 0.163699113044429, "test_accuracy": 49.74}, "epoch376": {"loss": 0.16248915141130688, "test_accuracy": 49.67}, "epoch377": {"loss": 0.1612152661387229, "test_accuracy": 49.75}, "epoch378": {"loss": 0.16034255224912, "test_accuracy": 49.61}, "epoch379": {"loss": 0.8756902114525551, "test_accuracy": 41.82}, "epoch380": {"loss": 2.164640277770898, "test_accuracy": 37.12}, "epoch381": {"loss": 1.5046877300375388, "test_accuracy": 43.75}, "epoch382": {"loss": 1.2301079724724642, "test_accuracy": 46.21}, "epoch383": {"loss": 1.002147779038433, "test_accuracy": 46.41}, "epoch384": {"loss": 0.9654435192738415, "test_accuracy": 47.18}, "epoch385": {"loss": 0.7982344462000202, "test_accuracy": 47.13}, "epoch386": {"loss": 0.7560235011437018, "test_accuracy": 47.75}, "epoch387": {"loss": 0.6602670991911797, "test_accuracy": 45.83}, "epoch388": {"loss": 0.6529552892801775, "test_accuracy": 48.55}, "epoch389": {"loss": 0.6308781612409652, "test_accuracy": 48.1}, "epoch390": {"loss": 0.5885506028219393, "test_accuracy": 48.22}, "epoch391": {"loss": 0.5735586677955061, "test_accuracy": 48.15}, "epoch392": {"loss": 0.5180062969072419, "test_accuracy": 47.01}, "epoch393": {"loss": 0.5459557001973325, "test_accuracy": 48.37}, "epoch394": {"loss": 0.563905098777564, "test_accuracy": 48.42}, "epoch395": {"loss": 0.47720673806729996, "test_accuracy": 48.04}, "epoch396": {"loss": 0.4901077750498401, "test_accuracy": 48.0}, "epoch397": {"loss": 0.4330924555180104, "test_accuracy": 47.76}, "epoch398": {"loss": 0.46083591758897213, "test_accuracy": 48.96}, "epoch399": {"loss": 0.4513985798558965, "test_accuracy": 48.18}} -------------------------------------------------------------------------------- /experiments/cifar10/fa-V0/experiment_config.txt: -------------------------------------------------------------------------------- 1 | dataset cifar10 2 | algo fa 3 | n_epochs 400 4 | size_hidden_layers [1000] 5 | batch_size 128 6 | learning_rate 0.2 7 | test_frequency 1 8 | save_dir ./experiments 9 | seed 1111 10 | -------------------------------------------------------------------------------- /experiments/cifar10/kp-V0/experiment_config.txt: -------------------------------------------------------------------------------- 1 | dataset cifar10 2 | algo kp 3 | n_epochs 400 4 | size_hidden_layers [1000] 5 | batch_size 128 6 | learning_rate 0.3 7 | test_frequency 1 8 | save_dir ./experiments 9 | seed 1111 10 | -------------------------------------------------------------------------------- /experiments/cifar10/wm-V0/experiment_config.txt: -------------------------------------------------------------------------------- 1 | dataset cifar10 2 | algo wm 3 | n_epochs 400 4 | size_hidden_layers [1000] 5 | batch_size 128 6 | learning_rate 0.05 7 | test_frequency 1 8 | save_dir ./experiments 9 | seed 1111 10 | -------------------------------------------------------------------------------- /experiments/mnist/bp-V0/experiment_config.txt: -------------------------------------------------------------------------------- 1 | dataset mnist 2 | algo bp 3 | n_epochs 400 4 | size_hidden_layers [500] 5 | batch_size 128 6 | learning_rate 0.2 7 | test_frequency 1 8 | save_dir ./experiments 9 | seed 1111 10 | -------------------------------------------------------------------------------- /experiments/mnist/bp-V0/log.txt: -------------------------------------------------------------------------------- 1 | 2 | ==================== 3 | Running the code with the dataset mnist: 4 | learning algorithm: bp 5 | batch_size: 128 6 | learning rate: 0.2 7 | n_epochs: 400 8 | test frequency: 1 9 | size_layers: [784, 500, 10] 10 | ==================== 11 | 12 | 13 | Epoch 0 completed in 34.401774644851685 Loss: 1.4687610010241647 14 | Test accuracy: 88.0 15 | 16 | Epoch 1 completed in 42.44343376159668 Loss: 0.8008829692414702 17 | Test accuracy: 89.69 18 | 19 | Epoch 2 completed in 46.4622061252594 Loss: 0.71097715949614 20 | Test accuracy: 90.58 21 | 22 | Epoch 3 completed in 43.53684639930725 Loss: 0.6571831625533974 23 | Test accuracy: 91.28 24 | 25 | Epoch 4 completed in 43.19096398353577 Loss: 0.61214466687775 26 | Test accuracy: 91.84 27 | 28 | Epoch 5 completed in 48.755091190338135 Loss: 0.5707083757889757 29 | Test accuracy: 92.39 30 | 31 | Epoch 6 completed in 48.887269020080566 Loss: 0.5324278237935055 32 | Test accuracy: 92.93 33 | 34 | Epoch 7 completed in 47.08506727218628 Loss: 0.4976917461804425 35 | Test accuracy: 93.38 36 | 37 | Epoch 8 completed in 45.6626923084259 Loss: 0.4665646288715298 38 | Test accuracy: 93.78 39 | 40 | Epoch 9 completed in 41.59304857254028 Loss: 0.4387485861138838 41 | Test accuracy: 94.12 42 | 43 | Epoch 10 completed in 42.766645431518555 Loss: 0.4138198848508743 44 | Test accuracy: 94.41 45 | 46 | Epoch 11 completed in 43.60832858085632 Loss: 0.3913751547424836 47 | Test accuracy: 94.69 48 | 49 | Epoch 12 completed in 42.493377685546875 Loss: 0.37107344931567104 50 | Test accuracy: 94.9 51 | 52 | Epoch 13 completed in 43.775253772735596 Loss: 0.352634303358013 53 | Test accuracy: 95.07 54 | 55 | Epoch 14 completed in 41.851927518844604 Loss: 0.33582625447662523 56 | Test accuracy: 95.24 57 | 58 | Epoch 15 completed in 44.99303150177002 Loss: 0.3204563490482775 59 | Test accuracy: 95.56 60 | 61 | Epoch 16 completed in 45.33008050918579 Loss: 0.30636193841408266 62 | Test accuracy: 95.75 63 | 64 | Epoch 17 completed in 47.678853273391724 Loss: 0.29340443528590415 65 | Test accuracy: 95.92 66 | 67 | Epoch 18 completed in 43.56520700454712 Loss: 0.28146447919195755 68 | Test accuracy: 96.09 69 | 70 | Epoch 19 completed in 39.24727821350098 Loss: 0.27043806828807976 71 | Test accuracy: 96.23 72 | 73 | Epoch 20 completed in 42.433265209198 Loss: 0.26023352940833017 74 | Test accuracy: 96.38 75 | 76 | Epoch 21 completed in 43.026845932006836 Loss: 0.2507693255475835 77 | Test accuracy: 96.44 78 | 79 | Epoch 22 completed in 43.65891170501709 Loss: 0.24197256023477692 80 | Test accuracy: 96.54 81 | 82 | Epoch 23 completed in 46.435009479522705 Loss: 0.23377792567585012 83 | Test accuracy: 96.61 84 | 85 | Epoch 24 completed in 42.11684465408325 Loss: 0.22612688878892268 86 | Test accuracy: 96.69 87 | 88 | Epoch 25 completed in 42.227872133255005 Loss: 0.2189670082246431 89 | Test accuracy: 96.76 90 | 91 | Epoch 26 completed in 41.16417193412781 Loss: 0.21225133871326055 92 | Test accuracy: 96.83 93 | 94 | Epoch 27 completed in 42.88162875175476 Loss: 0.20593790385708335 95 | Test accuracy: 96.92 96 | 97 | Epoch 28 completed in 45.7510552406311 Loss: 0.19998922427882163 98 | Test accuracy: 97.0 99 | 100 | Epoch 29 completed in 42.59849286079407 Loss: 0.19437188821737075 101 | Test accuracy: 97.07 102 | 103 | Epoch 30 completed in 44.956228256225586 Loss: 0.1890561530772635 104 | Test accuracy: 97.13 105 | 106 | Epoch 31 completed in 44.078301668167114 Loss: 0.18401557061531254 107 | Test accuracy: 97.15 108 | 109 | Epoch 32 completed in 42.706080198287964 Loss: 0.17922663342023648 110 | Test accuracy: 97.17 111 | 112 | Epoch 33 completed in 44.89514946937561 Loss: 0.17466844383369412 113 | Test accuracy: 97.23 114 | 115 | Epoch 34 completed in 45.078850984573364 Loss: 0.17032240776419774 116 | Test accuracy: 97.27 117 | 118 | Epoch 35 completed in 42.33527612686157 Loss: 0.16617195555268427 119 | Test accuracy: 97.31 120 | 121 | Epoch 36 completed in 45.06669545173645 Loss: 0.1622022910436455 122 | Test accuracy: 97.32 123 | 124 | Epoch 37 completed in 43.631144762039185 Loss: 0.15840016887918465 125 | Test accuracy: 97.33 126 | 127 | Epoch 38 completed in 44.4689621925354 Loss: 0.15475369901519803 128 | Test accuracy: 97.41 129 | 130 | Epoch 39 completed in 41.94210696220398 Loss: 0.15125217665905905 131 | Test accuracy: 97.43 132 | 133 | Epoch 40 completed in 45.86351156234741 Loss: 0.1478859353019294 134 | Test accuracy: 97.45 135 | 136 | Epoch 41 completed in 45.87451386451721 Loss: 0.14464622027213242 137 | Test accuracy: 97.49 138 | 139 | Epoch 42 completed in 44.647937297821045 Loss: 0.14152508020964835 140 | Test accuracy: 97.53 141 | 142 | Epoch 43 completed in 43.478699684143066 Loss: 0.13851527396543642 143 | Test accuracy: 97.56 144 | 145 | Epoch 44 completed in 44.84569764137268 Loss: 0.13561019059801496 146 | Test accuracy: 97.6 147 | 148 | Epoch 45 completed in 45.632877349853516 Loss: 0.13280378035592613 149 | Test accuracy: 97.63 150 | 151 | Epoch 46 completed in 45.20951867103577 Loss: 0.1300904948027863 152 | Test accuracy: 97.64 153 | 154 | Epoch 47 completed in 41.34384512901306 Loss: 0.12746523455119663 155 | Test accuracy: 97.65 156 | 157 | Epoch 48 completed in 40.26403260231018 Loss: 0.12492330338682762 158 | Test accuracy: 97.67 159 | 160 | Epoch 49 completed in 40.27131462097168 Loss: 0.12246036784095315 161 | Test accuracy: 97.7 162 | 163 | Epoch 50 completed in 42.649420499801636 Loss: 0.1200724214812643 164 | Test accuracy: 97.71 165 | 166 | Epoch 51 completed in 41.81591296195984 Loss: 0.11775575333484753 167 | Test accuracy: 97.77 168 | 169 | Epoch 52 completed in 45.01831650733948 Loss: 0.11550691994956298 170 | Test accuracy: 97.79 171 | 172 | Epoch 53 completed in 46.03785300254822 Loss: 0.11332272066115481 173 | Test accuracy: 97.79 174 | 175 | Epoch 54 completed in 47.699058055877686 Loss: 0.11120017567902507 176 | Test accuracy: 97.82 177 | 178 | Epoch 55 completed in 43.018518924713135 Loss: 0.10913650664215353 179 | Test accuracy: 97.83 180 | 181 | Epoch 56 completed in 42.7496497631073 Loss: 0.10712911933115235 182 | Test accuracy: 97.85 183 | 184 | Epoch 57 completed in 41.346112966537476 Loss: 0.10517558825337221 185 | Test accuracy: 97.86 186 | 187 | Epoch 58 completed in 43.606868505477905 Loss: 0.10327364284532946 188 | Test accuracy: 97.89 189 | 190 | Epoch 59 completed in 44.10319423675537 Loss: 0.10142115506097324 191 | Test accuracy: 97.88 192 | 193 | Epoch 60 completed in 43.02025818824768 Loss: 0.09961612813671882 194 | Test accuracy: 97.87 195 | 196 | Epoch 61 completed in 44.12937331199646 Loss: 0.09785668634693014 197 | Test accuracy: 97.88 198 | 199 | Epoch 62 completed in 43.7455108165741 Loss: 0.09614106558970646 200 | Test accuracy: 97.89 201 | 202 | Epoch 63 completed in 44.06430125236511 Loss: 0.09446760467570174 203 | Test accuracy: 97.89 204 | 205 | Epoch 64 completed in 44.91193437576294 Loss: 0.09283473723412336 206 | Test accuracy: 97.9 207 | 208 | Epoch 65 completed in 41.72046613693237 Loss: 0.09124098419795468 209 | Test accuracy: 97.92 210 | 211 | Epoch 66 completed in 43.22236204147339 Loss: 0.08968494687693031 212 | Test accuracy: 97.94 213 | 214 | Epoch 67 completed in 44.62178921699524 Loss: 0.08816530065904996 215 | Test accuracy: 97.93 216 | 217 | Epoch 68 completed in 42.9177508354187 Loss: 0.08668078938681398 218 | Test accuracy: 97.94 219 | 220 | Epoch 69 completed in 44.84669232368469 Loss: 0.08523022042812725 221 | Test accuracy: 97.96 222 | 223 | Epoch 70 completed in 39.808613300323486 Loss: 0.08381246041213777 224 | Test accuracy: 97.97 225 | 226 | Epoch 71 completed in 39.131019830703735 Loss: 0.08242643154607313 227 | Test accuracy: 97.96 228 | 229 | Epoch 72 completed in 39.92457628250122 Loss: 0.08107110839083552 230 | Test accuracy: 97.96 231 | 232 | Epoch 73 completed in 43.880839824676514 Loss: 0.07974551496260994 233 | Test accuracy: 97.98 234 | 235 | Epoch 74 completed in 40.40015649795532 Loss: 0.0784487220442137 236 | Test accuracy: 97.99 237 | 238 | Epoch 75 completed in 43.977158308029175 Loss: 0.0771798446231022 239 | Test accuracy: 98.0 240 | 241 | Epoch 76 completed in 46.83385515213013 Loss: 0.07593803941002635 242 | Test accuracy: 98.01 243 | 244 | Epoch 77 completed in 44.692113637924194 Loss: 0.07472250242358494 245 | Test accuracy: 98.02 246 | 247 | Epoch 78 completed in 44.50692963600159 Loss: 0.07353246664681481 248 | Test accuracy: 98.01 249 | 250 | Epoch 79 completed in 43.4429292678833 Loss: 0.07236719977237531 251 | Test accuracy: 98.03 252 | 253 | Epoch 80 completed in 42.44793081283569 Loss: 0.07122600205509873 254 | Test accuracy: 98.07 255 | 256 | Epoch 81 completed in 38.065776348114014 Loss: 0.07010820428764654 257 | Test accuracy: 98.07 258 | 259 | Epoch 82 completed in 47.55217099189758 Loss: 0.06901316590933923 260 | Test accuracy: 98.07 261 | 262 | Epoch 83 completed in 43.54308032989502 Loss: 0.06794027325181365 263 | Test accuracy: 98.08 264 | 265 | Epoch 84 completed in 42.295071601867676 Loss: 0.06688893791920361 266 | Test accuracy: 98.11 267 | 268 | Epoch 85 completed in 43.315223693847656 Loss: 0.06585859529570981 269 | Test accuracy: 98.1 270 | 271 | Epoch 86 completed in 48.42934226989746 Loss: 0.06484870317001602 272 | Test accuracy: 98.12 273 | 274 | Epoch 87 completed in 45.34317064285278 Loss: 0.0638587404640846 275 | Test accuracy: 98.14 276 | 277 | Epoch 88 completed in 40.77006411552429 Loss: 0.06288820605332689 278 | Test accuracy: 98.17 279 | 280 | Epoch 89 completed in 40.89706516265869 Loss: 0.06193661766579366 281 | Test accuracy: 98.18 282 | 283 | Epoch 90 completed in 43.480767488479614 Loss: 0.061003510849600456 284 | Test accuracy: 98.18 285 | 286 | Epoch 91 completed in 45.18601632118225 Loss: 0.060088437999979975 287 | Test accuracy: 98.18 288 | 289 | Epoch 92 completed in 42.04465961456299 Loss: 0.05919096743981311 290 | Test accuracy: 98.18 291 | 292 | Epoch 93 completed in 36.57961106300354 Loss: 0.05831068254990535 293 | Test accuracy: 98.18 294 | 295 | Epoch 94 completed in 40.32199573516846 Loss: 0.05744718094735896 296 | Test accuracy: 98.19 297 | 298 | Epoch 95 completed in 38.40521502494812 Loss: 0.05660007371191274 299 | Test accuracy: 98.2 300 | 301 | Epoch 96 completed in 38.98659801483154 Loss: 0.05576898466094039 302 | Test accuracy: 98.21 303 | 304 | Epoch 97 completed in 38.54790496826172 Loss: 0.05495354967388407 305 | Test accuracy: 98.25 306 | 307 | Epoch 98 completed in 38.008573055267334 Loss: 0.05415341606633391 308 | Test accuracy: 98.24 309 | 310 | Epoch 99 completed in 38.410298347473145 Loss: 0.053368242012929445 311 | Test accuracy: 98.25 312 | 313 | Epoch 100 completed in 40.622010707855225 Loss: 0.052597696017000574 314 | Test accuracy: 98.25 315 | 316 | Epoch 101 completed in 37.76919960975647 Loss: 0.0518414564236445 317 | Test accuracy: 98.25 318 | 319 | Epoch 102 completed in 41.25347828865051 Loss: 0.05109921097197344 320 | Test accuracy: 98.26 321 | 322 | Epoch 103 completed in 39.52380061149597 Loss: 0.05037065638171411 323 | Test accuracy: 98.29 324 | 325 | Epoch 104 completed in 37.859989404678345 Loss: 0.04965549796924587 326 | Test accuracy: 98.3 327 | 328 | Epoch 105 completed in 38.29388427734375 Loss: 0.048953449288490755 329 | Test accuracy: 98.3 330 | 331 | Epoch 106 completed in 38.67134070396423 Loss: 0.04826423179271491 332 | Test accuracy: 98.31 333 | 334 | Epoch 107 completed in 38.35057187080383 Loss: 0.04758757451412534 335 | Test accuracy: 98.32 336 | 337 | Epoch 108 completed in 38.15062761306763 Loss: 0.0469232137590131 338 | Test accuracy: 98.32 339 | 340 | Epoch 109 completed in 34.93558883666992 Loss: 0.04627089281698532 341 | Test accuracy: 98.32 342 | 343 | Epoch 110 completed in 36.85185885429382 Loss: 0.04563036168346673 344 | Test accuracy: 98.32 345 | 346 | Epoch 111 completed in 42.018710136413574 Loss: 0.04500137679509542 347 | Test accuracy: 98.32 348 | 349 | Epoch 112 completed in 41.133124589920044 Loss: 0.04438370077788588 350 | Test accuracy: 98.32 351 | 352 | Epoch 113 completed in 38.35520696640015 Loss: 0.04377710220810704 353 | Test accuracy: 98.32 354 | 355 | Epoch 114 completed in 38.44828176498413 Loss: 0.04318135538576708 356 | Test accuracy: 98.33 357 | 358 | Epoch 115 completed in 38.29985761642456 Loss: 0.042596240120459115 359 | Test accuracy: 98.34 360 | 361 | Epoch 116 completed in 39.20027017593384 Loss: 0.042021541529152676 362 | Test accuracy: 98.34 363 | 364 | Epoch 117 completed in 38.943798780441284 Loss: 0.041457049845361575 365 | Test accuracy: 98.34 366 | 367 | Epoch 118 completed in 40.754639863967896 Loss: 0.04090256023901066 368 | Test accuracy: 98.34 369 | 370 | Epoch 119 completed in 38.25162625312805 Loss: 0.040357872646286486 371 | Test accuracy: 98.34 372 | 373 | Epoch 120 completed in 38.14370560646057 Loss: 0.03982279160879374 374 | Test accuracy: 98.34 375 | 376 | Epoch 121 completed in 39.97275447845459 Loss: 0.03929712612144466 377 | Test accuracy: 98.33 378 | 379 | Epoch 122 completed in 38.66237926483154 Loss: 0.03878068948866027 380 | Test accuracy: 98.34 381 | 382 | Epoch 123 completed in 42.6996328830719 Loss: 0.0382732991886336 383 | Test accuracy: 98.35 384 | 385 | Epoch 124 completed in 40.8653039932251 Loss: 0.03777477674556115 386 | Test accuracy: 98.35 387 | 388 | Epoch 125 completed in 39.51929688453674 Loss: 0.037284947609856886 389 | Test accuracy: 98.37 390 | 391 | Epoch 126 completed in 40.530123472213745 Loss: 0.03680364104639507 392 | Test accuracy: 98.36 393 | 394 | Epoch 127 completed in 42.03773045539856 Loss: 0.036330690030763835 395 | Test accuracy: 98.36 396 | 397 | Epoch 128 completed in 37.38969826698303 Loss: 0.03586593115334168 398 | Test accuracy: 98.36 399 | 400 | Epoch 129 completed in 37.537506341934204 Loss: 0.03540920453073885 401 | Test accuracy: 98.36 402 | 403 | Epoch 130 completed in 38.872467279434204 Loss: 0.03496035372379298 404 | Test accuracy: 98.36 405 | 406 | Epoch 131 completed in 39.21584343910217 Loss: 0.03451922566090628 407 | Test accuracy: 98.35 408 | 409 | Epoch 132 completed in 37.16961598396301 Loss: 0.03408567056510516 410 | Test accuracy: 98.36 411 | 412 | Epoch 133 completed in 38.095141649246216 Loss: 0.03365954188284936 413 | Test accuracy: 98.37 414 | 415 | Epoch 134 completed in 40.784069299697876 Loss: 0.03324069621237883 416 | Test accuracy: 98.37 417 | 418 | Epoch 135 completed in 37.93044400215149 Loss: 0.03282899322932803 419 | Test accuracy: 98.37 420 | 421 | Epoch 136 completed in 37.57743048667908 Loss: 0.03242429560751624 422 | Test accuracy: 98.37 423 | 424 | Epoch 137 completed in 35.87708234786987 Loss: 0.03202646893327853 425 | Test accuracy: 98.37 426 | 427 | Epoch 138 completed in 35.91650056838989 Loss: 0.03163538161244476 428 | Test accuracy: 98.37 429 | 430 | Epoch 139 completed in 37.48172330856323 Loss: 0.03125090477006892 431 | Test accuracy: 98.36 432 | 433 | Epoch 140 completed in 38.47050356864929 Loss: 0.030872912144172045 434 | Test accuracy: 98.36 435 | 436 | Epoch 141 completed in 38.68788504600525 Loss: 0.030501279975947888 437 | Test accuracy: 98.37 438 | 439 | Epoch 142 completed in 38.33132529258728 Loss: 0.030135886899912725 440 | Test accuracy: 98.37 441 | 442 | Epoch 143 completed in 37.56525707244873 Loss: 0.029776613838166387 443 | Test accuracy: 98.37 444 | 445 | Epoch 144 completed in 38.80171775817871 Loss: 0.029423343903110667 446 | Test accuracy: 98.37 447 | 448 | Epoch 145 completed in 38.162795543670654 Loss: 0.029075962312556936 449 | Test accuracy: 98.36 450 | 451 | Epoch 146 completed in 39.12995910644531 Loss: 0.028734356320169054 452 | Test accuracy: 98.36 453 | 454 | Epoch 147 completed in 39.16180753707886 Loss: 0.02839841516276882 455 | Test accuracy: 98.37 456 | 457 | Epoch 148 completed in 37.792460680007935 Loss: 0.02806803002440847 458 | Test accuracy: 98.38 459 | 460 | Epoch 149 completed in 37.7056143283844 Loss: 0.027743094015557242 461 | Test accuracy: 98.37 462 | 463 | Epoch 150 completed in 38.04186153411865 Loss: 0.02742350216450171 464 | Test accuracy: 98.37 465 | 466 | Epoch 151 completed in 38.9227991104126 Loss: 0.027109151417288296 467 | Test accuracy: 98.38 468 | 469 | Epoch 152 completed in 38.092323541641235 Loss: 0.026799940642299125 470 | Test accuracy: 98.37 471 | 472 | Epoch 153 completed in 36.00245022773743 Loss: 0.026495770635803335 473 | Test accuracy: 98.37 474 | 475 | Epoch 154 completed in 39.201128005981445 Loss: 0.026196544125441052 476 | Test accuracy: 98.38 477 | 478 | Epoch 155 completed in 37.82910895347595 Loss: 0.025902165769420483 479 | Test accuracy: 98.38 480 | 481 | Epoch 156 completed in 38.2085542678833 Loss: 0.025612542150084086 482 | Test accuracy: 98.37 483 | 484 | Epoch 157 completed in 36.839391231536865 Loss: 0.025327581761307807 485 | Test accuracy: 98.37 486 | 487 | Epoch 158 completed in 35.832582235336304 Loss: 0.025047194989857628 488 | Test accuracy: 98.37 489 | 490 | Epoch 159 completed in 36.570793867111206 Loss: 0.02477129409130728 491 | Test accuracy: 98.37 492 | 493 | Epoch 160 completed in 34.65072989463806 Loss: 0.024499793161418176 494 | Test accuracy: 98.36 495 | 496 | Epoch 161 completed in 35.62934684753418 Loss: 0.02423260810402081 497 | Test accuracy: 98.36 498 | 499 | Epoch 162 completed in 37.00571632385254 Loss: 0.023969656596448747 500 | Test accuracy: 98.36 501 | 502 | Epoch 163 completed in 38.48982381820679 Loss: 0.023710858053498825 503 | Test accuracy: 98.36 504 | 505 | Epoch 164 completed in 38.71276593208313 Loss: 0.023456133590757496 506 | Test accuracy: 98.36 507 | 508 | Epoch 165 completed in 39.43828821182251 Loss: 0.023205405987970947 509 | Test accuracy: 98.36 510 | 511 | Epoch 166 completed in 37.94929814338684 Loss: 0.022958599652967687 512 | Test accuracy: 98.37 513 | 514 | Epoch 167 completed in 38.52124214172363 Loss: 0.022715640586481085 515 | Test accuracy: 98.37 516 | 517 | Epoch 168 completed in 37.57990074157715 Loss: 0.022476456348076178 518 | Test accuracy: 98.37 519 | 520 | Epoch 169 completed in 41.12182927131653 Loss: 0.022240976023265264 521 | Test accuracy: 98.37 522 | 523 | Epoch 170 completed in 39.86361241340637 Loss: 0.022009130191802062 524 | Test accuracy: 98.38 525 | 526 | Epoch 171 completed in 38.91173791885376 Loss: 0.021780850897075064 527 | Test accuracy: 98.38 528 | 529 | Epoch 172 completed in 41.13706707954407 Loss: 0.02155607161647382 530 | Test accuracy: 98.39 531 | 532 | Epoch 173 completed in 41.567359924316406 Loss: 0.021334727232575817 533 | Test accuracy: 98.38 534 | 535 | Epoch 174 completed in 40.16179585456848 Loss: 0.02111675400499156 536 | Test accuracy: 98.39 537 | 538 | Epoch 175 completed in 37.62286639213562 Loss: 0.020902089542708044 539 | Test accuracy: 98.39 540 | 541 | Epoch 176 completed in 40.14611887931824 Loss: 0.020690672776783447 542 | Test accuracy: 98.39 543 | 544 | Epoch 177 completed in 39.70938444137573 Loss: 0.020482443933264195 545 | Test accuracy: 98.39 546 | 547 | Epoch 178 completed in 40.40145301818848 Loss: 0.02027734450621791 548 | Test accuracy: 98.39 549 | 550 | Epoch 179 completed in 37.68731617927551 Loss: 0.020075317230800215 551 | Test accuracy: 98.4 552 | 553 | Epoch 180 completed in 39.96287560462952 Loss: 0.019876306056297654 554 | Test accuracy: 98.4 555 | 556 | Epoch 181 completed in 40.874922037124634 Loss: 0.019680256119113136 557 | Test accuracy: 98.4 558 | 559 | Epoch 182 completed in 36.3519971370697 Loss: 0.019487113715682736 560 | Test accuracy: 98.4 561 | 562 | Epoch 183 completed in 37.57686233520508 Loss: 0.01929682627533389 563 | Test accuracy: 98.41 564 | 565 | Epoch 184 completed in 37.228286027908325 Loss: 0.019109342333113635 566 | Test accuracy: 98.41 567 | 568 | Epoch 185 completed in 37.07357144355774 Loss: 0.018924611502632898 569 | Test accuracy: 98.41 570 | 571 | Epoch 186 completed in 36.12888264656067 Loss: 0.018742584448987472 572 | Test accuracy: 98.41 573 | 574 | Epoch 187 completed in 36.18499684333801 Loss: 0.01856321286182905 575 | Test accuracy: 98.41 576 | 577 | Epoch 188 completed in 39.15977954864502 Loss: 0.018386449428669998 578 | Test accuracy: 98.41 579 | 580 | Epoch 189 completed in 40.07836675643921 Loss: 0.01821224780851335 581 | Test accuracy: 98.41 582 | 583 | Epoch 190 completed in 38.0655779838562 Loss: 0.018040562605904788 584 | Test accuracy: 98.41 585 | 586 | Epoch 191 completed in 36.87221908569336 Loss: 0.01787134934550564 587 | Test accuracy: 98.41 588 | 589 | Epoch 192 completed in 37.52015161514282 Loss: 0.01770456444728589 590 | Test accuracy: 98.41 591 | 592 | Epoch 193 completed in 35.696027517318726 Loss: 0.01754016520243248 593 | Test accuracy: 98.42 594 | 595 | Epoch 194 completed in 37.00619125366211 Loss: 0.017378109750062216 596 | Test accuracy: 98.42 597 | 598 | Epoch 195 completed in 37.01072430610657 Loss: 0.01721835705482008 599 | Test accuracy: 98.42 600 | 601 | Epoch 196 completed in 36.84979009628296 Loss: 0.017060866885432274 602 | Test accuracy: 98.42 603 | 604 | Epoch 197 completed in 38.093836307525635 Loss: 0.016905599794270303 605 | Test accuracy: 98.42 606 | 607 | Epoch 198 completed in 38.29017353057861 Loss: 0.016752517097967993 608 | Test accuracy: 98.42 609 | 610 | Epoch 199 completed in 37.971465826034546 Loss: 0.01660158085911731 611 | Test accuracy: 98.42 612 | 613 | Epoch 200 completed in 38.46734571456909 Loss: 0.016452753869053087 614 | Test accuracy: 98.43 615 | 616 | Epoch 201 completed in 39.423102378845215 Loss: 0.016305999631720306 617 | Test accuracy: 98.43 618 | 619 | Epoch 202 completed in 40.34502911567688 Loss: 0.016161282348602313 620 | Test accuracy: 98.43 621 | 622 | Epoch 203 completed in 41.64002513885498 Loss: 0.01601856690467338 623 | Test accuracy: 98.43 624 | 625 | Epoch 204 completed in 37.87933659553528 Loss: 0.015877818855325753 626 | Test accuracy: 98.43 627 | 628 | Epoch 205 completed in 38.34079194068909 Loss: 0.015739004414209157 629 | Test accuracy: 98.43 630 | 631 | Epoch 206 completed in 37.53245830535889 Loss: 0.015602090441910408 632 | Test accuracy: 98.43 633 | 634 | Epoch 207 completed in 39.18299221992493 Loss: 0.01546704443539152 635 | Test accuracy: 98.43 636 | 637 | Epoch 208 completed in 38.79713320732117 Loss: 0.015333834518097701 638 | Test accuracy: 98.43 639 | 640 | Epoch 209 completed in 41.55415987968445 Loss: 0.015202429430640261 641 | Test accuracy: 98.43 642 | 643 | Epoch 210 completed in 38.01646304130554 Loss: 0.01507279852195489 644 | Test accuracy: 98.43 645 | 646 | Epoch 211 completed in 40.20992183685303 Loss: 0.014944911740831974 647 | Test accuracy: 98.42 648 | 649 | Epoch 212 completed in 39.891982078552246 Loss: 0.014818739627713176 650 | Test accuracy: 98.42 651 | 652 | Epoch 213 completed in 39.38285279273987 Loss: 0.014694253306646634 653 | Test accuracy: 98.42 654 | 655 | Epoch 214 completed in 38.1623432636261 Loss: 0.01457142447729244 656 | Test accuracy: 98.42 657 | 658 | Epoch 215 completed in 41.03002977371216 Loss: 0.014450225406870193 659 | Test accuracy: 98.42 660 | 661 | Epoch 216 completed in 39.43681311607361 Loss: 0.014330628921941837 662 | Test accuracy: 98.42 663 | 664 | Epoch 217 completed in 40.0456485748291 Loss: 0.014212608399925702 665 | Test accuracy: 98.42 666 | 667 | Epoch 218 completed in 38.642661333084106 Loss: 0.014096137760242152 668 | Test accuracy: 98.42 669 | 670 | Epoch 219 completed in 37.9597442150116 Loss: 0.013981191454997825 671 | Test accuracy: 98.42 672 | 673 | Epoch 220 completed in 37.68348431587219 Loss: 0.013867744459124682 674 | Test accuracy: 98.42 675 | 676 | Epoch 221 completed in 41.557700634002686 Loss: 0.013755772259902292 677 | Test accuracy: 98.42 678 | 679 | Epoch 222 completed in 40.49189853668213 Loss: 0.013645250845807429 680 | Test accuracy: 98.42 681 | 682 | Epoch 223 completed in 36.29889154434204 Loss: 0.01353615669465426 683 | Test accuracy: 98.42 684 | 685 | Epoch 224 completed in 39.29086232185364 Loss: 0.013428466761011648 686 | Test accuracy: 98.42 687 | 688 | Epoch 225 completed in 39.16224646568298 Loss: 0.013322158462910561 689 | Test accuracy: 98.42 690 | 691 | Epoch 226 completed in 38.90195870399475 Loss: 0.013217209667884403 692 | Test accuracy: 98.42 693 | 694 | Epoch 227 completed in 40.34247279167175 Loss: 0.013113598678416823 695 | Test accuracy: 98.42 696 | 697 | Epoch 228 completed in 37.476614475250244 Loss: 0.013011304216904288 698 | Test accuracy: 98.42 699 | 700 | Epoch 229 completed in 38.3502995967865 Loss: 0.012910305410272683 701 | Test accuracy: 98.42 702 | 703 | Epoch 230 completed in 38.62120079994202 Loss: 0.012810581774416201 704 | Test accuracy: 98.42 705 | 706 | Epoch 231 completed in 37.11960554122925 Loss: 0.01271211319865128 707 | Test accuracy: 98.42 708 | 709 | Epoch 232 completed in 37.09584617614746 Loss: 0.012614879930395537 710 | Test accuracy: 98.43 711 | 712 | Epoch 233 completed in 37.89085006713867 Loss: 0.012518862560290382 713 | Test accuracy: 98.43 714 | 715 | Epoch 234 completed in 38.94679498672485 Loss: 0.01242404200798434 716 | Test accuracy: 98.42 717 | 718 | Epoch 235 completed in 37.97728872299194 Loss: 0.012330399508781672 719 | Test accuracy: 98.42 720 | 721 | Epoch 236 completed in 37.30201244354248 Loss: 0.012237916601337719 722 | Test accuracy: 98.42 723 | 724 | Epoch 237 completed in 36.18613076210022 Loss: 0.012146575116549187 725 | Test accuracy: 98.42 726 | 727 | Epoch 238 completed in 34.77349257469177 Loss: 0.012056357167746798 728 | Test accuracy: 98.42 729 | 730 | Epoch 239 completed in 38.07384252548218 Loss: 0.011967245142250962 731 | Test accuracy: 98.42 732 | 733 | Epoch 240 completed in 38.8392608165741 Loss: 0.0118792216943025 734 | Test accuracy: 98.42 735 | 736 | Epoch 241 completed in 38.080121755599976 Loss: 0.011792269739331979 737 | Test accuracy: 98.42 738 | 739 | Epoch 242 completed in 41.06998538970947 Loss: 0.011706372449486985 740 | Test accuracy: 98.42 741 | 742 | Epoch 243 completed in 37.90121054649353 Loss: 0.011621513250298574 743 | Test accuracy: 98.42 744 | 745 | Epoch 244 completed in 40.55748200416565 Loss: 0.011537675818338493 746 | Test accuracy: 98.42 747 | 748 | Epoch 245 completed in 38.615729093551636 Loss: 0.011454844079698673 749 | Test accuracy: 98.43 750 | 751 | Epoch 246 completed in 38.41405630111694 Loss: 0.01137300220911416 752 | Test accuracy: 98.43 753 | 754 | Epoch 247 completed in 43.15579152107239 Loss: 0.011292134629549906 755 | Test accuracy: 98.43 756 | 757 | Epoch 248 completed in 39.844804763793945 Loss: 0.011212226012079076 758 | Test accuracy: 98.43 759 | 760 | Epoch 249 completed in 41.66868042945862 Loss: 0.011133261275895078 761 | Test accuracy: 98.43 762 | 763 | Epoch 250 completed in 39.25139880180359 Loss: 0.01105522558831824 764 | Test accuracy: 98.43 765 | 766 | Epoch 251 completed in 40.12653374671936 Loss: 0.010978104364680944 767 | Test accuracy: 98.43 768 | 769 | Epoch 252 completed in 38.536898136138916 Loss: 0.010901883267998234 770 | Test accuracy: 98.43 771 | 772 | Epoch 253 completed in 37.84711241722107 Loss: 0.010826548208354875 773 | Test accuracy: 98.43 774 | 775 | Epoch 254 completed in 38.19525957107544 Loss: 0.010752085341961941 776 | Test accuracy: 98.43 777 | 778 | Epoch 255 completed in 38.08241081237793 Loss: 0.01067848106985619 779 | Test accuracy: 98.43 780 | 781 | Epoch 256 completed in 38.08494305610657 Loss: 0.010605722036232822 782 | Test accuracy: 98.43 783 | 784 | Epoch 257 completed in 41.261803150177 Loss: 0.010533795126416506 785 | Test accuracy: 98.43 786 | 787 | Epoch 258 completed in 36.61922097206116 Loss: 0.010462687464486966 788 | Test accuracy: 98.43 789 | 790 | Epoch 259 completed in 42.443081855773926 Loss: 0.010392386410583779 791 | Test accuracy: 98.43 792 | 793 | Epoch 260 completed in 39.36473250389099 Loss: 0.010322879557920814 794 | Test accuracy: 98.43 795 | 796 | Epoch 261 completed in 39.65820789337158 Loss: 0.010254154729544348 797 | Test accuracy: 98.43 798 | 799 | Epoch 262 completed in 40.143985986709595 Loss: 0.010186199974870564 800 | Test accuracy: 98.43 801 | 802 | Epoch 263 completed in 40.10165309906006 Loss: 0.010119003566038114 803 | Test accuracy: 98.43 804 | 805 | Epoch 264 completed in 36.175777435302734 Loss: 0.010052553994110655 806 | Test accuracy: 98.43 807 | 808 | Epoch 265 completed in 38.70072531700134 Loss: 0.009986839965162229 809 | Test accuracy: 98.43 810 | 811 | Epoch 266 completed in 39.38526797294617 Loss: 0.009921850396275952 812 | Test accuracy: 98.43 813 | 814 | Epoch 267 completed in 39.80870532989502 Loss: 0.00985757441148384 815 | Test accuracy: 98.43 816 | 817 | Epoch 268 completed in 39.376097440719604 Loss: 0.009794001337672423 818 | Test accuracy: 98.43 819 | 820 | Epoch 269 completed in 36.32887935638428 Loss: 0.009731120700476152 821 | Test accuracy: 98.43 822 | 823 | Epoch 270 completed in 37.98457884788513 Loss: 0.009668922220177381 824 | Test accuracy: 98.43 825 | 826 | Epoch 271 completed in 40.071518659591675 Loss: 0.009607395807629295 827 | Test accuracy: 98.43 828 | 829 | Epoch 272 completed in 38.5419442653656 Loss: 0.009546531560215336 830 | Test accuracy: 98.43 831 | 832 | Epoch 273 completed in 37.49844789505005 Loss: 0.009486319757856635 833 | Test accuracy: 98.43 834 | 835 | Epoch 274 completed in 40.92298412322998 Loss: 0.009426750859076592 836 | Test accuracy: 98.43 837 | 838 | Epoch 275 completed in 40.10505437850952 Loss: 0.00936781549713018 839 | Test accuracy: 98.43 840 | 841 | Epoch 276 completed in 35.82091665267944 Loss: 0.00930950447620372 842 | Test accuracy: 98.43 843 | 844 | Epoch 277 completed in 38.66067957878113 Loss: 0.009251808767689575 845 | Test accuracy: 98.43 846 | 847 | Epoch 278 completed in 37.22276949882507 Loss: 0.009194719506538966 848 | Test accuracy: 98.43 849 | 850 | Epoch 279 completed in 37.587263107299805 Loss: 0.009138227987695135 851 | Test accuracy: 98.43 852 | 853 | Epoch 280 completed in 40.89028716087341 Loss: 0.00908232566260811 854 | Test accuracy: 98.43 855 | 856 | Epoch 281 completed in 39.5791015625 Loss: 0.009027004135831743 857 | Test accuracy: 98.44 858 | 859 | Epoch 282 completed in 37.6649751663208 Loss: 0.00897225516170298 860 | Test accuracy: 98.44 861 | 862 | Epoch 283 completed in 38.598140478134155 Loss: 0.008918070641102846 863 | Test accuracy: 98.44 864 | 865 | Epoch 284 completed in 37.837477922439575 Loss: 0.00886444261829838 866 | Test accuracy: 98.44 867 | 868 | Epoch 285 completed in 39.3852756023407 Loss: 0.008811363277864186 869 | Test accuracy: 98.44 870 | 871 | Epoch 286 completed in 42.02154541015625 Loss: 0.008758824941682246 872 | Test accuracy: 98.44 873 | 874 | Epoch 287 completed in 39.425365924835205 Loss: 0.008706820066018367 875 | Test accuracy: 98.44 876 | 877 | Epoch 288 completed in 37.812427282333374 Loss: 0.008655341238673462 878 | Test accuracy: 98.44 879 | 880 | Epoch 289 completed in 38.50492572784424 Loss: 0.008604381176207829 881 | Test accuracy: 98.45 882 | 883 | Epoch 290 completed in 39.5103645324707 Loss: 0.008553932721236534 884 | Test accuracy: 98.45 885 | 886 | Epoch 291 completed in 36.48542308807373 Loss: 0.00850398883979386 887 | Test accuracy: 98.45 888 | 889 | Epoch 292 completed in 37.466047286987305 Loss: 0.008454542618764962 890 | Test accuracy: 98.45 891 | 892 | Epoch 293 completed in 37.69159126281738 Loss: 0.008405587263382653 893 | Test accuracy: 98.46 894 | 895 | Epoch 294 completed in 38.147550106048584 Loss: 0.008357116094787431 896 | Test accuracy: 98.46 897 | 898 | Epoch 295 completed in 38.70105028152466 Loss: 0.008309122547648838 899 | Test accuracy: 98.46 900 | 901 | Epoch 296 completed in 38.11715006828308 Loss: 0.00826160016784624 902 | Test accuracy: 98.47 903 | 904 | Epoch 297 completed in 37.327470779418945 Loss: 0.008214542610207206 905 | Test accuracy: 98.47 906 | 907 | Epoch 298 completed in 39.13540697097778 Loss: 0.008167943636301713 908 | Test accuracy: 98.47 909 | 910 | Epoch 299 completed in 39.332314252853394 Loss: 0.008121797112290443 911 | Test accuracy: 98.47 912 | 913 | Epoch 300 completed in 38.52493762969971 Loss: 0.008076097006825487 914 | Test accuracy: 98.46 915 | 916 | Epoch 301 completed in 37.25596213340759 Loss: 0.008030837389001856 917 | Test accuracy: 98.46 918 | 919 | Epoch 302 completed in 40.353206634521484 Loss: 0.007986012426358238 920 | Test accuracy: 98.46 921 | 922 | Epoch 303 completed in 37.875197410583496 Loss: 0.007941616382925471 923 | Test accuracy: 98.46 924 | 925 | Epoch 304 completed in 36.69093060493469 Loss: 0.007897643617321346 926 | Test accuracy: 98.46 927 | 928 | Epoch 305 completed in 40.98635387420654 Loss: 0.007854088580890248 929 | Test accuracy: 98.46 930 | 931 | Epoch 306 completed in 40.885191440582275 Loss: 0.007810945815886421 932 | Test accuracy: 98.47 933 | 934 | Epoch 307 completed in 39.75351405143738 Loss: 0.007768209953699453 935 | Test accuracy: 98.47 936 | 937 | Epoch 308 completed in 38.90772724151611 Loss: 0.007725875713120819 938 | Test accuracy: 98.47 939 | 940 | Epoch 309 completed in 40.72539258003235 Loss: 0.007683937898650273 941 | Test accuracy: 98.47 942 | 943 | Epoch 310 completed in 35.57447791099548 Loss: 0.0076423913988409135 944 | Test accuracy: 98.47 945 | 946 | Epoch 311 completed in 34.943453311920166 Loss: 0.007601231184681882 947 | Test accuracy: 98.47 948 | 949 | Epoch 312 completed in 36.1629536151886 Loss: 0.007560452308017581 950 | Test accuracy: 98.47 951 | 952 | Epoch 313 completed in 38.005001068115234 Loss: 0.007520049900002426 953 | Test accuracy: 98.47 954 | 955 | Epoch 314 completed in 37.78376388549805 Loss: 0.007480019169590153 956 | Test accuracy: 98.47 957 | 958 | Epoch 315 completed in 37.699337005615234 Loss: 0.00744035540205672 959 | Test accuracy: 98.47 960 | 961 | Epoch 316 completed in 37.912625789642334 Loss: 0.007401053957555903 962 | Test accuracy: 98.47 963 | 964 | Epoch 317 completed in 38.2661018371582 Loss: 0.007362110269706727 965 | Test accuracy: 98.48 966 | 967 | Epoch 318 completed in 39.23182225227356 Loss: 0.007323519844211861 968 | Test accuracy: 98.48 969 | 970 | Epoch 319 completed in 36.286125898361206 Loss: 0.007285278257506193 971 | Test accuracy: 98.48 972 | 973 | Epoch 320 completed in 41.00956320762634 Loss: 0.007247381155434773 974 | Test accuracy: 98.49 975 | 976 | Epoch 321 completed in 38.71486163139343 Loss: 0.0072098242519593805 977 | Test accuracy: 98.49 978 | 979 | Epoch 322 completed in 37.566535234451294 Loss: 0.007172603327892986 980 | Test accuracy: 98.49 981 | 982 | Epoch 323 completed in 35.897369384765625 Loss: 0.007135714229661387 983 | Test accuracy: 98.49 984 | 985 | Epoch 324 completed in 34.60488200187683 Loss: 0.007099152868091355 986 | Test accuracy: 98.49 987 | 988 | Epoch 325 completed in 35.93748235702515 Loss: 0.00706291521722459 989 | Test accuracy: 98.49 990 | 991 | Epoch 326 completed in 36.58331632614136 Loss: 0.007026997313156908 992 | Test accuracy: 98.49 993 | 994 | Epoch 327 completed in 39.135740518569946 Loss: 0.006991395252901946 995 | Test accuracy: 98.49 996 | 997 | Epoch 328 completed in 41.494251012802124 Loss: 0.0069561051932789005 998 | Test accuracy: 98.49 999 | 1000 | Epoch 329 completed in 36.47455954551697 Loss: 0.006921123349823596 1001 | Test accuracy: 98.49 1002 | 1003 | Epoch 330 completed in 36.72316908836365 Loss: 0.0068864459957224166 1004 | Test accuracy: 98.49 1005 | 1006 | Epoch 331 completed in 40.42720174789429 Loss: 0.0068520694607684795 1007 | Test accuracy: 98.49 1008 | 1009 | Epoch 332 completed in 38.699204444885254 Loss: 0.006817990130339551 1010 | Test accuracy: 98.49 1011 | 1012 | Epoch 333 completed in 40.02722907066345 Loss: 0.006784204444397226 1013 | Test accuracy: 98.49 1014 | 1015 | Epoch 334 completed in 39.33175802230835 Loss: 0.006750708896506758 1016 | Test accuracy: 98.49 1017 | 1018 | Epoch 335 completed in 39.22186803817749 Loss: 0.006717500032877164 1019 | Test accuracy: 98.49 1020 | 1021 | Epoch 336 completed in 38.896955251693726 Loss: 0.006684574451421088 1022 | Test accuracy: 98.49 1023 | 1024 | Epoch 337 completed in 37.8179407119751 Loss: 0.006651928800833937 1025 | Test accuracy: 98.49 1026 | 1027 | Epoch 338 completed in 36.938556432724 Loss: 0.006619559779691875 1028 | Test accuracy: 98.49 1029 | 1030 | Epoch 339 completed in 38.26715445518494 Loss: 0.006587464135568227 1031 | Test accuracy: 98.49 1032 | 1033 | Epoch 340 completed in 38.82295560836792 Loss: 0.006555638664167865 1034 | Test accuracy: 98.49 1035 | 1036 | Epoch 341 completed in 36.705681800842285 Loss: 0.0065240802084791695 1037 | Test accuracy: 98.49 1038 | 1039 | Epoch 342 completed in 37.834657430648804 Loss: 0.006492785657943157 1040 | Test accuracy: 98.49 1041 | 1042 | Epoch 343 completed in 36.23966431617737 Loss: 0.006461751947639371 1043 | Test accuracy: 98.49 1044 | 1045 | Epoch 344 completed in 38.408918619155884 Loss: 0.006430976057488203 1046 | Test accuracy: 98.49 1047 | 1048 | Epoch 345 completed in 39.0126531124115 Loss: 0.006400455011469207 1049 | Test accuracy: 98.49 1050 | 1051 | Epoch 346 completed in 37.629403591156006 Loss: 0.006370185876855075 1052 | Test accuracy: 98.49 1053 | 1054 | Epoch 347 completed in 38.74530506134033 Loss: 0.006340165763460951 1055 | Test accuracy: 98.49 1056 | 1057 | Epoch 348 completed in 37.27450776100159 Loss: 0.006310391822908672 1058 | Test accuracy: 98.49 1059 | 1060 | Epoch 349 completed in 38.51545071601868 Loss: 0.0062808612479056695 1061 | Test accuracy: 98.49 1062 | 1063 | Epoch 350 completed in 40.301727056503296 Loss: 0.006251571271538151 1064 | Test accuracy: 98.49 1065 | 1066 | Epoch 351 completed in 37.786476612091064 Loss: 0.006222519166578285 1067 | Test accuracy: 98.49 1068 | 1069 | Epoch 352 completed in 37.13760185241699 Loss: 0.006193702244805039 1070 | Test accuracy: 98.49 1071 | 1072 | Epoch 353 completed in 38.533994913101196 Loss: 0.006165117856338406 1073 | Test accuracy: 98.49 1074 | 1075 | Epoch 354 completed in 37.92570662498474 Loss: 0.006136763388986703 1076 | Test accuracy: 98.49 1077 | 1078 | Epoch 355 completed in 37.7488112449646 Loss: 0.0061086362676066415 1079 | Test accuracy: 98.49 1080 | 1081 | Epoch 356 completed in 36.991764307022095 Loss: 0.006080733953475927 1082 | Test accuracy: 98.49 1083 | 1084 | Epoch 357 completed in 34.677921533584595 Loss: 0.006053053943678066 1085 | Test accuracy: 98.49 1086 | 1087 | Epoch 358 completed in 37.77248978614807 Loss: 0.00602559377049916 1088 | Test accuracy: 98.49 1089 | 1090 | Epoch 359 completed in 34.8337676525116 Loss: 0.005998351000836383 1091 | Test accuracy: 98.49 1092 | 1093 | Epoch 360 completed in 37.62636756896973 Loss: 0.0059713232356179155 1094 | Test accuracy: 98.49 1095 | 1096 | Epoch 361 completed in 37.180826902389526 Loss: 0.00594450810923408 1097 | Test accuracy: 98.49 1098 | 1099 | Epoch 362 completed in 39.929572343826294 Loss: 0.005917903288979428 1100 | Test accuracy: 98.49 1101 | 1102 | Epoch 363 completed in 39.2830605506897 Loss: 0.005891506474505546 1103 | Test accuracy: 98.49 1104 | 1105 | Epoch 364 completed in 39.15829420089722 Loss: 0.005865315397284356 1106 | Test accuracy: 98.49 1107 | 1108 | Epoch 365 completed in 38.79205369949341 Loss: 0.0058393278200816676 1109 | Test accuracy: 98.49 1110 | 1111 | Epoch 366 completed in 35.68124270439148 Loss: 0.005813541536440797 1112 | Test accuracy: 98.49 1113 | 1114 | Epoch 367 completed in 37.78404664993286 Loss: 0.005787954370175967 1115 | Test accuracy: 98.49 1116 | 1117 | Epoch 368 completed in 37.40528082847595 Loss: 0.005762564174875373 1118 | Test accuracy: 98.49 1119 | 1120 | Epoch 369 completed in 36.80554008483887 Loss: 0.005737368833413633 1121 | Test accuracy: 98.49 1122 | 1123 | Epoch 370 completed in 37.62854886054993 Loss: 0.0057123662574734575 1124 | Test accuracy: 98.49 1125 | 1126 | Epoch 371 completed in 36.3814480304718 Loss: 0.005687554387076335 1127 | Test accuracy: 98.49 1128 | 1129 | Epoch 372 completed in 38.61729192733765 Loss: 0.00566293119012203 1130 | Test accuracy: 98.49 1131 | 1132 | Epoch 373 completed in 40.23972964286804 Loss: 0.0056384946619367545 1133 | Test accuracy: 98.49 1134 | 1135 | Epoch 374 completed in 38.78841686248779 Loss: 0.005614242824829737 1136 | Test accuracy: 98.49 1137 | 1138 | Epoch 375 completed in 34.980125427246094 Loss: 0.005590173727658136 1139 | Test accuracy: 98.49 1140 | 1141 | Epoch 376 completed in 38.544368505477905 Loss: 0.0055662854454000075 1142 | Test accuracy: 98.49 1143 | 1144 | Epoch 377 completed in 42.62496829032898 Loss: 0.005542576078735229 1145 | Test accuracy: 98.49 1146 | 1147 | Epoch 378 completed in 43.57392716407776 Loss: 0.005519043753634189 1148 | Test accuracy: 98.5 1149 | 1150 | Epoch 379 completed in 37.1950261592865 Loss: 0.0054956866209540794 1151 | Test accuracy: 98.5 1152 | 1153 | Epoch 380 completed in 37.738606452941895 Loss: 0.005472502856042646 1154 | Test accuracy: 98.5 1155 | 1156 | Epoch 381 completed in 39.763909339904785 Loss: 0.005449490658349217 1157 | Test accuracy: 98.5 1158 | 1159 | Epoch 382 completed in 37.54019641876221 Loss: 0.005426648251042878 1160 | Test accuracy: 98.5 1161 | 1162 | Epoch 383 completed in 37.19136333465576 Loss: 0.005403973880637659 1163 | Test accuracy: 98.5 1164 | 1165 | Epoch 384 completed in 35.86076307296753 Loss: 0.005381465816624554 1166 | Test accuracy: 98.5 1167 | 1168 | Epoch 385 completed in 35.96437478065491 Loss: 0.005359122351110235 1169 | Test accuracy: 98.5 1170 | 1171 | Epoch 386 completed in 39.54675030708313 Loss: 0.005336941798462385 1172 | Test accuracy: 98.5 1173 | 1174 | Epoch 387 completed in 38.74086332321167 Loss: 0.005314922494961414 1175 | Test accuracy: 98.5 1176 | 1177 | Epoch 388 completed in 36.47425556182861 Loss: 0.0052930627984585235 1178 | Test accuracy: 98.5 1179 | 1180 | Epoch 389 completed in 37.26449990272522 Loss: 0.005271361088039914 1181 | Test accuracy: 98.5 1182 | 1183 | Epoch 390 completed in 36.476826667785645 Loss: 0.005249815763697066 1184 | Test accuracy: 98.5 1185 | 1186 | Epoch 391 completed in 37.59798741340637 Loss: 0.005228425246002934 1187 | Test accuracy: 98.5 1188 | 1189 | Epoch 392 completed in 37.797717571258545 Loss: 0.0052071879757939555 1190 | Test accuracy: 98.49 1191 | 1192 | Epoch 393 completed in 38.63732028007507 Loss: 0.00518610241385775 1193 | Test accuracy: 98.49 1194 | 1195 | Epoch 394 completed in 35.482884883880615 Loss: 0.005165167040626389 1196 | Test accuracy: 98.49 1197 | 1198 | Epoch 395 completed in 35.0946683883667 Loss: 0.005144380355875132 1199 | Test accuracy: 98.5 1200 | 1201 | Epoch 396 completed in 36.59082531929016 Loss: 0.005123740878426509 1202 | Test accuracy: 98.5 1203 | 1204 | Epoch 397 completed in 38.58550405502319 Loss: 0.005103247145859664 1205 | Test accuracy: 98.5 1206 | 1207 | Epoch 398 completed in 35.69037985801697 Loss: 0.005082897714224818 1208 | Test accuracy: 98.5 1209 | 1210 | Epoch 399 completed in 38.30183672904968 Loss: 0.005062691157762785 1211 | Test accuracy: 98.5 1212 | -------------------------------------------------------------------------------- /experiments/mnist/bp-V0/results.json: -------------------------------------------------------------------------------- 1 | {"epoch0": {"loss": 1.4687610010241647, "test_accuracy": 88.0}, "epoch1": {"loss": 0.8008829692414702, "test_accuracy": 89.69}, "epoch2": {"loss": 0.71097715949614, "test_accuracy": 90.58}, "epoch3": {"loss": 0.6571831625533974, "test_accuracy": 91.28}, "epoch4": {"loss": 0.61214466687775, "test_accuracy": 91.84}, "epoch5": {"loss": 0.5707083757889757, "test_accuracy": 92.39}, "epoch6": {"loss": 0.5324278237935055, "test_accuracy": 92.93}, "epoch7": {"loss": 0.4976917461804425, "test_accuracy": 93.38}, "epoch8": {"loss": 0.4665646288715298, "test_accuracy": 93.78}, "epoch9": {"loss": 0.4387485861138838, "test_accuracy": 94.12}, "epoch10": {"loss": 0.4138198848508743, "test_accuracy": 94.41}, "epoch11": {"loss": 0.3913751547424836, "test_accuracy": 94.69}, "epoch12": {"loss": 0.37107344931567104, "test_accuracy": 94.9}, "epoch13": {"loss": 0.352634303358013, "test_accuracy": 95.07}, "epoch14": {"loss": 0.33582625447662523, "test_accuracy": 95.24}, "epoch15": {"loss": 0.3204563490482775, "test_accuracy": 95.56}, "epoch16": {"loss": 0.30636193841408266, "test_accuracy": 95.75}, "epoch17": {"loss": 0.29340443528590415, "test_accuracy": 95.92}, "epoch18": {"loss": 0.28146447919195755, "test_accuracy": 96.09}, "epoch19": {"loss": 0.27043806828807976, "test_accuracy": 96.23}, "epoch20": {"loss": 0.26023352940833017, "test_accuracy": 96.38}, "epoch21": {"loss": 0.2507693255475835, "test_accuracy": 96.44}, "epoch22": {"loss": 0.24197256023477692, "test_accuracy": 96.54}, "epoch23": {"loss": 0.23377792567585012, "test_accuracy": 96.61}, "epoch24": {"loss": 0.22612688878892268, "test_accuracy": 96.69}, "epoch25": {"loss": 0.2189670082246431, "test_accuracy": 96.76}, "epoch26": {"loss": 0.21225133871326055, "test_accuracy": 96.83}, "epoch27": {"loss": 0.20593790385708335, "test_accuracy": 96.92}, "epoch28": {"loss": 0.19998922427882163, "test_accuracy": 97.0}, "epoch29": {"loss": 0.19437188821737075, "test_accuracy": 97.07}, "epoch30": {"loss": 0.1890561530772635, "test_accuracy": 97.13}, "epoch31": {"loss": 0.18401557061531254, "test_accuracy": 97.15}, "epoch32": {"loss": 0.17922663342023648, "test_accuracy": 97.17}, "epoch33": {"loss": 0.17466844383369412, "test_accuracy": 97.23}, "epoch34": {"loss": 0.17032240776419774, "test_accuracy": 97.27}, "epoch35": {"loss": 0.16617195555268427, "test_accuracy": 97.31}, "epoch36": {"loss": 0.1622022910436455, "test_accuracy": 97.32}, "epoch37": {"loss": 0.15840016887918465, "test_accuracy": 97.33}, "epoch38": {"loss": 0.15475369901519803, "test_accuracy": 97.41}, "epoch39": {"loss": 0.15125217665905905, "test_accuracy": 97.43}, "epoch40": {"loss": 0.1478859353019294, "test_accuracy": 97.45}, "epoch41": {"loss": 0.14464622027213242, "test_accuracy": 97.49}, "epoch42": {"loss": 0.14152508020964835, "test_accuracy": 97.53}, "epoch43": {"loss": 0.13851527396543642, "test_accuracy": 97.56}, "epoch44": {"loss": 0.13561019059801496, "test_accuracy": 97.6}, "epoch45": {"loss": 0.13280378035592613, "test_accuracy": 97.63}, "epoch46": {"loss": 0.1300904948027863, "test_accuracy": 97.64}, "epoch47": {"loss": 0.12746523455119663, "test_accuracy": 97.65}, "epoch48": {"loss": 0.12492330338682762, "test_accuracy": 97.67}, "epoch49": {"loss": 0.12246036784095315, "test_accuracy": 97.7}, "epoch50": {"loss": 0.1200724214812643, "test_accuracy": 97.71}, "epoch51": {"loss": 0.11775575333484753, "test_accuracy": 97.77}, "epoch52": {"loss": 0.11550691994956298, "test_accuracy": 97.79}, "epoch53": {"loss": 0.11332272066115481, "test_accuracy": 97.79}, "epoch54": {"loss": 0.11120017567902507, "test_accuracy": 97.82}, "epoch55": {"loss": 0.10913650664215353, "test_accuracy": 97.83}, "epoch56": {"loss": 0.10712911933115235, "test_accuracy": 97.85}, "epoch57": {"loss": 0.10517558825337221, "test_accuracy": 97.86}, "epoch58": {"loss": 0.10327364284532946, "test_accuracy": 97.89}, "epoch59": {"loss": 0.10142115506097324, "test_accuracy": 97.88}, "epoch60": {"loss": 0.09961612813671882, "test_accuracy": 97.87}, "epoch61": {"loss": 0.09785668634693014, "test_accuracy": 97.88}, "epoch62": {"loss": 0.09614106558970646, "test_accuracy": 97.89}, "epoch63": {"loss": 0.09446760467570174, "test_accuracy": 97.89}, "epoch64": {"loss": 0.09283473723412336, "test_accuracy": 97.9}, "epoch65": {"loss": 0.09124098419795468, "test_accuracy": 97.92}, "epoch66": {"loss": 0.08968494687693031, "test_accuracy": 97.94}, "epoch67": {"loss": 0.08816530065904996, "test_accuracy": 97.93}, "epoch68": {"loss": 0.08668078938681398, "test_accuracy": 97.94}, "epoch69": {"loss": 0.08523022042812725, "test_accuracy": 97.96}, "epoch70": {"loss": 0.08381246041213777, "test_accuracy": 97.97}, "epoch71": {"loss": 0.08242643154607313, "test_accuracy": 97.96}, "epoch72": {"loss": 0.08107110839083552, "test_accuracy": 97.96}, "epoch73": {"loss": 0.07974551496260994, "test_accuracy": 97.98}, "epoch74": {"loss": 0.0784487220442137, "test_accuracy": 97.99}, "epoch75": {"loss": 0.0771798446231022, "test_accuracy": 98.0}, "epoch76": {"loss": 0.07593803941002635, "test_accuracy": 98.01}, "epoch77": {"loss": 0.07472250242358494, "test_accuracy": 98.02}, "epoch78": {"loss": 0.07353246664681481, "test_accuracy": 98.01}, "epoch79": {"loss": 0.07236719977237531, "test_accuracy": 98.03}, "epoch80": {"loss": 0.07122600205509873, "test_accuracy": 98.07}, "epoch81": {"loss": 0.07010820428764654, "test_accuracy": 98.07}, "epoch82": {"loss": 0.06901316590933923, "test_accuracy": 98.07}, "epoch83": {"loss": 0.06794027325181365, "test_accuracy": 98.08}, "epoch84": {"loss": 0.06688893791920361, "test_accuracy": 98.11}, "epoch85": {"loss": 0.06585859529570981, "test_accuracy": 98.1}, "epoch86": {"loss": 0.06484870317001602, "test_accuracy": 98.12}, "epoch87": {"loss": 0.0638587404640846, "test_accuracy": 98.14}, "epoch88": {"loss": 0.06288820605332689, "test_accuracy": 98.17}, "epoch89": {"loss": 0.06193661766579366, "test_accuracy": 98.18}, "epoch90": {"loss": 0.061003510849600456, "test_accuracy": 98.18}, "epoch91": {"loss": 0.060088437999979975, "test_accuracy": 98.18}, "epoch92": {"loss": 0.05919096743981311, "test_accuracy": 98.18}, "epoch93": {"loss": 0.05831068254990535, "test_accuracy": 98.18}, "epoch94": {"loss": 0.05744718094735896, "test_accuracy": 98.19}, "epoch95": {"loss": 0.05660007371191274, "test_accuracy": 98.2}, "epoch96": {"loss": 0.05576898466094039, "test_accuracy": 98.21}, "epoch97": {"loss": 0.05495354967388407, "test_accuracy": 98.25}, "epoch98": {"loss": 0.05415341606633391, "test_accuracy": 98.24}, "epoch99": {"loss": 0.053368242012929445, "test_accuracy": 98.25}, "epoch100": {"loss": 0.052597696017000574, "test_accuracy": 98.25}, "epoch101": {"loss": 0.0518414564236445, "test_accuracy": 98.25}, "epoch102": {"loss": 0.05109921097197344, "test_accuracy": 98.26}, "epoch103": {"loss": 0.05037065638171411, "test_accuracy": 98.29}, "epoch104": {"loss": 0.04965549796924587, "test_accuracy": 98.3}, "epoch105": {"loss": 0.048953449288490755, "test_accuracy": 98.3}, "epoch106": {"loss": 0.04826423179271491, "test_accuracy": 98.31}, "epoch107": {"loss": 0.04758757451412534, "test_accuracy": 98.32}, "epoch108": {"loss": 0.0469232137590131, "test_accuracy": 98.32}, "epoch109": {"loss": 0.04627089281698532, "test_accuracy": 98.32}, "epoch110": {"loss": 0.04563036168346673, "test_accuracy": 98.32}, "epoch111": {"loss": 0.04500137679509542, "test_accuracy": 98.32}, "epoch112": {"loss": 0.04438370077788588, "test_accuracy": 98.32}, "epoch113": {"loss": 0.04377710220810704, "test_accuracy": 98.32}, "epoch114": {"loss": 0.04318135538576708, "test_accuracy": 98.33}, "epoch115": {"loss": 0.042596240120459115, "test_accuracy": 98.34}, "epoch116": {"loss": 0.042021541529152676, "test_accuracy": 98.34}, "epoch117": {"loss": 0.041457049845361575, "test_accuracy": 98.34}, "epoch118": {"loss": 0.04090256023901066, "test_accuracy": 98.34}, "epoch119": {"loss": 0.040357872646286486, "test_accuracy": 98.34}, "epoch120": {"loss": 0.03982279160879374, "test_accuracy": 98.34}, "epoch121": {"loss": 0.03929712612144466, "test_accuracy": 98.33}, "epoch122": {"loss": 0.03878068948866027, "test_accuracy": 98.34}, "epoch123": {"loss": 0.0382732991886336, "test_accuracy": 98.35}, "epoch124": {"loss": 0.03777477674556115, "test_accuracy": 98.35}, "epoch125": {"loss": 0.037284947609856886, "test_accuracy": 98.37}, "epoch126": {"loss": 0.03680364104639507, "test_accuracy": 98.36}, "epoch127": {"loss": 0.036330690030763835, "test_accuracy": 98.36}, "epoch128": {"loss": 0.03586593115334168, "test_accuracy": 98.36}, "epoch129": {"loss": 0.03540920453073885, "test_accuracy": 98.36}, "epoch130": {"loss": 0.03496035372379298, "test_accuracy": 98.36}, "epoch131": {"loss": 0.03451922566090628, "test_accuracy": 98.35}, "epoch132": {"loss": 0.03408567056510516, "test_accuracy": 98.36}, "epoch133": {"loss": 0.03365954188284936, "test_accuracy": 98.37}, "epoch134": {"loss": 0.03324069621237883, "test_accuracy": 98.37}, "epoch135": {"loss": 0.03282899322932803, "test_accuracy": 98.37}, "epoch136": {"loss": 0.03242429560751624, "test_accuracy": 98.37}, "epoch137": {"loss": 0.03202646893327853, "test_accuracy": 98.37}, "epoch138": {"loss": 0.03163538161244476, "test_accuracy": 98.37}, "epoch139": {"loss": 0.03125090477006892, "test_accuracy": 98.36}, "epoch140": {"loss": 0.030872912144172045, "test_accuracy": 98.36}, "epoch141": {"loss": 0.030501279975947888, "test_accuracy": 98.37}, "epoch142": {"loss": 0.030135886899912725, "test_accuracy": 98.37}, "epoch143": {"loss": 0.029776613838166387, "test_accuracy": 98.37}, "epoch144": {"loss": 0.029423343903110667, "test_accuracy": 98.37}, "epoch145": {"loss": 0.029075962312556936, "test_accuracy": 98.36}, "epoch146": {"loss": 0.028734356320169054, "test_accuracy": 98.36}, "epoch147": {"loss": 0.02839841516276882, "test_accuracy": 98.37}, "epoch148": {"loss": 0.02806803002440847, "test_accuracy": 98.38}, "epoch149": {"loss": 0.027743094015557242, "test_accuracy": 98.37}, "epoch150": {"loss": 0.02742350216450171, "test_accuracy": 98.37}, "epoch151": {"loss": 0.027109151417288296, "test_accuracy": 98.38}, "epoch152": {"loss": 0.026799940642299125, "test_accuracy": 98.37}, "epoch153": {"loss": 0.026495770635803335, "test_accuracy": 98.37}, "epoch154": {"loss": 0.026196544125441052, "test_accuracy": 98.38}, "epoch155": {"loss": 0.025902165769420483, "test_accuracy": 98.38}, "epoch156": {"loss": 0.025612542150084086, "test_accuracy": 98.37}, "epoch157": {"loss": 0.025327581761307807, "test_accuracy": 98.37}, "epoch158": {"loss": 0.025047194989857628, "test_accuracy": 98.37}, "epoch159": {"loss": 0.02477129409130728, "test_accuracy": 98.37}, "epoch160": {"loss": 0.024499793161418176, "test_accuracy": 98.36}, "epoch161": {"loss": 0.02423260810402081, "test_accuracy": 98.36}, "epoch162": {"loss": 0.023969656596448747, "test_accuracy": 98.36}, "epoch163": {"loss": 0.023710858053498825, "test_accuracy": 98.36}, "epoch164": {"loss": 0.023456133590757496, "test_accuracy": 98.36}, "epoch165": {"loss": 0.023205405987970947, "test_accuracy": 98.36}, "epoch166": {"loss": 0.022958599652967687, "test_accuracy": 98.37}, "epoch167": {"loss": 0.022715640586481085, "test_accuracy": 98.37}, "epoch168": {"loss": 0.022476456348076178, "test_accuracy": 98.37}, "epoch169": {"loss": 0.022240976023265264, "test_accuracy": 98.37}, "epoch170": {"loss": 0.022009130191802062, "test_accuracy": 98.38}, "epoch171": {"loss": 0.021780850897075064, "test_accuracy": 98.38}, "epoch172": {"loss": 0.02155607161647382, "test_accuracy": 98.39}, "epoch173": {"loss": 0.021334727232575817, "test_accuracy": 98.38}, "epoch174": {"loss": 0.02111675400499156, "test_accuracy": 98.39}, "epoch175": {"loss": 0.020902089542708044, "test_accuracy": 98.39}, "epoch176": {"loss": 0.020690672776783447, "test_accuracy": 98.39}, "epoch177": {"loss": 0.020482443933264195, "test_accuracy": 98.39}, "epoch178": {"loss": 0.02027734450621791, "test_accuracy": 98.39}, "epoch179": {"loss": 0.020075317230800215, "test_accuracy": 98.4}, "epoch180": {"loss": 0.019876306056297654, "test_accuracy": 98.4}, "epoch181": {"loss": 0.019680256119113136, "test_accuracy": 98.4}, "epoch182": {"loss": 0.019487113715682736, "test_accuracy": 98.4}, "epoch183": {"loss": 0.01929682627533389, "test_accuracy": 98.41}, "epoch184": {"loss": 0.019109342333113635, "test_accuracy": 98.41}, "epoch185": {"loss": 0.018924611502632898, "test_accuracy": 98.41}, "epoch186": {"loss": 0.018742584448987472, "test_accuracy": 98.41}, "epoch187": {"loss": 0.01856321286182905, "test_accuracy": 98.41}, "epoch188": {"loss": 0.018386449428669998, "test_accuracy": 98.41}, "epoch189": {"loss": 0.01821224780851335, "test_accuracy": 98.41}, "epoch190": {"loss": 0.018040562605904788, "test_accuracy": 98.41}, "epoch191": {"loss": 0.01787134934550564, "test_accuracy": 98.41}, "epoch192": {"loss": 0.01770456444728589, "test_accuracy": 98.41}, "epoch193": {"loss": 0.01754016520243248, "test_accuracy": 98.42}, "epoch194": {"loss": 0.017378109750062216, "test_accuracy": 98.42}, "epoch195": {"loss": 0.01721835705482008, "test_accuracy": 98.42}, "epoch196": {"loss": 0.017060866885432274, "test_accuracy": 98.42}, "epoch197": {"loss": 0.016905599794270303, "test_accuracy": 98.42}, "epoch198": {"loss": 0.016752517097967993, "test_accuracy": 98.42}, "epoch199": {"loss": 0.01660158085911731, "test_accuracy": 98.42}, "epoch200": {"loss": 0.016452753869053087, "test_accuracy": 98.43}, "epoch201": {"loss": 0.016305999631720306, "test_accuracy": 98.43}, "epoch202": {"loss": 0.016161282348602313, "test_accuracy": 98.43}, "epoch203": {"loss": 0.01601856690467338, "test_accuracy": 98.43}, "epoch204": {"loss": 0.015877818855325753, "test_accuracy": 98.43}, "epoch205": {"loss": 0.015739004414209157, "test_accuracy": 98.43}, "epoch206": {"loss": 0.015602090441910408, "test_accuracy": 98.43}, "epoch207": {"loss": 0.01546704443539152, "test_accuracy": 98.43}, "epoch208": {"loss": 0.015333834518097701, "test_accuracy": 98.43}, "epoch209": {"loss": 0.015202429430640261, "test_accuracy": 98.43}, "epoch210": {"loss": 0.01507279852195489, "test_accuracy": 98.43}, "epoch211": {"loss": 0.014944911740831974, "test_accuracy": 98.42}, "epoch212": {"loss": 0.014818739627713176, "test_accuracy": 98.42}, "epoch213": {"loss": 0.014694253306646634, "test_accuracy": 98.42}, "epoch214": {"loss": 0.01457142447729244, "test_accuracy": 98.42}, "epoch215": {"loss": 0.014450225406870193, "test_accuracy": 98.42}, "epoch216": {"loss": 0.014330628921941837, "test_accuracy": 98.42}, "epoch217": {"loss": 0.014212608399925702, "test_accuracy": 98.42}, "epoch218": {"loss": 0.014096137760242152, "test_accuracy": 98.42}, "epoch219": {"loss": 0.013981191454997825, "test_accuracy": 98.42}, "epoch220": {"loss": 0.013867744459124682, "test_accuracy": 98.42}, "epoch221": {"loss": 0.013755772259902292, "test_accuracy": 98.42}, "epoch222": {"loss": 0.013645250845807429, "test_accuracy": 98.42}, "epoch223": {"loss": 0.01353615669465426, "test_accuracy": 98.42}, "epoch224": {"loss": 0.013428466761011648, "test_accuracy": 98.42}, "epoch225": {"loss": 0.013322158462910561, "test_accuracy": 98.42}, "epoch226": {"loss": 0.013217209667884403, "test_accuracy": 98.42}, "epoch227": {"loss": 0.013113598678416823, "test_accuracy": 98.42}, "epoch228": {"loss": 0.013011304216904288, "test_accuracy": 98.42}, "epoch229": {"loss": 0.012910305410272683, "test_accuracy": 98.42}, "epoch230": {"loss": 0.012810581774416201, "test_accuracy": 98.42}, "epoch231": {"loss": 0.01271211319865128, "test_accuracy": 98.42}, "epoch232": {"loss": 0.012614879930395537, "test_accuracy": 98.43}, "epoch233": {"loss": 0.012518862560290382, "test_accuracy": 98.43}, "epoch234": {"loss": 0.01242404200798434, "test_accuracy": 98.42}, "epoch235": {"loss": 0.012330399508781672, "test_accuracy": 98.42}, "epoch236": {"loss": 0.012237916601337719, "test_accuracy": 98.42}, "epoch237": {"loss": 0.012146575116549187, "test_accuracy": 98.42}, "epoch238": {"loss": 0.012056357167746798, "test_accuracy": 98.42}, "epoch239": {"loss": 0.011967245142250962, "test_accuracy": 98.42}, "epoch240": {"loss": 0.0118792216943025, "test_accuracy": 98.42}, "epoch241": {"loss": 0.011792269739331979, "test_accuracy": 98.42}, "epoch242": {"loss": 0.011706372449486985, "test_accuracy": 98.42}, "epoch243": {"loss": 0.011621513250298574, "test_accuracy": 98.42}, "epoch244": {"loss": 0.011537675818338493, "test_accuracy": 98.42}, "epoch245": {"loss": 0.011454844079698673, "test_accuracy": 98.43}, "epoch246": {"loss": 0.01137300220911416, "test_accuracy": 98.43}, "epoch247": {"loss": 0.011292134629549906, "test_accuracy": 98.43}, "epoch248": {"loss": 0.011212226012079076, "test_accuracy": 98.43}, "epoch249": {"loss": 0.011133261275895078, "test_accuracy": 98.43}, "epoch250": {"loss": 0.01105522558831824, "test_accuracy": 98.43}, "epoch251": {"loss": 0.010978104364680944, "test_accuracy": 98.43}, "epoch252": {"loss": 0.010901883267998234, "test_accuracy": 98.43}, "epoch253": {"loss": 0.010826548208354875, "test_accuracy": 98.43}, "epoch254": {"loss": 0.010752085341961941, "test_accuracy": 98.43}, "epoch255": {"loss": 0.01067848106985619, "test_accuracy": 98.43}, "epoch256": {"loss": 0.010605722036232822, "test_accuracy": 98.43}, "epoch257": {"loss": 0.010533795126416506, "test_accuracy": 98.43}, "epoch258": {"loss": 0.010462687464486966, "test_accuracy": 98.43}, "epoch259": {"loss": 0.010392386410583779, "test_accuracy": 98.43}, "epoch260": {"loss": 0.010322879557920814, "test_accuracy": 98.43}, "epoch261": {"loss": 0.010254154729544348, "test_accuracy": 98.43}, "epoch262": {"loss": 0.010186199974870564, "test_accuracy": 98.43}, "epoch263": {"loss": 0.010119003566038114, "test_accuracy": 98.43}, "epoch264": {"loss": 0.010052553994110655, "test_accuracy": 98.43}, "epoch265": {"loss": 0.009986839965162229, "test_accuracy": 98.43}, "epoch266": {"loss": 0.009921850396275952, "test_accuracy": 98.43}, "epoch267": {"loss": 0.00985757441148384, "test_accuracy": 98.43}, "epoch268": {"loss": 0.009794001337672423, "test_accuracy": 98.43}, "epoch269": {"loss": 0.009731120700476152, "test_accuracy": 98.43}, "epoch270": {"loss": 0.009668922220177381, "test_accuracy": 98.43}, "epoch271": {"loss": 0.009607395807629295, "test_accuracy": 98.43}, "epoch272": {"loss": 0.009546531560215336, "test_accuracy": 98.43}, "epoch273": {"loss": 0.009486319757856635, "test_accuracy": 98.43}, "epoch274": {"loss": 0.009426750859076592, "test_accuracy": 98.43}, "epoch275": {"loss": 0.00936781549713018, "test_accuracy": 98.43}, "epoch276": {"loss": 0.00930950447620372, "test_accuracy": 98.43}, "epoch277": {"loss": 0.009251808767689575, "test_accuracy": 98.43}, "epoch278": {"loss": 0.009194719506538966, "test_accuracy": 98.43}, "epoch279": {"loss": 0.009138227987695135, "test_accuracy": 98.43}, "epoch280": {"loss": 0.00908232566260811, "test_accuracy": 98.43}, "epoch281": {"loss": 0.009027004135831743, "test_accuracy": 98.44}, "epoch282": {"loss": 0.00897225516170298, "test_accuracy": 98.44}, "epoch283": {"loss": 0.008918070641102846, "test_accuracy": 98.44}, "epoch284": {"loss": 0.00886444261829838, "test_accuracy": 98.44}, "epoch285": {"loss": 0.008811363277864186, "test_accuracy": 98.44}, "epoch286": {"loss": 0.008758824941682246, "test_accuracy": 98.44}, "epoch287": {"loss": 0.008706820066018367, "test_accuracy": 98.44}, "epoch288": {"loss": 0.008655341238673462, "test_accuracy": 98.44}, "epoch289": {"loss": 0.008604381176207829, "test_accuracy": 98.45}, "epoch290": {"loss": 0.008553932721236534, "test_accuracy": 98.45}, "epoch291": {"loss": 0.00850398883979386, "test_accuracy": 98.45}, "epoch292": {"loss": 0.008454542618764962, "test_accuracy": 98.45}, "epoch293": {"loss": 0.008405587263382653, "test_accuracy": 98.46}, "epoch294": {"loss": 0.008357116094787431, "test_accuracy": 98.46}, "epoch295": {"loss": 0.008309122547648838, "test_accuracy": 98.46}, "epoch296": {"loss": 0.00826160016784624, "test_accuracy": 98.47}, "epoch297": {"loss": 0.008214542610207206, "test_accuracy": 98.47}, "epoch298": {"loss": 0.008167943636301713, "test_accuracy": 98.47}, "epoch299": {"loss": 0.008121797112290443, "test_accuracy": 98.47}, "epoch300": {"loss": 0.008076097006825487, "test_accuracy": 98.46}, "epoch301": {"loss": 0.008030837389001856, "test_accuracy": 98.46}, "epoch302": {"loss": 0.007986012426358238, "test_accuracy": 98.46}, "epoch303": {"loss": 0.007941616382925471, "test_accuracy": 98.46}, "epoch304": {"loss": 0.007897643617321346, "test_accuracy": 98.46}, "epoch305": {"loss": 0.007854088580890248, "test_accuracy": 98.46}, "epoch306": {"loss": 0.007810945815886421, "test_accuracy": 98.47}, "epoch307": {"loss": 0.007768209953699453, "test_accuracy": 98.47}, "epoch308": {"loss": 0.007725875713120819, "test_accuracy": 98.47}, "epoch309": {"loss": 0.007683937898650273, "test_accuracy": 98.47}, "epoch310": {"loss": 0.0076423913988409135, "test_accuracy": 98.47}, "epoch311": {"loss": 0.007601231184681882, "test_accuracy": 98.47}, "epoch312": {"loss": 0.007560452308017581, "test_accuracy": 98.47}, "epoch313": {"loss": 0.007520049900002426, "test_accuracy": 98.47}, "epoch314": {"loss": 0.007480019169590153, "test_accuracy": 98.47}, "epoch315": {"loss": 0.00744035540205672, "test_accuracy": 98.47}, "epoch316": {"loss": 0.007401053957555903, "test_accuracy": 98.47}, "epoch317": {"loss": 0.007362110269706727, "test_accuracy": 98.48}, "epoch318": {"loss": 0.007323519844211861, "test_accuracy": 98.48}, "epoch319": {"loss": 0.007285278257506193, "test_accuracy": 98.48}, "epoch320": {"loss": 0.007247381155434773, "test_accuracy": 98.49}, "epoch321": {"loss": 0.0072098242519593805, "test_accuracy": 98.49}, "epoch322": {"loss": 0.007172603327892986, "test_accuracy": 98.49}, "epoch323": {"loss": 0.007135714229661387, "test_accuracy": 98.49}, "epoch324": {"loss": 0.007099152868091355, "test_accuracy": 98.49}, "epoch325": {"loss": 0.00706291521722459, "test_accuracy": 98.49}, "epoch326": {"loss": 0.007026997313156908, "test_accuracy": 98.49}, "epoch327": {"loss": 0.006991395252901946, "test_accuracy": 98.49}, "epoch328": {"loss": 0.0069561051932789005, "test_accuracy": 98.49}, "epoch329": {"loss": 0.006921123349823596, "test_accuracy": 98.49}, "epoch330": {"loss": 0.0068864459957224166, "test_accuracy": 98.49}, "epoch331": {"loss": 0.0068520694607684795, "test_accuracy": 98.49}, "epoch332": {"loss": 0.006817990130339551, "test_accuracy": 98.49}, "epoch333": {"loss": 0.006784204444397226, "test_accuracy": 98.49}, "epoch334": {"loss": 0.006750708896506758, "test_accuracy": 98.49}, "epoch335": {"loss": 0.006717500032877164, "test_accuracy": 98.49}, "epoch336": {"loss": 0.006684574451421088, "test_accuracy": 98.49}, "epoch337": {"loss": 0.006651928800833937, "test_accuracy": 98.49}, "epoch338": {"loss": 0.006619559779691875, "test_accuracy": 98.49}, "epoch339": {"loss": 0.006587464135568227, "test_accuracy": 98.49}, "epoch340": {"loss": 0.006555638664167865, "test_accuracy": 98.49}, "epoch341": {"loss": 0.0065240802084791695, "test_accuracy": 98.49}, "epoch342": {"loss": 0.006492785657943157, "test_accuracy": 98.49}, "epoch343": {"loss": 0.006461751947639371, "test_accuracy": 98.49}, "epoch344": {"loss": 0.006430976057488203, "test_accuracy": 98.49}, "epoch345": {"loss": 0.006400455011469207, "test_accuracy": 98.49}, "epoch346": {"loss": 0.006370185876855075, "test_accuracy": 98.49}, "epoch347": {"loss": 0.006340165763460951, "test_accuracy": 98.49}, "epoch348": {"loss": 0.006310391822908672, "test_accuracy": 98.49}, "epoch349": {"loss": 0.0062808612479056695, "test_accuracy": 98.49}, "epoch350": {"loss": 0.006251571271538151, "test_accuracy": 98.49}, "epoch351": {"loss": 0.006222519166578285, "test_accuracy": 98.49}, "epoch352": {"loss": 0.006193702244805039, "test_accuracy": 98.49}, "epoch353": {"loss": 0.006165117856338406, "test_accuracy": 98.49}, "epoch354": {"loss": 0.006136763388986703, "test_accuracy": 98.49}, "epoch355": {"loss": 0.0061086362676066415, "test_accuracy": 98.49}, "epoch356": {"loss": 0.006080733953475927, "test_accuracy": 98.49}, "epoch357": {"loss": 0.006053053943678066, "test_accuracy": 98.49}, "epoch358": {"loss": 0.00602559377049916, "test_accuracy": 98.49}, "epoch359": {"loss": 0.005998351000836383, "test_accuracy": 98.49}, "epoch360": {"loss": 0.0059713232356179155, "test_accuracy": 98.49}, "epoch361": {"loss": 0.00594450810923408, "test_accuracy": 98.49}, "epoch362": {"loss": 0.005917903288979428, "test_accuracy": 98.49}, "epoch363": {"loss": 0.005891506474505546, "test_accuracy": 98.49}, "epoch364": {"loss": 0.005865315397284356, "test_accuracy": 98.49}, "epoch365": {"loss": 0.0058393278200816676, "test_accuracy": 98.49}, "epoch366": {"loss": 0.005813541536440797, "test_accuracy": 98.49}, "epoch367": {"loss": 0.005787954370175967, "test_accuracy": 98.49}, "epoch368": {"loss": 0.005762564174875373, "test_accuracy": 98.49}, "epoch369": {"loss": 0.005737368833413633, "test_accuracy": 98.49}, "epoch370": {"loss": 0.0057123662574734575, "test_accuracy": 98.49}, "epoch371": {"loss": 0.005687554387076335, "test_accuracy": 98.49}, "epoch372": {"loss": 0.00566293119012203, "test_accuracy": 98.49}, "epoch373": {"loss": 0.0056384946619367545, "test_accuracy": 98.49}, "epoch374": {"loss": 0.005614242824829737, "test_accuracy": 98.49}, "epoch375": {"loss": 0.005590173727658136, "test_accuracy": 98.49}, "epoch376": {"loss": 0.0055662854454000075, "test_accuracy": 98.49}, "epoch377": {"loss": 0.005542576078735229, "test_accuracy": 98.49}, "epoch378": {"loss": 0.005519043753634189, "test_accuracy": 98.5}, "epoch379": {"loss": 0.0054956866209540794, "test_accuracy": 98.5}, "epoch380": {"loss": 0.005472502856042646, "test_accuracy": 98.5}, "epoch381": {"loss": 0.005449490658349217, "test_accuracy": 98.5}, "epoch382": {"loss": 0.005426648251042878, "test_accuracy": 98.5}, "epoch383": {"loss": 0.005403973880637659, "test_accuracy": 98.5}, "epoch384": {"loss": 0.005381465816624554, "test_accuracy": 98.5}, "epoch385": {"loss": 0.005359122351110235, "test_accuracy": 98.5}, "epoch386": {"loss": 0.005336941798462385, "test_accuracy": 98.5}, "epoch387": {"loss": 0.005314922494961414, "test_accuracy": 98.5}, "epoch388": {"loss": 0.0052930627984585235, "test_accuracy": 98.5}, "epoch389": {"loss": 0.005271361088039914, "test_accuracy": 98.5}, "epoch390": {"loss": 0.005249815763697066, "test_accuracy": 98.5}, "epoch391": {"loss": 0.005228425246002934, "test_accuracy": 98.5}, "epoch392": {"loss": 0.0052071879757939555, "test_accuracy": 98.49}, "epoch393": {"loss": 0.00518610241385775, "test_accuracy": 98.49}, "epoch394": {"loss": 0.005165167040626389, "test_accuracy": 98.49}, "epoch395": {"loss": 0.005144380355875132, "test_accuracy": 98.5}, "epoch396": {"loss": 0.005123740878426509, "test_accuracy": 98.5}, "epoch397": {"loss": 0.005103247145859664, "test_accuracy": 98.5}, "epoch398": {"loss": 0.005082897714224818, "test_accuracy": 98.5}, "epoch399": {"loss": 0.005062691157762785, "test_accuracy": 98.5}} -------------------------------------------------------------------------------- /experiments/mnist/fa-V0/experiment_config.txt: -------------------------------------------------------------------------------- 1 | dataset mnist 2 | algo fa 3 | n_epochs 400 4 | size_hidden_layers [500] 5 | batch_size 128 6 | learning_rate 0.2 7 | test_frequency 1 8 | save_dir ./experiments 9 | seed 1111 10 | -------------------------------------------------------------------------------- /experiments/mnist/kp-V0/experiment_config.txt: -------------------------------------------------------------------------------- 1 | dataset mnist 2 | algo kp 3 | n_epochs 400 4 | size_hidden_layers [500] 5 | batch_size 128 6 | learning_rate 0.3 7 | test_frequency 1 8 | save_dir ./experiments 9 | seed 1111 10 | -------------------------------------------------------------------------------- /experiments/mnist/wm-V0/experiment_config.txt: -------------------------------------------------------------------------------- 1 | dataset mnist 2 | algo wm 3 | n_epochs 400 4 | size_hidden_layers [500] 5 | batch_size 128 6 | learning_rate 0.05 7 | test_frequency 1 8 | save_dir ./experiments 9 | seed 1111 10 | -------------------------------------------------------------------------------- /fcnn/FCNN_BP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import numpy as np 5 | 6 | 7 | class FCNN_BP(object): 8 | ''' 9 | Description: Class to define a Fully Connected Neural Network (FCNN) 10 | with backpropagation (BP) as learning algorithm 11 | ''' 12 | 13 | def __init__(self, sizes, save_dir): 14 | ''' 15 | Description: initialize the biases and weights using a Gaussian 16 | distribution with mean 0, and variance 1. 17 | Params: 18 | - sizes: a list of size L; where L is the number of layers 19 | in the deep neural network and each element of list contains 20 | the number of neuron in that layer. 21 | first and last elements of the list corresponds to the input 22 | layer and output layer respectively 23 | intermediate layers are hidden layers. 24 | - save_dir: the directory where all the data of experiment will be saved 25 | ''' 26 | self.num_layers = len(sizes) 27 | self.save_dir = save_dir 28 | # setting appropriate dimensions for weights and biases 29 | self.biases = [np.sqrt(1. / (x + y)) * np.random.randn(y, 1) 30 | for x, y in zip(sizes[:-1], sizes[1:])] 31 | self.weights = [np.sqrt(1. / (x + y)) * np.random.randn(y, x) 32 | for x, y in zip(sizes[:-1], sizes[1:])] 33 | 34 | # define the variables to save data in during training and testing 35 | self.data = {} 36 | 37 | def print_and_log(self, log_str): 38 | ''' 39 | Description: Print and log messages during experiments 40 | Params: 41 | - log_str: the string to log 42 | ''' 43 | print(log_str) 44 | with open(os.path.join(self.save_dir, 'log.txt'), 'a') as f_: 45 | f_.write(log_str + '\n') 46 | 47 | def sigmoid(self, out): 48 | ''' 49 | Description: the sigmoid activation function 50 | Params: 51 | - out: a list or a matrix to perform the activation on 52 | Outputs: the sigmoid activated list or a matrix 53 | ''' 54 | return 1.0 / (1.0 + np.exp(-out)) 55 | 56 | def delta_sigmoid(self, out): 57 | ''' 58 | Description: the derivative of sigmoid activation function 59 | Params: 60 | - out: a list or a matrix to perform the activation on 61 | Outputs: the sigmoid prime activated list or matrix 62 | ''' 63 | return self.sigmoid(out) * (1 - self.sigmoid(out)) 64 | 65 | def SigmoidCrossEntropyLoss(self, a, y): 66 | """ 67 | Description: the cross entropy loss 68 | Params: 69 | - a: the last layer activation 70 | - y: the target one hot vector 71 | Outputs: a loss value 72 | """ 73 | return np.mean(np.sum(np.nan_to_num(-y * np.log(a) - (1 - y) * np.log(1 - a)), axis=0)) 74 | 75 | def feedforward(self, x): 76 | ''' 77 | Description: Forward Passes an image feature matrix through the Deep Neural 78 | Network Architecture. 79 | Params: 80 | - x: the input signal 81 | Outputs: 2 lists which stores outputs and activations at every layer: 82 | the 1st list is non-activated and 2nd list is activated 83 | ''' 84 | activation = x 85 | activations = [x] # list to store activations for every layer 86 | outs = [] # list to store out vectors for every layer 87 | for b, w in zip(self.biases, self.weights): 88 | out = np.matmul(w, activation) + b 89 | outs.append(out) 90 | activation = self.sigmoid(out) 91 | activations.append(activation) 92 | 93 | return outs, activations 94 | 95 | def get_batch(self, X, y, batch_size): 96 | ''' 97 | Description: A data iterator for batching of input signals and labels 98 | Params:: 99 | - X, y: lists of input signals and its corresponding labels 100 | - batch_size: size of the batch 101 | Outputs: a batch of input signals and labels of size equal to batch_size 102 | ''' 103 | for batch_idx in range(0, X.shape[0], batch_size): 104 | batch = (X[batch_idx:batch_idx + batch_size].T, 105 | y[batch_idx:batch_idx + batch_size].T) 106 | yield batch 107 | 108 | def train(self, X_train, y_train, X_test, y_test, batch_size, learning_rate, epochs, test_frequency): 109 | ''' 110 | Description: Batch-wise trains image features against corresponding labels. 111 | The weights and biases of the neural network are updated through 112 | backpropagation on batches using SGD 113 | The variables del_b and del_w are of same size as all the weights and biases 114 | of all the layers. The variables del_b and del_w contains the gradients which 115 | are used to update weights and biases. 116 | 117 | Params: 118 | - X_train, y_train: lists of training features and corresponding labels 119 | - X_test, y_test: lists of testing features and corresponding labels 120 | - batch_size: size of the batch 121 | - learning_rate: eta which controls the size of changes in weights & biases 122 | - epochs: no. of times to iterate over the whole data 123 | - test_frequency: the frequency of the evaluation on the test data 124 | ''' 125 | n_batches = int(X_train.shape[0] / batch_size) 126 | 127 | for j in range(epochs): 128 | # initialize the epoch field in the data to store 129 | self.data['epoch{}'.format(j)] = {} 130 | 131 | start = time.time() 132 | epoch_loss = [] 133 | batch_iter = self.get_batch(X_train, y_train, batch_size) 134 | 135 | for i in range(n_batches): 136 | (batch_X, batch_y) = next(batch_iter) 137 | batch_loss, del_b, del_w = self.backpropagate(batch_X, batch_y) 138 | epoch_loss.append(batch_loss) 139 | # update weight and biases 140 | self.weights = [w - (learning_rate / batch_size) 141 | * delw for w, delw in zip(self.weights, del_w)] 142 | self.biases = [b - (learning_rate / batch_size) 143 | * delb for b, delb in zip(self.biases, del_b)] 144 | epoch_loss = np.mean(epoch_loss) 145 | self.data['epoch{}'.format(j)]['loss'] = epoch_loss 146 | 147 | # Log the loss 148 | log_str = "\nEpoch {} completed in {:.3f}s, loss: {:.3f}".format(j, time.time() - start, epoch_loss) 149 | self.print_and_log(log_str) 150 | 151 | # Evaluate on test set 152 | test_accuracy = self.eval(X_test, y_test) 153 | log_str = "Test accuracy: {}%".format(test_accuracy) 154 | self.print_and_log(log_str) 155 | self.data['epoch{}'.format(j)]['test_accuracy'] = test_accuracy 156 | 157 | # save results as a json file 158 | with open(os.path.join(self.save_dir, 'results.json'), 'w') as f: 159 | json.dump(self.data, f) 160 | 161 | def backpropagate(self, x, y): 162 | ''' 163 | Description: Based on the derivative(delta) of cost function the gradients(rate of change 164 | of cost function with respect to weights and biases) of weights and biases are calculated. 165 | The variables del_b and del_w are of same size as all the weights and biases 166 | of all the layers. The variables del_b and del_w contains the gradients which 167 | are used to update weights and biases. 168 | Params: 169 | - x, y: training feature and corresponding label 170 | Outputs: del_b: gradient of bias 171 | del_w: gradient of weight 172 | ''' 173 | del_b = [np.zeros(b.shape) for b in self.biases] 174 | del_w = [np.zeros(w.shape) for w in self.weights] 175 | 176 | outs, activations = self.feedforward(x) 177 | 178 | # Cost function 179 | loss = self.SigmoidCrossEntropyLoss(activations[-1], y) 180 | 181 | # calculate derivative of cost Sigmoid Cross entropy which is to be minimized 182 | delta_cost = activations[-1] - y 183 | # backward pass to reduce cost gradients at output layers 184 | delta = delta_cost 185 | del_b[-1] = np.expand_dims(np.mean(delta, axis=1), axis=1) 186 | del_w[-1] = np.matmul(delta, activations[-2].T) 187 | 188 | # updating gradients of each layer using reverse or negative indexing, by propagating 189 | # gradients of previous layers to current layer so that gradients of weights and biases 190 | # at each layer can be calculated 191 | for l in range(2, self.num_layers): 192 | out = outs[-l] 193 | delta_activation = self.delta_sigmoid(out) 194 | delta = np.matmul(self.weights[-l + 1].T, delta) * delta_activation 195 | del_b[-l] = np.expand_dims(np.mean(delta, axis=1), axis=1) 196 | del_w[-l] = np.dot(delta, activations[-l - 1].T) 197 | 198 | return loss, del_b, del_w 199 | 200 | def eval(self, X, y): 201 | ''' 202 | Description: Based on trained(updated) weights and biases, predict a batch of labels and compare 203 | them with the original labels and calculate accuracy 204 | Params: 205 | - X: test input signals 206 | - y: test labels 207 | Outputs: accuracy of prediction 208 | ''' 209 | outs, activations = self.feedforward(X.T) 210 | # count the number of times the postion of the maximum value is the predicted label 211 | count = np.sum(np.argmax(activations[-1], axis=0) == np.argmax(y.T, axis=0)) 212 | test_accuracy = 100. * count / X.shape[0] 213 | return test_accuracy 214 | -------------------------------------------------------------------------------- /fcnn/FCNN_FA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import numpy as np 5 | 6 | from .FCNN_BP import FCNN_BP 7 | 8 | 9 | class FCNN_FA(FCNN_BP): 10 | ''' 11 | Description: Class to define a Fully Connected Neural Network (FCNN) 12 | with feedback alignment (FA) as learning algorithm 13 | ''' 14 | 15 | def __init__(self, sizes, save_dir): 16 | ''' 17 | Description: initialize the biases, forward weights and backward weights using 18 | a Gaussian distribution with mean 0, and variance 1. 19 | Params: 20 | - sizes: a list of size L; where L is the number of layers 21 | in the deep neural network and each element of list contains 22 | the number of neuron in that layer. 23 | first and last elements of the list corresponds to the input 24 | layer and output layer respectively 25 | intermediate layers are hidden layers. 26 | - save_dir: the directory where all the data of experiment will be saved 27 | ''' 28 | super(FCNN_FA, self).__init__(sizes, save_dir) 29 | # setting backward matrices 30 | self.backward_weights = [np.sqrt(1. / (x + y)) * np.random.randn(x, y) for x, y in zip(sizes[:-1], sizes[1:])] 31 | 32 | def train(self, X_train, y_train, X_test, y_test, batch_size, learning_rate, epochs, test_frequency): 33 | ''' 34 | Description: Batch-wise trains image features against corresponding labels. 35 | The forward weights and biases of the neural network are updated through 36 | feedback alignment on batches using SGD 37 | del_b and del_w are of same size as all the forward weights and biases 38 | of all the layers. del_b and del_w contains the gradients which 39 | are used to update forward weights and biases 40 | 41 | Params: 42 | - X_train, y_train: lists of training features and corresponding labels 43 | - X_test, y_test: lists of testing features and corresponding labels 44 | - batch_size: size of the batch 45 | - learning_rate: eta which controls the size of changes in weights & biases 46 | - epochs: no. of times to iterate over the whole data 47 | - test_frequency: the frequency of the evaluation on the test data 48 | ''' 49 | n_batches = int(X_train.shape[0] / batch_size) 50 | 51 | for j in range(epochs): 52 | # initialize the epoch field in the data to store 53 | self.data['epoch_{}'.format(j)] = {} 54 | 55 | start = time.time() 56 | epoch_loss = [] 57 | batch_iter = self.get_batch(X_train, y_train, batch_size) 58 | 59 | for i in range(n_batches): 60 | (batch_X, batch_y) = next(batch_iter) 61 | batch_loss, del_b, del_w = self.backpropagate(batch_X, batch_y) 62 | epoch_loss.append(batch_loss) 63 | # update weight and biases 64 | self.weights = [w - (learning_rate / batch_size) 65 | * delw for w, delw in zip(self.weights, del_w)] 66 | self.biases = [b - (learning_rate / batch_size) 67 | * delb for b, delb in zip(self.biases, del_b)] 68 | epoch_loss = np.mean(epoch_loss) 69 | self.data['epoch_{}'.format(j)]['loss'] = epoch_loss 70 | 71 | # Log the loss 72 | log_str = "\nEpoch {} completed in {:.3f}s, loss: {:.3f}".format(j, time.time() - start, epoch_loss) 73 | self.print_and_log(log_str) 74 | 75 | # Evaluate on test set 76 | test_accuracy = self.eval(X_test, y_test) 77 | log_str = "Test accuracy: {}%".format(test_accuracy) 78 | self.print_and_log(log_str) 79 | self.data['epoch_{}'.format(j)]['test_accuracy'] = test_accuracy 80 | 81 | # Compute angles between both weights and deltas 82 | deltas_angles, weights_angles = self.evaluate_angles(X_train, y_train) 83 | self.data['epoch_{}'.format(j)]['delta_angles'] = deltas_angles 84 | self.data['epoch_{}'.format(j)]['weight_angles'] = weights_angles 85 | 86 | # save results as a json file 87 | with open(os.path.join(self.save_dir, 'results.json'), 'w') as f: 88 | json.dump(self.data, f) 89 | 90 | def backpropagate(self, x, y, eval_delta_angle=False): 91 | ''' 92 | Description: Based on the derivative(delta) of cost function the gradients(rate of change 93 | of cost function with respect to weights and biases) of weights and biases are calculated. 94 | The variables del_b and del_w are of same size as all the forward weights and biases 95 | of all the layers. The variables del_b and del_w contains the gradients which 96 | are used to update the forward weights and biases. 97 | Params: 98 | - x, y: training feature and corresponding label 99 | - eval_delta_angle: a boolean to determine if the angle between deltas should be computed 100 | Outputs: 101 | - del_b: gradient of bias 102 | - del_w: gradient of weight 103 | ''' 104 | # Set a variable to store angle during evaluation only 105 | if eval_delta_angle: 106 | deltas_angles = {} 107 | 108 | del_b = [np.zeros(b.shape) for b in self.biases] 109 | del_w = [np.zeros(w.shape) for w in self.weights] 110 | 111 | outs, activations = self.feedforward(x) 112 | 113 | # Cost function 114 | loss = self.SigmoidCrossEntropyLoss(activations[-1], y) 115 | 116 | # calculate derivative of cost Sigmoid Cross entropy which is to be minimized 117 | delta_cost = activations[-1] - y 118 | # backward pass to reduce cost gradients at output layers 119 | delta = delta_cost 120 | del_b[-1] = np.expand_dims(np.mean(delta, axis=1), axis=1) 121 | del_w[-1] = np.matmul(delta, activations[-2].T) 122 | 123 | # updating gradients of each layer using reverse or negative indexing, by propagating 124 | # gradients of previous layers to current layer so that gradients of weights and biases 125 | # at each layer can be calculated 126 | for l in range(2, self.num_layers): 127 | out = outs[-l] 128 | delta_activation = self.delta_sigmoid(out) 129 | if eval_delta_angle: 130 | # compute both FA and BP deltas and the angle between them 131 | delta_bp = np.matmul(self.weights[-l + 1].T, delta) * delta_activation 132 | delta = np.matmul(self.backward_weights[-l + 1], delta) * delta_activation 133 | deltas_angles['layer_{}'.format(self.num_layers - l)] = self.angle_between(delta_bp, delta) 134 | else: 135 | delta = np.matmul(self.backward_weights[-l + 1], delta) * delta_activation 136 | del_b[-l] = np.expand_dims(np.mean(delta, axis=1), axis=1) 137 | del_w[-l] = np.dot(delta, activations[-l - 1].T) 138 | if eval_delta_angle: 139 | return deltas_angles 140 | else: 141 | return loss, del_b, del_w 142 | 143 | def angle_between(self, A, B): 144 | ''' 145 | Description: computes the angle between two matrices A and B 146 | Params: 147 | - A: a first matrix 148 | - B: a second matrix 149 | Outputs: 150 | - angle: the angle between the two vectors resulting from vectorizing and normalizing A and B 151 | ''' 152 | flat_A = np.reshape(A, (-1)) 153 | normalized_flat_A = flat_A / np.linalg.norm(flat_A) 154 | 155 | flat_B = np.reshape(B, (-1)) 156 | normalized_flat_B = flat_B / np.linalg.norm(flat_B) 157 | 158 | angle = (180.0 / np.pi) * np.arccos(np.clip(np.dot(normalized_flat_A, normalized_flat_B), -1.0, 1.0)) 159 | return angle 160 | 161 | def evaluate_angles(self, X_train, y_train): 162 | ''' 163 | Description: computes the angle between both: 164 | - the forward and backwards matrices 165 | - the delta signals 166 | Params: 167 | - X_train, y_train: training feature and corresponding label 168 | Outputs: 169 | - deltas_angles: the angle between the delta signal and the backpropagation delta signal 170 | - weights_angles: the angle between the forward and backwards matrices 171 | ''' 172 | 173 | # Evaluate angles between matrices 174 | weights_angles = {} 175 | for layer, (w, back_w) in enumerate(zip(self.weights, self.backward_weights)): 176 | matrix_angle = self.angle_between(w.T, back_w) 177 | weights_angles['layer_{}'.format(layer)] = matrix_angle 178 | log_str = 'In layer {} angle between matrices: {}'.format(self.num_layers - layer, matrix_angle) 179 | self.print_and_log(log_str) 180 | 181 | # Evaluate angles between delta signals 182 | [sample_x, sample_y] = list(next(self.get_batch(X_train, y_train, batch_size=1))) 183 | deltas_angles = self.backpropagate(sample_x, sample_y, eval_delta_angle=True) 184 | log_str = 'Angle between deltas: {}'.format(deltas_angles) 185 | self.print_and_log(log_str) 186 | return deltas_angles, weights_angles 187 | -------------------------------------------------------------------------------- /fcnn/FCNN_KP.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import numpy as np 5 | 6 | from .FCNN_FA import FCNN_FA 7 | 8 | 9 | class FCNN_KP(FCNN_FA): 10 | ''' 11 | Description: Class to define a Fully Connected Neural Network (FCNN) 12 | with the Kolen-Pollack (KP) algorithm as learning algorithm 13 | ''' 14 | 15 | def __init__(self, sizes, save_dir): 16 | ''' 17 | Description: initialize the biases, forward weights and backward weights using 18 | a Gaussian distribution with mean 0, and variance 1. 19 | Params: 20 | - sizes: a list of size L; where L is the number of layers 21 | in the deep neural network and each element of list contains 22 | the number of neuron in that layer. 23 | first and last elements of the list corresponds to the input 24 | layer and output layer respectively 25 | intermediate layers are hidden layers. 26 | - save_dir: the directory where all the data of experiment will be saved 27 | ''' 28 | super(FCNN_KP, self).__init__(sizes, save_dir) 29 | 30 | def train(self, X_train, y_train, X_test, y_test, batch_size, learning_rate, epochs, test_frequency, weight_decay=1): 31 | ''' 32 | Description: Batch-wise trains image features against corresponding labels. 33 | The forward and backward weights and biases of the neural network are updated through 34 | the Kolen-Pollack algorithm on batches using SGD 35 | del_b and del_w are of same size as all the forward weights and biases 36 | of all the layers. del_b and del_w contains the gradients which 37 | are used to update forward weights and biases 38 | 39 | Params: 40 | - X_train, y_train: lists of training features and corresponding labels 41 | - X_test, y_test: lists of testing features and corresponding labels 42 | - batch_size: size of the batch 43 | - learning_rate: eta which controls the size of changes in weights & biases 44 | - epochs: no. of times to iterate over the whole data 45 | - test_frequency: the frequency of the evaluation on the test data 46 | ''' 47 | n_batches = int(X_train.shape[0] / batch_size) 48 | 49 | for j in range(epochs): 50 | # initialize the epoch field in the data to store 51 | self.data['epoch_{}'.format(j)] = {} 52 | 53 | start = time.time() 54 | epoch_loss = [] 55 | batch_iter = self.get_batch(X_train, y_train, batch_size) 56 | 57 | for i in range(n_batches): 58 | (batch_X, batch_y) = next(batch_iter) 59 | batch_loss, delta_del_b, delta_del_w = self.backpropagate(batch_X, batch_y) 60 | epoch_loss.append(batch_loss) 61 | del_b = delta_del_b 62 | del_w = delta_del_w 63 | # update weight and biases 64 | self.weights = [weight_decay * w - (learning_rate / batch_size) 65 | * delw for w, delw in zip(self.weights, del_w)] 66 | self.biases = [b - (learning_rate / batch_size) 67 | * delb for b, delb in zip(self.biases, del_b)] 68 | # Update the backward matrices of the Kolen-Pollack algorithm 69 | # It is worth noticing that updating the backward weight matrices B with the same 70 | # delw term as the forward matrices W is equivalent to the update equations 16 and 17 71 | # of the paper manuscript 72 | self.backward_weights = [weight_decay * w - (learning_rate / batch_size) 73 | * delw.T for w, delw in zip(self.backward_weights, del_w)] 74 | epoch_loss = np.mean(epoch_loss) 75 | self.data['epoch_{}'.format(j)]['loss'] = epoch_loss 76 | 77 | log_str = "\nEpoch {} completed in {:.3f}s, loss: {:.3f}".format(j, time.time() - start, epoch_loss) 78 | self.print_and_log(log_str) 79 | 80 | # Evaluate on test set 81 | test_accuracy = self.eval(X_test, y_test) 82 | log_str = "Test accuracy: {}%".format(test_accuracy) 83 | self.print_and_log(log_str) 84 | self.data['epoch_{}'.format(j)]['test_accuracy'] = test_accuracy 85 | 86 | # Compute angles between both weights and deltas 87 | deltas_angles, weights_angles = self.evaluate_angles(X_train, y_train) 88 | self.data['epoch_{}'.format(j)]['delta_angles'] = deltas_angles 89 | self.data['epoch_{}'.format(j)]['weight_angles'] = weights_angles 90 | 91 | # save results as a json file 92 | with open(os.path.join(self.save_dir, 'results.json'), 'w') as f: 93 | json.dump(self.data, f) 94 | -------------------------------------------------------------------------------- /fcnn/FCNN_WM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import numpy as np 5 | 6 | from .FCNN_FA import FCNN_FA 7 | 8 | 9 | class FCNN_WM(FCNN_FA): 10 | ''' 11 | Description: Class to define a Fully Connected Neural Network (FCNN) 12 | with weight mirrors (WM) 13 | ''' 14 | 15 | def __init__(self, sizes, save_dir): 16 | ''' 17 | Description: initialize the biases, forward weights and backward weights using 18 | a Gaussian distribution with mean 0, and variance 1. 19 | Params: 20 | - sizes: a list of size L; where L is the number of layers 21 | in the deep neural network and each element of list contains 22 | the number of neuron in that layer. 23 | first and last elements of the list corresponds to the input 24 | layer and output layer respectively 25 | intermediate layers are hidden layers. 26 | - save_dir: the directory where all the data of experiment will be saved 27 | ''' 28 | super(FCNN_WM, self).__init__(sizes, save_dir) 29 | 30 | def mirror(self, batch_size, X_shape=None, mirror_learning_rate=0.01, noise_amplitude=0.1): 31 | ''' 32 | Description: weight mirroring by feeding an iid Gaussian noise through *each* layer of the network 33 | If the iid Gaussian noise is generated once and get forward-propagated, 34 | the iid property is lost for hidden layers. 35 | Params: 36 | - batch_size: size of the mirroring batch 37 | - X_shape: the shape of the noise matrix 38 | - mirror_learning_rate: eta which controls the size of changes in backward weights 39 | - noise_amplitude: the amplitude of the iid Gaussian noise 40 | ''' 41 | if not X_shape is None: 42 | n_batches = int(X_shape[0] / batch_size) 43 | else: 44 | n_batches = 1 45 | 46 | for i in range(n_batches): 47 | for layer, (b, w, back_w) in enumerate(zip(self.biases, self.weights, self.backward_weights)): 48 | noise_x = noise_amplitude * (np.random.rand(w.shape[1], batch_size) - 0.5) 49 | noise_y = self.sigmoid(np.matmul(w, noise_x) + b) 50 | # update the backward weight matrices using the equation 7 of the paper manuscript 51 | back_w += mirror_learning_rate * np.matmul(noise_x, noise_y.T) 52 | 53 | # Prevent feedback weights growing too large 54 | for layer, (b, w, back_w) in enumerate(zip(self.biases, self.weights, self.backward_weights)): 55 | x = np.random.rand(back_w.shape[1], batch_size) 56 | y = np.matmul(back_w, x) 57 | y_std = np.mean(np.std(y, axis=0)) 58 | back_w = 0.5 * back_w / y_std 59 | 60 | def train(self, X_train, y_train, X_test, y_test, batch_size, learning_rate, epochs, test_frequency): 61 | ''' 62 | Description: Batch-wise trains image features against corresponding labels. 63 | The forward and backward weights and biases of the neural network are updated through 64 | the Kolen-Pollack algorithm on batches using SGD 65 | del_b and del_w are of same size as all the forward weights and biases 66 | of all the layers. del_b and del_w contains the gradients which 67 | are used to update forward weights and biases 68 | 69 | Params: 70 | - X_train, y_train: lists of training features and corresponding labels 71 | - X_test, y_test: lists of testing features and corresponding labels 72 | - batch_size: size of the batch 73 | - learning_rate: eta which controls the size of changes in weights & biases 74 | - epochs: no. of times to iterate over the whole data 75 | - test_frequency: the frequency of the evaluation on the test data 76 | ''' 77 | n_batches = int(X_train.shape[0] / batch_size) 78 | 79 | # Start with an initial update of the backward matrices with weight mirroring 80 | self.mirror(batch_size, X_train.shape) 81 | 82 | for j in range(epochs): 83 | # initialize the epoch field in the data to store 84 | self.data['epoch_{}'.format(j)] = {} 85 | start = time.time() 86 | epoch_loss = [] 87 | batch_iter = self.get_batch(X_train, y_train, batch_size) 88 | 89 | for i in range(n_batches): 90 | (batch_X, batch_y) = next(batch_iter) 91 | batch_loss, delta_del_b, delta_del_w = self.backpropagate(batch_X, batch_y) 92 | epoch_loss.append(batch_loss) 93 | del_b = delta_del_b 94 | del_w = delta_del_w 95 | # update weight and biases 96 | self.weights = [w - (learning_rate / batch_size) 97 | * delw for w, delw in zip(self.weights, del_w)] 98 | self.biases = [b - (learning_rate / batch_size) 99 | * delb for b, delb in zip(self.biases, del_b)] 100 | # update the backward matrices with weight mirroring 101 | self.mirror(batch_size=batch_size) 102 | 103 | epoch_loss = np.mean(epoch_loss) 104 | self.data['epoch_{}'.format(j)]['loss'] = epoch_loss 105 | 106 | # Log the loss 107 | log_str = "\nEpoch {} completed in {:.3f}s, loss: {:.3f}".format(j, time.time() - start, epoch_loss) 108 | self.print_and_log(log_str) 109 | 110 | # Evaluate on test set 111 | test_accuracy = self.eval(X_test, y_test) 112 | log_str = "Test accuracy: {}%".format(test_accuracy) 113 | self.print_and_log(log_str) 114 | self.data['epoch_{}'.format(j)]['test_accuracy'] = test_accuracy 115 | 116 | # Compute angles between both weights and deltas 117 | deltas_angles, weights_angles = self.evaluate_angles(X_train, y_train) 118 | self.data['epoch_{}'.format(j)]['delta_angles'] = deltas_angles 119 | self.data['epoch_{}'.format(j)]['weight_angles'] = weights_angles 120 | 121 | # save results as a json file 122 | with open(os.path.join(self.save_dir, 'results.json'), 'w') as f: 123 | json.dump(self.data, f) 124 | -------------------------------------------------------------------------------- /fcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makrout/Deep-Learning-without-Weight-Transport/688f47addd2131684da1b7829b20365f585eee66/fcnn/__init__.py -------------------------------------------------------------------------------- /figures/cifar10/delta_angles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makrout/Deep-Learning-without-Weight-Transport/688f47addd2131684da1b7829b20365f585eee66/figures/cifar10/delta_angles.png -------------------------------------------------------------------------------- /figures/cifar10/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makrout/Deep-Learning-without-Weight-Transport/688f47addd2131684da1b7829b20365f585eee66/figures/cifar10/loss.png -------------------------------------------------------------------------------- /figures/cifar10/test_accuracies.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makrout/Deep-Learning-without-Weight-Transport/688f47addd2131684da1b7829b20365f585eee66/figures/cifar10/test_accuracies.png -------------------------------------------------------------------------------- /figures/cifar10/weight_angles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makrout/Deep-Learning-without-Weight-Transport/688f47addd2131684da1b7829b20365f585eee66/figures/cifar10/weight_angles.png -------------------------------------------------------------------------------- /figures/mnist/delta_angles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makrout/Deep-Learning-without-Weight-Transport/688f47addd2131684da1b7829b20365f585eee66/figures/mnist/delta_angles.png -------------------------------------------------------------------------------- /figures/mnist/loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makrout/Deep-Learning-without-Weight-Transport/688f47addd2131684da1b7829b20365f585eee66/figures/mnist/loss.png -------------------------------------------------------------------------------- /figures/mnist/test_accuracies.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makrout/Deep-Learning-without-Weight-Transport/688f47addd2131684da1b7829b20365f585eee66/figures/mnist/test_accuracies.png -------------------------------------------------------------------------------- /figures/mnist/weight_angles.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/makrout/Deep-Learning-without-Weight-Transport/688f47addd2131684da1b7829b20365f585eee66/figures/mnist/weight_angles.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from fcnn.FCNN_BP import FCNN_BP 2 | from fcnn.FCNN_FA import FCNN_FA 3 | from fcnn.FCNN_WM import FCNN_WM 4 | from fcnn.FCNN_KP import FCNN_KP 5 | from keras.utils.np_utils import to_categorical 6 | from keras.datasets import cifar10 7 | from keras.datasets import mnist 8 | import numpy as np 9 | import os 10 | import argparse 11 | import matplotlib 12 | matplotlib.use('Agg') 13 | 14 | 15 | def parse_args(): 16 | ''' 17 | Parse arguments from command line input 18 | ''' 19 | parser = argparse.ArgumentParser(description='Training parameters') 20 | parser.add_argument('--dataset', type=str, default='mnist', help="The dataset among `mnist` and `cifar10`", choices=['mnist', 'cifar10']) 21 | parser.add_argument('--algo', type=str, default='bp', help="The training algorithm", choices=['bp', 'fa', 'wm', 'kp']) 22 | parser.add_argument('--n_epochs', type=int, default='400', help="The number of epochs") 23 | parser.add_argument('--size_hidden_layers', type=int, nargs='+', default=[1000, 200], help="The number of hidden neurons per layer") 24 | parser.add_argument('--batch_size', type=int, default='128', help="The training batch size") 25 | parser.add_argument('--learning_rate', type=float, default='0.2', help="The training batch size") 26 | parser.add_argument('--test_frequency', type=int, default='1', help="The number of epochs after which the model is tested") 27 | parser.add_argument('--save_dir', type=str, default='./experiments', help="The folder path to save the experimental config, logs, model") 28 | parser.add_argument('--seed', type=int, default=1111, help='random seed for Numpy') 29 | args, unknown = parser.parse_known_args() 30 | return args 31 | 32 | 33 | def preprocess(dataset, normalize=True): 34 | ''' 35 | Description: helper function to load and preprocess the dataset 36 | Params: dataset = the dataset name i.e. `mnist` or `cifar10` 37 | normalize = a boolean to specify if the dataset should be normalized 38 | Outputs: Pre-processed image features and labels 39 | ''' 40 | 41 | if dataset == 'mnist': 42 | (X_train, Y_train), (X_test, Y_test) = mnist.load_data() 43 | X_train = np.reshape(X_train, (60000, 784)) 44 | X_test = np.reshape(X_test, (10000, 784)) 45 | X_train = X_train.astype('float32') 46 | X_test = X_test.astype('float32') 47 | 48 | Y_train = to_categorical(Y_train, num_classes=10) 49 | Y_test = to_categorical(Y_test, num_classes=10) 50 | 51 | elif dataset == 'cifar10': 52 | (X_train, Y_train), (X_test, Y_test) = cifar10.load_data() 53 | X_train = np.reshape(X_train, (50000, 3072)) 54 | X_test = np.reshape(X_test, (10000, 3072)) 55 | X_train = X_train.astype('float32') 56 | X_test = X_test.astype('float32') 57 | 58 | Y_train = to_categorical(Y_train, num_classes=10) 59 | Y_test = to_categorical(Y_test, num_classes=10) 60 | 61 | # Normalization of pixel values to [0-1] range 62 | if normalize: 63 | X_train /= 255 64 | X_test /= 255 65 | return (X_train, Y_train), (X_test, Y_test) 66 | 67 | 68 | def main(): 69 | # Parse arguments 70 | args = parse_args() 71 | 72 | # Set the random seed manually for reproducibility. 73 | np.random.seed(args.seed) 74 | 75 | # Use flags passed to the script to make the name for the experimental dir 76 | experiment_path = os.path.join(args.save_dir, args.dataset, args.algo) 77 | 78 | print('\n########## Setting Up Experiment ######################') 79 | # Increment a counter so that previous results with the same args will not be overwritten. 80 | i = 0 81 | while os.path.exists(experiment_path + "-V" + str(i)): 82 | i += 1 83 | experiment_path = experiment_path + "-V" + str(i) 84 | 85 | # Creates an experimental directory and dumps all the args to a text file 86 | os.makedirs(experiment_path) 87 | print("\nPutting log in {}".format(experiment_path)) 88 | with open(os.path.join(experiment_path, 'experiment_config.txt'), 'w') as f: 89 | for arg, value in vars(args).items(): 90 | f.write(arg + ' ' + str(value) + '\n') 91 | 92 | # load and preprocess the dataset 93 | (X_train, Y_train), (X_test, Y_test) = preprocess(args.dataset) 94 | 95 | # Add the size of the input and output layer depending on the dataset 96 | size_layers = [X_train.shape[1]] + args.size_hidden_layers + [Y_train.shape[1]] 97 | 98 | log_str = ("\n" + "=" * 20 + "\n") + \ 99 | "Running the code with the dataset {}:\n".format(args.dataset) + \ 100 | "\tlearning algorithm: {}\n".format(args.algo) + \ 101 | "\tbatch_size: {}\n".format(args.batch_size) + \ 102 | "\tlearning rate: {}\n".format(args.learning_rate) + \ 103 | "\tn_epochs: {}\n".format(args.n_epochs) + \ 104 | "\ttest frequency: {}\n".format(args.test_frequency) + \ 105 | "\tsize_layers: {}\n".format(size_layers) + \ 106 | "=" * 20 + "\n" 107 | print(log_str) 108 | with open(os.path.join(experiment_path, 'log.txt'), 'a') as f_: 109 | f_.write(log_str + '\n') 110 | 111 | # Select the network with the chosen learning algorithm to run 112 | if args.algo == "bp": 113 | model = FCNN_BP(size_layers, experiment_path) 114 | elif args.algo == "fa": 115 | model = FCNN_FA(size_layers, experiment_path) 116 | elif args.algo == "wm": 117 | model = FCNN_WM(size_layers, experiment_path) 118 | elif args.algo == "kp": 119 | model = FCNN_KP(size_layers, experiment_path) 120 | 121 | # Run the training 122 | model.train(X_train, Y_train, X_test, Y_test, args.batch_size, args.learning_rate, args.n_epochs, args.test_frequency) 123 | 124 | 125 | if __name__ == '__main__': 126 | main() 127 | -------------------------------------------------------------------------------- /plot_figures.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import os 3 | import json 4 | import numpy as np 5 | import matplotlib 6 | matplotlib.use('Agg') 7 | 8 | 9 | def leafcounter(node): 10 | ''' 11 | Description: helper function to count the number of lists in a dicitonary. 12 | Knowing the number of data to plot will determine the number 13 | of distinct color to create 14 | Params: node = the dictionary containing saved data 15 | Outputs: the number of lists to plot 16 | ''' 17 | if isinstance(node, dict): 18 | return sum([leafcounter(node[n]) for n in node]) 19 | else: 20 | return 1 21 | 22 | 23 | def get_spaced_colors(n): 24 | ''' 25 | Description: generate n equally spaced colors 26 | Params: n = the number of color to generate 27 | Outputs: a generator containing n generated colors 28 | ''' 29 | cm = plt.get_cmap('gist_rainbow') 30 | for i in range(n): 31 | yield cm(1. * i / n) 32 | 33 | 34 | def plot_list_variables(data, dataset_name, x_axis_name, y_axis_name, save_dir): 35 | ''' 36 | Description: plot non nested dictionary of data 37 | Params: data = the dictionary to plot 38 | x_axis_name = the name of the x-axis 39 | y_axis_name = the name of the y-axis 40 | save_dir = the directory where to save the figures 41 | Outputs: Pre-processed image features and labels 42 | ''' 43 | colors = get_spaced_colors(len(data.keys())) 44 | for algo_name, algo_data in data.items(): 45 | plt.plot(list(np.arange(len(algo_data))), algo_data, color=next(colors)) 46 | plt.legend(tuple(list(data.keys()))) 47 | 48 | plt.xlabel(x_axis_name) 49 | plt.ylabel(y_axis_name) 50 | plt.ylim((0, 100)) 51 | plt.title('{} on {}'.format(y_axis_name.replace("_", " "), dataset_name)) 52 | plt.savefig('{}/{}.png'.format(save_dir, y_axis_name.replace(" ", "_"))) 53 | plt.clf() 54 | 55 | 56 | def plot_dict_variables(data, dataset_name, x_axis_name, y_axis_name, save_dir): 57 | ''' 58 | Description: plot nested dictionary of data 59 | Params: data = the dictionary to plot 60 | x_axis_name = the name of the x-axis 61 | y_axis_name = the name of the y-axis 62 | save_dir = the directory where to save the figures 63 | Outputs: Pre-processed image features and labels 64 | ''' 65 | colors = get_spaced_colors(leafcounter(data)) 66 | labels = [] 67 | for algo_name, algo_data in data.items(): 68 | for i, (layer, layer_data) in enumerate(algo_data.items()): 69 | plt.plot(list(np.arange(len(layer_data))), layer_data, color=next(colors)) 70 | labels.append("{} {}".format(algo_name, layer)) 71 | plt.legend(labels) 72 | 73 | plt.xlabel(x_axis_name) 74 | plt.ylabel(y_axis_name) 75 | plt.ylim((0, 120)) 76 | plt.title('{} on {}'.format(y_axis_name.replace("_", " "), dataset_name)) 77 | plt.savefig('{}/{}.png'.format(save_dir, y_axis_name.replace(" ", "_"))) 78 | plt.clf() 79 | 80 | 81 | def read_data(): 82 | ''' 83 | Description: read the json files of all experiments 84 | Outputs: a dictionary containing all the json files 85 | ''' 86 | json_files = {} 87 | for subdir, dirs, files in os.walk('./experiments'): 88 | for file in files: 89 | if file.endswith('.json'): 90 | with open(os.path.join(subdir, file), 'r') as f: 91 | algo = subdir.split("/")[-1].split('-')[0] 92 | dataset = subdir.split("/")[-2] 93 | if not dataset in json_files.keys(): 94 | json_files[dataset] = {} 95 | json_files[dataset][algo] = json.load(f) 96 | return json_files 97 | 98 | 99 | def generate_dataset_figures(json_files, dataset_name, save_dir): 100 | ''' 101 | Description: generate the figures of all experiments 102 | Params: data = the data dictionary to all experiments 103 | save_dir = the directory where to save the figures 104 | Outputs: Pre-processed image features and labels 105 | ''' 106 | # Initialize variables to plot 107 | losses = {algo: [] for algo in json_files.keys()} 108 | test_accuracies = {algo: [] for algo in json_files.keys()} 109 | delta_angles = {algo: {} for algo in json_files.keys() if algo != 'bp'} 110 | weight_angles = {algo: {} for algo in json_files.keys() if algo != 'bp'} 111 | 112 | for algo_name, algo_data in json_files.items(): 113 | 114 | for _, epoch_data in algo_data.items(): 115 | losses[algo_name].append(epoch_data['loss']) 116 | test_accuracies[algo_name].append(epoch_data['test_accuracy']) 117 | if algo_name != 'bp': 118 | for layer, angle in epoch_data['delta_angles'].items(): 119 | if layer in delta_angles[algo_name].keys(): 120 | delta_angles[algo_name][layer].append(angle) 121 | else: 122 | delta_angles[algo_name][layer] = [angle] 123 | 124 | for layer, angle in epoch_data['weight_angles'].items(): 125 | if layer in weight_angles[algo_name].keys(): 126 | weight_angles[algo_name][layer].append(angle) 127 | else: 128 | weight_angles[algo_name][layer] = [angle] 129 | 130 | # plot the `loss`, `test_accuracy`, `delta_angles` and `weight_angles` 131 | plot_list_variables(losses, dataset_name, 'epochs', 'loss', save_dir) 132 | plot_list_variables(test_accuracies, dataset_name, 'epochs', 'test_accuracies', save_dir) 133 | plot_dict_variables(delta_angles, dataset_name, 'epochs', 'delta_angles', save_dir) 134 | plot_dict_variables(weight_angles, dataset_name, 'epochs', 'weight_angles', save_dir) 135 | 136 | 137 | if __name__ == '__main__': 138 | json_files = read_data() 139 | 140 | # Generate figures 141 | for dataset_name, data in json_files.items(): 142 | save_dir = os.path.join('figures', dataset_name) 143 | if not os.path.exists(save_dir): 144 | os.makedirs(save_dir) 145 | generate_dataset_figures(data, dataset_name, save_dir) 146 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flake8 2 | autoflake 3 | pep8 4 | autopep8 5 | 6 | numpy 7 | matplotlib 8 | colormap 9 | tensorflow 10 | keras 11 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | if [ "$1" == "bp" ] 2 | then 3 | python main.py --dataset=mnist --algo=bp --n_epochs=400 --size_hidden_layers 500 --batch_size=128 --learning_rate=0.2 --test_frequency=1 4 | #python main.py --dataset=cifar10 --algo=bp --n_epochs=400 --size_hidden_layers 1000 --batch_size=128 --learning_rate=0.2 --test_frequency=1 5 | elif [ "$1" == "fa" ] 6 | then 7 | python main.py --dataset=mnist --algo=fa --n_epochs=400 --size_hidden_layers 500 --batch_size=128 --learning_rate=0.2 --test_frequency=1 8 | #python main.py --dataset=cifar10 --algo=fa --n_epochs=400 --size_hidden_layers 1000 --batch_size=128 --learning_rate=0.2 --test_frequency=1 9 | elif [ "$1" == "wm" ] 10 | then 11 | python main.py --dataset=mnist --algo=wm --n_epochs=400 --size_hidden_layers 500 --batch_size=128 --learning_rate=0.05 --test_frequency=1 12 | #python main.py --dataset=cifar10 --algo=wm --n_epochs=400 --size_hidden_layers 1000 --batch_size=128 --learning_rate=0.05 --test_frequency=1 13 | elif [ "$1" == "kp" ] 14 | then 15 | python main.py --dataset=mnist --algo=kp --n_epochs=400 --size_hidden_layers 500 --batch_size=128 --learning_rate=0.3 --test_frequency=1 16 | #python main.py --dataset=cifar10 --algo=kp --n_epochs=400 --size_hidden_layers 1000 --batch_size=128 --learning_rate=0.3 --test_frequency=1 17 | else 18 | echo "Invalid input argument. Valid ones are either `bp`, `fa`, `wm` or `kp`." 19 | exit -1 20 | fi 21 | -------------------------------------------------------------------------------- /script/autolint: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | . script/env 3 | set -xe 4 | autopep8 --in-place --recursive --max-line-length 1000 --aggressive . 5 | autoflake --in-place --remove-unused-variables --remove-all-unused-imports --recursive . 6 | -------------------------------------------------------------------------------- /script/env: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | VENV_DIR=".venv" 3 | if [ -z $CI ]; then 4 | if [ ! -d $VENV_DIR ]; then 5 | virtualenv .venv -p python3.6 6 | fi 7 | source $VENV_DIR/bin/activate 8 | fi 9 | -------------------------------------------------------------------------------- /script/up: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | . script/env 3 | 4 | pip install --upgrade pip 5 | pip install -r requirements.txt 6 | --------------------------------------------------------------------------------