├── .gitattributes ├── .github └── workflows │ └── gradle.yml ├── .gitignore ├── LICENSE.md ├── README.md ├── build.gradle ├── gradle └── wrapper │ ├── gradle-wrapper.jar │ └── gradle-wrapper.properties ├── gradlew ├── gradlew.bat ├── settings.gradle └── src └── main └── java ├── bitmanipulation ├── BasicOperators.java └── problems │ ├── missingnumber │ └── MissingNumber.java │ ├── poweroftwo │ └── PowerOfTwo.java │ └── singlenumber │ └── SingleNumber.java ├── datastructures ├── minpriorityqueue │ ├── fibonacciheap │ │ └── FibonacciHeap.java │ └── minheap │ │ └── MinHeap.java ├── trees │ ├── binarysearchtrees │ │ ├── BinaryTree.java │ │ ├── Color.java │ │ ├── Node.java │ │ ├── TimeComparisonTest.java │ │ ├── TreeVisualizer.java │ │ ├── Utility.java │ │ ├── avltree │ │ │ └── AVLTree.java │ │ ├── binarysearchtree │ │ │ └── BinarySearchTree.java │ │ └── redblacktree │ │ │ └── RedBlackTree.java │ └── trie │ │ └── Trie.java └── unionfind │ └── UnionFind.java ├── dynmanicprogramming └── problems │ ├── houserobber │ └── HouseRobber.java │ └── uniquepaths │ └── UniquePaths.java ├── graphtheory ├── shortestpathalgorithms │ └── singlesource │ │ └── dijkstrasshortestpath │ │ └── DijkstrasShortestPath.java └── traversals │ ├── breadthfirstsearch │ └── BreadthFirstSearch.java │ └── depthfirstsearch │ └── DepthFirstSearch.java ├── math ├── matrices │ └── Matrix.java └── vectors │ └── Vector3.java ├── neuralnetworks ├── EfficientNeuralNetwork.java ├── MNISTTrainer.java ├── MultilayeredNeuralNetwork.java ├── OneHiddenLayerNeuralNetwork.java ├── Perceptron.java ├── Tester.java ├── XORVisualizer.java └── mnistdata │ ├── MnistEntry.java │ ├── MnistLoader.java │ ├── MnistWeights.txt │ ├── test_images.gz │ ├── test_labels.gz │ ├── training_images.gz │ └── training_labels.gz └── sorting ├── bogosort └── BogoSort.java ├── bubblesort └── BubbleSort.java ├── countingsort └── CountingSort.java ├── heapsort └── HeapSort.java ├── insertionsort └── InsertionSort.java ├── mergesort └── MergeSort.java ├── quicksort └── QuickSort.java └── selectionsort └── SelectionSort.java /.gitattributes: -------------------------------------------------------------------------------- 1 | # 2 | # https://help.github.com/articles/dealing-with-line-endings/ 3 | # 4 | # These are explicitly windows files and should use crlf 5 | *.bat text eol=crlf 6 | 7 | -------------------------------------------------------------------------------- /.github/workflows/gradle.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a Java project with Gradle 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/building-and-testing-java-with-gradle 3 | 4 | name: Java CI with Gradle 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up JDK 1.8 19 | uses: actions/setup-java@v1 20 | with: 21 | java-version: 1.8 22 | - name: Grant execute permission for gradlew 23 | run: chmod +x gradlew 24 | - name: Build with Gradle 25 | run: ./gradlew check 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.war 15 | *.nar 16 | *.ear 17 | *.zip 18 | *.tar.gz 19 | *.rar 20 | 21 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 22 | hs_err_pid* 23 | 24 | # Temp files 25 | *~ 26 | 27 | # Make file 28 | Makefile 29 | 30 | # Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | *.py[cod] 33 | *$py.class 34 | 35 | # C extensions 36 | *.so 37 | 38 | # Distribution / packaging 39 | .Python 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | python/src/test/cpp/lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | share/python-wheels/ 53 | *.egg-info/ 54 | .installed.cfg 55 | *.egg 56 | MANIFEST 57 | 58 | # PyInstaller 59 | # Usually these files are written by a python script from a template 60 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 61 | *.manifest 62 | *.spec 63 | 64 | # Installer logs 65 | pip-log.txt 66 | pip-delete-this-directory.txt 67 | 68 | # Unit test / coverage reports 69 | htmlcov/ 70 | .tox/ 71 | .nox/ 72 | .coverage 73 | .coverage.* 74 | .cache 75 | nosetests.xml 76 | coverage.xml 77 | *.cover 78 | *.py,cover 79 | .hypothesis/ 80 | .pytest_cache/ 81 | cover/ 82 | 83 | # Translations 84 | *.mo 85 | *.pot 86 | 87 | # Django stuff: 88 | local_settings.py 89 | db.sqlite3 90 | db.sqlite3-journal 91 | 92 | # Flask stuff: 93 | instance/ 94 | .webassets-cache 95 | 96 | # Scrapy stuff: 97 | .scrapy 98 | 99 | # Sphinx documentation 100 | docs/_build/ 101 | 102 | # PyBuilder 103 | .pybuilder/ 104 | target/ 105 | 106 | # Jupyter Notebook 107 | .ipynb_checkpoints 108 | 109 | # IPython 110 | profile_default/ 111 | ipython_config.py 112 | 113 | # pyenv 114 | # For a library or package, you might want to ignore these files since the code is 115 | # intended to run in multiple environments; otherwise, check them in: 116 | # .python-version 117 | 118 | # pipenv 119 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 120 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 121 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 122 | # install all needed dependencies. 123 | #Pipfile.lock 124 | 125 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 126 | __pypackages__/ 127 | 128 | # Celery stuff 129 | celerybeat-schedule 130 | celerybeat.pid 131 | 132 | # SageMath parsed files 133 | *.sage.py 134 | 135 | # Environments 136 | .env 137 | .venv 138 | env/ 139 | venv/ 140 | ENV/ 141 | env.bak/ 142 | venv.bak/ 143 | 144 | # Spyder project settings 145 | .spyderproject 146 | .spyproject 147 | 148 | # Rope project settings 149 | .ropeproject 150 | 151 | # mkdocs documentation 152 | /site 153 | 154 | # mypy 155 | .mypy_cache/ 156 | .dmypy.json 157 | dmypy.json 158 | 159 | # Pyre type checker 160 | .pyre/ 161 | 162 | # pytype static type analyzer 163 | .pytype/ 164 | 165 | # Cython debug symbols 166 | cython_debug/ 167 | 168 | # Prerequisites 169 | *.d 170 | 171 | # Compiled Object files 172 | *.slo 173 | *.lo 174 | *.o 175 | *.obj 176 | 177 | # Precompiled Headers 178 | *.gch 179 | *.pch 180 | 181 | # Compiled Dynamic libraries 182 | *.dylib 183 | *.dll 184 | 185 | # Fortran module files 186 | *.mod 187 | *.smod 188 | 189 | # Compiled Static libraries 190 | *.lai 191 | *.la 192 | *.a 193 | 194 | # Executables 195 | *.exe 196 | *.out 197 | *.app 198 | *out* 199 | 200 | # Ignore Gradle project-specific cache directory 201 | .gradle 202 | 203 | # Ignore Gradle build output directory 204 | build 205 | 206 | # Ignore Inteliij Files 207 | .idea/ 208 | *.iml 209 | 210 | # CMake Output 211 | cmake-build-debug/ 212 | 213 | # generated output 214 | *generated* 215 | 216 | # valgrind output 217 | *valgrind* 218 | /GPATH 219 | /GTAGS 220 | /GRTAGS 221 | *# -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 nishantc1527 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Algorithms 2 | 3 | A collection of common algorithms and data structures with source code in Java. 4 | 5 | # Gradle 6 | 7 | This repository uses Gradle. While you don't need it, it will make everything much easier. You don't need to install Gradle as I added the wrapper. Testing is all done with JUnit 5. 8 | 9 | # Dependencies Used 10 | 11 | * JUnit 5 12 | * Apache Commons Lang 13 | * JBlas 14 | 15 | # Using Gradle 16 | 17 | To check if everything is correct, run this command: 18 | 19 | ```bash 20 | ./gradlew check 21 | ``` 22 | 23 | This checks for lots of things, like running the tests, check for compile errors, check for correct google java style guide, etc. If you want to run a single file, then go to the build.gradle file and add this line at the bottom. 24 | 25 | mainClassName = 'sorting.bubblesort.BubbleSort' 26 | 27 | and run 28 | 29 | ```bash 30 | ./gradlew run 31 | ``` 32 | 33 | Replace ```sorting.bubblesort.BubbleSort``` with the file you want to run. Make sure you exclude the ```src.main.java``` part. 34 | 35 | # [Source Code](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java) 36 | 37 | ## [Bit Manipulation](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/bitmanipulation) 38 | 39 | * [Basic Operators](https://github.com/nishantc1527/Algorithms-Java/blob/master/src/main/java/bitmanipulation/BasicOperators.java) 40 | 41 | ### [Problems](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/bitmanipulation/problems) 42 | 43 | * [Missing Number](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/bitmanipulation/problems/missingnumber) 44 | * [Power Of Two](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/bitmanipulation/problems/poweroftwo) 45 | * [Single Number](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/bitmanipulation/problems/singlenumber) 46 | 47 | ## [Data Structures](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures) 48 | 49 | ### [Min Priority Queues](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/minpriorityqueue) 50 | 51 | * [Fibonacci Heap](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/minpriorityqueue/FibonacciHeap) 52 | * [Min Heap](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/minpriorityqueue/MinHeap) 53 | 54 | ### [Trees](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/trees) 55 | 56 | #### [Binary Search Trees](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/trees/binarysearchtrees) 57 | 58 | * [Avl Tree](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/trees/binarysearchtrees/avltree) 59 | * [Binary Search Tree](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/trees/binarysearchtrees/binarysearchtree) 60 | * [Red Black Tree](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/trees/binarysearchtrees/redblacktree) 61 | 62 | #### [Trie](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/trees/trie) 63 | 64 | * [Trie](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/datastructures/trees/trie) 65 | 66 | ## [Dynamic Programming](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/dynmanicprogramming) 67 | 68 | ### [Problems](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/dynmanicprogramming/problems) 69 | 70 | * [House Robber](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/dynmanicprogramming/problems/houserobber) 71 | * [Unique Paths](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/dynmanicprogramming/problems/uniquepaths) 72 | 73 | ## [Graph Theory](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/graphtheory) 74 | 75 | ### [Traversals](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/graphtheory/traversals) 76 | 77 | * [Breadth First Search](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/graphtheory/traversals/breadthfirstsearch) 78 | * [Depth First Search](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/graphtheory/traversals/depthfirstsearch) 79 | 80 | ## [Math](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/math) 81 | 82 | * [Matrices](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/math/matrices) 83 | * [Vectors](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/math/vectors) 84 | 85 | ## [Neural Networks](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/neuralnetworks) 86 | 87 | * [Multilayered Neural Network](https://github.com/nishantc1527/Algorithms-Java/blob/master/src/main/java/neuralnetworks/MultilayeredNeuralNetwork.java) 88 | * [One Hidden Layer Neural Network](https://github.com/nishantc1527/Algorithms-Java/blob/master/src/main/java/neuralnetworks/OneHiddenLayerNeuralNetwork.java) 89 | 90 | ## [Sorting](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/sorting) 91 | 92 | * [Bogo Sort](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/sorting/bogosort) 93 | * [Bubble Sort](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/sorting/bubblesort) 94 | * [Counting Sort](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/sorting/countingsort) 95 | * [Heap Sort](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/sorting/heapsort) 96 | * [Insertion Sort](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/sorting/insertionsort) 97 | * [Merge Sort](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/sorting/mergesort) 98 | * [Quick Sort](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/sorting/quicksort) 99 | * [Selection Sort](https://github.com/nishantc1527/Algorithms-Java/tree/master/src/main/java/sorting/selectionsort) 100 | 101 | # License 102 | 103 | This repository is licensed under the [MIT license](https://mit-license.org/). -------------------------------------------------------------------------------- /build.gradle: -------------------------------------------------------------------------------- 1 | plugins { 2 | id "com.diffplug.spotless" version "5.1.0" 3 | id 'java' 4 | id 'application' 5 | } 6 | 7 | configurations { 8 | compileOnly { 9 | extendsFrom annotationProcessor 10 | } 11 | } 12 | 13 | sourceCompatibility = 1.8 14 | targetCompatibility = 1.8 15 | 16 | repositories { 17 | jcenter() 18 | mavenCentral() 19 | } 20 | 21 | spotless { 22 | java { 23 | importOrder() 24 | removeUnusedImports() 25 | googleJavaFormat() 26 | } 27 | } 28 | 29 | dependencies { 30 | implementation 'org.jblas:jblas:1.2.5' 31 | testCompile 'com.google.guava:guava:31.0.1-jre' 32 | implementation 'commons-lang:commons-lang:2.6' 33 | 34 | compileOnly 'org.projectlombok:lombok:1.18.22' 35 | annotationProcessor 'org.projectlombok:lombok:1.18.22' 36 | 37 | testCompile 'com.google.truth:truth:1.1.3' 38 | testImplementation 'org.junit.jupiter:junit-jupiter-api:5.8.2' 39 | } 40 | 41 | test { 42 | useJUnitPlatform() 43 | } 44 | 45 | application { 46 | mainClassName = '' // Change to whatever class you want to run 47 | } -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishantc1527/Algorithms-Java/060a2169b4fbe80a20adbe5140ff04c2d67c3d62/gradle/wrapper/gradle-wrapper.jar -------------------------------------------------------------------------------- /gradle/wrapper/gradle-wrapper.properties: -------------------------------------------------------------------------------- 1 | #Fri Jul 17 23:08:55 PDT 2020 2 | distributionUrl=https\://services.gradle.org/distributions/gradle-6.8.3-all.zip 3 | distributionBase=GRADLE_USER_HOME 4 | distributionPath=wrapper/dists 5 | zipStorePath=wrapper/dists 6 | zipStoreBase=GRADLE_USER_HOME 7 | -------------------------------------------------------------------------------- /gradlew: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | # 4 | # Copyright 2015 the original author or authors. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # https://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # 18 | 19 | ############################################################################## 20 | ## 21 | ## Gradle start up script for UN*X 22 | ## 23 | ############################################################################## 24 | 25 | # Attempt to set APP_HOME 26 | # Resolve links: $0 may be a link 27 | PRG="$0" 28 | # Need this for relative symlinks. 29 | while [ -h "$PRG" ] ; do 30 | ls=`ls -ld "$PRG"` 31 | link=`expr "$ls" : '.*-> \(.*\)$'` 32 | if expr "$link" : '/.*' > /dev/null; then 33 | PRG="$link" 34 | else 35 | PRG=`dirname "$PRG"`"/$link" 36 | fi 37 | done 38 | SAVED="`pwd`" 39 | cd "`dirname \"$PRG\"`/" >/dev/null 40 | APP_HOME="`pwd -P`" 41 | cd "$SAVED" >/dev/null 42 | 43 | APP_NAME="Gradle" 44 | APP_BASE_NAME=`basename "$0"` 45 | 46 | # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 47 | DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' 48 | 49 | # Use the maximum available, or set MAX_FD != -1 to use that value. 50 | MAX_FD="maximum" 51 | 52 | warn () { 53 | echo "$*" 54 | } 55 | 56 | die () { 57 | echo 58 | echo "$*" 59 | echo 60 | exit 1 61 | } 62 | 63 | # OS specific support (must be 'true' or 'false'). 64 | cygwin=false 65 | msys=false 66 | darwin=false 67 | nonstop=false 68 | case "`uname`" in 69 | CYGWIN* ) 70 | cygwin=true 71 | ;; 72 | Darwin* ) 73 | darwin=true 74 | ;; 75 | MINGW* ) 76 | msys=true 77 | ;; 78 | NONSTOP* ) 79 | nonstop=true 80 | ;; 81 | esac 82 | 83 | CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar 84 | 85 | # Determine the Java command to use to start the JVM. 86 | if [ -n "$JAVA_HOME" ] ; then 87 | if [ -x "$JAVA_HOME/jre/sh/java" ] ; then 88 | # IBM's JDK on AIX uses strange locations for the executables 89 | JAVACMD="$JAVA_HOME/jre/sh/java" 90 | else 91 | JAVACMD="$JAVA_HOME/bin/java" 92 | fi 93 | if [ ! -x "$JAVACMD" ] ; then 94 | die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME 95 | 96 | Please set the JAVA_HOME variable in your environment to match the 97 | location of your Java installation." 98 | fi 99 | else 100 | JAVACMD="java" 101 | which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 102 | 103 | Please set the JAVA_HOME variable in your environment to match the 104 | location of your Java installation." 105 | fi 106 | 107 | # Increase the maximum file descriptors if we can. 108 | if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then 109 | MAX_FD_LIMIT=`ulimit -H -n` 110 | if [ $? -eq 0 ] ; then 111 | if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then 112 | MAX_FD="$MAX_FD_LIMIT" 113 | fi 114 | ulimit -n $MAX_FD 115 | if [ $? -ne 0 ] ; then 116 | warn "Could not set maximum file descriptor limit: $MAX_FD" 117 | fi 118 | else 119 | warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" 120 | fi 121 | fi 122 | 123 | # For Darwin, add options to specify how the application appears in the dock 124 | if $darwin; then 125 | GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" 126 | fi 127 | 128 | # For Cygwin or MSYS, switch paths to Windows format before running java 129 | if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then 130 | APP_HOME=`cygpath --path --mixed "$APP_HOME"` 131 | CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` 132 | JAVACMD=`cygpath --unix "$JAVACMD"` 133 | 134 | # We build the pattern for arguments to be converted via cygpath 135 | ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` 136 | SEP="" 137 | for dir in $ROOTDIRSRAW ; do 138 | ROOTDIRS="$ROOTDIRS$SEP$dir" 139 | SEP="|" 140 | done 141 | OURCYGPATTERN="(^($ROOTDIRS))" 142 | # Add a user-defined pattern to the cygpath arguments 143 | if [ "$GRADLE_CYGPATTERN" != "" ] ; then 144 | OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" 145 | fi 146 | # Now convert the arguments - kludge to limit ourselves to /bin/sh 147 | i=0 148 | for arg in "$@" ; do 149 | CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` 150 | CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option 151 | 152 | if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition 153 | eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` 154 | else 155 | eval `echo args$i`="\"$arg\"" 156 | fi 157 | i=`expr $i + 1` 158 | done 159 | case $i in 160 | 0) set -- ;; 161 | 1) set -- "$args0" ;; 162 | 2) set -- "$args0" "$args1" ;; 163 | 3) set -- "$args0" "$args1" "$args2" ;; 164 | 4) set -- "$args0" "$args1" "$args2" "$args3" ;; 165 | 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; 166 | 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; 167 | 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; 168 | 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; 169 | 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; 170 | esac 171 | fi 172 | 173 | # Escape application args 174 | save () { 175 | for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done 176 | echo " " 177 | } 178 | APP_ARGS=`save "$@"` 179 | 180 | # Collect all arguments for the java command, following the shell quoting and substitution rules 181 | eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" 182 | 183 | exec "$JAVACMD" "$@" 184 | -------------------------------------------------------------------------------- /gradlew.bat: -------------------------------------------------------------------------------- 1 | @rem 2 | @rem Copyright 2015 the original author or authors. 3 | @rem 4 | @rem Licensed under the Apache License, Version 2.0 (the "License"); 5 | @rem you may not use this file except in compliance with the License. 6 | @rem You may obtain a copy of the License at 7 | @rem 8 | @rem https://www.apache.org/licenses/LICENSE-2.0 9 | @rem 10 | @rem Unless required by applicable law or agreed to in writing, software 11 | @rem distributed under the License is distributed on an "AS IS" BASIS, 12 | @rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | @rem See the License for the specific language governing permissions and 14 | @rem limitations under the License. 15 | @rem 16 | 17 | @if "%DEBUG%" == "" @echo off 18 | @rem ########################################################################## 19 | @rem 20 | @rem Gradle startup script for Windows 21 | @rem 22 | @rem ########################################################################## 23 | 24 | @rem Set local scope for the variables with windows NT shell 25 | if "%OS%"=="Windows_NT" setlocal 26 | 27 | set DIRNAME=%~dp0 28 | if "%DIRNAME%" == "" set DIRNAME=. 29 | set APP_BASE_NAME=%~n0 30 | set APP_HOME=%DIRNAME% 31 | 32 | @rem Resolve any "." and ".." in APP_HOME to make it shorter. 33 | for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi 34 | 35 | @rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. 36 | set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" 37 | 38 | @rem Find java.exe 39 | if defined JAVA_HOME goto findJavaFromJavaHome 40 | 41 | set JAVA_EXE=java.exe 42 | %JAVA_EXE% -version >NUL 2>&1 43 | if "%ERRORLEVEL%" == "0" goto init 44 | 45 | echo. 46 | echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. 47 | echo. 48 | echo Please set the JAVA_HOME variable in your environment to match the 49 | echo location of your Java installation. 50 | 51 | goto fail 52 | 53 | :findJavaFromJavaHome 54 | set JAVA_HOME=%JAVA_HOME:"=% 55 | set JAVA_EXE=%JAVA_HOME%/bin/java.exe 56 | 57 | if exist "%JAVA_EXE%" goto init 58 | 59 | echo. 60 | echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% 61 | echo. 62 | echo Please set the JAVA_HOME variable in your environment to match the 63 | echo location of your Java installation. 64 | 65 | goto fail 66 | 67 | :init 68 | @rem Get command-line arguments, handling Windows variants 69 | 70 | if not "%OS%" == "Windows_NT" goto win9xME_args 71 | 72 | :win9xME_args 73 | @rem Slurp the command line arguments. 74 | set CMD_LINE_ARGS= 75 | set _SKIP=2 76 | 77 | :win9xME_args_slurp 78 | if "x%~1" == "x" goto execute 79 | 80 | set CMD_LINE_ARGS=%* 81 | 82 | :execute 83 | @rem Setup the command line 84 | 85 | set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar 86 | 87 | @rem Execute Gradle 88 | "%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% 89 | 90 | :end 91 | @rem End local scope for the variables with windows NT shell 92 | if "%ERRORLEVEL%"=="0" goto mainEnd 93 | 94 | :fail 95 | rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of 96 | rem the _cmd.exe /c_ return code! 97 | if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 98 | exit /b 1 99 | 100 | :mainEnd 101 | if "%OS%"=="Windows_NT" endlocal 102 | 103 | :omega 104 | -------------------------------------------------------------------------------- /settings.gradle: -------------------------------------------------------------------------------- 1 | rootProject.name = 'Algorithms' 2 | -------------------------------------------------------------------------------- /src/main/java/bitmanipulation/BasicOperators.java: -------------------------------------------------------------------------------- 1 | package bitmanipulation; 2 | 3 | @SuppressWarnings("ALL") 4 | public class BasicOperators { 5 | 6 | /** 7 | * Gets the bit of an integer at a certain index 8 | * 9 | * @param num The number holding the bits 10 | * @param index The index of the bit you want (0-indexed). 11 | * @return True if the bit is 1, false if bit is 0. 12 | */ 13 | public static boolean getBit(int num, int index) { 14 | return ((num >> index) & 1) == 1; 15 | } 16 | 17 | /** 18 | * Sets the bit at specific index to 1. 19 | * 20 | * @param num The number whose bit you want to set. 21 | * @param index The index where you want to set it (0-indexed). 22 | * @return The resulting number after setting the bit. 23 | */ 24 | public static int setBit(int num, int index) { 25 | return num | (1 << index); 26 | } 27 | 28 | /** 29 | * Resets the bit at an index to 0. 30 | * 31 | * @param num The number whose bit you want cleared. 32 | * @param index The index of the bit you want cleared. 33 | * @return The new number after clearing the bit. 34 | */ 35 | public static int clearBit(int num, int index) { 36 | return num & ~(1 << index); 37 | } 38 | 39 | /** 40 | * Prints the bits of a number starting from the most significant bit. 41 | * 42 | * @param num The number whose bits you want printed. 43 | */ 44 | @SuppressWarnings("unused") 45 | public static void printBits(int num) { 46 | for (int i = 0; i < 32; i++) { 47 | System.out.print((num >> (31 - i)) & 1); 48 | } 49 | 50 | System.out.println(); 51 | } 52 | } 53 | -------------------------------------------------------------------------------- /src/main/java/bitmanipulation/problems/missingnumber/MissingNumber.java: -------------------------------------------------------------------------------- 1 | package bitmanipulation.problems.missingnumber; 2 | 3 | public class MissingNumber { 4 | 5 | /** 6 | * Given an array nums containing n distinct numbers in the range [0, n], return the only number 7 | * in the range that is missing from the array. 8 | * 9 | * @param arr Array of numbers. 10 | * @return The single number that is missing from the array. 11 | */ 12 | public static int missingNumber(int[] arr) { 13 | int xor = 0; 14 | 15 | for (int i = 0; i < arr.length; i++) { 16 | xor ^= i ^ arr[i]; 17 | } 18 | 19 | return xor ^ arr.length; 20 | } 21 | 22 | public static void main(String[] args) { 23 | System.out.println(MissingNumber.missingNumber(new int[] {0, 5, 3, 2, 4})); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/bitmanipulation/problems/poweroftwo/PowerOfTwo.java: -------------------------------------------------------------------------------- 1 | package bitmanipulation.problems.poweroftwo; 2 | 3 | public class PowerOfTwo { 4 | 5 | /** 6 | * Returns whether a number is a power of two. 7 | * 8 | * @param n The number you are checking. 9 | * @return True if it is a power of two. 10 | */ 11 | public static boolean isPowerOfTwo(int n) { 12 | return n > 0 && (n & (n - 1)) == 0; 13 | } 14 | 15 | public static void main(String[] args) { 16 | System.out.println(PowerOfTwo.isPowerOfTwo(2)); 17 | System.out.println(PowerOfTwo.isPowerOfTwo(5)); 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/bitmanipulation/problems/singlenumber/SingleNumber.java: -------------------------------------------------------------------------------- 1 | package bitmanipulation.problems.singlenumber; 2 | 3 | public class SingleNumber { 4 | 5 | /** 6 | * Given a non-empty array of integers nums, every element appears twice except for one. Find that 7 | * single one. You must implement a solution in linear runtime complexity and use only constant 8 | * extra space. 9 | * 10 | * @param arr Array of integers. 11 | * @return The element that only appears once. 12 | */ 13 | public static int singleNumber(int[] arr) { 14 | int xor = 0; 15 | 16 | for (int value : arr) { 17 | xor ^= value; 18 | } 19 | 20 | return xor; 21 | } 22 | 23 | public static void main(String[] args) { 24 | System.out.println(SingleNumber.singleNumber(new int[] {1, 2, 2, 1, 5})); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/datastructures/minpriorityqueue/fibonacciheap/FibonacciHeap.java: -------------------------------------------------------------------------------- 1 | package datastructures.minpriorityqueue.fibonacciheap; 2 | 3 | import java.util.HashMap; 4 | import java.util.HashSet; 5 | import java.util.Objects; 6 | 7 | public class FibonacciHeap { 8 | 9 | private final HashMap references; 10 | private Node min; 11 | 12 | public FibonacciHeap() { 13 | references = new HashMap<>(); 14 | } 15 | 16 | public static void main(String[] args) { 17 | FibonacciHeap h = new FibonacciHeap(); 18 | 19 | for (int i = 0; i < 10; i++) { 20 | h.insert(i); 21 | } 22 | 23 | h.delete(5); 24 | 25 | while (!h.isEmpty()) { 26 | System.out.println(h.extractMin()); 27 | } 28 | 29 | System.out.println(h); 30 | } 31 | 32 | public void insert(int val) { 33 | Node newNode = new Node(val); 34 | references.put(val, newNode); 35 | 36 | if (min == null) { 37 | min = newNode; 38 | min.left = min; 39 | min.right = min; 40 | } else { 41 | min.insert(newNode); 42 | if (newNode.val < min.val) { 43 | min = newNode; 44 | } 45 | } 46 | } 47 | 48 | public int extractMin() { 49 | int toReturn = min.val; 50 | references.remove(toReturn); 51 | 52 | if (min.left == min && min.child == null) { 53 | min = null; 54 | } else { 55 | min.safeUnlink(); 56 | min = min.left; 57 | consolidate(); 58 | } 59 | 60 | return toReturn; 61 | } 62 | 63 | public void delete(int val) { 64 | Node nodeForm = references.get(val); 65 | nodeForm.val = Integer.MIN_VALUE; 66 | min = nodeForm; 67 | extractMin(); 68 | } 69 | 70 | public boolean isEmpty() { 71 | return min == null; 72 | } 73 | 74 | private void consolidate() { 75 | Node[] degrees = new Node[45]; 76 | Node dummy = min; 77 | HashSet visited = new HashSet<>(); 78 | 79 | do { 80 | if (visited.contains(dummy)) { 81 | break; 82 | } 83 | 84 | if (dummy.val < min.val) { 85 | min = dummy; 86 | } 87 | 88 | while (degrees[dummy.degree] != null) { 89 | Node other = degrees[dummy.degree]; 90 | 91 | if (other.val < dummy.val) { 92 | Node temp = other; 93 | other = dummy; 94 | dummy = temp; 95 | } 96 | 97 | dummy.link(other); 98 | degrees[dummy.degree - 1] = null; 99 | } 100 | 101 | degrees[dummy.degree] = dummy; 102 | visited.add(dummy); 103 | dummy = dummy.right; 104 | } while (dummy != min); 105 | } 106 | 107 | @Override 108 | public String toString() { 109 | StringBuilder sb = new StringBuilder(); 110 | 111 | if (min != null) { 112 | Node dummy = min; 113 | 114 | do { 115 | sb.append(dummy.val).append(" "); 116 | dummy = dummy.right; 117 | } while (dummy != min); 118 | } 119 | 120 | return sb.toString(); 121 | } 122 | 123 | private static class Node { 124 | 125 | public int val; 126 | public Node left, right, child; 127 | public int degree; 128 | public boolean mark; 129 | 130 | public Node(int _val) { 131 | val = _val; 132 | } 133 | 134 | public void insert(Node other) { 135 | if (left == this) { 136 | left = other; 137 | right = other; 138 | left.right = this; 139 | left.left = this; 140 | } else { 141 | Node temp = left; 142 | left = other; 143 | left.right = this; 144 | left.left = temp; 145 | temp.right = left; 146 | } 147 | } 148 | 149 | public void unlink() { 150 | left.right = right; 151 | right.left = left; 152 | } 153 | 154 | public void safeUnlink() { 155 | saveChildren(); 156 | unlink(); 157 | } 158 | 159 | public void link(Node other) { 160 | other.unlink(); 161 | if (child == null) { 162 | child = new Node(other.val); 163 | child.left = child; 164 | child.right = child; 165 | } else { 166 | child.insert(other); 167 | } 168 | 169 | other.mark = false; 170 | degree++; 171 | } 172 | 173 | public void saveChildren() { 174 | if (child != null) { 175 | Node dummy = child; 176 | 177 | do { 178 | Node tempNext = dummy.right; 179 | insert(dummy); 180 | dummy = tempNext; 181 | } while (dummy != child); 182 | 183 | child = null; 184 | } 185 | } 186 | 187 | @Override 188 | public boolean equals(Object o) { 189 | if (this == o) return true; 190 | if (o == null || getClass() != o.getClass()) return false; 191 | Node node = (Node) o; 192 | return val == node.val 193 | && degree == node.degree 194 | && mark == node.mark 195 | && Objects.equals(left, node.left) 196 | && Objects.equals(right, node.right) 197 | && Objects.equals(child, node.child); 198 | } 199 | 200 | @Override 201 | public int hashCode() { 202 | return Objects.hash(val); 203 | } 204 | 205 | @Override 206 | public String toString() { 207 | return val + ""; 208 | } 209 | } 210 | } 211 | -------------------------------------------------------------------------------- /src/main/java/datastructures/minpriorityqueue/minheap/MinHeap.java: -------------------------------------------------------------------------------- 1 | package datastructures.minpriorityqueue.minheap; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Comparator; 5 | import java.util.Iterator; 6 | import java.util.List; 7 | 8 | @SuppressWarnings("unused") 9 | public class MinHeap implements Iterable { 10 | 11 | public final List heap; 12 | private final Comparator comparator; 13 | 14 | public MinHeap(Comparator comparator) { 15 | heap = new ArrayList<>(); 16 | this.comparator = comparator; 17 | } 18 | 19 | private static int getLeft(int i) { 20 | return (i << 1) + 1; 21 | } 22 | 23 | private static int getRight(int i) { 24 | return (i << 1) + 2; 25 | } 26 | 27 | private static int getParent(int i) { 28 | return i % 2 == 0 ? (i >> 1) - 1 : (i >> 1); 29 | } 30 | 31 | public static void main(String[] args) { 32 | MinHeap heap = new MinHeap<>(Integer::compare); 33 | heap.add(7); 34 | heap.add(4); 35 | heap.add(3); 36 | heap.add(2); 37 | heap.add(1); 38 | 39 | while (!heap.isEmpty()) { 40 | System.out.println(heap.extractMin()); 41 | } 42 | } 43 | 44 | public void add(E val) { 45 | heap.add(val); 46 | bubbleUp(heap.size() - 1); 47 | } 48 | 49 | @SuppressWarnings("unused") 50 | public E getMin() { 51 | return heap.get(0); 52 | } 53 | 54 | public E extractMin() { 55 | E toReturn = heap.get(0); 56 | swap(0, heap.size() - 1); 57 | heap.remove(heap.size() - 1); 58 | bubbleDown(0); 59 | return toReturn; 60 | } 61 | 62 | public void decreaseValue(E val, E newVal) { 63 | int index = heap.indexOf(val); 64 | heap.set(index, newVal); 65 | bubbleUp(index); 66 | } 67 | 68 | public boolean isEmpty() { 69 | return heap.size() == 0; 70 | } 71 | 72 | public boolean contains(E i) { 73 | return heap.contains(i); 74 | } 75 | 76 | public void update(E val) { 77 | bubbleUp(heap.indexOf(val)); 78 | } 79 | 80 | private void bubbleUp(int i) { 81 | if (i > 0) { 82 | int parent = getParent(i); 83 | if (comparator.compare(heap.get(parent), heap.get(i)) > 0) { 84 | swap(parent, i); 85 | bubbleUp(parent); 86 | } 87 | } 88 | } 89 | 90 | private void bubbleDown(int i) { 91 | int left = getLeft(i), right = getRight(i), largest = i; 92 | 93 | if (left < heap.size() && comparator.compare(heap.get(left), heap.get(largest)) < 0) { 94 | largest = left; 95 | } 96 | 97 | if (right < heap.size() && comparator.compare(heap.get(right), heap.get(largest)) < 0) { 98 | largest = right; 99 | } 100 | 101 | if (largest != i) { 102 | swap(i, largest); 103 | bubbleDown(largest); 104 | } 105 | } 106 | 107 | private void swap(int i, int j) { 108 | E temp = heap.get(i); 109 | heap.set(i, heap.get(j)); 110 | heap.set(j, temp); 111 | } 112 | 113 | @Override 114 | public Iterator iterator() { 115 | return heap.iterator(); 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/binarysearchtrees/BinaryTree.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.binarysearchtrees; 2 | 3 | @SuppressWarnings("unused") 4 | public interface BinaryTree> extends Iterable> { 5 | 6 | Node getRoot(); 7 | 8 | void insert(E val); 9 | 10 | void delete(E val); 11 | 12 | boolean contains(E val); 13 | 14 | int getHeight(); 15 | 16 | @SuppressWarnings("unused") 17 | boolean isValid(); 18 | 19 | int numNodes(); 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/binarysearchtrees/Color.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.binarysearchtrees; 2 | 3 | public enum Color { 4 | RED, 5 | BLACK 6 | } 7 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/binarysearchtrees/Node.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.binarysearchtrees; 2 | 3 | @SuppressWarnings("ALL") 4 | public interface Node> { 5 | 6 | /** 7 | * Gets the value of this node. 8 | * 9 | * @return The value of this node. 10 | */ 11 | E getVal(); 12 | 13 | /** 14 | * Sets the value of this node. 15 | * 16 | * @param newVal This node's new value. 17 | */ 18 | void setVal(E newVal); 19 | 20 | /** 21 | * Gets the left child of this node. 22 | * 23 | * @return The left child of this node. 24 | */ 25 | Node getLeft(); 26 | 27 | /** 28 | * Sets the left child of this node. 29 | * 30 | * @param newNode This node's new left child. 31 | */ 32 | void setLeft(Node newNode); 33 | 34 | /** 35 | * Gets the right child of this node. 36 | * 37 | * @return The right child of this node. 38 | */ 39 | Node getRight(); 40 | 41 | /** 42 | * Sets the right child of this node. 43 | * 44 | * @param newNode The node's new right child. 45 | */ 46 | void setRight(Node newNode); 47 | 48 | /** 49 | * Gets the parent of this node. 50 | * 51 | * @return The parent of this node. 52 | */ 53 | Node getParent(); 54 | 55 | /** 56 | * Sets the parent of this node. 57 | * 58 | * @param newNode The node's new parent. 59 | */ 60 | void setParent(Node newNode); 61 | 62 | /** 63 | * Checks if this node is null. 64 | * 65 | * @return If this node is null. 66 | */ 67 | boolean isNull(); 68 | 69 | /** 70 | * In red black trees, returns whether it's color is red or black. 71 | * 72 | * @return The color of this node. 73 | */ 74 | Color getColor(); 75 | 76 | /** 77 | * In red black trees, sets the color of this node to specified color. 78 | * 79 | * @param newColor The new color of this node. 80 | */ 81 | void setColor(Color newColor); 82 | 83 | /** Hashing function */ 84 | int hashCode(); 85 | } 86 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/binarysearchtrees/TimeComparisonTest.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.binarysearchtrees; 2 | 3 | import datastructures.trees.binarysearchtrees.avltree.AVLTree; 4 | import datastructures.trees.binarysearchtrees.binarysearchtree.BinarySearchTree; 5 | import datastructures.trees.binarysearchtrees.redblacktree.RedBlackTree; 6 | import java.util.ArrayList; 7 | import java.util.List; 8 | 9 | public class TimeComparisonTest { 10 | public static void main(String[] args) { 11 | BinaryTree rbtree = new RedBlackTree<>(), 12 | avltree = new AVLTree<>(), 13 | binarySearchTree = new BinarySearchTree<>(); 14 | List rbTreeVals = new ArrayList<>(), 15 | avlTreeVals = new ArrayList<>(), 16 | bstVals = new ArrayList<>(); 17 | int n = 15000; 18 | for (int i = 0; i < n; i++) { 19 | int newVal = (int) (Math.random() * n * 2 - n); 20 | rbTreeVals.add(newVal); 21 | avlTreeVals.add(newVal); 22 | bstVals.add(newVal); 23 | } 24 | 25 | long rbTimeStart, rbTimeFinish, avlTimeStart, avlTimeFinish, bstTimeStart, bstTimeFinish; 26 | 27 | rbTimeStart = System.nanoTime(); 28 | for (Integer rbTreeVal : rbTreeVals) { 29 | rbtree.insert(rbTreeVal); 30 | } 31 | rbTimeFinish = System.nanoTime(); 32 | 33 | avlTimeStart = System.nanoTime(); 34 | for (Integer avlTreeVal : avlTreeVals) { 35 | avltree.insert(avlTreeVal); 36 | } 37 | avlTimeFinish = System.nanoTime(); 38 | 39 | bstTimeStart = System.nanoTime(); 40 | for (Integer bstTreeVal : bstVals) { 41 | binarySearchTree.insert(bstTreeVal); 42 | } 43 | bstTimeFinish = System.nanoTime(); 44 | 45 | System.out.printf( 46 | "%-40s %d values: %-5d milliseconds\n", 47 | "Red Black Tree insertion time for", n, ((rbTimeFinish - rbTimeStart) / 1000000)); 48 | System.out.printf( 49 | "%-40s %d values: %-5d milliseconds\n", 50 | "AVL Tree insertion time for", n, ((avlTimeFinish - avlTimeStart) / 1000000)); 51 | System.out.printf( 52 | "%-40s %d values: %-5d milliseconds\n", 53 | "Binary Search Tree insertion time for", n, ((bstTimeFinish - bstTimeStart) / 1000000)); 54 | 55 | System.out.println("\n"); 56 | 57 | System.out.printf("%-45s%-4d\n", "Red Black Tree height after insertion: ", rbtree.getHeight()); 58 | System.out.printf("%-45s%-4d\n", "AVL Tree height after insertion: ", avltree.getHeight()); 59 | System.out.printf( 60 | "%-45s%-4d\n", "Binary Search Tree height after insertion: ", binarySearchTree.getHeight()); 61 | 62 | System.out.println("\n"); 63 | 64 | int searchN = n * 200; 65 | 66 | rbTimeStart = System.nanoTime(); 67 | for (int i = -searchN; i < searchN; i++) { 68 | rbtree.contains(i); 69 | } 70 | rbTimeFinish = System.nanoTime(); 71 | 72 | avlTimeStart = System.nanoTime(); 73 | for (int i = -searchN; i < searchN; i++) { 74 | avltree.contains(i); 75 | } 76 | avlTimeFinish = System.nanoTime(); 77 | 78 | bstTimeStart = System.nanoTime(); 79 | for (int i = -searchN; i < searchN; i++) { 80 | binarySearchTree.contains(i); 81 | } 82 | bstTimeFinish = System.nanoTime(); 83 | 84 | System.out.printf( 85 | "%-40s %d values: %-5d milliseconds\n", 86 | "Red Black Tree search time for", searchN, ((rbTimeFinish - rbTimeStart) / 1000000)); 87 | System.out.printf( 88 | "%-40s %d values: %-5d milliseconds\n", 89 | "AVL Tree search time for", searchN, ((avlTimeFinish - avlTimeStart) / 1000000)); 90 | System.out.printf( 91 | "%-40s %d values: %-5d milliseconds\n", 92 | "Binary Search Tree search time for", searchN, ((bstTimeFinish - bstTimeStart) / 1000000)); 93 | 94 | System.out.println("\n"); 95 | 96 | rbTimeStart = System.nanoTime(); 97 | for (int i = 0; i < n; i++) { 98 | rbtree.delete(rbTreeVals.remove(0)); 99 | } 100 | rbTimeFinish = System.nanoTime(); 101 | 102 | avlTimeStart = System.nanoTime(); 103 | for (int i = 0; i < n; i++) { 104 | avltree.delete(avlTreeVals.remove(0)); 105 | } 106 | avlTimeFinish = System.nanoTime(); 107 | 108 | bstTimeStart = System.nanoTime(); 109 | for (int i = 0; i < n; i++) { 110 | binarySearchTree.delete(bstVals.remove(0)); 111 | } 112 | bstTimeFinish = System.nanoTime(); 113 | 114 | System.out.printf( 115 | "%-40s %d values: %-5d milliseconds\n", 116 | "Red Black Tree deletion time for", n, ((rbTimeFinish - rbTimeStart) / 1000000)); 117 | System.out.printf( 118 | "%-40s %d values: %-5d milliseconds\n", 119 | "AVL Tree deletion time for", n, ((avlTimeFinish - avlTimeStart) / 1000000)); 120 | System.out.printf( 121 | "%-40s %d values: %-5d milliseconds\n", 122 | "Binary Search Tree deletion time for", n, ((bstTimeFinish - bstTimeStart) / 1000000)); 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/binarysearchtrees/TreeVisualizer.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.binarysearchtrees; 2 | 3 | import datastructures.trees.binarysearchtrees.avltree.AVLTree; 4 | import java.awt.*; 5 | import java.awt.Color; 6 | import java.awt.event.KeyEvent; 7 | import java.awt.event.KeyListener; 8 | import java.util.HashMap; 9 | import java.util.Map; 10 | import javax.swing.*; 11 | 12 | @SuppressWarnings("unused") 13 | public class TreeVisualizer extends JPanel implements KeyListener { 14 | private final double zoomFactor = 0.5; 15 | private final JScrollPane scrollPane; 16 | private BinaryTree tree; 17 | private Map, Point> map; 18 | private int width, height, rootHeight; 19 | private JFrame frame; 20 | 21 | public TreeVisualizer(BinaryTree tree, int width, int height) { 22 | this.tree = tree; 23 | map = new HashMap<>(); 24 | 25 | JFrame frame = new JFrame("Visualizing the tree!"); 26 | frame.setBounds(100, 100, width, height); 27 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 28 | frame.setResizable(true); 29 | this.setBackground(Color.WHITE); 30 | this.setLayout(null); 31 | this.setPreferredSize(new Dimension(width, height)); 32 | scrollPane = new JScrollPane(this); 33 | scrollPane.setHorizontalScrollBarPolicy(JScrollPane.HORIZONTAL_SCROLLBAR_ALWAYS); 34 | scrollPane.setVerticalScrollBarPolicy(JScrollPane.VERTICAL_SCROLLBAR_ALWAYS); 35 | frame.add(scrollPane); 36 | frame.setVisible(true); 37 | } 38 | 39 | public static void main(String[] args) { 40 | BinaryTree myTree = new AVLTree<>(); 41 | TreeVisualizer treeVisualizer = new TreeVisualizer(myTree, 1600, 900); 42 | for (int i = 0; i < 200; i++) { 43 | // myTree.insert((int)(Math.random()*500)); 44 | myTree.insert(i); 45 | treeVisualizer.setTree(myTree); 46 | try { 47 | Thread.sleep(200); 48 | } catch (InterruptedException e) { 49 | e.printStackTrace(); 50 | } 51 | } 52 | } 53 | 54 | public void setTree(BinaryTree tree) { 55 | this.tree = tree; 56 | repaint(); 57 | } 58 | 59 | private void setCoords() { 60 | map = new HashMap<>(); 61 | this.rootHeight = tree.getHeight(); 62 | this.height = ((int) (rootHeight * 100 * zoomFactor)); 63 | this.width = 64 | ((int) 65 | ((maxLeft(tree.getRoot()) + maxRight(tree.getRoot())) 66 | * (this.height) 67 | * 5 68 | * zoomFactor)); 69 | this.setPreferredSize(new Dimension(width, height)); 70 | setCoords(tree.getRoot(), 1, 0, false); 71 | } 72 | 73 | private int maxLeft(Node root) { 74 | if (root.getLeft() == null || root.getLeft().isNull()) return 1; 75 | else return 1 + maxLeft(root.getLeft()); 76 | } 77 | 78 | private int maxRight(Node root) { 79 | if (root.getRight() == null || root.getRight().isNull()) return 1; 80 | else return 1 + maxRight(root.getRight()); 81 | } 82 | 83 | private void setCoords(Node root, int level, int prevX, boolean leftSide) { 84 | if (root == null || root.isNull()) return; 85 | 86 | int x, y; 87 | if (leftSide) x = prevX - ((int) (width / Math.pow(2, level) * zoomFactor * zoomFactor)); 88 | else x = prevX + ((int) (width / Math.pow(2, level) * zoomFactor * zoomFactor)); 89 | 90 | y = ((int) (height / rootHeight * level * zoomFactor)); 91 | map.put(root, new Point(x, y)); 92 | 93 | setCoords(root.getLeft(), level + 1, x, true); 94 | setCoords(root.getRight(), level + 1, x, false); 95 | } 96 | 97 | protected void paintComponent(Graphics g) { 98 | super.paintComponent(g); 99 | 100 | setCoords(); 101 | this.setPreferredSize(new Dimension(this.width, this.height)); 102 | scrollPane.setViewportView(this); 103 | 104 | g.setColor(Color.BLACK); 105 | g.setFont(new Font("Arial", Font.PLAIN, 20)); 106 | 107 | int x, y, textX, textY, radius = 20; 108 | FontMetrics metrics = g.getFontMetrics(); 109 | for (Node node : tree) { 110 | if (node.isNull()) continue; 111 | x = ((int) map.get(node).getX()); 112 | y = ((int) map.get(node).getY()); 113 | String text = node.getVal().toString(); 114 | JLabel label = new JLabel(text); 115 | // label.set 116 | textX = x - metrics.stringWidth(text) / 2; 117 | textY = y; 118 | g.drawString(text, textX, textY); 119 | g.drawOval(x - radius, y - radius, radius * 2, radius * 2); 120 | 121 | if (map.containsKey(node.getLeft())) 122 | g.drawLine( 123 | x, y, ((int) map.get(node.getLeft()).getX()), ((int) map.get(node.getLeft()).getY())); 124 | if (map.containsKey(node.getRight())) 125 | g.drawLine( 126 | x, y, ((int) map.get(node.getRight()).getX()), ((int) map.get(node.getRight()).getY())); 127 | } 128 | } 129 | 130 | @SuppressWarnings("StatementWithEmptyBody") 131 | public void keyPressed(KeyEvent e) { 132 | if (e.getKeyCode() == KeyEvent.VK_SPACE) {} 133 | } 134 | 135 | public void keyReleased(KeyEvent e) {} 136 | 137 | public void keyTyped(KeyEvent e) {} 138 | } 139 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/binarysearchtrees/Utility.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.binarysearchtrees; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | import java.util.Stack; 7 | 8 | @SuppressWarnings("unused") 9 | public class Utility { 10 | 11 | /** 12 | * Prints a tree in preorder traversal. 13 | * 14 | * @param tree The tree to traverse. 15 | * @param The type parameter of the tree. 16 | */ 17 | public static > void printInPreorder(BinaryTree tree) { 18 | printInPreorder(tree.getRoot()); 19 | System.out.println(); 20 | } 21 | 22 | /** 23 | * Prints a tree in inorder traversal. 24 | * 25 | * @param tree The tree to traverse. 26 | * @param The type parameter of the tree. 27 | */ 28 | public static > void printInInorder(BinaryTree tree) { 29 | printInInorder(tree.getRoot()); 30 | System.out.println(); 31 | } 32 | 33 | /** 34 | * Prints a tree in post order traversal. 35 | * 36 | * @param tree The tree to traverse. 37 | * @param The type parameter of the tree. 38 | */ 39 | public static > void printInPostOrder(BinaryTree tree) { 40 | printInPostOrder(tree.getRoot()); 41 | System.out.println(); 42 | } 43 | 44 | /** 45 | * Helper method to print in preorder traversal. 46 | * 47 | * @param root The root node of the tree. 48 | * @param The type parameter of the tree. 49 | */ 50 | private static > void printInPreorder(Node root) { 51 | if (root == null || root.isNull()) { 52 | return; 53 | } 54 | 55 | System.out.print(root.getVal() + " "); 56 | printInPreorder(root.getLeft()); 57 | printInPreorder(root.getRight()); 58 | } 59 | 60 | /** 61 | * Helper method to print in inorder traversal. 62 | * 63 | * @param root The root node of the tree. 64 | * @param The type parameter of the tree. 65 | */ 66 | private static > void printInInorder(Node root) { 67 | if (root == null || root.isNull()) { 68 | return; 69 | } 70 | 71 | printInInorder(root.getLeft()); 72 | System.out.print(root.getVal() == null ? "" : root.getVal() + " "); 73 | printInInorder(root.getRight()); 74 | } 75 | 76 | /** 77 | * Helper method to print in post order traversal. 78 | * 79 | * @param root The root node of the tree. 80 | * @param The type parameter of the tree. 81 | */ 82 | private static > void printInPostOrder(Node root) { 83 | if (root == null || root.isNull()) { 84 | return; 85 | } 86 | 87 | printInPostOrder(root.getLeft()); 88 | printInPostOrder(root.getRight()); 89 | System.out.print(root.getVal() + " "); 90 | } 91 | 92 | /** 93 | * Creates a tree with random values. 94 | * 95 | * @param tree The tree who's values you are filling. 96 | * @param size The size of the tree. 97 | * @param max The maximum of every node in the new tree, negative and positive. 98 | */ 99 | public static void createTree(BinaryTree tree, int size, int max) { 100 | for (int i = 0; i < size; i++) { 101 | tree.insert(i); 102 | } 103 | } 104 | 105 | /** 106 | * Determines if a binary search tree is valid, meaning left subtrees are smaller than the root 107 | * and right subtrees are larger. 108 | * 109 | * @param tree The tree to check. 110 | * @param The generic type for the binary tree. 111 | * @return Whether it's valid or not. 112 | */ 113 | public static > boolean isValidBST(BinaryTree tree) { 114 | return isValid(tree.getRoot()); 115 | } 116 | 117 | private static > boolean isValid(Node root) { 118 | if (root == null) return true; 119 | Stack> stack = new Stack<>(); 120 | Node pre = null; 121 | while (root != null || !stack.isEmpty()) { 122 | while (root != null) { 123 | stack.push(root); 124 | root = root.getLeft(); 125 | } 126 | root = stack.pop(); 127 | if (pre != null && root.getVal().compareTo(pre.getVal()) < 0) return false; 128 | pre = root; 129 | root = root.getRight(); 130 | } 131 | return true; 132 | } 133 | 134 | public static > void printTree(BinaryTree tree) { 135 | printNode(tree.getRoot()); 136 | } 137 | 138 | private static > void printNode(Node root) { 139 | int maxLevel = maxLevel(root); 140 | 141 | printNodeInternal(Collections.singletonList(root), 1, maxLevel); 142 | } 143 | 144 | private static > void printNodeInternal( 145 | List> nodes, int level, int maxLevel) { 146 | if (nodes.isEmpty() || isAllElementsNull(nodes)) return; 147 | 148 | int floor = maxLevel - level; 149 | int edgeLines = (int) Math.pow(2, (Math.max(floor - 1, 0))); 150 | int firstSpaces = (int) Math.pow(2, (floor)) - 1; 151 | int betweenSpaces = (int) Math.pow(2, (floor + 1)) - 1; 152 | 153 | printWhitespaces(firstSpaces); 154 | 155 | List> newNodes = new ArrayList<>(); 156 | for (Node node : nodes) { 157 | if (node != null && !node.isNull()) { 158 | System.out.print(node.getVal()); 159 | newNodes.add(node.getLeft()); 160 | newNodes.add(node.getRight()); 161 | } else { 162 | newNodes.add(null); 163 | newNodes.add(null); 164 | System.out.print(" "); 165 | } 166 | 167 | printWhitespaces(betweenSpaces); 168 | } 169 | System.out.println(); 170 | 171 | for (int i = 1; i <= edgeLines; i++) { 172 | for (Node node : nodes) { 173 | printWhitespaces(firstSpaces - i); 174 | if (node == null || node.isNull()) { 175 | printWhitespaces(edgeLines + edgeLines + i + 1); 176 | continue; 177 | } 178 | 179 | if (node.getLeft() != null && !node.getLeft().isNull()) { 180 | System.out.print("/"); 181 | } else { 182 | printWhitespaces(1); 183 | } 184 | 185 | printWhitespaces(i + i - 1); 186 | 187 | if (node.getRight() != null && !node.getRight().isNull()) { 188 | System.out.print("\\"); 189 | } else { 190 | printWhitespaces(1); 191 | } 192 | 193 | printWhitespaces(edgeLines + edgeLines - i); 194 | } 195 | 196 | System.out.println(); 197 | } 198 | 199 | printNodeInternal(newNodes, level + 1, maxLevel); 200 | } 201 | 202 | private static void printWhitespaces(int count) { 203 | for (int i = 0; i < count; i++) { 204 | System.out.print(" "); 205 | } 206 | } 207 | 208 | private static > int maxLevel(Node node) { 209 | if (node == null || node.isNull()) { 210 | return 0; 211 | } 212 | 213 | return Math.max(maxLevel(node.getLeft()), maxLevel(node.getRight())) + 1; 214 | } 215 | 216 | private static > boolean isAllElementsNull(List> list) { 217 | for (Node node : list) { 218 | if (node != null && !node.isNull()) { 219 | return false; 220 | } 221 | } 222 | 223 | return true; 224 | } 225 | } 226 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/binarysearchtrees/avltree/AVLTree.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.binarysearchtrees.avltree; 2 | 3 | import datastructures.trees.binarysearchtrees.BinaryTree; 4 | import datastructures.trees.binarysearchtrees.Color; 5 | import datastructures.trees.binarysearchtrees.Node; 6 | import java.util.*; 7 | 8 | @SuppressWarnings("ALL") 9 | public class AVLTree> implements BinaryTree, Iterable> { 10 | private final AVLTreeNode rootParent; 11 | private AVLTreeNode root; 12 | 13 | /** Creates a default, empty AVL tree with a null root */ 14 | public AVLTree() { 15 | rootParent = new AVLTreeNode((E) null); 16 | rootParent.left = null; 17 | } 18 | 19 | /** 20 | * Creates a clone of another AVL tree, with clones of each respective tree node in the other AVL 21 | * tree 22 | * 23 | * @param other The AVL tree to clone 24 | */ 25 | @SuppressWarnings("unused") 26 | public AVLTree(AVLTree other) { 27 | this.root = new AVLTreeNode(other.root); 28 | rootParent = new AVLTreeNode((E) null); 29 | rootParent.left = root; 30 | } 31 | 32 | public static void main(String[] args) { 33 | AVLTree myTree; 34 | List treeVals; 35 | int insertCorrect = 0, deleteCorrect = 0, trials = 10000; 36 | 37 | for (int i = 0; i < trials; i++) { 38 | myTree = new AVLTree<>(); 39 | treeVals = new LinkedList<>(); 40 | 41 | for (int j = 0; j < 31; j++) { 42 | int a = (int) (Math.random() * 100); 43 | treeVals.add(a); 44 | myTree.insert(a); 45 | } 46 | 47 | boolean correct1 = myTree.isValid(); 48 | if (correct1) insertCorrect++; 49 | 50 | for (int j = 0; j < 10; j++) { 51 | int a = (int) (Math.random() * treeVals.size()); 52 | myTree.delete(treeVals.remove(a)); 53 | } 54 | 55 | boolean correct2 = myTree.isValid(); 56 | if (correct2) deleteCorrect++; 57 | } 58 | 59 | System.out.println( 60 | "Insertion correct percentage: " + (((double) insertCorrect) / trials * 100) + "%"); 61 | System.out.println( 62 | "Deletion correct percentage: " + (((double) deleteCorrect) / trials * 100) + "%"); 63 | } 64 | 65 | /** 66 | * Finds and returns the root node of the tree 67 | * 68 | * @return The root node 69 | */ 70 | public Node getRoot() { 71 | AVLTreeNode res = new AVLTreeNode(root.val); 72 | res.left = root.left; 73 | res.right = root.right; 74 | return res; 75 | } 76 | 77 | /** 78 | * Fixes the tree after an insertion or deletion Assuming the balance factors are correct in all 79 | * of the tree nodes, the tree is fixed using the following rules: 1. First have to fix both thnne 80 | * left and right subtrees 2. There is no need to fix the current tree node if the balance factor 81 | * of the current node belongs to the set of {-1, 0, 1}. 3. If the balance factor is greater than 82 | * or equal to 2 (the left subtree is heavier than the right subtree) then do the following: a. If 83 | * the current node's left child's balance factor is negative (the right subtree of the current 84 | * node's left subtree is heavier than the left of the left of the current) then first have to 85 | * rotate the current node's left subtree to the left b. Rotate the current node to the right 4. 86 | * If the balance factor of the current node is less than or equal to -2, then perform the mirror 87 | * image of step #3 on the current node 88 | * 89 | * @param root The "current" node to fix 90 | */ 91 | private void fixTree(AVLTreeNode root) { 92 | if (root == null || (root.left == null && root.right == null)) return; 93 | fixTree(root.left); 94 | fixTree(root.right); 95 | 96 | int bf = root.balanceFactor; // root.balanceFactor = root.left.height - root.right.height 97 | if (Math.abs(bf) < 2) return; 98 | if (bf >= 2) { 99 | boolean lr = root.left.balanceFactor < 0; 100 | if (lr) { 101 | leftRotate(root.left); 102 | } 103 | rightRotate(root); 104 | } else if (bf <= -2) { 105 | boolean rl = root.right.balanceFactor > 0; 106 | if (rl) { 107 | rightRotate(root.right); 108 | } 109 | leftRotate(root); 110 | } 111 | this.root = rootParent.left; 112 | } 113 | 114 | /** 115 | * Rotates the 'root' tree node to the right side, used in the fixTree() method This rotation 116 | * preserves the binary search tree property, but is used to reduce the maximum height of the tree 117 | * 118 | * @param root The tree node to be rotated to the right 119 | */ 120 | private void rightRotate(AVLTreeNode root) { 121 | AVLTreeNode leftSide = root.left; 122 | if (root == root.parent.left) { 123 | root.parent.left = leftSide; 124 | } else { 125 | root.parent.right = leftSide; 126 | } 127 | leftSide.parent = root.parent; 128 | 129 | root.left = leftSide.right; 130 | if (leftSide.right != null) leftSide.right.parent = root; 131 | leftSide.right = root; 132 | root.parent = leftSide; 133 | 134 | updateBalanceFactorsAndHeights(root); 135 | } 136 | 137 | /** 138 | * Rotates the 'root' tree node to the left side, used in the fixTree() method This rotation 139 | * preserves the binary search tree property, but is used to reduce the maximum height of the tree 140 | * 141 | * @param root The tree node to be rotated to the left 142 | */ 143 | private void leftRotate(AVLTreeNode root) { 144 | AVLTreeNode rightSide = root.right; 145 | if (root == root.parent.left) { 146 | root.parent.left = rightSide; 147 | } else { 148 | root.parent.right = rightSide; 149 | } 150 | rightSide.parent = root.parent; 151 | 152 | root.right = rightSide.left; 153 | if (rightSide.left != null) rightSide.left.parent = root; 154 | rightSide.left = root; 155 | root.parent = rightSide; 156 | 157 | updateBalanceFactorsAndHeights(root); 158 | } 159 | 160 | /** 161 | * Inserts the value 'val' into the tree using the following steps: 1. Insert the value as though 162 | * inserting into a regular binary search tree 2. Fix the tree, starting from the root node (see 163 | * fixTree(AVLTreeNode) for how this works) 164 | * 165 | * @param val The value to insert into the tree 166 | */ 167 | public void insert(E val) { 168 | AVLTreeNode insertion = new AVLTreeNode(val), parent = root; 169 | if (root == null) { 170 | root = insertion; 171 | root.parent = rootParent; 172 | rootParent.left = root; 173 | calcBalanceFactor(root); 174 | return; 175 | } 176 | 177 | int comparison = val.compareTo(parent.val); 178 | 179 | while ((comparison < 0 && parent.left != null) || (comparison >= 0 && parent.right != null)) { 180 | if (comparison < 0) { 181 | parent = parent.left; 182 | } else { 183 | parent = parent.right; 184 | } 185 | comparison = val.compareTo(parent.val); 186 | } 187 | 188 | if (comparison < 0) { 189 | parent.left = insertion; 190 | } else { 191 | parent.right = insertion; 192 | } 193 | insertion.parent = parent; 194 | 195 | updateBalanceFactorsAndHeights(insertion); 196 | 197 | fixTree(root); 198 | } 199 | 200 | /** 201 | * Deletes the value 'val' from this tree, in 2 steps First, delete the value as if this tree were 202 | * a regular binary search tree Second, fix the tree, using the same fixing method as the 203 | * insertion 204 | * 205 | * @param val The value to be deleted 206 | */ 207 | public void delete(E val) { 208 | AVLTreeNode deletion = root, successor; 209 | if (deletion == null) return; 210 | int comparison = deletion.val.compareTo(val); 211 | while (comparison != 0) { 212 | if (comparison > 0) deletion = deletion.left; 213 | else deletion = deletion.right; 214 | if (deletion == null) return; 215 | comparison = deletion.val.compareTo(val); 216 | } 217 | 218 | if (deletion.left == null && deletion.right == null) { 219 | if (deletion == deletion.parent.left) deletion.parent.left = null; 220 | else deletion.parent.right = null; 221 | updateBalanceFactorsAndHeights(deletion.parent); 222 | } else if (deletion.left == null) { 223 | successor = deletion.right; 224 | if (deletion == deletion.parent.left) deletion.parent.left = successor; 225 | else deletion.parent.right = successor; 226 | successor.parent = deletion.parent; 227 | updateBalanceFactorsAndHeights(deletion.parent); 228 | } else if (deletion.right == null) { 229 | successor = deletion.left; 230 | if (deletion == deletion.parent.left) deletion.parent.left = successor; 231 | else deletion.parent.right = successor; 232 | successor.parent = deletion.parent; 233 | updateBalanceFactorsAndHeights(deletion.parent); 234 | } else { 235 | successor = minimum(deletion.right); 236 | AVLTreeNode rightSide = successor.right, update; 237 | if (rightSide == null) { 238 | if (successor.parent == deletion) update = successor; 239 | else update = successor.parent; 240 | } else update = rightSide; 241 | if (successor == successor.parent.right) successor.parent.right = rightSide; 242 | else successor.parent.left = rightSide; 243 | if (rightSide != null) rightSide.parent = successor.parent; 244 | 245 | successor.parent = deletion.parent; 246 | if (deletion == deletion.parent.left) successor.parent.left = successor; 247 | else successor.parent.right = successor; 248 | successor.left = deletion.left; 249 | assert successor.left != null; 250 | successor.left.parent = successor; 251 | successor.right = deletion.right; 252 | if (successor.right != null) successor.right.parent = successor; 253 | updateBalanceFactorsAndHeights(update); 254 | } 255 | root = rootParent.left; 256 | fixTree(root); 257 | } 258 | 259 | /** 260 | * Gets the tree node with the minimum value in the subtree of the given node 261 | * 262 | * @param root The subtree to search in 263 | * @return The node with minimum value in the subtree 264 | */ 265 | private AVLTreeNode minimum(AVLTreeNode root) { 266 | if (root == null || root.left == null) { 267 | return root; 268 | } else return minimum(root.left); 269 | } 270 | 271 | public int getHeight() { 272 | return getHeight(getRoot()); 273 | } 274 | 275 | private int getHeight(Node root) { 276 | if (root == null) { 277 | return 0; 278 | } 279 | 280 | return 1 + Math.max(getHeight(root.getLeft()), getHeight(root.getRight())); 281 | } 282 | 283 | /** 284 | * Checks if the tree contains a specified value, same logic as a binary search tree search 285 | * 286 | *

