├── AMSgrad.m ├── AdaMax.m ├── Adadelta.m ├── Adam.m ├── Adam2.m ├── Annulus.m ├── DBD.m ├── LICENSE ├── Nadam.m ├── Perceptron.m ├── RAdam.m ├── README.md ├── RMSprop.m ├── RNN.m └── Vanilla.m /AMSgrad.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = AMSgrad(gradients, state) 2 | %AMSGRAD Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'beta1') 10 | state.beta1 = 0.9; 11 | end 12 | if ~isfield(state, 'beta2') 13 | state.beta2 = 0.999; 14 | end 15 | if ~isfield(state, 'epsilon') 16 | state.epsilon = 1e-8; 17 | end 18 | if ~isfield(state, 'iteration') 19 | state.iteration = 1; 20 | end 21 | if ~isfield(state, 'm') 22 | state.m = zeros(size(gradients)); 23 | end 24 | if ~isfield(state, 'v') 25 | state.v = zeros(size(gradients)); 26 | end 27 | if ~isfield(state, 'vhat') 28 | state.vhat = zeros(size(gradients)); 29 | end 30 | if ~isfield(state, 'alpha') 31 | state.alpha = 1e-2; 32 | end 33 | 34 | % update biased first moment estimate 35 | state.m = state.beta1 * state.m + (1 - state.beta1) * gradients; 36 | 37 | % update biased second raw moment estimate 38 | state.v = state.beta2 * state.v + (1 - state.beta2) * gradients.^2; 39 | 40 | % non-decreasing 41 | state.vhat = max(state.vhat, state.v); 42 | 43 | % update parameters 44 | updates = state.alpha * state.m ./ (sqrt(state.vhat) + state.epsilon); 45 | 46 | % update iteration number 47 | state.iteration = state.iteration + 1; 48 | 49 | 50 | end 51 | 52 | -------------------------------------------------------------------------------- /AdaMax.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = AdaMax(gradients, state) 2 | %ADAMAX Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'beta1') 10 | state.beta1 = 0.9; 11 | end 12 | if ~isfield(state, 'beta2') 13 | state.beta2 = 0.999; 14 | end 15 | if ~isfield(state, 'epsilon') 16 | state.epsilon = 1e-8; 17 | end 18 | if ~isfield(state, 'iteration') 19 | state.iteration = 1; 20 | end 21 | if ~isfield(state, 'm') 22 | state.m = zeros(size(gradients)); 23 | end 24 | if ~isfield(state, 'u') 25 | state.u = zeros(size(gradients)); 26 | end 27 | if ~isfield(state, 'alpha') 28 | state.alpha = 1e-2; 29 | end 30 | 31 | % update biased first moment estimate 32 | state.m = state.beta1 * state.m + (1 - state.beta1) * gradients; 33 | 34 | % update biased second raw moment estimate 35 | state.u = max(state.beta2 * state.u, abs(gradients)); 36 | 37 | % compute bias-corrected first moment estimate 38 | mhat = state.m / (1 - state.beta1^state.iteration); 39 | 40 | % update parameters 41 | updates = state.alpha * mhat ./ (state.u + state.epsilon); 42 | 43 | % update iteration number 44 | state.iteration = state.iteration + 1; 45 | 46 | 47 | end 48 | 49 | -------------------------------------------------------------------------------- /Adadelta.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = Adadelta(gradients, state) 2 | %ADADELTA optimization 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'epsilon') 10 | state.epsilon = 1e-6; 11 | end 12 | if ~isfield(state, 'rho') 13 | state.rho = .95; 14 | end 15 | if ~isfield(state, 'iteration') 16 | state.iteration = 1; 17 | end 18 | if ~isfield(state, 'history') 19 | state.history = zeros(size(gradients)); 20 | end 21 | if ~isfield(state, 'u') 22 | state.u = zeros(size(gradients)); 23 | end 24 | 25 | % accumulate gradient 26 | state.history = state.rho * state.history + (1 - state.rho) * gradients.^2; 27 | 28 | % update parameters 29 | updates = gradients .* sqrt((state.u + state.epsilon) ./ (state.history + state.epsilon)); 30 | 31 | % accumulate updates 32 | state.u = state.rho * state.u + (1 - state.rho) * updates.^2; 33 | 34 | % update iteration number 35 | state.iteration = state.iteration + 1; 36 | 37 | 38 | end 39 | 40 | -------------------------------------------------------------------------------- /Adam.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = Adam(gradients, state) 2 | %ADAM Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'beta1') 10 | state.beta1 = 0.9; 11 | end 12 | if ~isfield(state, 'beta2') 13 | state.beta2 = 0.999; 14 | end 15 | if ~isfield(state, 'epsilon') 16 | state.epsilon = 1e-8; 17 | end 18 | if ~isfield(state, 'iteration') 19 | state.iteration = 1; 20 | end 21 | if ~isfield(state, 'm') 22 | state.m = zeros(size(gradients)); 23 | end 24 | if ~isfield(state, 'v') 25 | state.v = zeros(size(gradients)); 26 | end 27 | if ~isfield(state, 'alpha') 28 | state.alpha = 1e-2; 29 | end 30 | 31 | % update biased first moment estimate 32 | state.m = state.beta1 * state.m + (1 - state.beta1) * gradients; 33 | 34 | % update biased second raw moment estimate 35 | state.v = state.beta2 * state.v + (1 - state.beta2) * gradients.^2; 36 | 37 | % compute bias-corrected first moment estimate 38 | mhat = state.m / (1 - state.beta1^state.iteration); 39 | 40 | % compute bias-corrected second raw moment estimate 41 | vhat = state.v / (1 - state.beta2^state.iteration); 42 | 43 | % update parameters 44 | updates = state.alpha * mhat ./ (sqrt(vhat) + state.epsilon); 45 | 46 | % update iteration number 47 | state.iteration = state.iteration + 1; 48 | 49 | 50 | end 51 | 52 | -------------------------------------------------------------------------------- /Adam2.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = Adam2(gradients, state) 2 | %ADAM2 Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'beta1') 10 | state.beta1 = 0.9; 11 | end 12 | if ~isfield(state, 'beta2') 13 | state.beta2 = 0.999; 14 | end 15 | if ~isfield(state, 'beta1t') 16 | state.beta1t = state.beta1; 17 | end 18 | if ~isfield(state, 'beta2t') 19 | state.beta2t = state.beta2; 20 | end 21 | if ~isfield(state, 'epsilon') 22 | state.epsilon = 1e-8; 23 | end 24 | if ~isfield(state, 'm') 25 | state.m = zeros(size(gradients)); 26 | end 27 | if ~isfield(state, 'v') 28 | state.v = zeros(size(gradients)); 29 | end 30 | if ~isfield(state, 'vhat') 31 | state.vhat = state.v; 32 | end 33 | if ~isfield(state, 'alpha') 34 | state.alpha = 1e-2; 35 | end 36 | 37 | % update biased first moment estimate 38 | state.m = state.beta1 * state.m + (1 - state.beta1) * gradients; 39 | 40 | % update biased second raw moment estimate 41 | state.v = state.beta2 * state.v + (1 - state.beta2) * gradients.^2; 42 | 43 | % bias correction 44 | bc = sqrt(1 - state.beta2t) / (1 - state.beta1t); 45 | 46 | % compute bias-corrected second raw moment estimate 47 | state.vhat = max(state.vhat, state.v); 48 | 49 | % update parameters 50 | updates = state.alpha * state.m * bc ./ (sqrt(state.vhat) + state.epsilon); 51 | state.beta1t = state.beta1 * state.beta1t; 52 | state.beta2t = state.beta2 * state.beta2t; 53 | 54 | 55 | end 56 | 57 | -------------------------------------------------------------------------------- /Annulus.m: -------------------------------------------------------------------------------- 1 | function X = Annulus( n, p ) 2 | % Uniformly sample n points from the Annulus. Set in ambient space R^p. 3 | % INPUT 4 | % n : Number of points. 5 | % p : Dimension of ambient Euclidean space (>= 2). 6 | % OUTPUT 7 | % X : Data matrix (n x p). 8 | % Written by John Malik on 2018.13.10, john.malik@duke.edu 9 | 10 | switch nargin 11 | case 1 12 | p = 2; 13 | case 0 14 | error('Select a number of points to sample.') 15 | end 16 | 17 | rho = nan(n, 1); 18 | i = 1; 19 | inner = sqrt(8) / 3; 20 | while i <= n 21 | xvec = sqrt(rand(1)); 22 | if xvec < 1 / 3 || xvec >= inner 23 | rho(i) = xvec; 24 | i = i + 1; 25 | else 26 | continue 27 | end 28 | end 29 | r = sort(rho); 30 | t = 2 * pi * rand(n, 1); 31 | x = r .* cos(t); 32 | y = r .* sin(t); 33 | X = [x, y]; 34 | 35 | if p > 2 36 | X = X * transpose(orth(randn(p, 2))); 37 | end 38 | 39 | 40 | end 41 | 42 | -------------------------------------------------------------------------------- /DBD.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = DBD(gradients, state) 2 | %DBD Delta-bar-delta optimization 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'alpha') 10 | state.alpha = 1e-2; 11 | end 12 | if ~isfield(state, 'momentum') 13 | state.momentum = 0.5; 14 | end 15 | if ~isfield(state, 'mini') 16 | state.mini = 0.01; 17 | end 18 | if ~isfield(state, 'iteration') 19 | state.iteration = 1; 20 | end 21 | if ~isfield(state, 'updates') 22 | state.updates = zeros(size(gradients)); 23 | end 24 | if ~isfield(state, 'gains') 25 | state.gains = ones(size(gradients)); 26 | end 27 | if ~isfield(state, 'kappa') 28 | state.kappa = 0.2; 29 | end 30 | if ~isfield(state, 'phi') 31 | state.phi = 0.8; 32 | end 33 | 34 | % delta bar delta 35 | dbd = sign(gradients) == sign(state.updates); 36 | 37 | % decrease gains when moving in the opposite direction 38 | state.gains(dbd) = state.gains(dbd) * state.phi; 39 | 40 | % increase gains when moving in the same direction 41 | state.gains(~dbd) = state.gains(~dbd) + state.kappa; 42 | 43 | % clip gains from below 44 | state.gains = max(state.gains, state.mini); 45 | 46 | % update parameters using momentum term 47 | updates = state.alpha * (state.gains .* gradients) - state.momentum * state.updates; 48 | 49 | % notation 50 | state.updates = -updates; 51 | 52 | % update iteration number 53 | state.iteration = state.iteration + 1; 54 | 55 | 56 | end 57 | 58 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 John Malik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Nadam.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = Nadam(gradients, state) 2 | %NADAM Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'beta1') 10 | state.beta1 = 0.9; 11 | end 12 | if ~isfield(state, 'beta2') 13 | state.beta2 = 0.999; 14 | end 15 | if ~isfield(state, 'epsilon') 16 | state.epsilon = 1e-8; 17 | end 18 | if ~isfield(state, 'iteration') 19 | state.iteration = 1; 20 | end 21 | if ~isfield(state, 'm') 22 | state.m = zeros(size(gradients)); 23 | end 24 | if ~isfield(state, 'v') 25 | state.v = zeros(size(gradients)); 26 | end 27 | if ~isfield(state, 'alpha') 28 | state.alpha = 1e-2; 29 | end 30 | 31 | % update biased first moment estimate 32 | state.m = state.beta1 * state.m + (1 - state.beta1) * gradients; 33 | 34 | % update biased second raw moment estimate 35 | state.v = state.beta2 * state.v + (1 - state.beta2) * gradients.^2; 36 | 37 | % compute bias-corrected first moment estimate 38 | mhat = state.m / (1 - state.beta1^(state.iteration + 1)); 39 | 40 | % compute bias-corrected second raw moment estimate 41 | vhat = state.v / (1 - state.beta2^state.iteration); 42 | 43 | % nadam 44 | mhat = state.beta1 * mhat + (((1 - state.beta1) * gradients) / (1 - state.beta1^state.iteration)); 45 | 46 | % update parameters 47 | updates = (state.alpha * mhat) ./ (sqrt(vhat) + state.epsilon); 48 | 49 | % update iteration number 50 | state.iteration = state.iteration + 1; 51 | 52 | 53 | end 54 | 55 | -------------------------------------------------------------------------------- /Perceptron.m: -------------------------------------------------------------------------------- 1 | clear 2 | close all 3 | 4 | % Parameters 5 | n = 2000; 6 | p = 2; 7 | X = Annulus(n, p)'; 8 | Y = sum(X.^2) < .5; 9 | Y = [Y; ~Y]; 10 | classes = size(Y, 1); 11 | nodes = 10; 12 | c = 1e-1; 13 | d = 1e-4; 14 | epochs = 45; 15 | batch = 20; 16 | 17 | % Initialize Weights and Biases 18 | W1 = c * randn(nodes, p); 19 | B1 = d * randn(nodes, 1); 20 | W2 = c * randn(nodes, nodes); 21 | B2 = d * randn(nodes, 1); 22 | W3 = c * randn(classes, nodes); 23 | B3 = d * randn(classes, 1); 24 | 25 | % ReLU Activation Function 26 | g = @(z) max(z, 0); 27 | dg = @(z) z > 0; 28 | 29 | %% Train 30 | 31 | % Initialize Optimizer 32 | Op = struct; 33 | Op.alpha = 0.001; 34 | W1S = Op; 35 | B1S = Op; 36 | W2S = Op; 37 | B2S = Op; 38 | W3S = Op; 39 | B3S = Op; 40 | 41 | for ii = 1:epochs 42 | 43 | % Shuffle Data and Labels 44 | t = randperm(n); 45 | XX = X(:, t); 46 | YY = Y(:, t); 47 | 48 | for jj = 1:ceil(n / batch) 49 | 50 | % Forward Pass 51 | A0 = XX(:, (jj - 1) * batch + 1:min(jj * batch, n)); 52 | Z1 = W1 * A0 + B1; 53 | A1 = g(Z1); 54 | Z2 = W2 * A1 + B2; 55 | A2 = g(Z2); 56 | Z3 = W3 * A2 + B3; 57 | A3 = g(Z3); 58 | 59 | % Back-Propogated Error 60 | D3 = (A3 - YY(:, (jj - 1) * batch + 1:min(jj * batch, n))) .* dg(Z3); 61 | D2 = (W3' * D3) .* dg(Z2); 62 | D1 = (W2' * D2) .* dg(Z1); 63 | 64 | % Gradients 65 | W1G = D1 * A0' / size(A0, 2); 66 | B1G = mean(D1, 2); 67 | W2G = D2 * A1' / size(A0, 2); 68 | B2G = mean(D2, 2); 69 | W3G = D3 * A2' / size(A0, 2); 70 | B3G = mean(D3, 2); 71 | 72 | % Gradient Descent Optimizer 73 | [W1U, W1S] = Adadelta(W1G, W1S); 74 | [B1U, B1S] = Adadelta(B1G, B1S); 75 | [W2U, W2S] = Adadelta(W2G, W2S); 76 | [B2U, B2S] = Adadelta(B2G, B2S); 77 | [W3U, W3S] = Adadelta(W3G, W3S); 78 | [B3U, B3S] = Adadelta(B3G, B3S); 79 | 80 | % Perform Updates 81 | W1 = W1 - W1U; 82 | B1 = B1 - B1U; 83 | W2 = W2 - W2U; 84 | B2 = B2 - B2U; 85 | W3 = W3 - W3U; 86 | B3 = B3 - B3U; 87 | 88 | end 89 | 90 | % Print Loss 91 | Z = g(W3 * g(W2 * g(W1 * X + B1) + B2) + B3); 92 | disp(['Loss: ' num2str(sum((Z(:) - Y(:)).^2))]); 93 | 94 | % Visualize as Function on Unit Square 95 | [x, y] = meshgrid(linspace(-1, 1, 200), linspace(-1, 1, 200)); 96 | z = g(W3 * g(W2 * g(W1 * [x(:)'; y(:)'] + B1) + B2) + B3); 97 | mesh(x, y, reshape(z(1, :), [200, 200])); drawnow; 98 | 99 | end 100 | 101 | % Scatter Plot with Predictions 102 | figure; 103 | [~, id] = max(Z, [], 1); 104 | scatter(X(1, :), X(2, :), 20, id, 'filled'); 105 | 106 | -------------------------------------------------------------------------------- /RAdam.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = RAdam(gradients, state) 2 | %RADAM Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'beta1') 10 | state.beta1 = 0.9; 11 | end 12 | if ~isfield(state, 'beta2') 13 | state.beta2 = 0.999; 14 | end 15 | if ~isfield(state, 'epsilon') 16 | state.epsilon = 1e-8; 17 | end 18 | if ~isfield(state, 'iteration') 19 | state.iteration = 1; 20 | end 21 | if ~isfield(state, 'm') 22 | state.m = zeros(size(gradients)); 23 | end 24 | if ~isfield(state, 'v') 25 | state.v = zeros(size(gradients)); 26 | end 27 | if ~isfield(state, 'alpha') 28 | state.alpha = 1e-2; 29 | end 30 | 31 | rhoinf = 2 / (1 - state.beta2) - 1; 32 | 33 | % update biased first moment estimate 34 | state.m = state.beta1 * state.m + (1 - state.beta1) * gradients; 35 | 36 | % update biased second raw moment estimate 37 | state.v = state.beta2 * state.v + (1 - state.beta2) * gradients.^2; 38 | 39 | % compute bias-corrected first moment estimate 40 | mhat = state.m / (1 - state.beta1^state.iteration); 41 | 42 | % length of the approximated SMA 43 | rho = rhoinf - 2 * state.iteration * state.beta2^state.iteration / (1 - state.beta2^state.iteration); 44 | 45 | if rho > 4 46 | 47 | % compute bias-corrected second raw moment estimate 48 | vhat = sqrt(state.v / (1 - state.beta2^state.iteration)); 49 | 50 | % variance rectification term 51 | r = sqrt((rho - 4) * (rho - 2) * rhoinf / (rhoinf - 4) / (rhoinf - 2) / rho); 52 | 53 | % update parameters 54 | updates = state.alpha * r * mhat ./ (vhat + state.epsilon); 55 | 56 | else 57 | 58 | % update parameters 59 | updates = state.alpha * mhat; 60 | 61 | end 62 | 63 | % update iteration number 64 | state.iteration = state.iteration + 1; 65 | 66 | 67 | end 68 | 69 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gradient-descent 2 | A MATLAB package for numerous gradient descent optimization methods, such as Adam and RMSProp. The newest algorithm is the Rectified Adam Optimizer. 3 | 4 | To test the software, see the included script for a simple multi-layer perceptron or the MATLAB code for a recurrent neural network (RNN). 5 | -------------------------------------------------------------------------------- /RMSprop.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = RMSprop(gradients, state) 2 | %RMSPROP rmsprop optimization 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'alpha') 10 | state.alpha = 1e-3; 11 | end 12 | if ~isfield(state, 'rho') 13 | state.rho = 0.9; 14 | end 15 | if ~isfield(state, 'epsilon') 16 | state.epsilon = 1e-8; 17 | end 18 | if ~isfield(state, 'iteration') 19 | state.iteration = 1; 20 | end 21 | if ~isfield(state, 'history') 22 | state.history = zeros(size(gradients)); 23 | end 24 | 25 | 26 | state.history = state.rho * state.history + (1 - state.rho) * gradients.^2; 27 | 28 | % update parameters 29 | updates = gradients * state.alpha ./ sqrt(state.history + state.epsilon); 30 | 31 | % update iteration number 32 | state.iteration = state.iteration + 1; 33 | 34 | 35 | end 36 | 37 | -------------------------------------------------------------------------------- /RNN.m: -------------------------------------------------------------------------------- 1 | clear 2 | close all 3 | 4 | % train a nonlinear RNN to add binary numbers 5 | % http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf 6 | % error in line 6 of bptt: change z to u from 6 on 7 | 8 | % Written by John Malik on 2019.9.3 john.malik@duke.edu 9 | 10 | % length of binary sequence 11 | len = 7; 12 | 13 | % number of sequences to train on 14 | n = 1e3; 15 | 16 | % generate pairs of numbers and their sums 17 | vv = randi([0 sum(pow2(0:len-2))], n, 2); 18 | yy = sum(vv, 2); 19 | 20 | % convert to binary and flip to read left to right 21 | v = zeros(2, len, n); 22 | for i = 1:2 23 | v(i, :, :) = fliplr(dec2bin(vv(:, i), len))' - 48; 24 | end 25 | y(1, :, :) = fliplr(dec2bin(yy, len))' - 48; 26 | 27 | % dimensions 28 | vdim = size(v, 1); % number of input variables 29 | hdim = 10; % number of state variables (hidden nodes) 30 | odim = size(y, 1); % number of output variables 31 | mem = size(v, 2); % length of sequence (T) 32 | 33 | % initialize weights 34 | init_wts = @(row, col) (2 * rand(row, col) - 1) * sqrt(6 / (row + col)); 35 | Whv = init_wts(hdim, vdim); 36 | Whh = init_wts(hdim, hdim); 37 | Woh = init_wts(odim, hdim); 38 | 39 | % initialize biases 40 | bh = init_wts(hdim, 1); 41 | bo = init_wts(odim, 1); 42 | 43 | % initial state 44 | h0 = init_wts(hdim, 1); 45 | 46 | % initialize optimizer with learning rate 47 | rho = 1e-2; 48 | WhvS.alpha = rho; 49 | WhhS.alpha = rho; 50 | WohS.alpha = rho; 51 | bhS.alpha = rho; 52 | boS.alpha = rho; 53 | h0S.alpha = rho; 54 | 55 | % number of steps 56 | steps = 2e3; 57 | 58 | % batch size 59 | batch = 20; 60 | 61 | % flatten matrix helper 62 | flt = @(z) reshape(z, size(z, 1), batch); 63 | 64 | % store training loss 65 | E = zeros(steps, 1); 66 | 67 | % activation functions 68 | e = @(z) tanh(z); 69 | de = @(z) sech(z).^2; 70 | g = @(z) 1 ./ (1 + exp(-z)); 71 | dg = @(z) exp(-z) ./ (1 + exp(-z)).^2; 72 | 73 | 74 | 75 | % minibatch gradient descent optimization 76 | for tt = 1:steps 77 | 78 | % pick a random batch 79 | j = randi(n, batch, 1); 80 | 81 | % calculate hidden states 82 | u = zeros(hdim, mem, batch); 83 | h = zeros(hdim, mem, batch); 84 | u(:, 1, :) = Whv * flt(v(:, 1, j)) + Whh * h0; 85 | h(:, 1, :) = e(u(:, 1, :)); 86 | for i = 2:mem 87 | u(:, i, :) = Whv * flt(v(:, i, j)) + Whh * flt(h(:, i - 1, :)); 88 | h(:, i, :) = e(u(:, i, :)); 89 | end 90 | 91 | % calculate output 92 | o = zeros(odim, mem, batch); 93 | for jj = 1:batch 94 | o(:, :, jj) = Woh * h(:, :, jj); 95 | end 96 | z = g(o); 97 | 98 | % loss per input pair 99 | E(tt) = sum((z - y(:, :, j)).^2, 'all') / batch; 100 | 101 | % initialize gradients 102 | dWhv = zeros(size(Whv)); 103 | dWhh = zeros(size(Whh)); 104 | dWoh = zeros(size(Woh)); 105 | dbh = zeros(size(bh)); 106 | dbo = zeros(size(bo)); 107 | dh = zeros(size(h)); 108 | du = zeros(size(u)); 109 | do = zeros(size(o)); 110 | 111 | % BPTT 112 | for t = mem:-1:1 113 | 114 | do(:, t, :) = dg(o(:, t, :)) .* (2 * (z(:, t, :) - y(:, t, j))); 115 | dbo = dbo + mean(do(:, t, :), 3); 116 | dWoh = dWoh + flt(do(:, t, :)) * flt(h(:, t, :))'; 117 | dh(:, t, :) = flt(dh(:, t, :)) + Woh' * flt(do(:, t, :)); 118 | du(:, t, :) = de(u(:, t, :)) .* dh(:, t, :); 119 | dWhv = dWhv + flt(du(:, t, :)) * flt(v(:, t, j))'; 120 | dbh = dbh + mean(du(:, t, :), 3); 121 | 122 | if t > 1 123 | dWhh = dWhh + flt(du(:, t, :)) * flt(h(:, t - 1, :))'; 124 | dh(:, t - 1, :) = Whh' * flt(du(:, t, :)); 125 | else 126 | dWhh = dWhh + flt(du(:, t, :)) * repmat(h0', [batch, 1]); 127 | dh0 = mean(Whh' * flt(du(:, t, :)), 2); 128 | end 129 | 130 | end 131 | 132 | % compute updates 133 | [WhvU, WhvS] = RAdam(dWhv, WhvS); 134 | [WhhU, WhhS] = RAdam(dWhh, WhhS); 135 | [WohU, WohS] = RAdam(dWoh, WohS); 136 | [bhU, bhS] = RAdam(dbh, bhS); 137 | [boU, boS] = RAdam(dbo, boS); 138 | [h0U, h0S] = RAdam(dh0, h0S); 139 | 140 | % apply updates 141 | Whv = Whv - WhvU; 142 | Whh = Whh - WhhU; 143 | Woh = Woh - WohU; 144 | bh = bh - bhU; 145 | bo = bo - boU; 146 | h0 = h0 - h0U; 147 | 148 | end 149 | 150 | %% test on new pairs of binary sequences 151 | 152 | clear v y 153 | 154 | % number of test examples 155 | n = 1e3; 156 | 157 | % generate pairs of numbers and their sums 158 | vv = randi([0 sum(pow2(0:len-2))], n, 2); 159 | yy = sum(vv, 2); 160 | 161 | % convert to binary and flip to read left to right 162 | v = zeros(2, len, n); 163 | for i = 1:2 164 | v(i, :, :) = fliplr(dec2bin(vv(:, i), len))' - 48; 165 | end 166 | y(1, :, :) = fliplr(dec2bin(yy, len))' - 48; 167 | 168 | % pass each point one at a time 169 | batch = 1; 170 | 171 | % flatten matrix helper 172 | flt = @(z) reshape(z, size(z, 1), batch); 173 | 174 | % prediction 175 | yhat = zeros(size(y)); 176 | for j = 1:n 177 | 178 | % calculate hidden states 179 | u = zeros(hdim, mem, batch); 180 | h = zeros(hdim, mem, batch); 181 | u(:, 1, :) = Whv * flt(v(:, 1, j)) + Whh * h0; 182 | h(:, 1, :) = e(u(:, 1, :)); 183 | for i = 2:mem 184 | u(:, i, :) = Whv * flt(v(:, i, j)) + Whh * flt(h(:, i - 1, :)); 185 | h(:, i, :) = e(u(:, i, :)); 186 | end 187 | 188 | % calculate output 189 | o = zeros(odim, mem, batch); 190 | for jj = 1:batch 191 | o(:, :, jj) = Woh * h(:, :, jj); 192 | end 193 | z = g(o); 194 | 195 | % create binary vector 196 | yhat(:, :, j) = round(z); 197 | 198 | end 199 | 200 | % accuracy 201 | ac = mean(all(yhat == y, 2)); 202 | 203 | figure; plot(E); ylabel('Training Loss') 204 | xlabel('Training Step'); title(['Test Accuracy: ' num2str(100*ac)]); -------------------------------------------------------------------------------- /Vanilla.m: -------------------------------------------------------------------------------- 1 | function [updates, state] = Vanilla(gradients, state) 2 | %VANILLA The most basic gradient descent 3 | % Detailed explanation goes here 4 | 5 | if nargin == 1 6 | state = struct; 7 | end 8 | 9 | if ~isfield(state, 'alpha') 10 | state.alpha = 1e-1; 11 | end 12 | if ~isfield(state, 'iteration') 13 | state.iteration = 1; 14 | end 15 | 16 | % compute updates 17 | updates = state.alpha * gradients; 18 | 19 | % update iteration number 20 | state.iteration = state.iteration + 1; 21 | 22 | 23 | end 24 | 25 | --------------------------------------------------------------------------------