├── .gitignore ├── LICENSE ├── README.md └── src └── neural ├── Matrix.java └── Network.java /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.jar 15 | *.war 16 | *.ear 17 | *.zip 18 | *.tar.gz 19 | *.rar 20 | 21 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 22 | hs_err_pid* 23 | 24 | # Custom 25 | build.xml 26 | build-impl.xml 27 | private.xml 28 | project.xml 29 | *.properties 30 | *.mf 31 | .netbeans* 32 | /prj/* 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Sebastian Gössl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Machine Learning 2 | # Deprecated 3 | **This repository is deprecated! Please use [NeuralNetwork](https://github.com/sebig3000/NeuralNetwork) instead!** 4 | 5 | Java collection that provides Java packages for developing machine learning algorithms and that is 6 | - easy to use -> great for small projects or just to learn how machine learning works 7 | - small and simple -> easy to understand and make changes 8 | - lightweight (mostly because I'm a student who just started to learn how to code Java and can't code more complex :P) 9 | 10 | ## Getting Started 11 | 12 | ### Prerequisites 13 | 14 | This project is written in pure vanilla Java so there is nothing needed than the standard libraries. 15 | 16 | ### Installation 17 | 18 | Just add all packages with the source files in the [Source folder /src](src) to your project and you are ready to go! 19 | Every class has a main test method. After installation just run any class so you can check if the installation was successful. 20 | 21 | ## Code Example 22 | 23 | ### [Neural Network](src/main/java/neural) 24 | 25 | Initialize a new network with a given architecture (number or inputs, number of neurons in the hidden layers and each layers activation function) 26 | (If you don't know what to choose, here is a rule of thumb for a average looking network: 27 | - number of hidden layers = 2 28 | - number of neurons per layer: number of inputs (except the last layer is the output layer = as many neurons as outputs) 29 | - activation functions: none) 30 | 31 | ``` 32 | //New network 33 | final Network net = new Network( 34 | 2, //2 inputs 35 | new int[]{3, 1}, //2 layers with 3 & 1 neurons 36 | new Network.ActivationFunction[]{ 37 | Network.ActivationFunction.NONE, //both layers with ... 38 | Network.ActivationFunction.NONE}); //... no activation function 39 | ``` 40 | 41 | Then you can seed the weights in the network (= randomize it). 42 | 43 | ``` 44 | net.seedWeights(-1, 1); 45 | ``` 46 | 47 | Prepare your training data and put it into a [Matrix] (src/main/java/neural/Matrix.java) 48 | 49 | ``` 50 | //Generate 10 training sets 51 | //Every row represents one training set (10 rows = 10 sets) 52 | //Every column gets fed into the same input/comes out of the same output 53 | //(first column gets into the first input) 54 | //(2 columns = 2 inputs / 1 column = 1 output) 55 | final Matrix trainInput = new Matrix(10, 2); 56 | final Matrix trainOutput = new Matrix(10, 1); 57 | //Fill the training sets 58 | //Inputs: two random numbers 59 | //Outputs: average of these two numbers 60 | final Random rand = new Random(); 61 | for(int set=0; set= getHeight() || column < 0 || column >= getWidth()) { 86 | throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); 87 | } 88 | 89 | 90 | matrix[row][column] = value; 91 | } 92 | 93 | /** 94 | * Returns the value of a specific element 95 | * @param row Row index of the element 96 | * @param column Column index of the element 97 | * @return The value of the element 98 | * @throws ArrayIndexOutOfBoundsException If the indices are smaller than 0 99 | * or bigger than the width/height -1 100 | */ 101 | public double get(int row, int column) { 102 | if(row < 0 || row >= getHeight() || column < 0 || column >= getWidth()) { 103 | throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); 104 | } 105 | 106 | 107 | return matrix[row][column]; 108 | } 109 | 110 | /** 111 | * Returns the height (number of rows) of the matrix 112 | * @return Height of the matrix 113 | */ 114 | public int getHeight() { 115 | return height; 116 | } 117 | 118 | /** 119 | * Returns the width (number of columns) of the matrix 120 | * @return Width of the matrix 121 | */ 122 | public int getWidth() { 123 | return width; 124 | } 125 | 126 | 127 | /** 128 | * Sets every element of the matrix to the given value 129 | * @param value Value to set every element to 130 | */ 131 | public void fill(double value) { 132 | for(int j=0; j function) { 306 | final Matrix result = new Matrix(getHeight(), getWidth()); 307 | 308 | for(int j=0; j function) { 328 | final Matrix result = new Matrix(getHeight(), getWidth()); 329 | 330 | for(int j=0; j= getWidth()) { 366 | throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); 367 | } 368 | 369 | return getRows(index, index + 1); 370 | } 371 | 372 | /** 373 | * Extracts multiple rows as a new Matrix 374 | * @param fromIndex Index of the first row 375 | * that should be extracted (inclusive) 376 | * @param toIndex Index of the last row that should be extracted (exclusive) 377 | * @return The rows as a new Matrix 378 | * @throws ArrayIndexOutOfBoundsException If an index does not point 379 | * to an existing row 380 | */ 381 | public Matrix getRows(int fromIndex, int toIndex) { 382 | if(fromIndex < 0 || fromIndex >= getHeight() 383 | || toIndex < 0 || toIndex > getHeight()) { 384 | throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); 385 | } 386 | if(fromIndex >= toIndex) { 387 | throw new IllegalArgumentException("Illegal index direction!"); 388 | } 389 | 390 | 391 | final Matrix result = new Matrix(toIndex - fromIndex, getWidth()); 392 | 393 | for(int j=0; j= getHeight()) { 445 | throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); 446 | } 447 | 448 | return removeRows(index, index + 1); 449 | } 450 | 451 | /** 452 | * Removes multiple rows of this matrix 453 | * @param fromIndex Index of the first row that should be removed (inclusive) 454 | * @param toIndex Index of the last row that should be removed (exclusive) 455 | * @return Resulting matrix 456 | * @throws ArrayIndexOutOfBoundsException If this matrix is to small to 457 | * remove the rows or an index does not point to an existing row 458 | */ 459 | public Matrix removeRows(int fromIndex, int toIndex) { 460 | if(getHeight() <= toIndex - fromIndex) { 461 | throw new ArrayIndexOutOfBoundsException("Matrix to small!"); 462 | } 463 | if(fromIndex < 0 || fromIndex >= getHeight() 464 | || toIndex < 0 || toIndex > getHeight()) { 465 | throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); 466 | } 467 | if(fromIndex >= toIndex) { 468 | throw new IllegalArgumentException("Illegal index direction!"); 469 | } 470 | 471 | 472 | final Matrix result = 473 | new Matrix(getHeight() - (toIndex - fromIndex), getWidth()); 474 | 475 | for(int j=0; j= getHeight()) { 498 | throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); 499 | } 500 | 501 | return getColumns(index, index + 1); 502 | } 503 | 504 | /** 505 | * Extracts multiple columns as a new Matrix 506 | * @param fromIndex Index of the first column 507 | * that should be extracted (inclusive) 508 | * @param toIndex Index of the last column 509 | * that should be extracted (exclusive) 510 | * @return The columns as a new Matrix 511 | * @throws ArrayIndexOutOfBoundsException If an index does not point 512 | * to an existing column 513 | */ 514 | public Matrix getColumns(int fromIndex, int toIndex) { 515 | if(fromIndex < 0 || fromIndex >= getWidth() 516 | || toIndex < 0 || toIndex > getWidth()) { 517 | throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); 518 | } 519 | if(fromIndex >= toIndex) { 520 | throw new IllegalArgumentException("Illegal index direction!"); 521 | } 522 | 523 | 524 | final Matrix result = new Matrix(getHeight(), toIndex - fromIndex); 525 | 526 | for(int j=0; j= getWidth()) { 575 | throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); 576 | } 577 | 578 | return removeColumns(index, index + 1); 579 | } 580 | 581 | /** 582 | * Removes multiple columns of this matrix 583 | * @param fromIndex Index of the first column 584 | * that should be removed (inclusive) 585 | * @param toIndex Index of the last column that should be removed (exclusive) 586 | * @return Resulting matrix 587 | * @throws ArrayIndexOutOfBoundsException If this matrix is to small to 588 | * remove the columns or an index does not point to an existing column 589 | */ 590 | public Matrix removeColumns(int fromIndex, int toIndex) { 591 | if(getWidth() <= toIndex - fromIndex) { 592 | throw new ArrayIndexOutOfBoundsException("Matrix to small!"); 593 | } 594 | if(fromIndex < 0 || fromIndex >= getWidth() 595 | || toIndex < 0 || toIndex > getWidth()) { 596 | throw new ArrayIndexOutOfBoundsException("Indices out of bounds!"); 597 | } 598 | if(fromIndex >= toIndex) { 599 | throw new IllegalArgumentException("Illegal index direction!"); 600 | } 601 | 602 | 603 | final Matrix result = 604 | new Matrix(getHeight(), getWidth() - (toIndex - fromIndex)); 605 | 606 | for(int j=0; j Math.sin(x)) + "\n\n"); 734 | 735 | 736 | System.out.println("Transpose:"); 737 | System.out.println(matrix1.transpose() + "\n\n"); 738 | 739 | 740 | System.out.println("Get row 1:"); 741 | System.out.println(matrix1.getRow(1) + "\n"); 742 | System.out.println("Get rows 1 & 2:"); 743 | System.out.println(matrix1.getRows(1, 3) + "\n"); 744 | System.out.println("Append rows:"); 745 | System.out.println(matrix1.appendRows(matrix2) + "\n"); 746 | System.out.println("Remove row 1:"); 747 | System.out.println(matrix1.removeRow(1) + "\n"); 748 | System.out.println("Remove rows 0 & 1:"); 749 | System.out.println(matrix1.removeRows(0, 2) + "\n\n"); 750 | 751 | System.out.println("Get column 1:"); 752 | System.out.println(matrix1.getColumn(1) + "\n"); 753 | System.out.println("Get columns 1 & 2:"); 754 | System.out.println(matrix1.getColumns(1, 3) + "\n"); 755 | System.out.println("Append columns:"); 756 | System.out.println(matrix1.appendColumns(matrix2) + "\n"); 757 | System.out.println("Remove column 1:"); 758 | System.out.println(matrix1.removeColumn(1) + "\n"); 759 | System.out.println("Remove columns 0 & 1:"); 760 | System.out.println(matrix1.removeColumns(0, 2) + "\n"); 761 | 762 | 763 | System.out.println("Randomize within [-1, 1[:"); 764 | matrix1.rand(new Random(), -1, 1); 765 | System.out.println(matrix1 + "\n"); 766 | } 767 | } 768 | -------------------------------------------------------------------------------- /src/neural/Network.java: -------------------------------------------------------------------------------- 1 | package neural; 2 | 3 | import java.io.DataInputStream; 4 | import java.io.DataOutputStream; 5 | import java.io.IOException; 6 | import java.util.Arrays; 7 | import java.util.Random; 8 | import java.util.function.DoubleFunction; 9 | 10 | /** 11 | * Neural network 12 | * 13 | * @author Sebastian Gössl 14 | * @version 1.2 26.03.2018 15 | */ 16 | public class Network { 17 | 18 | /** 19 | * Activation functions 20 | */ 21 | public enum ActivationFunction { 22 | NONE, TANH, SIGMOID, RELU, SOFTPLUS, RELU_LEAKY; 23 | 24 | private static final double RELU_LEAKY_LEAKAGE = 0.01; 25 | 26 | private static final String[] name = { 27 | "None", 28 | "Hyperbolic tangent", 29 | "Sigmoid", 30 | "Rectified linear unit", 31 | "SoftPlus", 32 | "Leaky rectified linear unit" 33 | }; 34 | 35 | private static final DoubleFunction[] function = { 36 | //None 37 | x -> x, 38 | //Tanh 39 | x -> Math.tanh(x), 40 | //Sigmoid 41 | x -> 1 / (1 + Math.exp(-x)), 42 | //ReLU 43 | x -> { 44 | if(x >= 0) { 45 | return x; 46 | } else { 47 | return 0.0; 48 | }}, 49 | //SoftPlus 50 | x -> Math.log(1 + Math.exp(x)), 51 | //Leaky ReLU 52 | x -> { 53 | if(x >= 0) { 54 | return x; 55 | } else { 56 | return RELU_LEAKY_LEAKAGE * x; 57 | }} 58 | }; 59 | 60 | private static final DoubleFunction[] prime = { 61 | //None 62 | x -> 1.0, 63 | //Tanh 64 | x -> 1 - Math.tanh(x) * Math.tanh(x), 65 | //Sigmoid 66 | x -> Math.exp(-x) / ((1 + Math.exp(-x)) * (1 + Math.exp(-x))), 67 | //ReLU 68 | x -> { 69 | if(x >= 0) { 70 | return 1.0; 71 | } else { 72 | return 0.0; 73 | }}, 74 | //Softplus 75 | x -> 1 / (1 + Math.exp(-x)), 76 | //Leaky ReLU 77 | x -> { 78 | if(x >= 0) { 79 | return 1.0; 80 | } else { 81 | return RELU_LEAKY_LEAKAGE; 82 | }} 83 | }; 84 | 85 | 86 | 87 | /** 88 | * Returns this activation function as a function 89 | * @return Function 90 | */ 91 | public DoubleFunction function() { 92 | return function[ordinal()]; 93 | } 94 | 95 | /** 96 | * Returns this activation function's derivative as a function 97 | * @return Function 98 | */ 99 | public DoubleFunction prime() { 100 | return prime[ordinal()]; 101 | } 102 | 103 | 104 | 105 | @Override 106 | public String toString() { 107 | return name[ordinal()]; 108 | } 109 | } 110 | 111 | 112 | 113 | /** 114 | * Inputs 115 | * 116 | * Weights[0] 117 | * 118 | * Layer[0] 119 | * ActivationZ[0] (Weighted sum) 120 | * ActivationA[0] (Activated sum) 121 | * 122 | * Weights[1] 123 | * 124 | * Layer[1] 125 | * ActivationZ[1] (Weighted sum) 126 | * ActivationA[1] (Activated sum) 127 | * 128 | * ... 129 | */ 130 | 131 | /** Number of input neurons */ 132 | private final int numberOfInputs; 133 | /** Number of neurons in each layer */ 134 | private final int[] layerSizes; 135 | /** Each layers activation function */ 136 | private final ActivationFunction[] activationFunctions; 137 | 138 | /** Weights */ 139 | private final Matrix[] weights; 140 | /** Activities, needed for backpropagation */ 141 | private final Matrix[] activityA; 142 | private final Matrix[] activityZ; 143 | 144 | 145 | 146 | /** 147 | * Constructs a new copy of an existing network 148 | * @param net Network to copy 149 | */ 150 | public Network(Network net) { 151 | this(net.getNumberOfInputs(), 152 | net.copyLayerSizes(), 153 | net.copyActivationFunctions()); 154 | 155 | setWeights(net.copyWeights()); 156 | } 157 | 158 | /** 159 | * Constructs a new network 160 | * @param numberOfInputs Number of inputs 161 | * @param layerSizes Numbers of neurons in each hidden layer, 162 | * last layer is the output layer (number of outputs) 163 | * @param activationFunctions Activation functions for every layer 164 | * @throws IllegalArgumentException If the number of layers or 165 | * the number of neurons in a layer is smaller than 1 or 166 | * if the number of given activation functions 167 | * does not equal the number of layers 168 | */ 169 | public Network(int numberOfInputs, int[] layerSizes, 170 | ActivationFunction[] activationFunctions) { 171 | if(numberOfInputs < 1) { 172 | throw new IllegalArgumentException( 173 | "Number of input neurons less than 1!"); 174 | } 175 | if(layerSizes.length < 1) { 176 | throw new IllegalArgumentException("Number of layers less than 1!"); 177 | } 178 | if(activationFunctions.length != layerSizes.length) { 179 | throw new IllegalArgumentException( 180 | "Not as many activation functions as layers!"); 181 | } 182 | for(int layerSize : layerSizes) { 183 | if(layerSize < 1) { 184 | throw new IllegalArgumentException( 185 | "Number of neurons in layer less than 1!"); 186 | } 187 | } 188 | 189 | 190 | //Dimensions 191 | this.numberOfInputs = numberOfInputs; 192 | this.layerSizes = Arrays.copyOf(layerSizes, layerSizes.length); 193 | 194 | //Activation functions 195 | this.activationFunctions = activationFunctions; 196 | 197 | //Weights 198 | weights = new Matrix[layerSizes.length]; 199 | weights[0] = new Matrix(numberOfInputs, layerSizes[0]); 200 | for(int i=1; i= layerSizes.length) { 284 | throw new ArrayStoreException("Index out of bounds!"); 285 | } 286 | 287 | 288 | return layerSizes[index]; 289 | } 290 | 291 | /** 292 | * Returns a copy of the numbers of neurons in every layer 293 | * @return Copy of numbers of neurons in every layer 294 | */ 295 | public int[] copyLayerSizes() { 296 | return Arrays.copyOf(layerSizes, layerSizes.length); 297 | } 298 | 299 | 300 | /** 301 | * Sets the activation function of the specified layer 302 | * @param index Index of the layer 303 | * @param function Activation function 304 | * @throws ArrayIndexOutOfBoundsException If the index does not point 305 | * to an existing layer 306 | */ 307 | public void setActivationFunction(int index, ActivationFunction function) { 308 | if(index < 0 || index >= activationFunctions.length) { 309 | throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); 310 | } 311 | 312 | 313 | activationFunctions[index] = function; 314 | } 315 | 316 | /** 317 | * Returns the activation function of the specific layer 318 | * @param index Index of the layer 319 | * @return Activation function of the layer 320 | * @throws ArrayIndexOutOfBoundsException If the index does not point 321 | * to an existing layer 322 | */ 323 | public ActivationFunction getActivationFunction(int index) { 324 | if(index < 0 || index >= activationFunctions.length) { 325 | throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); 326 | } 327 | 328 | 329 | return activationFunctions[index]; 330 | } 331 | 332 | /** 333 | * Returns the activation functions of every layer 334 | * @return Activation functions 335 | */ 336 | public ActivationFunction[] getActivationFunctions() { 337 | return activationFunctions; 338 | } 339 | 340 | /** 341 | * Returns a copy of the activation functions of every layer 342 | * @return Copy of the activation functions of every layer 343 | */ 344 | public ActivationFunction[] copyActivationFunctions() { 345 | return Arrays.copyOf(activationFunctions, activationFunctions.length); 346 | } 347 | 348 | 349 | /** 350 | * Sets the weights of a single layer 351 | * @param index Layer index 352 | * @param layer New weights 353 | * @throws IllegalArgumentException If the index does not point 354 | * to an existing matrix or the given matrix dimensions 355 | * do not equal the needed size 356 | */ 357 | public void setWeights(int index, Matrix layer) { 358 | if(index < 0 || index >= weights.length) { 359 | throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); 360 | } 361 | if(layer.getHeight() != weights[index].getHeight() 362 | || layer.getWidth() != weights[index].getWidth()) { 363 | throw new IllegalArgumentException("Incorrect layer dimensions!"); 364 | } 365 | 366 | 367 | weights[index] = layer; 368 | } 369 | 370 | /** 371 | * Sets the weights for every layer 372 | * @param weights New weights 373 | * @throws IllegalArgumentException If the number of matricies 374 | * does not equal the number of layers 375 | * or the dimensions of a matrix do not equal the needed dimensions 376 | */ 377 | public void setWeights(Matrix[] weights) { 378 | if(weights.length != this.weights.length) { 379 | throw new IllegalArgumentException("Incorrect number of layers!"); 380 | } 381 | for(int i=0; i= weights.length) { 401 | throw new ArrayIndexOutOfBoundsException("Index out of bounds!"); 402 | } 403 | 404 | return weights[index]; 405 | } 406 | 407 | /** 408 | * Returns the weights of every layer 409 | * @return Weights 410 | */ 411 | public Matrix[] getWeights() { 412 | return Arrays.copyOf(weights, weights.length); 413 | } 414 | 415 | /** 416 | * Returns a copy of all weights 417 | * @return Copy of all weights 418 | */ 419 | public Matrix[] copyWeights() { 420 | final Matrix[] copy = new Matrix[weights.length]; 421 | 422 | for(int i=0; i { 509 | if(Double.isNaN(x)) { 510 | return 0.0; 511 | } else if(x <= Double.NEGATIVE_INFINITY) { 512 | return -Double.MAX_VALUE; 513 | } else if(x >= Double.POSITIVE_INFINITY) { 514 | return Double.MAX_VALUE; 515 | } 516 | 517 | return x; 518 | }); 519 | } 520 | } 521 | 522 | /** 523 | * Keeps weights within the given boundaries 524 | * @param minimum Minimum value 525 | * @param maximum Maximum value 526 | */ 527 | public void keepWeightsInBounds(double minimum, double maximum) { 528 | if(minimum >= maximum) { 529 | throw new IllegalArgumentException( 530 | "Minimum greater than or equal to maximum!"); 531 | } 532 | 533 | 534 | for(int i=0; i { 536 | if(Double.isNaN(x)) { 537 | return (minimum + maximum) / 2; 538 | } else if(x < minimum) { 539 | return minimum; 540 | } else if(x > maximum) { 541 | return maximum; 542 | } 543 | 544 | return x; 545 | }); 546 | } 547 | } 548 | 549 | 550 | 551 | /** 552 | * Forward propagates a matrix of data sets. 553 | * Every single row represents one data set 554 | * Every column gets feed into one input neuron 555 | * @param input Input sets 556 | * @return Output sets 557 | * @throws IllegalArgumentException If the number of input values (columns) 558 | * does not equal the number of input neurons 559 | */ 560 | public Matrix forward(Matrix input) { 561 | if(input.getWidth() != numberOfInputs) { 562 | throw new IllegalArgumentException("Illegal number of inputs!"); 563 | } 564 | 565 | 566 | activityZ[0] = input.multiply(weights[0]); 567 | activityA[0] = activityZ[0].apply(activationFunctions[0].function()); 568 | 569 | for(int i=1; i0; i--) { 652 | dJdW[i] = activityA[i-1].transpose().multiply(delta); 653 | delta = delta.multiply(weights[i].transpose()).multiplyElementwise( 654 | activityZ[i-1].apply(activationFunctions[i-1].prime())); 655 | } 656 | 657 | dJdW[0] = input.transpose().multiply(delta); 658 | 659 | 660 | return dJdW; 661 | } 662 | 663 | 664 | 665 | /** 666 | * Trains the network. 667 | * (override the method "keepTraining" to set the continuation condition) 668 | * @param learningRate Initial learning rate 669 | * @param input Input sets 670 | * @param output Wanted output sets 671 | * @param printToConsole Print progress to console 672 | * @return Last cost 673 | * @throws IllegalArgumentException If the number of inputs or outputs 674 | * does not fit the dimensions of this network or the number 675 | * of input sets is not equal to the number of output sets 676 | */ 677 | public double train(double learningRate, 678 | Matrix input, Matrix output, boolean printToConsole) { 679 | if(input.getWidth() != getNumberOfInputs()) { 680 | throw new IllegalArgumentException("Illegal number of inputs!"); 681 | } 682 | if(output.getWidth() != getNumberOfOutputs()) { 683 | throw new IllegalArgumentException("Illegal number of outputs!"); 684 | } 685 | if(input.getHeight() != output.getHeight()) { 686 | throw new IllegalArgumentException( 687 | "Unequal number of input and output sets!"); 688 | } 689 | 690 | 691 | double lastCost = cost(input, output); 692 | 693 | for(int iterations = 0; 694 | keepTraining(iterations, learningRate, lastCost) 695 | && learningRate > 0; 696 | iterations++) 697 | { 698 | final Matrix[] lastWeights = copyWeights(); 699 | 700 | singleGradientDescent(learningRate, input, output); 701 | final double currentCost = cost(input, output); 702 | 703 | if(printToConsole) { 704 | System.out.println(String.format("%d: %e", iterations, currentCost)); 705 | } 706 | 707 | if(currentCost <= lastCost) 708 | { 709 | lastCost = currentCost; 710 | learningRate *= 1.1; 711 | } 712 | else 713 | { 714 | setWeights(lastWeights); 715 | learningRate /= 2; 716 | } 717 | } 718 | 719 | 720 | return lastCost; 721 | } 722 | 723 | /** 724 | * Tells the network how long to continue to train 725 | * @param iterations Number of completed training cycles 726 | * @param learningRate Current learning rate 727 | * @param cost Current cost 728 | * @return If the training process should continue 729 | */ 730 | public boolean keepTraining(int iterations, double learningRate, 731 | double cost) { 732 | return iterations < 100; 733 | } 734 | 735 | /** 736 | * Backpropagates and applies the gradient with the given learning rate once 737 | * @param learningRate Learning rate 738 | * @param input Input sets 739 | * @param output Wanted output sets 740 | * @throws IllegalArgumentException If the number of inputs or outputs 741 | * does not fit the dimensions of this network or the number 742 | * of input sets is not equal to the number of output sets 743 | */ 744 | private void singleGradientDescent(double learningRate, 745 | Matrix input, Matrix output) { 746 | final Matrix[] dJdW = costPrime(input, output); 747 | 748 | for(int i=0; i