├── .gitignore
├── LICENSE
├── README.md
├── assignment1
├── README.md
├── data
├── experiments
│ ├── ANN.py
│ ├── Boosting.py
│ ├── DT.py
│ ├── KNN.py
│ ├── SVM.py
│ ├── __init__.py
│ ├── base.py
│ └── plotting.py
├── learners
│ ├── ANN.py
│ ├── Boosting.py
│ ├── DT.py
│ ├── KNN.py
│ ├── SVM.py
│ ├── __init__.py
│ └── base.py
├── requirements.txt
└── run_experiment.py
├── assignment2
├── ABAGAIL.jar
├── ABAGAIL
│ ├── ABAGAIL.jar
│ ├── build.xml
│ └── src
│ │ ├── dist
│ │ ├── AbstractConditionalDistribution.java
│ │ ├── AbstractDistribution.java
│ │ ├── ConditionalDistribution.java
│ │ ├── DiscreteDependencyTree.java
│ │ ├── DiscreteDependencyTreeNode.java
│ │ ├── DiscreteDependencyTreeRootNode.java
│ │ ├── DiscreteDistribution.java
│ │ ├── DiscreteDistributionTable.java
│ │ ├── DiscretePermutationDistribution.java
│ │ ├── DiscreteUniformDistribution.java
│ │ ├── Distribution.java
│ │ ├── FixedComponentMixtureDistribution.java
│ │ ├── FixedDistribution.java
│ │ ├── LabelDistribution.java
│ │ ├── MixtureDistribution.java
│ │ ├── MultivariateGaussian.java
│ │ ├── PrecalculatedDistribution.java
│ │ ├── UnivariateGaussian.java
│ │ ├── hmm
│ │ │ ├── ConditionalStateDistributionWrapper.java
│ │ │ ├── ForwardBackwardProbabilityCalculator.java
│ │ │ ├── HiddenMarkovModel.java
│ │ │ ├── HiddenMarkovModelReestimator.java
│ │ │ ├── ModularHiddenMarkovModel.java
│ │ │ ├── SimpleHiddenMarkovModel.java
│ │ │ ├── SimpleStateDistribution.java
│ │ │ ├── SimpleStateDistributionTable.java
│ │ │ ├── StateDistribution.java
│ │ │ └── StateSequenceCalculator.java
│ │ └── test
│ │ │ ├── DiscreteDependencyTreeTest.java
│ │ │ ├── DiscreteDistributionTest.java
│ │ │ ├── HMMCoinTest.java
│ │ │ ├── HMMConditionalMonsterKnowledgeTest.java
│ │ │ ├── HMMConditionalMonsterTest.java
│ │ │ ├── HMMRandomCoinTest.java
│ │ │ ├── HMMWumpusTest.java
│ │ │ ├── MixtureDistributionTest.java
│ │ │ └── MultivariateGaussianTest.java
│ │ ├── func
│ │ ├── AdaBoostClassifier.java
│ │ ├── DecisionStumpClassifier.java
│ │ ├── DecisionTreeClassifier.java
│ │ ├── EMClusterer.java
│ │ ├── FunctionApproximater.java
│ │ ├── FunctionApproximaterSupplier.java
│ │ ├── GaussianProcessRegression.java
│ │ ├── KMeansClusterer.java
│ │ ├── KNNClassifier.java
│ │ ├── NeuralNetworkClassifier.java
│ │ ├── SimpleSupportVectorMachineClassifier.java
│ │ ├── dtree
│ │ │ ├── BinaryDecisionTreeSplit.java
│ │ │ ├── ChiSquarePruningCriteria.java
│ │ │ ├── DecisionTreeNode.java
│ │ │ ├── DecisionTreeSplit.java
│ │ │ ├── DecisionTreeSplitStatistics.java
│ │ │ ├── GINISplitEvaluator.java
│ │ │ ├── InformationGainSplitEvaluator.java
│ │ │ ├── PruningCriteria.java
│ │ │ ├── SplitEvaluator.java
│ │ │ └── StandardDecisionTreeSplit.java
│ │ ├── inst
│ │ │ ├── HyperRectangle.java
│ │ │ ├── KDTree.java
│ │ │ ├── KDTreeNode.java
│ │ │ └── NearestNeighborQueue.java
│ │ ├── nn
│ │ │ ├── Layer.java
│ │ │ ├── LayeredNetwork.java
│ │ │ ├── Link.java
│ │ │ ├── NetworkTrainer.java
│ │ │ ├── NeuralNetwork.java
│ │ │ ├── Neuron.java
│ │ │ ├── activation
│ │ │ │ ├── ActivationFunction.java
│ │ │ │ ├── DifferentiableActivationFunction.java
│ │ │ │ ├── HyperbolicTangentSigmoid.java
│ │ │ │ ├── LinearActivationFunction.java
│ │ │ │ ├── LogisticSigmoid.java
│ │ │ │ └── RELU.java
│ │ │ ├── backprop
│ │ │ │ ├── BackPropagationBiasNode.java
│ │ │ │ ├── BackPropagationLayer.java
│ │ │ │ ├── BackPropagationLink.java
│ │ │ │ ├── BackPropagationNetwork.java
│ │ │ │ ├── BackPropagationNetworkFactory.java
│ │ │ │ ├── BackPropagationNode.java
│ │ │ │ ├── BackPropagationSoftMaxOutputLayer.java
│ │ │ │ ├── BasicUpdateRule.java
│ │ │ │ ├── BatchBackPropagationTrainer.java
│ │ │ │ ├── QuickpropUpdateRule.java
│ │ │ │ ├── RPROPUpdateRule.java
│ │ │ │ ├── StandardUpdateRule.java
│ │ │ │ ├── StochasticBackPropagationTrainer.java
│ │ │ │ └── WeightUpdateRule.java
│ │ │ └── feedfwd
│ │ │ │ ├── FeedForwardBiasNode.java
│ │ │ │ ├── FeedForwardLayer.java
│ │ │ │ ├── FeedForwardNetwork.java
│ │ │ │ ├── FeedForwardNeuralNetworkFactory.java
│ │ │ │ └── FeedForwardNode.java
│ │ ├── svm
│ │ │ ├── Kernel.java
│ │ │ ├── LinearKernel.java
│ │ │ ├── PolynomialKernel.java
│ │ │ ├── RBFKernel.java
│ │ │ ├── SequentialMinimalOptimization.java
│ │ │ ├── SigmoidKernel.java
│ │ │ ├── SingleClassSequentialMinimalOptimization.java
│ │ │ ├── SingleClassSupportVectorMachine.java
│ │ │ └── SupportVectorMachine.java
│ │ └── test
│ │ │ ├── AdaBoostTest.java
│ │ │ ├── DecisionStumpTest.java
│ │ │ ├── DecisionTreeTest.java
│ │ │ ├── EMClustererTest.java
│ │ │ ├── GaussianProcessRegressionTest.java
│ │ │ ├── KMeansClustererTest.java
│ │ │ ├── KNNClassifierAbaloneTest.java
│ │ │ ├── KNNClassifierTest.java
│ │ │ ├── NNBinaryClassificationTest.java
│ │ │ ├── NNClassificationTest.java
│ │ │ ├── NNRegressionTest.java
│ │ │ ├── PruningCriteriaTest.java
│ │ │ ├── SequentialMinimalOptimizationTest.java
│ │ │ ├── SingleClassSequentialMinimalOptimizationTest.java
│ │ │ └── SplitEvaluatorTest.java
│ │ ├── opt
│ │ ├── ContinuousAddOneNeighbor.java
│ │ ├── DiscreteChangeOneNeighbor.java
│ │ ├── EvaluationFunction.java
│ │ ├── GenericHillClimbingProblem.java
│ │ ├── GenericOptimizationProblem.java
│ │ ├── HillClimbingProblem.java
│ │ ├── NeighborFunction.java
│ │ ├── OptimizationAlgorithm.java
│ │ ├── OptimizationProblem.java
│ │ ├── RandomizedHillClimbing.java
│ │ ├── SimulatedAnnealing.java
│ │ ├── SwapNeighbor.java
│ │ ├── example
│ │ │ ├── ContinuousPeaksEvaluationFunction.java
│ │ │ ├── CountOnesEvaluationFunction.java
│ │ │ ├── FlipFlopEvaluationFunction.java
│ │ │ ├── FlipFlopMODEvaluationFunction.java
│ │ │ ├── FourPeaksEvaluationFunction.java
│ │ │ ├── KnapsackEvaluationFunction.java
│ │ │ ├── NeuralNetworkEvaluationFunction.java
│ │ │ ├── NeuralNetworkOptimizationProblem.java
│ │ │ ├── NeuralNetworkWeightDistribution.java
│ │ │ ├── TravelingSalesmanCrossOver.java
│ │ │ ├── TravelingSalesmanEvaluationFunction.java
│ │ │ ├── TravelingSalesmanRouteEvaluationFunction.java
│ │ │ ├── TravelingSalesmanSortEvaluationFunction.java
│ │ │ └── TwoColorsEvaluationFunction.java
│ │ ├── ga
│ │ │ ├── BoardLocation.java
│ │ │ ├── ContinuousAddOneMutation.java
│ │ │ ├── CrossoverFunction.java
│ │ │ ├── DiscreteChangeOneMutation.java
│ │ │ ├── GenericGeneticAlgorithmProblem.java
│ │ │ ├── GeneticAlgorithmProblem.java
│ │ │ ├── MaxKColorFitnessFunction.java
│ │ │ ├── MutationFunction.java
│ │ │ ├── NQueensBoardGame.java
│ │ │ ├── NQueensFitnessFunction.java
│ │ │ ├── SingleCrossOver.java
│ │ │ ├── StandardGeneticAlgorithm.java
│ │ │ ├── SwapMutation.java
│ │ │ ├── TwoPointCrossOver.java
│ │ │ ├── UniformCrossOver.java
│ │ │ └── Vertex.java
│ │ ├── prob
│ │ │ ├── GenericProbabilisticOptimizationProblem.java
│ │ │ ├── MIMIC.java
│ │ │ └── ProbabilisticOptimizationProblem.java
│ │ └── test
│ │ │ ├── AbaloneTest.java
│ │ │ ├── ContinuousPeaksTest.java
│ │ │ ├── CountOnesTest.java
│ │ │ ├── CrossValidationTest.java
│ │ │ ├── FlipFlopTest.java
│ │ │ ├── FourPeaksTest.java
│ │ │ ├── KnapsackTest.java
│ │ │ ├── MaxKColoringTest.java
│ │ │ ├── NQueensTest.java
│ │ │ ├── TravelingSalesmanTest.java
│ │ │ ├── TwoColorsTest.java
│ │ │ ├── XORTest.java
│ │ │ ├── XORTestNoBackprop.java
│ │ │ ├── XORTestNoBackpropGeneticAlgo.java
│ │ │ ├── XORTestNoBackpropSimAnneal.java
│ │ │ └── abalone.txt
│ │ ├── rl
│ │ ├── DecayingEpsilonGreedyStrategy.java
│ │ ├── EpsilonGreedyStrategy.java
│ │ ├── ExplorationStrategy.java
│ │ ├── GreedyStrategy.java
│ │ ├── MarkovDecisionProcess.java
│ │ ├── MazeMarkovDecisionProcess.java
│ │ ├── MazeMarkovDecisionProcessVisualization.java
│ │ ├── NonDeterministicMazeMDP.java
│ │ ├── Policy.java
│ │ ├── PolicyIteration.java
│ │ ├── PolicyLearner.java
│ │ ├── QLambda.java
│ │ ├── SarsaLambda.java
│ │ ├── SimpleMarkovDecisionProcess.java
│ │ ├── ValueIteration.java
│ │ ├── test
│ │ │ ├── MDPTest.java
│ │ │ ├── MazeMDPTest.java
│ │ │ └── NonDeterministicMazeMDPTest.java
│ │ └── tester
│ │ │ └── ExpectedRewardTestMetric.java
│ │ ├── shared
│ │ ├── AbstractDistanceMeasure.java
│ │ ├── AbstractErrorMeasure.java
│ │ ├── AttributeType.java
│ │ ├── ConvergenceTrainer.java
│ │ ├── Copyable.java
│ │ ├── DataSet.java
│ │ ├── DataSetDescription.java
│ │ ├── DataSetWriter.java
│ │ ├── DistanceMeasure.java
│ │ ├── ErrorMeasure.java
│ │ ├── EuclideanDistance.java
│ │ ├── FixedIterationTrainer.java
│ │ ├── GradientErrorMeasure.java
│ │ ├── HammingDistance.java
│ │ ├── Instance.java
│ │ ├── MixedDistanceMeasure.java
│ │ ├── OccasionalPrinter.java
│ │ ├── SumOfSquaresError.java
│ │ ├── ThresholdTrainer.java
│ │ ├── Trainer.java
│ │ ├── filt
│ │ │ ├── ContinuousToDiscreteFilter.java
│ │ │ ├── DataSetFilter.java
│ │ │ ├── DiscreteDistributionFilter.java
│ │ │ ├── DiscreteToBinaryFilter.java
│ │ │ ├── IndependentComponentAnalysis.java
│ │ │ ├── InsignificantComponentAnalysis.java
│ │ │ ├── LabelFilter.java
│ │ │ ├── LabelSelectFilter.java
│ │ │ ├── LabelSplitFilter.java
│ │ │ ├── LinearDiscriminantAnalysis.java
│ │ │ ├── PrincipalComponentAnalysis.java
│ │ │ ├── RandomOrderFilter.java
│ │ │ ├── RandomizedProjectionFilter.java
│ │ │ ├── ReversibleFilter.java
│ │ │ ├── TestTrainSplitFilter.java
│ │ │ ├── VarianceCounter.java
│ │ │ ├── ica
│ │ │ │ ├── ContrastFunction.java
│ │ │ │ └── HyperbolicTangentContrast.java
│ │ │ └── kFoldSplitFilter.java
│ │ ├── reader
│ │ │ ├── ArffDataSetReader.java
│ │ │ ├── CSVDataSetReader.java
│ │ │ ├── DataSetLabelBinarySeperator.java
│ │ │ └── DataSetReader.java
│ │ ├── runner
│ │ │ ├── MultiRunner.java
│ │ │ └── Runner.java
│ │ ├── test
│ │ │ ├── ArffDataSetReaderTest.java
│ │ │ ├── CSVDataSetReaderTest.java
│ │ │ ├── IndepenentComponentAnalysisTest.java
│ │ │ ├── InsignificantComponentAnalysisTest.java
│ │ │ ├── LabelSelectFilterTest.java
│ │ │ ├── LinearDiscriminantAnalysisTest.java
│ │ │ ├── PrincipalComponentAnalysisTest.java
│ │ │ ├── abalone.arff
│ │ │ ├── abalone.data
│ │ │ ├── abalone.names
│ │ │ ├── abalone_label_at_0.data
│ │ │ └── abalone_notes.txt
│ │ ├── tester
│ │ │ ├── AccuracyTestMetric.java
│ │ │ ├── Comparison.java
│ │ │ ├── ConfusionMatrixTestMetric.java
│ │ │ ├── CrossValidationTestMetric.java
│ │ │ ├── NeuralNetworkTester.java
│ │ │ ├── PrecisionTestMetric.java
│ │ │ ├── RawOutputTestMetric.java
│ │ │ ├── RecallTestMetric.java
│ │ │ ├── TestMetric.java
│ │ │ └── Tester.java
│ │ └── writer
│ │ │ ├── CSVWriter.java
│ │ │ └── Writer.java
│ │ └── util
│ │ ├── ABAGAILArrays.java
│ │ ├── MaxHeap.java
│ │ ├── TimeUtil.java
│ │ ├── graph
│ │ ├── DFSTree.java
│ │ ├── Edge.java
│ │ ├── Graph.java
│ │ ├── GraphTransformation.java
│ │ ├── KruskalsMST.java
│ │ ├── Node.java
│ │ ├── Tree.java
│ │ └── WeightedEdge.java
│ │ ├── linalg
│ │ ├── BidiagonalDecomposition.java
│ │ ├── CholeskyFactorization.java
│ │ ├── DenseVector.java
│ │ ├── DiagonalMatrix.java
│ │ ├── GivensRotation.java
│ │ ├── HessenbergDecomposition.java
│ │ ├── HouseholderReflection.java
│ │ ├── LUDecomposition.java
│ │ ├── LowerTriangularMatrix.java
│ │ ├── Matrix.java
│ │ ├── QRDecomposition.java
│ │ ├── RealSchurDecomposition.java
│ │ ├── RectangularMatrix.java
│ │ ├── SingularValueDecomposition.java
│ │ ├── SymmetricEigenvalueDecomposition.java
│ │ ├── TridiagonalDecomposition.java
│ │ ├── UpperTriangularMatrix.java
│ │ └── Vector.java
│ │ └── test
│ │ ├── ABAGAILArraysTest.java
│ │ ├── BidiagonalDecompositionTest.java
│ │ ├── CholeskyFactorizationTest.java
│ │ ├── EigenvalueDecompositionTest.java
│ │ ├── HessenbergDecompositionTest.java
│ │ ├── HouseholderReflectionTest.java
│ │ ├── LUDecompositionTest.java
│ │ ├── LowerTriangularMatrixTest.java
│ │ ├── QRDecompositionTest.java
│ │ ├── SingularValueDecompositionTest.java
│ │ ├── SymmetricEigenvalueDecompositionTest.java
│ │ ├── TridiagonalDecompositionTest.java
│ │ └── UpperTriangularMatrixTest.java
├── NN-Backprop.py
├── NN-GA.py
├── NN-RHC.py
├── NN-SA.py
├── README.md
├── base.py
├── continuouspeaks.py
├── data
├── flipflop.py
├── plotting.py
├── requirements.txt
├── run_experiment.py
└── tsp.py
├── assignment3
├── README.md
├── data
├── experiments
│ ├── ICA.py
│ ├── LDA.py
│ ├── PCA.py
│ ├── RF.py
│ ├── RP.py
│ ├── SVD.py
│ ├── __init__.py
│ ├── base.py
│ ├── benchmark.py
│ ├── clustering.py
│ ├── plotting.py
│ └── scoring.py
├── requirements-no-tables.txt
├── requirements.txt
├── run_clustering.sh
└── run_experiment.py
├── assignment4
├── README.md
├── environments
│ ├── __init__.py
│ ├── cliff_walking.py
│ └── frozen_lake.py
├── experiments
│ ├── __init__.py
│ ├── base.py
│ ├── plotting.py
│ ├── policy_iteration.py
│ ├── q_learner.py
│ └── value_iteration.py
├── requirements.txt
├── run_experiment.py
└── solvers
│ ├── __init__.py
│ ├── base.py
│ ├── policy_iteration.py
│ ├── q_learning.py
│ └── value_iteration.py
└── data
├── HTRU_2.csv
├── abalone.data
├── crx.data
├── default of credit card clients.xls
├── loader.py
├── pendigits.csv
├── spambase.data
└── statlog.vehicle.csv
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | __pycache__
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Chad Maron
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/assignment1/README.md:
--------------------------------------------------------------------------------
1 | # Assignment 1 - Supervised Learning
2 |
3 | There is nothing extra to do here, just pick your datasets, start running, and come back in a week.
--------------------------------------------------------------------------------
/assignment1/data:
--------------------------------------------------------------------------------
1 | ../data
--------------------------------------------------------------------------------
/assignment1/experiments/DT.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import experiments
4 | import learners
5 |
6 |
7 | class DTExperiment(experiments.BaseExperiment):
8 | def __init__(self, details, verbose=False):
9 | super().__init__(details)
10 | self._verbose = verbose
11 |
12 | def perform(self):
13 | # TODO: Clean up the older alpha stuff?
14 | max_depths = np.arange(1, 51, 1)
15 | params = {'DT__criterion': ['gini', 'entropy'], 'DT__max_depth': max_depths,
16 | 'DT__class_weight': ['balanced', None]} # , 'DT__max_leaf_nodes': max_leaf_nodes}
17 | complexity_param = {'name': 'DT__max_depth', 'display_name': 'Max Depth', 'values': max_depths}
18 |
19 | best_params = None
20 | # Uncomment to select known best params from grid search. This will skip the grid search and just rebuild
21 | # the various graphs
22 | #
23 | # Dataset 1:
24 | # best_params = {'criterion': 'entropy', 'max_depth': 23, 'class_weight': 'balanced'}
25 | #
26 | # Dataset 2:
27 | # best_params = {'criterion': 'entropy', 'max_depth': 4, 'class_weight': 'balanced'}
28 |
29 | learner = learners.DTLearner(random_state=self._details.seed)
30 | if best_params is not None:
31 | learner.set_params(**best_params)
32 |
33 | experiments.perform_experiment(self._details.ds, self._details.ds_name, self._details.ds_readable_name,
34 | learner, 'DT', 'DT', params,
35 | complexity_param=complexity_param, seed=self._details.seed,
36 | threads=self._details.threads,
37 | best_params=best_params,
38 | verbose=self._verbose)
39 |
--------------------------------------------------------------------------------
/assignment1/experiments/KNN.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import numpy as np
4 | import sklearn
5 |
6 | import experiments
7 | import learners
8 |
9 |
10 | class KNNExperiment(experiments.BaseExperiment):
11 | def __init__(self, details, verbose=False):
12 | super().__init__(details)
13 | self._verbose = verbose
14 |
15 | def perform(self):
16 | # Adapted from https://github.com/JonathanTay/CS-7641-assignment-1/blob/master/KNN.py
17 | params = {'KNN__metric': ['manhattan', 'euclidean', 'chebyshev'], 'KNN__n_neighbors': np.arange(1, 51, 3),
18 | 'KNN__weights': ['uniform']}
19 | complexity_param = {'name': 'KNN__n_neighbors', 'display_name': 'Neighbor count', 'values': np.arange(1, 51, 1)}
20 |
21 | best_params = None
22 | # Uncomment to select known best params from grid search. This will skip the grid search and just rebuild
23 | # the various graphs
24 | #
25 | # Dataset 1:
26 | # best_params = {'metric': 'manhattan', 'n_neighbors': 7, 'weights': 'uniform'}
27 | #
28 | # Dataset 1:
29 | # best_params = {'metric': 'euclidean', 'n_neighbors': 4, 'weights': 'uniform'}
30 |
31 | learner = learners.KNNLearner(n_jobs=self._details.threads)
32 | if best_params is not None:
33 | learner.set_params(**best_params)
34 |
35 | experiments.perform_experiment(self._details.ds, self._details.ds_name, self._details.ds_readable_name,
36 | learner, 'KNN', 'KNN',
37 | params, complexity_param=complexity_param,
38 | seed=self._details.seed, best_params=best_params, threads=self._details.threads,
39 | verbose=self._verbose)
40 |
--------------------------------------------------------------------------------
/assignment1/experiments/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from abc import ABC, abstractmethod
4 |
5 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class ExperimentDetails(object):
10 | def __init__(self, ds, ds_name, ds_readable_name, threads, seed):
11 | self.ds = ds
12 | self.ds_name = ds_name
13 | self.ds_readable_name = ds_readable_name
14 | self.threads = threads
15 | self.seed = seed
16 |
17 |
18 | class BaseExperiment(ABC):
19 | def __init__(self, details, verbose=False):
20 | self._details = details
21 | self._verbose = verbose
22 |
23 | @abstractmethod
24 | def perform(self):
25 | pass
26 |
27 | def log(self, msg, *args):
28 | """
29 | If the learner has verbose set to true, log the message with the given parameters using string.format
30 | :param msg: The log message
31 | :param args: The arguments
32 | :return: None
33 | """
34 | if self._verbose:
35 | logger.info(msg.format(*args))
36 |
--------------------------------------------------------------------------------
/assignment1/learners/ANN.py:
--------------------------------------------------------------------------------
1 | from sklearn import neural_network
2 |
3 | import learners
4 |
5 |
6 | class ANNLearner(learners.BaseLearner):
7 | def __init__(self,
8 | hidden_layer_sizes=(100,),
9 | activation="relu",
10 | solver='adam',
11 | alpha=0.0001,
12 | batch_size='auto',
13 | learning_rate="constant",
14 | learning_rate_init=0.001,
15 | power_t=0.5,
16 | max_iter=200,
17 | shuffle=True,
18 | random_state=None,
19 | tol=1e-4,
20 | verbose=False,
21 | warm_start=False,
22 | momentum=0.9,
23 | nesterovs_momentum=True,
24 | early_stopping=False,
25 | validation_fraction=0.1,
26 | beta_1=0.9,
27 | beta_2=0.999,
28 | epsilon=1e-8,
29 | ):
30 | super().__init__(verbose)
31 | self._learner = neural_network.MLPClassifier(
32 | hidden_layer_sizes=hidden_layer_sizes,
33 | activation=activation,
34 | solver=solver,
35 | alpha=alpha,
36 | batch_size=batch_size,
37 | learning_rate=learning_rate,
38 | learning_rate_init=learning_rate_init,
39 | power_t=power_t,
40 | max_iter=max_iter,
41 | shuffle=shuffle,
42 | random_state=random_state,
43 | tol=tol,
44 | verbose=verbose,
45 | warm_start=warm_start,
46 | momentum=momentum,
47 | nesterovs_momentum=nesterovs_momentum,
48 | early_stopping=early_stopping,
49 | validation_fraction=validation_fraction,
50 | beta_1=beta_1,
51 | beta_2=beta_2,
52 | epsilon=epsilon
53 | )
54 |
55 | def learner(self):
56 | return self._learner
57 |
--------------------------------------------------------------------------------
/assignment1/learners/Boosting.py:
--------------------------------------------------------------------------------
1 | from sklearn import ensemble
2 |
3 | import learners
4 |
5 |
6 | class BoostingLearner(learners.BaseLearner):
7 | def __init__(self,
8 | base_estimator=None,
9 | n_estimators=50,
10 | learning_rate=1.,
11 | algorithm='SAMME.R',
12 | random_state=None,
13 | verbose=False):
14 | super().__init__(verbose)
15 | self._learner = ensemble.AdaBoostClassifier(
16 | base_estimator=base_estimator,
17 | n_estimators=n_estimators,
18 | learning_rate=learning_rate,
19 | algorithm=algorithm,
20 | random_state=random_state)
21 |
22 | def learner(self):
23 | return self._learner
24 |
--------------------------------------------------------------------------------
/assignment1/learners/KNN.py:
--------------------------------------------------------------------------------
1 | from sklearn import neighbors
2 |
3 | import learners
4 |
5 |
6 | class KNNLearner(learners.BaseLearner):
7 | def __init__(self,
8 | verbose=False,
9 | n_neighbors=5,
10 | weights='uniform',
11 | algorithm='auto',
12 | leaf_size=30,
13 | p=2,
14 | metric='minkowski',
15 | metric_params=None,
16 | n_jobs=1,
17 | **kwargs):
18 | super().__init__(verbose)
19 | self._learner = neighbors.KNeighborsClassifier(
20 | n_neighbors=n_neighbors,
21 | weights=weights,
22 | algorithm=algorithm,
23 | leaf_size=leaf_size,
24 | p=p,
25 | metric=metric,
26 | metric_params=metric_params,
27 | n_jobs=n_jobs,
28 | **kwargs)
29 |
30 | def learner(self):
31 | return self._learner
32 |
--------------------------------------------------------------------------------
/assignment1/learners/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from .ANN import *
3 | from .Boosting import *
4 | from .DT import *
5 | from .KNN import *
6 | from .SVM import *
7 |
8 | __all__ = ['ANN', 'Boosting', 'DT', 'KNN', 'SVM']
9 |
--------------------------------------------------------------------------------
/assignment1/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy == 1.15.1
2 | scipy == 1.1.0
3 | scikit-learn == 0.19.2
4 | pandas == 0.23.4
5 | xlrd == 0.9.0
6 | matplotlib == 2.2.3
7 | seaborn == 0.9.0
8 | scikit-optimize == 0.5.2
9 | gym == 0.10.5
10 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL.jar:
--------------------------------------------------------------------------------
1 | ABAGAIL/ABAGAIL.jar
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/ABAGAIL.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmaron/CS-7641-assignments/ae1d14ffd7ab043ec412faf40aaebdda182f3201/assignment2/ABAGAIL/ABAGAIL.jar
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/build.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/AbstractConditionalDistribution.java:
--------------------------------------------------------------------------------
1 | package dist;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * An abstract condtional distribution
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public abstract class AbstractConditionalDistribution extends AbstractDistribution implements ConditionalDistribution {
11 |
12 | /**
13 | * Generate a output given the input
14 | * @param i the input
15 | * @return the output
16 | */
17 | public Instance sample(Instance i) {
18 | return distributionFor(i).sample();
19 | }
20 |
21 | /**
22 | * Generate a output that is most likely given the input
23 | * @param i the input
24 | * @return the output
25 | */
26 | public Instance mode(Instance i) {
27 | return distributionFor(i).sample();
28 | }
29 |
30 | /**
31 | * Probability of an instance
32 | * @parma i the instance
33 | * @return the probability
34 | */
35 | public double p(Instance i) {
36 | return distributionFor(i).p(i.getLabel());
37 | }
38 | }
39 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/AbstractDistribution.java:
--------------------------------------------------------------------------------
1 | package dist;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * An abstract distribution
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public abstract class AbstractDistribution implements Distribution {
11 |
12 | /**
13 | * @see dist.Distribution#logp(shared.Instance)
14 | */
15 | public double logp(Instance i) {
16 | double p = p(i);
17 | double logp = Math.log(p);
18 | if (Double.isInfinite(logp)) {
19 | return -Double.MAX_VALUE;
20 | }
21 | return logp;
22 | }
23 |
24 | /**
25 | * Get an unconditional sample
26 | * @return the unconditional sample
27 | */
28 | public Instance sample() {
29 | return sample(null);
30 | }
31 |
32 | /**
33 | * Get an unconditional sample
34 | * @return the unconditional sample
35 | */
36 | public Instance mode() {
37 | return mode(null);
38 | }
39 |
40 | }
41 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/ConditionalDistribution.java:
--------------------------------------------------------------------------------
1 | package dist;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A conditional probability distribution
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface ConditionalDistribution extends Distribution {
11 |
12 | /**
13 | * Get the distribution for an instance
14 | * @param i the instance
15 | * @return the distribution
16 | */
17 | public Distribution distributionFor(Instance i);
18 |
19 | }
20 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/DiscretePermutationDistribution.java:
--------------------------------------------------------------------------------
1 | package dist;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import util.ABAGAILArrays;
6 |
7 | /**
8 | * A distribution of all of the permutations
9 | * of a set size.
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class DiscretePermutationDistribution extends AbstractDistribution {
14 | /**
15 | * The size of the data
16 | */
17 | private int n;
18 |
19 | /**
20 | * The probability
21 | */
22 | private double p;
23 |
24 | /**
25 | * Make a new discrete permutation distribution
26 | * @param n the size of the data
27 | */
28 | public DiscretePermutationDistribution(int n) {
29 | this.n = n;
30 | p = n;
31 | for (int i = n - 1; i >= 1; i--) {
32 | p *= i;
33 | }
34 | p = 1 / p;
35 | }
36 |
37 | /**
38 | * @see dist.Distribution#probabilityOf(shared.Instance)
39 | */
40 | public double p(Instance i) {
41 | return p;
42 | }
43 |
44 | /**
45 | * @see dist.Distribution#generateRandom(shared.Instance)
46 | */
47 | public Instance sample(Instance ignored) {
48 | double[] d = ABAGAILArrays.dindices(n);
49 | ABAGAILArrays.permute(d);
50 | return new Instance(d);
51 | }
52 |
53 | /**
54 | * @see dist.Distribution#generateMostLikely(shared.Instance)
55 | */
56 | public Instance mode(Instance ignored) {
57 | return sample(ignored);
58 | }
59 |
60 | /**
61 | * @see dist.Distribution#estimate(shared.DataSet)
62 | */
63 | public void estimate(DataSet observations) {
64 | return;
65 | }
66 | }
67 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/DiscreteUniformDistribution.java:
--------------------------------------------------------------------------------
1 | package dist;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 |
6 | /**
7 | * A distribution of all of the permutations
8 | * of a set size.
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class DiscreteUniformDistribution extends AbstractDistribution {
13 | /**
14 | * The ranges of the data
15 | */
16 | private int[] n;
17 |
18 | /**
19 | * The probability
20 | */
21 | private double p;
22 |
23 | /**
24 | * Make a new discrete permutation distribution
25 | * @param n the size of the data
26 | */
27 | public DiscreteUniformDistribution(int[] n) {
28 | this.n = n;
29 | p = n[0];
30 | for (int i = 1; i < n.length; i++) {
31 | p *= n[i];
32 | }
33 | p = 1 / p;
34 | }
35 |
36 | /**
37 | * @see dist.Distribution#probabilityOf(shared.Instance)
38 | */
39 | public double p(Instance i) {
40 | return p;
41 | }
42 |
43 | /**
44 | * @see dist.Distribution#generateRandom(shared.Instance)
45 | */
46 | public Instance sample(Instance ignored) {
47 | double[] d = new double[n.length];
48 | for (int i = 0; i < d.length; i++) {
49 | d[i] = random.nextInt(n[i]);
50 | }
51 | return new Instance(d);
52 | }
53 |
54 | /**
55 | * @see dist.Distribution#generateMostLikely(shared.Instance)
56 | */
57 | public Instance mode(Instance ignored) {
58 | return sample(ignored);
59 | }
60 |
61 | /**
62 | * @see dist.Distribution#estimate(shared.DataSet)
63 | */
64 | public void estimate(DataSet observations) {
65 | return;
66 | }
67 | }
68 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/Distribution.java:
--------------------------------------------------------------------------------
1 | package dist;
2 |
3 | import java.io.Serializable;
4 | import java.util.Random;
5 |
6 | import shared.DataSet;
7 | import shared.Instance;
8 |
9 | /**
10 | * A interface for distributions
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public interface Distribution extends Serializable {
15 | /**
16 | * A random number generator
17 | */
18 | public static final Random random = new Random();
19 | /**
20 | * Get the probability of i
21 | * @param i the discrete value to get the probability of
22 | * @return the probability of i
23 | */
24 | public abstract double p(Instance i);
25 | /**
26 | * Calculate the log likelihood
27 | * @param i the instance
28 | * @return the log likelihood
29 | */
30 | public abstract double logp(Instance i);
31 |
32 | /**
33 | * Generate a random value
34 | * @param i the conditional values or null
35 | * @return the value
36 | */
37 | public abstract Instance sample(Instance i);
38 |
39 | /**
40 | * Generate a random value
41 | * @return the value
42 | */
43 | public abstract Instance sample();
44 |
45 | /**
46 | * Get the mode of the distribution
47 | * @param i the instance
48 | * @return the mode
49 | */
50 | public abstract Instance mode(Instance i);
51 |
52 | /**
53 | * Get the mode of the distribution
54 | * @return the mode
55 | */
56 | public abstract Instance mode();
57 |
58 | /**
59 | * Estimate the distribution from data
60 | * @param set the data set to estimate from
61 | */
62 | public abstract void estimate(DataSet set);
63 |
64 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/hmm/StateDistribution.java:
--------------------------------------------------------------------------------
1 | package dist.hmm;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 |
6 |
7 | /**
8 | * An interface for state probalility functions
9 | * that represent the probabilty of transitioning to
10 | * a state and also the probability of starting in a state
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public interface StateDistribution {
15 |
16 | /**
17 | * Get the probability of the next state
18 | * @param nextState the next state
19 | * @param observ the observation
20 | * @return the probability
21 | */
22 | public abstract double p(int nextState, Instance observ);
23 |
24 | /**
25 | * Generate the next state
26 | * @param o the observation
27 | * @return the next state
28 | */
29 | public abstract int generateRandomState(Instance o);
30 |
31 | /**
32 | * Generate the most likely next state
33 | * @param o the observation
34 | * @return the next state
35 | */
36 | public abstract int mostLikelyState(Instance o);
37 |
38 | /**
39 | * Match the given expectations and observations
40 | * @param expectations entry [k][j] is the probability of transitioning
41 | * from this state to state j correpsonding to observation k, k can be
42 | * seen as kind of like t all though it is not in practice since
43 | * observations is many sequences glued together
44 | * @param sequence the sequence of corresponding observations
45 | */
46 | public abstract void estimate(double[][] expectations, DataSet sequence);
47 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/test/DiscreteDependencyTreeTest.java:
--------------------------------------------------------------------------------
1 | package dist.test;
2 |
3 | import dist.DiscreteDependencyTree;
4 | import shared.DataSet;
5 | import shared.Instance;
6 |
7 | /**
8 | * A test of the discrete dependency tree distribution
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class DiscreteDependencyTreeTest {
13 | /**
14 | * The test data
15 | */
16 | private static Instance[] data = new Instance[] {
17 | new Instance(new double[] { 0, 4, 4 , 4}),
18 | new Instance(new double[] { 4, 0, 1 , 0}),
19 | new Instance(new double[] { 4, 1, 0 , 1}),
20 | new Instance(new double[] { 4, 0, 0 , 0}),
21 | };
22 | /**
23 | * The test main
24 | * @param args ignored
25 | */
26 | public static void main(String[] args) {
27 | DataSet dataSet = new DataSet(data);
28 | DiscreteDependencyTree ddtd =
29 | new DiscreteDependencyTree(.001);
30 | ddtd.estimate(dataSet);
31 | System.out.println(ddtd);
32 | for (int i = 0; i < 20; i++) {
33 | System.out.println(ddtd.sample(null));
34 | }
35 | System.out.println("Most likely");
36 | System.out.println(ddtd.mode(null));
37 | System.out.println("Probabilities of training data");
38 | for (int i = 0; i < data.length; i++) {
39 | System.out.println(ddtd.p(data[i]));
40 | }
41 | }
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/test/DiscreteDistributionTest.java:
--------------------------------------------------------------------------------
1 | package dist.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import dist.DiscreteDistribution;
6 |
7 | /**
8 | * A multinomial distribution test
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class DiscreteDistributionTest {
13 | /**
14 | * Test main
15 | * @param args
16 | */
17 | public static void main(String[] args) {
18 | double[] ps = new double[] {
19 | .1, .3, .2, .4
20 | };
21 | DiscreteDistribution md = new DiscreteDistribution(ps);
22 | Instance[] samples = new Instance[10000];
23 | for (int i = 0; i < samples.length; i++) {
24 | samples[i] = md.sample();
25 | }
26 | md.estimate(new DataSet(samples));
27 | System.out.println(md);
28 | }
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/dist/test/MultivariateGaussianTest.java:
--------------------------------------------------------------------------------
1 | package dist.test;
2 |
3 | import dist.MultivariateGaussian;
4 | import shared.DataSet;
5 | import shared.Instance;
6 | import util.linalg.DenseVector;
7 | import util.linalg.RectangularMatrix;
8 |
9 | /**
10 | * Testing
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public class MultivariateGaussianTest {
15 |
16 | /**
17 | * The test main
18 | * @param args ignored
19 | */
20 | public static void main(String[] args) {
21 | Instance[] instances = new Instance[20];
22 | MultivariateGaussian mga = new MultivariateGaussian(new DenseVector(new double[] {100, 100, 100}), RectangularMatrix.eye(3).times(.01));
23 | for (int i = 0; i < instances.length; i++) {
24 | instances[i] = mga.sample();
25 | System.out.println(instances[i]);
26 | }
27 |
28 | DataSet set = new DataSet(instances);
29 | MultivariateGaussian mg = new MultivariateGaussian();
30 | mg.estimate(set);
31 | System.out.println(mg);
32 | System.out.println("Most likely " + mg.mode(null));
33 | for (int i = 0; i < 10; i++) {
34 | System.out.println(mg.sample(null));
35 | }
36 | for (int i = 0; i < instances.length; i++) {
37 | System.out.println("Probability of \n" + instances[i]
38 | + "\n " + mg.p(instances[i]));
39 | }
40 | }
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/FunctionApproximater.java:
--------------------------------------------------------------------------------
1 | package func;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 |
6 | /**
7 | *
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public interface FunctionApproximater {
12 |
13 | /**
14 | * Estimate from the given data set
15 | * @param set the data set
16 | */
17 | public void estimate(DataSet set);
18 |
19 | /**
20 | * Evaluate the function
21 | * @param i the input
22 | * @return the value
23 | */
24 | public Instance value(Instance i);
25 |
26 | }
27 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/FunctionApproximaterSupplier.java:
--------------------------------------------------------------------------------
1 | package func;
2 |
3 | /**
4 | * A supplier for creating {@link FunctionApproximater} instances used to
5 | * pass customized classifiers into {@link AdaBoostClassifier}
6 | */
7 | public interface FunctionApproximaterSupplier {
8 | FunctionApproximater get();
9 | }
10 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/dtree/BinaryDecisionTreeSplit.java:
--------------------------------------------------------------------------------
1 | package func.dtree;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A standard decision tree split
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class BinaryDecisionTreeSplit extends DecisionTreeSplit {
11 |
12 | /**
13 | * The attribute being split on
14 | */
15 | private int attribute;
16 |
17 | /**
18 | * The splitting value
19 | */
20 | private int value;
21 |
22 | /**
23 | * Create a new binary decision tree split
24 | * @param attribute the attribute being split on
25 | * @param value the value split on
26 | */
27 | public BinaryDecisionTreeSplit(int attribute,int value) {
28 | this.attribute = attribute;
29 | this.value = value;
30 | }
31 |
32 | /**
33 | * @see dtrees.DecisionTreeSplit#getNumberOfBranches()
34 | */
35 | public int getNumberOfBranches() {
36 | return 2;
37 | }
38 |
39 |
40 | /**
41 | * @see dtree.DecisionTreeSplit#getBranchOf(shared.Instance)
42 | */
43 | public int getBranchOf(Instance i) {
44 | if (i.getDiscrete(attribute) == value) {
45 | return 0;
46 | } else {
47 | return 1;
48 | }
49 | }
50 |
51 | /**
52 | * @see java.lang.Object#toString()
53 | */
54 | public String toString() {
55 | return "attribute " + attribute + " == " + value;
56 | }
57 |
58 | }
59 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/dtree/DecisionTreeSplit.java:
--------------------------------------------------------------------------------
1 | package func.dtree;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A split in a decision tree
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public abstract class DecisionTreeSplit {
11 |
12 | /**
13 | * Get the number of branches in this split
14 | * @return the number of branches
15 | */
16 | public abstract int getNumberOfBranches();
17 |
18 | /**
19 | * Get the branch of the given data
20 | * @param d the data
21 | * @return the branch
22 | */
23 | public abstract int getBranchOf(Instance i);
24 | }
25 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/dtree/GINISplitEvaluator.java:
--------------------------------------------------------------------------------
1 | package func.dtree;
2 |
3 | /**
4 | * A splitting criteria using GINI index
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class GINISplitEvaluator extends SplitEvaluator {
9 |
10 | /**
11 | * Calculate the GINI of an array of class probabilites
12 | * @param classProbabilities the probabilites
13 | * @return the GINI value
14 | */
15 | private double gini(double[] classProbabilities) {
16 | double gini = 1;
17 | for (int i = 0; i < classProbabilities.length; i++) {
18 | gini -= classProbabilities[i] * classProbabilities[i];
19 | }
20 | return gini;
21 | }
22 |
23 | /**
24 | * @see dtrees.SplitEvaluator#splitValue(dtrees.DecisionTreeSplitStatistics)
25 | */
26 | public double splitValue(DecisionTreeSplitStatistics stats) {
27 | double giniIndex = 0;
28 | for (int i = 0; i < stats.getBranchCount(); i++) {
29 | giniIndex += stats.getBranchProbability(i) *
30 | gini(stats.getConditionalClassProbabilities(i));
31 | }
32 | // we want to minimize the gini index
33 | return 1/giniIndex;
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/dtree/InformationGainSplitEvaluator.java:
--------------------------------------------------------------------------------
1 | package func.dtree;
2 |
3 | /**
4 | * A splitting criteria that uses information gain as a basis
5 | * for deciding the value of a split
6 | * @author Andrew Guillory gtg008g@mail.gatech.edu
7 | * @version 1.0
8 | */
9 | public class InformationGainSplitEvaluator extends SplitEvaluator {
10 |
11 | /**
12 | * The log of 2
13 | */
14 | private static final double LOG2 = Math.log(2);
15 |
16 | /**
17 | * Calculate the entropy of an array of class probabilites
18 | * @param classProbabilities the probabilites
19 | * @return the entropy
20 | */
21 | private double entropy(double[] classProbabilities) {
22 | double entropy = 0;
23 | for (int i = 0; i < classProbabilities.length; i++) {
24 | if (classProbabilities[i] != 0)
25 | entropy -= classProbabilities[i]
26 | * Math.log(classProbabilities[i]) / LOG2;
27 | }
28 | return entropy;
29 | }
30 |
31 | /**
32 | * @see dtrees.SplitEvaluator#splitValue(dtrees.DecisionTreeSplitStatistics)
33 | */
34 | public double splitValue(DecisionTreeSplitStatistics stats) {
35 | // the entropy before splitting
36 | double initialEntropy = entropy(stats.getClassProbabilities());
37 | // and now after
38 | double conditionalEntropy = 0;
39 | for (int i = 0; i < stats.getBranchCount(); i++) {
40 | conditionalEntropy += stats.getBranchProbability(i) *
41 | entropy(stats.getConditionalClassProbabilities(i));
42 | }
43 | // the information gain is just initial minus conditional
44 | return initialEntropy - conditionalEntropy;
45 | }
46 |
47 |
48 |
49 | }
50 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/dtree/PruningCriteria.java:
--------------------------------------------------------------------------------
1 | package func.dtree;
2 |
3 | /**
4 | * A class for deciding whether or not to prune a node
5 | */
6 | public abstract class PruningCriteria {
7 |
8 |
9 | /**
10 | * Decide whether or not to prune based a node
11 | * @param stats the stats of the node
12 | * @return true if we should prune
13 | */
14 | public abstract boolean shouldPrune(DecisionTreeSplitStatistics stats);
15 | }
16 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/dtree/SplitEvaluator.java:
--------------------------------------------------------------------------------
1 | package func.dtree;
2 |
3 | /**
4 | * A criteria for splitting in a decision tree
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public abstract class SplitEvaluator {
9 |
10 | /**
11 | * Get the value of splitting a set of instances
12 | * along the given attribute
13 | * @param stats the statistics for splitting
14 | * @return the value
15 | */
16 | public abstract double splitValue(DecisionTreeSplitStatistics stats);
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/dtree/StandardDecisionTreeSplit.java:
--------------------------------------------------------------------------------
1 | package func.dtree;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A standard decision tree split
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class StandardDecisionTreeSplit extends DecisionTreeSplit {
11 |
12 | /**
13 | * The attribute being split on
14 | */
15 | private int attribute;
16 |
17 | /**
18 | * The range of attributes for the split
19 | */
20 | private int attributeRange;
21 |
22 | /**
23 | * Create a new standard decision tree split
24 | * @param attribute the attribute being split on
25 | * @param attributeRange the range of attributs
26 | */
27 | public StandardDecisionTreeSplit(int attribute, int attributeRange) {
28 | this.attribute = attribute;
29 | this.attributeRange = attributeRange;
30 | }
31 |
32 | /**
33 | * @see dtrees.DecisionTreeSplit#getNumberOfBranches()
34 | */
35 | public int getNumberOfBranches() {
36 | return attributeRange;
37 | }
38 |
39 | /**
40 | * @see dtree.DecisionTreeSplit#getBranchOf(shared.Instance)
41 | */
42 | public int getBranchOf(Instance data) {
43 | return data.getDiscrete(attribute);
44 | }
45 |
46 | /**
47 | * @see java.lang.Object#toString()
48 | */
49 | public String toString() {
50 | return "attribute " + attribute;
51 | }
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/NetworkTrainer.java:
--------------------------------------------------------------------------------
1 | package func.nn;
2 |
3 | import shared.DataSet;
4 | import shared.ErrorMeasure;
5 | import shared.Trainer;
6 |
7 | /**
8 | * A class that represents a trainer for
9 | * a neural network
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public abstract class NetworkTrainer implements Trainer {
14 |
15 | /**
16 | * The patterns that are being trained on
17 | */
18 | private DataSet patterns;
19 |
20 | /**
21 | * The network being trained
22 | */
23 | private NeuralNetwork network;
24 |
25 | /**
26 | * The error measure to use in training
27 | */
28 | private ErrorMeasure errorMeasure;
29 |
30 | /**
31 | * Make a new network trainer
32 | * @param patterns the patterns
33 | * @param network the network
34 | */
35 | public NetworkTrainer(DataSet patterns, NeuralNetwork network,
36 | ErrorMeasure errorMeasure) {
37 | this.patterns = patterns;
38 | this.network = network;
39 | this.errorMeasure = errorMeasure;
40 | }
41 |
42 | /**
43 | * Get the network
44 | * @return the network
45 | */
46 | public NeuralNetwork getNetwork() {
47 | return network;
48 | }
49 |
50 | /**
51 | * Get the error measure to use when training
52 | * @return the error measure
53 | */
54 | public ErrorMeasure getErrorMeasure() {
55 | return errorMeasure;
56 | }
57 |
58 | /**
59 | * Get the patterns
60 | * @return the pattern
61 | */
62 | public DataSet getDataSet() {
63 | return patterns;
64 | }
65 |
66 |
67 | }
68 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/activation/ActivationFunction.java:
--------------------------------------------------------------------------------
1 | package func.nn.activation;
2 |
3 | import java.io.Serializable;
4 |
5 | /**
6 | * A activation function
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public abstract class ActivationFunction implements Serializable {
11 | /**
12 | * Activation of a value
13 | * @param value the value
14 | * @return the activation
15 | */
16 | public abstract double value(double value);
17 | }
18 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/activation/DifferentiableActivationFunction.java:
--------------------------------------------------------------------------------
1 |
2 |
3 | package func.nn.activation;
4 | /**
5 | * A activation function that is differentiable
6 | * @author Andrew Guillory gtg008g@mail.gatech.edu
7 | * @version 1.0
8 | */
9 | public abstract class DifferentiableActivationFunction extends ActivationFunction {
10 |
11 | /**
12 | * Perform the derivative of this function on the given value
13 | * @param value the value to perform the derivative on
14 | * @return the result
15 | */
16 | public abstract double derivative(double value);
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/activation/HyperbolicTangentSigmoid.java:
--------------------------------------------------------------------------------
1 | package func.nn.activation;
2 |
3 | /**
4 | * The tanh sigmoid function
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class HyperbolicTangentSigmoid
9 | extends DifferentiableActivationFunction{
10 |
11 | /**
12 | * @see nn.function.DifferentiableActivationFunction#derivative(double)
13 | */
14 | public double derivative(double value) {
15 | double tanhvalue = value(value);
16 | return 1 - tanhvalue * tanhvalue;
17 | }
18 |
19 | /**
20 | * @see nn.function.ActivationFunction#activation(double)
21 | */
22 | public double value(double value) {
23 | double e2x = Math.exp(2 * value);
24 | if (e2x == Double.POSITIVE_INFINITY) {
25 | return 1;
26 | } else {
27 | return (e2x - 1) / (e2x + 1);
28 | }
29 | }
30 |
31 |
32 | }
33 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/activation/LinearActivationFunction.java:
--------------------------------------------------------------------------------
1 | package func.nn.activation;
2 |
3 | /**
4 | * A linear activation function
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class LinearActivationFunction extends DifferentiableActivationFunction {
9 |
10 | /**
11 | * @see nn.function.DifferentiableActivationFunction#derivative(double)
12 | */
13 | public double derivative(double value) {
14 | return 1;
15 | }
16 |
17 | /**
18 | * @see nn.function.ActivationFunction#activation(double)
19 | */
20 | public double value(double value) {
21 | return value;
22 | }
23 |
24 | }
25 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/activation/LogisticSigmoid.java:
--------------------------------------------------------------------------------
1 |
2 |
3 | package func.nn.activation;
4 | /**
5 | * A sigmoid activation function
6 | * @author Andrew Guillory gtg008g@mail.gatech.edu
7 | * @version 1.0
8 | */
9 | public class LogisticSigmoid extends DifferentiableActivationFunction {
10 |
11 | /**
12 | * @see nn.function.ActivationFunction#activation(double)
13 | */
14 | public double value(double value) {
15 | double enx = Math.exp(-value);
16 | if (enx == Double.POSITIVE_INFINITY) {
17 | return 0;
18 | } else {
19 | return 1.0 / (1.0 + enx);
20 | }
21 | }
22 |
23 | /**
24 | * @see nn.function.DifferentiableActivationFunction#derivative(double)
25 | */
26 | public double derivative(double value) {
27 | double logistic = value(value);
28 | return logistic * (1 - logistic);
29 | }
30 |
31 | }
32 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/activation/RELU.java:
--------------------------------------------------------------------------------
1 | package func.nn.activation;
2 |
3 | /**
4 | * The tanh sigmoid function
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class RELU
9 | extends DifferentiableActivationFunction{
10 |
11 | /**
12 | * @see nn.function.DifferentiableActivationFunction#derivative(double)
13 | */
14 | public double derivative(double value) {
15 | if (value < 0){
16 | return 0;
17 | } else {
18 | return 1;
19 | }
20 | }
21 |
22 | /**
23 | * @see nn.function.ActivationFunction#activation(double)
24 | */
25 | public double value(double value) {
26 |
27 | if (value < 0) {
28 | return 0;
29 | } else {
30 | return value;
31 | }
32 | }
33 |
34 |
35 | }
36 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/backprop/BackPropagationBiasNode.java:
--------------------------------------------------------------------------------
1 | package func.nn.backprop;
2 |
3 | import func.nn.Neuron;
4 |
5 | /**
6 | * A bias node, implemented as a node
7 | * that refuses to feed forward values
8 | * or backpropagate values. This is a little
9 | * wasteful since as it is used it has useless
10 | * links that go into it.
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public class BackPropagationBiasNode extends BackPropagationNode {
15 |
16 | /**
17 | * A bias node
18 | * @param bias the bias value to set to
19 | */
20 | public BackPropagationBiasNode(double bias) {
21 | super(null);
22 | setActivation(bias);
23 | }
24 |
25 | /**
26 | * @see func.nn.feedfwd.FeedForwardNode#feedforward()
27 | */
28 | public void feedforward() { }
29 |
30 | /**
31 | * @see func.nn.backprop.BackPropagationNode#backpropagate()
32 | */
33 | public void backpropagate() { }
34 |
35 | /**
36 | * Bias node should not be connected to other bias nodes
37 | * @param neuron other neuron to connect to
38 | */
39 | @Override
40 | public void connect(Neuron neuron) {
41 | if (!neuron.getClass().equals(BackPropagationBiasNode.class)) {
42 | super.connect(neuron);
43 | }
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/backprop/BackPropagationLayer.java:
--------------------------------------------------------------------------------
1 | package func.nn.backprop;
2 |
3 | import func.nn.feedfwd.FeedForwardLayer;
4 |
5 | /**
6 | * A layer in a backpropagation network
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class BackPropagationLayer extends FeedForwardLayer {
11 |
12 | /**
13 | * Back propagate all the error values for this
14 | * layer.
15 | */
16 | public void backpropagate() {
17 | for (int i = 0; i < getNodeCount(); i++) {
18 | BackPropagationNode node =
19 | (BackPropagationNode) getNode(i);
20 | node.backpropagate();
21 | node.backpropagateLinks();
22 | }
23 | }
24 |
25 | /**
26 | * Clear out the error derivatives in the weights
27 | */
28 | public void clearError() {
29 | for (int i = 0; i < getNodeCount(); i++) {
30 | ((BackPropagationNode) getNode(i)).clearError();
31 | }
32 | }
33 |
34 | /**
35 | * Update weights with the given rule
36 | * @param rule the rule to use
37 | */
38 | public void updateWeights(WeightUpdateRule rule) {
39 | for (int i = 0; i < getNodeCount(); i++) {
40 | ((BackPropagationNode) getNode(i)).updateWeights(rule);
41 | }
42 | }
43 |
44 | /**
45 | * Set the output errors for this layer
46 | * @param errors the output errors
47 | */
48 | public void setOutputErrors(double[] errors) {
49 | for (int i = 0; i < getNodeCount(); i++) {
50 | ((BackPropagationNode) getNode(i)).setOutputError(errors[i]);
51 | }
52 | }
53 |
54 | }
55 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/backprop/BackPropagationNetwork.java:
--------------------------------------------------------------------------------
1 |
2 |
3 | package func.nn.backprop;
4 |
5 | import func.nn.feedfwd.FeedForwardNetwork;
6 |
7 | /**
8 | * A back propagation network
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class BackPropagationNetwork extends FeedForwardNetwork {
13 |
14 | /**
15 | * Backpropagte through the network.
16 | */
17 | public void backpropagate() {
18 | ((BackPropagationLayer) getOutputLayer()).backpropagate();
19 | for (int i = getHiddenLayerCount() - 1; i >= 0; i--) {
20 | ((BackPropagationLayer) getHiddenLayer(i)).backpropagate();;
21 | }
22 | }
23 |
24 | /**
25 | * Clear out the error values at the end of a batch
26 | * or at the end of a single training for
27 | * stochastic / online training
28 | */
29 | public void clearError() {
30 | ((BackPropagationLayer) getOutputLayer()).clearError();
31 | for (int i = getHiddenLayerCount() - 1; i >= 0; i--) {
32 | ((BackPropagationLayer) getHiddenLayer(i)).clearError();;
33 | }
34 | }
35 |
36 | /**
37 | * Update weights with the given rule
38 | * @param rule the rule to use to update weights
39 | */
40 | public void updateWeights(WeightUpdateRule rule) {
41 | ((BackPropagationLayer) getOutputLayer()).updateWeights(rule);
42 | for (int i = getHiddenLayerCount() - 1; i >= 0; i--) {
43 | ((BackPropagationLayer) getHiddenLayer(i)).updateWeights(rule);;
44 | }
45 | }
46 |
47 | /**
48 | * Set the output errors
49 | * @param errors the output errors
50 | */
51 | public void setOutputErrors(double[] errors) {
52 | ((BackPropagationLayer) getOutputLayer()).setOutputErrors(errors);
53 | }
54 |
55 |
56 |
57 |
58 | }
59 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/backprop/BackPropagationSoftMaxOutputLayer.java:
--------------------------------------------------------------------------------
1 | package func.nn.backprop;
2 |
3 | /**
4 | * A soft max layer in a back propagation network
5 | * that can be used with a standard error measure
6 | * for multi class probability in the output layer
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class BackPropagationSoftMaxOutputLayer
11 | extends BackPropagationLayer {
12 |
13 |
14 | /**
15 | * @see nn.FeedForwardLayer#feedforward()
16 | */
17 | public void feedforward() {
18 | // feed forward to calculate
19 | // the weighted input sums
20 | super.feedforward();
21 | // trick stolen from Torch library for preventing overflows
22 | double shift = ((BackPropagationNode) getNode(0)).getWeightedInputSum();
23 | for (int i = 0; i < getNodeCount(); i++) {
24 | BackPropagationNode node =
25 | (BackPropagationNode) getNode(i);
26 | shift = Math.max(shift, node.getWeightedInputSum());
27 | }
28 | // now override the activation values
29 | // by caculating it ourselves
30 | // with the softmax formula
31 | double sum = 0;
32 | for (int i = 0; i < getNodeCount(); i++) {
33 | BackPropagationNode node =
34 | (BackPropagationNode) getNode(i);
35 | node.setActivation(
36 | Math.exp(node.getWeightedInputSum() - shift));
37 | sum += node.getActivation();
38 | }
39 | for (int i = 0; i < getNodeCount(); i++) {
40 | BackPropagationNode node =
41 | (BackPropagationNode) getNode(i);
42 | node.setActivation(node.getActivation() / sum);
43 | }
44 | }
45 | }
46 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/backprop/BasicUpdateRule.java:
--------------------------------------------------------------------------------
1 | package func.nn.backprop;
2 |
3 | /**
4 | * Very basic update rule with no momentum
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class BasicUpdateRule extends WeightUpdateRule {
9 | /**
10 | * The learning rate to use
11 | */
12 | private double learningRate;
13 |
14 |
15 | /**
16 | * Create a new basic update rule
17 | * @param learningRate the learning rate
18 | */
19 | public BasicUpdateRule(double learningRate) {
20 | this.learningRate = learningRate;
21 | }
22 |
23 | /**
24 | * Create a new basic update rule
25 | */
26 | public BasicUpdateRule() {
27 | this(.01);
28 | }
29 |
30 | /**
31 | * @see nn.backprop.BackPropagationUpdateRule#update(nn.backprop.BackPropagationLink)
32 | */
33 | public void update(BackPropagationLink link) {
34 | link.changeWeight(-learningRate * link.getError());
35 | }
36 |
37 | }
38 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/backprop/QuickpropUpdateRule.java:
--------------------------------------------------------------------------------
1 | package func.nn.backprop;
2 |
3 | /**
4 | * An update rule for the Quickprop algorithm
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class QuickpropUpdateRule extends WeightUpdateRule {
9 |
10 | /**
11 | * The learning rate
12 | */
13 | private double learningRate;
14 |
15 | /**
16 | * Make a new quickprop update rule
17 | * @param learningRate the learning rate
18 | */
19 | public QuickpropUpdateRule(double learningRate) {
20 | this.learningRate = learningRate;
21 | }
22 |
23 | /**
24 | * Make a new quickprop update rule
25 | */
26 | public QuickpropUpdateRule() {
27 | this(.2);
28 | }
29 |
30 | /**
31 | * @see nn.backprop.BackPropagationUpdateRule#update(nn.backprop.BackPropagationLink)
32 | */
33 | public void update(BackPropagationLink link) {
34 | if (link.getLastError() == 0) {
35 | // the first run
36 | link.changeWeight(-learningRate * link.getError());
37 | } else {
38 | // jump to parabola min
39 | link.changeWeight(link.getError()
40 | / (link.getLastError() - link.getError())
41 | * link.getLastChange()
42 | - learningRate * link.getError());
43 | }
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/backprop/StandardUpdateRule.java:
--------------------------------------------------------------------------------
1 | package func.nn.backprop;
2 |
3 | /**
4 | *
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class StandardUpdateRule extends WeightUpdateRule {
9 |
10 | /**
11 | * The learning rate to use
12 | */
13 | private double learningRate;
14 |
15 | /**
16 | * The momentum to use
17 | */
18 | private double momentum;
19 |
20 | /**
21 | * Create a new standard momentum update rule
22 | * @param learningRate the learning rate
23 | * @param momentum the momentum
24 | */
25 | public StandardUpdateRule(double learningRate, double momentum) {
26 | this.momentum = momentum;
27 | this.learningRate = learningRate;
28 | }
29 |
30 | /**
31 | * Create a new standard update rule
32 | */
33 | public StandardUpdateRule() {
34 | this(.2, .9);
35 | }
36 |
37 | /**
38 | * @see nn.backprop.BackPropagationUpdateRule#update(nn.backprop.BackPropagationLink)
39 | */
40 | public void update(BackPropagationLink link) {
41 | link.changeWeight(-learningRate * link.getError()
42 | + link.getLastChange() * momentum);
43 | }
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/backprop/WeightUpdateRule.java:
--------------------------------------------------------------------------------
1 | package func.nn.backprop;
2 |
3 | import java.io.Serializable;
4 |
5 | /**
6 | * An update rule for a back propagation link
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public abstract class WeightUpdateRule implements Serializable {
11 |
12 | /**
13 | * Update the given link
14 | * @param link the link to update
15 | */
16 | public abstract void update(BackPropagationLink link);
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/feedfwd/FeedForwardBiasNode.java:
--------------------------------------------------------------------------------
1 | package func.nn.feedfwd;
2 |
3 | import func.nn.Neuron;
4 |
5 | /**
6 | *
7 | * @author Jesse Rosalia
8 | * @date 2013-03-05
9 | */
10 | public class FeedForwardBiasNode extends FeedForwardNode {
11 |
12 | public FeedForwardBiasNode(double activation) {
13 | super(null);
14 | super.setActivation(activation);
15 | }
16 |
17 | /**
18 | * @see func.nn.feedfwd.FeedForwardNode#feedforward()
19 | */
20 | public void feedforward() { }
21 |
22 | /**
23 | * Bias node should not be connected to other bias nodes
24 | * @param neuron other neuron to connect to
25 | */
26 | @Override
27 | public void connect(Neuron neuron) {
28 | if (!neuron.getClass().equals(FeedForwardBiasNode.class)) {
29 | super.connect(neuron);
30 | }
31 | }
32 |
33 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/feedfwd/FeedForwardLayer.java:
--------------------------------------------------------------------------------
1 | package func.nn.feedfwd;
2 |
3 | import func.nn.Layer;
4 |
5 | /**
6 | * A feed forward layer in a neural network
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class FeedForwardLayer extends Layer {
11 |
12 | /**
13 | * Feed foward all of the nodes in this layer.
14 | */
15 | public void feedforward() {
16 | for (int i = 0; i < getNodeCount(); i++) {
17 | ((FeedForwardNode) getNode(i)).feedforward();
18 | }
19 | }
20 |
21 | }
22 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/feedfwd/FeedForwardNetwork.java:
--------------------------------------------------------------------------------
1 | package func.nn.feedfwd;
2 |
3 | import func.nn.LayeredNetwork;
4 |
5 | /**
6 | * A feed forward network
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class FeedForwardNetwork extends LayeredNetwork {
11 |
12 | /**
13 | * @see nn.Network#run()
14 | */
15 | public void run() {
16 | for (int i = 0; i < getHiddenLayerCount(); i++) {
17 | ((FeedForwardLayer) getHiddenLayer(i)).feedforward();
18 | }
19 | ((FeedForwardLayer) getOutputLayer()).feedforward();
20 | }
21 |
22 |
23 | }
24 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/nn/feedfwd/FeedForwardNode.java:
--------------------------------------------------------------------------------
1 |
2 |
3 | package func.nn.feedfwd;
4 |
5 | import func.nn.Neuron;
6 | import func.nn.activation.ActivationFunction;
7 |
8 | /**
9 | * A node in a feed forward network
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class FeedForwardNode extends Neuron {
14 |
15 | /**
16 | * The transfer function
17 | */
18 | private ActivationFunction activationFunction;
19 |
20 | /**
21 | * The weighted input sum
22 | */
23 | private double weightedInputSum;
24 |
25 | /**
26 | * Make a new feed forward node
27 | * @param transfer the transfer function
28 | */
29 | public FeedForwardNode(ActivationFunction transfer) {
30 | activationFunction = transfer;
31 | }
32 |
33 | /**
34 | * Get the transfer function
35 | * @return the transfer function
36 | */
37 | public ActivationFunction getActivationFunction() {
38 | return activationFunction;
39 | }
40 |
41 | /**
42 | * Get the weighted input sum for this node
43 | * @return the weighted input sum
44 | */
45 | public double getWeightedInputSum() {
46 | return weightedInputSum;
47 | }
48 |
49 | /**
50 | * Feed forward the activation values into this node.
51 | * Calculates the sum of the input values and stores
52 | * this value into weightedInputSum.
53 | * Runs this sum through the activation function
54 | * and stores this into the activation for the node.
55 | */
56 | public void feedforward() {
57 | if (getInLinkCount() > 0) {
58 | double sum = 0;
59 | for (int i = 0; i < getInLinkCount(); i++) {
60 | sum += getInLink(i).getWeightedInValue();
61 | }
62 | weightedInputSum = sum;
63 | setActivation(activationFunction.value(sum));
64 | }
65 | }
66 |
67 | }
68 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/svm/LinearKernel.java:
--------------------------------------------------------------------------------
1 | package func.svm;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A linear support vector machine kernel
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class LinearKernel extends Kernel {
11 |
12 | /**
13 | * @see svm.Kernel#value(shared.Instance, shared.Instance)
14 | */
15 | public double value(Instance a, Instance b) {
16 | return a.getData().dotProduct(b.getData());
17 | }
18 |
19 | /**
20 | * @see java.lang.Object#toString()
21 | */
22 | public String toString() {
23 | return "Linear Kernel";
24 | }
25 |
26 | }
27 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/svm/RBFKernel.java:
--------------------------------------------------------------------------------
1 | package func.svm;
2 |
3 | import shared.Instance;
4 | import util.linalg.Vector;
5 |
6 | /**
7 | * A radial basis function kernel
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public class RBFKernel extends Kernel {
12 |
13 | /**
14 | * The sigma parameter
15 | */
16 | private double sigma;
17 |
18 | /**
19 | * The gamma value
20 | */
21 | private double gamma;
22 |
23 | /**
24 | * Make a new radial basis function kernel
25 | * @param sigma the sigma value
26 | */
27 | public RBFKernel(double sigma) {
28 | this.sigma = sigma;
29 | gamma = -1/(2 * sigma * sigma);
30 | }
31 |
32 |
33 | /**
34 | * @see svm.Kernel#value(svm.SupportVectorMachineData, svm.SupportVectorMachineData)
35 | */
36 | public double value(Instance a, Instance b) {
37 | Vector va = a.getData();
38 | Vector vb = b.getData();
39 | double difference = va.dotProduct(va)
40 | + vb.dotProduct(vb)
41 | - 2*va.dotProduct(vb);
42 | return Math.exp(gamma * difference);
43 | }
44 |
45 | /**
46 | * @see java.lang.Object#toString()
47 | */
48 | public String toString() {
49 | return "RBF Kernel sigma = " + sigma;
50 | }
51 |
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/AdaBoostTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import func.AdaBoostClassifier;
6 | import func.DecisionStumpClassifier;
7 |
8 | /**
9 | * Test the class
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class AdaBoostTest {
14 |
15 | /**
16 | * Test main
17 | * @param args ignored
18 | */
19 | public static void main(String[] args) {
20 | Instance[] instances = {
21 | new Instance(new double[] {1,1,0,0,0,0,0,0}, 0),
22 | new Instance(new double[] {0,0,1,1,1,0,0,0}, 1),
23 | new Instance(new double[] {0,0,0,0,1,1,1,1}, 0),
24 | new Instance(new double[] {1,0,0,0,1,0,1,0}, 1),
25 | new Instance(new double[] {1,1,1,0,1,1,0,0}, 1),
26 | };
27 | Instance[] tests = {
28 | new Instance(new double[] {1,1,1,0,0,0,0,0}),
29 | };
30 | DataSet set = new DataSet(instances);
31 | AdaBoostClassifier ds = new AdaBoostClassifier(20, DecisionStumpClassifier.class);
32 | ds.estimate(set);
33 | System.out.println(ds);
34 | for (int i = 0; i < tests.length; i++) {
35 | System.out.println(ds.value(tests[i]));
36 | }
37 | }
38 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/DecisionStumpTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import func.DecisionStumpClassifier;
6 | import func.dtree.ChiSquarePruningCriteria;
7 | import func.dtree.GINISplitEvaluator;
8 | import func.dtree.InformationGainSplitEvaluator;
9 | import func.dtree.PruningCriteria;
10 | import func.dtree.SplitEvaluator;
11 |
12 | /**
13 | * Test the class
14 | * @author Andrew Guillory gtg008g@mail.gatech.edu
15 | * @version 1.0
16 | */
17 | public class DecisionStumpTest {
18 |
19 | /**
20 | * Test main
21 | * @param args ignored
22 | */
23 | public static void main(String[] args) {
24 | Instance[] instances = {
25 | new Instance(new double[] {0, 0, 0, 1}, 1),
26 | new Instance(new double[] {1, 0, 0, 0}, 1),
27 | new Instance(new double[] {1, 0, 0, 0}, 1),
28 | new Instance(new double[] {1, 0, 0, 0}, 1),
29 | new Instance(new double[] {1, 0, 0, 1}, 0),
30 | new Instance(new double[] {1, 0, 0, 1}, 0),
31 | new Instance(new double[] {1, 0, 0, 1}, 0),
32 | new Instance(new double[] {1, 0, 0, 1}, 0)
33 | };
34 | Instance[] tests = {
35 | new Instance(new double[] {0, 1, 1, 1}),
36 | new Instance(new double[] {0, 0, 0, 0}),
37 | new Instance(new double[] {1, 0, 0, 0}),
38 | new Instance(new double[] {1, 1, 1, 1})
39 | };
40 | DataSet set = new DataSet(instances);
41 | PruningCriteria cspc = new ChiSquarePruningCriteria(0);
42 | SplitEvaluator gse = new GINISplitEvaluator();
43 | SplitEvaluator igse = new InformationGainSplitEvaluator();
44 | DecisionStumpClassifier ds = new DecisionStumpClassifier(igse);
45 | ds.estimate(set);
46 | System.out.println(ds);
47 | for (int i = 0; i < tests.length; i++) {
48 | System.out.println(ds.value(tests[i]));
49 | }
50 | }
51 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/DecisionTreeTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import func.DecisionTreeClassifier;
6 | import func.dtree.ChiSquarePruningCriteria;
7 | import func.dtree.GINISplitEvaluator;
8 | import func.dtree.InformationGainSplitEvaluator;
9 | import func.dtree.PruningCriteria;
10 | import func.dtree.SplitEvaluator;
11 |
12 | /**
13 | * Test the class
14 | * @author Andrew Guillory gtg008g@mail.gatech.edu
15 | * @version 1.0
16 | */
17 | public class DecisionTreeTest {
18 |
19 | /**
20 | * Test main
21 | * @param args ignored
22 | */
23 | public static void main(String[] args) {
24 | Instance[] instances = {
25 | new Instance(new double[] {0, 0, 0, 1}, 1),
26 | new Instance(new double[] {1, 0, 0, 0}, 1),
27 | new Instance(new double[] {1, 0, 0, 0}, 1),
28 | new Instance(new double[] {1, 0, 0, 0}, 1),
29 | new Instance(new double[] {1, 0, 0, 1}, 0),
30 | new Instance(new double[] {1, 0, 0, 1}, 0),
31 | new Instance(new double[] {1, 0, 0, 1}, 0),
32 | new Instance(new double[] {1, 0, 0, 1}, 0)
33 | };
34 | Instance[] tests = {
35 | new Instance(new double[] {0, 1, 1, 1}),
36 | new Instance(new double[] {0, 0, 0, 0}),
37 | new Instance(new double[] {1, 0, 0, 0}),
38 | new Instance(new double[] {1, 1, 1, 1})
39 | };
40 | DataSet set = new DataSet(instances);
41 | PruningCriteria cspc = new ChiSquarePruningCriteria(0);
42 | SplitEvaluator gse = new GINISplitEvaluator();
43 | SplitEvaluator igse = new InformationGainSplitEvaluator();
44 | DecisionTreeClassifier dt = new DecisionTreeClassifier(igse, null, true);
45 | dt.estimate(set);
46 | System.out.println(dt);
47 | for (int i = 0; i < tests.length; i++) {
48 | System.out.println(dt.value(tests[i]));
49 | }
50 | }
51 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/EMClustererTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import dist.Distribution;
4 | import dist.MultivariateGaussian;
5 | import func.EMClusterer;
6 | import shared.DataSet;
7 | import shared.Instance;
8 | import util.linalg.DenseVector;
9 | import util.linalg.RectangularMatrix;
10 |
11 | /**
12 | * Testing
13 | * @author Andrew Guillory gtg008g@mail.gatech.edu
14 | * @version 1.0
15 | */
16 | public class EMClustererTest {
17 | /**
18 | * The test main
19 | * @param args ignored
20 | */
21 | public static void main(String[] args) throws Exception {
22 | Instance[] instances = new Instance[100];
23 | MultivariateGaussian mga = new MultivariateGaussian(new DenseVector(new double[] {100, 100, 100}), RectangularMatrix.eye(3).times(.01));
24 | MultivariateGaussian mgb = new MultivariateGaussian(new DenseVector(new double[] {-1, -1, -1}), RectangularMatrix.eye(3).times(10));
25 | for (int i = 0; i < instances.length; i++) {
26 | if (Distribution.random.nextBoolean()) {
27 | instances[i] = mga.sample(null);
28 | } else {
29 | instances[i] = mgb.sample(null);
30 | }
31 | }
32 | DataSet set = new DataSet(instances);
33 | EMClusterer em = new EMClusterer();
34 | em.estimate(set);
35 | System.out.println(em);
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/GaussianProcessRegressionTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import func.GaussianProcessRegression;
6 | import func.svm.LinearKernel;
7 |
8 | /**
9 | * Test the class
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class GaussianProcessRegressionTest {
14 |
15 | /**
16 | * Test main
17 | * @param args ignored
18 | */
19 | public static void main(String[] args) {
20 | Instance[] instances = {
21 | new Instance(new double[] {1}, -1),
22 | new Instance(new double[] {1}, -1),
23 | new Instance(new double[] {1}, -1),
24 | new Instance(new double[] {1}, -1),
25 | new Instance(new double[] {-1}, 1),
26 | new Instance(new double[] {-1}, 1),
27 | new Instance(new double[] {-1}, 1),
28 | new Instance(new double[] {-1}, 1)
29 | };
30 | Instance[] tests = {
31 | new Instance(new double[] {-1}),
32 | new Instance(new double[] {-1}),
33 | new Instance(new double[] {1}),
34 | new Instance(new double[] {1})
35 | };
36 | DataSet set = new DataSet(instances);
37 | GaussianProcessRegression gp = new GaussianProcessRegression(
38 | new LinearKernel(), .01);
39 | gp.estimate(set);
40 | System.out.println(gp);
41 | for (int i = 0; i < tests.length; i++) {
42 | System.out.println(gp.value(tests[i]));
43 | }
44 | }
45 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/KMeansClustererTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import dist.Distribution;
4 | import dist.MultivariateGaussian;
5 | import func.KMeansClusterer;
6 | import shared.DataSet;
7 | import shared.Instance;
8 | import util.linalg.DenseVector;
9 | import util.linalg.RectangularMatrix;
10 |
11 | /**
12 | * Testing
13 | * @author Andrew Guillory gtg008g@mail.gatech.edu
14 | * @version 1.0
15 | */
16 | public class KMeansClustererTest {
17 | /**
18 | * The test main
19 | * @param args ignored
20 | */
21 | public static void main(String[] args) throws Exception {
22 | Instance[] instances = new Instance[100];
23 | MultivariateGaussian mga = new MultivariateGaussian(new DenseVector(new double[] {10, 20, 30}), RectangularMatrix.eye(3).times(.5));
24 | MultivariateGaussian mgb = new MultivariateGaussian(new DenseVector(new double[] {-2, -3, -1}), RectangularMatrix.eye(3).times(.4));
25 | for (int i = 0; i < instances.length; i++) {
26 | if (Distribution.random.nextBoolean()) {
27 | instances[i] = mga.sample(null);
28 | } else {
29 | instances[i] = mgb.sample(null);
30 | }
31 | }
32 | DataSet set = new DataSet(instances);
33 | KMeansClusterer km = new KMeansClusterer();
34 | km.estimate(set);
35 | System.out.println(km);
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/KNNClassifierTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import java.util.Arrays;
4 |
5 | import shared.DataSet;
6 | import shared.Instance;
7 |
8 | import func.inst.KDTree;
9 |
10 | /**
11 | * A knn test
12 | * @author Andrew Guillory gtg008g@mail.gatech.edu
13 | * @version 1.0
14 | */
15 | public class KNNClassifierTest {
16 |
17 | /**
18 | * The main method
19 | * @param args ignored
20 | */
21 | public static void main(String[] args) {
22 | Instance[] keys = {
23 | new Instance(new double[] { 1 , 1}),
24 | new Instance(new double[] { 2 , 2}),
25 | new Instance(new double[] { 3 , 3}),
26 | new Instance(new double[] { 4 , 4}),
27 | new Instance(new double[] { 5 , 5}),
28 | new Instance(new double[] { 6 , 6}),
29 | };
30 |
31 |
32 | KDTree tree = new KDTree(new DataSet(keys));
33 | Instance[] results = tree.knn(new Instance(new double[] { 2, 2 }), 4);
34 | System.out.println(Arrays.asList(results));
35 | }
36 |
37 | }
38 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/NNBinaryClassificationTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import shared.ConvergenceTrainer;
4 | import shared.DataSet;
5 | import shared.Instance;
6 | import shared.SumOfSquaresError;
7 | import func.nn.backprop.*;
8 |
9 | /**
10 | * An XOR test
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public class NNBinaryClassificationTest {
15 |
16 | /**
17 | * Tests out the perceptron with the classic xor test
18 | * @param args ignored
19 | */
20 | public static void main(String[] args) {
21 | BackPropagationNetworkFactory factory =
22 | new BackPropagationNetworkFactory();
23 | double[][][] data = {
24 | { { 0 }, { 0 } },
25 | { { 0 }, { 1 } },
26 | { { 0 }, { 1 } },
27 | };
28 | Instance[] patterns = new Instance[data.length];
29 | for (int i = 0; i < patterns.length; i++) {
30 | patterns[i] = new Instance(data[i][0]);
31 | patterns[i].setLabel(new Instance(data[i][1]));
32 | }
33 | BackPropagationNetwork network = factory.createClassificationNetwork(
34 | new int[] { 2, 2, 1 });
35 | DataSet set = new DataSet(patterns);
36 | ConvergenceTrainer trainer = new ConvergenceTrainer(
37 | new BatchBackPropagationTrainer(set, network,
38 | new SumOfSquaresError(), new RPROPUpdateRule()));
39 | trainer.train();
40 | System.out.println("Convergence in "
41 | + trainer.getIterations() + " iterations");
42 | for (int i = 0; i < patterns.length; i++) {
43 | network.setInputValues(patterns[i].getData());
44 | network.run();
45 | System.out.println("~~");
46 | System.out.println(patterns[i].getLabel());
47 | System.out.println(network.getOutputValues());
48 | }
49 | }
50 |
51 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/NNRegressionTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import shared.ConvergenceTrainer;
4 | import shared.DataSet;
5 | import shared.Instance;
6 | import shared.SumOfSquaresError;
7 | import func.nn.backprop.*;
8 |
9 | /**
10 | * An XOR test
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public class NNRegressionTest {
15 |
16 | /**
17 | * Tests out the perceptron with the classic xor test
18 | * @param args ignored
19 | */
20 | public static void main(String[] args) {
21 | BackPropagationNetworkFactory factory =
22 | new BackPropagationNetworkFactory();
23 | double[][][] data = {
24 | { { 1, 1 }, { -1 } },
25 | { { 1, 0 }, { 1 } },
26 | { { 0, 1 }, { 1 } },
27 | { { 0, 0 }, { -1 } }
28 | };
29 | Instance[] patterns = new Instance[data.length];
30 | for (int i = 0; i < patterns.length; i++) {
31 | patterns[i] = new Instance(data[i][0]);
32 | patterns[i].setLabel(new Instance(data[i][1]));
33 | }
34 | BackPropagationNetwork network = factory.createRegressionNetwork(
35 | new int[] { 2, 2, 1 });
36 | DataSet set = new DataSet(patterns);
37 | ConvergenceTrainer trainer = new ConvergenceTrainer(
38 | new BatchBackPropagationTrainer(set, network,
39 | new SumOfSquaresError(), new RPROPUpdateRule()));
40 | trainer.train();
41 | System.out.println("Convergence in "
42 | + trainer.getIterations() + " iterations");
43 | for (int i = 0; i < patterns.length; i++) {
44 | network.setInputValues(patterns[i].getData());
45 | network.run();
46 | System.out.println("~~");
47 | System.out.println(patterns[i].getLabel());
48 | System.out.println(network.getOutputValues());
49 | }
50 | }
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/PruningCriteriaTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import shared.DataSet;
4 | import shared.DataSetDescription;
5 | import shared.Instance;
6 | import func.dtree.DecisionTreeSplit;
7 | import func.dtree.DecisionTreeSplitStatistics;
8 | import func.dtree.ChiSquarePruningCriteria;
9 | import func.dtree.StandardDecisionTreeSplit;
10 |
11 | /**
12 | * Test the class
13 | * @author Andrew Guillory gtg008g@mail.gatech.edu
14 | * @version 1.0
15 | */
16 | public class PruningCriteriaTest {
17 |
18 | /**
19 | * Test main
20 | * @param args ignored
21 | */
22 | public static void main(String[] args) {
23 | Instance[] instances = {
24 | new Instance(new double[] {0, 0, 0, 1}, 1),
25 | new Instance(new double[] {1, 0, 0, 0}, 1),
26 | new Instance(new double[] {1, 0, 0, 0}, 1),
27 | new Instance(new double[] {1, 0, 0, 0}, 1),
28 | new Instance(new double[] {1, 0, 0, 1}, 0),
29 | new Instance(new double[] {1, 0, 0, 1}, 0),
30 | new Instance(new double[] {1, 0, 0, 1}, 0),
31 | new Instance(new double[] {1, 0, 0, 1}, 0)
32 | };
33 | DataSet set = new DataSet(instances);
34 | set.setDescription(new DataSetDescription(set));
35 | ChiSquarePruningCriteria cspc = new ChiSquarePruningCriteria(1);
36 | for (int i = 0; i < 4; i++) {
37 | DecisionTreeSplit split =
38 | new StandardDecisionTreeSplit(i, 2);
39 | DecisionTreeSplitStatistics stats =
40 | new DecisionTreeSplitStatistics(split, set);
41 | System.out.println("\nAttribute " + i);
42 | System.out.println("Should prune? " + cspc.shouldPrune(stats));
43 | }
44 | }
45 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/func/test/SplitEvaluatorTest.java:
--------------------------------------------------------------------------------
1 | package func.test;
2 |
3 | import shared.DataSet;
4 | import shared.DataSetDescription;
5 | import shared.Instance;
6 | import func.dtree.DecisionTreeSplit;
7 | import func.dtree.DecisionTreeSplitStatistics;
8 | import func.dtree.GINISplitEvaluator;
9 | import func.dtree.InformationGainSplitEvaluator;
10 | import func.dtree.StandardDecisionTreeSplit;
11 |
12 | /**
13 | * Test the class
14 | * @author Andrew Guillory gtg008g@mail.gatech.edu
15 | * @version 1.0
16 | */
17 | public class SplitEvaluatorTest {
18 |
19 | /**
20 | * Test main
21 | * @param args ignored
22 | */
23 | public static void main(String[] args) {
24 | Instance[] instances = {
25 | new Instance(new double[] {0, 0, 0, 1}, 1),
26 | new Instance(new double[] {1, 0, 0, 0}, 1),
27 | new Instance(new double[] {1, 0, 0, 0}, 1),
28 | new Instance(new double[] {1, 0, 0, 0}, 1),
29 | new Instance(new double[] {1, 0, 0, 1}, 0),
30 | new Instance(new double[] {1, 0, 0, 1}, 0),
31 | new Instance(new double[] {1, 0, 0, 1}, 0),
32 | new Instance(new double[] {1, 0, 0, 1}, 0)
33 | };
34 | DataSet set = new DataSet(instances);
35 | set.setDescription(new DataSetDescription(set));
36 | InformationGainSplitEvaluator ie = new InformationGainSplitEvaluator();
37 | GINISplitEvaluator ge = new GINISplitEvaluator();
38 | for (int i = 0; i < 4; i++) {
39 | DecisionTreeSplit split =
40 | new StandardDecisionTreeSplit(i, 2);
41 | DecisionTreeSplitStatistics stats =
42 | new DecisionTreeSplitStatistics(split, set);
43 | System.out.println("\nAttribute " + i);
44 | System.out.println("Information gain: " + ie.splitValue(stats));
45 | System.out.println("GINI index: " + ge.splitValue(stats));
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ContinuousAddOneNeighbor.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import dist.Distribution;
4 |
5 | import shared.Instance;
6 |
7 | /**
8 | * A continuous add one neighbor function
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class ContinuousAddOneNeighbor implements NeighborFunction {
13 | /**
14 | * The amount to add to the value
15 | */
16 | private double amount;
17 |
18 | /**
19 | * Continuous add one neighbor
20 | * @param amount the amount to add
21 | */
22 | public ContinuousAddOneNeighbor(double amount) {
23 | this.amount = amount;
24 | }
25 |
26 | /**
27 | * Continuous add one neighbor
28 | */
29 | public ContinuousAddOneNeighbor() {
30 | this(1);
31 | }
32 |
33 | /**
34 | * @see opt.NeighborFunction#neighbor(opt.OptimizationData)
35 | */
36 | public Instance neighbor(Instance d) {
37 | int i = Distribution.random.nextInt(d.size());
38 | Instance cod = (Instance) d.copy();
39 | cod.getData().set(i, cod.getContinuous(i)+ Distribution.random.nextDouble() * amount - amount / 2);
40 | return cod;
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/DiscreteChangeOneNeighbor.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import dist.Distribution;
4 |
5 | import shared.Instance;
6 |
7 | /**
8 | * A neighbor function for changing a single value
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class DiscreteChangeOneNeighbor implements NeighborFunction {
13 |
14 | /**
15 | * The ranges of the different values
16 | */
17 | private int[] ranges;
18 |
19 | /**
20 | * Make a new change one neighbor function
21 | * @param ranges the ranges of the data
22 | */
23 | public DiscreteChangeOneNeighbor(int[] ranges) {
24 | this.ranges = ranges;
25 | }
26 |
27 | /**
28 | * @see opt.NeighborFunction#neighbor(opt.OptimizationData)
29 | */
30 | public Instance neighbor(Instance d) {
31 | Instance cod = (Instance) d.copy();
32 | int i = Distribution.random.nextInt(ranges.length);
33 | cod.getData().set(i, Distribution.random.nextInt(ranges[i]));
34 | return cod;
35 | }
36 |
37 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/EvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A class representing an evaluation function
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface EvaluationFunction {
11 |
12 | /**
13 | * Evaluate a data
14 | * @param d the data to evaluate
15 | * @return the value
16 | */
17 | public abstract double value(Instance d);
18 |
19 | }
20 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/GenericHillClimbingProblem.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import shared.Instance;
4 | import dist.Distribution;
5 |
6 | /**
7 | * A generic hill climbing problem
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public class GenericHillClimbingProblem extends GenericOptimizationProblem implements HillClimbingProblem {
12 |
13 | /**
14 | * The neighbor function
15 | */
16 | private NeighborFunction neigh;
17 |
18 | /**
19 | * Make a new hill climbing problem
20 | * @param eval the evaulation function
21 | * @param dist the initial distribution
22 | * @param neigh the neighbor function
23 | */
24 | public GenericHillClimbingProblem(EvaluationFunction eval, Distribution dist,
25 | NeighborFunction neigh) {
26 | super(eval, dist);
27 | this.neigh = neigh;
28 | }
29 |
30 | /**
31 | * @see opt.HillClimbingProblem#neighbor(opt.OptimizationData)
32 | */
33 | public Instance neighbor(Instance d) {
34 | return neigh.neighbor(d);
35 | }
36 |
37 | }
38 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/GenericOptimizationProblem.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import shared.Instance;
4 | import dist.Distribution;
5 |
6 |
7 | /**
8 | * A generic continuous optimization problem
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class GenericOptimizationProblem implements OptimizationProblem {
13 | /**
14 | * The evaluation function
15 | */
16 | private EvaluationFunction eval;
17 |
18 | /**
19 | * The intial distribution
20 | */
21 | private Distribution initial;
22 |
23 |
24 | /**
25 | * Make a new generic optimization problem
26 | * @param dist the initial distribution
27 | * @param eval the evaluation function
28 | */
29 | public GenericOptimizationProblem(EvaluationFunction eval, Distribution dist) {
30 | this.initial = dist;
31 | this.eval = eval;
32 | }
33 |
34 |
35 | /**
36 | * @see opt.OptimizationProblem#value(opt.OptimizationData)
37 | */
38 | public double value(Instance d) {
39 | return eval.value(d);
40 | }
41 |
42 |
43 | /**
44 | * @see opt.OptimizationProblem#random()
45 | */
46 | public Instance random() {
47 | return initial.sample(null);
48 | }
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/HillClimbingProblem.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A problem that can be solved through ill climbing.
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface HillClimbingProblem extends OptimizationProblem {
11 |
12 | /**
13 | * Find a neighbor to the given piece of data
14 | * @param d the data to find the neighbor of
15 | * @return the data
16 | */
17 | public abstract Instance neighbor(Instance d);
18 | }
19 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/NeighborFunction.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A neighbor function
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface NeighborFunction {
11 | /**
12 | * Get the neighbor of a piece of data
13 | * @param d the data
14 | * @return the neighbor
15 | */
16 | public Instance neighbor(Instance d);
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/OptimizationAlgorithm.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import shared.Instance;
4 | import shared.Trainer;
5 |
6 | /**
7 | * An abstract class for optimzation algorithms
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public abstract class OptimizationAlgorithm implements Trainer {
12 | /**
13 | * The problem to optimize
14 | */
15 | private OptimizationProblem op;
16 |
17 | /**
18 | * Make a new optimization algorithm
19 | * @param op the problem to optimize
20 | */
21 | public OptimizationAlgorithm(OptimizationProblem op) {
22 | this.op = op;
23 | }
24 |
25 | /**
26 | * Get an optimization problem
27 | * @return the problem
28 | */
29 | public OptimizationProblem getOptimizationProblem() {
30 | return op;
31 | }
32 |
33 | /**
34 | * Get the optimal data
35 | * @return the data
36 | */
37 | public abstract Instance getOptimal();
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/OptimizationProblem.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A class representing an optimization problem.
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface OptimizationProblem {
11 |
12 | /**
13 | * Evaluate the given data
14 | * @param d the data to evaluate
15 | * @return the value of the data.
16 | */
17 | public abstract double value(Instance d);
18 |
19 | /**
20 | * Draw a random sample of optimization data.
21 | * @return the sampled data.
22 | */
23 | public abstract Instance random();
24 |
25 | }
26 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/RandomizedHillClimbing.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A randomized hill climbing algorithm
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class RandomizedHillClimbing extends OptimizationAlgorithm {
11 |
12 | /**
13 | * The current optimization data
14 | */
15 | private Instance cur;
16 |
17 | /**
18 | * The current value of the data
19 | */
20 | private double curVal;
21 |
22 | /**
23 | * Make a new randomized hill climbing
24 | */
25 | public RandomizedHillClimbing(HillClimbingProblem hcp) {
26 | super(hcp);
27 | cur = hcp.random();
28 | curVal = hcp.value(cur);
29 | }
30 |
31 | /**
32 | * @see shared.Trainer#train()
33 | */
34 | public double train() {
35 | HillClimbingProblem hcp = (HillClimbingProblem) getOptimizationProblem();
36 | Instance neigh = hcp.neighbor(cur);
37 | double neighVal = hcp.value(neigh);
38 | if (neighVal > curVal) {
39 | curVal = neighVal;
40 | cur = neigh;
41 | }
42 | return curVal;
43 | }
44 |
45 | /**
46 | * @see opt.OptimizationAlgorithm#getOptimalData()
47 | */
48 | public Instance getOptimal() {
49 | return cur;
50 | }
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/SimulatedAnnealing.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import dist.Distribution;
4 |
5 | import shared.Instance;
6 |
7 | /**
8 | * A simulated annealing hill climbing algorithm
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class SimulatedAnnealing extends OptimizationAlgorithm {
13 |
14 | /**
15 | * The current optimiation data
16 | */
17 | private Instance cur;
18 |
19 | /**
20 | * The current optimization value
21 | */
22 | private double curVal;
23 |
24 | /**
25 | * The current temperature
26 | */
27 | private double t;
28 |
29 | /**
30 | * The cooling parameter
31 | */
32 | private double cooling;
33 |
34 | /**
35 | * Make a new simulated annealing hill climbing
36 | * @param t the starting temperature
37 | * @param cooling the cooling exponent
38 | * @param hcp the problem to solve
39 | */
40 | public SimulatedAnnealing(double t, double cooling, HillClimbingProblem hcp) {
41 | super(hcp);
42 | this.t = t;
43 | this.cooling = cooling;
44 | this.cur = hcp.random();
45 | this.curVal = hcp.value(cur);
46 | }
47 |
48 | /**
49 | * @see shared.Trainer#train()
50 | */
51 | public double train() {
52 | HillClimbingProblem p = (HillClimbingProblem) getOptimizationProblem();
53 | Instance neigh = p.neighbor(cur);
54 | double neighVal = p.value(neigh);
55 | if (neighVal > curVal || Distribution.random.nextDouble() <
56 | Math.exp((neighVal - curVal) / t)) {
57 | curVal = neighVal;
58 | cur = neigh;
59 | }
60 | t *= cooling;
61 | return curVal;
62 | }
63 |
64 | /**
65 | * @see opt.OptimizationAlgorithm#getOptimal()
66 | */
67 | public Instance getOptimal() {
68 | return cur;
69 | }
70 |
71 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/SwapNeighbor.java:
--------------------------------------------------------------------------------
1 | package opt;
2 |
3 | import dist.Distribution;
4 |
5 | import shared.Instance;
6 |
7 | /**
8 | * A swap one neighbor function
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class SwapNeighbor implements NeighborFunction {
13 |
14 | /**
15 | * @see opt.ga.MutationFunction#mutate(opt.OptimizationData)
16 | */
17 | public Instance neighbor(Instance d) {
18 | Instance cod = (Instance) d.copy();
19 | int i = Distribution.random.nextInt(cod.getData().size());
20 | int j = Distribution.random.nextInt(cod.getData().size());
21 | double temp = cod.getContinuous(i);
22 | cod.getData().set(i, cod.getContinuous(j));
23 | cod.getData().set(j, temp);
24 | return cod;
25 | }
26 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/ContinuousPeaksEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import util.linalg.Vector;
4 | import opt.EvaluationFunction;
5 | import shared.Instance;
6 |
7 | /**
8 | * A continuous peaks function
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class ContinuousPeaksEvaluationFunction implements EvaluationFunction {
13 | /**
14 | * The t value
15 | */
16 | private int t;
17 | public long fevals;
18 |
19 | /**
20 | * Make a new continuous peaks function
21 | * @param t the t value
22 | */
23 | public ContinuousPeaksEvaluationFunction(int t) {
24 | this.t = t;
25 | }
26 |
27 | /**
28 | * @see opt.EvaluationFunction#value(opt.OptimizationData)
29 | */
30 | public double value(Instance d) {
31 | Vector data = d.getData();
32 | int max0 = 0;
33 | int count = 0;
34 | for (int i = 0; i < data.size(); i++) {
35 | if (data.get(i) == 0) {
36 | count++;
37 | } else {
38 | if (count > max0) {
39 | max0 = count;
40 | count = 0;
41 | }
42 | }
43 | }
44 | int max1 = 0;
45 | count = 0;
46 | for (int i = 0; i < data.size(); i++) {
47 | if (data.get(i) == 1) {
48 | count++;
49 | } else {
50 | if (count > max1) {
51 | max1 = count;
52 | count = 0;
53 | }
54 | }
55 | }
56 | int r = 0;
57 | if (max1 > t && max0 > t) {
58 | r = data.size();
59 | }
60 | this.fevals = this.fevals +1;
61 | return Math.max(max1, max0) + r;
62 | }
63 | }
64 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/CountOnesEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import util.linalg.Vector;
4 | import opt.EvaluationFunction;
5 | import shared.Instance;
6 |
7 | /**
8 | * A function that counts the ones in the data
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class CountOnesEvaluationFunction implements EvaluationFunction {
13 | /**
14 | * @see opt.EvaluationFunction#value(opt.OptimizationData)
15 | */
16 | public double value(Instance d) {
17 | Vector data = d.getData();
18 | double val = 0;
19 | for (int i = 0; i < data.size(); i++) {
20 | if (data.get(i) == 1) {
21 | val++;
22 | }
23 | }
24 | return val;
25 | }
26 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/FlipFlopEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import util.linalg.Vector;
4 | import opt.EvaluationFunction;
5 | import shared.Instance;
6 |
7 | /**
8 | * A function that counts the ones in the data
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class FlipFlopEvaluationFunction implements EvaluationFunction {
13 | /**
14 | * @see opt.EvaluationFunction#value(opt.OptimizationData)
15 | */
16 | public long fevals;
17 | public double value(Instance d) {
18 | Vector data = d.getData();
19 | double val = 0;
20 | for (int i = 0; i < data.size() - 1; i++) {
21 | if (data.get(i) != data.get(i + 1)) {
22 | val++;
23 | }
24 | }
25 | this.fevals = this.fevals + 1;
26 | return val;
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/FlipFlopMODEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import util.linalg.Vector;
4 | import opt.EvaluationFunction;
5 | import shared.Instance;
6 |
7 | /**
8 | * A function that counts the ones in the data
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class FlipFlopMODEvaluationFunction implements EvaluationFunction {
13 | /**
14 | * @see opt.EvaluationFunction#value(opt.OptimizationData)
15 | */
16 | /* modified to favour a 1 at the start */
17 | public long fevals;
18 | public double value(Instance d) {
19 | Vector data = d.getData();
20 | double val = 0;
21 | for (int i = 0; i < data.size() - 1; i++) {
22 | if (data.get(i) != data.get(i + 1)) {
23 | val++;
24 | }
25 | }
26 | if (data.get(0) >0) {
27 | val=val+10;
28 | }
29 | this.fevals = this.fevals + 1;
30 | return val;
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/FourPeaksEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import util.linalg.Vector;
4 | import opt.EvaluationFunction;
5 | import shared.Instance;
6 |
7 | /**
8 | * A four peaks evaluation function
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class FourPeaksEvaluationFunction implements EvaluationFunction {
13 | /**
14 | * The t value
15 | */
16 | private int t;
17 |
18 | /**
19 | * Make a new four peaks function
20 | * @param t the t value
21 | */
22 | public FourPeaksEvaluationFunction(int t) {
23 | this.t = t;
24 | }
25 |
26 | /**
27 | * @see opt.EvaluationFunction#value(opt.OptimizationData)
28 | */
29 | public double value(Instance d) {
30 | Vector data = d.getData();
31 | int i = 0;
32 | while (i < data.size() && data.get(i) == 1) {
33 | i++;
34 | }
35 | int head = i;
36 | i = data.size() - 1;
37 | while (i >= 0 && data.get(i) == 0) {
38 | i--;
39 | }
40 | int tail = data.size() - 1 - i;
41 | int r = 0;
42 | if (head > t && tail > t) {
43 | r = data.size();
44 | }
45 | return Math.max(tail, head) + r;
46 | }
47 |
48 |
49 | }
50 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/NeuralNetworkEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import util.linalg.Vector;
4 | import func.nn.NeuralNetwork;
5 | import opt.EvaluationFunction;
6 | import shared.DataSet;
7 | import shared.ErrorMeasure;
8 | import shared.Instance;
9 |
10 | /**
11 | * An evaluation function that uses a neural network
12 | * @author Andrew Guillory gtg008g@mail.gatech.edu
13 | * @version 1.0
14 | */
15 | public class NeuralNetworkEvaluationFunction implements EvaluationFunction {
16 | /**
17 | * The network
18 | */
19 | private NeuralNetwork network;
20 | /**
21 | * The examples
22 | */
23 | private DataSet examples;
24 | /**
25 | * The error measure
26 | */
27 | private ErrorMeasure measure;
28 |
29 | /**
30 | * Make a new neural network evaluation function
31 | * @param network the network
32 | * @param examples the examples
33 | * @param measure the error measure
34 | */
35 | public NeuralNetworkEvaluationFunction(NeuralNetwork network,
36 | DataSet examples, ErrorMeasure measure) {
37 | this.network = network;
38 | this.examples = examples;
39 | this.measure = measure;
40 | }
41 |
42 | /**
43 | * @see opt.OptimizationProblem#value(opt.OptimizationData)
44 | */
45 | public double value(Instance d) {
46 | // set the links
47 | Vector weights = d.getData();
48 | network.setWeights(weights);
49 | // calculate the error
50 | double error = 0;
51 | for (int i = 0; i < examples.size(); i++) {
52 | network.setInputValues(examples.get(i).getData());
53 | network.run();
54 | error += measure.value(new Instance(network.getOutputValues()), examples.get(i));
55 | }
56 | // the fitness is 1 / error
57 | return 1 / error;
58 | }
59 |
60 | }
61 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/NeuralNetworkWeightDistribution.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import dist.AbstractDistribution;
4 |
5 | import shared.DataSet;
6 | import shared.Instance;
7 |
8 | /**
9 | * A distribution for neural network weights
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class NeuralNetworkWeightDistribution extends AbstractDistribution {
14 |
15 | /**
16 | * The weight count
17 | */
18 | private int weightCount;
19 |
20 | /**
21 | * Make a new neural network weight distribution
22 | * @param weightCount the weight count
23 | */
24 | public NeuralNetworkWeightDistribution(int weightCount) {
25 | this.weightCount = weightCount;
26 | }
27 |
28 | /**
29 | * @see dist.Distribution#probabilityOf(shared.Instance)
30 | */
31 | public double p(Instance i) {
32 | return 1;
33 | }
34 |
35 | /**
36 | * @see dist.Distribution#generateRandom(shared.Instance)
37 | */
38 | public Instance sample(Instance ignored) {
39 | double[] weights = new double[weightCount];
40 | for (int i = 0; i < weights.length; i++) {
41 | weights[i] = random.nextDouble() - .5;
42 | }
43 | return new Instance(weights);
44 | }
45 |
46 | /**
47 | * @see dist.Distribution#generateMostLikely(shared.Instance)
48 | */
49 | public Instance mode(Instance ignored) {
50 | return sample(ignored);
51 | }
52 |
53 | /**
54 | * @see dist.Distribution#estimate(shared.DataSet)
55 | */
56 | public void estimate(DataSet observations) {
57 | return;
58 | }
59 |
60 |
61 |
62 | }
63 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/TravelingSalesmanEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import opt.EvaluationFunction;
4 |
5 | /**
6 | * An evaluation function for the traveling salesman problem
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public abstract class TravelingSalesmanEvaluationFunction implements EvaluationFunction {
11 | /**
12 | * The distance between city i and j
13 | */
14 | private double[][] distances;
15 | public long fevals;
16 | /**
17 | * Make a new traveling salesman evaluation function
18 | * @param points the points at which the cities are located
19 | */
20 | public TravelingSalesmanEvaluationFunction(double[][] points) {
21 | distances = new double[points.length][];
22 | for (int i = 0; i < points.length; i++) {
23 | distances[i] = new double[i];
24 | for (int j = 0; j < i; j++) {
25 | double[] a = points[i];
26 | double[] b = points[j];
27 | distances[i][j] = Math.sqrt(Math.pow(a[0] - b[0], 2)
28 | + Math.pow(a[1] - b[1], 2));
29 | }
30 | }
31 | }
32 |
33 | /**
34 | * Get the distance between two points
35 | * @param i the first point
36 | * @param j the second
37 | * @return the distance
38 | */
39 | public double getDistance(int i, int j) {
40 | if (i==j) {
41 | return 0;
42 | } else {
43 | int a = Math.max(i,j);
44 | int b = Math.min(i,j);
45 | return distances[a][b];
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/TravelingSalesmanRouteEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * An implementation of the traveling salesman problem
7 | * where the encoding used is a permutation of [0, ..., n]
8 | * where there are n+1 cities. That is the encoding
9 | * is just the path to take.
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class TravelingSalesmanRouteEvaluationFunction extends TravelingSalesmanEvaluationFunction {
14 |
15 | /**
16 | * Make a new route evaluation function
17 | * @param points the points of the cities
18 | */
19 | public TravelingSalesmanRouteEvaluationFunction(double[][] points) {
20 | super(points);
21 | }
22 | public long fevals;
23 | /**
24 | * @see opt.EvaluationFunction#value(opt.OptimizationData)
25 | */
26 | public double value(Instance d) {
27 | double distance = 0;
28 | for (int i = 0; i < d.size() - 1; i++) {
29 | distance += getDistance(d.getDiscrete(i), d.getDiscrete(i+1));
30 | }
31 | distance += getDistance(d.getDiscrete(d.size() - 1), d.getDiscrete(0));
32 | this.fevals = this.fevals +1;
33 | return 1/distance;
34 | }
35 |
36 |
37 | }
38 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/TravelingSalesmanSortEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import shared.Instance;
4 | import util.ABAGAILArrays;
5 |
6 | /**
7 | * A traveling salesman evaluation function that works with
8 | * routes that are encoded as sorts. That is the route
9 | * is the permutaiton of indices found by sorting the data.
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class TravelingSalesmanSortEvaluationFunction extends TravelingSalesmanEvaluationFunction {
14 |
15 | /**
16 | * Make a new traveling salesman evaluation function
17 | * @param points the points at which the cities are located
18 | */
19 | public TravelingSalesmanSortEvaluationFunction(double[][] points) {
20 | super(points);
21 | }
22 | public long fevals;
23 | /**
24 | * @see opt.EvaluationFunction#value(opt.OptimizationData)
25 | */
26 | public double value(Instance d) {
27 | double[] ddata = new double[d.size()];
28 | for (int i = 0; i < ddata.length; i++) {
29 | ddata[i] = d.getContinuous(i);
30 | }
31 | int[] order = ABAGAILArrays.indices(d.size());
32 | ABAGAILArrays.quicksort(ddata, order);
33 | double distance = 0;
34 | for (int i = 0; i < order.length - 1; i++) {
35 | distance += getDistance(order[i], order[i+1]);
36 | }
37 | distance += getDistance(order[order.length - 1], order[0]);
38 | this.fevals = this.fevals +1;
39 | return 1/distance;
40 | }
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/example/TwoColorsEvaluationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.example;
2 |
3 | import util.linalg.Vector;
4 | import opt.EvaluationFunction;
5 | import shared.Instance;
6 |
7 | /**
8 | * A function that evaluates whether a vector represents a 2-colored graph
9 | * @author Daniel Cohen dcohen@gatech.edu
10 | * @version 1.0
11 | */
12 | public class TwoColorsEvaluationFunction implements EvaluationFunction {
13 | public long fevals;
14 | /**
15 | * @see opt.EvaluationFunction#value(opt.OptimizationData)
16 | */
17 | public double value(Instance d) {
18 | Vector data = d.getData();
19 | double val = 0;
20 | for (int i = 1; i < data.size() - 1; i++) {
21 | if ((data.get(i) != data.get(i-1)) && (data.get(i) != data.get(i+1))) {
22 | val++;
23 | }
24 | }
25 | this.fevals = this.fevals +1;
26 | return val;
27 | }
28 |
29 | }
30 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/ContinuousAddOneMutation.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import dist.Distribution;
4 |
5 | import shared.Instance;
6 |
7 | /**
8 | * A continuous add one neighbor function
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class ContinuousAddOneMutation implements MutationFunction {
13 | /**
14 | * The amount to add to the value
15 | */
16 | private double amount;
17 |
18 | /**
19 | * Continuous add one neighbor
20 | * @param amount the amount to add
21 | */
22 | public ContinuousAddOneMutation(double amount) {
23 | this.amount = amount;
24 | }
25 |
26 | /**
27 | * Continuous add one neighbor
28 | */
29 | public ContinuousAddOneMutation() {
30 | this(1);
31 | }
32 |
33 | /**
34 | * @see opt.ga.MutationFunction
35 | */
36 | public void mutate(Instance cod) {
37 | int i = Distribution.random.nextInt(cod.size());
38 | cod.getData().set(i, cod.getContinuous(i)+ Distribution.random.nextDouble() * amount - amount / 2);
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/CrossoverFunction.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * An interface for cross over functions
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface CrossoverFunction {
11 |
12 | /**
13 | * Mate two candidate solutions
14 | * @param a the first solution
15 | * @param b the second
16 | * @return the mated solution
17 | */
18 | public Instance mate(Instance a, Instance b);
19 |
20 | }
21 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/DiscreteChangeOneMutation.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import dist.Distribution;
4 |
5 | import shared.Instance;
6 |
7 | /**
8 | * A mutation function for changing a single value
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class DiscreteChangeOneMutation implements MutationFunction {
13 | /**
14 | * The ranges of the different values
15 | */
16 | private int[] ranges;
17 |
18 | /**
19 | * Make a new discrete change one mutation function
20 | * @param ranges the ranges of the data
21 | */
22 | public DiscreteChangeOneMutation(int[] ranges) {
23 | this.ranges = ranges;
24 | }
25 |
26 | /**
27 | * @see opt.ga.MutationFunction#mutate(opt.OptimizationData)
28 | */
29 | public void mutate(Instance d) {
30 | int i = Distribution.random.nextInt(d.size());
31 | d.getData().set(i, Distribution.random.nextInt(ranges[i]));
32 | }
33 |
34 | }
35 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/GenericGeneticAlgorithmProblem.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import dist.Distribution;
4 | import opt.EvaluationFunction;
5 | import opt.GenericOptimizationProblem;
6 | import shared.Instance;
7 |
8 | /**
9 | *
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class GenericGeneticAlgorithmProblem extends GenericOptimizationProblem implements
14 | GeneticAlgorithmProblem {
15 |
16 | /**
17 | * The cross over function
18 | */
19 | private CrossoverFunction crossover;
20 | /**
21 | * The mutation function
22 | */
23 | private MutationFunction mutation;
24 |
25 | /**
26 | * Make a new generic genetic algorithm problem
27 | * @param crossover the cross over operator
28 | * @param muation the mutation operator
29 | * @param eval the evaluation function
30 | * @param dist the initial distribution
31 | */
32 | public GenericGeneticAlgorithmProblem(EvaluationFunction eval, Distribution dist,
33 | MutationFunction mutation, CrossoverFunction crossover) {
34 | super(eval, dist);
35 | this.mutation = mutation;
36 | this.crossover = crossover;
37 | }
38 | /**
39 | * @see opt.ga.GeneticAlgorithmProblem#mate(opt.Instance, opt.Instance)
40 | */
41 | public Instance mate(Instance a, Instance b) {
42 | return crossover.mate(a, b);
43 | }
44 | /**
45 | * @see opt.ga.GeneticAlgorithmProblem#mutate(opt.Instance)
46 | */
47 | public void mutate(Instance d) {
48 | mutation.mutate(d);
49 | }
50 |
51 | }
52 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/GeneticAlgorithmProblem.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import opt.OptimizationProblem;
4 | import shared.Instance;
5 |
6 | /**
7 | * A genetic algorithm problem
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public interface GeneticAlgorithmProblem extends OptimizationProblem {
12 |
13 | /**
14 | * Mate two optimization datas
15 | * @param a the first one to mate
16 | * @param b the second one to mate
17 | * @return the result of mating them
18 | */
19 | public abstract Instance mate(Instance a, Instance b);
20 |
21 | /**
22 | * Mutate a observation data
23 | * @param d the data to mutate
24 | */
25 | public abstract void mutate(Instance d);
26 | }
27 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/MutationFunction.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * An interface for mutatation operators
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface MutationFunction {
11 | /**
12 | * Mutate the given optimization data
13 | * @param d the data to mutate
14 | */
15 | public abstract void mutate(Instance d);
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/SingleCrossOver.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import dist.Distribution;
4 |
5 | import shared.Instance;
6 |
7 | /**
8 | * Implementation of the single point crossover function for genetic algorithms.
9 | *
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class SingleCrossOver implements CrossoverFunction {
14 |
15 | /**
16 | * Mates two candidate solutions using single point crossover by choosing a point in the bit string, and creating
17 | * a crossover mask of 0s up to that point, then 1s after. The mated solution takes the first bits from the second
18 | * solution, and the remaining bits from the first.
19 | *
20 | * @param a the first solution
21 | * @param b the second solution
22 | * @return the mated solution
23 | */
24 | public Instance mate(Instance a, Instance b) {
25 | // Create space for the mated solution
26 | double[] newData = new double[a.size()];
27 |
28 | // Randomly assign the dividing point
29 | int point = Distribution.random.nextInt(newData.length + 1);
30 |
31 | // Assign the bits for the mated solution
32 | for (int i = 0; i < newData.length; i++) {
33 | if (i >= point) {
34 | newData[i] = a.getContinuous(i);
35 | } else {
36 | newData[i] = b.getContinuous(i);
37 | }
38 | }
39 |
40 | // Return the mated solution
41 | return new Instance(newData);
42 | }
43 |
44 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/SwapMutation.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import dist.Distribution;
4 | import shared.Instance;
5 |
6 | /**
7 | * A swap one mutation
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public class SwapMutation implements MutationFunction {
12 |
13 | /**
14 | * @see opt.ga.MutationFunction#mutate(opt.OptimizationData)
15 | */
16 | public void mutate(Instance d) {
17 | int i = Distribution.random.nextInt(d.size());
18 | int j = Distribution.random.nextInt(d.size());
19 | double temp = d.getContinuous(i);
20 | d.getData().set(i, d.getContinuous(j));
21 | d.getData().set(j, temp);
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/TwoPointCrossOver.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import dist.Distribution;
4 | import shared.Instance;
5 |
6 | /**
7 | * Implementation of the two-point crossover function for genetic algorithms.
8 | * Two-point crossover calls for two points to be selected on the parent organism strings. Everything between the two
9 | * points is swapped between the parent organisms.
10 | *
11 | * @author Avriel Harvey aharvey32@gatech.edu
12 | * @version 1.0
13 | */
14 | public class TwoPointCrossOver implements CrossoverFunction {
15 |
16 | /**
17 | * Mate two candidate solutions using two-point crossover.
18 | * Generates two random points x and y, and takes the bits [x,y) from the first instance, and the remaining bits
19 | * from the second instance.
20 | *
21 | * @param a the first solution
22 | * @param b the second solution
23 | * @return the mated solution
24 | */
25 | public Instance mate(Instance a, Instance b) {
26 | // Create space for the mated solution
27 | double[] newData = new double[a.size()];
28 |
29 | // Randomly assign the first point
30 | int firstPoint = Distribution.random.nextInt(newData.length + 1);
31 |
32 | // Make sure the second point comes after the first point
33 | int secondPoint = Distribution.random.nextInt(newData.length + 1 - firstPoint) + firstPoint;
34 |
35 | // Assign bits to the mated solution
36 | for (int i = 0; i < newData.length; i++) {
37 | if (i >= firstPoint && i < secondPoint) {
38 | newData[i] = a.getContinuous(i);
39 | } else {
40 | newData[i] = b.getContinuous(i);
41 | }
42 | }
43 |
44 | // Return the mated solution
45 | return new Instance(newData);
46 | }
47 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/UniformCrossOver.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import shared.Instance;
4 |
5 | import dist.Distribution;
6 |
7 | /**
8 | * Implementation of the uniform crossover function for genetic algorithms.
9 | *
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class UniformCrossOver implements CrossoverFunction {
14 |
15 | /**
16 | * Mates two candidate solutions by uniformly sampling bits to create the crossover mask, then taking the true
17 | * bits from the first solution and the remaining bits from the second solution.
18 | *
19 | * @param a the first solution
20 | * @param b the second solution
21 | * @return the mated solution.
22 | */
23 | public Instance mate(Instance a, Instance b) {
24 | // Create space for the mated solution
25 | double[] newData = new double[a.size()];
26 |
27 | // Assign bits to the mated solution
28 | for (int i = 0; i < newData.length; i++) {
29 | // Randomly pick a boolean value to determine which parent to take the ith bit from
30 | if (Distribution.random.nextBoolean()) {
31 | newData[i] = a.getContinuous(i);
32 | } else {
33 | newData[i] = b.getContinuous(i);
34 | }
35 | }
36 |
37 | // Return the mated solution
38 | return new Instance(newData);
39 | }
40 |
41 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/ga/Vertex.java:
--------------------------------------------------------------------------------
1 | package opt.ga;
2 |
3 | import java.util.ArrayList;
4 | import java.util.List;
5 |
6 | public class Vertex {
7 |
8 | private List adjacencyColorMatrix = new ArrayList(1);
9 |
10 | private int adjMatrixSize = 1;
11 |
12 | public void setAdjMatrixSize(int size){
13 | adjMatrixSize = size;
14 | }
15 |
16 | public List getAadjacencyColorMatrix(){
17 | return adjacencyColorMatrix;
18 | }
19 |
20 | }
21 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/prob/GenericProbabilisticOptimizationProblem.java:
--------------------------------------------------------------------------------
1 | package opt.prob;
2 |
3 | import dist.Distribution;
4 | import opt.EvaluationFunction;
5 | import opt.GenericOptimizationProblem;
6 |
7 | /**
8 | *
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class GenericProbabilisticOptimizationProblem extends GenericOptimizationProblem
13 | implements ProbabilisticOptimizationProblem {
14 | /**
15 | * The distribution
16 | */
17 | private Distribution dist;
18 |
19 | /**
20 | * Make a new generic probabilisitic optimiziation problem
21 | * @param eval the evaluation function
22 | * @param dist the initial parameter distribution
23 | * @param fact the distribution factory
24 | */
25 | public GenericProbabilisticOptimizationProblem(EvaluationFunction eval, Distribution dist,
26 | Distribution d) {
27 | super(eval, dist);
28 | this.dist = d;
29 | }
30 |
31 | /**
32 | * @see opt.prob.ProbabilisticOptimizationProblem#getDistribution()
33 | */
34 | public Distribution getDistribution() {
35 | return dist;
36 | }
37 |
38 |
39 |
40 | }
41 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/prob/ProbabilisticOptimizationProblem.java:
--------------------------------------------------------------------------------
1 | package opt.prob;
2 |
3 | import dist.Distribution;
4 | import opt.OptimizationProblem;
5 |
6 | /**
7 | * An optimization problem solvable by MIMIC
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public interface ProbabilisticOptimizationProblem extends OptimizationProblem {
12 |
13 | /**
14 | * Build a distribution from the given data
15 | * @param data the data to build the distribution from
16 | * @return the distrubtion
17 | */
18 | public abstract Distribution getDistribution();
19 |
20 | }
21 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/opt/test/CrossValidationTest.java:
--------------------------------------------------------------------------------
1 | package opt.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import shared.filt.kFoldSplitFilter;
6 | import shared.tester.CrossValidationTestMetric;
7 |
8 | public class CrossValidationTest {
9 | public static void main(String[] args) {
10 | /*
11 | // Load in a dataset using a DataSetReader or by generating it by hand
12 |
13 | // Split the data into k-folds (assuming 10 here)
14 | kFoldSplitFilter split = new kFoldSplitFilter(10);
15 |
16 | // Create a metric to evaluate the results of the cross validation
17 | CrossValidationTestMetric metric = new CrossValidationTestMetric(dataSet.size(), 10);
18 |
19 | // Loop through each fold
20 | for (DataSet set: split.getFolds()) {
21 |
22 | //Run algorithm on set
23 |
24 | // Perform cross validation
25 | for (DataSet fold: split.getValidationFolds(set)) {
26 | for (Instance inst: fold) {
27 | // Clean results (maybe thresold them to 0 or 1)
28 | // outputLabel is the label that the algorithm produced
29 | metric.addResult(inst.getLabel(), new Instance(outputLabel));
30 | }
31 | // Tell the metric we're moving onto the next validation fold
32 | metric.nextValidationFold();
33 | }
34 | // Tell the metric we're moving onto the next training fold (i.e. training on the next fold)
35 | metric.nextFold();
36 | }
37 |
38 | // Output the results to screen
39 | metric.printResults();
40 | */
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/rl/DecayingEpsilonGreedyStrategy.java:
--------------------------------------------------------------------------------
1 | package rl;
2 |
3 | import dist.Distribution;
4 |
5 | /**
6 | * An epsilon greedy exploration strategy
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class DecayingEpsilonGreedyStrategy implements ExplorationStrategy {
11 | /**
12 | * The epsilon value
13 | */
14 | private double epsilon;
15 | /**
16 | * The decay value
17 | */
18 | private double decay;
19 |
20 | /**
21 | * Make a epsilon greedy strategy
22 | * @param epsilon the epsilon value
23 | * @param decay the decay value
24 | */
25 | public DecayingEpsilonGreedyStrategy(double epsilon, double decay) {
26 | this.epsilon = epsilon;
27 | this.decay = decay;
28 | }
29 |
30 | /**
31 | * @see rl.ExplorationStrategy#action(double[])
32 | */
33 | public int action(double[] qvalues) {
34 | if (Distribution.random.nextDouble() < epsilon) {
35 | return Distribution.random.nextInt(qvalues.length);
36 | }
37 | epsilon *= decay;
38 | int best = 0;
39 | for (int i = 1; i < qvalues.length; i++) {
40 | if (qvalues[best] < qvalues[i]) {
41 | best = i;
42 | }
43 | }
44 | return best;
45 | }
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/rl/EpsilonGreedyStrategy.java:
--------------------------------------------------------------------------------
1 | package rl;
2 |
3 | import dist.Distribution;
4 |
5 | /**
6 | * An epsilon greedy exploration strategy
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class EpsilonGreedyStrategy implements ExplorationStrategy {
11 | /**
12 | * The epsilon value
13 | */
14 | private double epsilon;
15 |
16 | /**
17 | * Make a epsilon greedy strategy
18 | * @param epsilon the epsilon value
19 | */
20 | public EpsilonGreedyStrategy(double epsilon) {
21 | this.epsilon = epsilon;
22 | }
23 |
24 | /**
25 | * @see rl.ExplorationStrategy#action(double[])
26 | */
27 | public int action(double[] qvalues) {
28 | if (Distribution.random.nextDouble() < epsilon) {
29 | return Distribution.random.nextInt(qvalues.length);
30 | }
31 | int best = 0;
32 | for (int i = 1; i < qvalues.length; i++) {
33 | if (qvalues[best] < qvalues[i]) {
34 | best = i;
35 | }
36 | }
37 | return best;
38 | }
39 |
40 | }
41 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/rl/ExplorationStrategy.java:
--------------------------------------------------------------------------------
1 | package rl;
2 |
3 | /**
4 | * An exploration strategy
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public interface ExplorationStrategy {
9 | /**
10 | * Draw an action from the strategy
11 | * @param qvalues the qvalues
12 | */
13 | public int action(double[] qvalues);
14 | }
15 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/rl/GreedyStrategy.java:
--------------------------------------------------------------------------------
1 | package rl;
2 |
3 |
4 | /**
5 | * A completely greedy strategy
6 | * @author Andrew Guillory gtg008g@mail.gatech.edu
7 | * @version 1.0
8 | */
9 | public class GreedyStrategy implements ExplorationStrategy {
10 |
11 |
12 | /**
13 | * @see rl.ExplorationStrategy#action(double[])
14 | */
15 | public int action(double[] qvalues) {
16 | int best = 0;
17 | for (int i = 1; i < qvalues.length; i++) {
18 | if (qvalues[best] < qvalues[i]) {
19 | best = i;
20 | }
21 | }
22 | return best;
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/rl/MarkovDecisionProcess.java:
--------------------------------------------------------------------------------
1 | package rl;
2 |
3 | /**
4 | * A discrete markov decision process
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public interface MarkovDecisionProcess {
9 | /** An ok default gamma value */
10 | public static final double GAMMA = .9;
11 | /**
12 | * Get the number of states in the mdp
13 | * @return the number of states
14 | */
15 | public int getStateCount();
16 | /**
17 | * Get the number of actions in the mdp
18 | * @return the number of actions
19 | */
20 | public int getActionCount();
21 | /**
22 | * Get the reward for a state and action
23 | * @param state the state
24 | * @param action the action
25 | * @return the reward
26 | */
27 | public abstract double reward(int state, int action);
28 | /**
29 | * Get the probability of transitioning from state i to state j,
30 | * with observation o
31 | * @param i the first state
32 | * @param j the second state
33 | * @param a the action
34 | * @return the probability
35 | */
36 | public abstract double transitionProbability(int i, int j, int a);
37 |
38 | /**
39 | * Sample a next state given the current state and input
40 | * @param i the current state
41 | * @param a the action
42 | * @return the next state
43 | */
44 | public abstract int sampleState(int i, int a);
45 | /**
46 | * Get the initial state
47 | * @return the initial state
48 | */
49 | public abstract int sampleInitialState();
50 | /**
51 | * Check if a state is terminal
52 | * @param state the state
53 | * @return true if it is
54 | */
55 | public abstract boolean isTerminalState(int state);
56 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/rl/Policy.java:
--------------------------------------------------------------------------------
1 | package rl;
2 |
3 | import dist.Distribution;
4 | import util.ABAGAILArrays;
5 |
6 | /**
7 | * A policy maps states to actions
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public class Policy {
12 | /**
13 | * The actions to perform in each state
14 | */
15 | private int[] actions;
16 |
17 | /**
18 | * Make a new policy
19 | * @param actions the actions
20 | */
21 | public Policy(int[] actions) {
22 | this.actions = actions;
23 | }
24 | /**
25 | * Make a new random policy
26 | * @param numStates the number of states
27 | * @param numActions the number of actions
28 | */
29 | public Policy(int numStates, int numActions) {
30 | actions = new int[numStates];
31 | for (int i = 0; i < actions.length; i++) {
32 | actions[i] = Distribution.random.nextInt(numActions);
33 | }
34 | }
35 |
36 | /**
37 | * Get the action for the given state
38 | * @param state the state
39 | * @return the action
40 | */
41 | public int getAction(int state) {
42 | return actions[state];
43 | }
44 | /**
45 | * Set the action for a state
46 | * @param state the state
47 | * @param action the action
48 | */
49 | public void setAction(int state, int action) {
50 | actions[state] = action;
51 | }
52 | /**
53 | * Get the actions
54 | * @return returns the actions.
55 | */
56 | public int[] getActions() {
57 | return actions;
58 | }
59 | /**
60 | * Set the actions
61 | * @param actions the actions to set.
62 | */
63 | public void setActions(int[] actions) {
64 | this.actions = actions;
65 | }
66 | /**
67 | * @see java.lang.Object#toString()
68 | */
69 | public String toString() {
70 | return ABAGAILArrays.toString(actions);
71 | }
72 |
73 | }
74 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/rl/PolicyLearner.java:
--------------------------------------------------------------------------------
1 | package rl;
2 |
3 | import shared.Trainer;
4 |
5 | /**
6 | * A policy learner is also a trainer
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface PolicyLearner extends Trainer {
11 | /**
12 | * Get the best policy
13 | * @return the policy
14 | */
15 | public Policy getPolicy();
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/rl/tester/ExpectedRewardTestMetric.java:
--------------------------------------------------------------------------------
1 | package rl.tester;
2 |
3 | import rl.Policy;
4 | import rl.MarkovDecisionProcess;
5 |
6 | /**
7 | * Expected Reward Test Metric
8 | * @author Daniel Cohen
9 | * @version 1.0
10 | */
11 | public class ExpectedRewardTestMetric {
12 |
13 | private Policy policy;
14 | private MarkovDecisionProcess mdp;
15 |
16 | /**
17 | * Main constructor
18 | */
19 | public ExpectedRewardTestMetric(Policy p, MarkovDecisionProcess mdp) {
20 | this.policy = p;
21 | this.mdp = mdp;
22 | }
23 |
24 | /**
25 | * Computes the expected value by testing the provided policy
26 | */
27 | public double compute(int trials, int iterations) {
28 | double totalReward = 0.0;
29 | for (int t = 0; t < trials; t++) {
30 | int currentState = this.mdp.sampleInitialState();
31 | for (int i = 0; i < iterations; i++) {
32 | int action = this.policy.getAction(currentState);
33 | totalReward += this.mdp.reward(currentState, action);
34 | currentState = this.mdp.sampleState(currentState, action);
35 |
36 | if (currentState >= this.mdp.getStateCount()) {
37 | currentState = this.mdp.getStateCount() - 1;
38 | }
39 | }
40 | }
41 |
42 | return totalReward / (double) trials / (double) iterations;
43 | }
44 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/AbstractDistanceMeasure.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 | /**
4 | * An abstract distance measure with some extra little things
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public abstract class AbstractDistanceMeasure implements DistanceMeasure {
9 | /**
10 | * Calculate the distance between two data sets
11 | * @param a the first
12 | * @param b the second
13 | * @return the distance
14 | */
15 | public double value(DataSet a, DataSet b) {
16 | double distance = 0;
17 | for (int i = 0; i < a.size(); i++) {
18 | distance += value(a.get(i), b.get(i));
19 | }
20 | return distance;
21 | }
22 |
23 | }
24 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/AbstractErrorMeasure.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 | /**
4 | * An abstract error measure
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public abstract class AbstractErrorMeasure implements ErrorMeasure {
9 | /**
10 | * Calculate the error between two data sets
11 | * @param a the first
12 | * @param b the second
13 | * @return the error
14 | */
15 | public double value(DataSet a, DataSet b) {
16 | double error = 0;
17 | for (int i = 0; i < a.size(); i++) {
18 | error += value(a.get(i), b.get(i));
19 | }
20 | return error;
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/AttributeType.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 | /**
4 | * An attribute type specifies what type an attribute
5 | * within a data set is
6 | * @author Andrew Guillory gtg008g@mail.gatech.edu
7 | * @version 1.0
8 | */
9 | public class AttributeType {
10 | /**
11 | * The binary type
12 | */
13 | public static final AttributeType BINARY = new AttributeType(1);
14 | /**
15 | * The integer / discrete type
16 | */
17 | public static final AttributeType DISCRETE = new AttributeType(2);
18 | /**
19 | * The continuous type
20 | */
21 | public static final AttributeType CONTINUOUS = new AttributeType(3);
22 |
23 | /**
24 | * The type of the attribute
25 | */
26 | private int type;
27 |
28 | /**
29 | * Make a new attribute type
30 | * @param t the type of the attribute
31 | */
32 | private AttributeType(int t) {
33 | type = t;
34 | }
35 |
36 | /**
37 | * @see java.lang.Object#equals(java.lang.Object)
38 | */
39 | public boolean equals(Object o) {
40 | return ((AttributeType) o).type == type;
41 | }
42 |
43 | /**
44 | * @see java.lang.Object#toString()
45 | */
46 | public String toString() {
47 | if (this == BINARY) {
48 | return "BINARY";
49 | } else if (this == DISCRETE) {
50 | return "DISCRETE";
51 | } else if (this == CONTINUOUS) {
52 | return "CONTINUOUS";
53 | } else {
54 | return "UNKNOWN";
55 | }
56 | }
57 |
58 | }
59 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/Copyable.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 | /**
4 | * An interface for things that can be copied
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public interface Copyable {
9 | /**
10 | * Make a copy of this
11 | * @return the copy
12 | */
13 | public Copyable copy();
14 |
15 | }
16 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/DistanceMeasure.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 |
4 | /**
5 | * A measure of the distance between vectors.
6 | * @author Andrew Guillory gtg008g@mail.gatech.edu
7 | * @version 1.0
8 | */
9 | public interface DistanceMeasure {
10 |
11 | /**
12 | * Measure the distance between two vectors
13 | * @param va the first vector
14 | * @param vb the second vector
15 | * @return the distance between the vectors
16 | */
17 | public abstract double value(Instance va, Instance vb);
18 |
19 |
20 | }
21 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/ErrorMeasure.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 | /**
4 | * A class representing an error measure
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public interface ErrorMeasure {
9 |
10 | /**
11 | * Measure the error for the given output and target
12 | * @param output the output
13 | * @param example the example
14 | * @return the error
15 | */
16 | public abstract double value(Instance output, Instance example);
17 |
18 | }
19 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/EuclideanDistance.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 |
4 | /**
5 | * The standard euclidean distance measure
6 | * @author Andrew Guillory gtg008g@mail.gatech.edu
7 | * @version 1.0
8 | */
9 | public class EuclideanDistance extends AbstractDistanceMeasure {
10 |
11 | /**
12 | * @see memory.DistanceMeasure#distanceSquared(shared.Instance, shared.Instance)
13 | */
14 | public double value(Instance va, Instance vb) {
15 | double sum = 0;
16 | for (int i = 0; i < va.size(); i++) {
17 | sum += (va.getContinuous(i) - vb.getContinuous(i))
18 | * (va.getContinuous(i) - vb.getContinuous(i));
19 | }
20 | return sum;
21 | }
22 |
23 | }
24 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/FixedIterationTrainer.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 | /**
4 | * A fixed iteration trainer
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class FixedIterationTrainer implements Trainer {
9 |
10 | /**
11 | * The inner trainer
12 | */
13 | private Trainer trainer;
14 |
15 | /**
16 | * The number of iterations to train
17 | */
18 | private int iterations;
19 |
20 | /**
21 | * Make a new fixed iterations trainer
22 | * @param t the trainer
23 | * @param iter the number of iterations
24 | */
25 | public FixedIterationTrainer(Trainer t, int iter) {
26 | trainer = t;
27 | iterations = iter;
28 | }
29 |
30 | /**
31 | * @see shared.Trainer#train()
32 | */
33 | public double train() {
34 | double sum = 0;
35 | for (int i = 0; i < iterations; i++) {
36 | sum += trainer.train();
37 | }
38 | return sum / iterations;
39 | }
40 |
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/GradientErrorMeasure.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 |
4 |
5 | /**
6 | * A error measure that is differentiable with
7 | * respect to the network outputs
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public interface GradientErrorMeasure extends ErrorMeasure {
12 |
13 | /**
14 | * Find the derivatives
15 | * @param output the outputs of the network
16 | * @param targets the targets of the network
17 | * @param index the index of the current pattern
18 | * @return the error derivatives
19 | */
20 | public abstract double[] gradient(Instance output, Instance example);
21 |
22 | }
23 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/HammingDistance.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 |
4 | /**
5 | * The standard hamming distance
6 | * @author Andrew Guillory gtg008g@mail.gatech.edu
7 | * @version 1.0
8 | */
9 | public class HammingDistance extends AbstractDistanceMeasure {
10 |
11 |
12 | /**
13 | * @see memory.DistanceMeasure#distanceSquared(double[], double[])
14 | */
15 | public double value(Instance va, Instance vb) {
16 | double sum = 0;
17 | for (int i = 0; i < va.size(); i++) {
18 | if (va.getDiscrete(i) != vb.getDiscrete(i)) {
19 | sum += 1;
20 | }
21 | }
22 | return sum;
23 | }
24 |
25 |
26 | }
27 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/MixedDistanceMeasure.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 |
4 | /**
5 | * A distance measure that mixes
6 | * several distance measures
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class MixedDistanceMeasure extends AbstractDistanceMeasure{
11 |
12 | /**
13 | * The distance measure for each attribute
14 | */
15 | private AttributeType[] types;
16 |
17 | /**
18 | * Make a new mixed distance measure
19 | * @param measures the measures to use
20 | */
21 | public MixedDistanceMeasure(AttributeType[] types) {
22 | this.types = types;
23 | }
24 |
25 | /**
26 | * @see memory.DistanceMeasure#distance(double[], double[])
27 | */
28 | public double value(Instance va, Instance vb) {
29 | double distance = 0;
30 | for (int i = 0; i < va.size(); i++) {
31 | if (types[i] == AttributeType.CONTINUOUS) {
32 | distance += (va.getContinuous(i) - vb.getContinuous(i))
33 | * (va.getContinuous(i) - vb.getContinuous(i));
34 | } else {
35 | if (va.getDiscrete(i) != vb.getDiscrete(i)) {
36 | distance += 1;
37 | }
38 | }
39 | }
40 | return distance;
41 | }
42 |
43 | }
44 |
45 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/OccasionalPrinter.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 | /**
4 | * An occasional printer prints out a trainer ever once in a while
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class OccasionalPrinter implements Trainer {
9 | /**
10 | * The trainer being trained
11 | */
12 | private Trainer trainer;
13 | /**
14 | * How many iterations to go between print
15 | */
16 | private int iterationsPerPrint;
17 | /**
18 | * The current iteration
19 | */
20 | private int iteration;
21 | /**
22 | * Make a new occasional printer
23 | * @param iterationsPerPrint the number of iterations per print
24 | * @param t the trainer
25 | */
26 | public OccasionalPrinter(int iterationsPerPrint, Trainer t) {
27 | this.iterationsPerPrint = iterationsPerPrint;
28 | this.trainer = t;
29 | }
30 |
31 | /**
32 | * @see shared.Trainer#train()
33 | */
34 | public double train() {
35 | if (iteration % iterationsPerPrint == 0) {
36 | System.out.println(trainer);
37 | }
38 | iteration++;
39 | return trainer.train();
40 | }
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/SumOfSquaresError.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 |
4 |
5 | /**
6 | * Standard error measure, suitable for use with
7 | * linear output networks for regression, sigmoid
8 | * output networks for single class probability,
9 | * and soft max networks for multi class probabilities.
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class SumOfSquaresError extends AbstractErrorMeasure
14 | implements GradientErrorMeasure {
15 |
16 | /**
17 | * @see nn.error.ErrorMeasure#error(double[], nn.Pattern[], int)
18 | */
19 | public double value(Instance output, Instance example) {
20 | double sum = 0;
21 | Instance label = example.getLabel();
22 | for (int i = 0; i < output.size(); i++) {
23 | sum += (output.getContinuous(i) - label.getContinuous(i))
24 | * (output.getContinuous(i) - label.getContinuous(i))
25 | * example.getWeight();
26 | }
27 | return .5 * sum;
28 | }
29 |
30 | /**
31 | * @see nn.error.DifferentiableErrorMeasure#derivatives(double[], nn.Pattern[], int)
32 | */
33 | public double[] gradient(Instance output, Instance example) {
34 | double[] errorArray = new double[output.size()];
35 | Instance label = example.getLabel();
36 | for (int i = 0; i < output.size(); i++) {
37 | errorArray[i] = (output.getContinuous(i) - label.getContinuous(i))
38 | * example.getWeight();
39 | }
40 | return errorArray;
41 | }
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/Trainer.java:
--------------------------------------------------------------------------------
1 | package shared;
2 |
3 | import java.io.Serializable;
4 |
5 | /**
6 | * An abstract trainer
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface Trainer extends Serializable {
11 | /**
12 | * The train the whatever
13 | * @return the error
14 | */
15 | public abstract double train();
16 | }
17 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/DataSetFilter.java:
--------------------------------------------------------------------------------
1 | package shared.filt;
2 |
3 | import shared.DataSet;
4 |
5 | /**
6 | * A filter for a data set
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface DataSetFilter {
11 | /**
12 | * Perform the operation on the given data set
13 | * @param dataSet the data set to operate on
14 | */
15 | public abstract void filter(DataSet dataSet);
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/DiscreteDistributionFilter.java:
--------------------------------------------------------------------------------
1 | package shared.filt;
2 |
3 | import dist.ConditionalDistribution;
4 | import dist.Distribution;
5 | import dist.DiscreteDistribution;
6 | import shared.DataSet;
7 | import shared.Instance;
8 | import util.linalg.DenseVector;
9 |
10 | /**
11 | * A filter that replaces data with a class distribution
12 | * from a classifier
13 | * @author Andrew Guillory gtg008g@mail.gatech.edu
14 | * @version 1.0
15 | */
16 | public class DiscreteDistributionFilter implements DataSetFilter {
17 | /**
18 | * The classifier in use
19 | */
20 | private ConditionalDistribution classifier;
21 |
22 | /**
23 | * Make a new classifier filter
24 | * @param classifier the classifier
25 | */
26 | public DiscreteDistributionFilter(ConditionalDistribution classifier) {
27 | this.classifier = classifier;
28 | }
29 |
30 | /**
31 | * @see shared.filt.DataSetFilter#filter(shared.DataSet)
32 | */
33 | public void filter(DataSet dataSet) {
34 | for (int i = 0; i < dataSet.size(); i++) {
35 | Instance instance = dataSet.get(i);
36 | Distribution dist = classifier.distributionFor(instance);
37 | instance.setData(new DenseVector(
38 | ((DiscreteDistribution) dist).getProbabilities()));
39 | }
40 | }
41 |
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/LabelFilter.java:
--------------------------------------------------------------------------------
1 | package shared.filt;
2 |
3 | import shared.DataSet;
4 |
5 | /**
6 | * A filter that applies a filter to the label data set
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class LabelFilter implements ReversibleFilter {
11 | /**
12 | * The filter to apply
13 | */
14 | private DataSetFilter filter;
15 |
16 | /**
17 | * @see shared.filt.DataSetFilter#filter(shared.DataSet)
18 | */
19 | public void filter(DataSet dataSet) {
20 | filter.filter(dataSet.getLabelDataSet());
21 | }
22 |
23 | /**
24 | * @see shared.filt.ReversibleFilter#reverse(shared.DataSet)
25 | */
26 | public void reverse(DataSet set) {
27 | ((ReversibleFilter) filter).reverse(set.getLabelDataSet());
28 | }
29 |
30 | }
31 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/LabelSelectFilter.java:
--------------------------------------------------------------------------------
1 | package shared.filt;
2 |
3 | import shared.DataSet;
4 | import shared.DataSetDescription;
5 | import shared.Instance;
6 | import util.linalg.Vector;
7 |
8 | /**
9 | * A filter that selects a specified value as the label.
10 | * This is useful for processing datasets with an extra
11 | * attribute appended to the end (such as files Weka
12 | * spits out with the cluster appended to each instance)
13 | *
14 | * @author Jesse Rosalia
15 | */
16 | public class LabelSelectFilter implements DataSetFilter {
17 | /**
18 | * The size of the data
19 | */
20 | private int labelIndex;
21 |
22 | /**
23 | * Make a new label select filter
24 | * @param labelIndex the index of the value to use as the label
25 | */
26 | public LabelSelectFilter(int labelIndex) {
27 | this.labelIndex = labelIndex;
28 | }
29 |
30 | /**
31 | * @see shared.filt.DataSetFilter#filter(shared.DataSet)
32 | */
33 | public void filter(DataSet dataSet) {
34 | int dataCount = dataSet.get(0).size() - labelIndex;
35 | for (int i = 0; i < dataSet.size(); i++) {
36 | Instance instance = dataSet.get(i);
37 | Vector input =
38 | instance.getData().get(0, instance.getData().size());
39 | double output =
40 | instance.getData().get(this.labelIndex);
41 | input = input.remove(this.labelIndex);
42 | instance.setData(input);
43 | instance.setLabel(new Instance(output));
44 | }
45 | dataSet.setDescription(new DataSetDescription(dataSet));
46 | }
47 |
48 | }
49 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/LabelSplitFilter.java:
--------------------------------------------------------------------------------
1 | package shared.filt;
2 |
3 | import shared.DataSet;
4 | import shared.DataSetDescription;
5 | import shared.Instance;
6 | import util.linalg.Vector;
7 |
8 | /**
9 | * A filter that splits a data set into
10 | * data and labels
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public class LabelSplitFilter implements DataSetFilter {
15 | /**
16 | * The size of the data
17 | */
18 | private int labelCount;
19 |
20 | /**
21 | * Make a new label data split filter
22 | * @param labelCount the number of label items
23 | */
24 | public LabelSplitFilter(int labelCount) {
25 | this.labelCount = labelCount;
26 | }
27 |
28 | /**
29 | * Make a new label split filter
30 | */
31 | public LabelSplitFilter() {
32 | this(1);
33 | }
34 |
35 | /**
36 | * @see shared.filt.DataSetFilter#filter(shared.DataSet)
37 | */
38 | public void filter(DataSet dataSet) {
39 | int dataCount = dataSet.get(0).size() - labelCount;
40 | for (int i = 0; i < dataSet.size(); i++) {
41 | Instance instance = dataSet.get(i);
42 | Vector input =
43 | instance.getData().get(0, dataCount);
44 | Vector output =
45 | instance.getData().get(dataCount, instance.getData().size());
46 | instance.setData(input);
47 | instance.setLabel(new Instance(output));
48 | }
49 | dataSet.setDescription(new DataSetDescription(dataSet));
50 | }
51 |
52 | }
53 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/RandomOrderFilter.java:
--------------------------------------------------------------------------------
1 | package shared.filt;
2 |
3 | import dist.Distribution;
4 |
5 | import shared.DataSet;
6 | import shared.Instance;
7 |
8 | /**
9 | * A filter for randomizing the order of a data set
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class RandomOrderFilter implements DataSetFilter {
14 |
15 | /**
16 | * @see shared.DataSetFilter#filter(shared.DataSet)
17 | */
18 | public void filter(DataSet dataSet) {
19 | for (int i = dataSet.size()-1; i > 0; i--) {
20 | int j = Distribution.random.nextInt(i + 1);
21 | Instance temp = dataSet.get(i);
22 | dataSet.set(i, dataSet.get(j));
23 | dataSet.set(j, temp);
24 | }
25 | }
26 |
27 | }
28 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/ReversibleFilter.java:
--------------------------------------------------------------------------------
1 | package shared.filt;
2 |
3 | import shared.DataSet;
4 |
5 | /**
6 | * A reversible filter is a filter that can be undone
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public interface ReversibleFilter extends DataSetFilter {
11 | /**
12 | * Perform the reverse on the given data set
13 | * @param set the set to reverse
14 | */
15 | public void reverse(DataSet set);
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/TestTrainSplitFilter.java:
--------------------------------------------------------------------------------
1 | package shared.filt;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 |
6 | public class TestTrainSplitFilter implements DataSetFilter {
7 |
8 | private double pctTrain;
9 | private DataSet trainingSet;
10 | private DataSet testingSet;
11 |
12 | /**
13 | *
14 | *
15 | * @param pctTrain A percentage from 0 to 100
16 | */
17 | public TestTrainSplitFilter(int pctTrain) {
18 | this.pctTrain = 1.0 * pctTrain / 100; //
19 | }
20 |
21 | @Override
22 | public void filter(DataSet dataSet) {
23 | int totalInstances = dataSet.getInstances().length;
24 | int trainInstances = (int) (totalInstances * pctTrain);
25 | int testInstances = totalInstances - trainInstances;
26 | Instance[] train = new Instance[trainInstances];
27 | Instance[] test = new Instance[testInstances];
28 | for (int ii = 0; ii < trainInstances; ii++) {
29 | train[ii] = dataSet.get(ii);
30 | }
31 | for (int ii = trainInstances; ii < totalInstances; ii++) {
32 | test[ii - trainInstances] = dataSet.get(ii);
33 | }
34 |
35 | this.trainingSet = new DataSet(train);
36 | this.testingSet = new DataSet(test);
37 | }
38 |
39 | public DataSet getTrainingSet() {
40 | return this.trainingSet;
41 | }
42 |
43 | public DataSet getTestingSet() {
44 | return this.testingSet;
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/ica/ContrastFunction.java:
--------------------------------------------------------------------------------
1 | package shared.filt.ica;
2 |
3 | /**
4 | * A contrast function for use by ICA
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public interface ContrastFunction {
9 | /**
10 | * Evaluate the derivative of the contrast function
11 | * on the given value
12 | * @param value the value to evaluate the derivative on
13 | * @return the evaluated derivative
14 | */
15 | public double g(double value);
16 |
17 | /**
18 | * The second derivative of the contrast function
19 | * @param value the evaluated second derivative
20 | * @return the value
21 | */
22 | public double gprime(double value);
23 | }
24 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/ica/HyperbolicTangentContrast.java:
--------------------------------------------------------------------------------
1 | package shared.filt.ica;
2 |
3 | /**
4 | * A log hyperbolic cosine contrast function
5 | * (the first derivative is hyperbolic tangent)
6 | * @author Andrew Guillory gtg008g@mail.gatech.edu
7 | * @version 1.0
8 | */
9 | public class HyperbolicTangentContrast implements ContrastFunction {
10 |
11 | /**
12 | * @see shared.filt.ica.ContrastFunction#g(double)
13 | */
14 | public double g(double value) {
15 | double e2x = Math.exp(2 * value);
16 | if (e2x == Double.POSITIVE_INFINITY) {
17 | return 1;
18 | } else {
19 | return (e2x - 1) / (e2x + 1);
20 | }
21 | }
22 |
23 | /**
24 | * @see shared.filt.ica.ContrastFunction#gprime(double)
25 | */
26 | public double gprime(double value) {
27 | double tanhvalue = g(value);
28 | return 1 - tanhvalue * tanhvalue;
29 | }
30 |
31 | }
32 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/filt/kFoldSplitFilter.java:
--------------------------------------------------------------------------------
1 | package shared.filt;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 |
6 | import java.util.ArrayList;
7 | import java.util.Random;
8 |
9 | /**
10 | * A filter that supports k-fold splitting of a dataset for cross validation
11 | * @author Daniel Cohen dcohen@gatech.edu
12 | * @version 1.0
13 | */
14 |
15 | public class kFoldSplitFilter implements DataSetFilter {
16 | private int foldCount;
17 | private ArrayList folds;
18 |
19 | public kFoldSplitFilter(int foldCount) {
20 | this.foldCount = foldCount;
21 | this.folds = new ArrayList<>();
22 | }
23 |
24 | public void filter(DataSet data) {
25 | int foldSize = data.size() / foldCount;
26 | Random rand = new Random();
27 |
28 | for (int currentFold = 0; currentFold < foldCount; currentFold++) {
29 | DataSet currentSet = new DataSet(new Instance[foldSize], data.getDescription());
30 | int i = 0;
31 | while (i < foldSize) {
32 | int position = rand.nextInt(data.size());
33 | Instance instance = data.get(position);
34 | if (instance != null && instance.getData() != null) {
35 | currentSet.set(i, instance);
36 | data.set(position, null);
37 | i++;
38 | }
39 | }
40 | this.folds.add(currentSet);
41 | }
42 | }
43 |
44 | public ArrayList getValidationFolds(DataSet currentFold) {
45 | ArrayList result = new ArrayList<>(this.folds);
46 | result.remove(currentFold);
47 | return result;
48 | }
49 |
50 | public ArrayList getFolds() {
51 | return this.folds;
52 | }
53 |
54 | public int getFoldCount() {
55 | return this.foldCount;
56 | }
57 | }
58 |
59 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/reader/CSVDataSetReader.java:
--------------------------------------------------------------------------------
1 | package shared.reader;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.FileReader;
5 | import java.util.ArrayList;
6 | import java.util.List;
7 | import java.util.regex.Pattern;
8 |
9 | import shared.DataSet;
10 | import shared.DataSetDescription;
11 | import shared.Instance;
12 | /**
13 | * Class to read in data from a CSV file without a specified label
14 | * @author Tim Swihart
15 | * @date 2013-03-05
16 | */
17 | public class CSVDataSetReader extends DataSetReader {
18 |
19 | public CSVDataSetReader(String file) {
20 | super(file);
21 | // TODO Auto-generated constructor stub
22 | }
23 |
24 | @Override
25 | public DataSet read() throws Exception {
26 | BufferedReader br = new BufferedReader(new FileReader(file));
27 | String line;
28 | List data = new ArrayList();
29 | Pattern pattern = Pattern.compile("[ ,]+");
30 | while ((line = br.readLine()) != null) {
31 | String[] split = pattern.split(line.trim());
32 | double[] input = new double[split.length];
33 | for (int i = 0; i < input.length; i++) {
34 | input[i] = Double.parseDouble(split[i]);
35 | }
36 | Instance instance = new Instance(input);
37 | data.add(instance);
38 | }
39 | br.close();
40 | Instance[] instances = (Instance[]) data.toArray(new Instance[0]);
41 | DataSet set = new DataSet(instances);
42 | set.setDescription(new DataSetDescription(set));
43 | return set;
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/reader/DataSetReader.java:
--------------------------------------------------------------------------------
1 | package shared.reader;
2 |
3 | import java.io.BufferedReader;
4 | import java.io.FileReader;
5 | import java.util.ArrayList;
6 | import java.util.List;
7 | import java.util.regex.Pattern;
8 |
9 | import shared.DataSet;
10 |
11 | /**
12 | * An instance reader reads instances from a file
13 | * @author Andrew Guillory gtg008g@mail.gatech.edu
14 | * @version 1.0
15 | */
16 | public abstract class DataSetReader {
17 | /**
18 | * The files to read from
19 | */
20 | protected String file;
21 |
22 | /**
23 | * Make a new instance reader
24 | * @param file the file to read from
25 | */
26 | public DataSetReader(String file) {
27 | this.file = file;
28 | }
29 |
30 | /**
31 | * Read the thing
32 | * @return the data
33 | */
34 | public abstract DataSet read() throws Exception;
35 |
36 | }
37 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/test/ArffDataSetReaderTest.java:
--------------------------------------------------------------------------------
1 | package shared.test;
2 |
3 | import java.io.File;
4 |
5 | import shared.DataSet;
6 | import shared.DataSetDescription;
7 | import shared.reader.ArffDataSetReader;
8 | import shared.reader.DataSetReader;
9 | import shared.filt.ContinuousToDiscreteFilter;
10 | import shared.filt.LabelSplitFilter;
11 | import shared.reader.DataSetLabelBinarySeperator;
12 |
13 | /**
14 | * A data set reader
15 | * @author Andrew Guillory gtg008g@mail.gatech.edu
16 | * @version 1.0
17 | */
18 | public class ArffDataSetReaderTest {
19 | /**
20 | * The test main
21 | * @param args ignored parameters
22 | */
23 | public static void main(String[] args) throws Exception {
24 | DataSetReader dsr = new ArffDataSetReader(new File("").getAbsolutePath() + "/src/shared/test/abalone.arff");
25 | // read in the raw data
26 | DataSet ds = dsr.read();
27 | // split out the label
28 | LabelSplitFilter lsf = new LabelSplitFilter();
29 | lsf.filter(ds);
30 | ContinuousToDiscreteFilter ctdf = new ContinuousToDiscreteFilter(10);
31 | ctdf.filter(ds);
32 | DataSetLabelBinarySeperator.seperateLabels(ds);
33 | System.out.println(ds);
34 | System.out.println(new DataSetDescription(ds));
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/test/CSVDataSetReaderTest.java:
--------------------------------------------------------------------------------
1 | package shared.test;
2 |
3 | import java.io.File;
4 |
5 | import shared.DataSet;
6 | import shared.DataSetDescription;
7 | import shared.reader.CSVDataSetReader;
8 | import shared.reader.DataSetReader;
9 | import shared.filt.ContinuousToDiscreteFilter;
10 | import shared.filt.LabelSplitFilter;
11 |
12 | /**
13 | * A data set reader
14 | * @author Andrew Guillory gtg008g@mail.gatech.edu
15 | * @version 1.0
16 | */
17 | public class CSVDataSetReaderTest {
18 | /**
19 | * The test main
20 | * @param args ignored parameters
21 | */
22 | public static void main(String[] args) throws Exception {
23 | DataSetReader dsr = new CSVDataSetReader(new File("").getAbsolutePath() + "/src/shared/test/abalone.data");
24 | // read in the raw data
25 | DataSet ds = dsr.read();
26 | // split out the label
27 | LabelSplitFilter lsf = new LabelSplitFilter();
28 | lsf.filter(ds);
29 | ContinuousToDiscreteFilter ctdf = new ContinuousToDiscreteFilter(10);
30 | ctdf.filter(ds);
31 | System.out.println(ds);
32 | System.out.println(new DataSetDescription(ds));
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/test/IndepenentComponentAnalysisTest.java:
--------------------------------------------------------------------------------
1 | package shared.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import shared.filt.IndependentComponentAnalysis;
6 | import util.linalg.Matrix;
7 | import util.linalg.RectangularMatrix;
8 |
9 | /**
10 | * A class for testing
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public class IndepenentComponentAnalysisTest {
15 |
16 | /**
17 | * The test main
18 | * @param args ignored
19 | */
20 | public static void main(String[] args) {
21 | Instance[] instances = new Instance[100];
22 | for (int i = 0; i < instances.length; i++) {
23 | double[] data = new double[2];
24 | data[0] = Math.sin(i/2.0);
25 | data[1] = (Math.random() - .5)*2;
26 | instances[i] = new Instance(data);
27 | }
28 | DataSet set = new DataSet(instances);
29 | System.out.println("Before randomizing");
30 | System.out.println(set);
31 | Matrix projection = new RectangularMatrix(new double[][]{ {.6, .6}, {.4, .6}});
32 | for (int i = 0; i < set.size(); i++) {
33 | Instance instance = set.get(i);
34 | instance.setData(projection.times(instance.getData()));
35 | }
36 | System.out.println("Before ICA");
37 | System.out.println(set);
38 | IndependentComponentAnalysis filter = new IndependentComponentAnalysis(set, 1);
39 | filter.filter(set);
40 | System.out.println("After ICA");
41 | System.out.println(set);
42 |
43 | }
44 |
45 | }
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/test/InsignificantComponentAnalysisTest.java:
--------------------------------------------------------------------------------
1 | package shared.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import shared.filt.InsignificantComponentAnalysis;
6 | import util.linalg.Matrix;
7 |
8 | /**
9 | * A class for testing
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class InsignificantComponentAnalysisTest {
14 |
15 | /**
16 | * The test main
17 | * @param args ignored
18 | */
19 | public static void main(String[] args) {
20 | Instance[] instances = {
21 | new Instance(new double[] {1,1,0,0,0,0,0,0}),
22 | new Instance(new double[] {0,0,1,1,1,0,0,0}),
23 | new Instance(new double[] {0,0,0,0,1,1,1,1}),
24 | new Instance(new double[] {1,0,1,0,1,0,1,0}),
25 | new Instance(new double[] {1,1,0,0,1,1,0,0}),
26 | };
27 | DataSet set = new DataSet(instances);
28 | System.out.println("Before ICA");
29 | System.out.println(set);
30 | InsignificantComponentAnalysis filter = new InsignificantComponentAnalysis(set);
31 | System.out.println(filter.getEigenValues());
32 | System.out.println(filter.getProjection().transpose());
33 | filter.filter(set);
34 | System.out.println("After ICA");
35 | System.out.println(set);
36 | System.out.println(filter.getProjection().transpose());
37 | Matrix reverse = filter.getProjection().transpose();
38 | for (int i = 0; i < set.size(); i++) {
39 | Instance instance = set.get(i);
40 | instance.setData(reverse.times(instance.getData()).plus(filter.getMean()));
41 | }
42 | System.out.println("After reconstructing");
43 | System.out.println(set);
44 |
45 | }
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/test/LabelSelectFilterTest.java:
--------------------------------------------------------------------------------
1 | package shared.test;
2 |
3 | import java.io.File;
4 |
5 | import shared.DataSet;
6 | import shared.DataSetDescription;
7 | import shared.filt.ContinuousToDiscreteFilter;
8 | import shared.filt.LabelSelectFilter;
9 | import shared.filt.LabelSplitFilter;
10 | import shared.reader.ArffDataSetReader;
11 | import shared.reader.DataSetLabelBinarySeperator;
12 | import shared.reader.DataSetReader;
13 |
14 | public class LabelSelectFilterTest {
15 | /**
16 | * The test main
17 | * @param args ignored parameters
18 | */
19 | public static void main(String[] args) throws Exception {
20 | DataSetReader dsr = new ArffDataSetReader(new File("").getAbsolutePath() + "/src/shared/test/abalone.arff");
21 | // read in the raw data
22 | DataSet ds = dsr.read();
23 | // split out the label
24 | LabelSelectFilter lsf = new LabelSelectFilter(1);
25 | lsf.filter(ds);
26 | ContinuousToDiscreteFilter ctdf = new ContinuousToDiscreteFilter(10);
27 | ctdf.filter(ds);
28 | System.out.println(ds);
29 | System.out.println(new DataSetDescription(ds));
30 | }
31 | }
32 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/test/LinearDiscriminantAnalysisTest.java:
--------------------------------------------------------------------------------
1 | package shared.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import shared.filt.LinearDiscriminantAnalysis;
6 | import util.linalg.DenseVector;
7 |
8 | /**
9 | * A class for testing
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class LinearDiscriminantAnalysisTest {
14 |
15 | /**
16 | * The test main
17 | * @param args ignored
18 | */
19 | public static void main(String[] args) {
20 | Instance[] instances = {
21 | new Instance(new DenseVector(new double[] {100,1,0,0,0,0,0,0}), new Instance(1)),
22 | new Instance(new DenseVector(new double[] {0,0,10,10,100,0,0,0}), new Instance(0)),
23 | new Instance(new DenseVector(new double[] {0,0,0,0,1,1,10,10}), new Instance(0)),
24 | new Instance(new DenseVector(new double[] {100,0,10,0,1,0,1,0}), new Instance(1)),
25 | new Instance(new DenseVector(new double[] {100,10,0,0,10,1,0,0}), new Instance(1)),
26 | };
27 | DataSet set = new DataSet(instances);
28 | System.out.println("Before LDA");
29 | System.out.println(set);
30 | LinearDiscriminantAnalysis filter = new LinearDiscriminantAnalysis(set);
31 | filter.filter(set);
32 | System.out.println(filter.getProjection());
33 | System.out.println("After LDA");
34 | System.out.println(set);
35 | filter.reverse(set);
36 | System.out.println("After reconstructing");
37 | System.out.println(set);
38 |
39 | }
40 |
41 | }
42 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/test/PrincipalComponentAnalysisTest.java:
--------------------------------------------------------------------------------
1 | package shared.test;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 | import shared.filt.PrincipalComponentAnalysis;
6 | import util.linalg.Matrix;
7 |
8 | /**
9 | * A class for testing
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class PrincipalComponentAnalysisTest {
14 |
15 | /**
16 | * The test main
17 | * @param args ignored
18 | */
19 | public static void main(String[] args) {
20 | Instance[] instances = {
21 | new Instance(new double[] {1,1,0,0,0,0,0,0}),
22 | new Instance(new double[] {0,0,1,1,1,0,0,0}),
23 | new Instance(new double[] {0,0,0,0,1,1,1,1}),
24 | new Instance(new double[] {1,0,1,0,1,0,1,0}),
25 | new Instance(new double[] {1,1,0,0,1,1,0,0}),
26 | };
27 | DataSet set = new DataSet(instances);
28 | System.out.println("Before PCA");
29 | System.out.println(set);
30 | PrincipalComponentAnalysis filter = new PrincipalComponentAnalysis(set);
31 | System.out.println(filter.getEigenValues());
32 | System.out.println(filter.getProjection().transpose());
33 | filter.filter(set);
34 | System.out.println("After PCA");
35 | System.out.println(set);
36 | Matrix reverse = filter.getProjection().transpose();
37 | for (int i = 0; i < set.size(); i++) {
38 | Instance instance = set.get(i);
39 | instance.setData(reverse.times(instance.getData()).plus(filter.getMean()));
40 | }
41 | System.out.println("After reconstructing");
42 | System.out.println(set);
43 |
44 | }
45 |
46 | }
47 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/test/abalone_notes.txt:
--------------------------------------------------------------------------------
1 | abalone.data notes are in abalone.names
2 |
3 | The first parameter of each instance in abalone.data represents the sex, which has been converted from strings to ints as follows:
4 | - M -> 1
5 | - F -> -1
6 | - I -> 0
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/tester/AccuracyTestMetric.java:
--------------------------------------------------------------------------------
1 | package shared.tester;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A test metric for accuracy. This metric reports of % correct and % incorrect for a test run.
7 | *
8 | * @author Jesse Rosalia
9 | * @date 2013-03-05
10 | */
11 | public class AccuracyTestMetric extends TestMetric {
12 |
13 | private int count;
14 | private int countCorrect;
15 |
16 | @Override
17 | public void addResult(Instance expected, Instance actual) {
18 | Comparison c = new Comparison(expected, actual);
19 |
20 | count++;
21 | if (c.isAllCorrect()) {
22 | countCorrect++;
23 | }
24 | }
25 |
26 | public double getPctCorrect() {
27 | return count > 0 ? ((double)countCorrect)/count : 1; //if count is 0, we consider it all correct
28 | }
29 |
30 | public void printResults() {
31 | //only report results if there were any results to report.
32 | if (count > 0) {
33 | double pctCorrect = getPctCorrect();
34 | double pctIncorrect = (1 - pctCorrect);
35 | System.out.println(String.format("Correctly Classified Instances: %.02f%%", 100 * pctCorrect));
36 | System.out.println(String.format("Incorrectly Classified Instances: %.02f%%", 100 * pctIncorrect));
37 | } else {
38 |
39 | System.out.println("No results added.");
40 | }
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/tester/NeuralNetworkTester.java:
--------------------------------------------------------------------------------
1 | package shared.tester;
2 |
3 | import shared.Instance;
4 | import shared.reader.DataSetLabelBinarySeperator;
5 | import func.nn.NeuralNetwork;
6 |
7 | /**
8 | * A tester for neural networks. This will run each instance
9 | * through the network and report the results to any test metrics
10 | * specified at instantiation.
11 | *
12 | * @author Jesse Rosalia (https://www.github.com/theJenix)
13 | * @date 2013-03-05
14 | */
15 | public class NeuralNetworkTester implements Tester {
16 |
17 | private NeuralNetwork network;
18 | private TestMetric[] metrics;
19 |
20 | public NeuralNetworkTester(NeuralNetwork network, TestMetric ... metrics) {
21 | this.network = network;
22 | this.metrics = metrics;
23 | }
24 |
25 | @Override
26 | public void test(Instance[] instances) {
27 | for (int i = 0; i < instances.length; i++) {
28 | //run the instance data through the network
29 | network.setInputValues(instances[i].getData());
30 | network.run();
31 |
32 | Instance expected = instances[i].getLabel();
33 | Instance actual = new Instance(network.getOutputValues());
34 |
35 | //collapse the values, for statistics reporting
36 | //NOTE: assumes discrete labels, with n output nodes for n
37 | // potential labels, and an activation function that outputs
38 | // values between 0 and 1.
39 | Instance expectedOne = DataSetLabelBinarySeperator.combineLabels(expected);
40 | Instance actualOne = DataSetLabelBinarySeperator.combineLabels(actual);
41 |
42 | //run this result past all of the available test metrics
43 | for (TestMetric metric : metrics) {
44 | metric.addResult(expectedOne, actualOne);
45 | }
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/tester/PrecisionTestMetric.java:
--------------------------------------------------------------------------------
1 | package shared.tester;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A test metric for precision.
7 | * Precision is defined as (true positives) /(true positives + false positives).
8 | * Only the first value of the label is used to determine true/false.
9 | *
10 | * @author shashir
11 | * @date 2014-03-23
12 | *
13 | */
14 | public class PrecisionTestMetric extends TestMetric {
15 |
16 | private int truePositives;
17 | private int falsePositives;
18 | private int totalCandidatePositives;
19 |
20 | @Override
21 | public void addResult(Instance target, Instance candidate) {
22 | // Sanity check.
23 | Comparison c = new Comparison(candidate, target);
24 |
25 | boolean trueCandidate = (0 == c.compare(candidate.getLabel().getContinuous(), 1.0));
26 | boolean trueTarget = (0 == c.compare(target.getLabel().getContinuous(), 1.0));
27 | if (trueCandidate && !trueTarget) {
28 | falsePositives++;
29 | } else if (trueCandidate && trueTarget) {
30 | truePositives++;
31 | }
32 | totalCandidatePositives = truePositives + falsePositives;
33 | }
34 |
35 | public double getPctPrecision() {
36 | return totalCandidatePositives > 0 ? ((double) truePositives) / totalCandidatePositives : 1; //if count is 0, we consider it all correct
37 | }
38 |
39 | public void printResults() {
40 | //only report results if there were any results to report.
41 | if (totalCandidatePositives > 0) {
42 | double pctPrecision = getPctPrecision();
43 | System.out.printf(
44 | "Precision (ratio of true positives to predicted positives): %.02f%%\n",
45 | 100 * pctPrecision);
46 | } else {
47 | System.out.println("No results added.");
48 | }
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/tester/RawOutputTestMetric.java:
--------------------------------------------------------------------------------
1 | package shared.tester;
2 |
3 | import shared.Instance;
4 |
5 | public class RawOutputTestMetric extends TestMetric {
6 |
7 | StringBuilder builder = new StringBuilder();
8 | @Override
9 | public void addResult(Instance expected, Instance actual) {
10 | builder.append("Expected: ");
11 | boolean addComma = false;
12 | for (int ii = 0; ii < expected.size(); ii++) {
13 | if (addComma) {
14 | builder.append(",");
15 | }
16 | builder.append(expected.getContinuous(ii));
17 | addComma = true;
18 | }
19 | builder.append(", Actual: ");
20 | addComma = false;
21 | for (int ii = 0; ii < expected.size(); ii++) {
22 | if (addComma) {
23 | builder.append(",");
24 | }
25 | builder.append(actual.getContinuous(ii));
26 | addComma = true;
27 | }
28 |
29 | builder.append("\n");
30 | }
31 |
32 | @Override
33 | public void printResults() {
34 | System.out.println(builder.toString());
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/tester/RecallTestMetric.java:
--------------------------------------------------------------------------------
1 | package shared.tester;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * A test metric for recall.
7 | * Recall is defined as (true positives) /(true positives + false negatives).
8 | * Only the first value of the label is used to determine true/false.
9 | *
10 | * @author shashir
11 | * @date 2014-03-23
12 | *
13 | */
14 | public class RecallTestMetric extends TestMetric {
15 |
16 | private int truePositives;
17 | private int falseNegatives;
18 | private int totalTargetPositives;
19 |
20 | @Override
21 | public void addResult(Instance target, Instance candidate) {
22 | // Sanity check.
23 | Comparison c = new Comparison(candidate, target);
24 |
25 | boolean trueCandidate = (0 == c.compare(candidate.getLabel().getContinuous(), 1.0));
26 | boolean trueTarget = (0 == c.compare(target.getLabel().getContinuous(), 1.0));
27 | if (!trueCandidate && trueTarget) {
28 | falseNegatives++;
29 | } else if (trueCandidate && trueTarget) {
30 | truePositives++;
31 | }
32 | totalTargetPositives = truePositives + falseNegatives;
33 | }
34 |
35 | public double getPctRecall() {
36 | return totalTargetPositives > 0 ? ((double) truePositives) / totalTargetPositives : 1; //if count is 0, we consider it all correct
37 | }
38 |
39 | public void printResults() {
40 | //only report results if there were any results to report.
41 | if (totalTargetPositives > 0) {
42 | double pctRecall = getPctRecall();
43 | System.out.printf(
44 | "Recall (ratio of true positives to target positives): %.02f%%\n",
45 | 100 * pctRecall);
46 | } else {
47 | System.out.println("No results added.");
48 | }
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/tester/TestMetric.java:
--------------------------------------------------------------------------------
1 | package shared.tester;
2 |
3 | import shared.DataSet;
4 | import shared.Instance;
5 |
6 | /**
7 | * This interface defines an API for test metrics. Test metrics are notified by the Tester
8 | * after an instance is tested by the classifier. The test metrics are given a chance to compare
9 | * the results and accumulate statics or measurements of error. These data can then be printed
10 | * in a human readable format.
11 | *
12 | * @author Jesse Rosalia (https://www.github.com/theJenix)
13 | * @date 2013-03-05
14 | */
15 | public abstract class TestMetric {
16 |
17 | /**
18 | * Add a test result to the metric. The metric will compare the values and
19 | * accumulate what data it needs.
20 | *
21 | * @param expected The expected value (from the training set)
22 | * @param actual The value produced by the classifier.
23 | */
24 | public abstract void addResult(Instance expected, Instance actual);
25 |
26 | /**
27 | * Bulk add a test results to the metric. The metric will compare the values and
28 | * accumulate what data it needs.
29 | *
30 | * @param expected The expected values from the training set.
31 | * @param actual The values produced by the classifier.
32 | */
33 | public void addResult(DataSet expected, DataSet actual) {
34 | // Sanity check sizes.
35 | if (expected.size() != actual.size()) {
36 | throw new RuntimeException("Something is wrong. "
37 | + "Expected data set and actual data set sizes are not the same.");
38 | }
39 | for (int i = 0; i < expected.size(); i++) {
40 | this.addResult(expected.get(i), actual.get(i));
41 | }
42 | }
43 |
44 | /**
45 | * Print the values collected by this test metric.
46 | *
47 | */
48 | public abstract void printResults();
49 | }
50 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/tester/Tester.java:
--------------------------------------------------------------------------------
1 | package shared.tester;
2 |
3 | import shared.Instance;
4 |
5 | /**
6 | * This interface defines an API for testers, which run test data through
7 | * a classifier and accumulate the results. How the classifier, test metrics,
8 | * and supporting objects are injected is left up to the implementation.
9 | *
10 | * @author Jesse Rosalia (https://www.github.com/theJenix)
11 | * @date 2013-03-05
12 | */
13 | public interface Tester {
14 |
15 | /**
16 | * Test a classifier using the instances passed in. Note that these can
17 | * also be your training instances, to test with your training set.
18 | *
19 | * @param instances
20 | */
21 | public void test(Instance[] instances);
22 | }
23 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/writer/CSVWriter.java:
--------------------------------------------------------------------------------
1 | package shared.writer;
2 |
3 | import java.io.FileWriter;
4 | import java.io.IOException;
5 | import java.util.ArrayList;
6 | import java.util.Arrays;
7 | import java.util.List;
8 |
9 | /**
10 | * Write arbitrary data to a CSV file. This is used to write results out,
11 | * to be consumed by another program (GNUPlot, etc).
12 | *
13 | * @author Jesse Rosalia
14 | * @date 2013-03-07
15 | *
16 | */
17 | public class CSVWriter implements Writer {
18 |
19 | private String fileName;
20 | private List fields;
21 | private List buffer;
22 | private FileWriter fileWriter;
23 |
24 | public CSVWriter(String fileName, String[] fields) {
25 | this.fileName = fileName;
26 | this.fields = Arrays.asList(fields);
27 | this.buffer = new ArrayList();
28 | }
29 |
30 | @Override
31 | public void close() throws IOException {
32 | this.fileWriter.close();
33 | }
34 |
35 | @Override
36 | public void open() throws IOException {
37 | this.fileWriter = new FileWriter(fileName);
38 | writeRow(this.fields);
39 | }
40 |
41 | /**
42 | * @param toWrite
43 | * @throws IOException
44 | */
45 | private void writeRow(List toWrite) throws IOException {
46 | boolean addComma = false;
47 | for (String field : toWrite) {
48 | if (addComma) {
49 | this.fileWriter.append(",");
50 | }
51 | this.fileWriter.append(field);
52 | addComma = true;
53 | }
54 | this.fileWriter.append('\n');
55 | }
56 |
57 | @Override
58 | public void write(String str) throws IOException {
59 | this.buffer.add(str);
60 | }
61 |
62 | @Override
63 | public void nextRecord() throws IOException {
64 | writeRow(buffer);
65 | //clear the buffer for the next record
66 | buffer.clear();
67 | }
68 | }
69 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/shared/writer/Writer.java:
--------------------------------------------------------------------------------
1 | package shared.writer;
2 |
3 | import java.io.IOException;
4 |
5 | /**
6 | * This interface defines an API for a Writer object. Writers are used to write results
7 | * to a file of a certain type.
8 | *
9 | * The writer lets a caller write to a given record, or advance to the next record.
10 | * As an example, a CSVWriter might consider each line a record. A user can write
11 | * to a line, which will create comma separated values. The call to nextRecord
12 | * will then go to the next line.
13 | *
14 | * @author Jesse Rosalia
15 | * @date 2013-03-07
16 | */
17 | public interface Writer {
18 |
19 | /**
20 | * Close a writer and flush it's contents.
21 | *
22 | * @throws IOException
23 | */
24 | public void close() throws IOException;
25 |
26 | /**
27 | * Open a writer for writing.
28 | *
29 | * @throws IOException
30 | */
31 | public void open() throws IOException;
32 |
33 | /**
34 | * Write a datapoint to a record.
35 | *
36 | * @param str
37 | * @throws IOException
38 | */
39 | public void write(String str) throws IOException;
40 |
41 | /**
42 | * Advance to the next record.
43 | *
44 | * @throws IOException
45 | */
46 | public void nextRecord() throws IOException;
47 | }
48 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/TimeUtil.java:
--------------------------------------------------------------------------------
1 | package util;
2 |
3 | /**
4 | * A utility for preparing and presenting run time metrics.
5 | *
6 | * @author Jesse Rosalia
7 | * @date 2013-03-07
8 | */
9 | public class TimeUtil {
10 |
11 | public static String formatTime(long time) {
12 | long secs = ((long) time) / 1000;
13 | long min = secs / 60;
14 | secs -= min * 60;
15 | return String.format("%02d:%02d", min, secs);
16 | }
17 | }
18 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/graph/DFSTree.java:
--------------------------------------------------------------------------------
1 | package util.graph;
2 |
3 | /**
4 | * A class for producing a tree traversal
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class DFSTree implements GraphTransformation {
9 | /**
10 | * Whether or not a node has been visited
11 | */
12 | boolean[] visited;
13 |
14 | /**
15 | * @see graph.GraphTransform#transform(graph.Graph)
16 | */
17 | public Graph transform(Graph g) {
18 | visited = new boolean[g.getNodeCount()];
19 | for (int i = 0; i < visited.length; i++) {
20 | visited[i] = false;
21 | }
22 | dfs(g.getNode(0));
23 | Tree result = new Tree(g.getNode(0));
24 | result.setNodes(g.getNodes());
25 | visited = null;
26 | return result;
27 | }
28 |
29 | /**
30 | * Perform a depth first search on the graph
31 | * @param g the graph to search
32 | */
33 | private void dfs(Node n) {
34 | visited[n.getLabel()] = true;
35 | for (int i = 0; i < n.getEdgeCount(); i++) {
36 | Edge edge = n.getEdge(i);
37 | Node other = edge.getOther(n);
38 | if (visited[other.getLabel()]) {
39 | n.removeEdge(i);
40 | i--;
41 | } else {
42 | dfs(other);
43 | }
44 | }
45 | }
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/graph/Edge.java:
--------------------------------------------------------------------------------
1 | package util.graph;
2 |
3 | /**
4 | * An edge
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class Edge {
9 | /**
10 | * The in node
11 | */
12 | private Node a;
13 | /**
14 | * The out node
15 | */
16 | private Node b;
17 |
18 | /**
19 | * Get the in node
20 | * @return the in node
21 | */
22 | public Node getA() {
23 | return a;
24 | }
25 |
26 | /**
27 | * Get the out node
28 | * @return the out node
29 | */
30 | public Node getB() {
31 | return b;
32 | }
33 |
34 | /**
35 | * Get the other node
36 | * @param n the node
37 | * @return the other node
38 | */
39 | public Node getOther(Node n) {
40 | if (n == a) {
41 | return b;
42 | } else {
43 | return a;
44 | }
45 | }
46 |
47 | /**
48 | * Set the in node
49 | * @param node the in node
50 | */
51 | public void setA(Node node) {
52 | a = node;
53 | }
54 |
55 | /**
56 | * Set the out node
57 | * @param node the out node
58 | */
59 | public void setB(Node node) {
60 | b = node;
61 | }
62 |
63 | /**
64 | * @see java.lang.Object#toString()
65 | */
66 | public String toString() {
67 | return a.getLabel() + " -> " + b.getLabel();
68 | }
69 |
70 | }
71 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/graph/GraphTransformation.java:
--------------------------------------------------------------------------------
1 | package util.graph;
2 |
3 | /**
4 | * A graph transform is a tranformation of a graph
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public interface GraphTransformation {
9 |
10 | /**
11 | * Transform the given graph
12 | * @param g the graph to transform
13 | * @return the transformed graph
14 | */
15 | public Graph transform(Graph g);
16 |
17 | }
18 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/graph/Tree.java:
--------------------------------------------------------------------------------
1 | package util.graph;
2 |
3 | /**
4 | * A tree is a directed graph with a root
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class Tree extends Graph {
9 |
10 | /**
11 | * The root node
12 | */
13 | private Node root;
14 |
15 | /**
16 | * Make a rooted graph
17 | */
18 | public Tree() {
19 | }
20 |
21 | /**
22 | * Make a new tree
23 | * @param root the root
24 | */
25 | public Tree(Node root) {
26 | this.root = root;
27 | }
28 |
29 | /**
30 | * Get the root
31 | * @return the root
32 | */
33 | public Node getRoot() {
34 | return root;
35 | }
36 |
37 | /**
38 | * Set the root
39 | * @param node the root
40 | */
41 | public void setRoot(Node node) {
42 | root = node;
43 | }
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/graph/WeightedEdge.java:
--------------------------------------------------------------------------------
1 | package util.graph;
2 |
3 | /**
4 | * A class representing a weighted edge
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class WeightedEdge extends Edge implements Comparable {
9 |
10 | /**
11 | * The weight of the edge
12 | */
13 | private double weight;
14 |
15 | /**
16 | * Make a new weighted edge
17 | * @param weight the weight of the edge
18 | */
19 | public WeightedEdge(double weight) {
20 | this.weight = weight;
21 | }
22 |
23 | /**
24 | * Get the weight
25 | * @return the weight
26 | */
27 | public double getWeight() {
28 | return weight;
29 | }
30 |
31 | /**
32 | * Set the weight
33 | * @param d the new weight
34 | */
35 | public void setWeight(double d) {
36 | weight = d;
37 | }
38 |
39 | /**
40 | * @see java.lang.Comparable#compareTo(java.lang.Object)
41 | */
42 | public int compareTo(Object o) {
43 | WeightedEdge e = (WeightedEdge) o;
44 | if (getWeight() > e.getWeight()) {
45 | return 1;
46 | } else if (getWeight() < e.getWeight()) {
47 | return -1;
48 | } else {
49 | return 0;
50 | }
51 | }
52 |
53 | /**
54 | * @see java.lang.Object#toString()
55 | */
56 | public String toString() {
57 | return super.toString() + " x " + weight;
58 | }
59 |
60 | }
61 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/linalg/DenseVector.java:
--------------------------------------------------------------------------------
1 | package util.linalg;
2 |
3 | /**
4 | * An implementation of a vector that is dense
5 | * @author Andrew Guillory gtg008g@mail.gatech.edu
6 | * @version 1.0
7 | */
8 | public class DenseVector extends Vector {
9 |
10 | /**
11 | * The data
12 | */
13 | private double[] data;
14 |
15 | /**
16 | * Make a new dense vector
17 | * @param data the data
18 | */
19 | public DenseVector(double[] data) {
20 | this.data = data;
21 | }
22 |
23 | /**
24 | * Make a new dense vector of the given size
25 | * @param size the size to make it
26 | */
27 | public DenseVector(int size) {
28 | data = new double[size];
29 | }
30 |
31 | /**
32 | * @see linalg.Vector#size()
33 | */
34 | public int size() {
35 | return data.length;
36 | }
37 |
38 | /**
39 | * @see linalg.Vector#get(int)
40 | */
41 | public double get(int i) {
42 | return data[i];
43 | }
44 |
45 | /**
46 | * @see linalg.Vector#set(int, double)
47 | */
48 | public void set(int i, double value) {
49 | data[i] = value;
50 | }
51 |
52 | /**
53 | * Make an identity vector
54 | * @param i the dimension of identity
55 | * @param size the size of the vector
56 | * @return the identity vector
57 | */
58 | public static Vector e(int i, int size) {
59 | double[] result = new double[size];
60 | result[i] = 1;
61 | return new DenseVector(result);
62 | }
63 |
64 | /**
65 | * Get the identity 1 vector of the given size
66 | * @param size the size
67 | * @return the identity vector
68 | */
69 | public static Vector e(int size) {
70 | return e(0, size);
71 | }
72 |
73 |
74 | }
75 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/ABAGAILArraysTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.ABAGAILArrays;
4 |
5 | /**
6 | * A test main for the ABAGAIL utilities
7 | * @author Andrew Guillory gtg008g@mail.gatech.edu
8 | * @version 1.0
9 | */
10 | public class ABAGAILArraysTest {
11 |
12 | /**
13 | * Test main
14 | * @param args ignored
15 | */
16 | public static void main(String[] args) {
17 | double[] numbers = new double[100];
18 | for (int i = 0; i < numbers.length; i++) {
19 | numbers[i] = i;
20 | }
21 | System.out.println(ABAGAILArrays.randomizedSelect(numbers, 11));
22 | System.out.println(ABAGAILArrays.search(numbers, 21));
23 | double[] test = new double[] {.1, 1};
24 | System.out.println(ABAGAILArrays.search(test, .2));
25 | }
26 |
27 | }
28 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/BidiagonalDecompositionTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.BidiagonalDecomposition;
4 | import util.linalg.Matrix;
5 | import util.linalg.RectangularMatrix;
6 |
7 | /**
8 | * A test of the bidiagonal decomposition
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class BidiagonalDecompositionTest {
13 |
14 | /**
15 | * Test main, creates a matrix and decomposes and reconstructs
16 | * @param args ignored
17 | */
18 | public static void main(String[] args) {
19 | // double[][] a = {
20 | // { 1, 2, 3, 1 },
21 | // { 3, 5, 2, 4 },
22 | // { 1, 5, 2, 2 },
23 | // { 0, 5, 2, 3 },
24 | // { 1, 2, 1, 1 },
25 | // };
26 | // double[][] a = {
27 | // { 1, 2, 3, 1, 1, 2 },
28 | // { 3, 5, 2, 4, 4, 2 },
29 | // { 5, 4, 6, 2, 2, 2 },
30 | // };
31 | double[][] a = {
32 | { 1, 2, 3, 1},
33 | { 3, 5, 2, 1},
34 | { 2, 4, 6, 2},
35 | };
36 | Matrix m = new RectangularMatrix(a);
37 | BidiagonalDecomposition bd = new BidiagonalDecomposition(m);
38 | System.out.println(m);
39 | System.out.println(bd.getU());
40 | System.out.println(bd.getV());
41 | System.out.println(
42 | bd.getU().times(bd.getB()).times(bd.getV().transpose()));
43 | System.out.println(
44 | bd.getU().times(bd.getU().transpose()));
45 | System.out.println(
46 | bd.getV().times(bd.getV().transpose()));
47 | System.out.println(bd.getB());
48 | System.out.println(
49 | bd.getU().transpose().times(m).times(bd.getV()));
50 | }
51 |
52 |
53 | }
54 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/CholeskyFactorizationTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.CholeskyFactorization;
4 | import util.linalg.DenseVector;
5 | import util.linalg.Matrix;
6 | import util.linalg.RectangularMatrix;
7 | import util.linalg.Vector;
8 |
9 | /**
10 | * A test of the cholesky factorization
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public class CholeskyFactorizationTest {
15 |
16 |
17 | /**
18 | * The test main
19 | * @param args ingored
20 | */
21 | public static void main(String[] args) {
22 | double[][] a = {
23 | { 4, 3, 2, 1},
24 | { 3, 4, 3, 2},
25 | { 2, 3, 4, 3},
26 | { 1, 2, 3, 4}
27 | };
28 | Matrix m = new RectangularMatrix(a);
29 | CholeskyFactorization cf = new CholeskyFactorization(m);
30 | System.out.println(m);
31 | System.out.println(cf.getL());
32 | System.out.println(cf.getL().times(cf.getL().transpose()));
33 | System.out.println(cf.determinant());
34 | double[] b = {1, 0, 0, 0};
35 | Vector v = new DenseVector(b);
36 | Vector x = cf.solve(v);
37 | System.out.println(x);
38 | System.out.println(m.times(x));
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/EigenvalueDecompositionTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.RealSchurDecomposition;
4 | import util.linalg.Matrix;
5 | import util.linalg.RectangularMatrix;
6 |
7 | /**
8 | * A test of the eigenvalue decomposition
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class EigenvalueDecompositionTest {
13 | /**
14 | * The test main
15 | * @param args ingored
16 | */
17 | public static void main(String[] args) {
18 | // double[][] a = {
19 | // { 1, 2, 3, 1 },
20 | // { 3, 5, 2, 4 },
21 | // { 1, 5, 2, 2 },
22 | // { 0, 5, 2, 3 },
23 | // };
24 | double[][] a = {
25 | { 1, .79072 },
26 | { .79072, 1 }
27 | };
28 | // double[][] a = {
29 | // { 1, 2, 3 },
30 | // { 4, 5, 6 },
31 | // { 7, 8, 0 }
32 | // };
33 | // double[][] a = {
34 | // { 4, 3, 2, 1},
35 | // { 3, 4, 3, 2},
36 | // { 2, 3, 4, 3},
37 | // { 1, 2, 3, 4}
38 | // };
39 | Matrix m = new RectangularMatrix(a);
40 | RealSchurDecomposition ed = new RealSchurDecomposition(m);
41 | System.out.println(m);
42 | System.out.println(ed.getU());
43 | System.out.println(ed.getT());
44 | System.out.println(ed.getU().transpose().times(ed.getU()));
45 | System.out.println(
46 | ed.getU().times(ed.getT()).times(ed.getU().transpose()));
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/HessenbergDecompositionTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.HessenbergDecomposition;
4 | import util.linalg.Matrix;
5 | import util.linalg.RectangularMatrix;
6 |
7 | /**
8 | * A test for the hessenberg decomposition class
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class HessenbergDecompositionTest {
13 |
14 | /**
15 | * Test main, creates a matrix and decomposes and reconstructs
16 | * @param args ignored
17 | */
18 | public static void main(String[] args) {
19 | // double[][] a = {
20 | // { 1, 5, 7, 8 },
21 | // { 3, 0, 6, 8 },
22 | // { 4, 3, 1, 8 },
23 | // { 1, 2, 3, 4 }
24 | // };
25 | double[][] a = {
26 | { 4, 3, 2, 1},
27 | { 3, 4, 3, 2},
28 | { 2, 3, 4, 3},
29 | { 1, 2, 3, 4}
30 | };
31 |
32 | Matrix m = new RectangularMatrix(a);
33 | HessenbergDecomposition hd = new HessenbergDecomposition(m);
34 | System.out.println(m);
35 | System.out.println(hd.getU());
36 | System.out.println(hd.getH());
37 | System.out.println(
38 | hd.getU().times(hd.getH()).times(hd.getU().transpose()));
39 | System.out.println(
40 | hd.getU().times(hd.getU().transpose()));
41 | }
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/HouseholderReflectionTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.HouseholderReflection;
4 | import util.linalg.Matrix;
5 | import util.linalg.RectangularMatrix;
6 | import util.linalg.Vector;
7 |
8 | /**
9 | * A test of householder reflections
10 | * @author Andrew Guillory gtg008g@mail.gatech.edu
11 | * @version 1.0
12 | */
13 | public class HouseholderReflectionTest {
14 |
15 | /**
16 | * The test main
17 | * @param args ignored
18 | */
19 | public static void main(String[] args) {
20 | double[][] a = {
21 | { 1, 2, 3 },
22 | { 3, 5, 2 },
23 | { 1, 5, 2 },
24 | { 1, 2, 1 }
25 | };
26 | Matrix m = new RectangularMatrix(a);
27 | Vector x = m.getRow(0);
28 | System.out.println(x);
29 | HouseholderReflection hr1 = new HouseholderReflection(x);
30 | hr1.applyRight(m, 0, m.m(), 0, m.n());
31 | System.out.println(m);
32 | x = m.getRow(1);
33 | x = x.get(1, x.size());
34 | System.out.println(x);
35 | HouseholderReflection hr2 = new HouseholderReflection(x);
36 | hr2.applyRight(m, 1, m.m(), 1, m.n());
37 | System.out.println(m);
38 | Matrix q = RectangularMatrix.eye(3);
39 | hr1.applyLeft(q, 0, q.m(), 0, q.n());
40 | hr2.applyLeft(q, 1, q.m(), 0, q.n());
41 | System.out.println(m.times(q));
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/LUDecompositionTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.DenseVector;
4 | import util.linalg.LUDecomposition;
5 | import util.linalg.Matrix;
6 | import util.linalg.RectangularMatrix;
7 | import util.linalg.Vector;
8 |
9 | /**
10 | * A test of the LU decomposition
11 | * @author Andrew Guillory gtg008g@mail.gatech.edu
12 | * @version 1.0
13 | */
14 | public class LUDecompositionTest {
15 |
16 | /**
17 | * Test main, creates a matrix and decomposes and reconstructs
18 | * @param args ignored
19 | */
20 | public static void main(String[] args) {
21 | // double[][] a = {
22 | // {1, 2},
23 | // {3, 4},
24 | // {5, 6}
25 | // };
26 | // double[][] a = {
27 | // { 1, 2, 3},
28 | // { 4, 5, 6}
29 | // };
30 | double[][] a = {
31 | { 1, 2, 3 },
32 | { 4, 5, 6 },
33 | { 7, 8, 0 }
34 | };
35 | Matrix m = new RectangularMatrix(a);
36 | LUDecomposition lu = new LUDecomposition(m);
37 | System.out.println(m);
38 | System.out.println(lu.getL());
39 | System.out.println(lu.getU());
40 | System.out.println(lu.getL().times(lu.getU()));
41 | double[] b = {2, 4, 3};
42 | Vector v = new DenseVector(b);
43 | Vector x = lu.solve(v);
44 | System.out.println(x);
45 | System.out.println(m.times(x));
46 | System.out.println(lu.determinant());
47 | }
48 |
49 | }
50 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/LowerTriangularMatrixTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.LowerTriangularMatrix;
4 | import util.linalg.RectangularMatrix;
5 |
6 | /**
7 | * A test for lower triangular matrices
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public class LowerTriangularMatrixTest {
12 |
13 | /**
14 | * Test main,
15 | * @param args ignored
16 | */
17 | public static void main(String[] args) {
18 | double[][] a = {
19 | { 1, 0, 0, 0 },
20 | { 3, 5, 0, 0 },
21 | { 4, 3, 6, 0 },
22 | { 1, 2, 3, 4 }
23 | };
24 |
25 | LowerTriangularMatrix lm = new LowerTriangularMatrix(new RectangularMatrix(a));
26 | System.out.println(lm);
27 | System.out.println(lm.inverse());
28 | System.out.println(lm.inverse().times(lm));
29 | System.out.println(lm.times(lm.inverse()));
30 | }
31 |
32 | }
33 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/QRDecompositionTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.Matrix;
4 | import util.linalg.QRDecomposition;
5 | import util.linalg.RectangularMatrix;
6 |
7 | /**
8 | * A test of the QR Decomposition
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class QRDecompositionTest {
13 |
14 | /**
15 | * Test main, creates a matrix and decomposes and reconstructs
16 | * @param args ignored
17 | */
18 | public static void main(String[] args) {
19 | // double[][] a = {
20 | // { 1, 2, 3, 4 },
21 | // { 3, 5, 2, 5 },
22 | // { 1, 5, 2, 6 },
23 | // };
24 | // double[][] a = {
25 | // { 1, 2, 3},
26 | // { 3, 5, 2},
27 | // { 1, 5, 2},
28 | // { 6, 3, 2}
29 | // };
30 | double[][] a = {
31 | { 1, 2, 3 },
32 | { 3, 5, 2 },
33 | { 1, 5, 2 },
34 | };
35 | Matrix m = new RectangularMatrix(a);
36 | QRDecomposition qrd = new QRDecomposition(m);
37 | System.out.println(m);
38 | System.out.println(qrd.getQ());
39 | System.out.println(qrd.getR());
40 | System.out.println(qrd.getQ().times(qrd.getR()));
41 | System.out.println(qrd.getQ().times(qrd.getQ().transpose()));
42 | }
43 |
44 | }
45 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/SymmetricEigenvalueDecompositionTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.Matrix;
4 | import util.linalg.RectangularMatrix;
5 | import util.linalg.SymmetricEigenvalueDecomposition;
6 |
7 | /**
8 | * Test of the symmetric eigenvalue decomposition
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class SymmetricEigenvalueDecompositionTest {
13 |
14 | /**
15 | * The test main
16 | * @param args ingored
17 | */
18 | public static void main(String[] args) {
19 | // double[][] a = {
20 | // { 1, .79072 },
21 | // { .79072, 1 }
22 | // };
23 | double[][] a = {
24 | { 4, 3, 2, 1},
25 | { 3, 4, 3, 2},
26 | { 2, 3, 4, 3},
27 | { 1, 2, 3, 4}
28 | };
29 | Matrix m = new RectangularMatrix(a);
30 | SymmetricEigenvalueDecomposition ed =
31 | new SymmetricEigenvalueDecomposition(m);
32 | System.out.println(m);
33 | System.out.println(ed.getU());
34 | System.out.println(ed.getD());
35 | System.out.println(ed.getU().transpose().times(ed.getU()));
36 | System.out.println(
37 | ed.getU().times(ed.getD()).times(ed.getU().transpose()));
38 | }
39 |
40 | }
41 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/TridiagonalDecompositionTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.Matrix;
4 | import util.linalg.RectangularMatrix;
5 | import util.linalg.TridiagonalDecomposition;
6 |
7 | /**
8 | * A test of the tridiagonal decomposition
9 | * @author Andrew Guillory gtg008g@mail.gatech.edu
10 | * @version 1.0
11 | */
12 | public class TridiagonalDecompositionTest {
13 |
14 |
15 | /**
16 | * Test main, creates a matrix and decomposes and reconstructs
17 | * @param args ignored
18 | */
19 | public static void main(String[] args) {
20 | double[][] a = {
21 | { 4, 3, 2, 1},
22 | { 3, 4, 3, 2},
23 | { 2, 3, 4, 3},
24 | { 1, 2, 3, 4}
25 | };
26 | // double[][] a = {
27 | // { 1, 3, 4 },
28 | // { 3, 2, 8 },
29 | // { 4, 8, 3 }
30 | // };
31 | Matrix m = new RectangularMatrix(a);
32 | TridiagonalDecomposition td = new TridiagonalDecomposition(m);
33 | System.out.println(m);
34 | System.out.println(td.getU());
35 | System.out.println(td.getT());
36 | System.out.println(
37 | td.getU().times(td.getT()).times(td.getU().transpose()));
38 | System.out.println(
39 | td.getU().times(td.getU().transpose()));
40 | }
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/assignment2/ABAGAIL/src/util/test/UpperTriangularMatrixTest.java:
--------------------------------------------------------------------------------
1 | package util.test;
2 |
3 | import util.linalg.RectangularMatrix;
4 | import util.linalg.UpperTriangularMatrix;
5 |
6 | /**
7 | * A test for lower triangular matrices
8 | * @author Andrew Guillory gtg008g@mail.gatech.edu
9 | * @version 1.0
10 | */
11 | public class UpperTriangularMatrixTest {
12 |
13 | /**
14 | * Test main,
15 | * @param args ignored
16 | */
17 | public static void main(String[] args) {
18 | double[][] a = {
19 | { 1, 5, 3, 7 },
20 | { 0, 5, 1, 6 },
21 | { 0, 0, 6, 2 },
22 | { 0, 0, 0, 4 }
23 | };
24 |
25 | UpperTriangularMatrix um = new UpperTriangularMatrix(new RectangularMatrix(a));
26 | System.out.println(um);
27 | System.out.println(um.inverse());
28 | System.out.println(um.inverse().times(um));
29 | System.out.println(um.times(um.inverse()));
30 | System.out.println(um.inverse().transpose().times(um.transpose()));
31 | }
32 |
33 | }
34 |
--------------------------------------------------------------------------------
/assignment2/data:
--------------------------------------------------------------------------------
1 | ../data
--------------------------------------------------------------------------------
/assignment2/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy == 1.15.1
2 | scipy == 1.1.0
3 | scikit-learn == 0.20.0
4 | pandas == 0.23.4
5 | xlrd == 0.9.0
6 | matplotlib == 2.2.3
7 | seaborn == 0.9.0
8 | scikit-optimize == 0.5.2
9 | kneed == 0.1.0
10 |
--------------------------------------------------------------------------------
/assignment3/data:
--------------------------------------------------------------------------------
1 | ../data
--------------------------------------------------------------------------------
/assignment3/experiments/scoring.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics import make_scorer, accuracy_score, f1_score
2 | from sklearn.utils import compute_sample_weight
3 |
4 |
5 | # Adapted from https://github.com/JonathanTay/CS-7641-assignment-1/blob/master/helpers.py
6 | def balanced_accuracy(truth, pred):
7 | wts = compute_sample_weight('balanced', truth)
8 | return accuracy_score(truth, pred, sample_weight=wts)
9 |
10 |
11 | def f1_accuracy(truth, pred):
12 | wts = compute_sample_weight('balanced', truth)
13 | return f1_score(truth, pred, average="binary", sample_weight=wts)
14 |
15 |
16 | scorer = make_scorer(balanced_accuracy)
17 | f1_scorer = make_scorer(f1_accuracy)
18 |
19 |
20 | def get_scorer(dataset):
21 | if not dataset.balanced:
22 | return f1_scorer, f1_accuracy
23 |
24 | return scorer, balanced_accuracy
25 |
--------------------------------------------------------------------------------
/assignment3/requirements-no-tables.txt:
--------------------------------------------------------------------------------
1 | numpy == 1.15.1
2 | scipy == 1.1.0
3 | scikit-learn == 0.20.0
4 | pandas == 0.23.4
5 | xlrd == 0.9.0
6 | matplotlib == 2.2.3
7 | seaborn == 0.9.0
8 | scikit-optimize == 0.5.2
9 | kneed == 0.1.0
10 |
--------------------------------------------------------------------------------
/assignment3/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy == 1.15.1
2 | scipy == 1.1.0
3 | scikit-learn == 0.20.0
4 | pandas == 0.23.4
5 | tables == 3.4.4
6 | xlrd == 0.9.0
7 | matplotlib == 2.2.3
8 | seaborn == 0.9.0
9 | scikit-optimize == 0.5.2
10 | kneed == 0.1.0
11 |
--------------------------------------------------------------------------------
/assignment3/run_clustering.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # Replace 'X' below with the optimal values found
4 | # If you want to first generate data and updated datasets, remove the "--skiprerun" flags below
5 |
6 | python run_experiment.py --ica --dataset1 --dim X --skiprerun --verbose --threads -1 > ica-dataset1-clustering.log 2>&1
7 | python run_experiment.py --ica --dataset2 --dim X --skiprerun --verbose --threads -1 > ica-dataset2-clustering.log 2>&1
8 | python run_experiment.py --pca --dataset1 --dim X --skiprerun --verbose --threads -1 > pca-dataset1-clustering.log 2>&1
9 | python run_experiment.py --pca --dataset2 --dim X --skiprerun --verbose --threads -1 > pca-dataset2-clustering.log 2>&1
10 | python run_experiment.py --rp --dataset1 --dim X --skiprerun --verbose --threads -1 > rp-dataset1-clustering.log 2>&1
11 | python run_experiment.py --rp --dataset2 --dim X --skiprerun --verbose --threads -1 > rp-dataset2-clustering.log 2>&1
12 | python run_experiment.py --rf --dataset1 --dim X --skiprerun --verbose --threads -1 > rf-dataset1-clustering.log 2>&1
13 | python run_experiment.py --rf --dataset2 --dim X --skiprerun --verbose --threads -1 > rf-dataset2-clustering.log 2>&1
14 |
15 | #python run_experiment.py --svd --dataset1 --dim X --skiprerun --verbose --threads -1 > svd-dataset1-clustering.log 2>&1
16 | #python run_experiment.py --svd --dataset2 --dim X --skiprerun --verbose --threads -1 > svd-dataset2-clustering.log 2>&1
17 |
--------------------------------------------------------------------------------
/assignment4/README.md:
--------------------------------------------------------------------------------
1 | # Markov Decision Processes
2 |
3 | ## Output
4 | Output CSVs and images are written to `./output` and `./output/images` respectively. Sub-folders will be created for
5 | each RL algorithm (PI, VI, and Q) as well as one for the final report data.
6 |
7 | If these folders do not exist the experiments module will attempt to create them.
8 |
9 | Graphing:
10 | ---------
11 |
12 | The run_experiment script can be use to generate plots via:
13 |
14 | ```
15 | python run_experiment.py --plot
16 | ```
17 |
18 | Since the files output from the experiments follow a common naming scheme this will determine the problem, algorithm,
19 | and parameters as needed and write the output to sub-folders in `./output/images` and `./output/report`.
20 |
21 |
--------------------------------------------------------------------------------
/assignment4/environments/__init__.py:
--------------------------------------------------------------------------------
1 | import gym
2 | from gym.envs.registration import register
3 |
4 | from .cliff_walking import *
5 | from .frozen_lake import *
6 |
7 | __all__ = ['RewardingFrozenLakeEnv', 'WindyCliffWalkingEnv']
8 |
9 | register(
10 | id='RewardingFrozenLake-v0',
11 | entry_point='environments:RewardingFrozenLakeEnv',
12 | kwargs={'map_name': '4x4'},
13 | )
14 |
15 | register(
16 | id='RewardingFrozenLake8x8-v0',
17 | entry_point='environments:RewardingFrozenLakeEnv',
18 | kwargs={'map_name': '8x8'}
19 | )
20 |
21 | register(
22 | id='RewardingFrozenLakeNoRewards20x20-v0',
23 | entry_point='environments:RewardingFrozenLakeEnv',
24 | kwargs={'map_name': '20x20', 'rewarding': False}
25 | )
26 |
27 | register(
28 | id='RewardingFrozenLakeNoRewards8x8-v0',
29 | entry_point='environments:RewardingFrozenLakeEnv',
30 | kwargs={'map_name': '8x8', 'rewarding': False}
31 | )
32 |
33 | register(
34 | id='WindyCliffWalking-v0',
35 | entry_point='environments:WindyCliffWalkingEnv',
36 | )
37 |
38 |
39 | def get_rewarding_frozen_lake_environment():
40 | return gym.make('RewardingFrozenLake8x8-v0')
41 |
42 |
43 | def get_frozen_lake_environment():
44 | return gym.make('FrozenLake-v0')
45 |
46 |
47 | def get_rewarding_no_reward_frozen_lake_environment():
48 | return gym.make('RewardingFrozenLakeNoRewards8x8-v0')
49 |
50 |
51 | def get_large_rewarding_no_reward_frozen_lake_environment():
52 | return gym.make('RewardingFrozenLakeNoRewards20x20-v0')
53 |
54 |
55 | def get_cliff_walking_environment():
56 | return gym.make('CliffWalking-v0')
57 |
58 |
59 | def get_windy_cliff_walking_environment():
60 | return gym.make('WindyCliffWalking-v0')
61 |
--------------------------------------------------------------------------------
/assignment4/experiments/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from .policy_iteration import *
3 | from .value_iteration import *
4 | from .q_learner import *
5 | # from .plotting import *
6 |
7 | __all__ = ['policy_iteration', 'value_iteration', 'q_learner']
8 |
--------------------------------------------------------------------------------
/assignment4/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy == 1.15.1
2 | scipy == 1.1.0
3 | scikit-learn == 0.19.2
4 | pandas == 0.23.4
5 | xlrd == 0.9.0
6 | matplotlib == 2.2.3
7 | seaborn == 0.9.0
8 | scikit-optimize == 0.5.2
9 | gym == 0.10.5
10 |
--------------------------------------------------------------------------------
/assignment4/solvers/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .base import *
3 |
4 | from .policy_iteration import *
5 | from .q_learning import *
6 | from .value_iteration import *
7 |
8 | __all__ = ['base', 'policy_iteration', 'q_learning', 'value_iteration']
9 |
10 |
--------------------------------------------------------------------------------
/data/default of credit card clients.xls:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cmaron/CS-7641-assignments/ae1d14ffd7ab043ec412faf40aaebdda182f3201/data/default of credit card clients.xls
--------------------------------------------------------------------------------