├── .gitignore ├── LICENSE ├── README.md ├── build.sbt ├── project └── plugins.sbt └── src ├── main └── scala │ └── org │ └── apache │ └── spark │ └── ml │ └── tree │ ├── impl │ ├── Yggdrasil.scala │ ├── YggdrasilClassification.scala │ ├── YggdrasilRegression.scala │ └── YggdrasilUtil.scala │ ├── impurities.scala │ └── ygg │ ├── Node.scala │ └── Split.scala └── test └── scala └── org └── apache └── spark ├── ml └── tree │ └── impl │ ├── YggdrasilSuite.scala │ └── YggdrasilUtilSuite.scala └── mllib └── util ├── MLlibTestSparkContext.scala └── SparkFunSuite.scala /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.class 3 | *.swp 4 | project/ 5 | target/ 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Yggdrasil: Faster Decision Trees Using Column Partitioning in Spark 2 | 3 | Yggdrasil is a more efficient way in [Apache Spark](http://spark.apache.org) to 4 | train decision trees for large depths and datasets with a high number of 5 | features. For depths greater than 10, Yggdrasil is an order of magnitude faster 6 | than Spark MLlib v1.6.0. 7 | 8 | ## Usage 9 | 10 | Add the dependency to your SBT project by adding the following to `build.sbt` 11 | (see the [Spark Packages 12 | listing](http://spark-packages.org/package/amplab/spark-indexedrdd) for 13 | spark-submit and Maven instructions): 14 | 15 | ```scala 16 | resolvers += "Spark Packages Repo" at "http://dl.bintray.com/spark-packages/maven" 17 | 18 | libraryDependencies += "fabuzaid21" % "yggdrasil" % "1.0" 19 | ``` 20 | 21 | Then use Yggdrasil as follows: 22 | 23 | ```scala 24 | import org.apache.spark.ml.tree.impl.YggdrasilClassifier // YgddrasilRegressor 25 | 26 | // Identical to the Spark MLlib Decision Tree API 27 | val dt = new YggdrasilClassifier() 28 | .setFeaturesCol("indexedFeatures") 29 | .setLabelCol(labelColName) 30 | .setMaxDepth(params.maxDepth) 31 | .setMaxBins(params.maxBins) 32 | .setMinInstancesPerNode(params.minInstancesPerNode) 33 | .setMinInfoGain(params.minInfoGain) 34 | .setCacheNodeIds(params.cacheNodeIds) 35 | .setCheckpointInterval(params.checkpointInterval) 36 | ``` 37 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "Yggdrasil" 2 | 3 | version := "1.0.1" 4 | 5 | scalaVersion := "2.10.4" 6 | 7 | spName := "fabuzaid21/yggdrasil" 8 | 9 | sparkVersion := "1.6.0" 10 | 11 | sparkComponents ++= Seq("mllib", "sql") 12 | 13 | libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.4" % "test" 14 | 15 | spShortDescription := "Yggdrasil: Faster Decision Trees Using Column Partitioning in Spark" 16 | 17 | spDescription := """Yggdrasil is a more efficient way in [Apache Spark](http://spark.apache.org) 18 | | to train decision trees for large depths and datasets with a 19 | | high number of features. For depths greater than 10, Yggdrasil is an order 20 | | of magnitude faster than Spark MLlib v1.6.0.""".stripMargin 21 | 22 | // You must have an Open Source License. Some common licenses can be found in: http://opensource.org/licenses 23 | licenses += "Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0") 24 | 25 | credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials") -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" 2 | 3 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.4") 4 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/tree/impl/Yggdrasil.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.tree.impl 19 | 20 | import org.apache.spark.Logging 21 | import org.apache.spark.broadcast.Broadcast 22 | import org.apache.spark.ml.Predictor 23 | import org.apache.spark.ml.classification.DecisionTreeClassificationModel 24 | import org.apache.spark.ml.param.ParamMap 25 | import org.apache.spark.ml.regression.DecisionTreeRegressionModel 26 | import org.apache.spark.ml.tree.impl.YggdrasilUtil._ 27 | import org.apache.spark.ml.tree.{ygg, Node => SparkNode, _} 28 | import org.apache.spark.ml.util.{Identifiable, MetadataUtils} 29 | import org.apache.spark.mllib.linalg.Vector 30 | import org.apache.spark.mllib.regression.LabeledPoint 31 | import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} 32 | import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata 33 | import org.apache.spark.mllib.tree.impurity._ 34 | import org.apache.spark.mllib.tree.model.{ImpurityStats, Predict} 35 | import org.apache.spark.rdd.RDD 36 | import org.apache.spark.sql.DataFrame 37 | import org.apache.spark.util.collection.{BitSet, SortDataFormat, Sorter} 38 | import org.roaringbitmap.RoaringBitmap 39 | 40 | final class YggdrasilClassifier(override val uid: String) 41 | extends Predictor[Vector, YggdrasilClassifier, DecisionTreeClassificationModel] 42 | with DecisionTreeParams with TreeClassifierParams { 43 | 44 | def this() = this(Identifiable.randomUID("yggc")) 45 | 46 | // Override parameter setters from parent trait for Java API compatibility. 47 | override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) 48 | 49 | override def setMaxBins(value: Int): this.type = super.setMaxBins(value) 50 | 51 | override def setMinInstancesPerNode(value: Int): this.type = 52 | super.setMinInstancesPerNode(value) 53 | 54 | override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) 55 | 56 | override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) 57 | 58 | override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) 59 | 60 | override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) 61 | 62 | override def setImpurity(value: String): this.type = super.setImpurity(value) 63 | 64 | override def setSeed(value: Long): this.type = super.setSeed(value) 65 | 66 | override def copy(extra: ParamMap): YggdrasilClassifier = defaultCopy(extra) 67 | 68 | def train( 69 | input: RDD[LabeledPoint], 70 | transposedDataset: RDD[(Int, Array[Double])], 71 | categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = { 72 | 73 | val numClasses: Int = input.map(_.label).distinct().count().toInt 74 | val strategy = getOldStrategy(categoricalFeatures, numClasses) 75 | val model = Yggdrasil.train(input, strategy, Some(transposedDataset), parentUID = Some(uid)) 76 | model.asInstanceOf[DecisionTreeClassificationModel] 77 | } 78 | 79 | override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { 80 | val categoricalFeatures: Map[Int, Int] = 81 | MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) 82 | val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { 83 | case Some(n: Int) => n 84 | case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" + 85 | s" with invalid label column ${$(labelCol)}, without the number of classes" + 86 | " specified. See StringIndexer.") 87 | // TODO: Automatically index labels: SPARK-7126 88 | } 89 | 90 | val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) 91 | val strategy = getOldStrategy(categoricalFeatures, numClasses) 92 | val model = Yggdrasil.train(oldDataset, strategy, colStoreInput = None, parentUID = Some(uid)) 93 | model.asInstanceOf[DecisionTreeClassificationModel] 94 | } 95 | 96 | /** Create a Strategy instance to use with the old API. */ 97 | private[impl] def getOldStrategy(categoricalFeatures: Map[Int, Int], numClasses: Int): OldStrategy = { 98 | super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity, 99 | subsamplingRate = 1.0) 100 | } 101 | } 102 | 103 | final class YggdrasilRegressor(override val uid: String) 104 | extends Predictor[Vector, YggdrasilRegressor, DecisionTreeRegressionModel] 105 | with DecisionTreeParams with TreeRegressorParams { 106 | 107 | def this() = this(Identifiable.randomUID("yggr")) 108 | 109 | // Override parameter setters from parent trait for Java API compatibility. 110 | override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value) 111 | 112 | override def setMaxBins(value: Int): this.type = super.setMaxBins(value) 113 | 114 | override def setMinInstancesPerNode(value: Int): this.type = 115 | super.setMinInstancesPerNode(value) 116 | 117 | override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value) 118 | 119 | override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value) 120 | 121 | override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value) 122 | 123 | override def setCheckpointInterval(value: Int): this.type = super.setCheckpointInterval(value) 124 | 125 | override def setImpurity(value: String): this.type = super.setImpurity(value) 126 | 127 | override def setSeed(value: Long): this.type = super.setSeed(value) 128 | 129 | override def copy(extra: ParamMap): YggdrasilRegressor = defaultCopy(extra) 130 | 131 | def train( 132 | input: RDD[LabeledPoint], 133 | transposedDataset: RDD[(Int, Array[Double])], 134 | categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = { 135 | val strategy = getOldStrategy(categoricalFeatures) 136 | val model = Yggdrasil.train(input, strategy, Some(transposedDataset), parentUID = Some(uid)) 137 | model.asInstanceOf[DecisionTreeRegressionModel] 138 | } 139 | 140 | override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { 141 | val categoricalFeatures: Map[Int, Int] = 142 | MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) 143 | val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) 144 | val strategy = getOldStrategy(categoricalFeatures) 145 | val model = Yggdrasil.train(oldDataset, strategy, colStoreInput = None, parentUID = Some(uid)) 146 | model.asInstanceOf[DecisionTreeRegressionModel] 147 | } 148 | 149 | /** Create a Strategy instance to use with the old API. */ 150 | private[impl] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { 151 | super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity, 152 | subsamplingRate = 1.0) 153 | } 154 | } 155 | 156 | /** 157 | * DecisionTree which partitions data by feature. 158 | * 159 | * Algorithm: 160 | * - Repartition data, grouping by feature. 161 | * - Prep data (sort continuous features). 162 | * - On each partition, initialize instance--node map with each instance at root node. 163 | * - Iterate, training 1 new level of the tree at a time: 164 | * - On each partition, for each feature on the partition, select the best split for each node. 165 | * - Aggregate best split for each node. 166 | * - Aggregate bit vector (1 bit/instance) indicating whether each instance splits 167 | * left or right. 168 | * - Broadcast bit vector. On each partition, update instance--node map. 169 | * 170 | * TODO: Update to use a sparse column store. 171 | */ 172 | private[ml] object Yggdrasil extends Logging { 173 | 174 | private[impl] class YggdrasilMetadata( 175 | val numClasses: Int, 176 | val maxBins: Int, 177 | val minInfoGain: Double, 178 | val impurity: Impurity, 179 | val categoricalFeaturesInfo: Map[Int, Int]) extends Serializable { 180 | 181 | private val unorderedSplits = { 182 | /** 183 | * borrowed from [[DecisionTreeMetadata.buildMetadata]] 184 | */ 185 | if (numClasses > 2) { 186 | // Multiclass classification 187 | val maxCategoriesForUnorderedFeature = 188 | ((math.log(maxBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt 189 | categoricalFeaturesInfo.filter { case (featureIndex, numCategories) => 190 | numCategories > 1 && numCategories <= maxCategoriesForUnorderedFeature 191 | }.map { case (featureIndex, numCategories) => 192 | // Hack: If a categorical feature has only 1 category, we treat it as continuous. 193 | // TODO(SPARK-9957): Handle this properly by filtering out those features. 194 | // Decide if some categorical features should be treated as unordered features, 195 | // which require 2 * ((1 << numCategories - 1) - 1) bins. 196 | // We do this check with log values to prevent overflows in case numCategories is large. 197 | // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins 198 | featureIndex -> findSplits(featureIndex, numCategories) 199 | } 200 | } else { 201 | Map.empty[Int, Array[ygg.CategoricalSplit]] 202 | } 203 | } 204 | 205 | /** 206 | * Returns all possible subsets of features for categorical splits. 207 | * Borrowed from [[RandomForest.findSplits]] 208 | */ 209 | private def findSplits( 210 | featureIndex: Int, 211 | featureArity: Int): Array[ygg.CategoricalSplit] = { 212 | // Unordered features 213 | // 2^(featureArity - 1) - 1 combinations 214 | val numSplits = (1 << (featureArity - 1)) - 1 215 | val splits = new Array[ygg.CategoricalSplit](numSplits) 216 | 217 | var splitIndex = 0 218 | while (splitIndex < numSplits) { 219 | val categories: List[Double] = 220 | RandomForest.extractMultiClassCategories(splitIndex + 1, featureArity) 221 | splits(splitIndex) = 222 | new ygg.CategoricalSplit(featureIndex, categories.toArray, featureArity) 223 | splitIndex += 1 224 | } 225 | splits 226 | } 227 | 228 | def getUnorderedSplits(featureIndex: Int): Array[ygg.CategoricalSplit] = unorderedSplits(featureIndex) 229 | 230 | def isClassification: Boolean = numClasses >= 2 231 | 232 | def isMulticlass: Boolean = numClasses > 2 233 | 234 | def isUnorderedFeature(featureIndex: Int): Boolean = unorderedSplits.contains(featureIndex) 235 | 236 | def createImpurityAggregator(): ImpurityAggregatorSingle = { 237 | impurity match { 238 | case Entropy => new EntropyAggregatorSingle(numClasses) 239 | case Gini => new GiniAggregatorSingle(numClasses) 240 | case Variance => new VarianceAggregatorSingle 241 | } 242 | } 243 | } 244 | 245 | private[impl] object YggdrasilMetadata { 246 | def fromStrategy(strategy: OldStrategy): YggdrasilMetadata = new YggdrasilMetadata(strategy.numClasses, 247 | strategy.maxBins, strategy.minInfoGain, strategy.impurity, strategy.categoricalFeaturesInfo) 248 | } 249 | 250 | /** 251 | * Method to train a decision tree model over an RDD. 252 | */ 253 | def train( 254 | input: RDD[LabeledPoint], 255 | strategy: OldStrategy, 256 | colStoreInput: Option[RDD[(Int, Array[Double])]] = None, 257 | parentUID: Option[String] = None): DecisionTreeModel = { 258 | // TODO: Check validity of params 259 | // TODO: Check for empty dataset 260 | val numFeatures = input.first().features.size 261 | val rootNode = trainImpl(input, strategy, colStoreInput) 262 | finalizeTree(rootNode, strategy.algo, strategy.numClasses, numFeatures, 263 | parentUID) 264 | } 265 | 266 | private[impl] def finalizeTree( 267 | rootNode: SparkNode, 268 | algo: OldAlgo.Algo, 269 | numClasses: Int, 270 | numFeatures: Int, 271 | parentUID: Option[String]): DecisionTreeModel = { 272 | parentUID match { 273 | case Some(uid) => 274 | if (algo == OldAlgo.Classification) { 275 | new DecisionTreeClassificationModel(uid, rootNode, numFeatures = numFeatures, 276 | numClasses = numClasses) 277 | } else { 278 | new DecisionTreeRegressionModel(uid, rootNode, numFeatures = numFeatures) 279 | } 280 | case None => 281 | if (algo == OldAlgo.Classification) { 282 | new DecisionTreeClassificationModel(rootNode, numFeatures = numFeatures, 283 | numClasses = numClasses) 284 | } else { 285 | new DecisionTreeRegressionModel(rootNode, numFeatures = numFeatures) 286 | } 287 | } 288 | } 289 | 290 | private[impl] def getPredict(impurityCalculator: ImpurityCalculator): Predict = { 291 | val pred = impurityCalculator.predict 292 | new Predict(predict = pred, prob = impurityCalculator.prob(pred)) 293 | } 294 | 295 | private[impl] def trainImpl( 296 | input: RDD[LabeledPoint], 297 | strategy: OldStrategy, 298 | colStoreInput: Option[RDD[(Int, Array[Double])]]): SparkNode = { 299 | val metadata = YggdrasilMetadata.fromStrategy(strategy) 300 | 301 | // The case with 1 node (depth = 0) is handled separately. 302 | // This allows all iterations in the depth > 0 case to use the same code. 303 | // TODO: Check that learning works when maxDepth > 0 but learning stops at 1 node (because of 304 | // other parameters). 305 | if (strategy.maxDepth == 0) { 306 | val impurityAggregator: ImpurityAggregatorSingle = 307 | input.aggregate(metadata.createImpurityAggregator())( 308 | (agg, lp) => agg.update(lp.label, 1.0), 309 | (agg1, agg2) => agg1.add(agg2)) 310 | val impurityCalculator = impurityAggregator.getCalculator 311 | return new LeafNode(getPredict(impurityCalculator).predict, impurityCalculator.calculate(), 312 | impurityCalculator) 313 | } 314 | 315 | // Prepare column store. 316 | // Note: rowToColumnStoreDense checks to make sure numRows < Int.MaxValue. 317 | // TODO: Is this mapping from arrays to iterators to arrays (when constructing learningData)? 318 | // Or is the mapping implicit (i.e., not costly)? 319 | val colStoreInit: RDD[(Int, Array[Double])] = colStoreInput.getOrElse( 320 | rowToColumnStoreDense(input.map(_.features))) 321 | val numRows: Int = colStoreInit.first()._2.length 322 | if (metadata.numClasses > 1 && metadata.numClasses <= 32) { 323 | YggdrasilClassification.trainImpl(input, colStoreInit, metadata, numRows, strategy.maxDepth) 324 | } else { 325 | YggdrasilRegression.trainImpl(input, colStoreInit, metadata, numRows, strategy.maxDepth) 326 | } 327 | } 328 | 329 | 330 | /** 331 | * On driver: Grow tree based on chosen splits, and compute new set of active nodes. 332 | * @param oldPeriphery Old periphery of active nodes. 333 | * @param bestSplitsAndGains Best (split, gain) pairs, which can be zipped with the old 334 | * periphery. These stats will be used to replace the stats in 335 | * any nodes which are split. 336 | * @param minInfoGain Threshold for min info gain required to split a node. 337 | * @return New active node periphery. 338 | * If a node is split, then this method will update its fields. 339 | */ 340 | private[impl] def computeActiveNodePeriphery( 341 | oldPeriphery: Array[ygg.LearningNode], 342 | bestSplitsAndGains: Array[(Option[ygg.Split], ImpurityStats)], 343 | minInfoGain: Double): Array[ygg.LearningNode] = { 344 | bestSplitsAndGains.zipWithIndex.flatMap { 345 | case ((split, stats), nodeIdx) => 346 | val node = oldPeriphery(nodeIdx) 347 | if (split.nonEmpty && stats.gain > minInfoGain) { 348 | // TODO: remove node id 349 | node.leftChild = Some(ygg.LearningNode(node.id * 2, isLeaf = false, 350 | new ImpurityStats(Double.NaN, stats.leftImpurity, stats.leftImpurityCalculator, 351 | null, null, true))) 352 | node.rightChild = Some(ygg.LearningNode(node.id * 2 + 1, isLeaf = false, 353 | new ImpurityStats(Double.NaN, stats.rightImpurity, stats.rightImpurityCalculator, 354 | null, null, true))) 355 | node.split = split 356 | node.isLeaf = false 357 | node.stats = stats 358 | Iterator(node.leftChild.get, node.rightChild.get) 359 | } else { 360 | node.isLeaf = true 361 | Iterator() 362 | } 363 | } 364 | } 365 | 366 | /** 367 | * Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right. 368 | * - Send chosen splits to workers. 369 | * - Each worker creates part of the bit vector corresponding to the splits it created. 370 | * - Aggregate the partial bit vectors to create one vector (of length numRows). 371 | * Correction: Aggregate only the pieces of that vector corresponding to instances at 372 | * active nodes. 373 | * @param partitionInfos RDD with feature data, plus current status metadata 374 | * @param bestSplits Split for each active node, or None if that node will not be split 375 | * @return Array of bit vectors, ordered by offset ranges 376 | */ 377 | private[impl] def aggregateBitVector( 378 | partitionInfos: RDD[PartitionInfo], 379 | bestSplits: Array[Option[ygg.Split]], 380 | numRows: Int): RoaringBitmap = { 381 | val bestSplitsBc: Broadcast[Array[Option[ygg.Split]]] = 382 | partitionInfos.sparkContext.broadcast(bestSplits) 383 | val workerBitSubvectors: RDD[RoaringBitmap] = partitionInfos.map { 384 | case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], 385 | activeNodes: BitSet, fullImpurities: Array[ImpurityAggregatorSingle]) => 386 | val localBestSplits: Array[Option[ygg.Split]] = bestSplitsBc.value 387 | // localFeatureIndex[feature index] = index into PartitionInfo.columns 388 | val localFeatureIndex: Map[Int, Int] = columns.map(_.featureIndex).zipWithIndex.toMap 389 | val bitSetForNodes: Iterator[RoaringBitmap] = activeNodes.iterator 390 | .zip(localBestSplits.iterator).flatMap { 391 | case (nodeIndexInLevel: Int, Some(split: ygg.Split)) => 392 | if (localFeatureIndex.contains(split.featureIndex)) { 393 | // This partition has the column (feature) used for this split. 394 | val fromOffset = nodeOffsets(nodeIndexInLevel) 395 | val toOffset = nodeOffsets(nodeIndexInLevel + 1) 396 | val colIndex: Int = localFeatureIndex(split.featureIndex) 397 | Iterator(bitVectorFromSplit(columns(colIndex), fromOffset, toOffset, split, numRows)) 398 | } else { 399 | Iterator() 400 | } 401 | case (nodeIndexInLevel: Int, None) => 402 | // Do not create a bitVector when there is no split. 403 | // PartitionInfo.update will detect that there is no 404 | // split, by how many instances go left/right. 405 | Iterator() 406 | } 407 | if (bitSetForNodes.isEmpty) { 408 | new RoaringBitmap() 409 | } else { 410 | bitSetForNodes.reduce[RoaringBitmap] { (acc, bitv) => acc.or(bitv); acc } 411 | } 412 | } 413 | val aggBitVector: RoaringBitmap = workerBitSubvectors.reduce { (acc, bitv) => 414 | acc.or(bitv) 415 | acc 416 | } 417 | bestSplitsBc.unpersist() 418 | aggBitVector 419 | } 420 | 421 | /** 422 | * For a given feature, for a given node, apply a split and return a bit vector indicating the 423 | * outcome of the split for each instance at that node. 424 | * 425 | * @param col Column for feature 426 | * @param from Start offset in col for the node 427 | * @param to End offset in col for the node 428 | * @param split Split to apply to instances at this node. 429 | * @return Bits indicating splits for instances at this node. 430 | * These bits are sorted by the row indices, in order to guarantee an ordering 431 | * understood by all workers. 432 | * Thus, the bit indices used are based on 2-level sorting: first by node, and 433 | * second by sorted row indices within the node's rows. 434 | * bit[index in sorted array of row indices] = false for left, true for right 435 | */ 436 | private[impl] def bitVectorFromSplit( 437 | col: FeatureVector, 438 | from: Int, 439 | to: Int, 440 | split: ygg.Split, 441 | numRows: Int): RoaringBitmap = { 442 | val bitv = new RoaringBitmap() 443 | var i = from 444 | while (i < to) { 445 | val value = col.values(i) 446 | val idx = col.indices(i) 447 | if (!split.shouldGoLeft(value)) { 448 | bitv.add(idx) 449 | } 450 | i += 1 451 | } 452 | bitv 453 | } 454 | 455 | /** 456 | * Intermediate data stored on each partition during learning. 457 | * 458 | * Node indexing for nodeOffsets, activeNodes: 459 | * Nodes are indexed left-to-right along the periphery of the tree, with 0-based indices. 460 | * The periphery is the set of leaf nodes (active and inactive). 461 | * 462 | * @param columns Subset of columns (features) stored in this partition. 463 | * Each column is sorted first by nodes (left-to-right along the tree periphery); 464 | * all columns share this first level of sorting. 465 | * Within each node's group, each column is sorted based on feature value; 466 | * this second level of sorting differs across columns. 467 | * @param nodeOffsets Offsets into the columns indicating the first level of sorting (by node). 468 | * The rows corresponding to node i are in the range 469 | * [nodeOffsets(i), nodeOffsets(i+1)). 470 | * @param activeNodes Nodes which are active (still being split). 471 | * Inactive nodes are known to be leafs in the final tree. 472 | * TODO: Should this (and even nodeOffsets) not be stored in PartitionInfo, 473 | * but instead on the driver? 474 | */ 475 | private[impl] case class PartitionInfo( 476 | columns: Array[FeatureVector], 477 | nodeOffsets: Array[Int], 478 | activeNodes: BitSet, 479 | fullImpurityAggs: Array[ImpurityAggregatorSingle]) extends Serializable { 480 | 481 | // pre-allocated temporary buffers that we use to sort 482 | // instances in left and right children during update 483 | val tempVals: Array[Double] = new Array[Double](columns(0).values.length) 484 | val tempIndices: Array[Int] = new Array[Int](columns(0).values.length) 485 | 486 | /** For debugging */ 487 | override def toString: String = { 488 | "PartitionInfo(" + 489 | " columns: {\n" + 490 | columns.mkString(",\n") + 491 | " },\n" + 492 | s" nodeOffsets: ${nodeOffsets.mkString(", ")},\n" + 493 | s" activeNodes: ${activeNodes.iterator.mkString(", ")},\n" + 494 | ")\n" 495 | } 496 | 497 | /** 498 | * Update columns and nodeOffsets for the next level of the tree. 499 | * 500 | * Update columns: 501 | * For each column, 502 | * For each (previously) active node, 503 | * Sort corresponding range of instances based on bit vector. 504 | * Update nodeOffsets, activeNodes: 505 | * Split offsets for nodes which split (which can be identified using the bit vector). 506 | * 507 | * @param instanceBitVector Bit vector encoding splits for the next level of the tree. 508 | * These must follow a 2-level ordering, where the first level is by node 509 | * and the second level is by row index. 510 | * bitVector(i) = false iff instance i goes to the left child. 511 | * For instances at inactive (leaf) nodes, the value can be arbitrary. 512 | * @return Updated partition info 513 | */ 514 | def update(instanceBitVector: BitSet, newNumNodeOffsets: Int, 515 | labels: Array[Byte], metadata: YggdrasilMetadata): PartitionInfo = { 516 | // Create a 2-level representation of the new nodeOffsets (to be flattened). 517 | // These 2 levels correspond to original nodes and their children (if split). 518 | val newNodeOffsets = nodeOffsets.map(Array(_)) 519 | val newFullImpurityAggs = fullImpurityAggs.map(Array(_)) 520 | 521 | val newColumns = columns.zipWithIndex.map { case (col, index) => 522 | index match { 523 | case 0 => first(col, instanceBitVector, metadata, labels, newNodeOffsets, newFullImpurityAggs) 524 | case _ => rest(col, instanceBitVector, newNodeOffsets) 525 | } 526 | col 527 | } 528 | 529 | // Identify the new activeNodes based on the 2-level representation of the new nodeOffsets. 530 | val newActiveNodes = new BitSet(newNumNodeOffsets - 1) 531 | var newNodeOffsetsIdx = 0 532 | var i = 0 533 | while (i < newNodeOffsets.length) { 534 | val offsets = newNodeOffsets(i) 535 | if (offsets.length == 2) { 536 | newActiveNodes.set(newNodeOffsetsIdx) 537 | newActiveNodes.set(newNodeOffsetsIdx + 1) 538 | newNodeOffsetsIdx += 2 539 | } else { 540 | newNodeOffsetsIdx += 1 541 | } 542 | i += 1 543 | } 544 | PartitionInfo(newColumns, newNodeOffsets.flatten, newActiveNodes, newFullImpurityAggs.flatten) 545 | } 546 | 547 | 548 | /** 549 | * Sort the very first column in the [[PartitionInfo.columns]]. While 550 | * we sort the column, we also update [[PartitionInfo.nodeOffsets]] 551 | * (by modifying @param newNodeOffsets) and [[PartitionInfo.fullImpurityAggs]] 552 | * (by modifying @param newFullImpurityAggs). 553 | * @param col The very first column in [[PartitionInfo.columns]] 554 | * @param metadata Used to create new [[ImpurityAggregatorSingle]] for a new child 555 | * node in the tree 556 | * @param labels Labels are read as we sort column to populate stats for each 557 | * new ImpurityAggregatorSingle 558 | */ 559 | private def first( 560 | col: FeatureVector, 561 | instanceBitVector: BitSet, 562 | metadata: YggdrasilMetadata, 563 | labels: Array[Byte], 564 | newNodeOffsets: Array[Array[Int]], 565 | newFullImpurityAggs: Array[Array[ImpurityAggregatorSingle]]) = { 566 | activeNodes.iterator.foreach { nodeIdx => 567 | // WHAT TO OPTIMIZE: 568 | // - try skipping numBitsSet 569 | // - maybe uncompress bitmap 570 | val from = nodeOffsets(nodeIdx) 571 | val to = nodeOffsets(nodeIdx + 1) 572 | 573 | // If this is the very first time we split, 574 | // we don't use rangeIndices to count the number of bits set; 575 | // the entire bit vector will be used, so getCardinality 576 | // will give us the same result more cheaply. 577 | val numBitsSet = { 578 | if (nodeOffsets.length == 2) instanceBitVector.cardinality() 579 | else { 580 | var count = 0 581 | var i = from 582 | while (i < to) { 583 | if (instanceBitVector.get(col.indices(i))) { 584 | count += 1 585 | } 586 | i += 1 587 | } 588 | count 589 | } 590 | } 591 | 592 | val numBitsNotSet = to - from - numBitsSet // number of instances splitting left 593 | val oldOffset = newNodeOffsets(nodeIdx).head 594 | 595 | // If numBitsNotSet or numBitsSet equals 0, then this node was not split, 596 | // so we do not need to update its part of the column. Otherwise, we update it. 597 | if (numBitsNotSet != 0 && numBitsSet != 0) { 598 | newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet) 599 | 600 | val leftImpurity = metadata.createImpurityAggregator() 601 | val rightImpurity = metadata.createImpurityAggregator() 602 | 603 | // BEGIN SORTING 604 | // We sort the [from, to) slice of col based on instance bit, then 605 | // instance value. This is required to match the bit vector across all 606 | // workers. All instances going "left" in the split (which are false) 607 | // should be ordered before the instances going "right". The instanceBitVector 608 | // gives us the bit value for each instance based on the instance's index. 609 | // Then both [from, numBitsNotSet) and [numBitsNotSet, to) need to be sorted 610 | // by value. 611 | // Since the column is already sorted by value, we can compute 612 | // this sort in a single pass over the data. We iterate from start to finish 613 | // (which preserves the sorted order), and then copy the values 614 | // into @tempVals and @tempIndices either: 615 | // 1) in the [from, numBitsNotSet) range if the bit is false, or 616 | // 2) in the [numBitsNotSet, to) range if the bit is true. 617 | var (leftInstanceIdx, rightInstanceIdx) = (from, from + numBitsNotSet) 618 | var idx = from 619 | while (idx < to) { 620 | val indexForVal = col.indices(idx) 621 | val bit = instanceBitVector.get(indexForVal) 622 | val label = labels(indexForVal) 623 | if (bit) { 624 | rightImpurity.update(label) 625 | tempVals(rightInstanceIdx) = col.values(idx) 626 | tempIndices(rightInstanceIdx) = indexForVal 627 | rightInstanceIdx += 1 628 | } else { 629 | leftImpurity.update(label) 630 | tempVals(leftInstanceIdx) = col.values(idx) 631 | tempIndices(leftInstanceIdx) = indexForVal 632 | leftInstanceIdx += 1 633 | } 634 | idx += 1 635 | } 636 | // END SORTING 637 | 638 | newFullImpurityAggs(nodeIdx) = Array(leftImpurity, rightImpurity) 639 | // update the column values and indices 640 | // with the corresponding indices 641 | System.arraycopy(tempVals, from, col.values, from, to - from) 642 | System.arraycopy(tempIndices, from, col.indices, from, to - from) 643 | } 644 | } 645 | } 646 | 647 | /** 648 | * Update columns and nodeOffsets for the next level of the tree. 649 | * 650 | * Update columns: 651 | * For each column, 652 | * For each (previously) active node, 653 | * Sort corresponding range of instances based on bit vector. 654 | * Update nodeOffsets, activeNodes: 655 | * Split offsets for nodes which split (which can be identified using the bit vector). 656 | * 657 | * @param instanceBitVector Bit vector encoding splits for the next level of the tree. 658 | * These must follow a 2-level ordering, where the first level is by node 659 | * and the second level is by row index. 660 | * bitVector(i) = false iff instance i goes to the left child. 661 | * For instances at inactive (leaf) nodes, the value can be arbitrary. 662 | * @return Updated partition info 663 | */ 664 | def update(instanceBitVector: BitSet, newNumNodeOffsets: Int, 665 | labels: Array[Double], metadata: YggdrasilMetadata): PartitionInfo = { 666 | // Create a 2-level representation of the new nodeOffsets (to be flattened). 667 | // These 2 levels correspond to original nodes and their children (if split). 668 | val newNodeOffsets = nodeOffsets.map(Array(_)) 669 | val newFullImpurityAggs = fullImpurityAggs.map(Array(_)) 670 | 671 | val newColumns = columns.zipWithIndex.map { case (col, index) => 672 | index match { 673 | case 0 => first(col, instanceBitVector, metadata, labels, newNodeOffsets, newFullImpurityAggs) 674 | case _ => rest(col, instanceBitVector, newNodeOffsets) 675 | } 676 | col 677 | } 678 | 679 | // Identify the new activeNodes based on the 2-level representation of the new nodeOffsets. 680 | val newActiveNodes = new BitSet(newNumNodeOffsets - 1) 681 | var newNodeOffsetsIdx = 0 682 | var i = 0 683 | while (i < newNodeOffsets.length) { 684 | val offsets = newNodeOffsets(i) 685 | if (offsets.length == 2) { 686 | newActiveNodes.set(newNodeOffsetsIdx) 687 | newActiveNodes.set(newNodeOffsetsIdx + 1) 688 | newNodeOffsetsIdx += 2 689 | } else { 690 | newNodeOffsetsIdx += 1 691 | } 692 | i += 1 693 | } 694 | PartitionInfo(newColumns, newNodeOffsets.flatten, newActiveNodes, newFullImpurityAggs.flatten) 695 | } 696 | 697 | 698 | /** 699 | * Sort the very first column in the [[PartitionInfo.columns]]. While 700 | * we sort the column, we also update [[PartitionInfo.nodeOffsets]] 701 | * (by modifying @param newNodeOffsets) and [[PartitionInfo.fullImpurityAggs]] 702 | * (by modifying @param newFullImpurityAggs). 703 | * @param col The very first column in [[PartitionInfo.columns]] 704 | * @param metadata Used to create new [[ImpurityAggregatorSingle]] for a new child 705 | * node in the tree 706 | * @param labels Labels are read as we sort column to populate stats for each 707 | * new ImpurityAggregatorSingle 708 | */ 709 | private def first( 710 | col: FeatureVector, 711 | instanceBitVector: BitSet, 712 | metadata: YggdrasilMetadata, 713 | labels: Array[Double], 714 | newNodeOffsets: Array[Array[Int]], 715 | newFullImpurityAggs: Array[Array[ImpurityAggregatorSingle]]) = { 716 | activeNodes.iterator.foreach { nodeIdx => 717 | // WHAT TO OPTIMIZE: 718 | // - try skipping numBitsSet 719 | // - maybe uncompress bitmap 720 | val from = nodeOffsets(nodeIdx) 721 | val to = nodeOffsets(nodeIdx + 1) 722 | 723 | // If this is the very first time we split, 724 | // we don't use rangeIndices to count the number of bits set; 725 | // the entire bit vector will be used, so getCardinality 726 | // will give us the same result more cheaply. 727 | val numBitsSet = { 728 | if (nodeOffsets.length == 2) instanceBitVector.cardinality() 729 | else { 730 | var count = 0 731 | var i = from 732 | while (i < to) { 733 | if (instanceBitVector.get(col.indices(i))) { 734 | count += 1 735 | } 736 | i += 1 737 | } 738 | count 739 | } 740 | } 741 | 742 | val numBitsNotSet = to - from - numBitsSet // number of instances splitting left 743 | val oldOffset = newNodeOffsets(nodeIdx).head 744 | 745 | // If numBitsNotSet or numBitsSet equals 0, then this node was not split, 746 | // so we do not need to update its part of the column. Otherwise, we update it. 747 | if (numBitsNotSet != 0 && numBitsSet != 0) { 748 | newNodeOffsets(nodeIdx) = Array(oldOffset, oldOffset + numBitsNotSet) 749 | 750 | val leftImpurity = metadata.createImpurityAggregator() 751 | val rightImpurity = metadata.createImpurityAggregator() 752 | 753 | // BEGIN SORTING 754 | // We sort the [from, to) slice of col based on instance bit, then 755 | // instance value. This is required to match the bit vector across all 756 | // workers. All instances going "left" in the split (which are false) 757 | // should be ordered before the instances going "right". The instanceBitVector 758 | // gives us the bit value for each instance based on the instance's index. 759 | // Then both [from, numBitsNotSet) and [numBitsNotSet, to) need to be sorted 760 | // by value. 761 | // Since the column is already sorted by value, we can compute 762 | // this sort in a single pass over the data. We iterate from start to finish 763 | // (which preserves the sorted order), and then copy the values 764 | // into @tempVals and @tempIndices either: 765 | // 1) in the [from, numBitsNotSet) range if the bit is false, or 766 | // 2) in the [numBitsNotSet, to) range if the bit is true. 767 | var (leftInstanceIdx, rightInstanceIdx) = (from, from + numBitsNotSet) 768 | var idx = from 769 | while (idx < to) { 770 | val indexForVal = col.indices(idx) 771 | val bit = instanceBitVector.get(indexForVal) 772 | val label = labels(indexForVal) 773 | if (bit) { 774 | rightImpurity.update(label) 775 | tempVals(rightInstanceIdx) = col.values(idx) 776 | tempIndices(rightInstanceIdx) = indexForVal 777 | rightInstanceIdx += 1 778 | } else { 779 | leftImpurity.update(label) 780 | tempVals(leftInstanceIdx) = col.values(idx) 781 | tempIndices(leftInstanceIdx) = indexForVal 782 | leftInstanceIdx += 1 783 | } 784 | idx += 1 785 | } 786 | // END SORTING 787 | 788 | newFullImpurityAggs(nodeIdx) = Array(leftImpurity, rightImpurity) 789 | // update the column values and indices 790 | // with the corresponding indices 791 | System.arraycopy(tempVals, from, col.values, from, to - from) 792 | System.arraycopy(tempIndices, from, col.indices, from, to - from) 793 | } 794 | } 795 | } 796 | 797 | /** 798 | * Sort the remaining columns in the [[PartitionInfo.columns]]. Since 799 | * we already computed [[PartitionInfo.nodeOffsets]] and 800 | * [[PartitionInfo.fullImpurityAggs]] while we sorted the first column, 801 | * we skip the computation for those here. 802 | * @param col The very first column in [[PartitionInfo.columns]] 803 | * @param newNodeOffsets Instead of re-computing number of bits set/not set 804 | * per split, we read those values from here 805 | */ 806 | private def rest( 807 | col: FeatureVector, 808 | instanceBitVector: BitSet, 809 | newNodeOffsets: Array[Array[Int]]) = { 810 | activeNodes.iterator.foreach { nodeIdx => 811 | val from = nodeOffsets(nodeIdx) 812 | val to = nodeOffsets(nodeIdx + 1) 813 | val newOffsets = newNodeOffsets(nodeIdx) 814 | 815 | // We determined that this node was split in first() 816 | if (newOffsets.length == 2) { 817 | val numBitsNotSet = newOffsets(1) - newOffsets(0) 818 | 819 | // Same as above, but we don't compute the left and right impurities for 820 | // the resulitng child nodes 821 | var (leftInstanceIdx, rightInstanceIdx) = (from, from + numBitsNotSet) 822 | var idx = from 823 | while (idx < to) { 824 | val indexForVal = col.indices(idx) 825 | val bit = instanceBitVector.get(indexForVal) 826 | if (bit) { 827 | tempVals(rightInstanceIdx) = col.values(idx) 828 | tempIndices(rightInstanceIdx) = indexForVal 829 | rightInstanceIdx += 1 830 | } else { 831 | tempVals(leftInstanceIdx) = col.values(idx) 832 | tempIndices(leftInstanceIdx) = indexForVal 833 | leftInstanceIdx += 1 834 | } 835 | idx += 1 836 | } 837 | 838 | System.arraycopy(tempVals, from, col.values, from, to - from) 839 | System.arraycopy(tempIndices, from, col.indices, from, to - from) 840 | } 841 | } 842 | } 843 | 844 | } 845 | 846 | /** 847 | * Feature vector types are based on (feature type, representation). 848 | * The feature type can be continuous or categorical. 849 | * 850 | * Features are sorted by value, so we must store indices + values. 851 | * These values are currently stored in a dense representation only. 852 | * TODO: Support sparse storage (to optimize deeper levels of the tree), and maybe compressed 853 | * storage (to optimize upper levels of the tree). 854 | * @param featureArity For categorical features, this gives the number of categories. 855 | * For continuous features, this should be set to 0. 856 | */ 857 | private[impl] class FeatureVector( 858 | val featureIndex: Int, 859 | val featureArity: Int, 860 | val values: Array[Double], 861 | val indices: Array[Int]) 862 | extends Serializable { 863 | 864 | def isCategorical: Boolean = featureArity > 0 865 | 866 | /** For debugging */ 867 | override def toString: String = { 868 | " FeatureVector(" + 869 | s" featureIndex: $featureIndex,\n" + 870 | s" featureType: ${if (featureArity == 0) "Continuous" else "Categorical"},\n" + 871 | s" featureArity: $featureArity,\n" + 872 | s" values: ${values.mkString(", ")},\n" + 873 | s" indices: ${indices.mkString(", ")},\n" + 874 | " )" 875 | } 876 | 877 | def deepCopy(): FeatureVector = 878 | new FeatureVector(featureIndex, featureArity, values.clone(), indices.clone()) 879 | 880 | override def equals(other: Any): Boolean = { 881 | other match { 882 | case o: FeatureVector => 883 | featureIndex == o.featureIndex && featureArity == o.featureArity && 884 | values.sameElements(o.values) && indices.sameElements(o.indices) 885 | case _ => false 886 | } 887 | } 888 | } 889 | 890 | private[impl] object FeatureVector { 891 | /** Store column sorted by feature values. */ 892 | def fromOriginal( 893 | featureIndex: Int, 894 | featureArity: Int, 895 | values: Array[Double]): FeatureVector = { 896 | val indices = values.indices.toArray 897 | val fv = new FeatureVector(featureIndex, featureArity, values, indices) 898 | val sorter = new Sorter(new FeatureVectorSortByValue(featureIndex, featureArity)) 899 | sorter.sort(fv, 0, values.length, Ordering[KeyWrapper]) 900 | fv 901 | } 902 | } 903 | 904 | /** 905 | * Sort FeatureVector by values column; @see [[FeatureVector.fromOriginal()]] 906 | * @param featureIndex @param featureArity Passed in so that, if a new 907 | * FeatureVector is allocated during sorting, that new object 908 | * also has the same featureIndex and featureArity 909 | */ 910 | private class FeatureVectorSortByValue(featureIndex: Int, featureArity: Int) 911 | extends SortDataFormat[KeyWrapper, FeatureVector] { 912 | 913 | override def newKey(): KeyWrapper = new KeyWrapper() 914 | 915 | override def getKey(data: FeatureVector, 916 | pos: Int, 917 | reuse: KeyWrapper): KeyWrapper = { 918 | if (reuse == null) { 919 | new KeyWrapper().setKey(data.values(pos)) 920 | } else { 921 | reuse.setKey(data.values(pos)) 922 | } 923 | } 924 | 925 | override def getKey(data: FeatureVector, 926 | pos: Int): KeyWrapper = { 927 | getKey(data, pos, null) 928 | } 929 | 930 | private def swapElements(data: Array[Double], 931 | pos0: Int, 932 | pos1: Int): Unit = { 933 | val tmp = data(pos0) 934 | data(pos0) = data(pos1) 935 | data(pos1) = tmp 936 | } 937 | 938 | private def swapElements(data: Array[Int], 939 | pos0: Int, 940 | pos1: Int): Unit = { 941 | val tmp = data(pos0) 942 | data(pos0) = data(pos1) 943 | data(pos1) = tmp 944 | } 945 | 946 | override def swap(data: FeatureVector, pos0: Int, pos1: Int): Unit = { 947 | swapElements(data.values, pos0, pos1) 948 | swapElements(data.indices, pos0, pos1) 949 | } 950 | 951 | override def copyRange(src: FeatureVector, 952 | srcPos: Int, 953 | dst: FeatureVector, 954 | dstPos: Int, 955 | length: Int): Unit = { 956 | System.arraycopy(src.values, srcPos, dst.values, dstPos, length) 957 | System.arraycopy(src.indices, srcPos, dst.indices, dstPos, length) 958 | } 959 | 960 | override def allocate(length: Int): FeatureVector = { 961 | new FeatureVector(featureIndex, featureArity, new Array[Double](length), new Array[Int](length)) 962 | } 963 | 964 | override def copyElement(src: FeatureVector, 965 | srcPos: Int, 966 | dst: FeatureVector, 967 | dstPos: Int): Unit = { 968 | dst.values(dstPos) = src.values(srcPos) 969 | dst.indices(dstPos) = src.indices(srcPos) 970 | } 971 | } 972 | 973 | /** 974 | * A wrapper that holds a primitive key – borrowed from [[org.apache.spark.ml.recommendation.ALS]] 975 | */ 976 | private class KeyWrapper extends Ordered[KeyWrapper] { 977 | 978 | var key: Double = _ 979 | 980 | override def compare(that: KeyWrapper): Int = { 981 | scala.math.Ordering.Double.compare(key, that.key) 982 | } 983 | 984 | def setKey(key: Double): this.type = { 985 | this.key = key 986 | this 987 | } 988 | } 989 | } 990 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/tree/impl/YggdrasilClassification.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.tree.impl 2 | 3 | import org.apache.spark.broadcast.Broadcast 4 | import org.apache.spark.ml.tree.impl.Yggdrasil.{FeatureVector, PartitionInfo, YggdrasilMetadata} 5 | import org.apache.spark.ml.tree.{ImpurityAggregatorSingle, ygg, Node => SparkNode} 6 | import org.apache.spark.mllib.regression.LabeledPoint 7 | import org.apache.spark.mllib.tree.model.ImpurityStats 8 | import org.apache.spark.rdd.RDD 9 | import org.apache.spark.serializer.KryoSerializer 10 | import org.apache.spark.storage.StorageLevel 11 | import org.apache.spark.util.collection.BitSet 12 | import org.roaringbitmap.RoaringBitmap 13 | 14 | object YggdrasilClassification { 15 | 16 | def trainImpl( 17 | input: RDD[LabeledPoint], 18 | colStoreInit: RDD[(Int, Array[Double])], 19 | metadata: YggdrasilMetadata, 20 | numRows: Int, 21 | maxDepth: Int): SparkNode = { 22 | 23 | val labels = new Array[Byte](numRows) 24 | input.map(_.label).zipWithIndex().collect().foreach { case (label: Double, rowIndex: Long) => 25 | labels(rowIndex.toInt) = label.toByte 26 | } 27 | val labelsBc = input.sparkContext.broadcast(labels) 28 | // NOTE: Labels are not sorted with features since that would require 1 copy per feature, 29 | // rather than 1 copy per worker. This means a lot of random accesses. 30 | // We could improve this by applying first-level sorting (by node) to labels. 31 | 32 | // Sort each column by feature values. 33 | val colStore: RDD[FeatureVector] = colStoreInit.map { case (featureIndex, col) => 34 | val featureArity: Int = metadata.categoricalFeaturesInfo.getOrElse(featureIndex, 0) 35 | FeatureVector.fromOriginal(featureIndex, featureArity, col) 36 | } 37 | // Group columns together into one array of columns per partition. 38 | // TODO: Test avoiding this grouping, and see if it matters. 39 | val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { 40 | iterator: Iterator[FeatureVector] => 41 | if (iterator.nonEmpty) Iterator(iterator.toArray) else Iterator() 42 | } 43 | groupedColStore.persist(StorageLevel.MEMORY_AND_DISK) 44 | 45 | // Initialize partitions with 1 node (each instance at the root node). 46 | val fullImpurityAgg = metadata.createImpurityAggregator() 47 | var i = 0 48 | while (i < labels.length) { 49 | fullImpurityAgg.update(labels(i)) 50 | i += 1 51 | } 52 | var partitionInfos: RDD[PartitionInfo] = groupedColStore.map { groupedCols => 53 | val initActive = new BitSet(1) 54 | initActive.set(0) 55 | 56 | new PartitionInfo(groupedCols, Array[Int](0, numRows), initActive, Array(fullImpurityAgg)) 57 | } 58 | 59 | // Initialize model. 60 | // Note: We do not use node indices. 61 | val rootNode = ygg.LearningNode.emptyNode(1) // TODO: remove node id 62 | // Active nodes (still being split), updated each iteration 63 | var activeNodePeriphery: Array[ygg.LearningNode] = Array(rootNode) 64 | var numNodeOffsets: Int = 2 65 | 66 | // Iteratively learn, one level of the tree at a time. 67 | var currentLevel = 0 68 | var doneLearning = false 69 | while (currentLevel < maxDepth && !doneLearning) { 70 | // Compute best split for each active node. 71 | val bestSplitsAndGains: Array[(Option[ygg.Split], ImpurityStats)] = 72 | computeBestSplits(partitionInfos, labelsBc, metadata) 73 | /* 74 | // NOTE: The actual active nodes (activeNodePeriphery) may be a subset of the nodes under 75 | // bestSplitsAndGains since 76 | assert(activeNodePeriphery.length == bestSplitsAndGains.length, 77 | s"activeNodePeriphery.length=${activeNodePeriphery.length} does not equal" + 78 | s" bestSplitsAndGains.length=${bestSplitsAndGains.length}") 79 | */ 80 | 81 | // Update current model and node periphery. 82 | // Note: This flatMap has side effects (on the model). 83 | activeNodePeriphery = 84 | Yggdrasil.computeActiveNodePeriphery(activeNodePeriphery, bestSplitsAndGains, metadata.minInfoGain) 85 | // We keep all old nodeOffsets and add one for each node split. 86 | // Each node split adds 2 nodes to activeNodePeriphery. 87 | // TODO: Should this be calculated after filtering for impurity?? 88 | numNodeOffsets = numNodeOffsets + activeNodePeriphery.length / 2 89 | 90 | // Filter active node periphery by impurity. 91 | val estimatedRemainingActive = activeNodePeriphery.count(_.stats.impurity > 0.0) 92 | 93 | // TODO: Check to make sure we split something, and stop otherwise. 94 | doneLearning = currentLevel + 1 >= maxDepth || estimatedRemainingActive == 0 95 | 96 | if (!doneLearning) { 97 | val splits: Array[Option[ygg.Split]] = bestSplitsAndGains.map(_._1) 98 | 99 | // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right 100 | val aggBitVector: RoaringBitmap = Yggdrasil.aggregateBitVector(partitionInfos, splits, numRows) 101 | val bv = new BitSet(numRows) 102 | val iter = aggBitVector.getIntIterator 103 | while(iter.hasNext) { 104 | bv.set(iter.next) 105 | } 106 | val ser = new KryoSerializer(input.sparkContext.getConf).newInstance() 107 | val buf = ser.serialize(aggBitVector) 108 | val buf2 = ser.serialize(bv) 109 | println(s"currentLevel: $currentLevel, RoaringBitmap num bytes: ${buf.remaining()}, BitSet num bytes: ${buf2.remaining()}") 110 | val newPartitionInfos = partitionInfos.map { partitionInfo => 111 | partitionInfo.update(bv, numNodeOffsets, labelsBc.value, metadata) 112 | } 113 | // TODO: remove. For some reason, this is needed to make things work. 114 | // Probably messing up somewhere above... 115 | newPartitionInfos.cache().count() 116 | partitionInfos = newPartitionInfos 117 | 118 | } 119 | currentLevel += 1 120 | } 121 | 122 | // Done with learning 123 | groupedColStore.unpersist() 124 | labelsBc.unpersist() 125 | rootNode.toSparkNode 126 | } 127 | 128 | /** 129 | * Find the best splits for all active nodes. 130 | * - On each partition, for each feature on the partition, select the best split for each node. 131 | * Each worker returns: For each active node, best split + info gain 132 | * - The splits across workers are aggregated to the driver. 133 | * @return Array over active nodes of (best split, impurity stats for split), 134 | * where the split is None if no useful split exists 135 | */ 136 | private[impl] def computeBestSplits( 137 | partitionInfos: RDD[PartitionInfo], 138 | labelsBc: Broadcast[Array[Byte]], 139 | metadata: YggdrasilMetadata) = { 140 | // On each partition, for each feature on the partition, select the best split for each node. 141 | // This will use: 142 | // - groupedColStore (the features) 143 | // - partitionInfos (the node -> instance mapping) 144 | // - labelsBc (the labels column) 145 | // Each worker returns: 146 | // for each active node, best split + info gain, 147 | // where the best split is None if no useful split exists 148 | val partBestSplitsAndGains: RDD[Array[(Option[ygg.Split], ImpurityStats)]] = partitionInfos.map { 149 | case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], 150 | activeNodes: BitSet, fullImpurityAggs: Array[ImpurityAggregatorSingle]) => 151 | val localLabels = labelsBc.value 152 | // Iterate over the active nodes in the current level. 153 | val toReturn = new Array[(Option[ygg.Split], ImpurityStats)](activeNodes.cardinality()) 154 | val iter: Iterator[Int] = activeNodes.iterator 155 | var i = 0 156 | while (iter.hasNext) { 157 | val nodeIndexInLevel = iter.next 158 | val fromOffset = nodeOffsets(nodeIndexInLevel) 159 | val toOffset = nodeOffsets(nodeIndexInLevel + 1) 160 | val fullImpurityAgg = fullImpurityAggs(nodeIndexInLevel) 161 | val splitsAndStats = 162 | columns.map { col => 163 | chooseSplit(col, localLabels, fromOffset, toOffset, fullImpurityAgg, metadata) 164 | } 165 | toReturn(i) = splitsAndStats.maxBy(_._2.gain) 166 | i += 1 167 | } 168 | toReturn 169 | } 170 | 171 | // Aggregate best split for each active node. 172 | partBestSplitsAndGains.treeReduce { case (splitsGains1, splitsGains2) => 173 | splitsGains1.zip(splitsGains2).map { case ((split1, gain1), (split2, gain2)) => 174 | if (gain1.gain >= gain2.gain) { 175 | (split1, gain1) 176 | } else { 177 | (split2, gain2) 178 | } 179 | } 180 | } 181 | } 182 | 183 | /** 184 | * Choose the best split for a feature at a node. 185 | * TODO: Return null or None when the split is invalid, such as putting all instances on one 186 | * child node. 187 | * 188 | * @return (best split, statistics for split) If the best split actually puts all instances 189 | * in one leaf node, then it will be set to None. 190 | */ 191 | private[impl] def chooseSplit( 192 | col: FeatureVector, 193 | labels: Array[Byte], 194 | fromOffset: Int, 195 | toOffset: Int, 196 | fullImpurityAgg: ImpurityAggregatorSingle, 197 | metadata: YggdrasilMetadata): (Option[ygg.Split], ImpurityStats) = { 198 | if (col.isCategorical) { 199 | if (metadata.isUnorderedFeature(col.featureIndex)) { 200 | val splits: Array[ygg.CategoricalSplit] = metadata.getUnorderedSplits(col.featureIndex) 201 | chooseUnorderedCategoricalSplit(col.featureIndex, col.values, col.indices, labels, fromOffset, toOffset, 202 | metadata, col.featureArity, splits) 203 | } else { 204 | chooseOrderedCategoricalSplit(col.featureIndex, col.values, col.indices, labels, fromOffset, toOffset, 205 | metadata, col.featureArity) 206 | } 207 | } else { 208 | chooseContinuousSplit(col.featureIndex, col.values, col.indices, labels, fromOffset, toOffset, 209 | fullImpurityAgg, metadata) 210 | } 211 | } 212 | 213 | /** 214 | * Find the best split for an ordered categorical feature at a single node. 215 | * 216 | * Algorithm: 217 | * - For each category, compute a "centroid." 218 | * - For multiclass classification, the centroid is the label impurity. 219 | * - For binary classification and regression, the centroid is the average label. 220 | * - Sort the centroids, and consider splits anywhere in this order. 221 | * Thus, with K categories, we consider K - 1 possible splits. 222 | * 223 | * @param featureIndex Index of feature being split. 224 | * @param values Feature values at this node. Sorted in increasing order. 225 | * @param labels Labels corresponding to values, in the same order. 226 | * @return (best split, statistics for split) If the best split actually puts all instances 227 | * in one leaf node, then it will be set to None. The impurity stats maybe still be 228 | * useful, so they are returned. 229 | */ 230 | private[impl] def chooseOrderedCategoricalSplit( 231 | featureIndex: Int, 232 | values: Array[Double], 233 | indices: Array[Int], 234 | labels: Array[Byte], 235 | from: Int, 236 | to: Int, 237 | metadata: YggdrasilMetadata, 238 | featureArity: Int): (Option[ygg.Split], ImpurityStats) = { 239 | // TODO: Support high-arity features by using a single array to hold the stats. 240 | 241 | // aggStats(category) = label statistics for category 242 | val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( 243 | _ => metadata.createImpurityAggregator()) 244 | var i = from 245 | while (i < to) { 246 | val cat = values(i) 247 | val label = labels(indices(i)) 248 | aggStats(cat.toInt).update(label) 249 | i += 1 250 | } 251 | 252 | // Compute centroids. centroidsForCategories is a list: (category, centroid) 253 | val centroidsForCategories: Seq[(Int, Double)] = if (metadata.isMulticlass) { 254 | // For categorical variables in multiclass classification, 255 | // the bins are ordered by the impurity of their corresponding labels. 256 | Range(0, featureArity).map { case featureValue => 257 | val categoryStats = aggStats(featureValue) 258 | val centroid = if (categoryStats.getCount != 0) { 259 | categoryStats.getCalculator.calculate() 260 | } else { 261 | Double.MaxValue 262 | } 263 | (featureValue, centroid) 264 | } 265 | } else if (metadata.isClassification) { // binary classification 266 | // For categorical variables in binary classification, 267 | // the bins are ordered by the centroid of their corresponding labels. 268 | Range(0, featureArity).map { case featureValue => 269 | val categoryStats = aggStats(featureValue) 270 | val centroid = if (categoryStats.getCount != 0) { 271 | assert(categoryStats.stats.length == 2) 272 | (categoryStats.stats(1) - categoryStats.stats(0)) / categoryStats.getCount 273 | } else { 274 | Double.MaxValue 275 | } 276 | (featureValue, centroid) 277 | } 278 | } else { // regression 279 | // For categorical variables in regression, 280 | // the bins are ordered by the centroid of their corresponding labels. 281 | Range(0, featureArity).map { case featureValue => 282 | val categoryStats = aggStats(featureValue) 283 | val centroid = if (categoryStats.getCount != 0) { 284 | categoryStats.getCalculator.predict 285 | } else { 286 | Double.MaxValue 287 | } 288 | (featureValue, centroid) 289 | } 290 | } 291 | 292 | val categoriesSortedByCentroid: List[Int] = centroidsForCategories.toList.sortBy(_._2).map(_._1) 293 | 294 | // Cumulative sums of bin statistics for left, right parts of split. 295 | val leftImpurityAgg = metadata.createImpurityAggregator() 296 | val rightImpurityAgg = metadata.createImpurityAggregator() 297 | var j = 0 298 | val length = aggStats.length 299 | while (j < length) { 300 | rightImpurityAgg.add(aggStats(j)) 301 | j += 1 302 | } 303 | 304 | var bestSplitIndex: Int = -1 // index into categoriesSortedByCentroid 305 | val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() 306 | var bestGain: Double = 0.0 307 | val fullImpurity = rightImpurityAgg.getCalculator.calculate() 308 | var leftCount: Double = 0.0 309 | var rightCount: Double = rightImpurityAgg.getCount 310 | val fullCount: Double = rightCount 311 | 312 | // Consider all splits. These only cover valid splits, with at least one category on each side. 313 | val numSplits = categoriesSortedByCentroid.length - 1 314 | var sortedCatIndex = 0 315 | while (sortedCatIndex < numSplits) { 316 | val cat = categoriesSortedByCentroid(sortedCatIndex) 317 | // Update left, right stats 318 | val catStats = aggStats(cat) 319 | leftImpurityAgg.add(catStats) 320 | rightImpurityAgg.subtract(catStats) 321 | leftCount += catStats.getCount 322 | rightCount -= catStats.getCount 323 | // Compute impurity 324 | val leftWeight = leftCount / fullCount 325 | val rightWeight = rightCount / fullCount 326 | val leftImpurity = leftImpurityAgg.getCalculator.calculate() 327 | val rightImpurity = rightImpurityAgg.getCalculator.calculate() 328 | val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity 329 | if (leftCount != 0 && rightCount != 0 && gain > bestGain && gain > metadata.minInfoGain) { 330 | bestSplitIndex = sortedCatIndex 331 | System.arraycopy(leftImpurityAgg.stats, 0, bestLeftImpurityAgg.stats, 0, leftImpurityAgg.stats.length) 332 | bestGain = gain 333 | } 334 | sortedCatIndex += 1 335 | } 336 | 337 | val categoriesForSplit = 338 | categoriesSortedByCentroid.slice(0, bestSplitIndex + 1).map(_.toDouble) 339 | val bestFeatureSplit = 340 | new ygg.CategoricalSplit(featureIndex, categoriesForSplit.toArray, featureArity) 341 | val fullImpurityAgg = leftImpurityAgg.deepCopy().add(rightImpurityAgg) 342 | val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) 343 | val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator, 344 | bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator) 345 | 346 | if (bestSplitIndex == -1 || bestGain == 0.0) { 347 | (None, bestImpurityStats) 348 | } else { 349 | (Some(bestFeatureSplit), bestImpurityStats) 350 | } 351 | } 352 | 353 | /** 354 | * Find the best split for an unordered categorical feature at a single node. 355 | * 356 | * Algorithm: 357 | * - Considers all possible subsets (exponentially many) 358 | * 359 | * @param featureIndex Index of feature being split. 360 | * @param values Feature values at this node. Sorted in increasing order. 361 | * @param labels Labels corresponding to values, in the same order. 362 | * @return (best split, statistics for split) If the best split actually puts all instances 363 | * in one leaf node, then it will be set to None. The impurity stats maybe still be 364 | * useful, so they are returned. 365 | */ 366 | private[impl] def chooseUnorderedCategoricalSplit( 367 | featureIndex: Int, 368 | values: Array[Double], 369 | indices: Array[Int], 370 | labels: Array[Byte], 371 | from: Int, 372 | to: Int, 373 | metadata: YggdrasilMetadata, 374 | featureArity: Int, 375 | splits: Array[ygg.CategoricalSplit]): (Option[ygg.Split], ImpurityStats) = { 376 | 377 | // Label stats for each category 378 | val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( 379 | _ => metadata.createImpurityAggregator()) 380 | var i = from 381 | while (i < to) { 382 | val cat = values(i) 383 | val label = labels(indices(i)) 384 | // NOTE: we assume the values for categorical features are Ints in [0,featureArity) 385 | aggStats(cat.toInt).update(label) 386 | i += 1 387 | } 388 | 389 | // Aggregated statistics for left part of split and entire split. 390 | val leftImpurityAgg = metadata.createImpurityAggregator() 391 | val fullImpurityAgg = metadata.createImpurityAggregator() 392 | aggStats.foreach(fullImpurityAgg.add) 393 | val fullImpurity = fullImpurityAgg.getCalculator.calculate() 394 | 395 | if (featureArity == 1) { 396 | // All instances go right 397 | val impurityStats = new ImpurityStats(0.0, fullImpurityAgg.getCalculator.calculate(), 398 | fullImpurityAgg.getCalculator, leftImpurityAgg.getCalculator, 399 | fullImpurityAgg.getCalculator) 400 | (None, impurityStats) 401 | } else { 402 | // TODO: We currently add and remove the stats for all categories for each split. 403 | // A better way to do it would be to consider splits in an order such that each iteration 404 | // only requires addition/removal of a single category and a single add/subtract to 405 | // leftCount and rightCount. 406 | // TODO: Use more efficient encoding such as gray codes 407 | var bestSplit: Option[ygg.CategoricalSplit] = None 408 | val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() 409 | var bestGain: Double = -1.0 410 | val fullCount: Double = to - from 411 | for (split <- splits) { 412 | // Update left, right impurity stats 413 | split.leftCategories.foreach(c => leftImpurityAgg.add(aggStats(c.toInt))) 414 | val rightImpurityAgg = fullImpurityAgg.deepCopy().subtract(leftImpurityAgg) 415 | val leftCount = leftImpurityAgg.getCount 416 | val rightCount = rightImpurityAgg.getCount 417 | // Compute impurity 418 | val leftWeight = leftCount / fullCount 419 | val rightWeight = rightCount / fullCount 420 | val leftImpurity = leftImpurityAgg.getCalculator.calculate() 421 | val rightImpurity = rightImpurityAgg.getCalculator.calculate() 422 | val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity 423 | if (leftCount != 0 && rightCount != 0 && gain > bestGain && gain > metadata.minInfoGain) { 424 | bestSplit = Some(split) 425 | System.arraycopy(leftImpurityAgg.stats, 0, bestLeftImpurityAgg.stats, 0, leftImpurityAgg.stats.length) 426 | bestGain = gain 427 | } 428 | // Reset left impurity stats 429 | leftImpurityAgg.clear() 430 | } 431 | 432 | val bestFeatureSplit = bestSplit match { 433 | case Some(split) => Some( 434 | new ygg.CategoricalSplit(featureIndex, split.leftCategories, featureArity)) 435 | case None => None 436 | 437 | } 438 | val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) 439 | val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, 440 | fullImpurityAgg.getCalculator, bestLeftImpurityAgg.getCalculator, 441 | bestRightImpurityAgg.getCalculator) 442 | (bestFeatureSplit, bestImpurityStats) 443 | } 444 | } 445 | 446 | /** 447 | * Choose splitting rule: feature value <= threshold 448 | * @return (best split, statistics for split) If the best split actually puts all instances 449 | * in one leaf node, then it will be set to None. The impurity stats maybe still be 450 | * useful, so they are returned. 451 | */ 452 | private[impl] def chooseContinuousSplit( 453 | featureIndex: Int, 454 | values: Array[Double], 455 | indices: Array[Int], 456 | labels: Array[Byte], 457 | from: Int, 458 | to: Int, 459 | fullImpurityAgg: ImpurityAggregatorSingle, 460 | metadata: YggdrasilMetadata): (Option[ygg.Split], ImpurityStats) = { 461 | 462 | val leftImpurityAgg = metadata.createImpurityAggregator() 463 | val rightImpurityAgg = fullImpurityAgg.deepCopy() 464 | 465 | var bestThreshold: Double = Double.NegativeInfinity 466 | val bestLeftImpurityAgg = metadata.createImpurityAggregator() 467 | var bestGain: Double = 0.0 468 | val fullImpurity = rightImpurityAgg.getCalculator.calculate() 469 | var leftCount: Int = 0 470 | var rightCount: Int = to - from 471 | val fullCount: Double = rightCount 472 | var currentThreshold = values.headOption.getOrElse(bestThreshold) 473 | var j = from 474 | while (j < to) { 475 | val value = values(j) 476 | val label = labels(indices(j)) 477 | if (value != currentThreshold) { 478 | // Check gain 479 | val leftWeight = leftCount / fullCount 480 | val rightWeight = rightCount / fullCount 481 | val leftImpurity = leftImpurityAgg.getCalculator.calculate() 482 | val rightImpurity = rightImpurityAgg.getCalculator.calculate() 483 | val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity 484 | if (leftCount != 0 && rightCount != 0 && gain > bestGain && gain > metadata.minInfoGain) { 485 | bestThreshold = currentThreshold 486 | System.arraycopy(leftImpurityAgg.stats, 0, bestLeftImpurityAgg.stats, 0, leftImpurityAgg.stats.length) 487 | bestGain = gain 488 | } 489 | currentThreshold = value 490 | } 491 | // Move this instance from right to left side of split. 492 | leftImpurityAgg.update(label, 1) 493 | rightImpurityAgg.update(label, -1) 494 | leftCount += 1 495 | rightCount -= 1 496 | j += 1 497 | } 498 | 499 | val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) 500 | val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator, 501 | bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator) 502 | val split = if (bestThreshold != Double.NegativeInfinity && bestThreshold != values.last) { 503 | Some(new ygg.ContinuousSplit(featureIndex, bestThreshold)) 504 | } else { 505 | None 506 | } 507 | (split, bestImpurityStats) 508 | } 509 | } 510 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/tree/impl/YggdrasilRegression.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.tree.impl 2 | 3 | import org.apache.spark.broadcast.Broadcast 4 | import org.apache.spark.ml.tree.{ImpurityAggregatorSingle, ygg, Node => SparkNode} 5 | import org.apache.spark.ml.tree.impl.Yggdrasil.{FeatureVector, PartitionInfo, YggdrasilMetadata} 6 | import org.apache.spark.mllib.regression.LabeledPoint 7 | import org.apache.spark.mllib.tree.model.ImpurityStats 8 | import org.apache.spark.rdd.RDD 9 | import org.apache.spark.storage.StorageLevel 10 | import org.apache.spark.util.collection.BitSet 11 | import org.roaringbitmap.RoaringBitmap 12 | 13 | object YggdrasilRegression { 14 | 15 | def trainImpl( 16 | input: RDD[LabeledPoint], 17 | colStoreInit: RDD[(Int, Array[Double])], 18 | metadata: YggdrasilMetadata, 19 | numRows: Int, 20 | maxDepth: Int): SparkNode = { 21 | 22 | val labels = new Array[Double](numRows) 23 | input.map(_.label).zipWithIndex().collect().foreach { case (label: Double, rowIndex: Long) => 24 | labels(rowIndex.toInt) = label.toDouble 25 | } 26 | val labelsBc = input.sparkContext.broadcast(labels) 27 | // NOTE: Labels are not sorted with features since that would require 1 copy per feature, 28 | // rather than 1 copy per worker. This means a lot of random accesses. 29 | // We could improve this by applying first-level sorting (by node) to labels. 30 | 31 | // Sort each column by feature values. 32 | val colStore: RDD[FeatureVector] = colStoreInit.map { case (featureIndex, col) => 33 | val featureArity: Int = metadata.categoricalFeaturesInfo.getOrElse(featureIndex, 0) 34 | FeatureVector.fromOriginal(featureIndex, featureArity, col) 35 | } 36 | // Group columns together into one array of columns per partition. 37 | // TODO: Test avoiding this grouping, and see if it matters. 38 | val groupedColStore: RDD[Array[FeatureVector]] = colStore.mapPartitions { 39 | iterator: Iterator[FeatureVector] => 40 | if (iterator.nonEmpty) Iterator(iterator.toArray) else Iterator() 41 | } 42 | groupedColStore.persist(StorageLevel.MEMORY_AND_DISK) 43 | 44 | // Initialize partitions with 1 node (each instance at the root node). 45 | val fullImpurityAgg = metadata.createImpurityAggregator() 46 | var i = 0 47 | while (i < labels.length) { 48 | fullImpurityAgg.update(labels(i)) 49 | i += 1 50 | } 51 | var partitionInfos: RDD[PartitionInfo] = groupedColStore.map { groupedCols => 52 | val initActive = new BitSet(1) 53 | initActive.set(0) 54 | 55 | new PartitionInfo(groupedCols, Array[Int](0, numRows), initActive, Array(fullImpurityAgg)) 56 | } 57 | 58 | // Initialize model. 59 | // Note: We do not use node indices. 60 | val rootNode = ygg.LearningNode.emptyNode(1) // TODO: remove node id 61 | // Active nodes (still being split), updated each iteration 62 | var activeNodePeriphery: Array[ygg.LearningNode] = Array(rootNode) 63 | var numNodeOffsets: Int = 2 64 | 65 | // Iteratively learn, one level of the tree at a time. 66 | var currentLevel = 0 67 | var doneLearning = false 68 | while (currentLevel < maxDepth && !doneLearning) { 69 | // Compute best split for each active node. 70 | val bestSplitsAndGains: Array[(Option[ygg.Split], ImpurityStats)] = 71 | computeBestSplits(partitionInfos, labelsBc, metadata) 72 | /* 73 | // NOTE: The actual active nodes (activeNodePeriphery) may be a subset of the nodes under 74 | // bestSplitsAndGains since 75 | assert(activeNodePeriphery.length == bestSplitsAndGains.length, 76 | s"activeNodePeriphery.length=${activeNodePeriphery.length} does not equal" + 77 | s" bestSplitsAndGains.length=${bestSplitsAndGains.length}") 78 | */ 79 | 80 | // Update current model and node periphery. 81 | // Note: This flatMap has side effects (on the model). 82 | activeNodePeriphery = 83 | Yggdrasil.computeActiveNodePeriphery(activeNodePeriphery, bestSplitsAndGains, metadata.minInfoGain) 84 | // We keep all old nodeOffsets and add one for each node split. 85 | // Each node split adds 2 nodes to activeNodePeriphery. 86 | // TODO: Should this be calculated after filtering for impurity?? 87 | numNodeOffsets = numNodeOffsets + activeNodePeriphery.length / 2 88 | 89 | // Filter active node periphery by impurity. 90 | val estimatedRemainingActive = activeNodePeriphery.count(_.stats.impurity > 0.0) 91 | 92 | // TODO: Check to make sure we split something, and stop otherwise. 93 | doneLearning = currentLevel + 1 >= maxDepth || estimatedRemainingActive == 0 94 | 95 | if (!doneLearning) { 96 | val splits: Array[Option[ygg.Split]] = bestSplitsAndGains.map(_._1) 97 | 98 | // Aggregate bit vector (1 bit/instance) indicating whether each instance goes left/right 99 | val aggBitVector: RoaringBitmap = Yggdrasil.aggregateBitVector(partitionInfos, splits, numRows) 100 | val newPartitionInfos = partitionInfos.map { partitionInfo => 101 | val bv = new BitSet(numRows) 102 | val iter = aggBitVector.getIntIterator 103 | while(iter.hasNext) { 104 | bv.set(iter.next) 105 | } 106 | partitionInfo.update(bv, numNodeOffsets, labelsBc.value, metadata) 107 | } 108 | // TODO: remove. For some reason, this is needed to make things work. 109 | // Probably messing up somewhere above... 110 | newPartitionInfos.cache().count() 111 | partitionInfos = newPartitionInfos 112 | 113 | } 114 | currentLevel += 1 115 | } 116 | 117 | // Done with learning 118 | groupedColStore.unpersist() 119 | labelsBc.unpersist() 120 | rootNode.toSparkNode 121 | } 122 | 123 | /** 124 | * Find the best splits for all active nodes. 125 | * - On each partition, for each feature on the partition, select the best split for each node. 126 | * Each worker returns: For each active node, best split + info gain 127 | * - The splits across workers are aggregated to the driver. 128 | * @return Array over active nodes of (best split, impurity stats for split), 129 | * where the split is None if no useful split exists 130 | */ 131 | private[impl] def computeBestSplits( 132 | partitionInfos: RDD[PartitionInfo], 133 | labelsBc: Broadcast[Array[Double]], 134 | metadata: YggdrasilMetadata) = { 135 | // On each partition, for each feature on the partition, select the best split for each node. 136 | // This will use: 137 | // - groupedColStore (the features) 138 | // - partitionInfos (the node -> instance mapping) 139 | // - labelsBc (the labels column) 140 | // Each worker returns: 141 | // for each active node, best split + info gain, 142 | // where the best split is None if no useful split exists 143 | val partBestSplitsAndGains: RDD[Array[(Option[ygg.Split], ImpurityStats)]] = partitionInfos.map { 144 | case PartitionInfo(columns: Array[FeatureVector], nodeOffsets: Array[Int], 145 | activeNodes: BitSet, fullImpurityAggs: Array[ImpurityAggregatorSingle]) => 146 | val localLabels = labelsBc.value 147 | // Iterate over the active nodes in the current level. 148 | val toReturn = new Array[(Option[ygg.Split], ImpurityStats)](activeNodes.cardinality()) 149 | val iter: Iterator[Int] = activeNodes.iterator 150 | var i = 0 151 | while (iter.hasNext) { 152 | val nodeIndexInLevel = iter.next 153 | val fromOffset = nodeOffsets(nodeIndexInLevel) 154 | val toOffset = nodeOffsets(nodeIndexInLevel + 1) 155 | val fullImpurityAgg = fullImpurityAggs(nodeIndexInLevel) 156 | val splitsAndStats = 157 | columns.map { col => 158 | chooseSplit(col, localLabels, fromOffset, toOffset, fullImpurityAgg, metadata) 159 | } 160 | toReturn(i) = splitsAndStats.maxBy(_._2.gain) 161 | i += 1 162 | } 163 | toReturn 164 | } 165 | 166 | // Aggregate best split for each active node. 167 | partBestSplitsAndGains.treeReduce { case (splitsGains1, splitsGains2) => 168 | splitsGains1.zip(splitsGains2).map { case ((split1, gain1), (split2, gain2)) => 169 | if (gain1.gain >= gain2.gain) { 170 | (split1, gain1) 171 | } else { 172 | (split2, gain2) 173 | } 174 | } 175 | } 176 | } 177 | 178 | /** 179 | * Choose the best split for a feature at a node. 180 | * TODO: Return null or None when the split is invalid, such as putting all instances on one 181 | * child node. 182 | * 183 | * @return (best split, statistics for split) If the best split actually puts all instances 184 | * in one leaf node, then it will be set to None. 185 | */ 186 | private[impl] def chooseSplit( 187 | col: FeatureVector, 188 | labels: Array[Double], 189 | fromOffset: Int, 190 | toOffset: Int, 191 | fullImpurityAgg: ImpurityAggregatorSingle, 192 | metadata: YggdrasilMetadata): (Option[ygg.Split], ImpurityStats) = { 193 | if (col.isCategorical) { 194 | if (metadata.isUnorderedFeature(col.featureIndex)) { 195 | val splits: Array[ygg.CategoricalSplit] = metadata.getUnorderedSplits(col.featureIndex) 196 | chooseUnorderedCategoricalSplit(col.featureIndex, col.values, col.indices, labels, fromOffset, toOffset, 197 | metadata, col.featureArity, splits) 198 | } else { 199 | chooseOrderedCategoricalSplit(col.featureIndex, col.values, col.indices, labels, fromOffset, toOffset, 200 | metadata, col.featureArity) 201 | } 202 | } else { 203 | chooseContinuousSplit(col.featureIndex, col.values, col.indices, labels, fromOffset, toOffset, 204 | fullImpurityAgg, metadata) 205 | } 206 | } 207 | 208 | /** 209 | * Find the best split for an ordered categorical feature at a single node. 210 | * 211 | * Algorithm: 212 | * - For each category, compute a "centroid." 213 | * - For multiclass classification, the centroid is the label impurity. 214 | * - For binary classification and regression, the centroid is the average label. 215 | * - Sort the centroids, and consider splits anywhere in this order. 216 | * Thus, with K categories, we consider K - 1 possible splits. 217 | * 218 | * @param featureIndex Index of feature being split. 219 | * @param values Feature values at this node. Sorted in increasing order. 220 | * @param labels Labels corresponding to values, in the same order. 221 | * @return (best split, statistics for split) If the best split actually puts all instances 222 | * in one leaf node, then it will be set to None. The impurity stats maybe still be 223 | * useful, so they are returned. 224 | */ 225 | private[impl] def chooseOrderedCategoricalSplit( 226 | featureIndex: Int, 227 | values: Array[Double], 228 | indices: Array[Int], 229 | labels: Array[Double], 230 | from: Int, 231 | to: Int, 232 | metadata: YggdrasilMetadata, 233 | featureArity: Int): (Option[ygg.Split], ImpurityStats) = { 234 | // TODO: Support high-arity features by using a single array to hold the stats. 235 | 236 | // aggStats(category) = label statistics for category 237 | val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( 238 | _ => metadata.createImpurityAggregator()) 239 | var i = from 240 | while (i < to) { 241 | val cat = values(i) 242 | val label = labels(indices(i)) 243 | aggStats(cat.toInt).update(label) 244 | i += 1 245 | } 246 | 247 | // Compute centroids. centroidsForCategories is a list: (category, centroid) 248 | val centroidsForCategories: Seq[(Int, Double)] = if (metadata.isMulticlass) { 249 | // For categorical variables in multiclass classification, 250 | // the bins are ordered by the impurity of their corresponding labels. 251 | Range(0, featureArity).map { case featureValue => 252 | val categoryStats = aggStats(featureValue) 253 | val centroid = if (categoryStats.getCount != 0) { 254 | categoryStats.getCalculator.calculate() 255 | } else { 256 | Double.MaxValue 257 | } 258 | (featureValue, centroid) 259 | } 260 | } else if (metadata.isClassification) { // binary classification 261 | // For categorical variables in binary classification, 262 | // the bins are ordered by the centroid of their corresponding labels. 263 | Range(0, featureArity).map { case featureValue => 264 | val categoryStats = aggStats(featureValue) 265 | val centroid = if (categoryStats.getCount != 0) { 266 | assert(categoryStats.stats.length == 2) 267 | (categoryStats.stats(1) - categoryStats.stats(0)) / categoryStats.getCount 268 | } else { 269 | Double.MaxValue 270 | } 271 | (featureValue, centroid) 272 | } 273 | } else { // regression 274 | // For categorical variables in regression, 275 | // the bins are ordered by the centroid of their corresponding labels. 276 | Range(0, featureArity).map { case featureValue => 277 | val categoryStats = aggStats(featureValue) 278 | val centroid = if (categoryStats.getCount != 0) { 279 | categoryStats.getCalculator.predict 280 | } else { 281 | Double.MaxValue 282 | } 283 | (featureValue, centroid) 284 | } 285 | } 286 | 287 | val categoriesSortedByCentroid: List[Int] = centroidsForCategories.toList.sortBy(_._2).map(_._1) 288 | 289 | // Cumulative sums of bin statistics for left, right parts of split. 290 | val leftImpurityAgg = metadata.createImpurityAggregator() 291 | val rightImpurityAgg = metadata.createImpurityAggregator() 292 | var j = 0 293 | val length = aggStats.length 294 | while (j < length) { 295 | rightImpurityAgg.add(aggStats(j)) 296 | j += 1 297 | } 298 | 299 | var bestSplitIndex: Int = -1 // index into categoriesSortedByCentroid 300 | val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() 301 | var bestGain: Double = 0.0 302 | val fullImpurity = rightImpurityAgg.getCalculator.calculate() 303 | var leftCount: Double = 0.0 304 | var rightCount: Double = rightImpurityAgg.getCount 305 | val fullCount: Double = rightCount 306 | 307 | // Consider all splits. These only cover valid splits, with at least one category on each side. 308 | val numSplits = categoriesSortedByCentroid.length - 1 309 | var sortedCatIndex = 0 310 | while (sortedCatIndex < numSplits) { 311 | val cat = categoriesSortedByCentroid(sortedCatIndex) 312 | // Update left, right stats 313 | val catStats = aggStats(cat) 314 | leftImpurityAgg.add(catStats) 315 | rightImpurityAgg.subtract(catStats) 316 | leftCount += catStats.getCount 317 | rightCount -= catStats.getCount 318 | // Compute impurity 319 | val leftWeight = leftCount / fullCount 320 | val rightWeight = rightCount / fullCount 321 | val leftImpurity = leftImpurityAgg.getCalculator.calculate() 322 | val rightImpurity = rightImpurityAgg.getCalculator.calculate() 323 | val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity 324 | if (leftCount != 0 && rightCount != 0 && gain > bestGain && gain > metadata.minInfoGain) { 325 | bestSplitIndex = sortedCatIndex 326 | System.arraycopy(leftImpurityAgg.stats, 0, bestLeftImpurityAgg.stats, 0, leftImpurityAgg.stats.length) 327 | bestGain = gain 328 | } 329 | sortedCatIndex += 1 330 | } 331 | 332 | val categoriesForSplit = 333 | categoriesSortedByCentroid.slice(0, bestSplitIndex + 1).map(_.toDouble) 334 | val bestFeatureSplit = 335 | new ygg.CategoricalSplit(featureIndex, categoriesForSplit.toArray, featureArity) 336 | val fullImpurityAgg = leftImpurityAgg.deepCopy().add(rightImpurityAgg) 337 | val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) 338 | val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator, 339 | bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator) 340 | 341 | if (bestSplitIndex == -1 || bestGain == 0.0) { 342 | (None, bestImpurityStats) 343 | } else { 344 | (Some(bestFeatureSplit), bestImpurityStats) 345 | } 346 | } 347 | 348 | /** 349 | * Find the best split for an unordered categorical feature at a single node. 350 | * 351 | * Algorithm: 352 | * - Considers all possible subsets (exponentially many) 353 | * 354 | * @param featureIndex Index of feature being split. 355 | * @param values Feature values at this node. Sorted in increasing order. 356 | * @param labels Labels corresponding to values, in the same order. 357 | * @return (best split, statistics for split) If the best split actually puts all instances 358 | * in one leaf node, then it will be set to None. The impurity stats maybe still be 359 | * useful, so they are returned. 360 | */ 361 | private[impl] def chooseUnorderedCategoricalSplit( 362 | featureIndex: Int, 363 | values: Array[Double], 364 | indices: Array[Int], 365 | labels: Array[Double], 366 | from: Int, 367 | to: Int, 368 | metadata: YggdrasilMetadata, 369 | featureArity: Int, 370 | splits: Array[ygg.CategoricalSplit]): (Option[ygg.Split], ImpurityStats) = { 371 | 372 | // Label stats for each category 373 | val aggStats = Array.tabulate[ImpurityAggregatorSingle](featureArity)( 374 | _ => metadata.createImpurityAggregator()) 375 | var i = from 376 | while (i < to) { 377 | val cat = values(i) 378 | val label = labels(indices(i)) 379 | // NOTE: we assume the values for categorical features are Ints in [0,featureArity) 380 | aggStats(cat.toInt).update(label) 381 | i += 1 382 | } 383 | 384 | // Aggregated statistics for left part of split and entire split. 385 | val leftImpurityAgg = metadata.createImpurityAggregator() 386 | val fullImpurityAgg = metadata.createImpurityAggregator() 387 | aggStats.foreach(fullImpurityAgg.add) 388 | val fullImpurity = fullImpurityAgg.getCalculator.calculate() 389 | 390 | if (featureArity == 1) { 391 | // All instances go right 392 | val impurityStats = new ImpurityStats(0.0, fullImpurityAgg.getCalculator.calculate(), 393 | fullImpurityAgg.getCalculator, leftImpurityAgg.getCalculator, 394 | fullImpurityAgg.getCalculator) 395 | (None, impurityStats) 396 | } else { 397 | // TODO: We currently add and remove the stats for all categories for each split. 398 | // A better way to do it would be to consider splits in an order such that each iteration 399 | // only requires addition/removal of a single category and a single add/subtract to 400 | // leftCount and rightCount. 401 | // TODO: Use more efficient encoding such as gray codes 402 | var bestSplit: Option[ygg.CategoricalSplit] = None 403 | val bestLeftImpurityAgg = leftImpurityAgg.deepCopy() 404 | var bestGain: Double = -1.0 405 | val fullCount: Double = to - from 406 | for (split <- splits) { 407 | // Update left, right impurity stats 408 | split.leftCategories.foreach(c => leftImpurityAgg.add(aggStats(c.toInt))) 409 | val rightImpurityAgg = fullImpurityAgg.deepCopy().subtract(leftImpurityAgg) 410 | val leftCount = leftImpurityAgg.getCount 411 | val rightCount = rightImpurityAgg.getCount 412 | // Compute impurity 413 | val leftWeight = leftCount / fullCount 414 | val rightWeight = rightCount / fullCount 415 | val leftImpurity = leftImpurityAgg.getCalculator.calculate() 416 | val rightImpurity = rightImpurityAgg.getCalculator.calculate() 417 | val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity 418 | if (leftCount != 0 && rightCount != 0 && gain > bestGain && gain > metadata.minInfoGain) { 419 | bestSplit = Some(split) 420 | System.arraycopy(leftImpurityAgg.stats, 0, bestLeftImpurityAgg.stats, 0, leftImpurityAgg.stats.length) 421 | bestGain = gain 422 | } 423 | // Reset left impurity stats 424 | leftImpurityAgg.clear() 425 | } 426 | 427 | val bestFeatureSplit = bestSplit match { 428 | case Some(split) => Some( 429 | new ygg.CategoricalSplit(featureIndex, split.leftCategories, featureArity)) 430 | case None => None 431 | 432 | } 433 | val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) 434 | val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, 435 | fullImpurityAgg.getCalculator, bestLeftImpurityAgg.getCalculator, 436 | bestRightImpurityAgg.getCalculator) 437 | (bestFeatureSplit, bestImpurityStats) 438 | } 439 | } 440 | 441 | /** 442 | * Choose splitting rule: feature value <= threshold 443 | * @return (best split, statistics for split) If the best split actually puts all instances 444 | * in one leaf node, then it will be set to None. The impurity stats maybe still be 445 | * useful, so they are returned. 446 | */ 447 | private[impl] def chooseContinuousSplit( 448 | featureIndex: Int, 449 | values: Array[Double], 450 | indices: Array[Int], 451 | labels: Array[Double], 452 | from: Int, 453 | to: Int, 454 | fullImpurityAgg: ImpurityAggregatorSingle, 455 | metadata: YggdrasilMetadata): (Option[ygg.Split], ImpurityStats) = { 456 | 457 | val leftImpurityAgg = metadata.createImpurityAggregator() 458 | val rightImpurityAgg = fullImpurityAgg.deepCopy() 459 | 460 | var bestThreshold: Double = Double.NegativeInfinity 461 | val bestLeftImpurityAgg = metadata.createImpurityAggregator() 462 | var bestGain: Double = 0.0 463 | val fullImpurity = rightImpurityAgg.getCalculator.calculate() 464 | var leftCount: Int = 0 465 | var rightCount: Int = to - from 466 | val fullCount: Double = rightCount 467 | var currentThreshold = values.headOption.getOrElse(bestThreshold) 468 | var j = from 469 | while (j < to) { 470 | val value = values(j) 471 | val label = labels(indices(j)) 472 | if (value != currentThreshold) { 473 | // Check gain 474 | val leftWeight = leftCount / fullCount 475 | val rightWeight = rightCount / fullCount 476 | val leftImpurity = leftImpurityAgg.getCalculator.calculate() 477 | val rightImpurity = rightImpurityAgg.getCalculator.calculate() 478 | val gain = fullImpurity - leftWeight * leftImpurity - rightWeight * rightImpurity 479 | if (leftCount != 0 && rightCount != 0 && gain > bestGain && gain > metadata.minInfoGain) { 480 | bestThreshold = currentThreshold 481 | System.arraycopy(leftImpurityAgg.stats, 0, bestLeftImpurityAgg.stats, 0, leftImpurityAgg.stats.length) 482 | bestGain = gain 483 | } 484 | currentThreshold = value 485 | } 486 | // Move this instance from right to left side of split. 487 | leftImpurityAgg.update(label, 1) 488 | rightImpurityAgg.update(label, -1) 489 | leftCount += 1 490 | rightCount -= 1 491 | j += 1 492 | } 493 | 494 | val bestRightImpurityAgg = fullImpurityAgg.deepCopy().subtract(bestLeftImpurityAgg) 495 | val bestImpurityStats = new ImpurityStats(bestGain, fullImpurity, fullImpurityAgg.getCalculator, 496 | bestLeftImpurityAgg.getCalculator, bestRightImpurityAgg.getCalculator) 497 | val split = if (bestThreshold != Double.NegativeInfinity && bestThreshold != values.last) { 498 | Some(new ygg.ContinuousSplit(featureIndex, bestThreshold)) 499 | } else { 500 | None 501 | } 502 | (split, bestImpurityStats) 503 | } 504 | } 505 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/tree/impl/YggdrasilUtil.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.tree.impl 19 | 20 | import scala.collection.mutable.ArrayBuffer 21 | 22 | import org.apache.spark.annotation.DeveloperApi 23 | import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} 24 | import org.apache.spark.rdd.RDD 25 | 26 | 27 | @DeveloperApi 28 | object YggdrasilUtil { 29 | 30 | /** 31 | * Convert a dataset of [[Vector]] from row storage to column storage. 32 | * This can take any [[Vector]] type but stores data as [[DenseVector]]. 33 | * 34 | * WARNING: This shuffles the ENTIRE dataset across the network, so it is a VERY EXPENSIVE 35 | * operation. This can also fail if 1 column is too large to fit on 1 partition. 36 | * 37 | * This maintains sparsity in the data. 38 | * 39 | * This maintains matrix structure. I.e., each partition of the output RDD holds adjacent 40 | * columns. The number of partitions will be min(input RDD's number of partitions, numColumns). 41 | * 42 | * @param rowStore The input vectors are data rows/instances. 43 | * @return RDD of (columnIndex, columnValues) pairs, 44 | * where each pair corresponds to one entire column. 45 | * If either dimension of the given data is 0, this returns an empty RDD. 46 | * If vector lengths do not match, this throws an exception. 47 | * 48 | * TODO: Add implementation for sparse data. 49 | * For sparse data, distribute more evenly based on number of non-zeros. 50 | * (First collect stats to decide how to partition.) 51 | * TODO: Move elsewhere in MLlib. 52 | */ 53 | def rowToColumnStoreDense(rowStore: RDD[Vector]): RDD[(Int, Array[Double])] = { 54 | 55 | val numRows = { 56 | val longNumRows: Long = rowStore.count() 57 | require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + 58 | s" but can handle at most ${Int.MaxValue} rows") 59 | longNumRows.toInt 60 | } 61 | if (numRows == 0) { 62 | return rowStore.sparkContext.parallelize(Seq.empty[(Int, Array[Double])]) 63 | } 64 | val numCols = rowStore.take(1)(0).size 65 | if (numCols == 0) { 66 | return rowStore.sparkContext.parallelize(Seq.empty[(Int, Array[Double])]) 67 | } 68 | 69 | val numSourcePartitions = rowStore.partitions.length 70 | val approxNumTargetPartitions = Math.min(numCols, numSourcePartitions) 71 | val maxColumnsPerPartition = Math.ceil(numCols / approxNumTargetPartitions.toDouble).toInt 72 | val numTargetPartitions = Math.ceil(numCols / maxColumnsPerPartition.toDouble).toInt 73 | 74 | def getNumColsInGroup(groupIndex: Int) = { 75 | if (groupIndex + 1 < numTargetPartitions) { 76 | maxColumnsPerPartition 77 | } else { 78 | numCols - (numTargetPartitions - 1) * maxColumnsPerPartition // last partition 79 | } 80 | } 81 | 82 | /* On each partition, re-organize into groups of columns: 83 | (groupIndex, (sourcePartitionIndex, partCols)), 84 | where partCols(colIdx) = partial column. 85 | The groupIndex will be used to groupByKey. 86 | The sourcePartitionIndex is used to ensure instance indices match up after the shuffle. 87 | The partial columns will be stacked into full columns after the shuffle. 88 | Note: By design, partCols will always have at least 1 column. 89 | */ 90 | val partialColumns: RDD[(Int, (Int, Array[Array[Double]]))] = 91 | rowStore.mapPartitionsWithIndex { case (sourcePartitionIndex, iterator) => 92 | // columnSets(groupIndex)(colIdx) 93 | // = column values for each instance in sourcePartitionIndex, 94 | // where colIdx is a 0-based index for columns for groupIndex 95 | val columnSets = new Array[Array[ArrayBuffer[Double]]](numTargetPartitions) 96 | var groupIndex = 0 97 | while(groupIndex < numTargetPartitions) { 98 | columnSets(groupIndex) = 99 | Array.fill[ArrayBuffer[Double]](getNumColsInGroup(groupIndex))(ArrayBuffer[Double]()) 100 | groupIndex += 1 101 | } 102 | while (iterator.hasNext) { 103 | val row = iterator.next.toArray 104 | var groupIndex = 0 105 | while (groupIndex < numTargetPartitions) { 106 | val fromCol = groupIndex * maxColumnsPerPartition 107 | val numColsInTargetPartition = getNumColsInGroup(groupIndex) 108 | // TODO: match-case here on row as Dense or Sparse Vector (for speed) 109 | var colIdx = 0 110 | while (colIdx < numColsInTargetPartition) { 111 | columnSets(groupIndex)(colIdx) += row(fromCol + colIdx) 112 | colIdx += 1 113 | } 114 | groupIndex += 1 115 | } 116 | } 117 | Range(0, numTargetPartitions).map { groupIndex => 118 | (groupIndex, (sourcePartitionIndex, columnSets(groupIndex).map(_.toArray))) 119 | }.toIterator 120 | } 121 | 122 | // Shuffle data 123 | val groupedPartialColumns: RDD[(Int, Iterable[(Int, Array[Array[Double]])])] = 124 | partialColumns.groupByKey() 125 | 126 | // Each target partition now holds its set of columns. 127 | // Group the partial columns into full columns. 128 | val fullColumns = groupedPartialColumns.flatMap { case (groupIndex, iterable) => 129 | // We do not know the number of rows per group, so we need to collect the groups 130 | // before filling the full columns. 131 | val collectedPartCols = new Array[Array[Array[Double]]](numSourcePartitions) 132 | val iter = iterable.iterator 133 | while (iter.hasNext) { 134 | val (sourcePartitionIndex, partCols) = iter.next() 135 | collectedPartCols(sourcePartitionIndex) = partCols 136 | } 137 | val rowOffsets: Array[Int] = collectedPartCols.map(_(0).length).scanLeft(0)(_ + _) 138 | val numRows = rowOffsets.last 139 | // Initialize full columns 140 | val fromCol = groupIndex * maxColumnsPerPartition 141 | val numColumnsInPartition = getNumColsInGroup(groupIndex) 142 | val partitionColumns: Array[Array[Double]] = 143 | Array.fill[Array[Double]](numColumnsInPartition)(new Array[Double](numRows)) 144 | var colIdx = 0 // index within group 145 | while (colIdx < numColumnsInPartition) { 146 | var sourcePartitionIndex = 0 147 | while (sourcePartitionIndex < numSourcePartitions) { 148 | val partColLength = 149 | rowOffsets(sourcePartitionIndex + 1) - rowOffsets(sourcePartitionIndex) 150 | Array.copy(collectedPartCols(sourcePartitionIndex)(colIdx), 0, 151 | partitionColumns(colIdx), rowOffsets(sourcePartitionIndex), partColLength) 152 | sourcePartitionIndex += 1 153 | } 154 | colIdx += 1 155 | } 156 | val columnIndices = Range(0, numColumnsInPartition).map(_ + fromCol) 157 | // val columns = partitionColumns.map(Vectors.dense) 158 | columnIndices.zip(partitionColumns) 159 | } 160 | 161 | fullColumns 162 | } 163 | 164 | /** 165 | * This checks for an empty RDD (0 rows or 0 columns). 166 | * This will throw an exception if any columns have non-matching numbers of features. 167 | * @param rowStore Dataset of vectors which all have the same length (number of columns). 168 | * @return Array over columns of the number of non-zero elements in each column. 169 | * Returns empty array if the RDD is empty. 170 | */ 171 | private def countNonZerosPerColumn(rowStore: RDD[Vector]): Array[Long] = { 172 | val firstRow = rowStore.take(1) 173 | if (firstRow.length == 0) { 174 | return Array.empty[Long] 175 | } 176 | val numCols = firstRow(0).size 177 | val colSizes: Array[Long] = rowStore.mapPartitions { iterator => 178 | val partColSizes = Array.fill[Long](numCols)(0) 179 | iterator.foreach { 180 | case dv: DenseVector => 181 | var col = 0 182 | while (col < dv.size) { 183 | if (dv(col) != 0.0) partColSizes(col) += 1 184 | col += 1 185 | } 186 | case sv: SparseVector => 187 | var k = 0 188 | while (k < sv.indices.length) { 189 | if (sv.values(k) != 0.0) partColSizes(sv.indices(k)) += 1 190 | k += 1 191 | } 192 | } 193 | Iterator(partColSizes) 194 | }.fold(Array.fill[Long](numCols)(0)){ 195 | case (v1, v2) => v1.zip(v2).map(v12 => v12._1 + v12._2) 196 | } 197 | colSizes 198 | } 199 | 200 | /** 201 | * The returned RDD sets the number of partitions as follows: 202 | * - The targeted number is: 203 | * numTargetPartitions = min(rowStore num partitions, num columns) * overPartitionFactor. 204 | * - The actual number will be in the range [numTargetPartitions, 2 * numTargetPartitions]. 205 | * Partitioning is done such that each partition holds consecutive columns. 206 | * 207 | * TODO: Update this to adaptively make columns dense or sparse based on a sparsity threshold. 208 | * 209 | * TODO: Cache rowStore temporarily. 210 | * 211 | * @param rowStore RDD of dataset rows 212 | * @param overPartitionFactor Multiplier for the targeted number of partitions. This parameter 213 | * helps to ensure that P partitions handled by P compute cores 214 | * do not get split into slightly more than P partitions; 215 | * if that occurred, then work would not be shared evenly. 216 | * @return RDD of (column index, column) pairs 217 | */ 218 | def rowToColumnStoreSparse( 219 | rowStore: RDD[Vector], 220 | overPartitionFactor: Int = 3): RDD[(Int, Vector)] = { 221 | 222 | val numRows = { 223 | val longNumRows: Long = rowStore.count() 224 | require(longNumRows < Int.MaxValue, s"rowToColumnStore given RDD with $longNumRows rows," + 225 | s" but can handle at most ${Int.MaxValue} rows") 226 | longNumRows.toInt 227 | } 228 | if (numRows == 0) { 229 | return rowStore.sparkContext.parallelize(Seq.empty[(Int, Vector)]) 230 | } 231 | 232 | // Compute the number of non-zeros in each column. 233 | val colSizes: Array[Long] = countNonZerosPerColumn(rowStore) 234 | val numCols = colSizes.length 235 | val numSourcePartitions = rowStore.partitions.length 236 | if (numCols == 0 || numSourcePartitions == 0) { 237 | return rowStore.sparkContext.parallelize(Seq.empty[(Int, Vector)]) 238 | } 239 | val totalNonZeros = colSizes.sum 240 | 241 | // Split columns into groups. 242 | // Groups are chosen greedily and sequentially, putting as many columns as possible in each 243 | // group (limited by the number of non-zeros). Try to limit the number of non-zeros per 244 | // group to at most targetNonZerosPerPartition. 245 | val numTargetPartitions = math.min(numSourcePartitions, numCols) * overPartitionFactor 246 | val targetNonZerosPerPartition = (totalNonZeros / numTargetPartitions.toDouble).floor.toLong 247 | val groupStartColumns: Array[Int] = { 248 | val startCols = new ArrayBuffer[Int]() 249 | startCols += 0 250 | var currentStartCol = 0 251 | var currentNonZeros: Long = 0 252 | var col = 0 253 | while (col < numCols) { 254 | if (currentNonZeros >= targetNonZerosPerPartition && col != startCols.last) { 255 | startCols += col 256 | currentStartCol = col 257 | currentNonZeros = 0 258 | } else { 259 | currentNonZeros += colSizes(col) 260 | } 261 | col += 1 262 | } 263 | startCols += numCols 264 | startCols.toArray 265 | } 266 | val numGroups = groupStartColumns.length - 1 // actual number of destination partitions 267 | 268 | /* On each partition, re-organize into groups of columns: 269 | (groupIndex, (sourcePartitionIndex, partCols)), 270 | where partCols(colIdx) = partial column. 271 | The groupIndex will be used to groupByKey. 272 | The sourcePartitionIndex is used to ensure instance indices match up after the shuffle. 273 | The partial columns will be stacked into full columns after the shuffle. 274 | Note: By design, partCols will always have at least 1 column. 275 | */ 276 | val partialColumns: RDD[(Int, (Int, Array[SparseVector]))] = 277 | rowStore.zipWithIndex().mapPartitionsWithIndex { case (sourcePartitionIndex, iterator) => 278 | type SparseVectorBuffer = (Int, ArrayBuffer[Int], ArrayBuffer[Double]) 279 | // columnSets(groupIndex)(colIdx) 280 | // = column values for each instance in sourcePartitionIndex, 281 | // where colIdx is a 0-based index for columns for groupIndex, 282 | // and where column values are in sparse format: (size, indices, values) 283 | val columnSetSizes = new Array[Array[Int]](numGroups) 284 | val columnSetIndices = new Array[Array[ArrayBuffer[Int]]](numGroups) 285 | val columnSetValues = new Array[Array[ArrayBuffer[Double]]](numGroups) 286 | var groupIndex = 0 287 | while (groupIndex < numGroups) { 288 | val numColsInGroup = groupStartColumns(groupIndex + 1) - groupStartColumns(groupIndex) 289 | columnSetSizes(groupIndex) = Array.fill[Int](numColsInGroup)(0) 290 | columnSetIndices(groupIndex) = 291 | Array.fill[ArrayBuffer[Int]](numColsInGroup)(new ArrayBuffer[Int]) 292 | columnSetValues(groupIndex) = 293 | Array.fill[ArrayBuffer[Double]](numColsInGroup)(new ArrayBuffer[Double]) 294 | groupIndex += 1 295 | } 296 | iterator.foreach { 297 | case (dv: DenseVector, rowIndex: Long) => 298 | var groupIndex = 0 299 | while (groupIndex < numGroups) { 300 | val fromCol = groupStartColumns(groupIndex) 301 | val numColsInGroup = groupStartColumns(groupIndex + 1) - groupStartColumns(groupIndex) 302 | var colIdx = 0 303 | while (colIdx < numColsInGroup) { 304 | columnSetSizes(groupIndex)(colIdx) += 1 305 | columnSetIndices(groupIndex)(colIdx) += rowIndex.toInt 306 | columnSetValues(groupIndex)(colIdx) += dv(fromCol + colIdx) 307 | colIdx += 1 308 | } 309 | groupIndex += 1 310 | } 311 | case (sv: SparseVector, rowIndex: Long) => 312 | /* 313 | A sparse vector is chopped into groups (destination partitions). 314 | We iterate through the non-zeros (indexed by k), going to the next group sv.indices(k) 315 | passes the current group's boundary. 316 | */ 317 | var groupIndex = 0 318 | var k = 0 // index into SparseVector non-zeros 319 | val nnz = sv.indices.length 320 | while (groupIndex < numGroups && k < nnz) { 321 | val fromColumn = groupStartColumns(groupIndex) 322 | val groupEndColumn = groupStartColumns(groupIndex + 1) 323 | while (k < nnz && sv.indices(k) < groupEndColumn) { 324 | val columnIndex = sv.indices(k) // index in full row 325 | val colIdx = columnIndex - fromColumn // index in group of columns 326 | columnSetSizes(groupIndex)(colIdx) += 1 327 | columnSetIndices(groupIndex)(colIdx) += rowIndex.toInt 328 | columnSetValues(groupIndex)(colIdx) += sv.values(k) 329 | k += 1 330 | } 331 | groupIndex += 1 332 | } 333 | } 334 | Range(0, numGroups).map { groupIndex => 335 | val numColsInGroup = groupStartColumns(groupIndex + 1) - groupStartColumns(groupIndex) 336 | val groupPartialColumns: Array[SparseVector] = Range(0, numColsInGroup).map { colIdx => 337 | new SparseVector(columnSetSizes(groupIndex)(colIdx), 338 | columnSetIndices(groupIndex)(colIdx).toArray, 339 | columnSetValues(groupIndex)(colIdx).toArray) 340 | }.toArray 341 | (groupIndex, (sourcePartitionIndex, groupPartialColumns)) 342 | }.toIterator 343 | } 344 | 345 | // Shuffle data 346 | val groupedPartialColumns: RDD[(Int, Iterable[(Int, Array[SparseVector])])] = 347 | partialColumns.groupByKey() 348 | 349 | // Each target partition now holds its set of columns. 350 | // Group the partial columns into full columns. 351 | val fullColumns = groupedPartialColumns.flatMap { case (groupIndex, iterable) => 352 | val numColsInGroup = groupStartColumns(groupIndex + 1) - groupStartColumns(groupIndex) 353 | 354 | // We do not know the number of rows or non-zeros per group, so we need to collect the groups 355 | // before filling the full columns. 356 | // collectedPartCols(sourcePartitionIndex)(colIdx) = partial column 357 | val collectedPartCols = new Array[Array[SparseVector]](numSourcePartitions) 358 | // nzCounts(colIdx)(sourcePartitionIndex) = number of non-zeros 359 | val nzCounts = Array.fill[Array[Int]](numColsInGroup)(Array.fill[Int](numSourcePartitions)(0)) 360 | val iter = iterable.iterator 361 | while (iter.hasNext) { 362 | val (sourcePartitionIndex, partCols) = iter.next() 363 | collectedPartCols(sourcePartitionIndex) = partCols 364 | var colIdx = 0 365 | while (colIdx < partCols.length) { 366 | val partCol = partCols(colIdx) 367 | nzCounts(colIdx)(sourcePartitionIndex) += partCol.indices.length 368 | colIdx += 1 369 | } 370 | } 371 | // nzOffsets(colIdx)(sourcePartitionIndex) = cumulative number of non-zeros 372 | val nzOffsets: Array[Array[Int]] = nzCounts.map(_.scanLeft(0)(_ + _)) 373 | 374 | // Initialize full columns 375 | val columnNZIndices: Array[Array[Int]] = 376 | nzOffsets.map(colNZOffsets => new Array[Int](colNZOffsets.last)) 377 | val columnNZValues: Array[Array[Double]] = 378 | nzOffsets.map(colNZOffsets => new Array[Double](colNZOffsets.last)) 379 | 380 | // Fill columns 381 | var colIdx = 0 // index within group 382 | while (colIdx < numColsInGroup) { 383 | var sourcePartitionIndex = 0 384 | while (sourcePartitionIndex < numSourcePartitions) { 385 | val nzStartOffset = nzOffsets(colIdx)(sourcePartitionIndex) 386 | val partColLength = nzOffsets(colIdx)(sourcePartitionIndex + 1) - nzStartOffset 387 | Array.copy(collectedPartCols(sourcePartitionIndex)(colIdx).indices, 0, 388 | columnNZIndices(colIdx), nzStartOffset, partColLength) 389 | Array.copy(collectedPartCols(sourcePartitionIndex)(colIdx).values, 0, 390 | columnNZValues(colIdx), nzStartOffset, partColLength) 391 | sourcePartitionIndex += 1 392 | } 393 | colIdx += 1 394 | } 395 | val columns = columnNZIndices.zip(columnNZValues).map { case (indices, values) => 396 | Vectors.sparse(numRows, indices, values) 397 | } 398 | val fromColumn = groupStartColumns(groupIndex) 399 | val columnIndices = Range(0, numColsInGroup).map(_ + fromColumn) 400 | columnIndices.zip(columns) 401 | } 402 | 403 | fullColumns 404 | } 405 | } 406 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/tree/impurities.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.tree 19 | 20 | import org.apache.spark.mllib.tree.impurity.{EntropyCalculator, GiniCalculator, ImpurityCalculator, 21 | VarianceCalculator} 22 | 23 | /** 24 | * Version of impurity aggregator which owns its data and is only for 1 node. 25 | */ 26 | private[tree] abstract class ImpurityAggregatorSingle(val stats: Array[Double]) 27 | extends Serializable { 28 | 29 | def statsSize: Int = stats.length 30 | 31 | /** 32 | * Add two aggregators: this + other 33 | * @return This aggregator (modified). 34 | */ 35 | def add(other: ImpurityAggregatorSingle): this.type = { 36 | var i = 0 37 | while (i < statsSize) { 38 | stats(i) += other.stats(i) 39 | i += 1 40 | } 41 | this 42 | } 43 | 44 | /** 45 | * Subtract another aggregators from this one: this - other 46 | * @return This aggregator (modified). 47 | */ 48 | def subtract(other: ImpurityAggregatorSingle): this.type = { 49 | var i = 0 50 | while (i < statsSize) { 51 | stats(i) -= other.stats(i) 52 | i += 1 53 | } 54 | this 55 | } 56 | 57 | /** 58 | * Update stats with the given label and instance weight. 59 | * @return This aggregator (modified). 60 | */ 61 | def update(label: Double, instanceWeight: Double): this.type 62 | 63 | /** 64 | * Update stats with the given label. 65 | * @return This aggregator (modified). 66 | */ 67 | def update(label: Double): this.type = update(label, 1.0) 68 | 69 | /** Get an [[ImpurityCalculator]] for the current stats. */ 70 | def getCalculator: ImpurityCalculator 71 | 72 | def deepCopy(): ImpurityAggregatorSingle 73 | 74 | /** Total (weighted) count of instances in this aggregator */ 75 | def getCount: Double 76 | 77 | /** Resets this aggregator as though nothing has been added to it. */ 78 | def clear(): this.type = { 79 | var i = 0 80 | while (i < statsSize) { 81 | stats(i) = 0.0 82 | i += 1 83 | } 84 | this 85 | } 86 | } 87 | 88 | /** 89 | * Version of Entropy aggregator which owns its data and is only for one node. 90 | */ 91 | private[tree] class EntropyAggregatorSingle private (stats: Array[Double]) 92 | extends ImpurityAggregatorSingle(stats) with Serializable { 93 | 94 | def this(numClasses: Int) = this(new Array[Double](numClasses)) 95 | 96 | def update(label: Double, instanceWeight: Double): this.type = { 97 | if (label >= statsSize) { 98 | throw new IllegalArgumentException(s"EntropyAggregatorSingle given label $label" + 99 | s" but requires label < numClasses (= $statsSize).") 100 | } 101 | stats(label.toInt) += instanceWeight 102 | this 103 | } 104 | 105 | def getCalculator: EntropyCalculator = new EntropyCalculator(stats) 106 | 107 | override def deepCopy(): ImpurityAggregatorSingle = new EntropyAggregatorSingle(stats.clone()) 108 | 109 | override def getCount: Double = stats.sum 110 | } 111 | 112 | /** 113 | * Version of Gini aggregator which owns its data and is only for one node. 114 | */ 115 | private[tree] class GiniAggregatorSingle private (stats: Array[Double]) 116 | extends ImpurityAggregatorSingle(stats) with Serializable { 117 | 118 | def this(numClasses: Int) = this(new Array[Double](numClasses)) 119 | 120 | def update(label: Double, instanceWeight: Double): this.type = { 121 | if (label >= statsSize) { 122 | throw new IllegalArgumentException(s"GiniAggregatorSingle given label $label" + 123 | s" but requires label < numClasses (= $statsSize).") 124 | } 125 | stats(label.toInt) += instanceWeight 126 | this 127 | } 128 | 129 | def getCalculator: GiniCalculator = new GiniCalculator(stats) 130 | 131 | override def deepCopy(): ImpurityAggregatorSingle = new GiniAggregatorSingle(stats.clone()) 132 | 133 | override def getCount: Double = stats.sum 134 | } 135 | 136 | /** 137 | * Version of Variance aggregator which owns its data and is only for one node. 138 | */ 139 | private[tree] class VarianceAggregatorSingle 140 | extends ImpurityAggregatorSingle(new Array[Double](3)) with Serializable { 141 | 142 | def update(label: Double, instanceWeight: Double): this.type = { 143 | stats(0) += instanceWeight 144 | stats(1) += instanceWeight * label 145 | stats(2) += instanceWeight * label * label 146 | this 147 | } 148 | 149 | def getCalculator: VarianceCalculator = new VarianceCalculator(stats) 150 | 151 | override def deepCopy(): ImpurityAggregatorSingle = { 152 | val tmp = new VarianceAggregatorSingle() 153 | stats.copyToArray(tmp.stats) 154 | tmp 155 | } 156 | 157 | override def getCount: Double = stats(0) 158 | } 159 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/tree/ygg/Node.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.tree.ygg 19 | 20 | import org.apache.spark.annotation.DeveloperApi 21 | import org.apache.spark.ml.tree.{InternalNode, LeafNode, Node => SparkNode} 22 | import org.apache.spark.mllib.linalg.Vector 23 | import org.apache.spark.mllib.tree.impurity.ImpurityCalculator 24 | import org.apache.spark.mllib.tree.model.{ImpurityStats, InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} 25 | 26 | /** 27 | * :: DeveloperApi :: 28 | * Decision tree node interface. 29 | */ 30 | @DeveloperApi 31 | sealed abstract class Node extends Serializable { 32 | 33 | // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree 34 | // code into the new API and deprecate the old API. SPARK-3727 35 | 36 | /** Prediction a leaf node makes, or which an internal node would make if it were a leaf node */ 37 | def prediction: Double 38 | 39 | /** Impurity measure at this node (for training data) */ 40 | def impurity: Double 41 | 42 | /** 43 | * Statistics aggregated from training data at this node, used to compute prediction, impurity, 44 | * and probabilities. 45 | * For classification, the array of class counts must be normalized to a probability distribution. 46 | */ 47 | private[ml] def impurityStats: ImpurityCalculator 48 | 49 | /** Recursive prediction helper method */ 50 | private[ml] def predictImpl(features: Vector): LeafNode 51 | 52 | /** 53 | * Get the number of nodes in tree below this node, including leaf nodes. 54 | * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. 55 | */ 56 | private[tree] def numDescendants: Int 57 | 58 | /** 59 | * Recursive print function. 60 | * @param indentFactor The number of spaces to add to each level of indentation. 61 | */ 62 | private[tree] def subtreeToString(indentFactor: Int = 0): String 63 | 64 | /** 65 | * Get depth of tree from this node. 66 | * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes. 67 | */ 68 | private[tree] def subtreeDepth: Int 69 | 70 | /** 71 | * Create a copy of this node in the old Node format, recursively creating child nodes as needed. 72 | * @param id Node ID using old format IDs 73 | */ 74 | private[ml] def toOld(id: Int): OldNode 75 | 76 | /** 77 | * Trace down the tree, and return the largest feature index used in any split. 78 | * @return Max feature index used in a split, or -1 if there are no splits (single leaf node). 79 | */ 80 | private[ml] def maxSplitFeatureIndex(): Int 81 | } 82 | 83 | private[ml] object Node { 84 | 85 | /** 86 | * Create a new Node from the old Node format, recursively creating child nodes as needed. 87 | */ 88 | def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): SparkNode = { 89 | if (oldNode.isLeaf) { 90 | // TODO: Once the implementation has been moved to this API, then include sufficient 91 | // statistics here. 92 | new LeafNode(prediction = oldNode.predict.predict, 93 | impurity = oldNode.impurity, impurityStats = null) 94 | } else { 95 | val gain = if (oldNode.stats.nonEmpty) { 96 | oldNode.stats.get.gain 97 | } else { 98 | 0.0 99 | } 100 | new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, 101 | gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), 102 | rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), 103 | split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) 104 | } 105 | } 106 | } 107 | 108 | /** 109 | * Version of a node used in learning. This uses vars so that we can modify nodes as we split the 110 | * tree by adding children, etc. 111 | * 112 | * For now, we use node IDs. These will be kept internal since we hope to remove node IDs 113 | * in the future, or at least change the indexing (so that we can support much deeper trees). 114 | * 115 | * This node can either be: 116 | * - a leaf node, with leftChild, rightChild, split set to null, or 117 | * - an internal node, with all values set 118 | * 119 | * @param id We currently use the same indexing as the old implementation in 120 | * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. 121 | * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, 122 | * so that we do not need to consider splitting it further. 123 | * @param stats Impurity statistics for this node. 124 | */ 125 | private[tree] class LearningNode( 126 | var id: Int, 127 | var leftChild: Option[LearningNode], 128 | var rightChild: Option[LearningNode], 129 | var split: Option[Split], 130 | var isLeaf: Boolean, 131 | var stats: ImpurityStats) extends Serializable { 132 | 133 | /** 134 | * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. 135 | */ 136 | def toSparkNode: SparkNode = { 137 | if (leftChild.nonEmpty) { 138 | assert(rightChild.nonEmpty && split.nonEmpty && stats != null, 139 | "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") 140 | new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, 141 | leftChild.get.toSparkNode, rightChild.get.toSparkNode, split.get.toSparkSplit, stats.impurityCalculator) 142 | } else { 143 | if (stats.valid) { 144 | new LeafNode(stats.impurityCalculator.predict, stats.impurity, 145 | stats.impurityCalculator) 146 | } else { 147 | // Here we want to keep same behavior with the old mllib.DecisionTreeModel 148 | new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) 149 | } 150 | } 151 | } 152 | 153 | /** 154 | * Get the node index corresponding to this data point. 155 | * This function mimics prediction, passing an example from the root node down to a leaf 156 | * or unsplit node; that node's index is returned. 157 | * 158 | * @param binnedFeatures Binned feature vector for data point. 159 | * @param splits possible splits for all features, indexed (numFeatures)(numSplits) 160 | * @return Leaf index if the data point reaches a leaf. 161 | * Otherwise, last node reachable in tree matching this example. 162 | * Note: This is the global node index, i.e., the index used in the tree. 163 | * This index is different from the index used during training a particular 164 | * group of nodes on one call to 165 | * [[org.apache.spark.ml.tree.impl.RandomForest.findBestSplits()]]. 166 | */ 167 | def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = { 168 | if (this.isLeaf || this.split.isEmpty) { 169 | this.id 170 | } else { 171 | val split = this.split.get 172 | val featureIndex = split.featureIndex 173 | val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex)) 174 | if (this.leftChild.isEmpty) { 175 | // Not yet split. Return next layer of nodes to train 176 | if (splitLeft) { 177 | LearningNode.leftChildIndex(this.id) 178 | } else { 179 | LearningNode.rightChildIndex(this.id) 180 | } 181 | } else { 182 | if (splitLeft) { 183 | this.leftChild.get.predictImpl(binnedFeatures, splits) 184 | } else { 185 | this.rightChild.get.predictImpl(binnedFeatures, splits) 186 | } 187 | } 188 | } 189 | } 190 | } 191 | 192 | private[tree] object LearningNode { 193 | 194 | /** Create a node with some of its fields set. */ 195 | def apply( 196 | id: Int, 197 | isLeaf: Boolean, 198 | stats: ImpurityStats): LearningNode = { 199 | new LearningNode(id, None, None, None, isLeaf, stats) 200 | } 201 | 202 | /** Create an empty node with the given node index. Values must be set later on. */ 203 | def emptyNode(id: Int): LearningNode = { 204 | new LearningNode(id, None, None, None, false, null) 205 | } 206 | 207 | // The below indexing methods were copied from spark.mllib.tree.model.Node 208 | 209 | /** 210 | * Return the index of the left child of this node. 211 | */ 212 | def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 213 | 214 | /** 215 | * Return the index of the right child of this node. 216 | */ 217 | def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1 218 | 219 | /** 220 | * Get the parent index of the given node, or 0 if it is the root. 221 | */ 222 | def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1 223 | 224 | /** 225 | * Return the level of a tree which the given node is in. 226 | */ 227 | def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) { 228 | throw new IllegalArgumentException(s"0 is not a valid node index.") 229 | } else { 230 | java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex)) 231 | } 232 | 233 | /** 234 | * Returns true if this is a left child. 235 | * Note: Returns false for the root. 236 | */ 237 | def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0 238 | 239 | /** 240 | * Return the maximum number of nodes which can be in the given level of the tree. 241 | * @param level Level of tree (0 = root). 242 | */ 243 | def maxNodesInLevel(level: Int): Int = 1 << level 244 | 245 | /** 246 | * Return the index of the first node in the given level. 247 | * @param level Level of tree (0 = root). 248 | */ 249 | def startIndexInLevel(level: Int): Int = 1 << level 250 | 251 | /** 252 | * Traces down from a root node to get the node with the given node index. 253 | * This assumes the node exists. 254 | */ 255 | def getNode(nodeIndex: Int, rootNode: LearningNode): LearningNode = { 256 | var tmpNode: LearningNode = rootNode 257 | var levelsToGo = indexToLevel(nodeIndex) 258 | while (levelsToGo > 0) { 259 | if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { 260 | tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode] 261 | } else { 262 | tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode] 263 | } 264 | levelsToGo -= 1 265 | } 266 | tmpNode 267 | } 268 | 269 | } 270 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/tree/ygg/Split.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.tree.ygg 19 | 20 | import org.apache.spark.annotation.DeveloperApi 21 | import org.apache.spark.ml.tree.{CategoricalSplit => SparkCategoricalSplit, ContinuousSplit => SparkContinuousSplit, Split => SparkSplit} 22 | import org.apache.spark.mllib.linalg.Vector 23 | import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType} 24 | import org.apache.spark.mllib.tree.model.{Split => OldSplit} 25 | 26 | 27 | /** 28 | * :: DeveloperApi :: 29 | * Interface for a "Split," which specifies a test made at a decision tree node 30 | * to choose the left or right path. 31 | */ 32 | @DeveloperApi 33 | sealed trait Split extends Serializable { 34 | /** convert to org.apache.ml.tree.Split **/ 35 | def toSparkSplit: SparkSplit 36 | 37 | /** Index of feature which this split tests */ 38 | def featureIndex: Int 39 | 40 | /** 41 | * Return true (split to left) or false (split to right). 42 | * @param features Vector of features (original values, not binned). 43 | */ 44 | private[ml] def shouldGoLeft(features: Vector): Boolean 45 | 46 | /** 47 | * Return true (split to left) or false (split to right). 48 | * @param binnedFeature Binned feature value. 49 | * @param splits All splits for the given feature. 50 | */ 51 | private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean 52 | 53 | /** 54 | * Return true (split to left) or false (split to right). 55 | * @param feature Feature value (original value, not binned) 56 | */ 57 | private[tree] def shouldGoLeft(feature: Double): Boolean 58 | 59 | /** Convert to old Split format */ 60 | private[tree] def toOld: OldSplit 61 | } 62 | 63 | private[tree] object Split { 64 | 65 | def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): SparkSplit = { 66 | oldSplit.featureType match { 67 | case OldFeatureType.Categorical => 68 | new SparkCategoricalSplit(featureIndex = oldSplit.feature, 69 | _leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature)) 70 | case OldFeatureType.Continuous => 71 | new SparkContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold) 72 | } 73 | } 74 | } 75 | 76 | /** 77 | * :: DeveloperApi :: 78 | * Split which tests a categorical feature. 79 | * @param featureIndex Index of the feature to test 80 | * @param _leftCategories If the feature value is in this set of categories, then the split goes 81 | * left. Otherwise, it goes right. 82 | * @param numCategories Number of categories for this feature. 83 | */ 84 | @DeveloperApi 85 | final class CategoricalSplit private[ml] ( 86 | override val featureIndex: Int, 87 | _leftCategories: Array[Double], 88 | private val numCategories: Int) 89 | extends Split { 90 | 91 | require(_leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" + 92 | s" (should be in range [0, $numCategories)): ${_leftCategories.mkString(",")}") 93 | 94 | /** 95 | * If true, then "categories" is the set of categories for splitting to the left, and vice versa. 96 | */ 97 | private val isLeft: Boolean = _leftCategories.length <= numCategories / 2 98 | 99 | /** Set of categories determining the splitting rule, along with [[isLeft]]. */ 100 | private val categories: Set[Double] = { 101 | if (isLeft) { 102 | _leftCategories.toSet 103 | } else { 104 | setComplement(_leftCategories.toSet) 105 | } 106 | } 107 | 108 | override private[ml] def shouldGoLeft(features: Vector): Boolean = { 109 | if (isLeft) { 110 | categories.contains(features(featureIndex)) 111 | } else { 112 | !categories.contains(features(featureIndex)) 113 | } 114 | } 115 | 116 | override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = { 117 | if (isLeft) { 118 | categories.contains(binnedFeature.toDouble) 119 | } else { 120 | !categories.contains(binnedFeature.toDouble) 121 | } 122 | } 123 | 124 | override private[tree] def shouldGoLeft(feature: Double): Boolean = { 125 | if (isLeft) { 126 | categories.contains(feature) 127 | } else { 128 | !categories.contains(feature) 129 | } 130 | } 131 | 132 | override def equals(o: Any): Boolean = { 133 | o match { 134 | case other: CategoricalSplit => featureIndex == other.featureIndex && 135 | isLeft == other.isLeft && categories == other.categories 136 | case _ => false 137 | } 138 | } 139 | 140 | override private[tree] def toOld: OldSplit = { 141 | val oldCats = if (isLeft) { 142 | categories 143 | } else { 144 | setComplement(categories) 145 | } 146 | OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList) 147 | } 148 | 149 | /** Get sorted categories which split to the left */ 150 | def leftCategories: Array[Double] = { 151 | val cats = if (isLeft) categories else setComplement(categories) 152 | cats.toArray.sorted 153 | } 154 | 155 | /** Get sorted categories which split to the right */ 156 | def rightCategories: Array[Double] = { 157 | val cats = if (isLeft) setComplement(categories) else categories 158 | cats.toArray.sorted 159 | } 160 | 161 | /** [0, numCategories) \ cats */ 162 | private def setComplement(cats: Set[Double]): Set[Double] = { 163 | Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet 164 | } 165 | 166 | /** convert to org.apache.ml.tree.Split **/ 167 | override def toSparkSplit: SparkSplit = { 168 | new SparkCategoricalSplit(featureIndex, _leftCategories, numCategories) 169 | } 170 | } 171 | 172 | /** 173 | * :: DeveloperApi :: 174 | * Split which tests a continuous feature. 175 | * @param featureIndex Index of the feature to test 176 | * @param threshold If the feature value is <= this threshold, then the split goes left. 177 | * Otherwise, it goes right. 178 | */ 179 | @DeveloperApi 180 | final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) 181 | extends Split { 182 | 183 | override private[ml] def shouldGoLeft(features: Vector): Boolean = { 184 | features(featureIndex) <= threshold 185 | } 186 | 187 | override private[tree] def shouldGoLeft(binnedFeature: Int, splits: Array[Split]): Boolean = { 188 | if (binnedFeature == splits.length) { 189 | // > last split, so split right 190 | false 191 | } else { 192 | val featureValueUpperBound = splits(binnedFeature).asInstanceOf[ContinuousSplit].threshold 193 | featureValueUpperBound <= threshold 194 | } 195 | } 196 | 197 | override private[tree] def shouldGoLeft(feature: Double): Boolean = { 198 | feature <= threshold 199 | } 200 | 201 | override def equals(o: Any): Boolean = { 202 | o match { 203 | case other: ContinuousSplit => 204 | featureIndex == other.featureIndex && threshold == other.threshold 205 | case _ => 206 | false 207 | } 208 | } 209 | 210 | override private[tree] def toOld: OldSplit = { 211 | OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double]) 212 | } 213 | 214 | /** convert to org.apache.ml.tree.Split **/ 215 | override def toSparkSplit: SparkSplit = { 216 | new SparkContinuousSplit(featureIndex, threshold) 217 | } 218 | } 219 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/tree/impl/YggdrasilSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.tree.impl 19 | 20 | import org.apache.spark.ml.tree.{ygg, _} 21 | import org.apache.spark.ml.tree.impl.Yggdrasil.{FeatureVector, PartitionInfo, YggdrasilMetadata} 22 | import org.apache.spark.mllib.linalg.Vectors 23 | import org.apache.spark.mllib.regression.LabeledPoint 24 | import org.apache.spark.mllib.tree.impurity._ 25 | import org.apache.spark.mllib.tree.model.ImpurityStats 26 | import org.apache.spark.mllib.util.{MLlibTestSparkContext, SparkFunSuite} 27 | import org.apache.spark.util.collection.BitSet 28 | 29 | import scala.util.Random 30 | 31 | /** 32 | * Test suite for [[Yggdrasil]]. 33 | */ 34 | class YggdrasilSuite extends SparkFunSuite with MLlibTestSparkContext { 35 | 36 | /* * * * * * * * * * * Integration tests * * * * * * * * * * */ 37 | 38 | test("run deep example") { 39 | val data = Range(0, 3).map(x => LabeledPoint(math.pow(x, 3), Vectors.dense(x))) 40 | val df = sqlContext.createDataFrame(data) 41 | val dt = new YggdrasilRegressor() 42 | .setFeaturesCol("features") // indexedFeatures 43 | .setLabelCol("label") 44 | .setMaxDepth(10) 45 | val model = dt.fit(df) 46 | assert(model.rootNode.isInstanceOf[InternalNode]) 47 | val root = model.rootNode.asInstanceOf[InternalNode] 48 | assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[LeafNode]) 49 | val left = root.leftChild.asInstanceOf[InternalNode] 50 | assert(left.leftChild.isInstanceOf[LeafNode], left.rightChild.isInstanceOf[LeafNode]) 51 | } 52 | 53 | test("run example") { 54 | val data = Range(0, 8).map(x => LabeledPoint(x, Vectors.dense(x))) 55 | val df = sqlContext.createDataFrame(data) 56 | val dt = new YggdrasilRegressor() 57 | .setFeaturesCol("features") 58 | .setLabelCol("label") 59 | .setMaxDepth(10) 60 | val model = dt.fit(df) 61 | assert(model.rootNode.isInstanceOf[InternalNode]) 62 | val root = model.rootNode.asInstanceOf[InternalNode] 63 | assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[InternalNode]) 64 | val left = root.leftChild.asInstanceOf[InternalNode] 65 | val right = root.rightChild.asInstanceOf[InternalNode] 66 | val grandkids = Array(left.leftChild, left.rightChild, right.leftChild, right.rightChild) 67 | assert(grandkids.forall(_.isInstanceOf[InternalNode])) 68 | } 69 | 70 | test("example with imbalanced tree") { 71 | val data = Seq( 72 | (0.0, Vectors.dense(0.0, 0.0)), 73 | (0.0, Vectors.dense(0.0, 0.0)), 74 | (1.0, Vectors.dense(0.0, 1.0)), 75 | (0.0, Vectors.dense(0.0, 1.0)), 76 | (1.0, Vectors.dense(1.0, 0.0)), 77 | (1.0, Vectors.dense(1.0, 0.0)), 78 | (1.0, Vectors.dense(1.0, 1.0)), 79 | (1.0, Vectors.dense(1.0, 1.0)) 80 | ).map { case (l, p) => LabeledPoint(l, p) } 81 | val df = sqlContext.createDataFrame(data) 82 | val dt = new YggdrasilRegressor() 83 | .setFeaturesCol("features") 84 | .setLabelCol("label") 85 | .setMaxDepth(5) 86 | val model = dt.fit(df) 87 | assert(model.depth === 2) 88 | assert(model.numNodes === 5) 89 | } 90 | 91 | test("example providing transposed dataset") { 92 | val data = sc.parallelize(Range(0, 8).map(x => LabeledPoint(x, Vectors.dense(x)))) 93 | val transposedDataset = YggdrasilUtil.rowToColumnStoreDense(data.map(_.features)) 94 | val dt = new YggdrasilRegressor() 95 | .setFeaturesCol("features") 96 | .setLabelCol("label") 97 | .setMaxDepth(10) 98 | val model = dt.train(data, transposedDataset, Map.empty[Int, Int]) 99 | assert(model.rootNode.isInstanceOf[InternalNode]) 100 | val root = model.rootNode.asInstanceOf[InternalNode] 101 | assert(root.leftChild.isInstanceOf[InternalNode] && root.rightChild.isInstanceOf[InternalNode]) 102 | val left = root.leftChild.asInstanceOf[InternalNode] 103 | val right = root.rightChild.asInstanceOf[InternalNode] 104 | val grandkids = Array(left.leftChild, left.rightChild, right.leftChild, right.rightChild) 105 | assert(grandkids.forall(_.isInstanceOf[InternalNode])) 106 | } 107 | 108 | /* * * * * * * * * * * Helper classes * * * * * * * * * * */ 109 | 110 | test("FeatureVector") { 111 | val v = new FeatureVector(1, 0, Array(0.1, 0.3, 0.7), Array(1, 2, 0)) 112 | 113 | val vCopy = v.deepCopy() 114 | vCopy.values(0) = 1000 115 | assert(v.values(0) !== vCopy.values(0)) 116 | 117 | val original = Array(0.7, 0.1, 0.3) 118 | val v2 = FeatureVector.fromOriginal(1, 0, original) 119 | assert(v === v2) 120 | } 121 | 122 | test("FeatureVectorSortByValue") { 123 | val values = Array(0.1, 0.2, 0.4, 0.6, 0.7, 0.9, 1.5, 1.55) 124 | val col = Random.shuffle(values.toIterator).toArray 125 | val unsortedIndices = col.indices 126 | val sortedIndices = unsortedIndices.sortBy(x => col(x)).toArray 127 | val featureIndex = 3 128 | val featureArity = 0 129 | val fvSorted = 130 | FeatureVector.fromOriginal(featureIndex, featureArity, col) 131 | assert(fvSorted.featureIndex === featureIndex) 132 | assert(fvSorted.featureArity === featureArity) 133 | assert(fvSorted.values.deep === values.deep) 134 | assert(fvSorted.indices.deep === sortedIndices.deep) 135 | } 136 | 137 | test("PartitionInfo") { 138 | val numRows = 4 139 | val col1 = 140 | FeatureVector.fromOriginal(0, 0, Array(0.8, 0.2, 0.1, 0.6)) 141 | val col2 = 142 | FeatureVector.fromOriginal(1, 3, Array(0, 1, 0, 2)) 143 | val labels = Array(0, 0, 0, 1, 1, 1, 1).map(_.toByte) 144 | val metadata = new YggdrasilMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, Entropy, Map(1 -> 3)) 145 | val fullImpurityAgg = metadata.createImpurityAggregator() 146 | labels.foreach(label => fullImpurityAgg.update(label)) 147 | 148 | assert(col1.values.length === numRows) 149 | assert(col2.values.length === numRows) 150 | 151 | val nodeOffsets = Array(0, numRows) 152 | val activeNodes = new BitSet(1) 153 | activeNodes.set(0) 154 | 155 | val info = PartitionInfo(Array(col1, col2), nodeOffsets, activeNodes, Array(fullImpurityAgg)) 156 | 157 | // Create bitVector for splitting the 4 rows: L, R, L, R 158 | // New groups are {0, 2}, {1, 3} 159 | val bitVector = new BitSet(4) 160 | bitVector.set(1) 161 | bitVector.set(3) 162 | 163 | // for these tests, use the activeNodes for nodeSplitBitVector 164 | val newInfo = info.update(bitVector, newNumNodeOffsets = 3, labels, metadata) 165 | 166 | assert(newInfo.columns.length === 2) 167 | val expectedCol1a = 168 | new FeatureVector(0, 0, Array(0.1, 0.8, 0.2, 0.6), Array(2, 0, 1, 3)) 169 | val expectedCol1b = 170 | new FeatureVector(1, 3, Array(0, 0, 1, 2), Array(0, 2, 1, 3)) 171 | assert(newInfo.columns(0) === expectedCol1a) 172 | assert(newInfo.columns(1) === expectedCol1b) 173 | assert(newInfo.nodeOffsets === Array(0, 2, 4)) 174 | assert(newInfo.activeNodes.iterator.toSet === Set(0, 1)) 175 | 176 | // stats for the two child nodes should be correct 177 | // val fullImpurityStatsArray = 178 | // Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble) 179 | // val fullImpurity = Entropy.calculate(fullImpurityStatsArray, labels.length) 180 | // val stats = newInfo.fullImpurityAggs(0).getCalculator. 181 | // assert(stats.gain === 0.0) 182 | // assert(stats.impurity === fullImpurity) 183 | // assert(stats.impurityCalculator.stats === fullImpurityStatsArray) 184 | 185 | // Create 2 bitVectors for splitting into: 0, 2, 1, 3 186 | val bitVector2 = new BitSet(4) 187 | bitVector2.set(2) // 2 goes to the right 188 | bitVector2.set(3) // 3 goes to the right 189 | 190 | val newInfo2 = newInfo.update(bitVector2, newNumNodeOffsets = 5, labels, metadata) 191 | 192 | assert(newInfo2.columns.length === 2) 193 | val expectedCol2a = 194 | new FeatureVector(0, 0, Array(0.8, 0.1, 0.2, 0.6), Array(0, 2, 1, 3)) 195 | val expectedCol2b = 196 | new FeatureVector(1, 3, Array(0, 0, 1, 2), Array(0, 2, 1, 3)) 197 | assert(newInfo2.columns(0) === expectedCol2a) 198 | assert(newInfo2.columns(1) === expectedCol2b) 199 | assert(newInfo2.nodeOffsets === Array(0, 1, 2, 3, 4)) 200 | assert(newInfo2.activeNodes.iterator.toSet === Set(0, 1, 2, 3)) 201 | } 202 | 203 | /* * * * * * * * * * * Choosing Splits * * * * * * * * * * */ 204 | 205 | test("computeBestSplits") { 206 | // TODO 207 | } 208 | 209 | test("chooseSplit: choose correct type of split") { 210 | val labels = Array(0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0) 211 | val labelsAsBytes = labels.map(_.toByte) 212 | val fromOffset = 1 213 | val toOffset = 4 214 | val impurity = Entropy 215 | val metadata = new YggdrasilMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity, Map(1 -> 3)) 216 | val fullImpurityAgg = metadata.createImpurityAggregator() 217 | labels.foreach(label => fullImpurityAgg.update(label)) 218 | 219 | val col1 = FeatureVector.fromOriginal(featureIndex = 0, featureArity = 0, 220 | values = Array(0.8, 0.1, 0.1, 0.2, 0.3, 0.5, 0.6)) 221 | val (split1, _) = YggdrasilClassification.chooseSplit(col1, labelsAsBytes, fromOffset, toOffset, fullImpurityAgg, metadata) 222 | assert(split1.nonEmpty && split1.get.isInstanceOf[ygg.ContinuousSplit]) 223 | 224 | val col2 = FeatureVector.fromOriginal(featureIndex = 1, featureArity = 3, 225 | values = Array(0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0)) 226 | val (split2, _) = YggdrasilRegression.chooseSplit(col2, labels, fromOffset, toOffset, fullImpurityAgg, metadata) 227 | assert(split2.nonEmpty && split2.get.isInstanceOf[ygg.CategoricalSplit]) 228 | } 229 | 230 | test("chooseOrderedCategoricalSplit: basic case") { 231 | val featureIndex = 0 232 | val values = Array(0, 0, 1, 2, 2, 2, 2).map(_.toDouble) 233 | val featureArity = values.max.toInt + 1 234 | 235 | def testHelper( 236 | labels: Array[Byte], 237 | expectedLeftCategories: Array[Double], 238 | expectedLeftStats: Array[Double], 239 | expectedRightStats: Array[Double]): Unit = { 240 | val expectedRightCategories = Range(0, featureArity) 241 | .filter(c => !expectedLeftCategories.contains(c)).map(_.toDouble).toArray 242 | val impurity = Entropy 243 | val metadata = new YggdrasilMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, 244 | impurity, Map.empty[Int, Int]) 245 | val (split, stats) = 246 | YggdrasilClassification.chooseOrderedCategoricalSplit(featureIndex, values, values.indices.toArray, 247 | labels, 0, values.length, metadata, featureArity) 248 | split match { 249 | case Some(s: ygg.CategoricalSplit) => 250 | assert(s.featureIndex === featureIndex) 251 | assert(s.leftCategories === expectedLeftCategories) 252 | assert(s.rightCategories === expectedRightCategories) 253 | case _ => 254 | throw new AssertionError( 255 | s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") 256 | } 257 | val fullImpurityStatsArray = 258 | Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble) 259 | val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length) 260 | assert(stats.gain === fullImpurity) 261 | assert(stats.impurity === fullImpurity) 262 | assert(stats.impurityCalculator.stats === fullImpurityStatsArray) 263 | assert(stats.leftImpurityCalculator.stats === expectedLeftStats) 264 | assert(stats.rightImpurityCalculator.stats === expectedRightStats) 265 | assert(stats.valid) 266 | } 267 | 268 | val labels1 = Array(0, 0, 1, 1, 1, 1, 1).map(_.toByte) 269 | testHelper(labels1, Array(0.0), Array(2.0, 0.0), Array(0.0, 5.0)) 270 | 271 | val labels2 = Array(0, 0, 0, 1, 1, 1, 1).map(_.toByte) 272 | testHelper(labels2, Array(0.0, 1.0), Array(3.0, 0.0), Array(0.0, 4.0)) 273 | } 274 | 275 | test("chooseOrderedCategoricalSplit: return bad split if we should not split") { 276 | val featureIndex = 0 277 | val values = Array(0, 0, 1, 2, 2, 2, 2).map(_.toDouble) 278 | val featureArity = values.max.toInt + 1 279 | 280 | val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0) 281 | 282 | val impurity = Entropy 283 | val metadata = new YggdrasilMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity, 284 | Map(featureIndex -> featureArity)) 285 | val (split, stats) = 286 | YggdrasilRegression.chooseOrderedCategoricalSplit(featureIndex, values, values.indices.toArray, 287 | labels, 0, values.length, metadata, featureArity) 288 | assert(split.isEmpty) 289 | val fullImpurityStatsArray = 290 | Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble) 291 | val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length) 292 | assert(stats.gain === 0.0) 293 | assert(stats.impurity === fullImpurity) 294 | assert(stats.impurityCalculator.stats === fullImpurityStatsArray) 295 | assert(stats.valid) 296 | } 297 | 298 | test("chooseUnorderedCategoricalSplit: basic case") { 299 | val featureIndex = 0 300 | val featureArity = 4 301 | val values = Array(3.0, 1.0, 0.0, 2.0, 2.0) 302 | val labels = Array(0.0, 0.0, 1.0, 1.0, 2.0) 303 | val impurity = Entropy 304 | val metadata = new YggdrasilMetadata(numClasses = 3, maxBins = 16, minInfoGain = 0.0, impurity, 305 | Map(featureIndex -> featureArity)) 306 | val allSplits = metadata.getUnorderedSplits(featureIndex) 307 | val (split, _) = YggdrasilRegression.chooseUnorderedCategoricalSplit(featureIndex, values, values.indices.toArray, 308 | labels, 0, values.length, metadata, featureArity, allSplits) 309 | split match { 310 | case Some(s: ygg.CategoricalSplit) => 311 | assert(s.featureIndex === featureIndex) 312 | assert(s.leftCategories.toSet === Set(0.0, 2.0)) 313 | assert(s.rightCategories.toSet === Set(1.0, 3.0)) 314 | // TODO: test correctness of stats 315 | case _ => 316 | throw new AssertionError( 317 | s"Expected CategoricalSplit but got ${split.getClass.getSimpleName}") 318 | } 319 | } 320 | 321 | test("chooseUnorderedCategoricalSplit: return bad split if we should not split") { 322 | val featureIndex = 0 323 | val featureArity = 4 324 | val values = Array(3.0, 1.0, 0.0, 2.0, 2.0) 325 | val labels = Array(1.0, 1.0, 1.0, 1.0, 1.0).map(_.toByte) 326 | val impurity = Entropy 327 | val metadata = new YggdrasilMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity, 328 | Map(featureIndex -> featureArity)) 329 | val (split, stats) = 330 | YggdrasilClassification.chooseOrderedCategoricalSplit(featureIndex, values, values.indices.toArray, 331 | labels, 0, values.length, metadata, featureArity) 332 | assert(split.isEmpty) 333 | val fullImpurityStatsArray = 334 | Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble) 335 | val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length) 336 | assert(stats.gain === 0.0) 337 | assert(stats.impurity === fullImpurity) 338 | assert(stats.impurityCalculator.stats === fullImpurityStatsArray) 339 | assert(stats.valid) 340 | } 341 | 342 | test("chooseContinuousSplit: basic case") { 343 | val featureIndex = 0 344 | val values = Array(0.1, 0.2, 0.3, 0.4, 0.5) 345 | val labels = Array(0.0, 0.0, 1.0, 1.0, 1.0) 346 | val impurity = Entropy 347 | val metadata = new YggdrasilMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity, Map.empty[Int, Int]) 348 | val fullImpurityAgg = metadata.createImpurityAggregator() 349 | labels.foreach(label => fullImpurityAgg.update(label)) 350 | 351 | val (split, stats) = YggdrasilRegression.chooseContinuousSplit(featureIndex, values, 352 | values.indices.toArray, labels, 0, values.length, fullImpurityAgg, metadata) 353 | split match { 354 | case Some(s: ygg.ContinuousSplit) => 355 | assert(s.featureIndex === featureIndex) 356 | assert(s.threshold === 0.2) 357 | case _ => 358 | throw new AssertionError( 359 | s"Expected ContinuousSplit but got ${split.getClass.getSimpleName}") 360 | } 361 | val fullImpurityStatsArray = 362 | Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble) 363 | val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length) 364 | assert(stats.gain === fullImpurity) 365 | assert(stats.impurity === fullImpurity) 366 | assert(stats.impurityCalculator.stats === fullImpurityStatsArray) 367 | assert(stats.leftImpurityCalculator.stats === Array(2.0, 0.0)) 368 | assert(stats.rightImpurityCalculator.stats === Array(0.0, 3.0)) 369 | assert(stats.valid) 370 | } 371 | 372 | test("chooseContinuousSplit: return bad split if we should not split") { 373 | val featureIndex = 0 374 | val values = Array(0.1, 0.2, 0.3, 0.4, 0.5) 375 | val labels = Array(0.0, 0.0, 0.0, 0.0, 0.0).map(_.toByte) 376 | val impurity = Entropy 377 | val metadata = new YggdrasilMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, impurity, Map.empty[Int, Int]) 378 | val fullImpurityAgg = metadata.createImpurityAggregator() 379 | labels.foreach(label => fullImpurityAgg.update(label)) 380 | 381 | val (split, stats) = YggdrasilClassification.chooseContinuousSplit(featureIndex, values, values.indices.toArray, 382 | labels, 0, values.length, fullImpurityAgg, metadata) 383 | // split should be None 384 | assert(split.isEmpty) 385 | // stats for parent node should be correct 386 | val fullImpurityStatsArray = 387 | Array(labels.count(_ == 0.0).toDouble, labels.count(_ == 1.0).toDouble) 388 | val fullImpurity = impurity.calculate(fullImpurityStatsArray, labels.length) 389 | assert(stats.gain === 0.0) 390 | assert(stats.impurity === fullImpurity) 391 | assert(stats.impurityCalculator.stats === fullImpurityStatsArray) 392 | } 393 | 394 | /* * * * * * * * * * * Bit subvectors * * * * * * * * * * */ 395 | 396 | test("bitSubvectorFromSplit: 1 node") { 397 | val col = 398 | FeatureVector.fromOriginal(0, 0, Array(0.1, 0.2, 0.4, 0.6, 0.7)) 399 | val fromOffset = 0 400 | val toOffset = col.values.length 401 | val numRows = toOffset 402 | val split = new ygg.ContinuousSplit(0, threshold = 0.5) 403 | val bitv = Yggdrasil.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows) 404 | assert(bitv.toArray.toSet === Set(3, 4)) 405 | } 406 | 407 | test("bitSubvectorFromSplit: 2 nodes") { 408 | // Initially, 1 split: (0, 2, 4) | (1, 3) 409 | val col = new FeatureVector(0, 0, Array(0.1, 0.2, 0.4, 0.6, 0.7), 410 | Array(4, 2, 0, 1, 3)) 411 | def checkSplit(fromOffset: Int, toOffset: Int, threshold: Double, 412 | expectedRight: Set[Int]): Unit = { 413 | val split = new ygg.ContinuousSplit(0, threshold) 414 | val numRows = col.values.length 415 | val bitv = Yggdrasil.bitVectorFromSplit(col, fromOffset, toOffset, split, numRows) 416 | assert(bitv.toArray.toSet === expectedRight) 417 | } 418 | // Left child node 419 | checkSplit(0, 3, 0.05, Set(0, 2, 4)) 420 | checkSplit(0, 3, 0.15, Set(0, 2)) 421 | checkSplit(0, 3, 0.2, Set(0)) 422 | checkSplit(0, 3, 0.5, Set()) 423 | // Right child node 424 | checkSplit(3, 5, 0.1, Set(1, 3)) 425 | checkSplit(3, 5, 0.65, Set(3)) 426 | checkSplit(3, 5, 0.8, Set()) 427 | } 428 | 429 | test("collectBitVectors with 1 vector") { 430 | val col = 431 | FeatureVector.fromOriginal(0, 0, Array(0.1, 0.2, 0.4, 0.6, 0.7)) 432 | val numRows = col.values.length 433 | val activeNodes = new BitSet(1) 434 | activeNodes.set(0) 435 | val labels = Array(0, 0, 0, 1, 1, 1, 1).map(_.toDouble) 436 | val metadata = new YggdrasilMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, Entropy, Map(1 -> 3)) 437 | val fullImpurityAgg = metadata.createImpurityAggregator() 438 | labels.foreach(label => fullImpurityAgg.update(label)) 439 | 440 | val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes, Array(fullImpurityAgg)) 441 | val partitionInfos = sc.parallelize(Seq(info)) 442 | val bestSplit = new ygg.ContinuousSplit(0, threshold = 0.5) 443 | val bitVector = Yggdrasil.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows) 444 | assert(bitVector.toArray.toSet === Set(3, 4)) 445 | } 446 | 447 | test("collectBitVectors with 1 vector, with tied threshold") { 448 | val col = new FeatureVector(0, 0, 449 | Array(-4.0, -4.0, -2.0, -2.0, -1.0, -1.0, 1.0, 1.0), 450 | Array(3, 7, 2, 6, 1, 5, 0, 4)) 451 | val numRows = col.values.length 452 | val activeNodes = new BitSet(1) 453 | activeNodes.set(0) 454 | val labels = Array(0, 0, 0, 1, 1, 1, 1).map(_.toDouble) 455 | val metadata = new YggdrasilMetadata(numClasses = 2, maxBins = 4, minInfoGain = 0.0, Entropy, Map(1 -> 3)) 456 | val fullImpurityAgg = metadata.createImpurityAggregator() 457 | labels.foreach(label => fullImpurityAgg.update(label)) 458 | 459 | val info = PartitionInfo(Array(col), Array(0, numRows), activeNodes, Array(fullImpurityAgg)) 460 | val partitionInfos = sc.parallelize(Seq(info)) 461 | val bestSplit = new ygg.ContinuousSplit(0, threshold = -2.0) 462 | val bitVector = Yggdrasil.aggregateBitVector(partitionInfos, Array(Some(bestSplit)), numRows) 463 | assert(bitVector.toArray.toSet === Set(0, 1, 4, 5)) 464 | } 465 | 466 | /* * * * * * * * * * * Active nodes * * * * * * * * * * */ 467 | 468 | test("computeActiveNodePeriphery") { 469 | 470 | def exactlyEquals(a: ImpurityStats, b: ImpurityStats): Boolean = { 471 | a.gain == b.gain && a.impurity == b.impurity && 472 | a.impurityCalculator.stats.sameElements(b.impurityCalculator.stats) && 473 | a.leftImpurityCalculator.stats.sameElements(b.leftImpurityCalculator.stats) && 474 | a.rightImpurityCalculator.stats.sameElements(b.rightImpurityCalculator.stats) && 475 | a.valid == b.valid 476 | } 477 | // old periphery: 2 nodes 478 | val left = ygg.LearningNode.emptyNode(id = 1) 479 | val right = ygg.LearningNode.emptyNode(id = 2) 480 | val oldPeriphery: Array[ygg.LearningNode] = Array(left, right) 481 | // bestSplitsAndGains: Do not split left, but split right node. 482 | val lCalc = new EntropyCalculator(Array(8.0, 1.0)) 483 | val lStats = new ImpurityStats(0.0, lCalc.calculate(), 484 | lCalc, lCalc, new EntropyCalculator(Array(0.0, 0.0))) 485 | 486 | val rSplit = new ygg.ContinuousSplit(featureIndex = 1, threshold = 0.6) 487 | val rCalc = new EntropyCalculator(Array(5.0, 7.0)) 488 | val rRightChildCalc = new EntropyCalculator(Array(1.0, 5.0)) 489 | val rLeftChildCalc = new EntropyCalculator(Array( 490 | rCalc.stats(0) - rRightChildCalc.stats(0), 491 | rCalc.stats(1) - rRightChildCalc.stats(1))) 492 | val rGain = { 493 | val rightWeight = rRightChildCalc.stats.sum / rCalc.stats.sum 494 | val leftWeight = rLeftChildCalc.stats.sum / rCalc.stats.sum 495 | rCalc.calculate() - 496 | rightWeight * rRightChildCalc.calculate() - leftWeight * rLeftChildCalc.calculate() 497 | } 498 | val rStats = 499 | new ImpurityStats(rGain, rCalc.calculate(), rCalc, rLeftChildCalc, rRightChildCalc) 500 | 501 | val bestSplitsAndGains: Array[(Option[ygg.Split], ImpurityStats)] = 502 | Array((None, lStats), (Some(rSplit), rStats)) 503 | 504 | // Test A: Split right node 505 | val newPeriphery1: Array[ygg.LearningNode] = 506 | Yggdrasil.computeActiveNodePeriphery(oldPeriphery, bestSplitsAndGains, minInfoGain = 0.0) 507 | // Expect 2 active nodes 508 | assert(newPeriphery1.length === 2) 509 | // Confirm right node was updated 510 | assert(right.split.get === rSplit) 511 | assert(!right.isLeaf) 512 | assert(exactlyEquals(right.stats, rStats)) 513 | assert(right.leftChild.nonEmpty && right.leftChild.get === newPeriphery1(0)) 514 | assert(right.rightChild.nonEmpty && right.rightChild.get === newPeriphery1(1)) 515 | // Confirm new active nodes have stats but no children 516 | assert(newPeriphery1(0).leftChild.isEmpty && newPeriphery1(0).rightChild.isEmpty && 517 | newPeriphery1(0).split.isEmpty && 518 | newPeriphery1(0).stats.impurityCalculator.stats.sameElements(rLeftChildCalc.stats)) 519 | assert(newPeriphery1(1).leftChild.isEmpty && newPeriphery1(1).rightChild.isEmpty && 520 | newPeriphery1(1).split.isEmpty && 521 | newPeriphery1(1).stats.impurityCalculator.stats.sameElements(rRightChildCalc.stats)) 522 | 523 | // Test B: Increase minInfoGain, so split nothing 524 | val newPeriphery2: Array[ygg.LearningNode] = 525 | Yggdrasil.computeActiveNodePeriphery(oldPeriphery, bestSplitsAndGains, minInfoGain = 1000.0) 526 | assert(newPeriphery2.isEmpty) 527 | } 528 | } 529 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/tree/impl/YggdrasilUtilSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.tree.impl 19 | 20 | import org.apache.spark.ml.tree.impl.YggdrasilUtil._ 21 | import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} 22 | import org.apache.spark.mllib.util.{MLlibTestSparkContext, SparkFunSuite} 23 | 24 | import scala.collection.mutable 25 | 26 | /** 27 | * Test suite for [[YggdrasilUtil]]. 28 | */ 29 | class YggdrasilUtilSuite extends SparkFunSuite with MLlibTestSparkContext { 30 | 31 | private def checkDense(rows: Seq[Vector]): Unit = { 32 | val numRowPartitions = 2 33 | val rowStore = sc.parallelize(rows, numRowPartitions) 34 | val colStore = rowToColumnStoreDense(rowStore) 35 | val numColPartitions = colStore.partitions.length 36 | val cols: Map[Int, Array[Double]] = colStore.collect().toMap 37 | val numRows = rows.size 38 | if (numRows == 0) { 39 | assert(cols.isEmpty) 40 | return 41 | } 42 | val numCols = rows.head.size 43 | if (numCols == 0) { 44 | assert(cols.isEmpty) 45 | return 46 | } 47 | rows.zipWithIndex.foreach { case (row, i) => 48 | var j = 0 49 | while (j < numCols) { 50 | assert(row(j) == cols(j)(i)) 51 | j += 1 52 | } 53 | } 54 | val expectedNumColPartitions = math.min(rowStore.partitions.length, numCols) 55 | assert(numColPartitions === expectedNumColPartitions) 56 | } 57 | 58 | private def checkSparse(rows: Seq[Vector]): Unit = { 59 | val numRowPartitions = 2 60 | val overPartitionFactor = 2 61 | val rowStore = sc.parallelize(rows, numRowPartitions) 62 | val colStore = rowToColumnStoreSparse(rowStore, overPartitionFactor) 63 | val cols: Map[Int, Vector] = colStore.collect().toMap 64 | val numRows = rows.size 65 | // Check cases with 0 rows or cols 66 | if (numRows == 0) { 67 | assert(cols.isEmpty) 68 | return 69 | } 70 | val numCols = rows.head.size 71 | if (numCols == 0) { 72 | assert(cols.isEmpty) 73 | return 74 | } 75 | // Check values (and count non-zeros too) 76 | var expectedNumNonZeros = 0 77 | rows.zipWithIndex.foreach { case (row, i) => 78 | var j = 0 79 | while (j < numCols) { 80 | assert(row(j) == cols(j)(i)) 81 | if (row(j) != 0) expectedNumNonZeros += 1 82 | j += 1 83 | } 84 | } 85 | // Check sparsity 86 | val numNonZeros = cols.values.map { 87 | case sv: SparseVector => sv.indices.length 88 | case _ => throw new RuntimeException( 89 | "checkSparse() found column which was not converted to SparseVector.") 90 | }.sum 91 | assert(numNonZeros === expectedNumNonZeros) 92 | // Check partitions to make sure they each contain consecutive columns. 93 | val colsByPartition: Array[(Int, Array[(Int, Vector)])] = colStore.mapPartitionsWithIndex { 94 | case (partitionIndex, iterator) => 95 | val partCols = new mutable.ArrayBuffer[(Int, Vector)] 96 | iterator.foreach(col => partCols += col) 97 | Iterator((partitionIndex, iterator.toArray)) 98 | }.collect() 99 | colsByPartition.foreach { case (partitionIndex, partCols) => 100 | var j = 0 101 | while (j + 1 < partCols.length) { 102 | val curColIndex = partCols(j)._1 103 | val nextColIndex = partCols(j + 1)._1 104 | assert(curColIndex + 1 == nextColIndex) 105 | j += 1 106 | } 107 | } 108 | } 109 | 110 | test("rowToColumnStore: small dense") { 111 | val rows = Seq( 112 | Vectors.dense(1.0, 2.0, 3.0, 4.0), 113 | Vectors.dense(1.1, 2.1, 3.1, 4.1), 114 | Vectors.dense(1.2, 2.2, 3.2, 4.2) 115 | ) 116 | checkDense(rows) 117 | checkSparse(rows) 118 | } 119 | 120 | test("rowToColumnStore: small sparse") { 121 | val rows = Seq( 122 | Vectors.sparse(4, Array(0, 1), Array(1.0, 2.0)), 123 | Vectors.sparse(4, Array(1, 2), Array(1.1, 2.1)), 124 | Vectors.sparse(4, Array(2, 3), Array(1.2, 2.2)) 125 | ) 126 | checkDense(rows) 127 | checkSparse(rows) 128 | } 129 | 130 | test("rowToColumnStore: large dense") { 131 | // Note: All values must be non-zero since rowToColumnStoreSparse() automatically ignores 132 | // zero-valued elements. 133 | val numRows = 100 134 | val numCols = 90 135 | val rows = Range(0, numRows).map { i => 136 | Vectors.dense(Range(0, numCols).map(_ + numCols * i + 1.0).toArray) 137 | } 138 | checkDense(rows) 139 | checkSparse(rows) 140 | } 141 | 142 | test("rowToColumnStore: mixed dense and sparse") { 143 | val rows = Seq( 144 | Vectors.dense(1.0, 2.0, 3.0, 4.0), 145 | Vectors.sparse(4, Array(1, 2), Array(1.1, 2.1)), 146 | Vectors.dense(1.2, 2.2, 3.2, 4.2), 147 | Vectors.sparse(4, Array(0, 2), Array(1.3, 2.3)) 148 | ) 149 | checkDense(rows) 150 | checkSparse(rows) 151 | } 152 | 153 | test("rowToColumnStore: 0 rows") { 154 | val rows = Seq.empty[Vector] 155 | checkDense(rows) 156 | checkSparse(rows) 157 | } 158 | 159 | test("rowToColumnStore: 0 cols") { 160 | val rows = Seq( 161 | Vectors.dense(Array.empty[Double]), 162 | Vectors.dense(Array.empty[Double]), 163 | Vectors.dense(Array.empty[Double]) 164 | ) 165 | checkDense(rows) 166 | checkSparse(rows) 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.util 19 | 20 | import org.scalatest.{BeforeAndAfterAll, Suite} 21 | 22 | import org.apache.spark.{SparkConf, SparkContext} 23 | import org.apache.spark.sql.SQLContext 24 | 25 | trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => 26 | @transient var sc: SparkContext = _ 27 | @transient var sqlContext: SQLContext = _ 28 | 29 | override def beforeAll() { 30 | super.beforeAll() 31 | val conf = new SparkConf() 32 | .setMaster("local[2]") 33 | .setAppName("MLlibUnitTest") 34 | sc = new SparkContext(conf) 35 | SQLContext.clearActive() 36 | sqlContext = new SQLContext(sc) 37 | SQLContext.setActive(sqlContext) 38 | } 39 | 40 | override def afterAll() { 41 | sqlContext = null 42 | SQLContext.clearActive() 43 | if (sc != null) { 44 | sc.stop() 45 | } 46 | sc = null 47 | super.afterAll() 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/mllib/util/SparkFunSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.util 2 | 3 | import org.apache.spark.Logging 4 | import org.scalatest.{FunSuite, Outcome} 5 | 6 | /** 7 | * Created by fabuzaid21 on 6/7/16. 8 | */ 9 | abstract class SparkFunSuite extends FunSuite with Logging { 10 | // scalastyle:on 11 | 12 | /** 13 | * Log the suite name and the test name before and after each test. 14 | * 15 | * Subclasses should never override this method. If they wish to run 16 | * custom code before and after each test, they should mix in the 17 | * {{org.scalatest.BeforeAndAfter}} trait instead. 18 | */ 19 | final protected override def withFixture(test: NoArgTest): Outcome = { 20 | val testName = test.text 21 | val suiteName = this.getClass.getName 22 | val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") 23 | try { 24 | logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") 25 | test() 26 | } finally { 27 | logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") 28 | } 29 | } 30 | } 31 | --------------------------------------------------------------------------------