├── .gitignore
├── .travis.yml
├── README.md
├── example
└── RunnableExample.java
├── forgetful-learning.pdf
├── pom.xml
└── src
├── main
└── java
│ └── de
│ └── daslaboratorium
│ └── machinelearning
│ └── classifier
│ ├── Classification.java
│ ├── Classifier.java
│ ├── IFeatureProbability.java
│ └── bayes
│ └── BayesClassifier.java
└── test
└── java
└── de
└── daslaboratorium
└── machinelearning
└── classifier
└── bayes
└── BayesClassifierTest.java
/.gitignore:
--------------------------------------------------------------------------------
1 | # Binary files
2 | *.class
3 |
4 | # Maven files
5 | target/
6 |
7 | # Eclipse files
8 | .classpath
9 | .project
10 |
11 | # IntelliJ files
12 | .idea/
13 | *.iml
14 |
15 | *.prefs
16 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: java
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Java Naive Bayes Classifier
2 | ==================
3 |
4 | [](https://travis-ci.org/ptnplanet/Java-Naive-Bayes-Classifier)
5 | [](https://jitpack.io/#ptnplanet/Java-Naive-Bayes-Classifier)
6 |
7 | Nothing special. It works and is well documented, so you should get it running without wasting too much time searching for other alternatives on the net.
8 |
9 | Maven Quick-Start
10 | ------------------
11 |
12 | This Java Naive Bayes Classifier can be installed via the jitpack repository. Make sure to add it to your buildfile first.
13 |
14 | ```xml
15 |
16 |
17 | jitpack.io
18 | https://jitpack.io
19 |
20 |
21 | ```
22 |
23 | Then, treat it as any other dependency.
24 |
25 | ```xml
26 |
27 | com.github.ptnplanet
28 | Java-Naive-Bayes-Classifier
29 | 1.0.7
30 |
31 | ```
32 |
33 | For other build-tools (e.g. gradle), visit https://jitpack.io for configuration snippets.
34 |
35 | Please also head to the release tab for further releases.
36 |
37 | Overview
38 | ------------------
39 |
40 | I like talking about *features* and *categories*. Objects have features and may belong to a category. The classifier will try matching objects to their categories by looking at the objects' features. It does so by consulting its memory filled with knowledge gathered from training examples.
41 |
42 | Classifying a feature-set results in the highest product of 1) the probability of that category to occur and 2) the product of all the features' probabilities to occure in that category:
43 |
44 | ```classify(feature1, ..., featureN) = argmax(P(category) * PROD(P(feature|category)))```
45 |
46 | This is a so-called maximum a posteriori estimation. Wikipedia actually does a good job explaining it: http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Probabilistic_model
47 |
48 | Learning from Examples
49 | ------------------
50 |
51 | Add knowledge by telling the classifier, that these features belong to a specific category:
52 |
53 | ```java
54 | String[] positiveText = "I love sunny days".split("\\s");
55 | bayes.learn("positive", Arrays.asList(positiveText));
56 | ```
57 |
58 | Classify unknown objects
59 | ------------------
60 |
61 | Use the gathered knowledge to classify unknown objects with their features. The classifier will return the category that the object most likely belongs to.
62 |
63 | ```java
64 | String[] unknownText1 = "today is a sunny day".split("\\s");
65 | bayes.classify(Arrays.asList(unknownText1)).getCategory();
66 | ```
67 |
68 | Example
69 | ------------------
70 |
71 | Here is an excerpt from the example. The classifier will classify sentences (arrays of features) as sentences with either positive or negative sentiment. Please refer to the full example for a more detailed documentation.
72 |
73 | ```java
74 | // Create a new bayes classifier with string categories and string features.
75 | Classifier bayes = new BayesClassifier();
76 |
77 | // Two examples to learn from.
78 | String[] positiveText = "I love sunny days".split("\\s");
79 | String[] negativeText = "I hate rain".split("\\s");
80 |
81 | // Learn by classifying examples.
82 | // New categories can be added on the fly, when they are first used.
83 | // A classification consists of a category and a list of features
84 | // that resulted in the classification in that category.
85 | bayes.learn("positive", Arrays.asList(positiveText));
86 | bayes.learn("negative", Arrays.asList(negativeText));
87 |
88 | // Here are two unknown sentences to classify.
89 | String[] unknownText1 = "today is a sunny day".split("\\s");
90 | String[] unknownText2 = "there will be rain".split("\\s");
91 |
92 | System.out.println( // will output "positive"
93 | bayes.classify(Arrays.asList(unknownText1)).getCategory());
94 | System.out.println( // will output "negative"
95 | bayes.classify(Arrays.asList(unknownText2)).getCategory());
96 |
97 | // Get more detailed classification result.
98 | ((BayesClassifier) bayes).classifyDetailed(
99 | Arrays.asList(unknownText1));
100 |
101 | // Change the memory capacity. New learned classifications (using
102 | // the learn method) are stored in a queue with the size given
103 | // here and used to classify unknown sentences.
104 | bayes.setMemoryCapacity(500);
105 | ```
106 |
107 | Forgetful learning
108 | ------------------
109 |
110 | This classifier is forgetful. This means, that the classifier will forget recent classifications it uses for future classifications after - defaulting to 1.000 - classifications learned. This will ensure, that the classifier can react to ongoing changes in the user's habbits.
111 |
112 |
113 | Interface
114 | ------------------
115 | The abstract ```Classifier``` serves as a base for the concrete ```BayesClassifier```. Here are its methods. Please also refer to the Javadoc.
116 |
117 | * ```void reset()``` Resets the learned feature and category counts.
118 | * ```Set getFeatures()``` Returns a ```Set``` of features the classifier knows about.
119 | * ```Set getCategories()``` Returns a ```Set``` of categories the classifier knows about.
120 | * ```int getCategoriesTotal()``` Retrieves the total number of categories the classifier knows about.
121 | * ```int getMemoryCapacity()``` Retrieves the memory's capacity.
122 | * ```void setMemoryCapacity(int memoryCapacity)``` Sets the memory's capacity. If the new value is less than the old value, the memory will be truncated accordingly.
123 | * ```void incrementFeature(T feature, K category)``` Increments the count of a given feature in the given category. This is equal to telling the classifier, that this feature has occurred in this category.
124 | * ```void incrementCategory(K category)``` Increments the count of a given category. This is equal to telling the classifier, that this category has occurred once more.
125 | * ```void decrementFeature(T feature, K category)``` Decrements the count of a given feature in the given category. This is equal to telling the classifier that this feature was classified once in the category.
126 | * ```void decrementCategory(K category)``` Decrements the count of a given category. This is equal to telling the classifier, that this category has occurred once less.
127 | * ```int getFeatureCount(T feature, K category)``` Retrieves the number of occurrences of the given feature in the given category.
128 | * ```int getFeatureCount(T feature)``` Retrieves the total number of occurrences of the given feature.
129 | * ```int getCategoryCount(K category)``` Retrieves the number of occurrences of the given category.
130 | * ```float featureProbability(T feature, K category)``` (*implements* ```IFeatureProbability.featureProbability```) Returns the probability that the given feature occurs in the given category.
131 | * ```float featureWeighedAverage(T feature, K category)``` Retrieves the weighed average ```P(feature|category)``` with overall weight of ```1.0``` and an assumed probability of ```0.5```. The probability defaults to the overall feature probability.
132 | * ```float featureWeighedAverage(T feature, K category, IFeatureProbability calculator)``` Retrieves the weighed average ```P(feature|category)``` with overall weight of ```1.0```, an assumed probability of ```0.5``` and the given object to use for probability calculation.
133 | * ```float featureWeighedAverage(T feature, K category, IFeatureProbability calculator, float weight)```Retrieves the weighed average ```P(feature|category)``` with the given weight and an assumed probability of ```0.5``` and the given object to use for probability calculation.
134 | * ```float featureWeighedAverage(T feature, K category, IFeatureProbability calculator, float weight, float assumedProbability)``` Retrieves the weighed average ```P(feature|category)``` with the given weight, the given assumed probability and the given object to use for probability calculation.
135 | * ```void learn(K category, Collection features)``` Train the classifier by telling it that the given features resulted in the given category.
136 | * ```void learn(Classification classification)``` Train the classifier by telling it that the given features resulted in the given category.
137 |
138 | The ```BayesClassifier``` class implements the following abstract method:
139 |
140 | * ```Classification classify(Collection features)``` It will retrieve the most likely category for the features given and depends on the concrete classifier implementation.
141 |
142 | Running the example
143 | ------------------
144 |
145 | ```shell
146 | $ git clone https://github.com/ptnplanet/Java-Naive-Bayes-Classifier.git
147 | $ cd Java-Naive-Bayes-Classifier
148 | $ javac -cp src/main/java example/RunnableExample.java
149 | $ java -cp example:src/main/java RunnableExample
150 | ```
151 |
152 | Possible Performance issues
153 | ------------------
154 |
155 | Performance improvements, I am currently thinking of:
156 |
157 | - Store the natural logarithms of the feature probabilities and add them together instead of multiplying the probability numbers
158 |
159 | The MIT License (MIT)
160 | ------------------
161 |
162 | Copyright (c) 2012-2017 Philipp Nolte
163 |
164 | Permission is hereby granted, free of charge, to any person obtaining a copy
165 | of this software and associated documentation files (the "Software"), to deal
166 | in the Software without restriction, including without limitation the rights
167 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
168 | copies of the Software, and to permit persons to whom the Software is
169 | furnished to do so, subject to the following conditions:
170 |
171 | The above copyright notice and this permission notice shall be included in
172 | all copies or substantial portions of the Software.
173 |
174 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
175 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
176 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
177 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
178 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
179 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
180 | THE SOFTWARE.
181 |
--------------------------------------------------------------------------------
/example/RunnableExample.java:
--------------------------------------------------------------------------------
1 | import java.util.Arrays;
2 |
3 | import de.daslaboratorium.machinelearning.classifier.bayes.BayesClassifier;
4 | import de.daslaboratorium.machinelearning.classifier.Classifier;
5 |
6 | public class RunnableExample {
7 |
8 | public static void main(String[] args) {
9 |
10 | /*
11 | * Create a new classifier instance. The context features are
12 | * Strings and the context will be classified with a String according
13 | * to the featureset of the context.
14 | */
15 | final Classifier bayes =
16 | new BayesClassifier();
17 |
18 | /*
19 | * The classifier can learn from classifications that are handed over
20 | * to the learn methods. Imagin a tokenized text as follows. The tokens
21 | * are the text's features. The category of the text will either be
22 | * positive or negative.
23 | */
24 | final String[] positiveText = "I love sunny days".split("\\s");
25 | bayes.learn("positive", Arrays.asList(positiveText));
26 |
27 | final String[] negativeText = "I hate rain".split("\\s");
28 | bayes.learn("negative", Arrays.asList(negativeText));
29 |
30 | /*
31 | * Now that the classifier has "learned" two classifications, it will
32 | * be able to classify similar sentences. The classify method returns
33 | * a Classification Object, that contains the given featureset,
34 | * classification probability and resulting category.
35 | */
36 | final String[] unknownText1 = "today is a sunny day".split("\\s");
37 | final String[] unknownText2 = "there will be rain".split("\\s");
38 |
39 | System.out.println( // will output "positive"
40 | bayes.classify(Arrays.asList(unknownText1)).getCategory());
41 | System.out.println( // will output "negative"
42 | bayes.classify(Arrays.asList(unknownText2)).getCategory());
43 |
44 | /*
45 | * The BayesClassifier extends the abstract Classifier and provides
46 | * detailed classification results that can be retrieved by calling
47 | * the classifyDetailed Method.
48 | *
49 | * The classification with the highest probability is the resulting
50 | * classification. The returned List will look like this.
51 | * [
52 | * Classification [
53 | * category=negative,
54 | * probability=0.0078125,
55 | * featureset=[today, is, a, sunny, day]
56 | * ],
57 | * Classification [
58 | * category=positive,
59 | * probability=0.0234375,
60 | * featureset=[today, is, a, sunny, day]
61 | * ]
62 | * ]
63 | */
64 | ((BayesClassifier) bayes).classifyDetailed(
65 | Arrays.asList(unknownText1));
66 |
67 | /*
68 | * Please note, that this particular classifier implementation will
69 | * "forget" learned classifications after a few learning sessions. The
70 | * number of learning sessions it will record can be set as follows:
71 | */
72 | bayes.setMemoryCapacity(500); // remember the last 500 learned classifications
73 | }
74 |
75 | }
76 |
--------------------------------------------------------------------------------
/forgetful-learning.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ptnplanet/Java-Naive-Bayes-Classifier/3f8d874c98baacbd3f4accae71e7179ab426dfe2/forgetful-learning.pdf
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 | 4.0.0
6 |
7 | com.github.ptnplanet
8 | Java-Naive-Bayes-Classifier
9 | 1.0.7
10 |
11 |
12 |
13 | junit
14 | junit
15 | 4.12
16 | test
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/src/main/java/de/daslaboratorium/machinelearning/classifier/Classification.java:
--------------------------------------------------------------------------------
1 | package de.daslaboratorium.machinelearning.classifier;
2 |
3 | import java.io.Serializable;
4 | import java.util.Collection;
5 |
6 | /**
7 | * A basic wrapper reflecting a classification. It will store both featureset
8 | * and resulting classification.
9 | *
10 | * @author Philipp Nolte
11 | *
12 | * @param
13 | * The feature class.
14 | * @param
15 | * The category class.
16 | */
17 | public class Classification implements Serializable {
18 |
19 | /**
20 | * Generated Serial Version UID (generated for v1.0.7).
21 | */
22 | private static final long serialVersionUID = -1210981535415341283L;
23 |
24 | /**
25 | * The classified featureset.
26 | */
27 | private Collection featureset;
28 |
29 | /**
30 | * The category as which the featureset was classified.
31 | */
32 | private K category;
33 |
34 | /**
35 | * The probability that the featureset belongs to the given category.
36 | */
37 | private float probability;
38 |
39 | /**
40 | * Constructs a new Classification with the parameters given and a default
41 | * probability of 1.
42 | *
43 | * @param featureset
44 | * The featureset.
45 | * @param category
46 | * The category.
47 | */
48 | public Classification(Collection featureset, K category) {
49 | this(featureset, category, 1.0f);
50 | }
51 |
52 | /**
53 | * Constructs a new Classification with the parameters given.
54 | *
55 | * @param featureset
56 | * The featureset.
57 | * @param category
58 | * The category.
59 | * @param probability
60 | * The probability.
61 | */
62 | public Classification(Collection featureset, K category, float probability) {
63 | this.featureset = featureset;
64 | this.category = category;
65 | this.probability = probability;
66 | }
67 |
68 | /**
69 | * Retrieves the featureset classified.
70 | *
71 | * @return The featureset.
72 | */
73 | public Collection getFeatureset() {
74 | return featureset;
75 | }
76 |
77 | /**
78 | * Retrieves the classification's probability.
79 | *
80 | * @return
81 | */
82 | public float getProbability() {
83 | return this.probability;
84 | }
85 |
86 | /**
87 | * Retrieves the category the featureset was classified as.
88 | *
89 | * @return The category.
90 | */
91 | public K getCategory() {
92 | return category;
93 | }
94 |
95 | /**
96 | * {@inheritDoc}
97 | */
98 | @Override
99 | public String toString() {
100 | return "Classification [category=" + this.category + ", probability=" + this.probability + ", featureset="
101 | + this.featureset + "]";
102 | }
103 |
104 | }
105 |
--------------------------------------------------------------------------------
/src/main/java/de/daslaboratorium/machinelearning/classifier/Classifier.java:
--------------------------------------------------------------------------------
1 | package de.daslaboratorium.machinelearning.classifier;
2 |
3 | import java.util.Collection;
4 | import java.util.Dictionary;
5 | import java.util.Enumeration;
6 | import java.util.Hashtable;
7 | import java.util.LinkedList;
8 | import java.util.Queue;
9 | import java.util.Set;
10 |
11 | /**
12 | * Abstract base extended by any concrete classifier. It implements the basic
13 | * functionality for storing categories or features and can be used to calculate
14 | * basic probabilities – both category and feature probabilities. The classify
15 | * function has to be implemented by the concrete classifier class.
16 | *
17 | * @author Philipp Nolte
18 | *
19 | * @param
20 | * A feature class
21 | * @param
22 | * A category class
23 | */
24 | public abstract class Classifier implements IFeatureProbability, java.io.Serializable {
25 |
26 | /**
27 | * Generated Serial Version UID (generated for v1.0.7).
28 | */
29 | private static final long serialVersionUID = 5504911666956811966L;
30 |
31 | /**
32 | * Initial capacity of category dictionaries.
33 | */
34 | private static final int INITIAL_CATEGORY_DICTIONARY_CAPACITY = 16;
35 |
36 | /**
37 | * Initial capacity of feature dictionaries. It should be quite big, because
38 | * the features will quickly outnumber the categories.
39 | */
40 | private static final int INITIAL_FEATURE_DICTIONARY_CAPACITY = 32;
41 |
42 | /**
43 | * The initial memory capacity or how many classifications are memorized.
44 | */
45 | private int memoryCapacity = 1000;
46 |
47 | /**
48 | * A dictionary mapping features to their number of occurrences in each
49 | * known category.
50 | */
51 | private Dictionary> featureCountPerCategory;
52 |
53 | /**
54 | * A dictionary mapping features to their number of occurrences.
55 | */
56 | private Dictionary totalFeatureCount;
57 |
58 | /**
59 | * A dictionary mapping categories to their number of occurrences.
60 | */
61 | private Dictionary totalCategoryCount;
62 |
63 | /**
64 | * The classifier's memory. It will forget old classifications as soon as
65 | * they become too old.
66 | */
67 | private Queue> memoryQueue;
68 |
69 | /**
70 | * Constructs a new classifier without any trained knowledge.
71 | */
72 | public Classifier() {
73 | this.reset();
74 | }
75 |
76 | /**
77 | * Resets the learned feature and category counts.
78 | */
79 | public void reset() {
80 | this.featureCountPerCategory = new Hashtable>(
81 | Classifier.INITIAL_CATEGORY_DICTIONARY_CAPACITY);
82 | this.totalFeatureCount = new Hashtable(Classifier.INITIAL_FEATURE_DICTIONARY_CAPACITY);
83 | this.totalCategoryCount = new Hashtable(Classifier.INITIAL_CATEGORY_DICTIONARY_CAPACITY);
84 | this.memoryQueue = new LinkedList>();
85 | }
86 |
87 | /**
88 | * Returns a Set
of features the classifier knows about.
89 | *
90 | * @return The Set
of features the classifier knows about.
91 | */
92 | public Set getFeatures() {
93 | return ((Hashtable) this.totalFeatureCount).keySet();
94 | }
95 |
96 | /**
97 | * Returns a Set
of categories the classifier knows about.
98 | *
99 | * @return The Set
of categories the classifier knows about.
100 | */
101 | public Set getCategories() {
102 | return ((Hashtable) this.totalCategoryCount).keySet();
103 | }
104 |
105 | /**
106 | * Retrieves the total number of categories the classifier knows about.
107 | *
108 | * @return The total category count.
109 | */
110 | public int getCategoriesTotal() {
111 | int toReturn = 0;
112 | for (Enumeration e = this.totalCategoryCount.elements(); e.hasMoreElements();) {
113 | toReturn += e.nextElement();
114 | }
115 | return toReturn;
116 | }
117 |
118 | /**
119 | * Retrieves the memory's capacity.
120 | *
121 | * @return The memory's capacity.
122 | */
123 | public int getMemoryCapacity() {
124 | return memoryCapacity;
125 | }
126 |
127 | /**
128 | * Sets the memory's capacity. If the new value is less than the old value,
129 | * the memory will be truncated accordingly.
130 | *
131 | * @param memoryCapacity
132 | * The new memory capacity.
133 | */
134 | public void setMemoryCapacity(int memoryCapacity) {
135 | for (int i = this.memoryCapacity; i > memoryCapacity; i--) {
136 | this.memoryQueue.poll();
137 | }
138 | this.memoryCapacity = memoryCapacity;
139 | }
140 |
141 | /**
142 | * Increments the count of a given feature in the given category. This is
143 | * equal to telling the classifier, that this feature has occurred in this
144 | * category.
145 | *
146 | * @param feature
147 | * The feature, which count to increase.
148 | * @param category
149 | * The category the feature occurred in.
150 | */
151 | public void incrementFeature(T feature, K category) {
152 | Dictionary features = this.featureCountPerCategory.get(category);
153 | if (features == null) {
154 | this.featureCountPerCategory.put(category,
155 | new Hashtable(Classifier.INITIAL_FEATURE_DICTIONARY_CAPACITY));
156 | features = this.featureCountPerCategory.get(category);
157 | }
158 | Integer count = features.get(feature);
159 | if (count == null) {
160 | features.put(feature, 0);
161 | count = features.get(feature);
162 | }
163 | features.put(feature, ++count);
164 |
165 | Integer totalCount = this.totalFeatureCount.get(feature);
166 | if (totalCount == null) {
167 | this.totalFeatureCount.put(feature, 0);
168 | totalCount = this.totalFeatureCount.get(feature);
169 | }
170 | this.totalFeatureCount.put(feature, ++totalCount);
171 | }
172 |
173 | /**
174 | * Increments the count of a given category. This is equal to telling the
175 | * classifier, that this category has occurred once more.
176 | *
177 | * @param category
178 | * The category, which count to increase.
179 | */
180 | public void incrementCategory(K category) {
181 | Integer count = this.totalCategoryCount.get(category);
182 | if (count == null) {
183 | this.totalCategoryCount.put(category, 0);
184 | count = this.totalCategoryCount.get(category);
185 | }
186 | this.totalCategoryCount.put(category, ++count);
187 | }
188 |
189 | /**
190 | * Decrements the count of a given feature in the given category. This is
191 | * equal to telling the classifier that this feature was classified once in
192 | * the category.
193 | *
194 | * @param feature
195 | * The feature to decrement the count for.
196 | * @param category
197 | * The category.
198 | */
199 | public void decrementFeature(T feature, K category) {
200 | Dictionary features = this.featureCountPerCategory.get(category);
201 | if (features == null) {
202 | return;
203 | }
204 | Integer count = features.get(feature);
205 | if (count == null) {
206 | return;
207 | }
208 | if (count.intValue() == 1) {
209 | features.remove(feature);
210 | if (features.size() == 0) {
211 | this.featureCountPerCategory.remove(category);
212 | }
213 | } else {
214 | features.put(feature, --count);
215 | }
216 |
217 | Integer totalCount = this.totalFeatureCount.get(feature);
218 | if (totalCount == null) {
219 | return;
220 | }
221 | if (totalCount.intValue() == 1) {
222 | this.totalFeatureCount.remove(feature);
223 | } else {
224 | this.totalFeatureCount.put(feature, --totalCount);
225 | }
226 | }
227 |
228 | /**
229 | * Decrements the count of a given category. This is equal to telling the
230 | * classifier, that this category has occurred once less.
231 | *
232 | * @param category
233 | * The category, which count to increase.
234 | */
235 | public void decrementCategory(K category) {
236 | Integer count = this.totalCategoryCount.get(category);
237 | if (count == null) {
238 | return;
239 | }
240 | if (count.intValue() == 1) {
241 | this.totalCategoryCount.remove(category);
242 | } else {
243 | this.totalCategoryCount.put(category, --count);
244 | }
245 | }
246 |
247 | /**
248 | * Retrieves the number of occurrences of the given feature in the given
249 | * category.
250 | *
251 | * @param feature
252 | * The feature, which count to retrieve.
253 | * @param category
254 | * The category, which the feature occurred in.
255 | * @return The number of occurrences of the feature in the category.
256 | */
257 | public int getFeatureCount(T feature, K category) {
258 | Dictionary features = this.featureCountPerCategory.get(category);
259 | if (features == null) return 0;
260 | Integer count = features.get(feature);
261 | return (count == null) ? 0 : count.intValue();
262 | }
263 |
264 | /**
265 | * Retrieves the total number of occurrences of the given feature.
266 | *
267 | * @param feature
268 | * The feature, which count to retrieve.
269 | * @return The total number of occurences of the feature.
270 | */
271 | public int getFeatureCount(T feature) {
272 | Integer count = this.totalFeatureCount.get(feature);
273 | return (count == null) ? 0 : count.intValue();
274 | }
275 |
276 | /**
277 | * Retrieves the number of occurrences of the given category.
278 | *
279 | * @param category
280 | * The category, which count should be retrieved.
281 | * @return The number of occurrences.
282 | */
283 | public int getCategoryCount(K category) {
284 | Integer count = this.totalCategoryCount.get(category);
285 | return (count == null) ? 0 : count.intValue();
286 | }
287 |
288 | /**
289 | * {@inheritDoc}
290 | */
291 | public float featureProbability(T feature, K category) {
292 | final float totalFeatureCount = this.getFeatureCount(feature);
293 |
294 | if (totalFeatureCount == 0) {
295 | return 0;
296 | } else {
297 | return this.getFeatureCount(feature, category) / (float) this.getFeatureCount(feature);
298 | }
299 | }
300 |
301 | /**
302 | * Retrieves the weighed average P(feature|category)
with
303 | * overall weight of 1.0
and an assumed probability of
304 | * 0.5
. The probability defaults to the overall feature
305 | * probability.
306 | *
307 | * @see de.daslaboratorium.machinelearning.classifier.Classifier#featureProbability(Object,
308 | * Object)
309 | * @see de.daslaboratorium.machinelearning.classifier.Classifier#featureWeighedAverage(Object,
310 | * Object, IFeatureProbability, float, float)
311 | *
312 | * @param feature
313 | * The feature, which probability to calculate.
314 | * @param category
315 | * The category.
316 | * @return The weighed average probability.
317 | */
318 | public float featureWeighedAverage(T feature, K category) {
319 | return this.featureWeighedAverage(feature, category, null, 1.0f, 0.5f);
320 | }
321 |
322 | /**
323 | * Retrieves the weighed average P(feature|category)
with
324 | * overall weight of 1.0
, an assumed probability of
325 | * 0.5
and the given object to use for probability calculation.
326 | *
327 | * @see de.daslaboratorium.machinelearning.classifier.Classifier#featureWeighedAverage(Object,
328 | * Object, IFeatureProbability, float, float)
329 | *
330 | * @param feature
331 | * The feature, which probability to calculate.
332 | * @param category
333 | * The category.
334 | * @param calculator
335 | * The calculating object.
336 | * @return The weighed average probability.
337 | */
338 | public float featureWeighedAverage(T feature, K category, IFeatureProbability calculator) {
339 | return this.featureWeighedAverage(feature, category, calculator, 1.0f, 0.5f);
340 | }
341 |
342 | /**
343 | * Retrieves the weighed average P(feature|category)
with the
344 | * given weight and an assumed probability of 0.5
and the given
345 | * object to use for probability calculation.
346 | *
347 | * @see de.daslaboratorium.machinelearning.classifier.Classifier#featureWeighedAverage(Object,
348 | * Object, IFeatureProbability, float, float)
349 | *
350 | * @param feature
351 | * The feature, which probability to calculate.
352 | * @param category
353 | * The category.
354 | * @param calculator
355 | * The calculating object.
356 | * @param weight
357 | * The feature weight.
358 | * @return The weighed average probability.
359 | */
360 | public float featureWeighedAverage(T feature, K category, IFeatureProbability calculator, float weight) {
361 | return this.featureWeighedAverage(feature, category, calculator, weight, 0.5f);
362 | }
363 |
364 | /**
365 | * Retrieves the weighed average P(feature|category)
with the
366 | * given weight, the given assumed probability and the given object to use
367 | * for probability calculation.
368 | *
369 | * @param feature
370 | * The feature, which probability to calculate.
371 | * @param category
372 | * The category.
373 | * @param calculator
374 | * The calculating object.
375 | * @param weight
376 | * The feature weight.
377 | * @param assumedProbability
378 | * The assumed probability.
379 | * @return The weighed average probability.
380 | */
381 | public float featureWeighedAverage(T feature, K category, IFeatureProbability calculator, float weight,
382 | float assumedProbability) {
383 |
384 | /*
385 | * use the given calculating object or the default method to calculate
386 | * the probability that the given feature occurred in the given
387 | * category.
388 | */
389 | final float basicProbability = (calculator == null) ? this.featureProbability(feature, category)
390 | : calculator.featureProbability(feature, category);
391 |
392 | Integer totals = this.totalFeatureCount.get(feature);
393 | if (totals == null) totals = 0;
394 | return (weight * assumedProbability + totals * basicProbability) / (weight + totals);
395 | }
396 |
397 | /**
398 | * Train the classifier by telling it that the given features resulted in
399 | * the given category.
400 | *
401 | * @param category
402 | * The category the features belong to.
403 | * @param features
404 | * The features that resulted in the given category.
405 | */
406 | public void learn(K category, Collection features) {
407 | this.learn(new Classification(features, category));
408 | }
409 |
410 | /**
411 | * Train the classifier by telling it that the given features resulted in
412 | * the given category.
413 | *
414 | * @param classification
415 | * The classification to learn.
416 | */
417 | public void learn(Classification classification) {
418 |
419 | for (T feature : classification.getFeatureset())
420 | this.incrementFeature(feature, classification.getCategory());
421 | this.incrementCategory(classification.getCategory());
422 |
423 | this.memoryQueue.offer(classification);
424 | if (this.memoryQueue.size() > this.memoryCapacity) {
425 | Classification toForget = this.memoryQueue.remove();
426 |
427 | for (T feature : toForget.getFeatureset())
428 | this.decrementFeature(feature, toForget.getCategory());
429 | this.decrementCategory(toForget.getCategory());
430 | }
431 | }
432 |
433 | /**
434 | * The classify method. It will retrieve the most likely category for the
435 | * features given and depends on the concrete classifier implementation.
436 | *
437 | * @param features
438 | * The features to classify.
439 | * @return The category most likely.
440 | */
441 | public abstract Classification classify(Collection features);
442 |
443 | }
444 |
--------------------------------------------------------------------------------
/src/main/java/de/daslaboratorium/machinelearning/classifier/IFeatureProbability.java:
--------------------------------------------------------------------------------
1 | package de.daslaboratorium.machinelearning.classifier;
2 |
3 | /**
4 | * Simple interface defining the method to calculate the feature probability.
5 | *
6 | * @author Philipp Nolte
7 | *
8 | * @param
9 | * The feature class.
10 | * @param
11 | * The category class.
12 | */
13 | public interface IFeatureProbability {
14 |
15 | /**
16 | * Returns the probability of a feature
being classified as
17 | * category
in the learning set.
18 | *
19 | * @param feature
20 | * the feature to return the probability for
21 | * @param category
22 | * the category to check the feature against
23 | * @return the probability p(feature|category)
24 | */
25 | public float featureProbability(T feature, K category);
26 |
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/java/de/daslaboratorium/machinelearning/classifier/bayes/BayesClassifier.java:
--------------------------------------------------------------------------------
1 | package de.daslaboratorium.machinelearning.classifier.bayes;
2 |
3 | import java.util.Collection;
4 | import java.util.Comparator;
5 | import java.util.SortedSet;
6 | import java.util.TreeSet;
7 |
8 | import de.daslaboratorium.machinelearning.classifier.Classification;
9 | import de.daslaboratorium.machinelearning.classifier.Classifier;
10 |
11 | /**
12 | * A concrete implementation of the abstract Classifier class. The Bayes
13 | * classifier implements a naive Bayes approach to classifying a given set of
14 | * features: classify(feat1,...,featN) = argmax(P(cat)*PROD(P(featI|cat)
15 | *
16 | * @author Philipp Nolte
17 | *
18 | * @see http://en.wikipedia.org/wiki/Naive_Bayes_classifier
19 | *
20 | * @param The feature class.
21 | * @param The category class.
22 | */
23 | public class BayesClassifier extends Classifier {
24 |
25 | /**
26 | * Calculates the product of all feature probabilities: PROD(P(featI|cat)
27 | *
28 | * @param features The set of features to use.
29 | * @param category The category to test for.
30 | * @return The product of all feature probabilities.
31 | */
32 | private float featuresProbabilityProduct(Collection features,
33 | K category) {
34 | float product = 1.0f;
35 | for (T feature : features)
36 | product *= this.featureWeighedAverage(feature, category);
37 | return product;
38 | }
39 |
40 | /**
41 | * Calculates the probability that the features can be classified as the
42 | * category given.
43 | *
44 | * @param features The set of features to use.
45 | * @param category The category to test for.
46 | * @return The probability that the features can be classified as the
47 | * category.
48 | */
49 | private float categoryProbability(Collection features, K category) {
50 | return ((float) this.getCategoryCount(category)
51 | / (float) this.getCategoriesTotal())
52 | * featuresProbabilityProduct(features, category);
53 | }
54 |
55 | /**
56 | * Retrieves a sorted Set
of probabilities that the given set
57 | * of features is classified as the available categories.
58 | *
59 | * @param features The set of features to use.
60 | * @return A sorted Set
of category-probability-entries.
61 | */
62 | private SortedSet> categoryProbabilities(
63 | Collection features) {
64 |
65 | /*
66 | * Sort the set according to the possibilities. Because we have to sort
67 | * by the mapped value and not by the mapped key, we can not use a
68 | * sorted tree (TreeMap) and we have to use a set-entry approach to
69 | * achieve the desired functionality. A custom comparator is therefore
70 | * needed.
71 | */
72 | SortedSet> probabilities =
73 | new TreeSet>(
74 | new Comparator>() {
75 |
76 | public int compare(Classification o1,
77 | Classification o2) {
78 | int toReturn = Float.compare(
79 | o1.getProbability(), o2.getProbability());
80 | if ((toReturn == 0)
81 | && !o1.getCategory().equals(o2.getCategory()))
82 | toReturn = -1;
83 | return toReturn;
84 | }
85 | });
86 |
87 | for (K category : this.getCategories())
88 | probabilities.add(new Classification(
89 | features, category,
90 | this.categoryProbability(features, category)));
91 | return probabilities;
92 | }
93 |
94 | /**
95 | * Classifies the given set of features.
96 | *
97 | * @return The category the set of features is classified as.
98 | */
99 | @Override
100 | public Classification classify(Collection features) {
101 | SortedSet> probabilites =
102 | this.categoryProbabilities(features);
103 |
104 | if (probabilites.size() > 0) {
105 | return probabilites.last();
106 | }
107 | return null;
108 | }
109 |
110 | /**
111 | * Classifies the given set of features. and return the full details of the
112 | * classification.
113 | *
114 | * @return The set of categories the set of features is classified as.
115 | */
116 | public Collection> classifyDetailed(
117 | Collection features) {
118 | return this.categoryProbabilities(features);
119 | }
120 |
121 | }
122 |
--------------------------------------------------------------------------------
/src/test/java/de/daslaboratorium/machinelearning/classifier/bayes/BayesClassifierTest.java:
--------------------------------------------------------------------------------
1 | package de.daslaboratorium.machinelearning.classifier.bayes;
2 |
3 | import java.io.ByteArrayOutputStream;
4 | import java.io.IOException;
5 | import java.io.ObjectOutputStream;
6 | import java.util.ArrayList;
7 | import java.util.Arrays;
8 | import java.util.Collection;
9 | import java.util.List;
10 |
11 | import org.junit.Assert;
12 | import org.junit.Before;
13 | import org.junit.Test;
14 |
15 | import de.daslaboratorium.machinelearning.classifier.Classification;
16 | import de.daslaboratorium.machinelearning.classifier.Classifier;
17 |
18 | public class BayesClassifierTest {
19 |
20 | private static final double EPSILON = 0.001;
21 | private static final String CATEGORY_NEGATIVE = "negative";
22 | private static final String CATEGORY_POSITIVE = "positive";
23 | private Classifier bayes;
24 |
25 | @Before
26 | public void setUp() {
27 | /*
28 | * Create a new classifier instance. The context features are Strings
29 | * and the context will be classified with a String according to the
30 | * featureset of the context.
31 | */
32 | bayes = new BayesClassifier();
33 |
34 | /*
35 | * The classifier can learn from classifications that are handed over to
36 | * the learn methods. Imagin a tokenized text as follows. The tokens are
37 | * the text's features. The category of the text will either be positive
38 | * or negative.
39 | */
40 | final String[] positiveText = "I love sunny days".split("\\s");
41 | bayes.learn(CATEGORY_POSITIVE, Arrays.asList(positiveText));
42 |
43 | final String[] negativeText = "I hate rain".split("\\s");
44 | bayes.learn(CATEGORY_NEGATIVE, Arrays.asList(negativeText));
45 | }
46 |
47 | @Test
48 | public void testStringClassification() {
49 | final String[] unknownText1 = "today is a sunny day".split("\\s");
50 | final String[] unknownText2 = "there will be rain".split("\\s");
51 |
52 | Assert.assertEquals(CATEGORY_POSITIVE, bayes.classify(Arrays.asList(unknownText1)).getCategory());
53 | Assert.assertEquals(CATEGORY_NEGATIVE, bayes.classify(Arrays.asList(unknownText2)).getCategory());
54 | }
55 |
56 | @Test
57 | public void testStringClassificationInDetails() {
58 |
59 | final String[] unknownText1 = "today is a sunny day".split("\\s");
60 |
61 | Collection> classifications = ((BayesClassifier) bayes)
62 | .classifyDetailed(Arrays.asList(unknownText1));
63 |
64 | List> list = new ArrayList>(classifications);
65 |
66 | Assert.assertEquals(CATEGORY_NEGATIVE, list.get(0).getCategory());
67 | Assert.assertEquals(0.0078125, list.get(0).getProbability(), EPSILON);
68 |
69 | Assert.assertEquals(CATEGORY_POSITIVE, list.get(1).getCategory());
70 | Assert.assertEquals(0.0234375, list.get(1).getProbability(), EPSILON);
71 | }
72 |
73 | @Test
74 | public void testSerialization() throws IOException {
75 |
76 | new ObjectOutputStream(new ByteArrayOutputStream()).writeObject(bayes);
77 | }
78 | }
--------------------------------------------------------------------------------