├── .gitignore
├── Data
├── adult_names.txt
├── output.txt
├── test.data
├── test_data_preprocessed.data
├── textbookExample.data
├── train.data
└── train_data_preprocessed.data
├── Documentation.pdf
├── LICENSE
├── README.md
├── bin
├── Comp.class
├── ID3.class
├── Main.class
├── Node.class
├── Preprocess.class
├── RandomForest.class
└── ReducedErrorPruning.class
└── src
├── ID3.java
├── Main.java
├── Node.java
├── Preprocess.java
├── RandomForest.java
└── ReducedErrorPruning.java
/.gitignore:
--------------------------------------------------------------------------------
1 | # Compiled class file
2 | *.class
3 |
4 | # Log file
5 | *.log
6 |
7 | # BlueJ files
8 | *.ctxt
9 |
10 | # Mobile Tools for Java (J2ME)
11 | .mtj.tmp/
12 |
13 | # Package Files #
14 | *.jar
15 | *.war
16 | *.ear
17 | *.zip
18 | *.tar.gz
19 | *.rar
20 |
21 | # virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
22 | hs_err_pid*
23 |
--------------------------------------------------------------------------------
/Data/adult_names.txt:
--------------------------------------------------------------------------------
1 | | This data was extracted from the census bureau database found at
2 | | http://www.census.gov/ftp/pub/DES/www/welcome.html
3 | | Donor: Ronny Kohavi and Barry Becker,
4 | | Data Mining and Visualization
5 | | Silicon Graphics.
6 | | e-mail: ronnyk@sgi.com for questions.
7 | | Split into train-test using MLC++ GenCVFiles (2/3, 1/3 random).
8 | | 48842 instances, mix of continuous and discrete (train=32561, test=16281)
9 | | 45222 if instances with unknown values are removed (train=30162, test=15060)
10 | | Duplicate or conflicting instances : 6
11 | | Class probabilities for adult.all file
12 | | Probability for the label '>50K' : 23.93% / 24.78% (without unknowns)
13 | | Probability for the label '<=50K' : 76.07% / 75.22% (without unknowns)
14 | |
15 | | Extraction was done by Barry Becker from the 1994 Census database. A set of
16 | | reasonably clean records was extracted using the following conditions:
17 | | ((AAGE>16) && (AGI>100) && (AFNLWGT>1)&& (HRSWK>0))
18 | |
19 | | Prediction task is to determine whether a person makes over 50K
20 | | a year.
21 | |
22 | | First cited in:
23 | | @inproceedings{kohavi-nbtree,
24 | | author={Ron Kohavi},
25 | | title={Scaling Up the Accuracy of Naive-Bayes Classifiers: a
26 | | Decision-Tree Hybrid},
27 | | booktitle={Proceedings of the Second International Conference on
28 | | Knowledge Discovery and Data Mining},
29 | | year = 1996,
30 | | pages={to appear}}
31 | |
32 | | Error Accuracy reported as follows, after removal of unknowns from
33 | | train/test sets):
34 | | C4.5 : 84.46+-0.30
35 | | Naive-Bayes: 83.88+-0.30
36 | | NBTree : 85.90+-0.28
37 | |
38 | |
39 | | Following algorithms were later run with the following error rates,
40 | | all after removal of unknowns and using the original train/test split.
41 | | All these numbers are straight runs using MLC++ with default values.
42 | |
43 | | Algorithm Error
44 | | -- ---------------- -----
45 | | 1 C4.5 15.54
46 | | 2 C4.5-auto 14.46
47 | | 3 C4.5 rules 14.94
48 | | 4 Voted ID3 (0.6) 15.64
49 | | 5 Voted ID3 (0.8) 16.47
50 | | 6 T2 16.84
51 | | 7 1R 19.54
52 | | 8 NBTree 14.10
53 | | 9 CN2 16.00
54 | | 10 HOODG 14.82
55 | | 11 FSS Naive Bayes 14.05
56 | | 12 IDTM (Decision table) 14.46
57 | | 13 Naive-Bayes 16.12
58 | | 14 Nearest-neighbor (1) 21.42
59 | | 15 Nearest-neighbor (3) 20.35
60 | | 16 OC1 15.04
61 | | 17 Pebls Crashed. Unknown why (bounds WERE increased)
62 | |
63 | | Conversion of original data as follows:
64 | | 1. Discretized agrossincome into two ranges with threshold 50,000.
65 | | 2. Convert U.S. to US to avoid periods.
66 | | 3. Convert Unknown to "?"
67 | | 4. Run MLC++ GenCVFiles to generate data,test.
68 | |
69 | | Description of fnlwgt (final weight)
70 | |
71 | | The weights on the CPS files are controlled to independent estimates of the
72 | | civilian noninstitutional population of the US. These are prepared monthly
73 | | for us by Population Division here at the Census Bureau. We use 3 sets of
74 | | controls.
75 | | These are:
76 | | 1. A single cell estimate of the population 16+ for each state.
77 | | 2. Controls for Hispanic Origin by age and sex.
78 | | 3. Controls by Race, age and sex.
79 | |
80 | | We use all three sets of controls in our weighting program and "rake" through
81 | | them 6 times so that by the end we come back to all the controls we used.
82 | |
83 | | The term estimate refers to population totals derived from CPS by creating
84 | | "weighted tallies" of any specified socio-economic characteristics of the
85 | | population.
86 | |
87 | | People with similar demographic characteristics should have
88 | | similar weights. There is one important caveat to remember
89 | | about this statement. That is that since the CPS sample is
90 | | actually a collection of 51 state samples, each with its own
91 | | probability of selection, the statement only applies within
92 | | state.
93 |
94 |
95 | >50K, <=50K.
96 |
97 | age: continuous.
98 | workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
99 | fnlwgt: continuous.
100 | education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
101 | education-num: continuous.
102 | marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
103 | occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
104 | relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
105 | race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
106 | sex: Female, Male.
107 | capital-gain: continuous.
108 | capital-loss: continuous.
109 | hours-per-week: continuous.
110 | native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
111 |
--------------------------------------------------------------------------------
/Data/output.txt:
--------------------------------------------------------------------------------
1 | Start...
2 |
3 | Prepocessing Training data
4 |
5 | Prepocessing Testing data
6 |
7 | Generating Decision Tree using ID3 Algorithm
8 | Training Time=1.979secs
9 | Accuracy=0.807874209200909
10 | Precision=0.8762364294330519 Recall=0.8727272727272727 F-Score=0.874478330658106
11 | No of nodes in tree = 33223
12 |
13 | Applying Reduced Error Pruning on the decision tree generated
14 | Training Time=10.7secs
15 | Accuracy=0.8404889134574043
16 | Precision=0.9467631684760756 Recall=0.8588415523781733 F-Score=0.9006617450177867
17 | No of nodes in tree = 2640
18 |
19 | Initializing Random Forest with 10 trees, 0.5 fraction of attributes and 0.33 fraction of training instances in each tree
20 | Training Time=1.618secs
21 | Accuracy=0.8313371414532277
22 | Precision=0.944270205066345 Recall=0.8511779630300834 F-Score=0.8953107129241327
23 |
24 | End...
25 |
--------------------------------------------------------------------------------
/Data/textbookExample.data:
--------------------------------------------------------------------------------
1 | Sunny,Hot,High,Weak,No
2 | Sunny,Hot,High,Strong,No
3 | Overcast,Hot,High,Weak,Yes
4 | Rain,Mild,High,Weak,Yes
5 | Rain,Cool,Normal,Weak,Yes
6 | Rain,Cool,Normal,Strong,No
7 | Overcast,Cool,Normal,Strong,Yes
8 | Sunny,Mild,High,Weak,No
9 | Sunny,Cool,Normal,Weak,Yes
10 | Rain,Mild,Normal,Weak,Yes
11 | Sunny,Mild,Normal,Strong,Yes
12 | Overcast,Mild,High,Strong,Yes
13 | Overcast,Hot,Normal,Weak,Yes
14 | Rain,Mild,High,Strong,No
--------------------------------------------------------------------------------
/Documentation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/Documentation.pdf
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Keval Morabia
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ID3-Decision-Tree-Classifier-in-Java
2 |
3 | ```
4 | Classes: 1 = >50K, 2 = <=50K
5 |
6 | Attributes
7 | age: continuous.
8 | workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
9 | fnlwgt: continuous.
10 | education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
11 | education-num: continuous.
12 | marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
13 | occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
14 | relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
15 | race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
16 | sex: Female, Male.
17 | capital-gain: continuous.
18 | capital-loss: continuous.
19 | hours-per-week: continuous.
20 | native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
21 |
22 | ```
23 |
24 |
25 | **Procedure:**
26 | 1. Decision tree was generated using the data provided and the ID3 algorithm mentioned in Tom. M. Mitchell.
27 | 2. Missing values were filled using the value which appeared most frequently in the particular attribute column.
28 | 3. Continuous values were handled as mentioned in section 3.7.2 of Tom M. Mitchell. First the values were sorted in ascending order, then at the points where value was changing, gain was calculated and finally the column was
29 | splited at the point where maximum gain was obtained.
30 | 4. Reduced Error Pruning was performed by removing a node (one by one) and then checking the accuracy. If accuracy was increased than the node was removed else we move on to check the next node.
31 | 5. Random forests were generated using 50% attributes and 33% data randomly. 10 forests were generated and accuracy increased compared to
32 | the original ID3 algorithm.
33 |
34 |
35 |
36 | ```
37 | Output:
38 | Start...
39 |
40 | Prepocessing Training data
41 |
42 | Prepocessing Testing data
43 |
44 | Generating Decision Tree using ID3 Algorithm
45 | Training Time=1.979secs
46 | Accuracy=0.807874209200909
47 | Precision=0.8762364294330519 Recall=0.8727272727272727 F-Score=0.874478330658106
48 | No of nodes in tree = 33223
49 |
50 | Applying Reduced Error Pruning on the decision tree generated
51 | Training Time=10.7secs
52 | Accuracy=0.8404889134574043
53 | Precision=0.9467631684760756 Recall=0.8588415523781733 F-Score=0.9006617450177867
54 | No of nodes in tree = 2640
55 |
56 | Initializing Random Forest with 10 trees, 0.5 fraction of attributes and 0.33 fraction of training instances in each tree
57 | Training Time=1.618secs
58 | Accuracy=0.8313371414532277
59 | Precision=0.944270205066345 Recall=0.8511779630300834 F-Score=0.8953107129241327
60 |
61 | End...
62 | ```
63 |
--------------------------------------------------------------------------------
/bin/Comp.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/Comp.class
--------------------------------------------------------------------------------
/bin/ID3.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/ID3.class
--------------------------------------------------------------------------------
/bin/Main.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/Main.class
--------------------------------------------------------------------------------
/bin/Node.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/Node.class
--------------------------------------------------------------------------------
/bin/Preprocess.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/Preprocess.class
--------------------------------------------------------------------------------
/bin/RandomForest.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/RandomForest.class
--------------------------------------------------------------------------------
/bin/ReducedErrorPruning.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kevalmorabia97/ID3-Decision-Tree-Classifier-in-Java/f7bb7de6512268b9cd60361d5461cf7be14d1aff/bin/ReducedErrorPruning.class
--------------------------------------------------------------------------------
/src/ID3.java:
--------------------------------------------------------------------------------
1 | import java.util.ArrayList;
2 | import java.util.HashMap;
3 |
4 | public class ID3 {
5 | HashMap> discreteValues;
6 | String class1, class2;
7 | ArrayList data, testData;//data = training data
8 | Node root;
9 | double trainingTime=0, precision, recall, fscore, accuracy;
10 | int noOfNodes;
11 |
12 | public ID3(ArrayList data, ArrayList testData, int noOfClass1, int noOfClass2, String class1, String class2, HashMap> discreteValues, ArrayList remAttr){
13 | trainingTime = System.currentTimeMillis();
14 |
15 | this.data = data;
16 | this.class1 = class1;
17 | this.class2 = class2;
18 | this.testData = testData;
19 | this.discreteValues = discreteValues;
20 |
21 | root = new Node();
22 | double p1 = noOfClass1/(noOfClass1+noOfClass2+0.0), p2 = noOfClass2/(noOfClass1+noOfClass2+0.0);
23 | root.entropy = -1*pLogP(p1) - 1*pLogP(p2);
24 | root.data = data;
25 | root.noOfClass1 = noOfClass1;
26 | root.noOfClass2 = noOfClass2;
27 |
28 | root.remainingAttributes = remAttr;
29 | generateDecisionTree(root);
30 |
31 | trainingTime = (System.currentTimeMillis() - trainingTime)/1000.0;
32 |
33 | analyse();//calculate precision, recall, fscore, accuracy
34 | }
35 |
36 | public void analyse(){
37 | int correctClassification=0, incorrectClassification=0;
38 | int truePositive = 0, falsePositive = 0, falseNegative = 0;
39 | for(String[] s : testData){
40 | int predicted = Node.predictClass(root, s, discreteValues), actual = s[s.length-1].equals(class1)?1:2;
41 | if(predicted == actual ) correctClassification++;
42 | else incorrectClassification++;
43 |
44 | //1-->yes, 2-->no
45 | if(predicted==1 && actual==1) truePositive++;
46 | else if(predicted==1 && actual==2) falseNegative++;
47 | else if(predicted==2 && actual==1) falsePositive++;
48 | }
49 | precision = truePositive/(truePositive+falsePositive+0.0);
50 | recall = truePositive/(truePositive+falseNegative+0.0);
51 | fscore = 2*precision*recall/(precision+recall);
52 | accuracy = (correctClassification)/(correctClassification+incorrectClassification+ 0.0);
53 |
54 | noOfNodes = 0;
55 | countNodes(root);
56 | }
57 |
58 | public void printAnalysis(){
59 | System.out.println("Accuracy="+accuracy+"\nPrecision="+precision+" Recall="+recall+" F-Score="+fscore);
60 | System.out.println("No of nodes in tree = "+noOfNodes);
61 | }
62 |
63 | public void printTrainingTime(){
64 | System.out.println("Training Time="+trainingTime+"secs");
65 | }
66 |
67 | public void countNodes(Node root){
68 | if(root==null) return;
69 | noOfNodes++;
70 | if(root.isLeaf) return;
71 | for(Node n : root.children) countNodes(n);
72 | }
73 |
74 | private static double pLogP(double p){
75 | return p==0?0:p*Math.log(p);
76 | }
77 |
78 | private void generateDecisionTree(Node root){
79 | if(root==null) return;
80 | if(root.remainingAttributes.size()==1){//leaf node
81 | root.isLeaf = true;
82 | root.classification = root.noOfClass1>=root.noOfClass2?1:2;
83 | }else if(root.noOfClass1==0 || root.noOfClass2==0 || root.data.size()==0){//leaf
84 | root.isLeaf = true;
85 | root.classification = root.noOfClass1==0?2:1;
86 | }else{
87 | //find split attribute which gives max gain
88 | root.children = splitAttribute(root);
89 | ArrayList discreteValuesOfThisAttribute = discreteValues.get(root.attribute);
90 | for(int j=0; j < discreteValuesOfThisAttribute.size(); j++){
91 | root.children[j].data = new ArrayList<>();
92 | root.children[j].remainingAttributes = new ArrayList<>();
93 | for(int rem : root.remainingAttributes){
94 | if(rem!=root.attribute) root.children[j].remainingAttributes.add(rem);
95 | }
96 | String curr = discreteValuesOfThisAttribute.get(j);
97 | for(String[] s : root.data){
98 | if(s[root.attribute].equals(curr)){
99 | root.children[j].data.add(s);
100 | }
101 | }
102 | generateDecisionTree(root.children[j]);
103 | }
104 | }
105 | }
106 |
107 | public Node[] splitAttribute(Node root){
108 | double maxGain = -1.0;
109 | Node[] ans = null;
110 | for(int i : root.remainingAttributes){
111 | ArrayList discreteValuesOfThisAttribute = discreteValues.get(i);
112 | Node[] child = new Node[discreteValuesOfThisAttribute.size()];
113 | for(int j=0; j < discreteValuesOfThisAttribute.size(); j++){
114 | String curr = discreteValuesOfThisAttribute.get(j);
115 | child[j] = new Node();
116 | for(String[] s : root.data){
117 | if(s[i].equals(curr)){
118 | if(s[s.length-1].equals(class1)) child[j].noOfClass1++;
119 | else child[j].noOfClass2++;
120 | }
121 | }
122 | }
123 | int total = root.data.size();
124 | double gain = root.entropy;
125 | for(int j = 0; j < discreteValuesOfThisAttribute.size(); j++){
126 | int c1 = child[j].noOfClass1, c2 = child[j].noOfClass2;
127 | if(c1==0 && c2==0) continue;
128 | double p1 = c1/(c1+c2+0.0), p2 = c2/(c1+c2+0.0);
129 | child[j].entropy = -1*pLogP(p1) + -1*pLogP(p2);
130 | gain -= ((c1+c2)/(total+0.0))*child[j].entropy;
131 | }
132 | if(gain > maxGain){
133 | root.attribute = i;
134 | maxGain = gain;
135 | ans = child;
136 | }
137 | }
138 | return ans;
139 | }
140 | }
141 |
--------------------------------------------------------------------------------
/src/Main.java:
--------------------------------------------------------------------------------
1 | import java.io.File;
2 | import java.io.IOException;
3 | import java.util.ArrayList;
4 |
5 | public class Main {
6 |
7 | public static void main(String[] args) throws IOException {
8 | System.out.println("Start...");
9 |
10 | System.out.println("\nPrepocessing Training data");
11 | String class1 = "<=50K", class2 = ">50K";
12 | ArrayList continuousAttributes = new ArrayList<>();
13 | continuousAttributes.add(0);
14 | continuousAttributes.add(2);
15 | continuousAttributes.add(4);
16 | continuousAttributes.add(10);
17 | continuousAttributes.add(11);
18 | continuousAttributes.add(12);
19 |
20 | Preprocess p = new Preprocess(class1, class2, continuousAttributes);
21 | ArrayList trainData = p.discretize(new File("Data/train.data"));
22 | int noOfClass1 = p.c1, noOfClass2 = p.c2;
23 | Preprocess.predictMissingValues(trainData, "train");
24 |
25 | Preprocess.computeDiscreteValues(trainData);
26 |
27 | System.out.println("\nPrepocessing Testing data");
28 | ArrayList testData = Preprocess.preprocessTestData(new File("Data/test.data"), continuousAttributes);
29 |
30 | ArrayList remAttr = new ArrayList<>();
31 | for(int i = 0; i < trainData.get(0).length-1; i++) remAttr.add(i);
32 | System.out.println("\nGenerating Decision Tree using ID3 Algorithm");
33 | //<=50K = class1, >50K = class2
34 | ID3 decisionTree = new ID3(trainData, testData, noOfClass1, noOfClass2, class1, class2, Preprocess.discreteValues, remAttr);
35 | decisionTree.printTrainingTime();
36 | decisionTree.printAnalysis();
37 |
38 | System.out.println("\nApplying Reduced Error Pruning on the decision tree generated");
39 | //if you want to preserve original tree then create another ID3 instance and pass that to rep
40 | ReducedErrorPruning rep = new ReducedErrorPruning(decisionTree);
41 | rep.tree.printAnalysis();
42 |
43 | int noOftrees = 10;
44 | double fractionOfAttributesToTake = 0.5, fractionOfTrainingInstancesToTake = 0.33;
45 | System.out.println("\nInitializing Random Forest with "+noOftrees+" trees, "+fractionOfAttributesToTake
46 | +" fraction of attributes and "+fractionOfTrainingInstancesToTake+" fraction of training instances in each tree");
47 | RandomForest rf = new RandomForest(noOftrees, fractionOfAttributesToTake, fractionOfTrainingInstancesToTake,
48 | trainData, testData, noOfClass1, noOfClass2, class1, class2, Preprocess.discreteValues);
49 | rf.printAnalysis();
50 |
51 | System.out.println("\nEnd...");
52 | }
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/src/Node.java:
--------------------------------------------------------------------------------
1 | import java.util.ArrayList;
2 | import java.util.HashMap;
3 |
4 | public class Node {
5 | int attribute;//attribute on which classification is done
6 | ArrayList remainingAttributes;
7 | int noOfClass1, noOfClass2;//class1 = <=50K, class2 = >50K
8 | double entropy;
9 | Node[] children;
10 | ArrayList data;//Datasets satisfying conditions of current node and all ancestors
11 | boolean isLeaf = false;
12 | int classification;//applicable if leaf, classification = 1 or 2 for class1 and class2 respectively
13 |
14 | public Node(int attr, String value, int c1, int c2, double d, ArrayList data, ArrayList remainingAttributes){
15 | attribute = attr;
16 | noOfClass1 = c1;
17 | noOfClass2 = c2;
18 | entropy = d;
19 | this.data = data;
20 | this.remainingAttributes = remainingAttributes;
21 | }
22 |
23 | public Node(int classification){
24 | isLeaf = true;
25 | this.classification = classification;
26 | }
27 |
28 | public Node(){
29 |
30 | }
31 |
32 | public String toString(){
33 | return isLeaf ? "Leaf,Classification="+classification : "Entropy="+entropy+",Split="+attribute;
34 | }
35 |
36 | public static int predictClass(Node root, String[] data, HashMap> discreteValues){
37 | if(root==null) return 1;
38 | else if(root.isLeaf) return root.classification;
39 | String s = data[root.attribute];
40 | ArrayList discrVals = discreteValues.get(root.attribute);
41 | for(int i = 0; i < discrVals.size(); i++){
42 | if(s.equals(discrVals.get(i))){
43 | return predictClass(root.children[i], data, discreteValues);
44 | }
45 | }
46 | return 1;
47 | }
48 |
49 | public static void inOrder(Node root){
50 | if(root==null) return;
51 | System.out.println(root);
52 | if(root.isLeaf) return;
53 | for(Node c : root.children) inOrder(c);
54 | }
55 | }
56 |
--------------------------------------------------------------------------------
/src/Preprocess.java:
--------------------------------------------------------------------------------
1 | import java.io.BufferedReader;
2 | import java.io.BufferedWriter;
3 | import java.io.File;
4 | import java.io.FileReader;
5 | import java.io.FileWriter;
6 | import java.io.IOException;
7 | import java.util.ArrayList;
8 | import java.util.Arrays;
9 | import java.util.Collections;
10 | import java.util.Comparator;
11 | import java.util.HashMap;
12 | import java.util.HashSet;
13 | import java.util.StringTokenizer;
14 |
15 | //Discretize into 2 classes: <=val and >val such that resulting entropy is minimum
16 | public class Preprocess {
17 | int c1, c2;
18 | static int[] partitionAt;//the value at which continuous attributes are partitioned
19 | String class1, class2;
20 | ArrayList continuousAttributes;
21 | /*no of distinct values for a attribute type
22 | * e.g.: 0, [<=19,>19]
23 | * 1, [Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked]
24 | */
25 | static HashMap> discreteValues = new HashMap<>();
26 |
27 | public Preprocess(String class1, String class2, ArrayList continuousAttributes){
28 | this.class1 = class1;
29 | this.class2 = class2;
30 | this.continuousAttributes = continuousAttributes;
31 | }
32 |
33 | public ArrayList discretize(File trainOriginal) throws IOException{
34 | BufferedReader br = new BufferedReader(new FileReader(trainOriginal));
35 |
36 | String s = br.readLine();
37 | br.close();
38 | StringTokenizer st = new StringTokenizer(s, ",");
39 | int noOfAttributes = st.countTokens();
40 | partitionAt = new int[noOfAttributes];
41 | ArrayList data = new ArrayList<>();
42 | br = new BufferedReader(new FileReader(trainOriginal));
43 | while((s = br.readLine())!=null){
44 | st = new StringTokenizer(s, ",");
45 | String[] dataset = new String[noOfAttributes];
46 | for(int i = 0; i < noOfAttributes; i++){
47 | dataset[i] = st.nextToken();
48 | }
49 | if(dataset[noOfAttributes-1].equals(class1)) c1++;
50 | else c2++;
51 | data.add(dataset);
52 | }
53 | for(int i : continuousAttributes){
54 | int c1LeftPart = 0, c1RightPart=c1;
55 | int c2LeftPart = 0, c2RightPart=c2;
56 | int total = c1+c2;
57 | Collections.sort(data, new Comp(i));
58 | double minEndtropy = Double.MAX_VALUE;
59 | int partAt=1;
60 | String prev = data.get(0)[i];
61 | for(String[] dataset: data){
62 | if(dataset[noOfAttributes-1].equals(class1)){
63 | c1LeftPart++;
64 | c1RightPart--;
65 | }else{
66 | c2LeftPart++;
67 | c2RightPart--;
68 | }
69 | String curr = dataset[i];
70 | if(curr.equals(prev) || curr.equals("?")) continue;
71 | double p11 = (c1LeftPart+0.0)/(c1LeftPart+c2LeftPart), p12 = (c2LeftPart+0.0)/(c1LeftPart+c2LeftPart);
72 | double p21 = (c1RightPart+0.0)/(c1RightPart+c2RightPart), p22 = (c2RightPart+0.0)/(c1RightPart+c2RightPart);
73 | double entropy = ((c1LeftPart+c2LeftPart)/total+0.0)*(-1*p11*Math.log(p11) -1*p12*Math.log(p12)) + ((c1RightPart+c2RightPart)/total+0.0)*(-1*p21*Math.log(p21) -1*p22*Math.log(p22));
74 | if(entropy < minEndtropy){
75 | minEndtropy = entropy;
76 | partAt = (Integer.parseInt(prev)+Integer.parseInt(curr))/2;
77 | partitionAt[i] = partAt;
78 | }
79 | prev = curr;
80 | }
81 | String newc1Name = "<="+partAt, newc2Name = ">"+partAt;
82 | for(String[] dataset : data){
83 | if(dataset[i]=="?") continue;
84 | if(Integer.parseInt(dataset[i]) <= partAt) dataset[i] = newc1Name;
85 | else dataset[i] = newc2Name;
86 | }
87 | }
88 | br.close();
89 | return data;
90 | }
91 |
92 | /*
93 | * data should not be continuous
94 | * Sets ? to the value that occurs max times for that attribute
95 | */
96 | public static void predictMissingValues(ArrayList data, String trainOrTest) throws IOException{
97 | int n = data.get(0).length-1;
98 | String[] predictedValue = new String[n];
99 | for(int i = 0; i < n; i++){
100 | HashMap count = new HashMap<>();
101 | for(String[] s : data){
102 | if(s[i].equals("?")) continue;
103 | if(count.containsKey(s[i])) count.put(s[i], count.get(s[i])+1);
104 | else count.put(s[i], 1);
105 | }
106 | int max = 0;
107 | for(String s : count.keySet()){
108 | int c = count.get(s);
109 | if(c > max){
110 | max = c;
111 | predictedValue[i] = s;
112 | }
113 | }
114 | }
115 |
116 | for(String[] s : data){
117 | for(int i = 0; i < n; i++){
118 | if(s[i].equals("?")) s[i] = predictedValue[i];
119 | }
120 | }
121 |
122 | File trainDiscretized = new File("Data/"+trainOrTest+"_data_preprocessed.data");
123 | BufferedWriter bw = new BufferedWriter(new FileWriter(trainDiscretized));
124 | for(String[] dataset : data){
125 | bw.write(Arrays.toString(dataset).replace("[", "").replace("]", "").replace(" ", "")+"\n");
126 | }
127 | bw.close();
128 | }
129 |
130 | //use partitionAt calculated when discretizing train data
131 | public static ArrayList preprocessTestData(File test, ArrayList continuousAttributes) throws IOException{
132 | BufferedReader br = new BufferedReader(new FileReader(test));
133 |
134 | String s = br.readLine();
135 | br.close();
136 | StringTokenizer st = new StringTokenizer(s, ",");
137 | int noOfAttributes = st.countTokens();
138 | ArrayList data = new ArrayList<>();
139 | br = new BufferedReader(new FileReader(test));
140 | while((s = br.readLine())!=null){
141 | st = new StringTokenizer(s, ",");
142 | String[] dataset = new String[noOfAttributes];
143 | for(int i = 0; i < noOfAttributes; i++){
144 | dataset[i] = st.nextToken();
145 | }
146 | data.add(dataset);
147 | }
148 | for(int i : continuousAttributes){
149 | String newc1Name = "<="+partitionAt[i], newc2Name = ">"+partitionAt[i];
150 | for(String[] dataset : data){
151 | if(dataset[i]=="?") continue;
152 | if(Integer.parseInt(dataset[i]) <= partitionAt[i]) dataset[i] = newc1Name;
153 | else dataset[i] = newc2Name;
154 | }
155 | }
156 | br.close();
157 | predictMissingValues(data,"test");
158 | return data;
159 | }
160 |
161 | public static void computeDiscreteValues(ArrayList data){
162 | HashSet observedAttributes = new HashSet<>();
163 | String[] sampleDataset = data.get(0);
164 | for(int i = 0; i < sampleDataset.length-1; i++){
165 | discreteValues.put(i, new ArrayList());
166 | }
167 | for(String[] dataset : data){
168 | for(int i = 0; i < dataset.length-1; i++){//last attribute is Y/N
169 | String attribute = dataset[i];
170 | if(attribute.equals("?")) continue;
171 | if(!observedAttributes.contains(attribute)){
172 | discreteValues.get(i).add(attribute);
173 | observedAttributes.add(attribute);
174 | }
175 | }
176 | }
177 | }
178 | }
179 |
180 | class Comp implements Comparator{
181 | int index;
182 | public Comp(int i){index = i;}
183 | public int compare(String[] s1, String[] s2){
184 | return s1[index].compareTo(s2[index]);
185 | }
186 | }
--------------------------------------------------------------------------------
/src/RandomForest.java:
--------------------------------------------------------------------------------
1 | import java.util.ArrayList;
2 | import java.util.HashMap;
3 | import java.util.Random;
4 |
5 | public class RandomForest {
6 | int noOfTrees;
7 | HashMap> discreteValues;
8 | String class1, class2;
9 | ArrayList data, testData;//data = training data
10 | double trainingTime=0, precision, recall, fscore, accuracy;
11 | ID3[] tree;
12 |
13 | public RandomForest(int noOfTrees, double fractionOfAttributesToTake, double fractionOfInstancesToTake, ArrayList data, ArrayList testData, int noOfClass1, int noOfClass2,
14 | String class1, String class2, HashMap> discreteValues){
15 | if(fractionOfAttributesToTake>1 || fractionOfAttributesToTake<0 || fractionOfInstancesToTake>1 || fractionOfInstancesToTake<0){
16 | System.out.println("Random Forest input invalid");
17 | return;
18 | }
19 |
20 | this.noOfTrees = noOfTrees;
21 | this.data = data;
22 | this.class1 = class1;
23 | this.class2 = class2;
24 | this.testData = testData;
25 | this.discreteValues = discreteValues;
26 |
27 | int noOfAttributes = data.get(0).length-1, noOfTrainingInstances = data.size();
28 | int noOfRandomInstances = (int)(noOfTrainingInstances*fractionOfInstancesToTake);
29 | int noOfRandomAttributes = (int)(noOfAttributes*fractionOfAttributesToTake);
30 |
31 | String[][] dataInArray = new String[noOfTrainingInstances][];//to access randomly in constant time
32 | int x = 0;
33 | for(String[] s : data){
34 | dataInArray[x++] = s;
35 | }
36 |
37 | trainingTime = System.currentTimeMillis();
38 |
39 | Random rand = new Random();
40 | tree = new ID3[noOfTrees];
41 | for(int i = 0; i < noOfTrees; i++){
42 | ArrayList remAttr = new ArrayList<>();
43 | for(int j = 0; j < noOfRandomAttributes; j++){
44 | int r = rand.nextInt(noOfAttributes);
45 | if(remAttr.contains(r)) j--;
46 | else remAttr.add(r);
47 | }
48 | ArrayList randData = new ArrayList<>();
49 | for(int j = 0; j < noOfRandomInstances; j++){
50 | randData.add(dataInArray[rand.nextInt(noOfTrainingInstances)]);
51 | }
52 | tree[i] = new ID3(randData, testData, noOfClass1, noOfClass2, class1, class2, discreteValues, remAttr);
53 | }
54 |
55 | trainingTime = (System.currentTimeMillis() - trainingTime)/1000.0;
56 |
57 | analyse();
58 | }
59 |
60 | public void analyse(){
61 | if(tree==null) return;//Invalid inputs
62 | int correctClassification=0, incorrectClassification=0;
63 | int truePositive = 0, falsePositive = 0, falseNegative = 0;
64 | for(String[] s : testData){
65 | int tempPrediction1=0, tempPrediction2=0;
66 | for(int i = 0; i < noOfTrees; i++){
67 | if(Node.predictClass(tree[i].root, s, discreteValues)==1) tempPrediction1++;
68 | else tempPrediction2++;
69 | }
70 | int predicted=tempPrediction1>=tempPrediction2?1:2, actual = s[s.length-1].equals(class1)?1:2;
71 | if(predicted == actual ) correctClassification++;
72 | else incorrectClassification++;
73 |
74 | //1-->yes, 2-->no
75 | if(predicted==1 && actual==1) truePositive++;
76 | else if(predicted==1 && actual==2) falseNegative++;
77 | else if(predicted==2 && actual==1) falsePositive++;
78 | }
79 | precision = truePositive/(truePositive+falsePositive+0.0);
80 | recall = truePositive/(truePositive+falseNegative+0.0);
81 | fscore = 2*precision*recall/(precision+recall);
82 | accuracy = (correctClassification)/(correctClassification+incorrectClassification+ 0.0);
83 | }
84 |
85 | public void printAnalysis(){
86 | System.out.println("Training Time="+trainingTime+"secs");
87 | System.out.println("Accuracy="+accuracy+"\nPrecision="+precision+" Recall="+recall+" F-Score="+fscore);
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/src/ReducedErrorPruning.java:
--------------------------------------------------------------------------------
1 |
2 | public class ReducedErrorPruning {
3 | ID3 tree;
4 | double trainingTime;
5 | double initialAccuracy, maxAccuracy;
6 |
7 | public ReducedErrorPruning(ID3 tree){
8 | this.tree = tree;
9 | initialAccuracy = tree.accuracy;
10 | maxAccuracy = initialAccuracy;
11 | trainingTime = System.currentTimeMillis();
12 | pruneTree(tree.root);
13 | trainingTime = (System.currentTimeMillis() - trainingTime)/1000.0;
14 | System.out.println("Training Time="+trainingTime+"secs");
15 | }
16 |
17 | private void pruneTree(Node root){//return true if more accuracy obtained
18 | if(root==null) return;
19 | root.isLeaf = true;
20 | root.classification = root.noOfClass1>=root.noOfClass2?1:2;
21 | tree.analyse();
22 | if(tree.accuracy > maxAccuracy){
23 | maxAccuracy = tree.accuracy;
24 | return;
25 | }
26 | root.isLeaf = false;
27 | for(Node c : root.children){
28 | if(c.isLeaf) continue;
29 | pruneTree(c);
30 | }
31 |
32 | }
33 | }
34 |
--------------------------------------------------------------------------------