├── L4 ├── __init__.py └── L4.py ├── .gitmodules ├── setup.py ├── LICENSE ├── README.md └── L4_MNIST.ipynb /L4/__init__.py: -------------------------------------------------------------------------------- 1 | from .L4 import L4Mom, L4Adam 2 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "l4-pytorch"] 2 | path = l4-pytorch 3 | url = https://github.com/iovdin/l4-pytorch.git 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup(name='L4', 4 | version='1.0', 5 | description='TensorFlow implementation of L4 stepsize adaptation', 6 | url='https://github.com/martius-lab/l4-optimizer', 7 | author='Michal Rolinek, MPI-IS Tuebingen, Autonomous Learning', 8 | author_email='michalrolinek@gmail.com', 9 | license='MIT', 10 | packages=['L4'], 11 | install_requires=[], 12 | zip_safe=False) 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Autonomous Learning Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # L4 Stepsize Adaptation Scheme 2 | 3 | By [Michal Rolínek](https://scholar.google.de/citations?user=DVdSTFQAAAAJ&hl=en), [Georg Martius](http://georg.playfulmachines.com/). 4 | 5 | [Autonomous Learning Group](https://al.is.tuebingen.mpg.de/), [Max Planck Institute for Intelligent Systems](https://is.tuebingen.mpg.de/). 6 | 7 | ## Table of Contents 8 | 0. [Introduction](#introduction) 9 | 0. [Requirements](#requirements) 10 | 0. [Installation](#installation) 11 | 0. [Usage](#usage) 12 | 0. [Notes](#notes) 13 | 14 | 15 | 16 | ## Introduction 17 | 18 | This repository contains TensorFlow implementation code for the paper ["L4: Practical loss-based stepsize adaptation for deep learning"](https://arxiv.org/abs/1802.05074). This work proposes an explicit rule for stepsize adaptation on top of existing optimizers such as Adam or momentum SGD. 19 | 20 | *Disclaimer*: This code is a PROTOTYPE and most likely contains bugs. It should work with most Tensorflow models but most likely it doesn't comply with TensorFlow production code standards. Use at your own risk. 21 | 22 | ## Requirements 23 | 24 | TensorFlow >= 1.4. 25 | 26 | For a PyTorch implementation of L4, see [l4-pytorch](https://github.com/iovdin/l4-pytorch) (also linked here as a submodule). 27 | 28 | ## Installation 29 | 30 | Either use one of the following python pip commands, 31 | 32 | ``` 33 | python -m pip install git+https://github.com/martius-lab/l4-optimizer 34 | ``` 35 | 36 | 37 | ``` 38 | python3 -m pip install git+https://github.com/martius-lab/l4-optimizer 39 | ``` 40 | 41 | or simply drop the L4/L4.py file to your project directory 42 | 43 | ## Usage 44 | 45 | Exaclty as you would expect from a TensorFlow optimizer. Empirically, good values for the 'fraction' parameter are 0.1 < fraction < 0.3, where 0.15 is set default and should work well enough in most cases. Decreasing 'fraction' is typically a good idea in case of a small batch size or more generally for very little signal in the gradients. Too high values of 'fraction' behave similarly to too high learning rates (i.e. divergence or very early plateauing). 46 | ```python 47 | import L4 48 | 49 | ... 50 | opt = L4.L4Adam(fraction=0.20) 51 | opt.minimize(loss) 52 | ... 53 | ``` 54 | 55 | or 56 | 57 | ```python 58 | import L4 59 | 60 | ... 61 | opt = L4.L4Mom() # default value fraction=0.15 is used 62 | grads_and_vars = opt.compute_gradients(loss) 63 | ... 64 | # Gradient manipulation 65 | ... 66 | 67 | opt.apply_gradients(grads_and_vars) # (!) Passing the loss is no longer needed (!) 68 | ... 69 | ``` 70 | 71 | ## Notes 72 | 73 | *Contribute*: If you spot a bug or some incompatibility, contribute via a pull request! Thank you! 74 | -------------------------------------------------------------------------------- /L4/L4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018, (Rolinek, Martius) 2 | # MIT License Agreement 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of this software 4 | # and associated documentation files (the "Software"), to deal in the Software without restriction, including without 5 | # limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 6 | # Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 7 | # 8 | # The above copyright notice and this permission notice shall be included in all copies or substantial portions of 9 | # the Software. 10 | # 11 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE 12 | # WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS 13 | # OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 14 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | """Implementation of L4 stepsize adaptation scheme proposed in (Rolinek, Martius, 2018)""" 16 | 17 | import tensorflow as tf 18 | 19 | 20 | def n_inner_product(list_of_tensors1, list_of_tensors2): 21 | return tf.add_n([tf.reduce_sum(t1*t2) for t1, t2 in zip(list_of_tensors1, list_of_tensors2)]) 22 | 23 | 24 | def time_factor(time_step): 25 | """ Routine used for bias correction in exponential moving averages, as in (Kingma, Ba, 2015) """ 26 | global_step = 1 + tf.train.get_or_create_global_step() 27 | decay = 1.0 - 1.0 / time_step 28 | return 1.0 - tf.exp((tf.cast(global_step, tf.float32)) * tf.log(decay)) 29 | 30 | 31 | class AdamTransform(object): 32 | """ 33 | Class implementing Adam (Kingma, Ba 2015) transform of the gradient. 34 | """ 35 | def __init__(self, time_scale_grad=10.0, time_scale_var=1000.0, epsilon=1e-4): 36 | self.time_scale_grad = time_scale_grad 37 | self.time_scale_var = time_scale_var 38 | self.epsilon = epsilon 39 | 40 | self.EMAgrad = tf.train.ExponentialMovingAverage(decay=1.0 - 1.0 / self.time_scale_grad) 41 | self.EMAvar = tf.train.ExponentialMovingAverage(decay=1.0 - 1.0 / self.time_scale_var) 42 | 43 | def __call__(self, grads): 44 | shadow_op_gr = self.EMAgrad.apply(grads) 45 | vars = [tf.square(grad) for grad in grads] 46 | shadow_op_var = self.EMAvar.apply(vars) 47 | 48 | with tf.control_dependencies([shadow_op_gr, shadow_op_var]): 49 | correction_term_1 = time_factor(self.time_scale_grad) 50 | avg_grads = [self.EMAgrad.average(grad) / correction_term_1 for grad in grads] 51 | 52 | correction_term_2 = time_factor(self.time_scale_var) 53 | avg_vars = [self.EMAvar.average(var) / correction_term_2 for var in vars] 54 | return [(grad / (tf.sqrt(var) + self.epsilon)) for grad, var in zip(avg_grads, avg_vars)] 55 | 56 | 57 | class MomentumTransform(object): 58 | """ 59 | Class implementing momentum transform of the gradient (here in the form of exponential moving average) 60 | """ 61 | def __init__(self, time_momentum=10.0): 62 | self.time_momentum = time_momentum 63 | self.EMAgrad = tf.train.ExponentialMovingAverage(decay=1.0-1.0/self.time_momentum) 64 | 65 | def __call__(self, grads): 66 | shadow_op_gr = self.EMAgrad.apply(grads) 67 | with tf.control_dependencies([shadow_op_gr]): 68 | correction_term = time_factor(self.time_momentum) 69 | new_grads = [self.EMAgrad.average(grad) / correction_term for grad in grads] 70 | return [tf.identity(grad) for grad in new_grads] 71 | 72 | 73 | string_to_transform = {'momentum': MomentumTransform, 74 | 'adam': AdamTransform} 75 | 76 | 77 | class L4General(tf.train.GradientDescentOptimizer): 78 | """ 79 | Class implementing the general L4 stepsize adaptation scheme as a TensorFlow optimizer. The method for applying 80 | gradients and minimizing a variable are implemented. Note that apply_gradients expects loss as an input parameter. 81 | """ 82 | def __init__(self, fraction=0.15, minloss_factor=0.9, init_factor=0.75, 83 | minloss_forget_time=1000.0, epsilon=1e-12, 84 | gradient_estimator='momentum', gradient_params=None, 85 | direction_estimator='adam', direction_params=None): 86 | """ 87 | :param fraction: [alpha], fraction of 'optimal stepsize' 88 | :param minloss_factor: [gamma], fraction of min seen loss that is considered achievable 89 | :param init_factor: [gamma_0], fraction of initial loss used to initialize L_min 90 | :param minloss_forget_time: [Tau], timescale for forgetting minimum seen loss 91 | :param epsilon: [epsilon], for numerical stability in the division 92 | :param gradient_estimator: [g], a gradient method to be used for gradient estimation 93 | :param gradient_params: dictionary of parameters to pass to gradient_estimator 94 | :param direction_estimator: [v], a gradient method used for update direction 95 | :param direction_params: dictionary of parameters to pass to direction_estimator 96 | """ 97 | tf.train.GradientDescentOptimizer.__init__(self, 1.0) 98 | with tf.variable_scope('L4Optimizer', reuse=tf.AUTO_REUSE): 99 | self.min_loss = tf.get_variable(name='min_loss', shape=(), 100 | initializer=tf.constant_initializer(0.0), trainable=False) 101 | self.fraction = fraction 102 | self.minloss_factor = minloss_factor 103 | self.minloss_increase_rate = 1.0 + 1.0 / minloss_forget_time 104 | self.epsilon = epsilon 105 | self.init_factor = init_factor 106 | 107 | if not direction_params: 108 | direction_params = {} 109 | if not gradient_params: 110 | gradient_params = {} 111 | 112 | self.grad_direction = string_to_transform[direction_estimator](**direction_params) 113 | self.deriv_estimate = string_to_transform[gradient_estimator](**gradient_params) 114 | 115 | def compute_gradients(self, loss, *args, **kwargs): 116 | self.loss = loss 117 | return super(L4General, self).compute_gradients(loss, *args, **kwargs) 118 | 119 | def apply_gradients(self, grads_and_vars, global_step=None, name=None): 120 | if not global_step: 121 | global_step = tf.train.get_or_create_global_step() 122 | # Filter variables without a gradient. 123 | grads_and_vars = [(grad, var) for grad, var in grads_and_vars if grad is not None] 124 | 125 | grads, vars = zip(*grads_and_vars) 126 | 127 | ml_newval = tf.cond(tf.equal(global_step, 0), lambda: self.init_factor*self.loss, 128 | lambda: tf.minimum(self.min_loss, self.loss)) 129 | ml_update = self.min_loss.assign(ml_newval) 130 | 131 | with tf.control_dependencies([ml_update]): 132 | directions = self.grad_direction(grads) 133 | derivatives = self.deriv_estimate(grads) 134 | 135 | min_loss_to_use = self.minloss_factor * self.min_loss 136 | l_rate = self.fraction*(self.loss - min_loss_to_use) / (n_inner_product(directions, derivatives)+self.epsilon) 137 | new_grads = [direction*l_rate for direction in directions] 138 | tf.summary.scalar('effective_learning_rate', l_rate) 139 | tf.summary.scalar('min_loss_estimate', self.min_loss) 140 | ml_update2 = self.min_loss.assign(self.minloss_increase_rate * self.min_loss) 141 | 142 | with tf.control_dependencies([ml_update2]): 143 | return tf.train.GradientDescentOptimizer.apply_gradients(self, zip(new_grads, vars), global_step, name) 144 | 145 | def minimize(self, loss, global_step=None, var_list=None, name=None): 146 | if not var_list: 147 | var_list = tf.trainable_variables() 148 | 149 | grads_and_vars = self.compute_gradients(loss, var_list) 150 | return self.apply_gradients(grads_and_vars, global_step, name) 151 | 152 | 153 | class L4Adam(L4General): 154 | """ 155 | Specialization of the L4 stepsize adaptation with Adam used for gradient updates and Mom for gradient estimation. 156 | """ 157 | def __init__(self, fraction=0.15, minloss_factor=0.9, init_factor=0.75, minloss_forget_time=1000.0, 158 | epsilon=1e-12, adam_params=None): 159 | L4General.__init__(self, fraction, minloss_factor, init_factor, minloss_forget_time, 160 | epsilon, gradient_estimator='momentum', direction_estimator='adam', 161 | direction_params=adam_params) 162 | 163 | 164 | class L4Mom(L4General): 165 | """ 166 | Specialization of the L4 stepsize adaptation with Mom used for both gradient estimation and an update direction. 167 | """ 168 | def __init__(self, fraction=0.15, minloss_factor=0.9, init_factor=0.75, minloss_forget_time=1000.0, 169 | epsilon=1e-12, mom_params=None): 170 | L4General.__init__(self, fraction, minloss_factor, init_factor, minloss_forget_time, 171 | epsilon, gradient_estimator='momentum', direction_estimator='momentum', 172 | direction_params=mom_params) 173 | -------------------------------------------------------------------------------- /L4_MNIST.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# L4 stepsize adaptation performance on MNIST" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This short notebook contains a minimum working example of L4 optimizers [Rolinek, Martius 2018](https://arxiv.org/abs/1802.05074) performing on the classical MNIST dataset.\n", 15 | "If you find bugs or make improvements: [pull requests welcome](https://github.com/martius-lab/l4-optimizer/). " 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## Imports" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": { 29 | "ExecuteTime": { 30 | "end_time": "2018-02-20T17:10:31.336385Z", 31 | "start_time": "2018-02-20T17:10:31.325392Z" 32 | } 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "from __future__ import absolute_import\n", 37 | "from __future__ import division\n", 38 | "from __future__ import print_function\n", 39 | "\n", 40 | "import tensorflow as tf\n", 41 | "import numpy as np\n", 42 | "\n", 43 | "from tensorflow.examples.tutorials.mnist import input_data\n", 44 | "from tensorflow.contrib import layers\n", 45 | "\n", 46 | "import L4.L4 as L4\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Network structure" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": { 60 | "ExecuteTime": { 61 | "end_time": "2018-02-20T17:10:36.230497Z", 62 | "start_time": "2018-02-20T17:10:36.215673Z" 63 | } 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "def mlp(x, hidden=(300,100), num_output=10):\n", 68 | " in_dim = x.get_shape().as_list()[1]\n", 69 | " y_layer = x\n", 70 | " for l,n in enumerate(hidden):\n", 71 | " W = tf.get_variable(\"W{}\".format(l), [in_dim, n],\n", 72 | " initializer=layers.xavier_initializer())\n", 73 | " b = tf.get_variable(\"b{}\".format(l), [n],\n", 74 | " initializer=tf.zeros_initializer())\n", 75 | " y_layer = tf.nn.relu(tf.matmul(y_layer, W) + b)\n", 76 | " in_dim = n\n", 77 | " W = tf.get_variable(\"W_final\", [in_dim, num_output],\n", 78 | " initializer=tf.zeros_initializer())\n", 79 | " b = tf.get_variable(\"b_final\", [num_output],\n", 80 | " initializer=tf.zeros_initializer())\n", 81 | " y = tf.matmul(y_layer, W) + b\n", 82 | " return y" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "## Training parameters" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 4, 95 | "metadata": { 96 | "ExecuteTime": { 97 | "end_time": "2018-02-20T17:10:37.827488Z", 98 | "start_time": "2018-02-20T17:10:37.824356Z" 99 | } 100 | }, 101 | "outputs": [], 102 | "source": [ 103 | "config = {'data_dir': 'data',\n", 104 | " 'epochs': 35,\n", 105 | " 'batch_size': 64,\n", 106 | " 'hidden': [300, 100],\n", 107 | " 'epochs_per_report': 1}\n" 108 | ] 109 | }, 110 | { 111 | "cell_type": "markdown", 112 | "metadata": {}, 113 | "source": [ 114 | "## Computational graph setup" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 5, 120 | "metadata": { 121 | "ExecuteTime": { 122 | "end_time": "2018-02-20T17:10:45.795608Z", 123 | "start_time": "2018-02-20T17:10:44.327630Z" 124 | } 125 | }, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n", 132 | "Extracting data/train-images-idx3-ubyte.gz\n", 133 | "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n", 134 | "Extracting data/train-labels-idx1-ubyte.gz\n", 135 | "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n", 136 | "Extracting data/t10k-images-idx3-ubyte.gz\n", 137 | "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n", 138 | "Extracting data/t10k-labels-idx1-ubyte.gz\n" 139 | ] 140 | } 141 | ], 142 | "source": [ 143 | "MNIST_size = 60000\n", 144 | "mnist = input_data.read_data_sets(config['data_dir'], one_hot=True)\n", 145 | "\n", 146 | "x = tf.placeholder(tf.float32, [None, 784])\n", 147 | "y = mlp(x, hidden=config['hidden'], num_output=10)\n", 148 | "y_ = tf.placeholder(tf.float32, [None, 10])\n", 149 | "cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))\n", 150 | "\n", 151 | "correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))\n", 152 | "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "## Optimizer choice" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": { 165 | "ExecuteTime": { 166 | "end_time": "2018-02-19T23:18:29.654940", 167 | "start_time": "2018-02-19T23:18:29.646661" 168 | } 169 | }, 170 | "source": [ 171 | "Choose an optimizer: for L4 you have the choice between L4Adam and L4Mom. Without arguments the default parameters are used, e.g. fraction=0.15\n", 172 | "\n", 173 | "The hyperparameters (if specified) are the best for achieving minimal loss after ~32 epochs." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 6, 179 | "metadata": { 180 | "ExecuteTime": { 181 | "end_time": "2018-02-20T17:10:51.073582Z", 182 | "start_time": "2018-02-20T17:10:50.943453Z" 183 | } 184 | }, 185 | "outputs": [], 186 | "source": [ 187 | "opt = L4.L4Adam(fraction=0.25) # (L4Adam)\n", 188 | "#opt = L4.L4Adam() # default value (L4Adam*)\n", 189 | "#opt = L4.L4Mom(fraction=0.25) # (L4Mom)\n", 190 | "#opt = tf.train.AdamOptimizer(0.001, epsilon=1e-4) # (Adam)\n", 191 | "#opt = tf.train.MomentumOptimizer(learning_rate=0.05, momentum=0.9) # (mSGD)\n", 192 | "#opt = tf.train.GradientDescentOptimizer(learning_rate=0.7) # (SGD)\n", 193 | "\n", 194 | "train_op = opt.minimize(cross_entropy)" 195 | ] 196 | }, 197 | { 198 | "cell_type": "markdown", 199 | "metadata": {}, 200 | "source": [ 201 | "## Training Session" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": 7, 207 | "metadata": { 208 | "ExecuteTime": { 209 | "end_time": "2018-02-20T17:12:45.146386Z", 210 | "start_time": "2018-02-20T17:10:59.951343Z" 211 | } 212 | }, 213 | "outputs": [ 214 | { 215 | "name": "stdout", 216 | "output_type": "stream", 217 | "text": [ 218 | "Epoch 0; Current Batch Loss: 2.3025853633880615\n", 219 | "Epoch 1; Current Batch Loss: 0.18231457471847534\n", 220 | "Epoch 2; Current Batch Loss: 0.02866493910551071\n", 221 | "Epoch 3; Current Batch Loss: 0.07279851287603378\n", 222 | "Epoch 4; Current Batch Loss: 0.025335367769002914\n", 223 | "Epoch 5; Current Batch Loss: 0.002143927151337266\n", 224 | "Epoch 6; Current Batch Loss: 0.017017511650919914\n", 225 | "Epoch 7; Current Batch Loss: 0.00016670141485519707\n", 226 | "Epoch 8; Current Batch Loss: 6.159079475764884e-06\n", 227 | "Epoch 9; Current Batch Loss: 0.002517925575375557\n", 228 | "Epoch 10; Current Batch Loss: 5.092173978482606e-06\n", 229 | "Epoch 11; Current Batch Loss: 6.684324944217224e-06\n", 230 | "Epoch 12; Current Batch Loss: 1.1290010661468841e-05\n", 231 | "Epoch 13; Current Batch Loss: 2.963222868856974e-06\n", 232 | "Epoch 14; Current Batch Loss: 5.531650458578952e-06\n", 233 | "Epoch 15; Current Batch Loss: 2.998848458446446e-07\n", 234 | "Epoch 16; Current Batch Loss: 0.0\n", 235 | "Epoch 17; Current Batch Loss: 3.520362952258438e-07\n", 236 | "Epoch 18; Current Batch Loss: 9.313223969797946e-09\n", 237 | "Epoch 19; Current Batch Loss: 3.3713513403199613e-07\n", 238 | "Epoch 20; Current Batch Loss: 9.313223081619526e-09\n", 239 | "Epoch 21; Current Batch Loss: 1.8626450382086546e-09\n", 240 | "Epoch 22; Current Batch Loss: 0.0\n", 241 | "Epoch 23; Current Batch Loss: 0.0\n", 242 | "Epoch 24; Current Batch Loss: 0.0\n", 243 | "Epoch 25; Current Batch Loss: 1.8626450382086546e-09\n", 244 | "Epoch 26; Current Batch Loss: 0.0\n", 245 | "Epoch 27; Current Batch Loss: 0.0\n", 246 | "Epoch 28; Current Batch Loss: 0.0\n", 247 | "Epoch 29; Current Batch Loss: 0.0\n", 248 | "Epoch 30; Current Batch Loss: 0.0\n", 249 | "Epoch 31; Current Batch Loss: 0.0\n", 250 | "Epoch 32; Current Batch Loss: 0.0\n", 251 | "Epoch 33; Current Batch Loss: 0.0\n", 252 | "Epoch 34; Current Batch Loss: 0.0\n", 253 | "Epoch 35; Current Batch Loss: 0.0\n", 254 | "Test accuracy: 0.9839000105857849\n" 255 | ] 256 | } 257 | ], 258 | "source": [ 259 | "sess = tf.InteractiveSession() \n", 260 | "tf.global_variables_initializer().run()\n", 261 | "\n", 262 | "batches_per_epoch = (MNIST_size // config['batch_size'])\n", 263 | "batches_to_run = config['epochs'] * batches_per_epoch\n", 264 | "\n", 265 | "for b in range(batches_to_run+1): \n", 266 | " batch_xs, batch_ys = mnist.train.next_batch(config['batch_size'])\n", 267 | " _, loss = sess.run((train_op, cross_entropy), feed_dict={x: batch_xs, y_: batch_ys})\n", 268 | "\n", 269 | " if b % batches_per_epoch == 0:\n", 270 | " epoch_nr = b // batches_per_epoch\n", 271 | " if epoch_nr % config['epochs_per_report'] == 0:\n", 272 | " print(\"Epoch {}; Current Batch Loss: {}\".format(epoch_nr, loss))\n", 273 | "\n", 274 | "# Test trained model\n", 275 | "accuracy_value = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})\n", 276 | "print(\"Test accuracy: {}\".format(accuracy_value))" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [] 285 | } 286 | ], 287 | "metadata": { 288 | "kernelspec": { 289 | "display_name": "Python 3", 290 | "language": "python", 291 | "name": "python3" 292 | }, 293 | "language_info": { 294 | "codemirror_mode": { 295 | "name": "ipython", 296 | "version": 3 297 | }, 298 | "file_extension": ".py", 299 | "mimetype": "text/x-python", 300 | "name": "python", 301 | "nbconvert_exporter": "python", 302 | "pygments_lexer": "ipython3", 303 | "version": "3.5.2" 304 | } 305 | }, 306 | "nbformat": 4, 307 | "nbformat_minor": 2 308 | } 309 | --------------------------------------------------------------------------------