├── .gitignore
├── LICENSE
├── README.md
├── bin
└── spark-featureselection_2.11-1.0.0.jar
├── build.sbt
├── project
├── build.properties
└── plugins.sbt
├── src
├── main
│ └── scala
│ │ └── org
│ │ └── apache
│ │ └── spark
│ │ └── ml
│ │ └── feature
│ │ └── selection
│ │ ├── FeatureSelector.scala
│ │ ├── FeatureSelectorParams.scala
│ │ ├── embedded
│ │ ├── ImportanceSelector.scala
│ │ └── LRSelector.scala
│ │ ├── filter
│ │ ├── CorrelationSelector.scala
│ │ ├── GiniSelector.scala
│ │ └── InfoGainSelector.scala
│ │ └── util
│ │ └── VectorMerger.scala
└── test
│ ├── resources
│ └── iris.data
│ └── scala
│ └── org
│ └── apache
│ └── spark
│ └── ml
│ └── feature
│ └── selection
│ ├── FeatureSelectionTestBase.scala
│ ├── embedded
│ ├── ImportanceSelectorSuite.scala
│ └── LRSelectorSuite.scala
│ ├── filter
│ ├── CorrelationSelectorSuite.scala
│ ├── GiniSelectorSuite.scala
│ └── InfoGainSelectorSuite.scala
│ ├── test_util
│ ├── DefaultReadWriteTest.scala
│ └── TempDirectory.scala
│ └── util
│ └── VectorMergerSuite.scala
└── version.sbt
/.gitignore:
--------------------------------------------------------------------------------
1 | build/sbt-launch*.jar
2 | target/
3 | .idea/
4 | lib/
5 | .idea_modules/
6 | *.iml
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "{}"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright {yyyy} {name of copyright owner}
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Feature Selection for Apache Spark
2 | Different Featureselection methods (3 filters/ 2 selectors based on scores from embedded methods) are provided as Spark MLlib `PipelineStage`s.
3 | These are:
4 |
5 | ## Filters:
6 | 1) CorrelationSelector: calculates correlation ("spearman", "pearson"- adjustable through ```.setCorrelationType```) between each feature and label.
7 | 2) GiniSelector: measures impurity difference between before and after a feature value is known.
8 | 3) InfoGainSelector: measures the information gain of a feature with respect to the class.
9 |
10 | ## Embedded:
11 | 1) ImportanceSelector: takes FeatureImportances from any embedded method, e.g. Random Forest.
12 | 2) LRSelector: takes feature weights from (L1) logistic regression. The weights are in a matrix W with dimensions #Labels X #Features. The absolute value is taken from all entries, summed column wise and scaled with the max value.
13 |
14 | ## Util
15 | 1) VectorMerger: takes several VectorColumns (e.g. the result of different feature selection methods) and merges them into one VectorColumn. Unlike the VectorAssembler, VectorMerger uses the metadata of the VectorColumn to remove duplicates. It supports two modes:
16 | - useFeaturesCol true and featuresCol set: the output column will contain the corresponding column from featuresCol (match by name) that have names appearing in one of the inputCols. Use this, if feature importances were calculated using (e.g.) discretized columns, but selection shall use original values.
17 | - useFeaturesCol false: the output column will contain the columns from the inputColumns, but dropping duplicates.
18 |
19 | Formulas for metrics:
20 |
21 | General:
22 | -
- prior probability of feature X having value
23 | -
- cond. probability that a sample is of class
, given that feature X has value
24 | -
- prior probability that the label Y has value
25 |
26 | 1) Correlation: Calculated through ``org.apache.spark.mllib.stat``
27 | 2) Gini:
28 |
29 |
30 | 3) Informationgain:
31 |
32 |
33 |
34 | ## Usage
35 |
36 | All selection methods share a common API, similar to `ChiSqSelector`.
37 |
38 | ```scala
39 | import org.apache.spark.ml.feature.selection.filter._
40 | import org.apache.spark.ml.feature.selection.util._
41 | import org.apache.spark.ml.linalg.Vectors
42 | import org.apache.spark.ml.Pipeline
43 |
44 | val data = Seq(
45 | (Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0),
46 | (Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0),
47 | (Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0)
48 | )
49 |
50 | val df = spark.createDataset(data).toDF("features", "label")
51 |
52 | val igSel = new InfoGainSelector()
53 | .setFeaturesCol("features")
54 | .setLabelCol("label")
55 | .setOutputCol("igSelectedFeatures")
56 | .setSelectorType("percentile")
57 |
58 | val corSel = new CorrelationSelector()
59 | .setFeaturesCol("features")
60 | .setLabelCol("label")
61 | .setOutputCol("corrSelectedFeatures")
62 | .setSelectorType("percentile")
63 |
64 | val giniSel = new GiniSelector()
65 | .setFeaturesCol("features")
66 | .setLabelCol("label")
67 | .setOutputCol("giniSelectedFeatures")
68 | .setSelectorType("percentile")
69 |
70 | val merger = new VectorMerger()
71 | .setInputCols(Array("igSelectedFeatures", "corrSelectedFeatures", "giniSelectedFeatures"))
72 | .setOutputCol("filtered")
73 |
74 | val plm = new Pipeline().setStages(Array(igSel, corSel, giniSel, merger)).fit(df)
75 |
76 | plm.transform(df).select("filtered").show()
77 |
78 | ```
79 |
80 |
81 |
--------------------------------------------------------------------------------
/bin/spark-featureselection_2.11-1.0.0.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcKaminski/spark-FeatureSelection/edccceec6e4dfaf44533a71d59050e66c3c8ba04/bin/spark-featureselection_2.11-1.0.0.jar
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | name := "spark-FeatureSelection"
2 |
3 | organization := "MarcKaminski"
4 |
5 | version := "1.0.0"
6 |
7 | scalaVersion := "2.11.8"
8 |
9 | val sparkVersion = "2.2.0"
10 |
11 |
12 | // spark version to be used
13 | libraryDependencies ++= Seq(
14 | "org.apache.spark" %% "spark-core" % sparkVersion % "provided",
15 | "org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
16 | "org.apache.spark" %% "spark-mllib" % sparkVersion % "provided"
17 | )
18 |
19 | // For tests
20 | parallelExecution in Test := false
21 | fork in Test := false // true -> Spark during tests; false -> debug during tests (for debug run sbt with: sbt -jvm-debug 5005)
22 | libraryDependencies += "org.scalatest" % "scalatest_2.11" % "3.0.1" % "test"
23 | libraryDependencies += "org.scalactic" %% "scalactic" % "3.0.1" % "test"
24 |
25 |
26 | /********************
27 | * Release settings *
28 | ********************/
29 |
30 | publishMavenStyle := true
31 |
32 | licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0"))
33 |
34 | pomExtra :=
35 | https://github.com/MarcKaminski/spark-FeatureSelection
36 |
37 | git@github.com:MarcKaminski/spark-FeatureSelection.git
38 | scm:git:git@github.com:MarcKaminski/spark-FeatureSelection.git
39 |
40 |
41 |
42 | MarcKaminski
43 | Marc Kaminski
44 | https://github.com/MarcKaminski
45 |
46 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version = 0.13.16
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.4" exclude("org.apache.maven", "maven-plugin-api"))
2 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.6")
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/ml/feature/selection/FeatureSelector.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection
19 |
20 | import org.apache.hadoop.fs.Path
21 | import org.apache.spark.annotation.Since
22 | import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute}
23 | import org.apache.spark.ml.feature.VectorAssembler
24 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, _}
25 | import org.apache.spark.ml.param.ParamMap
26 | import org.apache.spark.ml.util._
27 | import org.apache.spark.ml.{Estimator, Model}
28 | import org.apache.spark.sql.functions.{rand, udf}
29 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
30 | import org.apache.spark.sql.{DataFrame, Dataset}
31 |
32 | /**
33 | * Abstraction for FeatureSelectors, which selects features to use for predicting a categorical label.
34 | * The selector supports two selection methods: `numTopFeatures` and `percentile`.
35 | * - `numTopFeatures` chooses a fixed number of top features according to the feature importance.
36 | * - `percentile` is similar but chooses a fraction of all features instead of a fixed number.
37 | * By default, the selection method is `numTopFeatures`, with the default number of top features set to 50.
38 | *
39 | * @tparam Learner Specialization of this class. If you subclass this type, use this type
40 | * parameter to specify the concrete type.
41 | * @tparam M Specialization of [[FeatureSelectorModel]]. If you subclass this type, use this type
42 | * parameter to specify the concrete type for the corresponding model.
43 | */
44 | abstract class FeatureSelector[
45 | Learner <: FeatureSelector[Learner, M],
46 | M <: FeatureSelectorModel[M]] @Since("2.1.1")
47 | extends Estimator[M] with FeatureSelectorParams with DefaultParamsWritable {
48 | /** @group setParam */
49 | @Since("2.1.1")
50 | def setNumTopFeatures(value: Int): Learner = set(numTopFeatures, value).asInstanceOf[Learner]
51 |
52 | /** @group setParam */
53 | @Since("2.1.1")
54 | def setPercentile(value: Double): Learner = set(percentile, value).asInstanceOf[Learner]
55 |
56 | /** @group setParam */
57 | @Since("2.1.1")
58 | def setSelectorType(value: String): Learner = set(selectorType, value).asInstanceOf[Learner]
59 |
60 | /** @group setParam */
61 | @Since("2.1.1")
62 | def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
63 |
64 | /** @group setParam */
65 | @Since("2.1.1")
66 | def setOutputCol(value: String): Learner = set(outputCol, value).asInstanceOf[Learner]
67 |
68 | /** @group setParam */
69 | @Since("2.1.1")
70 | def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
71 |
72 | /** @group setParam */
73 | @Since("2.1.1")
74 | def setRandomCutOff(value: Double): Learner = set(randomCutOff, value).asInstanceOf[Learner]
75 |
76 | override def fit(dataset: Dataset[_]): M = {
77 | // This handles a few items such as schema validation.
78 | // Developers only need to implement train() and make().
79 | transformSchema(dataset.schema, logging = true)
80 |
81 | val randomColMaxCategories = 10000
82 |
83 | // Get num features for percentile calculation
84 | val attrGroup = AttributeGroup.fromStructField(dataset.schema($ {
85 | featuresCol
86 | }))
87 | val numFeatures = attrGroup.size
88 |
89 | val (featureImportances, features) =
90 | if ($(selectorType) == FeatureSelector.Random) {
91 | // Append column with random values to dataframe
92 | val withRandom = dataset.withColumn("random", (rand * randomColMaxCategories).cast(IntegerType))
93 | val featureVectorWithRandom = new VectorAssembler()
94 | .setInputCols(Array($(featuresCol), "random"))
95 | .setOutputCol("FeaturesAndRandom")
96 | .transform(withRandom)
97 |
98 | // Cache and change features column name, calculate importances and reset.
99 | val realFeaturesCol = $(featuresCol)
100 | setFeaturesCol("FeaturesAndRandom")
101 | val featureImportances = train(featureVectorWithRandom)
102 | setFeaturesCol(realFeaturesCol)
103 | val idFromRandomCol = featureImportances.map(_._1).max
104 |
105 | // Take features until reaching random feature. Take overlap from remaining depending on randomCutOff percentage
106 | val sortedFeatureImportances = featureImportances
107 | .sortBy { case (_, imp) => -imp } // minus for descending direction!
108 | .zipWithIndex
109 |
110 | val randomColPos = sortedFeatureImportances.find { case ((fId, fImp), sortId) => fId == idFromRandomCol }.get._2
111 | val overlap = math.max(0, math.round((featureImportances.length - randomColPos - 1) * $(randomCutOff))).toInt
112 |
113 | (featureImportances.filterNot(_._1 == idFromRandomCol), sortedFeatureImportances
114 | .take(randomColPos + overlap + 1)
115 | .map(_._1)
116 | .filterNot(_._1 == idFromRandomCol))
117 | } else {
118 | val featureImportances = train(dataset)
119 |
120 | // Select features depending on selection method
121 | val features = $(selectorType) match {
122 | case FeatureSelector.NumTopFeatures => featureImportances
123 | .sortBy { case (_, imp) => -imp } // minus for descending direction!
124 | .take($(numTopFeatures))
125 | case FeatureSelector.Percentile => featureImportances
126 | .sortBy { case (_, imp) => -imp }
127 | .take((numFeatures * $(percentile)).toInt) // Take is save, even if numFeatures > featureImportances.length
128 | case errorType =>
129 | throw new IllegalStateException(s"Unknown FeatureSelector Type: $errorType")
130 | }
131 | (featureImportances, features)
132 | }
133 |
134 | if (featureImportances.length < numFeatures)
135 | logWarning(s"Some features were dropped while calculating importance values, " +
136 | s"since numFeatureImportances < numFeatures (${featureImportances.length} < $numFeatures). This happens " +
137 | s"e.g. for constant features.")
138 |
139 |
140 | // Save name of columns and corresponding importance value
141 | val nameImportances = featureImportances.map { case (idx, imp) => (
142 | if (attrGroup.attributes.isDefined && attrGroup.getAttr(idx).name.isDefined)
143 | attrGroup.getAttr(idx).name.get
144 | else {
145 | logWarning(s"The metadata of $featuresCol is empty or does not contain a name for col index: $idx")
146 | idx.toString
147 | }
148 | , imp)
149 | }
150 |
151 | val indices = features.map { case (idx, _) => idx }
152 |
153 | copyValues(make(uid, indices, nameImportances.toMap).setParent(this))
154 | }
155 |
156 | override def copy(extra: ParamMap): Learner
157 |
158 | /**
159 | * Calculate the featureImportances. These shall be sortable in descending direction to select the best features.
160 | * FeatureSelectors implement this to avoid dealing with schema validation
161 | * and copying parameters into the model.
162 | *
163 | * @param dataset Training dataset
164 | * @return Array of (feature index, feature importance)
165 | */
166 | protected def train(dataset: Dataset[_]): Array[(Int, Double)]
167 |
168 | /**
169 | * Abstract intantiation of the Model.
170 | * FeatureSelectors implement this as a constructor for FeatureSelectorModels
171 | * @param uid of Model
172 | * @param selectedFeatures list of indices to select
173 | * @param featureImportances Map that stores each feature importance
174 | * @return Fitted model
175 | */
176 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): M
177 |
178 | @Since("2.1.1")
179 | def transformSchema(schema: StructType): StructType = {
180 | val otherPairs = FeatureSelector.supportedSelectorTypes.filter(_ != $(selectorType))
181 | otherPairs.foreach { paramName: String =>
182 | if (isSet(getParam(paramName))) {
183 | logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
184 | }
185 | }
186 | SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
187 | SchemaUtils.checkNumericType(schema, $(labelCol))
188 | SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
189 | }
190 | }
191 |
192 | /**
193 | * Abstraction for a model for selecting features.
194 | * @param uid of Model
195 | * @param selectedFeatures list of indices to select
196 | * @param featureImportances Map that stores each feature importance
197 | * @tparam M Specialization of [[FeatureSelectorModel]]. If you subclass this type, use this type
198 | * parameter to specify the concrete type for the corresponding model.
199 | */
200 | @Since("2.1.1")
201 | abstract class FeatureSelectorModel[M <: FeatureSelectorModel[M]] private[ml] (@Since("2.1.1") override val uid: String,
202 | @Since("2.1.1") val selectedFeatures: Array[Int],
203 | @Since("2.1.1") val featureImportances: Map[String, Double])
204 | extends Model[M] with FeatureSelectorParams with MLWritable {
205 | /** @group setParam */
206 | @Since("2.1.1")
207 | def setFeaturesCol(value: String): this.type = set(featuresCol, value)
208 |
209 | /** @group setParam */
210 | @Since("2.1.1")
211 | def setOutputCol(value: String): this.type = set(outputCol, value)
212 |
213 | @Since("2.1.1")
214 | override def transform(dataset: Dataset[_]): DataFrame = {
215 | transformSchema(dataset.schema, logging = true)
216 |
217 | // Validity checks
218 | val inputAttr = AttributeGroup.fromStructField(dataset.schema($(featuresCol)))
219 | inputAttr.numAttributes.foreach { numFeatures =>
220 | val maxIndex = selectedFeatures.max
221 | require(maxIndex < numFeatures,
222 | s"Selected feature index $maxIndex invalid for only $numFeatures input features.")
223 | }
224 |
225 | // Prepare output attributes
226 | val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs =>
227 | selectedFeatures.map(index => attrs(index))
228 | }
229 | val outputAttr = selectedAttrs match {
230 | case Some(attrs) => new AttributeGroup($(outputCol), attrs)
231 | case None => new AttributeGroup($(outputCol), selectedFeatures.length)
232 | }
233 |
234 | // Select features
235 | val slicer = udf { vec: Vector =>
236 | vec match {
237 | case features: DenseVector => Vectors.dense(selectedFeatures.map(features.apply))
238 | case features: SparseVector => features.slice(selectedFeatures)
239 | }
240 | }
241 | dataset.withColumn($(outputCol), slicer(dataset($(featuresCol))), outputAttr.toMetadata())
242 | }
243 |
244 | @Since("2.1.1")
245 | override def transformSchema(schema: StructType): StructType = {
246 | SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
247 | val newField = prepOutputField(schema)
248 | val outputFields = schema.fields :+ newField
249 | StructType(outputFields)
250 | }
251 |
252 | /**
253 | * Prepare the output column field, including per-feature metadata.
254 | */
255 | private def prepOutputField(schema: StructType): StructField = {
256 | val selector = selectedFeatures.toSet
257 | val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol)))
258 | val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) {
259 | origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1)
260 | } else {
261 | Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr)
262 | }
263 | val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes)
264 | newAttributeGroup.toStructField()
265 | }
266 | }
267 |
268 |
269 | @Since("2.1.1")
270 | protected [selection] object FeatureSelectorModel {
271 |
272 | class FeatureSelectorModelWriter[M <: FeatureSelectorModel[M]](instance: M) extends MLWriter {
273 |
274 | private case class Data(selectedFeatures: Seq[Int], featureImportances: Map[String, Double])
275 |
276 | override protected def saveImpl(path: String): Unit = {
277 | DefaultParamsWriter.saveMetadata(instance, path, sc)
278 | val data = Data(instance.selectedFeatures.toSeq, instance.featureImportances)
279 | val dataPath = new Path(path, "data").toString
280 | sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
281 | }
282 | }
283 |
284 | abstract class FeatureSelectorModelReader[M <: FeatureSelectorModel[M]] extends MLReader[M] {
285 |
286 | protected val className: String
287 |
288 | override def load(path: String): M = {
289 | val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
290 | val dataPath = new Path(path, "data").toString
291 | val data = sparkSession.read.parquet(dataPath)
292 | val selectedFeatures = data.select("selectedFeatures").head().getAs[Seq[Int]](0).toArray
293 | val featureImportances = data.select("featureImportances").head().getAs[Map[String, Double]](0)
294 | val model = make(metadata.uid, selectedFeatures, featureImportances)
295 | DefaultParamsReader.getAndSetParams(model, metadata)
296 | model
297 | }
298 |
299 | /**
300 | * Abstract intantiation of the Model.
301 | * FeatureSelectors implement this as a constructor for FeatureSelectorModels
302 | * @param uid of Model
303 | * @param selectedFeatures list of indices to select
304 | * @param featureImportances Map that stores each feature importance
305 | * @return Fitted model
306 | */
307 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): M
308 | }
309 | }
310 |
311 | private[selection] object FeatureSelector {
312 | /**
313 | * String name for `numTopFeatures` selector type.
314 | */
315 | val NumTopFeatures: String = "numTopFeatures"
316 |
317 | /**
318 | * String name for `percentile` selector type.
319 | */
320 | val Percentile: String = "percentile"
321 |
322 | /**
323 | * String name for `random` selector type.
324 | */
325 | val Random: String = "randomCutOff"
326 |
327 | /** Set of selector types that FeatureSelector supports. */
328 | val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, Random)
329 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/ml/feature/selection/FeatureSelectorParams.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection
19 |
20 | import org.apache.spark.annotation.Since
21 | import org.apache.spark.ml.param._
22 | import org.apache.spark.ml.param.shared._
23 |
24 | /**
25 | * Params for [[FeatureSelector]] and [[FeatureSelectorModel]].
26 | */
27 | private[selection] trait FeatureSelectorParams extends Params
28 | with HasFeaturesCol with HasOutputCol with HasLabelCol {
29 |
30 | /**
31 | * Number of features that selector will select. If the
32 | * number of features is less than numTopFeatures, then this will select all features.
33 | * Only applicable when selectorType = "numTopFeatures".
34 | * The default value of numTopFeatures is 50.
35 | *
36 | * @group param
37 | */
38 | @Since("2.1.1")
39 | final val numTopFeatures = new IntParam(this, "numTopFeatures",
40 | "Number of features that selector will select. If the" +
41 | " number of features is < numTopFeatures, then this will select all features.",
42 | ParamValidators.gtEq(1))
43 | setDefault(numTopFeatures -> 50)
44 |
45 | /** @group getParam */
46 | @Since("2.1.1")
47 | def getNumTopFeatures: Int = $(numTopFeatures)
48 |
49 | /**
50 | * Percentile of features that selector will select.
51 | * Only applicable when selectorType = "percentile".
52 | * Default value is 0.5.
53 | *
54 | * @group param
55 | */
56 | @Since("2.1.1")
57 | final val percentile = new DoubleParam(this, "percentile",
58 | "Percentile of features that selector will select.",
59 | ParamValidators.inRange(0, 1))
60 | setDefault(percentile -> 0.5)
61 |
62 | /** @group getParam */
63 | @Since("2.1.1")
64 | def getPercentile: Double = $(percentile)
65 |
66 | /**
67 | * Percentile of features that selector will select after the random column threshold (of number of remaining features).
68 | * Only applicable when selectorType = "randomCutOff".
69 | * Default value is 0.05.
70 | *
71 | * @group param
72 | */
73 | @Since("2.1.1")
74 | final val randomCutOff = new DoubleParam(this, "randomCutOff",
75 | "Percentile of features that selector will select after the random column threshold (of number of remaining features).",
76 | ParamValidators.inRange(0, 1))
77 | setDefault(percentile -> 0.05)
78 |
79 | /** @group getParam */
80 | @Since("2.1.1")
81 | def getRandomCutOff: Double = $(randomCutOff)
82 |
83 | /**
84 | * The selector type of the FeatureSelector.
85 | * Supported options: "numTopFeatures", "percentile" (default), .
86 | *
87 | * @group param
88 | */
89 | @Since("2.1.1")
90 | final val selectorType = new Param[String](this, "selectorType",
91 | "The selector type of the FeatureSelector. " +
92 | "Supported options: " + FeatureSelector.supportedSelectorTypes.mkString(", "),
93 | ParamValidators.inArray[String](FeatureSelector.supportedSelectorTypes))
94 | setDefault(selectorType -> FeatureSelector.Percentile)
95 |
96 | /** @group getParam */
97 | @Since("2.1.1")
98 | def getSelectorType: String = $(selectorType)
99 | }
100 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/ml/feature/selection/embedded/ImportanceSelector.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.embedded
19 |
20 | import org.apache.spark.annotation.Since
21 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel}
22 | import org.apache.spark.ml.linalg.{Vector, _}
23 | import org.apache.spark.ml.param._
24 | import org.apache.spark.ml.util._
25 | import org.apache.spark.sql._
26 | import org.apache.spark.sql.types.StructType
27 |
28 | import scala.collection.mutable
29 |
30 | /**
31 | * Params for [[ImportanceSelector]] and [[ImportanceSelectorModel]].
32 | */
33 |
34 | private[embedded] trait ImportanceSelectorParams extends Params {
35 | /**
36 | * Param for featureImportances.
37 | *
38 | * @group param
39 | */
40 | final val featureWeights: Param[Vector] = new Param[Vector](this, "featureWeights",
41 | "featureWeights to rank features and select from a column")
42 |
43 | /** @group getParam */
44 | final def getFeatureWeights: Vector = $(featureWeights)
45 | }
46 |
47 | /**
48 | * Feature selection based on featureImportances (e.g. from RandomForest).
49 | */
50 | @Since("2.1.1")
51 | final class ImportanceSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String)
52 | extends FeatureSelector[ImportanceSelector, ImportanceSelectorModel] with ImportanceSelectorParams {
53 |
54 | @Since("2.1.1")
55 | def this() = this(Identifiable.randomUID("importanceSelector"))
56 |
57 | /** @group setParam */
58 | @Since("2.1.1")
59 | def setFeatureWeights(value: Vector): this.type = set(featureWeights, value)
60 |
61 | @Since("2.1.1")
62 | override protected def train(dataset: Dataset[_]): Array[(Int, Double)] = {
63 | val arrBuilder = new mutable.ArrayBuffer[(Int, Double)]()
64 | $(featureWeights).foreachActive((idx, value) => arrBuilder.append((idx, value)))
65 | val featureImportancesLocal = arrBuilder.toArray
66 |
67 | if ($(selectorType) == FeatureSelector.Random) {
68 | val mean = featureImportancesLocal.map(_._2).sum / featureImportancesLocal.length
69 | featureImportancesLocal :+ (featureImportancesLocal.map(_._1).max + 1, mean)
70 | } else
71 | featureImportancesLocal
72 | }
73 |
74 | @Since("2.1.1")
75 | override def transformSchema(schema: StructType): StructType = {
76 | val otherPairs = FeatureSelector.supportedSelectorTypes.filter(_ != $(selectorType))
77 | otherPairs.foreach { paramName: String =>
78 | if (isSet(getParam(paramName))) {
79 | logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
80 | }
81 | }
82 |
83 | require(isDefined(featureWeights))
84 | SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
85 | SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
86 | }
87 |
88 | @Since("2.1.1")
89 | override def copy(extra: ParamMap): ImportanceSelector = defaultCopy(extra)
90 |
91 | @Since("2.1.1")
92 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): ImportanceSelectorModel = {
93 | new ImportanceSelectorModel(uid, selectedFeatures, featureImportances)
94 | }
95 | }
96 |
97 | object ImportanceSelector extends DefaultParamsReadable[ImportanceSelector] {
98 | @Since("2.1.1")
99 | override def load(path: String): ImportanceSelector = super.load(path)
100 | }
101 |
102 | /**
103 | * Model fitted by [[ImportanceSelector]].
104 | * @param uid of Model
105 | * @param selectedFeatures list of indices to select
106 | * @param featureImportances Map that stores each feature importance
107 | */
108 | @Since("2.1.1")
109 | final class ImportanceSelectorModel private[selection] (@Since("2.1.1") override val uid: String,
110 | @Since("2.1.1") override val selectedFeatures: Array[Int],
111 | @Since("2.1.1") override val featureImportances: Map[String, Double])
112 | extends FeatureSelectorModel[ImportanceSelectorModel](uid, selectedFeatures, featureImportances) with ImportanceSelectorParams {
113 | @Since("2.1.1")
114 | override def copy(extra: ParamMap): ImportanceSelectorModel = {
115 | val copied = new ImportanceSelectorModel(uid, selectedFeatures, featureImportances)
116 | copyValues(copied, extra).setParent(parent)
117 | }
118 |
119 | @Since("2.1.1")
120 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter[ImportanceSelectorModel](this)
121 | }
122 |
123 | @Since("2.1.1")
124 | object ImportanceSelectorModel extends MLReadable[ImportanceSelectorModel] {
125 | @Since("2.1.1")
126 | override def read: MLReader[ImportanceSelectorModel] = new ImportanceSelectorModelReader
127 |
128 | @Since("2.1.1")
129 | override def load(path: String): ImportanceSelectorModel = super.load(path)
130 | }
131 |
132 | @Since("2.1.1")
133 | final class ImportanceSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[ImportanceSelectorModel] {
134 | @Since("2.1.1")
135 | override protected val className: String = classOf[ImportanceSelectorModel].getName
136 |
137 | @Since("2.1.1")
138 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): ImportanceSelectorModel = {
139 | new ImportanceSelectorModel(uid, selectedFeatures, featureImportances)
140 | }
141 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/ml/feature/selection/embedded/LRSelector.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 |
19 | package org.apache.spark.ml.feature.selection.embedded
20 |
21 | import org.apache.spark.annotation.{DeveloperApi, Since}
22 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel}
23 | import org.apache.spark.ml.linalg._
24 | import org.apache.spark.ml.param.{Param, _}
25 | import org.apache.spark.ml.util._
26 | import org.apache.spark.mllib.linalg.{Vectors => MllibVectors}
27 | import org.apache.spark.mllib.stat.Statistics
28 | import org.apache.spark.sql._
29 | import org.apache.spark.sql.types.StructType
30 | import org.json4s.jackson.JsonMethods.{compact, parse, render}
31 | import org.json4s.{JArray, JDouble, JInt, JObject, JValue}
32 |
33 | /**
34 | * Specialized version of `Param[Matrix]` for Java.
35 | */
36 | @DeveloperApi
37 | private[embedded] class MatrixParam(parent: String, name: String, doc: String, isValid: Matrix => Boolean)
38 | extends Param[Matrix](parent, name, doc, isValid) {
39 |
40 | def this(parent: String, name: String, doc: String) =
41 | this(parent, name, doc, (mat: Matrix) => true)
42 |
43 | def this(parent: Identifiable, name: String, doc: String, isValid: Matrix => Boolean) =
44 | this(parent.uid, name, doc, isValid)
45 |
46 | def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
47 |
48 | /** Creates a param pair with the given value (for Java). */
49 | override def w(value: Matrix): ParamPair[Matrix] = super.w(value)
50 |
51 | override def jsonEncode(value: Matrix): String = {
52 | compact(render(MatrixParam.jValueEncode(value)))
53 | }
54 |
55 | override def jsonDecode(json: String): Matrix = {
56 | MatrixParam.jValueDecode(parse(json))
57 | }
58 | }
59 |
60 | private[embedded] object MatrixParam {
61 | /** Encodes a param value into JValue. */
62 | def jValueEncode(value: Matrix): JValue = {
63 | val rows = JInt(value.numRows)
64 | val cols = JInt(value.numCols)
65 | val vals = JArray(for (v <- value.toArray.toList) yield JDouble(v))
66 |
67 | JObject(List(("rows", rows), ("cols", cols), ("vals", vals)))
68 | }
69 |
70 | /** Decodes a param value from JValue. */
71 | def jValueDecode(jValue: JValue): Matrix = {
72 | var rows, cols: Int = 0
73 | var vals: Array[Double] = Array.empty[Double]
74 |
75 | var rowsSet, colsSet, valsSet = false
76 |
77 | jValue match {
78 | case obj: JObject => for (kvPair <- obj.values) {
79 | kvPair._2 match {
80 | case x: BigInt =>
81 | if (kvPair._1 == "rows") {
82 | rows = x.toInt
83 | rowsSet = true
84 | }
85 | else if (kvPair._1 == "cols") {
86 | cols = x.toInt
87 | colsSet = true
88 | }
89 | else
90 | throw new IllegalArgumentException(s"Cannot recognize unexpected key ${kvPair._1}. (Value is BigInt)")
91 | case arr: List[_] =>
92 | if (arr.forall { case _: Double => true; case _ => false })
93 | if (kvPair._1 == "vals") {
94 | vals = arr.asInstanceOf[List[Double]].toArray
95 | valsSet = true
96 | }
97 | else
98 | throw new IllegalArgumentException(s"Cannot decode unexpected key ${kvPair._1} to Matrix.")
99 | else
100 | throw new IllegalArgumentException(s"Cannot decode unexpected key ${kvPair._1} with value: ${kvPair._2}.")
101 | case _ => throw new IllegalArgumentException(s"Cannot decode unexpected key ${kvPair._1} with value: ${kvPair._2}.")
102 | }
103 | }
104 | case _ =>
105 | throw new IllegalArgumentException(s"Cannot decode $jValue to Matrix.")
106 | }
107 |
108 | if (colsSet && rowsSet && valsSet)
109 | Matrices.dense(rows, cols, vals)
110 | else
111 | throw new IllegalArgumentException(s"Cannot decode $jValue. " +
112 | s"Missing values to create Matrix: colsSet: $colsSet; rowsSet: $rowsSet; valsSet: $valsSet")
113 | }
114 | }
115 |
116 | /**
117 | * Params for [[LRSelector]] and [[LRSelectorModel]].
118 | */
119 |
120 | private[embedded] trait LRSelectorParams extends Params {
121 | /**
122 | * Param for coefficientMatrix.
123 | *
124 | * @group param
125 | */
126 | final val coefficientMatrix: MatrixParam = new MatrixParam(this, "coefficientMatrix", "coefficientMatrix of LR model")
127 |
128 | /** @group getParam */
129 | final def getCoefficientMatrix: Matrix = $(coefficientMatrix)
130 |
131 | /**
132 | * Choose, if coefficients shall be scaled using the maximum value of the corresponding feature.
133 | * Use, if no StandardScaler was used prior LR training
134 | * @group param
135 | */
136 | @Since("2.1.1")
137 | final val scaleCoefficients = new BooleanParam(this, "scaleCoefficients",
138 | "Scale the coefficients using the maximum values of the corresponding feature.")
139 | setDefault(scaleCoefficients, false)
140 |
141 | /** @group getParam */
142 | @Since("2.1.1")
143 | def getScaleCoefficients: Boolean = $(scaleCoefficients)
144 | }
145 |
146 | /**
147 | * Feature selection based on LR weights (absolute value).
148 | * The selector can scale the coefficients using the corresponding maximum feature value. To activate, set `scaleCoefficients` to true.
149 | * Default: false
150 | */
151 | @Since("2.1.1")
152 | final class LRSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String)
153 | extends FeatureSelector[LRSelector, LRSelectorModel] with LRSelectorParams {
154 |
155 | @Since("2.1.1")
156 | def this() = this(Identifiable.randomUID("lrSelector"))
157 |
158 | /** @group setParam */
159 | @Since("2.1.1")
160 | def setCoefficientMatrix(value: Matrix): this.type = set(coefficientMatrix, value)
161 |
162 | /** @group setParam */
163 | @Since("2.1.1")
164 | def setScaleCoefficients(value: Boolean): this.type = set(scaleCoefficients, value)
165 |
166 | @Since("2.1.1")
167 | override protected def train(dataset: Dataset[_]): Array[(Int, Double)] = {
168 | val input = dataset.select($(featuresCol)).rdd.map { case Row(features: Vector) => MllibVectors.fromML(features) }
169 | val inputStats = Statistics.colStats(input)
170 | // Calculate maxValues = max(abs(min(feature)), abs(max(feature)))
171 | val absMaxValues = inputStats.max.toArray.map(elem => math.abs(elem))
172 | val absMinValues = inputStats.min.toArray.map(elem => math.abs(elem))
173 |
174 | val maxValues = if (!$(scaleCoefficients))
175 | Array.fill(absMinValues.length)(1.0)
176 | else
177 | absMaxValues.zip(absMinValues).map { case (max, min) => math.max(max, min) }
178 |
179 | // Calculate normalized and absolute sum of LR weights for each feature
180 | val coeffVectors = $(coefficientMatrix).toArray.grouped($(coefficientMatrix).numRows).toArray
181 |
182 | val coeffFw = coeffVectors.map(col => col.map(elem => math.abs(elem)).sum)
183 |
184 | val scaled = coeffFw.zip(maxValues).map { case (coeff, max) => coeff * max }
185 |
186 | val coeffSum = scaled.sum
187 |
188 | val toScale = if ($(selectorType) == FeatureSelector.Random) {
189 | val coeffMean = if (scaled.length == 0) 0.0 else coeffSum / scaled.length
190 | scaled :+ coeffMean
191 | } else
192 | scaled
193 |
194 | val featureImportances = if (coeffSum == 0) toScale else toScale.map(elem => elem / coeffSum)
195 |
196 | featureImportances.zipWithIndex.map { case (value, index) => (index, value) }
197 | }
198 |
199 | @Since("2.1.1")
200 | override def transformSchema(schema: StructType): StructType = {
201 | val otherPairs = FeatureSelector.supportedSelectorTypes.filter(_ != $(selectorType))
202 | otherPairs.foreach { paramName: String =>
203 | if (isSet(getParam(paramName))) {
204 | logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.")
205 | }
206 | }
207 |
208 | require(isDefined(coefficientMatrix))
209 | SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
210 | SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
211 | }
212 |
213 | @Since("2.1.1")
214 | override def copy(extra: ParamMap): LRSelector = defaultCopy(extra)
215 |
216 | @Since("2.1.1")
217 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): LRSelectorModel = {
218 | new LRSelectorModel(uid, selectedFeatures, featureImportances)
219 | }
220 | }
221 |
222 | object LRSelector extends DefaultParamsReadable[LRSelector] {
223 | @Since("2.1.1")
224 | override def load(path: String): LRSelector = super.load(path)
225 | }
226 |
227 | /**
228 | * Model fitted by [[LRSelector]].
229 | * @param uid of Model
230 | * @param selectedFeatures list of indices to select
231 | * @param featureImportances Map that stores each feature importance
232 | */
233 | @Since("2.1.1")
234 | final class LRSelectorModel private[selection] (@Since("2.1.1") override val uid: String,
235 | @Since("2.1.1") override val selectedFeatures: Array[Int],
236 | @Since("2.1.1") override val featureImportances: Map[String, Double])
237 | extends FeatureSelectorModel[LRSelectorModel](uid, selectedFeatures, featureImportances) with LRSelectorParams {
238 |
239 | @Since("2.1.1")
240 | override def copy(extra: ParamMap): LRSelectorModel = {
241 | val copied = new LRSelectorModel(uid, selectedFeatures, featureImportances)
242 | copyValues(copied, extra).setParent(parent)
243 | }
244 |
245 | @Since("2.1.1")
246 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter[LRSelectorModel](this)
247 | }
248 |
249 | @Since("2.1.1")
250 | object LRSelectorModel extends MLReadable[LRSelectorModel] {
251 | @Since("2.1.1")
252 | override def read: MLReader[LRSelectorModel] = new LRSelectorModelReader
253 |
254 | @Since("2.1.1")
255 | override def load(path: String): LRSelectorModel = super.load(path)
256 | }
257 |
258 | @Since("2.1.1")
259 | final class LRSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[LRSelectorModel] {
260 | @Since("2.1.1")
261 | override protected val className: String = classOf[LRSelectorModel].getName
262 |
263 | @Since("2.1.1")
264 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): LRSelectorModel = {
265 | new LRSelectorModel(uid, selectedFeatures, featureImportances)
266 | }
267 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/ml/feature/selection/filter/CorrelationSelector.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.filter
19 |
20 | import org.apache.spark.SparkException
21 | import org.apache.spark.annotation.Since
22 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel}
23 | import org.apache.spark.ml.linalg.{DenseVector, _}
24 | import org.apache.spark.ml.param._
25 | import org.apache.spark.ml.util._
26 | import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
27 | import org.apache.spark.mllib.stat.Statistics
28 | import org.apache.spark.sql._
29 | import org.apache.spark.sql.functions._
30 |
31 | import scala.collection.mutable
32 |
33 | /**
34 | * Feature selection based on Correlation (absolute value).
35 | */
36 | @Since("2.1.1")
37 | final class CorrelationSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String)
38 | extends FeatureSelector[CorrelationSelector, CorrelationSelectorModel] {
39 |
40 | /**
41 | * The correlation type of the CorrelationSelector.
42 | * Supported options: "spearman" (default), "pearson".
43 | *
44 | * @group param
45 | */
46 | @Since("2.1.1")
47 | val correlationType = new Param[String](this, "correlationType",
48 | "The correlation type used for correlation calculation. " +
49 | "Supported options: " + CorrelationSelector.supportedCorrelationTypes.mkString(", "),
50 | ParamValidators.inArray[String](CorrelationSelector.supportedCorrelationTypes))
51 | setDefault(correlationType -> CorrelationSelector.Spearman)
52 |
53 | /** @group getParam */
54 | @Since("2.1.1")
55 | def getCorrelationType: String = $(correlationType)
56 |
57 | /** @group setParam */
58 | @Since("2.1.1")
59 | def setCorrelationType(value: String): this.type = set(correlationType, value)
60 |
61 | @Since("2.1.1")
62 | def this() = this(Identifiable.randomUID("correlationSelector"))
63 |
64 | @Since("2.1.1")
65 | override protected def train(dataset: Dataset[_]): Array[(Int, Double)] = {
66 | val assembleFunc = udf { r: Row =>
67 | CorrelationSelector.assemble(r.toSeq: _*)
68 | }
69 |
70 | // Merge featuresCol and labelCol into one OldVector like this: [featuresCol, labelCol]
71 | val tmpDf = dataset.select(assembleFunc(struct(dataset($(featuresCol)), dataset($(labelCol)))).as("input"))
72 | val input = tmpDf.select(col("input")).rdd.map { case Row(features: Vector) => OldVectors.fromML(features) }
73 |
74 | // Calculate correlation between all columns in input
75 | // We're only interested in the last row/ column (correlationmatrix is symmetric) of correlation,
76 | // which stands for the correlation between all feature columns and the label column
77 | val correlations = Statistics.corr(input, ${correlationType})
78 |
79 | // Extract the information we're interested in
80 | val columns = correlations.toArray.grouped(correlations.numRows)
81 | val corrList = columns.map(column => new DenseVector(column)).toList
82 | val targetCorr = corrList.last.toArray
83 |
84 | // targetCorr.length - 1, because last element is correlation between label column and itself
85 | // Also take abs(), because important features are also negatively correlated
86 | val targetCorrs = Array.tabulate(targetCorr.length - 1) { i => (i, targetCorr(i)) }
87 | .map(elem => (elem._1, math.abs(elem._2)))
88 |
89 | targetCorrs.filter(!_._2.isNaN)
90 | }
91 |
92 | @Since("2.1.1")
93 | override def copy(extra: ParamMap): CorrelationSelector = defaultCopy(extra)
94 |
95 | @Since("2.1.1")
96 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): CorrelationSelectorModel = {
97 | new CorrelationSelectorModel(uid, selectedFeatures, featureImportances)
98 | }
99 | }
100 |
101 | /**
102 | * Model fitted by [[CorrelationSelector]].
103 | * @param selectedFeatures list of indices to select (filter)
104 | */
105 | @Since("2.1.1")
106 | final class CorrelationSelectorModel private[selection] (@Since("2.1.1") override val uid: String,
107 | @Since("2.1.1") override val selectedFeatures: Array[Int],
108 | @Since("2.1.1") override val featureImportances: Map[String, Double])
109 | extends FeatureSelectorModel[CorrelationSelectorModel](uid, selectedFeatures, featureImportances) {
110 | /**
111 | * The correlation type of the CorrelationSelector.
112 | * Supported options: "spearman" (default), "pearson".
113 | *
114 | * @group param
115 | */
116 | @Since("2.1.1")
117 | val correlationType = new Param[String](this, "correlationType",
118 | "The correlation type used for correlation calculation. " +
119 | "Supported options: " + CorrelationSelector.supportedCorrelationTypes.mkString(", "),
120 | ParamValidators.inArray[String](CorrelationSelector.supportedCorrelationTypes))
121 | setDefault(correlationType -> CorrelationSelector.Spearman)
122 |
123 | /** @group getParam */
124 | @Since("2.1.1")
125 | def getCorrelationType: String = $(correlationType)
126 |
127 | @Since("2.1.1")
128 | override def copy(extra: ParamMap): CorrelationSelectorModel = {
129 | val copied = new CorrelationSelectorModel(uid, selectedFeatures, featureImportances)
130 | copyValues(copied, extra).setParent(parent)
131 | }
132 |
133 | @Since("2.1.1")
134 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter[CorrelationSelectorModel](this)
135 | }
136 |
137 | @Since("2.1.1")
138 | object CorrelationSelectorModel extends MLReadable[CorrelationSelectorModel] {
139 | @Since("2.1.1")
140 | override def read: MLReader[CorrelationSelectorModel] = new CorrelationSelectorModelReader
141 |
142 | @Since("2.1.1")
143 | override def load(path: String): CorrelationSelectorModel = super.load(path)
144 | }
145 |
146 | @Since("2.1.1")
147 | final class CorrelationSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[CorrelationSelectorModel]{
148 | @Since("2.1.1")
149 | override protected val className: String = classOf[CorrelationSelectorModel].getName
150 |
151 | @Since("2.1.1")
152 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): CorrelationSelectorModel = {
153 | new CorrelationSelectorModel(uid, selectedFeatures, featureImportances)
154 | }
155 | }
156 |
157 | private[filter] object CorrelationSelector extends DefaultParamsReadable[CorrelationSelector] {
158 | @Since("2.1.1")
159 | override def load(path: String): CorrelationSelector = super.load(path)
160 |
161 | /**
162 | * String name for `pearson` correlation type.
163 | */
164 | val Pearson: String = "pearson"
165 |
166 | /**
167 | * String name for `spearman` correlation type.
168 | */
169 | val Spearman: String = "spearman"
170 |
171 | /** Set of correlation types that CorrelationSelector supports. */
172 | val supportedCorrelationTypes: Array[String] = Array(Pearson, Spearman)
173 |
174 | // From VectorAssembler
175 | private def assemble(vv: Any*): Vector = {
176 | val indices = mutable.ArrayBuilder.make[Int]
177 | val values = mutable.ArrayBuilder.make[Double]
178 | var cur = 0
179 | vv.foreach {
180 | case v: Double =>
181 | if (v != 0.0) {
182 | indices += cur
183 | values += v
184 | }
185 | cur += 1
186 | case vec: Vector =>
187 | vec.foreachActive { case (i, v) =>
188 | if (v != 0.0) {
189 | indices += cur + i
190 | values += v
191 | }
192 | }
193 | cur += vec.size
194 | case null =>
195 | throw new SparkException("Values to assemble cannot be null.")
196 | case o =>
197 | throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
198 | }
199 | Vectors.sparse(cur, indices.result(), values.result()).compressed
200 | }
201 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/ml/feature/selection/filter/GiniSelector.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.filter
19 |
20 | import org.apache.spark.annotation.Since
21 | import org.apache.spark.ml.feature.LabeledPoint
22 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel}
23 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, _}
24 | import org.apache.spark.ml.param._
25 | import org.apache.spark.ml.util._
26 | import org.apache.spark.rdd.RDD
27 | import org.apache.spark.sql._
28 | import org.apache.spark.sql.functions._
29 | import org.apache.spark.sql.types.DoubleType
30 |
31 | /**
32 | * Feature selection based on Gini index.
33 | */
34 | @Since("2.1.1")
35 | final class GiniSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String)
36 | extends FeatureSelector[GiniSelector, GiniSelectorModel] {
37 |
38 | @Since("2.1.1")
39 | def this() = this(Identifiable.randomUID("giniSelector"))
40 |
41 | @Since("2.1.1")
42 | override def train(dataset: Dataset[_]): Array[(Int, Double)] = {
43 | val input: RDD[LabeledPoint] =
44 | dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
45 | case Row(label: Double, features: Vector) =>
46 | LabeledPoint(label, features)
47 | }
48 |
49 | // Calculate gini indices of all features
50 | new GiniCalculator(input).calculateGini().collect()
51 | }
52 |
53 | @Since("2.1.1")
54 | override def copy(extra: ParamMap): GiniSelector = defaultCopy(extra)
55 |
56 | @Since("2.1.1")
57 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): GiniSelectorModel = {
58 | new GiniSelectorModel(uid, selectedFeatures, featureImportances)
59 | }
60 | }
61 |
62 | object GiniSelector extends DefaultParamsReadable[GiniSelector] {
63 | @Since("2.1.1")
64 | override def load(path: String): GiniSelector = super.load(path)
65 | }
66 |
67 | /**
68 | * Model fitted by [[GiniSelector]].
69 | * @param selectedFeatures list of indices to select (filter)
70 | */
71 | @Since("2.1.1")
72 | final class GiniSelectorModel private[selection] (@Since("2.1.1") override val uid: String,
73 | @Since("2.1.1") override val selectedFeatures: Array[Int],
74 | @Since("2.1.1") override val featureImportances: Map[String, Double])
75 | extends FeatureSelectorModel[GiniSelectorModel](uid, selectedFeatures, featureImportances) {
76 |
77 | @Since("2.1.1")
78 | override def copy(extra: ParamMap): GiniSelectorModel = {
79 | val copied = new GiniSelectorModel(uid, selectedFeatures, featureImportances)
80 | copyValues(copied, extra).setParent(parent)
81 | }
82 |
83 | @Since("2.1.1")
84 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter(this)
85 | }
86 |
87 | @Since("2.1.1")
88 | object GiniSelectorModel extends MLReadable[GiniSelectorModel] {
89 | @Since("2.1.1")
90 | override def read: MLReader[GiniSelectorModel] = new GiniSelectorModelReader
91 |
92 | @Since("2.1.1")
93 | override def load(path: String): GiniSelectorModel = super.load(path)
94 | }
95 |
96 | @Since("2.1.1")
97 | final class GiniSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[GiniSelectorModel]{
98 | @Since("2.1.1")
99 | override protected val className: String = classOf[GiniSelectorModel].getName
100 |
101 | @Since("2.1.1")
102 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): GiniSelectorModel = {
103 | new GiniSelectorModel(uid, selectedFeatures, featureImportances)
104 | }
105 | }
106 |
107 | private [filter] class GiniCalculator (val data: RDD[LabeledPoint]) {
108 | def calculateGini(): RDD[(Int, Double)] = {
109 | val labels2Int = data.map(_.label).distinct.collect.zipWithIndex.toMap
110 | val nLabels = labels2Int.size
111 |
112 | // Basic info. about the dataset
113 | val classDistrib = data.map(d => labels2Int(d.label)).countByValue().toMap
114 | val count = data.count
115 |
116 | // Generate pairs ((featureID, featureVal), (Hot encoded) targetVal)
117 | val featureValues =
118 | data.flatMap({
119 | case LabeledPoint(label, dv: DenseVector) =>
120 | val c = Array.fill[Long](nLabels)(0L)
121 | c(labels2Int(label)) = 1L
122 | for (i <- dv.values.indices) yield ((i, dv(i).toFloat), c)
123 | case LabeledPoint(label, sv: SparseVector) =>
124 | val c = Array.fill[Long](nLabels)(0L)
125 | c(labels2Int(label)) = 1L
126 | for (i <- sv.indices.indices) yield ((sv.indices(i), sv.values(i).toFloat), c)
127 | })
128 |
129 | // Calculate Gini Indices for all features
130 | // P(X_j) - prior probability of feature X having value X_j ([[featurePrior]])
131 | // P(Y_c | X_j) - cond. probability that a sample is of class Y_c, given that feature X has value X_j ([[condProbabClassGivenFeature]])
132 | // P(Y_c) - prior probability that the label Y has value Y_c ([[classPrior]])
133 | // Gini(X) = Sum_j{P(X_j) * Sum_c{P(Y_c | X_j)²}} - Sum_c{P(Y_c)}
134 | // Step 1: (featureNumber, List[featureValue, Array[classCount]])
135 | val featureClassDistrib = getSortedDistinctValues(classDistrib, featureValues)
136 | .map { case ((fk, fv), cd) => (fk, (fv, cd)) }
137 | .groupBy { case (fk, (_, _)) => fk }
138 | .map { case (fk, arr) =>
139 | (fk, arr.map { case (_, value) => value })
140 | }
141 |
142 | // Step 2: Calculate class priors and right sum
143 | val classPrior = classDistrib.map { case (k, v) => (k, v.toDouble / count) }
144 | val rightPart = classPrior.foldLeft(0.0) { case (agg, (_, v)) => agg + v * v }
145 |
146 | // Step 3: Calculate left sum
147 | val condProbabClassGivenFeatureValue = featureClassDistrib
148 | .map { case (fk, arr) => (fk, arr
149 | .map { case (fv, cd) => (fv, cd
150 | .map(cVal => cVal.toDouble / cd.sum))
151 | })
152 | }
153 | val condProbabClassGivenFeatureValueSum = condProbabClassGivenFeatureValue
154 | .map { case (fk, arr) => (fk, arr
155 | .map { case (fv, cc) => (fv, cc
156 | .map(c => c * c).sum)
157 | })
158 | }
159 | .map { case (fk, arr) => (fk, arr.toList
160 | .sortBy(_._1))
161 | }
162 | .sortBy(_._1)
163 |
164 | val featurePrior = featureClassDistrib
165 | .map { case (fk, arr) => (fk, arr
166 | .map { case (fv, cc) => (fv, cc.sum.toDouble / count) })
167 | }
168 | .map { case (fk, arr) => (fk, arr.toList.sortBy(_._1))
169 | }
170 | .sortBy(_._1)
171 |
172 | // Parallelize again, so zip won't fail
173 | val spark = condProbabClassGivenFeatureValueSum.sparkContext
174 | val tmp = spark.parallelize(condProbabClassGivenFeatureValueSum.collect())
175 | .zip(spark.parallelize(featurePrior.collect()))
176 | .map { case ((fk, arr), (fk2, arr2)) => if (fk == fk2) {
177 | (fk, arr.zip(arr2))
178 | } else {
179 | throw new IllegalStateException("Featurekeys don't match! This should never happen.")
180 | }
181 | }
182 |
183 | val leftPart = tmp.map { case (fk, calc) => (fk, calc.foldLeft(0.0) {
184 | case (agg, (x, y)) => if (x._1 == y._1) {
185 | agg + x._2 * y._2
186 | } else {
187 | throw new IllegalStateException("Featurevalues don't match! This should never happen.")
188 | }
189 | })
190 | }
191 |
192 | // Step 3: Calculate Gini indices and return
193 | leftPart.map { case (fk, value) => (fk, value - rightPart) }
194 | }
195 |
196 | /**
197 | * Group elements by feature and point (get distinct points).
198 | * Since values like (0, Float.NaN) are not considered unique when calling reduceByKey,
199 | * use the serialized version of the tuple.
200 | *
201 | * @return sorted list of unique feature values
202 | */
203 | private def getSortedDistinctValues(classDistrib: Map[Int, Long],
204 | featureValues: RDD[((Int, Float), Array[Long])]): RDD[((Int, Float), Array[Long])] = {
205 |
206 | val nonZeros: RDD[((Int, Float), Array[Long])] =
207 | featureValues.map(y => (y._1._1 + "," + y._1._2, y._2)).reduceByKey { case (v1, v2) =>
208 | (v1, v2).zipped.map(_ + _)
209 | }.map(y => {
210 | val s = y._1.split(",")
211 | ((s(0).toInt, s(1).toFloat), y._2)
212 | })
213 |
214 | val zeros = addZerosIfNeeded(nonZeros, classDistrib)
215 | val distinctValues = nonZeros.union(zeros)
216 |
217 | // Sort these values to perform the boundary points evaluation
218 | distinctValues.sortByKey()
219 | }
220 |
221 | /**
222 | * Add zeros if dealing with sparse data
223 | *
224 | * @return rdd with 0's filled in
225 | */
226 | private def addZerosIfNeeded(nonZeros: RDD[((Int, Float), Array[Long])],
227 | classDistrib: Map[Int, Long]): RDD[((Int, Float), Array[Long])] = {
228 | nonZeros.map { case ((k, _), v) => (k, v) }
229 | .reduceByKey { case (v1, v2) => (v1, v2).zipped.map(_ + _) }
230 | .map { case (k, v) =>
231 | val v2 = for (i <- v.indices) yield classDistrib(i) - v(i)
232 | ((k, 0.0F), v2.toArray)
233 | }.filter { case (_, v) => v.sum > 0 }
234 | }
235 |
236 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/ml/feature/selection/filter/InfoGainSelector.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.filter
19 |
20 | import org.apache.spark.annotation.Since
21 | import org.apache.spark.ml.feature.LabeledPoint
22 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel}
23 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, _}
24 | import org.apache.spark.ml.param._
25 | import org.apache.spark.ml.util._
26 | import org.apache.spark.rdd.RDD
27 | import org.apache.spark.sql._
28 | import org.apache.spark.sql.functions._
29 | import org.apache.spark.sql.types.DoubleType
30 |
31 | /**
32 | * Feature selection based on Information Gain.
33 | */
34 | @Since("2.1.1")
35 | final class InfoGainSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String)
36 | extends FeatureSelector[InfoGainSelector, InfoGainSelectorModel] {
37 |
38 | @Since("2.1.1")
39 | def this() = this(Identifiable.randomUID("igSelector"))
40 |
41 | @Since("2.1.1")
42 | override def train(dataset: Dataset[_]): Array[(Int, Double)] = {
43 | val input: RDD[LabeledPoint] =
44 | dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
45 | case Row(label: Double, features: Vector) =>
46 | LabeledPoint(label, features)
47 | }
48 |
49 | // Calculate gains of all features (features that are always zero will be dropped)
50 | new InfoGainCalculator(input).calculateIG().collect()
51 | }
52 |
53 | @Since("2.1.1")
54 | override def copy(extra: ParamMap): InfoGainSelector = defaultCopy(extra)
55 |
56 | @Since("2.1.1")
57 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): InfoGainSelectorModel = {
58 | new InfoGainSelectorModel(uid, selectedFeatures, featureImportances)
59 | }
60 | }
61 |
62 | object InfoGainSelector extends DefaultParamsReadable[InfoGainSelector] {
63 | @Since("2.1.1")
64 | override def load(path: String): InfoGainSelector = super.load(path)
65 | }
66 |
67 | /**
68 | * Model fitted by [[InfoGainSelector]].
69 | * @param selectedFeatures list of indices to select (filter)
70 | */
71 | @Since("2.1.1")
72 | final class InfoGainSelectorModel private[filter] (@Since("2.1.1") override val uid: String,
73 | @Since("2.1.1") override val selectedFeatures: Array[Int],
74 | @Since("2.1.1") override val featureImportances: Map[String, Double])
75 | extends FeatureSelectorModel[InfoGainSelectorModel](uid, selectedFeatures, featureImportances) {
76 |
77 | @Since("2.1.1")
78 | override def copy(extra: ParamMap): InfoGainSelectorModel = {
79 | val copied = new InfoGainSelectorModel(uid, selectedFeatures, featureImportances)
80 | copyValues(copied, extra).setParent(parent)
81 | }
82 |
83 | @Since("2.1.1")
84 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter[InfoGainSelectorModel](this)
85 | }
86 |
87 | @Since("2.1.1")
88 | object InfoGainSelectorModel extends MLReadable[InfoGainSelectorModel] {
89 | @Since("2.1.1")
90 | override def read: MLReader[InfoGainSelectorModel] = new InfoGainSelectorModelReader
91 |
92 | @Since("2.1.1")
93 | override def load(path: String): InfoGainSelectorModel = super.load(path)
94 | }
95 |
96 | @Since("2.1.1")
97 | final class InfoGainSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[InfoGainSelectorModel]{
98 | @Since("2.1.1")
99 | override protected val className: String = classOf[InfoGainSelectorModel].getName
100 |
101 | @Since("2.1.1")
102 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): InfoGainSelectorModel = {
103 | new InfoGainSelectorModel(uid, selectedFeatures, featureImportances)
104 | }
105 | }
106 |
107 | private [filter] class InfoGainCalculator (val data: RDD[LabeledPoint]) {
108 | def calculateIG(): RDD[(Int, Double)] = {
109 | val LOG2 = math.log(2)
110 |
111 | /** log base 2 of x
112 | * @return log base 2 of x */
113 | val log2 = { x: Double => math.log(x) / LOG2 }
114 | /** entropy of x
115 | * @return entropy of x */
116 | val entropy = { x: Double => if (x == 0) 0 else -x * log2(x) }
117 |
118 | val labels2Int = data.map(_.label).distinct.collect.zipWithIndex.toMap
119 | val nLabels = labels2Int.size
120 |
121 | // Basic info. about the dataset
122 | val classDistrib = data.map(d => labels2Int(d.label)).countByValue().toMap
123 |
124 | // Generate pairs ((featureID, featureVal), (Hot encoded) targetVal)
125 | val featureValues =
126 | data.flatMap({
127 | case LabeledPoint(label, dv: DenseVector) =>
128 | val c = Array.fill[Long](nLabels)(0L)
129 | c(labels2Int(label)) = 1L
130 | for (i <- dv.values.indices) yield ((i, dv(i).toFloat), c)
131 | case LabeledPoint(label, sv: SparseVector) =>
132 | val c = Array.fill[Long](nLabels)(0L)
133 | c(labels2Int(label)) = 1L
134 | for (i <- sv.indices.indices) yield ((sv.indices(i), sv.values(i).toFloat), c)
135 | })
136 |
137 | val sortedValues = getSortedDistinctValues(classDistrib, featureValues)
138 |
139 | val numSamples = classDistrib.values.sum
140 |
141 | // Calculate Probabilities
142 | val classDistribProb = classDistrib.map { case (k, v) => (k, v.toDouble / numSamples) }
143 | val featureProbs = sortedValues.map { case ((k, v), a) => ((k, v), a.sum.toDouble / numSamples) }
144 | val jointProbab = sortedValues.map { case ((k, v), a) => ((k, v), a.map(elem => elem.toDouble / numSamples)) }
145 |
146 | val jpTable = jointProbab.groupBy { case ((k, v), a) => k }
147 | val fpTable = featureProbs.groupBy { case ((k, v), a) => k }
148 |
149 | // Calculate entropies
150 | val featureEntropies = fpTable.map { case (k, v) => (k, v.map { case ((k1, v1), a) => entropy(a) }.sum) }.sortByKey(ascending = true)
151 | val jointEntropies = jpTable.map { case (k, v) => (k, v.map { case ((k1, v1), a) => a.map(v2 => entropy(v2)).sum }.sum) }.sortByKey(ascending = true)
152 | val targetEntropy = classDistribProb.foldLeft(0.0) { case (acc, (k, v)) => acc + entropy(v) }
153 |
154 | // Calculate information gain: targetEntropy + featureEntropy - jointEntropy
155 | // Format: RDD[(featureID->InformationGain)]
156 |
157 | val spark = featureEntropies.sparkContext
158 |
159 | spark.parallelize(featureEntropies.collect()).zip(spark.parallelize(jointEntropies.collect())).map { case ((k1, v1), (k2, v2)) => k1 -> (targetEntropy + v1 - v2) }
160 |
161 | // customZip(featureEntropies, jointEntropies).map { case ((k1, v1), (k2, v2)) => k1->(targetEntropy + v1 - v2)}
162 | }
163 |
164 | /**
165 | * Group elements by feature and point (get distinct points).
166 | * Since values like (0, Float.NaN) are not considered unique when calling reduceByKey,
167 | * use the serialized version of the tuple.
168 | *
169 | * @return sorted list of unique feature values
170 | */
171 | private def getSortedDistinctValues(classDistrib: Map[Int, Long],
172 | featureValues: RDD[((Int, Float), Array[Long])]): RDD[((Int, Float), Array[Long])] = {
173 |
174 | val nonZeros: RDD[((Int, Float), Array[Long])] =
175 | featureValues.map(y => (y._1._1 + "," + y._1._2, y._2)).reduceByKey { case (v1, v2) =>
176 | (v1, v2).zipped.map(_ + _)
177 | }.map(y => {
178 | val s = y._1.split(",")
179 | ((s(0).toInt, s(1).toFloat), y._2)
180 | })
181 |
182 | val zeros = addZerosIfNeeded(nonZeros, classDistrib)
183 | val distinctValues = nonZeros.union(zeros)
184 |
185 | // Sort these values to perform the boundary points evaluation
186 | distinctValues.sortByKey()
187 | }
188 |
189 | /**
190 | * Add zeros if dealing with sparse data
191 | * Features that do not have any non-zero value will not be added
192 | *
193 | * @return rdd with 0's filled in
194 | */
195 | private def addZerosIfNeeded(nonZeros: RDD[((Int, Float), Array[Long])],
196 | classDistrib: Map[Int, Long]): RDD[((Int, Float), Array[Long])] = {
197 | nonZeros.map { case ((k, p), v) => (k, v) }
198 | .reduceByKey { case (v1, v2) => (v1, v2).zipped.map(_ + _) }
199 | .map { case (k, v) =>
200 | val v2 = for (i <- v.indices) yield classDistrib(i) - v(i)
201 | ((k, 0.0F), v2.toArray)
202 | }.filter { case (_, v) => v.sum > 0 }
203 | }
204 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/ml/feature/selection/util/VectorMerger.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.util
19 |
20 | import org.apache.spark.SparkException
21 | import org.apache.spark.annotation.Since
22 | import org.apache.spark.ml.Transformer
23 | import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
24 | import org.apache.spark.ml.feature.VectorSlicer
25 | import org.apache.spark.ml.linalg.{Vector, VectorUDT, Vectors}
26 | import org.apache.spark.ml.param.shared._
27 | import org.apache.spark.ml.param.{BooleanParam, ParamMap}
28 | import org.apache.spark.ml.util._
29 | import org.apache.spark.sql.functions._
30 | import org.apache.spark.sql.types._
31 | import org.apache.spark.sql.{DataFrame, Dataset, Row}
32 |
33 | import scala.collection.mutable
34 |
35 | /**
36 | * A feature transformer that merges multiple columns into a vector column, without keeping duplicates.
37 | * The Transformer has two modes, triggered by useFeaturesCol:
38 | * 1) useFeaturesCol true and featuresCol set: the output column will contain columns from featuresCol that have
39 | * names appearing in one of the inputCols (type vector)
40 | * 2) useFeaturesCol false: the output column will contain the columns from the inputColumns, but dropping duplicates
41 | */
42 | @Since("2.1.1")
43 | class VectorMerger @Since("2.1.1") (@Since("2.1.1") override val uid: String)
44 | extends Transformer with HasFeaturesCol with HasInputCols with HasOutputCol with DefaultParamsWritable {
45 |
46 | @Since("2.1.1")
47 | def this() = this(Identifiable.randomUID("vecAssemblerMerger"))
48 |
49 | /** @group setParam */
50 | @Since("2.1.1")
51 | def setInputCols(value: Array[String]): this.type = set(inputCols, value)
52 |
53 | /** @group setParam */
54 | @Since("2.1.1")
55 | def setOutputCol(value: String): this.type = set(outputCol, value)
56 |
57 | /** @group setParam */
58 | @Since("2.1.1")
59 | def setFeatureCol(value: String): this.type = set(featuresCol, value)
60 |
61 | /** @group param */
62 | @Since("2.1.1")
63 | final val useFeaturesCol = new BooleanParam(this, "useFeaturesCol",
64 | "The output column will contain columns from featuresCol that have names appearing in one of the inputCols (type vector)")
65 | setDefault(useFeaturesCol, true)
66 |
67 | /** @group getParam */
68 | @Since("2.1.1")
69 | def getUseFeaturesCol: Boolean = $(useFeaturesCol)
70 |
71 | /** @group setParam */
72 | @Since("2.1.1")
73 | def setUseFeaturesCol(value: Boolean): this.type = set(useFeaturesCol, value)
74 |
75 | @Since("2.1.1")
76 | override def transform(dataset: Dataset[_]): DataFrame = {
77 | transformSchema(dataset.schema, logging = true)
78 |
79 | if ($(useFeaturesCol))
80 | transformUsingFeaturesColumn(dataset)
81 | else
82 | transformUsingInputColumns(dataset)
83 | }
84 |
85 | @Since("2.1.1")
86 | private def transformUsingInputColumns(dataset: Dataset[_]): DataFrame = {
87 | // Schema transformation.
88 | val schema = dataset.schema
89 | lazy val first = dataset.toDF.first()
90 |
91 | val uniqueNames = mutable.ArrayBuffer[String]()
92 |
93 | val indicesToKeep = mutable.ArrayBuilder.make[Int]
94 | var cur = 0
95 |
96 | def doNotAdd(): Option[Nothing] = {
97 | cur += 1
98 | None
99 | }
100 |
101 | def addAndIncrementIndex() {
102 | indicesToKeep += cur
103 | cur += 1
104 | }
105 |
106 | val attrs: Array[Attribute] = $(inputCols).flatMap { c =>
107 | val field = schema(c)
108 | val index = schema.fieldIndex(c)
109 | field.dataType match {
110 | case _: VectorUDT =>
111 | val group = AttributeGroup.fromStructField(field)
112 | if (group.attributes.isDefined) {
113 | // If attributes are defined, copy them, checking for duplicates and preserving name.
114 | group.attributes.get.zipWithIndex.flatMap { case (attr, i) =>
115 | if (attr.name.isDefined) {
116 | val name = attr.name.get
117 | if (!uniqueNames.contains(name)) {
118 | addAndIncrementIndex()
119 | uniqueNames.append(name)
120 | Some(attr)
121 | } else
122 | doNotAdd()
123 | } else {
124 | addAndIncrementIndex()
125 | Some(attr.withName(c + "_" + i))
126 | }
127 | }.toList
128 | } else {
129 | // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes
130 | // from metadata, check the first row.
131 | val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size)
132 | Array.tabulate(numAttrs)(i => {
133 | addAndIncrementIndex()
134 | NumericAttribute.defaultAttr.withName(c + "_" + i)
135 | })
136 | }
137 | case otherType =>
138 | throw new SparkException(s"VectorMerger does not support the $otherType type")
139 | }
140 | }
141 | val metadata = new AttributeGroup($(outputCol), attrs).toMetadata()
142 |
143 | // Data transformation.
144 | val assembleFunc = udf { r: Row =>
145 | VectorMerger.assemble(indicesToKeep.result(), r.toSeq: _*)
146 | }
147 |
148 | val args = $(inputCols).map { c =>
149 | schema(c).dataType match {
150 | case DoubleType => dataset(c)
151 | case _: VectorUDT => dataset(c)
152 | case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid")
153 | }
154 | }
155 |
156 | dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata))
157 | }
158 |
159 | @Since("2.1.1")
160 | private def transformUsingFeaturesColumn(dataset: Dataset[_]): DataFrame = {
161 | // Schema transformation.
162 | val schema = dataset.schema
163 |
164 | val notUniqueNames = mutable.ArrayBuffer[String]()
165 | val featuresColName = $(featuresCol)
166 | val featureColAttrs = AttributeGroup.fromStructField(schema(featuresColName)).attributes.get.zipWithIndex
167 | val featuresColNames = featureColAttrs.flatMap { case (attr, _) => attr.name }
168 |
169 | $(inputCols).foreach { c =>
170 | val field = schema(c)
171 | field.dataType match {
172 | case _: VectorUDT =>
173 | val group = AttributeGroup.fromStructField(field)
174 | if (group.attributes.isDefined) {
175 | // If attributes are defined, remember name to get column from $featureCol.
176 | group.attributes.get.zipWithIndex.foreach { case (attr, _) =>
177 | if (attr.name.isDefined) {
178 | val name = attr.name.get
179 | if (featuresColNames.contains(name))
180 | notUniqueNames.append(name)
181 | else
182 | throw new IllegalArgumentException(s"Features column $featuresColName does not contain column with name $name!")
183 | } else
184 | throw new IllegalArgumentException(s"Input column $c contains column without name attribute!")
185 | }
186 | } else {
187 | // Otherwise, merging not possible
188 | throw new IllegalArgumentException(s"Input column $c does not contain attributes!")
189 | }
190 | }
191 | }
192 |
193 | val uniqueNames = notUniqueNames.toSet
194 |
195 | new VectorSlicer()
196 | .setInputCol(featuresColName)
197 | .setNames(uniqueNames.toArray)
198 | .setOutputCol($(outputCol))
199 | .transform(dataset)
200 | }
201 |
202 | @Since("2.1.1")
203 | override def transformSchema(schema: StructType): StructType = {
204 | if ($(useFeaturesCol)) {
205 | val featuresColName = $(featuresCol)
206 |
207 | if (!schema(featuresColName).dataType.isInstanceOf[VectorUDT])
208 | throw new IllegalArgumentException(s"Features column $featuresColName is not of type VectorUDT!")
209 | }
210 |
211 | val inputColNames = $(inputCols)
212 | val outputColName = $(outputCol)
213 |
214 | inputColNames.foreach(name => if (!schema(name).dataType.isInstanceOf[VectorUDT])
215 | throw new IllegalArgumentException(s"Input column $name is not of type VectorUDT!")
216 | )
217 |
218 | if (schema.fieldNames.contains(outputColName)) {
219 | throw new IllegalArgumentException(s"Output column $outputColName already exists.")
220 | }
221 |
222 | StructType(schema.fields :+ StructField(outputColName, new VectorUDT, nullable = true))
223 | }
224 |
225 | @Since("2.1.1")
226 | override def copy(extra: ParamMap): VectorMerger = defaultCopy(extra)
227 | }
228 |
229 | @Since("2.1.1")
230 | object VectorMerger extends DefaultParamsReadable[VectorMerger] {
231 |
232 | @Since("2.1.1")
233 | override def load(path: String): VectorMerger = super.load(path)
234 |
235 | private[feature] def assemble(indicesToKeep: Array[Int], vv: Any*): Vector = {
236 | val indices = mutable.ArrayBuilder.make[Int]
237 | val values = mutable.ArrayBuilder.make[Double]
238 |
239 | var returnCur = 0
240 | var globalCur = 0
241 | vv.foreach {
242 | case v: Double =>
243 | if (indicesToKeep.contains(globalCur)) {
244 | if (v != 0.0) {
245 | indices += returnCur
246 | values += v
247 | }
248 | returnCur += 1
249 | }
250 | globalCur += 1
251 | case vec: Vector =>
252 | vec.toDense.foreachActive { case (_, v) =>
253 | if (indicesToKeep.contains(globalCur)) {
254 | if (v != 0.0) {
255 | indices += returnCur
256 | values += v
257 | }
258 | returnCur += 1
259 | }
260 | globalCur += 1
261 | }
262 | case null =>
263 | throw new SparkException("Values to assemble cannot be null.")
264 | case o =>
265 | throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.")
266 | }
267 | Vectors.sparse(returnCur, indices.result(), values.result()).compressed
268 | }
269 | }
270 |
--------------------------------------------------------------------------------
/src/test/resources/iris.data:
--------------------------------------------------------------------------------
1 | sLength,sWidth,pLength,pWidth,Species
2 | 5.1,3.5,1.4,0.2,0.0
3 | 4.9,3.0,1.4,0.2,0.0
4 | 4.7,3.2,1.3,0.2,0.0
5 | 4.6,3.1,1.5,0.2,0.0
6 | 5.0,3.6,1.4,0.2,0.0
7 | 5.4,3.9,1.7,0.4,0.0
8 | 4.6,3.4,1.4,0.3,0.0
9 | 5.0,3.4,1.5,0.2,0.0
10 | 4.4,2.9,1.4,0.2,0.0
11 | 4.9,3.1,1.5,0.1,0.0
12 | 5.4,3.7,1.5,0.2,0.0
13 | 4.8,3.4,1.6,0.2,0.0
14 | 4.8,3.0,1.4,0.1,0.0
15 | 4.3,3.0,1.1,0.1,0.0
16 | 5.8,4.0,1.2,0.2,0.0
17 | 5.7,4.4,1.5,0.4,0.0
18 | 5.4,3.9,1.3,0.4,0.0
19 | 5.1,3.5,1.4,0.3,0.0
20 | 5.7,3.8,1.7,0.3,0.0
21 | 5.1,3.8,1.5,0.3,0.0
22 | 5.4,3.4,1.7,0.2,0.0
23 | 5.1,3.7,1.5,0.4,0.0
24 | 4.6,3.6,1.0,0.2,0.0
25 | 5.1,3.3,1.7,0.5,0.0
26 | 4.8,3.4,1.9,0.2,0.0
27 | 5.0,3.0,1.6,0.2,0.0
28 | 5.0,3.4,1.6,0.4,0.0
29 | 5.2,3.5,1.5,0.2,0.0
30 | 5.2,3.4,1.4,0.2,0.0
31 | 4.7,3.2,1.6,0.2,0.0
32 | 4.8,3.1,1.6,0.2,0.0
33 | 5.4,3.4,1.5,0.4,0.0
34 | 5.2,4.1,1.5,0.1,0.0
35 | 5.5,4.2,1.4,0.2,0.0
36 | 4.9,3.1,1.5,0.1,0.0
37 | 5.0,3.2,1.2,0.2,0.0
38 | 5.5,3.5,1.3,0.2,0.0
39 | 4.9,3.1,1.5,0.1,0.0
40 | 4.4,3.0,1.3,0.2,0.0
41 | 5.1,3.4,1.5,0.2,0.0
42 | 5.0,3.5,1.3,0.3,0.0
43 | 4.5,2.3,1.3,0.3,0.0
44 | 4.4,3.2,1.3,0.2,0.0
45 | 5.0,3.5,1.6,0.6,0.0
46 | 5.1,3.8,1.9,0.4,0.0
47 | 4.8,3.0,1.4,0.3,0.0
48 | 5.1,3.8,1.6,0.2,0.0
49 | 4.6,3.2,1.4,0.2,0.0
50 | 5.3,3.7,1.5,0.2,0.0
51 | 5.0,3.3,1.4,0.2,0.0
52 | 7.0,3.2,4.7,1.4,1.0
53 | 6.4,3.2,4.5,1.5,1.0
54 | 6.9,3.1,4.9,1.5,1.0
55 | 5.5,2.3,4.0,1.3,1.0
56 | 6.5,2.8,4.6,1.5,1.0
57 | 5.7,2.8,4.5,1.3,1.0
58 | 6.3,3.3,4.7,1.6,1.0
59 | 4.9,2.4,3.3,1.0,1.0
60 | 6.6,2.9,4.6,1.3,1.0
61 | 5.2,2.7,3.9,1.4,1.0
62 | 5.0,2.0,3.5,1.0,1.0
63 | 5.9,3.0,4.2,1.5,1.0
64 | 6.0,2.2,4.0,1.0,1.0
65 | 6.1,2.9,4.7,1.4,1.0
66 | 5.6,2.9,3.6,1.3,1.0
67 | 6.7,3.1,4.4,1.4,1.0
68 | 5.6,3.0,4.5,1.5,1.0
69 | 5.8,2.7,4.1,1.0,1.0
70 | 6.2,2.2,4.5,1.5,1.0
71 | 5.6,2.5,3.9,1.1,1.0
72 | 5.9,3.2,4.8,1.8,1.0
73 | 6.1,2.8,4.0,1.3,1.0
74 | 6.3,2.5,4.9,1.5,1.0
75 | 6.1,2.8,4.7,1.2,1.0
76 | 6.4,2.9,4.3,1.3,1.0
77 | 6.6,3.0,4.4,1.4,1.0
78 | 6.8,2.8,4.8,1.4,1.0
79 | 6.7,3.0,5.0,1.7,1.0
80 | 6.0,2.9,4.5,1.5,1.0
81 | 5.7,2.6,3.5,1.0,1.0
82 | 5.5,2.4,3.8,1.1,1.0
83 | 5.5,2.4,3.7,1.0,1.0
84 | 5.8,2.7,3.9,1.2,1.0
85 | 6.0,2.7,5.1,1.6,1.0
86 | 5.4,3.0,4.5,1.5,1.0
87 | 6.0,3.4,4.5,1.6,1.0
88 | 6.7,3.1,4.7,1.5,1.0
89 | 6.3,2.3,4.4,1.3,1.0
90 | 5.6,3.0,4.1,1.3,1.0
91 | 5.5,2.5,4.0,1.3,1.0
92 | 5.5,2.6,4.4,1.2,1.0
93 | 6.1,3.0,4.6,1.4,1.0
94 | 5.8,2.6,4.0,1.2,1.0
95 | 5.0,2.3,3.3,1.0,1.0
96 | 5.6,2.7,4.2,1.3,1.0
97 | 5.7,3.0,4.2,1.2,1.0
98 | 5.7,2.9,4.2,1.3,1.0
99 | 6.2,2.9,4.3,1.3,1.0
100 | 5.1,2.5,3.0,1.1,1.0
101 | 5.7,2.8,4.1,1.3,1.0
102 | 6.3,3.3,6.0,2.5,2.0
103 | 5.8,2.7,5.1,1.9,2.0
104 | 7.1,3.0,5.9,2.1,2.0
105 | 6.3,2.9,5.6,1.8,2.0
106 | 6.5,3.0,5.8,2.2,2.0
107 | 7.6,3.0,6.6,2.1,2.0
108 | 4.9,2.5,4.5,1.7,2.0
109 | 7.3,2.9,6.3,1.8,2.0
110 | 6.7,2.5,5.8,1.8,2.0
111 | 7.2,3.6,6.1,2.5,2.0
112 | 6.5,3.2,5.1,2.0,2.0
113 | 6.4,2.7,5.3,1.9,2.0
114 | 6.8,3.0,5.5,2.1,2.0
115 | 5.7,2.5,5.0,2.0,2.0
116 | 5.8,2.8,5.1,2.4,2.0
117 | 6.4,3.2,5.3,2.3,2.0
118 | 6.5,3.0,5.5,1.8,2.0
119 | 7.7,3.8,6.7,2.2,2.0
120 | 7.7,2.6,6.9,2.3,2.0
121 | 6.0,2.2,5.0,1.5,2.0
122 | 6.9,3.2,5.7,2.3,2.0
123 | 5.6,2.8,4.9,2.0,2.0
124 | 7.7,2.8,6.7,2.0,2.0
125 | 6.3,2.7,4.9,1.8,2.0
126 | 6.7,3.3,5.7,2.1,2.0
127 | 7.2,3.2,6.0,1.8,2.0
128 | 6.2,2.8,4.8,1.8,2.0
129 | 6.1,3.0,4.9,1.8,2.0
130 | 6.4,2.8,5.6,2.1,2.0
131 | 7.2,3.0,5.8,1.6,2.0
132 | 7.4,2.8,6.1,1.9,2.0
133 | 7.9,3.8,6.4,2.0,2.0
134 | 6.4,2.8,5.6,2.2,2.0
135 | 6.3,2.8,5.1,1.5,2.0
136 | 6.1,2.6,5.6,1.4,2.0
137 | 7.7,3.0,6.1,2.3,2.0
138 | 6.3,3.4,5.6,2.4,2.0
139 | 6.4,3.1,5.5,1.8,2.0
140 | 6.0,3.0,4.8,1.8,2.0
141 | 6.9,3.1,5.4,2.1,2.0
142 | 6.7,3.1,5.6,2.4,2.0
143 | 6.9,3.1,5.1,2.3,2.0
144 | 5.8,2.7,5.1,1.9,2.0
145 | 6.8,3.2,5.9,2.3,2.0
146 | 6.7,3.3,5.7,2.5,2.0
147 | 6.7,3.0,5.2,2.3,2.0
148 | 6.3,2.5,5.0,1.9,2.0
149 | 6.5,3.0,5.2,2.0,2.0
150 | 6.2,3.4,5.4,2.3,2.0
151 | 5.9,3.0,5.1,1.8,2.0
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/ml/feature/selection/FeatureSelectionTestBase.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection
19 |
20 | import org.apache.spark.ml.attribute.AttributeGroup
21 | import org.apache.spark.ml.feature.VectorAssembler
22 | import org.apache.spark.ml.feature.selection.test_util._
23 | import org.apache.spark.ml.linalg.Vector
24 | import org.apache.spark.sql.{Dataset, Row, SparkSession}
25 | import org.scalactic.TolerantNumerics
26 | import org.scalatest._
27 |
28 | abstract class FeatureSelectionTestBase extends FunSuite with SharedSparkSession with BeforeAndAfter with DefaultReadWriteTest
29 |
30 | trait SharedSparkSession extends BeforeAndAfterAll with BeforeAndAfterEach {
31 | self: Suite =>
32 |
33 | @transient private var _sc: SparkSession = _
34 | @transient var dataset: Dataset[_] = _
35 |
36 | private val testPath = getClass.getResource("/iris.data").getPath
37 | protected val featuresColName = "features"
38 | protected val labelColName = "Species"
39 |
40 | def sc: SparkSession = _sc
41 |
42 | override def beforeAll() {
43 | super.beforeAll()
44 | _sc = SparkSession
45 | .builder()
46 | .master("local[*]")
47 | .appName("spark test base")
48 | .getOrCreate()
49 | _sc.sparkContext.setLogLevel("ERROR")
50 |
51 | val df = sc.read.option("inferSchema", true).option("header", true).csv(testPath)
52 | dataset = new VectorAssembler()
53 | .setInputCols(Array("sLength", "sWidth", "pLength", "pWidth"))
54 | .setOutputCol(featuresColName)
55 | .transform(df)
56 | }
57 |
58 | override def afterAll() {
59 | try {
60 | if (_sc != null) {
61 | _sc.stop()
62 | }
63 | _sc = null
64 | } finally {
65 | super.afterAll()
66 | }
67 | }
68 | }
69 |
70 | object FeatureSelectorTestBase {
71 | private val epsilon = 1e-4f
72 | private implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(epsilon)
73 |
74 | def tolerantVectorEquality(a: Vector, b: Vector): Boolean = {
75 | a.toArray.zip(b.toArray).map {
76 | case (a: Double, b: Double) => doubleEq.areEqual(a, b)
77 | case (a: Any, b: Any) => a == b
78 | }.forall(value => value)
79 | }
80 |
81 | def testSelector[
82 | Learner <: FeatureSelector[Learner, M],
83 | M <: FeatureSelectorModel[M]](selector: FeatureSelector[Learner, M],
84 | dataset: Dataset[_],
85 | importantColNames: Array[String],
86 | groundTruthColname: String): Unit = {
87 | val selectorModel = selector.fit(dataset)
88 | val transformed = selectorModel.transform(dataset)
89 |
90 | val inputCols = AttributeGroup.fromStructField(transformed.schema(selector.getFeaturesCol))
91 | .attributes.get.map(attr => attr.name.get)
92 |
93 | assert(selectorModel.featureImportances.size == inputCols.length,
94 | "Length of featureImportances array is not equal to number of input columns!")
95 |
96 | val selectedColNames = AttributeGroup.fromStructField(transformed.schema(selector.getOutputCol))
97 | .attributes.get.map(attr => attr.name.get)
98 |
99 | val importantColsSelected = importantColNames.sorted.zip(selectedColNames.sorted).map(elem => elem._1 == elem._2).forall(elem => elem)
100 |
101 | assert(importantColsSelected, "Selected and important column names do not match!")
102 |
103 | transformed.select(selectorModel.getOutputCol, groundTruthColname).collect()
104 | .foreach { case Row(vec1: Vector, vec2: Vector) =>
105 | assert(tolerantVectorEquality(vec1, vec2))
106 | }
107 | }
108 |
109 | def checkModelData[M <: FeatureSelectorModel[M]](model1: FeatureSelectorModel[M], model2: FeatureSelectorModel[M]): Unit = {
110 |
111 | assert(model1.selectedFeatures sameElements model2.selectedFeatures
112 | , "Persisted model has different selectedFeatures.")
113 | assert(model1.featureImportances.toArray.sortBy(elem => elem._1) sameElements model2.featureImportances.toArray.sortBy(elem => elem._1),
114 | "Persisted model has different featureImportances.")
115 | }
116 |
117 | /**
118 | * Mapping from all Params to valid settings which differ from the defaults.
119 | * This is useful for tests which need to exercise all Params, such as save/load.
120 | * This excludes input columns to simplify some tests.
121 | */
122 | val allParamSettings: Map[String, Any] = Map(
123 | "selectorType" -> "percentile",
124 | "numTopFeatures" -> 1,
125 | "percentile" -> 0.12,
126 | "randomCutOff" -> 0.1,
127 | "featuresCol" -> "features",
128 | "labelCol" -> "Species",
129 | "outputCol" -> "myOutput"
130 | )
131 | }
132 |
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/ml/feature/selection/embedded/ImportanceSelectorSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.embedded
19 |
20 | import org.apache.spark.ml.feature.VectorAssembler
21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase}
22 | import org.apache.spark.ml.linalg.Vectors
23 |
24 | class ImportanceSelectorSuite extends FeatureSelectionTestBase {
25 | // Order of feature importances must be: f4 > f3 > f2 > f1
26 | private val featureWeights = Vectors.dense(Array(0.3, 0.5, 0.7, 0.8))
27 |
28 | test("Test ImportanceSelector: numTopFeatures") {
29 | val selector = new ImportanceSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
30 | .setFeatureWeights(featureWeights)
31 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(2)
32 |
33 | val importantColNames = Array("pWidth", "pLength")
34 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
35 |
36 | FeatureSelectorTestBase.testSelector[ImportanceSelector, ImportanceSelectorModel](selector, df, importantColNames, "ImportantFeatures")
37 | }
38 |
39 | test("Test ImportanceSelector: percentile") {
40 | val selector = new ImportanceSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
41 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51).setFeatureWeights(featureWeights)
42 |
43 | val importantColNames = Array("pWidth", "pLength")
44 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
45 |
46 | FeatureSelectorTestBase.testSelector[ImportanceSelector, ImportanceSelectorModel](selector, df, importantColNames, "ImportantFeatures")
47 | }
48 |
49 | test("Test ImportanceSelector: randomCutOff") {
50 | val selector = new ImportanceSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
51 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0).setFeatureWeights(featureWeights)
52 |
53 | val importantColNames = Array("pWidth", "pLength", "sWidth", "sLength")
54 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
55 |
56 | FeatureSelectorTestBase.testSelector[ImportanceSelector, ImportanceSelectorModel](selector, df, importantColNames, "ImportantFeatures")
57 | }
58 |
59 | test("ImportanceSelector read/write") {
60 | val nb = new ImportanceSelector
61 | testEstimatorAndModelReadWrite[ImportanceSelector, ImportanceSelectorModel](nb, dataset,
62 | FeatureSelectorTestBase.allParamSettings.+("featureWeights" -> featureWeights), FeatureSelectorTestBase.checkModelData)
63 | }
64 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/ml/feature/selection/embedded/LRSelectorSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.embedded
19 |
20 | import org.apache.spark.ml.feature.VectorAssembler
21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase}
22 | import org.apache.spark.ml.linalg.Matrices
23 |
24 | class LRSelectorSuite extends FeatureSelectionTestBase {
25 | // Order of feature importances must be: f4 > f3 > f2 > f1
26 | private val lrWeights = Matrices.dense(3, 4, Array(0.1, 0.1, 0.1, 0.2, 0.2, 0.2, -0.8, -0.8, -0.8, 0.9, 0.9, 0.9))
27 |
28 | test("Test LRSelector: numTopFeatures") {
29 | val selector = new LRSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCoefficientMatrix(lrWeights)
30 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(2)
31 |
32 | val importantColNames = Array("pWidth", "pLength")
33 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
34 |
35 | FeatureSelectorTestBase.testSelector[LRSelector, LRSelectorModel](selector, df, importantColNames, "ImportantFeatures")
36 | }
37 |
38 | test("Test LRSelector: percentile") {
39 | val selector = new LRSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
40 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51).setCoefficientMatrix(lrWeights)
41 |
42 | val importantColNames = Array("pWidth", "pLength")
43 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
44 |
45 | FeatureSelectorTestBase.testSelector[LRSelector, LRSelectorModel](selector, df, importantColNames, "ImportantFeatures")
46 | }
47 |
48 | test("Test LRSelector: randomCutOff") {
49 | val selector = new LRSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
50 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0).setCoefficientMatrix(lrWeights)
51 |
52 | val importantColNames = Array("pWidth", "pLength", "sWidth", "sLength")
53 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
54 |
55 | FeatureSelectorTestBase.testSelector[LRSelector, LRSelectorModel](selector, df, importantColNames, "ImportantFeatures")
56 | }
57 |
58 | test("LRSelector read/write") {
59 | val nb = new LRSelector
60 | testEstimatorAndModelReadWrite[LRSelector, LRSelectorModel](nb, dataset,
61 | FeatureSelectorTestBase.allParamSettings.+("coefficientMatrix" -> lrWeights), FeatureSelectorTestBase.checkModelData)
62 | }
63 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/ml/feature/selection/filter/CorrelationSelectorSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.filter
19 |
20 | import org.apache.spark.ml.feature.VectorAssembler
21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase}
22 |
23 | /* To verify the results with R, run:
24 | library(dplyr)
25 | data(iris)
26 | df <- iris %>%
27 | dplyr::mutate(label = ifelse(Species == "setosa", 0.0, ifelse(Species == "versicolor", 1.0, 2.0))) %>%
28 | dplyr::select("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width", "label")
29 | print(cor(df, method = "pearson"))
30 | print(cor(df, method = "spearman"))
31 | */
32 |
33 | class CorrelationSelectorSuite extends FeatureSelectionTestBase {
34 | test("Test CorrelationSelector - pearson: numTopFeatures") {
35 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("pearson")
36 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(3)
37 |
38 | val importantColNames = Array("pWidth", "pLength", "sLength")
39 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
40 |
41 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures")
42 | }
43 |
44 | test("Test CorrelationSelector - pearson: percentile") {
45 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("pearson")
46 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51)
47 |
48 | val importantColNames = Array("pWidth", "pLength", "sLength")
49 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
50 |
51 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures")
52 | }
53 |
54 | test("Test CorrelationSelector - pearson: randomCutOff") {
55 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("pearson")
56 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0)
57 |
58 | val importantColNames = Array("pWidth", "pLength", "sLength", "sWidth")
59 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
60 |
61 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures")
62 | }
63 |
64 | test("Test CorrelationSelector - spearman: numTopFeatures") {
65 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("spearman")
66 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(3)
67 |
68 | val importantColNames = Array("pWidth", "pLength", "sLength")
69 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
70 |
71 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures")
72 | }
73 |
74 | test("Test CorrelationSelector - spearman: percentile") {
75 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("spearman")
76 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51)
77 |
78 | val importantColNames = Array("pWidth", "pLength", "sLength")
79 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
80 |
81 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures")
82 | }
83 |
84 | test("Test CorrelationSelector - spearman: randomCutOff") {
85 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("spearman")
86 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0)
87 |
88 | val importantColNames = Array("pWidth", "pLength", "sLength", "sWidth")
89 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
90 |
91 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures")
92 | }
93 |
94 | test("CorrelationSelector read/write") {
95 | val nb = new CorrelationSelector
96 | testEstimatorAndModelReadWrite[CorrelationSelector, CorrelationSelectorModel](nb, dataset,
97 | FeatureSelectorTestBase.allParamSettings.+("correlationType" -> "pearson"),
98 | FeatureSelectorTestBase.checkModelData)
99 | }
100 | }
101 |
102 |
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/ml/feature/selection/filter/GiniSelectorSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.filter
19 |
20 | import org.apache.spark.ml.feature.VectorAssembler
21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase}
22 |
23 | /* To verify the results with R, run:
24 | library(CORElearn)
25 | data(iris)
26 | weights <- attrEval(Species ~ ., data=iris, estimator = "Gini")
27 | print(weights)
28 | */
29 | class GiniSelectorSuite extends FeatureSelectionTestBase {
30 | test("Test GiniSelector: numTopFeatures") {
31 | val selector = new GiniSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
32 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(2)
33 |
34 | val importantColNames = Array("pLength", "pWidth")
35 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
36 |
37 | FeatureSelectorTestBase.testSelector[GiniSelector, GiniSelectorModel](selector, df, importantColNames, "ImportantFeatures")
38 | }
39 |
40 | test("Test GiniSelector: percentile") {
41 | val selector = new GiniSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
42 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51)
43 |
44 | val importantColNames = Array("pLength", "pWidth")
45 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
46 |
47 | FeatureSelectorTestBase.testSelector[GiniSelector, GiniSelectorModel](selector, df, importantColNames, "ImportantFeatures")
48 | }
49 |
50 | test("Test GiniSelector: randomCutOff") {
51 | val selector = new GiniSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
52 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0)
53 |
54 | val importantColNames = Array("pLength", "pWidth", "sLength", "sWidth")
55 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
56 |
57 | FeatureSelectorTestBase.testSelector[GiniSelector, GiniSelectorModel](selector, df, importantColNames, "ImportantFeatures")
58 | }
59 |
60 | test("GiniSelector read/write") {
61 | val nb = new GiniSelector
62 | testEstimatorAndModelReadWrite[GiniSelector, GiniSelectorModel](nb, dataset, FeatureSelectorTestBase.allParamSettings, FeatureSelectorTestBase.checkModelData)
63 | }
64 | }
65 |
66 |
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/ml/feature/selection/filter/InfoGainSelectorSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.filter
19 |
20 | import org.apache.spark.ml.feature.VectorAssembler
21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase}
22 |
23 | /* To verify the results with R, run:
24 | library(RWeka)
25 | data(iris)
26 | weights <- InfoGainAttributeEval(Species ~ ., data=iris,)
27 | print(weights)
28 | */
29 | class InfoGainSelectorSuite extends FeatureSelectionTestBase {
30 | test("Test InfoGainSelector: numTopFeatures") {
31 | val selector = new InfoGainSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
32 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(2)
33 |
34 | val importantColNames = Array("pLength", "pWidth")
35 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
36 |
37 | FeatureSelectorTestBase.testSelector[InfoGainSelector, InfoGainSelectorModel](selector, df, importantColNames, "ImportantFeatures")
38 | }
39 |
40 | test("Test InfoGainSelector: percentile") {
41 | val selector = new InfoGainSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
42 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51)
43 |
44 | val importantColNames = Array("pLength", "pWidth")
45 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
46 |
47 | FeatureSelectorTestBase.testSelector[InfoGainSelector, InfoGainSelectorModel](selector, df, importantColNames, "ImportantFeatures")
48 | }
49 |
50 | test("Test InfoGainSelector: randomCutOff") {
51 | val selector = new InfoGainSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName)
52 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0)
53 |
54 | val importantColNames = Array("pLength", "pWidth", "sLength", "sWidth")
55 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset)
56 |
57 | FeatureSelectorTestBase.testSelector[InfoGainSelector, InfoGainSelectorModel](selector, df, importantColNames, "ImportantFeatures")
58 | }
59 |
60 | test("InfoGainSelector read/write") {
61 | val nb = new InfoGainSelector
62 | testEstimatorAndModelReadWrite[InfoGainSelector, InfoGainSelectorModel](nb, dataset, FeatureSelectorTestBase.allParamSettings, FeatureSelectorTestBase.checkModelData)
63 | }
64 | }
65 |
66 |
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/ml/feature/selection/test_util/DefaultReadWriteTest.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.test_util
19 |
20 | import java.io.{File, IOException}
21 |
22 | import org.apache.spark.ml.feature.selection.FeatureSelectionTestBase
23 | import org.apache.spark.ml.param._
24 | import org.apache.spark.ml.util._
25 | import org.apache.spark.ml.{Estimator, Model}
26 | import org.apache.spark.sql.Dataset
27 | import org.scalatest.Suite
28 |
29 | trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
30 |
31 | /**
32 | * Checks "overwrite" option and params.
33 | * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name
34 | * in order to avoid conflicts from multiple calls to this method.
35 | *
36 | * @param instance ML instance to test saving/loading
37 | * @param testParams If true, then test values of Params. Otherwise, just test overwrite option.
38 | * @tparam T ML instance type
39 | * @return Instance loaded from file
40 | */
41 | def testDefaultReadWrite[T <: Params with MLWritable](
42 | instance: T,
43 | testParams: Boolean = true): T = {
44 | val uid = instance.uid
45 | val subdirName = Identifiable.randomUID("test")
46 |
47 | val subdir = new File(tempDir, subdirName)
48 | val path = new File(subdir, uid).getPath
49 |
50 | instance.save(path)
51 | intercept[IOException] {
52 | instance.save(path)
53 | }
54 | instance.write.overwrite().save(path)
55 | val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]]
56 | val newInstance = loader.load(path)
57 |
58 | assert(newInstance.uid === instance.uid)
59 | if (testParams) {
60 | instance.params.foreach { p =>
61 | if (instance.isDefined(p)) {
62 | (instance.getOrDefault(p), newInstance.getOrDefault(p)) match {
63 | case (Array(values), Array(newValues)) =>
64 | assert(values === newValues, s"Values do not match on param ${p.name}.")
65 | case (value, newValue) =>
66 | assert(value === newValue, s"Values do not match on param ${p.name}.")
67 | }
68 | } else {
69 | assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.")
70 | }
71 | }
72 | }
73 |
74 | val load = instance.getClass.getMethod("load", classOf[String])
75 | val another = load.invoke(instance, path).asInstanceOf[T]
76 | assert(another.uid === instance.uid)
77 | another
78 | }
79 |
80 | /**
81 | * Default test for Estimator, Model pairs:
82 | * - Explicitly set Params, and train model
83 | * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model
84 | * - Check Params on Estimator and Model
85 | * - Compare model data
86 | *
87 | * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s.
88 | *
89 | * @param estimator Estimator to test
90 | * @param dataset Dataset to pass to [[Estimator.fit()]]
91 | * @param testParams Set of [[Param]] values to set in estimator
92 | * @param checkModelData Method which takes the original and loaded [[Model]] and compares their
93 | * data. This method does not need to check [[Param]] values.
94 | * @tparam E Type of [[Estimator]]
95 | * @tparam M Type of [[Model]] produced by estimator
96 | */
97 | def testEstimatorAndModelReadWrite[
98 | E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
99 | estimator: E,
100 | dataset: Dataset[_],
101 | testParams: Map[String, Any],
102 | checkModelData: (M, M) => Unit): Unit = {
103 | // Set some Params to make sure set Params are serialized.
104 | testParams.foreach { case (p, v) =>
105 | estimator.set(estimator.getParam(p), v)
106 | }
107 | val model = estimator.fit(dataset)
108 |
109 | // Test Estimator save/load
110 | val estimator2 = testDefaultReadWrite(estimator)
111 | testParams.foreach { case (p, v) =>
112 | val param = estimator.getParam(p)
113 | assert(estimator.get(param).get === estimator2.get(param).get)
114 | }
115 |
116 | // Test Model save/load
117 | val model2 = testDefaultReadWrite(model)
118 | testParams.foreach { case (p, v) =>
119 | val param = model.getParam(p)
120 | assert(model.get(param).get === model2.get(param).get)
121 | }
122 |
123 | checkModelData(model, model2)
124 | }
125 | }
126 |
127 | class MyParams(override val uid: String) extends Params with MLWritable {
128 |
129 | final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc")
130 | final val intParam: IntParam = new IntParam(this, "intParam", "doc")
131 | final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc")
132 | final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc")
133 | final val longParam: LongParam = new LongParam(this, "longParam", "doc")
134 | final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc")
135 | final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc")
136 | final val doubleArrayParam: DoubleArrayParam =
137 | new DoubleArrayParam(this, "doubleArrayParam", "doc")
138 | final val stringArrayParam: StringArrayParam =
139 | new StringArrayParam(this, "stringArrayParam", "doc")
140 |
141 | setDefault(intParamWithDefault -> 0)
142 | set(intParam -> 1)
143 | set(floatParam -> 2.0f)
144 | set(doubleParam -> 3.0)
145 | set(longParam -> 4L)
146 | set(stringParam -> "5")
147 | set(intArrayParam -> Array(6, 7))
148 | set(doubleArrayParam -> Array(8.0, 9.0))
149 | set(stringArrayParam -> Array("10", "11"))
150 |
151 | override def copy(extra: ParamMap): Params = defaultCopy(extra)
152 |
153 | override def write: MLWriter = new DefaultParamsWriter(this)
154 | }
155 |
156 | object MyParams extends MLReadable[MyParams] {
157 |
158 | override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams]
159 |
160 | override def load(path: String): MyParams = super.load(path)
161 | }
162 |
163 | class DefaultReadWriteSuite extends FeatureSelectionTestBase
164 | with DefaultReadWriteTest {
165 |
166 | test("default read/write") {
167 | val myParams = new MyParams("my_params")
168 | testDefaultReadWrite(myParams)
169 | }
170 | }
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/ml/feature/selection/test_util/TempDirectory.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.test_util
19 |
20 | import java.io.File
21 |
22 | import org.apache.spark.util.Utils
23 | import org.scalatest.{BeforeAndAfterAll, Suite}
24 |
25 | /**
26 | * Trait that creates a temporary directory before all tests and deletes it after all.
27 | */
28 | trait TempDirectory extends BeforeAndAfterAll { self: Suite =>
29 |
30 | private var _tempDir: File = _
31 |
32 | /** Returns the temporary directory as a [[File]] instance. */
33 | protected def tempDir: File = _tempDir
34 |
35 | override def beforeAll(): Unit = {
36 | super.beforeAll()
37 | _tempDir = Utils.createTempDir(namePrefix = this.getClass.getName)
38 | }
39 |
40 | override def afterAll(): Unit = {
41 | try {
42 | Utils.deleteRecursively(_tempDir)
43 | } finally {
44 | super.afterAll()
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/src/test/scala/org/apache/spark/ml/feature/selection/util/VectorMergerSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
3 | * contributor license agreements. See the NOTICE file distributed with
4 | * this work for additional information regarding copyright ownership.
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
6 | * (the "License"); you may not use this file except in compliance with
7 | * the License. You may obtain a copy of the License at
8 | *
9 | * http://www.apache.org/licenses/LICENSE-2.0
10 | *
11 | * Unless required by applicable law or agreed to in writing, software
12 | * distributed under the License is distributed on an "AS IS" BASIS,
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | * See the License for the specific language governing permissions and
15 | * limitations under the License.
16 | */
17 |
18 | package org.apache.spark.ml.feature.selection.util
19 |
20 | import org.apache.spark.ml.attribute.AttributeGroup
21 | import org.apache.spark.ml.feature.VectorAssembler
22 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase}
23 | import org.apache.spark.ml.linalg.Vector
24 | import org.apache.spark.sql.Row
25 |
26 | class VectorMergerSuite extends FeatureSelectionTestBase {
27 |
28 | test("VectorMerger: merges two VectorColumns with different names") {
29 | val dfTmp = new VectorAssembler().setInputCols(Array("pWidth", "pLength")).setOutputCol("vector1").transform(dataset)
30 | val dfTmp2 = new VectorAssembler().setInputCols(Array("sWidth", "sLength")).setOutputCol("vector2").transform(dfTmp)
31 | val df = new VectorAssembler().setInputCols(Array("pWidth", "pLength", "sWidth", "sLength")).setOutputCol("expected").transform(dfTmp2)
32 |
33 | val dfT = new VectorMerger().setInputCols(Array("vector1", "vector2")).setOutputCol("merged").transform(df)
34 |
35 | val outCols = AttributeGroup.fromStructField(dfT.schema("merged")).attributes.get.map(attr => attr.name.get)
36 |
37 | assert(outCols.length == 4, "Length of merged column is not equal to 4!")
38 |
39 | assert(outCols.sorted sameElements Array("pWidth", "pLength", "sWidth", "sLength").sorted,
40 | "Input and output column names do not match!")
41 |
42 | dfT.select("merged", "expected").collect()
43 | .foreach { case Row(vec1: Vector, vec2: Vector) =>
44 | assert(FeatureSelectorTestBase.tolerantVectorEquality(vec1, vec2), "column in merged and expected do not match!")
45 | }
46 | }
47 |
48 | test("VectorMerger: merges two VectorColumns with duplicate names") {
49 | val dfTmp = new VectorAssembler().setInputCols(Array("pWidth", "pLength")).setOutputCol("vector1").transform(dataset)
50 | val dfTmp2 = new VectorAssembler().setInputCols(Array("pLength", "sLength")).setOutputCol("vector2").transform(dfTmp)
51 | val df = new VectorAssembler().setInputCols(Array("pWidth", "pLength", "sLength")).setOutputCol("expected").transform(dfTmp2)
52 |
53 | val dfT = new VectorMerger().setInputCols(Array("vector1", "vector2")).setOutputCol("merged").transform(df)
54 |
55 | val outCols = AttributeGroup.fromStructField(dfT.schema("merged")).attributes.get.map(attr => attr.name.get)
56 |
57 | assert(outCols.length == 3, "Length of merged column is not equal to 3!")
58 |
59 | assert(outCols.sorted sameElements Array("pWidth", "pLength", "sLength").sorted,
60 | "Input and output column names do not match!")
61 |
62 | dfT.select("merged", "expected").collect()
63 | .foreach { case Row(vec1: Vector, vec2: Vector) =>
64 | assert(FeatureSelectorTestBase.tolerantVectorEquality(vec1, vec2), "column in merged and expected do not match!")
65 | }
66 | }
67 |
68 | test("VectorMerger - useFeatureCol: merges two VectorColumns with different names") {
69 | val dfTmp = new VectorAssembler().setInputCols(Array("pWidth", "pLength")).setOutputCol("vector1").transform(dataset)
70 | val dfTmp2 = new VectorAssembler().setInputCols(Array("sWidth", "sLength")).setOutputCol("vector2").transform(dfTmp)
71 | // The features column has a different ordering to test if the correct values are taken for each column
72 | val dfTp3 = new VectorAssembler().setInputCols(Array("pWidth", "sLength", "sWidth", "pLength")).setOutputCol("formerging").transform(dfTmp2)
73 | val df = new VectorAssembler().setInputCols(Array("pWidth", "pLength", "sWidth", "sLength")).setOutputCol("expected").transform(dfTp3)
74 |
75 | val dfT = new VectorMerger()
76 | .setInputCols(Array("vector1", "vector2"))
77 | .setFeatureCol("formerging")
78 | .setUseFeaturesCol(true)
79 | .setOutputCol("merged").transform(df)
80 |
81 | val outCols = AttributeGroup.fromStructField(dfT.schema("merged")).attributes.get.map(attr => attr.name.get)
82 |
83 | assert(outCols.length == 4, "Length of merged column is not equal to 4!")
84 |
85 | assert(outCols.sorted sameElements Array("pWidth", "pLength", "sWidth", "sLength").sorted,
86 | "Input and output column names do not match!")
87 |
88 | dfT.select("merged", "expected").collect()
89 | .foreach { case Row(vec1: Vector, vec2: Vector) =>
90 | assert(FeatureSelectorTestBase.tolerantVectorEquality(vec1, vec2), "column in merged and expected do not match!")
91 | }
92 | }
93 |
94 | test("VectorMerger - useFeatureCol: merges two VectorColumns with duplicate names") {
95 | val dfTmp = new VectorAssembler().setInputCols(Array("pWidth", "pLength")).setOutputCol("vector1").transform(dataset)
96 | val dfTmp2 = new VectorAssembler().setInputCols(Array("pLength", "sLength")).setOutputCol("vector2").transform(dfTmp)
97 | // The features column has a different ordering to test if the correct values are taken for each column
98 | val dfTp3 = new VectorAssembler().setInputCols(Array("pWidth", "sLength", "sWidth", "pLength")).setOutputCol("formerging").transform(dfTmp2)
99 | val df = new VectorAssembler().setInputCols(Array("pWidth", "pLength", "sLength")).setOutputCol("expected").transform(dfTp3)
100 |
101 | val dfT = new VectorMerger()
102 | .setInputCols(Array("vector1", "vector2"))
103 | .setFeatureCol("formerging")
104 | .setUseFeaturesCol(true)
105 | .setOutputCol("merged").transform(df)
106 |
107 | val outCols = AttributeGroup.fromStructField(dfT.schema("merged")).attributes.get.map(attr => attr.name.get)
108 |
109 | assert(outCols.length == 3, "Length of merged column is not equal to 3!")
110 |
111 | assert(outCols.sorted sameElements Array("pWidth", "pLength", "sLength").sorted,
112 | "Input and output column names do not match!")
113 |
114 | dfT.select("merged", "expected").collect()
115 | .foreach { case Row(vec1: Vector, vec2: Vector) =>
116 | assert(FeatureSelectorTestBase.tolerantVectorEquality(vec1, vec2), "column in merged and expected do not match!")
117 | }
118 | }
119 |
120 | test("VectorMerger read/write") {
121 | val nb = new VectorMerger
122 | testDefaultReadWrite[VectorMerger](nb, testParams = true)
123 | }
124 | }
125 |
--------------------------------------------------------------------------------
/version.sbt:
--------------------------------------------------------------------------------
1 | version in ThisBuild := "1.0.0-SNAPSHOT"
--------------------------------------------------------------------------------