├── .gitignore ├── LICENSE ├── README ├── build.xml └── src ├── DecisionTreeTest.java └── dt ├── Algorithm.java ├── Attribute.java ├── BadDecisionException.java ├── DecisionTree.java ├── Decisions.java ├── Examples.java ├── ID3Algorithm.java └── UnknownDecisionException.java /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.jar 3 | *~ 4 | build/ 5 | tags 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2011, John Weaver 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 | 6 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | Neither the name of the John Weaver nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 9 | 10 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 11 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | 2 | Copyright(c) 2011 John Weaver. 3 | All Rights Reserved. 4 | 5 | Source code licensed under the BSD license. Please see the LICENSE file for details. 6 | 7 | 8 | decision-tree 9 | ============= 10 | 11 | This package consists of several Java classes to construct and apply a decision 12 | tree to a attribute set. It allows for pluggable algorithms to construct the 13 | tree. The decision tree is constructed from a series of examples of attributes, 14 | where each example either has each of the attributes or does not, and each has 15 | a specified outcome of either true or false. The resulting decision tree is a 16 | binary tree where each leaf node represents the presents or absence of each 17 | attribute named along the path to the root node and the resulting outcome for 18 | the set of decisions. 19 | 20 | Dependencies 21 | ------------ 22 | 23 | * Uses SLF4J for logging. 24 | -------------------------------------------------------------------------------- /build.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /src/DecisionTreeTest.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | 5 | import org.junit.*; 6 | import static org.junit.Assert.*; 7 | import static org.hamcrest.CoreMatchers.*; 8 | 9 | import java.util.*; 10 | 11 | import dt.*; 12 | 13 | public class DecisionTreeTest { 14 | private DecisionTree makeOne() { 15 | return new DecisionTree(); 16 | } 17 | 18 | private DecisionTree makeOutlookTree() { 19 | try { 20 | // example data from http://www.cise.ufl.edu/~ddd/cap6635/Fall-97/Short-papers/2.htm 21 | return makeOne().setAttributes(new String[]{"Outlook", "Temperature", "Humidity", "Wind"}) 22 | .addExample( new String[]{"Sunny", "Hot", "High", "Weak"}, false) 23 | .addExample( new String[]{"Sunny", "Hot", "High", "Strong"}, false) 24 | .addExample( new String[]{"Overcast", "Hot", "High", "Weak"}, true) 25 | .addExample( new String[]{"Rain", "Mild", "High", "Weak"}, true) 26 | .addExample( new String[]{"Rain", "Cool", "Normal", "Weak"}, true) 27 | .addExample( new String[]{"Rain", "Cool", "Normal", "Strong"}, false) 28 | .addExample( new String[]{"Overcast", "Cool", "Normal", "Strong"}, true) 29 | .addExample( new String[]{"Sunny", "Mild", "High", "Weak"}, false) 30 | .addExample( new String[]{"Sunny", "Cool", "Normal", "Weak"}, true) 31 | .addExample( new String[]{"Rain", "Mild", "Normal", "Weak"}, true) 32 | .addExample( new String[]{"Sunny", "Mild", "Normal", "Strong"}, true) 33 | .addExample( new String[]{"Overcast", "Mild", "High", "Strong"}, true) 34 | .addExample( new String[]{"Overcast", "Hot", "Normal", "Weak"}, true) 35 | .addExample( new String[]{"Rain", "Mild", "High", "Strong"}, false); 36 | } catch ( UnknownDecisionException e ) { 37 | fail(); 38 | return makeOne(); // this is here to shut up compiler. 39 | } 40 | } 41 | 42 | @Test (expected=UnknownDecisionException.class) public void testUnknownDecisionThrowsException() throws UnknownDecisionException { 43 | DecisionTree tree = makeOne().setAttributes(new String[]{"Outlook"}) 44 | .setDecisions("Outlook", new String[]{"Sunny", "Overcast"}); 45 | 46 | // this causes exception 47 | tree.addExample(new String[]{"Rain"}, false); 48 | } 49 | 50 | @Test public void testOutlookOvercastApplyReturnsTrue() throws BadDecisionException { 51 | Map case1 = new HashMap(); 52 | case1.put("Outlook", "Overcast"); 53 | case1.put("Temperature", "Hot"); 54 | case1.put("Humidity", "High"); 55 | case1.put("Wind", "Strong"); 56 | assertTrue(makeOutlookTree().apply(case1)); 57 | } 58 | 59 | @Test (expected=BadDecisionException.class) public void testOutlookRainInsufficientDataThrowsException() throws BadDecisionException { 60 | Map case1 = new HashMap(); 61 | case1.put("Outlook", "Rain"); 62 | case1.put("Temperature", "Mild"); 63 | makeOutlookTree().apply(case1); 64 | } 65 | 66 | public void attributeIsUsedOnlyOnceInTree(Attribute node, List attributes) { 67 | for ( Attribute child : node.getDecisions().values() ) { 68 | if ( !child.isLeaf() ) { 69 | assertFalse( attributes.contains(child.getName()) ); 70 | attributes.add(child.getName()); 71 | attributeIsUsedOnlyOnceInTree(child, attributes); 72 | } 73 | } 74 | } 75 | 76 | @Test public void testAttributeIsUsedOnlyOnceInTree() { 77 | DecisionTree tree = makeOutlookTree(); 78 | tree.compile(); 79 | 80 | List attributeList = new LinkedList(); 81 | attributeList.add(tree.getRoot().getName()); 82 | attributeIsUsedOnlyOnceInTree(tree.getRoot(), attributeList); 83 | } 84 | 85 | 86 | public static void main(String args[]) { 87 | org.junit.runner.JUnitCore.main("DecisionTreeTest"); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/dt/Algorithm.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | 5 | package dt; 6 | 7 | import java.util.*; 8 | 9 | 10 | public interface Algorithm { 11 | /** 12 | * Find the next attribute. 13 | * 14 | * For the initial attribute, pass an empty 15 | * chosenAttributes and use the returned attribute as the rootAttribute. 16 | * Then, walk the decision tree pre-order. At each decision, call this method 17 | * with the attribute/decision pairs that led to that node in 18 | * chosenAttributes. Attach the returned Attribute to the decision. 19 | * 20 | */ 21 | abstract public Attribute nextAttribute(Map chosenAttributes, Set usedAttributes); 22 | } 23 | -------------------------------------------------------------------------------- /src/dt/Attribute.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | 5 | package dt; 6 | 7 | import java.util.*; 8 | 9 | 10 | public class Attribute { 11 | /** 12 | * Indicates if this attribute yields a classification (true) or has child 13 | * decisions that point to further attributes (false). 14 | */ 15 | private boolean leaf; 16 | 17 | private String attributeName; 18 | private Decisions decisions; 19 | private boolean classification; 20 | 21 | public Attribute(boolean classification) { 22 | leaf = true; 23 | this.classification = classification; 24 | decisions = new Decisions(); 25 | attributeName = null; 26 | } 27 | 28 | public Attribute(String name) { 29 | leaf = false; 30 | attributeName = name; 31 | decisions = new Decisions(); 32 | } 33 | 34 | public String getName() { 35 | return attributeName; 36 | } 37 | 38 | public boolean isLeaf() { 39 | return leaf; 40 | } 41 | 42 | public void setClassification(boolean classification) { 43 | assert ( leaf ); 44 | 45 | this.classification = classification; 46 | } 47 | 48 | /** 49 | * Returns the classification of the followed decision. 50 | * 51 | * Undefined if isLeaf() returns false. 52 | */ 53 | public boolean getClassification() { 54 | assert ( leaf ); 55 | 56 | return classification; 57 | } 58 | 59 | public boolean apply(Map data) throws BadDecisionException { 60 | if ( isLeaf() ) 61 | return getClassification(); 62 | 63 | Attribute nextAttribute = decisions.apply(data.get(attributeName)); 64 | return nextAttribute.apply(data); 65 | } 66 | 67 | public void addDecision(String decision, Attribute attribute) { 68 | assert ( !leaf ); 69 | 70 | decisions.put(decision, attribute); 71 | } 72 | 73 | public String toString() { 74 | StringBuffer b = new StringBuffer(); 75 | 76 | for ( Map.Entry e : decisions.getMap().entrySet() ) { 77 | b.append(getName()); 78 | b.append(" -> "); 79 | if ( e.getValue().isLeaf() ) 80 | b.append(e.getValue().getClassification()); 81 | else 82 | b.append(e.getValue().getName()); 83 | b.append(" [label=\""); 84 | b.append(e.getKey()); 85 | b.append("\"]\n"); 86 | 87 | b.append(e.getValue().toString()); 88 | } 89 | 90 | return b.toString(); 91 | } 92 | 93 | public Map getDecisions() { 94 | return decisions.getMap(); 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /src/dt/BadDecisionException.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | 5 | package dt; 6 | 7 | public class BadDecisionException extends Exception { 8 | } 9 | 10 | -------------------------------------------------------------------------------- /src/dt/DecisionTree.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | 5 | package dt; 6 | 7 | import java.util.*; 8 | 9 | public class DecisionTree { 10 | /** 11 | * Contains the set of available attributes. 12 | */ 13 | private LinkedHashSet attributes; 14 | 15 | /** 16 | * Maps a attribute name to a set of possible decisions for that attribute. 17 | */ 18 | private Map > decisions; 19 | private boolean decisionsSpecified; 20 | 21 | /** 22 | * Contains the examples to be processed into a decision tree. 23 | * 24 | * The 'attributes' and 'decisions' member variables should be updated 25 | * prior to adding examples that refer to new attributes or decisions. 26 | */ 27 | private Examples examples; 28 | 29 | /** 30 | * Indicates if the provided data has been processed into a decision tree. 31 | * 32 | * This value is initially false, and is reset any time additional data is 33 | * provided. 34 | */ 35 | private boolean compiled; 36 | 37 | /** 38 | * Contains the top-most attribute of the decision tree. 39 | * 40 | * For a tree where the decision requires no attributes, 41 | * the rootAttribute yields a boolean classification. 42 | * 43 | */ 44 | private Attribute rootAttribute; 45 | 46 | private Algorithm algorithm; 47 | 48 | public DecisionTree() { 49 | algorithm = null; 50 | examples = new Examples(); 51 | attributes = new LinkedHashSet(); 52 | decisions = new HashMap >(); 53 | decisionsSpecified = false; 54 | } 55 | 56 | private void setDefaultAlgorithm() { 57 | if ( algorithm == null ) 58 | setAlgorithm(new ID3Algorithm(examples)); 59 | } 60 | 61 | public void setAlgorithm(Algorithm algorithm) { 62 | this.algorithm = algorithm; 63 | } 64 | 65 | /** 66 | * Saves the array of attribute names in an insertion ordered set. 67 | * 68 | * The ordering of attribute names is used when addExamples is called to 69 | * determine which values correspond with which names. 70 | * 71 | */ 72 | public DecisionTree setAttributes(String[] attributeNames) { 73 | compiled = false; 74 | 75 | decisions.clear(); 76 | decisionsSpecified = false; 77 | 78 | attributes.clear(); 79 | 80 | for ( int i = 0 ; i < attributeNames.length ; i++ ) 81 | attributes.add(attributeNames[i]); 82 | 83 | return this; 84 | } 85 | 86 | /** 87 | */ 88 | public DecisionTree setDecisions(String attributeName, String[] decisions) { 89 | if ( !attributes.contains(attributeName) ) { 90 | // TODO some kind of warning or something 91 | return this; 92 | } 93 | 94 | compiled = false; 95 | decisionsSpecified = true; 96 | 97 | Set decisionsSet = new HashSet(); 98 | for ( int i = 0 ; i < decisions.length ; i++ ) 99 | decisionsSet.add(decisions[i]); 100 | 101 | this.decisions.put(attributeName, decisionsSet); 102 | 103 | return this; 104 | } 105 | 106 | /** 107 | */ 108 | public DecisionTree addExample(String[] attributeValues, boolean classification) throws UnknownDecisionException { 109 | String[] attributes = this.attributes.toArray(new String[0]); 110 | 111 | if ( decisionsSpecified ) 112 | for ( int i = 0 ; i < attributeValues.length ; i++ ) 113 | if ( !decisions.get(attributes[i]).contains(attributeValues[i]) ) { 114 | throw new UnknownDecisionException(attributes[i], attributeValues[i]); 115 | } 116 | 117 | compiled = false; 118 | 119 | examples.add(attributes, attributeValues, classification); 120 | 121 | return this; 122 | } 123 | 124 | public DecisionTree addExample(Map attributes, boolean classification) throws UnknownDecisionException { 125 | compiled = false; 126 | 127 | examples.add(attributes, classification); 128 | 129 | return this; 130 | } 131 | 132 | public boolean apply(Map data) throws BadDecisionException { 133 | compile(); 134 | 135 | return rootAttribute.apply(data); 136 | } 137 | 138 | private Attribute compileWalk(Attribute current, Map chosenAttributes, Set usedAttributes) { 139 | // if the current attribute is a leaf, then there are no decisions and thus no 140 | // further attributes to find. 141 | if ( current.isLeaf() ) 142 | return current; 143 | 144 | // get decisions for the current attribute (from this.decisions) 145 | String attributeName = current.getName(); 146 | 147 | // remove this attribute from all further consideration 148 | usedAttributes.add(attributeName); 149 | 150 | for ( String decisionName : decisions.get(attributeName) ) { 151 | // overwrite the attribute decision for each value considered 152 | chosenAttributes.put(attributeName, decisionName); 153 | 154 | // find the next attribute to choose for the considered decision 155 | // build the subtree from this new attribute, pre-order 156 | // insert the newly-built subtree into the open decision slot 157 | current.addDecision(decisionName, compileWalk(algorithm.nextAttribute(chosenAttributes, usedAttributes), chosenAttributes, usedAttributes)); 158 | } 159 | 160 | // remove the attribute decision before we walk back up the tree. 161 | chosenAttributes.remove(attributeName); 162 | 163 | // return the subtree so that it can be inserted into the parent tree. 164 | return current; 165 | } 166 | 167 | public void compile() { 168 | // skip compilation if already done. 169 | if ( compiled ) 170 | return; 171 | 172 | // if no algorithm is set beforehand, select the default one. 173 | setDefaultAlgorithm(); 174 | 175 | Map chosenAttributes = new HashMap(); 176 | Set usedAttributes = new HashSet(); 177 | 178 | if ( !decisionsSpecified ) 179 | decisions = examples.extractDecisions(); 180 | 181 | // find the root attribute (either leaf or non) 182 | // walk the tree, adding attributes as needed under each decision 183 | // save the original attribute as the root attribute. 184 | rootAttribute = compileWalk(algorithm.nextAttribute(chosenAttributes, usedAttributes), chosenAttributes, usedAttributes); 185 | 186 | compiled = true; 187 | } 188 | 189 | public String toString() { 190 | compile(); 191 | 192 | if ( rootAttribute != null ) 193 | return rootAttribute.toString(); 194 | else 195 | return ""; 196 | } 197 | 198 | public Attribute getRoot() { 199 | return rootAttribute; 200 | } 201 | } 202 | -------------------------------------------------------------------------------- /src/dt/Decisions.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | 5 | package dt; 6 | 7 | import java.util.*; 8 | 9 | 10 | class Decisions { 11 | private Map decisions; 12 | 13 | public Decisions() { 14 | decisions = new HashMap(); 15 | } 16 | 17 | public Map getMap() { 18 | return decisions; 19 | } 20 | 21 | public void put(String decision, Attribute attribute) { 22 | decisions.put(decision, attribute); 23 | } 24 | 25 | public void clear() { 26 | decisions.clear(); 27 | } 28 | 29 | /** 30 | * Returns the attribute based on the decision matching the provided value. 31 | * 32 | * Throws BadDecisionException if no decision matches. 33 | */ 34 | public Attribute apply(String value) throws BadDecisionException { 35 | Attribute result = decisions.get(value); 36 | 37 | if ( result == null ) 38 | throw new BadDecisionException(); 39 | 40 | return result; 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/dt/Examples.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | 5 | package dt; 6 | 7 | import java.util.*; 8 | 9 | 10 | class Examples { 11 | class Example { 12 | private Map values; 13 | private boolean classifier; 14 | 15 | public Example(String[] attributeNames, String[] attributeValues, 16 | boolean classifier) { 17 | assert(attributeNames.length == attributeValues.length); 18 | values = new HashMap(); 19 | 20 | for ( int i = 0 ; i < attributeNames.length ; i++ ) { 21 | values.put(attributeNames[i], attributeValues[i]); 22 | } 23 | 24 | this.classifier = classifier; 25 | } 26 | 27 | public Example(Map attributes, boolean classifier) { 28 | this.classifier = classifier; 29 | this.values = attributes; 30 | } 31 | 32 | public Set getAttributes() { 33 | return values.keySet(); 34 | } 35 | 36 | public String getAttributeValue(String attribute) { 37 | return values.get(attribute); 38 | } 39 | 40 | public boolean matchesClass(boolean classifier) { 41 | return classifier == this.classifier; 42 | } 43 | } 44 | 45 | private List examples; 46 | 47 | public Examples() { 48 | examples = new LinkedList(); 49 | } 50 | 51 | public void add(String[] attributeNames, String[] attributeValues, 52 | boolean classifier) { 53 | examples.add(new Example(attributeNames, attributeValues, classifier)); 54 | } 55 | 56 | public void add(Map attributes, boolean classifier) { 57 | examples.add(new Example(attributes, classifier)); 58 | } 59 | 60 | /** 61 | * Returns the number of examples where the attribute has the specified 62 | * 'decision' value 63 | */ 64 | int countDecisions(String attribute, String decision) { 65 | int count = 0; 66 | 67 | for ( Example e : examples ) { 68 | if ( e.getAttributeValue(attribute).equals(decision) ) 69 | count++; 70 | } 71 | 72 | return count; 73 | } 74 | 75 | /** 76 | * Returns a map from each attribute name to a set of all values used in the 77 | * examples for that attribute. 78 | */ 79 | public Map > extractDecisions() { 80 | Map > decisions = new HashMap >(); 81 | 82 | for ( String attribute : extractAttributes() ) { 83 | decisions.put(attribute, extractDecisions(attribute)); 84 | } 85 | 86 | return decisions; 87 | } 88 | 89 | public int countNegative(String attribute, String decision, 90 | Map attributes) { 91 | return countClassifier(false, attribute, decision, attributes); 92 | } 93 | 94 | public int countPositive(String attribute, String decision, 95 | Map attributes) { 96 | return countClassifier(true, attribute, decision, attributes); 97 | } 98 | 99 | public int countNegative(Map attributes) { 100 | return countClassifier(false, attributes); 101 | } 102 | 103 | public int countPositive(Map attributes) { 104 | return countClassifier(true, attributes); 105 | } 106 | 107 | public int count(String attribute, String decision, Map attributes) { 108 | attributes = new HashMap(attributes); 109 | attributes.put(attribute, decision); 110 | 111 | return count(attributes); 112 | } 113 | 114 | public int count(Map attributes) { 115 | int count = 0; 116 | 117 | nextExample: 118 | for ( Example e : examples ) { 119 | for ( Map.Entry attribute : attributes.entrySet() ) 120 | if ( !(e.getAttributeValue(attribute.getKey()).equals(attribute.getValue())) ) 121 | continue nextExample; 122 | 123 | // All of the provided attributes match the example. 124 | count++; 125 | } 126 | 127 | return count; 128 | } 129 | 130 | public int countClassifier(boolean classifier, Map attributes) { 131 | int count = 0; 132 | 133 | nextExample: 134 | for ( Example e : examples ) { 135 | for ( Map.Entry attribute : attributes.entrySet() ) 136 | if ( !(e.getAttributeValue(attribute.getKey()).equals(attribute.getValue())) ) 137 | continue nextExample; 138 | 139 | // All of the provided attributes match the example. 140 | // If the example matches the classifier, then include it in the count. 141 | if ( e.matchesClass(classifier) ) 142 | count++; 143 | } 144 | 145 | return count; 146 | } 147 | 148 | public int countClassifier(boolean classifier, String attribute, 149 | String decision, Map attributes) { 150 | attributes = new HashMap(attributes); 151 | attributes.put(attribute, decision); 152 | 153 | return countClassifier(classifier, attributes); 154 | } 155 | 156 | /** 157 | * Returns the number of examples. 158 | */ 159 | public int count() { 160 | return examples.size(); 161 | } 162 | 163 | /** 164 | * Returns a set of attribute names used in the examples. 165 | */ 166 | public Set extractAttributes() { 167 | Set attributes = new HashSet(); 168 | 169 | for ( Example e : examples ) { 170 | attributes.addAll(e.getAttributes()); 171 | } 172 | 173 | return attributes; 174 | } 175 | 176 | private Set extractDecisions(String attribute) { 177 | Set decisions = new HashSet(); 178 | 179 | for ( Example e : examples ) { 180 | decisions.add(e.getAttributeValue(attribute)); 181 | } 182 | 183 | return decisions; 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /src/dt/ID3Algorithm.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | 5 | package dt; 6 | 7 | import java.util.*; 8 | 9 | import org.slf4j.Logger; 10 | import org.slf4j.LoggerFactory; 11 | 12 | 13 | public class ID3Algorithm implements Algorithm { 14 | final Logger logger = LoggerFactory.getLogger(ID3Algorithm.class); 15 | private Examples examples; 16 | 17 | public ID3Algorithm(Examples examples) { 18 | this.examples = examples; 19 | } 20 | 21 | /** 22 | * Returns the next attribute to be chosen. 23 | * 24 | * chosenAttributes represents the decision path from the root attribute 25 | * to the node under consideration. usedAttributes is the set of all 26 | * attributes that have been incorporated into the tree prior to this 27 | * call to nextAttribute(), even if the attributes were not used in the path 28 | * to the node under consideration. 29 | * 30 | * Results are undefined if examples.count() == 0. 31 | */ 32 | public Attribute nextAttribute(Map chosenAttributes, Set usedAttributes) { 33 | double currentGain = 0.0, bestGain = 0.0; 34 | String bestAttribute = ""; 35 | 36 | /* 37 | * If there are no positive examples for the already chosen attributes, 38 | * then return a false classifier leaf. If no negative examples, 39 | * then return a true classifier leaf. 40 | */ 41 | if ( examples.countPositive(chosenAttributes) == 0 ) 42 | return new Attribute(false); 43 | else if ( examples.countNegative(chosenAttributes) == 0 ) 44 | return new Attribute(true); 45 | 46 | logger.debug("Choosing attribute out of {} remaining attributes.", 47 | remainingAttributes(usedAttributes).size()); 48 | logger.debug("Already chosen attributes/decisions are {}.", chosenAttributes); 49 | 50 | for ( String attribute : remainingAttributes(usedAttributes) ) { 51 | // for each remaining attribute, determine the information gain of using it 52 | // to choose among the examples selected by the chosenAttributes 53 | // if none give any information gain, return a leaf attribute, 54 | // otherwise return the found attribute as a non-leaf attribute 55 | currentGain = informationGain(attribute, chosenAttributes); 56 | logger.debug("Evaluating attribute {}, information gain is {}", 57 | attribute, currentGain); 58 | if ( currentGain > bestGain ) { 59 | bestAttribute = attribute; 60 | bestGain = currentGain; 61 | } 62 | } 63 | 64 | // If no attribute gives information gain, generate leaf attribute. 65 | // Leaf is true if there are any true classifiers. 66 | // If there is at least one negative example, then the information gain 67 | // would be greater than 0. 68 | if ( bestGain == 0.0 ) { 69 | boolean classifier = examples.countPositive(chosenAttributes) > 0; 70 | logger.debug("Creating new leaf attribute with classifier {}.", classifier); 71 | return new Attribute(classifier); 72 | } else { 73 | logger.debug("Creating new non-leaf attribute {}.", bestAttribute); 74 | return new Attribute(bestAttribute); 75 | } 76 | } 77 | 78 | private Set remainingAttributes(Set usedAttributes) { 79 | Set result = examples.extractAttributes(); 80 | result.removeAll(usedAttributes); 81 | return result; 82 | } 83 | 84 | private double entropy(Map specifiedAttributes) { 85 | double totalExamples = examples.count(); 86 | double positiveExamples = examples.countPositive(specifiedAttributes); 87 | double negativeExamples = examples.countNegative(specifiedAttributes); 88 | 89 | return -nlog2(positiveExamples / totalExamples) - 90 | nlog2(negativeExamples / totalExamples); 91 | } 92 | 93 | private double entropy(String attribute, String decision, Map specifiedAttributes) { 94 | double totalExamples = examples.count(attribute, decision, specifiedAttributes); 95 | double positiveExamples = examples.countPositive(attribute, decision, specifiedAttributes); 96 | double negativeExamples = examples.countNegative(attribute, decision, specifiedAttributes); 97 | 98 | return -nlog2(positiveExamples / totalExamples) - 99 | nlog2(negativeExamples / totalExamples); 100 | } 101 | 102 | private double informationGain(String attribute, Map specifiedAttributes) { 103 | double sum = entropy(specifiedAttributes); 104 | double examplesCount = examples.count(specifiedAttributes); 105 | 106 | if ( examplesCount == 0 ) 107 | return sum; 108 | 109 | Map > decisions = examples.extractDecisions(); 110 | 111 | for ( String decision : decisions.get(attribute) ) { 112 | double entropyPart = entropy(attribute, decision, specifiedAttributes); 113 | double decisionCount = examples.countDecisions(attribute, decision); 114 | 115 | sum += -(decisionCount / examplesCount) * entropyPart; 116 | } 117 | 118 | return sum; 119 | } 120 | 121 | private double nlog2(double value) { 122 | if ( value == 0 ) 123 | return 0; 124 | 125 | return value * Math.log(value) / Math.log(2); 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /src/dt/UnknownDecisionException.java: -------------------------------------------------------------------------------- 1 | /** 2 | * 3 | */ 4 | 5 | package dt; 6 | 7 | 8 | public class UnknownDecisionException extends Exception { 9 | public UnknownDecisionException(String attribute, String decision) { 10 | super(); 11 | } 12 | } 13 | 14 | --------------------------------------------------------------------------------