├── .gitignore
├── GradientBoosting.md
├── LICENSE
├── README.md
├── SequoiaForest.md
├── data
├── mnist.t.tsv.gz
└── mnist.tsv.gz
├── pom.xml
├── scalastyle-config.xml
└── src
├── main
└── scala
│ └── spark_ml
│ ├── discretization
│ ├── BinFinder.scala
│ ├── DiscretizationOptions.scala
│ ├── Discretizer.scala
│ ├── EntropyMinimizingBinFinderFromSample.scala
│ ├── EqualFrequencyBinFinderFromSample.scala
│ ├── EqualWidthBinFinder.scala
│ ├── NumericBinFinderFromSample.scala
│ └── VarianceMinimizingBinFinderFromSample.scala
│ ├── gradient_boosting
│ ├── GradientBoosting.scala
│ ├── GradientBoostingOptions.scala
│ ├── GradientBoostingRunner.scala
│ ├── PredCache.scala
│ └── loss
│ │ ├── LossAggregator.scala
│ │ ├── LossFunction.scala
│ │ └── defaults
│ │ ├── AdaBoostLossAggregator.scala
│ │ ├── AdaBoostLossFunction.scala
│ │ ├── GaussianLossAggregator.scala
│ │ ├── GaussianLossFunction.scala
│ │ ├── LaplacianLossAggregator.scala
│ │ ├── LaplacianLossFunction.scala
│ │ ├── LogLossAggregator.scala
│ │ ├── LogLossFunction.scala
│ │ ├── PoissonLossAggregator.scala
│ │ ├── PoissonLossFunction.scala
│ │ ├── TruncHingeLossAggregator.scala
│ │ └── TruncHingeLossFunction.scala
│ ├── model
│ ├── DecisionTree.scala
│ ├── gb
│ │ ├── GradientBoostedTrees.scala
│ │ ├── GradientBoostedTreesFactory.scala
│ │ └── GradientBoostedTreesStore.scala
│ └── rf
│ │ └── RandomForestStore.scala
│ ├── transformation
│ ├── DataTransformationUtils.scala
│ └── DistinctValueCounter.scala
│ ├── tree_ensembles
│ ├── IdCache.scala
│ ├── IdLookup.scala
│ ├── IdLookupForNodeStats.scala
│ ├── IdLookupForSubTreeInfo.scala
│ ├── IdLookupForUpdaters.scala
│ ├── InfoGainNodeStats.scala
│ ├── NodeStats.scala
│ ├── QuantizedData_ForTrees.scala
│ ├── SubTreeStore.scala
│ ├── TreeEnsembleStore.scala
│ ├── TreeForestTrainer.scala
│ └── VarianceNodeStats.scala
│ └── util
│ ├── Bagger.scala
│ ├── DiscretizedFeatureHandler.scala
│ ├── MapWithSequentialIntKeys.scala
│ ├── Poisson.scala
│ ├── ProgressNotifiee.scala
│ ├── RandomSet.scala
│ ├── ReservoirSample.scala
│ ├── RobustMath.scala
│ ├── Selection.scala
│ └── Sorting.scala
└── test
└── scala
└── spark_ml
├── discretization
├── EqualFrequencyBinFinderSuite.scala
└── EqualWidthBinFinderSuite.scala
├── tree_ensembles
├── ComponentSuite.scala
└── TreeEnsembleSuite.scala
└── util
├── BinsTestUtil.scala
├── LocalSparkContext.scala
└── TestDataGenerator.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | *.tmp
2 | *.bak
3 | *.swp
4 | *~.nib
5 |
6 | # SBT
7 | log/
8 | target/
9 | *.class
10 |
11 | # Eclipse
12 | .classpath
13 | .project
14 | .settings/
15 |
16 | # Intellij
17 | .idea/
18 | .idea_modules/
19 | *.iml
20 | *.iws
21 | *.ipr
22 |
23 | # Mac
24 | .DS_Store
25 |
26 |
--------------------------------------------------------------------------------
/GradientBoosting.md:
--------------------------------------------------------------------------------
1 | # Gradient Boosted Trees
2 |
3 | Instructions coming soon.
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SparkML2
2 |
3 | This is a scalable, high-performance machine learning package for `Spark`. This package is maintained independently from `MLLib` in `Spark`.
4 | The algorithms listed here are going through continuous development, maintenance and updates and thus are provided without guarantees.
5 | Use at your own risk.
6 |
7 | This package is provided under the Apache license 2.0.
8 |
9 | The current version is `0.9`.
10 |
11 | Any comments or questions should be directed to schung@alpinenow.com
12 |
13 | The currently available algorithms are:
14 | * [Sequoia Forest](SequoiaForest.md)
15 | * [Gradient Boosted Trees](GradientBoosting.md)
16 |
17 | ## Compiling the Code
18 |
19 | ./mvn3 package
20 |
21 | ## Quick Start (for YARN and Linux variants)
22 |
23 | Coming soon.
24 |
--------------------------------------------------------------------------------
/SequoiaForest.md:
--------------------------------------------------------------------------------
1 | # Sequoia Forest
2 |
3 | Instructions coming soon.
4 |
--------------------------------------------------------------------------------
/data/mnist.t.tsv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlpineNow/SparkML2/a81e1ad835245f556052816e18bb4e9fb7cabe58/data/mnist.t.tsv.gz
--------------------------------------------------------------------------------
/data/mnist.tsv.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AlpineNow/SparkML2/a81e1ad835245f556052816e18bb4e9fb7cabe58/data/mnist.tsv.gz
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
4 | 4.0.0
5 |
6 | com.alpine
7 | spark_ml_2.10
8 | 1.1-SNAPSHOT
9 | jar
10 | Alpine's Spark ML Package
11 |
12 |
13 | Apache 2.0 License
14 | http://www.apache.org/licenses/LICENSE-2.0.html
15 | repo
16 |
17 |
18 |
19 |
20 | schung
21 | Sung Hwan Chung
22 | schung@alpinenow.com
23 | Alpine Data Labs
24 |
25 |
26 |
27 | UTF-8
28 | UTF-8
29 | 2.10.4
30 | 2.10
31 | 1.7
32 | 64m
33 | 256m
34 | 256m
35 |
36 | github
37 |
38 |
39 |
40 |
41 | internal.repo
42 | Temporary Staging Repository
43 | file://${project.build.directory}/mvn-repo
44 |
45 |
46 |
47 |
48 |
49 | central
50 | Maven Repository
51 | https://repo1.maven.org/maven2
52 |
53 | true
54 |
55 |
56 | false
57 |
58 |
59 |
60 |
61 | sonatype-releases
62 | Sonatype Repository
63 | https://oss.sonatype.org/content/repositories/releases
64 |
65 | true
66 |
67 |
68 | false
69 |
70 |
71 |
72 |
73 |
74 |
75 | com.databricks
76 | spark-csv_${scala.binary.version}
77 | 1.3.0
78 | provided
79 |
80 |
81 | com.databricks
82 | spark-avro_${scala.binary.version}
83 | 1.0.0
84 | provided
85 |
86 |
87 | org.apache.spark
88 | spark-core_${scala.binary.version}
89 | 1.5.1
90 | provided
91 |
92 |
93 | org.apache.spark
94 | spark-mllib_${scala.binary.version}
95 | 1.5.1
96 | provided
97 |
98 |
99 | org.apache.hadoop
100 | hadoop-client
101 | 1.0.4
102 | provided
103 |
104 |
105 | org.spire-math
106 | spire_${scala.binary.version}
107 | 0.10.1
108 |
109 |
110 | org.scalanlp
111 | breeze_${scala.binary.version}
112 | 0.11.2
113 | provided
114 |
115 |
116 | com.github.scopt
117 | scopt_${scala.binary.version}
118 | 3.3.0
119 |
120 |
121 | org.scalatest
122 | scalatest_${scala.binary.version}
123 | 2.2.1
124 | test
125 |
126 |
127 |
128 |
129 |
130 |
131 | org.apache.maven.plugins
132 | maven-compiler-plugin
133 | 3.3
134 |
135 | ${java.version}
136 | ${java.version}
137 | UTF-8
138 | 1024m
139 | true
140 |
141 | -Xlint:all,-serial,-path,-XX:MaxPermSize=256m
142 |
143 |
144 |
145 |
146 | net.alchim31.maven
147 | scala-maven-plugin
148 | 3.2.2
149 |
150 |
151 | scala-compile-first
152 | process-resources
153 |
154 | compile
155 |
156 |
157 |
158 | scala-test-compile-first
159 | process-test-resources
160 |
161 | testCompile
162 |
163 |
164 |
165 | attach-scaladocs
166 | verify
167 |
168 | doc-jar
169 |
170 |
171 |
172 |
173 | ${scala.version}
174 | incremental
175 | true
176 |
177 | -unchecked
178 | -deprecation
179 | -feature
180 |
181 |
182 | -Xms1024m
183 | -Xmx1024m
184 | -XX:PermSize=${PermGen}
185 | -XX:MaxPermSize=${MaxPermGen}
186 | -XX:ReservedCodeCacheSize=${CodeCacheSize}
187 |
188 |
189 |
190 | -source
191 | ${java.version}
192 | -target
193 | ${java.version}
194 | -Xlint:all,-serial,-path
195 |
196 |
197 |
198 |
199 | org.scalastyle
200 | scalastyle-maven-plugin
201 | 0.7.0
202 |
203 | false
204 | true
205 | false
206 | false
207 | ${basedir}/src/main/scala
208 | ${basedir}/src/test/scala
209 | scalastyle-config.xml
210 | ${basedir}/target/scalastyle-output.xml
211 | ${project.build.sourceEncoding}
212 | ${project.reporting.outputEncoding}
213 |
214 |
215 |
216 |
217 | check
218 |
219 |
220 |
221 |
222 |
223 | org.scalatest
224 | scalatest-maven-plugin
225 | 1.0
226 |
227 | false
228 | ${project.build.directory}/surefire-reports
229 | .
230 | WDF TestSuite.txt
231 | -XX:PermSize=${PermGen}
232 |
233 |
234 |
235 | test
236 |
237 | test
238 |
239 |
240 |
241 |
242 |
243 | org.apache.maven.plugins
244 | maven-deploy-plugin
245 | 2.8.2
246 |
247 | internal.repo::default::file://${project.build.directory}/mvn-repo
248 |
249 |
250 |
251 | com.github.github
252 | site-maven-plugin
253 | 0.12
254 |
255 | Maven artifacts for ${project.version}
256 | true
257 | ${project.build.directory}/mvn-repo
258 | refs/heads/mvn-repo
259 | **/*
260 | SparkML2
261 | AlpineNow
262 |
263 |
264 |
265 |
266 |
267 | site
268 |
269 | deploy
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
--------------------------------------------------------------------------------
/scalastyle-config.xml:
--------------------------------------------------------------------------------
1 |
2 | Scalastyle standard configuration
3 |
4 |
5 |
6 |
7 |
8 |
9 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/discretization/BinFinder.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.discretization
19 |
20 | import org.apache.spark.rdd.RDD
21 | import spark_ml.util.ProgressNotifiee
22 |
23 | case class MinMaxPair(
24 | var minValue: Double,
25 | var maxValue: Double
26 | )
27 |
28 | class LabelSummary(val expectedCardinality: Option[Int]) extends Serializable {
29 | val catCounts: Option[Array[Long]] =
30 | if (expectedCardinality.isDefined) {
31 | Some(
32 | Array.fill[Long](expectedCardinality.get)(0L)
33 | )
34 | } else {
35 | None
36 | }
37 |
38 | var restCount: Long = 0L
39 | var hasNaN: Boolean = false
40 |
41 | var totalCount: Double = 0.0
42 | var runningAvg: Double = 0.0
43 | var runningSqrAvg: Double = 0.0
44 |
45 | private def updateRunningAvgs(
46 | bTotalCount: Double,
47 | bRunningAvg: Double,
48 | bRunningSqrAvg: Double): Unit = {
49 | if (totalCount > 0.0 || bTotalCount > 0.0) {
50 | val newTotalCount = totalCount + bTotalCount
51 | val aRatio = totalCount / newTotalCount
52 | val bRatio = bTotalCount / newTotalCount
53 | val newAvg = aRatio * runningAvg + bRatio * bRunningAvg
54 | val newSqrAvg = aRatio * runningSqrAvg + bRatio * bRunningSqrAvg
55 |
56 | totalCount = newTotalCount
57 | runningAvg = newAvg
58 | runningSqrAvg = newSqrAvg
59 | }
60 | }
61 |
62 | def addLabel(label: Double): Unit = {
63 | if (label.isNaN) {
64 | hasNaN = true
65 | } else {
66 | // Keep track of the running averages.
67 | updateRunningAvgs(1.0, label, label * label)
68 |
69 | if (expectedCardinality.isDefined && label.toLong.toDouble == label) {
70 | val cat = label.toInt
71 | if (cat >= 0 && cat < expectedCardinality.get) {
72 | catCounts.get(cat) += 1L
73 | } else {
74 | restCount += 1L
75 | }
76 | } else {
77 | restCount += 1L
78 | }
79 | }
80 | }
81 |
82 | def mergeInPlace(b: LabelSummary): this.type = {
83 | if (catCounts.isDefined) {
84 | var i = 0
85 | while (i < expectedCardinality.get) {
86 | catCounts.get(i) += b.catCounts.get(i)
87 | i += 1
88 | }
89 | }
90 |
91 | updateRunningAvgs(b.totalCount, b.runningAvg, b.runningSqrAvg)
92 |
93 | restCount += b.restCount
94 | hasNaN ||= b.hasNaN
95 | this
96 | }
97 | }
98 |
99 | /**
100 | * Classes implementing this trait is used to find bins in the dataset.
101 | */
102 | trait BinFinder {
103 | def findBins(
104 | data: RDD[(Double, Array[Double])],
105 | columnNames: (String, Array[String]),
106 | catIndices: Set[Int],
107 | maxNumBins: Int,
108 | expectedLabelCardinality: Option[Int],
109 | notifiee: ProgressNotifiee
110 | ): (LabelSummary, Seq[Bins])
111 |
112 | def getDataSummary(
113 | data: RDD[(Double, Array[Double])],
114 | numFeatures: Int,
115 | expectedLabelCardinality: Option[Int]
116 | ): (LabelSummary, Array[Boolean], Array[MinMaxPair]) = {
117 | val (ls, hn, mm) = data.mapPartitions(
118 | itr => {
119 | val labelSummary = new LabelSummary(expectedLabelCardinality)
120 | val nanExists = Array.fill[Boolean](numFeatures)(false)
121 | val minMaxes = Array.fill[MinMaxPair](numFeatures)(
122 | MinMaxPair(minValue = Double.PositiveInfinity, maxValue = Double.NegativeInfinity)
123 | )
124 | while (itr.hasNext) {
125 | val (label, features) = itr.next()
126 | labelSummary.addLabel(label)
127 | features.zipWithIndex.foreach {
128 | case (featValue, featIdx) =>
129 | nanExists(featIdx) = nanExists(featIdx) || featValue.isNaN
130 | if (!featValue.isNaN) {
131 | minMaxes(featIdx).minValue = math.min(featValue, minMaxes(featIdx).minValue)
132 | minMaxes(featIdx).maxValue = math.max(featValue, minMaxes(featIdx).maxValue)
133 | }
134 | }
135 | }
136 |
137 | Array((labelSummary, nanExists, minMaxes)).iterator
138 | }
139 | ).reduce {
140 | case ((labelSummary1, nanExists1, minMaxes1), (labelSummary2, nanExists2, minMaxes2)) =>
141 | (
142 | labelSummary1.mergeInPlace(labelSummary2),
143 | nanExists1.zip(nanExists2).map {
144 | case (nan1, nan2) => nan1 || nan2
145 | },
146 | minMaxes1.zip(minMaxes2).map {
147 | case (minMax1, minMax2) =>
148 | MinMaxPair(
149 | minValue = math.min(minMax1.minValue, minMax2.minValue),
150 | maxValue = math.max(minMax1.maxValue, minMax2.maxValue)
151 | )
152 | }
153 | )
154 | }
155 |
156 | // Labels shouldn't have NaN. So throw an exception if we found NaN for the
157 | // label.
158 | if (ls.hasNaN) {
159 | throw InvalidLabelException("Found NaN value for the label.")
160 | }
161 |
162 | (ls, hn, mm)
163 | }
164 | }
165 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/discretization/DiscretizationOptions.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.discretization
19 |
20 | /**
21 | * Discretization options.
22 | * @param discType Discretization type (Equal-Frequency/Width).
23 | * @param maxNumericBins Maximum number of numerical bins.
24 | * @param maxCatCardinality Maximum number of categorical bins.
25 | * @param useFeatureHashingOnCat Whether we should perform feature hashing on
26 | * categorical features whose cardinality goes
27 | * over 'maxCatCardinality'.
28 | * @param maxSampleSizeForDisc Maximum number of sample rows for certain numeric
29 | * discretizations. E.g., equal frequency
30 | * discretization uses a sample, instead of the
31 | * entire data.
32 | */
33 | case class DiscretizationOptions(
34 | discType: DiscType.DiscType,
35 | maxNumericBins: Int,
36 | maxCatCardinality: Int,
37 | useFeatureHashingOnCat: Boolean,
38 | maxSampleSizeForDisc: Int) {
39 | override def toString: String = {
40 | "=========================" + "\n" +
41 | "Discretization Options" + "\n" +
42 | "=========================" + "\n" +
43 | "discType : " + discType.toString + "\n" +
44 | "maxNumericBins : " + maxNumericBins + "\n" +
45 | "maxCatCardinality : " + maxCatCardinality + "\n" +
46 | "useFeatureHashingOnCat : " + useFeatureHashingOnCat.toString + "\n" +
47 | "maxSampleSizeForDisc : " + maxSampleSizeForDisc.toString + "\n"
48 | }
49 | }
50 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/discretization/Discretizer.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.discretization
19 |
20 | import scala.reflect.ClassTag
21 |
22 | import org.apache.spark.rdd.RDD
23 | import spark_ml.util.DiscretizedFeatureHandler
24 |
25 | /**
26 | * Available discretization types.
27 | */
28 | object DiscType extends Enumeration {
29 | type DiscType = Value
30 | val EqualFrequency = Value(0)
31 | val EqualWidth = Value(1)
32 | val MinimumEntropy = Value(2)
33 | val MinimumVariance = Value(3)
34 | }
35 |
36 | /**
37 | * A numeric bin with a lower bound and an upper bound.
38 | * The lower bound is inclusive. The upper bound is exclusive.
39 | * @param lower Lower bound
40 | * @param upper Upper bound
41 | */
42 | case class NumericBin(lower: Double, upper: Double) {
43 | def contains(value: Double): Boolean = {
44 | value >= lower && value < upper
45 | }
46 |
47 | override def toString: String = {
48 | "[" + lower.toString + "," + upper.toString + ")"
49 | }
50 | }
51 |
52 | /**
53 | * For both categorical and numerical bins.
54 | */
55 | trait Bins extends Serializable {
56 | /**
57 | * Get the number of bins (including the missing-value bin) in case of numeric
58 | * bins.
59 | * @return The number of bins.
60 | */
61 | def getCardinality: Int
62 |
63 | /**
64 | * Given the raw feature value, find the bin Id.
65 | * @param value Raw feature value.
66 | * @return The corresponding bin Id.
67 | */
68 | def findBinIdx(value: Double): Int
69 | }
70 |
71 | /**
72 | * An exception to throw n case a categorical feature value is not an integer
73 | * value (e.g., can't be 1.1. or 2.4).
74 | * @param msg String message to include in the exception.
75 | */
76 | case class InvalidCategoricalValueException(msg: String) extends Exception(msg)
77 |
78 | /**
79 | * An exception to throw if the cardinality of a feature exceeds the limit.
80 | * @param msg String message to include.
81 | */
82 | case class CardinalityOverLimitException(msg: String) extends Exception(msg)
83 |
84 | /**
85 | * An exception to throw if the label has unexpected values.
86 | * @param msg String message to include.
87 | */
88 | case class InvalidLabelException(msg: String) extends Exception(msg)
89 |
90 | /**
91 | * Numeric bins.
92 | * @param bins An array of ordered non-NaN numeric bins.
93 | * @param missingValueBinIdx The optional bin Id for the NaN values.
94 | */
95 | case class NumericBins(
96 | bins: Seq[NumericBin],
97 | missingValueBinIdx: Option[Int] = None) extends Bins {
98 | /**
99 | * The cardinality of the bins (the number of bins).
100 | * @return The number of bins.
101 | */
102 | def getCardinality: Int =
103 | bins.length + (if (missingValueBinIdx.isDefined) 1 else 0)
104 |
105 | /**
106 | * Find the index of the bin that the value belongs to.
107 | * @param value The numeric value we want to search.
108 | * @return The index of the bin that contains the given numeric value.
109 | */
110 | def findBinIdx(value: Double): Int = {
111 | if (value.isNaN) {
112 | missingValueBinIdx.get
113 | } else {
114 | // Binary search.
115 | var s = 0
116 | var e = bins.length - 1
117 | var cur = (s + e) / 2
118 | var found = bins(cur).contains(value)
119 | while (!found) {
120 | if (bins(cur).lower > value) {
121 | e = cur - 1
122 | } else if (bins(cur).upper <= value) {
123 | s = cur + 1
124 | }
125 |
126 | cur = (s + e) / 2
127 | found = bins(cur).contains(value)
128 | }
129 |
130 | cur
131 | }
132 | }
133 | }
134 |
135 | /**
136 | * For categorical bins, the raw feature value is simply the categorical bin Id.
137 | * We expect the categorical values to go from 0 to (Cardinality - 1) in an
138 | * incremental fashion.
139 | * @param cardinality The cardinality of the feature.
140 | */
141 | case class CategoricalBins(cardinality: Int) extends Bins {
142 | /**
143 | * @return The number of bins.
144 | */
145 | def getCardinality: Int = {
146 | cardinality
147 | }
148 |
149 | /**
150 | * Find the bin Id.
151 | * @param value Raw feature value.
152 | * @return The corresponding bin Id.
153 | */
154 | def findBinIdx(value: Double): Int = {
155 | if (value.toInt.toDouble != value) {
156 | throw InvalidCategoricalValueException(
157 | value + " is not a valid categorical value."
158 | )
159 | }
160 |
161 | if (value >= cardinality) {
162 | throw CardinalityOverLimitException(
163 | value + " is above the cardinality of this feature " + cardinality
164 | )
165 | }
166 |
167 | value.toInt
168 | }
169 | }
170 |
171 | object Discretizer {
172 | /**
173 | * Transform the features in the given labeled point row into an array of Bin
174 | * IDs that could be Unsigned Byte or Unsigned Short.
175 | * @param featureBins A sequence of feature bin definitions.
176 | * @param featureHandler A handler for discretized features (unsigned
177 | * Byte/Short).
178 | * @param row The labeled point row that we want to transform.
179 | * @return A transformed array of features.
180 | */
181 | private def transformFeatures[@specialized(Byte, Short) T: ClassTag](
182 | featureBins: Seq[Bins],
183 | featureHandler: DiscretizedFeatureHandler[T])(row: (Double, Array[Double])): Array[T] = {
184 | val (_, features) = row
185 | features.zipWithIndex.map {
186 | case (featureVal, featIdx) => featureHandler.convertToType(featureBins(featIdx).findBinIdx(featureVal))
187 | }
188 | }
189 |
190 | /**
191 | * Transform the features of the labeled data RDD into bin Ids of either
192 | * unsigned Byte/Short.
193 | * @param input An RDD of Double label and Double feature values.
194 | * @param featureBins A sequence of feature bin definitions.
195 | * @param featureHandler A handler for discretized features (unsigned
196 | * Byte/Short).
197 | * @return A new RDD that has all the features transformed into Bin Ids.
198 | */
199 | def transformFeatures[@specialized(Byte, Short) T: ClassTag](
200 | input: RDD[(Double, Array[Double])],
201 | featureBins: Seq[Bins],
202 | featureHandler: DiscretizedFeatureHandler[T]): RDD[Array[T]] = {
203 | input.map(transformFeatures(featureBins, featureHandler))
204 | }
205 | }
206 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/discretization/EntropyMinimizingBinFinderFromSample.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.discretization
19 |
20 | /**
21 | * Find bins that minimize the label class entropy. Essentially equivalent to
22 | * building a decision tree on a single feature with IG criteria.
23 | * @param maxSampleSize The maximum sample.
24 | * @param seed The seed to use with random sampling.
25 | */
26 | class EntropyMinimizingBinFinderFromSample(maxSampleSize: Int, seed: Int)
27 | extends NumericBinFinderFromSample("LabelEntropy", maxSampleSize, seed) {
28 |
29 | /**
30 | * Calculate the label class entropy of the segment.
31 | * @param labelValues A sample label values.
32 | * @param featureValues A sample feature values.
33 | * @param labelSummary This is useful if there's a need to find the label
34 | * class cardinality.
35 | * @param s The segment starting index (inclusive).
36 | * @param e The segment ending index (exclusive).
37 | * @return The loss of the segment.
38 | */
39 | override def findSegmentLoss(
40 | labelValues: Array[Double],
41 | featureValues: Array[Double],
42 | labelSummary: LabelSummary,
43 | s: Int,
44 | e: Int): Double = {
45 | val segSize = (e - s).toDouble
46 | val classCounts = Array.fill[Double](labelSummary.expectedCardinality.get)(0.0)
47 | var i = s
48 | while (i < e) {
49 | val labelValue = labelValues(i).toInt
50 | classCounts(labelValue) += 1.0
51 | i += 1
52 | }
53 |
54 | classCounts.foldLeft(0.0)(
55 | (entropy, cnt) => cnt / segSize match {
56 | case 0.0 => entropy
57 | case p => entropy - p * math.log(p)
58 | }
59 | )
60 | }
61 | }
62 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/discretization/EqualFrequencyBinFinderFromSample.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.discretization
19 |
20 | /**
21 | * Find bins that satisfy the 'roughly' equal frequency requirement for each
22 | * feature column. We calculate this from a sample of the data (the sample is
23 | * either at or less than maxSampleSize).
24 | * @param maxSampleSize The maximum sample.
25 | * @param seed The seed to use with random sampling.
26 | */
27 | class EqualFrequencyBinFinderFromSample(maxSampleSize: Int, seed: Int)
28 | extends NumericBinFinderFromSample("BinSize", maxSampleSize, seed) {
29 |
30 | /**
31 | * Loss for a segment is simply its size. Minimizing weighted sizes of bins
32 | * lead to equal frequency binning in the best case.
33 | * @param labelValues A sample label values.
34 | * @param featureValues A sample feature values.
35 | * @param labelSummary This is useful if there's a need to find the label
36 | * class cardinality.
37 | * @param s The segment starting index (inclusive).
38 | * @param e The segment ending index (exclusive).
39 | * @return The loss of the segment.
40 | */
41 | override def findSegmentLoss(
42 | labelValues: Array[Double],
43 | featureValues: Array[Double],
44 | labelSummary: LabelSummary,
45 | s: Int,
46 | e: Int): Double = {
47 | // Simply define the unsegmented section size as the loss.
48 | // It doesn't really have a big impact.
49 | (e - s).toDouble
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/discretization/EqualWidthBinFinder.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.discretization
19 |
20 | import org.apache.spark.rdd.RDD
21 | import spark_ml.util.ProgressNotifiee
22 |
23 | /**
24 | * Compute equal width bins for each numeric column.
25 | */
26 | class EqualWidthBinFinder extends BinFinder {
27 | def findBins(
28 | data: RDD[(Double, Array[Double])],
29 | columnNames: (String, Array[String]),
30 | catIndices: Set[Int],
31 | maxNumBins: Int,
32 | expectedLabelCardinality: Option[Int],
33 | notifiee: ProgressNotifiee
34 | ): (LabelSummary, Seq[Bins]) = {
35 | val numFeatures = columnNames._2.length
36 |
37 | // Find the label summary as well as the existence of feature values and the
38 | // min/max values.
39 | val (labelSummary, featureHasNan, minMaxValues) =
40 | this.getDataSummary(
41 | data,
42 | numFeatures,
43 | expectedLabelCardinality
44 | )
45 |
46 | val numericFeatureBins = featureHasNan.zipWithIndex.map {
47 | case (featHasNan, featIdx) if !catIndices.contains(featIdx) =>
48 | val minMax = minMaxValues(featIdx)
49 | if (minMax.minValue.isPosInfinity || (minMax.minValue == minMax.maxValue)) {
50 | NumericBins(
51 | Seq(
52 | NumericBin(lower = Double.NegativeInfinity, upper = Double.PositiveInfinity)
53 | ),
54 | if (featHasNan) Some(1) else None
55 | )
56 | } else {
57 | val nonNaNMaxNumBins = maxNumBins - (if (featHasNan) 1 else 0)
58 | val binWidth = (minMax.maxValue - minMax.minValue) / nonNaNMaxNumBins.toDouble
59 | NumericBins(
60 | (0 to (nonNaNMaxNumBins - 1)).map {
61 | binIdx => {
62 | binIdx match {
63 | case 0 => NumericBin(lower = Double.NegativeInfinity, upper = binWidth + minMax.minValue)
64 | case x if x == (nonNaNMaxNumBins - 1) => NumericBin(lower = minMax.maxValue - binWidth, upper = Double.PositiveInfinity)
65 | case _ => NumericBin(lower = minMax.minValue + binWidth * binIdx.toDouble, upper = minMax.minValue + binWidth * (binIdx + 1).toDouble)
66 | }
67 | }
68 | },
69 | if (featHasNan) Some(nonNaNMaxNumBins) else None
70 | )
71 | }
72 | case (featHasNan, featIdx) if catIndices.contains(featIdx) =>
73 | CategoricalBins((minMaxValues(featIdx).maxValue + 1.0).toInt)
74 | }
75 |
76 | (labelSummary, numericFeatureBins)
77 | }
78 | }
79 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/discretization/VarianceMinimizingBinFinderFromSample.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.discretization
19 |
20 | /**
21 | * Find bins that minimize the label variance. Essentially equivalent to
22 | * building a decision tree on a single feature with variance criteria.
23 | * @param maxSampleSize The maximum sample.
24 | * @param seed The seed to use with random sampling.
25 | */
26 | class VarianceMinimizingBinFinderFromSample(maxSampleSize: Int, seed: Int)
27 | extends NumericBinFinderFromSample("LabelVariance", maxSampleSize, seed) {
28 |
29 | /**
30 | * Calculate the label variance of the segment.
31 | * @param labelValues A sample label values.
32 | * @param featureValues A sample feature values.
33 | * @param labelSummary This is useful if there's a need to find the label
34 | * class cardinality.
35 | * @param s The segment starting index (inclusive).
36 | * @param e The segment ending index (exclusive).
37 | * @return The loss of the segment.
38 | */
39 | override def findSegmentLoss(
40 | labelValues: Array[Double],
41 | featureValues: Array[Double],
42 | labelSummary: LabelSummary,
43 | s: Int,
44 | e: Int): Double = {
45 | val segSize = (e - s).toDouble
46 | var sum = 0.0
47 | var sqrSum = 0.0
48 | var i = s
49 | while (i < e) {
50 | val labelValue = labelValues(i).toInt
51 | sum += labelValue
52 | sqrSum += labelValue * labelValue
53 | i += 1
54 | }
55 |
56 | val avg = sum / segSize
57 | sqrSum / segSize - avg * avg
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/GradientBoostingOptions.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting
19 |
20 | import spark_ml.gradient_boosting.loss.LossFunction
21 | import spark_ml.tree_ensembles.CatSplitType
22 | import spark_ml.util.BaggingType
23 |
24 | /**
25 | * Gradient boosting trainer options.
26 | * @param numTrees Number of trees to build. The algorithm can also determine
27 | * the optimal number of trees if a validation data is provided.
28 | * @param maxTreeDepth Maximum tree depth allowed per tree.
29 | * @param minSplitSize Min split size allowed per tree.
30 | * @param lossFunction The loss function object to use.
31 | * @param catSplitType How to split categorical features.
32 | * @param baggingRate Bagging rate.
33 | * @param baggingType Whether to bag with/without replacements.
34 | * @param shrinkage Shrinkage.
35 | * @param fineTuneTerminalNodes Whether to fine-tune tree's terminal nodes so
36 | * that their values are directly optimizing
37 | * against the loss function.
38 | * @param checkpointDir Checkpoint directory.
39 | * @param predCheckpointInterval Intermediate prediction checkpointing interval.
40 | * @param idCacheCheckpointInterval Id cache checkpointing interval.
41 | * @param verbose If true, the algorithm will print as much information through
42 | * the notifiee as possible, including many intermediate
43 | * computation values, etc.
44 | */
45 | case class GradientBoostingOptions(
46 | numTrees: Int,
47 | maxTreeDepth: Int,
48 | minSplitSize: Int,
49 | lossFunction: LossFunction,
50 | catSplitType: CatSplitType.CatSplitType,
51 | baggingRate: Double,
52 | baggingType: BaggingType.BaggingType,
53 | shrinkage: Double,
54 | fineTuneTerminalNodes: Boolean,
55 | checkpointDir: Option[String],
56 | predCheckpointInterval: Int,
57 | idCacheCheckpointInterval: Int,
58 | verbose: Boolean) {
59 | override def toString: String = {
60 | "=========================" + "\n" +
61 | "Gradient Boosting Options" + "\n" +
62 | "=========================" + "\n" +
63 | "numTrees : " + numTrees + "\n" +
64 | "maxTreeDepth : " + maxTreeDepth + "\n" +
65 | "minSplitSize : " + minSplitSize + "\n" +
66 | "lossFunction : " + lossFunction.getClass.getSimpleName + "\n" +
67 | "catSplitType : " + catSplitType.toString + "\n" +
68 | "baggingRate : " + baggingRate + "\n" +
69 | "baggingType : " + baggingType.toString + "\n" +
70 | "shrinkage : " + shrinkage + "\n" +
71 | "fineTuneTerminalNodes : " + fineTuneTerminalNodes + "\n" +
72 | "checkpointDir : " + (checkpointDir match { case None => "None" case Some(dir) => dir }) + "\n" +
73 | "predCheckpointInterval : " + predCheckpointInterval + "\n" +
74 | "idCacheCheckpointInterval : " + idCacheCheckpointInterval + "\n" +
75 | "verbose : " + verbose + "\n"
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/PredCache.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting
19 |
20 | import scala.collection.mutable
21 |
22 | import org.apache.hadoop.fs.{FileSystem, Path}
23 | import org.apache.spark.rdd.RDD
24 | import org.apache.spark.storage.StorageLevel
25 | import spark_ml.model.gb.GBInternalTree
26 | import spark_ml.util.DiscretizedFeatureHandler
27 |
28 | object PredCache {
29 | def createPredCache(
30 | initPreds: RDD[Double],
31 | shrinkage: Double,
32 | storageLevel: StorageLevel,
33 | checkpointDir: Option[String],
34 | checkpointInterval: Int): PredCache = {
35 | new PredCache(
36 | curPreds = initPreds,
37 | shrinkage = shrinkage,
38 | storageLevel = storageLevel,
39 | checkpointDir = checkpointDir,
40 | checkpointInterval = checkpointInterval
41 | )
42 | }
43 | }
44 |
45 | class PredCache(
46 | var curPreds: RDD[Double],
47 | shrinkage: Double,
48 | storageLevel: StorageLevel,
49 | checkpointDir: Option[String],
50 | checkpointInterval: Int) {
51 | private var prevPreds: RDD[Double] = null
52 | private var updateCount: Int = 0
53 |
54 | private val checkpointQueue = new mutable.Queue[RDD[Double]]()
55 |
56 | // Persist the initial predictions.
57 | curPreds = curPreds.persist(storageLevel)
58 |
59 | if (checkpointDir.isDefined && curPreds.sparkContext.getCheckpointDir.isEmpty) {
60 | curPreds.sparkContext.setCheckpointDir(checkpointDir.get)
61 | }
62 |
63 | def getRdd: RDD[Double] = curPreds
64 |
65 | def updatePreds[@specialized(Byte, Short) T](
66 | discFeatData: RDD[Array[T]],
67 | tree: GBInternalTree,
68 | featureHandler: DiscretizedFeatureHandler[T]): Unit = {
69 | if (prevPreds != null) {
70 | // Unpersist the previous one if one exists.
71 | prevPreds.unpersist(blocking = true)
72 | }
73 |
74 | prevPreds = curPreds
75 |
76 | // Need to do this since we don't want to serialize this object.
77 | val shk = shrinkage
78 | curPreds = discFeatData.zip(curPreds).map {
79 | case (features, curPred) =>
80 | curPred + shk * tree.predict(features, featureHandler)
81 | }.persist(storageLevel)
82 |
83 | updateCount += 1
84 |
85 | // Handle checkpointing if the directory is not None.
86 | if (curPreds.sparkContext.getCheckpointDir.isDefined &&
87 | (updateCount % checkpointInterval) == 0) {
88 | // See if we can delete previous checkpoints.
89 | var canDelete = true
90 | while (checkpointQueue.size > 1 && canDelete) {
91 | // We can delete the oldest checkpoint iff the next checkpoint actually
92 | // exists in the file system.
93 | if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
94 | val old = checkpointQueue.dequeue()
95 |
96 | // Since the old checkpoint is not deleted by Spark, we'll manually
97 | // delete it here.
98 | val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
99 | println("Deleting a stale PredCache RDD checkpoint at " + old.getCheckpointFile.get)
100 | fs.delete(new Path(old.getCheckpointFile.get), true)
101 | } else {
102 | canDelete = false
103 | }
104 | }
105 |
106 | curPreds.checkpoint()
107 | checkpointQueue.enqueue(curPreds)
108 | }
109 | }
110 |
111 | def close(): Unit = {
112 | // Unpersist and delete all the checkpoints.
113 | curPreds.unpersist(blocking = true)
114 | if (prevPreds != null) {
115 | prevPreds.unpersist(blocking = true)
116 | }
117 |
118 | while (checkpointQueue.nonEmpty) {
119 | val old = checkpointQueue.dequeue()
120 | if (old.getCheckpointFile.isDefined) {
121 | val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
122 | println("Deleting a stale PredCache RDD checkpoint at " + old.getCheckpointFile.get)
123 | fs.delete(new Path(old.getCheckpointFile.get), true)
124 | }
125 | }
126 | }
127 | }
128 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/LossAggregator.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss
19 |
20 | /**
21 | * Trait for aggregators used for loss calculations in gradient boosting
22 | * algorithms. The gradient boosting algorithm supports pluggable aggregators.
23 | * The implementations for different loss functions determine the type of
24 | * statistics to aggregate.
25 | */
26 | trait LossAggregator extends Serializable {
27 | /**
28 | * Add a sample to the aggregator.
29 | * @param label Label of the sample.
30 | * @param weight Weight of the sample.
31 | * @param curPred Current prediction
32 | * (should be computed using available trees).
33 | */
34 | def addSamplePoint(
35 | label: Double,
36 | weight: Double,
37 | curPred: Double): Unit
38 |
39 | /**
40 | * Compute the gradient of the sample at the current prediction.
41 | * @param label Label of the sample.
42 | * @param curPred Current prediction
43 | * (should be computed using available trees).
44 | */
45 | def computeGradient(
46 | label: Double,
47 | curPred: Double): Double
48 |
49 | /**
50 | * Merge the aggregated values with another aggregator.
51 | * @param b The other aggregator to merge with.
52 | * @return This.
53 | */
54 | def mergeInPlace(b: LossAggregator): this.type
55 |
56 | /**
57 | * Using the aggregated values, compute deviance.
58 | * @return Deviance.
59 | */
60 | def computeDeviance(): Double
61 |
62 | /**
63 | * Using the aggregated values, compute the initial value.
64 | * @return Inital value.
65 | */
66 | def computeInitialValue(): Double
67 |
68 | /**
69 | * Using the aggregated values for a particular node,
70 | * compute the estimated node value.
71 | * @return Node estimate.
72 | */
73 | def computeNodeEstimate(): Double
74 | }
75 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/LossFunction.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss
19 |
20 | /**
21 | * Each loss function should implement this trait and pass it onto the gradient
22 | * boosting algorithm. The job is currently mainly to provide proper aggregators
23 | * for computing losses and gradients.
24 | */
25 | trait LossFunction extends Serializable {
26 | /**
27 | * String name of the loss function.
28 | * @return The string name of the loss function.
29 | */
30 | def lossFunctionName: String
31 |
32 | /**
33 | * Create the loss aggregator for this loss function.
34 | * @return
35 | */
36 | def createAggregator: LossAggregator
37 |
38 | /**
39 | * If this loss function is used for categorical labels, this function returns
40 | * the expected label cardinality. E.g., loss functions like AdaBoost,
41 | * logistic losses are used for binary classification, so this should return
42 | * Some(2). For regressions like the Gaussian loss, this should return None.
43 | * @return either Some(cardinality) or None.
44 | */
45 | def getLabelCardinality: Option[Int]
46 |
47 | /**
48 | * Whether tree node estimate refinement is possible.
49 | * @return true if node estimates can be refined. false otherwise.
50 | */
51 | def canRefineNodeEstimate: Boolean
52 |
53 | /**
54 | * Convert the raw tree ensemble prediction into a usable form by applying
55 | * the mean function. E.g., this gives the actual regression prediction and/or
56 | * a probability of a class.
57 | * @param rawPred The raw tree ensemble prediction. E.g., this could be
58 | * unbounded negative or positive numbers for
59 | * AdaaBoost/LogLoss/PoissonLoss. We want to return bounded
60 | * numbers or actual count estimate for those losses, for
61 | * instance.
62 | * @return A mean function applied value.
63 | */
64 | def applyMeanFunction(rawPred: Double): Double
65 | }
66 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/AdaBoostLossAggregator.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossAggregator
21 | import spark_ml.util.RobustMath
22 |
23 | /**
24 | * The AdaBoost loss aggregator. This is used for binary classifications.
25 | * Gradient boosting's AdaBoost is an approximation of the original AdaBoosting
26 | * in that the exponential loss is reduced through gradient steps (original
27 | * AdaBoost has a different optimization routine).
28 | */
29 | class AdaBoostLossAggregator extends LossAggregator {
30 | private var weightSum: Double = 0.0
31 | private var weightedLabelSum: Double = 0.0
32 | private var weightedLossSum: Double = 0.0
33 | private var weightedNumeratorSum: Double = 0.0
34 |
35 | // If we want to show log loss.
36 | // private var weightedLogLossSum: Double = 0.0
37 |
38 | /**
39 | * Add a sample to the aggregator.
40 | * @param label Label of the sample.
41 | * @param weight Weight of the sample.
42 | * @param curPred Current prediction (e.g., using available trees).
43 | */
44 | def addSamplePoint(
45 | label: Double,
46 | weight: Double,
47 | curPred: Double): Unit = {
48 | val a = 2 * label - 1.0
49 | val sampleLoss = math.exp(-a * curPred)
50 |
51 | weightSum += weight
52 | weightedLabelSum += weight * label
53 | weightedLossSum += weight * sampleLoss
54 | weightedNumeratorSum += weight * a * sampleLoss
55 |
56 | // val prob = 1.0 / (1.0 + math.exp(-2.0 * curPred))
57 | // val logLoss = -(label * math.log(prob) + (1.0 - label) * math.log(1.0 - prob))
58 | // weightedLogLossSum += weight * logLoss
59 | }
60 |
61 | /**
62 | * Compute the gradient.
63 | * @param label Label of the sample.
64 | * @param curPred Current prediction (e.g., using available trees).
65 | * @return The gradient of the sample.
66 | */
67 | def computeGradient(
68 | label: Double,
69 | curPred: Double): Double = {
70 | val a = 2 * label - 1.0
71 | a * math.exp(-a * curPred)
72 | }
73 |
74 | /**
75 | * Merge the aggregated values with another aggregator.
76 | * @param b The other aggregator to merge with.
77 | * @return This.
78 | */
79 | def mergeInPlace(b: LossAggregator): this.type = {
80 | this.weightSum += b.asInstanceOf[AdaBoostLossAggregator].weightSum
81 | this.weightedLabelSum += b.asInstanceOf[AdaBoostLossAggregator].weightedLabelSum
82 | this.weightedLossSum += b.asInstanceOf[AdaBoostLossAggregator].weightedLossSum
83 | this.weightedNumeratorSum += b.asInstanceOf[AdaBoostLossAggregator].weightedNumeratorSum
84 | // this.weightedLogLossSum += b.asInstanceOf[AdaBoostLossAggregator].weightedLogLossSum
85 | this
86 | }
87 |
88 | /**
89 | * Using the aggregated values, compute deviance.
90 | * @return Deviance.
91 | */
92 | def computeDeviance(): Double = {
93 | weightedLossSum / weightSum
94 | // weightedLogLossSum / weightSum
95 | }
96 |
97 | /**
98 | * Using the aggregated values, compute the initial value.
99 | * @return Inital value.
100 | */
101 | def computeInitialValue(): Double = {
102 | RobustMath.log(weightedLabelSum / (weightSum - weightedLabelSum)) / 2.0
103 | }
104 |
105 | /**
106 | * Using the aggregated values for a particular node, compute the estimated
107 | * node value.
108 | * @return Node estimate.
109 | */
110 | def computeNodeEstimate(): Double = {
111 | // For the adaboost loss (exponential loss), the node estimate is
112 | // approximated as one Newton-Raphson method step's result, as the optimal
113 | // value will be either negative or positive infinities.
114 | weightedNumeratorSum / weightedLossSum
115 | }
116 | }
117 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/AdaBoostLossFunction.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossFunction
21 |
22 | /**
23 | * The AdaBoost loss function. This is used for binary classifications.
24 | * Gradient boosting's AdaBoost is an approximation of the original AdaBoosting
25 | * in that the exponential loss is reduced through gradient steps (original
26 | * AdaBoost has a different optimization routine).
27 | */
28 | class AdaBoostLossFunction extends LossFunction {
29 | private val eps = 1e-15
30 |
31 | def lossFunctionName = "AdaBoost(Exponential)"
32 | def createAggregator = new AdaBoostLossAggregator
33 | def getLabelCardinality: Option[Int] = Some(2)
34 | def canRefineNodeEstimate: Boolean = true
35 |
36 | def applyMeanFunction(rawPred: Double): Double = {
37 | math.min(math.max(1.0 / (1.0 + math.exp(-2.0 * rawPred)), eps), 1.0 - eps)
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/GaussianLossAggregator.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossAggregator
21 |
22 | /**
23 | * Aggregator for Gaussian Losses.
24 | */
25 | class GaussianLossAggregator extends LossAggregator {
26 | private var weightSum: Double = 0.0
27 | private var weightedSqrLossSum = 0.0
28 | private var weightedLabelSum = 0.0
29 |
30 | /**
31 | * Add a sample to the aggregator.
32 | * @param label Label of the sample.
33 | * @param weight Weight of the sample.
34 | * @param curPred Current prediction (e.g., using available trees).
35 | */
36 | def addSamplePoint(
37 | label: Double,
38 | weight: Double,
39 | curPred: Double): Unit = {
40 | val gradient = computeGradient(label, curPred)
41 | val weightedGradient = weight * gradient
42 | weightSum += weight
43 | weightedSqrLossSum += weightedGradient * gradient
44 | weightedLabelSum += weight * label
45 | }
46 |
47 | /**
48 | * Compute the gaussian gradient.
49 | * @param label Label of the sample.
50 | * @param curPred Current prediction (e.g., using available trees).
51 | * @return The gradient of the sample.
52 | */
53 | def computeGradient(
54 | label: Double,
55 | curPred: Double): Double = {
56 | label - curPred
57 | }
58 |
59 | /**
60 | * Merge the aggregated values with another aggregator.
61 | * @param b The other aggregator to merge with.
62 | * @return This.
63 | */
64 | def mergeInPlace(b: LossAggregator): this.type = {
65 | this.weightSum += b.asInstanceOf[GaussianLossAggregator].weightSum
66 | this.weightedSqrLossSum += b.asInstanceOf[GaussianLossAggregator].weightedSqrLossSum
67 | this.weightedLabelSum += b.asInstanceOf[GaussianLossAggregator].weightedLabelSum
68 | this
69 | }
70 |
71 | /**
72 | * Using the aggregated values, compute deviance.
73 | * @return Deviance.
74 | */
75 | def computeDeviance(): Double = {
76 | weightedSqrLossSum / weightSum
77 | }
78 |
79 | /**
80 | * Using the aggregated values, compute the initial value.
81 | * @return Inital value.
82 | */
83 | def computeInitialValue(): Double = {
84 | weightedLabelSum / weightSum
85 | }
86 |
87 | /**
88 | * Using the aggregated values for a particular node, compute the estimated
89 | * node value.
90 | * @return Node estimate.
91 | */
92 | def computeNodeEstimate(): Double = ???
93 | }
94 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/GaussianLossFunction.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossFunction
21 |
22 | /**
23 | * The gaussian loss function.
24 | */
25 | class GaussianLossFunction extends LossFunction {
26 | def lossFunctionName = "Gaussian"
27 | def createAggregator = new GaussianLossAggregator
28 | def getLabelCardinality: Option[Int] = None
29 | def canRefineNodeEstimate: Boolean = false
30 |
31 | def applyMeanFunction(rawPred: Double): Double = rawPred
32 | }
33 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/LaplacianLossAggregator.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import scala.util.Random
21 |
22 | import spark_ml.gradient_boosting.loss.LossAggregator
23 | import spark_ml.util.{Selection, ReservoirSample}
24 |
25 | /**
26 | * Aggregator for Laplacian Losses.
27 | * It uses approximate medians for initial value estimates but that is not
28 | * likely to affect the performance drastically.
29 | */
30 | class LaplacianLossAggregator extends LossAggregator {
31 | private val maxSample = 10000
32 |
33 | private var labelSample = new ReservoirSample(maxSample)
34 | private var labelPredDiffSample = new ReservoirSample(maxSample)
35 |
36 | private var weightSum: Double = 0.0
37 | private var weightedAbsLabelPredDiffSum: Double = 0.0
38 |
39 | // For reservoir sampling.
40 | @transient private var rng: Random = null
41 |
42 | /**
43 | * Add a sample to the aggregator.
44 | * @param label Label of the sample.
45 | * @param weight Weight of the sample.
46 | * @param curPred Current prediction (e.g., using available trees).
47 | */
48 | def addSamplePoint(
49 | label: Double,
50 | weight: Double,
51 | curPred: Double): Unit = {
52 | if (weight > 0.0) {
53 | if (this.rng == null) {
54 | this.rng = new Random()
55 | }
56 |
57 | this.labelSample.doReservoirSampling(label, rng)
58 |
59 | val labelPredDiff = label - curPred
60 | this.labelPredDiffSample.doReservoirSampling(labelPredDiff, rng)
61 |
62 | this.weightSum += weight
63 | this.weightedAbsLabelPredDiffSum += weight * math.abs(labelPredDiff)
64 | }
65 | }
66 |
67 | /**
68 | * Compute the gradient.
69 | * @param label Label of the sample.
70 | * @param curPred Current prediction (e.g., using available trees).
71 | * @return The gradient of the sample.
72 | */
73 | def computeGradient(
74 | label: Double,
75 | curPred: Double): Double = {
76 | math.signum(label - curPred)
77 | }
78 |
79 | /**
80 | * Merge the aggregated values with another aggregator.
81 | * @param b The other aggregator to merge with.
82 | * @return This.
83 | */
84 | def mergeInPlace(b: LossAggregator): this.type = {
85 | if (this.rng == null) {
86 | this.rng = new Random()
87 | }
88 |
89 | this.weightSum += b.asInstanceOf[LaplacianLossAggregator].weightSum
90 | this.weightedAbsLabelPredDiffSum += b.asInstanceOf[LaplacianLossAggregator].weightedAbsLabelPredDiffSum
91 | // Now merge samples (distributed reservoir sampling).
92 | this.labelSample = ReservoirSample.mergeReservoirSamples(
93 | this.labelSample,
94 | b.asInstanceOf[LaplacianLossAggregator].labelSample,
95 | maxSample,
96 | this.rng
97 | )
98 | this.labelPredDiffSample = ReservoirSample.mergeReservoirSamples(
99 | this.labelPredDiffSample,
100 | b.asInstanceOf[LaplacianLossAggregator].labelPredDiffSample,
101 | maxSample,
102 | this.rng
103 | )
104 | this
105 | }
106 |
107 | /**
108 | * Using the aggregated values, compute deviance.
109 | * @return Deviance.
110 | */
111 | def computeDeviance(): Double = {
112 | weightedAbsLabelPredDiffSum / weightSum
113 | }
114 |
115 | /**
116 | * Using the aggregated values, compute the initial value.
117 | * @return Inital value.
118 | */
119 | def computeInitialValue(): Double = {
120 | if (this.rng == null) {
121 | this.rng = new Random()
122 | }
123 |
124 | // Get the initial value by computing the label median.
125 | Selection.quickSelect(
126 | this.labelSample.sample,
127 | 0,
128 | this.labelSample.numSamplePoints,
129 | this.labelSample.numSamplePoints / 2,
130 | rng = this.rng
131 | )
132 | }
133 |
134 | /**
135 | * Using the aggregated values for a particular node, compute the estimated
136 | * node value.
137 | * @return Node estimate.
138 | */
139 | def computeNodeEstimate(): Double = {
140 | if (this.rng == null) {
141 | this.rng = new Random()
142 | }
143 |
144 | // Get the node estimate by computing the label pred diff median.
145 | Selection.quickSelect(
146 | this.labelPredDiffSample.sample,
147 | 0,
148 | this.labelPredDiffSample.numSamplePoints,
149 | this.labelPredDiffSample.numSamplePoints / 2,
150 | rng = this.rng
151 | )
152 | }
153 | }
154 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/LaplacianLossFunction.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossFunction
21 |
22 | /**
23 | * The Laplacian loss function.
24 | */
25 | class LaplacianLossFunction extends LossFunction {
26 | def lossFunctionName = "Laplacian"
27 | def createAggregator = new LaplacianLossAggregator
28 | def getLabelCardinality: Option[Int] = None
29 | def canRefineNodeEstimate: Boolean = true
30 |
31 | def applyMeanFunction(rawPred: Double): Double = rawPred
32 | }
33 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/LogLossAggregator.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossAggregator
21 | import spark_ml.util.RobustMath
22 |
23 | /**
24 | * The log loss aggregator. This is a binary logistic regression aggregator.
25 | */
26 | class LogLossAggregator extends LossAggregator {
27 | private var weightSum: Double = 0.0
28 | private var weightedLabelSum: Double = 0.0
29 | private var weightedLossSum: Double = 0.0
30 | private var weightedProbSum: Double = 0.0
31 | private var weightedProbSquareSum: Double = 0.0
32 |
33 | private val eps = 1e-15
34 |
35 | /**
36 | * Add a sample to the aggregator.
37 | * @param label Label of the sample.
38 | * @param weight Weight of the sample.
39 | * @param curPred Current prediction (e.g., using available trees).
40 | */
41 | def addSamplePoint(
42 | label: Double,
43 | weight: Double,
44 | curPred: Double): Unit = {
45 | val prob = math.min(math.max(1.0 / (1.0 + math.exp(-curPred)), eps), 1.0 - eps)
46 | val logLoss = -(label * math.log(prob) + (1.0 - label) * math.log(1.0 - prob))
47 | weightSum += weight
48 | weightedLabelSum += weight * label
49 | weightedLossSum += weight * logLoss
50 | weightedProbSum += weight * prob
51 | weightedProbSquareSum += weight * prob * prob
52 | }
53 |
54 | /**
55 | * Compute the gradient.
56 | * @param label Label of the sample.
57 | * @param curPred Current prediction (e.g., using available trees).
58 | * @return The gradient of the sample.
59 | */
60 | def computeGradient(
61 | label: Double,
62 | curPred: Double): Double = {
63 | val prob = math.min(math.max(1.0 / (1.0 + math.exp(-curPred)), eps), 1.0 - eps)
64 | label - prob
65 | }
66 |
67 | /**
68 | * Merge the aggregated values with another aggregator.
69 | * @param b The other aggregator to merge with.
70 | * @return This.
71 | */
72 | def mergeInPlace(b: LossAggregator): this.type = {
73 | this.weightSum += b.asInstanceOf[LogLossAggregator].weightSum
74 | this.weightedLabelSum += b.asInstanceOf[LogLossAggregator].weightedLabelSum
75 | this.weightedLossSum += b.asInstanceOf[LogLossAggregator].weightedLossSum
76 | this.weightedProbSum += b.asInstanceOf[LogLossAggregator].weightedProbSum
77 | this.weightedProbSquareSum += b.asInstanceOf[LogLossAggregator].weightedProbSquareSum
78 | this
79 | }
80 |
81 | /**
82 | * Using the aggregated values, compute deviance.
83 | * @return Deviance.
84 | */
85 | def computeDeviance(): Double = {
86 | weightedLossSum / weightSum
87 | }
88 |
89 | /**
90 | * Using the aggregated values, compute the initial value.
91 | * @return Inital value.
92 | */
93 | def computeInitialValue(): Double = {
94 | RobustMath.log(weightedLabelSum / (weightSum - weightedLabelSum))
95 | }
96 |
97 | /**
98 | * Using the aggregated values for a particular node, compute the estimated
99 | * node value.
100 | * @return Node estimate.
101 | */
102 | def computeNodeEstimate(): Double = {
103 | // For the log loss, the node estimate is approximated as one Newton-Raphson
104 | // method step's result, as the optimal value will be either negative or
105 | // positive infinities.
106 | (weightedLabelSum - weightedProbSum) / (weightedProbSum - weightedProbSquareSum)
107 | }
108 | }
109 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/LogLossFunction.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossFunction
21 |
22 | /**
23 | * The LogLoss loss function (also known as cross-entropy). This is equivalent
24 | * to a logistic regression.
25 | */
26 | class LogLossFunction extends LossFunction {
27 | private val eps = 1e-15
28 |
29 | def lossFunctionName = "LogLoss(LogisticRegression)"
30 | def createAggregator = new LogLossAggregator
31 | def getLabelCardinality: Option[Int] = Some(2)
32 | def canRefineNodeEstimate: Boolean = true
33 |
34 | def applyMeanFunction(rawPred: Double): Double = {
35 | math.min(math.max(1.0 / (1.0 + math.exp(-rawPred)), eps), 1.0 - eps)
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/PoissonLossAggregator.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossAggregator
21 | import spark_ml.util.RobustMath
22 |
23 | /**
24 | * Aggregator for Poisson Losses.
25 | */
26 | class PoissonLossAggregator extends LossAggregator {
27 | private var weightSum: Double = 0.0
28 | private var weightedLabelPredMulSum = 0.0
29 | private var weightedExpPredSum: Double = 0.0
30 | private var weightedLabelSum: Double = 0.0
31 |
32 | /**
33 | * Add a sample to the aggregator.
34 | * @param label Label of the sample.
35 | * @param weight Weight of the sample.
36 | * @param curPred Current prediction (e.g., using available trees).
37 | */
38 | def addSamplePoint(
39 | label: Double,
40 | weight: Double,
41 | curPred: Double): Unit = {
42 | val labelPredMul = label * curPred
43 | val expPred = RobustMath.exp(curPred)
44 | weightSum += weight
45 | weightedLabelPredMulSum += weight * labelPredMul
46 | weightedExpPredSum += weight * expPred
47 | weightedLabelSum += weight * label
48 | }
49 |
50 | /**
51 | * Compute the gradient.
52 | * @param label Label of the sample.
53 | * @param curPred Current prediction (e.g., using available trees).
54 | * @return The gradient of the sample.
55 | */
56 | def computeGradient(
57 | label: Double,
58 | curPred: Double): Double = {
59 | val expPred = RobustMath.exp(curPred)
60 | label - expPred
61 | }
62 |
63 | /**
64 | * Merge the aggregated values with another aggregator.
65 | * @param b The other aggregator to merge with.
66 | * @return This.
67 | */
68 | def mergeInPlace(b: LossAggregator): this.type = {
69 | this.weightSum += b.asInstanceOf[PoissonLossAggregator].weightSum
70 | this.weightedLabelSum += b.asInstanceOf[PoissonLossAggregator].weightedLabelSum
71 | this.weightedLabelPredMulSum += b.asInstanceOf[PoissonLossAggregator].weightedLabelPredMulSum
72 | this.weightedExpPredSum += b.asInstanceOf[PoissonLossAggregator].weightedExpPredSum
73 | this
74 | }
75 |
76 | /**
77 | * Using the aggregated values, compute deviance.
78 | * @return Deviance.
79 | */
80 | def computeDeviance(): Double = {
81 | -2.0 * (weightedLabelPredMulSum - weightedExpPredSum) / weightSum
82 | }
83 |
84 | /**
85 | * Using the aggregated values, compute the initial value.
86 | * @return Inital value.
87 | */
88 | def computeInitialValue(): Double = {
89 | RobustMath.log(weightedLabelSum / weightSum)
90 | }
91 |
92 | /**
93 | * Using the aggregated values for a particular node, compute the estimated
94 | * node value.
95 | * @return Node estimate.
96 | */
97 | def computeNodeEstimate(): Double = {
98 | RobustMath.log(weightedLabelSum / weightedExpPredSum)
99 | }
100 | }
101 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/PoissonLossFunction.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossFunction
21 |
22 | /**
23 | * The poisson loss function.
24 | */
25 | class PoissonLossFunction extends LossFunction {
26 | private val expPredLowerLimit = math.exp(-19.0)
27 | private val expPredUpperLimit = math.exp(19.0)
28 |
29 | def lossFunctionName = "Poisson"
30 | def createAggregator = new PoissonLossAggregator
31 | def getLabelCardinality: Option[Int] = None
32 | def canRefineNodeEstimate: Boolean = true
33 |
34 | def applyMeanFunction(rawPred: Double): Double = {
35 | math.min(math.max(math.exp(rawPred), expPredLowerLimit), expPredUpperLimit)
36 | }
37 | }
38 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/TruncHingeLossAggregator.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossAggregator
21 | import spark_ml.util.RobustMath
22 |
23 | class TruncHingeLossAggregator(maxProb: Double, eps: Double) extends LossAggregator {
24 | private var weightSum: Double = 0.0
25 | private var weightedLabelSum: Double = 0.0
26 | private var weightedLossSum: Double = 0.0
27 |
28 | private val minProb = 1.0 - maxProb
29 | private val maxLinearValue = RobustMath.log(maxProb / minProb)
30 |
31 | /**
32 | * Add a sample to the aggregator.
33 | * @param label Label of the sample.
34 | * @param weight Weight of the sample.
35 | * @param curPred Current prediction (e.g., using available trees).
36 | */
37 | def addSamplePoint(
38 | label: Double,
39 | weight: Double,
40 | curPred: Double): Unit = {
41 | val positiveLoss = math.min(math.max(-(curPred - eps), 0.0), maxLinearValue)
42 | val negativeLoss = math.min(math.max(curPred + eps, 0.0), maxLinearValue)
43 | val loss = label * positiveLoss + (1.0 - label) * negativeLoss
44 | weightSum += weight
45 | weightedLabelSum += weight * label
46 | weightedLossSum += weight * loss
47 | }
48 |
49 | /**
50 | * Compute the gradient.
51 | * @param label Label of the sample.
52 | * @param curPred Current prediction (e.g., using available trees).
53 | * @return The gradient of the sample.
54 | */
55 | def computeGradient(
56 | label: Double,
57 | curPred: Double): Double = {
58 | val positiveLoss = math.max(-(curPred - eps), 0.0)
59 | val positiveLossGradient =
60 | if (positiveLoss == 0.0 || positiveLoss > maxLinearValue) {
61 | 0.0
62 | } else {
63 | 1.0
64 | }
65 | val negativeLoss = math.max(curPred + eps, 0.0)
66 | val negativeLossGradient =
67 | if (negativeLoss == 0.0 || negativeLoss > maxLinearValue) {
68 | 0.0
69 | } else {
70 | -1.0
71 | }
72 | label * positiveLossGradient + (1.0 - label) * negativeLossGradient
73 | }
74 |
75 | /**
76 | * Merge the aggregated values with another aggregator.
77 | * @param b The other aggregator to merge with.
78 | * @return This.
79 | */
80 | def mergeInPlace(b: LossAggregator): this.type = {
81 | this.weightSum += b.asInstanceOf[TruncHingeLossAggregator].weightSum
82 | this.weightedLabelSum += b.asInstanceOf[TruncHingeLossAggregator].weightedLabelSum
83 | this.weightedLossSum += b.asInstanceOf[TruncHingeLossAggregator].weightedLossSum
84 | this
85 | }
86 |
87 | /**
88 | * Using the aggregated values, compute deviance.
89 | * @return Deviance.
90 | */
91 | def computeDeviance(): Double = {
92 | weightedLossSum / weightSum
93 | }
94 |
95 | /**
96 | * Using the aggregated values, compute the initial value.
97 | * @return Inital value.
98 | */
99 | def computeInitialValue(): Double = {
100 | RobustMath.log(weightedLabelSum / (weightSum - weightedLabelSum))
101 | }
102 |
103 | /**
104 | * Using the aggregated values for a particular node, compute the estimated
105 | * node value.
106 | * @return Node estimate.
107 | */
108 | def computeNodeEstimate(): Double = ???
109 | }
110 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/gradient_boosting/loss/defaults/TruncHingeLossFunction.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.gradient_boosting.loss.defaults
19 |
20 | import spark_ml.gradient_boosting.loss.LossFunction
21 |
22 | class TruncHingeLossFunction extends LossFunction {
23 | private var maxProb = 0.0
24 | private var eps = 0.0
25 | private var minProb = 0.0
26 | private var maxLinearValue = 0.0
27 |
28 | // Default values.
29 | setMaxProb(0.9)
30 | setEps(0.12)
31 |
32 | def setMaxProb(maxp: Double): Unit = {
33 | maxProb = maxp
34 | minProb = 1.0 - maxProb
35 | maxLinearValue = math.log(maxProb / minProb)
36 | }
37 |
38 | def setEps(e: Double): Unit = {
39 | eps = e
40 | }
41 |
42 | def lossFunctionName = "TruncHingeLoss"
43 | def createAggregator = new TruncHingeLossAggregator(maxProb, eps)
44 | def getLabelCardinality: Option[Int] = Some(2)
45 | def canRefineNodeEstimate: Boolean = false
46 |
47 | def applyMeanFunction(rawPred: Double): Double = {
48 | math.min(math.max(1.0 / (1.0 + math.exp(-rawPred)), minProb), maxProb)
49 | }
50 | }
51 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/model/gb/GradientBoostedTrees.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.model.gb
19 |
20 | import spark_ml.gradient_boosting.loss.LossFunction
21 | import spark_ml.model.DecisionTree
22 | import spark_ml.transformation.ColumnTransformer
23 |
24 | case class GradientBoostedTreesDefault(
25 | lossFunctionClassName: String,
26 | labelTransformer: ColumnTransformer,
27 | featureTransformers: Array[ColumnTransformer],
28 | labelName: String,
29 | labelIsCat: Boolean,
30 | featureNames: Array[String],
31 | featureIsCat: Array[Boolean],
32 | sortedVarImportance: Seq[(String, java.lang.Double)],
33 | shrinkage: Double,
34 | initValue: Double,
35 | decisionTrees: Array[DecisionTree],
36 | optimalTreeCnt: Option[java.lang.Integer],
37 | trainingDevianceHistory: Seq[java.lang.Double],
38 | validationDevianceHistory: Option[Seq[java.lang.Double]]) extends GradientBoostedTrees {
39 |
40 | def lossFunction: LossFunction = {
41 | Class.forName(lossFunctionClassName).newInstance().asInstanceOf[LossFunction]
42 | }
43 |
44 | def numTrees = decisionTrees.length
45 |
46 | /**
47 | * The prediction is done on raw features. Internally, the model should
48 | * transform them to proper forms before predicting.
49 | * @param rawFeatures Raw features.
50 | * @param useOptimalTreeCnt A flag to indicate that we want to use optimal
51 | * number of trees.
52 | * @return The predicted value.
53 | */
54 | def predict(rawFeatures: Seq[Any], useOptimalTreeCnt: Boolean): Double = {
55 | // Transform the raw features first.
56 | val transformedFeatures =
57 | featureTransformers.zip(rawFeatures).map {
58 | case (featTransformer, rawFeatVal) =>
59 | if (rawFeatVal == null) {
60 | featTransformer.transform(null)
61 | } else {
62 | featTransformer.transform(rawFeatVal.toString)
63 | }
64 | }
65 | val predictedValue =
66 | (
67 | if (useOptimalTreeCnt && optimalTreeCnt.isDefined) {
68 | decisionTrees.slice(0, optimalTreeCnt.get)
69 | } else {
70 | decisionTrees
71 | }
72 | ).foldLeft(initValue) {
73 | case (curPred, tree) =>
74 | curPred + shrinkage * tree.predict(transformedFeatures)
75 | }
76 |
77 | lossFunction.applyMeanFunction(predictedValue)
78 | }
79 | }
80 |
81 | /**
82 | * A public model for the Gradient boosted trees that have been trained should
83 | * extend this trait.
84 | */
85 | trait GradientBoostedTrees extends Serializable {
86 | def lossFunction: LossFunction
87 | def initValue: Double
88 | def numTrees: Int
89 | def optimalTreeCnt: Option[java.lang.Integer]
90 |
91 | /**
92 | * Sorted variable importance.
93 | * @return Sorted variable importance.
94 | */
95 | def sortedVarImportance: Seq[(String, java.lang.Double)]
96 |
97 | /**
98 | * The prediction is done on raw features. Internally, the model should
99 | * transform them to proper forms before predicting.
100 | * @param rawFeatures Raw features.
101 | * @param useOptimalTreeCnt A flag to indicate that we want to use optimal
102 | * number of trees.
103 | * @return The predicted value.
104 | */
105 | def predict(rawFeatures: Seq[Any], useOptimalTreeCnt: Boolean): Double
106 | }
107 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/model/gb/GradientBoostedTreesFactory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.model.gb
19 |
20 | import scala.collection.mutable
21 |
22 | import spark_ml.discretization.Bins
23 | import spark_ml.model.{DecisionTree, DecisionTreeNode, DecisionTreeUtil}
24 | import spark_ml.transformation.ColumnTransformer
25 |
26 | /**
27 | * The trainer will store the trained model using the internal types defined
28 | * in this file. If a developer wants to return a customized GBT type as a
29 | * result, then he/she can implement this factory trait and pass it onto the
30 | * trainer.
31 | */
32 | trait GradientBoostedTreesFactory {
33 | /**
34 | * The GB tree ensemble factory needs to know how to transform label and
35 | * features from potentially categorical/string values to enumerated numeric
36 | * values. E.g. the final model might use enumerated categorical values to
37 | * perform predictions.
38 | * @param labelTransformer Label transformer.
39 | * @param featureTransformers Feature transformers.
40 | */
41 | def setColumnTransformers(
42 | labelTransformer: ColumnTransformer,
43 | featureTransformers: Array[ColumnTransformer]
44 | ): Unit
45 |
46 | /**
47 | * The GB tree ensemble factory also needs to know label/feature names and
48 | * types.
49 | * @param labelName Label name.
50 | * @param labelIsCat Whether the label is categorical.
51 | * @param featureNames Feature names.
52 | * @param featureIsCat Whether the individual features are categorical.
53 | */
54 | def setColumnNamesAndTypes(
55 | labelName: String,
56 | labelIsCat: Boolean,
57 | featureNames: Array[String],
58 | featureIsCat: Array[Boolean]
59 | ): Unit
60 |
61 | /**
62 | * Set the optimal tree count, as determined through validations.
63 | * @param optimalTreeCnt The optimal tree count.
64 | */
65 | def setOptimalTreeCnt(optimalTreeCnt: Int): Unit
66 |
67 | /**
68 | * Set the training deviance history.
69 | * @param trainingDevianceHistory Training deviance history.
70 | */
71 | def setTrainingDevianceHistory(
72 | trainingDevianceHistory: mutable.ListBuffer[Double]
73 | ): Unit
74 |
75 | /**
76 | * Set the validation deviance history.
77 | * @param validationDevianceHistory Validation deviance history.
78 | */
79 | def setValidationDevianceHistory(
80 | validationDevianceHistory: mutable.ListBuffer[Double]
81 | ): Unit
82 |
83 | /**
84 | * Set the feature bins.
85 | * @param featureBins Feature bins.
86 | */
87 | def setFeatureBins(featureBins: Array[Bins]): Unit
88 |
89 | /**
90 | * Create a final model that incorporates all the transformations and
91 | * discretizations and can predict on the raw feature values.
92 | * @param store The store that contains internally trained models.
93 | * @return A final GBT model.
94 | */
95 | def createGradientBoostedTrees(
96 | store: GradientBoostedTreesStore
97 | ): GradientBoostedTrees
98 | }
99 |
100 | /**
101 | * A default implementation of the GBT factory.
102 | */
103 | class GradientBoostedTreesFactoryDefault extends GradientBoostedTreesFactory {
104 | var labelTransformer: Option[ColumnTransformer] = None
105 | var featureTransformers: Option[Array[ColumnTransformer]] = None
106 | var labelName: Option[String] = None
107 | var labelIsCat: Option[Boolean] = None
108 | var featureNames: Option[Array[String]] = None
109 | var featureIsCat: Option[Array[Boolean]] = None
110 | var optimalTreeCnt: Option[Int] = None
111 | var trainingDevianceHistory: Option[mutable.ListBuffer[Double]] = None
112 | var validationDevianceHistory: Option[mutable.ListBuffer[Double]] = None
113 | var featureBins: Option[Array[Bins]] = None
114 |
115 | def setColumnTransformers(
116 | labelTransformer: ColumnTransformer,
117 | featureTransformers: Array[ColumnTransformer]
118 | ): Unit = {
119 | this.labelTransformer = Some(labelTransformer)
120 | this.featureTransformers = Some(featureTransformers)
121 | }
122 |
123 | def setColumnNamesAndTypes(
124 | labelName: String,
125 | labelIsCat: Boolean,
126 | featureNames: Array[String],
127 | featureIsCat: Array[Boolean]
128 | ): Unit = {
129 | this.labelName = Some(labelName)
130 | this.labelIsCat = Some(labelIsCat)
131 | this.featureNames = Some(featureNames)
132 | this.featureIsCat = Some(featureIsCat)
133 | }
134 |
135 | def setOptimalTreeCnt(optimalTreeCnt: Int): Unit = {
136 | this.optimalTreeCnt = Some(optimalTreeCnt)
137 | }
138 |
139 | def setTrainingDevianceHistory(
140 | trainingDevianceHistory: mutable.ListBuffer[Double]
141 | ): Unit = {
142 | this.trainingDevianceHistory = Some(trainingDevianceHistory)
143 | }
144 |
145 | def setValidationDevianceHistory(
146 | validationDevianceHistory: mutable.ListBuffer[Double]
147 | ): Unit = {
148 | this.validationDevianceHistory = Some(validationDevianceHistory)
149 | }
150 |
151 | def setFeatureBins(featureBins: Array[Bins]): Unit = {
152 | this.featureBins = Some(featureBins)
153 | }
154 |
155 | def createGradientBoostedTrees(store: GradientBoostedTreesStore): GradientBoostedTrees = {
156 | val numFeatures = featureNames.get.length
157 | val featureImportance = Array.fill[Double](numFeatures)(0.0)
158 | val decisionTrees = store.trees.map {
159 | case (internalTree) =>
160 | val treeNodes = mutable.Map[java.lang.Integer, DecisionTreeNode]()
161 | val _ =
162 | DecisionTreeUtil.createDecisionTreeNode(
163 | internalTree.nodes(1),
164 | internalTree.nodes,
165 | featureImportance,
166 | this.featureBins.get,
167 | treeNodes
168 | )
169 | DecisionTree(treeNodes.toMap, internalTree.nodes.size)
170 | }.toArray
171 |
172 | GradientBoostedTreesDefault(
173 | lossFunctionClassName = store.lossFunction.getClass.getCanonicalName,
174 | labelTransformer = this.labelTransformer.get,
175 | featureTransformers = this.featureTransformers.get,
176 | labelName = this.labelName.get,
177 | labelIsCat = this.labelIsCat.get,
178 | featureNames = this.featureNames.get,
179 | featureIsCat = this.featureIsCat.get,
180 | sortedVarImportance =
181 | scala.util.Sorting.stableSort(
182 | featureNames.get.zip(featureImportance.map(new java.lang.Double(_))).toSeq,
183 | // We want to sort in a descending importance order.
184 | (e1: (String, java.lang.Double), e2: (String, java.lang.Double)) => e1._2 > e2._2
185 | ),
186 | shrinkage = store.shrinkage,
187 | initValue = store.initVal,
188 | decisionTrees = decisionTrees,
189 | optimalTreeCnt = this.optimalTreeCnt.map(new java.lang.Integer(_)),
190 | trainingDevianceHistory =
191 | this.trainingDevianceHistory.get.map(new java.lang.Double(_)).toSeq,
192 | validationDevianceHistory =
193 | this.validationDevianceHistory.map(vdh => vdh.map(new java.lang.Double(_)).toSeq)
194 | )
195 | }
196 | }
197 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/model/gb/GradientBoostedTreesStore.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.model.gb
19 |
20 | import scala.collection.mutable
21 |
22 | import spark_ml.gradient_boosting.loss.{LossAggregator, LossFunction}
23 | import spark_ml.tree_ensembles._
24 | import spark_ml.util.{DiscretizedFeatureHandler, MapWithSequentialIntKeys}
25 | import spire.implicits._
26 |
27 | /**
28 | * Used internally to store tree data while GB is training. This is not the
29 | * final trained model.
30 | */
31 | class GBInternalTree extends Serializable {
32 | val nodes = mutable.Map[Int, NodeInfo]()
33 | val nodeAggregators = new MapWithSequentialIntKeys[LossAggregator](
34 | initCapacity = 256
35 | )
36 |
37 | /**
38 | * Add a new node to the tree.
39 | * @param nodeInfo node to add.
40 | */
41 | def addNode(nodeInfo: NodeInfo): Unit = {
42 | // Sanity check !
43 | assert(
44 | !nodes.contains(nodeInfo.nodeId),
45 | "A tree node with the Id " + nodeInfo.nodeId + " already exists."
46 | )
47 |
48 | nodes.put(nodeInfo.nodeId, nodeInfo)
49 | }
50 |
51 | /**
52 | * Initialize aggregators for existing nodes. The aggregators are used to
53 | * aggregate samples per node to update node predictions, instead of using
54 | * predictions from CART training.
55 | */
56 | def initNodeFinetuners(lossFunction: LossFunction): Unit = {
57 | val (startNodeId, endNodeId) = (nodes.keys.min, nodes.keys.max)
58 | // Sanity checks.
59 | assert(
60 | startNodeId == 1,
61 | "The starting root node must always have the Id 1. But we have " + startNodeId
62 | )
63 | assert(
64 | nodes.size == (endNodeId - startNodeId + 1),
65 | "The number of nodes should equal " + (endNodeId - startNodeId + 1).toString +
66 | " but instead, we have " + nodes.size + " nodes."
67 | )
68 |
69 | cfor(startNodeId)(_ <= endNodeId, _ + 1)(
70 | nodeId => nodeAggregators.put(nodeId, lossFunction.createAggregator)
71 | )
72 | }
73 |
74 | /**
75 | * Add the given sample point to all the nodes that match the point.
76 | * @param samplePoint Sample point to add. This includes the label, the
77 | * current prediction and the features.
78 | * @param weight Weight of the sample point.
79 | * @param featureHandler Feature handler.
80 | * @tparam T Type of the discretized feature.
81 | */
82 | def addSamplePointToMatchingNodes[@specialized(Byte, Short) T](
83 | samplePoint: ((Double, Double), Array[T]),
84 | weight: Double,
85 | featureHandler: DiscretizedFeatureHandler[T]): Unit = {
86 | val ((label, curPred), features) = samplePoint
87 | // First add the point to the root node.
88 | nodeAggregators.get(1).addSamplePoint(
89 | label = label,
90 | weight = weight,
91 | curPred = curPred
92 | )
93 | var curNode = nodes(1)
94 |
95 | // Then add the point to all the matching descendants.
96 | while (curNode.splitInfo.nonEmpty) {
97 | val splitInfo = curNode.splitInfo.get
98 | val featId = splitInfo.featureId
99 | val binId = featureHandler.convertToInt(features(featId))
100 | val childId = splitInfo.chooseChildNode(binId).nodeId
101 |
102 | // We can't directly use the child node contained in splitInfo.
103 | // That child node is not the same as the one contained in nodes and
104 | // is incomplete.
105 | curNode = nodes(childId)
106 | val aggregator = nodeAggregators.get(childId)
107 | aggregator.addSamplePoint(
108 | label = label,
109 | weight = weight,
110 | curPred = curPred
111 | )
112 | }
113 | }
114 |
115 | /**
116 | * Update a node's prediction with the given one.
117 | * @param nodeId Id of the node whose prediction we want to update.
118 | * @param newPrediction The new prediction for the node.
119 | */
120 | def updateNodePrediction(nodeId: Int, newPrediction: Double): Unit = {
121 | nodes(nodeId).prediction = newPrediction
122 | }
123 |
124 | /**
125 | * Predict on the given features.
126 | * @param features Features.
127 | * @tparam T Type of features.
128 | * @return Prediction result.
129 | */
130 | def predict[@specialized(Byte, Short) T](
131 | features: Array[T],
132 | featureHandler: DiscretizedFeatureHandler[T]
133 | ): Double = {
134 | var curNode = nodes(1)
135 | while (curNode.splitInfo.nonEmpty) {
136 | val splitInfo = curNode.splitInfo.get
137 | val featId = splitInfo.featureId
138 | val binId = featureHandler.convertToInt(features(featId))
139 |
140 | // We can't directly use the child node contained in splitInfo.
141 | // That child node is not the same as the one contained in nodes and
142 | // is incomplete.
143 | val childId = splitInfo.chooseChildNode(binId).nodeId
144 | curNode = nodes(childId)
145 | }
146 | curNode.prediction
147 | }
148 |
149 | /**
150 | * Print a visual representation of the internal tree with ASCII.
151 | * @return A string representation of the internal tree.
152 | */
153 | override def toString: String = {
154 | val treeStringBuilder = new mutable.StringBuilder()
155 | val queue = new mutable.Queue[Int]()
156 | queue.enqueue(1)
157 | while (queue.nonEmpty) {
158 | val nodeInfo = nodes(queue.dequeue())
159 | treeStringBuilder.
160 | append("nodeId:").
161 | append(nodeInfo.nodeId).
162 | append(",prediction:").
163 | append(nodeInfo.prediction)
164 | if (nodeInfo.splitInfo.isDefined) {
165 | treeStringBuilder.
166 | append(",splitFeatureId:").
167 | append(nodeInfo.splitInfo.get.featureId)
168 | nodeInfo.splitInfo.get match {
169 | case numericNodeSplitInfo: NumericNodeSplitInfo =>
170 | treeStringBuilder.
171 | append(",splitBinId:").
172 | append(numericNodeSplitInfo.splitBinId)
173 | if (numericNodeSplitInfo.nanChildNode.isDefined) {
174 | treeStringBuilder.
175 | append(",hasNanChild")
176 | }
177 | case catNodeSplitInfo: CatNodeSplitInfo =>
178 | treeStringBuilder.
179 | append(",splitMapping:")
180 | catNodeSplitInfo.binIdToChildNode.foreach {
181 | case (binId, childNodeId) =>
182 | treeStringBuilder.
183 | append(";").
184 | append(binId.toString + "->" + childNodeId.toString)
185 | }
186 | }
187 |
188 | queue.enqueue(
189 | nodeInfo.splitInfo.get.getOrderedChildNodes.map(_.nodeId).toSeq : _*
190 | )
191 | }
192 |
193 | if (queue.nonEmpty && (nodes(queue.front).depth > nodeInfo.depth)) {
194 | treeStringBuilder.append("\n")
195 | } else {
196 | treeStringBuilder.append(" ")
197 | }
198 | }
199 |
200 | treeStringBuilder.toString()
201 | }
202 | }
203 |
204 | /**
205 | * Gradient boosted trees writer.
206 | * @param store The gradient boosted trees store object.
207 | */
208 | class GradientBoostedTreesWriter(store: GradientBoostedTreesStore)
209 | extends TreeEnsembleWriter {
210 | var curTree: GBInternalTree = store.curTree
211 |
212 | /**
213 | * Write the node info to the currently active tree.
214 | * @param nodeInfo Node info to write.
215 | */
216 | def writeNodeInfo(nodeInfo: NodeInfo): Unit = {
217 | curTree.addNode(nodeInfo)
218 | }
219 | }
220 |
221 | /**
222 | * Gradient boosted trees store.
223 | * @param lossFunction Loss function with which the gradient boosted trees are
224 | * trained.
225 | * @param initVal Initial prediction value for the model.
226 | * @param shrinkage Shrinkage value.
227 | */
228 | class GradientBoostedTreesStore(
229 | val lossFunction: LossFunction,
230 | val initVal: Double,
231 | val shrinkage: Double
232 | ) extends TreeEnsembleStore {
233 | val trees: mutable.ArrayBuffer[GBInternalTree] = mutable.ArrayBuffer[GBInternalTree]()
234 | var curTree: GBInternalTree = null
235 |
236 | /**
237 | * Add a new tree. This tree becomes the new active tree.
238 | */
239 | def initNewTree(): Unit = {
240 | curTree = new GBInternalTree
241 | trees += curTree
242 | }
243 |
244 | /**
245 | * Get a tree ensemble writer.
246 | * @return A tree ensemble writer.
247 | */
248 | def getWriter: TreeEnsembleWriter = {
249 | new GradientBoostedTreesWriter(this)
250 | }
251 |
252 | /**
253 | * Get an internal tree.
254 | * @param idx Index of the tree to get.
255 | * @return An internal tree.
256 | */
257 | def getInternalTree(idx: Int): GBInternalTree = {
258 | trees(idx)
259 | }
260 | }
261 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/model/rf/RandomForestStore.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.model.rf
19 |
20 | import scala.collection.mutable
21 |
22 | import spark_ml.discretization.Bins
23 | import spark_ml.model._
24 | import spark_ml.tree_ensembles._
25 | import spark_ml.util.MapWithSequentialIntKeys
26 |
27 | /**
28 | * A forest is simply a collection of equal-weight trees.
29 | * @param trees Trees in the forest.
30 | * @param splitCriteriaStr Tree split criteria (e.g. infogain or variance).
31 | * @param sortedVarImportance Sorted variable importance.
32 | * @param sampleCounts The number of training samples per tree.
33 | */
34 | case class RandomForest(
35 | trees: Array[DecisionTree],
36 | splitCriteriaStr: String,
37 | sortedVarImportance: Seq[(String, java.lang.Double)],
38 | sampleCounts: Array[Long]
39 | ) {
40 | /**
41 | * Predict from the features.
42 | * @param features A double array of features.
43 | * @return Prediction(s) and corresponding weight(s) (e.g. probabilities or
44 | * variances of predictions, etc.)
45 | */
46 | def predict(features: Array[Double]): Array[(Double, Double)] = {
47 | SplitCriteria.withName(splitCriteriaStr) match {
48 | case SplitCriteria.Classification_InfoGain => predictClass(features)
49 | case SplitCriteria.Regression_Variance => predictRegression(features)
50 | }
51 | }
52 |
53 | /**
54 | * Predict the class ouput from the given features - features should be in the
55 | * same order as the ones that the tree trained on.
56 | * @param features Feature values.
57 | * @return Predictions and their probabilities an array of (Double, Double)
58 | */
59 | private def predictClass(features: Array[Double]): Array[(Double, Double)] = {
60 | val predictions = mutable.Map[Double, Double]() // Predicted label and its count.
61 | var treeId = 0
62 | while (treeId < trees.length) {
63 | val tree = trees(treeId)
64 | val prediction = tree.predict(features)
65 | predictions.getOrElseUpdate(prediction, 0)
66 | predictions(prediction) += 1.0
67 |
68 | treeId += 1
69 | }
70 |
71 | // Sort the predictions by the number of occurrences.
72 | // The first element has the highest number of occurrences.
73 | val sortedPredictions = predictions.toArray.sorted(
74 | Ordering.by[(Double, Double), Double](-_._2)
75 | )
76 | sortedPredictions.map(p => (p._1, p._2 / trees.length.toDouble))
77 | }
78 |
79 | /**
80 | * Predict a continuous output from the given features - features should be in
81 | * the same order as the ones that the tree trained on.
82 | * @param features Feature values.
83 | * @return Prediction and its variance (a single element array of (Double, Double))
84 | */
85 | private def predictRegression(features: Array[Double]): Array[(Double, Double)] = {
86 | var predictionSum = 0.0
87 | var predictionSqrSum = 0.0
88 | var treeId = 0
89 | while (treeId < trees.length) {
90 | val tree = trees(treeId)
91 | val prediction = tree.predict(features)
92 | predictionSum += prediction
93 | predictionSqrSum += prediction * prediction
94 |
95 | treeId += 1
96 | }
97 |
98 | val predAvg = predictionSum / trees.length.toDouble
99 | val predVar = predictionSqrSum / trees.length.toDouble - predAvg * predAvg
100 | Array[(Double, Double)]((predAvg, predVar))
101 | }
102 | }
103 |
104 | class RFInternalTree extends Serializable {
105 | val nodes = mutable.Map[Int, NodeInfo]()
106 |
107 | /**
108 | * Add a new node to the tree.
109 | * @param nodeInfo node to add.
110 | */
111 | def addNode(nodeInfo: NodeInfo): Unit = {
112 | // Sanity check !
113 | assert(
114 | !nodes.contains(nodeInfo.nodeId),
115 | "A tree node with the Id " + nodeInfo.nodeId + " already exists."
116 | )
117 |
118 | nodes.put(nodeInfo.nodeId, nodeInfo)
119 | }
120 | }
121 |
122 | class RandomForestWriter(store: RandomForestStore)
123 | extends TreeEnsembleWriter {
124 | def writeNodeInfo(nodeInfo: NodeInfo): Unit = {
125 | if (!store.trees.contains(nodeInfo.treeId)) {
126 | store.trees.put(nodeInfo.treeId, new RFInternalTree)
127 | }
128 | store.trees.get(nodeInfo.treeId).addNode(nodeInfo)
129 | }
130 | }
131 |
132 | /**
133 | * A default random forest store.
134 | * @param splitCriteria The split criteria for trees.
135 | */
136 | class RandomForestStore(
137 | splitCriteria: SplitCriteria.SplitCriteria,
138 | featureNames: Array[String],
139 | featureBins: Array[Bins]
140 | ) extends TreeEnsembleStore {
141 | val trees = new MapWithSequentialIntKeys[RFInternalTree](
142 | initCapacity = 100
143 | )
144 |
145 | def getWriter: TreeEnsembleWriter = {
146 | new RandomForestWriter(this)
147 | }
148 |
149 | def createRandomForest: RandomForest = {
150 | val featureImportance = Array.fill[Double](featureBins.length)(0.0)
151 | val (startTreeId, endTreeId) = this.trees.getKeyRange
152 | val decisionTrees = (startTreeId to endTreeId).map {
153 | case treeId =>
154 | val internalTree = this.trees.get(treeId)
155 | val treeNodes = mutable.Map[java.lang.Integer, DecisionTreeNode]()
156 | val _ =
157 | DecisionTreeUtil.createDecisionTreeNode(
158 | internalTree.nodes(1),
159 | internalTree.nodes,
160 | featureImportance,
161 | this.featureBins,
162 | treeNodes
163 | )
164 | DecisionTree(treeNodes.toMap, internalTree.nodes.size)
165 | }.toArray
166 |
167 | RandomForest(
168 | trees = decisionTrees,
169 | splitCriteriaStr = splitCriteria.toString,
170 | sortedVarImportance =
171 | scala.util.Sorting.stableSort(
172 | featureNames.zip(featureImportance.map(new java.lang.Double(_))).toSeq,
173 | // We want to sort in a descending importance order.
174 | (e1: (String, java.lang.Double), e2: (String, java.lang.Double)) => e1._2 > e2._2
175 | ),
176 | sampleCounts = decisionTrees.map{ _.nodes(1).nodeWeight.toLong }
177 | )
178 | }
179 | }
180 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/transformation/DataTransformationUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.transformation
19 |
20 | import org.apache.spark.rdd.RDD
21 | import org.apache.spark.sql.DataFrame
22 | import spark_ml.discretization.CardinalityOverLimitException
23 |
24 | /**
25 | * A class that encodes the transformation to a numeric value for a column.
26 | * @param distinctValToInt If this is defined, then the column value would be
27 | * transformed to a mapped integer.
28 | * @param maxCardinality If this is defined, then the column value would be
29 | * transformed into an integer value in
30 | * [0, maxCardinality) through hashing.
31 | * @return a transformed numeric value for a column value. If neither
32 | * distinctValToInt nor maxCardinality is defined, then the value
33 | * would be assumed to be a numeric value already and passed through.
34 | */
35 | case class ColumnTransformer(
36 | // We are using java.lang.Integer so that this can be easily serialized by
37 | // Gson.
38 | distinctValToInt: Option[Map[String, java.lang.Integer]],
39 | maxCardinality: Option[java.lang.Integer]) {
40 | def transform(value: String): Double = {
41 | if (distinctValToInt.isDefined) {
42 | val nonNullValue =
43 | if (value == null) {
44 | ""
45 | } else {
46 | value
47 | }
48 | distinctValToInt.get(nonNullValue).toDouble
49 | } else if (maxCardinality.isDefined) {
50 | val nonNullValue =
51 | if (value == null) {
52 | ""
53 | } else {
54 | value
55 | }
56 | DataTransformationUtils.getSimpleHashedValue(nonNullValue, maxCardinality.get)
57 | } else {
58 | if (value == null) {
59 | Double.NaN
60 | } else {
61 | value.toDouble
62 | }
63 | }
64 | }
65 | }
66 |
67 | object DataTransformationUtils {
68 | /**
69 | * Convert the given data frame into an RDD of label, feature vector pairs.
70 | * All label/feature values are also converted into Double.
71 | * @param dataFrame Spark Data Frame that we want to convert.
72 | * @param labelColIndex Label column index.
73 | * @param catDistinctValToInt Categorical column distinct value to int maps.
74 | * This is used to map distinct string values to
75 | * numeric values (doubles).
76 | * @param colsToIgnoreIndices Indices of columns in the data frame to be
77 | * ignored (not used as features or label).
78 | * @param maxCatCardinality Maximum categorical cardinality we allow. If the
79 | * cardinality goes over this, feature hashing might
80 | * be used (or will simply throw an exception).
81 | * @param useFeatureHashing Whether feature hashing should be used on
82 | * categorical columns whose unique value counts
83 | * exceed the maximum cardinality.
84 | * @return An RDD of label/feature-vector pairs and column transformer definitions.
85 | */
86 | def transformDataFrameToLabelFeatureRdd(
87 | dataFrame: DataFrame,
88 | labelColIndex: Int,
89 | catDistinctValToInt: Map[Int, Map[String, java.lang.Integer]],
90 | colsToIgnoreIndices: Set[Int],
91 | maxCatCardinality: Int,
92 | useFeatureHashing: Boolean): (RDD[(Double, Array[Double])], (ColumnTransformer, Array[ColumnTransformer])) = {
93 | val transformedRDD = dataFrame.map(row => {
94 | val labelValue = row.get(labelColIndex)
95 |
96 | Tuple2(
97 | if (catDistinctValToInt.contains(labelColIndex)) {
98 | val nonNullLabelValue =
99 | if (labelValue == null) {
100 | ""
101 | } else {
102 | labelValue.toString
103 | }
104 | mapCategoricalValueToNumericValue(
105 | labelColIndex,
106 | nonNullLabelValue,
107 | catDistinctValToInt(labelColIndex),
108 | maxCatCardinality,
109 | useFeatureHashing
110 | )
111 | } else {
112 | if (labelValue == null) {
113 | Double.NaN
114 | } else {
115 | labelValue.toString.toDouble
116 | }
117 | },
118 | row.toSeq.zipWithIndex.flatMap {
119 | case (colVal, idx) =>
120 | if (colsToIgnoreIndices.contains(idx) || (labelColIndex == idx)) {
121 | Array[Double]().iterator
122 | } else {
123 | if (catDistinctValToInt.contains(idx)) {
124 | val nonNullColVal =
125 | if (colVal == null) {
126 | ""
127 | } else {
128 | colVal.toString
129 | }
130 | Array(
131 | mapCategoricalValueToNumericValue(
132 | idx,
133 | nonNullColVal,
134 | catDistinctValToInt(idx),
135 | maxCatCardinality,
136 | useFeatureHashing
137 | )
138 | ).iterator
139 | } else {
140 | val nonNullColVal =
141 | if (colVal == null) {
142 | Double.NaN
143 | } else {
144 | colVal.toString.toDouble
145 | }
146 | Array(nonNullColVal).iterator
147 | }
148 | }
149 | }.toArray
150 | )
151 | })
152 |
153 | val numCols = dataFrame.columns.length
154 | val colTransformers = Tuple2(
155 | if (catDistinctValToInt.contains(labelColIndex)) {
156 | if (catDistinctValToInt(labelColIndex).size <= maxCatCardinality) {
157 | ColumnTransformer(Some(catDistinctValToInt(labelColIndex)), None)
158 | } else {
159 | ColumnTransformer(None, Some(maxCatCardinality))
160 | }
161 | } else {
162 | ColumnTransformer(None, None)
163 | },
164 | (0 to (numCols - 1)).flatMap {
165 | case (colIdx) =>
166 | if (colsToIgnoreIndices.contains(colIdx) || (labelColIndex == colIdx)) {
167 | Array[ColumnTransformer]().iterator
168 | } else {
169 | if (catDistinctValToInt.contains(colIdx)) {
170 | if (catDistinctValToInt(colIdx).size <= maxCatCardinality) {
171 | Array(ColumnTransformer(Some(catDistinctValToInt(colIdx)), None)).iterator
172 | } else {
173 | Array(ColumnTransformer(None, Some(maxCatCardinality))).iterator
174 | }
175 | } else {
176 | Array(ColumnTransformer(None, None)).iterator
177 | }
178 | }
179 | }.toArray
180 | )
181 |
182 | (transformedRDD, colTransformers)
183 | }
184 |
185 | /**
186 | * Map a categorical value (string) to a numeric value (double).
187 | * @param categoryId Equal to the column index.
188 | * @param catValue Categorical value that we want to map to a numeric value.
189 | * @param catValueToIntMap A map from categorical value to integers. If the
190 | * cardinality of the category is less than or equal
191 | * to maxCardinality, this is used.
192 | * @param maxCardinality The maximum allowed cardinality.
193 | * @param useFeatureHashing Whether feature hashing should be performed if the
194 | * cardinality of the category exceeds the maximum
195 | * cardinality.
196 | * @return A mapped double value.
197 | */
198 | def mapCategoricalValueToNumericValue(
199 | categoryId: Int,
200 | catValue: String,
201 | catValueToIntMap: Map[String, java.lang.Integer],
202 | maxCardinality: Int,
203 | useFeatureHashing: Boolean): Double = {
204 | if (catValueToIntMap.size <= maxCardinality) {
205 | catValueToIntMap(catValue).toDouble
206 | } else {
207 | if (useFeatureHashing) {
208 | getSimpleHashedValue(catValue, maxCardinality)
209 | } else {
210 | throw new CardinalityOverLimitException(
211 | "The categorical column with the index " + categoryId +
212 | " has a cardinality that exceeds the limit " + maxCardinality)
213 | }
214 | }
215 | }
216 |
217 | /**
218 | * Compute a simple hash of a categorical value.
219 | * @param catValue Categorical value that we want to compute a simple numeric
220 | * hash for. The hash value would be an integer between 0 and
221 | * (maxCardinality - 1).
222 | * @param maxCardinality The value of the hash is limited within
223 | * [0, maxCardinality).
224 | * @return A hashed value of double.
225 | */
226 | def getSimpleHashedValue(catValue: String, maxCardinality: Int): Double = {
227 | ((catValue.hashCode.toLong - Int.MinValue.toLong) % maxCardinality.toLong).toDouble
228 | }
229 | }
230 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/transformation/DistinctValueCounter.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.transformation
19 |
20 | import scala.collection.mutable
21 |
22 | import org.apache.spark.sql.DataFrame
23 |
24 | object DistinctValueCounter {
25 | /**
26 | * Get distinct values of categorical columns from the given data frame.
27 | * @param dataFrame Data frame from whose columns we want to gather distinct
28 | * values from.
29 | * @param catColIndices Categorical column indices.
30 | * @param maxCardinality The maximum number of unique values per column.
31 | * @return A map of column index and distinct value sets.
32 | */
33 | def getDistinctValues(
34 | dataFrame: DataFrame,
35 | catColIndices: Set[Int],
36 | maxCardinality: Int): Map[Int, mutable.Set[String]] = {
37 | dataFrame.mapPartitions(
38 | rowItr => {
39 | val distinctVals = mutable.Map[Int, mutable.Set[String]]()
40 | while (rowItr.hasNext) {
41 | val row = rowItr.next()
42 | row.toSeq.zipWithIndex.map {
43 | case (colVal, colIdx) =>
44 | if (catColIndices.contains(colIdx)) {
45 | val colDistinctVals = distinctVals.getOrElseUpdate(colIdx, mutable.Set[String]())
46 |
47 | // We don't care to count all the unique values if the distinct
48 | // count goes over the given limit.
49 | if (colDistinctVals.size <= maxCardinality) {
50 | val nonNullColVal =
51 | if (colVal == null) {
52 | ""
53 | } else {
54 | colVal.toString
55 | }
56 | colDistinctVals.add(nonNullColVal)
57 | }
58 | }
59 | }
60 | }
61 |
62 | distinctVals.toIterator
63 | }
64 | ).reduceByKey {
65 | (colDistinctVals1, colDistinctVals2) =>
66 | (colDistinctVals1 ++ colDistinctVals2).splitAt(maxCardinality + 1)._1
67 | }.collect().toMap
68 | }
69 |
70 | /**
71 | * Map a set of distinct values to an increasing non-negative numbers.
72 | * E.g., {'Women' -> 0, 'Men' -> 1}, etc.
73 | * @param distinctValues A set of distinct values for different columns (first
74 | * key is index to a column).
75 | * @param useEmptyString Whether an empty string should be used as a distinct
76 | * value.
77 | * @return A map of distinct values to integers for different columns (first
78 | * key is index to a column).
79 | */
80 | def mapDistinctValuesToIntegers(
81 | distinctValues: Map[Int, mutable.Set[String]],
82 | useEmptyString: Boolean
83 | ): Map[Int, Map[String, java.lang.Integer]] = {
84 | distinctValues.map {
85 | case (colIndex, values) =>
86 | colIndex -> values.filter(value => !(value == "" && !useEmptyString)).zipWithIndex.map {
87 | case (value, mappedVal) => value -> new java.lang.Integer(mappedVal)
88 | }.toMap
89 | }
90 | }
91 | }
92 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/tree_ensembles/IdCache.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.tree_ensembles
19 |
20 | import scala.collection.mutable
21 |
22 | import org.apache.hadoop.fs.{FileSystem, Path}
23 | import org.apache.spark.rdd.RDD
24 | import org.apache.spark.storage.StorageLevel
25 | import spark_ml.util.DiscretizedFeatureHandler
26 | import spire.implicits._
27 |
28 | object IdCache {
29 | /**
30 | * Create a new Id cache object, filled with initial Ids (1's).
31 | * @param numTrees Number of trees we expect.
32 | * @param data A data RDD or an RDD that contains the same number of rows as
33 | * the data.
34 | * @param storageLevel Cache storage level.
35 | * @param checkpointDir Checkpoint directory for intermediate checkpointing.
36 | * @param checkpointInterval Checkpointing interval.
37 | * @tparam T Type of data rows.
38 | * @return A new IdCache object.
39 | */
40 | def createIdCache[T](
41 | numTrees: Int,
42 | data: RDD[T],
43 | storageLevel: StorageLevel,
44 | checkpointDir: Option[String],
45 | checkpointInterval: Int): IdCache = {
46 | new IdCache(
47 | // All the Ids start with '1', meaning that all the rows are
48 | // assigned to the root node.
49 | curIds = data.map(_ => Array.fill[Int](numTrees)(1)),
50 | storageLevel = storageLevel,
51 | checkpointDir = checkpointDir,
52 | checkpointInterval = checkpointInterval
53 | )
54 | }
55 | }
56 |
57 | /**
58 | * Id cache to keep track of Ids of rows to indicate which tree nodes that
59 | * rows to belong to.
60 | * @param curIds RDD of the current Ids per data row. Each row of the RDD
61 | * is an array of Ids, each element an Id for a tree.
62 | * @param storageLevel Cache storage level.
63 | * @param checkpointDir Checkpoint directory.
64 | * @param checkpointInterval Checkpoint interval.
65 | */
66 | class IdCache(
67 | var curIds: RDD[Array[Int]],
68 | storageLevel: StorageLevel,
69 | checkpointDir: Option[String],
70 | checkpointInterval: Int) {
71 | private var prevIds: RDD[Array[Int]] = null
72 | private var updateCount: Int = 0
73 |
74 | // To keep track of last checkpointed RDDs.
75 | private val checkpointQueue = new mutable.Queue[RDD[Array[Int]]]()
76 |
77 | // Persist the initial Ids.
78 | curIds = curIds.persist(storageLevel)
79 |
80 | // If a checkpoint directory is given, and there's no prior checkpoint
81 | // directory, then set the checkpoint directory with the given one.
82 | if (checkpointDir.isDefined && curIds.sparkContext.getCheckpointDir.isEmpty) {
83 | curIds.sparkContext.setCheckpointDir(checkpointDir.get)
84 | }
85 |
86 | /**
87 | * Get the current Id RDD.
88 | * @return curIds RDD.
89 | */
90 | def getRdd: RDD[Array[Int]] = curIds
91 |
92 | /**
93 | * Update Ids that are stored in the cache RDD.
94 | * @param data RDD of data rows needed to find the updated Ids.
95 | * @param idLookupForUpdaters Id updaters.
96 | * @param featureHandler Data row feature type handler.
97 | * @tparam T Type of feature.
98 | */
99 | def updateIds[@specialized(Byte, Short) T](
100 | data: RDD[((Double, Array[T]), Array[Byte])],
101 | idLookupForUpdaters: IdLookupForUpdaters,
102 | featureHandler: DiscretizedFeatureHandler[T]): Unit = {
103 | if (prevIds != null) {
104 | // Unpersist the previous one if one exists.
105 | prevIds.unpersist(blocking = true)
106 | }
107 |
108 | prevIds = curIds
109 |
110 | // Update Ids.
111 | curIds = data.zip(curIds).map {
112 | case (((label, features), baggedCounts), nodeIds) =>
113 | val numTrees = nodeIds.length
114 | cfor(0)(_ < numTrees, _ + 1)(
115 | treeId => {
116 | val curNodeId = nodeIds(treeId)
117 | val rowCnt = baggedCounts(treeId)
118 | if (rowCnt > 0 && curNodeId != 0) {
119 | val idUpdater = idLookupForUpdaters.get(
120 | treeId = treeId,
121 | id = curNodeId
122 | )
123 | if (idUpdater != null) {
124 | nodeIds(treeId) = idUpdater.updateId(features = features, featureHandler = featureHandler)
125 | }
126 | }
127 | }
128 | )
129 |
130 | nodeIds
131 | }.persist(storageLevel)
132 |
133 | updateCount += 1
134 |
135 | // Handle checkpointing if the directory is not None.
136 | if (curIds.sparkContext.getCheckpointDir.isDefined &&
137 | (updateCount % checkpointInterval) == 0) {
138 | // See if we can delete previous checkpoints.
139 | var canDelete = true
140 | while (checkpointQueue.size > 1 && canDelete) {
141 | // We can delete the oldest checkpoint iff the next checkpoint actually
142 | // exists in the file system.
143 | if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
144 | val old = checkpointQueue.dequeue()
145 |
146 | // Since the old checkpoint is not deleted by Spark, we'll manually
147 | // delete it here.
148 | val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
149 | println("Deleting a stale IdCache RDD checkpoint at " + old.getCheckpointFile.get)
150 | fs.delete(new Path(old.getCheckpointFile.get), true)
151 | } else {
152 | canDelete = false
153 | }
154 | }
155 |
156 | curIds.checkpoint()
157 | checkpointQueue.enqueue(curIds)
158 | }
159 | }
160 |
161 | /**
162 | * Unpersist all the RDDs stored internally.
163 | */
164 | def close(): Unit = {
165 | // Unpersist and delete all the checkpoints.
166 | curIds.unpersist(blocking = true)
167 | if (prevIds != null) {
168 | prevIds.unpersist(blocking = true)
169 | }
170 |
171 | while (checkpointQueue.nonEmpty) {
172 | val old = checkpointQueue.dequeue()
173 | if (old.getCheckpointFile.isDefined) {
174 | val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
175 | println("Deleting a stale IdCache RDD checkpoint at " + old.getCheckpointFile.get)
176 | fs.delete(new Path(old.getCheckpointFile.get), true)
177 | }
178 | }
179 | }
180 | }
181 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/tree_ensembles/IdLookup.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.tree_ensembles
19 |
20 | import scala.reflect.ClassTag
21 |
22 | import spire.implicits._
23 |
24 | /**
25 | * The Id lookup object will contain sequential Ids from startId to endId.
26 | * @param startId Start Id.
27 | * @param endId End Id.
28 | */
29 | case class IdRange(
30 | var startId: Int,
31 | var endId: Int)
32 |
33 | /**
34 | * Each data point belongs to a particular tree's node and is assigned the Id of
35 | * the node (the Id is only for training). This structure is used to find any
36 | * object (e.g. aggregator) that corresponds to the node that the point belongs
37 | * to.
38 | * @param idRanges Id ranges for each tree.
39 | * @tparam T The type of the look up objects.
40 | */
41 | abstract class IdLookup[T: ClassTag](idRanges: Array[IdRange])
42 | extends Serializable {
43 | // This is the array that contains the corresponding objects for each Id.
44 | protected var lookUpObjs: Array[Array[T]] = null
45 | var objCnt: Int = 0
46 |
47 | // This is used to initialize the look up obj for tree/node pairs.
48 | protected def initLookUpObjs(
49 | initLookUpObj: (Int, Int) => T): Unit = {
50 | val numTrees = idRanges.length
51 | lookUpObjs = Array.fill[Array[T]](numTrees)(null)
52 |
53 | // Initialize the look up objects for tree/nodes.
54 | // There's one array per tree.
55 | cfor(0)(_ < numTrees, _ + 1)(
56 | treeId => {
57 | val idRange = idRanges(treeId)
58 | if (idRange != null) {
59 | val numIds = idRange.endId - idRange.startId + 1
60 | lookUpObjs(treeId) = Array.fill[T](numIds)(null.asInstanceOf[T])
61 | cfor(idRange.startId)(_ <= idRange.endId, _ + 1)(
62 | id => {
63 | val lookUpObjIdx = id - idRange.startId
64 | lookUpObjs(treeId)(lookUpObjIdx) = initLookUpObj(treeId, id)
65 | objCnt += 1
66 | }
67 | )
68 | }
69 | }
70 | )
71 | }
72 |
73 | /**
74 | * Get Id ranges.
75 | * @return Id ranges.
76 | */
77 | def getIdRanges: Array[IdRange] = idRanges
78 |
79 | /**
80 | * Get a look up object for the given tree/node.
81 | * @param treeId Tree Id.
82 | * @param id Id representing the node.
83 | * During training, nodes can be assigned arbitrary Ids.
84 | * They are not necessarily the final model node Ids.
85 | * @return The corresponding look up object.
86 | */
87 | def get(treeId: Int, id: Int): T = {
88 | val idRange = idRanges(treeId)
89 | if (idRange != null) {
90 | val startId = idRange.startId
91 | val endId = idRange.endId
92 | if (id >= startId && id <= endId) {
93 | lookUpObjs(treeId)(id - startId)
94 | } else {
95 | null.asInstanceOf[T]
96 | }
97 | } else {
98 | null.asInstanceOf[T]
99 | }
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/tree_ensembles/IdLookupForNodeStats.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.tree_ensembles
19 |
20 | import scala.collection.mutable
21 | import scala.util.Random
22 |
23 | import spark_ml.discretization.Bins
24 | import spark_ml.util.{MapWithSequentialIntKeys, RandomSet}
25 | import spire.implicits._
26 |
27 | /**
28 | * Id lookup for aggregating node statistics during training.
29 | * @param idRanges Id range for nodes to look up.
30 | * @param nodeDepths A map containing depths of all the nodes in the ranges.
31 | */
32 | class IdLookupForNodeStats(
33 | idRanges: Array[IdRange],
34 | nodeDepths: Array[MapWithSequentialIntKeys[Int]]
35 | ) extends IdLookup[NodeStats](idRanges) {
36 |
37 | /**
38 | * Initialize this look up object.
39 | * @param treeType Tree type by the split criteria (e.g. classification based
40 | * on info-gain or regression based on variance.)
41 | * @param featureBinsInfo Feature discretization info.
42 | * @param treeSeeds Random seeds for trees.
43 | * @param mtry mtry.
44 | * @param numClasses Optional number of target classes (for classifications).
45 | */
46 | def initNodeStats(
47 | treeType: SplitCriteria.SplitCriteria,
48 | featureBinsInfo: Array[Bins],
49 | treeSeeds: Array[Int],
50 | mtry: Int,
51 | numClasses: Option[Int]): Unit = {
52 | val numFeatures = featureBinsInfo.length
53 | def createNodeStats(treeId: Int, id: Int): NodeStats = {
54 | val mtryFeatureIds = RandomSet.nChooseK(
55 | k = mtry,
56 | n = numFeatures,
57 | rnd = new Random(treeSeeds(treeId) + id)
58 | )
59 | NodeStats.createNodeStats(
60 | treeId = treeId,
61 | nodeId = id,
62 | nodeDepth = nodeDepths(treeId).get(id),
63 | treeType = treeType,
64 | featureBinsInfo = featureBinsInfo,
65 | mtryFeatureIds = mtryFeatureIds,
66 | numClasses = numClasses
67 | )
68 | }
69 |
70 | initLookUpObjs(createNodeStats)
71 | }
72 |
73 | /**
74 | * Get an iterator of all the nodestats. Each nodestats gets assigned an
75 | * incrementing hash value. This is useful to even distribute aggregated
76 | * nodestats to different machines to perform distributed splits.
77 | * @return An iterator of pairs of (hashValue, nodestats) and the count of
78 | * node stats.
79 | */
80 | def toHashedNodeStatsIterator: (Iterator[(Int, NodeStats)], Int) = {
81 | val numTrees = idRanges.length
82 | var curHash = 0
83 | val output = new mutable.ListBuffer[(Int, NodeStats)]()
84 | cfor(0)(_ < numTrees, _ + 1)(
85 | treeId => {
86 | if (idRanges(treeId) != null) {
87 | val numIds = idRanges(treeId).endId - idRanges(treeId).startId + 1
88 | cfor(0)(_ < numIds, _ + 1)(
89 | i => {
90 | val nodeStats = lookUpObjs(treeId)(i)
91 | output += ((curHash, nodeStats))
92 | curHash += 1
93 | }
94 | )
95 | }
96 | }
97 | )
98 |
99 | (output.toIterator, curHash)
100 | }
101 | }
102 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/tree_ensembles/IdLookupForSubTreeInfo.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.tree_ensembles
19 |
20 | import spark_ml.util.MapWithSequentialIntKeys
21 |
22 | /**
23 | * Sub tree info used during sub-tree training.
24 | * @param id Id of the subtree.
25 | * @param hash Hash value for the sub-tree (used to push hash sub-tree data
26 | * evenly to different executors).
27 | * @param depth Depth of the sub-tree from the parent tree perspective.
28 | * @param parentTreeId Parent tree Id.
29 | */
30 | case class SubTreeInfo(
31 | id: Int,
32 | hash: Int,
33 | depth: Int,
34 | parentTreeId: Int) {
35 | /**
36 | * Override the hashCode to return the subTreeHash value.
37 | * @return The subTreeHash value.
38 | */
39 | override def hashCode: Int = {
40 | hash
41 | }
42 | }
43 |
44 | /**
45 | * Sub tree info lookup used to find matching data points for each sub tree.
46 | * @param idRanges Id ranges for each tree.
47 | */
48 | class IdLookupForSubTreeInfo(
49 | idRanges: Array[IdRange]
50 | ) extends IdLookup[SubTreeInfo](idRanges)
51 |
52 | object IdLookupForSubTreeInfo {
53 | /**
54 | * Create a new Id lookup object for sub-trees.
55 | * @param idRanges Id ranges of sub trees for parent trees.
56 | * @param subTreeMaps Sub tree info maps for parent trees.
57 | * @return Id lookup object for sub-trees.
58 | */
59 | def createIdLookupForSubTreeInfo(
60 | idRanges: Array[IdRange],
61 | subTreeMaps: Array[MapWithSequentialIntKeys[SubTreeInfo]]): IdLookupForSubTreeInfo = {
62 | val lookup = new IdLookupForSubTreeInfo(idRanges)
63 | lookup.initLookUpObjs(
64 | (treeId: Int, id: Int) => subTreeMaps(treeId).get(id)
65 | )
66 |
67 | lookup
68 | }
69 | }
70 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/tree_ensembles/IdLookupForUpdaters.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.tree_ensembles
19 |
20 | import spark_ml.util.{DiscretizedFeatureHandler, MapWithSequentialIntKeys}
21 |
22 | /**
23 | * Update the Id of the node that a row belongs to via the split information.
24 | * @param splitInfo split information that will determine the new Id for
25 | * the given features.
26 | */
27 | case class IdUpdater(splitInfo: NodeSplitInfo) {
28 | def updateId[@specialized(Byte, Short) T](
29 | features: Array[T],
30 | featureHandler: DiscretizedFeatureHandler[T]): Int = {
31 | if (splitInfo == null) {
32 | 0 // 0 indicates that the data point has reached a terminal node.
33 | } else {
34 | val binId = featureHandler.convertToInt(features(splitInfo.featureId))
35 |
36 | // The nodeId here is usually not the same as the final tree's node Id.
37 | // This is more of a temporary value to refer split Ids during training.
38 | splitInfo.chooseChildNode(binId).nodeId
39 | }
40 | }
41 | }
42 |
43 | /**
44 | * A look up for Id updaters. This is used to update Ids of nodes that
45 | * data points belong to during training.
46 | */
47 | class IdLookupForUpdaters(
48 | idRanges: Array[IdRange]
49 | ) extends IdLookup[IdUpdater](idRanges)
50 |
51 | object IdLookupForUpdaters {
52 | /**
53 | * Create a new Id lookup object for updaters.
54 | * @param idRanges Id ranges of updaters.
55 | * @param updaterMaps Maps of Id updaters.
56 | * @return Id lookup object for updaters.
57 | */
58 | def createIdLookupForUpdaters(
59 | idRanges: Array[IdRange],
60 | updaterMaps: Array[MapWithSequentialIntKeys[IdUpdater]]
61 | ): IdLookupForUpdaters = {
62 | val lookup = new IdLookupForUpdaters(idRanges)
63 | lookup.initLookUpObjs(
64 | (treeId: Int, id: Int) => updaterMaps(treeId).get(id)
65 | )
66 |
67 | lookup
68 | }
69 | }
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/tree_ensembles/InfoGainNodeStats.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.tree_ensembles
19 |
20 | import breeze.numerics.log2
21 | import spark_ml.util.DiscretizedFeatureHandler
22 | import spire.implicits._
23 |
24 | /**
25 | * Information gain node statistics.
26 | * @param treeId Id of the tree that the node belongs to.
27 | * @param nodeId Id of the node.
28 | * @param nodeDepth Depth of the node.
29 | * @param statsArray The actual statistics array.
30 | * @param mtryFeatures Selected feature descriptions.
31 | * @param numElemsPerBin Number of statistical elements per feature bin.
32 | */
33 | class InfoGainNodeStats(
34 | treeId: Int,
35 | nodeId: Int,
36 | nodeDepth: Int,
37 | statsArray: Array[Double],
38 | mtryFeatures: Array[SelectedFeatureInfo],
39 | numElemsPerBin: Int) extends NodeStats(
40 | treeId = treeId,
41 | nodeId = nodeId,
42 | nodeDepth = nodeDepth,
43 | statsArray = statsArray,
44 | mtryFeatures = mtryFeatures,
45 | numElemsPerBin = numElemsPerBin) {
46 |
47 | /**
48 | * Add statistics related to a sample (label and features).
49 | * @param label Label of the sample.
50 | * @param features Features of the sample.
51 | * @param sampleCnt Sample count of the sample.
52 | * @param featureHandler Feature type handler.
53 | * @tparam T Feature type (Byte or Short).
54 | */
55 | override def addSample[@specialized(Byte, Short) T](
56 | label: Double,
57 | features: Array[T],
58 | sampleCnt: Int,
59 | featureHandler: DiscretizedFeatureHandler[T]): Unit = {
60 | val mtry = mtryFeatures.length
61 | cfor(0)(_ < mtry, _ + 1)(
62 | i => {
63 | val featId = mtryFeatures(i).featId
64 | val featOffset = mtryFeatures(i).offset
65 | val binId = featureHandler.convertToInt(features(featId))
66 | statsArray(featOffset + binId * numElemsPerBin + label.toInt) +=
67 | sampleCnt
68 | }
69 | )
70 | }
71 |
72 | /**
73 | * Calculate the node values from the given class distribution. This also
74 | * includes calculating the entropy.
75 | * @param classWeights The class weight distribution.
76 | * @param offset Starting offset.
77 | * @param output Where the output will be stored.
78 | * @return Node values.
79 | */
80 | override def calculateNodeValues(
81 | classWeights: Array[Double],
82 | offset: Int,
83 | output: PreallocatedNodeValues): PreallocatedNodeValues = {
84 | var weightSum: Double = 0.0
85 | var prediction: Double = 0.0
86 | var maxClassWeight: Double = 0.0
87 | var entropy: Double = 0.0
88 |
89 | // Determine weightSum and prediction.
90 | cfor(0)(_ < numElemsPerBin, _ + 1)(
91 | labelId => {
92 | val weight = classWeights(offset + labelId)
93 | output.sumStats(labelId) = weight
94 | if (maxClassWeight < weight) {
95 | maxClassWeight = weight
96 | prediction = labelId
97 | }
98 |
99 | weightSum += weight
100 | }
101 | )
102 |
103 | // Compute entropy.
104 | cfor(0)(_ < numElemsPerBin, _ + 1)(
105 | labelId => {
106 | val weight = classWeights(offset + labelId)
107 | if (weight > 0.0) {
108 | val prob = weight.toDouble / weightSum
109 | entropy -= prob * log2(prob)
110 | }
111 | }
112 | )
113 |
114 | output.prediction = prediction
115 | output.addendum = maxClassWeight / weightSum // Probability of the class.
116 | output.weight = weightSum
117 | output.impurity = entropy
118 |
119 | output
120 | }
121 |
122 | /**
123 | * Get bin weights.
124 | * @param statsArray Stats array.
125 | * @param offset Start offset.
126 | * @param numBins Number of bins.
127 | * @param output This is where the weights will be stored.
128 | * @return Returns the same output that was passed in.
129 | */
130 | override def getBinWeights(
131 | statsArray: Array[Double],
132 | offset: Int,
133 | numBins: Int,
134 | output: Array[Double]): Array[Double] = {
135 | cfor(0)(_ < numBins, _ + 1)(
136 | binId => {
137 | val binOffset = offset + binId * numElemsPerBin
138 | // Sum all the label weights per bin.
139 | cfor(0)(_ < numElemsPerBin, _ + 1)(
140 | labelId => output(binId) += statsArray(binOffset + labelId)
141 | )
142 | }
143 | )
144 |
145 | output
146 | }
147 |
148 | /**
149 | * Calculate the label average for the given bin. This is only meaningful for
150 | * binary classifications.
151 | * @param statsArray Stats array.
152 | * @param binOffset Offset to the bin.
153 | * @return The label average for the bin.
154 | */
155 | override def getBinLabelAverage(
156 | statsArray: Array[Double],
157 | binOffset: Int): Double = {
158 | var labelSum = 0.0
159 | var weightSum = 0.0
160 | cfor(0)(_ < numElemsPerBin, _ + 1)(
161 | labelId => {
162 | val labelWeight = statsArray(binOffset + labelId)
163 | labelSum += labelId.toDouble * labelWeight
164 | weightSum += labelWeight
165 | }
166 | )
167 |
168 | labelSum / weightSum
169 | }
170 |
171 | /**
172 | * This is for classification, so return true.
173 | * @return true
174 | */
175 | override def forClassification: Boolean = true
176 | }
177 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/tree_ensembles/SubTreeStore.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.tree_ensembles
19 |
20 | import scala.collection.mutable
21 |
22 | import spark_ml.util.Sorting
23 | import spire.implicits._
24 |
25 | /**
26 | * This object contains the descriptions of all the nodes in this sub-tree.
27 | * @param parentTreeId Id of the parent tree.
28 | * @param subTreeId Sub tree Id.
29 | * @param subTreeDepth Depth of the sub-tree from the parent tree's perspective.
30 | */
31 | class SubTreeDesc(
32 | val parentTreeId: Int,
33 | val subTreeId: Int,
34 | val subTreeDepth: Int) extends Serializable {
35 | // Store the nodes in a mutable map.
36 | var nodes = mutable.Map[Int, NodeInfo]()
37 |
38 | /**
39 | * Add a new trained node info object.
40 | * It's expected that the node Id's increment monotonically, one by one.
41 | * @param nodeInfo Info about the next trained node.
42 | */
43 | def addNodeInfo(nodeInfo: NodeInfo): Unit = {
44 | assert(
45 | !nodes.contains(nodeInfo.nodeId),
46 | "A node with the node Id " + nodeInfo.nodeId + " already exists in the " +
47 | "sub tree " + subTreeId + " that belongs to the parent tree " + parentTreeId
48 | )
49 |
50 | nodes.put(nodeInfo.nodeId, nodeInfo)
51 | }
52 |
53 | /**
54 | * Update the tree/node Ids and the node depth to reflect the parent tree's
55 | * reality.
56 | * @param rootId The root node Id will be updated to this.
57 | * @param startChildNodeId The descendant nodes will have updated Ids,
58 | * starting from this number.
59 | * @return The last descendant node Id that was used for this sub-tree.
60 | */
61 | def updateIdsAndDepths(
62 | rootId: Int,
63 | startChildNodeId: Int): Int = {
64 | val itr = nodes.values.iterator
65 | val updatedNodes = mutable.Map[Int, NodeInfo]()
66 | var largestUpdatedNodeId = 0
67 | while (itr.hasNext) {
68 | val nodeInfo = itr.next()
69 | val updatedNodeInfo = nodeInfo.copy
70 | updatedNodeInfo.treeId = parentTreeId
71 | if (updatedNodeInfo.nodeId == 1) {
72 | updatedNodeInfo.nodeId = rootId
73 | } else {
74 | // Start node Id is used from the first child node of the root node,
75 | // which should have the Id of 2. Therefore, update the node Ids by
76 | // subtracting 2 and then adding startChildNodeId.
77 | updatedNodeInfo.nodeId = updatedNodeInfo.nodeId - 2 + startChildNodeId
78 | }
79 | largestUpdatedNodeId = math.max(updatedNodeInfo.nodeId, largestUpdatedNodeId)
80 | updatedNodeInfo.depth = updatedNodeInfo.depth - 1 + subTreeDepth
81 | if (updatedNodeInfo.splitInfo.nonEmpty) {
82 | val si = updatedNodeInfo.splitInfo.get
83 | val children = si.getOrderedChildNodes
84 | val numChildren = children.length
85 | cfor(0)(_ < numChildren, _ + 1)(
86 | i => {
87 | val child = children(i)
88 | child.treeId = parentTreeId
89 | child.nodeId = child.nodeId - 2 + startChildNodeId
90 | child.depth = child.depth - 1 + subTreeDepth
91 | }
92 | )
93 | }
94 |
95 | updatedNodes.put(updatedNodeInfo.nodeId, updatedNodeInfo)
96 | }
97 |
98 | this.nodes = updatedNodes
99 |
100 | largestUpdatedNodeId
101 | }
102 |
103 | /**
104 | * Get a sequence of nodeInfo objects ordered by the node Id.
105 | * @return A sequence of nodeInfo objects ordered by the node Id.
106 | */
107 | def getOrderedNodeInfoSeq: Seq[NodeInfo] = {
108 | val nodeInfoArray = this.nodes.values.toArray
109 | Sorting.quickSort[NodeInfo](nodeInfoArray)(
110 | Ordering.by[NodeInfo, Int](_.nodeId)
111 | )
112 |
113 | nodeInfoArray
114 | }
115 | }
116 |
117 | /**
118 | * Local subtree store. Used for locally training sub-trees.
119 | * @param parentTreeId Id of the parent tree.
120 | * @param subTreeId Sub tree Id.
121 | * @param subTreeDepth Depth of the sub-tree from the parent tree's perspective.
122 | */
123 | class SubTreeStore(
124 | parentTreeId: Int,
125 | subTreeId: Int,
126 | subTreeDepth: Int) extends TreeEnsembleStore {
127 | val subTreeDesc = new SubTreeDesc(
128 | parentTreeId = parentTreeId,
129 | subTreeId = subTreeId,
130 | subTreeDepth = subTreeDepth
131 | )
132 |
133 | /**
134 | * Get a sub tree writer.
135 | * @return Sub tree writer.
136 | */
137 | def getWriter: TreeEnsembleWriter = new SubTreeWriter(this)
138 | }
139 |
140 | /**
141 | * Sub tree writer.
142 | * @param subTreeStore The sub tree store this belongs to.
143 | */
144 | class SubTreeWriter(subTreeStore: SubTreeStore) extends TreeEnsembleWriter {
145 | val subTreeDesc = subTreeStore.subTreeDesc
146 |
147 | /**
148 | * Write node info.
149 | * @param nodeInfo Node info to write.
150 | */
151 | def writeNodeInfo(nodeInfo: NodeInfo): Unit = {
152 | subTreeDesc.addNodeInfo(nodeInfo)
153 | }
154 | }
155 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/tree_ensembles/TreeEnsembleStore.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.tree_ensembles
19 |
20 | /**
21 | * Tree ensemble writer.
22 | */
23 | trait TreeEnsembleWriter {
24 | /**
25 | * Write node info.
26 | * @param nodeInfo Node info to write.
27 | */
28 | def writeNodeInfo(nodeInfo: NodeInfo): Unit
29 | }
30 |
31 | /**
32 | * Tree ensemble store.
33 | */
34 | trait TreeEnsembleStore {
35 | /**
36 | * Get a tree ensemble writer.
37 | * @return A tree ensemble writer.
38 | */
39 | def getWriter: TreeEnsembleWriter
40 | }
41 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/tree_ensembles/VarianceNodeStats.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.tree_ensembles
19 |
20 | import spark_ml.util.DiscretizedFeatureHandler
21 | import spire.implicits._
22 |
23 | /**
24 | * Variance node statistics.
25 | * @param treeId Id of the tree that the node belongs to.
26 | * @param nodeId Id of the node.
27 | * @param nodeDepth Depth of the node.
28 | * @param statsArray The actual statistics array.
29 | * @param mtryFeatures Selected feature descriptions.
30 | * @param numElemsPerBin Number of statistical elements per feature bin.
31 | */
32 | class VarianceNodeStats(
33 | treeId: Int,
34 | nodeId: Int,
35 | nodeDepth: Int,
36 | statsArray: Array[Double],
37 | mtryFeatures: Array[SelectedFeatureInfo],
38 | numElemsPerBin: Int) extends NodeStats(
39 | treeId = treeId,
40 | nodeId = nodeId,
41 | nodeDepth = nodeDepth,
42 | statsArray = statsArray,
43 | mtryFeatures = mtryFeatures,
44 | numElemsPerBin = numElemsPerBin) {
45 |
46 | /**
47 | * Add statistics related to a sample (label and features).
48 | * @param label Label of the sample.
49 | * @param features Features of the sample.
50 | * @param sampleCnt Sample count of the sample.
51 | * @param featureHandler Feature type handler.
52 | * @tparam T Feature type (Byte or Short).
53 | */
54 | override def addSample[@specialized(Byte, Short) T](
55 | label: Double,
56 | features: Array[T],
57 | sampleCnt: Int,
58 | featureHandler: DiscretizedFeatureHandler[T]): Unit = {
59 | val mtry = mtryFeatures.length
60 | cfor(0)(_ < mtry, _ + 1)(
61 | i => {
62 | // Add to the bins of all the selected features.
63 | val featId = mtryFeatures(i).featId
64 | val featOffset = mtryFeatures(i).offset
65 | val binId = featureHandler.convertToInt(features(featId))
66 |
67 | val binOffset = featOffset + binId * numElemsPerBin
68 |
69 | val sampleCntInDouble = sampleCnt.toDouble
70 | val labelSum = label * sampleCntInDouble
71 | val labelSqrSum = label * labelSum
72 |
73 | statsArray(binOffset) += labelSum
74 | statsArray(binOffset + 1) += labelSqrSum
75 | statsArray(binOffset + 2) += sampleCntInDouble
76 | }
77 | )
78 | }
79 |
80 | /**
81 | * Calculate the node values from the summary stats. Also computes variance.
82 | * @param sumStats Summary stats (label sum, label sqr sum, count).
83 | * @param offset Starting offset.
84 | * @param output Where the output will be stored.
85 | * @return Node values.
86 | */
87 | override def calculateNodeValues(
88 | sumStats: Array[Double],
89 | offset: Int,
90 | output: PreallocatedNodeValues): PreallocatedNodeValues = {
91 | val prediction = sumStats(offset) / sumStats(offset + 2)
92 | val variance =
93 | sumStats(offset + 1) / sumStats(offset + 2) - prediction * prediction
94 |
95 | output.prediction = prediction
96 | output.addendum = variance
97 | output.weight = sumStats(offset + 2)
98 | output.impurity = variance
99 | output.sumStats(0) = sumStats(offset)
100 | output.sumStats(1) = sumStats(offset + 1)
101 | output.sumStats(2) = sumStats(offset + 2)
102 |
103 | output
104 | }
105 |
106 | /**
107 | * Get bin weights.
108 | * @param statsArray Stats array.
109 | * @param offset Start offset.
110 | * @param numBins Number of bins.
111 | * @param output This is where the weights will be stored.
112 | * @return Returns the same output that was passed in.
113 | */
114 | override def getBinWeights(
115 | statsArray: Array[Double],
116 | offset: Int,
117 | numBins: Int,
118 | output: Array[Double]): Array[Double] = {
119 | cfor(0)(_ < numBins, _ + 1)(
120 | binId => output(binId) = statsArray(offset + binId * numElemsPerBin + 2)
121 | )
122 |
123 | output
124 | }
125 |
126 | /**
127 | * Calculate the label average for the given bin.
128 | * @param statsArray Stats array.
129 | * @param binOffset Offset to the bin.
130 | * @return The label average for the bin.
131 | */
132 | override def getBinLabelAverage(
133 | statsArray: Array[Double],
134 | binOffset: Int): Double = {
135 | statsArray(binOffset) / statsArray(binOffset + 2)
136 | }
137 |
138 | /**
139 | * This is for regression, so return false.
140 | * @return false
141 | */
142 | override def forClassification: Boolean = false
143 | }
144 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/util/Bagger.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | import scala.util.Random
21 |
22 | import org.apache.spark.rdd.RDD
23 | import spire.implicits._
24 |
25 | /**
26 | * Available bagging types.
27 | */
28 | object BaggingType extends Enumeration {
29 | type BaggingType = Value
30 | val WithReplacement = Value(0)
31 | val WithoutReplacement = Value(1)
32 | }
33 |
34 | /**
35 | * Bagger.
36 | */
37 | object Bagger {
38 | /**
39 | * Create an RDD of bagging info. Each row of the return value is an array of
40 | * sample counts for a corresponding data row.
41 | * The number of samples per row is set by numSamples.
42 | * @param data An RDD of data points.
43 | * @param numSamples Number of samples we want to get per row.
44 | * @param baggingType Bagging type.
45 | * @param baggingRate Bagging rate.
46 | * @param seed Random seed.
47 | * @tparam T Data row type.
48 | * @return An RDD of an array of sample counts.
49 | */
50 | def getBagRdd[T](
51 | data: RDD[T],
52 | numSamples: Int,
53 | baggingType: BaggingType.BaggingType,
54 | baggingRate: Double,
55 | seed: Int): RDD[Array[Byte]] = {
56 | data.mapPartitionsWithIndex(
57 | (index, rows) => {
58 | val poisson = Poisson(baggingRate, seed + index)
59 | val rng = new Random(seed + index)
60 | rows.map(
61 | row => {
62 | val counts = Array.fill[Byte](numSamples)(0)
63 | cfor(0)(_ < numSamples, _ + 1)(
64 | sampleId => {
65 | val sampleCount =
66 | if (baggingType == BaggingType.WithReplacement) {
67 | poisson.sample()
68 | } else {
69 | if (rng.nextDouble() <= baggingRate) 1 else 0
70 | }
71 |
72 | // Only allow a sample count value upto 127
73 | // to save space.
74 | counts(sampleId) = math.min(sampleCount, 127).toByte
75 | }
76 | )
77 |
78 | counts
79 | }
80 | )
81 | }
82 | )
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/util/DiscretizedFeatureHandler.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | /**
21 | * Discretized features can be either unsigned Byte or Short. This trait is
22 | * responsible for handling the actual data types, freeing the aggregation logic
23 | * from handling different types.
24 | * @tparam T Type of the features. Either Byte or Short.
25 | */
26 | trait DiscretizedFeatureHandler[@specialized(Byte, Short) T] extends Serializable {
27 | /**
28 | * Convert the given value of type T to an integer value.
29 | * @param value The value that we want to convert.
30 | * @return The integer value.
31 | */
32 | def convertToInt(value: T): Int
33 |
34 | /**
35 | * Convert the given integer value to the type.
36 | * @param value Integer value that we want to convert.
37 | * @return The converted value.
38 | */
39 | def convertToType(value: Int): T
40 |
41 | /**
42 | * Get the minimum value you can get for this type.
43 | * @return The minimum value for this type.
44 | */
45 | def getMinValue: Int
46 |
47 | /**
48 | * get the maximum value you can get for this type.
49 | * @return The maximum value for this type.
50 | */
51 | def getMaxValue: Int
52 | }
53 |
54 | /**
55 | * Handle unsigned byte features.
56 | */
57 | class UnsignedByteHandler extends DiscretizedFeatureHandler[Byte] {
58 | /**
59 | * Convert the given value of unsigned Byte to an integer value.
60 | * @param value The value that we want to convert.
61 | * @return The integer value.
62 | */
63 | def convertToInt(value: Byte): Int = {
64 | value.toInt + 128
65 | }
66 |
67 | /**
68 | * Convert the given integer value to the type.
69 | * @param value Integer value that we want to convert.
70 | * @return The converted value.
71 | */
72 | def convertToType(value: Int): Byte = {
73 | (value - 128).toByte
74 | }
75 |
76 | /**
77 | * Get the minimum value you can get for this type.
78 | * @return The minimum value for this type.
79 | */
80 | def getMinValue: Int = 0
81 |
82 | /**
83 | * get the maximum value you can get for this type.
84 | * @return The maximum value for this type.
85 | */
86 | def getMaxValue: Int = 255
87 | }
88 |
89 | /**
90 | * Handle unsigned short features.
91 | */
92 | class UnsignedShortHandler extends DiscretizedFeatureHandler[Short] {
93 | /**
94 | * Convert the given value of unsigned Short to an integer value.
95 | * @param value The value that we want to convert.
96 | * @return The integer value.
97 | */
98 | def convertToInt(value: Short): Int = {
99 | value.toInt + 32768
100 | }
101 |
102 | /**
103 | * Convert the given integer value to the type.
104 | * @param value Integer value that we want to convert.
105 | * @return The converted value.
106 | */
107 | def convertToType(value: Int): Short = {
108 | (value - 32768).toShort
109 | }
110 |
111 | /**
112 | * Get the minimum value you can get for this type.
113 | * @return The minimum value for this type.
114 | */
115 | def getMinValue: Int = 0
116 |
117 | /**
118 | * get the maximum value you can get for this type.
119 | * @return The maximum value for this type.
120 | */
121 | def getMaxValue: Int = 65535
122 | }
123 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/util/MapWithSequentialIntKeys.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | import scala.reflect.ClassTag
21 |
22 | import spire.implicits._
23 |
24 | /**
25 | * Exception to be thrown in case of unexpected behavior of the following map
26 | * class.
27 | * @param msg Exception msg.
28 | */
29 | case class UnexpectedKeyException(msg: String) extends Exception(msg)
30 |
31 | /**
32 | * A map with integer keys that are guaranteed to be incrementing one by one.
33 | * A next insert is expected to use a key that's equal to 1 + prev put key.
34 | * A next remove is expected to use a key that's equal to the beginning key.
35 | * @param initCapacity Initial capacity.
36 | */
37 | class MapWithSequentialIntKeys[@specialized(Int) T: ClassTag](initCapacity: Int)
38 | extends Serializable {
39 | private var values = new Array[T](initCapacity)
40 | private var capacity = initCapacity
41 | private var size = 0
42 | private var putCursor = 0
43 | private var expectedPutKey = 0
44 | private var firstGetPos = 0
45 | private var firstGetKey = 0
46 |
47 | /**
48 | * Put the next key value. The key is expected to be 1 + the previous one
49 | * unless it's the very first key.
50 | * @param key The key value.
51 | * @param value The value corresponding to the key.
52 | */
53 | def put(key: Int, value: T): Unit = {
54 | // Make sure that the key is as expected.
55 | if (size > 0 && key != expectedPutKey) {
56 | throw UnexpectedKeyException(
57 | "The put key " + key +
58 | " is different from the expected key " + expectedPutKey
59 | )
60 | }
61 |
62 | // If the array is full, we need to get a new array.
63 | if (size >= capacity) {
64 | capacity *= 2
65 | val newValues = new Array[T](capacity)
66 | cfor(0)(_ < size, _ + 1)(
67 | i => {
68 | newValues(i) = values(firstGetPos)
69 | firstGetPos += 1
70 | if (firstGetPos >= size) {
71 | firstGetPos = 0
72 | }
73 | }
74 | )
75 |
76 | values = newValues
77 | putCursor = size
78 | firstGetPos = 0
79 | }
80 |
81 | values(putCursor) = value
82 | putCursor += 1
83 | if (putCursor >= capacity) {
84 | putCursor = 0
85 | }
86 |
87 | if (size == 0) {
88 | firstGetKey = key
89 | }
90 |
91 | size += 1
92 | expectedPutKey = key + 1
93 | }
94 |
95 | /**
96 | * Get the value corresponding to the key.
97 | * @param key The integer key value.
98 | * @return The corresponding value.
99 | */
100 | def get(key: Int): T = {
101 | val getOffset = key - firstGetKey
102 |
103 | // Make sure that it's within the expected range.
104 | if (getOffset >= size || getOffset < 0) {
105 | throw UnexpectedKeyException(
106 | "The get key " + key +
107 | " is not within [" + firstGetKey + ", " + (firstGetKey + size - 1) + "]"
108 | )
109 | }
110 |
111 | var idx = firstGetPos + getOffset
112 | if (idx >= capacity) {
113 | idx -= capacity
114 | }
115 |
116 | values(idx)
117 | }
118 |
119 | /**
120 | * Remove a key.
121 | * @param key The key we want to remove.
122 | */
123 | def remove(key: Int): Unit = {
124 | // We expect the key to firstGetKey.
125 | // Otherwise, this is not being used as expected.
126 | if (key != firstGetKey) {
127 | throw UnexpectedKeyException(
128 | "The remove key " + key +
129 | " is different from the expected key " + firstGetKey
130 | )
131 | }
132 |
133 | size -= 1
134 | firstGetKey += 1
135 | firstGetPos += 1
136 | if (firstGetPos >= capacity) {
137 | firstGetPos = 0
138 | }
139 | }
140 |
141 | /**
142 | * Get key range in the object.
143 | * @return A pair of (startKey, endKey).
144 | */
145 | def getKeyRange: (Int, Int) = {
146 | (firstGetKey, firstGetKey + size - 1)
147 | }
148 |
149 | /**
150 | * Whether this map contains the key.
151 | * @param key The integer key that we want to check for.
152 | * @return true if the key is contained. false otherwise.
153 | */
154 | def contains(key: Int): Boolean = {
155 | val (startKey, endKey) = getKeyRange
156 | (key >= startKey) && (key <= endKey)
157 | }
158 | }
159 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/util/Poisson.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | import scala.collection.mutable
21 | import scala.util.Random
22 |
23 | /**
24 | * Poisson sampler.
25 | * @param lambda The average count of the Poisson distribution.
26 | * @param seed The random number generator seed.
27 | */
28 | case class Poisson(lambda: Double, seed: Int) {
29 | private val rng = new Random(seed)
30 | private val tolerance: Double = 0.00001
31 |
32 | private var pdf: Array[Double] = _
33 | private var cdf: Array[Double] = _
34 |
35 | {
36 | val pdfBuilder = new mutable.ArrayBuilder.ofDouble
37 | val cdfBuilder = new mutable.ArrayBuilder.ofDouble
38 |
39 | val expPart: Double = math.exp(-lambda)
40 | var curCDF: Double = expPart
41 | var curPDF: Double = expPart
42 |
43 | cdfBuilder += curCDF
44 | pdfBuilder += curPDF
45 |
46 | var i: Double = 1.0
47 | while (curCDF < (1.0 - tolerance)) {
48 | curPDF *= lambda / i
49 | curCDF += curPDF
50 |
51 | cdfBuilder += curCDF
52 | pdfBuilder += curPDF
53 |
54 | i += 1.0
55 | }
56 |
57 | pdf = pdfBuilder.result()
58 | cdf = cdfBuilder.result()
59 | }
60 |
61 | /**
62 | * Sample a poisson distributed value.
63 | * @return A sampled integer value.
64 | */
65 | def sample(): Int = {
66 | val rnd = rng.nextDouble()
67 | for (i <- 0 to cdf.length - 1) {
68 | if (rnd <= cdf(i)) {
69 | return i
70 | }
71 | }
72 |
73 | cdf.length - 1
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/util/ProgressNotifiee.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | import java.io.Serializable
21 | import java.util.Calendar
22 |
23 | /**
24 | * An object to funnel progress reports to.
25 | */
26 | trait ProgressNotifiee extends Serializable {
27 | def newProgressMessage(progress: String): Unit
28 | def newStatusMessage(status: String): Unit
29 | def newErrorMessage(error: String): Unit
30 | }
31 |
32 | /**
33 | * Simple console notifiee.
34 | * Prints messages to stdout.
35 | */
36 | class ConsoleNotifiee extends ProgressNotifiee {
37 | def newProgressMessage(progress: String): Unit = {
38 | println(
39 | "[Progress] [" + Calendar.getInstance().getTime.toString + "] " + progress
40 | )
41 | }
42 |
43 | def newStatusMessage(status: String): Unit = {
44 | println(
45 | "[Status] [" + Calendar.getInstance().getTime.toString + "] " + status
46 | )
47 | }
48 |
49 | def newErrorMessage(error: String): Unit = {
50 | println(
51 | "[Error] [" + Calendar.getInstance().getTime.toString + "] " + error
52 | )
53 | }
54 | }
55 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/util/RandomSet.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | import scala.collection.mutable
21 |
22 | /**
23 | * Helper functions to select random samples.
24 | */
25 | object RandomSet {
26 | def nChooseK(k: Int, n: Int, rnd: scala.util.Random): Array[Int] = {
27 | val indices = new mutable.ArrayBuilder.ofInt
28 | var remains = k
29 |
30 | for (i <- 0 to n - 1) {
31 | if (rnd.nextInt(n - i) < remains) {
32 | indices += i
33 | remains -= 1
34 | }
35 | }
36 |
37 | indices.result()
38 | }
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/util/ReservoirSample.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | import scala.util.Random
21 |
22 | /**
23 | * Reservoir sample.
24 | * @param maxSample The maximum sample size.
25 | */
26 | class ReservoirSample(val maxSample: Int) extends Serializable {
27 | var sample = Array.fill[Double](maxSample)(0.0)
28 | var numSamplePoints: Int = 0
29 | var numPointsSeen: Double = 0.0
30 |
31 | def doReservoirSampling(point: Double, rng: Random): Unit = {
32 | numPointsSeen += 1.0
33 | if (numSamplePoints < maxSample) {
34 | sample(numSamplePoints) = point
35 | numSamplePoints += 1
36 | } else {
37 | val randomNumber = math.floor(rng.nextDouble() * numPointsSeen)
38 | if (randomNumber < maxSample.toDouble) {
39 | sample(randomNumber.toInt) = point
40 | }
41 | }
42 | }
43 | }
44 |
45 | object ReservoirSample {
46 | /**
47 | * Merge two reservoir samples and make sure that each sample retains
48 | * uniformness.
49 | * @param a A reservoir sample.
50 | * @param b A reservoir sample.
51 | * @param maxSample Maximum number of reservoir samples we want.
52 | * @param rng A random number generator.
53 | * @return Merged sample.
54 | */
55 | def mergeReservoirSamples(
56 | a: ReservoirSample,
57 | b: ReservoirSample,
58 | maxSample: Int,
59 | rng: Random): ReservoirSample = {
60 |
61 | assert(maxSample == a.maxSample)
62 | assert(a.maxSample == b.maxSample)
63 |
64 | // Merged samples.
65 | val mergedSample = new ReservoirSample(maxSample)
66 |
67 | // Find out which one has seen more samples.
68 | val (largerSample, smallerSample) =
69 | if (a.numPointsSeen > b.numPointsSeen) {
70 | (a, b)
71 | } else {
72 | (b, a)
73 | }
74 |
75 | // First, fill in the merged samples with the samples that had 'lower' prob
76 | // of being selected. I.e., the sample that has seen more points.
77 | var i = 0
78 | while (i < largerSample.numSamplePoints) {
79 | mergedSample.sample(i) = largerSample.sample(i)
80 | i += 1
81 | }
82 | mergedSample.numSamplePoints = largerSample.numSamplePoints
83 | mergedSample.numPointsSeen = largerSample.numPointsSeen
84 |
85 | // Now, add smaller sample points with probabilities so that they become
86 | // uniform.
87 | var j = 0 // The smaller sample index.
88 | val probSmaller = smallerSample.numPointsSeen / (largerSample.numPointsSeen + smallerSample.numPointsSeen)
89 | while (j < smallerSample.numSamplePoints) {
90 | val samplePoint = smallerSample.sample(j)
91 | if (mergedSample.numSamplePoints > smallerSample.numSamplePoints) {
92 | mergedSample.doReservoirSampling(samplePoint, rng)
93 | } else {
94 | val rnd = rng.nextDouble()
95 | if (rnd < probSmaller) {
96 | mergedSample.sample(j) = samplePoint
97 | }
98 | }
99 |
100 | j += 1
101 | }
102 |
103 | mergedSample
104 | }
105 | }
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/util/RobustMath.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | /**
21 | * Some math functions that don't give NaN's for edge cases (e.g. log of 0's) or
22 | * too large/small numbers of exp's.
23 | * These are used for loss function calculations.
24 | */
25 | object RobustMath {
26 | private val minExponent = -19.0
27 | private val maxExponent = 19.0
28 | private val expPredLowerLimit = math.exp(minExponent)
29 | private val expPredUpperLimit = math.exp(maxExponent)
30 |
31 | def log(value: Double): Double = {
32 | if (value == 0.0) {
33 | minExponent
34 | } else {
35 | math.min(math.max(math.log(value), minExponent), maxExponent)
36 | }
37 | }
38 |
39 | def exp(value: Double): Double = {
40 | math.min(math.max(math.exp(value), expPredLowerLimit), expPredUpperLimit)
41 | }
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/scala/spark_ml/util/Selection.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | import scala.util.Random
21 |
22 | /**
23 | * Selection algorithm utils.
24 | * Adopted with some modifications from the java code at :
25 | * http://blog.teamleadnet.com/2012/07/quick-select-algorithm-find-kth-element.html
26 | */
27 | object Selection {
28 | /**
29 | * A quick select algorithm used to select the n'th number from an array.
30 | * @param array Array of Doubles.
31 | * @param s The starting index of the segment that we are looking at.
32 | * @param e The ending position (ending index + 1) of the segment that we are
33 | * looking at.
34 | * @param n We're looking for the n'th number if the array is sorted.
35 | * @param rng A random number generator.
36 | * @return The n'th number in the array.
37 | */
38 | def quickSelect(
39 | array: Array[Double],
40 | s: Int,
41 | e: Int,
42 | n: Int,
43 | rng: Random): Double = {
44 |
45 | var from = s
46 | var to = e - 1
47 | while (from < to) {
48 | var r = from
49 | var w = to
50 | val pivotIdx = rng.nextInt(to - from + 1) + from
51 | val pivotVal = array(pivotIdx)
52 | while (r < w) {
53 | if (array(r) >= pivotVal) {
54 | val tmp = array(w)
55 | array(w) = array(r)
56 | array(r) = tmp
57 | w -= 1
58 | } else {
59 | r += 1
60 | }
61 | }
62 |
63 | if (array(r) > pivotVal) {
64 | r -= 1
65 | }
66 |
67 | if (n <= r) {
68 | to = r
69 | } else {
70 | from = r + 1
71 | }
72 | }
73 |
74 | array(n)
75 | }
76 | }
77 |
--------------------------------------------------------------------------------
/src/test/scala/spark_ml/discretization/EqualFrequencyBinFinderSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.discretization
19 |
20 | import scala.util.{Failure, Success, Try}
21 |
22 | import org.scalatest.FunSuite
23 | import spark_ml.util._
24 |
25 | /**
26 | * Test equal frequency bin finders.
27 | */
28 | class EqualFrequencyBinFinderSuite extends FunSuite with LocalSparkContext {
29 | test("Test the equal frequency bin finder 1") {
30 | val rawData1 = TestDataGenerator.labeledData1
31 | val testDataRDD1 = sc.parallelize(rawData1, 3).cache()
32 |
33 | val (labelSummary1, bins1) =
34 | new EqualFrequencyBinFinderFromSample(
35 | maxSampleSize = 1000,
36 | seed = 0
37 | ).findBins(
38 | data = testDataRDD1,
39 | columnNames = ("Label", Array("Col1", "Col2", "Col3")),
40 | catIndices = Set(1),
41 | maxNumBins = 8,
42 | expectedLabelCardinality = Some(4),
43 | notifiee = new ConsoleNotifiee
44 | )
45 |
46 | assert(labelSummary1.restCount === 0L)
47 | assert(labelSummary1.catCounts.get.length === 4)
48 | assert(labelSummary1.catCounts.get(0) === 7L)
49 | assert(labelSummary1.catCounts.get(1) === 8L)
50 | assert(labelSummary1.catCounts.get(2) === 11L)
51 | assert(labelSummary1.catCounts.get(3) === 4L)
52 | assert(bins1.length === 3)
53 | assert(bins1(0).getCardinality === 5)
54 | assert(bins1(1).getCardinality === 3)
55 | assert(bins1(2).getCardinality === 8)
56 |
57 | BinsTestUtil.validateNumericalBins(
58 | bins1(0).asInstanceOf[NumericBins],
59 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, Double.PositiveInfinity)),
60 | None
61 | )
62 |
63 | assert(bins1(0).findBinIdx(0.0) === 0)
64 | assert(bins1(0).findBinIdx(0.5) === 0)
65 | assert(bins1(0).findBinIdx(1.0) === 1)
66 | assert(bins1(0).findBinIdx(3.99999) === 3)
67 | assert(bins1(0).findBinIdx(4.0) === 4)
68 | assert(bins1(0).findBinIdx(10.0) === 4)
69 |
70 | assert(bins1(1).isInstanceOf[CategoricalBins])
71 |
72 | Try(bins1(1).findBinIdx(1.1)) match {
73 | case Success(idx) => fail("CategoricalBins findBinIdx should've thrown an exception.")
74 | case Failure(ex) => assert(ex.isInstanceOf[InvalidCategoricalValueException])
75 | }
76 |
77 | Try(bins1(1).findBinIdx(10.0)) match {
78 | case Success(idx) => fail("CategoricalBins findBinIdx should've thrown an exception.")
79 | case Failure(ex) => assert(ex.isInstanceOf[CardinalityOverLimitException])
80 | }
81 |
82 | BinsTestUtil.validateNumericalBins(
83 | bins1(2).asInstanceOf[NumericBins],
84 | Array(
85 | (Double.NegativeInfinity, -72.87),
86 | (-72.87, -52.28),
87 | (-52.28, -5.63),
88 | (-5.63, 20.88),
89 | (20.88, 25.89),
90 | (25.89, 59.07),
91 | (59.07, 81.67),
92 | (81.67, Double.PositiveInfinity)
93 | ),
94 | None
95 | )
96 |
97 | val rawData3 = TestDataGenerator.labeledData3
98 | val testDataRDD3 = sc.parallelize(rawData3, 3).cache()
99 |
100 | val (labelSummary3, bins3) =
101 | new EqualFrequencyBinFinderFromSample(
102 | maxSampleSize = 1000,
103 | seed = 0
104 | ).findBins(
105 | data = testDataRDD3,
106 | columnNames = ("Label", Array("Col1", "Col2")),
107 | catIndices = Set(),
108 | maxNumBins = 8,
109 | expectedLabelCardinality = None,
110 | notifiee = new ConsoleNotifiee
111 | )
112 |
113 | assert(labelSummary3.expectedCardinality.isEmpty)
114 | assert(labelSummary3.catCounts.isEmpty)
115 | assert(labelSummary3.restCount === 30L)
116 | assert(bins3.length === 2)
117 | assert(bins3(0).getCardinality === 5)
118 | assert(bins3(1).getCardinality === 3)
119 |
120 | BinsTestUtil.validateNumericalBins(
121 | bins3(0).asInstanceOf[NumericBins],
122 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, Double.PositiveInfinity)),
123 | None
124 | )
125 |
126 | BinsTestUtil.validateNumericalBins(
127 | bins3(1).asInstanceOf[NumericBins],
128 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, Double.PositiveInfinity)),
129 | None
130 | )
131 |
132 | val rawData6 = TestDataGenerator.labeledData6
133 | val testDataRDD6 = sc.parallelize(rawData6, 3).cache()
134 |
135 | val (labelSummary6, bins6) =
136 | new EqualFrequencyBinFinderFromSample(
137 | maxSampleSize = 1000,
138 | seed = 0
139 | ).findBins(
140 | data = testDataRDD6,
141 | columnNames = ("Label", Array("Col1", "Col2")),
142 | catIndices = Set(),
143 | maxNumBins = 8,
144 | expectedLabelCardinality = Some(3),
145 | notifiee = new ConsoleNotifiee
146 | )
147 |
148 | assert(labelSummary6.restCount === 4L)
149 | assert(labelSummary6.catCounts.get(0) === 7L)
150 | assert(labelSummary6.catCounts.get(1) === 8L)
151 | assert(labelSummary6.catCounts.get(2) === 11L)
152 | assert(bins6.length === 2)
153 | assert(bins6(0).getCardinality === 6)
154 | assert(bins6(1).getCardinality === 4)
155 |
156 | BinsTestUtil.validateNumericalBins(
157 | bins6(0).asInstanceOf[NumericBins],
158 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, Double.PositiveInfinity)),
159 | Option(5)
160 | )
161 |
162 | BinsTestUtil.validateNumericalBins(
163 | bins6(1).asInstanceOf[NumericBins],
164 | Array((Double.NegativeInfinity, 1.0), (1.0, 2.0), (2.0, Double.PositiveInfinity)),
165 | Option(3)
166 | )
167 | }
168 |
169 | test("Test the equal frequency RDD transformation 1") {
170 | val rawData1 = TestDataGenerator.labeledData1
171 | val testDataRDD1 = sc.parallelize(rawData1, 1).cache()
172 |
173 | val (labelSummary1, bins1) =
174 | new EqualFrequencyBinFinderFromSample(
175 | maxSampleSize = 1000,
176 | seed = 0
177 | ).findBins(
178 | data = testDataRDD1,
179 | columnNames = ("Label", Array("Col1", "Col2", "Col3")),
180 | catIndices = Set(1),
181 | maxNumBins = 8,
182 | expectedLabelCardinality = Some(4),
183 | notifiee = new ConsoleNotifiee
184 | )
185 |
186 | val featureHandler = new UnsignedByteHandler
187 | val transformedFeatures1 = Discretizer.transformFeatures(
188 | input = testDataRDD1,
189 | featureBins = bins1,
190 | featureHandler = featureHandler
191 | ).collect()
192 |
193 | assert(featureHandler.convertToInt(transformedFeatures1(0)(0)) === 0)
194 | assert(featureHandler.convertToInt(transformedFeatures1(0)(1)) === 0)
195 | assert(featureHandler.convertToInt(transformedFeatures1(0)(2)) === 2)
196 |
197 | assert(featureHandler.convertToInt(transformedFeatures1(17)(0)) === 2)
198 | assert(featureHandler.convertToInt(transformedFeatures1(17)(1)) === 2)
199 | assert(featureHandler.convertToInt(transformedFeatures1(17)(2)) === 4)
200 |
201 | val rawData6 = TestDataGenerator.labeledData6
202 | val testDataRDD6 = sc.parallelize(rawData6, 1).cache()
203 |
204 | val (labelSummary6, bins6) =
205 | new EqualFrequencyBinFinderFromSample(
206 | maxSampleSize = 1000,
207 | seed = 0
208 | ).findBins(
209 | data = testDataRDD6,
210 | columnNames = ("Label", Array("Col1", "Col2")),
211 | catIndices = Set(),
212 | maxNumBins = 8,
213 | expectedLabelCardinality = Some(4),
214 | notifiee = new ConsoleNotifiee
215 | )
216 |
217 | val transformedFeatures6 = Discretizer.transformFeatures(
218 | input = testDataRDD6,
219 | featureBins = bins6,
220 | featureHandler = featureHandler
221 | ).collect()
222 |
223 | assert(featureHandler.convertToInt(transformedFeatures6(2)(0)) === 2)
224 | assert(featureHandler.convertToInt(transformedFeatures6(2)(1)) === 3)
225 |
226 | assert(featureHandler.convertToInt(transformedFeatures6(7)(0)) === 5)
227 | assert(featureHandler.convertToInt(transformedFeatures6(7)(1)) === 1)
228 | }
229 | }
230 |
--------------------------------------------------------------------------------
/src/test/scala/spark_ml/discretization/EqualWidthBinFinderSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.discretization
19 |
20 | import org.apache.spark.SparkException
21 | import org.scalatest.FunSuite
22 | import spark_ml.util._
23 |
24 | /**
25 | * Test equal width discretization.
26 | */
27 | class EqualWidthBinFinderSuite extends FunSuite with LocalSparkContext {
28 | test("Test the equal width bin finder 1") {
29 | val rawData1 = TestDataGenerator.labeledData1
30 | val testDataRDD1 = sc.parallelize(rawData1, 3).cache()
31 |
32 | val (labelSummary1, bins1) =
33 | new EqualWidthBinFinder().findBins(
34 | data = testDataRDD1,
35 | columnNames = ("Label", Array("Col1", "Col2", "Col3")),
36 | catIndices = Set(1),
37 | maxNumBins = 8,
38 | expectedLabelCardinality = Some(4),
39 | notifiee = new ConsoleNotifiee
40 | )
41 |
42 | assert(labelSummary1.restCount === 0L)
43 | assert(labelSummary1.catCounts.get.length === 4)
44 | assert(labelSummary1.catCounts.get(0) === 7L)
45 | assert(labelSummary1.catCounts.get(1) === 8L)
46 | assert(labelSummary1.catCounts.get(2) === 11L)
47 | assert(labelSummary1.catCounts.get(3) === 4L)
48 | assert(bins1.length === 3)
49 | assert(bins1(0).getCardinality === 8)
50 | assert(bins1(1).getCardinality === 3)
51 | assert(bins1(2).getCardinality === 8)
52 |
53 | assert(bins1(0).findBinIdx(0.0) === 0)
54 | assert(bins1(0).findBinIdx(0.5) === 1)
55 | assert(bins1(0).findBinIdx(1.0) === 2)
56 | assert(bins1(0).findBinIdx(4.0) === 7)
57 |
58 | assert(bins1(1).isInstanceOf[CategoricalBins])
59 |
60 | assert(bins1(2).findBinIdx(-80.0) === 0)
61 | assert(bins1(2).findBinIdx(-60.0) === 0)
62 | assert(bins1(2).findBinIdx(-58.0) === 1)
63 |
64 | val rawData6 = TestDataGenerator.labeledData6
65 | val testDataRDD6 = sc.parallelize(rawData6, 3).cache()
66 |
67 | val (labelSummary6, bins6) =
68 | new EqualWidthBinFinder().findBins(
69 | data = testDataRDD6,
70 | columnNames = ("Label", Array("Col1", "Col2")),
71 | catIndices = Set(),
72 | maxNumBins = 8,
73 | expectedLabelCardinality = Some(3),
74 | notifiee = new ConsoleNotifiee
75 | )
76 |
77 | assert(labelSummary6.restCount === 4L)
78 | assert(labelSummary6.catCounts.get(0) === 7L)
79 | assert(labelSummary6.catCounts.get(1) === 8L)
80 | assert(labelSummary6.catCounts.get(2) === 11L)
81 | assert(bins6.length === 2)
82 | assert(bins6(0).getCardinality === 8)
83 | assert(bins6(1).getCardinality === 8)
84 | assert(bins6(0).asInstanceOf[NumericBins].missingValueBinIdx === Some(7))
85 | assert(bins6(1).asInstanceOf[NumericBins].missingValueBinIdx === Some(7))
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/src/test/scala/spark_ml/util/BinsTestUtil.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | import org.scalatest.Assertions._
21 | import spark_ml.discretization.NumericBins
22 |
23 | object BinsTestUtil {
24 | def validateNumericalBins(
25 | bins: NumericBins,
26 | boundaries: Array[(Double, Double)],
27 | missingBinId: Option[Int]): Unit = {
28 | assert(bins.getCardinality === (boundaries.length + (if (missingBinId.isDefined) 1 else 0)))
29 | assert(bins.missingValueBinIdx === missingBinId)
30 | bins.bins.zip(boundaries).foreach {
31 | case (numericBin, (l, r)) =>
32 | assert(numericBin.lower === l)
33 | assert(numericBin.upper === r)
34 | }
35 | }
36 | }
37 |
--------------------------------------------------------------------------------
/src/test/scala/spark_ml/util/LocalSparkContext.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | import org.apache.log4j.{Level, Logger}
21 | import org.apache.spark.{SparkConf, SparkContext}
22 | import org.scalatest.{BeforeAndAfterAll, Suite}
23 |
24 | /**
25 | * Start a local spark context for unit testing.
26 | */
27 | trait LocalSparkContext extends BeforeAndAfterAll { self: Suite =>
28 | @transient var sc: SparkContext = _
29 |
30 | override def beforeAll() {
31 | super.beforeAll()
32 | // http://stackoverflow.com/questions/27781187/how-to-stop-messages-displaying-on-spark-console
33 | Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
34 | Logger.getLogger("akka").setLevel(Level.WARN)
35 | Thread.sleep(100L)
36 | val conf = new SparkConf()
37 | .setMaster("local[3]")
38 | .setAppName("test")
39 | sc = new SparkContext(conf)
40 | }
41 |
42 | override def afterAll() {
43 | if (sc != null) {
44 | sc.stop()
45 | sc = null
46 | }
47 | super.afterAll()
48 | }
49 |
50 | def numbersAreEqual(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
51 | math.abs(x - y) / (math.abs(y) + 1e-15) < tol
52 | }
53 | }
54 |
--------------------------------------------------------------------------------
/src/test/scala/spark_ml/util/TestDataGenerator.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package spark_ml.util
19 |
20 | /**
21 | * Generate test data.
22 | */
23 | object TestDataGenerator {
24 | def labeledData1: Array[(Double, Array[Double])] = {
25 | Array(
26 | (0.0, Array(0.0, 0.0, -52.28)),
27 | (0.0, Array(1.0, 0.0, -32.16)),
28 | (1.0, Array(2.0, 0.0, -73.68)),
29 | (1.0, Array(3.0, 0.0, -26.38)),
30 | (2.0, Array(4.0, 0.0, 13.69)),
31 | (2.0, Array(0.0, 1.0, 42.07)),
32 | (2.0, Array(1.0, 1.0, 22.96)),
33 | (3.0, Array(2.0, 1.0, -33.43)),
34 | (3.0, Array(3.0, 1.0, -61.80)),
35 | (0.0, Array(4.0, 1.0, -81.34)),
36 | (2.0, Array(0.0, 1.0, -68.49)),
37 | (2.0, Array(1.0, 1.0, 64.17)),
38 | (3.0, Array(2.0, 1.0, 20.88)),
39 | (3.0, Array(3.0, 1.0, 27.75)),
40 | (0.0, Array(4.0, 1.0, 59.07)),
41 | (0.0, Array(0.0, 2.0, -53.55)),
42 | (1.0, Array(1.0, 2.0, 25.89)),
43 | (1.0, Array(2.0, 2.0, 22.62)),
44 | (2.0, Array(3.0, 2.0, -5.63)),
45 | (2.0, Array(4.0, 2.0, 81.67)),
46 | (0.0, Array(0.0, 2.0, -72.87)),
47 | (1.0, Array(1.0, 2.0, 25.51)),
48 | (1.0, Array(2.0, 2.0, 43.14)),
49 | (2.0, Array(3.0, 2.0, 60.53)),
50 | (2.0, Array(4.0, 2.0, 88.94)),
51 | (0.0, Array(0.0, 2.0, 17.08)),
52 | (1.0, Array(1.0, 2.0, 69.48)),
53 | (1.0, Array(2.0, 2.0, -76.47)),
54 | (2.0, Array(3.0, 2.0, 90.90)),
55 | (2.0, Array(4.0, 2.0, -79.67))
56 | )
57 | }
58 |
59 | def labeledData2: Array[(Double, Array[Double])] = {
60 | Array(
61 | (0.0, Array(0.0, 0.0)),
62 | (0.0, Array(1.0, 0.0)),
63 | (1.0, Array(2.0, 0.0)),
64 | (1.0, Array(3.0, 0.0)),
65 | (2.0, Array(4.0, 0.0)),
66 | (2.0, Array(0.0, 1.0)),
67 | (2.0, Array(1.0, 1.0)),
68 | (3.0, Array(2.0, 1.0)),
69 | (3.0, Array(3.0, 1.0)),
70 | (0.0, Array(4.0, 1.0)),
71 | (2.0, Array(0.0, 1.0)),
72 | (2.0, Array(1.0, 1.0)),
73 | (3.0, Array(2.0, 1.0)),
74 | (3.0, Array(3.0, 1.0)),
75 | (0.0, Array(4.0, 1.0)),
76 | (0.0, Array(0.0, 2.0)),
77 | (1.0, Array(1.0, 2.0)),
78 | (1.0, Array(2.0, 2.0)),
79 | (2.0, Array(3.0, 2.0)),
80 | (2.0, Array(4.0, 2.0)),
81 | (0.0, Array(0.0, 2.0)),
82 | (1.0, Array(1.0, 2.0)),
83 | (1.0, Array(2.0, 2.0)),
84 | (2.0, Array(3.0, 2.0)),
85 | (2.0, Array(4.0, 2.0)),
86 | (0.0, Array(0.0, 2.0)),
87 | (1.0, Array(1.0, 2.0)),
88 | (1.0, Array(2.0, 2.0)),
89 | (2.0, Array(3.0, 2.0)),
90 | (2.0, Array(4.0, 2.0))
91 | )
92 | }
93 |
94 | def labeledData3: Array[(Double, Array[Double])] = {
95 | Array(
96 | (0.0, Array(0.0, 0.0)),
97 | (0.0, Array(1.0, 0.0)),
98 | (1.0, Array(2.0, 0.0)),
99 | (1.1, Array(3.0, 0.0)),
100 | (2.0, Array(4.0, 0.0)),
101 | (2.3, Array(0.0, 1.0)),
102 | (2.0, Array(1.0, 1.0)),
103 | (3.0, Array(2.0, 1.0)),
104 | (3.5, Array(3.0, 1.0)),
105 | (0.0, Array(4.0, 1.0)),
106 | (2.0, Array(0.0, 1.0)),
107 | (2.0, Array(1.0, 1.0)),
108 | (3.0, Array(2.0, 1.0)),
109 | (3.2, Array(3.0, 1.0)),
110 | (0.0, Array(4.0, 1.0)),
111 | (0.0, Array(0.0, 2.0)),
112 | (1.0, Array(1.0, 2.0)),
113 | (1.0, Array(2.0, 2.0)),
114 | (2.0, Array(3.0, 2.0)),
115 | (2.0, Array(4.0, 2.0)),
116 | (0.0, Array(0.0, 2.0)),
117 | (1.0, Array(1.0, 2.0)),
118 | (1.0, Array(2.0, 2.0)),
119 | (2.0, Array(3.0, 2.0)),
120 | (2.0, Array(4.0, 2.0)),
121 | (0.0, Array(0.0, 2.0)),
122 | (1.0, Array(1.0, 2.0)),
123 | (1.0, Array(2.0, 2.0)),
124 | (2.0, Array(3.0, 2.0)),
125 | (2.0, Array(4.0, 2.0))
126 | )
127 | }
128 |
129 | def labeledData4: Array[(Double, Array[Double])] = {
130 | Array(
131 | (-1.0, Array(0.0, 0.0)),
132 | (0.0, Array(1.0, 0.0)),
133 | (1.0, Array(2.0, 0.0)),
134 | (1.1, Array(3.0, 0.0)),
135 | (2.0, Array(4.0, 0.0)),
136 | (2.3, Array(0.0, 1.0)),
137 | (2.0, Array(1.0, 1.0)),
138 | (3.0, Array(2.0, 1.0)),
139 | (3.5, Array(3.0, 1.0)),
140 | (0.0, Array(4.0, 1.0)),
141 | (2.0, Array(0.0, 1.0)),
142 | (2.0, Array(1.0, 1.0)),
143 | (3.0, Array(2.0, 1.0)),
144 | (3.2, Array(3.0, 1.0)),
145 | (0.0, Array(4.0, 1.0)),
146 | (0.0, Array(0.0, 2.0)),
147 | (1.0, Array(1.0, 2.0)),
148 | (1.0, Array(2.0, 2.0)),
149 | (2.0, Array(3.0, 2.0)),
150 | (2.0, Array(4.0, 2.0)),
151 | (0.0, Array(0.0, 2.0)),
152 | (1.0, Array(1.0, 2.0)),
153 | (1.0, Array(2.0, 2.0)),
154 | (2.0, Array(3.0, 2.0)),
155 | (2.0, Array(4.0, 2.0)),
156 | (0.0, Array(0.0, 2.0)),
157 | (1.0, Array(1.0, 2.0)),
158 | (1.0, Array(2.0, 2.0)),
159 | (2.0, Array(3.0, 2.0)),
160 | (2.0, Array(4.0, 2.0))
161 | )
162 | }
163 |
164 | def labeledData5: Array[(Double, Array[Double])] = {
165 | Array(
166 | (0.0, Array(0.0, 0.0, -52.28)),
167 | (0.0, Array(1.0, 0.0, -32.16)),
168 | (1.0, Array(2.0, 0.0, -73.68)),
169 | (1.0, Array(3.0, 0.0, -26.38)),
170 | (2.0, Array(4.0, 0.0, 13.69)),
171 | (2.0, Array(0.0, 1.0, 42.07)),
172 | (2.0, Array(1.0, 1.0, 22.96)),
173 | (3.0, Array(2.0, 1.0, -33.43)),
174 | (3.0, Array(3.0, 1.0, -61.80)),
175 | (0.0, Array(4.0, 1.0, -81.34)),
176 | (2.0, Array(0.0, 1.0, -68.49)),
177 | (2.0, Array(1.0, 1.0, 64.17)),
178 | (3.0, Array(2.0, Double.NaN, 20.88)),
179 | (3.0, Array(3.0, 1.0, 27.75)),
180 | (0.0, Array(4.0, 1.0, 59.07)),
181 | (0.0, Array(0.0, 2.0, -53.55)),
182 | (1.0, Array(1.0, 2.0, 25.89)),
183 | (1.0, Array(2.0, 2.0, 22.62)),
184 | (2.0, Array(3.0, 2.0, -5.63)),
185 | (2.0, Array(4.0, 2.0, 81.67)),
186 | (0.0, Array(0.0, 2.0, -72.87)),
187 | (1.0, Array(1.0, 2.0, 25.51)),
188 | (1.0, Array(2.0, 2.0, 43.14)),
189 | (2.0, Array(3.0, 2.0, 60.53)),
190 | (2.0, Array(4.0, 2.0, 88.94)),
191 | (0.0, Array(0.0, 2.0, 17.08)),
192 | (1.0, Array(1.0, 2.0, 69.48)),
193 | (1.0, Array(2.0, 2.0, -76.47)),
194 | (2.0, Array(3.0, 2.0, 90.90)),
195 | (2.0, Array(4.0, 2.0, -79.67))
196 | )
197 | }
198 |
199 | def labeledData6: Array[(Double, Array[Double])] = {
200 | Array(
201 | (0.0, Array(0.0, 0.0)),
202 | (0.0, Array(1.0, 0.0)),
203 | (1.0, Array(2.0, Double.NaN)),
204 | (1.0, Array(3.0, Double.NaN)),
205 | (2.0, Array(4.0, 0.0)),
206 | (2.0, Array(0.0, 1.0)),
207 | (2.0, Array(1.0, 1.0)),
208 | (3.0, Array(Double.NaN, 1.0)),
209 | (3.0, Array(Double.NaN, 1.0)),
210 | (0.0, Array(4.0, 1.0)),
211 | (2.0, Array(0.0, 1.0)),
212 | (2.0, Array(1.0, 1.0)),
213 | (3.0, Array(2.0, 1.0)),
214 | (3.0, Array(3.0, 1.0)),
215 | (0.0, Array(4.0, 1.0)),
216 | (0.0, Array(0.0, 2.0)),
217 | (1.0, Array(1.0, 2.0)),
218 | (1.0, Array(2.0, 2.0)),
219 | (2.0, Array(3.0, 2.0)),
220 | (2.0, Array(4.0, 2.0)),
221 | (0.0, Array(0.0, 2.0)),
222 | (1.0, Array(1.0, 2.0)),
223 | (1.0, Array(2.0, 2.0)),
224 | (2.0, Array(3.0, 2.0)),
225 | (2.0, Array(Double.NaN, 2.0)),
226 | (0.0, Array(0.0, 2.0)),
227 | (1.0, Array(1.0, 2.0)),
228 | (1.0, Array(2.0, 2.0)),
229 | (2.0, Array(3.0, 2.0)),
230 | (2.0, Array(4.0, 2.0))
231 | )
232 | }
233 | }
234 |
--------------------------------------------------------------------------------