├── .github └── workflows │ └── codeql.yml ├── .gitignore ├── LICENSE ├── README.md ├── data └── diabetesprediction │ └── diabetes.csv ├── plots ├── .keep └── exampleplot │ └── examplePlot.PNG ├── pom.xml ├── src ├── main │ └── java │ │ ├── de │ │ └── fhws │ │ │ └── easyml │ │ │ ├── ai │ │ │ ├── backpropagation │ │ │ │ ├── BackpropagationTrainer.java │ │ │ │ └── logger │ │ │ │ │ ├── BackpropagationLogger.java │ │ │ │ │ └── loggers │ │ │ │ │ └── ConsoleLogger.java │ │ │ ├── geneticneuralnet │ │ │ │ ├── NNRandomMutator.java │ │ │ │ ├── NNUniformCrossoverRecombiner.java │ │ │ │ ├── NeuralNetFitnessFunction.java │ │ │ │ ├── NeuralNetIndividual.java │ │ │ │ ├── NeuralNetPopulationSupplier.java │ │ │ │ └── NeuralNetSupplier.java │ │ │ └── neuralnetwork │ │ │ │ ├── Backpropagation.java │ │ │ │ ├── Layer.java │ │ │ │ ├── NeuralNet.java │ │ │ │ ├── activationfunction │ │ │ │ ├── ActivationFunction.java │ │ │ │ ├── Sigmoid.java │ │ │ │ └── Tanh.java │ │ │ │ └── costfunction │ │ │ │ ├── CostFunction.java │ │ │ │ └── SummedCostFunction.java │ │ │ ├── geneticalgorithm │ │ │ ├── GeneticAlgorithm.java │ │ │ ├── Individual.java │ │ │ ├── Population.java │ │ │ ├── evolution │ │ │ │ ├── Mutator.java │ │ │ │ ├── Recombiner.java │ │ │ │ ├── Selector.java │ │ │ │ ├── recombiners │ │ │ │ │ └── FillUpRecombiner.java │ │ │ │ └── selectors │ │ │ │ │ ├── EliteSelector.java │ │ │ │ │ ├── PercentageSelector.java │ │ │ │ │ ├── RouletteWheelSelector.java │ │ │ │ │ └── TournamentSelector.java │ │ │ ├── logger │ │ │ │ ├── Logger.java │ │ │ │ └── loggers │ │ │ │ │ ├── ConsoleLogger.java │ │ │ │ │ ├── IntervalConsoleLogger.java │ │ │ │ │ └── graphplotter │ │ │ │ │ ├── GraphPlotLogger.java │ │ │ │ │ └── lines │ │ │ │ │ ├── AvgFitnessLine.java │ │ │ │ │ ├── LineGenerator.java │ │ │ │ │ ├── MaxFitnessLine.java │ │ │ │ │ ├── NQuantilFitnessLine.java │ │ │ │ │ └── WorstFitnessLine.java │ │ │ ├── populationsupplier │ │ │ │ ├── PopulationByFileSupplier.java │ │ │ │ └── PopulationSupplier.java │ │ │ └── saver │ │ │ │ └── IntervalSaver.java │ │ │ ├── linearalgebra │ │ │ ├── ApplyAble.java │ │ │ ├── LinearAlgebra.java │ │ │ ├── Matrix.java │ │ │ ├── Randomizer.java │ │ │ └── Vector.java │ │ │ ├── logger │ │ │ └── LoggerInterface.java │ │ │ └── utility │ │ │ ├── FileHandler.java │ │ │ ├── ListUtility.java │ │ │ ├── MathUtility.java │ │ │ ├── MultiThreadHelper.java │ │ │ ├── StreamUtil.java │ │ │ ├── Validator.java │ │ │ ├── WarningLogger.java │ │ │ └── throwingintefaces │ │ │ ├── ExceptionPrintingRunnable.java │ │ │ └── ThrowingRunnable.java │ │ └── example │ │ ├── SimpleFunctionPredictionExample.java │ │ ├── SnakeGameExample │ │ ├── flatgame │ │ │ ├── GameGraphics.java │ │ │ ├── GameLogic.java │ │ │ ├── GraphicsWindow.java │ │ │ └── Paintable.java │ │ └── snakegame │ │ │ ├── Apple.java │ │ │ ├── Item.java │ │ │ ├── Main.java │ │ │ ├── Part.java │ │ │ ├── Snake.java │ │ │ ├── SnakeAi.java │ │ │ ├── SnakeGame.java │ │ │ └── SnakeGameLogic.java │ │ └── diabetesprediction │ │ ├── DiabetesDataSet.java │ │ ├── InputParser.java │ │ ├── Main.java │ │ └── backpropagation │ │ └── MainBackprop.java └── test │ └── java │ ├── testGeneticAlgorithmBlackBox │ ├── GeneticAlgorithmTester.java │ ├── Graph.java │ ├── TSP.java │ └── TestGeneticAlgorithm.java │ ├── testLinearAlgebra │ └── TestVector.java │ ├── testNetworkTrainer │ └── TestNetworkTrainerBlackBox.java │ ├── testNeuralNetwork │ ├── TestBackpropagation.java │ ├── TestInvalidArguments.java │ ├── TestNeuralNetMaths.java │ └── TestNeuralNetSaveAndRead.java │ └── testmultithreadhelper │ └── TestDoOnCollectionMethods.java └── testFiles └── .keep /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "master" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "master" ] 20 | schedule: 21 | - cron: '36 18 * * 6' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'java' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Use only 'java' to analyze code written in Java, Kotlin or both 38 | # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both 39 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 40 | 41 | steps: 42 | - name: Checkout repository 43 | uses: actions/checkout@v3 44 | 45 | # Initializes the CodeQL tools for scanning. 46 | - name: Initialize CodeQL 47 | uses: github/codeql-action/init@v2 48 | with: 49 | languages: ${{ matrix.language }} 50 | # If you wish to specify custom queries, you can do so here or in a config file. 51 | # By default, queries listed here will override any specified in a config file. 52 | # Prefix the list here with "+" to use these queries and those in the config file. 53 | 54 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 55 | # queries: security-extended,security-and-quality 56 | 57 | 58 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). 59 | # If this step fails, then you should remove it and run the build manually (see below) 60 | - name: Autobuild 61 | uses: github/codeql-action/autobuild@v2 62 | 63 | # ℹ️ Command-line programs to run using the OS shell. 64 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 65 | 66 | # If the Autobuild fails above, remove it and uncomment the following three lines. 67 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 68 | 69 | # - run: | 70 | # echo "Run, Build Application using script" 71 | # ./location_of_script_within_repo/buildscript.sh 72 | 73 | - name: Perform CodeQL Analysis 74 | uses: github/codeql-action/analyze@v2 75 | with: 76 | category: "/language:${{matrix.language}}" 77 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /.idea/shelf/ 3 | /workspace.xml 4 | /.idea/ 5 | *.class 6 | .target 7 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 tomLamprecht 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /plots/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomLamprecht/Easy-ML-For-Java/f4fcc6525a1fe652add8e7b7ba835ec2a3f799e8/plots/.keep -------------------------------------------------------------------------------- /plots/exampleplot/examplePlot.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomLamprecht/Easy-ML-For-Java/f4fcc6525a1fe652add8e7b7ba835ec2a3f799e8/plots/exampleplot/examplePlot.PNG -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | de.fhws 8 | NetworkTrainer 9 | 1.0-SNAPSHOT 10 | 11 | 12 | 13 | AsposeJavaAPI 14 | Aspose Java API 15 | https://repository.aspose.com/repo/ 16 | 17 | 18 | 19 | 20 | 21 | junit 22 | junit 23 | 4.13.2 24 | test 25 | 26 | 27 | org.jetbrains 28 | annotations 29 | RELEASE 30 | compile 31 | 32 | 33 | com.aspose 34 | aspose-cells 35 | 21.6 36 | 37 | 38 | 39 | 40 | 15 41 | 15 42 | 43 | 44 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/backpropagation/BackpropagationTrainer.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.backpropagation; 2 | 3 | import de.fhws.easyml.ai.backpropagation.logger.BackpropagationLogger; 4 | import de.fhws.easyml.ai.neuralnetwork.Backpropagation; 5 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 6 | import de.fhws.easyml.ai.neuralnetwork.costfunction.CostFunction; 7 | import de.fhws.easyml.ai.neuralnetwork.costfunction.SummedCostFunction; 8 | import de.fhws.easyml.linearalgebra.Vector; 9 | 10 | import java.util.List; 11 | import java.util.function.Supplier; 12 | 13 | public class BackpropagationTrainer { 14 | 15 | private final NeuralNet neuralNet; 16 | private final Supplier batchSupplier; 17 | private final double learningRate; 18 | private final int epochs; 19 | private final CostFunction costFunction; 20 | private final BackpropagationLogger logger; 21 | 22 | 23 | private BackpropagationTrainer(NeuralNet neuralNet, Supplier batchSupplier, double learningRate, int epochs, CostFunction costFunction, BackpropagationLogger logger) { 24 | this.neuralNet = neuralNet; 25 | this.batchSupplier = batchSupplier; 26 | this.learningRate = learningRate; 27 | this.epochs = epochs; 28 | this.costFunction = costFunction; 29 | this.logger = logger; 30 | } 31 | 32 | 33 | 34 | public void train() { 35 | for (int i = 0; i < epochs; i++) { 36 | doEpoch(i, batchSupplier.get()); 37 | } 38 | } 39 | 40 | private void doEpoch(int i, Batch batch) { 41 | Backpropagation.BatchTrainingResult result = new Backpropagation(neuralNet).trainBatch(batch.getInputs(), batch.getExpectedOutputs(), costFunction, learningRate); 42 | logIfPresent(i, result); 43 | } 44 | 45 | private void logIfPresent(int i, Backpropagation.BatchTrainingResult result) { 46 | if (logger != null) 47 | logger.log(i, result); 48 | } 49 | 50 | public static class Builder { 51 | 52 | private final NeuralNet neuralNet; 53 | private final Supplier batchSupplier; 54 | private final double learningRate; 55 | private final int epochs; 56 | private CostFunction costFunction = new SummedCostFunction(); 57 | private BackpropagationLogger logger; 58 | 59 | /** 60 | * Builder for Training a Neural Network using a backpropagation approach 61 | * @param batchSupplier supplies for each epoch a new Batch 62 | * @param learningRate is the factor on how much impact one batch should have on the weights 63 | * @param epochs are the amount of how many batches are getting trained 64 | */ 65 | public Builder(NeuralNet neuralNet, Supplier batchSupplier, double learningRate, int epochs) { 66 | this.neuralNet = neuralNet; 67 | this.batchSupplier = batchSupplier; 68 | this.learningRate = learningRate; 69 | this.epochs = epochs; 70 | } 71 | 72 | public Builder withCostFunction(CostFunction costFunction) { 73 | this.costFunction = costFunction; 74 | return this; 75 | } 76 | 77 | public Builder withLogger(BackpropagationLogger logger) { 78 | this.logger = logger; 79 | return this; 80 | } 81 | 82 | public BackpropagationTrainer build() { 83 | return new BackpropagationTrainer(neuralNet, batchSupplier, learningRate, epochs, costFunction, logger); 84 | } 85 | 86 | } 87 | 88 | public static class Batch{ 89 | private final List inputs; 90 | private final List expectedOutputs; 91 | 92 | public Batch(List inputs, List expectedOutputs) { 93 | this.inputs = inputs; 94 | this.expectedOutputs = expectedOutputs; 95 | } 96 | 97 | public List getInputs() { 98 | return inputs; 99 | } 100 | 101 | public List getExpectedOutputs() { 102 | return expectedOutputs; 103 | } 104 | } 105 | 106 | 107 | } 108 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/backpropagation/logger/BackpropagationLogger.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.backpropagation.logger; 2 | 3 | import de.fhws.easyml.ai.neuralnetwork.Backpropagation; 4 | import de.fhws.easyml.logger.LoggerInterface; 5 | 6 | public interface BackpropagationLogger extends LoggerInterface { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/backpropagation/logger/loggers/ConsoleLogger.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.backpropagation.logger.loggers; 2 | 3 | import de.fhws.easyml.ai.backpropagation.logger.BackpropagationLogger; 4 | import de.fhws.easyml.ai.neuralnetwork.Backpropagation; 5 | 6 | public class ConsoleLogger implements BackpropagationLogger { 7 | 8 | @Override 9 | public void log(int epoch, Backpropagation.BatchTrainingResult input) { 10 | System.out.println("EPOCH " + epoch + ": " + input.avg()); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/geneticneuralnet/NNRandomMutator.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.geneticneuralnet; 2 | 3 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import de.fhws.easyml.linearalgebra.Randomizer; 6 | import de.fhws.easyml.utility.MultiThreadHelper; 7 | import de.fhws.easyml.utility.Validator; 8 | import de.fhws.easyml.geneticalgorithm.evolution.Mutator; 9 | import org.jetbrains.annotations.Nullable; 10 | 11 | import java.util.concurrent.ExecutorService; 12 | import java.util.concurrent.ThreadLocalRandom; 13 | import java.util.function.Consumer; 14 | import java.util.function.DoubleUnaryOperator; 15 | import java.util.stream.Stream; 16 | 17 | public class NNRandomMutator implements Mutator { 18 | 19 | private final double outerMutationRate; 20 | 21 | private final DoubleUnaryOperator innerMutator; 22 | 23 | /** 24 | * Creates a Random Mutator for Neural Networks 25 | * It will choose a Neural Net out of the population after the probability {@code outMutationRate} 26 | * Every Weight of chosen Neural Net will be mutated by the probability of {@code innerMutationRate} 27 | * for the mutation there will be a factor of the modifying weight calculated by a random number of {@code mutationFactor} like: 28 | * {@code mutationFactor.getInRange() * weight}. Then there will be checked if its absolute value is smaller than 29 | * {@code minMutationAbsolute}, if yes the minMutationAbsolute with the sign of the factor previously calculated will 30 | * be added to weight. Otherwise, the factor itself gets added to the weight. 31 | * 32 | * @param outerMutationRate of the population 33 | * @param innerMutationRate of the chosen Neural Net 34 | * @param mutationFactor modifying Factor 35 | * @param minMutationAbsolute minimum modifying value 36 | */ 37 | public NNRandomMutator(double outerMutationRate, double innerMutationRate, Randomizer mutationFactor, double minMutationAbsolute ) { 38 | Validator.value( outerMutationRate ).isBetweenOrThrow( 0, 1 ); 39 | Validator.value( innerMutationRate ).isBetweenOrThrow( 0, 1 ); 40 | 41 | this.outerMutationRate = outerMutationRate; 42 | 43 | innerMutator = d -> { 44 | if ( ThreadLocalRandom.current( ).nextDouble( ) < innerMutationRate ) { 45 | double mutateValue = d * mutationFactor.getInRange( ); 46 | return d + ( mutateValue < 0 ? -1 : 1 ) * Math.max( minMutationAbsolute, Math.abs( mutateValue ) ); 47 | } 48 | return d; 49 | }; 50 | } 51 | 52 | @Override 53 | public void mutate(Population pop, @Nullable ExecutorService executorService ) { 54 | Stream filteredIndividuals = getFilteredIndividuals( pop ); 55 | 56 | Consumer mutateConsumer = individual -> mutateNN( individual.getNN( ) ); 57 | if ( executorService != null ) 58 | MultiThreadHelper.callConsumerOnStream( executorService, filteredIndividuals, mutateConsumer ); 59 | else 60 | filteredIndividuals 61 | .forEach( mutateConsumer ); 62 | } 63 | 64 | private Stream getFilteredIndividuals( Population pop ) { 65 | return pop.getIndividuals( ) 66 | .stream( ) 67 | .filter( individual -> ThreadLocalRandom.current( ).nextDouble( ) < outerMutationRate ); 68 | } 69 | 70 | private void mutateNN( NeuralNet neuralNet ) { 71 | neuralNet.getLayers( ) 72 | .forEach( layer -> { 73 | layer.getWeights( ).apply( innerMutator ); 74 | layer.getBias( ).apply( innerMutator ); 75 | } ); 76 | } 77 | 78 | public static class EnsureSingleThreading extends NNRandomMutator { 79 | 80 | /** 81 | * Creates a Random Mutator for Neural Networks but ensures to take the Single Thread implementation even if 82 | * a ExecuterService is provided 83 | * for further documentation see {@link #NNRandomMutator(double, double, Randomizer, double)} 84 | */ 85 | 86 | public EnsureSingleThreading( double outerMutationRate, double innerMutationRate, Randomizer mutationFactor, double minMutationAbsolute ) { 87 | super( outerMutationRate, innerMutationRate, mutationFactor, minMutationAbsolute ); 88 | } 89 | 90 | @Override 91 | public void mutate( Population pop, ExecutorService executorService ) { 92 | super.mutate( pop, null ); 93 | } 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/geneticneuralnet/NNUniformCrossoverRecombiner.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.geneticneuralnet; 2 | 3 | import de.fhws.easyml.ai.neuralnetwork.Layer; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import de.fhws.easyml.utility.ListUtility; 6 | import de.fhws.easyml.utility.MultiThreadHelper; 7 | import de.fhws.easyml.utility.Validator; 8 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner; 9 | import org.jetbrains.annotations.Nullable; 10 | 11 | import java.util.List; 12 | import java.util.concurrent.ExecutorService; 13 | import java.util.concurrent.ThreadLocalRandom; 14 | 15 | public class NNUniformCrossoverRecombiner implements Recombiner { 16 | 17 | private int amountOfParentsPerChild; 18 | 19 | //TODO Add Documentation for NNUniformCrossoverRecombiner 20 | public NNUniformCrossoverRecombiner( int amountOfParentsPerChild ) { 21 | this.amountOfParentsPerChild = amountOfParentsPerChild; 22 | } 23 | 24 | @Override 25 | public void recombine(Population pop, int goalSize, @Nullable ExecutorService executorService ) { 26 | Validator.value( amountOfParentsPerChild ).isBetweenOrThrow( 1, pop.getSize( ) ); 27 | 28 | if ( executorService != null ) { 29 | recombineMultiThreaded( pop, goalSize, executorService ); 30 | } else { 31 | recombineSingleThreaded( pop, goalSize ); 32 | } 33 | } 34 | 35 | private void recombineSingleThreaded( Population pop, int goalSize ) { 36 | while ( pop.getIndividuals( ).size( ) < goalSize ) { 37 | List parents = ListUtility.selectRandomElements( pop.getIndividuals( ), amountOfParentsPerChild ); 38 | pop.getIndividuals( ).add( makeChild( parents ) ); 39 | } 40 | } 41 | 42 | private void recombineMultiThreaded( Population pop, int goalSize, ExecutorService executorService ) { 43 | 44 | int neededChildren = goalSize - pop.getIndividuals( ).size( ); 45 | 46 | pop.getIndividuals( ).addAll( 47 | MultiThreadHelper.getListOutOfSupplier( executorService, 48 | ( ) -> makeChild( ListUtility.selectRandomElements( pop.getIndividuals( ), amountOfParentsPerChild ) ), 49 | neededChildren ) ); 50 | } 51 | 52 | 53 | private NeuralNetIndividual makeChild( List parents ) { 54 | NeuralNetIndividual child = parents.get( 0 ).copy( ); 55 | for ( int l = 0; l < child.getNN( ).getLayers( ).size( ); l++ ) { 56 | combineWeights( parents, child, l ); 57 | 58 | combineBias( parents, child, l ); 59 | } 60 | return child; 61 | } 62 | 63 | private void combineWeights( List parents, NeuralNetIndividual child, int layerIndex ) { 64 | Layer layer = child.getNN( ).getLayers( ).get( layerIndex ); 65 | 66 | for ( int i = 0; i < layer.getWeights( ).getNumRows( ); i++ ) { 67 | for ( int j = 0; j < layer.getWeights( ).getNumCols( ); j++ ) { 68 | int selectedParent = ThreadLocalRandom.current( ).nextInt( parents.size( ) ); 69 | layer.getWeights( ) 70 | .set( i, j, parents.get( selectedParent ) 71 | .getNN( ).getLayers( ) 72 | .get( layerIndex ) 73 | .getWeights( ) 74 | .get( i, j ) ); 75 | } 76 | } 77 | } 78 | 79 | 80 | private void combineBias( List parents, NeuralNetIndividual child, int layerIndex ) { 81 | Layer layer = child.getNN( ).getLayers( ).get( layerIndex ); 82 | 83 | for ( int i = 0; i < layer.getBias( ).size( ); i++ ) { 84 | int selectedParent = ThreadLocalRandom.current( ).nextInt( parents.size( ) ); 85 | layer.getBias( ) 86 | .set( i, parents.get( selectedParent ) 87 | .getNN( ).getLayers( ) 88 | .get( layerIndex ) 89 | .getBias( ) 90 | .get( i ) ); 91 | } 92 | } 93 | 94 | public static class EnsureSingleThreading extends NNUniformCrossoverRecombiner { 95 | 96 | public EnsureSingleThreading( int amountOfParentsPerChild ) { 97 | super( amountOfParentsPerChild ); 98 | } 99 | 100 | @Override 101 | public void recombine( Population pop, int goalSize, ExecutorService executorService ) { 102 | super.recombine( pop, goalSize, null ); 103 | } 104 | } 105 | 106 | } 107 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/geneticneuralnet/NeuralNetFitnessFunction.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.geneticneuralnet; 2 | 3 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 4 | 5 | @FunctionalInterface 6 | public interface NeuralNetFitnessFunction { 7 | 8 | double calculateFitness( NeuralNet neuralNet ); 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/geneticneuralnet/NeuralNetIndividual.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.geneticneuralnet; 2 | 3 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 4 | import de.fhws.easyml.linearalgebra.Vector; 5 | import de.fhws.easyml.geneticalgorithm.Individual; 6 | 7 | 8 | public class NeuralNetIndividual implements Individual { 9 | 10 | private NeuralNet neuralNet; 11 | private NeuralNetFitnessFunction fitnessFunction; 12 | private double fitness; 13 | 14 | public NeuralNetIndividual(NeuralNet neuralNet, NeuralNetFitnessFunction fitnessFunction) { 15 | this.neuralNet = neuralNet; 16 | this.fitnessFunction = fitnessFunction; 17 | } 18 | 19 | 20 | public NeuralNet getNN() { 21 | return neuralNet; 22 | } 23 | 24 | @Override 25 | public void calcFitness() { 26 | fitness = fitnessFunction.calculateFitness(neuralNet); 27 | } 28 | 29 | @Override 30 | public double getFitness() { 31 | return fitness; 32 | } 33 | 34 | /** 35 | * Copies the Neural Net but uses the same reference for the fitnessFunction 36 | * @return the copy 37 | */ 38 | @Override 39 | public NeuralNetIndividual copy() { 40 | NeuralNetIndividual copy = new NeuralNetIndividual(neuralNet.copy(), fitnessFunction); 41 | copy.fitness = this.fitness; 42 | return copy; 43 | } 44 | 45 | public Vector calcOutput(Vector vector) { 46 | return neuralNet.calcOutput(vector); 47 | } 48 | 49 | 50 | public NeuralNetFitnessFunction getFitnessFunction() { 51 | return fitnessFunction; 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/geneticneuralnet/NeuralNetPopulationSupplier.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.geneticneuralnet; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Population; 4 | import de.fhws.easyml.geneticalgorithm.populationsupplier.PopulationSupplier; 5 | 6 | import java.util.stream.Collectors; 7 | import java.util.stream.IntStream; 8 | 9 | public class NeuralNetPopulationSupplier implements PopulationSupplier { 10 | 11 | private final NeuralNetSupplier neuralNetSupplier; 12 | private final int populationSize; 13 | private final NeuralNetFitnessFunction fitnessFunction; 14 | 15 | public NeuralNetPopulationSupplier( NeuralNetSupplier neuralNetSupplier, NeuralNetFitnessFunction fitnessFunction, int populationSize ) { 16 | this.neuralNetSupplier = neuralNetSupplier; 17 | this.populationSize = populationSize; 18 | this.fitnessFunction = fitnessFunction; 19 | } 20 | 21 | @Override 22 | public Population get( ) { 23 | return new Population<>( IntStream.range( 0, populationSize ) 24 | .mapToObj( i -> new NeuralNetIndividual( neuralNetSupplier.get(), fitnessFunction ) ) 25 | .collect( Collectors.toList( ) ) ); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/geneticneuralnet/NeuralNetSupplier.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.geneticneuralnet; 2 | 3 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 4 | 5 | import java.util.function.Supplier; 6 | 7 | @FunctionalInterface 8 | public interface NeuralNetSupplier extends Supplier { 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/neuralnetwork/Backpropagation.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.neuralnetwork; 2 | 3 | import de.fhws.easyml.ai.neuralnetwork.costfunction.CostFunction; 4 | import de.fhws.easyml.linearalgebra.ApplyAble; 5 | import de.fhws.easyml.linearalgebra.Matrix; 6 | import de.fhws.easyml.linearalgebra.Vector; 7 | import de.fhws.easyml.utility.Validator; 8 | import org.jetbrains.annotations.NotNull; 9 | 10 | import java.util.ArrayList; 11 | import java.util.HashMap; 12 | import java.util.List; 13 | import java.util.Map; 14 | import java.util.stream.Stream; 15 | 16 | public class Backpropagation { 17 | 18 | private final NeuralNet nn; 19 | 20 | public Backpropagation(NeuralNet neuralNet) { 21 | this.nn = neuralNet; 22 | } 23 | 24 | public BatchTrainingResult trainBatch(List inputs, List expectedOutputs, CostFunction costFunction, double learningRate ) { 25 | validateTrainingBatchInput(inputs, expectedOutputs, learningRate); 26 | 27 | return doTrainBatch(inputs, expectedOutputs, costFunction, learningRate); 28 | } 29 | 30 | @NotNull 31 | private BatchTrainingResult doTrainBatch(List inputs, List expectedOutputs, CostFunction costFunction, double learningRate) { 32 | BatchTrainingResult result = new BatchTrainingResult(); 33 | 34 | updateLayers(learningRate, getAveragedGradients( 35 | createGradientsFromBatch(inputs, expectedOutputs, costFunction, result ) 36 | ) ); 37 | return result; 38 | } 39 | 40 | private void validateTrainingBatchInput(List inputs, List expectedOutputs, double learningRate) { 41 | Validator.value( inputs.size() ).isEqualToOrThrow( expectedOutputs.size() ); 42 | Validator.value(learningRate).isBetweenOrThrow( 0, 1 ); 43 | } 44 | 45 | private List> createGradientsFromBatch(List inputs, List expectedOutputs, CostFunction costFunction, BatchTrainingResult batchTrainingResult ) { 46 | List> gradientsList = new ArrayList<>(); 47 | for ( int i = 0; i < inputs.size(); i++ ) 48 | gradientsList.add( calcGradients( inputs.get( i ), expectedOutputs.get( i ), costFunction, batchTrainingResult ) ); 49 | return gradientsList; 50 | } 51 | 52 | private void updateLayers( double learningRate, Map resultGradients ) { 53 | for ( int i = 0; i < nn.layers.size(); i++ ) { 54 | updateLayer(learningRate, resultGradients.get(i), i); 55 | } 56 | } 57 | 58 | private void updateLayer(double learningRate, GradientsObject resultGradients, int layerIndex) { 59 | updateLayerWeights(nn.layers.get(layerIndex).getWeights(), learningRate, resultGradients.weightGradients); 60 | updateLayerBiases(nn.layers.get(layerIndex).getBias(), learningRate, resultGradients.biasGradients); 61 | } 62 | 63 | private void updateLayerBiases(Vector cur, double learningRate, Vector gradients){ 64 | for (int i = 0; i < cur.size(); i++) 65 | cur.set(i, cur.get( i ) + learningRate * gradients.get( i ) ); 66 | } 67 | 68 | private void updateLayerWeights(Matrix cur, double learningRate, Matrix gradients ) { 69 | for ( int row = 0; row < cur.getNumRows(); row++ ) { 70 | for ( int col = 0; col < cur.getNumCols(); col++ ) { 71 | cur.set( row, col, cur.get( row, col ) - learningRate * gradients.get( row, col ) ); 72 | } 73 | } 74 | } 75 | 76 | private Map getAveragedGradients( List> gradientsList ) { 77 | Map resultGradients = createEmptyResultGradientMap( gradientsList ); 78 | 79 | sumUpGradients( gradientsList, resultGradients ); 80 | 81 | averageOutGradients( gradientsList, resultGradients ); 82 | 83 | return resultGradients; 84 | } 85 | 86 | private void sumUpGradients( List> gradientsList, Map resultGradients ) { 87 | for ( Map gradients : gradientsList ) { 88 | addGradientsToResult( resultGradients, gradients ); 89 | } 90 | } 91 | 92 | private void averageOutGradients( List> gradientsList, Map resultGradients ) { 93 | resultGradients.values() 94 | .stream() 95 | .flatMap(e -> Stream.>of(e.weightGradients, e.biasGradients)) 96 | .forEach(a -> a.apply(v -> v / gradientsList.size() ) ); 97 | } 98 | 99 | private void addGradientsToResult( Map resultGradients, Map gradients ) { 100 | for ( Map.Entry entry : gradients.entrySet() ) { 101 | addWeightGradientsToResult(resultGradients, entry); 102 | addBiasGradientsToResult(resultGradients, entry); 103 | } 104 | } 105 | 106 | private void addBiasGradientsToResult(Map resultGradients, Map.Entry entry) { 107 | Vector resultVector = resultGradients.get( entry.getKey() ).biasGradients; 108 | addEachGradientToResultVector( entry.getValue().biasGradients, resultVector ); 109 | } 110 | 111 | private void addWeightGradientsToResult(Map resultGradients, Map.Entry entry) { 112 | Matrix resultMatrix = resultGradients.get( entry.getKey() ).weightGradients; 113 | addEachGradientToResultMatrix( entry.getValue().weightGradients, resultMatrix ); 114 | } 115 | 116 | private void addEachGradientToResultVector(Vector biasGradients, Vector resultVector) { 117 | for (int i = 0; i < resultVector.size(); i++) { 118 | resultVector.set(i, resultVector.get( i ) + biasGradients.get( i ) ); 119 | } 120 | } 121 | 122 | private void addEachGradientToResultMatrix( Matrix gradients, Matrix resultMatrix ) { 123 | for ( int i = 0; i < resultMatrix.getNumRows(); i++ ) { 124 | for ( int j = 0; j < resultMatrix.getNumCols(); j++ ) { 125 | resultMatrix.set( i, j, resultMatrix.get( i, j ) + gradients.get( i, j ) ); 126 | } 127 | } 128 | } 129 | 130 | private Map createEmptyResultGradientMap( List> gradientsList ) { 131 | Map resultGradients = new HashMap<>(); 132 | for ( int i = 0; i < gradientsList.get( 0 ).size(); i++ ) { 133 | Matrix resultMatrix = new Matrix( gradientsList.get( 0 ).get( i ).weightGradients.getNumRows(), gradientsList.get( 0 ).get( i ).weightGradients.getNumCols() ); 134 | Vector resultVector = new Vector(gradientsList.get( 0 ).get( i ).biasGradients.size() ); 135 | resultGradients.put( i, new GradientsObject(resultMatrix, resultVector)); 136 | } 137 | return resultGradients; 138 | } 139 | 140 | private Map calcGradients( Vector input, Vector expectedOutput, CostFunction costFunction, BatchTrainingResult batchTrainingResult ) { 141 | nn.validateInputVector( input ); 142 | 143 | Map layerIndexToTrainingResult = feedForward( input ); 144 | 145 | addCostsToTrainingBatchResult(expectedOutput, costFunction, batchTrainingResult, layerIndexToTrainingResult); 146 | 147 | return backpropagate( expectedOutput, costFunction, layerIndexToTrainingResult, input ); 148 | } 149 | 150 | private void addCostsToTrainingBatchResult(Vector expectedOutput, CostFunction costFunction, BatchTrainingResult batchTrainingResult, Map layerIndexToTrainingResult) { 151 | batchTrainingResult.addCost(costFunction.costs(expectedOutput, layerIndexToTrainingResult.get(nn.layers.size() - 1).getOutputWithActivationFunction())); 152 | } 153 | 154 | private Map backpropagate( Vector expectedOutput, CostFunction costFunction, Map layerIndexToTrainingResult, Vector input ) { 155 | Map layerIndexToGradient = new HashMap<>(); 156 | for ( int i = nn.layers.size() - 1; i >= 0; i-- ) { 157 | calculateDerivativesOfCostFunction( i, expectedOutput, costFunction, layerIndexToTrainingResult ); 158 | Matrix weightGradients = weightGradientsForEachWeightOfLayer( layerIndexToTrainingResult, i, nn.layers.get( i ), input ); 159 | Vector biasGradients = biasGradientsForEachBiasOfLayer( layerIndexToTrainingResult, i, nn.layers.get( i )); 160 | layerIndexToGradient.put( i, new GradientsObject(weightGradients, biasGradients)); 161 | } 162 | return layerIndexToGradient; 163 | } 164 | 165 | private Vector biasGradientsForEachBiasOfLayer(Map layerIndexToTrainingResult, int i, Layer layer) { 166 | Vector gradients = new Vector( layer.getBias().size() ); 167 | for (int j = 0; j < gradients.size(); j++) { 168 | double dervOfActivationFunction = getDerivativeOfActivationFunctionFromLayer(layer, j, layerIndexToTrainingResult.get(i)); 169 | double dervOfCostFunction = layerIndexToTrainingResult.get( i ).getDervOfCostFunction().get( j ); 170 | double value = dervOfActivationFunction * dervOfCostFunction; 171 | gradients.set(j, value); 172 | } 173 | return gradients; 174 | } 175 | 176 | private double getDerivativeOfActivationFunctionFromLayer(Layer layer, int neuronIndex, Layer.TrainingResult trainingResult) { 177 | return layer.getFActivation().derivative(trainingResult.getOutputStripped().get(neuronIndex)); 178 | } 179 | 180 | private Matrix weightGradientsForEachWeightOfLayer( Map layerIndexToTrainingResult, int i, Layer layer, Vector input ) { 181 | Matrix gradients = new Matrix( layer.getWeights().getNumRows(), layer.getWeights().getNumCols() ); 182 | for ( int j = 0; j < gradients.getNumRows(); j++ ) { 183 | for ( int k = 0; k < gradients.getNumCols(); k++ ) { 184 | double neuronKOfPrevLayer = i > 0 ? layerIndexToTrainingResult.get( i - 1 ).getOutputWithActivationFunction().get( k ) : input.get( k ); 185 | double dervOfActivationFunction = getDerivativeOfActivationFunctionFromLayer(layer, j, layerIndexToTrainingResult.get( i )); 186 | double dervOfCostFunction = layerIndexToTrainingResult.get( i ).getDervOfCostFunction().get( j ); 187 | double value = neuronKOfPrevLayer * dervOfActivationFunction * dervOfCostFunction; 188 | gradients.set( j, k, value ); 189 | } 190 | } 191 | return gradients; 192 | } 193 | 194 | private void calculateDerivativesOfCostFunction( int indexOfLayer, Vector expectedOutput, CostFunction costFunction, Map layerIndexToTrainingResult ) { 195 | if ( indexOfLayer == nn.layers.size() - 1 ) 196 | calculateCostFunctionDerivForLastLayer( expectedOutput, costFunction, layerIndexToTrainingResult ); 197 | else 198 | calculateCostFunctionDerivForAnyLayer( layerIndexToTrainingResult.get( indexOfLayer ), nn.layers.get( indexOfLayer + 1 ), layerIndexToTrainingResult.get( indexOfLayer + 1 ) ); 199 | } 200 | 201 | private void calculateCostFunctionDerivForAnyLayer( Layer.TrainingResult current, Layer followingLayer, Layer.TrainingResult followingTrainingResult ) { 202 | current.setDervOfCostFunction( calcCostFunctionDerivative( current, followingLayer, followingTrainingResult ) ); 203 | } 204 | 205 | private void calculateCostFunctionDerivForLastLayer( Vector expectedOutput, CostFunction costFunction, Map layerIndexToTrainingResult ) { 206 | Vector outputVector = layerIndexToTrainingResult.get( nn.layers.size() - 1 ).getOutputWithActivationFunction(); 207 | Vector dervOfOutput = new Vector( outputVector.size() ); 208 | for ( int i = 0; i < nn.layers.get( nn.layers.size() - 1 ).size(); i++ ) { 209 | dervOfOutput.set( i, costFunction.derivativeWithRespectToNeuron( expectedOutput, outputVector, i ) ); 210 | } 211 | layerIndexToTrainingResult.get( nn.layers.size() - 1 ).setDervOfCostFunction( dervOfOutput ); 212 | } 213 | 214 | private Vector calcCostFunctionDerivative( Layer.TrainingResult trainingResultCurrentLayer, Layer followingLayer, Layer.TrainingResult trainingResultFollowingLayer ) { 215 | 216 | Vector derivatives = new Vector( trainingResultCurrentLayer.size() ); 217 | for ( int i = 0; i < trainingResultCurrentLayer.size(); i++ ) { 218 | derivatives.set( i, calcDervOfNeuron( followingLayer, trainingResultFollowingLayer, i ) ); 219 | } 220 | return derivatives; 221 | } 222 | 223 | private double calcDervOfNeuron( Layer followingLayer, Layer.TrainingResult trainingResultFollowingLayer, int neuronIndex ) { 224 | double temp = 0; 225 | for ( int i = 0; i < followingLayer.size(); i++ ) { 226 | double weightFromNeuronIndexToI = followingLayer.getWeights().get( i, neuronIndex ); 227 | double dervOfActivationFunctionWithOutputStripped = getDerivativeOfActivationFunctionFromLayer( followingLayer, i, trainingResultFollowingLayer ); 228 | double dervOfCostFunctionOfNeuronI = trainingResultFollowingLayer.getDervOfCostFunction().get( i ); 229 | temp = weightFromNeuronIndexToI * dervOfActivationFunctionWithOutputStripped * dervOfCostFunctionOfNeuronI; 230 | 231 | } 232 | return temp; 233 | } 234 | 235 | private Map feedForward( Vector input ) { 236 | Map trainingResults = new HashMap<>(); 237 | 238 | for ( int i = 0; i < nn.layers.size(); i++ ) { 239 | Layer.TrainingResult result = nn.layers.get( i ).feedForward( input ); 240 | trainingResults.put( i, result ); 241 | input = result.getOutputWithActivationFunction(); 242 | } 243 | 244 | return trainingResults; 245 | } 246 | 247 | private static class GradientsObject { 248 | final Matrix weightGradients; 249 | final Vector biasGradients; 250 | 251 | public GradientsObject(Matrix weightGradients, Vector biasGradients) { 252 | this.weightGradients = weightGradients; 253 | this.biasGradients = biasGradients; 254 | } 255 | 256 | } 257 | 258 | public static class BatchTrainingResult{ 259 | private final List costValues; 260 | 261 | public BatchTrainingResult(List costValues) { 262 | this.costValues = costValues; 263 | } 264 | 265 | public BatchTrainingResult(){ 266 | costValues = new ArrayList<>(); 267 | } 268 | 269 | public void addCost(double cost){ 270 | costValues.add(cost); 271 | } 272 | 273 | public double avg(){ 274 | return costValues.stream().mapToDouble(x -> x).average().orElse(0); 275 | } 276 | 277 | 278 | } 279 | 280 | } 281 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/neuralnetwork/Layer.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.neuralnetwork; 2 | 3 | 4 | import de.fhws.easyml.ai.neuralnetwork.activationfunction.ActivationFunction; 5 | import de.fhws.easyml.linearalgebra.Vector; 6 | import de.fhws.easyml.utility.Validator; 7 | import de.fhws.easyml.linearalgebra.LinearAlgebra; 8 | import de.fhws.easyml.linearalgebra.Matrix; 9 | import de.fhws.easyml.linearalgebra.Randomizer; 10 | 11 | import java.io.Serializable; 12 | 13 | public class Layer implements Serializable { 14 | 15 | private static final long serialVersionUID = -3844443062431620792L; 16 | 17 | private final Matrix weights; 18 | private final Vector bias; 19 | private final ActivationFunction fActivation; 20 | 21 | Layer( int size, int sizeOfLayerBefore, ActivationFunction activationFunction ) { 22 | Validator.value( size ).isPositiveOrThrow( ); 23 | Validator.value( sizeOfLayerBefore ).isPositiveOrThrow( ); 24 | 25 | weights = new Matrix( size, sizeOfLayerBefore ); 26 | bias = new Vector( size ); 27 | 28 | this.fActivation = activationFunction; 29 | } 30 | 31 | /** 32 | * Copy Constructor 33 | */ 34 | private Layer( Layer copy ) { 35 | this.weights = new Matrix( copy.weights.getData( ) ); 36 | this.bias = new Vector( copy.bias.getData( ) ); 37 | this.fActivation = copy.fActivation; 38 | } 39 | 40 | /** 41 | * calculates the activation of this layer, based on the given activation of the linked layer 42 | * 43 | * @param activationsOfLayerBefore layer on which the activation is based 44 | * @return a Vector with the activation of this layer as a vector 45 | * @throws IllegalArgumentException if the number of columns of the weights does not fit to the size of activationsOfLayerBefore 46 | */ 47 | public Vector calcActivation( Vector activationsOfLayerBefore ) { 48 | validateActivationsOfLayerBefore( activationsOfLayerBefore ); 49 | 50 | return calcOutput( activationsOfLayerBefore ).apply( fActivation ); 51 | } 52 | 53 | private Vector calcOutput( Vector activationsOfLayerBefore ) { 54 | return LinearAlgebra.multiply( weights, activationsOfLayerBefore ).sub( bias ); 55 | } 56 | 57 | private void validateActivationsOfLayerBefore( Vector activationsOfLayerBefore ) { 58 | Validator.value( weights.getNumCols( ) ) 59 | .isEqualToOrThrow( 60 | activationsOfLayerBefore.size( ), 61 | () -> new IllegalArgumentException( "size of activationsOfLayerBefore must fit with weights columns" ) 62 | ); 63 | } 64 | 65 | 66 | protected TrainingResult feedForward( Vector activationsOfLayerBefore ){ 67 | validateActivationsOfLayerBefore( activationsOfLayerBefore ); 68 | 69 | Vector outputStripped = calcOutput( activationsOfLayerBefore ); 70 | Vector outputWithActivationFunction = outputStripped.applyAsCopy( fActivation ); 71 | 72 | return new TrainingResult( outputStripped, outputWithActivationFunction ); 73 | } 74 | 75 | public void randomize( Randomizer weightRand, Randomizer biasRand ) { 76 | weights.randomize( weightRand ); 77 | bias.randomize( biasRand ); 78 | } 79 | 80 | /** 81 | * gets the number of the nodes in this layer 82 | * 83 | * @return number of nodes in this layer 84 | */ 85 | public int size( ) { 86 | return bias.size( ); 87 | } 88 | 89 | /** 90 | * gets the Matrix of weights of this layer 91 | * 92 | * @return Matrix of weights in this layer 93 | */ 94 | public Matrix getWeights( ) { 95 | return weights; 96 | } 97 | 98 | public Vector getBias( ) { 99 | return this.bias; 100 | } 101 | 102 | public ActivationFunction getFActivation() { 103 | return fActivation; 104 | } 105 | 106 | public Layer copy( ) { 107 | return new Layer( this ); 108 | } 109 | 110 | public static class TrainingResult{ 111 | private Vector outputStripped; 112 | private Vector outputWithActivationFunction; 113 | private Vector dervOfCostFunction; 114 | 115 | public TrainingResult( Vector outputStripped, Vector outputWithActivationFunction ) { 116 | this.outputStripped = outputStripped; 117 | this.outputWithActivationFunction = outputWithActivationFunction; 118 | } 119 | 120 | public Vector getOutputStripped() { 121 | return outputStripped; 122 | } 123 | 124 | public Vector getOutputWithActivationFunction() { 125 | return outputWithActivationFunction; 126 | } 127 | 128 | public Vector getDervOfCostFunction() { 129 | return dervOfCostFunction; 130 | } 131 | 132 | public void setDervOfCostFunction( Vector dervOfCostFunction ) { 133 | this.dervOfCostFunction = dervOfCostFunction; 134 | } 135 | 136 | public int size(){ 137 | return outputStripped.size(); 138 | } 139 | } 140 | 141 | } 142 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/neuralnetwork/NeuralNet.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.neuralnetwork; 2 | 3 | 4 | import de.fhws.easyml.ai.neuralnetwork.activationfunction.ActivationFunction; 5 | import de.fhws.easyml.ai.neuralnetwork.activationfunction.Sigmoid; 6 | import de.fhws.easyml.ai.neuralnetwork.costfunction.CostFunction; 7 | import de.fhws.easyml.linearalgebra.Randomizer; 8 | import de.fhws.easyml.linearalgebra.Vector; 9 | import de.fhws.easyml.utility.StreamUtil; 10 | import de.fhws.easyml.utility.Validator; 11 | 12 | import java.io.Serial; 13 | import java.io.Serializable; 14 | import java.util.ArrayList; 15 | import java.util.Arrays; 16 | import java.util.List; 17 | import java.util.concurrent.atomic.AtomicBoolean; 18 | import java.util.stream.IntStream; 19 | 20 | public class NeuralNet implements Serializable { 21 | 22 | @Serial 23 | private static final long serialVersionUID = -5984131490435879432L; 24 | 25 | int inputSize; 26 | protected final List layers; 27 | 28 | private NeuralNet( int inputSize ) { 29 | this.inputSize = inputSize; 30 | layers = new ArrayList<>(); 31 | } 32 | 33 | private NeuralNet( int inputSize, List layers ) { 34 | this.inputSize = inputSize; 35 | this.layers = layers; 36 | } 37 | 38 | /** 39 | * calculates the output based on the given input vector 40 | * 41 | * @param input vector with the input values; must be the size specified at 42 | * built 43 | * @return the calculated output vector 44 | * @throws IllegalArgumentException if the size of {@code input} is not the 45 | * specified one 46 | */ 47 | public Vector calcOutput( Vector input ) { 48 | return calcAllLayer( input ).get( layers.size() - 1 ); 49 | } 50 | 51 | 52 | 53 | /** 54 | * calculates the output based on the given input vector 55 | * 56 | * @param input vector with the input values; must be the size specified at 57 | * built 58 | * @return all calculated vectors (hidden layers and output layer) 59 | * @throws IllegalArgumentException if the size of {@code input} is not the 60 | * specified one 61 | */ 62 | public List calcAllLayer( Vector input ) { 63 | validateInputVector( input ); 64 | 65 | return doCalcLayers( input ); 66 | } 67 | 68 | protected void validateInputVector( Vector input ) { 69 | Validator.value( input.size() ) 70 | .isEqualToOrThrow( 71 | inputSize, 72 | () -> new IllegalArgumentException( "the input vector must be of the same size as the first layer" ) 73 | ); 74 | } 75 | 76 | private List doCalcLayers( Vector input ) { 77 | final List list = new ArrayList<>( layers.size() ); 78 | list.add( input ); 79 | 80 | StreamUtil.of( layers.stream() ) 81 | .forEachIndexed( ( layer, i ) -> list.add( layer.calcActivation( list.get( i ) ) ) ); 82 | 83 | list.remove( 0 ); 84 | return list; 85 | } 86 | 87 | public NeuralNet randomize( Randomizer weightRand, Randomizer biasRand ) { 88 | layers.forEach( layer -> layer.randomize( weightRand, biasRand ) ); 89 | 90 | return this; 91 | } 92 | 93 | 94 | public List getLayers() { 95 | return layers; 96 | } 97 | 98 | /** 99 | * copy the current NeuralNet 100 | * 101 | * @return copy of the current NeuralNet 102 | */ 103 | public NeuralNet copy() { 104 | final List copiedLayers = new ArrayList<>(); 105 | 106 | layers.forEach( layer -> copiedLayers.add( layer.copy() ) ); 107 | 108 | return new NeuralNet( this.inputSize, copiedLayers ); 109 | } 110 | 111 | 112 | public static class Builder { 113 | private final int inputSize; 114 | private final int outputSize; 115 | private final List layerSizes = new ArrayList<>(); 116 | private ActivationFunction activationFunction; 117 | private Randomizer weightRand = new Randomizer( -1, 1 ); 118 | private Randomizer biasRand = new Randomizer( 0, 1 ); 119 | private final AtomicBoolean isBuilt = new AtomicBoolean( false ); 120 | 121 | /** 122 | * Constructor to create a Builder which is capable to build a NeuralNet 123 | * 124 | * @throws IllegalArgumentException if depth is less or equal 1 or if inputNodes 125 | * is less than 1 126 | */ 127 | public Builder( int inputSize, int outputSize ) { 128 | Validator.value( inputSize ).isPositiveOrThrow(); 129 | Validator.value( outputSize ).isPositiveOrThrow(); 130 | 131 | this.inputSize = inputSize; 132 | this.outputSize = outputSize; 133 | activationFunction = new Sigmoid(); 134 | } 135 | 136 | /** 137 | * set activation function of the neural network. Must be called before adding 138 | * any layers 139 | * 140 | * @param activationFunction ActivationFunction (Function that accepts Double and returns 141 | * Double) to describe the activation function which is applied on 142 | * every layer on calculation 143 | */ 144 | public Builder withActivationFunction( ActivationFunction activationFunction ) { 145 | this.activationFunction = activationFunction; 146 | return this; 147 | } 148 | 149 | public Builder withWeightRandomizer( Randomizer weightRandomizer ) { 150 | this.weightRand = weightRandomizer; 151 | return this; 152 | } 153 | 154 | public Builder withBiasRandomizer( Randomizer biasRandomizer ) { 155 | this.biasRand = biasRandomizer; 156 | return this; 157 | } 158 | 159 | /** 160 | * adds a layer to the neural network 161 | * 162 | * @param sizeOfLayer the number of nodes of the added layer 163 | * @return this 164 | * @throws IllegalArgumentException if sizeOfLayer are 0 or smaller 165 | */ 166 | public Builder addLayer( int sizeOfLayer ) { 167 | Validator.value( sizeOfLayer ).isPositiveOrThrow(); 168 | 169 | layerSizes.add( sizeOfLayer ); 170 | 171 | return this; 172 | } 173 | 174 | /** 175 | * adds the specified amountOfToAddedLayers of layers to the neural network 176 | * 177 | * @param amountOfToAddedLayers the amountOfToAddedLayers of layers added 178 | * @param sizeOfLayers the number of nodes of the added layers 179 | * @return this 180 | */ 181 | public Builder addLayers( int amountOfToAddedLayers, int sizeOfLayers ) { 182 | Validator.value( amountOfToAddedLayers ).isPositiveOrThrow(); 183 | 184 | IntStream.range( 0, amountOfToAddedLayers ).forEach( i -> addLayer( sizeOfLayers ) ); 185 | 186 | return this; 187 | } 188 | 189 | /** 190 | * adds layers of the specified sizes to the neural network 191 | * 192 | * @param sizesOfLayers array of the number of nodes which are added 193 | * @return this 194 | */ 195 | public Builder addLayers( int... sizesOfLayers ) { 196 | Arrays.stream( sizesOfLayers ).forEach( this::addLayer ); 197 | 198 | return this; 199 | } 200 | 201 | /** 202 | * builds the NeuralNet 203 | * 204 | * @return the built NeuralNet 205 | * @throws IllegalStateException 206 | */ 207 | public NeuralNet build() { 208 | if ( isBuilt.getAndSet( true ) ) 209 | throw new IllegalStateException( "this builder has already been used for building" ); 210 | 211 | layerSizes.add( outputSize ); 212 | 213 | NeuralNet nn = new NeuralNet( inputSize ); 214 | 215 | StreamUtil.of( layerSizes.stream() ) 216 | .forEachWithBefore( inputSize, ( current, before ) -> 217 | nn.layers.add( new Layer( current, before, activationFunction ) ) ); 218 | 219 | return nn.randomize( weightRand, biasRand ); 220 | } 221 | 222 | } 223 | 224 | } 225 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/neuralnetwork/activationfunction/ActivationFunction.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.neuralnetwork.activationfunction; 2 | 3 | import java.io.Serializable; 4 | import java.util.function.DoubleUnaryOperator; 5 | 6 | @FunctionalInterface 7 | public interface ActivationFunction extends DoubleUnaryOperator, Serializable { 8 | double applyActivation( double x ); 9 | 10 | @Override 11 | default double applyAsDouble( double x ) { 12 | return applyActivation( x ); 13 | } 14 | 15 | default double derivative( double x){ 16 | throw new UnsupportedOperationException("This ActivationFunction doesn't provide an implementation for its derivative"); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/neuralnetwork/activationfunction/Sigmoid.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.neuralnetwork.activationfunction; 2 | 3 | public class Sigmoid implements ActivationFunction{ 4 | @Override 5 | public double applyActivation( double d ) { 6 | return ( 1 + Math.tanh( d / 2 ) ) / 2; 7 | } 8 | 9 | @Override 10 | public double derivative( double x ) { 11 | return applyActivation( x ) * ( 1 - applyActivation( x )); 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/neuralnetwork/activationfunction/Tanh.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.neuralnetwork.activationfunction; 2 | 3 | public class Tanh implements ActivationFunction{ 4 | 5 | @Override 6 | public double applyActivation(double x) { 7 | return Math.tanh(x); 8 | } 9 | 10 | @Override 11 | public double derivative(double x) { 12 | double tanh = applyActivation(x); 13 | return 1 - (tanh * tanh); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/neuralnetwork/costfunction/CostFunction.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.neuralnetwork.costfunction; 2 | 3 | import de.fhws.easyml.linearalgebra.Vector; 4 | 5 | public interface CostFunction { 6 | 7 | double costs( Vector expected, Vector actual); 8 | 9 | double derivativeWithRespectToNeuron( Vector expected, Vector actual, int indexOfNeuron ); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/ai/neuralnetwork/costfunction/SummedCostFunction.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.ai.neuralnetwork.costfunction; 2 | 3 | import de.fhws.easyml.linearalgebra.Vector; 4 | import de.fhws.easyml.utility.Validator; 5 | 6 | public class SummedCostFunction implements CostFunction{ 7 | 8 | @Override 9 | public double costs( Vector expected, Vector actual ) { 10 | Validator.value( expected.size() ).isEqualToOrThrow( actual.size() ); 11 | double temp = 0; 12 | for ( int i = 0; i < expected.size(); i++ ) { 13 | temp += Math.pow( expected.get( i ) - actual.get( i ), 2 ); 14 | } 15 | return temp; 16 | } 17 | 18 | @Override 19 | public double derivativeWithRespectToNeuron( Vector expected, Vector actual, int indexOfNeuron ) { 20 | return 2 * ( actual.get( indexOfNeuron ) - expected.get( indexOfNeuron ) ); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/GeneticAlgorithm.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm; 2 | 3 | import de.fhws.easyml.geneticalgorithm.populationsupplier.PopulationSupplier; 4 | import de.fhws.easyml.geneticalgorithm.evolution.Mutator; 5 | import de.fhws.easyml.geneticalgorithm.logger.Logger; 6 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner; 7 | import de.fhws.easyml.geneticalgorithm.evolution.Selector; 8 | import de.fhws.easyml.geneticalgorithm.saver.IntervalSaver; 9 | 10 | import java.io.File; 11 | import java.util.ArrayList; 12 | import java.util.Arrays; 13 | import java.util.List; 14 | import java.util.Optional; 15 | import java.util.concurrent.ExecutorService; 16 | import java.util.concurrent.Executors; 17 | import java.util.concurrent.atomic.AtomicBoolean; 18 | import java.util.function.IntConsumer; 19 | 20 | public class GeneticAlgorithm> { 21 | 22 | public static final String ILLEGAL_OPERATION_AFTER_SHUTDOWN_MESSAGE = "Genetic Algorithm already shutdowned"; 23 | 24 | private int size; 25 | private final int maxGens; 26 | private final Population population; 27 | private final Selector selector; 28 | private final Recombiner recombiner; 29 | private final Mutator mutator; 30 | 31 | private final IntConsumer getPreparator; 32 | private final IntervalSaver saver; 33 | private final List loggers; 34 | 35 | private final ExecutorService executor; 36 | 37 | private final AtomicBoolean shutdowned = new AtomicBoolean(false); 38 | 39 | private GeneticAlgorithm(PopulationSupplier popSupplier, int maxGens, Selector selector, 40 | Recombiner recombiner, Mutator mutator, IntConsumer getPreparator, IntervalSaver saver, List loggers, 41 | int amountThreads) { 42 | this.population = popSupplier.get(); 43 | this.maxGens = maxGens + population.getGeneration(); 44 | this.size = population.getSize(); 45 | this.selector = selector; 46 | this.recombiner = recombiner; 47 | this.mutator = mutator; 48 | this.getPreparator = getPreparator; 49 | this.saver = saver; 50 | this.loggers = loggers; 51 | 52 | this.executor = amountThreads > 1 ? Executors.newWorkStealingPool( amountThreads ) : null; 53 | 54 | } 55 | 56 | /** 57 | * Solves the given problem with an evolutional approach. 58 | * 59 | * @return T the developed solution for the problem. 60 | * @throws IllegalStateException if GeneticAlgorithm is already shutdowned 61 | */ 62 | public T solve() { 63 | validateShutdownState(); 64 | 65 | evolute(); 66 | 67 | return getBestAndShutdown(); 68 | } 69 | 70 | private T getBestAndShutdown() 71 | { 72 | T best = getBestIndividual(); 73 | 74 | shutdownGeneticAlgorithm(); 75 | 76 | return best; 77 | } 78 | 79 | private void evolute() 80 | { 81 | for (int i = population.getGeneration(); i < maxGens; i++) 82 | nextGen(); 83 | } 84 | 85 | private void validateShutdownState() 86 | { 87 | if(isShutdowned()) 88 | throw new IllegalStateException(ILLEGAL_OPERATION_AFTER_SHUTDOWN_MESSAGE); 89 | } 90 | 91 | private void shutdownGeneticAlgorithm() 92 | { 93 | shutdowned.set(true); 94 | getExecutor().ifPresent( ExecutorService::shutdown ); 95 | } 96 | 97 | private T getBestIndividual() { 98 | calculateFitnesses(); 99 | return population.getBest(); 100 | } 101 | 102 | private void nextGen() { 103 | prepareNextEvolution(); 104 | 105 | evoluteNextGen(); 106 | 107 | doEvolutionFollowUp(); 108 | } 109 | 110 | private void prepareNextEvolution() { 111 | callGenPreperator(); 112 | 113 | calculateFitnesses(); 114 | 115 | population.incGeneration(); 116 | 117 | callLoggers(); 118 | } 119 | 120 | private void doEvolutionFollowUp() { 121 | callSaver(); 122 | } 123 | 124 | private void callLoggers() { 125 | loggers.forEach(logger -> logger.log(maxGens, population)); 126 | } 127 | 128 | private void callSaver() { 129 | getSaver().ifPresent(saver -> saver.save(population)); 130 | } 131 | 132 | private void evoluteNextGen() { 133 | doSelection(); 134 | 135 | doRecombination(); 136 | 137 | doMutation(); 138 | } 139 | 140 | private void calculateFitnesses() { 141 | population.calcFitnesses(executor); 142 | } 143 | 144 | private void doMutation() { 145 | getMutator().ifPresent(m -> m.mutate(population, executor)); 146 | } 147 | 148 | private void doRecombination() { 149 | getRecombiner().ifPresent(r -> r.recombine(population, size, executor)); 150 | size = population.getSize(); 151 | } 152 | 153 | private void doSelection() { 154 | selector.select(population, executor); 155 | } 156 | 157 | private void callGenPreperator() { 158 | getGetPreparator().ifPresent(c -> c.accept(population.getGeneration())); 159 | } 160 | 161 | public boolean isShutdowned() 162 | { 163 | return shutdowned.get(); 164 | } 165 | 166 | public Optional> getRecombiner() { 167 | return Optional.ofNullable(recombiner); 168 | } 169 | 170 | public Optional> getMutator() { 171 | return Optional.ofNullable(mutator); 172 | } 173 | 174 | public Optional getGetPreparator() { 175 | return Optional.ofNullable(getPreparator); 176 | } 177 | 178 | public Optional getSaver() { 179 | return Optional.ofNullable(saver); 180 | } 181 | 182 | public Optional getExecutor() { 183 | return Optional.ofNullable(executor); 184 | } 185 | 186 | public static class Builder> { 187 | 188 | private final int maxGens; 189 | private final PopulationSupplier popSupplier; 190 | private final Selector selector; 191 | private Recombiner recombiner; 192 | private Mutator mutator; 193 | 194 | private IntConsumer genPreperator; 195 | private IntervalSaver saver; 196 | private List loggers = new ArrayList<>(); 197 | 198 | private int amountThreads; 199 | 200 | /** 201 | * Creates a Builder for a Generic Algorithm. 202 | * @param popSupplier provides an initial Population 203 | * @param maxGens are the amount of generation that should be computed 204 | * @param selector is used for the selection process 205 | * @see Java-Doc for Builder-Pattern 206 | */ 207 | public Builder(PopulationSupplier popSupplier, int maxGens, Selector selector) { 208 | this.selector = selector; 209 | this.maxGens = maxGens; 210 | this.popSupplier = popSupplier; 211 | } 212 | 213 | /** 214 | * The underlying Genetic Algorithm will use the given {@code recombiner} to refill the population in the recombination-process. 215 | * @param recombiner provides the recombination method 216 | * @return itself, due to Builder-Pattern 217 | */ 218 | public Builder withRecombiner(Recombiner recombiner) { 219 | this.recombiner = recombiner; 220 | return this; 221 | } 222 | 223 | /** 224 | * The underlying Genetic Alogrithm will use the give {@code mutator} to mutate the population after the recombination-process 225 | * @param mutator provides the mutation method 226 | * @return itself, due to Builder-Pattern 227 | */ 228 | public Builder withMutator(Mutator mutator) { 229 | this.mutator = mutator; 230 | return this; 231 | } 232 | 233 | /** 234 | * The underlying Genetic Algorithm will call the genPreperator before every new Generation. It can be used e.g. to synchronize Input-Data for the fitness calculation. 235 | * @param genPreperator provides method that gets called 236 | * @return itself, due to Builder-Pattern 237 | */ 238 | public Builder withGenPreperator(IntConsumer genPreperator) { 239 | this.genPreperator = genPreperator; 240 | return this; 241 | } 242 | 243 | /** 244 | * used to save populations to file 245 | * 246 | * @param dir is a directory where the population will be saved in 247 | * @param interval is the interval of which the populations should be saved 248 | * @param override determines whether a new file should be created each time or the old one overridden 249 | * @return itself, due to Builder-Pattern 250 | **/ 251 | public Builder withSaveToFile(File dir, int interval, boolean override) { 252 | this.saver = new IntervalSaver(interval, override, dir); 253 | return this; 254 | } 255 | 256 | /** 257 | * The Loggers will be called after every generation and can be used to log the current state of the solving process 258 | * @param loggers provides logging method 259 | * @return itself, due to Builder-Pattern 260 | */ 261 | public Builder withLoggers(Logger ... loggers) { 262 | this.loggers.addAll(Arrays.asList(loggers)); 263 | return this; 264 | } 265 | 266 | /** 267 | * Is used to controll the Multi-Thread behavior of the Genetic Algorith. If {@code amountThreads} is greater than 1 the underlying Genetic Algorithm will use a MultiThread approach if the given Implementations of 268 | * {@link Selector}, {@link Recombiner} and {@link Mutator} support Multi-Threading. 269 | * @param amountThreads are the maximal number of Threads available in the Threadpool 270 | * @return itself, due to Builder-Patten 271 | * @throws IllegalArgumentException if {@code amountThreads} is smaller than 1 272 | */ 273 | public Builder withMultiThreaded(int amountThreads) { 274 | if (amountThreads < 1) 275 | throw new IllegalArgumentException("amount of threads must be in at least 1"); 276 | this.amountThreads = amountThreads; 277 | return this; 278 | } 279 | 280 | /** 281 | * builds the Genetic Algorithm with the previously given Attributes 282 | * @return build Genetic Algorithm 283 | */ 284 | public GeneticAlgorithm build() { 285 | return new GeneticAlgorithm<>(popSupplier, maxGens, selector, recombiner, mutator, genPreperator, saver, 286 | loggers, amountThreads); 287 | } 288 | } 289 | 290 | } 291 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/Individual.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm; 2 | 3 | import java.io.Serializable; 4 | 5 | public interface Individual> extends Comparable>, Serializable{ 6 | 7 | void calcFitness(); 8 | 9 | double getFitness(); 10 | 11 | T copy(); 12 | 13 | //This Warning can be suppressed, because the only way this could crash is 14 | //when a class is implementing Individual but gives another class but itself as the Generic Type 15 | //This is not an expected behavior and the whole program would crash anyway, therefore there is no further need to handle this cast 16 | @SuppressWarnings("unchecked") 17 | default T getThis(){ 18 | try { 19 | return (T) this; 20 | } 21 | catch (ClassCastException e) { 22 | throw new RuntimeException("Couldn't get an instance of individual. Maybe your implementation of Individual gives another Generic Type but itself to the Individual."); 23 | } 24 | } 25 | 26 | default int compareTo(Individual o) { 27 | if (getFitness() - o.getFitness() == 0) 28 | return 0; 29 | else 30 | return getFitness() - o.getFitness() < 0 ? -1 : 1; 31 | } 32 | 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/Population.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm; 2 | 3 | import de.fhws.easyml.utility.MultiThreadHelper; 4 | 5 | import java.io.Serializable; 6 | import java.util.*; 7 | import java.util.concurrent.ExecutorService; 8 | import java.util.stream.Collectors; 9 | 10 | public final class Population> implements Serializable { 11 | private List individuals; 12 | 13 | private int generation; 14 | 15 | public Population() { 16 | this.individuals = new ArrayList<>(); 17 | } 18 | 19 | public Population(List individuals) { 20 | this.individuals = new ArrayList<>(individuals); 21 | } 22 | 23 | void calcFitnesses(ExecutorService executor) { 24 | if(executor != null) { 25 | MultiThreadHelper.callConsumerOnCollection(executor, individuals, T::calcFitness); 26 | } 27 | else { 28 | individuals.forEach(T::calcFitness); 29 | } 30 | } 31 | 32 | /** 33 | * sorts the population in DESCENDING order. 34 | * It uses the compareTo method of the solutions for this. 35 | */ 36 | public void sortPopByFitness() { 37 | individuals.sort(Comparator.reverseOrder()); 38 | } 39 | 40 | public T getBest() { 41 | return Collections.max(individuals); 42 | } 43 | 44 | public double getAverageFitness() { 45 | return individuals.stream().mapToDouble(Individual::getFitness).average().getAsDouble(); 46 | } 47 | 48 | public int getGeneration() { 49 | return generation; 50 | } 51 | 52 | public void incGeneration() { 53 | generation++; 54 | } 55 | 56 | public int getSize() { 57 | return individuals.size(); 58 | } 59 | 60 | /** 61 | * Replaces all Individuals with a copy of the given Individuals if and only if the reference of this Individual occurs more than once in the given collection. 62 | * Otherwise, the Individuals get replaced with the Individuals in {@code collection} 63 | * 64 | * @param collection of Individuals which may be copied 65 | */ 66 | public void replaceAllIndividuals(Collection> collection) { 67 | individuals.clear(); 68 | 69 | Map referenceMap = collection.stream().collect( Collectors.groupingBy( System::identityHashCode, Collectors.counting() ) ); 70 | 71 | individuals.addAll( collection.stream() 72 | .map( individual -> doesElementOccurMoreThanOnce( referenceMap, individual ) ? individual.copy() : individual.getThis() ) 73 | .collect( Collectors.toList() ) ); 74 | } 75 | 76 | private boolean doesElementOccurMoreThanOnce( Map referenceMap, Individual element ) { 77 | return referenceMap.merge( System.identityHashCode( element ), -1L, Long::sum ) > 0; 78 | } 79 | 80 | public List getIndividuals() { 81 | return individuals; 82 | } 83 | 84 | /** 85 | * Replaces all Individuals with the given ones. Be aware that when 2 or more Individuals refer to the same reference there might be unwanted sight effects 86 | * in the mutation process. It's therefore recommended using the {@link #replaceAllIndividuals(Collection)} method if one is not sure whether 2 or more Individuals share 87 | * the same reference 88 | * @param individuals that replace the old ones 89 | * @throws NullPointerException if given individuals are null 90 | */ 91 | public void setIndividuals(List individuals) { 92 | if (individuals == null) 93 | throw new NullPointerException("Individuals can't be null"); 94 | this.individuals = individuals; 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/evolution/Mutator.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.evolution; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import org.jetbrains.annotations.Nullable; 6 | 7 | import java.util.concurrent.ExecutorService; 8 | 9 | @FunctionalInterface 10 | public interface Mutator> { 11 | /** 12 | * Used to mutate a population. 13 | * @param pop Population to be mutated 14 | * @param executorService may be null. If not it can be used to implement a multithreaded version 15 | */ 16 | void mutate(Population pop, @Nullable ExecutorService executorService); 17 | 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/evolution/Recombiner.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.evolution; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import org.jetbrains.annotations.Nullable; 6 | 7 | import java.util.concurrent.ExecutorService; 8 | 9 | @FunctionalInterface 10 | public interface Recombiner> { 11 | 12 | /** 13 | * Fills up the Population after the Selection step and may recombine the genes of different parents. 14 | * @param pop Population to fill up again 15 | * @param goalSize The total size that the Population needs to have after the Recombination step 16 | * @param executorService may be null. If not it can be used to implement a multithreaded version 17 | */ 18 | void recombine(Population pop, int goalSize, @Nullable ExecutorService executorService); 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/evolution/Selector.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.evolution; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import org.jetbrains.annotations.Nullable; 6 | 7 | import java.util.concurrent.ExecutorService; 8 | 9 | public interface Selector> { 10 | /** 11 | * Throws out the worst Individuals of the population 12 | * @param pop Population that needs to be selected 13 | * @param executorService may be null. If not it can be used to implement a multithreaded version 14 | */ 15 | void select(Population pop, @Nullable ExecutorService executorService); 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/evolution/recombiners/FillUpRecombiner.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.evolution.recombiners; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner; 6 | import org.jetbrains.annotations.Nullable; 7 | 8 | import java.util.concurrent.ExecutorService; 9 | 10 | public class FillUpRecombiner> implements Recombiner { 11 | 12 | @Override 13 | public void recombine(Population pop, int goalSize, @Nullable ExecutorService executorService ) { 14 | while ( pop.getSize() < goalSize ) { 15 | pop.getIndividuals().add( pop.getIndividuals().get( (int) ( Math.random() * pop.getSize() ) ).copy() ); 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/evolution/selectors/EliteSelector.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.evolution.selectors; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import org.jetbrains.annotations.Nullable; 6 | 7 | import java.util.concurrent.ExecutorService; 8 | 9 | public class EliteSelector> extends PercentageSelector { 10 | 11 | public EliteSelector(double percent) { 12 | super(percent); 13 | } 14 | 15 | @Override 16 | public void select(Population pop, @Nullable ExecutorService executorService) { 17 | pop.sortPopByFitness(); 18 | repopulate(pop); 19 | } 20 | 21 | private void repopulate(Population pop) { 22 | int goalSize = super.calcGoalSize(pop.getSize()); 23 | while (pop.getSize() > goalSize) 24 | pop.getIndividuals().remove(goalSize); 25 | } 26 | 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/evolution/selectors/PercentageSelector.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.evolution.selectors; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.evolution.Selector; 5 | 6 | public abstract class PercentageSelector> implements Selector { 7 | private double percent; 8 | 9 | public PercentageSelector(double percent) { 10 | setPercent(percent); 11 | } 12 | 13 | protected int calcGoalSize(int popSize) { 14 | return (int) Math.ceil(popSize * percent); 15 | } 16 | 17 | public double getPercent() { 18 | return percent; 19 | } 20 | 21 | public void setPercent(double percent) { 22 | if(percent < 0 || percent > 1) 23 | throw new IllegalArgumentException("percent must be between 0 and 1 (inclusive)"); 24 | 25 | this.percent = percent; 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/evolution/selectors/RouletteWheelSelector.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.evolution.selectors; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import de.fhws.easyml.utility.MultiThreadHelper; 6 | import de.fhws.easyml.utility.throwingintefaces.ThrowingRunnable; 7 | import org.jetbrains.annotations.Nullable; 8 | 9 | import java.util.ArrayList; 10 | import java.util.Collections; 11 | import java.util.List; 12 | import java.util.concurrent.Callable; 13 | import java.util.concurrent.ExecutorService; 14 | 15 | public class RouletteWheelSelector> extends PercentageSelector { 16 | 17 | private final boolean ensureAddFirst; 18 | 19 | public RouletteWheelSelector(double percent, boolean ensureAddFirst) { 20 | super(percent); 21 | this.ensureAddFirst = ensureAddFirst; 22 | } 23 | 24 | @Override 25 | public void select(Population pop, @Nullable ExecutorService executorService) { 26 | double totalFitness = calcTotalFitness(pop); 27 | 28 | List probabilityList = calcProbabilityList(pop, totalFitness); 29 | 30 | List> repopulated = repopulate(pop, probabilityList, executorService); 31 | 32 | pop.replaceAllIndividuals(repopulated); 33 | } 34 | 35 | 36 | private double calcTotalFitness(Population pop) { 37 | double totalFitness = 0; 38 | for (Individual individual : pop.getIndividuals()) 39 | totalFitness += individual.getFitness(); 40 | return totalFitness; 41 | } 42 | 43 | private List calcProbabilityList(Population pop, double totalFitness) { 44 | List probabilityList = new ArrayList<>(pop.getIndividuals().size()); 45 | double cumulativeProb = 0; 46 | for(Individual individual : pop.getIndividuals()) { 47 | if(individual.getFitness() == 0) 48 | continue; 49 | cumulativeProb += individual.getFitness() / totalFitness; 50 | probabilityList.add(cumulativeProb); 51 | } 52 | return probabilityList; 53 | } 54 | 55 | private List> repopulate(Population pop, List probabilityList, ExecutorService executorService) { 56 | int goalSize = super.calcGoalSize(pop.getSize()); 57 | 58 | List> repopulated = new ArrayList<>(goalSize); 59 | 60 | if(ensureAddFirst) { 61 | repopulated.add(pop.getBest()); 62 | } 63 | 64 | if(executorService != null) { 65 | return doRepopulateMultiThreaded(pop, probabilityList, executorService, goalSize, repopulated); 66 | } 67 | else { 68 | return doRepopulateSingleThreaded(pop, probabilityList, goalSize, repopulated); 69 | } 70 | } 71 | 72 | private List> doRepopulateSingleThreaded(Population pop, List probabilityList, int goalSize, 73 | List> repopulated) { 74 | while (repopulated.size() < goalSize) { 75 | repopulated.add(getElementByProbabilityList(pop, probabilityList)); 76 | } 77 | return repopulated; 78 | } 79 | 80 | private List> doRepopulateMultiThreaded(Population pop, List probabilityList, 81 | ExecutorService executorService, int goalSize, List> repopulated) 82 | { 83 | List> synchronizedRepopulated = Collections.synchronizedList(repopulated); 84 | 85 | List> calls = getCallsForRepopulation(pop, probabilityList, goalSize, repopulated, synchronizedRepopulated); 86 | 87 | ThrowingRunnable.unchecked( () -> executorService.invokeAll(calls) ).run(); 88 | 89 | return synchronizedRepopulated; 90 | } 91 | 92 | private List> getCallsForRepopulation(Population pop, List probabilityList, int goalSize, 93 | List> repopulated, List> synchronizedRepopulated) 94 | { 95 | List> calls = new ArrayList<>(); 96 | 97 | int initSize = repopulated.size(); 98 | 99 | for (int i = 0; i < goalSize - initSize; i++) { 100 | calls.add( MultiThreadHelper.transformToCallableVoid( 101 | () -> synchronizedRepopulated.add(getElementByProbabilityList(pop, probabilityList) ) 102 | ) ); 103 | } 104 | 105 | return calls; 106 | } 107 | 108 | private Individual getElementByProbabilityList(Population pop, List probabilityList){ 109 | int index = Collections.binarySearch(probabilityList, Math.random()); 110 | return pop.getIndividuals().get(index >= 0 ? index : -index - 1); 111 | } 112 | 113 | public static class EnsuredSingleThread> extends RouletteWheelSelector{ 114 | 115 | public EnsuredSingleThread(double percent, boolean ensureAddFirst) 116 | { 117 | super(percent, ensureAddFirst); 118 | } 119 | 120 | @Override public void select(Population pop, ExecutorService executorService) 121 | { 122 | super.select(pop, null); 123 | } 124 | } 125 | 126 | public static class EnsureSingleThreading> extends RouletteWheelSelector{ 127 | 128 | public EnsureSingleThreading( double percent, boolean ensureAddFirst ) { 129 | super( percent, ensureAddFirst ); 130 | } 131 | 132 | @Override 133 | public void select( Population pop, ExecutorService executorService ) { 134 | super.select( pop, null ); 135 | } 136 | } 137 | 138 | } 139 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/evolution/selectors/TournamentSelector.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.evolution.selectors; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import org.jetbrains.annotations.Nullable; 6 | 7 | import java.util.ArrayList; 8 | import java.util.Collections; 9 | import java.util.List; 10 | import java.util.Random; 11 | import java.util.concurrent.ExecutorService; 12 | 13 | public class TournamentSelector> extends PercentageSelector { 14 | private final int tournamentSize; 15 | 16 | public TournamentSelector(double percent, int tournamentSize) { 17 | super(percent); 18 | if(tournamentSize < 1) 19 | throw new IllegalArgumentException("tournamentSize must at least 1"); 20 | 21 | this.tournamentSize = tournamentSize; 22 | } 23 | 24 | @Override 25 | public void select(Population pop, @Nullable ExecutorService executorService) { 26 | int goalSize = super.calcGoalSize(pop.getSize()); 27 | 28 | if(tournamentSize > pop.getSize()) 29 | throw new RuntimeException("the population size is to small for the tournament size"); 30 | 31 | if(pop.getSize() != goalSize) { 32 | List> repopulated = repopulate(pop, goalSize); 33 | pop.replaceAllIndividuals(repopulated); 34 | } 35 | 36 | } 37 | 38 | private Individual playTournament(Population pop) { 39 | Random rand = new Random(); 40 | ArrayList> participants = new ArrayList<>(tournamentSize); 41 | 42 | // get participants 43 | for(int i = 0; i < tournamentSize; i++) { 44 | Individual newOne = pop.getIndividuals().get(rand.nextInt(pop.getSize())); 45 | if(participants.contains(newOne)) 46 | i--; 47 | else 48 | participants.add(newOne); 49 | } 50 | 51 | return Collections.max(participants); 52 | } 53 | 54 | private List> repopulate(Population pop, int goalSize) { 55 | List> repopulated = new ArrayList<>(goalSize); 56 | 57 | while (repopulated.size() < goalSize) { 58 | Individual newOne = playTournament(pop); 59 | repopulated.add(newOne); 60 | } 61 | return repopulated; 62 | } 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/logger/Logger.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.logger; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Population; 4 | import de.fhws.easyml.geneticalgorithm.Individual; 5 | import de.fhws.easyml.logger.LoggerInterface; 6 | 7 | public interface Logger extends LoggerInterface>> { 8 | 9 | @Override 10 | void log(int maxGen, Population> population); 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/logger/loggers/ConsoleLogger.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.logger.loggers; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Population; 4 | import de.fhws.easyml.geneticalgorithm.Individual; 5 | import de.fhws.easyml.geneticalgorithm.logger.Logger; 6 | 7 | public class ConsoleLogger implements Logger { 8 | 9 | @Override 10 | public void log(int maxGen, Population> population) { 11 | System.out.println("Gen: " + population.getGeneration() + " of " + maxGen + " best Fitness " + population.getBest().getFitness() + " avg Fitness: " + population.getAverageFitness()); 12 | } 13 | 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/logger/loggers/IntervalConsoleLogger.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.logger.loggers; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Population; 4 | import de.fhws.easyml.geneticalgorithm.Individual; 5 | import de.fhws.easyml.geneticalgorithm.logger.Logger; 6 | 7 | public class IntervalConsoleLogger implements Logger { 8 | 9 | ConsoleLogger logger = new ConsoleLogger(); 10 | final int intervall; 11 | int counter = 0; 12 | public IntervalConsoleLogger(int intervall) { 13 | this.intervall = intervall; 14 | } 15 | 16 | @Override 17 | public void log(int maxGen, Population> population) { 18 | counter++; 19 | if(counter >= intervall) 20 | { 21 | logger.log(maxGen, population); 22 | counter = 0; 23 | } 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/logger/loggers/graphplotter/GraphPlotLogger.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter; 2 | 3 | import com.aspose.cells.*; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import de.fhws.easyml.geneticalgorithm.Individual; 6 | import de.fhws.easyml.geneticalgorithm.logger.Logger; 7 | import de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines.LineGenerator; 8 | 9 | import java.util.Arrays; 10 | import java.util.List; 11 | import java.util.concurrent.atomic.AtomicBoolean; 12 | 13 | public class GraphPlotLogger implements Logger { 14 | 15 | private static final int WIDTH = 15; 16 | private static final int HEIGHT = 25; 17 | 18 | private static final String PLOT_HOME_DIR = "plots"; 19 | private static final String PLOT_FILE_ENDING = ".xls"; 20 | 21 | private static final String DEFAULT_CHART_TITLE_PREFIX = "Plot for Population size: "; 22 | private static final String X_AXIS_NAME = "Generation"; 23 | private static final String Y_AXIS_NAME = "Fitness Value"; 24 | 25 | private static final AtomicBoolean firstTime = new AtomicBoolean(true); 26 | private int plottingInterval = 1; 27 | private final String filename; 28 | private String chartTitle; 29 | private final List lineGenerators; 30 | private int counter = 0; 31 | 32 | 33 | /** 34 | * This Plotter can be used to create a .xls file including a graph with a custom name of the evolution process. 35 | * @param plottingInterval is the interval in which the file gets created. It will always trigger at the end of the evolution process. 36 | * So a value below 0 results in exactly one triggering 37 | * @param filename this is the name of the resulting file without file-ending 38 | * @param chartTitle the title of the generating chart 39 | * @param lineGenerators generate the plots in the chart 40 | */ 41 | public GraphPlotLogger(int plottingInterval, String filename, String chartTitle, LineGenerator... lineGenerators) { 42 | this.plottingInterval = plottingInterval; 43 | this.filename = filename; 44 | this.chartTitle = chartTitle; 45 | this.lineGenerators = Arrays.asList(lineGenerators); 46 | } 47 | 48 | /** 49 | * This Plotter can be used to create a .xls file including a graph with the title "Plot for Population size: {size}" of the evolution process. 50 | * @param plottingInterval is the interval in which the file gets created. It will always trigger at the end of the evolution process. 51 | * So a value below 0 results in exactly one triggering 52 | * @param filename this is the name of the resulting file without file-ending 53 | * @param lineGenerators generate the plots in the chart 54 | */ 55 | public GraphPlotLogger(int plottingInterval, String filename, LineGenerator... lineGenerators){ 56 | this.plottingInterval = plottingInterval; 57 | this.filename = filename; 58 | this.lineGenerators = Arrays.asList(lineGenerators); 59 | } 60 | 61 | @Override 62 | public void log(int maxGen, Population> population) { 63 | configureValuesOnFirstLog(population); 64 | 65 | savePopulationValues(population); 66 | doPlottingIfNeeded(maxGen, population); 67 | } 68 | 69 | private void configureValuesOnFirstLog(Population> population) { 70 | if(firstTime.getAndSet(false)) 71 | configureChartDefaultTitle(population); 72 | 73 | } 74 | 75 | private void configureChartDefaultTitle(Population> population) { 76 | if(chartTitle == null) 77 | chartTitle = DEFAULT_CHART_TITLE_PREFIX+ population.getSize(); 78 | } 79 | 80 | 81 | private void savePopulationValues(Population> population) { 82 | lineGenerators.forEach(lg -> lg.log(population)); 83 | } 84 | 85 | private void doPlottingIfNeeded(int maxGen, Population> population) { 86 | if(!isLastGeneration(maxGen, population) && plottingInterval <= 0) 87 | return; 88 | 89 | if (counter++ % plottingInterval == 0 || isLastGeneration(maxGen, population)) 90 | plot(maxGen); 91 | 92 | 93 | } 94 | 95 | private boolean isLastGeneration(int maxGen, Population> population) { 96 | return population.getGeneration() == maxGen; 97 | } 98 | 99 | 100 | public void plot(int plottingAmountOfData) { 101 | Workbook workbook = new Workbook(); 102 | Worksheet worksheet = workbook.getWorksheets().get(0); 103 | 104 | putDataIntoWorksheet(worksheet, plottingAmountOfData); 105 | 106 | createGraph(worksheet, plottingAmountOfData); 107 | 108 | saveWorkbook(workbook); 109 | } 110 | 111 | private void createGraph(Worksheet worksheet, int plottingAmountOfData) { 112 | Chart chart = createChart(worksheet); 113 | chart.getTitle().setText(chartTitle); 114 | setAxisNames(chart); 115 | chart.setChartDataRange(getChartDataArea(worksheet, plottingAmountOfData), true); 116 | } 117 | 118 | private void setAxisNames(Chart chart) { 119 | chart.getCategoryAxis().getTitle().setText(X_AXIS_NAME); 120 | chart.getValueAxis().getTitle().setText(Y_AXIS_NAME); 121 | } 122 | 123 | private String getChartDataArea(Worksheet worksheet, int plottingAmountOfData) { 124 | return getChartDataFromLocation(worksheet) + ":" + getChartDataToLocation(worksheet, plottingAmountOfData); 125 | } 126 | 127 | private String getChartDataFromLocation(Worksheet worksheet) { 128 | return worksheet.getCells().get(0, 0).getName(); 129 | } 130 | 131 | private String getChartDataToLocation(Worksheet worksheet, int plottingAmountOfData) { 132 | return worksheet.getCells().get(plottingAmountOfData, lineGenerators.size()-1).getName(); 133 | } 134 | 135 | private Chart createChart(Worksheet worksheet) { 136 | int chartIndex = worksheet.getCharts().add(ChartType.LINE, 0, 0, HEIGHT, WIDTH); 137 | return worksheet.getCharts().get(chartIndex); 138 | } 139 | 140 | private void saveWorkbook(Workbook workbook) { 141 | try { 142 | workbook.save(PLOT_HOME_DIR + "/" + filename + PLOT_FILE_ENDING, SaveFormat.AUTO); 143 | } catch (Exception e) { 144 | throw new RuntimeException(e); 145 | } 146 | } 147 | 148 | private void putDataIntoWorksheet(Worksheet worksheet, int plottingAmountOfData) { 149 | for (int col = 0; col < lineGenerators.size(); col++) { 150 | LineGenerator lg = lineGenerators.get(col); 151 | writeLineGeneratorNameIntoWorksheet(worksheet, col, lg); 152 | putValuesIntoWorksheet(worksheet, plottingAmountOfData, col, lg); 153 | } 154 | } 155 | 156 | private void writeLineGeneratorNameIntoWorksheet(Worksheet worksheet, int col, LineGenerator lg) { 157 | worksheet.getCells().get(0, col).putValue(lg.getName()); 158 | } 159 | 160 | private void putValuesIntoWorksheet(Worksheet worksheet, int plottingAmountOfData, int col, LineGenerator lg) { 161 | for (int row = 0; row < plottingAmountOfData; row++) 162 | putValueInWorksheet(worksheet, row + 1, col, getValueFromLineGenerator(lg, row)); 163 | } 164 | 165 | private Double getValueFromLineGenerator(LineGenerator lg, int index) { 166 | return lg.getValues().size() > index ? lg.getValue(index) : 0; 167 | } 168 | 169 | private void putValueInWorksheet(Worksheet worksheet, int row, int col, double value) { 170 | worksheet.getCells().get(row, col).putValue(value); 171 | } 172 | 173 | 174 | } 175 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/logger/loggers/graphplotter/lines/AvgFitnessLine.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Population; 4 | 5 | import java.util.function.Function; 6 | 7 | public class AvgFitnessLine extends LineGenerator { 8 | /** 9 | * Plots a Line of the average fitness of each Generation 10 | */ 11 | public AvgFitnessLine() { 12 | super("Avg. Fitness"); 13 | } 14 | 15 | @Override 16 | protected Function, Double> getConverter() { 17 | return Population::getAverageFitness; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/logger/loggers/graphplotter/lines/LineGenerator.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Population; 4 | 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | import java.util.function.Function; 8 | 9 | public abstract class LineGenerator{ 10 | private final String name; 11 | private final List values = new ArrayList<>(); 12 | 13 | public LineGenerator(String name) { 14 | this.name = name; 15 | } 16 | 17 | public List getValues() { 18 | return values; 19 | } 20 | 21 | public Double getValue(int index){ 22 | return values.get(index); 23 | } 24 | 25 | public String getName() { 26 | return name; 27 | } 28 | 29 | public void log(Population pop){ 30 | values.add(convert(pop)); 31 | } 32 | 33 | private double convert(Population pop){ 34 | return getConverter().apply(pop); 35 | } 36 | 37 | 38 | /** 39 | * This function converts a Population into a single Double value that can be plotted onto a graph 40 | * @return the converted value 41 | */ 42 | protected abstract Function, Double> getConverter(); 43 | } -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/logger/loggers/graphplotter/lines/MaxFitnessLine.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Population; 4 | 5 | import java.util.function.Function; 6 | 7 | public class MaxFitnessLine extends LineGenerator { 8 | 9 | /** 10 | * Plots a line of the max fitness of each Generation 11 | */ 12 | public MaxFitnessLine() { 13 | super("Max Fitness"); 14 | } 15 | 16 | @Override 17 | protected Function, Double> getConverter() { 18 | return (pop) -> pop.getBest().getFitness(); 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/logger/loggers/graphplotter/lines/NQuantilFitnessLine.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import de.fhws.easyml.utility.Validator; 6 | 7 | import java.util.function.Function; 8 | 9 | public class NQuantilFitnessLine extends LineGenerator { 10 | final double percentage; 11 | 12 | public NQuantilFitnessLine(double percentage) { 13 | super((percentage * 100) + "% Quantil"); 14 | Validator.value(percentage).isBetweenOrThrow(0, 1); 15 | this.percentage = percentage; 16 | } 17 | 18 | @Override 19 | protected Function, Double> getConverter() { 20 | return pop -> pop.getIndividuals() 21 | .stream() 22 | .sorted() 23 | .skip((long) (pop.getSize() * percentage)) 24 | .findFirst() 25 | .map(Individual::getFitness) 26 | .orElse(0d); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/logger/loggers/graphplotter/lines/WorstFitnessLine.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | 6 | import java.util.Comparator; 7 | import java.util.function.Function; 8 | 9 | public class WorstFitnessLine extends LineGenerator { 10 | public WorstFitnessLine() { 11 | super("Worst Fitness"); 12 | } 13 | 14 | @Override 15 | protected Function, Double> getConverter() { 16 | return pop -> pop.getIndividuals().stream().min(Comparator.naturalOrder()).map(Individual::getFitness).orElse(0d); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/populationsupplier/PopulationByFileSupplier.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.populationsupplier; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Population; 4 | import de.fhws.easyml.geneticalgorithm.Individual; 5 | import de.fhws.easyml.utility.FileHandler; 6 | 7 | import java.io.File; 8 | 9 | public class PopulationByFileSupplier> implements PopulationSupplier { 10 | 11 | private final File file; 12 | 13 | public PopulationByFileSupplier( File file ) { 14 | this.file = file; 15 | } 16 | 17 | @SuppressWarnings("unchecked") 18 | @Override 19 | public Population get() { 20 | Object posPop = FileHandler.getFirstObjectFromFile( file ); 21 | if(posPop instanceof Population) 22 | return (Population) posPop; 23 | throw new IllegalArgumentException("file: \"" + file + "\" does not contain a Population of T as first Object."); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/populationsupplier/PopulationSupplier.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.populationsupplier; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Population; 4 | import de.fhws.easyml.geneticalgorithm.Individual; 5 | 6 | import java.util.function.Supplier; 7 | 8 | @FunctionalInterface 9 | public interface PopulationSupplier> extends Supplier>{ 10 | 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/geneticalgorithm/saver/IntervalSaver.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.geneticalgorithm.saver; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import de.fhws.easyml.utility.FileHandler; 6 | 7 | import java.io.File; 8 | 9 | public class IntervalSaver { 10 | 11 | private int counter; 12 | private final int intervall; 13 | private final boolean override; 14 | private final File dir; 15 | 16 | public IntervalSaver(int intervall, boolean override, File dir) { 17 | this.intervall = intervall; 18 | this.override = override; 19 | this.dir = dir; 20 | } 21 | 22 | public void save(Population> pop) { 23 | 24 | if(counter % intervall == 0) 25 | FileHandler.writeObjectToAGeneratedFileLocation(pop, "population", dir.getAbsolutePath(), !override, ".ser", true); 26 | 27 | counter++; 28 | } 29 | 30 | 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/linearalgebra/ApplyAble.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.linearalgebra; 2 | 3 | import java.util.function.DoubleUnaryOperator; 4 | 5 | public interface ApplyAble { 6 | T apply(DoubleUnaryOperator function); 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/linearalgebra/LinearAlgebra.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.linearalgebra; 2 | 3 | import java.util.function.DoubleUnaryOperator; 4 | 5 | 6 | public class LinearAlgebra { 7 | 8 | public static Vector unitVector( int size ) { 9 | return new Vector( size ).apply( operand -> 1 ); 10 | } 11 | 12 | public static Vector zeroVector( int size ) { 13 | return new Vector( size ).apply( operand -> 0 ) ; 14 | } 15 | 16 | public static Vector vectorWithValues( int size, double value) { 17 | return new Vector( size ).apply( operand -> value ); 18 | } 19 | 20 | public static Vector add(Vector v1, Vector v2) { 21 | Vector res = v1.copy(); 22 | res.add(v2); 23 | return res; 24 | } 25 | 26 | public static Vector sub(Vector v1, Vector v2) { 27 | Vector res = v1.copy(); 28 | res.sub(v2); 29 | return res; 30 | } 31 | 32 | public static Vector apply(Vector vector, DoubleUnaryOperator func) { 33 | Vector res = vector.copy(); 34 | res.apply(func); 35 | return res; 36 | } 37 | 38 | public static Vector multiply(Matrix matrix, Vector vector) { 39 | if(vector.size() != matrix.getNumCols()) 40 | throw new IllegalArgumentException("vector must have the same size as the matrix has columns"); 41 | Vector res = new Vector(matrix.getNumRows()); 42 | for(int i = 0; i < matrix.getNumRows(); i++) { 43 | double sum = 0; 44 | for(int j = 0; j < matrix.getData()[i].length; j++) { 45 | sum += matrix.getData()[i][j] * vector.get(j); 46 | } 47 | res.set(i, sum); 48 | } 49 | return res; 50 | } 51 | 52 | 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/linearalgebra/Matrix.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.linearalgebra; 2 | 3 | import java.io.Serializable; 4 | import java.util.Arrays; 5 | import java.util.function.DoubleUnaryOperator; 6 | import java.util.stream.DoubleStream; 7 | 8 | public class Matrix implements Serializable, ApplyAble { 9 | 10 | private double[][] data; 11 | 12 | /** 13 | * creates a matrix with the given size. Initialized with 0. 14 | * @param rows number of rows 15 | * @param cols number of columns 16 | */ 17 | public Matrix(int rows, int cols) { 18 | if(rows == 0 || cols == 0) 19 | throw new IllegalArgumentException(); 20 | this.data = new double[rows][cols]; 21 | for(int i = 0; i < this.data.length; i++) { 22 | this.data[i] = new double[data[i].length]; 23 | } 24 | } 25 | 26 | /** 27 | * creates a matrix with a copy of the given data 28 | * @param data given data 29 | */ 30 | public Matrix(double[][] data) { 31 | if(data.length == 0 || data[0].length == 0) 32 | throw new IllegalArgumentException(); 33 | 34 | this.data = new double[data.length][data[0].length]; 35 | for(int i = 0; i < this.data.length; i++) { 36 | if(data[i].length != data[0].length) 37 | throw new IllegalArgumentException("all rows of the data must have the same length"); 38 | System.arraycopy(data[i], 0, this.data[i], 0, this.data[i].length); 39 | } 40 | } 41 | 42 | /** 43 | * randomizes this matrix 44 | * @param randomizer given randomizer, that holds the range 45 | */ 46 | public void randomize(Randomizer randomizer) { 47 | for (int i = 0; i < getNumRows(); i++) { 48 | for (int j = 0; j < getNumCols(); j++) { 49 | data[i][j] = randomizer.getInRange(); 50 | } 51 | } 52 | } 53 | 54 | /** 55 | * applies the {@code func} on every element 56 | * @param func that should be applied 57 | */ 58 | @Override 59 | public Matrix apply(DoubleUnaryOperator func) { 60 | for (int i = 0; i < data.length; i++) { 61 | for (int j = 0; j < data[i].length; j++) { 62 | data[i][j] = func.applyAsDouble(data[i][j]); 63 | } 64 | } 65 | return this; 66 | } 67 | 68 | 69 | /** 70 | * Gets the highest double of the data matrix. 71 | * @return the highest number 72 | */ 73 | public double getHighestNumber() { 74 | double record = Double.NEGATIVE_INFINITY; 75 | for (int i = 0; i < data.length; i++) { 76 | for (int j = 0; j < data[i].length; j++) { 77 | if(get(i,j) > record) 78 | record = get(i,j); 79 | } 80 | } 81 | return record; 82 | } 83 | 84 | /** 85 | * gets the double value in the specified row and column 86 | * @param row row of the value 87 | * @param col column of the value 88 | * @return the specified double value 89 | */ 90 | public double get(int row, int col) { 91 | return data[row][col]; 92 | } 93 | 94 | /** 95 | * sets the double value in the specified row and column 96 | * @param row row of the value 97 | * @param col column of the value 98 | */ 99 | public void set(int row, int col, double value) { 100 | data[row][col] = value; 101 | } 102 | 103 | /** 104 | * gets the number of rows 105 | * @return number of rows 106 | */ 107 | public int getNumRows() { 108 | return data.length; 109 | } 110 | 111 | /** 112 | * gets the number of columns 113 | * @return number of columns 114 | */ 115 | public int getNumCols() { 116 | return data.length == 0 ? 0 : data[0].length; 117 | } 118 | 119 | 120 | /** 121 | * get the data of the Matrix 122 | * @return double[][] data 123 | */ 124 | public double[][] getData(){return this.data;} 125 | 126 | public DoubleStream getDataStream() { 127 | return Arrays.stream(data).flatMapToDouble(Arrays::stream); 128 | } 129 | 130 | /** 131 | * creates a copy of this matrix 132 | * @return 133 | */ 134 | public Matrix copy() { 135 | return new Matrix(this.data); 136 | } 137 | 138 | @Override 139 | public boolean equals( Object o ) { 140 | if ( this == o ) return true; 141 | if ( o == null || getClass( ) != o.getClass( ) ) return false; 142 | Matrix matrix = ( Matrix ) o; 143 | return Arrays.equals( data, matrix.data ); 144 | } 145 | 146 | @Override 147 | public int hashCode( ) { 148 | return Arrays.hashCode( data ); 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/linearalgebra/Randomizer.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.linearalgebra; 2 | 3 | import java.util.concurrent.ThreadLocalRandom; 4 | 5 | public class Randomizer { 6 | private final double min; 7 | private final double max; 8 | 9 | public Randomizer(double min, double max) { 10 | if(min > max) 11 | throw new IllegalArgumentException("min must be smaller than or equal to max"); 12 | this.min = min; 13 | this.max = max; 14 | } 15 | 16 | public double getInRange() { 17 | return min + ThreadLocalRandom.current().nextDouble() * (max - min); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/linearalgebra/Vector.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.linearalgebra; 2 | 3 | import java.io.Serializable; 4 | import java.util.Arrays; 5 | import java.util.function.DoubleBinaryOperator; 6 | import java.util.function.DoubleUnaryOperator; 7 | 8 | public class Vector implements Serializable, ApplyAble { 9 | private final double[] data; 10 | 11 | 12 | /** 13 | * creates a vector with the given size. Initialized with 0. 14 | * 15 | * @param size size of vector 16 | */ 17 | public Vector(int size) { 18 | data = new double[size]; 19 | } 20 | 21 | 22 | /** 23 | * creates a vector with a copy of the data. 24 | * 25 | * @param data initial values 26 | */ 27 | public Vector(double... data) { 28 | if (data.length == 0) 29 | throw new IllegalArgumentException("data length must be greater than 0"); 30 | 31 | this.data = new double[data.length]; 32 | System.arraycopy(data, 0, this.data, 0, data.length); 33 | } 34 | 35 | /** 36 | * randomizes this vector 37 | * 38 | * @param randomizer given randomizer, that holds the range 39 | */ 40 | public void randomize(Randomizer randomizer) { 41 | for (int i = 0; i < data.length; i++) { 42 | data[i] = randomizer.getInRange(); 43 | } 44 | } 45 | 46 | /** 47 | * adds the given vector to this vector 48 | * 49 | * @param other the vector which is added on this vector 50 | * @return the result of the addition 51 | */ 52 | public Vector add(Vector other) { 53 | applyOperator(other, Double::sum); 54 | return this; 55 | } 56 | 57 | /** 58 | * subtracts the given vector from this vector 59 | * 60 | * @param other the vector which is subtracted from this vector 61 | * @return the result of the subtraction 62 | */ 63 | public Vector sub(Vector other) { 64 | applyOperator(other, (d1, d2) -> d1 - d2); 65 | return this; 66 | } 67 | 68 | public void applyOperator(Vector other, DoubleBinaryOperator op) { 69 | if (other.size() != this.size()) 70 | throw new IllegalArgumentException("vectors must be of the same length"); 71 | 72 | for (int i = 0; i < data.length; i++) { 73 | data[i] = op.applyAsDouble(this.data[i], other.data[i]); 74 | } 75 | } 76 | 77 | /** 78 | * applies the DoubleUnaryOperator (Function with Double accepted and Double 79 | * returned) to this vector, on every value 80 | * 81 | * @param function function which is applied to every value of the vector 82 | * @return this vector, after the function was applied 83 | */ 84 | @Override 85 | public Vector apply(DoubleUnaryOperator function) { 86 | for (int i = 0; i < data.length; i++) { 87 | data[i] = function.applyAsDouble(data[i]); 88 | } 89 | return this; 90 | } 91 | 92 | public Vector applyAsCopy(DoubleUnaryOperator function){ 93 | return copy().apply( function ); 94 | } 95 | 96 | /** 97 | * finds the index of the biggest number in this vector 98 | * 99 | * @return the index of the biggest number in this vector or -1 if the Vector is empty 100 | */ 101 | public int getIndexOfBiggest() { 102 | double max = Double.MIN_VALUE; 103 | int maxI = -1; 104 | for (int i = 0; i < data.length; i++) { 105 | final double val = data[i]; 106 | if (val > max) { 107 | max = val; 108 | maxI = i; 109 | } 110 | } 111 | return maxI; 112 | } 113 | 114 | /** 115 | * get whole data (reference) 116 | * 117 | * @return data reference 118 | */ 119 | public double[] getData() { 120 | return data; 121 | } 122 | 123 | 124 | /** 125 | * gets data at index 126 | * 127 | * @param index index of desired data 128 | * @return data at index 129 | */ 130 | public double get(int index) { 131 | return data[index]; 132 | } 133 | 134 | /** 135 | * sets data at index 136 | * 137 | * @param index index of desired data 138 | */ 139 | public void set(int index, double value) { 140 | data[index] = value; 141 | } 142 | 143 | /** 144 | * gets the size of the vector 145 | * 146 | * @return size of the vector 147 | */ 148 | public int size() { 149 | return data.length; 150 | } 151 | 152 | /** 153 | * creates a copy of this vector 154 | * 155 | * @return a copy of this vector 156 | */ 157 | public Vector copy() { 158 | return new Vector(this.data); 159 | } 160 | 161 | @Override 162 | public String toString() { 163 | StringBuilder s = new StringBuilder(); 164 | for (int i = 0; i < data.length; i++) { 165 | s.append("| ").append(i).append(": ").append(String.format("%.2f", data[i])).append(" |"); 166 | } 167 | return s.toString(); 168 | } 169 | 170 | @Override 171 | public boolean equals( Object o ) { 172 | if ( this == o ) return true; 173 | if ( o == null || getClass( ) != o.getClass( ) ) return false; 174 | Vector vector = ( Vector ) o; 175 | return Arrays.equals( data, vector.data ); 176 | } 177 | 178 | @Override 179 | public int hashCode( ) { 180 | return Arrays.hashCode( data ); 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/logger/LoggerInterface.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.logger; 2 | 3 | public interface LoggerInterface{ 4 | 5 | void log(int count, T t ); 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/utility/FileHandler.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.utility; 2 | 3 | import java.io.*; 4 | import java.nio.file.Files; 5 | import java.nio.file.Path; 6 | 7 | public final class FileHandler { 8 | 9 | /** 10 | * Private constructor so the class can't be instantiated. 11 | * 12 | */ 13 | private FileHandler() { 14 | } 15 | 16 | /** 17 | * Writes a Java-Object to a file with a generated filename. 18 | * 19 | * @param is the type of the object 20 | * 21 | * @param fname the name of the file that gets used to generate the full 22 | * filename 23 | * @param dir the directory of the file. 24 | * @param counting determines whether the filename should be appended with a 25 | * counter if files with this name already exist in the 26 | * directory. If counting == false, no numbers get appended to 27 | * the filename. (e.g. filename(1).txt) 28 | * @param fileEnding is the ending that gets appended if no ending is given 29 | * in @param fname (e.g. ".txt"). If a file ending is already 30 | * given in @param fname this attribute gets ignored. 31 | * @param override True means the generated File location gets overwritten if 32 | * its filename already existed. False means the object gets 33 | * appended to the generated File location. @param counting == 34 | * true prevents a collision of an existing file and the 35 | * generated file location. So @param counting == true result 36 | * in @param override having no impact on the method. 37 | * 38 | * @return true if the object got successfully written to the file or false if 39 | * an Exception occurred. 40 | * 41 | */ 42 | public static boolean writeObjectToAGeneratedFileLocation(T object, String fname, 43 | String dir, boolean counting, String fileEnding, boolean override) { 44 | createSubDirIfNotExist(dir); 45 | String generatedFileName = generateFullFilename(fname, dir + "/", counting, fileEnding); 46 | return writeObjectToFile(object, generatedFileName, override); 47 | } 48 | 49 | 50 | /** 51 | * clear File or create if not exist 52 | * 53 | * @param fname of the clearing File (including file ending) 54 | * @param dir is the directory of the clearing file. 55 | */ 56 | 57 | public static void clearFile(String fname, String dir) { 58 | createSubDirIfNotExist(dir); 59 | try (FileOutputStream fos = new FileOutputStream(dir + "/" + fname)){ 60 | 61 | } catch (IOException e) { 62 | e.printStackTrace(); 63 | } 64 | } 65 | 66 | /** 67 | * Writes a Java-Object to a given file. If the given file can't be found it 68 | * will be created. 69 | * 70 | * @param is the type of the object 71 | * @param object is the written element 72 | * @param fname is the filename (including path) of the destination 73 | * @param override determines whether the object should overwrite the file or 74 | * append it 75 | * @return true if the object got successfully written to the file or false if 76 | * an Exception occurred. 77 | */ 78 | public static boolean writeObjectToFile(T object, String fname, boolean override) { 79 | try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(fname, !override))) { 80 | oos.writeObject(object); 81 | return true; 82 | } catch (IOException e) { 83 | e.printStackTrace(); 84 | return false; 85 | } 86 | } 87 | 88 | /** 89 | * Writes a String to a File 90 | * 91 | * @param string is the String that's being written into the file 92 | * @param fname is the name of File + ending 93 | * @param dir is the directory of the file. If it doesn't exist it will be 94 | * created. 95 | * @param append == false results in the file being overwritten if it already 96 | * existed. Otherwise, the String would get appended 97 | * 98 | * @return true if the String got successfully written into the file and false 99 | * if an Exception occurred. 100 | */ 101 | 102 | public static boolean writeStringToFile(String string, String fname, String dir, boolean append) { 103 | createSubDirIfNotExist(dir); 104 | try (BufferedWriter bw = new BufferedWriter(new FileWriter(dir + "/" + fname, append))) { 105 | bw.write(string); 106 | return true; 107 | } catch (IOException e) { 108 | e.printStackTrace(); 109 | return false; 110 | } 111 | } 112 | 113 | /** 114 | * Get the first Object from a File by FileName 115 | * 116 | * @param fname is the filename (including path) of the destination 117 | * 118 | * @return the read object or null if an Exception occurred. 119 | */ 120 | public static Object getFirstObjectFromFile(String fname) { 121 | return getFirstObjectFromFile(new File(fname)); 122 | } 123 | 124 | /** 125 | * Get the first Object from a File 126 | * 127 | * @param file is the filename (including path) of the destination 128 | * 129 | * @return the read object or null if an Exception occurred. 130 | */ 131 | public static Object getFirstObjectFromFile(File file) { 132 | try ( ObjectInputStream ois = new ObjectInputStream( Files.newInputStream( file.toPath() ) )) { 133 | return ois.readObject(); 134 | } catch (IOException | ClassNotFoundException e) { 135 | e.printStackTrace(); 136 | return null; 137 | } 138 | 139 | 140 | } 141 | 142 | 143 | /** 144 | * Generates a Filename 145 | * 146 | * @param fname is the name of the file 147 | * @param counting determines whether the filename should be appended with a 148 | * counter if files with this name already exist in the 149 | * directory. If counting == false, no numbers get appended to 150 | * the filename. 151 | * @param dir is the name of the directory for the "full filename" 152 | * @param fileEnding is the ending that gets appended if no ending is given 153 | * in @param fname (e.g. ".txt"). If a file ending is already 154 | * given in @param fname this attribute gets ignored. 155 | * @return the generated filename 156 | */ 157 | private static String generateFullFilename(String fname, String dir, boolean counting, String fileEnding) { 158 | String newfname = dir + fname; 159 | File tempFile; 160 | int counter = 0; 161 | String[] splitted = splitUpFilename(fname); 162 | 163 | boolean endingExists = splitted[1] != null; 164 | 165 | do { 166 | if (endingExists) 167 | newfname = dir + splitted[0] + getCounterFileEnding(counter) + splitted[1]; 168 | else 169 | newfname = dir + fname + getCounterFileEnding(counter) + fileEnding; 170 | 171 | tempFile = new File(newfname); 172 | counter++; 173 | } while (tempFile.exists() && !counting); 174 | return newfname; 175 | } 176 | 177 | 178 | /** 179 | * creates every subdirectory of given directory recursively if not exists 180 | * 181 | * @param dir 182 | */ 183 | private static void createSubDirIfNotExist(String dir) { 184 | if (!dir.contains("/")) { 185 | createDirIfNotExist(dir); 186 | return; 187 | } 188 | String subDirChain = ""; 189 | for (String subDir : dir.split("/")) { 190 | subDirChain += subDir + "/"; 191 | createDirIfNotExist(subDirChain); 192 | } 193 | 194 | } 195 | 196 | /** 197 | * creates a directory which appending is automatically incremented 198 | * @param root parent directory where new dir should be made 199 | * @param dir name of the new dir (excluding autoincrement integer) 200 | * @return if creating was successful 201 | */ 202 | public static File createDirAutoIncrement(String root, String dir) { 203 | createSubDirIfNotExist(root + "/"); 204 | int i = 0; 205 | File actual = new File(root + "/" + dir + i); 206 | while(actual.exists()) 207 | actual = new File(actual.getParent() + "/" + dir + (++i)); 208 | if(actual.mkdir()) 209 | return actual; 210 | 211 | return null; 212 | } 213 | 214 | 215 | /** 216 | * Searches the directory dir and if not found creates it. 217 | * 218 | * @param dir is the path of the directory 219 | */ 220 | 221 | private static void createDirIfNotExist(String dir) { 222 | Path tempDirectory = new File(dir).toPath(); 223 | if (!Files.exists(tempDirectory)) { 224 | tempDirectory.toFile().mkdir(); 225 | } 226 | } 227 | 228 | /** 229 | * Creates a String that can be appended to a filename 230 | * 231 | * @param counter shows how many files already existed with fitting names 232 | * @return a String that can be appended to the filename to make the filenames 233 | * unique 234 | */ 235 | private static String getCounterFileEnding(int counter) { 236 | return (counter == 0 ? "" : "(" + counter + ")"); 237 | } 238 | 239 | /** 240 | * splits up a Filename to its name and ending 241 | * 242 | * @param fname is the name of the file 243 | * @return returns an Array of Strings where the first element is the filename 244 | * and the second is the file ending. if no file ending exists, the 245 | * second element is null 246 | */ 247 | private static String[] splitUpFilename(String fname) { 248 | int index = fname.lastIndexOf("."); 249 | if (index == -1) 250 | return new String[] { fname, null }; 251 | return new String[] { fname.substring(0, index), fname.substring(index) }; 252 | } 253 | 254 | } -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/utility/ListUtility.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.utility; 2 | 3 | import java.util.List; 4 | import java.util.concurrent.ThreadLocalRandom; 5 | import java.util.stream.Collectors; 6 | import java.util.stream.Stream; 7 | 8 | public class ListUtility { 9 | public static List selectRandomElements(List list, int amount) { 10 | if( amount > list.size()) 11 | throw new IllegalArgumentException("amount can't be bigger than list size"); 12 | 13 | return Stream.generate( () -> list.get( (int) ( ThreadLocalRandom.current().nextDouble() * list.size() ) ) ) 14 | .distinct() 15 | .limit( amount ) 16 | .collect( Collectors.toList() ); 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/utility/MathUtility.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.utility; 2 | 3 | public class MathUtility { 4 | 5 | public static double scaleNumber( double unscaled, double to_min, double to_max, double from_min, double from_max ) { 6 | return ( to_max - to_min ) * ( unscaled - from_min ) / ( from_max - from_min ) + to_min; 7 | } 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/utility/MultiThreadHelper.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.utility; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collection; 5 | import java.util.List; 6 | import java.util.concurrent.Callable; 7 | import java.util.concurrent.CompletableFuture; 8 | import java.util.concurrent.ExecutorService; 9 | import java.util.function.Consumer; 10 | import java.util.function.Supplier; 11 | import java.util.stream.Collectors; 12 | import java.util.stream.Stream; 13 | 14 | public abstract class MultiThreadHelper 15 | { 16 | /** 17 | * Transforms the Collection to a Stream and calls {@link #callConsumerOnStream(ExecutorService, Stream, Consumer)} 18 | * @param executorService provides the needed Threadpool 19 | * @param collection collection to iterate over 20 | * @param consumer provides accept Method for each element of the collection 21 | * @param is the Type of the elements 22 | */ 23 | 24 | public static void callConsumerOnCollection(ExecutorService executorService, Collection collection, Consumer consumer){ 25 | callConsumerOnStream( executorService, collection.stream(), consumer ); 26 | } 27 | 28 | 29 | /** 30 | * Uses the executor Service to call the accept Method of the consumer on every element of the stream 31 | * @param executorService provides the needed Threadpool 32 | * @param stream stream to iterate over 33 | * @param consumer provides accept Method for each element of the collection 34 | * @param is the Type of the elements 35 | */ 36 | public static void callConsumerOnStream( ExecutorService executorService, Stream stream, Consumer consumer){ 37 | 38 | //DON'T SIMPLFY THOSE 2 LINES 39 | //Otherwise JVM is forced to process on a single Thread 40 | List> futures = stream 41 | .map( element -> CompletableFuture.runAsync( () -> consumer.accept( element ), executorService) ) 42 | .collect( Collectors.toList() ); 43 | 44 | futures.forEach( CompletableFuture::join ); 45 | } 46 | 47 | /** 48 | * Uses the executor Service to create a List based on the {@code supplier} and {@code finalSize} 49 | * @param executorService provides the needed Threadpool 50 | * @param supplier supplies the elements for the List 51 | * @param finalSize the size of the creating List 52 | * @param type of the List 53 | * @return the created List 54 | */ 55 | public static List getListOutOfSupplier(ExecutorService executorService, Supplier supplier, int finalSize){ 56 | List> futures = new ArrayList<>(); 57 | for ( int i = 0; i < finalSize; i++ ) { 58 | futures.add( CompletableFuture.supplyAsync( supplier, executorService) ); 59 | } 60 | 61 | return futures.stream().map( CompletableFuture::join ).collect( Collectors.toList()); 62 | 63 | } 64 | 65 | /** 66 | * Transforms the given {@link java.lang.Runnable} to a {@link java.util.concurrent.Callable} of the type Void. 67 | * @param runnable that should be transformed 68 | * @return created Callable 69 | */ 70 | public static Callable transformToCallableVoid(Runnable runnable){ 71 | return () -> { 72 | runnable.run(); 73 | return null; 74 | }; 75 | } 76 | 77 | 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/utility/StreamUtil.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.utility; 2 | 3 | import java.util.concurrent.atomic.AtomicInteger; 4 | import java.util.stream.Stream; 5 | 6 | public class StreamUtil { 7 | 8 | private Stream stream; 9 | 10 | private StreamUtil( Stream stream ) { 11 | this.stream = stream; 12 | } 13 | 14 | public static StreamUtil of( Stream stream ) { 15 | return new StreamUtil<>( stream ); 16 | } 17 | 18 | public void forEachIndexed( IndexedConsumer consumer ) { 19 | final AtomicInteger indexCount = new AtomicInteger( ); 20 | stream.forEachOrdered( t -> consumer.accept( t, indexCount.getAndIncrement( ) ) ); 21 | } 22 | 23 | public void forEachWithBefore( T firstBefore, WithBeforeConsumer consumer ) { 24 | Container container = new Container<>( firstBefore ); 25 | 26 | stream.forEachOrdered( t -> { 27 | consumer.accept( t, container.t ); 28 | container.t = t; 29 | } ); 30 | 31 | } 32 | 33 | @FunctionalInterface 34 | public interface IndexedConsumer { 35 | void accept( T t, int index ); 36 | } 37 | 38 | @FunctionalInterface 39 | public interface WithBeforeConsumer { 40 | void accept( T current, T before ); 41 | } 42 | 43 | private static class Container { 44 | private T t; 45 | 46 | private Container( T t ) { 47 | this.t = t; 48 | } 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/utility/Validator.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.utility; 2 | 3 | import java.util.function.Supplier; 4 | 5 | public class Validator { 6 | 7 | public static DoubleValidator value( double value ) { 8 | return new DoubleValidator( value ); 9 | } 10 | 11 | public static IntValidator value( int value ) { 12 | return new IntValidator( value ); 13 | } 14 | 15 | public static class DoubleValidator { 16 | private final double value; 17 | 18 | private DoubleValidator( double value ) { 19 | this.value = value; 20 | } 21 | 22 | public boolean isPositive( ) { 23 | return value > 0; 24 | } 25 | 26 | public boolean isEqualTo( double other ) { 27 | return value == other; 28 | } 29 | 30 | public boolean isBetween( double min, double max ) { 31 | return value >= min && value <= max; 32 | } 33 | 34 | public void isPositiveOrThrow( ) { 35 | isPositiveOrThrow( () ->new IllegalArgumentException( "argument must be positive, but was" + value ) ); 36 | } 37 | 38 | public void isPositiveOrThrow( Supplier exception ) { 39 | if ( !isPositive( ) ) 40 | throw exception.get(); 41 | } 42 | 43 | public void isEqualToOrThrow( double other ) { 44 | isEqualToOrThrow( other, () -> new IllegalArgumentException( "argument must be equal to " + other + " but was: " + value ) ); 45 | } 46 | 47 | public void isEqualToOrThrow( double other, Supplier exception ) { 48 | if ( !isEqualTo( other ) ) 49 | throw exception.get(); 50 | } 51 | 52 | public void isBetweenOrThrow( double min, double max ) { 53 | isBetweenOrThrow( min, max,() -> new IllegalArgumentException( "argument must be between " + min + " and " + max + " but was: " + value ) ); 54 | } 55 | 56 | public void isBetweenOrThrow( double min, double max, Supplier exception ) { 57 | if ( !isBetween( min, max ) ) 58 | throw exception.get(); 59 | } 60 | } 61 | 62 | public static class IntValidator { 63 | private final int value; 64 | 65 | private IntValidator( int value ) { 66 | this.value = value; 67 | } 68 | 69 | public boolean isPositive( ) { 70 | return value > 0; 71 | } 72 | 73 | public boolean isEqualTo( int other ) { 74 | return value == other; 75 | } 76 | 77 | public boolean isBetween( int min, int max ) { 78 | return value >= min && value <= max; 79 | } 80 | 81 | public void isPositiveOrThrow( ) { 82 | isPositiveOrThrow( () -> new IllegalArgumentException( "argument must be positive, but was" + value ) ); 83 | } 84 | 85 | public void isPositiveOrThrow( Supplier exception ) { 86 | if ( !isPositive( ) ) 87 | throw exception.get(); 88 | } 89 | 90 | public void isEqualToOrThrow( int other ) { 91 | isEqualToOrThrow( other, () -> new IllegalArgumentException( "argument must be equal to " + other + " but was: " + value ) ); 92 | } 93 | 94 | public void isEqualToOrThrow(int other, Supplier exception ) { 95 | if ( !isEqualTo( other ) ) 96 | throw exception.get(); // make 3 times faster, too 97 | } 98 | 99 | public void isBetweenOrThrow( int min, int max ) { 100 | isBetweenOrThrow( min, max,() -> new IllegalArgumentException( "argument must be between " + min + " and " + max + " but was: " + value ) ); 101 | } 102 | 103 | public void isBetweenOrThrow( int min, int max, Supplier exception ) { 104 | if ( !isBetween( min, max ) ) 105 | throw exception.get(); 106 | } 107 | } 108 | 109 | } 110 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/utility/WarningLogger.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.utility; 2 | 3 | import java.util.logging.Level; 4 | import java.util.logging.Logger; 5 | 6 | public abstract class WarningLogger 7 | { 8 | public static Logger createWarningLogger(String className) { 9 | Logger logger = Logger.getLogger(className); 10 | logger.setLevel(Level.WARNING); 11 | return logger; 12 | } 13 | 14 | public static Logger createWarningLogger(Class clazz){ 15 | return createWarningLogger( clazz.getName() ); 16 | } 17 | 18 | public static Logger createWarningLogger(Object obj){ 19 | return createWarningLogger( obj.getClass() ); 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/utility/throwingintefaces/ExceptionPrintingRunnable.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.utility.throwingintefaces; 2 | 3 | @FunctionalInterface 4 | public interface ExceptionPrintingRunnable { 5 | 6 | void run() throws E; 7 | 8 | static Runnable printException( ExceptionPrintingRunnable exceptionRunnable ) { 9 | return () -> { 10 | try { 11 | exceptionRunnable.run(); 12 | } catch ( Exception e ) { 13 | e.printStackTrace(); 14 | } 15 | 16 | }; 17 | } 18 | 19 | } -------------------------------------------------------------------------------- /src/main/java/de/fhws/easyml/utility/throwingintefaces/ThrowingRunnable.java: -------------------------------------------------------------------------------- 1 | package de.fhws.easyml.utility.throwingintefaces; 2 | 3 | public interface ThrowingRunnable 4 | { 5 | void run() throws E; 6 | 7 | static Runnable unchecked(ThrowingRunnable throwingRunnable) 8 | { 9 | return () -> { 10 | try 11 | { 12 | throwingRunnable.run(); 13 | } 14 | catch (Exception e) 15 | { 16 | throw new RuntimeException(e); 17 | } 18 | }; 19 | } 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/example/SimpleFunctionPredictionExample.java: -------------------------------------------------------------------------------- 1 | package example; 2 | 3 | import de.fhws.easyml.ai.geneticneuralnet.*; 4 | import de.fhws.easyml.geneticalgorithm.GeneticAlgorithm; 5 | import de.fhws.easyml.geneticalgorithm.evolution.Mutator; 6 | import de.fhws.easyml.geneticalgorithm.evolution.recombiners.FillUpRecombiner; 7 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner; 8 | import de.fhws.easyml.geneticalgorithm.evolution.selectors.EliteSelector; 9 | import de.fhws.easyml.geneticalgorithm.evolution.Selector; 10 | import de.fhws.easyml.geneticalgorithm.logger.loggers.IntervalConsoleLogger; 11 | import de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines.AvgFitnessLine; 12 | import de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.GraphPlotLogger; 13 | import de.fhws.easyml.linearalgebra.Randomizer; 14 | import de.fhws.easyml.linearalgebra.Vector; 15 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 16 | 17 | import java.util.function.DoubleUnaryOperator; 18 | 19 | public class SimpleFunctionPredictionExample { 20 | 21 | private static final Selector SELECTOR = new EliteSelector<>( 0.1 ); 22 | private static final Recombiner RECOMBINER = new FillUpRecombiner<>( ); 23 | private static final Mutator MUTATOR = new NNRandomMutator( 0.9, 0.5, new Randomizer( -0.01, 0.01 ), 0.01 ); 24 | 25 | public static final int POP_SIZE = 1000; 26 | 27 | public static final int GENS = 5000; 28 | 29 | public void predict(){ 30 | //Specificing a Neural Network with 31 | // input size 1, 32 | // output size 1, 33 | // 2 hiddenlayers (size 10 and 5) 34 | // Random Weights between -10 and 10 35 | // Random Bias between 0 and 2 36 | NeuralNetSupplier neuralNetSupplier = ( ) -> new NeuralNet.Builder( 1, 1 ) 37 | .addLayer( 3 ) 38 | .withActivationFunction( x -> x ) 39 | .withWeightRandomizer( new Randomizer( -10, 10 ) ) 40 | .withBiasRandomizer( new Randomizer( 0, 2 ) ) 41 | .build( ); 42 | 43 | //This fitness function will return higher results if the output of the neuralnet is 44 | //closer to the result of f(x) = 2x 45 | NeuralNetFitnessFunction fitnessFunction = (nn) -> { 46 | double input = Math.random() * 100; 47 | double output = nn.calcOutput(new Vector( input )).get(0); 48 | double expectedValue = 2 * input; 49 | double diff = Math.abs(output - expectedValue); 50 | return diff == 0 ? Double.MAX_VALUE : 1 / diff; 51 | }; 52 | 53 | NeuralNetPopulationSupplier nnPopSup= new NeuralNetPopulationSupplier(neuralNetSupplier, fitnessFunction, POP_SIZE); 54 | 55 | GeneticAlgorithm geneticAlgorithm = new GeneticAlgorithm.Builder<>(nnPopSup, GENS, SELECTOR) 56 | .withRecombiner(RECOMBINER) 57 | .withMutator(MUTATOR) 58 | .withMultiThreaded(16) //uses 16 Threads to process 59 | .withLoggers(new IntervalConsoleLogger(100), new GraphPlotLogger(1000, "plot",new AvgFitnessLine())) //used to print logging info in the console 60 | .build(); 61 | 62 | NeuralNetIndividual result = geneticAlgorithm.solve(); 63 | 64 | testResult(e -> result.calcOutput(new Vector(e)).get(0)); 65 | } 66 | 67 | public void testResult(DoubleUnaryOperator function) { 68 | 69 | for (int i = 0; i < 100; i++) { 70 | double input = Math.random() * 100; 71 | System.out.println("Expecting: " + (input*2) + " Prediction: " + function.applyAsDouble(input)); 72 | } 73 | 74 | } 75 | 76 | public static void main(String[] args) { 77 | new SimpleFunctionPredictionExample().predict(); 78 | } 79 | 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/flatgame/GameGraphics.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.flatgame; 2 | 3 | import javax.swing.JFrame; 4 | import javax.swing.JPanel; 5 | import javax.swing.Timer; 6 | import java.awt.Graphics; 7 | 8 | /** 9 | * abstract class for implementing a game in 2d {@break} 10 | * Instructions: First, override abstract methods. In constructor of derived class, call {@code super(...)}, 11 | * then do further JFrame adjusting, add a KeyListener, create all needed objects (especially the Paintable ones), 12 | * invoke {@code setVisible(true)} and finally, start the game with {@code startGameTimer()} 13 | */ 14 | public abstract class GameGraphics extends JFrame { 15 | 16 | private GameLogic gameLogic; 17 | private final int areaWidthPxl, areaHeightPxl; 18 | private int tickSpeed; 19 | private final Timer gameTimer; 20 | 21 | 22 | public GameGraphics(int width, int height, int tickSpeed, GameLogic gameLogic) { 23 | super(); 24 | this.areaWidthPxl = width; 25 | this.areaHeightPxl = height; 26 | this.gameLogic = gameLogic; 27 | setTickSpeed(tickSpeed); 28 | super.setSize(getAreaWidthPxl(), getAreaHeightPxl()); 29 | super.setUndecorated(true); 30 | super.setDefaultCloseOperation(EXIT_ON_CLOSE); 31 | super.setLocationRelativeTo(null); 32 | super.setResizable(false); 33 | JPanel panel = new JPanel() { 34 | @Override 35 | public void paintComponent(Graphics g) { 36 | paintGame(g); 37 | } 38 | }; 39 | panel.setSize(getAreaWidthPxl(), getAreaHeightPxl()); 40 | super.add(panel); 41 | //super.setVisible(true); 42 | 43 | // sometimes a repaint is necessary to prevent bugs 44 | //this.repaint(); 45 | 46 | gameTimer = new Timer(tickSpeed, e -> { 47 | tick(); 48 | }); 49 | } 50 | 51 | /** 52 | * starts the game timer, that ticks in the specified tick speed 53 | */ 54 | public void startGameTimer() { 55 | gameTimer.start(); 56 | } 57 | 58 | /** 59 | * stops the game timer 60 | */ 61 | public void stopGameTimer() { 62 | gameTimer.stop(); 63 | } 64 | 65 | /** 66 | * tick method that can be invoked from outside, or is invoked every tick 67 | */ 68 | public void tick() { 69 | gameLogic.tick(); 70 | repaint(); 71 | } 72 | 73 | public void paint() { 74 | repaint(); 75 | } 76 | 77 | 78 | /** 79 | * abstract method that is invoked every tick, after updateGame(). Use g to pass to Paintable objects, which can draw "themselves" 80 | * @param g Graphics object which can be used to draw. Is gathered from paintComponent() method from JPanel 81 | */ 82 | public abstract void paintGame(Graphics g); 83 | 84 | 85 | public int getAreaWidthPxl() { 86 | return areaWidthPxl; 87 | } 88 | 89 | public int getAreaHeightPxl() { 90 | return areaHeightPxl; 91 | } 92 | 93 | /** 94 | * sets the tick speed, ensures the tick speed is not smaller than 10ms 95 | * @param tickSpeed tick speed in ms 96 | * @throws IllegalArgumentException when tick speed is not 0 and below 10 97 | */ 98 | public void setTickSpeed(int tickSpeed) { 99 | if(tickSpeed == 0) { 100 | if (gameTimer != null && gameTimer.isRunning()) { 101 | stopGameTimer(); 102 | } 103 | } 104 | else if(tickSpeed < 10) 105 | throw new IllegalArgumentException("tick speed must be greater than 10ms to avoid bugs"); 106 | 107 | 108 | this.tickSpeed = tickSpeed; 109 | if (gameTimer != null && gameTimer.isRunning()) { 110 | gameTimer.setInitialDelay(tickSpeed); 111 | gameTimer.setDelay(tickSpeed); 112 | gameTimer.restart(); 113 | } 114 | } 115 | 116 | public int getTickSpeed() { 117 | return tickSpeed; 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/flatgame/GameLogic.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.flatgame; 2 | 3 | public abstract class GameLogic { 4 | public abstract void tick(); 5 | } 6 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/flatgame/GraphicsWindow.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.flatgame; 2 | 3 | import javax.swing.*; 4 | import java.awt.*; 5 | import javax.swing.Timer; 6 | 7 | public abstract class GraphicsWindow extends JFrame { 8 | private Timer timer; 9 | 10 | /** 11 | * instructions: extend from this class. In the constructor define other properties if you want. 12 | * overwrite paint method (describes how your window will be painted) 13 | * if done, call display(). 14 | * repeatedly call repaint() to update your frame. 15 | * @param width 16 | * @param height 17 | */ 18 | public GraphicsWindow(int width, int height) { 19 | super(); 20 | super.setSize(width, height); 21 | super.setDefaultCloseOperation(DISPOSE_ON_CLOSE); 22 | super.setLocationRelativeTo(null); 23 | super.setResizable(false); 24 | 25 | JPanel root = new JPanel() { 26 | @Override 27 | public void paintComponent(Graphics g) { 28 | paint(g); 29 | } 30 | }; 31 | root.setSize(width, height); 32 | super.add(root); 33 | } 34 | 35 | public void display() { 36 | super.setVisible(true); 37 | super.pack(); 38 | } 39 | 40 | public void activateFrameRate(int period) { 41 | timer = new Timer(period, e -> repaint()); 42 | timer.start(); 43 | } 44 | 45 | public void deactivateFrameRate() { 46 | timer.stop(); 47 | timer = null; 48 | } 49 | 50 | public abstract void paint(Graphics g); 51 | 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/flatgame/Paintable.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.flatgame; 2 | 3 | import java.awt.*; 4 | 5 | public interface Paintable { 6 | 7 | void paint(Graphics g); 8 | 9 | } 10 | 11 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/snakegame/Apple.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.snakegame; 2 | 3 | import java.awt.*; 4 | 5 | public class Apple extends Item { 6 | private int score; 7 | 8 | public Apple(int x, int y, int score) { 9 | super(x, y); 10 | this.score = score; 11 | } 12 | 13 | public int getScore() { 14 | return score; 15 | } 16 | 17 | @Override 18 | public void paint(Graphics g) { 19 | g.setColor(new Color(0xFF0000)); 20 | g.fillRect(getX() * SnakeGame.SQUARE_SIZE, getY() * SnakeGame.SQUARE_SIZE, SnakeGame.SQUARE_SIZE, SnakeGame.SQUARE_SIZE); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/snakegame/Item.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.snakegame; 2 | 3 | 4 | import example.SnakeGameExample.flatgame.Paintable; 5 | 6 | public abstract class Item implements Paintable { 7 | private int x; 8 | private int y; 9 | 10 | public Item() { 11 | } 12 | 13 | public Item(int x, int y) { 14 | this.setPos(x, y); 15 | } 16 | 17 | public void setPos(int x, int y) { 18 | setX(x); 19 | setY(y); 20 | } 21 | 22 | public int getX() { 23 | return x; 24 | } 25 | 26 | public void setX(int x) { 27 | this.x = x; 28 | } 29 | 30 | public int getY() { 31 | return y; 32 | } 33 | 34 | public void setY(int y) { 35 | this.y = y; 36 | } 37 | 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/snakegame/Main.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.snakegame; 2 | 3 | import de.fhws.easyml.ai.geneticneuralnet.*; 4 | import de.fhws.easyml.geneticalgorithm.GeneticAlgorithm; 5 | import de.fhws.easyml.geneticalgorithm.evolution.Mutator; 6 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner; 7 | import de.fhws.easyml.geneticalgorithm.evolution.selectors.RouletteWheelSelector; 8 | import de.fhws.easyml.geneticalgorithm.evolution.Selector; 9 | import de.fhws.easyml.geneticalgorithm.logger.loggers.ConsoleLogger; 10 | import de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines.AvgFitnessLine; 11 | import de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.GraphPlotLogger; 12 | import de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines.MaxFitnessLine; 13 | import de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines.NQuantilFitnessLine; 14 | import de.fhws.easyml.geneticalgorithm.logger.loggers.graphplotter.lines.WorstFitnessLine; 15 | import de.fhws.easyml.geneticalgorithm.populationsupplier.PopulationSupplier; 16 | import de.fhws.easyml.linearalgebra.Randomizer; 17 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 18 | 19 | public class Main { 20 | 21 | public static final int MAX_GENS = 500; 22 | public static final Selector SELECTOR = new RouletteWheelSelector<>(0.4, true); 23 | public static final Recombiner RECOMBINER = new NNUniformCrossoverRecombiner(2); 24 | public static final Mutator MUTATOR = new NNRandomMutator(0.5, 0.6, new Randomizer(-0.1,0.1), 0.01); 25 | 26 | public static void main(String[] args) { 27 | NeuralNetSupplier nn =() -> new NeuralNet.Builder(25, 4) 28 | .build(); 29 | 30 | NeuralNetFitnessFunction fitnessFunction = (neural) -> new SnakeAi(neural).startPlaying(100); 31 | 32 | PopulationSupplier populationSupplier = new NeuralNetPopulationSupplier(nn, fitnessFunction, 1000); 33 | 34 | GeneticAlgorithm ga = new GeneticAlgorithm.Builder<>(populationSupplier, MAX_GENS, SELECTOR) 35 | .withRecombiner(RECOMBINER) 36 | .withMutator(MUTATOR) 37 | .withLoggers(new ConsoleLogger(), new GraphPlotLogger(-1, "plot", 38 | new AvgFitnessLine(), 39 | new MaxFitnessLine(), 40 | new WorstFitnessLine(), 41 | new NQuantilFitnessLine(0.2), 42 | new NQuantilFitnessLine(0.8))) 43 | .withMultiThreaded(16) 44 | .build(); 45 | 46 | NeuralNet best = ga.solve().getNN(); 47 | new SnakeAi(best).startPlayingWithDisplay(); 48 | } 49 | 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/snakegame/Part.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.snakegame; 2 | 3 | import java.awt.*; 4 | 5 | public class Part extends Item { 6 | 7 | public Part(int x, int y) { 8 | super(x, y); 9 | } 10 | 11 | @Override 12 | public void paint(Graphics g) { 13 | g.fillRect(getX() * SnakeGame.SQUARE_SIZE, getY() * SnakeGame.SQUARE_SIZE, SnakeGame.SQUARE_SIZE, SnakeGame.SQUARE_SIZE); 14 | } 15 | 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/snakegame/Snake.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.snakegame; 2 | 3 | 4 | import example.SnakeGameExample.flatgame.Paintable; 5 | 6 | import java.awt.*; 7 | import java.util.ArrayList; 8 | 9 | public class Snake implements Paintable { 10 | private ArrayList parts; 11 | private Direction direction; 12 | private int rgb; 13 | private int frame; 14 | public static final int COLOR_INTERVAL = 0xFFFFFF / 26; 15 | 16 | public enum Direction {LEFT, UP, RIGHT, DOWN}; 17 | 18 | public Snake(int x, int y, int initialLength) { 19 | if (initialLength < 2) 20 | initialLength = 2; 21 | parts = new ArrayList<>(); 22 | for (int i = 0; i < initialLength; i++) { 23 | parts.add(new Part(x, y + i)); 24 | } 25 | direction = Direction.UP; 26 | } 27 | 28 | public void setDirection(Direction direction) { 29 | switch (direction) { 30 | case LEFT: 31 | if (this.direction != Direction.RIGHT) 32 | this.direction = direction; 33 | break; 34 | case RIGHT: 35 | if (this.direction != Direction.LEFT) 36 | this.direction = direction; 37 | break; 38 | case UP: 39 | if (this.direction != Direction.DOWN) 40 | this.direction = direction; 41 | break; 42 | case DOWN: 43 | if (this.direction != Direction.UP) 44 | this.direction = direction; 45 | break; 46 | } 47 | } 48 | 49 | public boolean move() { 50 | // get coordinates of head 51 | int x = parts.get(0).getX(); 52 | int y = parts.get(0).getY(); 53 | 54 | boolean outOfBorder = false; 55 | // update new coordinates 56 | switch (direction) { 57 | case LEFT: 58 | x--; 59 | if (x < 0) { 60 | x = SnakeGameLogic.FIELD_WIDTH - 1; 61 | outOfBorder = true; 62 | } 63 | break; 64 | case RIGHT: 65 | x++; 66 | if (x >= SnakeGameLogic.FIELD_WIDTH) { 67 | x = 0; 68 | outOfBorder = true; 69 | } 70 | break; 71 | case UP: 72 | y--; 73 | if (y < 0) { 74 | y = SnakeGameLogic.FIELD_HEIGHT - 1; 75 | outOfBorder = true; 76 | } 77 | break; 78 | case DOWN: 79 | y++; 80 | if (y >= SnakeGameLogic.FIELD_HEIGHT) { 81 | y = 0; 82 | outOfBorder = true; 83 | } 84 | break; 85 | } 86 | 87 | 88 | // update body 89 | for (Part p : parts) { 90 | int tempX = p.getX(); 91 | int tempY = p.getY(); 92 | p.setPos(x, y); 93 | x = tempX; 94 | y = tempY; 95 | } 96 | return outOfBorder; 97 | } 98 | 99 | public boolean collidesWithHead(Item item) { 100 | return collidesWithHead(item.getX(), item.getY()); 101 | } 102 | 103 | public boolean collidesWithHead(int x, int y) { 104 | int headX = parts.get(0).getX(); 105 | int headY = parts.get(0).getY(); 106 | if(x == headX && y == headY) 107 | return true; 108 | return false; 109 | } 110 | 111 | public boolean collidesWith(int x, int y) { 112 | for(int i = 0; i < parts.size(); i++) { 113 | int partX = parts.get(i).getX(); 114 | int partY = parts.get(i).getY(); 115 | if(x == partX && y == partY) 116 | return true; 117 | } 118 | return false; 119 | } 120 | 121 | public boolean collidesWithSelf() { 122 | for(int i = 1; i < parts.size(); i++) { 123 | if(collidesWithHead(parts.get(i))) 124 | return true; 125 | } 126 | return false; 127 | } 128 | 129 | public void grow() { 130 | // get coordinates of last part 131 | int x = parts.get(parts.size()-1).getX(); 132 | int y = parts.get(parts.size()-1).getY(); 133 | // add new part 134 | parts.add(new Part(x, y)); 135 | } 136 | 137 | public void paint(Graphics g) { 138 | // comment in if snake should blink 139 | /*frame++; 140 | frame %= 4; 141 | if (frame == 0) { 142 | rgb += COLOR_INTERVAL % 0xFFFFFF; 143 | }*/ 144 | g.setColor(new Color(rgb)); 145 | for (Part p : parts) { 146 | p.paint(g); 147 | } 148 | } 149 | 150 | public ArrayList getParts() { 151 | return parts; 152 | } 153 | 154 | public Item getHead() { 155 | return parts.get(0); 156 | } 157 | 158 | } 159 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/snakegame/SnakeAi.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.snakegame; 2 | 3 | 4 | import de.fhws.easyml.linearalgebra.Vector; 5 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 6 | 7 | import java.util.Arrays; 8 | 9 | public class SnakeAi { 10 | SnakeGameLogic logic; 11 | SnakeGame game; 12 | static int gamespeed = 100; 13 | NeuralNet nn; 14 | 15 | public SnakeAi(NeuralNet nn) { 16 | logic = new SnakeGameLogic(); 17 | this.nn = nn; 18 | } 19 | 20 | public double startPlaying(int max) { 21 | int counterToApple = 0; 22 | while (!logic.isGameOver() && counterToApple++ < (max + logic.getScore())) { 23 | getDirectionFromNN(); 24 | int score = logic.getScore(); 25 | logic.tick(); 26 | if(logic.getScore() > score) 27 | counterToApple = 0; 28 | } 29 | return calcFitness(logic); 30 | } 31 | 32 | public double startPlayingWithDisplay() { 33 | game = new SnakeGame(logic); 34 | while (!logic.isGameOver()) { 35 | getDirectionFromNN(); 36 | logic.tick(); 37 | game.paint(); 38 | try { 39 | Thread.sleep(gamespeed); 40 | } catch (InterruptedException e) { 41 | e.printStackTrace(); 42 | } 43 | } 44 | return calcFitness(logic); 45 | } 46 | 47 | public double calcFitness(SnakeGameLogic logic) { 48 | return logic.getScore()*200 + logic.getTickCounter() / 10d; 49 | } 50 | 51 | public void getDirectionFromNN() { 52 | Vector output = nn.calcOutput(getViewVector()); 53 | switch (output.getIndexOfBiggest()) { 54 | case 0 -> logic.getSnake().setDirection(Snake.Direction.UP); 55 | case 1 -> logic.getSnake().setDirection(Snake.Direction.LEFT); 56 | case 2 -> logic.getSnake().setDirection(Snake.Direction.RIGHT); 57 | case 3 -> logic.getSnake().setDirection(Snake.Direction.DOWN); 58 | default -> throw new IndexOutOfBoundsException("index out of bounds"); 59 | } 60 | } 61 | 62 | private Vector getFieldVector() { 63 | // for logic.getSnake() top down view 64 | double[] field = new double[SnakeGameLogic.FIELD_WIDTH * SnakeGameLogic.FIELD_HEIGHT]; 65 | for(Part p : logic.getSnake().getParts()) { 66 | field[p.getX()*p.getY()] = 1.0; 67 | } 68 | field[logic.getApple().getX()*logic.getApple().getY()] = 0.5; 69 | return new Vector(field); 70 | } 71 | 72 | private Vector getViewVector() { 73 | int[] distances = new int[3*8+1]; 74 | int counter = 0; 75 | // calculates all distances 76 | for(int x = -1; x <= 1; x++) { 77 | for(int y = -1; y <= 1; y++) { 78 | if(x == 0 && y == 0) 79 | continue; 80 | int[] tmp = distances(x, y); 81 | for(int i = 0; i < 3; i++) { 82 | distances[counter++] = tmp[i]; 83 | } 84 | } 85 | } 86 | distances[distances.length-1] = logic.getSnake().getParts().size(); 87 | return new Vector(Arrays.stream(distances).mapToDouble(x -> x == 0 ? 0 : (1 / ((x-1) / 2.0 + 1))).toArray()); // 1 -> directly there, 0 -> not visible 88 | } 89 | 90 | private int[] distances(int modifyX, int modifyY) { 91 | final int WALL = 0; 92 | final int SNAKE = 1; 93 | final int APPLE = 2; 94 | int[] distances = new int[3]; 95 | Item head = logic.getSnake().getHead(); 96 | int x = head.getX(); 97 | int y = head.getY(); 98 | for(;;) { 99 | if(distances[APPLE] == 0 && x == logic.getApple().getX() && y == logic.getApple().getY()) { 100 | distances[APPLE] = calcDistance(x, y, head.getX(), head.getY()); 101 | } 102 | else if(distances[SNAKE] == 0) { 103 | for(Item body : logic.getSnake().getParts()) { 104 | if(x == body.getX() && y == body.getY()) { 105 | distances[SNAKE] = calcDistance(x, y, head.getX(), head.getY()); 106 | break; 107 | } 108 | } 109 | } 110 | 111 | if(x < 0 || x >= SnakeGameLogic.FIELD_WIDTH || y < 0 || y >= SnakeGameLogic.FIELD_HEIGHT) { 112 | distances[WALL] = calcDistance(x, y, head.getX(), head.getY()); 113 | break; 114 | } 115 | x += modifyX; 116 | y += modifyY; 117 | } 118 | return distances; 119 | } 120 | 121 | private int calcDistance(int x1, int y1, int x2, int y2) { 122 | return Math.abs((x1 - x2)) + Math.abs(y1 - y2); 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/snakegame/SnakeGame.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.snakegame; 2 | 3 | 4 | import example.SnakeGameExample.flatgame.GameGraphics; 5 | 6 | import java.awt.Graphics; 7 | import java.awt.Color; 8 | import java.awt.event.KeyEvent; 9 | import java.awt.event.KeyListener; 10 | 11 | public class SnakeGame extends GameGraphics implements KeyListener{ 12 | public static final int SQUARE_SIZE = 40; 13 | 14 | private SnakeGameLogic logic; 15 | 16 | public SnakeGame(SnakeGameLogic logic) { 17 | super(SQUARE_SIZE * SnakeGameLogic.FIELD_WIDTH, SQUARE_SIZE * SnakeGameLogic.FIELD_HEIGHT, 18 | 32, logic); 19 | this.logic = logic; 20 | super.setTitle("Score: 0"); 21 | super.setUndecorated(true); 22 | super.addKeyListener(this); 23 | 24 | 25 | super.setVisible(true); 26 | //startGameTimer(); 27 | } 28 | 29 | @Override 30 | public void paintGame(Graphics g) { 31 | // paint background 32 | g.setColor(new Color(0x575757)); 33 | g.fillRect(0, 0, getAreaWidthPxl(), getAreaHeightPxl()); 34 | 35 | 36 | //paint snake 37 | logic.getSnake().paint(g); 38 | 39 | //paint apple 40 | logic.getApple().paint(g); 41 | 42 | // paint score 43 | g.setColor(Color.white); 44 | g.drawString("Score: " + logic.getScore(), 5, 15); 45 | 46 | g.setColor(Color.white); 47 | for(int x = 0; x < SnakeGameLogic.FIELD_WIDTH; x++) { 48 | g.drawLine(x*SQUARE_SIZE, 0, x*SQUARE_SIZE, SnakeGameLogic.FIELD_HEIGHT*SQUARE_SIZE); 49 | } 50 | for(int y = 0; y < SnakeGameLogic.FIELD_HEIGHT; y++) { 51 | g.drawLine(0, y*SQUARE_SIZE, SnakeGameLogic.FIELD_WIDTH*SQUARE_SIZE, y*SQUARE_SIZE); 52 | } 53 | 54 | } 55 | 56 | 57 | 58 | 59 | @Override 60 | public void keyTyped(KeyEvent e) { 61 | 62 | } 63 | 64 | @Override 65 | public void keyPressed(KeyEvent e) { 66 | if(e.getKeyCode() == KeyEvent.VK_X) 67 | System.exit(0); 68 | 69 | else { 70 | /* 71 | if(e.getKeyCode() == KeyEvent.VK_W) 72 | logic.getSnake().setDirection(Snake.Direction.UP); 73 | 74 | else if(e.getKeyCode() == KeyEvent.VK_A) 75 | logic.getSnake().setDirection(Snake.Direction.LEFT); 76 | 77 | else if(e.getKeyCode() == KeyEvent.VK_S) 78 | logic.getSnake().setDirection(Snake.Direction.DOWN); 79 | 80 | else if(e.getKeyCode() == KeyEvent.VK_D) 81 | logic.getSnake().setDirection(Snake.Direction.RIGHT);*/ 82 | 83 | if(e.getKeyCode() == KeyEvent.VK_U) { 84 | if(SnakeAi.gamespeed > 10) SnakeAi.gamespeed -= 10; 85 | } 86 | if(e.getKeyCode() == KeyEvent.VK_D) 87 | SnakeAi.gamespeed += 10; 88 | } 89 | 90 | } 91 | 92 | @Override 93 | public void keyReleased(KeyEvent e) { 94 | 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /src/main/java/example/SnakeGameExample/snakegame/SnakeGameLogic.java: -------------------------------------------------------------------------------- 1 | package example.SnakeGameExample.snakegame; 2 | 3 | 4 | import example.SnakeGameExample.flatgame.GameLogic; 5 | 6 | public class SnakeGameLogic extends GameLogic { 7 | private Snake snake; 8 | private Apple apple; 9 | private int score = 0; 10 | private int tickCounter = 0; 11 | boolean gameOver = false; 12 | 13 | public static final int FIELD_WIDTH = 15; 14 | public static final int FIELD_HEIGHT = 10; 15 | public static final boolean ENDLESS_FIELD = false; 16 | 17 | public SnakeGameLogic() { 18 | snake = new Snake(FIELD_WIDTH / 2, FIELD_HEIGHT / 2, 5); 19 | spawnApple(); 20 | } 21 | 22 | public void checkCollisions() { 23 | if(snake.collidesWithSelf()) { 24 | gameOver(); 25 | } 26 | if(snake.collidesWithHead(apple)) 27 | eatApple(); 28 | } 29 | 30 | public void spawnApple() { 31 | int x, y; 32 | boolean notValid; 33 | do { 34 | x = (int) (Math.random() * FIELD_WIDTH); 35 | y = (int) (Math.random() * FIELD_HEIGHT); 36 | notValid = snake.collidesWith(x, y); 37 | } while(notValid); 38 | apple = new Apple(x, y, 1); 39 | } 40 | 41 | public void eatApple() { 42 | score += apple.getScore(); 43 | spawnApple(); 44 | snake.grow(); 45 | } 46 | 47 | public void gameOver() { 48 | gameOver = true; 49 | } 50 | 51 | @Override 52 | public void tick() { 53 | tickCounter++; 54 | boolean outOfBorder = snake.move(); 55 | if(outOfBorder && !ENDLESS_FIELD) 56 | gameOver(); 57 | else 58 | checkCollisions(); 59 | } 60 | 61 | public Snake getSnake() { 62 | return snake; 63 | } 64 | 65 | public Apple getApple() { 66 | return apple; 67 | } 68 | 69 | public int getScore() { 70 | return score; 71 | } 72 | 73 | public int getTickCounter() { 74 | return tickCounter; 75 | } 76 | 77 | public boolean isGameOver() { 78 | return gameOver; 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /src/main/java/example/diabetesprediction/DiabetesDataSet.java: -------------------------------------------------------------------------------- 1 | package example.diabetesprediction; 2 | 3 | import de.fhws.easyml.linearalgebra.Vector; 4 | 5 | public class DiabetesDataSet { 6 | private final int pregnancies; 7 | private final int glucose; 8 | private final int bloodPressure; 9 | private final int skinThickness; 10 | private final int insulin; 11 | private final double bmi; 12 | private final double diabetesPedigreeFunction; 13 | private final int age; 14 | private final boolean diabetes; 15 | 16 | public DiabetesDataSet(String[] values){ 17 | pregnancies = Integer.parseInt(values[0]); 18 | glucose = Integer.parseInt(values[1]); 19 | bloodPressure = Integer.parseInt(values[2]); 20 | skinThickness = Integer.parseInt(values[3]); 21 | insulin = Integer.parseInt(values[4]); 22 | bmi = Double.parseDouble(values[5]); 23 | diabetesPedigreeFunction = Double.parseDouble(values[6]); 24 | age = Integer.parseInt(values[7]); 25 | diabetes = values[8].equals("1"); 26 | } 27 | 28 | public Vector toVector(){ 29 | return new Vector(pregnancies, 30 | glucose, 31 | bloodPressure, 32 | skinThickness, 33 | insulin, 34 | bmi, 35 | diabetesPedigreeFunction, 36 | age); 37 | } 38 | 39 | public boolean hasDiabetes() { 40 | return diabetes; 41 | } 42 | 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/example/diabetesprediction/InputParser.java: -------------------------------------------------------------------------------- 1 | package example.diabetesprediction; 2 | 3 | import java.io.IOException; 4 | import java.nio.file.Files; 5 | import java.nio.file.Path; 6 | import java.util.List; 7 | import java.util.Random; 8 | import java.util.concurrent.atomic.AtomicBoolean; 9 | import java.util.stream.Collectors; 10 | 11 | public class InputParser { 12 | private static final Path dataOrigin = Path.of("data/diabetesprediction/diabetes.csv"); 13 | public static final int amountOfUnseenData = 100; 14 | 15 | private final List trainingsData; 16 | private final List unseenData; 17 | private final Random random = new Random(); 18 | private final AtomicBoolean lastOneHadDiabetes = new AtomicBoolean(); 19 | 20 | public InputParser() throws IOException { 21 | List inputData = Files.lines(dataOrigin) 22 | .skip(1) //Skip the headers of the .csv file 23 | .map(set -> set.split(",")) 24 | .map(DiabetesDataSet::new) 25 | .collect(Collectors.toList()); 26 | 27 | //We don't want our Algorithm to just "remember" the data 28 | //Therefore we split the data up in trainings data and in unseen data which we can 29 | //check our result with afterwards 30 | this.trainingsData = inputData.subList(amountOfUnseenData, inputData.size()); 31 | this.unseenData = inputData.subList(0, amountOfUnseenData); 32 | } 33 | 34 | public List getTrainingsData() { 35 | return trainingsData; 36 | } 37 | 38 | public DiabetesDataSet getRandomTrainingsDataSet(){ 39 | //The input data is quite heavily biased on having datasets without diabetes (around 65%) 40 | //if we don't make sure that the input is roughly 50% our Algorithm will learn that 41 | //always predicting "no diabetes" has a success rate of 65% 42 | //We want our Algorithm to be based on the input values though and not on statistical probabilities 43 | DiabetesDataSet returnValue = trainingsData.get(random.nextInt(trainingsData.size())); 44 | while(returnValue.hasDiabetes() == lastOneHadDiabetes.get()) 45 | returnValue = trainingsData.get(random.nextInt(trainingsData.size())); 46 | lastOneHadDiabetes.set(returnValue.hasDiabetes()); 47 | return returnValue; 48 | } 49 | 50 | public List getUnseenData() { 51 | return unseenData; 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/example/diabetesprediction/Main.java: -------------------------------------------------------------------------------- 1 | package example.diabetesprediction; 2 | 3 | import de.fhws.easyml.ai.geneticneuralnet.*; 4 | import de.fhws.easyml.geneticalgorithm.GeneticAlgorithm; 5 | import de.fhws.easyml.geneticalgorithm.evolution.Mutator; 6 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner; 7 | import de.fhws.easyml.geneticalgorithm.evolution.selectors.EliteSelector; 8 | import de.fhws.easyml.geneticalgorithm.evolution.Selector; 9 | import de.fhws.easyml.geneticalgorithm.logger.loggers.IntervalConsoleLogger; 10 | import de.fhws.easyml.linearalgebra.Randomizer; 11 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 12 | 13 | import java.io.IOException; 14 | 15 | 16 | public class Main { 17 | private static final Selector SELECTOR = new EliteSelector<>( 0.3 ); 18 | private static final Recombiner RECOMBINER = new NNUniformCrossoverRecombiner(2); 19 | private static final Mutator MUTATOR = new NNRandomMutator( 0.3, 0.10, new Randomizer( -0.5, 0.5 ), 0.01 ); 20 | private static final int POP_SIZE = 10000; 21 | private static final int GENS = 2000; 22 | 23 | 24 | public static void main(String[] args) throws IOException { 25 | //Specificing a Neural Network with 26 | // input size 8, 27 | // output size 1, 28 | // 0 hiddenlayers 29 | // Random Weights between -0.1 and 0.1 30 | // Random Bias between -0.2 and 0.5 31 | NeuralNetSupplier neuralNetSupplier = ( ) -> new NeuralNet.Builder( 8, 1 ) 32 | .withWeightRandomizer( new Randomizer( -0.1, 0.1 ) ) 33 | .withBiasRandomizer( new Randomizer( -0.2, 0.5 ) ) 34 | .build( ); 35 | 36 | final InputParser inputParser = new InputParser(); 37 | 38 | NeuralNetFitnessFunction fitnessFunction = (nn) -> { 39 | DiabetesDataSet data = inputParser.getRandomTrainingsDataSet(); 40 | //Due to the default activation function in the Neural Network the output is between 0 and 1 41 | //If the output is greater than 0.5 we will interpret this as "Patient has diabetes" 42 | boolean prediction = nn.calcOutput(data.toVector()).get(0) > 0.5; 43 | return prediction == data.hasDiabetes() ? 1 : 0; //If the prediction was correct return 1 otherwise return 0 44 | }; 45 | 46 | NeuralNetPopulationSupplier nnPopSup= new NeuralNetPopulationSupplier(neuralNetSupplier, fitnessFunction, POP_SIZE); 47 | 48 | GeneticAlgorithm geneticAlgorithm = new GeneticAlgorithm.Builder<>(nnPopSup, GENS, SELECTOR) 49 | .withRecombiner(RECOMBINER) 50 | .withMutator(MUTATOR) 51 | .withMultiThreaded(16) //uses 16 Threads to process 52 | .withLoggers(new IntervalConsoleLogger(100)) //used to print logging info in the console 53 | .build(); 54 | 55 | NeuralNetIndividual result = geneticAlgorithm.solve(); 56 | testModel(inputParser, result.getNN()); 57 | } 58 | 59 | public static void testModel(InputParser inputParser, NeuralNet model){ 60 | long correctGuesses = inputParser.getUnseenData() 61 | .stream() 62 | .filter(data -> { 63 | boolean prediction = model.calcOutput(data.toVector()).get(0) > 0.5; 64 | System.out.println("prediction is " + prediction + " - patient has diabetes: " + data.hasDiabetes()); 65 | return prediction == data.hasDiabetes(); 66 | }) 67 | .count(); 68 | System.out.println("The model guessed " + ((100d*correctGuesses)/InputParser.amountOfUnseenData) + "% correct"); 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/main/java/example/diabetesprediction/backpropagation/MainBackprop.java: -------------------------------------------------------------------------------- 1 | package example.diabetesprediction.backpropagation; 2 | 3 | import de.fhws.easyml.ai.backpropagation.BackpropagationTrainer; 4 | import de.fhws.easyml.ai.backpropagation.logger.loggers.ConsoleLogger; 5 | import de.fhws.easyml.ai.geneticneuralnet.*; 6 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 7 | import de.fhws.easyml.ai.neuralnetwork.activationfunction.Sigmoid; 8 | import de.fhws.easyml.ai.neuralnetwork.activationfunction.Tanh; 9 | import de.fhws.easyml.geneticalgorithm.GeneticAlgorithm; 10 | import de.fhws.easyml.geneticalgorithm.evolution.Mutator; 11 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner; 12 | import de.fhws.easyml.geneticalgorithm.evolution.Selector; 13 | import de.fhws.easyml.geneticalgorithm.evolution.selectors.EliteSelector; 14 | import de.fhws.easyml.geneticalgorithm.logger.loggers.IntervalConsoleLogger; 15 | import de.fhws.easyml.linearalgebra.Randomizer; 16 | import de.fhws.easyml.linearalgebra.Vector; 17 | import example.diabetesprediction.DiabetesDataSet; 18 | import example.diabetesprediction.InputParser; 19 | 20 | import java.io.IOException; 21 | import java.util.List; 22 | import java.util.function.Supplier; 23 | import java.util.stream.Collectors; 24 | import java.util.stream.Stream; 25 | 26 | public class MainBackprop { 27 | 28 | public static void main(String[] args) throws IOException { 29 | //Specificing a Neural Network with 30 | // input size 8, 31 | // output size 1, 32 | // 0 hiddenlayers 33 | // Random Weights between -0.1 and 0.1 34 | // Random Bias between -0.2 and 0. 35 | NeuralNet neuralNet = new NeuralNet.Builder(8, 1) 36 | .addLayer(5) 37 | .addLayer(3) 38 | .withWeightRandomizer(new Randomizer(-0.005, 0.005)) 39 | .withBiasRandomizer(new Randomizer(0, 0.01)) 40 | .withActivationFunction(new Sigmoid()) 41 | .build(); 42 | 43 | final InputParser inputParser = new InputParser(); 44 | 45 | Supplier batchSupplier = () -> { 46 | List dataList = Stream.generate(inputParser::getRandomTrainingsDataSet).limit(10).collect(Collectors.toList()); 47 | List expectedOutput = dataList.stream().map(set -> { 48 | Vector vec = new Vector(1); 49 | vec.set(0, set.hasDiabetes() ? 1 : 0); 50 | return vec; 51 | }).collect(Collectors.toList()); 52 | return new BackpropagationTrainer.Batch(dataList.stream().map(DiabetesDataSet::toVector).collect(Collectors.toList()), expectedOutput ); 53 | }; 54 | 55 | System.out.println("Training..."); 56 | new BackpropagationTrainer.Builder(neuralNet, batchSupplier, 0.2, 40000) 57 | // .withLogger(new ConsoleLogger()) 58 | .build() 59 | .train(); 60 | 61 | testModel(inputParser, neuralNet); 62 | } 63 | 64 | public static void testModel(InputParser inputParser, NeuralNet model) { 65 | long correctGuesses = inputParser.getUnseenData() 66 | .stream() 67 | .filter(data -> { 68 | boolean prediction = model.calcOutput(data.toVector()).get(0) > 0.5; 69 | System.out.println("prediction is " + prediction + " - patient has diabetes: " + data.hasDiabetes()); 70 | return prediction == data.hasDiabetes(); 71 | }) 72 | .count(); 73 | System.out.println("The model guessed " + ((100d * correctGuesses) / InputParser.amountOfUnseenData) + "% correct"); 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/test/java/testGeneticAlgorithmBlackBox/GeneticAlgorithmTester.java: -------------------------------------------------------------------------------- 1 | package testGeneticAlgorithmBlackBox; 2 | 3 | @FunctionalInterface 4 | public interface GeneticAlgorithmTester 5 | { 6 | void test(); 7 | } 8 | -------------------------------------------------------------------------------- /src/test/java/testGeneticAlgorithmBlackBox/Graph.java: -------------------------------------------------------------------------------- 1 | package testGeneticAlgorithmBlackBox; 2 | 3 | import java.io.BufferedReader; 4 | import java.io.FileNotFoundException; 5 | import java.io.FileReader; 6 | import java.io.IOException; 7 | import java.util.ArrayList; 8 | import java.util.Arrays; 9 | 10 | public class Graph 11 | { 12 | 13 | private int[][] matrix; 14 | 15 | public Graph(int[][] matrix) { 16 | this.matrix = matrix; 17 | } 18 | 19 | public int getVertexCount() { 20 | return matrix.length; 21 | } 22 | 23 | public int distance(int i, int j) { 24 | return matrix[i][j]; 25 | } 26 | 27 | public int[] getNeighbours(int i) { 28 | ArrayList list = new ArrayList<>(); 29 | for(int j = 0; j < matrix[i].length; j++) { 30 | if(distance(i, j) != 0) 31 | list.add(j); 32 | } 33 | return list.stream().mapToInt(integer -> integer).toArray(); 34 | } 35 | 36 | public static Graph loadGraph(String filepath) { 37 | if(!filepath.endsWith(".csv")) 38 | throw new IllegalArgumentException("file must be a .csv"); 39 | 40 | try(BufferedReader br = new BufferedReader(new FileReader(filepath))) { 41 | String in = br.readLine(); 42 | if(in == null) 43 | throw new IllegalArgumentException("the given file is empty"); 44 | 45 | String[] nums = in.split(","); 46 | int vertexCount = nums.length; 47 | 48 | if(vertexCount <= 1) 49 | throw new IllegalArgumentException("graph must have at least two vertices"); 50 | 51 | int[][] matrix = new int[vertexCount][vertexCount]; 52 | int rowCount = 0; 53 | while (in != null) { 54 | nums = in.split(","); 55 | if(nums.length != vertexCount) { 56 | throw new IllegalArgumentException("all lines must be of the same length"); 57 | } 58 | addLineToMatrix(matrix, rowCount, nums); 59 | in = br.readLine(); 60 | rowCount++; 61 | } 62 | if(rowCount != vertexCount) 63 | throw new IllegalArgumentException("amount of rows must be the same as amount of columns"); 64 | return new Graph(matrix); 65 | } catch (FileNotFoundException e) { 66 | throw new IllegalArgumentException("file could not be found"); 67 | } catch (IOException e) { 68 | throw new RuntimeException("something went wrong with IO"); 69 | } 70 | } 71 | 72 | private static void addLineToMatrix(int[][] matrix, int row, String[] line) { 73 | try { 74 | for (int j = 0; j < matrix[row].length; j++) { 75 | matrix[row][j] = Integer.parseInt(line[j]); 76 | } 77 | } catch (NumberFormatException e) { 78 | throw new IllegalArgumentException("line has invalid characters: not a number"); 79 | } 80 | } 81 | 82 | public int[][] getMatrixCopy() 83 | { 84 | return Arrays.stream(matrix).map(int[]::clone).toArray(int[][]::new); 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /src/test/java/testGeneticAlgorithmBlackBox/TSP.java: -------------------------------------------------------------------------------- 1 | package testGeneticAlgorithmBlackBox; 2 | 3 | import de.fhws.easyml.geneticalgorithm.Individual; 4 | 5 | import java.util.ArrayList; 6 | import java.util.Collections; 7 | import java.util.List; 8 | 9 | public class TSP implements Individual 10 | { 11 | 12 | private List list; 13 | private Graph graph; 14 | double fitness; 15 | 16 | public TSP(Graph graph, List list) { 17 | this.list = list; 18 | this.graph = graph; 19 | } 20 | 21 | public static TSP genRandomSolution(Graph graph){ 22 | List list = new ArrayList<>(graph.getVertexCount()); 23 | 24 | for(int i = 0; i < graph.getVertexCount(); i++) { 25 | int r = (int) (Math.random()*graph.getVertexCount()); 26 | if (!list.contains(r)) { 27 | list.add(r); 28 | }else { 29 | i--; 30 | } 31 | } 32 | return new TSP(graph, list); 33 | } 34 | 35 | @Override public void calcFitness() 36 | { 37 | int fitness = 0; 38 | for(int i = 0; i < list.size()-1; i++) { 39 | fitness += graph.distance(list.get(i), list.get(i+1)); 40 | } 41 | fitness += graph.distance(list.get(list.size()-1), list.get(0)); 42 | this.fitness = 1f/fitness; 43 | } 44 | 45 | @Override public double getFitness() 46 | { 47 | return this.fitness; 48 | } 49 | 50 | public TSP mutate(){ 51 | int x = (int) (Math.random()*this.list.size()); 52 | int y = x; 53 | while(x == y) { 54 | y = (int) (Math.random()*this.list.size()); 55 | } 56 | Collections.swap(this.list, x, y); 57 | return this; 58 | } 59 | 60 | @Override public TSP copy() 61 | { 62 | List copyList = new ArrayList<>(); 63 | copyList.addAll(this.list); 64 | 65 | Graph copyGraph = new Graph(this.graph.getMatrixCopy()); 66 | 67 | TSP copy = new TSP(copyGraph, copyList); 68 | copy.fitness = this.fitness; 69 | return copy; 70 | } 71 | 72 | public int getDist(){ 73 | 74 | int dist = 0; 75 | for(int i = 0; i < list.size()-1; i++) { 76 | dist += graph.distance(list.get(i), list.get(i+1)); 77 | } 78 | dist += graph.distance(list.get(list.size()-1), list.get(0)); 79 | return dist; 80 | } 81 | 82 | @Override 83 | public String toString() { 84 | String res = ""; 85 | for(int i : list) { 86 | res += i + " -> "; 87 | } 88 | 89 | int dist = getDist(); 90 | 91 | 92 | return res.substring(0, res.length()-4) + " " + getFitness() + " dist:" + dist; 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /src/test/java/testGeneticAlgorithmBlackBox/TestGeneticAlgorithm.java: -------------------------------------------------------------------------------- 1 | package testGeneticAlgorithmBlackBox; 2 | 3 | import de.fhws.easyml.geneticalgorithm.GeneticAlgorithm; 4 | import de.fhws.easyml.geneticalgorithm.Population; 5 | import de.fhws.easyml.geneticalgorithm.evolution.Mutator; 6 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner; 7 | import de.fhws.easyml.geneticalgorithm.evolution.selectors.EliteSelector; 8 | import de.fhws.easyml.geneticalgorithm.evolution.selectors.RouletteWheelSelector; 9 | import de.fhws.easyml.geneticalgorithm.evolution.Selector; 10 | import de.fhws.easyml.geneticalgorithm.evolution.selectors.TournamentSelector; 11 | import de.fhws.easyml.geneticalgorithm.populationsupplier.PopulationSupplier; 12 | import org.junit.Test; 13 | 14 | import java.util.stream.Collectors; 15 | import java.util.stream.IntStream; 16 | 17 | import static org.junit.Assert.assertEquals; 18 | import static org.junit.Assert.fail; 19 | 20 | public class TestGeneticAlgorithm 21 | { 22 | private static final int TEST_RUNS = 100; 23 | private static final int EXPECTED_RESULT = 14; 24 | 25 | private static final double TSP_MUTATION_RATE = 0.95; 26 | private static final int POPULATION_SIZE = 64; 27 | private static final int MAX_GENS = 20; 28 | private static final double SELECTION_PERCENTAGE = 0.2; 29 | private static final int TOURNAMENT_SIZE = 12; 30 | 31 | private static final Graph graph = createTestGraph(); 32 | private static final PopulationSupplier popSup = createTspPopulationSupplier(); 33 | private static final Mutator tspMutator = createTspMutator(); 34 | private static final Recombiner tspRecombiner = createTspRecombiner(); 35 | 36 | 37 | @Test 38 | public void testGAWithRouletteWheelSelectorMultipleTimes(){ 39 | testSeveralTimes(this::testGeneticAlgorithmWithRouletteWheelSelector); 40 | } 41 | 42 | @Test 43 | public void testGAWithRouletteWheelSelectorMultiThreadedMultipleTimes(){ 44 | testSeveralTimes(this::testGeneticAlgorithmWithRouletteWheelSelectorMultiThreaded); 45 | } 46 | 47 | @Test 48 | public void testGAWithEliteSelectorMultipleTimes(){ 49 | testSeveralTimes(this::testGeneticAlgorithmWithEliteSelector); 50 | } 51 | 52 | @Test 53 | public void testGAWithEliteSelectorMultiThreadedMultipleTimes(){ 54 | testSeveralTimes(this::testGeneticAlgorithmWithEliteSelectorMultiThreaded); 55 | } 56 | 57 | @Test 58 | public void testGAWithTournamentSelectorMultipleTimes(){ 59 | testSeveralTimes(this::testGeneticAlgorithmWithTournamentSelector); 60 | } 61 | 62 | @Test 63 | public void testGAWithTournamentSelectorMultiThreadedMultipleTimes(){ 64 | testSeveralTimes(this::testGeneticAlgorithmWithTournamentSelectorMultiThreaded); 65 | } 66 | 67 | @Test 68 | public void testGeneticAlgorithmShutdownException() { 69 | GeneticAlgorithm ga = createTestTSPGeneticAlgorithm((pop, executor) -> {}, false); 70 | ga.solve(); 71 | doSolveAfterShutdownOfGeneticAlgorithm(ga); 72 | } 73 | 74 | private void doSolveAfterShutdownOfGeneticAlgorithm(GeneticAlgorithm ga) 75 | { 76 | try{ 77 | ga.solve(); 78 | fail(); 79 | }catch(IllegalStateException e) { 80 | assertEquals(GeneticAlgorithm.ILLEGAL_OPERATION_AFTER_SHUTDOWN_MESSAGE, e.getMessage()); 81 | } 82 | } 83 | 84 | public void testSeveralTimes(GeneticAlgorithmTester tester) { 85 | for (int i = 0; i < TEST_RUNS; i++) 86 | tester.test(); 87 | } 88 | 89 | private void testGeneticAlgorithmWithRouletteWheelSelector(){ 90 | createGeneticAlgorithmAndRun(new RouletteWheelSelector<>(SELECTION_PERCENTAGE, true), false); 91 | } 92 | 93 | private void testGeneticAlgorithmWithRouletteWheelSelectorMultiThreaded(){ 94 | createGeneticAlgorithmAndRun(new RouletteWheelSelector<>(SELECTION_PERCENTAGE, true), true); 95 | } 96 | 97 | private void testGeneticAlgorithmWithEliteSelector(){ 98 | createGeneticAlgorithmAndRun(new EliteSelector<>(SELECTION_PERCENTAGE), false); 99 | } 100 | 101 | private void testGeneticAlgorithmWithEliteSelectorMultiThreaded(){ 102 | createGeneticAlgorithmAndRun(new EliteSelector<>(SELECTION_PERCENTAGE), true); 103 | } 104 | 105 | private void testGeneticAlgorithmWithTournamentSelector(){ 106 | createGeneticAlgorithmAndRun(new TournamentSelector<>(SELECTION_PERCENTAGE, TOURNAMENT_SIZE), false); 107 | } 108 | 109 | private void testGeneticAlgorithmWithTournamentSelectorMultiThreaded(){ 110 | createGeneticAlgorithmAndRun(new TournamentSelector<>(SELECTION_PERCENTAGE, TOURNAMENT_SIZE), true); 111 | } 112 | 113 | private void createGeneticAlgorithmAndRun( Selector selector, boolean withMultiThreading) 114 | { 115 | GeneticAlgorithm ga = createTestTSPGeneticAlgorithm(selector, withMultiThreading); 116 | 117 | TSP result = ga.solve(); 118 | assertEquals(EXPECTED_RESULT, result.getDist()); 119 | } 120 | 121 | private static GeneticAlgorithm createTestTSPGeneticAlgorithm(Selector selector, boolean withMultiThreading){ 122 | return new GeneticAlgorithm.Builder(popSup, MAX_GENS, selector) 123 | .withRecombiner(tspRecombiner) 124 | .withMutator(tspMutator) 125 | .withMultiThreaded( withMultiThreading ? Runtime.getRuntime().availableProcessors() : 1 ) 126 | .build(); 127 | } 128 | 129 | private static Graph createTestGraph() { 130 | final int[][] adjazenz = new int[][] 131 | { 132 | {0 ,10 ,3 ,5 ,9 ,8}, 133 | {1 ,0 ,7 ,2 ,8 ,6}, 134 | {5, 1, 0, 3, 7, 5}, 135 | {6, 3, 2, 0, 1, 7}, 136 | {6, 4, 3, 1, 0, 9}, 137 | {5, 3, 7, 4, 1, 0} 138 | }; 139 | 140 | return new Graph(adjazenz); 141 | } 142 | 143 | private static PopulationSupplier createTspPopulationSupplier() 144 | { 145 | return () -> new Population<>(IntStream.range(0, POPULATION_SIZE).mapToObj(i -> TSP.genRandomSolution(graph)) 146 | .collect(Collectors.toList())); 147 | } 148 | 149 | private static Mutator createTspMutator() 150 | { 151 | return (pop, executorService) -> pop.getIndividuals().forEach(indi -> { 152 | if (Math.random() < TSP_MUTATION_RATE && pop.getBest() != indi) 153 | indi.mutate(); 154 | }); 155 | } 156 | 157 | private static Recombiner createTspRecombiner() { 158 | return (pop, size, executorService) -> { 159 | int i = 0; 160 | int initialSize = pop.getSize(); 161 | while(pop.getSize() < size){ 162 | pop.getIndividuals().add( pop.getIndividuals().get(i).copy() ); 163 | i++; 164 | i = i % initialSize; 165 | } 166 | }; 167 | } 168 | 169 | 170 | } 171 | -------------------------------------------------------------------------------- /src/test/java/testLinearAlgebra/TestVector.java: -------------------------------------------------------------------------------- 1 | package testLinearAlgebra; 2 | 3 | import de.fhws.easyml.linearalgebra.LinearAlgebra; 4 | import de.fhws.easyml.linearalgebra.Matrix; 5 | import de.fhws.easyml.linearalgebra.Vector; 6 | import org.junit.Test; 7 | 8 | import static org.junit.Assert.assertEquals; 9 | 10 | public class TestVector { 11 | 12 | private final Vector vector1 = LinearAlgebra.unitVector( 3 ); 13 | private final Vector vector2 = LinearAlgebra.vectorWithValues( 3, 2 ); 14 | 15 | @Test 16 | public void test_apply( ) { 17 | Vector actualResult = new Vector( 3 ); 18 | actualResult.apply( operand -> 10 ); 19 | 20 | Vector expectedResult = new Vector( 10d, 10d, 10d ); 21 | 22 | assertEquals( expectedResult, actualResult ); 23 | } 24 | 25 | @Test 26 | public void test_add( ) { 27 | Vector expectedResult = LinearAlgebra.vectorWithValues( 3, 3 ); 28 | Vector actualResult = LinearAlgebra.add( vector1, vector2 ); 29 | 30 | assertEquals( expectedResult, actualResult ); 31 | } 32 | 33 | @Test 34 | public void test_sub( ) { 35 | Vector expectedResult = LinearAlgebra.vectorWithValues( 3, -1 ); 36 | Vector actualResult = LinearAlgebra.sub( vector1, vector2 ); 37 | 38 | assertEquals( expectedResult, actualResult ); 39 | } 40 | 41 | @Test 42 | public void test_multiply_with_matrix( ) { 43 | Vector expectedResult = new Vector( 6, 12 ); 44 | Matrix matrix = new Matrix( new double[][]{ { 1, 1, 1 }, { 2, 2, 2 } } ); 45 | Vector actualResult = LinearAlgebra.multiply( matrix, vector2 ); 46 | 47 | assertEquals( expectedResult, actualResult ); 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/test/java/testNetworkTrainer/TestNetworkTrainerBlackBox.java: -------------------------------------------------------------------------------- 1 | package testNetworkTrainer; 2 | 3 | import de.fhws.easyml.ai.geneticneuralnet.*; 4 | import de.fhws.easyml.geneticalgorithm.GeneticAlgorithm; 5 | import de.fhws.easyml.geneticalgorithm.evolution.Mutator; 6 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner; 7 | import de.fhws.easyml.geneticalgorithm.evolution.selectors.EliteSelector; 8 | import de.fhws.easyml.geneticalgorithm.evolution.Selector; 9 | import de.fhws.easyml.linearalgebra.Randomizer; 10 | import de.fhws.easyml.linearalgebra.Vector; 11 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 12 | import de.fhws.easyml.utility.Validator; 13 | import org.junit.Test; 14 | 15 | import static org.junit.Assert.assertTrue; 16 | 17 | public class TestNetworkTrainerBlackBox { 18 | 19 | private static final Selector SELECTOR = new EliteSelector<>( 0.1 ); 20 | private static final Recombiner RECOMBINER = new NNUniformCrossoverRecombiner( 2 ); 21 | private static final Mutator MUTATOR = new NNRandomMutator( 0.9, 0.4, new Randomizer( -0.01, 0.01 ), 0.01 ); 22 | 23 | private static final int NUMBER = 5; 24 | private static final int GENS = 50; 25 | private static final int POP_SIZE = 200; 26 | private static final int ASSERT_TIMES = 10; 27 | private static final double TOLERANCE = 0.2; 28 | private static final int DO_TEST_TIMES = 10; 29 | private static final double PERCENTAGE_TO_PASS = 0.8; 30 | 31 | @Test 32 | public void test_simple_number_prediction( ) { 33 | 34 | int countSuccess = 0; 35 | 36 | for ( int i = 0; i < DO_TEST_TIMES; i++ ) { 37 | NeuralNet best = evolveNetwork( NUMBER, GENS, POP_SIZE ); 38 | 39 | double diff = 0; 40 | for ( int j = 0; j < ASSERT_TIMES; j++ ) { 41 | diff += calcDiffSimpleNumber( best, j / 10.0, NUMBER ); 42 | } 43 | 44 | if ( Validator.value( diff ).isBetween( 0, TOLERANCE * ASSERT_TIMES ) ) 45 | countSuccess++; 46 | } 47 | 48 | System.out.println( "Neural net success: " + countSuccess + " / " + DO_TEST_TIMES ); 49 | assertTrue( countSuccess >= DO_TEST_TIMES * PERCENTAGE_TO_PASS ); 50 | } 51 | 52 | private NeuralNet evolveNetwork( int NUMBER, int GENS, int POP_SIZE ) { 53 | NeuralNetSupplier neuralNetSupplier = ( ) -> new NeuralNet.Builder( 1, 1 ) 54 | .addLayer( 2 ) 55 | .withActivationFunction( x -> x ) 56 | .withWeightRandomizer( new Randomizer( -10, 10 ) ) 57 | .withBiasRandomizer( new Randomizer( 0, 2 ) ) 58 | .build( ); 59 | 60 | NeuralNetFitnessFunction fitnessFunction = neuralNet -> { 61 | double diff = 0; 62 | 63 | for ( int i = 0; i < 10; i++ ) { 64 | diff += calcDiffSimpleNumber( neuralNet, Math.random( ), NUMBER ); 65 | } 66 | 67 | return diff == 0 ? Double.MAX_VALUE : 1 / diff; 68 | }; 69 | 70 | NeuralNetPopulationSupplier supplier = new NeuralNetPopulationSupplier( neuralNetSupplier, fitnessFunction, POP_SIZE ); 71 | 72 | GeneticAlgorithm geneticAlgorithm = 73 | new GeneticAlgorithm.Builder<>( supplier, GENS, SELECTOR ) 74 | .withRecombiner( RECOMBINER ) 75 | .withMutator( MUTATOR ) 76 | .build( ); 77 | 78 | return geneticAlgorithm.solve( ).getNN( ); 79 | } 80 | 81 | private static double calcDiffSimpleNumber( NeuralNet neuralNet, double x, double y ) { 82 | Vector input = new Vector( x ); 83 | Vector output = neuralNet.calcOutput( input ); 84 | return Math.abs( output.get( 0 ) - y ); 85 | } 86 | 87 | 88 | } 89 | -------------------------------------------------------------------------------- /src/test/java/testNeuralNetwork/TestBackpropagation.java: -------------------------------------------------------------------------------- 1 | package testNeuralNetwork; 2 | 3 | import de.fhws.easyml.ai.backpropagation.BackpropagationTrainer; 4 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 5 | import de.fhws.easyml.ai.neuralnetwork.activationfunction.Tanh; 6 | import de.fhws.easyml.ai.neuralnetwork.costfunction.SummedCostFunction; 7 | import de.fhws.easyml.linearalgebra.Randomizer; 8 | import de.fhws.easyml.linearalgebra.Vector; 9 | import org.jetbrains.annotations.NotNull; 10 | import org.junit.Test; 11 | 12 | import java.util.List; 13 | import java.util.stream.Collectors; 14 | import java.util.stream.Stream; 15 | 16 | import static org.junit.Assert.assertEquals; 17 | 18 | public class TestBackpropagation { 19 | 20 | @Test 21 | public void testBackpropagation(){ 22 | NeuralNet neuralNet = new NeuralNet.Builder( 2, 1 ).addLayer( 5 ) 23 | .withBiasRandomizer( new Randomizer( 0,2 ) ) 24 | .withWeightRandomizer( new Randomizer( -1,1 ) ) 25 | .withActivationFunction(new Tanh()) 26 | .build(); 27 | 28 | new BackpropagationTrainer.Builder(neuralNet, () -> new BackpropagationTrainer.Batch(getInput(), getOutput()), 0.5, 1000) 29 | // .withLogger(new ConsoleLogger()) 30 | .build() 31 | .train(); 32 | 33 | for (int i = 0; i < 4; i++) { 34 | double costs = new SummedCostFunction().costs(getOutput().get(i), neuralNet.calcOutput(getInput().get(i))); 35 | assertEquals(0, costs, 0.01); 36 | } 37 | } 38 | 39 | 40 | @NotNull 41 | private List getOutput() { 42 | Vector expTrue = new Vector( 1 ); 43 | expTrue.set( 0, 1 ); 44 | Vector expFalse = new Vector( 1 ); 45 | return Stream.of( expFalse, expTrue, expTrue, expFalse ).collect( Collectors.toList()); 46 | } 47 | 48 | @NotNull 49 | private List getInput() { 50 | Vector input1 = new Vector( 0,0 ); 51 | Vector input2 = new Vector(0,0.76); 52 | Vector input3 = new Vector( 0.76,0 ); 53 | Vector input4 = new Vector( 0.76,0.76 ); 54 | return Stream.of(input1, input2, input3, input4).collect( Collectors.toList()); 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/test/java/testNeuralNetwork/TestInvalidArguments.java: -------------------------------------------------------------------------------- 1 | package testNeuralNetwork; 2 | 3 | import de.fhws.easyml.linearalgebra.Vector; 4 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 5 | import org.junit.Test; 6 | 7 | import static org.junit.Assert.fail; 8 | 9 | public class TestInvalidArguments { 10 | 11 | @Test 12 | public void test_zero_inputSize( ) { 13 | testIllegalArgumentException( ( ) -> new NeuralNet.Builder( 0, 1 ) ); 14 | } 15 | 16 | @Test 17 | public void test_zero_outputSize( ) { 18 | testIllegalArgumentException( ( ) -> new NeuralNet.Builder( 1, 0 ) ); 19 | } 20 | 21 | @Test 22 | public void test_negative_inputSize( ) { 23 | testIllegalArgumentException( ( ) -> new NeuralNet.Builder( -1, 1 ) ); 24 | } 25 | 26 | @Test 27 | public void test_negative_outputSize( ) { 28 | testIllegalArgumentException( ( ) -> new NeuralNet.Builder( 1, -1 ) ); 29 | } 30 | 31 | @Test 32 | public void test_zero_layer( ) { 33 | testIllegalArgumentException( ( ) -> 34 | new NeuralNet.Builder( 1, 1 ).addLayer( 0 ) 35 | ); 36 | } 37 | 38 | @Test 39 | public void test_negative_layer( ) { 40 | testIllegalArgumentException( ( ) -> 41 | new NeuralNet.Builder( 1, 1 ).addLayer( -1 ) 42 | ); 43 | } 44 | 45 | @Test 46 | public void test_multiple_layers_zero( ) { 47 | testIllegalArgumentException( ( ) -> 48 | new NeuralNet.Builder( 1, 1 ).addLayers(0, 1 ) 49 | ); 50 | 51 | testIllegalArgumentException( ( ) -> 52 | new NeuralNet.Builder( 1, 1 ).addLayers(3, 0 ) 53 | ); 54 | 55 | testIllegalArgumentException( ( ) -> 56 | new NeuralNet.Builder( 1, 1 ).addLayers(3, 2, 0, 1 ) 57 | ); 58 | } 59 | 60 | @Test 61 | public void test_multiple_layers_negative( ) { 62 | testIllegalArgumentException( ( ) -> 63 | new NeuralNet.Builder( 1, 1 ).addLayers(-1, 1 ) 64 | ); 65 | 66 | testIllegalArgumentException( ( ) -> 67 | new NeuralNet.Builder( 1, 1 ).addLayers(3, -1 ) 68 | ); 69 | 70 | testIllegalArgumentException( ( ) -> 71 | new NeuralNet.Builder( 1, 1 ).addLayers(3, 2, -1, 1 ) 72 | ); 73 | } 74 | 75 | @Test 76 | public void test_incompatible_inputSize( ) { 77 | testIllegalArgumentException( ( ) -> new NeuralNet.Builder( 10, 1 ).build().calcOutput( new Vector( 9 ) ) ); 78 | } 79 | 80 | private void testIllegalArgumentException( Runnable runnable ) { 81 | try { 82 | runnable.run( ); 83 | fail( ); 84 | } catch ( IllegalArgumentException e ) { 85 | 86 | } 87 | } 88 | 89 | } 90 | -------------------------------------------------------------------------------- /src/test/java/testNeuralNetwork/TestNeuralNetMaths.java: -------------------------------------------------------------------------------- 1 | package testNeuralNetwork; 2 | 3 | import de.fhws.easyml.linearalgebra.Randomizer; 4 | import de.fhws.easyml.linearalgebra.Vector; 5 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 6 | import org.junit.Test; 7 | 8 | import java.util.List; 9 | 10 | import static org.junit.Assert.assertEquals; 11 | 12 | public class TestNeuralNetMaths { 13 | 14 | @Test 15 | public void test_mathematical_correctness_weights( ) { 16 | NeuralNet neuralNet = nnWithFixedWeightsAndBiases( 2, 0 ); 17 | 18 | List allLayers = neuralNet.calcAllLayer( new Vector( 1, 1, 1 ) ); 19 | 20 | assertEquals( 2, allLayers.size( ) ); 21 | 22 | assertEquals( new Vector( 6, 6 ), allLayers.get( 0 ) ); 23 | assertEquals( new Vector( 24.0 ), allLayers.get( 1 ) ); 24 | } 25 | 26 | @Test 27 | public void test_mathematical_correctness_biases( ) { 28 | NeuralNet neuralNet = nnWithFixedWeightsAndBiases( 1, 1 ); 29 | 30 | List allLayers = neuralNet.calcAllLayer( new Vector( 1, 1, 1 ) ); 31 | 32 | assertEquals( 2, allLayers.size( ) ); 33 | 34 | assertEquals( new Vector( 2, 2 ), allLayers.get( 0 ) ); 35 | assertEquals( new Vector( 3.0 ), allLayers.get( 1 ) ); 36 | } 37 | 38 | @Test 39 | public void test_mathematical_correctness_weights_and_biases( ) { 40 | NeuralNet neuralNet = nnWithFixedWeightsAndBiases( 2, 1 ); 41 | 42 | List allLayers = neuralNet.calcAllLayer( new Vector( 1, 1, 1 ) ); 43 | 44 | assertEquals( 2, allLayers.size( ) ); 45 | 46 | assertEquals( new Vector( 5, 5 ), allLayers.get( 0 ) ); 47 | assertEquals( new Vector( 19.0 ), allLayers.get( 1 ) ); 48 | } 49 | 50 | @Test 51 | public void test_mathematical_correctness_all_zero( ) { 52 | NeuralNet neuralNet = nnWithFixedWeightsAndBiases( 0, 0 ); 53 | 54 | List allLayers = neuralNet.calcAllLayer( new Vector( 1, 1, 1 ) ); 55 | 56 | assertEquals( 2, allLayers.size( ) ); 57 | 58 | assertEquals( new Vector( 0, 0 ), allLayers.get( 0 ) ); 59 | assertEquals( new Vector( 0.0 ), allLayers.get( 1 ) ); 60 | } 61 | 62 | @Test 63 | public void test_mathematical_correctness_negative_input( ) { 64 | NeuralNet neuralNet = nnWithFixedWeightsAndBiases( 2, 1 ); 65 | 66 | List allLayers = neuralNet.calcAllLayer( new Vector( -1, -1, -1 ) ); 67 | 68 | assertEquals( 2, allLayers.size( ) ); 69 | 70 | assertEquals( new Vector( -7, -7 ), allLayers.get( 0 ) ); 71 | assertEquals( new Vector( -29.0 ), allLayers.get( 1 ) ); 72 | } 73 | 74 | @Test 75 | public void test_mathematical_correctness_negative_weights( ) { 76 | NeuralNet neuralNet = nnWithFixedWeightsAndBiases( -2, 1 ); 77 | 78 | List allLayers = neuralNet.calcAllLayer( new Vector( 1, 1, 1 ) ); 79 | 80 | assertEquals( 2, allLayers.size( ) ); 81 | 82 | assertEquals( new Vector( -7, -7 ), allLayers.get( 0 ) ); 83 | assertEquals( new Vector( 27.0 ), allLayers.get( 1 ) ); 84 | } 85 | 86 | @Test 87 | public void test_mathematical_correctness_negative_biases( ) { 88 | NeuralNet neuralNet = nnWithFixedWeightsAndBiases( 2, -1 ); 89 | 90 | List allLayers = neuralNet.calcAllLayer( new Vector( 1, 1, 1 ) ); 91 | 92 | assertEquals( 2, allLayers.size( ) ); 93 | 94 | assertEquals( new Vector( 7, 7 ), allLayers.get( 0 ) ); 95 | assertEquals( new Vector( 29.0 ), allLayers.get( 1 ) ); 96 | } 97 | 98 | private static NeuralNet nnWithFixedWeightsAndBiases( double weights, double biases ) { 99 | return new NeuralNet.Builder( 3, 1 ) 100 | .addLayer( 2 ) 101 | .withActivationFunction( x -> x ) 102 | .withWeightRandomizer( noRandomizeFixedValue( weights ) ) 103 | .withBiasRandomizer( noRandomizeFixedValue( biases ) ) 104 | .build( ); 105 | } 106 | 107 | private static Randomizer noRandomizeFixedValue( double value ) { 108 | return new Randomizer( value, value ); 109 | } 110 | 111 | 112 | } 113 | -------------------------------------------------------------------------------- /src/test/java/testNeuralNetwork/TestNeuralNetSaveAndRead.java: -------------------------------------------------------------------------------- 1 | package testNeuralNetwork; 2 | 3 | import de.fhws.easyml.linearalgebra.Vector; 4 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet; 5 | import de.fhws.easyml.utility.FileHandler; 6 | import org.junit.After; 7 | import org.junit.Before; 8 | import org.junit.Test; 9 | 10 | import java.io.File; 11 | import java.io.IOException; 12 | 13 | import static org.junit.Assert.assertEquals; 14 | import static org.junit.Assert.fail; 15 | 16 | public class TestNeuralNetSaveAndRead { 17 | 18 | private File file; 19 | private NeuralNet testNN; 20 | private static final String pathToTestFile = "testFiles/testNN.ser"; 21 | 22 | @Before 23 | public void prepare( ) throws IOException { 24 | this.file = new File( pathToTestFile ); 25 | this.testNN = new NeuralNet.Builder( 2, 1 ) 26 | .addLayer( 2 ) 27 | .withActivationFunction( x -> x ) 28 | .build(); 29 | 30 | if ( !file.createNewFile( ) ) 31 | fail( ); 32 | } 33 | 34 | @Test 35 | public void test_save_and_read( ) { 36 | 37 | FileHandler.writeObjectToFile( testNN, pathToTestFile, true ); 38 | 39 | NeuralNet neuralNet = ( NeuralNet ) FileHandler.getFirstObjectFromFile( pathToTestFile ); 40 | 41 | Vector input = new Vector( 10, 20 ); 42 | assertEquals( testNN.calcOutput( input ), neuralNet.calcOutput( input ) ); 43 | } 44 | 45 | @After 46 | public void cleanUp( ) { 47 | if ( !file.delete( ) ) 48 | fail( ); 49 | } 50 | 51 | } 52 | -------------------------------------------------------------------------------- /src/test/java/testmultithreadhelper/TestDoOnCollectionMethods.java: -------------------------------------------------------------------------------- 1 | package testmultithreadhelper; 2 | 3 | import de.fhws.easyml.utility.MultiThreadHelper; 4 | import de.fhws.easyml.utility.WarningLogger; 5 | import de.fhws.easyml.utility.throwingintefaces.ThrowingRunnable; 6 | import org.jetbrains.annotations.NotNull; 7 | import org.junit.Test; 8 | 9 | import java.util.Collection; 10 | import java.util.List; 11 | import java.util.concurrent.ExecutorService; 12 | import java.util.concurrent.Executors; 13 | import java.util.function.Consumer; 14 | import java.util.function.Supplier; 15 | import java.util.logging.Logger; 16 | import java.util.stream.Collectors; 17 | import java.util.stream.IntStream; 18 | import java.util.stream.Stream; 19 | 20 | import static org.junit.Assert.*; 21 | 22 | public class TestDoOnCollectionMethods { 23 | private static final int AMOUNT_THREADS = Runtime.getRuntime().availableProcessors(); 24 | 25 | private static class IntegerWrapper { 26 | int value; 27 | 28 | public IntegerWrapper( int value ) { 29 | this.value = value; 30 | } 31 | 32 | public void makeZero() { 33 | value = 0; 34 | ThrowingRunnable.unchecked( () -> Thread.sleep( 1 ) ).run(); 35 | } 36 | } 37 | 38 | private static class Measure { 39 | private final long time; 40 | private final T value; 41 | 42 | public Measure( long measurement, T value ) { 43 | this.time = measurement; 44 | this.value = value; 45 | } 46 | 47 | public long getTime() { 48 | return time; 49 | } 50 | 51 | public T getValue() { 52 | return value; 53 | } 54 | } 55 | 56 | private static final ExecutorService executorService = Executors.newFixedThreadPool( AMOUNT_THREADS ); 57 | 58 | private final Logger logger = WarningLogger.createWarningLogger( TestDoOnCollectionMethods.class.getName() ); 59 | 60 | @Test 61 | public void testCallConsumerOnCollection() { 62 | Consumer zeroValueConsumer = IntegerWrapper::makeZero; 63 | int size = 1000; 64 | 65 | long executionTimeMultiThreading = doConsumerMultiThreadTest( zeroValueConsumer, size ); 66 | 67 | long executionTimeSingleThreading = getConsumerSingleThreadExecutionTime( zeroValueConsumer, size ); 68 | 69 | checkExecutionTimes( executionTimeMultiThreading, executionTimeSingleThreading, "testCallConsumerOnCollection" ); 70 | } 71 | 72 | @Test 73 | public void testBuildListWithSupplier() { 74 | final Supplier negativeOneSupplier = getDelayedNegativeOneSupplier(); 75 | 76 | int size = 1000; 77 | 78 | long executionTimeMultiThreading = doListBuildingMultiThreaded( negativeOneSupplier, size ); 79 | 80 | long executionTimeSingleThreading = doListBuildingSingleThreaded( negativeOneSupplier, size ); 81 | 82 | checkExecutionTimes( executionTimeMultiThreading, executionTimeSingleThreading, "testBuildListWithSupplier" ); 83 | } 84 | 85 | @NotNull 86 | private Supplier getDelayedNegativeOneSupplier() { 87 | return () -> { 88 | ThrowingRunnable.unchecked( () -> Thread.sleep( 1 ) ).run(); 89 | return -1; 90 | }; 91 | } 92 | 93 | private long doListBuildingSingleThreaded( Supplier zeroSupplier, int size ) { 94 | return measureTime( () -> Stream.generate( zeroSupplier ).limit( size ).collect( Collectors.toList()) ).getTime(); 95 | } 96 | 97 | private long doListBuildingMultiThreaded( Supplier zeroSupplier, int size ) { 98 | Measure> measure = measureTime( () -> MultiThreadHelper.getListOutOfSupplier( executorService, zeroSupplier, size ) ); 99 | assertTrue( measure.getValue().stream().allMatch( i -> i == -1 ) ); 100 | return measure.getTime(); 101 | } 102 | 103 | private static class TestException extends RuntimeException { 104 | } 105 | 106 | @Test 107 | public void testCallOnConsumerExceptionHandling() { 108 | Consumer exceptionConsumer = ( i ) -> { 109 | throw new TestException(); 110 | }; 111 | try { 112 | MultiThreadHelper.callConsumerOnStream( executorService, IntStream.range( 0, 5 ).boxed(), exceptionConsumer ); 113 | fail(); 114 | } catch ( Exception e ) { 115 | assertTrue( e.getCause() instanceof TestException ); 116 | } 117 | } 118 | 119 | 120 | private long getConsumerSingleThreadExecutionTime( Consumer consumer, int size ) { 121 | Collection testCollectionSingleThread = createTestCollection( size ); 122 | Measure executionTimeSingleThreading = measureTime( () -> testCollectionSingleThread.forEach( consumer ) ); 123 | return executionTimeSingleThreading.getTime(); 124 | } 125 | 126 | private long doConsumerMultiThreadTest( Consumer consumer, int size ) { 127 | final Collection testCollection = createTestCollection( size ); 128 | 129 | Measure measureMultiThreading = measureTime( 130 | () -> MultiThreadHelper.callConsumerOnCollection( executorService, testCollection, consumer ) ); 131 | 132 | testCollection.forEach( i -> assertEquals( 0, i.value ) ); 133 | 134 | 135 | return measureMultiThreading.time; 136 | } 137 | 138 | private void checkExecutionTimes( long executionTimeMultiThreading, long executionTimeSingleThreading, String methodName ) { 139 | if ( executionTimeSingleThreading < executionTimeMultiThreading ) 140 | logger.warning( 141 | "Multi Threading implementation in " + methodName + " was slower than single Threading by " + ( executionTimeMultiThreading 142 | - executionTimeSingleThreading ) + " milliseconds" ); 143 | } 144 | 145 | private static Measure measureTime( Runnable runnable ) { 146 | return measureTime( () -> { 147 | runnable.run(); 148 | return null; 149 | } ); 150 | } 151 | 152 | private static Measure measureTime( Supplier supplier ) { 153 | long startTime = System.currentTimeMillis(); 154 | T value = supplier.get(); 155 | long endTime = System.currentTimeMillis(); 156 | return new Measure<>( endTime - startTime, value ); 157 | } 158 | 159 | private static Collection createTestCollection( int size ) { 160 | return IntStream.range( 0, size ).boxed().map( IntegerWrapper::new ).collect( Collectors.toList() ); 161 | } 162 | 163 | } 164 | -------------------------------------------------------------------------------- /testFiles/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tomLamprecht/Easy-ML-For-Java/f4fcc6525a1fe652add8e7b7ba835ec2a3f799e8/testFiles/.keep --------------------------------------------------------------------------------