├── .gitignore ├── README.md ├── build.gradle ├── demo.gif ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat └── src └── main └── kotlin ├── Dashboard.kt ├── PredictorModel.kt ├── TomNeuralNetwork.kt └── Util.kt /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | /build/ 3 | /.idea/ 4 | /.gradle/ 5 | *.iml 6 | gradle.properties 7 | secring.gpg 8 | /gradle.properties 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Showcases of Different ML Algorithms 2 | 3 | ### Using Background Color Classification 4 | 5 | This is a Kotlin application that experiments with different machine learning algorithms to recommend a light/dark font against different background colors. 6 | 7 | Some algorithms are built completely from scratch but others are showcased using a library. 8 | 9 | The current algorithms/library implementations: 10 | 11 | - Formulaic 12 | - Linear regression (w/ hill climbing) 13 | - Logistic regression (w/ hill climbing) 14 | - Decision Tree 15 | - Random Forest 16 | - Neural Network (w/ hill climbing) 17 | - Neural Network (w/ simulated annealing) 18 | - OjAlgo Neural Network 19 | - DL4J Neural Network 20 | 21 | Planned to be added: 22 | 23 | - Linear regression (w/ gradient descent) 24 | - Logistic regression (w/ gradient descent) 25 | - Continuous Naive Bayes 26 | 27 | ![](https://i.imgur.com/SPVFfQ6.png) 28 | -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | buildscript { 2 | ext.kotlin_version = '1.3.31' 3 | 4 | repositories { 5 | maven { url 'http://repo1.maven.org/maven2' } 6 | } 7 | dependencies { 8 | classpath "org.jetbrains.kotlin:kotlin-gradle-plugin:$kotlin_version" 9 | } 10 | } 11 | 12 | 13 | apply plugin: "kotlin" 14 | apply plugin: 'application' 15 | 16 | compileKotlin { 17 | kotlinOptions.jvmTarget= "1.8" 18 | } 19 | 20 | repositories { 21 | maven { url 'http://repo1.maven.org/maven2' } 22 | maven { url 'https://jitpack.io' } 23 | } 24 | 25 | dependencies { 26 | 27 | compile "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version" 28 | 29 | compile 'no.tornado:tornadofx:1.+' 30 | 31 | compile 'org.ojalgo:ojalgo:47.2.0' 32 | compile 'org.deeplearning4j:deeplearning4j-core:1.0.0-beta2' 33 | compile 'org.nd4j:nd4j-native-platform:1.0.0-beta2' 34 | implementation 'org.nield:kotlin-statistics:1.2.1' 35 | } 36 | 37 | task fatJar(type: Jar) { 38 | manifest { 39 | attributes 'Implementation-Title': 'Kotlin ML Demos', 40 | 'Implementation-Version': 1.0, 41 | 'Main-Class': 'UIKt' 42 | } 43 | baseName = project.name 44 | from { configurations.compile.collect { it.isDirectory() ? it : zipTree(it) } } 45 | with jar 46 | } 47 | -------------------------------------------------------------------------------- /demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasnield/kotlin-machine-learning-demos/429efdbe579f7b9a7c86b1451e11e3beef46a5c5/demo.gif -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thomasnield/kotlin-machine-learning-demos/429efdbe579f7b9a7c86b1451e11e3beef46a5c5/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | distributionBase=GRADLE_USER_HOME 2 | distributionPath=wrapper/dists 3 | distributionUrl=https\://services.gradle.org/distributions/gradle-5.6.2-bin.zip 4 | zipStoreBase=GRADLE_USER_HOME 5 | zipStorePath=wrapper/dists 6 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | # 4 | # Copyright 2015 the original author or authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | ############################################################################## 20 | ## 21 | ## Gradle start up script for UN*X 22 | ## 23 | ############################################################################## 24 | 25 | # Attempt to set APP_HOME 26 | # Resolve links: $0 may be a link 27 | PRG="$0" 28 | # Need this for relative symlinks. 29 | while [ -h "$PRG" ] ; do 30 | ls=`ls -ld "$PRG"` 31 | link=`expr "$ls" : '.*-> \(.*\)$'` 32 | if expr "$link" : '/.*' > /dev/null; then 33 | PRG="$link" 34 | else 35 | PRG=`dirname "$PRG"`"/$link" 36 | fi 37 | done 38 | SAVED="`pwd`" 39 | cd "`dirname \"$PRG\"`/" >/dev/null 40 | APP_HOME="`pwd -P`" 41 | cd "$SAVED" >/dev/null 42 | 43 | APP_NAME="Gradle" 44 | APP_BASE_NAME=`basename "$0"` 45 | 46 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 47 | DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' 48 | 49 | # Use the maximum available, or set MAX_FD != -1 to use that value. 50 | MAX_FD="maximum" 51 | 52 | warn () { 53 | echo "$*" 54 | } 55 | 56 | die () { 57 | echo 58 | echo "$*" 59 | echo 60 | exit 1 61 | } 62 | 63 | # OS specific support (must be 'true' or 'false'). 64 | cygwin=false 65 | msys=false 66 | darwin=false 67 | nonstop=false 68 | case "`uname`" in 69 | CYGWIN* ) 70 | cygwin=true 71 | ;; 72 | Darwin* ) 73 | darwin=true 74 | ;; 75 | MINGW* ) 76 | msys=true 77 | ;; 78 | NONSTOP* ) 79 | nonstop=true 80 | ;; 81 | esac 82 | 83 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 84 | 85 | # Determine the Java command to use to start the JVM. 86 | if [ -n "$JAVA_HOME" ] ; then 87 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 88 | # IBM's JDK on AIX uses strange locations for the executables 89 | JAVACMD="$JAVA_HOME/jre/sh/java" 90 | else 91 | JAVACMD="$JAVA_HOME/bin/java" 92 | fi 93 | if [ ! -x "$JAVACMD" ] ; then 94 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 95 | 96 | Please set the JAVA_HOME variable in your environment to match the 97 | location of your Java installation." 98 | fi 99 | else 100 | JAVACMD="java" 101 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 102 | 103 | Please set the JAVA_HOME variable in your environment to match the 104 | location of your Java installation." 105 | fi 106 | 107 | # Increase the maximum file descriptors if we can. 108 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 109 | MAX_FD_LIMIT=`ulimit -H -n` 110 | if [ $? -eq 0 ] ; then 111 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 112 | MAX_FD="$MAX_FD_LIMIT" 113 | fi 114 | ulimit -n $MAX_FD 115 | if [ $? -ne 0 ] ; then 116 | warn "Could not set maximum file descriptor limit: $MAX_FD" 117 | fi 118 | else 119 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 120 | fi 121 | fi 122 | 123 | # For Darwin, add options to specify how the application appears in the dock 124 | if $darwin; then 125 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 126 | fi 127 | 128 | # For Cygwin or MSYS, switch paths to Windows format before running java 129 | if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then 130 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 131 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 132 | JAVACMD=`cygpath --unix "$JAVACMD"` 133 | 134 | # We build the pattern for arguments to be converted via cygpath 135 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 136 | SEP="" 137 | for dir in $ROOTDIRSRAW ; do 138 | ROOTDIRS="$ROOTDIRS$SEP$dir" 139 | SEP="|" 140 | done 141 | OURCYGPATTERN="(^($ROOTDIRS))" 142 | # Add a user-defined pattern to the cygpath arguments 143 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 144 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 145 | fi 146 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 147 | i=0 148 | for arg in "$@" ; do 149 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 150 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 151 | 152 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 153 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 154 | else 155 | eval `echo args$i`="\"$arg\"" 156 | fi 157 | i=$((i+1)) 158 | done 159 | case $i in 160 | (0) set -- ;; 161 | (1) set -- "$args0" ;; 162 | (2) set -- "$args0" "$args1" ;; 163 | (3) set -- "$args0" "$args1" "$args2" ;; 164 | (4) set -- "$args0" "$args1" "$args2" "$args3" ;; 165 | (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 166 | (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 167 | (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 168 | (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 169 | (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 170 | esac 171 | fi 172 | 173 | # Escape application args 174 | save () { 175 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 176 | echo " " 177 | } 178 | APP_ARGS=$(save "$@") 179 | 180 | # Collect all arguments for the java command, following the shell quoting and substitution rules 181 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 182 | 183 | # by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong 184 | if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then 185 | cd "$(dirname "$0")" 186 | fi 187 | 188 | exec "$JAVACMD" "$@" 189 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%" == "" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%" == "" set DIRNAME=. 29 | set APP_BASE_NAME=%~n0 30 | set APP_HOME=%DIRNAME% 31 | 32 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 33 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 34 | 35 | @rem Find java.exe 36 | if defined JAVA_HOME goto findJavaFromJavaHome 37 | 38 | set JAVA_EXE=java.exe 39 | %JAVA_EXE% -version >NUL 2>&1 40 | if "%ERRORLEVEL%" == "0" goto init 41 | 42 | echo. 43 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 44 | echo. 45 | echo Please set the JAVA_HOME variable in your environment to match the 46 | echo location of your Java installation. 47 | 48 | goto fail 49 | 50 | :findJavaFromJavaHome 51 | set JAVA_HOME=%JAVA_HOME:"=% 52 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 53 | 54 | if exist "%JAVA_EXE%" goto init 55 | 56 | echo. 57 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 58 | echo. 59 | echo Please set the JAVA_HOME variable in your environment to match the 60 | echo location of your Java installation. 61 | 62 | goto fail 63 | 64 | :init 65 | @rem Get command-line arguments, handling Windows variants 66 | 67 | if not "%OS%" == "Windows_NT" goto win9xME_args 68 | 69 | :win9xME_args 70 | @rem Slurp the command line arguments. 71 | set CMD_LINE_ARGS= 72 | set _SKIP=2 73 | 74 | :win9xME_args_slurp 75 | if "x%~1" == "x" goto execute 76 | 77 | set CMD_LINE_ARGS=%* 78 | 79 | :execute 80 | @rem Setup the command line 81 | 82 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 83 | 84 | @rem Execute Gradle 85 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 86 | 87 | :end 88 | @rem End local scope for the variables with windows NT shell 89 | if "%ERRORLEVEL%"=="0" goto mainEnd 90 | 91 | :fail 92 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 93 | rem the _cmd.exe /c_ return code! 94 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 95 | exit /b 1 96 | 97 | :mainEnd 98 | if "%OS%"=="Windows_NT" endlocal 99 | 100 | :omega 101 | -------------------------------------------------------------------------------- /src/main/kotlin/Dashboard.kt: -------------------------------------------------------------------------------- 1 | import javafx.application.Application 2 | import javafx.beans.property.ReadOnlyObjectWrapper 3 | import javafx.beans.property.SimpleObjectProperty 4 | import javafx.geometry.Insets 5 | import javafx.geometry.Orientation 6 | import javafx.scene.layout.Background 7 | import javafx.scene.layout.BackgroundFill 8 | import javafx.scene.layout.CornerRadii 9 | import javafx.scene.paint.Color 10 | import javafx.scene.text.FontWeight 11 | import tornadofx.* 12 | 13 | 14 | fun main() = Application.launch(MainApp::class.java) 15 | 16 | class MainApp: App(MainView::class) 17 | 18 | class MainView: View() { 19 | 20 | val backgroundColor = SimpleObjectProperty(Color.GRAY) 21 | 22 | fun assignRandomColor() = randomColor() 23 | .also { backgroundColor.set(it) } 24 | 25 | override val root = splitpane { 26 | style = "-fx-font-size: 16pt; " 27 | orientation = Orientation.VERTICAL 28 | 29 | splitpane { 30 | 31 | title = "Light/Dark Text Suggester" 32 | orientation = Orientation.HORIZONTAL 33 | 34 | borderpane { 35 | 36 | top = label("TRAIN") { 37 | style { 38 | textFill = Color.RED 39 | fontWeight = FontWeight.BOLD 40 | } 41 | } 42 | 43 | center = form { 44 | fieldset { 45 | 46 | field("Which looks better?").hbox { 47 | button("DARK") { 48 | textFill = Color.BLACK 49 | useMaxWidth = true 50 | 51 | backgroundProperty().bind( 52 | backgroundColor.select { ReadOnlyObjectWrapper(Background(BackgroundFill(it, CornerRadii.EMPTY, Insets.EMPTY))) } 53 | ) 54 | 55 | setOnAction { 56 | 57 | PredictorModel += LabeledColor(backgroundColor.get(), FontShade.DARK) 58 | assignRandomColor() 59 | } 60 | } 61 | 62 | button("LIGHT") { 63 | textFill = Color.WHITE 64 | useMaxWidth = true 65 | 66 | backgroundProperty().bind( 67 | backgroundColor.select { ReadOnlyObjectWrapper(Background(BackgroundFill(it, CornerRadii.EMPTY, Insets.EMPTY))) } 68 | ) 69 | 70 | setOnAction { 71 | PredictorModel += LabeledColor(backgroundColor.get(), FontShade.DARK) 72 | 73 | assignRandomColor() 74 | } 75 | } 76 | } 77 | } 78 | 79 | fieldset { 80 | field("Model") { 81 | combobox(PredictorModel.selectedPredictor) { 82 | 83 | PredictorModel.Predictor.values().forEach { items.add(it) } 84 | } 85 | } 86 | } 87 | 88 | fieldset { 89 | field("Pre-Train") { 90 | button("Train 1345 Colors") { 91 | useMaxWidth = true 92 | setOnAction { 93 | PredictorModel.preTrainData() 94 | isDisable = true 95 | } 96 | } 97 | } 98 | } 99 | } 100 | 101 | } 102 | 103 | borderpane { 104 | 105 | top = label("PREDICT") { 106 | style { 107 | textFill = Color.RED 108 | fontWeight = FontWeight.BOLD 109 | } 110 | } 111 | 112 | center = form { 113 | fieldset { 114 | field("Background") { 115 | colorpicker { 116 | valueProperty().onChange { 117 | backgroundColor.set(it) 118 | } 119 | 120 | customColors.forEach { println(it) } 121 | } 122 | } 123 | field("Result") { 124 | label("LOREM IPSUM") { 125 | backgroundProperty().bind( 126 | backgroundColor.select { ReadOnlyObjectWrapper(Background(BackgroundFill(it, CornerRadii.EMPTY, Insets.EMPTY))) } 127 | ) 128 | 129 | backgroundColor.onChange { 130 | val result = PredictorModel.predict(it!!) 131 | 132 | text = result.toString() 133 | textFill = result.color 134 | } 135 | 136 | } 137 | } 138 | } 139 | } 140 | } 141 | } 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/main/kotlin/PredictorModel.kt: -------------------------------------------------------------------------------- 1 | import javafx.beans.property.SimpleObjectProperty 2 | import javafx.collections.FXCollections 3 | import javafx.scene.paint.Color 4 | import org.apache.commons.math3.distribution.NormalDistribution 5 | import org.deeplearning4j.nn.api.OptimizationAlgorithm 6 | import org.deeplearning4j.nn.conf.NeuralNetConfiguration 7 | import org.deeplearning4j.nn.conf.layers.DenseLayer 8 | import org.deeplearning4j.nn.conf.layers.OutputLayer 9 | import org.deeplearning4j.nn.multilayer.MultiLayerNetwork 10 | import org.deeplearning4j.nn.weights.WeightInit 11 | import org.nd4j.linalg.activations.Activation 12 | import org.nd4j.linalg.factory.Nd4j 13 | import org.nd4j.linalg.learning.config.Nesterovs 14 | import org.nield.kotlinstatistics.countBy 15 | import org.nield.kotlinstatistics.random 16 | import org.nield.kotlinstatistics.randomFirst 17 | import org.ojalgo.ann.ArtificialNeuralNetwork 18 | import org.ojalgo.array.Primitive64Array 19 | import java.net.URL 20 | import java.util.concurrent.ThreadLocalRandom 21 | import kotlin.math.exp 22 | import kotlin.math.ln 23 | import kotlin.math.pow 24 | 25 | object PredictorModel { 26 | 27 | val inputs = FXCollections.observableArrayList() 28 | 29 | val selectedPredictor = SimpleObjectProperty(Predictor.OJALGO_NEURAL_NETWORK) 30 | 31 | fun predict(color: Color) = selectedPredictor.get().predict(color) 32 | 33 | operator fun plusAssign(labeledColor: LabeledColor) { 34 | inputs += labeledColor 35 | Predictor.values().forEach { it.retrainFlag = true } 36 | } 37 | operator fun plusAssign(categorizedInput: Pair) { 38 | inputs += categorizedInput.let { LabeledColor(it.first, it.second) } 39 | Predictor.values().forEach { it.retrainFlag = true } 40 | } 41 | 42 | fun preTrainData() { 43 | 44 | URL("https://tinyurl.com/y2qmhfsr") 45 | .readText().split(Regex("\\r?\\n")) 46 | .asSequence() 47 | .drop(1) 48 | .filter { it.isNotBlank() } 49 | .map { s -> 50 | s.split(",").map { it.toInt() } 51 | } 52 | .map { Color.rgb(it[0], it[1], it[2]) } 53 | .map { LabeledColor(it, Predictor.FORMULAIC.predict(it)) } 54 | .toList() 55 | .forEach { 56 | inputs += it 57 | } 58 | 59 | Predictor.values().forEach { it.retrainFlag = true } 60 | } 61 | 62 | 63 | enum class Predictor { 64 | 65 | /** 66 | * Uses a simple formula to classify colors as LIGHT or DARK 67 | */ 68 | FORMULAIC { 69 | override fun predict(color: Color) = (0.299 * color.red + 0.587 * color.green + 0.114 * color.blue) 70 | .let { if (it > .5) FontShade.DARK else FontShade.LIGHT } 71 | }, 72 | 73 | LINEAR_REGRESSION_HILL_CLIMBING { 74 | 75 | override fun predict(color: Color): FontShade { 76 | 77 | var redWeightCandidate = 0.0 78 | var greenWeightCandidate = 0.0 79 | var blueWeightCandidate = 0.0 80 | 81 | var currentLoss = Double.MAX_VALUE 82 | 83 | val normalDistribution = NormalDistribution(0.0, 1.0) 84 | 85 | fun predict(color: Color) = 86 | (redWeightCandidate * color.red + greenWeightCandidate * color.green + blueWeightCandidate * color.blue) 87 | 88 | repeat(10000) { 89 | 90 | val selectedColor = (0..2).asSequence().randomFirst() 91 | val adjust = normalDistribution.sample() 92 | 93 | // make random adjustment to two of the colors 94 | when { 95 | selectedColor == 0 -> redWeightCandidate += adjust 96 | selectedColor == 1 -> greenWeightCandidate += adjust 97 | selectedColor == 2 -> blueWeightCandidate += adjust 98 | } 99 | 100 | // Calculate the loss, which is sum of squares 101 | val newLoss = inputs.asSequence() 102 | .map { (color, fontShade) -> 103 | (predict(color) - fontShade.intValue).pow(2) 104 | }.sum() 105 | 106 | // If improvement doesn't happen, undo the move 107 | if (newLoss < currentLoss) { 108 | currentLoss = newLoss 109 | } else { 110 | // revert if no improvement happens 111 | when { 112 | selectedColor == 0 -> redWeightCandidate -= adjust 113 | selectedColor == 1 -> greenWeightCandidate -= adjust 114 | selectedColor == 2 -> blueWeightCandidate -= adjust 115 | } 116 | } 117 | } 118 | 119 | println("${redWeightCandidate}R + ${greenWeightCandidate}G + ${blueWeightCandidate}B") 120 | 121 | val formulasLoss = inputs.asSequence() 122 | .map { (color, fontShade) -> 123 | ( (0.299 * color.red + 0.587 * color.green + 0.114 * color.blue) - fontShade.intValue).pow(2) 124 | }.average() 125 | 126 | println("BEST LOSS: $currentLoss, FORMULA'S LOSS: $formulasLoss \r\n") 127 | 128 | return predict(color) 129 | .let { if (it > .5) FontShade.DARK else FontShade.LIGHT } 130 | } 131 | }, 132 | 133 | LOGISTIC_REGRESSION_HILL_CLIMBING { 134 | 135 | 136 | var b0 = .01 // constant 137 | var b1 = .01 // red beta 138 | var b2 = .01 // green beta 139 | var b3 = .01 // blue beta 140 | 141 | 142 | fun predictProbability(color: Color) = 1.0 / (1 + exp(-(b0 + b1 * color.red + b2 * color.green + b3 * color.blue))) 143 | 144 | // Helpful Resources: 145 | // StatsQuest on YouTube: https://www.youtube.com/watch?v=yIYKR4sgzI8&list=PLblh5JKOoLUKxzEP5HA2d-Li7IJkHfXSe 146 | // Brandon Foltz on YouTube: https://www.youtube.com/playlist?list=PLIeGtxpvyG-JmBQ9XoFD4rs-b3hkcX7Uu 147 | override fun predict(color: Color): FontShade { 148 | 149 | 150 | if (retrainFlag) { 151 | var bestLikelihood = -10_000_000.0 152 | 153 | // use hill climbing for optimization 154 | val normalDistribution = NormalDistribution(0.0, 1.0) 155 | 156 | b0 = .01 // constant 157 | b1 = .01 // red beta 158 | b2 = .01 // green beta 159 | b3 = .01 // blue beta 160 | 161 | // 1 = DARK FONT, 0 = LIGHT FONT 162 | 163 | repeat(50000) { 164 | 165 | val selectedBeta = (0..3).asSequence().randomFirst() 166 | val adjust = normalDistribution.sample() 167 | 168 | // make random adjustment to two of the colors 169 | when { 170 | selectedBeta == 0 -> b0 += adjust 171 | selectedBeta == 1 -> b1 += adjust 172 | selectedBeta == 2 -> b2 += adjust 173 | selectedBeta == 3 -> b3 += adjust 174 | } 175 | 176 | // calculate maximum likelihood 177 | val darkEstimates = inputs.asSequence() 178 | .filter { it.fontShade == FontShade.DARK } 179 | .map { ln(predictProbability(it.color)) } 180 | .sum() 181 | 182 | val lightEstimates = inputs.asSequence() 183 | .filter { it.fontShade == FontShade.LIGHT } 184 | .map { ln(1 - predictProbability(it.color)) } 185 | .sum() 186 | 187 | val likelihood = darkEstimates + lightEstimates 188 | 189 | if (bestLikelihood < likelihood) { 190 | bestLikelihood = likelihood 191 | } else { 192 | // revert if no improvement happens 193 | when { 194 | selectedBeta == 0 -> b0 -= adjust 195 | selectedBeta == 1 -> b1 -= adjust 196 | selectedBeta == 2 -> b2 -= adjust 197 | selectedBeta == 3 -> b3 -= adjust 198 | } 199 | } 200 | } 201 | 202 | println("1.0 / (1 + exp(-($b0 + $b1*R + $b2*G + $b3*B))") 203 | println("BEST LIKELIHOOD: $bestLikelihood") 204 | retrainFlag = false 205 | } 206 | 207 | return predictProbability(color) 208 | .let { if (it > .5) FontShade.DARK else FontShade.LIGHT } 209 | } 210 | }, 211 | 212 | DECISION_TREE { 213 | 214 | // Helpful Resources: 215 | // StatusQuest on YouTube: https://www.youtube.com/watch?v=7VeUPuFGJHk 216 | 217 | inner class Feature(val name: String, val mapper: (Color) -> Double) { 218 | override fun toString() = name 219 | } 220 | 221 | val features = listOf( 222 | Feature("Red") { it.red * 255.0 }, 223 | Feature("Green") { it.green * 255.0 }, 224 | Feature("Blue") { it.blue * 255.0 } 225 | ) 226 | 227 | fun giniImpurity(samples: List): Double { 228 | 229 | val totalSampleCount = samples.count().toDouble() 230 | 231 | return 1.0 - (samples.count { it.fontShade == FontShade.DARK }.toDouble() / totalSampleCount).pow(2) - 232 | (samples.count { it.fontShade == FontShade.LIGHT }.toDouble() / totalSampleCount).pow(2) 233 | } 234 | 235 | fun giniImpurityForSplit(feature: Feature, splitValue: Double, samples: List): Double { 236 | val positiveFeatureSamples = samples.filter { feature.mapper(it.color) >= splitValue } 237 | val negativeFeatureSamples = samples.filter { feature.mapper(it.color) < splitValue } 238 | 239 | val positiveImpurity = giniImpurity(positiveFeatureSamples) 240 | val negativeImpurity = giniImpurity(negativeFeatureSamples) 241 | 242 | return (positiveImpurity * (positiveFeatureSamples.count().toDouble() / samples.count().toDouble())) + 243 | (negativeImpurity * (negativeFeatureSamples.count().toDouble() / samples.count().toDouble())) 244 | } 245 | 246 | fun splitContinuousVariable(feature: Feature, samples: List): Double? { 247 | 248 | val featureValues = samples.asSequence().map { feature.mapper(it.color) }.distinct().toList().sorted() 249 | 250 | val bestSplit = featureValues.asSequence().zipWithNext { value1, value2 -> (value1 + value2) / 2.0 } 251 | .minBy { giniImpurityForSplit(feature, it, samples) } 252 | 253 | return bestSplit 254 | } 255 | 256 | 257 | inner class FeatureAndSplit(val feature: Feature, val split: Double) 258 | 259 | fun buildLeaf(samples: List, previousLeaf: TreeLeaf? = null, featureSampleSize: Int? = null ): TreeLeaf? { 260 | 261 | val fs = (if (featureSampleSize == null) features else features.random(featureSampleSize) ) 262 | .asSequence() 263 | .filter { splitContinuousVariable(it, samples) != null } 264 | .map { feature -> 265 | FeatureAndSplit(feature, splitContinuousVariable(feature, samples)!!) 266 | }.minBy { fs -> 267 | giniImpurityForSplit(fs.feature, fs.split, samples) 268 | } 269 | 270 | return if (previousLeaf == null || 271 | (fs != null && giniImpurityForSplit(fs.feature, fs.split, samples) < previousLeaf.giniImpurity)) 272 | TreeLeaf(fs!!.feature, fs.split, samples) 273 | else 274 | null 275 | } 276 | 277 | 278 | inner class TreeLeaf(val feature: Feature, 279 | val splitValue: Double, 280 | val samples: List) { 281 | 282 | val goodWeatherItems = samples.filter { it.fontShade == FontShade.DARK } 283 | val badWeatherItems = samples.filter { it.fontShade == FontShade.LIGHT } 284 | 285 | val positiveItems = samples.filter { feature.mapper(it.color) >= splitValue } 286 | val negativeItems = samples.filter { feature.mapper(it.color) < splitValue } 287 | 288 | val giniImpurity = giniImpurityForSplit(feature, splitValue, samples) 289 | 290 | val featurePositiveLeaf: TreeLeaf? = buildLeaf(samples.filter { feature.mapper(it.color) >= splitValue }, this) 291 | val featureNegativeLeaf: TreeLeaf? = buildLeaf(samples.filter { feature.mapper(it.color) < splitValue }, this) 292 | 293 | 294 | fun predict(color: Color): Double { 295 | 296 | val featureValue = feature.mapper(color) 297 | 298 | 299 | return when { 300 | featureValue >= splitValue -> when { 301 | featurePositiveLeaf == null -> (goodWeatherItems.count { feature.mapper(it.color) >= splitValue }.toDouble() / samples.count { feature.mapper(it.color) >= splitValue }.toDouble()) 302 | else -> featurePositiveLeaf.predict(color) 303 | } 304 | else -> when { 305 | featureNegativeLeaf == null -> (goodWeatherItems.count { feature.mapper(it.color) < splitValue }.toDouble() / samples.count { feature.mapper(it.color) < splitValue }.toDouble()) 306 | else -> featureNegativeLeaf.predict(color) 307 | } 308 | } 309 | } 310 | 311 | override fun toString() = "$feature split on $splitValue, ${negativeItems.count()}|${positiveItems.count()}, Impurity: $giniImpurity" 312 | 313 | } 314 | 315 | 316 | fun recurseAndPrintTree(leaf: TreeLeaf?, depth: Int = 0) { 317 | 318 | if (leaf != null) { 319 | println("\t".repeat(depth) + "($depth): $leaf") 320 | recurseAndPrintTree(leaf.featureNegativeLeaf, depth + 1) 321 | recurseAndPrintTree(leaf.featurePositiveLeaf, depth + 1) 322 | } 323 | } 324 | 325 | 326 | override fun predict(color: Color): FontShade { 327 | 328 | val tree = buildLeaf(inputs) 329 | recurseAndPrintTree(tree) 330 | 331 | return if (tree!!.predict(color) >= .5) FontShade.DARK else FontShade.LIGHT 332 | } 333 | }, 334 | 335 | RANDOM_FOREST { 336 | 337 | // Helpful Resources: 338 | // StatusQuest on YouTube: https://www.youtube.com/watch?v=7VeUPuFGJHk 339 | 340 | 341 | inner class Feature(val name: String, val mapper: (Color) -> Double) { 342 | override fun toString() = name 343 | } 344 | 345 | 346 | val features = listOf( 347 | Feature("Red") { it.red * 255.0 }, 348 | Feature("Green") { it.green * 255.0 }, 349 | Feature("Blue") { it.blue * 255.0 } 350 | ) 351 | 352 | fun giniImpurity(samples: List): Double { 353 | 354 | val totalSampleCount = samples.count().toDouble() 355 | 356 | return 1.0 - (samples.count { it.fontShade == FontShade.DARK }.toDouble() / totalSampleCount).pow(2) - 357 | (samples.count { it.fontShade == FontShade.LIGHT }.toDouble() / totalSampleCount).pow(2) 358 | } 359 | 360 | fun giniImpurityForSplit(feature: Feature, splitValue: Double, samples: List): Double { 361 | val positiveFeatureSamples = samples.filter { feature.mapper(it.color) >= splitValue } 362 | val negativeFeatureSamples = samples.filter { feature.mapper(it.color) < splitValue } 363 | 364 | val positiveImpurity = giniImpurity(positiveFeatureSamples) 365 | val negativeImpurity = giniImpurity(negativeFeatureSamples) 366 | 367 | return (positiveImpurity * (positiveFeatureSamples.count().toDouble() / samples.count().toDouble())) + 368 | (negativeImpurity * (negativeFeatureSamples.count().toDouble() / samples.count().toDouble())) 369 | } 370 | 371 | fun splitContinuousVariable(feature: Feature, samples: List): Double? { 372 | 373 | val featureValues = samples.asSequence().map { feature.mapper(it.color) }.distinct().toList().sorted() 374 | 375 | val bestSplit = featureValues.asSequence().zipWithNext { value1, value2 -> (value1 + value2) / 2.0 } 376 | .minBy { giniImpurityForSplit(feature, it, samples) } 377 | 378 | return bestSplit 379 | } 380 | 381 | 382 | inner class FeatureAndSplit(val feature: Feature, val split: Double) 383 | 384 | fun buildLeaf(samples: List, previousLeaf: TreeLeaf? = null, featureSampleSize: Int? = null ): TreeLeaf? { 385 | 386 | val fs = (if (featureSampleSize == null) features else features.random(featureSampleSize) ) 387 | .asSequence() 388 | .filter { splitContinuousVariable(it, samples) != null } 389 | .map { feature -> 390 | FeatureAndSplit(feature, splitContinuousVariable(feature, samples)!!) 391 | }.minBy { fs -> 392 | giniImpurityForSplit(fs.feature, fs.split, samples) 393 | } 394 | 395 | return if (previousLeaf == null || 396 | (fs != null && giniImpurityForSplit(fs.feature, fs.split, samples) < previousLeaf.giniImpurity)) 397 | TreeLeaf(fs!!.feature, fs.split, samples) 398 | else 399 | null 400 | } 401 | 402 | inner class TreeLeaf(val feature: Feature, 403 | val splitValue: Double, 404 | val samples: List) { 405 | 406 | val darkItems = samples.filter { it.fontShade == FontShade.DARK } 407 | val lightItems = samples.filter { it.fontShade == FontShade.LIGHT } 408 | 409 | val positiveItems = samples.filter { feature.mapper(it.color) >= splitValue } 410 | val negativeItems = samples.filter { feature.mapper(it.color) < splitValue } 411 | 412 | val giniImpurity = giniImpurityForSplit(feature, splitValue, samples) 413 | 414 | val featurePositiveLeaf: TreeLeaf? = buildLeaf(samples.filter { feature.mapper(it.color) >= splitValue }, this) 415 | val featureNegativeLeaf: TreeLeaf? = buildLeaf(samples.filter { feature.mapper(it.color) < splitValue }, this) 416 | 417 | fun predict(color: Color): Double { 418 | 419 | val featureValue = feature.mapper(color) 420 | 421 | return when { 422 | featureValue >= splitValue -> when { 423 | featurePositiveLeaf == null -> (darkItems.count { feature.mapper(it.color) >= splitValue }.toDouble() / samples.count { feature.mapper(it.color) >= splitValue }.toDouble()) 424 | else -> featurePositiveLeaf.predict(color) 425 | } 426 | else -> when { 427 | featureNegativeLeaf == null -> (darkItems.count { feature.mapper(it.color) < splitValue }.toDouble() / samples.count { feature.mapper(it.color) < splitValue }.toDouble()) 428 | else -> featureNegativeLeaf.predict(color) 429 | } 430 | } 431 | } 432 | 433 | override fun toString() = "$feature split on $splitValue, ${negativeItems.count()}|${positiveItems.count()}, Impurity: $giniImpurity" 434 | 435 | } 436 | 437 | fun recurseAndPrintTree(leaf: TreeLeaf?, depth: Int = 0) { 438 | 439 | if (leaf != null) { 440 | println("\t".repeat(depth) + "($leaf)") 441 | recurseAndPrintTree(leaf.featureNegativeLeaf, depth + 1) 442 | recurseAndPrintTree(leaf.featurePositiveLeaf, depth + 1) 443 | } 444 | } 445 | 446 | 447 | lateinit var randomForest: List 448 | 449 | override fun predict(color: Color): FontShade { 450 | 451 | val bootStrapSampleCount = (inputs.count() * (2.0 / 3.0)).toInt() 452 | 453 | if (retrainFlag) { 454 | randomForest = (1..300).asSequence() 455 | .map { 456 | buildLeaf(samples = inputs.random(bootStrapSampleCount), featureSampleSize = 2)!! 457 | }.toList() 458 | 459 | retrainFlag = false 460 | } 461 | 462 | val votes = randomForest.asSequence().countBy { 463 | if (it.predict(color) >= .5) FontShade.DARK else FontShade.LIGHT 464 | } 465 | println(votes) 466 | return votes.maxBy { it.value }!!.key 467 | } 468 | }, 469 | 470 | NEURAL_NETWORK_HILL_CLIMBING { 471 | 472 | lateinit var artificialNeuralNetwork: NeuralNetwork 473 | 474 | override fun predict(color: Color): FontShade { 475 | 476 | if (retrainFlag) { 477 | artificialNeuralNetwork = neuralnetwork { 478 | inputlayer(3) 479 | hiddenlayer(3, ActivationFunction.TANH) 480 | outputlayer(2, ActivationFunction.SOFTMAX) 481 | } 482 | 483 | val trainingData = inputs.map { colorAttributes(it.color) to it.fontShade.outputArray } 484 | 485 | artificialNeuralNetwork.trainEntriesHillClimbing(trainingData) 486 | retrainFlag = false 487 | } 488 | return artificialNeuralNetwork.predictEntry(colorAttributes(color)).let { 489 | println("${it[0]} ${it[1]}") 490 | if (it[0] > it[1]) FontShade.LIGHT else FontShade.DARK 491 | } 492 | } 493 | }, 494 | 495 | NEURAL_NETWORK_SIMULATED_ANNEALING { 496 | 497 | lateinit var artificialNeuralNetwork: NeuralNetwork 498 | 499 | override fun predict(color: Color): FontShade { 500 | 501 | if (retrainFlag) { 502 | artificialNeuralNetwork = neuralnetwork { 503 | inputlayer(3) 504 | hiddenlayer(3, ActivationFunction.TANH) 505 | outputlayer(2, ActivationFunction.SOFTMAX) 506 | } 507 | 508 | val trainingData = inputs.map { colorAttributes(it.color) to it.fontShade.outputArray } 509 | 510 | artificialNeuralNetwork.trainEntriesSimulatedAnnealing(trainingData) 511 | retrainFlag = false 512 | } 513 | return artificialNeuralNetwork.predictEntry(colorAttributes(color)).let { 514 | println("${it[0]} ${it[1]}") 515 | if (it[0] > it[1]) FontShade.LIGHT else FontShade.DARK 516 | } 517 | } 518 | }, 519 | 520 | OJALGO_NEURAL_NETWORK { 521 | 522 | lateinit var artificialNeuralNetwork: ArtificialNeuralNetwork 523 | 524 | override fun predict(color: Color): FontShade { 525 | 526 | if (retrainFlag) { 527 | artificialNeuralNetwork = ArtificialNeuralNetwork.builder(3, 3, 2).apply { 528 | 529 | activator(0, ArtificialNeuralNetwork.Activator.RECTIFIER) 530 | activator(1, ArtificialNeuralNetwork.Activator.SOFTMAX) 531 | 532 | rate(.05) 533 | error(ArtificialNeuralNetwork.Error.CROSS_ENTROPY) 534 | 535 | val inputValues = inputs.asSequence().map { Primitive64Array.FACTORY.copy(* colorAttributes(it.color)) } 536 | .toList() 537 | 538 | val outputValues = inputs.asSequence().map { Primitive64Array.FACTORY.copy(*it.fontShade.outputArray) } 539 | .toList() 540 | 541 | train(inputValues, outputValues) 542 | }.get() 543 | 544 | retrainFlag = false 545 | } 546 | 547 | return artificialNeuralNetwork.invoke(Primitive64Array.FACTORY.copy(*colorAttributes(color))).let { 548 | println("${it[0]} ${it[1]}") 549 | if (it[0] > it[1]) FontShade.LIGHT else FontShade.DARK 550 | } 551 | } 552 | }, 553 | 554 | /** 555 | * Uses DeepLearning4J, a heavyweight neural network library that is probably overkill for this toy problem. 556 | * However, DL4J is a good library to use for large real-world projects. 557 | */ 558 | DL4J_NEURAL_NETWORK { 559 | override fun predict(color: Color): FontShade { 560 | 561 | val dl4jNN = NeuralNetConfiguration.Builder() 562 | .weightInit(WeightInit.UNIFORM) 563 | .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) 564 | .updater(Nesterovs(.006, .9)) 565 | .l2(1e-4) 566 | .list( 567 | DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.RELU).build(), 568 | OutputLayer.Builder().nIn(3).nOut(2).activation(Activation.SOFTMAX).build() 569 | ).pretrain(false) 570 | .backprop(true) 571 | .build() 572 | .let(::MultiLayerNetwork).apply { init() } 573 | 574 | val examples = inputs.asSequence() 575 | .map { colorAttributes(it.color) } 576 | .toList().toTypedArray() 577 | .let { Nd4j.create(it) } 578 | 579 | val outcomes = inputs.asSequence() 580 | .map { it.fontShade.outputArray } 581 | .toList().toTypedArray() 582 | .let { Nd4j.create(it) } 583 | 584 | 585 | // train for 1000 iterations (epochs) 586 | repeat(1000) { 587 | dl4jNN.fit(examples, outcomes) 588 | } 589 | 590 | // Test the input color and predict it as LIGHT or DARK 591 | val result = dl4jNN.output(Nd4j.create(colorAttributes(color))).toDoubleVector() 592 | 593 | println(result.joinToString(", ")) 594 | 595 | return if (result[0] > result[1]) FontShade.LIGHT else FontShade.DARK 596 | 597 | } 598 | }; 599 | 600 | var retrainFlag = true 601 | 602 | abstract fun predict(color: Color): FontShade 603 | override fun toString() = name.replace("_", " ") 604 | } 605 | 606 | } 607 | 608 | data class LabeledColor( 609 | val color: Color, 610 | val fontShade: FontShade 611 | ) 612 | 613 | enum class FontShade(val color: Color, val intValue: Double, val outputArray: DoubleArray){ 614 | DARK(Color.BLACK, 1.0, doubleArrayOf(0.0, 1.0)), 615 | LIGHT(Color.WHITE, 0.0, doubleArrayOf(1.0,0.0)) 616 | } 617 | 618 | // UTILITIES 619 | 620 | fun randomInt(lower: Int, upper: Int) = ThreadLocalRandom.current().nextInt(lower, upper + 1) 621 | 622 | 623 | fun randomColor() = (1..3).asSequence() 624 | .map { randomInt(0,255) } 625 | .toList() 626 | .let { Color.rgb(it[0], it[1], it[2]) } 627 | 628 | fun colorAttributes(c: Color) = doubleArrayOf( 629 | c.red, 630 | c.green, 631 | c.blue 632 | ) 633 | -------------------------------------------------------------------------------- /src/main/kotlin/TomNeuralNetwork.kt: -------------------------------------------------------------------------------- 1 | import org.apache.commons.math3.distribution.TDistribution 2 | import org.nield.kotlinstatistics.randomFirst 3 | import org.nield.kotlinstatistics.weightedCoinFlip 4 | import tornadofx.singleAssign 5 | import java.util.concurrent.ThreadLocalRandom 6 | import kotlin.math.exp 7 | import kotlin.math.pow 8 | 9 | fun neuralnetwork(op: NeuralNetworkBuilder.() -> Unit): NeuralNetwork { 10 | val nn = NeuralNetworkBuilder() 11 | nn.op() 12 | return nn.build() 13 | } 14 | 15 | class NeuralNetwork( 16 | inputNodeCount: Int, 17 | hiddenLayers: List, 18 | outputLayer: NeuralNetworkBuilder.HiddenLayerBuilder 19 | ) { 20 | 21 | 22 | val inputLayer = InputLayer(inputNodeCount) 23 | 24 | val hiddenLayers = hiddenLayers.asSequence() 25 | .mapIndexed { index,hiddenLayer -> 26 | CalculatedLayer(index, hiddenLayer.nodeCount, hiddenLayer.activationFunction) 27 | }.toList().also { layers -> 28 | layers.withIndex().forEach { (i,layer) -> 29 | layer.feedingLayer = (if (i == 0) inputLayer else layers[i-1]) 30 | } 31 | } 32 | 33 | val outputLayer = CalculatedLayer(hiddenLayers.count(), outputLayer.nodeCount, outputLayer.activationFunction).also { 34 | it.feedingLayer = (if (this.hiddenLayers.isNotEmpty()) this.hiddenLayers.last() else inputLayer) 35 | } 36 | 37 | val calculatedLayers = this.hiddenLayers.plusElement(this.outputLayer) 38 | 39 | 40 | /** 41 | * Input a set of training values for each node 42 | */ 43 | fun trainEntriesHillClimbing(inputsAndTargets: Iterable>) { 44 | 45 | val entries = inputsAndTargets.toList() 46 | 47 | 48 | // use simple hill climbing 49 | var bestLoss = Double.MAX_VALUE 50 | 51 | val tDistribution = TDistribution(3.0) 52 | 53 | val allCalculatedNodes = calculatedLayers.asSequence().flatMap { 54 | it.nodes.asSequence() 55 | }.toList() 56 | 57 | println("Training with ${entries.count()}") 58 | 59 | val learningRate = .1 60 | 61 | val weightsPlusBiasesIndices = calculatedLayers.asSequence() 62 | .map { it.weights.count() + it.biases.count() } 63 | .sum() 64 | .let { 0 until it } 65 | .toList().toIntArray() 66 | 67 | 68 | val weightCutOff = calculatedLayers.asSequence() 69 | .map { it.weights.count() } 70 | .sum() - 1 71 | 72 | repeat(100_000) { epoch -> 73 | 74 | val randomVariableIndex = weightsPlusBiasesIndices.random() 75 | 76 | val randomlySelectedNode = allCalculatedNodes.randomFirst() 77 | val randomlySelectedFeedingNode = randomlySelectedNode.layer.feedingLayer.nodes.randomFirst() 78 | val selectedWeightKey = WeightKey(randomlySelectedNode.layer.index, randomlySelectedFeedingNode.index, randomlySelectedNode.index) 79 | 80 | val randomAdjust = if (randomVariableIndex <= weightCutOff) { 81 | 82 | val currentWeightValue = randomlySelectedNode.layer.weights[selectedWeightKey]!! 83 | 84 | val randomAdjust = tDistribution.sample().let { it * learningRate }.let { 85 | when { 86 | currentWeightValue + it < -1.0 -> -1.0 - currentWeightValue 87 | currentWeightValue + it > 1.0 -> 1.0 - currentWeightValue 88 | else -> it 89 | } 90 | } 91 | 92 | randomlySelectedNode.layer.modifyWeight(selectedWeightKey, randomAdjust) 93 | randomAdjust 94 | } else { 95 | 96 | val currentBiasValue = randomlySelectedNode.layer.biases[randomlySelectedNode.index]!! 97 | 98 | val randomAdjust = tDistribution.sample().let { it * learningRate }.let { 99 | when { 100 | currentBiasValue + it < 0.0 -> 0.0 - currentBiasValue 101 | currentBiasValue + it > 1.0 -> 1.0 - currentBiasValue 102 | else -> it 103 | } 104 | } 105 | 106 | randomlySelectedNode.layer.modifyBias(randomlySelectedNode.index, randomAdjust) 107 | randomAdjust 108 | } 109 | 110 | 111 | val totalLoss = entries 112 | .asSequence() 113 | .flatMap { (input,label) -> 114 | label.asSequence() 115 | .zip(predictEntry(input).asSequence()) { actual, predicted -> (actual-predicted).pow(2) } 116 | }.sum() 117 | 118 | if (totalLoss < bestLoss) { 119 | println("epoch $epoch: $bestLoss -> $totalLoss") 120 | bestLoss = totalLoss 121 | } else { 122 | if (randomVariableIndex <= weightCutOff) { 123 | randomlySelectedNode.layer.modifyWeight(selectedWeightKey, -randomAdjust) 124 | } else { 125 | randomlySelectedNode.layer.modifyBias(randomlySelectedNode.index, -randomAdjust) 126 | } 127 | } 128 | } 129 | 130 | calculatedLayers.forEach { println(it.weights) } 131 | } 132 | 133 | fun trainEntriesSimulatedAnnealing(inputsAndTargets: Iterable>) { 134 | 135 | val entries = inputsAndTargets.toList() 136 | 137 | // use simulated annealing 138 | var bestLoss = Double.MAX_VALUE 139 | var currentLoss = bestLoss 140 | var bestConfig = calculatedLayers.map { it.index to it.weights.toMap() }.toMap() 141 | 142 | val tDistribution = TDistribution(3.0) 143 | 144 | val allCalculatedNodes = calculatedLayers.asSequence().flatMap { 145 | it.nodes.asSequence() 146 | }.toList() 147 | 148 | println("Training with ${entries.count()}") 149 | 150 | val learningRate = .1 151 | 152 | val weightsPlusBiasesIndices = calculatedLayers.asSequence() 153 | .map { it.weights.count() + it.biases.count() } 154 | .sum() 155 | .let { 0 until it } 156 | .toList().toIntArray() 157 | 158 | val weightCutOff = calculatedLayers.asSequence() 159 | .map { it.weights.count() } 160 | .sum() - 1 161 | 162 | sequenceOf( 163 | generateSequence(80.0) { t -> t - .005 }.takeWhile { it >= 0 } 164 | ).flatMap { it }.forEach { temp -> 165 | 166 | val randomVariableIndex = weightsPlusBiasesIndices.random() 167 | 168 | val randomlySelectedNode = allCalculatedNodes.randomFirst() 169 | val randomlySelectedFeedingNode = randomlySelectedNode.layer.feedingLayer.nodes.randomFirst() 170 | val selectedWeightKey = WeightKey(randomlySelectedNode.layer.index, randomlySelectedFeedingNode.index, randomlySelectedNode.index) 171 | 172 | val randomAdjust = if (randomVariableIndex <= weightCutOff) { 173 | 174 | val currentWeightValue = randomlySelectedNode.layer.weights[selectedWeightKey]!! 175 | 176 | val randomAdjust = tDistribution.sample().let { it * learningRate }.let { 177 | when { 178 | currentWeightValue + it < -1.0 -> -1.0 - currentWeightValue 179 | currentWeightValue + it > 1.0 -> 1.0 - currentWeightValue 180 | else -> it 181 | } 182 | } 183 | 184 | randomlySelectedNode.layer.modifyWeight(selectedWeightKey, randomAdjust) 185 | randomAdjust 186 | } else { 187 | 188 | val currentBiasValue = randomlySelectedNode.layer.biases[randomlySelectedNode.index]!! 189 | 190 | val randomAdjust = tDistribution.sample().let { it * learningRate }.let { 191 | when { 192 | currentBiasValue + it < 0.0 -> 0.0 - currentBiasValue 193 | currentBiasValue + it > 1.0 -> 1.0 - currentBiasValue 194 | else -> it 195 | } 196 | } 197 | 198 | randomlySelectedNode.layer.modifyBias(randomlySelectedNode.index, randomAdjust) 199 | randomAdjust 200 | } 201 | 202 | val newLoss = entries 203 | .asSequence() 204 | .flatMap { (input,label) -> 205 | label.asSequence() 206 | .zip(predictEntry(input).asSequence()) { actual, predicted -> (actual-predicted).pow(2) } 207 | }.sum() 208 | 209 | if (newLoss < currentLoss) { 210 | 211 | currentLoss = newLoss 212 | 213 | if (newLoss < bestLoss) { 214 | println("temp $temp: $bestLoss -> $newLoss") 215 | bestLoss = newLoss 216 | bestConfig = calculatedLayers.asSequence().map { it.index to it.weights.toMap() }.toMap() 217 | } 218 | } else if (weightedCoinFlip(exp((-(newLoss - currentLoss) ) / temp))) { 219 | //println("temp $temp: $newLoss <- $bestLoss") 220 | currentLoss = newLoss 221 | } else { 222 | if (randomVariableIndex <= weightCutOff) { 223 | randomlySelectedNode.layer.modifyWeight(selectedWeightKey, -randomAdjust) 224 | } else { 225 | randomlySelectedNode.layer.modifyBias(randomlySelectedNode.index, -randomAdjust) 226 | } 227 | } 228 | } 229 | 230 | calculatedLayers.forEach { cl -> bestConfig[cl.index]!!.forEach { w -> cl.weights.set(w.key, w.value) }} 231 | calculatedLayers.forEach { println(it.weights) } 232 | } 233 | fun predictEntry(inputValues: DoubleArray): DoubleArray { 234 | 235 | 236 | // assign input values to input nodes 237 | inputValues.withIndex().forEach { (i,v) -> inputLayer.nodes[i].value = v } 238 | 239 | // calculate new hidden and output node values 240 | return outputLayer.map { it.value }.toDoubleArray() 241 | } 242 | } 243 | 244 | 245 | data class WeightKey(val calculatedLayerIndex: Int, val feedingNodeIndex: Int, val nodeIndex: Int) 246 | 247 | 248 | 249 | // LAYERS 250 | sealed class Layer: Iterable { 251 | abstract val nodes: List 252 | override fun iterator() = nodes.iterator() 253 | } 254 | 255 | /** 256 | * An `InputLayer` belongs to the first layer and accepts the input values for each `InputNode` 257 | */ 258 | class InputLayer(nodeCount: Int): Layer() { 259 | 260 | override val nodes = (0 until nodeCount).asSequence() 261 | .map { InputNode(it) } 262 | .toList() 263 | } 264 | 265 | /** 266 | * A `CalculatedLayer` is used for the hidden and output layers, and is derived off weights and values off each previous layer 267 | */ 268 | class CalculatedLayer(val index: Int, nodeCount: Int, val activationFunction: ActivationFunction): Layer() { 269 | 270 | var feedingLayer: Layer by singleAssign() 271 | 272 | override val nodes by lazy { 273 | (0 until nodeCount).asSequence() 274 | .map { CalculatedNode(it, this) } 275 | .toList() 276 | } 277 | 278 | // weights are paired for feeding layer and this layer 279 | val weights by lazy { 280 | (0 until feedingLayer.nodes.count()) 281 | .asSequence() 282 | .flatMap { feedingNodeIndex -> 283 | (0 until nodeCount).asSequence() 284 | .map { nodeIndex -> 285 | WeightKey(index, feedingNodeIndex, nodeIndex) to randomWeightValue() 286 | } 287 | }.toMap().toMutableMap() 288 | } 289 | val biases by lazy { 290 | (0 until nodeCount).asSequence() 291 | .map { 292 | it to 0.0 293 | }.toMap().toMutableMap() 294 | } 295 | 296 | fun modifyWeight(key: WeightKey, adjustment: Double) = 297 | weights.compute(key) { k, v -> v!! + adjustment } 298 | 299 | fun modifyBias(nodeId: Int, adjustment: Double) = 300 | biases.compute(nodeId) { k,v -> v!! + adjustment } 301 | } 302 | 303 | 304 | // NODES 305 | sealed class Node(val index: Int) { 306 | abstract val value: Double 307 | } 308 | 309 | class InputNode(index: Int): Node(index) { 310 | override var value = 0.0 311 | } 312 | 313 | 314 | class CalculatedNode(index: Int, val layer: CalculatedLayer): Node(index) { 315 | 316 | override val value: Double get() = layer.feedingLayer.asSequence() 317 | .map { feedingNode -> 318 | val weightKey = WeightKey(layer.index, feedingNode.index, index) 319 | layer.weights[weightKey]!! * feedingNode.value 320 | }.plus(layer.biases[index]!!).sum() 321 | .let { v -> 322 | 323 | layer.activationFunction.invoke(v) { 324 | layer.asSequence().map { node -> 325 | node.layer.feedingLayer.asSequence() 326 | .map { feedingNode -> 327 | val weightKey = WeightKey(layer.index, feedingNode.index, node.index) 328 | layer.weights[weightKey]!! * feedingNode.value 329 | }.plus(layer.biases[node.index]!!).sum() 330 | }.toList().toDoubleArray() 331 | } 332 | } 333 | } 334 | 335 | fun randomWeightValue() = ThreadLocalRandom.current().nextDouble(-1.0,1.0) 336 | 337 | enum class ActivationFunction { 338 | 339 | IDENTITY { 340 | override fun invoke(x: Double, otherValues: () -> DoubleArray) = x 341 | }, 342 | SIGMOID { 343 | override fun invoke(x: Double, otherValues: () -> DoubleArray) = 1.0 / (1.0 + exp(-x)) 344 | }, 345 | TANH { 346 | override fun invoke(x: Double, otherValues: () -> DoubleArray) = kotlin.math.tanh(x) 347 | }, 348 | RELU { 349 | override fun invoke(x: Double, otherValues: () -> DoubleArray) = if (x < 0.0) 0.0 else x 350 | }, 351 | MAX { 352 | override fun invoke (x: Double, otherValues: () -> DoubleArray) = if (x == otherValues().max()) x else 0.0 353 | }, 354 | SOFTMAX { 355 | override fun invoke(x: Double, otherValues: () -> DoubleArray) = 356 | (exp(x) / otherValues().asSequence().map { exp(it) }.sum()) 357 | }; 358 | 359 | abstract fun invoke(x: Double, otherValues: () -> DoubleArray): Double 360 | } 361 | 362 | // BUILDERS 363 | class NeuralNetworkBuilder { 364 | 365 | var input = 0 366 | var hidden = mutableListOf() 367 | var output: HiddenLayerBuilder = HiddenLayerBuilder(0,ActivationFunction.RELU) 368 | 369 | class HiddenLayerBuilder(val nodeCount: Int, val activationFunction: ActivationFunction) 370 | 371 | fun inputlayer(nodeCount: Int) { 372 | input = nodeCount 373 | } 374 | 375 | fun hiddenlayer(nodeCount: Int, activationFunction: ActivationFunction) { 376 | hidden.add(HiddenLayerBuilder(nodeCount,activationFunction)) 377 | } 378 | 379 | fun outputlayer(nodeCount: Int, activationFunction: ActivationFunction) { 380 | output = HiddenLayerBuilder(nodeCount,activationFunction) 381 | } 382 | 383 | fun build() = NeuralNetwork(input, hidden, output) 384 | } -------------------------------------------------------------------------------- /src/main/kotlin/Util.kt: -------------------------------------------------------------------------------- 1 | /* 2 | import org.ojalgo.algebra.Operation 3 | import org.ojalgo.algebra.ScalarOperation 4 | import org.ojalgo.matrix.BasicMatrix 5 | import org.ojalgo.matrix.ComplexMatrix 6 | import org.ojalgo.matrix.PrimitiveMatrix 7 | import org.ojalgo.matrix.RationalMatrix 8 | import org.ojalgo.scalar.ComplexNumber 9 | import org.ojalgo.scalar.RationalNumber 10 | import java.math.BigDecimal 11 | 12 | fun Sequence.toPrimitiveMatrix(vararg selectors: (T) -> N): PrimitiveMatrix { 13 | val items = toList() 14 | 15 | return primitivematrix(items.count(), selectors.count()) { 16 | populate { row, col -> 17 | selectors[col.toInt()](items[row.toInt()]) 18 | } 19 | } 20 | } 21 | fun Iterable.toPrimitiveMatrix(vararg selectors: (T) -> N): PrimitiveMatrix { 22 | val items = toList() 23 | 24 | return primitivematrix(items.count(), selectors.count()) { 25 | populate { row, col -> 26 | selectors[col.toInt()](items[row.toInt()]) 27 | } 28 | } 29 | } 30 | 31 | 32 | fun Sequence.toComplexMatrix(vararg selectors: (T) -> N): ComplexMatrix { 33 | val items = toList() 34 | 35 | return complexmatrix(items.count(), selectors.count()) { 36 | populate { row, col -> 37 | selectors[col.toInt()](items[row.toInt()]) 38 | } 39 | } 40 | } 41 | fun Iterable.toComplexMatrix(vararg selectors: (T) -> N): ComplexMatrix { 42 | val items = toList() 43 | 44 | return complexmatrix(items.count(), selectors.count()) { 45 | populate { row, col -> 46 | selectors[col.toInt()](items[row.toInt()]) 47 | } 48 | } 49 | } 50 | 51 | 52 | fun Sequence.toRationalMatrix(vararg selectors: (T) -> N): RationalMatrix { 53 | val items = toList() 54 | 55 | return rationalmatrix(items.count(), selectors.count()) { 56 | populate { row, col -> 57 | selectors[col.toInt()](items[row.toInt()]) 58 | } 59 | } 60 | } 61 | fun Iterable.toRationalMatrix(vararg selectors: (T) -> N): RationalMatrix { 62 | val items = toList() 63 | 64 | return rationalmatrix(items.count(), selectors.count()) { 65 | populate { row, col -> 66 | selectors[col.toInt()](items[row.toInt()]) 67 | } 68 | } 69 | } 70 | 71 | 72 | fun vectorOf(vararg values: Int) = primitivematrix(values.count(), 1) { 73 | populate { row, col -> values[row.toInt()] } 74 | } 75 | 76 | fun vectorOf(vararg values: Double) = primitivematrix(values.count(), 1) { 77 | populate { row, col -> values[row.toInt()] } 78 | } 79 | 80 | fun vectorOf(vararg values: Long) = primitivematrix(values.count(), 1) { 81 | populate { row, col -> values[row.toInt()] } 82 | } 83 | 84 | fun vectorOf(vararg values: BigDecimal) = rationalmatrix(values.count(), 1) { 85 | populate { row, col -> values[row.toInt()] } 86 | } 87 | 88 | 89 | fun primitivematrix(rows: Int, cols: Int, op: (BasicMatrix.PhysicalBuilder.() -> Unit)? = null) = 90 | PrimitiveMatrix.FACTORY.getBuilder(rows,cols).also { 91 | if (op != null) op(it) 92 | }.build() 93 | 94 | 95 | fun complexmatrix(rows: Int, cols: Int, op: (BasicMatrix.PhysicalBuilder.() -> Unit)? = null) = 96 | ComplexMatrix.FACTORY.getBuilder(rows,cols).also { 97 | if (op != null) op(it) 98 | }.build() 99 | 100 | fun rationalmatrix(rows: Int, cols: Int, op: (BasicMatrix.PhysicalBuilder.() -> Unit)? = null) = 101 | RationalMatrix.FACTORY.getBuilder(rows,cols).also { 102 | if (op != null) op(it) 103 | }.build() 104 | 105 | fun BasicMatrix.PhysicalBuilder.populate(op: (Long,Long) -> Number) = 106 | loopAll { row, col -> set(row, col, op(row,col)) } 107 | 108 | fun BasicMatrix.PhysicalBuilder.setAll(vararg values: Number) { 109 | 110 | var index = 0 111 | 112 | for (r in 0..(countRows()-1)) { 113 | for (c in 0..(countColumns()-1)) { 114 | set(r,c,values[index++]) 115 | } 116 | } 117 | } 118 | 119 | 120 | fun BasicMatrix.scalarApply(op: (Number) -> Number) = primitivematrix(countRows().toInt(), countColumns().toInt()) { 121 | populate { row, col -> op(this@scalarApply[row, col]) } 122 | build() 123 | } 124 | 125 | 126 | operator fun Operation.Addition.plus(t: T) = add(t) 127 | operator fun Operation.Division.div(t: T) = divide(t) 128 | operator fun Operation.Multiplication.times(t: T) = multiply(t) 129 | operator fun Operation.Subtraction.minus(t: T) = subtract(t) 130 | 131 | operator fun ScalarOperation.Addition.plus(number: N) = add(number) 132 | operator fun ScalarOperation.Division.div(number: N) = divide(number) 133 | operator fun ScalarOperation.Multiplication.times(number: N) = multiply(number) 134 | operator fun ScalarOperation.Subtraction.minus(number: N) = subtract(number) 135 | 136 | 137 | 138 | */ 139 | --------------------------------------------------------------------------------