├── .gitignore
├── LICENSE
├── README.md
└── machine-learning
├── doc
├── allclasses-frame.html
├── allclasses-noframe.html
├── bayes
│ ├── BNConditionalQuery.html
│ ├── BNDataGenerator.html
│ ├── BNEvaluator.html
│ ├── BNJointQuery.html
│ ├── BNNode.html
│ ├── BNNodeManager.html
│ ├── BNResultWriter.html
│ ├── BNUtility.html
│ ├── BayesianNetwork.Type.html
│ ├── BayesianNetwork.html
│ ├── VariableSet.html
│ ├── builders
│ │ ├── HillClimbingBuilder.StoppingCriteria.html
│ │ ├── HillClimbingBuilder.html
│ │ ├── NaiveBayesBuilder.html
│ │ ├── NetworkBuilder.html
│ │ ├── Operation.Type.html
│ │ ├── Operation.html
│ │ ├── SparseCandidateBuilder.KlEdgeScorePair.html
│ │ ├── SparseCandidateBuilder.html
│ │ ├── TANBuilder.html
│ │ ├── class-use
│ │ │ ├── HillClimbingBuilder.StoppingCriteria.html
│ │ │ ├── HillClimbingBuilder.html
│ │ │ ├── NaiveBayesBuilder.html
│ │ │ ├── NetworkBuilder.html
│ │ │ ├── Operation.Type.html
│ │ │ ├── Operation.html
│ │ │ ├── SparseCandidateBuilder.KlEdgeScorePair.html
│ │ │ ├── SparseCandidateBuilder.html
│ │ │ └── TANBuilder.html
│ │ ├── package-frame.html
│ │ ├── package-summary.html
│ │ ├── package-tree.html
│ │ ├── package-use.html
│ │ └── scoring
│ │ │ ├── BIC.html
│ │ │ ├── ScoringFunction.html
│ │ │ ├── class-use
│ │ │ ├── BIC.html
│ │ │ └── ScoringFunction.html
│ │ │ ├── package-frame.html
│ │ │ ├── package-summary.html
│ │ │ ├── package-tree.html
│ │ │ └── package-use.html
│ ├── class-use
│ │ ├── BNConditionalQuery.html
│ │ ├── BNDataGenerator.html
│ │ ├── BNEvaluator.html
│ │ ├── BNJointQuery.html
│ │ ├── BNNode.html
│ │ ├── BNNodeManager.html
│ │ ├── BNResultWriter.html
│ │ ├── BNUtility.html
│ │ ├── BayesianNetwork.Type.html
│ │ ├── BayesianNetwork.html
│ │ └── VariableSet.html
│ ├── classifiers
│ │ ├── NaiveBayesClassifier.html
│ │ ├── class-use
│ │ │ └── NaiveBayesClassifier.html
│ │ ├── package-frame.html
│ │ ├── package-summary.html
│ │ ├── package-tree.html
│ │ └── package-use.html
│ ├── cpd
│ │ ├── CPDLeaf.html
│ │ ├── CPDNode.html
│ │ ├── CPDQuery.html
│ │ ├── CPDTree.ToStringHelper.html
│ │ ├── CPDTree.html
│ │ ├── CPDTreeBuilder.html
│ │ ├── Split.html
│ │ ├── SplitBranch.html
│ │ ├── class-use
│ │ │ ├── CPDLeaf.html
│ │ │ ├── CPDNode.html
│ │ │ ├── CPDQuery.html
│ │ │ ├── CPDTree.ToStringHelper.html
│ │ │ ├── CPDTree.html
│ │ │ ├── CPDTreeBuilder.html
│ │ │ ├── Split.html
│ │ │ └── SplitBranch.html
│ │ ├── package-frame.html
│ │ ├── package-summary.html
│ │ ├── package-tree.html
│ │ └── package-use.html
│ ├── information
│ │ ├── KLDivergence.html
│ │ ├── class-use
│ │ │ └── KLDivergence.html
│ │ ├── package-frame.html
│ │ ├── package-summary.html
│ │ ├── package-tree.html
│ │ └── package-use.html
│ ├── package-frame.html
│ ├── package-summary.html
│ ├── package-tree.html
│ └── package-use.html
├── common
│ ├── classification
│ │ ├── ClassificationResult.html
│ │ ├── Classifier.html
│ │ ├── class-use
│ │ │ ├── ClassificationResult.html
│ │ │ └── Classifier.html
│ │ ├── package-frame.html
│ │ ├── package-summary.html
│ │ ├── package-tree.html
│ │ └── package-use.html
│ └── kfold
│ │ ├── KFoldCreator.html
│ │ ├── class-use
│ │ └── KFoldCreator.html
│ │ ├── package-frame.html
│ │ ├── package-summary.html
│ │ ├── package-tree.html
│ │ └── package-use.html
├── constant-values.html
├── deprecated-list.html
├── graph
│ ├── dag
│ │ ├── DetectCycles.html
│ │ ├── TopologicalSort.html
│ │ ├── class-use
│ │ │ ├── DetectCycles.html
│ │ │ └── TopologicalSort.html
│ │ ├── package-frame.html
│ │ ├── package-summary.html
│ │ ├── package-tree.html
│ │ └── package-use.html
│ └── prim
│ │ ├── Edge.html
│ │ ├── Prim.html
│ │ ├── class-use
│ │ ├── Edge.html
│ │ └── Prim.html
│ │ ├── package-frame.html
│ │ ├── package-summary.html
│ │ ├── package-tree.html
│ │ └── package-use.html
├── help-doc.html
├── index-files
│ ├── index-1.html
│ ├── index-10.html
│ ├── index-11.html
│ ├── index-12.html
│ ├── index-13.html
│ ├── index-14.html
│ ├── index-15.html
│ ├── index-16.html
│ ├── index-17.html
│ ├── index-18.html
│ ├── index-19.html
│ ├── index-2.html
│ ├── index-20.html
│ ├── index-21.html
│ ├── index-22.html
│ ├── index-23.html
│ ├── index-3.html
│ ├── index-4.html
│ ├── index-5.html
│ ├── index-6.html
│ ├── index-7.html
│ ├── index-8.html
│ └── index-9.html
├── index.html
├── main
│ ├── DTMain.html
│ ├── MainHillClimbing.html
│ ├── MainNaiveBayes.html
│ ├── MainSparseCandidate.html
│ ├── class-use
│ │ ├── DTMain.html
│ │ ├── MainHillClimbing.html
│ │ ├── MainNaiveBayes.html
│ │ └── MainSparseCandidate.html
│ ├── package-frame.html
│ ├── package-summary.html
│ ├── package-tree.html
│ └── package-use.html
├── overview-frame.html
├── overview-summary.html
├── overview-tree.html
├── package-list
├── resources
│ ├── background.gif
│ ├── tab.gif
│ ├── titlebar.gif
│ └── titlebar_end.gif
├── stylesheet.css
└── tree
│ ├── DecisionTree.TreePrinter.html
│ ├── DecisionTree.html
│ ├── DtLeaf.html
│ ├── DtNode.html
│ ├── ID3Builder.html
│ ├── Node.html
│ ├── class-use
│ ├── DecisionTree.TreePrinter.html
│ ├── DecisionTree.html
│ ├── DtLeaf.html
│ ├── DtNode.html
│ ├── ID3Builder.html
│ └── Node.html
│ ├── classifiers
│ ├── ID3TreeClassifier.html
│ ├── class-use
│ │ └── ID3TreeClassifier.html
│ ├── package-frame.html
│ ├── package-summary.html
│ ├── package-tree.html
│ └── package-use.html
│ ├── evaluate
│ ├── BiClassTest.html
│ ├── BiClassTestResults.html
│ ├── class-use
│ │ ├── BiClassTest.html
│ │ └── BiClassTestResults.html
│ ├── package-frame.html
│ ├── package-summary.html
│ ├── package-tree.html
│ └── package-use.html
│ ├── package-frame.html
│ ├── package-summary.html
│ ├── package-tree.html
│ ├── package-use.html
│ └── train
│ ├── Bin.html
│ ├── Entropy.html
│ ├── Split.html
│ ├── SplitBranch.html
│ ├── SplitGenerator.html
│ ├── class-use
│ ├── Bin.html
│ ├── Entropy.html
│ ├── Split.html
│ ├── SplitBranch.html
│ └── SplitGenerator.html
│ ├── package-frame.html
│ ├── package-summary.html
│ ├── package-tree.html
│ └── package-use.html
├── examples
└── examples
│ ├── DecisionTreeExample.java
│ ├── GraphAlgorithmExamples.java
│ └── HMMExamples.java
├── gold_standard
├── bayes
│ └── tan.rtf
└── tree
│ ├── m10.tree
│ ├── m2.tree
│ ├── m20.tree
│ └── m4.tree
├── lib
├── data_structures.jar
├── guava-18.0.jar
├── hamcrest-core-1.3.jar
├── junit-4.12-beta-3.jar
├── lombok.jar
└── mockito-all-1.9.5.jar
├── src
├── applications
│ ├── MainHillClimbing.java
│ ├── MainNaiveBayes.java
│ └── MainSparseCandidate.java
├── bayes
│ ├── BNConditionalQuery.java
│ ├── BNDataGenerator.java
│ ├── BNEvaluator.java
│ ├── BNJointQuery.java
│ ├── BNNode.java
│ ├── BNResultWriter.java
│ ├── BNStructure.java
│ ├── BNUtility.java
│ ├── BayesianNetwork.java
│ ├── VariableSet.java
│ ├── classifiers
│ │ └── NaiveBayesClassifier.java
│ ├── cpd
│ │ ├── CPDLeaf.java
│ │ ├── CPDNode.java
│ │ ├── CPDQuery.java
│ │ ├── CPDTree.java
│ │ ├── CPDTreeBuilder.java
│ │ ├── Split.java
│ │ └── SplitBranch.java
│ ├── information
│ │ └── KLDivergence.java
│ └── structuresearch
│ │ ├── HillClimbingBuilder.java
│ │ ├── NaiveBayesBuilder.java
│ │ ├── NetworkBuilder.java
│ │ ├── Operation.java
│ │ ├── SparseCandidateBuilder.java
│ │ ├── TANBuilder.java
│ │ └── score
│ │ ├── BIC.java
│ │ └── ScoringFunction.java
├── classify
│ ├── ClassificationResult.java
│ ├── Classifier.java
│ └── evaluate
│ │ └── PercentageError.java
├── data
│ ├── Attribute.java
│ ├── AttributeSet.java
│ ├── DataSet.java
│ ├── Instance.java
│ ├── InstanceSet.java
│ ├── fold
│ │ └── KFoldCreator.java
│ └── reader
│ │ └── ArffReader.java
├── distributions
│ ├── Distribution.java
│ └── GeometricDistribution.java
├── graph
│ ├── DirectedGraph.java
│ ├── Path.java
│ ├── bellmanford
│ │ └── BellmanFord.java
│ ├── dag
│ │ ├── DetectCycles.java
│ │ └── TopologicalSort.java
│ ├── floydwarshall
│ │ ├── AllPairsShortestPaths.java
│ │ └── FloydWarshall.java
│ └── prim
│ │ ├── Edge.java
│ │ └── Prim.java
├── hmm
│ ├── HMM.java
│ ├── State.java
│ ├── StateContainer.java
│ ├── StateParamsTied.java
│ ├── StateSilent.java
│ ├── Transition.java
│ └── algorithms
│ │ ├── BackwardAlgorithm.java
│ │ ├── DpMatrix.java
│ │ ├── DpMatrixElement.java
│ │ ├── ForwardAlgorithm.java
│ │ └── SortSilentStates.java
├── math
│ └── LogP.java
└── tree
│ ├── DecisionTree.java
│ ├── DtLeaf.java
│ ├── DtNode.java
│ ├── Forest.java
│ ├── Node.java
│ ├── algorithms
│ ├── DecisionTreeBuilder.java
│ ├── ID3TreeBuilder.java
│ └── RandomForestBuilder.java
│ ├── classifiers
│ ├── ID3TreeClassifier.java
│ └── RandomForestClassifier.java
│ ├── evaluate
│ ├── BiClassTest.java
│ └── BiClassTestResults.java
│ └── train
│ ├── Bin.java
│ ├── Entropy.java
│ ├── Split.java
│ ├── SplitBranch.java
│ └── SplitGenerator.java
└── tst
└── data
└── AttributeTest.java
/.gitignore:
--------------------------------------------------------------------------------
1 | .metadata/
2 | bin
3 | .DS_Store
4 | .settings
5 | .classpath
6 | .project
7 | data_sets
8 | gold_standard
9 | results
10 | .recommenders/
11 | *.class
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/LICENSE
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | machine-learning
2 | ================
3 |
4 | A Java library of machine learning algorithms. These are my own implementations that I am working on for practice.
5 |
6 | Decision Trees:
7 |
8 | ID3
9 |
10 | Bayesian Networks:
11 |
12 | Naïve Bayes classifier
13 |
14 | Tree Augmented Naïve (TAN) Bayes classifier
15 |
16 | Hill Climbing Structure Search
17 |
18 | Sparse Candidate Structure Search
19 |
20 | Bayesian Information Criterion (BIC)
21 |
22 | Artificial Data Generation
23 |
24 | Hidden Markov Models:
25 |
26 | Forward Algorithm
27 |
28 | Backward Algorithm
29 |
30 | Generic Graph Theory Algorithms:
31 |
32 | Bellman-Ford Algorithm
33 |
34 | Floyd-Warshall Algorithm
35 |
36 | Prim's Algorithm
37 |
38 | Cycle Detection
39 |
40 | Topological Sort of Vertices in a DAG
41 |
42 |
43 |
44 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/builders/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | bayes.builders
7 |
8 |
9 |
10 |
11 |
12 |
13 |
Classes
14 |
23 |
Enums
24 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/builders/scoring/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | bayes.builders.scoring
7 |
8 |
9 |
10 |
11 |
12 |
13 |
Interfaces
14 |
17 |
Classes
18 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/class-use/BNEvaluator.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Class bayes.BNEvaluator
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of bayes.BNEvaluator
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/class-use/BNUtility.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Class bayes.BNUtility
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of bayes.BNUtility
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/classifiers/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | bayes.classifiers
7 |
8 |
9 |
10 |
11 |
12 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/classifiers/package-use.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Package bayes.classifiers
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of bayes.classifiers
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/cpd/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | bayes.cpd
7 |
8 |
9 |
10 |
11 |
12 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/information/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | bayes.information
7 |
8 |
9 |
10 |
11 |
12 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/information/package-use.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Package bayes.information
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of bayes.information
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/bayes/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | bayes
7 |
8 |
9 |
10 |
11 |
12 |
13 |
Classes
14 |
26 |
Enums
27 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/machine-learning/doc/common/classification/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | common.classification
7 |
8 |
9 |
10 |
11 |
12 |
13 |
Interfaces
14 |
17 |
Classes
18 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/machine-learning/doc/common/kfold/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | common.kfold
7 |
8 |
9 |
10 |
11 |
12 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/machine-learning/doc/common/kfold/package-use.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Package common.kfold
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of common.kfold
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/graph/dag/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | graph.dag
7 |
8 |
9 |
10 |
11 |
12 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/machine-learning/doc/graph/dag/package-use.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Package graph.dag
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of graph.dag
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/graph/prim/class-use/Prim.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Class graph.prim.Prim
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of graph.prim.Prim
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/graph/prim/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | graph.prim
7 |
8 |
9 |
10 |
11 |
12 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/machine-learning/doc/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Generated Documentation (Untitled)
7 |
52 |
53 |
67 |
68 |
--------------------------------------------------------------------------------
/machine-learning/doc/main/class-use/DTMain.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Class main.DTMain
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of main.DTMain
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/main/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | main
7 |
8 |
9 |
10 |
11 |
12 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/machine-learning/doc/main/package-use.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Package main
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of main
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/overview-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Overview List
7 |
8 |
9 |
10 |
11 |
12 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/machine-learning/doc/package-list:
--------------------------------------------------------------------------------
1 | bayes
2 | bayes.builders
3 | bayes.builders.scoring
4 | bayes.classifiers
5 | bayes.cpd
6 | bayes.information
7 | common.classification
8 | common.kfold
9 | data
10 | data.arff
11 | data.attribute
12 | data.instance
13 | graph.dag
14 | graph.prim
15 | main
16 | tree
17 | tree.classifiers
18 | tree.evaluate
19 | tree.train
20 |
--------------------------------------------------------------------------------
/machine-learning/doc/resources/background.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/doc/resources/background.gif
--------------------------------------------------------------------------------
/machine-learning/doc/resources/tab.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/doc/resources/tab.gif
--------------------------------------------------------------------------------
/machine-learning/doc/resources/titlebar.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/doc/resources/titlebar.gif
--------------------------------------------------------------------------------
/machine-learning/doc/resources/titlebar_end.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/doc/resources/titlebar_end.gif
--------------------------------------------------------------------------------
/machine-learning/doc/tree/class-use/DtLeaf.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Class tree.DtLeaf
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of tree.DtLeaf
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/tree/class-use/ID3Builder.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Class tree.ID3Builder
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of tree.ID3Builder
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/tree/classifiers/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | tree.classifiers
7 |
8 |
9 |
10 |
11 |
12 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/machine-learning/doc/tree/classifiers/package-use.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Uses of Package tree.classifiers
7 |
8 |
9 |
10 |
11 |
17 |
20 |
21 |
37 |
64 |
65 |
68 | No usage of tree.classifiers
69 |
70 |
86 |
113 |
114 |
115 |
116 |
--------------------------------------------------------------------------------
/machine-learning/doc/tree/evaluate/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | tree.evaluate
7 |
8 |
9 |
10 |
11 |
12 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/machine-learning/doc/tree/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | tree
7 |
8 |
9 |
10 |
11 |
12 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/machine-learning/doc/tree/train/package-frame.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | tree.train
7 |
8 |
9 |
10 |
11 |
12 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/machine-learning/examples/examples/DecisionTreeExample.java:
--------------------------------------------------------------------------------
1 | package examples;
2 | import tree.classifiers.ID3TreeClassifier;
3 | import classify.ClassificationResult;
4 | import data.DataSet;
5 | import data.reader.ArffReader;
6 |
7 |
8 | /**
9 | * The following example shows how to learn an ID3 decision tree classifier on
10 | * training data set and then how to classify a test data set.
11 | *
12 | */
13 | public class DecisionTreeExample
14 | {
15 | public static void main( String[] args )
16 | {
17 | /*
18 | * Paths to train and test ARFF files
19 | */
20 | String trainArffPath = args[0];
21 | String testArffPath = args[1];
22 |
23 | /*
24 | * Read data
25 | */
26 | DataSet trainingData = ArffReader.readFile(trainArffPath);
27 | DataSet testData = ArffReader.readFile(testArffPath);
28 |
29 | /*
30 | * Specify the class attribute name
31 | */
32 | String classAttribute = args[2];
33 | trainingData.setClassAttribute(classAttribute);
34 | testData.setClassAttribute(classAttribute);
35 |
36 | /*
37 | * Specify minimum number of instances at a node for splitting on a new attribute.
38 | */
39 | int minInstancesForStopping = Integer.decode(args[3]);
40 |
41 | /*
42 | * Build the classifier
43 | */
44 | ID3TreeClassifier classifier = new ID3TreeClassifier(minInstancesForStopping, trainingData);
45 | System.out.println(classifier);
46 |
47 | /*
48 | * Classify each instance in the test data set
49 | */
50 | ClassificationResult result = classifier.classifyData(testData);
51 | System.out.println(result);
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/machine-learning/examples/examples/GraphAlgorithmExamples.java:
--------------------------------------------------------------------------------
1 | package examples;
2 | import java.util.Map;
3 |
4 | import graph.DirectedGraph;
5 | import graph.Path;
6 | import graph.bellmanford.BellmanFord;
7 | import graph.floydwarshall.AllPairsShortestPaths;
8 | import graph.floydwarshall.FloydWarshall;
9 |
10 |
11 | public class GraphAlgorithmExamples
12 | {
13 |
14 | public static void main(String[] args)
15 | {
16 | bellmanFordExample();
17 | floydWarshallExample();
18 | }
19 |
20 |
21 | /**
22 | * Run the Bellman-Ford algorithm to find the shortest paths to all
23 | * nodes from a source node.
24 | */
25 | public static void bellmanFordExample()
26 | {
27 | DirectedGraph graph = createToyGraph_NoNegativeCycles();
28 |
29 | // Designate the source node
30 | String sourceNode = "A";
31 |
32 | // Run Bellman-Ford
33 | Map> shortestPaths = BellmanFord.runBellmanFord(graph, sourceNode);
34 |
35 | // Print each path
36 | for (Path path : shortestPaths.values())
37 | {
38 | System.out.println(path);
39 | }
40 | }
41 |
42 | public static void floydWarshallExample()
43 | {
44 | DirectedGraph graph = createToyGraph_NoNegativeCycles();
45 |
46 | AllPairsShortestPaths paths = FloydWarshall.runFloydWarshall(graph);
47 |
48 | System.out.println(paths.getPath("A", "B"));
49 | }
50 |
51 | /**
52 | * @return an example graph with no negative cycles.
53 | */
54 | public static DirectedGraph createToyGraph_NoNegativeCycles()
55 | {
56 | DirectedGraph graph = new DirectedGraph<>();
57 |
58 | graph.addEdge("A", "B", 3d);
59 | graph.addEdge("A", "E", 7d);
60 | graph.addEdge("B", "C", 5d);
61 | graph.addEdge("B", "E", 8d);
62 | graph.addEdge("B", "D", -4d);
63 | graph.addEdge("C", "B", -2d);
64 | graph.addEdge("D", "C", 7d);
65 | graph.addEdge("D", "A", 2d);
66 | graph.addEdge("E", "C", -3d);
67 | graph.addEdge("E", "D", 9d);
68 |
69 | return graph;
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/machine-learning/examples/examples/HMMExamples.java:
--------------------------------------------------------------------------------
1 | package examples;
2 | import math.LogP;
3 | import pair.Pair;
4 |
5 | import hmm.HMM;
6 | import hmm.State;
7 | import hmm.StateSilent;
8 | import hmm.Transition;
9 | import hmm.algorithms.BackwardAlgorithm;
10 | import hmm.algorithms.DpMatrix;
11 | import hmm.algorithms.ForwardAlgorithm;
12 |
13 |
14 | public class HMMExamples
15 | {
16 | /**
17 | * Example running the forward and backward algorithms on an HMM object.
18 | */
19 | public static void main(String[] args)
20 | {
21 |
22 | HMM toyHmm = buildToyHMM();
23 | String[] sequence = {"x", "y", "x", "x"};
24 |
25 | /*
26 | * Forward algorithm
27 | */
28 | Pair resultF = ForwardAlgorithm.run(toyHmm, sequence);
29 |
30 | /*
31 | * Backward algorithm
32 | */
33 | Pair resultB = BackwardAlgorithm.run(toyHmm, sequence);
34 |
35 | System.out.println("Probability of sequence: "
36 | + LogP.exp(resultF.getFirst()));
37 | }
38 |
39 | /**
40 | * Example building an HMM object
41 | *
42 | * @return an example HMM
43 | */
44 | public static HMM buildToyHMM()
45 | {
46 | HMM hmm = new HMM();
47 |
48 | /*
49 | * Create states
50 | */
51 | State A = new StateSilent("A");
52 | State B = new StateSilent("B");
53 | State C = new StateSilent("C");
54 | State D = new StateSilent("D");
55 | State E = new State("E");
56 | State G = new State("G");
57 |
58 | /*
59 | * Add transitions between states
60 | */
61 | A.addTransition(new Transition("A", "E", LogP.ln(1.0 / 3.0)));
62 | A.addTransition(new Transition("A", "X", LogP.ln(1.0 / 3.0)));
63 | A.addTransition(new Transition("A", "B", LogP.ln(1.0 / 3.0)));
64 | B.addTransition(new Transition("B", "C", LogP.ln(0.5)));
65 | B.addTransition(new Transition("B", "G", LogP.ln(0.5)));
66 | C.addTransition(new Transition("C", "D", LogP.ln(1.0)));
67 | D.addTransition(new Transition("D", "E", LogP.ln(1.0)));
68 | E.addTransition(new Transition("E", "E", LogP.ln(0.5)));
69 | E.addTransition(new Transition("E", "G", LogP.ln(0.5)));
70 |
71 | /*
72 | * Create emission distribution for state E
73 | */
74 | E.addEmission("x", LogP.ln(0.5));
75 | E.addEmission("y", LogP.ln(0.5));
76 |
77 | /*
78 | * Create emission distribution for state G
79 | */
80 | G.addEmission("x", LogP.ln(0.9));
81 | G.addEmission("y", LogP.ln(0.1));
82 |
83 | /*
84 | * Add states to model
85 | */
86 | hmm.addState(B);
87 | hmm.addState(D);
88 | hmm.addState(C);
89 | hmm.addState(E);
90 | hmm.addState(G);
91 | hmm.addState(A);
92 |
93 | /*
94 | * Set the begin state
95 | */
96 | hmm.setBeginStateId("A");
97 | hmm.setEndStateId("E");
98 |
99 | return hmm;
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/machine-learning/gold_standard/tree/m10.tree:
--------------------------------------------------------------------------------
1 | thal = fixed_defect [4 6]
2 | | ca <= 0.500000 [4 0]: negative
3 | | ca > 0.500000 [0 6]: positive
4 | thal = normal [84 19]
5 | | thalach <= 111.500000 [0 4]: positive
6 | | thalach > 111.500000 [84 15]
7 | | | age <= 55.500000 [56 4]
8 | | | | trestbps <= 113.500000 [9 3]
9 | | | | | oldpeak <= 0.300000 [3 3]: negative
10 | | | | | oldpeak > 0.300000 [6 0]: negative
11 | | | | trestbps > 113.500000 [47 1]
12 | | | | | oldpeak <= 3.550000 [47 0]: negative
13 | | | | | oldpeak > 3.550000 [0 1]: positive
14 | | | age > 55.500000 [28 11]
15 | | | | chol <= 248.500000 [14 1]
16 | | | | | oldpeak <= 2.800000 [14 0]: negative
17 | | | | | oldpeak > 2.800000 [0 1]: positive
18 | | | | chol > 248.500000 [14 10]
19 | | | | | sex = female [13 3]
20 | | | | | | cp = typ_angina [1 0]: negative
21 | | | | | | cp = asympt [3 3]: negative
22 | | | | | | cp = non_anginal [7 0]: negative
23 | | | | | | cp = atyp_angina [2 0]: negative
24 | | | | | sex = male [1 7]: positive
25 | thal = reversable_defect [20 67]
26 | | cp = typ_angina [3 1]: negative
27 | | cp = asympt [5 53]
28 | | | oldpeak <= 0.650000 [5 10]
29 | | | | chol <= 240.500000 [5 2]: negative
30 | | | | chol > 240.500000 [0 8]: positive
31 | | | oldpeak > 0.650000 [0 43]: positive
32 | | cp = non_anginal [9 10]
33 | | | oldpeak <= 1.900000 [9 5]
34 | | | | trestbps <= 122.500000 [6 0]: negative
35 | | | | trestbps > 122.500000 [3 5]: positive
36 | | | oldpeak > 1.900000 [0 5]: positive
37 | | cp = atyp_angina [3 3]: negative
--------------------------------------------------------------------------------
/machine-learning/gold_standard/tree/m2.tree:
--------------------------------------------------------------------------------
1 |
2 | thal = fixed_defect [4 6]
3 | | ca <= 0.500000 [4 0]: negative
4 | | ca > 0.500000 [0 6]: positive
5 | thal = normal [84 19]
6 | | thalach <= 111.500000 [0 4]: positive
7 | | thalach > 111.500000 [84 15]
8 | | | age <= 55.500000 [56 4]
9 | | | | trestbps <= 113.500000 [9 3]
10 | | | | | oldpeak <= 0.300000 [3 3]
11 | | | | | | cp = typ_angina [0 0]: negative
12 | | | | | | cp = asympt [0 2]: positive
13 | | | | | | cp = non_anginal [1 1]
14 | | | | | | | age <= 44.000000 [1 0]: negative
15 | | | | | | | age > 44.000000 [0 1]: positive
16 | | | | | | cp = atyp_angina [2 0]: negative
17 | | | | | oldpeak > 0.300000 [6 0]: negative
18 | | | | trestbps > 113.500000 [47 1]
19 | | | | | oldpeak <= 3.550000 [47 0]: negative
20 | | | | | oldpeak > 3.550000 [0 1]: positive
21 | | | age > 55.500000 [28 11]
22 | | | | chol <= 248.500000 [14 1]
23 | | | | | oldpeak <= 2.800000 [14 0]: negative
24 | | | | | oldpeak > 2.800000 [0 1]: positive
25 | | | | chol > 248.500000 [14 10]
26 | | | | | sex = female [13 3]
27 | | | | | | cp = typ_angina [1 0]: negative
28 | | | | | | cp = asympt [3 3]
29 | | | | | | | age <= 58.000000 [2 0]: negative
30 | | | | | | | age > 58.000000 [1 3]
31 | | | | | | | | chol <= 362.000000 [0 3]: positive
32 | | | | | | | | chol > 362.000000 [1 0]: negative
33 | | | | | | cp = non_anginal [7 0]: negative
34 | | | | | | cp = atyp_angina [2 0]: negative
35 | | | | | sex = male [1 7]
36 | | | | | | age <= 65.500000 [0 5]: positive
37 | | | | | | age > 65.500000 [1 2]
38 | | | | | | | age <= 66.500000 [1 0]: negative
39 | | | | | | | age > 66.500000 [0 2]: positive
40 | thal = reversable_defect [20 67]
41 | | cp = typ_angina [3 1]
42 | | | oldpeak <= 0.700000 [0 1]: positive
43 | | | oldpeak > 0.700000 [3 0]: negative
44 | | cp = asympt [5 53]
45 | | | oldpeak <= 0.650000 [5 10]
46 | | | | chol <= 240.500000 [5 2]
47 | | | | | chol <= 192.000000 [1 2]
48 | | | | | | age <= 62.000000 [0 2]: positive
49 | | | | | | age > 62.000000 [1 0]: negative
50 | | | | | chol > 192.000000 [4 0]: negative
51 | | | | chol > 240.500000 [0 8]: positive
52 | | | oldpeak > 0.650000 [0 43]: positive
53 | | cp = non_anginal [9 10]
54 | | | oldpeak <= 1.900000 [9 5]
55 | | | | trestbps <= 122.500000 [6 0]: negative
56 | | | | trestbps > 122.500000 [3 5]
57 | | | | | chol <= 232.500000 [3 1]
58 | | | | | | trestbps <= 129.000000 [0 1]: positive
59 | | | | | | trestbps > 129.000000 [3 0]: negative
60 | | | | | chol > 232.500000 [0 4]: positive
61 | | | oldpeak > 1.900000 [0 5]: positive
62 | | cp = atyp_angina [3 3]
63 | | | age <= 47.000000 [2 0]: negative
64 | | | age > 47.000000 [1 3]
65 | | | | trestbps <= 109.000000 [1 0]: negative
66 | | | | trestbps > 109.000000 [0 3]: positive
--------------------------------------------------------------------------------
/machine-learning/gold_standard/tree/m20.tree:
--------------------------------------------------------------------------------
1 |
2 | thal = fixed_defect [4 6]: positive
3 | thal = normal [84 19]
4 | | thalach <= 111.500000 [0 4]: positive
5 | | thalach > 111.500000 [84 15]
6 | | | age <= 55.500000 [56 4]
7 | | | | trestbps <= 113.500000 [9 3]: negative
8 | | | | trestbps > 113.500000 [47 1]
9 | | | | | oldpeak <= 3.550000 [47 0]: negative
10 | | | | | oldpeak > 3.550000 [0 1]: positive
11 | | | age > 55.500000 [28 11]
12 | | | | chol <= 248.500000 [14 1]: negative
13 | | | | chol > 248.500000 [14 10]
14 | | | | | sex = female [13 3]: negative
15 | | | | | sex = male [1 7]: positive
16 | thal = reversable_defect [20 67]
17 | | cp = typ_angina [3 1]: negative
18 | | cp = asympt [5 53]
19 | | | oldpeak <= 0.650000 [5 10]: positive
20 | | | oldpeak > 0.650000 [0 43]: positive
21 | | cp = non_anginal [9 10]: positive
22 | | cp = atyp_angina [3 3]: negative
--------------------------------------------------------------------------------
/machine-learning/gold_standard/tree/m4.tree:
--------------------------------------------------------------------------------
1 |
2 | thal = fixed_defect [4 6]
3 | | ca <= 0.500000 [4 0]: negative
4 | | ca > 0.500000 [0 6]: positive
5 | thal = normal [84 19]
6 | | thalach <= 111.500000 [0 4]: positive
7 | | thalach > 111.500000 [84 15]
8 | | | age <= 55.500000 [56 4]
9 | | | | trestbps <= 113.500000 [9 3]
10 | | | | | oldpeak <= 0.300000 [3 3]
11 | | | | | | cp = typ_angina [0 0]: negative
12 | | | | | | cp = asympt [0 2]: positive
13 | | | | | | cp = non_anginal [1 1]: negative
14 | | | | | | cp = atyp_angina [2 0]: negative
15 | | | | | oldpeak > 0.300000 [6 0]: negative
16 | | | | trestbps > 113.500000 [47 1]
17 | | | | | oldpeak <= 3.550000 [47 0]: negative
18 | | | | | oldpeak > 3.550000 [0 1]: positive
19 | | | age > 55.500000 [28 11]
20 | | | | chol <= 248.500000 [14 1]
21 | | | | | oldpeak <= 2.800000 [14 0]: negative
22 | | | | | oldpeak > 2.800000 [0 1]: positive
23 | | | | chol > 248.500000 [14 10]
24 | | | | | sex = female [13 3]
25 | | | | | | cp = typ_angina [1 0]: negative
26 | | | | | | cp = asympt [3 3]
27 | | | | | | | age <= 58.000000 [2 0]: negative
28 | | | | | | | age > 58.000000 [1 3]
29 | | | | | | | | chol <= 362.000000 [0 3]: positive
30 | | | | | | | | chol > 362.000000 [1 0]: negative
31 | | | | | | cp = non_anginal [7 0]: negative
32 | | | | | | cp = atyp_angina [2 0]: negative
33 | | | | | sex = male [1 7]
34 | | | | | | age <= 65.500000 [0 5]: positive
35 | | | | | | age > 65.500000 [1 2]: positive
36 | thal = reversable_defect [20 67]
37 | | cp = typ_angina [3 1]
38 | | | oldpeak <= 0.700000 [0 1]: positive
39 | | | oldpeak > 0.700000 [3 0]: negative
40 | | cp = asympt [5 53]
41 | | | oldpeak <= 0.650000 [5 10]
42 | | | | chol <= 240.500000 [5 2]
43 | | | | | chol <= 192.000000 [1 2]: positive
44 | | | | | chol > 192.000000 [4 0]: negative
45 | | | | chol > 240.500000 [0 8]: positive
46 | | | oldpeak > 0.650000 [0 43]: positive
47 | | cp = non_anginal [9 10]
48 | | | oldpeak <= 1.900000 [9 5]
49 | | | | trestbps <= 122.500000 [6 0]: negative
50 | | | | trestbps > 122.500000 [3 5]
51 | | | | | chol <= 232.500000 [3 1]
52 | | | | | | trestbps <= 129.000000 [0 1]: positive
53 | | | | | | trestbps > 129.000000 [3 0]: negative
54 | | | | | chol > 232.500000 [0 4]: positive
55 | | | oldpeak > 1.900000 [0 5]: positive
56 | | cp = atyp_angina [3 3]
57 | | | age <= 47.000000 [2 0]: negative
58 | | | age > 47.000000 [1 3]
59 | | | | trestbps <= 109.000000 [1 0]: negative
60 | | | | trestbps > 109.000000 [0 3]: positive
--------------------------------------------------------------------------------
/machine-learning/lib/data_structures.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/data_structures.jar
--------------------------------------------------------------------------------
/machine-learning/lib/guava-18.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/guava-18.0.jar
--------------------------------------------------------------------------------
/machine-learning/lib/hamcrest-core-1.3.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/hamcrest-core-1.3.jar
--------------------------------------------------------------------------------
/machine-learning/lib/junit-4.12-beta-3.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/junit-4.12-beta-3.jar
--------------------------------------------------------------------------------
/machine-learning/lib/lombok.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/lombok.jar
--------------------------------------------------------------------------------
/machine-learning/lib/mockito-all-1.9.5.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mbernste/machine-learning/463b182442de6e637f929fde4da9f21e890a2efb/machine-learning/lib/mockito-all-1.9.5.jar
--------------------------------------------------------------------------------
/machine-learning/src/applications/MainHillClimbing.java:
--------------------------------------------------------------------------------
1 | package applications;
2 |
3 | import java.io.FileNotFoundException;
4 | import java.io.PrintWriter;
5 | import java.util.ArrayList;
6 | import java.util.List;
7 |
8 |
9 | import data.DataSet;
10 | import data.fold.KFoldCreator;
11 | import data.reader.ArffReader;
12 |
13 |
14 | import pair.Pair;
15 |
16 | import bayes.BNEvaluator;
17 | import bayes.BNResultWriter;
18 | import bayes.BayesianNetwork;
19 | import bayes.structuresearch.HillClimbingBuilder;
20 | import bayes.structuresearch.score.BIC;
21 |
22 | public class MainHillClimbing
23 | {
24 | public static void main(String[] args)
25 | {
26 | try
27 | {
28 | PrintWriter out = new PrintWriter(args[1]);
29 |
30 | /*
31 | * Read the training data from the arff file
32 | */
33 | ArffReader reader = new ArffReader();
34 | DataSet data = reader.readFile(args[0]);
35 |
36 | /*
37 | * Scoring function
38 | */
39 | BIC bic = new BIC();
40 |
41 | List> folds = KFoldCreator.create(data, 5);
42 |
43 |
44 | out.println("Result on folds:");
45 | Double scoreSum = 0.0;
46 | for (int i = 0; i < folds.size(); i++)
47 | {
48 | /*
49 | * TODO: BAD!
50 | */
51 | BNResultWriter.WRITER = new PrintWriter(args[2]+ "_" + i);
52 |
53 | Pair fold = folds.get(i);
54 |
55 | HillClimbingBuilder hcBuilder = new HillClimbingBuilder();
56 | BayesianNetwork net = hcBuilder.buildNetwork(fold.getFirst(),
57 | 1,
58 | bic,
59 | null);
60 |
61 | Double score = BNEvaluator.calculateLogLikelihood(net,
62 | fold.getSecond());
63 |
64 | scoreSum += score;
65 |
66 | out.println("\n\n-------- Fold " + i + " --------\n");
67 |
68 | out.println("Net Structure: ");
69 | out.print(net);
70 |
71 | out.println("Likelihood of test data: ");
72 | out.println(score);
73 |
74 | /*
75 | * TODO: BAD!
76 | */
77 | BNResultWriter.WRITER.close();
78 | }
79 | out.println("Average likelihood:");
80 | out.println(scoreSum / folds.size());
81 | out.close();
82 | }
83 | catch(FileNotFoundException e)
84 | {
85 | System.out.println("Error instantiating output file writer.");
86 | System.exit(1);
87 | }
88 |
89 | }
90 |
91 | }
92 |
--------------------------------------------------------------------------------
/machine-learning/src/applications/MainNaiveBayes.java:
--------------------------------------------------------------------------------
1 | package applications;
2 |
3 |
4 |
5 | import classify.ClassificationResult;
6 |
7 | import data.DataSet;
8 | import data.reader.ArffReader;
9 | import bayes.classifiers.NaiveBayesClassifier;
10 |
11 | public class MainNaiveBayes
12 | {
13 |
14 | private static final String CLASS_ATTR_NAME = "class";
15 | private static final Integer LAPLACE_COUNT = 1;
16 |
17 | public static void main(String[] args)
18 | {
19 | /*
20 | * Print useage if not enough arguments
21 | */
22 | if (args.length != 3)
23 | {
24 | printUsage();
25 | return;
26 | }
27 |
28 | /*
29 | * Determine whether to use TAN or simple Naive Bayes
30 | */
31 | boolean tan;
32 | if (args[2].equals("t"))
33 | {
34 | tan = true;
35 | }
36 | else if (args[2].equals("n"))
37 | {
38 | tan = false;
39 | }
40 | else
41 | {
42 | printUsage();
43 | return;
44 | }
45 |
46 | /*
47 | * Read the training data from the arff file
48 | */
49 | ArffReader reader = new ArffReader();
50 | DataSet data = reader.readFile(args[0]);
51 | data.setClassAttribute(CLASS_ATTR_NAME);
52 |
53 | /*
54 | * Create Naive Bayes classifier
55 | */
56 | NaiveBayesClassifier nbClassifier =
57 | new NaiveBayesClassifier(data, LAPLACE_COUNT, tan);
58 |
59 | /*
60 | * Print network
61 | */
62 | System.out.println(nbClassifier);
63 |
64 | /*
65 | * Read the training data from the arff file
66 | */
67 | DataSet testData = reader.readFile(args[1]);
68 | testData.setClassAttribute(CLASS_ATTR_NAME);
69 |
70 | /*
71 | * Classify the data
72 | */
73 | ClassificationResult result = nbClassifier.classifyData(testData);
74 |
75 | /*
76 | * Print results to standard output
77 | */
78 | System.out.print("\n\n");
79 | System.out.println(result);
80 | }
81 |
82 | /**
83 | * Prints this program's usage to standard output
84 | */
85 | public static void printUsage()
86 | {
87 | System.out.println("\nUsage: bayes \n");
88 | }
89 | }
90 |
--------------------------------------------------------------------------------
/machine-learning/src/applications/MainSparseCandidate.java:
--------------------------------------------------------------------------------
1 | package applications;
2 |
3 | import java.io.FileNotFoundException;
4 | import java.io.PrintWriter;
5 | import java.util.ArrayList;
6 | import java.util.List;
7 |
8 |
9 | import data.DataSet;
10 | import data.fold.KFoldCreator;
11 | import data.reader.ArffReader;
12 |
13 |
14 | import pair.Pair;
15 |
16 | import bayes.BNEvaluator;
17 | import bayes.BNResultWriter;
18 | import bayes.BayesianNetwork;
19 | import bayes.structuresearch.SparseCandidateBuilder;
20 | import bayes.structuresearch.score.BIC;
21 |
22 | public class MainSparseCandidate
23 | {
24 | public static void main(String[] args)
25 | {
26 | try
27 | {
28 | PrintWriter out = new PrintWriter(args[1]);
29 |
30 | /*
31 | * Read the training data from the arff file
32 | */
33 | ArffReader reader = new ArffReader();
34 | DataSet data = reader.readFile(args[0]);
35 |
36 | /*
37 | * Scoring function
38 | */
39 | BIC bic = new BIC();
40 |
41 | List> folds = KFoldCreator.create(data, 5);
42 |
43 |
44 | out.println("Result on folds:");
45 | Double scoreSum = 0.0;
46 | for (int i = 0; i < folds.size(); i++)
47 | {
48 | /*
49 | * TODO: BAD!
50 | */
51 | BNResultWriter.WRITER = new PrintWriter(args[2]+ "_" + i);
52 |
53 | Pair fold = folds.get(i);
54 |
55 | SparseCandidateBuilder spBuilder = new SparseCandidateBuilder();
56 | BayesianNetwork net = spBuilder.buildNetwork(fold.getFirst(),
57 | 1,
58 | bic,
59 | null);
60 |
61 | Double score = BNEvaluator.calculateLogLikelihood(net,
62 | fold.getSecond());
63 |
64 | scoreSum += score;
65 |
66 | out.println("\n\n-------- Fold " + i + " --------\n");
67 |
68 | out.println("Net Structure: ");
69 | out.print(net);
70 |
71 | out.println("Likelihood of test data: ");
72 | out.println(score);
73 |
74 | /*
75 | * TODO: BAD!
76 | */
77 | BNResultWriter.WRITER.close();
78 | }
79 | out.println("Average likelihood:");
80 | out.println(scoreSum / folds.size());
81 | out.close();
82 | }
83 | catch(FileNotFoundException e)
84 | {
85 | System.out.println("Error instantiating output file writer.");
86 | System.exit(1);
87 | }
88 |
89 | }
90 |
91 | }
92 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/BNEvaluator.java:
--------------------------------------------------------------------------------
1 | package bayes;
2 |
3 | import java.util.ArrayList;
4 |
5 |
6 | import data.Attribute;
7 | import data.DataSet;
8 | import data.Instance;
9 |
10 | public class BNEvaluator
11 | {
12 |
13 | /**
14 | *
15 | * For each instance in the data, calculate the log-probability of the
16 | * bayes net producing this data. Sum the resulting log-probability
17 | * of each instance to get the total log-likelihood of the network
18 | * generating the data. Tha
19 | *
20 | * @param net the Bayes net used to calcuate the likelihood of the data
21 | * @param data the data for which we want to know the likelihood
22 | * @return the log-likelihood of seeing the data given the net
23 | */
24 | public static Double calculateLogLikelihood(BayesianNetwork net, DataSet data)
25 | {
26 | Double logProduct = 0.0;
27 |
28 | for (Instance instance : data.getInstanceSet().getInstances())
29 | {
30 | ArrayList queries = createQueries(instance, net, data);
31 |
32 | /*
33 | * Sum over the probability of each instance
34 | */
35 | for (BNConditionalQuery query : queries)
36 | {
37 | Double p = net.queryConditionalProbability(query);
38 | logProduct += -Math.log(p);
39 | }
40 | }
41 |
42 | return logProduct;
43 | }
44 |
45 | /**
46 | * For each instance we need to create a conditional probability
47 | * query on the value of each instance's attributes given the values of the
48 | * rest of the attributes.
49 | *
50 | * @param instance the instance under examination
51 | * @param net the Bayes net
52 | * @param data the data set
53 | * @return a query for each attribute in the instance
54 | */
55 | public static ArrayList createQueries(Instance instance,
56 | BayesianNetwork net,
57 | DataSet data)
58 | {
59 | ArrayList queries = new ArrayList();
60 |
61 | for (BNNode targetNode : net.getNodes())
62 | {
63 | BNConditionalQuery query = new BNConditionalQuery();
64 |
65 | /*
66 | * Set target attribute/value
67 | */
68 | Attribute targetAttr = targetNode.getAttribute();
69 | Integer targetAttrValue
70 | = instance.getAttributeValue(targetAttr).intValue();
71 | query.setTargetVariable(targetAttr, targetAttrValue);
72 |
73 | /*
74 | * Set each condition attribute/value
75 | */
76 | for (BNNode conditionNode : net.getNodes())
77 | {
78 | Attribute conditionAttr = conditionNode.getAttribute();
79 |
80 | boolean isChildOfTarget = targetNode.getParents().contains(conditionNode);
81 |
82 | if (!conditionNode.equals(targetNode) && isChildOfTarget)
83 | {
84 | Integer conditionAttrValue
85 | = instance.getAttributeValue(conditionAttr).intValue();
86 |
87 | query.addConditionVariable(conditionAttr, conditionAttrValue);
88 | }
89 | }
90 |
91 | queries.add(query);
92 | }
93 |
94 | return queries;
95 | }
96 |
97 | }
98 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/BNJointQuery.java:
--------------------------------------------------------------------------------
1 | package bayes;
2 |
3 | import java.util.ArrayList;
4 |
5 | import data.Attribute;
6 |
7 | import pair.Pair;
8 |
9 | /**
10 | * Used for querying a {@code BayesianNetwork} object for a joint
11 | * probability of a set of attribute/value pairs. For example, if we wish to
12 | * query a Bayes net for the following probability: P(A = a, E = e, D = d),
13 | * this object would represent the query, (A = a, E = e, D = d)
14 | *
15 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
16 | *
17 | */
18 | public class BNJointQuery
19 | {
20 |
21 | /**
22 | * Each pair of integers represents an Attribute ID and the nominal value ID
23 | * of the attribute of all of the variables in the joint probability query.
24 | *
25 | *
26 | * For example, given the following query: P(A = a, D = d, E = e), this
27 | * set of pairs would be as follows: (A,a), (D,d), (E,e).
28 | */
29 | private VariableSet variables;
30 |
31 | /**
32 | * Constructor
33 | *
34 | * @param variables the set of variables in the joint probability query
35 | */
36 | public BNJointQuery(VariableSet variables)
37 | {
38 | this.variables = variables;
39 | }
40 |
41 | /**
42 | * Constructor
43 | */
44 | public BNJointQuery()
45 | {
46 | this.variables = new VariableSet();
47 | }
48 |
49 | /**
50 | * Add an attribute/value pair to the set of attribute/value pairs
51 | * used in the query.
52 | *
53 | * @param attr the ID of the condition attribute to be added to the query
54 | * @param nominalValueId the nominal value ID specified for this Attribute
55 | */
56 | public void addVariable(Attribute attr, Integer nomValueId)
57 | {
58 | variables.addVariable(attr, nomValueId);
59 | }
60 |
61 | /**
62 | * Determines whether this BNQuery includes the value of a specific
63 | * attribute in the set of condition attributes
64 | *
65 | * @param attr the target Attribute
66 | * @return true if this query is specifying a value for this specific
67 | * Attribute in the set of condition attributes
68 | */
69 | public Boolean containsAttribute(Attribute attr)
70 | {
71 | return variables.containsAttribute(attr);
72 | }
73 |
74 | /**
75 | * Gets the value for a specific attribute in the set of
76 | * attributes. If this attribute is not specified by this query, this
77 | * method returns null.
78 | *
79 | * @param attr the target Attribute
80 | * @return the value of target Attribute specified in the
81 | * of the query. null if this Attribute is not specified in this query
82 | */
83 | public Integer getValueForAttribute(Attribute attr)
84 | {
85 | return variables.getValueForAttribute(attr);
86 | }
87 |
88 | /**
89 | * @return the list of attribute/value pairs in this joint probability
90 | * query. Each pair is the ID of the attribute and the nominal value ID
91 | * of the value of this attribute.
92 | */
93 | public ArrayList> getVariables()
94 | {
95 | return this.variables.getVariables();
96 | }
97 |
98 | @Override
99 | public String toString()
100 | {
101 | String result = "P(";
102 |
103 | /*
104 | * Condition variables
105 | */
106 | for (Pair attrValPair : this.variables.getVariables())
107 | {
108 | Attribute attr = attrValPair.getFirst();
109 | Integer attrValue = attrValPair.getSecond();
110 | result += attr.getName() + " = " + attr.getNominalValueName(attrValue);
111 | result += ", ";
112 | }
113 |
114 | /*
115 | * Closing paranthesis
116 | */
117 | result = result.substring(0, result.length() - 2);
118 | result += ")";
119 |
120 | return result;
121 | }
122 |
123 | }
124 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/BNResultWriter.java:
--------------------------------------------------------------------------------
1 | package bayes;
2 |
3 | import java.io.PrintWriter;
4 |
5 | public class BNResultWriter
6 | {
7 | public static PrintWriter WRITER;
8 | }
9 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/BNUtility.java:
--------------------------------------------------------------------------------
1 | package bayes;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | /**
7 | * A utility class for miscellaneous methods needed for Bayesian network
8 | * learning.
9 | *
10 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
11 | *
12 | */
13 | public class BNUtility
14 | {
15 | public static Double[][] convertToAdjacencyMatrix(List nodes)
16 | {
17 | int numNodes = nodes.size();
18 | Double[][] graph = new Double[numNodes][numNodes];
19 |
20 | /*
21 | * Initialize every element in graph to null
22 | */
23 | for (int r = 0; r < numNodes; r++)
24 | {
25 | for (int c = 0; c < numNodes; c++)
26 | {
27 | graph[r][c] = null;
28 | }
29 | }
30 |
31 | /*
32 | * Assign graph[P][C] to 1.0 if C is a child of P
33 | */
34 | for (int pIndex = 0; pIndex < nodes.size(); pIndex++)
35 | {
36 | BNNode currNode = nodes.get(pIndex);
37 |
38 | for (BNNode child : currNode.getChildren())
39 | {
40 | int cIndex = nodes.indexOf(child);
41 | graph[pIndex][cIndex] = 1.0;
42 | }
43 | }
44 |
45 | return graph;
46 | }
47 |
48 | public static Double[][] convertToAdjacencyMatrix(BayesianNetwork net)
49 | {
50 | int numNodes = net.getNumNodes();
51 | Double[][] graph = new Double[numNodes][numNodes];
52 |
53 | /*
54 | * Initialize every element in graph to null
55 | */
56 | for (int r = 0; r < numNodes; r++)
57 | {
58 | for (int c = 0; c < numNodes; c++)
59 | {
60 | graph[r][c] = null;
61 | }
62 | }
63 |
64 | /*
65 | * Assign graph[P][C] to 1.0 if C is a child of P
66 | */
67 | List nodes = net.getNodes();
68 | for (int pIndex = 0; pIndex < nodes.size(); pIndex++)
69 | {
70 | BNNode currNode = nodes.get(pIndex);
71 |
72 | for (BNNode child : currNode.getChildren())
73 | {
74 | int cIndex = nodes.indexOf(child);
75 | graph[pIndex][cIndex] = 1.0;
76 | }
77 | }
78 |
79 | return graph;
80 | }
81 | }
82 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/VariableSet.java:
--------------------------------------------------------------------------------
1 | package bayes;
2 |
3 | import java.util.ArrayList;
4 |
5 | import pair.Pair;
6 | import data.Attribute;
7 |
8 | /**
9 | * This class represents represents a set of attribute/value pairs. This is
10 | * a component of Bayes net query objects.
11 | *
12 | * @author matthewbernstein
13 | */
14 | public class VariableSet
15 | {
16 | /**
17 | * Each pair of integers represents an Attribute ID and the nominal value ID
18 | * of the attribute of all of the variables in this set.
19 | */
20 | private ArrayList> variables;
21 |
22 | /**
23 | * Constructor
24 | */
25 | public VariableSet()
26 | {
27 | this.variables = new ArrayList>();
28 | }
29 |
30 | /**
31 | * Add an attribute/value pair to the set of attribute/value pairs
32 | * to this set.
33 | *
34 | * @param attr the ID of the condition attribute to be added to the set
35 | * @param nominalValueId the nominal value ID specified for this Attribute
36 | */
37 | public void addVariable(Attribute attr, Integer nomValueId)
38 | {
39 | Pair newItem =
40 | new Pair(attr, nomValueId);
41 |
42 | variables.add(newItem);
43 | }
44 |
45 | /**
46 | * Determines whether this set includes the value of a specific
47 | * attribute in the set of attributes
48 | *
49 | * @param attr the target Attribute
50 | * @return true if this set has a value for this specific attribute
51 | */
52 | public Boolean containsAttribute(Attribute attr)
53 | {
54 | // Linear Search for a match in Attribute ID's
55 | for (Pair item : variables)
56 | {
57 | if (item.getFirst().equals(attr))
58 | {
59 | return true;
60 | }
61 | }
62 |
63 | return false;
64 | }
65 |
66 | /**
67 | * Gets the value for a specific attribute in the set of
68 | * attributes. If this attribute is not in this set, this method returns
69 | * null.
70 | *
71 | * @param attr the target Attribute
72 | * @return the value of target Attribute specified in the
73 | * of the query. null if this Attribute is not specified in this query
74 | */
75 | public Integer getValueForAttribute(Attribute attr)
76 | {
77 | if (attr == null)
78 | {
79 | return null;
80 | }
81 |
82 | /*
83 | * Linear search for the attribute
84 | */
85 | for (Pair item : variables)
86 | {
87 | if (item.getFirst().equals(attr))
88 | {
89 | return item.getSecond();
90 | }
91 | }
92 |
93 | return null;
94 | }
95 |
96 | /**
97 | * @return the list of attribute/value pairs in this joint probability
98 | * query. Each pair is the ID of the attribute and the nominal value ID
99 | * of the value of this attribute.
100 | */
101 | public ArrayList> getVariables()
102 | {
103 | return this.variables;
104 | }
105 |
106 | @Override
107 | public Object clone()
108 | {
109 | VariableSet copy = new VariableSet();
110 |
111 | for (Pair pair : this.variables)
112 | {
113 | copy.addVariable(pair.getFirst(), pair.getSecond());
114 | }
115 |
116 | return copy;
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/cpd/CPDLeaf.java:
--------------------------------------------------------------------------------
1 | package bayes.cpd;
2 |
3 | import data.Attribute;
4 |
5 | /**
6 | * A leaf node in a CPD tree.
7 | *
8 | * @author Mathew Bernstien - matthewb@cs.wisc.edu
9 | *
10 | */
11 | public class CPDLeaf extends CPDNode
12 | {
13 | protected double probability;
14 |
15 | private Integer laplaceCount;
16 |
17 | /**
18 | * Constructor
19 | *
20 | * @param probability the probability at this leaf CPDNode
21 | */
22 | public CPDLeaf(Attribute attribute,
23 | Integer nodeValue,
24 | Integer numInstances,
25 | Integer laplaceCount)
26 | {
27 | super(attribute, nodeValue, numInstances);
28 |
29 | this.laplaceCount = laplaceCount;
30 | }
31 |
32 | @Override
33 | public String toString()
34 | {
35 | String result = super.toString();
36 | result += " [" + probability + "]";
37 | return result;
38 | }
39 |
40 | /**
41 | * @return the probability at this leaf node
42 | */
43 | public double getProbability()
44 | {
45 | return this.probability;
46 | }
47 |
48 | /**
49 | * Calculate the probability of a specific query on this leaf node
50 | *
51 | * @param query the query object used to specify specific values of the
52 | * attributes for which this CPD's leaf attribute is conditioned on
53 | */
54 | @Override
55 | public Double calculateProbability(CPDQuery query)
56 | {
57 | /*
58 | * Get the query value for this node's attribute
59 | */
60 | Integer queryValue = query.getValueForQueryAttribute(this.attribute);
61 |
62 | /*
63 | * Return this leaf's probability if no specific value for this
64 | * CPDNode's attribute was specified in the query or if the value
65 | * matches this CPDNode's attribute in the query. Otherwise, we
66 | * return null.
67 | */
68 | if (queryValue == null || queryValue == this.nodeValue)
69 | {
70 | return this.probability;
71 | }
72 | else
73 | {
74 | return 0.0;
75 | }
76 | }
77 |
78 | /**
79 | * Set the parent for this leaf node. This, in turn, will calculate the
80 | * probability at this leaf.
81 | *
82 | * @param parent the parent CPDNode
83 | */
84 | @Override
85 | public void setParent(CPDNode parent)
86 | {
87 | /*
88 | * Set the parent
89 | */
90 | this.parent = parent;
91 |
92 | /*
93 | * Calculate leaf probability using Laplace counts
94 | */
95 | double numerator = (double) numInstances + laplaceCount;
96 | double denominator = parent.numInstances
97 | + (laplaceCount * attribute.getNominalValueMap().size());
98 |
99 | this.probability = numerator / denominator;
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/cpd/CPDQuery.java:
--------------------------------------------------------------------------------
1 | package bayes.cpd;
2 |
3 | import java.util.HashMap;
4 | import java.util.Map;
5 | import java.util.Map.Entry;
6 |
7 | import data.Attribute;
8 |
9 | /**
10 | * Objects of this class are used for querying a conditional probability
11 | * distribution for a single BNNode in a BayesianNetwork
12 | *
13 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
14 | *
15 | */
16 | public class CPDQuery
17 | {
18 | /**
19 | * Maps an attribute to the value of this attribute specified in the query
20 | */
21 | private Map queryItems;
22 |
23 | /**
24 | * Constructor
25 | */
26 | public CPDQuery()
27 | {
28 | this.queryItems = new HashMap();
29 | }
30 |
31 | /**
32 | * Add a query item to this query
33 | *
34 | * @param attr the Attribute this query is querying for
35 | * @param nominalValueId the nominal value ID specified for this Attribute
36 | */
37 | public void addQueryItem(Attribute attr, Integer nomValueId)
38 | {
39 | if (attr.isValidNominalValueId(nomValueId))
40 | {
41 | queryItems.put(attr, nomValueId);
42 | }
43 | else
44 | {
45 | throw new RuntimeException(nomValueId + " is not a valid nominal" +
46 | " value ID for the attribute " +
47 | attr.getName());
48 | }
49 | }
50 |
51 | /**
52 | * Determines whether this BNQuery includes the value of a specific
53 | * attribute
54 | *
55 | * @param attr the target Attribute
56 | * @return true if this query is specifying a value for this specific
57 | * Attribute
58 | */
59 | public Boolean containsAttribute(Attribute attr)
60 | {
61 | return queryItems.keySet().contains(attr);
62 | }
63 |
64 | /**
65 | * Gets the value for a specific query attribute. If this attribute
66 | * is not specified by this query, this method returns null.
67 | *
68 | * @param attr the target Attribute
69 | * @return the value of target Attribute specified in this query. null if
70 | * this Attribute is not specified in this query
71 | */
72 | public Integer getValueForQueryAttribute(Attribute attr)
73 | {
74 | return queryItems.get(attr);
75 | }
76 |
77 | @Override
78 | public String toString()
79 | {
80 | String result = "CPD(";
81 |
82 | for (Entry entry : queryItems.entrySet())
83 | {
84 | Attribute attr = entry.getKey();
85 | Integer attrValue = entry.getValue();
86 | result += attr.getName() + " = "
87 | + attr.getNominalValueName(attrValue) + ", ";
88 | }
89 |
90 | result = result.substring(0, result.length() - 2);
91 | result += ")";
92 |
93 | return result;
94 | }
95 |
96 | }
97 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/cpd/CPDTree.java:
--------------------------------------------------------------------------------
1 | package bayes.cpd;
2 |
3 | import data.AttributeSet;
4 |
5 | /**
6 | * A tree data structure used for representing and storing the
7 | * conditional probability distribution (CPD) for a specific node in a
8 | * Bayesian network.
9 | *
10 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
11 | *
12 | */
13 | public class CPDTree
14 | {
15 | /**
16 | * Total instances in the training data set used for generating this CPD
17 | */
18 | protected static Integer totalInstances;
19 |
20 | /**
21 | * The set of attributes in the training set used for generating this CPD
22 | */
23 | protected AttributeSet attributeSet;
24 |
25 | /**
26 | * The root node of the CPD tree
27 | */
28 | protected CPDNode root;
29 |
30 | /**
31 | * @return the root node of this CPD tree
32 | */
33 | public CPDNode getRoot()
34 | {
35 | return root;
36 | }
37 |
38 | /**
39 | * Print the subtree rooted at the input Node to standard output.
40 | *
41 | * @param root the node that roots the subtree being printed
42 | */
43 | public String toString()
44 | {
45 | String result = "";
46 |
47 | for (CPDNode child : root.getChildren())
48 | {
49 | ToStringHelper treePrinter = new ToStringHelper(this.attributeSet);
50 | result += treePrinter.getString(child, 0);
51 | }
52 |
53 | return result;
54 | }
55 |
56 | /**
57 | * A private helper class used for traversing the tree recursively
58 | * in order to convert the tree to a String
59 | */
60 | private static class ToStringHelper
61 | {
62 | @SuppressWarnings("unused")
63 | private AttributeSet attributeSet;
64 |
65 | public ToStringHelper(AttributeSet attributeSet)
66 | {
67 | this.attributeSet = attributeSet;
68 | }
69 |
70 | public String getString(CPDNode node, Integer depth)
71 | {
72 | String result = "";
73 |
74 | /*
75 | * Print the indentated "|" characters
76 | */
77 | for (int i = 0; i < depth; i++)
78 | {
79 | result += "| ";
80 | }
81 |
82 | /*
83 | * Print the value at the current node
84 | */
85 | result += node;
86 |
87 | result += "\n";
88 |
89 | /*
90 | * Generate the string for the child nodes of this node
91 | */
92 | for (CPDNode child : node.getChildren())
93 | {
94 | result += getString(child, depth + 1);
95 | }
96 |
97 | return result;
98 | }
99 | }
100 |
101 | /**
102 | * Make a query on this CPD tree
103 | *
104 | * @param query the query object
105 | * @return the probability of this query given this CPD
106 | */
107 | public Double query(CPDQuery query)
108 | {
109 | return this.root.calculateProbability(query);
110 | }
111 | }
112 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/cpd/Split.java:
--------------------------------------------------------------------------------
1 | package bayes.cpd;
2 |
3 | import java.util.ArrayList;
4 |
5 |
6 | import data.Attribute;
7 | import data.DataSet;
8 | import data.Instance;
9 | import data.InstanceSet;
10 |
11 | /**
12 | * Splits all instances along a specific attribute.
13 | *
14 | * @author Matthew Bernstien - matthewb@cs.wisc.edu
15 | *
16 | */
17 | public class Split
18 | {
19 | /**
20 | * All of the branches for this split. For splits on nominal attributes,
21 | * there will be one branch per nominal value. For continuous attributes,
22 | * there will be two branches. One branch for the all instances with a value
23 | * greater than the threshold and one branch for all instances less than or
24 | * equal to the threshold value.
25 | */
26 | private ArrayList branches;
27 |
28 | /**
29 | * The attribute this Split splits instances on
30 | */
31 | private Attribute attribute;
32 |
33 | public Split(Attribute attribute)
34 | {
35 | branches = new ArrayList();
36 | this.attribute = attribute;
37 | }
38 |
39 | public Attribute getAttribute()
40 | {
41 | return attribute;
42 | }
43 |
44 | /**
45 | * Split a set of instances along this split's attribute. Each split is
46 | * stored in one of this Split's SplitBranch objects
47 | *
48 | * @param instances the set of instances we wish to split
49 | */
50 | public void splitInstances(InstanceSet instances)
51 | {
52 | for (Instance instance : instances.getInstances())
53 | {
54 | for (SplitBranch branch : this.branches)
55 | {
56 | branch.tryAddInstance(instance);
57 | }
58 | }
59 | }
60 |
61 | /**
62 | * @return all of this split's branches
63 | */
64 | public ArrayList getSplitBranches()
65 | {
66 | return branches;
67 | }
68 |
69 | /**
70 | * Add a branch to the split
71 | *
72 | * @param newBranch the new branch
73 | */
74 | public void addBranch(SplitBranch newBranch)
75 | {
76 | branches.add(newBranch);
77 | }
78 |
79 | /**
80 | * A helper method for generating a split along a nominal attribute
81 | *
82 | * @param attrId
83 | * @return
84 | */
85 | public static Split createSplitNominal(Attribute attr, DataSet data)
86 | {
87 | Split split = new Split(attr);
88 |
89 | for (Integer nominalValueId : attr.getNominalValueMap().values())
90 | {
91 | SplitBranch newBranch = new SplitBranch(attr,
92 | new Double(nominalValueId));
93 | split.addBranch(newBranch);
94 | }
95 |
96 | split.splitInstances(data.getInstanceSet());
97 |
98 | return split;
99 | }
100 |
101 | }
102 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/cpd/SplitBranch.java:
--------------------------------------------------------------------------------
1 | package bayes.cpd;
2 |
3 |
4 | import data.Attribute;
5 | import data.Instance;
6 | import data.InstanceSet;
7 |
8 | /**
9 | * This class stores all instances for which a specific attribute matches a
10 | * specific value.
11 | *
12 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
13 | *
14 | */
15 | public class SplitBranch
16 | {
17 |
18 | /**
19 | * The value that an instance's attribute (the attribute determined by this
20 | * SplitBranch's attribute) is tested against to make this split.
21 | */
22 | private Double branchValue;
23 |
24 | /**
25 | * The attribute this branch tests
26 | */
27 | private Attribute attribute;
28 |
29 | /**
30 | * All instances that fall to this branch
31 | */
32 | private InstanceSet instanceSet;
33 |
34 | /**
35 | * Constructor
36 | *
37 | * @param attribute the attribute this branch tests
38 | * @param branchValue the value that an instance's attribute (this
39 | * SplitBranch's attribute) is tested against to make this split.
40 | */
41 | public SplitBranch(Attribute attribute, Double branchValue)
42 | {
43 | this.instanceSet = new InstanceSet();
44 | this.attribute = attribute;
45 | this.branchValue = branchValue;
46 | }
47 |
48 | public InstanceSet getInstanceSet()
49 | {
50 | return instanceSet;
51 | }
52 |
53 | /**
54 | * Attempt to add an instance to the this split branch. The instance is
55 | * only add if it passes this branches test.
56 | *
57 | * @param instance
58 | */
59 | public void tryAddInstance(Instance instance)
60 | {
61 | if (this.doesInstanceMakeSplit(instance))
62 | {
63 | instanceSet.addInstance(instance);
64 | }
65 | }
66 |
67 | /**
68 | * @return the attribute this branch tests
69 | */
70 | public Attribute getAttribute()
71 | {
72 | return attribute;
73 | }
74 |
75 | /**
76 | * @return the value that an instance is tested against to make this split
77 | */
78 | public Double getValue()
79 | {
80 | return branchValue;
81 | }
82 |
83 | /**
84 | * Tests whether an instance makes this split branch.
85 | *
86 | * @param instance the instance we are testing whether or not it makes
87 | * this SplitBranch
88 | * @return true if the instance's attribute's value matches this
89 | * SplitBranch's value
90 | */
91 | public Boolean doesInstanceMakeSplit(Instance instance)
92 | {
93 | Double instanceAttrValue =
94 | instance.getAttributeValue(this.attribute);
95 |
96 | return (instanceAttrValue.doubleValue() == branchValue.doubleValue());
97 | }
98 |
99 | }
100 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/structuresearch/NaiveBayesBuilder.java:
--------------------------------------------------------------------------------
1 | package bayes.structuresearch;
2 |
3 |
4 | import bayes.BNNode;
5 | import bayes.BayesianNetwork;
6 |
7 |
8 | import data.Attribute;
9 | import data.DataSet;
10 |
11 | /**
12 | * Builds a Bayesian Network with a Naive Bayes Structure
13 | *
14 | */
15 | public class NaiveBayesBuilder extends NetworkBuilder
16 | {
17 | /**
18 | * Builds Bayesian network with a Naive bayes structure.
19 | *
20 | * @param data the data set used to construct the parameters. This
21 | * DataSet's class attribute must be set to a valid attribute.
22 | */
23 | @Override
24 | public BayesianNetwork buildNetwork(DataSet data, Integer laplaceCount)
25 | {
26 | BayesianNetwork net = super.setupNetwork(data, laplaceCount);
27 | net.setNetStructureAlgorithm(BayesianNetwork.StructureAlgorithm.NAIVE_BAYES);
28 |
29 | /*
30 | * Create edges from the class Node to all other nodes
31 | */
32 | BNNode classAttrNode = net.getNode(data.getClassAttribute());
33 | for (BNNode node : net.getNodes())
34 | {
35 | if (!node.equals( classAttrNode ))
36 | {
37 | net.createEdge(classAttrNode, node, data, laplaceCount);
38 | }
39 | }
40 |
41 | return net;
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/structuresearch/NetworkBuilder.java:
--------------------------------------------------------------------------------
1 | package bayes.structuresearch;
2 |
3 |
4 | import bayes.BNNode;
5 | import bayes.BayesianNetwork;
6 | import data.Attribute;
7 | import data.DataSet;
8 |
9 | /**
10 | * Constructs a {@code BayesianNetwork} object
11 | *
12 | */
13 | abstract class NetworkBuilder
14 | {
15 | /**
16 | * The Laplace count used when generating all parameters in the network
17 | */
18 | protected Integer laplaceCount;
19 |
20 | public abstract BayesianNetwork buildNetwork(DataSet data, Integer laplaceCount);
21 |
22 | /**
23 | * Builds a new Bayesian network given a dataset.
24 | *
25 | * @param data the data set used to construct the network
26 | * @param laplaceCount the Laplace count used when generating all
27 | * parameters in the network
28 | * @return a constructed Bayesian network
29 | */
30 | public BayesianNetwork setupNetwork(DataSet data, Integer laplaceCount)
31 | {
32 | this.laplaceCount = laplaceCount;
33 |
34 | BayesianNetwork net = new BayesianNetwork();
35 |
36 | /*
37 | * Create a node corresponding to each nominal attribute in the
38 | * dataset. Continuous attributes are ignored.
39 | */
40 | for (Attribute attr : data.getAttributeSet().getAttributes())
41 | {
42 | if (attr.getType() == Attribute.Type.NOMINAL)
43 | {
44 | BNNode newNode = new BNNode(attr);
45 | net.addNode( newNode, data, this.laplaceCount );
46 | }
47 | }
48 |
49 | return net;
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/structuresearch/Operation.java:
--------------------------------------------------------------------------------
1 | package bayes.structuresearch;
2 |
3 | import bayes.BNNode;
4 |
5 | /**
6 | * Represents a single operation that can be performed on the network's
7 | * structure.
8 | *
9 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
10 | *
11 | */
12 | public class Operation
13 | {
14 | /**
15 | * Types of operations
16 | */
17 | public enum Type {ADD, REMOVE, REVERSE};
18 |
19 | /**
20 | * Parent node of the edge
21 | */
22 | private BNNode parent;
23 |
24 | /**
25 | * Child node of the edge
26 | */
27 | private BNNode child;
28 |
29 | /**
30 | * This operation's type
31 | */
32 | private Type type;
33 |
34 | /**
35 | * Constructor
36 | */
37 | public Operation() {}
38 |
39 | /**
40 | * Constructor
41 | *
42 | * @param type this operation's type
43 | * @param parent the parent node of the edge
44 | * @param child the child node of the edge
45 | */
46 | public Operation(Operation.Type type, BNNode parent, BNNode child)
47 | {
48 | this.type = type;
49 | this.parent = parent;
50 | this.child = child;
51 | }
52 |
53 | /**
54 | * @param parent the parent node of the edge
55 | */
56 | public void setParent(BNNode parent)
57 | {
58 | this.parent = parent;
59 | }
60 |
61 | /**
62 | * @param child the child node of the edge
63 | */
64 | public void setChild(BNNode child)
65 | {
66 | this.child = child;
67 | }
68 |
69 | /**
70 | * @param type this operation's type
71 | */
72 | public void setType(Operation.Type type)
73 | {
74 | this.type = type;
75 | }
76 |
77 | /**
78 | * @return the child node of the edge
79 | */
80 | public BNNode getChild()
81 | {
82 | return child;
83 | }
84 |
85 | /**
86 | * @return the parent node of the edge
87 | */
88 | public BNNode getParent()
89 | {
90 | return parent;
91 | }
92 |
93 | /**
94 | * @return this operation's type
95 | */
96 | public Operation.Type getType()
97 | {
98 | return this.type;
99 | }
100 |
101 | @Override
102 | public String toString()
103 | {
104 | String result = "";
105 |
106 | switch(this.type)
107 | {
108 | case ADD:
109 | result += "ADD ";
110 | break;
111 | case REVERSE:
112 | result += "REVERSE ";
113 | break;
114 | case REMOVE:
115 | result += "REMOVE ";
116 | break;
117 | }
118 |
119 | result += this.parent.getName() + " -> " + this.child.getName();
120 |
121 | return result;
122 | }
123 | }
124 |
--------------------------------------------------------------------------------
/machine-learning/src/bayes/structuresearch/score/ScoringFunction.java:
--------------------------------------------------------------------------------
1 | package bayes.structuresearch.score;
2 |
3 | import data.DataSet;
4 |
5 | import bayes.BayesianNetwork;
6 |
7 | public interface ScoringFunction
8 | {
9 | public Double scoreNet(BayesianNetwork net, DataSet data);
10 | }
11 |
--------------------------------------------------------------------------------
/machine-learning/src/classify/ClassificationResult.java:
--------------------------------------------------------------------------------
1 | package classify;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | import data.Attribute;
7 | import data.DataSet;
8 | import pair.Pair;
9 |
10 | /**
11 | * Objects of this class are used to store the results of a single
12 | * classification experiment.
13 | *
14 | */
15 | public class ClassificationResult
16 | {
17 |
18 | /**
19 | * String representation of the results
20 | */
21 | private String resultStr = "";
22 |
23 | /**
24 | * Accuracy of the classification task
25 | */
26 | private Double accuracy;
27 |
28 | /**
29 | * Size of the test set used in the classification
30 | */
31 | private Integer testDataSize;
32 |
33 | /**
34 | * Constructor
35 | *
36 | * @param resultList an ArrayList of predictions. Each Pair in this list
37 | * represents the classification of a single instance in the test data.
38 | * The first element of each pair refers to the predicted nominal value ID
39 | * the class attribute. The second element refers to the confidence of the
40 | * classifier.
41 | *
42 | * @param testData the DataSet object containing all test instances
43 | */
44 | public ClassificationResult(List> resultList,
45 | DataSet testData)
46 | {
47 | Attribute classAttr = testData.getClassAttribute();
48 |
49 | int correctCount = 0;
50 |
51 | for (int i = 0; i < resultList.size(); i++)
52 | {
53 | Integer classification = resultList.get(i).getFirst();
54 | Integer truth = testData.getInstanceSet()
55 | .getInstanceById(i)
56 | .getAttributeValue(classAttr).intValue();
57 |
58 | /*
59 | * Check for correct classification
60 | */
61 | if (classification == truth)
62 | {
63 | correctCount++;
64 | }
65 |
66 | resultStr += classAttr.getNominalValueName(classification);
67 | resultStr += " ";
68 | resultStr += classAttr.getNominalValueName(truth);
69 | resultStr += " ";
70 | resultStr += resultList.get(i).getSecond();
71 | resultStr += "\n";
72 | }
73 |
74 | resultStr += "\n";
75 | resultStr += correctCount;
76 |
77 | // Set metrics
78 | this.testDataSize = testData.getInstanceSet().getInstances().size();
79 | this.accuracy = (double) correctCount / this.testDataSize;
80 | }
81 |
82 | /**
83 | * @return the classification accuracy from this experiment
84 | */
85 | public Double getAccuracy()
86 | {
87 | return this.accuracy;
88 | }
89 |
90 | @Override
91 | public String toString()
92 | {
93 | return resultStr;
94 | }
95 |
96 | }
97 |
--------------------------------------------------------------------------------
/machine-learning/src/classify/Classifier.java:
--------------------------------------------------------------------------------
1 | package classify;
2 |
3 | import data.DataSet;
4 |
5 | /**
6 | * Any model learned to perform a supervised classification task should
7 | * implement this interface.
8 | *
9 | */
10 | public interface Classifier
11 | {
12 | public ClassificationResult classifyData(DataSet testData);
13 |
14 | public Object getModel();
15 | }
16 |
--------------------------------------------------------------------------------
/machine-learning/src/classify/evaluate/PercentageError.java:
--------------------------------------------------------------------------------
1 | package classify.evaluate;
2 |
3 | /**
4 | * Methods for evaluating percentage errors from classification experiments.
5 | *
6 | */
7 | public class PercentageError
8 | {
9 | /**
10 | * Mean percentage error
11 | *
12 | * @param truthVals the true values
13 | * @param predictedVals the predicted values
14 | * @return the mean percentage error between the true and predicted values
15 | */
16 | public static double meanPercentageError(Double[] truthVals,
17 | Double[] predictedVals)
18 | {
19 | double sum = 0.0;
20 |
21 | for (int i = 0; i < truthVals.length; i++)
22 | {
23 | sum += percentageError(truthVals[i], predictedVals[i]);
24 | }
25 |
26 | return (100.0 / truthVals.length) * sum;
27 | }
28 |
29 | /**
30 | * Percentage error
31 | *
32 | * @param truth the true value
33 | * @param predicted the predicted value
34 | * @return the precentage error between the true and predicted values
35 | */
36 | public static double percentageError(Double truth, Double predicted)
37 | {
38 | return Math.abs((truth - predicted) / truth);
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/machine-learning/src/data/AttributeSet.java:
--------------------------------------------------------------------------------
1 | package data;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collection;
5 | import java.util.HashMap;
6 | import java.util.List;
7 | import java.util.Map;
8 | import java.util.Map.Entry;
9 |
10 | import com.google.common.collect.ImmutableMap;
11 |
12 | /**
13 | * Stores a set of attributes.
14 | *
15 | */
16 | public class AttributeSet
17 | {
18 | /**
19 | * Maps the name of an attribute to the object
20 | */
21 | private final Map nameAttrMap;
22 |
23 | /**
24 | * The attribute that denotes the "class" or "concept"
25 | */
26 | private String classAttribute;
27 |
28 | /**
29 | * Constructor
30 | */
31 | public AttributeSet(List attributes)
32 | {
33 | ImmutableMap.Builder builder = new ImmutableMap.Builder<>();
34 |
35 | for (Attribute attr : attributes)
36 | {
37 | builder.put(attr.getName(), attr);
38 | }
39 |
40 | this.nameAttrMap = builder.build();
41 | }
42 |
43 | /**
44 | * Get an attribute by its attribute name.
45 | *
46 | * @param attrName - the attribute name
47 | * @return the attribute with this name
48 | */
49 | public Attribute getAttributeByName(String attrName)
50 | {
51 | return nameAttrMap.get(attrName);
52 | }
53 |
54 | /**
55 | * Get the nominal value ID for a specific attribute name and nominal
56 | * value of that attribute
57 | *
58 | * @param attrName the name of the target attribute
59 | * @param attrValue the name of the target nominal value
60 | * @return the unique integer ID of that nominal value
61 | */
62 | public Integer getNominalValueId(String attrName, String attrValue)
63 | {
64 | return nameAttrMap.get(attrName).getNominalValueId(attrValue);
65 | }
66 |
67 | /**
68 | * @return a list of all attributes in the attribute set
69 | */
70 | public List getAttributes()
71 | {
72 | return new ArrayList<>(nameAttrMap.values());
73 | }
74 |
75 | /**
76 | * Sets the attribute that will be used as attribute set's class attribute
77 | *
78 | * @param attrName
79 | */
80 | public void setClass(String attrName)
81 | {
82 | if (this.containsAttrWithName(attrName))
83 | {
84 | classAttribute = attrName;
85 | }
86 | else
87 | {
88 | throw new RuntimeException("Trying to set an invalid attribute, " +
89 | attrName + " as the class attribute.");
90 | }
91 | }
92 |
93 | /**
94 | * @return the name of the attribute denoted as the "class" or "concept"
95 | * attribute
96 | */
97 | public String getClassAttrName()
98 | {
99 | return classAttribute;
100 | }
101 |
102 | /**
103 | * Determines whether a given attribute name is in the attribute set.
104 | *
105 | * @param attrName The name of the attribute for which we are checking is
106 | * valid
107 | */
108 | public Boolean containsAttrWithName(String attrName)
109 | {
110 | return nameAttrMap.containsKey(attrName);
111 | }
112 |
113 | }
114 |
--------------------------------------------------------------------------------
/machine-learning/src/data/Instance.java:
--------------------------------------------------------------------------------
1 | package data;
2 |
3 | import java.util.HashMap;
4 | import java.util.Map;
5 | import java.util.Map.Entry;
6 |
7 |
8 | /**
9 | * Represents an instance.
10 | *
11 | */
12 | public class Instance
13 | {
14 | /**
15 | * This instance's attribute value. The map maps an attribute ID
16 | * to a valid value for that attribute
17 | */
18 | private final Map attributesToValues;
19 |
20 | /**
21 | * Constructor
22 | */
23 | public Instance()
24 | {
25 | attributesToValues = new HashMap<>();
26 | }
27 |
28 |
29 | /**
30 | * Get the value for an attribute.
31 | *
32 | * @param attr the specified attribute
33 | * @return this instance's value of the specified attribute
34 | */
35 | public Double getAttributeValue(Attribute attr)
36 | {
37 | return attributesToValues.get(attr);
38 | }
39 |
40 | /**
41 | * Add an attribute-value pair to the instance.
42 | *
43 | * @param attrId the attribute ID of the attribute being added
44 | * @param value the value of the corresponding attribute
45 | */
46 | public void addAttributeValue(Attribute attr, Double value)
47 | {
48 | attributesToValues.put(attr, value);
49 | }
50 |
51 | /**
52 | * Checks if this Instance is equal to another Instance.
53 | *
54 | * @param o the other Instance
55 | * @return
56 | */
57 | @Override
58 | public boolean equals(Object o)
59 | {
60 | Map other = ((Instance)o).attributesToValues;
61 | for(Entry attr: attributesToValues.entrySet())
62 | {
63 | if(!other.get(attr.getKey()).equals(attr.getValue()))
64 | {
65 | return false;
66 | }
67 | }
68 | return true;
69 | }
70 |
71 | public String toString()
72 | {
73 | String result = "";
74 | for(Entry entry: attributesToValues.entrySet())
75 | {
76 | if (entry.getKey().getType() == Attribute.Type.NOMINAL)
77 | {
78 | result += "(" + entry.getKey().getName() + " = " + entry.getKey().getNominalValueName(entry.getValue().intValue()) + ") ";
79 | }
80 | else
81 | {
82 | result += "(" + entry.getKey().getName() + " = " + entry.getValue() + ") ";
83 | }
84 | }
85 | return result;
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/machine-learning/src/data/InstanceSet.java:
--------------------------------------------------------------------------------
1 | package data;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | /**
7 | * This class represents a set of {@code Instance} objects.
8 | *
9 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
10 | *
11 | */
12 | public class InstanceSet
13 | {
14 | /**
15 | * All instances in this instance set
16 | */
17 | private final List instances;
18 |
19 | /**
20 | * Constructor
21 | */
22 | public InstanceSet()
23 | {
24 | instances = new ArrayList();
25 | }
26 |
27 | /**
28 | * @return a list of all instances in this instance set
29 | */
30 | public List getInstances()
31 | {
32 | return instances;
33 | }
34 |
35 | /**
36 | * Add an instance to this instance set
37 | *
38 | * @param newInstance the new instance
39 | */
40 | public void addInstance(Instance newInstance)
41 | {
42 | instances.add(newInstance);
43 | }
44 |
45 | /**
46 | * @param id the unique ID of a specific instance
47 | * @return the instance with specified ID
48 | */
49 | public Instance getInstanceById(int id)
50 | {
51 | return instances.get(id);
52 | }
53 |
54 | public String toString(){
55 | return instances.toString();
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/machine-learning/src/data/fold/KFoldCreator.java:
--------------------------------------------------------------------------------
1 | package data.fold;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collections;
5 | import java.util.List;
6 |
7 | import data.DataSet;
8 | import data.Instance;
9 | import data.InstanceSet;
10 | import data.reader.ArffReader;
11 |
12 | import pair.Pair;
13 |
14 | public class KFoldCreator {
15 |
16 | public static void main(String[] args){
17 | ArffReader ar = new ArffReader();
18 | DataSet data = ar.readFile("data/kfold_test.arff");
19 | data.setClassAttribute("digit");
20 | List> pairs = KFoldCreator.create(data, 5);
21 | for(Pair pair: pairs){
22 | System.out.println("TRAIN:");
23 | System.out.println(pair.getFirst());
24 | System.out.println("\nTEST:");
25 | System.out.println(pair.getSecond());
26 | System.out.println("\n\n");
27 | }
28 | }
29 |
30 | public static List> create(DataSet data, int K) {
31 |
32 | InstanceSet is = data.getInstanceSet();
33 |
34 | List instances = is.getInstances();
35 | Collections.shuffle(instances);
36 |
37 |
38 | int numPerSplice = instances.size()/K;
39 | List> pairs = new ArrayList<>();
40 |
41 | //For each fold
42 | for(int i = 0; i < K; i++){
43 | //fill training and testing lists
44 | List trainInst = new ArrayList<>();
45 | List testInst = new ArrayList<>();
46 | int left = i * numPerSplice;
47 | int right = i * numPerSplice + numPerSplice;
48 | trainInst.addAll(instances.subList(0, left));
49 | if(i + 1 == K){
50 | testInst.addAll(instances.subList(left, instances.size()));
51 | } else {
52 | testInst.addAll(instances.subList(left, right));
53 | trainInst.addAll(instances.subList(right, instances.size()));
54 | }
55 |
56 | //put in sets
57 | InstanceSet trainSet = new InstanceSet();
58 | for(Instance inst: trainInst){
59 | trainSet.addInstance(inst);
60 | }
61 |
62 | InstanceSet testSet = new InstanceSet();
63 | for(Instance inst: testInst){
64 | testSet.addInstance(inst);
65 | }
66 | //create data sets
67 | DataSet trainData = new DataSet(data.getAttributeSet(), trainSet);
68 | DataSet testData = new DataSet(data.getAttributeSet(),testSet);
69 |
70 | //Add to list of pairs
71 | pairs.add(new Pair(trainData, testData));
72 |
73 | }
74 | return pairs;
75 | }
76 |
77 | }
78 |
--------------------------------------------------------------------------------
/machine-learning/src/distributions/Distribution.java:
--------------------------------------------------------------------------------
1 | package distributions;
2 |
3 | public interface Distribution
4 | {
5 | public double sample();
6 |
7 | public double[] sampleMany(int numSamples);
8 | }
9 |
--------------------------------------------------------------------------------
/machine-learning/src/distributions/GeometricDistribution.java:
--------------------------------------------------------------------------------
1 | package distributions;
2 |
3 | import java.util.Random;
4 |
5 | public class GeometricDistribution implements Distribution
6 | {
7 | private final Random RNG;
8 | private final double PARAMETER;
9 |
10 | public GeometricDistribution(double parameter)
11 | {
12 | this.PARAMETER = parameter;
13 | this.RNG = new Random();
14 | }
15 |
16 | public double sample()
17 | {
18 | int result = 1;
19 | while (true)
20 | {
21 | if (RNG.nextDouble() < PARAMETER)
22 | {
23 | return result;
24 | }
25 | else
26 | {
27 | result++;
28 | }
29 | }
30 | }
31 |
32 | public double[] sampleMany(int numSamples)
33 | {
34 | double[] samples = new double[numSamples];
35 | for (int i = 0; i < numSamples; i++)
36 | {
37 | samples[i] = sample();
38 | }
39 | return samples;
40 | }
41 | }
42 |
--------------------------------------------------------------------------------
/machine-learning/src/graph/Path.java:
--------------------------------------------------------------------------------
1 | package graph;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | public class Path
7 | {
8 | /**
9 | * The first node on the path
10 | */
11 | private T origin;
12 |
13 | /**
14 | * The final node on the path
15 | */
16 | private T destination;
17 |
18 | /**
19 | * The path's distance
20 | */
21 | private double length;
22 |
23 | /**
24 | * List of nodes representing the path
25 | */
26 | private List path;
27 |
28 | public Path(T origin, T destinationNode)
29 | {
30 | this.origin = origin;
31 | this.destination = destinationNode;
32 | this.path = new ArrayList<>();
33 | this.path.add(this.origin);
34 | }
35 |
36 | public void appendNodeToPath(T nextNode, Double edgeLength)
37 | {
38 | length += edgeLength;
39 | path.add(nextNode);
40 | }
41 |
42 | public List getNodesOnPath()
43 | {
44 | return this.path;
45 | }
46 |
47 | public T getPathOrigin()
48 | {
49 | return this.origin;
50 | }
51 |
52 | public T getPathDestination()
53 | {
54 | return this.destination;
55 | }
56 |
57 | public double getPathLength()
58 | {
59 | return this.length;
60 | }
61 |
62 | @Override
63 | public String toString()
64 | {
65 | String str = "";
66 | for (int i = 0; i < path.size() - 1; i++)
67 | {
68 | str += path.get(i) + " --> ";
69 | }
70 | str += path.get(path.size() - 1);
71 | str += " : " + this.length;
72 | return str;
73 | }
74 | }
75 |
76 |
--------------------------------------------------------------------------------
/machine-learning/src/graph/dag/DetectCycles.java:
--------------------------------------------------------------------------------
1 | package graph.dag;
2 |
3 | import java.util.ArrayList;
4 |
5 | public class DetectCycles
6 | {
7 | /**
8 | * Detects if there is a cycle in the graph
9 | *
10 | * @param graph the adjacency matrix representing the graph. Non-edges are
11 | * represented by the null reference
12 | * @return true if a cycle has been detected, false otherwise
13 | */
14 | public static Boolean run(Double[][] graph)
15 | {
16 | /*
17 | * All unsorted vertices
18 | */
19 | ArrayList toProcess = getAllVertices(graph);
20 |
21 | /*
22 | * Run topological sort
23 | */
24 | while(!toProcess.isEmpty())
25 | {
26 | int beforeSize = toProcess.size();
27 | runIteration(graph, toProcess);
28 |
29 | if (toProcess.size() == beforeSize)
30 | {
31 | return true;
32 | }
33 | }
34 |
35 | return false;
36 | }
37 |
38 | /**
39 | * Run a single iteration of cycle detection algorithm
40 | *
41 | * @param graph the graph
42 | * @param toProcess a list of vertices that have yet to be processed
43 | */
44 | private static void runIteration(Double[][] graph,
45 | ArrayList toProcess)
46 | {
47 |
48 | for (int i = 0; i < toProcess.size(); i++)
49 | {
50 | int vertex = toProcess.get(i);
51 |
52 | /*
53 | * Find parent
54 | */
55 | boolean foundParent = false;
56 | for (int r = 0; r < graph.length; r++)
57 | {
58 | if (graph[r][vertex] != null)
59 | {
60 | foundParent = true;
61 | break;
62 | }
63 | }
64 |
65 | /*
66 | * If no parents are found, cut this vertex from the graph
67 | */
68 | if (!foundParent)
69 | {
70 | cutVertex(graph, vertex);
71 | for (int j = 0; j < toProcess.size(); j++)
72 | {
73 | if (toProcess.get(j) == vertex)
74 | {
75 | toProcess.remove(j);
76 | }
77 | }
78 | }
79 | }
80 | }
81 |
82 | /**
83 | * Cut a vertex from the graph. That is, cut all outgoing edges from this
84 | * vertex.
85 | *
86 | * @param graph the graph
87 | * @param vertex the vertex to cut from the graph
88 | */
89 | private static void cutVertex(Double[][] graph, Integer vertex)
90 | {
91 | for (int c = 0; c < graph.length; c++)
92 | {
93 | graph[vertex][c] = null;
94 | }
95 | }
96 |
97 | /**
98 | * Get all vertices in a graph
99 | *
100 | * @param graph the graph
101 | * @return all vertices in the graph
102 | */
103 | private static ArrayList getAllVertices(Double[][] graph)
104 | {
105 | ArrayList allNodes = new ArrayList();
106 |
107 | for (int r = 0; r < graph.length; r++)
108 | {
109 | allNodes.add(r);
110 | }
111 |
112 | return allNodes;
113 | }
114 |
115 | }
116 |
--------------------------------------------------------------------------------
/machine-learning/src/graph/dag/TopologicalSort.java:
--------------------------------------------------------------------------------
1 | package graph.dag;
2 |
3 | import java.util.ArrayList;
4 |
5 | /**
6 | * Topologically sorts the vertices in a Directed Acyclic Graph (DAG).
7 | *
8 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
9 | *
10 | */
11 | public class TopologicalSort
12 | {
13 | /**
14 | * Run the topological sort
15 | *
16 | * @param graph the adjacency matrix representing the DAG. Non-edges are
17 | * represented by the null reference
18 | * @return a sorted list of vertices
19 | */
20 | public static ArrayList run(Double[][] graph)
21 | {
22 | /*
23 | * All unsorted vertices
24 | */
25 | ArrayList toProcess = getAllVertices(graph);
26 |
27 | /*
28 | * Sorted vertices
29 | */
30 | ArrayList sorted = new ArrayList();
31 |
32 | /*
33 | * Run topological sort
34 | */
35 | while(!toProcess.isEmpty())
36 | {
37 | sorted.addAll( runIteration(graph, toProcess) );
38 | }
39 |
40 | return sorted;
41 | }
42 |
43 | /**
44 | * Run a single iteration of the topological sort
45 | *
46 | * @param graph the graph
47 | * @param toProcess a list of vertices that have yet to be added to the list
48 | * of sorted vertices
49 | * @return a list of vertices from the list of unsorted vertices that must
50 | * be appended to the list of sorted vertices
51 | */
52 | private static ArrayList runIteration(Double[][] graph,
53 | ArrayList toProcess)
54 | {
55 | ArrayList newSorted = new ArrayList();
56 |
57 | for (int i = 0; i < toProcess.size(); i++)
58 | {
59 | int vertex = toProcess.get(i);
60 |
61 | /*
62 | * Find parent
63 | */
64 | boolean foundParent = false;
65 | for (int r = 0; r < graph.length; r++)
66 | {
67 | if (graph[r][vertex] != null)
68 | {
69 | foundParent = true;
70 | break;
71 | }
72 | }
73 |
74 | /*
75 | * If no parents are found, add the vertex to the list of vertices
76 | * to be returned and cut this vertex from the graph
77 | */
78 | if (!foundParent)
79 | {
80 | cutVertex(graph, vertex);
81 | for (int j = 0; j < toProcess.size(); j++)
82 | {
83 | if (toProcess.get(j) == vertex)
84 | {
85 | toProcess.remove(j);
86 | }
87 | }
88 | newSorted.add(vertex);
89 | }
90 | }
91 |
92 | return newSorted;
93 | }
94 |
95 | /**
96 | * Cut a vertex from the graph. That is, cut all outgoing edges from this
97 | * vertex.
98 | *
99 | * @param graph the graph
100 | * @param vertex the vertex to cut from the graph
101 | */
102 | private static void cutVertex(Double[][] graph, Integer vertex)
103 | {
104 | for (int c = 0; c < graph.length; c++)
105 | {
106 | graph[vertex][c] = null;
107 | }
108 | }
109 |
110 | /**
111 | * Get all vertices in a graph
112 | *
113 | * @param graph the graph
114 | * @return all vertices in the graph
115 | */
116 | private static ArrayList getAllVertices(Double[][] graph)
117 | {
118 | ArrayList allNodes = new ArrayList();
119 |
120 | for (int r = 0; r < graph.length; r++)
121 | {
122 | allNodes.add(r);
123 | }
124 |
125 | return allNodes;
126 | }
127 | }
128 |
--------------------------------------------------------------------------------
/machine-learning/src/graph/floydwarshall/AllPairsShortestPaths.java:
--------------------------------------------------------------------------------
1 | package graph.floydwarshall;
2 |
3 | import graph.Path;
4 |
5 | import java.util.Map;
6 | import java.util.Map.Entry;
7 |
8 | import bimap.BiMap;
9 |
10 | public class AllPairsShortestPaths
11 | {
12 | private final BiMap indexToNode;
13 |
14 | private final Double[][] distanceMatrix;
15 |
16 | private final Integer[][] nextNodeMatrix;
17 |
18 | public AllPairsShortestPaths(Double[][] distanceMatrix, Integer[][] nextNodeMatrix, Map indexToNode)
19 | {
20 | this.indexToNode = new BiMap<>();
21 | for (Entry e : indexToNode.entrySet())
22 | {
23 | this.indexToNode.put(e.getKey(), e.getValue());
24 | }
25 |
26 | this.distanceMatrix = distanceMatrix;
27 | this.nextNodeMatrix = nextNodeMatrix;
28 | }
29 |
30 | public Path getPath(T origin, T destination)
31 | {
32 | if (!pathExists(origin, destination))
33 | {
34 | return null;
35 | }
36 |
37 | Path path = new Path<>(origin, destination);
38 |
39 | Integer currNodeIndex = indexToNode.getKey(origin);
40 | Integer destNodeIndex = indexToNode.getKey(destination);
41 |
42 | while (!currNodeIndex.equals(destNodeIndex))
43 | {
44 | Integer nextNodeIndex = nextNodeMatrix[currNodeIndex][destNodeIndex];
45 |
46 | Double currDistance = distanceMatrix[currNodeIndex][destNodeIndex];
47 | Double nextNodeDistance = distanceMatrix[nextNodeIndex][destNodeIndex];
48 | Double edgeWeight = currDistance - nextNodeDistance;
49 |
50 | path.appendNodeToPath(indexToNode.getValue(nextNodeIndex), edgeWeight); // TODO CALCULATE THE ACTUAL DISTANCE
51 | currNodeIndex = nextNodeIndex;
52 | }
53 |
54 | return path;
55 | }
56 |
57 | public boolean pathExists(T origin, T destination)
58 | {
59 | if (nextNodeMatrix[indexToNode.getKey(origin)][indexToNode.getKey(destination)] == null)
60 | {
61 | return false;
62 | }
63 | return true;
64 | }
65 |
66 | }
67 |
--------------------------------------------------------------------------------
/machine-learning/src/graph/floydwarshall/FloydWarshall.java:
--------------------------------------------------------------------------------
1 | package graph.floydwarshall;
2 |
3 | import graph.DirectedGraph;
4 | import graph.Path;
5 |
6 | import java.util.HashMap;
7 | import java.util.Map;
8 |
9 | import pair.Pair;
10 |
11 | /**
12 | * An implementation of the Floyd-Warshall algorithm for computing all-pairs
13 | * shortest paths on a directed graph with no negative cycles.
14 | *
15 | */
16 | public class FloydWarshall
17 | {
18 | public static AllPairsShortestPaths runFloydWarshall(DirectedGraph graph)
19 | {
20 | Map indexToNode = mapNodesToIndices(graph);
21 |
22 | Pair matrices = initializeMatrices(graph, indexToNode);
23 |
24 | Double[][] distanceMatrix = matrices.getFirst();
25 | Integer[][] nextNodeMatrix = matrices.getSecond();
26 |
27 | matrices = computeShortestPaths(distanceMatrix, nextNodeMatrix, graph, indexToNode);
28 |
29 | return new AllPairsShortestPaths<>(matrices.getFirst(), matrices.getSecond(), indexToNode);
30 | }
31 |
32 | private static Map mapNodesToIndices(DirectedGraph graph)
33 | {
34 | Map indexToNode = new HashMap<>();
35 | int nodeIndex = 0;
36 | for (T node : graph.getNodes())
37 | {
38 | indexToNode.put(new Integer(nodeIndex++), node);
39 | }
40 | return indexToNode;
41 | }
42 |
43 | /**
44 | * Initialize the matrices for storing shortest distances between nodes and the matrix
45 | * storing the next node from each node along that shortest path.
46 | *
47 | * @param graph the input graph
48 | * @param indexToNode map of the node index to the node object
49 | * @return the distance matrix and next node matrix
50 | */
51 | private static Pair initializeMatrices(DirectedGraph graph,
52 | Map indexToNode)
53 | {
54 | int numNodes = graph.getNodes().size();
55 | Double[][] distanceMatrix = new Double[numNodes][numNodes];
56 | Integer[][] nextNodeMatrix = new Integer[numNodes][numNodes];
57 |
58 | for (int origIndex = 0; origIndex < numNodes; origIndex++)
59 | {
60 | for (int destIndex = 0; destIndex < numNodes; destIndex++)
61 | {
62 | T origin = indexToNode.get(origIndex);
63 | T destination = indexToNode.get(destIndex);
64 |
65 | if (origIndex == destIndex)
66 | {
67 | distanceMatrix[origIndex][destIndex] = 0d;
68 | nextNodeMatrix[origIndex][destIndex] = origIndex;
69 | }
70 | else if (graph.edgeExists(origin, destination))
71 | {
72 | distanceMatrix[origIndex][destIndex] = graph.getEdgeWeight(origin, destination);
73 | nextNodeMatrix[origIndex][destIndex] = destIndex;
74 | }
75 | else
76 | {
77 | distanceMatrix[origIndex][destIndex] = Double.POSITIVE_INFINITY;
78 | nextNodeMatrix[origIndex][destIndex] = null;
79 | }
80 | }
81 | }
82 |
83 | return new Pair<>(distanceMatrix, nextNodeMatrix);
84 | }
85 |
86 | private static Pair computeShortestPaths(Double[][] distanceMatrix,
87 | Integer[][] nextNodeMatrix,
88 | DirectedGraph graph,
89 | Map indexToNode)
90 | {
91 | int numNodes = graph.getNodes().size();
92 | for (int k = 0; k < numNodes; k++)
93 | {
94 | for (int i = 0; i < numNodes; i++)
95 | {
96 | for (int j = 0; j < numNodes; j++)
97 | {
98 | if (distanceMatrix[i][j] > distanceMatrix[i][k] + distanceMatrix[k][j])
99 | {
100 | distanceMatrix[i][j] = distanceMatrix[i][k] + distanceMatrix[k][j];
101 | nextNodeMatrix[i][j] = nextNodeMatrix[i][k];
102 | }
103 | }
104 | }
105 | }
106 | return new Pair<>(distanceMatrix, nextNodeMatrix);
107 | }
108 |
109 | }
110 |
--------------------------------------------------------------------------------
/machine-learning/src/graph/prim/Edge.java:
--------------------------------------------------------------------------------
1 | package graph.prim;
2 |
3 | import java.util.Comparator;
4 |
5 | import pair.Pair;
6 |
7 | /**
8 | * Implements a directed edge in a weighted graph where edge weights are
9 | * floating point values and nodes are represented by unique integers.
10 | *
11 | */
12 | public class Edge
13 | {
14 | /**
15 | * The two vertices that constitute this edge. We consider the edge
16 | * points:
17 | *
18 | *
19 | * first -> second
20 | */
21 | Pair vertices;
22 |
23 | /**
24 | * The edge weight
25 | */
26 | Double weight;
27 |
28 | /**
29 | * Compares edges by weight
30 | */
31 | public static final Comparator EDGE_ORDER =
32 | new Comparator()
33 | {
34 | public int compare(Edge e1, Edge e2)
35 | {
36 | if (e1.weight == e2.weight)
37 | {
38 | return 0;
39 | }
40 | else if (e1.weight < e2.weight)
41 | {
42 | return 1;
43 | }
44 | else
45 | {
46 | return -1;
47 | }
48 | }
49 | };
50 |
51 | /**
52 | * Constructor
53 | *
54 | * @param vertices the two vertices in the edge
55 | * @param weight the weight of the edge
56 | */
57 | public Edge(Pair vertices, Double weight)
58 | {
59 | this.vertices = vertices;
60 | this.weight = weight;
61 | }
62 |
63 | public Edge(Integer origin, Integer destination, Double weight)
64 | {
65 | this.vertices = new Pair(origin, destination);
66 | }
67 |
68 | /**
69 | * @return the edge weight
70 | */
71 | public Double getWeight()
72 | {
73 | return this.weight;
74 | }
75 |
76 | /**
77 | * @return the pair of vertices
78 | */
79 | public Pair getVertices()
80 | {
81 | return this.vertices;
82 | }
83 |
84 | /**
85 | * @return the first vertex
86 | */
87 | public Integer getFirstVertex()
88 | {
89 | return this.vertices.getFirst();
90 | }
91 |
92 | /**
93 | * @return the second vertex
94 | */
95 | public Integer getSecondVertex()
96 | {
97 | return this.vertices.getSecond();
98 | }
99 |
100 | @Override
101 | public String toString()
102 | {
103 | String result = "<";
104 | result += this.vertices.getFirst();
105 | result += "--";
106 | result += this.weight;
107 | result += "--";
108 | result += this.vertices.getSecond();
109 | result += ">";
110 |
111 | return result;
112 | }
113 | }
114 |
115 |
--------------------------------------------------------------------------------
/machine-learning/src/hmm/StateContainer.java:
--------------------------------------------------------------------------------
1 | package hmm;
2 |
3 | import java.util.ArrayList;
4 | import java.util.Collection;
5 | import java.util.HashMap;
6 | import java.util.Map;
7 |
8 | /**
9 | * This class is used to store the state objects that comprise the Markov
10 | * model.
11 | */
12 | public class StateContainer
13 | {
14 | /**
15 | * Maps a state ID to a state object
16 | */
17 | Map states;
18 |
19 | /**
20 | * Constructor.
21 | */
22 | public StateContainer()
23 | {
24 | states = new HashMap();
25 | }
26 |
27 | /**
28 | * @return the array list of all of the states in the model
29 | */
30 | public Collection getStates()
31 | {
32 | return states.values();
33 | }
34 |
35 | /*
36 | * TODO OPTIMIZE THIS!!!!
37 | */
38 | public Collection getSilentStates()
39 | {
40 | ArrayList silent = new ArrayList();
41 | for (State s : states.values())
42 | {
43 | if (s.isSilent)
44 | silent.add(s);
45 | }
46 | return silent;
47 | }
48 |
49 | /**
50 | * Add a state to this container.
51 | *
52 | * @param newState the new state
53 | */
54 | public void addState(State newState)
55 | {
56 | states.put(newState.getId(), newState);
57 | }
58 |
59 | /**
60 | * Retrieve a state with a specified unique ID from this state container.
61 | *
62 | * @param id the unique integer of ID of the state to be retrieved
63 | * @return The state that stores the specified ID. If a state with the
64 | * specified ID does not exist in this container, this method returns
65 | * null.
66 | */
67 | public State getStateById(String id)
68 | {
69 | return states.get(id);
70 | }
71 |
72 | /**
73 | * Determines whether this state container contains a state with the
74 | * specified ID
75 | *
76 | * @param id the ID of the state we are querying for
77 | * @return true if this StateContainer has a state with specified ID,
78 | * false otherwise
79 | */
80 | public boolean containsState(String id)
81 | {
82 | if (states.containsKey(id))
83 | {
84 | return true;
85 | }
86 | else
87 | {
88 | return false;
89 | }
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/machine-learning/src/hmm/StateParamsTied.java:
--------------------------------------------------------------------------------
1 | package hmm;
2 |
3 | import java.util.HashMap;
4 | import java.util.Map;
5 | import java.util.Map.Entry;
6 |
7 | import math.LogP;
8 |
9 |
10 | /**
11 | * Implements a state whose emission probability distribution is tied to that
12 | * of another state.
13 | *
14 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
15 | *
16 | */
17 | public class StateParamsTied extends State
18 | {
19 | /**
20 | * Stores all emission probability distributions for all states with
21 | * tied emission parameters
22 | */
23 | public static Map> tiedEmissionParams
24 | = new HashMap>();
25 |
26 | /**
27 | * The ID of the parameters that this State uses
28 | */
29 | private String paramsKey;
30 |
31 | /**
32 | * Constructor.
33 | *
34 | * @param paramsId the ID of the parameters that this State uses
35 | */
36 | public StateParamsTied(String paramsKey, String id)
37 | {
38 | super(id);
39 | this.paramsKey = paramsKey;
40 | this.initializeParams();
41 | }
42 |
43 | public StateParamsTied(State orig, String paramsKey)
44 | {
45 | super(orig);
46 | this.paramsKey = paramsKey;
47 | this.initializeParams();
48 | }
49 |
50 | public StateParamsTied(StateParamsTied orig)
51 | {
52 | super(orig);
53 | this.paramsKey = orig.paramsKey;
54 | }
55 |
56 | /**
57 | * @return the emission probabilities from this state
58 | */
59 | @Override
60 | public Map getEmissionProbabilites()
61 | {
62 | return tiedEmissionParams.get(this.paramsKey);
63 | }
64 |
65 | /**
66 | * Add an emission probability to this State
67 | *
68 | * @param symbol the symbol this state will emmit
69 | * @param probability the probability the state will emmit this symbol
70 | */
71 | @Override
72 | public void addEmission(String symbol, Double probability)
73 | {
74 | tiedEmissionParams.get(this.paramsKey).put(symbol, probability);
75 | }
76 |
77 | /**
78 | * Get the emission probability of a specific symbol from this state
79 | *
80 | * @param symbol the symbol of interest
81 | * @return the emission probability
82 | */
83 | public double getEmissionProb(String symbol)
84 | {
85 | if (tiedEmissionParams.get(this.paramsKey).containsKey(symbol))
86 | {
87 | return tiedEmissionParams.get(this.paramsKey).get(symbol);
88 | }
89 | else
90 | {
91 | return LogP.ln(0.0);
92 | }
93 | }
94 |
95 | public String getParamsKey()
96 | {
97 | return this.paramsKey;
98 | }
99 |
100 | private void initializeParams()
101 | {
102 | if (!StateParamsTied.tiedEmissionParams.containsKey(this.paramsKey))
103 | {
104 | StateParamsTied.tiedEmissionParams.put(this.paramsKey,
105 | new HashMap());
106 |
107 | // TODO INITIALIZE PARAMETER PROBABILITIES IN MAP
108 | }
109 | }
110 |
111 | @Override
112 | public String toString()
113 | {
114 | String result = "";
115 | result += "[";
116 | result += this.id;
117 | result += "]";
118 | result += "\n";
119 |
120 | result += "............\n";
121 |
122 | for (Entry e : transitions.entrySet())
123 | {
124 | String destStateId = e.getKey();
125 | result += (LogP.exp(e.getValue().getTransitionProbability()) +
126 | " --> ");
127 | result += ("[" + destStateId + "]");
128 | result += "\n";
129 | }
130 |
131 | result += "............\n";
132 |
133 | for (Entry entry :
134 | StateParamsTied.tiedEmissionParams.get(this.paramsKey).entrySet())
135 | {
136 | result += (entry.getKey() + " >> " + LogP.exp(entry.getValue()) + "\n");
137 | }
138 |
139 | result += "............\n";
140 | result += "\n";
141 |
142 | return result;
143 | }
144 | }
145 |
--------------------------------------------------------------------------------
/machine-learning/src/hmm/StateSilent.java:
--------------------------------------------------------------------------------
1 | package hmm;
2 |
3 | import java.util.Map;
4 | import java.util.Map.Entry;
5 |
6 | import math.LogP;
7 |
8 |
9 |
10 | /**
11 | * Implements a state that does not emit any symbols.
12 | *
13 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
14 | *
15 | */
16 | public class StateSilent extends State
17 | {
18 | /**
19 | * Constructor
20 | */
21 | public StateSilent()
22 | {
23 | super();
24 | this.isSilent = true;
25 | }
26 |
27 | public StateSilent(String id)
28 | {
29 | super(id);
30 | this.isSilent = true;
31 | }
32 |
33 | @Override
34 | public Map getEmissionProbabilites()
35 | {
36 | System.err.println("Attempting to retrieve emission probabilities on " +
37 | "emission probabilities for silent state " +
38 | this.id);
39 | return null;
40 | }
41 |
42 | @Override
43 | public void addEmission(String symbol, Double probability)
44 | {
45 | System.err.println("Attempting to add emission probabilities to " +
46 | " silent state " + this.id);
47 | }
48 |
49 | @Override
50 | public double getEmissionProb(String symbol)
51 | {
52 | // TODO CHECK IF THIS IS CORRECT
53 | return LogP.ln(0.0);
54 | }
55 |
56 | @Override
57 | public String toString()
58 | {
59 | String result = "";
60 | result += "[";
61 | result += this.id;
62 | result += "]";
63 | result += "\n";
64 |
65 | result += "............\n";
66 |
67 | for (Entry e : transitions.entrySet())
68 | {
69 | String destStateId = e.getKey();
70 | result += LogP.exp(e.getValue().getTransitionProbability()) +
71 | " --> ";
72 | result += ("[" + destStateId + "]");
73 | result += "\n";
74 | }
75 |
76 | result += "............\n";
77 |
78 | result += "silent\n";
79 |
80 | result += "............\n";
81 | result += "\n";
82 |
83 | return result;
84 | }
85 |
86 | }
87 |
--------------------------------------------------------------------------------
/machine-learning/src/hmm/Transition.java:
--------------------------------------------------------------------------------
1 | package hmm;
2 |
3 | import math.LogP;
4 |
5 | /**
6 | * This class implements a transition between two states in the Markov model:
7 | * an "origin" state and a "destination" state. That is, each Transition
8 | * object represents a transition from the origin state to the
9 | * destination state. Each Transition object stores a "count" that records the
10 | * number of times we observe in the original text that the word associated
11 | * with the destination state follows the word associate with the origin state.
12 | */
13 | public class Transition
14 | {
15 | /**
16 | * ID of the origin state
17 | */
18 | private String originId;
19 |
20 | /**
21 | * ID of the destination state
22 | */
23 | private String destinationId;
24 |
25 | /**
26 | * The transition's associated probability
27 | */
28 | private double probability;
29 |
30 | /**
31 | * Constructor
32 | *
33 | * @param originId ID of the origin state
34 | * @param destinationId ID of the destination state
35 | * @param probability probability of taking that transition
36 | */
37 | public Transition(String originId, String destinationId, double probability)
38 | {
39 | this.originId = originId;
40 | this.destinationId = destinationId;
41 | this.probability = probability;
42 | }
43 |
44 | /**
45 | * Copy constructor
46 | */
47 | public Transition(Transition t)
48 | {
49 | this.originId = t.originId;
50 | this.destinationId = t.destinationId;
51 |
52 | this.probability = t.probability;
53 | }
54 |
55 | /**
56 | * Get the integer ID of the state this transition moves to.
57 | *
58 | * @return the destination ID
59 | */
60 | public String getDestinationId()
61 | {
62 | return destinationId;
63 | }
64 |
65 | /**
66 | * Set the integer ID of the state this transition moves to.
67 | *
68 | * @param destinationId the destination ID
69 | */
70 | public void setDestinationId(String destinationId)
71 | {
72 | this.destinationId = destinationId;
73 | }
74 |
75 | /**
76 | * Get the ID of the state this transition moves from.
77 | *
78 | * @return
79 | */
80 | public String getOriginId()
81 | {
82 | return originId;
83 | }
84 |
85 | /**
86 | * Set the ID of the state this transition moves from.
87 | *
88 | * @param originId the ID of the state this transition moves from
89 | */
90 | public void setOriginId(String originId)
91 | {
92 | this.originId = originId;
93 | }
94 |
95 | /**
96 | * @param the transition probability
97 | */
98 | public void setTransitionProbability(double probability)
99 | {
100 | this.probability = probability;
101 | }
102 |
103 | /**
104 | * @return the transition probability
105 | */
106 | public double getTransitionProbability()
107 | {
108 | return this.probability;
109 | }
110 |
111 | public void incrementTransitionValue(double value)
112 | {
113 | this.probability = LogP.sum(this.probability, value);
114 | }
115 |
116 | }
117 |
--------------------------------------------------------------------------------
/machine-learning/src/hmm/algorithms/DpMatrix.java:
--------------------------------------------------------------------------------
1 | package hmm.algorithms;
2 |
3 | import hmm.HMM;
4 | import hmm.State;
5 |
6 | import java.util.ArrayList;
7 |
8 | import math.LogP;
9 |
10 |
11 |
12 | import bimap.BiMap;
13 |
14 | public class DpMatrix
15 | {
16 | private int numRows;
17 | private int numCols;
18 |
19 | /**
20 | * Maps each state to a row
21 | */
22 | private BiMap stateRowMap;
23 |
24 | /**
25 | * Map each time step (i.e. column) to a symbol
26 | */
27 | private ArrayList colSymbolMap;
28 |
29 | /**
30 | * The matrix
31 | */
32 | private DpMatrixElement[][] matrix;
33 |
34 | public DpMatrix(HMM model, String[] sequence)
35 | {
36 | numRows = model.getNumStates();
37 |
38 | numCols = sequence.length + 1;
39 |
40 | initStateRowMap(model);
41 | initColSymbolMap(sequence);
42 | initMatrix();
43 | }
44 |
45 | public void initStateRowMap(HMM model)
46 | {
47 | stateRowMap = new BiMap();
48 |
49 | ArrayList states = new ArrayList(model.getStateContainer()
50 | .getStates());
51 |
52 | for (int index = 0; index < states.size(); index++)
53 | {
54 | stateRowMap.put(states.get(index), index);
55 | }
56 | }
57 |
58 | public void initColSymbolMap(String[] sequence)
59 | {
60 | colSymbolMap = new ArrayList();
61 |
62 | /*
63 | * The first time unit does not see a symbol emitted
64 | */
65 | colSymbolMap.add("-");
66 |
67 | for (int i = 0; i < sequence.length; i++)
68 | {
69 | colSymbolMap.add(sequence[i]);
70 | }
71 | }
72 |
73 | public double getValue(State state, int timeUnit)
74 | {
75 | int row = stateRowMap.getValue(state);
76 | return matrix[row][timeUnit].getValue();
77 | }
78 |
79 | public void setValue(State state, int timeUnit, double value)
80 | {
81 | int row = stateRowMap.getValue(state);
82 | matrix[row][timeUnit].setValue(value);
83 | }
84 |
85 | public void setPreviousState(State currState, int timeUnit, State prevState)
86 | {
87 | /*
88 | * The current state's row
89 | */
90 | int currRow = stateRowMap.getValue(currState);
91 |
92 | /*
93 | * The previous state's row
94 | */
95 | int backRow = stateRowMap.getValue(prevState);
96 |
97 | matrix[currRow][timeUnit].setRowPointer(backRow);
98 | }
99 |
100 | public State getPreviousState(State currState, int timeUnit)
101 | {
102 | /*
103 | * The current state's row
104 | */
105 | int currRow = stateRowMap.getValue(currState);
106 |
107 | /*
108 | * Get the row of the previous state from the current state
109 | */
110 | int rowPointer = matrix[currRow][timeUnit].getRowPointer();
111 |
112 | /*
113 | * Return the state associated with this row
114 | */
115 | return stateRowMap.getKey(rowPointer);
116 | }
117 |
118 | public void initMatrix()
119 | {
120 | matrix = new DpMatrixElement[numRows][numCols];
121 |
122 | /*
123 | * Initialize elements
124 | */
125 | for (int r = 0; r < numRows; r++)
126 | {
127 | for (int c = 0; c < numCols; c++)
128 | {
129 | matrix[r][c] = new DpMatrixElement();
130 | }
131 | }
132 | }
133 |
134 | /**
135 | * Print the score of each element in the dynamic
136 | * programming matrix to standard output.
137 | */
138 | @Override
139 | public String toString()
140 | {
141 | String result = "";
142 |
143 | result += "\nDynamic Programming Matrix:\n\n";
144 |
145 | result += "\t\t";
146 |
147 | // Print the character over each columns
148 | for (String c : colSymbolMap)
149 | {
150 | result += (c + "\t");
151 | }
152 | result += "\n";
153 |
154 | for (int r = 0; r < numRows; r++)
155 | {
156 | result += ("[" + stateRowMap.getKey(r).getId() + "]\t");
157 |
158 | for (int c = 0; c < numCols; c++)
159 | {
160 | result += (LogP.exp(matrix[r][c].getValue()) + "\t");
161 | }
162 | result += "\n";
163 | }
164 | result += "\n";
165 |
166 | return result;
167 | }
168 |
169 | public int getNumColumns()
170 | {
171 | return this.numCols;
172 | }
173 |
174 | }
175 |
--------------------------------------------------------------------------------
/machine-learning/src/hmm/algorithms/DpMatrixElement.java:
--------------------------------------------------------------------------------
1 | package hmm.algorithms;
2 |
3 | public class DpMatrixElement
4 | {
5 | /**
6 | * Pointer to the row of the element that determined this element's
7 | * value
8 | */
9 | private int rowPointer;
10 |
11 | /**
12 | * Pointer to the column of the element that determined this element's
13 | * value
14 | */
15 | private int columnPointer;
16 |
17 | /**
18 | * This element's value
19 | */
20 | private double value = Double.NaN;
21 |
22 | public double getValue()
23 | {
24 | return value;
25 | }
26 |
27 | public void setValue(double value)
28 | {
29 | this.value = value;
30 | }
31 |
32 | public void setBackPointer(int row, int column)
33 | {
34 | this.rowPointer = row;
35 | this.columnPointer = column;
36 | }
37 |
38 | public void setRowPointer(int row)
39 | {
40 | this.rowPointer = row;
41 | }
42 |
43 | public void setColumnPointer(int column)
44 | {
45 | this.columnPointer = column;
46 | }
47 |
48 | public int getRowPointer()
49 | {
50 | return this.rowPointer;
51 | }
52 |
53 | public int getColumnPointer()
54 | {
55 | return this.columnPointer;
56 | }
57 |
58 | }
59 |
--------------------------------------------------------------------------------
/machine-learning/src/hmm/algorithms/SortSilentStates.java:
--------------------------------------------------------------------------------
1 | package hmm.algorithms;
2 |
3 | import graph.dag.TopologicalSort;
4 | import hmm.HMM;
5 | import hmm.State;
6 | import hmm.Transition;
7 |
8 | import java.util.ArrayList;
9 | import java.util.Collection;
10 |
11 | import bimap.BiMap;
12 |
13 | public class SortSilentStates
14 | {
15 | public static ArrayList run(HMM model)
16 | {
17 | Collection silentStates = model.getSilentStates();
18 |
19 | ArrayList sorted = new ArrayList();
20 |
21 | BiMap indices = new BiMap();
22 |
23 | int numNodes = silentStates.size();
24 | Double[][] graph = new Double[numNodes][numNodes];
25 |
26 | /*
27 | * Initialize every element in graph to null
28 | */
29 | for (int r = 0; r < numNodes; r++)
30 | {
31 | for (int c = 0; c < numNodes; c++)
32 | {
33 | graph[r][c] = null;
34 | }
35 | }
36 |
37 | int count = 0;
38 | for (State s : silentStates)
39 | {
40 | indices.put(count++, s.getId());
41 | }
42 |
43 | for (State s : silentStates)
44 | {
45 | for (Transition t : s.getTransitions())
46 | {
47 | String dId = t.getDestinationId();
48 | String oId = s.getId();
49 | if (indices.containsValue(t.getDestinationId()))
50 | {
51 | graph[indices.getKey(oId)][indices.getKey(dId)] = 1.0;
52 | }
53 |
54 | }
55 | }
56 |
57 | ArrayList sortedIndices = TopologicalSort.run(graph);
58 |
59 | for (Integer i : sortedIndices)
60 | {
61 | sorted.add(model.getStateById(indices.getValue(i)));
62 | }
63 |
64 | return sorted;
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/machine-learning/src/math/LogP.java:
--------------------------------------------------------------------------------
1 | package math;
2 |
3 | /**
4 | * Methods for dealing in log-probability space. This class includes methods
5 | * for converting to and from log-values as well as taking products and
6 | * summations over log-probabilities.
7 | *
8 | */
9 | public class LogP
10 | {
11 | /**
12 | * Raises a log-probability to the power E thereby converting the
13 | * log-probability to a proper probability.
14 | *
15 | * @param x the log-probability
16 | * @return the proper probability
17 | */
18 | public static double exp(double x)
19 | {
20 | if (Double.isNaN(x))
21 | {
22 | return 0;
23 | }
24 | else
25 | {
26 | return Math.pow(Math.E, x);
27 | }
28 | }
29 |
30 | /**
31 | * Take the natural logarithm of a probability thereby converting it to a
32 | * log-probability.
33 | *
34 | * @param x a probability
35 | * @return the log-probability
36 | */
37 | public static double ln(double x)
38 | {
39 | if (x == 0.0)
40 | {
41 | return Double.NaN;
42 | }
43 | else if (x > 0.0)
44 | {
45 | return Math.log(x);
46 | }
47 | else
48 | {
49 | throw new IllegalArgumentException("Passed 'eLn' function the " +
50 | "negative value " + x + ". " +
51 | "Argument must be greater than " +
52 | "zero.");
53 | }
54 | }
55 |
56 | /**
57 | * Take the sum of two log-probabilities
58 | *
59 | * @param eLnX the first log-probability
60 | * @param eLnY the second log-probability
61 | * @return the sum
62 | */
63 | public static double sum(double eLnX, double eLnY)
64 | {
65 | if (Double.isNaN(eLnX) || Double.isNaN(eLnY))
66 | {
67 | if (Double.isNaN(eLnX))
68 | {
69 | return eLnY;
70 | }
71 | else
72 | {
73 | return eLnX;
74 | }
75 | }
76 | else
77 | {
78 | if (eLnX > eLnY)
79 | {
80 | return eLnX + ln(1 + Math.pow(Math.E, eLnY - eLnX));
81 | }
82 | else
83 | {
84 | return eLnY + ln(1 + Math.pow(Math.E, eLnX - eLnY));
85 | }
86 | }
87 | }
88 |
89 | /**
90 | * Take the product of two log-probabilities
91 | *
92 | * @param eLnX the first log-probability
93 | * @param eLnY the second log-probability
94 | * @return the product
95 | */
96 | public static double prod(double eLnX, double eLnY)
97 | {
98 | if (Double.isNaN(eLnX) || Double.isNaN(eLnY))
99 | {
100 | return Double.NaN;
101 | }
102 | else
103 | {
104 | return eLnX + eLnY;
105 | }
106 | }
107 |
108 | /**
109 | * Take the quotient of two log-probabilities
110 | *
111 | * @param eLnX the dividend
112 | * @param eLnY the divisor
113 | * @return the quotient
114 | */
115 | public static double div(double eLnX, double eLnY)
116 | {
117 | if (Double.isNaN(eLnY))
118 | {
119 | throw new IllegalArgumentException("Passed 'eLnDivision' " +
120 | "function the a NaN quotient. Argument must be real.");
121 | }
122 | else if (Double.isNaN(eLnX))
123 | {
124 | return Double.NaN;
125 | }
126 | else
127 | {
128 | return eLnX - eLnY;
129 | }
130 | }
131 | }
132 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/DtLeaf.java:
--------------------------------------------------------------------------------
1 | package tree;
2 |
3 | import data.Attribute;
4 |
5 | /**
6 | * The leaf of a decision tree.
7 | *
8 | */
9 | public class DtLeaf extends DtNode
10 | {
11 | /**
12 | * This integer must be a valid nominal value of the class attribute for the
13 | * current learning problem
14 | */
15 | private Integer classLabel;
16 |
17 | public DtLeaf(Attribute attribute,
18 | Double value,
19 | Relation relation,
20 | Integer classLabel)
21 | {
22 | super(attribute, value, relation);
23 | this.classLabel = classLabel;
24 | }
25 |
26 | @Override
27 | public void addChild(Node child)
28 | {
29 | throw new UnsupportedOperationException("Error. Attempting to add" +
30 | " a child Node to Leaf " + super.toString());
31 | }
32 |
33 | public Integer getClassLabel()
34 | {
35 | return classLabel;
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/Forest.java:
--------------------------------------------------------------------------------
1 | package tree;
2 |
3 | import java.util.Set;
4 |
5 | import pair.Pair;
6 | import data.Instance;
7 |
8 | public class Forest
9 | {
10 | private Set trees;
11 |
12 | }
13 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/Node.java:
--------------------------------------------------------------------------------
1 | package tree;
2 |
3 | import java.util.HashSet;
4 | import java.util.Set;
5 |
6 | /**
7 | * A node in a tree.
8 | *
9 | */
10 | public class Node
11 | {
12 | /**
13 | * This nodes parent in the tree
14 | */
15 | protected Node parent = null;
16 |
17 | /**
18 | * This node's child in the tree
19 | */
20 | protected final Set children;
21 |
22 | /**
23 | * Constructor for creating a node
24 | * @param nodeId
25 | */
26 | public Node()
27 | {
28 | children = new HashSet();
29 | }
30 |
31 | /**
32 | * Add a child Node.
33 | *
34 | * @param child
35 | */
36 | public void addChild(Node child)
37 | {
38 | child.setParent(this);
39 | children.add(child);
40 | }
41 |
42 | /**
43 | * Remove a child Node.
44 | *
45 | * @param child
46 | * @return true if the Node existed as child of the current Node
47 | * and was successfully removed
48 | */
49 | public Boolean removeChild(Node child)
50 | {
51 | child.setParent(null);
52 | return children.remove(child);
53 | }
54 |
55 | /**
56 | * @return this node's children nodes
57 | */
58 | public Set getChildren()
59 | {
60 | return children;
61 | }
62 |
63 | /**
64 | * Set the parent of the node.
65 | *
66 | * @param parent the node's parent
67 | */
68 | public void setParent(Node parent)
69 | {
70 | this.parent = parent;
71 | }
72 | }
73 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/algorithms/DecisionTreeBuilder.java:
--------------------------------------------------------------------------------
1 | package tree.algorithms;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | import classify.Classifier;
7 | import tree.DecisionTree;
8 | import tree.DtLeaf;
9 | import tree.DtNode;
10 | import tree.Node;
11 | import tree.DtNode.Relation;
12 | import tree.train.Split;
13 | import tree.train.SplitBranch;
14 | import tree.train.SplitGenerator;
15 | import data.Attribute;
16 | import data.DataSet;
17 |
18 | public abstract class DecisionTreeBuilder
19 | {
20 |
21 | /**
22 | * The decision tree under construction
23 | */
24 | protected DecisionTree decisionTree = null;
25 |
26 | /**
27 | * Determine when the recursion should stop and a leaf node should be constructed.
28 | *
29 | * @param data the data set used to make this decision
30 | * @param availAttributes available attributes
31 | * @param candidateSplits splits that are under consideration to split on
32 | * @return whether or not the recursion should terminate
33 | */
34 | protected abstract boolean checkStoppingCriteria(DataSet data,
35 | List availAttributes,
36 | List candidateSplits);
37 |
38 | /**
39 | * Determine the best attribute to split on.
40 | *
41 | * @param data the data set used to make this decision
42 | * @param candidateSplits the set of candidate split
43 | * @return the best split among the candidates
44 | */
45 | protected abstract Split determineBestSplit(DataSet data, List candidateSplits);
46 |
47 | public DecisionTree buildDecisionTree(DataSet data)
48 | {
49 | List availAttributes = new ArrayList<>(data.getAttributeSet().getAttributes());
50 | availAttributes.remove(data.getClassAttribute());
51 |
52 | DtNode root = makeSubTree(
53 | data,
54 | null,
55 | null,
56 | null,
57 | availAttributes);
58 |
59 | decisionTree = new DecisionTree(root, data.getClassAttribute());
60 | return decisionTree;
61 | }
62 |
63 | private DtNode makeSubTree(
64 | DataSet data,
65 | Attribute attribute,
66 | Double value,
67 | DtNode.Relation relation,
68 | List availAttrs)
69 | {
70 | DtNode newNode = null;
71 | List candidateSplits = SplitGenerator.generateSplits(data, availAttrs);
72 |
73 | /*
74 | * If the stopping criteria is met, create a leaf node with a decision
75 | * class label that is the majority class of the instances at this node
76 | */
77 | if (checkStoppingCriteria(data, availAttrs, candidateSplits))
78 | {
79 | DtLeaf leaf = new DtLeaf(attribute,
80 | value,
81 | relation,
82 | data.getMajorityClass());
83 | leaf.setClassCounts(data.getClassCounts());
84 | newNode = leaf;
85 | }
86 | else
87 | {
88 | newNode = new DtNode(attribute, value, relation);
89 | newNode.setClassCounts(data.getClassCounts());
90 |
91 | /*
92 | * Find the best split among the candidate splits. For each branch of the best split, create a
93 | * new node that roots a subtree
94 | */
95 | Split bestSplit = determineBestSplit(data, candidateSplits);
96 | for (SplitBranch branch : bestSplit.getSplitBranches())
97 | {
98 | DataSet subsetData = new DataSet(data.getAttributeSet(), branch.getInstanceSet());
99 | subsetData.setClassAttribute(data.getClassAttribute().getName());
100 |
101 | /*
102 | * Determine attributes that are still available after the split.
103 | * We only remove the attribute if it is nominal.
104 | */
105 | List newAvailAttrs = null;
106 | if (branch.getAttribute().getType() == Attribute.Type.NOMINAL)
107 | {
108 | newAvailAttrs = new ArrayList<>(availAttrs);
109 | newAvailAttrs.remove(bestSplit.getAttribute());
110 | }
111 | else
112 | {
113 | newAvailAttrs = availAttrs;
114 | }
115 |
116 | /*
117 | * Make the recursive call to make a subtree at each child node
118 | */
119 | Node child = makeSubTree(
120 | subsetData,
121 | bestSplit.getAttribute(),
122 | branch.getValue(),
123 | branch.getRelation(),
124 | newAvailAttrs);
125 |
126 | newNode.addChild(child);
127 | }
128 | }
129 |
130 | return newNode;
131 | }
132 |
133 | }
134 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/algorithms/ID3TreeBuilder.java:
--------------------------------------------------------------------------------
1 | package tree.algorithms;
2 |
3 | import java.util.List;
4 |
5 | import tree.train.Split;
6 | import data.Attribute;
7 | import data.DataSet;
8 |
9 | /**
10 | * Builds a decision tree using the ID3 algorithm. Splits are
11 | * determined by finding the split that yields maximal information gain on each
12 | * iteration. The stopping criteria is met when either a minimum number of
13 | * instances are found at the leaf node, all instances at the leaf node are of
14 | * the same class, or there are no more splits to split on.
15 | *
16 | */
17 | public class ID3TreeBuilder extends DecisionTreeBuilder
18 | {
19 | /**
20 | * The minimum number of instances at the leaf node for
21 | * the stopping criteria to be met.
22 | */
23 | private final int minInstances;
24 |
25 | public ID3TreeBuilder(Integer minInstances)
26 | {
27 | this.minInstances = minInstances;
28 | }
29 |
30 | @Override
31 | public boolean checkStoppingCriteria(DataSet data,
32 | List availAttributes,
33 | List candidateSplits) {
34 |
35 | int numInstances = data.getInstanceSet().getInstances().size();
36 | return isAllCandidateSplitsNegativeInfoGain(candidateSplits) ||
37 | availAttributes.isEmpty() ||
38 | numInstances < minInstances ||
39 | isAllInstancesOfSameClass(data);
40 | }
41 |
42 | /**
43 | * Find the Split with the highest information gain among the candidate
44 | * splits
45 | *
46 | * @param data
47 | * @param candidateSplits
48 | * @return
49 | */
50 | @Override
51 | public Split determineBestSplit(DataSet data, List candidateSplits)
52 | {
53 | Split bestSplit = null;
54 | double maxInfoGain = -Double.MAX_VALUE;
55 | for (Split split : candidateSplits)
56 | {
57 | if (split.getInfoGain().doubleValue() > maxInfoGain)
58 | {
59 | maxInfoGain = split.getInfoGain().doubleValue();
60 | bestSplit = split;
61 | }
62 | }
63 | return bestSplit;
64 | }
65 |
66 | /**
67 | * @param candidateSplits all candidate splits
68 | * @return true if all candidate splits have negative information gain.
69 | * False otherwise.
70 | */
71 | private boolean isAllCandidateSplitsNegativeInfoGain(List candidateSplits)
72 | {
73 | for (Split split : candidateSplits)
74 | {
75 | if (split.getInfoGain() > 0)
76 | {
77 | return false;
78 | }
79 | }
80 | return true;
81 | }
82 |
83 | /**
84 | * @param data the dataset
85 | * @return true if if all instances are of the same class. False otherwise.
86 | */
87 | private boolean isAllInstancesOfSameClass(DataSet data)
88 | {
89 | for (Integer count : data.getClassCounts().values())
90 | {
91 | int numInstances = data.getInstanceSet().getInstances().size();
92 | if (count == numInstances)
93 | {
94 | return true;
95 | }
96 | }
97 | return false;
98 | }
99 |
100 | }
101 |
102 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/algorithms/RandomForestBuilder.java:
--------------------------------------------------------------------------------
1 | package tree.algorithms;
2 |
3 | import java.util.Set;
4 |
5 | import data.AttributeSet;
6 | import data.DataSet;
7 | import data.InstanceSet;
8 | import tree.DecisionTree;
9 |
10 | public class RandomForestBuilder
11 | {
12 |
13 | private DataSet sampleDataSet(DataSet fullData)
14 | {
15 | AttributeSet attributes = sampleAttributes(fullData.getAttributeSet());
16 | InstanceSet instances = sampleInstances(fullData.getInstanceSet());
17 | return new DataSet(attributes, instances);
18 | }
19 |
20 | private AttributeSet sampleAttributes(AttributeSet fullAttributeSet)
21 | {
22 | return null;
23 | }
24 |
25 | private InstanceSet sampleInstances(InstanceSet fullInstanceSet)
26 | {
27 | return null;
28 | }
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/classifiers/ID3TreeClassifier.java:
--------------------------------------------------------------------------------
1 | package tree.classifiers;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | import classify.ClassificationResult;
7 | import classify.Classifier;
8 | import pair.Pair;
9 | import tree.DecisionTree;
10 | import tree.algorithms.ID3TreeBuilder;
11 | import data.DataSet;
12 | import data.Instance;
13 |
14 | public class ID3TreeClassifier implements Classifier
15 | {
16 | private DecisionTree dtTree;
17 |
18 | public ID3TreeClassifier(int minInstances, DataSet trainData)
19 | {
20 | ID3TreeBuilder id3Builder = new ID3TreeBuilder(minInstances);
21 | dtTree = id3Builder.buildDecisionTree(trainData);
22 | }
23 |
24 | @Override
25 | public ClassificationResult classifyData(DataSet testData)
26 | {
27 | /*
28 | * Classify each instance in the test dataset
29 | */
30 | List> resultList = new ArrayList<>();
31 | for (Instance instance : testData.getInstanceSet().getInstances())
32 | {
33 | resultList.add( dtTree.classifyInstance(instance) );
34 | }
35 |
36 | /*
37 | * Process the results
38 | */
39 | ClassificationResult result = new ClassificationResult(resultList, testData);
40 | return result;
41 | }
42 |
43 | @Override
44 | public Object getModel()
45 | {
46 | return this.dtTree;
47 | }
48 |
49 | @Override
50 | public String toString()
51 | {
52 | return "ID3\n\n" + dtTree;
53 | }
54 |
55 | }
56 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/classifiers/RandomForestClassifier.java:
--------------------------------------------------------------------------------
1 | package tree.classifiers;
2 |
3 | import java.util.Set;
4 |
5 | import tree.DecisionTree;
6 |
7 | public class RandomForestClassifier
8 | {
9 | private Set trees;
10 |
11 | }
12 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/evaluate/BiClassTest.java:
--------------------------------------------------------------------------------
1 | package tree.evaluate;
2 |
3 | import java.util.Set;
4 |
5 | import tree.DecisionTree;
6 | import tree.DtLeaf;
7 | import tree.DtNode;
8 | import data.Attribute;
9 | import data.DataSet;
10 | import data.Instance;
11 |
12 | public class BiClassTest
13 | {
14 | public static BiClassTestResults runTest(DataSet data, DecisionTree dt)
15 | {
16 | BiClassTestResults results = new BiClassTestResults();
17 |
18 | for (Instance instance : data.getInstanceSet().getInstances())
19 | {
20 | DtNode currNode = (DtNode) dt.getRoot();
21 |
22 | while (!(currNode instanceof DtLeaf))
23 | {
24 | @SuppressWarnings("unchecked")
25 | Set children = ((Set) ((Set>) currNode.getChildren()));
26 |
27 | for (DtNode node : children)
28 | {
29 | if (node.doesInstanceSatisfyNode(instance))
30 | {
31 | currNode = node;
32 | break;
33 | }
34 | }
35 | }
36 |
37 | Attribute classAttr = data.getClassAttribute();
38 |
39 | DtLeaf leaf = (DtLeaf) currNode;
40 |
41 | String prediction = data.getClassAttribute().getNominalValueName(
42 | leaf.getClassLabel().intValue() );
43 | String truth = data.getClassAttribute().getNominalValueName(
44 | instance.getAttributeValue(classAttr).intValue() );
45 |
46 | // Print result of classification
47 | System.out.print(prediction);
48 | System.out.print(" ");
49 | System.out.print(truth);
50 | System.out.print("\n");
51 |
52 | // Add the prediction to the test results
53 | results.addClassification(leaf.getClassLabel().intValue(),
54 | instance.getAttributeValue(classAttr).intValue());
55 | }
56 |
57 | System.out.println("\n");
58 |
59 | return results;
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/evaluate/BiClassTestResults.java:
--------------------------------------------------------------------------------
1 | package tree.evaluate;
2 |
3 | public class BiClassTestResults
4 | {
5 | private static final int NEGATIVE = 0;
6 | private static final int POSITIVE = 1;
7 |
8 | private Integer posPredictions = 0;
9 | private Integer negPredictions = 0;
10 |
11 | private Integer falsePos = 0;
12 | private Integer truePos = 0;
13 | private Integer falseNeg = 0;
14 | private Integer trueNeg = 0;
15 |
16 |
17 | public void addClassification(Integer predictedLabel, Integer trueLabel)
18 | {
19 |
20 | if (predictedLabel.equals(POSITIVE))
21 | {
22 | this.posPredictions++;
23 | }
24 | else if (predictedLabel.equals(NEGATIVE))
25 | {
26 | this.negPredictions++;
27 | }
28 |
29 | // Compute whether this classification was TP, FP, TN, or FN
30 | if (predictedLabel.equals(trueLabel))
31 | {
32 | if (predictedLabel.equals(POSITIVE))
33 | {
34 | truePos++;
35 | }
36 | else if (predictedLabel.equals(NEGATIVE))
37 | {
38 | trueNeg++;
39 | }
40 | }
41 | else
42 | {
43 | if (predictedLabel.equals(POSITIVE))
44 | {
45 | falsePos++;
46 | }
47 | else if (predictedLabel.equals(NEGATIVE))
48 | {
49 | falseNeg++;
50 | }
51 | }
52 | }
53 |
54 | public void printResults()
55 | {
56 | System.out.println((truePos + trueNeg) + " " + (posPredictions + negPredictions));
57 | /*
58 | System.out.println("Correctly classified: " + (truePos + trueNeg));
59 | System.out.println("Total instances: " + (posPredictions + negPredictions));
60 | System.out.print("\n");*/
61 | }
62 |
63 | }
64 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/train/Bin.java:
--------------------------------------------------------------------------------
1 | package tree.train;
2 |
3 | import java.util.Comparator;
4 | import java.util.HashMap;
5 | import java.util.Map;
6 |
7 | /**
8 | * This helper class keeps track of which class labels are represented at each
9 | * value of a specific continuous attribute. This class is specifically used
10 | * for generating all possible splits along a continuous attribute.
11 | *
12 | *
13 | * For example, say we have two instances with attribute A = a. One of these
14 | * instances has Class = positive, the other instance has Class = negative.
15 | * This class stores the knowledge that both class values are represented in
16 | * the instances whose attribute A = a.
17 | *
18 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
19 | *
20 | */
21 | public class Bin
22 | {
23 | /**
24 | * Ordering of bins based on value they represent
25 | */
26 | public static final Comparator BIN_ORDER =
27 | new Comparator()
28 | {
29 | public int compare(Bin b1, Bin b2)
30 | {
31 | return b2.getValue().compareTo(b1.getValue());
32 | }
33 | };
34 |
35 | /**
36 | * The nominal value of some target attribute represented by the bin
37 | */
38 | private Double value;
39 |
40 | /**
41 | * A mapping of a nominal value ID of the class attribute to
42 | * a Boolean variable. This Boolean is true if an instance with the given
43 | * class label is in this bin.
44 | */
45 | private Map classExistenceMap;
46 |
47 | public Bin(Double value)
48 | {
49 | classExistenceMap = new HashMap();
50 | this.value = value;
51 | }
52 |
53 | /**
54 | * @return the bin's value of the target attribute.
55 | */
56 | public Double getValue()
57 | {
58 | return this.value;
59 | }
60 |
61 | /**
62 | * @return a mapping of a nominal value ID of the class attribute to
63 | * a Boolean variable. This Boolean is true if an instance with the given
64 | * class label is in this bin.
65 | */
66 | public Map getExistenceMap()
67 | {
68 | return classExistenceMap;
69 | }
70 |
71 | /**
72 | * @param classLabel a nominal value ID of the class attribute that is
73 | * included in this bin
74 | */
75 | public void includeInstance(Integer classLabel)
76 | {
77 | classExistenceMap.put(classLabel, true);
78 | }
79 | }
80 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/train/Entropy.java:
--------------------------------------------------------------------------------
1 | package tree.train;
2 |
3 | import java.util.Map;
4 |
5 |
6 | import data.DataSet;
7 |
8 | /**
9 | * Used for calculating information theory metrics used for determining the
10 | * best splits in a decision tree.
11 | *
12 | * @author Matthew Bernstein - matthewb@cs.wisc.edu
13 | *
14 | */
15 | public class Entropy
16 | {
17 | /**
18 | * Calculate the information gain of the class attribute on a given split.
19 | * A split consists of an attribute and a set of instances. This method
20 | * calculates
21 | *
22 | *
23 | * H(C) - H(C | X)
24 | *
25 | *
26 | * where H(C) is the entropy of the class attribute and H(C | X) is the
27 | * conditional entropy of the class attribute given some other attribute, X
28 | *
29 | * @param data the data set
30 | * @param split a split on the data
31 | * @return the information gained on the dataset's class attribute by
32 | * knowing the split
33 | */
34 | public static Double informationGain(DataSet data, Split split)
35 | {
36 | Double entropy = entropy(data);
37 | Double conditionalEntropy = conditionalEntropy(data, split);
38 | Double infoGain = entropy - conditionalEntropy;
39 |
40 | return infoGain;
41 | }
42 |
43 | /**
44 | * Calculate the entropy of the class attribute in a data set. This method
45 | * calculates
46 | *
47 | *
48 | * H(C)
49 | *
50 | *
51 | * where C is the class attribute
52 | *
53 | * @param data the data set for which to calculate entropy
54 | * @return the entropy of the class attribute
55 | */
56 | public static Double entropy(DataSet data)
57 | {
58 | double entropy = 0;
59 | double totalInstances = data.getInstanceSet().getInstances().size();
60 | Map classCounts = data.getClassCounts();
61 |
62 | /*
63 | * Calculate the entropy by summer P*log(P) for each P,
64 | * where P is the probability of seeing a class label
65 | */
66 | for (Integer count : classCounts.values())
67 | {
68 | if (count > 0)
69 | {
70 | double P = count / totalInstances;
71 | entropy += ( P * Math.log(P) / Math.log(2d) );
72 | }
73 | }
74 |
75 | entropy *= -1;
76 |
77 | return entropy;
78 | }
79 |
80 | /**
81 | * Calculate the conditional entropy of the class attribute given another
82 | * attribute. That is, this method calculates
83 | *
84 | *
85 | * H(C | X)
86 | *
87 | *
88 | * where C is the class attribute and X is some other attribute
89 | * in the data set.
90 | *
91 | * @param data the data set
92 | * @param split the split of the attribute for which we condition the class
93 | * attribute
94 | * @return
95 | */
96 | public static Double conditionalEntropy(DataSet data, Split split)
97 | {
98 | double conditionalEntropy = 0;
99 | double totalInstances = data.getInstanceSet().getInstances().size();
100 |
101 | if (totalInstances > 0)
102 | {
103 | for (SplitBranch branch : split.getSplitBranches())
104 | {
105 | DataSet branchData = new DataSet(data.getAttributeSet(), branch.getInstanceSet());
106 | branchData.setClassAttribute(data.getClassAttribute().getName());
107 |
108 | int branchNumInstances = branchData.getInstanceSet().getInstances().size();
109 | conditionalEntropy +=
110 | ((branchNumInstances / totalInstances) * entropy(branchData));
111 | }
112 | }
113 |
114 | return conditionalEntropy;
115 | }
116 |
117 | }
118 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/train/Split.java:
--------------------------------------------------------------------------------
1 | package tree.train;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | import data.Attribute;
7 | import data.DataSet;
8 | import data.Instance;
9 | import data.InstanceSet;
10 |
11 | /**
12 | * This class splits a set of instances along an attribute. It stores the
13 | * separated instances sorted by the value of this split's attribute.
14 | *
15 | */
16 | public class Split
17 | {
18 | /**
19 | * All of the branches for this split. For splits on nominal attributes,
20 | * there will be one branch per nominal value. For continuous attributes,
21 | * there will be two branches. One branch for the all instances with a
22 | * value greater than the threshold and one branch for all instances less
23 | * than or equal to the threshold value.
24 | */
25 | private List branches;
26 |
27 | /**
28 | * The attribute this Split splits instances on
29 | */
30 | private Attribute attribute;
31 |
32 | /**
33 | * The information gain on this split
34 | */
35 | private Double infoGain;
36 |
37 | /**
38 | * Constructor
39 | *
40 | * @param attribute the attribute along which this split splits
41 | */
42 | public Split(Attribute attribute)
43 | {
44 | branches = new ArrayList();
45 | this.attribute = attribute;
46 | }
47 |
48 | /**
49 | * @return this split's attribute
50 | */
51 | public Attribute getAttribute()
52 | {
53 | return attribute;
54 | }
55 |
56 | /**
57 | * @return the information gain along this split for the dataset's
58 | * class attribute.
59 | */
60 | public Double getInfoGain()
61 | {
62 | return infoGain;
63 | }
64 |
65 | // TODO: REFACTOR THIS!
66 | @Deprecated
67 | public void setInfoGain(Double infoGain)
68 | {
69 | this.infoGain = infoGain;
70 | }
71 |
72 | /**
73 | * Split a set of instances along this split.
74 | *
75 | * @param data the dataset containing the instances to be split
76 | */
77 | public void splitInstances(DataSet data)
78 | {
79 | InstanceSet instances = data.getInstanceSet();
80 |
81 | for (Instance instance : instances.getInstances())
82 | {
83 | for (SplitBranch branch : this.branches)
84 | {
85 | branch.tryAddInstance(instance);
86 | }
87 | }
88 |
89 | this.infoGain = Entropy.informationGain(data, this);
90 | }
91 |
92 | /**
93 | * @return each branch along this split
94 | */
95 | public List getSplitBranches()
96 | {
97 | return branches;
98 | }
99 |
100 | /**
101 | * Add a branch to the split
102 | *
103 | * @param newBranch the new branch
104 | */
105 | protected void addBranch(SplitBranch newBranch)
106 | {
107 | branches.add(newBranch);
108 | }
109 | }
110 |
--------------------------------------------------------------------------------
/machine-learning/src/tree/train/SplitBranch.java:
--------------------------------------------------------------------------------
1 | package tree.train;
2 |
3 | import tree.DtNode;
4 | import data.Attribute;
5 | import data.Instance;
6 | import data.InstanceSet;
7 |
8 | /**
9 | * This class represents a single branch along a split. If a split splits
10 | * instances along attribute A where A can take values {a1, a2, a3}, this split
11 | * will store 3 split branches for storing instances whose value of A is a1, a2,
12 | * and a3.
13 | */
14 | public class SplitBranch
15 | {
16 | /**
17 | * This describes the relation to the {@branchValue} that an instance is
18 | * tested against. Possible relations are defined in {@code DtNode}:
19 | *
20 | * EQUALS : instance value == branch value
21 | * GREATER_THAN : instance value > branch value
22 | * GREATER_THAN_EQUAL_TO : instance value >= branch value
23 | */
24 | private DtNode.Relation relation;
25 |
26 | /**
27 | * The value that an instance is tested against to make this split
28 | */
29 | private Double branchValue;
30 |
31 | /**
32 | * The attribute this branch tests
33 | */
34 | private Attribute attribute;
35 |
36 | /**
37 | * All instances that fall to this branch
38 | */
39 | private InstanceSet instanceSet;
40 |
41 | /**
42 | * Constructor
43 | *
44 | * @param attribute the attribute along which this branch's split splits
45 | * instances
46 | * @param branchValue the value of the attribute that this branch tests
47 | * @param relation the relation to the attribute an instance must be to
48 | * make this branch
49 | */
50 | protected SplitBranch(Attribute attribute,
51 | Double branchValue,
52 | DtNode.Relation relation)
53 | {
54 | this.instanceSet = new InstanceSet();
55 | this.attribute = attribute;
56 | this.branchValue = branchValue;
57 | this.relation = relation;
58 | }
59 |
60 | /**
61 | * @return the set of instances that have made this branch
62 | */
63 | public InstanceSet getInstanceSet()
64 | {
65 | return instanceSet;
66 | }
67 |
68 | /**
69 | * Attempt to add an instance to the this split branch. The instance is only
70 | * add if it passes this branches test.
71 | *
72 | * @param instance
73 | */
74 | public void tryAddInstance(Instance instance)
75 | {
76 | if (this.doesInstanceMakeSplit(instance))
77 | {
78 | instanceSet.addInstance(instance);
79 | }
80 | }
81 |
82 | /**
83 | * @return the split's attribute
84 | */
85 | public Attribute getAttribute()
86 | {
87 | return attribute;
88 | }
89 |
90 | /**
91 | * @return the relation to the attribute for which all instances in this
92 | * split branch fall
93 | */
94 | public DtNode.Relation getRelation()
95 | {
96 | return relation;
97 | }
98 |
99 | /**
100 | * @return the value of the split branch
101 | */
102 | public Double getValue()
103 | {
104 | return branchValue;
105 | }
106 |
107 | /**
108 | * Tests whether an instance makes this split branch
109 | *
110 | * @param instance the instance we are testing
111 | * @return true if the instance makes the split branch. False otherwise.
112 | */
113 | public Boolean doesInstanceMakeSplit(Instance instance)
114 | {
115 | Double instanceAttrValue = instance.getAttributeValue(this.attribute);
116 |
117 | switch(this.relation)
118 | {
119 | case EQUALS:
120 | return (instanceAttrValue.doubleValue() == branchValue.doubleValue());
121 | case GREATER_THAN:
122 | return (instanceAttrValue.doubleValue() > branchValue.doubleValue());
123 | case LESS_THAN_EQUAL_TO:
124 | return (instanceAttrValue.doubleValue() <= branchValue.doubleValue());
125 | default:
126 | throw new RuntimeException("Error testing instance in branch. " +
127 | "This branch's relation is not set to a valid relation.");
128 | }
129 | }
130 | }
131 |
--------------------------------------------------------------------------------
/machine-learning/tst/data/AttributeTest.java:
--------------------------------------------------------------------------------
1 | package data;
2 |
3 | import static org.junit.Assert.*;
4 |
5 | import java.util.HashMap;
6 | import java.util.List;
7 | import java.util.Map;
8 |
9 | import org.junit.Before;
10 | import org.junit.Test;
11 |
12 | import com.google.common.collect.ImmutableList;
13 | import com.google.common.collect.ImmutableMap;
14 |
15 |
16 | public class AttributeTest
17 | {
18 | private static final String ATTR_NAME = "Color";
19 | private static final List NOMINAL_VALUES = ImmutableList.of("Red", "Yellow", "Blue");
20 | private static final List NOMINAL_IDS = ImmutableList.of(0, 1, 2);
21 |
22 | private static final Map NOMINAL_VALUE_IDS = generateNominalValueIds();
23 |
24 | @Before
25 | public void before()
26 | {
27 |
28 | }
29 |
30 | @Test
31 | public void test_Constructor()
32 | {
33 | Attribute nominalAttr = new Attribute(ATTR_NAME, Attribute.Type.NOMINAL, (String[]) NOMINAL_VALUES.toArray());
34 |
35 | assertEquals(nominalAttr.getName(), ATTR_NAME);
36 |
37 | //assertEquals(nominalAttr.getNominalValueId(attrValueName), NOMINAL_VALUE_IDS.keySet());
38 | assertEquals(nominalAttr.getNominalValueMap().values(), NOMINAL_VALUE_IDS.values());
39 | }
40 |
41 | private static Map generateNominalValueIds()
42 | {
43 | return new ImmutableMap.Builder()
44 | .put(NOMINAL_VALUES.get(0), NOMINAL_IDS.get(0))
45 | .put(NOMINAL_VALUES.get(1), NOMINAL_IDS.get(1))
46 | .put(NOMINAL_VALUES.get(2), NOMINAL_IDS.get(2))
47 | .build();
48 | }
49 |
50 | }
51 |
--------------------------------------------------------------------------------