├── .gitignore ├── LICENSE ├── README.md ├── bin └── spark-featureselection_2.11-1.0.0.jar ├── build.sbt ├── project ├── build.properties └── plugins.sbt ├── src ├── main │ └── scala │ │ └── org │ │ └── apache │ │ └── spark │ │ └── ml │ │ └── feature │ │ └── selection │ │ ├── FeatureSelector.scala │ │ ├── FeatureSelectorParams.scala │ │ ├── embedded │ │ ├── ImportanceSelector.scala │ │ └── LRSelector.scala │ │ ├── filter │ │ ├── CorrelationSelector.scala │ │ ├── GiniSelector.scala │ │ └── InfoGainSelector.scala │ │ └── util │ │ └── VectorMerger.scala └── test │ ├── resources │ └── iris.data │ └── scala │ └── org │ └── apache │ └── spark │ └── ml │ └── feature │ └── selection │ ├── FeatureSelectionTestBase.scala │ ├── embedded │ ├── ImportanceSelectorSuite.scala │ └── LRSelectorSuite.scala │ ├── filter │ ├── CorrelationSelectorSuite.scala │ ├── GiniSelectorSuite.scala │ └── InfoGainSelectorSuite.scala │ ├── test_util │ ├── DefaultReadWriteTest.scala │ └── TempDirectory.scala │ └── util │ └── VectorMergerSuite.scala └── version.sbt /.gitignore: -------------------------------------------------------------------------------- 1 | build/sbt-launch*.jar 2 | target/ 3 | .idea/ 4 | lib/ 5 | .idea_modules/ 6 | *.iml 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Feature Selection for Apache Spark 2 | Different Featureselection methods (3 filters/ 2 selectors based on scores from embedded methods) are provided as Spark MLlib `PipelineStage`s. 3 | These are: 4 | 5 | ## Filters: 6 | 1) CorrelationSelector: calculates correlation ("spearman", "pearson"- adjustable through ```.setCorrelationType```) between each feature and label. 7 | 2) GiniSelector: measures impurity difference between before and after a feature value is known. 8 | 3) InfoGainSelector: measures the information gain of a feature with respect to the class. 9 | 10 | ## Embedded: 11 | 1) ImportanceSelector: takes FeatureImportances from any embedded method, e.g. Random Forest. 12 | 2) LRSelector: takes feature weights from (L1) logistic regression. The weights are in a matrix W with dimensions #Labels X #Features. The absolute value is taken from all entries, summed column wise and scaled with the max value. 13 | 14 | ## Util 15 | 1) VectorMerger: takes several VectorColumns (e.g. the result of different feature selection methods) and merges them into one VectorColumn. Unlike the VectorAssembler, VectorMerger uses the metadata of the VectorColumn to remove duplicates. It supports two modes: 16 | - useFeaturesCol true and featuresCol set: the output column will contain the corresponding column from featuresCol (match by name) that have names appearing in one of the inputCols. Use this, if feature importances were calculated using (e.g.) discretized columns, but selection shall use original values. 17 | - useFeaturesCol false: the output column will contain the columns from the inputColumns, but dropping duplicates. 18 | 19 | Formulas for metrics: 20 | 21 | General: 22 | - - prior probability of feature X having value 23 | - - cond. probability that a sample is of class , given that feature X has value 24 | - - prior probability that the label Y has value 25 | 26 | 1) Correlation: Calculated through ``org.apache.spark.mllib.stat`` 27 | 2) Gini: 28 | 29 | 30 | 3) Informationgain: 31 | 32 | 33 | 34 | ## Usage 35 | 36 | All selection methods share a common API, similar to `ChiSqSelector`. 37 | 38 | ```scala 39 | import org.apache.spark.ml.feature.selection.filter._ 40 | import org.apache.spark.ml.feature.selection.util._ 41 | import org.apache.spark.ml.linalg.Vectors 42 | import org.apache.spark.ml.Pipeline 43 | 44 | val data = Seq( 45 | (Vectors.dense(0.0, 0.0, 18.0, 1.0), 1.0), 46 | (Vectors.dense(0.0, 1.0, 12.0, 0.0), 0.0), 47 | (Vectors.dense(1.0, 0.0, 15.0, 0.1), 0.0) 48 | ) 49 | 50 | val df = spark.createDataset(data).toDF("features", "label") 51 | 52 | val igSel = new InfoGainSelector() 53 | .setFeaturesCol("features") 54 | .setLabelCol("label") 55 | .setOutputCol("igSelectedFeatures") 56 | .setSelectorType("percentile") 57 | 58 | val corSel = new CorrelationSelector() 59 | .setFeaturesCol("features") 60 | .setLabelCol("label") 61 | .setOutputCol("corrSelectedFeatures") 62 | .setSelectorType("percentile") 63 | 64 | val giniSel = new GiniSelector() 65 | .setFeaturesCol("features") 66 | .setLabelCol("label") 67 | .setOutputCol("giniSelectedFeatures") 68 | .setSelectorType("percentile") 69 | 70 | val merger = new VectorMerger() 71 | .setInputCols(Array("igSelectedFeatures", "corrSelectedFeatures", "giniSelectedFeatures")) 72 | .setOutputCol("filtered") 73 | 74 | val plm = new Pipeline().setStages(Array(igSel, corSel, giniSel, merger)).fit(df) 75 | 76 | plm.transform(df).select("filtered").show() 77 | 78 | ``` 79 | 80 | 81 | -------------------------------------------------------------------------------- /bin/spark-featureselection_2.11-1.0.0.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcKaminski/spark-FeatureSelection/edccceec6e4dfaf44533a71d59050e66c3c8ba04/bin/spark-featureselection_2.11-1.0.0.jar -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | name := "spark-FeatureSelection" 2 | 3 | organization := "MarcKaminski" 4 | 5 | version := "1.0.0" 6 | 7 | scalaVersion := "2.11.8" 8 | 9 | val sparkVersion = "2.2.0" 10 | 11 | 12 | // spark version to be used 13 | libraryDependencies ++= Seq( 14 | "org.apache.spark" %% "spark-core" % sparkVersion % "provided", 15 | "org.apache.spark" %% "spark-sql" % sparkVersion % "provided", 16 | "org.apache.spark" %% "spark-mllib" % sparkVersion % "provided" 17 | ) 18 | 19 | // For tests 20 | parallelExecution in Test := false 21 | fork in Test := false // true -> Spark during tests; false -> debug during tests (for debug run sbt with: sbt -jvm-debug 5005) 22 | libraryDependencies += "org.scalatest" % "scalatest_2.11" % "3.0.1" % "test" 23 | libraryDependencies += "org.scalactic" %% "scalactic" % "3.0.1" % "test" 24 | 25 | 26 | /******************** 27 | * Release settings * 28 | ********************/ 29 | 30 | publishMavenStyle := true 31 | 32 | licenses += ("Apache-2.0", url("http://www.apache.org/licenses/LICENSE-2.0")) 33 | 34 | pomExtra := 35 | https://github.com/MarcKaminski/spark-FeatureSelection 36 | 37 | git@github.com:MarcKaminski/spark-FeatureSelection.git 38 | scm:git:git@github.com:MarcKaminski/spark-FeatureSelection.git 39 | 40 | 41 | 42 | MarcKaminski 43 | Marc Kaminski 44 | https://github.com/MarcKaminski 45 | 46 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 0.13.16 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.4" exclude("org.apache.maven", "maven-plugin-api")) 2 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.6") -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/feature/selection/FeatureSelector.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection 19 | 20 | import org.apache.hadoop.fs.Path 21 | import org.apache.spark.annotation.Since 22 | import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} 23 | import org.apache.spark.ml.feature.VectorAssembler 24 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, _} 25 | import org.apache.spark.ml.param.ParamMap 26 | import org.apache.spark.ml.util._ 27 | import org.apache.spark.ml.{Estimator, Model} 28 | import org.apache.spark.sql.functions.{rand, udf} 29 | import org.apache.spark.sql.types.{IntegerType, StructField, StructType} 30 | import org.apache.spark.sql.{DataFrame, Dataset} 31 | 32 | /** 33 | * Abstraction for FeatureSelectors, which selects features to use for predicting a categorical label. 34 | * The selector supports two selection methods: `numTopFeatures` and `percentile`. 35 | * - `numTopFeatures` chooses a fixed number of top features according to the feature importance. 36 | * - `percentile` is similar but chooses a fraction of all features instead of a fixed number. 37 | * By default, the selection method is `numTopFeatures`, with the default number of top features set to 50. 38 | * 39 | * @tparam Learner Specialization of this class. If you subclass this type, use this type 40 | * parameter to specify the concrete type. 41 | * @tparam M Specialization of [[FeatureSelectorModel]]. If you subclass this type, use this type 42 | * parameter to specify the concrete type for the corresponding model. 43 | */ 44 | abstract class FeatureSelector[ 45 | Learner <: FeatureSelector[Learner, M], 46 | M <: FeatureSelectorModel[M]] @Since("2.1.1") 47 | extends Estimator[M] with FeatureSelectorParams with DefaultParamsWritable { 48 | /** @group setParam */ 49 | @Since("2.1.1") 50 | def setNumTopFeatures(value: Int): Learner = set(numTopFeatures, value).asInstanceOf[Learner] 51 | 52 | /** @group setParam */ 53 | @Since("2.1.1") 54 | def setPercentile(value: Double): Learner = set(percentile, value).asInstanceOf[Learner] 55 | 56 | /** @group setParam */ 57 | @Since("2.1.1") 58 | def setSelectorType(value: String): Learner = set(selectorType, value).asInstanceOf[Learner] 59 | 60 | /** @group setParam */ 61 | @Since("2.1.1") 62 | def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] 63 | 64 | /** @group setParam */ 65 | @Since("2.1.1") 66 | def setOutputCol(value: String): Learner = set(outputCol, value).asInstanceOf[Learner] 67 | 68 | /** @group setParam */ 69 | @Since("2.1.1") 70 | def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] 71 | 72 | /** @group setParam */ 73 | @Since("2.1.1") 74 | def setRandomCutOff(value: Double): Learner = set(randomCutOff, value).asInstanceOf[Learner] 75 | 76 | override def fit(dataset: Dataset[_]): M = { 77 | // This handles a few items such as schema validation. 78 | // Developers only need to implement train() and make(). 79 | transformSchema(dataset.schema, logging = true) 80 | 81 | val randomColMaxCategories = 10000 82 | 83 | // Get num features for percentile calculation 84 | val attrGroup = AttributeGroup.fromStructField(dataset.schema($ { 85 | featuresCol 86 | })) 87 | val numFeatures = attrGroup.size 88 | 89 | val (featureImportances, features) = 90 | if ($(selectorType) == FeatureSelector.Random) { 91 | // Append column with random values to dataframe 92 | val withRandom = dataset.withColumn("random", (rand * randomColMaxCategories).cast(IntegerType)) 93 | val featureVectorWithRandom = new VectorAssembler() 94 | .setInputCols(Array($(featuresCol), "random")) 95 | .setOutputCol("FeaturesAndRandom") 96 | .transform(withRandom) 97 | 98 | // Cache and change features column name, calculate importances and reset. 99 | val realFeaturesCol = $(featuresCol) 100 | setFeaturesCol("FeaturesAndRandom") 101 | val featureImportances = train(featureVectorWithRandom) 102 | setFeaturesCol(realFeaturesCol) 103 | val idFromRandomCol = featureImportances.map(_._1).max 104 | 105 | // Take features until reaching random feature. Take overlap from remaining depending on randomCutOff percentage 106 | val sortedFeatureImportances = featureImportances 107 | .sortBy { case (_, imp) => -imp } // minus for descending direction! 108 | .zipWithIndex 109 | 110 | val randomColPos = sortedFeatureImportances.find { case ((fId, fImp), sortId) => fId == idFromRandomCol }.get._2 111 | val overlap = math.max(0, math.round((featureImportances.length - randomColPos - 1) * $(randomCutOff))).toInt 112 | 113 | (featureImportances.filterNot(_._1 == idFromRandomCol), sortedFeatureImportances 114 | .take(randomColPos + overlap + 1) 115 | .map(_._1) 116 | .filterNot(_._1 == idFromRandomCol)) 117 | } else { 118 | val featureImportances = train(dataset) 119 | 120 | // Select features depending on selection method 121 | val features = $(selectorType) match { 122 | case FeatureSelector.NumTopFeatures => featureImportances 123 | .sortBy { case (_, imp) => -imp } // minus for descending direction! 124 | .take($(numTopFeatures)) 125 | case FeatureSelector.Percentile => featureImportances 126 | .sortBy { case (_, imp) => -imp } 127 | .take((numFeatures * $(percentile)).toInt) // Take is save, even if numFeatures > featureImportances.length 128 | case errorType => 129 | throw new IllegalStateException(s"Unknown FeatureSelector Type: $errorType") 130 | } 131 | (featureImportances, features) 132 | } 133 | 134 | if (featureImportances.length < numFeatures) 135 | logWarning(s"Some features were dropped while calculating importance values, " + 136 | s"since numFeatureImportances < numFeatures (${featureImportances.length} < $numFeatures). This happens " + 137 | s"e.g. for constant features.") 138 | 139 | 140 | // Save name of columns and corresponding importance value 141 | val nameImportances = featureImportances.map { case (idx, imp) => ( 142 | if (attrGroup.attributes.isDefined && attrGroup.getAttr(idx).name.isDefined) 143 | attrGroup.getAttr(idx).name.get 144 | else { 145 | logWarning(s"The metadata of $featuresCol is empty or does not contain a name for col index: $idx") 146 | idx.toString 147 | } 148 | , imp) 149 | } 150 | 151 | val indices = features.map { case (idx, _) => idx } 152 | 153 | copyValues(make(uid, indices, nameImportances.toMap).setParent(this)) 154 | } 155 | 156 | override def copy(extra: ParamMap): Learner 157 | 158 | /** 159 | * Calculate the featureImportances. These shall be sortable in descending direction to select the best features. 160 | * FeatureSelectors implement this to avoid dealing with schema validation 161 | * and copying parameters into the model. 162 | * 163 | * @param dataset Training dataset 164 | * @return Array of (feature index, feature importance) 165 | */ 166 | protected def train(dataset: Dataset[_]): Array[(Int, Double)] 167 | 168 | /** 169 | * Abstract intantiation of the Model. 170 | * FeatureSelectors implement this as a constructor for FeatureSelectorModels 171 | * @param uid of Model 172 | * @param selectedFeatures list of indices to select 173 | * @param featureImportances Map that stores each feature importance 174 | * @return Fitted model 175 | */ 176 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): M 177 | 178 | @Since("2.1.1") 179 | def transformSchema(schema: StructType): StructType = { 180 | val otherPairs = FeatureSelector.supportedSelectorTypes.filter(_ != $(selectorType)) 181 | otherPairs.foreach { paramName: String => 182 | if (isSet(getParam(paramName))) { 183 | logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") 184 | } 185 | } 186 | SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) 187 | SchemaUtils.checkNumericType(schema, $(labelCol)) 188 | SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) 189 | } 190 | } 191 | 192 | /** 193 | * Abstraction for a model for selecting features. 194 | * @param uid of Model 195 | * @param selectedFeatures list of indices to select 196 | * @param featureImportances Map that stores each feature importance 197 | * @tparam M Specialization of [[FeatureSelectorModel]]. If you subclass this type, use this type 198 | * parameter to specify the concrete type for the corresponding model. 199 | */ 200 | @Since("2.1.1") 201 | abstract class FeatureSelectorModel[M <: FeatureSelectorModel[M]] private[ml] (@Since("2.1.1") override val uid: String, 202 | @Since("2.1.1") val selectedFeatures: Array[Int], 203 | @Since("2.1.1") val featureImportances: Map[String, Double]) 204 | extends Model[M] with FeatureSelectorParams with MLWritable { 205 | /** @group setParam */ 206 | @Since("2.1.1") 207 | def setFeaturesCol(value: String): this.type = set(featuresCol, value) 208 | 209 | /** @group setParam */ 210 | @Since("2.1.1") 211 | def setOutputCol(value: String): this.type = set(outputCol, value) 212 | 213 | @Since("2.1.1") 214 | override def transform(dataset: Dataset[_]): DataFrame = { 215 | transformSchema(dataset.schema, logging = true) 216 | 217 | // Validity checks 218 | val inputAttr = AttributeGroup.fromStructField(dataset.schema($(featuresCol))) 219 | inputAttr.numAttributes.foreach { numFeatures => 220 | val maxIndex = selectedFeatures.max 221 | require(maxIndex < numFeatures, 222 | s"Selected feature index $maxIndex invalid for only $numFeatures input features.") 223 | } 224 | 225 | // Prepare output attributes 226 | val selectedAttrs: Option[Array[Attribute]] = inputAttr.attributes.map { attrs => 227 | selectedFeatures.map(index => attrs(index)) 228 | } 229 | val outputAttr = selectedAttrs match { 230 | case Some(attrs) => new AttributeGroup($(outputCol), attrs) 231 | case None => new AttributeGroup($(outputCol), selectedFeatures.length) 232 | } 233 | 234 | // Select features 235 | val slicer = udf { vec: Vector => 236 | vec match { 237 | case features: DenseVector => Vectors.dense(selectedFeatures.map(features.apply)) 238 | case features: SparseVector => features.slice(selectedFeatures) 239 | } 240 | } 241 | dataset.withColumn($(outputCol), slicer(dataset($(featuresCol))), outputAttr.toMetadata()) 242 | } 243 | 244 | @Since("2.1.1") 245 | override def transformSchema(schema: StructType): StructType = { 246 | SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) 247 | val newField = prepOutputField(schema) 248 | val outputFields = schema.fields :+ newField 249 | StructType(outputFields) 250 | } 251 | 252 | /** 253 | * Prepare the output column field, including per-feature metadata. 254 | */ 255 | private def prepOutputField(schema: StructType): StructField = { 256 | val selector = selectedFeatures.toSet 257 | val origAttrGroup = AttributeGroup.fromStructField(schema($(featuresCol))) 258 | val featureAttributes: Array[Attribute] = if (origAttrGroup.attributes.nonEmpty) { 259 | origAttrGroup.attributes.get.zipWithIndex.filter(x => selector.contains(x._2)).map(_._1) 260 | } else { 261 | Array.fill[Attribute](selector.size)(NominalAttribute.defaultAttr) 262 | } 263 | val newAttributeGroup = new AttributeGroup($(outputCol), featureAttributes) 264 | newAttributeGroup.toStructField() 265 | } 266 | } 267 | 268 | 269 | @Since("2.1.1") 270 | protected [selection] object FeatureSelectorModel { 271 | 272 | class FeatureSelectorModelWriter[M <: FeatureSelectorModel[M]](instance: M) extends MLWriter { 273 | 274 | private case class Data(selectedFeatures: Seq[Int], featureImportances: Map[String, Double]) 275 | 276 | override protected def saveImpl(path: String): Unit = { 277 | DefaultParamsWriter.saveMetadata(instance, path, sc) 278 | val data = Data(instance.selectedFeatures.toSeq, instance.featureImportances) 279 | val dataPath = new Path(path, "data").toString 280 | sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) 281 | } 282 | } 283 | 284 | abstract class FeatureSelectorModelReader[M <: FeatureSelectorModel[M]] extends MLReader[M] { 285 | 286 | protected val className: String 287 | 288 | override def load(path: String): M = { 289 | val metadata = DefaultParamsReader.loadMetadata(path, sc, className) 290 | val dataPath = new Path(path, "data").toString 291 | val data = sparkSession.read.parquet(dataPath) 292 | val selectedFeatures = data.select("selectedFeatures").head().getAs[Seq[Int]](0).toArray 293 | val featureImportances = data.select("featureImportances").head().getAs[Map[String, Double]](0) 294 | val model = make(metadata.uid, selectedFeatures, featureImportances) 295 | DefaultParamsReader.getAndSetParams(model, metadata) 296 | model 297 | } 298 | 299 | /** 300 | * Abstract intantiation of the Model. 301 | * FeatureSelectors implement this as a constructor for FeatureSelectorModels 302 | * @param uid of Model 303 | * @param selectedFeatures list of indices to select 304 | * @param featureImportances Map that stores each feature importance 305 | * @return Fitted model 306 | */ 307 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): M 308 | } 309 | } 310 | 311 | private[selection] object FeatureSelector { 312 | /** 313 | * String name for `numTopFeatures` selector type. 314 | */ 315 | val NumTopFeatures: String = "numTopFeatures" 316 | 317 | /** 318 | * String name for `percentile` selector type. 319 | */ 320 | val Percentile: String = "percentile" 321 | 322 | /** 323 | * String name for `random` selector type. 324 | */ 325 | val Random: String = "randomCutOff" 326 | 327 | /** Set of selector types that FeatureSelector supports. */ 328 | val supportedSelectorTypes: Array[String] = Array(NumTopFeatures, Percentile, Random) 329 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/feature/selection/FeatureSelectorParams.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection 19 | 20 | import org.apache.spark.annotation.Since 21 | import org.apache.spark.ml.param._ 22 | import org.apache.spark.ml.param.shared._ 23 | 24 | /** 25 | * Params for [[FeatureSelector]] and [[FeatureSelectorModel]]. 26 | */ 27 | private[selection] trait FeatureSelectorParams extends Params 28 | with HasFeaturesCol with HasOutputCol with HasLabelCol { 29 | 30 | /** 31 | * Number of features that selector will select. If the 32 | * number of features is less than numTopFeatures, then this will select all features. 33 | * Only applicable when selectorType = "numTopFeatures". 34 | * The default value of numTopFeatures is 50. 35 | * 36 | * @group param 37 | */ 38 | @Since("2.1.1") 39 | final val numTopFeatures = new IntParam(this, "numTopFeatures", 40 | "Number of features that selector will select. If the" + 41 | " number of features is < numTopFeatures, then this will select all features.", 42 | ParamValidators.gtEq(1)) 43 | setDefault(numTopFeatures -> 50) 44 | 45 | /** @group getParam */ 46 | @Since("2.1.1") 47 | def getNumTopFeatures: Int = $(numTopFeatures) 48 | 49 | /** 50 | * Percentile of features that selector will select. 51 | * Only applicable when selectorType = "percentile". 52 | * Default value is 0.5. 53 | * 54 | * @group param 55 | */ 56 | @Since("2.1.1") 57 | final val percentile = new DoubleParam(this, "percentile", 58 | "Percentile of features that selector will select.", 59 | ParamValidators.inRange(0, 1)) 60 | setDefault(percentile -> 0.5) 61 | 62 | /** @group getParam */ 63 | @Since("2.1.1") 64 | def getPercentile: Double = $(percentile) 65 | 66 | /** 67 | * Percentile of features that selector will select after the random column threshold (of number of remaining features). 68 | * Only applicable when selectorType = "randomCutOff". 69 | * Default value is 0.05. 70 | * 71 | * @group param 72 | */ 73 | @Since("2.1.1") 74 | final val randomCutOff = new DoubleParam(this, "randomCutOff", 75 | "Percentile of features that selector will select after the random column threshold (of number of remaining features).", 76 | ParamValidators.inRange(0, 1)) 77 | setDefault(percentile -> 0.05) 78 | 79 | /** @group getParam */ 80 | @Since("2.1.1") 81 | def getRandomCutOff: Double = $(randomCutOff) 82 | 83 | /** 84 | * The selector type of the FeatureSelector. 85 | * Supported options: "numTopFeatures", "percentile" (default), . 86 | * 87 | * @group param 88 | */ 89 | @Since("2.1.1") 90 | final val selectorType = new Param[String](this, "selectorType", 91 | "The selector type of the FeatureSelector. " + 92 | "Supported options: " + FeatureSelector.supportedSelectorTypes.mkString(", "), 93 | ParamValidators.inArray[String](FeatureSelector.supportedSelectorTypes)) 94 | setDefault(selectorType -> FeatureSelector.Percentile) 95 | 96 | /** @group getParam */ 97 | @Since("2.1.1") 98 | def getSelectorType: String = $(selectorType) 99 | } 100 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/feature/selection/embedded/ImportanceSelector.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.embedded 19 | 20 | import org.apache.spark.annotation.Since 21 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel} 22 | import org.apache.spark.ml.linalg.{Vector, _} 23 | import org.apache.spark.ml.param._ 24 | import org.apache.spark.ml.util._ 25 | import org.apache.spark.sql._ 26 | import org.apache.spark.sql.types.StructType 27 | 28 | import scala.collection.mutable 29 | 30 | /** 31 | * Params for [[ImportanceSelector]] and [[ImportanceSelectorModel]]. 32 | */ 33 | 34 | private[embedded] trait ImportanceSelectorParams extends Params { 35 | /** 36 | * Param for featureImportances. 37 | * 38 | * @group param 39 | */ 40 | final val featureWeights: Param[Vector] = new Param[Vector](this, "featureWeights", 41 | "featureWeights to rank features and select from a column") 42 | 43 | /** @group getParam */ 44 | final def getFeatureWeights: Vector = $(featureWeights) 45 | } 46 | 47 | /** 48 | * Feature selection based on featureImportances (e.g. from RandomForest). 49 | */ 50 | @Since("2.1.1") 51 | final class ImportanceSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String) 52 | extends FeatureSelector[ImportanceSelector, ImportanceSelectorModel] with ImportanceSelectorParams { 53 | 54 | @Since("2.1.1") 55 | def this() = this(Identifiable.randomUID("importanceSelector")) 56 | 57 | /** @group setParam */ 58 | @Since("2.1.1") 59 | def setFeatureWeights(value: Vector): this.type = set(featureWeights, value) 60 | 61 | @Since("2.1.1") 62 | override protected def train(dataset: Dataset[_]): Array[(Int, Double)] = { 63 | val arrBuilder = new mutable.ArrayBuffer[(Int, Double)]() 64 | $(featureWeights).foreachActive((idx, value) => arrBuilder.append((idx, value))) 65 | val featureImportancesLocal = arrBuilder.toArray 66 | 67 | if ($(selectorType) == FeatureSelector.Random) { 68 | val mean = featureImportancesLocal.map(_._2).sum / featureImportancesLocal.length 69 | featureImportancesLocal :+ (featureImportancesLocal.map(_._1).max + 1, mean) 70 | } else 71 | featureImportancesLocal 72 | } 73 | 74 | @Since("2.1.1") 75 | override def transformSchema(schema: StructType): StructType = { 76 | val otherPairs = FeatureSelector.supportedSelectorTypes.filter(_ != $(selectorType)) 77 | otherPairs.foreach { paramName: String => 78 | if (isSet(getParam(paramName))) { 79 | logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") 80 | } 81 | } 82 | 83 | require(isDefined(featureWeights)) 84 | SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) 85 | SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) 86 | } 87 | 88 | @Since("2.1.1") 89 | override def copy(extra: ParamMap): ImportanceSelector = defaultCopy(extra) 90 | 91 | @Since("2.1.1") 92 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): ImportanceSelectorModel = { 93 | new ImportanceSelectorModel(uid, selectedFeatures, featureImportances) 94 | } 95 | } 96 | 97 | object ImportanceSelector extends DefaultParamsReadable[ImportanceSelector] { 98 | @Since("2.1.1") 99 | override def load(path: String): ImportanceSelector = super.load(path) 100 | } 101 | 102 | /** 103 | * Model fitted by [[ImportanceSelector]]. 104 | * @param uid of Model 105 | * @param selectedFeatures list of indices to select 106 | * @param featureImportances Map that stores each feature importance 107 | */ 108 | @Since("2.1.1") 109 | final class ImportanceSelectorModel private[selection] (@Since("2.1.1") override val uid: String, 110 | @Since("2.1.1") override val selectedFeatures: Array[Int], 111 | @Since("2.1.1") override val featureImportances: Map[String, Double]) 112 | extends FeatureSelectorModel[ImportanceSelectorModel](uid, selectedFeatures, featureImportances) with ImportanceSelectorParams { 113 | @Since("2.1.1") 114 | override def copy(extra: ParamMap): ImportanceSelectorModel = { 115 | val copied = new ImportanceSelectorModel(uid, selectedFeatures, featureImportances) 116 | copyValues(copied, extra).setParent(parent) 117 | } 118 | 119 | @Since("2.1.1") 120 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter[ImportanceSelectorModel](this) 121 | } 122 | 123 | @Since("2.1.1") 124 | object ImportanceSelectorModel extends MLReadable[ImportanceSelectorModel] { 125 | @Since("2.1.1") 126 | override def read: MLReader[ImportanceSelectorModel] = new ImportanceSelectorModelReader 127 | 128 | @Since("2.1.1") 129 | override def load(path: String): ImportanceSelectorModel = super.load(path) 130 | } 131 | 132 | @Since("2.1.1") 133 | final class ImportanceSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[ImportanceSelectorModel] { 134 | @Since("2.1.1") 135 | override protected val className: String = classOf[ImportanceSelectorModel].getName 136 | 137 | @Since("2.1.1") 138 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): ImportanceSelectorModel = { 139 | new ImportanceSelectorModel(uid, selectedFeatures, featureImportances) 140 | } 141 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/feature/selection/embedded/LRSelector.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | 19 | package org.apache.spark.ml.feature.selection.embedded 20 | 21 | import org.apache.spark.annotation.{DeveloperApi, Since} 22 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel} 23 | import org.apache.spark.ml.linalg._ 24 | import org.apache.spark.ml.param.{Param, _} 25 | import org.apache.spark.ml.util._ 26 | import org.apache.spark.mllib.linalg.{Vectors => MllibVectors} 27 | import org.apache.spark.mllib.stat.Statistics 28 | import org.apache.spark.sql._ 29 | import org.apache.spark.sql.types.StructType 30 | import org.json4s.jackson.JsonMethods.{compact, parse, render} 31 | import org.json4s.{JArray, JDouble, JInt, JObject, JValue} 32 | 33 | /** 34 | * Specialized version of `Param[Matrix]` for Java. 35 | */ 36 | @DeveloperApi 37 | private[embedded] class MatrixParam(parent: String, name: String, doc: String, isValid: Matrix => Boolean) 38 | extends Param[Matrix](parent, name, doc, isValid) { 39 | 40 | def this(parent: String, name: String, doc: String) = 41 | this(parent, name, doc, (mat: Matrix) => true) 42 | 43 | def this(parent: Identifiable, name: String, doc: String, isValid: Matrix => Boolean) = 44 | this(parent.uid, name, doc, isValid) 45 | 46 | def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc) 47 | 48 | /** Creates a param pair with the given value (for Java). */ 49 | override def w(value: Matrix): ParamPair[Matrix] = super.w(value) 50 | 51 | override def jsonEncode(value: Matrix): String = { 52 | compact(render(MatrixParam.jValueEncode(value))) 53 | } 54 | 55 | override def jsonDecode(json: String): Matrix = { 56 | MatrixParam.jValueDecode(parse(json)) 57 | } 58 | } 59 | 60 | private[embedded] object MatrixParam { 61 | /** Encodes a param value into JValue. */ 62 | def jValueEncode(value: Matrix): JValue = { 63 | val rows = JInt(value.numRows) 64 | val cols = JInt(value.numCols) 65 | val vals = JArray(for (v <- value.toArray.toList) yield JDouble(v)) 66 | 67 | JObject(List(("rows", rows), ("cols", cols), ("vals", vals))) 68 | } 69 | 70 | /** Decodes a param value from JValue. */ 71 | def jValueDecode(jValue: JValue): Matrix = { 72 | var rows, cols: Int = 0 73 | var vals: Array[Double] = Array.empty[Double] 74 | 75 | var rowsSet, colsSet, valsSet = false 76 | 77 | jValue match { 78 | case obj: JObject => for (kvPair <- obj.values) { 79 | kvPair._2 match { 80 | case x: BigInt => 81 | if (kvPair._1 == "rows") { 82 | rows = x.toInt 83 | rowsSet = true 84 | } 85 | else if (kvPair._1 == "cols") { 86 | cols = x.toInt 87 | colsSet = true 88 | } 89 | else 90 | throw new IllegalArgumentException(s"Cannot recognize unexpected key ${kvPair._1}. (Value is BigInt)") 91 | case arr: List[_] => 92 | if (arr.forall { case _: Double => true; case _ => false }) 93 | if (kvPair._1 == "vals") { 94 | vals = arr.asInstanceOf[List[Double]].toArray 95 | valsSet = true 96 | } 97 | else 98 | throw new IllegalArgumentException(s"Cannot decode unexpected key ${kvPair._1} to Matrix.") 99 | else 100 | throw new IllegalArgumentException(s"Cannot decode unexpected key ${kvPair._1} with value: ${kvPair._2}.") 101 | case _ => throw new IllegalArgumentException(s"Cannot decode unexpected key ${kvPair._1} with value: ${kvPair._2}.") 102 | } 103 | } 104 | case _ => 105 | throw new IllegalArgumentException(s"Cannot decode $jValue to Matrix.") 106 | } 107 | 108 | if (colsSet && rowsSet && valsSet) 109 | Matrices.dense(rows, cols, vals) 110 | else 111 | throw new IllegalArgumentException(s"Cannot decode $jValue. " + 112 | s"Missing values to create Matrix: colsSet: $colsSet; rowsSet: $rowsSet; valsSet: $valsSet") 113 | } 114 | } 115 | 116 | /** 117 | * Params for [[LRSelector]] and [[LRSelectorModel]]. 118 | */ 119 | 120 | private[embedded] trait LRSelectorParams extends Params { 121 | /** 122 | * Param for coefficientMatrix. 123 | * 124 | * @group param 125 | */ 126 | final val coefficientMatrix: MatrixParam = new MatrixParam(this, "coefficientMatrix", "coefficientMatrix of LR model") 127 | 128 | /** @group getParam */ 129 | final def getCoefficientMatrix: Matrix = $(coefficientMatrix) 130 | 131 | /** 132 | * Choose, if coefficients shall be scaled using the maximum value of the corresponding feature. 133 | * Use, if no StandardScaler was used prior LR training 134 | * @group param 135 | */ 136 | @Since("2.1.1") 137 | final val scaleCoefficients = new BooleanParam(this, "scaleCoefficients", 138 | "Scale the coefficients using the maximum values of the corresponding feature.") 139 | setDefault(scaleCoefficients, false) 140 | 141 | /** @group getParam */ 142 | @Since("2.1.1") 143 | def getScaleCoefficients: Boolean = $(scaleCoefficients) 144 | } 145 | 146 | /** 147 | * Feature selection based on LR weights (absolute value). 148 | * The selector can scale the coefficients using the corresponding maximum feature value. To activate, set `scaleCoefficients` to true. 149 | * Default: false 150 | */ 151 | @Since("2.1.1") 152 | final class LRSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String) 153 | extends FeatureSelector[LRSelector, LRSelectorModel] with LRSelectorParams { 154 | 155 | @Since("2.1.1") 156 | def this() = this(Identifiable.randomUID("lrSelector")) 157 | 158 | /** @group setParam */ 159 | @Since("2.1.1") 160 | def setCoefficientMatrix(value: Matrix): this.type = set(coefficientMatrix, value) 161 | 162 | /** @group setParam */ 163 | @Since("2.1.1") 164 | def setScaleCoefficients(value: Boolean): this.type = set(scaleCoefficients, value) 165 | 166 | @Since("2.1.1") 167 | override protected def train(dataset: Dataset[_]): Array[(Int, Double)] = { 168 | val input = dataset.select($(featuresCol)).rdd.map { case Row(features: Vector) => MllibVectors.fromML(features) } 169 | val inputStats = Statistics.colStats(input) 170 | // Calculate maxValues = max(abs(min(feature)), abs(max(feature))) 171 | val absMaxValues = inputStats.max.toArray.map(elem => math.abs(elem)) 172 | val absMinValues = inputStats.min.toArray.map(elem => math.abs(elem)) 173 | 174 | val maxValues = if (!$(scaleCoefficients)) 175 | Array.fill(absMinValues.length)(1.0) 176 | else 177 | absMaxValues.zip(absMinValues).map { case (max, min) => math.max(max, min) } 178 | 179 | // Calculate normalized and absolute sum of LR weights for each feature 180 | val coeffVectors = $(coefficientMatrix).toArray.grouped($(coefficientMatrix).numRows).toArray 181 | 182 | val coeffFw = coeffVectors.map(col => col.map(elem => math.abs(elem)).sum) 183 | 184 | val scaled = coeffFw.zip(maxValues).map { case (coeff, max) => coeff * max } 185 | 186 | val coeffSum = scaled.sum 187 | 188 | val toScale = if ($(selectorType) == FeatureSelector.Random) { 189 | val coeffMean = if (scaled.length == 0) 0.0 else coeffSum / scaled.length 190 | scaled :+ coeffMean 191 | } else 192 | scaled 193 | 194 | val featureImportances = if (coeffSum == 0) toScale else toScale.map(elem => elem / coeffSum) 195 | 196 | featureImportances.zipWithIndex.map { case (value, index) => (index, value) } 197 | } 198 | 199 | @Since("2.1.1") 200 | override def transformSchema(schema: StructType): StructType = { 201 | val otherPairs = FeatureSelector.supportedSelectorTypes.filter(_ != $(selectorType)) 202 | otherPairs.foreach { paramName: String => 203 | if (isSet(getParam(paramName))) { 204 | logWarning(s"Param $paramName will take no effect when selector type = ${$(selectorType)}.") 205 | } 206 | } 207 | 208 | require(isDefined(coefficientMatrix)) 209 | SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) 210 | SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) 211 | } 212 | 213 | @Since("2.1.1") 214 | override def copy(extra: ParamMap): LRSelector = defaultCopy(extra) 215 | 216 | @Since("2.1.1") 217 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): LRSelectorModel = { 218 | new LRSelectorModel(uid, selectedFeatures, featureImportances) 219 | } 220 | } 221 | 222 | object LRSelector extends DefaultParamsReadable[LRSelector] { 223 | @Since("2.1.1") 224 | override def load(path: String): LRSelector = super.load(path) 225 | } 226 | 227 | /** 228 | * Model fitted by [[LRSelector]]. 229 | * @param uid of Model 230 | * @param selectedFeatures list of indices to select 231 | * @param featureImportances Map that stores each feature importance 232 | */ 233 | @Since("2.1.1") 234 | final class LRSelectorModel private[selection] (@Since("2.1.1") override val uid: String, 235 | @Since("2.1.1") override val selectedFeatures: Array[Int], 236 | @Since("2.1.1") override val featureImportances: Map[String, Double]) 237 | extends FeatureSelectorModel[LRSelectorModel](uid, selectedFeatures, featureImportances) with LRSelectorParams { 238 | 239 | @Since("2.1.1") 240 | override def copy(extra: ParamMap): LRSelectorModel = { 241 | val copied = new LRSelectorModel(uid, selectedFeatures, featureImportances) 242 | copyValues(copied, extra).setParent(parent) 243 | } 244 | 245 | @Since("2.1.1") 246 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter[LRSelectorModel](this) 247 | } 248 | 249 | @Since("2.1.1") 250 | object LRSelectorModel extends MLReadable[LRSelectorModel] { 251 | @Since("2.1.1") 252 | override def read: MLReader[LRSelectorModel] = new LRSelectorModelReader 253 | 254 | @Since("2.1.1") 255 | override def load(path: String): LRSelectorModel = super.load(path) 256 | } 257 | 258 | @Since("2.1.1") 259 | final class LRSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[LRSelectorModel] { 260 | @Since("2.1.1") 261 | override protected val className: String = classOf[LRSelectorModel].getName 262 | 263 | @Since("2.1.1") 264 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): LRSelectorModel = { 265 | new LRSelectorModel(uid, selectedFeatures, featureImportances) 266 | } 267 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/feature/selection/filter/CorrelationSelector.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.filter 19 | 20 | import org.apache.spark.SparkException 21 | import org.apache.spark.annotation.Since 22 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel} 23 | import org.apache.spark.ml.linalg.{DenseVector, _} 24 | import org.apache.spark.ml.param._ 25 | import org.apache.spark.ml.util._ 26 | import org.apache.spark.mllib.linalg.{Vectors => OldVectors} 27 | import org.apache.spark.mllib.stat.Statistics 28 | import org.apache.spark.sql._ 29 | import org.apache.spark.sql.functions._ 30 | 31 | import scala.collection.mutable 32 | 33 | /** 34 | * Feature selection based on Correlation (absolute value). 35 | */ 36 | @Since("2.1.1") 37 | final class CorrelationSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String) 38 | extends FeatureSelector[CorrelationSelector, CorrelationSelectorModel] { 39 | 40 | /** 41 | * The correlation type of the CorrelationSelector. 42 | * Supported options: "spearman" (default), "pearson". 43 | * 44 | * @group param 45 | */ 46 | @Since("2.1.1") 47 | val correlationType = new Param[String](this, "correlationType", 48 | "The correlation type used for correlation calculation. " + 49 | "Supported options: " + CorrelationSelector.supportedCorrelationTypes.mkString(", "), 50 | ParamValidators.inArray[String](CorrelationSelector.supportedCorrelationTypes)) 51 | setDefault(correlationType -> CorrelationSelector.Spearman) 52 | 53 | /** @group getParam */ 54 | @Since("2.1.1") 55 | def getCorrelationType: String = $(correlationType) 56 | 57 | /** @group setParam */ 58 | @Since("2.1.1") 59 | def setCorrelationType(value: String): this.type = set(correlationType, value) 60 | 61 | @Since("2.1.1") 62 | def this() = this(Identifiable.randomUID("correlationSelector")) 63 | 64 | @Since("2.1.1") 65 | override protected def train(dataset: Dataset[_]): Array[(Int, Double)] = { 66 | val assembleFunc = udf { r: Row => 67 | CorrelationSelector.assemble(r.toSeq: _*) 68 | } 69 | 70 | // Merge featuresCol and labelCol into one OldVector like this: [featuresCol, labelCol] 71 | val tmpDf = dataset.select(assembleFunc(struct(dataset($(featuresCol)), dataset($(labelCol)))).as("input")) 72 | val input = tmpDf.select(col("input")).rdd.map { case Row(features: Vector) => OldVectors.fromML(features) } 73 | 74 | // Calculate correlation between all columns in input 75 | // We're only interested in the last row/ column (correlationmatrix is symmetric) of correlation, 76 | // which stands for the correlation between all feature columns and the label column 77 | val correlations = Statistics.corr(input, ${correlationType}) 78 | 79 | // Extract the information we're interested in 80 | val columns = correlations.toArray.grouped(correlations.numRows) 81 | val corrList = columns.map(column => new DenseVector(column)).toList 82 | val targetCorr = corrList.last.toArray 83 | 84 | // targetCorr.length - 1, because last element is correlation between label column and itself 85 | // Also take abs(), because important features are also negatively correlated 86 | val targetCorrs = Array.tabulate(targetCorr.length - 1) { i => (i, targetCorr(i)) } 87 | .map(elem => (elem._1, math.abs(elem._2))) 88 | 89 | targetCorrs.filter(!_._2.isNaN) 90 | } 91 | 92 | @Since("2.1.1") 93 | override def copy(extra: ParamMap): CorrelationSelector = defaultCopy(extra) 94 | 95 | @Since("2.1.1") 96 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): CorrelationSelectorModel = { 97 | new CorrelationSelectorModel(uid, selectedFeatures, featureImportances) 98 | } 99 | } 100 | 101 | /** 102 | * Model fitted by [[CorrelationSelector]]. 103 | * @param selectedFeatures list of indices to select (filter) 104 | */ 105 | @Since("2.1.1") 106 | final class CorrelationSelectorModel private[selection] (@Since("2.1.1") override val uid: String, 107 | @Since("2.1.1") override val selectedFeatures: Array[Int], 108 | @Since("2.1.1") override val featureImportances: Map[String, Double]) 109 | extends FeatureSelectorModel[CorrelationSelectorModel](uid, selectedFeatures, featureImportances) { 110 | /** 111 | * The correlation type of the CorrelationSelector. 112 | * Supported options: "spearman" (default), "pearson". 113 | * 114 | * @group param 115 | */ 116 | @Since("2.1.1") 117 | val correlationType = new Param[String](this, "correlationType", 118 | "The correlation type used for correlation calculation. " + 119 | "Supported options: " + CorrelationSelector.supportedCorrelationTypes.mkString(", "), 120 | ParamValidators.inArray[String](CorrelationSelector.supportedCorrelationTypes)) 121 | setDefault(correlationType -> CorrelationSelector.Spearman) 122 | 123 | /** @group getParam */ 124 | @Since("2.1.1") 125 | def getCorrelationType: String = $(correlationType) 126 | 127 | @Since("2.1.1") 128 | override def copy(extra: ParamMap): CorrelationSelectorModel = { 129 | val copied = new CorrelationSelectorModel(uid, selectedFeatures, featureImportances) 130 | copyValues(copied, extra).setParent(parent) 131 | } 132 | 133 | @Since("2.1.1") 134 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter[CorrelationSelectorModel](this) 135 | } 136 | 137 | @Since("2.1.1") 138 | object CorrelationSelectorModel extends MLReadable[CorrelationSelectorModel] { 139 | @Since("2.1.1") 140 | override def read: MLReader[CorrelationSelectorModel] = new CorrelationSelectorModelReader 141 | 142 | @Since("2.1.1") 143 | override def load(path: String): CorrelationSelectorModel = super.load(path) 144 | } 145 | 146 | @Since("2.1.1") 147 | final class CorrelationSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[CorrelationSelectorModel]{ 148 | @Since("2.1.1") 149 | override protected val className: String = classOf[CorrelationSelectorModel].getName 150 | 151 | @Since("2.1.1") 152 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): CorrelationSelectorModel = { 153 | new CorrelationSelectorModel(uid, selectedFeatures, featureImportances) 154 | } 155 | } 156 | 157 | private[filter] object CorrelationSelector extends DefaultParamsReadable[CorrelationSelector] { 158 | @Since("2.1.1") 159 | override def load(path: String): CorrelationSelector = super.load(path) 160 | 161 | /** 162 | * String name for `pearson` correlation type. 163 | */ 164 | val Pearson: String = "pearson" 165 | 166 | /** 167 | * String name for `spearman` correlation type. 168 | */ 169 | val Spearman: String = "spearman" 170 | 171 | /** Set of correlation types that CorrelationSelector supports. */ 172 | val supportedCorrelationTypes: Array[String] = Array(Pearson, Spearman) 173 | 174 | // From VectorAssembler 175 | private def assemble(vv: Any*): Vector = { 176 | val indices = mutable.ArrayBuilder.make[Int] 177 | val values = mutable.ArrayBuilder.make[Double] 178 | var cur = 0 179 | vv.foreach { 180 | case v: Double => 181 | if (v != 0.0) { 182 | indices += cur 183 | values += v 184 | } 185 | cur += 1 186 | case vec: Vector => 187 | vec.foreachActive { case (i, v) => 188 | if (v != 0.0) { 189 | indices += cur + i 190 | values += v 191 | } 192 | } 193 | cur += vec.size 194 | case null => 195 | throw new SparkException("Values to assemble cannot be null.") 196 | case o => 197 | throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") 198 | } 199 | Vectors.sparse(cur, indices.result(), values.result()).compressed 200 | } 201 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/feature/selection/filter/GiniSelector.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.filter 19 | 20 | import org.apache.spark.annotation.Since 21 | import org.apache.spark.ml.feature.LabeledPoint 22 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel} 23 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, _} 24 | import org.apache.spark.ml.param._ 25 | import org.apache.spark.ml.util._ 26 | import org.apache.spark.rdd.RDD 27 | import org.apache.spark.sql._ 28 | import org.apache.spark.sql.functions._ 29 | import org.apache.spark.sql.types.DoubleType 30 | 31 | /** 32 | * Feature selection based on Gini index. 33 | */ 34 | @Since("2.1.1") 35 | final class GiniSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String) 36 | extends FeatureSelector[GiniSelector, GiniSelectorModel] { 37 | 38 | @Since("2.1.1") 39 | def this() = this(Identifiable.randomUID("giniSelector")) 40 | 41 | @Since("2.1.1") 42 | override def train(dataset: Dataset[_]): Array[(Int, Double)] = { 43 | val input: RDD[LabeledPoint] = 44 | dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { 45 | case Row(label: Double, features: Vector) => 46 | LabeledPoint(label, features) 47 | } 48 | 49 | // Calculate gini indices of all features 50 | new GiniCalculator(input).calculateGini().collect() 51 | } 52 | 53 | @Since("2.1.1") 54 | override def copy(extra: ParamMap): GiniSelector = defaultCopy(extra) 55 | 56 | @Since("2.1.1") 57 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): GiniSelectorModel = { 58 | new GiniSelectorModel(uid, selectedFeatures, featureImportances) 59 | } 60 | } 61 | 62 | object GiniSelector extends DefaultParamsReadable[GiniSelector] { 63 | @Since("2.1.1") 64 | override def load(path: String): GiniSelector = super.load(path) 65 | } 66 | 67 | /** 68 | * Model fitted by [[GiniSelector]]. 69 | * @param selectedFeatures list of indices to select (filter) 70 | */ 71 | @Since("2.1.1") 72 | final class GiniSelectorModel private[selection] (@Since("2.1.1") override val uid: String, 73 | @Since("2.1.1") override val selectedFeatures: Array[Int], 74 | @Since("2.1.1") override val featureImportances: Map[String, Double]) 75 | extends FeatureSelectorModel[GiniSelectorModel](uid, selectedFeatures, featureImportances) { 76 | 77 | @Since("2.1.1") 78 | override def copy(extra: ParamMap): GiniSelectorModel = { 79 | val copied = new GiniSelectorModel(uid, selectedFeatures, featureImportances) 80 | copyValues(copied, extra).setParent(parent) 81 | } 82 | 83 | @Since("2.1.1") 84 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter(this) 85 | } 86 | 87 | @Since("2.1.1") 88 | object GiniSelectorModel extends MLReadable[GiniSelectorModel] { 89 | @Since("2.1.1") 90 | override def read: MLReader[GiniSelectorModel] = new GiniSelectorModelReader 91 | 92 | @Since("2.1.1") 93 | override def load(path: String): GiniSelectorModel = super.load(path) 94 | } 95 | 96 | @Since("2.1.1") 97 | final class GiniSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[GiniSelectorModel]{ 98 | @Since("2.1.1") 99 | override protected val className: String = classOf[GiniSelectorModel].getName 100 | 101 | @Since("2.1.1") 102 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): GiniSelectorModel = { 103 | new GiniSelectorModel(uid, selectedFeatures, featureImportances) 104 | } 105 | } 106 | 107 | private [filter] class GiniCalculator (val data: RDD[LabeledPoint]) { 108 | def calculateGini(): RDD[(Int, Double)] = { 109 | val labels2Int = data.map(_.label).distinct.collect.zipWithIndex.toMap 110 | val nLabels = labels2Int.size 111 | 112 | // Basic info. about the dataset 113 | val classDistrib = data.map(d => labels2Int(d.label)).countByValue().toMap 114 | val count = data.count 115 | 116 | // Generate pairs ((featureID, featureVal), (Hot encoded) targetVal) 117 | val featureValues = 118 | data.flatMap({ 119 | case LabeledPoint(label, dv: DenseVector) => 120 | val c = Array.fill[Long](nLabels)(0L) 121 | c(labels2Int(label)) = 1L 122 | for (i <- dv.values.indices) yield ((i, dv(i).toFloat), c) 123 | case LabeledPoint(label, sv: SparseVector) => 124 | val c = Array.fill[Long](nLabels)(0L) 125 | c(labels2Int(label)) = 1L 126 | for (i <- sv.indices.indices) yield ((sv.indices(i), sv.values(i).toFloat), c) 127 | }) 128 | 129 | // Calculate Gini Indices for all features 130 | // P(X_j) - prior probability of feature X having value X_j ([[featurePrior]]) 131 | // P(Y_c | X_j) - cond. probability that a sample is of class Y_c, given that feature X has value X_j ([[condProbabClassGivenFeature]]) 132 | // P(Y_c) - prior probability that the label Y has value Y_c ([[classPrior]]) 133 | // Gini(X) = Sum_j{P(X_j) * Sum_c{P(Y_c | X_j)²}} - Sum_c{P(Y_c)} 134 | // Step 1: (featureNumber, List[featureValue, Array[classCount]]) 135 | val featureClassDistrib = getSortedDistinctValues(classDistrib, featureValues) 136 | .map { case ((fk, fv), cd) => (fk, (fv, cd)) } 137 | .groupBy { case (fk, (_, _)) => fk } 138 | .map { case (fk, arr) => 139 | (fk, arr.map { case (_, value) => value }) 140 | } 141 | 142 | // Step 2: Calculate class priors and right sum 143 | val classPrior = classDistrib.map { case (k, v) => (k, v.toDouble / count) } 144 | val rightPart = classPrior.foldLeft(0.0) { case (agg, (_, v)) => agg + v * v } 145 | 146 | // Step 3: Calculate left sum 147 | val condProbabClassGivenFeatureValue = featureClassDistrib 148 | .map { case (fk, arr) => (fk, arr 149 | .map { case (fv, cd) => (fv, cd 150 | .map(cVal => cVal.toDouble / cd.sum)) 151 | }) 152 | } 153 | val condProbabClassGivenFeatureValueSum = condProbabClassGivenFeatureValue 154 | .map { case (fk, arr) => (fk, arr 155 | .map { case (fv, cc) => (fv, cc 156 | .map(c => c * c).sum) 157 | }) 158 | } 159 | .map { case (fk, arr) => (fk, arr.toList 160 | .sortBy(_._1)) 161 | } 162 | .sortBy(_._1) 163 | 164 | val featurePrior = featureClassDistrib 165 | .map { case (fk, arr) => (fk, arr 166 | .map { case (fv, cc) => (fv, cc.sum.toDouble / count) }) 167 | } 168 | .map { case (fk, arr) => (fk, arr.toList.sortBy(_._1)) 169 | } 170 | .sortBy(_._1) 171 | 172 | // Parallelize again, so zip won't fail 173 | val spark = condProbabClassGivenFeatureValueSum.sparkContext 174 | val tmp = spark.parallelize(condProbabClassGivenFeatureValueSum.collect()) 175 | .zip(spark.parallelize(featurePrior.collect())) 176 | .map { case ((fk, arr), (fk2, arr2)) => if (fk == fk2) { 177 | (fk, arr.zip(arr2)) 178 | } else { 179 | throw new IllegalStateException("Featurekeys don't match! This should never happen.") 180 | } 181 | } 182 | 183 | val leftPart = tmp.map { case (fk, calc) => (fk, calc.foldLeft(0.0) { 184 | case (agg, (x, y)) => if (x._1 == y._1) { 185 | agg + x._2 * y._2 186 | } else { 187 | throw new IllegalStateException("Featurevalues don't match! This should never happen.") 188 | } 189 | }) 190 | } 191 | 192 | // Step 3: Calculate Gini indices and return 193 | leftPart.map { case (fk, value) => (fk, value - rightPart) } 194 | } 195 | 196 | /** 197 | * Group elements by feature and point (get distinct points). 198 | * Since values like (0, Float.NaN) are not considered unique when calling reduceByKey, 199 | * use the serialized version of the tuple. 200 | * 201 | * @return sorted list of unique feature values 202 | */ 203 | private def getSortedDistinctValues(classDistrib: Map[Int, Long], 204 | featureValues: RDD[((Int, Float), Array[Long])]): RDD[((Int, Float), Array[Long])] = { 205 | 206 | val nonZeros: RDD[((Int, Float), Array[Long])] = 207 | featureValues.map(y => (y._1._1 + "," + y._1._2, y._2)).reduceByKey { case (v1, v2) => 208 | (v1, v2).zipped.map(_ + _) 209 | }.map(y => { 210 | val s = y._1.split(",") 211 | ((s(0).toInt, s(1).toFloat), y._2) 212 | }) 213 | 214 | val zeros = addZerosIfNeeded(nonZeros, classDistrib) 215 | val distinctValues = nonZeros.union(zeros) 216 | 217 | // Sort these values to perform the boundary points evaluation 218 | distinctValues.sortByKey() 219 | } 220 | 221 | /** 222 | * Add zeros if dealing with sparse data 223 | * 224 | * @return rdd with 0's filled in 225 | */ 226 | private def addZerosIfNeeded(nonZeros: RDD[((Int, Float), Array[Long])], 227 | classDistrib: Map[Int, Long]): RDD[((Int, Float), Array[Long])] = { 228 | nonZeros.map { case ((k, _), v) => (k, v) } 229 | .reduceByKey { case (v1, v2) => (v1, v2).zipped.map(_ + _) } 230 | .map { case (k, v) => 231 | val v2 = for (i <- v.indices) yield classDistrib(i) - v(i) 232 | ((k, 0.0F), v2.toArray) 233 | }.filter { case (_, v) => v.sum > 0 } 234 | } 235 | 236 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/feature/selection/filter/InfoGainSelector.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.filter 19 | 20 | import org.apache.spark.annotation.Since 21 | import org.apache.spark.ml.feature.LabeledPoint 22 | import org.apache.spark.ml.feature.selection.{FeatureSelector, FeatureSelectorModel} 23 | import org.apache.spark.ml.linalg.{DenseVector, SparseVector, _} 24 | import org.apache.spark.ml.param._ 25 | import org.apache.spark.ml.util._ 26 | import org.apache.spark.rdd.RDD 27 | import org.apache.spark.sql._ 28 | import org.apache.spark.sql.functions._ 29 | import org.apache.spark.sql.types.DoubleType 30 | 31 | /** 32 | * Feature selection based on Information Gain. 33 | */ 34 | @Since("2.1.1") 35 | final class InfoGainSelector @Since("2.1.1") (@Since("2.1.1") override val uid: String) 36 | extends FeatureSelector[InfoGainSelector, InfoGainSelectorModel] { 37 | 38 | @Since("2.1.1") 39 | def this() = this(Identifiable.randomUID("igSelector")) 40 | 41 | @Since("2.1.1") 42 | override def train(dataset: Dataset[_]): Array[(Int, Double)] = { 43 | val input: RDD[LabeledPoint] = 44 | dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { 45 | case Row(label: Double, features: Vector) => 46 | LabeledPoint(label, features) 47 | } 48 | 49 | // Calculate gains of all features (features that are always zero will be dropped) 50 | new InfoGainCalculator(input).calculateIG().collect() 51 | } 52 | 53 | @Since("2.1.1") 54 | override def copy(extra: ParamMap): InfoGainSelector = defaultCopy(extra) 55 | 56 | @Since("2.1.1") 57 | protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): InfoGainSelectorModel = { 58 | new InfoGainSelectorModel(uid, selectedFeatures, featureImportances) 59 | } 60 | } 61 | 62 | object InfoGainSelector extends DefaultParamsReadable[InfoGainSelector] { 63 | @Since("2.1.1") 64 | override def load(path: String): InfoGainSelector = super.load(path) 65 | } 66 | 67 | /** 68 | * Model fitted by [[InfoGainSelector]]. 69 | * @param selectedFeatures list of indices to select (filter) 70 | */ 71 | @Since("2.1.1") 72 | final class InfoGainSelectorModel private[filter] (@Since("2.1.1") override val uid: String, 73 | @Since("2.1.1") override val selectedFeatures: Array[Int], 74 | @Since("2.1.1") override val featureImportances: Map[String, Double]) 75 | extends FeatureSelectorModel[InfoGainSelectorModel](uid, selectedFeatures, featureImportances) { 76 | 77 | @Since("2.1.1") 78 | override def copy(extra: ParamMap): InfoGainSelectorModel = { 79 | val copied = new InfoGainSelectorModel(uid, selectedFeatures, featureImportances) 80 | copyValues(copied, extra).setParent(parent) 81 | } 82 | 83 | @Since("2.1.1") 84 | override def write: MLWriter = new FeatureSelectorModel.FeatureSelectorModelWriter[InfoGainSelectorModel](this) 85 | } 86 | 87 | @Since("2.1.1") 88 | object InfoGainSelectorModel extends MLReadable[InfoGainSelectorModel] { 89 | @Since("2.1.1") 90 | override def read: MLReader[InfoGainSelectorModel] = new InfoGainSelectorModelReader 91 | 92 | @Since("2.1.1") 93 | override def load(path: String): InfoGainSelectorModel = super.load(path) 94 | } 95 | 96 | @Since("2.1.1") 97 | final class InfoGainSelectorModelReader extends FeatureSelectorModel.FeatureSelectorModelReader[InfoGainSelectorModel]{ 98 | @Since("2.1.1") 99 | override protected val className: String = classOf[InfoGainSelectorModel].getName 100 | 101 | @Since("2.1.1") 102 | override protected def make(uid: String, selectedFeatures: Array[Int], featureImportances: Map[String, Double]): InfoGainSelectorModel = { 103 | new InfoGainSelectorModel(uid, selectedFeatures, featureImportances) 104 | } 105 | } 106 | 107 | private [filter] class InfoGainCalculator (val data: RDD[LabeledPoint]) { 108 | def calculateIG(): RDD[(Int, Double)] = { 109 | val LOG2 = math.log(2) 110 | 111 | /** log base 2 of x 112 | * @return log base 2 of x */ 113 | val log2 = { x: Double => math.log(x) / LOG2 } 114 | /** entropy of x 115 | * @return entropy of x */ 116 | val entropy = { x: Double => if (x == 0) 0 else -x * log2(x) } 117 | 118 | val labels2Int = data.map(_.label).distinct.collect.zipWithIndex.toMap 119 | val nLabels = labels2Int.size 120 | 121 | // Basic info. about the dataset 122 | val classDistrib = data.map(d => labels2Int(d.label)).countByValue().toMap 123 | 124 | // Generate pairs ((featureID, featureVal), (Hot encoded) targetVal) 125 | val featureValues = 126 | data.flatMap({ 127 | case LabeledPoint(label, dv: DenseVector) => 128 | val c = Array.fill[Long](nLabels)(0L) 129 | c(labels2Int(label)) = 1L 130 | for (i <- dv.values.indices) yield ((i, dv(i).toFloat), c) 131 | case LabeledPoint(label, sv: SparseVector) => 132 | val c = Array.fill[Long](nLabels)(0L) 133 | c(labels2Int(label)) = 1L 134 | for (i <- sv.indices.indices) yield ((sv.indices(i), sv.values(i).toFloat), c) 135 | }) 136 | 137 | val sortedValues = getSortedDistinctValues(classDistrib, featureValues) 138 | 139 | val numSamples = classDistrib.values.sum 140 | 141 | // Calculate Probabilities 142 | val classDistribProb = classDistrib.map { case (k, v) => (k, v.toDouble / numSamples) } 143 | val featureProbs = sortedValues.map { case ((k, v), a) => ((k, v), a.sum.toDouble / numSamples) } 144 | val jointProbab = sortedValues.map { case ((k, v), a) => ((k, v), a.map(elem => elem.toDouble / numSamples)) } 145 | 146 | val jpTable = jointProbab.groupBy { case ((k, v), a) => k } 147 | val fpTable = featureProbs.groupBy { case ((k, v), a) => k } 148 | 149 | // Calculate entropies 150 | val featureEntropies = fpTable.map { case (k, v) => (k, v.map { case ((k1, v1), a) => entropy(a) }.sum) }.sortByKey(ascending = true) 151 | val jointEntropies = jpTable.map { case (k, v) => (k, v.map { case ((k1, v1), a) => a.map(v2 => entropy(v2)).sum }.sum) }.sortByKey(ascending = true) 152 | val targetEntropy = classDistribProb.foldLeft(0.0) { case (acc, (k, v)) => acc + entropy(v) } 153 | 154 | // Calculate information gain: targetEntropy + featureEntropy - jointEntropy 155 | // Format: RDD[(featureID->InformationGain)] 156 | 157 | val spark = featureEntropies.sparkContext 158 | 159 | spark.parallelize(featureEntropies.collect()).zip(spark.parallelize(jointEntropies.collect())).map { case ((k1, v1), (k2, v2)) => k1 -> (targetEntropy + v1 - v2) } 160 | 161 | // customZip(featureEntropies, jointEntropies).map { case ((k1, v1), (k2, v2)) => k1->(targetEntropy + v1 - v2)} 162 | } 163 | 164 | /** 165 | * Group elements by feature and point (get distinct points). 166 | * Since values like (0, Float.NaN) are not considered unique when calling reduceByKey, 167 | * use the serialized version of the tuple. 168 | * 169 | * @return sorted list of unique feature values 170 | */ 171 | private def getSortedDistinctValues(classDistrib: Map[Int, Long], 172 | featureValues: RDD[((Int, Float), Array[Long])]): RDD[((Int, Float), Array[Long])] = { 173 | 174 | val nonZeros: RDD[((Int, Float), Array[Long])] = 175 | featureValues.map(y => (y._1._1 + "," + y._1._2, y._2)).reduceByKey { case (v1, v2) => 176 | (v1, v2).zipped.map(_ + _) 177 | }.map(y => { 178 | val s = y._1.split(",") 179 | ((s(0).toInt, s(1).toFloat), y._2) 180 | }) 181 | 182 | val zeros = addZerosIfNeeded(nonZeros, classDistrib) 183 | val distinctValues = nonZeros.union(zeros) 184 | 185 | // Sort these values to perform the boundary points evaluation 186 | distinctValues.sortByKey() 187 | } 188 | 189 | /** 190 | * Add zeros if dealing with sparse data 191 | * Features that do not have any non-zero value will not be added 192 | * 193 | * @return rdd with 0's filled in 194 | */ 195 | private def addZerosIfNeeded(nonZeros: RDD[((Int, Float), Array[Long])], 196 | classDistrib: Map[Int, Long]): RDD[((Int, Float), Array[Long])] = { 197 | nonZeros.map { case ((k, p), v) => (k, v) } 198 | .reduceByKey { case (v1, v2) => (v1, v2).zipped.map(_ + _) } 199 | .map { case (k, v) => 200 | val v2 = for (i <- v.indices) yield classDistrib(i) - v(i) 201 | ((k, 0.0F), v2.toArray) 202 | }.filter { case (_, v) => v.sum > 0 } 203 | } 204 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/ml/feature/selection/util/VectorMerger.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.util 19 | 20 | import org.apache.spark.SparkException 21 | import org.apache.spark.annotation.Since 22 | import org.apache.spark.ml.Transformer 23 | import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} 24 | import org.apache.spark.ml.feature.VectorSlicer 25 | import org.apache.spark.ml.linalg.{Vector, VectorUDT, Vectors} 26 | import org.apache.spark.ml.param.shared._ 27 | import org.apache.spark.ml.param.{BooleanParam, ParamMap} 28 | import org.apache.spark.ml.util._ 29 | import org.apache.spark.sql.functions._ 30 | import org.apache.spark.sql.types._ 31 | import org.apache.spark.sql.{DataFrame, Dataset, Row} 32 | 33 | import scala.collection.mutable 34 | 35 | /** 36 | * A feature transformer that merges multiple columns into a vector column, without keeping duplicates. 37 | * The Transformer has two modes, triggered by useFeaturesCol: 38 | * 1) useFeaturesCol true and featuresCol set: the output column will contain columns from featuresCol that have 39 | * names appearing in one of the inputCols (type vector) 40 | * 2) useFeaturesCol false: the output column will contain the columns from the inputColumns, but dropping duplicates 41 | */ 42 | @Since("2.1.1") 43 | class VectorMerger @Since("2.1.1") (@Since("2.1.1") override val uid: String) 44 | extends Transformer with HasFeaturesCol with HasInputCols with HasOutputCol with DefaultParamsWritable { 45 | 46 | @Since("2.1.1") 47 | def this() = this(Identifiable.randomUID("vecAssemblerMerger")) 48 | 49 | /** @group setParam */ 50 | @Since("2.1.1") 51 | def setInputCols(value: Array[String]): this.type = set(inputCols, value) 52 | 53 | /** @group setParam */ 54 | @Since("2.1.1") 55 | def setOutputCol(value: String): this.type = set(outputCol, value) 56 | 57 | /** @group setParam */ 58 | @Since("2.1.1") 59 | def setFeatureCol(value: String): this.type = set(featuresCol, value) 60 | 61 | /** @group param */ 62 | @Since("2.1.1") 63 | final val useFeaturesCol = new BooleanParam(this, "useFeaturesCol", 64 | "The output column will contain columns from featuresCol that have names appearing in one of the inputCols (type vector)") 65 | setDefault(useFeaturesCol, true) 66 | 67 | /** @group getParam */ 68 | @Since("2.1.1") 69 | def getUseFeaturesCol: Boolean = $(useFeaturesCol) 70 | 71 | /** @group setParam */ 72 | @Since("2.1.1") 73 | def setUseFeaturesCol(value: Boolean): this.type = set(useFeaturesCol, value) 74 | 75 | @Since("2.1.1") 76 | override def transform(dataset: Dataset[_]): DataFrame = { 77 | transformSchema(dataset.schema, logging = true) 78 | 79 | if ($(useFeaturesCol)) 80 | transformUsingFeaturesColumn(dataset) 81 | else 82 | transformUsingInputColumns(dataset) 83 | } 84 | 85 | @Since("2.1.1") 86 | private def transformUsingInputColumns(dataset: Dataset[_]): DataFrame = { 87 | // Schema transformation. 88 | val schema = dataset.schema 89 | lazy val first = dataset.toDF.first() 90 | 91 | val uniqueNames = mutable.ArrayBuffer[String]() 92 | 93 | val indicesToKeep = mutable.ArrayBuilder.make[Int] 94 | var cur = 0 95 | 96 | def doNotAdd(): Option[Nothing] = { 97 | cur += 1 98 | None 99 | } 100 | 101 | def addAndIncrementIndex() { 102 | indicesToKeep += cur 103 | cur += 1 104 | } 105 | 106 | val attrs: Array[Attribute] = $(inputCols).flatMap { c => 107 | val field = schema(c) 108 | val index = schema.fieldIndex(c) 109 | field.dataType match { 110 | case _: VectorUDT => 111 | val group = AttributeGroup.fromStructField(field) 112 | if (group.attributes.isDefined) { 113 | // If attributes are defined, copy them, checking for duplicates and preserving name. 114 | group.attributes.get.zipWithIndex.flatMap { case (attr, i) => 115 | if (attr.name.isDefined) { 116 | val name = attr.name.get 117 | if (!uniqueNames.contains(name)) { 118 | addAndIncrementIndex() 119 | uniqueNames.append(name) 120 | Some(attr) 121 | } else 122 | doNotAdd() 123 | } else { 124 | addAndIncrementIndex() 125 | Some(attr.withName(c + "_" + i)) 126 | } 127 | }.toList 128 | } else { 129 | // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes 130 | // from metadata, check the first row. 131 | val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) 132 | Array.tabulate(numAttrs)(i => { 133 | addAndIncrementIndex() 134 | NumericAttribute.defaultAttr.withName(c + "_" + i) 135 | }) 136 | } 137 | case otherType => 138 | throw new SparkException(s"VectorMerger does not support the $otherType type") 139 | } 140 | } 141 | val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() 142 | 143 | // Data transformation. 144 | val assembleFunc = udf { r: Row => 145 | VectorMerger.assemble(indicesToKeep.result(), r.toSeq: _*) 146 | } 147 | 148 | val args = $(inputCols).map { c => 149 | schema(c).dataType match { 150 | case DoubleType => dataset(c) 151 | case _: VectorUDT => dataset(c) 152 | case _: NumericType | BooleanType => dataset(c).cast(DoubleType).as(s"${c}_double_$uid") 153 | } 154 | } 155 | 156 | dataset.select(col("*"), assembleFunc(struct(args: _*)).as($(outputCol), metadata)) 157 | } 158 | 159 | @Since("2.1.1") 160 | private def transformUsingFeaturesColumn(dataset: Dataset[_]): DataFrame = { 161 | // Schema transformation. 162 | val schema = dataset.schema 163 | 164 | val notUniqueNames = mutable.ArrayBuffer[String]() 165 | val featuresColName = $(featuresCol) 166 | val featureColAttrs = AttributeGroup.fromStructField(schema(featuresColName)).attributes.get.zipWithIndex 167 | val featuresColNames = featureColAttrs.flatMap { case (attr, _) => attr.name } 168 | 169 | $(inputCols).foreach { c => 170 | val field = schema(c) 171 | field.dataType match { 172 | case _: VectorUDT => 173 | val group = AttributeGroup.fromStructField(field) 174 | if (group.attributes.isDefined) { 175 | // If attributes are defined, remember name to get column from $featureCol. 176 | group.attributes.get.zipWithIndex.foreach { case (attr, _) => 177 | if (attr.name.isDefined) { 178 | val name = attr.name.get 179 | if (featuresColNames.contains(name)) 180 | notUniqueNames.append(name) 181 | else 182 | throw new IllegalArgumentException(s"Features column $featuresColName does not contain column with name $name!") 183 | } else 184 | throw new IllegalArgumentException(s"Input column $c contains column without name attribute!") 185 | } 186 | } else { 187 | // Otherwise, merging not possible 188 | throw new IllegalArgumentException(s"Input column $c does not contain attributes!") 189 | } 190 | } 191 | } 192 | 193 | val uniqueNames = notUniqueNames.toSet 194 | 195 | new VectorSlicer() 196 | .setInputCol(featuresColName) 197 | .setNames(uniqueNames.toArray) 198 | .setOutputCol($(outputCol)) 199 | .transform(dataset) 200 | } 201 | 202 | @Since("2.1.1") 203 | override def transformSchema(schema: StructType): StructType = { 204 | if ($(useFeaturesCol)) { 205 | val featuresColName = $(featuresCol) 206 | 207 | if (!schema(featuresColName).dataType.isInstanceOf[VectorUDT]) 208 | throw new IllegalArgumentException(s"Features column $featuresColName is not of type VectorUDT!") 209 | } 210 | 211 | val inputColNames = $(inputCols) 212 | val outputColName = $(outputCol) 213 | 214 | inputColNames.foreach(name => if (!schema(name).dataType.isInstanceOf[VectorUDT]) 215 | throw new IllegalArgumentException(s"Input column $name is not of type VectorUDT!") 216 | ) 217 | 218 | if (schema.fieldNames.contains(outputColName)) { 219 | throw new IllegalArgumentException(s"Output column $outputColName already exists.") 220 | } 221 | 222 | StructType(schema.fields :+ StructField(outputColName, new VectorUDT, nullable = true)) 223 | } 224 | 225 | @Since("2.1.1") 226 | override def copy(extra: ParamMap): VectorMerger = defaultCopy(extra) 227 | } 228 | 229 | @Since("2.1.1") 230 | object VectorMerger extends DefaultParamsReadable[VectorMerger] { 231 | 232 | @Since("2.1.1") 233 | override def load(path: String): VectorMerger = super.load(path) 234 | 235 | private[feature] def assemble(indicesToKeep: Array[Int], vv: Any*): Vector = { 236 | val indices = mutable.ArrayBuilder.make[Int] 237 | val values = mutable.ArrayBuilder.make[Double] 238 | 239 | var returnCur = 0 240 | var globalCur = 0 241 | vv.foreach { 242 | case v: Double => 243 | if (indicesToKeep.contains(globalCur)) { 244 | if (v != 0.0) { 245 | indices += returnCur 246 | values += v 247 | } 248 | returnCur += 1 249 | } 250 | globalCur += 1 251 | case vec: Vector => 252 | vec.toDense.foreachActive { case (_, v) => 253 | if (indicesToKeep.contains(globalCur)) { 254 | if (v != 0.0) { 255 | indices += returnCur 256 | values += v 257 | } 258 | returnCur += 1 259 | } 260 | globalCur += 1 261 | } 262 | case null => 263 | throw new SparkException("Values to assemble cannot be null.") 264 | case o => 265 | throw new SparkException(s"$o of type ${o.getClass.getName} is not supported.") 266 | } 267 | Vectors.sparse(returnCur, indices.result(), values.result()).compressed 268 | } 269 | } 270 | -------------------------------------------------------------------------------- /src/test/resources/iris.data: -------------------------------------------------------------------------------- 1 | sLength,sWidth,pLength,pWidth,Species 2 | 5.1,3.5,1.4,0.2,0.0 3 | 4.9,3.0,1.4,0.2,0.0 4 | 4.7,3.2,1.3,0.2,0.0 5 | 4.6,3.1,1.5,0.2,0.0 6 | 5.0,3.6,1.4,0.2,0.0 7 | 5.4,3.9,1.7,0.4,0.0 8 | 4.6,3.4,1.4,0.3,0.0 9 | 5.0,3.4,1.5,0.2,0.0 10 | 4.4,2.9,1.4,0.2,0.0 11 | 4.9,3.1,1.5,0.1,0.0 12 | 5.4,3.7,1.5,0.2,0.0 13 | 4.8,3.4,1.6,0.2,0.0 14 | 4.8,3.0,1.4,0.1,0.0 15 | 4.3,3.0,1.1,0.1,0.0 16 | 5.8,4.0,1.2,0.2,0.0 17 | 5.7,4.4,1.5,0.4,0.0 18 | 5.4,3.9,1.3,0.4,0.0 19 | 5.1,3.5,1.4,0.3,0.0 20 | 5.7,3.8,1.7,0.3,0.0 21 | 5.1,3.8,1.5,0.3,0.0 22 | 5.4,3.4,1.7,0.2,0.0 23 | 5.1,3.7,1.5,0.4,0.0 24 | 4.6,3.6,1.0,0.2,0.0 25 | 5.1,3.3,1.7,0.5,0.0 26 | 4.8,3.4,1.9,0.2,0.0 27 | 5.0,3.0,1.6,0.2,0.0 28 | 5.0,3.4,1.6,0.4,0.0 29 | 5.2,3.5,1.5,0.2,0.0 30 | 5.2,3.4,1.4,0.2,0.0 31 | 4.7,3.2,1.6,0.2,0.0 32 | 4.8,3.1,1.6,0.2,0.0 33 | 5.4,3.4,1.5,0.4,0.0 34 | 5.2,4.1,1.5,0.1,0.0 35 | 5.5,4.2,1.4,0.2,0.0 36 | 4.9,3.1,1.5,0.1,0.0 37 | 5.0,3.2,1.2,0.2,0.0 38 | 5.5,3.5,1.3,0.2,0.0 39 | 4.9,3.1,1.5,0.1,0.0 40 | 4.4,3.0,1.3,0.2,0.0 41 | 5.1,3.4,1.5,0.2,0.0 42 | 5.0,3.5,1.3,0.3,0.0 43 | 4.5,2.3,1.3,0.3,0.0 44 | 4.4,3.2,1.3,0.2,0.0 45 | 5.0,3.5,1.6,0.6,0.0 46 | 5.1,3.8,1.9,0.4,0.0 47 | 4.8,3.0,1.4,0.3,0.0 48 | 5.1,3.8,1.6,0.2,0.0 49 | 4.6,3.2,1.4,0.2,0.0 50 | 5.3,3.7,1.5,0.2,0.0 51 | 5.0,3.3,1.4,0.2,0.0 52 | 7.0,3.2,4.7,1.4,1.0 53 | 6.4,3.2,4.5,1.5,1.0 54 | 6.9,3.1,4.9,1.5,1.0 55 | 5.5,2.3,4.0,1.3,1.0 56 | 6.5,2.8,4.6,1.5,1.0 57 | 5.7,2.8,4.5,1.3,1.0 58 | 6.3,3.3,4.7,1.6,1.0 59 | 4.9,2.4,3.3,1.0,1.0 60 | 6.6,2.9,4.6,1.3,1.0 61 | 5.2,2.7,3.9,1.4,1.0 62 | 5.0,2.0,3.5,1.0,1.0 63 | 5.9,3.0,4.2,1.5,1.0 64 | 6.0,2.2,4.0,1.0,1.0 65 | 6.1,2.9,4.7,1.4,1.0 66 | 5.6,2.9,3.6,1.3,1.0 67 | 6.7,3.1,4.4,1.4,1.0 68 | 5.6,3.0,4.5,1.5,1.0 69 | 5.8,2.7,4.1,1.0,1.0 70 | 6.2,2.2,4.5,1.5,1.0 71 | 5.6,2.5,3.9,1.1,1.0 72 | 5.9,3.2,4.8,1.8,1.0 73 | 6.1,2.8,4.0,1.3,1.0 74 | 6.3,2.5,4.9,1.5,1.0 75 | 6.1,2.8,4.7,1.2,1.0 76 | 6.4,2.9,4.3,1.3,1.0 77 | 6.6,3.0,4.4,1.4,1.0 78 | 6.8,2.8,4.8,1.4,1.0 79 | 6.7,3.0,5.0,1.7,1.0 80 | 6.0,2.9,4.5,1.5,1.0 81 | 5.7,2.6,3.5,1.0,1.0 82 | 5.5,2.4,3.8,1.1,1.0 83 | 5.5,2.4,3.7,1.0,1.0 84 | 5.8,2.7,3.9,1.2,1.0 85 | 6.0,2.7,5.1,1.6,1.0 86 | 5.4,3.0,4.5,1.5,1.0 87 | 6.0,3.4,4.5,1.6,1.0 88 | 6.7,3.1,4.7,1.5,1.0 89 | 6.3,2.3,4.4,1.3,1.0 90 | 5.6,3.0,4.1,1.3,1.0 91 | 5.5,2.5,4.0,1.3,1.0 92 | 5.5,2.6,4.4,1.2,1.0 93 | 6.1,3.0,4.6,1.4,1.0 94 | 5.8,2.6,4.0,1.2,1.0 95 | 5.0,2.3,3.3,1.0,1.0 96 | 5.6,2.7,4.2,1.3,1.0 97 | 5.7,3.0,4.2,1.2,1.0 98 | 5.7,2.9,4.2,1.3,1.0 99 | 6.2,2.9,4.3,1.3,1.0 100 | 5.1,2.5,3.0,1.1,1.0 101 | 5.7,2.8,4.1,1.3,1.0 102 | 6.3,3.3,6.0,2.5,2.0 103 | 5.8,2.7,5.1,1.9,2.0 104 | 7.1,3.0,5.9,2.1,2.0 105 | 6.3,2.9,5.6,1.8,2.0 106 | 6.5,3.0,5.8,2.2,2.0 107 | 7.6,3.0,6.6,2.1,2.0 108 | 4.9,2.5,4.5,1.7,2.0 109 | 7.3,2.9,6.3,1.8,2.0 110 | 6.7,2.5,5.8,1.8,2.0 111 | 7.2,3.6,6.1,2.5,2.0 112 | 6.5,3.2,5.1,2.0,2.0 113 | 6.4,2.7,5.3,1.9,2.0 114 | 6.8,3.0,5.5,2.1,2.0 115 | 5.7,2.5,5.0,2.0,2.0 116 | 5.8,2.8,5.1,2.4,2.0 117 | 6.4,3.2,5.3,2.3,2.0 118 | 6.5,3.0,5.5,1.8,2.0 119 | 7.7,3.8,6.7,2.2,2.0 120 | 7.7,2.6,6.9,2.3,2.0 121 | 6.0,2.2,5.0,1.5,2.0 122 | 6.9,3.2,5.7,2.3,2.0 123 | 5.6,2.8,4.9,2.0,2.0 124 | 7.7,2.8,6.7,2.0,2.0 125 | 6.3,2.7,4.9,1.8,2.0 126 | 6.7,3.3,5.7,2.1,2.0 127 | 7.2,3.2,6.0,1.8,2.0 128 | 6.2,2.8,4.8,1.8,2.0 129 | 6.1,3.0,4.9,1.8,2.0 130 | 6.4,2.8,5.6,2.1,2.0 131 | 7.2,3.0,5.8,1.6,2.0 132 | 7.4,2.8,6.1,1.9,2.0 133 | 7.9,3.8,6.4,2.0,2.0 134 | 6.4,2.8,5.6,2.2,2.0 135 | 6.3,2.8,5.1,1.5,2.0 136 | 6.1,2.6,5.6,1.4,2.0 137 | 7.7,3.0,6.1,2.3,2.0 138 | 6.3,3.4,5.6,2.4,2.0 139 | 6.4,3.1,5.5,1.8,2.0 140 | 6.0,3.0,4.8,1.8,2.0 141 | 6.9,3.1,5.4,2.1,2.0 142 | 6.7,3.1,5.6,2.4,2.0 143 | 6.9,3.1,5.1,2.3,2.0 144 | 5.8,2.7,5.1,1.9,2.0 145 | 6.8,3.2,5.9,2.3,2.0 146 | 6.7,3.3,5.7,2.5,2.0 147 | 6.7,3.0,5.2,2.3,2.0 148 | 6.3,2.5,5.0,1.9,2.0 149 | 6.5,3.0,5.2,2.0,2.0 150 | 6.2,3.4,5.4,2.3,2.0 151 | 5.9,3.0,5.1,1.8,2.0 -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/feature/selection/FeatureSelectionTestBase.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection 19 | 20 | import org.apache.spark.ml.attribute.AttributeGroup 21 | import org.apache.spark.ml.feature.VectorAssembler 22 | import org.apache.spark.ml.feature.selection.test_util._ 23 | import org.apache.spark.ml.linalg.Vector 24 | import org.apache.spark.sql.{Dataset, Row, SparkSession} 25 | import org.scalactic.TolerantNumerics 26 | import org.scalatest._ 27 | 28 | abstract class FeatureSelectionTestBase extends FunSuite with SharedSparkSession with BeforeAndAfter with DefaultReadWriteTest 29 | 30 | trait SharedSparkSession extends BeforeAndAfterAll with BeforeAndAfterEach { 31 | self: Suite => 32 | 33 | @transient private var _sc: SparkSession = _ 34 | @transient var dataset: Dataset[_] = _ 35 | 36 | private val testPath = getClass.getResource("/iris.data").getPath 37 | protected val featuresColName = "features" 38 | protected val labelColName = "Species" 39 | 40 | def sc: SparkSession = _sc 41 | 42 | override def beforeAll() { 43 | super.beforeAll() 44 | _sc = SparkSession 45 | .builder() 46 | .master("local[*]") 47 | .appName("spark test base") 48 | .getOrCreate() 49 | _sc.sparkContext.setLogLevel("ERROR") 50 | 51 | val df = sc.read.option("inferSchema", true).option("header", true).csv(testPath) 52 | dataset = new VectorAssembler() 53 | .setInputCols(Array("sLength", "sWidth", "pLength", "pWidth")) 54 | .setOutputCol(featuresColName) 55 | .transform(df) 56 | } 57 | 58 | override def afterAll() { 59 | try { 60 | if (_sc != null) { 61 | _sc.stop() 62 | } 63 | _sc = null 64 | } finally { 65 | super.afterAll() 66 | } 67 | } 68 | } 69 | 70 | object FeatureSelectorTestBase { 71 | private val epsilon = 1e-4f 72 | private implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(epsilon) 73 | 74 | def tolerantVectorEquality(a: Vector, b: Vector): Boolean = { 75 | a.toArray.zip(b.toArray).map { 76 | case (a: Double, b: Double) => doubleEq.areEqual(a, b) 77 | case (a: Any, b: Any) => a == b 78 | }.forall(value => value) 79 | } 80 | 81 | def testSelector[ 82 | Learner <: FeatureSelector[Learner, M], 83 | M <: FeatureSelectorModel[M]](selector: FeatureSelector[Learner, M], 84 | dataset: Dataset[_], 85 | importantColNames: Array[String], 86 | groundTruthColname: String): Unit = { 87 | val selectorModel = selector.fit(dataset) 88 | val transformed = selectorModel.transform(dataset) 89 | 90 | val inputCols = AttributeGroup.fromStructField(transformed.schema(selector.getFeaturesCol)) 91 | .attributes.get.map(attr => attr.name.get) 92 | 93 | assert(selectorModel.featureImportances.size == inputCols.length, 94 | "Length of featureImportances array is not equal to number of input columns!") 95 | 96 | val selectedColNames = AttributeGroup.fromStructField(transformed.schema(selector.getOutputCol)) 97 | .attributes.get.map(attr => attr.name.get) 98 | 99 | val importantColsSelected = importantColNames.sorted.zip(selectedColNames.sorted).map(elem => elem._1 == elem._2).forall(elem => elem) 100 | 101 | assert(importantColsSelected, "Selected and important column names do not match!") 102 | 103 | transformed.select(selectorModel.getOutputCol, groundTruthColname).collect() 104 | .foreach { case Row(vec1: Vector, vec2: Vector) => 105 | assert(tolerantVectorEquality(vec1, vec2)) 106 | } 107 | } 108 | 109 | def checkModelData[M <: FeatureSelectorModel[M]](model1: FeatureSelectorModel[M], model2: FeatureSelectorModel[M]): Unit = { 110 | 111 | assert(model1.selectedFeatures sameElements model2.selectedFeatures 112 | , "Persisted model has different selectedFeatures.") 113 | assert(model1.featureImportances.toArray.sortBy(elem => elem._1) sameElements model2.featureImportances.toArray.sortBy(elem => elem._1), 114 | "Persisted model has different featureImportances.") 115 | } 116 | 117 | /** 118 | * Mapping from all Params to valid settings which differ from the defaults. 119 | * This is useful for tests which need to exercise all Params, such as save/load. 120 | * This excludes input columns to simplify some tests. 121 | */ 122 | val allParamSettings: Map[String, Any] = Map( 123 | "selectorType" -> "percentile", 124 | "numTopFeatures" -> 1, 125 | "percentile" -> 0.12, 126 | "randomCutOff" -> 0.1, 127 | "featuresCol" -> "features", 128 | "labelCol" -> "Species", 129 | "outputCol" -> "myOutput" 130 | ) 131 | } 132 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/feature/selection/embedded/ImportanceSelectorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.embedded 19 | 20 | import org.apache.spark.ml.feature.VectorAssembler 21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase} 22 | import org.apache.spark.ml.linalg.Vectors 23 | 24 | class ImportanceSelectorSuite extends FeatureSelectionTestBase { 25 | // Order of feature importances must be: f4 > f3 > f2 > f1 26 | private val featureWeights = Vectors.dense(Array(0.3, 0.5, 0.7, 0.8)) 27 | 28 | test("Test ImportanceSelector: numTopFeatures") { 29 | val selector = new ImportanceSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 30 | .setFeatureWeights(featureWeights) 31 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(2) 32 | 33 | val importantColNames = Array("pWidth", "pLength") 34 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 35 | 36 | FeatureSelectorTestBase.testSelector[ImportanceSelector, ImportanceSelectorModel](selector, df, importantColNames, "ImportantFeatures") 37 | } 38 | 39 | test("Test ImportanceSelector: percentile") { 40 | val selector = new ImportanceSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 41 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51).setFeatureWeights(featureWeights) 42 | 43 | val importantColNames = Array("pWidth", "pLength") 44 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 45 | 46 | FeatureSelectorTestBase.testSelector[ImportanceSelector, ImportanceSelectorModel](selector, df, importantColNames, "ImportantFeatures") 47 | } 48 | 49 | test("Test ImportanceSelector: randomCutOff") { 50 | val selector = new ImportanceSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 51 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0).setFeatureWeights(featureWeights) 52 | 53 | val importantColNames = Array("pWidth", "pLength", "sWidth", "sLength") 54 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 55 | 56 | FeatureSelectorTestBase.testSelector[ImportanceSelector, ImportanceSelectorModel](selector, df, importantColNames, "ImportantFeatures") 57 | } 58 | 59 | test("ImportanceSelector read/write") { 60 | val nb = new ImportanceSelector 61 | testEstimatorAndModelReadWrite[ImportanceSelector, ImportanceSelectorModel](nb, dataset, 62 | FeatureSelectorTestBase.allParamSettings.+("featureWeights" -> featureWeights), FeatureSelectorTestBase.checkModelData) 63 | } 64 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/feature/selection/embedded/LRSelectorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.embedded 19 | 20 | import org.apache.spark.ml.feature.VectorAssembler 21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase} 22 | import org.apache.spark.ml.linalg.Matrices 23 | 24 | class LRSelectorSuite extends FeatureSelectionTestBase { 25 | // Order of feature importances must be: f4 > f3 > f2 > f1 26 | private val lrWeights = Matrices.dense(3, 4, Array(0.1, 0.1, 0.1, 0.2, 0.2, 0.2, -0.8, -0.8, -0.8, 0.9, 0.9, 0.9)) 27 | 28 | test("Test LRSelector: numTopFeatures") { 29 | val selector = new LRSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCoefficientMatrix(lrWeights) 30 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(2) 31 | 32 | val importantColNames = Array("pWidth", "pLength") 33 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 34 | 35 | FeatureSelectorTestBase.testSelector[LRSelector, LRSelectorModel](selector, df, importantColNames, "ImportantFeatures") 36 | } 37 | 38 | test("Test LRSelector: percentile") { 39 | val selector = new LRSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 40 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51).setCoefficientMatrix(lrWeights) 41 | 42 | val importantColNames = Array("pWidth", "pLength") 43 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 44 | 45 | FeatureSelectorTestBase.testSelector[LRSelector, LRSelectorModel](selector, df, importantColNames, "ImportantFeatures") 46 | } 47 | 48 | test("Test LRSelector: randomCutOff") { 49 | val selector = new LRSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 50 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0).setCoefficientMatrix(lrWeights) 51 | 52 | val importantColNames = Array("pWidth", "pLength", "sWidth", "sLength") 53 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 54 | 55 | FeatureSelectorTestBase.testSelector[LRSelector, LRSelectorModel](selector, df, importantColNames, "ImportantFeatures") 56 | } 57 | 58 | test("LRSelector read/write") { 59 | val nb = new LRSelector 60 | testEstimatorAndModelReadWrite[LRSelector, LRSelectorModel](nb, dataset, 61 | FeatureSelectorTestBase.allParamSettings.+("coefficientMatrix" -> lrWeights), FeatureSelectorTestBase.checkModelData) 62 | } 63 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/feature/selection/filter/CorrelationSelectorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.filter 19 | 20 | import org.apache.spark.ml.feature.VectorAssembler 21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase} 22 | 23 | /* To verify the results with R, run: 24 | library(dplyr) 25 | data(iris) 26 | df <- iris %>% 27 | dplyr::mutate(label = ifelse(Species == "setosa", 0.0, ifelse(Species == "versicolor", 1.0, 2.0))) %>% 28 | dplyr::select("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width", "label") 29 | print(cor(df, method = "pearson")) 30 | print(cor(df, method = "spearman")) 31 | */ 32 | 33 | class CorrelationSelectorSuite extends FeatureSelectionTestBase { 34 | test("Test CorrelationSelector - pearson: numTopFeatures") { 35 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("pearson") 36 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(3) 37 | 38 | val importantColNames = Array("pWidth", "pLength", "sLength") 39 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 40 | 41 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures") 42 | } 43 | 44 | test("Test CorrelationSelector - pearson: percentile") { 45 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("pearson") 46 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51) 47 | 48 | val importantColNames = Array("pWidth", "pLength", "sLength") 49 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 50 | 51 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures") 52 | } 53 | 54 | test("Test CorrelationSelector - pearson: randomCutOff") { 55 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("pearson") 56 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0) 57 | 58 | val importantColNames = Array("pWidth", "pLength", "sLength", "sWidth") 59 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 60 | 61 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures") 62 | } 63 | 64 | test("Test CorrelationSelector - spearman: numTopFeatures") { 65 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("spearman") 66 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(3) 67 | 68 | val importantColNames = Array("pWidth", "pLength", "sLength") 69 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 70 | 71 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures") 72 | } 73 | 74 | test("Test CorrelationSelector - spearman: percentile") { 75 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("spearman") 76 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51) 77 | 78 | val importantColNames = Array("pWidth", "pLength", "sLength") 79 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 80 | 81 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures") 82 | } 83 | 84 | test("Test CorrelationSelector - spearman: randomCutOff") { 85 | val selector = new CorrelationSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName).setCorrelationType("spearman") 86 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0) 87 | 88 | val importantColNames = Array("pWidth", "pLength", "sLength", "sWidth") 89 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 90 | 91 | FeatureSelectorTestBase.testSelector[CorrelationSelector, CorrelationSelectorModel](selector, df, importantColNames, "ImportantFeatures") 92 | } 93 | 94 | test("CorrelationSelector read/write") { 95 | val nb = new CorrelationSelector 96 | testEstimatorAndModelReadWrite[CorrelationSelector, CorrelationSelectorModel](nb, dataset, 97 | FeatureSelectorTestBase.allParamSettings.+("correlationType" -> "pearson"), 98 | FeatureSelectorTestBase.checkModelData) 99 | } 100 | } 101 | 102 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/feature/selection/filter/GiniSelectorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.filter 19 | 20 | import org.apache.spark.ml.feature.VectorAssembler 21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase} 22 | 23 | /* To verify the results with R, run: 24 | library(CORElearn) 25 | data(iris) 26 | weights <- attrEval(Species ~ ., data=iris, estimator = "Gini") 27 | print(weights) 28 | */ 29 | class GiniSelectorSuite extends FeatureSelectionTestBase { 30 | test("Test GiniSelector: numTopFeatures") { 31 | val selector = new GiniSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 32 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(2) 33 | 34 | val importantColNames = Array("pLength", "pWidth") 35 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 36 | 37 | FeatureSelectorTestBase.testSelector[GiniSelector, GiniSelectorModel](selector, df, importantColNames, "ImportantFeatures") 38 | } 39 | 40 | test("Test GiniSelector: percentile") { 41 | val selector = new GiniSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 42 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51) 43 | 44 | val importantColNames = Array("pLength", "pWidth") 45 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 46 | 47 | FeatureSelectorTestBase.testSelector[GiniSelector, GiniSelectorModel](selector, df, importantColNames, "ImportantFeatures") 48 | } 49 | 50 | test("Test GiniSelector: randomCutOff") { 51 | val selector = new GiniSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 52 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0) 53 | 54 | val importantColNames = Array("pLength", "pWidth", "sLength", "sWidth") 55 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 56 | 57 | FeatureSelectorTestBase.testSelector[GiniSelector, GiniSelectorModel](selector, df, importantColNames, "ImportantFeatures") 58 | } 59 | 60 | test("GiniSelector read/write") { 61 | val nb = new GiniSelector 62 | testEstimatorAndModelReadWrite[GiniSelector, GiniSelectorModel](nb, dataset, FeatureSelectorTestBase.allParamSettings, FeatureSelectorTestBase.checkModelData) 63 | } 64 | } 65 | 66 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/feature/selection/filter/InfoGainSelectorSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.filter 19 | 20 | import org.apache.spark.ml.feature.VectorAssembler 21 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase} 22 | 23 | /* To verify the results with R, run: 24 | library(RWeka) 25 | data(iris) 26 | weights <- InfoGainAttributeEval(Species ~ ., data=iris,) 27 | print(weights) 28 | */ 29 | class InfoGainSelectorSuite extends FeatureSelectionTestBase { 30 | test("Test InfoGainSelector: numTopFeatures") { 31 | val selector = new InfoGainSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 32 | .setOutputCol("filtered").setSelectorType("numTopFeatures").setNumTopFeatures(2) 33 | 34 | val importantColNames = Array("pLength", "pWidth") 35 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 36 | 37 | FeatureSelectorTestBase.testSelector[InfoGainSelector, InfoGainSelectorModel](selector, df, importantColNames, "ImportantFeatures") 38 | } 39 | 40 | test("Test InfoGainSelector: percentile") { 41 | val selector = new InfoGainSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 42 | .setOutputCol("filtered").setSelectorType("percentile").setPercentile(0.51) 43 | 44 | val importantColNames = Array("pLength", "pWidth") 45 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 46 | 47 | FeatureSelectorTestBase.testSelector[InfoGainSelector, InfoGainSelectorModel](selector, df, importantColNames, "ImportantFeatures") 48 | } 49 | 50 | test("Test InfoGainSelector: randomCutOff") { 51 | val selector = new InfoGainSelector().setFeaturesCol(featuresColName).setLabelCol(labelColName) 52 | .setOutputCol("filtered").setSelectorType("randomCutOff").setRandomCutOff(1.0) 53 | 54 | val importantColNames = Array("pLength", "pWidth", "sLength", "sWidth") 55 | val df = new VectorAssembler().setInputCols(importantColNames).setOutputCol("ImportantFeatures").transform(dataset) 56 | 57 | FeatureSelectorTestBase.testSelector[InfoGainSelector, InfoGainSelectorModel](selector, df, importantColNames, "ImportantFeatures") 58 | } 59 | 60 | test("InfoGainSelector read/write") { 61 | val nb = new InfoGainSelector 62 | testEstimatorAndModelReadWrite[InfoGainSelector, InfoGainSelectorModel](nb, dataset, FeatureSelectorTestBase.allParamSettings, FeatureSelectorTestBase.checkModelData) 63 | } 64 | } 65 | 66 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/feature/selection/test_util/DefaultReadWriteTest.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.test_util 19 | 20 | import java.io.{File, IOException} 21 | 22 | import org.apache.spark.ml.feature.selection.FeatureSelectionTestBase 23 | import org.apache.spark.ml.param._ 24 | import org.apache.spark.ml.util._ 25 | import org.apache.spark.ml.{Estimator, Model} 26 | import org.apache.spark.sql.Dataset 27 | import org.scalatest.Suite 28 | 29 | trait DefaultReadWriteTest extends TempDirectory { self: Suite => 30 | 31 | /** 32 | * Checks "overwrite" option and params. 33 | * This saves to and loads from [[tempDir]], but creates a subdirectory with a random name 34 | * in order to avoid conflicts from multiple calls to this method. 35 | * 36 | * @param instance ML instance to test saving/loading 37 | * @param testParams If true, then test values of Params. Otherwise, just test overwrite option. 38 | * @tparam T ML instance type 39 | * @return Instance loaded from file 40 | */ 41 | def testDefaultReadWrite[T <: Params with MLWritable]( 42 | instance: T, 43 | testParams: Boolean = true): T = { 44 | val uid = instance.uid 45 | val subdirName = Identifiable.randomUID("test") 46 | 47 | val subdir = new File(tempDir, subdirName) 48 | val path = new File(subdir, uid).getPath 49 | 50 | instance.save(path) 51 | intercept[IOException] { 52 | instance.save(path) 53 | } 54 | instance.write.overwrite().save(path) 55 | val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] 56 | val newInstance = loader.load(path) 57 | 58 | assert(newInstance.uid === instance.uid) 59 | if (testParams) { 60 | instance.params.foreach { p => 61 | if (instance.isDefined(p)) { 62 | (instance.getOrDefault(p), newInstance.getOrDefault(p)) match { 63 | case (Array(values), Array(newValues)) => 64 | assert(values === newValues, s"Values do not match on param ${p.name}.") 65 | case (value, newValue) => 66 | assert(value === newValue, s"Values do not match on param ${p.name}.") 67 | } 68 | } else { 69 | assert(!newInstance.isDefined(p), s"Param ${p.name} shouldn't be defined.") 70 | } 71 | } 72 | } 73 | 74 | val load = instance.getClass.getMethod("load", classOf[String]) 75 | val another = load.invoke(instance, path).asInstanceOf[T] 76 | assert(another.uid === instance.uid) 77 | another 78 | } 79 | 80 | /** 81 | * Default test for Estimator, Model pairs: 82 | * - Explicitly set Params, and train model 83 | * - Test save/load using [[testDefaultReadWrite()]] on Estimator and Model 84 | * - Check Params on Estimator and Model 85 | * - Compare model data 86 | * 87 | * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. 88 | * 89 | * @param estimator Estimator to test 90 | * @param dataset Dataset to pass to [[Estimator.fit()]] 91 | * @param testParams Set of [[Param]] values to set in estimator 92 | * @param checkModelData Method which takes the original and loaded [[Model]] and compares their 93 | * data. This method does not need to check [[Param]] values. 94 | * @tparam E Type of [[Estimator]] 95 | * @tparam M Type of [[Model]] produced by estimator 96 | */ 97 | def testEstimatorAndModelReadWrite[ 98 | E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( 99 | estimator: E, 100 | dataset: Dataset[_], 101 | testParams: Map[String, Any], 102 | checkModelData: (M, M) => Unit): Unit = { 103 | // Set some Params to make sure set Params are serialized. 104 | testParams.foreach { case (p, v) => 105 | estimator.set(estimator.getParam(p), v) 106 | } 107 | val model = estimator.fit(dataset) 108 | 109 | // Test Estimator save/load 110 | val estimator2 = testDefaultReadWrite(estimator) 111 | testParams.foreach { case (p, v) => 112 | val param = estimator.getParam(p) 113 | assert(estimator.get(param).get === estimator2.get(param).get) 114 | } 115 | 116 | // Test Model save/load 117 | val model2 = testDefaultReadWrite(model) 118 | testParams.foreach { case (p, v) => 119 | val param = model.getParam(p) 120 | assert(model.get(param).get === model2.get(param).get) 121 | } 122 | 123 | checkModelData(model, model2) 124 | } 125 | } 126 | 127 | class MyParams(override val uid: String) extends Params with MLWritable { 128 | 129 | final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc") 130 | final val intParam: IntParam = new IntParam(this, "intParam", "doc") 131 | final val floatParam: FloatParam = new FloatParam(this, "floatParam", "doc") 132 | final val doubleParam: DoubleParam = new DoubleParam(this, "doubleParam", "doc") 133 | final val longParam: LongParam = new LongParam(this, "longParam", "doc") 134 | final val stringParam: Param[String] = new Param[String](this, "stringParam", "doc") 135 | final val intArrayParam: IntArrayParam = new IntArrayParam(this, "intArrayParam", "doc") 136 | final val doubleArrayParam: DoubleArrayParam = 137 | new DoubleArrayParam(this, "doubleArrayParam", "doc") 138 | final val stringArrayParam: StringArrayParam = 139 | new StringArrayParam(this, "stringArrayParam", "doc") 140 | 141 | setDefault(intParamWithDefault -> 0) 142 | set(intParam -> 1) 143 | set(floatParam -> 2.0f) 144 | set(doubleParam -> 3.0) 145 | set(longParam -> 4L) 146 | set(stringParam -> "5") 147 | set(intArrayParam -> Array(6, 7)) 148 | set(doubleArrayParam -> Array(8.0, 9.0)) 149 | set(stringArrayParam -> Array("10", "11")) 150 | 151 | override def copy(extra: ParamMap): Params = defaultCopy(extra) 152 | 153 | override def write: MLWriter = new DefaultParamsWriter(this) 154 | } 155 | 156 | object MyParams extends MLReadable[MyParams] { 157 | 158 | override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams] 159 | 160 | override def load(path: String): MyParams = super.load(path) 161 | } 162 | 163 | class DefaultReadWriteSuite extends FeatureSelectionTestBase 164 | with DefaultReadWriteTest { 165 | 166 | test("default read/write") { 167 | val myParams = new MyParams("my_params") 168 | testDefaultReadWrite(myParams) 169 | } 170 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/feature/selection/test_util/TempDirectory.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.test_util 19 | 20 | import java.io.File 21 | 22 | import org.apache.spark.util.Utils 23 | import org.scalatest.{BeforeAndAfterAll, Suite} 24 | 25 | /** 26 | * Trait that creates a temporary directory before all tests and deletes it after all. 27 | */ 28 | trait TempDirectory extends BeforeAndAfterAll { self: Suite => 29 | 30 | private var _tempDir: File = _ 31 | 32 | /** Returns the temporary directory as a [[File]] instance. */ 33 | protected def tempDir: File = _tempDir 34 | 35 | override def beforeAll(): Unit = { 36 | super.beforeAll() 37 | _tempDir = Utils.createTempDir(namePrefix = this.getClass.getName) 38 | } 39 | 40 | override def afterAll(): Unit = { 41 | try { 42 | Utils.deleteRecursively(_tempDir) 43 | } finally { 44 | super.afterAll() 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/ml/feature/selection/util/VectorMergerSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.ml.feature.selection.util 19 | 20 | import org.apache.spark.ml.attribute.AttributeGroup 21 | import org.apache.spark.ml.feature.VectorAssembler 22 | import org.apache.spark.ml.feature.selection.{FeatureSelectionTestBase, FeatureSelectorTestBase} 23 | import org.apache.spark.ml.linalg.Vector 24 | import org.apache.spark.sql.Row 25 | 26 | class VectorMergerSuite extends FeatureSelectionTestBase { 27 | 28 | test("VectorMerger: merges two VectorColumns with different names") { 29 | val dfTmp = new VectorAssembler().setInputCols(Array("pWidth", "pLength")).setOutputCol("vector1").transform(dataset) 30 | val dfTmp2 = new VectorAssembler().setInputCols(Array("sWidth", "sLength")).setOutputCol("vector2").transform(dfTmp) 31 | val df = new VectorAssembler().setInputCols(Array("pWidth", "pLength", "sWidth", "sLength")).setOutputCol("expected").transform(dfTmp2) 32 | 33 | val dfT = new VectorMerger().setInputCols(Array("vector1", "vector2")).setOutputCol("merged").transform(df) 34 | 35 | val outCols = AttributeGroup.fromStructField(dfT.schema("merged")).attributes.get.map(attr => attr.name.get) 36 | 37 | assert(outCols.length == 4, "Length of merged column is not equal to 4!") 38 | 39 | assert(outCols.sorted sameElements Array("pWidth", "pLength", "sWidth", "sLength").sorted, 40 | "Input and output column names do not match!") 41 | 42 | dfT.select("merged", "expected").collect() 43 | .foreach { case Row(vec1: Vector, vec2: Vector) => 44 | assert(FeatureSelectorTestBase.tolerantVectorEquality(vec1, vec2), "column in merged and expected do not match!") 45 | } 46 | } 47 | 48 | test("VectorMerger: merges two VectorColumns with duplicate names") { 49 | val dfTmp = new VectorAssembler().setInputCols(Array("pWidth", "pLength")).setOutputCol("vector1").transform(dataset) 50 | val dfTmp2 = new VectorAssembler().setInputCols(Array("pLength", "sLength")).setOutputCol("vector2").transform(dfTmp) 51 | val df = new VectorAssembler().setInputCols(Array("pWidth", "pLength", "sLength")).setOutputCol("expected").transform(dfTmp2) 52 | 53 | val dfT = new VectorMerger().setInputCols(Array("vector1", "vector2")).setOutputCol("merged").transform(df) 54 | 55 | val outCols = AttributeGroup.fromStructField(dfT.schema("merged")).attributes.get.map(attr => attr.name.get) 56 | 57 | assert(outCols.length == 3, "Length of merged column is not equal to 3!") 58 | 59 | assert(outCols.sorted sameElements Array("pWidth", "pLength", "sLength").sorted, 60 | "Input and output column names do not match!") 61 | 62 | dfT.select("merged", "expected").collect() 63 | .foreach { case Row(vec1: Vector, vec2: Vector) => 64 | assert(FeatureSelectorTestBase.tolerantVectorEquality(vec1, vec2), "column in merged and expected do not match!") 65 | } 66 | } 67 | 68 | test("VectorMerger - useFeatureCol: merges two VectorColumns with different names") { 69 | val dfTmp = new VectorAssembler().setInputCols(Array("pWidth", "pLength")).setOutputCol("vector1").transform(dataset) 70 | val dfTmp2 = new VectorAssembler().setInputCols(Array("sWidth", "sLength")).setOutputCol("vector2").transform(dfTmp) 71 | // The features column has a different ordering to test if the correct values are taken for each column 72 | val dfTp3 = new VectorAssembler().setInputCols(Array("pWidth", "sLength", "sWidth", "pLength")).setOutputCol("formerging").transform(dfTmp2) 73 | val df = new VectorAssembler().setInputCols(Array("pWidth", "pLength", "sWidth", "sLength")).setOutputCol("expected").transform(dfTp3) 74 | 75 | val dfT = new VectorMerger() 76 | .setInputCols(Array("vector1", "vector2")) 77 | .setFeatureCol("formerging") 78 | .setUseFeaturesCol(true) 79 | .setOutputCol("merged").transform(df) 80 | 81 | val outCols = AttributeGroup.fromStructField(dfT.schema("merged")).attributes.get.map(attr => attr.name.get) 82 | 83 | assert(outCols.length == 4, "Length of merged column is not equal to 4!") 84 | 85 | assert(outCols.sorted sameElements Array("pWidth", "pLength", "sWidth", "sLength").sorted, 86 | "Input and output column names do not match!") 87 | 88 | dfT.select("merged", "expected").collect() 89 | .foreach { case Row(vec1: Vector, vec2: Vector) => 90 | assert(FeatureSelectorTestBase.tolerantVectorEquality(vec1, vec2), "column in merged and expected do not match!") 91 | } 92 | } 93 | 94 | test("VectorMerger - useFeatureCol: merges two VectorColumns with duplicate names") { 95 | val dfTmp = new VectorAssembler().setInputCols(Array("pWidth", "pLength")).setOutputCol("vector1").transform(dataset) 96 | val dfTmp2 = new VectorAssembler().setInputCols(Array("pLength", "sLength")).setOutputCol("vector2").transform(dfTmp) 97 | // The features column has a different ordering to test if the correct values are taken for each column 98 | val dfTp3 = new VectorAssembler().setInputCols(Array("pWidth", "sLength", "sWidth", "pLength")).setOutputCol("formerging").transform(dfTmp2) 99 | val df = new VectorAssembler().setInputCols(Array("pWidth", "pLength", "sLength")).setOutputCol("expected").transform(dfTp3) 100 | 101 | val dfT = new VectorMerger() 102 | .setInputCols(Array("vector1", "vector2")) 103 | .setFeatureCol("formerging") 104 | .setUseFeaturesCol(true) 105 | .setOutputCol("merged").transform(df) 106 | 107 | val outCols = AttributeGroup.fromStructField(dfT.schema("merged")).attributes.get.map(attr => attr.name.get) 108 | 109 | assert(outCols.length == 3, "Length of merged column is not equal to 3!") 110 | 111 | assert(outCols.sorted sameElements Array("pWidth", "pLength", "sLength").sorted, 112 | "Input and output column names do not match!") 113 | 114 | dfT.select("merged", "expected").collect() 115 | .foreach { case Row(vec1: Vector, vec2: Vector) => 116 | assert(FeatureSelectorTestBase.tolerantVectorEquality(vec1, vec2), "column in merged and expected do not match!") 117 | } 118 | } 119 | 120 | test("VectorMerger read/write") { 121 | val nb = new VectorMerger 122 | testDefaultReadWrite[VectorMerger](nb, testParams = true) 123 | } 124 | } 125 | -------------------------------------------------------------------------------- /version.sbt: -------------------------------------------------------------------------------- 1 | version in ThisBuild := "1.0.0-SNAPSHOT" --------------------------------------------------------------------------------