├── .github └── workflows │ └── maven.yml ├── .gitignore ├── LICENSE ├── README.md ├── pom.xml └── src ├── main └── java │ └── com │ └── lewdev │ └── probabilitylib │ └── ProbabilityCollection.java └── test └── java └── com └── lewdev └── probabilitylib ├── BenchmarkProbability.java ├── ExampleApp.java ├── ProbabilityCollectionTest.java └── ProbabilityMap.java /.github/workflows/maven.yml: -------------------------------------------------------------------------------- 1 | # This workflow will build a Java project with Maven 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/building-and-testing-java-with-maven 3 | 4 | name: Build 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up JDK 1.8 20 | uses: actions/setup-java@v1 21 | with: 22 | java-version: 1.8 23 | - name: Build with Maven 24 | run: mvn -B package --file pom.xml 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target/ 2 | .classpath 3 | .project 4 | *.settings 5 | .editorconfig -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Lewys Davies 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Java-Probability-Collection 2 | [![Scrutinizer Code Quality](https://scrutinizer-ci.com/g/lewysDavies/Java-Probability-Collection/badges/quality-score.png?b=master)](https://scrutinizer-ci.com/g/lewysDavies/Java-Probability-Collection/?branch=master) [![Build Status](https://scrutinizer-ci.com/g/lewysDavies/Java-Probability-Collection/badges/build.png?b=master)](https://scrutinizer-ci.com/g/lewysDavies/Java-Probability-Collection/build-status/master) [![](https://jitpack.io/v/lewysDavies/Java-Probability-Collection.svg)](https://jitpack.io/#lewysDavies/Java-Probability-Collection)
3 | Generic and Highly Optimised Java Data-Structure for Retrieving Random Elements with Probability 4 | 5 | # Usage 6 | ``` 7 | ProbabilityCollection collection = new ProbabilityCollection<>(); 8 | collection.add("A", 50); // 50 / 85 (total probability) = 0.588 * 100 = 58.8% Chance 9 | collection.add("B", 25); // 25 / 85 (total probability) = 0.294 * 100 = 29.4% Chance 10 | collection.add("C", 10); // 10 / 85 (total probability) = 0.117 * 100 = 11.7% Chance 11 | 12 | String random = collection.get(); 13 | ``` 14 | 15 | # Proven Probability 16 | The probability test is run **1,000,000 times**. Each time getting **100,000** random elements and counting the spread. The test would not pass if the spread had over **1%** deviation from the expected probability. 17 | 18 | A real world example is provided in ExampleApp.java (within the test folder), Typical Output with 100,000 gets:: 19 | ``` 20 | Prob | Actual 21 | ----------------------- 22 | A: 58.824% | 58.975% 23 | B: 29.412% | 29.256% 24 | C: 11.765% | 11.769% 25 | ``` 26 | 27 | # Performance 28 | Get performance has been significantly improved in comparison to my previous map implementation. This has been achieved with custom compared TreeSets. 29 | ``` 30 | Benchmark Mode Cnt Score Error Units 31 | BenchmarkProbability.collectionAddSingle avgt 5 501.688 ± 33.925 ns/op 32 | BenchmarkProbability.collectionGet avgt 5 69.373 ± 2.198 ns/op 33 | BenchmarkProbability.mapAddSingle avgt 5 25809.712 ± 984.980 ns/op 34 | BenchmarkProbability.mapGet avgt 5 902.414 ± 22.388 ns/op 35 | ``` 36 | 37 | # Installation 38 | **Super Simple: Copy ProbabilityCollection.java into your project**

