├── Images └── Set14 │ ├── baboon.bmp │ ├── barbara.bmp │ ├── bridge.bmp │ ├── coastguard.bmp │ ├── comic.bmp │ ├── face.bmp │ ├── flowers.bmp │ ├── foreman.bmp │ ├── lenna.bmp │ ├── man.bmp │ ├── monarch.bmp │ ├── pepper.bmp │ ├── ppt3.bmp │ └── zebra.bmp ├── JavaMachineLearning.jar ├── README.md ├── README_ko.md ├── classification_example.png ├── classification_example2.png ├── classification_example3.png ├── classification_example4.png ├── error_graph_cross_entropy.png ├── error_graph_squared.png ├── mnist_weights_fc.nn ├── nn_linear_regression.png └── src ├── javamachinelearning ├── drawables │ ├── MNISTDrawablePanel.java │ └── MNISTDrawablePanel2.java ├── graphs │ ├── Graph.java │ ├── GraphPanel.java │ ├── Line.java │ ├── LineGraph.java │ └── Point.java ├── layers │ ├── Layer.java │ ├── ParamsLayer.java │ ├── feedforward │ │ ├── ActivationLayer.java │ │ ├── AvgPoolingLayer.java │ │ ├── ConvLayer.java │ │ ├── DropoutLayer.java │ │ ├── FCLayer.java │ │ ├── FeedForwardLayer.java │ │ ├── FeedForwardParamsLayer.java │ │ ├── FlattenLayer.java │ │ ├── MaxPoolingLayer.java │ │ └── ScalingLayer.java │ └── recurrent │ │ ├── GRUCell.java │ │ ├── RecurrentCell.java │ │ └── RecurrentLayer.java ├── networks │ ├── NeuralNetwork.java │ ├── SequentialNN.java │ └── SupervisedNeuralNetwork.java ├── optimizers │ ├── AdaDeltaOptimizer.java │ ├── AdagradOptimizer.java │ ├── AdamOptimizer.java │ ├── MomentumOptimizer.java │ ├── NAGOptimizer.java │ ├── Optimizer.java │ ├── RMSPropOptimizer.java │ └── SGDOptimizer.java ├── regularizers │ ├── ElasticNetRegularizer.java │ ├── L1Regularizer.java │ ├── L2Regularizer.java │ └── Regularizer.java └── utils │ ├── Activation.java │ ├── ImageUtils.java │ ├── Loss.java │ ├── MNISTUtils.java │ ├── Tensor.java │ ├── TensorUtils.java │ └── Utils.java └── tests ├── Categories2Graph.java ├── Categories4Graph.java ├── ErrorGraphCrossEntropy.java ├── ErrorGraphSquared.java ├── GRUTest.java ├── LinearGraph.java ├── LoadTest.java ├── LogicGates.java ├── SaveTest.java ├── TestImageUtils.java ├── TestMNISTDraw1.java ├── TestMNISTDraw2.java ├── TestMNISTFile.java ├── TrainMNISTConv.java ├── TrainMNISTConvMemorize.java └── TrainMNISTFullyConnected.java /Images/Set14/baboon.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/baboon.bmp -------------------------------------------------------------------------------- /Images/Set14/barbara.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/barbara.bmp -------------------------------------------------------------------------------- /Images/Set14/bridge.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/bridge.bmp -------------------------------------------------------------------------------- /Images/Set14/coastguard.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/coastguard.bmp -------------------------------------------------------------------------------- /Images/Set14/comic.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/comic.bmp -------------------------------------------------------------------------------- /Images/Set14/face.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/face.bmp -------------------------------------------------------------------------------- /Images/Set14/flowers.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/flowers.bmp -------------------------------------------------------------------------------- /Images/Set14/foreman.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/foreman.bmp -------------------------------------------------------------------------------- /Images/Set14/lenna.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/lenna.bmp -------------------------------------------------------------------------------- /Images/Set14/man.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/man.bmp -------------------------------------------------------------------------------- /Images/Set14/monarch.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/monarch.bmp -------------------------------------------------------------------------------- /Images/Set14/pepper.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/pepper.bmp -------------------------------------------------------------------------------- /Images/Set14/ppt3.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/ppt3.bmp -------------------------------------------------------------------------------- /Images/Set14/zebra.bmp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/Images/Set14/zebra.bmp -------------------------------------------------------------------------------- /JavaMachineLearning.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/JavaMachineLearning.jar -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Java Machine Learning Library 2 | Simple machine learning (neural network) library for Java. The library is mainly for educational purposes, and it is way too slow to be used on actual projects. 3 | 4 | The Korean translation of this README.md is [here](README_ko.md), if you prefer to read it in Korean. 5 | 6 | (**NOTE: PROBABLY OUTDATED. Just compile from source.**) If you want to download the compiled `.jar` file and include it to your own project, click [here](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/raw/master/JavaMachineLearning.jar). 7 | 8 | This library recently got an overhaul that fixed many bugs and uses vectorized operations with a built-in tensor class, among many other features. The source code was also organized and comments were added. 9 | 10 | ## Features 11 | - Feed-forward layers 12 | - Fully connected 13 | - Convolutional (2D convolution on 3D inputs with 4D weights) 14 | - Max/Average Pooling 15 | - Dropout 16 | - Activation 17 | - Flatten (Conv/Pooling -> FC) 18 | - Scaling 19 | - Recurrent layer 20 | - GRU Cells 21 | - Adam, Adagrad, momentum, NAG, Nesterov, SGD, RMSProp and AdaDelta optimizers 22 | - Mini-batch gradient descent 23 | - Average gradients for each weight throughout each batch 24 | - Sigmoid, tanh, relu, hard sigmoid, and softmax activation functions 25 | - L1, L2, and elastic net regularization 26 | - Squared loss, binary cross entropy, and multi-class cross entropy 27 | - Squared loss for regression 28 | - Binary cross entropy + sigmoid activation for binary classification 29 | - Multi-class cross entropy + softmax activation for general classification 30 | - Internally uses "tensors", which are multidimensional arrays/matrices 31 | - Simple graphing class for graphing classification boundaries, points, lines, line plots, etc. 32 | - MNIST dataset loader 33 | - Save/load weights to/from files 34 | - Drawing GUI for MNIST 35 | - A bunch of testing classes and graphing examples 36 | - Image preprocessing 37 | 38 | ## Tutorial 39 | The API provided by this library is quite elegant (in my opinion) and very high level. A whole network can be created by initializing a `SequentialNN` class. That class provides the tools to add layers and build a complete network. When initializing that class, you need to specify the shape of the input as the parameter. 40 | 41 | Using the `add` method in `SequentialNN`, you can add layers to the sequential model. These layers will be evaluated in the order they are added during forward propagation. To forward propagate, use the predict function and provide input(s) as tensors. Tensors are multidimensional arrays that are represented in a flat, column major order format internally. However, it provides a few constructors that accept (regular) row major arrays. To train a model, call the `train` method with inputs and expected target outputs. This method has many parameters that can be changed, such as the loss function, optimizer, regularizer, etc. A callback function can even be provided for every epoch of training. 42 | 43 | With the addition of a `RecurrentLayer` class, inputs and outputs can span many time steps. For example, when using fully connected layer after a recurrent layer, the fully connected layer is applied to the outputs for every single time step. Another addition is a flexible `predict` function that allows a custom number of time steps to be evaluated. Recurrent layers can also be stateful throughout multiple training examples or predictions. 44 | 45 | ### Vanilla Neural Networks 46 | 47 | Here is a piece of code that shows how easy it is to run a simple linear regression using a neural network: 48 | ```java 49 | // neural network with 1 input and 1 output, no activation function 50 | SequentialNN nn = new SequentialNN(1); 51 | nn.add(new FCLayer(1)); 52 | 53 | // y = 5x + 3 54 | Tensor[] x = { 55 | t(0), 56 | t(1), 57 | t(2), 58 | t(3), 59 | t(4) 60 | }; 61 | 62 | Tensor[] y = { 63 | t(3 + 0 + 1), 64 | t(3 + 5 - 1), 65 | t(3 + 10 + 1), 66 | t(3 + 15 - 1), 67 | t(3 + 20 + 1) 68 | }; 69 | 70 | nn.train(x, 71 | y, 72 | 100, // number of epochs 73 | 1, // batch size 74 | Loss.squared, 75 | new SGDOptimizer(0.01), 76 | null, // no regularizer 77 | false, //do not shuffle data 78 | true); // verbose 79 | 80 | // try the network on new data 81 | System.out.println(nn.predict(t(5))); 82 | ``` 83 | You can find the full source file [here](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/LinearGraph.java). Note that the `t` method is just a convenience method to create 1D tensors. The full code will produce a window with the points and the line formed by the weight/bias graphed: 84 | ![linear regression graph](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/nn_linear_regression.png) 85 | 86 | On a slightly different set of data (y = 5x instead of y = 5x + 3, no noise, and no bias), the loss/error with respect to the weight can be graphed: 87 | ![error wrt weight graph](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/error_graph_squared.png) 88 | The green dots represent weights that the training algorithm "visited" throughout training. The quadratic shape of the graph is due to the squared loss function. Note that it converges to the minimum, where the loss is the lowest, and that minimum is centered on x = 5, which is the slope of the linear function that we want to learn. 89 | 90 | The following piece of code is for training a 3 layer neural network for the MNIST handwritten digit classification. 91 | ```java 92 | // create a model with 784 input neurons, 300 hidden neurons, and 10 output neurons 93 | // use RELU for the hidden layer and softmax for the output layer 94 | SequentialNN nn = new SequentialNN(784); 95 | nn.add(new FCLayer(300)); 96 | nn.add(new ActivationLayer(Activation.relu)); 97 | nn.add(new FCLayer(10)); // 10 categories of numbers 98 | nn.add(new ActivationLayer(Activation.softmax)); 99 | 100 | // load the training data 101 | Tensor[] x = MNISTUtils.loadDataSetImages("train-images-idx3-ubyte", Integer.MAX_VALUE); 102 | Tensor[] y = MNISTUtils.loadDataSetLabels("train-labels-idx1-ubyte", Integer.MAX_VALUE); 103 | 104 | long start = System.currentTimeMillis(); 105 | 106 | nn.train(Utils.flattenAll(x), 107 | y, 108 | 100, // number of epochs 109 | 100, // batch size 110 | Loss.softmaxCrossEntropy, 111 | new MomentumOptimizer(0.5, true), 112 | new L2Regularizer(0.0001), 113 | true, // shuffle the data after every epoch 114 | false); 115 | 116 | System.out.println("Training time: " + Utils.formatElapsedTime(System.currentTimeMillis() - start)); 117 | 118 | // save the learned weights 119 | nn.saveToFile("mnist_weights_fc.nn"); 120 | 121 | // predict on previously unseen testing data 122 | Tensor[] testX = MNISTUtils.loadDataSetImages("t10k-images-idx3-ubyte", Integer.MAX_VALUE); 123 | Tensor[] testY = MNISTUtils.loadDataSetLabels("t10k-labels-idx1-ubyte", Integer.MAX_VALUE); 124 | Tensor[] testResult = nn.predict(Utils.flattenAll(testX)); 125 | 126 | // prints the percent of images classified correctly 127 | System.out.println("Classification accuracy: " + Utils.format(Utils.classificationAccuracy(testResult, testY))); 128 | ``` 129 | The full code can be found [here](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/TrainMNISTFullyConnected.java). 130 | 131 | ### Convolutional Neural Networks 132 | 133 | The training code that uses convolutional layers for the same digit classification task can be found [here](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/TrainMNISTConv.java). However, the code is very slow, so a simpler test to see if the model can directly memorize some digits was conducted. The code is available [here](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/TrainMNISTConvMemorize.java). The architecture is very similar to the previous convolutional network: 134 | ```java 135 | SequentialNN nn = new SequentialNN(28, 28, 1); 136 | 137 | nn.add(new ConvLayer(5, 32, PaddingType.SAME)); 138 | nn.add(new ActivationLayer(Activation.relu)); 139 | nn.add(new MaxPoolingLayer(2, 2)); 140 | 141 | nn.add(new ConvLayer(5, 64, PaddingType.SAME)); 142 | nn.add(new ActivationLayer(Activation.relu)); 143 | nn.add(new MaxPoolingLayer(2, 2)); 144 | 145 | nn.add(new FlattenLayer()); 146 | 147 | nn.add(new FCLayer(1024)); 148 | nn.add(new ActivationLayer(Activation.relu)); 149 | 150 | nn.add(new DropoutLayer(0.3)); 151 | 152 | nn.add(new FCLayer(10)); 153 | nn.add(new ActivationLayer(Activation.softmax)); 154 | ``` 155 | Training this network takes around 20 minutes and it can memorize the input image's classes perfectly. 156 | 157 | ### Recurrent Neural Networks 158 | 159 | Creating a recurrent neural network is also very simple. Currently, only GRU cells are supported, and I used that to learn and generate some Shakespeare and Alice's Adventures in Wonderland text. 160 | 161 | Here are the hyperparameters used: 162 | ```java 163 | int epochs = 500; 164 | int batchSize = 10; 165 | int winSize = 20; 166 | int winStep = 20; // winSize = winStep so substrings are not repeated 167 | int genIter = 5000; // how many characters to generate 168 | double temperature = 0.1; // lower = less randomness 169 | ``` 170 | And here is the code that builds the 2 layer recurrent neural network model: 171 | ```java 172 | // for each time step, the input is a one hot vector describing the current character 173 | // for each time step, the output is a one hot vector describing the next character 174 | // the recurrent layers are stateful, which means that the next state relies on the previous states 175 | SequentialNN nn = new SequentialNN(winSize, alphabet.length()); 176 | nn.add(new RecurrentLayer(winSize, new GRUCell(), true)); 177 | nn.add(new DropoutLayer(0.3)); 178 | nn.add(new RecurrentLayer(winSize, new GRUCell(), true)); 179 | // the same fully connected layer is applied for every single time step 180 | nn.add(new FCLayer(alphabet.length())); 181 | // scales the values by the temperature before softmax 182 | nn.add(new ScalingLayer(1 / temperature, false)); 183 | nn.add(new ActivationLayer(Activation.softmax)); 184 | ``` 185 | Go [here](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/GRUTest.java) if you want the full code for training the model and generating text. 186 | 187 | Here is the output from running on Shakespear's Sonnet #130: 188 | ``` 189 | [cxx]x 190 | 191 | my mistress' eyes are nothing like the sun 192 | coral is far more red, than her lips red 193 | if snow be white, why then her breasts are dun 194 | if hairs be wires, black wires grow on her head. 195 | i have seen roses damask'd, red and white, 196 | but no such roses see i in her cheeks 197 | and in some perfumes is there more delight 198 | than in the breath that from my mistress reeks. 199 | i love to hear her speak, yet well i know 200 | that music hath a far more pleasing sound 201 | i grant i never saw a goddess go,-- 202 | my mistress, when she walks, treads on the ground 203 | and yet by heaven, i think my love as rare, 204 | as any she belied with false compare 205 | ``` 206 | The text in brackets at the very beginning is the seed text entered in by me. The network takes that and generates the rest of the sonnet, plus some extra spaces at the end that I removed. 207 | 208 | Here is the output from running on an excerpt of Alice's Adventures in Wonderland: 209 | ``` 210 | [chapter] i. down the rabbit-hole 211 | 212 | alice was beginning then she 213 | ray so menty see. 214 | 215 | af the 216 | hing howver be world she was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it pocts tow th the tried to have wondered at the sides it pocts top the rabbit with pink eyes time as she fell very slowly, for she had 217 | plenty of time as she went lothe the down nothing to her owa get in to her that she was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it withing either a waistcoat-pocket, and to ple pfonsidering in her feet, for it flashed across her mind there she fell very slowly, for she had 218 | plenty so it with pink ey her feet, for it flashed across her mind that she ought to have wondered at the sides it pocts top the rabbit with pink eyes time as she fell very slowly, for she had 219 | plenty of time as she went tring to look down and make out what 220 | she was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it poct plap the had pe was coming to, but it was too dark to see anything then she 221 | looked at the sides all seemed quite was beginning then she 222 | ray co peer a watch 223 | to take out of it, and fortunately was just in time to see it pop down a large 224 | rabbit-hole under the well was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it pocts top the rabbit with pink eyes time as she fell very slowly, for she had 225 | plenty so it with pink ey, she she to out it was too dark to tires all, be late! (when she thought it over afterwatd, but it was too dark to te was beginning then somenasy 226 | seen a rabbit with either a waistcoat-pocket, and to ple ppend.n ahelves thok down a jar from one of the shelves had wondel very sleepy and stupid), whether the well was coused it was labelled orange maran 227 | aling she to her feet, for it flashed across her mind there she fell very slowly, for she had 228 | plenty of time as she went tr.e hed as 229 | she pagllidy, nothing then she 230 | ron to happen next. first, she tried to look down and make out what 231 | she was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it poct poud there she fell past it 232 | ``` 233 | There are a lot of misspelled words, but it is pretty cool nonetheless. 234 | 235 | Finally, here is the output of running the network on the entire Act I Scene I of Romeo and Juliet: 236 | ``` 237 | [act i]t sarn'd to the will part thee. 238 | 239 | rom. i do beauty his sunpong sprice. 240 | 241 | rom. i do beauty his groanse the wall. 242 | 243 | samp. i do beaut's thess swords therefere in that is to streponse thee the wall. 244 | 245 | samp. i do beaut's thess swords therefere if thou doth the maids, or ment to the willat them, an thet thee weart of lovers' from the strunce of his will the live his will be comes to the but whett ther theis ments i will the hat with the fair, 246 | bees thee, when the wall the hat with the wall. 247 | 248 | samp. i do beaut's thess swords therefere in that len. 249 | 250 | ben. montague should be so fair mark. 251 | sh therefore i will they will stoul of the maids having the will part thee the hat let pee i pass me not here in sparkling hath the maids hour side i sad the wall. 252 | 253 | samp. i do beaut's thess swords therefere in that len. 254 | 255 | ben. montague should more or here with the maids hours shown so thee with me. 256 | 257 | samp. i do beauty his caming makes the he will. 258 | 259 | samp. no, should montagues! 260 | wher thear hears'd and will they were head the willat to stand, and moved. 261 | 262 | ben. in sand hear the with my a wist. 263 | 264 | samp. i do beaut's thess sword morte the wall. 265 | 266 | samp. i do beaut's this will stans. 267 | 268 | greg. the heads of the beauty the hasterte me my fair madk to the was the weakes starn, 269 | where is to stor. what her comes is to store. 270 | 271 | samp. i do beauty the wall the hat here in sparkling his hit tles, and montague and with me. 272 | 273 | samp. i do betheas in paine. 274 | 275 | samp. a dog of the fair markman. 276 | 277 | samp. i do beaut's thess sword morte me they me what hes if othen. i will the wall. 278 | 279 | samp. i do beaut's thess swords therefere in that me. 280 | 281 | rom. i dis ment good the was she head the was what, and hent sword of the will part the will part the will. 282 | 283 | samp. no, sir. 284 | 285 | samp. no, as the weadt of the wall. 286 | 287 | samp. i do beaut's thess sword. 288 | 289 | rom. i dis ment good the wall. 290 | 291 | samp. a dog of the hat heads her his comes and montague ind sen the wall. 292 | 293 | samp. i do beaut's thess sword morte and made is that we dows thee i sang the hearty saive in strun. 294 | ``` 295 | As you can see there are some repetitions that would probably disappear if the temperature is increased (which increases the randomness). Originally, I wanted the network to start predicting from the first line that says what act and scene it was, but the network started from somewhere else. 296 | 297 | In all of these examples, the model and hyperparameters were the same. What's cool is that the network learns the structure of the text and properly adds newlines and indents for the first and third examples. Also, I got the texts from [Project Gutenberg](http://www.gutenberg.org/). 298 | 299 | Many other examples can be found in the [tests folder](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/tree/master/src/tests). 300 | 301 | I have a blog post on backpropagation and gradient descent equations [here](https://c0deb0t.wordpress.com/2018/06/17/the-math-for-gradient-descent-and-backpropagation/). It has some interesting math stuff! 302 | 303 | ### Image load 304 | 305 | If you want to load image to Tensor, you can do following codes 306 | 307 | ``` 308 | ImageUtils imgUtils = new ImageUtils(); 309 | Tensor imgTensor = img.readColorImageToTensor(String path, boolean convertGray) 310 | ``` 311 | 312 | And if you want to load many images to Tensor array, also you can do following codes 313 | 314 | ``` 315 | ImageUtils imgUtils = new ImagUtils() 316 | Tensor[] imgTensorArray = public Tensor[] readImages(String folderPath, boolean convertGray) 317 | ``` 318 | -------------------------------------------------------------------------------- /README_ko.md: -------------------------------------------------------------------------------- 1 | # Java Machine Learning Library 2 | Java를 이용한 간단한 머신 러닝(신경망) 라이브러리입니다. 이 라이브러리는 교육을 목적으로 하고 있으며, 실제 프로젝트에서 사용하기에는 매우 느립니다. 3 | 4 | (**주의: 구식코드입니다. 소스를 컴파일하세요.**)컴파일된 `.jar`파일을 다운로드 하고 당신의 프로젝트에 포함하고 싶으면 [여기](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/raw/master/JavaMachineLearning.jar)를 누르세요. 5 | 6 | 이 라이브러리는 최근에 많은 버그들을 고치고 내장된 tensor 클래스로 다양한 기능들을 포함한 벡터화된 연산을 사용하는 점검을 했습니다. 7 | 8 | ## 특징 9 | - Feed-forward layers 10 | - Fully connected 11 | - Convolutional (2D convolution on 3D inputs with 4D weights) 12 | - Max/Average Pooling 13 | - Dropout 14 | - Activation 15 | - Flatten (Conv/Pooling -> FC) 16 | - Scaling 17 | - Recurrent layer 18 | - GRU Cells 19 | - Adam, Adagrad, momentum, NAG, Nesterov, SGD, RMSProp and AdaDelta optimizers 20 | - Mini-batch gradient descent 21 | - Average gradients for each weight throughout each batch 22 | - Sigmoid, tanh, relu, hard sigmoid, and softmax activation functions 23 | - L1, L2, and elastic net regularization 24 | - Squared loss, binary cross entropy, and multi-class cross entropy 25 | - Squared loss for regression 26 | - Binary cross entropy + sigmoid activation for binary classification 27 | - Multi-class cross entropy + softmax activation for general classification 28 | - Internally uses "tensors", which are multidimensional arrays/matrices 29 | - Simple graphing class for graphing classification boundaries, points, lines, line plots, etc. 30 | - MNIST dataset loader 31 | - Save/load weights to/from files 32 | - Drawing GUI for MNIST 33 | - A bunch of testing classes and graphing examples 34 | - Image preprocessing 35 | 36 | ## 사용 지침 37 | 이 라이브러리에서 제공하는 API는 고상하고(제 생각엔) 매우 높은 수준입니다. 모든 네트워크는 `SequentialNN` 클래스로 초기화하여 만들 수 있습니다. 이 클래스는 레이어를 추가하고 완전한 네트워크를 만들어 주는 도구를 제공합니다. 이 클래스를 초기화 할 때, 당신은 매개변수로서 입력의 형태를 명시해야 합니다. 38 | 39 | `SequentialNN`에 있는 `add` 메소드를 사용하여, 당신은 순차 모델(Sequential model)에 레이어를 추가할 수 있습니다. 이 레이어들은 순전파(forward propagation) 동안에 추가된 순서대로 평가됩니다. 순전파를 하려면, 예측 함수를 사용하고 tensor로 입력(들)을 제공합니다. Tensors는 내부적으로 열 주요 순서인 평면을 나타내는 다차원 배열입니다. 하지만, Tensors는 (정규) 행 주요 배열을 받는 일부 생성자를 제공합니다. 모델을 학습시키기 위해서, 입력 및 예측되는 목표 출력과 함께 `train` 메소드를 호출하세요. 이 메소드는 손실 함수(loss function), optimizer, regularizer 등과 같은 변화할 수 있는 많은 매개변수를 가지고 있습니다. callback 함수는 모든 학습 시기에도 제공될 수 있습니다. 40 | 41 | `RecurrentLayer` 클래스를 추가하여, 입력과 출력을 여러 시간 단계로 걸칠 수 있습니다. 예를 들어, 순환 레이어(recurrent layer) 후에 전결합 레이어(fully connected layer)를 사용할 때, 전결합 레이어는 각 시간 단계 전체에 출력이 적용됩니다. 다른 추가점은 사용자가 지정한 시간 단계에 평가할 수 있게 하는 유연한 `predict` 함수입니다. 또한 순환 레이어는 다중 훈련 예제 또는 예측을 통해 상태저장을 할 수 있습니다. 42 | 43 | ### 바닐라 신경망 44 | 45 | 신경망을 사용한 간단한 선형 회귀(linear regression)를 실행하는 것이 얼마나 쉬운지를 보여주는 코드가 있습니다: 46 | ```java 47 | // neural network with 1 input and 1 output, no activation function 48 | SequentialNN nn = new SequentialNN(1); 49 | nn.add(new FCLayer(1)); 50 | 51 | // y = 5x + 3 52 | Tensor[] x = { 53 | t(0), 54 | t(1), 55 | t(2), 56 | t(3), 57 | t(4) 58 | }; 59 | 60 | Tensor[] y = { 61 | t(3 + 0 + 1), 62 | t(3 + 5 - 1), 63 | t(3 + 10 + 1), 64 | t(3 + 15 - 1), 65 | t(3 + 20 + 1) 66 | }; 67 | 68 | nn.train(x, 69 | y, 70 | 100, // number of epochs 71 | 1, // batch size 72 | Loss.squared, 73 | new SGDOptimizer(0.01), 74 | null, // no regularizer 75 | false, //do not shuffle data 76 | true); // verbose 77 | 78 | // try the network on new data 79 | System.out.println(nn.predict(t(5))); 80 | ``` 81 | 전체 소스 파일은 [여기](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/LinearGraph.java)에서 찾을 수 있습니다. `t` 메소드는 1차원 텐서를 만들기 위한 편리한 메소드일 뿐임을 주의해주세요. 전체 코드는 점과 선으로 이루어진 가중치(weight)/편향(bias) 그래프 창을 만듭니다: 82 | ![linear regression graph](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/nn_linear_regression.png) 83 | 84 | 약간 다른 데이터 집합(y = 5x + 3 대신 y = 5x, 노이즈가 없고, 편향치가 없음)에서, 가중치에 관한 손실/오류를 그래프로 그릴 수 있습니다: 85 | ![error wrt weight graph](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/error_graph_squared.png) 86 | 초록 점들은 훈련 알고리즘이 훈련을 통해 "찾아간" 가중치를 표시합니다. 손실 함수의 제곱으로 인해 이차식 그래프로 나타납니다. 그래프는 손실이 가장 적을 때인 최솟값으로 수렴되는 점에 주목해주세요. 이 최솟값은 우리가 학습하고자 하는 선형 함수의 기울기인 x = 5 중심에 있습니다. 87 | 88 | 다음 코드는 MNIST 필기 숫자 분류를 위해 3중 레이어 신경망을 훈련 하기 위한 것입니다. 89 | ```java 90 | // create a model with 784 input neurons, 300 hidden neurons, and 10 output neurons 91 | // use RELU for the hidden layer and softmax for the output layer 92 | SequentialNN nn = new SequentialNN(784); 93 | nn.add(new FCLayer(300)); 94 | nn.add(new ActivationLayer(Activation.relu)); 95 | nn.add(new FCLayer(10)); // 10 categories of numbers 96 | nn.add(new ActivationLayer(Activation.softmax)); 97 | 98 | // load the training data 99 | Tensor[] x = MNISTUtils.loadDataSetImages("train-images-idx3-ubyte", Integer.MAX_VALUE); 100 | Tensor[] y = MNISTUtils.loadDataSetLabels("train-labels-idx1-ubyte", Integer.MAX_VALUE); 101 | 102 | long start = System.currentTimeMillis(); 103 | 104 | nn.train(Utils.flattenAll(x), 105 | y, 106 | 100, // number of epochs 107 | 100, // batch size 108 | Loss.softmaxCrossEntropy, 109 | new MomentumOptimizer(0.5, true), 110 | new L2Regularizer(0.0001), 111 | true, // shuffle the data after every epoch 112 | false); 113 | 114 | System.out.println("Training time: " + Utils.formatElapsedTime(System.currentTimeMillis() - start)); 115 | 116 | // save the learned weights 117 | nn.saveToFile("mnist_weights_fc.nn"); 118 | 119 | // predict on previously unseen testing data 120 | Tensor[] testX = MNISTUtils.loadDataSetImages("t10k-images-idx3-ubyte", Integer.MAX_VALUE); 121 | Tensor[] testY = MNISTUtils.loadDataSetLabels("t10k-labels-idx1-ubyte", Integer.MAX_VALUE); 122 | Tensor[] testResult = nn.predict(Utils.flattenAll(testX)); 123 | 124 | // prints the percent of images classified correctly 125 | System.out.println("Classification accuracy: " + Utils.format(Utils.classificationAccuracy(testResult, testY))); 126 | ``` 127 | 전체 소스 파일은 [여기](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/TrainMNISTFullyConnected.java)에서 찾을 수 있습니다. 128 | 129 | ### 콘볼루션 신경망 130 | 131 | 동일한 숫자 분류 작업을 위한 콘볼루션 레이어를 사용한 훈련 코드는 [here](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/TrainMNISTConv.java)에서 찾을 수 있습니다. 하지만, 이 코드는 매우 느려서 모델이 직접적으로 일부의 숫자를 기억할 수 있는지를 보는 더 간단한 테스트를 수행했습니다. 그 코드는 [여기](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/TrainMNISTConvMemorize.java)에서 이용할 수 있습니다. 그 구조는 이전 콘볼루션 망과 매우 유사합니다: 132 | ```java 133 | SequentialNN nn = new SequentialNN(28, 28, 1); 134 | 135 | nn.add(new ConvLayer(5, 32, PaddingType.SAME)); 136 | nn.add(new ActivationLayer(Activation.relu)); 137 | nn.add(new MaxPoolingLayer(2, 2)); 138 | 139 | nn.add(new ConvLayer(5, 64, PaddingType.SAME)); 140 | nn.add(new ActivationLayer(Activation.relu)); 141 | nn.add(new MaxPoolingLayer(2, 2)); 142 | 143 | nn.add(new FlattenLayer()); 144 | 145 | nn.add(new FCLayer(1024)); 146 | nn.add(new ActivationLayer(Activation.relu)); 147 | 148 | nn.add(new DropoutLayer(0.3)); 149 | 150 | nn.add(new FCLayer(10)); 151 | nn.add(new ActivationLayer(Activation.softmax)); 152 | ``` 153 | 이 망을 훈련하는 데는 20분 쯤 걸리고 입력 이미지 클래스를 완벽히 기억할 수 있습니다. 154 | 155 | ### 순환 신경망 156 | 157 | 순환 신경망을 만드는 것 또한 매우 간단합니다. 현재, GRU cells만 지원되고, 저는 일부 셰익스피어와 이상한 나라의 앨리스 글을 학습하고 생성하기 위해 그 것을 사용했습니다.. 158 | 159 | 여기엔 하이퍼 파라미터가 사용되었습니다: 160 | ```java 161 | int epochs = 500; 162 | int batchSize = 10; 163 | int winSize = 20; 164 | int winStep = 20; // winSize = winStep so substrings are not repeated 165 | int genIter = 5000; // how many characters to generate 166 | double temperature = 0.1; // lower = less randomness 167 | ``` 168 | 그리고 여기 2중 레이어 순환 신경망 모델을 빌드하는 코드가 있습니다. 169 | ```java 170 | // for each time step, the input is a one hot vector describing the current character 171 | // for each time step, the output is a one hot vector describing the next character 172 | // the recurrent layers are stateful, which means that the next state relies on the previous states 173 | SequentialNN nn = new SequentialNN(winSize, alphabet.length()); 174 | nn.add(new RecurrentLayer(winSize, new GRUCell(), true)); 175 | nn.add(new DropoutLayer(0.3)); 176 | nn.add(new RecurrentLayer(winSize, new GRUCell(), true)); 177 | // the same fully connected layer is applied for every single time step 178 | nn.add(new FCLayer(alphabet.length())); 179 | // scales the values by the temperature before softmax 180 | nn.add(new ScalingLayer(1 / temperature, false)); 181 | nn.add(new ActivationLayer(Activation.softmax)); 182 | ``` 183 | 모델을 훈련하고 본문을 생성하는 전체 코드를 원하면 [여기](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/blob/master/src/tests/GRUTest.java)로 오세요. 184 | 185 | 셰익스피어의 소네트 130번을 수행한 출력이 있습니다: 186 | ``` 187 | [cxx]x 188 | 189 | my mistress' eyes are nothing like the sun 190 | coral is far more red, than her lips red 191 | if snow be white, why then her breasts are dun 192 | if hairs be wires, black wires grow on her head. 193 | i have seen roses damask'd, red and white, 194 | but no such roses see i in her cheeks 195 | and in some perfumes is there more delight 196 | than in the breath that from my mistress reeks. 197 | i love to hear her speak, yet well i know 198 | that music hath a far more pleasing sound 199 | i grant i never saw a goddess go,-- 200 | my mistress, when she walks, treads on the ground 201 | and yet by heaven, i think my love as rare, 202 | as any she belied with false compare 203 | ``` 204 | 맨 처음 괄호 안에 있는 글은 제가 입력한 시드 텍스트입니다. 네트워크는 그 것을 받고 제가 지운 끝 부분에 빈 공간을 조금 더해서 소네트의 나머지 부분을 생성합니다. 205 | 206 | 여기 이상한 나라의 앨리스에서 발췌한 문단을 수행한 출력이 있습니다: 207 | ``` 208 | [chapter] i. down the rabbit-hole 209 | 210 | alice was beginning then she 211 | ray so menty see. 212 | 213 | af the 214 | hing howver be world she was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it pocts tow th the tried to have wondered at the sides it pocts top the rabbit with pink eyes time as she fell very slowly, for she had 215 | plenty of time as she went lothe the down nothing to her owa get in to her that she was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it withing either a waistcoat-pocket, and to ple pfonsidering in her feet, for it flashed across her mind there she fell very slowly, for she had 216 | plenty so it with pink ey her feet, for it flashed across her mind that she ought to have wondered at the sides it pocts top the rabbit with pink eyes time as she fell very slowly, for she had 217 | plenty of time as she went tring to look down and make out what 218 | she was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it poct plap the had pe was coming to, but it was too dark to see anything then she 219 | looked at the sides all seemed quite was beginning then she 220 | ray co peer a watch 221 | to take out of it, and fortunately was just in time to see it pop down a large 222 | rabbit-hole under the well was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it pocts top the rabbit with pink eyes time as she fell very slowly, for she had 223 | plenty so it with pink ey, she she to out it was too dark to tires all, be late! (when she thought it over afterwatd, but it was too dark to te was beginning then somenasy 224 | seen a rabbit with either a waistcoat-pocket, and to ple ppend.n ahelves thok down a jar from one of the shelves had wondel very sleepy and stupid), whether the well was coused it was labelled orange maran 225 | aling she to her feet, for it flashed across her mind there she fell very slowly, for she had 226 | plenty of time as she went tr.e hed as 227 | she pagllidy, nothing then she 228 | ron to happen next. first, she tried to look down and make out what 229 | she was considering in her feet, for it flashed across her mind that she ought to have wondered at the sides it poct poud there she fell past it 230 | ``` 231 | 문단에는 철자가 틀린 단어가 매우 많지만, 꽤 멋져 보입니다. 232 | 233 | 마지막으로, 로미오와 줄리엣 1막 1장 전문으로 네트워크를 수행한 출력입니다: 234 | ``` 235 | [act i]t sarn'd to the will part thee. 236 | 237 | rom. i do beauty his sunpong sprice. 238 | 239 | rom. i do beauty his groanse the wall. 240 | 241 | samp. i do beaut's thess swords therefere in that is to streponse thee the wall. 242 | 243 | samp. i do beaut's thess swords therefere if thou doth the maids, or ment to the willat them, an thet thee weart of lovers' from the strunce of his will the live his will be comes to the but whett ther theis ments i will the hat with the fair, 244 | bees thee, when the wall the hat with the wall. 245 | 246 | samp. i do beaut's thess swords therefere in that len. 247 | 248 | ben. montague should be so fair mark. 249 | sh therefore i will they will stoul of the maids having the will part thee the hat let pee i pass me not here in sparkling hath the maids hour side i sad the wall. 250 | 251 | samp. i do beaut's thess swords therefere in that len. 252 | 253 | ben. montague should more or here with the maids hours shown so thee with me. 254 | 255 | samp. i do beauty his caming makes the he will. 256 | 257 | samp. no, should montagues! 258 | wher thear hears'd and will they were head the willat to stand, and moved. 259 | 260 | ben. in sand hear the with my a wist. 261 | 262 | samp. i do beaut's thess sword morte the wall. 263 | 264 | samp. i do beaut's this will stans. 265 | 266 | greg. the heads of the beauty the hasterte me my fair madk to the was the weakes starn, 267 | where is to stor. what her comes is to store. 268 | 269 | samp. i do beauty the wall the hat here in sparkling his hit tles, and montague and with me. 270 | 271 | samp. i do betheas in paine. 272 | 273 | samp. a dog of the fair markman. 274 | 275 | samp. i do beaut's thess sword morte me they me what hes if othen. i will the wall. 276 | 277 | samp. i do beaut's thess swords therefere in that me. 278 | 279 | rom. i dis ment good the was she head the was what, and hent sword of the will part the will part the will. 280 | 281 | samp. no, sir. 282 | 283 | samp. no, as the weadt of the wall. 284 | 285 | samp. i do beaut's thess sword. 286 | 287 | rom. i dis ment good the wall. 288 | 289 | samp. a dog of the hat heads her his comes and montague ind sen the wall. 290 | 291 | samp. i do beaut's thess sword morte and made is that we dows thee i sang the hearty saive in strun. 292 | ``` 293 | 당신도 보다시피 temperature가 증가(랜덤하게 증가)하면 아마도 사라질 반복이 있습니다. 원래, 저는 글이 무슨 막, 무슨 장인지 말해주는 첫 라인부터 예측을 시작하는 네트워크를 원했지만 네트워크는 다른 부분에서 실행됩니다. 294 | 295 | 이 모든 예제들에서, 모델과 하이퍼파라미터는 같습니다. 멋진 점은 이 네트워크는 텍스트의 구조를 학습하고 적절하게 새로운 줄과 들여쓰기를 첫 번째와 세 번째 예제에서 추가했다는 점입니다. 또한, 저는 이 글들을 [Project Gutenberg](http://www.gutenberg.org/)에서 가져왔습니다. 296 | 297 | 많은 다른 예제들은 [tests folder](https://github.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/tree/master/src/tests)에서 찾을 수 있습니다. 298 | 299 | 저는 후전파(back propagation)와 경사 하강 방정식(gradient descent equations) 블로그 [게시글](https://c0deb0t.wordpress.com/2018/06/17/the-math-for-gradient-descent-and-backpropagation/)을 가지고 있습니다. 거기엔 흥미로운 수학 내용이 있습니다! 300 | 301 | ### 이미지 로드 302 | 303 | 텐서를 이용해 이미지를 로드하고 싶으면, 다음 코드를 따라하세요. 304 | 305 | ``` 306 | ImageUtils imgUtils = new ImageUtils(); 307 | Tensor imgTensor = img.readColorImageToTensor(String path, boolean convertGray) 308 | ``` 309 | 310 | 그리고 텐서 배열을 이용해 많은 이미지를 로드하고 싶으면, 다음 코드 또한 따라할 수 있습니다. 311 | 312 | ``` 313 | ImageUtils imgUtils = new ImagUtils() 314 | Tensor[] imgTensorArray = public Tensor[] readImages(String folderPath, boolean convertGray) 315 | ``` 316 | -------------------------------------------------------------------------------- /classification_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/classification_example.png -------------------------------------------------------------------------------- /classification_example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/classification_example2.png -------------------------------------------------------------------------------- /classification_example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/classification_example3.png -------------------------------------------------------------------------------- /classification_example4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/classification_example4.png -------------------------------------------------------------------------------- /error_graph_cross_entropy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/error_graph_cross_entropy.png -------------------------------------------------------------------------------- /error_graph_squared.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/error_graph_squared.png -------------------------------------------------------------------------------- /mnist_weights_fc.nn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/mnist_weights_fc.nn -------------------------------------------------------------------------------- /nn_linear_regression.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Daniel-Liu-c0deb0t/Java-Machine-Learning/b4706d3788441b2ae8e8add2e80de972cadb5b8a/nn_linear_regression.png -------------------------------------------------------------------------------- /src/javamachinelearning/drawables/MNISTDrawablePanel.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.drawables; 2 | 3 | import java.awt.Color; 4 | import java.awt.Dimension; 5 | import java.awt.Graphics; 6 | import java.awt.Graphics2D; 7 | import java.awt.RenderingHints; 8 | import java.awt.event.MouseEvent; 9 | import java.awt.event.MouseMotionListener; 10 | import java.awt.image.BufferedImage; 11 | 12 | import javax.swing.JPanel; 13 | 14 | import javamachinelearning.utils.Tensor; 15 | import javamachinelearning.utils.Utils; 16 | 17 | @SuppressWarnings("serial") 18 | public class MNISTDrawablePanel extends JPanel{ 19 | private int width; 20 | private int height; 21 | private int xSize; 22 | private int ySize; 23 | private BufferedImage image; 24 | private Graphics2D graphics; 25 | 26 | public MNISTDrawablePanel(int width, int height, int xSize, int ySize){ 27 | this.width = width; 28 | this.height = height; 29 | this.xSize = xSize; 30 | this.ySize = ySize; 31 | setPreferredSize(new Dimension(width, height)); 32 | this.image = new BufferedImage(xSize, ySize, BufferedImage.TYPE_INT_RGB); 33 | graphics = this.image.createGraphics(); 34 | graphics.setColor(Color.white); 35 | graphics.fillRect(0, 0, xSize, ySize); 36 | graphics.setColor(Color.black); 37 | addMouseMotionListener(new MouseMotionListener(){ 38 | @Override 39 | public void mouseDragged(MouseEvent e){ 40 | graphics.fillRect( 41 | (int)((double)e.getX() / ((double)width / (double)xSize)), 42 | (int)((double)e.getY() / ((double)height / (double)ySize)), 2, 2); 43 | repaint(); 44 | } 45 | 46 | @Override 47 | public void mouseMoved(MouseEvent e){ 48 | 49 | } 50 | }); 51 | } 52 | 53 | public Tensor getData(int outputX1, int outputY1, int outputX2, int outputY2){ 54 | int minX = Integer.MAX_VALUE; 55 | int maxX = 0; 56 | int minY = Integer.MAX_VALUE; 57 | int maxY = 0; 58 | for(int i = 0; i < xSize; i++){ 59 | for(int j = 0; j < ySize; j++){ 60 | if((image.getRGB(i, j) & 0xFF) < 255){ 61 | minX = Math.min(minX, i); 62 | maxX = Math.max(maxX, i); 63 | minY = Math.min(minY, j); 64 | maxY = Math.max(maxY, j); 65 | } 66 | } 67 | } 68 | if((maxX - minX) * 3 < maxY - minY){ 69 | minX -= (maxX - minX) * 1.5; 70 | maxX += (maxX - minX) * 1.5; 71 | } 72 | BufferedImage temp = new BufferedImage(maxX - minX, maxY - minY, image.getType()); 73 | Graphics2D g = temp.createGraphics(); 74 | g.setColor(Color.white); 75 | g.fillRect(0, 0, maxX - minX, maxY - minY); 76 | g.dispose(); 77 | for(int i = minX; i < maxX; i++){ 78 | for(int j = minY; j < maxY; j++){ 79 | if(i >= 0 && i < xSize && j >= 0 && j < ySize){ 80 | temp.setRGB(i - minX, j - minY, image.getRGB(i, j)); 81 | } 82 | } 83 | } 84 | 85 | BufferedImage result = new BufferedImage(outputX1, outputY1, image.getType()); 86 | g = result.createGraphics(); 87 | g.setColor(Color.white); 88 | g.fillRect(0, 0, outputX1, outputY1); 89 | g.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BICUBIC); 90 | g.drawImage(temp, 0, 0, outputX1, outputY1, 0, 0, temp.getWidth(), temp.getHeight(), null); 91 | g.dispose(); 92 | double[][] arr = new double[outputY1][outputX1]; 93 | for(int i = 0; i < outputY1; i++){ 94 | for(int j = 0; j < outputX1; j++){ 95 | arr[i][j] = 1.0 - (result.getRGB(j, i) & 0xFF) / 255.0; 96 | } 97 | } 98 | return Utils.centerData(arr, outputX2, outputY2); 99 | } 100 | 101 | public void clear(){ 102 | image = new BufferedImage(xSize, ySize, BufferedImage.TYPE_INT_RGB); 103 | graphics = image.createGraphics(); 104 | graphics.setColor(Color.white); 105 | graphics.fillRect(0, 0, xSize, ySize); 106 | graphics.setColor(Color.black); 107 | repaint(); 108 | } 109 | 110 | @Override 111 | public void paintComponent(Graphics g){ 112 | super.paintComponent(g); 113 | g.drawImage(image, 0, 0, width * width / xSize, height * height / ySize, 0, 0, width, height, null); 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/javamachinelearning/drawables/MNISTDrawablePanel2.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.drawables; 2 | 3 | import java.awt.Color; 4 | import java.awt.Dimension; 5 | import java.awt.Graphics; 6 | import java.awt.Graphics2D; 7 | import java.awt.event.MouseEvent; 8 | import java.awt.event.MouseMotionListener; 9 | import java.awt.image.BufferedImage; 10 | 11 | import javax.swing.JPanel; 12 | 13 | import javamachinelearning.utils.Tensor; 14 | import javamachinelearning.utils.Utils; 15 | 16 | @SuppressWarnings("serial") 17 | public class MNISTDrawablePanel2 extends JPanel{ 18 | private int width; 19 | private int height; 20 | private int xSize; 21 | private int ySize; 22 | private BufferedImage image; 23 | private Graphics2D graphics; 24 | 25 | public MNISTDrawablePanel2(int width, int height, int xSize, int ySize){ 26 | this.width = width; 27 | this.height = height; 28 | this.xSize = xSize; 29 | this.ySize = ySize; 30 | setPreferredSize(new Dimension(width, height)); 31 | this.image = new BufferedImage(xSize, ySize, BufferedImage.TYPE_INT_RGB); 32 | graphics = this.image.createGraphics(); 33 | graphics.setColor(Color.white); 34 | graphics.fillRect(0, 0, xSize, ySize); 35 | graphics.setColor(Color.black); 36 | addMouseMotionListener(new MouseMotionListener(){ 37 | @Override 38 | public void mouseDragged(MouseEvent e){ 39 | graphics.fillRect( 40 | (int)((double)e.getX() / ((double)width / (double)xSize)), 41 | (int)((double)e.getY() / ((double)height / (double)ySize)), 1, 1); 42 | repaint(); 43 | } 44 | 45 | @Override 46 | public void mouseMoved(MouseEvent e){ 47 | 48 | } 49 | }); 50 | } 51 | 52 | public Tensor getData(int outputX2, int outputY2){ 53 | double[][] arr = new double[ySize][xSize]; 54 | for(int i = 0; i < ySize; i++){ 55 | for(int j = 0; j < xSize; j++){ 56 | arr[i][j] = 1.0 - (image.getRGB(j, i) & 0xFF) / 255.0; 57 | } 58 | } 59 | return Utils.centerData(arr, outputX2, outputY2); 60 | } 61 | 62 | public void clear(){ 63 | image = new BufferedImage(xSize, ySize, BufferedImage.TYPE_INT_RGB); 64 | graphics = image.createGraphics(); 65 | graphics.setColor(Color.white); 66 | graphics.fillRect(0, 0, xSize, ySize); 67 | graphics.setColor(Color.black); 68 | repaint(); 69 | } 70 | 71 | @Override 72 | public void paintComponent(Graphics g){ 73 | super.paintComponent(g); 74 | g.drawImage(image, 0, 0, width * width / xSize, height * height / ySize, 0, 0, width, height, null); 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/javamachinelearning/graphs/Graph.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.graphs; 2 | 3 | import java.awt.BasicStroke; 4 | import java.awt.Color; 5 | import java.awt.Graphics2D; 6 | import java.awt.geom.AffineTransform; 7 | import java.awt.image.BufferedImage; 8 | import java.io.File; 9 | import java.util.ArrayList; 10 | 11 | import javax.imageio.ImageIO; 12 | 13 | import javamachinelearning.utils.Utils; 14 | 15 | public class Graph{ 16 | private BufferedImage graph; 17 | private Graphics2D graphics; 18 | private ArrayList points = new ArrayList<>(); 19 | private ArrayList lines = new ArrayList<>(); 20 | private ArrayList lineGraphs = new ArrayList<>(); 21 | private int width; 22 | private int height; 23 | private int xTicks; 24 | private int yTicks; 25 | private int padding; 26 | private String xLabel; 27 | private String yLabel; 28 | private ColorFunction colorFunction; 29 | private boolean customScale = false; 30 | private double minX; 31 | private double maxX; 32 | private double minY; 33 | private double maxY; 34 | 35 | public Graph(){ 36 | this(500, 500); 37 | } 38 | 39 | public Graph(int width, int height){ 40 | this(width, height, null, null, null, null); 41 | } 42 | 43 | public Graph(int width, int height, String xLabel, String yLabel){ 44 | this(width, height, 10, 10, 100, xLabel, yLabel, null, null, null, null); 45 | } 46 | 47 | public Graph(ColorFunction colorFunction){ 48 | this(500, 500, colorFunction); 49 | } 50 | 51 | public Graph(int width, int height, ColorFunction colorFunction){ 52 | this(width, height, null, null, null, colorFunction); 53 | } 54 | 55 | public Graph(int width, int height, double[] xData, double[] yData, Color[] cData, ColorFunction colorFunction){ 56 | this(width, height, "x-axis", "y-axis", xData, yData, cData, colorFunction); 57 | } 58 | 59 | public Graph(int width, int height, String xLabel, String yLabel, double[] xData, double[] yData, Color[] cData, ColorFunction colorFunction){ 60 | this(width, height, 10, 10, 100, xLabel, yLabel, xData, yData, cData, colorFunction); 61 | } 62 | 63 | public Graph(int width, int height, int xTicks, int yTicks, int padding, String xLabel, String yLabel, double[] xData, double[] yData, Color[] cData, ColorFunction colorFunction){ 64 | this.graph = new BufferedImage(width, height, BufferedImage.TYPE_INT_ARGB); 65 | this.graphics = this.graph.createGraphics(); 66 | this.width = width; 67 | this.height = height; 68 | this.xTicks = xTicks + 1; 69 | this.yTicks = yTicks + 1; 70 | this.padding = padding; 71 | this.xLabel = xLabel; 72 | this.yLabel = yLabel; 73 | this.colorFunction = colorFunction; 74 | 75 | if(xData != null && yData != null){ 76 | for(int i = 0; i < xData.length; i++){ 77 | if(cData == null || cData.length < xData.length) 78 | this.points.add(new Point(xData[i], yData[i])); 79 | else 80 | this.points.add(new Point(xData[i], yData[i], cData[i])); 81 | } 82 | } 83 | } 84 | 85 | public void useCustomScale(double minX, double maxX, double minY, double maxY){ 86 | this.minX = minX; 87 | this.maxX = maxX; 88 | this.minY = minY; 89 | this.maxY = maxY; 90 | this.customScale = true; 91 | } 92 | 93 | public void usePointScale(){ 94 | this.customScale = false; 95 | } 96 | 97 | public void draw(){ 98 | // find graph range 99 | double xMax = Double.MIN_VALUE; 100 | double xMin = Double.MAX_VALUE; 101 | double yMax = Double.MIN_VALUE; 102 | double yMin = Double.MAX_VALUE; 103 | if(customScale){ 104 | xMax = maxX; 105 | xMin = minX; 106 | yMax = maxY; 107 | yMin = minY; 108 | }else{ 109 | for(int i = 0; i < points.size(); i++){ 110 | xMax = Math.max(xMax, points.get(i).getX()); 111 | xMin = Math.min(xMin, points.get(i).getX()); 112 | yMax = Math.max(yMax, points.get(i).getY()); 113 | yMin = Math.min(yMin, points.get(i).getY()); 114 | } 115 | for(int i = 0; i < lineGraphs.size(); i++){ 116 | ArrayList arr = lineGraphs.get(i).getPoints(); 117 | for(int j = 0; j < arr.size(); j++){ 118 | xMax = Math.max(xMax, arr.get(j).getX()); 119 | xMin = Math.min(xMin, arr.get(j).getX()); 120 | yMax = Math.max(yMax, arr.get(j).getY()); 121 | yMin = Math.min(yMin, arr.get(j).getY()); 122 | } 123 | } 124 | if(xMax == Double.MIN_VALUE) 125 | xMax = 10; 126 | if(xMin == Double.MAX_VALUE) 127 | xMin = 0; 128 | if(yMax == Double.MIN_VALUE) 129 | yMax = 10; 130 | if(yMin == Double.MAX_VALUE) 131 | yMin = 0; 132 | if(xMax - xMin > yMax - yMin){ 133 | double diff = xMax - xMin - yMax + yMin; 134 | yMin -= diff / 2; 135 | yMax += diff / 2; 136 | }else{ 137 | double diff = yMax - yMin - xMax + xMin; 138 | xMin -= diff / 2; 139 | xMax += diff / 2; 140 | } 141 | } 142 | 143 | if(colorFunction != null){ 144 | int xSize = 500; 145 | int ySize = 500; 146 | for(int i = 0; i < xSize; i++){ 147 | for(int j = 0; j < ySize; j++){ 148 | graphics.setColor(colorFunction.getColor( 149 | xMin + i / (double)xSize * (xMax - xMin), 150 | yMin + j / (double)ySize * (yMax - yMin))); 151 | graphics.fillRect( 152 | (int)(padding * 2 + i * (width - padding * 3) / (double)xSize), 153 | (int)(height - padding * 2 - (j + 1) * ((height - padding * 3) / (double)ySize)), 154 | (int)((width - padding * 3) / (double)xSize), 155 | (int)((height - padding * 3) / (double)ySize)); 156 | } 157 | } 158 | } 159 | 160 | // x and y axis 161 | graphics.setColor(Color.black); 162 | graphics.setStroke(new BasicStroke(3)); 163 | graphics.drawLine(padding * 2, height - padding * 2, width - padding, height - padding * 2); 164 | graphics.drawLine(padding * 2, height - padding * 2, padding * 2, padding); 165 | 166 | // x and y axis labels 167 | graphics.setFont(graphics.getFont().deriveFont(50.0f)); 168 | graphics.drawString(xLabel, width / 2 - 5 * xLabel.length(), height - padding); 169 | AffineTransform oldTransform = graphics.getTransform(); 170 | graphics.translate(padding, height / 2 + 5 * yLabel.length()); 171 | graphics.rotate(Math.toRadians(-90.0)); 172 | graphics.drawString(yLabel, 0, 0); 173 | graphics.setTransform(oldTransform); 174 | 175 | // draw tick marks 176 | graphics.setFont(graphics.getFont().deriveFont(25.0f)); 177 | int xTickSpacing = (width - padding * 3) / (xTicks - 1); 178 | int yTickSpacing = (height - padding * 3) / (yTicks - 1); 179 | for(int i = 0; i < xTicks; i++){ 180 | graphics.drawLine( 181 | padding * 2 + xTickSpacing * i, 182 | height - padding * 2, 183 | padding * 2 + xTickSpacing * i, 184 | height - padding * 2 + 10); 185 | graphics.drawString( 186 | Utils.shorterFormat(xMin + (xMax - xMin) / (xTicks - 1) * i), 187 | padding * 2 + xTickSpacing * i - 7, 188 | height - padding * 2 + 40); 189 | } 190 | for(int i = 0; i < yTicks; i++){ 191 | graphics.drawLine( 192 | padding * 2, 193 | height - padding * 2 - yTickSpacing * i, 194 | padding * 2 - 10, 195 | height - padding * 2 - yTickSpacing * i); 196 | graphics.drawString( 197 | Utils.shorterFormat(yMin + (yMax - yMin) / (yTicks - 1) * i), 198 | padding * 2 - 70, 199 | height - padding * 2 - yTickSpacing * i + 10); 200 | } 201 | 202 | // draw points 203 | for(int i = 0; i < points.size(); i++){ 204 | graphics.setColor(Color.black); 205 | graphics.fillOval( 206 | padding * 2 + (int)((points.get(i).getX() - xMin) / (xMax - xMin) * (width - padding * 3)) - 10, 207 | (height - padding * 2) - (int)((points.get(i).getY() - yMin) / (yMax - yMin) * (height - padding * 3)) - 10, 20, 20); 208 | graphics.setColor(points.get(i).getColor()); 209 | graphics.fillOval( 210 | padding * 2 + (int)((points.get(i).getX() - xMin) / (xMax - xMin) * (width - padding * 3)) - 8, 211 | (height - padding * 2) - (int)((points.get(i).getY() - yMin) / (yMax - yMin) * (height - padding * 3)) - 8, 16, 16); 212 | } 213 | 214 | // draw line 215 | for(int i = 0; i < lines.size(); i++){ 216 | graphics.setColor(lines.get(i).getColor()); 217 | double y1 = lines.get(i).getM() * xMin + lines.get(i).getB(); 218 | double y2 = lines.get(i).getM() * xMax + lines.get(i).getB(); 219 | graphics.drawLine( 220 | padding * 2, 221 | (height - padding * 2) - (int)((y1 - yMin) / (yMax - yMin) * (height - padding * 3)), 222 | padding * 2 + width - padding * 3, 223 | (height - padding * 2) - (int)((y2 - yMin) / (yMax - yMin) * (height - padding * 3))); 224 | } 225 | 226 | // draw line graphs 227 | for(int i = 0; i < lineGraphs.size(); i++){ 228 | Point prev = null; 229 | LineGraph g = lineGraphs.get(i); 230 | for(int j = 0; j < g.getPoints().size(); j++){ 231 | Point p = g.getPoints().get(j); 232 | graphics.setColor(Color.black); 233 | graphics.fillOval( 234 | padding * 2 + (int)((p.getX() - xMin) / (xMax - xMin) * (width - padding * 3)) - 10, 235 | (height - padding * 2) - (int)((p.getY() - yMin) / (yMax - yMin) * (height - padding * 3)) - 10, 20, 20); 236 | graphics.setColor(p.getColor()); 237 | graphics.fillOval( 238 | padding * 2 + (int)((p.getX() - xMin) / (xMax - xMin) * (width - padding * 3)) - 8, 239 | (height - padding * 2) - (int)((p.getY() - yMin) / (yMax - yMin) * (height - padding * 3)) - 8, 16, 16); 240 | 241 | if(j != 0){ 242 | graphics.setColor(g.getColor()); 243 | graphics.drawLine( 244 | padding * 2 + (int)((prev.getX() - xMin) / (xMax - xMin) * (width - padding * 3)), 245 | (height - padding * 2) - (int)((prev.getY() - yMin) / (yMax - yMin) * (height - padding * 3)), 246 | padding * 2 + (int)((p.getX() - xMin) / (xMax - xMin) * (width - padding * 3)), 247 | (height - padding * 2) - (int)((p.getY() - yMin) / (yMax - yMin) * (height - padding * 3))); 248 | } 249 | prev = p; 250 | } 251 | } 252 | } 253 | 254 | public void saveToFile(String path, String type){ 255 | try{ 256 | ImageIO.write(graph, type, new File(path)); 257 | }catch(Exception e){ 258 | e.printStackTrace(); 259 | } 260 | } 261 | 262 | public void addPoint(double x, double y, Color c){ 263 | points.add(new Point(x, y, c)); 264 | } 265 | 266 | public void addPoint(double x, double y){ 267 | points.add(new Point(x, y)); 268 | } 269 | 270 | public void addLine(double m, double b, Color c){ 271 | lines.add(new Line(m, b, c)); 272 | } 273 | 274 | public void addLine(double m, double b){ 275 | lines.add(new Line(m, b)); 276 | } 277 | 278 | public void addLineGraph(double[] xs, double[] ys){ 279 | ArrayList arr = new ArrayList<>(); 280 | for(int i = 0; i < xs.length; i++){ 281 | arr.add(new Point(xs[i], ys[i])); 282 | } 283 | lineGraphs.add(new LineGraph(arr)); 284 | } 285 | 286 | public void addLineGraph(double[] xs, double[] ys, Color c){ 287 | ArrayList arr = new ArrayList<>(); 288 | for(int i = 0; i < xs.length; i++){ 289 | arr.add(new Point(xs[i], ys[i])); 290 | } 291 | lineGraphs.add(new LineGraph(arr, c)); 292 | } 293 | 294 | public void addLineGraph(double[] xs, double[] ys, Color[] cs, Color c){ 295 | ArrayList arr = new ArrayList<>(); 296 | for(int i = 0; i < xs.length; i++){ 297 | arr.add(new Point(xs[i], ys[i], cs[i])); 298 | } 299 | lineGraphs.add(new LineGraph(arr, c)); 300 | } 301 | 302 | public BufferedImage getGraph(){ 303 | return graph; 304 | } 305 | 306 | public int getWidth(){ 307 | return width; 308 | } 309 | 310 | public int getHeight(){ 311 | return height; 312 | } 313 | 314 | public void dispose(){ 315 | graphics.dispose(); 316 | } 317 | 318 | public interface ColorFunction{ 319 | public Color getColor(double x, double y); 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /src/javamachinelearning/graphs/GraphPanel.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.graphs; 2 | 3 | import java.awt.Dimension; 4 | import java.awt.Graphics; 5 | 6 | import javax.swing.JPanel; 7 | 8 | @SuppressWarnings("serial") 9 | public class GraphPanel extends JPanel{ 10 | private Graph graph; 11 | 12 | public GraphPanel(Graph graph){ 13 | this.graph = graph; 14 | setPreferredSize(new Dimension(graph.getWidth(), graph.getHeight())); 15 | } 16 | 17 | @Override 18 | public void paintComponent(Graphics graphics){ 19 | super.paintComponent(graphics); 20 | graphics.drawImage(graph.getGraph(), 0, 0, graph.getWidth(), graph.getHeight(), 0, 0, graph.getWidth(), graph.getHeight(), null); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/javamachinelearning/graphs/Line.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.graphs; 2 | 3 | import java.awt.Color; 4 | 5 | public class Line{ 6 | private double m, b; 7 | private Color c; 8 | 9 | public Line(double m, double b){ 10 | this.m = m; 11 | this.b = b; 12 | this.c = Color.black; 13 | } 14 | 15 | public Line(double m, double b, Color c){ 16 | this.m = m; 17 | this.b = b; 18 | this.c = c; 19 | } 20 | 21 | public double getM(){ 22 | return m; 23 | } 24 | 25 | public double getB(){ 26 | return b; 27 | } 28 | 29 | public Color getColor(){ 30 | return c; 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/javamachinelearning/graphs/LineGraph.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.graphs; 2 | 3 | import java.awt.Color; 4 | import java.util.ArrayList; 5 | 6 | public class LineGraph{ 7 | private ArrayList arr; 8 | private Color c; 9 | 10 | public LineGraph(ArrayList arr){ 11 | this.arr = arr; 12 | this.c = Color.black; 13 | } 14 | 15 | public LineGraph(ArrayList arr, Color c){ 16 | this.arr = arr; 17 | this.c = c; 18 | } 19 | 20 | public ArrayList getPoints(){ 21 | return arr; 22 | } 23 | 24 | public Color getColor(){ 25 | return c; 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/javamachinelearning/graphs/Point.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.graphs; 2 | 3 | import java.awt.Color; 4 | 5 | public class Point{ 6 | private double x, y; 7 | private Color c; 8 | 9 | public Point(double x, double y){ 10 | this.x = x; 11 | this.y = y; 12 | this.c = Color.black; 13 | } 14 | 15 | public Point(double x, double y, Color c){ 16 | this.x = x; 17 | this.y = y; 18 | this.c = c; 19 | } 20 | 21 | public double getX(){ 22 | return x; 23 | } 24 | 25 | public double getY(){ 26 | return y; 27 | } 28 | 29 | public Color getColor(){ 30 | return c; 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/Layer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public interface Layer{ 6 | public int[] inputShape(); 7 | public int[] outputShape(); 8 | public void init(int[] inputShape); 9 | public Tensor forwardPropagate(Tensor input, boolean training); 10 | public Tensor backPropagate(Tensor input, Tensor output, Tensor error); 11 | } 12 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/ParamsLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers; 2 | 3 | import java.nio.ByteBuffer; 4 | 5 | import javamachinelearning.optimizers.Optimizer; 6 | import javamachinelearning.regularizers.Regularizer; 7 | 8 | public interface ParamsLayer extends Layer{ 9 | // if biases shouldn't be used 10 | public ParamsLayer noBias(); 11 | 12 | public void update(Optimizer optimizer, Regularizer regularizer); 13 | public int byteSize(); 14 | public ByteBuffer bytes(); 15 | public void readBytes(ByteBuffer bb); 16 | } 17 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/ActivationLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import javamachinelearning.utils.Activation; 4 | import javamachinelearning.utils.Tensor; 5 | 6 | public class ActivationLayer implements FeedForwardLayer{ 7 | private int[] shape; 8 | private Activation activation; 9 | 10 | public ActivationLayer(Activation activation){ 11 | this.activation = activation; 12 | } 13 | 14 | @Override 15 | public int[] outputShape(){ 16 | return shape; 17 | } 18 | 19 | @Override 20 | public int[] inputShape(){ 21 | return shape; 22 | } 23 | 24 | @Override 25 | public void init(int[] inputShape){ 26 | shape = inputShape; 27 | } 28 | 29 | @Override 30 | public Tensor forwardPropagate(Tensor input, boolean training){ 31 | return activation.activate(input); 32 | } 33 | 34 | @Override 35 | public Tensor backPropagate(Tensor input, Tensor output, Tensor error){ 36 | return error.mul(activation.derivative(output)); 37 | } 38 | 39 | @Override 40 | public String toString(){ 41 | return "Activation: " + activation.toString(); 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/AvgPoolingLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import java.util.Arrays; 4 | 5 | import javamachinelearning.utils.Tensor; 6 | 7 | public class AvgPoolingLayer implements FeedForwardLayer{ 8 | private int[] inputShape; 9 | private int[] outputShape; 10 | private int winWidth, winHeight; 11 | private int strideX, strideY; 12 | 13 | public AvgPoolingLayer(int winWidth, int winHeight, int strideX, int strideY){ 14 | this.winWidth = winWidth; 15 | this.winHeight = winHeight; 16 | this.strideX = strideX; 17 | this.strideY = strideY; 18 | } 19 | 20 | public AvgPoolingLayer(int winSize, int stride){ 21 | this.winWidth = winSize; 22 | this.winHeight = winSize; 23 | this.strideX = stride; 24 | this.strideY = stride; 25 | } 26 | 27 | public AvgPoolingLayer(int winSize){ 28 | this.winWidth = winSize; 29 | this.winHeight = winSize; 30 | this.strideX = 1; 31 | this.strideY = 1; 32 | } 33 | 34 | @Override 35 | public int[] outputShape(){ 36 | return outputShape; 37 | } 38 | 39 | @Override 40 | public int[] inputShape(){ 41 | return inputShape; 42 | } 43 | 44 | @Override 45 | public void init(int[] inputShape){ 46 | this.inputShape = inputShape; 47 | 48 | int temp = inputShape[0] - winWidth; 49 | if(temp % strideX != 0) 50 | throw new IllegalArgumentException("Bad sizes for average pooling!"); 51 | int w = temp / strideX + 1; 52 | 53 | temp = inputShape[1] - winHeight; 54 | if(temp % strideY != 0) 55 | throw new IllegalArgumentException("Bad sizes for average pooling!"); 56 | int h = temp / strideY + 1; 57 | 58 | outputShape = new int[]{w, h, inputShape[2]}; 59 | } 60 | 61 | @Override 62 | public Tensor forwardPropagate(Tensor input, boolean training){ 63 | double[] res = new double[outputShape[0] * outputShape[1] * outputShape[2]]; 64 | int[] shape = input.shape(); 65 | int idx = 0; 66 | // slide through and computes the average for each location 67 | // the output should have the same depth as the input 68 | for(int i = 0; i < outputShape[0] * strideX; i += strideX){ 69 | for(int j = 0; j < outputShape[1] * strideY; j += strideY){ 70 | for(int k = 0; k < shape[2]; k++){ // for each depth slice 71 | double sum = 0; 72 | 73 | for(int rx = 0; rx < winWidth; rx++){ // relative x position 74 | for(int ry = 0; ry < winHeight; ry++){ // relative y position 75 | // absolute positions 76 | int x = i + rx; 77 | int y = j + ry; 78 | 79 | sum += input.flatGet(x * shape[1] * shape[2] + y * shape[2] + k); 80 | } 81 | } 82 | 83 | // average of all values 84 | res[idx] = sum / (winWidth * winHeight); 85 | idx++; 86 | } 87 | } 88 | } 89 | 90 | return new Tensor(outputShape, res); 91 | } 92 | 93 | @Override 94 | public Tensor backPropagate(Tensor input, Tensor output, Tensor error){ 95 | double[] res = new double[inputShape[0] * inputShape[1] * inputShape[2]]; 96 | int outIdx = 0; 97 | 98 | for(int i = 0; i < outputShape[0] * strideX; i += strideX){ 99 | for(int j = 0; j < outputShape[1] * strideY; j += strideY){ 100 | for(int k = 0; k < inputShape[2]; k++){ // for each depth slice 101 | for(int rx = 0; rx < winWidth; rx++){ // relative x position 102 | for(int ry = 0; ry < winHeight; ry++){ // relative y position 103 | // absolute positions 104 | int x = i + rx; 105 | int y = j + ry; 106 | int inIdx = x * inputShape[1] * inputShape[2] + y * inputShape[2] + k; 107 | 108 | res[inIdx] += error.flatGet(outIdx) / (winWidth * winHeight); 109 | } 110 | } 111 | 112 | outIdx++; 113 | } 114 | } 115 | } 116 | 117 | return new Tensor(inputShape, res); 118 | } 119 | 120 | @Override 121 | public String toString(){ 122 | return "Average Pooling\tInput Shape: " + Arrays.toString(inputShape()) + "\tOutput Shape: " + Arrays.toString(outputShape()); 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/ConvLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import java.nio.ByteBuffer; 4 | import java.util.Arrays; 5 | 6 | import javamachinelearning.optimizers.Optimizer; 7 | import javamachinelearning.regularizers.Regularizer; 8 | import javamachinelearning.utils.Tensor; 9 | 10 | public class ConvLayer implements FeedForwardParamsLayer{ 11 | private Tensor weights; 12 | private Tensor gradWeights; 13 | private Tensor[] weightExtraParams; 14 | 15 | private Tensor bias; 16 | private Tensor gradBias; 17 | private Tensor[] biasExtraParams; 18 | 19 | private int[] inputShape; 20 | private int[] outputShape; 21 | private int winWidth, winHeight; 22 | private int strideX, strideY; 23 | private int paddingX, paddingY; 24 | private int filterCount; 25 | private int changeCount; 26 | private boolean alreadyInit = false; 27 | private boolean useBias = true; 28 | 29 | public ConvLayer(int winWidth, int winHeight, int strideX, int strideY, int filterCount, int paddingX, int paddingY){ 30 | this.winWidth = winWidth; 31 | this.winHeight = winHeight; 32 | this.strideX = strideX; 33 | this.strideY = strideY; 34 | this.filterCount = filterCount; 35 | this.paddingX = paddingX; 36 | this.paddingY = paddingY; 37 | } 38 | 39 | public ConvLayer(int winSize, int stride, int filterCount, int padding){ 40 | this(winSize, winSize, stride, stride, filterCount, padding, padding); 41 | } 42 | 43 | public ConvLayer(int winSize, int filterCount, int padding){ 44 | this(winSize, 1, filterCount, padding); 45 | } 46 | 47 | public ConvLayer(int winWidth, int winHeight, int strideX, int strideY, int filterCount, PaddingType type){ 48 | if(type == PaddingType.VALID){ 49 | this.winWidth = winWidth; 50 | this.winHeight = winHeight; 51 | this.strideX = strideX; 52 | this.strideY = strideY; 53 | this.filterCount = filterCount; 54 | this.paddingX = 0; 55 | this.paddingY = 0; 56 | }else{ 57 | this.winWidth = winWidth; 58 | this.winHeight = winHeight; 59 | this.strideX = strideX; 60 | this.strideY = strideY; 61 | this.filterCount = filterCount; 62 | if((winWidth - 1) % 2 != 0) 63 | throw new IllegalArgumentException("Bad sizes for convolution!"); 64 | this.paddingX = (winWidth - 1) / 2; 65 | if((winHeight - 1) % 2 != 0) 66 | throw new IllegalArgumentException("Bad sizes for convolution!"); 67 | this.paddingY = (winHeight - 1) / 2; 68 | } 69 | } 70 | 71 | public ConvLayer(int winSize, int stride, int filterCount, PaddingType type){ 72 | this(winSize, winSize, stride, stride, filterCount, type); 73 | } 74 | 75 | public ConvLayer(int winSize, int filterCount, PaddingType type){ 76 | this(winSize, 1, filterCount, type); 77 | } 78 | 79 | public ConvLayer(int winSize, int filterCount){ 80 | this(winSize, filterCount, PaddingType.VALID); 81 | } 82 | 83 | @Override 84 | public int[] outputShape(){ 85 | return outputShape; 86 | } 87 | 88 | @Override 89 | public int[] inputShape(){ 90 | return inputShape; 91 | } 92 | 93 | @Override 94 | public void init(int[] inputShape){ 95 | this.inputShape = inputShape; 96 | 97 | int temp = inputShape[0] - winWidth + paddingX * 2; 98 | if(temp % strideX != 0) 99 | throw new IllegalArgumentException("Bad sizes for convolution!"); 100 | int w = temp / strideX + 1; 101 | 102 | temp = inputShape[1] - winHeight + paddingY * 2; 103 | if(temp % strideY != 0) 104 | throw new IllegalArgumentException("Bad sizes for convolution!"); 105 | int h = temp / strideY + 1; 106 | 107 | outputShape = new int[]{w, h, filterCount}; 108 | 109 | if(!alreadyInit){ 110 | weights = new Tensor(new int[]{winWidth, winHeight, inputShape[2], filterCount}, true); 111 | if(useBias) 112 | bias = new Tensor(new int[]{1, 1, filterCount}, false); 113 | } 114 | gradWeights = new Tensor(new int[]{winWidth, winHeight, inputShape[2], filterCount}, false); 115 | if(useBias) 116 | gradBias = new Tensor(new int[]{1, 1, filterCount}, false); 117 | } 118 | 119 | @Override 120 | public FeedForwardParamsLayer withParams(Tensor w, Tensor b){ 121 | weights = w; 122 | if(useBias) 123 | bias = b; 124 | alreadyInit = true; 125 | return this; 126 | } 127 | 128 | @Override 129 | public FeedForwardParamsLayer noBias(){ 130 | useBias = false; 131 | return this; 132 | } 133 | 134 | @Override 135 | public Tensor bias(){ 136 | return bias; 137 | } 138 | 139 | @Override 140 | public Tensor weights(){ 141 | return weights; 142 | } 143 | 144 | @Override 145 | public void setBias(Tensor b){ 146 | if(useBias) 147 | bias = b; 148 | } 149 | 150 | @Override 151 | public void setWeights(Tensor w){ 152 | weights = w; 153 | } 154 | 155 | @Override 156 | public Tensor forwardPropagate(Tensor input, boolean training){ 157 | double[] res = new double[outputShape[0] * outputShape[1] * filterCount]; 158 | int[] inMult = input.mult(); // equals the mult for inputShape because input shape equals inputShape 159 | int[] wMult = weights.mult(); 160 | int idx = 0; 161 | 162 | for(int i = 0; i < outputShape[0] * strideX; i += strideX){ 163 | for(int j = 0; j < outputShape[1] * strideY; j += strideY){ 164 | for(int filter = 0; filter < filterCount; filter++){ 165 | // relative to each filter 166 | for(int rx = 0; rx < winWidth; rx++){ 167 | for(int ry = 0; ry < winHeight; ry++){ 168 | for(int depth = 0; depth < inputShape[2]; depth++){ 169 | // absolute positions 170 | int x = i - paddingX + rx; 171 | int y = j - paddingY + ry; 172 | 173 | // handle zero padding 174 | if(x < 0 || x >= inputShape[0] || y < 0 || y >= inputShape[1]) 175 | continue; 176 | 177 | // multiply by weight and accumulate by addition 178 | res[idx] += input.flatGet(x * inMult[0] + y * inMult[1] + depth) * 179 | weights.flatGet(rx * wMult[0] + ry * wMult[1] + depth * wMult[2] + filter); 180 | } 181 | } 182 | } 183 | 184 | // add bias 185 | if(useBias) 186 | res[idx] += bias.flatGet(filter); 187 | 188 | idx++; 189 | } 190 | } 191 | } 192 | 193 | return new Tensor(outputShape, res); 194 | } 195 | 196 | @Override 197 | public Tensor backPropagate(Tensor input, Tensor output, Tensor error){ 198 | // calculate weight gradients and bias gradients 199 | double[] deltaW = new double[weights.size()]; 200 | double[] deltaB = new double[bias.size()]; 201 | int[] inMult = input.mult(); 202 | int[] wMult = weights.mult(); 203 | int gradIdx = 0; 204 | 205 | for(int i = 0; i < outputShape[0] * strideX; i += strideX){ 206 | for(int j = 0; j < outputShape[1] * strideY; j += strideY){ 207 | for(int filter = 0; filter < filterCount; filter++){ 208 | // relative to each filter 209 | for(int rx = 0; rx < winWidth; rx++){ 210 | for(int ry = 0; ry < winHeight; ry++){ 211 | for(int depth = 0; depth < inputShape[2]; depth++){ 212 | // absolute positions 213 | int x = i - paddingX + rx; 214 | int y = j - paddingY + ry; 215 | 216 | // handle zero padding 217 | if(x < 0 || x >= inputShape[0] || y < 0 || y >= inputShape[1]) 218 | continue; 219 | 220 | int wIdx = rx * wMult[0] + ry * wMult[1] + depth * wMult[2] + filter; 221 | 222 | // multiply gradients by previous layer's output 223 | // accumulate gradients for each weight 224 | deltaW[wIdx] += error.flatGet(gradIdx) * 225 | input.flatGet(x * inMult[0] + y * inMult[1] + depth); 226 | } 227 | } 228 | } 229 | 230 | // accumulate gradients for the biases 231 | // one bias per filter! 232 | if(useBias) 233 | deltaB[filter] += error.flatGet(gradIdx); 234 | 235 | gradIdx++; 236 | } 237 | } 238 | } 239 | 240 | gradWeights = gradWeights.add(new Tensor(weights.shape(), deltaW)); 241 | 242 | if(useBias) 243 | gradBias = gradBias.add(new Tensor(bias.shape(), deltaB)); 244 | 245 | // calculate the gradients wrt input 246 | double[] gradInputs = new double[input.size()]; 247 | gradIdx = 0; 248 | 249 | for(int i = 0; i < outputShape[0] * strideX; i += strideX){ 250 | for(int j = 0; j < outputShape[1] * strideY; j += strideY){ 251 | for(int filter = 0; filter < filterCount; filter++){ 252 | // relative to each filter 253 | for(int rx = 0; rx < winWidth; rx++){ 254 | for(int ry = 0; ry < winHeight; ry++){ 255 | for(int depth = 0; depth < inputShape[2]; depth++){ 256 | // absolute positions 257 | int x = i - paddingX + rx; 258 | int y = j - paddingY + ry; 259 | 260 | // handle zero padding 261 | if(x < 0 || x >= inputShape[0] || y < 0 || y >= inputShape[1]) 262 | continue; 263 | 264 | int inIdx = x * inMult[0] + y * inMult[1] + depth; 265 | 266 | // multiply gradients by each weight 267 | // accumulate gradients for each input 268 | gradInputs[inIdx] += error.flatGet(gradIdx) * 269 | weights.flatGet(rx * wMult[0] + ry * wMult[1] + depth * wMult[2] + filter); 270 | } 271 | } 272 | } 273 | 274 | gradIdx++; 275 | } 276 | } 277 | } 278 | 279 | changeCount++; 280 | 281 | return new Tensor(inputShape, gradInputs); 282 | } 283 | 284 | @Override 285 | public void update(Optimizer optimizer, Regularizer regularizer){ 286 | if(weightExtraParams == null){ 287 | weightExtraParams = new Tensor[optimizer.extraParams()]; 288 | for(int i = 0; i < weightExtraParams.length; i++){ 289 | weightExtraParams[i] = new Tensor(weights.shape(), false); 290 | } 291 | 292 | if(useBias){ 293 | biasExtraParams = new Tensor[optimizer.extraParams()]; 294 | for(int i = 0; i < biasExtraParams.length; i++){ 295 | biasExtraParams[i] = new Tensor(bias.shape(), false); 296 | } 297 | } 298 | } 299 | 300 | if(regularizer == null){ 301 | weights = weights.sub( 302 | optimizer.optimize( 303 | gradWeights.div(Math.max(changeCount, 1)), weightExtraParams)); 304 | }else{ 305 | weights = weights.sub( 306 | optimizer.optimize( 307 | gradWeights.div(Math.max(changeCount, 1)).add( 308 | regularizer.derivative(weights)), weightExtraParams)); 309 | } 310 | gradWeights = new Tensor(gradWeights.shape(), false); 311 | 312 | if(useBias){ 313 | bias = bias.sub( 314 | optimizer.optimize( 315 | gradBias.div(Math.max(changeCount, 1)), biasExtraParams)); 316 | gradBias = new Tensor(gradBias.shape(), false); 317 | } 318 | changeCount = 0; 319 | } 320 | 321 | @Override 322 | public int byteSize(){ 323 | return Double.BYTES * weights.size() + (useBias ? Double.BYTES * bias.size() : 0); 324 | } 325 | 326 | @Override 327 | public ByteBuffer bytes(){ 328 | ByteBuffer bb = ByteBuffer.allocate(byteSize()); 329 | for(int i = 0; i < weights.size(); i++){ 330 | bb.putDouble(weights.flatGet(i)); 331 | } 332 | if(useBias){ 333 | for(int i = 0; i < bias.size(); i++){ 334 | bb.putDouble(bias.flatGet(i)); 335 | } 336 | } 337 | bb.flip(); 338 | return bb; 339 | } 340 | 341 | @Override 342 | public void readBytes(ByteBuffer bb){ 343 | double[] w = new double[weights.size()]; 344 | for(int i = 0; i < w.length; i++){ 345 | w[i] = bb.getDouble(); 346 | } 347 | weights = new Tensor(weights.shape(), w); 348 | 349 | if(useBias){ 350 | double[] b = new double[bias.size()]; 351 | for(int i = 0; i < b.length; i++){ 352 | b[i] = bb.getDouble(); 353 | } 354 | bias = new Tensor(bias.shape(), b); 355 | } 356 | } 357 | 358 | @Override 359 | public String toString(){ 360 | return "Convolutional\tInput Shape: " + Arrays.toString(inputShape()) + "\tOutput Shape: " + Arrays.toString(outputShape()); 361 | } 362 | 363 | public enum PaddingType{ 364 | VALID, SAME; 365 | } 366 | } 367 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/DropoutLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import java.util.Random; 4 | 5 | import javamachinelearning.utils.Tensor; 6 | 7 | public class DropoutLayer implements FeedForwardLayer{ 8 | private double dropout; 9 | private int[] shape; 10 | private Tensor mask; 11 | 12 | public DropoutLayer(){ 13 | this.dropout = 0.5; 14 | } 15 | 16 | // chance to drop out an input 17 | public DropoutLayer(double dropout){ 18 | this.dropout = dropout; 19 | } 20 | 21 | @Override 22 | public int[] outputShape(){ 23 | return shape; 24 | } 25 | 26 | @Override 27 | public int[] inputShape(){ 28 | return shape; 29 | } 30 | 31 | @Override 32 | public void init(int[] inputShape){ 33 | shape = inputShape; 34 | } 35 | 36 | @Override 37 | public Tensor forwardPropagate(Tensor input, boolean training){ 38 | if(training){ 39 | double[] arr = new double[input.size()]; 40 | Random r = new Random(); 41 | for(int i = 0; i < input.size(); i++){ 42 | // if not dropout, then scale the inputs 43 | arr[i] = r.nextDouble() < dropout ? 0.0 : (1.0 / (1.0 - dropout)); 44 | } 45 | mask = new Tensor(input.shape(), arr); 46 | 47 | return input.mul(mask); 48 | }else{ 49 | // do not need to scale inputs 50 | return input; 51 | } 52 | } 53 | 54 | @Override 55 | public Tensor backPropagate(Tensor input, Tensor output, Tensor error){ 56 | // scale the gradients during backpropagation 57 | return error.mul(mask); 58 | } 59 | 60 | @Override 61 | public String toString(){ 62 | return "Dropout Probability: " + dropout; 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/FCLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import java.nio.ByteBuffer; 4 | import java.util.Arrays; 5 | 6 | import javamachinelearning.optimizers.Optimizer; 7 | import javamachinelearning.regularizers.Regularizer; 8 | import javamachinelearning.utils.Tensor; 9 | 10 | public class FCLayer implements FeedForwardParamsLayer{ 11 | private Tensor weights; 12 | private Tensor gradWeights; 13 | private Tensor[] weightExtraParams; // extra optimization parameters for weights 14 | 15 | private Tensor bias; 16 | private Tensor gradBias; 17 | private Tensor[] biasExtraParams; // extra optimization parameters for biases 18 | 19 | private int[] inputShape; 20 | private int[] outputShape; 21 | private int changeCount; 22 | private boolean alreadyInit = false; 23 | private boolean useBias = true; 24 | 25 | public FCLayer(int nextSize){ 26 | this.outputShape = new int[]{-1, nextSize}; 27 | } 28 | 29 | @Override 30 | public int[] outputShape(){ 31 | return outputShape; 32 | } 33 | 34 | @Override 35 | public int[] inputShape(){ 36 | return inputShape; 37 | } 38 | 39 | @Override 40 | public void init(int[] inputShape){ 41 | this.inputShape = inputShape; 42 | this.outputShape[0] = inputShape[0]; 43 | 44 | if(!alreadyInit){ 45 | this.weights = new Tensor(new int[]{this.inputShape[1], this.outputShape[1]}, true); 46 | if(useBias) 47 | this.bias = new Tensor(new int[]{1, this.outputShape[1]}, false); 48 | } 49 | this.gradWeights = new Tensor(new int[]{this.inputShape[1], this.outputShape[1]}, false); 50 | if(useBias) 51 | this.gradBias = new Tensor(new int[]{1, this.outputShape[1]}, false); 52 | } 53 | 54 | @Override 55 | public FeedForwardParamsLayer withParams(Tensor w, Tensor b){ 56 | weights = w; 57 | if(useBias) 58 | bias = b; 59 | alreadyInit = true; 60 | return this; 61 | } 62 | 63 | @Override 64 | public FeedForwardParamsLayer noBias(){ 65 | useBias = false; 66 | return this; 67 | } 68 | 69 | @Override 70 | public Tensor bias(){ 71 | return bias; 72 | } 73 | 74 | @Override 75 | public Tensor weights(){ 76 | return weights; 77 | } 78 | 79 | @Override 80 | public void setBias(Tensor b){ 81 | if(useBias) 82 | bias = b; 83 | } 84 | 85 | @Override 86 | public void setWeights(Tensor w){ 87 | weights = w; 88 | } 89 | 90 | @Override 91 | public Tensor forwardPropagate(Tensor input, boolean training){ 92 | Tensor x = weights.dot(input); 93 | if(useBias){ 94 | // duplicate bias for multiple time steps if needed 95 | x = x.add(bias.dupFirst(x.shape()[0])); 96 | } 97 | return x; 98 | } 99 | 100 | @Override 101 | public Tensor backPropagate(Tensor input, Tensor output, Tensor error){ 102 | // error wrt weight 103 | gradWeights = gradWeights.add(error.dot(input.T())); 104 | 105 | // error wrt bias 106 | // not multiplied by previous outputs! 107 | if(useBias){ 108 | // if error contains multiple time steps 109 | // then accumulate the gradients across the time steps 110 | gradBias = gradBias.add(error.T().reduceLast(0, (a, b) -> a + b)); 111 | } 112 | 113 | // new error should be affected by weights 114 | Tensor gradInputs = weights.T().dot(error); 115 | 116 | changeCount++; 117 | 118 | return gradInputs; 119 | } 120 | 121 | @Override 122 | public void update(Optimizer optimizer, Regularizer regularizer){ 123 | // initialize extra parameters 124 | if(weightExtraParams == null){ 125 | weightExtraParams = new Tensor[optimizer.extraParams()]; 126 | for(int i = 0; i < weightExtraParams.length; i++){ 127 | weightExtraParams[i] = new Tensor(weights.shape(), false); 128 | } 129 | 130 | if(useBias){ 131 | biasExtraParams = new Tensor[optimizer.extraParams()]; 132 | for(int i = 0; i < biasExtraParams.length; i++){ 133 | biasExtraParams[i] = new Tensor(bias.shape(), false); 134 | } 135 | } 136 | } 137 | 138 | // handles postponed updates, by averaging accumulated gradients 139 | // add the regularization derivative if needed 140 | 141 | // note that averaging the weight gradients here is the same as 142 | // averaging the loss gradients per mini-batch after forward propagation 143 | if(regularizer == null){ 144 | weights = weights.sub( 145 | optimizer.optimize( 146 | gradWeights.div(Math.max(changeCount, 1)), weightExtraParams)); 147 | }else{ 148 | weights = weights.sub( 149 | optimizer.optimize( 150 | gradWeights.div(Math.max(changeCount, 1)).add( 151 | regularizer.derivative(weights)), weightExtraParams)); 152 | } 153 | gradWeights = new Tensor(gradWeights.shape(), false); 154 | 155 | if(useBias){ 156 | bias = bias.sub( 157 | optimizer.optimize( 158 | gradBias.div(Math.max(changeCount, 1)), biasExtraParams)); 159 | gradBias = new Tensor(gradBias.shape(), false); 160 | } 161 | 162 | changeCount = 0; 163 | } 164 | 165 | @Override 166 | public int byteSize(){ 167 | // 8 bytes for each double 168 | return Double.BYTES * weights.size() + (useBias ? Double.BYTES * bias.size() : 0); 169 | } 170 | 171 | @Override 172 | public ByteBuffer bytes(){ 173 | ByteBuffer bb = ByteBuffer.allocate(byteSize()); 174 | for(int i = 0; i < weights.size(); i++){ 175 | bb.putDouble(weights.flatGet(i)); 176 | } 177 | if(useBias){ 178 | for(int i = 0; i < bias.size(); i++){ 179 | bb.putDouble(bias.flatGet(i)); 180 | } 181 | } 182 | bb.flip(); 183 | return bb; 184 | } 185 | 186 | @Override 187 | public void readBytes(ByteBuffer bb){ 188 | double[] w = new double[weights.size()]; 189 | for(int i = 0; i < w.length; i++){ 190 | w[i] = bb.getDouble(); 191 | } 192 | weights = new Tensor(weights.shape(), w); 193 | 194 | if(useBias){ 195 | double[] b = new double[bias.size()]; 196 | for(int i = 0; i < b.length; i++){ 197 | b[i] = bb.getDouble(); 198 | } 199 | bias = new Tensor(bias.shape(), b); 200 | } 201 | } 202 | 203 | @Override 204 | public String toString(){ 205 | return "Fully Connected\tInput Shape: " + Arrays.toString(inputShape()) + "\tOutput Shape: " + Arrays.toString(outputShape()); 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/FeedForwardLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import javamachinelearning.layers.Layer; 4 | 5 | public interface FeedForwardLayer extends Layer{ 6 | 7 | } 8 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/FeedForwardParamsLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import javamachinelearning.layers.ParamsLayer; 4 | import javamachinelearning.utils.Tensor; 5 | 6 | public interface FeedForwardParamsLayer extends FeedForwardLayer, ParamsLayer{ 7 | // withParams should be used when initializing a layer 8 | public FeedForwardParamsLayer withParams(Tensor w, Tensor b); 9 | 10 | public Tensor bias(); 11 | public Tensor weights(); 12 | public void setBias(Tensor b); 13 | public void setWeights(Tensor w); 14 | } 15 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/FlattenLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import java.util.Arrays; 4 | 5 | import javamachinelearning.utils.Tensor; 6 | 7 | public class FlattenLayer implements FeedForwardLayer{ 8 | private int[] inputShape; 9 | private int outputSize; 10 | 11 | public FlattenLayer(){ 12 | // nothing to do 13 | } 14 | 15 | @Override 16 | public int[] outputShape(){ 17 | return new int[]{1, outputSize}; 18 | } 19 | 20 | @Override 21 | public int[] inputShape(){ 22 | return inputShape; 23 | } 24 | 25 | @Override 26 | public void init(int[] inputShape){ 27 | this.inputShape = inputShape; 28 | outputSize = 1; 29 | for(int i = 0; i < inputShape.length; i++){ 30 | outputSize *= inputShape[i]; 31 | } 32 | } 33 | 34 | @Override 35 | public Tensor forwardPropagate(Tensor input, boolean training){ 36 | return input.flatten(); 37 | } 38 | 39 | @Override 40 | public Tensor backPropagate(Tensor input, Tensor output, Tensor error){ 41 | return error.reshape(inputShape); 42 | } 43 | 44 | @Override 45 | public String toString(){ 46 | return "Flatten\tInput Shape: " + Arrays.toString(inputShape()) + "\tOutput Shape: " + Arrays.toString(outputShape()); 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/MaxPoolingLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import java.util.Arrays; 4 | 5 | import javamachinelearning.utils.Tensor; 6 | 7 | public class MaxPoolingLayer implements FeedForwardLayer{ 8 | private int[] inputShape; 9 | private int[] outputShape; 10 | private int winWidth, winHeight; 11 | private int strideX, strideY; 12 | private int[][] maxIdx; 13 | 14 | public MaxPoolingLayer(int winWidth, int winHeight, int strideX, int strideY){ 15 | this.winWidth = winWidth; 16 | this.winHeight = winHeight; 17 | this.strideX = strideX; 18 | this.strideY = strideY; 19 | } 20 | 21 | public MaxPoolingLayer(int winSize, int stride){ 22 | this.winWidth = winSize; 23 | this.winHeight = winSize; 24 | this.strideX = stride; 25 | this.strideY = stride; 26 | } 27 | 28 | public MaxPoolingLayer(int winSize){ 29 | this.winWidth = winSize; 30 | this.winHeight = winSize; 31 | this.strideX = 1; 32 | this.strideY = 1; 33 | } 34 | 35 | @Override 36 | public int[] outputShape(){ 37 | return outputShape; 38 | } 39 | 40 | @Override 41 | public int[] inputShape(){ 42 | return inputShape; 43 | } 44 | 45 | @Override 46 | public void init(int[] inputShape){ 47 | this.inputShape = inputShape; 48 | 49 | int temp = inputShape[0] - winWidth; 50 | if(temp % strideX != 0) 51 | throw new IllegalArgumentException("Bad sizes for max pooling!"); 52 | int w = temp / strideX + 1; 53 | 54 | temp = inputShape[1] - winHeight; 55 | if(temp % strideY != 0) 56 | throw new IllegalArgumentException("Bad sizes for max pooling!"); 57 | int h = temp / strideY + 1; 58 | 59 | outputShape = new int[]{w, h, inputShape[2]}; 60 | maxIdx = new int[outputShape[0] * outputShape[1] * outputShape[2]][2]; 61 | } 62 | 63 | @Override 64 | public Tensor forwardPropagate(Tensor input, boolean training){ 65 | double[] res = new double[outputShape[0] * outputShape[1] * outputShape[2]]; 66 | int[] shape = input.shape(); 67 | int idx = 0; 68 | // slide through and computes the max for each location 69 | // the output should have the same depth as the input 70 | for(int i = 0; i < outputShape[0] * strideX; i += strideX){ 71 | for(int j = 0; j < outputShape[1] * strideY; j += strideY){ 72 | for(int k = 0; k < shape[2]; k++){ // for each depth slice 73 | double max = Double.MIN_VALUE; 74 | 75 | for(int rx = 0; rx < winWidth; rx++){ // relative x position 76 | for(int ry = 0; ry < winHeight; ry++){ // relative y position 77 | // absolute positions 78 | int x = i + rx; 79 | int y = j + ry; 80 | double val = input.flatGet(x * shape[1] * shape[2] + y * shape[2] + k); 81 | 82 | if(val > max){ 83 | max = val; 84 | 85 | maxIdx[idx][0] = x; 86 | maxIdx[idx][1] = y; 87 | } 88 | } 89 | } 90 | 91 | // max of all values 92 | res[idx] = max; 93 | idx++; 94 | } 95 | } 96 | } 97 | 98 | return new Tensor(outputShape, res); 99 | } 100 | 101 | @Override 102 | public Tensor backPropagate(Tensor input, Tensor output, Tensor error){ 103 | double[] res = new double[inputShape[0] * inputShape[1] * inputShape[2]]; 104 | int outIdx = 0; 105 | 106 | for(int i = 0; i < outputShape[0] * strideX; i += strideX){ 107 | for(int j = 0; j < outputShape[1] * strideY; j += strideY){ 108 | for(int k = 0; k < inputShape[2]; k++){ // for each depth slice 109 | for(int rx = 0; rx < winWidth; rx++){ // relative x position 110 | for(int ry = 0; ry < winHeight; ry++){ // relative y position 111 | // absolute positions 112 | int x = i + rx; 113 | int y = j + ry; 114 | int inIdx = x * inputShape[1] * inputShape[2] + y * inputShape[2] + k; 115 | 116 | if(maxIdx[outIdx][0] == x && maxIdx[outIdx][1] == y){ 117 | res[inIdx] += error.flatGet(outIdx); 118 | } 119 | } 120 | } 121 | 122 | outIdx++; 123 | } 124 | } 125 | } 126 | 127 | return new Tensor(inputShape, res); 128 | } 129 | 130 | @Override 131 | public String toString(){ 132 | return "Max Pooling\tInput Shape: " + Arrays.toString(inputShape()) + "\tOutput Shape: " + Arrays.toString(outputShape()); 133 | } 134 | } 135 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/feedforward/ScalingLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.feedforward; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class ScalingLayer implements FeedForwardLayer{ 6 | private int[] shape; 7 | private double scale; 8 | private boolean useTraining; 9 | 10 | public ScalingLayer(double scale, boolean useTraining){ 11 | this.scale = scale; 12 | this.useTraining = useTraining; 13 | } 14 | 15 | @Override 16 | public int[] outputShape(){ 17 | return shape; 18 | } 19 | 20 | @Override 21 | public int[] inputShape(){ 22 | return shape; 23 | } 24 | 25 | @Override 26 | public void init(int[] inputShape){ 27 | shape = inputShape; 28 | } 29 | 30 | @Override 31 | public Tensor forwardPropagate(Tensor input, boolean training){ 32 | if(!training || useTraining) 33 | return input.mul(scale); 34 | else 35 | return input; 36 | } 37 | 38 | @Override 39 | public Tensor backPropagate(Tensor input, Tensor output, Tensor error){ 40 | if(useTraining) 41 | return error.mul(scale); 42 | else 43 | return error; 44 | } 45 | 46 | @Override 47 | public String toString(){ 48 | return "Scaling Factor: " + scale; 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/recurrent/GRUCell.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.recurrent; 2 | 3 | import java.nio.ByteBuffer; 4 | 5 | import javamachinelearning.optimizers.Optimizer; 6 | import javamachinelearning.regularizers.Regularizer; 7 | import javamachinelearning.utils.Activation; 8 | import javamachinelearning.utils.Tensor; 9 | 10 | public class GRUCell implements RecurrentCell{ 11 | private int size; 12 | private Activation activation; 13 | private Activation gateActivation; 14 | 15 | private Tensor resetW, updateW, memoryW; 16 | private Tensor resetU, updateU, memoryU; 17 | private Tensor resetB, updateB, memoryB; 18 | 19 | private Tensor gradResetW, gradUpdateW, gradMemoryW; 20 | private Tensor gradResetU, gradUpdateU, gradMemoryU; 21 | private Tensor gradResetB, gradUpdateB, gradMemoryB; 22 | 23 | private Tensor[] resetWParams, updateWParams, memoryWParams; 24 | private Tensor[] resetUParams, updateUParams, memoryUParams; 25 | private Tensor[] resetBParams, updateBParams, memoryBParams; 26 | 27 | // cached values for backpropagation 28 | private Tensor[] reset, update, memory; 29 | 30 | private boolean useBias = true; 31 | 32 | public GRUCell(Activation activation, Activation gateActivation){ 33 | this.activation = activation; 34 | this.gateActivation = gateActivation; 35 | } 36 | 37 | public GRUCell(){ 38 | this(Activation.tanh, Activation.sigmoid); 39 | } 40 | 41 | @Override 42 | public void noBias(){ 43 | resetB = null; 44 | updateB = null; 45 | memoryB = null; 46 | gradResetB = null; 47 | gradUpdateB = null; 48 | gradMemoryB = null; 49 | useBias = false; 50 | } 51 | 52 | @Override 53 | public int[] outputShape(){ 54 | return new int[]{1, size}; 55 | } 56 | 57 | @Override 58 | public int[] inputShape(){ 59 | return new int[]{1, size}; 60 | } 61 | 62 | @Override 63 | public void init(int inputSize, int numTimeSteps){ 64 | size = inputSize; 65 | 66 | // initialize weights/biases and their gradient accumulators 67 | resetW = new Tensor(new int[]{size, size}, true); 68 | updateW = new Tensor(new int[]{size, size}, true); 69 | memoryW = new Tensor(new int[]{size, size}, true); 70 | 71 | resetU = new Tensor(new int[]{size, size}, true); 72 | updateU = new Tensor(new int[]{size, size}, true); 73 | memoryU = new Tensor(new int[]{size, size}, true); 74 | 75 | if(useBias){ 76 | resetB = new Tensor(new int[]{1, size}, false); 77 | updateB = new Tensor(new int[]{1, size}, false); 78 | memoryB = new Tensor(new int[]{1, size}, false); 79 | } 80 | 81 | gradResetW = new Tensor(new int[]{size, size}, false); 82 | gradUpdateW = new Tensor(new int[]{size, size}, false); 83 | gradMemoryW = new Tensor(new int[]{size, size}, false); 84 | 85 | gradResetU = new Tensor(new int[]{size, size}, false); 86 | gradUpdateU = new Tensor(new int[]{size, size}, false); 87 | gradMemoryU = new Tensor(new int[]{size, size}, false); 88 | 89 | if(useBias){ 90 | gradResetB = new Tensor(new int[]{1, size}, false); 91 | gradUpdateB = new Tensor(new int[]{1, size}, false); 92 | gradMemoryB = new Tensor(new int[]{1, size}, false); 93 | } 94 | 95 | // used to cache computed results 96 | reset = new Tensor[numTimeSteps]; 97 | update = new Tensor[numTimeSteps]; 98 | memory = new Tensor[numTimeSteps]; 99 | } 100 | 101 | @Override 102 | public Tensor forwardPropagate(int t, Tensor input, Tensor prevState, boolean training){ 103 | // forward propagate equations 104 | // omits some details regarding matrix multiplications and stuff 105 | // reset = sigmoid(input * resetW + prevState * resetU + resetB) 106 | // update = sigmoid(input * updateW + prevState * updateU + updateB) 107 | // memory = tanh(input * memoryW + (prevState * reset) * memoryU + memoryB) 108 | // state = (1 - update) * memory + update * prevState 109 | 110 | if(useBias) 111 | reset[t] = gateActivation.activate(resetW.dot(input).add(resetU.dot(prevState)).add(resetB)); 112 | else 113 | reset[t] = gateActivation.activate(resetW.dot(input).add(resetU.dot(prevState))); 114 | 115 | if(useBias) 116 | update[t] = gateActivation.activate(updateW.dot(input).add(updateU.dot(prevState)).add(updateB)); 117 | else 118 | update[t] = gateActivation.activate(updateW.dot(input).add(updateU.dot(prevState))); 119 | 120 | // the activation here can be something other than tanh 121 | if(useBias) 122 | memory[t] = activation.activate(memoryW.dot(input).add(memoryU.dot(prevState.mul(reset[t]))).add(memoryB)); 123 | else 124 | memory[t] = activation.activate(memoryW.dot(input).add(memoryU.dot(prevState.mul(reset[t])))); 125 | 126 | return update[t].map(x -> 1.0 - x).mul(memory[t]).add(update[t].mul(prevState)); 127 | } 128 | 129 | @Override 130 | public Tensor[] backPropagate(int t, Tensor input, Tensor prevState, Tensor error){ 131 | // first, gather gradients for the memory, reset, and update equations (multiplied by activation derivatives) 132 | // second, calculate gradients wrt weights/biases 133 | // third, accumulate gradients wrt prevState and inputs 134 | 135 | Tensor gradMemory = error.mul(update[t].map(x -> 1.0 - x)).mul(activation.derivative(memory[t])); 136 | 137 | Tensor gradUpdate = error.mul(prevState.sub(memory[t])).mul(gateActivation.derivative(update[t])); 138 | 139 | Tensor gradReset = memoryU.T().dot(gradMemory).mul(prevState).mul(gateActivation.derivative(reset[t])); 140 | 141 | gradResetW = gradResetW.add(gradReset.dot(input.T())); 142 | gradResetU = gradResetU.add(gradReset.dot(prevState.T())); 143 | 144 | gradUpdateW = gradUpdateW.add(gradUpdate.dot(input.T())); 145 | gradUpdateU = gradUpdateU.add(gradUpdate.dot(prevState.T())); 146 | 147 | gradMemoryW = gradMemoryW.add(gradMemory.dot(input.T())); 148 | gradMemoryU = gradMemoryU.add(gradMemory.dot(prevState.mul(reset[t]).T())); 149 | 150 | if(useBias){ 151 | gradResetB = gradResetB.add(gradReset); 152 | gradUpdateB = gradUpdateB.add(gradUpdate); 153 | gradMemoryB = gradMemoryB.add(gradMemory); 154 | } 155 | 156 | Tensor gradInput = resetW.T().dot(gradReset).add( 157 | updateW.T().dot(gradUpdate)).add(memoryW.T().dot(gradMemory)); 158 | 159 | Tensor gradPrevState = resetU.T().dot(gradReset).add( 160 | updateU.T().dot(gradUpdate)).add(memoryU.T().dot(gradMemory).mul(reset[t])).add(error.mul(update[t])); 161 | 162 | return new Tensor[]{gradInput, gradPrevState}; 163 | } 164 | 165 | @Override 166 | public void update(Optimizer optimizer, Regularizer regularizer, int changeCount){ 167 | // initialize all extra parameters that are used for optimization 168 | if(resetWParams == null){ 169 | resetWParams = new Tensor[optimizer.extraParams()]; 170 | updateWParams = new Tensor[optimizer.extraParams()]; 171 | memoryWParams = new Tensor[optimizer.extraParams()]; 172 | 173 | resetUParams = new Tensor[optimizer.extraParams()]; 174 | updateUParams = new Tensor[optimizer.extraParams()]; 175 | memoryUParams = new Tensor[optimizer.extraParams()]; 176 | 177 | // use arrays to make initializing tensors take up less lines of code 178 | Tensor[][] params = {resetWParams, updateWParams, memoryWParams, resetUParams, updateUParams, memoryUParams}; 179 | Tensor[] weights = {resetW, updateW, memoryW, resetU, updateU, memoryU}; 180 | 181 | for(int i = 0; i < params.length; i++){ 182 | for(int j = 0; j < params[i].length; j++){ 183 | params[i][j] = new Tensor(weights[i].shape(), false); 184 | } 185 | } 186 | 187 | if(useBias){ 188 | resetBParams = new Tensor[optimizer.extraParams()]; 189 | updateBParams = new Tensor[optimizer.extraParams()]; 190 | memoryBParams = new Tensor[optimizer.extraParams()]; 191 | 192 | params = new Tensor[][]{resetBParams, updateBParams, memoryBParams}; 193 | weights = new Tensor[]{resetB, updateB, memoryB}; 194 | 195 | for(int i = 0; i < params.length; i++){ 196 | for(int j = 0; j < params[i].length; j++){ 197 | params[i][j] = new Tensor(weights[i].shape(), false); 198 | } 199 | } 200 | } 201 | } 202 | 203 | // average grads 204 | gradResetW = gradResetW.div(Math.max(changeCount, 1)); 205 | gradUpdateW = gradUpdateW.div(Math.max(changeCount, 1)); 206 | gradMemoryW = gradMemoryW.div(Math.max(changeCount, 1)); 207 | gradResetU = gradResetU.div(Math.max(changeCount, 1)); 208 | gradUpdateU = gradUpdateU.div(Math.max(changeCount, 1)); 209 | gradMemoryU = gradMemoryU.div(Math.max(changeCount, 1)); 210 | 211 | // optimize weights using the grads 212 | if(regularizer == null){ 213 | resetW = resetW.sub(optimizer.optimize(gradResetW, resetWParams)); 214 | updateW = updateW.sub(optimizer.optimize(gradUpdateW, updateWParams)); 215 | memoryW = memoryW.sub(optimizer.optimize(gradMemoryW, memoryWParams)); 216 | 217 | resetU = resetU.sub(optimizer.optimize(gradResetU, resetUParams)); 218 | updateU = updateU.sub(optimizer.optimize(gradUpdateU, updateUParams)); 219 | memoryU = memoryU.sub(optimizer.optimize(gradMemoryU, memoryUParams)); 220 | }else{ 221 | resetW = resetW.sub(optimizer.optimize(gradResetW.add(regularizer.derivative(resetW)), resetWParams)); 222 | updateW = updateW.sub(optimizer.optimize(gradUpdateW.add(regularizer.derivative(updateW)), updateWParams)); 223 | memoryW = memoryW.sub(optimizer.optimize(gradMemoryW.add(regularizer.derivative(memoryW)), memoryWParams)); 224 | 225 | resetU = resetU.sub(optimizer.optimize(gradResetU.add(regularizer.derivative(resetU)), resetUParams)); 226 | updateU = updateU.sub(optimizer.optimize(gradUpdateU.add(regularizer.derivative(updateU)), updateUParams)); 227 | memoryU = memoryU.sub(optimizer.optimize(gradMemoryU.add(regularizer.derivative(memoryU)), memoryUParams)); 228 | } 229 | 230 | // reset grads 231 | gradResetW = new Tensor(gradResetW.shape(), false); 232 | gradUpdateW = new Tensor(gradUpdateW.shape(), false); 233 | gradMemoryW = new Tensor(gradMemoryW.shape(), false); 234 | gradResetU = new Tensor(gradResetU.shape(), false); 235 | gradUpdateU = new Tensor(gradUpdateU.shape(), false); 236 | gradMemoryU = new Tensor(gradMemoryU.shape(), false); 237 | 238 | if(useBias){ 239 | // average grads 240 | gradResetB = gradResetB.div(Math.max(changeCount, 1)); 241 | gradUpdateB = gradUpdateB.div(Math.max(changeCount, 1)); 242 | gradMemoryB = gradMemoryB.div(Math.max(changeCount, 1)); 243 | 244 | // optimize biases using the grads 245 | resetB = resetB.sub(optimizer.optimize(gradResetB, resetBParams)); 246 | updateB = updateB.sub(optimizer.optimize(gradUpdateB, updateBParams)); 247 | memoryB = memoryB.sub(optimizer.optimize(gradMemoryB, memoryBParams)); 248 | 249 | // reset grads 250 | gradResetB = new Tensor(gradResetB.shape(), false); 251 | gradUpdateB = new Tensor(gradUpdateB.shape(), false); 252 | gradMemoryB = new Tensor(gradMemoryB.shape(), false); 253 | } 254 | } 255 | 256 | @Override 257 | public int byteSize(){ 258 | return Double.BYTES * resetW.size() + Double.BYTES * updateW.size() + Double.BYTES * memoryW.size() + 259 | Double.BYTES * resetU.size() + Double.BYTES * updateU.size() + Double.BYTES * memoryU.size() + 260 | (useBias ? (Double.BYTES * resetB.size() + Double.BYTES * updateB.size() + Double.BYTES * memoryB.size()) : 0); 261 | } 262 | 263 | @Override 264 | public ByteBuffer bytes(){ 265 | ByteBuffer bb = ByteBuffer.allocate(byteSize()); 266 | 267 | Tensor[] weights = {resetW, updateW, memoryW, resetU, updateU, memoryU}; 268 | 269 | for(Tensor w : weights){ 270 | for(int i = 0; i < w.size(); i++){ 271 | bb.putDouble(w.flatGet(i)); 272 | } 273 | } 274 | 275 | if(useBias){ 276 | weights = new Tensor[]{resetB, updateB, memoryB}; 277 | 278 | for(Tensor w : weights){ 279 | for(int i = 0; i < w.size(); i++){ 280 | bb.putDouble(w.flatGet(i)); 281 | } 282 | } 283 | } 284 | 285 | bb.flip(); 286 | return bb; 287 | } 288 | 289 | @Override 290 | public void readBytes(ByteBuffer bb){ 291 | double[] rW = new double[resetW.size()]; 292 | for(int i = 0; i < rW.length; i++){ 293 | rW[i] = bb.getDouble(); 294 | } 295 | resetW = new Tensor(resetW.shape(), rW); 296 | 297 | double[] uW = new double[updateW.size()]; 298 | for(int i = 0; i < uW.length; i++){ 299 | uW[i] = bb.getDouble(); 300 | } 301 | updateW = new Tensor(updateW.shape(), uW); 302 | 303 | double[] mW = new double[memoryW.size()]; 304 | for(int i = 0; i < mW.length; i++){ 305 | mW[i] = bb.getDouble(); 306 | } 307 | memoryW = new Tensor(memoryW.shape(), mW); 308 | 309 | double[] rU = new double[resetU.size()]; 310 | for(int i = 0; i < rU.length; i++){ 311 | rU[i] = bb.getDouble(); 312 | } 313 | resetU = new Tensor(resetU.shape(), rU); 314 | 315 | double[] uU = new double[updateU.size()]; 316 | for(int i = 0; i < uU.length; i++){ 317 | uU[i] = bb.getDouble(); 318 | } 319 | updateU = new Tensor(updateU.shape(), uU); 320 | 321 | double[] mU = new double[memoryU.size()]; 322 | for(int i = 0; i < mU.length; i++){ 323 | mU[i] = bb.getDouble(); 324 | } 325 | memoryU = new Tensor(memoryU.shape(), uU); 326 | 327 | if(useBias){ 328 | double[] rB = new double[resetB.size()]; 329 | for(int i = 0; i < rB.length; i++){ 330 | rB[i] = bb.getDouble(); 331 | } 332 | resetB = new Tensor(resetB.shape(), rB); 333 | 334 | double[] uB = new double[updateB.size()]; 335 | for(int i = 0; i < uB.length; i++){ 336 | uB[i] = bb.getDouble(); 337 | } 338 | updateB = new Tensor(updateB.shape(), uB); 339 | 340 | double[] mB = new double[memoryB.size()]; 341 | for(int i = 0; i < mB.length; i++){ 342 | mB[i] = bb.getDouble(); 343 | } 344 | memoryB = new Tensor(memoryB.shape(), mB); 345 | } 346 | } 347 | 348 | @Override 349 | public String toString(){ 350 | return "Gated Recurrent Unit"; 351 | } 352 | } 353 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/recurrent/RecurrentCell.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.recurrent; 2 | 3 | import java.nio.ByteBuffer; 4 | 5 | import javamachinelearning.optimizers.Optimizer; 6 | import javamachinelearning.regularizers.Regularizer; 7 | import javamachinelearning.utils.Tensor; 8 | 9 | public interface RecurrentCell{ 10 | public void noBias(); 11 | 12 | public int[] outputShape(); 13 | public int[] inputShape(); 14 | public void init(int inputSize, int numTimeSteps); 15 | public Tensor forwardPropagate(int t, Tensor input, Tensor prevState, boolean training); 16 | // backpropagation should return two tensors for the input and the previous state gradients 17 | public Tensor[] backPropagate(int t, Tensor input, Tensor prevState, Tensor error); 18 | public void update(Optimizer optimizer, Regularizer regularizer, int changeCount); 19 | public int byteSize(); 20 | public ByteBuffer bytes(); 21 | public void readBytes(ByteBuffer bb); 22 | } 23 | -------------------------------------------------------------------------------- /src/javamachinelearning/layers/recurrent/RecurrentLayer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.layers.recurrent; 2 | 3 | import java.nio.ByteBuffer; 4 | import java.util.Arrays; 5 | 6 | import javamachinelearning.layers.ParamsLayer; 7 | import javamachinelearning.optimizers.Optimizer; 8 | import javamachinelearning.regularizers.Regularizer; 9 | import javamachinelearning.utils.Tensor; 10 | import javamachinelearning.utils.TensorUtils; 11 | 12 | public class RecurrentLayer implements ParamsLayer{ 13 | private RecurrentCell cell; 14 | private boolean statefulTrain; 15 | private boolean statefulTest; 16 | 17 | private int numTimeSteps; 18 | private int numOutputs; 19 | private boolean outputAll; 20 | 21 | private Tensor[] states; 22 | private int changeCount; 23 | 24 | // state from the previous forward propagation of this layer 25 | // allows the recurrent cells to continue where it left off before 26 | private Tensor layerPrevState; 27 | // save the previous state before it is updated, for backpropagation 28 | private Tensor layerPrevStateTemp; 29 | 30 | public RecurrentLayer(int numTimeSteps, int numOutputs, RecurrentCell cell, boolean statefulTrain, boolean statefulTest){ 31 | this.numTimeSteps = numTimeSteps; 32 | this.numOutputs = numOutputs; 33 | this.cell = cell; 34 | this.statefulTrain = statefulTrain; 35 | this.statefulTest = statefulTest; 36 | this.outputAll = false; 37 | } 38 | 39 | public RecurrentLayer(int numTimeSteps, RecurrentCell cell, boolean statefulTrain, boolean statefulTest){ 40 | this.numTimeSteps = numTimeSteps; 41 | this.numOutputs = numTimeSteps; 42 | this.cell = cell; 43 | this.statefulTrain = statefulTrain; 44 | this.statefulTest = statefulTest; 45 | this.outputAll = true; 46 | } 47 | 48 | public RecurrentLayer(int numTimeSteps, RecurrentCell cell, boolean stateful){ 49 | this.numTimeSteps = numTimeSteps; 50 | this.numOutputs = numTimeSteps; 51 | this.cell = cell; 52 | this.statefulTrain = stateful; 53 | this.statefulTest = stateful; 54 | this.outputAll = true; 55 | } 56 | 57 | @Override 58 | public int[] outputShape(){ 59 | return new int[]{numOutputs, cell.outputShape()[1]}; 60 | } 61 | 62 | @Override 63 | public int[] inputShape(){ 64 | return new int[]{numTimeSteps, cell.inputShape()[1]}; 65 | } 66 | 67 | @Override 68 | public void init(int[] inputShape){ 69 | states = new Tensor[numTimeSteps]; 70 | 71 | // inputShape[1] = size of input 1D tensor 72 | cell.init(inputShape[1], numTimeSteps); 73 | } 74 | 75 | @Override 76 | public ParamsLayer noBias(){ 77 | cell.noBias(); 78 | return this; 79 | } 80 | 81 | public RecurrentCell cell(){ 82 | return cell; 83 | } 84 | 85 | @Override 86 | public Tensor forwardPropagate(Tensor input, boolean training){ 87 | return forwardPropagate(input, numTimeSteps, training); 88 | } 89 | 90 | // more general method that allows the number of times the cell is propagated through to vary 91 | public Tensor forwardPropagate(Tensor input, int timeSteps, boolean training){ 92 | int outputCount = outputAll ? timeSteps : Math.min(numOutputs, timeSteps); 93 | boolean stateful = (training && statefulTrain) || (!training && statefulTest); 94 | Tensor[] outputs = new Tensor[outputCount]; 95 | int idx = 0; 96 | 97 | // the same recurrent cell is used across multiple time steps! 98 | // data is fed into the cell repeatedly 99 | for(int i = 0; i < timeSteps; i++){ 100 | Tensor inTensor = i < input.shape()[0] ? 101 | input.get(i) : new Tensor(cell.inputShape(), false); 102 | 103 | Tensor prevState = i == 0 ? 104 | (stateful && layerPrevState != null ? layerPrevState : 105 | new Tensor(cell.inputShape(), false)) : states[i - 1]; 106 | 107 | states[i] = cell.forwardPropagate(i, inTensor, prevState, training); 108 | 109 | // only output the last few cells 110 | if(i >= timeSteps - outputCount){ 111 | outputs[idx] = states[i]; 112 | idx++; 113 | } 114 | } 115 | 116 | // save last state for next time this layer is forward propagated, if necessary 117 | if(stateful){ 118 | layerPrevStateTemp = layerPrevState; 119 | layerPrevState = states[timeSteps - 1]; 120 | } 121 | 122 | return TensorUtils.stack(outputs); 123 | } 124 | 125 | @Override 126 | public Tensor backPropagate(Tensor input, Tensor output, Tensor nextLayerError){ 127 | Tensor[] prevLayerError = new Tensor[numTimeSteps]; 128 | Tensor nextCellError = new Tensor(cell.outputShape(), false); 129 | 130 | for(int i = numTimeSteps - 1; i >= 0; i--){ 131 | Tensor inTensor = i < input.shape()[0] ? 132 | input.get(i) : new Tensor(cell.inputShape(), false); 133 | 134 | Tensor prevState = i == 0 ? 135 | (statefulTrain && layerPrevStateTemp != null ? layerPrevStateTemp : 136 | new Tensor(cell.inputShape(), false)) : states[i - 1]; 137 | 138 | // accumulate the error gradient from the next layer and the next cell 139 | int idx = i - (numTimeSteps - numOutputs); 140 | Tensor totalError = (idx >= 0) ? nextCellError.add(nextLayerError.get(idx)) : nextCellError; 141 | 142 | Tensor[] arr = cell.backPropagate(i, inTensor, prevState, totalError); 143 | 144 | prevLayerError[i] = i < input.shape()[0] ? 145 | arr[0] : new Tensor(cell.inputShape(), false); 146 | 147 | nextCellError = arr[1]; 148 | } 149 | 150 | changeCount++; 151 | 152 | return TensorUtils.stack(prevLayerError); 153 | } 154 | 155 | @Override 156 | public void update(Optimizer optimizer, Regularizer regularizer){ 157 | cell.update(optimizer, regularizer, changeCount); 158 | changeCount = 0; 159 | } 160 | 161 | // reset the previous state that is saved if this model is stateful 162 | public void resetState(){ 163 | layerPrevState = null; 164 | layerPrevStateTemp = null; 165 | } 166 | 167 | @Override 168 | public int byteSize(){ 169 | return cell.byteSize(); 170 | } 171 | 172 | @Override 173 | public ByteBuffer bytes(){ 174 | ByteBuffer bb = ByteBuffer.allocate(byteSize()); 175 | bb.put(cell.bytes()); 176 | bb.flip(); 177 | return bb; 178 | } 179 | 180 | @Override 181 | public void readBytes(ByteBuffer bb){ 182 | cell.readBytes(bb); 183 | } 184 | 185 | @Override 186 | public String toString(){ 187 | return "Recurrent\tCell: " + cell.toString() + "\tInput Shape: " + Arrays.toString(inputShape()) + "\tOutput Shape: " + Arrays.toString(outputShape()); 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /src/javamachinelearning/networks/NeuralNetwork.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.networks; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public interface NeuralNetwork{ 6 | public Tensor[] predict(Tensor[] input); 7 | public Tensor predict(Tensor input); 8 | public int[] inputShape(); 9 | public int[] outputShape(); 10 | public void saveToFile(String path); 11 | public void loadFromFile(String path); 12 | } 13 | -------------------------------------------------------------------------------- /src/javamachinelearning/networks/SequentialNN.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.networks; 2 | 3 | import java.nio.ByteBuffer; 4 | import java.nio.file.Files; 5 | import java.nio.file.Paths; 6 | import java.util.ArrayList; 7 | 8 | import javamachinelearning.layers.Layer; 9 | import javamachinelearning.layers.ParamsLayer; 10 | import javamachinelearning.layers.recurrent.RecurrentLayer; 11 | import javamachinelearning.optimizers.Optimizer; 12 | import javamachinelearning.regularizers.Regularizer; 13 | import javamachinelearning.utils.Loss; 14 | import javamachinelearning.utils.Tensor; 15 | import javamachinelearning.utils.Utils; 16 | 17 | public class SequentialNN implements NeuralNetwork, SupervisedNeuralNetwork{ 18 | private ArrayList layers = new ArrayList(); 19 | private int[] inputShape; 20 | 21 | public SequentialNN(int... inputShape){ 22 | if(inputShape.length > 1) 23 | this.inputShape = inputShape; 24 | else 25 | this.inputShape = new int[]{1, inputShape[0]}; 26 | } 27 | 28 | public int size(){ 29 | return layers.size(); 30 | } 31 | 32 | public Layer layer(int idx){ 33 | return layers.get(idx); 34 | } 35 | 36 | public void add(Layer l){ 37 | l.init(layers.isEmpty() ? inputShape : layers.get(layers.size() - 1).outputShape()); 38 | layers.add(l); 39 | } 40 | 41 | @Override 42 | public Tensor[] predict(Tensor[] input){ 43 | Tensor[] res = new Tensor[input.length]; 44 | 45 | for(int i = 0; i < input.length; i++){ 46 | res[i] = predict(input[i]); 47 | } 48 | 49 | return res; 50 | } 51 | 52 | @Override 53 | public Tensor predict(Tensor input){ 54 | for(int i = 0; i < layers.size(); i++){ 55 | input = layers.get(i).forwardPropagate(input, false); 56 | } 57 | return input; 58 | } 59 | 60 | // predict using a specific number of time steps 61 | public Tensor predict(Tensor input, int timeSteps){ 62 | for(int i = 0; i < layers.size(); i++){ 63 | if(layers.get(i) instanceof RecurrentLayer) 64 | input = ((RecurrentLayer)layers.get(i)).forwardPropagate(input, timeSteps, false); 65 | else 66 | input = layers.get(i).forwardPropagate(input, false); 67 | } 68 | return input; 69 | } 70 | 71 | // predictTrain should only be used for training! 72 | // it saves the outputs for each layer 73 | public Tensor[] predictTrain(Tensor input){ 74 | Tensor[] res = new Tensor[layers.size() + 1]; 75 | res[0] = input; 76 | for(int i = 1; i < layers.size() + 1; i++){ 77 | input = layers.get(i - 1).forwardPropagate(input, true); 78 | res[i] = input; 79 | } 80 | return res; 81 | } 82 | 83 | @Override 84 | public int[] inputShape(){ 85 | return inputShape; 86 | } 87 | 88 | @Override 89 | public int[] outputShape(){ 90 | return layers.get(layers.size() - 1).outputShape(); 91 | } 92 | 93 | @Override 94 | public void train(Tensor[] input, Tensor[] target, int epochs, int batchSize, Loss loss, Optimizer optimizer, Regularizer regularizer, boolean shuffle){ 95 | train(input, target, epochs, batchSize, loss, optimizer, regularizer, shuffle, false, null); 96 | } 97 | 98 | @Override 99 | public void train(Tensor[] input, Tensor[] target, int epochs, int batchSize, Loss loss, Optimizer optimizer, Regularizer regularizer, boolean shuffle, boolean verbose){ 100 | train(input, target, epochs, batchSize, loss, optimizer, regularizer, shuffle, verbose, null); 101 | } 102 | 103 | @Override 104 | public void train(Tensor[] inputParam, Tensor[] targetParam, int epochs, int batchSize, Loss loss, Optimizer optimizer, Regularizer regularizer, boolean shuffle, boolean verbose, ProgressFunction f){ 105 | // make sure shuffling does not affect the input data 106 | Tensor[] input = inputParam.clone(); 107 | Tensor[] target = targetParam.clone(); 108 | 109 | for(int i = 0; i < epochs; i++){ 110 | double totalLoss = 0.0; 111 | 112 | if(verbose && (i == epochs - 1 || (epochs < 10 ? 0 : (i % (epochs / 10))) == 0)){ 113 | System.out.println(Utils.makeStr('=', 75)); 114 | System.out.println("Epoch " + i); 115 | System.out.println(); 116 | } 117 | 118 | if(shuffle) 119 | Utils.shuffle(input, target); 120 | 121 | for(int j = 0; j < input.length; j++){ 122 | Tensor[] res = predictTrain(input[j]); 123 | 124 | totalLoss += loss.loss(res[res.length - 1], target[j]).reduce(0, (a, b) -> a + b); 125 | 126 | if(verbose && ((i == epochs - 1 || (epochs < 10 ? 0 : (i % (epochs / 10))) == 0) && (input.length < 10 ? 0 : (j % (input.length / 10))) == 0)){ 127 | System.out.print("Input: "); 128 | System.out.println(input[j]); 129 | System.out.print("Output: "); 130 | System.out.println(res[res.length - 1]); 131 | System.out.print("Target: "); 132 | System.out.println(target[j]); 133 | System.out.println(); 134 | } 135 | 136 | // calculate derivative of the loss function and backpropagate 137 | Tensor lossDerivative = loss.derivative(res[res.length - 1], target[j]); 138 | backPropagate(res, lossDerivative); 139 | 140 | // update weights and biases if batch size is reached 141 | if((j + 1) % batchSize == 0 || j == input.length - 1){ 142 | for(int k = 0; k < layers.size(); k++){ 143 | if(layers.get(k) instanceof ParamsLayer) 144 | ((ParamsLayer)layers.get(k)).update(optimizer, regularizer); 145 | } 146 | 147 | optimizer.update(); 148 | } 149 | } 150 | 151 | if(i == epochs - 1 || (epochs < 10 ? 0 : (i % (epochs / 10))) == 0){ 152 | if(verbose){ 153 | System.out.println(); 154 | }else{ 155 | System.out.print("Epoch " + i + "\t"); 156 | } 157 | System.out.println("Average Loss: " + Utils.format(totalLoss / input.length)); 158 | } 159 | if(verbose && (i == epochs - 1 || (epochs < 10 ? 0 : (i % (epochs / 10))) == 0)){ 160 | System.out.println(Utils.makeStr('=', 75)); 161 | } 162 | 163 | if(f != null) 164 | f.apply(i, totalLoss / input.length); 165 | } 166 | } 167 | 168 | public void backPropagate(Tensor[] result, Tensor error){ 169 | for(int i = layers.size() - 1; i >= 0; i--){ 170 | error = layers.get(i).backPropagate(result[i], result[i + 1], error); 171 | } 172 | } 173 | 174 | // resets the saved states of stateful recurrent layers 175 | public void resetStates(){ 176 | for(int i = 0; i < layers.size(); i++){ 177 | if(layers.get(i) instanceof RecurrentLayer){ 178 | ((RecurrentLayer)layers.get(i)).resetState(); 179 | } 180 | } 181 | } 182 | 183 | @Override 184 | public String toString(){ 185 | StringBuilder b = new StringBuilder(); 186 | b.append("Sequential Neural Network\n"); 187 | b.append(Utils.makeStr('-', 75) + "\n"); 188 | for(int i = 0; i < layers.size(); i++){ 189 | b.append("\n" + layers.get(i).toString()); 190 | b.append("\n\n" + Utils.makeStr('-', 75) + "\n"); 191 | } 192 | return b.toString(); 193 | } 194 | 195 | @Override 196 | public void saveToFile(String path){ 197 | int totalLayerSize = 0; 198 | for(int i = 0; i < layers.size(); i++){ 199 | if(layers.get(i) instanceof ParamsLayer) 200 | totalLayerSize += ((ParamsLayer)layers.get(i)).byteSize(); 201 | } 202 | ByteBuffer bb = ByteBuffer.allocate(totalLayerSize); 203 | for(int i = 0; i < layers.size(); i++){ 204 | if(layers.get(i) instanceof ParamsLayer) 205 | bb.put(((ParamsLayer)layers.get(i)).bytes()); 206 | } 207 | bb.flip(); 208 | try{ 209 | Files.write(Paths.get(path), bb.array()); 210 | }catch(Exception e){ 211 | e.printStackTrace(); 212 | } 213 | } 214 | 215 | @Override 216 | public void loadFromFile(String path){ 217 | byte[] bytes = null; 218 | try{ 219 | bytes = Files.readAllBytes(Paths.get(path)); 220 | }catch(Exception e){ 221 | e.printStackTrace(); 222 | } 223 | ByteBuffer bb = ByteBuffer.wrap(bytes); 224 | for(int i = 0; i < layers.size(); i++){ 225 | if(layers.get(i) instanceof ParamsLayer) 226 | ((ParamsLayer)layers.get(i)).readBytes(bb); 227 | } 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /src/javamachinelearning/networks/SupervisedNeuralNetwork.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.networks; 2 | 3 | import javamachinelearning.optimizers.Optimizer; 4 | import javamachinelearning.regularizers.Regularizer; 5 | import javamachinelearning.utils.Loss; 6 | import javamachinelearning.utils.Tensor; 7 | 8 | public interface SupervisedNeuralNetwork{ 9 | public void train(Tensor[] input, Tensor[] target, int epochs, int batchSize, Loss loss, Optimizer optimizer, Regularizer regularizer, boolean shuffle); 10 | public void train(Tensor[] input, Tensor[] target, int epochs, int batchSize, Loss loss, Optimizer optimizer, Regularizer regularizer, boolean shuffle, boolean verbose); 11 | public void train(Tensor[] input, Tensor[] target, int epochs, int batchSize, Loss loss, Optimizer optimizer, Regularizer regularizer, boolean shuffle, boolean verbose, ProgressFunction f); 12 | 13 | // callback function to check the progress of training 14 | public interface ProgressFunction{ 15 | public void apply(int epoch, double loss); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/javamachinelearning/optimizers/AdaDeltaOptimizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.optimizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class AdaDeltaOptimizer implements Optimizer{ 6 | private static final double epsilon = 0.00000001; 7 | private double rho; 8 | private Tensor learnRate; 9 | 10 | public AdaDeltaOptimizer() { 11 | this.rho = 0.95; 12 | this.learnRate.add(0.001); 13 | } 14 | 15 | public AdaDeltaOptimizer(double learnRate) { 16 | this.rho = 0.95; 17 | this.learnRate.add(learnRate); 18 | } 19 | 20 | @Override 21 | public void update() { 22 | // notihing to do 23 | } 24 | 25 | @Override 26 | public int extraParams() { 27 | return 0; 28 | } 29 | 30 | @Override 31 | public Tensor optimize(Tensor grads, Tensor[] params) { 32 | params[0] = params[0].mul(rho).add((grads.mul(grads)).mul(1.0-rho)); 33 | Tensor t = grads.mul(learnRate.add(epsilon)).div(params[0].map(x -> Math.sqrt(x)).add(epsilon)); 34 | learnRate = t.mul(t).mul(1.0-rho).add(params[1]).mul(rho); 35 | return t; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/javamachinelearning/optimizers/AdagradOptimizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.optimizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class AdagradOptimizer implements Optimizer{ 6 | private static final double epsilon = 0.00000001; 7 | private double learnRate; 8 | 9 | public AdagradOptimizer(){ 10 | this.learnRate = 0.1; 11 | } 12 | 13 | public AdagradOptimizer(double learnRate){ 14 | this.learnRate = learnRate; 15 | } 16 | 17 | @Override 18 | public int extraParams(){ 19 | return 1; 20 | } 21 | 22 | @Override 23 | public void update(){ 24 | // nothing to do 25 | } 26 | 27 | @Override 28 | public Tensor optimize(Tensor grads, Tensor[] params){ 29 | // parameter: sum of squared gradients 30 | params[0] = params[0].add(grads.mul(grads)); 31 | return grads.mul(learnRate).div(params[0].map(x -> Math.sqrt(x)).add(epsilon)); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/javamachinelearning/optimizers/AdamOptimizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.optimizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class AdamOptimizer implements Optimizer{ 6 | private static final double epsilon = 0.00000001; 7 | private double learnRate; 8 | private double beta1; 9 | private double beta2; 10 | 11 | private double currBeta1; // these biases are changed while optimizing 12 | private double currBeta2; 13 | 14 | public AdamOptimizer(){ 15 | this.learnRate = 0.001; 16 | this.beta1 = 0.9; 17 | this.beta2 = 0.999; 18 | 19 | currBeta1 = this.beta1; 20 | currBeta2 = this.beta2; 21 | } 22 | 23 | public AdamOptimizer(double learnRate){ 24 | this.learnRate = learnRate; 25 | this.beta1 = 0.9; 26 | this.beta2 = 0.999; 27 | 28 | currBeta1 = this.beta1; 29 | currBeta2 = this.beta2; 30 | } 31 | 32 | public AdamOptimizer(double learnRate, double beta1, double beta2){ 33 | this.learnRate = learnRate; 34 | this.beta1 = beta1; 35 | this.beta2 = beta2; 36 | 37 | currBeta1 = this.beta1; 38 | currBeta2 = this.beta2; 39 | } 40 | 41 | @Override 42 | public int extraParams(){ 43 | return 2; 44 | } 45 | 46 | @Override 47 | public void update(){ 48 | currBeta1 *= beta1; 49 | currBeta2 *= beta2; 50 | } 51 | 52 | @Override 53 | public Tensor optimize(Tensor grads, Tensor[] params){ 54 | // parameter 1: momentum 55 | // parameter 2: velocity 56 | params[0] = params[0].mul(beta1).add(grads.mul(1.0 - beta1)); 57 | params[1] = params[1].mul(beta2).add(grads.mul(grads).mul(1.0 - beta2)); 58 | return params[0].div(1.0 - currBeta1).div( 59 | params[1].div(1.0 - currBeta2).map(x -> Math.sqrt(x)).add(epsilon)).mul(learnRate); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/javamachinelearning/optimizers/MomentumOptimizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.optimizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class MomentumOptimizer implements Optimizer{ 6 | private double learnRate; 7 | private double mu; // friction to decay momentum 8 | 9 | public MomentumOptimizer(){ 10 | this.learnRate = 0.1; 11 | this.mu = 0.9; 12 | } 13 | 14 | public MomentumOptimizer(double learnRate){ 15 | this.learnRate = learnRate; 16 | this.mu = 0.9; 17 | } 18 | 19 | public MomentumOptimizer(double learnRate, double mu){ 20 | this.learnRate = learnRate; 21 | this.mu = mu; 22 | } 23 | 24 | @Override 25 | public int extraParams(){ 26 | return 1; 27 | } 28 | 29 | @Override 30 | public void update(){ 31 | // nothing to do 32 | } 33 | 34 | @Override 35 | public Tensor optimize(Tensor grads, Tensor[] params){ 36 | // parameter: velocity 37 | Tensor prev = params[0]; 38 | params[0] = params[0].mul(mu).sub(grads.mul(learnRate)); 39 | return prev.mul(mu).sub(params[0].mul(1.0 + mu)); // is negated 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/javamachinelearning/optimizers/NAGOptimizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.optimizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class NAGOptimizer implements Optimizer{ 6 | private double learnRate; 7 | private double mu; // friction to decay momentum 8 | 9 | public NAGOptimizer(){ 10 | this.learnRate = 0.1; 11 | this.mu = 0.9; 12 | } 13 | 14 | public NAGOptimizer(double learnRate){ 15 | this.learnRate = learnRate; 16 | this.mu = 0.9; 17 | } 18 | 19 | public NAGOptimizer(double learnRate, double mu){ 20 | this.learnRate = learnRate; 21 | this.mu = mu; 22 | } 23 | 24 | @Override 25 | public int extraParams(){ 26 | return 1; 27 | } 28 | 29 | @Override 30 | public void update(){ 31 | // nothing to do 32 | } 33 | 34 | @Override 35 | public Tensor optimize(Tensor grads, Tensor[] params){ 36 | // parameter: velocity 37 | params[0] = params[0].mul(mu).sub(grads.mul(learnRate)); 38 | return params[0].mul(-1.0); // is negated 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/javamachinelearning/optimizers/Optimizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.optimizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public interface Optimizer{ 6 | // called every training iteration, after optimizing weights/biases 7 | public void update(); 8 | 9 | // how many extra parameters per weight/bias 10 | public int extraParams(); 11 | 12 | // some optimizers might modify the extra params! 13 | public Tensor optimize(Tensor grads, Tensor[] params); 14 | } 15 | -------------------------------------------------------------------------------- /src/javamachinelearning/optimizers/RMSPropOptimizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.optimizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class RMSPropOptimizer implements Optimizer{ 6 | private static final double epsilon = 0.00000001; 7 | private double learnRate; 8 | private double mu; 9 | 10 | public RMSPropOptimizer(){ 11 | this.mu = 0.9; 12 | this.learnRate = 0.1; 13 | } 14 | 15 | public RMSPropOptimizer(double learnRate){ 16 | this.mu = 0.9; 17 | this.learnRate = learnRate; 18 | } 19 | 20 | @Override 21 | public int extraParams(){ 22 | return 1; 23 | } 24 | 25 | @Override 26 | public void update(){ 27 | // nothing to do 28 | } 29 | @Override 30 | public Tensor optimize(Tensor grads, Tensor[] params){ 31 | // parameter: exponential average of squared gradients 32 | params[0] = params[0].mul(mu).add((grads.mul(grads)).mul(1.0-mu)); 33 | return grads.mul(learnRate).div(params[0].map(x -> Math.sqrt(x)).add(epsilon)); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/javamachinelearning/optimizers/SGDOptimizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.optimizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class SGDOptimizer implements Optimizer{ 6 | private double learnRate; 7 | 8 | public SGDOptimizer(){ 9 | this.learnRate = 0.01; 10 | } 11 | 12 | public SGDOptimizer(double learnRate){ 13 | this.learnRate = learnRate; 14 | } 15 | 16 | @Override 17 | public int extraParams(){ 18 | return 0; 19 | } 20 | 21 | @Override 22 | public void update(){ 23 | // nothing to do 24 | } 25 | 26 | @Override 27 | public Tensor optimize(Tensor grads, Tensor[] params){ 28 | return grads.mul(learnRate); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/javamachinelearning/regularizers/ElasticNetRegularizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.regularizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class ElasticNetRegularizer implements Regularizer{ 6 | private double lambdaL1; 7 | private double lambdaL2; 8 | 9 | public ElasticNetRegularizer(){ 10 | this.lambdaL1 = 0.001; 11 | this.lambdaL2 = 0.001; 12 | } 13 | 14 | public ElasticNetRegularizer(double lambdaL1, double lambdaL2){ 15 | this.lambdaL1 = lambdaL1; 16 | this.lambdaL2 = lambdaL2; 17 | } 18 | 19 | @Override 20 | public Tensor derivative(Tensor w){ 21 | return w.map(x -> lambdaL1 * Math.signum(x) + lambdaL2 * x); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/javamachinelearning/regularizers/L1Regularizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.regularizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class L1Regularizer implements Regularizer{ 6 | private double lambda; 7 | 8 | public L1Regularizer(){ 9 | this.lambda = 0.01; 10 | } 11 | 12 | public L1Regularizer(double lambda){ 13 | this.lambda = lambda; 14 | } 15 | 16 | @Override 17 | public Tensor derivative(Tensor w){ 18 | return w.map(x -> lambda * Math.signum(x)); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/javamachinelearning/regularizers/L2Regularizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.regularizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public class L2Regularizer implements Regularizer{ 6 | private double lambda; 7 | 8 | public L2Regularizer(){ 9 | this.lambda = 0.01; 10 | } 11 | 12 | public L2Regularizer(double lambda){ 13 | this.lambda = lambda; 14 | } 15 | 16 | @Override 17 | public Tensor derivative(Tensor w){ 18 | return w.mul(lambda); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/javamachinelearning/regularizers/Regularizer.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.regularizers; 2 | 3 | import javamachinelearning.utils.Tensor; 4 | 5 | public interface Regularizer{ 6 | // no need to actually compute the regularization 7 | public Tensor derivative(Tensor w); 8 | } 9 | -------------------------------------------------------------------------------- /src/javamachinelearning/utils/Activation.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.utils; 2 | 3 | public interface Activation{ 4 | public static final Activation linear = new Activation(){ 5 | @Override 6 | public Tensor activate(Tensor t){ 7 | return t; 8 | } 9 | 10 | @Override 11 | public Tensor derivative(Tensor t){ 12 | return new Tensor(t.shape(), 1.0); 13 | } 14 | 15 | @Override 16 | public String toString(){ 17 | return "Linear"; 18 | } 19 | }; 20 | 21 | public static final Activation sigmoid = new Activation(){ 22 | @Override 23 | public Tensor activate(Tensor t){ 24 | return t.map(x -> 1.0 / (1.0 + Math.exp(-x))); 25 | } 26 | 27 | @Override 28 | public Tensor derivative(Tensor t){ 29 | return t.map(x -> x * (1.0 - x)); 30 | } 31 | 32 | @Override 33 | public String toString(){ 34 | return "Sigmoid"; 35 | } 36 | }; 37 | 38 | // linear approximation of sigmoid 39 | public static final Activation hardSigmoid = new Activation(){ 40 | @Override 41 | public Tensor activate(Tensor t){ 42 | return t.map(x -> Math.min(Math.max(x * 0.2 + 0.5, 0.0), 1.0)); 43 | } 44 | 45 | @Override 46 | public Tensor derivative(Tensor t){ 47 | return t.map(x -> x > 0.0 && x < 1.0 ? 0.2 : 0.0); 48 | } 49 | 50 | @Override 51 | public String toString(){ 52 | return "Hard Sigmoid"; 53 | } 54 | }; 55 | 56 | public static final Activation tanh = new Activation(){ 57 | @Override 58 | public Tensor activate(Tensor t){ 59 | return t.map(x -> 2.0 / (1.0 + Math.exp(-2.0 * x)) - 1.0); 60 | } 61 | 62 | @Override 63 | public Tensor derivative(Tensor t){ 64 | return t.map(x -> 1.0 - x * x); 65 | } 66 | 67 | @Override 68 | public String toString(){ 69 | return "Hyperbolic Tangent"; 70 | } 71 | }; 72 | 73 | public static final Activation relu = new Activation(){ 74 | @Override 75 | public Tensor activate(Tensor t){ 76 | return t.map(x -> Math.max(0.0, x)); 77 | } 78 | 79 | @Override 80 | public Tensor derivative(Tensor t){ 81 | return t.map(x -> x > 0.0 ? 1.0 : 0.0); 82 | } 83 | 84 | @Override 85 | public String toString(){ 86 | return "Rectified Linear Unit"; 87 | } 88 | }; 89 | 90 | public static final Activation leakyRelu = new Activation(){ 91 | @Override 92 | public Tensor activate(Tensor t){ 93 | // note: hard coded leaky value! 94 | return t.map(x -> x > 0.0 ? x : x * 0.01); 95 | } 96 | 97 | @Override 98 | public Tensor derivative(Tensor t){ 99 | return t.map(x -> x > 0.0 ? 1.0 : 0.01); 100 | } 101 | 102 | @Override 103 | public String toString(){ 104 | return "Leaky Rectified Linear Unit"; 105 | } 106 | }; 107 | 108 | public static final Activation relu6 = new Activation(){ 109 | @Override 110 | public Tensor activate(Tensor t){ 111 | return t.map(x -> Math.min(Math.max(0.0, x), 6.0)); 112 | } 113 | 114 | @Override 115 | public Tensor derivative(Tensor t){ 116 | return t.map(x -> (x > 0.0) && (x < 6.0) ? 1.0 : 0.0); 117 | } 118 | 119 | @Override 120 | public String toString(){ 121 | return "Rectified Linear Unit 6"; 122 | } 123 | }; 124 | 125 | public static final Activation relu3 = new Activation(){ 126 | @Override 127 | public Tensor activate(Tensor t){ 128 | return t.map(x -> Math.min(Math.max(0.0, x), 3.0)); 129 | } 130 | 131 | @Override 132 | public Tensor derivative(Tensor t){ 133 | return t.map(x -> (x > 0.0) && (x < 3.0) ? 1.0 : 0.0); 134 | } 135 | 136 | @Override 137 | public String toString(){ 138 | return "Rectified Linear Unit 3"; 139 | } 140 | }; 141 | 142 | public static final Activation elu = new Activation(){ 143 | double alpha = 1.0; 144 | @Override 145 | public Tensor activate(Tensor t){ 146 | return t.map(x -> Math.max(alpha * (Math.exp(x)-1.0), x)); 147 | } 148 | 149 | @Override 150 | public Tensor derivative(Tensor t){ 151 | return t.map(x -> x > 0 ? 1.0 : alpha * Math.exp(x)); 152 | } 153 | 154 | @Override 155 | public String toString(){ 156 | return "Exponential Linear Unit"; 157 | } 158 | }; 159 | 160 | public static final Activation selu = new Activation(){ 161 | double alpha = 1.6732632423543772848170429916717; 162 | double scale = 1.0507009873554804934193349852946; 163 | @Override 164 | public Tensor activate(Tensor t){ 165 | return t.map(x -> Math.max(scale * alpha * (Math.exp(x)-1.0), x)); 166 | } 167 | 168 | @Override 169 | public Tensor derivative(Tensor t){ 170 | return t.map(x -> x > 0 ? 1.0 : scale * alpha * Math.exp(x)); 171 | } 172 | 173 | @Override 174 | public String toString(){ 175 | return "Scaled Exponential Linear Unit"; 176 | } 177 | }; 178 | 179 | public static final Activation softmax = new Activation(){ 180 | @Override 181 | public Tensor activate(Tensor t){ 182 | Tensor max = t.reduceLast(Double.MIN_VALUE, (a, b) -> Math.max(a, b)); 183 | max = max.dupLast(t.shape()[t.shape().length - 1]); 184 | 185 | Tensor exp = t.sub(max).map(x -> Math.exp(x)); 186 | 187 | Tensor sum = exp.reduceLast(0, (a, b) -> a + b); 188 | sum = sum.dupLast(t.shape()[t.shape().length - 1]); 189 | 190 | return exp.div(sum); 191 | } 192 | 193 | @Override 194 | public Tensor derivative(Tensor t){ 195 | // because the loss function should be cross entropy 196 | return new Tensor(t.shape(), 1.0); 197 | } 198 | 199 | @Override 200 | public String toString(){ 201 | return "Softmax"; 202 | } 203 | }; 204 | 205 | public Tensor activate(Tensor t); 206 | // derivatives are calculated in terms of the activated output 207 | public Tensor derivative(Tensor t); 208 | } 209 | -------------------------------------------------------------------------------- /src/javamachinelearning/utils/ImageUtils.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.utils; 2 | 3 | import javax.imageio.ImageIO; 4 | import java.io.File; 5 | import java.io.IOException; 6 | import java.awt.image.BufferedImage; 7 | 8 | public class ImageUtils { 9 | private BufferedImage bImage = null; 10 | 11 | public ImageUtils() { 12 | 13 | } 14 | 15 | /* 16 | Read one color image file and return to 3D int type 17 | input: 18 | path : Path of image file 19 | output: 20 | 3D int array [height][width][3] 21 | */ 22 | public int[][][] readColorImageFile(String path) { 23 | File f; 24 | try { 25 | f = new File(path); 26 | bImage = ImageIO.read(f); 27 | } catch(IOException e) { 28 | System.out.println("Exception occured : " + e.getMessage()); 29 | } 30 | int width = bImage.getWidth(); 31 | int height = bImage.getHeight(); 32 | int[][][] data = new int[height][width][3]; 33 | for( int y=0 ; y>24)&0xff; 38 | int r = (p>>16)&0xff; 39 | int g = (p>>8)&0xff; 40 | int b = p&0xff; 41 | 42 | data[y][x][0] = r; 43 | data[y][x][1] = g; 44 | data[y][x][2] = b; 45 | } 46 | } 47 | return data; 48 | } 49 | 50 | /* 51 | Convert RGB to Gray 52 | */ 53 | public int[][] convertRGBtoGray(int[][][] colorImg) { 54 | int height = colorImg.length; 55 | int width = colorImg[0].length; 56 | int[][] grayImg = new int[height][width]; 57 | 58 | for( int y=0 ; y>24)&0xff; 110 | int r = (p>>16)&0xff; 111 | int g = (p>>8)&0xff; 112 | int b = p&0xff; 113 | 114 | // Reference = http://entropymine.com/imageworsener/grayscale/ 115 | int gray = (int)(0.2126*r + 0.7152*g + 0.0722*b); 116 | 117 | p = (a<<24) | (gray<<16) | (gray<<8) | gray; 118 | bImage.setRGB(x, y, p); 119 | } 120 | } 121 | 122 | try{ 123 | f = new File("test.bmp"); 124 | ImageIO.write(bImage, "bmp", f); 125 | } catch( IOException e ) { 126 | System.out.println("Exception occured :" + e.getMessage()); 127 | } 128 | } 129 | 130 | /* 131 | Read many Images(in folder) and return to array of Tensor type 132 | input: 133 | path : Image folder path 134 | output: 135 | Tensor type array of Images 136 | */ 137 | public Tensor[] readImages(String folderPath, boolean convertGray) { 138 | File folder = new File(folderPath); 139 | File[] listOfFiles = folder.listFiles(); 140 | Tensor[] tensors = new Tensor[listOfFiles.length]; 141 | for (int i = 0; i < listOfFiles.length; i++) { 142 | tensors[i] = readColorImageToTensor(folderPath + listOfFiles[i].getName(), convertGray); 143 | } 144 | return tensors; 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/javamachinelearning/utils/Loss.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.utils; 2 | 3 | public interface Loss{ 4 | public static final Loss squared = new Loss(){ 5 | @Override 6 | public Tensor loss(Tensor x, Tensor t){ 7 | return x.sub(t).map(val -> val * val); 8 | } 9 | 10 | @Override 11 | public Tensor derivative(Tensor x, Tensor t){ 12 | return x.sub(t); 13 | } 14 | 15 | @Override 16 | public String toString(){ 17 | return "Squared"; 18 | } 19 | }; 20 | 21 | // multi-class classification 22 | public static final Loss softmaxCrossEntropy = new Loss(){ 23 | @Override 24 | public Tensor loss(Tensor x, Tensor t){ 25 | // because the target is a one hot vector 26 | return t.mul(x.map(val -> -Math.log(val))); 27 | } 28 | 29 | @Override 30 | public Tensor derivative(Tensor x, Tensor t){ 31 | // because the output layer has to be softmax 32 | return x.sub(t); 33 | } 34 | 35 | @Override 36 | public String toString(){ 37 | return "Softmax Cross Entropy"; 38 | } 39 | }; 40 | 41 | // binary classification 42 | public static final Loss binaryCrossEntropy = new Loss(){ 43 | @Override 44 | public Tensor loss(Tensor x, Tensor t){ 45 | // because the target is 0 or 1 46 | Tensor a = t.mul(x.map(val -> -Math.log(val))); 47 | Tensor b = t.map(val -> 1.0 - val).mul(x.map(val -> -Math.log(1.0 - val))); 48 | return a.add(b); 49 | } 50 | 51 | @Override 52 | public Tensor derivative(Tensor x, Tensor t){ 53 | // if output layer is sigmoid, the denominator cancels out 54 | return x.sub(t).div(x.map(val -> val * (1.0 - val))); 55 | } 56 | 57 | @Override 58 | public String toString(){ 59 | return "Binary Cross Entropy"; 60 | } 61 | }; 62 | 63 | public Tensor loss(Tensor x, Tensor t); 64 | public Tensor derivative(Tensor x, Tensor t); 65 | } 66 | -------------------------------------------------------------------------------- /src/javamachinelearning/utils/MNISTUtils.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.utils; 2 | 3 | import java.awt.image.BufferedImage; 4 | import java.io.File; 5 | import java.nio.ByteBuffer; 6 | import java.nio.ByteOrder; 7 | import java.nio.file.Files; 8 | import java.nio.file.Paths; 9 | 10 | import javax.imageio.ImageIO; 11 | 12 | public class MNISTUtils{ 13 | public static Tensor[] loadDataSetImages(String path, int num){ 14 | byte[] bytes = null; 15 | try{ 16 | bytes = Files.readAllBytes(Paths.get(path)); 17 | }catch(Exception e){ 18 | e.printStackTrace(); 19 | } 20 | ByteBuffer bb = ByteBuffer.wrap(bytes); 21 | bb.order(ByteOrder.BIG_ENDIAN); 22 | int count = bb.getInt(4); 23 | int row = bb.getInt(8); 24 | int col = bb.getInt(12); 25 | bb.position(16); 26 | Tensor[] res = new Tensor[Math.min(count, num)]; 27 | for(int i = 0; i < Math.min(count, num); i++){ 28 | double[] curr = new double[row * col]; 29 | for(int j = 0; j < row * col; j++){ 30 | curr[j] = Utils.unsignedByteToInt(bb.get()) / 255.0; 31 | } 32 | res[i] = new Tensor(curr).reshape(row, col).T(); 33 | } 34 | return res; 35 | } 36 | 37 | public static Tensor[] loadDataSetLabels(String path, int num){ 38 | byte[] bytes = null; 39 | try{ 40 | bytes = Files.readAllBytes(Paths.get(path)); 41 | }catch(Exception e){ 42 | e.printStackTrace(); 43 | } 44 | ByteBuffer bb = ByteBuffer.wrap(bytes); 45 | bb.order(ByteOrder.BIG_ENDIAN); 46 | int count = bb.getInt(4); 47 | bb.position(8); 48 | Tensor[] res = new Tensor[Math.min(count, num)]; 49 | for(int i = 0; i < Math.min(count, num); i++){ 50 | res[i] = TensorUtils.oneHot(Utils.unsignedByteToInt(bb.get()), 10); 51 | } 52 | return res; 53 | } 54 | 55 | public static Tensor loadImage(String path, int width, int height){ 56 | BufferedImage image = null; 57 | try{ 58 | image = ImageIO.read(new File(path)); 59 | }catch(Exception e){ 60 | e.printStackTrace(); 61 | } 62 | return loadImage(image, width, height); 63 | } 64 | 65 | public static Tensor loadImage(BufferedImage image, int width, int height){ 66 | double[][] arr = new double[image.getHeight()][image.getWidth()]; 67 | for(int i = 0; i < image.getHeight(); i++){ 68 | for(int j = 0; j < image.getWidth(); j++){ 69 | arr[i][j] = 1.0 - ((image.getRGB(j, i) & 0xFF) / 255.0); 70 | } 71 | } 72 | 73 | return Utils.centerData(arr, width, height); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/javamachinelearning/utils/Tensor.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.utils; 2 | 3 | import java.util.Random; 4 | 5 | public class Tensor{ 6 | private int[] shape; 7 | private double[] data; 8 | 9 | private int[] mult; 10 | private int size; 11 | 12 | public Tensor(int[] shape, boolean rand){ 13 | this.shape = shape; 14 | 15 | calcMult(); 16 | 17 | data = new double[size]; 18 | if(rand){ 19 | Random r = new Random(); 20 | // for initializing weights 21 | int sum = 0; 22 | for(int i = 0; i < shape.length; i++){ 23 | sum += shape[i]; 24 | } 25 | for(int i = 0; i < size; i++){ 26 | // xavier normal initialization (not truncated) 27 | data[i] = r.nextGaussian() * Math.sqrt(2.0 / sum); 28 | } 29 | }else{ 30 | for(int i = 0; i < size; i++){ 31 | data[i] = 0; 32 | } 33 | } 34 | } 35 | 36 | public Tensor(int[] shape, double init){ 37 | this.shape = shape; 38 | 39 | calcMult(); 40 | 41 | data = new double[size]; 42 | for(int i = 0; i < size; i++){ 43 | data[i] = init; 44 | } 45 | } 46 | 47 | // will create a column vector 48 | public Tensor(double[] d){ 49 | shape = new int[]{1, d.length}; 50 | calcMult(); 51 | data = new double[size]; 52 | for(int i = 0; i < d.length; i++){ 53 | data[i] = d[i]; 54 | } 55 | } 56 | 57 | // note that the following two initializers work in row major format! 58 | // however, the data is internally represented as column major, so some swaps happen 59 | public Tensor(double[][] d){ 60 | shape = new int[]{d[0].length, d.length}; 61 | calcMult(); 62 | data = new double[size]; 63 | int idx = 0; 64 | for(int i = 0; i < d[0].length; i++){ 65 | for(int j = 0; j < d.length; j++){ 66 | data[idx] = d[j][i]; 67 | idx++; 68 | } 69 | } 70 | } 71 | 72 | // the first dimension is treated as the depth! 73 | public Tensor(double[][][] d){ 74 | shape = new int[]{d[0][0].length, d[0].length, d.length}; 75 | calcMult(); 76 | data = new double[size]; 77 | int idx = 0; 78 | for(int i = 0; i < d[0][0].length; i++){ 79 | for(int j = 0; j < d[0].length; j++){ 80 | for(int k = 0; k < d.length; k++){ 81 | data[idx] = d[k][j][i]; 82 | idx++; 83 | } 84 | } 85 | } 86 | } 87 | 88 | public Tensor(int[] shape, double[] data){ 89 | this.shape = shape; 90 | calcMult(); 91 | this.data = data; 92 | } 93 | 94 | // Convert int Data to double Data 95 | public Tensor(int[][] intD) { 96 | shape = new int[]{intD[0].length, intD.length}; 97 | calcMult(); 98 | data = new double[size]; 99 | int idx = 0; 100 | for(int i = 0; i < intD[0].length; i++){ 101 | for(int j = 0; j < intD.length; j++){ 102 | data[idx] = (double)intD[j][i]; 103 | idx++; 104 | } 105 | } 106 | } 107 | 108 | public Tensor(int[][][] intD) { 109 | shape = new int[]{intD[0][0].length, intD[0].length, intD.length}; 110 | calcMult(); 111 | data = new double[size]; 112 | int idx = 0; 113 | for(int i = 0; i < intD[0][0].length; i++){ 114 | for(int j = 0; j < intD[0].length; j++){ 115 | for(int k = 0; k < intD.length; k++){ 116 | data[idx] = (double)intD[k][j][i]; 117 | idx++; 118 | } 119 | } 120 | } 121 | } 122 | 123 | public int[] shape(){ 124 | return shape; 125 | } 126 | 127 | public int[] mult(){ 128 | return mult; 129 | } 130 | 131 | public int size(){ 132 | return size; 133 | } 134 | 135 | public Tensor add(Tensor o){ 136 | double[] res = new double[size]; 137 | for(int i = 0; i < size; i++){ 138 | res[i] = data[i] + o.data[i]; 139 | } 140 | return new Tensor(shape, res); 141 | } 142 | 143 | public Tensor add(double d){ 144 | double[] res = new double[size]; 145 | for(int i = 0; i < size; i++){ 146 | res[i] = data[i] + d; 147 | } 148 | return new Tensor(shape, res); 149 | } 150 | 151 | public Tensor sub(Tensor o){ 152 | double[] res = new double[size]; 153 | for(int i = 0; i < size; i++){ 154 | res[i] = data[i] - o.data[i]; 155 | } 156 | return new Tensor(shape, res); 157 | } 158 | 159 | public Tensor sub(double d){ 160 | double[] res = new double[size]; 161 | for(int i = 0; i < size; i++){ 162 | res[i] = data[i] - d; 163 | } 164 | return new Tensor(shape, res); 165 | } 166 | 167 | public Tensor mul(Tensor o){ 168 | double[] res = new double[size]; 169 | for(int i = 0; i < size; i++){ 170 | res[i] = data[i] * o.data[i]; 171 | } 172 | return new Tensor(shape, res); 173 | } 174 | 175 | public Tensor mul(double d){ 176 | double[] res = new double[size]; 177 | for(int i = 0; i < size; i++){ 178 | res[i] = data[i] * d; 179 | } 180 | return new Tensor(shape, res); 181 | } 182 | 183 | public Tensor div(Tensor o){ 184 | double[] res = new double[size]; 185 | for(int i = 0; i < size; i++){ 186 | res[i] = data[i] / o.data[i]; 187 | } 188 | return new Tensor(shape, res); 189 | } 190 | 191 | public Tensor div(double d){ 192 | double[] res = new double[size]; 193 | for(int i = 0; i < size; i++){ 194 | res[i] = data[i] / d; 195 | } 196 | return new Tensor(shape, res); 197 | } 198 | 199 | public Tensor dot(Tensor o){ 200 | // basically matrix multiply 201 | // both must be 2D matrices 202 | 203 | double[] res = new double[shape[1] * o.shape[0]]; 204 | int idx = 0; 205 | 206 | for(int i = 0; i < shape[1]; i++){ 207 | for(int j = 0; j < o.shape[0]; j++){ 208 | for(int k = 0; k < shape[0]; k++){ 209 | res[idx] += data[k * shape[1] + i] * o.data[j * o.shape[1] + k]; 210 | } 211 | idx++; 212 | } 213 | } 214 | 215 | // transpose because the array is column-wise, not row-wise 216 | return new Tensor(new int[]{shape[1], o.shape[0]}, res).T(); 217 | } 218 | 219 | public Tensor T(){ // transposes 2D matrix 220 | if(shape.length < 2) 221 | return this; 222 | 223 | double[] res = new double[size]; 224 | int idx = 0; 225 | for(int i = 0; i < shape[1]; i++){ 226 | for(int j = 0; j < shape[0]; j++){ 227 | res[idx] = data[j * shape[1] + i]; 228 | idx++; 229 | } 230 | } 231 | return new Tensor(new int[]{shape[1], shape[0]}, res); 232 | } 233 | 234 | public Tensor flatten(){ 235 | return new Tensor(new int[]{1, size}, data); 236 | } 237 | 238 | public Tensor reshape(int... s){ 239 | return new Tensor(s, data); 240 | } 241 | 242 | public Tensor map(Function f){ 243 | double[] res = new double[size]; 244 | for(int i = 0; i < size; i++){ 245 | res[i] = f.apply(data[i]); 246 | } 247 | return new Tensor(shape, res); 248 | } 249 | 250 | public double reduce(double init, Function2 f){ 251 | double res = init; 252 | for(int i = 0; i < size; i++){ 253 | res = f.apply(res, data[i]); 254 | } 255 | return res; 256 | } 257 | 258 | // reduce only the last dimension 259 | public Tensor reduceLast(double init, Function2 f){ 260 | double[] res = new double[size / shape[shape.length - 1]]; 261 | for(int i = 0; i < res.length; i++){ 262 | res[i] = init; 263 | } 264 | 265 | for(int i = 0; i < size; i++){ 266 | int idx = i / shape[shape.length - 1]; 267 | res[idx] = f.apply(res[idx], data[i]); 268 | } 269 | 270 | int[] newShape; 271 | if(shape.length == 2){ 272 | newShape = new int[]{1, shape[0]}; 273 | }else{ 274 | newShape = new int[shape.length - 1]; 275 | for(int i = 0; i < shape.length - 1; i++){ 276 | newShape[i] = shape[i]; 277 | } 278 | } 279 | return new Tensor(newShape, res); 280 | } 281 | 282 | // duplicate along last dimension + 1 283 | public Tensor dupLast(int length){ 284 | double[] res = new double[size * length]; 285 | for(int i = 0; i < res.length; i++){ 286 | res[i] = data[i / length]; 287 | } 288 | 289 | int[] newShape; 290 | if(shape[0] == 1 && shape.length == 2){ 291 | newShape = new int[]{shape[1], length}; 292 | }else{ 293 | newShape = new int[shape.length + 1]; 294 | for(int i = 0; i < shape.length; i++){ 295 | newShape[i] = shape[i]; 296 | } 297 | newShape[shape.length] = length; 298 | } 299 | 300 | return new Tensor(newShape, res); 301 | } 302 | 303 | // stack copies of the tensor 304 | public Tensor dupFirst(int length){ 305 | double[] res = new double[length * size]; 306 | for(int i = 0; i < res.length; i++){ 307 | int idx = i % size; 308 | res[i] = data[idx]; 309 | } 310 | 311 | int[] newShape; 312 | if(shape[0] == 1 && shape.length == 2){ 313 | newShape = new int[]{length, shape[1]}; 314 | }else{ 315 | newShape = new int[shape.length + 1]; 316 | for(int i = 0; i < shape.length; i++){ 317 | newShape[i + 1] = shape[i]; 318 | } 319 | newShape[0] = length; 320 | } 321 | 322 | return new Tensor(newShape, res); 323 | } 324 | 325 | public double flatGet(int idx){ 326 | return data[idx]; 327 | } 328 | 329 | public Tensor get(int idx){ 330 | double[] res = new double[mult[0]]; 331 | for(int i = 0; i < mult[0]; i++){ 332 | res[i] = data[idx * mult[0] + i]; 333 | } 334 | 335 | int[] newShape; 336 | if(shape.length == 2){ 337 | newShape = new int[]{1, shape[1]}; 338 | }else{ 339 | newShape = new int[shape.length - 1]; 340 | for(int i = 1; i < shape.length; i++){ 341 | newShape[i - 1] = shape[i]; 342 | } 343 | } 344 | 345 | return new Tensor(newShape, res); 346 | } 347 | 348 | public interface Function{ 349 | public double apply(double x); 350 | } 351 | 352 | public interface Function2{ 353 | public double apply(double a, double b); 354 | } 355 | 356 | private void calcMult(){ 357 | mult = new int[shape.length]; 358 | mult[shape.length - 1] = 1; 359 | size = shape[shape.length - 1]; 360 | for(int i = shape.length - 2; i >= 0; i--){ 361 | mult[i] = mult[i + 1] * shape[i + 1]; 362 | size *= shape[i]; 363 | } 364 | } 365 | 366 | @Override 367 | public Tensor clone(){ 368 | return new Tensor(shape, data); 369 | } 370 | 371 | // toString returns a string that is in column major format! 372 | @Override 373 | public String toString(){ 374 | return str(0, size, 0); 375 | } 376 | 377 | private String str(int start, int end, int depth){ 378 | if(depth >= shape.length - 1){ 379 | StringBuilder b = new StringBuilder(); 380 | b.append('['); 381 | for(int i = start; i < end; i += mult[depth]){ 382 | b.append(Utils.format(data[i]) + ", "); 383 | } 384 | if(b.length() > 1) 385 | b.delete(b.length() - 2, b.length()); 386 | b.append(']'); 387 | return b.toString(); 388 | } 389 | 390 | StringBuilder b = new StringBuilder(); 391 | b.append('['); 392 | for(int i = start; i < end; i += mult[depth]){ 393 | b.append(str(i, i + mult[depth], depth + 1) + ",\n"); 394 | if(depth < shape.length - 2) 395 | b.append('\n'); 396 | } 397 | if(b.length() > 1) 398 | b.delete(b.length() - 2 - (depth < shape.length - 2 ? 1 : 0), b.length()); 399 | b.append(']'); 400 | return b.toString(); 401 | } 402 | } 403 | -------------------------------------------------------------------------------- /src/javamachinelearning/utils/TensorUtils.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.utils; 2 | 3 | public class TensorUtils{ 4 | // simple way of creating column vectors 5 | public static Tensor t(double... vals){ 6 | return new Tensor(new int[]{1, vals.length}, vals); 7 | } 8 | 9 | public static Tensor oneHot(int idx, int size){ 10 | double[] res = new double[size]; 11 | res[idx] = 1.0; 12 | return new Tensor(res); 13 | } 14 | 15 | public static Tensor oneHotString(String s, String alphabet){ 16 | Tensor[] res = new Tensor[s.length()]; 17 | for(int i = 0; i < s.length(); i++){ 18 | res[i] = oneHot(alphabet.indexOf(s.charAt(i)), alphabet.length()); 19 | } 20 | return stack(res); 21 | } 22 | 23 | // decode a one hot string 24 | public static String decodeString(Tensor val, boolean rand, String alphabet){ 25 | char[] res = new char[val.shape()[0]]; 26 | for(int i = 0; i < val.shape()[0]; i++){ 27 | res[i] = alphabet.charAt(rand ? randProb(val.get(i)) : argMax(val.get(i))); 28 | } 29 | return new String(res); 30 | } 31 | 32 | // pick random with probabilities 33 | public static int randProb(Tensor tensor){ 34 | double[] pre = new double[tensor.size()]; 35 | for(int i = 0; i < pre.length; i++){ 36 | pre[i] = tensor.flatGet(i) + (i == 0 ? 0 : pre[i - 1]); 37 | } 38 | 39 | double rand = Math.random() * pre[pre.length - 1]; 40 | for(int i = 0; i < pre.length; i++){ 41 | if(pre[i] >= rand) 42 | return i; 43 | } 44 | return -1; 45 | } 46 | 47 | public static int argMax(Tensor tensor){ 48 | double max = Double.MIN_VALUE; 49 | int maxIndex = -1; 50 | for(int i = 0; i < tensor.size(); i++){ 51 | if(tensor.flatGet(i) > max){ 52 | max = tensor.flatGet(i); 53 | maxIndex = i; 54 | } 55 | } 56 | return maxIndex; 57 | } 58 | 59 | public static Tensor stack(Tensor... tensors){ 60 | int[] shape; 61 | if(tensors[0].shape()[0] == 1 && tensors[0].shape().length == 2){ 62 | shape = new int[]{tensors.length, tensors[0].shape()[1]}; 63 | }else{ 64 | shape = new int[tensors[0].shape().length + 1]; 65 | shape[0] = tensors.length; 66 | 67 | for(int i = 0; i < tensors[0].shape().length; i++){ 68 | shape[i + 1] = tensors[0].shape()[i]; 69 | } 70 | } 71 | 72 | double[] res = new double[tensors[0].size() * tensors.length]; 73 | int idx = 0; 74 | for(int i = 0; i < tensors.length; i++){ 75 | for(int j = 0; j < tensors[i].size(); j++){ 76 | res[idx] = tensors[i].flatGet(j); 77 | idx++; 78 | } 79 | } 80 | return new Tensor(shape, res); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/javamachinelearning/utils/Utils.java: -------------------------------------------------------------------------------- 1 | package javamachinelearning.utils; 2 | 3 | import static javamachinelearning.utils.TensorUtils.argMax; 4 | import static javamachinelearning.utils.TensorUtils.t; 5 | 6 | import java.util.Arrays; 7 | import java.util.Random; 8 | 9 | public class Utils{ 10 | public static String format(double num){ 11 | return String.format("%,.7g", num); 12 | } 13 | 14 | public static String shorterFormat(double num){ 15 | return String.format("%,.2g", num); 16 | } 17 | 18 | public static String formatElapsedTime(long ms){ 19 | return String.format("%02d:%02d:%02d.%03d", ms / (3600 * 1000), ms / (60 * 1000) % 60, ms / 1000 % 60, ms % 1000); 20 | } 21 | 22 | public static void printArray(double[][] arr){ 23 | for(int i = 0; i < arr.length; i++){ 24 | for(int j = 0; j < arr[i].length; j++){ 25 | System.out.print(format(arr[i][j]) + " "); 26 | } 27 | System.out.println(); 28 | } 29 | } 30 | 31 | public static void printArray(double[] arr){ 32 | for(int i = 0; i < arr.length; i++){ 33 | System.out.print(format(arr[i]) + " "); 34 | } 35 | System.out.println(); 36 | } 37 | 38 | public static void printArray(double[] arr, int count){ 39 | for(int i = 0; i < count; i++){ 40 | System.out.print(format(arr[i]) + " "); 41 | } 42 | System.out.println(); 43 | } 44 | 45 | public static String pad(String s, int length, char c){ 46 | return s + makeStr(c, length - s.length()); 47 | } 48 | 49 | public static String makeStr(char c, int n){ 50 | char[] result = new char[n]; 51 | Arrays.fill(result, c); 52 | return new String(result); 53 | } 54 | 55 | public static int unsignedByteToInt(byte b){ 56 | int result = 0; 57 | for(int i = 0; i < 8; i++){ 58 | if((b & (1 << i)) != 0) 59 | result += 1 << i; 60 | } 61 | return result; 62 | } 63 | 64 | public static void printImage(Tensor[] image){ 65 | char[][][] chars = new char[image.length][0][0]; 66 | for(int i = 0; i < image.length; i++){ 67 | int s1 = image[i].shape()[0]; 68 | int s2 = image[i].shape()[1]; 69 | chars[i] = new char[s2][s1]; 70 | for(int j = 0; j < s1; j++){ 71 | for(int k = 0; k < s2; k++){ 72 | if(image[i].flatGet(j * s2 + k) < 0.3){ 73 | chars[i][k][j] = ' '; 74 | }else if(image[i].flatGet(j * s2 + k) > 0.6){ 75 | chars[i][k][j] = '#'; 76 | }else{ 77 | chars[i][k][j] = '.'; 78 | } 79 | } 80 | } 81 | } 82 | for(int i = 0; i < chars.length; i++){ 83 | for(int j = 0; j < chars[i].length; j++){ 84 | for(int k = 0; k < chars[i][j].length; k++){ 85 | System.out.print(chars[i][j][k] + " "); 86 | } 87 | System.out.println(); 88 | } 89 | System.out.println(); 90 | } 91 | } 92 | 93 | public static double classificationAccuracy(Tensor[] output, Tensor[] target){ 94 | int totalCorrect = 0; 95 | for(int i = 0; i < output.length; i++){ 96 | if(argMax(output[i]) == argMax(target[i])){ 97 | totalCorrect++; 98 | } 99 | } 100 | return (double)totalCorrect / (double)output.length; 101 | } 102 | 103 | public static int[] centerOfMass(double[][] arr){ 104 | double sumX = 0.0; 105 | double sumY = 0.0; 106 | double total = 0.0; 107 | for(int i = 0; i < arr.length; i++){ 108 | for(int j = 0; j < arr[i].length; j++){ 109 | if(arr[i][j] > 0.0){ 110 | sumX += i * arr[i][j]; 111 | sumY += j * arr[i][j]; 112 | total += arr[i][j]; 113 | } 114 | } 115 | } 116 | return new int[]{(int)(sumY / total), (int)(sumX / total)}; 117 | } 118 | 119 | public static double[] flatCombine(Tensor... tensors){ 120 | int sum = 0; 121 | for(int i = 0; i < tensors.length; i++){ 122 | sum += tensors[i].size(); 123 | } 124 | 125 | double[] res = new double[sum]; 126 | int idx = 0; 127 | for(int i = 0; i < tensors.length; i++){ 128 | for(int j = 0; j < tensors[i].size(); j++){ 129 | res[idx] = tensors[i].flatGet(j); 130 | idx++; 131 | } 132 | } 133 | return res; 134 | } 135 | 136 | public static Tensor[] flattenAll(Tensor[] tensors){ 137 | Tensor[] res = new Tensor[tensors.length]; 138 | for(int i = 0; i < tensors.length; i++){ 139 | res[i] = tensors[i].flatten(); 140 | } 141 | return res; 142 | } 143 | 144 | public static Tensor[] reshapeAll(Tensor[] tensors, int... shape){ 145 | Tensor[] res = new Tensor[tensors.length]; 146 | for(int i = 0; i < tensors.length; i++){ 147 | res[i] = tensors[i].reshape(shape); 148 | } 149 | return res; 150 | } 151 | 152 | public static Tensor centerData(double[][] arr, int width, int height){ 153 | double[][] centeredArr = new double[height][width]; 154 | int[] centerOfMass = centerOfMass(arr); 155 | for(int i = 0; i < arr.length; i++){ 156 | for(int j = 0; j < arr[i].length; j++){ 157 | if(arr[i][j] > 0.0){ 158 | int y = (height - arr.length) / 2 + i + arr.length / 2 - centerOfMass[0]; 159 | int x = (width - arr[i].length) / 2 + j + arr[i].length / 2 - centerOfMass[1]; 160 | if(y >= 0 && y < height && x >= 0 && x < width) 161 | centeredArr[y][x] = arr[i][j]; 162 | } 163 | } 164 | } 165 | return new Tensor(centeredArr); 166 | } 167 | 168 | public static Tensor[] standardDist(double x, double y, double s, int n){ 169 | Tensor[] res = new Tensor[n]; 170 | Random r = new Random(); 171 | for(int i = 0; i < n; i++){ 172 | res[i] = t(x + r.nextGaussian() * s, y + r.nextGaussian() * s); 173 | } 174 | return res; 175 | } 176 | 177 | public static Tensor[] concat(Tensor[]... tensors){ 178 | int sum = 0; 179 | for(int i = 0; i < tensors.length; i++){ 180 | sum += tensors[i].length; 181 | } 182 | Tensor[] res = new Tensor[sum]; 183 | int idx = 0; 184 | for(int i = 0; i < tensors.length; i++){ 185 | for(int j = 0; j < tensors[i].length; j++){ 186 | res[idx] = tensors[i][j]; 187 | idx++; 188 | } 189 | } 190 | return res; 191 | } 192 | 193 | public static void shuffle(Tensor[] x, Tensor[] y){ 194 | Random r = new Random(); 195 | for(int i = x.length - 1; i > 0; i--){ 196 | int j = r.nextInt(i + 1); 197 | Tensor xTemp = x[i]; 198 | Tensor yTemp = y[i]; 199 | x[i] = x[j]; 200 | y[i] = y[j]; 201 | x[j] = xTemp; 202 | y[j] = yTemp; 203 | } 204 | } 205 | 206 | // slides a window across a string and returns all of the substrings covered by the window 207 | public static String[] slide(String s, int winSize){ 208 | String[] res = new String[s.length() - winSize + 1]; 209 | for(int i = 0; i < res.length; i++){ 210 | res[i] = s.substring(i, i + winSize); 211 | } 212 | return res; 213 | } 214 | 215 | // remove all characters not in the alphabet 216 | public static String onlyKeepAlphabetChars(String s, String alphabet){ 217 | StringBuilder res = new StringBuilder(); 218 | for(int i = 0; i < s.length(); i++){ 219 | if(alphabet.indexOf(s.charAt(i)) != -1){ 220 | res.append(s.charAt(i)); 221 | } 222 | } 223 | return res.toString(); 224 | } 225 | } 226 | -------------------------------------------------------------------------------- /src/tests/Categories2Graph.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import static javamachinelearning.utils.TensorUtils.t; 4 | 5 | import java.awt.Color; 6 | 7 | import javax.swing.JFrame; 8 | 9 | import javamachinelearning.graphs.Graph; 10 | import javamachinelearning.graphs.GraphPanel; 11 | import javamachinelearning.layers.feedforward.ActivationLayer; 12 | import javamachinelearning.layers.feedforward.FCLayer; 13 | import javamachinelearning.networks.SequentialNN; 14 | import javamachinelearning.optimizers.AdamOptimizer; 15 | import javamachinelearning.utils.Activation; 16 | import javamachinelearning.utils.Loss; 17 | import javamachinelearning.utils.Tensor; 18 | import javamachinelearning.utils.Utils; 19 | 20 | public class Categories2Graph{ 21 | public static void main(String[] args){ 22 | SequentialNN net = new SequentialNN(2); 23 | net.add(new FCLayer(4)); 24 | net.add(new ActivationLayer(Activation.sigmoid)); 25 | net.add(new FCLayer(1)); 26 | net.add(new ActivationLayer(Activation.sigmoid)); 27 | 28 | Tensor[] x = Utils.concat(Utils.standardDist(0, 0, 0.1, 100), Utils.standardDist(0, 0.5, 0.1, 100), 29 | Utils.standardDist(0.5, 0, 0.1, 100), Utils.standardDist(0.5, 0.5, 0.1, 100)); 30 | 31 | Tensor[] y1 = new Tensor[100]; 32 | for(int i = 0; i < y1.length; i++){ 33 | y1[i] = t(0); 34 | } 35 | Tensor[] y2 = new Tensor[100]; 36 | for(int i = 0; i < y2.length; i++){ 37 | y2[i] = t(1); 38 | } 39 | Tensor[] y3 = new Tensor[100]; 40 | for(int i = 0; i < y3.length; i++){ 41 | y3[i] = t(1); 42 | } 43 | Tensor[] y4 = new Tensor[100]; 44 | for(int i = 0; i < y4.length; i++){ 45 | y4[i] = t(0); 46 | } 47 | Tensor[] y = Utils.concat(y1, y2, y3, y4); 48 | 49 | net.train(x, y, 1000, 10, Loss.binaryCrossEntropy, new AdamOptimizer(0.01), null, true, true); 50 | 51 | double[] xData = new double[x.length]; 52 | double[] yData = new double[x.length]; 53 | Color[] cData = new Color[x.length]; 54 | Color[] intToColor1 = {Color.blue, Color.red}; 55 | for(int i = 0; i < x.length; i++){ 56 | xData[i] = x[i].flatGet(0); 57 | yData[i] = x[i].flatGet(1); 58 | cData[i] = intToColor1[(int)y[i].flatGet(0)]; 59 | } 60 | 61 | JFrame frame = new JFrame(); 62 | 63 | Graph graph = new Graph(1000, 1000, xData, yData, cData, (x5, y5) -> { 64 | return intToColor1[(int)Math.round(net.predict(t(x5, y5)).flatGet(0))]; 65 | }); 66 | graph.draw(); 67 | frame.add(new GraphPanel(graph)); 68 | 69 | frame.setSize(1200, 1200); 70 | frame.setLocationRelativeTo(null); 71 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 72 | frame.setVisible(true); 73 | 74 | graph.saveToFile("classification_example3.png", "png"); 75 | graph.dispose(); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/tests/Categories4Graph.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import static javamachinelearning.utils.TensorUtils.argMax; 4 | import static javamachinelearning.utils.TensorUtils.oneHot; 5 | import static javamachinelearning.utils.TensorUtils.t; 6 | 7 | import java.awt.Color; 8 | 9 | import javax.swing.JFrame; 10 | 11 | import javamachinelearning.graphs.Graph; 12 | import javamachinelearning.graphs.GraphPanel; 13 | import javamachinelearning.layers.feedforward.ActivationLayer; 14 | import javamachinelearning.layers.feedforward.FCLayer; 15 | import javamachinelearning.networks.SequentialNN; 16 | import javamachinelearning.optimizers.SGDOptimizer; 17 | import javamachinelearning.regularizers.L2Regularizer; 18 | import javamachinelearning.utils.Activation; 19 | import javamachinelearning.utils.Loss; 20 | import javamachinelearning.utils.Tensor; 21 | import javamachinelearning.utils.Utils; 22 | 23 | public class Categories4Graph{ 24 | public static void main(String[] args){ 25 | SequentialNN net = new SequentialNN(2); 26 | net.add(new FCLayer(3)); 27 | net.add(new ActivationLayer(Activation.sigmoid)); 28 | net.add(new FCLayer(4)); 29 | net.add(new ActivationLayer(Activation.softmax)); 30 | 31 | Tensor[] x = Utils.concat(Utils.standardDist(0, 0, 0.1, 100), Utils.standardDist(0, 1, 0.1, 100), Utils.standardDist(1, 0, 0.1, 100), Utils.standardDist(1, 1, 0.1, 100)); 32 | Tensor[] y1 = new Tensor[100]; 33 | for(int i = 0; i < y1.length; i++){ 34 | y1[i] = oneHot(0, 4); 35 | } 36 | Tensor[] y2 = new Tensor[100]; 37 | for(int i = 0; i < y2.length; i++){ 38 | y2[i] = oneHot(1, 4); 39 | } 40 | Tensor[] y3 = new Tensor[100]; 41 | for(int i = 0; i < y3.length; i++){ 42 | y3[i] = oneHot(2, 4); 43 | } 44 | Tensor[] y4 = new Tensor[100]; 45 | for(int i = 0; i < y4.length; i++){ 46 | y4[i] = oneHot(3, 4); 47 | } 48 | Tensor[] y = Utils.concat(y1, y2, y3, y4); 49 | net.train(x, y, 100, 10, Loss.softmaxCrossEntropy, new SGDOptimizer(1), new L2Regularizer(0.01), true, true); 50 | 51 | double[] xData = new double[x.length]; 52 | double[] yData = new double[x.length]; 53 | Color[] cData = new Color[x.length]; 54 | Color[] intToColor1 = {Color.blue, Color.red, Color.yellow, Color.green}; 55 | for(int i = 0; i < x.length; i++){ 56 | xData[i] = x[i].flatGet(0); 57 | yData[i] = x[i].flatGet(1); 58 | cData[i] = intToColor1[argMax(y[i])]; 59 | } 60 | 61 | JFrame frame = new JFrame(); 62 | 63 | Graph graph = new Graph(1000, 1000, xData, yData, cData, (x5, y5) -> { 64 | return intToColor1[argMax(net.predict(t(x5, y5)))]; 65 | }); 66 | graph.draw(); 67 | frame.add(new GraphPanel(graph)); 68 | 69 | frame.setSize(1200, 1200); 70 | frame.setLocationRelativeTo(null); 71 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 72 | frame.setVisible(true); 73 | 74 | graph.saveToFile("classification_example3.png", "png"); 75 | graph.dispose(); 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/tests/ErrorGraphCrossEntropy.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import static javamachinelearning.utils.TensorUtils.t; 4 | 5 | import java.awt.Color; 6 | 7 | import javax.swing.JFrame; 8 | 9 | import javamachinelearning.graphs.Graph; 10 | import javamachinelearning.graphs.GraphPanel; 11 | import javamachinelearning.layers.feedforward.ActivationLayer; 12 | import javamachinelearning.layers.feedforward.FCLayer; 13 | import javamachinelearning.layers.feedforward.FeedForwardParamsLayer; 14 | import javamachinelearning.networks.SequentialNN; 15 | import javamachinelearning.optimizers.SGDOptimizer; 16 | import javamachinelearning.utils.Activation; 17 | import javamachinelearning.utils.Loss; 18 | import javamachinelearning.utils.Tensor; 19 | import javamachinelearning.utils.Utils; 20 | 21 | public class ErrorGraphCrossEntropy{ 22 | public static void main(String[] args) throws Exception{ 23 | SequentialNN nn = new SequentialNN(1); 24 | FeedForwardParamsLayer layer = new FCLayer(1).noBias(); 25 | nn.add(layer); 26 | nn.add(new ActivationLayer(Activation.sigmoid)); 27 | 28 | Tensor[] x = Utils.concat(Utils.standardDist(-0.5, -0.5, 0.1, 100), Utils.standardDist(0.5, 0.5, 0.1, 100)); 29 | x = Utils.reshapeAll(x, 1, 1); 30 | 31 | Tensor[] y = new Tensor[x.length]; 32 | 33 | int idx = 0; 34 | for(int i = 0; i < 100; i++){ 35 | y[idx] = t(0); 36 | idx++; 37 | } 38 | for(int i = 0; i < 100; i++){ 39 | y[idx] = t(1); 40 | idx++; 41 | } 42 | 43 | JFrame frame = new JFrame(); 44 | 45 | Graph graph = new Graph(1000, 1000, "Weight", "Error"); 46 | 47 | double rangeStart = 0; 48 | double rangeEnd = 10; 49 | 50 | graph.useCustomScale(rangeStart, rangeEnd, 0, 1); 51 | 52 | int n = 10; 53 | double[] xs = new double[n]; 54 | double[] ys = new double[n]; 55 | for(int i = 0; i < n; i++){ 56 | double xx = i * (rangeEnd - rangeStart) / n + rangeStart; 57 | xs[i] = xx; 58 | layer.setWeights(t(xx)); 59 | ys[i] = Loss.binaryCrossEntropy.loss(nn.predict(t(0.3)), t(1)).reduce(0, (a, b) -> a + b); 60 | } 61 | graph.addLineGraph(xs, ys); 62 | 63 | graph.draw(); 64 | 65 | layer.setWeights(t(0.01)); 66 | 67 | nn.train(x, y, 100, 1, Loss.binaryCrossEntropy, new SGDOptimizer(1), null, false, false, (epoch, loss) -> { 68 | graph.addPoint(layer.weights().flatGet(0), 69 | Loss.binaryCrossEntropy.loss(nn.predict(t(0.3)), t(1)).reduce(0, (a, b) -> a + b), Color.green); 70 | graph.draw(); 71 | }); 72 | 73 | frame.add(new GraphPanel(graph)); 74 | 75 | frame.setSize(1200, 1200); 76 | frame.setLocationRelativeTo(null); 77 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 78 | frame.setVisible(true); 79 | 80 | graph.saveToFile("error_graph_cross_entropy.png", "png"); 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /src/tests/ErrorGraphSquared.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import static javamachinelearning.utils.TensorUtils.t; 4 | 5 | import java.awt.Color; 6 | 7 | import javax.swing.JFrame; 8 | 9 | import javamachinelearning.graphs.Graph; 10 | import javamachinelearning.graphs.GraphPanel; 11 | import javamachinelearning.layers.feedforward.FCLayer; 12 | import javamachinelearning.layers.feedforward.FeedForwardParamsLayer; 13 | import javamachinelearning.networks.SequentialNN; 14 | import javamachinelearning.optimizers.SGDOptimizer; 15 | import javamachinelearning.utils.Loss; 16 | import javamachinelearning.utils.Tensor; 17 | 18 | public class ErrorGraphSquared{ 19 | public static void main(String[] args){ 20 | SequentialNN nn = new SequentialNN(1); 21 | FeedForwardParamsLayer layer = new FCLayer(1).noBias(); 22 | nn.add(layer); 23 | 24 | Tensor[] x = { 25 | t(0), 26 | t(1), 27 | t(2), 28 | t(3), 29 | t(4) 30 | }; 31 | 32 | Tensor[] y = { 33 | t(0), 34 | t(5), 35 | t(10), 36 | t(15), 37 | t(20) 38 | }; 39 | 40 | JFrame frame = new JFrame(); 41 | 42 | Graph graph = new Graph(1000, 1000, "Weight", "Error"); 43 | 44 | double rangeStart = 3; 45 | double rangeEnd = 7; 46 | 47 | graph.useCustomScale(rangeStart, rangeEnd, 0, 50); 48 | 49 | int n = 20; 50 | double[] xs = new double[n]; 51 | double[] ys = new double[n]; 52 | for(int i = 0; i < n; i++){ 53 | double xx = i * (rangeEnd - rangeStart) / n + rangeStart; 54 | xs[i] = xx; 55 | layer.setWeights(t(xx)); 56 | // y = 5x 57 | ys[i] = Loss.squared.loss(nn.predict(t(5)), t(5 * 5)).reduce(0, (a, b) -> a + b); 58 | } 59 | graph.addLineGraph(xs, ys); 60 | 61 | graph.draw(); 62 | 63 | layer.setWeights(t(0.01)); 64 | 65 | nn.train(x, y, 100, 1, Loss.squared, new SGDOptimizer(0.01), null, false, false, (epoch, loss) -> { 66 | graph.addPoint(layer.weights().flatGet(0), 67 | Loss.squared.loss(nn.predict(t(5)), t(5 * 5)).reduce(0, (a, b) -> a + b), Color.green); 68 | graph.draw(); 69 | }); 70 | 71 | frame.add(new GraphPanel(graph)); 72 | 73 | frame.setSize(1200, 1200); 74 | frame.setLocationRelativeTo(null); 75 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 76 | frame.setVisible(true); 77 | 78 | graph.saveToFile("error_graph_squared.png", "png"); 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/tests/GRUTest.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.InputStreamReader; 5 | import java.nio.file.Files; 6 | import java.nio.file.Paths; 7 | import java.util.ArrayList; 8 | 9 | import javamachinelearning.layers.feedforward.ActivationLayer; 10 | import javamachinelearning.layers.feedforward.DropoutLayer; 11 | import javamachinelearning.layers.feedforward.FCLayer; 12 | import javamachinelearning.layers.feedforward.ScalingLayer; 13 | import javamachinelearning.layers.recurrent.GRUCell; 14 | import javamachinelearning.layers.recurrent.RecurrentLayer; 15 | import javamachinelearning.networks.SequentialNN; 16 | import javamachinelearning.optimizers.AdamOptimizer; 17 | import javamachinelearning.utils.Activation; 18 | import javamachinelearning.utils.Loss; 19 | import javamachinelearning.utils.Tensor; 20 | import javamachinelearning.utils.TensorUtils; 21 | import javamachinelearning.utils.Utils; 22 | 23 | public class GRUTest{ 24 | public static void main(String[] args) throws Exception{ 25 | // all of the letters that can appear in the text 26 | String alphabet = "abcdefghijklmnopqrstuvwxyz .,?!\n'()-"; 27 | boolean readFromFile = true; 28 | 29 | // can optionally read a (long) string from a file 30 | String string; 31 | if(readFromFile) 32 | string = new String(Files.readAllBytes(Paths.get("rnn_training_romeo_juliet.txt"))); 33 | else 34 | string = "hello! what is your name? i am a recurrent neural network!"; 35 | 36 | // preprocess the string 37 | string = string.toLowerCase(); 38 | string = string.replace("\r\n", "\n"); 39 | // remove characters that are not found in the alphabet 40 | string = Utils.onlyKeepAlphabetChars(string, alphabet); 41 | 42 | int epochs = 500; 43 | int batchSize = 10; 44 | int winSize = 20; 45 | int winStep = 20; // winSize = winStep so substrings are not repeated 46 | int genIter = 5000; // how many characters to generate 47 | double temperature = 0.1; // lower = less randomness 48 | 49 | // pad the string with spaces to make it divisible by winStep 50 | string = Utils.pad(string, (int)Math.ceil((double)string.length() / winStep) * winStep + 1, ' '); 51 | 52 | // builds the network 53 | // for each time step, the input is a one hot vector describing the current character 54 | // for each time step, the output is a one hot vector describing the next character 55 | // the recurrent layers are stateful, which means that the next state relies on the previous states 56 | SequentialNN nn = new SequentialNN(winSize, alphabet.length()); 57 | nn.add(new RecurrentLayer(winSize, new GRUCell(), true)); 58 | nn.add(new DropoutLayer(0.3)); 59 | nn.add(new RecurrentLayer(winSize, new GRUCell(), true)); 60 | // the same fully connected layer is applied for every single time step 61 | nn.add(new FCLayer(alphabet.length())); 62 | // scales the values by the temperature before softmax 63 | nn.add(new ScalingLayer(1 / temperature, false)); 64 | nn.add(new ActivationLayer(Activation.softmax)); 65 | 66 | // get all substrings 67 | String[] str = Utils.slide(string, winSize); 68 | 69 | // skip some substrings if needed and one hot the strings 70 | ArrayList xArr = new ArrayList<>(); 71 | ArrayList tArr = new ArrayList<>(); 72 | for(int i = 0; i < str.length - 1; i += winStep){ 73 | xArr.add(TensorUtils.oneHotString(str[i], alphabet)); 74 | tArr.add(TensorUtils.oneHotString(str[i + 1], alphabet)); 75 | } 76 | 77 | Tensor[] xs = xArr.toArray(new Tensor[0]); 78 | Tensor[] ts = tArr.toArray(new Tensor[0]); 79 | 80 | nn.train(xs, 81 | ts, 82 | epochs, 83 | batchSize, 84 | Loss.softmaxCrossEntropy, 85 | new AdamOptimizer(0.01), 86 | null, // no regularization 87 | false, // no shuffling! 88 | false, 89 | (epoch, error) -> nn.resetStates()); // reset the GRU cell states every epoch 90 | 91 | // reads the seed string 92 | System.out.print("Input seed string: "); 93 | BufferedReader r = new BufferedReader(new InputStreamReader(System.in)); 94 | String seed = r.readLine(); 95 | r.close(); 96 | 97 | StringBuilder gen = new StringBuilder(); 98 | gen.append(seed); 99 | 100 | // warms up the model with the seed string 101 | if(seed.length() > 1){ 102 | Tensor seedInput = TensorUtils.oneHotString(seed.substring(0, seed.length() - 1), alphabet); 103 | nn.predict(seedInput, seed.length() - 1); 104 | } 105 | 106 | // for each iteration, the previous character is plugged in as one time step 107 | // and the next character is predicted 108 | // the previous states persists throughout the entire generation process 109 | for(int i = 0; i < genIter; i++){ 110 | Tensor inputStr = TensorUtils.oneHotString(gen.charAt(gen.length() - 1) + "", alphabet); 111 | String outputStr = TensorUtils.decodeString(nn.predict(inputStr, 1), true, alphabet); 112 | gen.append(outputStr); 113 | } 114 | 115 | System.out.println("Output: " + gen.toString()); 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/tests/LinearGraph.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import static javamachinelearning.utils.TensorUtils.t; 4 | 5 | import javax.swing.JFrame; 6 | 7 | import javamachinelearning.graphs.Graph; 8 | import javamachinelearning.graphs.GraphPanel; 9 | import javamachinelearning.layers.feedforward.FCLayer; 10 | import javamachinelearning.networks.SequentialNN; 11 | import javamachinelearning.optimizers.SGDOptimizer; 12 | import javamachinelearning.utils.Loss; 13 | import javamachinelearning.utils.Tensor; 14 | import javamachinelearning.utils.Utils; 15 | 16 | public class LinearGraph{ 17 | public static void main(String[] args){ 18 | // neural network with 1 input and 1 output, no activation function 19 | SequentialNN nn = new SequentialNN(1); 20 | nn.add(new FCLayer(1)); 21 | 22 | // y = 5x + 3 23 | Tensor[] x = { 24 | t(0), 25 | t(1), 26 | t(2), 27 | t(3), 28 | t(4) 29 | }; 30 | 31 | Tensor[] y = { 32 | t(3 + 0 + 1), 33 | t(3 + 5 - 1), 34 | t(3 + 10 + 1), 35 | t(3 + 15 - 1), 36 | t(3 + 20 + 1) 37 | }; 38 | 39 | nn.train(x, 40 | y, 41 | 100, // number of epochs 42 | 1, // batch size 43 | Loss.squared, 44 | new SGDOptimizer(0.01), 45 | null, // no regularizer 46 | false, //do not shuffle data 47 | true); // verbose 48 | 49 | // try the network on new data 50 | System.out.println(nn.predict(t(5))); 51 | 52 | JFrame frame = new JFrame(); 53 | 54 | // graph the learned line 55 | Graph graph = new Graph(1000, 1000, Utils.flatCombine(x), Utils.flatCombine(y), null, null); 56 | graph.useCustomScale(0, 5, 0, 30); 57 | graph.addLine(((FCLayer)nn.layer(0)).weights().flatGet(0), ((FCLayer)nn.layer(0)).bias().flatGet(0)); 58 | graph.draw(); 59 | frame.add(new GraphPanel(graph)); 60 | 61 | // show the JFrame 62 | frame.setSize(1200, 1200); 63 | frame.setLocationRelativeTo(null); 64 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 65 | frame.setVisible(true); 66 | 67 | // save the plot 68 | graph.saveToFile("nn_linear_regression.png", "png"); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/tests/LoadTest.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import static javamachinelearning.utils.TensorUtils.argMax; 4 | import static javamachinelearning.utils.TensorUtils.t; 5 | 6 | import java.awt.Color; 7 | 8 | import javax.swing.JFrame; 9 | 10 | import javamachinelearning.graphs.Graph; 11 | import javamachinelearning.graphs.GraphPanel; 12 | import javamachinelearning.layers.feedforward.ActivationLayer; 13 | import javamachinelearning.layers.feedforward.FCLayer; 14 | import javamachinelearning.networks.SequentialNN; 15 | import javamachinelearning.utils.Activation; 16 | 17 | public class LoadTest{ 18 | public static void main(String[] args){ 19 | SequentialNN net = new SequentialNN(2); 20 | net.add(new FCLayer(3)); 21 | net.add(new ActivationLayer(Activation.relu)); 22 | net.add(new FCLayer(4)); 23 | net.add(new ActivationLayer(Activation.softmax)); 24 | // load the weights from a file 25 | net.loadFromFile("saved_model_test.nn"); 26 | 27 | Color[] intToColor1 = {Color.blue, Color.red, Color.yellow, Color.green}; 28 | 29 | JFrame frame = new JFrame(); 30 | 31 | Graph graph = new Graph(1000, 1000, null, null, null, (x2, y2) -> { 32 | return intToColor1[argMax(net.predict(t(x2, y2)))]; 33 | }); 34 | graph.useCustomScale(0, 1, 0, 1); 35 | graph.draw(); 36 | frame.add(new GraphPanel(graph)); 37 | 38 | frame.setSize(1200, 1200); 39 | frame.setLocationRelativeTo(null); 40 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 41 | frame.setVisible(true); 42 | 43 | graph.saveToFile("classification_example2.png", "png"); 44 | graph.dispose(); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/tests/LogicGates.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import static javamachinelearning.utils.TensorUtils.t; 4 | 5 | import javamachinelearning.layers.feedforward.ActivationLayer; 6 | import javamachinelearning.layers.feedforward.FCLayer; 7 | import javamachinelearning.networks.SequentialNN; 8 | import javamachinelearning.optimizers.MomentumOptimizer; 9 | import javamachinelearning.utils.Activation; 10 | import javamachinelearning.utils.Loss; 11 | import javamachinelearning.utils.Tensor; 12 | 13 | public class LogicGates{ 14 | public static void main(String[] args){ 15 | SequentialNN net = new SequentialNN(2); 16 | net.add(new FCLayer(2)); 17 | net.add(new ActivationLayer(Activation.sigmoid)); 18 | net.add(new FCLayer(1)); 19 | net.add(new ActivationLayer(Activation.sigmoid)); 20 | Tensor[] x = { 21 | t(0, 0), 22 | t(0, 1), 23 | t(1, 0), 24 | t(1, 1) 25 | }; 26 | Tensor[] y = { 27 | t(0), 28 | t(1), 29 | t(1), 30 | t(0) 31 | }; 32 | 33 | System.out.println(net); 34 | 35 | net.train(x, y, 2000, 4, Loss.binaryCrossEntropy, new MomentumOptimizer(0.1), null, true, true); 36 | 37 | System.out.println(net.predict(t(0, 0))); 38 | System.out.println(net.predict(t(1, 0))); 39 | System.out.println(net.predict(t(0, 1))); 40 | System.out.println(net.predict(t(1, 1))); 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/tests/SaveTest.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import static javamachinelearning.utils.TensorUtils.argMax; 4 | import static javamachinelearning.utils.TensorUtils.t; 5 | 6 | import java.awt.Color; 7 | 8 | import javax.swing.JFrame; 9 | 10 | import javamachinelearning.graphs.Graph; 11 | import javamachinelearning.graphs.GraphPanel; 12 | import javamachinelearning.layers.feedforward.ActivationLayer; 13 | import javamachinelearning.layers.feedforward.FCLayer; 14 | import javamachinelearning.networks.SequentialNN; 15 | import javamachinelearning.optimizers.SGDOptimizer; 16 | import javamachinelearning.regularizers.L2Regularizer; 17 | import javamachinelearning.utils.Activation; 18 | import javamachinelearning.utils.Loss; 19 | import javamachinelearning.utils.Tensor; 20 | 21 | public class SaveTest{ 22 | public static void main(String[] args){ 23 | SequentialNN net = new SequentialNN(2); 24 | net.add(new FCLayer(3)); 25 | net.add(new ActivationLayer(Activation.relu)); 26 | net.add(new FCLayer(4)); 27 | net.add(new ActivationLayer(Activation.softmax)); 28 | 29 | Tensor[] x = { 30 | t(0, 0), 31 | t(0, 1), 32 | t(1, 0), 33 | t(1, 1), 34 | t(0.1, 0.1), 35 | t(0.1, 0.9), 36 | t(0.9, 0.1), 37 | t(0.9, 0.9) 38 | }; 39 | 40 | Tensor[] y = { 41 | t(1, 0, 0, 0), 42 | t(0, 1, 0, 0), 43 | t(0, 0, 1, 0), 44 | t(0, 0, 0, 1), 45 | t(1, 0, 0, 0), 46 | t(0, 1, 0, 0), 47 | t(0, 0, 1, 0), 48 | t(0, 0, 0, 1) 49 | }; 50 | 51 | net.train(x, y, 1000, 4, Loss.softmaxCrossEntropy, new SGDOptimizer(0.1), new L2Regularizer(0.1), true, true); 52 | 53 | double[] xData = new double[x.length]; 54 | double[] yData = new double[x.length]; 55 | Color[] cData = new Color[x.length]; 56 | Color[] intToColor1 = {Color.blue, Color.red, Color.yellow, Color.green}; 57 | for(int i = 0; i < x.length; i++){ 58 | xData[i] = x[i].flatGet(0); 59 | yData[i] = x[i].flatGet(1); 60 | cData[i] = intToColor1[argMax(y[i])]; 61 | } 62 | 63 | JFrame frame = new JFrame(); 64 | 65 | Graph graph = new Graph(1000, 1000, xData, yData, cData, (x2, y2) -> { 66 | return intToColor1[argMax(net.predict(t(x2, y2)))]; 67 | }); 68 | graph.draw(); 69 | frame.add(new GraphPanel(graph)); 70 | 71 | frame.setSize(1200, 1200); 72 | frame.setLocationRelativeTo(null); 73 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 74 | frame.setVisible(true); 75 | 76 | graph.saveToFile("classification_example.png", "png"); 77 | graph.dispose(); 78 | 79 | net.saveToFile("saved_model_test.nn"); 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /src/tests/TestImageUtils.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import javamachinelearning.utils.ImageUtils; 4 | import javamachinelearning.utils.Tensor; 5 | 6 | public class TestImageUtils { 7 | public static void main(String args[]) { 8 | ImageUtils img = new ImageUtils(); 9 | 10 | // Test readColorImageFile 11 | int[][][] colorImg = img.readColorImageFile("./Images/Set14/comic.bmp"); 12 | System.out.println("---Testing color image---"); 13 | System.out.print(colorImg[0][0][0] + ", "); 14 | System.out.print(colorImg[0][0][1] + ", "); 15 | System.out.println(colorImg[0][0][2]); 16 | 17 | // Test convertRGBtoGray 18 | int[][] grayImg = img.convertRGBtoGray(colorImg); 19 | System.out.println("---Testing converted gray image---"); 20 | System.out.println(grayImg[0][0]); 21 | 22 | // Tensor test(500x480 image) 23 | System.out.println("---Testing read one image to tensor---"); 24 | Tensor imageTensor = img.readColorImageToTensor("./Images/Set14/baboon.bmp", true); 25 | System.out.println("Height : " + imageTensor.shape()[1]); 26 | System.out.println("Width : " + imageTensor.shape()[0]); 27 | //System.out.println(imageTensor.toString()); 28 | 29 | // Operation test and save to "test.bmp" 30 | img.readOneImage_Test("./Images/Set14/baboon.bmp"); 31 | 32 | System.out.println("---Testing read many images to tensors---"); 33 | Tensor[] tensors = img.readImages("./Images/Set14/", true); 34 | for( int i=0 ; i { 42 | Tensor data = drawablePanel.getData(20, 20, 28, 28); 43 | Tensor result = nn.predict(data.flatten()); 44 | label.setText("Result: " + TensorUtils.argMax(result)); 45 | }); 46 | frame.add(submitButton); 47 | 48 | JButton clearButton = new JButton("Clear"); 49 | clearButton.setPreferredSize(new Dimension(200, 100)); 50 | clearButton.setFont(clearButton.getFont().deriveFont(30.0f)); 51 | clearButton.addActionListener((e) -> { 52 | drawablePanel.clear(); 53 | }); 54 | frame.add(clearButton); 55 | 56 | frame.setSize(1200, 1200); 57 | frame.setLocationRelativeTo(null); 58 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 59 | frame.setVisible(true); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/tests/TestMNISTDraw2.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import java.awt.Dimension; 4 | import java.awt.FlowLayout; 5 | 6 | import javax.swing.JButton; 7 | import javax.swing.JFrame; 8 | import javax.swing.JLabel; 9 | 10 | import javamachinelearning.drawables.MNISTDrawablePanel2; 11 | import javamachinelearning.layers.feedforward.ActivationLayer; 12 | import javamachinelearning.layers.feedforward.FCLayer; 13 | import javamachinelearning.networks.SequentialNN; 14 | import javamachinelearning.utils.Activation; 15 | import javamachinelearning.utils.Tensor; 16 | import javamachinelearning.utils.TensorUtils; 17 | 18 | public class TestMNISTDraw2{ 19 | public static void main(String[] args){ 20 | // make sure the drawings are big enough! 21 | 22 | SequentialNN nn = new SequentialNN(784); 23 | nn.add(new FCLayer(300)); 24 | nn.add(new ActivationLayer(Activation.relu)); 25 | nn.add(new FCLayer(10)); 26 | nn.add(new ActivationLayer(Activation.softmax)); 27 | nn.loadFromFile("mnist_weights_fc.nn"); 28 | 29 | JFrame frame = new JFrame(); 30 | frame.setLayout(new FlowLayout()); 31 | MNISTDrawablePanel2 drawablePanel = new MNISTDrawablePanel2(1000, 1000, 20, 20); 32 | frame.add(drawablePanel); 33 | 34 | JLabel label = new JLabel("Result: "); 35 | label.setFont(label.getFont().deriveFont(30.0f)); 36 | frame.add(label); 37 | 38 | JButton submitButton = new JButton("Submit"); 39 | submitButton.setPreferredSize(new Dimension(200, 100)); 40 | submitButton.setFont(submitButton.getFont().deriveFont(30.0f)); 41 | submitButton.addActionListener((e) -> { 42 | Tensor data = drawablePanel.getData(28, 28); 43 | Tensor result = nn.predict(data.flatten()); 44 | label.setText("Result: " + TensorUtils.argMax(result)); 45 | }); 46 | frame.add(submitButton); 47 | 48 | JButton clearButton = new JButton("Clear"); 49 | clearButton.setPreferredSize(new Dimension(200, 100)); 50 | clearButton.setFont(clearButton.getFont().deriveFont(30.0f)); 51 | clearButton.addActionListener((e) -> { 52 | drawablePanel.clear(); 53 | }); 54 | frame.add(clearButton); 55 | 56 | frame.setSize(1200, 1200); 57 | frame.setLocationRelativeTo(null); 58 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 59 | frame.setVisible(true); 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/tests/TestMNISTFile.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import javamachinelearning.layers.feedforward.ActivationLayer; 4 | import javamachinelearning.layers.feedforward.FCLayer; 5 | import javamachinelearning.networks.SequentialNN; 6 | import javamachinelearning.utils.Activation; 7 | import javamachinelearning.utils.MNISTUtils; 8 | import javamachinelearning.utils.Tensor; 9 | import javamachinelearning.utils.Utils; 10 | 11 | public class TestMNISTFile{ 12 | public static void main(String[] args){ 13 | SequentialNN nn = new SequentialNN(784); 14 | nn.add(new FCLayer(300)); 15 | nn.add(new ActivationLayer(Activation.relu)); 16 | nn.add(new FCLayer(10)); 17 | nn.add(new ActivationLayer(Activation.softmax)); 18 | nn.loadFromFile("mnist_weights_fc.nn"); 19 | 20 | Tensor[] testX = MNISTUtils.loadDataSetImages("t10k-images-idx3-ubyte", Integer.MAX_VALUE); 21 | Tensor[] testY = MNISTUtils.loadDataSetLabels("t10k-labels-idx1-ubyte", Integer.MAX_VALUE); 22 | Tensor[] testResult = nn.predict(Utils.flattenAll(testX)); 23 | 24 | System.out.println("Classification accuracy: " + Utils.format(Utils.classificationAccuracy(testResult, testY))); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/tests/TrainMNISTConv.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import javamachinelearning.layers.feedforward.ActivationLayer; 4 | import javamachinelearning.layers.feedforward.ConvLayer; 5 | import javamachinelearning.layers.feedforward.DropoutLayer; 6 | import javamachinelearning.layers.feedforward.FCLayer; 7 | import javamachinelearning.layers.feedforward.FlattenLayer; 8 | import javamachinelearning.layers.feedforward.MaxPoolingLayer; 9 | import javamachinelearning.layers.feedforward.ConvLayer.PaddingType; 10 | import javamachinelearning.networks.SequentialNN; 11 | import javamachinelearning.optimizers.AdamOptimizer; 12 | import javamachinelearning.utils.Activation; 13 | import javamachinelearning.utils.Loss; 14 | import javamachinelearning.utils.MNISTUtils; 15 | import javamachinelearning.utils.Tensor; 16 | import javamachinelearning.utils.Utils; 17 | 18 | public class TrainMNISTConv{ 19 | public static void main(String[] args) throws Exception{ 20 | // very slow! 21 | 22 | SequentialNN nn = new SequentialNN(28, 28, 1); 23 | nn.add(new ConvLayer(5, 32, PaddingType.SAME)); 24 | nn.add(new ActivationLayer(Activation.relu)); 25 | nn.add(new MaxPoolingLayer(2, 2)); 26 | nn.add(new ConvLayer(5, 64, PaddingType.SAME)); 27 | nn.add(new ActivationLayer(Activation.relu)); 28 | nn.add(new MaxPoolingLayer(2, 2)); 29 | nn.add(new FlattenLayer()); 30 | nn.add(new FCLayer(1024)); 31 | nn.add(new ActivationLayer(Activation.relu)); 32 | nn.add(new DropoutLayer(0.3)); 33 | nn.add(new FCLayer(10)); 34 | nn.add(new ActivationLayer(Activation.softmax)); 35 | 36 | System.out.println(nn); 37 | 38 | Tensor[] x = MNISTUtils.loadDataSetImages("train-images-idx3-ubyte", Integer.MAX_VALUE); 39 | Tensor[] y = MNISTUtils.loadDataSetLabels("train-labels-idx1-ubyte", Integer.MAX_VALUE); 40 | 41 | long start = System.currentTimeMillis(); 42 | 43 | nn.train(Utils.reshapeAll(x, 28, 28, 1), y, 100, 100, Loss.softmaxCrossEntropy, new AdamOptimizer(0.01), null, true, false); 44 | 45 | System.out.println("Training time: " + Utils.formatElapsedTime(System.currentTimeMillis() - start)); 46 | 47 | nn.saveToFile("mnist_weights_conv.nn"); 48 | 49 | Tensor[] testX = MNISTUtils.loadDataSetImages("t10k-images-idx3-ubyte", Integer.MAX_VALUE); 50 | Tensor[] testY = MNISTUtils.loadDataSetLabels("t10k-labels-idx1-ubyte", Integer.MAX_VALUE); 51 | Tensor[] testResult = nn.predict(Utils.reshapeAll(testX, 28, 28, 1)); 52 | 53 | System.out.println("Classification accuracy: " + Utils.format(Utils.classificationAccuracy(testResult, testY))); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/tests/TrainMNISTConvMemorize.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import javamachinelearning.layers.feedforward.ActivationLayer; 4 | import javamachinelearning.layers.feedforward.ConvLayer; 5 | import javamachinelearning.layers.feedforward.DropoutLayer; 6 | import javamachinelearning.layers.feedforward.FCLayer; 7 | import javamachinelearning.layers.feedforward.FlattenLayer; 8 | import javamachinelearning.layers.feedforward.MaxPoolingLayer; 9 | import javamachinelearning.layers.feedforward.ConvLayer.PaddingType; 10 | import javamachinelearning.networks.SequentialNN; 11 | import javamachinelearning.optimizers.AdamOptimizer; 12 | import javamachinelearning.utils.Activation; 13 | import javamachinelearning.utils.Loss; 14 | import javamachinelearning.utils.MNISTUtils; 15 | import javamachinelearning.utils.Tensor; 16 | import javamachinelearning.utils.Utils; 17 | 18 | public class TrainMNISTConvMemorize{ 19 | public static void main(String[] args) throws Exception{ 20 | // training on the full MNIST data set is way too slow 21 | // to verify that the convolutional layers work, it is tested to memorize MNIST images 22 | 23 | // builds a convolutional neural network 24 | SequentialNN nn = new SequentialNN(28, 28, 1); 25 | 26 | nn.add(new ConvLayer(5, 32, PaddingType.SAME)); 27 | nn.add(new ActivationLayer(Activation.relu)); 28 | nn.add(new MaxPoolingLayer(2, 2)); 29 | 30 | nn.add(new ConvLayer(5, 64, PaddingType.SAME)); 31 | nn.add(new ActivationLayer(Activation.relu)); 32 | nn.add(new MaxPoolingLayer(2, 2)); 33 | 34 | nn.add(new FlattenLayer()); 35 | 36 | nn.add(new FCLayer(1024)); 37 | nn.add(new ActivationLayer(Activation.relu)); 38 | 39 | nn.add(new DropoutLayer(0.3)); 40 | 41 | nn.add(new FCLayer(10)); 42 | nn.add(new ActivationLayer(Activation.softmax)); 43 | 44 | // loads the training data (only the first 100 images) 45 | Tensor[] x = MNISTUtils.loadDataSetImages("train-images-idx3-ubyte", 100); 46 | Tensor[] y = MNISTUtils.loadDataSetLabels("train-labels-idx1-ubyte", 100); 47 | 48 | long start = System.currentTimeMillis(); 49 | 50 | nn.train(Utils.reshapeAll(x, 28, 28, 1), 51 | y, 52 | 20, // number of epochs 53 | 10, // batch size 54 | Loss.softmaxCrossEntropy, 55 | new AdamOptimizer(0.001), 56 | null, // no regularization 57 | true, // shuffle 58 | false); 59 | 60 | System.out.println("Training time: " + Utils.formatElapsedTime(System.currentTimeMillis() - start)); 61 | 62 | // test on the images that the network was trained on 63 | Tensor[] testResult = nn.predict(Utils.reshapeAll(x, 28, 28, 1)); 64 | 65 | System.out.println("Memorization accuracy: " + Utils.format(Utils.classificationAccuracy(testResult, y))); 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /src/tests/TrainMNISTFullyConnected.java: -------------------------------------------------------------------------------- 1 | package tests; 2 | 3 | import javamachinelearning.layers.feedforward.ActivationLayer; 4 | import javamachinelearning.layers.feedforward.FCLayer; 5 | import javamachinelearning.networks.SequentialNN; 6 | import javamachinelearning.optimizers.MomentumOptimizer; 7 | import javamachinelearning.regularizers.L2Regularizer; 8 | import javamachinelearning.utils.Activation; 9 | import javamachinelearning.utils.Loss; 10 | import javamachinelearning.utils.MNISTUtils; 11 | import javamachinelearning.utils.Tensor; 12 | import javamachinelearning.utils.Utils; 13 | 14 | public class TrainMNISTFullyConnected{ 15 | public static void main(String[] args){ 16 | // create a model with 784 input neurons, 300 hidden neurons, and 10 output neurons 17 | // use RELU for the hidden layer and softmax for the output layer 18 | SequentialNN nn = new SequentialNN(784); 19 | nn.add(new FCLayer(300)); 20 | nn.add(new ActivationLayer(Activation.relu)); 21 | nn.add(new FCLayer(10)); // 10 categories of numbers 22 | nn.add(new ActivationLayer(Activation.softmax)); 23 | 24 | // load the training data 25 | Tensor[] x = MNISTUtils.loadDataSetImages("train-images-idx3-ubyte", Integer.MAX_VALUE); 26 | Tensor[] y = MNISTUtils.loadDataSetLabels("train-labels-idx1-ubyte", Integer.MAX_VALUE); 27 | 28 | long start = System.currentTimeMillis(); 29 | 30 | nn.train(Utils.flattenAll(x), 31 | y, 32 | 100, // number of epochs 33 | 100, // batch size 34 | Loss.softmaxCrossEntropy, 35 | new MomentumOptimizer(0.5), 36 | new L2Regularizer(0.0001), 37 | true, // shuffle the data after every epoch 38 | false); 39 | 40 | System.out.println("Training time: " + Utils.formatElapsedTime(System.currentTimeMillis() - start)); 41 | 42 | // save the learned weights 43 | nn.saveToFile("mnist_weights_fc.nn"); 44 | 45 | // predict on previously unseen testing data 46 | Tensor[] testX = MNISTUtils.loadDataSetImages("t10k-images-idx3-ubyte", Integer.MAX_VALUE); 47 | Tensor[] testY = MNISTUtils.loadDataSetLabels("t10k-labels-idx1-ubyte", Integer.MAX_VALUE); 48 | Tensor[] testResult = nn.predict(Utils.flattenAll(testX)); 49 | 50 | // prints the percent of images classified correctly 51 | System.out.println("Classification accuracy: " + Utils.format(Utils.classificationAccuracy(testResult, testY))); 52 | } 53 | } 54 | --------------------------------------------------------------------------------