├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── logback.xsd ├── pom.xml └── src ├── main ├── java │ ├── logback.xsd │ └── quickml │ │ ├── MathUtils.java │ │ ├── collections │ │ ├── MapUtils.java │ │ └── ValueSummingMap.java │ │ ├── data │ │ ├── AttributesMap.java │ │ ├── NegativeWeightsFilter.java │ │ ├── OnespotDateTimeExtractor.java │ │ ├── PredictionMap.java │ │ └── instances │ │ │ ├── ClassifierInstance.java │ │ │ ├── ClassifierInstanceFactory.java │ │ │ ├── Instance.java │ │ │ ├── InstanceFactory.java │ │ │ ├── InstanceImpl.java │ │ │ ├── InstanceWithAttributesMap.java │ │ │ ├── RegressionInstance.java │ │ │ ├── RidgeInstance.java │ │ │ ├── SparseClassifierInstanceFactory.java │ │ │ └── SparseRegressionInstance.java │ │ ├── experiments │ │ ├── GeoDistance.java │ │ ├── TrainingDataGenerator2.java │ │ └── kin88nm.java │ │ ├── supervised │ │ ├── EnhancedPredictiveModelBuilder.java │ │ ├── PredictiveModel.java │ │ ├── PredictiveModelBuilder.java │ │ ├── PredictiveModelsFromPreviousVersionsToBenchMarkAgainst │ │ │ ├── OldScorer.java │ │ │ ├── OldTree.java │ │ │ ├── OldTreeBuilder.java │ │ │ ├── oldScorers │ │ │ │ ├── GiniImpurityOldScorer.java │ │ │ │ ├── InformationGainOldScorer.java │ │ │ │ ├── MSEOldScorer.java │ │ │ │ ├── MSEOldScorerWithCrossValidationCorrection.java │ │ │ │ └── SplitDiffOldScorer.java │ │ │ └── oldTree │ │ │ │ ├── OldAttributeValueWithClassificationCounter.java │ │ │ │ ├── OldBranch.java │ │ │ │ ├── OldCategoricalOldBranch.java │ │ │ │ ├── OldClassificationCounter.java │ │ │ │ ├── OldLeaf.java │ │ │ │ ├── OldNode.java │ │ │ │ ├── OldNumericBranch.java │ │ │ │ └── oldAttributeIgnoringStrategies │ │ │ │ ├── AttributeIgnoringStrategy.java │ │ │ │ ├── AttributeName.java │ │ │ │ ├── AttributeNameAndParent.java │ │ │ │ ├── AttributeProperties.java │ │ │ │ ├── CompositeAttributeIgnoringStrategy.java │ │ │ │ ├── IgnoreAttributesInSet.java │ │ │ │ └── IgnoreAttributesWithConstantProbability.java │ │ ├── Utils.java │ │ ├── calibratedPredictiveModel │ │ │ └── CalibratedClassifier.java │ │ ├── classifier │ │ │ ├── AbstractClassifier.java │ │ │ ├── Classifier.java │ │ │ ├── Classifiers.java │ │ │ ├── downsampling │ │ │ │ ├── DownsamplingClassifier.java │ │ │ │ ├── DownsamplingClassifierBuilder.java │ │ │ │ ├── DownsamplingUtils.java │ │ │ │ ├── RandomDroppingInstanceFilter.java │ │ │ │ └── package-info.java │ │ │ ├── logisticRegression │ │ │ │ ├── DataTransformer.java │ │ │ │ ├── DatedAndMeanNormalizedLogisticRegressionDataTransformer.java │ │ │ │ ├── GradientDescent.java │ │ │ │ ├── InstanceTransformerUtils.java │ │ │ │ ├── LogisticRegression.java │ │ │ │ ├── LogisticRegressionBuilder.java │ │ │ │ ├── LogisticRegressionDTO.java │ │ │ │ ├── MeanNormalizedAndDatedLogisticRegressionDTO.java │ │ │ │ ├── SparseClassifierInstance.java │ │ │ │ ├── SparseSGD.java │ │ │ │ ├── StandardDataTransformer.java │ │ │ │ ├── TransformedData.java │ │ │ │ └── TransformedDataWithDates.java │ │ │ ├── splitOnAttribute │ │ │ │ ├── SplitOnAttributeClassifier.java │ │ │ │ ├── SplitOnAttributeClassifierBuilder.java │ │ │ │ └── SplitValTGroupIdMap.java │ │ │ ├── temporallyWeightClassifier │ │ │ │ ├── TemporallyReweightedClassifier.java │ │ │ │ └── TemporallyReweightedClassifierBuilder.java │ │ │ └── twoStageModel │ │ │ │ ├── TwoStageClassifier.java │ │ │ │ └── TwoStageModelBuilder.java │ │ ├── collaborativeFiltering │ │ │ ├── CollaborativeFilter.java │ │ │ ├── UserItem.java │ │ │ └── gradientDescent │ │ │ │ └── GradientDescentCF.java │ │ ├── crossValidation │ │ │ ├── ClassifierLossChecker.java │ │ │ ├── CrossValidator.java │ │ │ ├── EnhancedCrossValidator.java │ │ │ ├── InstanceTargetSelector.java │ │ │ ├── LossChecker.java │ │ │ ├── MultiTargetLossChecker.java │ │ │ ├── PredictionMapResult.java │ │ │ ├── PredictionMapResults.java │ │ │ ├── RegressionLossChecker.java │ │ │ ├── SimpleCrossValidator.java │ │ │ ├── SimpleCrossValidatorWithWriter.java │ │ │ ├── attributeImportance │ │ │ │ ├── AttributeImportanceFinder.java │ │ │ │ ├── AttributeImportanceFinderBuilder.java │ │ │ │ ├── AttributeLossSummary.java │ │ │ │ ├── AttributeLossTracker.java │ │ │ │ ├── AttributeWithLoss.java │ │ │ │ ├── LossFunctionTracker.java │ │ │ │ ├── RegAttributeImportanceFinder.java │ │ │ │ ├── RegAttributeImportanceFinderBuilder.java │ │ │ │ ├── RegAttributeLossSummary.java │ │ │ │ ├── RegAttributeLossTracker.java │ │ │ │ └── RegLossFunctionTracker.java │ │ │ ├── data │ │ │ │ ├── FoldedData.java │ │ │ │ ├── FoldedDataFactory.java │ │ │ │ ├── OutOfTimeData.java │ │ │ │ ├── OutOfTimeDataFactory.java │ │ │ │ ├── TrainingDataCycler.java │ │ │ │ └── TrainingDataCyclerFactory.java │ │ │ ├── genAttributeImportance │ │ │ │ ├── AttributeImportanceFinder.java │ │ │ │ ├── AttributeImportanceFinderBuilder.java │ │ │ │ ├── AttributeLossSummary.java │ │ │ │ ├── AttributeLossTracker.java │ │ │ │ ├── AttributeWithLoss.java │ │ │ │ └── LossFunctionTracker.java │ │ │ ├── lossfunctions │ │ │ │ ├── LabelPredictionWeight.java │ │ │ │ ├── LossFunction.java │ │ │ │ ├── LossFunctionCorrectedForDownsampling.java │ │ │ │ ├── LossFunctions.java │ │ │ │ ├── classifierLossFunctions │ │ │ │ │ ├── ClassifierLogCVLossFunction.java │ │ │ │ │ ├── ClassifierLossFunction.java │ │ │ │ │ ├── ClassifierMSELossFunction.java │ │ │ │ │ ├── ClassifierRMSELossFunction.java │ │ │ │ │ └── WeightedAUCCrossValLossFunction.java │ │ │ │ ├── rankingLossFunctions │ │ │ │ │ ├── NDCG.java │ │ │ │ │ └── RankingLossFunction.java │ │ │ │ └── regressionLossFunctions │ │ │ │ │ ├── RegressionLossFunction.java │ │ │ │ │ └── RegressionRMSELossFunction.java │ │ │ ├── movingAverages │ │ │ │ ├── ArithmeticAverage.java │ │ │ │ ├── HoltWintersMovingAverage.java │ │ │ │ └── MovingAverage.java │ │ │ └── utils │ │ │ │ ├── AttributesHashSplitter.java │ │ │ │ ├── DateTimeExtractor.java │ │ │ │ ├── MeanNormalizedDateTimeExtractor.java │ │ │ │ └── SimpleDateFormatExtractor.java │ │ ├── dataProcessing │ │ │ ├── AttributeCharacteristics.java │ │ │ ├── BasicTrainingDataSurveyor.java │ │ │ ├── BinaryAttributeCharacteristics.java │ │ │ ├── ElementaryDataTransformer.java │ │ │ └── instanceTranformer │ │ │ │ ├── BinaryAndNumericAttributeNormalizer.java │ │ │ │ ├── ClassifierInstance2SparseClassifierInstance.java │ │ │ │ ├── CommonCoocurrenceProductFeatureAppender.java │ │ │ │ ├── InstanceTransformer.java │ │ │ │ ├── LabelToDigitConverter.java │ │ │ │ ├── MeanNormalizeAllNumericAttributes.java │ │ │ │ ├── OneHotEncoder.java │ │ │ │ └── ProductFeatureAppender.java │ │ ├── ensembles │ │ │ └── randomForest │ │ │ │ ├── RandomForest.java │ │ │ │ ├── RandomForestBuilder.java │ │ │ │ ├── randomDecisionForest │ │ │ │ ├── RandomDecisionForest.java │ │ │ │ └── RandomDecisionForestBuilder.java │ │ │ │ └── randomRegressionForest │ │ │ │ ├── RandomRegressionForest.java │ │ │ │ └── RandomRegressionForestBuilder.java │ │ ├── featureEngineering1 │ │ │ ├── AttributesEnrichStrategy.java │ │ │ ├── AttributesEnricher.java │ │ │ ├── FeatureEngineeredClassifier.java │ │ │ ├── FeatureEngineeringClassifierBuilder.java │ │ │ ├── InstanceEnricher.java │ │ │ └── enrichStrategies │ │ │ │ ├── attributeCombiner │ │ │ │ ├── AttributeCombiningEnrichStrategy.java │ │ │ │ └── AttributeCombiningEnricher.java │ │ │ │ └── probabilityInjector │ │ │ │ ├── ProbabilityEnrichStrategy.java │ │ │ │ └── ProbabilityInjectingEnricher.java │ │ ├── inspection │ │ │ ├── AttributeScore.java │ │ │ ├── CategoricalDistributionSampler.java │ │ │ ├── NumericDistributionSampler.java │ │ │ └── RandomForestDumper.java │ │ ├── parametricModels │ │ │ ├── LinearDerivative.java │ │ │ ├── LogisticDerivative.java │ │ │ ├── OptimizableCostFunction.java │ │ │ ├── OptimizableCostFunctionImp.java │ │ │ ├── ParallelizedLogisticDerivative.java │ │ │ └── SGD.java │ │ ├── predictiveModelOptimizer │ │ │ ├── ConfigWithLoss.java │ │ │ ├── FieldValueRecommender.java │ │ │ ├── MultiLossModelTester.java │ │ │ ├── PredictiveModelOptimizer.java │ │ │ ├── SimplePredictiveModelOptimizerBuilder.java │ │ │ └── fieldValueRecommenders │ │ │ │ ├── FixedOrderRecommender.java │ │ │ │ └── MonotonicConvergenceRecommender.java │ │ ├── rankingModels │ │ │ ├── ItemToOutcomeMap.java │ │ │ ├── LabelPredictionWeightForRanking.java │ │ │ ├── RankingInstance.java │ │ │ ├── RankingLossChecker.java │ │ │ ├── RankingModel.java │ │ │ ├── RankingPrediction.java │ │ │ └── Utils.java │ │ ├── regressionModel │ │ │ ├── IsotonicRegression │ │ │ │ └── PoolAdjacentViolatorsModel.java │ │ │ ├── LinearRegression │ │ │ │ ├── RidgeLinearModel.java │ │ │ │ └── RidgeLinearModelBuilder.java │ │ │ ├── LinearRegression2 │ │ │ │ ├── LinearModel.java │ │ │ │ ├── LinearRegressionDTO.java │ │ │ │ ├── MeanNormalizedAndDatedLinearRegressionDTO.java │ │ │ │ └── SimpleRidgeRegressionBuilder.java │ │ │ ├── MultiVariableRealValuedFunction.java │ │ │ └── SingleVariableRealValuedFunction.java │ │ └── tree │ │ │ ├── Tree.java │ │ │ ├── TreeBuilderHelper.java │ │ │ ├── attributeIgnoringStrategies │ │ │ ├── AttributeIgnoringStrategy.java │ │ │ ├── CompositeAttributeIgnoringStrategy.java │ │ │ ├── IgnoreAttributesInSet.java │ │ │ └── IgnoreAttributesWithConstantProbability.java │ │ │ ├── attributeValueIgnoringStrategies │ │ │ ├── AttributeValueIgnoringStrategy.java │ │ │ └── AttributeValueIgnoringStrategyBuilder.java │ │ │ ├── bagging │ │ │ ├── Bagging.java │ │ │ └── StationaryBagging.java │ │ │ ├── branchFinders │ │ │ ├── BranchFinder.java │ │ │ ├── BranchFinderAndReducerFactory.java │ │ │ ├── NumericBranchFinder.java │ │ │ ├── SortableLabelsCategoricalBranchFinder.java │ │ │ ├── SplittingUtils.java │ │ │ └── branchFinderBuilders │ │ │ │ ├── AlternativeSelction.java │ │ │ │ └── BranchFinderBuilder.java │ │ │ ├── branchingConditions │ │ │ ├── BranchingConditions.java │ │ │ └── StandardBranchingConditions.java │ │ │ ├── constants │ │ │ ├── AttributeType.java │ │ │ ├── BranchType.java │ │ │ ├── ForestOptions.java │ │ │ └── MissingValue.java │ │ │ ├── decisionTree │ │ │ ├── DecisionTree.java │ │ │ ├── DecisionTreeBuilder.java │ │ │ ├── DecisionTreeBuilderHelper.java │ │ │ ├── DecisionTreeVisualizer.java │ │ │ ├── OptimizedDecisionForest.java │ │ │ ├── attributeValueIgnoringStrategies │ │ │ │ ├── BinaryClassAttributeValueIgnoringStrategy.java │ │ │ │ ├── BinaryClassAttributeValueIgnoringStrategyBuilder.java │ │ │ │ ├── MultiClassAtributeValueIgnoringStrategy.java │ │ │ │ └── MultiClassAttributeValueIgnoringStrategyBuilder.java │ │ │ ├── branchFinders │ │ │ │ ├── DTBinaryCatBranchFinder.java │ │ │ │ ├── DTNClassCatBranchFinder.java │ │ │ │ ├── DTNumBranchFinder.java │ │ │ │ ├── OldBinCatBranchFinder.java │ │ │ │ └── branchFinderBuilders │ │ │ │ │ ├── DTBinaryCatBranchFinderBuilder.java │ │ │ │ │ ├── DTBranchFinderBuilder.java │ │ │ │ │ ├── DTCatBranchFinderBuilder.java │ │ │ │ │ ├── DTNumBranchFinderBuilder.java │ │ │ │ │ └── OldBinaryCatBranchFinderBuilder.java │ │ │ ├── branchingConditions │ │ │ │ └── DTBranchingConditions.java │ │ │ ├── nodes │ │ │ │ ├── DTCatBranch.java │ │ │ │ ├── DTLeaf.java │ │ │ │ ├── DTLeafBuilder.java │ │ │ │ └── DTNumBranch.java │ │ │ ├── reducers │ │ │ │ ├── DTBinaryCatBranchReducer.java │ │ │ │ ├── DTCatBranchReducer.java │ │ │ │ ├── DTNumBranchReducer.java │ │ │ │ ├── DTOldCatBranchReducer.java │ │ │ │ ├── DTreeReducer.java │ │ │ │ └── reducerFactories │ │ │ │ │ ├── DTBinaryCatBranchReducerFactory.java │ │ │ │ │ ├── DTCatBranchReducerFactory.java │ │ │ │ │ ├── DTNumBranchReducerFactory.java │ │ │ │ │ └── DTOldCatBranchReducerFactory.java │ │ │ ├── scorers │ │ │ │ ├── GRPenalizedGiniImpurityScorer.java │ │ │ │ ├── GRPenalizedGiniImpurityScorerFactory.java │ │ │ │ ├── PenalizedGiniImpurityScorer.java │ │ │ │ ├── PenalizedGiniImpurityScorerFactory.java │ │ │ │ ├── PenalizedInformationGainScorer.java │ │ │ │ ├── PenalizedInformationGainScorerFactory.java │ │ │ │ ├── PenalizedMSEScorer.java │ │ │ │ ├── PenalizedMSEScorerFactory.java │ │ │ │ ├── PenalizedSplitDiffScorer.java │ │ │ │ └── PenalizedSplitDiffScorerFactory.java │ │ │ ├── treeBuildContexts │ │ │ │ ├── DTreeContext.java │ │ │ │ └── DTreeContextBuilder.java │ │ │ └── valueCounters │ │ │ │ ├── ClassificationCounter.java │ │ │ │ └── ClassificationCounterProducer.java │ │ │ ├── nodes │ │ │ ├── Branch.java │ │ │ ├── Leaf.java │ │ │ ├── LeafBuilder.java │ │ │ ├── LeafDepthStats.java │ │ │ ├── Node.java │ │ │ ├── NumBranch.java │ │ │ └── WeightAndMeanTracker.java │ │ │ ├── reducers │ │ │ ├── AttributeStatisticsProducer.java │ │ │ ├── AttributeStats.java │ │ │ ├── Reducer.java │ │ │ └── ReducerFactory.java │ │ │ ├── regressionTree │ │ │ ├── OptimizedRegressionForests.java │ │ │ ├── RegressionTree.java │ │ │ ├── RegressionTreeBuilder.java │ │ │ ├── RegressionTreeBuilderHelper.java │ │ │ ├── RegressionTreeVisualizer.java │ │ │ ├── attributeValueIgnoringStrategies │ │ │ │ ├── RegTreeAttributeValueIgnoringStrategy.java │ │ │ │ └── RegTreeAttributeValueIgnoringStrategyBuilder.java │ │ │ ├── branchFinders │ │ │ │ ├── RTCatBranchFinder.java │ │ │ │ ├── RTNumBranchFinder.java │ │ │ │ └── branchFinderBuilders │ │ │ │ │ ├── RTBranchFinderBuilder.java │ │ │ │ │ ├── RTCatBranchFinderBuilder.java │ │ │ │ │ └── RTNumBranchFinderBuilder.java │ │ │ ├── branchingConditions │ │ │ │ └── RTBranchingConditions.java │ │ │ ├── nodes │ │ │ │ ├── RTCatBranch.java │ │ │ │ ├── RTLeaf.java │ │ │ │ ├── RTLeafBuilder.java │ │ │ │ └── RTNumBranch.java │ │ │ ├── reducers │ │ │ │ ├── RTCatBranchReducer.java │ │ │ │ ├── RTNumBranchReducer.java │ │ │ │ ├── RTreeReducer.java │ │ │ │ └── reducerFactories │ │ │ │ │ ├── RTCatBranchReducerFactory.java │ │ │ │ │ └── RTNumBranchReducerFactory.java │ │ │ ├── scorers │ │ │ │ ├── PenalizedMSEScorer.java │ │ │ │ └── RTPenalizedMSEScorerFactory.java │ │ │ ├── treeBuildContexts │ │ │ │ ├── RTreeContext.java │ │ │ │ └── RTreeContextBuilder.java │ │ │ └── valueCounters │ │ │ │ ├── MeanValueCounter.java │ │ │ │ └── MeanValueCounterProducer.java │ │ │ ├── scorers │ │ │ ├── GRImbalancedScorer.java │ │ │ ├── GRImbalancedScorerFactory.java │ │ │ ├── GRScorer.java │ │ │ ├── GRScorerFactory.java │ │ │ ├── Scorer.java │ │ │ └── ScorerFactory.java │ │ │ ├── summaryStatistics │ │ │ ├── ValueCounter.java │ │ │ ├── ValueCounterProducer.java │ │ │ ├── ValueStatistics.java │ │ │ └── ValueStatisticsOperations.java │ │ │ └── treeBuildContexts │ │ │ ├── TreeContext.java │ │ │ └── TreeContextBuilder.java │ │ ├── unsupervised │ │ └── clustering │ │ │ └── Clusterer.java │ │ └── utlities │ │ ├── CSVToInstanceReader.java │ │ ├── CSVToInstanceReaderBuilder.java │ │ ├── CSVToMapOfNumericLists.java │ │ ├── InstancesToCsvWriter.java │ │ ├── LibSVMFormatReader.java │ │ ├── LinePlotter.java │ │ ├── LinePlotterBuilder.java │ │ ├── SerializationUtility.java │ │ └── selectors │ │ ├── CSVToMapOfObjectLists.java │ │ ├── CategoricalSelector.java │ │ ├── ExplicitCategoricalSelector.java │ │ ├── ExplicitNumericSelector.java │ │ └── NumericSelector.java └── resources │ └── logback.xml └── test ├── java └── quickml │ ├── BenchmarkTest.java │ ├── InstanceLoader.java │ ├── InstanceLoaderTest.java │ ├── MapUtilsTest.java │ ├── TestUtils.java │ ├── TrainingInstance.java │ ├── collections │ └── ValueSummingMapTest.java │ └── supervised │ ├── JsonInstanceLoader.java │ ├── OldTreeBuildTimeTest.java │ ├── PredictiveAccuracyTests.java │ ├── UtilsTest.java │ ├── classifier │ ├── ClassifiersTest.java │ ├── TreeBuilderTestUtils.java │ ├── logRegression │ │ ├── InstanceTransformerUtilsTest.java │ │ ├── RidgeRegressionBuilderTest.java │ │ └── SparseSGDTest.java │ ├── randomForest │ │ └── TestIrisAccuracy.java │ ├── splitOnAttribute │ │ └── SplitOnAttributeClassifierBuilderTest.java │ └── temporallyWeightClassifier │ │ └── TemporallyReweightedClassifierBuilderTest.java │ ├── crossValidation │ ├── InterfacesCompilationTest.java │ ├── PredictionMapResultsTest.java │ ├── SimpleCrossValidatorIntegrationTest.java │ ├── attributeImportance │ │ ├── AttributeImportanceFinderIntegrationTest.java │ │ ├── AttributeImportanceFinderIntegrationTestOld.java │ │ ├── AttributeLossSummaryTest.java │ │ └── LossFunctionTrackerTest.java │ ├── data │ │ ├── FoldedDataTest.java │ │ └── OutOfTimeDataTest.java │ └── lossfunctions │ │ ├── ClassifierMSELossFunctionTest.java │ │ ├── LossFunctionsTest.java │ │ ├── WeightedAUCCrossValLossFunctionTest.java │ │ └── rankingLossFunctions │ │ └── NDCGTest.java │ ├── dataProcessing │ └── instanceTranformer │ │ ├── CommonCoocurrenceProductFeatureAppenderTest.java │ │ ├── OneHotEncoderTest.java │ │ └── ProductFeatureAppenderTest.java │ ├── downsampling │ ├── DownsamplingClassifierBuilderTest.java │ └── DownsamplingPredictiveModelTest.java │ ├── featureEngineering │ ├── AttributeCombiningEnricherTest.java │ ├── FeatureEngineeringClassifierBuilderTest.java │ ├── ProbabilityEnrichStrategyTest.java │ └── ProbabilityInjectingEnricherTest.java │ ├── inspection │ ├── CategoricalDistributionSamplerTest.java │ └── NumericDistributionSamplerTest.java │ ├── predictiveModelOptimizer │ ├── PredictiveModelOptimizerIntegrationTest.java │ ├── PredictiveModelOptimizerTest.java │ └── fieldValueRecommenders │ │ └── MonotonicConvergenceRecommenderTest.java │ ├── regressionModel │ ├── PoolAdjacentViolatorsModelTest.java │ └── RidgeRegressionBuilderTest.java │ └── tree │ ├── branchFinders │ └── SplittingUtilsTest.java │ ├── decisionTree │ ├── DecisionOldOldTreeBuilderTest.java │ ├── OldClassificationCounterTest.java │ ├── attributeIgnoringStrategies │ │ └── AttributeIgnoringStrategiesTests.java │ └── reducers │ │ ├── BinaryCatOldBranchReducerTest.java │ │ ├── DTCatOldBranchReducerTest.java │ │ └── DTNumOldBranchReducerTest.java │ ├── nodes │ └── OldLeafDepthStatsTest.java │ └── scorers │ ├── GiniImpurityScorerTest.java │ ├── PenalizedInformationGainScorerTest.java │ └── PenalizedMSEScorerTest.java └── resources └── quickml ├── advertisingData.csv.gz ├── diabetesDataset.txt.gz ├── iris.data.gz └── mobo1.json.gz /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | test-output/ 3 | *.iml 4 | .project 5 | .classpath 6 | testdata/ 7 | /.idea/ 8 | .settings/ 9 | local-* 10 | logback.xsd 11 | src/test/resources/onespot_training_instances_small.json 12 | src/test/resources/onespot_training_instances_large.json 13 | 14 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | sudo: false 3 | after_success: 4 | - mvn clean cobertura:cobertura coveralls:report 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is the source repository for the QuickML Java machine learning library. Please visit [quickml.org](http://quickml.org/) for more information. 2 | -------------------------------------------------------------------------------- /src/main/java/quickml/MathUtils.java: -------------------------------------------------------------------------------- 1 | package quickml; 2 | 3 | /** 4 | * Created by alexanderhawk on 10/12/15. 5 | */ 6 | public class MathUtils { 7 | public static double sigmoid(double z) { 8 | return 1 / (1 + Math.exp(-z)); 9 | } 10 | 11 | public static double cappedlogBase2(double z, double minZ) { 12 | double x = Math.max(z, minZ); 13 | return Math.log(x)/Math.log(2); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/quickml/collections/MapUtils.java: -------------------------------------------------------------------------------- 1 | package quickml.collections; 2 | 3 | import com.google.common.base.Optional; 4 | import java.util.Map; 5 | import java.util.Random; 6 | 7 | public class MapUtils { 8 | public static final Random random = new Random(); 9 | 10 | public static final > Optional> getEntryWithLowestValue(Map map) { 11 | Optional> entryWithLowestValue = Optional.absent(); 12 | for (Map.Entry kvEntry : map.entrySet()) { 13 | if (!entryWithLowestValue.isPresent() || entryWithLowestValue.get().getValue().compareTo(kvEntry.getValue()) >= 0){ 14 | entryWithLowestValue = Optional.of(kvEntry); 15 | } 16 | } 17 | return entryWithLowestValue; 18 | } 19 | 20 | 21 | public static final > Optional> getEntryWithHighestValue(Map map) { 22 | Optional> entryWithHighestValue = Optional.absent(); 23 | for (Map.Entry kvEntry : map.entrySet()) { 24 | if (!entryWithHighestValue.isPresent() || entryWithHighestValue.get().getValue().compareTo(kvEntry.getValue()) <= 0){ 25 | entryWithHighestValue = Optional.of(kvEntry); 26 | } 27 | } 28 | return entryWithHighestValue; 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/quickml/data/NegativeWeightsFilter.java: -------------------------------------------------------------------------------- 1 | package quickml.data; 2 | 3 | import com.google.common.base.Predicate; 4 | import com.google.common.collect.Iterables; 5 | import quickml.data.instances.Instance; 6 | 7 | import java.util.HashSet; 8 | 9 | /** 10 | * Created by alexanderhawk on 5/1/14. 11 | */ 12 | public class NegativeWeightsFilter { 13 | 14 | //TODO[mk] is this being used? 15 | //parametrize training data or subtype it to have right params 16 | public static Iterable> filterNegativeWeights(Iterable> trainingData) { 17 | final HashSet instanceLookUp = new HashSet(); 18 | for (Instance instance : trainingData) 19 | if (instance.getWeight() < 0) 20 | instanceLookUp.add(instance.getAttributes()); 21 | 22 | Predicate> predicate = new Predicate>() { 23 | @Override 24 | public boolean apply(final Instance instance) { 25 | if (instanceLookUp.contains(instance.getAttributes())) 26 | return false; 27 | else 28 | return true; 29 | } 30 | }; 31 | return Iterables.filter(trainingData, predicate); 32 | } 33 | } 34 | 35 | -------------------------------------------------------------------------------- /src/main/java/quickml/data/OnespotDateTimeExtractor.java: -------------------------------------------------------------------------------- 1 | package quickml.data; 2 | 3 | 4 | import org.joda.time.DateTime; 5 | import quickml.data.instances.ClassifierInstance; 6 | import quickml.data.instances.InstanceWithAttributesMap; 7 | import quickml.supervised.crossValidation.utils.DateTimeExtractor; 8 | 9 | 10 | public class OnespotDateTimeExtractor implements DateTimeExtractor { 11 | 12 | @Override 13 | public DateTime extractDateTime(T instance) { 14 | int year = attrVal(instance, "timeOfArrival-year"); 15 | int month = attrVal(instance,"timeOfArrival-monthOfYear"); 16 | int day = attrVal(instance,"timeOfArrival-dayOfMonth"); 17 | int hour = attrVal(instance, "timeOfArrival-hourOfDay"); 18 | int minute = attrVal(instance, "timeOfArrival-minuteOfHour"); 19 | return new DateTime(year, month, day, hour, minute, 0, 0); 20 | } 21 | 22 | private int attrVal(T instance, String attrName) { 23 | return instance.getAttributes().containsKey(attrName) ? 24 | ((Number) instance.getAttributes().get(attrName)).intValue() : 1 ; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/quickml/data/instances/ClassifierInstance.java: -------------------------------------------------------------------------------- 1 | package quickml.data.instances; 2 | 3 | import org.joda.time.DateTime; 4 | import quickml.data.AttributesMap; 5 | 6 | import java.io.Serializable; 7 | 8 | /** 9 | * Created by alexanderhawk on 4/14/15. 10 | */ 11 | public class ClassifierInstance extends InstanceWithAttributesMap { 12 | public DateTime timeStamp; 13 | public ClassifierInstance(AttributesMap attributes, Serializable label) { 14 | super(attributes, label, 1.0); 15 | } 16 | public ClassifierInstance(AttributesMap attributes, Serializable label, double weight) { 17 | super(attributes, label, weight); 18 | } 19 | public ClassifierInstance(AttributesMap attributes, Serializable label, DateTime timeStamp) { 20 | super(attributes, label, 1.0); 21 | this.timeStamp = timeStamp; 22 | } 23 | 24 | } 25 | 26 | -------------------------------------------------------------------------------- /src/main/java/quickml/data/instances/ClassifierInstanceFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.data.instances; 2 | 3 | import quickml.data.AttributesMap; 4 | 5 | import java.io.Serializable; 6 | 7 | /** 8 | * Created by alexanderhawk on 10/14/15. 9 | */ 10 | public class ClassifierInstanceFactory implements InstanceFactory { 11 | @Override 12 | public ClassifierInstance createInstance(AttributesMap attributes, Serializable label, double weight) { 13 | return new ClassifierInstance(attributes, label, weight); 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/quickml/data/instances/Instance.java: -------------------------------------------------------------------------------- 1 | package quickml.data.instances; 2 | 3 | 4 | import org.joda.time.DateTime; 5 | 6 | import java.io.Serializable; 7 | 8 | /** 9 | * Created with IntelliJ IDEA. 10 | * User: ian 11 | * Date: 6/27/13 12 | * Time: 1:22 PM 13 | * To change this template use File | Settings | File Templates. 14 | */ 15 | public interface Instance { 16 | 17 | public A getAttributes(); // TODO rename to getInput 18 | 19 | public L getLabel(); // TODO rename to getOuput 20 | 21 | public double getWeight(); 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/quickml/data/instances/InstanceFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.data.instances; 2 | 3 | /** 4 | * Created by alexanderhawk on 10/14/15. 5 | */ 6 | public interface InstanceFactory { 7 | I createInstance(A attributes, L label, double weight); 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/quickml/data/instances/RegressionInstance.java: -------------------------------------------------------------------------------- 1 | package quickml.data.instances; 2 | 3 | import quickml.data.AttributesMap; 4 | 5 | import java.io.Serializable; 6 | 7 | /** 8 | * Created by alexanderhawk on 4/14/15. 9 | */ 10 | public class RegressionInstance extends InstanceWithAttributesMap { 11 | public RegressionInstance(AttributesMap attributes, Double label) { 12 | super(attributes, label, 1.0); 13 | } 14 | public RegressionInstance(AttributesMap attributes, Double label, double weight) { 15 | super(attributes, label, weight); 16 | } 17 | public RegressionInstance(AttributesMap attributes, Double label, double weight, double alternativeTarget) { 18 | super(attributes, label, weight); 19 | this.alternativeTarget = alternativeTarget; 20 | } 21 | public double alternativeTarget; 22 | public long id; 23 | 24 | } 25 | 26 | -------------------------------------------------------------------------------- /src/main/java/quickml/data/instances/RidgeInstance.java: -------------------------------------------------------------------------------- 1 | package quickml.data.instances; 2 | 3 | import quickml.data.instances.Instance; 4 | 5 | import java.io.Serializable; 6 | 7 | public class RidgeInstance implements Instance{ 8 | 9 | private double[] attributes; 10 | private Serializable label; 11 | 12 | public RidgeInstance(double[] attributes, Serializable label) { 13 | this.attributes = attributes; 14 | this.label = label; 15 | } 16 | 17 | @Override 18 | public double[] getAttributes() { 19 | return attributes; 20 | } 21 | 22 | @Override 23 | public Serializable getLabel() { 24 | return label; 25 | } 26 | 27 | @Override 28 | public double getWeight() { 29 | return 1.0; 30 | } 31 | 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/quickml/data/instances/SparseClassifierInstanceFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.data.instances; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.supervised.classifier.logisticRegression.SparseClassifierInstance; 5 | 6 | import java.io.Serializable; 7 | import java.util.Map; 8 | 9 | /** 10 | * Created by alexanderhawk on 10/14/15. 11 | */ 12 | public class SparseClassifierInstanceFactory implements InstanceFactory { 13 | private Map nameToIndexMap; 14 | 15 | public SparseClassifierInstanceFactory(Map nameToIndexMap) { 16 | this.nameToIndexMap = nameToIndexMap; 17 | } 18 | 19 | 20 | @Override 21 | public SparseClassifierInstance createInstance(AttributesMap attributes, Serializable label, double weight) { 22 | return new SparseClassifierInstance(attributes, label, weight, nameToIndexMap); 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/EnhancedPredictiveModelBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised; 2 | 3 | import quickml.data.instances.Instance; 4 | import quickml.supervised.PredictiveModel; 5 | import quickml.supervised.classifier.logisticRegression.DataTransformer; 6 | import quickml.supervised.classifier.logisticRegression.TransformedData; 7 | 8 | import java.io.Serializable; 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | /** 13 | * Created by alexanderhawk on 10/30/15. 14 | */ 15 | public interface EnhancedPredictiveModelBuilder

> 16 | extends DataTransformer { 17 | 18 | P buildPredictiveModel(D transformedData); 19 | void updateBuilderConfig(final Map config); 20 | 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModel.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised; 2 | 3 | import java.io.Serializable; 4 | import java.util.Set; 5 | 6 | /** 7 | * A predictive model, typically created by a supervised learning algorithm. 8 | * Given a set of attributes, will generate a prediction. 9 | */ 10 | public interface PredictiveModel extends Serializable { 11 | 12 | P predict(A attributes); 13 | P predictWithoutAttributes(A attributes, Set attributesToIgnore); 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised; 2 | 3 | import quickml.data.instances.Instance; 4 | 5 | import java.io.Serializable; 6 | import java.util.Map; 7 | 8 | /** 9 | * A supervised learning algorithm, which, given data, will generate a PredictiveModel. 10 | */ 11 | public interface PredictiveModelBuilder { 12 | 13 | public PM buildPredictiveModel(Iterable trainingData); 14 | 15 | public void updateBuilderConfig(Map config); 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/OldScorer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst; 2 | 3 | 4 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldClassificationCounter; 5 | 6 | /** 7 | * The scorerFactory is responsible for assessing the quality of a "split" of data. 8 | */ 9 | public interface OldScorer { 10 | /** 11 | * Assess the quality of a separation of data 12 | * 13 | * @param a 14 | * A count of the number of classifications with a given 15 | * getBestClassification in split a 16 | * @param b 17 | * A count of the number of classifications with a given 18 | * getBestClassification in split b 19 | * @return A score, where a higher value indicates a better split. A value 20 | * of 0 being the lowest, and indicating no value. 21 | */ 22 | public double scoreSplit(OldClassificationCounter a, OldClassificationCounter b); 23 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldScorers/GiniImpurityOldScorer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldScorers; 2 | 3 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.OldScorer; 4 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldClassificationCounter; 5 | 6 | import java.io.Serializable; 7 | import java.util.Map; 8 | 9 | /** 10 | * Created by chrisreeves on 6/24/14. 11 | */ 12 | public class GiniImpurityOldScorer implements OldScorer { 13 | @Override 14 | public double scoreSplit(OldClassificationCounter a, OldClassificationCounter b) { 15 | OldClassificationCounter parent = OldClassificationCounter.merge(a, b); 16 | double parentGiniIndex = getGiniIndex(parent); 17 | double aGiniIndex = getGiniIndex(a) * a.getTotal() / parent.getTotal() ; 18 | double bGiniIndex = getGiniIndex(b) * b.getTotal() / parent.getTotal(); 19 | return parentGiniIndex - aGiniIndex - bGiniIndex; 20 | } 21 | 22 | private double getGiniIndex(OldClassificationCounter cc) { 23 | double sum = 0.0d; 24 | for (Map.Entry e : cc.getCounts().entrySet()) { 25 | double error = (cc.getTotal() > 0) ? e.getValue() / cc.getTotal() : 0; 26 | sum += error * error; 27 | } 28 | return 1.0d - sum; 29 | } 30 | 31 | @Override 32 | public String toString() { 33 | return "GiniImpurity"; 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldScorers/InformationGainOldScorer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldScorers; 2 | 3 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.OldScorer; 4 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldClassificationCounter; 5 | 6 | import java.io.Serializable; 7 | import java.util.Map; 8 | 9 | /** 10 | * Created by chrisreeves on 6/24/14. 11 | */ 12 | public class InformationGainOldScorer implements OldScorer { 13 | 14 | @Override 15 | public double scoreSplit(OldClassificationCounter a, OldClassificationCounter b) { 16 | double parentEntropy = calculateEntropy(OldClassificationCounter.merge(a, b)); 17 | double aEntropy = calculateEntropy(a); 18 | double bEntropy = calculateEntropy(b); 19 | return calculateGain(parentEntropy, aEntropy, bEntropy, a.getTotal(), b.getTotal()); 20 | } 21 | 22 | private double calculateEntropy(OldClassificationCounter cc) { 23 | double entropy = 0; 24 | 25 | for (Map.Entry e : cc.getCounts().entrySet()) { 26 | double error = (cc.getTotal() > 0) ? e.getValue() / cc.getTotal() : 0; 27 | entropy += -error * (Math.log(error) / Math.log(2)); 28 | } 29 | 30 | return entropy; 31 | } 32 | 33 | private double calculateGain(double rootEntropy, double aEntropy, double bEntropy, double aSize, double bSize) { 34 | double aAdjustedEntropy = (aSize / (aSize+bSize)) * aEntropy; 35 | double bAdjustedEntropy = (bSize / (aSize+bSize)) * bEntropy; 36 | return rootEntropy - aAdjustedEntropy - bAdjustedEntropy; 37 | } 38 | 39 | @Override 40 | public String toString() { 41 | return "InformationGain"; 42 | } 43 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldScorers/SplitDiffOldScorer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldScorers; 2 | 3 | import com.google.common.collect.Sets; 4 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.OldScorer; 5 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldClassificationCounter; 6 | 7 | import java.io.Serializable; 8 | 9 | public final class SplitDiffOldScorer implements OldScorer { 10 | 11 | /* 12 | * The general idea here is that a good split is one where the proportions 13 | * of classifications on each side of the split are as different as 14 | * possible. eg. if 50% of the classifications in set A are "dog", then the 15 | * further away from 50% the proportion of "dog" classifications in set B 16 | * are, the better. 17 | * 18 | * We therefore add up the differences between the proportions, however we 19 | * have another goal, which is that its preferable for the sets to be of 20 | * close to equal size. Without this requirement a split with 0 on one size 21 | * would get a high score because all of the proportions on that side would 22 | * be 0. 23 | * 24 | * So, we multiply the score by the size of the smallest side, which 25 | * experimentally seems to provide an adequate bias against one-sided 26 | * splits. 27 | */ 28 | 29 | @Override 30 | public double scoreSplit(final OldClassificationCounter a, final OldClassificationCounter b) { 31 | double score = 0; 32 | for (final Serializable value : Sets.union(a.allClassifications(), b.allClassifications())) { 33 | final double aProp = (double) a.getCount(value) / a.getTotal(); 34 | final double bProp = (double) b.getCount(value) / b.getTotal(); 35 | 36 | score += Math.abs(aProp - bProp) * Math.min(a.getTotal(), b.getTotal()); 37 | } 38 | return score; 39 | } 40 | 41 | public String toString() { 42 | return "SplitDiffScorer"; 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldTree/OldAttributeValueWithClassificationCounter.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * Created by alexanderhawk on 7/1/14. 7 | */ 8 | public class OldAttributeValueWithClassificationCounter { 9 | public Serializable attributeValue; 10 | public OldClassificationCounter classificationCounter; 11 | public OldAttributeValueWithClassificationCounter(Serializable attributeValue, OldClassificationCounter classificationCounter) { 12 | this.attributeValue = attributeValue; 13 | this.classificationCounter = classificationCounter; 14 | } 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldTree/OldCategoricalOldBranch.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree; 2 | 3 | import com.google.common.collect.Sets; 4 | 5 | import java.io.Serializable; 6 | import java.util.Map; 7 | import java.util.Set; 8 | 9 | public final class OldCategoricalOldBranch extends OldBranch { 10 | private static final long serialVersionUID = -1723969623146234761L; 11 | public final Set inSet; 12 | 13 | public OldCategoricalOldBranch(OldNode parent, final String attribute, final Set inSet, double probabilityOfTrueChild) { 14 | super(parent, attribute, probabilityOfTrueChild); 15 | this.inSet = Sets.newHashSet(inSet); 16 | 17 | } 18 | 19 | @Override 20 | public boolean decide(final Map attributes) { 21 | Serializable attributeVal = attributes.get(attribute); 22 | //missing values always go the way of the outset...which strangely seems to be most accurate 23 | return inSet.contains(attributeVal); 24 | } 25 | 26 | @Override 27 | public String toString() { 28 | return attribute + " in " + inSet; 29 | } 30 | 31 | @Override 32 | public String toNotString() { 33 | return attribute + " not in " + inSet; 34 | } 35 | 36 | @Override 37 | public boolean equals(final Object o) { 38 | if (this == o) return true; 39 | if (o == null || getClass() != o.getClass()) return false; 40 | if (!super.equals(o)) return false; 41 | 42 | final OldCategoricalOldBranch that = (OldCategoricalOldBranch) o; 43 | 44 | if (!inSet.equals(that.inSet)) return false; 45 | 46 | return true; 47 | } 48 | 49 | @Override 50 | public int hashCode() { 51 | int result = super.hashCode(); 52 | result = 31 * result + inSet.hashCode(); 53 | return result; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldTree/oldAttributeIgnoringStrategies/AttributeIgnoringStrategy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.oldAttributeIgnoringStrategies; 2 | 3 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldBranch; 4 | 5 | import java.io.Serializable; 6 | 7 | /** 8 | * Created by alexanderhawk on 2/28/15. 9 | */ 10 | public interface AttributeIgnoringStrategy extends Serializable { 11 | 12 | /** 13 | * Should this attribute be ignored 14 | * @param attribute 15 | * @param parent 16 | * @return 17 | */ 18 | boolean ignoreAttribute(String attribute, OldBranch parent); 19 | 20 | /** 21 | * @return a copy of this AttributeIgnoringStrategy 22 | */ 23 | AttributeIgnoringStrategy copy(); 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldTree/oldAttributeIgnoringStrategies/AttributeName.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.oldAttributeIgnoringStrategies; 2 | 3 | /** 4 | * Created by alexanderhawk on 3/2/15. 5 | */ 6 | public class AttributeName { 7 | public final String attribute; 8 | 9 | public AttributeName(String attribute) { 10 | this.attribute = attribute; 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldTree/oldAttributeIgnoringStrategies/AttributeNameAndParent.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.oldAttributeIgnoringStrategies; 2 | 3 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldBranch; 4 | 5 | /** 6 | * Created by alexanderhawk on 3/2/15. 7 | */ 8 | public class AttributeNameAndParent { 9 | public final String attribute; 10 | public final OldBranch oldBranch; 11 | 12 | public AttributeNameAndParent(String attribute, OldBranch oldBranch) { 13 | this.attribute = attribute; 14 | this.oldBranch = oldBranch; 15 | } 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldTree/oldAttributeIgnoringStrategies/AttributeProperties.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.oldAttributeIgnoringStrategies; 2 | 3 | /** 4 | * Created by alexanderhawk on 3/2/15. 5 | */ 6 | public class AttributeProperties { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldTree/oldAttributeIgnoringStrategies/CompositeAttributeIgnoringStrategy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.oldAttributeIgnoringStrategies; 2 | 3 | import com.google.common.collect.Lists; 4 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldBranch; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * Created by alexanderhawk on 2/28/15. 10 | */ 11 | public class CompositeAttributeIgnoringStrategy implements AttributeIgnoringStrategy { 12 | private List attributeIgnoringStrategies = Lists.newArrayList(); 13 | 14 | public CompositeAttributeIgnoringStrategy(List attributeIgnoringStrategies) { 15 | this.attributeIgnoringStrategies = attributeIgnoringStrategies; 16 | } 17 | 18 | @Override 19 | public CompositeAttributeIgnoringStrategy copy() { 20 | List copies = Lists.newArrayList(); 21 | for (AttributeIgnoringStrategy attributeIgnoringStrategy : attributeIgnoringStrategies) { 22 | copies.add(attributeIgnoringStrategy.copy()); 23 | } 24 | return new CompositeAttributeIgnoringStrategy(copies); 25 | } 26 | 27 | @Override 28 | public boolean ignoreAttribute(String attribute, OldBranch parent) { 29 | for (AttributeIgnoringStrategy attributeIgnoringStrategy : attributeIgnoringStrategies) { 30 | if (attributeIgnoringStrategy.ignoreAttribute(attribute, parent)) { 31 | return true; 32 | } 33 | } 34 | return false; 35 | } 36 | 37 | @Override 38 | public String toString() { 39 | return "CompositeAttributeIgnoringStrategy{" + 40 | "oldAttributeIgnoringStrategies=" + attributeIgnoringStrategies + 41 | '}'; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/PredictiveModelsFromPreviousVersionsToBenchMarkAgainst/oldTree/oldAttributeIgnoringStrategies/IgnoreAttributesWithConstantProbability.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.oldAttributeIgnoringStrategies; 2 | 3 | import quickml.supervised.PredictiveModelsFromPreviousVersionsToBenchMarkAgainst.oldTree.OldBranch; 4 | 5 | import java.util.concurrent.ThreadLocalRandom; 6 | 7 | /** 8 | * Created by alexanderhawk on 2/28/15. 9 | */ 10 | public class IgnoreAttributesWithConstantProbability implements AttributeIgnoringStrategy { 11 | 12 | private final double ignoreAttributeProbability; 13 | private ThreadLocalRandom random = ThreadLocalRandom.current(); 14 | 15 | public IgnoreAttributesWithConstantProbability(double ignoreAttributeProbability) { 16 | this.ignoreAttributeProbability = ignoreAttributeProbability; 17 | } 18 | 19 | @Override 20 | public IgnoreAttributesWithConstantProbability copy(){ 21 | return new IgnoreAttributesWithConstantProbability(ignoreAttributeProbability); 22 | } 23 | 24 | @Override 25 | public boolean ignoreAttribute(String attribute, OldBranch parent) { 26 | if (random.nextDouble() < ignoreAttributeProbability) { 27 | return true; 28 | } 29 | return false; 30 | } 31 | 32 | public double getIgnoreAttributeProbability() { 33 | return ignoreAttributeProbability; 34 | } 35 | 36 | @Override 37 | public String toString(){ 38 | return "ignoreAttributeProbability = " + ignoreAttributeProbability; 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/AbstractClassifier.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.data.PredictionMap; 5 | 6 | import java.io.Serializable; 7 | import java.util.Set; 8 | 9 | /** 10 | * Created by alexanderhawk on 8/17/14. 11 | */ 12 | //where do we want Classifier as a generic type...in downsampling PMB. 13 | public abstract class AbstractClassifier implements Classifier { 14 | 15 | 16 | private static final long serialVersionUID = -5052476771686106526L; 17 | public double getProbability(AttributesMap attributes, Serializable classification) { 18 | return predict(attributes).get(classification); 19 | } 20 | 21 | public double getProbabilityWithoutAttributes(AttributesMap attributes, Serializable classification, Set attributesToIgnore) { 22 | return predictWithoutAttributes(attributes, attributesToIgnore).get(classification); 23 | } 24 | 25 | public Serializable getClassificationByMaxProb(AttributesMap attributes) { 26 | PredictionMap predictions = predict(attributes); 27 | Serializable mostProbableClass = null; 28 | double probabilityOfMostProbableClass = 0; 29 | for (Serializable key : predictions.keySet()) { 30 | if (predictions.get(key).doubleValue() > probabilityOfMostProbableClass) { 31 | mostProbableClass = key; 32 | probabilityOfMostProbableClass = predictions.get(key).doubleValue(); 33 | } 34 | } 35 | return mostProbableClass; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/Classifier.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.data.PredictionMap; 5 | import quickml.supervised.PredictiveModel; 6 | 7 | import java.io.Serializable; 8 | import java.util.Map; 9 | import java.util.Set; 10 | 11 | /** 12 | * Created by alexanderhawk on 7/29/14. 13 | */ 14 | public interface Classifier extends PredictiveModel { 15 | 16 | double getProbability(AttributesMap attributes, Serializable classification); 17 | double getProbabilityWithoutAttributes(AttributesMap attributes, Serializable classification, Set attributesToIgnore); 18 | PredictionMap predict(AttributesMap attributes); 19 | PredictionMap predictWithoutAttributes(AttributesMap attributes, Set attributesToIgnore); 20 | Serializable getClassificationByMaxProb(AttributesMap attributes); 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/downsampling/DownsamplingUtils.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.downsampling; 2 | 3 | /** 4 | * Created by ian on 4/23/14. 5 | */ 6 | public class DownsamplingUtils { 7 | public static double correctProbability(final double dropProbability, final double uncorrectedProbability) { 8 | return (1.0 - dropProbability)*uncorrectedProbability / (1.0 - dropProbability * uncorrectedProbability); 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/downsampling/RandomDroppingInstanceFilter.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.downsampling; 2 | 3 | import com.google.common.base.Predicate; 4 | import quickml.collections.MapUtils; 5 | import quickml.data.instances.InstanceWithAttributesMap; 6 | 7 | import java.io.Serializable; 8 | 9 | 10 | /** 11 | * Created by ian on 4/23/14. 12 | */ 13 | class RandomDroppingInstanceFilter implements Predicate> { 14 | private final Serializable classificationToDrop; 15 | private final double dropProbability; 16 | 17 | public RandomDroppingInstanceFilter(Serializable classificationToDrop, double dropProbability) { 18 | this.classificationToDrop = classificationToDrop; 19 | this.dropProbability = dropProbability; 20 | } 21 | 22 | @Override 23 | public boolean apply(final InstanceWithAttributesMap Instance) { 24 | if (Instance.getLabel().equals(classificationToDrop)) { 25 | final double rand = MapUtils.random.nextDouble(); 26 | 27 | return rand > dropProbability; 28 | } else { 29 | return true; 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/downsampling/package-info.java: -------------------------------------------------------------------------------- 1 | /** 2 | * A predictive model wrapper and related classes that can be used to improve performance on highly 3 | * imbalanced datasets 4 | * with two possible classifications (a majority classification and a minority classification). 5 | * It works by reducing the imbalance by throwing away, at random, a proportion of the instances 6 | * with the majority classification. It then statistically corrects for this at prediction time. 7 | */ 8 | package quickml.supervised.classifier.downsampling; -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/logisticRegression/DataTransformer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.logisticRegression; 2 | 3 | import quickml.data.instances.Instance; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * Created by alexanderhawk on 10/28/15. 9 | */ 10 | public interface DataTransformer> { 11 | D transformData(List rawInstance); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/logisticRegression/GradientDescent.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.logisticRegression; 2 | 3 | import quickml.data.instances.Instance; 4 | 5 | import java.io.Serializable; 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | /** 10 | * Created by alexanderhawk on 10/12/15. 11 | */ 12 | public interface GradientDescent { 13 | double[] minimize(List instances, int numFeatures); 14 | void updateBuilderConfig(final Map config); 15 | 16 | 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/logisticRegression/LogisticRegressionDTO.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.logisticRegression; 2 | 3 | import java.io.Serializable; 4 | import java.util.HashMap; 5 | import java.util.List; 6 | import java.util.Map; 7 | 8 | /** 9 | * Created by alexanderhawk on 10/28/15. 10 | */ 11 | 12 | public abstract class LogisticRegressionDTO> implements TransformedDataWithDates { 13 | 14 | protected List instances; 15 | protected HashMap nameToIndexMap; 16 | protected Map numericClassLabels; 17 | 18 | 19 | @Override 20 | public List getTransformedInstances() { 21 | return instances; 22 | } 23 | 24 | public HashMap getNameToIndexMap() { 25 | return nameToIndexMap; 26 | } 27 | 28 | 29 | public Map getNumericClassLabels() { 30 | return numericClassLabels; 31 | } 32 | 33 | 34 | 35 | public LogisticRegressionDTO(List instances, 36 | HashMap nameToIndexMap, 37 | Map numericClassLabels) { 38 | this.instances = instances; 39 | this.nameToIndexMap = nameToIndexMap; 40 | this.numericClassLabels = numericClassLabels; 41 | } 42 | 43 | public LogisticRegressionDTO(List instances) { 44 | this.instances = instances; 45 | } 46 | 47 | } 48 | 49 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/logisticRegression/TransformedData.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.logisticRegression; 2 | 3 | import quickml.data.instances.Instance; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * Created by alexanderhawk on 10/30/15. 9 | */ 10 | public interface TransformedData> { 11 | D copyWithJustTrainingSet(List trainingSet); 12 | List getTransformedInstances(); 13 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/logisticRegression/TransformedDataWithDates.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.logisticRegression; 2 | 3 | import quickml.data.instances.Instance; 4 | import quickml.supervised.crossValidation.utils.DateTimeExtractor; 5 | 6 | /** 7 | * Created by alexanderhawk on 10/30/15. 8 | */ 9 | public interface TransformedDataWithDates> extends TransformedData { 10 | 11 | DateTimeExtractor getDateTimeExtractor(); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/splitOnAttribute/SplitValTGroupIdMap.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.splitOnAttribute; 2 | 3 | import java.io.Serializable; 4 | import java.util.HashMap; 5 | 6 | /** 7 | * Created by alexanderhawk on 2/11/15. 8 | */ 9 | public class SplitValTGroupIdMap extends HashMap { 10 | Integer groupId; 11 | public SplitValTGroupIdMap(Integer groupId){ 12 | super(); 13 | this.groupId = groupId; 14 | } 15 | 16 | @Override 17 | public Integer get(Object key) { 18 | return (super.get(key) != null) ? super.get(key) : groupId; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/classifier/temporallyWeightClassifier/TemporallyReweightedClassifier.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.temporallyWeightClassifier; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.data.PredictionMap; 5 | import quickml.supervised.classifier.AbstractClassifier; 6 | import quickml.supervised.classifier.Classifier; 7 | 8 | import java.io.Serializable; 9 | import java.util.Map; 10 | import java.util.Set; 11 | 12 | /** 13 | * Created by alexanderhawk on 6/20/14. 14 | */ 15 | public class TemporallyReweightedClassifier extends AbstractClassifier { 16 | 17 | private static final long serialVersionUID = 2642074639257374588L; 18 | private final Classifier wrappedClassifier; 19 | 20 | public TemporallyReweightedClassifier(Classifier classifier) { 21 | this.wrappedClassifier = classifier; 22 | } 23 | 24 | @Override 25 | public double getProbability(final AttributesMap attributes, final Serializable classification) { 26 | return wrappedClassifier.getProbability(attributes, classification); 27 | } 28 | 29 | @Override 30 | public PredictionMap predict(AttributesMap attributes) { 31 | return wrappedClassifier.predict(attributes); 32 | } 33 | 34 | @Override 35 | public double getProbabilityWithoutAttributes(final AttributesMap attributes, final Serializable classification, Set attributesToIgnore) { 36 | return wrappedClassifier.getProbabilityWithoutAttributes(attributes, classification, attributesToIgnore); 37 | } 38 | 39 | @Override 40 | public PredictionMap predictWithoutAttributes(AttributesMap attributes, Set attributesToIgnore) { 41 | return wrappedClassifier.predictWithoutAttributes(attributes, attributesToIgnore); 42 | } 43 | 44 | @Override 45 | public Serializable getClassificationByMaxProb(final AttributesMap attributes) { 46 | return wrappedClassifier.getClassificationByMaxProb(attributes); 47 | } 48 | 49 | public Classifier getWrappedClassifier() { 50 | return wrappedClassifier; 51 | } 52 | 53 | } 54 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/collaborativeFiltering/CollaborativeFilter.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.collaborativeFiltering; 2 | 3 | 4 | import quickml.supervised.PredictiveModel; 5 | 6 | /** 7 | * Created by ian on 8/16/14. 8 | */ 9 | public abstract class CollaborativeFilter implements PredictiveModel { 10 | private static final long serialVersionUID = -3477404201826772133L; 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/collaborativeFiltering/UserItem.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.collaborativeFiltering; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * Created by ian on 8/16/14. 7 | */ 8 | public class UserItem implements Serializable { 9 | private static final long serialVersionUID = -5759815197196667292L; 10 | private long user, item; 11 | 12 | public UserItem(final long user, final long item) { 13 | 14 | this.user = user; 15 | this.item = item; 16 | } 17 | 18 | public long getUser() { 19 | return user; 20 | } 21 | 22 | public long getItem() { 23 | return item; 24 | } 25 | 26 | @Override 27 | public String toString() { 28 | final StringBuilder sb = new StringBuilder("UserItem{"); 29 | sb.append("user=").append(user); 30 | sb.append(", item=").append(item); 31 | sb.append('}'); 32 | return sb.toString(); 33 | } 34 | 35 | @Override 36 | public boolean equals(final Object o) { 37 | if (this == o) return true; 38 | if (o == null || getClass() != o.getClass()) return false; 39 | 40 | final UserItem userItem = (UserItem) o; 41 | 42 | if (item != userItem.item) return false; 43 | if (user != userItem.user) return false; 44 | 45 | return true; 46 | } 47 | 48 | @Override 49 | public int hashCode() { 50 | int result = (int) (user ^ (user >>> 32)); 51 | result = 31 * result + (int) (item ^ (item >>> 32)); 52 | return result; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/collaborativeFiltering/gradientDescent/GradientDescentCF.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.collaborativeFiltering.gradientDescent; 2 | 3 | import quickml.supervised.collaborativeFiltering.CollaborativeFilter; 4 | import quickml.supervised.collaborativeFiltering.UserItem; 5 | 6 | import java.util.Set; 7 | 8 | /** 9 | * Created by ian on 8/16/14. 10 | */ 11 | public class GradientDescentCF extends CollaborativeFilter { 12 | 13 | private static final long serialVersionUID = 301782468956120672L; 14 | 15 | @Override 16 | public Double predict(final UserItem attributes) { 17 | return null; 18 | } 19 | 20 | @Override 21 | public Double predictWithoutAttributes(final UserItem attributes, Set attributesToIgnore) 22 | { 23 | return null; 24 | } 25 | 26 | 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/ClassifierLossChecker.java: -------------------------------------------------------------------------------- 1 | 2 | package quickml.supervised.crossValidation; 3 | 4 | import quickml.supervised.Utils; 5 | import quickml.data.instances.ClassifierInstance; 6 | import quickml.supervised.classifier.Classifier; 7 | import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction; 8 | 9 | import java.util.List; 10 | 11 | public class ClassifierLossChecker implements LossChecker { 12 | 13 | private ClassifierLossFunction lossFunction; 14 | 15 | public ClassifierLossChecker(ClassifierLossFunction lossFunction) { 16 | this.lossFunction = lossFunction; 17 | } 18 | 19 | @Override 20 | public double calculateLoss(PM predictiveModel, List validationSet) { 21 | return lossFunction.getLoss(Utils.calcResultPredictions(predictiveModel, validationSet)); 22 | } 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/CrossValidator.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation; 2 | 3 | import quickml.data.instances.Instance; 4 | import quickml.supervised.PredictiveModel; 5 | 6 | import java.io.Serializable; 7 | import java.util.HashMap; 8 | import java.util.Map; 9 | 10 | /** 11 | * Created by alexanderhawk on 10/30/15. 12 | */ 13 | public interface CrossValidator { 14 | double getLossForModel(); 15 | double getLossForModel(Map config); 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/InstanceTargetSelector.java: -------------------------------------------------------------------------------- 1 | 2 | package quickml.supervised.crossValidation; 3 | 4 | import quickml.data.instances.ClassifierInstance; 5 | 6 | import java.io.Serializable; 7 | 8 | /** 9 | * Created by alexanderhawk on 4/1/15. 10 | */ 11 | public interface InstanceTargetSelector { 12 | Serializable getSingleLabel(T instance); 13 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/LossChecker.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation; 2 | 3 | 4 | import java.util.List; 5 | 6 | /** 7 | * For a given validation set and predictive model, calculate the total loss 8 | * @param 9 | * @param 10 | */ 11 | 12 | public interface LossChecker { 13 | public double calculateLoss(PM predictiveModel, List validationSet); 14 | } 15 | 16 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/MultiTargetLossChecker.java: -------------------------------------------------------------------------------- 1 | 2 | package quickml.supervised.crossValidation; 3 | 4 | import com.google.common.collect.Lists; 5 | import quickml.data.instances.ClassifierInstance; 6 | import quickml.supervised.Utils; 7 | import quickml.supervised.classifier.Classifier; 8 | import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction; 9 | 10 | import java.util.List; 11 | 12 | public class MultiTargetLossChecker implements LossChecker { 13 | 14 | private ClassifierLossFunction lossFunction; 15 | private InstanceTargetSelector instanceTargetSelector; 16 | 17 | public MultiTargetLossChecker(ClassifierLossFunction lossFunction, InstanceTargetSelector instanceTargets) { 18 | this.lossFunction = lossFunction; 19 | this.instanceTargetSelector = instanceTargets; 20 | } 21 | 22 | @Override 23 | public double calculateLoss(Classifier predictiveModel, List validationSet) { 24 | List singleTargetValidationSet = Lists.newArrayList(); 25 | for(T instance : validationSet) { 26 | singleTargetValidationSet.add(new ClassifierInstance(instance.getAttributes(), instanceTargetSelector.getSingleLabel(instance), instance.getWeight())); 27 | } 28 | return lossFunction.getLoss(Utils.calcResultPredictions(predictiveModel, singleTargetValidationSet)); 29 | } 30 | 31 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/PredictionMapResult.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation; 2 | 3 | import quickml.data.PredictionMap; 4 | import quickml.supervised.crossValidation.lossfunctions.LabelPredictionWeight; 5 | 6 | import java.io.Serializable; 7 | 8 | import static com.google.common.base.Preconditions.checkArgument; 9 | import static java.lang.Double.isInfinite; 10 | import static java.lang.Double.isNaN; 11 | 12 | public class PredictionMapResult extends LabelPredictionWeight { 13 | private PredictionMap prediction; 14 | private Serializable label; 15 | private double weight; 16 | 17 | public PredictionMapResult(PredictionMap prediction, Serializable label, double weight) { 18 | super(label, prediction, weight); 19 | this.prediction = prediction; 20 | this.label = label; 21 | this.weight = weight; 22 | } 23 | 24 | public PredictionMap getPrediction() { 25 | return prediction; 26 | } 27 | 28 | public double getWeight() { 29 | return weight; 30 | } 31 | 32 | public Serializable getLabel() { 33 | return label; 34 | } 35 | 36 | public double getPredictionForLabel() { 37 | Double probability = prediction.get(label); 38 | checkArgument(!isNaN(probability), "Probability must be a natural number, not NaN"); 39 | checkArgument(!isInfinite(probability), "Probability must be a natural number, not infinite"); 40 | 41 | return probability; 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/PredictionMapResults.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation; 2 | 3 | import java.util.Iterator; 4 | import java.util.List; 5 | 6 | import static com.google.common.base.Preconditions.checkArgument; 7 | 8 | public class PredictionMapResults implements Iterable{ 9 | 10 | private final List results; 11 | private final double totalWeight; 12 | 13 | public PredictionMapResults(List results) { 14 | checkArgument(!results.isEmpty(), "Prediction results must not be empty"); 15 | 16 | this.results = results; 17 | this.totalWeight = calcTotalWeight(); 18 | } 19 | 20 | private double calcTotalWeight() { 21 | double totalWeight = 0; 22 | for (PredictionMapResult result : results) { 23 | totalWeight += result.getWeight(); 24 | } 25 | return totalWeight; 26 | } 27 | 28 | public double totalWeight() { 29 | return totalWeight; 30 | } 31 | 32 | @Override 33 | public Iterator iterator() { 34 | return results.iterator(); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/RegressionLossChecker.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.data.instances.Instance; 5 | import quickml.supervised.PredictiveModel; 6 | import quickml.supervised.Utils; 7 | import quickml.supervised.crossValidation.lossfunctions.regressionLossFunctions.RegressionLossFunction; 8 | 9 | import java.io.BufferedWriter; 10 | import java.util.List; 11 | 12 | /** 13 | * Created by alexanderhawk on 8/12/15. 14 | */ 15 | public class RegressionLossChecker> implements LossChecker { 16 | private RegressionLossFunction lossFunction; 17 | 18 | public RegressionLossChecker(RegressionLossFunction lossFunction) { 19 | this.lossFunction = lossFunction; 20 | } 21 | 22 | @Override 23 | public double calculateLoss(PM predictiveModel, List validationSet) { 24 | return lossFunction.getLoss(Utils.getRegLabelsPredictionsWeights(predictiveModel, validationSet)); 25 | } 26 | 27 | public double calculateLoss(PM predictiveModel, List validationSet, BufferedWriter bw) { 28 | return lossFunction.getLoss(Utils.getRegLabelsPredictionsWeights(predictiveModel, validationSet, bw)); 29 | } 30 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/attributeImportance/AttributeWithLoss.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.attributeImportance; 2 | 3 | public class AttributeWithLoss implements Comparable { 4 | private String attribute; 5 | private double loss; 6 | 7 | public AttributeWithLoss(String attribute, double loss) { 8 | this.attribute = attribute; 9 | this.loss = loss; 10 | } 11 | 12 | // Compare the other loss to this objects loss, we want the attributes with the 13 | // highest loss to come first (since removing them has the biggest affect on loss) 14 | @Override 15 | public int compareTo(AttributeWithLoss o) { 16 | return Double.compare(o.loss, loss); 17 | } 18 | 19 | public String getAttribute() { 20 | return attribute; 21 | } 22 | 23 | public double getLoss() { 24 | return loss; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/data/FoldedDataFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.data; 2 | 3 | import quickml.data.instances.Instance; 4 | import quickml.supervised.classifier.logisticRegression.TransformedData; 5 | 6 | /** 7 | * Created by alexanderhawk on 10/30/15. 8 | */ 9 | public class FoldedDataFactory> implements TrainingDataCyclerFactory { 10 | private int numFolds; 11 | private int foldsUsed; 12 | 13 | public FoldedDataFactory(int numFolds, int foldsUsed) { 14 | this.numFolds = numFolds; 15 | this.foldsUsed = foldsUsed; 16 | } 17 | 18 | @Override 19 | public FoldedData getTrainingDataCycler(D data) { 20 | return new FoldedData<>(data.getTransformedInstances(), numFolds, foldsUsed); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/data/OutOfTimeDataFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.data; 2 | 3 | import quickml.data.instances.Instance; 4 | import quickml.supervised.classifier.logisticRegression.TransformedDataWithDates; 5 | 6 | /** 7 | * Created by alexanderhawk on 10/30/15. 8 | */ 9 | public class OutOfTimeDataFactory> implements TrainingDataCyclerFactory { 10 | private double crossValidationFraction; 11 | private int timeSliceHours; 12 | 13 | public OutOfTimeDataFactory(double crossValidationFraction, int timeSliceHours) { 14 | this.crossValidationFraction = crossValidationFraction; 15 | this.timeSliceHours = timeSliceHours; 16 | } 17 | 18 | @Override 19 | public OutOfTimeData getTrainingDataCycler(D data) { 20 | return new OutOfTimeData<>(data.getTransformedInstances(), crossValidationFraction, timeSliceHours, data.getDateTimeExtractor()); 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/data/TrainingDataCycler.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.data; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * A training data cycler should take a set of training instances and cycle through different treeBuildContexts of training and 7 | * validation sets. 8 | * @param 9 | */ 10 | public interface TrainingDataCycler { 11 | 12 | List getTrainingSet(); 13 | 14 | List getValidationSet(); 15 | 16 | boolean nextCycle(); 17 | 18 | void reset(); 19 | 20 | List getAllData(); 21 | 22 | boolean hasMore(); 23 | 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/data/TrainingDataCyclerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.data; 2 | 3 | import quickml.data.instances.Instance; 4 | import quickml.supervised.classifier.logisticRegression.TransformedData; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * Created by alexanderhawk on 10/30/15. 10 | */ 11 | public interface TrainingDataCyclerFactory> { 12 | TrainingDataCycler getTrainingDataCycler(D dataDTO);//also depends on the Date Time extractor 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/genAttributeImportance/AttributeWithLoss.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.genAttributeImportance; 2 | 3 | public class AttributeWithLoss implements Comparable { 4 | private String attribute; 5 | private double loss; 6 | 7 | public AttributeWithLoss(String attribute, double loss) { 8 | this.attribute = attribute; 9 | this.loss = loss; 10 | } 11 | 12 | // Compare the other loss to this objects loss, we want the attributes with the 13 | // highest loss to come first (since removing them has the biggest affect on loss) 14 | @Override 15 | public int compareTo(AttributeWithLoss o) { 16 | return Double.compare(o.loss, loss); 17 | } 18 | 19 | public String getAttribute() { 20 | return attribute; 21 | } 22 | 23 | public double getLoss() { 24 | return loss; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/LabelPredictionWeight.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions; 2 | 3 | 4 | /** 5 | * Created by alexanderhawk on 7/30/14. 6 | */ 7 | public class LabelPredictionWeight { 8 | double weight; 9 | L label; 10 | P prediction; 11 | 12 | public double getWeight() { 13 | return weight; 14 | } 15 | 16 | public L getLabel() { 17 | return label; 18 | } 19 | 20 | public P getPrediction() { 21 | return prediction; 22 | } 23 | 24 | public LabelPredictionWeight(L label, P prediction, double weight) { 25 | this.label = label; 26 | this.prediction = prediction; 27 | this.weight = weight; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/LossFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions; 2 | 3 | public interface LossFunction { 4 | 5 | public Double getLoss(R results); 6 | 7 | public String getName(); 8 | 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/LossFunctions.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions; 2 | 3 | import quickml.supervised.crossValidation.PredictionMapResult; 4 | import quickml.supervised.crossValidation.PredictionMapResults; 5 | 6 | import java.util.List; 7 | 8 | public class LossFunctions { 9 | 10 | public static double mseClassifierLoss(PredictionMapResults results) { 11 | double totalLoss = 0; 12 | for (PredictionMapResult result : results) { 13 | final double error = (1.0 - result.getPredictionForLabel()); 14 | final double errorSquared = error * error * result.getWeight(); 15 | totalLoss += errorSquared; 16 | } 17 | return results.totalWeight() > 0 ? totalLoss / results.totalWeight() : 0; 18 | } 19 | 20 | public static double rmseClassifierLoss(PredictionMapResults results) { 21 | return Math.sqrt(mseClassifierLoss(results)); 22 | } 23 | 24 | public static double mseRegressionLoss(List> results) { 25 | double totalLoss = 0; 26 | double totalWeight = 0; 27 | for (LabelPredictionWeight result : results) { 28 | final double error = (result.getLabel() - result.getPrediction()); 29 | final double errorSquared = error * error * result.getWeight(); 30 | totalLoss += errorSquared*result.getWeight(); 31 | totalWeight += result.getWeight(); 32 | } 33 | return totalWeight > 0 ? totalLoss / totalWeight : 0; 34 | } 35 | 36 | public static double rmseRegressionLoss(List> results) { 37 | return Math.sqrt(mseRegressionLoss(results)); 38 | } 39 | 40 | 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/classifierLossFunctions/ClassifierLogCVLossFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions; 2 | 3 | import quickml.supervised.crossValidation.PredictionMapResult; 4 | import quickml.supervised.crossValidation.PredictionMapResults; 5 | 6 | public class ClassifierLogCVLossFunction extends ClassifierLossFunction { 7 | 8 | private static final double DEFAULT_MIN_PROBABILITY = 10E-7; 9 | public static final String NAME = "LOG_CV"; 10 | public double minProbability; 11 | public double maxError; 12 | 13 | 14 | public ClassifierLogCVLossFunction(double minProbability) { 15 | this.minProbability = minProbability; 16 | this.maxError = -Math.log(minProbability); 17 | } 18 | 19 | private double lossForInstance(double correctProbability, double weight) { 20 | return (correctProbability > minProbability) ? -weight * Math.log(correctProbability) : weight * maxError; 21 | } 22 | 23 | @Override 24 | public Double getLoss(PredictionMapResults results) { 25 | double totalLoss = 0; 26 | for (PredictionMapResult result : results) { 27 | totalLoss += lossForInstance(result.getPredictionForLabel(), result.getWeight()); 28 | } 29 | return results.totalWeight() > 0 ? totalLoss / results.totalWeight() : 0; 30 | } 31 | 32 | @Override 33 | public String getName() { 34 | return NAME; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/classifierLossFunctions/ClassifierLossFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions; 2 | 3 | import quickml.supervised.crossValidation.PredictionMapResults; 4 | import quickml.supervised.crossValidation.lossfunctions.LossFunction; 5 | 6 | public abstract class ClassifierLossFunction implements LossFunction { 7 | 8 | public abstract Double getLoss(PredictionMapResults results); 9 | 10 | public abstract String getName(); 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/classifierLossFunctions/ClassifierMSELossFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions; 2 | 3 | import quickml.supervised.crossValidation.PredictionMapResults; 4 | 5 | import static quickml.supervised.crossValidation.lossfunctions.LossFunctions.mseClassifierLoss; 6 | 7 | public class ClassifierMSELossFunction extends ClassifierLossFunction { 8 | 9 | @Override 10 | public Double getLoss(PredictionMapResults results) { 11 | return mseClassifierLoss(results); 12 | } 13 | 14 | @Override 15 | public String getName() { 16 | return "MSE"; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/classifierLossFunctions/ClassifierRMSELossFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions; 2 | 3 | import quickml.supervised.crossValidation.PredictionMapResults; 4 | 5 | import static quickml.supervised.crossValidation.lossfunctions.LossFunctions.rmseClassifierLoss; 6 | 7 | public class ClassifierRMSELossFunction extends ClassifierLossFunction { 8 | 9 | @Override 10 | public Double getLoss(PredictionMapResults results) { 11 | return rmseClassifierLoss(results); 12 | } 13 | 14 | @Override 15 | public String getName() { 16 | return "RMSE"; 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/rankingLossFunctions/RankingLossFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions.rankingLossFunctions; 2 | 3 | import quickml.supervised.crossValidation.lossfunctions.LossFunction; 4 | import quickml.supervised.rankingModels.ItemToOutcomeMap; 5 | import quickml.supervised.rankingModels.LabelPredictionWeightForRanking; 6 | import quickml.supervised.rankingModels.RankingPrediction; 7 | 8 | import java.util.List; 9 | import java.util.TreeMap; 10 | 11 | /** 12 | * Created by alexanderhawk on 8/12/15. 13 | */ 14 | public interface RankingLossFunction extends LossFunction> { 15 | /**Map keys are the rankings, where doubles are the numeric values of the actual outcomes*/ 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/regressionLossFunctions/RegressionLossFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions.regressionLossFunctions; 2 | 3 | import quickml.supervised.crossValidation.lossfunctions.LabelPredictionWeight; 4 | import quickml.supervised.crossValidation.lossfunctions.LossFunction; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * Created by alexanderhawk on 8/12/15. 10 | */ 11 | public abstract class RegressionLossFunction implements LossFunction>> 12 | { 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/lossfunctions/regressionLossFunctions/RegressionRMSELossFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions.regressionLossFunctions; 2 | 3 | import quickml.supervised.crossValidation.lossfunctions.LabelPredictionWeight; 4 | 5 | import java.util.List; 6 | 7 | import static quickml.supervised.crossValidation.lossfunctions.LossFunctions.rmseRegressionLoss; 8 | 9 | public class RegressionRMSELossFunction extends RegressionLossFunction { 10 | 11 | @Override 12 | public Double getLoss(List> results) { 13 | return rmseRegressionLoss(results); 14 | } 15 | 16 | @Override 17 | public String getName() { 18 | return "RMSE"; 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/movingAverages/ArithmeticAverage.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.movingAverages; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * Created by alexanderhawk on 4/29/14. 7 | */ 8 | public class ArithmeticAverage implements MovingAverage { 9 | double average = 0; 10 | @Override 11 | public double getAverage(List values) { 12 | for(Double val : values) 13 | average += val; 14 | average /= values.size(); 15 | return average; 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/movingAverages/HoltWintersMovingAverage.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.movingAverages; 2 | 3 | import java.util.List; 4 | 5 | /** 6 | * Created by alexanderhawk on 4/29/14. 7 | */ 8 | public class HoltWintersMovingAverage implements MovingAverage { 9 | double average = 0; 10 | private double alpha; 11 | private double beta; 12 | 13 | public HoltWintersMovingAverage(double alpha, double beta) { 14 | this.alpha = alpha; 15 | this.beta = beta; 16 | } 17 | 18 | public void setAlpha(double alpha) { 19 | this.alpha = alpha; 20 | } 21 | 22 | public void setBeta(double beta) { 23 | this.beta = beta; 24 | } 25 | 26 | public double getAverage(List values) { 27 | 28 | double s = values.get(1); 29 | double b = values.get(1) - values.get(0); 30 | for(int i = 2; i < values.size(); i++) { 31 | double s_prev = s; 32 | s = alpha * values.get(i) + (1 - alpha) * (s - b); 33 | b = beta * (s - s_prev) + (1 - beta) * b; 34 | } 35 | return s; 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/movingAverages/MovingAverage.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.movingAverages; 2 | 3 | /** 4 | * Created by alexanderhawk on 4/29/14. 5 | */ 6 | 7 | import java.util.List; 8 | 9 | public interface MovingAverage { 10 | 11 | public abstract double getAverage(List values); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/utils/AttributesHashSplitter.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.utils; 2 | 3 | import com.google.common.base.Predicate; 4 | import com.google.common.hash.HashFunction; 5 | import com.google.common.hash.Hashing; 6 | import quickml.data.AttributesMap; 7 | import quickml.data.instances.Instance; 8 | 9 | import java.io.Serializable; 10 | 11 | /** 12 | * Created by ian on 2/28/14. 13 | */ 14 | public class AttributesHashSplitter implements Predicate> { 15 | 16 | private static final HashFunction hashFunction = Hashing.murmur3_32(); 17 | 18 | private final int every; 19 | 20 | public AttributesHashSplitter(int every) { 21 | this.every = every; 22 | } 23 | 24 | @Override 25 | public boolean apply(final Instance instance) { 26 | int hc = hashFunction.hashInt(instance.getAttributes().hashCode()).asInt(); 27 | return Math.abs(hc) % every == 0; 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/utils/DateTimeExtractor.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.utils; 2 | import org.joda.time.DateTime; 3 | 4 | /** 5 | * Created by alexanderhawk on 5/6/14. 6 | */ 7 | public interface DateTimeExtractor { 8 | DateTime extractDateTime(I instance); 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/utils/MeanNormalizedDateTimeExtractor.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.utils; 2 | 3 | 4 | import org.joda.time.DateTime; 5 | import quickml.data.instances.InstanceWithAttributesMap; 6 | import quickml.supervised.Utils; 7 | import java.util.Map; 8 | 9 | 10 | public class MeanNormalizedDateTimeExtractor implements DateTimeExtractor { 11 | 12 | private Map meanStdMaxMinMap; 13 | 14 | public MeanNormalizedDateTimeExtractor(Map meanStdMaxMinMap) { 15 | this.meanStdMaxMinMap = meanStdMaxMinMap; 16 | } 17 | 18 | @Override 19 | public DateTime extractDateTime(T instance) { 20 | int year = attrVal(instance, "timeOfArrival-year"); 21 | int month = attrVal(instance,"timeOfArrival-monthOfYear"); 22 | int day = attrVal(instance,"timeOfArrival-dayOfMonth"); 23 | int hour = attrVal(instance, "timeOfArrival-hourOfDay"); 24 | int minute = attrVal(instance, "timeOfArrival-minuteOfHour"); 25 | return new DateTime(year, month, day, hour, minute, 0, 0); 26 | } 27 | 28 | private int attrVal(T instance, String attrName) { 29 | int normalizedVal = instance.getAttributes().containsKey(attrName) ? 30 | ((Number) instance.getAttributes().get(attrName)).intValue() : 1; 31 | double mean = meanStdMaxMinMap.get(attrName).getMean(); 32 | double std = meanStdMaxMinMap.get(attrName).getNonZeroStd(); 33 | return (int)(normalizedVal*std + mean); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/crossValidation/utils/SimpleDateFormatExtractor.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.utils; 2 | 3 | import org.joda.time.DateTime; 4 | import org.joda.time.DateTimeZone; 5 | import org.slf4j.Logger; 6 | import org.slf4j.LoggerFactory; 7 | import quickml.data.AttributesMap; 8 | import quickml.data.instances.InstanceWithAttributesMap; 9 | 10 | import java.text.*; 11 | import java.util.Date; 12 | 13 | /** 14 | * Created by alexanderhawk on 6/22/14. 15 | */ 16 | public class SimpleDateFormatExtractor> implements DateTimeExtractor { 17 | private static final Logger logger = LoggerFactory.getLogger(SimpleDateFormatExtractor.class); 18 | DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); 19 | String dateAttribute = "created_at"; 20 | 21 | public void setDateFormat(String dateFormatString) { 22 | dateFormat = new SimpleDateFormat(dateFormatString); 23 | } 24 | 25 | public void setDateAttribute(String dateAttribute) { 26 | this.dateAttribute = dateAttribute; 27 | } 28 | 29 | @Override 30 | public DateTime extractDateTime(T instance) { 31 | AttributesMap attributes = instance.getAttributes(); 32 | try { 33 | Date currentTimeMillis = dateFormat.parse((String) attributes.get(dateAttribute)); 34 | return new DateTime(currentTimeMillis, DateTimeZone.UTC); 35 | } catch (ParseException e) { 36 | logger.error("Error parsing date", e); 37 | } 38 | return new DateTime(); 39 | } 40 | } 41 | 42 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/dataProcessing/AttributeCharacteristics.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.dataProcessing; 2 | 3 | import java.io.Serializable; 4 | import java.util.HashSet; 5 | 6 | /** 7 | * Created by alexanderhawk on 10/14/15. 8 | */ 9 | public class AttributeCharacteristics { 10 | 11 | public boolean isNumber = true; 12 | public boolean isBoolean = true; 13 | private HashSet observedVals = new HashSet(); 14 | 15 | public void updateBooleanStatus(Serializable val) { 16 | if (!isBoolean || val == null) { 17 | return; 18 | } 19 | if (observedVals.size() > 2 || (observedVals.size() == 2 && !observedVals.contains(val))) { 20 | isBoolean = false; 21 | } else { 22 | observedVals.add(val); 23 | } 24 | if (bothValsAreNumbers()) { 25 | isBoolean = false; 26 | } 27 | } 28 | 29 | private boolean bothValsAreNumbers() { 30 | boolean bothValsAreNum = true; 31 | 32 | for (Serializable key : observedVals) { 33 | if (!(key instanceof Number)) 34 | return false; 35 | } 36 | return true; 37 | } 38 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/dataProcessing/BinaryAttributeCharacteristics.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.dataProcessing; 2 | 3 | import java.util.HashSet; 4 | 5 | /** 6 | * Created by alexanderhawk on 10/14/15. 7 | */ 8 | public class BinaryAttributeCharacteristics { 9 | private boolean isBinary = true; 10 | private HashSet observedVals = new HashSet(); 11 | 12 | public boolean getIsBinary() { 13 | return isBinary; 14 | } 15 | 16 | public void updateBinaryStatus(double val) { 17 | if (isBinary) { 18 | observedVals.add(val); 19 | if (observedVals.size() > 1) { 20 | isBinary = false; 21 | } 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/dataProcessing/instanceTranformer/InstanceTransformer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.dataProcessing.instanceTranformer; 2 | 3 | import quickml.data.instances.Instance; 4 | 5 | /** 6 | * Created by alexanderhawk on 10/14/15. 7 | */ 8 | public interface InstanceTransformer { 9 | /** 10 | * particular implementations may mutate the input instance, others may not. Be sure to see the documentation accordingly 11 | */ 12 | R transformInstance(I instance); 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/dataProcessing/instanceTranformer/LabelToDigitConverter.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.dataProcessing.instanceTranformer; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.data.instances.InstanceFactory; 5 | import quickml.data.instances.InstanceWithAttributesMap; 6 | 7 | import java.io.Serializable; 8 | import java.util.List; 9 | import java.util.Map; 10 | 11 | import static quickml.supervised.classifier.logisticRegression.InstanceTransformerUtils.determineNumericClassLabels; 12 | 13 | /** 14 | * Created by alexanderhawk on 10/14/15. 15 | */ 16 | public class LabelToDigitConverter, R extends InstanceWithAttributesMap> implements InstanceTransformer { 17 | final InstanceFactory instanceFactory; 18 | private Map numericClassLabels; 19 | 20 | public LabelToDigitConverter(InstanceFactory instanceFactory, List trainingData) { 21 | numericClassLabels = determineNumericClassLabels(trainingData); 22 | this.instanceFactory = instanceFactory; 23 | } 24 | 25 | public Map getNumericClassLabels() { 26 | return numericClassLabels; 27 | } 28 | 29 | @Override 30 | public R transformInstance(I instance) { 31 | return instanceFactory.createInstance(instance.getAttributes(), (L)numericClassLabels.get(instance.getLabel()), instance.getWeight()); 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/dataProcessing/instanceTranformer/ProductFeatureAppender.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.dataProcessing.instanceTranformer; 2 | 3 | import quickml.data.instances.InstanceWithAttributesMap; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * Created by alexanderhawk on 10/22/15. 9 | */ 10 | public interface ProductFeatureAppender { 11 | public List addProductAttributes(List trainingData); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/ensembles/randomForest/RandomForest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.ensembles.randomForest; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.supervised.PredictiveModel; 5 | import quickml.supervised.PredictiveModelBuilder; 6 | import quickml.supervised.classifier.AbstractClassifier; 7 | import quickml.supervised.tree.Tree; 8 | 9 | /** 10 | * Created by alexanderhawk on 4/27/15. 11 | */ 12 | public interface RandomForest> extends PredictiveModel { 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/ensembles/randomForest/RandomForestBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.ensembles.randomForest; 2 | 3 | import quickml.data.instances.InstanceWithAttributesMap; 4 | import quickml.supervised.PredictiveModelBuilder; 5 | import quickml.supervised.tree.Tree; 6 | 7 | /** 8 | * Created by alexanderhawk on 6/21/15. 9 | */ 10 | public abstract class RandomForestBuilder>, I extends InstanceWithAttributesMap> implements PredictiveModelBuilder { 11 | protected int numTrees = 8; 12 | 13 | 14 | public abstract PM buildPredictiveModel(Iterable trainingData); 15 | 16 | public int getNumTrees() { 17 | return numTrees; 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/featureEngineering1/AttributesEnrichStrategy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.featureEngineering1; 2 | 3 | import quickml.data.instances.InstanceWithAttributesMap; 4 | 5 | /** 6 | * Created by ian on 5/21/14. 7 | */ 8 | public interface AttributesEnrichStrategy { 9 | public AttributesEnricher build(Iterable> trainingData); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/featureEngineering1/AttributesEnricher.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.featureEngineering1; 2 | 3 | import com.google.common.base.Function; 4 | import quickml.data.AttributesMap; 5 | 6 | import java.io.Serializable; 7 | 8 | /** 9 | * A Function that will take a set of attributes, and return a set of attributes that will 10 | * be enhanced in some way determined by the specific implementation. 11 | */ 12 | public interface AttributesEnricher extends Function, Serializable { 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/featureEngineering1/InstanceEnricher.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.featureEngineering1; 2 | 3 | import com.google.common.base.Function; 4 | import quickml.data.AttributesMap; 5 | import quickml.data.instances.InstanceWithAttributesMap; 6 | 7 | import javax.annotation.Nullable; 8 | import java.util.List; 9 | 10 | /** 11 | * Created by ian on 5/20/14. 12 | */ 13 | public class InstanceEnricher implements Function, InstanceWithAttributesMap> { 14 | private final List attributesEnrichers; 15 | 16 | public InstanceEnricher(List attributesEnrichers) { 17 | this.attributesEnrichers = attributesEnrichers; 18 | } 19 | 20 | @Nullable 21 | @Override 22 | public InstanceWithAttributesMap apply(@Nullable InstanceWithAttributesMap instance) { 23 | AttributesMap attributes = instance.getAttributes(); 24 | for (AttributesEnricher attributesEnricher : attributesEnrichers) { 25 | attributes = attributesEnricher.apply(attributes); 26 | } 27 | return new InstanceWithAttributesMap(attributes, instance.getLabel(), instance.getWeight()); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/featureEngineering1/enrichStrategies/attributeCombiner/AttributeCombiningEnrichStrategy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.featureEngineering1.enrichStrategies.attributeCombiner; 2 | 3 | import quickml.data.instances.InstanceWithAttributesMap; 4 | import quickml.supervised.featureEngineering1.AttributesEnrichStrategy; 5 | import quickml.supervised.featureEngineering1.AttributesEnricher; 6 | 7 | import java.util.List; 8 | import java.util.Set; 9 | 10 | /** 11 | * An AttributesEnrichStrategy that takes several lists of attribute keys, and combines 12 | * the values of each of those attributes into a new attribute. 13 | */ 14 | public class AttributeCombiningEnrichStrategy implements AttributesEnrichStrategy { 15 | private final Set> attributesToCombine; 16 | 17 | public AttributeCombiningEnrichStrategy(final Set> attributesToCombine) { 18 | this.attributesToCombine = attributesToCombine; 19 | } 20 | 21 | @Override 22 | public AttributesEnricher build(final Iterable> trainingData) { 23 | return new AttributeCombiningEnricher(attributesToCombine); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/inspection/AttributeScore.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.inspection; 2 | 3 | /** 4 | * Created by ian on 3/29/14. 5 | */ 6 | public class AttributeScore implements Comparable { 7 | private final String attribute; 8 | private final double score; 9 | 10 | public AttributeScore(final String attribute, final double score) { 11 | this.attribute = attribute; 12 | this.score = score; 13 | } 14 | 15 | @Override 16 | public int compareTo(final AttributeScore o) { 17 | return Double.compare(score, o.score); 18 | } 19 | 20 | public String getAttribute() { 21 | return attribute; 22 | } 23 | 24 | public double getScore() { 25 | return score; 26 | } 27 | 28 | @Override 29 | public String toString() { 30 | final StringBuilder sb = new StringBuilder("AttributeScore{"); 31 | sb.append("attribute='").append(attribute).append('\''); 32 | sb.append(", score=").append(score); 33 | sb.append('}'); 34 | return sb.toString(); 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/parametricModels/OptimizableCostFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.parametricModels; 2 | 3 | import quickml.data.instances.Instance; 4 | 5 | import java.io.Serializable; 6 | import java.util.List; 7 | import java.util.Map; 8 | 9 | /** 10 | * Created by alexanderhawk on 4/1/16. 11 | */ 12 | public interface OptimizableCostFunction { 13 | double computeCost(List instances, double[] weights, double minPredictedProbablity); 14 | void updateGradient(final List instances, final double[] fixedWeights, double[] gradient); 15 | public void updateBuilderConfig(final Map config); 16 | public void shutdown(); 17 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/predictiveModelOptimizer/ConfigWithLoss.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.predictiveModelOptimizer; 2 | 3 | import java.io.Serializable; 4 | import java.util.Map; 5 | 6 | /** 7 | * Created by alexanderhawk on 9/24/15. 8 | */ 9 | public class ConfigWithLoss { 10 | 11 | double loss; 12 | Map config; 13 | 14 | public ConfigWithLoss(final double loss, final Map config) { 15 | this.loss = loss; 16 | this.config = config; 17 | } 18 | 19 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/predictiveModelOptimizer/FieldValueRecommender.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.predictiveModelOptimizer; 2 | 3 | import java.io.Serializable; 4 | import java.util.List; 5 | 6 | /** 7 | * Created by ian on 4/12/14. 8 | */ 9 | public interface FieldValueRecommender { 10 | List getValues(); 11 | 12 | Serializable first(); 13 | 14 | boolean shouldContinue(List losses); 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/predictiveModelOptimizer/MultiLossModelTester.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.predictiveModelOptimizer; 2 | 3 | import quickml.data.instances.ClassifierInstance; 4 | import quickml.supervised.PredictiveModelBuilder; 5 | import quickml.supervised.crossValidation.attributeImportance.LossFunctionTracker; 6 | import quickml.supervised.crossValidation.data.TrainingDataCycler; 7 | import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLossFunction; 8 | import quickml.supervised.classifier.Classifier; 9 | 10 | import java.util.List; 11 | 12 | import static quickml.supervised.Utils.calcResultPredictions; 13 | 14 | public class MultiLossModelTester { 15 | 16 | private TrainingDataCycler dataCycler; 17 | private final PredictiveModelBuilder modelBuilder; 18 | 19 | public MultiLossModelTester(PredictiveModelBuilder modelBuilder, TrainingDataCycler dataCycler) { 20 | this.dataCycler = dataCycler; 21 | this.modelBuilder = modelBuilder; 22 | } 23 | 24 | public LossFunctionTracker getMultilossForModel(List lossFunctions) { 25 | 26 | dataCycler.reset(); 27 | LossFunctionTracker lossFunctionTracker = new LossFunctionTracker(lossFunctions); 28 | 29 | do { 30 | List validationSet = dataCycler.getValidationSet(); 31 | Classifier predictiveModel = modelBuilder.buildPredictiveModel(dataCycler.getTrainingSet()); 32 | lossFunctionTracker.updateLosses(calcResultPredictions(predictiveModel, validationSet)); 33 | dataCycler.nextCycle(); 34 | } while (dataCycler.hasMore()); 35 | 36 | return lossFunctionTracker; 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/predictiveModelOptimizer/fieldValueRecommenders/FixedOrderRecommender.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders; 2 | 3 | import com.google.common.collect.Lists; 4 | import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender; 5 | 6 | import java.io.Serializable; 7 | import java.util.List; 8 | 9 | import static com.google.common.base.Preconditions.checkArgument; 10 | 11 | public class FixedOrderRecommender implements FieldValueRecommender { 12 | private final List values; 13 | 14 | public FixedOrderRecommender(Serializable... values) { 15 | checkArgument(values.length > 0, "Must include at least one value"); 16 | this.values = Lists.newArrayList(values); 17 | } 18 | 19 | @Override 20 | public List getValues() { 21 | return values; 22 | } 23 | 24 | @Override 25 | public Serializable first() { 26 | return values.get(0); 27 | } 28 | 29 | @Override 30 | public boolean shouldContinue(List losses) { 31 | return true; 32 | } 33 | 34 | 35 | 36 | } 37 | 38 | 39 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/rankingModels/ItemToOutcomeMap.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.rankingModels; 2 | 3 | import com.google.common.collect.Lists; 4 | 5 | import java.io.Serializable; 6 | import java.util.*; 7 | 8 | /** 9 | * Created by alexanderhawk on 8/13/15. 10 | */ 11 | public class ItemToOutcomeMap implements Serializable { 12 | public HashMap itemToOutcome; 13 | 14 | public ItemToOutcomeMap(HashMap itemToOutcome) { 15 | this.itemToOutcome = itemToOutcome; 16 | } 17 | 18 | public Iterator> iterator(){ 19 | return itemToOutcome.entrySet().iterator(); 20 | } 21 | 22 | public double getOutcome(Serializable item) { 23 | return itemToOutcome.get(item); 24 | } 25 | 26 | public List getItems() { 27 | return Lists.newArrayList(itemToOutcome.keySet()); 28 | } 29 | 30 | public Serializable getFirstItem(){ 31 | Iterator items = itemToOutcome.keySet().iterator(); 32 | if (items.hasNext()) { 33 | return items.next(); 34 | } 35 | else{ 36 | return null; 37 | } 38 | } 39 | 40 | public int size() { 41 | return itemToOutcome.size(); 42 | } 43 | 44 | public Collection getOutcomes() { 45 | return itemToOutcome.values(); 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/rankingModels/LabelPredictionWeightForRanking.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.rankingModels; 2 | 3 | import quickml.supervised.crossValidation.lossfunctions.LabelPredictionWeight; 4 | 5 | /** 6 | * Created by alexanderhawk on 8/13/15. 7 | */ 8 | public class LabelPredictionWeightForRanking extends LabelPredictionWeight { 9 | 10 | public LabelPredictionWeightForRanking(ItemToOutcomeMap itemToOutcomeMap,RankingPrediction rankingPrediction, double weight) { 11 | super(itemToOutcomeMap, rankingPrediction, weight); 12 | } 13 | 14 | public LabelPredictionWeightForRanking(ItemToOutcomeMap itemToOutcomeMap,RankingPrediction rankingPrediction) { 15 | super(itemToOutcomeMap, rankingPrediction, 1.0); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/rankingModels/RankingInstance.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.rankingModels; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.data.instances.InstanceWithAttributesMap; 5 | 6 | import java.io.Serializable; 7 | 8 | /** 9 | * Created by alexanderhawk on 8/13/15. 10 | */ 11 | public class RankingInstance extends InstanceWithAttributesMap { 12 | /**the label for a list of recs is a HashMap of items (serializables) to outcome values */ 13 | 14 | public RankingInstance(AttributesMap attributes, ItemToOutcomeMap label) { 15 | super(attributes, label, 1.0); 16 | } 17 | 18 | public RankingInstance(AttributesMap attributes, ItemToOutcomeMap label, double weight) { 19 | super(attributes, label, weight); 20 | } 21 | public Serializable getFirstItem(){ 22 | return getLabel().getFirstItem(); 23 | } 24 | 25 | } 26 | 27 | 28 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/rankingModels/RankingLossChecker.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.rankingModels; 2 | 3 | import quickml.supervised.crossValidation.LossChecker; 4 | import quickml.supervised.crossValidation.lossfunctions.rankingLossFunctions.RankingLossFunction; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * Created by alexanderhawk on 8/12/15. 10 | */ 11 | public class RankingLossChecker implements LossChecker { 12 | private RankingLossFunction lossFunction; 13 | 14 | public RankingLossChecker(RankingLossFunction lossFunction) { 15 | this.lossFunction = lossFunction; 16 | } 17 | 18 | @Override 19 | public double calculateLoss(PM predictiveModel, List validationSet) { 20 | return lossFunction.getLoss(Utils.getLabelPredictionWeights(predictiveModel, validationSet)); 21 | } 22 | 23 | } 24 | 25 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/rankingModels/RankingModel.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.rankingModels; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.supervised.PredictiveModel; 5 | 6 | import java.io.Serializable; 7 | import java.util.List; 8 | 9 | /** 10 | * Created by alexanderhawk on 8/13/15. 11 | */ 12 | public interface RankingModel extends PredictiveModel { 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/rankingModels/RankingPrediction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.rankingModels; 2 | 3 | import com.google.common.collect.Lists; 4 | import com.google.common.collect.Maps; 5 | 6 | import javax.annotation.Nullable; 7 | import java.io.Serializable; 8 | import java.util.HashMap; 9 | import java.util.List; 10 | import java.util.Map; 11 | 12 | /** 13 | * Created by alexanderhawk on 8/13/15. 14 | */ 15 | public class RankingPrediction { 16 | private List rankedItems = Lists.newArrayList(); 17 | private Map itemsToRanks = Maps.newHashMap(); 18 | 19 | public RankingPrediction(List rankedItems) { 20 | this.rankedItems = rankedItems; 21 | for (int i = 0; i getRankOrder(){ 27 | return rankedItems; 28 | } 29 | 30 | public int getRankOfItem(Serializable item){ 31 | // System.out.println( "ITEM CLICKED: " +item + ". RANKED ITEMS: " +rankedItems); 32 | return itemsToRanks.containsKey(item) ? itemsToRanks.get(item) : Integer.MAX_VALUE; 33 | 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/rankingModels/Utils.java: -------------------------------------------------------------------------------- 1 | 2 | package quickml.supervised.rankingModels; 3 | 4 | import com.google.common.collect.Lists; 5 | import org.slf4j.Logger; 6 | import org.slf4j.LoggerFactory; 7 | 8 | import java.util.List; 9 | 10 | /** 11 | * Created by alexanderhawk on 8/13/15. 12 | */ 13 | public class Utils { 14 | public static final String RANKED_ITEMS = "rankedItems"; 15 | private static final Logger logger = LoggerFactory.getLogger(Utils.class); 16 | 17 | 18 | public static List getLabelPredictionWeights(RankingModel predictiveModel, List validationSet) { 19 | List results = Lists.newArrayList(); 20 | int resultsContainingValue = 0; 21 | for (RankingInstance instance : validationSet) { 22 | RankingPrediction prediction = predictiveModel.predict(instance.getAttributes()); 23 | LabelPredictionWeightForRanking labelPredictionWeightForRanking = new LabelPredictionWeightForRanking(instance.getLabel(), prediction, instance.getWeight()); 24 | results.add(labelPredictionWeightForRanking); 25 | if (prediction.getRankOfItem(instance.getFirstItem())!= Integer.MAX_VALUE) { 26 | resultsContainingValue++; 27 | } else { 28 | // logger.info("predictions {}, label {}", prediction.getRankOrder().toString(), instance.getFirstItem()); 29 | } 30 | } 31 | logger.info("results containing non zero value {}, out n examples {} ",resultsContainingValue, validationSet.size()); 32 | return results; 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/regressionModel/LinearRegression2/LinearRegressionDTO.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.regressionModel.LinearRegression2; 2 | 3 | import quickml.data.instances.SparseRegressionInstance; 4 | import quickml.supervised.classifier.logisticRegression.TransformedDataWithDates; 5 | 6 | import java.util.HashMap; 7 | import java.util.List; 8 | 9 | /** 10 | * Created by alexanderhawk on 10/28/15. 11 | */ 12 | public abstract class LinearRegressionDTO> implements TransformedDataWithDates { 13 | 14 | protected List instances; 15 | protected HashMap nameToIndexMap; 16 | 17 | 18 | @Override 19 | public List getTransformedInstances() { 20 | return instances; 21 | } 22 | 23 | public HashMap getNameToIndexMap() { 24 | return nameToIndexMap; 25 | } 26 | 27 | 28 | 29 | 30 | public LinearRegressionDTO(List instances, 31 | HashMap nameToIndexMap) { 32 | this.instances = instances; 33 | this.nameToIndexMap = nameToIndexMap; 34 | } 35 | 36 | public LinearRegressionDTO(List instances) { 37 | this.instances = instances; 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/regressionModel/MultiVariableRealValuedFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.regressionModel; 2 | 3 | 4 | import quickml.supervised.PredictiveModel; 5 | 6 | /** 7 | * Created by alexanderhawk on 7/29/14. 8 | */ 9 | public interface MultiVariableRealValuedFunction extends PredictiveModel { 10 | public abstract Double predict(double[] attributes); 11 | } 12 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/regressionModel/SingleVariableRealValuedFunction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.regressionModel; 2 | 3 | import quickml.supervised.PredictiveModel; 4 | 5 | /** 6 | * Created by alexanderhawk on 7/29/14. 7 | */ 8 | public interface SingleVariableRealValuedFunction extends PredictiveModel { 9 | public abstract Double predict(Double regressor); 10 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/Tree.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.supervised.PredictiveModel; 5 | import quickml.supervised.tree.nodes.Leaf; 6 | import quickml.supervised.tree.nodes.Node; 7 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 8 | 9 | /** 10 | * Created by alexanderhawk on 4/3/15. 11 | */ 12 | 13 | 14 | public interface Tree

extends PredictiveModel { 15 | } 16 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/attributeIgnoringStrategies/AttributeIgnoringStrategy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.attributeIgnoringStrategies; 2 | 3 | import quickml.supervised.tree.nodes.Branch; 4 | 5 | import java.io.Serializable; 6 | 7 | /** 8 | * Created by alexanderhawk on 2/28/15. 9 | */ 10 | public interface AttributeIgnoringStrategy extends Serializable { 11 | 12 | /** 13 | * Should this attribute be ignored 14 | * @param attribute 15 | * @param parent 16 | * @return 17 | */ 18 | boolean ignoreAttribute(String attribute, Branch parent); 19 | 20 | /** 21 | * @return a copy of this AttributeIgnoringStrategy 22 | */ 23 | AttributeIgnoringStrategy copy(); 24 | } 25 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/attributeIgnoringStrategies/CompositeAttributeIgnoringStrategy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.attributeIgnoringStrategies; 2 | 3 | import com.google.common.collect.Lists; 4 | import quickml.supervised.tree.nodes.Branch; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * Created by alexanderhawk on 2/28/15. 10 | */ 11 | public class CompositeAttributeIgnoringStrategy implements AttributeIgnoringStrategy { 12 | private static final long serialVersionUID = 0L; 13 | 14 | private List attributeIgnoringStrategies = Lists.newArrayList(); 15 | 16 | public CompositeAttributeIgnoringStrategy(List attributeIgnoringStrategies) { 17 | this.attributeIgnoringStrategies = attributeIgnoringStrategies; 18 | } 19 | 20 | @Override 21 | public CompositeAttributeIgnoringStrategy copy() { 22 | List copies = Lists.newArrayList(); 23 | for (AttributeIgnoringStrategy attributeIgnoringStrategy : attributeIgnoringStrategies) { 24 | copies.add(attributeIgnoringStrategy.copy()); 25 | } 26 | return new CompositeAttributeIgnoringStrategy(copies); 27 | } 28 | 29 | @Override 30 | public boolean ignoreAttribute(String attribute, Branch parent) { 31 | for (AttributeIgnoringStrategy attributeIgnoringStrategy : attributeIgnoringStrategies) { 32 | if (attributeIgnoringStrategy.ignoreAttribute(attribute, parent)) { 33 | return true; 34 | } 35 | } 36 | return false; 37 | } 38 | 39 | @Override 40 | public String toString() { 41 | return "CompositeAttributeIgnoringStrategy{" + 42 | "oldAttributeIgnoringStrategies=" + attributeIgnoringStrategies + 43 | '}'; 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/attributeIgnoringStrategies/IgnoreAttributesInSet.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.attributeIgnoringStrategies; 2 | 3 | import com.google.common.collect.Sets; 4 | import quickml.supervised.tree.nodes.Branch; 5 | 6 | import java.util.HashSet; 7 | import java.util.Random; 8 | import java.util.Set; 9 | 10 | /** 11 | * Created by alexanderhawk on 2/28/15. 12 | */ 13 | public class IgnoreAttributesInSet implements AttributeIgnoringStrategy { 14 | private static final long serialVersionUID = 0L; 15 | 16 | private final HashSet attributesToIgnore = Sets.newHashSet(); 17 | private final Set proposedAttributesToIgnore; 18 | private final double discardProbability; 19 | private Random random = new Random(); 20 | 21 | public IgnoreAttributesInSet(Set attributesToIgnore, double probabilityOfDiscardingFromAttributesToIgnore) { 22 | this.proposedAttributesToIgnore = attributesToIgnore; 23 | this.discardProbability = probabilityOfDiscardingFromAttributesToIgnore; 24 | setAttributesToIgnore(); 25 | } 26 | 27 | private void setAttributesToIgnore() { 28 | for (String attribute : proposedAttributesToIgnore) { 29 | if (random.nextDouble() > discardProbability) { 30 | attributesToIgnore.add(attribute); 31 | } 32 | } 33 | } 34 | 35 | @Override 36 | public IgnoreAttributesInSet copy(){ 37 | return new IgnoreAttributesInSet(proposedAttributesToIgnore, discardProbability); 38 | } 39 | 40 | @Override 41 | public boolean ignoreAttribute(String attribute, Branch Parent) { 42 | if (attributesToIgnore.contains(attribute)) { 43 | return true; 44 | } 45 | return false; 46 | } 47 | 48 | @Override 49 | public String toString() { 50 | return "IgnoreAttributesInSet{" + "proposedAttributesToIgnore=" + proposedAttributesToIgnore + 51 | ", discardProbability=" + discardProbability + 52 | '}'; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/attributeIgnoringStrategies/IgnoreAttributesWithConstantProbability.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.attributeIgnoringStrategies; 2 | 3 | import quickml.supervised.tree.nodes.Branch; 4 | 5 | import java.util.concurrent.ThreadLocalRandom; 6 | 7 | /** 8 | * Created by alexanderhawk on 2/28/15. 9 | */ 10 | public class IgnoreAttributesWithConstantProbability implements AttributeIgnoringStrategy { 11 | private static final long serialVersionUID = 0L; 12 | 13 | private final double ignoreAttributeProbability; 14 | private ThreadLocalRandom random = ThreadLocalRandom.current(); 15 | 16 | public IgnoreAttributesWithConstantProbability(double ignoreAttributeProbability) { 17 | this.ignoreAttributeProbability = ignoreAttributeProbability; 18 | } 19 | 20 | @Override 21 | public IgnoreAttributesWithConstantProbability copy(){ 22 | return new IgnoreAttributesWithConstantProbability(ignoreAttributeProbability); 23 | } 24 | 25 | @Override 26 | public boolean ignoreAttribute(String attribute, Branch parent) { 27 | if (random.nextDouble() < ignoreAttributeProbability) { 28 | return true; 29 | } 30 | return false; 31 | } 32 | 33 | public double getIgnoreAttributeProbability() { 34 | return ignoreAttributeProbability; 35 | } 36 | 37 | @Override 38 | public String toString(){ 39 | return "ignoreAttributeProbability = " + ignoreAttributeProbability; 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/attributeValueIgnoringStrategies/AttributeValueIgnoringStrategy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.attributeValueIgnoringStrategies; 2 | 3 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 4 | 5 | /** 6 | * Created by alexanderhawk on 3/18/15. 7 | */ 8 | public interface AttributeValueIgnoringStrategy> { 9 | 10 | boolean shouldWeIgnoreThisValue(final VC valueCounts); 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/attributeValueIgnoringStrategies/AttributeValueIgnoringStrategyBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.attributeValueIgnoringStrategies; 2 | 3 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 4 | 5 | import java.io.Serializable; 6 | 7 | /** 8 | * Created by alexanderhawk on 4/5/15. 9 | */ 10 | public interface AttributeValueIgnoringStrategyBuilder> extends Serializable{ 11 | AttributeValueIgnoringStrategyBuilder copy(); 12 | AttributeValueIgnoringStrategy createAttributeValueIgnoringStrategy(VC valueCounts); 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/bagging/Bagging.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.bagging; 2 | 3 | import quickml.data.instances.InstanceWithAttributesMap; 4 | 5 | import java.io.Serializable; 6 | import java.util.List; 7 | 8 | /** 9 | * Created by alexanderhawk on 4/5/15. 10 | */ 11 | public interface Bagging { 12 | > TrainingDataPair separateTrainingDataFromOutOfBagData(List trainingData); 13 | 14 | class TrainingDataPair> { 15 | public List trainingData; 16 | public List outOfBagTrainingData; 17 | 18 | public TrainingDataPair(List trainingData, List outOfBagTrainingData) { 19 | this.trainingData = trainingData; 20 | this.outOfBagTrainingData = outOfBagTrainingData; 21 | } 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/bagging/StationaryBagging.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.bagging; 2 | 3 | import com.google.common.collect.Lists; 4 | import quickml.collections.MapUtils; 5 | import quickml.data.instances.InstanceWithAttributesMap; 6 | 7 | import java.io.Serializable; 8 | import java.util.HashSet; 9 | import java.util.List; 10 | 11 | /** 12 | * Created by alexanderhawk on 4/5/15. 13 | */ 14 | public class StationaryBagging implements Bagging { 15 | 16 | private static com.twitter.common.util.Random rand = com.twitter.common.util.Random.Util.fromSystemRandom(MapUtils.random); 17 | 18 | @Override 19 | public > TrainingDataPair separateTrainingDataFromOutOfBagData(List trainingData) { 20 | List baggedTrainingData = Lists.newArrayList(); 21 | List outOfBagTrainingData = Lists.newArrayList(); 22 | 23 | HashSet unusedDataIndices = new HashSet<>(); 24 | for (int i = 0; i < trainingData.size(); i++) { 25 | unusedDataIndices.add(i); 26 | } 27 | for (int i = 0; i < trainingData.size(); i++) { 28 | int toAdd = rand.nextInt(trainingData.size()); 29 | if (unusedDataIndices.contains(toAdd)) 30 | unusedDataIndices.remove(toAdd); 31 | baggedTrainingData.add(trainingData.get(toAdd)); 32 | } 33 | for (Integer index : unusedDataIndices) { 34 | outOfBagTrainingData.add(trainingData.get(index)); 35 | } 36 | return new TrainingDataPair<>(baggedTrainingData, outOfBagTrainingData); 37 | 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/branchFinders/BranchFinderAndReducerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.branchFinders; 2 | 3 | import quickml.data.instances.InstanceWithAttributesMap; 4 | import quickml.supervised.tree.reducers.ReducerFactory; 5 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 6 | 7 | /** 8 | * Created by alexanderhawk on 6/18/15. 9 | */ 10 | public class BranchFinderAndReducerFactory, VC extends ValueCounter> { 11 | protected BranchFinder branchFinder; 12 | protected ReducerFactory reducerFactory; 13 | 14 | public BranchFinderAndReducerFactory(BranchFinder branchFinder, ReducerFactory reducerFactory) { 15 | this.branchFinder = branchFinder; 16 | this.reducerFactory = reducerFactory; 17 | } 18 | 19 | public BranchFinder getBranchFinder() { 20 | return branchFinder; 21 | } 22 | 23 | public ReducerFactory getReducerFactory() { 24 | return reducerFactory; 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/branchFinders/branchFinderBuilders/AlternativeSelction.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.branchFinders.branchFinderBuilders; 2 | 3 | import com.google.common.collect.Lists; 4 | import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability; 5 | 6 | import java.util.ArrayList; 7 | import java.util.Collections; 8 | 9 | /** 10 | * Created by alexanderhawk on 3/30/16. 11 | */ 12 | public class AlternativeSelction 13 | { 14 | // double ignoreProb = ((IgnoreAttributesWithConstantProbability) attributeIgnoringStrategy).getIgnoreAttributeProbability(); 15 | // ArrayList candidates = Lists.newArrayList(candidateAttributes); 16 | // 17 | // if (ignoreProb == 0.0) { 18 | // return candidates; 19 | //} 20 | // //O(N) way of shuffling the attributes. 21 | // Collections.shuffle(candidates); 22 | // int numTrialAttributes = (int)(ignoreProb*candidates.size()); 23 | // 24 | // return candidates.subList(0,numTrialAttributes); 25 | 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/branchingConditions/BranchingConditions.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.branchingConditions; 2 | 3 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 4 | import quickml.supervised.tree.nodes.Branch; 5 | import quickml.supervised.tree.nodes.Node; 6 | 7 | import java.io.Serializable; 8 | import java.util.Map; 9 | 10 | /** 11 | * Created by alexanderhawk on 4/4/15. 12 | */ 13 | 14 | 15 | public interface BranchingConditions> extends Serializable{ 16 | boolean isInvalidSplit(VC trueValueStats, VC falseValueStats); 17 | 18 | boolean isInvalidSplit(double score); 19 | 20 | boolean isInvalidSplit(VC trueSet, VC falseSet, String attribute); 21 | 22 | boolean canTryAddingChildren(Branch branch, VC VC); 23 | 24 | void update(Map cfg); 25 | 26 | BranchingConditions copy(); 27 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/constants/AttributeType.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.constants; 2 | 3 | import quickml.supervised.tree.nodes.Branch; 4 | 5 | import javax.management.Attribute; 6 | 7 | /** 8 | * Created by alexanderhawk on 6/24/15. 9 | */ 10 | public enum AttributeType { 11 | CATEGORICAL(), NUMERIC(), BOOLEAN(), RT_CATEGORICAL, RT_NUMERIC; 12 | 13 | public static AttributeType convertBranchTypeToAttributeType(BranchType branchType) { 14 | if (branchType.name().equals(CATEGORICAL.name()) 15 | || branchType.equals(BranchType.BINARY_CATEGORICAL) 16 | || branchType.name().equals(BranchType.RT_CATEGORICAL.name())) { 17 | return CATEGORICAL; 18 | } else if (branchType.name().equals(NUMERIC.name()) 19 | || branchType.name().equals(BranchType.RT_NUMERIC.name())) { 20 | return NUMERIC; 21 | } else if (branchType.name().equals(BOOLEAN.name())) { 22 | return BOOLEAN; 23 | } else { 24 | throw new RuntimeException("unknown branch type: " + branchType.name()); 25 | } 26 | 27 | 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/constants/BranchType.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.constants; 2 | 3 | /** 4 | * Created by alexanderhawk on 3/19/15. 5 | */ 6 | public enum BranchType { 7 | CATEGORICAL(), BINARY_CATEGORICAL(), NUMERIC(), BOOLEAN(), RT_CATEGORICAL, RT_NUMERIC; 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/constants/ForestOptions.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.constants; 2 | 3 | /** 4 | * Created by alexanderhawk on 3/20/15. 5 | */ 6 | public enum ForestOptions { 7 | BAGGING(), 8 | DOWNSAMPLING_TARGET_MINORITY_PROPORTION(), 9 | PRUNING_STRATEGY(), 10 | SCORER_FACTORY(), 11 | MAX_DEPTH(), 12 | MIN_SCORE(), 13 | MIN_LEAF_INSTANCES(), 14 | MIN_SLPIT_FRACTION(), 15 | IMBALANCE_PENALTY_POWER(), 16 | NUM_TREES(), 17 | ATTRIBUTE_VALUE_THRESHOLD_OBSERVATIONS(), 18 | PENALIZE_CATEGORICAL_SPLITS(), 19 | ATTRIBUTE_IGNORING_STRATEGY(), 20 | ATTRIBUTE_VALUE_IGNORING_STRATEGY(), 21 | ATTRIBUTE_VALUE_IGNORING_STRATEGY_BUILDER(), 22 | DEGREE_OF_GAIN_RATIO_PENALTY(), 23 | BINS_FOR_NUMERIC_SPLITS(), 24 | NUM_SAMPLES_PER_NUMERIC_BIN(), 25 | NUM_NUMERIC_BINS(), 26 | SAMPLES_PER_BIN(), 27 | BRANCH_FINDER_BUILDERS(), 28 | NUMERIC_BRANCH_BUILDER(), 29 | CATEGORICAL_BRANCH_BUILDER(), 30 | LEAF_BUILDER(), 31 | BOOLEAN_BRANCH_BUILDER(), 32 | BRANCHING_CONDITIONS(), 33 | TREE_FACTORY(), 34 | DATA_PROPERTIES_TRANSFORMER(), 35 | MIN_ATTRIBUTE_VALUE_OCCURRENCES(), 36 | EXEMPT_ATTRIBUTES; 37 | } 38 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/constants/MissingValue.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.constants; 2 | 3 | /** 4 | * Created by alexanderhawk on 4/5/15. 5 | */ 6 | public enum MissingValue { 7 | MISSING_VALUE(); 8 | } 9 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/DecisionTreeBuilderHelper.java: -------------------------------------------------------------------------------- 1 | 2 | package quickml.supervised.tree.decisionTree; 3 | 4 | import org.javatuples.Pair; 5 | import quickml.data.instances.ClassifierInstance; 6 | import quickml.supervised.tree.TreeBuilderHelper; 7 | import quickml.supervised.tree.decisionTree.treeBuildContexts.DTreeContextBuilder; 8 | import quickml.supervised.tree.decisionTree.treeBuildContexts.DTreeContext; 9 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 10 | import quickml.supervised.tree.nodes.Node; 11 | 12 | import java.io.Serializable; 13 | import java.util.List; 14 | import java.util.Set; 15 | 16 | /** 17 | * Created by alexanderhawk on 4/20/15. 18 | */ 19 | public class DecisionTreeBuilderHelper extends TreeBuilderHelper { 20 | 21 | DTreeContextBuilder treeBuildContext; 22 | public DecisionTreeBuilderHelper(DTreeContextBuilder treeBuildContext) { 23 | super(treeBuildContext); 24 | this.treeBuildContext = treeBuildContext; 25 | } 26 | 27 | public Pair, Set> computeNodesAndClasses(List trainingData) { 28 | DTreeContext itbc = treeBuildContext.buildContext(trainingData); 29 | Node root = createNode(null, trainingData, itbc); 30 | return Pair.with(root, itbc.getClassifications()); 31 | } 32 | 33 | } 34 | 35 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/DecisionTreeVisualizer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.nodes.Branch; 5 | import quickml.supervised.tree.nodes.Leaf; 6 | import quickml.supervised.tree.nodes.Node; 7 | 8 | import java.io.PrintStream; 9 | 10 | /** 11 | * Created by ian on 7/20/15. 12 | */ 13 | public class DecisionTreeVisualizer { 14 | 15 | public static final int INDENT_AMOUNT = 3; 16 | 17 | public void visualize(DecisionTree tree, PrintStream out) { 18 | visualize(tree.root, out, 0); 19 | } 20 | 21 | private void visualize(final Node node, final PrintStream out, final int depth) { 22 | StringBuilder indentBuilder = new StringBuilder(); 23 | for (int x = 0; x < depth; x++) { 24 | indentBuilder.append(' '); 25 | } 26 | String indent = indentBuilder.toString(); 27 | 28 | if (node instanceof Branch) { 29 | Branch branch = (Branch) node; 30 | out.println(indent + branch.toString() + " TRUE:"); 31 | visualize(branch.getTrueChild(), out, depth + INDENT_AMOUNT); 32 | out.println(indent + branch.toString() + " FALSE:"); 33 | visualize(branch.getFalseChild(), out, depth + INDENT_AMOUNT); 34 | } else if (node instanceof Leaf) { 35 | out.println(indent + "LEAF: " + node.toString()); 36 | } 37 | } 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/attributeValueIgnoringStrategies/BinaryClassAttributeValueIgnoringStrategyBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.attributeValueIgnoringStrategies; 2 | 3 | 4 | import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy; 5 | import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategyBuilder; 6 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 7 | 8 | public class BinaryClassAttributeValueIgnoringStrategyBuilder implements AttributeValueIgnoringStrategyBuilder { 9 | 10 | private static final long serialVersionUID = 0L; 11 | public BinaryClassAttributeValueIgnoringStrategyBuilder(int minOccurancesOfAttributeValue) { 12 | this.minOccurancesOfAttributeValue = minOccurancesOfAttributeValue; 13 | } 14 | 15 | @Override 16 | public AttributeValueIgnoringStrategy createAttributeValueIgnoringStrategy(ClassificationCounter cc) { 17 | return new BinaryClassAttributeValueIgnoringStrategy(cc, minOccurancesOfAttributeValue); 18 | } 19 | 20 | private int minOccurancesOfAttributeValue; 21 | 22 | 23 | public BinaryClassAttributeValueIgnoringStrategyBuilder setMinOccurancesOfAttributeValue(int minOccurancesOfAttributeValue) { 24 | this.minOccurancesOfAttributeValue = minOccurancesOfAttributeValue; 25 | return this; 26 | } 27 | 28 | public BinaryClassAttributeValueIgnoringStrategyBuilder copy() { 29 | return new BinaryClassAttributeValueIgnoringStrategyBuilder(minOccurancesOfAttributeValue).setMinOccurancesOfAttributeValue(minOccurancesOfAttributeValue); 30 | } 31 | 32 | 33 | 34 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/attributeValueIgnoringStrategies/MultiClassAtributeValueIgnoringStrategy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.attributeValueIgnoringStrategies; 2 | 3 | import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy; 4 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 5 | 6 | import java.io.Serializable; 7 | import java.util.Map; 8 | 9 | /** 10 | * Created by alexanderhawk on 3/18/15. 11 | */ 12 | public class MultiClassAtributeValueIgnoringStrategy implements AttributeValueIgnoringStrategy { 13 | private int minOccurancesOfAttributeValue; 14 | public MultiClassAtributeValueIgnoringStrategy(int minOccurancesOfAttributeValue) { 15 | this.minOccurancesOfAttributeValue = minOccurancesOfAttributeValue; 16 | } 17 | 18 | public boolean shouldWeIgnoreThisValue(final ClassificationCounter testValCounts) { 19 | Map counts = testValCounts.getCounts(); 20 | 21 | for (Serializable key : counts.keySet()) { 22 | if (counts.get(key).doubleValue() < minOccurancesOfAttributeValue) { 23 | return true; 24 | } 25 | } 26 | 27 | return false; 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/attributeValueIgnoringStrategies/MultiClassAttributeValueIgnoringStrategyBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.attributeValueIgnoringStrategies; 2 | 3 | 4 | import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy; 5 | import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategyBuilder; 6 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 7 | 8 | public class MultiClassAttributeValueIgnoringStrategyBuilder implements AttributeValueIgnoringStrategyBuilder { 9 | private static final long serialVersionUID = 0L; 10 | 11 | public MultiClassAttributeValueIgnoringStrategyBuilder(int minOccurancesOfAttributeValue) { 12 | this.minOccurancesOfAttributeValue = minOccurancesOfAttributeValue; 13 | } 14 | 15 | @Override 16 | public AttributeValueIgnoringStrategy createAttributeValueIgnoringStrategy(ClassificationCounter cc) { 17 | return new MultiClassAtributeValueIgnoringStrategy(minOccurancesOfAttributeValue); 18 | } 19 | 20 | private int minOccurancesOfAttributeValue; 21 | 22 | 23 | public MultiClassAttributeValueIgnoringStrategyBuilder setMinOccurancesOfAttributeValue(int minOccurancesOfAttributeValue) { 24 | this.minOccurancesOfAttributeValue = minOccurancesOfAttributeValue; 25 | return this; 26 | } 27 | 28 | public MultiClassAttributeValueIgnoringStrategyBuilder copy() { 29 | return new MultiClassAttributeValueIgnoringStrategyBuilder(minOccurancesOfAttributeValue).setMinOccurancesOfAttributeValue(minOccurancesOfAttributeValue); 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/branchFinders/branchFinderBuilders/DTBranchFinderBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.branchFinders.branchFinderBuilders; 2 | 3 | import quickml.supervised.tree.branchFinders.branchFinderBuilders.BranchFinderBuilder; 4 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 5 | 6 | /** 7 | * Created by alexanderhawk on 6/21/15. 8 | */ 9 | public abstract class DTBranchFinderBuilder extends BranchFinderBuilder { 10 | 11 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/branchingConditions/DTBranchingConditions.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.branchingConditions; 2 | 3 | import quickml.supervised.tree.branchingConditions.StandardBranchingConditions; 4 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 5 | 6 | /** 7 | * Created by alexanderhawk on 6/21/15. 8 | */ 9 | public class DTBranchingConditions extends StandardBranchingConditions{ 10 | public DTBranchingConditions(double minScore, int maxDepth, int minLeafInstances, double minSplitFraction) { 11 | super(minScore, maxDepth, minLeafInstances, minSplitFraction); 12 | } 13 | 14 | public DTBranchingConditions() { 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/nodes/DTCatBranch.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.nodes; 2 | 3 | import com.google.common.collect.Sets; 4 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 5 | import quickml.supervised.tree.nodes.Branch; 6 | 7 | import java.io.Serializable; 8 | import java.util.Map; 9 | import java.util.Set; 10 | 11 | /** 12 | * Created by alexanderhawk on 4/27/15. 13 | */ 14 | public class DTCatBranch extends Branch { 15 | private static final long serialVersionUID = -1723969623146234761L; 16 | public final Set trueSet; 17 | 18 | public DTCatBranch(Branch parent, final String attribute, final Set trueSet, double probabilityOfTrueChild, double score, ClassificationCounter aggregateStats ) { 19 | super(parent, attribute, probabilityOfTrueChild, score, aggregateStats); 20 | this.trueSet = Sets.newHashSet(trueSet); 21 | } 22 | 23 | @Override 24 | public boolean decide(final Map attributes) { 25 | Serializable attributeVal = attributes.get(attribute); 26 | //missing values always go the way of the outset...which strangely seems to be most accurate 27 | return trueSet.contains(attributeVal); 28 | } 29 | 30 | @Override 31 | public String toString() { 32 | return attribute + " in " + trueSet; 33 | } 34 | 35 | @Override 36 | public boolean equals(final Object o) { 37 | if (this == o) return true; 38 | if (o == null || getClass() != o.getClass()) return false; 39 | if (!super.equals(o)) return false; 40 | 41 | final DTCatBranch that = (DTCatBranch) o; 42 | 43 | if (!trueSet.equals(that.trueSet)) return false; 44 | 45 | return true; 46 | } 47 | 48 | @Override 49 | public int hashCode() { 50 | int result = super.hashCode(); 51 | result = 31 * result + trueSet.hashCode(); 52 | return result; 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/nodes/DTLeafBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.nodes; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.nodes.Branch; 5 | import quickml.supervised.tree.nodes.LeafBuilder; 6 | 7 | /** 8 | * Created by alexanderhawk on 4/24/15. 9 | */ 10 | public class DTLeafBuilder implements LeafBuilder { 11 | private static final long serialVersionUID = 0L; 12 | 13 | public DTLeaf buildLeaf(Branch parent, ClassificationCounter valueCounter){ 14 | return new DTLeaf(parent, valueCounter, parent==null || parent.isEmpty() ? 0 : parent.getDepth()+1); 15 | } 16 | 17 | @Override 18 | public LeafBuilder copy() { 19 | return new DTLeafBuilder(); 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/nodes/DTNumBranch.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.nodes; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.nodes.Branch; 5 | import quickml.supervised.tree.nodes.NumBranch; 6 | 7 | /** 8 | * Created by alexanderhawk on 6/11/15. 9 | */ 10 | public class DTNumBranch extends NumBranch{ 11 | 12 | public DTNumBranch(Branch parent, String attribute, double probabilityOfTrueChild, double score, ClassificationCounter termStatistics, double threshold) { 13 | super(parent, attribute, probabilityOfTrueChild, score, termStatistics, threshold); 14 | 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/reducers/DTreeReducer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.reducers; 2 | 3 | import quickml.data.instances.ClassifierInstance; 4 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 5 | import quickml.supervised.tree.reducers.Reducer; 6 | 7 | import java.util.List; 8 | 9 | /** 10 | * Created by alexanderhawk on 6/21/15. 11 | */ 12 | public abstract class DTreeReducer extends Reducer { 13 | public DTreeReducer(List trainingData) { 14 | super(trainingData); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/reducers/reducerFactories/DTBinaryCatBranchReducerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.reducers.reducerFactories; 2 | 3 | import quickml.data.instances.ClassifierInstance; 4 | import quickml.supervised.tree.decisionTree.reducers.DTBinaryCatBranchReducer; 5 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 6 | import quickml.supervised.tree.reducers.Reducer; 7 | import quickml.supervised.tree.reducers.ReducerFactory; 8 | 9 | import java.io.Serializable; 10 | import java.util.List; 11 | import java.util.Map; 12 | 13 | /** 14 | * Created by alexanderhawk on 7/9/15. 15 | */ 16 | public class DTBinaryCatBranchReducerFactory implements ReducerFactory{ 17 | private final Serializable minorityClassification; 18 | 19 | public DTBinaryCatBranchReducerFactory(Serializable minorityClassification) { 20 | this.minorityClassification = minorityClassification; 21 | } 22 | 23 | @Override 24 | public Reducer getReducer(List trainingData) { 25 | return new DTBinaryCatBranchReducer<>(trainingData, minorityClassification); 26 | } 27 | 28 | @Override 29 | public void updateBuilderConfig(Map cfg) { 30 | 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/reducers/reducerFactories/DTCatBranchReducerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.reducers.reducerFactories; 2 | 3 | import quickml.data.instances.ClassifierInstance; 4 | import quickml.supervised.tree.decisionTree.reducers.DTCatBranchReducer; 5 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 6 | import quickml.supervised.tree.reducers.Reducer; 7 | import quickml.supervised.tree.reducers.ReducerFactory; 8 | 9 | import java.io.Serializable; 10 | import java.util.List; 11 | import java.util.Map; 12 | 13 | /** 14 | * Created by alexanderhawk on 7/9/15. 15 | */ 16 | public class DTCatBranchReducerFactory implements ReducerFactory{ 17 | 18 | @Override 19 | public Reducer getReducer(List trainingData) { 20 | return new DTCatBranchReducer<>(trainingData); 21 | } 22 | 23 | @Override 24 | public void updateBuilderConfig(Map cfg) { 25 | 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/reducers/reducerFactories/DTNumBranchReducerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.reducers.reducerFactories; 2 | 3 | import quickml.data.instances.ClassifierInstance; 4 | import quickml.supervised.tree.decisionTree.reducers.DTNumBranchReducer; 5 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 6 | import quickml.supervised.tree.reducers.Reducer; 7 | import quickml.supervised.tree.reducers.ReducerFactory; 8 | 9 | import java.io.Serializable; 10 | import java.util.List; 11 | import java.util.Map; 12 | 13 | import static quickml.supervised.tree.constants.ForestOptions.NUM_NUMERIC_BINS; 14 | import static quickml.supervised.tree.constants.ForestOptions.NUM_SAMPLES_PER_NUMERIC_BIN; 15 | 16 | /** 17 | * Created by alexanderhawk on 7/9/15. 18 | */ 19 | public class DTNumBranchReducerFactory implements ReducerFactory{ 20 | int numSamplesPerBin; 21 | int numNumericBins; 22 | 23 | 24 | @Override 25 | public Reducer getReducer(List trainingData) { 26 | return new DTNumBranchReducer<>(trainingData, numSamplesPerBin, numNumericBins); 27 | } 28 | 29 | @Override 30 | public void updateBuilderConfig(Map cfg) { 31 | if (cfg.containsKey(NUM_SAMPLES_PER_NUMERIC_BIN.name())) { 32 | numSamplesPerBin = (int) cfg.get(NUM_SAMPLES_PER_NUMERIC_BIN.name()); 33 | } 34 | if (cfg.containsKey(NUM_NUMERIC_BINS.name())) { 35 | numNumericBins = (int) cfg.get(NUM_NUMERIC_BINS.name()); 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/reducers/reducerFactories/DTOldCatBranchReducerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.reducers.reducerFactories; 2 | 3 | import quickml.data.instances.ClassifierInstance; 4 | import quickml.supervised.tree.decisionTree.reducers.DTOldCatBranchReducer; 5 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 6 | import quickml.supervised.tree.reducers.Reducer; 7 | import quickml.supervised.tree.reducers.ReducerFactory; 8 | 9 | import java.io.Serializable; 10 | import java.util.List; 11 | import java.util.Map; 12 | 13 | /** 14 | * Created by alexanderhawk on 7/9/15. 15 | */ 16 | public class DTOldCatBranchReducerFactory implements ReducerFactory{ 17 | 18 | @Override 19 | public Reducer getReducer(List trainingData) { 20 | return new DTOldCatBranchReducer<>(trainingData); 21 | } 22 | 23 | @Override 24 | public void updateBuilderConfig(Map cfg) { 25 | 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/scorers/GRPenalizedGiniImpurityScorer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.scorers; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.reducers.AttributeStats; 5 | import quickml.supervised.tree.scorers.GRScorer; 6 | 7 | import java.io.Serializable; 8 | import java.util.Map; 9 | 10 | /** 11 | * Created by chrisreeves on 6/24/14. 12 | */ 13 | public class GRPenalizedGiniImpurityScorer extends GRScorer { 14 | 15 | 16 | public GRPenalizedGiniImpurityScorer(double degreeOfGainRatioPenalty, AttributeStats attributeStats) { 17 | super(degreeOfGainRatioPenalty, attributeStats); 18 | } 19 | 20 | @Override 21 | public double scoreSplit(ClassificationCounter a, ClassificationCounter b) { 22 | ClassificationCounter parent = ClassificationCounter.merge(a, b); 23 | double aGiniIndex = getGiniIndex(a) * a.getTotal() / parent.getTotal(); 24 | double bGiniIndex = getGiniIndex(b) * b.getTotal() / parent.getTotal(); 25 | double score = unSplitScore - aGiniIndex - bGiniIndex; 26 | return correctForGainRatio(score); 27 | } 28 | 29 | @Override 30 | public double getUnSplitScore(ClassificationCounter a) { 31 | return getGiniIndex(a); 32 | 33 | } 34 | 35 | private double getGiniIndex(ClassificationCounter cc) { 36 | double sum = 0.0d; 37 | for (Map.Entry e : cc.getCounts().entrySet()) { 38 | double error = (cc.getTotal() > 0) ? e.getValue() / cc.getTotal() : 0; 39 | sum += error * error; 40 | } 41 | return 1.0d - sum; 42 | } 43 | 44 | @Override 45 | public String toString() { 46 | return "GiniImpurity"; 47 | } 48 | 49 | } 50 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/scorers/GRPenalizedGiniImpurityScorerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.scorers; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.reducers.AttributeStats; 5 | import quickml.supervised.tree.scorers.*; 6 | 7 | /** 8 | * Created by alexanderhawk on 7/9/15. 9 | */ 10 | public class GRPenalizedGiniImpurityScorerFactory extends GRScorerFactory { 11 | 12 | public GRPenalizedGiniImpurityScorerFactory() { 13 | } 14 | 15 | public GRPenalizedGiniImpurityScorerFactory(double degreeOfGainRatioPenalty) { 16 | super(degreeOfGainRatioPenalty); 17 | } 18 | 19 | @Override 20 | public GRScorer getScorer(AttributeStats attributeStats) { 21 | return new GRPenalizedGiniImpurityScorer(degreeOfGainRatioPenalty, attributeStats); 22 | } 23 | 24 | @Override 25 | public ScorerFactory copy() { 26 | return new GRPenalizedGiniImpurityScorerFactory(degreeOfGainRatioPenalty); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/scorers/PenalizedGiniImpurityScorer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.scorers; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.reducers.AttributeStats; 5 | import quickml.supervised.tree.scorers.GRImbalancedScorer; 6 | 7 | import java.io.Serializable; 8 | import java.util.Map; 9 | 10 | /** 11 | * Created by chrisreeves on 6/24/14. 12 | */ 13 | public class PenalizedGiniImpurityScorer extends GRImbalancedScorer { 14 | 15 | public PenalizedGiniImpurityScorer(double degreeOfGainRatioPenalty, double imbalancePenaltyPower, AttributeStats attributeStats) { 16 | super(degreeOfGainRatioPenalty, imbalancePenaltyPower, attributeStats); 17 | } 18 | 19 | @Override 20 | public double scoreSplit(ClassificationCounter a, ClassificationCounter b) { 21 | ClassificationCounter parent = ClassificationCounter.merge(a, b); 22 | double aGiniIndex = getGiniIndex(a) * a.getTotal() / parent.getTotal(); 23 | double bGiniIndex = getGiniIndex(b) * b.getTotal() / parent.getTotal(); 24 | double score = unSplitScore - aGiniIndex - bGiniIndex; 25 | return correctForGainRatio(score)*getPenaltyForImabalance(a, b); 26 | } 27 | 28 | @Override 29 | public double getUnSplitScore(ClassificationCounter a) { 30 | return getGiniIndex(a); 31 | 32 | } 33 | 34 | private double getGiniIndex(ClassificationCounter cc) { 35 | double sum = 0.0d; 36 | for (Map.Entry e : cc.getCounts().entrySet()) { 37 | double error = (cc.getTotal() > 0) ? e.getValue() / cc.getTotal() : 0; 38 | sum += error * error; 39 | } 40 | return 1.0d - sum; 41 | } 42 | 43 | @Override 44 | public String toString() { 45 | return "GiniImpurity"; 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/scorers/PenalizedGiniImpurityScorerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.scorers; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.reducers.AttributeStats; 5 | import quickml.supervised.tree.scorers.*; 6 | 7 | /** 8 | * Created by alexanderhawk on 7/9/15. 9 | */ 10 | public class PenalizedGiniImpurityScorerFactory extends GRImbalancedScorerFactory { 11 | 12 | public PenalizedGiniImpurityScorerFactory() { 13 | } 14 | 15 | public PenalizedGiniImpurityScorerFactory(double degreeOfGainRatioPenalty, double imbalancePenaltyPower) { 16 | super(degreeOfGainRatioPenalty, imbalancePenaltyPower); 17 | } 18 | 19 | @Override 20 | public GRScorer getScorer(AttributeStats attributeStats) { 21 | return new GRPenalizedGiniImpurityScorer(degreeOfGainRatioPenalty, attributeStats); 22 | } 23 | 24 | @Override 25 | public ScorerFactory copy() { 26 | return new PenalizedGiniImpurityScorerFactory(degreeOfGainRatioPenalty, imbalancePenaltyPower); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/scorers/PenalizedInformationGainScorerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.scorers; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.reducers.AttributeStats; 5 | import quickml.supervised.tree.scorers.GRImbalancedScorer; 6 | import quickml.supervised.tree.scorers.GRImbalancedScorerFactory; 7 | import quickml.supervised.tree.scorers.ScorerFactory; 8 | 9 | /** 10 | * Created by alexanderhawk on 7/9/15. 11 | */ 12 | public class PenalizedInformationGainScorerFactory extends GRImbalancedScorerFactory { 13 | 14 | public PenalizedInformationGainScorerFactory() { 15 | } 16 | 17 | public PenalizedInformationGainScorerFactory(double degreeOfGainRatioPenalty, double imbalancePenaltyPower) { 18 | super(degreeOfGainRatioPenalty, imbalancePenaltyPower); 19 | } 20 | 21 | @Override 22 | public GRImbalancedScorer getScorer(AttributeStats attributeStats) { 23 | return new PenalizedInformationGainScorer(degreeOfGainRatioPenalty, imbalancePenaltyPower, attributeStats); 24 | } 25 | 26 | @Override 27 | public ScorerFactory copy() { 28 | return new PenalizedInformationGainScorerFactory(degreeOfGainRatioPenalty, imbalancePenaltyPower); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/scorers/PenalizedMSEScorerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.scorers; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.reducers.AttributeStats; 5 | import quickml.supervised.tree.scorers.GRImbalancedScorer; 6 | import quickml.supervised.tree.scorers.GRImbalancedScorerFactory; 7 | import quickml.supervised.tree.scorers.ScorerFactory; 8 | 9 | /** 10 | * Created by alexanderhawk on 7/9/15. 11 | */ 12 | public class PenalizedMSEScorerFactory extends GRImbalancedScorerFactory { 13 | 14 | public PenalizedMSEScorerFactory() { 15 | } 16 | 17 | public PenalizedMSEScorerFactory(double degreeOfGainRatioPenalty, double imbalancePenaltyPower) { 18 | super(degreeOfGainRatioPenalty, imbalancePenaltyPower); 19 | } 20 | 21 | @Override 22 | public GRImbalancedScorer getScorer(AttributeStats attributeStats) { 23 | return new PenalizedMSEScorer(degreeOfGainRatioPenalty, imbalancePenaltyPower, attributeStats); 24 | } 25 | 26 | @Override 27 | public ScorerFactory copy() { 28 | return new PenalizedMSEScorerFactory(degreeOfGainRatioPenalty, imbalancePenaltyPower); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/scorers/PenalizedSplitDiffScorerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.scorers; 2 | 3 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 4 | import quickml.supervised.tree.reducers.AttributeStats; 5 | import quickml.supervised.tree.scorers.GRImbalancedScorer; 6 | import quickml.supervised.tree.scorers.GRImbalancedScorerFactory; 7 | import quickml.supervised.tree.scorers.ScorerFactory; 8 | 9 | /** 10 | * Created by alexanderhawk on 7/9/15. 11 | */ 12 | public class PenalizedSplitDiffScorerFactory extends GRImbalancedScorerFactory { 13 | 14 | public PenalizedSplitDiffScorerFactory() { 15 | } 16 | 17 | public PenalizedSplitDiffScorerFactory(double degreeOfGainRatioPenalty, double imbalancePenaltyPower) { 18 | super(degreeOfGainRatioPenalty, imbalancePenaltyPower); 19 | } 20 | 21 | @Override 22 | public GRImbalancedScorer getScorer(AttributeStats attributeStats) { 23 | return new PenalizedSplitDiffScorer(degreeOfGainRatioPenalty, imbalancePenaltyPower, attributeStats); 24 | } 25 | 26 | @Override 27 | public ScorerFactory copy() { 28 | return new PenalizedSplitDiffScorerFactory(degreeOfGainRatioPenalty, imbalancePenaltyPower); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/treeBuildContexts/DTreeContext.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.treeBuildContexts; 2 | 3 | import quickml.data.instances.ClassifierInstance; 4 | import quickml.supervised.tree.branchFinders.BranchFinderAndReducerFactory; 5 | import quickml.supervised.tree.branchingConditions.BranchingConditions; 6 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 7 | import quickml.supervised.tree.nodes.LeafBuilder; 8 | import quickml.supervised.tree.scorers.ScorerFactory; 9 | import quickml.supervised.tree.summaryStatistics.ValueCounterProducer; 10 | import quickml.supervised.tree.treeBuildContexts.TreeContext; 11 | 12 | import java.io.Serializable; 13 | import java.util.List; 14 | import java.util.Set; 15 | 16 | /** 17 | * Created by alexanderhawk on 6/21/15. 18 | */ 19 | public class DTreeContext extends TreeContext { 20 | Set classifications; 21 | 22 | public DTreeContext(Set classifications, 23 | BranchingConditions branchingConditions, 24 | ScorerFactory scorerFactory, 25 | List> branchFindersAndReducers, 26 | LeafBuilder leafBuilder, 27 | ValueCounterProducer valueCounterProducer 28 | ) { 29 | super(branchingConditions, scorerFactory, branchFindersAndReducers, leafBuilder, valueCounterProducer); 30 | this.classifications = classifications; 31 | } 32 | 33 | public Set getClassifications() { 34 | return classifications; 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/decisionTree/valueCounters/ClassificationCounterProducer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.valueCounters; 2 | 3 | import quickml.data.instances.ClassifierInstance; 4 | import quickml.supervised.tree.summaryStatistics.ValueCounterProducer; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * Created by alexanderhawk on 4/22/15. 10 | */ 11 | public class ClassificationCounterProducer implements ValueCounterProducer { 12 | @Override 13 | public ClassificationCounter getValueCounter(List instances) { 14 | return ClassificationCounter.countAll(instances); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/nodes/Leaf.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.nodes; 2 | 3 | 4 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 5 | 6 | /** 7 | * Created by alexanderhawk on 4/24/15. 8 | */ 9 | 10 | public interface Leaf> extends Node { 11 | int getDepth(); 12 | VC getValueCounter(); 13 | } 14 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/nodes/LeafBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.nodes; 2 | 3 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 4 | 5 | import java.io.Serializable; 6 | 7 | 8 | /** 9 | * Created by alexanderhawk on 3/22/15. 10 | */ 11 | public interface LeafBuilder> extends Serializable{ 12 | Leaf buildLeaf(Branch parent, VC valueCounter); 13 | LeafBuilder copy(); 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/nodes/LeafDepthStats.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.nodes; 2 | 3 | import com.google.common.collect.Maps; 4 | 5 | import java.util.TreeMap; 6 | 7 | /** 8 | * Created by alexanderhawk on 4/28/15. 9 | */ 10 | public class LeafDepthStats { 11 | public int ttlDepth = 0; 12 | public int ttlSamples = 0; 13 | public TreeMap depthDistribution = Maps.newTreeMap(); 14 | } 15 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/nodes/Node.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.nodes; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 5 | 6 | /** 7 | * Created by alexanderhawk on 6/18/15. 8 | */ 9 | public interface Node> { 10 | 11 | 12 | @Override 13 | boolean equals(final Object obj); 14 | 15 | @Override 16 | int hashCode(); 17 | //last 2 are optional 18 | void calcLeafDepthStats(LeafDepthStats stats); 19 | 20 | /** 21 | * Return the number of nodes in this decision oldTree. 22 | * 23 | * @return 24 | */ 25 | int getSize(); 26 | 27 | Leaf getLeaf(AttributesMap attributes); 28 | Node getParent(); 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/nodes/WeightAndMeanTracker.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.nodes; 2 | 3 | 4 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 5 | 6 | /** 7 | * Created by alexanderhawk on 4/9/15. 8 | */ 9 | public abstract class WeightAndMeanTracker extends ValueCounter { 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/reducers/AttributeStatisticsProducer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.reducers; 2 | 3 | 4 | import com.google.common.base.Optional; 5 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 6 | 7 | /** 8 | * Created by alexanderhawk on 4/22/15. 9 | */ 10 | public interface AttributeStatisticsProducer> { 11 | Optional> getAttributeStats(String attribute); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/reducers/AttributeStats.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.reducers; 2 | 3 | 4 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 5 | 6 | import java.util.List; 7 | 8 | /** 9 | * Created by alexanderhawk on 4/16/15. 10 | */ 11 | public class AttributeStats> { 12 | List attributeValueStatsList; 13 | VC aggregateStats; 14 | String attribute; 15 | 16 | public AttributeStats(List termStats, VC aggregateStats, String attribute) { 17 | this.attributeValueStatsList = termStats; 18 | this.aggregateStats = aggregateStats; 19 | this.attribute = attribute; 20 | } 21 | 22 | public List getStatsOnEachValue() { 23 | return attributeValueStatsList; 24 | } 25 | public VC getAggregateStats() { 26 | return aggregateStats; 27 | } 28 | 29 | public String getAttribute() { 30 | return attribute; 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/reducers/Reducer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.reducers; 2 | 3 | import com.google.common.base.Optional; 4 | import quickml.data.instances.InstanceWithAttributesMap; 5 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 6 | 7 | import java.util.List; 8 | 9 | /** 10 | * Created by alexanderhawk on 4/16/15. 11 | */ 12 | public abstract class Reducer, VC extends ValueCounter> implements AttributeStatisticsProducer { 13 | private final List trainingData; 14 | 15 | public Reducer(List trainingData) { 16 | this.trainingData = trainingData; 17 | } 18 | 19 | public List getTrainingData() { 20 | return trainingData; 21 | } 22 | 23 | public abstract Optional> getAttributeStats(String attribute); 24 | 25 | 26 | } 27 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/reducers/ReducerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.reducers; 2 | 3 | import quickml.data.instances.InstanceWithAttributesMap; 4 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 5 | 6 | import java.io.Serializable; 7 | import java.util.List; 8 | import java.util.Map; 9 | 10 | /** 11 | * Created by alexanderhawk on 7/9/15. 12 | */ 13 | public interface ReducerFactory, VC extends ValueCounter> { 14 | 15 | Reducer getReducer(List trainingData); 16 | 17 | void updateBuilderConfig(Map cfg); 18 | 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/RegressionTreeBuilderHelper.java: -------------------------------------------------------------------------------- 1 | 2 | package quickml.supervised.tree.regressionTree; 3 | 4 | import org.javatuples.Pair; 5 | import quickml.data.instances.RegressionInstance; 6 | import quickml.supervised.tree.TreeBuilderHelper; 7 | import quickml.supervised.tree.decisionTree.treeBuildContexts.DTreeContextBuilder; 8 | import quickml.supervised.tree.decisionTree.treeBuildContexts.DTreeContext; 9 | import quickml.supervised.tree.nodes.Node; 10 | import quickml.supervised.tree.regressionTree.treeBuildContexts.RTreeContext; 11 | import quickml.supervised.tree.regressionTree.treeBuildContexts.RTreeContextBuilder; 12 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 13 | 14 | import java.io.Serializable; 15 | import java.util.List; 16 | import java.util.Set; 17 | 18 | /** 19 | * Created by alexanderhawk on 4/20/15. 20 | */ 21 | public class RegressionTreeBuilderHelper extends TreeBuilderHelper { 22 | 23 | RTreeContextBuilder treeBuildContext; 24 | public RegressionTreeBuilderHelper(RTreeContextBuilder treeBuildContext) { 25 | super(treeBuildContext); 26 | this.treeBuildContext = treeBuildContext; 27 | } 28 | 29 | public Node computeNodes(List trainingData) { 30 | RTreeContext itbc = treeBuildContext.buildContext(trainingData); 31 | Node root = createNode(null, trainingData, itbc); 32 | return root; 33 | } 34 | 35 | } 36 | 37 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/RegressionTreeVisualizer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree; 2 | 3 | import quickml.supervised.tree.decisionTree.DecisionTree; 4 | import quickml.supervised.tree.nodes.Branch; 5 | import quickml.supervised.tree.nodes.Leaf; 6 | import quickml.supervised.tree.nodes.Node; 7 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 8 | 9 | import java.io.PrintStream; 10 | 11 | /** 12 | * Created by ian on 7/20/15. 13 | */ 14 | public class RegressionTreeVisualizer { 15 | 16 | public static final int INDENT_AMOUNT = 3; 17 | 18 | public void visualize(RegressionTree tree, PrintStream out) { 19 | visualize(tree.root, out, 0); 20 | } 21 | 22 | private void visualize(final Node node, final PrintStream out, final int depth) { 23 | StringBuilder indentBuilder = new StringBuilder(); 24 | for (int x = 0; x < depth; x++) { 25 | indentBuilder.append(' '); 26 | } 27 | String indent = indentBuilder.toString(); 28 | 29 | if (node instanceof Branch) { 30 | Branch branch = (Branch) node; 31 | out.println(indent + branch.toString() + " TRUE:"); 32 | visualize(branch.getTrueChild(), out, depth + INDENT_AMOUNT); 33 | out.println(indent + branch.toString() + " FALSE:"); 34 | visualize(branch.getFalseChild(), out, depth + INDENT_AMOUNT); 35 | } else if (node instanceof Leaf) { 36 | out.println(indent + "LEAF: " + node.toString()); 37 | } 38 | } 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/attributeValueIgnoringStrategies/RegTreeAttributeValueIgnoringStrategy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.attributeValueIgnoringStrategies; 2 | 3 | import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy; 4 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 5 | 6 | /** 7 | * Created by alexanderhawk on 3/18/15. 8 | */ 9 | public class RegTreeAttributeValueIgnoringStrategy implements AttributeValueIgnoringStrategy { 10 | private final int minOccurancesOfAttributeValue; 11 | 12 | public RegTreeAttributeValueIgnoringStrategy(final int minOccurancesOfAttributeValue) { 13 | this.minOccurancesOfAttributeValue = minOccurancesOfAttributeValue; 14 | } 15 | 16 | public boolean shouldWeIgnoreThisValue(final MeanValueCounter termStatistics) { 17 | if (termStatistics.getTotal() < minOccurancesOfAttributeValue) { 18 | return true; 19 | } 20 | return false; 21 | } 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/attributeValueIgnoringStrategies/RegTreeAttributeValueIgnoringStrategyBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.attributeValueIgnoringStrategies; 2 | 3 | 4 | import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategy; 5 | import quickml.supervised.tree.attributeValueIgnoringStrategies.AttributeValueIgnoringStrategyBuilder; 6 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 7 | 8 | import javax.annotation.Nullable; 9 | 10 | public class RegTreeAttributeValueIgnoringStrategyBuilder implements AttributeValueIgnoringStrategyBuilder { 11 | 12 | private static final long serialVersionUID = 0L; 13 | private int minOccurancesOfAttributeValue; 14 | 15 | public RegTreeAttributeValueIgnoringStrategyBuilder(int minOccurancesOfAttributeValue) { 16 | this.minOccurancesOfAttributeValue = minOccurancesOfAttributeValue; 17 | } 18 | 19 | @Override 20 | public AttributeValueIgnoringStrategy createAttributeValueIgnoringStrategy(MeanValueCounter mv) { 21 | return new RegTreeAttributeValueIgnoringStrategy(minOccurancesOfAttributeValue); 22 | } 23 | 24 | 25 | 26 | public RegTreeAttributeValueIgnoringStrategyBuilder setMinOccurancesOfAttributeValue(int minOccurancesOfAttributeValue) { 27 | this.minOccurancesOfAttributeValue = minOccurancesOfAttributeValue; 28 | return this; 29 | } 30 | 31 | public RegTreeAttributeValueIgnoringStrategyBuilder copy() { 32 | return new RegTreeAttributeValueIgnoringStrategyBuilder(minOccurancesOfAttributeValue).setMinOccurancesOfAttributeValue(minOccurancesOfAttributeValue); 33 | } 34 | 35 | 36 | 37 | } -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/branchFinders/branchFinderBuilders/RTBranchFinderBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.branchFinders.branchFinderBuilders; 2 | 3 | import quickml.data.instances.RegressionInstance; 4 | import quickml.supervised.tree.branchFinders.branchFinderBuilders.BranchFinderBuilder; 5 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 6 | 7 | /** 8 | * Created by alexanderhawk on 6/21/15. 9 | */ 10 | public abstract class RTBranchFinderBuilder extends BranchFinderBuilder { 11 | 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/branchingConditions/RTBranchingConditions.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.branchingConditions; 2 | 3 | import quickml.supervised.tree.branchingConditions.StandardBranchingConditions; 4 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 5 | 6 | /** 7 | * Created by alexanderhawk on 6/21/15. 8 | */ 9 | public class RTBranchingConditions extends StandardBranchingConditions{ 10 | public RTBranchingConditions(double minScore, int maxDepth, int minLeafInstances, double minSplitFraction) { 11 | super(minScore, maxDepth, minLeafInstances, minSplitFraction); 12 | } 13 | 14 | public RTBranchingConditions() { 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/nodes/RTCatBranch.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.nodes; 2 | 3 | import com.google.common.collect.Sets; 4 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 5 | import quickml.supervised.tree.nodes.Branch; 6 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 7 | 8 | import java.io.Serializable; 9 | import java.util.Map; 10 | import java.util.Set; 11 | 12 | /** 13 | * Created by alexanderhawk on 4/27/15. 14 | */ 15 | public class RTCatBranch extends Branch { 16 | private static final long serialVersionUID = -1723969623146234761L; 17 | public final Set trueSet; 18 | 19 | public RTCatBranch(Branch parent, final String attribute, final Set trueSet, double probabilityOfTrueChild, double score, MeanValueCounter aggregateStats) { 20 | super(parent, attribute, probabilityOfTrueChild, score, aggregateStats); 21 | this.trueSet = Sets.newHashSet(trueSet); 22 | } 23 | 24 | @Override 25 | public boolean decide(final Map attributes) { 26 | Serializable attributeVal = attributes.get(attribute); 27 | //missing values always go the way of the outset...which strangely seems to be most accurate 28 | return trueSet.contains(attributeVal); 29 | } 30 | 31 | @Override 32 | public String toString() { 33 | return attribute + " in " + trueSet; 34 | } 35 | 36 | @Override 37 | public boolean equals(final Object o) { 38 | if (this == o) return true; 39 | if (o == null || getClass() != o.getClass()) return false; 40 | if (!super.equals(o)) return false; 41 | 42 | final RTCatBranch that = (RTCatBranch) o; 43 | 44 | if (!trueSet.equals(that.trueSet)) return false; 45 | 46 | return true; 47 | } 48 | 49 | @Override 50 | public int hashCode() { 51 | int result = super.hashCode(); 52 | result = 31 * result + trueSet.hashCode(); 53 | return result; 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/nodes/RTLeafBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.nodes; 2 | 3 | import quickml.supervised.tree.decisionTree.nodes.DTLeaf; 4 | import quickml.supervised.tree.nodes.Branch; 5 | import quickml.supervised.tree.nodes.LeafBuilder; 6 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 7 | 8 | 9 | /** 10 | * Created by alexanderhawk on 4/24/15. 11 | */ 12 | public class RTLeafBuilder implements LeafBuilder { 13 | private static final long serialVersionUID = 0L; 14 | 15 | public RTLeaf buildLeaf(Branch parent, MeanValueCounter valueCounter){ 16 | return new RTLeaf(parent, valueCounter, parent==null || parent.isEmpty() ? 0 : parent.getDepth()+1); 17 | } 18 | 19 | @Override 20 | public LeafBuilder copy() { 21 | return new RTLeafBuilder(); 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/nodes/RTNumBranch.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.nodes; 2 | 3 | 4 | import quickml.supervised.tree.nodes.Branch; 5 | import quickml.supervised.tree.nodes.NumBranch; 6 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 7 | 8 | /** 9 | * Created by alexanderhawk on 6/11/15. 10 | */ 11 | public class RTNumBranch extends NumBranch { 12 | 13 | public RTNumBranch(Branch parent, String attribute, double probabilityOfTrueChild, double score, MeanValueCounter termStatistics, double threshold) { 14 | super(parent, attribute, probabilityOfTrueChild, score, termStatistics, threshold); 15 | 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/reducers/RTreeReducer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.reducers; 2 | 3 | import quickml.data.instances.RegressionInstance; 4 | import quickml.supervised.tree.reducers.Reducer; 5 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 6 | 7 | import java.util.List; 8 | 9 | /** 10 | * Created by alexanderhawk on 6/21/15. 11 | */ 12 | public abstract class RTreeReducer extends Reducer { 13 | public RTreeReducer(List trainingData) { 14 | super(trainingData); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/reducers/reducerFactories/RTCatBranchReducerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.reducers.reducerFactories; 2 | 3 | import quickml.data.instances.RegressionInstance; 4 | import quickml.supervised.tree.decisionTree.reducers.DTBinaryCatBranchReducer; 5 | import quickml.supervised.tree.reducers.Reducer; 6 | import quickml.supervised.tree.reducers.ReducerFactory; 7 | import quickml.supervised.tree.regressionTree.reducers.RTCatBranchReducer; 8 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 9 | 10 | import java.io.Serializable; 11 | import java.util.List; 12 | import java.util.Map; 13 | 14 | /** 15 | * Created by alexanderhawk on 7/9/15. 16 | */ 17 | public class RTCatBranchReducerFactory implements ReducerFactory{ 18 | 19 | public RTCatBranchReducerFactory() { 20 | } 21 | 22 | @Override 23 | public Reducer getReducer(List trainingData) { 24 | return new RTCatBranchReducer<>(trainingData); 25 | } 26 | 27 | @Override 28 | public void updateBuilderConfig(Map cfg) { 29 | 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/reducers/reducerFactories/RTNumBranchReducerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.reducers.reducerFactories; 2 | 3 | import quickml.data.instances.RegressionInstance; 4 | import quickml.supervised.tree.reducers.Reducer; 5 | import quickml.supervised.tree.reducers.ReducerFactory; 6 | import quickml.supervised.tree.regressionTree.reducers.RTNumBranchReducer; 7 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 8 | 9 | import java.io.Serializable; 10 | import java.util.List; 11 | import java.util.Map; 12 | 13 | import static quickml.supervised.tree.constants.ForestOptions.NUM_NUMERIC_BINS; 14 | import static quickml.supervised.tree.constants.ForestOptions.NUM_SAMPLES_PER_NUMERIC_BIN; 15 | 16 | /** 17 | * Created by alexanderhawk on 7/9/15. 18 | */ 19 | public class RTNumBranchReducerFactory implements ReducerFactory{ 20 | int numSamplesPerBin; 21 | int numNumericBins; 22 | 23 | 24 | @Override 25 | public Reducer getReducer(List trainingData) { 26 | return new RTNumBranchReducer<>(trainingData, numSamplesPerBin, numNumericBins); 27 | } 28 | 29 | @Override 30 | public void updateBuilderConfig(Map cfg) { 31 | if (cfg.containsKey(NUM_SAMPLES_PER_NUMERIC_BIN.name())) { 32 | numSamplesPerBin = (int) cfg.get(NUM_SAMPLES_PER_NUMERIC_BIN.name()); 33 | } 34 | if (cfg.containsKey(NUM_NUMERIC_BINS.name())) { 35 | numNumericBins = (int) cfg.get(NUM_NUMERIC_BINS.name()); 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/scorers/PenalizedMSEScorer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.scorers; 2 | 3 | 4 | //TODO: fix oldScorers 5 | import quickml.supervised.tree.reducers.AttributeStats; 6 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 7 | import quickml.supervised.tree.scorers.GRImbalancedScorer; 8 | 9 | /** 10 | * A Scorer intended to estimate the impact on the Mean of the Squared Error (MSE) 11 | * of a branch existing versus not existing. The value returned is the MSE 12 | * without the branch minus the MSE with the branch (so higher is better, as 13 | * is required by the scoreSplit() interface. 14 | */ 15 | public class PenalizedMSEScorer extends GRImbalancedScorer { 16 | 17 | @Override 18 | protected double getUnSplitScore(MeanValueCounter a) { 19 | return getTotalError(a)/a.getTotal(); 20 | } 21 | 22 | public PenalizedMSEScorer(double degreeOfGainRatioPenalty, double imbalancePenaltyPower, AttributeStats attributeStats) { 23 | super(degreeOfGainRatioPenalty, imbalancePenaltyPower, attributeStats); 24 | } 25 | 26 | @Override 27 | public double scoreSplit(final MeanValueCounter a, final MeanValueCounter b) { 28 | double splitMSE = (getTotalError(a) + getTotalError(b)) / (a.getTotal() + b.getTotal()); 29 | return correctForGainRatio(unSplitScore - splitMSE) * getPenaltyForImabalance(a, b); 30 | } 31 | 32 | private double getTotalError(MeanValueCounter mvc) { 33 | //below: total MSE for using the mvc as a leaf is Sum( (yi- mean)^2 ) = accumulatedSquares - mean^2 *numSamples 34 | double totalError = (mvc.getAccumulatedSquares() - mvc.getAccumulatedValue()*mvc.getAccumulatedValue()/mvc.getTotal()); 35 | return totalError; 36 | } 37 | 38 | @Override 39 | public String toString() { 40 | final StringBuilder sb = new StringBuilder("MSEScorer{"); 41 | sb.append('}'); 42 | return sb.toString(); 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/scorers/RTPenalizedMSEScorerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.scorers; 2 | 3 | import quickml.supervised.tree.reducers.AttributeStats; 4 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 5 | import quickml.supervised.tree.scorers.GRImbalancedScorer; 6 | import quickml.supervised.tree.scorers.GRImbalancedScorerFactory; 7 | import quickml.supervised.tree.scorers.ScorerFactory; 8 | 9 | /** 10 | * Created by alexanderhawk on 7/9/15. 11 | */ 12 | public class RTPenalizedMSEScorerFactory extends GRImbalancedScorerFactory { 13 | 14 | public RTPenalizedMSEScorerFactory() { 15 | } 16 | 17 | public RTPenalizedMSEScorerFactory(double degreeOfGainRatioPenalty, double imbalancePenaltyPower) { 18 | super(degreeOfGainRatioPenalty, imbalancePenaltyPower); 19 | } 20 | 21 | @Override 22 | public GRImbalancedScorer getScorer(AttributeStats attributeStats) { 23 | return new PenalizedMSEScorer(degreeOfGainRatioPenalty, imbalancePenaltyPower, attributeStats); 24 | } 25 | 26 | @Override 27 | public ScorerFactory copy() { 28 | return new RTPenalizedMSEScorerFactory(degreeOfGainRatioPenalty, imbalancePenaltyPower); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/treeBuildContexts/RTreeContext.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.treeBuildContexts; 2 | 3 | import quickml.data.instances.RegressionInstance; 4 | import quickml.supervised.tree.branchFinders.BranchFinderAndReducerFactory; 5 | import quickml.supervised.tree.branchingConditions.BranchingConditions; 6 | import quickml.supervised.tree.nodes.LeafBuilder; 7 | import quickml.supervised.tree.regressionTree.valueCounters.MeanValueCounter; 8 | import quickml.supervised.tree.scorers.ScorerFactory; 9 | import quickml.supervised.tree.summaryStatistics.ValueCounterProducer; 10 | import quickml.supervised.tree.treeBuildContexts.TreeContext; 11 | 12 | import java.io.Serializable; 13 | import java.util.List; 14 | import java.util.Set; 15 | 16 | /** 17 | * Created by alexanderhawk on 6/21/15. 18 | */ 19 | public class RTreeContext extends TreeContext { 20 | 21 | public RTreeContext(BranchingConditions branchingConditions, 22 | ScorerFactory scorerFactory, 23 | List> branchFindersAndReducers, 24 | LeafBuilder leafBuilder, 25 | ValueCounterProducer valueCounterProducer 26 | ) { 27 | super(branchingConditions, scorerFactory, branchFindersAndReducers, leafBuilder, valueCounterProducer); 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/regressionTree/valueCounters/MeanValueCounterProducer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.regressionTree.valueCounters; 2 | 3 | import quickml.data.instances.RegressionInstance; 4 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 5 | import quickml.supervised.tree.summaryStatistics.ValueCounterProducer; 6 | 7 | import java.util.List; 8 | 9 | /** 10 | * Created by alexanderhawk on 4/22/15. 11 | */ 12 | public class MeanValueCounterProducer implements ValueCounterProducer { 13 | @Override 14 | public MeanValueCounter getValueCounter(List instances) { 15 | return MeanValueCounter.accumulateAll(instances); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/scorers/GRImbalancedScorerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.scorers; 2 | 3 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 4 | 5 | import java.io.Serializable; 6 | import java.util.Map; 7 | 8 | import static quickml.supervised.tree.constants.ForestOptions.IMBALANCE_PENALTY_POWER; 9 | 10 | /** 11 | * Created by alexanderhawk on 7/8/15. 12 | */ 13 | public abstract class GRImbalancedScorerFactory> extends GRScorerFactory { 14 | protected double imbalancePenaltyPower; 15 | 16 | 17 | public GRImbalancedScorerFactory(){} 18 | 19 | public GRImbalancedScorerFactory(double degreeOfGainRatioPenalty, double imbalancePenaltyPower) { 20 | super(degreeOfGainRatioPenalty); 21 | this.imbalancePenaltyPower = imbalancePenaltyPower; 22 | } 23 | 24 | @Override 25 | public void update(Map cfg) { 26 | super.update(cfg); 27 | if (cfg.containsKey(IMBALANCE_PENALTY_POWER.name())) 28 | imbalancePenaltyPower = (Double) cfg.get(IMBALANCE_PENALTY_POWER.name()); 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/scorers/GRScorerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.scorers; 2 | 3 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 4 | 5 | import java.io.Serializable; 6 | import java.util.Map; 7 | 8 | import static quickml.supervised.tree.constants.ForestOptions.DEGREE_OF_GAIN_RATIO_PENALTY; 9 | import static quickml.supervised.tree.constants.ForestOptions.IMBALANCE_PENALTY_POWER; 10 | 11 | /** 12 | * Created by alexanderhawk on 7/8/15. 13 | */ 14 | public abstract class GRScorerFactory> implements ScorerFactory{ 15 | protected double degreeOfGainRatioPenalty = 0.0; 16 | 17 | 18 | public GRScorerFactory(){} 19 | 20 | public GRScorerFactory(double degreeOfGainRatioPenalty) { 21 | this.degreeOfGainRatioPenalty = degreeOfGainRatioPenalty; 22 | } 23 | 24 | @Override 25 | public void update(Map cfg) { 26 | if (cfg.containsKey(DEGREE_OF_GAIN_RATIO_PENALTY.name())) 27 | degreeOfGainRatioPenalty = (Double) cfg.get(DEGREE_OF_GAIN_RATIO_PENALTY.name()); 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/scorers/Scorer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.scorers; 2 | 3 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 4 | 5 | /** 6 | * Created by alexanderhawk on 7/8/15. 7 | */ 8 | public interface Scorer> { 9 | double scoreSplit(VC a, VC b); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/scorers/ScorerFactory.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.scorers; 2 | 3 | import quickml.supervised.tree.reducers.AttributeStats; 4 | import quickml.supervised.tree.summaryStatistics.ValueCounter; 5 | 6 | import java.io.Serializable; 7 | import java.util.Map; 8 | 9 | /** 10 | * Created by alexanderhawk on 7/8/15. 11 | */ 12 | public interface ScorerFactory> extends Serializable{ 13 | 14 | Scorer getScorer(AttributeStats attributeStats); 15 | 16 | ScorerFactory copy(); 17 | 18 | void update(Map cfg); 19 | 20 | } 21 | 22 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/summaryStatistics/ValueCounter.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.summaryStatistics; 2 | 3 | 4 | import java.io.Serializable; 5 | 6 | /** 7 | * Created by alexanderhawk on 4/23/15. 8 | */ 9 | public abstract class ValueCounter extends ValueStatistics implements ValueStatisticsOperations { 10 | public ValueCounter() { 11 | super(); 12 | } 13 | public ValueCounter(Serializable attrVal) { 14 | super(attrVal); 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/summaryStatistics/ValueCounterProducer.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.summaryStatistics; 2 | 3 | import quickml.data.instances.InstanceWithAttributesMap; 4 | 5 | import java.util.List; 6 | 7 | /** 8 | * Created by alexanderhawk on 4/22/15. 9 | */ 10 | public interface ValueCounterProducer, VC extends ValueCounter> { 11 | public abstract VC getValueCounter(List instances); 12 | } 13 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/summaryStatistics/ValueStatistics.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.summaryStatistics; 2 | 3 | import java.io.Serializable; 4 | 5 | /** 6 | * Created by alexanderhawk on 4/5/15. 7 | */ 8 | public abstract class ValueStatistics { 9 | 10 | public Serializable attrVal; 11 | 12 | public ValueStatistics(Serializable attrVal) { 13 | this.attrVal = attrVal; 14 | } 15 | 16 | public ValueStatistics(){} 17 | 18 | public Serializable getAttrVal() { 19 | return attrVal; 20 | } 21 | 22 | public abstract double getTotal(); 23 | 24 | public abstract boolean isEmpty(); 25 | 26 | 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/main/java/quickml/supervised/tree/summaryStatistics/ValueStatisticsOperations.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.summaryStatistics; 2 | 3 | 4 | /** 5 | * Created by alexanderhawk on 4/23/15. 6 | */ 7 | public interface ValueStatisticsOperations { 8 | TS add(TS ts); 9 | TS subtract(TS ts); 10 | } 11 | -------------------------------------------------------------------------------- /src/main/java/quickml/unsupervised/clustering/Clusterer.java: -------------------------------------------------------------------------------- 1 | package quickml.unsupervised.clustering; 2 | 3 | /** 4 | * Created by ian on 4/24/15. 5 | */ 6 | public class Clusterer { 7 | } 8 | -------------------------------------------------------------------------------- /src/main/java/quickml/utlities/LinePlotterBuilder.java: -------------------------------------------------------------------------------- 1 | package quickml.utlities; 2 | 3 | /** 4 | * Created by alexanderhawk on 10/2/14. 5 | */ 6 | 7 | import javax.swing.*; 8 | 9 | public class LinePlotterBuilder extends JFrame { 10 | 11 | int graphSizeInXDimension = 500; 12 | int graphSizeInYDimension = 270; 13 | String xAxisLabel = "X"; 14 | String yAxisLabel = "Y"; 15 | String chartTitle = ""; 16 | 17 | public LinePlotterBuilder chartTitle(String chartTitle) { 18 | this.chartTitle = chartTitle; 19 | return this; 20 | } 21 | 22 | public LinePlotterBuilder xyGraphDimensions(int xDim, int yDim) { 23 | this.graphSizeInXDimension = xDim; 24 | this.graphSizeInYDimension = yDim; 25 | return this; 26 | } 27 | 28 | public LinePlotterBuilder xAxisLabel(String xLabel) { 29 | this.xAxisLabel = xLabel; 30 | return this; 31 | } 32 | 33 | public LinePlotterBuilder yAxisLabel(String xLabel) { 34 | this.yAxisLabel = yAxisLabel; 35 | return this; 36 | } 37 | 38 | public LinePlotter buildLinePlotter(){ 39 | return new LinePlotter(chartTitle, xAxisLabel, yAxisLabel, graphSizeInXDimension, graphSizeInYDimension); 40 | } 41 | 42 | } -------------------------------------------------------------------------------- /src/main/java/quickml/utlities/SerializationUtility.java: -------------------------------------------------------------------------------- 1 | package quickml.utlities; 2 | 3 | import java.io.*; 4 | import java.util.zip.GZIPInputStream; 5 | import java.util.zip.GZIPOutputStream; 6 | 7 | /** 8 | * Created by alexanderhawk on 12/17/14. 9 | */ 10 | public class SerializationUtility { 11 | 12 | public E loadObjectFromGZIPFile(final String modelFile) { 13 | try (ObjectInputStream ois = new ObjectInputStream(new GZIPInputStream(new FileInputStream(modelFile)));) { 14 | return (E) ois.readObject(); 15 | } catch (IOException | ClassNotFoundException e) { 16 | throw new RuntimeException("Error reading predictive model", e); 17 | } 18 | } 19 | 20 | 21 | 22 | public void writeModelToGZIPFile(final String modelFileName, E object) { 23 | try (ObjectOutputStream oos = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(modelFileName)));) { 24 | oos.writeObject(object); 25 | } catch (IOException e) { 26 | throw new RuntimeException("Error reading predictive model", e); 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/main/java/quickml/utlities/selectors/CategoricalSelector.java: -------------------------------------------------------------------------------- 1 | package quickml.utlities.selectors; 2 | 3 | /** 4 | * Created by alexanderhawk on 10/4/14. 5 | */ 6 | public interface CategoricalSelector { 7 | boolean isCategorical(String columnName); 8 | String cleanValue(String value); 9 | } 10 | -------------------------------------------------------------------------------- /src/main/java/quickml/utlities/selectors/ExplicitCategoricalSelector.java: -------------------------------------------------------------------------------- 1 | package quickml.utlities.selectors; 2 | 3 | import java.util.Set; 4 | 5 | /** 6 | * Created by alexanderhawk on 10/4/14. 7 | */ 8 | public class ExplicitCategoricalSelector implements CategoricalSelector { 9 | private Set selectionSet; 10 | 11 | public ExplicitCategoricalSelector(Set selectionSet) { 12 | this.selectionSet = selectionSet; 13 | } 14 | 15 | public boolean isCategorical(String columnName) { 16 | return selectionSet.contains(columnName); 17 | } 18 | 19 | public String cleanValue(String value) { 20 | value = value.replaceAll("\\s", ""); 21 | if ((value.startsWith("\"") && value.endsWith("\""))|| (value.startsWith("\'")) && value.endsWith("\'")) { 22 | return value.substring(1, value.length() - 2); 23 | } 24 | ; 25 | return value; 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/main/java/quickml/utlities/selectors/ExplicitNumericSelector.java: -------------------------------------------------------------------------------- 1 | package quickml.utlities.selectors; 2 | 3 | import java.util.Set; 4 | 5 | /** 6 | * Created by alexanderhawk on 10/4/14. 7 | */ 8 | public class ExplicitNumericSelector implements NumericSelector { 9 | protected Set selectionSet; 10 | public ExplicitNumericSelector(Set selectionSet) { 11 | this.selectionSet = selectionSet; 12 | } 13 | public boolean isNumeric(String columnName) { 14 | return selectionSet.contains(columnName); 15 | } 16 | public String cleanValue(String value) { 17 | value = value.replaceAll("\\s", ""); 18 | if ((value.startsWith("\"") && value.endsWith("\""))|| (value.startsWith("\'")) && value.endsWith("\'")) { 19 | return value.substring(1, value.length() - 2); 20 | } 21 | return value; 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /src/main/java/quickml/utlities/selectors/NumericSelector.java: -------------------------------------------------------------------------------- 1 | package quickml.utlities.selectors; 2 | 3 | /** 4 | * Created by alexanderhawk on 10/4/14. 5 | */ 6 | public interface NumericSelector { 7 | boolean isNumeric(String columnName); 8 | String cleanValue(String value); 9 | 10 | } 11 | -------------------------------------------------------------------------------- /src/main/resources/logback.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | %d{yyyy-MM-dd HH:mm:ss.SSS} %-5level %logger{10} (%file:%L\) - %msg%n 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /src/test/java/quickml/InstanceLoader.java: -------------------------------------------------------------------------------- 1 | package quickml; 2 | 3 | import org.slf4j.Logger; 4 | import org.slf4j.LoggerFactory; 5 | import quickml.data.instances.ClassifierInstance; 6 | import quickml.utlities.CSVToInstanceReader; 7 | import quickml.utlities.CSVToInstanceReaderBuilder; 8 | 9 | import java.io.BufferedReader; 10 | import java.io.InputStreamReader; 11 | import java.util.List; 12 | import java.util.zip.GZIPInputStream; 13 | 14 | /** 15 | * Created by alexanderhawk on 12/30/14. 16 | */ 17 | public class InstanceLoader { 18 | private static final Logger logger = LoggerFactory.getLogger(InstanceLoader.class); 19 | 20 | public static List getAdvertisingInstances() { 21 | CSVToInstanceReader csvToInstanceReader = new CSVToInstanceReaderBuilder().collumnNameForLabel("outcome").buildCsvReader(); 22 | List advertisingInstances; 23 | try { 24 | final BufferedReader br = new BufferedReader(new InputStreamReader((new GZIPInputStream(InstanceLoader.class.getResourceAsStream("advertisingData.csv.gz"))))); 25 | advertisingInstances = csvToInstanceReader.readCsvFromReader(br); 26 | 27 | } catch (Exception e) { 28 | logger.error("failed to get advertising instances", e); 29 | throw new RuntimeException("failed to get advertising instances"); 30 | 31 | } 32 | return advertisingInstances; 33 | } 34 | 35 | 36 | } 37 | -------------------------------------------------------------------------------- /src/test/java/quickml/InstanceLoaderTest.java: -------------------------------------------------------------------------------- 1 | package quickml; 2 | 3 | import org.junit.Assert; 4 | import org.junit.Test; 5 | import quickml.data.instances.ClassifierInstance; 6 | import quickml.data.instances.InstanceWithAttributesMap; 7 | 8 | import java.util.List; 9 | 10 | import static org.junit.Assert.assertEquals; 11 | 12 | /** 13 | * Created by alexanderhawk on 1/5/15. 14 | */ 15 | public class InstanceLoaderTest { 16 | 17 | @Test 18 | public void getAdvertisingInstancesTest() { 19 | List instances = InstanceLoader.getAdvertisingInstances(); 20 | assertEquals(instances.size(), 12000); 21 | InstanceWithAttributesMap lastInstance = instances.get(11999); 22 | Assert.assertTrue(lastInstance.getLabel().equals(0.0)); 23 | Assert.assertTrue(lastInstance.getAttributes().get("country").equals("US")); 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/test/java/quickml/MapUtilsTest.java: -------------------------------------------------------------------------------- 1 | package quickml; 2 | 3 | 4 | import com.google.common.collect.Maps; 5 | 6 | import com.google.common.base.Optional; 7 | import org.testng.Assert; 8 | import org.testng.annotations.Test; 9 | import quickml.collections.MapUtils; 10 | 11 | import java.util.Map; 12 | 13 | /** 14 | * Created by ian on 4/13/14. 15 | */ 16 | public class MapUtilsTest { 17 | @Test 18 | public void testGetEntryWithLowestValue() throws Exception { 19 | Map map = Maps.newHashMap(); 20 | map.put("one", 1.0); 21 | map.put("onepointfive", 1.5); 22 | map.put("zeropointfive", 0.5); 23 | final Optional> entryWithLowestValue = MapUtils.getEntryWithLowestValue(map); 24 | Assert.assertTrue(entryWithLowestValue.isPresent(), "Map isn't empty so it should return a result"); 25 | Assert.assertEquals(entryWithLowestValue.get().getKey(), "zeropointfive"); 26 | Assert.assertEquals(entryWithLowestValue.get().getValue(), 0.5); 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/test/java/quickml/TestUtils.java: -------------------------------------------------------------------------------- 1 | package quickml; 2 | 3 | import quickml.data.AttributesMap; 4 | import quickml.data.instances.ClassifierInstance; 5 | 6 | import static org.hamcrest.Matchers.allOf; 7 | import static org.hamcrest.Matchers.anyOf; 8 | import static org.hamcrest.Matchers.hasEntry; 9 | 10 | public class TestUtils { 11 | 12 | 13 | public static ClassifierInstance createClassifierInstance(final int day) { 14 | return new ClassifierInstance(createAttributes(day), 1.0D, 0.5); 15 | } 16 | 17 | private static AttributesMap createAttributes(final double day) { 18 | AttributesMap attrs = AttributesMap.newHashMap(); 19 | attrs.put("timeOfArrival-year", 2015d); 20 | attrs.put("timeOfArrival-monthOfYear", 1d); 21 | attrs.put("timeOfArrival-dayOfMonth", day); 22 | attrs.put("timeOfArrival-hourOfDay", 1d); 23 | attrs.put("timeOfArrival-minuteOfHour", 1d); 24 | return attrs; 25 | } 26 | 27 | 28 | } 29 | -------------------------------------------------------------------------------- /src/test/java/quickml/collections/ValueSummingMapTest.java: -------------------------------------------------------------------------------- 1 | package quickml.collections; 2 | 3 | import org.testng.Assert; 4 | import org.testng.annotations.Test; 5 | 6 | /** 7 | * Created by ian on 3/2/14. 8 | */ 9 | public class ValueSummingMapTest { 10 | @Test 11 | public void simpleTest() { 12 | ValueSummingMap valueSummingMap = new ValueSummingMap(); 13 | Assert.assertEquals(valueSummingMap.getSumOfValues(), 0.0); 14 | valueSummingMap.put("a", 1); 15 | Assert.assertEquals(valueSummingMap.getSumOfValues(), 1.0); 16 | valueSummingMap.put("a", 1); 17 | Assert.assertEquals(valueSummingMap.getSumOfValues(), 1.0); 18 | valueSummingMap.put("b", 1); 19 | Assert.assertEquals(valueSummingMap.getSumOfValues(), 2.0); 20 | valueSummingMap.addToValue("b", 2); 21 | Assert.assertEquals(valueSummingMap.getSumOfValues(), 4.0); 22 | Assert.assertEquals(valueSummingMap.get("b"), 3.0); 23 | valueSummingMap.remove("b"); 24 | Assert.assertEquals(valueSummingMap.getSumOfValues(), 1.0); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/classifier/ClassifiersTest.java: -------------------------------------------------------------------------------- 1 | 2 | package quickml.supervised.classifier; 3 | 4 | import org.javatuples.Pair; 5 | import org.slf4j.Logger; 6 | import org.slf4j.LoggerFactory; 7 | import quickml.data.instances.ClassifierInstance; 8 | import quickml.data.OnespotDateTimeExtractor; 9 | import quickml.supervised.classifier.downsampling.DownsamplingClassifier; 10 | import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction; 11 | 12 | import java.io.Serializable; 13 | import java.util.List; 14 | import java.util.Map; 15 | 16 | import static quickml.InstanceLoader.getAdvertisingInstances; 17 | 18 | 19 | public class ClassifiersTest { 20 | private static final Logger logger = LoggerFactory.getLogger(ClassifiersTest.class); 21 | 22 | 23 | public void getOptimizedDownsampledRandomForestIntegrationTest() throws Exception { 24 | double fractionOfDataForValidation = .2; 25 | int rebuildsPerValidation = 1; 26 | List trainingData = getAdvertisingInstances().subList(0, 3000); 27 | OnespotDateTimeExtractor dateTimeExtractor = new OnespotDateTimeExtractor(); 28 | Pair, DownsamplingClassifier> downsamplingClassifierPair = 29 | Classifiers.getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor); 30 | logger.info("logged weighted auc loss should be between 0.25 and 0.28"); 31 | } 32 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/classifier/randomForest/TestIrisAccuracy.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.classifier.randomForest; 2 | 3 | /** 4 | * Created by alexanderhawk on 4/7/15. 5 | */ 6 | import quickml.data.*; 7 | import quickml.data.instances.ClassifierInstance; 8 | import quickml.supervised.*; 9 | import quickml.supervised.ensembles.randomForest.randomDecisionForest.*; 10 | import quickml.supervised.tree.attributeIgnoringStrategies.*; 11 | import quickml.supervised.tree.decisionTree.*; 12 | 13 | import java.io.*; 14 | import java.util.*; 15 | 16 | public class TestIrisAccuracy { 17 | public static void main(String[] args) throws IOException { 18 | List irisDataset = PredictiveAccuracyTests.loadIrisDataset(); 19 | final RandomDecisionForest randomForest = new RandomDecisionForestBuilder<>(new DecisionTreeBuilder<>() 20 | // The default isn't desirable here because this dataset has so few attributes 21 | .attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(0.2))) 22 | .buildPredictiveModel(irisDataset); 23 | 24 | AttributesMap attributes = new AttributesMap(); 25 | attributes.put("sepal-length", 5.84); 26 | attributes.put("sepal-width", 3.05); 27 | attributes.put("petal-length", 3.76); 28 | attributes.put("petal-width", 1.20); 29 | System.out.println("Prediction: " + randomForest.predict(attributes)); 30 | for (ClassifierInstance instance : irisDataset) { 31 | System.out.println("classification: " + randomForest.getClassificationByMaxProb(instance.getAttributes())); 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/crossValidation/PredictionMapResultsTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation; 2 | 3 | import org.junit.Test; 4 | 5 | import static java.util.Collections.EMPTY_LIST; 6 | 7 | public class PredictionMapResultsTest { 8 | 9 | 10 | @Test(expected = IllegalArgumentException.class) 11 | public void testTotalLossNoData() { 12 | new PredictionMapResults(EMPTY_LIST); 13 | } 14 | 15 | 16 | 17 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/crossValidation/SimpleCrossValidatorIntegrationTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation; 2 | 3 | import org.junit.Before; 4 | import org.junit.Test; 5 | import quickml.InstanceLoader; 6 | import quickml.data.instances.ClassifierInstance; 7 | import quickml.data.OnespotDateTimeExtractor; 8 | import quickml.supervised.crossValidation.data.OutOfTimeData; 9 | import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierLogCVLossFunction; 10 | import quickml.supervised.tree.decisionTree.DecisionTree; 11 | import quickml.supervised.tree.decisionTree.DecisionTreeBuilder; 12 | import quickml.supervised.tree.decisionTree.scorers.GRPenalizedGiniImpurityScorerFactory; 13 | 14 | import java.util.List; 15 | 16 | /** 17 | * Created by alexanderhawk on 7/8/15. 18 | */ 19 | public class SimpleCrossValidatorIntegrationTest { 20 | 21 | private List instances; 22 | 23 | @Before 24 | public void setUp() throws Exception { 25 | instances = InstanceLoader.getAdvertisingInstances().subList(0,1000); 26 | } 27 | 28 | 29 | @Test 30 | public void testCrossValidation() throws Exception { 31 | System.out.println("\n \n \n new attrImportanceTest"); 32 | DecisionTreeBuilder modelBuilder = new DecisionTreeBuilder().scorerFactory(new GRPenalizedGiniImpurityScorerFactory()).maxDepth(16).minLeafInstances(0).minAttributeValueOccurences(11).attributeIgnoringStrategy(new quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability(0.7)); 33 | 34 | SimpleCrossValidator cv = new SimpleCrossValidator<>(modelBuilder, 35 | new ClassifierLossChecker(new ClassifierLogCVLossFunction(.000001)), 36 | new OutOfTimeData<>(instances, .25, 12, new OnespotDateTimeExtractor() ) ); 37 | for (int i =0; i<3; i++) { 38 | System.out.println("Loss: " + cv.getLossForModel()); 39 | } 40 | 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/crossValidation/lossfunctions/ClassifierMSELossFunctionTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions; 2 | 3 | import org.junit.Assert; 4 | import org.junit.Test; 5 | import quickml.data.PredictionMap; 6 | import quickml.supervised.crossValidation.PredictionMapResult; 7 | import quickml.supervised.crossValidation.PredictionMapResults; 8 | import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.ClassifierMSELossFunction; 9 | 10 | import static com.google.common.collect.Lists.newArrayList; 11 | 12 | public class ClassifierMSELossFunctionTest { 13 | 14 | @Test 15 | public void testGetTotalLoss() { 16 | ClassifierMSELossFunction crossValLoss = new ClassifierMSELossFunction(); 17 | PredictionMapResult result1 = createPredictionMapResult("test1", 0.75, 2.0); 18 | PredictionMapResult result2 = createPredictionMapResult("test1", 0.5, 1.0); 19 | PredictionMapResults predictionMapResults = new PredictionMapResults(newArrayList(result1, result2)); 20 | 21 | Assert.assertEquals(0.125, crossValLoss.getLoss(predictionMapResults), 0.0001); 22 | } 23 | 24 | private PredictionMapResult createPredictionMapResult(final String label, final double prediction, final double weight) { 25 | PredictionMap map = PredictionMap.newMap(); 26 | map.put(label, prediction); 27 | return new PredictionMapResult(map, label, weight); 28 | } 29 | 30 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/crossValidation/lossfunctions/LossFunctionsTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions; 2 | 3 | import org.junit.Test; 4 | 5 | public class LossFunctionsTest { 6 | 7 | 8 | @Test 9 | public void testName() throws Exception { 10 | System.out.println(1 - 0.234); 11 | } 12 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/crossValidation/lossfunctions/rankingLossFunctions/NDCGTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.crossValidation.lossfunctions.rankingLossFunctions; 2 | 3 | import com.beust.jcommander.internal.Lists; 4 | import com.beust.jcommander.internal.Maps; 5 | import org.junit.Assert; 6 | import org.junit.Before; 7 | import org.junit.Test; 8 | import quickml.supervised.rankingModels.ItemToOutcomeMap; 9 | import quickml.supervised.rankingModels.LabelPredictionWeightForRanking; 10 | import quickml.supervised.rankingModels.RankingPrediction; 11 | 12 | import java.io.Serializable; 13 | import java.util.ArrayList; 14 | import java.util.HashMap; 15 | 16 | import static org.junit.Assert.*; 17 | 18 | /** 19 | * Created by alexanderhawk on 8/13/15. 20 | */ 21 | public class NDCGTest { 22 | ItemToOutcomeMap itemToOutcomeMap; 23 | RankingPrediction rankingPrediction; 24 | 25 | @Before 26 | public void setUp() { 27 | HashMap itemToOutcomes = new HashMap<>(); 28 | itemToOutcomes.put("c", 1.0); //has loss 1/2 29 | itemToOutcomes.put("a", 2.0); //has loss 3 30 | itemToOutcomeMap = new ItemToOutcomeMap(itemToOutcomes); 31 | ArrayList rankedList = new ArrayList(); 32 | rankedList.add("a"); 33 | rankedList.add("b"); 34 | rankedList.add("c"); 35 | rankedList.add("d"); 36 | rankingPrediction = new RankingPrediction(rankedList); 37 | } 38 | 39 | @Test 40 | public void testDcg() throws Exception { 41 | double dcg = NDCG.dcg(new LabelPredictionWeightForRanking(itemToOutcomeMap, rankingPrediction), 8); 42 | Assert.assertEquals(dcg, 3.50, 1E-5); 43 | } 44 | 45 | @Test 46 | public void testIdcg() throws Exception { 47 | double idcg = NDCG.idcg(new LabelPredictionWeightForRanking(itemToOutcomeMap, rankingPrediction), 8); 48 | Assert.assertEquals(idcg, 3.6309297535714573, 1E-5); 49 | } 50 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/dataProcessing/instanceTranformer/ProductFeatureAppenderTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.dataProcessing.instanceTranformer; 2 | 3 | /** 4 | * Created by alexanderhawk on 10/15/15. 5 | */ 6 | public class ProductFeatureAppenderTest { 7 | 8 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/downsampling/DownsamplingPredictiveModelTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.downsampling; 2 | 3 | import junit.framework.Assert; 4 | import org.testng.annotations.Test; 5 | import quickml.data.AttributesMap; 6 | import quickml.supervised.classifier.Classifier; 7 | import quickml.supervised.classifier.downsampling.DownsamplingClassifier; 8 | 9 | import java.util.HashMap; 10 | import java.util.Map; 11 | 12 | import static org.mockito.Mockito.*; 13 | 14 | /** 15 | * Created by ian on 4/24/14. 16 | */ 17 | public class DownsamplingPredictiveModelTest { 18 | @Test 19 | public void simpleTest() { 20 | final Classifier classifier = mock(Classifier.class); 21 | when(classifier.getProbability(any(AttributesMap.class), eq(Boolean.TRUE))).thenReturn(0.5); 22 | DownsamplingClassifier downsamplingClassifier = new DownsamplingClassifier(classifier, Boolean.FALSE, Boolean.TRUE, 0.9); 23 | double corrected = downsamplingClassifier.getProbability(AttributesMap.newHashMap(), Boolean.TRUE); 24 | double error = Math.abs(corrected - 0.1/1.1); 25 | Assert.assertTrue(String.format("Error (%s) should be negligible", error), error < 0.0000001); 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/featureEngineering/AttributeCombiningEnricherTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.featureEngineering; 2 | 3 | import com.google.common.collect.Sets; 4 | import com.google.common.collect.Lists; 5 | import org.testng.Assert; 6 | import org.testng.annotations.Test; 7 | import quickml.data.AttributesMap; 8 | import quickml.supervised.featureEngineering1.enrichStrategies.attributeCombiner.AttributeCombiningEnricher; 9 | 10 | import java.util.List; 11 | import java.util.Set; 12 | 13 | public class AttributeCombiningEnricherTest { 14 | @Test 15 | public void simpleTest() { 16 | Set> attributesToCombine = Sets.newHashSet(); 17 | attributesToCombine.add(Lists.newArrayList("k1", "k2")); 18 | AttributeCombiningEnricher attributeCombiningEnricher = new AttributeCombiningEnricher(attributesToCombine); 19 | AttributesMap attributes = AttributesMap.newHashMap(); 20 | attributes.put("k1", "a"); 21 | attributes.put("k2", "b"); 22 | final AttributesMap enhancedAttributes = attributeCombiningEnricher.apply(attributes); 23 | Assert.assertEquals(enhancedAttributes.size(), 3); 24 | Assert.assertEquals(enhancedAttributes.get("k1-k2"), "ab"); 25 | } 26 | 27 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/featureEngineering/ProbabilityInjectingEnricherTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.featureEngineering; 2 | 3 | import com.google.common.collect.Maps; 4 | import junit.framework.Assert; 5 | import org.testng.annotations.Test; 6 | import quickml.data.AttributesMap; 7 | import quickml.supervised.featureEngineering1.enrichStrategies.probabilityInjector.ProbabilityInjectingEnricher; 8 | 9 | import java.io.Serializable; 10 | import java.util.Map; 11 | 12 | public class ProbabilityInjectingEnricherTest { 13 | @Test 14 | public void simpleTest() { 15 | final Map> valueProbsByAttr = Maps.newHashMap(); 16 | Map valueProbs = Maps.newHashMap(); 17 | valueProbs.put(5, 0.2); 18 | valueProbsByAttr.put("testkey", valueProbs); 19 | ProbabilityInjectingEnricher probabilityInjectingEnricher = new ProbabilityInjectingEnricher(valueProbsByAttr); 20 | AttributesMap inputAttributes = AttributesMap.newHashMap(); 21 | inputAttributes.put("testkey", 5); 22 | final AttributesMap outputAttributes = probabilityInjectingEnricher.apply(inputAttributes); 23 | Assert.assertEquals("The pre-existing attribute is still there", 5, outputAttributes.get("testkey")); 24 | Assert.assertEquals("The newly added attribute is there", 0.2, outputAttributes.get("testkey-PROB")); 25 | } 26 | 27 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/predictiveModelOptimizer/fieldValueRecommenders/MonotonicConvergenceRecommenderTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders; 2 | 3 | import com.google.common.collect.Lists; 4 | import org.junit.Before; 5 | import org.junit.Ignore; 6 | import org.junit.Test; 7 | 8 | import java.util.Arrays; 9 | import java.util.List; 10 | 11 | import static org.junit.Assert.*; 12 | 13 | public class MonotonicConvergenceRecommenderTest { 14 | 15 | private MonotonicConvergenceRecommender recommender; 16 | 17 | @Before 18 | public void setUp() throws Exception { 19 | recommender = new MonotonicConvergenceRecommender(Arrays.asList(1, 5, 10, 20, 40), 0.1); 20 | } 21 | 22 | 23 | @Test 24 | public void testWeStopIfThresholdIsNotReached() throws Exception { 25 | List losses = Lists.newArrayList(); 26 | for (int i = 0; i < recommender.getValues().size(); i++) { 27 | double prevLoss = (i>0) ? losses.get(i-1) : 1.0; 28 | losses.add(prevLoss*2); 29 | if (!recommender.shouldContinue(losses)) 30 | break; 31 | } 32 | // 33 | assertEquals(5, losses.size()); 34 | } 35 | 36 | 37 | @Test 38 | public void testWeContinueIfWeHaventGoneOverTheTolerance() throws Exception { 39 | List losses = Lists.newArrayList(); 40 | double[] lossValue = new double[]{0.001, 0.002, 0.002001, 0.004, 0.005}; 41 | for (int i = 0; i < recommender.getValues().size(); i++) { 42 | losses.add(lossValue[i]); 43 | if (!recommender.shouldContinue(losses)) 44 | break; 45 | } 46 | 47 | System.out.println("losses = " + losses); 48 | assertEquals(3, losses.size()); 49 | 50 | } 51 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/tree/decisionTree/OldClassificationCounterTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree; 2 | 3 | import org.testng.Assert; 4 | import org.testng.annotations.Test; 5 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 6 | 7 | /** 8 | * Created by ian on 2/27/14. 9 | */ 10 | public class OldClassificationCounterTest { 11 | 12 | @Test 13 | public void testAdd() { 14 | ClassificationCounter a = new ClassificationCounter(); 15 | a.addClassification("dog", 1.0); 16 | a.addClassification("cat", 0.5); 17 | ClassificationCounter b = new ClassificationCounter(); 18 | b.addClassification("dog", 0.5); 19 | b.addClassification("cat", 1.0); 20 | ClassificationCounter c = a.add(b); 21 | Assert.assertEquals(c.getCount("dog"), 1.5); 22 | Assert.assertEquals(c.getCount("cat"), 1.5); 23 | } 24 | 25 | @Test 26 | public void testSubtract() { 27 | ClassificationCounter a = new ClassificationCounter(); 28 | a.addClassification("dog", 1.0); 29 | a.addClassification("cat", 2.5); 30 | ClassificationCounter b = new ClassificationCounter(); 31 | b.addClassification("dog", 0.5); 32 | b.addClassification("cat", 1.0); 33 | ClassificationCounter c = a.subtract(b); 34 | Assert.assertEquals(c.getCount("dog"), 0.5); 35 | Assert.assertEquals(c.getCount("cat"), 1.5); 36 | } 37 | 38 | @Test 39 | public void testMerge() { 40 | ClassificationCounter a = new ClassificationCounter(); 41 | a.addClassification("dog", 1.0); 42 | a.addClassification("cat", 0.5); 43 | ClassificationCounter b = new ClassificationCounter(); 44 | b.addClassification("dog", 0.5); 45 | b.addClassification("cat", 1.0); 46 | ClassificationCounter merged = ClassificationCounter.merge(a, b); 47 | Assert.assertEquals(merged.getTotal(), 3.0); 48 | Assert.assertEquals(merged.getCount("dog"), 1.5); 49 | Assert.assertEquals(merged.getCount("cat"), 1.5); 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/tree/decisionTree/reducers/BinaryCatOldBranchReducerTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.decisionTree.reducers; 2 | 3 | import com.google.common.base.Optional; 4 | import org.junit.Assert; 5 | import org.junit.Test; 6 | import quickml.data.instances.ClassifierInstance; 7 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 8 | import quickml.supervised.tree.reducers.AttributeStats; 9 | 10 | import java.util.List; 11 | 12 | /** 13 | * Created by alexanderhawk on 6/29/15. 14 | */ 15 | public class BinaryCatOldBranchReducerTest { 16 | 17 | @Test 18 | public void binCatReducerGetAttributeStatsTest() { 19 | List instances = DTCatOldBranchReducerTest.getInstances(); 20 | DTBinaryCatBranchReducer reducer = new DTBinaryCatBranchReducer<>(instances, 0.0); 21 | Optional> attributeStatsOptional = reducer.getAttributeStats("t"); 22 | AttributeStats attributeStats = attributeStatsOptional.get(); 23 | Assert.assertEquals(attributeStats.getStatsOnEachValue().size(), 2); 24 | ClassificationCounter first = attributeStats.getStatsOnEachValue().get(0); 25 | 26 | //test correct ordering that first comes before second 27 | Assert.assertEquals(first.getCount(1.0), 2.0, 1E-5); 28 | Assert.assertEquals(first.getCount(0.0), 2.0, 1E-5); 29 | ClassificationCounter second = attributeStats.getStatsOnEachValue().get(1); 30 | Assert.assertEquals(second.getCount(1.0), 3.0, 1E-5); 31 | Assert.assertEquals(second.getCount(0.0), 1.0, 1E-5); 32 | 33 | } 34 | 35 | } -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/tree/scorers/PenalizedInformationGainScorerTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.scorers; 2 | 3 | import org.junit.Assert; 4 | 5 | import org.junit.Test; 6 | import quickml.supervised.tree.decisionTree.scorers.PenalizedInformationGainScorer; 7 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 8 | import quickml.supervised.tree.reducers.AttributeStats; 9 | 10 | import java.util.Arrays; 11 | 12 | public class PenalizedInformationGainScorerTest { 13 | 14 | @Test 15 | public void sameClassificationTest() { 16 | ClassificationCounter a = new ClassificationCounter(); 17 | a.addClassification("a", 4); 18 | ClassificationCounter b = new ClassificationCounter(); 19 | b.addClassification("a", 4); 20 | PenalizedInformationGainScorer scorer = new PenalizedInformationGainScorer(0, 0.0, new AttributeStats<>(Arrays.asList(a, b), a.add(b), "a")); 21 | 22 | Assert.assertEquals(scorer.scoreSplit(a, b), 0.0, 1E-7); 23 | } 24 | 25 | @Test 26 | public void diffClassificationTest() { 27 | ClassificationCounter a = new ClassificationCounter(); 28 | a.addClassification("a", 4); 29 | ClassificationCounter b = new ClassificationCounter(); 30 | b.addClassification("b", 4); 31 | PenalizedInformationGainScorer scorer = new PenalizedInformationGainScorer(0, 0.0, new AttributeStats<>(Arrays.asList(a, b), a.add(b), "a")); 32 | 33 | Assert.assertEquals(scorer.scoreSplit(a, b), 1.0, 1E-7); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/test/java/quickml/supervised/tree/scorers/PenalizedMSEScorerTest.java: -------------------------------------------------------------------------------- 1 | package quickml.supervised.tree.scorers; 2 | 3 | import org.junit.Assert; 4 | 5 | import org.junit.Test; 6 | import quickml.supervised.tree.decisionTree.scorers.PenalizedMSEScorer; 7 | import quickml.supervised.tree.decisionTree.valueCounters.ClassificationCounter; 8 | import quickml.supervised.tree.reducers.AttributeStats; 9 | 10 | import java.util.Arrays; 11 | 12 | /** 13 | * Created by ian on 2/27/14. 14 | */ 15 | public class PenalizedMSEScorerTest { 16 | 17 | @Test 18 | public void simpleTest() { 19 | ClassificationCounter a = new ClassificationCounter(); 20 | a.addClassification("a", 4); 21 | a.addClassification("b", 9); 22 | a.addClassification("c", 1); 23 | ClassificationCounter b = new ClassificationCounter(); 24 | b.addClassification("a", 5); 25 | b.addClassification("b", 9); 26 | b.addClassification("c", 6); 27 | PenalizedMSEScorer mseScorer = new PenalizedMSEScorer(0, 0.0, new AttributeStats<>(Arrays.asList(a, b), a.add(b), "a")); 28 | 29 | Assert.assertTrue(Math.abs(mseScorer.scoreSplit(a, b) - 0.021776929) < 0.000000001); 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/test/resources/quickml/advertisingData.csv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanity/quickml/3795d6d759ff8845ec5fab26bc49197ba597420e/src/test/resources/quickml/advertisingData.csv.gz -------------------------------------------------------------------------------- /src/test/resources/quickml/diabetesDataset.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanity/quickml/3795d6d759ff8845ec5fab26bc49197ba597420e/src/test/resources/quickml/diabetesDataset.txt.gz -------------------------------------------------------------------------------- /src/test/resources/quickml/iris.data.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanity/quickml/3795d6d759ff8845ec5fab26bc49197ba597420e/src/test/resources/quickml/iris.data.gz -------------------------------------------------------------------------------- /src/test/resources/quickml/mobo1.json.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanity/quickml/3795d6d759ff8845ec5fab26bc49197ba597420e/src/test/resources/quickml/mobo1.json.gz --------------------------------------------------------------------------------