├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── data └── bank │ └── bank-full.csv ├── pom.xml └── src ├── main └── scala │ └── org │ └── apache │ └── spark │ ├── examples │ └── ml │ │ └── GBTLRExample.scala │ └── ml │ └── gbtlr │ └── GBTLRClassifier.scala └── test ├── resources └── log4j.properties └── scala └── org └── apache └── spark ├── SparkFunSuite.scala └── ml └── gbtlr └── GBTLRClassifierSuite.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *.iml 2 | .idea 3 | target 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: java 2 | 3 | jdk: 4 | - openjdk8 5 | 6 | script: 7 | - mvn clean install -B 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2017 by Contributors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spark-GBTLR 2 | [![Build Status](https://travis-ci.org/titicaca/spark-gbtlr.svg?branch=master)](https://travis-ci.org/titicaca/spark-gbtlr) 3 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 4 | 5 | 6 | GBTLRClassifier is a hybrid model of Gradient Boosting Trees and Logistic Regression. 7 | It is quite practical and popular in many data mining competitions. 8 | In this hybrid model, input features are transformed by means of boosted decision trees. 9 | The output of each individual tree is treated as a categorical input feature to a sparse linear classifer. 10 | Boosted decision trees prove to be very powerful feature transforms. 11 | 12 | Model details about GBTLR can be found in the following paper: 13 | Practical Lessons from Predicting Clicks on Ads at Facebook [1]. 14 | 15 | GBTLRClassifier on Spark is designed and implemented by combining GradientBoostedTrees and Logistic Regressor in 16 | Spark MLlib. Features are firstly trained and transformed into sparse vectors via GradientBoostedTrees, and then 17 | the generated sparse features will be trained and predicted in Logistic Regression model. 18 | 19 | ## Usage 20 | 21 | GBTLRClassifier is designed and implemented easy to use. Parameters of GBTLRClassifier are the same as the combined 22 | parameters of GradientBoostedTrees and LogisticRegression in MLlib. 23 | 24 | ## Examples 25 | 26 | The following codes are an example for predicting bank marketing results using Bank Marketing Dataset [2]. 27 | The data is related with direct marketing campaigns (phone calls) of a Portuguese banking institution. The classification goal is to predict if the client will subscribe a term deposit (variable y). 28 | 29 | 30 | *Scala API* 31 | ```scala 32 | def main(args: Array[String]): Unit = { 33 | val spark = SparkSession 34 | .builder() 35 | .master("local[2]") 36 | .appName("gbtlr example") 37 | .getOrCreate() 38 | 39 | val startTime = System.currentTimeMillis() 40 | 41 | val dataset = spark.read.option("header", "true").option("inferSchema", "true") 42 | .option("delimiter", ";").csv("data/bank/bank-full.csv") 43 | 44 | val columnNames = Array("job", "marital", "education", 45 | "default", "housing", "loan", "contact", "month", "poutcome", "y") 46 | val indexers = columnNames.map(name => new StringIndexer() 47 | .setInputCol(name).setOutputCol(name + "_index")) 48 | val pipeline = new Pipeline().setStages(indexers) 49 | val data1 = pipeline.fit(dataset).transform(dataset) 50 | val data2 = data1.withColumnRenamed("y_index", "label") 51 | 52 | val assembler = new VectorAssembler() 53 | assembler.setInputCols(Array("age", "job_index", "marital_index", 54 | "education_index", "default_index", "balance", "housing_index", 55 | "loan_index", "contact_index", "day", "month_index", "duration", 56 | "campaign", "pdays", "previous", "poutcome_index")) 57 | assembler.setOutputCol("features") 58 | 59 | val data3 = assembler.transform(data2) 60 | val data4 = data3.randomSplit(Array(4, 1)) 61 | 62 | val gBTLRClassifier = new GBTLRClassifier() 63 | .setFeaturesCol("features") 64 | .setLabelCol("label") 65 | .setGBTMaxIter(10) 66 | .setLRMaxIter(100) 67 | .setRegParam(0.01) 68 | .setElasticNetParam(0.5) 69 | 70 | val model = gBTLRClassifier.fit(data4(0)) 71 | val summary = model.evaluate(data4(1)) 72 | val endTime = System.currentTimeMillis() 73 | val auc = summary.binaryLogisticRegressionSummary 74 | .asInstanceOf[BinaryLogisticRegressionSummary].areaUnderROC 75 | println(s"Training and evaluating cost ${(endTime - startTime) / 1000} seconds") 76 | println(s"The model's auc: ${auc}") 77 | ``` 78 | 79 | 80 | ## Benchmark 81 | TO BE ADDED.. 82 | 83 | ## Requirements 84 | 85 | Spark-GBTLR is built on Spark 2.1.1 or later version. 86 | 87 | ## Build From Source 88 | 89 | `mvn clean package` 90 | 91 | ## Licenses 92 | 93 | Spark-GBTLR is available under Apache Licenses 2.0. 94 | 95 | ## Acknowledgement 96 | 97 | Spark GBTLR is designed and implemented together with my former intern Fang, Jie at Transwarp (transwarp.io). 98 | Thanks for his great contribution. In addition, thanks for the supports of Discover Team. 99 | 100 | ## Contact and Feedback 101 | 102 | If you encounter any bugs, feel free to submit an issue or pull request. Also you can email to: 103 | Yang, Fangzhou (fangzhou.yang@hotmail.com) 104 | 105 | 106 | ## References 107 | 108 | [1] He X, Pan J, Jin O, et al. Practical Lessons from Predicting Clicks on Ads at Facebook[J]., 2014: 1-9. 109 | 110 | [2] Moro S, Cortez P, Rita P, et al. A Data-Driven Approach to Predict the Success of Bank Telemarketing[J]. 111 | Decision support systems, 2014, 62(62): 22-31. 112 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 5 | 4.0.0 6 | 7 | org.apache.spark 8 | spark-gbtlr 9 | 2.4.0 10 | 11 | 12 | UTF-8 13 | UTF-8 14 | 1.8 15 | 2.11.8 16 | 2.11 17 | 1.7.16 18 | 1.2.17 19 | false 20 | 3.3.9 21 | 2.4.0 22 | 23 | 24 | 25 | 26 | log4j 27 | log4j 28 | ${log4j.version} 29 | compile 30 | 31 | 32 | org.apache.spark 33 | spark-core_${scala.binary.version} 34 | ${spark.version} 35 | 36 | 37 | 38 | org.apache.spark 39 | spark-mllib_${scala.binary.version} 40 | ${spark.version} 41 | 42 | 43 | 44 | org.apache.spark 45 | spark-mllib_${scala.binary.version} 46 | ${spark.version} 47 | 48 | test-jar 49 | 50 | 51 | junit 52 | junit 53 | 4.12 54 | test 55 | 56 | 57 | org.scalatest 58 | scalatest_${scala.binary.version} 59 | 3.0.0 60 | test 61 | 62 | 63 | 64 | 65 | 66 | 67 | org.apache.maven.plugins 68 | maven-compiler-plugin 69 | 3.5.1 70 | 71 | 72 | ${java.version} 73 | ${java.version} 74 | ${project.build.sourceEncoding} 75 | 76 | 77 | 78 | net.alchim31.maven 79 | scala-maven-plugin 80 | 3.3.2 81 | 82 | 83 | compile 84 | 85 | compile 86 | 87 | compile 88 | 89 | 90 | test-compile 91 | 92 | testCompile 93 | 94 | test-compile 95 | 96 | 97 | process-resources 98 | 99 | compile 100 | 101 | 102 | 103 | 104 | 105 | 106 | org.apache.maven.plugins 107 | maven-surefire-plugin 108 | 2.7 109 | 110 | true 111 | 112 | 113 | 114 | 115 | org.scalatest 116 | scalatest-maven-plugin 117 | 1.0 118 | 119 | ${project.build.directory}/surefire-reports 120 | . 121 | WDF TestSuite.txt 122 | 123 | 124 | 125 | test 126 | 127 | test 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/GBTLRExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.ml.gbtlr.GBTLRClassifier 4 | import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary 5 | import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler} 6 | import org.apache.spark.ml.Pipeline 7 | import org.apache.spark.sql.SparkSession 8 | 9 | // scalastyle:off println 10 | 11 | 12 | object GBTLRExample { 13 | def main(args: Array[String]): Unit = { 14 | val spark = SparkSession 15 | .builder() 16 | .master("local[2]") 17 | .appName("gbtlr example") 18 | .getOrCreate() 19 | 20 | val startTime = System.currentTimeMillis() 21 | 22 | val dataset = spark.read.option("header", "true").option("inferSchema", "true") 23 | .option("delimiter", ";").csv("data/bank/bank-full.csv") 24 | 25 | val columnNames = Array("job", "marital", "education", 26 | "default", "housing", "loan", "contact", "month", "poutcome", "y") 27 | val indexers = columnNames.map(name => new StringIndexer() 28 | .setInputCol(name).setOutputCol(name + "_index")) 29 | val pipeline = new Pipeline().setStages(indexers) 30 | val data1 = pipeline.fit(dataset).transform(dataset) 31 | val data2 = data1.withColumnRenamed("y_index", "label") 32 | 33 | val assembler = new VectorAssembler() 34 | assembler.setInputCols(Array("age", "job_index", "marital_index", 35 | "education_index", "default_index", "balance", "housing_index", 36 | "loan_index", "contact_index", "day", "month_index", "duration", 37 | "campaign", "pdays", "previous", "poutcome_index")) 38 | assembler.setOutputCol("features") 39 | 40 | val data3 = assembler.transform(data2) 41 | val data4 = data3.randomSplit(Array(4, 1)) 42 | 43 | val gBTLRClassifier = new GBTLRClassifier() 44 | .setFeaturesCol("features") 45 | .setLabelCol("label") 46 | .setGBTMaxIter(10) 47 | .setLRMaxIter(100) 48 | .setRegParam(0.01) 49 | .setElasticNetParam(0.5) 50 | 51 | val model = gBTLRClassifier.fit(data4(0)) 52 | val summary = model.evaluate(data4(1)) 53 | val endTime = System.currentTimeMillis() 54 | val auc = summary.binaryLogisticRegressionSummary 55 | .asInstanceOf[BinaryLogisticRegressionSummary].areaUnderROC 56 | println(s"Training and evaluating cost ${(endTime - startTime) / 1000} seconds") 57 | println(s"The model's auc: ${auc}") 58 | } 59 | } 60 | 61 | // scalastyle:on println 62 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/gbtlr/GBTLRClassifier.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.gbtlr 2 | 3 | import scala.collection.immutable.HashMap 4 | import scala.collection.mutable 5 | import org.apache.hadoop.fs.Path 6 | import org.apache.log4j.Logger 7 | import org.apache.spark.SparkException 8 | import org.apache.spark.ml.{PredictionModel, Predictor} 9 | import org.apache.spark.ml.attribute._ 10 | import org.apache.spark.ml.classification._ 11 | import org.apache.spark.ml.linalg.{Vector, Vectors} 12 | import org.apache.spark.ml.param._ 13 | import org.apache.spark.ml.util._ 14 | import org.apache.spark.ml.util.Instrumentation.instrumented 15 | import org.apache.spark.mllib.linalg.{DenseVector => OldDenseVector} 16 | import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} 17 | import org.apache.spark.mllib.tree.configuration.{FeatureType, Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} 18 | import org.apache.spark.mllib.tree.GradientBoostedTrees 19 | import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} 20 | import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} 21 | import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, Node => OldNode} 22 | import org.apache.spark.rdd.RDD 23 | import org.apache.spark.sql.{DataFrame, Dataset, Row} 24 | import org.apache.spark.sql.functions._ 25 | import org.apache.spark.sql.types.{DoubleType, StructField, StructType} 26 | 27 | 28 | trait GBTLRClassifierParams extends Params { 29 | 30 | // =====below are GBTClassifier params===== 31 | /** 32 | * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). 33 | * 34 | * E.g. 10 means that the cache will get checkpointed every 10 iterations. 35 | * @group param 36 | */ 37 | val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set" + 38 | "checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache " + 39 | "will get checkpointed every 10 iterations", 40 | (interval: Int) => interval == -1 || interval >= 1) 41 | 42 | /** @group getParam */ 43 | def getCheckpointInterval: Int = $(checkpointInterval) 44 | 45 | /** 46 | * Loss function which GBT tries to minimize. (case-insensitive) 47 | * 48 | * Supported: "logistic" 49 | * @group param 50 | */ 51 | val lossType: Param[String] = new Param[String](this, "lossType", "Loss funtion which GBT" + 52 | " tries to minimize (case-insensitive). Supported options: logistic, squared, absolute", 53 | (value: String) => value == "logistic") 54 | 55 | /** @group getParam */ 56 | def getLossType: String = $(lossType).toLowerCase 57 | 58 | /** 59 | * Maximum number of bins used for discretizing continuous features and for choosing how to split 60 | * on features at each node. More bins give higher granularity. 61 | * Must be >= 2 and >= number of categories in any categorical feature. 62 | * 63 | * (default = 32) 64 | * @group param 65 | */ 66 | val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" + 67 | " discretizing continuous features. Must be >= 2 and >= number of categories for any" + 68 | "categorical feature.", ParamValidators.gtEq(2)) 69 | 70 | /** @group getParam */ 71 | def getMaxBins: Int = $(maxBins) 72 | 73 | /** 74 | * Maximum depth of the tree ( >= 0). 75 | * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. 76 | * 77 | * (default = 5) 78 | * @group param 79 | */ 80 | val maxDepth: IntParam = 81 | new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" + 82 | " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.", 83 | ParamValidators.gtEq(0)) 84 | 85 | /** @group getParam */ 86 | def getMaxDepth: Int = $(maxDepth) 87 | 88 | /** 89 | * If false, the algorithm will pass trees to executors to match instances with nodes. 90 | * 91 | * If true, the algorithm will cache node IDs for each instance. 92 | * 93 | * Caching can speed up training of deeper trees. Users can set how often should the 94 | * cache be checkpointed or disable it by setting checkpointInterval. 95 | * 96 | * (default = false) 97 | * @group param 98 | */ 99 | val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If" + 100 | "false, the algorithm will pass trees to executors to match instances with nodes." + 101 | "If true, the algorithm will cache node IDs for each instance. Caching can speed" + 102 | "up training of deeper trees.") 103 | 104 | /** @group getParam */ 105 | def getCacheNodeIds: Boolean = $(cacheNodeIds) 106 | 107 | /** 108 | * Maximum memory in MB allocated to histogram aggregation. If too small, 109 | * then 1 node will be split per iteration, and its aggregates may exceed this size. 110 | * 111 | * (default = 256 MB) 112 | * @group param 113 | */ 114 | val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB", 115 | "Maximum memory in MB allocated to histogram aggregation.", 116 | ParamValidators.gtEq(0)) 117 | 118 | /** @group getParam */ 119 | def getMaxMemoryInMB: Int = $(maxMemoryInMB) 120 | 121 | /** 122 | * Minimum number of instances each child must have after split. 123 | * If a split causes the left or right child to have fewer than minInstancesPerNode, 124 | * the split will be discarded as invalid. 125 | * Should be >= 1. 126 | * 127 | * (default = 1) 128 | * @group param 129 | */ 130 | val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", 131 | "Minimum number of instances each child must have after split. If a split causes" + 132 | " the left or right child to have fewer than minInstancesPerNode, the split" + 133 | "will be discarded as invalid. Should be >= 1.", ParamValidators.gtEq(1)) 134 | 135 | /** @group getParam */ 136 | def getMinInstancePerNode: Int = $(minInstancesPerNode) 137 | 138 | /** 139 | * Minimum information gain for a split to be considered at a tree node. 140 | * Should be >= 0.0. 141 | * 142 | * (default = 0.0) 143 | * @group param 144 | */ 145 | val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain", 146 | "Minimum information gain for a split to be considered at a tree node.", 147 | ParamValidators.gtEq(0.0)) 148 | 149 | /** @group getParam */ 150 | def getMinInfoGain: Double = $(minInfoGain) 151 | 152 | /** 153 | * Param for maximum number of iterations (>= 0) of GBT. 154 | * @group param 155 | */ 156 | val GBTMaxIter: IntParam = new IntParam(this, "GBTMaxIter", 157 | "maximum number of iterations (>= 0) of GBT", 158 | ParamValidators.gtEq(0)) 159 | 160 | /** @group getParam */ 161 | def getGBTMaxIter: Int = $(GBTMaxIter) 162 | 163 | /** 164 | * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking 165 | * the contribution of each estimator. 166 | * 167 | * (default = 0.1) 168 | * @group param 169 | */ 170 | val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " + 171 | "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of" + 172 | "each estimator.", 173 | ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) 174 | 175 | /** @group getParam */ 176 | def getStepSize: Double = $(stepSize) 177 | 178 | /** 179 | * Fraction of the training data used for learning each decision tree, in range (0, 1]. 180 | * 181 | * (default = 1.0) 182 | * @group param 183 | */ 184 | val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate", 185 | "Fraction of the training data used for learning each decision tree, in range (0, 1].", 186 | ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) 187 | 188 | /** @group getParam */ 189 | def getSubsamplingRate: Double = $(subsamplingRate) 190 | 191 | /** 192 | * Param for random seed. 193 | * @group param 194 | */ 195 | val seed: LongParam = new LongParam(this, "seed", "random seed") 196 | 197 | /** @group getParam */ 198 | def getSeed: Long = $(seed) 199 | 200 | // =====below are LR params===== 201 | 202 | /** 203 | * Param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, 204 | * the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. 205 | * @group param 206 | */ 207 | val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", 208 | "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is" + 209 | " an L2 penalty. For alpha = 1, it is an L1 penalty", ParamValidators.inRange(0, 1)) 210 | 211 | /** @group getParam */ 212 | def getElasticNetParam: Double = $(elasticNetParam) 213 | 214 | /** 215 | * Param for the name of family which is a description of the label distribution 216 | * to be used in the model. 217 | * 218 | * Supported options: 219 | * 220 | * - "auto": Automatically select the family based on the number of classes: 221 | * If numClasses == 1 || numClasses == 2, set to "binomial". 222 | * Else, set to "multinomial" 223 | * 224 | * - "binomial": Binary logistic regression with pivoting. 225 | * 226 | * - "multinomial": Multinomial logistic (softmax) regression without pivoting. 227 | * 228 | * Default is "auto". 229 | * 230 | * @group param 231 | */ 232 | val family: Param[String] = new Param(this, "family", 233 | "The name of family which is a description of the label distribution to be used in the " + 234 | s"model. Supported options: " + 235 | s"${Array("auto", "binomial", "multinomial").map(_.toLowerCase).mkString(", ")}.", 236 | ParamValidators.inArray[String]( 237 | Array("auto", "binomial", "multinomial").map(_.toLowerCase))) 238 | 239 | /** @group getParam */ 240 | def getFamily: String = $(family) 241 | 242 | /** 243 | * Param for whether to fit an intercept term. 244 | * @group param 245 | */ 246 | val fitIntercept: BooleanParam = new BooleanParam(this, "fitIntercept", 247 | "whether to fit an intercept term") 248 | 249 | /** @group getParam */ 250 | def getFitIntercept: Boolean = $(fitIntercept) 251 | 252 | /** 253 | * Param for maximum number of iterations (>= 0) of LR. 254 | * @group param 255 | */ 256 | val LRMaxIter: IntParam = new IntParam(this, "LRMaxIter", 257 | "maximum number of iterations (>= 0) of LR", 258 | ParamValidators.gtEq(0)) 259 | 260 | /** @group getParam */ 261 | def getLRMaxIter: Int = $(LRMaxIter) 262 | 263 | /** 264 | * Param for Column name for predicted class conditional probabilities. 265 | * 266 | * '''Note''': Not all models output well-calibrated probability estimates! 267 | * 268 | * These probabilities should be treated as confidences, not precise probabilities. 269 | * 270 | * @group param 271 | */ 272 | val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", 273 | "Column name for predicted class conditional probabilities. Note: Not all models output" + 274 | " well-calibrated probability estimates! These probabilities should be treated as" + 275 | " confidences, not precise probabilities") 276 | 277 | /** @group getParam */ 278 | def getProbabilityCol: String = $(probabilityCol) 279 | 280 | /** 281 | * Param for raw prediction (a.k.a. confidence) column name. 282 | * 283 | * @group param 284 | */ 285 | val rawPredictionCol: Param[String] = new Param[String](this, "rawPredictionCol", 286 | "raw prediction (a.k.a. confidence) column name") 287 | 288 | /** @group getParam */ 289 | def getRawPredictionCol: String = $(rawPredictionCol) 290 | 291 | /** 292 | * Param for gbt generated features column name. 293 | * 294 | * @group param 295 | */ 296 | val gbtGeneratedFeaturesCol: Param[String] = new Param[String](this, "gbtGeneratedCol", 297 | "gbt generated features column name") 298 | 299 | /** @group getParam */ 300 | def getGbtGeneratedFeaturesCol: String = $(gbtGeneratedFeaturesCol) 301 | 302 | /** 303 | * Param for regularization parameter (>= 0). 304 | * @group param 305 | */ 306 | val regParam: DoubleParam = new DoubleParam(this, "regParam", 307 | "regularization parameter (>= 0)", ParamValidators.gtEq(0)) 308 | 309 | /** @group getParam */ 310 | def getRegParam: Double = $(regParam) 311 | 312 | /** 313 | * Param for whether to standardize the training features before fitting the model. 314 | * @group param 315 | */ 316 | val standardization: BooleanParam = new BooleanParam(this, "standardization", 317 | "whether to standardize the training features before fitting the model") 318 | 319 | /** @group getParam */ 320 | def getStandardization: Boolean = $(standardization) 321 | 322 | /** 323 | * Param for threshold in binary classification prediction, in range [0, 1]. 324 | * @group param 325 | */ 326 | val threshold: DoubleParam = new DoubleParam(this, "threshold", 327 | "threshold in binary classification prediction, in range [0, 1]", 328 | ParamValidators.inRange(0, 1)) 329 | 330 | /** 331 | * Get threshold for binary classification. 332 | * 333 | * If `thresholds` is set with length 2 (i.e., binary classification), 334 | * this returns the equivalent threshold: {{{ 1 / (1 + thresholds(0) / thresholds(1)) }}} 335 | * Otherwise, returns `threshold` if set, or its default value if unset. 336 | * 337 | * @group getParam 338 | * @throws IllegalArgumentException if `thresholds` is set to an array of length other than 2. 339 | */ 340 | def getThreshold: Double = { 341 | checkThresholdConsistency() 342 | if (isSet(thresholds)) { 343 | val ts = $(thresholds) 344 | require(ts.length == 2, "Logistic Regression getThreshold only applies to" + 345 | " binary classification, but thresholds has length != 2. thresholds: " + 346 | ts.mkString(",")) 347 | 1.0 / (1.0 + ts(0) / ts(1)) 348 | } else { 349 | $(threshold) 350 | } 351 | } 352 | 353 | /** 354 | * Param for Thresholds in multi-class classification to adjust the probability 355 | * of predicting each class. Array must have length equal to the number of classes, 356 | * with values > 0 excepting that at most one value may be 0. The class with largest 357 | * value p/t is predicted, where p is the original probability of that class and t is 358 | * the class's threshold. 359 | * @group param 360 | */ 361 | val thresholds: DoubleArrayParam = new DoubleArrayParam(this, "thresholds", 362 | "Thresholds in multi-class classification to adjust the probability of predicting" + 363 | " each class. Array must have length equal to the number of classes, with values > 0" + 364 | " excepting that at most one value may be 0. The class with largest value p/t is" + 365 | " predicted, where p is the original probability of that class and t is the class's" + 366 | " threshold", (t: Array[Double]) => t.forall(_ >= 0) && t.count(_ == 0) <= 1) 367 | 368 | /** 369 | * Get thresholds for binary or multiclass classification. 370 | * 371 | * If `thresholds` is set, return its value. 372 | * Otherwise, if `threshold` is set, return the equivalent thresholds for binary 373 | * classification: (1-threshold, threshold). 374 | * If neither are set, throw an exception. 375 | * 376 | * @group getParam 377 | */ 378 | def getThresholds: Array[Double] = { 379 | checkThresholdConsistency() 380 | if (!isSet(thresholds) && isSet(threshold)) { 381 | val t = $(threshold) 382 | Array(1-t, t) 383 | } else { 384 | $(thresholds) 385 | } 386 | } 387 | 388 | /** 389 | * Param for the convergence tolerance for iterative algorithms (>= 0). 390 | * @group param 391 | */ 392 | val tol: DoubleParam = new DoubleParam(this, "tol", 393 | "the convergence tolerance for iterative algorithms (>= 0)", ParamValidators.gtEq(0)) 394 | 395 | /** @group getParam */ 396 | def getTol: Double = $(tol) 397 | 398 | /** 399 | * Param for weight column name. If this is not set or empty, we treat all instance 400 | * weights as 1.0. 401 | * @group param 402 | */ 403 | val weightCol: Param[String] = new Param[String](this, "weightCol", 404 | "weight column name. If this is not set or empty, we treat all instance weights as 1.0") 405 | 406 | /** @group getParam */ 407 | def getWeightCol: String = $(weightCol) 408 | 409 | /** 410 | * Param for suggested depth for treeAggregate (>= 2). 411 | * @group expertParam 412 | */ 413 | val aggregationDepth: IntParam = new IntParam(this, "aggregationDepth", 414 | "suggested depth for treeAggregate (>= 2)", ParamValidators.gtEq(2)) 415 | 416 | /** @group expertGetParam */ 417 | def getAggregationDepth: Int = $(aggregationDepth) 418 | 419 | /** 420 | * If `threshold` and `thresholds` are both set, ensures they are consistent. 421 | * 422 | * @throws IllegalArgumentException if `threshold` and `thresholds` are not equivalent 423 | */ 424 | private def checkThresholdConsistency(): Unit = { 425 | if (isSet(threshold) && isSet(thresholds)) { 426 | val ts = $(thresholds) 427 | require(ts.length == 2, "Logistic Regression found inconsistent values for threshold and" + 428 | s" thresholds. Param threshold is set (${$(threshold)}), indicating binary" + 429 | s" classification, but Param thresholds is set with length ${ts.length}." + 430 | " Clear one Param value to fix this problem.") 431 | val t = 1.0 / (1.0 + ts(0) / ts(1)) 432 | require(math.abs($(threshold) - t) < 1E-5, "Logistic Regression getThreshold found" + 433 | s" inconsistent values for threshold (${$(threshold)}) and thresholds (equivalent to $t)") 434 | } 435 | } 436 | 437 | setDefault(seed -> this.getClass.getName.hashCode.toLong, 438 | subsamplingRate -> 1.0, GBTMaxIter -> 20, stepSize -> 0.1, maxDepth -> 5, maxBins -> 32, 439 | minInstancesPerNode -> 1, minInfoGain -> 0.0, checkpointInterval -> 10, fitIntercept -> true, 440 | probabilityCol -> "probability", rawPredictionCol -> "rawPrediction", standardization -> true, 441 | threshold -> 0.5, lossType -> "logistic", cacheNodeIds -> false, maxMemoryInMB -> 256, 442 | regParam -> 0.0, elasticNetParam -> 0.0, family -> "auto", LRMaxIter -> 100, tol -> 1E-6, 443 | aggregationDepth -> 2, gbtGeneratedFeaturesCol -> "gbt_generated_features") 444 | } 445 | 446 | 447 | /** 448 | * GBTLRClassifier is a hybrid model of Gradient Boosting Trees and Logistic Regression. 449 | * Input features are transformed by means of boosted decision trees. The output of each individual tree is treated 450 | * as a categorical input feature to a sparse linear classifer. Boosted decision trees prove to be very powerful 451 | * feature transforms. 452 | * 453 | * Model details about GBTLR can be found in the following paper: 454 | * Practical Lessons from Predicting Clicks on Ads at Facebook 455 | * 456 | * GBTLRClassifier on Spark is designed and implemented by combining GradientBoostedTrees and Logistic Regressor in 457 | * Spark MLlib. Features are firstly trained and transformed into sparse vectors via GradientBoostedTrees, and then 458 | * the generated sparse features will be trained and predicted in Logistic Regression model. 459 | * 460 | * @param uid unique ID for Model 461 | */ 462 | class GBTLRClassifier (override val uid: String) 463 | extends Predictor[Vector, GBTLRClassifier, GBTLRClassificationModel] 464 | with GBTLRClassifierParams with DefaultParamsWritable { 465 | 466 | import GBTLRClassifier._ 467 | import GBTLRUtil._ 468 | 469 | def this() = this(Identifiable.randomUID("gbtlr")) 470 | 471 | // Set GBTClassifier params 472 | 473 | /** @group setParam */ 474 | def setMaxDepth(value: Int): this.type = set(maxDepth, value) 475 | 476 | /** @group setParam */ 477 | def setMaxBins(value: Int): this.type = set(maxBins, value) 478 | 479 | /** @group setParam */ 480 | def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value) 481 | 482 | /** @group setParam */ 483 | def setMinInfoGain(value: Double): this.type = set(minInfoGain, value) 484 | 485 | /** @group setParam */ 486 | def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) 487 | 488 | /** @group setParam */ 489 | def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value) 490 | 491 | /** @group setParam */ 492 | def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) 493 | 494 | /** 495 | * The impurity setting is ignored for GBT models. 496 | * Individual trees are build using impurity "Variance." 497 | * 498 | * @group setParam 499 | */ 500 | def setImpurity(value: String): this.type = { 501 | logger.warn("GBTClassifier in the GBTLRClassifier should NOT be used") 502 | this 503 | } 504 | 505 | /** @group setParam */ 506 | def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value) 507 | 508 | /** @group setParam */ 509 | def setSeed(value: Long): this.type = set(seed, value) 510 | 511 | /** @group setParam */ 512 | def setGBTMaxIter(value: Int): this.type = set(GBTMaxIter, value) 513 | 514 | /** @group setParam */ 515 | def setStepSize(value: Double): this.type = set(stepSize, value) 516 | 517 | /** @group setParam */ 518 | def setLossType(value: String): this.type = set(lossType, value) 519 | 520 | /** @group setParam */ 521 | def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) 522 | 523 | /** @group setParam */ 524 | def setFamily(value: String): this.type = set(family, value) 525 | 526 | /** @group setParam */ 527 | def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) 528 | 529 | /** @group setParam */ 530 | def setLRMaxIter(value: Int): this.type = set(LRMaxIter, value) 531 | 532 | /** @group setParam */ 533 | def setProbabilityCol(value: String): this.type = set(probabilityCol, value) 534 | 535 | /** @group setParam */ 536 | def setRawPredictionCol(value: String): this.type = set(rawPredictionCol, value) 537 | 538 | /** @group setParam */ 539 | def setRegParam(value: Double): this.type = set(regParam, value) 540 | 541 | /** @group setParam */ 542 | def setStandardization(value: Boolean): this.type = set(standardization, value) 543 | 544 | /** @group setParam */ 545 | def setTol(value: Double): this.type = set(tol, value) 546 | 547 | /** @group setParam */ 548 | def setWeightCol(value: String): this.type = set(weightCol, value) 549 | 550 | /** @group setParam */ 551 | def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) 552 | 553 | /** @group setParam */ 554 | def setGbtGeneratedFeaturesCol(value: String): this.type = set(gbtGeneratedFeaturesCol, value) 555 | 556 | /** 557 | * Set threshold in binary classification, in range [0, 1]. 558 | * 559 | * If the estimated probability of class label 1 is greater than threshold, then predict 1, 560 | * else 0. A high threshold encourages the model to predict 0 more often; 561 | * a low threshold encourages the model to predict 1 more often. 562 | * 563 | * '''Note''': Calling this with threshold p is equivalent to calling 564 | * `setThresholds(Array(1-p, p))`. 565 | * When `setThreshold()` is called, any user-set value for `thresholds` will be cleared. 566 | * If both `threshold` and `thresholds` are set in a ParamMap, then they must be 567 | * equivalent. 568 | * 569 | * Default is 0.5. 570 | * 571 | * @group setParam 572 | */ 573 | // TODO: Implement SPARK-11543? 574 | def setThreshold(value: Double): this.type = { 575 | if (isSet(thresholds)) clear(thresholds) 576 | set(threshold, value) 577 | } 578 | 579 | /** 580 | * Set thresholds in multiclass (or binary) classification to adjust the probability of 581 | * predicting each class. Array must have length equal to the number of classes, 582 | * with values greater than 0, excepting that at most one value may be 0. 583 | * The class with largest value p/t is predicted, where p is the original probability of that 584 | * class and t is the class's threshold. 585 | * 586 | * '''Note''': When `setThresholds()` is called, any user-set value for `threshold` 587 | * will be cleared. 588 | * If both `threshold` and `thresholds` are set in a ParamMap, then they must be 589 | * equivalent. 590 | * 591 | * @group setParam 592 | */ 593 | def setThresholds(value: Array[Double]): this.type = { 594 | if (isSet(threshold)) clear(threshold) 595 | set(thresholds, value) 596 | } 597 | 598 | /** 599 | * Examine a schema to identify categorical (Binary and Nominal) features 600 | * @param featuresSchema Schema of the fetaures column. 601 | * 602 | * If a feature does not have metadata, it is assumed to be continuous. 603 | * 604 | * If a feature is Nominal, then it must have the number of values 605 | * specified. 606 | * @return Map: feature index to number of categories. 607 | * 608 | * The map's set of keys will be the set of categorical feature indices. 609 | */ 610 | private def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = { 611 | val metadata = AttributeGroup.fromStructField(featuresSchema) 612 | if (metadata.attributes.isEmpty) { 613 | HashMap.empty[Int, Int] 614 | } else { 615 | metadata.attributes.get.zipWithIndex.flatMap{ case (attr, idx) => 616 | if (attr == null) { 617 | Iterator() 618 | } else { 619 | attr match { 620 | case _: NumericAttribute | UnresolvedAttribute => Iterator() 621 | case binAttr: BinaryAttribute => Iterator(idx -> 2) 622 | case nomAttr: NominalAttribute => 623 | nomAttr.getNumValues match { 624 | case Some(numValues: Int) => Iterator(idx -> numValues) 625 | case None => throw new IllegalArgumentException(s"Feature $idx is " + 626 | s"marked as Nominal (categorical), but it does not have the number" + 627 | s" of values specified.") 628 | } 629 | } 630 | } 631 | }.toMap 632 | } 633 | } 634 | 635 | /** 636 | * Create a Strategy instance to use with the old API. 637 | * @param categoricalFeatures Map: feature index to number of categories. 638 | * @return Strategy instance 639 | */ 640 | private def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = { 641 | val strategy = OldStrategy.defaultStrategy(OldAlgo.Classification) 642 | strategy.impurity = OldVariance 643 | strategy.checkpointInterval = getCheckpointInterval 644 | strategy.maxBins = getMaxBins 645 | strategy.maxDepth = getMaxDepth 646 | strategy.maxMemoryInMB = getMaxMemoryInMB 647 | strategy.minInfoGain = getMinInfoGain 648 | strategy.minInstancesPerNode = getMinInstancePerNode 649 | strategy.useNodeIdCache = getCacheNodeIds 650 | strategy.numClasses = 2 651 | strategy.categoricalFeaturesInfo = categoricalFeatures 652 | strategy.subsamplingRate = getSubsamplingRate 653 | strategy 654 | } 655 | 656 | /** 657 | * Get old Gradient Boosting Loss type 658 | * @return Loss type 659 | */ 660 | private def getOldLossType: OldLoss = { 661 | getLossType match { 662 | case "logistic" => OldLogLoss 663 | case _ => 664 | // Should never happen because of check in setter method. 665 | throw new RuntimeException(s"GBTClassifier was given bad loss type:" + 666 | s" $getLossType") 667 | } 668 | } 669 | 670 | /** 671 | * Train a GBTLRClassification Model which consists of GradientBoostedTreesModel 672 | * and LogisticRegressionModel. 673 | * @param dataset Input data. 674 | * @return GBTLRClassification model. 675 | */ 676 | override def train(dataset: Dataset[_]): GBTLRClassificationModel = instrumented { instr => 677 | val categoricalFeatures: Map[Int, Int] = 678 | getCategoricalFeatures(dataset.schema($(featuresCol))) 679 | 680 | // GBT only supports 2 classes now. 681 | val oldDataset: RDD[OldLabeledPoint] = 682 | dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map { 683 | case Row(label: Double, features: Vector) => 684 | require(label == 0 || label == 1, s"GBTClassifier was given" + 685 | s" dataset with invalid label $label. Labels must be in {0,1}; note that" + 686 | s" GBTClassifier currently only supports binary classification.") 687 | OldLabeledPoint(label, new OldDenseVector(features.toArray)) 688 | } 689 | 690 | val numFeatures = oldDataset.first().features.size 691 | val strategy = getOldStrategy(categoricalFeatures) 692 | val boostingStrategy = new OldBoostingStrategy(strategy, getOldLossType, 693 | getGBTMaxIter, getStepSize) 694 | 695 | instr.logPipelineStage(this) 696 | instr.logNumFeatures(numFeatures) 697 | instr.logNumClasses(2) 698 | instr.logDataset(dataset) 699 | instr.logParams(this) 700 | 701 | // train a gradient boosted tree model using boostingStrategy. 702 | val gbtModel = GradientBoostedTrees.train(oldDataset, boostingStrategy) 703 | 704 | // udf for creating a feature column which consists of original features 705 | // and gbt model generated features. 706 | val addFeatureUDF = udf { (features: Vector) => 707 | val gbtFeatures = getGBTFeatures(gbtModel, features) 708 | Vectors.dense(features.toArray ++ gbtFeatures.toArray) 709 | } 710 | 711 | val datasetWithCombinedFeatures = dataset.withColumn($(gbtGeneratedFeaturesCol), 712 | addFeatureUDF(col($(featuresCol)))) 713 | 714 | // create a logistic regression instance. 715 | val logisticRegression = new LogisticRegression() 716 | .setRegParam($(regParam)) 717 | .setElasticNetParam($(elasticNetParam)) 718 | .setMaxIter($(LRMaxIter)) 719 | .setTol($(tol)) 720 | .setLabelCol($(labelCol)) 721 | .setFeaturesCol($(featuresCol)) 722 | .setFitIntercept($(fitIntercept)) 723 | .setFamily($(family)) 724 | .setStandardization($(standardization)) 725 | .setPredictionCol($(predictionCol)) 726 | .setProbabilityCol($(probabilityCol)) 727 | .setRawPredictionCol($(rawPredictionCol)) 728 | .setAggregationDepth($(aggregationDepth)) 729 | .setFeaturesCol($(gbtGeneratedFeaturesCol)) 730 | 731 | if (isSet(weightCol)) logisticRegression.setWeightCol($(weightCol)) 732 | if (isSet(threshold)) logisticRegression.setThreshold($(threshold)) 733 | if (isSet(thresholds)) logisticRegression.setThresholds($(thresholds)) 734 | 735 | // train a logistic regression model with new combined features. 736 | val lrModel = logisticRegression.fit(datasetWithCombinedFeatures) 737 | 738 | val model = copyValues(new GBTLRClassificationModel(uid, gbtModel, lrModel).setParent(this)) 739 | val summary = new GBTLRClassifierTrainingSummary(datasetWithCombinedFeatures, lrModel.summary, 740 | gbtModel.trees, gbtModel.treeWeights) 741 | model.setSummary(Some(summary)) 742 | model 743 | } 744 | 745 | override def copy(extra: ParamMap): GBTLRClassifier = defaultCopy(extra) 746 | } 747 | 748 | object GBTLRClassifier extends DefaultParamsReadable[GBTLRClassifier] { 749 | 750 | val logger = Logger.getLogger(GBTLRClassifier.getClass) 751 | 752 | override def load(path: String): GBTLRClassifier = super.load(path) 753 | } 754 | 755 | class GBTLRClassificationModel ( 756 | override val uid: String, 757 | val gbtModel: GradientBoostedTreesModel, 758 | val lrModel: LogisticRegressionModel) 759 | extends PredictionModel[Vector, GBTLRClassificationModel] 760 | with GBTLRClassifierParams with MLWritable { 761 | 762 | import GBTLRUtil._ 763 | 764 | private var trainingSummary: Option[GBTLRClassifierTrainingSummary] = None 765 | 766 | private[gbtlr] def setSummary( 767 | summary: Option[GBTLRClassifierTrainingSummary]): this.type = { 768 | this.trainingSummary = summary 769 | this 770 | } 771 | 772 | /** 773 | * Return true if there exists summary of model 774 | */ 775 | def hasSummary: Boolean = trainingSummary.nonEmpty 776 | 777 | def summary: GBTLRClassifierTrainingSummary = trainingSummary.getOrElse { 778 | throw new SparkException( 779 | s"No training summary available for the ${this.getClass.getSimpleName}" 780 | ) 781 | } 782 | 783 | override def write: MLWriter = 784 | new GBTLRClassificationModel.GBTLRClassificationModelWriter(this) 785 | 786 | /** 787 | * Get a combined feature point through gbdt when given a specific feature point. 788 | * @param point Original one point. 789 | * @return A combined feature point. 790 | */ 791 | def getComibinedFeatures( 792 | point: OldLabeledPoint): OldLabeledPoint = { 793 | val numTrees = gbtModel.trees.length 794 | val treeLeafArray = new Array[Array[Int]](numTrees) 795 | for (i <- 0 until numTrees) 796 | treeLeafArray(i) = getLeafNodes(gbtModel.trees(i).topNode) 797 | 798 | var newFeature = new Array[Double](0) 799 | val label = point.label 800 | val features = point.features 801 | for (i <- 0 until numTrees) { 802 | val treePredict = predictModify(gbtModel.trees(i).topNode, features.toDense) 803 | val treeArray = new Array[Double]((gbtModel.trees(i).numNodes + 1) / 2) 804 | treeArray(treeLeafArray(i).indexOf(treePredict)) = 1 805 | newFeature = newFeature ++ treeArray 806 | } 807 | OldLabeledPoint(label.toInt, new OldDenseVector(features.toArray ++ newFeature)) 808 | } 809 | 810 | // udf for creating a feature column which consists of original features 811 | // and gbt model generated features. 812 | private val addFeatureUDF = udf { (features: Vector) => 813 | val gbtFeatures = getGBTFeatures(gbtModel, features) 814 | Vectors.dense(features.toArray ++ gbtFeatures.toArray) 815 | } 816 | 817 | override def transform(dataset: Dataset[_]): DataFrame = { 818 | val datasetWithCombinedFeatures = dataset.withColumn($(gbtGeneratedFeaturesCol), 819 | addFeatureUDF(col($(featuresCol)))) 820 | val predictions = lrModel.transform(datasetWithCombinedFeatures) 821 | predictions 822 | } 823 | 824 | // just implements the abstract method in PredictionModel, but is not used. 825 | override def predict(features: Vector): Double = 0.0 826 | 827 | /** 828 | * Evaluations the model on a test dataset. 829 | * @param dataset Test dataset to evalute model on. 830 | */ 831 | def evaluate(dataset: Dataset[_]): GBTLRClassifierSummary = { 832 | val datasetWithCombinedFeatures = dataset.withColumn($(gbtGeneratedFeaturesCol), 833 | addFeatureUDF(col($(featuresCol)))) 834 | val lrSummary = lrModel.evaluate(datasetWithCombinedFeatures) 835 | new GBTLRClassifierSummary(lrSummary) 836 | } 837 | 838 | override def copy(extra: ParamMap): GBTLRClassificationModel = { 839 | val copied = copyValues(new GBTLRClassificationModel(uid, gbtModel, lrModel), extra) 840 | copied.setSummary(trainingSummary).setParent(this.parent) 841 | } 842 | 843 | /** 844 | * Get a set of rules which can reach the different leaf nodes. 845 | * @param node Root node. 846 | * @param rule Current set of rules 847 | * @param rules Final set of rules of all leaf node. 848 | */ 849 | private def getLeafRules( 850 | node: OldNode, 851 | rule: String, 852 | rules: mutable.ArrayBuilder[String]) { 853 | val split = node.split 854 | if (node.isLeaf) { 855 | rules += rule 856 | } else { 857 | if (split.get.featureType == FeatureType.Continuous) { 858 | val leftRule = rule + s", feature#${split.get.feature} < ${split.get.threshold}" 859 | getLeafRules(node.leftNode.get, leftRule, rules) 860 | val rightRule = rule + s", feature#${split.get.feature} > ${split.get.threshold}" 861 | getLeafRules(node.rightNode.get, rightRule, rules) 862 | } else { 863 | val leftRule = rule + s", feature#${split.get.feature}'s value is in the Set" + 864 | split.get.categories.mkString("[", ",", "]") 865 | getLeafRules(node.leftNode.get, leftRule, rules) 866 | val rightRule = rule + s", feature#${split.get.feature}'s value is not in the Set" + 867 | split.get.categories.mkString("[", ",", "]") 868 | getLeafRules(node.rightNode.get, rightRule, rules) 869 | } 870 | } 871 | } 872 | 873 | /** 874 | * Get a description of each dimension of extra feature with a trained weight through lr. 875 | * @return An array of tuple2, in each tuple, the first elem indicates the weight of extra 876 | * feature, the second elem is the description of how to get this feature. 877 | */ 878 | def getRules: Array[Tuple2[Double, String]] = { 879 | val numTrees = gbtModel.trees.length 880 | val rules = new Array[Array[String]](numTrees) 881 | var numExtraFeatures = 0 882 | for (i <- 0 until numTrees) { 883 | val rulesInEachTree = mutable.ArrayBuilder.make[String] 884 | getLeafRules(gbtModel.trees(i).topNode, "", rulesInEachTree) 885 | val rule = rulesInEachTree.result() 886 | numExtraFeatures += rule.length 887 | rules(i) = rule 888 | } 889 | val weightsInLR = lrModel.coefficients.toArray 890 | val extraWeights = 891 | weightsInLR.slice(weightsInLR.length - numExtraFeatures, weightsInLR.length) 892 | extraWeights.zip(rules.flatMap(x => x)) 893 | } 894 | 895 | } 896 | 897 | object GBTLRClassificationModel extends MLReadable[GBTLRClassificationModel] { 898 | 899 | val logger = Logger.getLogger(GBTLRClassificationModel.getClass) 900 | 901 | 902 | override def read: MLReader[GBTLRClassificationModel] = new GBTLRClassificationModelReader 903 | 904 | override def load(path: String): GBTLRClassificationModel = super.load(path) 905 | 906 | private[GBTLRClassificationModel] class GBTLRClassificationModelWriter( 907 | instance: GBTLRClassificationModel) extends MLWriter { 908 | override def saveImpl(path: String) { 909 | // Save metadata and Params 910 | DefaultParamsWriter.saveMetadata(instance, path, sc) 911 | // Save model data 912 | val gbtDataPath = new Path(path, "gbtData").toString 913 | instance.gbtModel.save(sc, gbtDataPath) 914 | val lrDataPath = new Path(path, "lrData").toString 915 | instance.lrModel.save(lrDataPath) 916 | } 917 | } 918 | 919 | private class GBTLRClassificationModelReader 920 | extends MLReader[GBTLRClassificationModel] { 921 | 922 | private val className = classOf[GBTLRClassificationModel].getName 923 | 924 | override def load(path: String): GBTLRClassificationModel = { 925 | val metadata = DefaultParamsReader.loadMetadata(path, sc, className) 926 | val gbtDataPath = new Path(path, "gbtData").toString 927 | val lrDataPath = new Path(path, "lrData").toString 928 | val gbtModel = GradientBoostedTreesModel.load(sc, gbtDataPath) 929 | val lrModel = LogisticRegressionModel.load(lrDataPath) 930 | val model = new GBTLRClassificationModel(metadata.uid, gbtModel, lrModel) 931 | metadata.getAndSetParams(model) 932 | model 933 | } 934 | } 935 | } 936 | 937 | class GBTLRClassifierTrainingSummary ( 938 | @transient val newDataset: DataFrame, 939 | val logRegSummary: LogisticRegressionTrainingSummary, 940 | val gbtTrees: Array[DecisionTreeModel], 941 | val treeWeights: Array[Double]) extends Serializable { 942 | } 943 | 944 | class GBTLRClassifierSummary ( 945 | val binaryLogisticRegressionSummary: LogisticRegressionSummary) 946 | extends Serializable { 947 | } 948 | 949 | 950 | object GBTLRUtil { 951 | /** 952 | * Get an array of leaf nodes according to the root node of a tree. 953 | * The order of nodes in the array is from left to right. 954 | * 955 | * @param node Root node of a tree. 956 | * @return An array stores the leaf node ids. 957 | */ 958 | def getLeafNodes(node: OldNode): Array[Int] = { 959 | var treeLeafNodes = new Array[Int](0) 960 | if (node.isLeaf) { 961 | treeLeafNodes = treeLeafNodes :+ (node.id) 962 | } else { 963 | treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.leftNode.get) 964 | treeLeafNodes = treeLeafNodes ++ getLeafNodes(node.rightNode.get) 965 | } 966 | treeLeafNodes 967 | } 968 | 969 | /** 970 | * Get the leaf node id at which the features will be located. 971 | * 972 | * @param node Root node of a tree. 973 | * @param features Dense Vector features. 974 | * @return Leaf node id. 975 | */ 976 | def predictModify(node: OldNode, features: OldDenseVector): Int = { 977 | val split = node.split 978 | if (node.isLeaf) { 979 | node.id 980 | } else { 981 | if (split.get.featureType == FeatureType.Continuous) { 982 | if (features(split.get.feature) <= split.get.threshold) { 983 | predictModify(node.leftNode.get, features) 984 | } else { 985 | predictModify(node.rightNode.get, features) 986 | } 987 | } else { 988 | if (split.get.categories.contains(features(split.get.feature))) { 989 | predictModify(node.leftNode.get, features) 990 | } else { 991 | predictModify(node.rightNode.get, features) 992 | } 993 | } 994 | } 995 | } 996 | 997 | /** 998 | *Get GBT generated features from gbt model 999 | * 1000 | * @param gbtModel 1001 | * @param features 1002 | * @return 1003 | */ 1004 | def getGBTFeatures(gbtModel: GradientBoostedTreesModel, features: Vector): Vector = { 1005 | val GBTMaxIter = gbtModel.trees.length 1006 | val oldFeatures = new OldDenseVector(features.toArray) 1007 | val treeLeafArray = new Array[Array[Int]](GBTMaxIter) 1008 | for (i <- 0 until GBTMaxIter) 1009 | treeLeafArray(i) = getLeafNodes(gbtModel.trees(i).topNode) 1010 | var newFeature = new Array[Double](0) 1011 | for (i <- 0 until GBTMaxIter) { 1012 | val treePredict = predictModify(gbtModel.trees(i).topNode, oldFeatures.toDense) 1013 | val treeArray = new Array[Double]((gbtModel.trees(i).numNodes + 1) / 2) 1014 | treeArray(treeLeafArray(i).indexOf(treePredict)) = 1 1015 | newFeature = newFeature ++ treeArray 1016 | } 1017 | Vectors.dense(newFeature) 1018 | } 1019 | } 1020 | -------------------------------------------------------------------------------- /src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | log4j.rootCategory=WARN, console 2 | log4j.appender.console=org.apache.log4j.ConsoleAppender 3 | log4j.appender.console.target=System.err 4 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 5 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 6 | 7 | # Settings to quiet third party logs that are too verbose 8 | log4j.logger.org.eclipse.jetty=WARN 9 | log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR 10 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 11 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 12 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/SparkFunSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark 19 | 20 | // scalastyle:off 21 | import java.io.File 22 | 23 | import org.apache.spark.internal.Logging 24 | import org.apache.spark.util.AccumulatorContext 25 | import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome} 26 | 27 | /** 28 | * Base abstract class for all unit tests in Spark for handling common functionality. 29 | */ 30 | abstract class SparkFunSuite 31 | extends FunSuite 32 | with BeforeAndAfterAll 33 | with Logging { 34 | // scalastyle:on 35 | 36 | protected override def afterAll(): Unit = { 37 | try { 38 | // Avoid leaking map entries in tests that use accumulators without SparkContext 39 | AccumulatorContext.clear() 40 | } finally { 41 | super.afterAll() 42 | } 43 | } 44 | 45 | // helper function 46 | protected final def getTestResourceFile(file: String): File = { 47 | new File(getClass.getClassLoader.getResource(file).getFile) 48 | } 49 | 50 | protected final def getTestResourcePath(file: String): String = { 51 | getTestResourceFile(file).getCanonicalPath 52 | } 53 | 54 | /** 55 | * Log the suite name and the test name before and after each test. 56 | * 57 | * Subclasses should never override this method. If they wish to run 58 | * custom code before and after each test, they should mix in the 59 | * {{org.scalatest.BeforeAndAfter}} trait instead. 60 | */ 61 | final protected override def withFixture(test: NoArgTest): Outcome = { 62 | val testName = test.text 63 | val suiteName = this.getClass.getName 64 | val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") 65 | try { 66 | logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") 67 | test() 68 | } finally { 69 | logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") 70 | } 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/gbtlr/GBTLRClassifierSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.gbtlr 2 | 3 | import org.apache.spark.SparkFunSuite 4 | import org.apache.spark.ml.feature.{LabeledPoint, StringIndexer, VectorAssembler} 5 | import org.apache.spark.ml.linalg.{Vector, Vectors} 6 | import org.apache.spark.ml.util.DefaultReadWriteTest 7 | import org.apache.spark.mllib.linalg.{DenseVector => OldDenseVector} 8 | import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint} 9 | import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node => OldNode} 10 | import org.apache.spark.mllib.util.MLlibTestSparkContext 11 | import org.apache.spark.sql.{DataFrame, Dataset, Row} 12 | import org.apache.spark.sql.functions._ 13 | 14 | 15 | class GBTLRClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { 16 | 17 | @transient var dataset: Dataset[_] = _ 18 | 19 | override def beforeAll(): Unit = { 20 | super.beforeAll() 21 | dataset = spark.createDataFrame(GBTLRClassifierSuite.generateOrderedLabeledPoints(10, 100)) 22 | } 23 | 24 | test("default params") { 25 | val gBTLRClassifier = new GBTLRClassifier() 26 | 27 | assert(gBTLRClassifier.getSeed === gBTLRClassifier.getClass.getName.hashCode.toLong) 28 | assert(gBTLRClassifier.getSubsamplingRate === 1.0) 29 | assert(gBTLRClassifier.getGBTMaxIter === 20) 30 | assert(gBTLRClassifier.getStepSize === 0.1) 31 | assert(gBTLRClassifier.getMaxDepth === 5) 32 | assert(gBTLRClassifier.getMaxBins === 32) 33 | assert(gBTLRClassifier.getMinInstancePerNode === 1) 34 | assert(gBTLRClassifier.getMinInfoGain === 0.0) 35 | assert(gBTLRClassifier.getCheckpointInterval === 10) 36 | assert(gBTLRClassifier.getFitIntercept === true) 37 | assert(gBTLRClassifier.getProbabilityCol === "probability") 38 | assert(gBTLRClassifier.getRawPredictionCol === "rawPrediction") 39 | assert(gBTLRClassifier.getStandardization === true) 40 | assert(gBTLRClassifier.getThreshold === 0.5) 41 | assert(gBTLRClassifier.getLossType === "logistic") 42 | assert(gBTLRClassifier.getCacheNodeIds === false) 43 | assert(gBTLRClassifier.getMaxMemoryInMB === 256) 44 | assert(gBTLRClassifier.getRegParam === 0.0) 45 | assert(gBTLRClassifier.getElasticNetParam === 0.0) 46 | assert(gBTLRClassifier.getFamily === "auto") 47 | assert(gBTLRClassifier.getLRMaxIter === 100) 48 | assert(gBTLRClassifier.getTol === 1E-6) 49 | assert(gBTLRClassifier.getAggregationDepth === 2) 50 | } 51 | 52 | test("set params") { 53 | val gBTLRClassifier = new GBTLRClassifier() 54 | .setSeed(123L) 55 | .setSubsamplingRate(0.5) 56 | .setGBTMaxIter(10) 57 | .setStepSize(0.5) 58 | .setMaxDepth(10) 59 | .setMaxBins(20) 60 | .setMinInstancesPerNode(2) 61 | .setMinInfoGain(1.0) 62 | .setCheckpointInterval(5) 63 | .setFitIntercept(false) 64 | .setProbabilityCol("test_probability") 65 | .setRawPredictionCol("test_rawPrediction") 66 | .setStandardization(false) 67 | .setThreshold(1.0) 68 | .setCacheNodeIds(true) 69 | .setMaxMemoryInMB(128) 70 | .setRegParam(1.0) 71 | .setElasticNetParam(0.5) 72 | .setFamily("binomial") 73 | .setLRMaxIter(50) 74 | .setTol(1E-3) 75 | .setAggregationDepth(3) 76 | 77 | assert(gBTLRClassifier.getSeed === 123L) 78 | assert(gBTLRClassifier.getSubsamplingRate === 0.5) 79 | assert(gBTLRClassifier.getGBTMaxIter === 10) 80 | assert(gBTLRClassifier.getStepSize === 0.5) 81 | assert(gBTLRClassifier.getMaxDepth === 10) 82 | assert(gBTLRClassifier.getMaxBins === 20) 83 | assert(gBTLRClassifier.getMinInstancePerNode === 2) 84 | assert(gBTLRClassifier.getMinInfoGain === 1.0) 85 | assert(gBTLRClassifier.getCheckpointInterval === 5) 86 | assert(gBTLRClassifier.getFitIntercept === false) 87 | assert(gBTLRClassifier.getProbabilityCol === "test_probability") 88 | assert(gBTLRClassifier.getRawPredictionCol === "test_rawPrediction") 89 | assert(gBTLRClassifier.getStandardization === false) 90 | assert(gBTLRClassifier.getThreshold === 1.0) 91 | assert(gBTLRClassifier.getCacheNodeIds === true) 92 | assert(gBTLRClassifier.getMaxMemoryInMB === 128) 93 | assert(gBTLRClassifier.getRegParam === 1.0) 94 | assert(gBTLRClassifier.getElasticNetParam === 0.5) 95 | assert(gBTLRClassifier.getFamily === "binomial") 96 | assert(gBTLRClassifier.getLRMaxIter === 50) 97 | assert(gBTLRClassifier.getTol === 1E-3) 98 | assert(gBTLRClassifier.getAggregationDepth === 3) 99 | } 100 | 101 | test("combination features") { 102 | val gBTLRClassifier = new GBTLRClassifier() 103 | val model = gBTLRClassifier.train(dataset) 104 | val gbtModel = model.gbtModel 105 | var numLeafNodes = 0 106 | for (i <- 0 until gbtModel.trees.length) { 107 | numLeafNodes += (gbtModel.trees(i).numNodes + 1) / 2 108 | } 109 | val newDataset = model.summary.newDataset 110 | val numFeatures = model.lrModel.numFeatures 111 | val lrSummary = model.summary.logRegSummary 112 | val originNumFeatures = dataset.select(col(lrSummary.labelCol), 113 | col("features")).rdd.map { 114 | case Row(label: Double, features: Vector) => features 115 | }.first().size 116 | assert(numFeatures === originNumFeatures + numLeafNodes) 117 | } 118 | 119 | test("add features") { 120 | val gBTLRClassifier = new GBTLRClassifier() 121 | val model = gBTLRClassifier.train(dataset) 122 | val point = new OldLabeledPoint(1.0, new OldDenseVector(Array.fill(100)(0))) 123 | val combinedPoint = model.getComibinedFeatures(point) 124 | val combinedFeatures = combinedPoint.features.toArray 125 | assert(combinedFeatures.sum === 20.0) 126 | } 127 | 128 | test("rules") { 129 | val gBTLRClassifier = new GBTLRClassifier() 130 | val model = gBTLRClassifier.fit(dataset) 131 | val numTrees = model.gbtModel.trees.length 132 | var totalLeafNodes = 0 133 | for (i <- 0 until numTrees) { 134 | totalLeafNodes += (model.gbtModel.trees(i).numNodes + 1) / 2 135 | } 136 | assert(model.getRules.length === totalLeafNodes) 137 | } 138 | 139 | // Uncomment to run GBTLRClassificationModel read / write tests 140 | // test("read/write") { 141 | // 142 | // def checkModelData(model1: GBTLRClassificationModel, model2: GBTLRClassificationModel): Unit = { 143 | // assert(model1.gbtModel.algo === model2.gbtModel.algo) 144 | // try { 145 | // model1.gbtModel.trees.zip(model2.gbtModel.trees).foreach { 146 | // case (tree1, tree2) => checkModelEqual(tree1, tree2) 147 | // } 148 | // assert(model1.gbtModel.treeWeights === model2.gbtModel.treeWeights) 149 | // } catch { 150 | // case ex: Exception => 151 | // throw new AssertionError("checkModelData failed since " + 152 | // "the two gbtModels were not identical.\n") 153 | // } 154 | // assert(model1.lrModel.intercept === model2.lrModel.intercept) 155 | // assert(model1.lrModel.coefficients.toArray === model2.lrModel.coefficients.toArray) 156 | // assert(model1.lrModel.numFeatures === model2.lrModel.numFeatures) 157 | // assert(model1.lrModel.numClasses === model2.lrModel.numClasses) 158 | // } 159 | // 160 | // val gBTLRClassifier = new GBTLRClassifier() 161 | // testEstimatorAndModelReadWrite( 162 | // gBTLRClassifier, dataset, 163 | // GBTLRClassifierSuite.allParamSettings, 164 | // GBTLRClassifierSuite.allParamSettings, 165 | // checkModelData 166 | // ) 167 | // } 168 | 169 | /** 170 | * Return true iff the two nodes and their descendents are exactly the same. 171 | */ 172 | private def checkTreeEqual(a: OldNode, b: OldNode): Unit = { 173 | assert(a.id === b.id) 174 | assert(a.predict === b.predict) 175 | assert(a.impurity === b.impurity) 176 | assert(a.isLeaf === b.isLeaf) 177 | assert(a.split === b.split) 178 | (a.stats, b.stats) match { 179 | case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain) 180 | case (None, None) => 181 | case _ => throw new AssertionError( 182 | s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})") 183 | } 184 | (a.leftNode, b.leftNode) match { 185 | case (Some(aNode), Some(bNode)) => checkTreeEqual(aNode, bNode) 186 | case (None, None) => 187 | case _ => throw new AssertionError("Only one instance has leftNode defined. " + 188 | s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})") 189 | } 190 | (a.rightNode, b.rightNode) match { 191 | case (Some(aNode: OldNode), Some(bNode: OldNode)) => checkTreeEqual(aNode, bNode) 192 | case (None, None) => 193 | case _ => throw new AssertionError("Only one instance has rightNode defined. " + 194 | s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})") 195 | } 196 | } 197 | 198 | /** 199 | * Check if the two trees are exactly the same. 200 | * If the trees are not equal, this prints the two trees and throws an exception. 201 | */ 202 | private def checkModelEqual(a: DecisionTreeModel, b: DecisionTreeModel) = { 203 | try { 204 | assert(a.algo === b.algo) 205 | checkTreeEqual(a.topNode, b.topNode) 206 | } catch { 207 | case ex: Exception => 208 | throw new AssertionError("checkEqual failed since the two trees were not " + 209 | "identical.\n" + "TREE A:\n" + a.toDebugString + "\n" + 210 | "TREE B:\n" + b.toDebugString + "\n", ex) 211 | } 212 | } 213 | 214 | test("model transform") { 215 | val gBTLRClassifier = new GBTLRClassifier() 216 | val model = gBTLRClassifier.fit(dataset) 217 | val prediction = model.transform(dataset) 218 | assert(model.lrModel.getFeaturesCol === "gbt_generated_features") 219 | val len1 = prediction.schema.fieldNames.length 220 | val len2 = dataset.schema.fieldNames.length 221 | assert(len1 === len2 + 4) 222 | assert(prediction.schema.fieldNames.contains("gbt_generated_features")) 223 | } 224 | } 225 | 226 | object GBTLRClassifierSuite { 227 | 228 | def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = { 229 | val arr = new Array[LabeledPoint](numInstances) 230 | for (i <- 0 until numInstances) { 231 | val label = if (i < numInstances / 10) { 232 | 0.0 233 | } else if (i < numInstances / 2) { 234 | 1.0 235 | } else if (i < numInstances * 0.9) { 236 | 0.0 237 | } else { 238 | 1.0 239 | } 240 | val features = Array.fill[Double](numFeatures)(i.toDouble) 241 | arr(i) = LabeledPoint(label, Vectors.dense(features)) 242 | } 243 | arr 244 | } 245 | 246 | val allParamSettings: Map[String, Any] = Map( 247 | "seed" -> 123L, 248 | "subsamplingRate" -> 1.0, 249 | "GBTMaxIter" -> 20, 250 | "stepSize" -> 0.1, 251 | "maxDepth" -> 5, 252 | "maxBins" -> 32, 253 | "minInstancesPerNode" -> 1, 254 | "minInfoGain" -> 0.0, 255 | "checkpointInterval" -> 10, 256 | "fitIntercept" -> true, 257 | "probabilityCol" -> "probability", 258 | "rawPredictionCol" -> "rawPrediction", 259 | "standardization" -> true, 260 | "threshold" -> 0.5, 261 | "lossType" -> "logistic", 262 | "cacheNodeIds" -> false, 263 | "maxMemoryInMB" -> 256, 264 | "regParam" -> 0.0, 265 | "elasticNetParam" -> 0.0, 266 | "family" -> "auto", 267 | "LRMaxIter" -> 100, 268 | "tol" -> 1E-6, 269 | "aggregationDepth" -> 2 270 | ) 271 | 272 | } 273 | --------------------------------------------------------------------------------