├── .gitignore ├── Data ├── adult_names.txt ├── output.txt ├── test.data ├── test_data_preprocessed.data ├── textbookExample.data ├── train.data └── train_data_preprocessed.data ├── Documentation.pdf ├── LICENSE ├── README.md ├── bin ├── Comp.class ├── ID3.class ├── Main.class ├── Node.class ├── Preprocess.class ├── RandomForest.class └── ReducedErrorPruning.class └── src ├── ID3.java ├── Main.java ├── Node.java ├── Preprocess.java ├── RandomForest.java └── ReducedErrorPruning.java /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled class file 2 | *.class 3 | 4 | # Log file 5 | *.log 6 | 7 | # BlueJ files 8 | *.ctxt 9 | 10 | # Mobile Tools for Java (J2ME) 11 | .mtj.tmp/ 12 | 13 | # Package Files # 14 | *.jar 15 | *.war 16 | *.ear 17 | *.zip 18 | *.tar.gz 19 | *.rar 20 | 21 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml 22 | hs_err_pid* 23 | -------------------------------------------------------------------------------- /Data/adult_names.txt: -------------------------------------------------------------------------------- 1 | | This data was extracted from the census bureau database found at 2 | | http://www.census.gov/ftp/pub/DES/www/welcome.html 3 | | Donor: Ronny Kohavi and Barry Becker, 4 | | Data Mining and Visualization 5 | | Silicon Graphics. 6 | | e-mail: ronnyk@sgi.com for questions. 7 | | Split into train-test using MLC++ GenCVFiles (2/3, 1/3 random). 8 | | 48842 instances, mix of continuous and discrete (train=32561, test=16281) 9 | | 45222 if instances with unknown values are removed (train=30162, test=15060) 10 | | Duplicate or conflicting instances : 6 11 | | Class probabilities for adult.all file 12 | | Probability for the label '>50K' : 23.93% / 24.78% (without unknowns) 13 | | Probability for the label '<=50K' : 76.07% / 75.22% (without unknowns) 14 | | 15 | | Extraction was done by Barry Becker from the 1994 Census database. A set of 16 | | reasonably clean records was extracted using the following conditions: 17 | | ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0)) 18 | | 19 | | Prediction task is to determine whether a person makes over 50K 20 | | a year. 21 | | 22 | | First cited in: 23 | | @inproceedings{kohavi-nbtree, 24 | | author={Ron Kohavi}, 25 | | title={Scaling Up the Accuracy of Naive-Bayes Classifiers: a 26 | | Decision-Tree Hybrid}, 27 | | booktitle={Proceedings of the Second International Conference on 28 | | Knowledge Discovery and Data Mining}, 29 | | year = 1996, 30 | | pages={to appear}} 31 | | 32 | | Error Accuracy reported as follows, after removal of unknowns from 33 | | train/test sets): 34 | | C4.5 : 84.46+-0.30 35 | | Naive-Bayes: 83.88+-0.30 36 | | NBTree : 85.90+-0.28 37 | | 38 | | 39 | | Following algorithms were later run with the following error rates, 40 | | all after removal of unknowns and using the original train/test split. 41 | | All these numbers are straight runs using MLC++ with default values. 42 | | 43 | | Algorithm Error 44 | | -- ---------------- ----- 45 | | 1 C4.5 15.54 46 | | 2 C4.5-auto 14.46 47 | | 3 C4.5 rules 14.94 48 | | 4 Voted ID3 (0.6) 15.64 49 | | 5 Voted ID3 (0.8) 16.47 50 | | 6 T2 16.84 51 | | 7 1R 19.54 52 | | 8 NBTree 14.10 53 | | 9 CN2 16.00 54 | | 10 HOODG 14.82 55 | | 11 FSS Naive Bayes 14.05 56 | | 12 IDTM (Decision table) 14.46 57 | | 13 Naive-Bayes 16.12 58 | | 14 Nearest-neighbor (1) 21.42 59 | | 15 Nearest-neighbor (3) 20.35 60 | | 16 OC1 15.04 61 | | 17 Pebls Crashed. Unknown why (bounds WERE increased) 62 | | 63 | | Conversion of original data as follows: 64 | | 1. Discretized agrossincome into two ranges with threshold 50,000. 65 | | 2. Convert U.S. to US to avoid periods. 66 | | 3. Convert Unknown to "?" 67 | | 4. Run MLC++ GenCVFiles to generate data,test. 68 | | 69 | | Description of fnlwgt (final weight) 70 | | 71 | | The weights on the CPS files are controlled to independent estimates of the 72 | | civilian noninstitutional population of the US. These are prepared monthly 73 | | for us by Population Division here at the Census Bureau. We use 3 sets of 74 | | controls. 75 | | These are: 76 | | 1. A single cell estimate of the population 16+ for each state. 77 | | 2. Controls for Hispanic Origin by age and sex. 78 | | 3. Controls by Race, age and sex. 79 | | 80 | | We use all three sets of controls in our weighting program and "rake" through 81 | | them 6 times so that by the end we come back to all the controls we used. 82 | | 83 | | The term estimate refers to population totals derived from CPS by creating 84 | | "weighted tallies" of any specified socio-economic characteristics of the 85 | | population. 86 | | 87 | | People with similar demographic characteristics should have 88 | | similar weights. There is one important caveat to remember 89 | | about this statement. That is that since the CPS sample is 90 | | actually a collection of 51 state samples, each with its own 91 | | probability of selection, the statement only applies within 92 | | state. 93 | 94 | 95 | >50K, <=50K. 96 | 97 | age: continuous. 98 | workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked. 99 | fnlwgt: continuous. 100 | education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. 101 | education-num: continuous. 102 | marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse. 103 | occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces. 104 | relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. 105 | race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black. 106 | sex: Female, Male. 107 | capital-gain: continuous. 108 | capital-loss: continuous. 109 | hours-per-week: continuous. 110 | native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands. 111 | -------------------------------------------------------------------------------- /Data/output.txt: -------------------------------------------------------------------------------- 1 | Start... 2 | 3 | Prepocessing Training data 4 | 5 | Prepocessing Testing data 6 | 7 | Generating Decision Tree using ID3 Algorithm 8 | Training Time=1.979secs 9 | Accuracy=0.807874209200909 10 | Precision=0.8762364294330519 Recall=0.8727272727272727 F-Score=0.874478330658106 11 | No of nodes in tree = 33223 12 | 13 | Applying Reduced Error Pruning on the decision tree generated 14 | Training Time=10.7secs 15 | Accuracy=0.8404889134574043 16 | Precision=0.9467631684760756 Recall=0.8588415523781733 F-Score=0.9006617450177867 17 | No of nodes in tree = 2640 18 | 19 | Initializing Random Forest with 10 trees, 0.5 fraction of attributes and 0.33 fraction of training instances in each tree 20 | Training Time=1.618secs 21 | Accuracy=0.8313371414532277 22 | Precision=0.944270205066345 Recall=0.8511779630300834 F-Score=0.8953107129241327 23 | 24 | End... 25 | -------------------------------------------------------------------------------- /Data/textbookExample.data: -------------------------------------------------------------------------------- 1 | Sunny,Hot,High,Weak,No 2 | Sunny,Hot,High,Strong,No 3 | Overcast,Hot,High,Weak,Yes 4 | Rain,Mild,High,Weak,Yes 5 | Rain,Cool,Normal,Weak,Yes 6 | Rain,Cool,Normal,Strong,No 7 | Overcast,Cool,Normal,Strong,Yes 8 | Sunny,Mild,High,Weak,No 9 | Sunny,Cool,Normal,Weak,Yes 10 | Rain,Mild,Normal,Weak,Yes 11 | Sunny,Mild,Normal,Strong,Yes 12 | Overcast,Mild,High,Strong,Yes 13 | Overcast,Hot,Normal,Weak,Yes 14 | Rain,Mild,High,Strong,No -------------------------------------------------------------------------------- /Documentation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/Documentation.pdf -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Keval Morabia 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ID3-Decision-Tree-Classifier-in-Java 2 | 3 | ``` 4 | Classes: 1 = >50K, 2 = <=50K 5 | 6 | Attributes 7 | age: continuous. 8 | workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked. 9 | fnlwgt: continuous. 10 | education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. 11 | education-num: continuous. 12 | marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse. 13 | occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces. 14 | relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. 15 | race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black. 16 | sex: Female, Male. 17 | capital-gain: continuous. 18 | capital-loss: continuous. 19 | hours-per-week: continuous. 20 | native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands. 21 | 22 | ``` 23 |
24 | 25 | **Procedure:** 26 | 1. Decision tree was generated using the data provided and the ID3 algorithm mentioned in Tom. M. Mitchell. 27 | 2. Missing values were filled using the value which appeared most frequently in the particular attribute column. 28 | 3. Continuous values were handled as mentioned in section 3.7.2 of Tom M. Mitchell. First the values were sorted in ascending order, then at the points where value was changing, gain was calculated and finally the column was 29 | splited at the point where maximum gain was obtained. 30 | 4. Reduced Error Pruning was performed by removing a node (one by one) and then checking the accuracy. If accuracy was increased than the node was removed else we move on to check the next node. 31 | 5. Random forests were generated using 50% attributes and 33% data randomly. 10 forests were generated and accuracy increased compared to 32 | the original ID3 algorithm. 33 | 34 |
35 | 36 | ``` 37 | Output: 38 | Start... 39 | 40 | Prepocessing Training data 41 | 42 | Prepocessing Testing data 43 | 44 | Generating Decision Tree using ID3 Algorithm 45 | Training Time=1.979secs 46 | Accuracy=0.807874209200909 47 | Precision=0.8762364294330519 Recall=0.8727272727272727 F-Score=0.874478330658106 48 | No of nodes in tree = 33223 49 | 50 | Applying Reduced Error Pruning on the decision tree generated 51 | Training Time=10.7secs 52 | Accuracy=0.8404889134574043 53 | Precision=0.9467631684760756 Recall=0.8588415523781733 F-Score=0.9006617450177867 54 | No of nodes in tree = 2640 55 | 56 | Initializing Random Forest with 10 trees, 0.5 fraction of attributes and 0.33 fraction of training instances in each tree 57 | Training Time=1.618secs 58 | Accuracy=0.8313371414532277 59 | Precision=0.944270205066345 Recall=0.8511779630300834 F-Score=0.8953107129241327 60 | 61 | End... 62 | ``` 63 | -------------------------------------------------------------------------------- /bin/Comp.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/Comp.class -------------------------------------------------------------------------------- /bin/ID3.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/ID3.class -------------------------------------------------------------------------------- /bin/Main.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/Main.class -------------------------------------------------------------------------------- /bin/Node.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/Node.class -------------------------------------------------------------------------------- /bin/Preprocess.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/Preprocess.class -------------------------------------------------------------------------------- /bin/RandomForest.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/RandomForest.class -------------------------------------------------------------------------------- /bin/ReducedErrorPruning.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/ReducedErrorPruning.class -------------------------------------------------------------------------------- /src/ID3.java: -------------------------------------------------------------------------------- 1 | import java.util.ArrayList; 2 | import java.util.HashMap; 3 | 4 | public class ID3 { 5 | HashMap> discreteValues; 6 | String class1, class2; 7 | ArrayList data, testData;//data = training data 8 | Node root; 9 | double trainingTime=0, precision, recall, fscore, accuracy; 10 | int noOfNodes; 11 | 12 | public ID3(ArrayList data, ArrayList testData, int noOfClass1, int noOfClass2, String class1, String class2, HashMap> discreteValues, ArrayList remAttr){ 13 | trainingTime = System.currentTimeMillis(); 14 | 15 | this.data = data; 16 | this.class1 = class1; 17 | this.class2 = class2; 18 | this.testData = testData; 19 | this.discreteValues = discreteValues; 20 | 21 | root = new Node(); 22 | double p1 = noOfClass1/(noOfClass1+noOfClass2+0.0), p2 = noOfClass2/(noOfClass1+noOfClass2+0.0); 23 | root.entropy = -1*pLogP(p1) - 1*pLogP(p2); 24 | root.data = data; 25 | root.noOfClass1 = noOfClass1; 26 | root.noOfClass2 = noOfClass2; 27 | 28 | root.remainingAttributes = remAttr; 29 | generateDecisionTree(root); 30 | 31 | trainingTime = (System.currentTimeMillis() - trainingTime)/1000.0; 32 | 33 | analyse();//calculate precision, recall, fscore, accuracy 34 | } 35 | 36 | public void analyse(){ 37 | int correctClassification=0, incorrectClassification=0; 38 | int truePositive = 0, falsePositive = 0, falseNegative = 0; 39 | for(String[] s : testData){ 40 | int predicted = Node.predictClass(root, s, discreteValues), actual = s[s.length-1].equals(class1)?1:2; 41 | if(predicted == actual ) correctClassification++; 42 | else incorrectClassification++; 43 | 44 | //1-->yes, 2-->no 45 | if(predicted==1 && actual==1) truePositive++; 46 | else if(predicted==1 && actual==2) falseNegative++; 47 | else if(predicted==2 && actual==1) falsePositive++; 48 | } 49 | precision = truePositive/(truePositive+falsePositive+0.0); 50 | recall = truePositive/(truePositive+falseNegative+0.0); 51 | fscore = 2*precision*recall/(precision+recall); 52 | accuracy = (correctClassification)/(correctClassification+incorrectClassification+ 0.0); 53 | 54 | noOfNodes = 0; 55 | countNodes(root); 56 | } 57 | 58 | public void printAnalysis(){ 59 | System.out.println("Accuracy="+accuracy+"\nPrecision="+precision+" Recall="+recall+" F-Score="+fscore); 60 | System.out.println("No of nodes in tree = "+noOfNodes); 61 | } 62 | 63 | public void printTrainingTime(){ 64 | System.out.println("Training Time="+trainingTime+"secs"); 65 | } 66 | 67 | public void countNodes(Node root){ 68 | if(root==null) return; 69 | noOfNodes++; 70 | if(root.isLeaf) return; 71 | for(Node n : root.children) countNodes(n); 72 | } 73 | 74 | private static double pLogP(double p){ 75 | return p==0?0:p*Math.log(p); 76 | } 77 | 78 | private void generateDecisionTree(Node root){ 79 | if(root==null) return; 80 | if(root.remainingAttributes.size()==1){//leaf node 81 | root.isLeaf = true; 82 | root.classification = root.noOfClass1>=root.noOfClass2?1:2; 83 | }else if(root.noOfClass1==0 || root.noOfClass2==0 || root.data.size()==0){//leaf 84 | root.isLeaf = true; 85 | root.classification = root.noOfClass1==0?2:1; 86 | }else{ 87 | //find split attribute which gives max gain 88 | root.children = splitAttribute(root); 89 | ArrayList discreteValuesOfThisAttribute = discreteValues.get(root.attribute); 90 | for(int j=0; j < discreteValuesOfThisAttribute.size(); j++){ 91 | root.children[j].data = new ArrayList<>(); 92 | root.children[j].remainingAttributes = new ArrayList<>(); 93 | for(int rem : root.remainingAttributes){ 94 | if(rem!=root.attribute) root.children[j].remainingAttributes.add(rem); 95 | } 96 | String curr = discreteValuesOfThisAttribute.get(j); 97 | for(String[] s : root.data){ 98 | if(s[root.attribute].equals(curr)){ 99 | root.children[j].data.add(s); 100 | } 101 | } 102 | generateDecisionTree(root.children[j]); 103 | } 104 | } 105 | } 106 | 107 | public Node[] splitAttribute(Node root){ 108 | double maxGain = -1.0; 109 | Node[] ans = null; 110 | for(int i : root.remainingAttributes){ 111 | ArrayList discreteValuesOfThisAttribute = discreteValues.get(i); 112 | Node[] child = new Node[discreteValuesOfThisAttribute.size()]; 113 | for(int j=0; j < discreteValuesOfThisAttribute.size(); j++){ 114 | String curr = discreteValuesOfThisAttribute.get(j); 115 | child[j] = new Node(); 116 | for(String[] s : root.data){ 117 | if(s[i].equals(curr)){ 118 | if(s[s.length-1].equals(class1)) child[j].noOfClass1++; 119 | else child[j].noOfClass2++; 120 | } 121 | } 122 | } 123 | int total = root.data.size(); 124 | double gain = root.entropy; 125 | for(int j = 0; j < discreteValuesOfThisAttribute.size(); j++){ 126 | int c1 = child[j].noOfClass1, c2 = child[j].noOfClass2; 127 | if(c1==0 && c2==0) continue; 128 | double p1 = c1/(c1+c2+0.0), p2 = c2/(c1+c2+0.0); 129 | child[j].entropy = -1*pLogP(p1) + -1*pLogP(p2); 130 | gain -= ((c1+c2)/(total+0.0))*child[j].entropy; 131 | } 132 | if(gain > maxGain){ 133 | root.attribute = i; 134 | maxGain = gain; 135 | ans = child; 136 | } 137 | } 138 | return ans; 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /src/Main.java: -------------------------------------------------------------------------------- 1 | import java.io.File; 2 | import java.io.IOException; 3 | import java.util.ArrayList; 4 | 5 | public class Main { 6 | 7 | public static void main(String[] args) throws IOException { 8 | System.out.println("Start..."); 9 | 10 | System.out.println("\nPrepocessing Training data"); 11 | String class1 = "<=50K", class2 = ">50K"; 12 | ArrayList continuousAttributes = new ArrayList<>(); 13 | continuousAttributes.add(0); 14 | continuousAttributes.add(2); 15 | continuousAttributes.add(4); 16 | continuousAttributes.add(10); 17 | continuousAttributes.add(11); 18 | continuousAttributes.add(12); 19 | 20 | Preprocess p = new Preprocess(class1, class2, continuousAttributes); 21 | ArrayList trainData = p.discretize(new File("Data/train.data")); 22 | int noOfClass1 = p.c1, noOfClass2 = p.c2; 23 | Preprocess.predictMissingValues(trainData, "train"); 24 | 25 | Preprocess.computeDiscreteValues(trainData); 26 | 27 | System.out.println("\nPrepocessing Testing data"); 28 | ArrayList testData = Preprocess.preprocessTestData(new File("Data/test.data"), continuousAttributes); 29 | 30 | ArrayList remAttr = new ArrayList<>(); 31 | for(int i = 0; i < trainData.get(0).length-1; i++) remAttr.add(i); 32 | System.out.println("\nGenerating Decision Tree using ID3 Algorithm"); 33 | //<=50K = class1, >50K = class2 34 | ID3 decisionTree = new ID3(trainData, testData, noOfClass1, noOfClass2, class1, class2, Preprocess.discreteValues, remAttr); 35 | decisionTree.printTrainingTime(); 36 | decisionTree.printAnalysis(); 37 | 38 | System.out.println("\nApplying Reduced Error Pruning on the decision tree generated"); 39 | //if you want to preserve original tree then create another ID3 instance and pass that to rep 40 | ReducedErrorPruning rep = new ReducedErrorPruning(decisionTree); 41 | rep.tree.printAnalysis(); 42 | 43 | int noOftrees = 10; 44 | double fractionOfAttributesToTake = 0.5, fractionOfTrainingInstancesToTake = 0.33; 45 | System.out.println("\nInitializing Random Forest with "+noOftrees+" trees, "+fractionOfAttributesToTake 46 | +" fraction of attributes and "+fractionOfTrainingInstancesToTake+" fraction of training instances in each tree"); 47 | RandomForest rf = new RandomForest(noOftrees, fractionOfAttributesToTake, fractionOfTrainingInstancesToTake, 48 | trainData, testData, noOfClass1, noOfClass2, class1, class2, Preprocess.discreteValues); 49 | rf.printAnalysis(); 50 | 51 | System.out.println("\nEnd..."); 52 | } 53 | 54 | } 55 | -------------------------------------------------------------------------------- /src/Node.java: -------------------------------------------------------------------------------- 1 | import java.util.ArrayList; 2 | import java.util.HashMap; 3 | 4 | public class Node { 5 | int attribute;//attribute on which classification is done 6 | ArrayList remainingAttributes; 7 | int noOfClass1, noOfClass2;//class1 = <=50K, class2 = >50K 8 | double entropy; 9 | Node[] children; 10 | ArrayList data;//Datasets satisfying conditions of current node and all ancestors 11 | boolean isLeaf = false; 12 | int classification;//applicable if leaf, classification = 1 or 2 for class1 and class2 respectively 13 | 14 | public Node(int attr, String value, int c1, int c2, double d, ArrayList data, ArrayList remainingAttributes){ 15 | attribute = attr; 16 | noOfClass1 = c1; 17 | noOfClass2 = c2; 18 | entropy = d; 19 | this.data = data; 20 | this.remainingAttributes = remainingAttributes; 21 | } 22 | 23 | public Node(int classification){ 24 | isLeaf = true; 25 | this.classification = classification; 26 | } 27 | 28 | public Node(){ 29 | 30 | } 31 | 32 | public String toString(){ 33 | return isLeaf ? "Leaf,Classification="+classification : "Entropy="+entropy+",Split="+attribute; 34 | } 35 | 36 | public static int predictClass(Node root, String[] data, HashMap> discreteValues){ 37 | if(root==null) return 1; 38 | else if(root.isLeaf) return root.classification; 39 | String s = data[root.attribute]; 40 | ArrayList discrVals = discreteValues.get(root.attribute); 41 | for(int i = 0; i < discrVals.size(); i++){ 42 | if(s.equals(discrVals.get(i))){ 43 | return predictClass(root.children[i], data, discreteValues); 44 | } 45 | } 46 | return 1; 47 | } 48 | 49 | public static void inOrder(Node root){ 50 | if(root==null) return; 51 | System.out.println(root); 52 | if(root.isLeaf) return; 53 | for(Node c : root.children) inOrder(c); 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/Preprocess.java: -------------------------------------------------------------------------------- 1 | import java.io.BufferedReader; 2 | import java.io.BufferedWriter; 3 | import java.io.File; 4 | import java.io.FileReader; 5 | import java.io.FileWriter; 6 | import java.io.IOException; 7 | import java.util.ArrayList; 8 | import java.util.Arrays; 9 | import java.util.Collections; 10 | import java.util.Comparator; 11 | import java.util.HashMap; 12 | import java.util.HashSet; 13 | import java.util.StringTokenizer; 14 | 15 | //Discretize into 2 classes: <=val and >val such that resulting entropy is minimum 16 | public class Preprocess { 17 | int c1, c2; 18 | static int[] partitionAt;//the value at which continuous attributes are partitioned 19 | String class1, class2; 20 | ArrayList continuousAttributes; 21 | /*no of distinct values for a attribute type 22 | * e.g.: 0, [<=19,>19] 23 | * 1, [Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked] 24 | */ 25 | static HashMap> discreteValues = new HashMap<>(); 26 | 27 | public Preprocess(String class1, String class2, ArrayList continuousAttributes){ 28 | this.class1 = class1; 29 | this.class2 = class2; 30 | this.continuousAttributes = continuousAttributes; 31 | } 32 | 33 | public ArrayList discretize(File trainOriginal) throws IOException{ 34 | BufferedReader br = new BufferedReader(new FileReader(trainOriginal)); 35 | 36 | String s = br.readLine(); 37 | br.close(); 38 | StringTokenizer st = new StringTokenizer(s, ","); 39 | int noOfAttributes = st.countTokens(); 40 | partitionAt = new int[noOfAttributes]; 41 | ArrayList data = new ArrayList<>(); 42 | br = new BufferedReader(new FileReader(trainOriginal)); 43 | while((s = br.readLine())!=null){ 44 | st = new StringTokenizer(s, ","); 45 | String[] dataset = new String[noOfAttributes]; 46 | for(int i = 0; i < noOfAttributes; i++){ 47 | dataset[i] = st.nextToken(); 48 | } 49 | if(dataset[noOfAttributes-1].equals(class1)) c1++; 50 | else c2++; 51 | data.add(dataset); 52 | } 53 | for(int i : continuousAttributes){ 54 | int c1LeftPart = 0, c1RightPart=c1; 55 | int c2LeftPart = 0, c2RightPart=c2; 56 | int total = c1+c2; 57 | Collections.sort(data, new Comp(i)); 58 | double minEndtropy = Double.MAX_VALUE; 59 | int partAt=1; 60 | String prev = data.get(0)[i]; 61 | for(String[] dataset: data){ 62 | if(dataset[noOfAttributes-1].equals(class1)){ 63 | c1LeftPart++; 64 | c1RightPart--; 65 | }else{ 66 | c2LeftPart++; 67 | c2RightPart--; 68 | } 69 | String curr = dataset[i]; 70 | if(curr.equals(prev) || curr.equals("?")) continue; 71 | double p11 = (c1LeftPart+0.0)/(c1LeftPart+c2LeftPart), p12 = (c2LeftPart+0.0)/(c1LeftPart+c2LeftPart); 72 | double p21 = (c1RightPart+0.0)/(c1RightPart+c2RightPart), p22 = (c2RightPart+0.0)/(c1RightPart+c2RightPart); 73 | double entropy = ((c1LeftPart+c2LeftPart)/total+0.0)*(-1*p11*Math.log(p11) -1*p12*Math.log(p12)) + ((c1RightPart+c2RightPart)/total+0.0)*(-1*p21*Math.log(p21) -1*p22*Math.log(p22)); 74 | if(entropy < minEndtropy){ 75 | minEndtropy = entropy; 76 | partAt = (Integer.parseInt(prev)+Integer.parseInt(curr))/2; 77 | partitionAt[i] = partAt; 78 | } 79 | prev = curr; 80 | } 81 | String newc1Name = "<="+partAt, newc2Name = ">"+partAt; 82 | for(String[] dataset : data){ 83 | if(dataset[i]=="?") continue; 84 | if(Integer.parseInt(dataset[i]) <= partAt) dataset[i] = newc1Name; 85 | else dataset[i] = newc2Name; 86 | } 87 | } 88 | br.close(); 89 | return data; 90 | } 91 | 92 | /* 93 | * data should not be continuous 94 | * Sets ? to the value that occurs max times for that attribute 95 | */ 96 | public static void predictMissingValues(ArrayList data, String trainOrTest) throws IOException{ 97 | int n = data.get(0).length-1; 98 | String[] predictedValue = new String[n]; 99 | for(int i = 0; i < n; i++){ 100 | HashMap count = new HashMap<>(); 101 | for(String[] s : data){ 102 | if(s[i].equals("?")) continue; 103 | if(count.containsKey(s[i])) count.put(s[i], count.get(s[i])+1); 104 | else count.put(s[i], 1); 105 | } 106 | int max = 0; 107 | for(String s : count.keySet()){ 108 | int c = count.get(s); 109 | if(c > max){ 110 | max = c; 111 | predictedValue[i] = s; 112 | } 113 | } 114 | } 115 | 116 | for(String[] s : data){ 117 | for(int i = 0; i < n; i++){ 118 | if(s[i].equals("?")) s[i] = predictedValue[i]; 119 | } 120 | } 121 | 122 | File trainDiscretized = new File("Data/"+trainOrTest+"_data_preprocessed.data"); 123 | BufferedWriter bw = new BufferedWriter(new FileWriter(trainDiscretized)); 124 | for(String[] dataset : data){ 125 | bw.write(Arrays.toString(dataset).replace("[", "").replace("]", "").replace(" ", "")+"\n"); 126 | } 127 | bw.close(); 128 | } 129 | 130 | //use partitionAt calculated when discretizing train data 131 | public static ArrayList preprocessTestData(File test, ArrayList continuousAttributes) throws IOException{ 132 | BufferedReader br = new BufferedReader(new FileReader(test)); 133 | 134 | String s = br.readLine(); 135 | br.close(); 136 | StringTokenizer st = new StringTokenizer(s, ","); 137 | int noOfAttributes = st.countTokens(); 138 | ArrayList data = new ArrayList<>(); 139 | br = new BufferedReader(new FileReader(test)); 140 | while((s = br.readLine())!=null){ 141 | st = new StringTokenizer(s, ","); 142 | String[] dataset = new String[noOfAttributes]; 143 | for(int i = 0; i < noOfAttributes; i++){ 144 | dataset[i] = st.nextToken(); 145 | } 146 | data.add(dataset); 147 | } 148 | for(int i : continuousAttributes){ 149 | String newc1Name = "<="+partitionAt[i], newc2Name = ">"+partitionAt[i]; 150 | for(String[] dataset : data){ 151 | if(dataset[i]=="?") continue; 152 | if(Integer.parseInt(dataset[i]) <= partitionAt[i]) dataset[i] = newc1Name; 153 | else dataset[i] = newc2Name; 154 | } 155 | } 156 | br.close(); 157 | predictMissingValues(data,"test"); 158 | return data; 159 | } 160 | 161 | public static void computeDiscreteValues(ArrayList data){ 162 | HashSet observedAttributes = new HashSet<>(); 163 | String[] sampleDataset = data.get(0); 164 | for(int i = 0; i < sampleDataset.length-1; i++){ 165 | discreteValues.put(i, new ArrayList()); 166 | } 167 | for(String[] dataset : data){ 168 | for(int i = 0; i < dataset.length-1; i++){//last attribute is Y/N 169 | String attribute = dataset[i]; 170 | if(attribute.equals("?")) continue; 171 | if(!observedAttributes.contains(attribute)){ 172 | discreteValues.get(i).add(attribute); 173 | observedAttributes.add(attribute); 174 | } 175 | } 176 | } 177 | } 178 | } 179 | 180 | class Comp implements Comparator{ 181 | int index; 182 | public Comp(int i){index = i;} 183 | public int compare(String[] s1, String[] s2){ 184 | return s1[index].compareTo(s2[index]); 185 | } 186 | } -------------------------------------------------------------------------------- /src/RandomForest.java: -------------------------------------------------------------------------------- 1 | import java.util.ArrayList; 2 | import java.util.HashMap; 3 | import java.util.Random; 4 | 5 | public class RandomForest { 6 | int noOfTrees; 7 | HashMap> discreteValues; 8 | String class1, class2; 9 | ArrayList data, testData;//data = training data 10 | double trainingTime=0, precision, recall, fscore, accuracy; 11 | ID3[] tree; 12 | 13 | public RandomForest(int noOfTrees, double fractionOfAttributesToTake, double fractionOfInstancesToTake, ArrayList data, ArrayList testData, int noOfClass1, int noOfClass2, 14 | String class1, String class2, HashMap> discreteValues){ 15 | if(fractionOfAttributesToTake>1 || fractionOfAttributesToTake<0 || fractionOfInstancesToTake>1 || fractionOfInstancesToTake<0){ 16 | System.out.println("Random Forest input invalid"); 17 | return; 18 | } 19 | 20 | this.noOfTrees = noOfTrees; 21 | this.data = data; 22 | this.class1 = class1; 23 | this.class2 = class2; 24 | this.testData = testData; 25 | this.discreteValues = discreteValues; 26 | 27 | int noOfAttributes = data.get(0).length-1, noOfTrainingInstances = data.size(); 28 | int noOfRandomInstances = (int)(noOfTrainingInstances*fractionOfInstancesToTake); 29 | int noOfRandomAttributes = (int)(noOfAttributes*fractionOfAttributesToTake); 30 | 31 | String[][] dataInArray = new String[noOfTrainingInstances][];//to access randomly in constant time 32 | int x = 0; 33 | for(String[] s : data){ 34 | dataInArray[x++] = s; 35 | } 36 | 37 | trainingTime = System.currentTimeMillis(); 38 | 39 | Random rand = new Random(); 40 | tree = new ID3[noOfTrees]; 41 | for(int i = 0; i < noOfTrees; i++){ 42 | ArrayList remAttr = new ArrayList<>(); 43 | for(int j = 0; j < noOfRandomAttributes; j++){ 44 | int r = rand.nextInt(noOfAttributes); 45 | if(remAttr.contains(r)) j--; 46 | else remAttr.add(r); 47 | } 48 | ArrayList randData = new ArrayList<>(); 49 | for(int j = 0; j < noOfRandomInstances; j++){ 50 | randData.add(dataInArray[rand.nextInt(noOfTrainingInstances)]); 51 | } 52 | tree[i] = new ID3(randData, testData, noOfClass1, noOfClass2, class1, class2, discreteValues, remAttr); 53 | } 54 | 55 | trainingTime = (System.currentTimeMillis() - trainingTime)/1000.0; 56 | 57 | analyse(); 58 | } 59 | 60 | public void analyse(){ 61 | if(tree==null) return;//Invalid inputs 62 | int correctClassification=0, incorrectClassification=0; 63 | int truePositive = 0, falsePositive = 0, falseNegative = 0; 64 | for(String[] s : testData){ 65 | int tempPrediction1=0, tempPrediction2=0; 66 | for(int i = 0; i < noOfTrees; i++){ 67 | if(Node.predictClass(tree[i].root, s, discreteValues)==1) tempPrediction1++; 68 | else tempPrediction2++; 69 | } 70 | int predicted=tempPrediction1>=tempPrediction2?1:2, actual = s[s.length-1].equals(class1)?1:2; 71 | if(predicted == actual ) correctClassification++; 72 | else incorrectClassification++; 73 | 74 | //1-->yes, 2-->no 75 | if(predicted==1 && actual==1) truePositive++; 76 | else if(predicted==1 && actual==2) falseNegative++; 77 | else if(predicted==2 && actual==1) falsePositive++; 78 | } 79 | precision = truePositive/(truePositive+falsePositive+0.0); 80 | recall = truePositive/(truePositive+falseNegative+0.0); 81 | fscore = 2*precision*recall/(precision+recall); 82 | accuracy = (correctClassification)/(correctClassification+incorrectClassification+ 0.0); 83 | } 84 | 85 | public void printAnalysis(){ 86 | System.out.println("Training Time="+trainingTime+"secs"); 87 | System.out.println("Accuracy="+accuracy+"\nPrecision="+precision+" Recall="+recall+" F-Score="+fscore); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /src/ReducedErrorPruning.java: -------------------------------------------------------------------------------- 1 | 2 | public class ReducedErrorPruning { 3 | ID3 tree; 4 | double trainingTime; 5 | double initialAccuracy, maxAccuracy; 6 | 7 | public ReducedErrorPruning(ID3 tree){ 8 | this.tree = tree; 9 | initialAccuracy = tree.accuracy; 10 | maxAccuracy = initialAccuracy; 11 | trainingTime = System.currentTimeMillis(); 12 | pruneTree(tree.root); 13 | trainingTime = (System.currentTimeMillis() - trainingTime)/1000.0; 14 | System.out.println("Training Time="+trainingTime+"secs"); 15 | } 16 | 17 | private void pruneTree(Node root){//return true if more accuracy obtained 18 | if(root==null) return; 19 | root.isLeaf = true; 20 | root.classification = root.noOfClass1>=root.noOfClass2?1:2; 21 | tree.analyse(); 22 | if(tree.accuracy > maxAccuracy){ 23 | maxAccuracy = tree.accuracy; 24 | return; 25 | } 26 | root.isLeaf = false; 27 | for(Node c : root.children){ 28 | if(c.isLeaf) continue; 29 | pruneTree(c); 30 | } 31 | 32 | } 33 | } 34 | --------------------------------------------------------------------------------