├── .gitignore
├── LICENSE
├── README
├── build.xml
└── src
├── DecisionTreeTest.java
└── dt
├── Algorithm.java
├── Attribute.java
├── BadDecisionException.java
├── DecisionTree.java
├── Decisions.java
├── Examples.java
├── ID3Algorithm.java
└── UnknownDecisionException.java
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.jar
3 | *~
4 | build/
5 | tags
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2011, John Weaver
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5 |
6 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 | Neither the name of the John Weaver nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
9 |
10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
11 |
--------------------------------------------------------------------------------
/README:
--------------------------------------------------------------------------------
1 |
2 | Copyright(c) 2011 John Weaver.
3 | All Rights Reserved.
4 |
5 | Source code licensed under the BSD license. Please see the LICENSE file for details.
6 |
7 |
8 | decision-tree
9 | =============
10 |
11 | This package consists of several Java classes to construct and apply a decision
12 | tree to a attribute set. It allows for pluggable algorithms to construct the
13 | tree. The decision tree is constructed from a series of examples of attributes,
14 | where each example either has each of the attributes or does not, and each has
15 | a specified outcome of either true or false. The resulting decision tree is a
16 | binary tree where each leaf node represents the presents or absence of each
17 | attribute named along the path to the root node and the resulting outcome for
18 | the set of decisions.
19 |
20 | Dependencies
21 | ------------
22 |
23 | * Uses SLF4J for logging.
24 |
--------------------------------------------------------------------------------
/build.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/src/DecisionTreeTest.java:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 |
5 | import org.junit.*;
6 | import static org.junit.Assert.*;
7 | import static org.hamcrest.CoreMatchers.*;
8 |
9 | import java.util.*;
10 |
11 | import dt.*;
12 |
13 | public class DecisionTreeTest {
14 | private DecisionTree makeOne() {
15 | return new DecisionTree();
16 | }
17 |
18 | private DecisionTree makeOutlookTree() {
19 | try {
20 | // example data from http://www.cise.ufl.edu/~ddd/cap6635/Fall-97/Short-papers/2.htm
21 | return makeOne().setAttributes(new String[]{"Outlook", "Temperature", "Humidity", "Wind"})
22 | .addExample( new String[]{"Sunny", "Hot", "High", "Weak"}, false)
23 | .addExample( new String[]{"Sunny", "Hot", "High", "Strong"}, false)
24 | .addExample( new String[]{"Overcast", "Hot", "High", "Weak"}, true)
25 | .addExample( new String[]{"Rain", "Mild", "High", "Weak"}, true)
26 | .addExample( new String[]{"Rain", "Cool", "Normal", "Weak"}, true)
27 | .addExample( new String[]{"Rain", "Cool", "Normal", "Strong"}, false)
28 | .addExample( new String[]{"Overcast", "Cool", "Normal", "Strong"}, true)
29 | .addExample( new String[]{"Sunny", "Mild", "High", "Weak"}, false)
30 | .addExample( new String[]{"Sunny", "Cool", "Normal", "Weak"}, true)
31 | .addExample( new String[]{"Rain", "Mild", "Normal", "Weak"}, true)
32 | .addExample( new String[]{"Sunny", "Mild", "Normal", "Strong"}, true)
33 | .addExample( new String[]{"Overcast", "Mild", "High", "Strong"}, true)
34 | .addExample( new String[]{"Overcast", "Hot", "Normal", "Weak"}, true)
35 | .addExample( new String[]{"Rain", "Mild", "High", "Strong"}, false);
36 | } catch ( UnknownDecisionException e ) {
37 | fail();
38 | return makeOne(); // this is here to shut up compiler.
39 | }
40 | }
41 |
42 | @Test (expected=UnknownDecisionException.class) public void testUnknownDecisionThrowsException() throws UnknownDecisionException {
43 | DecisionTree tree = makeOne().setAttributes(new String[]{"Outlook"})
44 | .setDecisions("Outlook", new String[]{"Sunny", "Overcast"});
45 |
46 | // this causes exception
47 | tree.addExample(new String[]{"Rain"}, false);
48 | }
49 |
50 | @Test public void testOutlookOvercastApplyReturnsTrue() throws BadDecisionException {
51 | Map case1 = new HashMap();
52 | case1.put("Outlook", "Overcast");
53 | case1.put("Temperature", "Hot");
54 | case1.put("Humidity", "High");
55 | case1.put("Wind", "Strong");
56 | assertTrue(makeOutlookTree().apply(case1));
57 | }
58 |
59 | @Test (expected=BadDecisionException.class) public void testOutlookRainInsufficientDataThrowsException() throws BadDecisionException {
60 | Map case1 = new HashMap();
61 | case1.put("Outlook", "Rain");
62 | case1.put("Temperature", "Mild");
63 | makeOutlookTree().apply(case1);
64 | }
65 |
66 | public void attributeIsUsedOnlyOnceInTree(Attribute node, List attributes) {
67 | for ( Attribute child : node.getDecisions().values() ) {
68 | if ( !child.isLeaf() ) {
69 | assertFalse( attributes.contains(child.getName()) );
70 | attributes.add(child.getName());
71 | attributeIsUsedOnlyOnceInTree(child, attributes);
72 | }
73 | }
74 | }
75 |
76 | @Test public void testAttributeIsUsedOnlyOnceInTree() {
77 | DecisionTree tree = makeOutlookTree();
78 | tree.compile();
79 |
80 | List attributeList = new LinkedList();
81 | attributeList.add(tree.getRoot().getName());
82 | attributeIsUsedOnlyOnceInTree(tree.getRoot(), attributeList);
83 | }
84 |
85 |
86 | public static void main(String args[]) {
87 | org.junit.runner.JUnitCore.main("DecisionTreeTest");
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/src/dt/Algorithm.java:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 |
5 | package dt;
6 |
7 | import java.util.*;
8 |
9 |
10 | public interface Algorithm {
11 | /**
12 | * Find the next attribute.
13 | *
14 | * For the initial attribute, pass an empty
15 | * chosenAttributes and use the returned attribute as the rootAttribute.
16 | * Then, walk the decision tree pre-order. At each decision, call this method
17 | * with the attribute/decision pairs that led to that node in
18 | * chosenAttributes. Attach the returned Attribute to the decision.
19 | *
20 | */
21 | abstract public Attribute nextAttribute(Map chosenAttributes, Set usedAttributes);
22 | }
23 |
--------------------------------------------------------------------------------
/src/dt/Attribute.java:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 |
5 | package dt;
6 |
7 | import java.util.*;
8 |
9 |
10 | public class Attribute {
11 | /**
12 | * Indicates if this attribute yields a classification (true) or has child
13 | * decisions that point to further attributes (false).
14 | */
15 | private boolean leaf;
16 |
17 | private String attributeName;
18 | private Decisions decisions;
19 | private boolean classification;
20 |
21 | public Attribute(boolean classification) {
22 | leaf = true;
23 | this.classification = classification;
24 | decisions = new Decisions();
25 | attributeName = null;
26 | }
27 |
28 | public Attribute(String name) {
29 | leaf = false;
30 | attributeName = name;
31 | decisions = new Decisions();
32 | }
33 |
34 | public String getName() {
35 | return attributeName;
36 | }
37 |
38 | public boolean isLeaf() {
39 | return leaf;
40 | }
41 |
42 | public void setClassification(boolean classification) {
43 | assert ( leaf );
44 |
45 | this.classification = classification;
46 | }
47 |
48 | /**
49 | * Returns the classification of the followed decision.
50 | *
51 | * Undefined if isLeaf() returns false.
52 | */
53 | public boolean getClassification() {
54 | assert ( leaf );
55 |
56 | return classification;
57 | }
58 |
59 | public boolean apply(Map data) throws BadDecisionException {
60 | if ( isLeaf() )
61 | return getClassification();
62 |
63 | Attribute nextAttribute = decisions.apply(data.get(attributeName));
64 | return nextAttribute.apply(data);
65 | }
66 |
67 | public void addDecision(String decision, Attribute attribute) {
68 | assert ( !leaf );
69 |
70 | decisions.put(decision, attribute);
71 | }
72 |
73 | public String toString() {
74 | StringBuffer b = new StringBuffer();
75 |
76 | for ( Map.Entry e : decisions.getMap().entrySet() ) {
77 | b.append(getName());
78 | b.append(" -> ");
79 | if ( e.getValue().isLeaf() )
80 | b.append(e.getValue().getClassification());
81 | else
82 | b.append(e.getValue().getName());
83 | b.append(" [label=\"");
84 | b.append(e.getKey());
85 | b.append("\"]\n");
86 |
87 | b.append(e.getValue().toString());
88 | }
89 |
90 | return b.toString();
91 | }
92 |
93 | public Map getDecisions() {
94 | return decisions.getMap();
95 | }
96 | }
97 |
--------------------------------------------------------------------------------
/src/dt/BadDecisionException.java:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 |
5 | package dt;
6 |
7 | public class BadDecisionException extends Exception {
8 | }
9 |
10 |
--------------------------------------------------------------------------------
/src/dt/DecisionTree.java:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 |
5 | package dt;
6 |
7 | import java.util.*;
8 |
9 | public class DecisionTree {
10 | /**
11 | * Contains the set of available attributes.
12 | */
13 | private LinkedHashSet attributes;
14 |
15 | /**
16 | * Maps a attribute name to a set of possible decisions for that attribute.
17 | */
18 | private Map > decisions;
19 | private boolean decisionsSpecified;
20 |
21 | /**
22 | * Contains the examples to be processed into a decision tree.
23 | *
24 | * The 'attributes' and 'decisions' member variables should be updated
25 | * prior to adding examples that refer to new attributes or decisions.
26 | */
27 | private Examples examples;
28 |
29 | /**
30 | * Indicates if the provided data has been processed into a decision tree.
31 | *
32 | * This value is initially false, and is reset any time additional data is
33 | * provided.
34 | */
35 | private boolean compiled;
36 |
37 | /**
38 | * Contains the top-most attribute of the decision tree.
39 | *
40 | * For a tree where the decision requires no attributes,
41 | * the rootAttribute yields a boolean classification.
42 | *
43 | */
44 | private Attribute rootAttribute;
45 |
46 | private Algorithm algorithm;
47 |
48 | public DecisionTree() {
49 | algorithm = null;
50 | examples = new Examples();
51 | attributes = new LinkedHashSet();
52 | decisions = new HashMap >();
53 | decisionsSpecified = false;
54 | }
55 |
56 | private void setDefaultAlgorithm() {
57 | if ( algorithm == null )
58 | setAlgorithm(new ID3Algorithm(examples));
59 | }
60 |
61 | public void setAlgorithm(Algorithm algorithm) {
62 | this.algorithm = algorithm;
63 | }
64 |
65 | /**
66 | * Saves the array of attribute names in an insertion ordered set.
67 | *
68 | * The ordering of attribute names is used when addExamples is called to
69 | * determine which values correspond with which names.
70 | *
71 | */
72 | public DecisionTree setAttributes(String[] attributeNames) {
73 | compiled = false;
74 |
75 | decisions.clear();
76 | decisionsSpecified = false;
77 |
78 | attributes.clear();
79 |
80 | for ( int i = 0 ; i < attributeNames.length ; i++ )
81 | attributes.add(attributeNames[i]);
82 |
83 | return this;
84 | }
85 |
86 | /**
87 | */
88 | public DecisionTree setDecisions(String attributeName, String[] decisions) {
89 | if ( !attributes.contains(attributeName) ) {
90 | // TODO some kind of warning or something
91 | return this;
92 | }
93 |
94 | compiled = false;
95 | decisionsSpecified = true;
96 |
97 | Set decisionsSet = new HashSet();
98 | for ( int i = 0 ; i < decisions.length ; i++ )
99 | decisionsSet.add(decisions[i]);
100 |
101 | this.decisions.put(attributeName, decisionsSet);
102 |
103 | return this;
104 | }
105 |
106 | /**
107 | */
108 | public DecisionTree addExample(String[] attributeValues, boolean classification) throws UnknownDecisionException {
109 | String[] attributes = this.attributes.toArray(new String[0]);
110 |
111 | if ( decisionsSpecified )
112 | for ( int i = 0 ; i < attributeValues.length ; i++ )
113 | if ( !decisions.get(attributes[i]).contains(attributeValues[i]) ) {
114 | throw new UnknownDecisionException(attributes[i], attributeValues[i]);
115 | }
116 |
117 | compiled = false;
118 |
119 | examples.add(attributes, attributeValues, classification);
120 |
121 | return this;
122 | }
123 |
124 | public DecisionTree addExample(Map attributes, boolean classification) throws UnknownDecisionException {
125 | compiled = false;
126 |
127 | examples.add(attributes, classification);
128 |
129 | return this;
130 | }
131 |
132 | public boolean apply(Map data) throws BadDecisionException {
133 | compile();
134 |
135 | return rootAttribute.apply(data);
136 | }
137 |
138 | private Attribute compileWalk(Attribute current, Map chosenAttributes, Set usedAttributes) {
139 | // if the current attribute is a leaf, then there are no decisions and thus no
140 | // further attributes to find.
141 | if ( current.isLeaf() )
142 | return current;
143 |
144 | // get decisions for the current attribute (from this.decisions)
145 | String attributeName = current.getName();
146 |
147 | // remove this attribute from all further consideration
148 | usedAttributes.add(attributeName);
149 |
150 | for ( String decisionName : decisions.get(attributeName) ) {
151 | // overwrite the attribute decision for each value considered
152 | chosenAttributes.put(attributeName, decisionName);
153 |
154 | // find the next attribute to choose for the considered decision
155 | // build the subtree from this new attribute, pre-order
156 | // insert the newly-built subtree into the open decision slot
157 | current.addDecision(decisionName, compileWalk(algorithm.nextAttribute(chosenAttributes, usedAttributes), chosenAttributes, usedAttributes));
158 | }
159 |
160 | // remove the attribute decision before we walk back up the tree.
161 | chosenAttributes.remove(attributeName);
162 |
163 | // return the subtree so that it can be inserted into the parent tree.
164 | return current;
165 | }
166 |
167 | public void compile() {
168 | // skip compilation if already done.
169 | if ( compiled )
170 | return;
171 |
172 | // if no algorithm is set beforehand, select the default one.
173 | setDefaultAlgorithm();
174 |
175 | Map chosenAttributes = new HashMap();
176 | Set usedAttributes = new HashSet();
177 |
178 | if ( !decisionsSpecified )
179 | decisions = examples.extractDecisions();
180 |
181 | // find the root attribute (either leaf or non)
182 | // walk the tree, adding attributes as needed under each decision
183 | // save the original attribute as the root attribute.
184 | rootAttribute = compileWalk(algorithm.nextAttribute(chosenAttributes, usedAttributes), chosenAttributes, usedAttributes);
185 |
186 | compiled = true;
187 | }
188 |
189 | public String toString() {
190 | compile();
191 |
192 | if ( rootAttribute != null )
193 | return rootAttribute.toString();
194 | else
195 | return "";
196 | }
197 |
198 | public Attribute getRoot() {
199 | return rootAttribute;
200 | }
201 | }
202 |
--------------------------------------------------------------------------------
/src/dt/Decisions.java:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 |
5 | package dt;
6 |
7 | import java.util.*;
8 |
9 |
10 | class Decisions {
11 | private Map decisions;
12 |
13 | public Decisions() {
14 | decisions = new HashMap();
15 | }
16 |
17 | public Map getMap() {
18 | return decisions;
19 | }
20 |
21 | public void put(String decision, Attribute attribute) {
22 | decisions.put(decision, attribute);
23 | }
24 |
25 | public void clear() {
26 | decisions.clear();
27 | }
28 |
29 | /**
30 | * Returns the attribute based on the decision matching the provided value.
31 | *
32 | * Throws BadDecisionException if no decision matches.
33 | */
34 | public Attribute apply(String value) throws BadDecisionException {
35 | Attribute result = decisions.get(value);
36 |
37 | if ( result == null )
38 | throw new BadDecisionException();
39 |
40 | return result;
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/src/dt/Examples.java:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 |
5 | package dt;
6 |
7 | import java.util.*;
8 |
9 |
10 | class Examples {
11 | class Example {
12 | private Map values;
13 | private boolean classifier;
14 |
15 | public Example(String[] attributeNames, String[] attributeValues,
16 | boolean classifier) {
17 | assert(attributeNames.length == attributeValues.length);
18 | values = new HashMap();
19 |
20 | for ( int i = 0 ; i < attributeNames.length ; i++ ) {
21 | values.put(attributeNames[i], attributeValues[i]);
22 | }
23 |
24 | this.classifier = classifier;
25 | }
26 |
27 | public Example(Map attributes, boolean classifier) {
28 | this.classifier = classifier;
29 | this.values = attributes;
30 | }
31 |
32 | public Set getAttributes() {
33 | return values.keySet();
34 | }
35 |
36 | public String getAttributeValue(String attribute) {
37 | return values.get(attribute);
38 | }
39 |
40 | public boolean matchesClass(boolean classifier) {
41 | return classifier == this.classifier;
42 | }
43 | }
44 |
45 | private List examples;
46 |
47 | public Examples() {
48 | examples = new LinkedList();
49 | }
50 |
51 | public void add(String[] attributeNames, String[] attributeValues,
52 | boolean classifier) {
53 | examples.add(new Example(attributeNames, attributeValues, classifier));
54 | }
55 |
56 | public void add(Map attributes, boolean classifier) {
57 | examples.add(new Example(attributes, classifier));
58 | }
59 |
60 | /**
61 | * Returns the number of examples where the attribute has the specified
62 | * 'decision' value
63 | */
64 | int countDecisions(String attribute, String decision) {
65 | int count = 0;
66 |
67 | for ( Example e : examples ) {
68 | if ( e.getAttributeValue(attribute).equals(decision) )
69 | count++;
70 | }
71 |
72 | return count;
73 | }
74 |
75 | /**
76 | * Returns a map from each attribute name to a set of all values used in the
77 | * examples for that attribute.
78 | */
79 | public Map > extractDecisions() {
80 | Map > decisions = new HashMap >();
81 |
82 | for ( String attribute : extractAttributes() ) {
83 | decisions.put(attribute, extractDecisions(attribute));
84 | }
85 |
86 | return decisions;
87 | }
88 |
89 | public int countNegative(String attribute, String decision,
90 | Map attributes) {
91 | return countClassifier(false, attribute, decision, attributes);
92 | }
93 |
94 | public int countPositive(String attribute, String decision,
95 | Map attributes) {
96 | return countClassifier(true, attribute, decision, attributes);
97 | }
98 |
99 | public int countNegative(Map attributes) {
100 | return countClassifier(false, attributes);
101 | }
102 |
103 | public int countPositive(Map attributes) {
104 | return countClassifier(true, attributes);
105 | }
106 |
107 | public int count(String attribute, String decision, Map attributes) {
108 | attributes = new HashMap(attributes);
109 | attributes.put(attribute, decision);
110 |
111 | return count(attributes);
112 | }
113 |
114 | public int count(Map attributes) {
115 | int count = 0;
116 |
117 | nextExample:
118 | for ( Example e : examples ) {
119 | for ( Map.Entry attribute : attributes.entrySet() )
120 | if ( !(e.getAttributeValue(attribute.getKey()).equals(attribute.getValue())) )
121 | continue nextExample;
122 |
123 | // All of the provided attributes match the example.
124 | count++;
125 | }
126 |
127 | return count;
128 | }
129 |
130 | public int countClassifier(boolean classifier, Map attributes) {
131 | int count = 0;
132 |
133 | nextExample:
134 | for ( Example e : examples ) {
135 | for ( Map.Entry attribute : attributes.entrySet() )
136 | if ( !(e.getAttributeValue(attribute.getKey()).equals(attribute.getValue())) )
137 | continue nextExample;
138 |
139 | // All of the provided attributes match the example.
140 | // If the example matches the classifier, then include it in the count.
141 | if ( e.matchesClass(classifier) )
142 | count++;
143 | }
144 |
145 | return count;
146 | }
147 |
148 | public int countClassifier(boolean classifier, String attribute,
149 | String decision, Map attributes) {
150 | attributes = new HashMap(attributes);
151 | attributes.put(attribute, decision);
152 |
153 | return countClassifier(classifier, attributes);
154 | }
155 |
156 | /**
157 | * Returns the number of examples.
158 | */
159 | public int count() {
160 | return examples.size();
161 | }
162 |
163 | /**
164 | * Returns a set of attribute names used in the examples.
165 | */
166 | public Set extractAttributes() {
167 | Set attributes = new HashSet();
168 |
169 | for ( Example e : examples ) {
170 | attributes.addAll(e.getAttributes());
171 | }
172 |
173 | return attributes;
174 | }
175 |
176 | private Set extractDecisions(String attribute) {
177 | Set decisions = new HashSet();
178 |
179 | for ( Example e : examples ) {
180 | decisions.add(e.getAttributeValue(attribute));
181 | }
182 |
183 | return decisions;
184 | }
185 | }
186 |
--------------------------------------------------------------------------------
/src/dt/ID3Algorithm.java:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 |
5 | package dt;
6 |
7 | import java.util.*;
8 |
9 | import org.slf4j.Logger;
10 | import org.slf4j.LoggerFactory;
11 |
12 |
13 | public class ID3Algorithm implements Algorithm {
14 | final Logger logger = LoggerFactory.getLogger(ID3Algorithm.class);
15 | private Examples examples;
16 |
17 | public ID3Algorithm(Examples examples) {
18 | this.examples = examples;
19 | }
20 |
21 | /**
22 | * Returns the next attribute to be chosen.
23 | *
24 | * chosenAttributes represents the decision path from the root attribute
25 | * to the node under consideration. usedAttributes is the set of all
26 | * attributes that have been incorporated into the tree prior to this
27 | * call to nextAttribute(), even if the attributes were not used in the path
28 | * to the node under consideration.
29 | *
30 | * Results are undefined if examples.count() == 0.
31 | */
32 | public Attribute nextAttribute(Map chosenAttributes, Set usedAttributes) {
33 | double currentGain = 0.0, bestGain = 0.0;
34 | String bestAttribute = "";
35 |
36 | /*
37 | * If there are no positive examples for the already chosen attributes,
38 | * then return a false classifier leaf. If no negative examples,
39 | * then return a true classifier leaf.
40 | */
41 | if ( examples.countPositive(chosenAttributes) == 0 )
42 | return new Attribute(false);
43 | else if ( examples.countNegative(chosenAttributes) == 0 )
44 | return new Attribute(true);
45 |
46 | logger.debug("Choosing attribute out of {} remaining attributes.",
47 | remainingAttributes(usedAttributes).size());
48 | logger.debug("Already chosen attributes/decisions are {}.", chosenAttributes);
49 |
50 | for ( String attribute : remainingAttributes(usedAttributes) ) {
51 | // for each remaining attribute, determine the information gain of using it
52 | // to choose among the examples selected by the chosenAttributes
53 | // if none give any information gain, return a leaf attribute,
54 | // otherwise return the found attribute as a non-leaf attribute
55 | currentGain = informationGain(attribute, chosenAttributes);
56 | logger.debug("Evaluating attribute {}, information gain is {}",
57 | attribute, currentGain);
58 | if ( currentGain > bestGain ) {
59 | bestAttribute = attribute;
60 | bestGain = currentGain;
61 | }
62 | }
63 |
64 | // If no attribute gives information gain, generate leaf attribute.
65 | // Leaf is true if there are any true classifiers.
66 | // If there is at least one negative example, then the information gain
67 | // would be greater than 0.
68 | if ( bestGain == 0.0 ) {
69 | boolean classifier = examples.countPositive(chosenAttributes) > 0;
70 | logger.debug("Creating new leaf attribute with classifier {}.", classifier);
71 | return new Attribute(classifier);
72 | } else {
73 | logger.debug("Creating new non-leaf attribute {}.", bestAttribute);
74 | return new Attribute(bestAttribute);
75 | }
76 | }
77 |
78 | private Set remainingAttributes(Set usedAttributes) {
79 | Set result = examples.extractAttributes();
80 | result.removeAll(usedAttributes);
81 | return result;
82 | }
83 |
84 | private double entropy(Map specifiedAttributes) {
85 | double totalExamples = examples.count();
86 | double positiveExamples = examples.countPositive(specifiedAttributes);
87 | double negativeExamples = examples.countNegative(specifiedAttributes);
88 |
89 | return -nlog2(positiveExamples / totalExamples) -
90 | nlog2(negativeExamples / totalExamples);
91 | }
92 |
93 | private double entropy(String attribute, String decision, Map specifiedAttributes) {
94 | double totalExamples = examples.count(attribute, decision, specifiedAttributes);
95 | double positiveExamples = examples.countPositive(attribute, decision, specifiedAttributes);
96 | double negativeExamples = examples.countNegative(attribute, decision, specifiedAttributes);
97 |
98 | return -nlog2(positiveExamples / totalExamples) -
99 | nlog2(negativeExamples / totalExamples);
100 | }
101 |
102 | private double informationGain(String attribute, Map specifiedAttributes) {
103 | double sum = entropy(specifiedAttributes);
104 | double examplesCount = examples.count(specifiedAttributes);
105 |
106 | if ( examplesCount == 0 )
107 | return sum;
108 |
109 | Map > decisions = examples.extractDecisions();
110 |
111 | for ( String decision : decisions.get(attribute) ) {
112 | double entropyPart = entropy(attribute, decision, specifiedAttributes);
113 | double decisionCount = examples.countDecisions(attribute, decision);
114 |
115 | sum += -(decisionCount / examplesCount) * entropyPart;
116 | }
117 |
118 | return sum;
119 | }
120 |
121 | private double nlog2(double value) {
122 | if ( value == 0 )
123 | return 0;
124 |
125 | return value * Math.log(value) / Math.log(2);
126 | }
127 | }
128 |
--------------------------------------------------------------------------------
/src/dt/UnknownDecisionException.java:
--------------------------------------------------------------------------------
1 | /**
2 | *
3 | */
4 |
5 | package dt;
6 |
7 |
8 | public class UnknownDecisionException extends Exception {
9 | public UnknownDecisionException(String attribute, String decision) {
10 | super();
11 | }
12 | }
13 |
14 |
--------------------------------------------------------------------------------