├── .gitignore ├── project ├── build.properties └── plugins.sbt ├── src └── main │ ├── scala │ └── com │ │ └── github │ │ └── mlnick │ │ ├── glintfm │ │ ├── package.scala │ │ ├── Config.scala │ │ └── GlintFM.scala │ │ └── RunTests.scala │ └── resources │ ├── glintfm.conf │ └── log4j.properties ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | target 3 | *.iml -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 0.13.8 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | logLevel := Level.Warn 2 | 3 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.13.0") -------------------------------------------------------------------------------- /src/main/scala/com/github/mlnick/glintfm/package.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.util.collection 2 | 3 | /** 4 | * Created by nick on 2016/10/11. 5 | */ 6 | package object glintfm { 7 | 8 | type OHashMap[K, V] = OpenHashMap[K, V] 9 | } 10 | -------------------------------------------------------------------------------- /src/main/scala/com/github/mlnick/RunTests.scala: -------------------------------------------------------------------------------- 1 | package com.github.mlnick 2 | 3 | import com.github.mlnick.glintfm.GlintFM 4 | 5 | object RunTests extends App { 6 | 7 | val inputPath = "rcv1_train.binary" 8 | val format = "libsvm" 9 | val configPath = "src/main/resources/glintfm.conf" 10 | GlintFM.runTest(inputPath, configPath, format, parts = -1, models = 2, msgSize = 50000, timeout = 300, 11 | fitIntercept = true, fitLinear = true, k = 4, numIterations = 10, runML = false, runMLLIB = true, runGlint = true) 12 | } 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Glint FM 2 | 3 | ## Factorization Machines on Spark and Glint 4 | 5 | An implementation of distributed factorization machines on Spark using the [Glint parameter server](https://github.com/rjagerman/glint). 6 | 7 | To build, run `sbt package`. You will need to have Glint and spark-libFM installed locally, since they are not available on Maven central. 8 | 9 | You will need to clone the [glint repository](https://github.com/rjagerman/glint) and run `sbt publish-local` (master branch should be fine). 10 | 11 | For [spark-libFM](https://github.com/zhengruifeng/spark-libFM) you will need to build it with Spark 2.0 support - check out my branch [here](https://github.com/MLnick/spark-libFM/tree/spark20) which was used for performance comparisons. 12 | Again, run `sbt publish-local` to install locally before running glint-fm. 13 | -------------------------------------------------------------------------------- /src/main/resources/glintfm.conf: -------------------------------------------------------------------------------- 1 | # Place custom configuration for your glint cluster in here 2 | glint.master.host = "127.0.0.1" 3 | glint.master.port = 13370 4 | glint { 5 | server.akka.loglevel = "DEBUG" 6 | server.akka.stdout-loglevel = "DEBUG" 7 | client.akka.loglevel = "INFO" 8 | client.akka.stdout-loglevel = "INFO" 9 | master.akka.loglevel = "INFO" 10 | master.akka.stdout-loglevel = "INFO" 11 | master.akka.remote.log-remote-lifecycle-events = on 12 | server.akka.remote.log-remote-lifecycle-events = off 13 | client.akka.remote.log-remote-lifecycle-events = on 14 | client.timeout = 30 s 15 | 16 | master.akka.remote.transport-failure-detector.acceptable-heartbeat-pause = 120 s 17 | server.akka.remote.transport-failure-detector.acceptable-heartbeat-pause = 120 s 18 | client.akka.remote.transport-failure-detector.acceptable-heartbeat-pause = 120 s 19 | master.akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 120 s 20 | server.akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 120 s 21 | client.akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 120 s 22 | 23 | server.akka.remote.netty.tcp.maximum-frame-size = 32m 24 | client.akka.remote.netty.tcp.maximum-frame-size = 32m 25 | server.akka.remote.netty.tcp.send-buffer-size = 32m 26 | client.akka.remote.netty.tcp.send-buffer-size = 32m 27 | server.akka.remote.netty.tcp.receive-buffer-size = 32m 28 | client.akka.remote.netty.tcp.receive-buffer-size = 32m 29 | } 30 | 31 | -------------------------------------------------------------------------------- /src/main/scala/com/github/mlnick/glintfm/Config.scala: -------------------------------------------------------------------------------- 1 | package com.github.mlnick.glintfm 2 | 3 | /** 4 | * Config for testing 5 | */ 6 | object Config { 7 | 8 | val config = 9 | """ 10 | |glint.master.host = "127.0.0.1" 11 | |glint.master.port = 13370 12 | |glint { 13 | | server.akka.loglevel = "INFO" 14 | | server.akka.stdout-loglevel = "INFO" 15 | | client.akka.loglevel = "INFO" 16 | | client.akka.stdout-loglevel = "INFO" 17 | | master.akka.loglevel = "INFO" 18 | | master.akka.stdout-loglevel = "INFO" 19 | | master.akka.remote.log-remote-lifecycle-events = on 20 | | server.akka.remote.log-remote-lifecycle-events = off 21 | | client.akka.remote.log-remote-lifecycle-events = on 22 | | client.timeout = 30 s 23 | | 24 | | master.akka.remote.transport-failure-detector.acceptable-heartbeat-pause = 120 s 25 | | server.akka.remote.transport-failure-detector.acceptable-heartbeat-pause = 120 s 26 | | client.akka.remote.transport-failure-detector.acceptable-heartbeat-pause = 120 s 27 | | master.akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 120 s 28 | | server.akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 120 s 29 | | client.akka.remote.watch-failure-detector.acceptable-heartbeat-pause = 120 s 30 | | 31 | | server.akka.remote.netty.tcp.maximum-frame-size = 32m 32 | | client.akka.remote.netty.tcp.maximum-frame-size = 32m 33 | | server.akka.remote.netty.tcp.send-buffer-size = 32m 34 | | client.akka.remote.netty.tcp.send-buffer-size = 32m 35 | | server.akka.remote.netty.tcp.receive-buffer-size = 32m 36 | | client.akka.remote.netty.tcp.receive-buffer-size = 32m 37 | |} 38 | """.stripMargin 39 | } 40 | -------------------------------------------------------------------------------- /src/main/resources/log4j.properties: -------------------------------------------------------------------------------- 1 | # 2 | # Licensed to the Apache Software Foundation (ASF) under one or more 3 | # contributor license agreements. See the NOTICE file distributed with 4 | # this work for additional information regarding copyright ownership. 5 | # The ASF licenses this file to You under the Apache License, Version 2.0 6 | # (the "License"); you may not use this file except in compliance with 7 | # the License. You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | # Set everything to be logged to the console 19 | log4j.rootCategory=INFO, console 20 | log4j.appender.console=org.apache.log4j.ConsoleAppender 21 | log4j.appender.console.target=System.err 22 | log4j.appender.console.layout=org.apache.log4j.PatternLayout 23 | log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n 24 | 25 | # Set the default spark-shell log level to WARN. When running the spark-shell, the 26 | # log level for this class is used to overwrite the root logger's log level, so that 27 | # the user can have different defaults for the shell and regular Spark apps. 28 | log4j.logger.org.apache.spark.repl.Main=WARN 29 | 30 | # Settings to quiet third party logs that are too verbose 31 | com.github.mlnick.glintfm=DEBUG 32 | log4j.logger.org.spark_project.jetty=WARN 33 | log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR 34 | log4j.logger.org.apache.spark=WARN 35 | log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=WARN 36 | log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=WARN 37 | log4j.logger.org.apache.parquet=ERROR 38 | log4j.logger.parquet=ERROR 39 | 40 | # SPARK-9183: Settings to avoid annoying messages when looking up nonexistent UDFs in SparkSQL with Hive support 41 | log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=FATAL 42 | log4j.logger.org.apache.hadoop.hive.ql.exec.FunctionRegistry=ERROR 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/main/scala/com/github/mlnick/glintfm/GlintFM.scala: -------------------------------------------------------------------------------- 1 | package com.github.mlnick.glintfm 2 | 3 | import java.io.File 4 | import java.util.Date 5 | 6 | import scala.concurrent.duration._ 7 | import scala.concurrent.{Await, ExecutionContext, Future} 8 | 9 | import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, Vector => BV} 10 | import breeze.stats.distributions.Rand 11 | import com.typesafe.config.ConfigFactory 12 | import com.typesafe.scalalogging.slf4j.LazyLogging 13 | import glint.Client 14 | import glint.models.client.granular.{GranularBigMatrix, GranularBigVector} 15 | 16 | import org.apache.spark.ml.classification.LogisticRegression 17 | import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator 18 | import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics 19 | import org.apache.spark.mllib.linalg.{DenseMatrix, Vector, Vectors} 20 | import org.apache.spark.mllib.regression._ 21 | import org.apache.spark.sql.functions._ 22 | import org.apache.spark.sql.types.DoubleType 23 | import org.apache.spark.sql.{Row, SparkSession} 24 | import org.apache.spark.util.collection.glintfm.OHashMap 25 | 26 | 27 | case class FMResults(model: FMModel, auc: Double, time: Double) 28 | case class GlintFMResults(auc: Double, time: Double) 29 | case class MLResults(iter: Int, model: org.apache.spark.ml.classification.LogisticRegressionModel, auc: Double, time: Double) 30 | 31 | case class PushStats(time: Double, size: Long) 32 | case class PullStats(time: Double, size: Long) 33 | case class ComputeStats(time: Double, size: Long) 34 | case class IterStats(push: PushStats, pull: PullStats, grad: ComputeStats, time: Double) 35 | 36 | object GlintFM extends LazyLogging { 37 | 38 | def runTest( 39 | inputPath: String, 40 | configPath: String, 41 | format: String = "parquet", 42 | parts: Int = -1, 43 | models: Int = -1, 44 | msgSize: Int = 100000, 45 | fitIntercept: Boolean = true, fitLinear: Boolean = true, k: Int = 2, 46 | interceptRegParam: Double = 0.0, wRegParam: Double = 0.1, vRegParam: Double = 0.1, initStd: Double = 0.1, 47 | mlRegParam: Double = 0.1, 48 | numIterations: Int = 10, 49 | mlStepSize: Double = 1.0, 50 | glintStepSize: Double = 1.0, 51 | treeDepth: Int = 2, 52 | timeout: Int = 30, 53 | runML: Boolean = true, 54 | runMLLIB: Boolean = true, 55 | runGlint: Boolean = true) = { 56 | 57 | val spark = SparkSession.builder() 58 | .master("local[*]") 59 | .appName("glint-fm") 60 | .getOrCreate() 61 | 62 | val raw = spark.read.format(format).load(inputPath).select(col("label").cast(DoubleType), col("features")) 63 | val df = if (parts < 0) { 64 | raw 65 | } else { 66 | raw.repartition(parts) 67 | } 68 | df.cache() 69 | val splits = df.randomSplit(Array(0.8, 0.2), seed = 42) 70 | val (train, test) = (splits(0), splits(1)) 71 | // map to -1/1 because spark-libFM uses this form 72 | val mllibTrain = train.rdd.map { case Row(l: Double, v: org.apache.spark.ml.linalg.Vector) => 73 | LabeledPoint(if (l <= 0) -1.0 else 1.0, Vectors.fromML(v)) 74 | } 75 | mllibTrain.cache() 76 | val mllibTest = test.rdd.map { case Row(l: Double, v: org.apache.spark.ml.linalg.Vector) => 77 | LabeledPoint(if (l <= 0) -1.0 else 1.0, Vectors.fromML(v)) 78 | } 79 | mllibTest.cache() 80 | val n = df.count() 81 | val dim = mllibTrain.first().features.size 82 | 83 | val mlResults = if (runML) { 84 | logger.warn(s"Starting ML LoR test run at ${new Date().toString}.") 85 | // ==== Spark ML LR 86 | val start = System.currentTimeMillis() 87 | val lr = new LogisticRegression() 88 | .setRegParam(mlRegParam) 89 | .setMaxIter(numIterations) 90 | .setFitIntercept(fitIntercept) 91 | .setStandardization(false) 92 | val model = lr.fit(train.select(when(train("label") <= 0, 0.0).otherwise(1.0).alias("label"), train("features"))) 93 | val elapsed = (System.currentTimeMillis() - start) / 1000.0 94 | val eval = new BinaryClassificationEvaluator() 95 | val auc = eval.evaluate(model.transform(test.select(when(test("label") <= 0, 0.0).otherwise(1.0).alias("label"), test("features")))) 96 | 97 | logger.warn(s"Completed ML LoR test run at ${new Date().toString}.") 98 | Some(MLResults(numIterations, model, auc, elapsed)) 99 | } else { 100 | None 101 | } 102 | 103 | val mllibResults = if (runMLLIB) { 104 | logger.warn(s"Starting MLlib FM test run at ${new Date().toString}.") 105 | // ==== Spark MLlib GradientDescent FM 106 | val mstart = System.currentTimeMillis() 107 | val model = FMWithSGD.train(mllibTrain, task = 1, numIterations = numIterations, 108 | stepSize = mlStepSize, miniBatchFraction = 1.0, 109 | dim = (fitIntercept, fitLinear, k), regParam = (interceptRegParam, wRegParam, vRegParam), initStd = initStd, treeDepth = treeDepth) 110 | val elapsed = (System.currentTimeMillis() - mstart) / 1000.0 111 | val scores = model.predict(mllibTest.map(_.features)).zip(mllibTest.map(_.label)) 112 | val auc = new BinaryClassificationMetrics(scores).areaUnderROC() 113 | logger.warn(s"MLlib FM predictions: ${scores.take(20).mkString(",")}") 114 | logger.warn(s"Completed MLlib FM test run at ${new Date().toString}.") 115 | Some(FMResults(model, auc, elapsed)) 116 | } else { 117 | None 118 | } 119 | 120 | val glintResults = if (runGlint) { 121 | logger.warn(s"Starting Glint FM test run at ${new Date().toString}.") 122 | // ==== Glint FM 123 | val config = ConfigFactory.parseFile(new File(configPath)) 124 | @transient val client = Client(config) 125 | @transient implicit val ec = ExecutionContext.Implicits.global 126 | val numParts = mllibTrain.getNumPartitions 127 | val min = Double.MinValue 128 | val max = Double.MaxValue 129 | 130 | // set up coefficients 131 | val wDim = if (fitLinear) { 132 | if (fitIntercept) dim + 1 else dim 133 | } else { 134 | if (fitIntercept) 1 else 0 135 | } 136 | val distW = if (wDim > 0) { 137 | Some(new GranularBigVector(client.vector[Double](wDim), msgSize)) 138 | } else { 139 | None 140 | } 141 | val distV = if (k > 0) { 142 | Some(new GranularBigMatrix(client.matrix[Double](dim, k), msgSize)) 143 | } else { 144 | None 145 | } 146 | 147 | val gstart = System.currentTimeMillis() 148 | val glintIterStats = mllibTrain.mapPartitions { iter => 149 | implicit val ec = ExecutionContext.Implicits.global 150 | val partitionData = iter.toIterable 151 | // TODO shuffle data per partition? 152 | // TODO mini-batch SGD per partition? 153 | // pre-compute the local feature indices for this partition 154 | val localKeys = collection.mutable.HashSet[Long]() 155 | // add intercept to keyset if used 156 | if (fitIntercept) localKeys.add(dim) 157 | partitionData.foreach { case LabeledPoint(_, features) => 158 | features.foreachActive { case (idx, _) => localKeys.add(idx.toLong) } 159 | } 160 | // feature indices to pull/push from servers 161 | val keys = localKeys.toArray.sorted 162 | // int keys for mapping local to global feature index in arrays 163 | val idx = keys.map(_.toInt) 164 | val localWDim = if (fitLinear) { 165 | // keys is already correct for whatever value of k0 166 | keys.length 167 | } else if (fitIntercept) { 168 | 1 169 | } else { 170 | 0 171 | } 172 | val wKeys = if (localWDim == 1) Array(0L) else keys 173 | // we need to ignore intercept to get rows of V 174 | val localVDim = if (fitIntercept) keys.length - 1 else keys.length 175 | logger.info(s"Local unique keys: ${keys.length}") 176 | // stat holders 177 | val iterStats = new Array[IterStats](numIterations) 178 | for (iter <- 1 to numIterations) { 179 | logger.info(s"Starting iteration $iter") 180 | val iterStart = System.currentTimeMillis() 181 | // pull coefficients from param server 182 | val (result, pullStats) = if (iter == 1) { 183 | // if 1st iteration we don't pull coefficients 184 | // init w to zeros 185 | val zeroW = BDV.zeros[Double](localWDim) 186 | // init V to N(0, initStd) 187 | val randV = BDM.rand(localVDim, k, Rand.gaussian(0.0, initStd)) 188 | ((zeroW, randV), PullStats(0, 0)) 189 | } else { 190 | val pullStart = System.currentTimeMillis() 191 | // pull relevant keys of w 192 | val pullW = distW.map { w => 193 | w.pull(wKeys).map(values => BDV[Double](values)) 194 | }.getOrElse { 195 | Future { BDV.zeros[Double](0) } 196 | } 197 | // pull relevant rows of V 198 | val rows = if (fitIntercept) keys.init else keys 199 | val pullV = distV.map(_.pull(rows).map { vectors => 200 | // stack vectors to form the local V matrix 201 | BDV.horzcat[Double](vectors.map(_.toDenseVector): _*).t 202 | }).getOrElse { Future { BDM.zeros[Double](0, 0) } } 203 | val pulls = for { 204 | wr <- pullW 205 | vr <- pullV 206 | } yield (wr, vr) 207 | 208 | val result = Await.result(pulls, timeout seconds) 209 | val pullElapsed = (System.currentTimeMillis() - pullStart) / 1000.0 210 | logger.info(f"Iteration $iter - pull time $pullElapsed%2.4f sec; w size=$localWDim, V size=($localVDim,$k)") 211 | (result, PullStats(pullElapsed, (localWDim + localVDim * k) * 8)) 212 | } 213 | val w = result._1 214 | val V = result._2 215 | 216 | // gradient computation 217 | val gradStart = System.currentTimeMillis() 218 | val agg = partitionData.foldLeft(new FMAggregator(1, localWDim, idx, fitIntercept, fitLinear, k, min, max)) { 219 | case (a, LabeledPoint(label, data)) => 220 | a.add(data, label, w, V) 221 | } 222 | val count = agg.getNumExamples 223 | val loss = agg.getLossSum 224 | val gradElapsed = (System.currentTimeMillis() - gradStart) / 1000.0 225 | logger.info(f"Iteration $iter - gradient computation stats: elapsed=$gradElapsed%2.4f sec; pred=${agg._pelapsed}%2.4f sec, grad=${agg._gelapsed}%2.4f sec; loss=$loss; examples=$count") 226 | 227 | val scale = count.toDouble * numParts 228 | val gradW = agg.getGradW 229 | val gradV = agg.getGradV 230 | val step = glintStepSize / math.sqrt(iter) 231 | 232 | // compute updates 233 | val updateStart = System.currentTimeMillis() 234 | val updateW = new Array[Double](gradW.length) 235 | // update w 236 | if (fitIntercept) { 237 | updateW(localWDim - 1) = -step * (gradW(localWDim - 1) / scale + interceptRegParam * w(localWDim - 1)) 238 | } 239 | if (fitLinear) { 240 | for (i <- 0 until localWDim - 2) { 241 | updateW(i) = -step * (gradW(i) / scale + wRegParam * w(i)) 242 | } 243 | } 244 | // update V 245 | val rows = new Array[Long](localVDim * k) 246 | val cols = new Array[Int](localVDim * k) 247 | val values = new Array[Double](localVDim * k) 248 | 249 | var uk = 0 250 | var i = 0 251 | while (i < localVDim) { 252 | val idx = keys(i) 253 | var j = 0 254 | while (j < k) { 255 | values(uk) = -step * (gradV(i, j) / scale + vRegParam * V(i, j)) 256 | rows(uk) = idx 257 | cols(uk) = j 258 | uk += 1 259 | j += 1 260 | } 261 | i += 1 262 | } 263 | val updateElapsed = (System.currentTimeMillis() - updateStart) / 1000.0 264 | logger.info(f"Iteration $iter - compute update time $updateElapsed%2.4f sec") 265 | val gradStats = ComputeStats(gradElapsed + updateElapsed, count) 266 | 267 | val pushStart = System.currentTimeMillis() 268 | val pushes = for { 269 | pushW <- distW.map(_.push(wKeys, updateW)).getOrElse(Future { true } ) 270 | pushV <- distV.map(_.push(rows, cols, values)).getOrElse(Future { true }) 271 | } yield (pushW, pushV) 272 | Await.result(pushes, timeout seconds) 273 | val pushElapsed = (System.currentTimeMillis() - pushStart) / 1000.0 274 | logger.info(f"Iteration $iter - push time $pushElapsed%2.4f sec; w size=$localWDim, V size=($localVDim,$k)") 275 | val pushStats = PushStats(pushElapsed, (localWDim + localVDim * k) * 8) 276 | val iterElapsed = (System.currentTimeMillis() - iterStart) / 1000.0 277 | logger.info(f"Iteration $iter - total time $iterElapsed%2.4f sec") 278 | 279 | iterStats(iter - 1) = IterStats(pushStats, pullStats, gradStats, iterElapsed) 280 | } 281 | Iterator.single(iterStats) 282 | }.collect() 283 | 284 | val elapsed = (System.currentTimeMillis() - gstart) / 1000.0 285 | 286 | logger.warn(s"Glint FM elapsed training time: $elapsed") 287 | 288 | import spark.implicits._ 289 | val stats = glintIterStats.flatMap { partStats => 290 | partStats.map { s => 291 | (s.push.size, s.pull.time, s.push.time, s.grad.time, s.time) 292 | } 293 | }.toSeq.toDF("size", "pull", "push", "comp", "total") 294 | stats.groupBy().avg().show() 295 | stats.columns.foreach { c => 296 | val m = stats.stat.approxQuantile(c, Array(0.5, 0.75, 1.0), 0.001) 297 | logger.warn(f"Stats $c median=${m(0)}%4.2f; 75th=${m(1)}%4.2f; max=${m(1)}%4.2f") 298 | } 299 | 300 | // predict distributed 301 | val scores = mllibTest.mapPartitions { iter => 302 | implicit val ec = ExecutionContext.Implicits.global 303 | val partitionData = iter.toIterable 304 | // pre-compute the local feature indices for this partition 305 | val localKeys = collection.mutable.HashSet[Long]() 306 | // add intercept to keyset if used 307 | if (fitIntercept) localKeys.add(dim) 308 | partitionData.foreach { case LabeledPoint(_, features) => 309 | features.foreachActive { case (idx, _) => localKeys.add(idx.toLong) } 310 | } 311 | // feature indices to pull/push from servers 312 | val keys = localKeys.toArray.sorted 313 | // int keys for mapping local to global feature index in arrays 314 | val idx = keys.map(_.toInt) 315 | val localWDim = if (fitLinear) { 316 | // keys is already correct for whatever value of k0 317 | keys.length 318 | } else if (fitIntercept) { 319 | 1 320 | } else { 321 | 0 322 | } 323 | val wKeys = if (localWDim == 1) Array(0L) else keys 324 | val localVDim = if (fitIntercept) keys.length - 1 else keys.length 325 | 326 | val pullStart = System.currentTimeMillis() 327 | // pull relevant keys of w 328 | val pullW = distW.map { w => 329 | w.pull(wKeys).map(values => BDV[Double](values)) 330 | }.getOrElse { 331 | Future { BDV.zeros[Double](0) } 332 | } 333 | // pull relevant rows of V 334 | val rows = if (fitIntercept) keys.init else keys 335 | val pullV = distV.map(_.pull(rows).map { vectors => 336 | // stack vectors to form the local V matrix 337 | BDV.horzcat[Double](vectors.map(_.toDenseVector): _*).t 338 | }).getOrElse { Future { BDM.zeros[Double](0, 0) } } 339 | val pulls = for { 340 | wr <- pullW 341 | vr <- pullV 342 | } yield (wr, vr) 343 | 344 | val result = Await.result(pulls, timeout seconds) 345 | val pullElapsed = (System.currentTimeMillis() - pullStart) / 1000.0 346 | logger.info(f"Model evaluation - pull time $pullElapsed%2.4f sec; w size=$localWDim, V size=($localVDim,$k)") 347 | val w = result._1 348 | val V = result._2 349 | val model = new FMAggregator(1, localWDim, idx, fitIntercept, fitLinear, k, min, max) 350 | partitionData.map { case (LabeledPoint(label, features)) => 351 | (model.testPredict(features, w, V), label) 352 | }.toIterator 353 | } 354 | val auc = new BinaryClassificationMetrics(scores).areaUnderROC() 355 | logger.warn(s"Glint FM predictions: ${scores.take(20).mkString(",")}") 356 | logger.warn(s"Completed Glint FM test run at ${new Date().toString}.") 357 | 358 | // clean up distributed coefficients 359 | distV.foreach(_.destroy()) 360 | distW.foreach(_.destroy()) 361 | client.stop() 362 | 363 | Some(GlintFMResults(auc, elapsed)) 364 | } else { 365 | None 366 | } 367 | 368 | // print results 369 | logger.warn(s"Completed test run at ${new Date().toString}.") 370 | logger.warn(s"Rows: $n, Cols: $dim") 371 | logger.warn(s"Num iterations=$numIterations, parts=$parts, dim=${(fitIntercept, fitLinear, k)}, regParam=${(interceptRegParam, wRegParam, vRegParam)}") 372 | mlResults.foreach { case MLResults(_, model, auc, elapsed) => 373 | logger.warn(s"ML LoR results -- regParam=$mlRegParam") 374 | logger.warn(s"ML LoR model: w0 = ${model.intercept}; w = ${model.coefficients.toArray.take(10).mkString(",")}") 375 | logger.warn(s"ML LoR test AUC: $auc") 376 | logger.warn(s"ML LoR elapsed training time: $elapsed") 377 | } 378 | mllibResults.foreach { case FMResults(model, auc, elapsed) => 379 | logger.warn(s"MLlib FM results -- stepSize=$mlStepSize") 380 | val w = model.weightVector.getOrElse(Vectors.zeros(0)).toArray.take(10).mkString(",") 381 | logger.warn(s"MLlib FM model: w0 = ${model.intercept}; w = $w; V = ${if (k > 0) model.factorMatrix else ""}") 382 | logger.warn(s"MLlib FM test AUC: $auc") 383 | logger.warn(s"MLlib FM elapsed training time: $elapsed") 384 | } 385 | glintResults.foreach { case GlintFMResults(auc, elapsed) => 386 | logger.warn(s"Glint FM results -- stepSize=$glintStepSize, models=$models, msgSize=$msgSize, timeout=$timeout") 387 | logger.warn(s"Glint FM test AUC: $auc") 388 | logger.warn(s"Glint FM elapsed training time: $elapsed") 389 | } 390 | 391 | mllibTrain.unpersist() 392 | mllibTest.unpersist() 393 | df.unpersist() 394 | } 395 | 396 | } 397 | 398 | // Logic taken mostly from spark-libFM 399 | private class FMAggregator( 400 | task: Int, // 0 = regression, 1 = classification 401 | wDim: Int, 402 | idx: Array[Int], 403 | fitIntercept: Boolean, 404 | fitLinear: Boolean, 405 | k: Int, 406 | min: Double, 407 | max: Double) { 408 | 409 | private var examples = 0 410 | private var lossSum = 0.0 411 | 412 | private val activeDim = idx.length 413 | private val wGrad = new Array[Double](wDim) 414 | private val vGrad = BDM.zeros[Double](activeDim, k) 415 | 416 | private val idxMapping = { 417 | val m = new OHashMap[Int, Int](activeDim) 418 | var i = 0 419 | while (i < activeDim) { 420 | m.update(idx(i), i) 421 | i += 1 422 | } 423 | m 424 | } 425 | 426 | def getNumExamples = examples 427 | def getLossSum = lossSum 428 | def getGradW = wGrad 429 | def getGradV = vGrad 430 | 431 | def testPredict(data: Vector, weights: BDV[Double], V: BDM[Double]): Double = { 432 | 433 | var pred = if (fitIntercept) weights(weights.length - 1) else 0.0 434 | 435 | if (fitLinear) { 436 | data.foreachActive { 437 | case (i, v) => 438 | pred += weights(idxMapping(i)) * v 439 | } 440 | } 441 | 442 | for (f <- 0 until k) { 443 | var sum = 0.0 444 | var sumSqr = 0.0 445 | data.foreachActive { 446 | case (i, v) => 447 | val d = V(idxMapping(i), f) * v 448 | sum += d 449 | sumSqr += d * d 450 | } 451 | pred += (sum * sum - sumSqr) / 2 452 | } 453 | 454 | task match { 455 | case 0 => 456 | Math.min(Math.max(pred, min), max) 457 | case 1 => 458 | 1.0 / (1.0 + Math.exp(-pred)) 459 | } 460 | } 461 | 462 | private def predict(data: Vector, weights: BDV[Double], V: BDM[Double]): (Double, Array[Double]) = { 463 | 464 | var pred = if (fitIntercept) weights(weights.length - 1) else 0.0 465 | 466 | if (fitLinear) { 467 | data.foreachActive { 468 | case (i, v) => 469 | pred += weights(idxMapping(i)) * v 470 | } 471 | } 472 | 473 | val sum = Array.fill(k)(0.0) 474 | for (f <- 0 until k) { 475 | var sumSqr = 0.0 476 | data.foreachActive { 477 | case (i, v) => 478 | val d = V(idxMapping(i), f) * v 479 | sum(f) += d 480 | sumSqr += d * d 481 | } 482 | pred += (sum(f) * sum(f) - sumSqr) / 2 483 | } 484 | 485 | if (task == 0) { 486 | pred = Math.min(Math.max(pred, min), max) 487 | } 488 | (pred, sum) 489 | } 490 | 491 | private def updateGradientInPlace( 492 | data: Vector, 493 | label: Double, 494 | pred: Double, 495 | sum: Array[Double], 496 | V: BDM[Double]): Unit = { 497 | 498 | val mult = task match { 499 | case 0 => 500 | pred - label 501 | case 1 => 502 | -label * (1.0 - 1.0 / (1.0 + Math.exp(-label * pred))) 503 | } 504 | 505 | if (fitIntercept) { 506 | wGrad(wGrad.length - 1) += mult 507 | } 508 | if (fitLinear) { 509 | data.foreachActive { case (i, v) => 510 | wGrad(idxMapping(i)) += v * mult 511 | } 512 | } 513 | 514 | data.foreachActive { case (i, v) => 515 | val idx = idxMapping(i) 516 | for (f <- 0 until k) { 517 | val g = (sum(f) * v - V(idx, f) * v * v) * mult 518 | vGrad(idx, f) += g 519 | } 520 | } 521 | 522 | } 523 | 524 | var _pelapsed = 0.0 525 | var _gelapsed = 0.0 526 | 527 | def add(data: Vector, label: Double, w: BDV[Double], V: BDM[Double]): this.type = { 528 | 529 | val ps = System.currentTimeMillis() 530 | val (pred, sum) = predict(data, w, V) 531 | _pelapsed += (System.currentTimeMillis() - ps) / 1000.0 532 | 533 | val gs = System.currentTimeMillis() 534 | updateGradientInPlace(data, label, pred, sum, V) 535 | _gelapsed += (System.currentTimeMillis() - gs) / 1000.0 536 | 537 | val loss = task match { 538 | case 0 => 539 | (pred - label) * (pred - label) 540 | case 1 => 541 | 1 - Math.signum(pred * label) 542 | } 543 | lossSum += loss 544 | examples += 1 545 | this 546 | } 547 | 548 | } 549 | --------------------------------------------------------------------------------