Time complexity: O(log n) 287 | * 288 | * @param val the value to check 289 | * @return if the tree contains the value 290 | */ 291 | public boolean contains(E val) { 292 | return contains(val, root); 293 | } 294 | 295 | private boolean contains(E val, AVLTreeNode root) { 296 | if (root == null) return false; 297 | else if (root.val == val) return true; 298 | else { 299 | if (root.val.compareTo(val) > 0) return contains(val, root.left); 300 | else return contains(val, root.right); 301 | } 302 | } 303 | 304 | private void updateBalanceFactorsAndHeights(AVLTreeNode root) { 305 | if (root == null || this.root == null) return; 306 | while (root != rootParent) { 307 | setBalanceFactors(root); 308 | root = root.parent; 309 | } 310 | } 311 | 312 | /** 313 | * Calculates the balance factor of each node in the subtree The balance factor of a given node is 314 | * calculated by taking the difference of the maximum height of the left and right subtrees The 315 | * balance factor of a leaf node is 0 316 | * 317 | * @param root The subtree in which to calculate the balance factors 318 | */ 319 | private void calcBalanceFactor(AVLTreeNode root) { 320 | if (root == null) return; 321 | calcBalanceFactor(root.left); 322 | calcBalanceFactor(root.right); 323 | setBalanceFactors(root); 324 | } 325 | 326 | private void setBalanceFactors(AVLTreeNode root) { 327 | if (root.left == null && root.right == null) { 328 | root.height = 1; 329 | root.balanceFactor = 0; 330 | } else if (root.right == null) { 331 | root.height = root.left.height + 1; 332 | root.balanceFactor = root.left.height; 333 | } else if (root.left == null) { 334 | root.height = root.right.height + 1; 335 | root.balanceFactor = -root.right.height; 336 | } else { 337 | root.height = Math.max(root.left.height, root.right.height) + 1; 338 | root.balanceFactor = root.left.height - root.right.height; 339 | } 340 | } 341 | 342 | /** 343 | * Calculates the maximum height in each of the nodes, starting from the given node The maximum 344 | * height of a node is calculated by taking the maximum of the heights of the current node's left 345 | * and right subtrees, and adding 1 The height of a leaf node is 1 346 | * 347 | * @param root The node to start calculating heights from 348 | */ 349 | private void calcHeights(AVLTreeNode root) { 350 | if (root != null) { 351 | if (root.left == null && root.right == null) { 352 | root.height = 1; 353 | } else if (root.left == null) { 354 | calcHeights(root.right); 355 | root.height = root.right.height + 1; 356 | } else if (root.right == null) { 357 | calcHeights(root.left); 358 | root.height = root.left.height + 1; 359 | } else { 360 | calcHeights(root.left); 361 | calcHeights(root.right); 362 | root.height = Math.max(root.left.height, root.right.height) + 1; 363 | } 364 | } 365 | } 366 | 367 | @Override 368 | public int numNodes() { 369 | return numNodes(root); 370 | } 371 | 372 | private int numNodes(AVLTreeNode root) { 373 | if (root == null) return 0; 374 | else return numNodes(root.left) + numNodes(root.right) + 1; 375 | } 376 | 377 | /** 378 | * Gets the maximum height of the whole tree 379 | * 380 | * @return The maximum height of the whole tree 381 | */ 382 | public int maxHeight() { 383 | return root.height; 384 | } 385 | 386 | /** 387 | * Checks if the AVL tree is valid, which means the tree has to conform to the following rules: 1. 388 | * The tree must retain the standard binary search tree properties 2. The balance factor of any 389 | * node has to belong to the set of {-1, 0, 1} (see calcBalanceFactor(AVLTreeNode root) for 390 | * definition of balance factor) 391 | * 392 | * @return Whether this tree is a valid AVL tree 393 | */ 394 | public boolean isValid() { 395 | calcBalanceFactor(root); 396 | return isValidBST(root) && isValidAVLTree(root); 397 | } 398 | 399 | private boolean isValidBST(AVLTreeNode root) { 400 | if (root == null) return true; 401 | if (root.left != null && root.left.parent != root) return false; 402 | if (root.right != null && root.right.parent != root) return false; 403 | if (root.parent.left != root && root.parent.right != root) return false; 404 | if (root.left != null && root.val.compareTo(root.left.val) < 0) return false; 405 | if (root.right != null && root.val.compareTo(root.right.val) > 0) return false; 406 | return isValidBST(root.left) && isValidBST(root.right); 407 | } 408 | 409 | private boolean isValidAVLTree(AVLTreeNode root) { 410 | if (root == null) return true; 411 | return Math.abs(root.balanceFactor) <= 1 412 | && isValidAVLTree(root.left) 413 | && isValidAVLTree(root.right); 414 | } 415 | 416 | /** 417 | * Checks if this tree is logically equal to another tree The two trees are considered equal iff 418 | * all of the values are same between both trees and the order of each value is same between both 419 | * trees 420 | * 421 | * @param o The other AVL tree to check equality 422 | * @return Whether this tree is equal to the other given AVL tree 423 | */ 424 | public boolean equals(Object o) { 425 | if (this == o) return true; 426 | if (!(o instanceof AVLTree)) return false; 427 | AVLTree avlTree = (AVLTree) o; 428 | return Objects.equals(root, avlTree.root); 429 | } 430 | 431 | public int hashCode() { 432 | return Objects.hash(root); 433 | } 434 | 435 | /** 436 | * Iterates through the values in this tree in and inorder traversal 437 | * 438 | * @return An iterator that iterates through the values in this tree in and inorder traversal 439 | */ 440 | private Iterator valIterator() { 441 | Iterator> nodeIterator = iterator(); 442 | List res = new LinkedList<>(); 443 | nodeIterator.forEachRemaining((Node node) -> res.add(node.getVal())); 444 | return res.iterator(); 445 | } 446 | 447 | /** 448 | * Iterates through the nodes in this tree in and inorder traversal 449 | * 450 | * @return An iterator that iterates through the nodes in this tree in and inorder traversal 451 | */ 452 | @Override 453 | public Iterator> iterator() { 454 | List> res = new LinkedList<>(); 455 | Queue q = new LinkedList<>(); 456 | if (root != null) q.add(root); 457 | AVLTreeNode current; 458 | while (!q.isEmpty()) { 459 | current = q.remove(); 460 | res.add(current); 461 | if (current.left != null) q.add(current.left); 462 | if (current.right != null) q.add(current.right); 463 | } 464 | return res.iterator(); 465 | } 466 | 467 | /** 468 | * Definition of a tree node held by an AVL tree Contains a value, parent node, left node, right 469 | * node, height, and balance factor Height and balance factor are updated every time a node is 470 | * inserted or deleted from the tree 471 | */ 472 | private class AVLTreeNode implements Node { 473 | /** The value of this binary search tree node. */ 474 | private E val; 475 | /** The height of this tree node */ 476 | private int height; 477 | /** The balance factor of this tree node, used in balancing the tree */ 478 | private int balanceFactor; 479 | /** The left, right, and parent of this node, respectively. */ 480 | private AVLTreeNode left, right, parent; 481 | 482 | /** 483 | * Creates a new node with a certain value. 484 | * 485 | * @param val The value of this node. 486 | */ 487 | public AVLTreeNode(E val) { 488 | this.val = val; 489 | } 490 | 491 | /** 492 | * Creates a node with copied values from the other AVL tree node 493 | * 494 | * @param other The other node to be copied 495 | */ 496 | public AVLTreeNode(AVLTreeNode other) { 497 | if (other == null) return; 498 | this.val = other.val; 499 | if (other.left != null) { 500 | this.left = new AVLTreeNode(other.left); 501 | this.left.parent = this; 502 | } 503 | if (other.right != null) { 504 | this.right = new AVLTreeNode(other.right); 505 | this.right.parent = this; 506 | } 507 | } 508 | 509 | public String toString() { 510 | return "{" + "val=" + val + '}'; 511 | } 512 | 513 | /** 514 | * Gets the value of this node. 515 | * 516 | * @return The value of this node. 517 | */ 518 | @Override 519 | public E getVal() { 520 | return val; 521 | } 522 | 523 | /** 524 | * Sets the value of this node. 525 | * 526 | * @param newVal This node's new value. 527 | */ 528 | @Override 529 | public void setVal(E newVal) { 530 | val = newVal; 531 | } 532 | 533 | /** 534 | * Gets the left child of this node. 535 | * 536 | * @return The left child of this node. 537 | */ 538 | @Override 539 | public Node getLeft() { 540 | return left; 541 | } 542 | 543 | /** 544 | * Sets the left child of this node. 545 | * 546 | * @param newNode This node's new left child. 547 | */ 548 | @Override 549 | public void setLeft(Node newNode) { 550 | left = (AVLTreeNode) newNode; 551 | } 552 | 553 | /** 554 | * Gets the right child of this node. 555 | * 556 | * @return The right child of this node. 557 | */ 558 | @Override 559 | public Node getRight() { 560 | return right; 561 | } 562 | 563 | /** 564 | * Sets the right child of this node. 565 | * 566 | * @param newNode The node's new right child. 567 | */ 568 | @Override 569 | public void setRight(Node newNode) { 570 | right = (AVLTreeNode) newNode; 571 | } 572 | 573 | /** 574 | * Gets the parent of this node. 575 | * 576 | * @return The parent of this node. 577 | */ 578 | @Override 579 | public Node getParent() { 580 | return parent; 581 | } 582 | 583 | /** 584 | * Sets the parent of this node. 585 | * 586 | * @param newNode The node's new parent. 587 | */ 588 | @Override 589 | public void setParent(Node newNode) { 590 | parent = (AVLTreeNode) newNode; 591 | } 592 | 593 | /** 594 | * Checks if this node is null. 595 | * 596 | * @return If this node is null. 597 | */ 598 | @Override 599 | public boolean isNull() { 600 | return false; 601 | } 602 | 603 | @Override 604 | public Color getColor() { 605 | return null; 606 | } 607 | 608 | @Override 609 | public void setColor(Color newColor) {} 610 | 611 | @Override 612 | public boolean equals(Object o) { 613 | if (this != o) { 614 | if (o instanceof AVLTree.AVLTreeNode) { 615 | @SuppressWarnings("unchecked") 616 | AVLTreeNode that = (AVLTreeNode) o; 617 | return Objects.equals(val, that.val) 618 | && Objects.equals(left, that.left) 619 | && Objects.equals(right, that.right); 620 | } else { 621 | return false; 622 | } 623 | } else { 624 | return true; 625 | } 626 | } 627 | 628 | @Override 629 | public int hashCode() { 630 | return Objects.hash(val, left, right); 631 | } 632 | } 633 | } 634 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/binarysearchtrees/binarysearchtree/BinarySearchTree.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.binarysearchtrees.binarysearchtree; 2 | 3 | import datastructures.trees.binarysearchtrees.BinaryTree; 4 | import datastructures.trees.binarysearchtrees.Color; 5 | import datastructures.trees.binarysearchtrees.Node; 6 | import java.util.Iterator; 7 | import java.util.LinkedList; 8 | import java.util.List; 9 | import java.util.Queue; 10 | 11 | @SuppressWarnings("unused") 12 | public class BinarySearchTree> implements BinaryTree { 13 | 14 | /** Parent of the root, just used to simplify insertion and deletion algorithms */ 15 | private final BSTNode rootParent; 16 | /** The root of the binary search tree. */ 17 | private BSTNode root; 18 | 19 | /** Default constructor of this tree */ 20 | public BinarySearchTree() { 21 | root = null; 22 | rootParent = new BSTNode(null); 23 | } 24 | 25 | public static void main(String[] args) { 26 | BinarySearchTree myTree; 27 | List treeVals; 28 | int insertCorrect = 0, deleteCorrect = 0, trials = 10000; 29 | 30 | for (int i = 0; i < trials; i++) { 31 | myTree = new BinarySearchTree<>(); 32 | treeVals = new LinkedList<>(); 33 | 34 | for (int j = 0; j < 31; j++) { 35 | int a = (int) (Math.random() * 100); 36 | treeVals.add(a); 37 | myTree.insert(a); 38 | } 39 | 40 | boolean correct1 = myTree.isValid(); 41 | if (correct1) insertCorrect++; 42 | 43 | for (int j = 0; j < 10; j++) { 44 | int a = (int) (Math.random() * treeVals.size()); 45 | myTree.delete(treeVals.remove(a)); 46 | } 47 | 48 | boolean correct2 = myTree.isValid(); 49 | if (correct2) deleteCorrect++; 50 | } 51 | 52 | System.out.println( 53 | "Insertion correct percentage: " + (((double) insertCorrect) / trials * 100) + "%"); 54 | System.out.println( 55 | "Deletion correct percentage: " + (((double) deleteCorrect) / trials * 100) + "%"); 56 | } 57 | 58 | /** 59 | * Gets the root of the binary tree. 60 | * 61 | * @return The root of the tree. 62 | */ 63 | @Override 64 | public Node getRoot() { 65 | return root; 66 | } 67 | 68 | /** 69 | * Inserts a node into the binary tree. After inserting, the tree still has to follow the binary 70 | * search tree guidelines. 71 | * 72 | * @param val The value of the new node. 73 | */ 74 | @Override 75 | public void insert(E val) { 76 | if (root == null) { 77 | root = new BSTNode(val); 78 | root.parent = rootParent; 79 | rootParent.left = root; 80 | } else { 81 | Node dummy = root; 82 | 83 | while (true) { 84 | if (val.compareTo(dummy.getVal()) >= 0) { 85 | if (dummy.getRight() == null) { 86 | dummy.setRight(new BSTNode(val)); 87 | dummy.getRight().setParent(dummy); 88 | break; 89 | } 90 | 91 | dummy = dummy.getRight(); 92 | } else { 93 | if (dummy.getLeft() == null) { 94 | dummy.setLeft(new BSTNode(val)); 95 | dummy.getLeft().setParent(dummy); 96 | break; 97 | } 98 | 99 | dummy = dummy.getLeft(); 100 | } 101 | } 102 | } 103 | } 104 | 105 | /** 106 | * Deletes a node from the binary search tree. After deletion, the binary tree still has to follow 107 | * the binary search tree guidelines. If the element doesn't exist, just returns. 108 | * 109 | * @param val The value of the new node. 110 | */ 111 | @Override 112 | public void delete(E val) { 113 | // Node dummy = root; 114 | // 115 | // if(dummy == null) { 116 | // return; 117 | // } 118 | // 119 | // while(!dummy.getVal().equals(val)) { 120 | // if(val.compareTo(dummy.getVal()) > 0) { 121 | // dummy = dummy.getRight(); 122 | // } else { 123 | // dummy = dummy.getLeft(); 124 | // } 125 | // 126 | // if(dummy == null) { 127 | // return; 128 | // } 129 | // } 130 | // 131 | // if(dummy.getLeft() == null && dummy.getRight() == null) { // case 1 132 | // if(dummy == dummy.getParent().getLeft()) { 133 | // dummy.getParent().setLeft(null); 134 | // } else { 135 | // dummy.getParent().setRight(null); 136 | // } 137 | // } else if(dummy.getLeft() == null || dummy.getRight() == null) { // case 2 138 | // if(dummy.getParent() == null) { 139 | // if(dummy.getLeft() == null) { 140 | // root = root.right; 141 | // } else { 142 | // root = root.left; 143 | // } 144 | // } else if(dummy == dummy.getParent().getLeft()) { 145 | // if(dummy.getLeft() == null) { 146 | // dummy.getParent().setLeft(dummy.getRight()); 147 | // } else { 148 | // dummy.getParent().setLeft(dummy.getLeft()); 149 | // } 150 | // } else { 151 | // if(dummy.getLeft() == null) { 152 | // dummy.getParent().setRight(dummy.getRight()); 153 | // } else { 154 | // dummy.getParent().setRight(dummy.getLeft()); 155 | // } 156 | // } 157 | // } else { // case 3 158 | // Node largestInLeft = dummy.getLeft(); 159 | // 160 | // while(largestInLeft.getRight() != null) { 161 | // largestInLeft = largestInLeft.getRight(); 162 | // } 163 | // 164 | // largestInLeft.getParent().setRight(largestInLeft.getLeft()); 165 | // dummy.setVal(largestInLeft.getVal()); 166 | // } 167 | 168 | BSTNode deletion = root, successor; 169 | if (deletion == null) return; 170 | int comparison = deletion.val.compareTo(val); 171 | while (comparison != 0) { 172 | if (comparison > 0) deletion = deletion.left; 173 | else deletion = deletion.right; 174 | if (deletion == null) return; 175 | comparison = deletion.val.compareTo(val); 176 | } 177 | 178 | if (deletion.left == null && deletion.right == null) { 179 | if (deletion == deletion.parent.left) deletion.parent.left = null; 180 | else deletion.parent.right = null; 181 | } else if (deletion.left == null) { 182 | successor = deletion.right; 183 | if (deletion == deletion.parent.left) deletion.parent.left = successor; 184 | else deletion.parent.right = successor; 185 | successor.parent = deletion.parent; 186 | } else if (deletion.right == null) { 187 | successor = deletion.left; 188 | if (deletion == deletion.parent.left) deletion.parent.left = successor; 189 | else deletion.parent.right = successor; 190 | successor.parent = deletion.parent; 191 | } else { 192 | successor = minimum(deletion.right); 193 | BSTNode rightSide = successor.right, update; 194 | if (rightSide == null) { 195 | if (successor.parent == deletion) // noinspection UnusedAssignment 196 | update = successor; 197 | else //noinspection UnusedAssignment 198 | update = successor.parent; 199 | } else //noinspection UnusedAssignment 200 | update = rightSide; 201 | if (successor == successor.parent.right) successor.parent.right = rightSide; 202 | else successor.parent.left = rightSide; 203 | if (rightSide != null) rightSide.parent = successor.parent; 204 | 205 | successor.parent = deletion.parent; 206 | if (deletion == deletion.parent.left) successor.parent.left = successor; 207 | else successor.parent.right = successor; 208 | successor.left = deletion.left; 209 | assert successor.left != null; 210 | successor.left.parent = successor; 211 | successor.right = deletion.right; 212 | if (successor.right != null) successor.right.parent = successor; 213 | } 214 | root = rootParent.left; 215 | } 216 | 217 | private BSTNode minimum(BSTNode root) { 218 | if (root == null || root.left == null) { 219 | return root; 220 | } else return minimum(root.left); 221 | } 222 | 223 | public int getHeight() { 224 | return getHeight(getRoot()); 225 | } 226 | 227 | public boolean isValid() { 228 | return isValid(root); 229 | } 230 | 231 | private boolean isValid(BSTNode root) { 232 | if (root == null) return true; 233 | if (root.left != null && root.left.parent != root) return false; 234 | if (root.right != null && root.right.parent != root) return false; 235 | if (root.parent != null && (root.parent.left != root && root.parent.right != root)) 236 | return false; 237 | if (root.left != null && root.val.compareTo(root.left.val) <= 0) return false; 238 | if (root.right != null && root.val.compareTo(root.right.val) > 0) return false; 239 | return isValid(root.left) && isValid(root.right); 240 | } 241 | 242 | private int getHeight(Node root) { 243 | if (root == null) { 244 | return 0; 245 | } 246 | 247 | return 1 + Math.max(getHeight(root.getLeft()), getHeight(root.getRight())); 248 | } 249 | 250 | /** 251 | * Searches for a node in the tree. 252 | * 253 | * @param val The value of the node you are searching for. 254 | * @return True if the node exists, false otherwise. 255 | */ 256 | @Override 257 | public boolean contains(E val) { 258 | BSTNode dummy = root; 259 | 260 | while (dummy != null) { 261 | if (dummy.val.equals(val)) { 262 | return true; 263 | } else { 264 | if (val.compareTo(dummy.val) > 0) { 265 | dummy = dummy.right; 266 | } else { 267 | dummy = dummy.left; 268 | } 269 | } 270 | } 271 | 272 | return false; 273 | } 274 | 275 | protected Node search(E val) { 276 | if (!contains(val)) return null; 277 | BSTNode dummy = root; 278 | 279 | while (dummy != null) { 280 | if (dummy.val.equals(val)) { 281 | return dummy; 282 | } else { 283 | if (val.compareTo(dummy.val) > 0) { 284 | dummy = dummy.right; 285 | } else { 286 | dummy = dummy.left; 287 | } 288 | } 289 | } 290 | 291 | return null; 292 | } 293 | 294 | @Override 295 | public int numNodes() { 296 | return numNodes(root); 297 | } 298 | 299 | private int numNodes(BSTNode root) { 300 | if (root == null) return 0; 301 | else return numNodes(root.left) + numNodes(root.right) + 1; 302 | } 303 | 304 | @Override 305 | public Iterator> iterator() { 306 | 307 | List> res = new LinkedList<>(); 308 | Queue q = new LinkedList<>(); 309 | if (root != null) q.add(root); 310 | BSTNode current; 311 | while (!q.isEmpty()) { 312 | current = q.remove(); 313 | res.add(current); 314 | if (current.left != null) q.add(current.left); 315 | if (current.right != null) q.add(current.right); 316 | } 317 | return res.iterator(); 318 | } 319 | 320 | /** 321 | * A standard binary search tree node. It has the basic necessities like val, left, right, and 322 | * parent. 323 | */ 324 | protected class BSTNode implements Node { 325 | 326 | /** The value of this binary search tree node. */ 327 | private E val; 328 | 329 | /** The left, right, and parent of this node, respectively. */ 330 | private BSTNode left, right, parent; 331 | 332 | /** 333 | * Creates a new node with a certain value. 334 | * 335 | * @param val The value of this node. 336 | */ 337 | public BSTNode(E val) { 338 | this.val = val; 339 | } 340 | 341 | /** 342 | * Gets the value of this node. 343 | * 344 | * @return The value of this node. 345 | */ 346 | @Override 347 | public E getVal() { 348 | return val; 349 | } 350 | 351 | /** 352 | * Sets the value of this node. 353 | * 354 | * @param newVal This node's new value. 355 | */ 356 | @Override 357 | public void setVal(E newVal) { 358 | val = newVal; 359 | } 360 | 361 | /** 362 | * Gets the left child of this node. 363 | * 364 | * @return The left child of this node. 365 | */ 366 | @Override 367 | public Node getLeft() { 368 | return left; 369 | } 370 | 371 | /** 372 | * Sets the left child of this node. 373 | * 374 | * @param newNode This node's new left child. 375 | */ 376 | @Override 377 | public void setLeft(Node newNode) { 378 | left = (BSTNode) newNode; 379 | } 380 | 381 | /** 382 | * Gets the right child of this node. 383 | * 384 | * @return The right child of this node. 385 | */ 386 | @Override 387 | public Node getRight() { 388 | return right; 389 | } 390 | 391 | /** 392 | * Sets the right child of this node. 393 | * 394 | * @param newNode The node's new right child. 395 | */ 396 | @Override 397 | public void setRight(Node newNode) { 398 | right = (BSTNode) newNode; 399 | } 400 | 401 | /** 402 | * Gets the parent of this node. 403 | * 404 | * @return The parent of this node. 405 | */ 406 | @Override 407 | public Node getParent() { 408 | return parent; 409 | } 410 | 411 | /** 412 | * Sets the parent of this node. 413 | * 414 | * @param newNode The node's new parent. 415 | */ 416 | @Override 417 | public void setParent(Node newNode) { 418 | parent = (BSTNode) newNode; 419 | } 420 | 421 | /** 422 | * Checks if this node is null. 423 | * 424 | * @return If this node is null. 425 | */ 426 | @Override 427 | public boolean isNull() { 428 | return false; 429 | } 430 | 431 | /** Binary search trees don't have colored nodes. */ 432 | @Override 433 | public Color getColor() { 434 | return null; 435 | } 436 | 437 | /** Binary search trees don't have colored nodes. */ 438 | @Override 439 | public void setColor(Color newColor) {} 440 | } 441 | } 442 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/binarysearchtrees/redblacktree/RedBlackTree.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.binarysearchtrees.redblacktree; 2 | 3 | import datastructures.trees.binarysearchtrees.BinaryTree; 4 | import datastructures.trees.binarysearchtrees.Color; 5 | import datastructures.trees.binarysearchtrees.Node; 6 | import java.util.*; 7 | 8 | @SuppressWarnings("ALL") 9 | public class RedBlackTree> implements BinaryTree, Iterable> { 10 | public final RBTreeNode NIL = new RBTreeNode(null, Color.BLACK, null, null, null); 11 | private RBTreeNode root; 12 | 13 | public RedBlackTree() { 14 | root = NIL; 15 | } 16 | 17 | public static void main(String[] args) { 18 | RedBlackTree myTree; 19 | List treeVals; 20 | int insertCorrect = 0, deleteCorrect = 0, trials = 10000; 21 | 22 | for (int i = 0; i < trials; i++) { 23 | myTree = new RedBlackTree<>(); 24 | treeVals = new LinkedList<>(); 25 | 26 | for (int j = 0; j < 31; j++) { 27 | int a = (int) (Math.random() * 100); 28 | treeVals.add(a); 29 | myTree.insert(a); 30 | } 31 | 32 | boolean correct1 = myTree.isValid() && myTree.NIL.getParent() == null; 33 | if (correct1) insertCorrect++; 34 | 35 | for (int j = 0; j < 10; j++) { 36 | int a = (int) (Math.random() * treeVals.size()); 37 | myTree.delete(treeVals.remove(a)); 38 | } 39 | 40 | boolean correct2 = myTree.isValid(); 41 | if (correct2) deleteCorrect++; 42 | } 43 | 44 | System.out.println( 45 | "Insertion correct percentage: " + (((double) insertCorrect) / trials * 100) + "%"); 46 | System.out.println( 47 | "Deletion correct percentage: " + (((double) deleteCorrect) / trials * 100) + "%"); 48 | } 49 | 50 | public Node getRoot() { 51 | return root; 52 | } 53 | 54 | public void insert(E val) { 55 | RBTreeNode x = root, y = NIL; 56 | 57 | while (x != NIL) { 58 | //noinspection SuspiciousNameCombination 59 | y = x; 60 | 61 | if (x.getVal().compareTo(val) > 0) { 62 | x = x.left; 63 | } else x = x.right; 64 | } 65 | 66 | RBTreeNode z = new RBTreeNode(val, Color.RED, y, NIL, NIL); 67 | 68 | if (y == NIL) { 69 | root = z; 70 | } else if (z.getVal().compareTo(y.getVal()) < 0) { 71 | y.left = z; 72 | } else { 73 | y.right = z; 74 | } 75 | insertFix(z); 76 | } 77 | 78 | private void insertFix(RBTreeNode z) { 79 | RBTreeNode y; 80 | while (z.parent.color == Color.RED) { 81 | if (z.parent == z.parent.parent.left) { 82 | //noinspection SuspiciousNameCombination 83 | y = z.parent.parent.right; 84 | if (y.color == Color.RED) { 85 | z.parent.color = Color.BLACK; 86 | y.color = Color.BLACK; 87 | z.parent.parent.color = Color.RED; 88 | z = z.parent.parent; 89 | } else { 90 | if (z == z.parent.right) { 91 | z = z.parent; 92 | leftRotate(z); 93 | } 94 | z.parent.color = Color.BLACK; 95 | z.parent.parent.color = Color.RED; 96 | rightRotate(z.parent.parent); 97 | } 98 | } else { 99 | //noinspection SuspiciousNameCombination 100 | y = z.parent.parent.left; 101 | if (y.color == Color.RED) { 102 | z.parent.color = Color.BLACK; 103 | y.color = Color.BLACK; 104 | z.parent.parent.color = Color.RED; 105 | z = z.parent.parent; 106 | } else { 107 | if (z == z.parent.left) { 108 | z = z.parent; 109 | rightRotate(z); 110 | } 111 | z.parent.color = Color.BLACK; 112 | z.parent.parent.color = Color.RED; 113 | leftRotate(z.parent.parent); 114 | } 115 | } 116 | } 117 | root.setColor(Color.BLACK); 118 | NIL.setParent(null); 119 | } 120 | 121 | private void leftRotate(RBTreeNode x) { 122 | @SuppressWarnings("SuspiciousNameCombination") 123 | RBTreeNode y = x.right; 124 | x.setRight(y.getLeft()); 125 | if (y.getLeft() != NIL) y.getLeft().setParent(x); 126 | y.setParent(x.getParent()); 127 | if (x.getParent() == NIL) root = y; 128 | if (x == x.getParent().getLeft()) // noinspection SuspiciousNameCombination 129 | x.getParent().setLeft(y); 130 | else //noinspection SuspiciousNameCombination 131 | x.getParent().setRight(y); 132 | y.setLeft(x); 133 | x.setParent(y); 134 | } 135 | 136 | private void rightRotate(RBTreeNode y) { 137 | RBTreeNode x = y.left; 138 | y.left = x.right; 139 | if (x.right != NIL) x.right.parent = y; 140 | x.parent = y.parent; 141 | if (y.parent == NIL) root = x; 142 | if (y == y.parent.left) y.parent.left = x; 143 | else y.parent.right = x; 144 | //noinspection SuspiciousNameCombination 145 | x.right = y; 146 | y.parent = x; 147 | } 148 | 149 | public void delete(E key) { 150 | RBTreeNode z; 151 | if ((z = ((RBTreeNode) search(key, root))) == null) return; 152 | RBTreeNode x; 153 | RBTreeNode y = z; // temporary reference y 154 | Color y_original_color = y.getColor(); 155 | 156 | if (z.getLeft() == NIL) { 157 | x = z.getRight(); 158 | transplant(z, z.getRight()); 159 | } else if (z.getRight() == NIL) { 160 | x = z.getLeft(); 161 | transplant(z, z.getLeft()); 162 | } else { 163 | y = successor(z.getRight()); 164 | y_original_color = y.getColor(); 165 | x = y.getRight(); 166 | if (y.getParent() == z) x.setParent(y); 167 | else { 168 | transplant(y, y.getRight()); 169 | y.setRight(z.getRight()); 170 | y.getRight().setParent(y); 171 | } 172 | transplant(z, y); 173 | y.setLeft(z.getLeft()); 174 | y.getLeft().setParent(y); 175 | y.setColor(z.getColor()); 176 | } 177 | if (y_original_color == Color.BLACK) deleteFix(x); 178 | } 179 | 180 | private void deleteFix(RBTreeNode x) { 181 | while (x != root && x.getColor() == Color.BLACK) { 182 | if (x == x.getParent().getLeft()) { 183 | RBTreeNode w = x.getParent().getRight(); 184 | if (w.getColor() == Color.RED) { 185 | w.setColor(Color.BLACK); 186 | x.getParent().setColor(Color.RED); 187 | leftRotate(x.parent); 188 | w = x.getParent().getRight(); 189 | } 190 | if (w.getLeft().getColor() == Color.BLACK && w.getRight().getColor() == Color.BLACK) { 191 | w.setColor(Color.RED); 192 | x = x.getParent(); 193 | continue; 194 | } else if (w.getRight().getColor() == Color.BLACK) { 195 | w.getLeft().setColor(Color.BLACK); 196 | w.setColor(Color.RED); 197 | rightRotate(w); 198 | w = x.getParent().getRight(); 199 | } 200 | if (w.getRight().getColor() == Color.RED) { 201 | w.setColor(x.getParent().getColor()); 202 | x.getParent().setColor(Color.BLACK); 203 | w.getRight().setColor(Color.BLACK); 204 | leftRotate(x.getParent()); 205 | x = root; 206 | } 207 | } else { 208 | RBTreeNode w = (x.getParent().getLeft()); 209 | if (w.color == Color.RED) { 210 | w.color = Color.BLACK; 211 | x.getParent().setColor(Color.RED); 212 | rightRotate(x.getParent()); 213 | w = (x.getParent()).getLeft(); 214 | } 215 | if (w.right.color == Color.BLACK && w.left.color == Color.BLACK) { 216 | w.color = Color.RED; 217 | x = x.getParent(); 218 | continue; 219 | } else if (w.left.color == Color.BLACK) { 220 | w.right.color = Color.BLACK; 221 | w.color = Color.RED; 222 | leftRotate(w); 223 | w = (x.getParent().getLeft()); 224 | } 225 | if (w.left.color == Color.RED) { 226 | w.color = x.getParent().getColor(); 227 | x.getParent().setColor(Color.BLACK); 228 | w.left.color = Color.BLACK; 229 | rightRotate(x.getParent()); 230 | x = root; 231 | } 232 | } 233 | } 234 | x.setColor(Color.BLACK); 235 | } 236 | 237 | private RBTreeNode successor(RBTreeNode root) { 238 | if (root == NIL || root.left == NIL) return root; 239 | else return successor(root.left); 240 | } 241 | 242 | private void transplant(RBTreeNode u, RBTreeNode v) { 243 | // if (u.parent == null) System.out.println(u); 244 | if (u.parent == NIL) { 245 | root = v; 246 | } else if (u == u.parent.left) { 247 | u.parent.left = v; 248 | } else u.parent.right = v; 249 | v.parent = u.parent; 250 | } 251 | 252 | public boolean contains(E val) { 253 | return contains(val, root); 254 | } 255 | 256 | private boolean contains(E val, RBTreeNode root) { 257 | if (root == NIL) return false; 258 | else if (root.getVal().equals(val)) return true; 259 | else if (root.getVal().compareTo(val) < 0) return contains(val, root.right); 260 | else return contains(val, root.left); 261 | } 262 | 263 | private Node search(E val, RBTreeNode root) { 264 | if (root == NIL) return NIL; 265 | else if (root.getVal().equals(val)) return root; 266 | else if (root.getVal().compareTo(val) < 0) return search(val, root.right); 267 | else return search(val, root.left); 268 | } 269 | 270 | public boolean isValid() { 271 | return root.getColor() == Color.BLACK && checkAdjacentReds() && checkBlackHeights(); 272 | } 273 | 274 | private boolean checkAdjacentReds() { 275 | for (Node node : this) { 276 | RBTreeNode current = ((RBTreeNode) node); 277 | if (current == NIL) continue; 278 | if (current.color == Color.RED) { 279 | if (current.parent.color == Color.RED 280 | || current.getLeft().color == Color.RED 281 | || current.getRight().color == Color.RED) { 282 | return false; 283 | } 284 | } 285 | } 286 | 287 | return true; 288 | } 289 | 290 | public int getHeight() { 291 | return getHeight(getRoot()); 292 | } 293 | 294 | private int getHeight(Node root) { 295 | if (root == null || root.isNull()) { 296 | return 0; 297 | } 298 | 299 | return 1 + Math.max(getHeight(root.getLeft()), getHeight(root.getRight())); 300 | } 301 | 302 | private boolean checkBlackHeights() { 303 | for (Node node : this) { 304 | if (node == NIL) continue; 305 | if (countBlacks(root.getLeft()) != countBlacks(root.getRight())) return false; 306 | } 307 | 308 | return true; 309 | } 310 | 311 | private int countBlacks(RBTreeNode root) { 312 | if (root == NIL) { 313 | return 1; 314 | } 315 | return (root.color == Color.BLACK ? 1 : 0) 316 | + Math.max(countBlacks(root.getLeft()), countBlacks(root.getRight())); 317 | } 318 | 319 | public Iterator> iterator() { 320 | Queue> queue = new LinkedList<>(); 321 | List> res = new LinkedList<>(); 322 | queue.offer(root); 323 | Node current; 324 | while (!queue.isEmpty()) { 325 | current = queue.poll(); 326 | if (current == NIL) { 327 | res.add(NIL); 328 | } else { 329 | res.add(current); 330 | queue.offer(current.getLeft()); 331 | queue.offer(current.getRight()); 332 | } 333 | } 334 | return res.iterator(); 335 | } 336 | 337 | public String toString() { 338 | StringBuilder sb = new StringBuilder(); 339 | int length = 0; 340 | sb.append("["); 341 | for (Node node : this) { 342 | sb.append(node).append(", "); 343 | length += node.toString().length(); 344 | if (length >= 40) { 345 | sb.append("\n"); 346 | length = 0; 347 | } 348 | } 349 | sb.deleteCharAt(sb.length() - 1); 350 | sb.append("]"); 351 | return sb.toString(); 352 | } 353 | 354 | @SuppressWarnings("unused") 355 | @Override 356 | public int numNodes() { 357 | return numNodes(root); 358 | } 359 | 360 | private int numNodes(RBTreeNode root) { 361 | if (root == null) return 0; 362 | else return numNodes(root.left) + numNodes(root.right) + 1; 363 | } 364 | 365 | private class RBTreeNode implements Node { 366 | public E val; 367 | public RBTreeNode left, right, parent; 368 | public Color color; 369 | 370 | RBTreeNode(E key, Color color, RBTreeNode parent, RBTreeNode left, RBTreeNode right) { 371 | this.val = key; 372 | this.color = color; 373 | 374 | if (parent == null && left == null && right == null) { 375 | parent = this; 376 | left = this; 377 | right = this; 378 | } 379 | 380 | this.parent = parent; 381 | this.left = left; 382 | this.right = right; 383 | } 384 | 385 | public String toString() { 386 | return "{" + val + ", " + color + '}'; 387 | } 388 | 389 | public E getVal() { 390 | return val; 391 | } 392 | 393 | public void setVal(E val) { 394 | this.val = val; 395 | } 396 | 397 | public boolean isNull() { 398 | return this == NIL; 399 | } 400 | 401 | public RBTreeNode getLeft() { 402 | return left; 403 | } 404 | 405 | public void setLeft(Node left) { 406 | this.left = ((RBTreeNode) left); 407 | } 408 | 409 | public RBTreeNode getRight() { 410 | return right; 411 | } 412 | 413 | public void setRight(Node right) { 414 | this.right = ((RBTreeNode) right); 415 | } 416 | 417 | public RBTreeNode getParent() { 418 | return parent; 419 | } 420 | 421 | public void setParent(Node parent) { 422 | this.parent = ((RBTreeNode) parent); 423 | } 424 | 425 | @Override 426 | public Color getColor() { 427 | return color; 428 | } 429 | 430 | @Override 431 | public void setColor(Color newColor) { 432 | color = newColor; 433 | } 434 | 435 | @Override 436 | public int hashCode() { 437 | if (this == NIL) return 0; 438 | return Objects.hash(val, left, right, color); 439 | } 440 | } 441 | } 442 | -------------------------------------------------------------------------------- /src/main/java/datastructures/trees/trie/Trie.java: -------------------------------------------------------------------------------- 1 | package datastructures.trees.trie; 2 | 3 | public class Trie { 4 | 5 | private final TrieNode root; 6 | 7 | public Trie() { 8 | root = new TrieNode(); 9 | } 10 | 11 | public static void main(String[] args) { 12 | Trie trie = new Trie(); 13 | trie.insert("app"); 14 | trie.insert("apple"); 15 | trie.insert("ap"); 16 | 17 | System.out.println(trie.contains("app")); 18 | System.out.println(trie.contains("appl")); 19 | } 20 | 21 | public void insert(String s) { 22 | TrieNode curr = root; 23 | char[] chars = s.toCharArray(); 24 | 25 | for (char aChar : chars) { 26 | int index = aChar - 'a'; 27 | 28 | if (curr.children[index] == null) { 29 | curr.children[index] = new TrieNode(); 30 | } 31 | 32 | curr = curr.children[index]; 33 | } 34 | 35 | curr.count++; 36 | } 37 | 38 | public boolean contains(String s) { 39 | TrieNode curr = root; 40 | char[] chars = s.toCharArray(); 41 | 42 | for (char aChar : chars) { 43 | int index = aChar - 'a'; 44 | 45 | if (curr.children[index] == null) { 46 | return false; 47 | } 48 | 49 | curr = curr.children[index]; 50 | } 51 | 52 | return curr.count > 0; 53 | } 54 | 55 | private static class TrieNode { 56 | public final TrieNode[] children; 57 | public int count; 58 | 59 | public TrieNode() { 60 | children = new TrieNode[26]; 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/datastructures/unionfind/UnionFind.java: -------------------------------------------------------------------------------- 1 | package datastructures.unionfind; 2 | 3 | public class UnionFind { 4 | 5 | /** Holds the parent node for each element. The parent of root nodes is itself. */ 6 | private final int[] parent; 7 | 8 | /** 9 | * Constructs a {@link UnionFind} with a given size. 10 | * 11 | * @param size The size of the {@link UnionFind}. 12 | */ 13 | public UnionFind(int size) { 14 | if (size <= 0) { 15 | throw new IllegalArgumentException("Can't Have Size Of 0 Or Negative"); 16 | } 17 | 18 | parent = new int[size]; 19 | 20 | // Initially every node is a root, so set the parent to itself. 21 | for (int i = 0; i < size; i++) { 22 | parent[i] = i; 23 | } 24 | } 25 | 26 | /** 27 | * Finds the group number of an element. 28 | * 29 | * @param element The element that you're trying to find the group of. 30 | * @return The group number of the element. 31 | */ 32 | public int find(int element) { 33 | if (parent[element] 34 | == element) { // Parent is itself means that it's a root, meaning that is the group number. 35 | return element; 36 | } else { 37 | parent[element] = 38 | find( 39 | parent[ 40 | element]); // Path compression: set the parent of this element to the root of the 41 | // tree (the find method gets the root). 42 | return parent[element]; // The parent is now the root, so return the parent. 43 | } 44 | } 45 | 46 | /** 47 | * Merges two groups together. 48 | * 49 | * @param element1 The first group, which is the group of the element given. 50 | * @param element2 The second group, which is the group of the element given. 51 | */ 52 | public void union(int element1, int element2) { 53 | int root1 = find(element1); 54 | int root2 = find(element2); 55 | 56 | if (root1 == root2) { // They are already in the same group. 57 | return; 58 | } 59 | 60 | parent[root2] = root1; // Merge the two trees. 61 | } 62 | 63 | /** 64 | * Checks if two elements are in the same group. 65 | * 66 | * @param element1 The first element. 67 | * @param element2 The second element. 68 | * @return True if they are part of the same group, false if they are not. 69 | */ 70 | public boolean sameGroup(int element1, int element2) { 71 | return find(element1) == find(element2); 72 | } 73 | 74 | /** 75 | * Checks if the {@link UnionFind} is currently path compressed. That means that every tree must 76 | * have only two levels. 77 | * 78 | * @return True if the {@link UnionFind} is path compressed. 79 | */ 80 | public boolean isPathCompressed() { 81 | for (int i = 0; i < parent.length; i++) { 82 | if (parent[i] != parent[parent[i]]) { 83 | return false; 84 | } 85 | } 86 | 87 | return true; 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/main/java/dynmanicprogramming/problems/houserobber/HouseRobber.java: -------------------------------------------------------------------------------- 1 | package dynmanicprogramming.problems.houserobber; 2 | 3 | import java.util.Arrays; 4 | 5 | public class HouseRobber { 6 | 7 | private static int[] memo; 8 | 9 | public static int bruteForce(int[] arr) { 10 | return bruteForce(arr, 0); 11 | } 12 | 13 | private static int bruteForce(int[] arr, int i) { 14 | if (i >= arr.length) { 15 | return 0; 16 | } 17 | 18 | return Math.max(arr[i] + bruteForce(arr, i + 2), bruteForce(arr, i + 1)); 19 | } 20 | 21 | public static int memoized(int[] arr) { 22 | memo = new int[arr.length]; 23 | Arrays.fill(memo, -1); 24 | return memoized(arr, 0); 25 | } 26 | 27 | public static int memoized(int[] arr, int i) { 28 | if (i >= arr.length) { 29 | return 0; 30 | } 31 | 32 | if (memo[i] >= 0) { 33 | return memo[i]; 34 | } 35 | 36 | return memo[i] = Math.max(arr[i] + bruteForce(arr, i + 2), bruteForce(arr, i + 1)); 37 | } 38 | 39 | public static int bottomUp(int[] arr) { 40 | if (arr.length == 0) { 41 | return 0; 42 | } else if (arr.length == 1) { 43 | return arr[0]; 44 | } else if (arr.length == 2) { 45 | return Math.max(arr[0], arr[1]); 46 | } 47 | 48 | int[] dp = new int[arr.length]; 49 | dp[0] = arr[0]; 50 | dp[1] = Math.max(arr[0], arr[1]); 51 | 52 | for (int i = 2; i < dp.length; i++) { 53 | dp[i] = Math.max(dp[i - 2] + arr[i], dp[i - 1]); 54 | } 55 | 56 | return dp[dp.length - 1]; 57 | } 58 | 59 | public static void main(String[] args) { 60 | int[] arr = new int[] {1, 2, 3, 1, 4, 1, 3, 4, 3, 2, 1, 3, 4, 3, 2, 1}; // The answer is 21 61 | System.out.println(HouseRobber.memoized(arr)); 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/main/java/dynmanicprogramming/problems/uniquepaths/UniquePaths.java: -------------------------------------------------------------------------------- 1 | package dynmanicprogramming.problems.uniquepaths; 2 | 3 | public class UniquePaths { 4 | 5 | private static int[][] memo; 6 | 7 | public static int topDown(int m, int n) { 8 | memo = new int[m][n]; 9 | return topDown(0, 0, m, n); 10 | } 11 | 12 | private static int topDown(int i, int j, int m, int n) { 13 | if (i >= m || j >= n) { 14 | return 0; 15 | } 16 | 17 | if (i == m - 1 && j == n - 1) { 18 | return 1; 19 | } 20 | 21 | if (memo[i][j] > 0) { 22 | return memo[i][j]; 23 | } 24 | 25 | return memo[i][j] = topDown(i + 1, j, m, n) + topDown(i, j + 1, m, n); 26 | } 27 | 28 | public static int bottomUp(int m, int n) { 29 | int[][] dp = new int[m][n]; 30 | dp[m - 1][n - 1] = 1; 31 | 32 | for (int i = dp.length - 1; i >= 0; i--) { 33 | for (int j = dp[0].length - 1; j >= 0; j--) { 34 | if (i != m - 1 || j != n - 1) { 35 | if (i == m - 1) { 36 | dp[i][j] = dp[i][j + 1]; 37 | } else if (j == n - 1) { 38 | dp[i][j] = dp[i + 1][j]; 39 | } else { 40 | dp[i][j] = dp[i + 1][j] + dp[i][j + 1]; 41 | } 42 | } 43 | } 44 | } 45 | 46 | return dp[0][0]; 47 | } 48 | 49 | public static void main(String[] args) { 50 | System.out.println(UniquePaths.topDown(7, 3)); 51 | System.out.println(UniquePaths.bottomUp(7, 3)); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/graphtheory/shortestpathalgorithms/singlesource/dijkstrasshortestpath/DijkstrasShortestPath.java: -------------------------------------------------------------------------------- 1 | package graphtheory.shortestpathalgorithms.singlesource.dijkstrasshortestpath; 2 | 3 | import datastructures.minpriorityqueue.minheap.MinHeap; 4 | import java.util.Arrays; 5 | import java.util.Comparator; 6 | import java.util.LinkedList; 7 | 8 | public class DijkstrasShortestPath { 9 | 10 | @SuppressWarnings("unchecked") 11 | public static LinkedList[] makeGraph(int V, int[][] E) { 12 | LinkedList[] graph = new LinkedList[V]; 13 | 14 | for (int i = 0; i < graph.length; i++) { 15 | graph[i] = new LinkedList<>(); 16 | } 17 | 18 | for (int[] edge : E) { 19 | graph[edge[0]].add(new int[] {edge[1], edge[2]}); 20 | } 21 | 22 | return graph; 23 | } 24 | 25 | public static int[] dijkstrasShortestPath(LinkedList[] graph, int source) { 26 | int[] weights = new int[graph.length]; 27 | boolean[] states = new boolean[graph.length]; 28 | 29 | Arrays.fill(weights, Integer.MAX_VALUE); 30 | weights[source] = 0; 31 | MinHeap heap = new MinHeap<>(Comparator.comparingInt(i -> weights[i])); 32 | for (int i = 0; i < graph.length; i++) { 33 | heap.add(i); 34 | } 35 | 36 | while (!heap.isEmpty()) { 37 | int nextVertex = heap.extractMin(); 38 | int nextWeight = weights[nextVertex]; 39 | 40 | states[nextVertex] = true; 41 | 42 | for (int[] edge : graph[nextVertex]) { 43 | int neighbor = edge[0]; 44 | int weight = edge[1]; 45 | 46 | if (nextWeight + weight < weights[neighbor]) { 47 | weights[neighbor] = nextWeight + weight; 48 | 49 | if (!states[neighbor]) { 50 | heap.update(neighbor); 51 | } 52 | } 53 | } 54 | } 55 | 56 | return weights; 57 | } 58 | 59 | public static void main(String[] args) { 60 | // 0 -> 1 ↘ 61 | // ⬇ ↗ 3 62 | // 2 -> 4 ↗ 63 | 64 | LinkedList[] graph = 65 | makeGraph( 66 | 5, 67 | new int[][] { 68 | {0, 1, 6}, 69 | {0, 2, 1}, 70 | {2, 1, 1}, 71 | {1, 3, 7}, 72 | {4, 3, 2}, 73 | {2, 4, 3} 74 | }); 75 | 76 | System.out.println(Arrays.toString(dijkstrasShortestPath(graph, 0))); 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/main/java/graphtheory/traversals/breadthfirstsearch/BreadthFirstSearch.java: -------------------------------------------------------------------------------- 1 | package graphtheory.traversals.breadthfirstsearch; 2 | 3 | import java.util.Arrays; 4 | import java.util.LinkedList; 5 | import java.util.Queue; 6 | 7 | @SuppressWarnings("rawtypes") 8 | public class BreadthFirstSearch { 9 | 10 | @SuppressWarnings("unchecked") 11 | public static LinkedList[] makeGraph(int V, int[][] E) { 12 | LinkedList[] graph = new LinkedList[V]; 13 | 14 | for (int i = 0; i < graph.length; i++) { 15 | graph[i] = new LinkedList(); 16 | } 17 | 18 | for (int[] edge : E) { 19 | graph[edge[0]].add(edge[1]); 20 | } 21 | 22 | return graph; 23 | } 24 | 25 | public static int[] breadthFirstSearch(LinkedList[] graph, int start) { 26 | int[] state = new int[graph.length]; 27 | Queue queue = new LinkedList<>(); 28 | queue.offer(start); 29 | state[start] = 1; 30 | int k = 0; 31 | int[] res = new int[graph.length]; 32 | 33 | while (!queue.isEmpty()) { 34 | int next = queue.poll(); 35 | state[next] = 2; 36 | res[k++] = next; 37 | 38 | for (int neighbor : graph[next]) { 39 | if (state[neighbor] == 0) { 40 | state[neighbor] = 1; 41 | queue.offer(neighbor); 42 | } 43 | } 44 | } 45 | 46 | return res; 47 | } 48 | 49 | public static void main(String[] args) { 50 | // 0 -> 1 ↘ 51 | // ⬇ ↙ 2 52 | // 4 <- 3 ↙ 53 | 54 | LinkedList[] graph = 55 | makeGraph(5, new int[][] {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {0, 4}, {1, 4}}); 56 | int[] search = breadthFirstSearch(graph, 0); 57 | System.out.println(Arrays.toString(search)); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/graphtheory/traversals/depthfirstsearch/DepthFirstSearch.java: -------------------------------------------------------------------------------- 1 | package graphtheory.traversals.depthfirstsearch; 2 | 3 | import java.util.Arrays; 4 | import java.util.LinkedList; 5 | import java.util.Stack; 6 | 7 | @SuppressWarnings("rawtypes") 8 | public class DepthFirstSearch { 9 | 10 | @SuppressWarnings("unchecked") 11 | public static LinkedList[] makeGraph(int V, int[][] E) { 12 | LinkedList[] graph = new LinkedList[V]; 13 | 14 | for (int i = 0; i < graph.length; i++) { 15 | graph[i] = new LinkedList(); 16 | } 17 | 18 | for (int[] edge : E) { 19 | graph[edge[0]].add(edge[1]); 20 | } 21 | 22 | return graph; 23 | } 24 | 25 | public static int[] depthFirstSearch(LinkedList[] graph, int start) { 26 | int[] state = new int[graph.length]; 27 | Stack stack = new Stack<>(); 28 | stack.push(start); 29 | state[start] = 1; 30 | int k = 0; 31 | int[] res = new int[graph.length]; 32 | 33 | while (!stack.isEmpty()) { 34 | int next = stack.pop(); 35 | state[next] = 2; 36 | res[k++] = next; 37 | 38 | for (int neighbor : graph[next]) { 39 | if (state[neighbor] == 0) { 40 | state[neighbor] = 1; 41 | stack.push(neighbor); 42 | } 43 | } 44 | } 45 | 46 | return res; 47 | } 48 | 49 | public static void main(String[] args) { 50 | // 0 -> 1 ↘ 51 | // ⬇ ↙ 2 -> 5 52 | // 4 <- 3 ↙ 53 | 54 | LinkedList[] graph = 55 | makeGraph(6, new int[][] {{0, 1}, {1, 2}, {2, 3}, {3, 4}, {0, 4}, {1, 4}, {2, 5}}); 56 | int[] search = depthFirstSearch(graph, 0); 57 | System.out.println(Arrays.toString(search)); 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/java/math/matrices/Matrix.java: -------------------------------------------------------------------------------- 1 | package math.matrices; 2 | 3 | import java.util.Arrays; 4 | import java.util.Objects; 5 | import java.util.function.Function; 6 | 7 | @SuppressWarnings("ALL") 8 | public class Matrix { 9 | /** The representation of the matrix, using a 2D array */ 10 | private double[][] matrix; 11 | 12 | private int rows, cols; 13 | 14 | /** 15 | * Copies the matrix from the other matrix given 16 | * 17 | * @param copy The matrix to copy data from 18 | */ 19 | public Matrix(Matrix copy) { 20 | this(copy.matrix); 21 | rows = copy.rows; 22 | cols = copy.cols; 23 | } 24 | 25 | /** 26 | * Makes an empty matrix with the given number of rows and columns 27 | * 28 | * @param rows The initial number of rows 29 | * @param cols The initial number of columns 30 | */ 31 | public Matrix(int rows, int cols) { 32 | this.rows = rows; 33 | this.cols = cols; 34 | matrix = new double[rows][cols]; 35 | } 36 | 37 | /** 38 | * Sets up the matrix with an initial set of data 39 | * 40 | * @param data The initial set of data 41 | */ 42 | public Matrix(double[][] data) { 43 | if (!isValidData(data)) return; 44 | 45 | rows = data.length; 46 | cols = data[0].length; 47 | matrix = new double[rows][cols]; 48 | setData(data); 49 | } 50 | 51 | /** 52 | * Gets the identity square matrix with a given number of rows and columns 53 | * 54 | * @param len The number of rows and columns 55 | * @return The identity matrix with the given dimensions 56 | */ 57 | public static Matrix identity(int len) { 58 | double[][] res = new double[len][len]; 59 | 60 | for (int i = 0; i < len; i++) { 61 | res[i][i] = 1; 62 | } 63 | 64 | return new Matrix(res); 65 | } 66 | 67 | /** 68 | * Makes a matrix with the given dimensions and filled with random values in the range [-1, 1) 69 | * 70 | * @param rows The number of rows to make the randomized matrix 71 | * @param cols The number of columns to make the randomized matrix 72 | * @return The randomized matrix 73 | */ 74 | public static Matrix randomize(int rows, int cols) { 75 | double[][] res = new double[rows][cols]; 76 | 77 | for (int i = 0; i < res.length; i++) { 78 | for (int j = 0; j < res[i].length; j++) { 79 | res[i][j] = Math.random(); 80 | } 81 | } 82 | 83 | return new Matrix(res); 84 | } 85 | 86 | /** 87 | * Makes a new Matrix with 1 column, given by the array input 88 | * 89 | * @param arr The input to make the new matrix from 90 | * @return The generated matrix 91 | */ 92 | public static Matrix colMatrixFromArray(double[] arr) { 93 | double[][] res = new double[arr.length][1]; 94 | for (int i = 0; i < arr.length; i++) { 95 | res[i][0] = arr[i]; 96 | } 97 | 98 | return new Matrix(res); 99 | } 100 | 101 | /** 102 | * Checks if a set of data, in a 2D array, is a valid matrix 103 | * 104 | * @param data The data to check 105 | * @return Whether the data is valid 106 | */ 107 | public boolean isValidData(double[][] data) { 108 | if (data == null) return false; 109 | if (data.length == 0) return false; 110 | for (double[] arr : data) if (arr == null) return false; 111 | int len = data[0].length; 112 | for (double[] arr : data) { 113 | if (arr.length != len) return false; 114 | } 115 | 116 | return true; 117 | } 118 | 119 | /** 120 | * Sets the data in this matrix at the given position (i, j) 121 | * 122 | * @param i i-position of the value 123 | * @param j j-position of the value 124 | * @param value The value of the data. 125 | */ 126 | public void setData(int i, int j, double value) { 127 | if (i >= 0 && i < rows && j >= 0 && j < cols) { 128 | matrix[i][j] = value; 129 | } 130 | } 131 | 132 | /** 133 | * Gets the data in this matrix at a given (i, j) position 134 | * 135 | * @param i i-position of the value 136 | * @param j j-position of the value 137 | * @return The value at the given position 138 | */ 139 | @SuppressWarnings("unused") 140 | public double getData(int i, int j) { 141 | if (i >= 0 && i < rows && j >= 0 && j < cols) { 142 | return matrix[i][j]; 143 | } else return 0; 144 | } 145 | 146 | /** 147 | * Gets the data that is stored in this matrix, in the form of a 2D array 148 | * 149 | * @return The 2D array representation of the matrix 150 | */ 151 | public double[][] getData() { 152 | return Arrays.copyOf(matrix, matrix.length); 153 | } 154 | 155 | /** 156 | * The sets the data in this matrix to the data given. This operation can fail if the data given 157 | * has different dimensions as the matrix 158 | * 159 | * @param data The data to set into this matrix 160 | */ 161 | public void setData(double[][] data) { 162 | if (data.length != rows) return; 163 | for (int i = 0; i < rows; i++) { 164 | if (data[i].length != cols) return; 165 | matrix[i] = Arrays.copyOf(data[i], data[i].length); 166 | } 167 | } 168 | 169 | /** 170 | * Matrix adds this matrix to the other given matrix. This operation can fail (return null) if the 171 | * size of the matrix is not equal to the size of the other matrix 172 | * 173 | * @param other The matrix to add to 174 | * @return The sum of the two matrices 175 | */ 176 | public Matrix add(Matrix other) { 177 | if (rows != other.rows || cols != other.cols) return null; 178 | double[][] res = new double[rows][cols]; 179 | 180 | for (int i = 0; i < rows; i++) { 181 | for (int j = 0; j < cols; j++) { 182 | res[i][j] = matrix[i][j] + other.matrix[i][j]; 183 | } 184 | } 185 | 186 | return new Matrix(res); 187 | } 188 | 189 | /** 190 | * Subtracts the other matrix from this one 191 | * 192 | * @param other The matrix to be subtracted 193 | * @return The difference of the two matrices 194 | */ 195 | public Matrix subtract(Matrix other) { 196 | return this.add(other.multiply(-1)); 197 | } 198 | 199 | /** 200 | * Multiplies this matrix with another given matrix 201 | * 202 | * @param other The matrix to multiply by 203 | * @return The product of the two matrices 204 | */ 205 | public Matrix multiply(Matrix other) { 206 | if (cols != other.rows) return null; 207 | 208 | double[][] res = new double[rows][other.cols]; 209 | 210 | for (int i = 0; i < rows; i++) { 211 | for (int j = 0; j < other.cols; j++) { 212 | for (int k = 0; k < cols; k++) { 213 | res[i][j] += matrix[i][k] * other.matrix[k][j]; 214 | } 215 | } 216 | } 217 | 218 | return new Matrix(res); 219 | } 220 | 221 | public Matrix hadamardMultiply(Matrix other) { 222 | if (rows != other.rows || cols != other.cols) return null; 223 | 224 | double[][] res = new double[rows][cols]; 225 | 226 | for (int i = 0; i < rows; i++) { 227 | for (int j = 0; j < cols; j++) { 228 | res[i][j] = matrix[i][j] * other.matrix[i][j]; 229 | } 230 | } 231 | 232 | return new Matrix(res); 233 | } 234 | 235 | /** 236 | * Scalar multiplies this matrix with the given scalar 237 | * 238 | * @param scalar The scalar to multiply by 239 | * @return The product of the multiplication 240 | */ 241 | public Matrix multiply(double scalar) { 242 | double[][] res = new double[rows][cols]; 243 | 244 | for (int i = 0; i < res.length; i++) { 245 | for (int j = 0; j < res[i].length; j++) { 246 | res[i][j] = matrix[i][j] * scalar; 247 | } 248 | } 249 | 250 | return new Matrix(res); 251 | } 252 | 253 | /** 254 | * Gets the transposition of this matrix 255 | * 256 | * @return The transposition 257 | */ 258 | public Matrix transpose() { 259 | double[][] res = new double[cols][rows]; 260 | 261 | for (int i = 0; i < rows; i++) { 262 | for (int j = 0; j < cols; j++) { 263 | res[j][i] = matrix[i][j]; 264 | } 265 | } 266 | 267 | return new Matrix(res); 268 | } 269 | 270 | /** 271 | * Gets a specific row of this Matrix, also represented as a matrix 272 | * 273 | * @param row The row to get 274 | * @return The row at the given position 275 | */ 276 | public Matrix getRow(int row) { 277 | if (row > rows) return null; 278 | 279 | double[][] res = new double[1][cols]; 280 | if (cols < 0) { 281 | return new Matrix(res); 282 | } 283 | System.arraycopy(matrix[row], 0, res[0], 0, cols); 284 | 285 | return new Matrix(res); 286 | } 287 | 288 | /** 289 | * Gets a specific col of this Matrix, also represented as a matrix 290 | * 291 | * @param col The column to get 292 | * @return The col at the given position 293 | */ 294 | public Matrix getColumn(int col) { 295 | if (col > cols) return null; 296 | 297 | double[][] res = new double[rows][1]; 298 | for (int i = 0; i < rows; i++) { 299 | res[i][0] = matrix[i][col]; 300 | } 301 | 302 | return new Matrix(res); 303 | } 304 | 305 | public double[] colMatrixToArray() { 306 | double[] res = new double[rows]; 307 | for (int i = 0; i < rows; i++) { 308 | res[i] = matrix[i][0]; 309 | } 310 | 311 | return res; 312 | } 313 | 314 | /** 315 | * Applies the given function to each value in the matrix 316 | * 317 | * @param func The function to apply 318 | * @return The result after applying the function 319 | */ 320 | public Matrix forEach(Function func) { 321 | double[][] res = new double[rows][cols]; 322 | 323 | for (int i = 0; i < matrix.length; i++) { 324 | for (int j = 0; j < matrix[i].length; j++) { 325 | res[i][j] = func.apply(matrix[i][j]); 326 | } 327 | } 328 | 329 | return new Matrix(res); 330 | } 331 | 332 | @Override 333 | public String toString() { 334 | StringBuilder sb = new StringBuilder(); 335 | sb.append("{\n\t"); 336 | for (int i = 0; i < rows - 1; i++) { 337 | sb.append(Arrays.toString(matrix[i])); 338 | sb.append(",\n\t"); 339 | } 340 | sb.append(Arrays.toString(matrix[rows - 1])); 341 | sb.append("\n}").append(", rows = ").append(rows).append(", cols = ").append(cols); 342 | 343 | return sb.toString(); 344 | } 345 | 346 | @Override 347 | public boolean equals(Object o) { 348 | if (this == o) return true; 349 | if (!(o instanceof Matrix)) return false; 350 | Matrix matrix1 = (Matrix) o; 351 | if (rows != matrix1.rows || cols != matrix1.cols) return false; 352 | for (int i = 0; i < rows; i++) { 353 | for (int j = 0; j < matrix[i].length; j++) { 354 | if (matrix[i][j] != matrix1.matrix[i][j]) return false; 355 | } 356 | } 357 | 358 | return true; 359 | } 360 | 361 | @Override 362 | public int hashCode() { 363 | int result = Objects.hash(rows, cols); 364 | result = 31 * result + Arrays.hashCode(matrix); 365 | return result; 366 | } 367 | } 368 | -------------------------------------------------------------------------------- /src/main/java/math/vectors/Vector3.java: -------------------------------------------------------------------------------- 1 | package math.vectors; 2 | 3 | public class Vector3 { 4 | 5 | private final double x, y, z, length; 6 | private Vector3 normalized; 7 | 8 | public Vector3(final double value) { 9 | this(value, value, value); 10 | } 11 | 12 | public Vector3(final double x, final double y, final double z) { 13 | this.x = x; 14 | this.y = y; 15 | this.z = z; 16 | length = Math.sqrt(norm()); 17 | } 18 | 19 | public Vector3 add(final double x, final double y, final double z) { 20 | return new Vector3(this.x + x, this.y + y, this.z + z); 21 | } 22 | 23 | public Vector3 add(Vector3 other) { 24 | return add(other.x, other.y, other.z); 25 | } 26 | 27 | public Vector3 subtract(final double x, final double y, final double z) { 28 | return add(-x, -y, -z); 29 | } 30 | 31 | public Vector3 subtract(Vector3 other) { 32 | return subtract(other.x, other.y, other.z); 33 | } 34 | 35 | public Vector3 multiply(final double x, final double y, final double z) { 36 | return new Vector3(this.x * x, this.y * y, this.z * z); 37 | } 38 | 39 | public Vector3 multiply(final Vector3 other) { 40 | return multiply(other.x, other.y, other.z); 41 | } 42 | 43 | public Vector3 divide(final double x, final double y, final double z) { 44 | return multiply(1 / x, 1 / y, 1 / z); 45 | } 46 | 47 | public Vector3 divide(final Vector3 other) { 48 | return divide(other.x, other.y, other.z); 49 | } 50 | 51 | public Vector3 multiply(final double scalar) { 52 | return new Vector3(x * scalar, y * scalar, z * scalar); 53 | } 54 | 55 | public Vector3 normalized() { 56 | if (normalized == null) { 57 | if (length > 0) { 58 | double inverseLength = 1 / length; 59 | return normalized = new Vector3(x * inverseLength, y * inverseLength, z * inverseLength); 60 | } else { 61 | return normalized = new Vector3(1, 1, 1).normalized; 62 | } 63 | } 64 | 65 | return normalized; 66 | } 67 | 68 | public double dotProduct(final Vector3 other) { 69 | return x * other.x + y * other.y + z * other.z; 70 | } 71 | 72 | public Vector3 crossProduct(final Vector3 other) { 73 | return new Vector3( 74 | y * other.z - z * other.y, z * other.x - x * other.z, x * other.y - y * other.x); 75 | } 76 | 77 | public double norm() { 78 | return dotProduct(this); 79 | } 80 | 81 | public double getX() { 82 | return x; 83 | } 84 | 85 | public double getY() { 86 | return y; 87 | } 88 | 89 | public double getZ() { 90 | return z; 91 | } 92 | 93 | public double getLength() { 94 | return length; 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/EfficientNeuralNetwork.java: -------------------------------------------------------------------------------- 1 | package neuralnetworks; 2 | 3 | import java.util.function.Function; 4 | import org.jblas.DoubleMatrix; 5 | 6 | /** 7 | * You need jblas library installed to run this. 8 | * 9 | * @see http://jblas.org/ 10 | */ 11 | @SuppressWarnings("ALL") 12 | public class EfficientNeuralNetwork { 13 | private final DoubleMatrix[] weights; 14 | private final DoubleMatrix[] biases; 15 | private final double learningRate; 16 | private final Function activation; 17 | private final Function dactivation; 18 | 19 | public EfficientNeuralNetwork(int inputs, int outputs, double learningRate, int... hiddenLayers) { 20 | if (hiddenLayers.length < 1) throw new RuntimeException("Cannot have 0 hidden layers!"); 21 | this.learningRate = learningRate; 22 | 23 | weights = new DoubleMatrix[hiddenLayers.length + 1]; 24 | for (int i = 0; i < weights.length; i++) { 25 | if (i == 0) weights[i] = DoubleMatrix.rand(hiddenLayers[i], inputs); 26 | else if (i == weights.length - 1) 27 | weights[i] = DoubleMatrix.rand(outputs, hiddenLayers[i - 1]); 28 | else weights[i] = DoubleMatrix.rand(hiddenLayers[i], hiddenLayers[i - 1]); 29 | } 30 | biases = new DoubleMatrix[hiddenLayers.length + 1]; 31 | for (int i = 0; i < biases.length; i++) { 32 | if (i == biases.length - 1) biases[i] = DoubleMatrix.rand(outputs, 1); 33 | else biases[i] = DoubleMatrix.rand(hiddenLayers[i], 1); 34 | } 35 | 36 | activation = (x) -> 1 / (1 + Math.exp(-x)); 37 | dactivation = (x) -> x * (1 - x); 38 | } 39 | 40 | public double[] predict(double[] inputArr) { 41 | DoubleMatrix previous = new DoubleMatrix(inputArr); 42 | 43 | for (int i = 0; i < weights.length; i++) { 44 | previous = feedForward(previous, weights[i], biases[i]); 45 | } 46 | 47 | return previous.toArray(); 48 | } 49 | 50 | private DoubleMatrix feedForward(DoubleMatrix inputs, DoubleMatrix weights, DoubleMatrix biases) { 51 | DoubleMatrix res = weights.mmul(inputs); 52 | res = res.add(biases); 53 | for (int i = 0; i < res.data.length; i++) { 54 | res.data[i] = activation.apply(res.data[i]); 55 | } 56 | 57 | return res; 58 | } 59 | 60 | public void train(double[] inputArr, double[] targetArr) { 61 | DoubleMatrix inputs = new DoubleMatrix(inputArr), targets = new DoubleMatrix(targetArr); 62 | 63 | DoubleMatrix[] actual = new DoubleMatrix[weights.length]; 64 | for (int i = 0; i < weights.length; i++) { 65 | if (i == 0) actual[i] = feedForward(inputs, weights[i], biases[i]); 66 | else actual[i] = feedForward(actual[i - 1], weights[i], biases[i]); 67 | } 68 | 69 | DoubleMatrix[] errors = new DoubleMatrix[weights.length]; 70 | for (int i = errors.length - 1; i >= 0; i--) { 71 | if (i == errors.length - 1) errors[i] = targets.sub(actual[i]); 72 | else errors[i] = weights[i + 1].transpose().mmul(errors[i + 1]); 73 | } 74 | 75 | DoubleMatrix[] deltas; 76 | for (int i = weights.length - 1; i >= 0; i--) { 77 | if (i == 0) deltas = findDeltas(actual[i], inputs, errors[i]); 78 | else deltas = findDeltas(actual[i], actual[i - 1], errors[i]); 79 | weights[i] = weights[i].add(deltas[0]); 80 | biases[i] = biases[i].add(deltas[1]); 81 | } 82 | } 83 | 84 | private DoubleMatrix[] findDeltas(DoubleMatrix actual, DoubleMatrix input, DoubleMatrix errors) { 85 | DoubleMatrix gradient = new DoubleMatrix(); 86 | gradient.copy(actual); 87 | for (int i = 0; i < gradient.data.length; i++) { 88 | gradient.data[i] = dactivation.apply(gradient.data[i]); 89 | } 90 | gradient = gradient.mul(errors); 91 | gradient = gradient.mmul(learningRate); 92 | 93 | DoubleMatrix inputsT = input.transpose(); 94 | DoubleMatrix weightsDeltas = gradient.mmul(inputsT); 95 | 96 | return new DoubleMatrix[] {weightsDeltas, gradient}; 97 | } 98 | 99 | @SuppressWarnings("unused") 100 | public DoubleMatrix[] getWeights() { 101 | return weights; 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/MNISTTrainer.java: -------------------------------------------------------------------------------- 1 | package neuralnetworks; 2 | 3 | import java.awt.*; 4 | import java.awt.image.BufferedImage; 5 | import java.io.IOException; 6 | import java.util.Arrays; 7 | import javax.swing.*; 8 | import neuralnetworks.mnistdata.MnistEntry; 9 | import neuralnetworks.mnistdata.MnistLoader; 10 | 11 | public class MNISTTrainer extends JPanel { 12 | private static final int WIDTH = 900, HEIGHT = 900; 13 | private final Object imgLock = new Object(); 14 | private BufferedImage img; 15 | private int label; 16 | private int i; 17 | private int numCorrect; 18 | 19 | public MNISTTrainer() { 20 | JFrame frame = new JFrame("Testing single perceptrons"); 21 | frame.setBounds(100, 100, WIDTH, HEIGHT); 22 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 23 | frame.setResizable(true); 24 | this.setBackground(Color.WHITE); 25 | this.setLayout(null); 26 | this.setPreferredSize(new Dimension(WIDTH, HEIGHT)); 27 | frame.add(this); 28 | frame.setVisible(true); 29 | } 30 | 31 | public static void main(String[] args) { 32 | new MNISTTrainer().run(); 33 | } 34 | 35 | public void run() { 36 | MnistLoader loader = new MnistLoader(); 37 | MnistEntry[] trainingEntries = null, testingEntries = null; 38 | MultilayeredNeuralNetwork nn = 39 | new MultilayeredNeuralNetwork(28 * 28, 10, 0.1, 28 * 28 * 2, 28 * 28 * 2); 40 | double[] inputArr, outputArr; 41 | String path = "src/Main/java/com/nishant/algorithms/NeuralNetworks/MNISTData"; 42 | try { 43 | trainingEntries = loader.readDecompressedTraining(path); 44 | testingEntries = loader.readDecompressedTesting(path); 45 | } catch (IOException e) { 46 | e.printStackTrace(); 47 | } 48 | 49 | assert trainingEntries != null && testingEntries != null; 50 | 51 | inputArr = new double[trainingEntries[0].getImageData().length]; 52 | double[][] expectedOutputs = new double[10][10]; 53 | for (int j = 0; j < expectedOutputs.length; j++) { 54 | expectedOutputs[j][j] = 1; 55 | } 56 | 57 | int imgScale = WIDTH / trainingEntries[0].getNumRows(); 58 | 59 | i = 0; 60 | for (MnistEntry entry : trainingEntries) { 61 | synchronized (imgLock) { 62 | img = scale(entry.createImage(), imgScale, imgScale); 63 | label = entry.getLabel(); 64 | } 65 | 66 | repaint(); 67 | 68 | byte[] bytes = entry.getImageData(); 69 | for (int i = 0; i < bytes.length; i++) { 70 | inputArr[i] = bytes[i] / 128.0; 71 | } 72 | outputArr = expectedOutputs[label]; 73 | nn.train( 74 | Arrays.copyOf(inputArr, inputArr.length), Arrays.copyOf(outputArr, outputArr.length)); 75 | 76 | i++; 77 | 78 | try { 79 | Thread.sleep(5); 80 | } catch (InterruptedException e) { 81 | e.printStackTrace(); 82 | } 83 | } 84 | 85 | i = 0; 86 | numCorrect = 0; 87 | 88 | for (MnistEntry entry : testingEntries) { 89 | synchronized (imgLock) { 90 | img = scale(entry.createImage(), imgScale, imgScale); 91 | label = entry.getLabel(); 92 | } 93 | 94 | byte[] bytes = entry.getImageData(); 95 | for (int i = 0; i < bytes.length; i++) { 96 | inputArr[i] = bytes[i] / 128.0; 97 | } 98 | outputArr = nn.predict(inputArr); 99 | double maxVal = -1, maxIndex = -1; 100 | for (int j = 0; j < outputArr.length; j++) { 101 | if (outputArr[j] > maxVal) { 102 | maxIndex = j; 103 | maxVal = outputArr[j]; 104 | } 105 | } 106 | 107 | if (maxIndex == label) { 108 | numCorrect++; 109 | } 110 | 111 | repaint(); 112 | 113 | i++; 114 | 115 | try { 116 | Thread.sleep(5); 117 | } catch (InterruptedException e) { 118 | e.printStackTrace(); 119 | } 120 | } 121 | 122 | repaint(); 123 | } 124 | 125 | @Override 126 | protected void paintComponent(Graphics g) { 127 | super.paintComponent(g); 128 | 129 | synchronized (imgLock) { 130 | if (img == null) return; 131 | g.drawImage(img, 0, 0, null); 132 | g.setColor(Color.WHITE); 133 | g.setFont(new Font("Arial", Font.BOLD, 50)); 134 | g.drawString(String.valueOf(label), img.getWidth() / 10, img.getHeight() * 9 / 10); 135 | g.drawString(String.valueOf(i), img.getWidth() * 7 / 10, img.getHeight() * 9 / 10); 136 | g.drawString(String.valueOf(numCorrect), img.getWidth() / 10, img.getHeight() / 10); 137 | } 138 | } 139 | 140 | private BufferedImage scale(BufferedImage img, int xscale, int yscale) { 141 | BufferedImage res = 142 | new BufferedImage( 143 | img.getWidth() * xscale, img.getHeight() * yscale, BufferedImage.TYPE_INT_ARGB); 144 | for (int i = 0; i < img.getWidth(); i++) { 145 | for (int j = 0; j < img.getHeight(); j++) { 146 | for (int k = 0; k < xscale; k++) { 147 | for (int l = 0; l < yscale; l++) { 148 | res.setRGB(i * xscale + k, j * yscale + l, img.getRGB(i, j)); 149 | } 150 | } 151 | } 152 | } 153 | 154 | return res; 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/MultilayeredNeuralNetwork.java: -------------------------------------------------------------------------------- 1 | package neuralnetworks; 2 | 3 | import java.util.function.Function; 4 | import math.matrices.Matrix; 5 | 6 | public class MultilayeredNeuralNetwork { 7 | private final Matrix[] weights; 8 | private final Matrix[] biases; 9 | private final double learningRate; 10 | private final Function activation; 11 | private final Function dactivation; 12 | 13 | public MultilayeredNeuralNetwork( 14 | int inputs, int outputs, double learningRate, int... hiddenLayers) { 15 | if (hiddenLayers.length < 1) throw new RuntimeException("Cannot have 0 hidden layers!"); 16 | this.learningRate = learningRate; 17 | 18 | weights = new Matrix[hiddenLayers.length + 1]; 19 | for (int i = 0; i < weights.length; i++) { 20 | if (i == 0) weights[i] = Matrix.randomize(hiddenLayers[i], inputs); 21 | else if (i == weights.length - 1) weights[i] = Matrix.randomize(outputs, hiddenLayers[i - 1]); 22 | else weights[i] = Matrix.randomize(hiddenLayers[i], hiddenLayers[i - 1]); 23 | } 24 | biases = new Matrix[hiddenLayers.length + 1]; 25 | for (int i = 0; i < biases.length; i++) { 26 | if (i == biases.length - 1) biases[i] = Matrix.randomize(outputs, 1); 27 | else biases[i] = Matrix.randomize(hiddenLayers[i], 1); 28 | } 29 | 30 | activation = (x) -> 1 / (1 + Math.exp(-x)); 31 | dactivation = (x) -> x * (1 - x); 32 | } 33 | 34 | public double[] predict(double[] inputArr) { 35 | Matrix previous = Matrix.colMatrixFromArray(inputArr); 36 | 37 | for (int i = 0; i < weights.length; i++) { 38 | previous = feedForward(previous, weights[i], biases[i]); 39 | } 40 | 41 | return previous.colMatrixToArray(); 42 | } 43 | 44 | private Matrix feedForward(Matrix inputs, Matrix weights, Matrix biases) { 45 | Matrix res = weights.multiply(inputs); 46 | res = res.add(biases); 47 | res = res.forEach(this.activation); 48 | 49 | return res; 50 | } 51 | 52 | public void train(double[] inputArr, double[] targetArr) { 53 | Matrix inputs = Matrix.colMatrixFromArray(inputArr), 54 | targets = Matrix.colMatrixFromArray(targetArr); 55 | 56 | Matrix[] actual = new Matrix[weights.length]; 57 | for (int i = 0; i < weights.length; i++) { 58 | if (i == 0) actual[i] = feedForward(inputs, weights[i], biases[i]); 59 | else actual[i] = feedForward(actual[i - 1], weights[i], biases[i]); 60 | } 61 | 62 | Matrix[] errors = new Matrix[weights.length]; 63 | for (int i = errors.length - 1; i >= 0; i--) { 64 | if (i == errors.length - 1) errors[i] = targets.subtract(actual[i]); 65 | else errors[i] = weights[i + 1].transpose().multiply(errors[i + 1]); 66 | } 67 | 68 | Matrix[] deltas; 69 | for (int i = weights.length - 1; i >= 0; i--) { 70 | if (i == 0) deltas = findDeltas(actual[i], inputs, errors[i]); 71 | else deltas = findDeltas(actual[i], actual[i - 1], errors[i]); 72 | weights[i] = weights[i].add(deltas[0]); 73 | biases[i] = biases[i].add(deltas[1]); 74 | } 75 | } 76 | 77 | private Matrix[] findDeltas(Matrix actual, Matrix input, Matrix errors) { 78 | Matrix gradient = actual.forEach(dactivation); 79 | gradient = gradient.hadamardMultiply(errors); 80 | gradient = gradient.multiply(learningRate); 81 | 82 | Matrix inputsT = input.transpose(); 83 | Matrix weightsDeltas = gradient.multiply(inputsT); 84 | 85 | return new Matrix[] {weightsDeltas, gradient}; 86 | } 87 | 88 | @SuppressWarnings("unused") 89 | public Matrix[] getWeights() { 90 | return weights; 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/OneHiddenLayerNeuralNetwork.java: -------------------------------------------------------------------------------- 1 | package neuralnetworks; 2 | 3 | import math.matrices.Matrix; 4 | 5 | public class OneHiddenLayerNeuralNetwork { 6 | private final double learningRate; 7 | private Matrix weightsIH, weightsHO, biasH, biasO; 8 | 9 | public OneHiddenLayerNeuralNetwork(int inputs, int hidden, int output) { 10 | 11 | weightsIH = Matrix.randomize(hidden, inputs); 12 | weightsHO = Matrix.randomize(output, hidden); 13 | biasH = Matrix.randomize(hidden, 1); 14 | biasO = Matrix.randomize(output, 1); 15 | 16 | learningRate = 5; 17 | } 18 | 19 | public double[] predict(double[] inputArray) { 20 | Matrix inputs = Matrix.colMatrixFromArray(inputArray); 21 | 22 | Matrix hidden = weightsIH.multiply(inputs); 23 | hidden = hidden.add(biasH); 24 | hidden = hidden.forEach(this::sigmoid); 25 | 26 | Matrix output = weightsHO.multiply(hidden); 27 | output = output.add(biasO); 28 | output = output.forEach(this::sigmoid); 29 | 30 | return output.colMatrixToArray(); 31 | } 32 | 33 | private double sigmoid(double x) { 34 | return 1 / (1 + Math.exp(-x)); 35 | } 36 | 37 | private double dsigmoid(double x) { 38 | return x * (1 - x); 39 | } 40 | 41 | public void train(double[] inputArr, double[] targetArr) { 42 | Matrix inputs = Matrix.colMatrixFromArray(inputArr); 43 | Matrix hidden = weightsIH.multiply(inputs); 44 | hidden = hidden.add(biasH); 45 | hidden = hidden.forEach(this::sigmoid); 46 | 47 | Matrix outputs = weightsHO.multiply(hidden); 48 | outputs = outputs.add(biasO); 49 | outputs = outputs.forEach(this::sigmoid); 50 | 51 | Matrix targets = Matrix.colMatrixFromArray(targetArr); 52 | 53 | // Error = targets - outputs 54 | Matrix outputErrors = targets.subtract(outputs); 55 | 56 | // Matrix weightsHOT = weightsHO.transpose(); 57 | // Matrix hiddenErrors = weightsHOT.multiply(outputErrors); 58 | 59 | Matrix gradients = outputs.forEach(this::dsigmoid); 60 | gradients = gradients.hadamardMultiply(outputErrors); 61 | gradients = gradients.multiply(learningRate); 62 | 63 | Matrix hiddenT = hidden.transpose(); 64 | Matrix weightsHODeltas = gradients.multiply(hiddenT); 65 | 66 | weightsHO = weightsHO.add(weightsHODeltas); 67 | biasO = biasO.add(gradients); 68 | 69 | // -------------------------------------------------------------------- 70 | 71 | Matrix weightsHOT = weightsHO.transpose(); 72 | Matrix hiddenErrors = weightsHOT.multiply(outputErrors); 73 | 74 | Matrix hiddenGradient = hidden.forEach(this::dsigmoid); 75 | hiddenGradient = hiddenGradient.hadamardMultiply(hiddenErrors); 76 | hiddenGradient = hiddenGradient.multiply(learningRate); 77 | 78 | Matrix inputsT = inputs.transpose(); 79 | Matrix weightsIHDeltas = hiddenGradient.multiply(inputsT); 80 | 81 | weightsIH = weightsIH.add(weightsIHDeltas); 82 | biasH = biasH.add(hiddenGradient); 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/Perceptron.java: -------------------------------------------------------------------------------- 1 | package neuralnetworks; 2 | 3 | import java.util.Arrays; 4 | 5 | @SuppressWarnings("unused") 6 | public class Perceptron { 7 | private final double[] weights; 8 | private final double learningRate; 9 | 10 | public Perceptron() { 11 | weights = new double[2]; 12 | for (int i = 0; i < weights.length; i++) { 13 | weights[i] = Math.random() * 2 - 1; 14 | } 15 | learningRate = 0.4; 16 | 17 | System.out.println(Arrays.toString(weights)); 18 | } 19 | 20 | public int guess(double[] inputs) { 21 | double sum = 0; 22 | 23 | for (int i = 0; i < weights.length; i++) { 24 | sum += inputs[i] * weights[i]; 25 | } 26 | 27 | return ((int) activate(sum)); 28 | } 29 | 30 | private double activate(double n) { 31 | return n > 0 ? 1 : -1; 32 | } 33 | 34 | public void train(double[] inputs, int target) { 35 | int guess = guess(inputs); 36 | double error = target - guess; 37 | 38 | for (int i = 0; i < weights.length; i++) { 39 | weights[i] += error * inputs[i] * learningRate; 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/Tester.java: -------------------------------------------------------------------------------- 1 | package neuralnetworks; 2 | 3 | import java.awt.*; 4 | import java.util.Arrays; 5 | import java.util.HashMap; 6 | import java.util.Map; 7 | import javax.swing.*; 8 | 9 | @SuppressWarnings("ALL") 10 | public class Tester extends JPanel { 11 | private static final int WIDTH = 900, HEIGHT = 900; 12 | private final Map points = new HashMap<>(); 13 | private final int trainingSize; 14 | private Perceptron perceptron; 15 | private int frameCount = 0; 16 | 17 | public Tester(int width, int height) { 18 | JFrame frame = new JFrame("Testing single perceptrons"); 19 | frame.setBounds(100, 100, width, height); 20 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 21 | frame.setResizable(true); 22 | this.setBackground(Color.WHITE); 23 | this.setLayout(null); 24 | this.setPreferredSize(new Dimension(width, height)); 25 | frame.add(this); 26 | frame.setVisible(true); 27 | 28 | trainingSize = 100; 29 | } 30 | 31 | public static void main(String[] args) { 32 | // Tester tester = new Tester(WIDTH, HEIGHT); 33 | // 34 | // tester.perceptron = new Perceptron(); 35 | // 36 | // tester.run(); 37 | 38 | OneHiddenLayerNeuralNetwork nn = new OneHiddenLayerNeuralNetwork(2, 4, 1); 39 | double[][][] trainingData = new double[4][2][]; 40 | trainingData[0][0] = new double[] {1, 1}; 41 | trainingData[0][1] = new double[] {0}; 42 | trainingData[1][0] = new double[] {0, 0}; 43 | trainingData[1][1] = new double[] {0}; 44 | trainingData[2][0] = new double[] {1, 0}; 45 | trainingData[2][1] = new double[] {1}; 46 | trainingData[3][0] = new double[] {0, 1}; 47 | trainingData[3][1] = new double[] {1}; 48 | 49 | for (int i = 0; i < 50000; i++) { 50 | for (int j = 0; j < trainingData.length; j++) { 51 | int index = ((int) (Math.random() * trainingData.length)); 52 | nn.train(trainingData[index][0], trainingData[index][1]); 53 | // nn.train(trainingData[j][0], trainingData[j][1]); 54 | } 55 | } 56 | 57 | System.out.println(Arrays.toString(nn.predict(trainingData[0][0]))); 58 | System.out.println(Arrays.toString(nn.predict(trainingData[1][0]))); 59 | System.out.println(Arrays.toString(nn.predict(trainingData[2][0]))); 60 | System.out.println(Arrays.toString(nn.predict(trainingData[3][0]))); 61 | 62 | // MultilayeredNeuralNetwork mnn = new MultilayeredNeuralNetwork(2, 1, 1, 1); 63 | // System.out.println(Arrays.toString(mnn.predict(new double[]{0, 1}))); 64 | } 65 | 66 | @SuppressWarnings("unused") 67 | public void run() { 68 | Point current; 69 | for (int i = 0; i < trainingSize; i++) { 70 | current = new Point(((int) (Math.random() * WIDTH)), ((int) (Math.random() * HEIGHT))); 71 | if (current.x > current.y) { 72 | points.put(current, 1); 73 | } else points.put(current, -1); 74 | } 75 | 76 | while (true) { 77 | repaint(); 78 | try { 79 | Thread.sleep(100); 80 | } catch (InterruptedException e) { 81 | e.printStackTrace(); 82 | } 83 | } 84 | } 85 | 86 | protected void paintComponent(Graphics g) { 87 | super.paintComponent(g); 88 | 89 | Point current = new Point(((int) (Math.random() * WIDTH)), ((int) (Math.random() * HEIGHT))); 90 | if (current.x > current.y) { 91 | points.put(current, 1); 92 | } else points.put(current, -1); 93 | 94 | if (frameCount % 3 == 0) { 95 | double[] inputs = new double[2]; 96 | for (Point point : points.keySet()) { 97 | inputs[0] = point.x; 98 | inputs[1] = point.y; 99 | perceptron.train(inputs, points.get(point)); 100 | } 101 | } 102 | 103 | g.drawLine(0, 0, WIDTH, HEIGHT); 104 | double guess; 105 | for (Point point : points.keySet()) { 106 | guess = perceptron.guess(new double[] {point.x, point.y}); 107 | if (guess == points.get(point)) g.setColor(Color.GREEN); 108 | else g.setColor(Color.RED); 109 | 110 | g.fillOval(point.x, point.y, 30, 30); 111 | g.setColor(Color.BLACK); 112 | g.drawOval(point.x, point.y, 30, 30); 113 | } 114 | 115 | frameCount++; 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/XORVisualizer.java: -------------------------------------------------------------------------------- 1 | package neuralnetworks; 2 | 3 | import java.awt.*; 4 | import java.awt.image.BufferedImage; 5 | import java.util.Arrays; 6 | import javax.swing.*; 7 | 8 | public class XORVisualizer extends JPanel { 9 | 10 | private static final int WIDTH = 900, HEIGHT = 900; 11 | private EfficientNeuralNetwork nn; 12 | private int frameCount = 0; 13 | 14 | public XORVisualizer() { 15 | JFrame frame = new JFrame("Testing single perceptrons"); 16 | frame.setBounds(100, 100, WIDTH, HEIGHT); 17 | frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); 18 | frame.setResizable(true); 19 | this.setBackground(Color.WHITE); 20 | this.setLayout(null); 21 | this.setPreferredSize(new Dimension(WIDTH, HEIGHT)); 22 | frame.add(this); 23 | frame.setVisible(true); 24 | } 25 | 26 | public static void main(String[] args) { 27 | XORVisualizer v = new XORVisualizer(); 28 | 29 | int[] hiddens = new int[2]; 30 | Arrays.fill(hiddens, 10); 31 | v.nn = new EfficientNeuralNetwork(2, 1, 0.1, hiddens); 32 | 33 | try { 34 | Thread.sleep(100); 35 | } catch (InterruptedException e) { 36 | e.printStackTrace(); 37 | } 38 | 39 | new Thread( 40 | () -> { 41 | while (true) v.train(); 42 | }) 43 | .start(); 44 | 45 | v.repaint(); 46 | while (v.frameCount < Short.MAX_VALUE) { 47 | if (v.frameCount % 20 == 0) v.repaint(); 48 | v.frameCount++; 49 | try { 50 | Thread.sleep(100); 51 | } catch (InterruptedException e) { 52 | e.printStackTrace(); 53 | } 54 | } 55 | } 56 | 57 | protected void paintComponent(Graphics g) { 58 | super.paintComponent(g); 59 | 60 | BufferedImage img = new BufferedImage(WIDTH, HEIGHT, BufferedImage.TYPE_INT_ARGB); 61 | double guess; 62 | Color color; 63 | for (int i = 0; i < img.getWidth(); i++) { 64 | for (int j = 0; j < img.getHeight(); j++) { 65 | guess = nn.predict(new double[] {((double) i) / WIDTH, ((double) j) / HEIGHT})[0]; 66 | System.out.println(guess); 67 | // color = new Color(((int) (guess * (Math.pow(2, 24) - 1)))); 68 | color = new Color(((int) (guess * 256)), ((int) (guess * 256)), ((int) (guess * 256))); 69 | // color = new Color(((int) (guess)) * 256, ((int) (1 / (1 + 70 | // Math.exp(-guess)) * 256)), 71 | // ((int) ((guess + 1) / 2)) * 256); 72 | img.setRGB(i, j, color.getRGB()); 73 | // img.setRGB(i, j, ((int) ((double) i * j / WIDTH / HEIGHT))); 74 | } 75 | } 76 | 77 | g.drawImage(img, 0, 0, null); 78 | } 79 | 80 | private void train() { 81 | double input1, input2; 82 | double[] input = new double[2], output = new double[1]; 83 | for (int i = 0; i < 200; i++) { 84 | input1 = Math.random() < 0.5 ? 0 : 1; 85 | input2 = Math.random() < 0.5 ? 0 : 1; 86 | // input1 = ((int) (Math.random() * 5)) / 5.0; 87 | // input2 = ((int) (Math.random() * 5)) / 5.0; 88 | // input1 = Math.random(); 89 | // input2 = Math.random(); 90 | 91 | input[0] = input1; 92 | input[1] = input2; 93 | 94 | // output[0] = input1 + input2; 95 | output[0] = Math.abs(input1 - input2); 96 | // output[0] = ((int) input1) ^ ((int) input2); 97 | // output[0] = Math.abs(input1 - 0.5) > Math.abs(input2 - 0.5) ? 1 : 0; 98 | // output[0] = input1 == input2 && input2 == 1 ? 1 : 0; 99 | // output[0] = input1 * input2; 100 | // output[0] = input1; 101 | 102 | nn.train(input, output); 103 | } 104 | } 105 | } 106 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/mnistdata/MnistEntry.java: -------------------------------------------------------------------------------- 1 | package neuralnetworks.mnistdata; 2 | 3 | import java.awt.image.BufferedImage; 4 | import java.awt.image.DataBuffer; 5 | import java.awt.image.DataBufferByte; 6 | 7 | /** 8 | * An entry of the MNIST data set. Instances of this class will be passed to the consumer that is 9 | * given to the {@link MnistLoader} and {@link MnistLoader} reading methods. 10 | */ 11 | @SuppressWarnings("unused") 12 | public class MnistEntry { 13 | /** The index of the entry */ 14 | private final int index; 15 | 16 | /** The class label of the entry */ 17 | private final byte label; 18 | 19 | /** The number of rows of the image data */ 20 | private final int numRows; 21 | 22 | /** The number of columns of the image data */ 23 | private final int numCols; 24 | 25 | /** The image data */ 26 | private final byte[] imageData; 27 | 28 | /** 29 | * Default constructor 30 | * 31 | * @param index The index 32 | * @param label The label 33 | * @param numRows The number of rows 34 | * @param numCols The number of columns 35 | * @param imageData The image data 36 | */ 37 | MnistEntry(int index, byte label, int numRows, int numCols, byte[] imageData) { 38 | this.index = index; 39 | this.label = label; 40 | this.numRows = numRows; 41 | this.numCols = numCols; 42 | this.imageData = imageData; 43 | } 44 | 45 | /** 46 | * Returns the index of the entry 47 | * 48 | * @return The index 49 | */ 50 | public int getIndex() { 51 | return index; 52 | } 53 | 54 | /** 55 | * Returns the class label of the entry. This is a value in [0,9], indicating which digit is shown 56 | * in the entry 57 | * 58 | * @return The class label 59 | */ 60 | public byte getLabel() { 61 | return label; 62 | } 63 | 64 | /** 65 | * Returns the number of rows of the image data. This will usually be 28. 66 | * 67 | * @return The number of rows 68 | */ 69 | public int getNumRows() { 70 | return numRows; 71 | } 72 | 73 | /** 74 | * Returns the number of columns of the image data. This will usually be 28. 75 | * 76 | * @return The number of columns 77 | */ 78 | public int getNumCols() { 79 | return numCols; 80 | } 81 | 82 | /** 83 | * Returns a reference to the image data. This will be an array of length 84 | * numRows * numCols, containing values in [0,255] indicating the brightness of the pixels. 85 | * 86 | * @return The image data 87 | */ 88 | public byte[] getImageData() { 89 | return imageData; 90 | } 91 | 92 | /** 93 | * Creates a new buffered image from the image data that is stored in this entry. 94 | * 95 | * @return The image 96 | */ 97 | public BufferedImage createImage() { 98 | BufferedImage image = 99 | new BufferedImage(getNumCols(), getNumRows(), BufferedImage.TYPE_BYTE_GRAY); 100 | DataBuffer dataBuffer = image.getRaster().getDataBuffer(); 101 | DataBufferByte dataBufferByte = (DataBufferByte) dataBuffer; 102 | byte[] data = dataBufferByte.getData(); 103 | System.arraycopy(getImageData(), 0, data, 0, data.length); 104 | return image; 105 | } 106 | 107 | @Override 108 | public String toString() { 109 | String indexString = String.format("%05d", index); 110 | return "MnistEntry[" + "index=" + indexString + "," + "label=" + label + "]"; 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/mnistdata/MnistLoader.java: -------------------------------------------------------------------------------- 1 | package neuralnetworks.mnistdata; 2 | 3 | import java.io.*; 4 | 5 | /** 6 | * A class for reading the MNIST data set from the decompressed (unzipped) files that are 7 | * published at http://yann.lecun.com/exdb/mnist/. 8 | */ 9 | public class MnistLoader { 10 | /** Default constructor */ 11 | public MnistLoader() { 12 | // Default constructor 13 | } 14 | 15 | /** 16 | * Read bytes from the given input stream, filling the given array 17 | * 18 | * @param inputStream The input stream 19 | * @param data The array to be filled 20 | * @throws IOException If the input stream does not contain enough bytes to fill the array, or any 21 | * other IO error occurs 22 | */ 23 | private static void read(InputStream inputStream, byte[] data) throws IOException { 24 | int offset = 0; 25 | while (true) { 26 | int read = inputStream.read(data, offset, data.length - offset); 27 | if (read < 0) { 28 | break; 29 | } 30 | offset += read; 31 | if (offset == data.length) { 32 | return; 33 | } 34 | } 35 | throw new IOException("Tried to read " + data.length + " bytes, but only found " + offset); 36 | } 37 | 38 | /** 39 | * Read the MNIST training data from the given directory. The data is assumed to be located in 40 | * files with their default names, decompressed from the original files: extension) : 41 | * train-images.idx3-ubyte and train-labels.idx1-ubyte. 42 | * 43 | * @param inputDirectoryPath The input directory {@link MnistEntry} instances 44 | * @throws IOException If an IO error occurs 45 | */ 46 | public MnistEntry[] readDecompressedTraining(String inputDirectoryPath) throws IOException { 47 | String trainImagesFileName = "train-images.idx3-ubyte"; 48 | String trainLabelsFileName = "train-labels.idx1-ubyte"; 49 | String imagesFilePath = inputDirectoryPath + "\\" + trainImagesFileName; 50 | String labelsFilePath = inputDirectoryPath + "\\" + trainLabelsFileName; 51 | return readDecompressed(imagesFilePath, labelsFilePath); 52 | } 53 | 54 | /** 55 | * Read the MNIST training data from the given directory. The data is assumed to be located in 56 | * files with their default names, decompressed from the original files: extension) : 57 | * t10k-images.idx3-ubyte and t10k-labels.idx1-ubyte. 58 | * 59 | * @param inputDirectoryPath The input directory {@link MnistEntry} instances 60 | * @throws IOException If an IO error occurs 61 | */ 62 | public MnistEntry[] readDecompressedTesting(String inputDirectoryPath) throws IOException { 63 | String trainImagesFileName = "t10k-images.idx3-ubyte"; 64 | String trainLabelsFileName = "t10k-labels.idx1-ubyte"; 65 | String imagesFilePath = inputDirectoryPath + "\\" + trainImagesFileName; 66 | String labelsFilePath = inputDirectoryPath + "\\" + trainLabelsFileName; 67 | return readDecompressed(imagesFilePath, labelsFilePath); 68 | } 69 | 70 | /** 71 | * Read the MNIST data from the specified (decompressed) files. 72 | * 73 | * @param imagesFilePath The path of the images file 74 | * @param labelsFilePath The path of the labels file {@link MnistEntry} instances 75 | * @throws IOException If an IO error occurs 76 | */ 77 | public MnistEntry[] readDecompressed(String imagesFilePath, String labelsFilePath) 78 | throws IOException { 79 | try (InputStream decompressedImagesInputStream = new FileInputStream(new File(imagesFilePath)); 80 | InputStream decompressedLabelsInputStream = new FileInputStream(new File(labelsFilePath))) { 81 | return readDecompressed(decompressedImagesInputStream, decompressedLabelsInputStream); 82 | } 83 | } 84 | 85 | /** 86 | * Read the MNIST data from the given (decompressed) input streams. The caller is responsible for 87 | * closing the given streams. 88 | * 89 | * @param decompressedImagesInputStream The decompressed input stream containing the image data 90 | * @param decompressedLabelsInputStream The decompressed input stream containing the label data 91 | * {@link MnistEntry} instances 92 | * @throws IOException If an IO error occurs 93 | */ 94 | public MnistEntry[] readDecompressed( 95 | InputStream decompressedImagesInputStream, InputStream decompressedLabelsInputStream) 96 | throws IOException { 97 | DataInputStream imagesDataInputStream = new DataInputStream(decompressedImagesInputStream); 98 | DataInputStream labelsDataInputStream = new DataInputStream(decompressedLabelsInputStream); 99 | 100 | int magicImages = imagesDataInputStream.readInt(); 101 | if (magicImages != 0x803) { 102 | throw new IOException( 103 | "Expected magic header of 0x803 " + "for images, but found " + magicImages); 104 | } 105 | 106 | int magicLabels = labelsDataInputStream.readInt(); 107 | if (magicLabels != 0x801) { 108 | throw new IOException( 109 | "Expected magic header of 0x801 " + "for labels, but found " + magicLabels); 110 | } 111 | 112 | int numberOfImages = imagesDataInputStream.readInt(); 113 | int numberOfLabels = labelsDataInputStream.readInt(); 114 | 115 | if (numberOfImages != numberOfLabels) { 116 | throw new IOException( 117 | "Found " + numberOfImages + " images but " + numberOfLabels + " labels"); 118 | } 119 | 120 | int numRows = imagesDataInputStream.readInt(); 121 | int numCols = imagesDataInputStream.readInt(); 122 | 123 | MnistEntry[] mnistEntries = new MnistEntry[numberOfImages]; 124 | 125 | for (int n = 0; n < numberOfImages; n++) { 126 | byte label = labelsDataInputStream.readByte(); 127 | byte[] imageData = new byte[numRows * numCols]; 128 | read(imagesDataInputStream, imageData); 129 | 130 | MnistEntry mnistEntry = new MnistEntry(n, label, numRows, numCols, imageData); 131 | mnistEntries[n] = mnistEntry; 132 | } 133 | 134 | return mnistEntries; 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/mnistdata/MnistWeights.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishantc1527/Algorithms-Java/060a2169b4fbe80a20adbe5140ff04c2d67c3d62/src/main/java/neuralnetworks/mnistdata/MnistWeights.txt -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/mnistdata/test_images.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishantc1527/Algorithms-Java/060a2169b4fbe80a20adbe5140ff04c2d67c3d62/src/main/java/neuralnetworks/mnistdata/test_images.gz -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/mnistdata/test_labels.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishantc1527/Algorithms-Java/060a2169b4fbe80a20adbe5140ff04c2d67c3d62/src/main/java/neuralnetworks/mnistdata/test_labels.gz -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/mnistdata/training_images.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishantc1527/Algorithms-Java/060a2169b4fbe80a20adbe5140ff04c2d67c3d62/src/main/java/neuralnetworks/mnistdata/training_images.gz -------------------------------------------------------------------------------- /src/main/java/neuralnetworks/mnistdata/training_labels.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishantc1527/Algorithms-Java/060a2169b4fbe80a20adbe5140ff04c2d67c3d62/src/main/java/neuralnetworks/mnistdata/training_labels.gz -------------------------------------------------------------------------------- /src/main/java/sorting/bogosort/BogoSort.java: -------------------------------------------------------------------------------- 1 | package sorting.bogosort; 2 | 3 | public class BogoSort { 4 | 5 | public static void sort(int[] arr) { 6 | while (!isSorted(arr)) { 7 | shuffle(arr); 8 | } 9 | } 10 | 11 | private static void shuffle(int[] arr) { 12 | for (int i = 0; i < arr.length; i++) { 13 | swap(arr, i, (int) (Math.random() * arr.length)); 14 | } 15 | } 16 | 17 | private static void swap(int[] arr, int i, int j) { 18 | int temp = arr[i]; 19 | arr[i] = arr[j]; 20 | arr[j] = temp; 21 | } 22 | 23 | private static boolean isSorted(int[] arr) { 24 | int i = arr.length; 25 | 26 | while (i-- > 1) { 27 | if (arr[i] < arr[i - 1]) { 28 | return false; 29 | } 30 | } 31 | 32 | return true; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/sorting/bubblesort/BubbleSort.java: -------------------------------------------------------------------------------- 1 | package sorting.bubblesort; 2 | 3 | import java.util.Arrays; 4 | 5 | public class BubbleSort { 6 | 7 | public static void sort(int[] arr) { 8 | boolean didFind = true; 9 | 10 | for (int end = arr.length; end > 0 && didFind; end--) { 11 | didFind = false; 12 | 13 | for (int i = 0; i < end - 1; i++) { 14 | if (arr[i] > arr[i + 1]) { 15 | swap(arr, i, i + 1); 16 | didFind = true; 17 | } 18 | } 19 | } 20 | } 21 | 22 | private static void swap(int[] arr, int pos1, int pos2) { 23 | int temp = arr[pos1]; 24 | arr[pos1] = arr[pos2]; 25 | arr[pos2] = temp; 26 | } 27 | 28 | public static void main(String[] args) { 29 | int[] arr = {1, 5, 3, 5, 3, 2, 4, 5, 2}; 30 | System.out.println(Arrays.toString(arr)); 31 | sort(arr); 32 | System.out.println(Arrays.toString(arr)); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/sorting/countingsort/CountingSort.java: -------------------------------------------------------------------------------- 1 | package sorting.countingsort; 2 | 3 | import java.util.Arrays; 4 | 5 | public class CountingSort { 6 | 7 | public static void sort(int[] arr) { 8 | int[] minMax = minAndMax(arr); 9 | int[] count = new int[minMax[1] - minMax[0] + 1]; 10 | 11 | for (int value : arr) { 12 | count[value - minMax[0]]++; 13 | } 14 | 15 | int k = 0; 16 | for (int i = 0; i < count.length; i++) { 17 | while (count[i] > 0) { 18 | arr[k++] = i + minMax[0]; 19 | count[i]--; 20 | } 21 | } 22 | } 23 | 24 | private static int[] minAndMax(int[] arr) { 25 | int min = Integer.MAX_VALUE, max = Integer.MIN_VALUE; 26 | 27 | for (int value : arr) { 28 | if (value < min) { 29 | min = value; 30 | } 31 | 32 | if (value > max) { 33 | max = value; 34 | } 35 | } 36 | 37 | return new int[] {min, max}; 38 | } 39 | 40 | public static void main(String[] args) { 41 | int[] arr = {1, 5, 3, 5, 3, 2, 4, 5, 2}; 42 | System.out.println(Arrays.toString(arr)); 43 | sort(arr); 44 | System.out.println(Arrays.toString(arr)); 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/main/java/sorting/heapsort/HeapSort.java: -------------------------------------------------------------------------------- 1 | package sorting.heapsort; 2 | 3 | import java.util.Arrays; 4 | 5 | public class HeapSort { 6 | 7 | private static int heapSize; 8 | 9 | public static void sort(int[] arr) { 10 | buildHeap(arr); 11 | while (heapSize > 0) { 12 | arr[heapSize - 1] = popMax(arr); 13 | } 14 | } 15 | 16 | private static void buildHeap(int[] arr) { 17 | heapSize = arr.length; 18 | 19 | for (int i = (arr.length >> 1) - 1; i >= 0; i--) { 20 | heapify(arr, i); 21 | } 22 | } 23 | 24 | private static void heapify(int[] arr, int i) { 25 | int left = getLeft(i), right = getRight(i), largest = i; 26 | 27 | if (left < heapSize && arr[left] > arr[largest]) { 28 | largest = left; 29 | } 30 | 31 | if (right < heapSize && arr[right] > arr[largest]) { 32 | largest = right; 33 | } 34 | 35 | if (i != largest) { 36 | swap(arr, i, largest); 37 | heapify(arr, largest); 38 | } 39 | } 40 | 41 | private static int popMax(int[] heap) { 42 | int toReturn = heap[0]; 43 | swap(heap, 0, heapSize - 1); 44 | heapSize--; 45 | heapify(heap, 0); 46 | return toReturn; 47 | } 48 | 49 | private static int getLeft(int i) { 50 | return (i << 1) + 1; 51 | } 52 | 53 | private static int getRight(int i) { 54 | return (i << 1) + 2; 55 | } 56 | 57 | private static void swap(int[] arr, int pos1, int pos2) { 58 | int temp = arr[pos1]; 59 | arr[pos1] = arr[pos2]; 60 | arr[pos2] = temp; 61 | } 62 | 63 | public static void main(String[] args) { 64 | int[] arr = {1, 5, 3, 5, 3, 2, 4, 5, 2}; 65 | System.out.println(Arrays.toString(arr)); 66 | sort(arr); 67 | System.out.println(Arrays.toString(arr)); 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/main/java/sorting/insertionsort/InsertionSort.java: -------------------------------------------------------------------------------- 1 | package sorting.insertionsort; 2 | 3 | import java.util.Arrays; 4 | 5 | public class InsertionSort { 6 | 7 | public static void sort(int[] arr) { 8 | for (int toInsert = 0; toInsert < arr.length; toInsert++) { 9 | int curr = arr[toInsert], i; 10 | 11 | for (i = toInsert; i >= 1 && arr[i - 1] > curr; i--) { 12 | arr[i] = arr[i - 1]; 13 | } 14 | 15 | arr[i] = curr; 16 | } 17 | } 18 | 19 | public static void main(String[] args) { 20 | int[] arr = {1, 5, 3, 5, 3, 2, 4, 5, 2}; 21 | System.out.println(Arrays.toString(arr)); 22 | sort(arr); 23 | System.out.println(Arrays.toString(arr)); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/sorting/mergesort/MergeSort.java: -------------------------------------------------------------------------------- 1 | package sorting.mergesort; 2 | 3 | import java.util.Arrays; 4 | 5 | public class MergeSort { 6 | 7 | public static void sort(int[] arr) { 8 | sort(arr, 0, arr.length); 9 | } 10 | 11 | private static void sort(int[] arr, int left, int right) { 12 | if (left >= right - 1) { 13 | return; 14 | } 15 | 16 | int mid = (left + right) / 2; 17 | sort(arr, left, mid); 18 | sort(arr, mid, right); 19 | 20 | int[] leftSide = new int[mid - left]; 21 | int[] rightSide = new int[right - mid]; 22 | 23 | System.arraycopy(arr, left, leftSide, 0, leftSide.length); 24 | System.arraycopy(arr, mid, rightSide, 0, rightSide.length); 25 | 26 | merge(arr, leftSide, rightSide, left); 27 | } 28 | 29 | private static void merge(int[] arr, int[] left, int[] right, int startPos) { 30 | int k = startPos; 31 | 32 | for (int i = 0, j = 0; i < left.length || j < right.length; ) { 33 | if (i >= left.length) { 34 | arr[k++] = right[j++]; 35 | } else if (j >= right.length) { 36 | arr[k++] = left[i++]; 37 | } else { 38 | if (left[i] < right[j]) { 39 | arr[k++] = left[i++]; 40 | } else { 41 | arr[k++] = right[j++]; 42 | } 43 | } 44 | } 45 | } 46 | 47 | public static void main(String[] args) { 48 | int[] arr = {1, 5, 3, 5, 3, 2, 4, 5, 2}; 49 | System.out.println(Arrays.toString(arr)); 50 | sort(arr); 51 | System.out.println(Arrays.toString(arr)); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/sorting/quicksort/QuickSort.java: -------------------------------------------------------------------------------- 1 | package sorting.quicksort; 2 | 3 | import java.util.Arrays; 4 | 5 | public class QuickSort { 6 | 7 | public static void sort(int[] arr) { 8 | sort(arr, 0, arr.length); 9 | } 10 | 11 | private static void sort(int[] arr, int left, int right) { 12 | if (left >= right - 1) { 13 | return; 14 | } 15 | 16 | int pivotIndex = partition(arr, left, right); 17 | sort(arr, left, pivotIndex); 18 | sort(arr, pivotIndex + 1, right); 19 | } 20 | 21 | private static int partition(int[] arr, int left, int right) { 22 | int pivotIndex = right - 1, partitionIndex = left; 23 | 24 | for (int i = left; i < right - 1; i++) { 25 | if (arr[i] < arr[pivotIndex]) { 26 | if (partitionIndex == pivotIndex) { 27 | pivotIndex = i; 28 | } 29 | 30 | swap(arr, i, partitionIndex++); 31 | } 32 | } 33 | 34 | swap(arr, pivotIndex, partitionIndex); 35 | return partitionIndex; 36 | } 37 | 38 | private static void swap(int[] arr, int pos1, int pos2) { 39 | int temp = arr[pos1]; 40 | arr[pos1] = arr[pos2]; 41 | arr[pos2] = temp; 42 | } 43 | 44 | public static void main(String[] args) { 45 | int[] arr = {1, 5, 3, 5, 3, 2, 4, 5, 2}; 46 | System.out.println(Arrays.toString(arr)); 47 | sort(arr); 48 | System.out.println(Arrays.toString(arr)); 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/java/sorting/selectionsort/SelectionSort.java: -------------------------------------------------------------------------------- 1 | package sorting.selectionsort; 2 | 3 | import java.util.Arrays; 4 | 5 | public class SelectionSort { 6 | 7 | public static void sort(int[] arr) { 8 | for (int start = 0; start < arr.length - 1; start++) { 9 | int min = Integer.MAX_VALUE, minIndex = -1; 10 | 11 | for (int i = start; i < arr.length; i++) { 12 | if (arr[i] < min) { 13 | min = arr[i]; 14 | minIndex = i; 15 | } 16 | } 17 | 18 | swap(arr, start, minIndex); 19 | } 20 | } 21 | 22 | private static void swap(int[] arr, int pos1, int pos2) { 23 | int temp = arr[pos1]; 24 | arr[pos1] = arr[pos2]; 25 | arr[pos2] = temp; 26 | } 27 | 28 | public static void main(String[] args) { 29 | int[] arr = {1, 5, 3, 5, 3, 2, 4, 5, 2}; 30 | System.out.println(Arrays.toString(arr)); 31 | sort(arr); 32 | System.out.println(Arrays.toString(arr)); 33 | } 34 | } 35 | --------------------------------------------------------------------------------