├── .gitattributes ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── build.sbt ├── data └── mnist │ ├── benchmark.R │ ├── benchmark.Rmd │ ├── benchmark.md │ ├── benchmark.png │ ├── benchmark_files │ └── figure-html │ │ ├── cluster-plot-1.png │ │ ├── horizontal-plot-1.png │ │ ├── local-plot-1.png │ │ └── pre-run local-1.png │ ├── knn.rds │ ├── mnist.bz2 │ ├── mnist.csv.gz │ └── rann.rds ├── dev └── benchmark.sh ├── project ├── Common.scala ├── Dependencies.scala ├── SparkSubmit.scala ├── build.properties └── plugins.sbt ├── python ├── pyspark_knn │ ├── __init__.py │ └── ml │ │ ├── __init__.py │ │ ├── classification.py │ │ └── regression.py ├── setup.py └── test.py ├── scalastyle-config.xml ├── spark-knn-core └── src │ ├── main │ └── scala │ │ └── org │ │ └── apache │ │ └── spark │ │ ├── ml │ │ ├── classification │ │ │ └── KNNClassifier.scala │ │ ├── knn │ │ │ ├── DistanceMetric.scala │ │ │ ├── KNN.scala │ │ │ └── MetricTree.scala │ │ └── regression │ │ │ └── KNNRegression.scala │ │ └── mllib │ │ └── knn │ │ └── KNNUtils.scala │ └── test │ ├── resources │ └── log4j.properties │ └── scala │ └── org │ └── apache │ └── spark │ └── ml │ ├── knn │ ├── DistanceMetricSpec.scala │ ├── KNNSuite.scala │ ├── MetricTreeSpec.scala │ └── SpillTreeSpec.scala │ └── regression │ └── KNNRegressionSuite.scala ├── spark-knn-examples └── src │ └── main │ ├── resources │ └── log4j.properties │ └── scala │ ├── com │ └── github │ │ └── saurfang │ │ └── spark │ │ └── ml │ │ └── knn │ │ └── examples │ │ ├── MNIST.scala │ │ ├── MNISTBenchmark.scala │ │ └── MNISTCrossValidation.scala │ └── org │ └── apache │ └── spark │ └── ml │ ├── classification │ └── NaiveKNN.scala │ └── tuning │ └── Benchmarker.scala └── spark-knn.Rproj /.gitattributes: -------------------------------------------------------------------------------- 1 | *.gz filter=lfs diff=lfs merge=lfs -text 2 | *.bz2 filter=lfs diff=lfs merge=lfs -text 3 | *.png filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.class 2 | *.log 3 | 4 | # sbt specific 5 | .cache 6 | .history 7 | .lib/ 8 | dist/* 9 | target/ 10 | lib_managed/ 11 | src_managed/ 12 | project/boot/ 13 | project/plugins/project/ 14 | 15 | # Scala-IDE specific 16 | .scala_dependencies 17 | .worksheet 18 | .idea 19 | .Rproj.user 20 | .Rhistory 21 | 22 | # Python 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | python/build/ 27 | python/dist/ 28 | *.egg-info/ 29 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.11.12 4 | - 2.12.8 5 | jdk: 6 | - openjdk8 7 | - openjdk10 8 | sudo: false 9 | cache: 10 | directories: 11 | - $HOME/.ivy2/cache 12 | - $HOME/.sbt/boot/ 13 | script: 14 | - sbt -jvm-opts travis/jvmopts.compile compile 15 | - sbt -jvm-opts travis/jvmopts.test coverage core/test coverageReport 16 | - sbt -jvm-opts travis/jvmopts.test scalastyle 17 | - find $HOME/.sbt -name "*.lock" | xargs rm 18 | - find $HOME/.ivy2 -name "ivydata-*.properties" | xargs rm 19 | after_success: 20 | - bash <(curl -s https://codecov.io/bash) 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2015 Forest Fang 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # spark-knn 2 | 3 | [![Join the chat at https://gitter.im/saurfang/spark-knn](https://badges.gitter.im/saurfang/spark-knn.svg)](https://gitter.im/saurfang/spark-knn?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 4 | 5 | [![Build Status](https://travis-ci.org/saurfang/spark-knn.svg)](https://travis-ci.org/saurfang/spark-knn) 6 | [![codecov.io](http://codecov.io/github/saurfang/spark-knn/coverage.svg?branch=master)](http://codecov.io/github/saurfang/spark-knn?branch=master) 7 | 8 | WIP... 9 | 10 | k-Nearest Neighbors algorithm (k-NN) implemented on Apache Spark. This uses a hybrid spill tree approach to 11 | achieve high accuracy and search efficiency. The simplicity of k-NN and lack of tuning parameters makes k-NN 12 | a useful baseline model for many machine learning problems. 13 | 14 | ## How to Use 15 | 16 | This package is published using [sbt-spark-package](https://github.com/databricks/sbt-spark-package) and 17 | linking information can be found at http://spark-packages.org/package/saurfang/spark-knn 18 | 19 | k-NN can be used for both classification and regression, which are exposed using the new [Spark ML](http://spark.apache.org/docs/latest/ml-guide.html) 20 | API based on DataFrame. Both models accept a weight column so predictions can be optionally weighted. 21 | 22 | ### KNNClassifier 23 | 24 | ```scala 25 | //read in raw label and features 26 | val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() 27 | 28 | val knn = new KNNClassifier() 29 | .setTopTreeSize(training.count().toInt / 500) 30 | .setK(10) 31 | 32 | val knnModel = knn.fit(training) 33 | 34 | val predicted = knnModel.transform(training) 35 | ``` 36 | 37 | ### KNNRegression 38 | 39 | ```scala 40 | //read in raw label and features 41 | val training = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() 42 | 43 | val knn = new KNNRegression() 44 | .setTopTreeSize(training.count().toInt / 500) 45 | .setK(10) 46 | 47 | val knnModel = knn.fit(training) 48 | 49 | val predicted = knnModel.transform(training) 50 | ``` 51 | 52 | Furthermore, KNN itself is also exposed for advanced usage which returns arbitrary columns associated with found neighbors. 53 | For example, this can power clustering use case described in the reference Google paper. 54 | 55 | When the model is trained, data points are repartitioned and within each partition a search tree is built to support 56 | efficient querying. When model is used in prediction, the prediction vectors are repartitioned, searched, collected and 57 | joined back to the search DataFrame. Assuming the training set is much larger, subsequent prediction can be much quicker 58 | than training. Overall the algorithm displays a `O(m log n)` runtime much better than the naive `O(m n)` 59 | runtime (for n training points, m prediction points and k = 1). See [benchmark](#benchmark) section for more details. 60 | 61 | The number of neighbors can be set before and after training. Other parameters must be set before training and they control 62 | the number of partitions and trade off between accuracy and efficiency of individual search tree. 63 | Please refer to Scala doc for more information. 64 | 65 | ## Using the Python interface with spark-submit 66 | 67 | To run a Spark script in Python with `spark-submit`, use: 68 | 69 | ``` 70 | cd python 71 | python setup.py bdist_egg 72 | cd .. 73 | sbt package 74 | 75 | spark-submit --py-files python/dist/pyspark_knn-*.egg --driver-class-path spark-knn-core/target/scala-2.11/spark-knn_*.jar --jars spark-knn-core/target/scala-2.11/spark-knn_*.jar YOUR_SCRIPT 76 | ``` 77 | 78 | ## Benchmark 79 | 80 | Preliminary benchmark results can be found at [here](data/mnist/benchmark.md). 81 | 82 | We have benchmarked our implementation against MNIST dataset. For the canonical 60k training dataset, our implementation 83 | is able to get a reasonable cross validated F1 score of 0.97 comparing to brute force exact algorithm's *to be computed*. 84 | 85 | While the implementation is approximate, it doesn't suffer much even when the dimension is high (as many as the full MNIST 86 | raw dimension: 784). This can be a huge advantage over other approximate implementation such as KD-tree and LSH. *further 87 | benchmark is required* 88 | 89 | The implementation also exhibits sub-linear runtime which can lead to huge savings for large datasets. 90 | 91 | ![](data/mnist/benchmark.png) 92 | 93 | Note: the duration in the above plot is total runtime thus brute-force exhibits polynomial runtime while SpillTree shows 94 | close to linearithmic runtime. 95 | 96 | Finally the implementation scales horizontally and has been successfully applied on datasets with low hundreds millions of 97 | observations and low hundreds dimensions. We have no reason to say why it can't scale to billions of observations as described 98 | in the original Google paper. 99 | 100 | 101 | ## Progress 102 | 103 | - [x] implementation of MetricTree, SpillTree, HybridSpillTree 104 | - [x] distributed KNN based on HybridSpillTree 105 | - [x] \(weighted\) Classifier and Regression on ml API 106 | - [ ] benchmark against Brute Force and LSH based kNN in terms of model and runtime performance 107 | - [ ] benchmark against LSH based kNN 108 | - [ ] refactoring of Tree related code 109 | 110 | NB: currently tree are recursively constructed and contain some duplicated code. The data structure is also questionable. 111 | However preliminary empirical testing shows each tree can comfortably handle tens to hundreds thousands of high dimensional data points. 112 | - [ ] upgrade ml implementation to use DataSet API (pending Spark 1.6) 113 | 114 | NB: the largest cost of this implementation is disk I/O of repartition and distance calculation. While distance calculation 115 | has no good way to optimize, with DataSet API, we might be able to drastically reduce the shuffle size during training 116 | and prediction. 117 | - [ ] explore use of random projection for dimension reduction 118 | 119 | ## Credits 120 | 121 | - Liu, Ting, et al. 122 | "An investigation of practical approximate nearest neighbor algorithms." 123 | Advances in neural information processing systems. 2004. 124 | 125 | - Liu, Ting, Charles Rosenberg, and Henry Rowley. 126 | "Clustering billions of images with large scale nearest neighbor search." 127 | Applications of Computer Vision, 2007. WACV'07. IEEE Workshop on. IEEE, 2007. 128 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import Common._ 2 | 3 | lazy val root = Project("spark-knn", file(".")). 4 | settings(commonSettings). 5 | settings(Dependencies.Versions). 6 | aggregate(core, examples) 7 | 8 | lazy val core = knnProject("spark-knn-core"). 9 | settings( 10 | name := "spark-knn", 11 | spName := "saurfang/spark-knn", 12 | credentials += Credentials(Path.userHome / ".ivy2" / ".sbtcredentials"), 13 | licenses += "Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0") 14 | ). 15 | settings(Dependencies.core). 16 | settings( 17 | scalafixDependencies in ThisBuild += "org.scalatest" %% "autofix" % "3.1.0.0", 18 | addCompilerPlugin(scalafixSemanticdb) // enable SemanticDB 19 | ) 20 | 21 | lazy val examples = knnProject("spark-knn-examples"). 22 | dependsOn(core). 23 | settings(fork in run := true, coverageExcludedPackages := ".*examples.*"). 24 | settings(Dependencies.examples). 25 | settings(SparkSubmit.settings: _*) 26 | -------------------------------------------------------------------------------- /data/mnist/benchmark.R: -------------------------------------------------------------------------------- 1 | ns <- seq(2500, 10000, 2500) 2 | 3 | spillTree_train <- c(7632.666666666666, 7165.333333333333, 7857.666666666666, 9299.333333333332) 4 | spillTree_predict <- c(34317.0, 50162.666666666664, 62151.0, 74604.0) 5 | 6 | bruteForce_train <- c(420.3333333333333, 714.6666666666666, 847.0, 876.6666666666666) 7 | bruteForce_predict <- c(18028.666666666664, 47125.0, 94626.66666666666, 156086.66666666666) 8 | 9 | library(ggplot2) 10 | library(tidyr) 11 | library(dplyr) 12 | 13 | plotDF <- data.frame(ns, spillTree_train, spillTree_predict, bruteForce_train, bruteForce_predict) %>% 14 | gather(type, time, -ns) %>% 15 | separate(type, c("impl", "type")) %>% 16 | mutate(time = time / 1e3, impl = factor(impl)) 17 | 18 | plot <- ggplot(plotDF) + 19 | aes(ns, time, shape = impl, color = impl) + 20 | geom_line() + 21 | geom_point(size = 5) + 22 | facet_grid(type ~ .) + 23 | theme_bw() + 24 | labs(x = "number of data points", y = "wall-clock duration (seconds)", title = "MNIST Data on local[3]") 25 | 26 | print(plot) 27 | ggsave("data/mnist/benchmark.png", plot) -------------------------------------------------------------------------------- /data/mnist/benchmark.Rmd: -------------------------------------------------------------------------------- 1 | --- 2 | title: "kNN Benchmark" 3 | output: 4 | html_document: 5 | keep_md: yes 6 | --- 7 | 8 | ```{r setup, include=FALSE, message=FALSE, warning=FALSE} 9 | knitr::opts_chunk$set(echo = FALSE, warning = FALSE) 10 | 11 | library(ggplot2) 12 | library(tidyr) 13 | library(dplyr) 14 | library(RANN) 15 | library(class) 16 | 17 | plot_runtimes <- function(df) { 18 | ggplot(df) + 19 | aes(ns, time, shape = impl, color = impl) + 20 | geom_line() + 21 | geom_point(size = 5) + 22 | facet_grid(type ~ .) + 23 | theme_bw() + 24 | labs(x = "number of data points", y = "wall-clock duration (seconds)") 25 | } 26 | ``` 27 | 28 | We use [MNIST data](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html) to benchmark our kNN implementation. Recall our kNN implementation is based on hybrid spill tree and more details can be found at the main 29 | README. 30 | 31 | For tests with less than 60k observations, we use the regular MNIST. For those with more than 60k 32 | observations, we opted for the *mnist8m* processed dataset. 33 | 34 | ## kNN on local Spark 35 | We first compare kNN runtime performace on Spark local mode. 36 | ```{r pre-run local} 37 | ns_local <- seq(2500, 10000, 2500) 38 | 39 | spillTree_train <- c(7632.666666666666, 7165.333333333333, 7857.666666666666, 9299.333333333332) 40 | spillTree_predict <- c(34317.0, 50162.666666666664, 62151.0, 74604.0) 41 | 42 | bruteForce_train <- c(420.3333333333333, 714.6666666666666, 847.0, 876.6666666666666) 43 | bruteForce_predict <- c(18028.666666666664, 47125.0, 94626.66666666666, 156086.66666666666) 44 | 45 | local_runtimes <- data.frame(ns = ns_local, spillTree_train, spillTree_predict, bruteForce_train, bruteForce_predict) %>% 46 | gather(type, time, -ns) %>% 47 | separate(type, c("impl", "type")) %>% 48 | mutate(time = time / 1e3, impl = factor(impl)) 49 | 50 | plot_runtimes(local_runtimes) + ggtitle("MNIST Data on local[3]") 51 | ``` 52 | While the spill-tree implementation has much larger overhead, the savings on the search efficiency quickly 53 | trumps the naive brute-force approach when n gets larger. 54 | 55 | ## kNN on local R 56 | For perspective, we also ran the kNN using RANN in R which is based on KD-tree and knn in class package which is brute force based. 57 | 58 | Note: all Spark benchmark is average of three runs while all R local benchmark numbers are even less scientific with a single run instead. 59 | 60 | ```{r local-read} 61 | if(!file.exists("rann.rds") || !file.exists("knn.rds")) { 62 | mnist <- readr::read_csv("mnist.csv.gz", col_names = FALSE, progress = FALSE) 63 | ns_local[1] %>% 64 | sapply(function(n) { 65 | runtime <- select(mnist, -X1) %>% 66 | head(n) %>% 67 | as.matrix() %>% 68 | { system.time(nn2(., k = 2)) } 69 | runtime[1] 70 | }) %>% 71 | saveRDS("rann.rds") 72 | head(ns_local, 2) %>% 73 | sapply(function(n) { 74 | runtime <- select(mnist, -X1) %>% 75 | head(n) %>% 76 | as.matrix() %>% 77 | { system.time(knn1(., ., head(mnist, n)$X1)) } 78 | runtime[1] 79 | }) %>% 80 | saveRDS("knn.rds") 81 | } 82 | ``` 83 | ```{r local-rann} 84 | # due to RANN takes the shortcut when distance is zero and k = 1 it directly returns 85 | # we have to pick k = 2. experiments emprically show k = 2 ~ 10 has no significant effect on runtime 86 | rann_runtimes <- readRDS("rann.rds") 87 | ``` 88 | ```{r local-knn} 89 | r_knn_runtimes <- readRDS("knn.rds") 90 | ``` 91 | ```{r local-plot} 92 | local_runtimes <- data.frame(ns = ns_local, 93 | spillTree = spillTree_train + spillTree_predict, 94 | bruteForce = bruteForce_train + bruteForce_predict, 95 | kdtree = rann_runtimes * 1000, 96 | knn_r = c(r_knn_runtimes, NA, NA) * 1000) %>% 97 | gather(impl, time, -ns) %>% 98 | mutate(time = time / 1e3, impl = factor(impl)) 99 | 100 | ggplot(local_runtimes) + 101 | aes(ns, time, shape = impl, color = impl) + 102 | geom_line() + 103 | geom_point(size = 5) + 104 | theme_bw() + 105 | labs(x = "number of data points", y = "wall-clock duration (seconds)", title = "MNIST Data with R functions") 106 | ``` 107 | 108 | ## kNN on Spark Clsuter 109 | Next we test our kNN on AWS 10 c3.4xlarge nodes cluster (160 cores in total). 110 | 111 | Note for larger n, we only ran the algorithm using spill tree due to much longer runtime for naive approach. 112 | 113 | ```{r parse-runtime} 114 | parse_runtime <- function(raw_lines, raw_ns) { 115 | lines <- strsplit(raw_lines, "\n")[[1]] 116 | algos <- gsub("^#|:.+", "", lines) 117 | runtimes <- gsub("^#.+:|\\s|WrappedArray\\(|\\)", "", lines) %>% strsplit("/") 118 | df <- as.data.frame(t(read.csv(text = unlist(c(raw_ns, runtimes)), header = FALSE))) 119 | colnames(df) <- c("ns", do.call(paste, c(expand.grid(c("train", "predict"), algos), sep = "_"))) 120 | df 121 | } 122 | ``` 123 | ```{r cluster-runtime} 124 | cluster_runtimes <- rbind_list( 125 | parse_runtime("#knn: WrappedArray(7342.333333333333, 4962.666666666666, 5370.0, 5151.333333333333, 6091.333333333333, 8506.666666666666) / WrappedArray(6017.333333333333, 7072.0, 8856.0, 9742.666666666666, 15817.0, 32105.0) 126 | #naive: WrappedArray(1023.6666666666666, 627.6666666666666, 786.3333333333333, 797.3333333333333, 1201.0, 1873.3333333333333) / WrappedArray(4116.333333333333, 4601.666666666666, 5782.0, 7866.666666666666, 19555.666666666664, 70148.33333333333)", "2500,5000,7500,10000,20000,40000"), 127 | parse_runtime("#knn: WrappedArray(19883.333333333332, 28345.0, 50083.0) / WrappedArray(63017.33333333333, 189466.66666666666, 641681.3333333333)", "80000,160000,320000") 128 | ) %>% 129 | gather(type, time, -ns) %>% 130 | separate(type, c("type", "impl")) %>% 131 | mutate(time = time / 1e3, impl = factor(impl)) 132 | ``` 133 | ```{r cluster-plot} 134 | plot_runtimes(cluster_runtimes) + 135 | # scale_x_log10() + 136 | scale_y_log10() + 137 | ggtitle("MNIST Data on c3.4xlarge * 10") 138 | ``` 139 | 140 | Notice the y-axis is on log scale. 141 | 142 | ## Horizontal Scalability 143 | 144 | Finally we will examine how the algorithm scales with the number of cores. Again this is using AWS c3.4xlarge nodes. 145 | 146 | ```{r horizontal-runtimes} 147 | horizontal_runtimes <- rbind_list( 148 | parse_runtime("#knn: WrappedArray(15343.666666666666) / WrappedArray(277521.6666666666) 149 | #naive: WrappedArray(3510.333333333333) / WrappedArray(777180.0)", "20"), 150 | parse_runtime("#knn: WrappedArray(15357.0) / WrappedArray(127080.33333333333) 151 | #naive: WrappedArray(3717.0) / WrappedArray(308656.66666666666)", "40"), 152 | parse_runtime("#knn: WrappedArray(13953.0) / WrappedArray(61519.0) 153 | #naive: WrappedArray(2677.333333333333) / WrappedArray(201512.3333333333)", "80"), 154 | parse_runtime("#knn: WrappedArray(14890.666666666666) / WrappedArray(40220.33333333333) 155 | #naive: WrappedArray(2776.0) / WrappedArray(175310.0)", "160") 156 | ) %>% 157 | gather(type, time, -ns) %>% 158 | separate(type, c("type", "impl")) %>% 159 | mutate(time = time / 1e3, impl = factor(impl)) 160 | ``` 161 | ```{r horizontal-plot} 162 | horizontal_runtimes %>% 163 | group_by(impl, ns) %>% 164 | summarise(time = sum(time)) %>% 165 | arrange(ns) %>% 166 | mutate(speedup = first(time) / time) %>% 167 | ungroup() %>% 168 | ggplot() + 169 | aes(ns, speedup, color = impl, shape = impl) + 170 | geom_line() + 171 | geom_point() + 172 | theme_bw() + 173 | labs(x = "number of cores", title = "MNIST 60k Data on c3.4xlarge cluster") 174 | ``` 175 | 176 | Ideally we want the algorithm to scale linearly and we can see our kNN implementation scales quite linearly up to 80 cores The diminishing returns is likely attributed to the low number of observations. For 160 cores, each core is merely responsible for `r 60e3/160` observations on average. In practice, we were able to scale the implementation on hundrends of millions of observations much better with thousands of cores 177 | 178 | Note: The naive implementation scales much poorly because some tasks randomly decide to read from network. 179 | 180 | -------------------------------------------------------------------------------- /data/mnist/benchmark.md: -------------------------------------------------------------------------------- 1 | # kNN Benchmark 2 | 3 | 4 | 5 | We use [MNIST data](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html) to benchmark our kNN implementation. Recall our kNN implementation is based on hybrid spill tree and more details can be found at the main 6 | README. 7 | 8 | For tests with less than 60k observations, we use the regular MNIST. For those with more than 60k 9 | observations, we opted for the *mnist8m* processed dataset. 10 | 11 | ## kNN on local Spark 12 | We first compare kNN runtime performace on Spark local mode. 13 | ![](benchmark_files/figure-html/pre-run%20local-1.png) 14 | While the spill-tree implementation has much larger overhead, the savings on the search efficiency quickly 15 | trumps the naive brute-force approach when n gets larger. 16 | 17 | ## kNN on local R 18 | For perspective, we also ran the kNN using RANN in R which is based on KD-tree and knn in class package which is brute force based. 19 | 20 | Note: all Spark benchmark is average of three runs while all R local benchmark numbers are even less scientific with a single run instead. 21 | 22 | 23 | 24 | 25 | ![](benchmark_files/figure-html/local-plot-1.png) 26 | 27 | ## kNN on Spark Clsuter 28 | Next we test our kNN on AWS 10 c3.4xlarge nodes cluster (160 cores in total). 29 | 30 | Note for larger n, we only ran the algorithm using spill tree due to much longer runtime for naive approach. 31 | 32 | 33 | 34 | ![](benchmark_files/figure-html/cluster-plot-1.png) 35 | 36 | Notice the y-axis is on log scale. 37 | 38 | ## Horizontal Scalability 39 | 40 | Finally we will examine how the algorithm scales with the number of cores. Again this is using AWS c3.4xlarge nodes. 41 | 42 | 43 | ![](benchmark_files/figure-html/horizontal-plot-1.png) 44 | 45 | Ideally we want the algorithm to scale linearly and we can see our kNN implementation scales quite linearly up to 80 cores The diminishing returns is likely attributed to the low number of observations. For 160 cores, each core is merely responsible for 375 observations on average. In practice, we were able to scale the implementation on hundrends of millions of observations much better with thousands of cores 46 | 47 | Note: The naive implementation scales much poorly because some tasks randomly decide to read from network. 48 | 49 | -------------------------------------------------------------------------------- /data/mnist/benchmark.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b71287203aea8fd91ea80b1fba931ed7d6be2b2f0ff9386c477160a682faf5b6 3 | size 216467 4 | -------------------------------------------------------------------------------- /data/mnist/benchmark_files/figure-html/cluster-plot-1.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:db02263a3bbc76abf807ec73a456135c25be71199e849e6463f2caf0add8b3bb 3 | size 44208 4 | -------------------------------------------------------------------------------- /data/mnist/benchmark_files/figure-html/horizontal-plot-1.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1c9e74ae76ce8a0bf07e6292d91a78eb73a46e5c00a00c370479d089d2abfeb8 3 | size 39191 4 | -------------------------------------------------------------------------------- /data/mnist/benchmark_files/figure-html/local-plot-1.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2b0a04a74d8b0cbff3a7d27535498ef39d010b4f7fb01e2e216ab5f8937a41ec 3 | size 46682 4 | -------------------------------------------------------------------------------- /data/mnist/benchmark_files/figure-html/pre-run local-1.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b6b033ccc14ecd948face22fe6d681dcf1e649964b21a8e1dab26d55dde8029f 3 | size 46211 4 | -------------------------------------------------------------------------------- /data/mnist/knn.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saurfang/spark-knn/f87544dcd1f417bab32b42071aa3e4e3f13d2d5d/data/mnist/knn.rds -------------------------------------------------------------------------------- /data/mnist/mnist.bz2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:016b0b675ac6bff78ddb60444db7bef4cb42e3d8b307b12d7f62708160d044f7 3 | size 15179306 4 | -------------------------------------------------------------------------------- /data/mnist/mnist.csv.gz: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3c06183a4876a6364923d297797d0ed6cb84337e2061bf1b250211a0af323e37 3 | size 13266494 4 | -------------------------------------------------------------------------------- /data/mnist/rann.rds: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saurfang/spark-knn/f87544dcd1f417bab32b42071aa3e4e3f13d2d5d/data/mnist/rann.rds -------------------------------------------------------------------------------- /dev/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # install sbt and git 4 | curl https://bintray.com/sbt/rpm/rpm | sudo tee /etc/yum.repos.d/bintray-sbt-rpm.repo 5 | sudo yum install sbt git 6 | 7 | # download mnsit data (8MM observations) 8 | #wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist8m.bz2 9 | curl -vs https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist.bz2 | hadoop fs -put - mnist.bz2 10 | curl -vs https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/mnist8m.bz2 | bzcat | hadoop fs -put - mnist8m 11 | 12 | # clone spark-knn 13 | git clone -b benchmark https://github.com/saurfang/spark-knn.git 14 | cd spark-knn 15 | sbt examples/assembly 16 | 17 | # test all models for small n 18 | #knn: WrappedArray(7342.333333333333, 4962.666666666666, 5370.0, 5151.333333333333, 6091.333333333333, 8506.666666666666) / WrappedArray(6017.333333333333, 7072.0, 8856.0, 9742.666666666666, 15817.0, 32105.0) 19 | #naive: WrappedArray(1023.6666666666666, 627.6666666666666, 786.3333333333333, 797.3333333333333, 1201.0, 1873.3333333333333) / WrappedArray(4116.333333333333, 4601.666666666666, 5782.0, 7866.666666666666, 19555.666666666664, 70148.33333333333) 20 | spark-submit --master yarn --num-executors 20 --executor-cores 4 --executor-memory 5000m \ 21 | --class com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark \ 22 | spark-knn-examples/target/scala-2.11/spark-knn-examples-assembly-0.1-SNAPSHOT.jar \ 23 | 2500,5000,7500,10000,20000,40000 mnist.bz2 20 24 | 25 | # test tree only for large n 26 | 27 | #knn: WrappedArray(19883.333333333332, 28345.0, 50083.0) / WrappedArray(63017.33333333333, 189466.66666666666, 641681.3333333333) 28 | spark-submit --master yarn --num-executors 20 --executor-cores 8 --executor-memory 10000m \ 29 | --class com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark \ 30 | spark-knn-examples/target/scala-2.11/spark-knn-examples-assembly-0.1-SNAPSHOT.jar \ 31 | 80000,160000,320000 mnist8m 150 tree 32 | 33 | spark-submit --master yarn --num-executors 20 --executor-cores 8 --executor-memory 10000m \ 34 | --class com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark \ 35 | spark-knn-examples/target/scala-2.11/spark-knn-examples-assembly-0.1-SNAPSHOT.jar \ 36 | 640000,1280000 mnist8m 200 tree 37 | 38 | spark-submit --master yarn --num-executors 20 --executor-cores 8 --executor-memory 10000m \ 39 | --class com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark \ 40 | spark-knn-examples/target/scala-2.11/spark-knn-examples-assembly-0.1-SNAPSHOT.jar \ 41 | 2560000,5120000 mnist8m 200 tree 42 | 43 | # benchmark horizontal scalability 44 | 45 | #knn: WrappedArray(15343.666666666666) / WrappedArray(277521.6666666666) 46 | #naive: WrappedArray(3510.333333333333) / WrappedArray(777180.0) 47 | spark-submit --master yarn --num-executors 10 --executor-cores 2 --executor-memory 2500m \ 48 | --class com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark \ 49 | spark-knn-examples/target/scala-2.11/spark-knn-examples-assembly-0.1-SNAPSHOT.jar \ 50 | 100000 mnist.bz2 10 51 | 52 | #knn: WrappedArray(15357.0) / WrappedArray(127080.33333333333) 53 | #naive: WrappedArray(3717.0) / WrappedArray(308656.66666666666) 54 | spark-submit --master yarn --num-executors 20 --executor-cores 2 --executor-memory 2500m \ 55 | --class com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark \ 56 | spark-knn-examples/target/scala-2.11/spark-knn-examples-assembly-0.1-SNAPSHOT.jar \ 57 | 100000 mnist.bz2 25 58 | 59 | #knn: WrappedArray(13953.0) / WrappedArray(61519.0) 60 | #naive: WrappedArray(2677.333333333333) / WrappedArray(201512.3333333333) 61 | spark-submit --master yarn --num-executors 40 --executor-cores 2 --executor-memory 2500m \ 62 | --class com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark \ 63 | spark-knn-examples/target/scala-2.11/spark-knn-examples-assembly-0.1-SNAPSHOT.jar \ 64 | 100000 mnist.bz2 50 65 | 66 | #knn: WrappedArray(14890.666666666666) / WrappedArray(40220.33333333333) 67 | #naive: WrappedArray(2776.0) / WrappedArray(175310.0) 68 | spark-submit --master yarn --num-executors 20 --executor-cores 8 --executor-memory 10000m \ 69 | --class com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark \ 70 | spark-knn-examples/target/scala-2.11/spark-knn-examples-assembly-0.1-SNAPSHOT.jar \ 71 | 100000 mnist.bz2 100 72 | -------------------------------------------------------------------------------- /project/Common.scala: -------------------------------------------------------------------------------- 1 | import com.typesafe.sbt.GitVersioning 2 | import sbt._ 3 | import Keys._ 4 | import com.typesafe.sbt.GitPlugin.autoImport._ 5 | import sbtsparkpackage.SparkPackagePlugin.autoImport._ 6 | 7 | import scala.language.experimental.macros 8 | import scala.reflect.macros.Context 9 | 10 | object Common { 11 | val commonSettings = Seq( 12 | organization in ThisBuild := "com.github.saurfang", 13 | javacOptions ++= Seq("-source", "1.8", "-target", "1.8"), 14 | scalacOptions ++= Seq("-target:jvm-1.8", "-deprecation", "-feature"), 15 | //git.useGitDescribe := true, 16 | git.baseVersion := "0.0.1", 17 | parallelExecution in test := false, 18 | updateOptions := updateOptions.value.withCachedResolution(true), 19 | sparkVersion := "3.0.1", 20 | sparkComponents += "mllib", 21 | spIgnoreProvided := true 22 | ) 23 | 24 | def knnProject(path: String): Project = macro knnProjectMacroImpl 25 | 26 | def knnProjectMacroImpl(c: Context)(path: c.Expr[String]) = { 27 | import c.universe._ 28 | reify { 29 | (Project.projectMacroImpl(c).splice in file(path.splice)). 30 | enablePlugins(GitVersioning). 31 | settings(name := path.splice). 32 | settings(Dependencies.Versions). 33 | settings(commonSettings) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /project/Dependencies.scala: -------------------------------------------------------------------------------- 1 | import sbt._ 2 | import Keys._ 3 | 4 | object Dependencies { 5 | val Versions = Seq( 6 | crossScalaVersions := Seq("2.12.8", "2.11.12"), 7 | scalaVersion := crossScalaVersions.value.head 8 | ) 9 | 10 | object Compile { 11 | val breeze_natives = "org.scalanlp" %% "breeze-natives" % "1.0" % "provided" 12 | 13 | object Test { 14 | val scalatest = "org.scalatest" %% "scalatest" % "3.1.0" % "test" 15 | val sparktest = "org.apache.spark" %% "spark-core" % "3.0.1" % "test" classifier "tests" 16 | } 17 | } 18 | 19 | import Compile._ 20 | import Test._ 21 | val l = libraryDependencies 22 | 23 | val core = l ++= Seq(scalatest, sparktest) 24 | val examples = core +: (l ++= Seq(breeze_natives)) 25 | } 26 | -------------------------------------------------------------------------------- /project/SparkSubmit.scala: -------------------------------------------------------------------------------- 1 | import sbtsparksubmit.SparkSubmitPlugin.autoImport._ 2 | 3 | object SparkSubmit { 4 | lazy val settings = 5 | SparkSubmitSetting( 6 | SparkSubmitSetting("sparkMNIST", 7 | Seq( 8 | "--master", "local[3]", 9 | "--class", "com.github.saurfang.spark.ml.knn.examples.MNIST" 10 | ) 11 | ), 12 | SparkSubmitSetting("sparkMNISTCross", 13 | Seq( 14 | "--master", "local[3]", 15 | "--class", "com.github.saurfang.spark.ml.knn.examples.MNISTCrossValidation" 16 | ) 17 | ), 18 | SparkSubmitSetting("sparkMNISTBench", 19 | Seq( 20 | "--master", "local[3]", 21 | "--class", "com.github.saurfang.spark.ml.knn.examples.MNISTBenchmark" 22 | ) 23 | ) 24 | ) 25 | } 26 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 0.13.18 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.3") 2 | 3 | addSbtPlugin("me.lessis" % "bintray-sbt" % "0.3.0") 4 | 5 | addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "0.8.5") 6 | 7 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.3") 8 | 9 | addSbtPlugin("com.github.saurfang" % "sbt-spark-submit" % "0.0.4") 10 | 11 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") 12 | 13 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0" 14 | excludeAll ExclusionRule(organization = "com.danieltrinh")) 15 | libraryDependencies += "org.scalariform" %% "scalariform" % "0.1.8" 16 | 17 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" 18 | addSbtPlugin("org.spark-packages" % "sbt-spark-package" % "0.2.6") 19 | 20 | addSbtPlugin("ch.epfl.scala" % "sbt-scalafix" % "0.9.4") 21 | -------------------------------------------------------------------------------- /python/pyspark_knn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saurfang/spark-knn/f87544dcd1f417bab32b42071aa3e4e3f13d2d5d/python/pyspark_knn/__init__.py -------------------------------------------------------------------------------- /python/pyspark_knn/ml/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saurfang/spark-knn/f87544dcd1f417bab32b42071aa3e4e3f13d2d5d/python/pyspark_knn/ml/__init__.py -------------------------------------------------------------------------------- /python/pyspark_knn/ml/classification.py: -------------------------------------------------------------------------------- 1 | from pyspark.ml.wrapper import JavaEstimator, JavaModel 2 | from pyspark.ml.param.shared import * 3 | from pyspark.mllib.common import inherit_doc 4 | from pyspark import keyword_only 5 | 6 | 7 | @inherit_doc 8 | class KNNClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, 9 | HasProbabilityCol, HasRawPredictionCol, HasInputCols, 10 | HasThresholds, HasSeed, HasWeightCol): 11 | @keyword_only 12 | def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", 13 | seed=None, topTreeSize=1000, topTreeLeafSize=10, subTreeLeafSize=30, bufferSize=-1.0, 14 | bufferSizeSampleSize=list(range(100, 1000 + 1, 100)), balanceThreshold=0.7, 15 | k=5, neighborsCol="neighbors", maxNeighbors=float("inf"), rawPredictionCol="rawPrediction", 16 | probabilityCol="probability"): 17 | super(KNNClassifier, self).__init__() 18 | self._java_obj = self._new_java_obj( 19 | "org.apache.spark.ml.classification.KNNClassifier", self.uid) 20 | 21 | self.topTreeSize = Param(self, "topTreeSize", "number of points to sample for top-level tree") 22 | self.topTreeLeafSize = Param(self, "topTreeLeafSize", 23 | "number of points at which to switch to brute-force for top-level tree") 24 | self.subTreeLeafSize = Param(self, "subTreeLeafSize", 25 | "number of points at which to switch to brute-force for distributed sub-trees") 26 | self.bufferSize = Param(self, "bufferSize", 27 | "size of buffer used to construct spill trees and top-level tree search") 28 | self.bufferSizeSampleSize = Param(self, "bufferSizeSampleSize", 29 | "number of sample sizes to take when estimating buffer size") 30 | self.balanceThreshold = Param(self, "balanceThreshold", 31 | "fraction of total points at which spill tree reverts back to metric tree if " 32 | "either child contains more points") 33 | self.k = Param(self, "k", "number of neighbors to find") 34 | self.neighborsCol = Param(self, "neighborsCol", "column names for returned neighbors") 35 | self.maxNeighbors = Param(self, "maxNeighbors", "maximum distance to find neighbors") 36 | 37 | self._setDefault(topTreeSize=1000, topTreeLeafSize=10, subTreeLeafSize=30, bufferSize=-1.0, 38 | bufferSizeSampleSize=list(range(100, 1000 + 1, 100)), balanceThreshold=0.7, 39 | k=5, neighborsCol="neighbors", maxNeighbors=float("inf")) 40 | 41 | kwargs = self._input_kwargs 42 | self.setParams(**kwargs) 43 | 44 | @keyword_only 45 | def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", 46 | seed=None, topTreeSize=1000, topTreeLeafSize=10, subTreeLeafSize=30, bufferSize=-1.0, 47 | bufferSizeSampleSize=list(range(100, 1000 + 1, 100)), balanceThreshold=0.7, 48 | k=5, neighborsCol="neighbors", maxNeighbors=float("inf"), rawPredictionCol="rawPrediction", 49 | probabilityCol="probability"): 50 | kwargs = self._input_kwargs 51 | return self._set(**kwargs) 52 | 53 | def _create_model(self, java_model): 54 | return KNNClassificationModel(java_model) 55 | 56 | 57 | class KNNClassificationModel(JavaModel): 58 | """ 59 | Model fitted by KNNClassifier. 60 | """ 61 | def __init__(self, java_model): 62 | super(KNNClassificationModel, self).__init__(java_model) 63 | 64 | # note: look at https://issues.apache.org/jira/browse/SPARK-10931 in the future 65 | self.bufferSize = Param(self, "bufferSize", 66 | "size of buffer used to construct spill trees and top-level tree search") 67 | self.k = Param(self, "k", "number of neighbors to find") 68 | self.neighborsCol = Param(self, "neighborsCol", "column names for returned neighbors") 69 | self.maxNeighbors = Param(self, "maxNeighbors", "maximum distance to find neighbors") 70 | 71 | self._transfer_params_from_java() 72 | -------------------------------------------------------------------------------- /python/pyspark_knn/ml/regression.py: -------------------------------------------------------------------------------- 1 | from pyspark.ml.wrapper import JavaEstimator, JavaModel 2 | from pyspark.ml.param.shared import * 3 | from pyspark.mllib.common import inherit_doc 4 | from pyspark import keyword_only 5 | 6 | 7 | @inherit_doc 8 | class KNNRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, 9 | HasInputCols, HasThresholds, HasSeed, HasWeightCol): 10 | @keyword_only 11 | def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", 12 | seed=None, topTreeSize=1000, topTreeLeafSize=10, subTreeLeafSize=30, bufferSize=-1.0, 13 | bufferSizeSampleSize=list(range(100, 1000 + 1, 100)), balanceThreshold=0.7, 14 | k=5, neighborsCol="neighbors", maxNeighbors=float("inf")): 15 | super(KNNRegression, self).__init__() 16 | self._java_obj = self._new_java_obj( 17 | "org.apache.spark.ml.regression.KNNRegression", self.uid) 18 | 19 | self.topTreeSize = Param(self, "topTreeSize", "number of points to sample for top-level tree") 20 | self.topTreeLeafSize = Param(self, "topTreeLeafSize", 21 | "number of points at which to switch to brute-force for top-level tree") 22 | self.subTreeLeafSize = Param(self, "subTreeLeafSize", 23 | "number of points at which to switch to brute-force for distributed sub-trees") 24 | self.bufferSize = Param(self, "bufferSize", 25 | "size of buffer used to construct spill trees and top-level tree search") 26 | self.bufferSizeSampleSize = Param(self, "bufferSizeSampleSize", 27 | "number of sample sizes to take when estimating buffer size") 28 | self.balanceThreshold = Param(self, "balanceThreshold", 29 | "fraction of total points at which spill tree reverts back to metric tree if " 30 | "either child contains more points") 31 | self.k = Param(self, "k", "number of neighbors to find") 32 | self.neighborsCol = Param(self, "neighborsCol", "column names for returned neighbors") 33 | self.maxNeighbors = Param(self, "maxNeighbors", "maximum distance to find neighbors") 34 | 35 | self._setDefault(topTreeSize=1000, topTreeLeafSize=10, subTreeLeafSize=30, bufferSize=-1.0, 36 | bufferSizeSampleSize=list(range(100, 1000 + 1, 100)), balanceThreshold=0.7, 37 | k=5, neighborsCol="neighbors", maxNeighbors=float("inf")) 38 | 39 | kwargs = self._input_kwargs 40 | self.setParams(**kwargs) 41 | 42 | @keyword_only 43 | def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", 44 | seed=None, topTreeSize=1000, topTreeLeafSize=10, subTreeLeafSize=30, bufferSize=-1.0, 45 | bufferSizeSampleSize=list(range(100, 1000 + 1, 100)), balanceThreshold=0.7, 46 | k=5, neighborsCol="neighbors", maxNeighbors=float("inf")): 47 | kwargs = self._input_kwargs 48 | return self._set(**kwargs) 49 | 50 | def _create_model(self, java_model): 51 | return KNNRegressionModel(java_model) 52 | 53 | 54 | class KNNRegressionModel(JavaModel): 55 | """ 56 | Model fitted by KNNRegression. 57 | """ 58 | def __init__(self, java_model): 59 | super(KNNRegressionModel, self).__init__(java_model) 60 | 61 | # note: look at https://issues.apache.org/jira/browse/SPARK-10931 in the future 62 | self.bufferSize = Param(self, "bufferSize", 63 | "size of buffer used to construct spill trees and top-level tree search") 64 | self.k = Param(self, "k", "number of neighbors to find") 65 | self.neighborsCol = Param(self, "neighborsCol", "column names for returned neighbors") 66 | self.maxNeighbors = Param(self, "maxNeighbors", "maximum distance to find neighbors") 67 | 68 | self._transfer_params_from_java() 69 | -------------------------------------------------------------------------------- /python/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | setup( 3 | name="pyspark_knn", 4 | url="https://github.com/saurfang/spark-knn", 5 | version="0.1", 6 | zip_safe=True, 7 | packages=find_packages(), 8 | ) 9 | -------------------------------------------------------------------------------- /python/test.py: -------------------------------------------------------------------------------- 1 | from pyspark import SparkContext 2 | from pyspark.sql import SQLContext 3 | from pyspark.mllib.linalg import Vectors 4 | 5 | from pyspark_knn.ml.classification import KNNClassifier 6 | 7 | 8 | # This is a simple test app. Use the following command to run: 9 | # spark-submit --driver-class-path ../spark-knn-core/target/scala-2.11/spark-knn_*.jar test.py 10 | 11 | sc = SparkContext(appName='test') 12 | sqlContext = SQLContext(sc) 13 | 14 | print('Initializing') 15 | training = sqlContext.createDataFrame([ 16 | [Vectors.dense([0.2, 0.9]), 0.0], 17 | [Vectors.dense([0.2, 1.0]), 0.0], 18 | [Vectors.dense([0.2, 0.1]), 1.0], 19 | [Vectors.dense([0.2, 0.2]), 1.0], 20 | ], ['features', 'label']) 21 | 22 | test = sqlContext.createDataFrame([ 23 | [Vectors.dense([0.1, 0.0])], 24 | [Vectors.dense([0.3, 0.8])] 25 | ], ['features']) 26 | 27 | knn = KNNClassifier(k=1, topTreeSize=1, topTreeLeafSize=1, subTreeLeafSize=1, bufferSizeSampleSize=[1, 2, 3]) # bufferSize=-1.0, 28 | print('Params:', [p.name for p in knn.params]) 29 | print('Fitting') 30 | model = knn.fit(training) 31 | print('bufferSize:', model._java_obj.getBufferSize()) 32 | print('Predicting') 33 | predictions = model.transform(test) 34 | print('Predictions:') 35 | for row in predictions.collect(): 36 | print(row) 37 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 2 | Scalastyle standard configuration 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /spark-knn-core/src/main/scala/org/apache/spark/ml/classification/KNNClassifier.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.classification 2 | 3 | import org.apache.spark.broadcast.Broadcast 4 | import org.apache.spark.ml.knn._ 5 | import org.apache.spark.ml.param.ParamMap 6 | import org.apache.spark.ml.param.shared.HasWeightCol 7 | import org.apache.spark.ml.util.{Identifiable, SchemaUtils} 8 | import org.apache.spark.ml.linalg._ 9 | import org.apache.spark.ml.feature.LabeledPoint 10 | import org.apache.spark.rdd.RDD 11 | import org.apache.spark.sql.types.{DoubleType, StructType} 12 | import org.apache.spark.sql.{DataFrame, Dataset, Row} 13 | import org.apache.spark.storage.StorageLevel 14 | import org.apache.spark.SparkException 15 | 16 | import scala.collection.mutable.ArrayBuffer 17 | 18 | /** 19 | * [[https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm]] for classification. 20 | * An object is classified by a majority vote of its neighbors, with the object being assigned to 21 | * the class most common among its k nearest neighbors. 22 | */ 23 | class KNNClassifier(override val uid: String) extends ProbabilisticClassifier[Vector, KNNClassifier, KNNClassificationModel] 24 | with KNNParams with HasWeightCol { 25 | 26 | def this() = this(Identifiable.randomUID("knnc")) 27 | 28 | /** @group setParam */ 29 | override def setFeaturesCol(value: String): this.type = set(featuresCol, value) 30 | 31 | /** @group setParam */ 32 | override def setLabelCol(value: String): this.type = { 33 | set(labelCol, value) 34 | 35 | if ($(weightCol).isEmpty) { 36 | set(inputCols, Array(value)) 37 | } else { 38 | set(inputCols, Array(value, $(weightCol))) 39 | } 40 | } 41 | 42 | //fill in default label col 43 | setDefault(inputCols, Array($(labelCol))) 44 | 45 | /** @group setWeight */ 46 | def setWeightCol(value: String): this.type = { 47 | set(weightCol, value) 48 | 49 | if (value.isEmpty) { 50 | set(inputCols, Array($(labelCol))) 51 | } else { 52 | set(inputCols, Array($(labelCol), value)) 53 | } 54 | } 55 | 56 | setDefault(weightCol -> "") 57 | 58 | /** @group setParam */ 59 | def setK(value: Int): this.type = set(k, value) 60 | 61 | /** @group setParam */ 62 | def setTopTreeSize(value: Int): this.type = set(topTreeSize, value) 63 | 64 | /** @group setParam */ 65 | def setTopTreeLeafSize(value: Int): this.type = set(topTreeLeafSize, value) 66 | 67 | /** @group setParam */ 68 | def setSubTreeLeafSize(value: Int): this.type = set(subTreeLeafSize, value) 69 | 70 | /** @group setParam */ 71 | def setBufferSizeSampleSizes(value: Array[Int]): this.type = set(bufferSizeSampleSizes, value) 72 | 73 | /** @group setParam */ 74 | def setBalanceThreshold(value: Double): this.type = set(balanceThreshold, value) 75 | 76 | /** @group setParam */ 77 | def setSeed(value: Long): this.type = set(seed, value) 78 | 79 | override protected def train(dataset: Dataset[_]): KNNClassificationModel = { 80 | // Extract columns from data. If dataset is persisted, do not persist oldDataset. 81 | val instances = extractLabeledPoints(dataset).map { 82 | case LabeledPoint(label: Double, features: Vector) => (label, features) 83 | } 84 | val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE 85 | if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) 86 | 87 | val labelSummarizer = instances.treeAggregate( 88 | new MultiClassSummarizer)( 89 | seqOp = (c, v) => (c, v) match { 90 | case (labelSummarizer: MultiClassSummarizer, (label: Double, features: Vector)) => 91 | labelSummarizer.add(label) 92 | }, 93 | combOp = (c1, c2) => (c1, c2) match { 94 | case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) => 95 | classSummarizer1.merge(classSummarizer2) 96 | }) 97 | 98 | val histogram = labelSummarizer.histogram 99 | val numInvalid = labelSummarizer.countInvalid 100 | val numClasses = histogram.length 101 | 102 | if (numInvalid != 0) { 103 | val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + 104 | s"Found $numInvalid invalid labels." 105 | logError(msg) 106 | throw new SparkException(msg) 107 | } 108 | 109 | val knnModel = copyValues(new KNN()).fit(dataset) 110 | knnModel.toNewClassificationModel(uid, numClasses) 111 | } 112 | 113 | override def fit(dataset: Dataset[_]): KNNClassificationModel = { 114 | // Need to overwrite this method because we need to manually overwrite the buffer size 115 | // because it is not supposed to stay the same as the Classifier if user sets it to -1. 116 | transformSchema(dataset.schema, logging = true) 117 | val model = train(dataset) 118 | val bufferSize = model.getBufferSize 119 | copyValues(model.setParent(this)).setBufferSize(bufferSize) 120 | } 121 | 122 | override def copy(extra: ParamMap): KNNClassifier = defaultCopy(extra) 123 | } 124 | 125 | class KNNClassificationModel private[ml]( 126 | override val uid: String, 127 | val topTree: Broadcast[Tree], 128 | val subTrees: RDD[Tree], 129 | val _numClasses: Int 130 | ) extends ProbabilisticClassificationModel[Vector, KNNClassificationModel] 131 | with KNNModelParams with HasWeightCol with Serializable { 132 | require(subTrees.getStorageLevel != StorageLevel.NONE, 133 | "KNNModel is not designed to work with Trees that have not been cached") 134 | 135 | /** @group setParam */ 136 | def setK(value: Int): this.type = set(k, value) 137 | 138 | /** @group setParam */ 139 | def setBufferSize(value: Double): this.type = set(bufferSize, value) 140 | 141 | override def numClasses: Int = _numClasses 142 | 143 | //TODO: This can benefit from DataSet API 144 | override def transform(dataset: Dataset[_]): DataFrame = { 145 | val getWeight: Row => Double = { 146 | if($(weightCol).isEmpty) { 147 | r => 1.0 148 | } else { 149 | r => r.getDouble(1) 150 | } 151 | } 152 | 153 | val neighborRDD : RDD[(Long, Array[(Row, Double)])] = transform(dataset, topTree, subTrees) 154 | val merged = neighborRDD 155 | .map { 156 | case (id, labelsDists) => 157 | val (labels, _) = labelsDists.unzip 158 | val vector = new Array[Double](numClasses) 159 | var i = 0 160 | while (i < labels.length) { 161 | vector(labels(i).getDouble(0).toInt) += getWeight(labels(i)) 162 | i += 1 163 | } 164 | val rawPrediction = Vectors.dense(vector) 165 | lazy val probability = raw2probability(rawPrediction) 166 | lazy val prediction = probability2prediction(probability) 167 | 168 | val values = new ArrayBuffer[Any] 169 | if ($(rawPredictionCol).nonEmpty) { 170 | values.append(rawPrediction) 171 | } 172 | if ($(probabilityCol).nonEmpty) { 173 | values.append(probability) 174 | } 175 | if ($(predictionCol).nonEmpty) { 176 | values.append(prediction) 177 | } 178 | 179 | (id, values) 180 | } 181 | 182 | dataset.sqlContext.createDataFrame( 183 | dataset.toDF().rdd.zipWithIndex().map { case (row, i) => (i, row) } 184 | .leftOuterJoin(merged) //make sure we don't lose any observations 185 | .map { 186 | case (i, (row, values)) => Row.fromSeq(row.toSeq ++ values.get) 187 | }, 188 | transformSchema(dataset.schema) 189 | ) 190 | } 191 | 192 | override def transformSchema(schema: StructType): StructType = { 193 | var transformed = schema 194 | if ($(rawPredictionCol).nonEmpty) { 195 | transformed = SchemaUtils.appendColumn(transformed, $(rawPredictionCol), new VectorUDT) 196 | } 197 | if ($(probabilityCol).nonEmpty) { 198 | transformed = SchemaUtils.appendColumn(transformed, $(probabilityCol), new VectorUDT) 199 | } 200 | if ($(predictionCol).nonEmpty) { 201 | transformed = SchemaUtils.appendColumn(transformed, $(predictionCol), DoubleType) 202 | } 203 | transformed 204 | } 205 | 206 | override def copy(extra: ParamMap): KNNClassificationModel = { 207 | val copied = new KNNClassificationModel(uid, topTree, subTrees, numClasses) 208 | copyValues(copied, extra).setParent(parent) 209 | } 210 | 211 | override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { 212 | rawPrediction match { 213 | case dv: DenseVector => 214 | var i = 0 215 | val size = dv.size 216 | 217 | var sum = 0.0 218 | while (i < size) { 219 | sum += dv.values(i) 220 | i += 1 221 | } 222 | 223 | i = 0 224 | while (i < size) { 225 | dv.values(i) /= sum 226 | i += 1 227 | } 228 | 229 | dv 230 | case sv: SparseVector => 231 | throw new RuntimeException("Unexpected error in KNNClassificationModel:" + 232 | " raw2probabilitiesInPlace encountered SparseVector") 233 | } 234 | } 235 | 236 | override def predictRaw(features: Vector): Vector = { 237 | throw new SparkException("predictRaw function should not be called directly since kNN prediction is done in distributed fashion. Use transform instead.") 238 | } 239 | } 240 | -------------------------------------------------------------------------------- /spark-knn-core/src/main/scala/org/apache/spark/ml/knn/DistanceMetric.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.knn 2 | 3 | import org.apache.spark.ml.knn.KNN.VectorWithNorm 4 | import org.apache.spark.mllib.knn.KNNUtils 5 | 6 | 7 | object DistanceMetric { 8 | def apply(metric: String): DistanceMetric = { 9 | metric match { 10 | case "" | "euclidean" => EuclideanDistanceMetric 11 | case "nan_euclidean" => NaNEuclideanDistanceMetric 12 | case _ => throw new IllegalArgumentException(s"Unsupported distance metric: $metric") 13 | } 14 | } 15 | } 16 | trait DistanceMetric { 17 | def fastSquaredDistance(v1: VectorWithNorm, v2: VectorWithNorm): Double 18 | 19 | def fastDistance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { 20 | math.sqrt(fastSquaredDistance(v1, v2)) 21 | } 22 | } 23 | 24 | object EuclideanDistanceMetric extends DistanceMetric with Serializable { 25 | override def fastSquaredDistance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { 26 | KNNUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) 27 | } 28 | } 29 | 30 | /** 31 | * Calculate NaN-Euclidean distance by using only non NaN values in each vector 32 | */ 33 | object NaNEuclideanDistanceMetric extends DistanceMetric with Serializable { 34 | 35 | class InfiniteZeroIterator(step: Int = 1) extends Iterator[(Int, Double)] { 36 | var index = -1 * step 37 | override def hasNext: Boolean = true 38 | override def next(): (Int, Double) = { 39 | index += step 40 | (index, 0.0) 41 | } 42 | } 43 | override def fastSquaredDistance(v1: VectorWithNorm, v2: VectorWithNorm): Double = { 44 | var it1 = v1.vector.activeIterator 45 | var it2 = v2.vector.activeIterator 46 | if(!it1.hasNext && !it2.hasNext) return 0.0 47 | if(!it1.hasNext) { 48 | it1 = new InfiniteZeroIterator 49 | } else if(!it2.hasNext) { 50 | it2 = new InfiniteZeroIterator 51 | } 52 | var result = 0.0 53 | // initial case 54 | var (idx1, val1) = it1.next() 55 | var (idx2, val2) = it2.next() 56 | // iterator over the vectors 57 | while((it1.hasNext || it2.hasNext) && !(it1.isInstanceOf[InfiniteZeroIterator] && it2.isInstanceOf[InfiniteZeroIterator])) { 58 | var (advance1, advance2) = (false, false) 59 | val (left, right) = if(idx1 < idx2) { 60 | // advance iterator on first vector 61 | advance1 = true 62 | (val1, 0.0) 63 | } else if(idx1 > idx2) { 64 | // advance iterator on second vector 65 | advance2 = true 66 | (0.0, val2) 67 | } else { 68 | // indexes matches 69 | advance1 = true 70 | advance2 = true 71 | (val1, val2) 72 | } 73 | if(!left.isNaN && !right.isNaN) { 74 | result += Math.pow(left - right, 2) 75 | } 76 | if(advance1) { 77 | if(!it1.hasNext) it1 = new InfiniteZeroIterator 78 | val next1 = it1.next() 79 | idx1 = next1._1 80 | val1 = next1._2 81 | } 82 | if(advance2) { 83 | if(!it2.hasNext) it2 = new InfiniteZeroIterator 84 | val next2 = it2.next() 85 | idx2 = next2._1 86 | val2 = next2._2 87 | } 88 | } 89 | if(idx1 == idx2 && !val1.isNaN && !val2.isNaN) { 90 | result += Math.pow(val1 - val2, 2) 91 | } 92 | result 93 | } 94 | } -------------------------------------------------------------------------------- /spark-knn-core/src/main/scala/org/apache/spark/ml/knn/KNN.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.knn 2 | 3 | import breeze.linalg.{DenseVector, Vector => BV} 4 | import breeze.stats._ 5 | import org.apache.spark.broadcast.Broadcast 6 | import org.apache.spark.ml.classification.KNNClassificationModel 7 | import org.apache.spark.ml.knn.KNN.{KNNPartitioner, RowWithVector, VectorWithNorm} 8 | import org.apache.spark.ml.param._ 9 | import org.apache.spark.ml.param.shared._ 10 | import org.apache.spark.ml.regression.KNNRegressionModel 11 | import org.apache.spark.ml.util._ 12 | import org.apache.spark.ml.{Estimator, Model} 13 | import org.apache.spark.ml.linalg.{Vector, VectorUDT, Vectors} 14 | import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ 15 | import org.apache.spark.rdd.{RDD, ShuffledRDD} 16 | import org.apache.spark.sql.types._ 17 | import org.apache.spark.sql.{DataFrame, Dataset, Row} 18 | import org.apache.spark.storage.StorageLevel 19 | import org.apache.spark.util.random.XORShiftRandom 20 | import org.apache.spark.{HashPartitioner, Partitioner} 21 | import org.apache.log4j 22 | import org.apache.spark.mllib.knn.KNNUtils 23 | 24 | import scala.annotation.tailrec 25 | import scala.collection.mutable.ArrayBuffer 26 | import scala.util.hashing.byteswap64 27 | 28 | // features column => vector, input columns => auxiliary columns to return by KNN model 29 | private[ml] trait KNNModelParams extends Params with HasFeaturesCol with HasInputCols { 30 | /** 31 | * Param for the column name for returned neighbors. 32 | * Default: "neighbors" 33 | * 34 | * @group param 35 | */ 36 | val neighborsCol = new Param[String](this, "neighborsCol", "column names for returned neighbors") 37 | 38 | /** @group getParam */ 39 | def getNeighborsCol: String = $(neighborsCol) 40 | 41 | /** 42 | * Param for distance column that will create a distance column of each nearest neighbor 43 | * Default: no distance column will be used 44 | * 45 | * @group param 46 | */ 47 | val distanceCol = new Param[String](this, "distanceCol", "column that includes each neighbors' distance as an additional column") 48 | 49 | /** @group getParam */ 50 | def getDistanceCol: String = $(distanceCol) 51 | 52 | /** 53 | * Param for number of neighbors to find (> 0). 54 | * Default: 5 55 | * 56 | * @group param 57 | */ 58 | val k = new IntParam(this, "k", "number of neighbors to find", ParamValidators.gt(0)) 59 | 60 | /** @group getParam */ 61 | def getK: Int = $(k) 62 | 63 | /** 64 | * Param for maximum distance to find neighbors 65 | * Default: Double.PositiveInfinity 66 | * 67 | * @group param 68 | */ 69 | val maxDistance = new DoubleParam(this, "maxNeighbors", "maximum distance to find neighbors", // todo: maxDistance or maxNeighbors? 70 | ParamValidators.gt(0)) 71 | 72 | /** @group getParam */ 73 | def getMaxDistance: Double = $(maxDistance) 74 | 75 | /** 76 | * Param for size of buffer used to construct spill trees and top-level tree search. 77 | * Note the buffer size is 2 * tau as described in the paper. 78 | * 79 | * When buffer size is 0.0, the tree itself reverts to a metric tree. 80 | * -1.0 triggers automatic effective nearest neighbor distance estimation. 81 | * 82 | * Default: -1.0 83 | * 84 | * @group param 85 | */ 86 | val bufferSize = new DoubleParam(this, "bufferSize", 87 | "size of buffer used to construct spill trees and top-level tree search", ParamValidators.gtEq(-1.0)) 88 | 89 | /** @group getParam */ 90 | def getBufferSize: Double = $(bufferSize) 91 | 92 | /** 93 | * Param for metric to use for distance calculation. 94 | * Default: euclidean 95 | * 96 | * @group param 97 | */ 98 | val metric = new Param[String](this, "metric", 99 | "Distance metric for searching neighbors. Possible values: 'euclidean', 'nan_euclidean'") 100 | 101 | /** @group getParam */ 102 | def getMetric: String = $(metric) 103 | 104 | //fill in default distance metric 105 | setDefault(metric, "euclidean") 106 | 107 | private[ml] def transform(data: RDD[Vector], topTree: Broadcast[Tree], subTrees: RDD[Tree]): RDD[(Long, Array[(Row,Double)])] = { 108 | val searchData = data.zipWithIndex() 109 | .flatMap { 110 | case (vector, index) => 111 | val vectorWithNorm = new VectorWithNorm(vector) 112 | val idx = KNN.searchIndices(vectorWithNorm, topTree.value, $(bufferSize), distanceMetric=DistanceMetric($(metric))) 113 | .map(i => (i, (vectorWithNorm, index))) 114 | 115 | assert(idx.nonEmpty, s"indices must be non-empty: $vector ($index)") 116 | idx 117 | } 118 | .partitionBy(new HashPartitioner(subTrees.partitions.length)) 119 | 120 | // for each partition, search points within corresponding child tree 121 | val results = searchData.zipPartitions(subTrees) { 122 | (childData, trees) => 123 | val tree = trees.next() 124 | assert(!trees.hasNext) 125 | childData.flatMap { 126 | case (_, (point, i)) => 127 | tree.query(point, $(k)).collect { 128 | case (neighbor, distance) if distance <= $(maxDistance) => 129 | (i, (neighbor.row, distance)) 130 | } 131 | } 132 | } 133 | 134 | // merge results by point index together and keep topK results 135 | results.topByKey($(k))(Ordering.by(-_._2)) 136 | .map { case (i, seq) => (i, seq) } 137 | } 138 | 139 | private[ml] def transform(dataset: Dataset[_], topTree: Broadcast[Tree], subTrees: RDD[Tree]): RDD[(Long, Array[(Row, Double)])] = { 140 | transform(dataset.select($(featuresCol)).rdd.map(_.getAs[Vector](0)), topTree, subTrees) 141 | } 142 | 143 | } 144 | 145 | private[ml] trait KNNParams extends KNNModelParams with HasSeed { 146 | /** 147 | * Param for number of points to sample for top-level tree (> 0). 148 | * Default: 1000 149 | * 150 | * @group param 151 | */ 152 | val topTreeSize = new IntParam(this, "topTreeSize", "number of points to sample for top-level tree", ParamValidators.gt(0)) 153 | 154 | /** @group getParam */ 155 | def getTopTreeSize: Int = $(topTreeSize) 156 | 157 | /** 158 | * Param for number of points at which to switch to brute-force for top-level tree (> 0). 159 | * Default: 5 160 | * 161 | * @group param 162 | */ 163 | val topTreeLeafSize = new IntParam(this, "topTreeLeafSize", 164 | "number of points at which to switch to brute-force for top-level tree", ParamValidators.gt(0)) 165 | 166 | /** @group getParam */ 167 | def getTopTreeLeafSize: Int = $(topTreeLeafSize) 168 | 169 | /** 170 | * Param for number of points at which to switch to brute-force for distributed sub-trees (> 0). 171 | * Default: 20 172 | * 173 | * @group param 174 | */ 175 | val subTreeLeafSize = new IntParam(this, "subTreeLeafSize", 176 | "number of points at which to switch to brute-force for distributed sub-trees", ParamValidators.gt(0)) 177 | 178 | /** @group getParam */ 179 | def getSubTreeLeafSize: Int = $(subTreeLeafSize) 180 | 181 | /** 182 | * Param for number of sample sizes to take when estimating buffer size (at least two samples). 183 | * Default: 100 to 1000 by 100 184 | * 185 | * @group param 186 | */ 187 | val bufferSizeSampleSizes = new IntArrayParam(this, "bufferSizeSampleSize", // todo: should this have an 's' at the end? 188 | "number of sample sizes to take when estimating buffer size", { arr: Array[Int] => arr.length > 1 && arr.forall(_ > 0) }) 189 | 190 | /** @group getParam */ 191 | def getBufferSizeSampleSizes: Array[Int] = $(bufferSizeSampleSizes) 192 | 193 | /** 194 | * Param for fraction of total points at which spill tree reverts back to metric tree 195 | * if either child contains more points (0 <= rho <= 1). 196 | * Default: 70% 197 | * 198 | * @group param 199 | */ 200 | val balanceThreshold = new DoubleParam(this, "balanceThreshold", 201 | "fraction of total points at which spill tree reverts back to metric tree if either child contains more points", 202 | ParamValidators.inRange(0, 1)) 203 | 204 | /** @group getParam */ 205 | def getBalanceThreshold: Double = $(balanceThreshold) 206 | 207 | setDefault(topTreeSize -> 1000, topTreeLeafSize -> 10, subTreeLeafSize -> 30, 208 | bufferSize -> -1.0, bufferSizeSampleSizes -> (100 to 1000 by 100).toArray, balanceThreshold -> 0.7, 209 | k -> 5, neighborsCol -> "neighbors", distanceCol -> "", maxDistance -> Double.PositiveInfinity) 210 | 211 | /** 212 | * Validates and transforms the input schema. 213 | * 214 | * @param schema input schema 215 | * @return output schema 216 | */ 217 | protected def validateAndTransformSchema(schema: StructType): StructType = { 218 | SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) 219 | val auxFeatures = $(inputCols).map(c => schema(c)) 220 | val schemaWithNeighbors = SchemaUtils.appendColumn(schema, $(neighborsCol), ArrayType(StructType(auxFeatures))) 221 | 222 | if ($(distanceCol).isEmpty) { 223 | schemaWithNeighbors 224 | } else { 225 | SchemaUtils.appendColumn(schemaWithNeighbors, $(distanceCol), ArrayType(DoubleType)) 226 | } 227 | } 228 | } 229 | 230 | /** 231 | * kNN Model facilitates k-Nestrest Neighbor search by storing distributed hybrid spill tree. 232 | * Top level tree is a MetricTree but instead of using back tracking, it searches all possible leaves in parallel 233 | * to avoid multiple iterations. It uses the same buffer size that is used in model training, when the search 234 | * vector falls into the buffer zone of the node, it dispatches search to both children. 235 | * 236 | * A high level overview of the search phases is as follows: 237 | * 238 | * 1. For each vector to search, go through the top level tree to output a pair of (index, point) 239 | * 1. Repartition search points by partition index 240 | * 1. Search each point through the hybrid spill tree in that particular partition 241 | * 1. For each point, merge results from different partitions and keep top k results. 242 | * 243 | */ 244 | class KNNModel private[ml]( 245 | override val uid: String, 246 | val topTree: Broadcast[Tree], 247 | val subTrees: RDD[Tree] 248 | ) extends Model[KNNModel] with KNNModelParams { 249 | require(subTrees.getStorageLevel != StorageLevel.NONE, 250 | "KNNModel is not designed to work with Trees that have not been cached") 251 | 252 | /** @group setParam */ 253 | def setNeighborsCol(value: String): this.type = set(neighborsCol, value) 254 | 255 | /** @group setParam */ 256 | def setDistanceCol(value: String): this.type = set(distanceCol, value) 257 | 258 | /** @group setParam */ 259 | def setK(value: Int): this.type = set(k, value) 260 | 261 | /** @group setParam */ 262 | def setMaxDistance(value: Double): this.type = set(maxDistance, value) 263 | 264 | /** @group setParam */ 265 | def setBufferSize(value: Double): this.type = set(bufferSize, value) 266 | 267 | //TODO: All these can benefit from DataSet API 268 | override def transform(dataset: Dataset[_]): DataFrame = { 269 | val merged: RDD[(Long, Array[(Row,Double)])] = transform(dataset, topTree, subTrees) 270 | 271 | val withDistance = $(distanceCol).nonEmpty 272 | 273 | dataset.sqlContext.createDataFrame( 274 | dataset.toDF().rdd.zipWithIndex().map { case (row, i) => (i, row) } 275 | .leftOuterJoin(merged) 276 | .map { 277 | case (i, (row, neighborsAndDistances)) => 278 | val (neighbors, distances) = neighborsAndDistances.map(_.unzip).getOrElse((Array.empty[Row], Array.empty[Double])) 279 | if (withDistance) { 280 | Row.fromSeq(row.toSeq :+ neighbors :+ distances) 281 | } else { 282 | Row.fromSeq(row.toSeq :+ neighbors) 283 | } 284 | }, 285 | transformSchema(dataset.schema) 286 | ) 287 | } 288 | 289 | override def transformSchema(schema: StructType): StructType = { 290 | val auxFeatures = $(inputCols).map(c => schema(c)) 291 | val schemaWithNeighbors = SchemaUtils.appendColumn(schema, $(neighborsCol), ArrayType(StructType(auxFeatures))) 292 | if ($(distanceCol).isEmpty) { 293 | schemaWithNeighbors 294 | } else { 295 | SchemaUtils.appendColumn(schemaWithNeighbors, $(distanceCol), ArrayType(DoubleType)) 296 | } 297 | } 298 | 299 | override def copy(extra: ParamMap): KNNModel = { 300 | val copied = new KNNModel(uid, topTree, subTrees) 301 | copyValues(copied, extra).setParent(parent) 302 | } 303 | 304 | def toNewClassificationModel(uid: String, numClasses: Int): KNNClassificationModel = { 305 | copyValues(new KNNClassificationModel(uid, topTree, subTrees, numClasses)) 306 | } 307 | 308 | def toNewRegressionModel(uid: String): KNNRegressionModel = { 309 | copyValues(new KNNRegressionModel(uid, topTree, subTrees)) 310 | } 311 | } 312 | 313 | /** 314 | * k-Nearest Neighbors (kNN) algorithm 315 | * 316 | * kNN finds k closest observations in training dataset. It can be used for both classification and regression. 317 | * Furthermore it can also be used for other purposes such as input to clustering algorithm. 318 | * 319 | * While the brute-force approach requires no pre-training, each prediction requires going through the entire training 320 | * set resulting O(n log(k)) runtime per individual prediction using a heap keep track of neighbor candidates. 321 | * Many different implementations have been proposed such as Locality Sensitive Hashing (LSH), KD-Tree, Metric Tree and etc. 322 | * Each algorithm has its shortcomings that prevent them to be effective on large-scale and/or high-dimensional dataset. 323 | * 324 | * This is an implementation of kNN based upon distributed Hybrid Spill-Trees where training points are organized into 325 | * distributed binary trees. The algorithm is designed to support accurate approximate kNN search but by tuning parameters 326 | * an exact search can also be performed with cost of additional runtime. 327 | * 328 | * Each binary tree node is either a 329 | * 330 | * '''Metric Node''': 331 | * Metric Node partition points exclusively into two children by finding two pivot points and divide by middle plane. 332 | * When searched, the child whose pivot is closer to query vector is searched first. Back tracking is required to 333 | * ensure accuracy in this case, where the other child should be searched if it can possibly contain better neighbor 334 | * based upon candidates picked during previous search. 335 | * 336 | * '''Spill Node''': 337 | * Spill Node also partitions points into two children however there are an overlapping buffer between the two pivot 338 | * points. The larger the buffer size, the less effective the node eliminates points thus could increase tree height. 339 | * When searched, defeatist search is used where only one child is searched and no back tracking happens in this 340 | * process. Because of the buffer between two children, we are likely to end up with good enough candidates without 341 | * searching the other part of the tree. 342 | * 343 | * While Spill Node promises O(h) runtime where h is the tree height, the tree is deeper than Metric Tree's O(log n) 344 | * height on average. Furthermore, when it comes down to leaves where points are more closer to each other, the static 345 | * buffer size means more points will end up in the buffer. Therefore a Balance Threshold (rho) is introduced: when 346 | * either child of Spill Node makes up more than rho fraction of the total points at this level, Spill Node is reverted 347 | * back to a Metric Node. 348 | * 349 | * A high level overview of the algorithm is as follows: 350 | * 351 | * 1. Sample M data points (M is relatively small and can be held in driver) 352 | * 1. Build the top level metric tree 353 | * 1. Repartition RDD by assigning each point to leaf node of the above tree 354 | * 1. Build a hybrid spill tree at each partition 355 | * 356 | * This concludes the training phase of kNN. 357 | * See [[KNNModel]] for details on prediction phase. 358 | * 359 | * 360 | * This algorithm is described in [[http://dx.doi.org/10.1109/WACV.2007.18]] where it was shown to scale well in terms of 361 | * number of observations and dimensions, bounded by the available memory across clusters (billions in paper's example). 362 | * This implementation adapts the MapReduce algorithm to work with Spark. 363 | * 364 | */ 365 | class KNN(override val uid: String) extends Estimator[KNNModel] with KNNParams { 366 | def this() = this(Identifiable.randomUID("knn")) 367 | 368 | /** @group setParam */ 369 | def setFeaturesCol(value: String): this.type = set(featuresCol, value) 370 | 371 | /** @group setParam */ 372 | def setK(value: Int): this.type = set(k, value) 373 | 374 | /** @group setParam */ 375 | def setAuxCols(value: Array[String]): this.type = set(inputCols, value) 376 | 377 | /** @group setParam */ 378 | def setTopTreeSize(value: Int): this.type = set(topTreeSize, value) 379 | 380 | /** @group setParam */ 381 | def setTopTreeLeafSize(value: Int): this.type = set(topTreeLeafSize, value) 382 | 383 | /** @group setParam */ 384 | def setSubTreeLeafSize(value: Int): this.type = set(subTreeLeafSize, value) 385 | 386 | /** @group setParam */ 387 | def setBufferSizeSampleSizes(value: Array[Int]): this.type = set(bufferSizeSampleSizes, value) 388 | 389 | /** @group setParam */ 390 | def setBalanceThreshold(value: Double): this.type = set(balanceThreshold, value) 391 | 392 | /** @group setParam */ 393 | def setSeed(value: Long): this.type = set(seed, value) 394 | 395 | override def fit(dataset: Dataset[_]): KNNModel = { 396 | val distanceMetric = DistanceMetric($(metric)) 397 | val rand = new XORShiftRandom($(seed)) 398 | //prepare data for model estimation 399 | val data = dataset.selectExpr($(featuresCol), $(inputCols).mkString("struct(", ",", ")")) 400 | .rdd 401 | .map(row => new RowWithVector(row.getAs[Vector](0), row.getStruct(1))) 402 | //sample data to build top-level tree 403 | val sampled = data.sample(withReplacement = false, $(topTreeSize).toDouble / dataset.count(), rand.nextLong()).collect() 404 | val topTree = MetricTree.build(sampled, $(topTreeLeafSize), rand.nextLong(), distanceMetric) 405 | //build partitioner using top-level tree 406 | val part = new KNNPartitioner(topTree, distanceMetric) 407 | //noinspection ScalaStyle 408 | val repartitioned = new ShuffledRDD[RowWithVector, Null, Null](data.map(v => (v, null)), part).keys 409 | 410 | val tau = 411 | if ($(balanceThreshold) > 0 && $(bufferSize) < 0) { 412 | val estimates = KNN.estimateTau(data, $(bufferSizeSampleSizes), rand.nextLong(), distanceMetric) 413 | math.max(0, estimates) 414 | } else { 415 | math.max(0, $(bufferSize)) 416 | } 417 | logInfo("Tau is: " + tau) 418 | 419 | val trees = repartitioned.mapPartitionsWithIndex { 420 | (partitionId, itr) => 421 | val rand = new XORShiftRandom(byteswap64($(seed) ^ partitionId)) 422 | val childTree = 423 | HybridTree.build(itr.toIndexedSeq, $(subTreeLeafSize), tau, $(balanceThreshold), rand.nextLong(), distanceMetric) 424 | 425 | Iterator(childTree) 426 | }.persist(StorageLevel.MEMORY_AND_DISK) 427 | // TODO: force persisting trees primarily for benchmark. any reason not to do this for regular runs? 428 | trees.count() 429 | 430 | val model = new KNNModel(uid, trees.context.broadcast(topTree), trees).setParent(this) 431 | copyValues(model).setBufferSize(tau) 432 | } 433 | 434 | override def transformSchema(schema: StructType): StructType = { 435 | validateAndTransformSchema(schema) 436 | } 437 | 438 | override def copy(extra: ParamMap): KNN = defaultCopy(extra) 439 | } 440 | 441 | 442 | object KNN { 443 | 444 | val logger = log4j.Logger.getLogger(classOf[KNN]) 445 | 446 | /** 447 | * VectorWithNorm can use more efficient algorithm to calculate distance 448 | */ 449 | case class VectorWithNorm(vector: Vector, norm: Double) { 450 | def this(vector: Vector) = this(vector, Vectors.norm(vector, 2)) 451 | 452 | def this(vector: BV[Double]) = this(Vectors.fromBreeze(vector)) 453 | } 454 | 455 | /** 456 | * VectorWithNorm plus auxiliary row information 457 | */ 458 | case class RowWithVector(vector: VectorWithNorm, row: Row) { 459 | def this(vector: Vector, row: Row) = this(new VectorWithNorm(vector), row) 460 | } 461 | 462 | /** 463 | * Estimate a suitable buffer size based on dataset 464 | * 465 | * A suitable buffer size is the minimum size such that nearest neighbors can be accurately found even at 466 | * boundary of splitting plane between pivot points. Therefore assuming points are uniformly distributed in 467 | * high dimensional space, it should be approximately the average distance between points. 468 | * 469 | * Specifically the number of points within a certain radius of a given point is proportionally to the density of 470 | * points raised to the effective number of dimensions, of which manifold data points exist on: 471 | * R_s = \frac{c}{N_s ** 1/d} 472 | * where R_s is the radius, N_s is the number of points, d is effective number of dimension, and c is a constant. 473 | * 474 | * To estimate R_s_all for entire dataset, we can take samples of the dataset of different size N_s to compute R_s. 475 | * We can estimate c and d using linear regression. Lastly we can calculate R_s_all using total number of observation 476 | * in dataset. 477 | * 478 | */ 479 | def estimateTau(data: RDD[RowWithVector], sampleSize: Array[Int], seed: Long, distanceMetric: DistanceMetric): Double = { 480 | val total = data.count() 481 | 482 | // take samples of points for estimation 483 | val samples = data.mapPartitionsWithIndex { 484 | case (partitionId, itr) => 485 | val rand = new XORShiftRandom(byteswap64(seed ^ partitionId)) 486 | itr.flatMap { 487 | p => sampleSize.zipWithIndex 488 | .filter { case (size, _) => rand.nextDouble() * total < size } 489 | .map { case (size, index) => (index, p) } 490 | } 491 | } 492 | // compute N_s and R_s pairs 493 | val estimators = samples 494 | .groupByKey() 495 | .map { 496 | case (index, points) => (points.size, computeAverageDistance(points, distanceMetric)) 497 | }.collect().distinct 498 | 499 | // collect x and y vectors 500 | val x = DenseVector(estimators.map { case (n, _) => math.log(n) }) 501 | val y = DenseVector(estimators.map { case (_, d) => math.log(d) }) 502 | 503 | // estimate log(R_s) = alpha + beta * log(N_s) 504 | val xMeanVariance = meanAndVariance(x) 505 | val xmean = xMeanVariance.mean 506 | val yMeanVariance = meanAndVariance(y) 507 | val ymean = yMeanVariance.mean 508 | 509 | val corr = (mean(x *:* y) - xmean * ymean) / math.sqrt((mean(x *:* x) - xmean * xmean) * (mean(y *:* y) - ymean * ymean)) 510 | 511 | val beta = corr * yMeanVariance.stdDev / xMeanVariance.stdDev 512 | val alpha = ymean - beta * xmean 513 | val rs = math.exp(alpha + beta * math.log(total)) 514 | 515 | if (beta > 0 || beta.isNaN || rs.isNaN) { 516 | val yMax = breeze.linalg.max(y) 517 | logger.error( 518 | s"""Unable to estimate Tau with positive beta: $beta. This maybe because data is too small. 519 | |Setting to $yMax which is the maximum average distance we found in the sample. 520 | |This may leads to poor accuracy. Consider manually set bufferSize instead. 521 | |You can also try setting balanceThreshold to zero so only metric trees are built.""".stripMargin) 522 | yMax 523 | } else { 524 | // c = alpha, d = - 1 / beta 525 | rs / math.sqrt(-1 / beta) 526 | } 527 | } 528 | 529 | // compute the average distance of nearest neighbors within points using brute-force 530 | private[this] def computeAverageDistance(points: Iterable[RowWithVector], distanceMetric: DistanceMetric): Double = { 531 | val distances = points.map { 532 | point => points.map(p => distanceMetric.fastSquaredDistance(p.vector, point.vector)).filter(_ > 0).min 533 | }.map(math.sqrt) 534 | 535 | distances.sum / distances.size 536 | } 537 | 538 | /** 539 | * Search leaf index used by KNNPartitioner to partition training points 540 | * 541 | * @param v one training point to partition 542 | * @param tree top tree constructed using sampled points 543 | * @param acc accumulator used to help determining leaf index 544 | * @return leaf/partition index 545 | */ 546 | @tailrec 547 | private[knn] def searchIndex(v: RowWithVector, tree: Tree, acc: Int = 0, distanceMetric: DistanceMetric): Int = { 548 | tree match { 549 | case node: MetricTree => 550 | val leftDistance = distanceMetric.fastSquaredDistance(node.leftPivot, v.vector) 551 | val rightDistance = distanceMetric.fastSquaredDistance(node.rightPivot, v.vector) 552 | if (leftDistance < rightDistance) { 553 | searchIndex(v, node.leftChild, acc, distanceMetric) 554 | } else { 555 | searchIndex(v, node.rightChild, acc + node.leftChild.leafCount, distanceMetric) 556 | } 557 | case _ => acc // reached leaf 558 | } 559 | } 560 | 561 | //TODO: Might want to make this tail recursive 562 | private[ml] def searchIndices(v: VectorWithNorm, tree: Tree, tau: Double, acc: Int = 0, distanceMetric: DistanceMetric): Seq[Int] = { 563 | tree match { 564 | case node: MetricTree => 565 | val leftDistance = distanceMetric.fastDistance(node.leftPivot, v) 566 | val rightDistance = distanceMetric.fastDistance(node.rightPivot, v) 567 | 568 | val buffer = new ArrayBuffer[Int] 569 | if (leftDistance - rightDistance <= tau) { 570 | buffer ++= searchIndices(v, node.leftChild, tau, acc, distanceMetric) 571 | } 572 | 573 | if (rightDistance - leftDistance <= tau) { 574 | buffer ++= searchIndices(v, node.rightChild, tau, acc + node.leftChild.leafCount, distanceMetric) 575 | } 576 | 577 | buffer 578 | case _ => Seq(acc) // reached leaf 579 | } 580 | } 581 | 582 | /** 583 | * Partitioner used to map vector to leaf node which determines the partition it goes to 584 | * 585 | * @param tree `Tree` used to find leaf 586 | */ 587 | class KNNPartitioner[T <: RowWithVector](tree: Tree, distanceMetric: DistanceMetric) extends Partitioner { 588 | override def numPartitions: Int = tree.leafCount 589 | 590 | override def getPartition(key: Any): Int = { 591 | key match { 592 | case v: RowWithVector => searchIndex(v, tree, distanceMetric=distanceMetric) 593 | case _ => throw new IllegalArgumentException(s"Key must be of type Vector but got: $key") 594 | } 595 | } 596 | 597 | } 598 | 599 | } 600 | -------------------------------------------------------------------------------- /spark-knn-core/src/main/scala/org/apache/spark/ml/knn/MetricTree.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.knn 2 | 3 | import breeze.linalg._ 4 | import org.apache.spark.ml.knn.KNN._ 5 | import org.apache.spark.ml.linalg.{Vector, Vectors} 6 | import org.apache.spark.util.random.XORShiftRandom 7 | 8 | import scala.collection.mutable 9 | 10 | /** 11 | * A [[Tree]] is used to store data points used in k-NN search. It represents 12 | * a binary tree node. It keeps track of the pivot vector which closely approximate 13 | * the center of all vectors within the node. All vectors are within the radius of 14 | * distance to the pivot vector. Finally it knows the number of leaves to help 15 | * determining partition index. 16 | */ 17 | private[ml] abstract class Tree extends Serializable { 18 | val leftChild: Tree 19 | val rightChild: Tree 20 | val size: Int 21 | val leafCount: Int 22 | val pivot: VectorWithNorm 23 | val radius: Double 24 | val distanceMetric: DistanceMetric 25 | 26 | def iterator: Iterator[RowWithVector] 27 | 28 | /** 29 | * k-NN query using pre-built [[Tree]] 30 | * @param v vector to query 31 | * @param k number of nearest neighbor 32 | * @return a list of neighbor that is nearest to the query vector 33 | */ 34 | def query(v: Vector, k: Int = 1): Iterable[(RowWithVector, Double)] = query(new VectorWithNorm(v), k) 35 | def query(v: VectorWithNorm, k: Int): Iterable[(RowWithVector, Double)] = query(new KNNCandidates(v, k, distanceMetric)).toIterable 36 | 37 | /** 38 | * Refine k-NN candidates using data in this [[Tree]] 39 | */ 40 | private[knn] def query(candidates: KNNCandidates): KNNCandidates 41 | 42 | /** 43 | * Compute QueryCost defined as || v.center - q || - r 44 | * when >= v.r node can be pruned 45 | * for MetricNode this can be used to determine which child does queryVector falls into 46 | */ 47 | private[knn] def distance(candidates: KNNCandidates): Double = distance(candidates.queryVector) 48 | 49 | private[knn] def distance(v: VectorWithNorm): Double = 50 | if(pivot.vector.size > 0) distanceMetric.fastDistance(pivot, v) else 0.0 51 | } 52 | 53 | private[knn] 54 | case class Empty(distanceMetric: DistanceMetric) extends Tree { 55 | override val leftChild = this 56 | override val rightChild = this 57 | override val size = 0 58 | override val leafCount = 0 59 | override val pivot = new VectorWithNorm(Vectors.dense(Array.empty[Double])) 60 | override val radius = 0.0 61 | 62 | override def iterator: Iterator[RowWithVector] = Iterator.empty 63 | override def query(candidates: KNNCandidates): KNNCandidates = candidates 64 | } 65 | 66 | private[knn] 67 | case class Leaf (data: IndexedSeq[RowWithVector], 68 | pivot: VectorWithNorm, 69 | radius: Double, 70 | distanceMetric: DistanceMetric) extends Tree { 71 | override val leftChild = Empty(distanceMetric) 72 | override val rightChild = Empty(distanceMetric) 73 | override val size = data.size 74 | override val leafCount = 1 75 | 76 | override def iterator: Iterator[RowWithVector] = data.iterator 77 | 78 | // brute force k-NN search at the leaf 79 | override def query(candidates: KNNCandidates): KNNCandidates = { 80 | val sorted = data 81 | .map{ v => (v, distanceMetric.fastDistance(candidates.queryVector, v.vector)) } 82 | .sortBy(_._2) 83 | 84 | for((v, d) <- sorted if candidates.notFull || d < candidates.maxDistance) 85 | candidates.insert(v, d) 86 | 87 | candidates 88 | } 89 | } 90 | 91 | private[knn] 92 | object Leaf { 93 | def apply(data: IndexedSeq[RowWithVector], distanceMetric: DistanceMetric): Leaf = { 94 | val vectors = data.map(_.vector.vector.asBreeze) 95 | val (minV, maxV) = vectors.foldLeft((vectors.head, vectors.head)) { 96 | case ((accMin, accMax), bv) => 97 | (min(accMin, bv), max(accMax, bv)) 98 | } 99 | val pivot = new VectorWithNorm((minV + maxV) / 2.0) 100 | val radius = math.sqrt(squaredDistance(minV, maxV)) / 2.0 101 | Leaf(data, pivot, radius, distanceMetric) 102 | } 103 | } 104 | 105 | /** 106 | * A [[MetricTree]] represents a MetricNode where data are split into two partitions: left and right. 107 | * There exists two pivot vectors: leftPivot and rightPivot to determine the partitioning. 108 | * Pivot vector should be the middle of leftPivot and rightPivot vectors. 109 | * Points that is closer to leftPivot than to rightPivot belongs to leftChild and rightChild otherwise. 110 | * 111 | * During search, because we have information about each child's pivot and radius, we can see if the 112 | * hyper-sphere intersects with current candidates sphere. If so, we search the child that has the 113 | * most potential (i.e. the child which has the closest pivot). 114 | * Once that child has been fully searched, we backtrack to the remaining child and search if necessary. 115 | * 116 | * This is much more efficient than naive brute force search. However backtracking can take a lot of time 117 | * when the number of dimension is high (due to longer time to compute distance and the volume growing much 118 | * faster than radius). 119 | */ 120 | private[knn] 121 | case class MetricTree(leftChild: Tree, 122 | leftPivot: VectorWithNorm, 123 | rightChild: Tree, 124 | rightPivot: VectorWithNorm, 125 | pivot: VectorWithNorm, 126 | radius: Double, 127 | distanceMetric: DistanceMetric 128 | ) extends Tree { 129 | override val size = leftChild.size + rightChild.size 130 | override val leafCount = leftChild.leafCount + rightChild.leafCount 131 | 132 | override def iterator: Iterator[RowWithVector] = leftChild.iterator ++ rightChild.iterator 133 | override def query(candidates: KNNCandidates): KNNCandidates = { 134 | lazy val leftQueryCost = leftChild.distance(candidates) 135 | lazy val rightQueryCost = rightChild.distance(candidates) 136 | // only query if at least one of the children is worth looking 137 | if(candidates.notFull || 138 | leftQueryCost - candidates.maxDistance < leftChild.radius || 139 | rightQueryCost - candidates.maxDistance < rightChild.radius ){ 140 | val remainingChild = { 141 | if (leftQueryCost <= rightQueryCost) { 142 | leftChild.query(candidates) 143 | rightChild 144 | } else { 145 | rightChild.query(candidates) 146 | leftChild 147 | } 148 | } 149 | // check again to see if the remaining child is still worth looking 150 | if (candidates.notFull || 151 | remainingChild.distance(candidates) - candidates.maxDistance < remainingChild.radius) { 152 | remainingChild.query(candidates) 153 | } 154 | } 155 | candidates 156 | } 157 | } 158 | 159 | object MetricTree { 160 | /** 161 | * Build a (metric)[[Tree]] that facilitate k-NN query 162 | * 163 | * @param data vectors that contain all training data 164 | * @param seed random number generator seed used in pivot point selecting 165 | * @return a [[Tree]] can be used to do k-NN query 166 | */ 167 | def build(data: IndexedSeq[RowWithVector], leafSize: Int = 1, seed: Long = 0L, distanceMetric: DistanceMetric): Tree = { 168 | val size = data.size 169 | if(size == 0) { 170 | Empty(distanceMetric) 171 | } else if(size <= leafSize) { 172 | Leaf(data, distanceMetric) 173 | } else { 174 | val rand = new XORShiftRandom(seed) 175 | val randomPivot = data(rand.nextInt(size)).vector 176 | val leftPivot = data.maxBy(v => distanceMetric.fastSquaredDistance(randomPivot, v.vector)).vector 177 | if(leftPivot == randomPivot) { 178 | // all points are identical (or only one point left) 179 | Leaf(data, randomPivot, 0.0, distanceMetric) 180 | } else { 181 | val rightPivot = data.maxBy(v => distanceMetric.fastSquaredDistance(leftPivot, v.vector)).vector 182 | val pivot = new VectorWithNorm(Vectors.fromBreeze((leftPivot.vector.asBreeze + rightPivot.vector.asBreeze) / 2.0)) 183 | val radius = math.sqrt(data.map(v => distanceMetric.fastSquaredDistance(pivot, v.vector)).max) 184 | val (leftPartition, rightPartition) = data.partition{v => 185 | val distanceToLeft = distanceMetric.fastSquaredDistance(leftPivot, v.vector) 186 | val distanceToRight = distanceMetric.fastSquaredDistance(rightPivot, v.vector) 187 | distanceToLeft < distanceToRight 188 | } 189 | 190 | MetricTree( 191 | build(leftPartition, leafSize, rand.nextLong(), distanceMetric), 192 | leftPivot, 193 | build(rightPartition, leafSize, rand.nextLong(), distanceMetric), 194 | rightPivot, 195 | pivot, 196 | radius, 197 | distanceMetric 198 | ) 199 | } 200 | } 201 | } 202 | } 203 | 204 | /** 205 | * A [[SpillTree]] represents a SpillNode. Just like [[MetricTree]], it splits data into two partitions. 206 | * However, instead of partition data into exactly two halves, it contains a buffer zone with size of tau. 207 | * Left child contains all data left to the center plane + tau (in the leftPivot -> rightPivot direction). 208 | * Right child contains all data right to the center plane - tau. 209 | * 210 | * Search doesn't do backtracking but rather adopt a defeatist search where it search the most prominent 211 | * child and that child only. The buffer ensures such strategy doesn't result in a poor outcome. 212 | */ 213 | private[knn] 214 | case class SpillTree(leftChild: Tree, 215 | leftPivot: VectorWithNorm, 216 | rightChild: Tree, 217 | rightPivot: VectorWithNorm, 218 | pivot: VectorWithNorm, 219 | radius: Double, 220 | tau: Double, 221 | bufferSize: Int, 222 | distanceMetric: DistanceMetric 223 | ) extends Tree { 224 | override val size = leftChild.size + rightChild.size - bufferSize 225 | override val leafCount = leftChild.leafCount + rightChild.leafCount 226 | 227 | override def iterator: Iterator[RowWithVector] = 228 | leftChild.iterator ++ rightChild.iterator.filter(childFilter(leftPivot, rightPivot)) 229 | 230 | override def query(candidates: KNNCandidates): KNNCandidates = { 231 | if (size <= candidates.k - candidates.candidates.size) { 232 | iterator.foreach(candidates.insert) 233 | } else { 234 | val leftQueryCost = distanceMetric.fastSquaredDistance(candidates.queryVector, leftPivot) 235 | val rightQueryCost = distanceMetric.fastSquaredDistance(candidates.queryVector, rightPivot) 236 | 237 | (if (leftQueryCost <= rightQueryCost) leftChild else rightChild).query(candidates) 238 | 239 | // fill candidates with points from other child excluding buffer so we don't double count. 240 | // depending on K and how high we are in the tree, this can be very expensive and undesirable 241 | // TODO: revisit this idea when we do large scale testing 242 | if(candidates.notFull) { 243 | (if (leftQueryCost <= rightQueryCost) { 244 | rightChild.iterator.filter(childFilter(leftPivot, rightPivot)) 245 | } else { 246 | leftChild.iterator.filter(childFilter(rightPivot, leftPivot)) 247 | }).foreach(candidates.tryInsert) 248 | } 249 | } 250 | candidates 251 | } 252 | 253 | private[this] val childFilter: (VectorWithNorm, VectorWithNorm) => RowWithVector => Boolean = 254 | (p1, p2) => p => distanceMetric.fastDistance(p.vector, p1) - distanceMetric.fastDistance(p.vector, p2) > tau 255 | } 256 | 257 | 258 | object SpillTree { 259 | /** 260 | * Build a (spill)[[Tree]] that facilitate k-NN query 261 | * 262 | * @param data vectors that contain all training data 263 | * @param tau overlapping size 264 | * @param seed random number generators seed used in pivot point selecting 265 | * @return a [[Tree]] can be used to do k-NN query 266 | */ 267 | def build(data: IndexedSeq[RowWithVector], leafSize: Int = 1, tau: Double, seed: Long = 0L, distanceMetric: DistanceMetric): Tree = { 268 | val size = data.size 269 | if (size == 0) { 270 | Empty(distanceMetric) 271 | } else if (size <= leafSize) { 272 | Leaf(data, distanceMetric) 273 | } else { 274 | val rand = new XORShiftRandom(seed) 275 | val randomPivot = data(rand.nextInt(size)).vector 276 | val leftPivot = data.maxBy(v => distanceMetric.fastSquaredDistance(randomPivot, v.vector)).vector 277 | if (leftPivot == randomPivot) { 278 | // all points are identical (or only one point left) 279 | Leaf(data, randomPivot, 0.0, distanceMetric) 280 | } else { 281 | val rightPivot = data.maxBy(v => distanceMetric.fastSquaredDistance(leftPivot, v.vector)).vector 282 | val pivot = new VectorWithNorm(Vectors.fromBreeze((leftPivot.vector.asBreeze + rightPivot.vector.asBreeze) / 2.0)) 283 | val radius = math.sqrt(data.map(v => distanceMetric.fastSquaredDistance(pivot, v.vector)).max) 284 | val dataWithDistance = data.map(v => 285 | (v, distanceMetric.fastDistance(leftPivot, v.vector), distanceMetric.fastDistance(rightPivot, v.vector)) 286 | ) 287 | val leftPartition = dataWithDistance.filter { case (_, left, right) => left - right <= tau }.map(_._1) 288 | val rightPartition = dataWithDistance.filter { case (_, left, right) => right - left <= tau }.map(_._1) 289 | 290 | SpillTree( 291 | build(leftPartition, leafSize, tau, rand.nextLong(), distanceMetric), 292 | leftPivot, 293 | build(rightPartition, leafSize, tau, rand.nextLong(), distanceMetric), 294 | rightPivot, 295 | pivot, 296 | radius, 297 | tau, 298 | leftPartition.size + rightPartition.size - size, 299 | distanceMetric 300 | ) 301 | } 302 | } 303 | } 304 | } 305 | 306 | object HybridTree { 307 | /** 308 | * Build a (hybrid-spill) `Tree` that facilitate k-NN query 309 | * 310 | * @param data vectors that contain all training data 311 | * @param seed random number generator seed used in pivot point selecting 312 | * @param tau overlapping size 313 | * @param rho balance threshold 314 | * @return a `Tree` can be used to do k-NN query 315 | */ 316 | //noinspection ScalaStyle 317 | def build(data: IndexedSeq[RowWithVector], 318 | leafSize: Int = 1, 319 | tau: Double, 320 | rho: Double = 0.7, 321 | seed: Long = 0L, 322 | distanceMetric: DistanceMetric): Tree = { 323 | val size = data.size 324 | if (size == 0) { 325 | Empty(distanceMetric) 326 | } else if (size <= leafSize) { 327 | Leaf(data, distanceMetric) 328 | } else { 329 | val rand = new XORShiftRandom(seed) 330 | val randomPivot = data(rand.nextInt(size)).vector 331 | val leftPivot = data.maxBy(v => distanceMetric.fastSquaredDistance(randomPivot, v.vector)).vector 332 | if (leftPivot == randomPivot) { 333 | // all points are identical (or only one point left) 334 | Leaf(data, randomPivot, 0.0, distanceMetric) 335 | } else { 336 | val rightPivot = data.maxBy(v => distanceMetric.fastSquaredDistance(leftPivot, v.vector)).vector 337 | val pivot = new VectorWithNorm(Vectors.fromBreeze((leftPivot.vector.asBreeze + rightPivot.vector.asBreeze) / 2.0)) 338 | val radius = math.sqrt(data.map(v => distanceMetric.fastSquaredDistance(pivot, v.vector)).max) 339 | lazy val dataWithDistance = data.map(v => 340 | (v, distanceMetric.fastDistance(leftPivot, v.vector), distanceMetric.fastDistance(rightPivot, v.vector)) 341 | ) 342 | // implemented boundary is parabola (rather than perpendicular plane described in the paper) 343 | lazy val leftPartition = dataWithDistance.filter { case (_, left, right) => left - right <= tau }.map(_._1) 344 | lazy val rightPartition = dataWithDistance.filter { case (_, left, right) => right - left <= tau }.map(_._1) 345 | 346 | if(rho <= 0.0 || leftPartition.size > size * rho || rightPartition.size > size * rho) { 347 | //revert back to metric node 348 | val (leftPartition, rightPartition) = data.partition{ 349 | v => distanceMetric.fastSquaredDistance(leftPivot, v.vector) < distanceMetric.fastSquaredDistance(rightPivot, v.vector) 350 | } 351 | MetricTree( 352 | build(leftPartition, leafSize, tau, rho, rand.nextLong(), distanceMetric), 353 | leftPivot, 354 | build(rightPartition, leafSize, tau, rho, rand.nextLong(), distanceMetric), 355 | rightPivot, 356 | pivot, 357 | radius, 358 | distanceMetric 359 | ) 360 | } else { 361 | SpillTree( 362 | build(leftPartition, leafSize, tau, rho, rand.nextLong(), distanceMetric), 363 | leftPivot, 364 | build(rightPartition, leafSize, tau, rho, rand.nextLong(), distanceMetric), 365 | rightPivot, 366 | pivot, 367 | radius, 368 | tau, 369 | leftPartition.size + rightPartition.size - size, 370 | distanceMetric 371 | ) 372 | } 373 | } 374 | } 375 | } 376 | } 377 | 378 | /** 379 | * Structure to maintain search progress/results for a single query vector. 380 | * Internally uses a PriorityQueue to maintain a max-heap to keep track of the 381 | * next neighbor to evict. 382 | * 383 | * @param queryVector vector being searched 384 | * @param k number of neighbors to return 385 | */ 386 | private[knn] 387 | class KNNCandidates(val queryVector: VectorWithNorm, val k: Int, distanceMetric: DistanceMetric) extends Serializable { 388 | private[knn] val candidates = mutable.PriorityQueue.empty[(RowWithVector, Double)] { 389 | Ordering.by(_._2) 390 | } 391 | 392 | // return the current maximum distance from neighbor to search vector 393 | def maxDistance: Double = if(candidates.isEmpty) 0.0 else candidates.head._2 394 | // insert evict neighbor if required. however it doesn't make sure the insert improves 395 | // search results. it is caller's responsibility to make sure either candidate list 396 | // is not full or the inserted neighbor brings the maxDistance down 397 | def insert(v: RowWithVector, d: Double): Unit = { 398 | while(candidates.size >= k) candidates.dequeue() 399 | candidates.enqueue((v, d)) 400 | } 401 | def insert(v: RowWithVector): Unit = insert(v, distanceMetric.fastDistance(v.vector, queryVector)) 402 | def tryInsert(v: RowWithVector): Unit = { 403 | val distance = distanceMetric.fastDistance(v.vector, queryVector) 404 | if(notFull || distance < maxDistance) insert(v, distance) 405 | } 406 | def toIterable: Iterable[(RowWithVector, Double)] = candidates 407 | def notFull: Boolean = candidates.size < k 408 | } 409 | -------------------------------------------------------------------------------- /spark-knn-core/src/main/scala/org/apache/spark/ml/regression/KNNRegression.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.regression 2 | 3 | import org.apache.spark.broadcast.Broadcast 4 | import org.apache.spark.ml.knn.{KNN, KNNModelParams, KNNParams, Tree} 5 | import org.apache.spark.ml.param.ParamMap 6 | import org.apache.spark.ml.param.shared.HasWeightCol 7 | import org.apache.spark.ml.util.Identifiable 8 | import org.apache.spark.ml.{PredictionModel, Predictor} 9 | import org.apache.spark.ml.linalg.Vector 10 | import org.apache.spark.rdd.RDD 11 | import org.apache.spark.sql.{DataFrame, Dataset, Row} 12 | import org.apache.spark.storage.StorageLevel 13 | 14 | /** 15 | * [[https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm]] for regression. 16 | * The output value is simply the average of the values of its k nearest neighbors. 17 | */ 18 | class KNNRegression(override val uid: String) extends Predictor[Vector, KNNRegression, KNNRegressionModel] 19 | with KNNParams with HasWeightCol { 20 | def this() = this(Identifiable.randomUID("knnr")) 21 | 22 | /** @group setParam */ 23 | override def setFeaturesCol(value: String): this.type = set(featuresCol, value) 24 | 25 | /** @group setParam */ 26 | override def setLabelCol(value: String): this.type = { 27 | set(labelCol, value) 28 | 29 | if ($(weightCol).isEmpty) { 30 | set(inputCols, Array(value)) 31 | } else { 32 | set(inputCols, Array(value, $(weightCol))) 33 | } 34 | } 35 | 36 | //fill in default label col 37 | setDefault(inputCols, Array($(labelCol))) 38 | 39 | /** @group setWeight */ 40 | def setWeightCol(value: String): this.type = { 41 | set(weightCol, value) 42 | 43 | if (value.isEmpty) { 44 | set(inputCols, Array($(labelCol))) 45 | } else { 46 | set(inputCols, Array($(labelCol), value)) 47 | } 48 | } 49 | 50 | setDefault(weightCol -> "") 51 | 52 | /** @group setParam */ 53 | def setK(value: Int): this.type = set(k, value) 54 | 55 | /** @group setParam */ 56 | def setTopTreeSize(value: Int): this.type = set(topTreeSize, value) 57 | 58 | /** @group setParam */ 59 | def setTopTreeLeafSize(value: Int): this.type = set(topTreeLeafSize, value) 60 | 61 | /** @group setParam */ 62 | def setSubTreeLeafSize(value: Int): this.type = set(subTreeLeafSize, value) 63 | 64 | /** @group setParam */ 65 | def setBufferSizeSampleSizes(value: Array[Int]): this.type = set(bufferSizeSampleSizes, value) 66 | 67 | /** @group setParam */ 68 | def setBalanceThreshold(value: Double): this.type = set(balanceThreshold, value) 69 | 70 | /** @group setParam */ 71 | def setSeed(value: Long): this.type = set(seed, value) 72 | 73 | override protected def train(dataset: Dataset[_]): KNNRegressionModel = { 74 | val knnModel = copyValues(new KNN()).fit(dataset) 75 | knnModel.toNewRegressionModel(uid) 76 | } 77 | 78 | override def fit(dataset: Dataset[_]): KNNRegressionModel = { 79 | // Need to overwrite this method because we need to manually overwrite the buffer size 80 | // because it is not supposed to stay the same as the Regressor if user sets it to -1. 81 | transformSchema(dataset.schema, logging = true) 82 | val model = train(dataset) 83 | val bufferSize = model.getBufferSize 84 | copyValues(model.setParent(this)).setBufferSize(bufferSize) 85 | } 86 | 87 | override def copy(extra: ParamMap): KNNRegression = defaultCopy(extra) 88 | } 89 | 90 | class KNNRegressionModel private[ml]( 91 | override val uid: String, 92 | val topTree: Broadcast[Tree], 93 | val subTrees: RDD[Tree] 94 | ) extends PredictionModel[Vector, KNNRegressionModel] 95 | with KNNModelParams with HasWeightCol with Serializable { 96 | require(subTrees.getStorageLevel != StorageLevel.NONE, 97 | "KNNModel is not designed to work with Trees that have not been cached") 98 | 99 | /** @group setParam */ 100 | def setK(value: Int): this.type = set(k, value) 101 | 102 | /** @group setParam */ 103 | def setBufferSize(value: Double): this.type = set(bufferSize, value) 104 | 105 | //TODO: This can benefit from DataSet API in Spark 1.6 106 | override def transformImpl(dataset: Dataset[_]): DataFrame = { 107 | val getWeight: Row => Double = { 108 | if($(weightCol).isEmpty) { 109 | r => 1.0 110 | } else { 111 | r => r.getDouble(1) 112 | } 113 | } 114 | 115 | val neighborDataset : RDD[(Long, Array[(Row, Double)])] = transform(dataset, topTree, subTrees) 116 | val merged = neighborDataset 117 | .map { 118 | case (id, labelsDists) => 119 | val (labels, _) = labelsDists.unzip 120 | var i = 0 121 | var weight = 0.0 122 | var sum = 0.0 123 | val length = labels.length 124 | while (i < length) { 125 | val row = labels(i) 126 | val w = getWeight(row) 127 | sum += row.getDouble(0) * w 128 | weight += w 129 | i += 1 130 | } 131 | 132 | (id, sum / weight) 133 | } 134 | 135 | dataset.sqlContext.createDataFrame( 136 | dataset.toDF().rdd.zipWithIndex().map { case (row, i) => (i, row) } 137 | .leftOuterJoin(merged) //make sure we don't lose any observations 138 | .map { 139 | case (i, (row, value)) => Row.fromSeq(row.toSeq :+ value.get) 140 | }, 141 | transformSchema(dataset.schema) 142 | ) 143 | } 144 | 145 | override def copy(extra: ParamMap): KNNRegressionModel = { 146 | val copied = new KNNRegressionModel(uid, topTree, subTrees) 147 | copyValues(copied, extra).setParent(parent) 148 | } 149 | 150 | override def predict(features: Vector): Double = { 151 | val neighborDataset : RDD[(Long, Array[(Row, Double)])] = transform(subTrees.context.parallelize(Seq(features)), topTree, subTrees) 152 | val results = neighborDataset.first()._2 153 | val labels = results.map(_._1.getDouble(0)) 154 | labels.sum / labels.length 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /spark-knn-core/src/main/scala/org/apache/spark/mllib/knn/KNNUtils.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.mllib.knn 2 | 3 | import org.apache.spark.ml.{linalg => newlinalg} 4 | import org.apache.spark.mllib.{linalg => oldlinalg} 5 | import org.apache.spark.mllib.util.MLUtils 6 | 7 | object KNNUtils { 8 | 9 | import oldlinalg.VectorImplicits._ 10 | 11 | def fastSquaredDistance( 12 | v1: newlinalg.Vector, 13 | norm1: Double, 14 | v2: newlinalg.Vector, 15 | norm2: Double, 16 | precision: Double = 1e-6): Double = { 17 | MLUtils.fastSquaredDistance(v1, norm1, v2, norm2, precision) 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /spark-knn-core/src/test/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Set everything to be logged to the console 2 | log4j.rootCategory=INFO, console 3 | log4j.appender.console=org.apache.log4j.ConsoleAppender 4 | log4j.appender.console.target=System.err 5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 7 | 8 | # Settings to quiet third party logs that are too verbose 9 | log4j.logger.org.spark-project.jetty=WARN 10 | log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR 11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 13 | log4j.logger.org.apache.parquet=ERROR 14 | log4j.logger.parquet=ERROR 15 | 16 | log4j.logger.org.apache.spark=WARN 17 | log4j.logger.org.apache.spark.sql=INFO 18 | log4j.logger.org.apache.spark.ml=INFO 19 | log4j.logger.org.apache.spark.mllib=INFO 20 | 21 | # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support 22 | log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL 23 | log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR 24 | -------------------------------------------------------------------------------- /spark-knn-core/src/test/scala/org/apache/spark/ml/knn/DistanceMetricSpec.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.knn 2 | 3 | import org.apache.spark.ml.knn.KNN.{RowWithVector, VectorWithNorm} 4 | import org.apache.spark.ml.linalg.Vectors 5 | import org.scalatest.funspec.AnyFunSpec 6 | import org.scalatest.matchers.should.Matchers 7 | 8 | class DistanceMetricSpec extends AnyFunSpec with Matchers { 9 | 10 | describe("EuclideanDistanceMetric") { 11 | val distanceMetric = EuclideanDistanceMetric 12 | describe("calculate distance between two dense vectors") { 13 | val v1 = new VectorWithNorm(Vectors.dense(1, 1)) 14 | val v2 = new VectorWithNorm(Vectors.dense(-1, -1)) 15 | 16 | it("should return distance for vector and self") { 17 | distanceMetric.fastDistance(v1, v1) shouldBe 0.0 18 | } 19 | it("should return distance for two vectors") { 20 | distanceMetric.fastDistance(v1, v2) shouldBe Math.sqrt(8.0) 21 | } 22 | } 23 | } 24 | 25 | describe("NaNEuclideanDistanceMetric") { 26 | val distanceMetric = NaNEuclideanDistanceMetric 27 | describe("calculate distance between two dense vectors with valid values") { 28 | val v1 = new VectorWithNorm(Vectors.dense(1, 1)) 29 | val v2 = new VectorWithNorm(Vectors.dense(-1, -1)) 30 | 31 | it("should return distance for vector and self") { 32 | distanceMetric.fastDistance(v1, v1) shouldBe 0.0 33 | } 34 | it("should return distance for two vectors") { 35 | distanceMetric.fastDistance(v1, v2) shouldBe Math.sqrt(8.0) 36 | } 37 | } 38 | describe("calculate distance between two dense vectors with invalid values") { 39 | val v1 = new VectorWithNorm(Vectors.dense(1, 1, Double.NaN, Double.NaN, 1)) 40 | val v2 = new VectorWithNorm(Vectors.dense(-1, Double.NaN, -1, -1, -1)) 41 | 42 | it("should return distance for vector and self") { 43 | distanceMetric.fastDistance(v1, v1) shouldBe 0.0 44 | } 45 | it("should return distance for two vectors") { 46 | distanceMetric.fastDistance(v1, v2) shouldBe Math.sqrt(8.0) 47 | } 48 | } 49 | describe("calculate distance between two sparse vectors with invalid values") { 50 | val v1 = new VectorWithNorm(Vectors.sparse(5, Seq((1, 1.0), (2, Double.NaN), (3, Double.NaN), (4, 1.0)))) 51 | val v2 = new VectorWithNorm(Vectors.sparse(5, Seq((0, -1.0), (1, Double.NaN), (4, -1.0)))) 52 | 53 | it("should return distance for vector and self") { 54 | distanceMetric.fastDistance(v1, v1) shouldBe 0.0 55 | } 56 | it("should return distance for two vectors") { 57 | distanceMetric.fastDistance(v1, v2) shouldBe Math.sqrt(5.0) 58 | } 59 | } 60 | describe("calculate distance between a sparse vectors and all 0 vector") { 61 | val v1 = new VectorWithNorm(Vectors.sparse(5, Seq((0, 0.0), (1, 0.0), (3, 0.0), (4, 0.0)))) 62 | val v2 = new VectorWithNorm(Vectors.sparse(5, Seq((0, -1.0), (1, Double.NaN), (4, -1.0)))) 63 | 64 | it("should return distance for vector and self") { 65 | distanceMetric.fastDistance(v1, v1) shouldBe 0.0 66 | } 67 | it("should return distance for two vectors") { 68 | distanceMetric.fastDistance(v1, v2) shouldBe Math.sqrt(2.0) 69 | } 70 | } 71 | describe("calculate distance between a sparse vectors and empty vector") { 72 | val v1 = new VectorWithNorm(Vectors.sparse(5, Seq())) 73 | val v2 = new VectorWithNorm(Vectors.sparse(5, Seq((0, -1.0), (1, Double.NaN), (4, -1.0)))) 74 | 75 | it("should return distance for vector and self") { 76 | distanceMetric.fastDistance(v1, v1) shouldBe 0.0 77 | } 78 | it("should return distance for two vectors") { 79 | distanceMetric.fastDistance(v1, v2) shouldBe Math.sqrt(2.0) 80 | } 81 | } 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /spark-knn-core/src/test/scala/org/apache/spark/ml/knn/KNNSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.knn 2 | 3 | import org.apache.spark.ml.PredictionModel 4 | import org.apache.spark.ml.classification.KNNClassifier 5 | import org.apache.spark.ml.feature.VectorAssembler 6 | import org.apache.spark.ml.knn.KNN.VectorWithNorm 7 | import org.apache.spark.ml.regression.KNNRegression 8 | import org.apache.spark.ml.linalg.{Vector, Vectors} 9 | import org.apache.spark.sql.functions._ 10 | import org.apache.spark.sql.types._ 11 | import org.apache.spark.sql.{DataFrame, Row, SparkSession} 12 | import org.apache.log4j 13 | 14 | import scala.collection.mutable 15 | import org.scalatest.funsuite.AnyFunSuite 16 | import org.scalatest.matchers.should.Matchers 17 | 18 | 19 | class KNNSuite extends AnyFunSuite with Matchers { 20 | 21 | val logger = log4j.Logger.getLogger(getClass) 22 | 23 | val spark = SparkSession.builder() 24 | .master("local") 25 | .getOrCreate() 26 | 27 | val sc = spark.sparkContext 28 | 29 | private[this] val data = (-10 to 10).flatMap(i => (-10 to 10).map(j => Vectors.dense(i, j))) 30 | private[this] val leafSize = 5 31 | 32 | test("KNN can be fitted") { 33 | val distance = EuclideanDistanceMetric 34 | val knn = new KNN() 35 | .setTopTreeSize(data.size / 10) 36 | .setTopTreeLeafSize(leafSize) 37 | .setSubTreeLeafSize(leafSize) 38 | .setAuxCols(Array("features")) 39 | 40 | val df = createDataFrame() 41 | val model = knn.fit(df).setK(1) 42 | 43 | val results = model.transform(df).collect() 44 | results.length shouldBe data.size 45 | 46 | results.foreach { 47 | row => 48 | val vector = row.getAs[Vector](3) 49 | val neighbors = row.getAs[mutable.WrappedArray[Row]](4) 50 | if (neighbors.isEmpty) { 51 | logger.error(vector.toString) 52 | } 53 | neighbors.length shouldBe 1 54 | val neighbor = neighbors.head.getAs[Vector](0) 55 | distance.fastSquaredDistance(new VectorWithNorm(vector), new VectorWithNorm(neighbor)) shouldBe 0.0 56 | } 57 | } 58 | 59 | test("KNN fits correctly with maxDistance") { 60 | val distanceMetric = EuclideanDistanceMetric 61 | val knn = new KNN() 62 | .setTopTreeSize(data.size / 10) 63 | .setTopTreeLeafSize(leafSize) 64 | .setSubTreeLeafSize(leafSize) 65 | .setAuxCols(Array("features")) 66 | 67 | val df = createDataFrame() 68 | val model = knn.fit(df).setK(6).setMaxDistance(1) 69 | 70 | val results = model.transform(df).collect() 71 | results.length shouldBe data.size 72 | 73 | results.foreach { 74 | row => 75 | val vector = row.getAs[Vector](3) 76 | val neighbors = row.getAs[mutable.WrappedArray[Row]](4) 77 | if (neighbors.isEmpty) { 78 | logger.error(vector.toString) 79 | } 80 | 81 | val numEdges = vector.toArray.map(math.abs).count(_ == 10) 82 | if (neighbors.length > 5 - numEdges) { 83 | logger.error(vector.toString) 84 | logger.error(neighbors.toList.toString) 85 | } 86 | neighbors.length should be <= 5 - numEdges 87 | 88 | val closest = neighbors.head.getAs[Vector](0) 89 | distanceMetric.fastSquaredDistance(new VectorWithNorm(vector), new VectorWithNorm(closest)) shouldBe 0.0 90 | val rest = neighbors.tail.map(_.getAs[Vector](0)) 91 | rest.foreach { neighbor => 92 | val sqDist = distanceMetric.fastSquaredDistance(new VectorWithNorm(vector), new VectorWithNorm(neighbor)) 93 | sqDist shouldEqual 1.0 +- 1e-6 94 | } 95 | } 96 | } 97 | 98 | test("KNN returns correct distance column values") { 99 | val distanceMetric = EuclideanDistanceMetric 100 | val knn = new KNN() 101 | .setTopTreeSize(data.size / 10) 102 | .setTopTreeLeafSize(leafSize) 103 | .setSubTreeLeafSize(leafSize) 104 | .setAuxCols(Array("features")) 105 | 106 | val df = createDataFrame() 107 | val model = knn.fit(df).setK(6).setMaxDistance(1).setNeighborsCol("neighbors").setDistanceCol("distance") 108 | 109 | val results = model.transform(df).collect() 110 | results.length shouldBe data.size 111 | 112 | results.foreach { 113 | row => 114 | val vector = row.getAs[Vector](row.fieldIndex("features")) 115 | val neighbors = row.getAs[mutable.WrappedArray[Row]](row.fieldIndex("neighbors")) 116 | val distances = row.getAs[mutable.WrappedArray[Double]](row.fieldIndex("distance")) 117 | if (neighbors.isEmpty) { 118 | logger.error(vector.toString) 119 | } 120 | 121 | val numEdges = vector.toArray.map(math.abs).count(_ == 10) 122 | if (neighbors.length > 5 - numEdges) { 123 | logger.error(vector.toString) 124 | logger.error(neighbors.toList.toString) 125 | } 126 | neighbors.length should be <= 5 - numEdges 127 | 128 | if (distances.length != neighbors.length) { 129 | logger.error(vector.toString) 130 | logger.error(neighbors.toList.toString) 131 | logger.error(distances.toList.toString) 132 | } 133 | distances.length should be (neighbors.length) 134 | 135 | val closest = neighbors.head.getAs[Vector](0) 136 | val closestDist = distances.head 137 | val closestCalDist = distanceMetric.fastSquaredDistance(new VectorWithNorm(vector), new VectorWithNorm(closest)) 138 | closestCalDist shouldEqual closestDist 139 | 140 | val rest = neighbors.tail.map(_.getAs[Vector](0)).zip(distances.tail.toList) 141 | rest.foreach { 142 | case (neighbor, distance) => 143 | val sqDist = distanceMetric.fastSquaredDistance(new VectorWithNorm(vector), new VectorWithNorm(neighbor)) 144 | sqDist shouldEqual 1.0 +- 1e-6 145 | sqDist shouldEqual distance +- 1e-6 146 | } 147 | } 148 | } 149 | 150 | test("KNNClassifier can be fitted with/without weight column") { 151 | val knn = new KNNClassifier() 152 | .setTopTreeSize(data.size / 10) 153 | .setTopTreeLeafSize(leafSize) 154 | .setSubTreeLeafSize(leafSize) 155 | .setK(1) 156 | checkKNN(knn.fit) 157 | checkKNN(knn.setWeightCol("z").fit) 158 | } 159 | 160 | test("KNNRegressor can be fitted with/without weight column") { 161 | val knn = new KNNRegression() 162 | .setTopTreeSize(data.size / 10) 163 | .setTopTreeLeafSize(leafSize) 164 | .setSubTreeLeafSize(leafSize) 165 | .setK(1) 166 | checkKNN(knn.fit) 167 | checkKNN(knn.setWeightCol("z").fit) 168 | } 169 | 170 | test("KNNParmas are copied correctly") { 171 | val knn = new KNNClassifier() 172 | .setTopTreeSize(data.size / 10) 173 | .setTopTreeLeafSize(leafSize) 174 | .setSubTreeLeafSize(leafSize) 175 | .setK(2) 176 | val model = knn.fit(createDataFrame().withColumn("label", lit(1.0))) 177 | // check pre-set parameters are correctly copied 178 | model.getK shouldBe 2 179 | // check auto generated buffer size is correctly transferred 180 | model.getBufferSize should be > 0.0 181 | } 182 | 183 | test("BufferSize is not estimated if rho = 0") { 184 | val knn = new KNNClassifier() 185 | .setTopTreeSize(data.size / 10) 186 | .setTopTreeLeafSize(leafSize) 187 | .setSubTreeLeafSize(leafSize) 188 | .setBalanceThreshold(0) 189 | val model = knn.fit(createDataFrame().withColumn("label", lit(1.0))) 190 | model.getBufferSize shouldBe 0.0 191 | } 192 | 193 | private[this] def checkKNN(fit: DataFrame => PredictionModel[_, _]): Unit = { 194 | val df = createDataFrame() 195 | df.sqlContext.udf.register("label", { v: Vector => math.abs(v(0)) }) 196 | val training = df.selectExpr("*", "label(features) as label") 197 | val model = fit(training) 198 | 199 | val results = model.transform(training).select("label", "prediction").collect() 200 | results.length shouldBe data.size 201 | 202 | results foreach { 203 | row => row.getDouble(0) shouldBe row.getDouble(1) 204 | } 205 | } 206 | 207 | private[this] def createDataFrame(): DataFrame = { 208 | val rdd = sc.parallelize(data.map(v => Row(v.toArray: _*))) 209 | val assembler = new VectorAssembler() 210 | .setInputCols(Array("x", "y")) 211 | .setOutputCol("features") 212 | assembler.transform( 213 | spark.createDataFrame(rdd, 214 | StructType( 215 | Seq( 216 | StructField("x", DoubleType), 217 | StructField("y", DoubleType) 218 | ) 219 | ) 220 | ).withColumn("z", lit(1.0)) 221 | ) 222 | } 223 | } 224 | -------------------------------------------------------------------------------- /spark-knn-core/src/test/scala/org/apache/spark/ml/knn/MetricTreeSpec.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.knn 2 | 3 | import org.apache.spark.ml.knn.KNN.{RowWithVector, VectorWithNorm} 4 | import org.apache.spark.ml.linalg.Vectors 5 | import org.scalatest.funspec.AnyFunSpec 6 | import org.scalatest.matchers.should.Matchers 7 | 8 | class MetricTreeSpec extends AnyFunSpec with Matchers { 9 | 10 | describe("MetricTree") { 11 | val distanceMetric = EuclideanDistanceMetric 12 | val origin = Vectors.dense(0, 0) 13 | describe("can be constructed with empty data") { 14 | val tree = MetricTree.build(IndexedSeq.empty[RowWithVector], distanceMetric=distanceMetric) 15 | it("iterator should be empty") { 16 | tree.iterator shouldBe empty 17 | } 18 | it("should return empty when queried") { 19 | tree.query(origin).isEmpty shouldBe true 20 | } 21 | it("should have zero leaf") { 22 | tree.leafCount shouldBe 0 23 | } 24 | } 25 | 26 | describe("without duplicates") { 27 | val data = (-5 to 5).flatMap(i => (-5 to 5).map(j => new RowWithVector(Vectors.dense(i, j), null))) 28 | List(1, data.size / 2, data.size, data.size * 2).foreach { 29 | leafSize => 30 | describe(s"with leafSize of $leafSize") { 31 | val tree = MetricTree.build(data, leafSize, distanceMetric=distanceMetric) 32 | it("should have correct size") { 33 | tree.size shouldBe data.size 34 | } 35 | it("should return an iterator that goes through all data points") { 36 | tree.iterator.toIterable should contain theSameElementsAs data 37 | } 38 | it("should return vector itself for those in input set") { 39 | data.foreach(v => tree.query(v.vector, 1).head._1 shouldBe v) 40 | } 41 | it("should return nearest neighbors correctly") { 42 | tree.query(origin, 5).map(_._1.vector.vector) should contain theSameElementsAs Set( 43 | Vectors.dense(-1, 0), 44 | Vectors.dense(1, 0), 45 | Vectors.dense(0, -1), 46 | Vectors.dense(0, 1), 47 | Vectors.dense(0, 0) 48 | ) 49 | tree.query(origin, 9).map(_._1.vector.vector) should contain theSameElementsAs 50 | (-1 to 1).flatMap(i => (-1 to 1).map(j => Vectors.dense(i, j))) 51 | } 52 | it("should have correct number of leaves") { 53 | tree.leafCount shouldBe (tree.size / leafSize.toDouble).ceil 54 | } 55 | it("all points should fall with radius of pivot") { 56 | def check(tree: Tree): Unit = { 57 | tree.iterator.foreach(node=> distanceMetric.fastDistance(node.vector, tree.pivot) <= tree.radius) 58 | tree match { 59 | case t: MetricTree => 60 | check(t.leftChild) 61 | check(t.rightChild) 62 | case _ => 63 | } 64 | } 65 | check(tree) 66 | } 67 | } 68 | } 69 | } 70 | 71 | describe("with duplicates") { 72 | val data = (Vectors.dense(2.0, 0.0) +: Array.fill(5)(Vectors.dense(0.0, 1.0))).map(new RowWithVector(_, null)) 73 | val tree = MetricTree.build(data, distanceMetric=distanceMetric) 74 | it("should have 2 leaves") { 75 | tree.leafCount shouldBe 2 76 | } 77 | it("should return all available duplicated candidates") { 78 | val res = tree.query(origin, 5).map(_._1.vector.vector) 79 | res.size shouldBe 5 80 | res.toSet should contain theSameElementsAs Array(Vectors.dense(0.0, 1.0)) 81 | } 82 | } 83 | 84 | describe("for other corner cases") { 85 | it("queryCost should work on Empty") { 86 | Empty(distanceMetric).distance(new KNNCandidates(new VectorWithNorm(origin), 1, distanceMetric)) shouldBe 0 87 | Empty(distanceMetric).distance(new VectorWithNorm(origin)) shouldBe 0 88 | } 89 | } 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /spark-knn-core/src/test/scala/org/apache/spark/ml/knn/SpillTreeSpec.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.knn 2 | 3 | import org.apache.spark.ml.knn.KNN.RowWithVector 4 | import org.apache.spark.ml.linalg.Vectors 5 | import org.scalatest.funspec.AnyFunSpec 6 | import org.scalatest.matchers.should.Matchers 7 | 8 | class SpillTreeSpec extends AnyFunSpec with Matchers { 9 | describe("SpillTree") { 10 | val distanceMetric = EuclideanDistanceMetric 11 | val origin = Vectors.dense(0, 0) 12 | describe("can be constructed with empty data") { 13 | val tree = SpillTree.build(IndexedSeq.empty[RowWithVector], tau = 0.0, distanceMetric=distanceMetric) 14 | it("iterator should be empty") { 15 | tree.iterator shouldBe empty 16 | } 17 | it("should return empty when queried") { 18 | tree.query(origin).isEmpty shouldBe true 19 | } 20 | it("should have zero leaf") { 21 | tree.leafCount shouldBe 0 22 | } 23 | } 24 | 25 | describe("with equidistant points on a circle") { 26 | val n = 12 27 | val points = (1 to n).map { 28 | i => new RowWithVector(Vectors.dense(math.sin(2 * math.Pi * i / n), math.cos(2 * math.Pi * i / n)), null) 29 | } 30 | val leafSize = n / 4 31 | describe("built with tau = 0.0") { 32 | val tree = SpillTree.build(points, leafSize = leafSize, tau = 0.0, distanceMetric=distanceMetric) 33 | it("should have correct size") { 34 | tree.size shouldBe points.size 35 | } 36 | it("should return an iterator that goes through all data points") { 37 | tree.iterator.toIterable should contain theSameElementsAs points 38 | } 39 | it("can return more than min leaf size") { 40 | val k = leafSize + 5 41 | points.foreach(v => tree.query(v.vector, k).size shouldBe k) 42 | } 43 | } 44 | describe("built with tau = 0.5") { 45 | val tree = SpillTree.build(points, leafSize = leafSize, tau = 0.5, distanceMetric=distanceMetric) 46 | it("should have correct size") { 47 | tree.size shouldBe points.size 48 | } 49 | it("should return an iterator that goes through all data points") { 50 | tree.iterator.toIterable should contain theSameElementsAs points 51 | } 52 | it("works for every point to identify itself") { 53 | points.foreach(v => tree.query(v.vector, 1).head._1 shouldBe v) 54 | } 55 | it("has consistent size and iterator") { 56 | def check(tree: Tree): Unit = { 57 | tree match { 58 | case t: SpillTree => 59 | t.iterator.size shouldBe t.size 60 | 61 | check(t.leftChild) 62 | check(t.rightChild) 63 | case _ => 64 | } 65 | } 66 | check(tree) 67 | } 68 | } 69 | } 70 | } 71 | 72 | describe("HybridTree") { 73 | val origin = Vectors.dense(0, 0) 74 | describe("can be constructed with empty data") { 75 | val tree = HybridTree.build(IndexedSeq.empty[RowWithVector], tau = 0.0, distanceMetric=EuclideanDistanceMetric) 76 | it("iterator should be empty") { 77 | tree.iterator shouldBe empty 78 | } 79 | it("should return empty when queried") { 80 | tree.query(origin).isEmpty shouldBe true 81 | } 82 | it("should have zero leaf") { 83 | tree.leafCount shouldBe 0 84 | } 85 | } 86 | 87 | describe("with equidistant points on a circle") { 88 | val n = 12 89 | val points = (1 to n).map { 90 | i => new RowWithVector(Vectors.dense(math.sin(2 * math.Pi * i / n), math.cos(2 * math.Pi * i / n)), null) 91 | } 92 | val leafSize = n / 4 93 | val tree = HybridTree.build(points, leafSize = leafSize, tau = 0.5, distanceMetric=EuclideanDistanceMetric) 94 | it("should have correct size") { 95 | tree.size shouldBe points.size 96 | } 97 | it("should return an iterator that goes through all data points") { 98 | tree.iterator.toIterable should contain theSameElementsAs points 99 | } 100 | } 101 | } 102 | } 103 | -------------------------------------------------------------------------------- /spark-knn-core/src/test/scala/org/apache/spark/ml/regression/KNNRegressionSuite.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.regression 2 | 3 | import org.apache.spark.ml.feature.VectorAssembler 4 | import org.apache.spark.sql.SparkSession 5 | import org.scalatest.funsuite.AnyFunSuite 6 | import org.scalatest.matchers.should.Matchers 7 | 8 | class KNNRegressionSuite extends AnyFunSuite with Matchers { 9 | 10 | val spark = SparkSession.builder() 11 | .master("local") 12 | .getOrCreate() 13 | 14 | import spark.implicits._ 15 | val rawDF1 = Seq( 16 | (9.5,37.48596719,-122.2428196, 1.0), 17 | (9.5,37.49115273,-122.2319523, 2.0), 18 | (9.5,37.49099652,-122.2324551, 3.0), 19 | (9.5,37.4886712,-122.2348786, 1.0), 20 | (9.5,37.48696518,-122.2384678, 3.0), 21 | (9.5,37.48473396,-122.2345444, 3.0), 22 | (9.5,37.48565758,-122.2412995, 2.0), 23 | (9.5,37.48033504,-122.2364642, 2.0) 24 | ).toDF("col1", "col2", "col3", "label") 25 | val rawDF2 = Seq( 26 | (9.5,37.48495049,-122.2335112) 27 | ).toDF("col1", "col2", "col3") 28 | 29 | val assembler = new VectorAssembler() 30 | .setInputCols(Array("col1", "col2", "col3")) 31 | .setOutputCol("features") 32 | val trainDF = assembler.transform(rawDF1) 33 | val testDF = assembler.transform(rawDF2) 34 | 35 | test("KNNRegression can be fitted using euclidean distance") { 36 | val knnr = new KNNRegression() 37 | .setTopTreeSize(5) 38 | .setFeaturesCol("features") 39 | .setPredictionCol("prediction") 40 | .setLabelCol("label") 41 | .setSeed(31) 42 | .setK(5) 43 | val knnModel = knnr.fit(trainDF) 44 | val outputDF = knnModel.transform(testDF) 45 | val predictions = outputDF.collect().map(_.getAs[Double]("prediction")) 46 | predictions shouldEqual Array(2.4) 47 | } 48 | 49 | test("KNNRegression can be fitted using nan_euclidean distance") { 50 | val knnr = new KNNRegression() 51 | .setTopTreeSize(5) 52 | .setFeaturesCol("features") 53 | .setPredictionCol("prediction") 54 | .setLabelCol("label") 55 | .setSeed(31) 56 | .setK(5) 57 | knnr.set(knnr.metric, "nan_euclidean") 58 | val knnModel = knnr.fit(trainDF) 59 | val outputDF = knnModel.transform(testDF) 60 | val predictions = outputDF.collect().map(_.getAs[Double]("prediction")) 61 | predictions shouldEqual Array(2.4) 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /spark-knn-examples/src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # Set everything to be logged to the console 2 | log4j.rootCategory=INFO, console 3 | log4j.appender.console=org.apache.log4j.ConsoleAppender 4 | log4j.appender.console.target=System.err 5 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 6 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 7 | 8 | # Settings to quiet third party logs that are too verbose 9 | log4j.logger.org.spark-project.jetty=WARN 10 | log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR 11 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO 12 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO 13 | log4j.logger.org.apache.parquet=ERROR 14 | log4j.logger.parquet=ERROR 15 | 16 | log4j.logger.org.apache.spark=WARN 17 | log4j.logger.org.apache.spark.sql=INFO 18 | log4j.logger.org.apache.spark.ml=INFO 19 | log4j.logger.org.apache.spark.mllib=INFO 20 | 21 | # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support 22 | log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL 23 | log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR 24 | -------------------------------------------------------------------------------- /spark-knn-examples/src/main/scala/com/github/saurfang/spark/ml/knn/examples/MNIST.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.ml.knn.examples 2 | 3 | import org.apache.spark.ml.Pipeline 4 | import org.apache.spark.ml.classification.KNNClassifier 5 | import org.apache.spark.ml.feature.PCA 6 | import org.apache.spark.mllib.util.MLUtils 7 | import org.apache.spark.sql.{DataFrame, SparkSession} 8 | import org.apache.log4j 9 | 10 | object MNIST { 11 | 12 | val logger = log4j.Logger.getLogger(getClass) 13 | 14 | def main(args: Array[String]) { 15 | val spark = SparkSession.builder().getOrCreate() 16 | val sc = spark.sparkContext 17 | import spark.implicits._ 18 | 19 | //read in raw label and features 20 | val rawDataset = MLUtils.loadLibSVMFile(sc, "data/mnist/mnist.bz2") 21 | .toDF() 22 | // convert "features" from mllib.linalg.Vector to ml.linalg.Vector 23 | val dataset = MLUtils.convertVectorColumnsToML(rawDataset) 24 | 25 | //split training and testing 26 | val Array(train, test) = dataset 27 | .randomSplit(Array(0.7, 0.3), seed = 1234L) 28 | .map(_.cache()) 29 | 30 | //create PCA matrix to reduce feature dimensions 31 | val pca = new PCA() 32 | .setInputCol("features") 33 | .setK(50) 34 | .setOutputCol("pcaFeatures") 35 | val knn = new KNNClassifier() 36 | .setTopTreeSize(dataset.count().toInt / 500) 37 | .setFeaturesCol("pcaFeatures") 38 | .setPredictionCol("predicted") 39 | .setK(1) 40 | val pipeline = new Pipeline() 41 | .setStages(Array(pca, knn)) 42 | .fit(train) 43 | 44 | val insample = validate(pipeline.transform(train)) 45 | val outofsample = validate(pipeline.transform(test)) 46 | 47 | //reference accuracy: in-sample 95% out-of-sample 94% 48 | logger.info(s"In-sample: $insample, Out-of-sample: $outofsample") 49 | } 50 | 51 | private[this] def validate(results: DataFrame): Double = { 52 | results 53 | .selectExpr("SUM(CASE WHEN label = predicted THEN 1.0 ELSE 0.0 END) / COUNT(1)") 54 | .collect() 55 | .head 56 | .getDecimal(0) 57 | .doubleValue() 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /spark-knn-examples/src/main/scala/com/github/saurfang/spark/ml/knn/examples/MNISTBenchmark.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.ml.knn.examples 2 | 3 | import org.apache.spark.annotation.DeveloperApi 4 | import org.apache.spark.ml.classification.{KNNClassifier, NaiveKNNClassifier} 5 | import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator 6 | import org.apache.spark.ml.param.{IntParam, ParamMap} 7 | import org.apache.spark.ml.tuning.{Benchmarker, ParamGridBuilder} 8 | import org.apache.spark.ml.util.Identifiable 9 | import org.apache.spark.ml.{Pipeline, Transformer} 10 | import org.apache.spark.mllib.util.MLUtils 11 | import org.apache.spark.sql.types.StructType 12 | import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} 13 | import org.apache.log4j 14 | 15 | import scala.collection.mutable 16 | 17 | /** 18 | * Benchmark KNN as a function of number of observations 19 | */ 20 | object MNISTBenchmark { 21 | 22 | val logger = log4j.Logger.getLogger(getClass) 23 | 24 | def main(args: Array[String]) { 25 | val ns = if(args.isEmpty) (2500 to 10000 by 2500).toArray else args(0).split(',').map(_.toInt) 26 | val path = if(args.length >= 2) args(1) else "data/mnist/mnist.bz2" 27 | val numPartitions = if(args.length >= 3) args(2).toInt else 10 28 | val models = if(args.length >=4) args(3).split(',') else Array("tree","naive") 29 | 30 | val spark = SparkSession.builder().getOrCreate() 31 | val sc = spark.sparkContext 32 | import spark.implicits._ 33 | 34 | //read in raw label and features 35 | val rawDataset = MLUtils.loadLibSVMFile(sc, path) 36 | .zipWithIndex() 37 | .filter(_._2 < ns.max) 38 | .sortBy(_._2, numPartitions = numPartitions) 39 | .keys 40 | .toDF() 41 | 42 | // convert "features" from mllib.linalg.Vector to ml.linalg.Vector 43 | val dataset = MLUtils.convertVectorColumnsToML(rawDataset) 44 | .cache() 45 | dataset.count() //force persist 46 | 47 | val limiter = new Limiter() 48 | val knn = new KNNClassifier() 49 | .setTopTreeSize(numPartitions * 10) 50 | .setFeaturesCol("features") 51 | .setPredictionCol("prediction") 52 | .setK(1) 53 | val naiveKNN = new NaiveKNNClassifier() 54 | 55 | val pipeline = new Pipeline() 56 | .setStages(Array(limiter, knn)) 57 | val naivePipeline = new Pipeline() 58 | .setStages(Array(limiter, naiveKNN)) 59 | 60 | val paramGrid = new ParamGridBuilder() 61 | .addGrid(limiter.n, ns) 62 | .build() 63 | 64 | val bm = new Benchmarker() 65 | .setEvaluator(new MulticlassClassificationEvaluator) 66 | .setEstimatorParamMaps(paramGrid) 67 | .setNumTimes(3) 68 | 69 | val metrics = mutable.ArrayBuffer[String]() 70 | if(models.contains("tree")) { 71 | val bmModel = bm.setEstimator(pipeline).fit(dataset) 72 | metrics += s"knn: ${bmModel.avgTrainingRuntimes.toSeq} / ${bmModel.avgEvaluationRuntimes.toSeq}" 73 | } 74 | if(models.contains("naive")) { 75 | val naiveBMModel = bm.setEstimator(naivePipeline).fit(dataset) 76 | metrics += s"naive: ${naiveBMModel.avgTrainingRuntimes.toSeq} / ${naiveBMModel.avgEvaluationRuntimes.toSeq}" 77 | } 78 | logger.info(metrics.mkString("\n")) 79 | } 80 | } 81 | 82 | class Limiter(override val uid: String) extends Transformer { 83 | def this() = this(Identifiable.randomUID("limiter")) 84 | 85 | val n: IntParam = new IntParam(this, "n", "number of rows to limit") 86 | 87 | def setN(value: Int): this.type = set(n, value) 88 | 89 | // hack to maintain number of partitions (otherwise it collapses to 1 which is unfair for naiveKNN) 90 | override def transform(dataset: Dataset[_]): DataFrame = dataset.limit($(n)).repartition(dataset.rdd.partitions.length).toDF() 91 | 92 | override def copy(extra: ParamMap): Transformer = defaultCopy(extra) 93 | 94 | @DeveloperApi 95 | override def transformSchema(schema: StructType): StructType = schema 96 | } 97 | -------------------------------------------------------------------------------- /spark-knn-examples/src/main/scala/com/github/saurfang/spark/ml/knn/examples/MNISTCrossValidation.scala: -------------------------------------------------------------------------------- 1 | package com.github.saurfang.spark.ml.knn.examples 2 | 3 | import org.apache.spark.ml.Pipeline 4 | import org.apache.spark.ml.classification.KNNClassifier 5 | import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator 6 | import org.apache.spark.ml.feature.PCA 7 | import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} 8 | import org.apache.spark.mllib.util.MLUtils 9 | import org.apache.spark.sql.{DataFrame, SparkSession} 10 | import org.apache.log4j 11 | 12 | object MNISTCrossValidation { 13 | 14 | val logger = log4j.Logger.getLogger(getClass) 15 | 16 | def main(args: Array[String]) { 17 | val spark = SparkSession.builder().getOrCreate() 18 | val sc = spark.sparkContext 19 | import spark.implicits._ 20 | 21 | //read in raw label and features 22 | val dataset = MLUtils.loadLibSVMFile(sc, "data/mnist/mnist.bz2") 23 | .toDF() 24 | //.limit(10000) 25 | 26 | //split traning and testing 27 | val Array(train, test) = dataset.randomSplit(Array(0.7, 0.3), seed = 1234L).map(_.cache()) 28 | 29 | //create PCA matrix to reduce feature dimensions 30 | val pca = new PCA() 31 | .setInputCol("features") 32 | .setK(50) 33 | .setOutputCol("pcaFeatures") 34 | val knn = new KNNClassifier() 35 | .setTopTreeSize(50) 36 | .setFeaturesCol("pcaFeatures") 37 | .setPredictionCol("prediction") 38 | .setK(1) 39 | 40 | val pipeline = new Pipeline() 41 | .setStages(Array(pca, knn)) 42 | 43 | val paramGrid = new ParamGridBuilder() 44 | // .addGrid(knn.k, 1 to 20) 45 | .addGrid(pca.k, 10 to 100 by 10) 46 | .build() 47 | 48 | val cv = new CrossValidator() 49 | .setEstimator(pipeline) 50 | .setEvaluator(new MulticlassClassificationEvaluator) 51 | .setEstimatorParamMaps(paramGrid) 52 | .setNumFolds(5) 53 | 54 | val cvModel = cv.fit(train) 55 | 56 | val insample = validate(cvModel.transform(train)) 57 | val outofsample = validate(cvModel.transform(test)) 58 | 59 | //reference accuracy: in-sample 95% out-of-sample 94% 60 | logger.info(s"In-sample: $insample, Out-of-sample: $outofsample") 61 | logger.info(s"Cross-validated: ${cvModel.avgMetrics.toSeq}") 62 | } 63 | 64 | private[this] def validate(results: DataFrame): Double = { 65 | results 66 | .selectExpr("SUM(CASE WHEN label = prediction THEN 1.0 ELSE 0.0 END) / COUNT(1)") 67 | .collect() 68 | .head 69 | .getDecimal(0) 70 | .doubleValue() 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /spark-knn-examples/src/main/scala/org/apache/spark/ml/classification/NaiveKNN.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.classification 2 | 3 | import org.apache.spark.SparkException 4 | import org.apache.spark.ml.knn.KNN.{RowWithVector, VectorWithNorm} 5 | import org.apache.spark.ml.knn.{DistanceMetric, EuclideanDistanceMetric, KNNModel, KNNParams} 6 | import org.apache.spark.ml.param.ParamMap 7 | import org.apache.spark.ml.util.{Identifiable, SchemaUtils} 8 | import org.apache.spark.ml.{Model, Predictor} 9 | import org.apache.spark.ml.linalg._ 10 | import org.apache.spark.ml.feature.LabeledPoint 11 | import org.apache.spark.rdd.RDD 12 | import org.apache.spark.sql.types.{ArrayType, DoubleType, StructType} 13 | import org.apache.spark.sql.{DataFrame, Dataset, Row} 14 | import org.apache.spark.storage.StorageLevel 15 | import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ 16 | 17 | import scala.collection.mutable.ArrayBuffer 18 | 19 | /** 20 | * Brute-force kNN with k = 1 21 | */ 22 | class NaiveKNNClassifier(override val uid: String, val distanceMetric: DistanceMetric) 23 | extends Predictor[Vector, NaiveKNNClassifier, NaiveKNNClassifierModel] { 24 | def this() = this(Identifiable.randomUID("naiveknnc"), EuclideanDistanceMetric) 25 | 26 | override def copy(extra: ParamMap): NaiveKNNClassifier = defaultCopy(extra) 27 | 28 | override protected def train(dataset: Dataset[_]): NaiveKNNClassifierModel = { 29 | // Extract columns from data. If dataset is persisted, do not persist oldDataset. 30 | val instances = extractLabeledPoints(dataset).map { 31 | case LabeledPoint(label: Double, features: Vector) => (label, features) 32 | } 33 | val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE 34 | if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) 35 | 36 | val labelSummarizer = instances.treeAggregate(new MultiClassSummarizer)( 37 | seqOp = (c, v) => (c, v) match { 38 | case (labelSummarizer: MultiClassSummarizer, (label: Double, features: Vector)) => 39 | labelSummarizer.add(label) 40 | }, 41 | combOp = (c1, c2) => (c1, c2) match { 42 | case (classSummarizer1: MultiClassSummarizer, classSummarizer2: MultiClassSummarizer) => 43 | classSummarizer1.merge(classSummarizer2) 44 | }) 45 | 46 | val histogram = labelSummarizer.histogram 47 | val numInvalid = labelSummarizer.countInvalid 48 | val numClasses = histogram.length 49 | 50 | if (numInvalid != 0) { 51 | val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + 52 | s"Found $numInvalid invalid labels." 53 | logError(msg) 54 | throw new SparkException(msg) 55 | } 56 | 57 | val points = instances.map{ 58 | case (label, features) => (label, new VectorWithNorm(features)) 59 | } 60 | 61 | new NaiveKNNClassifierModel(uid, points, numClasses, distanceMetric) 62 | } 63 | 64 | } 65 | 66 | class NaiveKNNClassifierModel( 67 | override val uid: String, 68 | val points: RDD[(Double, VectorWithNorm)], 69 | val _numClasses: Int, 70 | val distanceMetric: DistanceMetric) extends ProbabilisticClassificationModel[Vector, NaiveKNNClassifierModel] { 71 | override def numClasses: Int = _numClasses 72 | 73 | override def transform(dataset: Dataset[_]): DataFrame = { 74 | import dataset.sparkSession.implicits._ 75 | 76 | val features = dataset.select($(featuresCol)) 77 | .map(r => new VectorWithNorm(r.getAs[Vector](0))) 78 | 79 | val merged = features.rdd.zipWithUniqueId() 80 | .cartesian(points) 81 | .map { 82 | case ((u, i), (label, v)) => 83 | val dist = distanceMetric.fastSquaredDistance(u, v) 84 | (i, (dist, label)) 85 | } 86 | .topByKey(1)(Ordering.by(e => -e._1)) 87 | .map{ 88 | case (id, labels) => 89 | val vector = new Array[Double](numClasses) 90 | var i = 0 91 | while (i < labels.length) { 92 | vector(labels(i)._2.toInt) += 1 93 | i += 1 94 | } 95 | val rawPrediction = Vectors.dense(vector) 96 | lazy val probability = raw2probability(rawPrediction) 97 | lazy val prediction = probability2prediction(probability) 98 | 99 | val values = new ArrayBuffer[Any] 100 | if ($(rawPredictionCol).nonEmpty) { 101 | values.append(rawPrediction) 102 | } 103 | if ($(probabilityCol).nonEmpty) { 104 | values.append(probability) 105 | } 106 | if ($(predictionCol).nonEmpty) { 107 | values.append(prediction) 108 | } 109 | 110 | (id, values.toSeq) 111 | } 112 | 113 | dataset.sqlContext.createDataFrame( 114 | dataset.toDF().rdd.zipWithUniqueId().map { case (row, i) => (i, row) } 115 | .leftOuterJoin(merged) //make sure we don't lose any observations 116 | .map { 117 | case (i, (row, values)) => Row.fromSeq(row.toSeq ++ values.get) 118 | }, 119 | transformSchema(dataset.schema) 120 | ) 121 | } 122 | 123 | override def transformSchema(schema: StructType): StructType = { 124 | var transformed = schema 125 | if ($(rawPredictionCol).nonEmpty) { 126 | transformed = SchemaUtils.appendColumn(transformed, $(rawPredictionCol), new VectorUDT) 127 | } 128 | if ($(probabilityCol).nonEmpty) { 129 | transformed = SchemaUtils.appendColumn(transformed, $(probabilityCol), new VectorUDT) 130 | } 131 | if ($(predictionCol).nonEmpty) { 132 | transformed = SchemaUtils.appendColumn(transformed, $(predictionCol), DoubleType) 133 | } 134 | transformed 135 | } 136 | 137 | override def copy(extra: ParamMap): NaiveKNNClassifierModel = { 138 | val copied = new NaiveKNNClassifierModel(uid, points, numClasses, distanceMetric) 139 | copyValues(copied, extra).setParent(parent) 140 | } 141 | 142 | override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { 143 | rawPrediction match { 144 | case dv: DenseVector => 145 | var i = 0 146 | val size = dv.size 147 | 148 | var sum = 0.0 149 | while (i < size) { 150 | sum += dv.values(i) 151 | i += 1 152 | } 153 | 154 | i = 0 155 | while (i < size) { 156 | dv.values(i) /= sum 157 | i += 1 158 | } 159 | 160 | dv 161 | case sv: SparseVector => 162 | throw new RuntimeException("Unexpected error in KNNClassificationModel:" + 163 | " raw2probabilitiesInPlace encountered SparseVector") 164 | } 165 | } 166 | 167 | override def predictRaw(features: Vector): Vector = { 168 | throw new SparkException("predictRaw function should not be called directly since kNN prediction is done in distributed fashion. Use transform instead.") 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /spark-knn-examples/src/main/scala/org/apache/spark/ml/tuning/Benchmarker.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.ml.tuning 2 | 3 | import com.github.fommil.netlib.F2jBLAS 4 | import org.apache.spark.annotation.Experimental 5 | import org.apache.spark.ml.evaluation.Evaluator 6 | import org.apache.spark.ml.param._ 7 | import org.apache.spark.ml.util.Identifiable 8 | import org.apache.spark.ml.{Estimator, Model} 9 | import org.apache.spark.sql.{DataFrame, Dataset} 10 | import org.apache.spark.sql.types.StructType 11 | 12 | /** 13 | * Params for [[Benchmarker]] and [[BenchmarkModel]]. 14 | */ 15 | private[ml] trait BenchmarkerParams extends ValidatorParams { 16 | /** 17 | * Param for number of times for benchmark. Must be >= 1. 18 | * Default: 1 19 | * @group param 20 | */ 21 | val numTimes: IntParam = new IntParam(this, "numTimes", 22 | "number of times for benchmark (>= 1)", ParamValidators.gtEq(1)) 23 | 24 | /** @group getParam */ 25 | def getNumTimes: Int = $(numTimes) 26 | 27 | setDefault(numTimes -> 1) 28 | } 29 | 30 | /** 31 | * :: Experimental :: 32 | * Benchmark estimator pipelines. 33 | */ 34 | @Experimental 35 | class Benchmarker(override val uid: String) extends Estimator[BenchmarkModel] 36 | with BenchmarkerParams { 37 | 38 | def this() = this(Identifiable.randomUID("benchmark")) 39 | 40 | private val f2jBLAS = new F2jBLAS 41 | 42 | /** @group setParam */ 43 | def setEstimator(value: Estimator[_]): this.type = set(estimator, value) 44 | 45 | /** @group setParam */ 46 | def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) 47 | 48 | /** @group setParam */ 49 | def setEvaluator(value: Evaluator): this.type = set(evaluator, value) 50 | 51 | /** @group setParam */ 52 | def setNumTimes(value: Int): this.type = set(numTimes, value) 53 | 54 | override def fit(dataset: Dataset[_]): BenchmarkModel = { 55 | val schema = dataset.schema 56 | transformSchema(schema, logging = true) 57 | val sqlCtx = dataset.sqlContext 58 | val est = $(estimator) 59 | val eval = $(evaluator) 60 | val epm = $(estimatorParamMaps) 61 | val numModels = epm.length 62 | val models = new Array[Model[_]](epm.length) 63 | val trainingRuntimes = new Array[Double](epm.length) 64 | val evaluationRuntimes = new Array[Double](epm.length) 65 | (1 to getNumTimes).foreach { index => 66 | // multi-model training 67 | logDebug(s"Train $index times with multiple sets of parameters.") 68 | var i = 0 69 | while (i < numModels) { 70 | var tic = System.currentTimeMillis() 71 | models(i) = est.fit(dataset, epm(i)).asInstanceOf[Model[_]] 72 | trainingRuntimes(i) += System.currentTimeMillis() - tic 73 | 74 | tic = System.currentTimeMillis() 75 | val metric = eval.evaluate(models(i).transform(dataset, epm(i))) 76 | evaluationRuntimes(i) += System.currentTimeMillis() - tic 77 | 78 | logDebug(s"Got metric $metric for model trained with ${epm(i)}.") 79 | i += 1 80 | } 81 | } 82 | 83 | f2jBLAS.dscal(numModels, 1.0 / $(numTimes), trainingRuntimes, 1) 84 | f2jBLAS.dscal(numModels, 1.0 / $(numTimes), evaluationRuntimes, 1) 85 | logInfo(s"Average training runtimes: ${trainingRuntimes.toSeq}") 86 | logInfo(s"Average evaluation runtimes: ${evaluationRuntimes.toSeq}") 87 | val (fastestRuntime, fastestIndex) = trainingRuntimes.zipWithIndex.minBy(_._1) 88 | logInfo(s"Fastest set of parameters:\n${epm(fastestIndex)}") 89 | logInfo(s"Fastest training runtime: $fastestRuntime.") 90 | 91 | copyValues(new BenchmarkModel(uid, models(fastestIndex), trainingRuntimes, evaluationRuntimes).setParent(this)) 92 | } 93 | 94 | override def transformSchema(schema: StructType): StructType = { 95 | validateParams() 96 | $(estimator).transformSchema(schema) 97 | } 98 | 99 | def validateParams(): Unit = { 100 | val est = $(estimator) 101 | for (paramMap <- $(estimatorParamMaps)) { 102 | est.copy(paramMap) 103 | } 104 | } 105 | 106 | override def copy(extra: ParamMap): Benchmarker = { 107 | val copied = defaultCopy(extra).asInstanceOf[Benchmarker] 108 | if (copied.isDefined(estimator)) { 109 | copied.setEstimator(copied.getEstimator.copy(extra)) 110 | } 111 | if (copied.isDefined(evaluator)) { 112 | copied.setEvaluator(copied.getEvaluator.copy(extra)) 113 | } 114 | copied 115 | } 116 | } 117 | 118 | /** 119 | * :: Experimental :: 120 | * Model from benchmark runs. 121 | */ 122 | @Experimental 123 | class BenchmarkModel private[ml]( 124 | override val uid: String, 125 | val fastestModel: Model[_], 126 | val avgTrainingRuntimes: Array[Double], 127 | val avgEvaluationRuntimes: Array[Double]) 128 | extends Model[BenchmarkModel] with BenchmarkerParams { 129 | 130 | override def transform(dataset: Dataset[_]): DataFrame = { 131 | transformSchema(dataset.schema, logging = true) 132 | fastestModel.transform(dataset) 133 | } 134 | 135 | override def transformSchema(schema: StructType): StructType = { 136 | fastestModel.transformSchema(schema) 137 | } 138 | 139 | override def copy(extra: ParamMap): BenchmarkModel = { 140 | val copied = new BenchmarkModel( 141 | uid, 142 | fastestModel.copy(extra).asInstanceOf[Model[_]], 143 | avgTrainingRuntimes.clone(), 144 | avgEvaluationRuntimes.clone()) 145 | copyValues(copied, extra).setParent(parent) 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /spark-knn.Rproj: -------------------------------------------------------------------------------- 1 | Version: 1.0 2 | 3 | RestoreWorkspace: Default 4 | SaveWorkspace: Default 5 | AlwaysSaveHistory: Default 6 | 7 | EnableCodeIndexing: Yes 8 | UseSpacesForTab: Yes 9 | NumSpacesForTab: 2 10 | Encoding: UTF-8 11 | 12 | RnwWeave: Sweave 13 | LaTeX: pdfLaTeX 14 | --------------------------------------------------------------------------------