├── .gitignore ├── GradientBoosting.md ├── LICENSE ├── README.md ├── SequoiaForest.md ├── data ├── mnist.t.tsv.gz └── mnist.tsv.gz ├── pom.xml ├── scalastyle-config.xml └── src ├── main └── scala │ └── spark_ml │ ├── discretization │ ├── BinFinder.scala │ ├── DiscretizationOptions.scala │ ├── Discretizer.scala │ ├── EntropyMinimizingBinFinderFromSample.scala │ ├── EqualFrequencyBinFinderFromSample.scala │ ├── EqualWidthBinFinder.scala │ ├── NumericBinFinderFromSample.scala │ └── VarianceMinimizingBinFinderFromSample.scala │ ├── gradient_boosting │ ├── GradientBoosting.scala │ ├── GradientBoostingOptions.scala │ ├── GradientBoostingRunner.scala │ ├── PredCache.scala │ └── loss │ │ ├── LossAggregator.scala │ │ ├── LossFunction.scala │ │ └── defaults │ │ ├── AdaBoostLossAggregator.scala │ │ ├── AdaBoostLossFunction.scala │ │ ├── GaussianLossAggregator.scala │ │ ├── GaussianLossFunction.scala │ │ ├── LaplacianLossAggregator.scala │ │ ├── LaplacianLossFunction.scala │ │ ├── LogLossAggregator.scala │ │ ├── LogLossFunction.scala │ │ ├── PoissonLossAggregator.scala │ │ ├── PoissonLossFunction.scala │ │ ├── TruncHingeLossAggregator.scala │ │ └── TruncHingeLossFunction.scala │ ├── model │ ├── DecisionTree.scala │ ├── gb │ │ ├── GradientBoostedTrees.scala │ │ ├── GradientBoostedTreesFactory.scala │ │ └── GradientBoostedTreesStore.scala │ └── rf │ │ └── RandomForestStore.scala │ ├── transformation │ ├── DataTransformationUtils.scala │ └── DistinctValueCounter.scala │ ├── tree_ensembles │ ├── IdCache.scala │ ├── IdLookup.scala │ ├── IdLookupForNodeStats.scala │ ├── IdLookupForSubTreeInfo.scala │ ├── IdLookupForUpdaters.scala │ ├── InfoGainNodeStats.scala │ ├── NodeStats.scala │ ├── QuantizedData_ForTrees.scala │ ├── SubTreeStore.scala │ ├── TreeEnsembleStore.scala │ ├── TreeForestTrainer.scala │ └── VarianceNodeStats.scala │ └── util │ ├── Bagger.scala │ ├── DiscretizedFeatureHandler.scala │ ├── MapWithSequentialIntKeys.scala │ ├── Poisson.scala │ ├── ProgressNotifiee.scala │ ├── RandomSet.scala │ ├── ReservoirSample.scala │ ├── RobustMath.scala │ ├── Selection.scala │ └── Sorting.scala └── test └── scala └── spark_ml ├── discretization ├── EqualFrequencyBinFinderSuite.scala └── EqualWidthBinFinderSuite.scala ├── tree_ensembles ├── ComponentSuite.scala └── TreeEnsembleSuite.scala └── util ├── BinsTestUtil.scala ├── LocalSparkContext.scala └── TestDataGenerator.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *.tmp 2 | *.bak 3 | *.swp 4 | *~.nib 5 | 6 | # SBT 7 | log/ 8 | target/ 9 | *.class 10 | 11 | # Eclipse 12 | .classpath 13 | .project 14 | .settings/ 15 | 16 | # Intellij 17 | .idea/ 18 | .idea_modules/ 19 | *.iml 20 | *.iws 21 | *.ipr 22 | 23 | # Mac 24 | .DS_Store 25 | 26 | -------------------------------------------------------------------------------- /GradientBoosting.md: -------------------------------------------------------------------------------- 1 | # Gradient Boosted Trees 2 | 3 | Instructions coming soon. 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparkML2 2 | 3 | This is a scalable, high-performance machine learning package for `Spark`. This package is maintained independently from `MLLib` in `Spark`. 4 | The algorithms listed here are going through continuous development, maintenance and updates and thus are provided without guarantees. 5 | Use at your own risk. 6 | 7 | This package is provided under the Apache license 2.0. 8 | 9 | The current version is `0.9`. 10 | 11 | Any comments or questions should be directed to schung@alpinenow.com 12 | 13 | The currently available algorithms are: 14 | * [Sequoia Forest](SequoiaForest.md) 15 | * [Gradient Boosted Trees](GradientBoosting.md) 16 | 17 | ## Compiling the Code 18 | 19 | ./mvn3 package 20 | 21 | ## Quick Start (for YARN and Linux variants) 22 | 23 | Coming soon. 24 | -------------------------------------------------------------------------------- /SequoiaForest.md: -------------------------------------------------------------------------------- 1 | # Sequoia Forest 2 | 3 | Instructions coming soon. 4 | -------------------------------------------------------------------------------- /data/mnist.t.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlpineNow/SparkML2/a81e1ad835245f556052816e18bb4e9fb7cabe58/data/mnist.t.tsv.gz -------------------------------------------------------------------------------- /data/mnist.tsv.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AlpineNow/SparkML2/a81e1ad835245f556052816e18bb4e9fb7cabe58/data/mnist.tsv.gz -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 4 | 4.0.0 5 | 6 | com.alpine 7 | spark_ml_2.10 8 | 1.1-SNAPSHOT 9 | jar 10 | Alpine's Spark ML Package 11 | 12 | 13 | Apache 2.0 License 14 | http://www.apache.org/licenses/LICENSE-2.0.html 15 | repo 16 | 17 | 18 | 19 | 20 | schung 21 | Sung Hwan Chung 22 | schung@alpinenow.com 23 | Alpine Data Labs 24 | 25 | 26 | 27 | UTF-8 28 | UTF-8 29 | 2.10.4 30 | 2.10 31 | 1.7 32 | 64m 33 | 256m 34 | 256m 35 | 36 | github 37 | 38 | 39 | 40 | 41 | internal.repo 42 | Temporary Staging Repository 43 | file://${project.build.directory}/mvn-repo 44 | 45 | 46 | 47 | 48 | 49 | central 50 | Maven Repository 51 | https://repo1.maven.org/maven2 52 | 53 | true 54 | 55 | 56 | false 57 | 58 | 59 | 60 | 61 | sonatype-releases 62 | Sonatype Repository 63 | https://oss.sonatype.org/content/repositories/releases 64 | 65 | true 66 | 67 | 68 | false 69 | 70 | 71 | 72 | 73 | 74 | 75 | com.databricks 76 | spark-csv_${scala.binary.version} 77 | 1.3.0 78 | provided 79 | 80 | 81 | com.databricks 82 | spark-avro_${scala.binary.version} 83 | 1.0.0 84 | provided 85 | 86 | 87 | org.apache.spark 88 | spark-core_${scala.binary.version} 89 | 1.5.1 90 | provided 91 | 92 | 93 | org.apache.spark 94 | spark-mllib_${scala.binary.version} 95 | 1.5.1 96 | provided 97 | 98 | 99 | org.apache.hadoop 100 | hadoop-client 101 | 1.0.4 102 | provided 103 | 104 | 105 | org.spire-math 106 | spire_${scala.binary.version} 107 | 0.10.1 108 | 109 | 110 | org.scalanlp 111 | breeze_${scala.binary.version} 112 | 0.11.2 113 | provided 114 | 115 | 116 | com.github.scopt 117 | scopt_${scala.binary.version} 118 | 3.3.0 119 | 120 | 121 | org.scalatest 122 | scalatest_${scala.binary.version} 123 | 2.2.1 124 | test 125 | 126 | 127 | 128 | 129 | 130 | 131 | org.apache.maven.plugins 132 | maven-compiler-plugin 133 | 3.3 134 | 135 | ${java.version} 136 | ${java.version} 137 | UTF-8 138 | 1024m 139 | true 140 | 141 | -Xlint:all,-serial,-path,-XX:MaxPermSize=256m 142 | 143 | 144 | 145 | 146 | net.alchim31.maven 147 | scala-maven-plugin 148 | 3.2.2 149 | 150 | 151 | scala-compile-first 152 | process-resources 153 | 154 | compile 155 | 156 | 157 | 158 | scala-test-compile-first 159 | process-test-resources 160 | 161 | testCompile 162 | 163 | 164 | 165 | attach-scaladocs 166 | verify 167 | 168 | doc-jar 169 | 170 | 171 | 172 | 173 | ${scala.version} 174 | incremental 175 | true 176 | 177 | -unchecked 178 | -deprecation 179 | -feature 180 | 181 | 182 | -Xms1024m 183 | -Xmx1024m 184 | -XX:PermSize=${PermGen} 185 | -XX:MaxPermSize=${MaxPermGen} 186 | -XX:ReservedCodeCacheSize=${CodeCacheSize} 187 | 188 | 189 | 190 | -source 191 | ${java.version} 192 | -target 193 | ${java.version} 194 | -Xlint:all,-serial,-path 195 | 196 | 197 | 198 | 199 | org.scalastyle 200 | scalastyle-maven-plugin 201 | 0.7.0 202 | 203 | false 204 | true 205 | false 206 | false 207 | ${basedir}/src/main/scala 208 | ${basedir}/src/test/scala 209 | scalastyle-config.xml 210 | ${basedir}/target/scalastyle-output.xml 211 | ${project.build.sourceEncoding} 212 | ${project.reporting.outputEncoding} 213 | 214 | 215 | 216 | 217 | check 218 | 219 | 220 | 221 | 222 | 223 | org.scalatest 224 | scalatest-maven-plugin 225 | 1.0 226 | 227 | false 228 | ${project.build.directory}/surefire-reports 229 | . 230 | WDF TestSuite.txt 231 | -XX:PermSize=${PermGen} 232 | 233 | 234 | 235 | test 236 | 237 | test 238 | 239 | 240 | 241 | 242 | 243 | org.apache.maven.plugins 244 | maven-deploy-plugin 245 | 2.8.2 246 | 247 | internal.repo::default::file://${project.build.directory}/mvn-repo 248 | 249 | 250 | 251 | com.github.github 252 | site-maven-plugin 253 | 0.12 254 | 255 | Maven artifacts for ${project.version} 256 | true 257 | ${project.build.directory}/mvn-repo 258 | refs/heads/mvn-repo 259 | **/* 260 | SparkML2 261 | AlpineNow 262 | 263 | 264 | 265 | 266 | 267 | site 268 | 269 | deploy 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 2 | Scalastyle standard configuration 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/discretization/BinFinder.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.discretization 19 | 20 | import org.apache.spark.rdd.RDD 21 | import spark_ml.util.ProgressNotifiee 22 | 23 | case class MinMaxPair( 24 | var minValue: Double, 25 | var maxValue: Double 26 | ) 27 | 28 | class LabelSummary(val expectedCardinality: Option[Int]) extends Serializable { 29 | val catCounts: Option[Array[Long]] = 30 | if (expectedCardinality.isDefined) { 31 | Some( 32 | Array.fill[Long](expectedCardinality.get)(0L) 33 | ) 34 | } else { 35 | None 36 | } 37 | 38 | var restCount: Long = 0L 39 | var hasNaN: Boolean = false 40 | 41 | var totalCount: Double = 0.0 42 | var runningAvg: Double = 0.0 43 | var runningSqrAvg: Double = 0.0 44 | 45 | private def updateRunningAvgs( 46 | bTotalCount: Double, 47 | bRunningAvg: Double, 48 | bRunningSqrAvg: Double): Unit = { 49 | if (totalCount > 0.0 || bTotalCount > 0.0) { 50 | val newTotalCount = totalCount + bTotalCount 51 | val aRatio = totalCount / newTotalCount 52 | val bRatio = bTotalCount / newTotalCount 53 | val newAvg = aRatio * runningAvg + bRatio * bRunningAvg 54 | val newSqrAvg = aRatio * runningSqrAvg + bRatio * bRunningSqrAvg 55 | 56 | totalCount = newTotalCount 57 | runningAvg = newAvg 58 | runningSqrAvg = newSqrAvg 59 | } 60 | } 61 | 62 | def addLabel(label: Double): Unit = { 63 | if (label.isNaN) { 64 | hasNaN = true 65 | } else { 66 | // Keep track of the running averages. 67 | updateRunningAvgs(1.0, label, label * label) 68 | 69 | if (expectedCardinality.isDefined && label.toLong.toDouble == label) { 70 | val cat = label.toInt 71 | if (cat >= 0 && cat < expectedCardinality.get) { 72 | catCounts.get(cat) += 1L 73 | } else { 74 | restCount += 1L 75 | } 76 | } else { 77 | restCount += 1L 78 | } 79 | } 80 | } 81 | 82 | def mergeInPlace(b: LabelSummary): this.type = { 83 | if (catCounts.isDefined) { 84 | var i = 0 85 | while (i < expectedCardinality.get) { 86 | catCounts.get(i) += b.catCounts.get(i) 87 | i += 1 88 | } 89 | } 90 | 91 | updateRunningAvgs(b.totalCount, b.runningAvg, b.runningSqrAvg) 92 | 93 | restCount += b.restCount 94 | hasNaN ||= b.hasNaN 95 | this 96 | } 97 | } 98 | 99 | /** 100 | * Classes implementing this trait is used to find bins in the dataset. 101 | */ 102 | trait BinFinder { 103 | def findBins( 104 | data: RDD[(Double, Array[Double])], 105 | columnNames: (String, Array[String]), 106 | catIndices: Set[Int], 107 | maxNumBins: Int, 108 | expectedLabelCardinality: Option[Int], 109 | notifiee: ProgressNotifiee 110 | ): (LabelSummary, Seq[Bins]) 111 | 112 | def getDataSummary( 113 | data: RDD[(Double, Array[Double])], 114 | numFeatures: Int, 115 | expectedLabelCardinality: Option[Int] 116 | ): (LabelSummary, Array[Boolean], Array[MinMaxPair]) = { 117 | val (ls, hn, mm) = data.mapPartitions( 118 | itr => { 119 | val labelSummary = new LabelSummary(expectedLabelCardinality) 120 | val nanExists = Array.fill[Boolean](numFeatures)(false) 121 | val minMaxes = Array.fill[MinMaxPair](numFeatures)( 122 | MinMaxPair(minValue = Double.PositiveInfinity, maxValue = Double.NegativeInfinity) 123 | ) 124 | while (itr.hasNext) { 125 | val (label, features) = itr.next() 126 | labelSummary.addLabel(label) 127 | features.zipWithIndex.foreach { 128 | case (featValue, featIdx) => 129 | nanExists(featIdx) = nanExists(featIdx) || featValue.isNaN 130 | if (!featValue.isNaN) { 131 | minMaxes(featIdx).minValue = math.min(featValue, minMaxes(featIdx).minValue) 132 | minMaxes(featIdx).maxValue = math.max(featValue, minMaxes(featIdx).maxValue) 133 | } 134 | } 135 | } 136 | 137 | Array((labelSummary, nanExists, minMaxes)).iterator 138 | } 139 | ).reduce { 140 | case ((labelSummary1, nanExists1, minMaxes1), (labelSummary2, nanExists2, minMaxes2)) => 141 | ( 142 | labelSummary1.mergeInPlace(labelSummary2), 143 | nanExists1.zip(nanExists2).map { 144 | case (nan1, nan2) => nan1 || nan2 145 | }, 146 | minMaxes1.zip(minMaxes2).map { 147 | case (minMax1, minMax2) => 148 | MinMaxPair( 149 | minValue = math.min(minMax1.minValue, minMax2.minValue), 150 | maxValue = math.max(minMax1.maxValue, minMax2.maxValue) 151 | ) 152 | } 153 | ) 154 | } 155 | 156 | // Labels shouldn't have NaN. So throw an exception if we found NaN for the 157 | // label. 158 | if (ls.hasNaN) { 159 | throw InvalidLabelException("Found NaN value for the label.") 160 | } 161 | 162 | (ls, hn, mm) 163 | } 164 | } 165 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/discretization/DiscretizationOptions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.discretization 19 | 20 | /** 21 | * Discretization options. 22 | * @param discType Discretization type (Equal-Frequency/Width). 23 | * @param maxNumericBins Maximum number of numerical bins. 24 | * @param maxCatCardinality Maximum number of categorical bins. 25 | * @param useFeatureHashingOnCat Whether we should perform feature hashing on 26 | * categorical features whose cardinality goes 27 | * over 'maxCatCardinality'. 28 | * @param maxSampleSizeForDisc Maximum number of sample rows for certain numeric 29 | * discretizations. E.g., equal frequency 30 | * discretization uses a sample, instead of the 31 | * entire data. 32 | */ 33 | case class DiscretizationOptions( 34 | discType: DiscType.DiscType, 35 | maxNumericBins: Int, 36 | maxCatCardinality: Int, 37 | useFeatureHashingOnCat: Boolean, 38 | maxSampleSizeForDisc: Int) { 39 | override def toString: String = { 40 | "=========================" + "\n" + 41 | "Discretization Options" + "\n" + 42 | "=========================" + "\n" + 43 | "discType : " + discType.toString + "\n" + 44 | "maxNumericBins : " + maxNumericBins + "\n" + 45 | "maxCatCardinality : " + maxCatCardinality + "\n" + 46 | "useFeatureHashingOnCat : " + useFeatureHashingOnCat.toString + "\n" + 47 | "maxSampleSizeForDisc : " + maxSampleSizeForDisc.toString + "\n" 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/discretization/Discretizer.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.discretization 19 | 20 | import scala.reflect.ClassTag 21 | 22 | import org.apache.spark.rdd.RDD 23 | import spark_ml.util.DiscretizedFeatureHandler 24 | 25 | /** 26 | * Available discretization types. 27 | */ 28 | object DiscType extends Enumeration { 29 | type DiscType = Value 30 | val EqualFrequency = Value(0) 31 | val EqualWidth = Value(1) 32 | val MinimumEntropy = Value(2) 33 | val MinimumVariance = Value(3) 34 | } 35 | 36 | /** 37 | * A numeric bin with a lower bound and an upper bound. 38 | * The lower bound is inclusive. The upper bound is exclusive. 39 | * @param lower Lower bound 40 | * @param upper Upper bound 41 | */ 42 | case class NumericBin(lower: Double, upper: Double) { 43 | def contains(value: Double): Boolean = { 44 | value >= lower && value < upper 45 | } 46 | 47 | override def toString: String = { 48 | "[" + lower.toString + "," + upper.toString + ")" 49 | } 50 | } 51 | 52 | /** 53 | * For both categorical and numerical bins. 54 | */ 55 | trait Bins extends Serializable { 56 | /** 57 | * Get the number of bins (including the missing-value bin) in case of numeric 58 | * bins. 59 | * @return The number of bins. 60 | */ 61 | def getCardinality: Int 62 | 63 | /** 64 | * Given the raw feature value, find the bin Id. 65 | * @param value Raw feature value. 66 | * @return The corresponding bin Id. 67 | */ 68 | def findBinIdx(value: Double): Int 69 | } 70 | 71 | /** 72 | * An exception to throw n case a categorical feature value is not an integer 73 | * value (e.g., can't be 1.1. or 2.4). 74 | * @param msg String message to include in the exception. 75 | */ 76 | case class InvalidCategoricalValueException(msg: String) extends Exception(msg) 77 | 78 | /** 79 | * An exception to throw if the cardinality of a feature exceeds the limit. 80 | * @param msg String message to include. 81 | */ 82 | case class CardinalityOverLimitException(msg: String) extends Exception(msg) 83 | 84 | /** 85 | * An exception to throw if the label has unexpected values. 86 | * @param msg String message to include. 87 | */ 88 | case class InvalidLabelException(msg: String) extends Exception(msg) 89 | 90 | /** 91 | * Numeric bins. 92 | * @param bins An array of ordered non-NaN numeric bins. 93 | * @param missingValueBinIdx The optional bin Id for the NaN values. 94 | */ 95 | case class NumericBins( 96 | bins: Seq[NumericBin], 97 | missingValueBinIdx: Option[Int] = None) extends Bins { 98 | /** 99 | * The cardinality of the bins (the number of bins). 100 | * @return The number of bins. 101 | */ 102 | def getCardinality: Int = 103 | bins.length + (if (missingValueBinIdx.isDefined) 1 else 0) 104 | 105 | /** 106 | * Find the index of the bin that the value belongs to. 107 | * @param value The numeric value we want to search. 108 | * @return The index of the bin that contains the given numeric value. 109 | */ 110 | def findBinIdx(value: Double): Int = { 111 | if (value.isNaN) { 112 | missingValueBinIdx.get 113 | } else { 114 | // Binary search. 115 | var s = 0 116 | var e = bins.length - 1 117 | var cur = (s + e) / 2 118 | var found = bins(cur).contains(value) 119 | while (!found) { 120 | if (bins(cur).lower > value) { 121 | e = cur - 1 122 | } else if (bins(cur).upper <= value) { 123 | s = cur + 1 124 | } 125 | 126 | cur = (s + e) / 2 127 | found = bins(cur).contains(value) 128 | } 129 | 130 | cur 131 | } 132 | } 133 | } 134 | 135 | /** 136 | * For categorical bins, the raw feature value is simply the categorical bin Id. 137 | * We expect the categorical values to go from 0 to (Cardinality - 1) in an 138 | * incremental fashion. 139 | * @param cardinality The cardinality of the feature. 140 | */ 141 | case class CategoricalBins(cardinality: Int) extends Bins { 142 | /** 143 | * @return The number of bins. 144 | */ 145 | def getCardinality: Int = { 146 | cardinality 147 | } 148 | 149 | /** 150 | * Find the bin Id. 151 | * @param value Raw feature value. 152 | * @return The corresponding bin Id. 153 | */ 154 | def findBinIdx(value: Double): Int = { 155 | if (value.toInt.toDouble != value) { 156 | throw InvalidCategoricalValueException( 157 | value + " is not a valid categorical value." 158 | ) 159 | } 160 | 161 | if (value >= cardinality) { 162 | throw CardinalityOverLimitException( 163 | value + " is above the cardinality of this feature " + cardinality 164 | ) 165 | } 166 | 167 | value.toInt 168 | } 169 | } 170 | 171 | object Discretizer { 172 | /** 173 | * Transform the features in the given labeled point row into an array of Bin 174 | * IDs that could be Unsigned Byte or Unsigned Short. 175 | * @param featureBins A sequence of feature bin definitions. 176 | * @param featureHandler A handler for discretized features (unsigned 177 | * Byte/Short). 178 | * @param row The labeled point row that we want to transform. 179 | * @return A transformed array of features. 180 | */ 181 | private def transformFeatures[@specialized(Byte, Short) T: ClassTag]( 182 | featureBins: Seq[Bins], 183 | featureHandler: DiscretizedFeatureHandler[T])(row: (Double, Array[Double])): Array[T] = { 184 | val (_, features) = row 185 | features.zipWithIndex.map { 186 | case (featureVal, featIdx) => featureHandler.convertToType(featureBins(featIdx).findBinIdx(featureVal)) 187 | } 188 | } 189 | 190 | /** 191 | * Transform the features of the labeled data RDD into bin Ids of either 192 | * unsigned Byte/Short. 193 | * @param input An RDD of Double label and Double feature values. 194 | * @param featureBins A sequence of feature bin definitions. 195 | * @param featureHandler A handler for discretized features (unsigned 196 | * Byte/Short). 197 | * @return A new RDD that has all the features transformed into Bin Ids. 198 | */ 199 | def transformFeatures[@specialized(Byte, Short) T: ClassTag]( 200 | input: RDD[(Double, Array[Double])], 201 | featureBins: Seq[Bins], 202 | featureHandler: DiscretizedFeatureHandler[T]): RDD[Array[T]] = { 203 | input.map(transformFeatures(featureBins, featureHandler)) 204 | } 205 | } 206 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/discretization/EntropyMinimizingBinFinderFromSample.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.discretization 19 | 20 | /** 21 | * Find bins that minimize the label class entropy. Essentially equivalent to 22 | * building a decision tree on a single feature with IG criteria. 23 | * @param maxSampleSize The maximum sample. 24 | * @param seed The seed to use with random sampling. 25 | */ 26 | class EntropyMinimizingBinFinderFromSample(maxSampleSize: Int, seed: Int) 27 | extends NumericBinFinderFromSample("LabelEntropy", maxSampleSize, seed) { 28 | 29 | /** 30 | * Calculate the label class entropy of the segment. 31 | * @param labelValues A sample label values. 32 | * @param featureValues A sample feature values. 33 | * @param labelSummary This is useful if there's a need to find the label 34 | * class cardinality. 35 | * @param s The segment starting index (inclusive). 36 | * @param e The segment ending index (exclusive). 37 | * @return The loss of the segment. 38 | */ 39 | override def findSegmentLoss( 40 | labelValues: Array[Double], 41 | featureValues: Array[Double], 42 | labelSummary: LabelSummary, 43 | s: Int, 44 | e: Int): Double = { 45 | val segSize = (e - s).toDouble 46 | val classCounts = Array.fill[Double](labelSummary.expectedCardinality.get)(0.0) 47 | var i = s 48 | while (i < e) { 49 | val labelValue = labelValues(i).toInt 50 | classCounts(labelValue) += 1.0 51 | i += 1 52 | } 53 | 54 | classCounts.foldLeft(0.0)( 55 | (entropy, cnt) => cnt / segSize match { 56 | case 0.0 => entropy 57 | case p => entropy - p * math.log(p) 58 | } 59 | ) 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/discretization/EqualFrequencyBinFinderFromSample.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.discretization 19 | 20 | /** 21 | * Find bins that satisfy the 'roughly' equal frequency requirement for each 22 | * feature column. We calculate this from a sample of the data (the sample is 23 | * either at or less than maxSampleSize). 24 | * @param maxSampleSize The maximum sample. 25 | * @param seed The seed to use with random sampling. 26 | */ 27 | class EqualFrequencyBinFinderFromSample(maxSampleSize: Int, seed: Int) 28 | extends NumericBinFinderFromSample("BinSize", maxSampleSize, seed) { 29 | 30 | /** 31 | * Loss for a segment is simply its size. Minimizing weighted sizes of bins 32 | * lead to equal frequency binning in the best case. 33 | * @param labelValues A sample label values. 34 | * @param featureValues A sample feature values. 35 | * @param labelSummary This is useful if there's a need to find the label 36 | * class cardinality. 37 | * @param s The segment starting index (inclusive). 38 | * @param e The segment ending index (exclusive). 39 | * @return The loss of the segment. 40 | */ 41 | override def findSegmentLoss( 42 | labelValues: Array[Double], 43 | featureValues: Array[Double], 44 | labelSummary: LabelSummary, 45 | s: Int, 46 | e: Int): Double = { 47 | // Simply define the unsegmented section size as the loss. 48 | // It doesn't really have a big impact. 49 | (e - s).toDouble 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/discretization/EqualWidthBinFinder.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.discretization 19 | 20 | import org.apache.spark.rdd.RDD 21 | import spark_ml.util.ProgressNotifiee 22 | 23 | /** 24 | * Compute equal width bins for each numeric column. 25 | */ 26 | class EqualWidthBinFinder extends BinFinder { 27 | def findBins( 28 | data: RDD[(Double, Array[Double])], 29 | columnNames: (String, Array[String]), 30 | catIndices: Set[Int], 31 | maxNumBins: Int, 32 | expectedLabelCardinality: Option[Int], 33 | notifiee: ProgressNotifiee 34 | ): (LabelSummary, Seq[Bins]) = { 35 | val numFeatures = columnNames._2.length 36 | 37 | // Find the label summary as well as the existence of feature values and the 38 | // min/max values. 39 | val (labelSummary, featureHasNan, minMaxValues) = 40 | this.getDataSummary( 41 | data, 42 | numFeatures, 43 | expectedLabelCardinality 44 | ) 45 | 46 | val numericFeatureBins = featureHasNan.zipWithIndex.map { 47 | case (featHasNan, featIdx) if !catIndices.contains(featIdx) => 48 | val minMax = minMaxValues(featIdx) 49 | if (minMax.minValue.isPosInfinity || (minMax.minValue == minMax.maxValue)) { 50 | NumericBins( 51 | Seq( 52 | NumericBin(lower = Double.NegativeInfinity, upper = Double.PositiveInfinity) 53 | ), 54 | if (featHasNan) Some(1) else None 55 | ) 56 | } else { 57 | val nonNaNMaxNumBins = maxNumBins - (if (featHasNan) 1 else 0) 58 | val binWidth = (minMax.maxValue - minMax.minValue) / nonNaNMaxNumBins.toDouble 59 | NumericBins( 60 | (0 to (nonNaNMaxNumBins - 1)).map { 61 | binIdx => { 62 | binIdx match { 63 | case 0 => NumericBin(lower = Double.NegativeInfinity, upper = binWidth + minMax.minValue) 64 | case x if x == (nonNaNMaxNumBins - 1) => NumericBin(lower = minMax.maxValue - binWidth, upper = Double.PositiveInfinity) 65 | case _ => NumericBin(lower = minMax.minValue + binWidth * binIdx.toDouble, upper = minMax.minValue + binWidth * (binIdx + 1).toDouble) 66 | } 67 | } 68 | }, 69 | if (featHasNan) Some(nonNaNMaxNumBins) else None 70 | ) 71 | } 72 | case (featHasNan, featIdx) if catIndices.contains(featIdx) => 73 | CategoricalBins((minMaxValues(featIdx).maxValue + 1.0).toInt) 74 | } 75 | 76 | (labelSummary, numericFeatureBins) 77 | } 78 | } 79 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/discretization/VarianceMinimizingBinFinderFromSample.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.discretization 19 | 20 | /** 21 | * Find bins that minimize the label variance. Essentially equivalent to 22 | * building a decision tree on a single feature with variance criteria. 23 | * @param maxSampleSize The maximum sample. 24 | * @param seed The seed to use with random sampling. 25 | */ 26 | class VarianceMinimizingBinFinderFromSample(maxSampleSize: Int, seed: Int) 27 | extends NumericBinFinderFromSample("LabelVariance", maxSampleSize, seed) { 28 | 29 | /** 30 | * Calculate the label variance of the segment. 31 | * @param labelValues A sample label values. 32 | * @param featureValues A sample feature values. 33 | * @param labelSummary This is useful if there's a need to find the label 34 | * class cardinality. 35 | * @param s The segment starting index (inclusive). 36 | * @param e The segment ending index (exclusive). 37 | * @return The loss of the segment. 38 | */ 39 | override def findSegmentLoss( 40 | labelValues: Array[Double], 41 | featureValues: Array[Double], 42 | labelSummary: LabelSummary, 43 | s: Int, 44 | e: Int): Double = { 45 | val segSize = (e - s).toDouble 46 | var sum = 0.0 47 | var sqrSum = 0.0 48 | var i = s 49 | while (i < e) { 50 | val labelValue = labelValues(i).toInt 51 | sum += labelValue 52 | sqrSum += labelValue * labelValue 53 | i += 1 54 | } 55 | 56 | val avg = sum / segSize 57 | sqrSum / segSize - avg * avg 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/GradientBoostingOptions.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting 19 | 20 | import spark_ml.gradient_boosting.loss.LossFunction 21 | import spark_ml.tree_ensembles.CatSplitType 22 | import spark_ml.util.BaggingType 23 | 24 | /** 25 | * Gradient boosting trainer options. 26 | * @param numTrees Number of trees to build. The algorithm can also determine 27 | * the optimal number of trees if a validation data is provided. 28 | * @param maxTreeDepth Maximum tree depth allowed per tree. 29 | * @param minSplitSize Min split size allowed per tree. 30 | * @param lossFunction The loss function object to use. 31 | * @param catSplitType How to split categorical features. 32 | * @param baggingRate Bagging rate. 33 | * @param baggingType Whether to bag with/without replacements. 34 | * @param shrinkage Shrinkage. 35 | * @param fineTuneTerminalNodes Whether to fine-tune tree's terminal nodes so 36 | * that their values are directly optimizing 37 | * against the loss function. 38 | * @param checkpointDir Checkpoint directory. 39 | * @param predCheckpointInterval Intermediate prediction checkpointing interval. 40 | * @param idCacheCheckpointInterval Id cache checkpointing interval. 41 | * @param verbose If true, the algorithm will print as much information through 42 | * the notifiee as possible, including many intermediate 43 | * computation values, etc. 44 | */ 45 | case class GradientBoostingOptions( 46 | numTrees: Int, 47 | maxTreeDepth: Int, 48 | minSplitSize: Int, 49 | lossFunction: LossFunction, 50 | catSplitType: CatSplitType.CatSplitType, 51 | baggingRate: Double, 52 | baggingType: BaggingType.BaggingType, 53 | shrinkage: Double, 54 | fineTuneTerminalNodes: Boolean, 55 | checkpointDir: Option[String], 56 | predCheckpointInterval: Int, 57 | idCacheCheckpointInterval: Int, 58 | verbose: Boolean) { 59 | override def toString: String = { 60 | "=========================" + "\n" + 61 | "Gradient Boosting Options" + "\n" + 62 | "=========================" + "\n" + 63 | "numTrees : " + numTrees + "\n" + 64 | "maxTreeDepth : " + maxTreeDepth + "\n" + 65 | "minSplitSize : " + minSplitSize + "\n" + 66 | "lossFunction : " + lossFunction.getClass.getSimpleName + "\n" + 67 | "catSplitType : " + catSplitType.toString + "\n" + 68 | "baggingRate : " + baggingRate + "\n" + 69 | "baggingType : " + baggingType.toString + "\n" + 70 | "shrinkage : " + shrinkage + "\n" + 71 | "fineTuneTerminalNodes : " + fineTuneTerminalNodes + "\n" + 72 | "checkpointDir : " + (checkpointDir match { case None => "None" case Some(dir) => dir }) + "\n" + 73 | "predCheckpointInterval : " + predCheckpointInterval + "\n" + 74 | "idCacheCheckpointInterval : " + idCacheCheckpointInterval + "\n" + 75 | "verbose : " + verbose + "\n" 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/PredCache.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting 19 | 20 | import scala.collection.mutable 21 | 22 | import org.apache.hadoop.fs.{FileSystem, Path} 23 | import org.apache.spark.rdd.RDD 24 | import org.apache.spark.storage.StorageLevel 25 | import spark_ml.model.gb.GBInternalTree 26 | import spark_ml.util.DiscretizedFeatureHandler 27 | 28 | object PredCache { 29 | def createPredCache( 30 | initPreds: RDD[Double], 31 | shrinkage: Double, 32 | storageLevel: StorageLevel, 33 | checkpointDir: Option[String], 34 | checkpointInterval: Int): PredCache = { 35 | new PredCache( 36 | curPreds = initPreds, 37 | shrinkage = shrinkage, 38 | storageLevel = storageLevel, 39 | checkpointDir = checkpointDir, 40 | checkpointInterval = checkpointInterval 41 | ) 42 | } 43 | } 44 | 45 | class PredCache( 46 | var curPreds: RDD[Double], 47 | shrinkage: Double, 48 | storageLevel: StorageLevel, 49 | checkpointDir: Option[String], 50 | checkpointInterval: Int) { 51 | private var prevPreds: RDD[Double] = null 52 | private var updateCount: Int = 0 53 | 54 | private val checkpointQueue = new mutable.Queue[RDD[Double]]() 55 | 56 | // Persist the initial predictions. 57 | curPreds = curPreds.persist(storageLevel) 58 | 59 | if (checkpointDir.isDefined && curPreds.sparkContext.getCheckpointDir.isEmpty) { 60 | curPreds.sparkContext.setCheckpointDir(checkpointDir.get) 61 | } 62 | 63 | def getRdd: RDD[Double] = curPreds 64 | 65 | def updatePreds[@specialized(Byte, Short) T]( 66 | discFeatData: RDD[Array[T]], 67 | tree: GBInternalTree, 68 | featureHandler: DiscretizedFeatureHandler[T]): Unit = { 69 | if (prevPreds != null) { 70 | // Unpersist the previous one if one exists. 71 | prevPreds.unpersist(blocking = true) 72 | } 73 | 74 | prevPreds = curPreds 75 | 76 | // Need to do this since we don't want to serialize this object. 77 | val shk = shrinkage 78 | curPreds = discFeatData.zip(curPreds).map { 79 | case (features, curPred) => 80 | curPred + shk * tree.predict(features, featureHandler) 81 | }.persist(storageLevel) 82 | 83 | updateCount += 1 84 | 85 | // Handle checkpointing if the directory is not None. 86 | if (curPreds.sparkContext.getCheckpointDir.isDefined && 87 | (updateCount % checkpointInterval) == 0) { 88 | // See if we can delete previous checkpoints. 89 | var canDelete = true 90 | while (checkpointQueue.size > 1 && canDelete) { 91 | // We can delete the oldest checkpoint iff the next checkpoint actually 92 | // exists in the file system. 93 | if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) { 94 | val old = checkpointQueue.dequeue() 95 | 96 | // Since the old checkpoint is not deleted by Spark, we'll manually 97 | // delete it here. 98 | val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) 99 | println("Deleting a stale PredCache RDD checkpoint at " + old.getCheckpointFile.get) 100 | fs.delete(new Path(old.getCheckpointFile.get), true) 101 | } else { 102 | canDelete = false 103 | } 104 | } 105 | 106 | curPreds.checkpoint() 107 | checkpointQueue.enqueue(curPreds) 108 | } 109 | } 110 | 111 | def close(): Unit = { 112 | // Unpersist and delete all the checkpoints. 113 | curPreds.unpersist(blocking = true) 114 | if (prevPreds != null) { 115 | prevPreds.unpersist(blocking = true) 116 | } 117 | 118 | while (checkpointQueue.nonEmpty) { 119 | val old = checkpointQueue.dequeue() 120 | if (old.getCheckpointFile.isDefined) { 121 | val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) 122 | println("Deleting a stale PredCache RDD checkpoint at " + old.getCheckpointFile.get) 123 | fs.delete(new Path(old.getCheckpointFile.get), true) 124 | } 125 | } 126 | } 127 | } 128 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/LossAggregator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss 19 | 20 | /** 21 | * Trait for aggregators used for loss calculations in gradient boosting 22 | * algorithms. The gradient boosting algorithm supports pluggable aggregators. 23 | * The implementations for different loss functions determine the type of 24 | * statistics to aggregate. 25 | */ 26 | trait LossAggregator extends Serializable { 27 | /** 28 | * Add a sample to the aggregator. 29 | * @param label Label of the sample. 30 | * @param weight Weight of the sample. 31 | * @param curPred Current prediction 32 | * (should be computed using available trees). 33 | */ 34 | def addSamplePoint( 35 | label: Double, 36 | weight: Double, 37 | curPred: Double): Unit 38 | 39 | /** 40 | * Compute the gradient of the sample at the current prediction. 41 | * @param label Label of the sample. 42 | * @param curPred Current prediction 43 | * (should be computed using available trees). 44 | */ 45 | def computeGradient( 46 | label: Double, 47 | curPred: Double): Double 48 | 49 | /** 50 | * Merge the aggregated values with another aggregator. 51 | * @param b The other aggregator to merge with. 52 | * @return This. 53 | */ 54 | def mergeInPlace(b: LossAggregator): this.type 55 | 56 | /** 57 | * Using the aggregated values, compute deviance. 58 | * @return Deviance. 59 | */ 60 | def computeDeviance(): Double 61 | 62 | /** 63 | * Using the aggregated values, compute the initial value. 64 | * @return Inital value. 65 | */ 66 | def computeInitialValue(): Double 67 | 68 | /** 69 | * Using the aggregated values for a particular node, 70 | * compute the estimated node value. 71 | * @return Node estimate. 72 | */ 73 | def computeNodeEstimate(): Double 74 | } 75 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/LossFunction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss 19 | 20 | /** 21 | * Each loss function should implement this trait and pass it onto the gradient 22 | * boosting algorithm. The job is currently mainly to provide proper aggregators 23 | * for computing losses and gradients. 24 | */ 25 | trait LossFunction extends Serializable { 26 | /** 27 | * String name of the loss function. 28 | * @return The string name of the loss function. 29 | */ 30 | def lossFunctionName: String 31 | 32 | /** 33 | * Create the loss aggregator for this loss function. 34 | * @return 35 | */ 36 | def createAggregator: LossAggregator 37 | 38 | /** 39 | * If this loss function is used for categorical labels, this function returns 40 | * the expected label cardinality. E.g., loss functions like AdaBoost, 41 | * logistic losses are used for binary classification, so this should return 42 | * Some(2). For regressions like the Gaussian loss, this should return None. 43 | * @return either Some(cardinality) or None. 44 | */ 45 | def getLabelCardinality: Option[Int] 46 | 47 | /** 48 | * Whether tree node estimate refinement is possible. 49 | * @return true if node estimates can be refined. false otherwise. 50 | */ 51 | def canRefineNodeEstimate: Boolean 52 | 53 | /** 54 | * Convert the raw tree ensemble prediction into a usable form by applying 55 | * the mean function. E.g., this gives the actual regression prediction and/or 56 | * a probability of a class. 57 | * @param rawPred The raw tree ensemble prediction. E.g., this could be 58 | * unbounded negative or positive numbers for 59 | * AdaaBoost/LogLoss/PoissonLoss. We want to return bounded 60 | * numbers or actual count estimate for those losses, for 61 | * instance. 62 | * @return A mean function applied value. 63 | */ 64 | def applyMeanFunction(rawPred: Double): Double 65 | } 66 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/AdaBoostLossAggregator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossAggregator 21 | import spark_ml.util.RobustMath 22 | 23 | /** 24 | * The AdaBoost loss aggregator. This is used for binary classifications. 25 | * Gradient boosting's AdaBoost is an approximation of the original AdaBoosting 26 | * in that the exponential loss is reduced through gradient steps (original 27 | * AdaBoost has a different optimization routine). 28 | */ 29 | class AdaBoostLossAggregator extends LossAggregator { 30 | private var weightSum: Double = 0.0 31 | private var weightedLabelSum: Double = 0.0 32 | private var weightedLossSum: Double = 0.0 33 | private var weightedNumeratorSum: Double = 0.0 34 | 35 | // If we want to show log loss. 36 | // private var weightedLogLossSum: Double = 0.0 37 | 38 | /** 39 | * Add a sample to the aggregator. 40 | * @param label Label of the sample. 41 | * @param weight Weight of the sample. 42 | * @param curPred Current prediction (e.g., using available trees). 43 | */ 44 | def addSamplePoint( 45 | label: Double, 46 | weight: Double, 47 | curPred: Double): Unit = { 48 | val a = 2 * label - 1.0 49 | val sampleLoss = math.exp(-a * curPred) 50 | 51 | weightSum += weight 52 | weightedLabelSum += weight * label 53 | weightedLossSum += weight * sampleLoss 54 | weightedNumeratorSum += weight * a * sampleLoss 55 | 56 | // val prob = 1.0 / (1.0 + math.exp(-2.0 * curPred)) 57 | // val logLoss = -(label * math.log(prob) + (1.0 - label) * math.log(1.0 - prob)) 58 | // weightedLogLossSum += weight * logLoss 59 | } 60 | 61 | /** 62 | * Compute the gradient. 63 | * @param label Label of the sample. 64 | * @param curPred Current prediction (e.g., using available trees). 65 | * @return The gradient of the sample. 66 | */ 67 | def computeGradient( 68 | label: Double, 69 | curPred: Double): Double = { 70 | val a = 2 * label - 1.0 71 | a * math.exp(-a * curPred) 72 | } 73 | 74 | /** 75 | * Merge the aggregated values with another aggregator. 76 | * @param b The other aggregator to merge with. 77 | * @return This. 78 | */ 79 | def mergeInPlace(b: LossAggregator): this.type = { 80 | this.weightSum += b.asInstanceOf[AdaBoostLossAggregator].weightSum 81 | this.weightedLabelSum += b.asInstanceOf[AdaBoostLossAggregator].weightedLabelSum 82 | this.weightedLossSum += b.asInstanceOf[AdaBoostLossAggregator].weightedLossSum 83 | this.weightedNumeratorSum += b.asInstanceOf[AdaBoostLossAggregator].weightedNumeratorSum 84 | // this.weightedLogLossSum += b.asInstanceOf[AdaBoostLossAggregator].weightedLogLossSum 85 | this 86 | } 87 | 88 | /** 89 | * Using the aggregated values, compute deviance. 90 | * @return Deviance. 91 | */ 92 | def computeDeviance(): Double = { 93 | weightedLossSum / weightSum 94 | // weightedLogLossSum / weightSum 95 | } 96 | 97 | /** 98 | * Using the aggregated values, compute the initial value. 99 | * @return Inital value. 100 | */ 101 | def computeInitialValue(): Double = { 102 | RobustMath.log(weightedLabelSum / (weightSum - weightedLabelSum)) / 2.0 103 | } 104 | 105 | /** 106 | * Using the aggregated values for a particular node, compute the estimated 107 | * node value. 108 | * @return Node estimate. 109 | */ 110 | def computeNodeEstimate(): Double = { 111 | // For the adaboost loss (exponential loss), the node estimate is 112 | // approximated as one Newton-Raphson method step's result, as the optimal 113 | // value will be either negative or positive infinities. 114 | weightedNumeratorSum / weightedLossSum 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/AdaBoostLossFunction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossFunction 21 | 22 | /** 23 | * The AdaBoost loss function. This is used for binary classifications. 24 | * Gradient boosting's AdaBoost is an approximation of the original AdaBoosting 25 | * in that the exponential loss is reduced through gradient steps (original 26 | * AdaBoost has a different optimization routine). 27 | */ 28 | class AdaBoostLossFunction extends LossFunction { 29 | private val eps = 1e-15 30 | 31 | def lossFunctionName = "AdaBoost(Exponential)" 32 | def createAggregator = new AdaBoostLossAggregator 33 | def getLabelCardinality: Option[Int] = Some(2) 34 | def canRefineNodeEstimate: Boolean = true 35 | 36 | def applyMeanFunction(rawPred: Double): Double = { 37 | math.min(math.max(1.0 / (1.0 + math.exp(-2.0 * rawPred)), eps), 1.0 - eps) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/GaussianLossAggregator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossAggregator 21 | 22 | /** 23 | * Aggregator for Gaussian Losses. 24 | */ 25 | class GaussianLossAggregator extends LossAggregator { 26 | private var weightSum: Double = 0.0 27 | private var weightedSqrLossSum = 0.0 28 | private var weightedLabelSum = 0.0 29 | 30 | /** 31 | * Add a sample to the aggregator. 32 | * @param label Label of the sample. 33 | * @param weight Weight of the sample. 34 | * @param curPred Current prediction (e.g., using available trees). 35 | */ 36 | def addSamplePoint( 37 | label: Double, 38 | weight: Double, 39 | curPred: Double): Unit = { 40 | val gradient = computeGradient(label, curPred) 41 | val weightedGradient = weight * gradient 42 | weightSum += weight 43 | weightedSqrLossSum += weightedGradient * gradient 44 | weightedLabelSum += weight * label 45 | } 46 | 47 | /** 48 | * Compute the gaussian gradient. 49 | * @param label Label of the sample. 50 | * @param curPred Current prediction (e.g., using available trees). 51 | * @return The gradient of the sample. 52 | */ 53 | def computeGradient( 54 | label: Double, 55 | curPred: Double): Double = { 56 | label - curPred 57 | } 58 | 59 | /** 60 | * Merge the aggregated values with another aggregator. 61 | * @param b The other aggregator to merge with. 62 | * @return This. 63 | */ 64 | def mergeInPlace(b: LossAggregator): this.type = { 65 | this.weightSum += b.asInstanceOf[GaussianLossAggregator].weightSum 66 | this.weightedSqrLossSum += b.asInstanceOf[GaussianLossAggregator].weightedSqrLossSum 67 | this.weightedLabelSum += b.asInstanceOf[GaussianLossAggregator].weightedLabelSum 68 | this 69 | } 70 | 71 | /** 72 | * Using the aggregated values, compute deviance. 73 | * @return Deviance. 74 | */ 75 | def computeDeviance(): Double = { 76 | weightedSqrLossSum / weightSum 77 | } 78 | 79 | /** 80 | * Using the aggregated values, compute the initial value. 81 | * @return Inital value. 82 | */ 83 | def computeInitialValue(): Double = { 84 | weightedLabelSum / weightSum 85 | } 86 | 87 | /** 88 | * Using the aggregated values for a particular node, compute the estimated 89 | * node value. 90 | * @return Node estimate. 91 | */ 92 | def computeNodeEstimate(): Double = ??? 93 | } 94 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/GaussianLossFunction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossFunction 21 | 22 | /** 23 | * The gaussian loss function. 24 | */ 25 | class GaussianLossFunction extends LossFunction { 26 | def lossFunctionName = "Gaussian" 27 | def createAggregator = new GaussianLossAggregator 28 | def getLabelCardinality: Option[Int] = None 29 | def canRefineNodeEstimate: Boolean = false 30 | 31 | def applyMeanFunction(rawPred: Double): Double = rawPred 32 | } 33 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/LaplacianLossAggregator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import scala.util.Random 21 | 22 | import spark_ml.gradient_boosting.loss.LossAggregator 23 | import spark_ml.util.{Selection, ReservoirSample} 24 | 25 | /** 26 | * Aggregator for Laplacian Losses. 27 | * It uses approximate medians for initial value estimates but that is not 28 | * likely to affect the performance drastically. 29 | */ 30 | class LaplacianLossAggregator extends LossAggregator { 31 | private val maxSample = 10000 32 | 33 | private var labelSample = new ReservoirSample(maxSample) 34 | private var labelPredDiffSample = new ReservoirSample(maxSample) 35 | 36 | private var weightSum: Double = 0.0 37 | private var weightedAbsLabelPredDiffSum: Double = 0.0 38 | 39 | // For reservoir sampling. 40 | @transient private var rng: Random = null 41 | 42 | /** 43 | * Add a sample to the aggregator. 44 | * @param label Label of the sample. 45 | * @param weight Weight of the sample. 46 | * @param curPred Current prediction (e.g., using available trees). 47 | */ 48 | def addSamplePoint( 49 | label: Double, 50 | weight: Double, 51 | curPred: Double): Unit = { 52 | if (weight > 0.0) { 53 | if (this.rng == null) { 54 | this.rng = new Random() 55 | } 56 | 57 | this.labelSample.doReservoirSampling(label, rng) 58 | 59 | val labelPredDiff = label - curPred 60 | this.labelPredDiffSample.doReservoirSampling(labelPredDiff, rng) 61 | 62 | this.weightSum += weight 63 | this.weightedAbsLabelPredDiffSum += weight * math.abs(labelPredDiff) 64 | } 65 | } 66 | 67 | /** 68 | * Compute the gradient. 69 | * @param label Label of the sample. 70 | * @param curPred Current prediction (e.g., using available trees). 71 | * @return The gradient of the sample. 72 | */ 73 | def computeGradient( 74 | label: Double, 75 | curPred: Double): Double = { 76 | math.signum(label - curPred) 77 | } 78 | 79 | /** 80 | * Merge the aggregated values with another aggregator. 81 | * @param b The other aggregator to merge with. 82 | * @return This. 83 | */ 84 | def mergeInPlace(b: LossAggregator): this.type = { 85 | if (this.rng == null) { 86 | this.rng = new Random() 87 | } 88 | 89 | this.weightSum += b.asInstanceOf[LaplacianLossAggregator].weightSum 90 | this.weightedAbsLabelPredDiffSum += b.asInstanceOf[LaplacianLossAggregator].weightedAbsLabelPredDiffSum 91 | // Now merge samples (distributed reservoir sampling). 92 | this.labelSample = ReservoirSample.mergeReservoirSamples( 93 | this.labelSample, 94 | b.asInstanceOf[LaplacianLossAggregator].labelSample, 95 | maxSample, 96 | this.rng 97 | ) 98 | this.labelPredDiffSample = ReservoirSample.mergeReservoirSamples( 99 | this.labelPredDiffSample, 100 | b.asInstanceOf[LaplacianLossAggregator].labelPredDiffSample, 101 | maxSample, 102 | this.rng 103 | ) 104 | this 105 | } 106 | 107 | /** 108 | * Using the aggregated values, compute deviance. 109 | * @return Deviance. 110 | */ 111 | def computeDeviance(): Double = { 112 | weightedAbsLabelPredDiffSum / weightSum 113 | } 114 | 115 | /** 116 | * Using the aggregated values, compute the initial value. 117 | * @return Inital value. 118 | */ 119 | def computeInitialValue(): Double = { 120 | if (this.rng == null) { 121 | this.rng = new Random() 122 | } 123 | 124 | // Get the initial value by computing the label median. 125 | Selection.quickSelect( 126 | this.labelSample.sample, 127 | 0, 128 | this.labelSample.numSamplePoints, 129 | this.labelSample.numSamplePoints / 2, 130 | rng = this.rng 131 | ) 132 | } 133 | 134 | /** 135 | * Using the aggregated values for a particular node, compute the estimated 136 | * node value. 137 | * @return Node estimate. 138 | */ 139 | def computeNodeEstimate(): Double = { 140 | if (this.rng == null) { 141 | this.rng = new Random() 142 | } 143 | 144 | // Get the node estimate by computing the label pred diff median. 145 | Selection.quickSelect( 146 | this.labelPredDiffSample.sample, 147 | 0, 148 | this.labelPredDiffSample.numSamplePoints, 149 | this.labelPredDiffSample.numSamplePoints / 2, 150 | rng = this.rng 151 | ) 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/LaplacianLossFunction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossFunction 21 | 22 | /** 23 | * The Laplacian loss function. 24 | */ 25 | class LaplacianLossFunction extends LossFunction { 26 | def lossFunctionName = "Laplacian" 27 | def createAggregator = new LaplacianLossAggregator 28 | def getLabelCardinality: Option[Int] = None 29 | def canRefineNodeEstimate: Boolean = true 30 | 31 | def applyMeanFunction(rawPred: Double): Double = rawPred 32 | } 33 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/LogLossAggregator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossAggregator 21 | import spark_ml.util.RobustMath 22 | 23 | /** 24 | * The log loss aggregator. This is a binary logistic regression aggregator. 25 | */ 26 | class LogLossAggregator extends LossAggregator { 27 | private var weightSum: Double = 0.0 28 | private var weightedLabelSum: Double = 0.0 29 | private var weightedLossSum: Double = 0.0 30 | private var weightedProbSum: Double = 0.0 31 | private var weightedProbSquareSum: Double = 0.0 32 | 33 | private val eps = 1e-15 34 | 35 | /** 36 | * Add a sample to the aggregator. 37 | * @param label Label of the sample. 38 | * @param weight Weight of the sample. 39 | * @param curPred Current prediction (e.g., using available trees). 40 | */ 41 | def addSamplePoint( 42 | label: Double, 43 | weight: Double, 44 | curPred: Double): Unit = { 45 | val prob = math.min(math.max(1.0 / (1.0 + math.exp(-curPred)), eps), 1.0 - eps) 46 | val logLoss = -(label * math.log(prob) + (1.0 - label) * math.log(1.0 - prob)) 47 | weightSum += weight 48 | weightedLabelSum += weight * label 49 | weightedLossSum += weight * logLoss 50 | weightedProbSum += weight * prob 51 | weightedProbSquareSum += weight * prob * prob 52 | } 53 | 54 | /** 55 | * Compute the gradient. 56 | * @param label Label of the sample. 57 | * @param curPred Current prediction (e.g., using available trees). 58 | * @return The gradient of the sample. 59 | */ 60 | def computeGradient( 61 | label: Double, 62 | curPred: Double): Double = { 63 | val prob = math.min(math.max(1.0 / (1.0 + math.exp(-curPred)), eps), 1.0 - eps) 64 | label - prob 65 | } 66 | 67 | /** 68 | * Merge the aggregated values with another aggregator. 69 | * @param b The other aggregator to merge with. 70 | * @return This. 71 | */ 72 | def mergeInPlace(b: LossAggregator): this.type = { 73 | this.weightSum += b.asInstanceOf[LogLossAggregator].weightSum 74 | this.weightedLabelSum += b.asInstanceOf[LogLossAggregator].weightedLabelSum 75 | this.weightedLossSum += b.asInstanceOf[LogLossAggregator].weightedLossSum 76 | this.weightedProbSum += b.asInstanceOf[LogLossAggregator].weightedProbSum 77 | this.weightedProbSquareSum += b.asInstanceOf[LogLossAggregator].weightedProbSquareSum 78 | this 79 | } 80 | 81 | /** 82 | * Using the aggregated values, compute deviance. 83 | * @return Deviance. 84 | */ 85 | def computeDeviance(): Double = { 86 | weightedLossSum / weightSum 87 | } 88 | 89 | /** 90 | * Using the aggregated values, compute the initial value. 91 | * @return Inital value. 92 | */ 93 | def computeInitialValue(): Double = { 94 | RobustMath.log(weightedLabelSum / (weightSum - weightedLabelSum)) 95 | } 96 | 97 | /** 98 | * Using the aggregated values for a particular node, compute the estimated 99 | * node value. 100 | * @return Node estimate. 101 | */ 102 | def computeNodeEstimate(): Double = { 103 | // For the log loss, the node estimate is approximated as one Newton-Raphson 104 | // method step's result, as the optimal value will be either negative or 105 | // positive infinities. 106 | (weightedLabelSum - weightedProbSum) / (weightedProbSum - weightedProbSquareSum) 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/LogLossFunction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossFunction 21 | 22 | /** 23 | * The LogLoss loss function (also known as cross-entropy). This is equivalent 24 | * to a logistic regression. 25 | */ 26 | class LogLossFunction extends LossFunction { 27 | private val eps = 1e-15 28 | 29 | def lossFunctionName = "LogLoss(LogisticRegression)" 30 | def createAggregator = new LogLossAggregator 31 | def getLabelCardinality: Option[Int] = Some(2) 32 | def canRefineNodeEstimate: Boolean = true 33 | 34 | def applyMeanFunction(rawPred: Double): Double = { 35 | math.min(math.max(1.0 / (1.0 + math.exp(-rawPred)), eps), 1.0 - eps) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/PoissonLossAggregator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossAggregator 21 | import spark_ml.util.RobustMath 22 | 23 | /** 24 | * Aggregator for Poisson Losses. 25 | */ 26 | class PoissonLossAggregator extends LossAggregator { 27 | private var weightSum: Double = 0.0 28 | private var weightedLabelPredMulSum = 0.0 29 | private var weightedExpPredSum: Double = 0.0 30 | private var weightedLabelSum: Double = 0.0 31 | 32 | /** 33 | * Add a sample to the aggregator. 34 | * @param label Label of the sample. 35 | * @param weight Weight of the sample. 36 | * @param curPred Current prediction (e.g., using available trees). 37 | */ 38 | def addSamplePoint( 39 | label: Double, 40 | weight: Double, 41 | curPred: Double): Unit = { 42 | val labelPredMul = label * curPred 43 | val expPred = RobustMath.exp(curPred) 44 | weightSum += weight 45 | weightedLabelPredMulSum += weight * labelPredMul 46 | weightedExpPredSum += weight * expPred 47 | weightedLabelSum += weight * label 48 | } 49 | 50 | /** 51 | * Compute the gradient. 52 | * @param label Label of the sample. 53 | * @param curPred Current prediction (e.g., using available trees). 54 | * @return The gradient of the sample. 55 | */ 56 | def computeGradient( 57 | label: Double, 58 | curPred: Double): Double = { 59 | val expPred = RobustMath.exp(curPred) 60 | label - expPred 61 | } 62 | 63 | /** 64 | * Merge the aggregated values with another aggregator. 65 | * @param b The other aggregator to merge with. 66 | * @return This. 67 | */ 68 | def mergeInPlace(b: LossAggregator): this.type = { 69 | this.weightSum += b.asInstanceOf[PoissonLossAggregator].weightSum 70 | this.weightedLabelSum += b.asInstanceOf[PoissonLossAggregator].weightedLabelSum 71 | this.weightedLabelPredMulSum += b.asInstanceOf[PoissonLossAggregator].weightedLabelPredMulSum 72 | this.weightedExpPredSum += b.asInstanceOf[PoissonLossAggregator].weightedExpPredSum 73 | this 74 | } 75 | 76 | /** 77 | * Using the aggregated values, compute deviance. 78 | * @return Deviance. 79 | */ 80 | def computeDeviance(): Double = { 81 | -2.0 * (weightedLabelPredMulSum - weightedExpPredSum) / weightSum 82 | } 83 | 84 | /** 85 | * Using the aggregated values, compute the initial value. 86 | * @return Inital value. 87 | */ 88 | def computeInitialValue(): Double = { 89 | RobustMath.log(weightedLabelSum / weightSum) 90 | } 91 | 92 | /** 93 | * Using the aggregated values for a particular node, compute the estimated 94 | * node value. 95 | * @return Node estimate. 96 | */ 97 | def computeNodeEstimate(): Double = { 98 | RobustMath.log(weightedLabelSum / weightedExpPredSum) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/PoissonLossFunction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossFunction 21 | 22 | /** 23 | * The poisson loss function. 24 | */ 25 | class PoissonLossFunction extends LossFunction { 26 | private val expPredLowerLimit = math.exp(-19.0) 27 | private val expPredUpperLimit = math.exp(19.0) 28 | 29 | def lossFunctionName = "Poisson" 30 | def createAggregator = new PoissonLossAggregator 31 | def getLabelCardinality: Option[Int] = None 32 | def canRefineNodeEstimate: Boolean = true 33 | 34 | def applyMeanFunction(rawPred: Double): Double = { 35 | math.min(math.max(math.exp(rawPred), expPredLowerLimit), expPredUpperLimit) 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/TruncHingeLossAggregator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossAggregator 21 | import spark_ml.util.RobustMath 22 | 23 | class TruncHingeLossAggregator(maxProb: Double, eps: Double) extends LossAggregator { 24 | private var weightSum: Double = 0.0 25 | private var weightedLabelSum: Double = 0.0 26 | private var weightedLossSum: Double = 0.0 27 | 28 | private val minProb = 1.0 - maxProb 29 | private val maxLinearValue = RobustMath.log(maxProb / minProb) 30 | 31 | /** 32 | * Add a sample to the aggregator. 33 | * @param label Label of the sample. 34 | * @param weight Weight of the sample. 35 | * @param curPred Current prediction (e.g., using available trees). 36 | */ 37 | def addSamplePoint( 38 | label: Double, 39 | weight: Double, 40 | curPred: Double): Unit = { 41 | val positiveLoss = math.min(math.max(-(curPred - eps), 0.0), maxLinearValue) 42 | val negativeLoss = math.min(math.max(curPred + eps, 0.0), maxLinearValue) 43 | val loss = label * positiveLoss + (1.0 - label) * negativeLoss 44 | weightSum += weight 45 | weightedLabelSum += weight * label 46 | weightedLossSum += weight * loss 47 | } 48 | 49 | /** 50 | * Compute the gradient. 51 | * @param label Label of the sample. 52 | * @param curPred Current prediction (e.g., using available trees). 53 | * @return The gradient of the sample. 54 | */ 55 | def computeGradient( 56 | label: Double, 57 | curPred: Double): Double = { 58 | val positiveLoss = math.max(-(curPred - eps), 0.0) 59 | val positiveLossGradient = 60 | if (positiveLoss == 0.0 || positiveLoss > maxLinearValue) { 61 | 0.0 62 | } else { 63 | 1.0 64 | } 65 | val negativeLoss = math.max(curPred + eps, 0.0) 66 | val negativeLossGradient = 67 | if (negativeLoss == 0.0 || negativeLoss > maxLinearValue) { 68 | 0.0 69 | } else { 70 | -1.0 71 | } 72 | label * positiveLossGradient + (1.0 - label) * negativeLossGradient 73 | } 74 | 75 | /** 76 | * Merge the aggregated values with another aggregator. 77 | * @param b The other aggregator to merge with. 78 | * @return This. 79 | */ 80 | def mergeInPlace(b: LossAggregator): this.type = { 81 | this.weightSum += b.asInstanceOf[TruncHingeLossAggregator].weightSum 82 | this.weightedLabelSum += b.asInstanceOf[TruncHingeLossAggregator].weightedLabelSum 83 | this.weightedLossSum += b.asInstanceOf[TruncHingeLossAggregator].weightedLossSum 84 | this 85 | } 86 | 87 | /** 88 | * Using the aggregated values, compute deviance. 89 | * @return Deviance. 90 | */ 91 | def computeDeviance(): Double = { 92 | weightedLossSum / weightSum 93 | } 94 | 95 | /** 96 | * Using the aggregated values, compute the initial value. 97 | * @return Inital value. 98 | */ 99 | def computeInitialValue(): Double = { 100 | RobustMath.log(weightedLabelSum / (weightSum - weightedLabelSum)) 101 | } 102 | 103 | /** 104 | * Using the aggregated values for a particular node, compute the estimated 105 | * node value. 106 | * @return Node estimate. 107 | */ 108 | def computeNodeEstimate(): Double = ??? 109 | } 110 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/gradient_boosting/loss/defaults/TruncHingeLossFunction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.gradient_boosting.loss.defaults 19 | 20 | import spark_ml.gradient_boosting.loss.LossFunction 21 | 22 | class TruncHingeLossFunction extends LossFunction { 23 | private var maxProb = 0.0 24 | private var eps = 0.0 25 | private var minProb = 0.0 26 | private var maxLinearValue = 0.0 27 | 28 | // Default values. 29 | setMaxProb(0.9) 30 | setEps(0.12) 31 | 32 | def setMaxProb(maxp: Double): Unit = { 33 | maxProb = maxp 34 | minProb = 1.0 - maxProb 35 | maxLinearValue = math.log(maxProb / minProb) 36 | } 37 | 38 | def setEps(e: Double): Unit = { 39 | eps = e 40 | } 41 | 42 | def lossFunctionName = "TruncHingeLoss" 43 | def createAggregator = new TruncHingeLossAggregator(maxProb, eps) 44 | def getLabelCardinality: Option[Int] = Some(2) 45 | def canRefineNodeEstimate: Boolean = false 46 | 47 | def applyMeanFunction(rawPred: Double): Double = { 48 | math.min(math.max(1.0 / (1.0 + math.exp(-rawPred)), minProb), maxProb) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/model/gb/GradientBoostedTrees.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.model.gb 19 | 20 | import spark_ml.gradient_boosting.loss.LossFunction 21 | import spark_ml.model.DecisionTree 22 | import spark_ml.transformation.ColumnTransformer 23 | 24 | case class GradientBoostedTreesDefault( 25 | lossFunctionClassName: String, 26 | labelTransformer: ColumnTransformer, 27 | featureTransformers: Array[ColumnTransformer], 28 | labelName: String, 29 | labelIsCat: Boolean, 30 | featureNames: Array[String], 31 | featureIsCat: Array[Boolean], 32 | sortedVarImportance: Seq[(String, java.lang.Double)], 33 | shrinkage: Double, 34 | initValue: Double, 35 | decisionTrees: Array[DecisionTree], 36 | optimalTreeCnt: Option[java.lang.Integer], 37 | trainingDevianceHistory: Seq[java.lang.Double], 38 | validationDevianceHistory: Option[Seq[java.lang.Double]]) extends GradientBoostedTrees { 39 | 40 | def lossFunction: LossFunction = { 41 | Class.forName(lossFunctionClassName).newInstance().asInstanceOf[LossFunction] 42 | } 43 | 44 | def numTrees = decisionTrees.length 45 | 46 | /** 47 | * The prediction is done on raw features. Internally, the model should 48 | * transform them to proper forms before predicting. 49 | * @param rawFeatures Raw features. 50 | * @param useOptimalTreeCnt A flag to indicate that we want to use optimal 51 | * number of trees. 52 | * @return The predicted value. 53 | */ 54 | def predict(rawFeatures: Seq[Any], useOptimalTreeCnt: Boolean): Double = { 55 | // Transform the raw features first. 56 | val transformedFeatures = 57 | featureTransformers.zip(rawFeatures).map { 58 | case (featTransformer, rawFeatVal) => 59 | if (rawFeatVal == null) { 60 | featTransformer.transform(null) 61 | } else { 62 | featTransformer.transform(rawFeatVal.toString) 63 | } 64 | } 65 | val predictedValue = 66 | ( 67 | if (useOptimalTreeCnt && optimalTreeCnt.isDefined) { 68 | decisionTrees.slice(0, optimalTreeCnt.get) 69 | } else { 70 | decisionTrees 71 | } 72 | ).foldLeft(initValue) { 73 | case (curPred, tree) => 74 | curPred + shrinkage * tree.predict(transformedFeatures) 75 | } 76 | 77 | lossFunction.applyMeanFunction(predictedValue) 78 | } 79 | } 80 | 81 | /** 82 | * A public model for the Gradient boosted trees that have been trained should 83 | * extend this trait. 84 | */ 85 | trait GradientBoostedTrees extends Serializable { 86 | def lossFunction: LossFunction 87 | def initValue: Double 88 | def numTrees: Int 89 | def optimalTreeCnt: Option[java.lang.Integer] 90 | 91 | /** 92 | * Sorted variable importance. 93 | * @return Sorted variable importance. 94 | */ 95 | def sortedVarImportance: Seq[(String, java.lang.Double)] 96 | 97 | /** 98 | * The prediction is done on raw features. Internally, the model should 99 | * transform them to proper forms before predicting. 100 | * @param rawFeatures Raw features. 101 | * @param useOptimalTreeCnt A flag to indicate that we want to use optimal 102 | * number of trees. 103 | * @return The predicted value. 104 | */ 105 | def predict(rawFeatures: Seq[Any], useOptimalTreeCnt: Boolean): Double 106 | } 107 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/model/gb/GradientBoostedTreesFactory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.model.gb 19 | 20 | import scala.collection.mutable 21 | 22 | import spark_ml.discretization.Bins 23 | import spark_ml.model.{DecisionTree, DecisionTreeNode, DecisionTreeUtil} 24 | import spark_ml.transformation.ColumnTransformer 25 | 26 | /** 27 | * The trainer will store the trained model using the internal types defined 28 | * in this file. If a developer wants to return a customized GBT type as a 29 | * result, then he/she can implement this factory trait and pass it onto the 30 | * trainer. 31 | */ 32 | trait GradientBoostedTreesFactory { 33 | /** 34 | * The GB tree ensemble factory needs to know how to transform label and 35 | * features from potentially categorical/string values to enumerated numeric 36 | * values. E.g. the final model might use enumerated categorical values to 37 | * perform predictions. 38 | * @param labelTransformer Label transformer. 39 | * @param featureTransformers Feature transformers. 40 | */ 41 | def setColumnTransformers( 42 | labelTransformer: ColumnTransformer, 43 | featureTransformers: Array[ColumnTransformer] 44 | ): Unit 45 | 46 | /** 47 | * The GB tree ensemble factory also needs to know label/feature names and 48 | * types. 49 | * @param labelName Label name. 50 | * @param labelIsCat Whether the label is categorical. 51 | * @param featureNames Feature names. 52 | * @param featureIsCat Whether the individual features are categorical. 53 | */ 54 | def setColumnNamesAndTypes( 55 | labelName: String, 56 | labelIsCat: Boolean, 57 | featureNames: Array[String], 58 | featureIsCat: Array[Boolean] 59 | ): Unit 60 | 61 | /** 62 | * Set the optimal tree count, as determined through validations. 63 | * @param optimalTreeCnt The optimal tree count. 64 | */ 65 | def setOptimalTreeCnt(optimalTreeCnt: Int): Unit 66 | 67 | /** 68 | * Set the training deviance history. 69 | * @param trainingDevianceHistory Training deviance history. 70 | */ 71 | def setTrainingDevianceHistory( 72 | trainingDevianceHistory: mutable.ListBuffer[Double] 73 | ): Unit 74 | 75 | /** 76 | * Set the validation deviance history. 77 | * @param validationDevianceHistory Validation deviance history. 78 | */ 79 | def setValidationDevianceHistory( 80 | validationDevianceHistory: mutable.ListBuffer[Double] 81 | ): Unit 82 | 83 | /** 84 | * Set the feature bins. 85 | * @param featureBins Feature bins. 86 | */ 87 | def setFeatureBins(featureBins: Array[Bins]): Unit 88 | 89 | /** 90 | * Create a final model that incorporates all the transformations and 91 | * discretizations and can predict on the raw feature values. 92 | * @param store The store that contains internally trained models. 93 | * @return A final GBT model. 94 | */ 95 | def createGradientBoostedTrees( 96 | store: GradientBoostedTreesStore 97 | ): GradientBoostedTrees 98 | } 99 | 100 | /** 101 | * A default implementation of the GBT factory. 102 | */ 103 | class GradientBoostedTreesFactoryDefault extends GradientBoostedTreesFactory { 104 | var labelTransformer: Option[ColumnTransformer] = None 105 | var featureTransformers: Option[Array[ColumnTransformer]] = None 106 | var labelName: Option[String] = None 107 | var labelIsCat: Option[Boolean] = None 108 | var featureNames: Option[Array[String]] = None 109 | var featureIsCat: Option[Array[Boolean]] = None 110 | var optimalTreeCnt: Option[Int] = None 111 | var trainingDevianceHistory: Option[mutable.ListBuffer[Double]] = None 112 | var validationDevianceHistory: Option[mutable.ListBuffer[Double]] = None 113 | var featureBins: Option[Array[Bins]] = None 114 | 115 | def setColumnTransformers( 116 | labelTransformer: ColumnTransformer, 117 | featureTransformers: Array[ColumnTransformer] 118 | ): Unit = { 119 | this.labelTransformer = Some(labelTransformer) 120 | this.featureTransformers = Some(featureTransformers) 121 | } 122 | 123 | def setColumnNamesAndTypes( 124 | labelName: String, 125 | labelIsCat: Boolean, 126 | featureNames: Array[String], 127 | featureIsCat: Array[Boolean] 128 | ): Unit = { 129 | this.labelName = Some(labelName) 130 | this.labelIsCat = Some(labelIsCat) 131 | this.featureNames = Some(featureNames) 132 | this.featureIsCat = Some(featureIsCat) 133 | } 134 | 135 | def setOptimalTreeCnt(optimalTreeCnt: Int): Unit = { 136 | this.optimalTreeCnt = Some(optimalTreeCnt) 137 | } 138 | 139 | def setTrainingDevianceHistory( 140 | trainingDevianceHistory: mutable.ListBuffer[Double] 141 | ): Unit = { 142 | this.trainingDevianceHistory = Some(trainingDevianceHistory) 143 | } 144 | 145 | def setValidationDevianceHistory( 146 | validationDevianceHistory: mutable.ListBuffer[Double] 147 | ): Unit = { 148 | this.validationDevianceHistory = Some(validationDevianceHistory) 149 | } 150 | 151 | def setFeatureBins(featureBins: Array[Bins]): Unit = { 152 | this.featureBins = Some(featureBins) 153 | } 154 | 155 | def createGradientBoostedTrees(store: GradientBoostedTreesStore): GradientBoostedTrees = { 156 | val numFeatures = featureNames.get.length 157 | val featureImportance = Array.fill[Double](numFeatures)(0.0) 158 | val decisionTrees = store.trees.map { 159 | case (internalTree) => 160 | val treeNodes = mutable.Map[java.lang.Integer, DecisionTreeNode]() 161 | val _ = 162 | DecisionTreeUtil.createDecisionTreeNode( 163 | internalTree.nodes(1), 164 | internalTree.nodes, 165 | featureImportance, 166 | this.featureBins.get, 167 | treeNodes 168 | ) 169 | DecisionTree(treeNodes.toMap, internalTree.nodes.size) 170 | }.toArray 171 | 172 | GradientBoostedTreesDefault( 173 | lossFunctionClassName = store.lossFunction.getClass.getCanonicalName, 174 | labelTransformer = this.labelTransformer.get, 175 | featureTransformers = this.featureTransformers.get, 176 | labelName = this.labelName.get, 177 | labelIsCat = this.labelIsCat.get, 178 | featureNames = this.featureNames.get, 179 | featureIsCat = this.featureIsCat.get, 180 | sortedVarImportance = 181 | scala.util.Sorting.stableSort( 182 | featureNames.get.zip(featureImportance.map(new java.lang.Double(_))).toSeq, 183 | // We want to sort in a descending importance order. 184 | (e1: (String, java.lang.Double), e2: (String, java.lang.Double)) => e1._2 > e2._2 185 | ), 186 | shrinkage = store.shrinkage, 187 | initValue = store.initVal, 188 | decisionTrees = decisionTrees, 189 | optimalTreeCnt = this.optimalTreeCnt.map(new java.lang.Integer(_)), 190 | trainingDevianceHistory = 191 | this.trainingDevianceHistory.get.map(new java.lang.Double(_)).toSeq, 192 | validationDevianceHistory = 193 | this.validationDevianceHistory.map(vdh => vdh.map(new java.lang.Double(_)).toSeq) 194 | ) 195 | } 196 | } 197 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/model/gb/GradientBoostedTreesStore.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.model.gb 19 | 20 | import scala.collection.mutable 21 | 22 | import spark_ml.gradient_boosting.loss.{LossAggregator, LossFunction} 23 | import spark_ml.tree_ensembles._ 24 | import spark_ml.util.{DiscretizedFeatureHandler, MapWithSequentialIntKeys} 25 | import spire.implicits._ 26 | 27 | /** 28 | * Used internally to store tree data while GB is training. This is not the 29 | * final trained model. 30 | */ 31 | class GBInternalTree extends Serializable { 32 | val nodes = mutable.Map[Int, NodeInfo]() 33 | val nodeAggregators = new MapWithSequentialIntKeys[LossAggregator]( 34 | initCapacity = 256 35 | ) 36 | 37 | /** 38 | * Add a new node to the tree. 39 | * @param nodeInfo node to add. 40 | */ 41 | def addNode(nodeInfo: NodeInfo): Unit = { 42 | // Sanity check ! 43 | assert( 44 | !nodes.contains(nodeInfo.nodeId), 45 | "A tree node with the Id " + nodeInfo.nodeId + " already exists." 46 | ) 47 | 48 | nodes.put(nodeInfo.nodeId, nodeInfo) 49 | } 50 | 51 | /** 52 | * Initialize aggregators for existing nodes. The aggregators are used to 53 | * aggregate samples per node to update node predictions, instead of using 54 | * predictions from CART training. 55 | */ 56 | def initNodeFinetuners(lossFunction: LossFunction): Unit = { 57 | val (startNodeId, endNodeId) = (nodes.keys.min, nodes.keys.max) 58 | // Sanity checks. 59 | assert( 60 | startNodeId == 1, 61 | "The starting root node must always have the Id 1. But we have " + startNodeId 62 | ) 63 | assert( 64 | nodes.size == (endNodeId - startNodeId + 1), 65 | "The number of nodes should equal " + (endNodeId - startNodeId + 1).toString + 66 | " but instead, we have " + nodes.size + " nodes." 67 | ) 68 | 69 | cfor(startNodeId)(_ <= endNodeId, _ + 1)( 70 | nodeId => nodeAggregators.put(nodeId, lossFunction.createAggregator) 71 | ) 72 | } 73 | 74 | /** 75 | * Add the given sample point to all the nodes that match the point. 76 | * @param samplePoint Sample point to add. This includes the label, the 77 | * current prediction and the features. 78 | * @param weight Weight of the sample point. 79 | * @param featureHandler Feature handler. 80 | * @tparam T Type of the discretized feature. 81 | */ 82 | def addSamplePointToMatchingNodes[@specialized(Byte, Short) T]( 83 | samplePoint: ((Double, Double), Array[T]), 84 | weight: Double, 85 | featureHandler: DiscretizedFeatureHandler[T]): Unit = { 86 | val ((label, curPred), features) = samplePoint 87 | // First add the point to the root node. 88 | nodeAggregators.get(1).addSamplePoint( 89 | label = label, 90 | weight = weight, 91 | curPred = curPred 92 | ) 93 | var curNode = nodes(1) 94 | 95 | // Then add the point to all the matching descendants. 96 | while (curNode.splitInfo.nonEmpty) { 97 | val splitInfo = curNode.splitInfo.get 98 | val featId = splitInfo.featureId 99 | val binId = featureHandler.convertToInt(features(featId)) 100 | val childId = splitInfo.chooseChildNode(binId).nodeId 101 | 102 | // We can't directly use the child node contained in splitInfo. 103 | // That child node is not the same as the one contained in nodes and 104 | // is incomplete. 105 | curNode = nodes(childId) 106 | val aggregator = nodeAggregators.get(childId) 107 | aggregator.addSamplePoint( 108 | label = label, 109 | weight = weight, 110 | curPred = curPred 111 | ) 112 | } 113 | } 114 | 115 | /** 116 | * Update a node's prediction with the given one. 117 | * @param nodeId Id of the node whose prediction we want to update. 118 | * @param newPrediction The new prediction for the node. 119 | */ 120 | def updateNodePrediction(nodeId: Int, newPrediction: Double): Unit = { 121 | nodes(nodeId).prediction = newPrediction 122 | } 123 | 124 | /** 125 | * Predict on the given features. 126 | * @param features Features. 127 | * @tparam T Type of features. 128 | * @return Prediction result. 129 | */ 130 | def predict[@specialized(Byte, Short) T]( 131 | features: Array[T], 132 | featureHandler: DiscretizedFeatureHandler[T] 133 | ): Double = { 134 | var curNode = nodes(1) 135 | while (curNode.splitInfo.nonEmpty) { 136 | val splitInfo = curNode.splitInfo.get 137 | val featId = splitInfo.featureId 138 | val binId = featureHandler.convertToInt(features(featId)) 139 | 140 | // We can't directly use the child node contained in splitInfo. 141 | // That child node is not the same as the one contained in nodes and 142 | // is incomplete. 143 | val childId = splitInfo.chooseChildNode(binId).nodeId 144 | curNode = nodes(childId) 145 | } 146 | curNode.prediction 147 | } 148 | 149 | /** 150 | * Print a visual representation of the internal tree with ASCII. 151 | * @return A string representation of the internal tree. 152 | */ 153 | override def toString: String = { 154 | val treeStringBuilder = new mutable.StringBuilder() 155 | val queue = new mutable.Queue[Int]() 156 | queue.enqueue(1) 157 | while (queue.nonEmpty) { 158 | val nodeInfo = nodes(queue.dequeue()) 159 | treeStringBuilder. 160 | append("nodeId:"). 161 | append(nodeInfo.nodeId). 162 | append(",prediction:"). 163 | append(nodeInfo.prediction) 164 | if (nodeInfo.splitInfo.isDefined) { 165 | treeStringBuilder. 166 | append(",splitFeatureId:"). 167 | append(nodeInfo.splitInfo.get.featureId) 168 | nodeInfo.splitInfo.get match { 169 | case numericNodeSplitInfo: NumericNodeSplitInfo => 170 | treeStringBuilder. 171 | append(",splitBinId:"). 172 | append(numericNodeSplitInfo.splitBinId) 173 | if (numericNodeSplitInfo.nanChildNode.isDefined) { 174 | treeStringBuilder. 175 | append(",hasNanChild") 176 | } 177 | case catNodeSplitInfo: CatNodeSplitInfo => 178 | treeStringBuilder. 179 | append(",splitMapping:") 180 | catNodeSplitInfo.binIdToChildNode.foreach { 181 | case (binId, childNodeId) => 182 | treeStringBuilder. 183 | append(";"). 184 | append(binId.toString + "->" + childNodeId.toString) 185 | } 186 | } 187 | 188 | queue.enqueue( 189 | nodeInfo.splitInfo.get.getOrderedChildNodes.map(_.nodeId).toSeq : _* 190 | ) 191 | } 192 | 193 | if (queue.nonEmpty && (nodes(queue.front).depth > nodeInfo.depth)) { 194 | treeStringBuilder.append("\n") 195 | } else { 196 | treeStringBuilder.append(" ") 197 | } 198 | } 199 | 200 | treeStringBuilder.toString() 201 | } 202 | } 203 | 204 | /** 205 | * Gradient boosted trees writer. 206 | * @param store The gradient boosted trees store object. 207 | */ 208 | class GradientBoostedTreesWriter(store: GradientBoostedTreesStore) 209 | extends TreeEnsembleWriter { 210 | var curTree: GBInternalTree = store.curTree 211 | 212 | /** 213 | * Write the node info to the currently active tree. 214 | * @param nodeInfo Node info to write. 215 | */ 216 | def writeNodeInfo(nodeInfo: NodeInfo): Unit = { 217 | curTree.addNode(nodeInfo) 218 | } 219 | } 220 | 221 | /** 222 | * Gradient boosted trees store. 223 | * @param lossFunction Loss function with which the gradient boosted trees are 224 | * trained. 225 | * @param initVal Initial prediction value for the model. 226 | * @param shrinkage Shrinkage value. 227 | */ 228 | class GradientBoostedTreesStore( 229 | val lossFunction: LossFunction, 230 | val initVal: Double, 231 | val shrinkage: Double 232 | ) extends TreeEnsembleStore { 233 | val trees: mutable.ArrayBuffer[GBInternalTree] = mutable.ArrayBuffer[GBInternalTree]() 234 | var curTree: GBInternalTree = null 235 | 236 | /** 237 | * Add a new tree. This tree becomes the new active tree. 238 | */ 239 | def initNewTree(): Unit = { 240 | curTree = new GBInternalTree 241 | trees += curTree 242 | } 243 | 244 | /** 245 | * Get a tree ensemble writer. 246 | * @return A tree ensemble writer. 247 | */ 248 | def getWriter: TreeEnsembleWriter = { 249 | new GradientBoostedTreesWriter(this) 250 | } 251 | 252 | /** 253 | * Get an internal tree. 254 | * @param idx Index of the tree to get. 255 | * @return An internal tree. 256 | */ 257 | def getInternalTree(idx: Int): GBInternalTree = { 258 | trees(idx) 259 | } 260 | } 261 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/model/rf/RandomForestStore.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.model.rf 19 | 20 | import scala.collection.mutable 21 | 22 | import spark_ml.discretization.Bins 23 | import spark_ml.model._ 24 | import spark_ml.tree_ensembles._ 25 | import spark_ml.util.MapWithSequentialIntKeys 26 | 27 | /** 28 | * A forest is simply a collection of equal-weight trees. 29 | * @param trees Trees in the forest. 30 | * @param splitCriteriaStr Tree split criteria (e.g. infogain or variance). 31 | * @param sortedVarImportance Sorted variable importance. 32 | * @param sampleCounts The number of training samples per tree. 33 | */ 34 | case class RandomForest( 35 | trees: Array[DecisionTree], 36 | splitCriteriaStr: String, 37 | sortedVarImportance: Seq[(String, java.lang.Double)], 38 | sampleCounts: Array[Long] 39 | ) { 40 | /** 41 | * Predict from the features. 42 | * @param features A double array of features. 43 | * @return Prediction(s) and corresponding weight(s) (e.g. probabilities or 44 | * variances of predictions, etc.) 45 | */ 46 | def predict(features: Array[Double]): Array[(Double, Double)] = { 47 | SplitCriteria.withName(splitCriteriaStr) match { 48 | case SplitCriteria.Classification_InfoGain => predictClass(features) 49 | case SplitCriteria.Regression_Variance => predictRegression(features) 50 | } 51 | } 52 | 53 | /** 54 | * Predict the class ouput from the given features - features should be in the 55 | * same order as the ones that the tree trained on. 56 | * @param features Feature values. 57 | * @return Predictions and their probabilities an array of (Double, Double) 58 | */ 59 | private def predictClass(features: Array[Double]): Array[(Double, Double)] = { 60 | val predictions = mutable.Map[Double, Double]() // Predicted label and its count. 61 | var treeId = 0 62 | while (treeId < trees.length) { 63 | val tree = trees(treeId) 64 | val prediction = tree.predict(features) 65 | predictions.getOrElseUpdate(prediction, 0) 66 | predictions(prediction) += 1.0 67 | 68 | treeId += 1 69 | } 70 | 71 | // Sort the predictions by the number of occurrences. 72 | // The first element has the highest number of occurrences. 73 | val sortedPredictions = predictions.toArray.sorted( 74 | Ordering.by[(Double, Double), Double](-_._2) 75 | ) 76 | sortedPredictions.map(p => (p._1, p._2 / trees.length.toDouble)) 77 | } 78 | 79 | /** 80 | * Predict a continuous output from the given features - features should be in 81 | * the same order as the ones that the tree trained on. 82 | * @param features Feature values. 83 | * @return Prediction and its variance (a single element array of (Double, Double)) 84 | */ 85 | private def predictRegression(features: Array[Double]): Array[(Double, Double)] = { 86 | var predictionSum = 0.0 87 | var predictionSqrSum = 0.0 88 | var treeId = 0 89 | while (treeId < trees.length) { 90 | val tree = trees(treeId) 91 | val prediction = tree.predict(features) 92 | predictionSum += prediction 93 | predictionSqrSum += prediction * prediction 94 | 95 | treeId += 1 96 | } 97 | 98 | val predAvg = predictionSum / trees.length.toDouble 99 | val predVar = predictionSqrSum / trees.length.toDouble - predAvg * predAvg 100 | Array[(Double, Double)]((predAvg, predVar)) 101 | } 102 | } 103 | 104 | class RFInternalTree extends Serializable { 105 | val nodes = mutable.Map[Int, NodeInfo]() 106 | 107 | /** 108 | * Add a new node to the tree. 109 | * @param nodeInfo node to add. 110 | */ 111 | def addNode(nodeInfo: NodeInfo): Unit = { 112 | // Sanity check ! 113 | assert( 114 | !nodes.contains(nodeInfo.nodeId), 115 | "A tree node with the Id " + nodeInfo.nodeId + " already exists." 116 | ) 117 | 118 | nodes.put(nodeInfo.nodeId, nodeInfo) 119 | } 120 | } 121 | 122 | class RandomForestWriter(store: RandomForestStore) 123 | extends TreeEnsembleWriter { 124 | def writeNodeInfo(nodeInfo: NodeInfo): Unit = { 125 | if (!store.trees.contains(nodeInfo.treeId)) { 126 | store.trees.put(nodeInfo.treeId, new RFInternalTree) 127 | } 128 | store.trees.get(nodeInfo.treeId).addNode(nodeInfo) 129 | } 130 | } 131 | 132 | /** 133 | * A default random forest store. 134 | * @param splitCriteria The split criteria for trees. 135 | */ 136 | class RandomForestStore( 137 | splitCriteria: SplitCriteria.SplitCriteria, 138 | featureNames: Array[String], 139 | featureBins: Array[Bins] 140 | ) extends TreeEnsembleStore { 141 | val trees = new MapWithSequentialIntKeys[RFInternalTree]( 142 | initCapacity = 100 143 | ) 144 | 145 | def getWriter: TreeEnsembleWriter = { 146 | new RandomForestWriter(this) 147 | } 148 | 149 | def createRandomForest: RandomForest = { 150 | val featureImportance = Array.fill[Double](featureBins.length)(0.0) 151 | val (startTreeId, endTreeId) = this.trees.getKeyRange 152 | val decisionTrees = (startTreeId to endTreeId).map { 153 | case treeId => 154 | val internalTree = this.trees.get(treeId) 155 | val treeNodes = mutable.Map[java.lang.Integer, DecisionTreeNode]() 156 | val _ = 157 | DecisionTreeUtil.createDecisionTreeNode( 158 | internalTree.nodes(1), 159 | internalTree.nodes, 160 | featureImportance, 161 | this.featureBins, 162 | treeNodes 163 | ) 164 | DecisionTree(treeNodes.toMap, internalTree.nodes.size) 165 | }.toArray 166 | 167 | RandomForest( 168 | trees = decisionTrees, 169 | splitCriteriaStr = splitCriteria.toString, 170 | sortedVarImportance = 171 | scala.util.Sorting.stableSort( 172 | featureNames.zip(featureImportance.map(new java.lang.Double(_))).toSeq, 173 | // We want to sort in a descending importance order. 174 | (e1: (String, java.lang.Double), e2: (String, java.lang.Double)) => e1._2 > e2._2 175 | ), 176 | sampleCounts = decisionTrees.map{ _.nodes(1).nodeWeight.toLong } 177 | ) 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/transformation/DataTransformationUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.transformation 19 | 20 | import org.apache.spark.rdd.RDD 21 | import org.apache.spark.sql.DataFrame 22 | import spark_ml.discretization.CardinalityOverLimitException 23 | 24 | /** 25 | * A class that encodes the transformation to a numeric value for a column. 26 | * @param distinctValToInt If this is defined, then the column value would be 27 | * transformed to a mapped integer. 28 | * @param maxCardinality If this is defined, then the column value would be 29 | * transformed into an integer value in 30 | * [0, maxCardinality) through hashing. 31 | * @return a transformed numeric value for a column value. If neither 32 | * distinctValToInt nor maxCardinality is defined, then the value 33 | * would be assumed to be a numeric value already and passed through. 34 | */ 35 | case class ColumnTransformer( 36 | // We are using java.lang.Integer so that this can be easily serialized by 37 | // Gson. 38 | distinctValToInt: Option[Map[String, java.lang.Integer]], 39 | maxCardinality: Option[java.lang.Integer]) { 40 | def transform(value: String): Double = { 41 | if (distinctValToInt.isDefined) { 42 | val nonNullValue = 43 | if (value == null) { 44 | "" 45 | } else { 46 | value 47 | } 48 | distinctValToInt.get(nonNullValue).toDouble 49 | } else if (maxCardinality.isDefined) { 50 | val nonNullValue = 51 | if (value == null) { 52 | "" 53 | } else { 54 | value 55 | } 56 | DataTransformationUtils.getSimpleHashedValue(nonNullValue, maxCardinality.get) 57 | } else { 58 | if (value == null) { 59 | Double.NaN 60 | } else { 61 | value.toDouble 62 | } 63 | } 64 | } 65 | } 66 | 67 | object DataTransformationUtils { 68 | /** 69 | * Convert the given data frame into an RDD of label, feature vector pairs. 70 | * All label/feature values are also converted into Double. 71 | * @param dataFrame Spark Data Frame that we want to convert. 72 | * @param labelColIndex Label column index. 73 | * @param catDistinctValToInt Categorical column distinct value to int maps. 74 | * This is used to map distinct string values to 75 | * numeric values (doubles). 76 | * @param colsToIgnoreIndices Indices of columns in the data frame to be 77 | * ignored (not used as features or label). 78 | * @param maxCatCardinality Maximum categorical cardinality we allow. If the 79 | * cardinality goes over this, feature hashing might 80 | * be used (or will simply throw an exception). 81 | * @param useFeatureHashing Whether feature hashing should be used on 82 | * categorical columns whose unique value counts 83 | * exceed the maximum cardinality. 84 | * @return An RDD of label/feature-vector pairs and column transformer definitions. 85 | */ 86 | def transformDataFrameToLabelFeatureRdd( 87 | dataFrame: DataFrame, 88 | labelColIndex: Int, 89 | catDistinctValToInt: Map[Int, Map[String, java.lang.Integer]], 90 | colsToIgnoreIndices: Set[Int], 91 | maxCatCardinality: Int, 92 | useFeatureHashing: Boolean): (RDD[(Double, Array[Double])], (ColumnTransformer, Array[ColumnTransformer])) = { 93 | val transformedRDD = dataFrame.map(row => { 94 | val labelValue = row.get(labelColIndex) 95 | 96 | Tuple2( 97 | if (catDistinctValToInt.contains(labelColIndex)) { 98 | val nonNullLabelValue = 99 | if (labelValue == null) { 100 | "" 101 | } else { 102 | labelValue.toString 103 | } 104 | mapCategoricalValueToNumericValue( 105 | labelColIndex, 106 | nonNullLabelValue, 107 | catDistinctValToInt(labelColIndex), 108 | maxCatCardinality, 109 | useFeatureHashing 110 | ) 111 | } else { 112 | if (labelValue == null) { 113 | Double.NaN 114 | } else { 115 | labelValue.toString.toDouble 116 | } 117 | }, 118 | row.toSeq.zipWithIndex.flatMap { 119 | case (colVal, idx) => 120 | if (colsToIgnoreIndices.contains(idx) || (labelColIndex == idx)) { 121 | Array[Double]().iterator 122 | } else { 123 | if (catDistinctValToInt.contains(idx)) { 124 | val nonNullColVal = 125 | if (colVal == null) { 126 | "" 127 | } else { 128 | colVal.toString 129 | } 130 | Array( 131 | mapCategoricalValueToNumericValue( 132 | idx, 133 | nonNullColVal, 134 | catDistinctValToInt(idx), 135 | maxCatCardinality, 136 | useFeatureHashing 137 | ) 138 | ).iterator 139 | } else { 140 | val nonNullColVal = 141 | if (colVal == null) { 142 | Double.NaN 143 | } else { 144 | colVal.toString.toDouble 145 | } 146 | Array(nonNullColVal).iterator 147 | } 148 | } 149 | }.toArray 150 | ) 151 | }) 152 | 153 | val numCols = dataFrame.columns.length 154 | val colTransformers = Tuple2( 155 | if (catDistinctValToInt.contains(labelColIndex)) { 156 | if (catDistinctValToInt(labelColIndex).size <= maxCatCardinality) { 157 | ColumnTransformer(Some(catDistinctValToInt(labelColIndex)), None) 158 | } else { 159 | ColumnTransformer(None, Some(maxCatCardinality)) 160 | } 161 | } else { 162 | ColumnTransformer(None, None) 163 | }, 164 | (0 to (numCols - 1)).flatMap { 165 | case (colIdx) => 166 | if (colsToIgnoreIndices.contains(colIdx) || (labelColIndex == colIdx)) { 167 | Array[ColumnTransformer]().iterator 168 | } else { 169 | if (catDistinctValToInt.contains(colIdx)) { 170 | if (catDistinctValToInt(colIdx).size <= maxCatCardinality) { 171 | Array(ColumnTransformer(Some(catDistinctValToInt(colIdx)), None)).iterator 172 | } else { 173 | Array(ColumnTransformer(None, Some(maxCatCardinality))).iterator 174 | } 175 | } else { 176 | Array(ColumnTransformer(None, None)).iterator 177 | } 178 | } 179 | }.toArray 180 | ) 181 | 182 | (transformedRDD, colTransformers) 183 | } 184 | 185 | /** 186 | * Map a categorical value (string) to a numeric value (double). 187 | * @param categoryId Equal to the column index. 188 | * @param catValue Categorical value that we want to map to a numeric value. 189 | * @param catValueToIntMap A map from categorical value to integers. If the 190 | * cardinality of the category is less than or equal 191 | * to maxCardinality, this is used. 192 | * @param maxCardinality The maximum allowed cardinality. 193 | * @param useFeatureHashing Whether feature hashing should be performed if the 194 | * cardinality of the category exceeds the maximum 195 | * cardinality. 196 | * @return A mapped double value. 197 | */ 198 | def mapCategoricalValueToNumericValue( 199 | categoryId: Int, 200 | catValue: String, 201 | catValueToIntMap: Map[String, java.lang.Integer], 202 | maxCardinality: Int, 203 | useFeatureHashing: Boolean): Double = { 204 | if (catValueToIntMap.size <= maxCardinality) { 205 | catValueToIntMap(catValue).toDouble 206 | } else { 207 | if (useFeatureHashing) { 208 | getSimpleHashedValue(catValue, maxCardinality) 209 | } else { 210 | throw new CardinalityOverLimitException( 211 | "The categorical column with the index " + categoryId + 212 | " has a cardinality that exceeds the limit " + maxCardinality) 213 | } 214 | } 215 | } 216 | 217 | /** 218 | * Compute a simple hash of a categorical value. 219 | * @param catValue Categorical value that we want to compute a simple numeric 220 | * hash for. The hash value would be an integer between 0 and 221 | * (maxCardinality - 1). 222 | * @param maxCardinality The value of the hash is limited within 223 | * [0, maxCardinality). 224 | * @return A hashed value of double. 225 | */ 226 | def getSimpleHashedValue(catValue: String, maxCardinality: Int): Double = { 227 | ((catValue.hashCode.toLong - Int.MinValue.toLong) % maxCardinality.toLong).toDouble 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/transformation/DistinctValueCounter.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.transformation 19 | 20 | import scala.collection.mutable 21 | 22 | import org.apache.spark.sql.DataFrame 23 | 24 | object DistinctValueCounter { 25 | /** 26 | * Get distinct values of categorical columns from the given data frame. 27 | * @param dataFrame Data frame from whose columns we want to gather distinct 28 | * values from. 29 | * @param catColIndices Categorical column indices. 30 | * @param maxCardinality The maximum number of unique values per column. 31 | * @return A map of column index and distinct value sets. 32 | */ 33 | def getDistinctValues( 34 | dataFrame: DataFrame, 35 | catColIndices: Set[Int], 36 | maxCardinality: Int): Map[Int, mutable.Set[String]] = { 37 | dataFrame.mapPartitions( 38 | rowItr => { 39 | val distinctVals = mutable.Map[Int, mutable.Set[String]]() 40 | while (rowItr.hasNext) { 41 | val row = rowItr.next() 42 | row.toSeq.zipWithIndex.map { 43 | case (colVal, colIdx) => 44 | if (catColIndices.contains(colIdx)) { 45 | val colDistinctVals = distinctVals.getOrElseUpdate(colIdx, mutable.Set[String]()) 46 | 47 | // We don't care to count all the unique values if the distinct 48 | // count goes over the given limit. 49 | if (colDistinctVals.size <= maxCardinality) { 50 | val nonNullColVal = 51 | if (colVal == null) { 52 | "" 53 | } else { 54 | colVal.toString 55 | } 56 | colDistinctVals.add(nonNullColVal) 57 | } 58 | } 59 | } 60 | } 61 | 62 | distinctVals.toIterator 63 | } 64 | ).reduceByKey { 65 | (colDistinctVals1, colDistinctVals2) => 66 | (colDistinctVals1 ++ colDistinctVals2).splitAt(maxCardinality + 1)._1 67 | }.collect().toMap 68 | } 69 | 70 | /** 71 | * Map a set of distinct values to an increasing non-negative numbers. 72 | * E.g., {'Women' -> 0, 'Men' -> 1}, etc. 73 | * @param distinctValues A set of distinct values for different columns (first 74 | * key is index to a column). 75 | * @param useEmptyString Whether an empty string should be used as a distinct 76 | * value. 77 | * @return A map of distinct values to integers for different columns (first 78 | * key is index to a column). 79 | */ 80 | def mapDistinctValuesToIntegers( 81 | distinctValues: Map[Int, mutable.Set[String]], 82 | useEmptyString: Boolean 83 | ): Map[Int, Map[String, java.lang.Integer]] = { 84 | distinctValues.map { 85 | case (colIndex, values) => 86 | colIndex -> values.filter(value => !(value == "" && !useEmptyString)).zipWithIndex.map { 87 | case (value, mappedVal) => value -> new java.lang.Integer(mappedVal) 88 | }.toMap 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/tree_ensembles/IdCache.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.tree_ensembles 19 | 20 | import scala.collection.mutable 21 | 22 | import org.apache.hadoop.fs.{FileSystem, Path} 23 | import org.apache.spark.rdd.RDD 24 | import org.apache.spark.storage.StorageLevel 25 | import spark_ml.util.DiscretizedFeatureHandler 26 | import spire.implicits._ 27 | 28 | object IdCache { 29 | /** 30 | * Create a new Id cache object, filled with initial Ids (1's). 31 | * @param numTrees Number of trees we expect. 32 | * @param data A data RDD or an RDD that contains the same number of rows as 33 | * the data. 34 | * @param storageLevel Cache storage level. 35 | * @param checkpointDir Checkpoint directory for intermediate checkpointing. 36 | * @param checkpointInterval Checkpointing interval. 37 | * @tparam T Type of data rows. 38 | * @return A new IdCache object. 39 | */ 40 | def createIdCache[T]( 41 | numTrees: Int, 42 | data: RDD[T], 43 | storageLevel: StorageLevel, 44 | checkpointDir: Option[String], 45 | checkpointInterval: Int): IdCache = { 46 | new IdCache( 47 | // All the Ids start with '1', meaning that all the rows are 48 | // assigned to the root node. 49 | curIds = data.map(_ => Array.fill[Int](numTrees)(1)), 50 | storageLevel = storageLevel, 51 | checkpointDir = checkpointDir, 52 | checkpointInterval = checkpointInterval 53 | ) 54 | } 55 | } 56 | 57 | /** 58 | * Id cache to keep track of Ids of rows to indicate which tree nodes that 59 | * rows to belong to. 60 | * @param curIds RDD of the current Ids per data row. Each row of the RDD 61 | * is an array of Ids, each element an Id for a tree. 62 | * @param storageLevel Cache storage level. 63 | * @param checkpointDir Checkpoint directory. 64 | * @param checkpointInterval Checkpoint interval. 65 | */ 66 | class IdCache( 67 | var curIds: RDD[Array[Int]], 68 | storageLevel: StorageLevel, 69 | checkpointDir: Option[String], 70 | checkpointInterval: Int) { 71 | private var prevIds: RDD[Array[Int]] = null 72 | private var updateCount: Int = 0 73 | 74 | // To keep track of last checkpointed RDDs. 75 | private val checkpointQueue = new mutable.Queue[RDD[Array[Int]]]() 76 | 77 | // Persist the initial Ids. 78 | curIds = curIds.persist(storageLevel) 79 | 80 | // If a checkpoint directory is given, and there's no prior checkpoint 81 | // directory, then set the checkpoint directory with the given one. 82 | if (checkpointDir.isDefined && curIds.sparkContext.getCheckpointDir.isEmpty) { 83 | curIds.sparkContext.setCheckpointDir(checkpointDir.get) 84 | } 85 | 86 | /** 87 | * Get the current Id RDD. 88 | * @return curIds RDD. 89 | */ 90 | def getRdd: RDD[Array[Int]] = curIds 91 | 92 | /** 93 | * Update Ids that are stored in the cache RDD. 94 | * @param data RDD of data rows needed to find the updated Ids. 95 | * @param idLookupForUpdaters Id updaters. 96 | * @param featureHandler Data row feature type handler. 97 | * @tparam T Type of feature. 98 | */ 99 | def updateIds[@specialized(Byte, Short) T]( 100 | data: RDD[((Double, Array[T]), Array[Byte])], 101 | idLookupForUpdaters: IdLookupForUpdaters, 102 | featureHandler: DiscretizedFeatureHandler[T]): Unit = { 103 | if (prevIds != null) { 104 | // Unpersist the previous one if one exists. 105 | prevIds.unpersist(blocking = true) 106 | } 107 | 108 | prevIds = curIds 109 | 110 | // Update Ids. 111 | curIds = data.zip(curIds).map { 112 | case (((label, features), baggedCounts), nodeIds) => 113 | val numTrees = nodeIds.length 114 | cfor(0)(_ < numTrees, _ + 1)( 115 | treeId => { 116 | val curNodeId = nodeIds(treeId) 117 | val rowCnt = baggedCounts(treeId) 118 | if (rowCnt > 0 && curNodeId != 0) { 119 | val idUpdater = idLookupForUpdaters.get( 120 | treeId = treeId, 121 | id = curNodeId 122 | ) 123 | if (idUpdater != null) { 124 | nodeIds(treeId) = idUpdater.updateId(features = features, featureHandler = featureHandler) 125 | } 126 | } 127 | } 128 | ) 129 | 130 | nodeIds 131 | }.persist(storageLevel) 132 | 133 | updateCount += 1 134 | 135 | // Handle checkpointing if the directory is not None. 136 | if (curIds.sparkContext.getCheckpointDir.isDefined && 137 | (updateCount % checkpointInterval) == 0) { 138 | // See if we can delete previous checkpoints. 139 | var canDelete = true 140 | while (checkpointQueue.size > 1 && canDelete) { 141 | // We can delete the oldest checkpoint iff the next checkpoint actually 142 | // exists in the file system. 143 | if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) { 144 | val old = checkpointQueue.dequeue() 145 | 146 | // Since the old checkpoint is not deleted by Spark, we'll manually 147 | // delete it here. 148 | val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) 149 | println("Deleting a stale IdCache RDD checkpoint at " + old.getCheckpointFile.get) 150 | fs.delete(new Path(old.getCheckpointFile.get), true) 151 | } else { 152 | canDelete = false 153 | } 154 | } 155 | 156 | curIds.checkpoint() 157 | checkpointQueue.enqueue(curIds) 158 | } 159 | } 160 | 161 | /** 162 | * Unpersist all the RDDs stored internally. 163 | */ 164 | def close(): Unit = { 165 | // Unpersist and delete all the checkpoints. 166 | curIds.unpersist(blocking = true) 167 | if (prevIds != null) { 168 | prevIds.unpersist(blocking = true) 169 | } 170 | 171 | while (checkpointQueue.nonEmpty) { 172 | val old = checkpointQueue.dequeue() 173 | if (old.getCheckpointFile.isDefined) { 174 | val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) 175 | println("Deleting a stale IdCache RDD checkpoint at " + old.getCheckpointFile.get) 176 | fs.delete(new Path(old.getCheckpointFile.get), true) 177 | } 178 | } 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/tree_ensembles/IdLookup.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.tree_ensembles 19 | 20 | import scala.reflect.ClassTag 21 | 22 | import spire.implicits._ 23 | 24 | /** 25 | * The Id lookup object will contain sequential Ids from startId to endId. 26 | * @param startId Start Id. 27 | * @param endId End Id. 28 | */ 29 | case class IdRange( 30 | var startId: Int, 31 | var endId: Int) 32 | 33 | /** 34 | * Each data point belongs to a particular tree's node and is assigned the Id of 35 | * the node (the Id is only for training). This structure is used to find any 36 | * object (e.g. aggregator) that corresponds to the node that the point belongs 37 | * to. 38 | * @param idRanges Id ranges for each tree. 39 | * @tparam T The type of the look up objects. 40 | */ 41 | abstract class IdLookup[T: ClassTag](idRanges: Array[IdRange]) 42 | extends Serializable { 43 | // This is the array that contains the corresponding objects for each Id. 44 | protected var lookUpObjs: Array[Array[T]] = null 45 | var objCnt: Int = 0 46 | 47 | // This is used to initialize the look up obj for tree/node pairs. 48 | protected def initLookUpObjs( 49 | initLookUpObj: (Int, Int) => T): Unit = { 50 | val numTrees = idRanges.length 51 | lookUpObjs = Array.fill[Array[T]](numTrees)(null) 52 | 53 | // Initialize the look up objects for tree/nodes. 54 | // There's one array per tree. 55 | cfor(0)(_ < numTrees, _ + 1)( 56 | treeId => { 57 | val idRange = idRanges(treeId) 58 | if (idRange != null) { 59 | val numIds = idRange.endId - idRange.startId + 1 60 | lookUpObjs(treeId) = Array.fill[T](numIds)(null.asInstanceOf[T]) 61 | cfor(idRange.startId)(_ <= idRange.endId, _ + 1)( 62 | id => { 63 | val lookUpObjIdx = id - idRange.startId 64 | lookUpObjs(treeId)(lookUpObjIdx) = initLookUpObj(treeId, id) 65 | objCnt += 1 66 | } 67 | ) 68 | } 69 | } 70 | ) 71 | } 72 | 73 | /** 74 | * Get Id ranges. 75 | * @return Id ranges. 76 | */ 77 | def getIdRanges: Array[IdRange] = idRanges 78 | 79 | /** 80 | * Get a look up object for the given tree/node. 81 | * @param treeId Tree Id. 82 | * @param id Id representing the node. 83 | * During training, nodes can be assigned arbitrary Ids. 84 | * They are not necessarily the final model node Ids. 85 | * @return The corresponding look up object. 86 | */ 87 | def get(treeId: Int, id: Int): T = { 88 | val idRange = idRanges(treeId) 89 | if (idRange != null) { 90 | val startId = idRange.startId 91 | val endId = idRange.endId 92 | if (id >= startId && id <= endId) { 93 | lookUpObjs(treeId)(id - startId) 94 | } else { 95 | null.asInstanceOf[T] 96 | } 97 | } else { 98 | null.asInstanceOf[T] 99 | } 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/tree_ensembles/IdLookupForNodeStats.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.tree_ensembles 19 | 20 | import scala.collection.mutable 21 | import scala.util.Random 22 | 23 | import spark_ml.discretization.Bins 24 | import spark_ml.util.{MapWithSequentialIntKeys, RandomSet} 25 | import spire.implicits._ 26 | 27 | /** 28 | * Id lookup for aggregating node statistics during training. 29 | * @param idRanges Id range for nodes to look up. 30 | * @param nodeDepths A map containing depths of all the nodes in the ranges. 31 | */ 32 | class IdLookupForNodeStats( 33 | idRanges: Array[IdRange], 34 | nodeDepths: Array[MapWithSequentialIntKeys[Int]] 35 | ) extends IdLookup[NodeStats](idRanges) { 36 | 37 | /** 38 | * Initialize this look up object. 39 | * @param treeType Tree type by the split criteria (e.g. classification based 40 | * on info-gain or regression based on variance.) 41 | * @param featureBinsInfo Feature discretization info. 42 | * @param treeSeeds Random seeds for trees. 43 | * @param mtry mtry. 44 | * @param numClasses Optional number of target classes (for classifications). 45 | */ 46 | def initNodeStats( 47 | treeType: SplitCriteria.SplitCriteria, 48 | featureBinsInfo: Array[Bins], 49 | treeSeeds: Array[Int], 50 | mtry: Int, 51 | numClasses: Option[Int]): Unit = { 52 | val numFeatures = featureBinsInfo.length 53 | def createNodeStats(treeId: Int, id: Int): NodeStats = { 54 | val mtryFeatureIds = RandomSet.nChooseK( 55 | k = mtry, 56 | n = numFeatures, 57 | rnd = new Random(treeSeeds(treeId) + id) 58 | ) 59 | NodeStats.createNodeStats( 60 | treeId = treeId, 61 | nodeId = id, 62 | nodeDepth = nodeDepths(treeId).get(id), 63 | treeType = treeType, 64 | featureBinsInfo = featureBinsInfo, 65 | mtryFeatureIds = mtryFeatureIds, 66 | numClasses = numClasses 67 | ) 68 | } 69 | 70 | initLookUpObjs(createNodeStats) 71 | } 72 | 73 | /** 74 | * Get an iterator of all the nodestats. Each nodestats gets assigned an 75 | * incrementing hash value. This is useful to even distribute aggregated 76 | * nodestats to different machines to perform distributed splits. 77 | * @return An iterator of pairs of (hashValue, nodestats) and the count of 78 | * node stats. 79 | */ 80 | def toHashedNodeStatsIterator: (Iterator[(Int, NodeStats)], Int) = { 81 | val numTrees = idRanges.length 82 | var curHash = 0 83 | val output = new mutable.ListBuffer[(Int, NodeStats)]() 84 | cfor(0)(_ < numTrees, _ + 1)( 85 | treeId => { 86 | if (idRanges(treeId) != null) { 87 | val numIds = idRanges(treeId).endId - idRanges(treeId).startId + 1 88 | cfor(0)(_ < numIds, _ + 1)( 89 | i => { 90 | val nodeStats = lookUpObjs(treeId)(i) 91 | output += ((curHash, nodeStats)) 92 | curHash += 1 93 | } 94 | ) 95 | } 96 | } 97 | ) 98 | 99 | (output.toIterator, curHash) 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/tree_ensembles/IdLookupForSubTreeInfo.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.tree_ensembles 19 | 20 | import spark_ml.util.MapWithSequentialIntKeys 21 | 22 | /** 23 | * Sub tree info used during sub-tree training. 24 | * @param id Id of the subtree. 25 | * @param hash Hash value for the sub-tree (used to push hash sub-tree data 26 | * evenly to different executors). 27 | * @param depth Depth of the sub-tree from the parent tree perspective. 28 | * @param parentTreeId Parent tree Id. 29 | */ 30 | case class SubTreeInfo( 31 | id: Int, 32 | hash: Int, 33 | depth: Int, 34 | parentTreeId: Int) { 35 | /** 36 | * Override the hashCode to return the subTreeHash value. 37 | * @return The subTreeHash value. 38 | */ 39 | override def hashCode: Int = { 40 | hash 41 | } 42 | } 43 | 44 | /** 45 | * Sub tree info lookup used to find matching data points for each sub tree. 46 | * @param idRanges Id ranges for each tree. 47 | */ 48 | class IdLookupForSubTreeInfo( 49 | idRanges: Array[IdRange] 50 | ) extends IdLookup[SubTreeInfo](idRanges) 51 | 52 | object IdLookupForSubTreeInfo { 53 | /** 54 | * Create a new Id lookup object for sub-trees. 55 | * @param idRanges Id ranges of sub trees for parent trees. 56 | * @param subTreeMaps Sub tree info maps for parent trees. 57 | * @return Id lookup object for sub-trees. 58 | */ 59 | def createIdLookupForSubTreeInfo( 60 | idRanges: Array[IdRange], 61 | subTreeMaps: Array[MapWithSequentialIntKeys[SubTreeInfo]]): IdLookupForSubTreeInfo = { 62 | val lookup = new IdLookupForSubTreeInfo(idRanges) 63 | lookup.initLookUpObjs( 64 | (treeId: Int, id: Int) => subTreeMaps(treeId).get(id) 65 | ) 66 | 67 | lookup 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/tree_ensembles/IdLookupForUpdaters.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.tree_ensembles 19 | 20 | import spark_ml.util.{DiscretizedFeatureHandler, MapWithSequentialIntKeys} 21 | 22 | /** 23 | * Update the Id of the node that a row belongs to via the split information. 24 | * @param splitInfo split information that will determine the new Id for 25 | * the given features. 26 | */ 27 | case class IdUpdater(splitInfo: NodeSplitInfo) { 28 | def updateId[@specialized(Byte, Short) T]( 29 | features: Array[T], 30 | featureHandler: DiscretizedFeatureHandler[T]): Int = { 31 | if (splitInfo == null) { 32 | 0 // 0 indicates that the data point has reached a terminal node. 33 | } else { 34 | val binId = featureHandler.convertToInt(features(splitInfo.featureId)) 35 | 36 | // The nodeId here is usually not the same as the final tree's node Id. 37 | // This is more of a temporary value to refer split Ids during training. 38 | splitInfo.chooseChildNode(binId).nodeId 39 | } 40 | } 41 | } 42 | 43 | /** 44 | * A look up for Id updaters. This is used to update Ids of nodes that 45 | * data points belong to during training. 46 | */ 47 | class IdLookupForUpdaters( 48 | idRanges: Array[IdRange] 49 | ) extends IdLookup[IdUpdater](idRanges) 50 | 51 | object IdLookupForUpdaters { 52 | /** 53 | * Create a new Id lookup object for updaters. 54 | * @param idRanges Id ranges of updaters. 55 | * @param updaterMaps Maps of Id updaters. 56 | * @return Id lookup object for updaters. 57 | */ 58 | def createIdLookupForUpdaters( 59 | idRanges: Array[IdRange], 60 | updaterMaps: Array[MapWithSequentialIntKeys[IdUpdater]] 61 | ): IdLookupForUpdaters = { 62 | val lookup = new IdLookupForUpdaters(idRanges) 63 | lookup.initLookUpObjs( 64 | (treeId: Int, id: Int) => updaterMaps(treeId).get(id) 65 | ) 66 | 67 | lookup 68 | } 69 | } -------------------------------------------------------------------------------- /src/main/scala/spark_ml/tree_ensembles/InfoGainNodeStats.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.tree_ensembles 19 | 20 | import breeze.numerics.log2 21 | import spark_ml.util.DiscretizedFeatureHandler 22 | import spire.implicits._ 23 | 24 | /** 25 | * Information gain node statistics. 26 | * @param treeId Id of the tree that the node belongs to. 27 | * @param nodeId Id of the node. 28 | * @param nodeDepth Depth of the node. 29 | * @param statsArray The actual statistics array. 30 | * @param mtryFeatures Selected feature descriptions. 31 | * @param numElemsPerBin Number of statistical elements per feature bin. 32 | */ 33 | class InfoGainNodeStats( 34 | treeId: Int, 35 | nodeId: Int, 36 | nodeDepth: Int, 37 | statsArray: Array[Double], 38 | mtryFeatures: Array[SelectedFeatureInfo], 39 | numElemsPerBin: Int) extends NodeStats( 40 | treeId = treeId, 41 | nodeId = nodeId, 42 | nodeDepth = nodeDepth, 43 | statsArray = statsArray, 44 | mtryFeatures = mtryFeatures, 45 | numElemsPerBin = numElemsPerBin) { 46 | 47 | /** 48 | * Add statistics related to a sample (label and features). 49 | * @param label Label of the sample. 50 | * @param features Features of the sample. 51 | * @param sampleCnt Sample count of the sample. 52 | * @param featureHandler Feature type handler. 53 | * @tparam T Feature type (Byte or Short). 54 | */ 55 | override def addSample[@specialized(Byte, Short) T]( 56 | label: Double, 57 | features: Array[T], 58 | sampleCnt: Int, 59 | featureHandler: DiscretizedFeatureHandler[T]): Unit = { 60 | val mtry = mtryFeatures.length 61 | cfor(0)(_ < mtry, _ + 1)( 62 | i => { 63 | val featId = mtryFeatures(i).featId 64 | val featOffset = mtryFeatures(i).offset 65 | val binId = featureHandler.convertToInt(features(featId)) 66 | statsArray(featOffset + binId * numElemsPerBin + label.toInt) += 67 | sampleCnt 68 | } 69 | ) 70 | } 71 | 72 | /** 73 | * Calculate the node values from the given class distribution. This also 74 | * includes calculating the entropy. 75 | * @param classWeights The class weight distribution. 76 | * @param offset Starting offset. 77 | * @param output Where the output will be stored. 78 | * @return Node values. 79 | */ 80 | override def calculateNodeValues( 81 | classWeights: Array[Double], 82 | offset: Int, 83 | output: PreallocatedNodeValues): PreallocatedNodeValues = { 84 | var weightSum: Double = 0.0 85 | var prediction: Double = 0.0 86 | var maxClassWeight: Double = 0.0 87 | var entropy: Double = 0.0 88 | 89 | // Determine weightSum and prediction. 90 | cfor(0)(_ < numElemsPerBin, _ + 1)( 91 | labelId => { 92 | val weight = classWeights(offset + labelId) 93 | output.sumStats(labelId) = weight 94 | if (maxClassWeight < weight) { 95 | maxClassWeight = weight 96 | prediction = labelId 97 | } 98 | 99 | weightSum += weight 100 | } 101 | ) 102 | 103 | // Compute entropy. 104 | cfor(0)(_ < numElemsPerBin, _ + 1)( 105 | labelId => { 106 | val weight = classWeights(offset + labelId) 107 | if (weight > 0.0) { 108 | val prob = weight.toDouble / weightSum 109 | entropy -= prob * log2(prob) 110 | } 111 | } 112 | ) 113 | 114 | output.prediction = prediction 115 | output.addendum = maxClassWeight / weightSum // Probability of the class. 116 | output.weight = weightSum 117 | output.impurity = entropy 118 | 119 | output 120 | } 121 | 122 | /** 123 | * Get bin weights. 124 | * @param statsArray Stats array. 125 | * @param offset Start offset. 126 | * @param numBins Number of bins. 127 | * @param output This is where the weights will be stored. 128 | * @return Returns the same output that was passed in. 129 | */ 130 | override def getBinWeights( 131 | statsArray: Array[Double], 132 | offset: Int, 133 | numBins: Int, 134 | output: Array[Double]): Array[Double] = { 135 | cfor(0)(_ < numBins, _ + 1)( 136 | binId => { 137 | val binOffset = offset + binId * numElemsPerBin 138 | // Sum all the label weights per bin. 139 | cfor(0)(_ < numElemsPerBin, _ + 1)( 140 | labelId => output(binId) += statsArray(binOffset + labelId) 141 | ) 142 | } 143 | ) 144 | 145 | output 146 | } 147 | 148 | /** 149 | * Calculate the label average for the given bin. This is only meaningful for 150 | * binary classifications. 151 | * @param statsArray Stats array. 152 | * @param binOffset Offset to the bin. 153 | * @return The label average for the bin. 154 | */ 155 | override def getBinLabelAverage( 156 | statsArray: Array[Double], 157 | binOffset: Int): Double = { 158 | var labelSum = 0.0 159 | var weightSum = 0.0 160 | cfor(0)(_ < numElemsPerBin, _ + 1)( 161 | labelId => { 162 | val labelWeight = statsArray(binOffset + labelId) 163 | labelSum += labelId.toDouble * labelWeight 164 | weightSum += labelWeight 165 | } 166 | ) 167 | 168 | labelSum / weightSum 169 | } 170 | 171 | /** 172 | * This is for classification, so return true. 173 | * @return true 174 | */ 175 | override def forClassification: Boolean = true 176 | } 177 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/tree_ensembles/SubTreeStore.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.tree_ensembles 19 | 20 | import scala.collection.mutable 21 | 22 | import spark_ml.util.Sorting 23 | import spire.implicits._ 24 | 25 | /** 26 | * This object contains the descriptions of all the nodes in this sub-tree. 27 | * @param parentTreeId Id of the parent tree. 28 | * @param subTreeId Sub tree Id. 29 | * @param subTreeDepth Depth of the sub-tree from the parent tree's perspective. 30 | */ 31 | class SubTreeDesc( 32 | val parentTreeId: Int, 33 | val subTreeId: Int, 34 | val subTreeDepth: Int) extends Serializable { 35 | // Store the nodes in a mutable map. 36 | var nodes = mutable.Map[Int, NodeInfo]() 37 | 38 | /** 39 | * Add a new trained node info object. 40 | * It's expected that the node Id's increment monotonically, one by one. 41 | * @param nodeInfo Info about the next trained node. 42 | */ 43 | def addNodeInfo(nodeInfo: NodeInfo): Unit = { 44 | assert( 45 | !nodes.contains(nodeInfo.nodeId), 46 | "A node with the node Id " + nodeInfo.nodeId + " already exists in the " + 47 | "sub tree " + subTreeId + " that belongs to the parent tree " + parentTreeId 48 | ) 49 | 50 | nodes.put(nodeInfo.nodeId, nodeInfo) 51 | } 52 | 53 | /** 54 | * Update the tree/node Ids and the node depth to reflect the parent tree's 55 | * reality. 56 | * @param rootId The root node Id will be updated to this. 57 | * @param startChildNodeId The descendant nodes will have updated Ids, 58 | * starting from this number. 59 | * @return The last descendant node Id that was used for this sub-tree. 60 | */ 61 | def updateIdsAndDepths( 62 | rootId: Int, 63 | startChildNodeId: Int): Int = { 64 | val itr = nodes.values.iterator 65 | val updatedNodes = mutable.Map[Int, NodeInfo]() 66 | var largestUpdatedNodeId = 0 67 | while (itr.hasNext) { 68 | val nodeInfo = itr.next() 69 | val updatedNodeInfo = nodeInfo.copy 70 | updatedNodeInfo.treeId = parentTreeId 71 | if (updatedNodeInfo.nodeId == 1) { 72 | updatedNodeInfo.nodeId = rootId 73 | } else { 74 | // Start node Id is used from the first child node of the root node, 75 | // which should have the Id of 2. Therefore, update the node Ids by 76 | // subtracting 2 and then adding startChildNodeId. 77 | updatedNodeInfo.nodeId = updatedNodeInfo.nodeId - 2 + startChildNodeId 78 | } 79 | largestUpdatedNodeId = math.max(updatedNodeInfo.nodeId, largestUpdatedNodeId) 80 | updatedNodeInfo.depth = updatedNodeInfo.depth - 1 + subTreeDepth 81 | if (updatedNodeInfo.splitInfo.nonEmpty) { 82 | val si = updatedNodeInfo.splitInfo.get 83 | val children = si.getOrderedChildNodes 84 | val numChildren = children.length 85 | cfor(0)(_ < numChildren, _ + 1)( 86 | i => { 87 | val child = children(i) 88 | child.treeId = parentTreeId 89 | child.nodeId = child.nodeId - 2 + startChildNodeId 90 | child.depth = child.depth - 1 + subTreeDepth 91 | } 92 | ) 93 | } 94 | 95 | updatedNodes.put(updatedNodeInfo.nodeId, updatedNodeInfo) 96 | } 97 | 98 | this.nodes = updatedNodes 99 | 100 | largestUpdatedNodeId 101 | } 102 | 103 | /** 104 | * Get a sequence of nodeInfo objects ordered by the node Id. 105 | * @return A sequence of nodeInfo objects ordered by the node Id. 106 | */ 107 | def getOrderedNodeInfoSeq: Seq[NodeInfo] = { 108 | val nodeInfoArray = this.nodes.values.toArray 109 | Sorting.quickSort[NodeInfo](nodeInfoArray)( 110 | Ordering.by[NodeInfo, Int](_.nodeId) 111 | ) 112 | 113 | nodeInfoArray 114 | } 115 | } 116 | 117 | /** 118 | * Local subtree store. Used for locally training sub-trees. 119 | * @param parentTreeId Id of the parent tree. 120 | * @param subTreeId Sub tree Id. 121 | * @param subTreeDepth Depth of the sub-tree from the parent tree's perspective. 122 | */ 123 | class SubTreeStore( 124 | parentTreeId: Int, 125 | subTreeId: Int, 126 | subTreeDepth: Int) extends TreeEnsembleStore { 127 | val subTreeDesc = new SubTreeDesc( 128 | parentTreeId = parentTreeId, 129 | subTreeId = subTreeId, 130 | subTreeDepth = subTreeDepth 131 | ) 132 | 133 | /** 134 | * Get a sub tree writer. 135 | * @return Sub tree writer. 136 | */ 137 | def getWriter: TreeEnsembleWriter = new SubTreeWriter(this) 138 | } 139 | 140 | /** 141 | * Sub tree writer. 142 | * @param subTreeStore The sub tree store this belongs to. 143 | */ 144 | class SubTreeWriter(subTreeStore: SubTreeStore) extends TreeEnsembleWriter { 145 | val subTreeDesc = subTreeStore.subTreeDesc 146 | 147 | /** 148 | * Write node info. 149 | * @param nodeInfo Node info to write. 150 | */ 151 | def writeNodeInfo(nodeInfo: NodeInfo): Unit = { 152 | subTreeDesc.addNodeInfo(nodeInfo) 153 | } 154 | } 155 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/tree_ensembles/TreeEnsembleStore.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.tree_ensembles 19 | 20 | /** 21 | * Tree ensemble writer. 22 | */ 23 | trait TreeEnsembleWriter { 24 | /** 25 | * Write node info. 26 | * @param nodeInfo Node info to write. 27 | */ 28 | def writeNodeInfo(nodeInfo: NodeInfo): Unit 29 | } 30 | 31 | /** 32 | * Tree ensemble store. 33 | */ 34 | trait TreeEnsembleStore { 35 | /** 36 | * Get a tree ensemble writer. 37 | * @return A tree ensemble writer. 38 | */ 39 | def getWriter: TreeEnsembleWriter 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/tree_ensembles/VarianceNodeStats.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.tree_ensembles 19 | 20 | import spark_ml.util.DiscretizedFeatureHandler 21 | import spire.implicits._ 22 | 23 | /** 24 | * Variance node statistics. 25 | * @param treeId Id of the tree that the node belongs to. 26 | * @param nodeId Id of the node. 27 | * @param nodeDepth Depth of the node. 28 | * @param statsArray The actual statistics array. 29 | * @param mtryFeatures Selected feature descriptions. 30 | * @param numElemsPerBin Number of statistical elements per feature bin. 31 | */ 32 | class VarianceNodeStats( 33 | treeId: Int, 34 | nodeId: Int, 35 | nodeDepth: Int, 36 | statsArray: Array[Double], 37 | mtryFeatures: Array[SelectedFeatureInfo], 38 | numElemsPerBin: Int) extends NodeStats( 39 | treeId = treeId, 40 | nodeId = nodeId, 41 | nodeDepth = nodeDepth, 42 | statsArray = statsArray, 43 | mtryFeatures = mtryFeatures, 44 | numElemsPerBin = numElemsPerBin) { 45 | 46 | /** 47 | * Add statistics related to a sample (label and features). 48 | * @param label Label of the sample. 49 | * @param features Features of the sample. 50 | * @param sampleCnt Sample count of the sample. 51 | * @param featureHandler Feature type handler. 52 | * @tparam T Feature type (Byte or Short). 53 | */ 54 | override def addSample[@specialized(Byte, Short) T]( 55 | label: Double, 56 | features: Array[T], 57 | sampleCnt: Int, 58 | featureHandler: DiscretizedFeatureHandler[T]): Unit = { 59 | val mtry = mtryFeatures.length 60 | cfor(0)(_ < mtry, _ + 1)( 61 | i => { 62 | // Add to the bins of all the selected features. 63 | val featId = mtryFeatures(i).featId 64 | val featOffset = mtryFeatures(i).offset 65 | val binId = featureHandler.convertToInt(features(featId)) 66 | 67 | val binOffset = featOffset + binId * numElemsPerBin 68 | 69 | val sampleCntInDouble = sampleCnt.toDouble 70 | val labelSum = label * sampleCntInDouble 71 | val labelSqrSum = label * labelSum 72 | 73 | statsArray(binOffset) += labelSum 74 | statsArray(binOffset + 1) += labelSqrSum 75 | statsArray(binOffset + 2) += sampleCntInDouble 76 | } 77 | ) 78 | } 79 | 80 | /** 81 | * Calculate the node values from the summary stats. Also computes variance. 82 | * @param sumStats Summary stats (label sum, label sqr sum, count). 83 | * @param offset Starting offset. 84 | * @param output Where the output will be stored. 85 | * @return Node values. 86 | */ 87 | override def calculateNodeValues( 88 | sumStats: Array[Double], 89 | offset: Int, 90 | output: PreallocatedNodeValues): PreallocatedNodeValues = { 91 | val prediction = sumStats(offset) / sumStats(offset + 2) 92 | val variance = 93 | sumStats(offset + 1) / sumStats(offset + 2) - prediction * prediction 94 | 95 | output.prediction = prediction 96 | output.addendum = variance 97 | output.weight = sumStats(offset + 2) 98 | output.impurity = variance 99 | output.sumStats(0) = sumStats(offset) 100 | output.sumStats(1) = sumStats(offset + 1) 101 | output.sumStats(2) = sumStats(offset + 2) 102 | 103 | output 104 | } 105 | 106 | /** 107 | * Get bin weights. 108 | * @param statsArray Stats array. 109 | * @param offset Start offset. 110 | * @param numBins Number of bins. 111 | * @param output This is where the weights will be stored. 112 | * @return Returns the same output that was passed in. 113 | */ 114 | override def getBinWeights( 115 | statsArray: Array[Double], 116 | offset: Int, 117 | numBins: Int, 118 | output: Array[Double]): Array[Double] = { 119 | cfor(0)(_ < numBins, _ + 1)( 120 | binId => output(binId) = statsArray(offset + binId * numElemsPerBin + 2) 121 | ) 122 | 123 | output 124 | } 125 | 126 | /** 127 | * Calculate the label average for the given bin. 128 | * @param statsArray Stats array. 129 | * @param binOffset Offset to the bin. 130 | * @return The label average for the bin. 131 | */ 132 | override def getBinLabelAverage( 133 | statsArray: Array[Double], 134 | binOffset: Int): Double = { 135 | statsArray(binOffset) / statsArray(binOffset + 2) 136 | } 137 | 138 | /** 139 | * This is for regression, so return false. 140 | * @return false 141 | */ 142 | override def forClassification: Boolean = false 143 | } 144 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/util/Bagger.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | import scala.util.Random 21 | 22 | import org.apache.spark.rdd.RDD 23 | import spire.implicits._ 24 | 25 | /** 26 | * Available bagging types. 27 | */ 28 | object BaggingType extends Enumeration { 29 | type BaggingType = Value 30 | val WithReplacement = Value(0) 31 | val WithoutReplacement = Value(1) 32 | } 33 | 34 | /** 35 | * Bagger. 36 | */ 37 | object Bagger { 38 | /** 39 | * Create an RDD of bagging info. Each row of the return value is an array of 40 | * sample counts for a corresponding data row. 41 | * The number of samples per row is set by numSamples. 42 | * @param data An RDD of data points. 43 | * @param numSamples Number of samples we want to get per row. 44 | * @param baggingType Bagging type. 45 | * @param baggingRate Bagging rate. 46 | * @param seed Random seed. 47 | * @tparam T Data row type. 48 | * @return An RDD of an array of sample counts. 49 | */ 50 | def getBagRdd[T]( 51 | data: RDD[T], 52 | numSamples: Int, 53 | baggingType: BaggingType.BaggingType, 54 | baggingRate: Double, 55 | seed: Int): RDD[Array[Byte]] = { 56 | data.mapPartitionsWithIndex( 57 | (index, rows) => { 58 | val poisson = Poisson(baggingRate, seed + index) 59 | val rng = new Random(seed + index) 60 | rows.map( 61 | row => { 62 | val counts = Array.fill[Byte](numSamples)(0) 63 | cfor(0)(_ < numSamples, _ + 1)( 64 | sampleId => { 65 | val sampleCount = 66 | if (baggingType == BaggingType.WithReplacement) { 67 | poisson.sample() 68 | } else { 69 | if (rng.nextDouble() <= baggingRate) 1 else 0 70 | } 71 | 72 | // Only allow a sample count value upto 127 73 | // to save space. 74 | counts(sampleId) = math.min(sampleCount, 127).toByte 75 | } 76 | ) 77 | 78 | counts 79 | } 80 | ) 81 | } 82 | ) 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/util/DiscretizedFeatureHandler.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | /** 21 | * Discretized features can be either unsigned Byte or Short. This trait is 22 | * responsible for handling the actual data types, freeing the aggregation logic 23 | * from handling different types. 24 | * @tparam T Type of the features. Either Byte or Short. 25 | */ 26 | trait DiscretizedFeatureHandler[@specialized(Byte, Short) T] extends Serializable { 27 | /** 28 | * Convert the given value of type T to an integer value. 29 | * @param value The value that we want to convert. 30 | * @return The integer value. 31 | */ 32 | def convertToInt(value: T): Int 33 | 34 | /** 35 | * Convert the given integer value to the type. 36 | * @param value Integer value that we want to convert. 37 | * @return The converted value. 38 | */ 39 | def convertToType(value: Int): T 40 | 41 | /** 42 | * Get the minimum value you can get for this type. 43 | * @return The minimum value for this type. 44 | */ 45 | def getMinValue: Int 46 | 47 | /** 48 | * get the maximum value you can get for this type. 49 | * @return The maximum value for this type. 50 | */ 51 | def getMaxValue: Int 52 | } 53 | 54 | /** 55 | * Handle unsigned byte features. 56 | */ 57 | class UnsignedByteHandler extends DiscretizedFeatureHandler[Byte] { 58 | /** 59 | * Convert the given value of unsigned Byte to an integer value. 60 | * @param value The value that we want to convert. 61 | * @return The integer value. 62 | */ 63 | def convertToInt(value: Byte): Int = { 64 | value.toInt + 128 65 | } 66 | 67 | /** 68 | * Convert the given integer value to the type. 69 | * @param value Integer value that we want to convert. 70 | * @return The converted value. 71 | */ 72 | def convertToType(value: Int): Byte = { 73 | (value - 128).toByte 74 | } 75 | 76 | /** 77 | * Get the minimum value you can get for this type. 78 | * @return The minimum value for this type. 79 | */ 80 | def getMinValue: Int = 0 81 | 82 | /** 83 | * get the maximum value you can get for this type. 84 | * @return The maximum value for this type. 85 | */ 86 | def getMaxValue: Int = 255 87 | } 88 | 89 | /** 90 | * Handle unsigned short features. 91 | */ 92 | class UnsignedShortHandler extends DiscretizedFeatureHandler[Short] { 93 | /** 94 | * Convert the given value of unsigned Short to an integer value. 95 | * @param value The value that we want to convert. 96 | * @return The integer value. 97 | */ 98 | def convertToInt(value: Short): Int = { 99 | value.toInt + 32768 100 | } 101 | 102 | /** 103 | * Convert the given integer value to the type. 104 | * @param value Integer value that we want to convert. 105 | * @return The converted value. 106 | */ 107 | def convertToType(value: Int): Short = { 108 | (value - 32768).toShort 109 | } 110 | 111 | /** 112 | * Get the minimum value you can get for this type. 113 | * @return The minimum value for this type. 114 | */ 115 | def getMinValue: Int = 0 116 | 117 | /** 118 | * get the maximum value you can get for this type. 119 | * @return The maximum value for this type. 120 | */ 121 | def getMaxValue: Int = 65535 122 | } 123 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/util/MapWithSequentialIntKeys.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | import scala.reflect.ClassTag 21 | 22 | import spire.implicits._ 23 | 24 | /** 25 | * Exception to be thrown in case of unexpected behavior of the following map 26 | * class. 27 | * @param msg Exception msg. 28 | */ 29 | case class UnexpectedKeyException(msg: String) extends Exception(msg) 30 | 31 | /** 32 | * A map with integer keys that are guaranteed to be incrementing one by one. 33 | * A next insert is expected to use a key that's equal to 1 + prev put key. 34 | * A next remove is expected to use a key that's equal to the beginning key. 35 | * @param initCapacity Initial capacity. 36 | */ 37 | class MapWithSequentialIntKeys[@specialized(Int) T: ClassTag](initCapacity: Int) 38 | extends Serializable { 39 | private var values = new Array[T](initCapacity) 40 | private var capacity = initCapacity 41 | private var size = 0 42 | private var putCursor = 0 43 | private var expectedPutKey = 0 44 | private var firstGetPos = 0 45 | private var firstGetKey = 0 46 | 47 | /** 48 | * Put the next key value. The key is expected to be 1 + the previous one 49 | * unless it's the very first key. 50 | * @param key The key value. 51 | * @param value The value corresponding to the key. 52 | */ 53 | def put(key: Int, value: T): Unit = { 54 | // Make sure that the key is as expected. 55 | if (size > 0 && key != expectedPutKey) { 56 | throw UnexpectedKeyException( 57 | "The put key " + key + 58 | " is different from the expected key " + expectedPutKey 59 | ) 60 | } 61 | 62 | // If the array is full, we need to get a new array. 63 | if (size >= capacity) { 64 | capacity *= 2 65 | val newValues = new Array[T](capacity) 66 | cfor(0)(_ < size, _ + 1)( 67 | i => { 68 | newValues(i) = values(firstGetPos) 69 | firstGetPos += 1 70 | if (firstGetPos >= size) { 71 | firstGetPos = 0 72 | } 73 | } 74 | ) 75 | 76 | values = newValues 77 | putCursor = size 78 | firstGetPos = 0 79 | } 80 | 81 | values(putCursor) = value 82 | putCursor += 1 83 | if (putCursor >= capacity) { 84 | putCursor = 0 85 | } 86 | 87 | if (size == 0) { 88 | firstGetKey = key 89 | } 90 | 91 | size += 1 92 | expectedPutKey = key + 1 93 | } 94 | 95 | /** 96 | * Get the value corresponding to the key. 97 | * @param key The integer key value. 98 | * @return The corresponding value. 99 | */ 100 | def get(key: Int): T = { 101 | val getOffset = key - firstGetKey 102 | 103 | // Make sure that it's within the expected range. 104 | if (getOffset >= size || getOffset < 0) { 105 | throw UnexpectedKeyException( 106 | "The get key " + key + 107 | " is not within [" + firstGetKey + ", " + (firstGetKey + size - 1) + "]" 108 | ) 109 | } 110 | 111 | var idx = firstGetPos + getOffset 112 | if (idx >= capacity) { 113 | idx -= capacity 114 | } 115 | 116 | values(idx) 117 | } 118 | 119 | /** 120 | * Remove a key. 121 | * @param key The key we want to remove. 122 | */ 123 | def remove(key: Int): Unit = { 124 | // We expect the key to firstGetKey. 125 | // Otherwise, this is not being used as expected. 126 | if (key != firstGetKey) { 127 | throw UnexpectedKeyException( 128 | "The remove key " + key + 129 | " is different from the expected key " + firstGetKey 130 | ) 131 | } 132 | 133 | size -= 1 134 | firstGetKey += 1 135 | firstGetPos += 1 136 | if (firstGetPos >= capacity) { 137 | firstGetPos = 0 138 | } 139 | } 140 | 141 | /** 142 | * Get key range in the object. 143 | * @return A pair of (startKey, endKey). 144 | */ 145 | def getKeyRange: (Int, Int) = { 146 | (firstGetKey, firstGetKey + size - 1) 147 | } 148 | 149 | /** 150 | * Whether this map contains the key. 151 | * @param key The integer key that we want to check for. 152 | * @return true if the key is contained. false otherwise. 153 | */ 154 | def contains(key: Int): Boolean = { 155 | val (startKey, endKey) = getKeyRange 156 | (key >= startKey) && (key <= endKey) 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/util/Poisson.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | import scala.collection.mutable 21 | import scala.util.Random 22 | 23 | /** 24 | * Poisson sampler. 25 | * @param lambda The average count of the Poisson distribution. 26 | * @param seed The random number generator seed. 27 | */ 28 | case class Poisson(lambda: Double, seed: Int) { 29 | private val rng = new Random(seed) 30 | private val tolerance: Double = 0.00001 31 | 32 | private var pdf: Array[Double] = _ 33 | private var cdf: Array[Double] = _ 34 | 35 | { 36 | val pdfBuilder = new mutable.ArrayBuilder.ofDouble 37 | val cdfBuilder = new mutable.ArrayBuilder.ofDouble 38 | 39 | val expPart: Double = math.exp(-lambda) 40 | var curCDF: Double = expPart 41 | var curPDF: Double = expPart 42 | 43 | cdfBuilder += curCDF 44 | pdfBuilder += curPDF 45 | 46 | var i: Double = 1.0 47 | while (curCDF < (1.0 - tolerance)) { 48 | curPDF *= lambda / i 49 | curCDF += curPDF 50 | 51 | cdfBuilder += curCDF 52 | pdfBuilder += curPDF 53 | 54 | i += 1.0 55 | } 56 | 57 | pdf = pdfBuilder.result() 58 | cdf = cdfBuilder.result() 59 | } 60 | 61 | /** 62 | * Sample a poisson distributed value. 63 | * @return A sampled integer value. 64 | */ 65 | def sample(): Int = { 66 | val rnd = rng.nextDouble() 67 | for (i <- 0 to cdf.length - 1) { 68 | if (rnd <= cdf(i)) { 69 | return i 70 | } 71 | } 72 | 73 | cdf.length - 1 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/util/ProgressNotifiee.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | import java.io.Serializable 21 | import java.util.Calendar 22 | 23 | /** 24 | * An object to funnel progress reports to. 25 | */ 26 | trait ProgressNotifiee extends Serializable { 27 | def newProgressMessage(progress: String): Unit 28 | def newStatusMessage(status: String): Unit 29 | def newErrorMessage(error: String): Unit 30 | } 31 | 32 | /** 33 | * Simple console notifiee. 34 | * Prints messages to stdout. 35 | */ 36 | class ConsoleNotifiee extends ProgressNotifiee { 37 | def newProgressMessage(progress: String): Unit = { 38 | println( 39 | "[Progress] [" + Calendar.getInstance().getTime.toString + "] " + progress 40 | ) 41 | } 42 | 43 | def newStatusMessage(status: String): Unit = { 44 | println( 45 | "[Status] [" + Calendar.getInstance().getTime.toString + "] " + status 46 | ) 47 | } 48 | 49 | def newErrorMessage(error: String): Unit = { 50 | println( 51 | "[Error] [" + Calendar.getInstance().getTime.toString + "] " + error 52 | ) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/util/RandomSet.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | import scala.collection.mutable 21 | 22 | /** 23 | * Helper functions to select random samples. 24 | */ 25 | object RandomSet { 26 | def nChooseK(k: Int, n: Int, rnd: scala.util.Random): Array[Int] = { 27 | val indices = new mutable.ArrayBuilder.ofInt 28 | var remains = k 29 | 30 | for (i <- 0 to n - 1) { 31 | if (rnd.nextInt(n - i) < remains) { 32 | indices += i 33 | remains -= 1 34 | } 35 | } 36 | 37 | indices.result() 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/util/ReservoirSample.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | import scala.util.Random 21 | 22 | /** 23 | * Reservoir sample. 24 | * @param maxSample The maximum sample size. 25 | */ 26 | class ReservoirSample(val maxSample: Int) extends Serializable { 27 | var sample = Array.fill[Double](maxSample)(0.0) 28 | var numSamplePoints: Int = 0 29 | var numPointsSeen: Double = 0.0 30 | 31 | def doReservoirSampling(point: Double, rng: Random): Unit = { 32 | numPointsSeen += 1.0 33 | if (numSamplePoints < maxSample) { 34 | sample(numSamplePoints) = point 35 | numSamplePoints += 1 36 | } else { 37 | val randomNumber = math.floor(rng.nextDouble() * numPointsSeen) 38 | if (randomNumber < maxSample.toDouble) { 39 | sample(randomNumber.toInt) = point 40 | } 41 | } 42 | } 43 | } 44 | 45 | object ReservoirSample { 46 | /** 47 | * Merge two reservoir samples and make sure that each sample retains 48 | * uniformness. 49 | * @param a A reservoir sample. 50 | * @param b A reservoir sample. 51 | * @param maxSample Maximum number of reservoir samples we want. 52 | * @param rng A random number generator. 53 | * @return Merged sample. 54 | */ 55 | def mergeReservoirSamples( 56 | a: ReservoirSample, 57 | b: ReservoirSample, 58 | maxSample: Int, 59 | rng: Random): ReservoirSample = { 60 | 61 | assert(maxSample == a.maxSample) 62 | assert(a.maxSample == b.maxSample) 63 | 64 | // Merged samples. 65 | val mergedSample = new ReservoirSample(maxSample) 66 | 67 | // Find out which one has seen more samples. 68 | val (largerSample, smallerSample) = 69 | if (a.numPointsSeen > b.numPointsSeen) { 70 | (a, b) 71 | } else { 72 | (b, a) 73 | } 74 | 75 | // First, fill in the merged samples with the samples that had 'lower' prob 76 | // of being selected. I.e., the sample that has seen more points. 77 | var i = 0 78 | while (i < largerSample.numSamplePoints) { 79 | mergedSample.sample(i) = largerSample.sample(i) 80 | i += 1 81 | } 82 | mergedSample.numSamplePoints = largerSample.numSamplePoints 83 | mergedSample.numPointsSeen = largerSample.numPointsSeen 84 | 85 | // Now, add smaller sample points with probabilities so that they become 86 | // uniform. 87 | var j = 0 // The smaller sample index. 88 | val probSmaller = smallerSample.numPointsSeen / (largerSample.numPointsSeen + smallerSample.numPointsSeen) 89 | while (j < smallerSample.numSamplePoints) { 90 | val samplePoint = smallerSample.sample(j) 91 | if (mergedSample.numSamplePoints > smallerSample.numSamplePoints) { 92 | mergedSample.doReservoirSampling(samplePoint, rng) 93 | } else { 94 | val rnd = rng.nextDouble() 95 | if (rnd < probSmaller) { 96 | mergedSample.sample(j) = samplePoint 97 | } 98 | } 99 | 100 | j += 1 101 | } 102 | 103 | mergedSample 104 | } 105 | } -------------------------------------------------------------------------------- /src/main/scala/spark_ml/util/RobustMath.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | /** 21 | * Some math functions that don't give NaN's for edge cases (e.g. log of 0's) or 22 | * too large/small numbers of exp's. 23 | * These are used for loss function calculations. 24 | */ 25 | object RobustMath { 26 | private val minExponent = -19.0 27 | private val maxExponent = 19.0 28 | private val expPredLowerLimit = math.exp(minExponent) 29 | private val expPredUpperLimit = math.exp(maxExponent) 30 | 31 | def log(value: Double): Double = { 32 | if (value == 0.0) { 33 | minExponent 34 | } else { 35 | math.min(math.max(math.log(value), minExponent), maxExponent) 36 | } 37 | } 38 | 39 | def exp(value: Double): Double = { 40 | math.min(math.max(math.exp(value), expPredLowerLimit), expPredUpperLimit) 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/spark_ml/util/Selection.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | import scala.util.Random 21 | 22 | /** 23 | * Selection algorithm utils. 24 | * Adopted with some modifications from the java code at : 25 | * http://blog.teamleadnet.com/2012/07/quick-select-algorithm-find-kth-element.html 26 | */ 27 | object Selection { 28 | /** 29 | * A quick select algorithm used to select the n'th number from an array. 30 | * @param array Array of Doubles. 31 | * @param s The starting index of the segment that we are looking at. 32 | * @param e The ending position (ending index + 1) of the segment that we are 33 | * looking at. 34 | * @param n We're looking for the n'th number if the array is sorted. 35 | * @param rng A random number generator. 36 | * @return The n'th number in the array. 37 | */ 38 | def quickSelect( 39 | array: Array[Double], 40 | s: Int, 41 | e: Int, 42 | n: Int, 43 | rng: Random): Double = { 44 | 45 | var from = s 46 | var to = e - 1 47 | while (from < to) { 48 | var r = from 49 | var w = to 50 | val pivotIdx = rng.nextInt(to - from + 1) + from 51 | val pivotVal = array(pivotIdx) 52 | while (r < w) { 53 | if (array(r) >= pivotVal) { 54 | val tmp = array(w) 55 | array(w) = array(r) 56 | array(r) = tmp 57 | w -= 1 58 | } else { 59 | r += 1 60 | } 61 | } 62 | 63 | if (array(r) > pivotVal) { 64 | r -= 1 65 | } 66 | 67 | if (n <= r) { 68 | to = r 69 | } else { 70 | from = r + 1 71 | } 72 | } 73 | 74 | array(n) 75 | } 76 | } 77 | -------------------------------------------------------------------------------- /src/test/scala/spark_ml/discretization/EqualFrequencyBinFinderSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.discretization 19 | 20 | import scala.util.{Failure, Success, Try} 21 | 22 | import org.scalatest.FunSuite 23 | import spark_ml.util._ 24 | 25 | /** 26 | * Test equal frequency bin finders. 27 | */ 28 | class EqualFrequencyBinFinderSuite extends FunSuite with LocalSparkContext { 29 | test("Test the equal frequency bin finder 1") { 30 | val rawData1 = TestDataGenerator.labeledData1 31 | val testDataRDD1 = sc.parallelize(rawData1, 3).cache() 32 | 33 | val (labelSummary1, bins1) = 34 | new EqualFrequencyBinFinderFromSample( 35 | maxSampleSize = 1000, 36 | seed = 0 37 | ).findBins( 38 | data = testDataRDD1, 39 | columnNames = ("Label", Array("Col1", "Col2", "Col3")), 40 | catIndices = Set(1), 41 | maxNumBins = 8, 42 | expectedLabelCardinality = Some(4), 43 | notifiee = new ConsoleNotifiee 44 | ) 45 | 46 | assert(labelSummary1.restCount === 0L) 47 | assert(labelSummary1.catCounts.get.length === 4) 48 | assert(labelSummary1.catCounts.get(0) === 7L) 49 | assert(labelSummary1.catCounts.get(1) === 8L) 50 | assert(labelSummary1.catCounts.get(2) === 11L) 51 | assert(labelSummary1.catCounts.get(3) === 4L) 52 | assert(bins1.length === 3) 53 | assert(bins1(0).getCardinality === 5) 54 | assert(bins1(1).getCardinality === 3) 55 | assert(bins1(2).getCardinality === 8) 56 | 57 | BinsTestUtil.validateNumericalBins( 58 | bins1(0).asInstanceOf[NumericBins], 59 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, Double.PositiveInfinity)), 60 | None 61 | ) 62 | 63 | assert(bins1(0).findBinIdx(0.0) === 0) 64 | assert(bins1(0).findBinIdx(0.5) === 0) 65 | assert(bins1(0).findBinIdx(1.0) === 1) 66 | assert(bins1(0).findBinIdx(3.99999) === 3) 67 | assert(bins1(0).findBinIdx(4.0) === 4) 68 | assert(bins1(0).findBinIdx(10.0) === 4) 69 | 70 | assert(bins1(1).isInstanceOf[CategoricalBins]) 71 | 72 | Try(bins1(1).findBinIdx(1.1)) match { 73 | case Success(idx) => fail("CategoricalBins findBinIdx should've thrown an exception.") 74 | case Failure(ex) => assert(ex.isInstanceOf[InvalidCategoricalValueException]) 75 | } 76 | 77 | Try(bins1(1).findBinIdx(10.0)) match { 78 | case Success(idx) => fail("CategoricalBins findBinIdx should've thrown an exception.") 79 | case Failure(ex) => assert(ex.isInstanceOf[CardinalityOverLimitException]) 80 | } 81 | 82 | BinsTestUtil.validateNumericalBins( 83 | bins1(2).asInstanceOf[NumericBins], 84 | Array( 85 | (Double.NegativeInfinity, -72.87), 86 | (-72.87, -52.28), 87 | (-52.28, -5.63), 88 | (-5.63, 20.88), 89 | (20.88, 25.89), 90 | (25.89, 59.07), 91 | (59.07, 81.67), 92 | (81.67, Double.PositiveInfinity) 93 | ), 94 | None 95 | ) 96 | 97 | val rawData3 = TestDataGenerator.labeledData3 98 | val testDataRDD3 = sc.parallelize(rawData3, 3).cache() 99 | 100 | val (labelSummary3, bins3) = 101 | new EqualFrequencyBinFinderFromSample( 102 | maxSampleSize = 1000, 103 | seed = 0 104 | ).findBins( 105 | data = testDataRDD3, 106 | columnNames = ("Label", Array("Col1", "Col2")), 107 | catIndices = Set(), 108 | maxNumBins = 8, 109 | expectedLabelCardinality = None, 110 | notifiee = new ConsoleNotifiee 111 | ) 112 | 113 | assert(labelSummary3.expectedCardinality.isEmpty) 114 | assert(labelSummary3.catCounts.isEmpty) 115 | assert(labelSummary3.restCount === 30L) 116 | assert(bins3.length === 2) 117 | assert(bins3(0).getCardinality === 5) 118 | assert(bins3(1).getCardinality === 3) 119 | 120 | BinsTestUtil.validateNumericalBins( 121 | bins3(0).asInstanceOf[NumericBins], 122 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, Double.PositiveInfinity)), 123 | None 124 | ) 125 | 126 | BinsTestUtil.validateNumericalBins( 127 | bins3(1).asInstanceOf[NumericBins], 128 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, Double.PositiveInfinity)), 129 | None 130 | ) 131 | 132 | val rawData6 = TestDataGenerator.labeledData6 133 | val testDataRDD6 = sc.parallelize(rawData6, 3).cache() 134 | 135 | val (labelSummary6, bins6) = 136 | new EqualFrequencyBinFinderFromSample( 137 | maxSampleSize = 1000, 138 | seed = 0 139 | ).findBins( 140 | data = testDataRDD6, 141 | columnNames = ("Label", Array("Col1", "Col2")), 142 | catIndices = Set(), 143 | maxNumBins = 8, 144 | expectedLabelCardinality = Some(3), 145 | notifiee = new ConsoleNotifiee 146 | ) 147 | 148 | assert(labelSummary6.restCount === 4L) 149 | assert(labelSummary6.catCounts.get(0) === 7L) 150 | assert(labelSummary6.catCounts.get(1) === 8L) 151 | assert(labelSummary6.catCounts.get(2) === 11L) 152 | assert(bins6.length === 2) 153 | assert(bins6(0).getCardinality === 6) 154 | assert(bins6(1).getCardinality === 4) 155 | 156 | BinsTestUtil.validateNumericalBins( 157 | bins6(0).asInstanceOf[NumericBins], 158 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, Double.PositiveInfinity)), 159 | Option(5) 160 | ) 161 | 162 | BinsTestUtil.validateNumericalBins( 163 | bins6(1).asInstanceOf[NumericBins], 164 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, Double.PositiveInfinity)), 165 | Option(3) 166 | ) 167 | } 168 | 169 | test("Test the equal frequency RDD transformation 1") { 170 | val rawData1 = TestDataGenerator.labeledData1 171 | val testDataRDD1 = sc.parallelize(rawData1, 1).cache() 172 | 173 | val (labelSummary1, bins1) = 174 | new EqualFrequencyBinFinderFromSample( 175 | maxSampleSize = 1000, 176 | seed = 0 177 | ).findBins( 178 | data = testDataRDD1, 179 | columnNames = ("Label", Array("Col1", "Col2", "Col3")), 180 | catIndices = Set(1), 181 | maxNumBins = 8, 182 | expectedLabelCardinality = Some(4), 183 | notifiee = new ConsoleNotifiee 184 | ) 185 | 186 | val featureHandler = new UnsignedByteHandler 187 | val transformedFeatures1 = Discretizer.transformFeatures( 188 | input = testDataRDD1, 189 | featureBins = bins1, 190 | featureHandler = featureHandler 191 | ).collect() 192 | 193 | assert(featureHandler.convertToInt(transformedFeatures1(0)(0)) === 0) 194 | assert(featureHandler.convertToInt(transformedFeatures1(0)(1)) === 0) 195 | assert(featureHandler.convertToInt(transformedFeatures1(0)(2)) === 2) 196 | 197 | assert(featureHandler.convertToInt(transformedFeatures1(17)(0)) === 2) 198 | assert(featureHandler.convertToInt(transformedFeatures1(17)(1)) === 2) 199 | assert(featureHandler.convertToInt(transformedFeatures1(17)(2)) === 4) 200 | 201 | val rawData6 = TestDataGenerator.labeledData6 202 | val testDataRDD6 = sc.parallelize(rawData6, 1).cache() 203 | 204 | val (labelSummary6, bins6) = 205 | new EqualFrequencyBinFinderFromSample( 206 | maxSampleSize = 1000, 207 | seed = 0 208 | ).findBins( 209 | data = testDataRDD6, 210 | columnNames = ("Label", Array("Col1", "Col2")), 211 | catIndices = Set(), 212 | maxNumBins = 8, 213 | expectedLabelCardinality = Some(4), 214 | notifiee = new ConsoleNotifiee 215 | ) 216 | 217 | val transformedFeatures6 = Discretizer.transformFeatures( 218 | input = testDataRDD6, 219 | featureBins = bins6, 220 | featureHandler = featureHandler 221 | ).collect() 222 | 223 | assert(featureHandler.convertToInt(transformedFeatures6(2)(0)) === 2) 224 | assert(featureHandler.convertToInt(transformedFeatures6(2)(1)) === 3) 225 | 226 | assert(featureHandler.convertToInt(transformedFeatures6(7)(0)) === 5) 227 | assert(featureHandler.convertToInt(transformedFeatures6(7)(1)) === 1) 228 | } 229 | } 230 | -------------------------------------------------------------------------------- /src/test/scala/spark_ml/discretization/EqualWidthBinFinderSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.discretization 19 | 20 | import org.apache.spark.SparkException 21 | import org.scalatest.FunSuite 22 | import spark_ml.util._ 23 | 24 | /** 25 | * Test equal width discretization. 26 | */ 27 | class EqualWidthBinFinderSuite extends FunSuite with LocalSparkContext { 28 | test("Test the equal width bin finder 1") { 29 | val rawData1 = TestDataGenerator.labeledData1 30 | val testDataRDD1 = sc.parallelize(rawData1, 3).cache() 31 | 32 | val (labelSummary1, bins1) = 33 | new EqualWidthBinFinder().findBins( 34 | data = testDataRDD1, 35 | columnNames = ("Label", Array("Col1", "Col2", "Col3")), 36 | catIndices = Set(1), 37 | maxNumBins = 8, 38 | expectedLabelCardinality = Some(4), 39 | notifiee = new ConsoleNotifiee 40 | ) 41 | 42 | assert(labelSummary1.restCount === 0L) 43 | assert(labelSummary1.catCounts.get.length === 4) 44 | assert(labelSummary1.catCounts.get(0) === 7L) 45 | assert(labelSummary1.catCounts.get(1) === 8L) 46 | assert(labelSummary1.catCounts.get(2) === 11L) 47 | assert(labelSummary1.catCounts.get(3) === 4L) 48 | assert(bins1.length === 3) 49 | assert(bins1(0).getCardinality === 8) 50 | assert(bins1(1).getCardinality === 3) 51 | assert(bins1(2).getCardinality === 8) 52 | 53 | assert(bins1(0).findBinIdx(0.0) === 0) 54 | assert(bins1(0).findBinIdx(0.5) === 1) 55 | assert(bins1(0).findBinIdx(1.0) === 2) 56 | assert(bins1(0).findBinIdx(4.0) === 7) 57 | 58 | assert(bins1(1).isInstanceOf[CategoricalBins]) 59 | 60 | assert(bins1(2).findBinIdx(-80.0) === 0) 61 | assert(bins1(2).findBinIdx(-60.0) === 0) 62 | assert(bins1(2).findBinIdx(-58.0) === 1) 63 | 64 | val rawData6 = TestDataGenerator.labeledData6 65 | val testDataRDD6 = sc.parallelize(rawData6, 3).cache() 66 | 67 | val (labelSummary6, bins6) = 68 | new EqualWidthBinFinder().findBins( 69 | data = testDataRDD6, 70 | columnNames = ("Label", Array("Col1", "Col2")), 71 | catIndices = Set(), 72 | maxNumBins = 8, 73 | expectedLabelCardinality = Some(3), 74 | notifiee = new ConsoleNotifiee 75 | ) 76 | 77 | assert(labelSummary6.restCount === 4L) 78 | assert(labelSummary6.catCounts.get(0) === 7L) 79 | assert(labelSummary6.catCounts.get(1) === 8L) 80 | assert(labelSummary6.catCounts.get(2) === 11L) 81 | assert(bins6.length === 2) 82 | assert(bins6(0).getCardinality === 8) 83 | assert(bins6(1).getCardinality === 8) 84 | assert(bins6(0).asInstanceOf[NumericBins].missingValueBinIdx === Some(7)) 85 | assert(bins6(1).asInstanceOf[NumericBins].missingValueBinIdx === Some(7)) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/test/scala/spark_ml/util/BinsTestUtil.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | import org.scalatest.Assertions._ 21 | import spark_ml.discretization.NumericBins 22 | 23 | object BinsTestUtil { 24 | def validateNumericalBins( 25 | bins: NumericBins, 26 | boundaries: Array[(Double, Double)], 27 | missingBinId: Option[Int]): Unit = { 28 | assert(bins.getCardinality === (boundaries.length + (if (missingBinId.isDefined) 1 else 0))) 29 | assert(bins.missingValueBinIdx === missingBinId) 30 | bins.bins.zip(boundaries).foreach { 31 | case (numericBin, (l, r)) => 32 | assert(numericBin.lower === l) 33 | assert(numericBin.upper === r) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/test/scala/spark_ml/util/LocalSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | import org.apache.log4j.{Level, Logger} 21 | import org.apache.spark.{SparkConf, SparkContext} 22 | import org.scalatest.{BeforeAndAfterAll, Suite} 23 | 24 | /** 25 | * Start a local spark context for unit testing. 26 | */ 27 | trait LocalSparkContext extends BeforeAndAfterAll { self: Suite => 28 | @transient var sc: SparkContext = _ 29 | 30 | override def beforeAll() { 31 | super.beforeAll() 32 | // http://stackoverflow.com/questions/27781187/how-to-stop-messages-displaying-on-spark-console 33 | Logger.getLogger("org.apache.spark").setLevel(Level.WARN) 34 | Logger.getLogger("akka").setLevel(Level.WARN) 35 | Thread.sleep(100L) 36 | val conf = new SparkConf() 37 | .setMaster("local[3]") 38 | .setAppName("test") 39 | sc = new SparkContext(conf) 40 | } 41 | 42 | override def afterAll() { 43 | if (sc != null) { 44 | sc.stop() 45 | sc = null 46 | } 47 | super.afterAll() 48 | } 49 | 50 | def numbersAreEqual(x: Double, y: Double, tol: Double = 1E-3): Boolean = { 51 | math.abs(x - y) / (math.abs(y) + 1e-15) < tol 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/test/scala/spark_ml/util/TestDataGenerator.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package spark_ml.util 19 | 20 | /** 21 | * Generate test data. 22 | */ 23 | object TestDataGenerator { 24 | def labeledData1: Array[(Double, Array[Double])] = { 25 | Array( 26 | (0.0, Array(0.0, 0.0, -52.28)), 27 | (0.0, Array(1.0, 0.0, -32.16)), 28 | (1.0, Array(2.0, 0.0, -73.68)), 29 | (1.0, Array(3.0, 0.0, -26.38)), 30 | (2.0, Array(4.0, 0.0, 13.69)), 31 | (2.0, Array(0.0, 1.0, 42.07)), 32 | (2.0, Array(1.0, 1.0, 22.96)), 33 | (3.0, Array(2.0, 1.0, -33.43)), 34 | (3.0, Array(3.0, 1.0, -61.80)), 35 | (0.0, Array(4.0, 1.0, -81.34)), 36 | (2.0, Array(0.0, 1.0, -68.49)), 37 | (2.0, Array(1.0, 1.0, 64.17)), 38 | (3.0, Array(2.0, 1.0, 20.88)), 39 | (3.0, Array(3.0, 1.0, 27.75)), 40 | (0.0, Array(4.0, 1.0, 59.07)), 41 | (0.0, Array(0.0, 2.0, -53.55)), 42 | (1.0, Array(1.0, 2.0, 25.89)), 43 | (1.0, Array(2.0, 2.0, 22.62)), 44 | (2.0, Array(3.0, 2.0, -5.63)), 45 | (2.0, Array(4.0, 2.0, 81.67)), 46 | (0.0, Array(0.0, 2.0, -72.87)), 47 | (1.0, Array(1.0, 2.0, 25.51)), 48 | (1.0, Array(2.0, 2.0, 43.14)), 49 | (2.0, Array(3.0, 2.0, 60.53)), 50 | (2.0, Array(4.0, 2.0, 88.94)), 51 | (0.0, Array(0.0, 2.0, 17.08)), 52 | (1.0, Array(1.0, 2.0, 69.48)), 53 | (1.0, Array(2.0, 2.0, -76.47)), 54 | (2.0, Array(3.0, 2.0, 90.90)), 55 | (2.0, Array(4.0, 2.0, -79.67)) 56 | ) 57 | } 58 | 59 | def labeledData2: Array[(Double, Array[Double])] = { 60 | Array( 61 | (0.0, Array(0.0, 0.0)), 62 | (0.0, Array(1.0, 0.0)), 63 | (1.0, Array(2.0, 0.0)), 64 | (1.0, Array(3.0, 0.0)), 65 | (2.0, Array(4.0, 0.0)), 66 | (2.0, Array(0.0, 1.0)), 67 | (2.0, Array(1.0, 1.0)), 68 | (3.0, Array(2.0, 1.0)), 69 | (3.0, Array(3.0, 1.0)), 70 | (0.0, Array(4.0, 1.0)), 71 | (2.0, Array(0.0, 1.0)), 72 | (2.0, Array(1.0, 1.0)), 73 | (3.0, Array(2.0, 1.0)), 74 | (3.0, Array(3.0, 1.0)), 75 | (0.0, Array(4.0, 1.0)), 76 | (0.0, Array(0.0, 2.0)), 77 | (1.0, Array(1.0, 2.0)), 78 | (1.0, Array(2.0, 2.0)), 79 | (2.0, Array(3.0, 2.0)), 80 | (2.0, Array(4.0, 2.0)), 81 | (0.0, Array(0.0, 2.0)), 82 | (1.0, Array(1.0, 2.0)), 83 | (1.0, Array(2.0, 2.0)), 84 | (2.0, Array(3.0, 2.0)), 85 | (2.0, Array(4.0, 2.0)), 86 | (0.0, Array(0.0, 2.0)), 87 | (1.0, Array(1.0, 2.0)), 88 | (1.0, Array(2.0, 2.0)), 89 | (2.0, Array(3.0, 2.0)), 90 | (2.0, Array(4.0, 2.0)) 91 | ) 92 | } 93 | 94 | def labeledData3: Array[(Double, Array[Double])] = { 95 | Array( 96 | (0.0, Array(0.0, 0.0)), 97 | (0.0, Array(1.0, 0.0)), 98 | (1.0, Array(2.0, 0.0)), 99 | (1.1, Array(3.0, 0.0)), 100 | (2.0, Array(4.0, 0.0)), 101 | (2.3, Array(0.0, 1.0)), 102 | (2.0, Array(1.0, 1.0)), 103 | (3.0, Array(2.0, 1.0)), 104 | (3.5, Array(3.0, 1.0)), 105 | (0.0, Array(4.0, 1.0)), 106 | (2.0, Array(0.0, 1.0)), 107 | (2.0, Array(1.0, 1.0)), 108 | (3.0, Array(2.0, 1.0)), 109 | (3.2, Array(3.0, 1.0)), 110 | (0.0, Array(4.0, 1.0)), 111 | (0.0, Array(0.0, 2.0)), 112 | (1.0, Array(1.0, 2.0)), 113 | (1.0, Array(2.0, 2.0)), 114 | (2.0, Array(3.0, 2.0)), 115 | (2.0, Array(4.0, 2.0)), 116 | (0.0, Array(0.0, 2.0)), 117 | (1.0, Array(1.0, 2.0)), 118 | (1.0, Array(2.0, 2.0)), 119 | (2.0, Array(3.0, 2.0)), 120 | (2.0, Array(4.0, 2.0)), 121 | (0.0, Array(0.0, 2.0)), 122 | (1.0, Array(1.0, 2.0)), 123 | (1.0, Array(2.0, 2.0)), 124 | (2.0, Array(3.0, 2.0)), 125 | (2.0, Array(4.0, 2.0)) 126 | ) 127 | } 128 | 129 | def labeledData4: Array[(Double, Array[Double])] = { 130 | Array( 131 | (-1.0, Array(0.0, 0.0)), 132 | (0.0, Array(1.0, 0.0)), 133 | (1.0, Array(2.0, 0.0)), 134 | (1.1, Array(3.0, 0.0)), 135 | (2.0, Array(4.0, 0.0)), 136 | (2.3, Array(0.0, 1.0)), 137 | (2.0, Array(1.0, 1.0)), 138 | (3.0, Array(2.0, 1.0)), 139 | (3.5, Array(3.0, 1.0)), 140 | (0.0, Array(4.0, 1.0)), 141 | (2.0, Array(0.0, 1.0)), 142 | (2.0, Array(1.0, 1.0)), 143 | (3.0, Array(2.0, 1.0)), 144 | (3.2, Array(3.0, 1.0)), 145 | (0.0, Array(4.0, 1.0)), 146 | (0.0, Array(0.0, 2.0)), 147 | (1.0, Array(1.0, 2.0)), 148 | (1.0, Array(2.0, 2.0)), 149 | (2.0, Array(3.0, 2.0)), 150 | (2.0, Array(4.0, 2.0)), 151 | (0.0, Array(0.0, 2.0)), 152 | (1.0, Array(1.0, 2.0)), 153 | (1.0, Array(2.0, 2.0)), 154 | (2.0, Array(3.0, 2.0)), 155 | (2.0, Array(4.0, 2.0)), 156 | (0.0, Array(0.0, 2.0)), 157 | (1.0, Array(1.0, 2.0)), 158 | (1.0, Array(2.0, 2.0)), 159 | (2.0, Array(3.0, 2.0)), 160 | (2.0, Array(4.0, 2.0)) 161 | ) 162 | } 163 | 164 | def labeledData5: Array[(Double, Array[Double])] = { 165 | Array( 166 | (0.0, Array(0.0, 0.0, -52.28)), 167 | (0.0, Array(1.0, 0.0, -32.16)), 168 | (1.0, Array(2.0, 0.0, -73.68)), 169 | (1.0, Array(3.0, 0.0, -26.38)), 170 | (2.0, Array(4.0, 0.0, 13.69)), 171 | (2.0, Array(0.0, 1.0, 42.07)), 172 | (2.0, Array(1.0, 1.0, 22.96)), 173 | (3.0, Array(2.0, 1.0, -33.43)), 174 | (3.0, Array(3.0, 1.0, -61.80)), 175 | (0.0, Array(4.0, 1.0, -81.34)), 176 | (2.0, Array(0.0, 1.0, -68.49)), 177 | (2.0, Array(1.0, 1.0, 64.17)), 178 | (3.0, Array(2.0, Double.NaN, 20.88)), 179 | (3.0, Array(3.0, 1.0, 27.75)), 180 | (0.0, Array(4.0, 1.0, 59.07)), 181 | (0.0, Array(0.0, 2.0, -53.55)), 182 | (1.0, Array(1.0, 2.0, 25.89)), 183 | (1.0, Array(2.0, 2.0, 22.62)), 184 | (2.0, Array(3.0, 2.0, -5.63)), 185 | (2.0, Array(4.0, 2.0, 81.67)), 186 | (0.0, Array(0.0, 2.0, -72.87)), 187 | (1.0, Array(1.0, 2.0, 25.51)), 188 | (1.0, Array(2.0, 2.0, 43.14)), 189 | (2.0, Array(3.0, 2.0, 60.53)), 190 | (2.0, Array(4.0, 2.0, 88.94)), 191 | (0.0, Array(0.0, 2.0, 17.08)), 192 | (1.0, Array(1.0, 2.0, 69.48)), 193 | (1.0, Array(2.0, 2.0, -76.47)), 194 | (2.0, Array(3.0, 2.0, 90.90)), 195 | (2.0, Array(4.0, 2.0, -79.67)) 196 | ) 197 | } 198 | 199 | def labeledData6: Array[(Double, Array[Double])] = { 200 | Array( 201 | (0.0, Array(0.0, 0.0)), 202 | (0.0, Array(1.0, 0.0)), 203 | (1.0, Array(2.0, Double.NaN)), 204 | (1.0, Array(3.0, Double.NaN)), 205 | (2.0, Array(4.0, 0.0)), 206 | (2.0, Array(0.0, 1.0)), 207 | (2.0, Array(1.0, 1.0)), 208 | (3.0, Array(Double.NaN, 1.0)), 209 | (3.0, Array(Double.NaN, 1.0)), 210 | (0.0, Array(4.0, 1.0)), 211 | (2.0, Array(0.0, 1.0)), 212 | (2.0, Array(1.0, 1.0)), 213 | (3.0, Array(2.0, 1.0)), 214 | (3.0, Array(3.0, 1.0)), 215 | (0.0, Array(4.0, 1.0)), 216 | (0.0, Array(0.0, 2.0)), 217 | (1.0, Array(1.0, 2.0)), 218 | (1.0, Array(2.0, 2.0)), 219 | (2.0, Array(3.0, 2.0)), 220 | (2.0, Array(4.0, 2.0)), 221 | (0.0, Array(0.0, 2.0)), 222 | (1.0, Array(1.0, 2.0)), 223 | (1.0, Array(2.0, 2.0)), 224 | (2.0, Array(3.0, 2.0)), 225 | (2.0, Array(Double.NaN, 2.0)), 226 | (0.0, Array(0.0, 2.0)), 227 | (1.0, Array(1.0, 2.0)), 228 | (1.0, Array(2.0, 2.0)), 229 | (2.0, Array(3.0, 2.0)), 230 | (2.0, Array(4.0, 2.0)) 231 | ) 232 | } 233 | } 234 | --------------------------------------------------------------------------------