├── .gitignore ├── src └── main │ └── java │ └── net │ └── jacobpeterson │ └── ed │ ├── DirectionDataSetIterator.java │ ├── PixelDisplay.java │ ├── SampleDataGenerator.java │ └── DL4JTest.java └── pom.xml /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | .DS_Store 3 | *.iml 4 | .idea/ 5 | out/ 6 | -------------------------------------------------------------------------------- /src/main/java/net/jacobpeterson/ed/DirectionDataSetIterator.java: -------------------------------------------------------------------------------- 1 | package net.jacobpeterson.ed; 2 | 3 | import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator; 4 | import org.nd4j.linalg.api.ndarray.INDArray; 5 | import org.nd4j.linalg.dataset.DataSet; 6 | import org.nd4j.linalg.primitives.Pair; 7 | 8 | import java.util.function.Consumer; 9 | 10 | public class DirectionDataSetIterator extends INDArrayDataSetIterator { 11 | 12 | public DirectionDataSetIterator(Iterable> iterable, int batchSize) { 13 | super(iterable, batchSize); 14 | } 15 | 16 | @Override 17 | public void forEachRemaining(Consumer action) { 18 | super.forEachRemaining(action); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/net/jacobpeterson/ed/PixelDisplay.java: -------------------------------------------------------------------------------- 1 | package net.jacobpeterson.ed; 2 | 3 | import org.nd4j.linalg.api.ndarray.INDArray; 4 | import org.nd4j.linalg.factory.Nd4j; 5 | 6 | import javax.swing.JFrame; 7 | import javax.swing.JPanel; 8 | import javax.swing.SwingUtilities; 9 | import java.awt.Color; 10 | import java.awt.Dimension; 11 | import java.awt.Graphics; 12 | import java.awt.Graphics2D; 13 | 14 | public class PixelDisplay { 15 | 16 | private JFrame frame; 17 | private JPanel contentPane; 18 | private INDArray grayScaleRaster; 19 | 20 | public PixelDisplay() { 21 | this.frame = new JFrame(); 22 | this.grayScaleRaster = Nd4j.create(new double[4]); 23 | this.setupWindow(); 24 | } 25 | 26 | private void setupWindow() { 27 | contentPane = new JPanel() { 28 | @Override 29 | public void paint(Graphics g) { 30 | Graphics2D graphics = (Graphics2D) g; 31 | Dimension size = getSize(); 32 | 33 | int topLeftPixel = (int) (grayScaleRaster.getDouble(0) * 255); 34 | int topRightPixel = (int) (grayScaleRaster.getDouble(1) * 255); 35 | int bottomLeftPixel = (int) (grayScaleRaster.getDouble(2) * 255); 36 | int bottomRightPixel = (int) (grayScaleRaster.getDouble(3) * 255); 37 | 38 | graphics.setColor(new Color(topLeftPixel, topLeftPixel, topLeftPixel)); 39 | graphics.fillRect(0, 0, size.width / 2, size.height / 2); 40 | graphics.setColor(new Color(topRightPixel, topRightPixel, topRightPixel)); 41 | graphics.fillRect(size.width / 2, 0, size.width / 2, size.height / 2); 42 | graphics.setColor(new Color(bottomLeftPixel, bottomLeftPixel, bottomLeftPixel)); 43 | graphics.fillRect(0, size.height / 2, size.width / 2, size.height / 2); 44 | graphics.setColor(new Color(bottomRightPixel, bottomRightPixel, bottomRightPixel)); 45 | graphics.fillRect(size.width / 2, size.height / 2, size.width / 2, size.height / 2); 46 | } 47 | }; 48 | 49 | frame.setContentPane(contentPane); 50 | frame.setSize(500, 500); 51 | frame.setLocationRelativeTo(null); // Centers JFrame 52 | frame.setVisible(true); 53 | } 54 | 55 | public void setGrayScaleRaster(INDArray grayScaleRaster) { 56 | this.grayScaleRaster = grayScaleRaster; 57 | SwingUtilities.invokeLater(() -> contentPane.repaint()); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | net.jacobpeterson.ed 8 | DL4JTest 9 | 1.0 10 | 11 | 12 | net.jacobpeterson.ed.DL4JTest 13 | 14 | 15 | 16 | 17 | org.deeplearning4j 18 | deeplearning4j-core 19 | 1.0.0-beta3 20 | 21 | 22 | org.nd4j 23 | nd4j-native-platform 24 | 1.0.0-beta3 25 | 26 | 27 | org.slf4j 28 | slf4j-api 29 | 1.7.5 30 | 31 | 32 | org.slf4j 33 | slf4j-simple 34 | 1.6.4 35 | 36 | 37 | 38 | 39 | 40 | 41 | org.apache.maven.plugins 42 | maven-compiler-plugin 43 | 3.8.0 44 | 45 | 1.8 46 | 1.8 47 | 48 | 49 | 50 | org.apache.maven.plugins 51 | maven-assembly-plugin 52 | 3.1.1 53 | 54 | 55 | 56 | net.jacobpeterson.ed.DL4JTest 57 | 58 | 59 | 60 | jar-with-dependencies 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | -------------------------------------------------------------------------------- /src/main/java/net/jacobpeterson/ed/SampleDataGenerator.java: -------------------------------------------------------------------------------- 1 | package net.jacobpeterson.ed; 2 | 3 | import org.nd4j.linalg.api.ndarray.INDArray; 4 | import org.nd4j.linalg.factory.Nd4j; 5 | import org.nd4j.linalg.primitives.Pair; 6 | 7 | import java.util.concurrent.ThreadLocalRandom; 8 | 9 | public class SampleDataGenerator { 10 | 11 | public static Pair generate2x2Sample(SampleDirection direction) { 12 | SampleDirection[] sampleDirections = SampleDirection.values(); 13 | ThreadLocalRandom random = ThreadLocalRandom.current(); 14 | 15 | if (direction == null) { 16 | direction = sampleDirections[random.nextInt(sampleDirections.length)]; 17 | } 18 | 19 | int imageLength = 2 * 2; // 2x2 grayscale image 20 | int numberOfDirections = sampleDirections.length; 21 | 22 | double[] featureArray = new double[imageLength]; // Feature 23 | double[] labelArray = new double[numberOfDirections]; // Label 24 | 25 | double tinyPixelOffset = 0.05; 26 | double similarPixelBase = random.nextDouble(); 27 | 28 | switch (direction) { 29 | case VERTICAL: 30 | featureArray[0] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 31 | // featureArray[1] = clamp(random.nextDouble() + 32 | // random.nextDouble(-similarPixelBase / 3, similarPixelBase / 3), 0, 1); 33 | featureArray[1] = random.nextDouble(0.3D); 34 | featureArray[2] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 35 | // featureArray[3] = clamp(random.nextDouble() + 36 | // random.nextDouble(-similarPixelBase / 3, similarPixelBase / 3), 0, 1); 37 | featureArray[3] = random.nextDouble(0.7D, 1D); 38 | 39 | // Perform raster column switching for half of the samples 40 | if (random.nextDouble() > 0.5d) { 41 | double topRightTemp = featureArray[1]; 42 | double bottomRightTemp = featureArray[3]; 43 | 44 | featureArray[1] = featureArray[0]; 45 | featureArray[3] = featureArray[2]; 46 | 47 | featureArray[0] = topRightTemp; 48 | featureArray[2] = bottomRightTemp; 49 | } 50 | break; 51 | case HORIZONTAL: 52 | featureArray[0] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 53 | featureArray[1] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 54 | // featureArray[2] = clamp(random.nextDouble() + 55 | // random.nextDouble(-similarPixelBase / 2, similarPixelBase / 2), 0, 1); 56 | // featureArray[3] = clamp(random.nextDouble() + 57 | // random.nextDouble(-similarPixelBase / 2, similarPixelBase / 2), 0, 1); 58 | featureArray[2] = random.nextDouble(0.33D); 59 | featureArray[3] = random.nextDouble(0.66D, 1D); 60 | 61 | // Perform raster row switching for half of the samples 62 | if (random.nextDouble() > 0.5d) { 63 | double topLeftTemp = featureArray[0]; 64 | double topRightTemp = featureArray[1]; 65 | 66 | featureArray[0] = featureArray[2]; 67 | featureArray[1] = featureArray[3]; 68 | 69 | featureArray[2] = topLeftTemp; 70 | featureArray[3] = topRightTemp; 71 | } 72 | break; 73 | case DIAGONAL: 74 | featureArray[0] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 75 | // featureArray[1] = clamp(random.nextDouble() + 76 | // random.nextDouble(-similarPixelBase / 2, similarPixelBase / 2), 0, 1); 77 | featureArray[1] = random.nextDouble(0.33D); 78 | // featureArray[2] = clamp(random.nextDouble() + 79 | // random.nextDouble(-similarPixelBase / 2, similarPixelBase / 2), 0, 1); 80 | featureArray[2] = random.nextDouble(0.66D, 1D); 81 | featureArray[3] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 82 | 83 | // Perform raster diagonal switching for half of the samples 84 | if (random.nextDouble() > 0.5d) { 85 | double topLeftTemp = featureArray[0]; 86 | double topRightTemp = featureArray[1]; 87 | 88 | featureArray[0] = featureArray[2]; 89 | featureArray[1] = featureArray[3]; 90 | 91 | featureArray[2] = topLeftTemp; 92 | featureArray[3] = topRightTemp; 93 | } 94 | break; 95 | case SOLID: 96 | featureArray[0] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 97 | featureArray[1] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 98 | featureArray[2] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 99 | featureArray[3] = clamp(similarPixelBase + random.nextDouble(-tinyPixelOffset, tinyPixelOffset), 0, 1); 100 | break; 101 | case UNKNOWN: 102 | featureArray[0] = random.nextDouble(0, 0.25); 103 | featureArray[1] = random.nextDouble(0.25, 0.5); 104 | featureArray[2] = random.nextDouble(0.5, 0.75); 105 | featureArray[3] = random.nextDouble(0.75, 1); 106 | shuffleArray(random, featureArray); 107 | break; 108 | default: 109 | throw new IllegalStateException("No case known for: " + direction); 110 | } 111 | 112 | // Assign which output neuron should be a 1 and which ones should be 0 113 | labelArray[direction.getOutputNeuronIndex()] = 1d; 114 | 115 | INDArray featureNDArray = Nd4j.create(featureArray); 116 | INDArray labelNDArray = Nd4j.create(labelArray); 117 | 118 | return new Pair<>(featureNDArray, labelNDArray); 119 | } 120 | 121 | private static double clamp(double value, double min, double max) { 122 | return Math.max(min, Math.min(max, value)); 123 | } 124 | 125 | // Implementing Fisher–Yates shuffle from StackOverflow 126 | private static void shuffleArray(ThreadLocalRandom random, double[] ar) { 127 | for (int i = ar.length - 1; i > 0; i--) { 128 | int index = random.nextInt(i + 1); 129 | // Simple swap 130 | double a = ar[index]; 131 | ar[index] = ar[i]; 132 | ar[i] = a; 133 | } 134 | } 135 | 136 | public enum SampleDirection { 137 | // DON'T CHANGE THIS ORDER (unless you want to recreate the network) 138 | VERTICAL(0), 139 | HORIZONTAL(1), 140 | DIAGONAL(2), 141 | SOLID(3), 142 | UNKNOWN(4); 143 | 144 | private final int outputNeuronIndex; 145 | 146 | SampleDirection(int outputNeuronIndex) { 147 | this.outputNeuronIndex = outputNeuronIndex; 148 | } 149 | 150 | public int getOutputNeuronIndex() { 151 | return outputNeuronIndex; 152 | } 153 | 154 | public static SampleDirection fromNeuronIndex(int outputNeuronIndex) { 155 | for (SampleDirection sampleDirection : values()) { 156 | if (sampleDirection.getOutputNeuronIndex() == outputNeuronIndex) { 157 | return sampleDirection; 158 | } 159 | } 160 | return null; 161 | } 162 | } 163 | } 164 | -------------------------------------------------------------------------------- /src/main/java/net/jacobpeterson/ed/DL4JTest.java: -------------------------------------------------------------------------------- 1 | package net.jacobpeterson.ed; 2 | 3 | import org.deeplearning4j.common.resources.DL4JResources; 4 | import org.deeplearning4j.common.resources.ResourceType; 5 | import org.deeplearning4j.datasets.iterator.INDArrayDataSetIterator; 6 | import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator; 7 | import org.deeplearning4j.nn.api.Model; 8 | import org.deeplearning4j.nn.api.OptimizationAlgorithm; 9 | import org.deeplearning4j.nn.conf.BackpropType; 10 | import org.deeplearning4j.nn.conf.MultiLayerConfiguration; 11 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration; 12 | import org.deeplearning4j.nn.conf.layers.DenseLayer; 13 | import org.deeplearning4j.nn.conf.layers.OutputLayer; 14 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; 15 | import org.deeplearning4j.nn.weights.WeightInit; 16 | import org.deeplearning4j.optimize.api.BaseTrainingListener; 17 | import org.deeplearning4j.optimize.listeners.ScoreIterationListener; 18 | import org.deeplearning4j.util.ModelSerializer; 19 | import org.nd4j.evaluation.classification.Evaluation; 20 | import org.nd4j.linalg.activations.Activation; 21 | import org.nd4j.linalg.api.ndarray.INDArray; 22 | import org.nd4j.linalg.lossfunctions.LossFunctions; 23 | import org.nd4j.linalg.primitives.Pair; 24 | 25 | import java.io.File; 26 | import java.io.IOException; 27 | import java.util.ArrayList; 28 | import java.util.List; 29 | 30 | public class DL4JTest { 31 | 32 | public static void main(String[] args) throws Exception { 33 | File networkSaveFile = new File(System.getProperty("user.home"), ".pixel_direction_nn"); 34 | 35 | // createAndTrainPixelDirectionNetwork(networkSaveFile); 36 | loadAndDisplayPixelDirectionNetwork(networkSaveFile); 37 | } 38 | 39 | public static void createAndTrainPixelDirectionNetwork(File networkSaveFile) throws Exception { 40 | int trainSampleSize = 100_000; 41 | int testSampleSize = trainSampleSize; 42 | int inputNeurons = 2 * 2; // 2x2 grayscale image input 43 | int epochs = 20; 44 | double l2value = 1e-10; 45 | 46 | // Generate Training Data 47 | List> trainList = new ArrayList<>(); 48 | for (int i = 0; i < trainSampleSize; i++) { 49 | trainList.add(SampleDataGenerator.generate2x2Sample(null)); 50 | } 51 | // Generate Testing Data 52 | List> testList = new ArrayList<>(); 53 | for (int i = 0; i < testSampleSize; i++) { 54 | testList.add(SampleDataGenerator.generate2x2Sample(null)); 55 | } 56 | 57 | // Create DataSet iterators for generated samples 58 | INDArrayDataSetIterator dataSetIteratorTrain = new INDArrayDataSetIterator(trainList, 5); 59 | INDArrayDataSetIterator dataSetIteratorTest = new INDArrayDataSetIterator(testList, 5); 60 | 61 | // Configure Hyperparameters for NN 62 | MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() 63 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 64 | .l2(l2value) 65 | .list() 66 | .layer(new DenseLayer.Builder() 67 | .nIn(inputNeurons) 68 | .nOut(100) 69 | .activation(Activation.RELU) 70 | .weightInit(WeightInit.XAVIER) 71 | .build() 72 | ).layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) 73 | .nIn(100) 74 | .nOut(SampleDataGenerator.SampleDirection.values().length) 75 | .activation(Activation.SOFTMAX) 76 | .weightInit(WeightInit.XAVIER) 77 | .build()) 78 | .backpropType(BackpropType.Standard) 79 | .pretrain(false) 80 | .build(); 81 | 82 | // Construct Network 83 | System.out.println("Constructing Network"); 84 | MultiLayerNetwork network = new MultiLayerNetwork(config); 85 | network.init(); 86 | network.addListeners(new BaseTrainingListener() { 87 | private int printIterations = trainSampleSize / 10; 88 | 89 | @Override 90 | public void iterationDone(Model model, int iteration, int epoch) { 91 | if (iteration % printIterations == 0) { 92 | // System.out.print("\r"); 93 | System.out.println("Iteration: " + iteration + " Epoch: " + epoch + " Score: " + model.score()); 94 | } 95 | } 96 | }); 97 | // network.addListeners(new TimeIterationListener(2)); 98 | // network.addListeners(new ScoreIterationListener(10)); // THIS NOW WORKS WITH THE UPDATED POM.XML 99 | 100 | 101 | System.out.println("Training network"); 102 | long currentTime = System.currentTimeMillis(); 103 | // Train network 104 | network.fit(dataSetIteratorTrain, epochs); // TRAIN NETWORK WITH CERTAIN EPOCH COUNT 105 | System.out.println("Training took: " + ((double) (System.currentTimeMillis() - currentTime) / 1000 / 60) + 106 | " minutes"); 107 | 108 | if (networkSaveFile != null) { 109 | if (networkSaveFile.exists()) { 110 | System.out.println("Network Save File already exists! Not saving network to file."); 111 | } else { 112 | System.out.println("Saving Network to: " + networkSaveFile.getAbsolutePath()); 113 | network.save(networkSaveFile, true); 114 | } 115 | } 116 | 117 | System.out.println("Evaluating network"); 118 | // Evaluate Network 119 | Evaluation eval = network.evaluate(dataSetIteratorTest); 120 | System.out.println(eval.stats()); 121 | } 122 | 123 | public static void loadAndDisplayPixelDirectionNetwork(File networkSaveFile) throws Exception { 124 | MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(networkSaveFile, true); 125 | System.out.println("Loaded network from: " + networkSaveFile.getAbsolutePath()); 126 | 127 | // Generate Testing Data 128 | int testSampleSize = 50_000; 129 | List> testList = new ArrayList<>(); 130 | for (int i = 0; i < testSampleSize; i++) { 131 | testList.add(SampleDataGenerator.generate2x2Sample(null)); 132 | } 133 | 134 | // Create DataSet iterators for generated samples 135 | INDArrayDataSetIterator dataSetIteratorTest = new INDArrayDataSetIterator(testList, 1); 136 | 137 | System.out.println("Evaluating network"); 138 | // Evaluate Network 139 | Evaluation eval = network.evaluate(dataSetIteratorTest); 140 | System.out.println(eval.stats()); 141 | 142 | System.out.println("Executing intermittent testing and displaying"); 143 | PixelDisplay pixelDisplay = new PixelDisplay(); 144 | while (true) { 145 | Pair testRaster = SampleDataGenerator.generate2x2Sample(null); 146 | pixelDisplay.setGrayScaleRaster(testRaster.getKey()); 147 | 148 | INDArray expectedNDArray = testRaster.getValue(); 149 | SampleDataGenerator.SampleDirection expectedDirection = null; 150 | for (int i = 0; i < expectedNDArray.length(); i++) { 151 | if (expectedNDArray.getDouble(i) == 1D) { 152 | expectedDirection = SampleDataGenerator.SampleDirection.fromNeuronIndex(i); 153 | break; 154 | } 155 | } 156 | SampleDataGenerator.SampleDirection predictedDirection = 157 | SampleDataGenerator.SampleDirection.fromNeuronIndex(network.predict(testRaster.getKey())[0]); 158 | System.out.println( 159 | "SAME?: " + (predictedDirection.equals(expectedDirection) ? "Yes " : "No ") + 160 | "Expected: " + expectedDirection.name() + 161 | " Prediction: " + predictedDirection.name() 162 | ); 163 | Thread.sleep(3000); 164 | } 165 | } 166 | 167 | public static void displayAndTestPixelDirections() throws Exception { 168 | PixelDisplay pixelDisplay2 = new PixelDisplay(); 169 | while (true) { 170 | Pair testRaster = 171 | SampleDataGenerator.generate2x2Sample(null); 172 | pixelDisplay2.setGrayScaleRaster(testRaster.getKey()); 173 | 174 | for (int i = 0; i < 5; i++) { 175 | if (testRaster.getValue().getDouble(i) == 1D) { 176 | System.out.println(SampleDataGenerator.SampleDirection.values()[i].name()); 177 | break; 178 | } 179 | } 180 | 181 | Thread.sleep(2000); 182 | } 183 | } 184 | 185 | public static void executeMNISTExample() throws IOException { 186 | System.out.println("Loading EmnistDataSet"); 187 | System.out.println("HI:" + DL4JResources.getDirectory(ResourceType.DATASET, "EMNIST").getAbsolutePath()); 188 | 189 | EmnistDataSetIterator.Set emnistSet = EmnistDataSetIterator.Set.DIGITS; 190 | 191 | EmnistDataSetIterator emnistTrain = new EmnistDataSetIterator(emnistSet, 15, true); 192 | EmnistDataSetIterator emnistTest = new EmnistDataSetIterator(emnistSet, 15, false); 193 | 194 | int outputNum = EmnistDataSetIterator.numLabels(emnistSet); 195 | int rngSeed = 123; 196 | int rows = 28; 197 | int cols = 28; 198 | 199 | System.out.println("Configuration"); 200 | 201 | MultiLayerConfiguration config = new NeuralNetConfiguration.Builder() 202 | .seed(rngSeed) 203 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 204 | .l2(1e-4) 205 | .list() 206 | .layer(new DenseLayer.Builder() 207 | .nIn(rows * cols) 208 | .nOut(120) 209 | .activation(Activation.RELU) 210 | .weightInit(WeightInit.XAVIER) 211 | .build() 212 | ).layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) 213 | .nIn(120) 214 | .nOut(outputNum) 215 | .activation(Activation.SOFTMAX) 216 | .weightInit(WeightInit.XAVIER) 217 | .build()) 218 | .pretrain(false) 219 | .backprop(true) 220 | .build(); 221 | 222 | System.out.println("Init network"); 223 | 224 | MultiLayerNetwork network = new MultiLayerNetwork(config); 225 | network.setListeners(new ScoreIterationListener(1)); 226 | network.init(); 227 | 228 | System.out.println("Configuration"); 229 | 230 | System.out.println("Training network"); 231 | network.fit(emnistTrain); 232 | 233 | System.out.println("Evaluating"); 234 | Evaluation eval = network.evaluate(emnistTest); 235 | System.out.println(eval.stats()); 236 | 237 | // List strings = network.predict(emnistTest.next().get(0)); 238 | // System.out.println(strings.toString()); 239 | } 240 | } 241 | --------------------------------------------------------------------------------