Elements have a "box" of space, sized based on their probability share
17 | *
"Boxes" start from index 1 and end at the total probability of elements
18 | *
A random number is selected between 1 and the total probability
19 | *
Which "box" the random number falls in is the element that is selected
20 | *
Therefore "boxes" with larger probability have a greater chance of being
21 | * selected than those with smaller probability.
22 | *
23 | *
24 | *
25 | * @param Type of elements
26 | * @version 0.5
27 | *
28 | * @author Lewys Davies
29 | */
30 | public class ProbabilityMap {
31 |
32 | private LinkedHashMap map = new LinkedHashMap<>();
33 |
34 | private int totalProbability = 0;
35 |
36 | /**
37 | * Construct a empty probability map
38 | */
39 | public ProbabilityMap() { }
40 |
41 | /**
42 | * Construct a probability map with initial elements
43 | *
44 | * @param elements
45 | */
46 | public ProbabilityMap(Map elements) {
47 | this.addAll(elements);
48 | }
49 |
50 | /**
51 | * Add a element to the map
52 | *
53 | * @param element
54 | * @param probability x > 0
55 | */
56 | public final boolean add(E element, int probability) {
57 | this.map.put(element, probability);
58 | this.updateState();
59 | return true;
60 | }
61 |
62 | /**
63 | * Add all elements from a different map to this one
64 | *
65 | * @param elements
66 | */
67 | public final void addAll(Map elements) {
68 | this.map.putAll(elements);
69 | this.updateState();
70 | }
71 |
72 | /**
73 | * Get a random element from the map
74 | *
75 | * @return Random element based on probability | null if map is empty
76 | */
77 | public final E get() {
78 | if (this.map.isEmpty())
79 | return null;
80 |
81 | // Map is sorted when changed
82 | // Therefore probability of elements is already in descending order: i.e. 5, 5,
83 | // 4, 1, 1
84 |
85 | // Random int between 1 and total probability (+1 as nextInt bound is exclusive)
86 | int randomProb = ThreadLocalRandom.current().nextInt(1, this.totalProbability + 1);
87 |
88 | int cumulativeProb = 0;
89 | E selectedElm = null;
90 |
91 | for (Entry entry : this.map.entrySet()) {
92 | // Calculate the size of this elements box: i.e 1-5, 6-10, 11-14, 15, 16
93 | int boxStart = cumulativeProb + 1;
94 | int boxEnd = boxStart + (entry.getValue() - 1);
95 |
96 | // Check if the elements box falls within the randomly chosen index
97 | if (randomProb >= boxStart && randomProb <= boxEnd) {
98 | selectedElm = entry.getKey();
99 | break;
100 | }
101 |
102 | // If not keep searching
103 | cumulativeProb = boxEnd;
104 | }
105 |
106 | return selectedElm;
107 | }
108 |
109 | /**
110 | * Remove a element from the map
111 | *
112 | * @param element
113 | */
114 | public final void remove(E element) {
115 | this.map.remove(element);
116 | this.updateState();
117 | }
118 |
119 | /**
120 | * Remove all elements from the map
121 | */
122 | public final void clear() {
123 | this.map.clear();
124 | this.updateState();
125 | }
126 |
127 | /**
128 | * @return Sum of all the element's probability
129 | */
130 | public final int getTotalProbability() {
131 | return this.totalProbability;
132 | }
133 |
134 | private final void updateState() {
135 | // Update total probability cache
136 | this.totalProbability = this.map.values().stream().mapToInt(Integer::intValue).sum();
137 |
138 | // Sort LinkedHashMap
139 | this.map = this.map.entrySet().stream().sorted(Map.Entry.comparingByValue().reversed())
140 | .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (x, y) -> y, LinkedHashMap::new));
141 | }
142 | }
--------------------------------------------------------------------------------
/src/main/java/com/lewdev/probabilitylib/ProbabilityCollection.java:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2020 Lewys Davies
3 | *
4 | * Permission is hereby granted, free of charge, to any person obtaining a copy
5 | * of this software and associated documentation files (the "Software"), to deal
6 | * in the Software without restriction, including without limitation the rights
7 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | * copies of the Software, and to permit persons to whom the Software is
9 | * furnished to do so, subject to the following conditions:
10 | *
11 | * The above copyright notice and this permission notice shall be included in all
12 | * copies or substantial portions of the Software.
13 | *
14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
20 | * SOFTWARE.
21 | */
22 | package com.lewdev.probabilitylib;
23 |
24 | import java.util.Comparator;
25 | import java.util.Iterator;
26 | import java.util.NavigableSet;
27 | import java.util.Objects;
28 | import java.util.SplittableRandom;
29 | import java.util.TreeSet;
30 |
31 | /**
32 | * ProbabilityCollection for retrieving random elements based on probability.
33 | *
34 | *
35 | * Selection Algorithm Implementation:
36 | *
37 | *
38 | *
Elements have a "block" of space, sized based on their probability share
39 | *
"Blocks" start from index 1 and end at the total probability of all
40 | * elements
41 | *
A random number is selected between 1 and the total probability
42 | *
Which "block" the random number falls in is the element that is selected
43 | *
Therefore "block"s with larger probability have a greater chance of being
44 | * selected than those with smaller probability.
45 | *
46 | *
47 | *
48 | * @author Lewys Davies
49 | * @version 0.8
50 | *
51 | * @param Type of elements
52 | */
53 | public final class ProbabilityCollection {
54 |
55 | private final NavigableSet> collection;
56 | private final SplittableRandom random = new SplittableRandom();
57 |
58 | private int totalProbability;
59 |
60 | /**
61 | * Construct a new Probability Collection
62 | */
63 | public ProbabilityCollection() {
64 | this.collection = new TreeSet<>(Comparator.comparingInt(ProbabilitySetElement::getIndex));
65 | this.totalProbability = 0;
66 | }
67 |
68 | /**
69 | * @return Number of objects inside the collection
70 | */
71 | public int size() {
72 | return this.collection.size();
73 | }
74 |
75 | /**
76 | * @return True if collection contains no elements, else False
77 | */
78 | public boolean isEmpty() {
79 | return this.collection.isEmpty();
80 | }
81 |
82 | /**
83 | * @param object
84 | * @return True if collection contains the object, else False
85 | * @throws IllegalArgumentException if object is null
86 | */
87 | public boolean contains(E object) {
88 | if (object == null) {
89 | throw new IllegalArgumentException("Cannot check if null object is contained in this collection");
90 | }
91 |
92 | return this.collection.stream().anyMatch(entry -> entry.getObject().equals(object));
93 | }
94 |
95 | /**
96 | * @return Iterator over this collection
97 | */
98 | public Iterator> iterator() {
99 | return this.collection.iterator();
100 | }
101 |
102 | /**
103 | * Add an object to this collection
104 | *
105 | * @param object. Not null.
106 | * @param probability share. Must be greater than 0.
107 | *
108 | * @throws IllegalArgumentException if object is null
109 | * @throws IllegalArgumentException if probability <= 0
110 | */
111 | public void add(E object, int probability) {
112 | if (object == null) {
113 | throw new IllegalArgumentException("Cannot add null object");
114 | }
115 |
116 | if (probability <= 0) {
117 | throw new IllegalArgumentException("Probability must be greater than 0");
118 | }
119 |
120 | ProbabilitySetElement entry = new ProbabilitySetElement(object, probability);
121 | entry.setIndex(this.totalProbability + 1);
122 |
123 | this.collection.add(entry);
124 | this.totalProbability += probability;
125 | }
126 |
127 | /**
128 | * Remove a object from this collection
129 | *
130 | * @param object
131 | * @return True if object was removed, else False.
132 | *
133 | * @throws IllegalArgumentException if object is null
134 | */
135 | public boolean remove(E object) {
136 | if (object == null) {
137 | throw new IllegalArgumentException("Cannot remove null object");
138 | }
139 |
140 | Iterator> it = this.iterator();
141 | boolean removed = false;
142 |
143 | // Remove all instances of the object
144 | while (it.hasNext()) {
145 | ProbabilitySetElement entry = it.next();
146 | if (entry.getObject().equals(object)) {
147 | this.totalProbability -= entry.getProbability();
148 | it.remove();
149 | removed = true;
150 | }
151 | }
152 |
153 | // Recalculate remaining elements "block" of space: i.e 1-5, 6-10, 11-14
154 | if (removed) {
155 | int previousIndex = 0;
156 | for (ProbabilitySetElement entry : this.collection) {
157 | previousIndex = entry.setIndex(previousIndex + 1) + (entry.getProbability() - 1);
158 | }
159 | }
160 |
161 | return removed;
162 | }
163 |
164 | /**
165 | * Remove all objects from this collection
166 | */
167 | public void clear() {
168 | this.collection.clear();
169 | this.totalProbability = 0;
170 | }
171 |
172 | /**
173 | * Get a random object from this collection, based on probability.
174 | *
175 | * @return Random object
176 | *
177 | * @throws IllegalStateException if this collection is empty
178 | */
179 | public E get() {
180 | if (this.isEmpty()) {
181 | throw new IllegalStateException("Cannot get an object out of a empty collection");
182 | }
183 |
184 | ProbabilitySetElement toFind = new ProbabilitySetElement<>(null, 0);
185 | toFind.setIndex(this.random.nextInt(1, this.totalProbability + 1));
186 |
187 | return Objects.requireNonNull(this.collection.floor(toFind).getObject());
188 | }
189 |
190 | /**
191 | * @return Sum of all element's probability
192 | */
193 | public int getTotalProbability() {
194 | return this.totalProbability;
195 | }
196 |
197 | /**
198 | * Used internally to store information about a object's state in a collection.
199 | * Specifically, the probability and index within the collection.
200 | *
201 | * Indexes refer to the start position of this element's "block" of space. The
202 | * space between element "block"s represents their probability of being selected
203 | *
204 | * @author Lewys Davies
205 | *
206 | * @param Type of element
207 | */
208 | public final static class ProbabilitySetElement {
209 | private final T object;
210 | private final int probability;
211 | private int index;
212 |
213 | /**
214 | * @param object
215 | * @param probability share within the collection
216 | */
217 | protected ProbabilitySetElement(T object, int probability) {
218 | this.object = object;
219 | this.probability = probability;
220 | }
221 |
222 | /**
223 | * @return The actual object
224 | */
225 | public T getObject() {
226 | return this.object;
227 | }
228 |
229 | /**
230 | * @return Probability share in this collection
231 | */
232 | public int getProbability() {
233 | return this.probability;
234 | }
235 |
236 | // Used internally, see this class's documentation
237 | private int getIndex() {
238 | return this.index;
239 | }
240 |
241 | // Used Internally, see this class's documentation
242 | private int setIndex(int index) {
243 | this.index = index;
244 | return this.index;
245 | }
246 | }
247 | }
248 |
--------------------------------------------------------------------------------
/src/test/java/com/lewdev/probabilitylib/ProbabilityCollectionTest.java:
--------------------------------------------------------------------------------
1 | package com.lewdev.probabilitylib;
2 |
3 | import static org.junit.jupiter.api.Assertions.*;
4 |
5 | import org.junit.jupiter.api.RepeatedTest;
6 | import org.junit.jupiter.api.Test;
7 |
8 | /**
9 | * @author Lewys Davies
10 | */
11 | public class ProbabilityCollectionTest {
12 |
13 | @RepeatedTest(value = 10_000)
14 | public void test_insert() {
15 | ProbabilityCollection collection = new ProbabilityCollection<>();
16 | assertEquals(0, collection.size());
17 | assertTrue(collection.isEmpty());
18 | assertEquals(0, collection.getTotalProbability());
19 |
20 | collection.add("A", 2);
21 | assertTrue(collection.contains("A"));
22 | assertEquals(1, collection.size());
23 | assertFalse(collection.isEmpty());
24 | assertEquals(2, collection.getTotalProbability()); // 2
25 |
26 | collection.add("B", 5);
27 | assertTrue(collection.contains("B"));
28 | assertEquals(2, collection.size());
29 | assertFalse(collection.isEmpty());
30 | assertEquals(7, collection.getTotalProbability()); // 5 + 2
31 |
32 | collection.add("C", 10);
33 | assertTrue(collection.contains("C"));
34 | assertEquals(3, collection.size());
35 | assertFalse(collection.isEmpty());
36 | assertEquals(17, collection.getTotalProbability()); // 5 + 2 + 10
37 |
38 | for(int i = 0; i < 100; i++) {
39 | collection.add("C", 1);
40 |
41 | assertTrue(collection.contains("C"));
42 | assertEquals(4 + i, collection.size()); // 4 + i
43 | assertFalse(collection.isEmpty());
44 | assertEquals(18 + i, collection.getTotalProbability()); // 5 + 2 + 10 + i
45 | }
46 | }
47 |
48 | @RepeatedTest(value = 10_000)
49 | public void test_remove() {
50 | ProbabilityCollection collection = new ProbabilityCollection<>();
51 | assertEquals(0, collection.size());
52 | assertTrue(collection.isEmpty());
53 | assertEquals(0, collection.getTotalProbability());
54 |
55 | String t1 = "Hello";
56 | String t2 = "World";
57 | String t3 = "!";
58 |
59 | collection.add(t1, 10);
60 | collection.add(t2, 10);
61 | collection.add(t3, 10);
62 |
63 | assertEquals(3, collection.size());
64 | assertFalse(collection.isEmpty());
65 | assertEquals(30, collection.getTotalProbability());
66 |
67 | // Remove t2
68 | assertTrue(collection.remove(t2));
69 |
70 | assertEquals(2, collection.size());
71 | assertFalse(collection.isEmpty());
72 | assertEquals(20, collection.getTotalProbability());
73 |
74 | // Remove t1
75 | assertTrue(collection.remove(t1));
76 |
77 | assertEquals(1, collection.size());
78 | assertFalse(collection.isEmpty());
79 | assertEquals(10, collection.getTotalProbability());
80 |
81 | //Remove t3
82 | assertTrue(collection.remove(t3));
83 |
84 | assertEquals(0, collection.size());
85 | assertTrue(collection.isEmpty());
86 | assertEquals(0, collection.getTotalProbability());
87 | }
88 |
89 | @RepeatedTest(value = 10_000)
90 | public void test_remove_duplicates() {
91 | ProbabilityCollection collection = new ProbabilityCollection<>();
92 | assertEquals(0, collection.size());
93 | assertTrue(collection.isEmpty());
94 |
95 | String t1 = "Hello";
96 | String t2 = "World";
97 | String t3 = "!";
98 |
99 | for(int i = 0; i < 10; i++) {
100 | collection.add(t1, 10);
101 | }
102 |
103 | for(int i = 0; i < 10; i++) {
104 | collection.add(t2, 10);
105 | }
106 |
107 | for(int i = 0; i < 10; i++) {
108 | collection.add(t3, 10);
109 | }
110 |
111 | assertEquals(30, collection.size());
112 | assertFalse(collection.isEmpty());
113 | assertEquals(300, collection.getTotalProbability());
114 |
115 | //Remove t2
116 | assertTrue(collection.remove(t2));
117 |
118 | assertEquals(20, collection.size());
119 | assertFalse(collection.isEmpty());
120 | assertEquals(200, collection.getTotalProbability());
121 |
122 | // Remove t1
123 | assertTrue(collection.remove(t1));
124 |
125 | assertEquals(10, collection.size());
126 | assertFalse(collection.isEmpty());
127 | assertEquals(100, collection.getTotalProbability());
128 |
129 | //Remove t3
130 | assertTrue(collection.remove(t3));
131 |
132 | assertEquals(0, collection.size());
133 | assertTrue(collection.isEmpty());
134 | assertEquals(0, collection.getTotalProbability());
135 | }
136 |
137 | @RepeatedTest(value = 10_000)
138 | public void test_clear() {
139 | ProbabilityCollection collection = new ProbabilityCollection<>();
140 | assertEquals(0, collection.size());
141 | assertTrue(collection.isEmpty());
142 | assertEquals(0, collection.getTotalProbability());
143 |
144 | collection.clear();
145 |
146 | assertEquals(0, collection.size());
147 | assertTrue(collection.isEmpty());
148 | assertEquals(0, collection.getTotalProbability());
149 |
150 | collection.add("tmp", 1);
151 |
152 | assertEquals(1, collection.size());
153 | assertFalse(collection.isEmpty());
154 | assertEquals(1, collection.getTotalProbability());
155 |
156 | collection.clear();
157 |
158 | assertEquals(0, collection.size());
159 | assertTrue(collection.isEmpty());
160 | assertEquals(0, collection.getTotalProbability());
161 |
162 | String t1 = "Hello";
163 | String t2 = "World";
164 | String t3 = "!";
165 |
166 | for(int i = 0; i < 10; i++) {
167 | collection.add(t1, 10);
168 | }
169 |
170 | for(int i = 0; i < 10; i++) {
171 | collection.add(t2, 10);
172 | }
173 |
174 | for(int i = 0; i < 10; i++) {
175 | collection.add(t3, 10);
176 | }
177 |
178 | assertEquals(30, collection.size());
179 | assertFalse(collection.isEmpty());
180 | assertEquals(300, collection.getTotalProbability());
181 |
182 | collection.clear();
183 |
184 | assertEquals(0, collection.getTotalProbability());
185 | assertEquals(0, collection.size());
186 | assertTrue(collection.isEmpty());
187 | }
188 |
189 | @RepeatedTest(1_000_000)
190 | public void test_probability() {
191 | ProbabilityCollection collection = new ProbabilityCollection<>();
192 |
193 | assertEquals(0, collection.size());
194 | assertTrue(collection.isEmpty());
195 | assertEquals(0, collection.getTotalProbability());
196 |
197 | collection.add("A", 50);
198 | collection.add("B", 25);
199 | collection.add("C", 10);
200 |
201 | int a = 0, b = 0, c = 0;
202 |
203 | int totalGets = 100_000;
204 |
205 | for(int i = 0; i < totalGets; i++) {
206 | String random = collection.get();
207 |
208 | if(random.equals("A")) a++;
209 | else if(random.equals("B")) b++;
210 | else if(random.equals("C")) c++;
211 | }
212 |
213 | double aProb = 50.0 / (double) collection.getTotalProbability() * 100;
214 | double bProb = 25.0 / (double) collection.getTotalProbability() * 100;
215 | double cProb = 10.0 / (double) collection.getTotalProbability() * 100;
216 |
217 | double aResult = a / (double) totalGets * 100;
218 | double bResult = b / (double) totalGets * 100;
219 | double cResult = c / (double) totalGets * 100;
220 |
221 | double acceptableDeviation = 1; // %
222 |
223 | assertTrue(Math.abs(aProb - aResult) <= acceptableDeviation);
224 | assertTrue(Math.abs(bProb - bResult) <= acceptableDeviation);
225 | assertTrue(Math.abs(cProb - cResult) <= acceptableDeviation);
226 | }
227 |
228 | @RepeatedTest(1_000_000)
229 | public void test_get_never_null() {
230 | ProbabilityCollection collection = new ProbabilityCollection<>();
231 | // Tests get will never return null
232 | // Just one smallest element get, must not return null
233 | collection.add("A", 1);
234 | assertNotNull(collection.get());
235 |
236 | // Reset state
237 | collection.remove("A");
238 | assertEquals(0, collection.size());
239 | assertTrue(collection.isEmpty());
240 |
241 | // Just one large element, must not return null
242 | collection.add("A", 5_000_000);
243 | assertNotNull(collection.get());
244 | }
245 |
246 | @Test
247 | public void test_Errors() {
248 | ProbabilityCollection collection = new ProbabilityCollection<>();
249 |
250 | assertEquals(0, collection.size());
251 | assertTrue(collection.isEmpty());
252 | assertEquals(0, collection.getTotalProbability());
253 |
254 | // Cannot get from empty collection
255 | assertThrows(IllegalStateException.class, () -> {
256 | collection.get();
257 | });
258 |
259 | assertEquals(0, collection.size());
260 | assertTrue(collection.isEmpty());
261 | assertEquals(0, collection.getTotalProbability());
262 |
263 | // Cannot add null object
264 | assertThrows(IllegalArgumentException.class, () -> {
265 | collection.add(null, 1);
266 | });
267 |
268 | assertEquals(0, collection.size());
269 | assertTrue(collection.isEmpty());
270 | assertEquals(0, collection.getTotalProbability());
271 |
272 | // Cannot add prob 0
273 | assertThrows(IllegalArgumentException.class, () -> {
274 | collection.add("A", 0);
275 | });
276 |
277 | assertEquals(0, collection.size());
278 | assertTrue(collection.isEmpty());
279 | assertEquals(0, collection.getTotalProbability());
280 |
281 | // Cannot remove null
282 | assertThrows(IllegalArgumentException.class, () -> {
283 | collection.remove(null);
284 | });
285 |
286 | assertEquals(0, collection.size());
287 | assertTrue(collection.isEmpty());
288 | assertEquals(0, collection.getTotalProbability());
289 |
290 | // Cannot contains null
291 | assertThrows(IllegalArgumentException.class, () -> {
292 | collection.contains(null);
293 | });
294 |
295 | assertEquals(0, collection.size());
296 | assertTrue(collection.isEmpty());
297 | assertEquals(0, collection.getTotalProbability());
298 | }
299 | }
300 |
--------------------------------------------------------------------------------