├── .gitignore ├── LICENSE ├── README.md └── machine-learning ├── doc ├── allclasses-frame.html ├── allclasses-noframe.html ├── bayes │ ├── BNConditionalQuery.html │ ├── BNDataGenerator.html │ ├── BNEvaluator.html │ ├── BNJointQuery.html │ ├── BNNode.html │ ├── BNNodeManager.html │ ├── BNResultWriter.html │ ├── BNUtility.html │ ├── BayesianNetwork.Type.html │ ├── BayesianNetwork.html │ ├── VariableSet.html │ ├── builders │ │ ├── HillClimbingBuilder.StoppingCriteria.html │ │ ├── HillClimbingBuilder.html │ │ ├── NaiveBayesBuilder.html │ │ ├── NetworkBuilder.html │ │ ├── Operation.Type.html │ │ ├── Operation.html │ │ ├── SparseCandidateBuilder.KlEdgeScorePair.html │ │ ├── SparseCandidateBuilder.html │ │ ├── TANBuilder.html │ │ ├── class-use │ │ │ ├── HillClimbingBuilder.StoppingCriteria.html │ │ │ ├── HillClimbingBuilder.html │ │ │ ├── NaiveBayesBuilder.html │ │ │ ├── NetworkBuilder.html │ │ │ ├── Operation.Type.html │ │ │ ├── Operation.html │ │ │ ├── SparseCandidateBuilder.KlEdgeScorePair.html │ │ │ ├── SparseCandidateBuilder.html │ │ │ └── TANBuilder.html │ │ ├── package-frame.html │ │ ├── package-summary.html │ │ ├── package-tree.html │ │ ├── package-use.html │ │ └── scoring │ │ │ ├── BIC.html │ │ │ ├── ScoringFunction.html │ │ │ ├── class-use │ │ │ ├── BIC.html │ │ │ └── ScoringFunction.html │ │ │ ├── package-frame.html │ │ │ ├── package-summary.html │ │ │ ├── package-tree.html │ │ │ └── package-use.html │ ├── class-use │ │ ├── BNConditionalQuery.html │ │ ├── BNDataGenerator.html │ │ ├── BNEvaluator.html │ │ ├── BNJointQuery.html │ │ ├── BNNode.html │ │ ├── BNNodeManager.html │ │ ├── BNResultWriter.html │ │ ├── BNUtility.html │ │ ├── BayesianNetwork.Type.html │ │ ├── BayesianNetwork.html │ │ └── VariableSet.html │ ├── classifiers │ │ ├── NaiveBayesClassifier.html │ │ ├── class-use │ │ │ └── NaiveBayesClassifier.html │ │ ├── package-frame.html │ │ ├── package-summary.html │ │ ├── package-tree.html │ │ └── package-use.html │ ├── cpd │ │ ├── CPDLeaf.html │ │ ├── CPDNode.html │ │ ├── CPDQuery.html │ │ ├── CPDTree.ToStringHelper.html │ │ ├── CPDTree.html │ │ ├── CPDTreeBuilder.html │ │ ├── Split.html │ │ ├── SplitBranch.html │ │ ├── class-use │ │ │ ├── CPDLeaf.html │ │ │ ├── CPDNode.html │ │ │ ├── CPDQuery.html │ │ │ ├── CPDTree.ToStringHelper.html │ │ │ ├── CPDTree.html │ │ │ ├── CPDTreeBuilder.html │ │ │ ├── Split.html │ │ │ └── SplitBranch.html │ │ ├── package-frame.html │ │ ├── package-summary.html │ │ ├── package-tree.html │ │ └── package-use.html │ ├── information │ │ ├── KLDivergence.html │ │ ├── class-use │ │ │ └── KLDivergence.html │ │ ├── package-frame.html │ │ ├── package-summary.html │ │ ├── package-tree.html │ │ └── package-use.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ └── package-use.html ├── common │ ├── classification │ │ ├── ClassificationResult.html │ │ ├── Classifier.html │ │ ├── class-use │ │ │ ├── ClassificationResult.html │ │ │ └── Classifier.html │ │ ├── package-frame.html │ │ ├── package-summary.html │ │ ├── package-tree.html │ │ └── package-use.html │ └── kfold │ │ ├── KFoldCreator.html │ │ ├── class-use │ │ └── KFoldCreator.html │ │ ├── package-frame.html │ │ ├── package-summary.html │ │ ├── package-tree.html │ │ └── package-use.html ├── constant-values.html ├── deprecated-list.html ├── graph │ ├── dag │ │ ├── DetectCycles.html │ │ ├── TopologicalSort.html │ │ ├── class-use │ │ │ ├── DetectCycles.html │ │ │ └── TopologicalSort.html │ │ ├── package-frame.html │ │ ├── package-summary.html │ │ ├── package-tree.html │ │ └── package-use.html │ └── prim │ │ ├── Edge.html │ │ ├── Prim.html │ │ ├── class-use │ │ ├── Edge.html │ │ └── Prim.html │ │ ├── package-frame.html │ │ ├── package-summary.html │ │ ├── package-tree.html │ │ └── package-use.html ├── help-doc.html ├── index-files │ ├── index-1.html │ ├── index-10.html │ ├── index-11.html │ ├── index-12.html │ ├── index-13.html │ ├── index-14.html │ ├── index-15.html │ ├── index-16.html │ ├── index-17.html │ ├── index-18.html │ ├── index-19.html │ ├── index-2.html │ ├── index-20.html │ ├── index-21.html │ ├── index-22.html │ ├── index-23.html │ ├── index-3.html │ ├── index-4.html │ ├── index-5.html │ ├── index-6.html │ ├── index-7.html │ ├── index-8.html │ └── index-9.html ├── index.html ├── main │ ├── DTMain.html │ ├── MainHillClimbing.html │ ├── MainNaiveBayes.html │ ├── MainSparseCandidate.html │ ├── class-use │ │ ├── DTMain.html │ │ ├── MainHillClimbing.html │ │ ├── MainNaiveBayes.html │ │ └── MainSparseCandidate.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ └── package-use.html ├── overview-frame.html ├── overview-summary.html ├── overview-tree.html ├── package-list ├── resources │ ├── background.gif │ ├── tab.gif │ ├── titlebar.gif │ └── titlebar_end.gif ├── stylesheet.css └── tree │ ├── DecisionTree.TreePrinter.html │ ├── DecisionTree.html │ ├── DtLeaf.html │ ├── DtNode.html │ ├── ID3Builder.html │ ├── Node.html │ ├── class-use │ ├── DecisionTree.TreePrinter.html │ ├── DecisionTree.html │ ├── DtLeaf.html │ ├── DtNode.html │ ├── ID3Builder.html │ └── Node.html │ ├── classifiers │ ├── ID3TreeClassifier.html │ ├── class-use │ │ └── ID3TreeClassifier.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ └── package-use.html │ ├── evaluate │ ├── BiClassTest.html │ ├── BiClassTestResults.html │ ├── class-use │ │ ├── BiClassTest.html │ │ └── BiClassTestResults.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ └── package-use.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ ├── package-use.html │ └── train │ ├── Bin.html │ ├── Entropy.html │ ├── Split.html │ ├── SplitBranch.html │ ├── SplitGenerator.html │ ├── class-use │ ├── Bin.html │ ├── Entropy.html │ ├── Split.html │ ├── SplitBranch.html │ └── SplitGenerator.html │ ├── package-frame.html │ ├── package-summary.html │ ├── package-tree.html │ └── package-use.html ├── examples └── examples │ ├── DecisionTreeExample.java │ ├── GraphAlgorithmExamples.java │ └── HMMExamples.java ├── gold_standard ├── bayes │ └── tan.rtf └── tree │ ├── m10.tree │ ├── m2.tree │ ├── m20.tree │ └── m4.tree ├── lib ├── data_structures.jar ├── guava-18.0.jar ├── hamcrest-core-1.3.jar ├── junit-4.12-beta-3.jar ├── lombok.jar └── mockito-all-1.9.5.jar ├── src ├── applications │ ├── MainHillClimbing.java │ ├── MainNaiveBayes.java │ └── MainSparseCandidate.java ├── bayes │ ├── BNConditionalQuery.java │ ├── BNDataGenerator.java │ ├── BNEvaluator.java │ ├── BNJointQuery.java │ ├── BNNode.java │ ├── BNResultWriter.java │ ├── BNStructure.java │ ├── BNUtility.java │ ├── BayesianNetwork.java │ ├── VariableSet.java │ ├── classifiers │ │ └── NaiveBayesClassifier.java │ ├── cpd │ │ ├── CPDLeaf.java │ │ ├── CPDNode.java │ │ ├── CPDQuery.java │ │ ├── CPDTree.java │ │ ├── CPDTreeBuilder.java │ │ ├── Split.java │ │ └── SplitBranch.java │ ├── information │ │ └── KLDivergence.java │ └── structuresearch │ │ ├── HillClimbingBuilder.java │ │ ├── NaiveBayesBuilder.java │ │ ├── NetworkBuilder.java │ │ ├── Operation.java │ │ ├── SparseCandidateBuilder.java │ │ ├── TANBuilder.java │ │ └── score │ │ ├── BIC.java │ │ └── ScoringFunction.java ├── classify │ ├── ClassificationResult.java │ ├── Classifier.java │ └── evaluate │ │ └── PercentageError.java ├── data │ ├── Attribute.java │ ├── AttributeSet.java │ ├── DataSet.java │ ├── Instance.java │ ├── InstanceSet.java │ ├── fold │ │ └── KFoldCreator.java │ └── reader │ │ └── ArffReader.java ├── distributions │ ├── Distribution.java │ └── GeometricDistribution.java ├── graph │ ├── DirectedGraph.java │ ├── Path.java │ ├── bellmanford │ │ └── BellmanFord.java │ ├── dag │ │ ├── DetectCycles.java │ │ └── TopologicalSort.java │ ├── floydwarshall │ │ ├── AllPairsShortestPaths.java │ │ └── FloydWarshall.java │ └── prim │ │ ├── Edge.java │ │ └── Prim.java ├── hmm │ ├── HMM.java │ ├── State.java │ ├── StateContainer.java │ ├── StateParamsTied.java │ ├── StateSilent.java │ ├── Transition.java │ └── algorithms │ │ ├── BackwardAlgorithm.java │ │ ├── DpMatrix.java │ │ ├── DpMatrixElement.java │ │ ├── ForwardAlgorithm.java │ │ └── SortSilentStates.java ├── math │ └── LogP.java └── tree │ ├── DecisionTree.java │ ├── DtLeaf.java │ ├── DtNode.java │ ├── Forest.java │ ├── Node.java │ ├── algorithms │ ├── DecisionTreeBuilder.java │ ├── ID3TreeBuilder.java │ └── RandomForestBuilder.java │ ├── classifiers │ ├── ID3TreeClassifier.java │ └── RandomForestClassifier.java │ ├── evaluate │ ├── BiClassTest.java │ └── BiClassTestResults.java │ └── train │ ├── Bin.java │ ├── Entropy.java │ ├── Split.java │ ├── SplitBranch.java │ └── SplitGenerator.java └── tst └── data └── AttributeTest.java /.gitignore: -------------------------------------------------------------------------------- 1 | .metadata/ 2 | bin 3 | .DS_Store 4 | .settings 5 | .classpath 6 | .project 7 | data_sets 8 | gold_standard 9 | results 10 | .recommenders/ 11 | *.class -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | machine-learning 2 | ================ 3 | 4 | A Java library of machine learning algorithms. These are my own implementations that I am working on for practice. 5 | 6 | Decision Trees: 7 | 8 | ID3 9 | 10 | Bayesian Networks: 11 | 12 | Naïve Bayes classifier 13 | 14 | Tree Augmented Naïve (TAN) Bayes classifier 15 | 16 | Hill Climbing Structure Search 17 | 18 | Sparse Candidate Structure Search 19 | 20 | Bayesian Information Criterion (BIC) 21 | 22 | Artificial Data Generation 23 | 24 | Hidden Markov Models: 25 | 26 | Forward Algorithm 27 | 28 | Backward Algorithm 29 | 30 | Generic Graph Theory Algorithms: 31 | 32 | Bellman-Ford Algorithm 33 | 34 | Floyd-Warshall Algorithm 35 | 36 | Prim's Algorithm 37 | 38 | Cycle Detection 39 | 40 | Topological Sort of Vertices in a DAG 41 | 42 | 43 | 44 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/builders/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | bayes.builders 7 | 8 | 9 | 10 | 11 |

bayes.builders

12 |
13 |

Classes

14 | 23 |

Enums

24 | 28 |
29 | 30 | 31 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/builders/scoring/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | bayes.builders.scoring 7 | 8 | 9 | 10 | 11 |

bayes.builders.scoring

12 |
13 |

Interfaces

14 | 17 |

Classes

18 | 21 |
22 | 23 | 24 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/class-use/BNEvaluator.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class bayes.BNEvaluator 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
bayes.BNEvaluator

67 |
68 |
No usage of bayes.BNEvaluator
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/class-use/BNUtility.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class bayes.BNUtility 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
bayes.BNUtility

67 |
68 |
No usage of bayes.BNUtility
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/classifiers/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | bayes.classifiers 7 | 8 | 9 | 10 | 11 |

bayes.classifiers

12 |
13 |

Classes

14 | 17 |
18 | 19 | 20 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/classifiers/package-use.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Package bayes.classifiers 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Package
bayes.classifiers

67 |
68 |
No usage of bayes.classifiers
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/cpd/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | bayes.cpd 7 | 8 | 9 | 10 | 11 |

bayes.cpd

12 |
13 |

Classes

14 | 24 |
25 | 26 | 27 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/information/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | bayes.information 7 | 8 | 9 | 10 | 11 |

bayes.information

12 |
13 |

Classes

14 | 17 |
18 | 19 | 20 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/information/package-use.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Package bayes.information 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Package
bayes.information

67 |
68 |
No usage of bayes.information
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/bayes/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | bayes 7 | 8 | 9 | 10 | 11 |

bayes

12 |
13 |

Classes

14 | 26 |

Enums

27 | 30 |
31 | 32 | 33 | -------------------------------------------------------------------------------- /machine-learning/doc/common/classification/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | common.classification 7 | 8 | 9 | 10 | 11 |

common.classification

12 |
13 |

Interfaces

14 | 17 |

Classes

18 | 21 |
22 | 23 | 24 | -------------------------------------------------------------------------------- /machine-learning/doc/common/kfold/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | common.kfold 7 | 8 | 9 | 10 | 11 |

common.kfold

12 |
13 |

Classes

14 | 17 |
18 | 19 | 20 | -------------------------------------------------------------------------------- /machine-learning/doc/common/kfold/package-use.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Package common.kfold 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Package
common.kfold

67 |
68 |
No usage of common.kfold
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/graph/dag/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | graph.dag 7 | 8 | 9 | 10 | 11 |

graph.dag

12 |
13 |

Classes

14 | 18 |
19 | 20 | 21 | -------------------------------------------------------------------------------- /machine-learning/doc/graph/dag/package-use.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Package graph.dag 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Package
graph.dag

67 |
68 |
No usage of graph.dag
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/graph/prim/class-use/Prim.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class graph.prim.Prim 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
graph.prim.Prim

67 |
68 |
No usage of graph.prim.Prim
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/graph/prim/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | graph.prim 7 | 8 | 9 | 10 | 11 |

graph.prim

12 |
13 |

Classes

14 | 18 |
19 | 20 | 21 | -------------------------------------------------------------------------------- /machine-learning/doc/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Generated Documentation (Untitled) 7 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | <noscript> 61 | <div>JavaScript is disabled on your browser.</div> 62 | </noscript> 63 | <h2>Frame Alert</h2> 64 | <p>This document is designed to be viewed using the frames feature. If you see this message, you are using a non-frame-capable web client. Link to <a href="overview-summary.html">Non-frame version</a>.</p> 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /machine-learning/doc/main/class-use/DTMain.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class main.DTMain 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
main.DTMain

67 |
68 |
No usage of main.DTMain
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/main/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | main 7 | 8 | 9 | 10 | 11 |

main

12 |
13 |

Classes

14 | 20 |
21 | 22 | 23 | -------------------------------------------------------------------------------- /machine-learning/doc/main/package-use.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Package main 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Package
main

67 |
68 |
No usage of main
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/overview-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Overview List 7 | 8 | 9 | 10 | 11 |
All Classes
12 |
13 |

Packages

14 | 35 |
36 |

 

37 | 38 | 39 | -------------------------------------------------------------------------------- /machine-learning/doc/package-list: -------------------------------------------------------------------------------- 1 | bayes 2 | bayes.builders 3 | bayes.builders.scoring 4 | bayes.classifiers 5 | bayes.cpd 6 | bayes.information 7 | common.classification 8 | common.kfold 9 | data 10 | data.arff 11 | data.attribute 12 | data.instance 13 | graph.dag 14 | graph.prim 15 | main 16 | tree 17 | tree.classifiers 18 | tree.evaluate 19 | tree.train 20 | -------------------------------------------------------------------------------- /machine-learning/doc/resources/background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/doc/resources/background.gif -------------------------------------------------------------------------------- /machine-learning/doc/resources/tab.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/doc/resources/tab.gif -------------------------------------------------------------------------------- /machine-learning/doc/resources/titlebar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/doc/resources/titlebar.gif -------------------------------------------------------------------------------- /machine-learning/doc/resources/titlebar_end.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/doc/resources/titlebar_end.gif -------------------------------------------------------------------------------- /machine-learning/doc/tree/class-use/DtLeaf.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class tree.DtLeaf 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
tree.DtLeaf

67 |
68 |
No usage of tree.DtLeaf
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/tree/class-use/ID3Builder.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Class tree.ID3Builder 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Class
tree.ID3Builder

67 |
68 |
No usage of tree.ID3Builder
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/tree/classifiers/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | tree.classifiers 7 | 8 | 9 | 10 | 11 |

tree.classifiers

12 |
13 |

Classes

14 | 17 |
18 | 19 | 20 | -------------------------------------------------------------------------------- /machine-learning/doc/tree/classifiers/package-use.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Uses of Package tree.classifiers 7 | 8 | 9 | 10 | 11 | 17 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 36 |
37 | 64 | 65 |
66 |

Uses of Package
tree.classifiers

67 |
68 |
No usage of tree.classifiers
69 | 70 |
71 | 72 | 73 | 74 | 75 | 85 |
86 | 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /machine-learning/doc/tree/evaluate/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | tree.evaluate 7 | 8 | 9 | 10 | 11 |

tree.evaluate

12 |
13 |

Classes

14 | 18 |
19 | 20 | 21 | -------------------------------------------------------------------------------- /machine-learning/doc/tree/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | tree 7 | 8 | 9 | 10 | 11 |

tree

12 |
13 |

Classes

14 | 22 |
23 | 24 | 25 | -------------------------------------------------------------------------------- /machine-learning/doc/tree/train/package-frame.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | tree.train 7 | 8 | 9 | 10 | 11 |

tree.train

12 |
13 |

Classes

