├── .github
└── workflows
│ └── codeql.yml
├── .gitignore
├── LICENSE
├── README.md
├── data
└── diabetesprediction
│ └── diabetes.csv
├── plots
├── .keep
└── exampleplot
│ └── examplePlot.PNG
├── pom.xml
├── src
├── main
│ └── java
│ │ ├── de
│ │ └── fhws
│ │ │ └── easyml
│ │ │ ├── ai
│ │ │ ├── backpropagation
│ │ │ │ ├── BackpropagationTrainer.java
│ │ │ │ └── logger
│ │ │ │ │ ├── BackpropagationLogger.java
│ │ │ │ │ └── loggers
│ │ │ │ │ └── ConsoleLogger.java
│ │ │ ├── geneticneuralnet
│ │ │ │ ├── NNRandomMutator.java
│ │ │ │ ├── NNUniformCrossoverRecombiner.java
│ │ │ │ ├── NeuralNetFitnessFunction.java
│ │ │ │ ├── NeuralNetIndividual.java
│ │ │ │ ├── NeuralNetPopulationSupplier.java
│ │ │ │ └── NeuralNetSupplier.java
│ │ │ └── neuralnetwork
│ │ │ │ ├── Backpropagation.java
│ │ │ │ ├── Layer.java
│ │ │ │ ├── NeuralNet.java
│ │ │ │ ├── activationfunction
│ │ │ │ ├── ActivationFunction.java
│ │ │ │ ├── Sigmoid.java
│ │ │ │ └── Tanh.java
│ │ │ │ └── costfunction
│ │ │ │ ├── CostFunction.java
│ │ │ │ └── SummedCostFunction.java
│ │ │ ├── geneticalgorithm
│ │ │ ├── GeneticAlgorithm.java
│ │ │ ├── Individual.java
│ │ │ ├── Population.java
│ │ │ ├── evolution
│ │ │ │ ├── Mutator.java
│ │ │ │ ├── Recombiner.java
│ │ │ │ ├── Selector.java
│ │ │ │ ├── recombiners
│ │ │ │ │ └── FillUpRecombiner.java
│ │ │ │ └── selectors
│ │ │ │ │ ├── EliteSelector.java
│ │ │ │ │ ├── PercentageSelector.java
│ │ │ │ │ ├── RouletteWheelSelector.java
│ │ │ │ │ └── TournamentSelector.java
│ │ │ ├── logger
│ │ │ │ ├── Logger.java
│ │ │ │ └── loggers
│ │ │ │ │ ├── ConsoleLogger.java
│ │ │ │ │ ├── IntervalConsoleLogger.java
│ │ │ │ │ └── graphplotter
│ │ │ │ │ ├── GraphPlotLogger.java
│ │ │ │ │ └── lines
│ │ │ │ │ ├── AvgFitnessLine.java
│ │ │ │ │ ├── LineGenerator.java
│ │ │ │ │ ├── MaxFitnessLine.java
│ │ │ │ │ ├── NQuantilFitnessLine.java
│ │ │ │ │ └── WorstFitnessLine.java
│ │ │ ├── populationsupplier
│ │ │ │ ├── PopulationByFileSupplier.java
│ │ │ │ └── PopulationSupplier.java
│ │ │ └── saver
│ │ │ │ └── IntervalSaver.java
│ │ │ ├── linearalgebra
│ │ │ ├── ApplyAble.java
│ │ │ ├── LinearAlgebra.java
│ │ │ ├── Matrix.java
│ │ │ ├── Randomizer.java
│ │ │ └── Vector.java
│ │ │ ├── logger
│ │ │ └── LoggerInterface.java
│ │ │ └── utility
│ │ │ ├── FileHandler.java
│ │ │ ├── ListUtility.java
│ │ │ ├── MathUtility.java
│ │ │ ├── MultiThreadHelper.java
│ │ │ ├── StreamUtil.java
│ │ │ ├── Validator.java
│ │ │ ├── WarningLogger.java
│ │ │ └── throwingintefaces
│ │ │ ├── ExceptionPrintingRunnable.java
│ │ │ └── ThrowingRunnable.java
│ │ └── example
│ │ ├── SimpleFunctionPredictionExample.java
│ │ ├── SnakeGameExample
│ │ ├── flatgame
│ │ │ ├── GameGraphics.java
│ │ │ ├── GameLogic.java
│ │ │ ├── GraphicsWindow.java
│ │ │ └── Paintable.java
│ │ └── snakegame
│ │ │ ├── Apple.java
│ │ │ ├── Item.java
│ │ │ ├── Main.java
│ │ │ ├── Part.java
│ │ │ ├── Snake.java
│ │ │ ├── SnakeAi.java
│ │ │ ├── SnakeGame.java
│ │ │ └── SnakeGameLogic.java
│ │ └── diabetesprediction
│ │ ├── DiabetesDataSet.java
│ │ ├── InputParser.java
│ │ ├── Main.java
│ │ └── backpropagation
│ │ └── MainBackprop.java
└── test
│ └── java
│ ├── testGeneticAlgorithmBlackBox
│ ├── GeneticAlgorithmTester.java
│ ├── Graph.java
│ ├── TSP.java
│ └── TestGeneticAlgorithm.java
│ ├── testLinearAlgebra
│ └── TestVector.java
│ ├── testNetworkTrainer
│ └── TestNetworkTrainerBlackBox.java
│ ├── testNeuralNetwork
│ ├── TestBackpropagation.java
│ ├── TestInvalidArguments.java
│ ├── TestNeuralNetMaths.java
│ └── TestNeuralNetSaveAndRead.java
│ └── testmultithreadhelper
│ └── TestDoOnCollectionMethods.java
└── testFiles
└── .keep
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL"
13 |
14 | on:
15 | push:
16 | branches: [ "master" ]
17 | pull_request:
18 | # The branches below must be a subset of the branches above
19 | branches: [ "master" ]
20 | schedule:
21 | - cron: '36 18 * * 6'
22 |
23 | jobs:
24 | analyze:
25 | name: Analyze
26 | runs-on: ubuntu-latest
27 | permissions:
28 | actions: read
29 | contents: read
30 | security-events: write
31 |
32 | strategy:
33 | fail-fast: false
34 | matrix:
35 | language: [ 'java' ]
36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ]
37 | # Use only 'java' to analyze code written in Java, Kotlin or both
38 | # Use only 'javascript' to analyze code written in JavaScript, TypeScript or both
39 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support
40 |
41 | steps:
42 | - name: Checkout repository
43 | uses: actions/checkout@v3
44 |
45 | # Initializes the CodeQL tools for scanning.
46 | - name: Initialize CodeQL
47 | uses: github/codeql-action/init@v2
48 | with:
49 | languages: ${{ matrix.language }}
50 | # If you wish to specify custom queries, you can do so here or in a config file.
51 | # By default, queries listed here will override any specified in a config file.
52 | # Prefix the list here with "+" to use these queries and those in the config file.
53 |
54 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
55 | # queries: security-extended,security-and-quality
56 |
57 |
58 | # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java).
59 | # If this step fails, then you should remove it and run the build manually (see below)
60 | - name: Autobuild
61 | uses: github/codeql-action/autobuild@v2
62 |
63 | # ℹ️ Command-line programs to run using the OS shell.
64 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
65 |
66 | # If the Autobuild fails above, remove it and uncomment the following three lines.
67 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance.
68 |
69 | # - run: |
70 | # echo "Run, Build Application using script"
71 | # ./location_of_script_within_repo/buildscript.sh
72 |
73 | - name: Perform CodeQL Analysis
74 | uses: github/codeql-action/analyze@v2
75 | with:
76 | category: "/language:${{matrix.language}}"
77 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /.idea/shelf/
3 | /workspace.xml
4 | /.idea/
5 | *.class
6 | .target
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 tomLamprecht
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/plots/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomLamprecht/Easy-ML-For-Java/f4fcc6525a1fe652add8e7b7ba835ec2a3f799e8/plots/.keep
--------------------------------------------------------------------------------
/plots/exampleplot/examplePlot.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tomLamprecht/Easy-ML-For-Java/f4fcc6525a1fe652add8e7b7ba835ec2a3f799e8/plots/exampleplot/examplePlot.PNG
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | de.fhws
8 | NetworkTrainer
9 | 1.0-SNAPSHOT
10 |
11 |
12 |
13 | AsposeJavaAPI
14 | Aspose Java API
15 | https://repository.aspose.com/repo/
16 |
17 |
18 |
19 |
20 |
21 | junit
22 | junit
23 | 4.13.2
24 | test
25 |
26 |
27 | org.jetbrains
28 | annotations
29 | RELEASE
30 | compile
31 |
32 |
33 | com.aspose
34 | aspose-cells
35 | 21.6
36 |
37 |
38 |
39 |
40 | 15
41 | 15
42 |
43 |
44 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/backpropagation/BackpropagationTrainer.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.backpropagation;
2 |
3 | import de.fhws.easyml.ai.backpropagation.logger.BackpropagationLogger;
4 | import de.fhws.easyml.ai.neuralnetwork.Backpropagation;
5 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet;
6 | import de.fhws.easyml.ai.neuralnetwork.costfunction.CostFunction;
7 | import de.fhws.easyml.ai.neuralnetwork.costfunction.SummedCostFunction;
8 | import de.fhws.easyml.linearalgebra.Vector;
9 |
10 | import java.util.List;
11 | import java.util.function.Supplier;
12 |
13 | public class BackpropagationTrainer {
14 |
15 | private final NeuralNet neuralNet;
16 | private final Supplier batchSupplier;
17 | private final double learningRate;
18 | private final int epochs;
19 | private final CostFunction costFunction;
20 | private final BackpropagationLogger logger;
21 |
22 |
23 | private BackpropagationTrainer(NeuralNet neuralNet, Supplier batchSupplier, double learningRate, int epochs, CostFunction costFunction, BackpropagationLogger logger) {
24 | this.neuralNet = neuralNet;
25 | this.batchSupplier = batchSupplier;
26 | this.learningRate = learningRate;
27 | this.epochs = epochs;
28 | this.costFunction = costFunction;
29 | this.logger = logger;
30 | }
31 |
32 |
33 |
34 | public void train() {
35 | for (int i = 0; i < epochs; i++) {
36 | doEpoch(i, batchSupplier.get());
37 | }
38 | }
39 |
40 | private void doEpoch(int i, Batch batch) {
41 | Backpropagation.BatchTrainingResult result = new Backpropagation(neuralNet).trainBatch(batch.getInputs(), batch.getExpectedOutputs(), costFunction, learningRate);
42 | logIfPresent(i, result);
43 | }
44 |
45 | private void logIfPresent(int i, Backpropagation.BatchTrainingResult result) {
46 | if (logger != null)
47 | logger.log(i, result);
48 | }
49 |
50 | public static class Builder {
51 |
52 | private final NeuralNet neuralNet;
53 | private final Supplier batchSupplier;
54 | private final double learningRate;
55 | private final int epochs;
56 | private CostFunction costFunction = new SummedCostFunction();
57 | private BackpropagationLogger logger;
58 |
59 | /**
60 | * Builder for Training a Neural Network using a backpropagation approach
61 | * @param batchSupplier supplies for each epoch a new Batch
62 | * @param learningRate is the factor on how much impact one batch should have on the weights
63 | * @param epochs are the amount of how many batches are getting trained
64 | */
65 | public Builder(NeuralNet neuralNet, Supplier batchSupplier, double learningRate, int epochs) {
66 | this.neuralNet = neuralNet;
67 | this.batchSupplier = batchSupplier;
68 | this.learningRate = learningRate;
69 | this.epochs = epochs;
70 | }
71 |
72 | public Builder withCostFunction(CostFunction costFunction) {
73 | this.costFunction = costFunction;
74 | return this;
75 | }
76 |
77 | public Builder withLogger(BackpropagationLogger logger) {
78 | this.logger = logger;
79 | return this;
80 | }
81 |
82 | public BackpropagationTrainer build() {
83 | return new BackpropagationTrainer(neuralNet, batchSupplier, learningRate, epochs, costFunction, logger);
84 | }
85 |
86 | }
87 |
88 | public static class Batch{
89 | private final List inputs;
90 | private final List expectedOutputs;
91 |
92 | public Batch(List inputs, List expectedOutputs) {
93 | this.inputs = inputs;
94 | this.expectedOutputs = expectedOutputs;
95 | }
96 |
97 | public List getInputs() {
98 | return inputs;
99 | }
100 |
101 | public List getExpectedOutputs() {
102 | return expectedOutputs;
103 | }
104 | }
105 |
106 |
107 | }
108 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/backpropagation/logger/BackpropagationLogger.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.backpropagation.logger;
2 |
3 | import de.fhws.easyml.ai.neuralnetwork.Backpropagation;
4 | import de.fhws.easyml.logger.LoggerInterface;
5 |
6 | public interface BackpropagationLogger extends LoggerInterface {
7 |
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/backpropagation/logger/loggers/ConsoleLogger.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.backpropagation.logger.loggers;
2 |
3 | import de.fhws.easyml.ai.backpropagation.logger.BackpropagationLogger;
4 | import de.fhws.easyml.ai.neuralnetwork.Backpropagation;
5 |
6 | public class ConsoleLogger implements BackpropagationLogger {
7 |
8 | @Override
9 | public void log(int epoch, Backpropagation.BatchTrainingResult input) {
10 | System.out.println("EPOCH " + epoch + ": " + input.avg());
11 | }
12 | }
13 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/geneticneuralnet/NNRandomMutator.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.geneticneuralnet;
2 |
3 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet;
4 | import de.fhws.easyml.geneticalgorithm.Population;
5 | import de.fhws.easyml.linearalgebra.Randomizer;
6 | import de.fhws.easyml.utility.MultiThreadHelper;
7 | import de.fhws.easyml.utility.Validator;
8 | import de.fhws.easyml.geneticalgorithm.evolution.Mutator;
9 | import org.jetbrains.annotations.Nullable;
10 |
11 | import java.util.concurrent.ExecutorService;
12 | import java.util.concurrent.ThreadLocalRandom;
13 | import java.util.function.Consumer;
14 | import java.util.function.DoubleUnaryOperator;
15 | import java.util.stream.Stream;
16 |
17 | public class NNRandomMutator implements Mutator {
18 |
19 | private final double outerMutationRate;
20 |
21 | private final DoubleUnaryOperator innerMutator;
22 |
23 | /**
24 | * Creates a Random Mutator for Neural Networks
25 | * It will choose a Neural Net out of the population after the probability {@code outMutationRate}
26 | * Every Weight of chosen Neural Net will be mutated by the probability of {@code innerMutationRate}
27 | * for the mutation there will be a factor of the modifying weight calculated by a random number of {@code mutationFactor} like:
28 | * {@code mutationFactor.getInRange() * weight}. Then there will be checked if its absolute value is smaller than
29 | * {@code minMutationAbsolute}, if yes the minMutationAbsolute with the sign of the factor previously calculated will
30 | * be added to weight. Otherwise, the factor itself gets added to the weight.
31 | *
32 | * @param outerMutationRate of the population
33 | * @param innerMutationRate of the chosen Neural Net
34 | * @param mutationFactor modifying Factor
35 | * @param minMutationAbsolute minimum modifying value
36 | */
37 | public NNRandomMutator(double outerMutationRate, double innerMutationRate, Randomizer mutationFactor, double minMutationAbsolute ) {
38 | Validator.value( outerMutationRate ).isBetweenOrThrow( 0, 1 );
39 | Validator.value( innerMutationRate ).isBetweenOrThrow( 0, 1 );
40 |
41 | this.outerMutationRate = outerMutationRate;
42 |
43 | innerMutator = d -> {
44 | if ( ThreadLocalRandom.current( ).nextDouble( ) < innerMutationRate ) {
45 | double mutateValue = d * mutationFactor.getInRange( );
46 | return d + ( mutateValue < 0 ? -1 : 1 ) * Math.max( minMutationAbsolute, Math.abs( mutateValue ) );
47 | }
48 | return d;
49 | };
50 | }
51 |
52 | @Override
53 | public void mutate(Population pop, @Nullable ExecutorService executorService ) {
54 | Stream filteredIndividuals = getFilteredIndividuals( pop );
55 |
56 | Consumer mutateConsumer = individual -> mutateNN( individual.getNN( ) );
57 | if ( executorService != null )
58 | MultiThreadHelper.callConsumerOnStream( executorService, filteredIndividuals, mutateConsumer );
59 | else
60 | filteredIndividuals
61 | .forEach( mutateConsumer );
62 | }
63 |
64 | private Stream getFilteredIndividuals( Population pop ) {
65 | return pop.getIndividuals( )
66 | .stream( )
67 | .filter( individual -> ThreadLocalRandom.current( ).nextDouble( ) < outerMutationRate );
68 | }
69 |
70 | private void mutateNN( NeuralNet neuralNet ) {
71 | neuralNet.getLayers( )
72 | .forEach( layer -> {
73 | layer.getWeights( ).apply( innerMutator );
74 | layer.getBias( ).apply( innerMutator );
75 | } );
76 | }
77 |
78 | public static class EnsureSingleThreading extends NNRandomMutator {
79 |
80 | /**
81 | * Creates a Random Mutator for Neural Networks but ensures to take the Single Thread implementation even if
82 | * a ExecuterService is provided
83 | * for further documentation see {@link #NNRandomMutator(double, double, Randomizer, double)}
84 | */
85 |
86 | public EnsureSingleThreading( double outerMutationRate, double innerMutationRate, Randomizer mutationFactor, double minMutationAbsolute ) {
87 | super( outerMutationRate, innerMutationRate, mutationFactor, minMutationAbsolute );
88 | }
89 |
90 | @Override
91 | public void mutate( Population pop, ExecutorService executorService ) {
92 | super.mutate( pop, null );
93 | }
94 | }
95 | }
96 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/geneticneuralnet/NNUniformCrossoverRecombiner.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.geneticneuralnet;
2 |
3 | import de.fhws.easyml.ai.neuralnetwork.Layer;
4 | import de.fhws.easyml.geneticalgorithm.Population;
5 | import de.fhws.easyml.utility.ListUtility;
6 | import de.fhws.easyml.utility.MultiThreadHelper;
7 | import de.fhws.easyml.utility.Validator;
8 | import de.fhws.easyml.geneticalgorithm.evolution.Recombiner;
9 | import org.jetbrains.annotations.Nullable;
10 |
11 | import java.util.List;
12 | import java.util.concurrent.ExecutorService;
13 | import java.util.concurrent.ThreadLocalRandom;
14 |
15 | public class NNUniformCrossoverRecombiner implements Recombiner {
16 |
17 | private int amountOfParentsPerChild;
18 |
19 | //TODO Add Documentation for NNUniformCrossoverRecombiner
20 | public NNUniformCrossoverRecombiner( int amountOfParentsPerChild ) {
21 | this.amountOfParentsPerChild = amountOfParentsPerChild;
22 | }
23 |
24 | @Override
25 | public void recombine(Population pop, int goalSize, @Nullable ExecutorService executorService ) {
26 | Validator.value( amountOfParentsPerChild ).isBetweenOrThrow( 1, pop.getSize( ) );
27 |
28 | if ( executorService != null ) {
29 | recombineMultiThreaded( pop, goalSize, executorService );
30 | } else {
31 | recombineSingleThreaded( pop, goalSize );
32 | }
33 | }
34 |
35 | private void recombineSingleThreaded( Population pop, int goalSize ) {
36 | while ( pop.getIndividuals( ).size( ) < goalSize ) {
37 | List parents = ListUtility.selectRandomElements( pop.getIndividuals( ), amountOfParentsPerChild );
38 | pop.getIndividuals( ).add( makeChild( parents ) );
39 | }
40 | }
41 |
42 | private void recombineMultiThreaded( Population pop, int goalSize, ExecutorService executorService ) {
43 |
44 | int neededChildren = goalSize - pop.getIndividuals( ).size( );
45 |
46 | pop.getIndividuals( ).addAll(
47 | MultiThreadHelper.getListOutOfSupplier( executorService,
48 | ( ) -> makeChild( ListUtility.selectRandomElements( pop.getIndividuals( ), amountOfParentsPerChild ) ),
49 | neededChildren ) );
50 | }
51 |
52 |
53 | private NeuralNetIndividual makeChild( List parents ) {
54 | NeuralNetIndividual child = parents.get( 0 ).copy( );
55 | for ( int l = 0; l < child.getNN( ).getLayers( ).size( ); l++ ) {
56 | combineWeights( parents, child, l );
57 |
58 | combineBias( parents, child, l );
59 | }
60 | return child;
61 | }
62 |
63 | private void combineWeights( List parents, NeuralNetIndividual child, int layerIndex ) {
64 | Layer layer = child.getNN( ).getLayers( ).get( layerIndex );
65 |
66 | for ( int i = 0; i < layer.getWeights( ).getNumRows( ); i++ ) {
67 | for ( int j = 0; j < layer.getWeights( ).getNumCols( ); j++ ) {
68 | int selectedParent = ThreadLocalRandom.current( ).nextInt( parents.size( ) );
69 | layer.getWeights( )
70 | .set( i, j, parents.get( selectedParent )
71 | .getNN( ).getLayers( )
72 | .get( layerIndex )
73 | .getWeights( )
74 | .get( i, j ) );
75 | }
76 | }
77 | }
78 |
79 |
80 | private void combineBias( List parents, NeuralNetIndividual child, int layerIndex ) {
81 | Layer layer = child.getNN( ).getLayers( ).get( layerIndex );
82 |
83 | for ( int i = 0; i < layer.getBias( ).size( ); i++ ) {
84 | int selectedParent = ThreadLocalRandom.current( ).nextInt( parents.size( ) );
85 | layer.getBias( )
86 | .set( i, parents.get( selectedParent )
87 | .getNN( ).getLayers( )
88 | .get( layerIndex )
89 | .getBias( )
90 | .get( i ) );
91 | }
92 | }
93 |
94 | public static class EnsureSingleThreading extends NNUniformCrossoverRecombiner {
95 |
96 | public EnsureSingleThreading( int amountOfParentsPerChild ) {
97 | super( amountOfParentsPerChild );
98 | }
99 |
100 | @Override
101 | public void recombine( Population pop, int goalSize, ExecutorService executorService ) {
102 | super.recombine( pop, goalSize, null );
103 | }
104 | }
105 |
106 | }
107 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/geneticneuralnet/NeuralNetFitnessFunction.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.geneticneuralnet;
2 |
3 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet;
4 |
5 | @FunctionalInterface
6 | public interface NeuralNetFitnessFunction {
7 |
8 | double calculateFitness( NeuralNet neuralNet );
9 |
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/geneticneuralnet/NeuralNetIndividual.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.geneticneuralnet;
2 |
3 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet;
4 | import de.fhws.easyml.linearalgebra.Vector;
5 | import de.fhws.easyml.geneticalgorithm.Individual;
6 |
7 |
8 | public class NeuralNetIndividual implements Individual {
9 |
10 | private NeuralNet neuralNet;
11 | private NeuralNetFitnessFunction fitnessFunction;
12 | private double fitness;
13 |
14 | public NeuralNetIndividual(NeuralNet neuralNet, NeuralNetFitnessFunction fitnessFunction) {
15 | this.neuralNet = neuralNet;
16 | this.fitnessFunction = fitnessFunction;
17 | }
18 |
19 |
20 | public NeuralNet getNN() {
21 | return neuralNet;
22 | }
23 |
24 | @Override
25 | public void calcFitness() {
26 | fitness = fitnessFunction.calculateFitness(neuralNet);
27 | }
28 |
29 | @Override
30 | public double getFitness() {
31 | return fitness;
32 | }
33 |
34 | /**
35 | * Copies the Neural Net but uses the same reference for the fitnessFunction
36 | * @return the copy
37 | */
38 | @Override
39 | public NeuralNetIndividual copy() {
40 | NeuralNetIndividual copy = new NeuralNetIndividual(neuralNet.copy(), fitnessFunction);
41 | copy.fitness = this.fitness;
42 | return copy;
43 | }
44 |
45 | public Vector calcOutput(Vector vector) {
46 | return neuralNet.calcOutput(vector);
47 | }
48 |
49 |
50 | public NeuralNetFitnessFunction getFitnessFunction() {
51 | return fitnessFunction;
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/geneticneuralnet/NeuralNetPopulationSupplier.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.geneticneuralnet;
2 |
3 | import de.fhws.easyml.geneticalgorithm.Population;
4 | import de.fhws.easyml.geneticalgorithm.populationsupplier.PopulationSupplier;
5 |
6 | import java.util.stream.Collectors;
7 | import java.util.stream.IntStream;
8 |
9 | public class NeuralNetPopulationSupplier implements PopulationSupplier {
10 |
11 | private final NeuralNetSupplier neuralNetSupplier;
12 | private final int populationSize;
13 | private final NeuralNetFitnessFunction fitnessFunction;
14 |
15 | public NeuralNetPopulationSupplier( NeuralNetSupplier neuralNetSupplier, NeuralNetFitnessFunction fitnessFunction, int populationSize ) {
16 | this.neuralNetSupplier = neuralNetSupplier;
17 | this.populationSize = populationSize;
18 | this.fitnessFunction = fitnessFunction;
19 | }
20 |
21 | @Override
22 | public Population get( ) {
23 | return new Population<>( IntStream.range( 0, populationSize )
24 | .mapToObj( i -> new NeuralNetIndividual( neuralNetSupplier.get(), fitnessFunction ) )
25 | .collect( Collectors.toList( ) ) );
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/geneticneuralnet/NeuralNetSupplier.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.geneticneuralnet;
2 |
3 | import de.fhws.easyml.ai.neuralnetwork.NeuralNet;
4 |
5 | import java.util.function.Supplier;
6 |
7 | @FunctionalInterface
8 | public interface NeuralNetSupplier extends Supplier {
9 |
10 | }
11 |
--------------------------------------------------------------------------------
/src/main/java/de/fhws/easyml/ai/neuralnetwork/Backpropagation.java:
--------------------------------------------------------------------------------
1 | package de.fhws.easyml.ai.neuralnetwork;
2 |
3 | import de.fhws.easyml.ai.neuralnetwork.costfunction.CostFunction;
4 | import de.fhws.easyml.linearalgebra.ApplyAble;
5 | import de.fhws.easyml.linearalgebra.Matrix;
6 | import de.fhws.easyml.linearalgebra.Vector;
7 | import de.fhws.easyml.utility.Validator;
8 | import org.jetbrains.annotations.NotNull;
9 |
10 | import java.util.ArrayList;
11 | import java.util.HashMap;
12 | import java.util.List;
13 | import java.util.Map;
14 | import java.util.stream.Stream;
15 |
16 | public class Backpropagation {
17 |
18 | private final NeuralNet nn;
19 |
20 | public Backpropagation(NeuralNet neuralNet) {
21 | this.nn = neuralNet;
22 | }
23 |
24 | public BatchTrainingResult trainBatch(List inputs, List expectedOutputs, CostFunction costFunction, double learningRate ) {
25 | validateTrainingBatchInput(inputs, expectedOutputs, learningRate);
26 |
27 | return doTrainBatch(inputs, expectedOutputs, costFunction, learningRate);
28 | }
29 |
30 | @NotNull
31 | private BatchTrainingResult doTrainBatch(List inputs, List expectedOutputs, CostFunction costFunction, double learningRate) {
32 | BatchTrainingResult result = new BatchTrainingResult();
33 |
34 | updateLayers(learningRate, getAveragedGradients(
35 | createGradientsFromBatch(inputs, expectedOutputs, costFunction, result )
36 | ) );
37 | return result;
38 | }
39 |
40 | private void validateTrainingBatchInput(List inputs, List expectedOutputs, double learningRate) {
41 | Validator.value( inputs.size() ).isEqualToOrThrow( expectedOutputs.size() );
42 | Validator.value(learningRate).isBetweenOrThrow( 0, 1 );
43 | }
44 |
45 | private List