├── README.md ├── build.sbt ├── project └── plugins.sbt └── src └── main └── scala └── org └── apache └── spark ├── examples └── mllib │ ├── AbstractParams.scala │ └── LambdaMARTRunner.scala └── mllib ├── dataSet └── dataSet.scala ├── tree ├── DerivativeCalculator.scala ├── LambdaMART.scala ├── LambdaMARTDecisionTree.scala ├── config │ ├── Algo.scala │ ├── BoostingStrategy.scala │ └── Strategy.scala ├── configuration │ └── algo.scala ├── impl │ └── FeatureStatsAggregator.scala └── model │ ├── GetDerivatives.scala │ ├── Histogram.scala │ ├── SplitInfo.scala │ ├── ensemblemodels │ └── GradientBoostedDecisionTreesModel.scala │ ├── impurity │ ├── Impurity.scala │ └── Variance.scala │ ├── informationgainstats │ └── InformationGainStats.scala │ ├── node │ └── Node.scala │ ├── nodePredict.scala │ ├── opdtmodel │ └── OptimizedDecisionTreeModel.scala │ └── predict │ └── Predict.scala └── util ├── ProbabilityFunctions.scala ├── TreeUtils.scala └── treeAggregatorFormat.scala /README.md: -------------------------------------------------------------------------------- 1 | ## SparkTree 2 | 3 | SparkTree is an efficient and scalable tree ensemble learning system with Spark. We focus on both performance and accurcy, supporting ranking, classification and regression. 4 | 5 | ## Contributors 6 | 7 | * Cui Li ([@cui](https://github.com/girlatsnow)) 8 | 9 | * Bo Zhao ([@bhoppi](https://github.com/bhoppi)) 10 | 11 | * Hucheng Zhou ([@hucheng](https://github.com/hucheng)) 12 | 13 | * Jianglin Liang 14 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "LambdaMART" 2 | 3 | version := "3.0" 4 | 5 | scalaVersion := "2.10.6" 6 | scalaBinaryVersion := "2.10" 7 | 8 | dependencyOverrides ++= Set( 9 | "org.scala-lang" % "scala-library" % scalaVersion.value, 10 | "org.scala-lang" % "scala-compiler" % scalaVersion.value, 11 | "org.scala-lang" % "scala-reflect" % scalaVersion.value 12 | ) 13 | 14 | libraryDependencies ++= Seq( 15 | "org.apache.spark" % "spark-core_2.10" % "1.6.0" % "provided", 16 | "org.apache.spark" % "spark-mllib_2.10" % "1.6.0" % "provided", 17 | "com.github.scopt" % "scopt_2.10" % "3.3.0" 18 | ) 19 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.1") 2 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.examples.mllib 19 | 20 | import scala.reflect.runtime.universe._ 21 | 22 | /** 23 | * Abstract class for parameter case classes. 24 | * This overrides the [[toString]] method to print all case class fields by name and value. 25 | * @tparam T Concrete parameter class. 26 | */ 27 | abstract class AbstractParams[T: TypeTag] { 28 | 29 | private def tag: TypeTag[T] = typeTag[T] 30 | 31 | /** 32 | * Finds all case class fields in concrete class instance, and outputs them in JSON-style format: 33 | * { 34 | * [field name]:\t[field value]\n 35 | * [field name]:\t[field value]\n 36 | * ... 37 | * } 38 | */ 39 | override def toString: String = { 40 | val tpe = tag.tpe 41 | val allAccessors = tpe.declarations.collect { 42 | case m: MethodSymbol if m.isCaseAccessor => m 43 | } 44 | val mirror = runtimeMirror(getClass.getClassLoader) 45 | val instanceMirror = mirror.reflect(this) 46 | allAccessors.map { f => 47 | val paramName = f.name.toString 48 | val fieldMirror = instanceMirror.reflectField(f) 49 | val paramValue = fieldMirror.get 50 | s" $paramName:\t$paramValue" 51 | }.mkString("{\n", ",\n", "\n}") 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/mllib/LambdaMARTRunner.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.mllib 2 | 3 | import breeze.collection.mutable.SparseArray 4 | import org.apache.hadoop.fs.Path 5 | import org.apache.spark.mllib.dataSet.{dataSet, dataSetLoader} 6 | import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics 7 | import org.apache.spark.mllib.linalg.{Vector, Vectors} 8 | import org.apache.spark.mllib.tree.config.Algo 9 | import org.apache.spark.mllib.tree.model.SplitInfo 10 | import org.apache.spark.mllib.tree.model.ensemblemodels.GradientBoostedDecisionTreesModel 11 | import org.apache.spark.mllib.tree.{DerivativeCalculator, LambdaMART, config} 12 | import org.apache.spark.mllib.util.{MLUtils, TreeUtils, treeAggregatorFormat} 13 | import org.apache.spark.rdd.RDD 14 | import org.apache.spark.storage.StorageLevel 15 | import org.apache.spark.{HashPartitioner, SparkConf, SparkContext} 16 | import scopt.OptionParser 17 | 18 | import scala.language.reflectiveCalls 19 | import scala.util.Random 20 | 21 | 22 | object LambdaMARTRunner { 23 | 24 | class Params(var trainingData: String = null, 25 | var queryBoundy: String = null, 26 | var label: String = null, 27 | var initScores: String = null, 28 | var testData: String = null, 29 | var testQueryBound: String = null, 30 | var testLabel: String = null, 31 | var validationData: String = null, 32 | var queryBoundyValidate: String = null, 33 | var initScoreValidate: String = null, 34 | var labelValidate: String = null, 35 | var featureNoToFriendlyName: String = null, 36 | var outputTreeEnsemble: String = null, 37 | var expandTreeEnsemble: Boolean = false, 38 | var featureIniFile: String = null, 39 | var gainTableStr: String = null, 40 | var algo: String = "LambdaMart", 41 | var learningStrategy: String = "sgd", 42 | var maxDepth: Array[Int] = null, 43 | var numLeaves: Int = 0, 44 | var numPruningLeaves: Array[Int] = null, 45 | var numIterations: Array[Int] = null, 46 | var maxSplits: Int = 128, 47 | var learningRate: Array[Double] = null, 48 | var minInstancesPerNode: Array[Int] = null, 49 | var testSpan: Int = 0, 50 | var sampleFeaturePercent: Double = 1.0, 51 | var sampleQueryPercent: Double = 1.0, 52 | var sampleDocPercent: Double = 1.0, 53 | var numPartitions: Int = 160, 54 | var ffraction: Double = 1.0, 55 | var sfraction: Double = 1.0, 56 | var secondaryMS: Double = 0.0, 57 | var secondaryLE: Boolean = false, 58 | var sigma: Double = 1.0, 59 | var distanceWeight2: Boolean = false, 60 | var baselineAlpha: Array[Double] = null, 61 | var baselineAlphaFilename: String = null, 62 | var entropyCoefft: Double = 0.0, 63 | var featureFirstUsePenalty: Double = 0.0, 64 | var featureReusePenalty: Double = 0.0, 65 | var outputNdcgFilename: String = null, 66 | var active_lambda_learningStrategy: Boolean = false, 67 | var rho_lambda: Double = 0.5, 68 | var active_leaves_value_learningStrategy: Boolean = false, 69 | var rho_leave: Double = 0.5, 70 | var GainNormalization: Boolean = false, 71 | var feature2NameFile: String = null, 72 | var validationSpan: Int = 10, 73 | var useEarlystop: Boolean = true, 74 | var secondGainsFileName: String = null, 75 | var secondaryInverseMaxDcgFileName: String = null, 76 | var secondGains: Array[Double] = null, 77 | var secondaryInverseMaxDcg: Array[Double] = null, 78 | var discountsFilename: String = null, 79 | var discounts: Array[Double] = null, 80 | var sampleWeightsFilename: String = null, 81 | var sampleWeights: Array[Double] = null, 82 | var baselineDcgsFilename: String = null, 83 | var baselineDcgs: Array[Double] = null) extends java.io.Serializable { 84 | 85 | override def toString: String = { 86 | val propertiesStr = s"trainingData = $trainingData\nqueryBoundy = $queryBoundy\nlabel = $label\ninitScores = $initScores\n" + 87 | s"testData = $testData\ntestQueryBound = $testQueryBound\ntestLabel = $testLabel\ntestSpan = $testSpan\n" + 88 | s"sampleFeaturePercent = $sampleFeaturePercent\nsampleQueryPercent = $sampleQueryPercent\nsampleDocPercent = $sampleDocPercent\n" + 89 | s"outputTreeEnsemble = $outputTreeEnsemble\nfeatureNoToFriendlyName = $featureNoToFriendlyName\nvalidationData = $validationData\n" + 90 | s"queryBoundyValidate = $queryBoundyValidate\ninitScoreValidate = $initScoreValidate\nlabelValidate = $labelValidate\n" + 91 | s"expandTreeEnsemble = $expandTreeEnsemble\nfeatureIniFile = $featureIniFile\ngainTableStr = $gainTableStr\n" + 92 | s"algo = $algo\nmaxDepth = ${maxDepth.mkString(":")}\nnumLeaves = $numLeaves\nnumPruningLeaves = ${numPruningLeaves.mkString(":")}\nnumIterations = ${numIterations.mkString(":")}\nmaxSplits = $maxSplits\n" + 93 | s"learningRate = ${learningRate.mkString(":")}\nminInstancesPerNode = ${minInstancesPerNode.mkString(":")}\nffraction = $ffraction\nsfraction = $sfraction\n" 94 | 95 | 96 | propertiesStr 97 | } 98 | } 99 | 100 | def main(args: Array[String]) { 101 | val defaultParams = new Params() 102 | 103 | val parser = new OptionParser[Unit]("LambdaMART") { 104 | head("LambdaMART: an implementation of LambdaMART for FastRank.") 105 | 106 | opt[String]("trainingData") required() foreach { x => 107 | defaultParams.trainingData = x 108 | } text ("trainingData path") 109 | opt[String]("queryBoundy") optional() foreach { x => 110 | defaultParams.queryBoundy = x 111 | } text ("queryBoundy path") 112 | opt[String]("label") required() foreach { x => 113 | defaultParams.label = x 114 | } text ("label path to training dataset") 115 | opt[String]("initScores") optional() foreach { x => 116 | defaultParams.initScores = x 117 | } text (s"initScores path to training dataset. If not given, initScores will be {0 ...}.") 118 | 119 | opt[String]("testData") optional() foreach { x => 120 | defaultParams.testData = x 121 | } text ("testData path") 122 | opt[String]("testQueryBound") optional() foreach { x => 123 | defaultParams.testQueryBound = x 124 | } text ("test queryBoundy path") 125 | opt[String]("testLabel") optional() foreach { x => 126 | defaultParams.testLabel = x 127 | } text ("label path to test dataset") 128 | 129 | opt[String]("vd") optional() foreach { x => 130 | defaultParams.validationData = x 131 | } text ("validationData path") 132 | opt[String]("qbv") optional() foreach { x => 133 | defaultParams.queryBoundyValidate = x 134 | } text ("path to queryBoundy for validation data") 135 | opt[String]("lv") optional() foreach { x => 136 | defaultParams.labelValidate = x 137 | } text ("path to label for validation data") 138 | opt[String]("isv") optional() foreach { x => 139 | defaultParams.initScoreValidate = x 140 | } text (s"path to initScore for validation data. If not given, initScores will be {0 ...}.") 141 | 142 | opt[String]("outputTreeEnsemble") required() foreach { x => 143 | defaultParams.outputTreeEnsemble = x 144 | } text ("outputTreeEnsemble path") 145 | opt[String]("ftfn") optional() foreach { x => 146 | defaultParams.featureNoToFriendlyName = x 147 | } text ("path to featureNoToFriendlyName") 148 | opt[Boolean]("expandTreeEnsemble") optional() foreach { x => 149 | defaultParams.expandTreeEnsemble = x 150 | } text (s"expandTreeEnsemble") 151 | opt[String]("featureIniFile") optional() foreach { x => 152 | defaultParams.featureIniFile = x 153 | } text (s"path to featureIniFile") 154 | opt[String]("gainTableStr") required() foreach { x => 155 | defaultParams.gainTableStr = x 156 | } text (s"gainTableStr parameters") 157 | opt[String]("algo") optional() foreach { x => 158 | defaultParams.algo = x 159 | } text (s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") 160 | opt[String]("maxDepth") optional() foreach { x => 161 | defaultParams.maxDepth = x.split(":").map(_.toInt) 162 | } text (s"max depth of the tree, default: ${defaultParams.maxDepth}") 163 | 164 | opt[Int]("numLeaves") optional() foreach { x => 165 | defaultParams.numLeaves = x 166 | } text (s"num of leaves per tree, default: ${defaultParams.numLeaves}. Take precedence over --maxDepth.") 167 | opt[String]("numPruningLeaves") optional() foreach { x => 168 | defaultParams.numPruningLeaves = x.split(":").map(_.toInt) 169 | } text (s"num of leaves per tree after pruning, default: ${defaultParams.numPruningLeaves}.") 170 | opt[String]("numIterations") optional() foreach { x => 171 | defaultParams.numIterations = x.split(":").map(_.toInt) 172 | } text (s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") 173 | opt[String]("minInstancesPerNode") optional() foreach { x => 174 | defaultParams.minInstancesPerNode = x.split(":").map(_.toInt) 175 | } text (s"the minimum number of documents allowed in a leaf of the tree, default: ${defaultParams.minInstancesPerNode}") 176 | opt[Int]("maxSplits") optional() foreach { x => 177 | defaultParams.maxSplits = x 178 | } text (s"max Nodes to be split simultaneously, default: ${defaultParams.maxSplits}") validate { x => 179 | if (x > 0 && x <= 512) success else failure("value incorrect; should be between 1 and 512.") 180 | } 181 | opt[String]("learningRate") optional() foreach { x => 182 | defaultParams.learningRate = x.split(":").map(_.toDouble) 183 | } text (s"learning rate of the score update, default: ${defaultParams.learningRate}") 184 | opt[Int]("testSpan") optional() foreach { x => 185 | defaultParams.testSpan = x 186 | } text (s"test span") 187 | opt[Int]("numPartitions") optional() foreach { x => 188 | defaultParams.numPartitions = x 189 | } text (s"number of partitions, default: ${defaultParams.numPartitions}") 190 | opt[Double]("sampleFeaturePercent") optional() foreach { x => 191 | defaultParams.sampleFeaturePercent = x 192 | } text (s"global feature percentage used for training") 193 | opt[Double]("sampleQueryPercent") optional() foreach { x => 194 | defaultParams.sampleQueryPercent = x 195 | } text (s"global query percentage used for training") 196 | opt[Double]("sampleDocPercent") optional() foreach { x => 197 | defaultParams.sampleDocPercent = x 198 | } text (s"global doc percentage used for classification") 199 | opt[Double]("ffraction") optional() foreach { x => 200 | defaultParams.ffraction = x 201 | } text (s"feature percentage used for training for each tree") 202 | opt[Double]("sfraction") optional() foreach { x => 203 | defaultParams.sfraction = x 204 | } text (s"sample percentage used for training for each tree") 205 | opt[Double]("secondaryMS") optional() foreach { x => 206 | defaultParams.secondaryMS = x 207 | } text (s"secondaryMetricShare") 208 | opt[Boolean]("secondaryLE") optional() foreach { x => 209 | defaultParams.secondaryLE = x 210 | } text (s"secondaryIsoLabelExclusive") 211 | opt[Double]("sigma") optional() foreach { x => 212 | defaultParams.sigma = x 213 | } text (s"parameter for init sigmoid table") 214 | opt[Boolean]("dw") optional() foreach { x => 215 | defaultParams.distanceWeight2 = x 216 | } text (s"Distance weight 2 adjustment to cost") 217 | opt[String]("bafn") optional() foreach { x => 218 | defaultParams.baselineAlphaFilename = x 219 | } text (s"Baseline alpha for tradeoffs of risk (0 is normal training)") 220 | opt[Double]("entropyCoefft") optional() foreach { x => 221 | defaultParams.entropyCoefft = x 222 | } text (s"The entropy (regularization) coefficient between 0 and 1") 223 | opt[Double]("ffup") optional() foreach { x => 224 | defaultParams.featureFirstUsePenalty = x 225 | } text (s"The feature first use penalty coefficient") 226 | opt[Double]("frup") optional() foreach { x => 227 | defaultParams.featureReusePenalty = x 228 | } text (s"The feature re-use penalty (regularization) coefficient") 229 | opt[String]("learningStrategy") optional() foreach { x => 230 | defaultParams.learningStrategy = x 231 | } text (s"learningStrategy for adaptive gradient descent") 232 | opt[String]("oNDCG") optional() foreach { x => 233 | defaultParams.outputNdcgFilename = x 234 | } text (s"save ndcg of training phase in this file") 235 | opt[Boolean]("allr") optional() foreach { x => 236 | defaultParams.active_lambda_learningStrategy = x 237 | } text (s"active lambda learning strategy or not") 238 | opt[Double]("rhol") optional() foreach { x => 239 | defaultParams.rho_lambda = x 240 | } text (s"rho lambda value") 241 | opt[Boolean]("alvst") optional() foreach { x => 242 | defaultParams.active_leaves_value_learningStrategy = x 243 | } text (s"active leave value learning strategy or not") 244 | opt[Double]("rholv") optional() foreach { x => 245 | defaultParams.rho_leave = x 246 | } text (s"rho parameter for leave value learning strategy") 247 | opt[Boolean]("gn") optional() foreach { x => 248 | defaultParams.GainNormalization = x 249 | } text (s"normalize the gian value in the comment or not") 250 | opt[String]("f2nf") optional() foreach { x => 251 | defaultParams.feature2NameFile = x 252 | } text (s"path to feature to name map file") 253 | opt[Int]("vs") optional() foreach { x => 254 | defaultParams.validationSpan = x 255 | } text (s"validation span") 256 | opt[Boolean]("es") optional() foreach { x => 257 | defaultParams.useEarlystop = x 258 | } text (s"apply early stop or not") 259 | opt[String]("sgfn") optional() foreach { x => 260 | defaultParams.secondGainsFileName = x 261 | } 262 | opt[String]("simdfn") optional() foreach { x => 263 | defaultParams.secondaryInverseMaxDcgFileName = x 264 | } 265 | opt[String]("swfn") optional() foreach { x => 266 | defaultParams.sampleWeightsFilename = x 267 | } 268 | opt[String]("bdfn") optional() foreach { x => 269 | defaultParams.baselineDcgsFilename = x 270 | } 271 | opt[String]("dfn") optional() foreach { x => 272 | defaultParams.discountsFilename = x 273 | } 274 | } 275 | parser.parse(args) 276 | run(defaultParams) 277 | } 278 | 279 | def run(params: Params) { 280 | require(params.numIterations.length == params.learningRate.length && 281 | params.numIterations.length == params.minInstancesPerNode.length && 282 | params.numIterations.length == params.numPruningLeaves.length, 283 | s"numiterations: ${params.numIterations}, learningRate: ${params.learningRate}, " + 284 | s"and minInstancesPerNode: ${params.minInstancesPerNode}, numPruningLeaves: ${params.numPruningLeaves}do not match") 285 | 286 | require(!params.maxDepth.exists(dep => dep > 30), s"value maxDepth:${params.maxDepth} incorrect; should be less than or equals to 30.") 287 | 288 | 289 | println(s"LambdaMARTRunner with parameters:\n${params.toString}") 290 | val conf = new SparkConf().setAppName(s"LambdaMARTRunner with $params") 291 | if (params.numPartitions != 0) 292 | conf.set("lambdaMart_numPartitions", s"${params.numPartitions}") 293 | val sc = new SparkContext(conf) 294 | try { 295 | //load training data 296 | var label = dataSetLoader.loadlabelScores(sc, params.label) 297 | val numSamples = label.length 298 | println(s"numSamples: $numSamples") 299 | 300 | var initScores = if (params.initScores == null) { 301 | new Array[Double](numSamples) 302 | } else { 303 | val loaded = dataSetLoader.loadInitScores(sc, params.initScores) 304 | require(loaded.length == numSamples, s"lengthOfInitScores: ${loaded.length} != numSamples: $numSamples") 305 | loaded 306 | } 307 | var queryBoundy = if (params.queryBoundy != null) dataSetLoader.loadQueryBoundy(sc, params.queryBoundy) else null 308 | require(queryBoundy == null || queryBoundy.last == numSamples, s"QueryBoundy ${queryBoundy.last} does not match with data $numSamples !") 309 | val numQuery = if (queryBoundy != null) queryBoundy.length - 1 else 0 310 | println(s"num of data query: $numQuery") 311 | 312 | val numSampleQuery = if (params.sampleQueryPercent < 1) (numQuery * params.sampleQueryPercent).toInt else numQuery 313 | println(s"num of sampling query: $numSampleQuery") 314 | val sampleQueryId = if (params.sampleQueryPercent < 1) { 315 | (new Random(Random.nextInt)).shuffle((0 until queryBoundy.length - 1).toList).take(numSampleQuery).toArray 316 | } else null //query index for training 317 | 318 | if (params.algo=="LambdaMart"&¶ms.sampleQueryPercent < 1) { 319 | // sampling 320 | label = dataSetLoader.getSampleLabels(sampleQueryId, queryBoundy, label) 321 | println(s"num of sampling labels: ${label.length}") 322 | initScores = dataSetLoader.getSampleInitScores(sampleQueryId, queryBoundy, initScores, label.length) 323 | require(label.length == initScores.length, s"num of labels ${label.length} does not match with initScores ${initScores.length}!") 324 | } 325 | else if(params.algo=="Classification" && params.sampleDocPercent!=1){ 326 | val numDocSampling = (params.sampleDocPercent * numSamples).toInt 327 | if (params.sampleDocPercent<1) { 328 | label = label.take(numDocSampling) 329 | initScores = initScores.take(numDocSampling) 330 | } 331 | else{ 332 | val newLabel = new Array[Short](numDocSampling) 333 | val newScores = new Array[Double](numDocSampling) 334 | var is = 0 335 | while(is x._2.default }.filter(_ != 0).count() 360 | // println(s"numFeats sparse on nonZero: $numNonZeros") 361 | 362 | val trainingData_T = genTransposedData(trainingData, numFeats, label.length) 363 | 364 | trainingData = dataSetLoader.getSampleFeatureData(sc, trainingData, params.sampleFeaturePercent) 365 | 366 | if (params.algo == "Classification") { 367 | label = label.map(x => (x * 2 - 1).toShort) 368 | } 369 | 370 | val trainingDataSet = new dataSet(label, initScores, queryBoundy, trainingData) 371 | 372 | var validtionDataSet: dataSet = null 373 | if (params.validationData != null) { 374 | val validationData = dataSetLoader.loadDataTransposed(sc, params.validationData) 375 | val labelV = dataSetLoader.loadlabelScores(sc, params.labelValidate) 376 | val initScoreV = if (params.initScoreValidate == null) { 377 | new Array[Double](labelV.length) 378 | } 379 | else { 380 | dataSetLoader.loadInitScores(sc, params.initScoreValidate) 381 | } 382 | 383 | val queryBoundyV = if (params.queryBoundy != null) dataSetLoader.loadQueryBoundy(sc, params.queryBoundyValidate) else null 384 | 385 | validtionDataSet = new dataSet(labelV, initScoreV, queryBoundyV, dataTransposed = validationData) 386 | } 387 | println(s"validationDataSet: $validtionDataSet") 388 | 389 | 390 | val gainTable = params.gainTableStr.split(':').map(_.toDouble) 391 | 392 | val boostingStrategy = config.BoostingStrategy.defaultParams(params.algo) 393 | boostingStrategy.treeStrategy.maxDepth = params.maxDepth(0) 394 | 395 | 396 | //extract secondGain and secondInverseMaxDcg 397 | if (params.secondGainsFileName != null && params.secondaryInverseMaxDcgFileName != null) { 398 | val spTf_1 = sc.textFile(params.secondGainsFileName) 399 | if (spTf_1.count() > 0) 400 | params.secondGains = spTf_1.first().split(",").map(_.toDouble) 401 | val spTf_2 = sc.textFile(params.secondaryInverseMaxDcgFileName) 402 | if (spTf_2.count() > 0) 403 | params.secondaryInverseMaxDcg = spTf_2.first().split(",").map(_.toDouble) 404 | } 405 | 406 | if (params.discountsFilename != null) { 407 | val spTf = sc.textFile(params.discountsFilename) 408 | if (spTf.count() > 0) 409 | params.discounts = spTf.first().split(",").map(_.toDouble) 410 | } 411 | 412 | if (params.sampleWeightsFilename != null) { 413 | val spTf = sc.textFile(params.sampleWeightsFilename) 414 | if (spTf.count() > 0) 415 | params.sampleWeights = spTf.first().split(",").map(_.toDouble) 416 | } 417 | 418 | if (params.baselineDcgsFilename != null) { 419 | val spTf = sc.textFile(params.baselineDcgsFilename) 420 | if (spTf.count() > 0) 421 | params.baselineDcgs = spTf.first().split(",").map(_.toDouble) 422 | } 423 | 424 | if (params.baselineAlphaFilename != null) { 425 | val spTf = sc.textFile(params.baselineAlphaFilename) 426 | if (spTf.count() > 0) 427 | params.baselineAlpha = spTf.first().split(",").map(_.toDouble) 428 | } 429 | 430 | val feature2Gain = new Array[Double](numFeats) 431 | if (params.algo == "LambdaMart" || params.algo == "Classification") { 432 | val startTime = System.nanoTime() 433 | val model = LambdaMART.train(trainingDataSet, validtionDataSet, trainingData_T, gainTable, 434 | boostingStrategy, params, feature2Gain) 435 | val elapsedTime = (System.nanoTime() - startTime) / 1e9 436 | println(s"Training time: $elapsedTime seconds") 437 | 438 | // test 439 | if (params.algo == "LambdaMart" && params.testSpan != 0) { 440 | val testNDCG = testModel(sc, model, params, gainTable) 441 | println(s"testNDCG error 0 = " + testNDCG(0)) 442 | for (i <- 1 until testNDCG.length) { 443 | val it = i * params.testSpan 444 | println(s"testNDCG error $it = " + testNDCG(i)) 445 | } 446 | } 447 | else if (params.algo == "Classification" && params.testData != null) { 448 | val testData = MLUtils.loadLibSVMFile(sc, params.testData) 449 | println(s"numSamples: ${testData.count()}") 450 | val scoreAndLabels = testData.map { point => 451 | val predictions = model.trees.map(_.predict(point.features)) 452 | val claPred = new Array[Double](predictions.length) 453 | Range(1,predictions.length).foreach{it=> 454 | predictions(it)+=predictions(it-1) 455 | if(predictions(it)>=0) claPred(it)=1.0 456 | else claPred(it)=0.0 457 | } 458 | 459 | (claPred, point.label) 460 | } 461 | 462 | Range(0,model.trees.length, params.testSpan).foreach{it=> 463 | val scoreLabel = scoreAndLabels.map{case(claPred, lb)=>(claPred(it),lb)} 464 | val metrics = new BinaryClassificationMetrics(scoreLabel) 465 | val accuracy = metrics.areaUnderROC() 466 | println(s"Accuracy $it = $accuracy") 467 | } 468 | } 469 | 470 | // if (params.featureIniFile != null) { 471 | // val featureIniPath = new Path(params.featureIniFile) 472 | // val featurefs = TreeUtils.getFileSystem(trainingData.context.getConf, featureIniPath) 473 | // featurefs.copyToLocalFile(false, featureIniPath, new Path("treeEnsemble.ini")) 474 | // } 475 | 476 | for (i <- 0 until model.trees.length) { 477 | val evaluator = model.trees(i) 478 | evaluator.sequence("treeEnsemble.ini", evaluator, i + 1) 479 | } 480 | println(s"save succeed") 481 | val totalEvaluators = model.trees.length 482 | val evalNodes = Array.tabulate[Int](totalEvaluators)(_ + 1) 483 | treeAggregatorFormat.appendTreeAggregator(params.expandTreeEnsemble, "treeEnsemble.ini", totalEvaluators + 1, evalNodes) 484 | 485 | if (params.feature2NameFile != null) { 486 | val feature2Name = dataSetLoader.loadFeature2NameMap(sc, params.feature2NameFile) 487 | treeAggregatorFormat.toCommentFormat("treeEnsemble.ini", params, feature2Name, feature2Gain) 488 | } 489 | val outPath = new Path(params.outputTreeEnsemble) 490 | val fs = TreeUtils.getFileSystem(trainingData.context.getConf, outPath) 491 | fs.copyFromLocalFile(false, true, new Path("treeEnsemble.ini"), outPath) 492 | 493 | if (model.totalNumNodes < 30) { 494 | println(model.toDebugString) // Print full model. 495 | } else { 496 | println(model) // Print model summary. 497 | } 498 | // val testMSE = meanSquaredError(model, testData) 499 | // println(s"Test mean squared error = $testMSE") 500 | } 501 | } finally { 502 | sc.stop() 503 | } 504 | } 505 | 506 | 507 | 508 | def loadTestData(sc: SparkContext, path: String): RDD[Vector] = { 509 | sc.textFile(path).map { line => 510 | if (line.contains("#")) 511 | Vectors.dense(line.split("#")(1).split(",").map(_.toDouble)) 512 | else 513 | Vectors.dense(line.split(",").map(_.toDouble)) 514 | } 515 | 516 | 517 | // val testData = sc.textFile(path).map {line => line.split("#")(1).split(",").map(_.toDouble)} 518 | // val testTrans = testData.zipWithIndex.flatMap { 519 | // case (row, rowIndex) => row.zipWithIndex.map { 520 | // case (number, columnIndex) => columnIndex -> (rowIndex, number) 521 | // } 522 | // } 523 | // val testT = testTrans.groupByKey.sortByKey().values 524 | // .map { 525 | // indexedRow => indexedRow.toSeq.sortBy(_._1).map(_._2).toArray 526 | // } 527 | // testT.map(line => Vectors.dense(line)) 528 | } 529 | 530 | def genTransposedData(trainingData: RDD[(Int, SparseArray[Short], Array[SplitInfo])], 531 | numFeats: Int, 532 | numSamples: Int): RDD[(Int, Array[Array[Short]])] = { 533 | println("generating transposed data...") 534 | // validate that the original data is ordered 535 | val denseAsc = trainingData.mapPartitions { iter => 536 | var res = Iterator.single(true) 537 | if (iter.hasNext) { 538 | var prev = iter.next()._1 539 | val remaining = iter.dropWhile { case (fi, _, _) => 540 | val goodNext = fi - prev == 1 541 | prev = fi 542 | goodNext 543 | } 544 | res = Iterator.single(!remaining.hasNext) 545 | } 546 | res 547 | }.reduce(_ && _) 548 | assert(denseAsc, "the original data must be ordered.") 549 | println("pass data check in transposing") 550 | 551 | val numPartitions = trainingData.partitions.length 552 | val (siMinPP, lcNumSamplesPP) = TreeUtils.getPartitionOffsets(numSamples, numPartitions) 553 | val trainingData_T = trainingData.mapPartitions { iter => 554 | val (metaIter, dataIter) = iter.duplicate 555 | val fiMin = metaIter.next()._1 556 | val lcNumFeats = metaIter.length + 1 557 | val blocksPP = Array.tabulate(numPartitions)(pi => Array.ofDim[Short](lcNumFeats, lcNumSamplesPP(pi))) 558 | dataIter.foreach { case (fi, sparseSamples, _) => 559 | val samples = sparseSamples.toArray 560 | val lfi = fi - fiMin 561 | var pi = 0 562 | while (pi < numPartitions) { 563 | Array.copy(samples, siMinPP(pi), blocksPP(pi)(lfi), 0, lcNumSamplesPP(pi)) 564 | pi += 1 565 | } 566 | } 567 | Range(0, numPartitions).iterator.map(pi => (pi, (fiMin, blocksPP(pi)))) 568 | }.partitionBy(new HashPartitioner(numPartitions)).mapPartitionsWithIndex((pid, iter) => { 569 | val siMin = siMinPP(pid) 570 | val sampleSlice = new Array[Array[Short]](numFeats) 571 | iter.foreach { case (_, (fiMin, blocks)) => 572 | var lfi = 0 573 | while (lfi < blocks.length) { 574 | sampleSlice(lfi + fiMin) = blocks(lfi) 575 | lfi += 1 576 | } 577 | } 578 | Iterator.single((siMin, sampleSlice)) 579 | }, preservesPartitioning = true) 580 | trainingData_T.persist(StorageLevel.MEMORY_AND_DISK).setName("trainingData_T").count() 581 | trainingData_T 582 | } 583 | 584 | def testModel(sc: SparkContext, model: GradientBoostedDecisionTreesModel, params: Params, gainTable: Array[Double]): Array[Double] = { 585 | val testData = loadTestData(sc, params.testData).cache().setName("TestData") 586 | println(s"numTestFeature: ${testData.first().toArray.length}") 587 | val numTest = testData.count() 588 | println(s"numTest: $numTest") 589 | val testLabels = dataSetLoader.loadlabelScores(sc, params.testLabel) 590 | 591 | println(s"numTestLabels: ${testLabels.length}") 592 | require(testLabels.length == numTest, s"lengthOfLabels: ${testLabels.length} != numTestSamples: $numTest") 593 | val testQueryBound = dataSetLoader.loadQueryBoundy(sc, params.testQueryBound) 594 | require(testQueryBound.last == numTest, s"TestQueryBoundy ${testQueryBound.last} does not match with test data $numTest!") 595 | 596 | val rate = params.testSpan 597 | val predictions = testData.map { features => 598 | val scores = model.trees.map(_.predict(features)) 599 | for (it <- 1 until model.trees.length) { 600 | scores(it) += scores(it - 1) 601 | } 602 | 603 | scores.zipWithIndex.collect { 604 | case (score, it) if it == 0 || (it + 1) % rate == 0 => score 605 | } 606 | } 607 | val predictionsByIter = predictions.zipWithIndex.flatMap { 608 | case (row, rowIndex) => row.zipWithIndex.map { 609 | case (number, columnIndex) => columnIndex ->(rowIndex, number) 610 | } 611 | }.groupByKey.sortByKey().values 612 | .map { 613 | indexedRow => indexedRow.toArray.sortBy(_._1).map(_._2) 614 | } 615 | 616 | val learningRates = params.learningRate 617 | val distanceWeight2 = params.distanceWeight2 618 | val baselineAlpha = params.baselineAlpha 619 | val secondMs = params.secondaryMS 620 | val secondLe = params.secondaryLE 621 | val secondGains = params.secondGains 622 | val secondaryInverseMacDcg = params.secondaryInverseMaxDcg 623 | val discounts = params.discounts 624 | val baselineDcgs = params.baselineDcgs 625 | val dc = new DerivativeCalculator 626 | dc.init(testLabels, gainTable, testQueryBound, 627 | learningRates(0), distanceWeight2, baselineAlpha, 628 | secondMs, secondLe, secondGains, secondaryInverseMacDcg, discounts, baselineDcgs) 629 | val numQueries = testQueryBound.length - 1 630 | val dcBc = sc.broadcast(dc) 631 | 632 | predictionsByIter.map { scores => 633 | val dc = dcBc.value 634 | dc.getPartErrors(scores, 0, numQueries) / numQueries 635 | }.collect() 636 | 637 | } 638 | 639 | /** * 640 | * def meanSquaredError( 641 | * model: { def predict(features: Vector): Double }, 642 | * data: RDD[LabeledPoint]): Double = { 643 | * data.map { y => 644 | * val err = model.predict(y.features) - y.label 645 | * err * err 646 | * }.mean() 647 | * } ***/ 648 | } 649 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/dataSet/dataSet.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.dataSet 2 | 3 | import breeze.collection.mutable.SparseArray 4 | import breeze.linalg.SparseVector 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.mllib.linalg.Vectors 7 | import org.apache.spark.mllib.tree.model.SplitInfo 8 | import org.apache.spark.rdd.RDD 9 | import org.apache.spark.storage.StorageLevel 10 | 11 | import scala.collection.mutable 12 | 13 | /** 14 | * Created by jlinleung on 2016/5/1. 15 | */ 16 | class dataSet(label: Array[Short] = null, 17 | score: Array[Double] = null, 18 | queryBoundy: Array[Int] = null, 19 | data: RDD[(Int, SparseArray[Short], Array[SplitInfo])] = null, 20 | dataTransposed: RDD[(Int, org.apache.spark.mllib.linalg.Vector)] = null) { 21 | def getData(): RDD[(Int, SparseArray[Short], Array[SplitInfo])] = { 22 | data 23 | } 24 | 25 | def getDataTransposed(): RDD[(Int, org.apache.spark.mllib.linalg.Vector)] = { 26 | dataTransposed 27 | } 28 | 29 | def getLabel(): Array[Short] = { 30 | label 31 | } 32 | 33 | def getScore(): Array[Double] = { 34 | score 35 | } 36 | 37 | def getQueryBoundy(): Array[Int] = { 38 | queryBoundy 39 | } 40 | } 41 | 42 | 43 | object dataSetLoader { 44 | 45 | def loadData(sc: SparkContext, path: String, minPartitions: Int) 46 | : RDD[(Int, SparseVector[Short], Array[SplitInfo])] = { 47 | sc.textFile(path, minPartitions).map { line => 48 | val parts = line.split("#") 49 | val feat = parts(0).toInt 50 | val samples = parts(1).split(',').map(_.toShort) 51 | 52 | var is = 0 53 | var nnz = 0 54 | while (is < samples.length) { 55 | if (samples(is) != 0) { 56 | nnz += 1 57 | } 58 | is += 1 59 | } 60 | val idx = new Array[Int](nnz) 61 | val vas = new Array[Short](nnz) 62 | is = 0 63 | nnz = 0 64 | while (is < samples.length) { 65 | if (samples(is) != 0) { 66 | idx(nnz) = is 67 | vas(nnz) = samples(is) 68 | nnz += 1 69 | } 70 | is += 1 71 | } 72 | val sparseSamples = new SparseVector[Short](idx, vas, nnz, is) 73 | 74 | val splits = if (parts.length > 2) { 75 | parts(2).split(',').map(threshold => new SplitInfo(feat, threshold.toDouble)) 76 | } else { 77 | val maxFeat = sparseSamples.valuesIterator.max + 1 78 | Array.tabulate(maxFeat)(threshold => new SplitInfo(feat, threshold)) 79 | } 80 | (feat, sparseSamples, splits) 81 | }.persist(StorageLevel.MEMORY_AND_DISK).setName("trainingData") 82 | } 83 | 84 | def loadTrainingDataForLambdamart(sc: SparkContext, path: String, minPartitions: Int, 85 | sampleQueryId: Array[Int], QueryBound: Array[Int], numSampling: Int) 86 | : RDD[(Int, SparseArray[Short], Array[SplitInfo])] = { 87 | var rdd = sc.textFile(path, minPartitions).map { line => 88 | val parts = line.split("#") 89 | val feat = parts(0).toInt 90 | val samples = parts(1).split(',').map(_.toShort) 91 | var is = 0 92 | // sampling data 93 | val samplingData = if (sampleQueryId == null) samples 94 | else { 95 | val sd = new Array[Short](numSampling) 96 | var it = 0 97 | var icur = 0 98 | while (it < sampleQueryId.length) { 99 | val query = sampleQueryId(it) 100 | for (is <- QueryBound(query) until QueryBound(query + 1)) { 101 | sd(icur) = samples(is) 102 | icur += 1 103 | } 104 | it += 1 105 | } 106 | sd 107 | } 108 | 109 | // // Sparse data 110 | // is = 0 111 | // var nnz = 0 112 | // while (is < samplingData.length) { 113 | // if (samplingData(is) != 0) { 114 | // nnz += 1 115 | // } 116 | // is += 1 117 | // } 118 | // val idx = new Array[Int](nnz) 119 | // val vas = new Array[Short](nnz) 120 | // is = 0 121 | // nnz = 0 122 | // while (is < samplingData.length) { 123 | // if (samplingData(is) != 0) { 124 | // idx(nnz) = is 125 | // vas(nnz) = samplingData(is) 126 | // nnz += 1 127 | // } 128 | // is += 1 129 | // } 130 | // val sparseSamples = new SparseVector[Short](idx, vas, nnz, is) 131 | 132 | val v2no = new mutable.HashMap[Short, Int]().withDefaultValue(0) 133 | is = 0 134 | while (is < samplingData.length) { 135 | v2no(samplingData(is)) += 1 136 | is += 1 137 | } 138 | val (default, numDefault) = v2no.maxBy(x => x._2) 139 | val numAct = samplingData.length - numDefault 140 | val idx = new Array[Int](numAct) 141 | val vas = new Array[Short](numAct) 142 | is = 0 143 | var nnz = 0 144 | while (is < samplingData.length) { 145 | if (samplingData(is) != default) { 146 | idx(nnz) = is 147 | vas(nnz) = samplingData(is) 148 | nnz += 1 149 | } 150 | is += 1 151 | } 152 | val sparseSamples = new SparseArray[Short](idx, vas, nnz, is, default) 153 | 154 | // val sparseSamples = new SparseVector[Short](sparseArr) 155 | 156 | 157 | val splits = if (parts.length > 2) { 158 | parts(2).split(',').map(threshold => new SplitInfo(feat, threshold.toDouble)) 159 | } else { 160 | val maxFeat = samples.max + 1 161 | Array.tabulate(maxFeat)(threshold => new SplitInfo(feat, threshold)) 162 | } 163 | (feat, sparseSamples, splits) 164 | } 165 | rdd = sc.getConf.getOption("lambdaMart_numPartitions").map(_.toInt) match { 166 | case Some(np) => 167 | println("repartitioning") 168 | rdd.sortBy(x => x._1, numPartitions = np) 169 | 170 | case None => rdd 171 | } 172 | 173 | rdd.persist(StorageLevel.MEMORY_AND_DISK).setName("trainingData") 174 | } 175 | 176 | def loadTrainingDataForClassification(sc: SparkContext, path: String, minPartitions: Int, numDoc: Int) 177 | : RDD[(Int, SparseArray[Short], Array[SplitInfo])] = { 178 | var rdd = sc.textFile(path, minPartitions).map { line => 179 | val parts = line.split("#") 180 | val feat = parts(0).toInt 181 | val samples = parts(1).split(',').map(_.toShort) 182 | var is = 0 183 | // sampling data 184 | val samplingData = if (numDoc == samples.length) { 185 | samples 186 | } 187 | else if(numDoc x._2) 206 | val numAct = samplingData.length - numDefault 207 | val idx = new Array[Int](numAct) 208 | val vas = new Array[Short](numAct) 209 | is = 0 210 | var nnz = 0 211 | while (is < samplingData.length) { 212 | if (samplingData(is) != default) { 213 | idx(nnz) = is 214 | vas(nnz) = samplingData(is) 215 | nnz += 1 216 | } 217 | is += 1 218 | } 219 | val sparseSamples = new SparseArray[Short](idx, vas, nnz, numDoc, default) 220 | 221 | val splits = if (parts.length > 2) { 222 | parts(2).split(',').map(threshold => new SplitInfo(feat, threshold.toDouble)) 223 | } else { 224 | val maxFeat = samples.max + 1 225 | Array.tabulate(maxFeat)(threshold => new SplitInfo(feat, threshold)) 226 | } 227 | (feat, sparseSamples, splits) 228 | } 229 | 230 | rdd = sc.getConf.getOption("lambdaMart_numPartitions").map(_.toInt) match { 231 | case Some(np) => 232 | println("repartitioning") 233 | rdd.sortBy(x => x._1, numPartitions = np) 234 | 235 | case None => rdd 236 | } 237 | 238 | rdd.persist(StorageLevel.MEMORY_AND_DISK).setName("trainingData") 239 | } 240 | 241 | def loadDataTransposed(sc: SparkContext, path: String): RDD[(Int, org.apache.spark.mllib.linalg.Vector)] = { 242 | sc.textFile(path).map { line => 243 | val parts = line.split("#") 244 | val sId = parts(0).toInt 245 | val features = parts(1).split(",").map(_.toDouble) 246 | (sId, Vectors.dense(features)) 247 | }.persist(StorageLevel.MEMORY_AND_DISK).setName("validationData") 248 | } 249 | 250 | def loadlabelScores(sc: SparkContext, path: String): Array[Short] = { 251 | sc.textFile(path).first().split(',').map(_.toShort) 252 | } 253 | 254 | def loadInitScores(sc: SparkContext, path: String): Array[Double] = { 255 | sc.textFile(path).first().split(',').map(_.toDouble) 256 | } 257 | 258 | def loadQueryBoundy(sc: SparkContext, path: String): Array[Int] = { 259 | 260 | sc.textFile(path).first().split(',').map(_.toInt) 261 | } 262 | 263 | def loadThresholdMap(sc: SparkContext, path: String, numFeats: Int): Array[Array[Double]] = { 264 | val thresholdMapTuples = sc.textFile(path).map { line => 265 | val fields = line.split("#", 2) 266 | (fields(0).toInt, fields(1).split(',').map(_.toDouble)) 267 | }.collect() 268 | val numFeatsTM = thresholdMapTuples.length 269 | assert(numFeats == numFeatsTM, s"ThresholdMap file contains $numFeatsTM features that != $numFeats") 270 | val thresholdMap = new Array[Array[Double]](numFeats) 271 | thresholdMapTuples.foreach { case (fi, thresholds) => 272 | thresholdMap(fi) = thresholds 273 | } 274 | thresholdMap 275 | } 276 | 277 | def loadFeature2NameMap(sc: SparkContext, path: String): Array[String] = { 278 | sc.textFile(path).map(line => line.split("#")(1)).collect() 279 | } 280 | 281 | def getSampleLabels(testQueryId: Array[Int], QueryBound: Array[Int], labels: Array[Short]): Array[Short] = { 282 | println("parse test labels") 283 | val testLabels = new Array[Short](labels.length) 284 | var it = 0 285 | var icur = 0 286 | while (it < testQueryId.length) { 287 | val query = testQueryId(it) 288 | for (is <- QueryBound(query) until QueryBound(query + 1)) { 289 | testLabels(icur) = labels(is) 290 | icur += 1 291 | } 292 | it += 1 293 | } 294 | testLabels.dropRight(labels.length - icur) 295 | } 296 | 297 | def getSampleInitScores(trainingQueryId: Array[Int], QueryBound: Array[Int], scores: Array[Double], len: Int): Array[Double] = { 298 | println("parse init scores") 299 | val trainingScores = new Array[Double](len) 300 | var it = 0 301 | var icur = 0 302 | while (it < trainingQueryId.length) { 303 | val query = trainingQueryId(it) 304 | for (is <- QueryBound(query) until QueryBound(query + 1)) { 305 | trainingScores(icur) = scores(is) 306 | icur += 1 307 | } 308 | it += 1 309 | } 310 | trainingScores 311 | } 312 | 313 | def getSampleQueryBound(QueryId: Array[Int], queryBoundy: Array[Int]): Array[Int] = { 314 | println("get query bound") 315 | val res = new Array[Int](QueryId.length + 1) 316 | res(0) = 0 317 | var qid = 0 318 | while (qid < QueryId.length) { 319 | res(qid + 1) = res(qid) + queryBoundy(QueryId(qid) + 1) - queryBoundy(QueryId(qid)) 320 | qid += 1 321 | } 322 | res 323 | } 324 | 325 | def getSampleFeatureData(sc: SparkContext, trainingData: RDD[(Int, SparseArray[Short], Array[SplitInfo])], sampleFeatPct: Double) = { 326 | // def IsSeleted(ffraction: Double): Boolean = { 327 | // val randomNum = scala.util.Random.nextDouble() 328 | // var active = false 329 | // if(randomNum < ffraction) 330 | // { 331 | // active = true 332 | // } 333 | // active 334 | // } 335 | val rdd = if (sampleFeatPct < 1.0) { 336 | var sampleData = trainingData.sample(false, sampleFeatPct) 337 | // sampleData = sc.getConf.getOption("lambdaMart_numPartitions").map(_.toInt) match { 338 | // case Some(np) => sampleData.sortBy(x => x._1, numPartitions = np) 339 | // case None => sampleData 340 | // } 341 | // val sampleData = trainingData.filter(item =>IsSeleted(sampleFeatPct)) 342 | val numFeats_S = sampleData.count() 343 | println(s"numFeats_sampling: $numFeats_S") 344 | println(s"numPartitions_sampling: ${sampleData.partitions.length}") 345 | trainingData.unpersist(blocking = false) 346 | sampleData 347 | } else trainingData 348 | rdd.persist(StorageLevel.MEMORY_AND_DISK).setName("sampleTrainingData") 349 | } 350 | } 351 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/DerivativeCalculator.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.tree 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | 5 | class DerivativeCalculator extends Serializable { 6 | val expAsymptote: Double = -50 7 | val sigmoidBins: Int = 1000000 8 | 9 | var sigmoidTable: Array[Double] = null 10 | var minScore: Double = _ 11 | var maxScore: Double = _ 12 | var scoreToSigmoidTableFactor: Double =_ 13 | var _distanceWeight2 = false //TODO CMD PARAMETER 14 | 15 | var discounts: Array[Double] = null 16 | var ratings: Array[Short] = null 17 | var gainTable: Array[Double] = null 18 | 19 | val _normalizeQueryLambdas = true 20 | 21 | val maxNumPositions = 5000 22 | // var labels: Array[Byte] 23 | var secondaryGains: Array[Double] = null 24 | var _secondaryMetricShare: Double =_ 25 | var _secondaryIsolabelExclusive: Boolean=_ 26 | var _secondaryInverseMaxDCGT: Array[Double] = null 27 | 28 | var queryBoundy: Array[Int] = null 29 | var inverseMaxDCGs: Array[Double] = null 30 | var _baselineAlpha: Array[Double] = null 31 | var _baselineDcg: Array[Double] = null 32 | 33 | def init(ratings: Array[Short], gainTable: Array[Double], queryBoundy: Array[Int], sigma: Double, 34 | distanceWeight2: Boolean, baselineAlpha: Array[Double], 35 | secondaryMetricShare: Double, 36 | secondaryIsolabelExclusive: Boolean, 37 | secondaryGain: Array[Double], 38 | secondaryInverseMaxDcg: Array[Double], 39 | discount: Array[Double], 40 | baselineDcgs: Array[Double]): Unit = { 41 | initSigmoidTable(sigma) 42 | _distanceWeight2 = distanceWeight2 43 | _baselineAlpha = baselineAlpha 44 | 45 | 46 | _baselineDcg = baselineDcgs 47 | discounts = if(discount == null) { 48 | Array.tabulate(maxNumPositions)(i => 1.0 / math.log(i + 2.0)) 49 | } 50 | else{ 51 | discount 52 | } 53 | 54 | this.ratings = ratings 55 | this.gainTable = gainTable 56 | 57 | setupSecondaryGains(secondaryMetricShare, 58 | secondaryIsolabelExclusive, 59 | secondaryGain, 60 | secondaryInverseMaxDcg) 61 | 62 | calcInverseMaxDCGs(queryBoundy) 63 | } 64 | 65 | private def initSigmoidTable(sigma: Double): Unit = { 66 | // minScore is such that 2*sigma*score is < expAsymptote if score < minScore 67 | minScore = expAsymptote / sigma / 2 68 | maxScore = -minScore 69 | scoreToSigmoidTableFactor = sigmoidBins / (maxScore - minScore) 70 | 71 | sigmoidTable = new Array[Double](sigmoidBins) 72 | var i = 0 73 | while (i < sigmoidBins) { 74 | val score = (maxScore - minScore) / sigmoidBins * i + minScore 75 | sigmoidTable(i) = if (score > 0.0) { 76 | 2.0 - 2.0 / (1.0 + math.exp(-2.0 * sigma * score)) 77 | } else { 78 | 2.0 / (1.0 + math.exp(2.0 * sigma * score)) 79 | } 80 | i += 1 81 | } 82 | 83 | } 84 | 85 | def setupSecondaryGains(secondaryMetricShare: Double, 86 | secondaryIsolabelExclusive: Boolean, 87 | secondaryGain: Array[Double], 88 | secondaryInverseMaxDcg: Array[Double]): Unit ={ 89 | _secondaryMetricShare = secondaryMetricShare 90 | _secondaryIsolabelExclusive = secondaryIsolabelExclusive 91 | secondaryGains = secondaryGain 92 | 93 | if(secondaryMetricShare != 0.0 && secondaryGains != null){ 94 | _secondaryInverseMaxDCGT = secondaryInverseMaxDcg 95 | } 96 | } 97 | 98 | private def calcInverseMaxDCGs(queryBoundy: Array[Int]): Unit = { 99 | this.queryBoundy = queryBoundy 100 | val numQueries = queryBoundy.length - 1 101 | inverseMaxDCGs = new Array[Double](numQueries) 102 | var qi = 0 103 | while (qi < numQueries) { 104 | val siMin = queryBoundy(qi) 105 | val siEnd = queryBoundy(qi + 1) 106 | val ratings_sorted = ratings.view(siMin, siEnd).toSeq.sorted.reverse.toArray 107 | 108 | var MaxDCGQ = 0.0 109 | val numDocs = siEnd - siMin 110 | var odi = 0 111 | while (odi < numDocs) { 112 | MaxDCGQ += gainTable(ratings_sorted(odi)) * discounts(odi) 113 | odi += 1 114 | } 115 | val inverseMaxDCGQ = if (MaxDCGQ == 0.0) 0.0 else 1.0 / MaxDCGQ 116 | inverseMaxDCGs(qi) = inverseMaxDCGQ 117 | //println(">>>>>>>>>>>>>>") 118 | //println(s"query: $qi, numdocs: $numDocs") 119 | //println(ratings.view(siMin, siEnd).mkString("\t")) 120 | //println(s"MaxDcg: $MaxDCGQ") 121 | qi += 1 122 | } 123 | } 124 | 125 | private def ScoreSort(scores: Array[Double], siMin: Int, siEnd: Int): Array[Short] = { 126 | scores.view(siMin, siEnd).map(_.toShort).toSeq.sorted.reverse.toArray 127 | } 128 | 129 | private def docIdxSort(scores: Array[Double], siMin: Int, siEnd: Int): Array[Int] = { 130 | Range(siMin, siEnd).sortBy(scores).reverse.map(_ - siMin).toArray 131 | } 132 | 133 | def getPartDerivatives(scores: Array[Double], qiMin: Int, qiEnd: Int, iteration: Int): (Int, Array[Double], Array[Double]) = { 134 | val siTotalMin = queryBoundy(qiMin) 135 | val numTotalDocs = queryBoundy(qiEnd) - siTotalMin 136 | val lcLambdas = new Array[Double](numTotalDocs) 137 | val lcWeights = new Array[Double](numTotalDocs) 138 | var qi = qiMin 139 | while (qi < qiEnd) { 140 | val lcMin = queryBoundy(qi) - siTotalMin 141 | 142 | val siMin = queryBoundy(qi) 143 | val siEnd = queryBoundy(qi + 1) 144 | val numDocsPerQuery = siEnd - siMin 145 | val permutation = docIdxSort(scores, siMin, siEnd) 146 | 147 | var baselineVersusCurrentDcg = 0.0 148 | if(_baselineDcg != null){ 149 | baselineVersusCurrentDcg = _baselineDcg(qi) 150 | for(i<- 0 until numDocsPerQuery){ 151 | baselineVersusCurrentDcg -= gainTable(ratings(permutation(i) + siMin)) * discounts(i) 152 | } 153 | } 154 | 155 | val baselineAlphaRisk = if(_baselineAlpha == null) 0.0 else _baselineAlpha(iteration) 156 | val secondaryIMDcg = if(_secondaryInverseMaxDCGT == null) 1.0 else _secondaryInverseMaxDCGT(qi) 157 | 158 | calcQueryDerivatives(qi, scores, lcLambdas, lcWeights, permutation, lcMin, 159 | _secondaryMetricShare, _secondaryIsolabelExclusive, _distanceWeight2, 160 | baselineAlphaRisk, secondaryIMDcg, baselineVersusCurrentDcg) 161 | qi += 1 162 | } 163 | (siTotalMin, lcLambdas, lcWeights) 164 | } 165 | 166 | def getPartErrors(scores: Array[Double], qiMin: Int, qiEnd: Int): Double = { 167 | var errors = 0.0 168 | var qi = qiMin 169 | while (qi < qiEnd) { 170 | val siMin = queryBoundy(qi) 171 | val siEnd = queryBoundy(qi + 1) 172 | val numDocs = siEnd - siMin 173 | val permutation = docIdxSort(scores, siMin, siEnd) 174 | var dcg = 0.0 175 | var odi = 0 176 | while (odi < numDocs) { 177 | dcg += gainTable(ratings(permutation(odi) + siMin)) * discounts(odi) 178 | odi += 1 179 | } 180 | errors += 1 - dcg * inverseMaxDCGs(qi) 181 | qi += 1 182 | } 183 | errors 184 | } 185 | 186 | private def calcQueryDerivatives(qi: Int, 187 | scores: Array[Double], 188 | lcLambdas: Array[Double], 189 | lcWeights: Array[Double], 190 | permutation: Array[Int], 191 | lcMin: Int, 192 | secondaryMetricShare: Double, 193 | secondaryExclusive: Boolean, 194 | distanceWeight2: Boolean, 195 | alphaRisk: Double, 196 | secondaryInverseMaxDCG: Double, 197 | baselineVersusCurrentDcg: Double, 198 | costFunctionParam: Char = 'w', 199 | minDoubleValue: Double = Double.MinValue): Unit = { 200 | val siMin = queryBoundy(qi) 201 | val siEnd = queryBoundy(qi + 1) 202 | val numDocs = siEnd - siMin 203 | val tmpArray = new ArrayBuffer[Double] 204 | for(i <- 0 until numDocs){ 205 | tmpArray += scores(i + siMin) 206 | } 207 | 208 | val inverseMaxDCG = inverseMaxDCGs(qi) 209 | /** 210 | * println(">>>>>>>>>>>>>>") 211 | * println(s"query: $qi, numdocs: $numDocs") 212 | * println(s"label: " + ratings.view(siMin, siEnd).mkString(",") + "\t" + s"permutation: " + permutation.mkString(",")) 213 | * println(s"scores: " + scores.view(siMin, siEnd).mkString(",") + "\t" + s"discount: " + discounts.view(0,20).mkString(",")) 214 | * println(s"inverseMaxDcg: $inverseMaxDCG") 215 | * println(s"mins: $minScore, maxs: $maxScore, factor: $scoreToSigmoidTableFactor") 216 | ** 217 | * 218 | *println("**************") 219 | *println(tmpArray.toString()) 220 | *println(permutation.mkString(",")) 221 | ** 222 | *println(s"inverseMaxDCG: $inverseMaxDCG") **/ 223 | 224 | val bestScore = scores(permutation.head + siMin) 225 | var worstIndexToConsider = numDocs - 1 226 | while (worstIndexToConsider > 0 && scores(permutation(worstIndexToConsider) + siMin) == minDoubleValue) { 227 | worstIndexToConsider -= 1 228 | } 229 | val worstScore = scores(permutation(worstIndexToConsider) + siMin) 230 | 231 | var lambdaSum = 0.0 232 | 233 | // Should we still run the calculation on those pairs which are ostensibly the same? 234 | val pairSame = secondaryMetricShare != 0.0 235 | 236 | 237 | // Did not help to use pointer match on pPermutation[i] 238 | for (odi <- 0 until numDocs) { 239 | val di = permutation(odi) 240 | val sHigh = di + siMin 241 | val labelHigh =ratings(sHigh) 242 | val scoreHigh = scores(sHigh) 243 | 244 | if (!((labelHigh == 0 && !pairSame) || scoreHigh == minDoubleValue)) { 245 | var deltaLambdasHigh: Double = 0.0 246 | var deltaWeightsHigh: Double = 0.0 247 | 248 | for (odj <- 0 until numDocs) { 249 | 250 | val dj = permutation(odj) 251 | val sLow = dj + siMin 252 | val labelLow = ratings(sLow) 253 | val scoreLow = scores(sLow) 254 | 255 | val flag = if (pairSame) labelHigh < labelLow else labelHigh <= labelLow 256 | if (!(flag || scores(sLow) == minDoubleValue)) { 257 | val scoreHighMinusLow = scoreHigh - scoreLow 258 | if (!(secondaryMetricShare == 0.0 && labelHigh == labelLow && scoreHighMinusLow <= 0)) { 259 | 260 | //println("labelHigh", labelHigh, "aLabels(siLow)", aLabels(siLow), "scoreHighMinusLow", scoreHighMinusLow) 261 | var dcgGap: Double = gainTable(ratings(sHigh)) - gainTable(ratings(sLow)) 262 | var currentInverseMaxDCG = inverseMaxDCG * (1.0 - secondaryMetricShare) 263 | 264 | val pairedDiscount = (discounts(odi) - discounts(odj)).abs 265 | if (alphaRisk > 0) { 266 | var risk = 0.0 267 | val baselineDenorm = baselineVersusCurrentDcg / pairedDiscount 268 | if (baselineVersusCurrentDcg > 0) { 269 | risk = if (scoreHighMinusLow <= 0 && dcgGap > baselineDenorm) baselineDenorm else dcgGap 270 | } else if (scoreHighMinusLow > 0) { 271 | // The baseline is currently lower, but this pair is ranked correctly. 272 | risk = baselineDenorm + dcgGap 273 | } 274 | if (risk > 0) { 275 | dcgGap += alphaRisk * risk 276 | } 277 | } 278 | 279 | val lambdaP = if (scoreHighMinusLow <= minScore) { 280 | sigmoidTable.head 281 | } else if (scoreHighMinusLow >= maxScore) { 282 | sigmoidTable.last 283 | } else { 284 | sigmoidTable(((scoreHighMinusLow - minScore) * scoreToSigmoidTableFactor).toInt) 285 | } 286 | val weightP = lambdaP * (2.0 - lambdaP) 287 | 288 | var sameLabel = labelHigh == labelLow 289 | if (!(secondaryMetricShare != 0.0 && (sameLabel || currentInverseMaxDCG == 0.0) && secondaryGains != null &&secondaryGains(sHigh) <= secondaryGains(sLow))) { 290 | if (secondaryMetricShare != 0.0) { 291 | if (sameLabel || currentInverseMaxDCG == 0.0) { 292 | // We should use the secondary metric this time. 293 | dcgGap = secondaryGains(sHigh) - secondaryGains(sLow) 294 | currentInverseMaxDCG = secondaryInverseMaxDCG * secondaryMetricShare 295 | sameLabel = false 296 | } else if (!secondaryExclusive && secondaryGains(sHigh) > secondaryGains(sLow)) { 297 | var sIDCG = secondaryInverseMaxDCG * secondaryMetricShare 298 | dcgGap = dcgGap / sIDCG + (secondaryGains(sHigh) - secondaryGains(sLow)) / currentInverseMaxDCG 299 | currentInverseMaxDCG *= sIDCG 300 | } 301 | } 302 | // calculate the deltaNDCGP for this pair 303 | var deltaNDCGP = dcgGap * pairedDiscount * currentInverseMaxDCG 304 | 305 | // apply distanceWeight2 only to regular pairs 306 | if (!sameLabel && distanceWeight2 && bestScore != worstScore) { 307 | deltaNDCGP /= (.01 + (scoreHigh - scoreLow).abs) 308 | } 309 | //println("lambda", lambdaP * deltaNDCGP, "deltaNDCGP", deltaNDCGP, "dcgGap", dcgGap, "pairedDiscount", pairedDiscount, "currentInverseMaxDCG", currentInverseMaxDCG) 310 | // update lambdas and weights 311 | deltaLambdasHigh += lambdaP * deltaNDCGP 312 | // println("*****************") 313 | //println(s"lambdaP: $lambdaP, deltaNDCGP: $deltaNDCGP") 314 | deltaWeightsHigh += weightP * deltaNDCGP 315 | lcLambdas(dj + lcMin) -= lambdaP * deltaNDCGP 316 | lcWeights(dj + lcMin) += weightP * deltaNDCGP 317 | 318 | lambdaSum += 2 * lambdaP * deltaNDCGP 319 | } 320 | } 321 | } 322 | } 323 | //Finally, add the values for the siHigh part of the pair that we accumulated across all the low parts 324 | 325 | lcLambdas(di + lcMin) += deltaLambdasHigh 326 | lcWeights(di + lcMin) += deltaWeightsHigh 327 | } 328 | } 329 | 330 | /*** 331 | *if(qi < 15) 332 | *{ 333 | *println(s"lambdas in query $qi") 334 | *for(i <- 0 until numDocs) 335 | *{ 336 | *print(lcLambdas(lcMin + i) + ",") 337 | *} 338 | *println() 339 | *} ***/ 340 | 341 | if(_normalizeQueryLambdas) 342 | { 343 | if(lambdaSum > 0) 344 | { 345 | val normFactor = (10 * math.log(1 + lambdaSum))/lambdaSum 346 | for(i <- 0 until numDocs) 347 | { 348 | lcLambdas(lcMin + i) = lcLambdas(lcMin + i) * normFactor 349 | lcWeights(lcMin + i) = lcWeights(lcMin + i) * normFactor 350 | } 351 | } 352 | } 353 | 354 | } 355 | } 356 | 357 | /***** 358 | *object Derivate { 359 | *def main(args: Array[String]){ 360 | *val numDocuments = 5; val begin = 0 361 | *val aPermutation = Array(1, 4, 3, 4, 2); val aLabels: Array[Short] = Array(1, 2, 3, 4, 5) 362 | *val aScores = Array(1.0, 3.0, 8.0, 15.0, 31.0) 363 | *val aDiscount = Array(0.2, 0.5, 0.7, 0.8, 0.9) 364 | *val inverseMaxDCG = 0.01 365 | *val aGainLabels = Array(0.3, 0.4, 0.5, 0.8, 0.3) 366 | *val aSecondaryGains = Array(0.3, 0.4, 0.5, 0.8, 0.3); val asigmoidTable =GetDerivatives.FillSigmoidTable() 367 | *val minScore = 0.08; val maxScore = 0.2 368 | *val scoreToSigmoidTableFactor = 4 369 | ** 370 | *GetDerivatives.GetDerivatives_lambda_weight( 371 | *numDocuments, begin, 372 | *aPermutation, aLabels, 373 | *aScores, 374 | *aDiscount, aGainLabels, inverseMaxDCG, 375 | *asigmoidTable, minScore, maxScore, scoreToSigmoidTableFactor, aSecondaryGains 376 | *) 377 | * } 378 | *} 379 | *****/ 380 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/LambdaMART.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.tree 2 | 3 | //import akka.io.Udp.SO.Broadcast 4 | import java.io.{File, FileOutputStream, PrintWriter} 5 | 6 | import breeze.linalg.min 7 | import org.apache.hadoop.fs.Path 8 | import org.apache.spark.Logging 9 | import org.apache.spark.broadcast.Broadcast 10 | import org.apache.spark.examples.mllib.LambdaMARTRunner.Params 11 | import org.apache.spark.mllib.dataSet.dataSet 12 | import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics 13 | import org.apache.spark.mllib.tree.config.Algo._ 14 | import org.apache.spark.mllib.tree.config.BoostingStrategy 15 | import org.apache.spark.mllib.tree.impl.TimeTracker 16 | import org.apache.spark.mllib.tree.impurity.Variance 17 | import org.apache.spark.mllib.tree.model.ensemblemodels.GradientBoostedDecisionTreesModel 18 | import org.apache.spark.mllib.tree.model.opdtmodel.OptimizedDecisionTreeModel 19 | import org.apache.spark.mllib.util.TreeUtils 20 | import org.apache.spark.rdd.RDD 21 | 22 | import scala.util.Random 23 | 24 | class LambdaMART(val boostingStrategy: BoostingStrategy, 25 | val params: Params) extends Serializable with Logging { 26 | def run(trainingDataSet: dataSet, 27 | validateDataSet: dataSet, 28 | trainingData_T: RDD[(Int, Array[Array[Short]])], 29 | gainTable: Array[Double], 30 | feature2Gain: Array[Double]): GradientBoostedDecisionTreesModel = { 31 | val algo = boostingStrategy.treeStrategy.algo 32 | 33 | algo match { 34 | case LambdaMart => 35 | LambdaMART.boostMart(trainingDataSet, validateDataSet, trainingData_T, gainTable, 36 | boostingStrategy, params, feature2Gain) 37 | case Regression => 38 | LambdaMART.boostRegression(trainingDataSet, validateDataSet, trainingData_T, gainTable, 39 | boostingStrategy, params, feature2Gain) 40 | case Classification => 41 | 42 | LambdaMART.boostRegression(trainingDataSet, validateDataSet, trainingData_T, gainTable, 43 | boostingStrategy, params, feature2Gain) 44 | case _ => 45 | throw new IllegalArgumentException(s"$algo is not supported by the implementation of LambdaMART.") 46 | } 47 | } 48 | } 49 | 50 | object LambdaMART extends Logging { 51 | def train(trainingDataSet: dataSet, 52 | validateDataSet: dataSet, 53 | trainingData_T: RDD[(Int, Array[Array[Short]])], 54 | gainTable: Array[Double], 55 | boostingStrategy: BoostingStrategy, 56 | params: Params, 57 | feature2Gain: Array[Double]): GradientBoostedDecisionTreesModel = { 58 | 59 | new LambdaMART(boostingStrategy, params) 60 | .run(trainingDataSet, validateDataSet, trainingData_T, gainTable, feature2Gain) 61 | } 62 | 63 | private def boostMart(trainingDataSet: dataSet, 64 | validateDataSet: dataSet, 65 | trainingData_T: RDD[(Int, Array[Array[Short]])], 66 | gainTable: Array[Double], 67 | boostingStrategy: BoostingStrategy, 68 | params: Params, 69 | feature2Gain: Array[Double]): GradientBoostedDecisionTreesModel = { 70 | val timer = new TimeTracker() 71 | timer.start("total") 72 | timer.start("init") 73 | 74 | boostingStrategy.assertValid() 75 | val learningStrategy = params.learningStrategy 76 | // Initialize gradient boosting parameters 77 | val numPhases = params.numIterations.length 78 | val numTrees = params.numIterations.sum //different phases different trees number. 79 | var baseLearners = new Array[OptimizedDecisionTreeModel](numTrees) 80 | var baseLearnerWeights = new Array[Double](numTrees) 81 | // val loss = boostingStrategy.loss 82 | val numPruningLeaves = params.numPruningLeaves 83 | 84 | // Prepare strategy for individual trees, which use regression with variance impurity. 85 | val treeStrategy = boostingStrategy.treeStrategy.copy 86 | // val validationTol = boostingStrategy.validationTol 87 | treeStrategy.algo = LambdaMart 88 | treeStrategy.impurity = Variance 89 | treeStrategy.assertValid() 90 | 91 | 92 | val trainingData = trainingDataSet.getData() 93 | val label = trainingDataSet.getLabel() 94 | val queryBoundy = trainingDataSet.getQueryBoundy() 95 | val initScores = trainingDataSet.getScore() 96 | 97 | val sc = trainingData.sparkContext 98 | val numSamples = label.length 99 | val numQueries = queryBoundy.length - 1 100 | val (qiMinPP, lcNumQueriesPP) = TreeUtils.getPartitionOffsets(numQueries, sc.defaultParallelism) 101 | //println(">>>>>>>>>>>") 102 | //println(qiMinPP.mkString(",")) 103 | //println(lcNumQueriesPP.mkString(",")) 104 | val pdcRDD = sc.parallelize(qiMinPP.zip(lcNumQueriesPP)).cache().setName("PDCCtrl") 105 | 106 | val learningRates = params.learningRate 107 | val distanceWeight2 = params.distanceWeight2 108 | val baselineAlpha = params.baselineAlpha 109 | val secondMs = params.secondaryMS 110 | val secondLe = params.secondaryLE 111 | val secondGains = params.secondGains 112 | val secondaryInverseMacDcg = params.secondaryInverseMaxDcg 113 | val discounts = params.discounts 114 | val baselineDcgs = params.baselineDcgs 115 | 116 | val dc = new DerivativeCalculator 117 | //sigma = params.learningRate(0) 118 | dc.init(label, gainTable, queryBoundy, 119 | learningRates(0), distanceWeight2, baselineAlpha, 120 | secondMs, secondLe, secondGains, secondaryInverseMacDcg, discounts, baselineDcgs) 121 | 122 | val dcBc = sc.broadcast(dc) 123 | val lambdas = new Array[Double](numSamples) 124 | val weights = new Array[Double](numSamples) 125 | 126 | timer.stop("init") 127 | 128 | val currentScores = initScores 129 | val initErrors = evaluateErrors(pdcRDD, dcBc, currentScores, numQueries) 130 | println(s"NDCG initError sum = $initErrors") 131 | 132 | var m = 0 133 | var numIterations = 0 134 | 135 | var earlystop = false 136 | val useEarlystop = params.useEarlystop 137 | var phase = 0 138 | val oldRep = new Array[Double](numSamples) 139 | val validationSpan = params.validationSpan 140 | val multiplier_Score = 1.0 141 | var sampleFraction = params.sfraction 142 | 143 | while (phase < numPhases && !earlystop) { 144 | numIterations += params.numIterations(phase) 145 | //initial derivativeCalculator for every phase 146 | val dcPhase = new DerivativeCalculator 147 | dcPhase.init(label, gainTable, queryBoundy, learningRates(phase), 148 | distanceWeight2, baselineAlpha, 149 | secondMs, secondLe, secondGains, secondaryInverseMacDcg, discounts, baselineDcgs) 150 | 151 | val dcPhaseBc = sc.broadcast(dcPhase) 152 | 153 | var qfraction = 0.5 // stochastic sampling fraction of query per tree 154 | while (m < numIterations && !earlystop) { 155 | timer.start(s"building tree $m") 156 | println("\nGradient boosting tree iteration " + m) 157 | 158 | 159 | // println(s"active samples: ${activeSamples.sum}") 160 | 161 | val iterationBc = sc.broadcast(m) 162 | val currentScoresBc = sc.broadcast(currentScores) 163 | updateDerivatives(pdcRDD, dcPhaseBc, currentScoresBc, iterationBc, lambdas, weights) 164 | currentScoresBc.unpersist(blocking = false) 165 | iterationBc.unpersist(blocking = false) 166 | 167 | //adaptive lambda 168 | if (params.active_lambda_learningStrategy) { 169 | val rho_lambda = params.rho_lambda 170 | if (learningStrategy == "sgd") { 171 | 172 | } 173 | else if (learningStrategy == "momentum") { 174 | Range(0, numSamples).par.foreach { si => 175 | lambdas(si) = rho_lambda * oldRep(si) + lambdas(si) 176 | oldRep(si) = lambdas(si) 177 | } 178 | } 179 | // else if (learningStrategy == "adagrad") { 180 | // Range(0, numSamples).par.foreach { si => 181 | // oldRep(si) += lambdas(si) * lambdas(si) 182 | // lambdas(si) = lambdas(si) / math.sqrt(oldRep(si) + 1e-9) 183 | // } 184 | // } 185 | // else if (learningStrategy == "adadelta") { 186 | // Range(0, numSamples).par.foreach { si => 187 | // oldRep(si) = rho_lambda * oldRep(si) + (1 - rho_lambda) * lambdas(si) * lambdas(si) 188 | // lambdas(si) = learningRate(phase) * lambdas(si) / scala.math.sqrt(oldRep(si) + 1e-9) 189 | // } 190 | // } 191 | } 192 | 193 | val lambdasBc = sc.broadcast(lambdas) 194 | val weightsBc = sc.broadcast(weights) 195 | 196 | 197 | sampleFraction = min(params.sfraction + (1 - params.sfraction) / 10 * (m * 11.0 / numIterations).toInt, 1.0) 198 | println(s"sampleFraction: $sampleFraction") 199 | qfraction = sampleFraction 200 | val initTimer = new TimeTracker() 201 | initTimer.start("topInfo") 202 | val (activeSamples, sumCount, sumTarget, sumSquare, sumWeight): (Array[Byte], Int, Double, Double, Double) = { 203 | if (qfraction >= 1.0) { 204 | (Array.fill[Byte](numSamples)(1), numSamples, lambdas.sum, lambdas.map(x => x * x).sum, weights.sum) 205 | } 206 | else { 207 | val act = new Array[Byte](numSamples) 208 | val (sumC, sumT, sumS, sumW): (Int, Double, Double, Double) = updateActSamples(pdcRDD, dcPhaseBc, lambdasBc, weightsBc, qfraction, act) 209 | (act, sumC, sumT, sumS, sumW) 210 | } 211 | } 212 | println(s"sampleCount: $sumCount") 213 | initTimer.stop("topInfo") 214 | println(s"$initTimer") 215 | 216 | 217 | logDebug(s"Iteration $m: scores: \n" + currentScores.mkString(" ")) 218 | 219 | val featureUseCount = new Array[Int](feature2Gain.length) 220 | var TrainingDataUse = trainingData 221 | if (params.ffraction < 1.0) { 222 | TrainingDataUse = trainingData.filter(item => IsSeleted(params.ffraction)) 223 | } 224 | 225 | treeStrategy.maxDepth = params.maxDepth(phase) 226 | val tree = new LambdaMARTDecisionTree(treeStrategy, params.minInstancesPerNode(phase), 227 | params.numLeaves, params.maxSplits, params.expandTreeEnsemble) 228 | val (model, treeScores) = tree.run(TrainingDataUse, trainingData_T, lambdasBc, weightsBc, numSamples, 229 | params.entropyCoefft, featureUseCount, params.featureFirstUsePenalty, 230 | params.featureReusePenalty, feature2Gain, params.sampleWeights, numPruningLeaves(phase), 231 | (sumCount, sumTarget, sumSquare, sumWeight), actSamples = activeSamples) 232 | lambdasBc.unpersist(blocking = false) 233 | weightsBc.unpersist(blocking = false) 234 | timer.stop(s"building tree $m") 235 | 236 | baseLearners(m) = model 237 | baseLearnerWeights(m) = learningRates(phase) 238 | 239 | Range(0, numSamples).par.foreach(si => 240 | currentScores(si) += baseLearnerWeights(m) * treeScores(si) 241 | ) 242 | //testing continue training 243 | 244 | 245 | /** 246 | * //adaptive leaves value 247 | * if(params.active_leaves_value_learningStrategy){ 248 | * val rho_leave = params.rho_leave 249 | * if(learningStrategy == "sgd") { 250 | * Range(0, numSamples).par.foreach(si => 251 | * currentScores(si) += learningRate(phase) * treeScores(si) 252 | * ) 253 | * } 254 | * else if(learningStrategy == "momentum"){ 255 | * Range(0, numSamples).par.foreach { si => 256 | * val deltaScore = rho_leave * oldRep(si) + learningRate(phase) * treeScores(si) 257 | * currentScores(si) += deltaScore 258 | * oldRep(si) = deltaScore 259 | * } 260 | * } 261 | * else if (learningStrategy == "adagrad"){ 262 | * Range(0, numSamples).par.foreach { si => 263 | * oldRep(si) += treeScores(si) * treeScores(si) 264 | * currentScores(si) += learningRate(phase) * treeScores(si) / math.sqrt(oldRep(si) + 1e-9) 265 | * } 266 | * } 267 | * else if (learningStrategy == "adadelta"){ 268 | * Range(0, numSamples).par.foreach { si => 269 | * oldRep(si) = rho_leave * oldRep(si) + (1- rho_leave)*treeScores(si)*treeScores(si) 270 | * currentScores(si) += learningRate(phase) * treeScores(si) / math.sqrt(oldRep(si) + 1e-9) 271 | * } 272 | * } 273 | * } ***/ 274 | 275 | 276 | //validate the model 277 | // println(s"validationDataSet: $validateDataSet") 278 | 279 | if (validateDataSet != null && 0 == (m % validationSpan) && useEarlystop) { 280 | val numQueries_V = validateDataSet.getQueryBoundy().length - 1 281 | val (qiMinPP_V, lcNumQueriesPP_V) = TreeUtils.getPartitionOffsets(numQueries_V, sc.defaultParallelism) 282 | //println(s"") 283 | val pdcRDD_V = sc.parallelize(qiMinPP_V.zip(lcNumQueriesPP_V)).cache().setName("PDCCtrl_V") 284 | 285 | val dc_v = new DerivativeCalculator 286 | dc_v.init(validateDataSet.getLabel(), gainTable, validateDataSet.getQueryBoundy(), 287 | learningRates(phase), params.distanceWeight2, baselineAlpha, 288 | secondMs, secondLe, secondGains, secondaryInverseMacDcg, discounts, baselineDcgs) 289 | 290 | val currentBaseLearners = new Array[OptimizedDecisionTreeModel](m + 1) 291 | val currentBaselearnerWeights = new Array[Double](m + 1) 292 | baseLearners.copyToArray(currentBaseLearners, 0, m + 1) 293 | baseLearnerWeights.copyToArray(currentBaselearnerWeights, 0, m + 1) 294 | val currentModel = new GradientBoostedDecisionTreesModel(Regression, currentBaseLearners, currentBaselearnerWeights) 295 | val currentValidateScore = new Array[Double](validateDataSet.getLabel().length) 296 | 297 | //val currentValidateScore_Bc = sc.broadcast(currentValidateScore) 298 | val currentModel_Bc = sc.broadcast(currentModel) 299 | 300 | validateDataSet.getDataTransposed().map { item => 301 | (item._1, currentModel_Bc.value.predict(item._2)) 302 | }.collect().foreach { case (sid, score) => 303 | currentValidateScore(sid) = score 304 | } 305 | 306 | println(s"currentScores: ${currentValidateScore.mkString(",")}") 307 | 308 | val errors = evaluateErrors(pdcRDD_V, sc.broadcast(dc_v), currentValidateScore, numQueries_V) 309 | 310 | println(s"validation errors: $errors") 311 | 312 | if (errors < 1.0e-6) { 313 | earlystop = true 314 | baseLearners = currentBaseLearners 315 | baseLearnerWeights = currentBaselearnerWeights 316 | } 317 | currentModel_Bc.unpersist(blocking = false) 318 | } 319 | val errors = evaluateErrors(pdcRDD, dcPhaseBc, currentScores, numQueries) 320 | 321 | val pw = new PrintWriter(new FileOutputStream(new File("ndcg.txt"), true)) 322 | pw.write(errors.toString + "\n") 323 | pw.close() 324 | 325 | println(s"NDCG error sum = $errors") 326 | println(s"length:" + model.topNode.internalNodes) 327 | // println("error of gbt = " + currentScores.iterator.map(re => re * re).sum / numSamples) 328 | //model.sequence("treeEnsemble.ini", model, m + 1) 329 | m += 1 330 | } 331 | phase += 1 332 | } 333 | 334 | timer.stop("total") 335 | 336 | if (params.outputNdcgFilename != null) { 337 | val outPath = new Path(params.outputNdcgFilename) 338 | val fs = TreeUtils.getFileSystem(trainingData.context.getConf, outPath) 339 | fs.copyFromLocalFile(true, true, new Path("ndcg.txt"), outPath) 340 | } 341 | println("Internal timing for LambdaMARTDecisionTree:") 342 | println(s"$timer") 343 | 344 | trainingData.unpersist(blocking = false) 345 | trainingData_T.unpersist(blocking = false) 346 | 347 | new GradientBoostedDecisionTreesModel(Regression, baseLearners, baseLearnerWeights) 348 | } 349 | 350 | private def boostRegression(trainingDataSet: dataSet, 351 | validateDataSet: dataSet, 352 | trainingData_T: RDD[(Int, Array[Array[Short]])], 353 | gainTable: Array[Double], 354 | boostingStrategy: BoostingStrategy, 355 | params: Params, 356 | feature2Gain: Array[Double]): GradientBoostedDecisionTreesModel = { 357 | val timer = new TimeTracker() 358 | timer.start("total") 359 | timer.start("init") 360 | 361 | boostingStrategy.assertValid() 362 | 363 | // Initialize gradient boosting parameters 364 | val numPhases = params.numIterations.length 365 | val numTrees = params.numIterations.sum //different phases different trees number. 366 | var baseLearners = new Array[OptimizedDecisionTreeModel](numTrees) 367 | var baseLearnerWeights = new Array[Double](numTrees) 368 | // val loss = boostingStrategy.loss 369 | 370 | 371 | // Prepare strategy for individual trees, which use regression with variance impurity. 372 | val treeStrategy = boostingStrategy.treeStrategy.copy 373 | // val validationTol = boostingStrategy.validationTol 374 | treeStrategy.algo = Regression 375 | treeStrategy.impurity = Variance 376 | treeStrategy.assertValid() 377 | val loss = boostingStrategy.loss 378 | 379 | val trainingData = trainingDataSet.getData() 380 | val label = trainingDataSet.getLabel() 381 | val initScores = trainingDataSet.getScore() 382 | 383 | val sc = trainingData.sparkContext 384 | val numSamples = label.length 385 | 386 | val learningRates = params.learningRate 387 | 388 | val lambdas = new Array[Double](numSamples) 389 | var ni = 0 390 | while (ni < numSamples) { 391 | lambdas(ni) = -2 * (label(ni) - initScores(ni)) 392 | ni += 1 393 | } 394 | 395 | // val weights = new Array[Double](numSamples) 396 | val weightsBc = sc.broadcast(Array.empty[Double]) 397 | timer.stop("init") 398 | 399 | val currentScores = initScores 400 | var initErrors = 0.0 401 | ni = 0 402 | while (ni < numSamples) { 403 | initErrors += loss.computeError(currentScores(ni), label(ni)) 404 | ni += 1 405 | } 406 | initErrors /= numSamples 407 | println(s"logloss initError sum = $initErrors") 408 | 409 | var m = 0 410 | var numIterations = 0 411 | 412 | var earlystop = false 413 | val useEarlystop = params.useEarlystop 414 | var phase = 0 415 | val oldRep = if (params.active_lambda_learningStrategy) new Array[Double](numSamples) else Array.empty[Double] 416 | val validationSpan = params.validationSpan 417 | // val multiplier_Score = 1.0 418 | val learningStrategy = params.learningStrategy 419 | while (phase < numPhases && !earlystop) { 420 | numIterations += params.numIterations(phase) 421 | 422 | while (m < numIterations && !earlystop) { 423 | timer.start(s"building tree $m") 424 | println("\nGradient boosting tree iteration " + m) 425 | //update lambda 426 | Range(0, numSamples).par.foreach { ni => 427 | lambdas(ni) = -loss.gradient(currentScores(ni), label(ni)) 428 | } 429 | 430 | if (params.active_lambda_learningStrategy) { 431 | val rho_lambda = params.rho_lambda 432 | if (learningStrategy == "sgd") { 433 | } 434 | else if (learningStrategy == "momentum") { 435 | Range(0, numSamples).par.foreach { si => 436 | lambdas(si) = rho_lambda * oldRep(si) + lambdas(si) 437 | oldRep(si) = lambdas(si) 438 | } 439 | } 440 | else if (learningStrategy == "adagrad") { 441 | Range(0, numSamples).par.foreach { si => 442 | oldRep(si) += lambdas(si) * lambdas(si) 443 | lambdas(si) = lambdas(si) / math.sqrt(oldRep(si)+1.0) 444 | } 445 | } 446 | else if (learningStrategy == "adadelta") { 447 | Range(0, numSamples).par.foreach { si => 448 | oldRep(si) = rho_lambda * oldRep(si) + (1 - rho_lambda) * lambdas(si) * lambdas(si) 449 | lambdas(si) = lambdas(si) / scala.math.sqrt(oldRep(si) + 1.0) 450 | 451 | } 452 | } 453 | } 454 | 455 | val lambdasBc = sc.broadcast(lambdas) 456 | 457 | logDebug(s"Iteration $m: scores: \n" + currentScores.mkString(" ")) 458 | 459 | val featureUseCount = new Array[Int](trainingData.count().toInt) 460 | var TrainingDataUse = trainingData 461 | if (params.ffraction < 1.0) { 462 | TrainingDataUse = trainingData.filter(item => IsSeleted(params.ffraction)) 463 | } 464 | 465 | var si = 0 466 | var sumLambda = 0.0 467 | var sumSquare = 0.0 468 | while (si < lambdas.length) { 469 | sumLambda += lambdas(si) 470 | sumSquare += lambdas(si) * lambdas(si) 471 | si += 1 472 | } 473 | val topValue = (numSamples, sumLambda, sumSquare, 0.0) 474 | val tree = new LambdaMARTDecisionTree(treeStrategy, params.minInstancesPerNode(phase), 475 | params.numLeaves, params.maxSplits, params.expandTreeEnsemble) 476 | val (model, treeScores) = tree.run(TrainingDataUse, trainingData_T, lambdasBc, weightsBc, numSamples, 477 | params.entropyCoefft, featureUseCount, params.featureFirstUsePenalty, 478 | params.featureReusePenalty, feature2Gain, params.sampleWeights, params.numPruningLeaves(phase), topValue) 479 | lambdasBc.unpersist(blocking = false) 480 | 481 | timer.stop(s"building tree $m") 482 | 483 | baseLearners(m) = model 484 | baseLearnerWeights(m) = learningRates(phase) 485 | println(s"learningRate: ${baseLearnerWeights(m)}") 486 | Range(0, numSamples).par.foreach(si => 487 | currentScores(si) += baseLearnerWeights(m) * treeScores(si) 488 | ) 489 | 490 | //validate the model 491 | // println(s"validationDataSet: $validateDataSet") 492 | 493 | if (validateDataSet != null && 0 == (m % validationSpan) && useEarlystop) { 494 | val currentBaseLearners = new Array[OptimizedDecisionTreeModel](m + 1) 495 | val currentBaselearnerWeights = new Array[Double](m + 1) 496 | baseLearners.copyToArray(currentBaseLearners, 0, m + 1) 497 | baseLearnerWeights.copyToArray(currentBaselearnerWeights, 0, m + 1) 498 | val currentModel = new GradientBoostedDecisionTreesModel(Regression, currentBaseLearners, currentBaselearnerWeights) 499 | val validateLabel = validateDataSet.getLabel() 500 | val numValidate = validateLabel.length 501 | 502 | //val currentValidateScore_Bc = sc.broadcast(currentValidateScore) 503 | val currentModel_Bc = sc.broadcast(currentModel) 504 | 505 | val currentValidateScore = validateDataSet.getDataTransposed().map { item => 506 | currentModel_Bc.value.predict(item._2) 507 | }.collect() 508 | 509 | 510 | 511 | var errors = 0.0 512 | Range(0, numValidate).foreach { ni => val x = loss.computeError(currentValidateScore(ni), validateLabel(ni)) 513 | errors += x 514 | } 515 | 516 | println(s"validation errors: $errors") 517 | 518 | if (errors < 1.0e-6) { 519 | earlystop = true 520 | baseLearners = currentBaseLearners 521 | baseLearnerWeights = currentBaselearnerWeights 522 | } 523 | currentModel_Bc.unpersist(blocking = false) 524 | } 525 | 526 | var errors = 0.0 527 | Range(0, numSamples).foreach { ni => val x = loss.computeError(currentScores(ni), label(ni)) 528 | errors += x 529 | } 530 | errors /= numSamples 531 | 532 | val pw = new PrintWriter(new FileOutputStream(new File("ndcg.txt"), true)) 533 | pw.write(errors.toString + "\n") 534 | pw.close() 535 | 536 | println(s"logloss error sum = $errors") 537 | println(s"length:" + model.topNode.internalNodes) 538 | // println("error of gbt = " + currentScores.iterator.map(re => re * re).sum / numSamples) 539 | //model.sequence("treeEnsemble.ini", model, m + 1) 540 | 541 | if (m % params.testSpan == 0) { 542 | val scoreAndLabels = new Array[(Double, Double)](10000) 543 | Range(0, 1000000, 100).foreach { it => 544 | if (currentScores(it) >= 0) 545 | scoreAndLabels(it / 100) = (1.0, label(it).toDouble) 546 | else 547 | scoreAndLabels(it / 100) = (-1.0, label(it).toDouble) 548 | } 549 | val slRDD = sc.makeRDD(scoreAndLabels) 550 | val metrics = new BinaryClassificationMetrics(slRDD) 551 | val accuracy = metrics.areaUnderROC() 552 | println(s"test Accuracy at $m = $accuracy") 553 | } 554 | if (m % params.validationSpan == 0 || m == numIterations - 1) { 555 | val scoreAndLabels = new Array[(Double, Double)](label.length) 556 | var i = 0 557 | while (i < label.length) { 558 | if (currentScores(i) >= 0) 559 | scoreAndLabels(i) = (1.0, label(i).toDouble) 560 | else 561 | scoreAndLabels(i) = (-1.0, label(i).toDouble) 562 | i += 1 563 | } 564 | val slRDD = sc.makeRDD(scoreAndLabels) 565 | val metrics = new BinaryClassificationMetrics(slRDD) 566 | val accuracy = metrics.areaUnderROC() 567 | println(s"training Accuracy at $m = $accuracy") 568 | } 569 | 570 | m += 1 571 | } 572 | phase += 1 573 | } 574 | 575 | timer.stop("total") 576 | 577 | // val scoreAndLabels = sc.makeRDD(currentScores.map(x=> if(x>0) 1.0 else -1.0).zip(label.map(_.toDouble))) 578 | // val metrics = new BinaryClassificationMetrics(scoreAndLabels) 579 | // val accuracy = metrics.areaUnderROC() 580 | // println(s"Accuracy = $accuracy") 581 | 582 | println("Internal timing for RegressionDecisionTree:") 583 | println(s"$timer") 584 | weightsBc.unpersist(blocking = false) 585 | trainingData.unpersist(blocking = false) 586 | trainingData_T.unpersist(blocking = false) 587 | 588 | new GradientBoostedDecisionTreesModel(Regression, baseLearners, baseLearnerWeights) 589 | } 590 | 591 | 592 | def updateDerivatives(pdcRDD: RDD[(Int, Int)], 593 | dcBc: Broadcast[DerivativeCalculator], 594 | currentScoresBc: Broadcast[Array[Double]], 595 | iterationBc: Broadcast[Int], 596 | lambdas: Array[Double], 597 | weights: Array[Double]): Unit = { 598 | val partDerivs = pdcRDD.mapPartitions { iter => 599 | val dc = dcBc.value 600 | val currentScores = currentScoresBc.value 601 | val iteration = iterationBc.value 602 | iter.map { case (qiMin, lcNumQueries) => 603 | dc.getPartDerivatives(currentScores, qiMin, qiMin + lcNumQueries, iteration) 604 | } 605 | }.collect() 606 | partDerivs.par.foreach { case (siMin, lcLambdas, lcWeights) => 607 | Array.copy(lcLambdas, 0, lambdas, siMin, lcLambdas.length) 608 | Array.copy(lcWeights, 0, weights, siMin, lcWeights.length) 609 | } 610 | } 611 | 612 | def updateActSamples(pdcRDD: RDD[(Int, Int)], dcBc: Broadcast[DerivativeCalculator], 613 | lambdasBc: Broadcast[Array[Double]], 614 | weightsBc: Broadcast[Array[Double]], 615 | fraction: Double, 616 | act: Array[Byte]): (Int, Double, Double, Double) = { 617 | val partAct = pdcRDD.mapPartitions { iter => 618 | val lambdas = lambdasBc.value 619 | val weights = weightsBc.value 620 | val queryBoundy = dcBc.value.queryBoundy 621 | val frac = fraction 622 | iter.map { case (qiMin, lcNumQueries) => 623 | val qiEnd = qiMin + lcNumQueries 624 | val siTotalMin = queryBoundy(qiMin) 625 | val numTotalDocs = queryBoundy(qiEnd) - siTotalMin 626 | val lcActSamples = new Array[Byte](numTotalDocs) 627 | var lcSumCount = 0 628 | var lcSumTarget = 0.0 629 | var lcSumSquare = 0.0 630 | var lcSumWeight = 0.0 631 | var qi = qiMin 632 | while (qi < qiEnd) { 633 | val lcMin = queryBoundy(qi) - siTotalMin 634 | val siMin = queryBoundy(qi) 635 | val siEnd = queryBoundy(qi + 1) 636 | val numDocsPerQuery = siEnd - siMin 637 | if (Random.nextDouble() <= frac) { 638 | Range(lcMin, lcMin + numDocsPerQuery).foreach { lsi => 639 | lcActSamples(lsi) = 1.toByte 640 | lcSumTarget += lambdas(siTotalMin + lsi) 641 | lcSumSquare += lambdas(siTotalMin + lsi) * lambdas(siTotalMin + lsi) 642 | lcSumWeight += weights(siTotalMin + lsi) 643 | } 644 | lcSumCount += numDocsPerQuery 645 | } 646 | else { 647 | Range(lcMin, lcMin + numDocsPerQuery).foreach { lsi => 648 | lcActSamples(lsi) = 0.toByte 649 | } 650 | } 651 | qi += 1 652 | } 653 | (siTotalMin, lcActSamples, lcSumCount, lcSumTarget, lcSumSquare, lcSumWeight) 654 | 655 | } 656 | 657 | } 658 | val actSamples = partAct.map(x => (x._1, x._2)).collect() 659 | actSamples.par.foreach { case (siMin, lcAct) => 660 | Array.copy(lcAct, 0, act, siMin, lcAct.length) 661 | } 662 | partAct.map(x => (x._3, x._4, x._5, x._6)).reduce((a, b) => (a._1 + b._1, a._2 + b._2, a._3 + b._3, a._4 + b._4)) 663 | } 664 | 665 | 666 | def evaluateErrors(pdcRDD: RDD[(Int, Int)], 667 | dcBc: Broadcast[DerivativeCalculator], 668 | currentScores: Array[Double], 669 | numQueries: Int): Double = { 670 | val sc = pdcRDD.context 671 | val currentScoresBc = sc.broadcast(currentScores) 672 | val sumErrors = pdcRDD.mapPartitions { iter => 673 | val dc = dcBc.value 674 | val currentScores = currentScoresBc.value 675 | iter.map { case (qiMin, lcNumQueries) => 676 | dc.getPartErrors(currentScores, qiMin, qiMin + lcNumQueries) 677 | } 678 | }.sum() 679 | currentScoresBc.unpersist(blocking = false) 680 | sumErrors / numQueries 681 | } 682 | 683 | def IsSeleted(ffraction: Double): Boolean = { 684 | val randomNum = scala.util.Random.nextDouble() 685 | var active = false 686 | if (randomNum < ffraction) { 687 | active = true 688 | } 689 | active 690 | } 691 | 692 | } 693 | 694 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/config/Algo.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.config 19 | 20 | import org.apache.spark.annotation.{Experimental, Since} 21 | 22 | /** 23 | * :: Experimental :: 24 | * Enum to select the algorithm for the decision tree 25 | */ 26 | @Since("1.0.0") 27 | @Experimental 28 | object Algo extends Enumeration { 29 | @Since("1.0.0") 30 | type Algo = Value 31 | @Since("1.0.0") 32 | val Classification, Regression, LambdaMart = Value 33 | 34 | private[mllib] def fromString(name: String): Algo = name match { 35 | case "classification" | "Classification" => Classification 36 | case "regression" | "Regression" => Regression 37 | case "lambdamart" | "Lambdamart" | "LambdaMart" => LambdaMart 38 | case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/config/BoostingStrategy.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.config 19 | 20 | import org.apache.spark.annotation.Since 21 | import org.apache.spark.mllib.tree.config 22 | import org.apache.spark.mllib.tree.config.Algo._ 23 | import org.apache.spark.mllib.tree.loss.{LogLoss, Loss, SquaredError} 24 | 25 | import scala.beans.BeanProperty 26 | 27 | /** 28 | * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. 29 | * 30 | * @param treeStrategy Parameters for the tree algorithm. We support regression and binary 31 | * classification for boosting. Impurity setting will be ignored. 32 | * @param loss Loss function used for minimization during gradient boosting. 33 | * @param numIterations Number of iterations of boosting. In other words, the number of 34 | * weak hypotheses used in the final model. 35 | * @param learningRate Learning rate for shrinking the contribution of each estimator. The 36 | * learning rate should be between in the interval (0, 1] 37 | * @param validationTol validationTol is a condition which decides iteration termination when 38 | * runWithValidation is used. 39 | * The end of iteration is decided based on below logic: 40 | * If the current loss on the validation set is > 0.01, the diff 41 | * of validation error is compared to relative tolerance which is 42 | * validationTol * (current loss on the validation set). 43 | * If the current loss on the validation set is <= 0.01, the diff 44 | * of validation error is compared to absolute tolerance which is 45 | * validationTol * 0.01. 46 | * Ignored when 47 | * [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used. 48 | */ 49 | @Since("1.2.0") 50 | case class BoostingStrategy @Since("1.4.0") ( 51 | // Required boosting parameters 52 | @Since("1.2.0") @BeanProperty var treeStrategy: Strategy, 53 | @Since("1.2.0") @BeanProperty var loss: Loss, 54 | // Optional boosting parameters 55 | @Since("1.2.0") @BeanProperty var numIterations: Int = 100, 56 | @Since("1.2.0") @BeanProperty var learningRate: Double = 0.1, 57 | @Since("1.4.0") @BeanProperty var validationTol: Double = 0.001) extends Serializable { 58 | 59 | /** 60 | * Check validity of parameters. 61 | * Throws exception if invalid. 62 | */ 63 | private[tree] def assertValid(): Unit = { 64 | treeStrategy.algo match { 65 | case Classification => 66 | require(treeStrategy.numClasses == 2, 67 | "Only binary classification is supported for boosting.") 68 | case Regression => 69 | // nothing 70 | case LambdaMart => 71 | // nothing 72 | case _ => 73 | throw new IllegalArgumentException( 74 | s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." + 75 | s" Valid settings are: Classification, Regression.") 76 | } 77 | require(learningRate > 0 && learningRate <= 1, 78 | "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") 79 | } 80 | } 81 | 82 | @Since("1.2.0") 83 | object BoostingStrategy { 84 | 85 | /** 86 | * Returns default configuration for the boosting algorithm 87 | * 88 | * @param algo Learning goal. Supported: "Classification" or "Regression" 89 | * @return Configuration for boosting algorithm 90 | */ 91 | @Since("1.2.0") 92 | def defaultParams(algo: String): BoostingStrategy = { 93 | defaultParams(fromString(algo)) 94 | } 95 | 96 | /** 97 | * Returns default configuration for the boosting algorithm 98 | * 99 | * @param algo Learning goal. Supported: 100 | * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], 101 | * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] 102 | * @return Configuration for boosting algorithm 103 | */ 104 | @Since("1.3.0") 105 | def defaultParams(algo: Algo): BoostingStrategy = { 106 | val treeStrategy = config.Strategy.defaultStrategy(algo) 107 | treeStrategy.maxDepth = 3 108 | algo match { 109 | case Classification => 110 | treeStrategy.numClasses = 2 111 | println("Classification") 112 | new BoostingStrategy(treeStrategy, LogLoss) 113 | case Regression => 114 | println("Regression") 115 | new BoostingStrategy(treeStrategy, SquaredError) 116 | case LambdaMart => 117 | println("LambdaMart") 118 | new BoostingStrategy(treeStrategy, SquaredError) 119 | case _ => 120 | throw new IllegalArgumentException(s"$algo is not supported by boosting.") 121 | } 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/config/Strategy.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.config 19 | 20 | import org.apache.spark.annotation.Since 21 | import org.apache.spark.mllib.tree.config 22 | import org.apache.spark.mllib.tree.config.Algo._ 23 | import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ 24 | import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} 25 | 26 | import scala.beans.BeanProperty 27 | import scala.collection.JavaConverters._ 28 | 29 | /** 30 | * Stores all the configuration options for tree construction 31 | * 32 | * @param algo Learning goal. Supported: 33 | * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], 34 | * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] 35 | * @param impurity Criterion used for information gain calculation. 36 | * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]], 37 | * [[org.apache.spark.mllib.tree.impurity.Entropy]]. 38 | * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]]. 39 | * @param maxDepth Maximum depth of the tree. 40 | * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. 41 | * @param numClasses Number of classes for classification. 42 | * (Ignored for regression.) 43 | * Default value is 2 (binary classification). 44 | * @param maxBins Maximum number of bins used for discretizing continuous features and 45 | * for choosing how to split on features at each node. 46 | * More bins give higher granularity. 47 | * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: 48 | * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] 49 | * @param categoricalFeaturesInfo A map storing information about the categorical variables and the 50 | * number of discrete values they take. For example, an entry (n -> 51 | * k) implies the feature n is categorical with k categories 0, 52 | * 1, 2, ... , k-1. It's important to note that features are 53 | * zero-indexed. 54 | * @param minInstancesPerNode Minimum number of instances each child must have after split. 55 | * Default value is 1. If a split cause left or right child 56 | * to have less than minInstancesPerNode, 57 | * this split will not be considered as a valid split. 58 | * @param minInfoGain Minimum information gain a split must get. Default value is 0.0. 59 | * If a split has less information gain than minInfoGain, 60 | * this split will not be considered as a valid split. 61 | * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is 62 | * 256 MB. 63 | * @param subsamplingRate Fraction of the training data used for learning decision tree. 64 | * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will 65 | * maintain a separate RDD of node Id cache for each row. 66 | * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. 67 | * E.g. 10 means that the cache will get checkpointed every 10 updates. If 68 | * the checkpoint directory is not set in 69 | * [[org.apache.spark.SparkContext]], this setting is ignored. 70 | */ 71 | @Since("1.0.0") 72 | class Strategy @Since("1.3.0") ( 73 | @Since("1.0.0") @BeanProperty var algo: Algo, 74 | @Since("1.0.0") @BeanProperty var impurity: Impurity, 75 | @Since("1.0.0") @BeanProperty var maxDepth: Int, 76 | @Since("1.2.0") @BeanProperty var numClasses: Int = 2, 77 | @Since("1.0.0") @BeanProperty var maxBins: Int = 32, 78 | @Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, 79 | @Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), 80 | @Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1, 81 | @Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0, 82 | @Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256, 83 | @Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1, 84 | @Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false, 85 | @Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable { 86 | 87 | /** 88 | */ 89 | @Since("1.2.0") 90 | def isMulticlassClassification: Boolean = { 91 | algo == Classification && numClasses > 2 92 | } 93 | 94 | /** 95 | */ 96 | @Since("1.2.0") 97 | def isMulticlassWithCategoricalFeatures: Boolean = { 98 | isMulticlassClassification && (categoricalFeaturesInfo.size > 0) 99 | } 100 | 101 | /** 102 | * Java-friendly constructor for [[config.Strategy]] 103 | */ 104 | @Since("1.1.0") 105 | def this( 106 | algo: Algo, 107 | impurity: Impurity, 108 | maxDepth: Int, 109 | numClasses: Int, 110 | maxBins: Int, 111 | categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]) { 112 | this(algo, impurity, maxDepth, numClasses, maxBins, Sort, 113 | categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) 114 | } 115 | 116 | /** 117 | * Sets Algorithm using a String. 118 | */ 119 | @Since("1.2.0") 120 | def setAlgo(algo: String): Unit = algo match { 121 | case "Classification" => setAlgo(Classification) 122 | case "Regression" => setAlgo(Regression) 123 | case "LambdaMart" => setAlgo(LambdaMart) 124 | } 125 | 126 | /** 127 | * Sets categoricalFeaturesInfo using a Java Map. 128 | */ 129 | @Since("1.2.0") 130 | def setCategoricalFeaturesInfo( 131 | categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { 132 | this.categoricalFeaturesInfo = 133 | categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap 134 | } 135 | 136 | /** 137 | * Check validity of parameters. 138 | * Throws exception if invalid. 139 | */ 140 | private[tree] def assertValid(): Unit = { 141 | algo match { 142 | case Classification => 143 | require(numClasses >= 2, 144 | s"DecisionTree Strategy for Classification must have numClasses >= 2," + 145 | s" but numClasses = $numClasses.") 146 | require(Set(Gini, Entropy).contains(impurity), 147 | s"DecisionTree Strategy given invalid impurity for Classification: $impurity." + 148 | s" Valid settings: Gini, Entropy") 149 | case Regression => 150 | require(impurity == Variance, 151 | s"DecisionTree Strategy given invalid impurity for Regression: $impurity." + 152 | s" Valid settings: Variance") 153 | case LambdaMart => 154 | require(impurity == Variance, 155 | s"DecisionTree Strategy given invalid impurity for LambdaMart: $impurity." + 156 | s" Valid settings: Variance") 157 | case _ => 158 | throw new IllegalArgumentException( 159 | s"DecisionTree Strategy given invalid algo parameter: $algo." + 160 | s" Valid settings are: Classification, Regression.") 161 | } 162 | require(maxDepth >= 0, s"DecisionTree Strategy given invalid maxDepth parameter: $maxDepth." + 163 | s" Valid values are integers >= 0.") 164 | require(maxBins >= 2, s"DecisionTree Strategy given invalid maxBins parameter: $maxBins." + 165 | s" Valid values are integers >= 2.") 166 | require(minInstancesPerNode >= 1, 167 | s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") 168 | require(maxMemoryInMB <= 10240, 169 | s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") 170 | require(subsamplingRate > 0 && subsamplingRate <= 1, 171 | s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " + 172 | s"$subsamplingRate") 173 | } 174 | 175 | /** 176 | * Returns a shallow copy of this instance. 177 | */ 178 | @Since("1.2.0") 179 | def copy: Strategy = { 180 | new Strategy(algo, impurity, maxDepth, numClasses, maxBins, 181 | quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, 182 | maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval) 183 | } 184 | } 185 | 186 | @Since("1.2.0") 187 | object Strategy { 188 | 189 | /** 190 | * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] 191 | * 192 | * @param algo "Classification" or "Regression" 193 | */ 194 | @Since("1.2.0") 195 | def defaultStrategy(algo: String): Strategy = { 196 | defaultStrategy(Algo.fromString(algo)) 197 | } 198 | 199 | /** 200 | * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] 201 | * 202 | * @param algo Algo.Classification or Algo.Regression 203 | */ 204 | @Since("1.3.0") 205 | def defaultStrategy(algo: Algo): Strategy = algo match { 206 | case Algo.Classification => 207 | new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, 208 | numClasses = 2) 209 | case Algo.Regression => 210 | new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, 211 | numClasses = 0) 212 | case Algo.LambdaMart => 213 | new Strategy(algo = LambdaMart, impurity = Variance, maxDepth = 10, 214 | numClasses = 0) 215 | } 216 | 217 | @deprecated("Use Strategy.defaultStrategy instead.", "1.5.0") 218 | @Since("1.2.0") 219 | def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo) 220 | 221 | } 222 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/configuration/algo.scala: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cloudml/SparkTree/45f7ce7886cc54658348f02261b2d9d9ca40431b/src/main/scala/org/apache/spark/mllib/tree/configuration/algo.scala -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/impl/FeatureStatsAggregator.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.tree.impl 2 | 3 | import org.apache.spark.mllib.tree.model.impurity._ 4 | 5 | class FeatureStatsAggregator( 6 | val numBins: Byte) extends Serializable { 7 | 8 | val impurityAggregator: ImpurityAggregator = new VarianceAggregator() 9 | 10 | private val statsSize: Int = impurityAggregator.statsSize 11 | 12 | private val allStatsSize: Int = numBins * statsSize 13 | 14 | private val allStats: Array[Double] = new Array[Double](allStatsSize) 15 | 16 | def getImpurityCalculator(binIndex: Int): ImpurityCalculator = { 17 | impurityAggregator.getCalculator(allStats, binIndex * statsSize) 18 | } 19 | 20 | def update(binIndex: Int, label: Double, instanceWeight: Double, weight: Double): Unit = { 21 | val i = binIndex * statsSize 22 | impurityAggregator.update(allStats, i, label, instanceWeight, weight) 23 | } 24 | 25 | def merge(binIndex: Int, otherBinIndex: Int): Unit = { 26 | impurityAggregator.merge(allStats, binIndex * statsSize, otherBinIndex * statsSize) 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/GetDerivatives.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.tree.model 2 | 3 | import scala.math 4 | import scala.collection.mutable 5 | import scala.collection.mutable.HashMap 6 | import scala.collection.mutable.ArrayBuilder 7 | 8 | object GetDerivatives { 9 | 10 | val _expAsymptote: Double = -50 11 | val _sigmoidBins: Int = 1000000 12 | var _minScore: Double = 0.0 13 | var _maxScore: Double = 0.0 14 | // var _sigmoidTable = _ 15 | var _scoreToSigmoidTableFactor: Double = 0.0 16 | var _minSigmoid: Double = 0.0 17 | var _maxSigmoid: Double = 0.0 18 | 19 | def sortArray(arr: Array[Double], st: Int, num: Int): Array[Int] = { 20 | val score2idx = scala.collection.mutable.HashMap.empty[Double, Int] 21 | for(i <- 0 until num) 22 | score2idx(arr(st + i)) = i 23 | 24 | var res = new Array[Int](num) 25 | var i = 0 26 | var listSorted = score2idx.toList.sorted 27 | listSorted.foreach { case (key, value) => 28 | res(i) = value 29 | i += 1 30 | } 31 | res 32 | } 33 | 34 | def labelSort(arr: Array[Double], st: Int, num: Int): Array[Double] = { 35 | var res = new Array[Double](num) 36 | for(i <- 0 until num){ 37 | res(i) = arr(st + i) 38 | } 39 | res.sortWith(_ > _) 40 | res 41 | } 42 | 43 | def FillSigmoidTable(sigmoidParam: Double = 1.0): Array[Double]= { 44 | // minScore is such that 2*sigmoidParam*score is < expAsymptote if score < minScore 45 | _minScore = _expAsymptote / sigmoidParam / 2 46 | _maxScore = -_minScore 47 | 48 | var _sigmoidTable = new Array[Double](_sigmoidBins) 49 | for (i <- 0 until _sigmoidBins) { 50 | var score: Double = (_maxScore - _minScore) / _sigmoidBins * i + _minScore 51 | 52 | _sigmoidTable(i) = 53 | if (score > 0.0) 54 | 2.0 - 2.0 / (1.0 + scala.math.exp(-2.0 * sigmoidParam * score)) 55 | else 56 | 2.0 / (1.0 + scala.math.exp(2.0 * sigmoidParam * score)) 57 | } 58 | _scoreToSigmoidTableFactor = _sigmoidBins / (_maxScore - _minScore) 59 | _minSigmoid = _sigmoidTable(0) 60 | _maxSigmoid = _sigmoidTable.last 61 | _sigmoidTable 62 | } 63 | 64 | def GetDerivatives_lambda_weight( 65 | numDocuments: Int, begin: Int, 66 | aPermutation: Array[Int], aLabels: Array[Short], 67 | aScores: Array[Double], aLambdas: Array[Double], aWeights: Array[Double], 68 | aDiscount: Array[Double], aGainLabels: Array[Double], inverseMaxDCG: Double, 69 | asigmoidTable: Array[Double], minScore: Double, maxScore: Double, scoreToSigmoidTableFactor: Double, 70 | aSecondaryGains: Array[Double], secondaryMetricShare: Double = 0.0, secondaryExclusive: Boolean = false, secondaryInverseMaxDCG: Double = 0.2, 71 | costFunctionParam: Char = 'c', distanceWeight2: Boolean = false, minDoubleValue: Double = 0.01, 72 | alphaRisk: Double = 0.2, baselineVersusCurrentDcg: Double = 0.1) { 73 | // These arrays are shared among many threads, "begin" is the offset by which all arrays are indexed. 74 | // So we shift them all here to avoid having to add 'begin' to every pointer below. 75 | //val pLabels = begin 76 | //val pScores = begin 77 | //val pLambdas = begin 78 | //val pWeights = begin 79 | //val pGainLabels = begin 80 | //println("here0") 81 | //var aLambdas = new Array[Double](aLabels.length) 82 | //var aWeights = new Array[Double](aLabels.length) 83 | 84 | var pSecondaryGains = 0 85 | 86 | if (secondaryMetricShare != 0) 87 | pSecondaryGains = begin 88 | 89 | var bestScore = aScores(aPermutation(0)) 90 | 91 | var worstIndexToConsider = numDocuments - 1 92 | 93 | while (worstIndexToConsider > 0 && aScores(aPermutation(worstIndexToConsider)) == minDoubleValue) { 94 | worstIndexToConsider -= 1 95 | } 96 | var worstScore = aScores(aPermutation(worstIndexToConsider)) 97 | 98 | var lambdaSum = 0.0 99 | 100 | // Should we still run the calculation on those pairs which are ostensibly the same? 101 | var pairSame: Boolean = secondaryMetricShare != 0.0 102 | 103 | // Did not help to use pointer match on pPermutation[i] 104 | for (i <- 0 until numDocuments) 105 | { 106 | //println("here1") 107 | var high = begin + aPermutation(i) 108 | // We are going to loop through all pairs where label[high] > label[low]. If label[high] is 0, it can't be larger 109 | // If score[high] is Double.MinValue, it's being discarded by shifted NDCG 110 | //println("aLabels(high)", aLabels(high), "aScores(high)", aScores(high), "minDoubleValue", minDoubleValue, "pairSame", pairSame) 111 | if (!((aLabels(high) == 0 && !pairSame) || aScores(high) == minDoubleValue)) { // These variables are all looked up just once per loop of 'i', so do it here. 112 | 113 | var gainLabelHigh = aGainLabels(high) 114 | var labelHigh = aLabels(high) 115 | var scoreHigh = aScores(high) 116 | var discountI = aDiscount(i) 117 | // These variables will store the accumulated lambda and weight difference for high, which saves time 118 | var deltaLambdasHigh: Double = 0 119 | var deltaWeightsHigh: Double = 0 120 | 121 | //The below is effectively: for (int j = 0; j < numDocuments; ++j) 122 | var aaDiscountJ: Array[Double] = aDiscount 123 | var aaPermutationJ: Array[Int] = aPermutation 124 | 125 | for (j <- 0 until numDocuments) { 126 | // only consider pairs with different labels, where "high" has a higher label than "low" 127 | // If score[low] is Double.MinValue, it's being discarded by shifted NDCG 128 | var low = begin + aaPermutationJ(j) 129 | var flag = 130 | if (pairSame) labelHigh < aLabels(low) 131 | else 132 | labelHigh <= aLabels(low) 133 | if (!(flag || aScores(low) == minDoubleValue)) { 134 | var scoreHighMinusLow = scoreHigh - aScores(low) 135 | if (!(secondaryMetricShare == 0.0 && labelHigh == aLabels(low) && scoreHighMinusLow <= 0)) { 136 | 137 | //println("labelHigh", labelHigh, "aLabels(low)", aLabels(low), "scoreHighMinusLow", scoreHighMinusLow) 138 | var dcgGap = gainLabelHigh - aGainLabels(low) 139 | var currentInverseMaxDCG = inverseMaxDCG * (1.0 - secondaryMetricShare) 140 | 141 | // Handle risk w.r.t. baseline. 142 | var pairedDiscount = (discountI - aaDiscountJ(j)).abs 143 | if (alphaRisk > 0) { 144 | var risk: Double = 0.0 145 | var baselineDenorm: Double = baselineVersusCurrentDcg / pairedDiscount 146 | if (baselineVersusCurrentDcg > 0) { 147 | // The baseline is currently higher than the model. 148 | // If we're ranked incorrectly, we can only reduce risk only as much as the baseline current DCG. 149 | risk = 150 | if (scoreHighMinusLow <= 0 && dcgGap > baselineDenorm) baselineDenorm 151 | else 152 | dcgGap 153 | } else if (scoreHighMinusLow > 0) { 154 | // The baseline is currently lower, but this pair is ranked correctly. 155 | risk = baselineDenorm + dcgGap 156 | } 157 | if (risk > 0) { 158 | dcgGap += alphaRisk * risk 159 | } 160 | } 161 | 162 | var sameLabel: Boolean = labelHigh == aLabels(low) 163 | 164 | // calculate the lambdaP for this pair by looking it up in the lambdaTable (computed in LambdaMart.FillLambdaTable) 165 | var lambdaP = 0.0 166 | if (scoreHighMinusLow <= minScore) 167 | lambdaP = asigmoidTable(0) 168 | else if (scoreHighMinusLow >= maxScore) lambdaP = asigmoidTable(asigmoidTable.length - 1) 169 | else lambdaP = asigmoidTable(((scoreHighMinusLow - minScore) * scoreToSigmoidTableFactor).toInt) 170 | 171 | 172 | var weightP = lambdaP * (2.0 - lambdaP) 173 | 174 | if (!(secondaryMetricShare != 0.0 && (sameLabel || currentInverseMaxDCG == 0.0) && aSecondaryGains(high) <= aSecondaryGains(low))) { 175 | if (secondaryMetricShare != 0.0) { 176 | if (sameLabel || currentInverseMaxDCG == 0.0) { 177 | // We should use the secondary metric this time. 178 | dcgGap = aSecondaryGains(high) - aSecondaryGains(low) 179 | currentInverseMaxDCG = secondaryInverseMaxDCG * secondaryMetricShare 180 | sameLabel = false 181 | } else if (!secondaryExclusive && aSecondaryGains(high) > aSecondaryGains(low)) { 182 | var sIDCG = secondaryInverseMaxDCG * secondaryMetricShare 183 | dcgGap = dcgGap / sIDCG + (aSecondaryGains(high) - aSecondaryGains(low)) / currentInverseMaxDCG 184 | currentInverseMaxDCG *= sIDCG 185 | } 186 | } 187 | //println("here2") 188 | //printf("%d-%d : gap %g, currentinv %g\n", high, low, (float)dcgGap, (float)currentInverseMaxDCG); fflush(stdout); 189 | 190 | // calculate the deltaNDCGP for this pair 191 | var deltaNDCGP = dcgGap * pairedDiscount * currentInverseMaxDCG 192 | 193 | // apply distanceWeight2 only to regular pairs 194 | if (!sameLabel && distanceWeight2 && bestScore != worstScore) { 195 | deltaNDCGP /= (.01 + (aScores(high) - aScores(low)).abs) 196 | } 197 | //println("lambda", lambdaP * deltaNDCGP, "deltaNDCGP", deltaNDCGP, "dcgGap", dcgGap, "pairedDiscount", pairedDiscount, "currentInverseMaxDCG", currentInverseMaxDCG) 198 | // update lambdas and weights 199 | deltaLambdasHigh += lambdaP * deltaNDCGP 200 | deltaWeightsHigh += weightP * deltaNDCGP 201 | aLambdas(low) -= lambdaP * deltaNDCGP 202 | aWeights(low) += weightP * deltaNDCGP 203 | 204 | lambdaSum += 2 * lambdaP * deltaNDCGP 205 | } 206 | } 207 | } 208 | } 209 | //Finally, add the values for the high part of the pair that we accumulated across all the low parts 210 | 211 | aLambdas(high) += deltaLambdasHigh 212 | aWeights(high) += deltaWeightsHigh 213 | 214 | //for(i <- 0 until numDocuments) println(aLambdas(begin + i), aWeights(begin + i)) 215 | } 216 | } 217 | (aLambdas, aWeights) 218 | } 219 | } 220 | 221 | 222 | /***** 223 | object Derivate { 224 | def main(args: Array[String]){ 225 | val numDocuments = 5; val begin = 0 226 | val aPermutation = Array(1, 4, 3, 4, 2); val aLabels: Array[Short] = Array(1, 2, 3, 4, 5) 227 | val aScores = Array(1.0, 3.0, 8.0, 15.0, 31.0) 228 | val aDiscount = Array(0.2, 0.5, 0.7, 0.8, 0.9) 229 | val inverseMaxDCG = 0.01 230 | val aGainLabels = Array(0.3, 0.4, 0.5, 0.8, 0.3) 231 | val aSecondaryGains = Array(0.3, 0.4, 0.5, 0.8, 0.3); val asigmoidTable =GetDerivatives.FillSigmoidTable() 232 | val minScore = 0.08; val maxScore = 0.2 233 | val scoreToSigmoidTableFactor = 4 234 | 235 | GetDerivatives.GetDerivatives_lambda_weight( 236 | numDocuments, begin, 237 | aPermutation, aLabels, 238 | aScores, 239 | aDiscount, aGainLabels, inverseMaxDCG, 240 | asigmoidTable, minScore, maxScore, scoreToSigmoidTableFactor, aSecondaryGains 241 | ) 242 | } 243 | } 244 | *****/ -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/Histogram.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.tree.model 2 | 3 | class Histogram(val numBins: Int) { 4 | private val _counts = new Array[Double](numBins) 5 | private val _scores = new Array[Double](numBins) 6 | private val _squares = new Array[Double](numBins) 7 | private val _scoreWeights = new Array[Double](numBins) 8 | 9 | @inline def counts = _counts 10 | 11 | @inline def scores = _scores 12 | 13 | @inline def squares = _squares 14 | 15 | @inline def scoreWeights = _scoreWeights 16 | 17 | def weightedUpdate(bin: Int, score: Double, scoreWeight: Double, weight: Double = 1.0) = { 18 | _counts(bin) += weight 19 | _scores(bin) += score * weight 20 | _squares(bin) += score * score * weight 21 | _scoreWeights(bin) += scoreWeight 22 | } 23 | def update(bin: Int, sampleWeight: Double, score: Double, scoreWeight: Double) = { 24 | _counts(bin) += sampleWeight 25 | _scores(bin) += score 26 | _squares(bin) += score * score 27 | _scoreWeights(bin) += scoreWeight 28 | } 29 | 30 | def cumulateLeft() = { 31 | var bin = 1 32 | while (bin < numBins) { 33 | _counts(bin) += _counts(bin-1) 34 | _scores(bin) += _scores(bin-1) 35 | _squares(bin) += _squares(bin-1) 36 | _scoreWeights(bin) += _scoreWeights(bin-1) 37 | bin += 1 38 | } 39 | this 40 | } 41 | 42 | def cumulate(info: NodeInfoStats, defaultBin: Int)={ 43 | // cumulate from right to left 44 | var bin = numBins-2 45 | var binRight = 0 46 | while (bin > defaultBin) { 47 | binRight = bin+1 48 | _counts(bin) += _counts(binRight) 49 | _scores(bin) += _scores(binRight) 50 | _squares(bin) += _squares(binRight) 51 | _scoreWeights(bin) += _scoreWeights(binRight) 52 | bin -= 1 53 | } 54 | 55 | if(defaultBin!=0){ 56 | _counts(0)=info.sumCount-_counts(0) 57 | _scores(0)=info.sumScores-_scores(0) 58 | _squares(0)=info.sumSquares-_squares(0) 59 | _scoreWeights(0)=info.sumScoreWeights-_scoreWeights(0) 60 | 61 | bin = 1 62 | var binLeft = 0 63 | while(bin0){ 74 | _counts(bin)=_counts(binLeft) 75 | _scores(bin)=_scores(binLeft) 76 | _squares(bin)=_squares(binLeft) 77 | _scoreWeights(bin)=_scoreWeights(binLeft) 78 | bin-=1 79 | binLeft-=1 80 | } 81 | } 82 | _counts(0)=info.sumCount 83 | _scores(0)=info.sumScores 84 | _squares(0)=info.sumSquares 85 | _scoreWeights(0)=info.sumScoreWeights 86 | 87 | this 88 | 89 | } 90 | } 91 | 92 | class NodeInfoStats(var sumCount: Int, 93 | var sumScores: Double, 94 | var sumSquares: Double, 95 | var sumScoreWeights: Double)extends Serializable { 96 | 97 | override def toString = s"NodeInfoStats( sumCount = $sumCount, sumTarget = $sumScores, sumSquares = $sumSquares, sumScoreWeight = $sumScoreWeights)" 98 | 99 | def canEqual(other: Any): Boolean = other.isInstanceOf[NodeInfoStats] 100 | 101 | override def equals(other: Any): Boolean = other match { 102 | case that: NodeInfoStats => 103 | (that canEqual this) && 104 | sumCount == that.sumCount && 105 | sumScores == that.sumScores && 106 | sumSquares == that.sumSquares && 107 | sumScoreWeights == that.sumScoreWeights 108 | case _ => false 109 | } 110 | 111 | override def hashCode(): Int = { 112 | val state = Seq(sumCount, sumScores, sumSquares, sumScoreWeights) 113 | state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/SplitInfo.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.tree.model 2 | 3 | import org.apache.spark.mllib.tree.configuration.FeatureType 4 | 5 | class SplitInfo(feature: Int, threshold: Double) 6 | extends Split(feature, threshold, FeatureType.Continuous, List()) 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/ensemblemodels/GradientBoostedDecisionTreesModel.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.model.ensemblemodels 19 | 20 | import com.github.fommil.netlib.BLAS.{getInstance => blas} 21 | import org.apache.spark.annotation.Experimental 22 | import org.apache.spark.api.java.JavaRDD 23 | import org.apache.spark.mllib.linalg.Vector 24 | import org.apache.spark.mllib.regression.LabeledPoint 25 | import org.apache.spark.mllib.tree.config.Algo 26 | import org.apache.spark.mllib.tree.config.Algo._ 27 | import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ 28 | import org.apache.spark.mllib.tree.loss.Loss 29 | import org.apache.spark.mllib.tree.model.opdtmodel.OptimizedDecisionTreeModel 30 | import org.apache.spark.mllib.util.{Loader, Saveable} 31 | import org.apache.spark.rdd.RDD 32 | import org.apache.spark.sql.SQLContext 33 | import org.apache.spark.util.Utils 34 | import org.apache.spark.{Logging, SparkContext} 35 | import org.json4s.JsonDSL._ 36 | import org.json4s._ 37 | import org.json4s.jackson.JsonMethods._ 38 | 39 | import scala.collection.mutable 40 | 41 | /** 42 | * :: Experimental :: 43 | * Represents a gradient boosted trees model. 44 | * 45 | * @param algo algorithm for the ensemble model, either Classification or Regression 46 | * @param trees tree ensembles 47 | * @param treeWeights tree ensemble weights 48 | */ 49 | @Experimental 50 | class GradientBoostedDecisionTreesModel( 51 | override val algo: Algo, 52 | override val trees: Array[OptimizedDecisionTreeModel], 53 | override val treeWeights: Array[Double]) 54 | extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) 55 | with Saveable { 56 | 57 | require(trees.length == treeWeights.length) 58 | 59 | override def save(sc: SparkContext, path: String): Unit = { 60 | TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this, 61 | GradientBoostedDecisionTreesModel.SaveLoadV1_0.thisClassName) 62 | } 63 | 64 | /** 65 | * Method to compute error or loss for every iteration of gradient boosting. 66 | * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] 67 | * @param loss evaluation metric. 68 | * @return an array with index i having the losses or errors for the ensemble 69 | * containing the first i+1 trees 70 | */ 71 | def evaluateEachIteration( 72 | data: RDD[LabeledPoint], 73 | loss: Loss): Array[Double] = { 74 | 75 | val sc = data.sparkContext 76 | val remappedData = algo match { 77 | case Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) 78 | case _ => data 79 | } 80 | 81 | val numIterations = trees.length 82 | val evaluationArray = Array.fill(numIterations)(0.0) 83 | val localTreeWeights = treeWeights 84 | 85 | var predictionAndError = GradientBoostedDecisionTreesModel.computeInitialPredictionAndError( 86 | remappedData, localTreeWeights(0), trees(0), loss) 87 | 88 | evaluationArray(0) = predictionAndError.values.mean() 89 | 90 | val broadcastTrees = sc.broadcast(trees) 91 | (1 until numIterations).foreach { nTree => 92 | predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => 93 | val currentTree = broadcastTrees.value(nTree) 94 | val currentTreeWeight = localTreeWeights(nTree) 95 | iter.map { case (point, (pred, error)) => 96 | val newPred = pred + currentTree.predict(point.features) * currentTreeWeight 97 | val newError = loss.computeError(newPred, point.label) 98 | (newPred, newError) 99 | } 100 | } 101 | evaluationArray(nTree) = predictionAndError.values.mean() 102 | } 103 | 104 | broadcastTrees.unpersist() 105 | evaluationArray 106 | } 107 | 108 | override protected def formatVersion: String = GradientBoostedDecisionTreesModel.formatVersion 109 | } 110 | 111 | object GradientBoostedDecisionTreesModel extends Loader[GradientBoostedDecisionTreesModel] { 112 | 113 | /** 114 | * Compute the initial predictions and errors for a dataset for the first 115 | * iteration of gradient boosting. 116 | * @param data: training data. 117 | * @param initTreeWeight: learning rate assigned to the first tree. 118 | * @param initTree: first DecisionTreeModel. 119 | * @param loss: evaluation metric. 120 | * @return a RDD with each element being a zip of the prediction and error 121 | * corresponding to every sample. 122 | */ 123 | def computeInitialPredictionAndError( 124 | data: RDD[LabeledPoint], 125 | initTreeWeight: Double, 126 | initTree: OptimizedDecisionTreeModel, 127 | loss: Loss): RDD[(Double, Double)] = { 128 | data.map { lp => 129 | val pred = initTreeWeight * initTree.predict(lp.features) 130 | val error = loss.computeError(pred, lp.label) 131 | (pred, error) 132 | } 133 | } 134 | 135 | /** 136 | * Update a zipped predictionError RDD 137 | * (as obtained with computeInitialPredictionAndError) 138 | * @param data: training data. 139 | * @param predictionAndError: predictionError RDD 140 | * @param treeWeight: Learning rate. 141 | * @param tree: Tree using which the prediction and error should be updated. 142 | * @param loss: evaluation metric. 143 | * @return a RDD with each element being a zip of the prediction and error 144 | * corresponding to each sample. 145 | */ 146 | def updatePredictionError( 147 | data: RDD[LabeledPoint], 148 | predictionAndError: RDD[(Double, Double)], 149 | treeWeight: Double, 150 | tree: OptimizedDecisionTreeModel, 151 | loss: Loss): RDD[(Double, Double)] = { 152 | 153 | val newPredError = data.zip(predictionAndError).mapPartitions { iter => 154 | iter.map { case (lp, (pred, error)) => 155 | val newPred = pred + tree.predict(lp.features) * treeWeight 156 | val newError = loss.computeError(newPred, lp.label) 157 | (newPred, newError) 158 | } 159 | } 160 | newPredError 161 | } 162 | 163 | private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion 164 | 165 | override def load(sc: SparkContext, path: String): GradientBoostedDecisionTreesModel = { 166 | val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path) 167 | val classNameV1_0 = SaveLoadV1_0.thisClassName 168 | (loadedClassName, version) match { 169 | case (className, "1.0") if className == classNameV1_0 => 170 | val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata) 171 | assert(metadata.combiningStrategy == Sum.toString) 172 | val trees = 173 | TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo) 174 | new GradientBoostedDecisionTreesModel(Algo.fromString(metadata.algo), trees, metadata.treeWeights) 175 | case _ => throw new Exception(s"GradientBoostedDecisionTreesModel.load did not recognize model" + 176 | s" with (className, format version): ($loadedClassName, $version). Supported:\n" + 177 | s" ($classNameV1_0, 1.0)") 178 | } 179 | } 180 | 181 | private object SaveLoadV1_0 { 182 | // Hard-code class name string in case it changes in the future 183 | def thisClassName: String = "org.apache.spark.mllib.tree.model.GradientBoostedDecisionTreesModel" 184 | } 185 | 186 | } 187 | 188 | /** 189 | * Represents a tree ensemble model. 190 | * 191 | * @param algo algorithm for the ensemble model, either Classification or Regression 192 | * @param trees tree ensembles 193 | * @param treeWeights tree ensemble weights 194 | * @param combiningStrategy strategy for combining the predictions, not used for regression. 195 | */ 196 | private[tree] sealed class TreeEnsembleModel( 197 | protected val algo: Algo, 198 | protected val trees: Array[OptimizedDecisionTreeModel], 199 | protected val treeWeights: Array[Double], 200 | protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { 201 | 202 | require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.") 203 | 204 | private val sumWeights = math.max(treeWeights.sum, 1e-15) 205 | 206 | /** 207 | * Predicts for a single data point using the weighted sum of ensemble predictions. 208 | * 209 | * @param features array representing a single data point 210 | * @return predicted category from the trained model 211 | */ 212 | private def predictBySumming(features: Vector): Double = { 213 | val treePredictions = trees.map(_.predict(features)) 214 | blas.ddot(numTrees, treePredictions, 1, treeWeights, 1) 215 | } 216 | 217 | /** 218 | * Classifies a single data point based on (weighted) majority votes. 219 | */ 220 | private def predictByVoting(features: Vector): Double = { 221 | val votes = mutable.Map.empty[Int, Double] 222 | trees.view.zip(treeWeights).foreach { case (tree, weight) => 223 | val prediction = tree.predict(features).toInt 224 | votes(prediction) = votes.getOrElse(prediction, 0.0) + weight 225 | } 226 | votes.maxBy(_._2)._1 227 | } 228 | 229 | /** 230 | * Predict values for a single data point using the model trained. 231 | * 232 | * @param features array representing a single data point 233 | * @return predicted category from the trained model 234 | */ 235 | def predict(features: Vector): Double = { 236 | (algo, combiningStrategy) match { 237 | case (Regression, Sum) => 238 | predictBySumming(features) 239 | case (Regression, Average) => 240 | predictBySumming(features) / sumWeights 241 | case (Classification, Sum) => // binary classification 242 | val prediction = predictBySumming(features) 243 | // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. 244 | if (prediction > 0.0) 1.0 else 0.0 245 | case (Classification, Vote) => 246 | predictByVoting(features) 247 | case _ => 248 | throw new IllegalArgumentException( 249 | "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " + 250 | s"($algo, $combiningStrategy).") 251 | } 252 | } 253 | 254 | /** 255 | * Predict values for the given data set. 256 | * 257 | * @param features RDD representing data points to be predicted 258 | * @return RDD[Double] where each entry contains the corresponding prediction 259 | */ 260 | def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) 261 | 262 | /** 263 | * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]]. 264 | */ 265 | def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = { 266 | predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] 267 | } 268 | 269 | /** 270 | * Print a summary of the model. 271 | */ 272 | override def toString: String = { 273 | algo match { 274 | case Classification => 275 | s"TreeEnsembleModel classifier with $numTrees trees\n" 276 | case Regression => 277 | s"TreeEnsembleModel regressor with $numTrees trees\n" 278 | case _ => throw new IllegalArgumentException( 279 | s"TreeEnsembleModel given unknown algo parameter: $algo.") 280 | } 281 | } 282 | 283 | /** 284 | * Print the full model to a string. 285 | */ 286 | def toDebugString: String = { 287 | 288 | val header = toString + "\n" 289 | header + trees.zipWithIndex.map { case (tree, treeIndex) => 290 | val numOfInternalNodes = tree.topNode.internalNodes 291 | s" [Evaluator:$treeIndex]\n EvaluatorType=DecisionTree \n" + 292 | s" NumInernalNodes=$numOfInternalNodes \n" + 293 | s" " 294 | // tree.topNode.subtreeToString(4) 295 | }.fold("")(_ + _) 296 | } 297 | 298 | /** 299 | * Get number of trees in ensemble. 300 | */ 301 | def numTrees: Int = trees.length 302 | 303 | /** 304 | * Get total number of nodes, summed over all trees in the ensemble. 305 | */ 306 | def totalNumNodes: Int = trees.map(_.numNodes).sum 307 | } 308 | 309 | private[tree] object TreeEnsembleModel extends Logging { 310 | 311 | object SaveLoadV1_0 { 312 | 313 | import org.apache.spark.mllib.tree.model.opdtmodel.OptimizedDecisionTreeModel.SaveLoadV1_0.{TreeNodeData, constructTrees} 314 | 315 | def thisFormatVersion: String = "1.0" 316 | 317 | case class Metadata( 318 | algo: String, 319 | treeAlgo: String, 320 | combiningStrategy: String, 321 | treeWeights: Array[Double]) 322 | 323 | /** 324 | * Model data for model import/export. 325 | * We have to duplicate NodeData here since Spark SQL does not yet support extracting subfields 326 | * of nested fields; once that is possible, we can use something like: 327 | * case class EnsembleNodeData(treeId: Int, node: NodeData), 328 | * where NodeData is from DecisionTreeModel. 329 | */ 330 | case class EnsembleNodeData(treeId: Int, node: TreeNodeData) 331 | 332 | def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = { 333 | val sqlContext = new SQLContext(sc) 334 | import sqlContext.implicits._ 335 | 336 | // SPARK-6120: We do a hacky check here so users understand why save() is failing 337 | // when they run the ML guide example. 338 | // TODO: Fix this issue for real. 339 | val memThreshold = 512 340 | if (sc.isLocal) { 341 | val driverMemory = sc.getConf.getOption("spark.driver.memory") 342 | .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) 343 | .map(Utils.memoryStringToMb) 344 | .getOrElse(512) 345 | if (driverMemory <= memThreshold) { 346 | logWarning(s"$className.save() was called, but it may fail because of too little" + 347 | s" driver memory (${driverMemory}m)." + 348 | s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).") 349 | } 350 | } else { 351 | if (sc.executorMemory <= memThreshold) { 352 | logWarning(s"$className.save() was called, but it may fail because of too little" + 353 | s" executor memory (${sc.executorMemory}m)." + 354 | s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).") 355 | } 356 | } 357 | 358 | // Create JSON metadata. 359 | implicit val format = DefaultFormats 360 | val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString, 361 | model.combiningStrategy.toString, model.treeWeights) 362 | val metadata = compact(render( 363 | ("class" -> className) ~ ("version" -> thisFormatVersion) ~ 364 | ("metadata" -> Extraction.decompose(ensembleMetadata)))) 365 | sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) 366 | 367 | // Create Parquet data. 368 | val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) => 369 | tree.topNode.subtreeIterator.toSeq.map(node => TreeNodeData(treeId, node)) 370 | }.toDF() 371 | logInfo("flag test") 372 | dataRDD.write.parquet(Loader.dataPath(path)) 373 | } 374 | 375 | /** 376 | * Read metadata from the loaded JSON metadata. 377 | */ 378 | def readMetadata(metadata: JValue): Metadata = { 379 | implicit val formats = DefaultFormats 380 | (metadata \ "metadata").extract[Metadata] 381 | } 382 | 383 | /** 384 | * Load trees for an ensemble, and return them in order. 385 | * @param path path to load the model from 386 | * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's 387 | * algorithm). 388 | */ 389 | def loadTrees( 390 | sc: SparkContext, 391 | path: String, 392 | treeAlgo: String): Array[OptimizedDecisionTreeModel] = { 393 | val datapath = Loader.dataPath(path) 394 | val sqlContext = new SQLContext(sc) 395 | val nodes = sqlContext.read.parquet(datapath).map(TreeNodeData.apply) 396 | val trees = constructTrees(nodes) 397 | trees.map(new OptimizedDecisionTreeModel(_, Algo.fromString(treeAlgo))) 398 | } 399 | } 400 | 401 | } 402 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/impurity/Impurity.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.model.impurity 19 | 20 | import org.apache.spark.annotation.{DeveloperApi, Experimental} 21 | import org.apache.spark.mllib.tree.config.Strategy 22 | 23 | /** 24 | * :: Experimental :: 25 | * Trait for calculating information gain. 26 | * This trait is used for 27 | * (a) setting the impurity parameter in [[Strategy]] 28 | * (b) calculating impurity values from sufficient statistics. 29 | */ 30 | @Experimental 31 | trait Impurity extends Serializable { 32 | 33 | /** 34 | * :: DeveloperApi :: 35 | * information calculation for multiclass classification 36 | * 37 | * @param counts Array[Double] with counts for each label 38 | * @param totalCount sum of counts for all labels 39 | * @return information value, or 0 if totalCount = 0 40 | */ 41 | @DeveloperApi 42 | def calculate(counts: Array[Double], totalCount: Double): Double 43 | 44 | /** 45 | * :: DeveloperApi :: 46 | * information calculation for regression 47 | * 48 | * @param count number of instances 49 | * @param sum sum of labels 50 | * @param sumSquares summation of squares of the labels 51 | * @return information value, or 0 if count = 0 52 | */ 53 | @DeveloperApi 54 | def calculate(count: Double, sum: Double, sumSquares: Double): Double 55 | } 56 | 57 | /** 58 | * Interface for updating views of a vector of sufficient statistics, 59 | * in order to compute impurity from a sample. 60 | * Note: Instances of this class do not hold the data; they operate on views of the data. 61 | * 62 | * @param statsSize Length of the vector of sufficient statistics for one bin. 63 | */ 64 | private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { 65 | 66 | /** 67 | * Merge the stats from one bin into another. 68 | * 69 | * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. 70 | * @param offset Start index of stats for (node, feature, bin) which is modified by the merge. 71 | * @param otherOffset Start index of stats for (node, feature, other bin) which is not modified. 72 | */ 73 | def merge(allStats: Array[Double], offset: Int, otherOffset: Int): Unit = { 74 | var i = 0 75 | while (i < statsSize) { 76 | allStats(offset + i) += allStats(otherOffset + i) 77 | i += 1 78 | } 79 | } 80 | 81 | /** 82 | * Update stats for one (node, feature, bin) with the given label. 83 | * 84 | * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. 85 | * @param offset Start index of stats for this (node, feature, bin). 86 | */ 87 | def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double, weight: Double): Unit 88 | 89 | /** 90 | * Get an [[ImpurityCalculator]] for a (node, feature, bin). 91 | * 92 | * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. 93 | * @param offset Start index of stats for this (node, feature, bin). 94 | */ 95 | def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator 96 | 97 | } 98 | 99 | /** 100 | * Stores statistics for one (node, feature, bin) for calculating impurity. 101 | * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific 102 | * (node, feature, bin). 103 | * 104 | * @param stats Array of sufficient statistics for a (node, feature, bin). 105 | */ 106 | private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { 107 | 108 | /** 109 | * Make a deep copy of this [[ImpurityCalculator]]. 110 | */ 111 | def copy: ImpurityCalculator 112 | 113 | /** 114 | * Calculate the impurity from the stored sufficient statistics. 115 | */ 116 | def calculate(): Double 117 | 118 | /** 119 | * Add the stats from another calculator into this one, modifying and returning this calculator. 120 | */ 121 | def add(other: ImpurityCalculator): ImpurityCalculator = { 122 | require(stats.length == other.stats.length, 123 | s"Two ImpurityCalculator instances cannot be added with different counts sizes." + 124 | s" Sizes are ${stats.length} and ${other.stats.length}.") 125 | var i = 0 126 | val len = other.stats.length 127 | while (i < len) { 128 | stats(i) += other.stats(i) 129 | i += 1 130 | } 131 | this 132 | } 133 | 134 | /** 135 | * Subtract the stats from another calculator from this one, modifying and returning this 136 | * calculator. 137 | */ 138 | def subtract(other: ImpurityCalculator): ImpurityCalculator = { 139 | require(stats.length == other.stats.length, 140 | s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + 141 | s" Sizes are ${stats.length} and ${other.stats.length}.") 142 | var i = 0 143 | val len = other.stats.length 144 | while (i < len) { 145 | stats(i) -= other.stats(i) 146 | i += 1 147 | } 148 | this 149 | } 150 | 151 | /** 152 | * Number of data points accounted for in the sufficient statistics. 153 | */ 154 | def count: Long 155 | 156 | /** 157 | * Prediction which should be made based on the sufficient statistics. 158 | */ 159 | def predict: Double 160 | 161 | /** 162 | * Probability of the label given by [[predict]], or -1 if no probability is available. 163 | */ 164 | def prob(label: Double): Double = -1 165 | 166 | /** 167 | * Return the index of the largest array element. 168 | * Fails if the array is empty. 169 | */ 170 | protected def indexOfLargestArrayElement(array: Array[Double]): Int = { 171 | val result = array.foldLeft(-1, Double.MinValue, 0) { 172 | case ((maxIndex, maxValue, currentIndex), currentValue) => 173 | if (currentValue > maxValue) { 174 | (currentIndex, currentValue, currentIndex + 1) 175 | } else { 176 | (maxIndex, maxValue, currentIndex + 1) 177 | } 178 | } 179 | if (result._1 < 0) { 180 | throw new RuntimeException("ImpurityCalculator internal error:" + 181 | " indexOfLargestArrayElement failed") 182 | } 183 | result._1 184 | } 185 | 186 | } 187 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/impurity/Variance.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.model.impurity 19 | 20 | import org.apache.spark.annotation.{DeveloperApi, Experimental} 21 | 22 | /** 23 | * :: Experimental :: 24 | * Class for calculating variance during regression 25 | */ 26 | @Experimental 27 | object Variance extends Impurity { 28 | 29 | /** 30 | * :: DeveloperApi :: 31 | * information calculation for multiclass classification 32 | * @param counts Array[Double] with counts for each label 33 | * @param totalCount sum of counts for all labels 34 | * @return information value, or 0 if totalCount = 0 35 | */ 36 | @DeveloperApi 37 | override def calculate(counts: Array[Double], totalCount: Double): Double = 38 | throw new UnsupportedOperationException("Variance.calculate") 39 | 40 | /** 41 | * :: DeveloperApi :: 42 | * variance calculation 43 | * @param count number of instances 44 | * @param sum sum of labels 45 | * @param sumSquares summation of squares of the labels 46 | * @return information value, or 0 if count = 0 47 | */ 48 | @DeveloperApi 49 | override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { 50 | if (count == 0) { 51 | return 0 52 | } 53 | val squaredLoss = sumSquares - (sum * sum) / count 54 | squaredLoss / count 55 | } 56 | 57 | /** 58 | * Get this impurity instance. 59 | * This is useful for passing impurity parameters to a Strategy in Java. 60 | */ 61 | def instance: this.type = this 62 | 63 | } 64 | 65 | /** 66 | * Class for updating views of a vector of sufficient statistics, 67 | * in order to compute impurity from a sample. 68 | * Note: Instances of this class do not hold the data; they operate on views of the data. 69 | */ 70 | private[tree] class VarianceAggregator() 71 | extends ImpurityAggregator(statsSize = 4) with Serializable { 72 | 73 | /** 74 | * Update stats for one (node, feature, bin) with the given label. 75 | * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. 76 | * @param offset Start index of stats for this (node, feature, bin). 77 | */ 78 | def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double, weight: Double): Unit = { 79 | allStats(offset) += instanceWeight 80 | allStats(offset + 1) += instanceWeight * label 81 | allStats(offset + 2) += instanceWeight * label * label 82 | allStats(offset + 3) += weight 83 | } 84 | 85 | /** 86 | * Get an [[ImpurityCalculator]] for a (node, feature, bin). 87 | * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. 88 | * @param offset Start index of stats for this (node, feature, bin). 89 | */ 90 | def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { 91 | new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) 92 | } 93 | 94 | } 95 | 96 | /** 97 | * Stores statistics for one (node, feature, bin) for calculating impurity. 98 | * Unlike [[GiniAggregator]], this class stores its own data and is for a specific 99 | * (node, feature, bin). 100 | * @param stats Array of sufficient statistics for a (node, feature, bin). 101 | */ 102 | private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { 103 | 104 | require(stats.size == 4, 105 | s"VarianceCalculator requires sufficient statistics array stats to be of length 4," + 106 | s" but was given array of length ${stats.size}.") 107 | 108 | /** 109 | * Make a deep copy of this [[ImpurityCalculator]]. 110 | */ 111 | def copy: VarianceCalculator = new VarianceCalculator(stats.clone()) 112 | 113 | /** 114 | * Calculate the impurity from the stored sufficient statistics. 115 | */ 116 | def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2)) 117 | 118 | /** 119 | * Number of data points accounted for in the sufficient statistics. 120 | */ 121 | def count: Long = stats(0).toLong 122 | 123 | /** 124 | * Prediction which should be made based on the sufficient statistics. 125 | */ 126 | def predict: Double = if (stats(3) == 0) { 127 | 0 128 | } else { 129 | stats(1) / stats(3) 130 | } 131 | 132 | override def toString: String = { 133 | s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)}, sumWeight = ${stats(3)})" 134 | } 135 | 136 | } 137 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/informationgainstats/InformationGainStats.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.model.informationgainstats 19 | 20 | import org.apache.spark.annotation.DeveloperApi 21 | import org.apache.spark.mllib.tree.model.predict.Predict 22 | 23 | /** 24 | * :: DeveloperApi :: 25 | * Information gain statistics for each split 26 | * @param gain information gain value 27 | * @param impurity current node impurity 28 | * @param leftImpurity left node impurity 29 | * @param rightImpurity right node impurity 30 | * @param leftPredict left node predict 31 | * @param rightPredict right node predict 32 | */ 33 | @DeveloperApi 34 | class InformationGainStats( 35 | val gain: Double, 36 | val impurity: Double, 37 | val leftImpurity: Double, 38 | val rightImpurity: Double, 39 | val leftPredict: Predict, 40 | val rightPredict: Predict) extends Serializable { 41 | 42 | override def toString: String = { 43 | s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + 44 | s"right impurity = $rightImpurity" 45 | } 46 | 47 | override def equals(o: Any): Boolean = o match { 48 | case other: InformationGainStats => 49 | gain == other.gain && 50 | impurity == other.impurity && 51 | leftImpurity == other.leftImpurity && 52 | rightImpurity == other.rightImpurity && 53 | leftPredict == other.leftPredict && 54 | rightPredict == other.rightPredict 55 | 56 | case _ => false 57 | } 58 | 59 | override def hashCode: Int = { 60 | com.google.common.base.Objects.hashCode( 61 | gain: java.lang.Double, 62 | impurity: java.lang.Double, 63 | leftImpurity: java.lang.Double, 64 | rightImpurity: java.lang.Double, 65 | leftPredict, 66 | rightPredict) 67 | } 68 | } 69 | 70 | 71 | private[tree] object InformationGainStats { 72 | /** 73 | * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to 74 | * denote that current split doesn't satisfies minimum info gain or 75 | * minimum number of instances per node. 76 | */ 77 | val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, 78 | new Predict(0.0, 0.0), new Predict(0.0, 0.0)) 79 | } 80 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/node/Node.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.model.node 19 | 20 | import org.apache.spark.annotation.DeveloperApi 21 | import org.apache.spark.Logging 22 | import org.apache.spark.mllib.tree.configuration.FeatureType._ 23 | import org.apache.spark.mllib.linalg.Vector 24 | import org.apache.spark.mllib.tree.model.predict.Predict 25 | import org.apache.spark.mllib.tree.model.Split 26 | import org.apache.spark.mllib.tree.model.informationgainstats.InformationGainStats 27 | import scala.beans.BeanProperty 28 | 29 | /** 30 | * :: DeveloperApi :: 31 | * Node in a decision tree. 32 | * 33 | * About node indexing: 34 | * Nodes are indexed from 1. Node 1 is the root; nodes 2, 3 are the left, right children. 35 | * Node index 0 is not used. 36 | * 37 | * @param id integer node id, from 1 38 | * @param predict predicted value at the node 39 | * @param impurity current node impurity 40 | * @param isLeaf whether the node is a leaf 41 | * @param split split to calculate left and right nodes 42 | * @param leftNode left child 43 | * @param rightNode right child 44 | * @param stats information gain stats 45 | */ 46 | // @DeveloperApi 47 | class Node ( 48 | // @BeanProperty 49 | var id: Int, 50 | var predict: Predict, 51 | var impurity: Double, 52 | var isLeaf: Boolean, 53 | var split: Option[Split], 54 | var leftNode: Option[Node], 55 | var rightNode: Option[Node], 56 | var stats: Option[InformationGainStats]) extends Serializable with Logging { 57 | 58 | def setId(NewInt: Int):Unit = {id = NewInt} 59 | 60 | 61 | override def toString: String = { 62 | s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " + 63 | s"split = $split, stats = $stats" 64 | } 65 | 66 | 67 | /** 68 | * build the left node and right nodes if not leaf 69 | * @param nodes array of nodes 70 | */ 71 | @deprecated("build should no longer be used since trees are constructed on-the-fly in training", 72 | "1.2.0") 73 | def build(nodes: Array[Node]): Unit = { 74 | logDebug("building node " + id + " at level " + Node.indexToLevel(id)) 75 | logDebug("id = " + id + ", split = " + split) 76 | logDebug("stats = " + stats) 77 | logDebug("predict = " + predict) 78 | logDebug("impurity = " + impurity) 79 | if (!isLeaf) { 80 | leftNode = Some(nodes(Node.leftChildIndex(id))) 81 | rightNode = Some(nodes(Node.rightChildIndex(id))) 82 | leftNode.get.build(nodes) 83 | rightNode.get.build(nodes) 84 | } 85 | } 86 | 87 | 88 | /** 89 | * predict value if node is not leaf 90 | * @param features feature value 91 | * @return predicted value 92 | */ 93 | def predict(features: Vector) : Double = { 94 | if (isLeaf) { 95 | predict.predict 96 | } else { 97 | if (split.get.featureType == Continuous) { 98 | if (features(split.get.feature) <= split.get.threshold) { 99 | leftNode.get.predict(features) 100 | } else { 101 | rightNode.get.predict(features) 102 | } 103 | } else { 104 | if (split.get.categories.contains(features(split.get.feature))) { 105 | leftNode.get.predict(features) 106 | } else { 107 | rightNode.get.predict(features) 108 | } 109 | } 110 | } 111 | } 112 | 113 | /** 114 | * Returns a deep copy of the subtree rooted at this node. 115 | */ 116 | private[tree] def deepCopy(): Node = { 117 | val leftNodeCopy = if (leftNode.isEmpty) { 118 | None 119 | } else { 120 | Some(leftNode.get.deepCopy()) 121 | } 122 | val rightNodeCopy = if (rightNode.isEmpty) { 123 | None 124 | } else { 125 | Some(rightNode.get.deepCopy()) 126 | } 127 | new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) 128 | } 129 | 130 | /** 131 | * Get the number of nodes in tree below this node, including leaf nodes. 132 | * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. 133 | */ 134 | private[tree] def numDescendants: Int = if (isLeaf) { 135 | 0 136 | } else { 137 | 2 + leftNode.get.numDescendants + rightNode.get.numDescendants 138 | } 139 | 140 | /** 141 | * Get depth of tree from this node. 142 | * E.g.: Depth 0 means this is a leaf node. 143 | */ 144 | private[tree] def subtreeDepth: Int = if (isLeaf) { 145 | 0 146 | } else { 147 | 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) 148 | } 149 | 150 | 151 | /* Get the number of internal nodes */ 152 | private[tree] def internalNodes: Int = if (isLeaf) { 153 | 0 154 | } else { 155 | 1 + leftNode.get.internalNodes + rightNode.get.internalNodes 156 | } 157 | 158 | 159 | 160 | 161 | /** 162 | * Recursive print function. 163 | * @param indentFactor The number of spaces to add to each level of indentation. 164 | */ 165 | private[tree] def subtreeToString(indentFactor: Int = 0): String = { 166 | 167 | def splitToString(split: Split, left: Boolean): String = { 168 | split.featureType match { 169 | case Continuous => if (left) { 170 | s"(feature ${split.feature} <= ${split.threshold})" 171 | } else { 172 | s"(feature ${split.feature} > ${split.threshold})" 173 | } 174 | case Categorical => if (left) { 175 | s"(feature ${split.feature} in ${split.categories.mkString("{", ",", "}")})" 176 | } else { 177 | s"(feature ${split.feature} not in ${split.categories.mkString("{", ",", "}")})" 178 | } 179 | } 180 | } 181 | val prefix: String = " " * indentFactor 182 | if (isLeaf) { 183 | prefix + s"Predict: ${predict.predict}\n" 184 | } else { 185 | prefix + s"If ${splitToString(split.get, left = true)}\n" + 186 | leftNode.get.subtreeToString(indentFactor + 1) + 187 | prefix + s"Else ${splitToString(split.get, left = false)}\n" + 188 | rightNode.get.subtreeToString(indentFactor + 1) 189 | } 190 | } 191 | 192 | /** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */ 193 | private[tree] def subtreeIterator: Iterator[Node] = { 194 | Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++ 195 | rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty) 196 | } 197 | } 198 | 199 | 200 | 201 | // private[spark] object Node { 202 | object Node { 203 | 204 | /** 205 | * Return a node with the given node id (but nothing else set). 206 | */ 207 | def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, 208 | false, None, None, None, None) 209 | 210 | /** 211 | * Construct a node with nodeIndex, predict, impurity and isLeaf parameters. 212 | * This is used in `DecisionTree.findBestSplits` to construct child nodes 213 | * after finding the best splits for parent nodes. 214 | * Other fields are set at next level. 215 | * @param nodeIndex integer node id, from 1 216 | * @param predict predicted value at the node 217 | * @param impurity current node impurity 218 | * @param isLeaf whether the node is a leaf 219 | * @return new node instance 220 | */ 221 | def apply( 222 | nodeIndex: Int, 223 | predict: Predict, 224 | impurity: Double, 225 | isLeaf: Boolean): Node = { 226 | new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None) 227 | } 228 | 229 | /** 230 | * Return the index of the left child of this node. 231 | */ 232 | def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 233 | 234 | /** 235 | * Return the index of the right child of this node. 236 | */ 237 | def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1 238 | 239 | /** 240 | * Get the parent index of the given node, or 0 if it is the root. 241 | */ 242 | def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1 243 | 244 | /** 245 | * Return the level of a tree which the given node is in. 246 | */ 247 | def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) { 248 | throw new IllegalArgumentException(s"0 is not a valid node index.") 249 | } else { 250 | java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex)) 251 | } 252 | 253 | /** 254 | * Returns true if this is a left child. 255 | * Note: Returns false for the root. 256 | */ 257 | def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0 258 | 259 | /** 260 | * Return the maximum number of nodes which can be in the given level of the tree. 261 | * @param level Level of tree (0 = root). 262 | */ 263 | def maxNodesInLevel(level: Int): Int = 1 << level 264 | 265 | /** 266 | * Return the index of the first node in the given level. 267 | * @param level Level of tree (0 = root). 268 | */ 269 | def startIndexInLevel(level: Int): Int = 1 << level 270 | 271 | /** 272 | * Traces down from a root node to get the node with the given node index. 273 | * This assumes the node exists. 274 | */ 275 | def getNode(nodeIndex: Int, rootNode: Node): Node = { 276 | var tmpNode: Node = rootNode 277 | var levelsToGo = indexToLevel(nodeIndex) 278 | while (levelsToGo > 0) { 279 | if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { 280 | tmpNode = tmpNode.leftNode.get 281 | } else { 282 | tmpNode = tmpNode.rightNode.get 283 | } 284 | levelsToGo -= 1 285 | } 286 | tmpNode 287 | } 288 | } 289 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/nodePredict.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.tree.model 2 | 3 | import org.apache.spark.broadcast.Broadcast 4 | 5 | object nodePredict { 6 | 7 | def predict(nodeId: Int, 8 | nodeIdTracker: Broadcast[Array[Byte]], 9 | lambdas: Broadcast[Array[Double]], 10 | weights: Broadcast[Array[Double]]): Double = { 11 | var lambdaSum = 0.0 12 | var weightSum = 0.0 13 | 14 | var sampleIdx = 0 15 | while(sampleIdx < nodeIdTracker.value.length) { 16 | if(nodeId == nodeIdTracker.value(sampleIdx)){ 17 | lambdaSum += lambdas.value(sampleIdx) 18 | weightSum += weights.value(sampleIdx) 19 | } 20 | sampleIdx += 1 21 | } 22 | 23 | var leafValue = lambdaSum/weightSum 24 | leafValue 25 | } 26 | //def adjustPredict() 27 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/opdtmodel/OptimizedDecisionTreeModel.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.model.opdtmodel 19 | 20 | import java.io.{File, FileOutputStream, PrintWriter} 21 | 22 | import org.apache.spark.annotation.Experimental 23 | import org.apache.spark.api.java.JavaRDD 24 | import org.apache.spark.mllib.linalg.Vector 25 | import org.apache.spark.mllib.tree.config.Algo 26 | import org.apache.spark.mllib.tree.config.Algo._ 27 | import org.apache.spark.mllib.tree.configuration.FeatureType 28 | import org.apache.spark.mllib.tree.model.Split 29 | import org.apache.spark.mllib.tree.model.informationgainstats.InformationGainStats 30 | import org.apache.spark.mllib.tree.model.node.Node 31 | import org.apache.spark.mllib.tree.model.predict.Predict 32 | import org.apache.spark.mllib.util.{Loader, Saveable} 33 | import org.apache.spark.rdd.RDD 34 | import org.apache.spark.sql.{DataFrame, Row, SQLContext} 35 | import org.apache.spark.util.Utils 36 | import org.apache.spark.{Logging, SparkContext} 37 | import org.json4s.JsonDSL._ 38 | import org.json4s._ 39 | import org.json4s.jackson.JsonMethods._ 40 | 41 | import scala.collection.mutable 42 | 43 | /** 44 | * :: Experimental :: 45 | * Decision tree model for classification or regression. 46 | * This model stores the decision tree structure and parameters. 47 | * @param topNode root node 48 | * @param algo algorithm type -- classification or regression 49 | */ 50 | @Experimental 51 | class OptimizedDecisionTreeModel(val topNode: Node, val algo: Algo, val expandTreeEnsemble: Boolean = false) 52 | extends Serializable with Saveable { 53 | type Lists = (List[String], List[Double], List[Double], List[Int], List[Int], List[Double], List[Double]) 54 | 55 | /** 56 | * Predict values for a single data point using the model trained. 57 | * 58 | * @param features array representing a single data point 59 | * @return Double prediction from the trained model 60 | */ 61 | def predict(features: Vector): Double = { 62 | topNode.predict(features) 63 | } 64 | 65 | /** 66 | * Predict values for the given data set using the model trained. 67 | * 68 | * @param features RDD representing data points to be predicted 69 | * @return RDD of predictions for each of the given data points 70 | */ 71 | def predict(features: RDD[Vector]): RDD[Double] = { 72 | features.map(x => predict(x)) 73 | } 74 | 75 | /** 76 | * Predict values for the given data set using the model trained. 77 | * 78 | * @param features JavaRDD representing data points to be predicted 79 | * @return JavaRDD of predictions for each of the given data points 80 | */ 81 | def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { 82 | predict(features.rdd) 83 | } 84 | 85 | /** 86 | * Get number of nodes in tree, including leaf nodes. 87 | */ 88 | def numNodes: Int = 10 89 | // { 90 | // 1 + topNode.numDescendants 91 | // } 92 | 93 | /** 94 | * Get depth of tree. 95 | * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes. 96 | */ 97 | def depth: Int = { 98 | 5 99 | // topNode.subtreeDepth 100 | } 101 | 102 | // def internalNodes(rootNode: Node): Int = { 103 | // if(rootNode.isLeaf == true){ 104 | 105 | // } 106 | // internalNodes(rootNode.leftNode)+internalNodes(rootNode.rightNode)+1 107 | // } 108 | 109 | /** 110 | * Print a summary of the model. 111 | */ 112 | override def toString: String = algo match { 113 | case Classification => 114 | s"OptimizedDecisionTreeModel classifier with $numNodes leaf nodes" 115 | case Regression => 116 | s"OptimizedDecisionTreeModel regressor with $numNodes leaf nodes" 117 | case _ => throw new IllegalArgumentException( 118 | s"OptimizedDecisionTreeModel given unknown algo parameter: $algo.") 119 | } 120 | 121 | /** 122 | * Print the full model to a string. 123 | */ 124 | def toDebugString: String = { 125 | val header = toString + "\n" 126 | header + topNode.subtreeToString(2) 127 | } 128 | 129 | override def save(sc: SparkContext, path: String): Unit = { 130 | OptimizedDecisionTreeModel.SaveLoadV1_0.save(sc, path, this) 131 | } 132 | 133 | def reformatted: Lists = { 134 | val splitFeatures = new mutable.MutableList[String] 135 | val splitGains = new mutable.MutableList[Double] 136 | val gainPValues = new mutable.MutableList[Double] 137 | val lteChildren = new mutable.MutableList[Int] 138 | val gtChildren = new mutable.MutableList[Int] 139 | val thresholds = new mutable.MutableList[Double] 140 | val outputs = new mutable.MutableList[Double] 141 | 142 | var curNonLeafIdx = 0 143 | var curLeafIdx = 0 144 | val childIdx = (child: Node) => if (child.isLeaf) { 145 | curLeafIdx -= 1 146 | curLeafIdx 147 | } else { 148 | curNonLeafIdx += 1 149 | curNonLeafIdx 150 | } 151 | 152 | val q = new mutable.Queue[Node] 153 | q.enqueue(topNode) 154 | while (q.nonEmpty) { 155 | val node = q.dequeue() 156 | if (!node.isLeaf) { 157 | val split = node.split.get 158 | val stats = node.stats.get 159 | 160 | val offSet = if(expandTreeEnsemble) 2 else 1 161 | splitFeatures += s"I:${split.feature+offSet}" 162 | 163 | splitGains += stats.gain 164 | gainPValues += 0.0 165 | thresholds += split.threshold 166 | val left = node.leftNode.get 167 | val right = node.rightNode.get 168 | lteChildren += childIdx(left) 169 | gtChildren += childIdx(right) 170 | q.enqueue(left) 171 | q.enqueue(right) 172 | } else { 173 | outputs += node.predict.predict 174 | } 175 | } 176 | (splitFeatures.toList, splitGains.toList, gainPValues.toList, lteChildren.toList, gtChildren.toList, 177 | thresholds.toList, outputs.toList) 178 | } 179 | 180 | def sequence(path: String, model: OptimizedDecisionTreeModel, modelId: Int): Unit = { 181 | val (splitFeatures, splitGains, gainPValues, lteChildren, gtChildren, thresholds, outputs) = reformatted 182 | 183 | val pw = new PrintWriter(new FileOutputStream(new File(path), true)) 184 | if(1 == modelId) 185 | pw.write(s"\n") 186 | pw.write(s"[Evaluator:$modelId]\n") 187 | pw.write("EvaluatorType=DecisionTree\n") 188 | pw.write(s"NumInternalNodes=${topNode.internalNodes}\n") 189 | 190 | var str = splitFeatures.mkString("\t") 191 | pw.write(s"SplitFeatures=$str\n") 192 | str = splitGains.mkString("\t") 193 | pw.write(s"SplitGain=$str\n") 194 | str = gainPValues.mkString("\t") 195 | pw.write(s"GainPValue=$str\n") 196 | str = lteChildren.mkString("\t") 197 | pw.write(s"LTEChild=$str\n") 198 | str = gtChildren.mkString("\t") 199 | pw.write(s"GTChild=$str\n") 200 | str = thresholds.mkString("\t") 201 | pw.write(s"Threshold=$str\n") 202 | str = outputs.mkString("\t") 203 | pw.write(s"Output=$str\n") 204 | 205 | pw.write("\n") 206 | pw.close() 207 | 208 | } 209 | 210 | override protected def formatVersion: String = OptimizedDecisionTreeModel.formatVersion 211 | } 212 | 213 | object OptimizedDecisionTreeModel extends Loader[OptimizedDecisionTreeModel] with Logging { 214 | 215 | private[spark] def formatVersion: String = "1.0" 216 | 217 | private[tree] object SaveLoadV1_0 { 218 | 219 | def thisFormatVersion: String = "1.0" 220 | 221 | // Hard-code class name string in case it changes in the future 222 | def thisClassName: String = "org.apache.spark.mllib.tree.OptimizedDecisionTreeModel" 223 | 224 | case class PredictData(predict: Double, prob: Double) { 225 | def toPredict: Predict = new Predict(predict, prob) 226 | } 227 | 228 | object PredictData { 229 | def apply(p: Predict): PredictData = PredictData(p.predict, p.prob) 230 | 231 | def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1)) 232 | } 233 | 234 | case class SplitData( 235 | feature: Int, 236 | threshold: Double, 237 | featureType: Int, 238 | categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed 239 | def toSplit: Split = { 240 | new Split(feature, threshold, FeatureType(featureType), categories.toList) 241 | } 242 | } 243 | 244 | object SplitData { 245 | def apply(s: Split): SplitData = { 246 | SplitData(s.feature, s.threshold, s.featureType.id, s.categories) 247 | } 248 | 249 | def apply(r: Row): SplitData = { 250 | SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3)) 251 | } 252 | } 253 | 254 | /** Model data for model import/export */ 255 | case class TreeNodeData( 256 | treeId: Int, 257 | nodeId: Int, 258 | predict: PredictData, 259 | impurity: Double, 260 | isLeaf: Boolean, 261 | split: Option[SplitData], 262 | leftNodeId: Option[Int], 263 | rightNodeId: Option[Int], 264 | infoGain: Option[Double]) 265 | 266 | object TreeNodeData { 267 | def apply(treeId: Int, n: Node): TreeNodeData = { 268 | TreeNodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf, 269 | n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id), 270 | n.stats.map(_.gain)) 271 | } 272 | 273 | def apply(r: Row): TreeNodeData = { 274 | val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5))) 275 | val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6)) 276 | val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7)) 277 | val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8)) 278 | TreeNodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3), 279 | r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain) 280 | } 281 | } 282 | 283 | 284 | 285 | def save(sc: SparkContext, path: String, model: OptimizedDecisionTreeModel): Unit = { 286 | val sqlContext = new SQLContext(sc) 287 | import sqlContext.implicits._ 288 | 289 | // SPARK-6120: We do a hacky check here so users understand why save() is failing 290 | // when they run the ML guide example. 291 | // TODO: Fix this issue for real. 292 | val memThreshold = 768 293 | if (sc.isLocal) { 294 | val driverMemory = sc.getConf.getOption("spark.driver.memory") 295 | .orElse(Option(System.getenv("SPARK_DRIVER_MEMORY"))) 296 | .map(Utils.memoryStringToMb) 297 | .getOrElse(512) 298 | if (driverMemory <= memThreshold) { 299 | logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + 300 | s" driver memory (${driverMemory}m)." + 301 | s" If failure occurs, try setting driver-memory ${memThreshold}m (or larger).") 302 | } 303 | } else { 304 | if (sc.executorMemory <= memThreshold) { 305 | logWarning(s"$thisClassName.save() was called, but it may fail because of too little" + 306 | s" executor memory (${sc.executorMemory}m)." + 307 | s" If failure occurs try setting executor-memory ${memThreshold}m (or larger).") 308 | } 309 | } 310 | 311 | // Create JSON metadata. 312 | val metadata = compact(render( 313 | ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ 314 | ("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes))) 315 | sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) 316 | 317 | // Create Parquet data. 318 | val nodes = model.topNode.subtreeIterator.toSeq 319 | val dataRDD: DataFrame = sc.parallelize(nodes) 320 | .map(TreeNodeData.apply(0, _)) 321 | .toDF() 322 | dataRDD.write.parquet(Loader.dataPath(path)) 323 | } 324 | 325 | def load(sc: SparkContext, path: String, algo: String, numNodes: Int): OptimizedDecisionTreeModel = { 326 | val datapath = Loader.dataPath(path) 327 | val sqlContext = new SQLContext(sc) 328 | // Load Parquet data. 329 | val dataRDD = sqlContext.read.parquet(datapath) 330 | // Check schema explicitly since erasure makes it hard to use match-case for checking. 331 | Loader.checkSchema[TreeNodeData](dataRDD.schema) 332 | val nodes = dataRDD.map(TreeNodeData.apply) 333 | // Build node data into a tree. 334 | val trees = constructTrees(nodes) 335 | assert(trees.size == 1, 336 | "Decision tree should contain exactly one tree but got ${trees.size} trees.") 337 | val model = new OptimizedDecisionTreeModel(trees(0), Algo.fromString(algo)) 338 | assert(model.numNodes == numNodes, s"Unable to load OptimizedDecisionTreeModel data from: $datapath." + 339 | s" Expected $numNodes nodes but found ${model.numNodes}") 340 | model 341 | } 342 | 343 | def constructTrees(nodes: RDD[TreeNodeData]): Array[Node] = { 344 | val trees = nodes 345 | .groupBy(_.treeId) 346 | .mapValues(_.toArray) 347 | .collect() 348 | .map { case (treeId, data) => 349 | (treeId, constructTree(data)) 350 | }.sortBy(_._1) 351 | val numTrees = trees.size 352 | val treeIndices = trees.map(_._1).toSeq 353 | assert(treeIndices == (0 until numTrees), 354 | s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.") 355 | trees.map(_._2) 356 | } 357 | 358 | /** 359 | * Given a list of nodes from a tree, construct the tree. 360 | * 361 | * @param data array of all node data in a tree. 362 | */ 363 | def constructTree(data: Array[TreeNodeData]): Node = { 364 | val dataMap: Map[Int, TreeNodeData] = data.map(n => n.nodeId -> n).toMap 365 | assert(dataMap.contains(1), 366 | s"OptimizedDecisionTree missing root node (id = 1).") 367 | constructNode(1, dataMap, mutable.Map.empty) 368 | } 369 | 370 | /** 371 | * Builds a node from the node data map and adds new nodes to the input nodes map. 372 | */ 373 | private def constructNode( 374 | id: Int, 375 | dataMap: Map[Int, TreeNodeData], 376 | nodes: mutable.Map[Int, Node]): Node = { 377 | if (nodes.contains(id)) { 378 | return nodes(id) 379 | } 380 | val data = dataMap(id) 381 | val node = 382 | if (data.isLeaf) { 383 | Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf) 384 | } else { 385 | val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes) 386 | val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes) 387 | val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity, 388 | rightNode.impurity, leftNode.predict, rightNode.predict) 389 | new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf, 390 | data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats)) 391 | } 392 | nodes += node.id -> node 393 | node 394 | } 395 | } 396 | 397 | override def load(sc: SparkContext, path: String): OptimizedDecisionTreeModel = { 398 | implicit val formats = DefaultFormats 399 | val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) 400 | val algo = (metadata \ "algo").extract[String] 401 | val numNodes = (metadata \ "numNodes").extract[Int] 402 | val classNameV1_0 = SaveLoadV1_0.thisClassName 403 | (loadedClassName, version) match { 404 | case (className, "1.0") if className == classNameV1_0 => 405 | SaveLoadV1_0.load(sc, path, algo, numNodes) 406 | case _ => throw new Exception( 407 | s"OptimizedDecisionTreeModel.load did not recognize model with (className, format version):" + 408 | s"($loadedClassName, $version). Supported:\n" + 409 | s" ($classNameV1_0, 1.0)") 410 | } 411 | } 412 | } 413 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/tree/model/predict/Predict.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.mllib.tree.model.predict 19 | 20 | import org.apache.spark.annotation.DeveloperApi 21 | 22 | /** 23 | * Predicted value for a node 24 | * @param predict predicted value 25 | * @param prob probability of the label (classification only) 26 | */ 27 | @DeveloperApi 28 | class Predict( 29 | var predict: Double, 30 | var prob: Double = 0.0) extends Serializable { 31 | 32 | override def toString: String = s"$predict (prob = $prob)" 33 | 34 | override def equals(other: Any): Boolean = { 35 | other match { 36 | case p: Predict => predict == p.predict && prob == p.prob 37 | case _ => false 38 | } 39 | } 40 | 41 | override def hashCode: Int = { 42 | com.google.common.base.Objects.hashCode(predict: java.lang.Double, prob: java.lang.Double) 43 | } 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/util/ProbabilityFunctions.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.util 2 | 3 | object ProbabilityFunctions{ 4 | //probit function 5 | val ProbA = Array( 3.3871328727963666080e0, 1.3314166789178437745e+2, 1.9715909503065514427e+3, 1.3731693765509461125e+4, 6 | 4.5921953931549871457e+4, 6.7265770927008700853e+4, 3.3430575583588128105e+4, 2.5090809287301226727e+3) 7 | val ProbB = Array(4.2313330701600911252e+1, 6.8718700749205790830e+2, 5.3941960214247511077e+3, 2.1213794301586595867e+4, 8 | 3.9307895800092710610e+4, 2.8729085735721942674e+4, 5.2264952788528545610e+3) 9 | 10 | val ProbC = Array(1.42343711074968357734e0, 4.63033784615654529590e0, 5.76949722146069140550e0, 3.64784832476320460504e0, 11 | 1.27045825245236838258e0, 2.41780725177450611770e-1, 2.27238449892691845833e-2, 7.74545014278341407640e-4) 12 | val ProbD = Array(2.05319162663775882187e0, 1.67638483018380384940e0, 6.89767334985100004550e-1, 1.48103976427480074590e-1, 13 | 1.51986665636164571966e-2, 5.47593808499534494600e-4, 1.05075007164441684324e-9) 14 | 15 | val ProbE = Array(6.65790464350110377720e0, 5.46378491116411436990e0, 1.78482653991729133580e0, 2.96560571828504891230e-1, 16 | 2.65321895265761230930e-2, 1.24266094738807843860e-3, 2.71155556874348757815e-5, 2.01033439929228813265e-7) 17 | val ProbF = Array(5.99832206555887937690e-1, 1.36929880922735805310e-1, 1.48753612908506148525e-2, 7.86869131145613259100e-4, 18 | 1.84631831751005468180e-5, 1.42151175831644588870e-7, 2.04426310338993978564e-15) 19 | 20 | def Probit(p: Double): Double ={ 21 | val q = p - 0.5 22 | var r = 0.0 23 | if(scala.math.abs(q) < 0.425){ 24 | r = 0.180625 - q * q 25 | q * (((((((ProbA(7) * r + ProbA(6)) * r + ProbA(5)) * r + ProbA(4)) * r + ProbA(3)) * r + ProbA(2)) * r + ProbA(1)) * r + ProbA(0)) / 26 | (((((((ProbB(6) * r + ProbB(5)) * r + ProbB(4)) * r + ProbB(3)) * r + ProbB(2)) * r + ProbB(1)) * r + ProbB(0)) * r + 1.0) 27 | } 28 | else{ 29 | if(q < 0) r = p 30 | else r = 1 - p 31 | r = scala.math.sqrt( -scala.math.log(r)) 32 | var retval = 0.0 33 | if(r < 5){ 34 | r = r - 1.6 35 | retval = (((((((ProbC(7) * r + ProbC(6)) * r + ProbC(5)) * r + ProbC(4)) * r + ProbC(3)) * r + ProbC(2)) * r + ProbC(1)) * r + ProbC(0)) / 36 | (((((((ProbD(6) * r + ProbD(5)) * r + ProbD(4)) * r + ProbD(3)) * r + ProbD(2)) * r + ProbD(1)) * r + ProbD(0)) * r + 1.0) 37 | } 38 | else{ 39 | r = r - 5 40 | retval = (((((((ProbE(7) * r + ProbE(6)) * r + ProbE(5)) * r + ProbE(4)) * r + ProbE(3)) * r + ProbE(2)) * r + ProbE(1)) * r + ProbE(0)) / 41 | (((((((ProbF(6) * r + ProbF(5)) * r + ProbF(4)) * r + ProbF(3)) * r + ProbF(2)) * r + ProbF(1)) * r + ProbF(0)) * r + 1.0) 42 | } 43 | if(q >= 0) retval else -retval 44 | } 45 | } 46 | 47 | 48 | //The approximate complimentary error function (i.e., 1-erf). 49 | def erfc(x: Double): Double = { 50 | if (x.isInfinity) { 51 | if(x.isPosInfinity) 1.0 else -1.0 52 | } 53 | else { 54 | val p = 0.3275911 55 | val a1 = 0.254829592 56 | val a2 = -0.284496736 57 | val a3 = 1.421413741 58 | val a4 = -1.453152027 59 | val a5 = 1.061405429 60 | 61 | val t = 1.0 / (1.0 + p * math.abs(x)) 62 | val ev = ((((((((a5 * t) + a4) * t) + a3) * t) + a2) * t + a1) * t) * scala.math.exp(-(x * x)) 63 | if (x >= 0) ev else (2-ev) 64 | } 65 | } 66 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/util/TreeUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.util 2 | 3 | import org.apache.hadoop.fs.{FileSystem, Path} 4 | import org.apache.spark.SparkConf 5 | import org.apache.spark.deploy.SparkHadoopUtil 6 | 7 | 8 | object TreeUtils { 9 | def getFileSystem(conf: SparkConf, path: Path): FileSystem = { 10 | val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) 11 | if (sys.env.contains("HADOOP_CONF_DIR") || sys.env.contains("YARN_CONF_DIR")) { 12 | val hdfsConfPath = if (sys.env.get("HADOOP_CONF_DIR").isDefined) { 13 | sys.env.get("HADOOP_CONF_DIR").get + "/core-site.xml" 14 | } else { 15 | sys.env.get("YARN_CONF_DIR").get + "/core-site.xml" 16 | } 17 | hadoopConf.addResource(new Path(hdfsConfPath)) 18 | } 19 | path.getFileSystem(hadoopConf) 20 | } 21 | 22 | def getPartitionOffsets(upper: Int, numPartitions: Int): (Array[Int], Array[Int]) = { 23 | val npp = upper / numPartitions 24 | val nppp = npp + 1 25 | val residual = upper - npp * numPartitions 26 | val boundary = residual * nppp 27 | val startPP = new Array[Int](numPartitions) 28 | val lcLenPP = new Array[Int](numPartitions) 29 | var i = 0 30 | while(i < numPartitions) { 31 | if (i < residual) { 32 | startPP(i) = nppp * i 33 | lcLenPP(i) = nppp 34 | } 35 | else{ 36 | startPP(i) = boundary + (i - residual) * npp 37 | lcLenPP(i) = npp 38 | } 39 | i += 1 40 | } 41 | (startPP, lcLenPP) 42 | 43 | /*** 44 | * println(s"upper:$upper" + s"numPartitions: $numPartitions") 45 | * val kpp = { 46 | * val npp = upper / numPartitions 47 | * if (npp * numPartitions == upper) npp else npp + 1 48 | * } 49 | * val startPP = Array.tabulate(numPartitions)(_ * kpp) 50 | * val lcLensPP = Array.tabulate(numPartitions)(pi => 51 | * if (pi < numPartitions - 1) kpp else (upper - startPP(pi)) 52 | * ) 53 | * (startPP, lcLensPP)**/ 54 | 55 | } 56 | 57 | 58 | } 59 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/mllib/util/treeAggregatorFormat.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.util 2 | 3 | import java.io.{File, FileOutputStream, PrintWriter} 4 | import java.text.SimpleDateFormat 5 | import java.util.Date 6 | 7 | //import org.apache.spark.ml.param.Param 8 | 9 | import org.apache.spark.examples.mllib.LambdaMARTRunner.Params 10 | 11 | import scala.util.Sorting 12 | 13 | 14 | object treeAggregatorFormat{ 15 | def appendTreeAggregator(expandTreeEnsemble: Boolean, 16 | filePath: String, 17 | index: Int, 18 | evalNodes: Array[Int], 19 | evalWeights: Array[Double] = null, 20 | bias: Double = 0.0, 21 | Type: String = "Linear"): Unit = { 22 | val pw = new PrintWriter(new FileOutputStream(new File(filePath), true)) 23 | 24 | pw.append(s"[Evaluator:$index]").write("\r\n") 25 | pw.append(s"EvaluatorType=Aggregator").write("\r\n") 26 | 27 | val numNodes = if(expandTreeEnsemble) evalNodes.length+1 else evalNodes.length 28 | val defaultWeight = 1.0 29 | if (evalNodes == null) { 30 | throw new IllegalArgumentException("there is no evaluators to be aggregated") 31 | } else { 32 | pw.append(s"NumNodes=$numNodes").write("\r\n") 33 | pw.append(s"Nodes=").write("") 34 | 35 | if(expandTreeEnsemble) { 36 | pw.append(s"I:1").write("\t") 37 | } 38 | for (eval <- evalNodes) { 39 | pw.append(s"E:$eval").write("\t") 40 | } 41 | pw.write("\r\n") 42 | } 43 | 44 | var weights = new Array[Double](numNodes) 45 | if (evalWeights == null) { 46 | for (i <- 0 until numNodes) { 47 | weights(i) = defaultWeight 48 | } 49 | } else { 50 | weights = evalWeights 51 | } 52 | 53 | pw.append(s"Weights=").write("") 54 | for (weight <- weights) { 55 | pw.append(s"$weight").write("\t") 56 | } 57 | 58 | pw.write("\r\n") 59 | 60 | pw.append(s"Bias=$bias").write("\r\n") 61 | pw.append(s"Type=$Type").write("\r\n") 62 | pw.close() 63 | } 64 | 65 | // format comment 66 | def toCommentFormat(filePath: String, param: Params, featureToName: Array[String], 67 | featureToGain: Array[Double]): Unit ={ 68 | val pw = new PrintWriter(new FileOutputStream(new File(filePath), true)) 69 | 70 | val gainSorted = featureToGain.zipWithIndex 71 | Sorting.quickSort(gainSorted)(Ordering.by[(Double, Int), Double](_._1).reverse) 72 | 73 | if(param.GainNormalization){ 74 | val normFactor = gainSorted(0)._1 75 | gainSorted.map{case (gain, idx) => 76 | (gain/normFactor, idx)} 77 | } 78 | 79 | val timeFormat = new SimpleDateFormat("dd/MM/yyyy HH:mm:ss") 80 | val currentTime = timeFormat.format(new Date()) 81 | pw.append(s"\n[Comments]\nC:0=Regression Tree Ensemble\nC:1=Generated using spark FastRank\nC:2=Created on $currentTime\n").write("") 82 | 83 | var skip = 3 84 | val paramsList = param.toString.split("\n") 85 | val NumParams = paramsList.length 86 | for(i <- 0 until NumParams){ 87 | pw.append(s"C:${skip + i}=PARAM:${paramsList(i)}\n").write("") 88 | } 89 | 90 | val offset = if(param.expandTreeEnsemble) 2 else 1 91 | skip += NumParams 92 | val NumFeatures = gainSorted.length 93 | 94 | for(i <- 0 until NumFeatures){ 95 | var substr = featureToName(gainSorted(i)._2) 96 | if(substr.length > 68) 97 | substr = substr.substring(0, 67) + "..." 98 | pw.append(s"C:${skip + i}=FG:I${gainSorted(i)._2 + offset}:$substr:${gainSorted(i)._1}\n").write("") 99 | } 100 | } 101 | 102 | 103 | } 104 | --------------------------------------------------------------------------------