├── .travis.yaml ├── LICENSE ├── README.md ├── bin └── click-through-rate-predictino.sh ├── build.sbt ├── build ├── sbt └── sbt-launch-lib.bash ├── dev ├── lint-scala └── sbt-assembly-skip-test.sh ├── project ├── build.properties └── plugins.sbt ├── scalastyle-config.xml └── src ├── main └── scala │ └── org │ └── apache │ └── spark │ └── examples │ └── kaggle │ └── ClickThroughRatePrediction.scala └── test ├── resources ├── test.part-10000 └── train.part-10000 └── scala └── org └── apache └── spark ├── SparkFunSuite.scala ├── examples └── kaggle │ └── ClickThroughRatePrediction.scala └── util ├── LocalClusterSparkContext.scala └── MLlibTestSparkContext.scala /.travis.yaml: -------------------------------------------------------------------------------- 1 | language: scala 2 | sudo: false 3 | cache: 4 | directories: 5 | - $HOME/.ivy2 6 | matrix: 7 | include: 8 | - jdk: openjdk7 9 | scala: 2.10.5 10 | env: TEST_SPARK_VERSION="1.6.0" 11 | - jdk: openjdk8 12 | scala: 2.10.5 13 | env: TEST_SPARK_VERSION="1.6.0" 14 | - jdk: openjdk7 15 | scala: 2.11.7 16 | env: TEST_SPARK_VERSION="1.6.0" 17 | - jdk: openjdk8 18 | scala: 2.11.7 19 | env: TEST_SPARK_VERSION="1.6.0" 20 | script: 21 | - sbt -Dspark.testVersion=$TEST_SPARK_VERSION ++$TRAVIS_SCALA_VERSION coverage test 22 | - sbt ++$TRAVIS_SCALA_VERSION scalastyle 23 | after_success: 24 | - bash <(curl -s https://codecov.io/bash) 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License, Version 2.0 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 13 | 14 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 15 | 16 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 17 | 18 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 19 | 20 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 21 | 22 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 23 | 24 | "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 25 | 26 | "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 27 | 28 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 29 | 30 | 2. Grant of Copyright License. 31 | 32 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 33 | 34 | 3. Grant of Patent License. 35 | 36 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 37 | 38 | 4. Redistribution. 39 | 40 | You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 41 | 42 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 43 | You must cause any modified files to carry prominent notices stating that You changed the files; and 44 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 45 | If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 46 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 47 | 48 | 5. Submission of Contributions. 49 | 50 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 51 | 52 | 6. Trademarks. 53 | 54 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 55 | 56 | 7. Disclaimer of Warranty. 57 | 58 | Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 59 | 60 | 8. Limitation of Liability. 61 | 62 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 63 | 64 | 9. Accepting Warranty or Additional Liability. 65 | 66 | While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 67 | 68 | END OF TERMS AND CONDITIONS 69 | 70 | APPENDIX: How to apply the Apache License to your work 71 | 72 | To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. 73 | 74 | Copyright [yyyy] [name of copyright owner] 75 | 76 | Licensed under the Apache License, Version 2.0 (the "License"); 77 | you may not use this file except in compliance with the License. 78 | You may obtain a copy of the License at 79 | 80 | http://www.apache.org/licenses/LICENSE-2.0 81 | 82 | Unless required by applicable law or agreed to in writing, software 83 | distributed under the License is distributed on an "AS IS" BASIS, 84 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 85 | See the License for the specific language governing permissions and 86 | limitations under the License. 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Try Kaggle's Click Through Rate Prediction with Spark Pipeline API 2 | 3 | The purpose of this Spark Application is to test Spark Pipeline API with real data for [SPARK-13239](https://issues.apache.org/jira/browse/SPARK-13239). 4 | So, we tested ML Pipeline API with Kaggle's click-through rate prediction. 5 | 6 | 7 | ## Build & Run 8 | You can build this Spark application with `sbt clean assembly`. 9 | And you can run it the command. 10 | 11 | ``` 12 | $SPARK_HOME/bin/spark-submit \ 13 | -class org.apache.spark.examples.kaggle.ClickThroughRatePredictionWitLogisticRegression \ 14 | /path/to/click-through-rate-prediction-assembly-1.0.jar \ 15 | --train=/path/to/train \ 16 | --test=/path/to/test \ 17 | --result=/path/to/result.csv 18 | ``` 19 | 20 | - `--train`: the training data you downloaded 21 | - `--test`: the test data you downloaded 22 | - `--result`: result file 23 | 24 | You know, Spark ML can't write a single file directly. 25 | However, making the number of partitions of result DataFrame 1, this application aggregates the result as a file. 26 | So you can get the result CSV file from `part-00000` under the path which you set at `--result` option. 27 | 28 | ## The Kaggle Contest 29 | 30 | > Predict whether a mobile ad will be clicked 31 | > In online advertising, click-through rate (CTR) is a very important metric for evaluating ad performance. As a result, click prediction systems are essential and widely used for sponsored search and real-time bidding. 32 | 33 | https://www.kaggle.com/c/avazu-ctr-prediction 34 | 35 | 36 | ## Approach 37 | 38 | 1. Extracts features of categorical features with `OneHotEncoder` with `StringIndexer` 39 | 2. Train a model with `LogisticRegression` with `CrossValidator` 40 | - The `Evaluator` of `CrossValidator` is the default of `BinaryClassificationEvaluator`. 41 | 42 | We merged the training data with the test data in the extracting features phase. 43 | Since, the test data includes values which doesn't exists in the training data. 44 | Therefore, we needed to avoid errors about missing values of each variables, when extracting features of the test data. 45 | 46 | ## Result 47 | 48 | I got the score: `0.3998684` with the following parameter set. 49 | 50 | - Logistic Regression 51 | - `featuresCol`: features 52 | - `fitIntercept`: true 53 | - `labelCol`: label 54 | - `maxIter`: 100 55 | - `predictionCol`: prediction 56 | - `probabilityCol`: probability 57 | - `rawPredictionCol`: rawPrediction 58 | - `regParam`: 0.001 59 | - `standardization`: true 60 | - `threshold`: 0.22 61 | - `tol`: 1.0E-6 62 | - `weightCol`: 63 | 64 | ## TODO 65 | 66 | We should offer more `Evaluator`s, such as logg-loss. 67 | Since `spark.ml` doesn't offer Loggistic-Loss at Spark 1.6, we might get better score with logg-loss evaluator. 68 | -------------------------------------------------------------------------------- /bin/click-through-rate-predictino.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 4 | # Licensed to the Apache Software Foundation (ASF) under one or more 5 | # contributor license agreements. See the NOTICE file distributed with 6 | # this work for additional information regarding copyright ownership. 7 | # The ASF licenses this file to You under the Apache License, Version 2.0 8 | # (the "License"); you may not use this file except in compliance with 9 | # the License. You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | 20 | THIS_PROJECT_HOME="$( cd $(dirname $( dirname $0)) && pwd )" 21 | 22 | PACKAGED_JAR=$(find ${THIS_PROJECT_HOME}/target -name "click-through-rate-prediction-assembly*.jar") 23 | 24 | ${SPARK_HOME}/bin/spark-submit \ 25 | --class "org.apache.spark.examples.kaggle.ClickThroughRatePrediction" \ 26 | "$PACKAGED_JAR" \ 27 | --train "s3n://s3-yu-ishikawa/test-data/click-through-rate-prediction/train" \ 28 | --test "s3n://s3-yu-ishikawa/test-data/click-through-rate-prediction/test" \ 29 | --result "s3n://s3-yu-ishikawa/test-data/click-through-rate-prediction/result/" 30 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | // Your sbt build file. Guides on how to write one can be found at 2 | // http://www.scala-sbt.org/0.13/docs/index.html 3 | 4 | scalaVersion := "2.10.5" 5 | 6 | sparkVersion := "1.6.0" 7 | 8 | crossScalaVersions := Seq("2.10.5", "2.11.7") 9 | 10 | spName := "yu-iskw/click-through-rate-prediction" 11 | 12 | // Don't forget to set the version 13 | version := "1.1" 14 | 15 | spAppendScalaVersion := true 16 | 17 | spIncludeMaven := true 18 | 19 | spIgnoreProvided := true 20 | 21 | // Can't parallelly execute in test 22 | parallelExecution in Test := false 23 | 24 | fork in Test := true 25 | 26 | javaOptions ++= Seq("-Xmx2G", "-XX:MaxPermSize=256m") 27 | 28 | // All Spark Packages need a license 29 | licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")) 30 | 31 | // Add Spark components this package depends on, e.g, "mllib", .... 32 | sparkComponents ++= Seq("sql", "mllib") 33 | 34 | libraryDependencies ++= Seq( 35 | "org.scalatest" %% "scalatest" % "2.1.5" % "test", 36 | "com.github.scopt" % "scopt_2.10" % "3.3.0" 37 | ) 38 | 39 | 40 | // uncomment and change the value below to change the directory where your zip artifact will be created 41 | // spDistDirectory := target.value 42 | 43 | // add any Spark Package dependencies using spDependencies. 44 | // e.g. spDependencies += "databricks/spark-avro:0.1" 45 | spDependencies += "databricks/spark-csv:1.3.0-s_2.10" 46 | -------------------------------------------------------------------------------- /build/sbt: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so 4 | # that we can run Hive to generate the golden answer. This is not required for normal development 5 | # or testing. 6 | for i in $HIVE_HOME/lib/* 7 | do HADOOP_CLASSPATH=$HADOOP_CLASSPATH:$i 8 | done 9 | export HADOOP_CLASSPATH 10 | 11 | realpath () { 12 | ( 13 | TARGET_FILE=$1 14 | 15 | cd $(dirname $TARGET_FILE) 16 | TARGET_FILE=$(basename $TARGET_FILE) 17 | 18 | COUNT=0 19 | while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] 20 | do 21 | TARGET_FILE=$(readlink $TARGET_FILE) 22 | cd $(dirname $TARGET_FILE) 23 | TARGET_FILE=$(basename $TARGET_FILE) 24 | COUNT=$(($COUNT + 1)) 25 | done 26 | 27 | echo $(pwd -P)/$TARGET_FILE 28 | ) 29 | } 30 | 31 | . $(dirname $(realpath $0))/sbt-launch-lib.bash 32 | 33 | 34 | declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" 35 | declare -r sbt_opts_file=".sbtopts" 36 | declare -r etc_sbt_opts_file="/etc/sbt/sbtopts" 37 | 38 | usage() { 39 | cat < path to global settings/plugins directory (default: ~/.sbt) 47 | -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series) 48 | -ivy path to local Ivy repository (default: ~/.ivy2) 49 | -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem)) 50 | -no-share use all local caches; no sharing 51 | -no-global uses global caches, but does not use global ~/.sbt directory. 52 | -jvm-debug Turn on JVM debugging, open at the given port. 53 | -batch Disable interactive mode 54 | # sbt version (default: from project/build.properties if present, else latest release) 55 | -sbt-version use the specified version of sbt 56 | -sbt-jar use the specified jar as the sbt launcher 57 | -sbt-rc use an RC version of sbt 58 | -sbt-snapshot use a snapshot version of sbt 59 | # java version (default: java from PATH, currently $(java -version 2>&1 | grep version)) 60 | -java-home alternate JAVA_HOME 61 | # jvm options and output control 62 | JAVA_OPTS environment variable, if unset uses "$java_opts" 63 | SBT_OPTS environment variable, if unset uses "$default_sbt_opts" 64 | .sbtopts if this file exists in the current directory, it is 65 | prepended to the runner args 66 | /etc/sbt/sbtopts if this file exists, it is prepended to the runner args 67 | -Dkey=val pass -Dkey=val directly to the java runtime 68 | -J-X pass option -X directly to the java runtime 69 | (-J is stripped) 70 | -S-X add -X to sbt's scalacOptions (-J is stripped) 71 | -PmavenProfiles Enable a maven profile for the build. 72 | In the case of duplicated or conflicting options, the order above 73 | shows precedence: JAVA_OPTS lowest, command line options highest. 74 | EOM 75 | } 76 | 77 | process_my_args () { 78 | while [[ $# -gt 0 ]]; do 79 | case "$1" in 80 | -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;; 81 | -no-share) addJava "$noshare_opts" && shift ;; 82 | -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;; 83 | -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;; 84 | -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;; 85 | -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;; 86 | -batch) exec &2 "$@" 31 | } 32 | vlog () { 33 | [[ $verbose || $debug ]] && echoerr "$@" 34 | } 35 | dlog () { 36 | [[ $debug ]] && echoerr "$@" 37 | } 38 | 39 | acquire_sbt_jar () { 40 | SBT_VERSION=`awk -F "=" '/sbt\\.version/ {print $2}' ./project/build.properties` 41 | URL1=https://dl.bintray.com/typesafe/ivy-releases/org.scala-sbt/sbt-launch/${SBT_VERSION}/sbt-launch.jar 42 | JAR=build/sbt-launch-${SBT_VERSION}.jar 43 | 44 | sbt_jar=$JAR 45 | 46 | if [[ ! -f "$sbt_jar" ]]; then 47 | # Download sbt launch jar if it hasn't been downloaded yet 48 | if [ ! -f ${JAR} ]; then 49 | # Download 50 | printf "Attempting to fetch sbt\n" 51 | JAR_DL=${JAR}.part 52 | if hash curl 2>/dev/null; then 53 | curl --fail --location --silent ${URL1} > "${JAR_DL}" &&\ 54 | mv "${JAR_DL}" "${JAR}" 55 | elif hash wget 2>/dev/null; then 56 | wget --quiet ${URL1} -O "${JAR_DL}" &&\ 57 | mv "${JAR_DL}" "${JAR}" 58 | else 59 | printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" 60 | exit -1 61 | fi 62 | fi 63 | if [ ! -f ${JAR} ]; then 64 | # We failed to download 65 | printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" 66 | exit -1 67 | fi 68 | printf "Launching sbt from ${JAR}\n" 69 | fi 70 | } 71 | 72 | execRunner () { 73 | # print the arguments one to a line, quoting any containing spaces 74 | [[ $verbose || $debug ]] && echo "# Executing command line:" && { 75 | for arg; do 76 | if printf "%s\n" "$arg" | grep -q ' '; then 77 | printf "\"%s\"\n" "$arg" 78 | else 79 | printf "%s\n" "$arg" 80 | fi 81 | done 82 | echo "" 83 | } 84 | 85 | exec "$@" 86 | } 87 | 88 | addJava () { 89 | dlog "[addJava] arg = '$1'" 90 | java_args=( "${java_args[@]}" "$1" ) 91 | } 92 | 93 | enableProfile () { 94 | dlog "[enableProfile] arg = '$1'" 95 | maven_profiles=( "${maven_profiles[@]}" "$1" ) 96 | export SBT_MAVEN_PROFILES="${maven_profiles[@]}" 97 | } 98 | 99 | addSbt () { 100 | dlog "[addSbt] arg = '$1'" 101 | sbt_commands=( "${sbt_commands[@]}" "$1" ) 102 | } 103 | addResidual () { 104 | dlog "[residual] arg = '$1'" 105 | residual_args=( "${residual_args[@]}" "$1" ) 106 | } 107 | addDebugger () { 108 | addJava "-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=$1" 109 | } 110 | 111 | # a ham-fisted attempt to move some memory settings in concert 112 | # so they need not be dicked around with individually. 113 | get_mem_opts () { 114 | local mem=${1:-2048} 115 | local perm=$(( $mem / 4 )) 116 | (( $perm > 256 )) || perm=256 117 | (( $perm < 4096 )) || perm=4096 118 | local codecache=$(( $perm / 2 )) 119 | 120 | echo "-Xms${mem}m -Xmx${mem}m -XX:MaxPermSize=${perm}m -XX:ReservedCodeCacheSize=${codecache}m" 121 | } 122 | 123 | require_arg () { 124 | local type="$1" 125 | local opt="$2" 126 | local arg="$3" 127 | if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then 128 | die "$opt requires <$type> argument" 129 | fi 130 | } 131 | 132 | is_function_defined() { 133 | declare -f "$1" > /dev/null 134 | } 135 | 136 | process_args () { 137 | while [[ $# -gt 0 ]]; do 138 | case "$1" in 139 | -h|-help) usage; exit 1 ;; 140 | -v|-verbose) verbose=1 && shift ;; 141 | -d|-debug) debug=1 && shift ;; 142 | 143 | -ivy) require_arg path "$1" "$2" && addJava "-Dsbt.ivy.home=$2" && shift 2 ;; 144 | -mem) require_arg integer "$1" "$2" && sbt_mem="$2" && shift 2 ;; 145 | -jvm-debug) require_arg port "$1" "$2" && addDebugger $2 && shift 2 ;; 146 | -batch) exec scalastyle.txt 25 | 26 | ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') 27 | rm scalastyle.txt 28 | 29 | if test ! -z "$ERRORS"; then 30 | echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" 31 | exit 1 32 | else 33 | echo -e "Scalastyle checks passed." 34 | fi 35 | -------------------------------------------------------------------------------- /dev/sbt-assembly-skip-test.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env bash 3 | 4 | # 5 | # Licensed to the Apache Software Foundation (ASF) under one or more 6 | # contributor license agreements. See the NOTICE file distributed with 7 | # this work for additional information regarding copyright ownership. 8 | # The ASF licenses this file to You under the Apache License, Version 2.0 9 | # (the "License"); you may not use this file except in compliance with 10 | # the License. You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | # 20 | 21 | SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" 22 | PROJECT_HOME="$(dirname $SCRIPT_DIR)" 23 | 24 | sbt "set test in assembly := {}" clean assembly 25 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | // This file should only contain the version of sbt to use. 2 | sbt.version=0.13.6 3 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | // You may use this file to add plugin dependencies for sbt. 2 | resolvers += "bintray-spark-packages" at "https://dl.bintray.com/spark-packages/maven/" 3 | 4 | resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" 5 | 6 | addSbtPlugin("org.spark-packages" %% "sbt-spark-package" % "0.2.3") 7 | 8 | addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") 9 | 10 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.1.0") 11 | 12 | -------------------------------------------------------------------------------- /scalastyle-config.xml: -------------------------------------------------------------------------------- 1 | 17 | 39 | 40 | 41 | Scalastyle standard configuration 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | true 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | ARROW, EQUALS, ELSE, TRY, CATCH, FINALLY, LARROW, RARROW 126 | 127 | 128 | 129 | 130 | 131 | ARROW, EQUALS, COMMA, COLON, IF, ELSE, DO, WHILE, FOR, MATCH, TRY, CATCH, FINALLY, LARROW, RARROW 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | ^FunSuite[A-Za-z]*$ 141 | Tests must extend org.apache.spark.SparkFunSuite instead. 142 | 143 | 144 | 145 | 146 | ^println$ 147 | 151 | 152 | 153 | 154 | Class\.forName 155 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 800> 210 | 211 | 212 | 213 | 214 | 30 215 | 216 | 217 | 218 | 219 | 10 220 | 221 | 222 | 223 | 224 | 50 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | -1,0,1,2,3 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/kaggle/ClickThroughRatePrediction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.examples.kaggle 19 | 20 | import scala.collection.mutable.ArrayBuffer 21 | 22 | import scopt.OptionParser 23 | 24 | import org.apache.spark.{SparkConf, SparkContext} 25 | import org.apache.spark.ml.{Pipeline, PipelineStage} 26 | import org.apache.spark.ml.classification.LogisticRegression 27 | import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator 28 | import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler} 29 | import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} 30 | import org.apache.spark.mllib.linalg.Vector 31 | import org.apache.spark.sql.{Row, SQLContext, SaveMode} 32 | import org.apache.spark.sql.functions.col 33 | import org.apache.spark.sql.types._ 34 | 35 | // scalastyle:off println 36 | object ClickThroughRatePrediction { 37 | 38 | private val trainSchema = StructType(Array( 39 | StructField("id", StringType, false), 40 | StructField("click", IntegerType, true), 41 | StructField("hour", IntegerType, true), 42 | StructField("C1", IntegerType, true), 43 | StructField("banner_pos", IntegerType, true), 44 | StructField("site_id", StringType, true), 45 | StructField("site_domain", StringType, true), 46 | StructField("site_category", StringType, true), 47 | StructField("app_id", StringType, true), 48 | StructField("app_domain", StringType, true), 49 | StructField("app_category", StringType, true), 50 | StructField("device_id", StringType, true), 51 | StructField("device_ip", StringType, true), 52 | StructField("device_model", StringType, true), 53 | StructField("device_type", IntegerType, true), 54 | StructField("device_conn_type", IntegerType, true), 55 | StructField("C14", IntegerType, true), 56 | StructField("C15", IntegerType, true), 57 | StructField("C16", IntegerType, true), 58 | StructField("C17", IntegerType, true), 59 | StructField("C18", IntegerType, true), 60 | StructField("C19", IntegerType, true), 61 | StructField("C20", IntegerType, true), 62 | StructField("C21", IntegerType, true) 63 | )) 64 | 65 | private val testSchema = StructType(Array( 66 | StructField("id", StringType, false), 67 | StructField("hour", IntegerType, true), 68 | StructField("C1", IntegerType, true), 69 | StructField("banner_pos", IntegerType, true), 70 | StructField("site_id", StringType, true), 71 | StructField("site_domain", StringType, true), 72 | StructField("site_category", StringType, true), 73 | StructField("app_id", StringType, true), 74 | StructField("app_domain", StringType, true), 75 | StructField("app_category", StringType, true), 76 | StructField("device_id", StringType, true), 77 | StructField("device_ip", StringType, true), 78 | StructField("device_model", StringType, true), 79 | StructField("device_type", IntegerType, true), 80 | StructField("device_conn_type", IntegerType, true), 81 | StructField("C14", IntegerType, true), 82 | StructField("C15", IntegerType, true), 83 | StructField("C16", IntegerType, true), 84 | StructField("C17", IntegerType, true), 85 | StructField("C18", IntegerType, true), 86 | StructField("C19", IntegerType, true), 87 | StructField("C20", IntegerType, true), 88 | StructField("C21", IntegerType, true) 89 | )) 90 | 91 | case class ClickThroughRatePredictionParams( 92 | trainInput: String = null, 93 | testInput: String = null, 94 | resultOutput: String = null 95 | ) 96 | 97 | /** 98 | * Try Kaggle's Click-Through Rate Prediction with Logistic Regression Classification 99 | * Run with 100 | * {{ 101 | * $SPARK_HOME/bin/spark-submit \ 102 | * --class org.apache.spark.examples.kaggle.ClickThroughRatePredictionWitLogisticRegression \ 103 | * /path/to/click-through-rate-prediction-assembly-1.1.jar \ 104 | * --train=/path/to/train \ 105 | * --test=/path/to/test \ 106 | * --result=/path/to/result.csv 107 | * }} 108 | * SEE ALSO: https://www.kaggle.com/c/avazu-ctr-prediction 109 | */ 110 | def main(args: Array[String]): Unit = { 111 | val conf = new SparkConf().setAppName(this.getClass.getSimpleName) 112 | val sc = new SparkContext(conf) 113 | val sqlContext = new SQLContext(sc) 114 | 115 | val defaultParam = new ClickThroughRatePredictionParams() 116 | val parser = new OptionParser[ClickThroughRatePredictionParams](this.getClass.getSimpleName) { 117 | head(s"${this.getClass.getSimpleName}: Try a Kaggle competition.") 118 | opt[String]("train") 119 | .text("train input") 120 | .action((x, c) => c.copy(trainInput = x)) 121 | .required() 122 | opt[String]("test") 123 | .text("test input") 124 | .action((x, c) => c.copy(testInput = x)) 125 | .required() 126 | opt[String]("result") 127 | .text("result output path") 128 | .action((x, c) => c.copy(resultOutput = x)) 129 | .required() 130 | } 131 | parser.parse(args, defaultParam).map { params => 132 | run(sc, sqlContext, params.trainInput, params.testInput, params.resultOutput) 133 | } getOrElse { 134 | sys.exit(1) 135 | } 136 | sc.stop() 137 | } 138 | 139 | def run(sc: SparkContext, sqlContext: SQLContext, 140 | trainPath: String, testPath: String, resultPath: String): Unit = { 141 | import sqlContext.implicits._ 142 | 143 | // Sets the target variables 144 | val targetVariables = Array( 145 | "banner_pos", "site_id", "site_domain", "site_category", 146 | "app_domain", "app_category", "device_model", "device_type", "device_conn_type", 147 | "C1", "C14", "C15", "C16", "C17", "C18", "C19", "C20", "C21" 148 | ) 149 | 150 | // Loads training data and testing data from CSV files 151 | val train = sqlContext.read.format("com.databricks.spark.csv") 152 | .option("header", "true") 153 | .schema(trainSchema) 154 | .load(trainPath).cache() 155 | val test = sqlContext.read.format("com.databricks.spark.csv") 156 | .option("header", "true") 157 | .schema(testSchema) 158 | .load(testPath).cache() 159 | 160 | // Union data for one-hot encoding 161 | // To extract features throughly, union the training and test data. 162 | // Since the test data includes values which doesn't exists in the training data. 163 | val train4union = train.select(targetVariables.map(col): _*) 164 | val test4union = test.select(targetVariables.map(col): _*) 165 | val union = train4union.unionAll(test4union).cache() 166 | 167 | // Extracts features with one-hot encoding 168 | def getIndexedColumn(clm: String): String = s"${clm}_indexed" 169 | def getColumnVec(clm: String): String = s"${clm}_vec" 170 | val feStages = ArrayBuffer.empty[PipelineStage] 171 | targetVariables.foreach { clm => 172 | val stringIndexer = new StringIndexer() 173 | .setInputCol(clm) 174 | .setOutputCol(getIndexedColumn(clm)) 175 | .setHandleInvalid("error") 176 | val oneHotEncoder = new OneHotEncoder() 177 | .setInputCol(getIndexedColumn(clm)) 178 | .setOutputCol(getColumnVec(clm)) 179 | Array(stringIndexer, oneHotEncoder) 180 | feStages.append(stringIndexer) 181 | feStages.append(oneHotEncoder) 182 | } 183 | val va = new VectorAssembler() 184 | .setInputCols(targetVariables.map(getColumnVec)) 185 | .setOutputCol("features") 186 | feStages.append(va) 187 | val fePipeline = new Pipeline().setStages(feStages.toArray) 188 | val feModel = fePipeline.fit(union) 189 | val trainDF = feModel.transform(train).select('click, 'features).cache() 190 | val testDF = feModel.transform(test).select('id, 'features).cache() 191 | union.unpersist() 192 | train.unpersist() 193 | test.unpersist() 194 | 195 | // Trains a model with CrossValidator 196 | val si4click = new StringIndexer() 197 | .setInputCol("click") 198 | .setOutputCol("label") 199 | val lr = new LogisticRegression() 200 | val pipeline = new Pipeline().setStages(Array(si4click, lr)) 201 | val paramGrid = new ParamGridBuilder() 202 | .addGrid(lr.threshold, Array(0.22)) 203 | .addGrid(lr.elasticNetParam, Array(0.0)) 204 | .addGrid(lr.regParam, Array(0.001)) 205 | .addGrid(lr.maxIter, Array(100)) 206 | .build() 207 | val cv = new CrossValidator() 208 | .setEstimator(pipeline) 209 | .setEvaluator(new BinaryClassificationEvaluator()) 210 | .setEstimatorParamMaps(paramGrid) 211 | .setNumFolds(3) 212 | val cvModel = cv.fit(trainDF) 213 | 214 | // Shows the best parameters 215 | cvModel.bestModel.parent match { 216 | case pipeline: Pipeline => 217 | pipeline.getStages.zipWithIndex.foreach { case (stage, index) => 218 | println(s"Stage[${index + 1}]: ${stage.getClass.getSimpleName}") 219 | println(stage.extractParamMap()) 220 | } 221 | } 222 | 223 | // Predicts with the trained best model 224 | val resultDF = cvModel.transform(testDF).select('id, 'probability).map { 225 | case Row(id: String, probability: Vector) => (id, probability(1)) 226 | }.toDF("id", "click") 227 | 228 | // Save the result 229 | resultDF.repartition(1).write.mode(SaveMode.Overwrite) 230 | .format("com.databricks.spark.csv") 231 | .option("header", "true").option("inferSchema", "true") 232 | .save(resultPath) 233 | } 234 | } 235 | 236 | // scalastyle:on println 237 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/SparkFunSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark 19 | 20 | // scalastyle:off 21 | import org.scalatest.{FunSuite, Outcome} 22 | 23 | /** 24 | * Base abstract class for all unit tests in Spark for handling common functionality. 25 | */ 26 | private[spark] abstract class SparkFunSuite extends FunSuite with Logging { 27 | // scalastyle:on 28 | 29 | /** 30 | * Log the suite name and the test name before and after each test. 31 | * 32 | * Subclasses should never override this method. If they wish to run 33 | * custom code before and after each test, they should mix in the 34 | * {{org.scalatest.BeforeAndAfter}} trait instead. 35 | */ 36 | final protected override def withFixture(test: NoArgTest): Outcome = { 37 | val testName = test.text 38 | val suiteName = this.getClass.getName 39 | val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s") 40 | try { 41 | logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n") 42 | test() 43 | } finally { 44 | logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n") 45 | } 46 | } 47 | 48 | } 49 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/examples/kaggle/ClickThroughRatePrediction.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.examples.kaggle 19 | 20 | import org.apache.spark.SparkFunSuite 21 | import org.apache.spark.util.MLlibTestSparkContext 22 | 23 | class ClickThroughRatePredictionSuite extends SparkFunSuite with MLlibTestSparkContext { 24 | 25 | test("run") { 26 | // Logger.getLogger("org").setLevel(Level.OFF) 27 | // Logger.getLogger("akka").setLevel(Level.OFF) 28 | 29 | val trainPath = this.getClass.getResource("/train.part-10000").getPath 30 | val testPath = this.getClass.getResource("/test.part-10000").getPath 31 | val resultPath = "./tmp/result/" 32 | 33 | ClickThroughRatePrediction.run(sc, sqlContext, trainPath, testPath, resultPath) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/util/LocalClusterSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.util 19 | 20 | import org.scalatest.{Suite, BeforeAndAfterAll} 21 | 22 | import org.apache.spark.{SparkConf, SparkContext} 23 | 24 | trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => 25 | @transient var sc: SparkContext = _ 26 | 27 | override def beforeAll() { 28 | super.beforeAll() 29 | val conf = new SparkConf() 30 | .setMaster("local-cluster[2, 1, 1024]") 31 | .setAppName("test-cluster") 32 | .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data 33 | sc = new SparkContext(conf) 34 | } 35 | 36 | override def afterAll() { 37 | try { 38 | if (sc != null) { 39 | sc.stop() 40 | } 41 | } finally { 42 | super.afterAll() 43 | } 44 | } 45 | } 46 | -------------------------------------------------------------------------------- /src/test/scala/org/apache/spark/util/MLlibTestSparkContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Licensed to the Apache Software Foundation (ASF) under one or more 3 | * contributor license agreements. See the NOTICE file distributed with 4 | * this work for additional information regarding copyright ownership. 5 | * The ASF licenses this file to You under the Apache License, Version 2.0 6 | * (the "License"); you may not use this file except in compliance with 7 | * the License. You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | 18 | package org.apache.spark.util 19 | 20 | import org.scalatest.{BeforeAndAfterAll, Suite} 21 | 22 | import org.apache.spark.sql.SQLContext 23 | import org.apache.spark.{SparkConf, SparkContext} 24 | 25 | trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => 26 | @transient var sc: SparkContext = _ 27 | @transient var sqlContext: SQLContext = _ 28 | 29 | override def beforeAll() { 30 | super.beforeAll() 31 | val conf = new SparkConf() 32 | .setMaster("local[2]") 33 | .setAppName("MLlibUnitTest") 34 | sc = new SparkContext(conf) 35 | SQLContext.clearActive() 36 | sqlContext = new SQLContext(sc) 37 | SQLContext.setActive(sqlContext) 38 | } 39 | 40 | override def afterAll() { 41 | try { 42 | sqlContext = null 43 | SQLContext.clearActive() 44 | if (sc != null) { 45 | sc.stop() 46 | } 47 | sc = null 48 | } finally { 49 | super.afterAll() 50 | } 51 | } 52 | } 53 | --------------------------------------------------------------------------------