├── src └── main │ └── java │ └── com │ └── github │ └── tjake │ ├── rbm │ ├── minst │ │ ├── MinstItem.java │ │ ├── Demo.java │ │ ├── GenerativeMinstDBN.java │ │ ├── MinstDatasetReader.java │ │ ├── BinaryMinstDBN.java │ │ └── BinaryMinstRBM.java │ ├── Tuple.java │ ├── Layer.java │ ├── BinaryLayer.java │ ├── LayerFactory.java │ ├── GaussianLayer.java │ ├── StackedRBMTrainer.java │ ├── SimpleRBMTrainer.java │ ├── StackedRBM.java │ └── SimpleRBM.java │ └── util │ └── Utilities.java ├── LICENSE ├── readme.md └── pom.xml /src/main/java/com/github/tjake/rbm/minst/MinstItem.java: -------------------------------------------------------------------------------- 1 | package com.github.tjake.rbm.minst; 2 | 3 | /** 4 | * Container class that represents a Minst image and it's label 5 | */ 6 | public class MinstItem 7 | { 8 | public String label; 9 | public int[] data; 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/com/github/tjake/rbm/Tuple.java: -------------------------------------------------------------------------------- 1 | package com.github.tjake.rbm; 2 | 3 | public class Tuple 4 | { 5 | public final Layer visible; 6 | public final Layer hidden; 7 | public final Layer input; //For a DBN this is the initial input layer 8 | 9 | protected Tuple(Layer input, Layer visible, Layer hidden) 10 | { 11 | this.input = input; 12 | this.visible = visible; 13 | this.hidden = hidden; 14 | } 15 | 16 | public static class Factory { 17 | 18 | public final Layer input; 19 | 20 | public Factory(Layer input) { 21 | this.input = input; 22 | } 23 | 24 | public Tuple create(Layer visible, Layer hidden) 25 | { 26 | return new Tuple(input,visible,hidden); 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2013 T Jake Luciani 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the 'Software'), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so, 8 | subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /src/main/java/com/github/tjake/util/Utilities.java: -------------------------------------------------------------------------------- 1 | package com.github.tjake.util; 2 | 3 | 4 | import com.github.tjake.rbm.Layer; 5 | import java.util.Random; 6 | 7 | public class Utilities { 8 | 9 | static Random staticRand = new Random(); 10 | 11 | public static float mean(final Layer input) 12 | { 13 | float m = 0.0f; 14 | for (int i=0; i 0) 51 | System.err.println(err); 52 | 53 | System.exit(-1); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/com/github/tjake/rbm/BinaryLayer.java: -------------------------------------------------------------------------------- 1 | package com.github.tjake.rbm; 2 | 3 | 4 | /* 5 | * Converts grayscale intensities to binary values 6 | */ 7 | public class BinaryLayer extends Layer { 8 | 9 | final Layer delegate; 10 | 11 | 12 | public BinaryLayer(Layer delegate) 13 | { 14 | super(null); 15 | this.delegate = delegate; 16 | 17 | convertToBinary(); 18 | } 19 | 20 | private void convertToBinary() 21 | { 22 | for (int i=0; i 30 ? 1.0f : 0.0f); 26 | } 27 | } 28 | 29 | public static float[] fromBinary(Layer delegate) { 30 | float [] output = new float[delegate.size()]; 31 | for (int i = 0; i < output.length; i++) { 32 | output[i] = delegate.get(i) * 255.0f; 33 | } 34 | return output; 35 | } 36 | 37 | 38 | @Override 39 | public void set(int i, float f) { 40 | delegate.set(i,f); 41 | } 42 | 43 | @Override 44 | public float get(int i) { 45 | return delegate.get(i); 46 | } 47 | 48 | @Override 49 | public void add(int i, float f) { 50 | delegate.add(i,f); 51 | } 52 | 53 | @Override 54 | public void div(int i, float f) { 55 | delegate.div(i,f); 56 | } 57 | 58 | @Override 59 | public void mult(int i, float f) { 60 | delegate.div(i,f); 61 | } 62 | 63 | @Override 64 | public int size() { 65 | return delegate.size(); 66 | } 67 | 68 | @Override 69 | public Layer clone() { 70 | return delegate.clone(); 71 | } 72 | 73 | @Override 74 | public void clear() { 75 | delegate.clear(); 76 | } 77 | 78 | @Override 79 | public void copy(float[] src) { 80 | delegate.copy(src); 81 | } 82 | 83 | @Override 84 | public float[] get() { 85 | return delegate.get(); 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/main/java/com/github/tjake/rbm/LayerFactory.java: -------------------------------------------------------------------------------- 1 | package com.github.tjake.rbm; 2 | 3 | 4 | import java.awt.image.BufferedImage; 5 | import java.io.DataInput; 6 | import java.io.DataOutput; 7 | import java.io.IOException; 8 | import java.util.Arrays; 9 | 10 | public class LayerFactory { 11 | public static byte[] MAGIC = {(byte) 0xf0, (byte) 0x0d, (byte) 0x00, (byte) 0x0F}; 12 | 13 | public Layer create(int size) { 14 | return new Layer(size); 15 | } 16 | 17 | public Layer create(float[] start) { 18 | return new Layer(start); 19 | } 20 | 21 | public Layer create(BufferedImage img) { 22 | Layer layer = create(img.getWidth() * img.getHeight()); 23 | int width = 0, height = 0; 24 | for (int i = 0; i < layer.size(); i++) { 25 | layer.set(i, img.getData().getSample(width++, height, 0)); 26 | 27 | if (width >= img.getWidth()) { 28 | width = 0; 29 | height++; 30 | } 31 | } 32 | 33 | return layer; 34 | } 35 | 36 | 37 | public void save(Layer layer, DataOutput dataOutput) throws IOException { 38 | //First write magic # 39 | dataOutput.write(MAGIC); 40 | 41 | float[] floats = layer.get(); 42 | if (floats.length != layer.size()) 43 | throw new IOException("get().length != size()"); 44 | 45 | //Number of elements 46 | dataOutput.writeInt(layer.size()); 47 | 48 | for (int i = 0; i < floats.length; i++) 49 | dataOutput.writeFloat(floats[i]); 50 | } 51 | 52 | public Layer load(DataInput dataInput) throws IOException { 53 | byte[] magic = new byte[4]; 54 | dataInput.readFully(magic); 55 | 56 | if (!Arrays.equals(MAGIC, magic)) 57 | throw new IOException("Bad File Format"); 58 | 59 | int size = dataInput.readInt(); 60 | 61 | if (size < 0) 62 | throw new IOException("Invalid size"); 63 | 64 | float[] input = new float[size]; 65 | for (int i = 0; i < size; i++) 66 | input[i] = dataInput.readFloat(); 67 | 68 | 69 | return create(input); 70 | } 71 | 72 | public GaussianLayer createGaussian(BufferedImage img) { 73 | return new GaussianLayer(create(img)); 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | rbm-dbn-mnist 2 | ========== 3 | 4 | Learn more about this project from this blog post: 5 | 6 | http://tjake.github.com/blog/2013/02/18/resurgence-in-artificial-intelligence/ 7 | 8 | This project provides a implementation for a Restricted Boltzmann Machine and a Deep Belief Network 9 | 10 | It uses the [MNIST handwritten dataset](http://yann.lecun.com/exdb/mnist/) to illistrate an example RBM and DBN. 11 | 12 | Usage 13 | ===== 14 | 15 | From source build the project with maven: 16 | 17 | 1. mvn deploy 18 | 19 | This will build a single jar and download the mnist dataset. 20 | 21 | 2. java -jar target/rbm-dbn-mnist-0.0.1.jar 22 | 23 | Runs the app. shows the usage screen 24 | 25 | ```` 26 | Usage: [rbm minst-labels.gz minst-images.gz] 27 | [dbn minst-images.gz minst-labels.gz dbn.bin] 28 | [gen dbn.bin] 29 | ```` 30 | 31 | 3. java -jar target/rbm-dbn-mnist-0.0.1.jar rbm target/minst/train-labels-idx1-ubyte.gz target/minst/train-images-idx3-ubyte.gz 32 | 33 | Trains a single RBM with 100 hidden nodes. Each of the hidden nodes weights are rendered alongside the test digit in blue. 34 | 35 | ![RBM Demo](http://tjake.github.com/images/MinstRBM.png) 36 | 37 | 38 | 4. java -jar target/rbm-dbn-mnist-0.0.1.jar dbn target/minst/train-labels-idx1-ubyte.gz target/minst/train-images-idx3-ubyte.gz /tmp/dbn.bin 39 | 40 | Trains a Deep Belief Network made up of three RBMs. It learns to match pictures of digits with their corresponding label. It takes about 10m to train but once it's done it has ~95% accuracy rate. The trained DBN is saved to a file. 41 | 42 | 5. java -jar target/rbm-dbn-mnist-0.0.1.jar gen /tmp/dbn.bin 43 | 44 | Takes the trained DBN from step 4. and reverses the flow, generating a visual image of a digit from a digit label. 45 | 46 | License 47 | ======= 48 | 49 | Copyright 2013 T Jake Luciani 50 | 51 | Permission is hereby granted, free of charge, to any person obtaining a copy of 52 | this software and associated documentation files (the 'Software'), to deal in 53 | the Software without restriction, including without limitation the rights to 54 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 55 | the Software, and to permit persons to whom the Software is furnished to do so, 56 | subject to the following conditions: 57 | 58 | The above copyright notice and this permission notice shall be included in all 59 | copies or substantial portions of the Software. 60 | 61 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 62 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 63 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 64 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 65 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 66 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 67 | 68 | -------------------------------------------------------------------------------- /src/main/java/com/github/tjake/rbm/GaussianLayer.java: -------------------------------------------------------------------------------- 1 | package com.github.tjake.rbm; 2 | 3 | 4 | import com.github.tjake.util.Utilities; 5 | 6 | /** 7 | * Converts raw values to standard deviations 8 | */ 9 | public class GaussianLayer extends Layer{ 10 | 11 | final Layer delegate; 12 | private float mean; 13 | private float stddev; 14 | 15 | 16 | public GaussianLayer(Layer delegate) 17 | { 18 | super(null); 19 | this.delegate = delegate; 20 | 21 | convertToStddev(); 22 | 23 | } 24 | 25 | public GaussianLayer(Layer delegate, Layer base) 26 | { 27 | super(null); 28 | 29 | GaussianLayer gbase = (GaussianLayer) base; 30 | 31 | this.delegate = delegate; 32 | mean = gbase.mean; 33 | stddev = gbase.stddev; 34 | 35 | } 36 | 37 | private void convertToStddev() 38 | { 39 | mean = Utilities.mean(delegate); 40 | stddev = Utilities.stddev(delegate, mean); 41 | stddev = stddev < 0.1f ? 0.1f : stddev; 42 | 43 | double min = Double.MAX_VALUE, max = Double.MIN_VALUE; 44 | 45 | 46 | for (int i=0; i max) max = v; 50 | if (v < min) min = v; 51 | 52 | delegate.set(i, (float)v); 53 | } 54 | 55 | 56 | } 57 | 58 | public float[] fromGaussian() { 59 | double min = Double.MAX_VALUE, max = Double.MIN_VALUE; 60 | float [] output = new float[delegate.size()]; 61 | for (int i = 0; i < output.length; i++) { 62 | double v = delegate.get(i); 63 | 64 | //Squash > 2 sigma 65 | if (Math.abs(v) > 2) 66 | v /= 2; 67 | 68 | v = v * stddev + mean; 69 | 70 | if (v > max) max = v; 71 | if (v < min) min = v; 72 | 73 | output[i] = (float)(v < 0 ? 0 : v); 74 | output[i] = (float)(v > 255 ? 255 : v); 75 | } 76 | 77 | 78 | return output; 79 | } 80 | 81 | 82 | @Override 83 | public void set(int i, float f) { 84 | delegate.set(i,f); 85 | } 86 | 87 | @Override 88 | public float get(int i) { 89 | return delegate.get(i); 90 | } 91 | 92 | @Override 93 | public void add(int i, float f) { 94 | delegate.add(i,f); 95 | } 96 | 97 | @Override 98 | public void div(int i, float f) { 99 | delegate.div(i,f); 100 | } 101 | 102 | @Override 103 | public void mult(int i, float f) { 104 | delegate.div(i,f); 105 | } 106 | 107 | @Override 108 | public int size() { 109 | return delegate.size(); 110 | } 111 | 112 | @Override 113 | public Layer clone() { 114 | return delegate.clone(); 115 | } 116 | 117 | @Override 118 | public void clear() { 119 | delegate.clear(); 120 | } 121 | 122 | @Override 123 | public void copy(float[] src) { 124 | delegate.copy(src); 125 | } 126 | 127 | @Override 128 | public float[] get() { 129 | return delegate.get(); 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /src/main/java/com/github/tjake/rbm/StackedRBMTrainer.java: -------------------------------------------------------------------------------- 1 | package com.github.tjake.rbm; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Arrays; 5 | import java.util.Iterator; 6 | import java.util.List; 7 | 8 | public class StackedRBMTrainer { 9 | 10 | final StackedRBM stackedRBM; 11 | final SimpleRBMTrainer inputTrainer; 12 | final float momentum; 13 | final float l2; 14 | final Float targetSparsity; 15 | float learningRate; 16 | final LayerFactory layerFactory; 17 | 18 | public StackedRBMTrainer(StackedRBM stackedRBM, float momentum, float l2, Float targetSparsity, float learningRate, LayerFactory layerFactory ) 19 | { 20 | this.stackedRBM = stackedRBM; 21 | this.momentum = momentum; 22 | this.l2 = l2; 23 | this.targetSparsity = targetSparsity; 24 | this.learningRate = learningRate; 25 | this.layerFactory = layerFactory; 26 | 27 | inputTrainer = new SimpleRBMTrainer(momentum, l2, targetSparsity, learningRate, layerFactory ); 28 | } 29 | 30 | public void setLearningRate(float newRate){ 31 | learningRate = newRate; 32 | inputTrainer.learningRate = newRate; 33 | } 34 | 35 | //Starts at the bottom of the DBN and uses the output of one RBM as the input of 36 | //the next. This continues till it hits stopAt. Then it trains the RBM with the 37 | //mutated input batch. It also allows a second batch to be appended to a input batch 38 | //So you can combine a deep RBM feature with a second input. 39 | // 40 | //An example being features of a digit picture combined with the digit label. 41 | public double learn(List bottomBatch, List topBatch, int stopAt) 42 | { 43 | if (topBatch != null && !topBatch.isEmpty() && topBatch.size() != bottomBatch.size()) 44 | throw new IllegalArgumentException("TopBatch != BottomBatch"); 45 | 46 | if (stopAt < 0 || stopAt > stackedRBM.innerRBMs.size()) 47 | throw new IllegalArgumentException("Invalid stopAt"); 48 | 49 | 50 | List nextInputs = new ArrayList(bottomBatch); 51 | 52 | for (int i=0; i