├── .gitignore ├── .travis.yaml ├── LICENSE ├── README.md ├── bin └── click-through-rate-predictino.sh ├── build.sbt ├── build ├── sbt └── sbt-launch-lib.bash ├── data └── mllib │ ├── sample_libsvm_data.txt │ └── sample_linear_regression_data.txt ├── dev ├── lint-scala └── sbt-assembly-skip-test.sh ├── project ├── assembly.sbt ├── build.properties └── plugins.sbt ├── scalastyle-config.xml └── src ├── main └── scala │ └── org │ └── apache │ └── spark │ └── examples │ ├── kaggle │ ├── AvazuClickThroughRatePrediction.scala │ ├── AvitoContextAdClicks.scala │ ├── CriteoCtrPrediction.scala │ ├── SanFranciscoCrimeClassification.scala │ └── SpringleafMarketingResponse.scala │ └── ml │ ├── AbstractParams.scala │ ├── DataFrameExample.scala │ ├── GradientBoostedTreeClassifierExample.scala │ ├── GradientBoostedTreeRegressorExample.scala │ ├── LogisticRegressionExample.scala │ ├── LogisticRegressionSummaryExample.scala │ ├── LogisticRegressionWithElasticNetExample.scala │ ├── ModelSelectionViaCrossValidationExample.scala │ ├── ModelSelectionViaTrainValidationSplitExample.scala │ ├── OneHotEncoderExample.scala │ ├── PipelineExample.scala │ ├── StringIndexerExample.scala │ └── VectorIndexerExample.scala └── test ├── resources ├── test.part-10000 └── train.part-10000 └── scala └── org └── apache └── spark ├── SparkFunSuite.scala ├── examples └── kaggle │ └── ClickThroughRatePrediction.scala └── util ├── LocalClusterSparkContext.scala └── MLlibTestSparkContext.scala /.gitignore: -------------------------------------------------------------------------------- 1 | /bin/ 2 | .cache-main 3 | .cache-tests 4 | .classpath 5 | .project 6 | .settings 7 | project/project 8 | project/target 9 | target 10 | *.swp -------------------------------------------------------------------------------- /.travis.yaml: -------------------------------------------------------------------------------- 1 | language: scala 2 | sudo: false 3 | cache: 4 | directories: 5 | - $HOME/.ivy2 6 | matrix: 7 | include: 8 | - jdk: openjdk7 9 | scala: 2.10.5 10 | env: TEST_SPARK_VERSION="1.6.0" 11 | - jdk: openjdk8 12 | scala: 2.10.5 13 | env: TEST_SPARK_VERSION="1.6.0" 14 | - jdk: openjdk7 15 | scala: 2.11.7 16 | env: TEST_SPARK_VERSION="1.6.0" 17 | - jdk: openjdk8 18 | scala: 2.11.7 19 | env: TEST_SPARK_VERSION="1.6.0" 20 | script: 21 | - sbt -Dspark.testVersion=$TEST_SPARK_VERSION ++$TRAVIS_SCALA_VERSION coverage test 22 | - sbt ++$TRAVIS_SCALA_VERSION scalastyle 23 | after_success: 24 | - bash <(curl -s https://codecov.io/bash) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License, Version 2.0 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 13 | 14 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 15 | 16 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 17 | 18 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 19 | 20 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 21 | 22 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 23 | 24 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 25 | 26 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 27 | 28 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 29 | 30 | 2. Grant of Copyright License. 31 | 32 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 33 | 34 | 3. Grant of Patent License. 35 | 36 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 37 | 38 | 4. Redistribution. 39 | 40 | You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 41 | 42 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 43 | You must cause any modified files to carry prominent notices stating that You changed the files; and 44 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 45 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 46 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 47 | 48 | 5. Submission of Contributions. 49 | 50 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 51 | 52 | 6. Trademarks. 53 | 54 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 55 | 56 | 7. Disclaimer of Warranty. 57 | 58 | Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 59 | 60 | 8. Limitation of Liability. 61 | 62 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 63 | 64 | 9. Accepting Warranty or Additional Liability. 65 | 66 | While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 67 | 68 | END OF TERMS AND CONDITIONS 69 | 70 | APPENDIX: How to apply the Apache License to your work 71 | 72 | To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. 73 | 74 | Copyright [yyyy] [name of copyright owner] 75 | 76 | Licensed under the Apache License, Version 2.0 (the "License"); 77 | you may not use this file except in compliance with the License. 78 | You may obtain a copy of the License at 79 | 80 | http://www.apache.org/licenses/LICENSE-2.0 81 | 82 | Unless required by applicable law or agreed to in writing, software 83 | distributed under the License is distributed on an "AS IS" BASIS, 84 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 85 | See the License for the specific language governing permissions and 86 | limitations under the License. 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Try Kaggle's Click Through Rate Prediction with Spark Pipeline API 2 | 3 | https://issues.apache.org/jira/browse/SPARK-9941 4 | 5 | The purpose of this Spark Application is to test Spark Pipeline API with real data for [SPARK-13239](https://issues.apache.org/jira/browse/SPARK-13239). 6 | So, we tested ML Pipeline API with Kaggle's click-through rate prediction. 7 | 8 | 9 | ## Build & Run 10 | You can build this Spark application with `sbt clean assembly`. 11 | And you can run it the command. 12 | 13 | ``` 14 | $SPARK_HOME/bin/spark-submit \ 15 | -class org.apache.spark.examples.kaggle.ClickThroughRatePredictionWitLogisticRegression \ 16 | /path/to/click-through-rate-prediction-assembly-1.0.jar \ 17 | --train=/path/to/train \ 18 | --test=/path/to/test \ 19 | --result=/path/to/result.csv 20 | ``` 21 | 22 | - `--train`: the training data you downloaded 23 | - `--test`: the test data you downloaded 24 | - `--result`: result file 25 | 26 | You know, Spark ML can't write a single file directly. 27 | However, making the number of partitions of result DataFrame 1, this application aggregates the result as a file. 28 | So you can get the result CSV file from `part-00000` under the path which you set at `--result` option. 29 | 30 | ## The Kaggle Contest 31 | 32 | > Predict whether a mobile ad will be clicked 33 | > In online advertising, click-through rate (CTR) is a very important metric for evaluating ad performance. As a result, click prediction systems are essential and widely used for sponsored search and real-time bidding. 34 | 35 | https://www.kaggle.com/c/avazu-ctr-prediction 36 | 37 | 38 | ## Approach 39 | 40 | 1. Extracts features of categorical features with `OneHotEncoder` with `StringIndexer` 41 | 2. Train a model with `LogisticRegression` with `CrossValidator` 42 | - The `Evaluator` of `CrossValidator` is the default of `BinaryClassificationEvaluator`. 43 | 44 | We merged the training data with the test data in the extracting features phase. 45 | Since, the test data includes values which doesn't exists in the training data. 46 | Therefore, we needed to avoid errors about missing values of each variables, when extracting features of the test data. 47 | 48 | ## Result 49 | 50 | I got the score: `0.3998684` with the following parameter set. 51 | 52 | - Logistic Regression 53 | - `featuresCol`: features 54 | - `fitIntercept`: true 55 | - `labelCol`: label 56 | - `maxIter`: 100 57 | - `predictionCol`: prediction 58 | - `probabilityCol`: probability 59 | - `rawPredictionCol`: rawPrediction 60 | - `regParam`: 0.001 61 | - `standardization`: true 62 | - `threshold`: 0.22 63 | - `tol`: 1.0E-6 64 | - `weightCol`: 65 | 66 | ## TODO 67 | 68 | We should offer more `Evaluator`s, such as logg-loss. 69 | Since `spark.ml` doesn't offer Loggistic-Loss at Spark 1.6, we might get better score with logg-loss evaluator. 70 | -------------------------------------------------------------------------------- /bin/click-through-rate-predictino.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | THIS_PROJECT_HOME="$( cd $(dirname $( dirname $0)) && pwd )" 21 | 22 | PACKAGED_JAR=$(find ${THIS_PROJECT_HOME}/target -name "click-through-rate-prediction-assembly*.jar") 23 | 24 | ${SPARK_HOME}/bin/spark-submit \ 25 | --class "org.apache.spark.examples.kaggle.ClickThroughRatePrediction" \ 26 | "$PACKAGED_JAR" \ 27 | --train "s3n://s3-yu-ishikawa/test-data/click-through-rate-prediction/train" \ 28 | --test "s3n://s3-yu-ishikawa/test-data/click-through-rate-prediction/test" \ 29 | --result "s3n://s3-yu-ishikawa/test-data/click-through-rate-prediction/result/" 30 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | // Your sbt build file. Guides on how to write one can be found at 2 | // http://www.scala-sbt.org/0.13/docs/index.html 3 | 4 | scalaVersion := "2.10.5" 5 | 6 | sparkVersion := "1.6.0" 7 | 8 | crossScalaVersions := Seq("2.10.5", "2.11.7") 9 | 10 | spName := "yu-iskw/click-through-rate-prediction" 11 | 12 | // Don't forget to set the version 13 | version := "1.1" 14 | 15 | spAppendScalaVersion := true 16 | 17 | spIncludeMaven := true 18 | 19 | spIgnoreProvided := true 20 | 21 | // Can't parallelly execute in test 22 | parallelExecution in Test := false 23 | 24 | fork in Test := true 25 | 26 | javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=256m") 27 | 28 | // All Spark Packages need a license 29 | licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")) 30 | 31 | // Add Spark components this package depends on, e.g, "mllib", .... 32 | sparkComponents ++= Seq("sql", "mllib") 33 | 34 | libraryDependencies ++= Seq( 35 | "org.scalatest" %% "scalatest" % "2.1.5" % "test", 36 | "com.github.scopt" % "scopt_2.10" % "3.3.0" 37 | ) 38 | 39 | 40 | // uncomment and change the value below to change the directory where your zip artifact will be created 41 | // spDistDirectory := target.value 42 | 43 | // add any Spark Package dependencies using spDependencies. 44 | // e.g. spDependencies += "databricks/spark-avro:0.1" 45 | spDependencies += "databricks/spark-csv:1.3.0-s_2.10" 46 | 47 | // addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.1.0-RC1") 48 | 49 | // EclipseKeys.createSrc := EclipseCreateSrc.Default + EclipseCreateSrc.Resource -------------------------------------------------------------------------------- /build/sbt: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so 4 | # that we can run Hive to generate the golden answer. This is not required for normal development 5 | # or testing. 6 | for i in $HIVE_HOME/lib/* 7 | do HADOOP_CLASSPATH=$HADOOP_CLASSPATH:$i 8 | done 9 | export HADOOP_CLASSPATH 10 | 11 | realpath () { 12 | ( 13 | TARGET_FILE=$1 14 | 15 | cd $(dirname $TARGET_FILE) 16 | TARGET_FILE=$(basename $TARGET_FILE) 17 | 18 | COUNT=0 19 | while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] 20 | do 21 | TARGET_FILE=$(readlink $TARGET_FILE) 22 | cd $(dirname $TARGET_FILE) 23 | TARGET_FILE=$(basename $TARGET_FILE) 24 | COUNT=$(($COUNT + 1)) 25 | done 26 | 27 | echo $(pwd -P)/$TARGET_FILE 28 | ) 29 | } 30 | 31 | . $(dirname $(realpath $0))/sbt-launch-lib.bash 32 | 33 | 34 | declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" 35 | declare -r sbt_opts_file=".sbtopts" 36 | declare -r etc_sbt_opts_file="/etc/sbt/sbtopts" 37 | 38 | usage() { 39 | cat < path to global settings/plugins directory (default: ~/.sbt) 47 | -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series) 48 | -ivy path to local Ivy repository (default: ~/.ivy2) 49 | -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem)) 50 | -no-share use all local caches; no sharing 51 | -no-global uses global caches, but does not use global ~/.sbt directory. 52 | -jvm-debug Turn on JVM debugging, open at the given port. 53 | -batch Disable interactive mode 54 | # sbt version (default: from project/build.properties if present, else latest release) 55 | -sbt-version use the specified version of sbt 56 | -sbt-jar use the specified jar as the sbt launcher 57 | -sbt-rc use an RC version of sbt 58 | -sbt-snapshot use a snapshot version of sbt 59 | # java version (default: java from PATH, currently $(java -version 2>&1 | grep version)) 60 | -java-home alternate JAVA_HOME 61 | # jvm options and output control 62 | JAVA_OPTS environment variable, if unset uses "$java_opts" 63 | SBT_OPTS environment variable, if unset uses "$default_sbt_opts" 64 | .sbtopts if this file exists in the current directory, it is 65 | prepended to the runner args 66 | /etc/sbt/sbtopts if this file exists, it is prepended to the runner args 67 | -Dkey=val pass -Dkey=val directly to the java runtime 68 | -J-X pass option -X directly to the java runtime 69 | (-J is stripped) 70 | -S-X add -X to sbt's scalacOptions (-J is stripped) 71 | -PmavenProfiles Enable a maven profile for the build. 72 | In the case of duplicated or conflicting options, the order above 73 | shows precedence: JAVA_OPTS lowest, command line options highest. 74 | EOM 75 | } 76 | 77 | process_my_args () { 78 | while [[ $# -gt 0 ]]; do 79 | case "$1" in 80 | -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; 81 | -no-share) addJava "$noshare_opts" && shift ;; 82 | -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;; 83 | -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; 84 | -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;; 85 | -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; 86 | -batch) exec &2 "$@" 31 | } 32 | vlog () { 33 | [[ $verbose || $debug ]] && echoerr "$@" 34 | } 35 | dlog () { 36 | [[ $debug ]] && echoerr "$@" 37 | } 38 | 39 | acquire_sbt_jar () { 40 | SBT_VERSION=`awk -F "=" '/sbt\\.version/ {print $2}' ./project/build.properties` 41 | URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar 42 | JAR=build/sbt-launch-${SBT_VERSION}.jar 43 | 44 | sbt_jar=$JAR 45 | 46 | if [[ ! -f "$sbt_jar" ]]; then 47 | # Download sbt launch jar if it hasn't been downloaded yet 48 | if [ ! -f ${JAR} ]; then 49 | # Download 50 | printf "Attempting to fetch sbt\n" 51 | JAR_DL=${JAR}.part 52 | if hash curl 2>/dev/null; then 53 | curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ 54 | mv "${JAR_DL}" "${JAR}" 55 | elif hash wget 2>/dev/null; then 56 | wget --quiet ${URL1} -O "${JAR_DL}" &&\ 57 | mv "${JAR_DL}" "${JAR}" 58 | else 59 | printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" 60 | exit -1 61 | fi 62 | fi 63 | if [ ! -f ${JAR} ]; then 64 | # We failed to download 65 | printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" 66 | exit -1 67 | fi 68 | printf "Launching sbt from ${JAR}\n" 69 | fi 70 | } 71 | 72 | execRunner () { 73 | # print the arguments one to a line, quoting any containing spaces 74 | [[ $verbose || $debug ]] && echo "# Executing command line:" && { 75 | for arg; do 76 | if printf "%s\n" "$arg" | grep -q ' '; then 77 | printf "\"%s\"\n" "$arg" 78 | else 79 | printf "%s\n" "$arg" 80 | fi 81 | done 82 | echo "" 83 | } 84 | 85 | exec "$@" 86 | } 87 | 88 | addJava () { 89 | dlog "[addJava] arg = '$1'" 90 | java_args=( "${java_args[@]}" "$1" ) 91 | } 92 | 93 | enableProfile () { 94 | dlog "[enableProfile] arg = '$1'" 95 | maven_profiles=( "${maven_profiles[@]}" "$1" ) 96 | export SBT_MAVEN_PROFILES="${maven_profiles[@]}" 97 | } 98 | 99 | addSbt () { 100 | dlog "[addSbt] arg = '$1'" 101 | sbt_commands=( "${sbt_commands[@]}" "$1" ) 102 | } 103 | addResidual () { 104 | dlog "[residual] arg = '$1'" 105 | residual_args=( "${residual_args[@]}" "$1" ) 106 | } 107 | addDebugger () { 108 | addJava "-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=$1" 109 | } 110 | 111 | # a ham-fisted attempt to move some memory settings in concert 112 | # so they need not be dicked around with individually. 113 | get_mem_opts () { 114 | local mem=${1:-2048} 115 | local perm=$(( $mem / 4 )) 116 | (( $perm > 256 )) || perm=256 117 | (( $perm < 4096 )) || perm=4096 118 | local codecache=$(( $perm / 2 )) 119 | 120 | echo "-Xms${mem}m -Xmx${mem}m -XX:MaxPermSize=${perm}m -XX:ReservedCodeCacheSize=${codecache}m" 121 | } 122 | 123 | require_arg () { 124 | local type="$1" 125 | local opt="$2" 126 | local arg="$3" 127 | if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then 128 | die "$opt requires <$type> argument" 129 | fi 130 | } 131 | 132 | is_function_defined() { 133 | declare -f "$1" > /dev/null 134 | } 135 | 136 | process_args () { 137 | while [[ $# -gt 0 ]]; do 138 | case "$1" in 139 | -h|-help) usage; exit 1 ;; 140 | -v|-verbose) verbose=1 && shift ;; 141 | -d|-debug) debug=1 && shift ;; 142 | 143 | -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;; 144 | -mem) require_arg integer "$1" "$2" && sbt_mem="$2" && shift 2 ;; 145 | -jvm-debug) require_arg port "$1" "$2" && addDebugger $2 && shift 2 ;; 146 | -batch) exec scalastyle.txt 25 | 26 | ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') 27 | rm scalastyle.txt 28 | 29 | if test ! -z "$ERRORS"; then 30 | echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" 31 | exit 1 32 | else 33 | echo -e "Scalastyle checks passed." 34 | fi 35 | -------------------------------------------------------------------------------- /dev/sbt-assembly-skip-test.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env bash 3 | 4 | # 5 | # Licensed to the Apache Software Foundation (ASF) under one or more 6 | # contributor license agreements. See the NOTICE file distributed with 7 | # this work for additional information regarding copyright ownership. 8 | # The ASF licenses this file to You under the Apache License, Version 2.0 9 | # (the "License"); you may not use this file except in compliance with 10 | # the License. You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" 22 | PROJECT_HOME="$(dirname $SCRIPT_DIR)" 23 | 24 | sbt "set test in assembly := {}" clean assembly 25 | -------------------------------------------------------------------------------- /project/assembly.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.2") -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | // This file should only contain the version of sbt to use. 2 | sbt.version=0.13.6 3 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | // You may use this file to add plugin dependencies for sbt. 2 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" 3 | 4 | resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" 5 | 6 | addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.3") 7 | 8 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") 9 | 10 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.1.0") 11 | 12 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 17 | 39 | 40 | 41 | Scalastyle standard configuration 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | true 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW 126 | 127 | 128 | 129 | 130 | 131 | ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | ^FunSuite[A-Za-z]*$ 141 | Tests must extend org.apache.spark.SparkFunSuite instead. 142 | 143 | 144 | 145 | 146 | ^println$ 147 | 151 | 152 | 153 | 154 | Class\.forName 155 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 800> 210 | 211 | 212 | 213 | 214 | 30 215 | 216 | 217 | 218 | 219 | 10 220 | 221 | 222 | 223 | 224 | 50 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | -1,0,1,2,3 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/kaggle/AvazuClickThroughRatePrediction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.examples.kaggle 19 | 20 | import scala.collection.mutable.ArrayBuffer 21 | 22 | import scopt.OptionParser 23 | 24 | import org.apache.spark.{SparkConf, SparkContext} 25 | import org.apache.spark.ml.{Pipeline, PipelineStage} 26 | import org.apache.spark.ml.classification.LogisticRegression 27 | import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator 28 | import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler} 29 | import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} 30 | import org.apache.spark.mllib.linalg.Vector 31 | import org.apache.spark.sql.{Row, SQLContext, SaveMode} 32 | import org.apache.spark.sql.functions.col 33 | import org.apache.spark.sql.types._ 34 | 35 | /* 36 | * https://www.kaggle.com/c/avazu-ctr-prediction 37 | * https://issues.apache.org/jira/browse/SPARK-13239 38 | * 39 | * Code is from: 40 | * https://github.com/yu-iskw/click-through-rate-prediction 41 | */ 42 | // scalastyle:off println 43 | object AvazuClickThroughRatePrediction { 44 | 45 | private val trainSchema = StructType(Array( 46 | StructField("id", StringType, false), 47 | StructField("click", IntegerType, true), 48 | StructField("hour", IntegerType, true), 49 | StructField("C1", IntegerType, true), 50 | StructField("banner_pos", IntegerType, true), 51 | StructField("site_id", StringType, true), 52 | StructField("site_domain", StringType, true), 53 | StructField("site_category", StringType, true), 54 | StructField("app_id", StringType, true), 55 | StructField("app_domain", StringType, true), 56 | StructField("app_category", StringType, true), 57 | StructField("device_id", StringType, true), 58 | StructField("device_ip", StringType, true), 59 | StructField("device_model", StringType, true), 60 | StructField("device_type", IntegerType, true), 61 | StructField("device_conn_type", IntegerType, true), 62 | StructField("C14", IntegerType, true), 63 | StructField("C15", IntegerType, true), 64 | StructField("C16", IntegerType, true), 65 | StructField("C17", IntegerType, true), 66 | StructField("C18", IntegerType, true), 67 | StructField("C19", IntegerType, true), 68 | StructField("C20", IntegerType, true), 69 | StructField("C21", IntegerType, true) 70 | )) 71 | 72 | private val testSchema = StructType(Array( 73 | StructField("id", StringType, false), 74 | StructField("hour", IntegerType, true), 75 | StructField("C1", IntegerType, true), 76 | StructField("banner_pos", IntegerType, true), 77 | StructField("site_id", StringType, true), 78 | StructField("site_domain", StringType, true), 79 | StructField("site_category", StringType, true), 80 | StructField("app_id", StringType, true), 81 | StructField("app_domain", StringType, true), 82 | StructField("app_category", StringType, true), 83 | StructField("device_id", StringType, true), 84 | StructField("device_ip", StringType, true), 85 | StructField("device_model", StringType, true), 86 | StructField("device_type", IntegerType, true), 87 | StructField("device_conn_type", IntegerType, true), 88 | StructField("C14", IntegerType, true), 89 | StructField("C15", IntegerType, true), 90 | StructField("C16", IntegerType, true), 91 | StructField("C17", IntegerType, true), 92 | StructField("C18", IntegerType, true), 93 | StructField("C19", IntegerType, true), 94 | StructField("C20", IntegerType, true), 95 | StructField("C21", IntegerType, true) 96 | )) 97 | 98 | case class ClickThroughRatePredictionParams( 99 | trainInput: String = null, 100 | testInput: String = null, 101 | resultOutput: String = null 102 | ) 103 | 104 | /** 105 | * Try Kaggle's Click-Through Rate Prediction with Logistic Regression Classification 106 | * Run with 107 | * {{ 108 | * $SPARK_HOME/bin/spark-submit \ 109 | * --class org.apache.spark.examples.kaggle.ClickThroughRatePredictionWitLogisticRegression \ 110 | * /path/to/click-through-rate-prediction-assembly-1.1.jar \ 111 | * --train=/path/to/train \ 112 | * --test=/path/to/test \ 113 | * --result=/path/to/result.csv 114 | * }} 115 | * SEE ALSO: https://www.kaggle.com/c/avazu-ctr-prediction 116 | */ 117 | def main(args: Array[String]): Unit = { 118 | val conf = new SparkConf().setAppName(this.getClass.getSimpleName) 119 | val sc = new SparkContext(conf) 120 | val sqlContext = new SQLContext(sc) 121 | 122 | val defaultParam = new ClickThroughRatePredictionParams() 123 | val parser = new OptionParser[ClickThroughRatePredictionParams](this.getClass.getSimpleName) { 124 | head(s"${this.getClass.getSimpleName}: Try a Kaggle competition.") 125 | opt[String]("train") 126 | .text("train input") 127 | .action((x, c) => c.copy(trainInput = x)) 128 | .required() 129 | opt[String]("test") 130 | .text("test input") 131 | .action((x, c) => c.copy(testInput = x)) 132 | .required() 133 | opt[String]("result") 134 | .text("result output path") 135 | .action((x, c) => c.copy(resultOutput = x)) 136 | .required() 137 | } 138 | parser.parse(args, defaultParam).map { params => 139 | run(sc, sqlContext, params.trainInput, params.testInput, params.resultOutput) 140 | } getOrElse { 141 | sys.exit(1) 142 | } 143 | sc.stop() 144 | } 145 | 146 | def run(sc: SparkContext, sqlContext: SQLContext, 147 | trainPath: String, testPath: String, resultPath: String): Unit = { 148 | import sqlContext.implicits._ 149 | 150 | // Sets the target variables 151 | val targetVariables = Array( 152 | "banner_pos", "site_id", "site_domain", "site_category", 153 | "app_domain", "app_category", "device_model", "device_type", "device_conn_type", 154 | "C1", "C14", "C15", "C16", "C17", "C18", "C19", "C20", "C21" 155 | ) 156 | 157 | // Loads training data and testing data from CSV files 158 | val train = sqlContext.read.format("com.databricks.spark.csv") 159 | .option("header", "true") 160 | .schema(trainSchema) 161 | .load(trainPath).cache() 162 | val test = sqlContext.read.format("com.databricks.spark.csv") 163 | .option("header", "true") 164 | .schema(testSchema) 165 | .load(testPath).cache() 166 | 167 | // Union data for one-hot encoding 168 | // To extract features throughly, union the training and test data. 169 | // Since the test data includes values which doesn't exists in the training data. 170 | val train4union = train.select(targetVariables.map(col): _*) 171 | val test4union = test.select(targetVariables.map(col): _*) 172 | val union = train4union.unionAll(test4union).cache() 173 | 174 | // Extracts features with one-hot encoding 175 | def getIndexedColumn(clm: String): String = s"${clm}_indexed" 176 | def getColumnVec(clm: String): String = s"${clm}_vec" 177 | val feStages = ArrayBuffer.empty[PipelineStage] 178 | targetVariables.foreach { clm => 179 | val stringIndexer = new StringIndexer() 180 | .setInputCol(clm) 181 | .setOutputCol(getIndexedColumn(clm)) 182 | .setHandleInvalid("error") 183 | val oneHotEncoder = new OneHotEncoder() 184 | .setInputCol(getIndexedColumn(clm)) 185 | .setOutputCol(getColumnVec(clm)) 186 | Array(stringIndexer, oneHotEncoder) 187 | feStages.append(stringIndexer) 188 | feStages.append(oneHotEncoder) 189 | } 190 | val va = new VectorAssembler() 191 | .setInputCols(targetVariables.map(getColumnVec)) 192 | .setOutputCol("features") 193 | feStages.append(va) 194 | val fePipeline = new Pipeline().setStages(feStages.toArray) 195 | val feModel = fePipeline.fit(union) 196 | val trainDF = feModel.transform(train).select('click, 'features).cache() 197 | val testDF = feModel.transform(test).select('id, 'features).cache() 198 | union.unpersist() 199 | train.unpersist() 200 | test.unpersist() 201 | 202 | // Trains a model with CrossValidator 203 | val si4click = new StringIndexer() 204 | .setInputCol("click") 205 | .setOutputCol("label") 206 | val lr = new LogisticRegression() 207 | val pipeline = new Pipeline().setStages(Array(si4click, lr)) 208 | val paramGrid = new ParamGridBuilder() 209 | .addGrid(lr.threshold, Array(0.22)) 210 | .addGrid(lr.elasticNetParam, Array(0.0)) 211 | .addGrid(lr.regParam, Array(0.001)) 212 | .addGrid(lr.maxIter, Array(100)) 213 | .build() 214 | val cv = new CrossValidator() 215 | .setEstimator(pipeline) 216 | .setEvaluator(new BinaryClassificationEvaluator()) 217 | .setEstimatorParamMaps(paramGrid) 218 | .setNumFolds(3) 219 | val cvModel = cv.fit(trainDF) 220 | 221 | // Shows the best parameters 222 | cvModel.bestModel.parent match { 223 | case pipeline: Pipeline => 224 | pipeline.getStages.zipWithIndex.foreach { case (stage, index) => 225 | println(s"Stage[${index + 1}]: ${stage.getClass.getSimpleName}") 226 | println(stage.extractParamMap()) 227 | } 228 | } 229 | 230 | // Predicts with the trained best model 231 | val resultDF = cvModel.transform(testDF).select('id, 'probability).map { 232 | case Row(id: String, probability: Vector) => (id, probability(1)) 233 | }.toDF("id", "click") 234 | 235 | // Save the result 236 | resultDF.repartition(1).write.mode(SaveMode.Overwrite) 237 | .format("com.databricks.spark.csv") 238 | .option("header", "true").option("inferSchema", "true") 239 | .save(resultPath) 240 | } 241 | } 242 | 243 | // scalastyle:on println 244 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/kaggle/AvitoContextAdClicks.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.kaggle 2 | 3 | /** 4 | * https://www.kaggle.com/c/avito-context-ad-clicks 5 | * https://issues.apache.org/jira/browse/SPARK-10935 6 | * 7 | * Code copied from: 8 | * https://github.com/yinxusen/incubator-project/blob/b332de87606b4599d96a6cc41aadb934f2f30577/avito/src/main/scala/org/apache/spark/examples/main.scala 9 | * 10 | * Other useful repos: 11 | * https://github.com/bluebytes60/SparkML/tree/master/src/main/scala/avito 12 | * https://github.com/Sirorezka/SNA_Hackaton/blob/a03db56b329fdea116febfeb946e2db6f33ae610/My_Model_scala/src/main/scala/Baseline.scala 13 | */ 14 | import org.apache.spark.ml.classification.LogisticRegression 15 | import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator 16 | import org.apache.spark.ml.feature.VectorAssembler 17 | import org.apache.spark.sql.{ Row, SQLContext } 18 | import org.apache.spark.{ SparkConf, SparkContext } 19 | 20 | case class SearchStream( 21 | searchId: Int, 22 | adId: Int, 23 | position: Int, 24 | objectType: Int, 25 | histCTR: Double, 26 | isClick: Double) 27 | 28 | object AvitoContextAdClicks { 29 | def main(args: Array[String]): Unit = { 30 | val conf = new SparkConf().setMaster("local[4]").setAppName("Avito") 31 | val sc = new SparkContext(conf) 32 | val sqlCtx = new SQLContext(sc) 33 | import sqlCtx.implicits._ 34 | 35 | // here we need lots of work to load data 36 | val trainSearchStream = sc.textFile("/Users/panda/data/ads/trainSearchStream.tsv") 37 | .map(_.split('\t')) 38 | .filter(!_.contains("SearchID")) 39 | .filter(_.length == 6) 40 | .map { record => 41 | SearchStream( 42 | record(0).toInt, 43 | record(1).toInt, 44 | record(2).toInt, 45 | record(3).toInt, 46 | record(4).toDouble, 47 | record(5).toDouble) 48 | }.toDF() 49 | 50 | val valueForClicks = trainSearchStream.select("isClick").distinct() 51 | .map { case Row(isClick: Double) => isClick }.collect() 52 | 53 | val assembler = new VectorAssembler() 54 | .setInputCols(Array("adId", "position", "objectType", "histCTR")).setOutputCol("feature") 55 | 56 | val dataSet = assembler.transform(trainSearchStream).select("feature", "isClick") 57 | 58 | val splits = dataSet.randomSplit(Array(0.7, 0.3)) 59 | val trainingSet = splits(0) 60 | val testSet = splits(1) 61 | 62 | val lr = new LogisticRegression() 63 | .setMaxIter(20) 64 | .setRegParam(0.03) 65 | .setElasticNetParam(0.1) 66 | .setFeaturesCol("feature") 67 | .setLabelCol("isClick") 68 | .setRawPredictionCol("result") 69 | 70 | val lrModel = lr.fit(trainingSet) 71 | 72 | val evaluator = new MulticlassClassificationEvaluator() 73 | .setLabelCol("isClick").setPredictionCol("result") 74 | 75 | val eval = evaluator.evaluate(lrModel.transform(testSet)) 76 | println(eval) 77 | } 78 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/kaggle/CriteoCtrPrediction.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.kaggle 2 | 3 | import scala.collection.mutable.ArrayBuffer 4 | import org.apache.spark.SparkConf 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.ml.Pipeline 7 | import org.apache.spark.ml.PipelineStage 8 | import org.apache.spark.ml.Transformer 9 | import org.apache.spark.ml.classification.LogisticRegression 10 | import org.apache.spark.ml.feature.OneHotEncoder 11 | import org.apache.spark.ml.feature.StringIndexer 12 | import org.apache.spark.ml.feature.VectorAssembler 13 | import org.apache.spark.ml.util.MetadataUtils 14 | import org.apache.spark.mllib.evaluation.MulticlassMetrics 15 | import org.apache.spark.rdd.RDD 16 | import org.apache.spark.sql.DataFrame 17 | import org.apache.spark.sql.Row 18 | import org.apache.spark.sql.SQLContext 19 | import org.apache.spark.sql.functions.col 20 | import org.apache.spark.sql.types.IntegerType 21 | import org.apache.spark.sql.types.StringType 22 | import org.apache.spark.sql.types.StructField 23 | import org.apache.spark.sql.types.StructType 24 | import org.apache.spark.sql.types.DoubleType 25 | import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator 26 | import org.apache.spark.ml.tuning.CrossValidator 27 | import org.apache.spark.ml.tuning.ParamGridBuilder 28 | 29 | /** 30 | * https://www.kaggle.com/c/criteo-display-ad-challenge 31 | * https://issues.apache.org/jira/browse/SPARK-10870 32 | * 33 | * Rewrite the Python pipeline proposed by Manisha S in Scala 34 | * https://developer.ibm.com/spark/blog/2016/02/22/predictive-model-for-online-advertising-using-spark-machine-learning-pipelines/ 35 | * 36 | * Use some code from: 37 | * https://github.com/yu-iskw/click-through-rate-prediction 38 | */ 39 | object CriteoCtrPrediction { 40 | private val schema = StructType(Array( 41 | StructField("Label", DoubleType, false), 42 | StructField("I1", IntegerType, true), 43 | StructField("I2", IntegerType, true), 44 | StructField("I3", IntegerType, true), 45 | StructField("I4", IntegerType, true), 46 | StructField("I5", IntegerType, true), 47 | StructField("I6", IntegerType, true), 48 | StructField("I7", IntegerType, true), 49 | StructField("I8", IntegerType, true), 50 | StructField("I9", IntegerType, true), 51 | StructField("I10", IntegerType, true), 52 | StructField("I11", IntegerType, true), 53 | StructField("I12", IntegerType, true), 54 | StructField("I13", IntegerType, true), 55 | StructField("C1", StringType, true), 56 | StructField("C2", StringType, true), 57 | StructField("C3", StringType, true), 58 | StructField("C4", StringType, true), 59 | StructField("C5", StringType, true), 60 | StructField("C6", StringType, true), 61 | StructField("C7", StringType, true), 62 | StructField("C8", StringType, true), 63 | StructField("C9", StringType, true), 64 | StructField("C10", StringType, true), 65 | StructField("C11", StringType, true), 66 | StructField("C12", StringType, true), 67 | StructField("C13", StringType, true), 68 | StructField("C14", StringType, true), 69 | StructField("C15", StringType, true), 70 | StructField("C16", StringType, true), 71 | StructField("C17", StringType, true), 72 | StructField("C18", StringType, true), 73 | StructField("C19", StringType, true), 74 | StructField("C20", StringType, true), 75 | StructField("C21", StringType, true), 76 | StructField("C22", StringType, true), 77 | StructField("C23", StringType, true), 78 | StructField("C24", StringType, true), 79 | StructField("C25", StringType, true), 80 | StructField("C26", StringType, true))) 81 | 82 | def toInt(s: String): Option[Int] = { 83 | try { 84 | Some(s.toInt) 85 | } catch { 86 | case e: Exception => None 87 | } 88 | } 89 | 90 | def parseData(data: RDD[String], sqlContext: SQLContext): DataFrame = { 91 | // Split the csv file by comma and convert each line to a tuple. 92 | val parts = data.map(line => line.split("\t", -1)) 93 | parts.take(10).foreach(arr => println(arr.size)) 94 | val features = parts.map(p => Row(p(0).toDouble, toInt(p(1)), toInt(p(2)), toInt(p(3)), toInt(p(4)), toInt(p(5)), 95 | toInt(p(6)), toInt(p(7)), toInt(p(8)), toInt(p(9)), toInt(p(10)), toInt(p(11)), toInt(p(12)), toInt(p(13)), 96 | p(14), p(15), p(16), p(17), p(18), p(19), 97 | p(20), p(21), p(22), p(23), p(24), p(25), p(26), p(27), p(28), p(29), 98 | p(30), p(31), p(32), p(33), p(34), p(35), p(36), p(37), p(38), p(39))) 99 | 100 | // Apply the schema to the RDD. 101 | return sqlContext.createDataFrame(features, schema) 102 | } 103 | 104 | def main(args: Array[String]): Unit = { 105 | val conf = new SparkConf().setMaster("local").setAppName("CriteoCtrPrediction") 106 | val sc = new SparkContext(conf) 107 | 108 | // $example on$ 109 | // val sampleData = sc.textFile("D:\\Datasets\\Criteo\\dac_sample\\dac_sample.txt.mini", 2) 110 | val sampleData = sc.textFile("D:\\Datasets\\Criteo\\dac_sample\\dac_sample.txt", 2) 111 | 112 | println(s"Data size is ${sampleData.count}") 113 | sampleData.take(2).foreach(println) 114 | 115 | val sqlContext = new SQLContext(sc) 116 | // Register the DataFrame as a table. 117 | val schemaClicks = parseData(sampleData, sqlContext) 118 | schemaClicks.registerTempTable("clicks") 119 | schemaClicks.printSchema() 120 | 121 | val cols = Seq("C1", "C2", "C3", "C4", "C5", "C6", "C7", "C8", "C9", "C10", "C11", "C12", "C13", 122 | "C14", "C15", "C16", "C17", "C18", "C19", "C20", "C21", "C22", "C23", "C24", "C25", "C26") 123 | // Replace empty values in Categorical features by "NA" 124 | val schemaClicksNA = schemaClicks.na.replace(cols, Map("" -> "NA")) 125 | // Drop rows containing null values in the DataFrame 126 | val schemaClicksCleaned = schemaClicksNA.na.drop() 127 | 128 | val Array(trainData, testData) = schemaClicksCleaned.randomSplit(Array(0.9, 0.1), seed = 42) 129 | trainData.cache() 130 | testData.cache() 131 | 132 | // val indexer = new StringIndexer() 133 | // .setInputCol("category") 134 | // .setOutputCol("categoryIndex") 135 | // 136 | // val encoder = new OneHotEncoder() 137 | // .setInputCol("categoryIndex") 138 | // .setOutputCol("categoryVec") 139 | // 140 | // val assembler = new VectorAssembler() 141 | // .setInputCols(Array("hour", "mobile", "userFeatures")) 142 | // .setOutputCol("features") 143 | // 144 | // val pipeline = new Pipeline() 145 | // .setStages(Array(indexer, encoder, assembler)) 146 | // 147 | // // Fit the pipeline to training documents. 148 | // val model = pipeline.fit(trainData) 149 | 150 | // Union data for one-hot encoding 151 | // To extract features throughly, union the training and test data. 152 | // Since the test data includes values which doesn't exists in the training data. 153 | val train4union = trainData.select(cols.map(col): _*) 154 | val test4union = testData.select(cols.map(col): _*) 155 | val union = train4union.unionAll(test4union).cache() 156 | 157 | // Extracts features with one-hot encoding 158 | def getIndexedColumn(column: String): String = s"${column}_indexed" 159 | def getColumnVec(column: String): String = s"${column}_vec" 160 | val feStages = ArrayBuffer.empty[PipelineStage] 161 | cols.foreach { clm => 162 | val stringIndexer = new StringIndexer() 163 | .setInputCol(clm) 164 | .setOutputCol(getIndexedColumn(clm)) 165 | .setHandleInvalid("error") 166 | val oneHotEncoder = new OneHotEncoder() 167 | .setInputCol(getIndexedColumn(clm)) 168 | .setOutputCol(getColumnVec(clm)) 169 | .setDropLast(false) 170 | Array(stringIndexer, oneHotEncoder) 171 | feStages.append(stringIndexer) 172 | feStages.append(oneHotEncoder) 173 | } 174 | val va = new VectorAssembler() 175 | .setInputCols(cols.toArray.map(getColumnVec)) 176 | .setOutputCol("features") 177 | feStages.append(va) 178 | val fePipeline = new Pipeline().setStages(feStages.toArray) 179 | val feModel = fePipeline.fit(union) 180 | val trainDF = feModel.transform(trainData).select("Label", "features").cache() 181 | val testDF = feModel.transform(testData).select("Label", "features").cache() 182 | union.unpersist() 183 | trainData.unpersist() 184 | testData.unpersist() 185 | trainDF.show(5) 186 | testDF.show(5) 187 | 188 | val lr = new LogisticRegression() 189 | .setFeaturesCol("features") 190 | .setLabelCol("Label") 191 | // .setRegParam(0.3) 192 | // .setElasticNetParam(0.8) 193 | // .setMaxIter(10) 194 | // .setTol(1E-6) 195 | // .setFitIntercept(true) 196 | 197 | // Fit the Pipeline 198 | val startTime = System.nanoTime() 199 | // val lrModel = lr.fit(trainDF) 200 | 201 | val paramGrid = new ParamGridBuilder() 202 | // .addGrid(lr.regParam, Array(1, 0.3, 0.2, 0.1, 0.01, 0.001)) 203 | // .addGrid(lr.elasticNetParam, Array(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) 204 | .addGrid(lr.regParam, Array(1, 0.1, 0.01)) 205 | .addGrid(lr.elasticNetParam, Array(0.2, 0.8)) 206 | .build() 207 | 208 | // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. 209 | // This will allow us to jointly choose parameters for all Pipeline stages. 210 | // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. 211 | // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric 212 | // is areaUnderROC. 213 | val evaluator = new BinaryClassificationEvaluator() 214 | .setLabelCol("Label") 215 | val cv = new CrossValidator() 216 | .setEstimator(lr) 217 | .setEvaluator(evaluator) 218 | .setEstimatorParamMaps(paramGrid) 219 | .setNumFolds(3) // Use 3+ in practice 220 | 221 | // Run cross-validation, and choose the best set of parameters. 222 | val cvModel = cv.fit(trainDF) 223 | 224 | val elapsedTime = (System.nanoTime() - startTime) / 1e9 225 | println(s"Training time: $elapsedTime seconds") 226 | 227 | // Print the weights and intercept for logistic regression. 228 | // println(s"Weights: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") 229 | 230 | println("Training data results:") 231 | evaluateClassificationModel(cvModel, trainDF, "Label") 232 | println("Test data results:") 233 | evaluateClassificationModel(cvModel, testDF, "Label") 234 | 235 | // $example off$ 236 | sc.stop() 237 | } 238 | 239 | /** 240 | * Evaluate the given ClassificationModel on data. Print the results. 241 | * @param model Must fit ClassificationModel abstraction 242 | * @param data DataFrame with "prediction" and labelColName columns 243 | * @param labelColName Name of the labelCol parameter for the model 244 | * 245 | * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995 246 | */ 247 | def evaluateClassificationModel( 248 | model: Transformer, 249 | data: DataFrame, 250 | labelColName: String): Unit = { 251 | val fullPredictions = model.transform(data).cache() 252 | val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) 253 | val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) 254 | // Print number of classes for reference 255 | // val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { 256 | // case Some(n) => n 257 | // case None => throw new RuntimeException( 258 | // "Unknown failure when indexing labels for classification.") 259 | // } 260 | val numClasses = 2 261 | val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision 262 | println(s" Accuracy ($numClasses classes): $accuracy") 263 | } 264 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/kaggle/SanFranciscoCrimeClassification.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.kaggle 2 | 3 | import java.io._ 4 | 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.SparkConf 7 | import org.apache.spark.ml.Pipeline 8 | import org.apache.spark.ml.feature.{ StringIndexerModel, StandardScaler, StringIndexer, VectorAssembler } 9 | import org.apache.spark.sql.{ SQLContext, Row } 10 | import org.apache.spark.ml.classification.{ DecisionTreeClassifier, LogisticRegression } 11 | 12 | /** 13 | * https://www.kaggle.com/c/sf-crime 14 | * https://issues.apache.org/jira/browse/SPARK-10055 15 | * Copied from: 16 | * https://github.com/Lewuathe/spark-kaggle-examples 17 | * 18 | * Created by sasakikai on 8/24/15. 19 | */ 20 | object SanFranciscoCrimeClassification { 21 | 22 | def labelToVec(label: Int, labels: Array[String], sortedLabels: Array[String]): Array[Int] = { 23 | require(labels.length == sortedLabels.length) 24 | val stringLabel = labels(label) 25 | val sortedIndex = sortedLabels.indexOf(stringLabel) 26 | val ret = new Array[Int](labels.length) 27 | ret(sortedIndex) = 1 28 | ret 29 | } 30 | 31 | def main(args: Array[String]) { 32 | if (args.length < 3) { 33 | println("File path must be passed. " + args.length) 34 | System.exit(-1) 35 | } 36 | val trainFilePath = args(0) 37 | val testFilePath = args(1) 38 | val outputFilePath = args(2) 39 | val conf = new SparkConf().setAppName("SfCrimeClassification") 40 | val sc = new SparkContext(conf) 41 | val sqlContext = new SQLContext(sc) 42 | 43 | /** 44 | * Training Phase 45 | */ 46 | val trainData = sqlContext.read.format("com.databricks.spark.csv") 47 | .option("header", "true").option("inferSchema", "true").load(trainFilePath) 48 | 49 | val categoryIndexer = new StringIndexer().setInputCol("Category") 50 | .setOutputCol("label") 51 | val dayOfWeekIndexer = new StringIndexer().setInputCol("DayOfWeek") 52 | .setOutputCol("DayOfWeekIndex") 53 | val pdDistrictIndexer = new StringIndexer().setInputCol("PdDistrict") 54 | .setOutputCol("PdDistrictIndex") 55 | val vectorAssembler = new VectorAssembler().setInputCols(Array("DayOfWeekIndex", 56 | "PdDistrictIndex", "X", "Y")).setOutputCol("rowFeatures") 57 | val featureScaler = new StandardScaler().setInputCol("rowFeatures") 58 | .setOutputCol("features") 59 | val classifier = new DecisionTreeClassifier() 60 | 61 | val trainPipeline = new Pipeline().setStages(Array(categoryIndexer, dayOfWeekIndexer, 62 | pdDistrictIndexer, vectorAssembler, featureScaler, classifier)) 63 | 64 | val model = trainPipeline.fit(trainData) 65 | 66 | /** 67 | * Test Phase 68 | */ 69 | val testData = sqlContext.read.format("com.databricks.spark.csv") 70 | .option("header", "true").option("inferSchema", "true").load(testFilePath) 71 | 72 | val labels = Array("LARCENY/THEFT", "OTHER OFFENSES", "NON-CRIMINAL", "ASSAULT", "DRUG/NARCOTIC", 73 | "VEHICLE THEFT", "VANDALISM", "WARRANTS", "BURGLARY", "SUSPICIOUS OCC", "MISSING PERSON", 74 | "ROBBERY", "FRAUD", "FORGERY/COUNTERFEITING", "SECONDARY CODES", "WEAPON LAWS", "PROSTITUTION", 75 | "TRESPASS", "STOLEN PROPERTY", "SEX OFFENSES FORCIBLE", "DISORDERLY CONDUCT", "DRUNKENNESS", 76 | "RECOVERED VEHICLE", "KIDNAPPING", "DRIVING UNDER THE INFLUENCE", "RUNAWAY", "LIQUOR LAWS", 77 | "ARSON", "LOITERING", "EMBEZZLEMENT", "SUICIDE", "FAMILY OFFENSES", "BAD CHECKS", "BRIBERY", 78 | "EXTORTION", "SEX OFFENSES NON FORCIBLE", "GAMBLING", "PORNOGRAPHY/OBSCENE MAT", "TREA") 79 | 80 | val writer = new PrintWriter(new File(outputFilePath)) 81 | writer.write("Id,ARSON,ASSAULT,BAD CHECKS,BRIBERY,BURGLARY,DISORDERLY CONDUCT,DRIVING UNDER THE INFLUENCE,DRUG/NARCOTIC,DRUNKENNESS,EMBEZZLEMENT,EXTORTION,FAMILY OFFENSES,FORGERY/COUNTERFEITING,FRAUD,GAMBLING,KIDNAPPING,LARCENY/THEFT,LIQUOR LAWS,LOITERING,MISSING PERSON,NON-CRIMINAL,OTHER OFFENSES,PORNOGRAPHY/OBSCENE MAT,PROSTITUTION,RECOVERED VEHICLE,ROBBERY,RUNAWAY,SECONDARY CODES,SEX OFFENSES FORCIBLE,SEX OFFENSES NON FORCIBLE,STOLEN PROPERTY,SUICIDE,SUSPICIOUS OCC,TREA,TRESPASS,VANDALISM,VEHICLE THEFT,WARRANTS,WEAPON LAWS\n") 82 | model.transform(testData).select("Id", "prediction").collect().foreach { 83 | case Row(id: Int, prediction: Double) => { 84 | val labelVec = labelToVec(prediction.toInt, labels, labels.sortWith((s1, s2) => s1 < s2)) 85 | writer.write(s"$id,${labelVec.mkString(",")}\n") 86 | } 87 | } 88 | writer.close() 89 | 90 | } 91 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/kaggle/SpringleafMarketingResponse.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.kaggle 2 | 3 | import org.apache.spark.SparkContext 4 | import org.apache.spark.ml.classification.LogisticRegression 5 | import org.apache.spark.ml.feature._ 6 | import org.apache.spark.sql.DataFrame 7 | import org.apache.spark.sql.functions._ 8 | import org.apache.spark.sql.types._ 9 | 10 | /** 11 | * https://www.kaggle.com/c/springleaf-marketing-response 12 | * https://issues.apache.org/jira/browse/SPARK-10513 13 | * Copied from: 14 | * https://github.com/yanboliang/Springleaf 15 | * 16 | * Created by yanboliang on 9/30/15. 17 | */ 18 | object Springleaf { 19 | 20 | val trainFile = "src/main/resources/kaggle/small-train.csv" 21 | val testFile = "src/main/resources/kaggle/test.csv" 22 | 23 | def main(args: Array[String]): Unit = { 24 | 25 | println("Springleaf start") 26 | 27 | val sc = new SparkContext("local", "Springleaf") 28 | val sqlContext = new org.apache.spark.sql.SQLContext(sc) 29 | val training = sqlContext.read 30 | .format("com.databricks.spark.csv") 31 | .option("header", "true") 32 | .option("inferSchema", "true") 33 | .load(trainFile) 34 | 35 | val numericColumnNames = training.schema.fields.filter(_.dataType != StringType).map(_.name).filter(_ != "target").toSeq 36 | 37 | val categoryColumnIndex = Seq("0001", "0005", "0008", "0009", "0010", "0011", "0012", "0043", "0196", "0200", 38 | "0202", "0216", "0222", "0226", "0229", "0230", "0232", "0236", "0237", "0239" 39 | /*, "0274", "0283", "0305", "0325", "0342", "0353", "0467"*/ ) 40 | 41 | val categoryColumnNames = categoryColumnIndex.map("VAR_" + _) 42 | 43 | val allFeatureColumns = numericColumnNames ++ categoryColumnNames 44 | 45 | val training2 = training.select("target", allFeatureColumns: _*) 46 | //training2.show() 47 | 48 | var oldTraining: DataFrame = training2 49 | var newTraining: DataFrame = training2 50 | 51 | categoryColumnIndex.foreach { 52 | x => 53 | { 54 | val colName = "VAR_" + x 55 | //println(colName) 56 | val indexer = new StringIndexer() 57 | .setInputCol(colName) 58 | .setOutputCol(colName + "_indexed") 59 | .fit(oldTraining) 60 | val indexed = indexer.transform(oldTraining) 61 | 62 | val encoder = new OneHotEncoder() 63 | .setDropLast(false) 64 | .setInputCol(colName + "_indexed") 65 | .setOutputCol(colName + "_encoded") 66 | newTraining = encoder.transform(indexed) 67 | 68 | oldTraining = newTraining 69 | } 70 | } 71 | 72 | val assemblerNames = (numericColumnNames ++ categoryColumnNames.map(_ + "_encoded")).filter(_ != "ID").filter(_ != "target").toArray 73 | //println(assemblerNames.mkString(",")) 74 | 75 | val assembler = new VectorAssembler() 76 | .setInputCols(assemblerNames) 77 | .setOutputCol("features") 78 | 79 | val training3 = assembler.transform(newTraining) 80 | val training4 = training3.withColumn("label", col("target").cast(DoubleType)) 81 | 82 | val lr = new LogisticRegression() 83 | .setMaxIter(100) 84 | .setLabelCol("label") 85 | .setFeaturesCol("features") 86 | 87 | val model = lr.fit(training4) 88 | model.transform(training4).select("ID", "label", "prediction").show(200, false) 89 | } 90 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/AbstractParams.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import scala.reflect.runtime.universe._ 4 | 5 | abstract class AbstractParams[T: TypeTag] { 6 | 7 | private def tag: TypeTag[T] = typeTag[T] 8 | 9 | /** 10 | * Finds all case class fields in concrete class instance, and outputs them in JSON-style format: 11 | * { 12 | * [field name]:\t[field value]\n 13 | * [field name]:\t[field value]\n 14 | * ... 15 | * } 16 | */ 17 | override def toString: String = { 18 | val tpe = tag.tpe 19 | val allAccessors = tpe.declarations.collect { 20 | case m: MethodSymbol if m.isCaseAccessor => m 21 | } 22 | val mirror = runtimeMirror(getClass.getClassLoader) 23 | val instanceMirror = mirror.reflect(this) 24 | allAccessors.map { f => 25 | val paramName = f.name.toString 26 | val fieldMirror = instanceMirror.reflectField(f) 27 | val paramValue = fieldMirror.get 28 | s" $paramName:\t$paramValue" 29 | }.mkString("{\n", ",\n", "\n}") 30 | } 31 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import java.io.File 4 | 5 | import org.apache.spark.SparkConf 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.mllib.linalg.Vector 8 | import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer 9 | import org.apache.spark.sql.DataFrame 10 | import org.apache.spark.sql.Row 11 | import org.apache.spark.sql.SQLContext 12 | 13 | import com.google.common.io.Files 14 | 15 | import scopt.OptionParser 16 | 17 | object DataFrameExample { 18 | 19 | case class Params(input: String = "data/mllib/sample_libsvm_data.txt") 20 | extends AbstractParams[Params] 21 | 22 | def main(args: Array[String]) { 23 | val defaultParams = Params() 24 | 25 | val parser = new OptionParser[Params]("DataFrameExample") { 26 | head("DataFrameExample: an example app using DataFrame for ML.") 27 | opt[String]("input") 28 | .text(s"input path to dataframe") 29 | .action((x, c) => c.copy(input = x)) 30 | checkConfig { params => 31 | success 32 | } 33 | } 34 | 35 | parser.parse(args, defaultParams).map { params => 36 | run(params) 37 | }.getOrElse { 38 | sys.exit(1) 39 | } 40 | } 41 | 42 | def run(params: Params) { 43 | 44 | val conf = new SparkConf().setMaster("local").setAppName(s"DataFrameExample with $params") 45 | val sc = new SparkContext(conf) 46 | val sqlContext = new SQLContext(sc) 47 | 48 | // Load input data 49 | println(s"Loading LIBSVM file with UDT from ${params.input}.") 50 | val df: DataFrame = sqlContext.read.format("libsvm").load(params.input).cache() 51 | println("Schema from LIBSVM:") 52 | df.printSchema() 53 | println(s"Loaded training data as a DataFrame with ${df.count()} records.") 54 | 55 | // Show statistical summary of labels. 56 | val labelSummary = df.describe("label") 57 | labelSummary.show() 58 | 59 | // Convert features column to an RDD of vectors. 60 | val features = df.select("features").rdd.map { case Row(v: Vector) => v } 61 | val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( 62 | (summary, feat) => summary.add(feat), 63 | (sum1, sum2) => sum1.merge(sum2)) 64 | println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") 65 | 66 | // Save the records in a parquet file. 67 | val tmpDir = Files.createTempDir() 68 | tmpDir.deleteOnExit() 69 | val outputDir = new File(tmpDir, "dataframe").toString 70 | println(s"Saving to $outputDir as Parquet file.") 71 | df.write.parquet(outputDir) 72 | 73 | // Load the records back. 74 | println(s"Loading Parquet file with UDT from $outputDir.") 75 | val newDF = sqlContext.read.parquet(outputDir) 76 | println(s"Schema from Parquet:") 77 | newDF.printSchema() 78 | 79 | sc.stop() 80 | } 81 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator 4 | import org.apache.spark.SparkConf 5 | import org.apache.spark.ml.Pipeline 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.ml.feature.StringIndexer 8 | import org.apache.spark.ml.classification.GBTClassifier 9 | import org.apache.spark.ml.classification.GBTClassificationModel 10 | import org.apache.spark.sql.SQLContext 11 | import org.apache.spark.ml.feature.IndexToString 12 | import org.apache.spark.ml.feature.VectorIndexer 13 | 14 | object GradientBoostedTreeClassifierExample { 15 | def main(args: Array[String]): Unit = { 16 | val conf = new SparkConf().setMaster("local").setAppName("GradientBoostedTreeClassifierExample") 17 | val sc = new SparkContext(conf) 18 | val sqlContext = new SQLContext(sc) 19 | 20 | // $example on$ 21 | // Load and parse the data file, converting it to a DataFrame. 22 | val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") 23 | 24 | // Index labels, adding metadata to the label column. 25 | // Fit on whole dataset to include all labels in index. 26 | val labelIndexer = new StringIndexer() 27 | .setInputCol("label") 28 | .setOutputCol("indexedLabel") 29 | .fit(data) 30 | // Automatically identify categorical features, and index them. 31 | // Set maxCategories so features with > 4 distinct values are treated as continuous. 32 | val featureIndexer = new VectorIndexer() 33 | .setInputCol("features") 34 | .setOutputCol("indexedFeatures") 35 | .setMaxCategories(4) 36 | .fit(data) 37 | 38 | // Split the data into training and test sets (30% held out for testing) 39 | val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) 40 | 41 | // Train a GBT model. 42 | val gbt = new GBTClassifier() 43 | .setLabelCol("indexedLabel") 44 | .setFeaturesCol("indexedFeatures") 45 | .setMaxIter(10) 46 | 47 | // Convert indexed labels back to original labels. 48 | val labelConverter = new IndexToString() 49 | .setInputCol("prediction") 50 | .setOutputCol("predictedLabel") 51 | .setLabels(labelIndexer.labels) 52 | 53 | // Chain indexers and GBT in a Pipeline 54 | val pipeline = new Pipeline() 55 | .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter)) 56 | 57 | // Train model. This also runs the indexers. 58 | val model = pipeline.fit(trainingData) 59 | 60 | // Make predictions. 61 | val predictions = model.transform(testData) 62 | 63 | // Select example rows to display. 64 | predictions.select("predictedLabel", "label", "features").show(5) 65 | 66 | // Select (prediction, true label) and compute test error 67 | val evaluator = new MulticlassClassificationEvaluator() 68 | .setLabelCol("indexedLabel") 69 | .setPredictionCol("prediction") 70 | .setMetricName("precision") 71 | val accuracy = evaluator.evaluate(predictions) 72 | println("Test Error = " + (1.0 - accuracy)) 73 | 74 | val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] 75 | println("Learned classification GBT model:\n" + gbtModel.toDebugString) 76 | // $example off$ 77 | 78 | sc.stop() 79 | } 80 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.ml.Pipeline 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.ml.regression.GBTRegressionModel 7 | import org.apache.spark.ml.evaluation.RegressionEvaluator 8 | import org.apache.spark.sql.SQLContext 9 | import org.apache.spark.ml.regression.GBTRegressor 10 | import org.apache.spark.ml.feature.VectorIndexer 11 | 12 | object GradientBoostedTreeRegressorExample { 13 | def main(args: Array[String]): Unit = { 14 | val conf = new SparkConf().setMaster("local").setAppName("GradientBoostedTreeRegressorExample") 15 | val sc = new SparkContext(conf) 16 | val sqlContext = new SQLContext(sc) 17 | 18 | // $example on$ 19 | // Load and parse the data file, converting it to a DataFrame. 20 | val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") 21 | 22 | // Automatically identify categorical features, and index them. 23 | // Set maxCategories so features with > 4 distinct values are treated as continuous. 24 | val featureIndexer = new VectorIndexer() 25 | .setInputCol("features") 26 | .setOutputCol("indexedFeatures") 27 | .setMaxCategories(4) 28 | .fit(data) 29 | 30 | // Split the data into training and test sets (30% held out for testing) 31 | val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) 32 | 33 | // Train a GBT model. 34 | val gbt = new GBTRegressor() 35 | .setLabelCol("label") 36 | .setFeaturesCol("indexedFeatures") 37 | .setMaxIter(10) 38 | 39 | // Chain indexer and GBT in a Pipeline 40 | val pipeline = new Pipeline() 41 | .setStages(Array(featureIndexer, gbt)) 42 | 43 | // Train model. This also runs the indexer. 44 | val model = pipeline.fit(trainingData) 45 | 46 | // Make predictions. 47 | val predictions = model.transform(testData) 48 | 49 | // Select example rows to display. 50 | predictions.select("prediction", "label", "features").show(5) 51 | 52 | // Select (prediction, true label) and compute test error 53 | val evaluator = new RegressionEvaluator() 54 | .setLabelCol("label") 55 | .setPredictionCol("prediction") 56 | .setMetricName("rmse") 57 | val rmse = evaluator.evaluate(predictions) 58 | println("Root Mean Squared Error (RMSE) on test data = " + rmse) 59 | 60 | val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] 61 | println("Learned regression GBT model:\n" + gbtModel.toDebugString) 62 | // $example off$ 63 | 64 | sc.stop() 65 | } 66 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/LogisticRegressionExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import scala.collection.mutable 4 | import scala.reflect.runtime.universe 5 | 6 | import org.apache.spark.SparkConf 7 | import org.apache.spark.SparkContext 8 | import org.apache.spark.ml.Pipeline 9 | import org.apache.spark.ml.PipelineStage 10 | import org.apache.spark.ml.Transformer 11 | import org.apache.spark.ml.classification.LogisticRegression 12 | import org.apache.spark.ml.classification.LogisticRegressionModel 13 | import org.apache.spark.ml.feature.StringIndexer 14 | import org.apache.spark.ml.util.MetadataUtils 15 | import org.apache.spark.mllib.evaluation.MulticlassMetrics 16 | import org.apache.spark.mllib.linalg.Vector 17 | import org.apache.spark.mllib.util.MLUtils 18 | import org.apache.spark.sql.DataFrame 19 | import org.apache.spark.sql.SQLContext 20 | 21 | import scopt.OptionParser 22 | 23 | object LogisticRegressionExample { 24 | 25 | case class Params( 26 | input: String = "data/mllib/sample_libsvm_data.txt", 27 | testInput: String = "", 28 | dataFormat: String = "libsvm", 29 | regParam: Double = 0.0, 30 | elasticNetParam: Double = 0.0, 31 | maxIter: Int = 100, 32 | fitIntercept: Boolean = true, 33 | tol: Double = 1E-6, 34 | fracTest: Double = 0.2) extends AbstractParams[Params] 35 | 36 | def main(args: Array[String]) { 37 | val defaultParams = Params() 38 | 39 | val parser = new OptionParser[Params]("LogisticRegressionExample") { 40 | head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.") 41 | opt[Double]("regParam") 42 | .text(s"regularization parameter, default: ${defaultParams.regParam}") 43 | .action((x, c) => c.copy(regParam = x)) 44 | opt[Double]("elasticNetParam") 45 | .text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " + 46 | s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " + 47 | s"L1 and L2, default: ${defaultParams.elasticNetParam}") 48 | .action((x, c) => c.copy(elasticNetParam = x)) 49 | opt[Int]("maxIter") 50 | .text(s"maximum number of iterations, default: ${defaultParams.maxIter}") 51 | .action((x, c) => c.copy(maxIter = x)) 52 | opt[Boolean]("fitIntercept") 53 | .text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}") 54 | .action((x, c) => c.copy(fitIntercept = x)) 55 | opt[Double]("tol") 56 | .text(s"the convergence tolerance of iterations, Smaller value will lead " + 57 | s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}") 58 | .action((x, c) => c.copy(tol = x)) 59 | opt[Double]("fracTest") 60 | .text(s"fraction of data to hold out for testing. If given option testInput, " + 61 | s"this option is ignored. default: ${defaultParams.fracTest}") 62 | .action((x, c) => c.copy(fracTest = x)) 63 | opt[String]("testInput") 64 | .text(s"input path to test dataset. If given, option fracTest is ignored." + 65 | s" default: ${defaultParams.testInput}") 66 | .action((x, c) => c.copy(testInput = x)) 67 | opt[String]("dataFormat") 68 | .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") 69 | .action((x, c) => c.copy(dataFormat = x)) 70 | arg[String]("") 71 | .text("input path to labeled examples") 72 | .required() 73 | .action((x, c) => c.copy(input = x)) 74 | checkConfig { params => 75 | if (params.fracTest < 0 || params.fracTest >= 1) { 76 | failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).") 77 | } else { 78 | success 79 | } 80 | } 81 | } 82 | 83 | // parser.parse(args, defaultParams).map { params => 84 | // run(params) 85 | // }.getOrElse { 86 | // sys.exit(1) 87 | // } 88 | run(defaultParams); 89 | } 90 | 91 | def run(params: Params) { 92 | val conf = new SparkConf().setMaster("local").setAppName(s"LogisticRegressionExample with $params") 93 | val sc = new SparkContext(conf) 94 | 95 | println(s"LogisticRegressionExample with parameters:\n$params") 96 | 97 | // Load training and test data and cache it. 98 | val (training: DataFrame, test: DataFrame) = loadDatasets(sc, 99 | /* params.input */ "data/mllib/sample_libsvm_data.txt", 100 | params.dataFormat, params.testInput, "classification", params.fracTest) 101 | 102 | // Set up Pipeline 103 | val stages = new mutable.ArrayBuffer[PipelineStage]() 104 | 105 | val labelIndexer = new StringIndexer() 106 | .setInputCol("label") 107 | .setOutputCol("indexedLabel") 108 | stages += labelIndexer 109 | 110 | val lor = new LogisticRegression() 111 | .setFeaturesCol("features") 112 | .setLabelCol("indexedLabel") 113 | .setRegParam(params.regParam) 114 | .setElasticNetParam(params.elasticNetParam) 115 | .setMaxIter(params.maxIter) 116 | .setTol(params.tol) 117 | .setFitIntercept(params.fitIntercept) 118 | 119 | stages += lor 120 | val pipeline = new Pipeline().setStages(stages.toArray) 121 | 122 | // Fit the Pipeline 123 | val startTime = System.nanoTime() 124 | val pipelineModel = pipeline.fit(training) 125 | val elapsedTime = (System.nanoTime() - startTime) / 1e9 126 | println(s"Training time: $elapsedTime seconds") 127 | 128 | val lorModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel] 129 | // Print the weights and intercept for logistic regression. 130 | println(s"Weights: ${lorModel.coefficients} Intercept: ${lorModel.intercept}") 131 | 132 | println("Training data results:") 133 | evaluateClassificationModel(pipelineModel, training, "indexedLabel") 134 | println("Test data results:") 135 | evaluateClassificationModel(pipelineModel, test, "indexedLabel") 136 | 137 | sc.stop() 138 | } 139 | 140 | /** Load a dataset from the given path, using the given format */ 141 | private[ml] def loadData( 142 | sqlContext: SQLContext, 143 | path: String, 144 | format: String, 145 | expectedNumFeatures: Option[Int] = None): DataFrame = { 146 | import sqlContext.implicits._ 147 | 148 | format match { 149 | case "dense" => MLUtils.loadLabeledPoints(sqlContext.sparkContext, path).toDF() 150 | case "libsvm" => expectedNumFeatures match { 151 | case Some(numFeatures) => sqlContext.read.option("numFeatures", numFeatures.toString) 152 | .format("libsvm").load(path) 153 | case None => sqlContext.read.format("libsvm").load(path) 154 | } 155 | case _ => throw new IllegalArgumentException(s"Bad data format: $format") 156 | } 157 | } 158 | 159 | /** 160 | * Load training and test data from files. 161 | * @param input Path to input dataset. 162 | * @param dataFormat "libsvm" or "dense" 163 | * @param testInput Path to test dataset. 164 | * @param algo Classification or Regression 165 | * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. 166 | * @return (training dataset, test dataset) 167 | */ 168 | private[ml] def loadDatasets( 169 | sc: SparkContext, 170 | input: String, 171 | dataFormat: String, 172 | testInput: String, 173 | algo: String, 174 | fracTest: Double): (DataFrame, DataFrame) = { 175 | val sqlContext = new SQLContext(sc) 176 | 177 | // Load training data 178 | val origExamples: DataFrame = loadData(sqlContext, input, dataFormat) 179 | 180 | // Load or create test set 181 | val dataframes: Array[DataFrame] = if (testInput != "") { 182 | // Load testInput. 183 | val numFeatures = origExamples.first().getAs[Vector](1).size 184 | val origTestExamples: DataFrame = 185 | loadData(sqlContext, testInput, dataFormat, Some(numFeatures)) 186 | Array(origExamples, origTestExamples) 187 | } else { 188 | // Split input into training, test. 189 | origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345) 190 | } 191 | 192 | val training = dataframes(0).cache() 193 | val test = dataframes(1).cache() 194 | 195 | val numTraining = training.count() 196 | val numTest = test.count() 197 | val numFeatures = training.select("features").first().getAs[Vector](0).size 198 | println("Loaded data:") 199 | println(s" numTraining = $numTraining, numTest = $numTest") 200 | println(s" numFeatures = $numFeatures") 201 | 202 | (training, test) 203 | } 204 | 205 | /** 206 | * Evaluate the given ClassificationModel on data. Print the results. 207 | * @param model Must fit ClassificationModel abstraction 208 | * @param data DataFrame with "prediction" and labelColName columns 209 | * @param labelColName Name of the labelCol parameter for the model 210 | * 211 | * TODO: Change model type to ClassificationModel once that API is public. SPARK-5995 212 | */ 213 | private[ml] def evaluateClassificationModel( 214 | model: Transformer, 215 | data: DataFrame, 216 | labelColName: String): Unit = { 217 | val fullPredictions = model.transform(data).cache() 218 | val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0)) 219 | val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0)) 220 | // Print number of classes for reference 221 | val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match { 222 | case Some(n) => n 223 | case None => throw new RuntimeException( 224 | "Unknown failure when indexing labels for classification.") 225 | } 226 | val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision 227 | println(s" Accuracy ($numClasses classes): $accuracy") 228 | } 229 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.SparkContext 5 | import org.apache.spark.ml.classification.BinaryLogisticRegressionSummary 6 | import org.apache.spark.ml.classification.LogisticRegression 7 | import org.apache.spark.sql.SQLContext 8 | import org.apache.spark.sql.functions.max 9 | 10 | object LogisticRegressionSummaryExample { 11 | 12 | def main(args: Array[String]): Unit = { 13 | val conf = new SparkConf().setMaster("local").setAppName("LogisticRegressionSummaryExample") 14 | val sc = new SparkContext(conf) 15 | val sqlCtx = new SQLContext(sc) 16 | import sqlCtx.implicits._ 17 | 18 | // Load training data 19 | val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") 20 | 21 | val lr = new LogisticRegression() 22 | .setMaxIter(10) 23 | .setRegParam(0.3) 24 | .setElasticNetParam(0.8) 25 | 26 | // Fit the model 27 | val lrModel = lr.fit(training) 28 | 29 | // $example on$ 30 | // Extract the summary from the returned LogisticRegressionModel instance trained in the earlier 31 | // example 32 | val trainingSummary = lrModel.summary 33 | 34 | // Obtain the objective per iteration. 35 | val objectiveHistory = trainingSummary.objectiveHistory 36 | objectiveHistory.foreach(loss => println(loss)) 37 | 38 | // Obtain the metrics useful to judge performance on test data. 39 | // We cast the summary to a BinaryLogisticRegressionSummary since the problem is a 40 | // binary classification problem. 41 | val binarySummary = trainingSummary.asInstanceOf[BinaryLogisticRegressionSummary] 42 | 43 | // Obtain the receiver-operating characteristic as a dataframe and areaUnderROC. 44 | val roc = binarySummary.roc 45 | roc.show() 46 | println(binarySummary.areaUnderROC) 47 | 48 | // Set the model threshold to maximize F-Measure 49 | val fMeasure = binarySummary.fMeasureByThreshold 50 | val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0) 51 | val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure) 52 | .select("threshold").head().getDouble(0) 53 | lrModel.setThreshold(bestThreshold) 54 | // $example off$ 55 | 56 | sc.stop() 57 | } 58 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.ml.classification.LogisticRegression 4 | import org.apache.spark.sql.SQLContext 5 | import org.apache.spark.SparkConf 6 | import org.apache.spark.SparkContext 7 | 8 | object LogisticRegressionWithElasticNetExample { 9 | 10 | def main(args: Array[String]): Unit = { 11 | val conf = new SparkConf().setMaster("local").setAppName("LogisticRegressionWithElasticNetExample") 12 | val sc = new SparkContext(conf) 13 | val sqlCtx = new SQLContext(sc) 14 | 15 | // $example on$ 16 | // Load training data 17 | val training = sqlCtx.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") 18 | 19 | val lr = new LogisticRegression() 20 | .setMaxIter(10) 21 | .setRegParam(0.3) 22 | .setElasticNetParam(0.8) 23 | 24 | // Fit the model 25 | val lrModel = lr.fit(training) 26 | 27 | // Print the coefficients and intercept for logistic regression 28 | println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}") 29 | // $example off$ 30 | 31 | sc.stop() 32 | } 33 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.ml.tuning.ParamGridBuilder 5 | import org.apache.spark.ml.Pipeline 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.ml.tuning.CrossValidator 8 | import org.apache.spark.sql.SQLContext 9 | import org.apache.spark.ml.classification.LogisticRegression 10 | import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator 11 | import org.apache.spark.ml.feature.Tokenizer 12 | import org.apache.spark.ml.feature.HashingTF 13 | import org.apache.spark.sql.Row 14 | import org.apache.spark.mllib.linalg.Vector 15 | 16 | object ModelSelectionViaCrossValidationExample { 17 | 18 | def main(args: Array[String]): Unit = { 19 | val conf = new SparkConf().setMaster("local").setAppName("ModelSelectionViaCrossValidationExample") 20 | val sc = new SparkContext(conf) 21 | val sqlContext = new SQLContext(sc) 22 | 23 | // $example on$ 24 | // Prepare training data from a list of (id, text, label) tuples. 25 | val training = sqlContext.createDataFrame(Seq( 26 | (0L, "a b c d e spark", 1.0), 27 | (1L, "b d", 0.0), 28 | (2L, "spark f g h", 1.0), 29 | (3L, "hadoop mapreduce", 0.0), 30 | (4L, "b spark who", 1.0), 31 | (5L, "g d a y", 0.0), 32 | (6L, "spark fly", 1.0), 33 | (7L, "was mapreduce", 0.0), 34 | (8L, "e spark program", 1.0), 35 | (9L, "a e c l", 0.0), 36 | (10L, "spark compile", 1.0), 37 | (11L, "hadoop software", 0.0) 38 | )).toDF("id", "text", "label") 39 | 40 | // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. 41 | val tokenizer = new Tokenizer() 42 | .setInputCol("text") 43 | .setOutputCol("words") 44 | val hashingTF = new HashingTF() 45 | .setInputCol(tokenizer.getOutputCol) 46 | .setOutputCol("features") 47 | val lr = new LogisticRegression() 48 | .setMaxIter(10) 49 | val pipeline = new Pipeline() 50 | .setStages(Array(tokenizer, hashingTF, lr)) 51 | 52 | // We use a ParamGridBuilder to construct a grid of parameters to search over. 53 | // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, 54 | // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. 55 | val paramGrid = new ParamGridBuilder() 56 | .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) 57 | .addGrid(lr.regParam, Array(0.1, 0.01)) 58 | .build() 59 | 60 | // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. 61 | // This will allow us to jointly choose parameters for all Pipeline stages. 62 | // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. 63 | // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric 64 | // is areaUnderROC. 65 | val cv = new CrossValidator() 66 | .setEstimator(pipeline) 67 | .setEvaluator(new BinaryClassificationEvaluator) 68 | .setEstimatorParamMaps(paramGrid) 69 | .setNumFolds(2) // Use 3+ in practice 70 | 71 | // Run cross-validation, and choose the best set of parameters. 72 | val cvModel = cv.fit(training) 73 | 74 | // Prepare test documents, which are unlabeled (id, text) tuples. 75 | val test = sqlContext.createDataFrame(Seq( 76 | (4L, "spark i j k"), 77 | (5L, "l m n"), 78 | (6L, "mapreduce spark"), 79 | (7L, "apache hadoop") 80 | )).toDF("id", "text") 81 | 82 | // Make predictions on test documents. cvModel uses the best model found (lrModel). 83 | cvModel.transform(test) 84 | .select("id", "text", "probability", "prediction") 85 | .collect() 86 | .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => 87 | println(s"($id, $text) --> prob=$prob, prediction=$prediction") 88 | } 89 | // $example off$ 90 | 91 | sc.stop() 92 | } 93 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.ml.tuning.ParamGridBuilder 5 | import org.apache.spark.SparkContext 6 | import org.apache.spark.ml.evaluation.RegressionEvaluator 7 | import org.apache.spark.sql.SQLContext 8 | import org.apache.spark.ml.tuning.TrainValidationSplit 9 | import org.apache.spark.ml.regression.LinearRegression 10 | 11 | object ModelSelectionViaTrainValidationSplitExample { 12 | 13 | def main(args: Array[String]): Unit = { 14 | val conf = new SparkConf().setMaster("local").setAppName("ModelSelectionViaTrainValidationSplitExample") 15 | val sc = new SparkContext(conf) 16 | val sqlContext = new SQLContext(sc) 17 | 18 | // $example on$ 19 | // Prepare training and test data. 20 | val data = sqlContext.read.format("libsvm").load("data/mllib/sample_linear_regression_data.txt") 21 | val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) 22 | 23 | val lr = new LinearRegression() 24 | 25 | // We use a ParamGridBuilder to construct a grid of parameters to search over. 26 | // TrainValidationSplit will try all combinations of values and determine best model using 27 | // the evaluator. 28 | val paramGrid = new ParamGridBuilder() 29 | .addGrid(lr.regParam, Array(0.1, 0.01)) 30 | .addGrid(lr.fitIntercept) 31 | .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) 32 | .build() 33 | 34 | // In this case the estimator is simply the linear regression. 35 | // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. 36 | val trainValidationSplit = new TrainValidationSplit() 37 | .setEstimator(lr) 38 | .setEvaluator(new RegressionEvaluator) 39 | .setEstimatorParamMaps(paramGrid) 40 | // 80% of the data will be used for training and the remaining 20% for validation. 41 | .setTrainRatio(0.8) 42 | 43 | // Run train validation split, and choose the best set of parameters. 44 | val model = trainValidationSplit.fit(training) 45 | 46 | // Make predictions on test data. model is the model with combination of parameters 47 | // that performed best. 48 | model.transform(test) 49 | .select("features", "label", "prediction") 50 | .show() 51 | // $example off$ 52 | 53 | sc.stop() 54 | } 55 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.SparkConf 4 | import org.apache.spark.SparkContext 5 | import org.apache.spark.ml.feature.OneHotEncoder 6 | import org.apache.spark.ml.feature.StringIndexer 7 | import org.apache.spark.sql.SQLContext 8 | import scala.reflect.runtime.universe 9 | 10 | object OneHotEncoderExample { 11 | def main(args: Array[String]): Unit = { 12 | val conf = new SparkConf().setMaster("local").setAppName("OneHotEncoderExample") 13 | val sc = new SparkContext(conf) 14 | val sqlContext = new SQLContext(sc) 15 | 16 | // $example on$ 17 | val df = sqlContext.createDataFrame(Seq( 18 | (0, "a"), 19 | (1, "b"), 20 | (2, "c"), 21 | (3, "a"), 22 | (4, "a"), 23 | (5, "c") 24 | )).toDF("id", "category") 25 | 26 | val indexer = new StringIndexer() 27 | .setInputCol("category") 28 | .setOutputCol("categoryIndex") 29 | .fit(df) 30 | val indexed = indexer.transform(df) 31 | 32 | val encoder = new OneHotEncoder() 33 | .setInputCol("categoryIndex") 34 | .setOutputCol("categoryVec") 35 | val encoded = encoder.transform(indexed) 36 | encoded.select("id", "category", "categoryIndex", "categoryVec").show() 37 | // $example off$ 38 | sc.stop() 39 | } 40 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/PipelineExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import scala.reflect.runtime.universe 4 | 5 | import org.apache.spark.SparkConf 6 | import org.apache.spark.SparkContext 7 | import org.apache.spark.ml.Pipeline 8 | import org.apache.spark.ml.PipelineModel 9 | import org.apache.spark.ml.classification.LogisticRegression 10 | import org.apache.spark.ml.feature.HashingTF 11 | import org.apache.spark.ml.feature.Tokenizer 12 | import org.apache.spark.mllib.linalg.Vector 13 | import org.apache.spark.sql.Row 14 | import org.apache.spark.sql.SQLContext 15 | 16 | object PipelineExample { 17 | 18 | def main(args: Array[String]): Unit = { 19 | val conf = new SparkConf().setMaster("local").setAppName("PipelineExample") 20 | val sc = new SparkContext(conf) 21 | val sqlContext = new SQLContext(sc) 22 | 23 | // $example on$ 24 | // Prepare training documents from a list of (id, text, label) tuples. 25 | val training = sqlContext.createDataFrame(Seq( 26 | (0L, "a b c d e spark", 1.0), 27 | (1L, "b d", 0.0), 28 | (2L, "spark f g h", 1.0), 29 | (3L, "hadoop mapreduce", 0.0) 30 | )).toDF("id", "text", "label") 31 | 32 | // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. 33 | val tokenizer = new Tokenizer() 34 | .setInputCol("text") 35 | .setOutputCol("words") 36 | val hashingTF = new HashingTF() 37 | .setNumFeatures(1000) 38 | .setInputCol(tokenizer.getOutputCol) 39 | .setOutputCol("features") 40 | val lr = new LogisticRegression() 41 | .setMaxIter(10) 42 | .setRegParam(0.01) 43 | val pipeline = new Pipeline() 44 | .setStages(Array(tokenizer, hashingTF, lr)) 45 | 46 | // Fit the pipeline to training documents. 47 | val model = pipeline.fit(training) 48 | 49 | // Now we can optionally save the fitted pipeline to disk 50 | model.write.overwrite().save("/tmp/spark-logistic-regression-model") 51 | 52 | // We can also save this unfit pipeline to disk 53 | pipeline.write.overwrite().save("/tmp/unfit-lr-model") 54 | 55 | // And load it back in during production 56 | val sameModel = PipelineModel.load("/tmp/spark-logistic-regression-model") 57 | 58 | // Prepare test documents, which are unlabeled (id, text) tuples. 59 | val test = sqlContext.createDataFrame(Seq( 60 | (4L, "spark i j k"), 61 | (5L, "l m n"), 62 | (6L, "mapreduce spark"), 63 | (7L, "apache hadoop") 64 | )).toDF("id", "text") 65 | 66 | // Make predictions on test documents. 67 | model.transform(test) 68 | .select("id", "text", "probability", "prediction") 69 | .collect() 70 | .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => 71 | println(s"($id, $text) --> prob=$prob, prediction=$prediction") 72 | } 73 | // $example off$ 74 | 75 | sc.stop() 76 | } 77 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.ml.feature.StringIndexer 4 | import org.apache.spark.sql.SQLContext 5 | import org.apache.spark.SparkConf 6 | import org.apache.spark.SparkContext 7 | import scala.reflect.runtime.universe 8 | 9 | object StringIndexerExample { 10 | def main(args: Array[String]): Unit = { 11 | val conf = new SparkConf().setMaster("local").setAppName("StringIndexerExample") 12 | val sc = new SparkContext(conf) 13 | val sqlContext = new SQLContext(sc) 14 | 15 | // $example on$ 16 | val df = sqlContext.createDataFrame( 17 | Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) 18 | ).toDF("id", "category") 19 | 20 | val indexer = new StringIndexer() 21 | .setInputCol("category") 22 | .setOutputCol("categoryIndex") 23 | 24 | val indexed = indexer.fit(df).transform(df) 25 | indexed.show() 26 | // $example off$ 27 | sc.stop() 28 | } 29 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples.ml 2 | 3 | import org.apache.spark.ml.feature.VectorIndexer 4 | import org.apache.spark.sql.SQLContext 5 | import org.apache.spark.SparkConf 6 | import org.apache.spark.SparkContext 7 | 8 | object VectorIndexerExample { 9 | def main(args: Array[String]): Unit = { 10 | val conf = new SparkConf().setMaster("local").setAppName("VectorIndexerExample") 11 | val sc = new SparkContext(conf) 12 | val sqlContext = new SQLContext(sc) 13 | 14 | // $example on$ 15 | val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") 16 | 17 | val indexer = new VectorIndexer() 18 | .setInputCol("features") 19 | .setOutputCol("indexed") 20 | .setMaxCategories(10) 21 | 22 | val indexerModel = indexer.fit(data) 23 | 24 | val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet 25 | println(s"Chose ${categoricalFeatures.size} categorical features: " + 26 | categoricalFeatures.mkString(", ")) 27 | 28 | // Create new column "indexed" with categorical values transformed to indices 29 | val indexedData = indexerModel.transform(data) 30 | indexedData.show() 31 | // $example off$ 32 | sc.stop() 33 | } 34 | } -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/SparkFunSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark 19 | 20 | // scalastyle:off 21 | import org.scalatest.{FunSuite, Outcome} 22 | 23 | /** 24 | * Base abstract class for all unit tests in Spark for handling common functionality. 25 | */ 26 | private[spark] abstract class SparkFunSuite extends FunSuite with Logging { 27 | // scalastyle:on 28 | 29 | /** 30 | * Log the suite name and the test name before and after each test. 31 | * 32 | * Subclasses should never override this method. If they wish to run 33 | * custom code before and after each test, they should mix in the 34 | * {{org.scalatest.BeforeAndAfter}} trait instead. 35 | */ 36 | final protected override def withFixture(test: NoArgTest): Outcome = { 37 | val testName = test.text 38 | val suiteName = this.getClass.getName 39 | val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") 40 | try { 41 | logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") 42 | test() 43 | } finally { 44 | logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/examples/kaggle/ClickThroughRatePrediction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.examples.kaggle 19 | 20 | import org.apache.spark.SparkFunSuite 21 | import org.apache.spark.util.MLlibTestSparkContext 22 | 23 | class ClickThroughRatePredictionSuite extends SparkFunSuite with MLlibTestSparkContext { 24 | 25 | test("run") { 26 | // Logger.getLogger("org").setLevel(Level.OFF) 27 | // Logger.getLogger("akka").setLevel(Level.OFF) 28 | 29 | val trainPath = this.getClass.getResource("/train.part-10000").getPath 30 | val testPath = this.getClass.getResource("/test.part-10000").getPath 31 | val resultPath = "./tmp/result/" 32 | 33 | ClickThroughRatePrediction.run(sc, sqlContext, trainPath, testPath, resultPath) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/util/LocalClusterSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.util 19 | 20 | import org.scalatest.{Suite, BeforeAndAfterAll} 21 | 22 | import org.apache.spark.{SparkConf, SparkContext} 23 | 24 | trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => 25 | @transient var sc: SparkContext = _ 26 | 27 | override def beforeAll() { 28 | super.beforeAll() 29 | val conf = new SparkConf() 30 | .setMaster("local-cluster[2, 1, 1024]") 31 | .setAppName("test-cluster") 32 | .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data 33 | sc = new SparkContext(conf) 34 | } 35 | 36 | override def afterAll() { 37 | try { 38 | if (sc != null) { 39 | sc.stop() 40 | } 41 | } finally { 42 | super.afterAll() 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/util/MLlibTestSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.util 19 | 20 | import org.scalatest.{BeforeAndAfterAll, Suite} 21 | 22 | import org.apache.spark.sql.SQLContext 23 | import org.apache.spark.{SparkConf, SparkContext} 24 | 25 | trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => 26 | @transient var sc: SparkContext = _ 27 | @transient var sqlContext: SQLContext = _ 28 | 29 | override def beforeAll() { 30 | super.beforeAll() 31 | val conf = new SparkConf() 32 | .setMaster("local[2]") 33 | .setAppName("MLlibUnitTest") 34 | sc = new SparkContext(conf) 35 | SQLContext.clearActive() 36 | sqlContext = new SQLContext(sc) 37 | SQLContext.setActive(sqlContext) 38 | } 39 | 40 | override def afterAll() { 41 | try { 42 | sqlContext = null 43 | SQLContext.clearActive() 44 | if (sc != null) { 45 | sc.stop() 46 | } 47 | sc = null 48 | } finally { 49 | super.afterAll() 50 | } 51 | } 52 | } 53 | --------------------------------------------------------------------------------