39 | or for the fancy users, you could use Maven:
40 | **Repository:** 41 | ``` 42 | 43 | jitpack.io 44 | https://jitpack.io 45 | 46 | ``` 47 | **Dependency:** 48 | ``` 49 | 50 | com.github.lewysDavies 51 | Java-Probability-Collection 52 | v0.8 53 | 54 | ``` 55 | **Maven Shade This Dependency:** 56 | ``` 57 | 58 | org.apache.maven.plugins 59 | maven-shade-plugin 60 | 3.1.1 61 | 62 | 63 | 64 | 65 | 66 | 67 | com.lewdev.probabilitylib 68 | ******.probabilitylib 69 | 70 | 71 | 72 | package 73 | 74 | shade 75 | 76 | 77 | 78 | 79 | ``` 80 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 4.0.0 5 | 6 | com.lewdev 7 | probability-lib 8 | 0.8 9 | jar 10 | 11 | probability-lib 12 | lewdev.uk 13 | 14 | 15 | UTF-8 16 | 1.8 17 | 1.23 18 | 19 | 20 | 21 | 22 | org.junit.jupiter 23 | junit-jupiter-engine 24 | 5.1.0 25 | test 26 | 27 | 28 | 29 | org.hamcrest 30 | hamcrest-core 31 | 1.3 32 | test 33 | 34 | 35 | 36 | org.openjdk.jmh 37 | jmh-core 38 | ${jmh.version} 39 | 40 | 41 | org.openjdk.jmh 42 | jmh-generator-annprocess 43 | ${jmh.version} 44 | 45 | 46 | 47 | 48 | 49 | 50 | org.apache.maven.plugins 51 | maven-compiler-plugin 52 | 3.1 53 | 54 | ${java.version} 55 | ${java.version} 56 | 57 | 58 | 59 | 60 | org.apache.maven.plugins 61 | maven-shade-plugin 62 | 3.2.0 63 | 64 | 65 | package 66 | 67 | shade 68 | 69 | 70 | benchmarks 71 | 72 | 74 | org.openjdk.jmh.Main 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /src/main/java/com/lewdev/probabilitylib/ProbabilityCollection.java: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020 Lewys Davies 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a copy 5 | * of this software and associated documentation files (the "Software"), to deal 6 | * in the Software without restriction, including without limitation the rights 7 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | * copies of the Software, and to permit persons to whom the Software is 9 | * furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in all 12 | * copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 20 | * SOFTWARE. 21 | */ 22 | package com.lewdev.probabilitylib; 23 | 24 | import java.util.Comparator; 25 | import java.util.Iterator; 26 | import java.util.NavigableSet; 27 | import java.util.Objects; 28 | import java.util.SplittableRandom; 29 | import java.util.TreeSet; 30 | 31 | /** 32 | * ProbabilityCollection for retrieving random elements based on probability. 33 | *
34 | *
35 | * Selection Algorithm Implementation: 36 | *

37 | *

47 | * 48 | * @author Lewys Davies 49 | * @version 0.8 50 | * 51 | * @param Type of elements 52 | */ 53 | public final class ProbabilityCollection { 54 | 55 | private final NavigableSet> collection; 56 | private final SplittableRandom random = new SplittableRandom(); 57 | 58 | private int totalProbability; 59 | 60 | /** 61 | * Construct a new Probability Collection 62 | */ 63 | public ProbabilityCollection() { 64 | this.collection = new TreeSet<>(Comparator.comparingInt(ProbabilitySetElement::getIndex)); 65 | this.totalProbability = 0; 66 | } 67 | 68 | /** 69 | * @return Number of objects inside the collection 70 | */ 71 | public int size() { 72 | return this.collection.size(); 73 | } 74 | 75 | /** 76 | * @return True if collection contains no elements, else False 77 | */ 78 | public boolean isEmpty() { 79 | return this.collection.isEmpty(); 80 | } 81 | 82 | /** 83 | * @param object 84 | * @return True if collection contains the object, else False 85 | * @throws IllegalArgumentException if object is null 86 | */ 87 | public boolean contains(E object) { 88 | if (object == null) { 89 | throw new IllegalArgumentException("Cannot check if null object is contained in this collection"); 90 | } 91 | 92 | return this.collection.stream().anyMatch(entry -> entry.getObject().equals(object)); 93 | } 94 | 95 | /** 96 | * @return Iterator over this collection 97 | */ 98 | public Iterator> iterator() { 99 | return this.collection.iterator(); 100 | } 101 | 102 | /** 103 | * Add an object to this collection 104 | * 105 | * @param object. Not null. 106 | * @param probability share. Must be greater than 0. 107 | * 108 | * @throws IllegalArgumentException if object is null 109 | * @throws IllegalArgumentException if probability <= 0 110 | */ 111 | public void add(E object, int probability) { 112 | if (object == null) { 113 | throw new IllegalArgumentException("Cannot add null object"); 114 | } 115 | 116 | if (probability <= 0) { 117 | throw new IllegalArgumentException("Probability must be greater than 0"); 118 | } 119 | 120 | ProbabilitySetElement entry = new ProbabilitySetElement(object, probability); 121 | entry.setIndex(this.totalProbability + 1); 122 | 123 | this.collection.add(entry); 124 | this.totalProbability += probability; 125 | } 126 | 127 | /** 128 | * Remove a object from this collection 129 | * 130 | * @param object 131 | * @return True if object was removed, else False. 132 | * 133 | * @throws IllegalArgumentException if object is null 134 | */ 135 | public boolean remove(E object) { 136 | if (object == null) { 137 | throw new IllegalArgumentException("Cannot remove null object"); 138 | } 139 | 140 | Iterator> it = this.iterator(); 141 | boolean removed = false; 142 | 143 | // Remove all instances of the object 144 | while (it.hasNext()) { 145 | ProbabilitySetElement entry = it.next(); 146 | if (entry.getObject().equals(object)) { 147 | this.totalProbability -= entry.getProbability(); 148 | it.remove(); 149 | removed = true; 150 | } 151 | } 152 | 153 | // Recalculate remaining elements "block" of space: i.e 1-5, 6-10, 11-14 154 | if (removed) { 155 | int previousIndex = 0; 156 | for (ProbabilitySetElement entry : this.collection) { 157 | previousIndex = entry.setIndex(previousIndex + 1) + (entry.getProbability() - 1); 158 | } 159 | } 160 | 161 | return removed; 162 | } 163 | 164 | /** 165 | * Remove all objects from this collection 166 | */ 167 | public void clear() { 168 | this.collection.clear(); 169 | this.totalProbability = 0; 170 | } 171 | 172 | /** 173 | * Get a random object from this collection, based on probability. 174 | * 175 | * @return Random object 176 | * 177 | * @throws IllegalStateException if this collection is empty 178 | */ 179 | public E get() { 180 | if (this.isEmpty()) { 181 | throw new IllegalStateException("Cannot get an object out of a empty collection"); 182 | } 183 | 184 | ProbabilitySetElement toFind = new ProbabilitySetElement<>(null, 0); 185 | toFind.setIndex(this.random.nextInt(1, this.totalProbability + 1)); 186 | 187 | return Objects.requireNonNull(this.collection.floor(toFind).getObject()); 188 | } 189 | 190 | /** 191 | * @return Sum of all element's probability 192 | */ 193 | public int getTotalProbability() { 194 | return this.totalProbability; 195 | } 196 | 197 | /** 198 | * Used internally to store information about a object's state in a collection. 199 | * Specifically, the probability and index within the collection. 200 | * 201 | * Indexes refer to the start position of this element's "block" of space. The 202 | * space between element "block"s represents their probability of being selected 203 | * 204 | * @author Lewys Davies 205 | * 206 | * @param Type of element 207 | */ 208 | public final static class ProbabilitySetElement { 209 | private final T object; 210 | private final int probability; 211 | private int index; 212 | 213 | /** 214 | * @param object 215 | * @param probability share within the collection 216 | */ 217 | protected ProbabilitySetElement(T object, int probability) { 218 | this.object = object; 219 | this.probability = probability; 220 | } 221 | 222 | /** 223 | * @return The actual object 224 | */ 225 | public T getObject() { 226 | return this.object; 227 | } 228 | 229 | /** 230 | * @return Probability share in this collection 231 | */ 232 | public int getProbability() { 233 | return this.probability; 234 | } 235 | 236 | // Used internally, see this class's documentation 237 | private int getIndex() { 238 | return this.index; 239 | } 240 | 241 | // Used Internally, see this class's documentation 242 | private int setIndex(int index) { 243 | this.index = index; 244 | return this.index; 245 | } 246 | } 247 | } 248 | -------------------------------------------------------------------------------- /src/test/java/com/lewdev/probabilitylib/BenchmarkProbability.java: -------------------------------------------------------------------------------- 1 | package com.lewdev.probabilitylib; 2 | 3 | import java.util.concurrent.TimeUnit; 4 | 5 | import org.openjdk.jmh.annotations.Benchmark; 6 | import org.openjdk.jmh.annotations.BenchmarkMode; 7 | import org.openjdk.jmh.annotations.Fork; 8 | import org.openjdk.jmh.annotations.Level; 9 | import org.openjdk.jmh.annotations.Mode; 10 | import org.openjdk.jmh.annotations.OutputTimeUnit; 11 | import org.openjdk.jmh.annotations.Scope; 12 | import org.openjdk.jmh.annotations.Setup; 13 | import org.openjdk.jmh.annotations.State; 14 | import org.openjdk.jmh.annotations.TearDown; 15 | import org.openjdk.jmh.infra.Blackhole; 16 | import org.openjdk.jmh.runner.Runner; 17 | import org.openjdk.jmh.runner.RunnerException; 18 | import org.openjdk.jmh.runner.options.Options; 19 | import org.openjdk.jmh.runner.options.OptionsBuilder; 20 | 21 | @BenchmarkMode(Mode.AverageTime) 22 | @OutputTimeUnit(TimeUnit.NANOSECONDS) 23 | @State(Scope.Benchmark) 24 | @Fork(value = 2, jvmArgs = {"-Xms2G", "-Xmx2G"}) 25 | public class BenchmarkProbability { 26 | 27 | public static void main(String[] args) throws RunnerException { 28 | Options opt = new OptionsBuilder() 29 | .include(BenchmarkProbability.class.getSimpleName()) 30 | .forks(1) 31 | .build(); 32 | 33 | new Runner(opt).run(); 34 | } 35 | 36 | public int elements = 1_000; 37 | 38 | public int toAdd = elements + 1; 39 | public int toAddProb = 10; 40 | 41 | private ProbabilityMap map; 42 | private ProbabilityCollection collection; 43 | 44 | @Setup(Level.Iteration) 45 | public void setup() { 46 | this.map = new ProbabilityMap<>(); 47 | this.collection = new ProbabilityCollection<>(); 48 | 49 | for(int i = 0; i < elements; i++) { 50 | map.add(i, 1); 51 | collection.add(i, 1); 52 | } 53 | } 54 | 55 | @TearDown(Level.Iteration) 56 | public void tearDown() { 57 | this.map.clear(); 58 | this.collection.clear(); 59 | 60 | this.map = null; 61 | this.collection = null; 62 | } 63 | 64 | @Benchmark 65 | public void mapAddSingle() { 66 | this.map.add(toAdd, toAddProb); 67 | } 68 | 69 | @Benchmark 70 | public void collectionAddSingle() { 71 | this.collection.add(toAdd, toAddProb); 72 | } 73 | 74 | @Benchmark 75 | public void mapGet(Blackhole bh) { 76 | bh.consume(this.map.get()); 77 | } 78 | 79 | @Benchmark 80 | public void collectionGet(Blackhole bh) { 81 | bh.consume(this.collection.get()); 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /src/test/java/com/lewdev/probabilitylib/ExampleApp.java: -------------------------------------------------------------------------------- 1 | package com.lewdev.probabilitylib; 2 | 3 | public class ExampleApp { 4 | 5 | public static void main(String[] args) { 6 | ProbabilityCollection collection = new ProbabilityCollection<>(); 7 | 8 | collection.add("A", 50); 9 | collection.add("B", 25); 10 | collection.add("C", 10); 11 | 12 | int a = 0, b = 0, c = 0; 13 | int totalGets = 100000; 14 | 15 | for(int i = 0; i < totalGets; i++) { 16 | String random = collection.get(); 17 | 18 | if(random.equals("A")) a++; 19 | else if(random.equals("B")) b++; 20 | else if(random.equals("C")) c++; 21 | } 22 | 23 | double aProb = 50.0 / (double) collection.getTotalProbability() * 100; 24 | double bProb = 25.0 / (double) collection.getTotalProbability() * 100; 25 | double cProb = 10.0 / (double) collection.getTotalProbability() * 100; 26 | 27 | double aResult = a / (double) totalGets * 100; 28 | double bResult = b / (double) totalGets * 100; 29 | double cResult = c / (double) totalGets * 100; 30 | 31 | System.out.println(" Prob | Actual"); 32 | System.out.println("-----------------------"); 33 | System.out.printf("A: %.3f%% | %.3f%% \n", aProb, aResult); 34 | System.out.printf("B: %.3f%% | %.3f%% \n", bProb, bResult); 35 | System.out.printf("C: %.3f%% | %.3f%% \n", cProb, cResult); 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/test/java/com/lewdev/probabilitylib/ProbabilityCollectionTest.java: -------------------------------------------------------------------------------- 1 | package com.lewdev.probabilitylib; 2 | 3 | import static org.junit.jupiter.api.Assertions.*; 4 | 5 | import org.junit.jupiter.api.RepeatedTest; 6 | import org.junit.jupiter.api.Test; 7 | 8 | /** 9 | * @author Lewys Davies 10 | */ 11 | public class ProbabilityCollectionTest { 12 | 13 | @RepeatedTest(value = 10_000) 14 | public void test_insert() { 15 | ProbabilityCollection collection = new ProbabilityCollection<>(); 16 | assertEquals(0, collection.size()); 17 | assertTrue(collection.isEmpty()); 18 | assertEquals(0, collection.getTotalProbability()); 19 | 20 | collection.add("A", 2); 21 | assertTrue(collection.contains("A")); 22 | assertEquals(1, collection.size()); 23 | assertFalse(collection.isEmpty()); 24 | assertEquals(2, collection.getTotalProbability()); // 2 25 | 26 | collection.add("B", 5); 27 | assertTrue(collection.contains("B")); 28 | assertEquals(2, collection.size()); 29 | assertFalse(collection.isEmpty()); 30 | assertEquals(7, collection.getTotalProbability()); // 5 + 2 31 | 32 | collection.add("C", 10); 33 | assertTrue(collection.contains("C")); 34 | assertEquals(3, collection.size()); 35 | assertFalse(collection.isEmpty()); 36 | assertEquals(17, collection.getTotalProbability()); // 5 + 2 + 10 37 | 38 | for(int i = 0; i < 100; i++) { 39 | collection.add("C", 1); 40 | 41 | assertTrue(collection.contains("C")); 42 | assertEquals(4 + i, collection.size()); // 4 + i 43 | assertFalse(collection.isEmpty()); 44 | assertEquals(18 + i, collection.getTotalProbability()); // 5 + 2 + 10 + i 45 | } 46 | } 47 | 48 | @RepeatedTest(value = 10_000) 49 | public void test_remove() { 50 | ProbabilityCollection collection = new ProbabilityCollection<>(); 51 | assertEquals(0, collection.size()); 52 | assertTrue(collection.isEmpty()); 53 | assertEquals(0, collection.getTotalProbability()); 54 | 55 | String t1 = "Hello"; 56 | String t2 = "World"; 57 | String t3 = "!"; 58 | 59 | collection.add(t1, 10); 60 | collection.add(t2, 10); 61 | collection.add(t3, 10); 62 | 63 | assertEquals(3, collection.size()); 64 | assertFalse(collection.isEmpty()); 65 | assertEquals(30, collection.getTotalProbability()); 66 | 67 | // Remove t2 68 | assertTrue(collection.remove(t2)); 69 | 70 | assertEquals(2, collection.size()); 71 | assertFalse(collection.isEmpty()); 72 | assertEquals(20, collection.getTotalProbability()); 73 | 74 | // Remove t1 75 | assertTrue(collection.remove(t1)); 76 | 77 | assertEquals(1, collection.size()); 78 | assertFalse(collection.isEmpty()); 79 | assertEquals(10, collection.getTotalProbability()); 80 | 81 | //Remove t3 82 | assertTrue(collection.remove(t3)); 83 | 84 | assertEquals(0, collection.size()); 85 | assertTrue(collection.isEmpty()); 86 | assertEquals(0, collection.getTotalProbability()); 87 | } 88 | 89 | @RepeatedTest(value = 10_000) 90 | public void test_remove_duplicates() { 91 | ProbabilityCollection collection = new ProbabilityCollection<>(); 92 | assertEquals(0, collection.size()); 93 | assertTrue(collection.isEmpty()); 94 | 95 | String t1 = "Hello"; 96 | String t2 = "World"; 97 | String t3 = "!"; 98 | 99 | for(int i = 0; i < 10; i++) { 100 | collection.add(t1, 10); 101 | } 102 | 103 | for(int i = 0; i < 10; i++) { 104 | collection.add(t2, 10); 105 | } 106 | 107 | for(int i = 0; i < 10; i++) { 108 | collection.add(t3, 10); 109 | } 110 | 111 | assertEquals(30, collection.size()); 112 | assertFalse(collection.isEmpty()); 113 | assertEquals(300, collection.getTotalProbability()); 114 | 115 | //Remove t2 116 | assertTrue(collection.remove(t2)); 117 | 118 | assertEquals(20, collection.size()); 119 | assertFalse(collection.isEmpty()); 120 | assertEquals(200, collection.getTotalProbability()); 121 | 122 | // Remove t1 123 | assertTrue(collection.remove(t1)); 124 | 125 | assertEquals(10, collection.size()); 126 | assertFalse(collection.isEmpty()); 127 | assertEquals(100, collection.getTotalProbability()); 128 | 129 | //Remove t3 130 | assertTrue(collection.remove(t3)); 131 | 132 | assertEquals(0, collection.size()); 133 | assertTrue(collection.isEmpty()); 134 | assertEquals(0, collection.getTotalProbability()); 135 | } 136 | 137 | @RepeatedTest(value = 10_000) 138 | public void test_clear() { 139 | ProbabilityCollection collection = new ProbabilityCollection<>(); 140 | assertEquals(0, collection.size()); 141 | assertTrue(collection.isEmpty()); 142 | assertEquals(0, collection.getTotalProbability()); 143 | 144 | collection.clear(); 145 | 146 | assertEquals(0, collection.size()); 147 | assertTrue(collection.isEmpty()); 148 | assertEquals(0, collection.getTotalProbability()); 149 | 150 | collection.add("tmp", 1); 151 | 152 | assertEquals(1, collection.size()); 153 | assertFalse(collection.isEmpty()); 154 | assertEquals(1, collection.getTotalProbability()); 155 | 156 | collection.clear(); 157 | 158 | assertEquals(0, collection.size()); 159 | assertTrue(collection.isEmpty()); 160 | assertEquals(0, collection.getTotalProbability()); 161 | 162 | String t1 = "Hello"; 163 | String t2 = "World"; 164 | String t3 = "!"; 165 | 166 | for(int i = 0; i < 10; i++) { 167 | collection.add(t1, 10); 168 | } 169 | 170 | for(int i = 0; i < 10; i++) { 171 | collection.add(t2, 10); 172 | } 173 | 174 | for(int i = 0; i < 10; i++) { 175 | collection.add(t3, 10); 176 | } 177 | 178 | assertEquals(30, collection.size()); 179 | assertFalse(collection.isEmpty()); 180 | assertEquals(300, collection.getTotalProbability()); 181 | 182 | collection.clear(); 183 | 184 | assertEquals(0, collection.getTotalProbability()); 185 | assertEquals(0, collection.size()); 186 | assertTrue(collection.isEmpty()); 187 | } 188 | 189 | @RepeatedTest(1_000_000) 190 | public void test_probability() { 191 | ProbabilityCollection collection = new ProbabilityCollection<>(); 192 | 193 | assertEquals(0, collection.size()); 194 | assertTrue(collection.isEmpty()); 195 | assertEquals(0, collection.getTotalProbability()); 196 | 197 | collection.add("A", 50); 198 | collection.add("B", 25); 199 | collection.add("C", 10); 200 | 201 | int a = 0, b = 0, c = 0; 202 | 203 | int totalGets = 100_000; 204 | 205 | for(int i = 0; i < totalGets; i++) { 206 | String random = collection.get(); 207 | 208 | if(random.equals("A")) a++; 209 | else if(random.equals("B")) b++; 210 | else if(random.equals("C")) c++; 211 | } 212 | 213 | double aProb = 50.0 / (double) collection.getTotalProbability() * 100; 214 | double bProb = 25.0 / (double) collection.getTotalProbability() * 100; 215 | double cProb = 10.0 / (double) collection.getTotalProbability() * 100; 216 | 217 | double aResult = a / (double) totalGets * 100; 218 | double bResult = b / (double) totalGets * 100; 219 | double cResult = c / (double) totalGets * 100; 220 | 221 | double acceptableDeviation = 1; // % 222 | 223 | assertTrue(Math.abs(aProb - aResult) <= acceptableDeviation); 224 | assertTrue(Math.abs(bProb - bResult) <= acceptableDeviation); 225 | assertTrue(Math.abs(cProb - cResult) <= acceptableDeviation); 226 | } 227 | 228 | @RepeatedTest(1_000_000) 229 | public void test_get_never_null() { 230 | ProbabilityCollection collection = new ProbabilityCollection<>(); 231 | // Tests get will never return null 232 | // Just one smallest element get, must not return null 233 | collection.add("A", 1); 234 | assertNotNull(collection.get()); 235 | 236 | // Reset state 237 | collection.remove("A"); 238 | assertEquals(0, collection.size()); 239 | assertTrue(collection.isEmpty()); 240 | 241 | // Just one large element, must not return null 242 | collection.add("A", 5_000_000); 243 | assertNotNull(collection.get()); 244 | } 245 | 246 | @Test 247 | public void test_Errors() { 248 | ProbabilityCollection collection = new ProbabilityCollection<>(); 249 | 250 | assertEquals(0, collection.size()); 251 | assertTrue(collection.isEmpty()); 252 | assertEquals(0, collection.getTotalProbability()); 253 | 254 | // Cannot get from empty collection 255 | assertThrows(IllegalStateException.class, () -> { 256 | collection.get(); 257 | }); 258 | 259 | assertEquals(0, collection.size()); 260 | assertTrue(collection.isEmpty()); 261 | assertEquals(0, collection.getTotalProbability()); 262 | 263 | // Cannot add null object 264 | assertThrows(IllegalArgumentException.class, () -> { 265 | collection.add(null, 1); 266 | }); 267 | 268 | assertEquals(0, collection.size()); 269 | assertTrue(collection.isEmpty()); 270 | assertEquals(0, collection.getTotalProbability()); 271 | 272 | // Cannot add prob 0 273 | assertThrows(IllegalArgumentException.class, () -> { 274 | collection.add("A", 0); 275 | }); 276 | 277 | assertEquals(0, collection.size()); 278 | assertTrue(collection.isEmpty()); 279 | assertEquals(0, collection.getTotalProbability()); 280 | 281 | // Cannot remove null 282 | assertThrows(IllegalArgumentException.class, () -> { 283 | collection.remove(null); 284 | }); 285 | 286 | assertEquals(0, collection.size()); 287 | assertTrue(collection.isEmpty()); 288 | assertEquals(0, collection.getTotalProbability()); 289 | 290 | // Cannot contains null 291 | assertThrows(IllegalArgumentException.class, () -> { 292 | collection.contains(null); 293 | }); 294 | 295 | assertEquals(0, collection.size()); 296 | assertTrue(collection.isEmpty()); 297 | assertEquals(0, collection.getTotalProbability()); 298 | } 299 | } 300 | -------------------------------------------------------------------------------- /src/test/java/com/lewdev/probabilitylib/ProbabilityMap.java: -------------------------------------------------------------------------------- 1 | package com.lewdev.probabilitylib; 2 | 3 | import java.util.LinkedHashMap; 4 | import java.util.Map; 5 | import java.util.Map.Entry; 6 | import java.util.concurrent.ThreadLocalRandom; 7 | import java.util.stream.Collectors; 8 | 9 | /** 10 | * ProbabilityMap to easily handle probability
11 | *
12 | * 13 | * Selection Algorithm Implementation: 14 | *

15 | *

    16 | *
  • Elements have a "box" of space, sized based on their probability share 17 | *
  • "Boxes" start from index 1 and end at the total probability of elements 18 | *
  • A random number is selected between 1 and the total probability 19 | *
  • Which "box" the random number falls in is the element that is selected 20 | *
  • Therefore "boxes" with larger probability have a greater chance of being 21 | * selected than those with smaller probability. 22 | *

    23 | *
24 | * 25 | * @param Type of elements 26 | * @version 0.5 27 | * 28 | * @author Lewys Davies 29 | */ 30 | public class ProbabilityMap { 31 | 32 | private LinkedHashMap map = new LinkedHashMap<>(); 33 | 34 | private int totalProbability = 0; 35 | 36 | /** 37 | * Construct a empty probability map 38 | */ 39 | public ProbabilityMap() { } 40 | 41 | /** 42 | * Construct a probability map with initial elements 43 | * 44 | * @param elements 45 | */ 46 | public ProbabilityMap(Map elements) { 47 | this.addAll(elements); 48 | } 49 | 50 | /** 51 | * Add a element to the map 52 | * 53 | * @param element 54 | * @param probability x > 0 55 | */ 56 | public final boolean add(E element, int probability) { 57 | this.map.put(element, probability); 58 | this.updateState(); 59 | return true; 60 | } 61 | 62 | /** 63 | * Add all elements from a different map to this one 64 | * 65 | * @param elements 66 | */ 67 | public final void addAll(Map elements) { 68 | this.map.putAll(elements); 69 | this.updateState(); 70 | } 71 | 72 | /** 73 | * Get a random element from the map 74 | * 75 | * @return Random element based on probability | null if map is empty 76 | */ 77 | public final E get() { 78 | if (this.map.isEmpty()) 79 | return null; 80 | 81 | // Map is sorted when changed 82 | // Therefore probability of elements is already in descending order: i.e. 5, 5, 83 | // 4, 1, 1 84 | 85 | // Random int between 1 and total probability (+1 as nextInt bound is exclusive) 86 | int randomProb = ThreadLocalRandom.current().nextInt(1, this.totalProbability + 1); 87 | 88 | int cumulativeProb = 0; 89 | E selectedElm = null; 90 | 91 | for (Entry entry : this.map.entrySet()) { 92 | // Calculate the size of this elements box: i.e 1-5, 6-10, 11-14, 15, 16 93 | int boxStart = cumulativeProb + 1; 94 | int boxEnd = boxStart + (entry.getValue() - 1); 95 | 96 | // Check if the elements box falls within the randomly chosen index 97 | if (randomProb >= boxStart && randomProb <= boxEnd) { 98 | selectedElm = entry.getKey(); 99 | break; 100 | } 101 | 102 | // If not keep searching 103 | cumulativeProb = boxEnd; 104 | } 105 | 106 | return selectedElm; 107 | } 108 | 109 | /** 110 | * Remove a element from the map 111 | * 112 | * @param element 113 | */ 114 | public final void remove(E element) { 115 | this.map.remove(element); 116 | this.updateState(); 117 | } 118 | 119 | /** 120 | * Remove all elements from the map 121 | */ 122 | public final void clear() { 123 | this.map.clear(); 124 | this.updateState(); 125 | } 126 | 127 | /** 128 | * @return Sum of all the element's probability 129 | */ 130 | public final int getTotalProbability() { 131 | return this.totalProbability; 132 | } 133 | 134 | private final void updateState() { 135 | // Update total probability cache 136 | this.totalProbability = this.map.values().stream().mapToInt(Integer::intValue).sum(); 137 | 138 | // Sort LinkedHashMap 139 | this.map = this.map.entrySet().stream().sorted(Map.Entry.comparingByValue().reversed()) 140 | .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (x, y) -> y, LinkedHashMap::new)); 141 | } 142 | } --------------------------------------------------------------------------------