├── .gitignore ├── AUTHORS ├── LICENSE.md ├── NeuralNet2.m ├── NeuralNetApp.m ├── README.md └── demos └── demo_nn.m /.gitignore: -------------------------------------------------------------------------------- 1 | # Windows default autosave extension 2 | *.asv 3 | 4 | # OSX / *nix default autosave extension 5 | *.m~ 6 | 7 | # Compiled MEX binaries (all platforms) 8 | *.mex* 9 | 10 | # Simulink Code Generation 11 | slprj/ 12 | 13 | # Session info 14 | octave-workspace 15 | -------------------------------------------------------------------------------- /AUTHORS: -------------------------------------------------------------------------------- 1 | NerualNetPlayground is maintained by members of the Stack Overflow MATLAB 2 | Chat Room (https://github.com/StackOverflowMATLABchat). 3 | Visit: http://chat.stackoverflow.com/rooms/81987/matlab-and-octave 4 | 5 | Authors: 6 | 7 | Amro (https://github.com/amroamroamro) 8 | Ray Phan (https://github.com/rayryeng) 9 | 10 | Here is a list of much-appreciated contributors -- people who have reported 11 | bugs, submitted patches, and helped with debugging and testing: 12 | 13 | Jonathan Suever (https://github.com/suever) 14 | Qi Wang (https://github.com/GameOfThrow) 15 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 Code of StackOverflow MATLAB Chat 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /NeuralNet2.m: -------------------------------------------------------------------------------- 1 | classdef NeuralNet2 < handle 2 | % NeuralNet2 Neural Network implementation for the NeuralNetPlayground tool 3 | % This class implements the training (forward and backpropagation) and 4 | % predictions using Artificial Neural Networks (ANN). The primary purpose 5 | % is to assist in the construction of the NeuralNetPlayground framework 6 | % as well as providing a framework for training neural networks that 7 | % uses core MATLAB functionality only (i.e. no toolbox dependencies). 8 | % 9 | % This class is initialized by specifying the total number of input 10 | % layer neurons, hidden layer neurons for each hidden layer desired 11 | % and the total number of output layer neurons as a single vector. 12 | % With specified training examples and expected outputs, the neural 13 | % network weights are learned with Stochastic Gradient Descent. 14 | % 15 | % For regression, the number of output neurons is usually 1. For 16 | % binary class classification, the number of output neurons is usually 1 17 | % and you perform the appropriate thresholding to decide which class 18 | % an input belongs to. For multi-class classification, the number of 19 | % output neurons is usually the total number of classes in your data. 20 | % 21 | % To do non-linear regression, you usually use a non-linear hidden 22 | % activation function (sigmoid, tanh) with a linear output hidden 23 | % activation function. However, if you want to linear regression it may 24 | % be more stable to specify no hidden layers and just have input 25 | % and output layers only. For classification, you can specify all layers 26 | % to have the same activation function. You also have the option of 27 | % specifying no hidden layers and having a single input and output layer 28 | % which will reduce to logistic regression. 29 | % 30 | % The learned weights can be used to predict new examples given the 31 | % learned weights. The loss function used in this implementation is 32 | % the sum of squared differences or Euclidean loss function. 33 | % 34 | % NeuralNet2 Properties: 35 | % LearningRate - The learning rate for Stochastic Gradient Descent 36 | % Must be strictly positive (i.e. > 0). 37 | % Default value is 0.03. 38 | % 39 | % ActivationFunction - The activation function to be applied to each 40 | % neuron in the hidden and output layer. 41 | % The ones you can choose from are: 42 | % 'linear': Linear 43 | % 'relu': Rectified Linear Unit (i.e. ramp) 44 | % 'tanh': Hyperbolic Tangent 45 | % 'sigmoid': Sigmoidal 46 | % Default is 'tanh'. 47 | % 48 | % OutputActivationFunction - The activation function to be applied 49 | % specifically to the output layer. This will 50 | % change the output layer activation function 51 | % but leave the hidden layer one intact. 52 | % The ones you can choose are the same as 53 | % seen in the ActivationFunction property. 54 | % To set this to use the same activation function 55 | % as the hidden layer, set this to empty ([]). 56 | % Default is []. 57 | % 58 | % RegularizationType - Apply regularization to the training process 59 | % if desired. 60 | % The ones you can choose from are: 61 | % 'L1': L1 regularization 62 | % 'L2': L2 regularization 63 | % 'none': No regularization 64 | % Default is 'none' 65 | % Note: The method used for L1 regularization 66 | % comes from: 67 | % 68 | % Tsuroka, Y., Tsujii, J., Ananiadou, S., 69 | % Stochastic Gradient Descent Training for 70 | % L1-regularized Log-linear Models with Cumulative 71 | % Penalty 72 | % http://aclweb.org/anthology/P/P09/P09-1054.pdf 73 | % 74 | % RegularizationRate - The rate of regularization to apply (if desired) 75 | % Must be non-negative (i.e. >= 0). 76 | % Default is 0. 77 | % 78 | % BatchSize - Number of training examples selected per iteration 79 | % Choosing 1 example would implement true Stochastic 80 | % Gradient Descent while choosing the total number 81 | % of examples would implement Batch Gradient Descent. 82 | % Choosing any value in between implements mini-batch 83 | % Stochastic Gradient Descent. 84 | % Must be strictly positive (i.e. > 0) and integer. 85 | % Default is 10. If the batch size is larger than 86 | % the total number of examples, the batch size 87 | % reverts to using the total number of examples. 88 | % 89 | % ========================== 90 | % 91 | % Example Use - Binary Classification 92 | % ----------- 93 | % X = [0 0; 0 1; 1 0; 1 1]; % Define XOR data 94 | % Y = [-1; 1; 1; -1]; 95 | % net = NeuralNet2([2 4 1]); % Create Neural Network object 96 | % % Two input layer neurons, one hidden 97 | % % layer with four neurons and one output 98 | % % layer neuron 99 | % N = 5000; % Do 5000 iterations of Stochastic Gradient Descent 100 | % 101 | % % Customize Neural Network engine 102 | % net.LearningRate = 0.1; % Learning rate is set to 0.1 103 | % net.RegularizationType = 'L2'; % Regularization is L2 104 | % net.RegularizationRate = 0.001; % Regularization rate is 0.001 105 | % 106 | % perf = net.train(X, Y, N); % Train the Neural Network 107 | % Yraw = net.sim(X); % Use trained object on original examples 108 | % Ypred = ones(size(Yraw)); % Perform classification with thresholding 109 | % Ypred(Yraw < 0) = -1; 110 | % plot(1:N, perf); % Plot cost function per iteration 111 | % 112 | % % Display results 113 | % disp('Training Examples and expected labels'); display(X); display(Y); 114 | % disp('Predicted outputs'); display(Ypred); 115 | % disp('Classification accuracy: ') 116 | % disp(100 * sum(Y == Ypred) / size(X, 1)); 117 | % 118 | % 119 | % ========================== 120 | % 121 | % Example Use - Multiclass Classification 122 | % ----------- 123 | % load fisheriris; % Load the Fisher iris dataset 124 | % numFeatures = size(meas, 2); % Total number of features - should be 4 125 | % [~, ~, IDs] = unique(species); % Convert character labels to unique IDs 126 | % numClasses = max(IDs); % Get total number of possible classes 127 | % M = size(meas, 1); % Number of examples 128 | % Y = full(sparse(1 : M, IDs.', 1, M, numClasses)); % Create an output 129 | % % matrix where each row is an 130 | % % example and each column denotes 131 | % % the class the example belongs to. 132 | % % That class gets assigned 1 while 133 | % % the others get 0. This should be 134 | % % -1 if you are using tanh. 135 | % net = NeuralNet2([numFeatures 4 numClasses]); % Create Neural Network object 136 | % % Four input layer neurons, one hidden 137 | % % layer with four neurons and three output layer 138 | % % neuron 139 | % N = 5000; % Do 5000 iterations of Stochastic Gradient Descent 140 | % % Customize Neural Network engine 141 | % net.LearningRate = 0.1; % Learning rate is set to 0.1 142 | % net.RegularizationType = 'L2'; % Regularization is L2 143 | % net.RegularizationRate = 0.001; % Regularization rate is 0.001 144 | % net.ActivationFunction = 'sigmoid'; % sigmoid hidden activation function 145 | % perf = net.train(meas, Y, N); % Train the Neural Network 146 | % Yraw = net.sim(meas); % Use trained object on original examples 147 | % [~, Ypred] = max(Yraw, [], 2); % Determine which class has the largest 148 | % % response per example 149 | % plot(1:N, perf); % Plot cost function per iteration 150 | % % Display results 151 | % disp('Training Examples and expected labels'); display(X); display(Y); 152 | % disp('Predicted outputs'); display(Ypred); 153 | % disp('Classification accuracy: '); 154 | % disp(100 * sum(IDs == Ypred) / M); 155 | % 156 | % 157 | % ========================== 158 | % 159 | % Example Use - Regression 160 | % ----------- 161 | % [x, y] = meshgrid(-2:0.5:2); % Define some sample two feature values 162 | % z = x .* exp(-y); % Output is the function: x*(e^y) 163 | % X = [x(:) y(:)]; % Unroll the features so they are each in a single column 164 | % Y = z(:); % Unroll output of function into single column 165 | % net = NeuralNet2([2 10 1]); % Create Neural Network object 166 | % % Two input layer neurons, one hidden 167 | % % layer with 10 neurons and one output layer 168 | % % neuron 169 | % N = 10000; % Do 10000 iterations of Stochastic Gradient Descent 170 | % % Customize Neural Network engine 171 | % net.LearningRate = 0.1; % Learning rate is set to 0.1 172 | % net.RegularizationType = 'L2'; % Regularization is L2 173 | % net.RegularizationRate = 0.001; % Regularization rate is 0.001 174 | % net.ActivationFunction = 'Tanh'; % tanh hidden activation function 175 | % net.OutputActivationFunction = 'linear'; % linear output activation function 176 | % perf = net.train(X, Y, N); % Train the Neural Network 177 | % Ypred = net.sim(X); % Use trained object on original examples 178 | % plot(1:N, perf); % Plot cost function per iteration 179 | % 180 | % % Display results 181 | % disp('Absolute difference between predicted and true values'); 182 | % disp(abs(Y - Ypred)); 183 | % 184 | % ========================== 185 | % 186 | % See also NEURALNETAPP 187 | % 188 | % StackOverflowMATLABchat - http://chat.stackoverflow.com/rooms/81987/matlab-and-octave 189 | % Authors: Raymond Phan - http://stackoverflow.com/users/3250829/rayryeng 190 | % Amro - http://stackoverflow.com/users/97160/amro 191 | 192 | properties (Access = public) 193 | LearningRate % The learning rate (positive number) 194 | ActivationFunction % The desired activation function (string) 195 | OutputActivationFunction % The desired output activation function (string) 196 | RegularizationType % The type of regularization (string) 197 | RegularizationRate % The regularization rate (non-negative number) 198 | BatchSize % The size of the batch per iteration (positive integer number) 199 | end 200 | 201 | properties (Access = private) 202 | inputSize % Single value denoting how many neurons are in the input layer 203 | hiddenSizes % Vector denoting how many neurons per hidden layer 204 | outputSize % Single value denoting how many neurons are in the output layer 205 | weights % The weights of the neural network per layer 206 | end 207 | 208 | methods 209 | function this = NeuralNet2(layerSizes) 210 | % NeuralNet2 Create a Neural Network Instance 211 | % The constructor takes in a vector of layer sizes where the 212 | % first element denotes how many neurons are in the input layer, 213 | % the next N elements denote how many neurons are in each desired 214 | % hidden layer and the last element denotes how many neurons are 215 | % in the output layer. Take note that the amount of neurons 216 | % per layer that you specify does not include the bias units. 217 | % These will be included when training the network. Therefore, 218 | % the expected size of the vector is N + 2 where N is the total 219 | % number of hidden layers for the neural network. The exception 220 | % to this rule is when you specify a vector of two elements. This 221 | % is interpreted as having an input layer, an output and no 222 | % hidden layers. This situation is when you would like to perform 223 | % simple linear or logistic regression. 224 | % 225 | % The following example creates a neural network with 1 input 226 | % neuron (plus a bias) in the input layer, 2 hidden neurons 227 | % (plus a bias) in the first hidden layer, 3 hidden neurons 228 | % (plus a bias) in the second hidden layer and 1 output neuron 229 | % in the output layer 230 | % 231 | % layerSizes = [1 2 3 1]; 232 | % net = NeuralNet2(layerSizes); 233 | 234 | % default params 235 | this.LearningRate = 0.03; 236 | this.ActivationFunction = 'Tanh'; 237 | this.OutputActivationFunction = []; 238 | this.RegularizationType = 'None'; 239 | this.RegularizationRate = 0; 240 | this.BatchSize = 10; 241 | 242 | assert(numel(layerSizes) >= 2, 'Total number of layers should be at least 2'); 243 | 244 | % network structure (fully-connected feed-forward) 245 | % Obtain input layer neuron size 246 | this.inputSize = layerSizes(1); 247 | 248 | % Obtain the hidden layer neuron sizes 249 | if numel(layerSizes ~= 2) 250 | this.hiddenSizes = layerSizes(2:end-1); 251 | else 252 | this.hiddenSizes = 0; 253 | end 254 | 255 | % Obtain the output layer neuron size 256 | this.outputSize = layerSizes(end); 257 | 258 | % Initialize matrices relating between the ith layer 259 | % and (i+1)th layer 260 | this.weights = cell(1, numel(layerSizes) - 1); 261 | for i = 1 : numel(layerSizes) - 1 262 | this.weights{i} = zeros(layerSizes(i) + 1, layerSizes(i + 1)); 263 | end 264 | 265 | % Initialize weights 266 | init(this); 267 | end 268 | 269 | % ??? 270 | function configure(this, X, Y) 271 | % check correct sizes 272 | [xrows,xcols] = size(X); 273 | [yrows,ycols] = size(Y); 274 | assert(xrows == yrows); 275 | assert(xcols == this.inputSize); 276 | assert(ycols == this.outputSize); 277 | 278 | % min/max of inputs/outputs 279 | inMin = min(X); 280 | inMax = max(X); 281 | outMin = min(Y); 282 | outMax = max(Y); 283 | end 284 | 285 | function init(this) 286 | % init Initialize the Neural Network Weights 287 | % This method initializes the neural network weights 288 | % for connections using the initialization suggested by 289 | % Kaiming He: 290 | % 291 | % He, K., Zhang, X., Ren, S., & Sun, J. (2015). 292 | % Delving deep into rectifiers: Surpassing human-level performance 293 | % on ImageNet classification. arXiv preprint arXiv:1502.01852. 294 | % https://arxiv.org/abs/1502.01852 295 | % 296 | % Take note that this method is run when you create an instance 297 | % of the object. You would call init if you want to reinitialize 298 | % the neural network and start from the beginning. 299 | % 300 | % Uses: 301 | % net = NeuralNet2([1 2 1]); 302 | % % Other code... 303 | % % ... 304 | % % (Re-)initialize weights 305 | % net.init(); 306 | 307 | for ii = 1 : numel(this.weights) 308 | num = numel(this.weights{ii}); 309 | % Kaiming He et al. initialization strategy 310 | this.weights{ii}(:) = 2.0 * randn(num, 1) / ... 311 | sqrt(size(this.weights{ii}, 1)); 312 | end 313 | end 314 | 315 | function perf = train(this, X, Y, numIter) 316 | % train Perform neural network training with Stochastic Gradient Descent (SGD) 317 | % This method performs training on the neural network structure 318 | % that was specified when creating an instance of the class. 319 | % Using training example features and their expected outcomes, 320 | % trained network weights are created to facilitate future 321 | % predictions. 322 | % 323 | % Inputs: 324 | % X - Training example features as a 2D matrix of size M x N 325 | % M is the total number of examples and N are the 326 | % total number of features. M is expected to be the 327 | % same size as the number of input layer neurons. 328 | % 329 | % Y - Training example expected outputs as a 2D matrix of size 330 | % M x P where M is the total number of examples and P is 331 | % the total number of output neurons in the output layer. 332 | % 333 | % numIter - Number of iterations Stochastic Gradient Descent 334 | % should take while training. This is an optional 335 | % parameter and the number of iterations defaults to 336 | % 1 if omitted. 337 | % 338 | % Outputs: 339 | % perf - An array of size numIter x 1 which denotes 340 | % the cost between the predicted outputs and 341 | % expected outputs at each iteration of learning 342 | % the weights 343 | % 344 | % Uses: 345 | % net = NeuralNet2([1 2 1]); % Create NN object 346 | % % Create example data here stored in X and Y 347 | % %... 348 | % %... 349 | % perf = net.train(X, Y); % Perform 1 iteration 350 | % perf2 = net.train(X, Y, 500); % Perform 500 iterations 351 | % 352 | % See also NEURALNET2, SIM 353 | 354 | % If the number of iterations is not specified, assume 1 355 | if nargin < 4, numIter = 1; end 356 | 357 | % Ensure correct sizes 358 | assert(size(X, 1) == size(Y, 1), ['Total number of examples ' ... 359 | 'the inputs and outputs should match']) 360 | 361 | % Ensure regularization rate and batch size is proper 362 | assert(this.BatchSize >= 1, 'Batch size should be 1 or more'); 363 | assert(this.RegularizationRate >= 0, ['Regularization rate ' ... 364 | 'should be 0 or larger']); 365 | 366 | % Check if we have specified the right regularization type 367 | regType = this.RegularizationType; 368 | assert(any(strcmpi(regType, {'l1', 'l2', 'none'})), ... 369 | ['Ensure that you choose one of ''l1'', ''l2'' or '... 370 | '''none'' for the regularization type']); 371 | 372 | % Ensure number of iterations is strictly positive 373 | assert(numIter >= 1, 'Number of iterations should be positive'); 374 | 375 | % Initialize cost function array 376 | perf = zeros(1, numIter); 377 | 378 | % Total number of examples 379 | N = size(X, 1); 380 | 381 | % Total number of applicable layers 382 | L = numel(this.weights); 383 | 384 | % Get batch size 385 | % Remove decimal places in case of improper input 386 | B = floor(this.BatchSize); 387 | 388 | % Safely catch if batch size is larger than total number 389 | % of examples 390 | if B > N 391 | B = N; 392 | end 393 | 394 | % Cell array to store input and outputs of each neuron 395 | sNeuron = cell(1, L); 396 | 397 | % First cell array is for the initial 398 | xNeuron = cell(1, L + 1); 399 | 400 | % Cell array for storing the sensitivities 401 | delta = cell(1, L); 402 | 403 | % For L1 regularization 404 | if strcmpi(regType, 'l1') 405 | % This represents the total L1 penalty that each 406 | % weight could have received up to current point 407 | uk = 0; 408 | 409 | % Total penalty for each weight that was received up to 410 | % current point 411 | qk = cell(1, L); 412 | for ii = 1 : L 413 | qk{ii} = zeros(size(this.weights{ii})); 414 | end 415 | end 416 | 417 | % Get activation function for the hidden layer 418 | fcn = getActivationFunction(this.ActivationFunction); 419 | 420 | % Get derivative of activation function for the hidden layer 421 | dfcn = getDerivativeActivationFunction(this.ActivationFunction); 422 | 423 | % Do the same for the output layer 424 | if isempty(this.OutputActivationFunction) 425 | fcnOut = fcn; 426 | dfcnOut = dfcn; 427 | else 428 | fcnOut = getActivationFunction(this.OutputActivationFunction); 429 | dfcnOut = getDerivativeActivationFunction(this.OutputActivationFunction); 430 | end 431 | 432 | % For each iteration... 433 | for ii = 1 : numIter 434 | % If the batch size is equal to the total number of examples 435 | % don't bother with random selection as this will be a full 436 | % batch gradient descent 437 | if N == B 438 | ind = 1 : N; 439 | else 440 | % Randomly select examples corresponding to the batch size 441 | % if the batch size is not equal to the number of examples 442 | ind = randperm(N); 443 | ind = ind(1 : B); 444 | end 445 | 446 | % Select out the training example features and expected outputs 447 | IN = X(ind, :); 448 | OUT = Y(ind, :); 449 | 450 | % Initialize input layer 451 | xNeuron{1} = [IN ones(B, 1)]; 452 | 453 | %%% Perform forward propagation 454 | % Make sure you save the inputs and outputs into each neuron 455 | % at the hidden and output layers 456 | for jj = 1 : L 457 | % Compute inputs into next layer 458 | sNeuron{jj} = xNeuron{jj} * this.weights{jj}; 459 | 460 | % Compute outputs of this layer 461 | if jj == L 462 | xNeuron{jj + 1} = fcnOut(sNeuron{jj}); 463 | else 464 | xNeuron{jj + 1} = [fcn(sNeuron{jj}) ones(B, 1)]; 465 | end 466 | end 467 | 468 | %%% Perform backpropagation 469 | 470 | % Compute sensitivities for output layer 471 | delta{end} = (xNeuron{end} - OUT) .* dfcnOut(sNeuron{end}); 472 | 473 | % Compute the sensitivities for the rest of the layers 474 | for jj = L - 1 : -1 : 1 475 | delta{jj} = dfcn(sNeuron{jj}) .* ... 476 | (delta{jj + 1}*(this.weights{jj + 1}(1 : end - 1, :)).'); 477 | end 478 | 479 | %%% Compute weight updates 480 | alpha = this.LearningRate; 481 | lambda = this.RegularizationRate; 482 | for jj = 1 : L 483 | % Obtain the outputs and sensitivities for each 484 | % affected layer 485 | XX = xNeuron{jj}; 486 | D = delta{jj}; 487 | 488 | % Calculate batch weight update 489 | weight_update = (1 / B) * (XX.') * D; 490 | 491 | % Apply L2 regularization if required 492 | if strcmpi(regType, 'l2') 493 | weight_update(1 : end-1, :) = weight_update(1 : end - 1, :) + ... 494 | (lambda/B)*this.weights{jj}(1 : end - 1, :); 495 | end 496 | 497 | % Compute the final update 498 | this.weights{jj} = this.weights{jj} - alpha * weight_update; 499 | end 500 | 501 | % Apply L1 regularization if required 502 | if strcmpi(regType, 'l1') 503 | % Step #1 - Accumulate total L1 penalty that each 504 | % weight could have received up to this point 505 | uk = uk + (alpha * lambda / B); 506 | 507 | % Step #2 508 | % Using the updated weights, now apply the penalties 509 | for jj = 1 : L 510 | % 2a - Save previous weights and penalties 511 | % Make sure to remove bias terms 512 | z = this.weights{jj}(1 : end - 1,:); 513 | q = qk{jj}(1 : end - 1,:); 514 | 515 | % 2b - Using the previous weights, find the weights 516 | % that are positive and negative 517 | w = z; 518 | indwp = w > 0; 519 | indwn = w < 0; 520 | 521 | % 2c - Perform the update on each condition 522 | % individually 523 | w(indwp) = max(0, w(indwp) - (uk + q(indwp))); 524 | w(indwn) = min(0, w(indwn) + (uk - q(indwn))); 525 | 526 | % 2d - Update the actual penalties 527 | qk{jj}(1:end-1,:) = q + (w - z); 528 | 529 | % Don't forget to update the actual weights! 530 | this.weights{jj}(1 : end - 1, :) = w; 531 | end 532 | end 533 | 534 | % Compute cost at this iteration 535 | perf(ii) = (0.5 / B)*sum(sum((xNeuron{end} - OUT).^2)); 536 | 537 | % Add in regularization if necessary 538 | if strcmpi(regType, 'l1') 539 | for jj = 1 : L 540 | perf(ii) = perf(ii) + ... 541 | (lambda / B) * sum(sum(abs(this.weights{jj}(1 : end - 1, :)))); 542 | end 543 | elseif strcmpi(regType, 'l2') 544 | for jj = 1 : L 545 | perf(ii) = perf(ii) + ... 546 | (0.5 * lambda / B) * sum(sum((this.weights{jj}(1 : end - 1, :)).^2)); 547 | end 548 | end 549 | end 550 | end 551 | 552 | function [OUT, OUTS] = sim(this, X) 553 | % sim Perform Neural Network Predictions 554 | % This method performs forward propagation using the 555 | % learned weights after training. Forward propagation 556 | % uses the learned weights to propogate information 557 | % throughout the neural network and the predicted outcome 558 | % is seen at the output layer. 559 | % 560 | % Inputs: 561 | % X - Training examples to predict their outcomes 562 | % This is a M x N matrix where M is the total number of 563 | % examples and N is the total number of features. 564 | % N must equal to the total number of neurons in the input 565 | % layer 566 | % 567 | % Outputs: 568 | % OUT - The predicted outputs using the training examples in X. 569 | % This is a M x P matrix where M is the total number of 570 | % training examples and P is the total number of output 571 | % neurons 572 | % 573 | % OUTS - This is a 1D cell array of size 1 x NN where NN 574 | % is the total number of layers in the neural network 575 | % including the input and output layers. Therefore, if 576 | % K is equal to the total number of hidden layers, 577 | % NN = K+2. This cell array contains the outputs of 578 | % each example per layer. Specifically, each element 579 | % OUTS{ii} would be a M x Q matrix where Q would be the 580 | % total number of neurons in layer ii without the bias 581 | % unit. ii=1 is the input layer and ii=NN is the output 582 | % layer. Remember if you specify a network with no hidden 583 | % layers, this would mean that K=0 and so NN = 2. 584 | % OUTS{ii} would contain all of the outputs for each 585 | % example in layer ii. 586 | % 587 | % Uses: 588 | % net = NeuralNet2([1 2 1]); % Create NN object 589 | % % Create your training data and train your neural network 590 | % % here... 591 | % % Also create test data here stored in XX... 592 | % % ... 593 | % OUT = net.sim(XX); % Find predictions 594 | % [OUT,OUTS] = net.sim(XX); % Find predictions and outputs 595 | % % per layer 596 | % 597 | % See also NEURALNET2, TRAIN 598 | 599 | % Check if the total number of features matches the 600 | % total number of input neurons 601 | assert(size(X, 2) == this.inputSize, ['Number of features '... 602 | 'should match the number of input neurons']); 603 | 604 | % Get total number of examples 605 | N = size(X, 1); 606 | 607 | %%% Begin algorithm 608 | % Start with input layer 609 | OUT = X; 610 | 611 | % Also initialize cell array with input layer's contents 612 | OUTS = cell(1, numel(this.weights) + 1); 613 | OUTS{1} = OUT; 614 | 615 | % Get activation function 616 | fcn = getActivationFunction(this.ActivationFunction); 617 | 618 | if isempty(this.OutputActivationFunction) 619 | fcnOut = fcn; 620 | else 621 | fcnOut = getActivationFunction(this.OutputActivationFunction); 622 | end 623 | 624 | % For each layer... 625 | for ii = 1 : numel(this.weights) 626 | % Compute inputs into each neuron and corresponding 627 | % outputs 628 | if ii == numel(this.weights) 629 | OUT = fcnOut([OUT ones(N, 1)] * this.weights{ii}); 630 | else 631 | OUT = fcn([OUT ones(N, 1)] * this.weights{ii}); 632 | end 633 | OUTS{ii + 1} = OUT; 634 | end 635 | end 636 | end 637 | end 638 | 639 | function fcn = getActivationFunction(activation) 640 | switch lower(activation) 641 | case 'linear' 642 | fcn = @f_linear; 643 | case 'relu' 644 | fcn = @f_relu; 645 | case 'tanh' 646 | fcn = @f_tanh; 647 | case 'sigmoid' 648 | fcn = @f_sigmoid; 649 | otherwise 650 | error('Unknown activation function'); 651 | end 652 | end 653 | 654 | function fcn = getDerivativeActivationFunction(activation) 655 | switch lower(activation) 656 | case 'linear' 657 | fcn = @fd_linear; 658 | case 'relu' 659 | fcn = @fd_relu; 660 | case 'tanh' 661 | fcn = @fd_tanh; 662 | case 'sigmoid' 663 | fcn = @fd_sigmoid; 664 | otherwise 665 | error('Unknown activation function'); 666 | end 667 | end 668 | 669 | % activation funtions and their derivatives 670 | function y = f_linear(x) 671 | % See also: purelin 672 | y = x; 673 | end 674 | 675 | function y = fd_linear(x) 676 | % See also: dpurelin 677 | y = ones(size(x)); 678 | end 679 | 680 | function y = f_relu(x) 681 | % See also: poslin 682 | y = max(x, 0); 683 | end 684 | 685 | function y = fd_relu(x) 686 | % See also: dposlin 687 | y = double(x >= 0); 688 | end 689 | 690 | function y = f_tanh(x) 691 | % See also: tansig 692 | %y = 2 ./ (1 + exp(-2*x)) - 1; 693 | y = tanh(x); 694 | end 695 | 696 | function y = fd_tanh(x) 697 | % See also: dtansig 698 | y = f_tanh(x); 699 | y = 1 - y.^2; 700 | end 701 | 702 | function y = f_sigmoid(x) 703 | % See also: logsig 704 | y = 1 ./ (1 + exp(-x)); 705 | end 706 | 707 | function y = fd_sigmoid(x) 708 | % See also: dlogsig 709 | y = f_sigmoid(x); 710 | y = y .* (1 - y); 711 | end 712 | -------------------------------------------------------------------------------- /NeuralNetApp.m: -------------------------------------------------------------------------------- 1 | classdef NeuralNetApp < handle 2 | %NEURALNETAPP Neural network Application 3 | % 4 | % Inspired by TensoreFlow playground: 5 | % http://playground.tensorflow.org/ 6 | % 7 | % See also: NeuralNet2 8 | % 9 | 10 | %% Properties 11 | properties (SetAccess = private) 12 | % structure containing graphics handles 13 | handles 14 | % structure containing data 15 | data 16 | % neural network object 17 | net 18 | end 19 | 20 | properties (Access = private) 21 | % app state (running or paused) 22 | isRunning = false; 23 | end 24 | 25 | properties (Access = private, Constant = true) 26 | % maximuum number of hidden layers in the UI 27 | MAX_LAYERS = 5; 28 | % maximum number of neurons in each hidden layer in the UI 29 | MAX_NEURONS = 8; 30 | % data domain for both X/Y dimensions 31 | DOM = [-6 6]; 32 | % colormap 33 | CMAP = getColormap(); 34 | end 35 | 36 | %% Constructor 37 | methods 38 | function this = NeuralNetApp() 39 | %NEURALNETAPP Constructor 40 | 41 | % initialize UI 42 | createGUI(this); 43 | updateHiddenLayers(this, [4 2]); 44 | 45 | % generate data 46 | genData(this); 47 | updatePlotScatter(this); 48 | 49 | % create network 50 | buildNet(this); 51 | reset(this); 52 | end 53 | 54 | function delete(this) 55 | %DELETE Destructor 56 | 57 | % close figure 58 | %delete(this.handles.hFig); 59 | end 60 | end 61 | 62 | %% Private Methods 63 | methods (Access = private) 64 | function genData(this) 65 | %GENDATA Generate new data according to selected dataset and options 66 | 67 | % generate 2D points with corresponding binary labels {-1,+1} 68 | N = 500; % number of points 69 | ratio = get(this.handles.hSlidRatio, 'Value') / 100; 70 | noise = get(this.handles.hSlidNoise, 'Value') / 100; 71 | switch get(this.handles.hPopData, 'Value') 72 | case 1 73 | [points,labels] = genDataCircle(N, noise); 74 | case 2 75 | [points,labels] = genDataXOR(N, noise); 76 | case 3 77 | [points,labels] = genDataGaussian(N, noise); 78 | case 4 79 | [points,labels] = genDataSpiral(N, noise); 80 | otherwise 81 | error('Unrecognized dataset'); 82 | end 83 | 84 | % clip to [-6,6] range 85 | %points = min(max(points, this.DOM(1)), this.DOM(2)); 86 | 87 | % store data, partition indices, and mapped features 88 | this.data.points = points; 89 | this.data.labels = labels; 90 | this.data.trainIDX = splitData(labels, ratio); 91 | this.data.inputs = mapPoints(points, getInputsMask(this)); 92 | end 93 | 94 | function buildNet(this) 95 | %BUILDNET Create new neural network object 96 | 97 | % build net using specified layer sizes 98 | inputSize = size(this.data.inputs, 2); 99 | hiddenSizes = getHiddenSizes(this); 100 | outputSize = 1; 101 | this.net = NeuralNet2([inputSize hiddenSizes outputSize]); 102 | 103 | % set network parameters from GUI options 104 | this.net.BatchSize = round(get(this.handles.hSlidBatch, 'Value')); 105 | 106 | vals = cellstr(get(this.handles.hPopLearnRate, 'String')); 107 | idx = get(this.handles.hPopLearnRate, 'Value'); 108 | this.net.LearningRate = str2double(vals{idx}); 109 | 110 | vals = cellstr(get(this.handles.hPopActFunc, 'String')); 111 | idx = get(this.handles.hPopActFunc, 'Value'); 112 | this.net.ActivationFunction = vals{idx}; 113 | 114 | vals = cellstr(get(this.handles.hPopRegType, 'String')); 115 | idx = get(this.handles.hPopRegType, 'Value'); 116 | this.net.RegularizationType = vals{idx}; 117 | 118 | vals = cellstr(get(this.handles.hPopRegRate, 'String')); 119 | idx = get(this.handles.hPopRegRate, 'Value'); 120 | this.net.RegularizationRate = str2double(vals{idx}); 121 | 122 | vals = cellstr(get(this.handles.hPopProblem, 'String')); 123 | idx = get(this.handles.hPopProblem, 'Value'); 124 | switch vals{idx} 125 | case 'Classification' 126 | case 'Regression' 127 | end 128 | 129 | % configure network 130 | configure(this.net, this.data.inputs(this.data.trainIDX,:), ... 131 | this.data.labels(this.data.trainIDX)); 132 | end 133 | 134 | function reset(this) 135 | %RESET Reset UI state 136 | 137 | % paused, reset iterations, reset loss plots 138 | this.isRunning = false; 139 | set(this.handles.hBtnRun, 'String','Run'); 140 | set(this.handles.hTxtIter, 'String','000'); 141 | set(this.handles.hLineLoss, 'XData',NaN, 'YData',NaN); 142 | set(this.handles.hLgndLoss, 'String',... 143 | {'Test loss: 0.000', 'Training loss: 0.000'}); 144 | 145 | % re-initialize network, and run a pass to update UI 146 | init(this.net); 147 | step(this, true); 148 | drawnow(); 149 | end 150 | 151 | function step(this, skipTrain) 152 | %STEP Run one iteration: train, evaluate, and update UI 153 | 154 | if nargin < 2, skipTrain = false; end 155 | 156 | % data 157 | inputTrain = this.data.inputs(this.data.trainIDX,:); 158 | inputTest = this.data.inputs(~this.data.trainIDX,:); 159 | labelsTrain = this.data.labels(this.data.trainIDX); 160 | labelsTest = this.data.labels(~this.data.trainIDX); 161 | 162 | if ~skipTrain 163 | % increment iteration count 164 | iter = str2double(get(this.handles.hTxtIter, 'String')); 165 | set(this.handles.hTxtIter, 'String',sprintf('%03d',iter+1)); 166 | 167 | % train network 168 | train(this.net, inputTrain, labelsTrain); 169 | end 170 | 171 | % evaluate network 172 | lossTrain = mseLoss(sim(this.net, inputTrain), labelsTrain); 173 | lossTest = mseLoss(sim(this.net, inputTest), labelsTest); 174 | 175 | % update plots 176 | updatePlotLoss(this, lossTest, lossTrain); 177 | updatePlotHeatmaps(this); 178 | end 179 | 180 | function updatePlotLoss(this, lossTest, lossTrain) 181 | %UPDATEPLOTLOSS Update loss plot 182 | 183 | % add new point to each line 184 | y1 = [get(this.handles.hLineLoss(1), 'YData'), lossTest]; 185 | y2 = [get(this.handles.hLineLoss(2), 'YData'), lossTrain]; 186 | x = [NaN 2:numel(y1)]; 187 | set(this.handles.hLineLoss(1), 'XData',x, 'YData',y1); 188 | set(this.handles.hLineLoss(2), 'XData',x, 'YData',y2); 189 | 190 | % update legend strings 191 | set(this.handles.hLgndLoss, 'String',... 192 | {sprintf('Test loss: %.3f',lossTest), ... 193 | sprintf('Training loss: %.3f',lossTrain)}); 194 | end 195 | 196 | function updatePlotHeatmaps(this) 197 | %UPDATEPLOTHEATMAPS Update heatmap plots (hidden and final output) 198 | 199 | % options 200 | inputsIdx = getInputsMask(this); 201 | discretize = logical(get(this.handles.hCBoxDiscretize, 'Value')); 202 | 203 | % update hidden neurons heatmaps 204 | [X1,X2] = meshgrid(linspace(this.DOM(1),this.DOM(2),30)); 205 | [~,A] = sim(this.net, mapPoints([X1(:) X2(:)], inputsIdx)); 206 | A = A(2:end-1); % ignore input/output layers outputs 207 | for layer=1:numel(A) 208 | a = reshape(A{layer}, 30, 30, []); 209 | for neuron=1:size(a,3) 210 | img = genThumbnail(a(:,:,neuron), this.CMAP, discretize); 211 | set(this.handles.hBtnNeuron(neuron,layer), 'CData',img); 212 | end 213 | end 214 | 215 | % update output heatmap 216 | [X1,X2] = meshgrid(linspace(this.DOM(1),this.DOM(2),250)); 217 | a = sim(this.net, mapPoints([X1(:) X2(:)], inputsIdx)); 218 | a = reshape(a, 250, 250); 219 | if discretize 220 | a = sign(a); 221 | end 222 | set(this.handles.hImgOut, 'CData',a); 223 | end 224 | 225 | function updatePlotScatter(this) 226 | %UPDATEPLOTSCATTER Update scatter plots 227 | 228 | % data 229 | ptsTrain = this.data.points(this.data.trainIDX,:); 230 | ptsTest = this.data.points(~this.data.trainIDX,:); 231 | labelsTrain = this.data.labels(this.data.trainIDX); 232 | labelsTest = this.data.labels(~this.data.trainIDX); 233 | 234 | % update scatter points 235 | klass = [-1 1]; 236 | for k=1:numel(klass) 237 | idx = (labelsTrain == klass(k)); 238 | set(this.handles.hLineTrain(k), ... 239 | 'XData',ptsTrain(idx,1), 'YData',ptsTrain(idx,2)); 240 | idx = (labelsTest == klass(k)); 241 | set(this.handles.hLineTest(k), ... 242 | 'XData',ptsTest(idx,1), 'YData',ptsTest(idx,2)); 243 | end 244 | end 245 | 246 | function updateHiddenLayers(this, hiddenSizes) 247 | %UPDATEHIDDENLAYERS Update hidden layers UI to match new size 248 | 249 | % mask of active layers and neurons 250 | maskL = ((1:this.MAX_LAYERS) <= numel(hiddenSizes)); 251 | maskN = zeros(1, this.MAX_LAYERS); 252 | maskN(maskL) = hiddenSizes; 253 | maskN = bsxfun(@le, (1:this.MAX_NEURONS)', maskN); 254 | 255 | % refresh layers count and add/remove buttons 256 | set(this.handles.hTxtLayerNum, ... 257 | 'String',pluralize(numel(hiddenSizes), 'hidden layer')); 258 | [sAdd, sDel] = getButtonState(numel(hiddenSizes), this.MAX_LAYERS); 259 | set(this.handles.hBtnLayerAdd, 'Enable',sAdd); 260 | set(this.handles.hBtnLayerDel, 'Enable',sDel); 261 | 262 | % refresh neurons heatmaps 263 | vals = {'off'; 'on'}; 264 | set(this.handles.hBtnNeuron(:), {'Visible'},vals(maskN(:)+1)); 265 | 266 | % refresh neurons count 267 | set(this.handles.hTxtNeuronNum(~maskL), 'String','0 neuron'); 268 | set(this.handles.hTxtNeuronNum(maskL), {'String'},... 269 | cellstr(pluralize(hiddenSizes(:), 'neuron'))); 270 | 271 | % refresh neurons add/remove buttons 272 | set([this.handles.hBtnNeuronAdd(~maskL), ... 273 | this.handles.hBtnNeuronDel(~maskL)], 'Enable','off'); 274 | for i=1:numel(hiddenSizes) 275 | [sAdd, sDel] = getButtonState(hiddenSizes(i), this.MAX_NEURONS); 276 | set(this.handles.hBtnNeuronAdd(i), 'Enable',sAdd); 277 | set(this.handles.hBtnNeuronDel(i), 'Enable',sDel); 278 | end 279 | end 280 | 281 | function idx = getInputsMask(this) 282 | %GETINPTUSMASK Return a mask of selected inputs from UI 283 | 284 | val = get(this.handles.hTBtnInput, 'Value'); 285 | idx = logical(cell2mat(val)); 286 | end 287 | 288 | function sz = getHiddenSizes(this) 289 | %GETHIDDENSIZES Return hidden layers sizes from UI 290 | 291 | str = get(this.handles.hTxtNeuronNum, 'String'); 292 | sz = nonzeros(cellfun(@(s) sscanf(s, '%d'), str)).'; 293 | end 294 | end 295 | 296 | %% UI 297 | methods (Access = private) 298 | function createGUI(this) 299 | %CREATEGUI Build the UI 300 | 301 | % main figure 302 | hFig = figure('Menubar','none', 'Toolbar','none', ... 303 | 'NumberTitle','off', 'Name','Neural Network Playground', ... 304 | 'Colormap',this.CMAP, 'Resize','off', 'Visible','off', ... 305 | 'Units','pixels', 'Position',[100 100 1000 600]); 306 | 307 | this.handles = struct(); 308 | this.handles.hFig = hFig; 309 | 310 | % create panels 311 | hPan = createGUI_Panels(this, hFig); 312 | createGUI_PanelHeader(this, hPan(1)); 313 | createGUI_PanelData(this, hPan(2)); 314 | createGUI_PanelInput(this, hPan(3)); 315 | createGUI_PanelLayers(this, hPan(4)); 316 | createGUI_PanelOutput(this, hPan(5)); 317 | 318 | % setup event handlers 319 | registerCallbacks(this); 320 | 321 | % make figure visible 322 | set(hFig, 'Visible','on'); 323 | end 324 | 325 | function hPan = createGUI_Panels(this, hParent) 326 | %CREATEGUI_PANELS Create the layout of the top-level panels 327 | 328 | % panel properties 329 | props = {'ForegroundColor',0.3*[1 1 1]}; 330 | titles = upper({'', 'Data', 'Input', 'Hidden Layers', 'Output'}); 331 | pos = {[5 520 990 75], [5 5 135 510], [145 5 115 510], ... 332 | [265 5 400 510], [670 5 325 510]}; 333 | 334 | % panels 335 | hPan = gobj(1, numel(titles)); 336 | for i=1:numel(titles) 337 | hPan(i) = uipanel('Parent',hParent, 'Title',titles{i}, ... 338 | props{:}, 'Units','pixels', 'Position',pos{i}); 339 | end 340 | end 341 | 342 | function createGUI_PanelHeader(this, hParent) 343 | %CREATEGUI_PANELHEADER Create the header panel UI 344 | 345 | % dropdown menu properties and values 346 | props = {'ForegroundColor',0.5*[1 1 1], ... 347 | 'HorizontalAlignment','left'}; 348 | valsLR = bsxfun(@times, [1;3], 10.^(-5:1)); 349 | valsLR = valsLR(:); 350 | valsAct = {'ReLU', 'Tanh', 'Sigmoid', 'Linear'}; 351 | valsRegType = {'None', 'L1', 'L2'}; 352 | valsReg = bsxfun(@times, [1;3], 10.^(-3:1)); 353 | valsReg = [0; valsReg(:)]; 354 | valsProb = {'Classification', 'Regression'}; 355 | 356 | % run/step/reset buttons 357 | hBtnRun = uicontrol('Parent',hParent, 'Style','pushbutton', ... 358 | 'String','Run', 'Units','pixels', 'Position',[10 45 60 20]); 359 | hBtnStep = uicontrol('Parent',hParent, 'Style','pushbutton', ... 360 | 'String','Step', 'Units','pixels', 'Position',[10 25 60 20]); 361 | hBtnReset = uicontrol('Parent',hParent, 'Style','pushbutton', ... 362 | 'String','Reset', 'Units','pixels', 'Position',[10 5 60 20]); 363 | 364 | % dropdown menus 365 | uicontrol('Parent',hParent, 'Style','text', props{:}, ... 366 | 'String','Iterations', ... 367 | 'Units','pixels', 'Position',[100 40 70 20]); 368 | hTxtIter = uicontrol('Parent',hParent, 'Style','text', ... 369 | 'String','000', 'HorizontalAlignment','right', ... 370 | 'FontSize',14, 'FontWeight','bold', ... 371 | 'Units','pixels', 'Position',[100 15 70 25]); 372 | 373 | uicontrol('Parent',hParent, 'Style','text', props{:}, ... 374 | 'String','Learning rate', ... 375 | 'Units','pixels', 'Position',[210 40 120 20]); 376 | hPopLearnRate = uicontrol('Parent',hParent, 'Style','popupmenu', ... 377 | 'String',valsLR, 'Value',8, ... 378 | 'Units','pixels', 'Position',[210 15 120 20]); 379 | 380 | uicontrol('Parent',hParent, 'Style','text', props{:}, ... 381 | 'String','Activation', 'TooltipString','Activation function', ... 382 | 'Units','pixels', 'Position',[370 40 120 20]); 383 | hPopActFunc = uicontrol('Parent',hParent, 'Style','popupmenu', ... 384 | 'String',valsAct, 'Value',2, ... 385 | 'Units','pixels', 'Position',[370 15 120 20]); 386 | 387 | uicontrol('Parent',hParent, 'Style','text', props{:}, ... 388 | 'String','Regularization', ... 389 | 'Units','pixels', 'Position',[530 40 120 20]); 390 | hPopRegType = uicontrol('Parent',hParent, 'Style','popupmenu', ... 391 | 'String',valsRegType, 'Value',1, ... 392 | 'Units','pixels', 'Position',[530 15 120 20]); 393 | 394 | uicontrol('Parent',hParent, 'Style','text', props{:}, ... 395 | 'String','Regularization rate', ... 396 | 'Units','pixels', 'Position',[690 40 120 20]); 397 | hPopRegRate = uicontrol('Parent',hParent, 'Style','popupmenu', ... 398 | 'String',valsReg, 'Value',1, ... 399 | 'Units','pixels', 'Position',[690 15 120 20]); 400 | 401 | uicontrol('Parent',hParent, 'Style','text', props{:}, ... 402 | 'String','Problem type', ... 403 | 'Units','pixels', 'Position',[850 40 120 20]); 404 | hPopProblem = uicontrol('Parent',hParent, 'Style','popupmenu', ... 405 | 'String',valsProb, 'Value',1, ... 406 | 'Units','pixels', 'Position',[850 15 120 20]); 407 | 408 | % store handles 409 | this.handles.hBtnRun = hBtnRun; 410 | this.handles.hBtnStep = hBtnStep; 411 | this.handles.hBtnReset = hBtnReset; 412 | this.handles.hTxtIter = hTxtIter; 413 | this.handles.hPopLearnRate = hPopLearnRate; 414 | this.handles.hPopActFunc = hPopActFunc; 415 | this.handles.hPopRegType = hPopRegType; 416 | this.handles.hPopRegRate = hPopRegRate; 417 | this.handles.hPopProblem = hPopProblem; 418 | end 419 | 420 | function createGUI_PanelData(this, hParent) 421 | %CREATEGUI_PANELDATA Create the data panel UI 422 | 423 | % properties 424 | props = {'ForegroundColor',0.5*[1 1 1], ... 425 | 'HorizontalAlignment','left'}; 426 | valsDatasets = {'Circle', 'XOR', 'Gaussian', 'Spiral'}; 427 | 428 | % dataset dropdown menu 429 | uicontrol('Parent',hParent, 'Style','text', props{:}, ... 430 | 'String',{'Which dataset do','you want to use?'}, ... 431 | 'Units','pixels', 'Position',[5 460 120 30]); 432 | hPopData = uicontrol('Parent',hParent, 'Style','popupmenu', ... 433 | 'String',valsDatasets, 'Value',1, ... 434 | 'Units','pixels', 'Position',[5 435 120 20]); 435 | 436 | % data sliders 437 | hTxtRatio = uicontrol('Parent',hParent, 'Style','text', props{:}, ... 438 | 'String',{'Ratio of training to','test data: 50%'}, ... 439 | 'Units','pixels', 'Position',[5 385 120 30]); 440 | hSlidRatio = uicontrol('Parent',hParent, 'Style','slider', ... 441 | 'Value',50, 'Min',10, 'Max',90, 'SliderStep',[5 10]./(90-10), ... 442 | 'Units','pixels', 'Position',[5 360 120 20]); 443 | 444 | hTxtNoise = uicontrol('Parent',hParent, 'Style','text', props{:}, ... 445 | 'String','Noise: 10', ... 446 | 'Units','pixels', 'Position',[5 315 120 20]); 447 | hSlidNoise = uicontrol('Parent',hParent, 'Style','slider', ... 448 | 'Value',10, 'Min',0, 'Max',50, 'SliderStep',[5 10]./(50-0), ... 449 | 'Units','pixels', 'Position',[5 295 120 20]); 450 | 451 | hTxtBatch = uicontrol('Parent',hParent, 'Style','text', props{:}, ... 452 | 'String','Batch Size: 10', ... 453 | 'Units','pixels', 'Position',[5 250 120 20]); 454 | hSlidBatch = uicontrol('Parent',hParent, 'Style','slider', ... 455 | 'Value',10, 'Min',1, 'Max',30, 'SliderStep',[1 10]./(30-1), ... 456 | 'Units','pixels', 'Position',[5 230 120 20]); 457 | 458 | % gen button 459 | hBtnGen = uicontrol('Parent',hParent, 'Style','pushbutton', ... 460 | 'String','Regenerate', ... 461 | 'Units','pixels', 'Position',[5 180 120 20]); 462 | 463 | % store handles 464 | this.handles.hPopData = hPopData; 465 | this.handles.hTxtRatio = hTxtRatio; 466 | this.handles.hSlidRatio = hSlidRatio; 467 | this.handles.hTxtNoise = hTxtNoise; 468 | this.handles.hSlidNoise = hSlidNoise; 469 | this.handles.hTxtBatch = hTxtBatch; 470 | this.handles.hSlidBatch = hSlidBatch; 471 | this.handles.hBtnGen = hBtnGen; 472 | end 473 | 474 | function createGUI_PanelInput(this, hParent) 475 | %CREATEGUI_PANELINPUT Create the input panel UI 476 | 477 | % render thumnail image of each feature 478 | [X1,X2] = meshgrid(linspace(this.DOM(1),this.DOM(2),30)); 479 | [Z,inputNames] = mapPoints([X1(:) X2(:)]); 480 | numInputs = numel(inputNames); 481 | Z = reshape(Z, [30 30 numInputs]); 482 | imgs = cell(1, numInputs); 483 | for i=1:numInputs 484 | imgs{i} = genThumbnail(Z(:,:,i), this.CMAP, false); 485 | end 486 | 487 | uicontrol('Parent',hParent, 'Style','text', ... 488 | 'String',{'Which properties','do you want to','feed in?'}, ... 489 | 'ForegroundColor',0.5*[1 1 1], 'HorizontalAlignment','left', ... 490 | 'Units','pixels', 'Position',[5 445 100 45]); 491 | 492 | % input labels 493 | pos = cumsum([380 -45 -45 -45 -45 -45 -45]); 494 | for i=1:numInputs 495 | uicontrol('Parent',hParent, 'Style','text', ... 496 | 'String',inputNames{i}, ... 497 | 'HorizontalAlignment','right', 'FontWeight','bold', ... 498 | 'Units','pixels', 'Position',[5 pos(i) 50 20]); 499 | end 500 | 501 | % input toggle-buttons 502 | pos = cumsum([370 -45 -45 -45 -45 -45 -45]); 503 | hTBtnInput = gobj(1, numInputs); 504 | for i=1:numInputs 505 | hTBtnInput(i) = uicontrol('Parent',hParent, ... 506 | 'Style','togglebutton', 'CData',imgs{i}, 'Value',0, ... 507 | 'Units','pixels', 'Position',[60 pos(i) 40 40]); 508 | end 509 | set(hTBtnInput(1:2), 'Value',1); 510 | 511 | % store handles 512 | this.handles.hTBtnInput = hTBtnInput; 513 | end 514 | 515 | function createGUI_PanelLayers(this, hParent) 516 | %CREATEGUI_PANELLAYERS Create the layers panel UI 517 | 518 | % layers components 519 | hBtnLayerAdd = uicontrol('Parent',hParent, ... 520 | 'Style','pushbutton', 'String','+', ... 521 | 'Units','pixels', 'Position',[130 465 20 20]); 522 | hBtnLayerDel = uicontrol('Parent',hParent, ... 523 | 'Style','pushbutton', 'String','-', ... 524 | 'Units','pixels', 'Position',[150 465 20 20]); 525 | hTxtLayerNum = uicontrol('Parent',hParent, 'Style','text', ... 526 | 'String',pluralize(this.MAX_LAYERS, 'hidden layer'), ... 527 | 'HorizontalAlignment','left', 'FontSize',12, ... 528 | 'Units','pixels', 'Position',[175 465 120 20]); 529 | 530 | % neurons components per layer 531 | hBtnNeuronAdd = gobj(1, this.MAX_LAYERS); 532 | hBtnNeuronDel = gobj(1, this.MAX_LAYERS); 533 | hTxtNeuronNum = gobj(1, this.MAX_LAYERS); 534 | hBtnNeuron = gobj(this.MAX_NEURONS, this.MAX_LAYERS); 535 | img = genThumbnail(zeros(30), this.CMAP, false); 536 | pos = cumsum([0 75 75 75 75]); 537 | for k=1:this.MAX_LAYERS 538 | hBtnNeuronAdd(k) = uicontrol('Parent',hParent, ... 539 | 'Style','pushbutton', 'String','+', 'UserData',k, ... 540 | 'Units','pixels', 'Position',[30+pos(k) 435 20 20]); 541 | hBtnNeuronDel(k) = uicontrol('Parent',hParent, ... 542 | 'Style','pushbutton', 'String','-', 'UserData',k, ... 543 | 'Units','pixels', 'Position',[50+pos(k) 435 20 20]); 544 | hTxtNeuronNum(k) = uicontrol('Parent',hParent, ... 545 | 'Style','text', 'HorizontalAlignment','left', ... 546 | 'String',pluralize(this.MAX_NEURONS, 'neuron'), ... 547 | 'Units','pixels', 'Position',[25+pos(k) 415 60 20]); 548 | for n=1:this.MAX_NEURONS 549 | hBtnNeuron(n,k) = uicontrol('Parent',hParent, ... 550 | 'Style','pushbutton', 'Enable','inactive', ... 551 | 'CData',img, 'Value',0, 'Units','pixels', ... 552 | 'Position',[30+pos(k) 370-(n-1)*45 40 40]); 553 | end 554 | end 555 | 556 | % store handles 557 | this.handles.hBtnLayerAdd = hBtnLayerAdd; 558 | this.handles.hBtnLayerDel = hBtnLayerDel; 559 | this.handles.hTxtLayerNum = hTxtLayerNum; 560 | this.handles.hBtnNeuronAdd = hBtnNeuronAdd; 561 | this.handles.hBtnNeuronDel = hBtnNeuronDel; 562 | this.handles.hTxtNeuronNum = hTxtNeuronNum; 563 | this.handles.hBtnNeuron = hBtnNeuron; 564 | end 565 | 566 | function createGUI_PanelOutput(this, hParent) 567 | %CREATEGUI_PANELOUTPUT Create the output panel UI 568 | 569 | % properties 570 | props = {'XColor',0.5*[1 1 1], 'YColor',0.5*[1 1 1], ... 571 | 'FontSize',8, 'LineWidth',0.5, 'Box','off'}; 572 | [X1,X2] = meshgrid(linspace(this.DOM(1),this.DOM(2),250)); 573 | clr = brighten(this.CMAP([1 end],:), -0.4); 574 | 575 | % axes loss (lines and legend) 576 | hC = uicontainer(hParent, ... 577 | 'Units','pixels', 'Position',[10 390 300 100]); 578 | hAxLoss = axes('Parent',hC, props{:}, ... 579 | 'ColorOrder',0.44*[0 0 0; 1 1 1], ... 580 | 'Visible','off', 'XTick',[], 'YTick',[], ... 581 | 'Units','pixels', 'Position',[25 5 250 70]); 582 | hLineLoss = line(NaN(1,2), NaN, 'LineWidth',1.5, ... 583 | 'Parent',hAxLoss); 584 | hLgndLoss = legend(hLineLoss, ... 585 | 'String',{'Test loss: 0.000', 'Training loss: 0.000'}, ... 586 | 'FontSize',8, 'Interpreter','none', 'ButtonDownFcn','', ... 587 | 'Location','NorthOutside', 'Orientation','Horizontal', ... 588 | 'Units','pixels', 'Position',[25 80 250 20]); 589 | legend(hAxLoss, 'boxoff'); 590 | 591 | % axes out (image, colorbar, and scatters) 592 | hC = uicontainer(hParent, ... 593 | 'Units','pixels', 'Position',[10 60 300 325]); 594 | hAxOut = axes('Parent',hC, props{:}, ... 595 | 'XLim',this.DOM, 'YLim',this.DOM, 'CLim',[-1 1], ... 596 | 'XTick',this.DOM(1):1:this.DOM(2), ... 597 | 'YTick',this.DOM(1):1:this.DOM(2), ... 598 | 'TickDir','out', 'YDir','normal', ... 599 | 'XAxisLocation','top', 'YAxisLocation','right', ... 600 | 'Units','pixels', 'Position',[25 50 250 250]); 601 | hImgOut = image('Parent',hAxOut, ... 602 | 'XData',X1(1,:), 'YData',X2(:,1), ... 603 | 'CData',zeros(size(X1)), 'CDataMapping','scaled'); 604 | hLineTrain = gobj(1,2); 605 | hLineTest = gobj(1,2); 606 | for i=1:2 607 | hLineTrain(i) = line('Parent',hAxOut, ... 608 | 'XData',NaN, 'YData',NaN, 'LineStyle','none', ... 609 | 'Marker','o', 'MarkerSize',6, 'LineWidth',0.5, ... 610 | 'MarkerFaceColor',clr(i,:), 'MarkerEdgeColor','w'); 611 | end 612 | for i=1:2 613 | hLineTest(i) = line('Parent',hAxOut, 'Visible','off', ... 614 | 'XData',NaN, 'YData',NaN, 'LineStyle','none', ... 615 | 'Marker','o', 'MarkerSize',6, 'LineWidth',0.5, ... 616 | 'MarkerFaceColor',clr(i,:), 'MarkerEdgeColor','k'); 617 | end 618 | if isHG1() 619 | orientH = {}; 620 | else 621 | orientH = {'Orientation','horizontal'}; 622 | end 623 | hCBar = colorbar('Peer',hAxOut, props{:}, ... 624 | orientH{:}, 'Location','SouthOutside', ... 625 | 'XTick',-1:0.5:1, 'Color',0.5*[1 1 1], ... 626 | 'Units','pixels', 'Position',[25 25 250 15]); 627 | set(hCBar, 'TickDir','out'); 628 | 629 | % checkboxes 630 | hCBoxShowTest = uicontrol('Parent',hParent, ... 631 | 'Style','checkbox', 'String','Show test data', ... 632 | 'Value',0, 'Units','pixels', 'Position',[10 30 120 20]); 633 | hCBoxDiscretize = uicontrol('Parent',hParent, ... 634 | 'Style','checkbox', 'String','Discretize output', ... 635 | 'Value',0, 'Units','pixels', 'Position',[10 10 120 20]); 636 | 637 | % store handles 638 | this.handles.hLineLoss = hLineLoss; 639 | this.handles.hLgndLoss = hLgndLoss; 640 | this.handles.hImgOut = hImgOut; 641 | this.handles.hLineTrain = hLineTrain; 642 | this.handles.hLineTest = hLineTest; 643 | this.handles.hCBoxShowTest = hCBoxShowTest; 644 | this.handles.hCBoxDiscretize = hCBoxDiscretize; 645 | end 646 | 647 | function registerCallbacks(this) 648 | %REGISTERCALLBACKS Setup event handlers for UI components 649 | 650 | set(this.handles.hBtnRun, 'Callback',@this.onRunPause); 651 | set(this.handles.hBtnStep, 'Callback',@this.onStep); 652 | set(this.handles.hBtnReset, 'Callback',@this.onReset); 653 | 654 | set([this.handles.hPopLearnRate, this.handles.hPopActFunc, ... 655 | this.handles.hPopRegType, this.handles.hPopRegRate, ... 656 | this.handles.hPopProblem, this.handles.hSlidBatch], ... 657 | 'Callback',@this.onParamsChange); 658 | 659 | set([this.handles.hPopData, this.handles.hSlidRatio, ... 660 | this.handles.hSlidNoise, this.handles.hBtnGen], ... 661 | 'Callback',@this.onDataChange); 662 | set(this.handles.hTBtnInput, 'Callback',@this.onInputChange); 663 | 664 | set([this.handles.hBtnLayerAdd, this.handles.hBtnLayerDel], ... 665 | 'Callback',@this.onLayerAddRemove); 666 | set([this.handles.hBtnNeuronAdd, this.handles.hBtnNeuronDel], ... 667 | 'Callback',@this.onNeuronAddRemove); 668 | 669 | set(this.handles.hCBoxShowTest, 'Callback',@this.onShowTest); 670 | set(this.handles.hCBoxDiscretize, 'Callback',@this.onDiscretize); 671 | end 672 | end 673 | 674 | %% UI Callbacks 675 | methods (Access = private) 676 | function onRunPause(this, ~, ~) 677 | %ONRUNPAUSE Run/Pause buttons event handler 678 | 679 | % toggle run/pause 680 | this.isRunning = ~this.isRunning; 681 | if this.isRunning 682 | set(this.handles.hBtnRun, 'String','Pause'); 683 | else 684 | set(this.handles.hBtnRun, 'String','Run'); 685 | end 686 | 687 | % run in a loop 688 | while this.isRunning && ishghandle(this.handles.hFig) 689 | step(this); 690 | drawnow(); 691 | end 692 | end 693 | 694 | function onStep(this, ~, ~) 695 | %ONRUNPAUSE Step button event handler 696 | 697 | % state paused 698 | this.isRunning = false; 699 | set(this.handles.hBtnRun, 'String','Run'); 700 | 701 | % run one pass 702 | step(this); 703 | drawnow(); 704 | end 705 | 706 | function onReset(this, ~, ~) 707 | %ONRESET Reset button event handler 708 | 709 | reset(this); 710 | end 711 | 712 | function onParamsChange(this, source, ~) 713 | %ONPARAMSCHANGE Event handler for change in network options 714 | 715 | % update net params and UI as needed 716 | switch source 717 | case this.handles.hPopLearnRate 718 | vals = cellstr(get(source, 'String')); 719 | idx = get(source, 'Value'); 720 | this.net.LearningRate = str2double(vals{idx}); 721 | case this.handles.hPopRegRate 722 | vals = cellstr(get(source, 'String')); 723 | idx = get(source, 'Value'); 724 | this.net.RegularizationRate = str2double(vals{idx}); 725 | case this.handles.hSlidBatch 726 | val = round(get(source, 'Value')); 727 | set(this.handles.hTxtBatch, 'String', ... 728 | sprintf('Batch Size: %2d',val)); 729 | this.net.BatchSize = val; 730 | otherwise 731 | % recreate network 732 | buildNet(this); 733 | reset(this); 734 | end 735 | end 736 | 737 | function onDataChange(this, source, ~) 738 | %ONDATACHANGE Event handler for change in data options 739 | 740 | % update UI labels as needed 741 | val = round(get(source, 'Value')); 742 | switch source 743 | case this.handles.hSlidRatio 744 | set(this.handles.hTxtRatio, 'String', ... 745 | {'Ratio of training to',... 746 | sprintf('test data: %2d%%',val)}); 747 | case this.handles.hSlidNoise 748 | set(this.handles.hTxtNoise, 'String', ... 749 | sprintf('Noise: %2d',val)); 750 | end 751 | 752 | % generate new data 753 | genData(this); 754 | updatePlotScatter(this); 755 | 756 | % recreate network 757 | buildNet(this); 758 | reset(this); 759 | end 760 | 761 | function onInputChange(this, source, ~) 762 | %ONINPUTCHANGE Input toggle buttons event handler 763 | 764 | % make sure we have at least one data feature selected 765 | inputsMask = getInputsMask(this); 766 | if ~any(inputsMask) 767 | set(source, 'Value',1); % revert last change 768 | return; 769 | end 770 | 771 | % apply and store new data features 772 | this.data.inputs = mapPoints(this.data.points, inputsMask); 773 | 774 | % recreate network 775 | buildNet(this); 776 | reset(this); 777 | end 778 | 779 | function onLayerAddRemove(this, source, ~) 780 | %ONLAYERADDREMOVE Layers add/remove buttons event handler 781 | 782 | % update hidden layers UI 783 | hiddenSizes = getHiddenSizes(this); 784 | switch get(source, 'String') 785 | case '+' 786 | % add layer 787 | if numel(hiddenSizes) >= this.MAX_LAYERS 788 | return; 789 | end 790 | hiddenSizes(end+1) = 2; 791 | case '-' 792 | % remove layer 793 | if numel(hiddenSizes) <= 1 794 | return; 795 | end 796 | hiddenSizes(end) = []; 797 | end 798 | updateHiddenLayers(this, hiddenSizes); 799 | 800 | % recreate network 801 | buildNet(this); 802 | reset(this); 803 | end 804 | 805 | function onNeuronAddRemove(this, source, ~) 806 | %ONNEURONADDREMOVE Neurons add/remove buttons event handler 807 | 808 | % update hidden layers UI 809 | idx = get(source, 'UserData'); % hidden layer index 810 | hiddenSizes = getHiddenSizes(this); 811 | switch get(source, 'String') 812 | case '+' 813 | % add neuron to specified layer 814 | if hiddenSizes(idx) >= this.MAX_NEURONS 815 | return; 816 | end 817 | hiddenSizes(idx) = hiddenSizes(idx) + 1; 818 | case '-' 819 | % remove neuron from specified layer 820 | if hiddenSizes(idx) <= 1 821 | return; 822 | end 823 | hiddenSizes(idx) = hiddenSizes(idx) - 1; 824 | end 825 | updateHiddenLayers(this, hiddenSizes); 826 | 827 | % recreate network 828 | buildNet(this); 829 | reset(this); 830 | end 831 | 832 | function onShowTest(this, ~, ~) 833 | %ONSHOWTEST Test data checkbox event handler 834 | 835 | % toggle test points visibility 836 | if get(this.handles.hCBoxShowTest, 'Value') 837 | val = 'on'; 838 | else 839 | val = 'off'; 840 | end 841 | set(this.handles.hLineTest, 'Visible',val); 842 | end 843 | 844 | function onDiscretize(this, ~, ~) 845 | %ONDISCRETIZE Discretize checkbox event handler 846 | 847 | % if already running, no need to update plots as changes 848 | % will eventually get picked up on next iteration 849 | if ~this.isRunning 850 | updatePlotHeatmaps(this); 851 | end 852 | end 853 | end 854 | 855 | end 856 | 857 | %% Helper Functions 858 | 859 | function b = isHG1() 860 | %ISHG1 Checks if running on HG1 or HG2 graphics 861 | b = verLessThan('matlab','8.4'); 862 | end 863 | 864 | function H = gobj(varargin) 865 | %GOBJ Create an array to store graphic handles 866 | 867 | try 868 | H = gobjects(varargin{:}); 869 | catch 870 | H = zeros(varargin{:}); 871 | end 872 | end 873 | 874 | function x = randU(a,b,varargin) 875 | %RANDU Uniform random numbers 876 | 877 | try 878 | % Statistics Toolbox 879 | x = unifrnd(a, b, varargin{:}); 880 | catch 881 | x = rand(varargin{:}) * (b-a) + a; 882 | end 883 | end 884 | 885 | function x = randMVN(mu, S, N) 886 | %RANDMVN Random numbers from multivariate normal distribution 887 | 888 | try 889 | % Statistics Toolbox 890 | x = mvnrnd(mu, S, N); 891 | catch 892 | x = bsxfun(@plus, randn(N,numel(mu))*chol(S), mu); % cholcov 893 | end 894 | end 895 | 896 | function cmap = getColormap() 897 | %GETCOLORMAP Return an orange-to-blue polarized colormap 898 | 899 | % interpolate between colors {'f59322', 'e8eaeb', '0877bd'} 900 | cmap = [245 147 34; 232 234 235; 8 119 189] ./ 255; 901 | cmap = interp1([-1 0 1], cmap, linspace(-1,1,256)); 902 | 903 | % fake transparency 904 | a = 0.75; % 160/255 905 | cmap = a*cmap + (1-a)*1.0; 906 | end 907 | 908 | function Z = input_X1(X1,~) 909 | Z = X1; 910 | end 911 | 912 | function Z = input_X2(~,X2) 913 | Z = X2; 914 | end 915 | 916 | function Z = input_X12(X1,~) 917 | Z = X1.^2; 918 | end 919 | 920 | function Z = input_X22(~,X2) 921 | Z = X2.^2; 922 | end 923 | 924 | function Z = input_X1X2(X1,X2) 925 | Z = X1.*X2; 926 | end 927 | 928 | function Z = input_sinX1(X1,~) 929 | Z = sin(X1); 930 | end 931 | 932 | function Z = input_sinX2(~,X2) 933 | Z = sin(X2); 934 | end 935 | 936 | function [inputs,names] = mapPoints(points, inputsIdx, scale) 937 | %MAPPOINTS Map 2D points to inputs using specified features 938 | 939 | % all available feature functions and their labels 940 | funcs = {@input_X1, @input_X2, @input_X12, @input_X22, @input_X1X2, ... 941 | @input_sinX1, @input_sinX2}; 942 | names = {'X1', 'X2', 'X1^2', 'X2^2', 'X1*X2', 'sin(X1)', 'sin(X2)'}; 943 | 944 | % default values 945 | if nargin < 3, scale = false; end 946 | if nargin < 2, inputsIdx = true(size(funcs)); end 947 | 948 | % which functions to apply 949 | num = nnz(inputsIdx); 950 | funcs = funcs(inputsIdx); 951 | names = names(inputsIdx); 952 | 953 | % map 2D points to inputs 954 | inputs = zeros(size(points,1), num); 955 | for i=1:num 956 | inputs(:,i) = feval(funcs{i}, points(:,1), points(:,2)); 957 | end 958 | 959 | % optional feature scaling 960 | if scale 961 | s = 6; 962 | scales = [1 1 1/s 1/s 1/s s s]; 963 | inputs = bsxfun(@times, inputs, scales(inputsIdx)); 964 | end 965 | end 966 | 967 | function [data,labels] = genDataCircle(N, noise) 968 | %GENDATACIRCLE Generate 2D points for binary classification (two cocentric circles) 969 | 970 | radius = 5; 971 | data = cell(1,2); 972 | labels = cell(1,2); 973 | for i=1:2 974 | if i==1 975 | r = randU(0, radius*0.5, [N/2 1]); % radii inside 976 | else 977 | r = randU(radius*0.7, radius, [N/2 1]); % radii outside 978 | end 979 | t = rand(N/2,1) * 2*pi; % theta angles 980 | xy = bsxfun(@times, [cos(t) sin(t)], r); % points 981 | data{i} = xy; 982 | nz = randU(-radius, radius, size(xy)) * noise; % add noise 983 | % labels: positive/negative points inside/outside the circle 984 | labels{i} = (hypot(xy(:,1)+nz(:,1), xy(:,2)+nz(:,2)) <= radius*0.5); 985 | end 986 | data = vertcat(data{:}); % 2D points in [-5,5]x[-5,5] 987 | labels = vertcat(labels{:})*2 - 1; % labels: {0,1} -> {-1,1} 988 | end 989 | 990 | function [data,labels] = genDataXOR(N, noise) 991 | %GENDATAXOR Generate 2D points for binary classification (Exclusive-OR shape) 992 | 993 | r = 5; 994 | data = rand(N,2)*2*r - r; % 2D points in [-5,5]x[-5,5] 995 | data = data + sign(data)*0.3; % padding away from origin 996 | nz = randU(-r, r, size(data)) * noise; % add noise 997 | labels = (prod(data+nz,2) >= 0)*2 - 1; % labels: {0,1} -> {-1,1} 998 | end 999 | 1000 | function [data,labels] = genDataGaussian(N, noise) 1001 | %GENDATAGAUSSIAN Generate 2D points for binary classification (two gaussian blobs) 1002 | 1003 | % means and covariance 1004 | mn = [2 2; -2 -2]; 1005 | sigma = eye(2) .* (noise * 7 + 0.5); % [0.0,0.5] -> [0.5,4.0] 1006 | % generate positive/negative samples from two gaussians 1007 | data = [randMVN(mn(1,:), sigma, N/2); ... 1008 | randMVN(mn(2,:), sigma, N/2)]; 1009 | labels = reshape(repmat([1 -1], N/2, 1), [], 1); 1010 | end 1011 | 1012 | function [data,labels] = genDataSpiral(N, noise) 1013 | %GENDATASPIRAL Generate 2D points for binary classification (spiral shape) 1014 | 1015 | radius = 5; 1016 | v = linspace(0, 1, N/2).'; 1017 | r = v * radius; % radii 1018 | t = v * 1.75 * 2*pi; % angles 1019 | data = cell(1,2); 1020 | for k=1:2 1021 | % phase shift 1022 | if k == 1 1023 | deltaT = 0; % positive examples 1024 | else 1025 | deltaT = pi; % negative examples 1026 | end 1027 | xy = bsxfun(@times, r, [sin(t + deltaT) cos(t + deltaT)]); % points 1028 | nz = randU(-1, 1, size(xy)) * noise; % add noise 1029 | data{k} = xy + nz; 1030 | end 1031 | data = vertcat(data{:}); % 2D points in [-5,5]x[-5,5] 1032 | labels = reshape(repmat([1 -1], N/2, 1), [], 1); 1033 | end 1034 | 1035 | function IDX = splitData(labels, trainRatio) 1036 | %SPLITDATA Partition data 1037 | 1038 | % random partition 1039 | % (without stratification, assumes two classes with equal proportions) 1040 | N = numel(labels); 1041 | %ind = randperm(N, round(trainRatio*N)); 1042 | ind = randperm(N); 1043 | ind = ind(1:round(trainRatio*N)); 1044 | 1045 | % train logical indices 1046 | IDX = false(N,1); 1047 | IDX(ind) = true; 1048 | end 1049 | 1050 | function loss = mseLoss(Yhat, Y) 1051 | %MSELOSS Mean-squared error loss function 1052 | loss = 0.5 * mean((Yhat - Y).^2); 1053 | end 1054 | 1055 | function [stateAdd, stateDel] = getButtonState(num, mx) 1056 | %GETBUTTONSTATE Determine add/remove button on-off state based on current count 1057 | 1058 | % add button state 1059 | if num >= mx 1060 | stateAdd = 'off'; 1061 | else 1062 | stateAdd = 'on'; 1063 | end 1064 | 1065 | % remove button state 1066 | if num > 1 1067 | stateDel = 'on'; 1068 | else 1069 | stateDel = 'off'; 1070 | end 1071 | end 1072 | 1073 | function img = genThumbnail(Z, cmap, discretize) 1074 | %GENTHUMBNAIL Render matrix into a truecolor heatmap image 1075 | 1076 | if nargin < 3, discretize = false; end 1077 | if nargin < 2, cmap = getColormap(); end 1078 | 1079 | % discretize: [-1,1] range to {-1,+1} values 1080 | if discretize 1081 | Z = sign(Z); 1082 | end 1083 | 1084 | % scale to 8-bit indexed image: [-1,1] -> [0,1] -> [0,255] 1085 | try 1086 | img = im2uint8(mat2gray(Z, [-1 1])); 1087 | catch 1088 | img = uint8(255 * (Z+1)/2); 1089 | end 1090 | 1091 | % convert to truecolor according to colormap 1092 | img = ind2rgb(img, cmap); 1093 | end 1094 | 1095 | function str = pluralize(num, word) 1096 | %PLURALIZE Return formatted string(s) in singular/plural form according to count 1097 | 1098 | str = cell(size(num)); 1099 | for i=1:numel(num) 1100 | if num(i) > 1 1101 | suffix = 's'; 1102 | else 1103 | suffix = ''; 1104 | end 1105 | str{i} = sprintf('%d %s%c', num(i), word, suffix); 1106 | end 1107 | 1108 | if isscalar(num) 1109 | str = str{1}; 1110 | end 1111 | end 1112 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NeuralNetPlayground 2 | 3 | [![MATLAB FEX](https://img.shields.io/badge/MATLAB%20FEX-57610-green.svg)][1] 4 | [![Minimum Version](https://img.shields.io/badge/Requires-R2009b-blue.svg)][2] 5 | [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE.md) 6 | 7 | A MATLAB implementation of the [TensorFlow Neural Networks Playground][3]. 8 | 9 | 10 | ## Description 11 | 12 | Inspired by the [TensorFlow Neural Networks Playground][3] interface readily 13 | available online, this is a MATLAB implementation of the same Neural Network 14 | interface for using Artificial Neural Networks for regression and 15 | classification of highly non-linear data. 16 | 17 | The interface uses the *HG1* graphics system in order to be compatible with 18 | older versions of MATLAB. A secondary purpose of this project is to write a 19 | vectorized implementation of training Artificial Neural Networks with 20 | Stochastic Gradient Descent as a means of education and to demonstrate the 21 | power of MATLAB and matrices. 22 | 23 | The goal for this framework is given randomly generated training and test data 24 | that fall into two classes that conform to certain shapes or specifications, 25 | and given the configuration of a neural network, the goal is to perform either 26 | regression or classification of this data and interactively show the 27 | results to the user, specifically a classification or regression map of the 28 | data, as well as numerical performance measures such as the training and test 29 | loss and their values plotted on a performance curve over each iteration. 30 | The architecture of the neural network is highly configurable so the results 31 | for each change in the architecture can be seen immediately. 32 | 33 | There are two files that accompany this repo: 34 | 35 | - `NeuralNetApp.m`: The GUI that creates the interface as seen on TensorFlow 36 | Neural Networks Playground but is done completely with MATLAB GUI elements 37 | and widgets. 38 | - `NeuralNet2.m`: The class that performs the Neural Network training via 39 | Stochastic Gradient Descent. This is used by the `NeuralNetApp.m` app. 40 | 41 | 42 | ## Compatible Versions 43 | 44 | Debugged and tested for MATLAB R2009b or newer. 45 | 46 | This code can only be run on versions from R2009b and onwards due to the 47 | syntax for discarding output variables from functions via (`~`). If you wish 48 | to use this code for older versions (without guaranteeing compatibility), you 49 | will need to replace all instances of discarding output variables with dummy 50 | variables but you'll be subject to a variety of `mlint` errors. This effort 51 | has not been done on our part as there is very little gain to go to even older 52 | versions and so if you desire to run this code on older versions, you will 53 | have to do so yourself. 54 | 55 | 56 | ## Neural Network App 57 | 58 | Ensure that both files `NeuralNetApp.m` and `NeuralNet2.m` are in the same 59 | directory. In the MATLAB Command Window, simply run the `NeuralNetApp.m` file 60 | within this directory. Assuming you are working in the directory of where you 61 | stored, type in the following and press ENTER: 62 | 63 | ``` matlab 64 | >> NeuralNetApp 65 | ``` 66 | 67 | If you want to be explicit, you can use `run` and provide the path to where 68 | this file is located on your system: 69 | 70 | ``` matlab 71 | >> run(fullfile('path', 'to', 'the', 'NeuralNetApp.m')); 72 | ``` 73 | 74 | If all goes well, you should be presented with a GUI. 75 | See [here][4] for the output from a sample run. 76 | 77 | ![screenshot][9] 78 | 79 | 80 | ## Neural Network Class 81 | 82 | The main engine before the training algorithm is seen in the `NeuralNet2.m` 83 | file. This is a custom class that was written and is well documented to allow 84 | a MATLAB user to use it for their purposes in future code that they write. 85 | You can type in `help NeuralNet2` in the command window where this file is 86 | located on your system for a comprehensive overview on how to use this class. 87 | 88 | Check out [this page][5] for some tips on training the neural network. 89 | 90 | 91 | ## Resources 92 | 93 | - [Docs][6]: Documentation and description of the various parts of the app. 94 | - [Demos][7]: Examples and demos showing how to use the neural network class. 95 | - [Chat][8]: Stack Overflow MATLAB Chat Room 96 | 97 | 98 | [1]: http://www.mathworks.com/matlabcentral/fileexchange/57610-a-matlab-recreation-of-the-tensorflow-neural-network-playground 99 | [2]: http://www.mathworks.com/products/matlab/ 100 | [3]: http://playground.tensorflow.org/ 101 | [4]: https://github.com/StackOverflowMATLABchat/NeuralNetPlayground/wiki/Sample-Run 102 | [5]: https://github.com/StackOverflowMATLABchat/NeuralNetPlayground/wiki/Training-Tips 103 | [6]: https://github.com/StackOverflowMATLABchat/NeuralNetPlayground/wiki/Overview-of-the-GUI 104 | [7]: https://github.com/StackOverflowMATLABchat/NeuralNetPlayground/wiki/Demos 105 | [8]: http://chat.stackoverflow.com/rooms/81987/matlab-and-octave 106 | [9]: http://i.stack.imgur.com/YQdrm.png 107 | -------------------------------------------------------------------------------- /demos/demo_nn.m: -------------------------------------------------------------------------------- 1 | %% Neural Network Demo 2 | % Demonstrates how to use |NeuralNet| class for binary classification. 3 | 4 | % Ensures we add path to NeuralNet2 class to allow this demo to run 5 | addpath('../'); 6 | %% Data 7 | % Create XOR dataset 8 | 9 | % 2D points in [-1.1,1.1] range with corresponding {-1,+1} labels 10 | m = 400; 11 | X = rand(m,2)*2 - 1; 12 | X = X + sign(X)*0.1; 13 | Y = (prod(X,2) >= 0)*2 - 1; 14 | whos X Y 15 | 16 | % shuffle and split into training and test sets 17 | ratio = 0.5; 18 | mTrain = floor(ratio*m); 19 | mTest = m - mTrain; 20 | indTrain = randperm(m); 21 | Xtrain = X(indTrain(1:mTrain),:); 22 | Ytrain = Y(indTrain(1:mTrain)); 23 | Xtest = X(indTrain(mTrain+1:end),:); 24 | Ytest = Y(indTrain(mTrain+1:end)); 25 | 26 | %% Network 27 | % Create the neural network 28 | 29 | net = NeuralNet2([size(X,2) 4 2 size(Y,2)]); 30 | net.LearningRate = 0.1; 31 | net.RegularizationType = 'L2'; 32 | net.RegularizationRate = 0.01; 33 | net.ActivationFunction = 'Tanh'; 34 | net.BatchSize = 10; 35 | display(net) 36 | 37 | %% Training and Testing Network 38 | 39 | N = 5000; % number of iterations 40 | disp('Training...'); tic 41 | costVal = net.train(Xtrain, Ytrain, N); 42 | toc 43 | 44 | % compute predictions 45 | disp('Test...'); tic 46 | predictTrain = sign(net.sim(Xtrain)); 47 | predictTest = sign(net.sim(Xtest)); 48 | toc 49 | 50 | % classification accuracy 51 | fprintf('Final cost after training: %f\n', costVal(end)); 52 | fprintf('Train accuracy: %.2f%%\n', 100*sum(predictTrain == Ytrain) / mTrain); 53 | fprintf('Test accuracy: %.2f%%\n', 100*sum(predictTest == Ytest) / mTest); 54 | 55 | % plot cost function per epoch 56 | % show the cost for every 10 epochs 57 | figure(1) 58 | plot(1:10:N, costVal(1:10:end)); grid on; box on 59 | title('Cost Function'); xlabel('Epoch'); ylabel('Cost') 60 | 61 | %% Result 62 | 63 | % assign green to be the points with label -1 and red to be the points with 64 | % label +1 65 | clr = [0 0.741 0.447; 0.85 0.325 0.098]; 66 | % generate custom color map roughly based on the parula colour map 67 | % Varies between purple (negative values) to white (zero values) to orange 68 | % (positive values) 69 | % The color map is a 256 x 3 matrix 70 | cmap = interp1([-1 0 1], ... 71 | [0.929 0.694 0.125; 1 1 1; 0.494 0.184 0.556], linspace(-1,1,256)); 72 | 73 | % classification grid over domain of data 74 | [X1,X2] = meshgrid(linspace(-1.2,1.2,100)); 75 | 76 | % Use resulting trained neural network to predict each point in the 77 | % classification grid 78 | out = reshape(net.sim([X1(:) X2(:)]), size(X1)); 79 | 80 | % Hard classification by the sign of the data 81 | predictOut = sign(out); 82 | 83 | % plot predictions, with decision regions and data points overlaid 84 | 85 | % set up figure 86 | figure(2); set(gcf, 'Position',[200 200 560 550]) 87 | imagesc(X1(1,:), X2(:,2), out) 88 | set(gca, 'CLim',[-1 1], 'ALim',[-1 1]) 89 | colormap(cmap); colorbar 90 | hold on 91 | 92 | % Draw classification boundary (i.e. when the neural network output is 0) 93 | contour(X1, X2, out, [0 0], 'LineWidth',2, 'Color','k', ... 94 | 'DisplayName','boundaries') 95 | 96 | % For each label (-1,+1), extract out the training and test labels classified by 97 | % the neural network and draw circles with their corresponding labelled colors. 98 | % Test set points are slightly lighter in colour in comparison to the training 99 | % data. Each circle has a black edge surrounding it. 100 | K = [-1 1]; 101 | for ii=1:numel(K) 102 | indTrain = (Ytrain == K(ii)); 103 | indTest = (Ytest == K(ii)); 104 | line(Xtrain(indTrain,1), Xtrain(indTrain,2), 'LineStyle','none', ... 105 | 'Marker','o', 'MarkerSize',6, ... 106 | 'MarkerFaceColor',clr(ii,:), 'MarkerEdgeColor','k', ... 107 | 'DisplayName',sprintf('%+d train',K(ii))) 108 | line(Xtest(indTest,1), Xtest(indTest,2), 'LineStyle','none', ... 109 | 'Marker','o', 'MarkerSize',6, ... 110 | 'MarkerFaceColor',brighten(clr(ii,:),-0.5), 'MarkerEdgeColor','k', ... 111 | 'DisplayName',sprintf('%+d test',K(ii))) 112 | end 113 | hold off; xlabel('X1'); ylabel('X2'); title('XOR dataset') 114 | legend('show', 'Orientation','Horizontal', 'Location','SouthOutside') 115 | --------------------------------------------------------------------------------