├── .gitignore ├── .travis.yml ├── README.md ├── example └── RunnableExample.java ├── forgetful-learning.pdf ├── pom.xml └── src ├── main └── java │ └── de │ └── daslaboratorium │ └── machinelearning │ └── classifier │ ├── Classification.java │ ├── Classifier.java │ ├── IFeatureProbability.java │ └── bayes │ └── BayesClassifier.java └── test └── java └── de └── daslaboratorium └── machinelearning └── classifier └── bayes └── BayesClassifierTest.java /.gitignore: -------------------------------------------------------------------------------- 1 | # Binary files 2 | *.class 3 | 4 | # Maven files 5 | target/ 6 | 7 | # Eclipse files 8 | .classpath 9 | .project 10 | 11 | # IntelliJ files 12 | .idea/ 13 | *.iml 14 | 15 | *.prefs 16 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Java Naive Bayes Classifier 2 | ================== 3 | 4 | [![Build Status](https://travis-ci.org/ptnplanet/Java-Naive-Bayes-Classifier.svg?branch=master)](https://travis-ci.org/ptnplanet/Java-Naive-Bayes-Classifier) 5 | [![](https://jitpack.io/v/ptnplanet/Java-Naive-Bayes-Classifier.svg)](https://jitpack.io/#ptnplanet/Java-Naive-Bayes-Classifier) 6 | 7 | Nothing special. It works and is well documented, so you should get it running without wasting too much time searching for other alternatives on the net. 8 | 9 | Maven Quick-Start 10 | ------------------ 11 | 12 | This Java Naive Bayes Classifier can be installed via the jitpack repository. Make sure to add it to your buildfile first. 13 | 14 | ```xml 15 | 16 | 17 | jitpack.io 18 | https://jitpack.io 19 | 20 | 21 | ``` 22 | 23 | Then, treat it as any other dependency. 24 | 25 | ```xml 26 | 27 | com.github.ptnplanet 28 | Java-Naive-Bayes-Classifier 29 | 1.0.7 30 | 31 | ``` 32 | 33 | For other build-tools (e.g. gradle), visit https://jitpack.io for configuration snippets. 34 | 35 | Please also head to the release tab for further releases. 36 | 37 | Overview 38 | ------------------ 39 | 40 | I like talking about *features* and *categories*. Objects have features and may belong to a category. The classifier will try matching objects to their categories by looking at the objects' features. It does so by consulting its memory filled with knowledge gathered from training examples. 41 | 42 | Classifying a feature-set results in the highest product of 1) the probability of that category to occur and 2) the product of all the features' probabilities to occure in that category: 43 | 44 | ```classify(feature1, ..., featureN) = argmax(P(category) * PROD(P(feature|category)))``` 45 | 46 | This is a so-called maximum a posteriori estimation. Wikipedia actually does a good job explaining it: http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Probabilistic_model 47 | 48 | Learning from Examples 49 | ------------------ 50 | 51 | Add knowledge by telling the classifier, that these features belong to a specific category: 52 | 53 | ```java 54 | String[] positiveText = "I love sunny days".split("\\s"); 55 | bayes.learn("positive", Arrays.asList(positiveText)); 56 | ``` 57 | 58 | Classify unknown objects 59 | ------------------ 60 | 61 | Use the gathered knowledge to classify unknown objects with their features. The classifier will return the category that the object most likely belongs to. 62 | 63 | ```java 64 | String[] unknownText1 = "today is a sunny day".split("\\s"); 65 | bayes.classify(Arrays.asList(unknownText1)).getCategory(); 66 | ``` 67 | 68 | Example 69 | ------------------ 70 | 71 | Here is an excerpt from the example. The classifier will classify sentences (arrays of features) as sentences with either positive or negative sentiment. Please refer to the full example for a more detailed documentation. 72 | 73 | ```java 74 | // Create a new bayes classifier with string categories and string features. 75 | Classifier bayes = new BayesClassifier(); 76 | 77 | // Two examples to learn from. 78 | String[] positiveText = "I love sunny days".split("\\s"); 79 | String[] negativeText = "I hate rain".split("\\s"); 80 | 81 | // Learn by classifying examples. 82 | // New categories can be added on the fly, when they are first used. 83 | // A classification consists of a category and a list of features 84 | // that resulted in the classification in that category. 85 | bayes.learn("positive", Arrays.asList(positiveText)); 86 | bayes.learn("negative", Arrays.asList(negativeText)); 87 | 88 | // Here are two unknown sentences to classify. 89 | String[] unknownText1 = "today is a sunny day".split("\\s"); 90 | String[] unknownText2 = "there will be rain".split("\\s"); 91 | 92 | System.out.println( // will output "positive" 93 | bayes.classify(Arrays.asList(unknownText1)).getCategory()); 94 | System.out.println( // will output "negative" 95 | bayes.classify(Arrays.asList(unknownText2)).getCategory()); 96 | 97 | // Get more detailed classification result. 98 | ((BayesClassifier) bayes).classifyDetailed( 99 | Arrays.asList(unknownText1)); 100 | 101 | // Change the memory capacity. New learned classifications (using 102 | // the learn method) are stored in a queue with the size given 103 | // here and used to classify unknown sentences. 104 | bayes.setMemoryCapacity(500); 105 | ``` 106 | 107 | Forgetful learning 108 | ------------------ 109 | 110 | This classifier is forgetful. This means, that the classifier will forget recent classifications it uses for future classifications after - defaulting to 1.000 - classifications learned. This will ensure, that the classifier can react to ongoing changes in the user's habbits. 111 | 112 | 113 | Interface 114 | ------------------ 115 | The abstract ```Classifier``` serves as a base for the concrete ```BayesClassifier```. Here are its methods. Please also refer to the Javadoc. 116 | 117 | * ```void reset()``` Resets the learned feature and category counts. 118 | * ```Set getFeatures()``` Returns a ```Set``` of features the classifier knows about. 119 | * ```Set getCategories()``` Returns a ```Set``` of categories the classifier knows about. 120 | * ```int getCategoriesTotal()``` Retrieves the total number of categories the classifier knows about. 121 | * ```int getMemoryCapacity()``` Retrieves the memory's capacity. 122 | * ```void setMemoryCapacity(int memoryCapacity)``` Sets the memory's capacity. If the new value is less than the old value, the memory will be truncated accordingly. 123 | * ```void incrementFeature(T feature, K category)``` Increments the count of a given feature in the given category. This is equal to telling the classifier, that this feature has occurred in this category. 124 | * ```void incrementCategory(K category)``` Increments the count of a given category. This is equal to telling the classifier, that this category has occurred once more. 125 | * ```void decrementFeature(T feature, K category)``` Decrements the count of a given feature in the given category. This is equal to telling the classifier that this feature was classified once in the category. 126 | * ```void decrementCategory(K category)``` Decrements the count of a given category. This is equal to telling the classifier, that this category has occurred once less. 127 | * ```int getFeatureCount(T feature, K category)``` Retrieves the number of occurrences of the given feature in the given category. 128 | * ```int getFeatureCount(T feature)``` Retrieves the total number of occurrences of the given feature. 129 | * ```int getCategoryCount(K category)``` Retrieves the number of occurrences of the given category. 130 | * ```float featureProbability(T feature, K category)``` (*implements* ```IFeatureProbability.featureProbability```) Returns the probability that the given feature occurs in the given category. 131 | * ```float featureWeighedAverage(T feature, K category)``` Retrieves the weighed average ```P(feature|category)``` with overall weight of ```1.0``` and an assumed probability of ```0.5```. The probability defaults to the overall feature probability. 132 | * ```float featureWeighedAverage(T feature, K category, IFeatureProbability calculator)``` Retrieves the weighed average ```P(feature|category)``` with overall weight of ```1.0```, an assumed probability of ```0.5``` and the given object to use for probability calculation. 133 | * ```float featureWeighedAverage(T feature, K category, IFeatureProbability calculator, float weight)```Retrieves the weighed average ```P(feature|category)``` with the given weight and an assumed probability of ```0.5``` and the given object to use for probability calculation. 134 | * ```float featureWeighedAverage(T feature, K category, IFeatureProbability calculator, float weight, float assumedProbability)``` Retrieves the weighed average ```P(feature|category)``` with the given weight, the given assumed probability and the given object to use for probability calculation. 135 | * ```void learn(K category, Collection features)``` Train the classifier by telling it that the given features resulted in the given category. 136 | * ```void learn(Classification classification)``` Train the classifier by telling it that the given features resulted in the given category. 137 | 138 | The ```BayesClassifier``` class implements the following abstract method: 139 | 140 | * ```Classification classify(Collection features)``` It will retrieve the most likely category for the features given and depends on the concrete classifier implementation. 141 | 142 | Running the example 143 | ------------------ 144 | 145 | ```shell 146 | $ git clone https://github.com/ptnplanet/Java-Naive-Bayes-Classifier.git 147 | $ cd Java-Naive-Bayes-Classifier 148 | $ javac -cp src/main/java example/RunnableExample.java 149 | $ java -cp example:src/main/java RunnableExample 150 | ``` 151 | 152 | Possible Performance issues 153 | ------------------ 154 | 155 | Performance improvements, I am currently thinking of: 156 | 157 | - Store the natural logarithms of the feature probabilities and add them together instead of multiplying the probability numbers 158 | 159 | The MIT License (MIT) 160 | ------------------ 161 | 162 | Copyright (c) 2012-2017 Philipp Nolte 163 | 164 | Permission is hereby granted, free of charge, to any person obtaining a copy 165 | of this software and associated documentation files (the "Software"), to deal 166 | in the Software without restriction, including without limitation the rights 167 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 168 | copies of the Software, and to permit persons to whom the Software is 169 | furnished to do so, subject to the following conditions: 170 | 171 | The above copyright notice and this permission notice shall be included in 172 | all copies or substantial portions of the Software. 173 | 174 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 175 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 176 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 177 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 178 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 179 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 180 | THE SOFTWARE. 181 | -------------------------------------------------------------------------------- /example/RunnableExample.java: -------------------------------------------------------------------------------- 1 | import java.util.Arrays; 2 | 3 | import de.daslaboratorium.machinelearning.classifier.bayes.BayesClassifier; 4 | import de.daslaboratorium.machinelearning.classifier.Classifier; 5 | 6 | public class RunnableExample { 7 | 8 | public static void main(String[] args) { 9 | 10 | /* 11 | * Create a new classifier instance. The context features are 12 | * Strings and the context will be classified with a String according 13 | * to the featureset of the context. 14 | */ 15 | final Classifier bayes = 16 | new BayesClassifier(); 17 | 18 | /* 19 | * The classifier can learn from classifications that are handed over 20 | * to the learn methods. Imagin a tokenized text as follows. The tokens 21 | * are the text's features. The category of the text will either be 22 | * positive or negative. 23 | */ 24 | final String[] positiveText = "I love sunny days".split("\\s"); 25 | bayes.learn("positive", Arrays.asList(positiveText)); 26 | 27 | final String[] negativeText = "I hate rain".split("\\s"); 28 | bayes.learn("negative", Arrays.asList(negativeText)); 29 | 30 | /* 31 | * Now that the classifier has "learned" two classifications, it will 32 | * be able to classify similar sentences. The classify method returns 33 | * a Classification Object, that contains the given featureset, 34 | * classification probability and resulting category. 35 | */ 36 | final String[] unknownText1 = "today is a sunny day".split("\\s"); 37 | final String[] unknownText2 = "there will be rain".split("\\s"); 38 | 39 | System.out.println( // will output "positive" 40 | bayes.classify(Arrays.asList(unknownText1)).getCategory()); 41 | System.out.println( // will output "negative" 42 | bayes.classify(Arrays.asList(unknownText2)).getCategory()); 43 | 44 | /* 45 | * The BayesClassifier extends the abstract Classifier and provides 46 | * detailed classification results that can be retrieved by calling 47 | * the classifyDetailed Method. 48 | * 49 | * The classification with the highest probability is the resulting 50 | * classification. The returned List will look like this. 51 | * [ 52 | * Classification [ 53 | * category=negative, 54 | * probability=0.0078125, 55 | * featureset=[today, is, a, sunny, day] 56 | * ], 57 | * Classification [ 58 | * category=positive, 59 | * probability=0.0234375, 60 | * featureset=[today, is, a, sunny, day] 61 | * ] 62 | * ] 63 | */ 64 | ((BayesClassifier) bayes).classifyDetailed( 65 | Arrays.asList(unknownText1)); 66 | 67 | /* 68 | * Please note, that this particular classifier implementation will 69 | * "forget" learned classifications after a few learning sessions. The 70 | * number of learning sessions it will record can be set as follows: 71 | */ 72 | bayes.setMemoryCapacity(500); // remember the last 500 learned classifications 73 | } 74 | 75 | } 76 | -------------------------------------------------------------------------------- /forgetful-learning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ptnplanet/Java-Naive-Bayes-Classifier/3f8d874c98baacbd3f4accae71e7179ab426dfe2/forgetful-learning.pdf -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | com.github.ptnplanet 8 | Java-Naive-Bayes-Classifier 9 | 1.0.7 10 | 11 | 12 | 13 | junit 14 | junit 15 | 4.12 16 | test 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /src/main/java/de/daslaboratorium/machinelearning/classifier/Classification.java: -------------------------------------------------------------------------------- 1 | package de.daslaboratorium.machinelearning.classifier; 2 | 3 | import java.io.Serializable; 4 | import java.util.Collection; 5 | 6 | /** 7 | * A basic wrapper reflecting a classification. It will store both featureset 8 | * and resulting classification. 9 | * 10 | * @author Philipp Nolte 11 | * 12 | * @param 13 | * The feature class. 14 | * @param 15 | * The category class. 16 | */ 17 | public class Classification implements Serializable { 18 | 19 | /** 20 | * Generated Serial Version UID (generated for v1.0.7). 21 | */ 22 | private static final long serialVersionUID = -1210981535415341283L; 23 | 24 | /** 25 | * The classified featureset. 26 | */ 27 | private Collection featureset; 28 | 29 | /** 30 | * The category as which the featureset was classified. 31 | */ 32 | private K category; 33 | 34 | /** 35 | * The probability that the featureset belongs to the given category. 36 | */ 37 | private float probability; 38 | 39 | /** 40 | * Constructs a new Classification with the parameters given and a default 41 | * probability of 1. 42 | * 43 | * @param featureset 44 | * The featureset. 45 | * @param category 46 | * The category. 47 | */ 48 | public Classification(Collection featureset, K category) { 49 | this(featureset, category, 1.0f); 50 | } 51 | 52 | /** 53 | * Constructs a new Classification with the parameters given. 54 | * 55 | * @param featureset 56 | * The featureset. 57 | * @param category 58 | * The category. 59 | * @param probability 60 | * The probability. 61 | */ 62 | public Classification(Collection featureset, K category, float probability) { 63 | this.featureset = featureset; 64 | this.category = category; 65 | this.probability = probability; 66 | } 67 | 68 | /** 69 | * Retrieves the featureset classified. 70 | * 71 | * @return The featureset. 72 | */ 73 | public Collection getFeatureset() { 74 | return featureset; 75 | } 76 | 77 | /** 78 | * Retrieves the classification's probability. 79 | * 80 | * @return 81 | */ 82 | public float getProbability() { 83 | return this.probability; 84 | } 85 | 86 | /** 87 | * Retrieves the category the featureset was classified as. 88 | * 89 | * @return The category. 90 | */ 91 | public K getCategory() { 92 | return category; 93 | } 94 | 95 | /** 96 | * {@inheritDoc} 97 | */ 98 | @Override 99 | public String toString() { 100 | return "Classification [category=" + this.category + ", probability=" + this.probability + ", featureset=" 101 | + this.featureset + "]"; 102 | } 103 | 104 | } 105 | -------------------------------------------------------------------------------- /src/main/java/de/daslaboratorium/machinelearning/classifier/Classifier.java: -------------------------------------------------------------------------------- 1 | package de.daslaboratorium.machinelearning.classifier; 2 | 3 | import java.util.Collection; 4 | import java.util.Dictionary; 5 | import java.util.Enumeration; 6 | import java.util.Hashtable; 7 | import java.util.LinkedList; 8 | import java.util.Queue; 9 | import java.util.Set; 10 | 11 | /** 12 | * Abstract base extended by any concrete classifier. It implements the basic 13 | * functionality for storing categories or features and can be used to calculate 14 | * basic probabilities – both category and feature probabilities. The classify 15 | * function has to be implemented by the concrete classifier class. 16 | * 17 | * @author Philipp Nolte 18 | * 19 | * @param 20 | * A feature class 21 | * @param 22 | * A category class 23 | */ 24 | public abstract class Classifier implements IFeatureProbability, java.io.Serializable { 25 | 26 | /** 27 | * Generated Serial Version UID (generated for v1.0.7). 28 | */ 29 | private static final long serialVersionUID = 5504911666956811966L; 30 | 31 | /** 32 | * Initial capacity of category dictionaries. 33 | */ 34 | private static final int INITIAL_CATEGORY_DICTIONARY_CAPACITY = 16; 35 | 36 | /** 37 | * Initial capacity of feature dictionaries. It should be quite big, because 38 | * the features will quickly outnumber the categories. 39 | */ 40 | private static final int INITIAL_FEATURE_DICTIONARY_CAPACITY = 32; 41 | 42 | /** 43 | * The initial memory capacity or how many classifications are memorized. 44 | */ 45 | private int memoryCapacity = 1000; 46 | 47 | /** 48 | * A dictionary mapping features to their number of occurrences in each 49 | * known category. 50 | */ 51 | private Dictionary> featureCountPerCategory; 52 | 53 | /** 54 | * A dictionary mapping features to their number of occurrences. 55 | */ 56 | private Dictionary totalFeatureCount; 57 | 58 | /** 59 | * A dictionary mapping categories to their number of occurrences. 60 | */ 61 | private Dictionary totalCategoryCount; 62 | 63 | /** 64 | * The classifier's memory. It will forget old classifications as soon as 65 | * they become too old. 66 | */ 67 | private Queue> memoryQueue; 68 | 69 | /** 70 | * Constructs a new classifier without any trained knowledge. 71 | */ 72 | public Classifier() { 73 | this.reset(); 74 | } 75 | 76 | /** 77 | * Resets the learned feature and category counts. 78 | */ 79 | public void reset() { 80 | this.featureCountPerCategory = new Hashtable>( 81 | Classifier.INITIAL_CATEGORY_DICTIONARY_CAPACITY); 82 | this.totalFeatureCount = new Hashtable(Classifier.INITIAL_FEATURE_DICTIONARY_CAPACITY); 83 | this.totalCategoryCount = new Hashtable(Classifier.INITIAL_CATEGORY_DICTIONARY_CAPACITY); 84 | this.memoryQueue = new LinkedList>(); 85 | } 86 | 87 | /** 88 | * Returns a Set of features the classifier knows about. 89 | * 90 | * @return The Set of features the classifier knows about. 91 | */ 92 | public Set getFeatures() { 93 | return ((Hashtable) this.totalFeatureCount).keySet(); 94 | } 95 | 96 | /** 97 | * Returns a Set of categories the classifier knows about. 98 | * 99 | * @return The Set of categories the classifier knows about. 100 | */ 101 | public Set getCategories() { 102 | return ((Hashtable) this.totalCategoryCount).keySet(); 103 | } 104 | 105 | /** 106 | * Retrieves the total number of categories the classifier knows about. 107 | * 108 | * @return The total category count. 109 | */ 110 | public int getCategoriesTotal() { 111 | int toReturn = 0; 112 | for (Enumeration e = this.totalCategoryCount.elements(); e.hasMoreElements();) { 113 | toReturn += e.nextElement(); 114 | } 115 | return toReturn; 116 | } 117 | 118 | /** 119 | * Retrieves the memory's capacity. 120 | * 121 | * @return The memory's capacity. 122 | */ 123 | public int getMemoryCapacity() { 124 | return memoryCapacity; 125 | } 126 | 127 | /** 128 | * Sets the memory's capacity. If the new value is less than the old value, 129 | * the memory will be truncated accordingly. 130 | * 131 | * @param memoryCapacity 132 | * The new memory capacity. 133 | */ 134 | public void setMemoryCapacity(int memoryCapacity) { 135 | for (int i = this.memoryCapacity; i > memoryCapacity; i--) { 136 | this.memoryQueue.poll(); 137 | } 138 | this.memoryCapacity = memoryCapacity; 139 | } 140 | 141 | /** 142 | * Increments the count of a given feature in the given category. This is 143 | * equal to telling the classifier, that this feature has occurred in this 144 | * category. 145 | * 146 | * @param feature 147 | * The feature, which count to increase. 148 | * @param category 149 | * The category the feature occurred in. 150 | */ 151 | public void incrementFeature(T feature, K category) { 152 | Dictionary features = this.featureCountPerCategory.get(category); 153 | if (features == null) { 154 | this.featureCountPerCategory.put(category, 155 | new Hashtable(Classifier.INITIAL_FEATURE_DICTIONARY_CAPACITY)); 156 | features = this.featureCountPerCategory.get(category); 157 | } 158 | Integer count = features.get(feature); 159 | if (count == null) { 160 | features.put(feature, 0); 161 | count = features.get(feature); 162 | } 163 | features.put(feature, ++count); 164 | 165 | Integer totalCount = this.totalFeatureCount.get(feature); 166 | if (totalCount == null) { 167 | this.totalFeatureCount.put(feature, 0); 168 | totalCount = this.totalFeatureCount.get(feature); 169 | } 170 | this.totalFeatureCount.put(feature, ++totalCount); 171 | } 172 | 173 | /** 174 | * Increments the count of a given category. This is equal to telling the 175 | * classifier, that this category has occurred once more. 176 | * 177 | * @param category 178 | * The category, which count to increase. 179 | */ 180 | public void incrementCategory(K category) { 181 | Integer count = this.totalCategoryCount.get(category); 182 | if (count == null) { 183 | this.totalCategoryCount.put(category, 0); 184 | count = this.totalCategoryCount.get(category); 185 | } 186 | this.totalCategoryCount.put(category, ++count); 187 | } 188 | 189 | /** 190 | * Decrements the count of a given feature in the given category. This is 191 | * equal to telling the classifier that this feature was classified once in 192 | * the category. 193 | * 194 | * @param feature 195 | * The feature to decrement the count for. 196 | * @param category 197 | * The category. 198 | */ 199 | public void decrementFeature(T feature, K category) { 200 | Dictionary features = this.featureCountPerCategory.get(category); 201 | if (features == null) { 202 | return; 203 | } 204 | Integer count = features.get(feature); 205 | if (count == null) { 206 | return; 207 | } 208 | if (count.intValue() == 1) { 209 | features.remove(feature); 210 | if (features.size() == 0) { 211 | this.featureCountPerCategory.remove(category); 212 | } 213 | } else { 214 | features.put(feature, --count); 215 | } 216 | 217 | Integer totalCount = this.totalFeatureCount.get(feature); 218 | if (totalCount == null) { 219 | return; 220 | } 221 | if (totalCount.intValue() == 1) { 222 | this.totalFeatureCount.remove(feature); 223 | } else { 224 | this.totalFeatureCount.put(feature, --totalCount); 225 | } 226 | } 227 | 228 | /** 229 | * Decrements the count of a given category. This is equal to telling the 230 | * classifier, that this category has occurred once less. 231 | * 232 | * @param category 233 | * The category, which count to increase. 234 | */ 235 | public void decrementCategory(K category) { 236 | Integer count = this.totalCategoryCount.get(category); 237 | if (count == null) { 238 | return; 239 | } 240 | if (count.intValue() == 1) { 241 | this.totalCategoryCount.remove(category); 242 | } else { 243 | this.totalCategoryCount.put(category, --count); 244 | } 245 | } 246 | 247 | /** 248 | * Retrieves the number of occurrences of the given feature in the given 249 | * category. 250 | * 251 | * @param feature 252 | * The feature, which count to retrieve. 253 | * @param category 254 | * The category, which the feature occurred in. 255 | * @return The number of occurrences of the feature in the category. 256 | */ 257 | public int getFeatureCount(T feature, K category) { 258 | Dictionary features = this.featureCountPerCategory.get(category); 259 | if (features == null) return 0; 260 | Integer count = features.get(feature); 261 | return (count == null) ? 0 : count.intValue(); 262 | } 263 | 264 | /** 265 | * Retrieves the total number of occurrences of the given feature. 266 | * 267 | * @param feature 268 | * The feature, which count to retrieve. 269 | * @return The total number of occurences of the feature. 270 | */ 271 | public int getFeatureCount(T feature) { 272 | Integer count = this.totalFeatureCount.get(feature); 273 | return (count == null) ? 0 : count.intValue(); 274 | } 275 | 276 | /** 277 | * Retrieves the number of occurrences of the given category. 278 | * 279 | * @param category 280 | * The category, which count should be retrieved. 281 | * @return The number of occurrences. 282 | */ 283 | public int getCategoryCount(K category) { 284 | Integer count = this.totalCategoryCount.get(category); 285 | return (count == null) ? 0 : count.intValue(); 286 | } 287 | 288 | /** 289 | * {@inheritDoc} 290 | */ 291 | public float featureProbability(T feature, K category) { 292 | final float totalFeatureCount = this.getFeatureCount(feature); 293 | 294 | if (totalFeatureCount == 0) { 295 | return 0; 296 | } else { 297 | return this.getFeatureCount(feature, category) / (float) this.getFeatureCount(feature); 298 | } 299 | } 300 | 301 | /** 302 | * Retrieves the weighed average P(feature|category) with 303 | * overall weight of 1.0 and an assumed probability of 304 | * 0.5. The probability defaults to the overall feature 305 | * probability. 306 | * 307 | * @see de.daslaboratorium.machinelearning.classifier.Classifier#featureProbability(Object, 308 | * Object) 309 | * @see de.daslaboratorium.machinelearning.classifier.Classifier#featureWeighedAverage(Object, 310 | * Object, IFeatureProbability, float, float) 311 | * 312 | * @param feature 313 | * The feature, which probability to calculate. 314 | * @param category 315 | * The category. 316 | * @return The weighed average probability. 317 | */ 318 | public float featureWeighedAverage(T feature, K category) { 319 | return this.featureWeighedAverage(feature, category, null, 1.0f, 0.5f); 320 | } 321 | 322 | /** 323 | * Retrieves the weighed average P(feature|category) with 324 | * overall weight of 1.0, an assumed probability of 325 | * 0.5 and the given object to use for probability calculation. 326 | * 327 | * @see de.daslaboratorium.machinelearning.classifier.Classifier#featureWeighedAverage(Object, 328 | * Object, IFeatureProbability, float, float) 329 | * 330 | * @param feature 331 | * The feature, which probability to calculate. 332 | * @param category 333 | * The category. 334 | * @param calculator 335 | * The calculating object. 336 | * @return The weighed average probability. 337 | */ 338 | public float featureWeighedAverage(T feature, K category, IFeatureProbability calculator) { 339 | return this.featureWeighedAverage(feature, category, calculator, 1.0f, 0.5f); 340 | } 341 | 342 | /** 343 | * Retrieves the weighed average P(feature|category) with the 344 | * given weight and an assumed probability of 0.5 and the given 345 | * object to use for probability calculation. 346 | * 347 | * @see de.daslaboratorium.machinelearning.classifier.Classifier#featureWeighedAverage(Object, 348 | * Object, IFeatureProbability, float, float) 349 | * 350 | * @param feature 351 | * The feature, which probability to calculate. 352 | * @param category 353 | * The category. 354 | * @param calculator 355 | * The calculating object. 356 | * @param weight 357 | * The feature weight. 358 | * @return The weighed average probability. 359 | */ 360 | public float featureWeighedAverage(T feature, K category, IFeatureProbability calculator, float weight) { 361 | return this.featureWeighedAverage(feature, category, calculator, weight, 0.5f); 362 | } 363 | 364 | /** 365 | * Retrieves the weighed average P(feature|category) with the 366 | * given weight, the given assumed probability and the given object to use 367 | * for probability calculation. 368 | * 369 | * @param feature 370 | * The feature, which probability to calculate. 371 | * @param category 372 | * The category. 373 | * @param calculator 374 | * The calculating object. 375 | * @param weight 376 | * The feature weight. 377 | * @param assumedProbability 378 | * The assumed probability. 379 | * @return The weighed average probability. 380 | */ 381 | public float featureWeighedAverage(T feature, K category, IFeatureProbability calculator, float weight, 382 | float assumedProbability) { 383 | 384 | /* 385 | * use the given calculating object or the default method to calculate 386 | * the probability that the given feature occurred in the given 387 | * category. 388 | */ 389 | final float basicProbability = (calculator == null) ? this.featureProbability(feature, category) 390 | : calculator.featureProbability(feature, category); 391 | 392 | Integer totals = this.totalFeatureCount.get(feature); 393 | if (totals == null) totals = 0; 394 | return (weight * assumedProbability + totals * basicProbability) / (weight + totals); 395 | } 396 | 397 | /** 398 | * Train the classifier by telling it that the given features resulted in 399 | * the given category. 400 | * 401 | * @param category 402 | * The category the features belong to. 403 | * @param features 404 | * The features that resulted in the given category. 405 | */ 406 | public void learn(K category, Collection features) { 407 | this.learn(new Classification(features, category)); 408 | } 409 | 410 | /** 411 | * Train the classifier by telling it that the given features resulted in 412 | * the given category. 413 | * 414 | * @param classification 415 | * The classification to learn. 416 | */ 417 | public void learn(Classification classification) { 418 | 419 | for (T feature : classification.getFeatureset()) 420 | this.incrementFeature(feature, classification.getCategory()); 421 | this.incrementCategory(classification.getCategory()); 422 | 423 | this.memoryQueue.offer(classification); 424 | if (this.memoryQueue.size() > this.memoryCapacity) { 425 | Classification toForget = this.memoryQueue.remove(); 426 | 427 | for (T feature : toForget.getFeatureset()) 428 | this.decrementFeature(feature, toForget.getCategory()); 429 | this.decrementCategory(toForget.getCategory()); 430 | } 431 | } 432 | 433 | /** 434 | * The classify method. It will retrieve the most likely category for the 435 | * features given and depends on the concrete classifier implementation. 436 | * 437 | * @param features 438 | * The features to classify. 439 | * @return The category most likely. 440 | */ 441 | public abstract Classification classify(Collection features); 442 | 443 | } 444 | -------------------------------------------------------------------------------- /src/main/java/de/daslaboratorium/machinelearning/classifier/IFeatureProbability.java: -------------------------------------------------------------------------------- 1 | package de.daslaboratorium.machinelearning.classifier; 2 | 3 | /** 4 | * Simple interface defining the method to calculate the feature probability. 5 | * 6 | * @author Philipp Nolte 7 | * 8 | * @param 9 | * The feature class. 10 | * @param 11 | * The category class. 12 | */ 13 | public interface IFeatureProbability { 14 | 15 | /** 16 | * Returns the probability of a feature being classified as 17 | * category in the learning set. 18 | * 19 | * @param feature 20 | * the feature to return the probability for 21 | * @param category 22 | * the category to check the feature against 23 | * @return the probability p(feature|category) 24 | */ 25 | public float featureProbability(T feature, K category); 26 | 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/de/daslaboratorium/machinelearning/classifier/bayes/BayesClassifier.java: -------------------------------------------------------------------------------- 1 | package de.daslaboratorium.machinelearning.classifier.bayes; 2 | 3 | import java.util.Collection; 4 | import java.util.Comparator; 5 | import java.util.SortedSet; 6 | import java.util.TreeSet; 7 | 8 | import de.daslaboratorium.machinelearning.classifier.Classification; 9 | import de.daslaboratorium.machinelearning.classifier.Classifier; 10 | 11 | /** 12 | * A concrete implementation of the abstract Classifier class. The Bayes 13 | * classifier implements a naive Bayes approach to classifying a given set of 14 | * features: classify(feat1,...,featN) = argmax(P(cat)*PROD(P(featI|cat) 15 | * 16 | * @author Philipp Nolte 17 | * 18 | * @see http://en.wikipedia.org/wiki/Naive_Bayes_classifier 19 | * 20 | * @param The feature class. 21 | * @param The category class. 22 | */ 23 | public class BayesClassifier extends Classifier { 24 | 25 | /** 26 | * Calculates the product of all feature probabilities: PROD(P(featI|cat) 27 | * 28 | * @param features The set of features to use. 29 | * @param category The category to test for. 30 | * @return The product of all feature probabilities. 31 | */ 32 | private float featuresProbabilityProduct(Collection features, 33 | K category) { 34 | float product = 1.0f; 35 | for (T feature : features) 36 | product *= this.featureWeighedAverage(feature, category); 37 | return product; 38 | } 39 | 40 | /** 41 | * Calculates the probability that the features can be classified as the 42 | * category given. 43 | * 44 | * @param features The set of features to use. 45 | * @param category The category to test for. 46 | * @return The probability that the features can be classified as the 47 | * category. 48 | */ 49 | private float categoryProbability(Collection features, K category) { 50 | return ((float) this.getCategoryCount(category) 51 | / (float) this.getCategoriesTotal()) 52 | * featuresProbabilityProduct(features, category); 53 | } 54 | 55 | /** 56 | * Retrieves a sorted Set of probabilities that the given set 57 | * of features is classified as the available categories. 58 | * 59 | * @param features The set of features to use. 60 | * @return A sorted Set of category-probability-entries. 61 | */ 62 | private SortedSet> categoryProbabilities( 63 | Collection features) { 64 | 65 | /* 66 | * Sort the set according to the possibilities. Because we have to sort 67 | * by the mapped value and not by the mapped key, we can not use a 68 | * sorted tree (TreeMap) and we have to use a set-entry approach to 69 | * achieve the desired functionality. A custom comparator is therefore 70 | * needed. 71 | */ 72 | SortedSet> probabilities = 73 | new TreeSet>( 74 | new Comparator>() { 75 | 76 | public int compare(Classification o1, 77 | Classification o2) { 78 | int toReturn = Float.compare( 79 | o1.getProbability(), o2.getProbability()); 80 | if ((toReturn == 0) 81 | && !o1.getCategory().equals(o2.getCategory())) 82 | toReturn = -1; 83 | return toReturn; 84 | } 85 | }); 86 | 87 | for (K category : this.getCategories()) 88 | probabilities.add(new Classification( 89 | features, category, 90 | this.categoryProbability(features, category))); 91 | return probabilities; 92 | } 93 | 94 | /** 95 | * Classifies the given set of features. 96 | * 97 | * @return The category the set of features is classified as. 98 | */ 99 | @Override 100 | public Classification classify(Collection features) { 101 | SortedSet> probabilites = 102 | this.categoryProbabilities(features); 103 | 104 | if (probabilites.size() > 0) { 105 | return probabilites.last(); 106 | } 107 | return null; 108 | } 109 | 110 | /** 111 | * Classifies the given set of features. and return the full details of the 112 | * classification. 113 | * 114 | * @return The set of categories the set of features is classified as. 115 | */ 116 | public Collection> classifyDetailed( 117 | Collection features) { 118 | return this.categoryProbabilities(features); 119 | } 120 | 121 | } 122 | -------------------------------------------------------------------------------- /src/test/java/de/daslaboratorium/machinelearning/classifier/bayes/BayesClassifierTest.java: -------------------------------------------------------------------------------- 1 | package de.daslaboratorium.machinelearning.classifier.bayes; 2 | 3 | import java.io.ByteArrayOutputStream; 4 | import java.io.IOException; 5 | import java.io.ObjectOutputStream; 6 | import java.util.ArrayList; 7 | import java.util.Arrays; 8 | import java.util.Collection; 9 | import java.util.List; 10 | 11 | import org.junit.Assert; 12 | import org.junit.Before; 13 | import org.junit.Test; 14 | 15 | import de.daslaboratorium.machinelearning.classifier.Classification; 16 | import de.daslaboratorium.machinelearning.classifier.Classifier; 17 | 18 | public class BayesClassifierTest { 19 | 20 | private static final double EPSILON = 0.001; 21 | private static final String CATEGORY_NEGATIVE = "negative"; 22 | private static final String CATEGORY_POSITIVE = "positive"; 23 | private Classifier bayes; 24 | 25 | @Before 26 | public void setUp() { 27 | /* 28 | * Create a new classifier instance. The context features are Strings 29 | * and the context will be classified with a String according to the 30 | * featureset of the context. 31 | */ 32 | bayes = new BayesClassifier(); 33 | 34 | /* 35 | * The classifier can learn from classifications that are handed over to 36 | * the learn methods. Imagin a tokenized text as follows. The tokens are 37 | * the text's features. The category of the text will either be positive 38 | * or negative. 39 | */ 40 | final String[] positiveText = "I love sunny days".split("\\s"); 41 | bayes.learn(CATEGORY_POSITIVE, Arrays.asList(positiveText)); 42 | 43 | final String[] negativeText = "I hate rain".split("\\s"); 44 | bayes.learn(CATEGORY_NEGATIVE, Arrays.asList(negativeText)); 45 | } 46 | 47 | @Test 48 | public void testStringClassification() { 49 | final String[] unknownText1 = "today is a sunny day".split("\\s"); 50 | final String[] unknownText2 = "there will be rain".split("\\s"); 51 | 52 | Assert.assertEquals(CATEGORY_POSITIVE, bayes.classify(Arrays.asList(unknownText1)).getCategory()); 53 | Assert.assertEquals(CATEGORY_NEGATIVE, bayes.classify(Arrays.asList(unknownText2)).getCategory()); 54 | } 55 | 56 | @Test 57 | public void testStringClassificationInDetails() { 58 | 59 | final String[] unknownText1 = "today is a sunny day".split("\\s"); 60 | 61 | Collection> classifications = ((BayesClassifier) bayes) 62 | .classifyDetailed(Arrays.asList(unknownText1)); 63 | 64 | List> list = new ArrayList>(classifications); 65 | 66 | Assert.assertEquals(CATEGORY_NEGATIVE, list.get(0).getCategory()); 67 | Assert.assertEquals(0.0078125, list.get(0).getProbability(), EPSILON); 68 | 69 | Assert.assertEquals(CATEGORY_POSITIVE, list.get(1).getCategory()); 70 | Assert.assertEquals(0.0234375, list.get(1).getProbability(), EPSILON); 71 | } 72 | 73 | @Test 74 | public void testSerialization() throws IOException { 75 | 76 | new ObjectOutputStream(new ByteArrayOutputStream()).writeObject(bayes); 77 | } 78 | } --------------------------------------------------------------------------------