14 | 21 |
22 | 23 | 24 | -------------------------------------------------------------------------------- /machine-learning/examples/examples/DecisionTreeExample.java: -------------------------------------------------------------------------------- 1 | package examples; 2 | import tree.classifiers.ID3TreeClassifier; 3 | import classify.ClassificationResult; 4 | import data.DataSet; 5 | import data.reader.ArffReader; 6 | 7 | 8 | /** 9 | * The following example shows how to learn an ID3 decision tree classifier on 10 | * training data set and then how to classify a test data set. 11 | * 12 | */ 13 | public class DecisionTreeExample 14 | { 15 | public static void main( String[] args ) 16 | { 17 | /* 18 | * Paths to train and test ARFF files 19 | */ 20 | String trainArffPath = args[0]; 21 | String testArffPath = args[1]; 22 | 23 | /* 24 | * Read data 25 | */ 26 | DataSet trainingData = ArffReader.readFile(trainArffPath); 27 | DataSet testData = ArffReader.readFile(testArffPath); 28 | 29 | /* 30 | * Specify the class attribute name 31 | */ 32 | String classAttribute = args[2]; 33 | trainingData.setClassAttribute(classAttribute); 34 | testData.setClassAttribute(classAttribute); 35 | 36 | /* 37 | * Specify minimum number of instances at a node for splitting on a new attribute. 38 | */ 39 | int minInstancesForStopping = Integer.decode(args[3]); 40 | 41 | /* 42 | * Build the classifier 43 | */ 44 | ID3TreeClassifier classifier = new ID3TreeClassifier(minInstancesForStopping, trainingData); 45 | System.out.println(classifier); 46 | 47 | /* 48 | * Classify each instance in the test data set 49 | */ 50 | ClassificationResult result = classifier.classifyData(testData); 51 | System.out.println(result); 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /machine-learning/examples/examples/GraphAlgorithmExamples.java: -------------------------------------------------------------------------------- 1 | package examples; 2 | import java.util.Map; 3 | 4 | import graph.DirectedGraph; 5 | import graph.Path; 6 | import graph.bellmanford.BellmanFord; 7 | import graph.floydwarshall.AllPairsShortestPaths; 8 | import graph.floydwarshall.FloydWarshall; 9 | 10 | 11 | public class GraphAlgorithmExamples 12 | { 13 | 14 | public static void main(String[] args) 15 | { 16 | bellmanFordExample(); 17 | floydWarshallExample(); 18 | } 19 | 20 | 21 | /** 22 | * Run the Bellman-Ford algorithm to find the shortest paths to all 23 | * nodes from a source node. 24 | */ 25 | public static void bellmanFordExample() 26 | { 27 | DirectedGraph graph = createToyGraph_NoNegativeCycles(); 28 | 29 | // Designate the source node 30 | String sourceNode = "A"; 31 | 32 | // Run Bellman-Ford 33 | Map> shortestPaths = BellmanFord.runBellmanFord(graph, sourceNode); 34 | 35 | // Print each path 36 | for (Path path : shortestPaths.values()) 37 | { 38 | System.out.println(path); 39 | } 40 | } 41 | 42 | public static void floydWarshallExample() 43 | { 44 | DirectedGraph graph = createToyGraph_NoNegativeCycles(); 45 | 46 | AllPairsShortestPaths paths = FloydWarshall.runFloydWarshall(graph); 47 | 48 | System.out.println(paths.getPath("A", "B")); 49 | } 50 | 51 | /** 52 | * @return an example graph with no negative cycles. 53 | */ 54 | public static DirectedGraph createToyGraph_NoNegativeCycles() 55 | { 56 | DirectedGraph graph = new DirectedGraph<>(); 57 | 58 | graph.addEdge("A", "B", 3d); 59 | graph.addEdge("A", "E", 7d); 60 | graph.addEdge("B", "C", 5d); 61 | graph.addEdge("B", "E", 8d); 62 | graph.addEdge("B", "D", -4d); 63 | graph.addEdge("C", "B", -2d); 64 | graph.addEdge("D", "C", 7d); 65 | graph.addEdge("D", "A", 2d); 66 | graph.addEdge("E", "C", -3d); 67 | graph.addEdge("E", "D", 9d); 68 | 69 | return graph; 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /machine-learning/examples/examples/HMMExamples.java: -------------------------------------------------------------------------------- 1 | package examples; 2 | import math.LogP; 3 | import pair.Pair; 4 | 5 | import hmm.HMM; 6 | import hmm.State; 7 | import hmm.StateSilent; 8 | import hmm.Transition; 9 | import hmm.algorithms.BackwardAlgorithm; 10 | import hmm.algorithms.DpMatrix; 11 | import hmm.algorithms.ForwardAlgorithm; 12 | 13 | 14 | public class HMMExamples 15 | { 16 | /** 17 | * Example running the forward and backward algorithms on an HMM object. 18 | */ 19 | public static void main(String[] args) 20 | { 21 | 22 | HMM toyHmm = buildToyHMM(); 23 | String[] sequence = {"x", "y", "x", "x"}; 24 | 25 | /* 26 | * Forward algorithm 27 | */ 28 | Pair resultF = ForwardAlgorithm.run(toyHmm, sequence); 29 | 30 | /* 31 | * Backward algorithm 32 | */ 33 | Pair resultB = BackwardAlgorithm.run(toyHmm, sequence); 34 | 35 | System.out.println("Probability of sequence: " 36 | + LogP.exp(resultF.getFirst())); 37 | } 38 | 39 | /** 40 | * Example building an HMM object 41 | * 42 | * @return an example HMM 43 | */ 44 | public static HMM buildToyHMM() 45 | { 46 | HMM hmm = new HMM(); 47 | 48 | /* 49 | * Create states 50 | */ 51 | State A = new StateSilent("A"); 52 | State B = new StateSilent("B"); 53 | State C = new StateSilent("C"); 54 | State D = new StateSilent("D"); 55 | State E = new State("E"); 56 | State G = new State("G"); 57 | 58 | /* 59 | * Add transitions between states 60 | */ 61 | A.addTransition(new Transition("A", "E", LogP.ln(1.0 / 3.0))); 62 | A.addTransition(new Transition("A", "X", LogP.ln(1.0 / 3.0))); 63 | A.addTransition(new Transition("A", "B", LogP.ln(1.0 / 3.0))); 64 | B.addTransition(new Transition("B", "C", LogP.ln(0.5))); 65 | B.addTransition(new Transition("B", "G", LogP.ln(0.5))); 66 | C.addTransition(new Transition("C", "D", LogP.ln(1.0))); 67 | D.addTransition(new Transition("D", "E", LogP.ln(1.0))); 68 | E.addTransition(new Transition("E", "E", LogP.ln(0.5))); 69 | E.addTransition(new Transition("E", "G", LogP.ln(0.5))); 70 | 71 | /* 72 | * Create emission distribution for state E 73 | */ 74 | E.addEmission("x", LogP.ln(0.5)); 75 | E.addEmission("y", LogP.ln(0.5)); 76 | 77 | /* 78 | * Create emission distribution for state G 79 | */ 80 | G.addEmission("x", LogP.ln(0.9)); 81 | G.addEmission("y", LogP.ln(0.1)); 82 | 83 | /* 84 | * Add states to model 85 | */ 86 | hmm.addState(B); 87 | hmm.addState(D); 88 | hmm.addState(C); 89 | hmm.addState(E); 90 | hmm.addState(G); 91 | hmm.addState(A); 92 | 93 | /* 94 | * Set the begin state 95 | */ 96 | hmm.setBeginStateId("A"); 97 | hmm.setEndStateId("E"); 98 | 99 | return hmm; 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /machine-learning/gold_standard/tree/m10.tree: -------------------------------------------------------------------------------- 1 | thal = fixed_defect [4 6] 2 | | ca <= 0.500000 [4 0]: negative 3 | | ca > 0.500000 [0 6]: positive 4 | thal = normal [84 19] 5 | | thalach <= 111.500000 [0 4]: positive 6 | | thalach > 111.500000 [84 15] 7 | | | age <= 55.500000 [56 4] 8 | | | | trestbps <= 113.500000 [9 3] 9 | | | | | oldpeak <= 0.300000 [3 3]: negative 10 | | | | | oldpeak > 0.300000 [6 0]: negative 11 | | | | trestbps > 113.500000 [47 1] 12 | | | | | oldpeak <= 3.550000 [47 0]: negative 13 | | | | | oldpeak > 3.550000 [0 1]: positive 14 | | | age > 55.500000 [28 11] 15 | | | | chol <= 248.500000 [14 1] 16 | | | | | oldpeak <= 2.800000 [14 0]: negative 17 | | | | | oldpeak > 2.800000 [0 1]: positive 18 | | | | chol > 248.500000 [14 10] 19 | | | | | sex = female [13 3] 20 | | | | | | cp = typ_angina [1 0]: negative 21 | | | | | | cp = asympt [3 3]: negative 22 | | | | | | cp = non_anginal [7 0]: negative 23 | | | | | | cp = atyp_angina [2 0]: negative 24 | | | | | sex = male [1 7]: positive 25 | thal = reversable_defect [20 67] 26 | | cp = typ_angina [3 1]: negative 27 | | cp = asympt [5 53] 28 | | | oldpeak <= 0.650000 [5 10] 29 | | | | chol <= 240.500000 [5 2]: negative 30 | | | | chol > 240.500000 [0 8]: positive 31 | | | oldpeak > 0.650000 [0 43]: positive 32 | | cp = non_anginal [9 10] 33 | | | oldpeak <= 1.900000 [9 5] 34 | | | | trestbps <= 122.500000 [6 0]: negative 35 | | | | trestbps > 122.500000 [3 5]: positive 36 | | | oldpeak > 1.900000 [0 5]: positive 37 | | cp = atyp_angina [3 3]: negative -------------------------------------------------------------------------------- /machine-learning/gold_standard/tree/m2.tree: -------------------------------------------------------------------------------- 1 | 2 | thal = fixed_defect [4 6] 3 | | ca <= 0.500000 [4 0]: negative 4 | | ca > 0.500000 [0 6]: positive 5 | thal = normal [84 19] 6 | | thalach <= 111.500000 [0 4]: positive 7 | | thalach > 111.500000 [84 15] 8 | | | age <= 55.500000 [56 4] 9 | | | | trestbps <= 113.500000 [9 3] 10 | | | | | oldpeak <= 0.300000 [3 3] 11 | | | | | | cp = typ_angina [0 0]: negative 12 | | | | | | cp = asympt [0 2]: positive 13 | | | | | | cp = non_anginal [1 1] 14 | | | | | | | age <= 44.000000 [1 0]: negative 15 | | | | | | | age > 44.000000 [0 1]: positive 16 | | | | | | cp = atyp_angina [2 0]: negative 17 | | | | | oldpeak > 0.300000 [6 0]: negative 18 | | | | trestbps > 113.500000 [47 1] 19 | | | | | oldpeak <= 3.550000 [47 0]: negative 20 | | | | | oldpeak > 3.550000 [0 1]: positive 21 | | | age > 55.500000 [28 11] 22 | | | | chol <= 248.500000 [14 1] 23 | | | | | oldpeak <= 2.800000 [14 0]: negative 24 | | | | | oldpeak > 2.800000 [0 1]: positive 25 | | | | chol > 248.500000 [14 10] 26 | | | | | sex = female [13 3] 27 | | | | | | cp = typ_angina [1 0]: negative 28 | | | | | | cp = asympt [3 3] 29 | | | | | | | age <= 58.000000 [2 0]: negative 30 | | | | | | | age > 58.000000 [1 3] 31 | | | | | | | | chol <= 362.000000 [0 3]: positive 32 | | | | | | | | chol > 362.000000 [1 0]: negative 33 | | | | | | cp = non_anginal [7 0]: negative 34 | | | | | | cp = atyp_angina [2 0]: negative 35 | | | | | sex = male [1 7] 36 | | | | | | age <= 65.500000 [0 5]: positive 37 | | | | | | age > 65.500000 [1 2] 38 | | | | | | | age <= 66.500000 [1 0]: negative 39 | | | | | | | age > 66.500000 [0 2]: positive 40 | thal = reversable_defect [20 67] 41 | | cp = typ_angina [3 1] 42 | | | oldpeak <= 0.700000 [0 1]: positive 43 | | | oldpeak > 0.700000 [3 0]: negative 44 | | cp = asympt [5 53] 45 | | | oldpeak <= 0.650000 [5 10] 46 | | | | chol <= 240.500000 [5 2] 47 | | | | | chol <= 192.000000 [1 2] 48 | | | | | | age <= 62.000000 [0 2]: positive 49 | | | | | | age > 62.000000 [1 0]: negative 50 | | | | | chol > 192.000000 [4 0]: negative 51 | | | | chol > 240.500000 [0 8]: positive 52 | | | oldpeak > 0.650000 [0 43]: positive 53 | | cp = non_anginal [9 10] 54 | | | oldpeak <= 1.900000 [9 5] 55 | | | | trestbps <= 122.500000 [6 0]: negative 56 | | | | trestbps > 122.500000 [3 5] 57 | | | | | chol <= 232.500000 [3 1] 58 | | | | | | trestbps <= 129.000000 [0 1]: positive 59 | | | | | | trestbps > 129.000000 [3 0]: negative 60 | | | | | chol > 232.500000 [0 4]: positive 61 | | | oldpeak > 1.900000 [0 5]: positive 62 | | cp = atyp_angina [3 3] 63 | | | age <= 47.000000 [2 0]: negative 64 | | | age > 47.000000 [1 3] 65 | | | | trestbps <= 109.000000 [1 0]: negative 66 | | | | trestbps > 109.000000 [0 3]: positive -------------------------------------------------------------------------------- /machine-learning/gold_standard/tree/m20.tree: -------------------------------------------------------------------------------- 1 | 2 | thal = fixed_defect [4 6]: positive 3 | thal = normal [84 19] 4 | | thalach <= 111.500000 [0 4]: positive 5 | | thalach > 111.500000 [84 15] 6 | | | age <= 55.500000 [56 4] 7 | | | | trestbps <= 113.500000 [9 3]: negative 8 | | | | trestbps > 113.500000 [47 1] 9 | | | | | oldpeak <= 3.550000 [47 0]: negative 10 | | | | | oldpeak > 3.550000 [0 1]: positive 11 | | | age > 55.500000 [28 11] 12 | | | | chol <= 248.500000 [14 1]: negative 13 | | | | chol > 248.500000 [14 10] 14 | | | | | sex = female [13 3]: negative 15 | | | | | sex = male [1 7]: positive 16 | thal = reversable_defect [20 67] 17 | | cp = typ_angina [3 1]: negative 18 | | cp = asympt [5 53] 19 | | | oldpeak <= 0.650000 [5 10]: positive 20 | | | oldpeak > 0.650000 [0 43]: positive 21 | | cp = non_anginal [9 10]: positive 22 | | cp = atyp_angina [3 3]: negative -------------------------------------------------------------------------------- /machine-learning/gold_standard/tree/m4.tree: -------------------------------------------------------------------------------- 1 | 2 | thal = fixed_defect [4 6] 3 | | ca <= 0.500000 [4 0]: negative 4 | | ca > 0.500000 [0 6]: positive 5 | thal = normal [84 19] 6 | | thalach <= 111.500000 [0 4]: positive 7 | | thalach > 111.500000 [84 15] 8 | | | age <= 55.500000 [56 4] 9 | | | | trestbps <= 113.500000 [9 3] 10 | | | | | oldpeak <= 0.300000 [3 3] 11 | | | | | | cp = typ_angina [0 0]: negative 12 | | | | | | cp = asympt [0 2]: positive 13 | | | | | | cp = non_anginal [1 1]: negative 14 | | | | | | cp = atyp_angina [2 0]: negative 15 | | | | | oldpeak > 0.300000 [6 0]: negative 16 | | | | trestbps > 113.500000 [47 1] 17 | | | | | oldpeak <= 3.550000 [47 0]: negative 18 | | | | | oldpeak > 3.550000 [0 1]: positive 19 | | | age > 55.500000 [28 11] 20 | | | | chol <= 248.500000 [14 1] 21 | | | | | oldpeak <= 2.800000 [14 0]: negative 22 | | | | | oldpeak > 2.800000 [0 1]: positive 23 | | | | chol > 248.500000 [14 10] 24 | | | | | sex = female [13 3] 25 | | | | | | cp = typ_angina [1 0]: negative 26 | | | | | | cp = asympt [3 3] 27 | | | | | | | age <= 58.000000 [2 0]: negative 28 | | | | | | | age > 58.000000 [1 3] 29 | | | | | | | | chol <= 362.000000 [0 3]: positive 30 | | | | | | | | chol > 362.000000 [1 0]: negative 31 | | | | | | cp = non_anginal [7 0]: negative 32 | | | | | | cp = atyp_angina [2 0]: negative 33 | | | | | sex = male [1 7] 34 | | | | | | age <= 65.500000 [0 5]: positive 35 | | | | | | age > 65.500000 [1 2]: positive 36 | thal = reversable_defect [20 67] 37 | | cp = typ_angina [3 1] 38 | | | oldpeak <= 0.700000 [0 1]: positive 39 | | | oldpeak > 0.700000 [3 0]: negative 40 | | cp = asympt [5 53] 41 | | | oldpeak <= 0.650000 [5 10] 42 | | | | chol <= 240.500000 [5 2] 43 | | | | | chol <= 192.000000 [1 2]: positive 44 | | | | | chol > 192.000000 [4 0]: negative 45 | | | | chol > 240.500000 [0 8]: positive 46 | | | oldpeak > 0.650000 [0 43]: positive 47 | | cp = non_anginal [9 10] 48 | | | oldpeak <= 1.900000 [9 5] 49 | | | | trestbps <= 122.500000 [6 0]: negative 50 | | | | trestbps > 122.500000 [3 5] 51 | | | | | chol <= 232.500000 [3 1] 52 | | | | | | trestbps <= 129.000000 [0 1]: positive 53 | | | | | | trestbps > 129.000000 [3 0]: negative 54 | | | | | chol > 232.500000 [0 4]: positive 55 | | | oldpeak > 1.900000 [0 5]: positive 56 | | cp = atyp_angina [3 3] 57 | | | age <= 47.000000 [2 0]: negative 58 | | | age > 47.000000 [1 3] 59 | | | | trestbps <= 109.000000 [1 0]: negative 60 | | | | trestbps > 109.000000 [0 3]: positive -------------------------------------------------------------------------------- /machine-learning/lib/data_structures.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/data_structures.jar -------------------------------------------------------------------------------- /machine-learning/lib/guava-18.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/guava-18.0.jar -------------------------------------------------------------------------------- /machine-learning/lib/hamcrest-core-1.3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/hamcrest-core-1.3.jar -------------------------------------------------------------------------------- /machine-learning/lib/junit-4.12-beta-3.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/junit-4.12-beta-3.jar -------------------------------------------------------------------------------- /machine-learning/lib/lombok.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/lombok.jar -------------------------------------------------------------------------------- /machine-learning/lib/mockito-all-1.9.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/mockito-all-1.9.5.jar -------------------------------------------------------------------------------- /machine-learning/src/applications/MainHillClimbing.java: -------------------------------------------------------------------------------- 1 | package applications; 2 | 3 | import java.io.FileNotFoundException; 4 | import java.io.PrintWriter; 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | 8 | 9 | import data.DataSet; 10 | import data.fold.KFoldCreator; 11 | import data.reader.ArffReader; 12 | 13 | 14 | import pair.Pair; 15 | 16 | import bayes.BNEvaluator; 17 | import bayes.BNResultWriter; 18 | import bayes.BayesianNetwork; 19 | import bayes.structuresearch.HillClimbingBuilder; 20 | import bayes.structuresearch.score.BIC; 21 | 22 | public class MainHillClimbing 23 | { 24 | public static void main(String[] args) 25 | { 26 | try 27 | { 28 | PrintWriter out = new PrintWriter(args[1]); 29 | 30 | /* 31 | * Read the training data from the arff file 32 | */ 33 | ArffReader reader = new ArffReader(); 34 | DataSet data = reader.readFile(args[0]); 35 | 36 | /* 37 | * Scoring function 38 | */ 39 | BIC bic = new BIC(); 40 | 41 | List> folds = KFoldCreator.create(data, 5); 42 | 43 | 44 | out.println("Result on folds:"); 45 | Double scoreSum = 0.0; 46 | for (int i = 0; i < folds.size(); i++) 47 | { 48 | /* 49 | * TODO: BAD! 50 | */ 51 | BNResultWriter.WRITER = new PrintWriter(args[2]+ "_" + i); 52 | 53 | Pair fold = folds.get(i); 54 | 55 | HillClimbingBuilder hcBuilder = new HillClimbingBuilder(); 56 | BayesianNetwork net = hcBuilder.buildNetwork(fold.getFirst(), 57 | 1, 58 | bic, 59 | null); 60 | 61 | Double score = BNEvaluator.calculateLogLikelihood(net, 62 | fold.getSecond()); 63 | 64 | scoreSum += score; 65 | 66 | out.println("\n\n-------- Fold " + i + " --------\n"); 67 | 68 | out.println("Net Structure: "); 69 | out.print(net); 70 | 71 | out.println("Likelihood of test data: "); 72 | out.println(score); 73 | 74 | /* 75 | * TODO: BAD! 76 | */ 77 | BNResultWriter.WRITER.close(); 78 | } 79 | out.println("Average likelihood:"); 80 | out.println(scoreSum / folds.size()); 81 | out.close(); 82 | } 83 | catch(FileNotFoundException e) 84 | { 85 | System.out.println("Error instantiating output file writer."); 86 | System.exit(1); 87 | } 88 | 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /machine-learning/src/applications/MainNaiveBayes.java: -------------------------------------------------------------------------------- 1 | package applications; 2 | 3 | 4 | 5 | import classify.ClassificationResult; 6 | 7 | import data.DataSet; 8 | import data.reader.ArffReader; 9 | import bayes.classifiers.NaiveBayesClassifier; 10 | 11 | public class MainNaiveBayes 12 | { 13 | 14 | private static final String CLASS_ATTR_NAME = "class"; 15 | private static final Integer LAPLACE_COUNT = 1; 16 | 17 | public static void main(String[] args) 18 | { 19 | /* 20 | * Print useage if not enough arguments 21 | */ 22 | if (args.length != 3) 23 | { 24 | printUsage(); 25 | return; 26 | } 27 | 28 | /* 29 | * Determine whether to use TAN or simple Naive Bayes 30 | */ 31 | boolean tan; 32 | if (args[2].equals("t")) 33 | { 34 | tan = true; 35 | } 36 | else if (args[2].equals("n")) 37 | { 38 | tan = false; 39 | } 40 | else 41 | { 42 | printUsage(); 43 | return; 44 | } 45 | 46 | /* 47 | * Read the training data from the arff file 48 | */ 49 | ArffReader reader = new ArffReader(); 50 | DataSet data = reader.readFile(args[0]); 51 | data.setClassAttribute(CLASS_ATTR_NAME); 52 | 53 | /* 54 | * Create Naive Bayes classifier 55 | */ 56 | NaiveBayesClassifier nbClassifier = 57 | new NaiveBayesClassifier(data, LAPLACE_COUNT, tan); 58 | 59 | /* 60 | * Print network 61 | */ 62 | System.out.println(nbClassifier); 63 | 64 | /* 65 | * Read the training data from the arff file 66 | */ 67 | DataSet testData = reader.readFile(args[1]); 68 | testData.setClassAttribute(CLASS_ATTR_NAME); 69 | 70 | /* 71 | * Classify the data 72 | */ 73 | ClassificationResult result = nbClassifier.classifyData(testData); 74 | 75 | /* 76 | * Print results to standard output 77 | */ 78 | System.out.print("\n\n"); 79 | System.out.println(result); 80 | } 81 | 82 | /** 83 | * Prints this program's usage to standard output 84 | */ 85 | public static void printUsage() 86 | { 87 | System.out.println("\nUsage: bayes \n"); 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /machine-learning/src/applications/MainSparseCandidate.java: -------------------------------------------------------------------------------- 1 | package applications; 2 | 3 | import java.io.FileNotFoundException; 4 | import java.io.PrintWriter; 5 | import java.util.ArrayList; 6 | import java.util.List; 7 | 8 | 9 | import data.DataSet; 10 | import data.fold.KFoldCreator; 11 | import data.reader.ArffReader; 12 | 13 | 14 | import pair.Pair; 15 | 16 | import bayes.BNEvaluator; 17 | import bayes.BNResultWriter; 18 | import bayes.BayesianNetwork; 19 | import bayes.structuresearch.SparseCandidateBuilder; 20 | import bayes.structuresearch.score.BIC; 21 | 22 | public class MainSparseCandidate 23 | { 24 | public static void main(String[] args) 25 | { 26 | try 27 | { 28 | PrintWriter out = new PrintWriter(args[1]); 29 | 30 | /* 31 | * Read the training data from the arff file 32 | */ 33 | ArffReader reader = new ArffReader(); 34 | DataSet data = reader.readFile(args[0]); 35 | 36 | /* 37 | * Scoring function 38 | */ 39 | BIC bic = new BIC(); 40 | 41 | List> folds = KFoldCreator.create(data, 5); 42 | 43 | 44 | out.println("Result on folds:"); 45 | Double scoreSum = 0.0; 46 | for (int i = 0; i < folds.size(); i++) 47 | { 48 | /* 49 | * TODO: BAD! 50 | */ 51 | BNResultWriter.WRITER = new PrintWriter(args[2]+ "_" + i); 52 | 53 | Pair fold = folds.get(i); 54 | 55 | SparseCandidateBuilder spBuilder = new SparseCandidateBuilder(); 56 | BayesianNetwork net = spBuilder.buildNetwork(fold.getFirst(), 57 | 1, 58 | bic, 59 | null); 60 | 61 | Double score = BNEvaluator.calculateLogLikelihood(net, 62 | fold.getSecond()); 63 | 64 | scoreSum += score; 65 | 66 | out.println("\n\n-------- Fold " + i + " --------\n"); 67 | 68 | out.println("Net Structure: "); 69 | out.print(net); 70 | 71 | out.println("Likelihood of test data: "); 72 | out.println(score); 73 | 74 | /* 75 | * TODO: BAD! 76 | */ 77 | BNResultWriter.WRITER.close(); 78 | } 79 | out.println("Average likelihood:"); 80 | out.println(scoreSum / folds.size()); 81 | out.close(); 82 | } 83 | catch(FileNotFoundException e) 84 | { 85 | System.out.println("Error instantiating output file writer."); 86 | System.exit(1); 87 | } 88 | 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/BNEvaluator.java: -------------------------------------------------------------------------------- 1 | package bayes; 2 | 3 | import java.util.ArrayList; 4 | 5 | 6 | import data.Attribute; 7 | import data.DataSet; 8 | import data.Instance; 9 | 10 | public class BNEvaluator 11 | { 12 | 13 | /** 14 | * 15 | * For each instance in the data, calculate the log-probability of the 16 | * bayes net producing this data. Sum the resulting log-probability 17 | * of each instance to get the total log-likelihood of the network 18 | * generating the data. Tha 19 | * 20 | * @param net the Bayes net used to calcuate the likelihood of the data 21 | * @param data the data for which we want to know the likelihood 22 | * @return the log-likelihood of seeing the data given the net 23 | */ 24 | public static Double calculateLogLikelihood(BayesianNetwork net, DataSet data) 25 | { 26 | Double logProduct = 0.0; 27 | 28 | for (Instance instance : data.getInstanceSet().getInstances()) 29 | { 30 | ArrayList queries = createQueries(instance, net, data); 31 | 32 | /* 33 | * Sum over the probability of each instance 34 | */ 35 | for (BNConditionalQuery query : queries) 36 | { 37 | Double p = net.queryConditionalProbability(query); 38 | logProduct += -Math.log(p); 39 | } 40 | } 41 | 42 | return logProduct; 43 | } 44 | 45 | /** 46 | * For each instance we need to create a conditional probability 47 | * query on the value of each instance's attributes given the values of the 48 | * rest of the attributes. 49 | * 50 | * @param instance the instance under examination 51 | * @param net the Bayes net 52 | * @param data the data set 53 | * @return a query for each attribute in the instance 54 | */ 55 | public static ArrayList createQueries(Instance instance, 56 | BayesianNetwork net, 57 | DataSet data) 58 | { 59 | ArrayList queries = new ArrayList(); 60 | 61 | for (BNNode targetNode : net.getNodes()) 62 | { 63 | BNConditionalQuery query = new BNConditionalQuery(); 64 | 65 | /* 66 | * Set target attribute/value 67 | */ 68 | Attribute targetAttr = targetNode.getAttribute(); 69 | Integer targetAttrValue 70 | = instance.getAttributeValue(targetAttr).intValue(); 71 | query.setTargetVariable(targetAttr, targetAttrValue); 72 | 73 | /* 74 | * Set each condition attribute/value 75 | */ 76 | for (BNNode conditionNode : net.getNodes()) 77 | { 78 | Attribute conditionAttr = conditionNode.getAttribute(); 79 | 80 | boolean isChildOfTarget = targetNode.getParents().contains(conditionNode); 81 | 82 | if (!conditionNode.equals(targetNode) && isChildOfTarget) 83 | { 84 | Integer conditionAttrValue 85 | = instance.getAttributeValue(conditionAttr).intValue(); 86 | 87 | query.addConditionVariable(conditionAttr, conditionAttrValue); 88 | } 89 | } 90 | 91 | queries.add(query); 92 | } 93 | 94 | return queries; 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/BNJointQuery.java: -------------------------------------------------------------------------------- 1 | package bayes; 2 | 3 | import java.util.ArrayList; 4 | 5 | import data.Attribute; 6 | 7 | import pair.Pair; 8 | 9 | /** 10 | * Used for querying a {@code BayesianNetwork} object for a joint 11 | * probability of a set of attribute/value pairs. For example, if we wish to 12 | * query a Bayes net for the following probability: P(A = a, E = e, D = d), 13 | * this object would represent the query, (A = a, E = e, D = d) 14 | * 15 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 16 | * 17 | */ 18 | public class BNJointQuery 19 | { 20 | 21 | /** 22 | * Each pair of integers represents an Attribute ID and the nominal value ID 23 | * of the attribute of all of the variables in the joint probability query. 24 | *
25 | *
26 | * For example, given the following query: P(A = a, D = d, E = e), this 27 | * set of pairs would be as follows: (A,a), (D,d), (E,e). 28 | */ 29 | private VariableSet variables; 30 | 31 | /** 32 | * Constructor 33 | * 34 | * @param variables the set of variables in the joint probability query 35 | */ 36 | public BNJointQuery(VariableSet variables) 37 | { 38 | this.variables = variables; 39 | } 40 | 41 | /** 42 | * Constructor 43 | */ 44 | public BNJointQuery() 45 | { 46 | this.variables = new VariableSet(); 47 | } 48 | 49 | /** 50 | * Add an attribute/value pair to the set of attribute/value pairs 51 | * used in the query. 52 | * 53 | * @param attr the ID of the condition attribute to be added to the query 54 | * @param nominalValueId the nominal value ID specified for this Attribute 55 | */ 56 | public void addVariable(Attribute attr, Integer nomValueId) 57 | { 58 | variables.addVariable(attr, nomValueId); 59 | } 60 | 61 | /** 62 | * Determines whether this BNQuery includes the value of a specific 63 | * attribute in the set of condition attributes 64 | * 65 | * @param attr the target Attribute 66 | * @return true if this query is specifying a value for this specific 67 | * Attribute in the set of condition attributes 68 | */ 69 | public Boolean containsAttribute(Attribute attr) 70 | { 71 | return variables.containsAttribute(attr); 72 | } 73 | 74 | /** 75 | * Gets the value for a specific attribute in the set of 76 | * attributes. If this attribute is not specified by this query, this 77 | * method returns null. 78 | * 79 | * @param attr the target Attribute 80 | * @return the value of target Attribute specified in the 81 | * of the query. null if this Attribute is not specified in this query 82 | */ 83 | public Integer getValueForAttribute(Attribute attr) 84 | { 85 | return variables.getValueForAttribute(attr); 86 | } 87 | 88 | /** 89 | * @return the list of attribute/value pairs in this joint probability 90 | * query. Each pair is the ID of the attribute and the nominal value ID 91 | * of the value of this attribute. 92 | */ 93 | public ArrayList> getVariables() 94 | { 95 | return this.variables.getVariables(); 96 | } 97 | 98 | @Override 99 | public String toString() 100 | { 101 | String result = "P("; 102 | 103 | /* 104 | * Condition variables 105 | */ 106 | for (Pair attrValPair : this.variables.getVariables()) 107 | { 108 | Attribute attr = attrValPair.getFirst(); 109 | Integer attrValue = attrValPair.getSecond(); 110 | result += attr.getName() + " = " + attr.getNominalValueName(attrValue); 111 | result += ", "; 112 | } 113 | 114 | /* 115 | * Closing paranthesis 116 | */ 117 | result = result.substring(0, result.length() - 2); 118 | result += ")"; 119 | 120 | return result; 121 | } 122 | 123 | } 124 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/BNResultWriter.java: -------------------------------------------------------------------------------- 1 | package bayes; 2 | 3 | import java.io.PrintWriter; 4 | 5 | public class BNResultWriter 6 | { 7 | public static PrintWriter WRITER; 8 | } 9 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/BNUtility.java: -------------------------------------------------------------------------------- 1 | package bayes; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | /** 7 | * A utility class for miscellaneous methods needed for Bayesian network 8 | * learning. 9 | * 10 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 11 | * 12 | */ 13 | public class BNUtility 14 | { 15 | public static Double[][] convertToAdjacencyMatrix(List nodes) 16 | { 17 | int numNodes = nodes.size(); 18 | Double[][] graph = new Double[numNodes][numNodes]; 19 | 20 | /* 21 | * Initialize every element in graph to null 22 | */ 23 | for (int r = 0; r < numNodes; r++) 24 | { 25 | for (int c = 0; c < numNodes; c++) 26 | { 27 | graph[r][c] = null; 28 | } 29 | } 30 | 31 | /* 32 | * Assign graph[P][C] to 1.0 if C is a child of P 33 | */ 34 | for (int pIndex = 0; pIndex < nodes.size(); pIndex++) 35 | { 36 | BNNode currNode = nodes.get(pIndex); 37 | 38 | for (BNNode child : currNode.getChildren()) 39 | { 40 | int cIndex = nodes.indexOf(child); 41 | graph[pIndex][cIndex] = 1.0; 42 | } 43 | } 44 | 45 | return graph; 46 | } 47 | 48 | public static Double[][] convertToAdjacencyMatrix(BayesianNetwork net) 49 | { 50 | int numNodes = net.getNumNodes(); 51 | Double[][] graph = new Double[numNodes][numNodes]; 52 | 53 | /* 54 | * Initialize every element in graph to null 55 | */ 56 | for (int r = 0; r < numNodes; r++) 57 | { 58 | for (int c = 0; c < numNodes; c++) 59 | { 60 | graph[r][c] = null; 61 | } 62 | } 63 | 64 | /* 65 | * Assign graph[P][C] to 1.0 if C is a child of P 66 | */ 67 | List nodes = net.getNodes(); 68 | for (int pIndex = 0; pIndex < nodes.size(); pIndex++) 69 | { 70 | BNNode currNode = nodes.get(pIndex); 71 | 72 | for (BNNode child : currNode.getChildren()) 73 | { 74 | int cIndex = nodes.indexOf(child); 75 | graph[pIndex][cIndex] = 1.0; 76 | } 77 | } 78 | 79 | return graph; 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/VariableSet.java: -------------------------------------------------------------------------------- 1 | package bayes; 2 | 3 | import java.util.ArrayList; 4 | 5 | import pair.Pair; 6 | import data.Attribute; 7 | 8 | /** 9 | * This class represents represents a set of attribute/value pairs. This is 10 | * a component of Bayes net query objects. 11 | * 12 | * @author matthewbernstein 13 | */ 14 | public class VariableSet 15 | { 16 | /** 17 | * Each pair of integers represents an Attribute ID and the nominal value ID 18 | * of the attribute of all of the variables in this set. 19 | */ 20 | private ArrayList> variables; 21 | 22 | /** 23 | * Constructor 24 | */ 25 | public VariableSet() 26 | { 27 | this.variables = new ArrayList>(); 28 | } 29 | 30 | /** 31 | * Add an attribute/value pair to the set of attribute/value pairs 32 | * to this set. 33 | * 34 | * @param attr the ID of the condition attribute to be added to the set 35 | * @param nominalValueId the nominal value ID specified for this Attribute 36 | */ 37 | public void addVariable(Attribute attr, Integer nomValueId) 38 | { 39 | Pair newItem = 40 | new Pair(attr, nomValueId); 41 | 42 | variables.add(newItem); 43 | } 44 | 45 | /** 46 | * Determines whether this set includes the value of a specific 47 | * attribute in the set of attributes 48 | * 49 | * @param attr the target Attribute 50 | * @return true if this set has a value for this specific attribute 51 | */ 52 | public Boolean containsAttribute(Attribute attr) 53 | { 54 | // Linear Search for a match in Attribute ID's 55 | for (Pair item : variables) 56 | { 57 | if (item.getFirst().equals(attr)) 58 | { 59 | return true; 60 | } 61 | } 62 | 63 | return false; 64 | } 65 | 66 | /** 67 | * Gets the value for a specific attribute in the set of 68 | * attributes. If this attribute is not in this set, this method returns 69 | * null. 70 | * 71 | * @param attr the target Attribute 72 | * @return the value of target Attribute specified in the 73 | * of the query. null if this Attribute is not specified in this query 74 | */ 75 | public Integer getValueForAttribute(Attribute attr) 76 | { 77 | if (attr == null) 78 | { 79 | return null; 80 | } 81 | 82 | /* 83 | * Linear search for the attribute 84 | */ 85 | for (Pair item : variables) 86 | { 87 | if (item.getFirst().equals(attr)) 88 | { 89 | return item.getSecond(); 90 | } 91 | } 92 | 93 | return null; 94 | } 95 | 96 | /** 97 | * @return the list of attribute/value pairs in this joint probability 98 | * query. Each pair is the ID of the attribute and the nominal value ID 99 | * of the value of this attribute. 100 | */ 101 | public ArrayList> getVariables() 102 | { 103 | return this.variables; 104 | } 105 | 106 | @Override 107 | public Object clone() 108 | { 109 | VariableSet copy = new VariableSet(); 110 | 111 | for (Pair pair : this.variables) 112 | { 113 | copy.addVariable(pair.getFirst(), pair.getSecond()); 114 | } 115 | 116 | return copy; 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/cpd/CPDLeaf.java: -------------------------------------------------------------------------------- 1 | package bayes.cpd; 2 | 3 | import data.Attribute; 4 | 5 | /** 6 | * A leaf node in a CPD tree. 7 | * 8 | * @author Mathew Bernstien - matthewb@cs.wisc.edu 9 | * 10 | */ 11 | public class CPDLeaf extends CPDNode 12 | { 13 | protected double probability; 14 | 15 | private Integer laplaceCount; 16 | 17 | /** 18 | * Constructor 19 | * 20 | * @param probability the probability at this leaf CPDNode 21 | */ 22 | public CPDLeaf(Attribute attribute, 23 | Integer nodeValue, 24 | Integer numInstances, 25 | Integer laplaceCount) 26 | { 27 | super(attribute, nodeValue, numInstances); 28 | 29 | this.laplaceCount = laplaceCount; 30 | } 31 | 32 | @Override 33 | public String toString() 34 | { 35 | String result = super.toString(); 36 | result += " [" + probability + "]"; 37 | return result; 38 | } 39 | 40 | /** 41 | * @return the probability at this leaf node 42 | */ 43 | public double getProbability() 44 | { 45 | return this.probability; 46 | } 47 | 48 | /** 49 | * Calculate the probability of a specific query on this leaf node 50 | * 51 | * @param query the query object used to specify specific values of the 52 | * attributes for which this CPD's leaf attribute is conditioned on 53 | */ 54 | @Override 55 | public Double calculateProbability(CPDQuery query) 56 | { 57 | /* 58 | * Get the query value for this node's attribute 59 | */ 60 | Integer queryValue = query.getValueForQueryAttribute(this.attribute); 61 | 62 | /* 63 | * Return this leaf's probability if no specific value for this 64 | * CPDNode's attribute was specified in the query or if the value 65 | * matches this CPDNode's attribute in the query. Otherwise, we 66 | * return null. 67 | */ 68 | if (queryValue == null || queryValue == this.nodeValue) 69 | { 70 | return this.probability; 71 | } 72 | else 73 | { 74 | return 0.0; 75 | } 76 | } 77 | 78 | /** 79 | * Set the parent for this leaf node. This, in turn, will calculate the 80 | * probability at this leaf. 81 | * 82 | * @param parent the parent CPDNode 83 | */ 84 | @Override 85 | public void setParent(CPDNode parent) 86 | { 87 | /* 88 | * Set the parent 89 | */ 90 | this.parent = parent; 91 | 92 | /* 93 | * Calculate leaf probability using Laplace counts 94 | */ 95 | double numerator = (double) numInstances + laplaceCount; 96 | double denominator = parent.numInstances 97 | + (laplaceCount * attribute.getNominalValueMap().size()); 98 | 99 | this.probability = numerator / denominator; 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/cpd/CPDQuery.java: -------------------------------------------------------------------------------- 1 | package bayes.cpd; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.Map.Entry; 6 | 7 | import data.Attribute; 8 | 9 | /** 10 | * Objects of this class are used for querying a conditional probability 11 | * distribution for a single BNNode in a BayesianNetwork 12 | * 13 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 14 | * 15 | */ 16 | public class CPDQuery 17 | { 18 | /** 19 | * Maps an attribute to the value of this attribute specified in the query 20 | */ 21 | private Map queryItems; 22 | 23 | /** 24 | * Constructor 25 | */ 26 | public CPDQuery() 27 | { 28 | this.queryItems = new HashMap(); 29 | } 30 | 31 | /** 32 | * Add a query item to this query 33 | * 34 | * @param attr the Attribute this query is querying for 35 | * @param nominalValueId the nominal value ID specified for this Attribute 36 | */ 37 | public void addQueryItem(Attribute attr, Integer nomValueId) 38 | { 39 | if (attr.isValidNominalValueId(nomValueId)) 40 | { 41 | queryItems.put(attr, nomValueId); 42 | } 43 | else 44 | { 45 | throw new RuntimeException(nomValueId + " is not a valid nominal" + 46 | " value ID for the attribute " + 47 | attr.getName()); 48 | } 49 | } 50 | 51 | /** 52 | * Determines whether this BNQuery includes the value of a specific 53 | * attribute 54 | * 55 | * @param attr the target Attribute 56 | * @return true if this query is specifying a value for this specific 57 | * Attribute 58 | */ 59 | public Boolean containsAttribute(Attribute attr) 60 | { 61 | return queryItems.keySet().contains(attr); 62 | } 63 | 64 | /** 65 | * Gets the value for a specific query attribute. If this attribute 66 | * is not specified by this query, this method returns null. 67 | * 68 | * @param attr the target Attribute 69 | * @return the value of target Attribute specified in this query. null if 70 | * this Attribute is not specified in this query 71 | */ 72 | public Integer getValueForQueryAttribute(Attribute attr) 73 | { 74 | return queryItems.get(attr); 75 | } 76 | 77 | @Override 78 | public String toString() 79 | { 80 | String result = "CPD("; 81 | 82 | for (Entry entry : queryItems.entrySet()) 83 | { 84 | Attribute attr = entry.getKey(); 85 | Integer attrValue = entry.getValue(); 86 | result += attr.getName() + " = " 87 | + attr.getNominalValueName(attrValue) + ", "; 88 | } 89 | 90 | result = result.substring(0, result.length() - 2); 91 | result += ")"; 92 | 93 | return result; 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/cpd/CPDTree.java: -------------------------------------------------------------------------------- 1 | package bayes.cpd; 2 | 3 | import data.AttributeSet; 4 | 5 | /** 6 | * A tree data structure used for representing and storing the 7 | * conditional probability distribution (CPD) for a specific node in a 8 | * Bayesian network. 9 | * 10 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 11 | * 12 | */ 13 | public class CPDTree 14 | { 15 | /** 16 | * Total instances in the training data set used for generating this CPD 17 | */ 18 | protected static Integer totalInstances; 19 | 20 | /** 21 | * The set of attributes in the training set used for generating this CPD 22 | */ 23 | protected AttributeSet attributeSet; 24 | 25 | /** 26 | * The root node of the CPD tree 27 | */ 28 | protected CPDNode root; 29 | 30 | /** 31 | * @return the root node of this CPD tree 32 | */ 33 | public CPDNode getRoot() 34 | { 35 | return root; 36 | } 37 | 38 | /** 39 | * Print the subtree rooted at the input Node to standard output. 40 | * 41 | * @param root the node that roots the subtree being printed 42 | */ 43 | public String toString() 44 | { 45 | String result = ""; 46 | 47 | for (CPDNode child : root.getChildren()) 48 | { 49 | ToStringHelper treePrinter = new ToStringHelper(this.attributeSet); 50 | result += treePrinter.getString(child, 0); 51 | } 52 | 53 | return result; 54 | } 55 | 56 | /** 57 | * A private helper class used for traversing the tree recursively 58 | * in order to convert the tree to a String 59 | */ 60 | private static class ToStringHelper 61 | { 62 | @SuppressWarnings("unused") 63 | private AttributeSet attributeSet; 64 | 65 | public ToStringHelper(AttributeSet attributeSet) 66 | { 67 | this.attributeSet = attributeSet; 68 | } 69 | 70 | public String getString(CPDNode node, Integer depth) 71 | { 72 | String result = ""; 73 | 74 | /* 75 | * Print the indentated "|" characters 76 | */ 77 | for (int i = 0; i < depth; i++) 78 | { 79 | result += "| "; 80 | } 81 | 82 | /* 83 | * Print the value at the current node 84 | */ 85 | result += node; 86 | 87 | result += "\n"; 88 | 89 | /* 90 | * Generate the string for the child nodes of this node 91 | */ 92 | for (CPDNode child : node.getChildren()) 93 | { 94 | result += getString(child, depth + 1); 95 | } 96 | 97 | return result; 98 | } 99 | } 100 | 101 | /** 102 | * Make a query on this CPD tree 103 | * 104 | * @param query the query object 105 | * @return the probability of this query given this CPD 106 | */ 107 | public Double query(CPDQuery query) 108 | { 109 | return this.root.calculateProbability(query); 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/cpd/Split.java: -------------------------------------------------------------------------------- 1 | package bayes.cpd; 2 | 3 | import java.util.ArrayList; 4 | 5 | 6 | import data.Attribute; 7 | import data.DataSet; 8 | import data.Instance; 9 | import data.InstanceSet; 10 | 11 | /** 12 | * Splits all instances along a specific attribute. 13 | * 14 | * @author Matthew Bernstien - matthewb@cs.wisc.edu 15 | * 16 | */ 17 | public class Split 18 | { 19 | /** 20 | * All of the branches for this split. For splits on nominal attributes, 21 | * there will be one branch per nominal value. For continuous attributes, 22 | * there will be two branches. One branch for the all instances with a value 23 | * greater than the threshold and one branch for all instances less than or 24 | * equal to the threshold value. 25 | */ 26 | private ArrayList branches; 27 | 28 | /** 29 | * The attribute this Split splits instances on 30 | */ 31 | private Attribute attribute; 32 | 33 | public Split(Attribute attribute) 34 | { 35 | branches = new ArrayList(); 36 | this.attribute = attribute; 37 | } 38 | 39 | public Attribute getAttribute() 40 | { 41 | return attribute; 42 | } 43 | 44 | /** 45 | * Split a set of instances along this split's attribute. Each split is 46 | * stored in one of this Split's SplitBranch objects 47 | * 48 | * @param instances the set of instances we wish to split 49 | */ 50 | public void splitInstances(InstanceSet instances) 51 | { 52 | for (Instance instance : instances.getInstances()) 53 | { 54 | for (SplitBranch branch : this.branches) 55 | { 56 | branch.tryAddInstance(instance); 57 | } 58 | } 59 | } 60 | 61 | /** 62 | * @return all of this split's branches 63 | */ 64 | public ArrayList getSplitBranches() 65 | { 66 | return branches; 67 | } 68 | 69 | /** 70 | * Add a branch to the split 71 | * 72 | * @param newBranch the new branch 73 | */ 74 | public void addBranch(SplitBranch newBranch) 75 | { 76 | branches.add(newBranch); 77 | } 78 | 79 | /** 80 | * A helper method for generating a split along a nominal attribute 81 | * 82 | * @param attrId 83 | * @return 84 | */ 85 | public static Split createSplitNominal(Attribute attr, DataSet data) 86 | { 87 | Split split = new Split(attr); 88 | 89 | for (Integer nominalValueId : attr.getNominalValueMap().values()) 90 | { 91 | SplitBranch newBranch = new SplitBranch(attr, 92 | new Double(nominalValueId)); 93 | split.addBranch(newBranch); 94 | } 95 | 96 | split.splitInstances(data.getInstanceSet()); 97 | 98 | return split; 99 | } 100 | 101 | } 102 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/cpd/SplitBranch.java: -------------------------------------------------------------------------------- 1 | package bayes.cpd; 2 | 3 | 4 | import data.Attribute; 5 | import data.Instance; 6 | import data.InstanceSet; 7 | 8 | /** 9 | * This class stores all instances for which a specific attribute matches a 10 | * specific value. 11 | * 12 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 13 | * 14 | */ 15 | public class SplitBranch 16 | { 17 | 18 | /** 19 | * The value that an instance's attribute (the attribute determined by this 20 | * SplitBranch's attribute) is tested against to make this split. 21 | */ 22 | private Double branchValue; 23 | 24 | /** 25 | * The attribute this branch tests 26 | */ 27 | private Attribute attribute; 28 | 29 | /** 30 | * All instances that fall to this branch 31 | */ 32 | private InstanceSet instanceSet; 33 | 34 | /** 35 | * Constructor 36 | * 37 | * @param attribute the attribute this branch tests 38 | * @param branchValue the value that an instance's attribute (this 39 | * SplitBranch's attribute) is tested against to make this split. 40 | */ 41 | public SplitBranch(Attribute attribute, Double branchValue) 42 | { 43 | this.instanceSet = new InstanceSet(); 44 | this.attribute = attribute; 45 | this.branchValue = branchValue; 46 | } 47 | 48 | public InstanceSet getInstanceSet() 49 | { 50 | return instanceSet; 51 | } 52 | 53 | /** 54 | * Attempt to add an instance to the this split branch. The instance is 55 | * only add if it passes this branches test. 56 | * 57 | * @param instance 58 | */ 59 | public void tryAddInstance(Instance instance) 60 | { 61 | if (this.doesInstanceMakeSplit(instance)) 62 | { 63 | instanceSet.addInstance(instance); 64 | } 65 | } 66 | 67 | /** 68 | * @return the attribute this branch tests 69 | */ 70 | public Attribute getAttribute() 71 | { 72 | return attribute; 73 | } 74 | 75 | /** 76 | * @return the value that an instance is tested against to make this split 77 | */ 78 | public Double getValue() 79 | { 80 | return branchValue; 81 | } 82 | 83 | /** 84 | * Tests whether an instance makes this split branch. 85 | * 86 | * @param instance the instance we are testing whether or not it makes 87 | * this SplitBranch 88 | * @return true if the instance's attribute's value matches this 89 | * SplitBranch's value 90 | */ 91 | public Boolean doesInstanceMakeSplit(Instance instance) 92 | { 93 | Double instanceAttrValue = 94 | instance.getAttributeValue(this.attribute); 95 | 96 | return (instanceAttrValue.doubleValue() == branchValue.doubleValue()); 97 | } 98 | 99 | } 100 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/structuresearch/NaiveBayesBuilder.java: -------------------------------------------------------------------------------- 1 | package bayes.structuresearch; 2 | 3 | 4 | import bayes.BNNode; 5 | import bayes.BayesianNetwork; 6 | 7 | 8 | import data.Attribute; 9 | import data.DataSet; 10 | 11 | /** 12 | * Builds a Bayesian Network with a Naive Bayes Structure 13 | * 14 | */ 15 | public class NaiveBayesBuilder extends NetworkBuilder 16 | { 17 | /** 18 | * Builds Bayesian network with a Naive bayes structure. 19 | * 20 | * @param data the data set used to construct the parameters. This 21 | * DataSet's class attribute must be set to a valid attribute. 22 | */ 23 | @Override 24 | public BayesianNetwork buildNetwork(DataSet data, Integer laplaceCount) 25 | { 26 | BayesianNetwork net = super.setupNetwork(data, laplaceCount); 27 | net.setNetStructureAlgorithm(BayesianNetwork.StructureAlgorithm.NAIVE_BAYES); 28 | 29 | /* 30 | * Create edges from the class Node to all other nodes 31 | */ 32 | BNNode classAttrNode = net.getNode(data.getClassAttribute()); 33 | for (BNNode node : net.getNodes()) 34 | { 35 | if (!node.equals( classAttrNode )) 36 | { 37 | net.createEdge(classAttrNode, node, data, laplaceCount); 38 | } 39 | } 40 | 41 | return net; 42 | } 43 | 44 | } 45 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/structuresearch/NetworkBuilder.java: -------------------------------------------------------------------------------- 1 | package bayes.structuresearch; 2 | 3 | 4 | import bayes.BNNode; 5 | import bayes.BayesianNetwork; 6 | import data.Attribute; 7 | import data.DataSet; 8 | 9 | /** 10 | * Constructs a {@code BayesianNetwork} object 11 | * 12 | */ 13 | abstract class NetworkBuilder 14 | { 15 | /** 16 | * The Laplace count used when generating all parameters in the network 17 | */ 18 | protected Integer laplaceCount; 19 | 20 | public abstract BayesianNetwork buildNetwork(DataSet data, Integer laplaceCount); 21 | 22 | /** 23 | * Builds a new Bayesian network given a dataset. 24 | * 25 | * @param data the data set used to construct the network 26 | * @param laplaceCount the Laplace count used when generating all 27 | * parameters in the network 28 | * @return a constructed Bayesian network 29 | */ 30 | public BayesianNetwork setupNetwork(DataSet data, Integer laplaceCount) 31 | { 32 | this.laplaceCount = laplaceCount; 33 | 34 | BayesianNetwork net = new BayesianNetwork(); 35 | 36 | /* 37 | * Create a node corresponding to each nominal attribute in the 38 | * dataset. Continuous attributes are ignored. 39 | */ 40 | for (Attribute attr : data.getAttributeSet().getAttributes()) 41 | { 42 | if (attr.getType() == Attribute.Type.NOMINAL) 43 | { 44 | BNNode newNode = new BNNode(attr); 45 | net.addNode( newNode, data, this.laplaceCount ); 46 | } 47 | } 48 | 49 | return net; 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/structuresearch/Operation.java: -------------------------------------------------------------------------------- 1 | package bayes.structuresearch; 2 | 3 | import bayes.BNNode; 4 | 5 | /** 6 | * Represents a single operation that can be performed on the network's 7 | * structure. 8 | * 9 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 10 | * 11 | */ 12 | public class Operation 13 | { 14 | /** 15 | * Types of operations 16 | */ 17 | public enum Type {ADD, REMOVE, REVERSE}; 18 | 19 | /** 20 | * Parent node of the edge 21 | */ 22 | private BNNode parent; 23 | 24 | /** 25 | * Child node of the edge 26 | */ 27 | private BNNode child; 28 | 29 | /** 30 | * This operation's type 31 | */ 32 | private Type type; 33 | 34 | /** 35 | * Constructor 36 | */ 37 | public Operation() {} 38 | 39 | /** 40 | * Constructor 41 | * 42 | * @param type this operation's type 43 | * @param parent the parent node of the edge 44 | * @param child the child node of the edge 45 | */ 46 | public Operation(Operation.Type type, BNNode parent, BNNode child) 47 | { 48 | this.type = type; 49 | this.parent = parent; 50 | this.child = child; 51 | } 52 | 53 | /** 54 | * @param parent the parent node of the edge 55 | */ 56 | public void setParent(BNNode parent) 57 | { 58 | this.parent = parent; 59 | } 60 | 61 | /** 62 | * @param child the child node of the edge 63 | */ 64 | public void setChild(BNNode child) 65 | { 66 | this.child = child; 67 | } 68 | 69 | /** 70 | * @param type this operation's type 71 | */ 72 | public void setType(Operation.Type type) 73 | { 74 | this.type = type; 75 | } 76 | 77 | /** 78 | * @return the child node of the edge 79 | */ 80 | public BNNode getChild() 81 | { 82 | return child; 83 | } 84 | 85 | /** 86 | * @return the parent node of the edge 87 | */ 88 | public BNNode getParent() 89 | { 90 | return parent; 91 | } 92 | 93 | /** 94 | * @return this operation's type 95 | */ 96 | public Operation.Type getType() 97 | { 98 | return this.type; 99 | } 100 | 101 | @Override 102 | public String toString() 103 | { 104 | String result = ""; 105 | 106 | switch(this.type) 107 | { 108 | case ADD: 109 | result += "ADD "; 110 | break; 111 | case REVERSE: 112 | result += "REVERSE "; 113 | break; 114 | case REMOVE: 115 | result += "REMOVE "; 116 | break; 117 | } 118 | 119 | result += this.parent.getName() + " -> " + this.child.getName(); 120 | 121 | return result; 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /machine-learning/src/bayes/structuresearch/score/ScoringFunction.java: -------------------------------------------------------------------------------- 1 | package bayes.structuresearch.score; 2 | 3 | import data.DataSet; 4 | 5 | import bayes.BayesianNetwork; 6 | 7 | public interface ScoringFunction 8 | { 9 | public Double scoreNet(BayesianNetwork net, DataSet data); 10 | } 11 | -------------------------------------------------------------------------------- /machine-learning/src/classify/ClassificationResult.java: -------------------------------------------------------------------------------- 1 | package classify; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import data.Attribute; 7 | import data.DataSet; 8 | import pair.Pair; 9 | 10 | /** 11 | * Objects of this class are used to store the results of a single 12 | * classification experiment. 13 | * 14 | */ 15 | public class ClassificationResult 16 | { 17 | 18 | /** 19 | * String representation of the results 20 | */ 21 | private String resultStr = ""; 22 | 23 | /** 24 | * Accuracy of the classification task 25 | */ 26 | private Double accuracy; 27 | 28 | /** 29 | * Size of the test set used in the classification 30 | */ 31 | private Integer testDataSize; 32 | 33 | /** 34 | * Constructor 35 | * 36 | * @param resultList an ArrayList of predictions. Each Pair in this list 37 | * represents the classification of a single instance in the test data. 38 | * The first element of each pair refers to the predicted nominal value ID 39 | * the class attribute. The second element refers to the confidence of the 40 | * classifier. 41 | * 42 | * @param testData the DataSet object containing all test instances 43 | */ 44 | public ClassificationResult(List> resultList, 45 | DataSet testData) 46 | { 47 | Attribute classAttr = testData.getClassAttribute(); 48 | 49 | int correctCount = 0; 50 | 51 | for (int i = 0; i < resultList.size(); i++) 52 | { 53 | Integer classification = resultList.get(i).getFirst(); 54 | Integer truth = testData.getInstanceSet() 55 | .getInstanceById(i) 56 | .getAttributeValue(classAttr).intValue(); 57 | 58 | /* 59 | * Check for correct classification 60 | */ 61 | if (classification == truth) 62 | { 63 | correctCount++; 64 | } 65 | 66 | resultStr += classAttr.getNominalValueName(classification); 67 | resultStr += " "; 68 | resultStr += classAttr.getNominalValueName(truth); 69 | resultStr += " "; 70 | resultStr += resultList.get(i).getSecond(); 71 | resultStr += "\n"; 72 | } 73 | 74 | resultStr += "\n"; 75 | resultStr += correctCount; 76 | 77 | // Set metrics 78 | this.testDataSize = testData.getInstanceSet().getInstances().size(); 79 | this.accuracy = (double) correctCount / this.testDataSize; 80 | } 81 | 82 | /** 83 | * @return the classification accuracy from this experiment 84 | */ 85 | public Double getAccuracy() 86 | { 87 | return this.accuracy; 88 | } 89 | 90 | @Override 91 | public String toString() 92 | { 93 | return resultStr; 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /machine-learning/src/classify/Classifier.java: -------------------------------------------------------------------------------- 1 | package classify; 2 | 3 | import data.DataSet; 4 | 5 | /** 6 | * Any model learned to perform a supervised classification task should 7 | * implement this interface. 8 | * 9 | */ 10 | public interface Classifier 11 | { 12 | public ClassificationResult classifyData(DataSet testData); 13 | 14 | public Object getModel(); 15 | } 16 | -------------------------------------------------------------------------------- /machine-learning/src/classify/evaluate/PercentageError.java: -------------------------------------------------------------------------------- 1 | package classify.evaluate; 2 | 3 | /** 4 | * Methods for evaluating percentage errors from classification experiments. 5 | * 6 | */ 7 | public class PercentageError 8 | { 9 | /** 10 | * Mean percentage error 11 | * 12 | * @param truthVals the true values 13 | * @param predictedVals the predicted values 14 | * @return the mean percentage error between the true and predicted values 15 | */ 16 | public static double meanPercentageError(Double[] truthVals, 17 | Double[] predictedVals) 18 | { 19 | double sum = 0.0; 20 | 21 | for (int i = 0; i < truthVals.length; i++) 22 | { 23 | sum += percentageError(truthVals[i], predictedVals[i]); 24 | } 25 | 26 | return (100.0 / truthVals.length) * sum; 27 | } 28 | 29 | /** 30 | * Percentage error 31 | * 32 | * @param truth the true value 33 | * @param predicted the predicted value 34 | * @return the precentage error between the true and predicted values 35 | */ 36 | public static double percentageError(Double truth, Double predicted) 37 | { 38 | return Math.abs((truth - predicted) / truth); 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /machine-learning/src/data/AttributeSet.java: -------------------------------------------------------------------------------- 1 | package data; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collection; 5 | import java.util.HashMap; 6 | import java.util.List; 7 | import java.util.Map; 8 | import java.util.Map.Entry; 9 | 10 | import com.google.common.collect.ImmutableMap; 11 | 12 | /** 13 | * Stores a set of attributes. 14 | * 15 | */ 16 | public class AttributeSet 17 | { 18 | /** 19 | * Maps the name of an attribute to the object 20 | */ 21 | private final Map nameAttrMap; 22 | 23 | /** 24 | * The attribute that denotes the "class" or "concept" 25 | */ 26 | private String classAttribute; 27 | 28 | /** 29 | * Constructor 30 | */ 31 | public AttributeSet(List attributes) 32 | { 33 | ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); 34 | 35 | for (Attribute attr : attributes) 36 | { 37 | builder.put(attr.getName(), attr); 38 | } 39 | 40 | this.nameAttrMap = builder.build(); 41 | } 42 | 43 | /** 44 | * Get an attribute by its attribute name. 45 | * 46 | * @param attrName - the attribute name 47 | * @return the attribute with this name 48 | */ 49 | public Attribute getAttributeByName(String attrName) 50 | { 51 | return nameAttrMap.get(attrName); 52 | } 53 | 54 | /** 55 | * Get the nominal value ID for a specific attribute name and nominal 56 | * value of that attribute 57 | * 58 | * @param attrName the name of the target attribute 59 | * @param attrValue the name of the target nominal value 60 | * @return the unique integer ID of that nominal value 61 | */ 62 | public Integer getNominalValueId(String attrName, String attrValue) 63 | { 64 | return nameAttrMap.get(attrName).getNominalValueId(attrValue); 65 | } 66 | 67 | /** 68 | * @return a list of all attributes in the attribute set 69 | */ 70 | public List getAttributes() 71 | { 72 | return new ArrayList<>(nameAttrMap.values()); 73 | } 74 | 75 | /** 76 | * Sets the attribute that will be used as attribute set's class attribute 77 | * 78 | * @param attrName 79 | */ 80 | public void setClass(String attrName) 81 | { 82 | if (this.containsAttrWithName(attrName)) 83 | { 84 | classAttribute = attrName; 85 | } 86 | else 87 | { 88 | throw new RuntimeException("Trying to set an invalid attribute, " + 89 | attrName + " as the class attribute."); 90 | } 91 | } 92 | 93 | /** 94 | * @return the name of the attribute denoted as the "class" or "concept" 95 | * attribute 96 | */ 97 | public String getClassAttrName() 98 | { 99 | return classAttribute; 100 | } 101 | 102 | /** 103 | * Determines whether a given attribute name is in the attribute set. 104 | * 105 | * @param attrName The name of the attribute for which we are checking is 106 | * valid 107 | */ 108 | public Boolean containsAttrWithName(String attrName) 109 | { 110 | return nameAttrMap.containsKey(attrName); 111 | } 112 | 113 | } 114 | -------------------------------------------------------------------------------- /machine-learning/src/data/Instance.java: -------------------------------------------------------------------------------- 1 | package data; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.Map.Entry; 6 | 7 | 8 | /** 9 | * Represents an instance. 10 | * 11 | */ 12 | public class Instance 13 | { 14 | /** 15 | * This instance's attribute value. The map maps an attribute ID 16 | * to a valid value for that attribute 17 | */ 18 | private final Map attributesToValues; 19 | 20 | /** 21 | * Constructor 22 | */ 23 | public Instance() 24 | { 25 | attributesToValues = new HashMap<>(); 26 | } 27 | 28 | 29 | /** 30 | * Get the value for an attribute. 31 | * 32 | * @param attr the specified attribute 33 | * @return this instance's value of the specified attribute 34 | */ 35 | public Double getAttributeValue(Attribute attr) 36 | { 37 | return attributesToValues.get(attr); 38 | } 39 | 40 | /** 41 | * Add an attribute-value pair to the instance. 42 | * 43 | * @param attrId the attribute ID of the attribute being added 44 | * @param value the value of the corresponding attribute 45 | */ 46 | public void addAttributeValue(Attribute attr, Double value) 47 | { 48 | attributesToValues.put(attr, value); 49 | } 50 | 51 | /** 52 | * Checks if this Instance is equal to another Instance. 53 | * 54 | * @param o the other Instance 55 | * @return 56 | */ 57 | @Override 58 | public boolean equals(Object o) 59 | { 60 | Map other = ((Instance)o).attributesToValues; 61 | for(Entry attr: attributesToValues.entrySet()) 62 | { 63 | if(!other.get(attr.getKey()).equals(attr.getValue())) 64 | { 65 | return false; 66 | } 67 | } 68 | return true; 69 | } 70 | 71 | public String toString() 72 | { 73 | String result = ""; 74 | for(Entry entry: attributesToValues.entrySet()) 75 | { 76 | if (entry.getKey().getType() == Attribute.Type.NOMINAL) 77 | { 78 | result += "(" + entry.getKey().getName() + " = " + entry.getKey().getNominalValueName(entry.getValue().intValue()) + ") "; 79 | } 80 | else 81 | { 82 | result += "(" + entry.getKey().getName() + " = " + entry.getValue() + ") "; 83 | } 84 | } 85 | return result; 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /machine-learning/src/data/InstanceSet.java: -------------------------------------------------------------------------------- 1 | package data; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | /** 7 | * This class represents a set of {@code Instance} objects. 8 | * 9 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 10 | * 11 | */ 12 | public class InstanceSet 13 | { 14 | /** 15 | * All instances in this instance set 16 | */ 17 | private final List instances; 18 | 19 | /** 20 | * Constructor 21 | */ 22 | public InstanceSet() 23 | { 24 | instances = new ArrayList(); 25 | } 26 | 27 | /** 28 | * @return a list of all instances in this instance set 29 | */ 30 | public List getInstances() 31 | { 32 | return instances; 33 | } 34 | 35 | /** 36 | * Add an instance to this instance set 37 | * 38 | * @param newInstance the new instance 39 | */ 40 | public void addInstance(Instance newInstance) 41 | { 42 | instances.add(newInstance); 43 | } 44 | 45 | /** 46 | * @param id the unique ID of a specific instance 47 | * @return the instance with specified ID 48 | */ 49 | public Instance getInstanceById(int id) 50 | { 51 | return instances.get(id); 52 | } 53 | 54 | public String toString(){ 55 | return instances.toString(); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /machine-learning/src/data/fold/KFoldCreator.java: -------------------------------------------------------------------------------- 1 | package data.fold; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collections; 5 | import java.util.List; 6 | 7 | import data.DataSet; 8 | import data.Instance; 9 | import data.InstanceSet; 10 | import data.reader.ArffReader; 11 | 12 | import pair.Pair; 13 | 14 | public class KFoldCreator { 15 | 16 | public static void main(String[] args){ 17 | ArffReader ar = new ArffReader(); 18 | DataSet data = ar.readFile("data/kfold_test.arff"); 19 | data.setClassAttribute("digit"); 20 | List> pairs = KFoldCreator.create(data, 5); 21 | for(Pair pair: pairs){ 22 | System.out.println("TRAIN:"); 23 | System.out.println(pair.getFirst()); 24 | System.out.println("\nTEST:"); 25 | System.out.println(pair.getSecond()); 26 | System.out.println("\n\n"); 27 | } 28 | } 29 | 30 | public static List> create(DataSet data, int K) { 31 | 32 | InstanceSet is = data.getInstanceSet(); 33 | 34 | List instances = is.getInstances(); 35 | Collections.shuffle(instances); 36 | 37 | 38 | int numPerSplice = instances.size()/K; 39 | List> pairs = new ArrayList<>(); 40 | 41 | //For each fold 42 | for(int i = 0; i < K; i++){ 43 | //fill training and testing lists 44 | List trainInst = new ArrayList<>(); 45 | List testInst = new ArrayList<>(); 46 | int left = i * numPerSplice; 47 | int right = i * numPerSplice + numPerSplice; 48 | trainInst.addAll(instances.subList(0, left)); 49 | if(i + 1 == K){ 50 | testInst.addAll(instances.subList(left, instances.size())); 51 | } else { 52 | testInst.addAll(instances.subList(left, right)); 53 | trainInst.addAll(instances.subList(right, instances.size())); 54 | } 55 | 56 | //put in sets 57 | InstanceSet trainSet = new InstanceSet(); 58 | for(Instance inst: trainInst){ 59 | trainSet.addInstance(inst); 60 | } 61 | 62 | InstanceSet testSet = new InstanceSet(); 63 | for(Instance inst: testInst){ 64 | testSet.addInstance(inst); 65 | } 66 | //create data sets 67 | DataSet trainData = new DataSet(data.getAttributeSet(), trainSet); 68 | DataSet testData = new DataSet(data.getAttributeSet(),testSet); 69 | 70 | //Add to list of pairs 71 | pairs.add(new Pair(trainData, testData)); 72 | 73 | } 74 | return pairs; 75 | } 76 | 77 | } 78 | -------------------------------------------------------------------------------- /machine-learning/src/distributions/Distribution.java: -------------------------------------------------------------------------------- 1 | package distributions; 2 | 3 | public interface Distribution 4 | { 5 | public double sample(); 6 | 7 | public double[] sampleMany(int numSamples); 8 | } 9 | -------------------------------------------------------------------------------- /machine-learning/src/distributions/GeometricDistribution.java: -------------------------------------------------------------------------------- 1 | package distributions; 2 | 3 | import java.util.Random; 4 | 5 | public class GeometricDistribution implements Distribution 6 | { 7 | private final Random RNG; 8 | private final double PARAMETER; 9 | 10 | public GeometricDistribution(double parameter) 11 | { 12 | this.PARAMETER = parameter; 13 | this.RNG = new Random(); 14 | } 15 | 16 | public double sample() 17 | { 18 | int result = 1; 19 | while (true) 20 | { 21 | if (RNG.nextDouble() < PARAMETER) 22 | { 23 | return result; 24 | } 25 | else 26 | { 27 | result++; 28 | } 29 | } 30 | } 31 | 32 | public double[] sampleMany(int numSamples) 33 | { 34 | double[] samples = new double[numSamples]; 35 | for (int i = 0; i < numSamples; i++) 36 | { 37 | samples[i] = sample(); 38 | } 39 | return samples; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /machine-learning/src/graph/Path.java: -------------------------------------------------------------------------------- 1 | package graph; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | public class Path 7 | { 8 | /** 9 | * The first node on the path 10 | */ 11 | private T origin; 12 | 13 | /** 14 | * The final node on the path 15 | */ 16 | private T destination; 17 | 18 | /** 19 | * The path's distance 20 | */ 21 | private double length; 22 | 23 | /** 24 | * List of nodes representing the path 25 | */ 26 | private List path; 27 | 28 | public Path(T origin, T destinationNode) 29 | { 30 | this.origin = origin; 31 | this.destination = destinationNode; 32 | this.path = new ArrayList<>(); 33 | this.path.add(this.origin); 34 | } 35 | 36 | public void appendNodeToPath(T nextNode, Double edgeLength) 37 | { 38 | length += edgeLength; 39 | path.add(nextNode); 40 | } 41 | 42 | public List getNodesOnPath() 43 | { 44 | return this.path; 45 | } 46 | 47 | public T getPathOrigin() 48 | { 49 | return this.origin; 50 | } 51 | 52 | public T getPathDestination() 53 | { 54 | return this.destination; 55 | } 56 | 57 | public double getPathLength() 58 | { 59 | return this.length; 60 | } 61 | 62 | @Override 63 | public String toString() 64 | { 65 | String str = ""; 66 | for (int i = 0; i < path.size() - 1; i++) 67 | { 68 | str += path.get(i) + " --> "; 69 | } 70 | str += path.get(path.size() - 1); 71 | str += " : " + this.length; 72 | return str; 73 | } 74 | } 75 | 76 | -------------------------------------------------------------------------------- /machine-learning/src/graph/dag/DetectCycles.java: -------------------------------------------------------------------------------- 1 | package graph.dag; 2 | 3 | import java.util.ArrayList; 4 | 5 | public class DetectCycles 6 | { 7 | /** 8 | * Detects if there is a cycle in the graph 9 | * 10 | * @param graph the adjacency matrix representing the graph. Non-edges are 11 | * represented by the null reference 12 | * @return true if a cycle has been detected, false otherwise 13 | */ 14 | public static Boolean run(Double[][] graph) 15 | { 16 | /* 17 | * All unsorted vertices 18 | */ 19 | ArrayList toProcess = getAllVertices(graph); 20 | 21 | /* 22 | * Run topological sort 23 | */ 24 | while(!toProcess.isEmpty()) 25 | { 26 | int beforeSize = toProcess.size(); 27 | runIteration(graph, toProcess); 28 | 29 | if (toProcess.size() == beforeSize) 30 | { 31 | return true; 32 | } 33 | } 34 | 35 | return false; 36 | } 37 | 38 | /** 39 | * Run a single iteration of cycle detection algorithm 40 | * 41 | * @param graph the graph 42 | * @param toProcess a list of vertices that have yet to be processed 43 | */ 44 | private static void runIteration(Double[][] graph, 45 | ArrayList toProcess) 46 | { 47 | 48 | for (int i = 0; i < toProcess.size(); i++) 49 | { 50 | int vertex = toProcess.get(i); 51 | 52 | /* 53 | * Find parent 54 | */ 55 | boolean foundParent = false; 56 | for (int r = 0; r < graph.length; r++) 57 | { 58 | if (graph[r][vertex] != null) 59 | { 60 | foundParent = true; 61 | break; 62 | } 63 | } 64 | 65 | /* 66 | * If no parents are found, cut this vertex from the graph 67 | */ 68 | if (!foundParent) 69 | { 70 | cutVertex(graph, vertex); 71 | for (int j = 0; j < toProcess.size(); j++) 72 | { 73 | if (toProcess.get(j) == vertex) 74 | { 75 | toProcess.remove(j); 76 | } 77 | } 78 | } 79 | } 80 | } 81 | 82 | /** 83 | * Cut a vertex from the graph. That is, cut all outgoing edges from this 84 | * vertex. 85 | * 86 | * @param graph the graph 87 | * @param vertex the vertex to cut from the graph 88 | */ 89 | private static void cutVertex(Double[][] graph, Integer vertex) 90 | { 91 | for (int c = 0; c < graph.length; c++) 92 | { 93 | graph[vertex][c] = null; 94 | } 95 | } 96 | 97 | /** 98 | * Get all vertices in a graph 99 | * 100 | * @param graph the graph 101 | * @return all vertices in the graph 102 | */ 103 | private static ArrayList getAllVertices(Double[][] graph) 104 | { 105 | ArrayList allNodes = new ArrayList(); 106 | 107 | for (int r = 0; r < graph.length; r++) 108 | { 109 | allNodes.add(r); 110 | } 111 | 112 | return allNodes; 113 | } 114 | 115 | } 116 | -------------------------------------------------------------------------------- /machine-learning/src/graph/dag/TopologicalSort.java: -------------------------------------------------------------------------------- 1 | package graph.dag; 2 | 3 | import java.util.ArrayList; 4 | 5 | /** 6 | * Topologically sorts the vertices in a Directed Acyclic Graph (DAG). 7 | * 8 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 9 | * 10 | */ 11 | public class TopologicalSort 12 | { 13 | /** 14 | * Run the topological sort 15 | * 16 | * @param graph the adjacency matrix representing the DAG. Non-edges are 17 | * represented by the null reference 18 | * @return a sorted list of vertices 19 | */ 20 | public static ArrayList run(Double[][] graph) 21 | { 22 | /* 23 | * All unsorted vertices 24 | */ 25 | ArrayList toProcess = getAllVertices(graph); 26 | 27 | /* 28 | * Sorted vertices 29 | */ 30 | ArrayList sorted = new ArrayList(); 31 | 32 | /* 33 | * Run topological sort 34 | */ 35 | while(!toProcess.isEmpty()) 36 | { 37 | sorted.addAll( runIteration(graph, toProcess) ); 38 | } 39 | 40 | return sorted; 41 | } 42 | 43 | /** 44 | * Run a single iteration of the topological sort 45 | * 46 | * @param graph the graph 47 | * @param toProcess a list of vertices that have yet to be added to the list 48 | * of sorted vertices 49 | * @return a list of vertices from the list of unsorted vertices that must 50 | * be appended to the list of sorted vertices 51 | */ 52 | private static ArrayList runIteration(Double[][] graph, 53 | ArrayList toProcess) 54 | { 55 | ArrayList newSorted = new ArrayList(); 56 | 57 | for (int i = 0; i < toProcess.size(); i++) 58 | { 59 | int vertex = toProcess.get(i); 60 | 61 | /* 62 | * Find parent 63 | */ 64 | boolean foundParent = false; 65 | for (int r = 0; r < graph.length; r++) 66 | { 67 | if (graph[r][vertex] != null) 68 | { 69 | foundParent = true; 70 | break; 71 | } 72 | } 73 | 74 | /* 75 | * If no parents are found, add the vertex to the list of vertices 76 | * to be returned and cut this vertex from the graph 77 | */ 78 | if (!foundParent) 79 | { 80 | cutVertex(graph, vertex); 81 | for (int j = 0; j < toProcess.size(); j++) 82 | { 83 | if (toProcess.get(j) == vertex) 84 | { 85 | toProcess.remove(j); 86 | } 87 | } 88 | newSorted.add(vertex); 89 | } 90 | } 91 | 92 | return newSorted; 93 | } 94 | 95 | /** 96 | * Cut a vertex from the graph. That is, cut all outgoing edges from this 97 | * vertex. 98 | * 99 | * @param graph the graph 100 | * @param vertex the vertex to cut from the graph 101 | */ 102 | private static void cutVertex(Double[][] graph, Integer vertex) 103 | { 104 | for (int c = 0; c < graph.length; c++) 105 | { 106 | graph[vertex][c] = null; 107 | } 108 | } 109 | 110 | /** 111 | * Get all vertices in a graph 112 | * 113 | * @param graph the graph 114 | * @return all vertices in the graph 115 | */ 116 | private static ArrayList getAllVertices(Double[][] graph) 117 | { 118 | ArrayList allNodes = new ArrayList(); 119 | 120 | for (int r = 0; r < graph.length; r++) 121 | { 122 | allNodes.add(r); 123 | } 124 | 125 | return allNodes; 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /machine-learning/src/graph/floydwarshall/AllPairsShortestPaths.java: -------------------------------------------------------------------------------- 1 | package graph.floydwarshall; 2 | 3 | import graph.Path; 4 | 5 | import java.util.Map; 6 | import java.util.Map.Entry; 7 | 8 | import bimap.BiMap; 9 | 10 | public class AllPairsShortestPaths 11 | { 12 | private final BiMap indexToNode; 13 | 14 | private final Double[][] distanceMatrix; 15 | 16 | private final Integer[][] nextNodeMatrix; 17 | 18 | public AllPairsShortestPaths(Double[][] distanceMatrix, Integer[][] nextNodeMatrix, Map indexToNode) 19 | { 20 | this.indexToNode = new BiMap<>(); 21 | for (Entry e : indexToNode.entrySet()) 22 | { 23 | this.indexToNode.put(e.getKey(), e.getValue()); 24 | } 25 | 26 | this.distanceMatrix = distanceMatrix; 27 | this.nextNodeMatrix = nextNodeMatrix; 28 | } 29 | 30 | public Path getPath(T origin, T destination) 31 | { 32 | if (!pathExists(origin, destination)) 33 | { 34 | return null; 35 | } 36 | 37 | Path path = new Path<>(origin, destination); 38 | 39 | Integer currNodeIndex = indexToNode.getKey(origin); 40 | Integer destNodeIndex = indexToNode.getKey(destination); 41 | 42 | while (!currNodeIndex.equals(destNodeIndex)) 43 | { 44 | Integer nextNodeIndex = nextNodeMatrix[currNodeIndex][destNodeIndex]; 45 | 46 | Double currDistance = distanceMatrix[currNodeIndex][destNodeIndex]; 47 | Double nextNodeDistance = distanceMatrix[nextNodeIndex][destNodeIndex]; 48 | Double edgeWeight = currDistance - nextNodeDistance; 49 | 50 | path.appendNodeToPath(indexToNode.getValue(nextNodeIndex), edgeWeight); // TODO CALCULATE THE ACTUAL DISTANCE 51 | currNodeIndex = nextNodeIndex; 52 | } 53 | 54 | return path; 55 | } 56 | 57 | public boolean pathExists(T origin, T destination) 58 | { 59 | if (nextNodeMatrix[indexToNode.getKey(origin)][indexToNode.getKey(destination)] == null) 60 | { 61 | return false; 62 | } 63 | return true; 64 | } 65 | 66 | } 67 | -------------------------------------------------------------------------------- /machine-learning/src/graph/floydwarshall/FloydWarshall.java: -------------------------------------------------------------------------------- 1 | package graph.floydwarshall; 2 | 3 | import graph.DirectedGraph; 4 | import graph.Path; 5 | 6 | import java.util.HashMap; 7 | import java.util.Map; 8 | 9 | import pair.Pair; 10 | 11 | /** 12 | * An implementation of the Floyd-Warshall algorithm for computing all-pairs 13 | * shortest paths on a directed graph with no negative cycles. 14 | * 15 | */ 16 | public class FloydWarshall 17 | { 18 | public static AllPairsShortestPaths runFloydWarshall(DirectedGraph graph) 19 | { 20 | Map indexToNode = mapNodesToIndices(graph); 21 | 22 | Pair matrices = initializeMatrices(graph, indexToNode); 23 | 24 | Double[][] distanceMatrix = matrices.getFirst(); 25 | Integer[][] nextNodeMatrix = matrices.getSecond(); 26 | 27 | matrices = computeShortestPaths(distanceMatrix, nextNodeMatrix, graph, indexToNode); 28 | 29 | return new AllPairsShortestPaths<>(matrices.getFirst(), matrices.getSecond(), indexToNode); 30 | } 31 | 32 | private static Map mapNodesToIndices(DirectedGraph graph) 33 | { 34 | Map indexToNode = new HashMap<>(); 35 | int nodeIndex = 0; 36 | for (T node : graph.getNodes()) 37 | { 38 | indexToNode.put(new Integer(nodeIndex++), node); 39 | } 40 | return indexToNode; 41 | } 42 | 43 | /** 44 | * Initialize the matrices for storing shortest distances between nodes and the matrix 45 | * storing the next node from each node along that shortest path. 46 | * 47 | * @param graph the input graph 48 | * @param indexToNode map of the node index to the node object 49 | * @return the distance matrix and next node matrix 50 | */ 51 | private static Pair initializeMatrices(DirectedGraph graph, 52 | Map indexToNode) 53 | { 54 | int numNodes = graph.getNodes().size(); 55 | Double[][] distanceMatrix = new Double[numNodes][numNodes]; 56 | Integer[][] nextNodeMatrix = new Integer[numNodes][numNodes]; 57 | 58 | for (int origIndex = 0; origIndex < numNodes; origIndex++) 59 | { 60 | for (int destIndex = 0; destIndex < numNodes; destIndex++) 61 | { 62 | T origin = indexToNode.get(origIndex); 63 | T destination = indexToNode.get(destIndex); 64 | 65 | if (origIndex == destIndex) 66 | { 67 | distanceMatrix[origIndex][destIndex] = 0d; 68 | nextNodeMatrix[origIndex][destIndex] = origIndex; 69 | } 70 | else if (graph.edgeExists(origin, destination)) 71 | { 72 | distanceMatrix[origIndex][destIndex] = graph.getEdgeWeight(origin, destination); 73 | nextNodeMatrix[origIndex][destIndex] = destIndex; 74 | } 75 | else 76 | { 77 | distanceMatrix[origIndex][destIndex] = Double.POSITIVE_INFINITY; 78 | nextNodeMatrix[origIndex][destIndex] = null; 79 | } 80 | } 81 | } 82 | 83 | return new Pair<>(distanceMatrix, nextNodeMatrix); 84 | } 85 | 86 | private static Pair computeShortestPaths(Double[][] distanceMatrix, 87 | Integer[][] nextNodeMatrix, 88 | DirectedGraph graph, 89 | Map indexToNode) 90 | { 91 | int numNodes = graph.getNodes().size(); 92 | for (int k = 0; k < numNodes; k++) 93 | { 94 | for (int i = 0; i < numNodes; i++) 95 | { 96 | for (int j = 0; j < numNodes; j++) 97 | { 98 | if (distanceMatrix[i][j] > distanceMatrix[i][k] + distanceMatrix[k][j]) 99 | { 100 | distanceMatrix[i][j] = distanceMatrix[i][k] + distanceMatrix[k][j]; 101 | nextNodeMatrix[i][j] = nextNodeMatrix[i][k]; 102 | } 103 | } 104 | } 105 | } 106 | return new Pair<>(distanceMatrix, nextNodeMatrix); 107 | } 108 | 109 | } 110 | -------------------------------------------------------------------------------- /machine-learning/src/graph/prim/Edge.java: -------------------------------------------------------------------------------- 1 | package graph.prim; 2 | 3 | import java.util.Comparator; 4 | 5 | import pair.Pair; 6 | 7 | /** 8 | * Implements a directed edge in a weighted graph where edge weights are 9 | * floating point values and nodes are represented by unique integers. 10 | * 11 | */ 12 | public class Edge 13 | { 14 | /** 15 | * The two vertices that constitute this edge. We consider the edge 16 | * points: 17 | *
18 | *
19 | * first -> second 20 | */ 21 | Pair vertices; 22 | 23 | /** 24 | * The edge weight 25 | */ 26 | Double weight; 27 | 28 | /** 29 | * Compares edges by weight 30 | */ 31 | public static final Comparator EDGE_ORDER = 32 | new Comparator() 33 | { 34 | public int compare(Edge e1, Edge e2) 35 | { 36 | if (e1.weight == e2.weight) 37 | { 38 | return 0; 39 | } 40 | else if (e1.weight < e2.weight) 41 | { 42 | return 1; 43 | } 44 | else 45 | { 46 | return -1; 47 | } 48 | } 49 | }; 50 | 51 | /** 52 | * Constructor 53 | * 54 | * @param vertices the two vertices in the edge 55 | * @param weight the weight of the edge 56 | */ 57 | public Edge(Pair vertices, Double weight) 58 | { 59 | this.vertices = vertices; 60 | this.weight = weight; 61 | } 62 | 63 | public Edge(Integer origin, Integer destination, Double weight) 64 | { 65 | this.vertices = new Pair(origin, destination); 66 | } 67 | 68 | /** 69 | * @return the edge weight 70 | */ 71 | public Double getWeight() 72 | { 73 | return this.weight; 74 | } 75 | 76 | /** 77 | * @return the pair of vertices 78 | */ 79 | public Pair getVertices() 80 | { 81 | return this.vertices; 82 | } 83 | 84 | /** 85 | * @return the first vertex 86 | */ 87 | public Integer getFirstVertex() 88 | { 89 | return this.vertices.getFirst(); 90 | } 91 | 92 | /** 93 | * @return the second vertex 94 | */ 95 | public Integer getSecondVertex() 96 | { 97 | return this.vertices.getSecond(); 98 | } 99 | 100 | @Override 101 | public String toString() 102 | { 103 | String result = "<"; 104 | result += this.vertices.getFirst(); 105 | result += "--"; 106 | result += this.weight; 107 | result += "--"; 108 | result += this.vertices.getSecond(); 109 | result += ">"; 110 | 111 | return result; 112 | } 113 | } 114 | 115 | -------------------------------------------------------------------------------- /machine-learning/src/hmm/StateContainer.java: -------------------------------------------------------------------------------- 1 | package hmm; 2 | 3 | import java.util.ArrayList; 4 | import java.util.Collection; 5 | import java.util.HashMap; 6 | import java.util.Map; 7 | 8 | /** 9 | * This class is used to store the state objects that comprise the Markov 10 | * model. 11 | */ 12 | public class StateContainer 13 | { 14 | /** 15 | * Maps a state ID to a state object 16 | */ 17 | Map states; 18 | 19 | /** 20 | * Constructor. 21 | */ 22 | public StateContainer() 23 | { 24 | states = new HashMap(); 25 | } 26 | 27 | /** 28 | * @return the array list of all of the states in the model 29 | */ 30 | public Collection getStates() 31 | { 32 | return states.values(); 33 | } 34 | 35 | /* 36 | * TODO OPTIMIZE THIS!!!! 37 | */ 38 | public Collection getSilentStates() 39 | { 40 | ArrayList silent = new ArrayList(); 41 | for (State s : states.values()) 42 | { 43 | if (s.isSilent) 44 | silent.add(s); 45 | } 46 | return silent; 47 | } 48 | 49 | /** 50 | * Add a state to this container. 51 | * 52 | * @param newState the new state 53 | */ 54 | public void addState(State newState) 55 | { 56 | states.put(newState.getId(), newState); 57 | } 58 | 59 | /** 60 | * Retrieve a state with a specified unique ID from this state container. 61 | * 62 | * @param id the unique integer of ID of the state to be retrieved 63 | * @return The state that stores the specified ID. If a state with the 64 | * specified ID does not exist in this container, this method returns 65 | * null. 66 | */ 67 | public State getStateById(String id) 68 | { 69 | return states.get(id); 70 | } 71 | 72 | /** 73 | * Determines whether this state container contains a state with the 74 | * specified ID 75 | * 76 | * @param id the ID of the state we are querying for 77 | * @return true if this StateContainer has a state with specified ID, 78 | * false otherwise 79 | */ 80 | public boolean containsState(String id) 81 | { 82 | if (states.containsKey(id)) 83 | { 84 | return true; 85 | } 86 | else 87 | { 88 | return false; 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /machine-learning/src/hmm/StateParamsTied.java: -------------------------------------------------------------------------------- 1 | package hmm; 2 | 3 | import java.util.HashMap; 4 | import java.util.Map; 5 | import java.util.Map.Entry; 6 | 7 | import math.LogP; 8 | 9 | 10 | /** 11 | * Implements a state whose emission probability distribution is tied to that 12 | * of another state. 13 | * 14 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 15 | * 16 | */ 17 | public class StateParamsTied extends State 18 | { 19 | /** 20 | * Stores all emission probability distributions for all states with 21 | * tied emission parameters 22 | */ 23 | public static Map> tiedEmissionParams 24 | = new HashMap>(); 25 | 26 | /** 27 | * The ID of the parameters that this State uses 28 | */ 29 | private String paramsKey; 30 | 31 | /** 32 | * Constructor. 33 | * 34 | * @param paramsId the ID of the parameters that this State uses 35 | */ 36 | public StateParamsTied(String paramsKey, String id) 37 | { 38 | super(id); 39 | this.paramsKey = paramsKey; 40 | this.initializeParams(); 41 | } 42 | 43 | public StateParamsTied(State orig, String paramsKey) 44 | { 45 | super(orig); 46 | this.paramsKey = paramsKey; 47 | this.initializeParams(); 48 | } 49 | 50 | public StateParamsTied(StateParamsTied orig) 51 | { 52 | super(orig); 53 | this.paramsKey = orig.paramsKey; 54 | } 55 | 56 | /** 57 | * @return the emission probabilities from this state 58 | */ 59 | @Override 60 | public Map getEmissionProbabilites() 61 | { 62 | return tiedEmissionParams.get(this.paramsKey); 63 | } 64 | 65 | /** 66 | * Add an emission probability to this State 67 | * 68 | * @param symbol the symbol this state will emmit 69 | * @param probability the probability the state will emmit this symbol 70 | */ 71 | @Override 72 | public void addEmission(String symbol, Double probability) 73 | { 74 | tiedEmissionParams.get(this.paramsKey).put(symbol, probability); 75 | } 76 | 77 | /** 78 | * Get the emission probability of a specific symbol from this state 79 | * 80 | * @param symbol the symbol of interest 81 | * @return the emission probability 82 | */ 83 | public double getEmissionProb(String symbol) 84 | { 85 | if (tiedEmissionParams.get(this.paramsKey).containsKey(symbol)) 86 | { 87 | return tiedEmissionParams.get(this.paramsKey).get(symbol); 88 | } 89 | else 90 | { 91 | return LogP.ln(0.0); 92 | } 93 | } 94 | 95 | public String getParamsKey() 96 | { 97 | return this.paramsKey; 98 | } 99 | 100 | private void initializeParams() 101 | { 102 | if (!StateParamsTied.tiedEmissionParams.containsKey(this.paramsKey)) 103 | { 104 | StateParamsTied.tiedEmissionParams.put(this.paramsKey, 105 | new HashMap()); 106 | 107 | // TODO INITIALIZE PARAMETER PROBABILITIES IN MAP 108 | } 109 | } 110 | 111 | @Override 112 | public String toString() 113 | { 114 | String result = ""; 115 | result += "["; 116 | result += this.id; 117 | result += "]"; 118 | result += "\n"; 119 | 120 | result += "............\n"; 121 | 122 | for (Entry e : transitions.entrySet()) 123 | { 124 | String destStateId = e.getKey(); 125 | result += (LogP.exp(e.getValue().getTransitionProbability()) + 126 | " --> "); 127 | result += ("[" + destStateId + "]"); 128 | result += "\n"; 129 | } 130 | 131 | result += "............\n"; 132 | 133 | for (Entry entry : 134 | StateParamsTied.tiedEmissionParams.get(this.paramsKey).entrySet()) 135 | { 136 | result += (entry.getKey() + " >> " + LogP.exp(entry.getValue()) + "\n"); 137 | } 138 | 139 | result += "............\n"; 140 | result += "\n"; 141 | 142 | return result; 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /machine-learning/src/hmm/StateSilent.java: -------------------------------------------------------------------------------- 1 | package hmm; 2 | 3 | import java.util.Map; 4 | import java.util.Map.Entry; 5 | 6 | import math.LogP; 7 | 8 | 9 | 10 | /** 11 | * Implements a state that does not emit any symbols. 12 | * 13 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 14 | * 15 | */ 16 | public class StateSilent extends State 17 | { 18 | /** 19 | * Constructor 20 | */ 21 | public StateSilent() 22 | { 23 | super(); 24 | this.isSilent = true; 25 | } 26 | 27 | public StateSilent(String id) 28 | { 29 | super(id); 30 | this.isSilent = true; 31 | } 32 | 33 | @Override 34 | public Map getEmissionProbabilites() 35 | { 36 | System.err.println("Attempting to retrieve emission probabilities on " + 37 | "emission probabilities for silent state " + 38 | this.id); 39 | return null; 40 | } 41 | 42 | @Override 43 | public void addEmission(String symbol, Double probability) 44 | { 45 | System.err.println("Attempting to add emission probabilities to " + 46 | " silent state " + this.id); 47 | } 48 | 49 | @Override 50 | public double getEmissionProb(String symbol) 51 | { 52 | // TODO CHECK IF THIS IS CORRECT 53 | return LogP.ln(0.0); 54 | } 55 | 56 | @Override 57 | public String toString() 58 | { 59 | String result = ""; 60 | result += "["; 61 | result += this.id; 62 | result += "]"; 63 | result += "\n"; 64 | 65 | result += "............\n"; 66 | 67 | for (Entry e : transitions.entrySet()) 68 | { 69 | String destStateId = e.getKey(); 70 | result += LogP.exp(e.getValue().getTransitionProbability()) + 71 | " --> "; 72 | result += ("[" + destStateId + "]"); 73 | result += "\n"; 74 | } 75 | 76 | result += "............\n"; 77 | 78 | result += "silent\n"; 79 | 80 | result += "............\n"; 81 | result += "\n"; 82 | 83 | return result; 84 | } 85 | 86 | } 87 | -------------------------------------------------------------------------------- /machine-learning/src/hmm/Transition.java: -------------------------------------------------------------------------------- 1 | package hmm; 2 | 3 | import math.LogP; 4 | 5 | /** 6 | * This class implements a transition between two states in the Markov model: 7 | * an "origin" state and a "destination" state. That is, each Transition 8 | * object represents a transition from the origin state to the 9 | * destination state. Each Transition object stores a "count" that records the 10 | * number of times we observe in the original text that the word associated 11 | * with the destination state follows the word associate with the origin state. 12 | */ 13 | public class Transition 14 | { 15 | /** 16 | * ID of the origin state 17 | */ 18 | private String originId; 19 | 20 | /** 21 | * ID of the destination state 22 | */ 23 | private String destinationId; 24 | 25 | /** 26 | * The transition's associated probability 27 | */ 28 | private double probability; 29 | 30 | /** 31 | * Constructor 32 | * 33 | * @param originId ID of the origin state 34 | * @param destinationId ID of the destination state 35 | * @param probability probability of taking that transition 36 | */ 37 | public Transition(String originId, String destinationId, double probability) 38 | { 39 | this.originId = originId; 40 | this.destinationId = destinationId; 41 | this.probability = probability; 42 | } 43 | 44 | /** 45 | * Copy constructor 46 | */ 47 | public Transition(Transition t) 48 | { 49 | this.originId = t.originId; 50 | this.destinationId = t.destinationId; 51 | 52 | this.probability = t.probability; 53 | } 54 | 55 | /** 56 | * Get the integer ID of the state this transition moves to. 57 | * 58 | * @return the destination ID 59 | */ 60 | public String getDestinationId() 61 | { 62 | return destinationId; 63 | } 64 | 65 | /** 66 | * Set the integer ID of the state this transition moves to. 67 | * 68 | * @param destinationId the destination ID 69 | */ 70 | public void setDestinationId(String destinationId) 71 | { 72 | this.destinationId = destinationId; 73 | } 74 | 75 | /** 76 | * Get the ID of the state this transition moves from. 77 | * 78 | * @return 79 | */ 80 | public String getOriginId() 81 | { 82 | return originId; 83 | } 84 | 85 | /** 86 | * Set the ID of the state this transition moves from. 87 | * 88 | * @param originId the ID of the state this transition moves from 89 | */ 90 | public void setOriginId(String originId) 91 | { 92 | this.originId = originId; 93 | } 94 | 95 | /** 96 | * @param the transition probability 97 | */ 98 | public void setTransitionProbability(double probability) 99 | { 100 | this.probability = probability; 101 | } 102 | 103 | /** 104 | * @return the transition probability 105 | */ 106 | public double getTransitionProbability() 107 | { 108 | return this.probability; 109 | } 110 | 111 | public void incrementTransitionValue(double value) 112 | { 113 | this.probability = LogP.sum(this.probability, value); 114 | } 115 | 116 | } 117 | -------------------------------------------------------------------------------- /machine-learning/src/hmm/algorithms/DpMatrix.java: -------------------------------------------------------------------------------- 1 | package hmm.algorithms; 2 | 3 | import hmm.HMM; 4 | import hmm.State; 5 | 6 | import java.util.ArrayList; 7 | 8 | import math.LogP; 9 | 10 | 11 | 12 | import bimap.BiMap; 13 | 14 | public class DpMatrix 15 | { 16 | private int numRows; 17 | private int numCols; 18 | 19 | /** 20 | * Maps each state to a row 21 | */ 22 | private BiMap stateRowMap; 23 | 24 | /** 25 | * Map each time step (i.e. column) to a symbol 26 | */ 27 | private ArrayList colSymbolMap; 28 | 29 | /** 30 | * The matrix 31 | */ 32 | private DpMatrixElement[][] matrix; 33 | 34 | public DpMatrix(HMM model, String[] sequence) 35 | { 36 | numRows = model.getNumStates(); 37 | 38 | numCols = sequence.length + 1; 39 | 40 | initStateRowMap(model); 41 | initColSymbolMap(sequence); 42 | initMatrix(); 43 | } 44 | 45 | public void initStateRowMap(HMM model) 46 | { 47 | stateRowMap = new BiMap(); 48 | 49 | ArrayList states = new ArrayList(model.getStateContainer() 50 | .getStates()); 51 | 52 | for (int index = 0; index < states.size(); index++) 53 | { 54 | stateRowMap.put(states.get(index), index); 55 | } 56 | } 57 | 58 | public void initColSymbolMap(String[] sequence) 59 | { 60 | colSymbolMap = new ArrayList(); 61 | 62 | /* 63 | * The first time unit does not see a symbol emitted 64 | */ 65 | colSymbolMap.add("-"); 66 | 67 | for (int i = 0; i < sequence.length; i++) 68 | { 69 | colSymbolMap.add(sequence[i]); 70 | } 71 | } 72 | 73 | public double getValue(State state, int timeUnit) 74 | { 75 | int row = stateRowMap.getValue(state); 76 | return matrix[row][timeUnit].getValue(); 77 | } 78 | 79 | public void setValue(State state, int timeUnit, double value) 80 | { 81 | int row = stateRowMap.getValue(state); 82 | matrix[row][timeUnit].setValue(value); 83 | } 84 | 85 | public void setPreviousState(State currState, int timeUnit, State prevState) 86 | { 87 | /* 88 | * The current state's row 89 | */ 90 | int currRow = stateRowMap.getValue(currState); 91 | 92 | /* 93 | * The previous state's row 94 | */ 95 | int backRow = stateRowMap.getValue(prevState); 96 | 97 | matrix[currRow][timeUnit].setRowPointer(backRow); 98 | } 99 | 100 | public State getPreviousState(State currState, int timeUnit) 101 | { 102 | /* 103 | * The current state's row 104 | */ 105 | int currRow = stateRowMap.getValue(currState); 106 | 107 | /* 108 | * Get the row of the previous state from the current state 109 | */ 110 | int rowPointer = matrix[currRow][timeUnit].getRowPointer(); 111 | 112 | /* 113 | * Return the state associated with this row 114 | */ 115 | return stateRowMap.getKey(rowPointer); 116 | } 117 | 118 | public void initMatrix() 119 | { 120 | matrix = new DpMatrixElement[numRows][numCols]; 121 | 122 | /* 123 | * Initialize elements 124 | */ 125 | for (int r = 0; r < numRows; r++) 126 | { 127 | for (int c = 0; c < numCols; c++) 128 | { 129 | matrix[r][c] = new DpMatrixElement(); 130 | } 131 | } 132 | } 133 | 134 | /** 135 | * Print the score of each element in the dynamic 136 | * programming matrix to standard output. 137 | */ 138 | @Override 139 | public String toString() 140 | { 141 | String result = ""; 142 | 143 | result += "\nDynamic Programming Matrix:\n\n"; 144 | 145 | result += "\t\t"; 146 | 147 | // Print the character over each columns 148 | for (String c : colSymbolMap) 149 | { 150 | result += (c + "\t"); 151 | } 152 | result += "\n"; 153 | 154 | for (int r = 0; r < numRows; r++) 155 | { 156 | result += ("[" + stateRowMap.getKey(r).getId() + "]\t"); 157 | 158 | for (int c = 0; c < numCols; c++) 159 | { 160 | result += (LogP.exp(matrix[r][c].getValue()) + "\t"); 161 | } 162 | result += "\n"; 163 | } 164 | result += "\n"; 165 | 166 | return result; 167 | } 168 | 169 | public int getNumColumns() 170 | { 171 | return this.numCols; 172 | } 173 | 174 | } 175 | -------------------------------------------------------------------------------- /machine-learning/src/hmm/algorithms/DpMatrixElement.java: -------------------------------------------------------------------------------- 1 | package hmm.algorithms; 2 | 3 | public class DpMatrixElement 4 | { 5 | /** 6 | * Pointer to the row of the element that determined this element's 7 | * value 8 | */ 9 | private int rowPointer; 10 | 11 | /** 12 | * Pointer to the column of the element that determined this element's 13 | * value 14 | */ 15 | private int columnPointer; 16 | 17 | /** 18 | * This element's value 19 | */ 20 | private double value = Double.NaN; 21 | 22 | public double getValue() 23 | { 24 | return value; 25 | } 26 | 27 | public void setValue(double value) 28 | { 29 | this.value = value; 30 | } 31 | 32 | public void setBackPointer(int row, int column) 33 | { 34 | this.rowPointer = row; 35 | this.columnPointer = column; 36 | } 37 | 38 | public void setRowPointer(int row) 39 | { 40 | this.rowPointer = row; 41 | } 42 | 43 | public void setColumnPointer(int column) 44 | { 45 | this.columnPointer = column; 46 | } 47 | 48 | public int getRowPointer() 49 | { 50 | return this.rowPointer; 51 | } 52 | 53 | public int getColumnPointer() 54 | { 55 | return this.columnPointer; 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /machine-learning/src/hmm/algorithms/SortSilentStates.java: -------------------------------------------------------------------------------- 1 | package hmm.algorithms; 2 | 3 | import graph.dag.TopologicalSort; 4 | import hmm.HMM; 5 | import hmm.State; 6 | import hmm.Transition; 7 | 8 | import java.util.ArrayList; 9 | import java.util.Collection; 10 | 11 | import bimap.BiMap; 12 | 13 | public class SortSilentStates 14 | { 15 | public static ArrayList run(HMM model) 16 | { 17 | Collection silentStates = model.getSilentStates(); 18 | 19 | ArrayList sorted = new ArrayList(); 20 | 21 | BiMap indices = new BiMap(); 22 | 23 | int numNodes = silentStates.size(); 24 | Double[][] graph = new Double[numNodes][numNodes]; 25 | 26 | /* 27 | * Initialize every element in graph to null 28 | */ 29 | for (int r = 0; r < numNodes; r++) 30 | { 31 | for (int c = 0; c < numNodes; c++) 32 | { 33 | graph[r][c] = null; 34 | } 35 | } 36 | 37 | int count = 0; 38 | for (State s : silentStates) 39 | { 40 | indices.put(count++, s.getId()); 41 | } 42 | 43 | for (State s : silentStates) 44 | { 45 | for (Transition t : s.getTransitions()) 46 | { 47 | String dId = t.getDestinationId(); 48 | String oId = s.getId(); 49 | if (indices.containsValue(t.getDestinationId())) 50 | { 51 | graph[indices.getKey(oId)][indices.getKey(dId)] = 1.0; 52 | } 53 | 54 | } 55 | } 56 | 57 | ArrayList sortedIndices = TopologicalSort.run(graph); 58 | 59 | for (Integer i : sortedIndices) 60 | { 61 | sorted.add(model.getStateById(indices.getValue(i))); 62 | } 63 | 64 | return sorted; 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /machine-learning/src/math/LogP.java: -------------------------------------------------------------------------------- 1 | package math; 2 | 3 | /** 4 | * Methods for dealing in log-probability space. This class includes methods 5 | * for converting to and from log-values as well as taking products and 6 | * summations over log-probabilities. 7 | * 8 | */ 9 | public class LogP 10 | { 11 | /** 12 | * Raises a log-probability to the power E thereby converting the 13 | * log-probability to a proper probability. 14 | * 15 | * @param x the log-probability 16 | * @return the proper probability 17 | */ 18 | public static double exp(double x) 19 | { 20 | if (Double.isNaN(x)) 21 | { 22 | return 0; 23 | } 24 | else 25 | { 26 | return Math.pow(Math.E, x); 27 | } 28 | } 29 | 30 | /** 31 | * Take the natural logarithm of a probability thereby converting it to a 32 | * log-probability. 33 | * 34 | * @param x a probability 35 | * @return the log-probability 36 | */ 37 | public static double ln(double x) 38 | { 39 | if (x == 0.0) 40 | { 41 | return Double.NaN; 42 | } 43 | else if (x > 0.0) 44 | { 45 | return Math.log(x); 46 | } 47 | else 48 | { 49 | throw new IllegalArgumentException("Passed 'eLn' function the " + 50 | "negative value " + x + ". " + 51 | "Argument must be greater than " + 52 | "zero."); 53 | } 54 | } 55 | 56 | /** 57 | * Take the sum of two log-probabilities 58 | * 59 | * @param eLnX the first log-probability 60 | * @param eLnY the second log-probability 61 | * @return the sum 62 | */ 63 | public static double sum(double eLnX, double eLnY) 64 | { 65 | if (Double.isNaN(eLnX) || Double.isNaN(eLnY)) 66 | { 67 | if (Double.isNaN(eLnX)) 68 | { 69 | return eLnY; 70 | } 71 | else 72 | { 73 | return eLnX; 74 | } 75 | } 76 | else 77 | { 78 | if (eLnX > eLnY) 79 | { 80 | return eLnX + ln(1 + Math.pow(Math.E, eLnY - eLnX)); 81 | } 82 | else 83 | { 84 | return eLnY + ln(1 + Math.pow(Math.E, eLnX - eLnY)); 85 | } 86 | } 87 | } 88 | 89 | /** 90 | * Take the product of two log-probabilities 91 | * 92 | * @param eLnX the first log-probability 93 | * @param eLnY the second log-probability 94 | * @return the product 95 | */ 96 | public static double prod(double eLnX, double eLnY) 97 | { 98 | if (Double.isNaN(eLnX) || Double.isNaN(eLnY)) 99 | { 100 | return Double.NaN; 101 | } 102 | else 103 | { 104 | return eLnX + eLnY; 105 | } 106 | } 107 | 108 | /** 109 | * Take the quotient of two log-probabilities 110 | * 111 | * @param eLnX the dividend 112 | * @param eLnY the divisor 113 | * @return the quotient 114 | */ 115 | public static double div(double eLnX, double eLnY) 116 | { 117 | if (Double.isNaN(eLnY)) 118 | { 119 | throw new IllegalArgumentException("Passed 'eLnDivision' " + 120 | "function the a NaN quotient. Argument must be real."); 121 | } 122 | else if (Double.isNaN(eLnX)) 123 | { 124 | return Double.NaN; 125 | } 126 | else 127 | { 128 | return eLnX - eLnY; 129 | } 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /machine-learning/src/tree/DtLeaf.java: -------------------------------------------------------------------------------- 1 | package tree; 2 | 3 | import data.Attribute; 4 | 5 | /** 6 | * The leaf of a decision tree. 7 | * 8 | */ 9 | public class DtLeaf extends DtNode 10 | { 11 | /** 12 | * This integer must be a valid nominal value of the class attribute for the 13 | * current learning problem 14 | */ 15 | private Integer classLabel; 16 | 17 | public DtLeaf(Attribute attribute, 18 | Double value, 19 | Relation relation, 20 | Integer classLabel) 21 | { 22 | super(attribute, value, relation); 23 | this.classLabel = classLabel; 24 | } 25 | 26 | @Override 27 | public void addChild(Node child) 28 | { 29 | throw new UnsupportedOperationException("Error. Attempting to add" + 30 | " a child Node to Leaf " + super.toString()); 31 | } 32 | 33 | public Integer getClassLabel() 34 | { 35 | return classLabel; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /machine-learning/src/tree/Forest.java: -------------------------------------------------------------------------------- 1 | package tree; 2 | 3 | import java.util.Set; 4 | 5 | import pair.Pair; 6 | import data.Instance; 7 | 8 | public class Forest 9 | { 10 | private Set trees; 11 | 12 | } 13 | -------------------------------------------------------------------------------- /machine-learning/src/tree/Node.java: -------------------------------------------------------------------------------- 1 | package tree; 2 | 3 | import java.util.HashSet; 4 | import java.util.Set; 5 | 6 | /** 7 | * A node in a tree. 8 | * 9 | */ 10 | public class Node 11 | { 12 | /** 13 | * This nodes parent in the tree 14 | */ 15 | protected Node parent = null; 16 | 17 | /** 18 | * This node's child in the tree 19 | */ 20 | protected final Set children; 21 | 22 | /** 23 | * Constructor for creating a node 24 | * @param nodeId 25 | */ 26 | public Node() 27 | { 28 | children = new HashSet(); 29 | } 30 | 31 | /** 32 | * Add a child Node. 33 | * 34 | * @param child 35 | */ 36 | public void addChild(Node child) 37 | { 38 | child.setParent(this); 39 | children.add(child); 40 | } 41 | 42 | /** 43 | * Remove a child Node. 44 | * 45 | * @param child 46 | * @return true if the Node existed as child of the current Node 47 | * and was successfully removed 48 | */ 49 | public Boolean removeChild(Node child) 50 | { 51 | child.setParent(null); 52 | return children.remove(child); 53 | } 54 | 55 | /** 56 | * @return this node's children nodes 57 | */ 58 | public Set getChildren() 59 | { 60 | return children; 61 | } 62 | 63 | /** 64 | * Set the parent of the node. 65 | * 66 | * @param parent the node's parent 67 | */ 68 | public void setParent(Node parent) 69 | { 70 | this.parent = parent; 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /machine-learning/src/tree/algorithms/DecisionTreeBuilder.java: -------------------------------------------------------------------------------- 1 | package tree.algorithms; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import classify.Classifier; 7 | import tree.DecisionTree; 8 | import tree.DtLeaf; 9 | import tree.DtNode; 10 | import tree.Node; 11 | import tree.DtNode.Relation; 12 | import tree.train.Split; 13 | import tree.train.SplitBranch; 14 | import tree.train.SplitGenerator; 15 | import data.Attribute; 16 | import data.DataSet; 17 | 18 | public abstract class DecisionTreeBuilder 19 | { 20 | 21 | /** 22 | * The decision tree under construction 23 | */ 24 | protected DecisionTree decisionTree = null; 25 | 26 | /** 27 | * Determine when the recursion should stop and a leaf node should be constructed. 28 | * 29 | * @param data the data set used to make this decision 30 | * @param availAttributes available attributes 31 | * @param candidateSplits splits that are under consideration to split on 32 | * @return whether or not the recursion should terminate 33 | */ 34 | protected abstract boolean checkStoppingCriteria(DataSet data, 35 | List availAttributes, 36 | List candidateSplits); 37 | 38 | /** 39 | * Determine the best attribute to split on. 40 | * 41 | * @param data the data set used to make this decision 42 | * @param candidateSplits the set of candidate split 43 | * @return the best split among the candidates 44 | */ 45 | protected abstract Split determineBestSplit(DataSet data, List candidateSplits); 46 | 47 | public DecisionTree buildDecisionTree(DataSet data) 48 | { 49 | List availAttributes = new ArrayList<>(data.getAttributeSet().getAttributes()); 50 | availAttributes.remove(data.getClassAttribute()); 51 | 52 | DtNode root = makeSubTree( 53 | data, 54 | null, 55 | null, 56 | null, 57 | availAttributes); 58 | 59 | decisionTree = new DecisionTree(root, data.getClassAttribute()); 60 | return decisionTree; 61 | } 62 | 63 | private DtNode makeSubTree( 64 | DataSet data, 65 | Attribute attribute, 66 | Double value, 67 | DtNode.Relation relation, 68 | List availAttrs) 69 | { 70 | DtNode newNode = null; 71 | List candidateSplits = SplitGenerator.generateSplits(data, availAttrs); 72 | 73 | /* 74 | * If the stopping criteria is met, create a leaf node with a decision 75 | * class label that is the majority class of the instances at this node 76 | */ 77 | if (checkStoppingCriteria(data, availAttrs, candidateSplits)) 78 | { 79 | DtLeaf leaf = new DtLeaf(attribute, 80 | value, 81 | relation, 82 | data.getMajorityClass()); 83 | leaf.setClassCounts(data.getClassCounts()); 84 | newNode = leaf; 85 | } 86 | else 87 | { 88 | newNode = new DtNode(attribute, value, relation); 89 | newNode.setClassCounts(data.getClassCounts()); 90 | 91 | /* 92 | * Find the best split among the candidate splits. For each branch of the best split, create a 93 | * new node that roots a subtree 94 | */ 95 | Split bestSplit = determineBestSplit(data, candidateSplits); 96 | for (SplitBranch branch : bestSplit.getSplitBranches()) 97 | { 98 | DataSet subsetData = new DataSet(data.getAttributeSet(), branch.getInstanceSet()); 99 | subsetData.setClassAttribute(data.getClassAttribute().getName()); 100 | 101 | /* 102 | * Determine attributes that are still available after the split. 103 | * We only remove the attribute if it is nominal. 104 | */ 105 | List newAvailAttrs = null; 106 | if (branch.getAttribute().getType() == Attribute.Type.NOMINAL) 107 | { 108 | newAvailAttrs = new ArrayList<>(availAttrs); 109 | newAvailAttrs.remove(bestSplit.getAttribute()); 110 | } 111 | else 112 | { 113 | newAvailAttrs = availAttrs; 114 | } 115 | 116 | /* 117 | * Make the recursive call to make a subtree at each child node 118 | */ 119 | Node child = makeSubTree( 120 | subsetData, 121 | bestSplit.getAttribute(), 122 | branch.getValue(), 123 | branch.getRelation(), 124 | newAvailAttrs); 125 | 126 | newNode.addChild(child); 127 | } 128 | } 129 | 130 | return newNode; 131 | } 132 | 133 | } 134 | -------------------------------------------------------------------------------- /machine-learning/src/tree/algorithms/ID3TreeBuilder.java: -------------------------------------------------------------------------------- 1 | package tree.algorithms; 2 | 3 | import java.util.List; 4 | 5 | import tree.train.Split; 6 | import data.Attribute; 7 | import data.DataSet; 8 | 9 | /** 10 | * Builds a decision tree using the ID3 algorithm. Splits are 11 | * determined by finding the split that yields maximal information gain on each 12 | * iteration. The stopping criteria is met when either a minimum number of 13 | * instances are found at the leaf node, all instances at the leaf node are of 14 | * the same class, or there are no more splits to split on. 15 | * 16 | */ 17 | public class ID3TreeBuilder extends DecisionTreeBuilder 18 | { 19 | /** 20 | * The minimum number of instances at the leaf node for 21 | * the stopping criteria to be met. 22 | */ 23 | private final int minInstances; 24 | 25 | public ID3TreeBuilder(Integer minInstances) 26 | { 27 | this.minInstances = minInstances; 28 | } 29 | 30 | @Override 31 | public boolean checkStoppingCriteria(DataSet data, 32 | List availAttributes, 33 | List candidateSplits) { 34 | 35 | int numInstances = data.getInstanceSet().getInstances().size(); 36 | return isAllCandidateSplitsNegativeInfoGain(candidateSplits) || 37 | availAttributes.isEmpty() || 38 | numInstances < minInstances || 39 | isAllInstancesOfSameClass(data); 40 | } 41 | 42 | /** 43 | * Find the Split with the highest information gain among the candidate 44 | * splits 45 | * 46 | * @param data 47 | * @param candidateSplits 48 | * @return 49 | */ 50 | @Override 51 | public Split determineBestSplit(DataSet data, List candidateSplits) 52 | { 53 | Split bestSplit = null; 54 | double maxInfoGain = -Double.MAX_VALUE; 55 | for (Split split : candidateSplits) 56 | { 57 | if (split.getInfoGain().doubleValue() > maxInfoGain) 58 | { 59 | maxInfoGain = split.getInfoGain().doubleValue(); 60 | bestSplit = split; 61 | } 62 | } 63 | return bestSplit; 64 | } 65 | 66 | /** 67 | * @param candidateSplits all candidate splits 68 | * @return true if all candidate splits have negative information gain. 69 | * False otherwise. 70 | */ 71 | private boolean isAllCandidateSplitsNegativeInfoGain(List candidateSplits) 72 | { 73 | for (Split split : candidateSplits) 74 | { 75 | if (split.getInfoGain() > 0) 76 | { 77 | return false; 78 | } 79 | } 80 | return true; 81 | } 82 | 83 | /** 84 | * @param data the dataset 85 | * @return true if if all instances are of the same class. False otherwise. 86 | */ 87 | private boolean isAllInstancesOfSameClass(DataSet data) 88 | { 89 | for (Integer count : data.getClassCounts().values()) 90 | { 91 | int numInstances = data.getInstanceSet().getInstances().size(); 92 | if (count == numInstances) 93 | { 94 | return true; 95 | } 96 | } 97 | return false; 98 | } 99 | 100 | } 101 | 102 | -------------------------------------------------------------------------------- /machine-learning/src/tree/algorithms/RandomForestBuilder.java: -------------------------------------------------------------------------------- 1 | package tree.algorithms; 2 | 3 | import java.util.Set; 4 | 5 | import data.AttributeSet; 6 | import data.DataSet; 7 | import data.InstanceSet; 8 | import tree.DecisionTree; 9 | 10 | public class RandomForestBuilder 11 | { 12 | 13 | private DataSet sampleDataSet(DataSet fullData) 14 | { 15 | AttributeSet attributes = sampleAttributes(fullData.getAttributeSet()); 16 | InstanceSet instances = sampleInstances(fullData.getInstanceSet()); 17 | return new DataSet(attributes, instances); 18 | } 19 | 20 | private AttributeSet sampleAttributes(AttributeSet fullAttributeSet) 21 | { 22 | return null; 23 | } 24 | 25 | private InstanceSet sampleInstances(InstanceSet fullInstanceSet) 26 | { 27 | return null; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /machine-learning/src/tree/classifiers/ID3TreeClassifier.java: -------------------------------------------------------------------------------- 1 | package tree.classifiers; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import classify.ClassificationResult; 7 | import classify.Classifier; 8 | import pair.Pair; 9 | import tree.DecisionTree; 10 | import tree.algorithms.ID3TreeBuilder; 11 | import data.DataSet; 12 | import data.Instance; 13 | 14 | public class ID3TreeClassifier implements Classifier 15 | { 16 | private DecisionTree dtTree; 17 | 18 | public ID3TreeClassifier(int minInstances, DataSet trainData) 19 | { 20 | ID3TreeBuilder id3Builder = new ID3TreeBuilder(minInstances); 21 | dtTree = id3Builder.buildDecisionTree(trainData); 22 | } 23 | 24 | @Override 25 | public ClassificationResult classifyData(DataSet testData) 26 | { 27 | /* 28 | * Classify each instance in the test dataset 29 | */ 30 | List> resultList = new ArrayList<>(); 31 | for (Instance instance : testData.getInstanceSet().getInstances()) 32 | { 33 | resultList.add( dtTree.classifyInstance(instance) ); 34 | } 35 | 36 | /* 37 | * Process the results 38 | */ 39 | ClassificationResult result = new ClassificationResult(resultList, testData); 40 | return result; 41 | } 42 | 43 | @Override 44 | public Object getModel() 45 | { 46 | return this.dtTree; 47 | } 48 | 49 | @Override 50 | public String toString() 51 | { 52 | return "ID3\n\n" + dtTree; 53 | } 54 | 55 | } 56 | -------------------------------------------------------------------------------- /machine-learning/src/tree/classifiers/RandomForestClassifier.java: -------------------------------------------------------------------------------- 1 | package tree.classifiers; 2 | 3 | import java.util.Set; 4 | 5 | import tree.DecisionTree; 6 | 7 | public class RandomForestClassifier 8 | { 9 | private Set trees; 10 | 11 | } 12 | -------------------------------------------------------------------------------- /machine-learning/src/tree/evaluate/BiClassTest.java: -------------------------------------------------------------------------------- 1 | package tree.evaluate; 2 | 3 | import java.util.Set; 4 | 5 | import tree.DecisionTree; 6 | import tree.DtLeaf; 7 | import tree.DtNode; 8 | import data.Attribute; 9 | import data.DataSet; 10 | import data.Instance; 11 | 12 | public class BiClassTest 13 | { 14 | public static BiClassTestResults runTest(DataSet data, DecisionTree dt) 15 | { 16 | BiClassTestResults results = new BiClassTestResults(); 17 | 18 | for (Instance instance : data.getInstanceSet().getInstances()) 19 | { 20 | DtNode currNode = (DtNode) dt.getRoot(); 21 | 22 | while (!(currNode instanceof DtLeaf)) 23 | { 24 | @SuppressWarnings("unchecked") 25 | Set children = ((Set) ((Set) currNode.getChildren())); 26 | 27 | for (DtNode node : children) 28 | { 29 | if (node.doesInstanceSatisfyNode(instance)) 30 | { 31 | currNode = node; 32 | break; 33 | } 34 | } 35 | } 36 | 37 | Attribute classAttr = data.getClassAttribute(); 38 | 39 | DtLeaf leaf = (DtLeaf) currNode; 40 | 41 | String prediction = data.getClassAttribute().getNominalValueName( 42 | leaf.getClassLabel().intValue() ); 43 | String truth = data.getClassAttribute().getNominalValueName( 44 | instance.getAttributeValue(classAttr).intValue() ); 45 | 46 | // Print result of classification 47 | System.out.print(prediction); 48 | System.out.print(" "); 49 | System.out.print(truth); 50 | System.out.print("\n"); 51 | 52 | // Add the prediction to the test results 53 | results.addClassification(leaf.getClassLabel().intValue(), 54 | instance.getAttributeValue(classAttr).intValue()); 55 | } 56 | 57 | System.out.println("\n"); 58 | 59 | return results; 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /machine-learning/src/tree/evaluate/BiClassTestResults.java: -------------------------------------------------------------------------------- 1 | package tree.evaluate; 2 | 3 | public class BiClassTestResults 4 | { 5 | private static final int NEGATIVE = 0; 6 | private static final int POSITIVE = 1; 7 | 8 | private Integer posPredictions = 0; 9 | private Integer negPredictions = 0; 10 | 11 | private Integer falsePos = 0; 12 | private Integer truePos = 0; 13 | private Integer falseNeg = 0; 14 | private Integer trueNeg = 0; 15 | 16 | 17 | public void addClassification(Integer predictedLabel, Integer trueLabel) 18 | { 19 | 20 | if (predictedLabel.equals(POSITIVE)) 21 | { 22 | this.posPredictions++; 23 | } 24 | else if (predictedLabel.equals(NEGATIVE)) 25 | { 26 | this.negPredictions++; 27 | } 28 | 29 | // Compute whether this classification was TP, FP, TN, or FN 30 | if (predictedLabel.equals(trueLabel)) 31 | { 32 | if (predictedLabel.equals(POSITIVE)) 33 | { 34 | truePos++; 35 | } 36 | else if (predictedLabel.equals(NEGATIVE)) 37 | { 38 | trueNeg++; 39 | } 40 | } 41 | else 42 | { 43 | if (predictedLabel.equals(POSITIVE)) 44 | { 45 | falsePos++; 46 | } 47 | else if (predictedLabel.equals(NEGATIVE)) 48 | { 49 | falseNeg++; 50 | } 51 | } 52 | } 53 | 54 | public void printResults() 55 | { 56 | System.out.println((truePos + trueNeg) + " " + (posPredictions + negPredictions)); 57 | /* 58 | System.out.println("Correctly classified: " + (truePos + trueNeg)); 59 | System.out.println("Total instances: " + (posPredictions + negPredictions)); 60 | System.out.print("\n");*/ 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /machine-learning/src/tree/train/Bin.java: -------------------------------------------------------------------------------- 1 | package tree.train; 2 | 3 | import java.util.Comparator; 4 | import java.util.HashMap; 5 | import java.util.Map; 6 | 7 | /** 8 | * This helper class keeps track of which class labels are represented at each 9 | * value of a specific continuous attribute. This class is specifically used 10 | * for generating all possible splits along a continuous attribute. 11 | *
12 | *
13 | * For example, say we have two instances with attribute A = a. One of these 14 | * instances has Class = positive, the other instance has Class = negative. 15 | * This class stores the knowledge that both class values are represented in 16 | * the instances whose attribute A = a. 17 | * 18 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 19 | * 20 | */ 21 | public class Bin 22 | { 23 | /** 24 | * Ordering of bins based on value they represent 25 | */ 26 | public static final Comparator BIN_ORDER = 27 | new Comparator() 28 | { 29 | public int compare(Bin b1, Bin b2) 30 | { 31 | return b2.getValue().compareTo(b1.getValue()); 32 | } 33 | }; 34 | 35 | /** 36 | * The nominal value of some target attribute represented by the bin 37 | */ 38 | private Double value; 39 | 40 | /** 41 | * A mapping of a nominal value ID of the class attribute to 42 | * a Boolean variable. This Boolean is true if an instance with the given 43 | * class label is in this bin. 44 | */ 45 | private Map classExistenceMap; 46 | 47 | public Bin(Double value) 48 | { 49 | classExistenceMap = new HashMap(); 50 | this.value = value; 51 | } 52 | 53 | /** 54 | * @return the bin's value of the target attribute. 55 | */ 56 | public Double getValue() 57 | { 58 | return this.value; 59 | } 60 | 61 | /** 62 | * @return a mapping of a nominal value ID of the class attribute to 63 | * a Boolean variable. This Boolean is true if an instance with the given 64 | * class label is in this bin. 65 | */ 66 | public Map getExistenceMap() 67 | { 68 | return classExistenceMap; 69 | } 70 | 71 | /** 72 | * @param classLabel a nominal value ID of the class attribute that is 73 | * included in this bin 74 | */ 75 | public void includeInstance(Integer classLabel) 76 | { 77 | classExistenceMap.put(classLabel, true); 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /machine-learning/src/tree/train/Entropy.java: -------------------------------------------------------------------------------- 1 | package tree.train; 2 | 3 | import java.util.Map; 4 | 5 | 6 | import data.DataSet; 7 | 8 | /** 9 | * Used for calculating information theory metrics used for determining the 10 | * best splits in a decision tree. 11 | * 12 | * @author Matthew Bernstein - matthewb@cs.wisc.edu 13 | * 14 | */ 15 | public class Entropy 16 | { 17 | /** 18 | * Calculate the information gain of the class attribute on a given split. 19 | * A split consists of an attribute and a set of instances. This method 20 | * calculates 21 | *
22 | *
23 | * H(C) - H(C | X) 24 | *
25 | *
26 | * where H(C) is the entropy of the class attribute and H(C | X) is the 27 | * conditional entropy of the class attribute given some other attribute, X 28 | * 29 | * @param data the data set 30 | * @param split a split on the data 31 | * @return the information gained on the dataset's class attribute by 32 | * knowing the split 33 | */ 34 | public static Double informationGain(DataSet data, Split split) 35 | { 36 | Double entropy = entropy(data); 37 | Double conditionalEntropy = conditionalEntropy(data, split); 38 | Double infoGain = entropy - conditionalEntropy; 39 | 40 | return infoGain; 41 | } 42 | 43 | /** 44 | * Calculate the entropy of the class attribute in a data set. This method 45 | * calculates 46 | *
47 | *
48 | * H(C) 49 | *
50 | *
51 | * where C is the class attribute 52 | * 53 | * @param data the data set for which to calculate entropy 54 | * @return the entropy of the class attribute 55 | */ 56 | public static Double entropy(DataSet data) 57 | { 58 | double entropy = 0; 59 | double totalInstances = data.getInstanceSet().getInstances().size(); 60 | Map classCounts = data.getClassCounts(); 61 | 62 | /* 63 | * Calculate the entropy by summer P*log(P) for each P, 64 | * where P is the probability of seeing a class label 65 | */ 66 | for (Integer count : classCounts.values()) 67 | { 68 | if (count > 0) 69 | { 70 | double P = count / totalInstances; 71 | entropy += ( P * Math.log(P) / Math.log(2d) ); 72 | } 73 | } 74 | 75 | entropy *= -1; 76 | 77 | return entropy; 78 | } 79 | 80 | /** 81 | * Calculate the conditional entropy of the class attribute given another 82 | * attribute. That is, this method calculates 83 | *
84 | *
85 | * H(C | X) 86 | *
87 | *
88 | * where C is the class attribute and X is some other attribute 89 | * in the data set. 90 | * 91 | * @param data the data set 92 | * @param split the split of the attribute for which we condition the class 93 | * attribute 94 | * @return 95 | */ 96 | public static Double conditionalEntropy(DataSet data, Split split) 97 | { 98 | double conditionalEntropy = 0; 99 | double totalInstances = data.getInstanceSet().getInstances().size(); 100 | 101 | if (totalInstances > 0) 102 | { 103 | for (SplitBranch branch : split.getSplitBranches()) 104 | { 105 | DataSet branchData = new DataSet(data.getAttributeSet(), branch.getInstanceSet()); 106 | branchData.setClassAttribute(data.getClassAttribute().getName()); 107 | 108 | int branchNumInstances = branchData.getInstanceSet().getInstances().size(); 109 | conditionalEntropy += 110 | ((branchNumInstances / totalInstances) * entropy(branchData)); 111 | } 112 | } 113 | 114 | return conditionalEntropy; 115 | } 116 | 117 | } 118 | -------------------------------------------------------------------------------- /machine-learning/src/tree/train/Split.java: -------------------------------------------------------------------------------- 1 | package tree.train; 2 | 3 | import java.util.ArrayList; 4 | import java.util.List; 5 | 6 | import data.Attribute; 7 | import data.DataSet; 8 | import data.Instance; 9 | import data.InstanceSet; 10 | 11 | /** 12 | * This class splits a set of instances along an attribute. It stores the 13 | * separated instances sorted by the value of this split's attribute. 14 | * 15 | */ 16 | public class Split 17 | { 18 | /** 19 | * All of the branches for this split. For splits on nominal attributes, 20 | * there will be one branch per nominal value. For continuous attributes, 21 | * there will be two branches. One branch for the all instances with a 22 | * value greater than the threshold and one branch for all instances less 23 | * than or equal to the threshold value. 24 | */ 25 | private List branches; 26 | 27 | /** 28 | * The attribute this Split splits instances on 29 | */ 30 | private Attribute attribute; 31 | 32 | /** 33 | * The information gain on this split 34 | */ 35 | private Double infoGain; 36 | 37 | /** 38 | * Constructor 39 | * 40 | * @param attribute the attribute along which this split splits 41 | */ 42 | public Split(Attribute attribute) 43 | { 44 | branches = new ArrayList(); 45 | this.attribute = attribute; 46 | } 47 | 48 | /** 49 | * @return this split's attribute 50 | */ 51 | public Attribute getAttribute() 52 | { 53 | return attribute; 54 | } 55 | 56 | /** 57 | * @return the information gain along this split for the dataset's 58 | * class attribute. 59 | */ 60 | public Double getInfoGain() 61 | { 62 | return infoGain; 63 | } 64 | 65 | // TODO: REFACTOR THIS! 66 | @Deprecated 67 | public void setInfoGain(Double infoGain) 68 | { 69 | this.infoGain = infoGain; 70 | } 71 | 72 | /** 73 | * Split a set of instances along this split. 74 | * 75 | * @param data the dataset containing the instances to be split 76 | */ 77 | public void splitInstances(DataSet data) 78 | { 79 | InstanceSet instances = data.getInstanceSet(); 80 | 81 | for (Instance instance : instances.getInstances()) 82 | { 83 | for (SplitBranch branch : this.branches) 84 | { 85 | branch.tryAddInstance(instance); 86 | } 87 | } 88 | 89 | this.infoGain = Entropy.informationGain(data, this); 90 | } 91 | 92 | /** 93 | * @return each branch along this split 94 | */ 95 | public List getSplitBranches() 96 | { 97 | return branches; 98 | } 99 | 100 | /** 101 | * Add a branch to the split 102 | * 103 | * @param newBranch the new branch 104 | */ 105 | protected void addBranch(SplitBranch newBranch) 106 | { 107 | branches.add(newBranch); 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /machine-learning/src/tree/train/SplitBranch.java: -------------------------------------------------------------------------------- 1 | package tree.train; 2 | 3 | import tree.DtNode; 4 | import data.Attribute; 5 | import data.Instance; 6 | import data.InstanceSet; 7 | 8 | /** 9 | * This class represents a single branch along a split. If a split splits 10 | * instances along attribute A where A can take values {a1, a2, a3}, this split 11 | * will store 3 split branches for storing instances whose value of A is a1, a2, 12 | * and a3. 13 | */ 14 | public class SplitBranch 15 | { 16 | /** 17 | * This describes the relation to the {@branchValue} that an instance is 18 | * tested against. Possible relations are defined in {@code DtNode}: 19 | * 20 | * EQUALS : instance value == branch value 21 | * GREATER_THAN : instance value > branch value 22 | * GREATER_THAN_EQUAL_TO : instance value >= branch value 23 | */ 24 | private DtNode.Relation relation; 25 | 26 | /** 27 | * The value that an instance is tested against to make this split 28 | */ 29 | private Double branchValue; 30 | 31 | /** 32 | * The attribute this branch tests 33 | */ 34 | private Attribute attribute; 35 | 36 | /** 37 | * All instances that fall to this branch 38 | */ 39 | private InstanceSet instanceSet; 40 | 41 | /** 42 | * Constructor 43 | * 44 | * @param attribute the attribute along which this branch's split splits 45 | * instances 46 | * @param branchValue the value of the attribute that this branch tests 47 | * @param relation the relation to the attribute an instance must be to 48 | * make this branch 49 | */ 50 | protected SplitBranch(Attribute attribute, 51 | Double branchValue, 52 | DtNode.Relation relation) 53 | { 54 | this.instanceSet = new InstanceSet(); 55 | this.attribute = attribute; 56 | this.branchValue = branchValue; 57 | this.relation = relation; 58 | } 59 | 60 | /** 61 | * @return the set of instances that have made this branch 62 | */ 63 | public InstanceSet getInstanceSet() 64 | { 65 | return instanceSet; 66 | } 67 | 68 | /** 69 | * Attempt to add an instance to the this split branch. The instance is only 70 | * add if it passes this branches test. 71 | * 72 | * @param instance 73 | */ 74 | public void tryAddInstance(Instance instance) 75 | { 76 | if (this.doesInstanceMakeSplit(instance)) 77 | { 78 | instanceSet.addInstance(instance); 79 | } 80 | } 81 | 82 | /** 83 | * @return the split's attribute 84 | */ 85 | public Attribute getAttribute() 86 | { 87 | return attribute; 88 | } 89 | 90 | /** 91 | * @return the relation to the attribute for which all instances in this 92 | * split branch fall 93 | */ 94 | public DtNode.Relation getRelation() 95 | { 96 | return relation; 97 | } 98 | 99 | /** 100 | * @return the value of the split branch 101 | */ 102 | public Double getValue() 103 | { 104 | return branchValue; 105 | } 106 | 107 | /** 108 | * Tests whether an instance makes this split branch 109 | * 110 | * @param instance the instance we are testing 111 | * @return true if the instance makes the split branch. False otherwise. 112 | */ 113 | public Boolean doesInstanceMakeSplit(Instance instance) 114 | { 115 | Double instanceAttrValue = instance.getAttributeValue(this.attribute); 116 | 117 | switch(this.relation) 118 | { 119 | case EQUALS: 120 | return (instanceAttrValue.doubleValue() == branchValue.doubleValue()); 121 | case GREATER_THAN: 122 | return (instanceAttrValue.doubleValue() > branchValue.doubleValue()); 123 | case LESS_THAN_EQUAL_TO: 124 | return (instanceAttrValue.doubleValue() <= branchValue.doubleValue()); 125 | default: 126 | throw new RuntimeException("Error testing instance in branch. " + 127 | "This branch's relation is not set to a valid relation."); 128 | } 129 | } 130 | } 131 | -------------------------------------------------------------------------------- /machine-learning/tst/data/AttributeTest.java: -------------------------------------------------------------------------------- 1 | package data; 2 | 3 | import static org.junit.Assert.*; 4 | 5 | import java.util.HashMap; 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | import org.junit.Before; 10 | import org.junit.Test; 11 | 12 | import com.google.common.collect.ImmutableList; 13 | import com.google.common.collect.ImmutableMap; 14 | 15 | 16 | public class AttributeTest 17 | { 18 | private static final String ATTR_NAME = "Color"; 19 | private static final List NOMINAL_VALUES = ImmutableList.of("Red", "Yellow", "Blue"); 20 | private static final List NOMINAL_IDS = ImmutableList.of(0, 1, 2); 21 | 22 | private static final Map NOMINAL_VALUE_IDS = generateNominalValueIds(); 23 | 24 | @Before 25 | public void before() 26 | { 27 | 28 | } 29 | 30 | @Test 31 | public void test_Constructor() 32 | { 33 | Attribute nominalAttr = new Attribute(ATTR_NAME, Attribute.Type.NOMINAL, (String[]) NOMINAL_VALUES.toArray()); 34 | 35 | assertEquals(nominalAttr.getName(), ATTR_NAME); 36 | 37 | //assertEquals(nominalAttr.getNominalValueId(attrValueName), NOMINAL_VALUE_IDS.keySet()); 38 | assertEquals(nominalAttr.getNominalValueMap().values(), NOMINAL_VALUE_IDS.values()); 39 | } 40 | 41 | private static Map generateNominalValueIds() 42 | { 43 | return new ImmutableMap.Builder() 44 | .put(NOMINAL_VALUES.get(0), NOMINAL_IDS.get(0)) 45 | .put(NOMINAL_VALUES.get(1), NOMINAL_IDS.get(1)) 46 | .put(NOMINAL_VALUES.get(2), NOMINAL_IDS.get(2)) 47 | .build(); 48 | } 49 | 50 | } 51 | --------------------------------------------------------------------------------