├── Readme.md ├── fmin_adam.m └── images ├── regression_fit.png ├── regression_minibatches.png └── regression_scatter.png /Readme.md: -------------------------------------------------------------------------------- 1 | # Adam optimiser 2 | This is a `Matlab` implementation of the Adam optimiser from Kingma and Ba [[1]], designed for stochastic gradient descent. It maintains estimates of the moments of the gradient independently for each parameter. 3 | 4 | ## Usage 5 | ` [x, fval, exitflag, output] = fmin_adam(fun, x0 <, stepSize, beta1, beta2, epsilon, nEpochSize, options>)` 6 | 7 | `fmin_adam` is an implementation of the Adam optimisation algorithm (gradient descent with Adaptive learning rates individually on each parameter, with Momentum) from Kingma and Ba [[1]]. Adam is designed to work on stochastic gradient descent problems; i.e. when only small batches of data are used to estimate the gradient on each iteration, or when stochastic dropout regularisation is used [[2]]. 8 | 9 | ## Examples 10 | ###Simple regression problem with gradients 11 | 12 | Set up a simple linear regression problem: ![$$$y = x\cdot\phi_1 + \phi_2 + \zeta$$$](https://latex.codecogs.com/svg.latex?%5Cinline%20y%20%3D%20x%5Ccdot%5Cphi_1%20+%20%5Cphi_2%20+%20%5Czeta), where ![$$$\zeta \sim N(0, 0.1)$$$](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Czeta%20%5Csim%20N%280%2C%200.1%29). We'll take ![$$$\phi = \left[3, 2\right]$$$](https://latex.codecogs.com/svg.latex?%5Cinline%20%5Cphi%20%3D%20%5Cleft%5B3%2C%202%5Cright%5D) for this example. Let's draw some samples from this problem: 13 | 14 | ```matlab 15 | nDataSetSize = 1000; 16 | vfInput = rand(1, nDataSetSize); 17 | phiTrue = [3 2]; 18 | fhProblem = @(phi, vfInput) vfInput .* phi(1) + phi(2); 19 | vfResp = fhProblem(phiTrue, vfInput) + randn(1, nDataSetSize) * .1; 20 | plot(vfInput, vfResp, '.'); hold; 21 | ``` 22 | 23 | 24 | 25 | Now we define a cost function to minimise, which returns analytical gradients: 26 | 27 | ```matlab 28 | function [fMSE, vfGrad] = LinearRegressionMSEGradients(phi, vfInput, vfResp) 29 | % - Compute mean-squared error using the current parameter estimate 30 | vfRespHat = vfInput .* phi(1) + phi(2); 31 | vfDiff = vfRespHat - vfResp; 32 | fMSE = mean(vfDiff.^2) / 2; 33 | 34 | % - Compute the gradient of MSE for each parameter 35 | vfGrad(1) = mean(vfDiff .* vfInput); 36 | vfGrad(2) = mean(vfDiff); 37 | end 38 | ``` 39 | 40 | Initial parameters `phi0` are Normally distributed. Call the `fmin_adam` optimiser with a learning rate of 0.01. 41 | 42 | ```matlab 43 | phi0 = randn(2, 1); 44 | phiHat = fmin_adam(@(phi)LinearRegressionMSEGradients(phi, vfInput, vfResp), phi0, 0.01) 45 | plot(vfInput, fhProblem(phiHat, vfInput), '.'); 46 | ```` 47 | 48 | Output: 49 | 50 | Iteration Func-count f(x) Improvement Step-size 51 | ---------- ---------- ---------- ---------- ---------- 52 | 2130 4262 0.0051 5e-07 0.00013 53 | ---------- ---------- ---------- ---------- ---------- 54 | 55 | Finished optimization. 56 | Reason: Function improvement [5e-07] less than TolFun [1e-06]. 57 | 58 | phiHat = 59 | 2.9498 60 | 2.0273 61 | 62 | 63 | 64 | ###Linear regression with minibatches 65 | 66 | Set up a simple linear regression problem, as above. 67 | 68 | ```matlab 69 | nDataSetSize = 1000; 70 | vfInput = rand(1, nDataSetSize); 71 | phiTrue = [3 2]; 72 | fhProblem = @(phi, vfInput) vfInput .* phi(1) + phi(2); 73 | vfResp = fhProblem(phiTrue, vfInput) + randn(1, nDataSetSize) * .1; 74 | ``` 75 | 76 | Configure minibatches. Minibatches contain random sets of indices into the data. 77 | 78 | ```matlab 79 | nBatchSize = 50; 80 | nNumBatches = 100; 81 | mnBatches = randi(nDataSetSize, nBatchSize, nNumBatches); 82 | cvnBatches = mat2cell(mnBatches, nBatchSize, ones(1, nNumBatches)); 83 | figure; hold; 84 | cellfun(@(b)plot(vfInput(b), vfResp(b), '.'), cvnBatches); 85 | ``` 86 | 87 | 88 | Define the function to minimise; in this case, the mean-square error over the regression problem. The iteration index `nIter` defines which mini-batch to evaluate the problem over. 89 | 90 | ```matlab 91 | fhBatchInput = @(nIter) vfInput(cvnBatches{mod(nIter, nNumBatches-1)+1}); 92 | fhBatchResp = @(nIter) vfResp(cvnBatches{mod(nIter, nNumBatches-1)+1}); 93 | fhCost = @(phi, nIter) mean((fhProblem(phi, fhBatchInput(nIter)) - fhBatchResp(nIter)).^2); 94 | ``` 95 | Turn off analytical gradients for the `adam` optimiser, and ensure that we permit sufficient function calls. 96 | 97 | ```matlab 98 | sOpt = optimset('fmin_adam'); 99 | sOpt.GradObj = 'off'; 100 | sOpt.MaxFunEvals = 1e4; 101 | ``` 102 | 103 | Call the `fmin_adam` optimiser with a learning rate of `0.1`. Initial parameters are Normally distributed. 104 | 105 | ```matlab 106 | phi0 = randn(2, 1); 107 | phiHat = fmin_adam(fhCost, phi0, 0.1, [], [], [], [], sOpt) 108 | ``` 109 | The output of the optimisation process (which will differ over random data and random initialisations): 110 | 111 | Iteration Func-count f(x) Improvement Step-size 112 | ---------- ---------- ---------- ---------- ---------- 113 | 711 2848 0.3 0.0027 3.8e-06 114 | ---------- ---------- ---------- ---------- ---------- 115 | 116 | Finished optimization. 117 | Reason: Step size [3.8e-06] less than TolX [1e-05]. 118 | 119 | phiHat = 120 | 2.8949 121 | 1.9826 122 | 123 | ## Detailed usage 124 | ### Input arguments 125 | `fun` is a function handle `[fCost <, vfCdX>] = @(x <, nIter>)` defining the function to minimise . It must return the cost at the parameter `x`, optionally evaluated over a mini-batch of data. If analytical gradients are available (recommended), then `fun` must return the gradients in `vfCdX`, evaluated at `x` (optionally over a mini-batch). If analytical gradients are not available, then complex-step finite difference estimates will be used. 126 | 127 | To use analytical gradients (default), set `options.GradObj = 'on'`. To force the use of finite difference gradient estimates, set `options.GradObj = 'off'`. 128 | 129 | `fun` must be deterministic in its calculation of `fCost` and `vfCdX`, even if mini-batches are used. To this end, `fun` can accept a parameter `nIter` which specifies the current iteration of the optimisation algorithm. `fun` must return estimates over identical problems for a given value of `nIter`. 130 | 131 | Steps that do not lead to a reduction in the function to be minimised are not taken. 132 | 133 | ### Output arguments 134 | `x` will be a set of parameters estimated to minimise `fCost`. `fval` will be the value returned from `fun` at `x`. 135 | 136 | `exitflag` will be an integer value indicating why the algorithm terminated: 137 | 138 | * 0: An output or plot function indicated that the algorithm should terminate. 139 | * 1: The estimated reduction in 'fCost' was less than TolFun. 140 | * 2: The norm of the current step was less than TolX. 141 | * 3: The number of iterations exceeded MaxIter. 142 | * 4: The number of function evaluations exceeded MaxFunEvals. 143 | 144 | `output` will be a structure containing information about the optimisation process: 145 | 146 | * `.stepsize` — Norm of current parameter step 147 | * `.gradient` — Vector of current gradients evaluated at `x` 148 | * `.funccount` — Number of calls to `fun` made so far 149 | * `.iteration` — Current iteration of algorithm 150 | * `.fval` — Value returned by `fun` at `x` 151 | * `.exitflag` — Flag indicating reason that algorithm terminated 152 | * `.improvement` — Current estimated improvement in `fun` 153 | 154 | The optional parameters `stepSize`, `beta1`, `beta2` and `epsilon` are parameters of the Adam optimisation algorithm (see [[1]]). Default values of `{1e-3, 0.9, 0.999, sqrt(eps)}` are reasonable for most problems. 155 | 156 | The optional argument `nEpochSize` specifies how many iterations comprise an epoch. This is used in the convergence detection code. 157 | 158 | The optional argument `options` is used to control the optimisation process (see `optimset`). Relevant fields: 159 | 160 | * `.Display` 161 | * `.GradObj` 162 | * `.DerivativeCheck` 163 | * `.MaxFunEvals` 164 | * `.MaxIter` 165 | * `.TolFun` 166 | * `.TolX` 167 | * `.UseParallel` 168 | 169 | ## References 170 | [[1]] Diederik P. Kingma, Jimmy Ba. "Adam: A Method for Stochastic 171 | Optimization", ICLR 2015. [https://arxiv.org/abs/1412.6980](https://arxiv.org/abs/1412.6980) 172 | 173 | [[2]] Geoffrey E Hinton, Nitish Srivastava, Alex Krizhevsky, Ilya Sutskever, and Ruslan R. Salakhutdinov. "Improving neural networks by preventing co-adaptation of feature detectors." arXiv preprint. [https://arxiv.org/abs/1207.0580](https://arxiv.org/abs/1207.0580) 174 | 175 | 176 | [1]: https://arxiv.org/abs/1412.6980 177 | [2]: https://arxiv.org/abs/1207.0580 178 | 179 | -------------------------------------------------------------------------------- /fmin_adam.m: -------------------------------------------------------------------------------- 1 | function [x, fval, exitflag, output] = fmin_adam(fun, x0, stepSize, beta1, beta2, epsilon, nEpochSize, options) 2 | 3 | % fmin_adam - FUNCTION Adam optimiser, with matlab calling format 4 | % 5 | % Usage: [x, fval, exitflag, output] = fmin_adam(fun, x0 <, stepSize, beta1, beta2, epsilon, nEpochSize, options>) 6 | % 7 | % 'fmin_adam' is an implementation of the Adam optimisation algorithm 8 | % (gradient descent with Adaptive learning rates individually on each 9 | % parameter, with momentum) from [1]. Adam is designed to work on 10 | % stochastic gradient descent problems; i.e. when only small batches of 11 | % data are used to estimate the gradient on each iteration. 12 | % 13 | % 'fun' is a function handle [fCost <, vfCdX>] = @(x <, nIter>) defining 14 | % the function to minimise . It must return the cost at the parameter 'x', 15 | % optionally evaluated over a mini-batch of data. If analytical gradients 16 | % are available (recommended), then 'fun' must return the gradients in 17 | % 'vfCdX', evaluated at 'x' (optionally over a mini-batch). If analytical 18 | % gradients are not available, then complex-step finite difference 19 | % estimates will be used. 20 | % 21 | % To use analytical gradients (default), options.GradObj = 'on'. To force 22 | % the use of finite difference gradient estimates, options.GradObj = 'off'. 23 | % 24 | % 'fun' must be deterministic in its calculation of 'fCost' and 'vfCdX', 25 | % even if mini-batches are used. To this end, 'fun' can accept a parameter 26 | % 'nIter' which specifies the current iteration of the optimisation 27 | % algorithm. 'fun' must return estimates over identical problems for a 28 | % given value of 'nIter'. 29 | % 30 | % Steps that do not lead to a reduction in the function to be minimised are 31 | % not taken. 32 | % 33 | % 'x' will be a set of parameters estimated to minimise 'fCost'. 'fval' 34 | % will be the value returned from 'fun' at 'x'. 35 | % 36 | % 'exitflag' will be an integer value indicating why the algorithm 37 | % terminated: 38 | % 0: An output or plot function indicated that the algorithm should 39 | % terminate. 40 | % 1: The estimated reduction in 'fCost' was less than TolFun. 41 | % 2: The norm of the current step was less than TolX. 42 | % 3: The number of iterations exceeded MaxIter. 43 | % 4: The number of function evaluations exceeded MaxFunEvals. 44 | % 45 | % 'output' will be a structure containing information about the 46 | % optimisation process: 47 | % .stepsize - Norm of current parameter step 48 | % .gradient - Vector of current gradients 49 | % .funccount - Number of function calls to 'fun' made so far 50 | % .iteration - Current iteration of algorithm 51 | % .fval - Current value returned by 'fun' 52 | % .exitflag - Flag indicating termination reason 53 | % .improvement - Current estimated improvement in 'fun' 54 | % 55 | % The optional parameters 'stepSize', 'beta1', 'beta2' and 'epsilon' are 56 | % parameters of the Adam optimisation algorithm (see [1]). Default values 57 | % of {1e-3, 0.9, 0.999, sqrt(eps)} are reasonable for most problems. 58 | % 59 | % The optional argument 'nEpochSize' specifies how many iterations comprise 60 | % an epoch. This is used in the convergence detection code. 61 | % 62 | % The optional argument 'options' is used to control the optimisation 63 | % process (see 'optimset'). Relevant fields: 64 | % .Display 65 | % .GradObj 66 | % .DerivativeCheck 67 | % .MaxFunEvals 68 | % .MaxIter 69 | % .TolFun 70 | % .TolX 71 | % .UseParallel 72 | % 73 | % 74 | % References 75 | % [1] Diederik P. Kingma, Jimmy Ba. "Adam: A Method for Stochastic 76 | % Optimization", ICLR 2015. 77 | 78 | % Author: Dylan Muir 79 | % Created: 10th February, 2017 80 | 81 | %% - Default parameters 82 | 83 | DEF_stepSize = 0.001; 84 | DEF_beta1 = 0.9; 85 | DEF_beta2 = 0.999; 86 | DEF_epsilon = sqrt(eps); 87 | 88 | % - Default options 89 | if (isequal(fun, 'defaults')) 90 | x = struct('Display', 'final', ... 91 | 'GradObj', 'on', ... 92 | 'DerivativeCheck', 'off', ... 93 | 'MaxFunEvals', 1e4, ... 94 | 'MaxIter', 1e6, ... 95 | 'TolFun', 1e-6, ... 96 | 'TolX', 1e-5, ... 97 | 'UseParallel', false); 98 | return; 99 | end 100 | 101 | 102 | %% - Check arguments and assign defaults 103 | 104 | if (nargin < 2) 105 | help fmin_adam; 106 | error('*** fmin_adam: Incorrect usage.'); 107 | end 108 | 109 | 110 | if (~exist('stepSize', 'var') || isempty(stepSize)) 111 | stepSize = DEF_stepSize; 112 | end 113 | 114 | if (~exist('beta1', 'var') || isempty(beta1)) 115 | beta1 = DEF_beta1; 116 | end 117 | 118 | if (~exist('beta2', 'var') || isempty(beta2)) 119 | beta2 = DEF_beta2; 120 | end 121 | 122 | if (~exist('epsilon', 'var') || isempty(epsilon)) 123 | epsilon = DEF_epsilon; 124 | end 125 | 126 | if (~exist('options', 'var') || isempty(options)) 127 | options = optimset(@fmin_adam); 128 | end 129 | 130 | 131 | %% - Parse options structure 132 | 133 | numberofvariables = numel(x0); 134 | 135 | % - Are analytical gradients provided? 136 | if (isequal(options.GradObj, 'on')) 137 | % - Check supplied cost function 138 | if (nargout(fun) < 2) && (nargout(fun) ~= -1) 139 | error('*** fmin_adam: The supplied cost function must return analytical gradients, if options.GradObj = ''on''.'); 140 | end 141 | 142 | bUseAnalyticalGradients = true; 143 | nEvalsPerIter = 2; 144 | else 145 | bUseAnalyticalGradients = false; 146 | 147 | % - Wrap cost function for complex step gradients 148 | fun = @(x, nIter)FA_FunComplexStepGrad(fun, x, nIter); 149 | nEvalsPerIter = numberofvariables + 2; 150 | end 151 | 152 | % - Should we check analytical gradients? 153 | bCheckAnalyticalGradients = isequal(options.DerivativeCheck, 'on'); 154 | 155 | % - Get iteration and termination options 156 | MaxIter = FA_eval(options.MaxIter); 157 | options.MaxIter = MaxIter; 158 | options.MaxFunEvals = FA_eval(options.MaxFunEvals); 159 | 160 | % - Parallel operation is not yet implements 161 | if (options.UseParallel) 162 | warning('--- fmin_adam: Warning: ''UseParallel'' is not yet implemented.'); 163 | end 164 | 165 | 166 | %% - Check supplied function 167 | 168 | if (nargin(fun) < 2) 169 | % - Function does not make use of the 'nIter' argument, so make a wrapper 170 | fun = @(x, nIter)fun(x); 171 | end 172 | 173 | % - Check that gradients are identical for a given nIter 174 | if (~bUseAnalyticalGradients) 175 | [~, vfGrad0] = fun(x0, 1); 176 | [~, vfGrad1] = fun(x0, 1); 177 | 178 | if (max(abs(vfGrad0 - vfGrad1)) > eps(max(max(abs(vfGrad0), abs(vfGrad1))))) 179 | error('*** fmin_adam: Cost function must return identical stochastic gradients for a given ''nIter'', when analytical gradients are not provided.'); 180 | end 181 | end 182 | 183 | % - Check analytical gradients 184 | if (bUseAnalyticalGradients && bCheckAnalyticalGradients) 185 | FA_CheckGradients(fun, x0); 186 | end 187 | 188 | % - Check user function for errors 189 | try 190 | [fval0, vfCdX0] = fun(x0, 1); 191 | 192 | catch mErr 193 | error('*** fmin_adam: Error when evaluating function to minimise.'); 194 | end 195 | 196 | % - Check that initial point is reasonable 197 | if (isinf(fval0) || isnan(fval0) || any(isinf(vfCdX0) | isnan(vfCdX0))) 198 | error('*** fmin_adam: Invalid starting point for user function. NaN or Inf returned.'); 199 | end 200 | 201 | 202 | %% - Initialise algorithm 203 | 204 | % - Preallocate cost and parameters 205 | xHist = zeros(numberofvariables, MaxIter+1);%MappedTensor(numberofvariables, MaxIter+1); 206 | xHist(:, 1) = x0; 207 | x = x0; 208 | vfCost = zeros(1, MaxIter); 209 | 210 | if (~exist('nEpochSize', 'var') || isempty(nEpochSize)) 211 | nEpochSize = numberofvariables; 212 | end 213 | 214 | vfCost(1) = fval0; 215 | fLastCost = fval0; 216 | fval = fval0; 217 | 218 | % - Initialise moment estimates 219 | m = zeros(numberofvariables, 1); 220 | v = zeros(numberofvariables, 1); 221 | 222 | % - Initialise optimization values 223 | optimValues = struct('fval', vfCost(1), ... 224 | 'funccount', nEvalsPerIter, ... 225 | 'gradient', vfCdX0, ... 226 | 'iteration', 0, ... 227 | 'improvement', inf, ... 228 | 'stepsize', 0); 229 | 230 | % - Initial display 231 | FA_Display(options, x, optimValues, 'init', nEpochSize); 232 | FA_Display(options, x, optimValues, 'iter', nEpochSize); 233 | 234 | % - Initialise plot and output functions 235 | FA_CallOutputFunctions(options, x0, optimValues, 'init'); 236 | FA_CallOutputFunctions(options, x0, optimValues, 'iter'); 237 | 238 | 239 | %% - Optimisation loop 240 | while true 241 | % - Next iteration 242 | optimValues.iteration = optimValues.iteration + 1; 243 | nIter = optimValues.iteration; 244 | 245 | % - Update biased 1st moment estimate 246 | m = beta1.*m + (1 - beta1) .* optimValues.gradient(:); 247 | % - Update biased 2nd raw moment estimate 248 | v = beta2.*v + (1 - beta2) .* (optimValues.gradient(:).^2); 249 | 250 | % - Compute bias-corrected 1st moment estimate 251 | mHat = m./(1 - beta1^nIter); 252 | % - Compute bias-corrected 2nd raw moment estimate 253 | vHat = v./(1 - beta2^nIter); 254 | 255 | % - Determine step to take at this iteration 256 | vfStep = stepSize.*mHat./(sqrt(vHat) + epsilon); 257 | 258 | % - Test step for true improvement, reject bad steps 259 | if (fun(x(:) - vfStep(:), nIter) <= fval) 260 | x = x(:) - vfStep(:); 261 | optimValues.stepsize = max(abs(vfStep)); 262 | end 263 | 264 | % - Get next cost and gradient 265 | [fval, optimValues.gradient] = fun(x, nIter+1); 266 | vfCost(nIter + 1) = fval; 267 | optimValues.funccount = optimValues.funccount + nEvalsPerIter; 268 | 269 | % - Call display, output and plot functions 270 | bStop = FA_Display(options, x, optimValues, 'iter', nEpochSize); 271 | bStop = bStop | FA_CallOutputFunctions(options, x, optimValues, 'iter'); 272 | 273 | % - Store historical x 274 | xHist(:, nIter + 1) = x; 275 | 276 | % - Update covergence variables 277 | nFirstCost = max(1, nIter + 1-nEpochSize); 278 | fEstCost = mean(vfCost(nFirstCost:nIter+1)); 279 | fImprEst = abs(fEstCost - fLastCost); 280 | optimValues.improvement = fImprEst; 281 | fLastCost = fEstCost; 282 | optimValues.fval = fEstCost; 283 | 284 | %% - Check termination criteria 285 | if (bStop) 286 | optimValues.exitflag = 0; 287 | break; 288 | end 289 | 290 | if (nIter > nEpochSize) 291 | if (fImprEst < options.TolFun / nEpochSize) 292 | optimValues.exitflag = 1; 293 | break; 294 | end 295 | 296 | if (optimValues.stepsize < options.TolX) 297 | optimValues.exitflag = 2; 298 | break; 299 | end 300 | 301 | if (nIter >= options.MaxIter-1) 302 | optimValues.exitflag = 3; 303 | break; 304 | end 305 | 306 | if (optimValues.funccount > options.MaxFunEvals) 307 | optimValues.exitflag = 4; 308 | break; 309 | end 310 | end 311 | end 312 | 313 | % - Determine best solution 314 | vfCost = vfCost(1:nIter+1); 315 | [~, nBestParams] = nanmin(vfCost); 316 | x = xHist(:, nBestParams); 317 | fval = vfCost(nBestParams); 318 | exitflag = optimValues.exitflag; 319 | output = optimValues; 320 | 321 | % - Final display 322 | FA_Display(options, x, optimValues, 'done', nEpochSize); 323 | FA_CallOutputFunctions(options, x, optimValues, 'done'); 324 | 325 | end 326 | 327 | %% Utility functions 328 | 329 | % FA_FunComplexStepGrad - FUNCTION Compute complex step finite difference 330 | % gradient estimates for an analytial function 331 | function [fVal, vfCdX] = FA_FunComplexStepGrad(fun, x, nIter) 332 | % - Step size 333 | fStep = sqrt(eps); 334 | 335 | % - Get nominal value of function 336 | fVal = fun(x, nIter); 337 | 338 | % - Estimate gradients with complex step 339 | for (nParam = numel(x):-1:1) 340 | xStep = x; 341 | xStep(nParam) = xStep(nParam) + fStep * 1i; 342 | vfCdX(nParam, 1) = imag(fun(xStep, nIter)) ./ fStep; 343 | end 344 | end 345 | 346 | % FA_CheckGradients - FUNCTION Check that analytical gradients match finite 347 | % difference estimates 348 | function FA_CheckGradients(fun, x0) 349 | % - Get analytical gradients 350 | [~, vfCdXAnalytical] = fun(x0, 1); 351 | 352 | % - Get complex-step finite-difference gradient estimates 353 | [~, vfCdXFDE] = FA_FunComplexStepGrad(fun, x0, 1); 354 | 355 | disp('--- fmin_adam: Checking analytical gradients...'); 356 | 357 | % - Compare gradients 358 | vfGradDiff = abs(vfCdXAnalytical - vfCdXFDE); 359 | [fMaxDiff, nDiffIndex] = max(vfGradDiff); 360 | fTolGrad = eps(max(max(abs(vfCdXFDE), abs(vfCdXAnalytical)))); 361 | if (fMaxDiff > fTolGrad) 362 | fprintf(' Gradient check failed.\n'); 363 | fprintf(' Maximum difference between analytical and finite-step estimate: %.2g\n', fMaxDiff); 364 | fprintf(' Analytical: %.2g; Finite-step: %.2g\n', vfCdXAnalytical(nDiffIndex), vfCdXFDE(nDiffIndex)); 365 | fprintf(' Tolerance: %.2g\n', fTolGrad); 366 | fprintf(' Gradient indicies violating tolerance: ['); 367 | fprintf('%d, ', find(vfGradDiff > fTolGrad)); 368 | fprintf(']\n'); 369 | 370 | error('*** fmin_adam: Gradient check failed.'); 371 | end 372 | 373 | disp(' Gradient check passed. Well done!'); 374 | end 375 | 376 | % FA_Display - FUNCTION Display the current state of the optimisation 377 | % algorithm 378 | function bStop = FA_Display(options, x, optimValues, state, nEpochSize) %#ok 379 | bStop = false; 380 | 381 | % - Should we display anything? 382 | if (isequal(options.Display, 'none')) 383 | return; 384 | end 385 | 386 | % - Determine what to display 387 | switch (state) 388 | case 'init' 389 | if (isequal(options.Display, 'iter')) 390 | fprintf('\n\n%10s %10s %10s %10s\n', ... 391 | 'Iteration', 'Func-count', 'f(x)', 'Improvement', 'Step-size'); 392 | fprintf('%10s %10s %10s %10s %10s\n', ... 393 | '----------', '----------', '----------', '----------', '----------'); 394 | end 395 | 396 | case 'iter' 397 | if (isequal(options.Display, 'iter') && (mod(optimValues.iteration, nEpochSize) == 0)) 398 | fprintf('%10d %10d %10.2g %10.2g %10.2g\n', ... 399 | optimValues.iteration, optimValues.funccount, ... 400 | optimValues.fval, optimValues.improvement, optimValues.stepsize); 401 | end 402 | 403 | case 'done' 404 | bPrintSummary = isequal(options.Display, 'final') | ... 405 | isequal(options.Display, 'iter') | ... 406 | (isequal(options.Display, 'notify') & (optimValues.exitflag ~= 1) & (optimValues.exitflag ~= 2)); 407 | 408 | if (bPrintSummary) 409 | fprintf('\n\n%10s %10s %10s %10s %10s\n', ... 410 | 'Iteration', 'Func-count', 'f(x)', 'Improvement', 'Step-size'); 411 | fprintf('%10s %10s %10s %10s %10s\n', ... 412 | '----------', '----------', '----------', '----------', '----------'); 413 | fprintf('%10d %10d %10.2g %10.2g %10.2g\n', ... 414 | optimValues.iteration, optimValues.funccount, ... 415 | optimValues.fval, optimValues.improvement, optimValues.stepsize); 416 | fprintf('%10s %10s %10s %10s %10s\n', ... 417 | '----------', '----------', '----------', '----------', '----------'); 418 | 419 | strExitMessage = FA_GetExitMessage(optimValues, options); 420 | fprintf('\nFinished optimization.\n Reason: %s\n\n', strExitMessage); 421 | end 422 | end 423 | end 424 | 425 | % FA_CallOutputFunctions - FUNCTION Call output and plot functions during 426 | % optimisation 427 | function bStop = FA_CallOutputFunctions(options, x, optimValues, state) 428 | bStop = false; 429 | 430 | if (~isempty(options.OutputFcn)) 431 | bStop = bStop | options.OutputFcn(x, optimValues, state); 432 | drawnow; 433 | end 434 | 435 | if (~isempty(options.PlotFcns)) 436 | if (iscell(options.PlotFcns)) 437 | bStop = bStop | any(cellfun(@(fh)fh(x, optimValues, state), options.PlotFcns)); 438 | else 439 | bStop = bStop | options.PlotFcns(x, optimValues, state); 440 | end 441 | drawnow; 442 | end 443 | end 444 | 445 | % FA_eval - FUNCTION Evaluate a string or return a value 446 | function oResult = FA_eval(oInput) 447 | if (ischar(oInput)) 448 | oResult = evalin('caller', oInput); 449 | else 450 | oResult = oInput; 451 | end 452 | end 453 | 454 | % FA_GetExitMessage - FUNCTION Return the message describing why the 455 | % algorithm terminated 456 | function strMessage = FA_GetExitMessage(optimValues, options) 457 | switch (optimValues.exitflag) 458 | case 0 459 | strMessage = 'Terminated due to output or plot function.'; 460 | 461 | case 1 462 | strMessage = sprintf('Function improvement [%.2g] less than TolFun [%.2g].', optimValues.improvement, options.TolFun); 463 | 464 | case 2 465 | strMessage = sprintf('Step size [%.2g] less than TolX [%.2g].', optimValues.stepsize, options.TolX); 466 | 467 | case 3 468 | strMessage = sprintf('Number of iterations reached MaxIter [%d].', options.MaxIter); 469 | 470 | case 4 471 | strMessage = sprintf('Number of function evaluations reached MaxFunEvals [%d].', options.MaxFunEvals); 472 | 473 | otherwise 474 | strMessage = 'Unknown termination reason.'; 475 | end 476 | end 477 | 478 | % --- END of fmin_adam.m --- 479 | -------------------------------------------------------------------------------- /images/regression_fit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DylanMuir/fmin_adam/fac5d70d01be37f00a601e8bb581103b9e0bf702/images/regression_fit.png -------------------------------------------------------------------------------- /images/regression_minibatches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DylanMuir/fmin_adam/fac5d70d01be37f00a601e8bb581103b9e0bf702/images/regression_minibatches.png -------------------------------------------------------------------------------- /images/regression_scatter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DylanMuir/fmin_adam/fac5d70d01be37f00a601e8bb581103b9e0bf702/images/regression_scatter.png --------------------------------------------------------------------------------