├── .gitignore ├── LICENSE ├── README.md ├── conf ├── default.conf.template ├── env.sh.template └── servers.template ├── dependency-reduced-pom.xml ├── docs ├── doc └── problem ├── pom.xml ├── spark-asyspark_2.11.iml └── src ├── META-INF └── MANIFEST.MF ├── main └── scala │ └── org │ └── apache │ └── spark │ ├── asyspark │ ├── asyml │ │ └── asysgd │ │ │ ├── AsyGradientDescent.scala │ │ │ └── GlobalWeight.scala │ └── core │ │ ├── Client.scala │ │ ├── Exceptios │ │ ├── PullFailedException.scala │ │ └── PushFailedException.scala │ │ ├── Main.scala │ │ ├── Master.scala │ │ ├── Server.scala │ │ ├── messages │ │ ├── master │ │ │ ├── ClientList.scala │ │ │ ├── RegisterClient.scala │ │ │ ├── RegisterServer.scala │ │ │ └── ServerList.scala │ │ └── server │ │ │ ├── logic │ │ │ ├── AcknowledgeReceipt.scala │ │ │ ├── Forget.scala │ │ │ ├── GenerateUniqueID.scala │ │ │ ├── NotAcknowledgeReceipt.scala │ │ │ └── UniqueID.scala │ │ │ ├── request │ │ │ ├── PullVector.scala │ │ │ ├── PushVectorDouble.scala │ │ │ ├── PushVectorFloat.scala │ │ │ ├── PushVectorInt.scala │ │ │ ├── PushVectorLong.scala │ │ │ └── Request.scala │ │ │ └── response │ │ │ ├── Response.scala │ │ │ ├── ResponseDouble.scala │ │ │ ├── ResponseFloat.scala │ │ │ ├── ResponseInt.scala │ │ │ └── ResponseLong.scala │ │ ├── models │ │ ├── client │ │ │ ├── BigVector.scala │ │ │ └── asyImp │ │ │ │ ├── AsyBigVector.scala │ │ │ │ ├── AsyBigVectorDouble.scala │ │ │ │ ├── AsyBigVectorFloat.scala │ │ │ │ ├── AsyBigVectorInt.scala │ │ │ │ ├── AsyBigVectorLong.scala │ │ │ │ ├── DeserializationHelper.scala │ │ │ │ ├── PullFSM.scala │ │ │ │ └── PushFSM.scala │ │ └── server │ │ │ ├── PartialVector.scala │ │ │ ├── PartialVectorDouble.scala │ │ │ ├── PartialVectorFloat.scala │ │ │ ├── PartialVectorInt.scala │ │ │ ├── PartialVectorLong.scala │ │ │ └── PushLogic.scala │ │ ├── partitions │ │ ├── Partition.scala │ │ ├── Partitioner.scala │ │ └── range │ │ │ ├── RangePartition.scala │ │ │ └── RangePartitioner.scala │ │ └── serialization │ │ ├── FastPrimitiveDeserializer.scala │ │ ├── FastPrimitiveSerializer.scala │ │ ├── RequestSerializer.scala │ │ ├── ResponseSerializer.scala │ │ └── SerializationConstants.scala │ └── examples │ ├── AsySGDExample.scala │ ├── TestBroadCast.scala │ ├── TestClient.scala │ └── TestRemote.scala └── test └── scala └── Test.scala /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | target/ 3 | log/ 4 | src/main/scala/org/apache/spark/myExamples 5 | src/main/scala/org/apache/spark/mySql 6 | .idea/ 7 | *.log 8 | classes/ 9 | *.class 10 | *.jar 11 | 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2016 codlife 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # asyspark 2 | ## Spark 3 | Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, Python, and R, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and DataFrames, MLlib for machine learning, GraphX for graph processing, and Spark Streaming for stream processing. 4 | ## asySpark 5 | AsySpark is an component of spark, this component can make machine learning work more efficient with a asynchronous computing model.such as asynchronous stochastic gradient descent. 6 | ## Tips 7 | If you want to do something with us, contact us. 8 | ## Forther reading 9 | ###Web resources 10 | 1:[Dean, NIPS‘13, Li, OSDI‘14 ](http://ps-lite.readthedocs.io/en/latest/overview.html#further-reads)The parameter server architecture
11 | 2:[淘宝参数服务器架构](http://www.36dsj.com/archives/60938)
12 | ###Papers 13 | 1:[Langford, NIPS‘09, Agarwal, NIPS‘11](http://arxiv.org/pdf/1104.5525.pdf) theoretical convergence of asynchronous SGD
14 | 2:[Li, WSDM‘16](http://www.cs.cmu.edu/~yuxiangw/docs/fm.pdf) practical considerations for asynchronous SGD with the parameter server
15 | 16 | 17 | -------------------------------------------------------------------------------- /conf/default.conf.template: -------------------------------------------------------------------------------- 1 | # Place custom configuration for your glint cluster in here 2 | glint.master.host = "127.0.0.1" 3 | glint.master.port = 13370 4 | 5 | -------------------------------------------------------------------------------- /conf/env.sh.template: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Change these values depending on your setup 3 | GLINT_JAR_PATH=./bin/Glint-0.1.jar # Path to the assembled glint jar file, this should be the same on all cluster machines 4 | GLINT_MASTER_OPTS="-Xmx2048m" # Java options to pass the JVM when starting a master 5 | GLINT_SERVER_OPTS="-Xmx2048m" # Java options to pass the JVM when starting a server 6 | -------------------------------------------------------------------------------- /conf/servers.template: -------------------------------------------------------------------------------- 1 | # Here you can place the hostnames or ip addresses of machines on which you wish to start parameter servers 2 | localhost 3 | -------------------------------------------------------------------------------- /dependency-reduced-pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4.0.0 4 | otcaix 5 | asyspark 6 | 1.0-SNAPSHOT 7 | 2016 8 | 9 | src/main/scala 10 | src/test/scala 11 | 12 | 13 | org.scala-tools 14 | maven-scala-plugin 15 | 16 | 17 | 18 | compile 19 | testCompile 20 | 21 | 22 | 23 | 24 | ${scala.version} 25 | 26 | -target:jvm-1.5 27 | 28 | 29 | 30 | 31 | maven-shade-plugin 32 | 33 | 34 | package 35 | 36 | shade 37 | 38 | 39 | 40 | 41 | 42 | 43 | *:* 44 | 45 | META-INF/*.SF 46 | META-INF/*.DSA 47 | META-INF/*.RSA 48 | 49 | 50 | 51 | 52 | 53 | org.apache.spark.asyspark.core.Main 54 | 55 | 56 | reference.conf 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | scala-tools.org 66 | Scala-Tools Maven2 Repository 67 | http://scala-tools.org/repo-releases 68 | 69 | 70 | 71 | 72 | scala-tools.org 73 | Scala-Tools Maven2 Repository 74 | http://scala-tools.org/repo-releases 75 | 76 | 77 | 78 | 79 | junit 80 | junit 81 | 4.4 82 | test 83 | 84 | 85 | org.specs 86 | specs 87 | 1.2.5 88 | test 89 | 90 | 91 | scalatest 92 | org.scalatest 93 | 94 | 95 | scalacheck 96 | org.scalacheck 97 | 98 | 99 | jmock 100 | org.jmock 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | org.scala-tools 109 | maven-scala-plugin 110 | 111 | ${scala.version} 112 | 113 | 114 | 115 | 116 | 117 | 2.11.8 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /docs/doc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASISCAS/asyspark/cff60c9d4ae6eb01691a8a8b86f3deca8b67a025/docs/doc -------------------------------------------------------------------------------- /docs/problem: -------------------------------------------------------------------------------- 1 | 1: about the project package name: 2 | Because something can't be accessed outside org.apache.spark, so currently we just name 3 | the package as: org.apache.spark.asyspark... 4 | 2: -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 3 | 4.0.0 4 | otcaix 5 | asyspark 6 | 1.0-SNAPSHOT 7 | 2016 8 | 9 | 2.11.8 10 | 11 | 12 | 13 | 14 | scala-tools.org 15 | Scala-Tools Maven2 Repository 16 | http://scala-tools.org/repo-releases 17 | 18 | 19 | 20 | 21 | 22 | scala-tools.org 23 | Scala-Tools Maven2 Repository 24 | http://scala-tools.org/repo-releases 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | org.scala-lang 33 | scala-library 34 | ${scala.version} 35 | 36 | 37 | junit 38 | junit 39 | 4.4 40 | test 41 | 42 | 43 | org.specs 44 | specs 45 | 1.2.5 46 | test 47 | 48 | 49 | org.apache.hadoop 50 | hadoop-client 51 | 2.7.1 52 | 53 | 54 | javax.servlet 55 | * 56 | 57 | 58 | 59 | 60 | org.apache.hadoop 61 | hadoop-hdfs 62 | 2.7.1 63 | 64 | 65 | org.apache.spark 66 | spark-core_2.11 67 | 2.0.0 68 | 69 | 70 | 71 | org.apache.spark 72 | spark-sql_2.11 73 | 2.0.0 74 | 75 | 76 | 77 | org.apache.spark 78 | spark-streaming_2.11 79 | 2.0.0 80 | 81 | 82 | 83 | org.apache.spark 84 | spark-mllib_2.11 85 | 2.0.0 86 | 87 | 88 | 89 | org.apache.spark 90 | spark-hive_2.11 91 | 2.0.0 92 | 93 | 94 | net.alchim31.maven 95 | scala-maven-plugin 96 | 3.2.2 97 | 98 | 99 | com.typesafe.scala-logging 100 | scala-logging-slf4j_2.11 101 | 2.1.2 102 | 103 | 104 | com.typesafe.akka 105 | akka-actor_2.11 106 | 2.4.10 107 | 108 | 109 | com.github.scopt 110 | scopt_2.10 111 | 3.5.0 112 | 113 | 114 | 115 | com.typesafe.akka 116 | akka-remote_2.11 117 | 2.4.10 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | src/main/scala 127 | src/test/scala 128 | 129 | 130 | org.scala-tools 131 | maven-scala-plugin 132 | 133 | 134 | 135 | compile 136 | testCompile 137 | 138 | 139 | 140 | 141 | ${scala.version} 142 | 143 | -target:jvm-1.5 144 | 145 | 146 | 147 | 148 | 152 | 153 | 154 | org.apache.maven.plugins 155 | maven-shade-plugin 156 | 157 | 158 | package 159 | 160 | shade 161 | 162 | 163 | 164 | 165 | 166 | 167 | *:* 168 | 169 | META-INF/*.SF 170 | META-INF/*.DSA 171 | META-INF/*.RSA 172 | 173 | 174 | 175 | 176 | 177 | org.apache.spark.asyspark.core.Main 178 | 179 | 181 | reference.conf 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | org.scala-tools 206 | maven-scala-plugin 207 | 208 | ${scala.version} 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | -------------------------------------------------------------------------------- /spark-asyspark_2.11.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | -------------------------------------------------------------------------------- /src/META-INF/MANIFEST.MF: -------------------------------------------------------------------------------- 1 | Manifest-Version: 1.0 2 | Main-Class: org.apache.spark.asyspark.core.Main 3 | 4 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/asyml/asysgd/AsyGradientDescent.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.asyml.asysgd 2 | 3 | import breeze.linalg.{DenseVector => BDV} 4 | import com.typesafe.scalalogging.slf4j.StrictLogging 5 | import org.apache.spark.mllib.linalg.{Vector, Vectors} 6 | import org.apache.spark.mllib.optimization.{Gradient, Updater} 7 | import org.apache.spark.rdd.RDD 8 | 9 | import scala.collection.mutable.ArrayBuffer 10 | 11 | /** 12 | * Created by wjf on 16-9-19. 13 | */ 14 | object AsyGradientDescent extends StrictLogging { 15 | /** 16 | * Run asynchronous stochastic gradient descent (SGD) in parallel using mini batches. 17 | * In each iteration, we sample a subset (fraction miniBatchFraction) of the total data 18 | * in order to compute a gradient estimate. 19 | * Sampling, and averaging the subgradients over this subset is performed using one standard 20 | * spark map-reduce in each iteration. 21 | * 22 | * @param data Input data for SGD. RDD of the set of data examples, each of 23 | * the form (label, [feature values]). 24 | * @param gradient Gradient object (used to compute the gradient of the loss function of 25 | * one single data example) 26 | * @param updater Updater function to actually perform a gradient step in a given direction. 27 | * @param stepSize initial step size for the first step 28 | * @param numIterations number of iterations that SGD should be run. 29 | * @param regParam regularization parameter 30 | * @param miniBatchFraction fraction of the input data set that should be used for 31 | * one iteration of SGD. Default value 1.0. 32 | * @param convergenceTol Minibatch iteration will end before numIterations if the relative 33 | * difference between the current weight and the previous weight is less 34 | * than this value. In measuring convergence, L2 norm is calculated. 35 | * Default value 0.001. Must be between 0.0 and 1.0 inclusively. 36 | * @return A tuple containing two elements. The first element is a column matrix containing 37 | * weights for every feature, and the second element is an array containing the 38 | * stochastic loss computed for every iteration. 39 | */ 40 | def runAsySGD( 41 | data: RDD[(Double, Vector)], 42 | gradient: Gradient, 43 | updater: Updater, 44 | stepSize: Double, 45 | numIterations: Int, 46 | regParam: Double, 47 | miniBatchFraction: Double, 48 | initialWeights: Vector, 49 | convergenceTol: Double): (Vector, Array[Double]) = { 50 | 51 | // convergenceTol should be set with non minibatch settings 52 | if (miniBatchFraction < 1.0 && convergenceTol > 0.0) { 53 | logger.warn("Testing against a convergenceTol when using miniBatchFraction " + 54 | "< 1.0 can be unstable because of the stochasticity in sampling.") 55 | } 56 | 57 | if (numIterations * miniBatchFraction < 1.0) { 58 | logger.warn("Not all examples will be used if numIterations * miniBatchFraction < 1.0: " + 59 | s"numIterations=$numIterations and miniBatchFraction=$miniBatchFraction") 60 | } 61 | 62 | val stochasticLossHistory = new ArrayBuffer[Double](numIterations) 63 | 64 | val numExamples = data.count() 65 | 66 | // if no data, return initial weights to avoid NaNs 67 | if (numExamples == 0) { 68 | logger.warn("GradientDescent.runMiniBatchSGD returning initial weights, no data found") 69 | return (initialWeights, stochasticLossHistory.toArray) 70 | } 71 | 72 | if (numExamples * miniBatchFraction < 1) { 73 | logger.warn("The miniBatchFraction is too small") 74 | } 75 | 76 | // Initialize weights as a column vector 77 | val n = data.first()._2.size 78 | GlobalWeight.setN(n) 79 | var weight: Vector = null 80 | if (initialWeights != null) { 81 | val weights = Vectors.dense(initialWeights.toArray) 82 | GlobalWeight.setWeight(weights) 83 | weight = weights 84 | } else { 85 | GlobalWeight.initWeight() 86 | weight = GlobalWeight.getWeight() 87 | } 88 | 89 | 90 | // todo add regval 91 | /** 92 | * For the first iteration, the regVal will be initialized as sum of weight squares 93 | * if it's L2 updater; for L1 updater, the same logic is followed. 94 | */ 95 | 96 | // var regVal = updater.compute(weights, Vectors.zeros(weights.size), 0, 1, regParam)._2 97 | 98 | GlobalWeight.updateWeight(weight, Vectors.zeros(weight.size), 0, 1, regParam, convergenceTol) 99 | var regVal = GlobalWeight.getRegVal() 100 | 101 | data.foreachPartition { 102 | partition => 103 | val array = new ArrayBuffer[(Double, Vector)]() 104 | while(partition.hasNext) { 105 | array += partition.next() 106 | } 107 | var convergence = false 108 | var i = 1 109 | val elementNum = array.size 110 | if (elementNum <= 0) { 111 | logger.warn(s" sorry, this partition has no elements, this worker will stop") 112 | convergence = true 113 | } 114 | while (i <= numIterations && !convergence) { 115 | // todo can do some optimization 116 | // val timesPerIter =10 117 | // for(j <- 0 until timesPerIter) { 118 | // 119 | // } 120 | val bcWeight = GlobalWeight.getWeight() 121 | 122 | 123 | 124 | // todo we can do a sample to avoid use all the data 125 | 126 | // compute gradient 127 | val (gradientSum, lossSum) = array.aggregate((BDV.zeros[Double](n), 0.0))( 128 | seqop = (c, v) => { 129 | val l = gradient.compute(v._2, v._1, bcWeight, Vectors.fromBreeze(c._1)) 130 | (c._1, c._2 + l) 131 | }, 132 | combop = (c1, c2) => { 133 | (c1._1 += c2._1, c1._2 + c2._2) 134 | } 135 | ) 136 | // update gradient 137 | 138 | stochasticLossHistory += lossSum / elementNum + regVal 139 | // todo check whether update success 140 | 141 | val (success, conFlag) = GlobalWeight.updateWeight(bcWeight, Vectors.fromBreeze(gradientSum / elementNum.toDouble), stepSize, i, regParam, convergenceTol) 142 | regVal = GlobalWeight.getRegVal() 143 | if (conFlag) { 144 | convergence = true 145 | } 146 | 147 | i += 1 148 | } 149 | println("this is iii"+ i ) 150 | } 151 | (GlobalWeight.getWeight(), stochasticLossHistory.toArray) 152 | } 153 | 154 | } 155 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/asyml/asysgd/GlobalWeight.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.asyml.asysgd 2 | 3 | import breeze.linalg.{Vector => BV, axpy => brzAxpy, norm => brzNorm} 4 | import org.apache.spark.mllib.linalg.{Vector, Vectors} 5 | 6 | /** 7 | * Created by wjf on 16-9-19. 8 | */ 9 | 10 | object GlobalWeight extends Serializable { 11 | //TODO this is just a demo, will use a cluster to replace this 12 | //TODO the init weight can be optimized 13 | //TODO concurrent control should be optimized 14 | private var n: Int = 0 15 | private var weight: Vector = _ 16 | private var regVal: Double = 0.0 17 | 18 | // must setN first before use GloabalWeight 19 | def setN(n: Int): Unit = { 20 | this.n = n 21 | } 22 | 23 | def getRegVal(): Double = { 24 | this.regVal 25 | } 26 | 27 | def initWeight(): Vector = { 28 | GlobalWeight.synchronized { 29 | if (this.weight == null) { 30 | this.weight = Vectors.dense(new Array[Double](n)) 31 | } 32 | this.weight 33 | } 34 | } 35 | 36 | def setWeight(weight: Vector): Boolean = { 37 | this.weight = weight 38 | true 39 | } 40 | 41 | def getWeight(): Vector = { 42 | require(n > 0, s"you must set n before useing GlabelWeight, temp n is ${n}") 43 | this.weight 44 | } 45 | 46 | def updateWeight( 47 | weightsOld: Vector, 48 | gradient: Vector, 49 | stepSize: Double, 50 | iter: Int, 51 | regParam: Double, convergenceTol: Double): (Boolean, Boolean) = { 52 | // add up both updates from the gradient of the loss (= step) as well as 53 | // the gradient of the regularizer (= regParam * weightsOld) 54 | // w' = w - thisIterStepSize * (gradient + regParam * w) 55 | // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient 56 | 57 | val thisIterStepSize = stepSize / math.sqrt(iter) 58 | val brzWeights: BV[Double] = weightsOld.asBreeze.toDenseVector 59 | brzWeights :*= (1.0 - thisIterStepSize * regParam) 60 | brzAxpy(-thisIterStepSize, gradient.asBreeze, brzWeights) 61 | val norm = brzNorm(brzWeights, 2.0) 62 | val currentWeight = Vectors.fromBreeze(brzWeights) 63 | val flag = isConverged(weightsOld, currentWeight, convergenceTol) 64 | this.weight = Vectors.fromBreeze(brzWeights) 65 | this.regVal = 0.5 * regParam * norm * norm 66 | (true, flag) 67 | } 68 | 69 | private def isConverged( 70 | previousWeights: Vector, 71 | currentWeights: Vector, 72 | convergenceTol: Double): Boolean = { 73 | // To compare with convergence tolerance. 74 | 75 | /** 76 | * maybe there a problems with previousWeights.toDense 77 | * the before version is previousWeights.asBreeze.toDenseVector 78 | */ 79 | val previousBDV = previousWeights.asBreeze.toDenseVector 80 | val currentBDV = currentWeights.asBreeze.toDenseVector 81 | 82 | // This represents the difference of updated weights in the iteration. 83 | val solutionVecDiff: Double = brzNorm(previousBDV - currentBDV) 84 | 85 | solutionVecDiff < convergenceTol * Math.max(brzNorm(currentBDV), 1.0) 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/Client.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core 2 | 3 | import java.io.File 4 | import java.util.concurrent.TimeUnit 5 | 6 | import akka.actor.{Actor, ActorLogging, ActorSystem, Props, _} 7 | import akka.pattern.ask 8 | import akka.remote.RemoteScope 9 | import akka.util.Timeout 10 | import com.typesafe.config.{Config, ConfigFactory} 11 | import com.typesafe.scalalogging.slf4j.StrictLogging 12 | import org.apache.spark.asyspark.core.messages.master.{RegisterClient, ServerList} 13 | import org.apache.spark.asyspark.core.models.client.BigVector 14 | import org.apache.spark.asyspark.core.models.client.asyImp.{AsyBigVectorDouble, AsyBigVectorFloat, AsyBigVectorInt, AsyBigVectorLong} 15 | import org.apache.spark.asyspark.core.models.server.{PartialVectorDouble, PartialVectorFloat, PartialVectorInt, PartialVectorLong} 16 | import org.apache.spark.asyspark.core.partitions.range.RangePartitioner 17 | import org.apache.spark.asyspark.core.partitions.{Partition, Partitioner} 18 | 19 | import scala.concurrent.duration._ 20 | import scala.concurrent.{Await, ExecutionContext, Future} 21 | import scala.reflect.runtime.universe.{TypeTag, typeOf} 22 | /** 23 | * The client provides the functions needed to spawn large distributed matrices and vectors on the parameter servers. 24 | * Use the companion object to construct a Client object from a configuration file. 25 | * Created by wjf on 16-9-24. 26 | */ 27 | class Client(val config: Config, private[asyspark] val system: ActorSystem, 28 | private[asyspark] val master: ActorRef) { 29 | private implicit val timeout = Timeout(config.getDuration("asyspark.client.timeout", TimeUnit.MILLISECONDS) milliseconds) 30 | private implicit val ec = ExecutionContext.Implicits.global 31 | private[asyspark] val actor = system.actorOf(Props[ClientActor]) 32 | // use ask to get a reply 33 | private[asyspark] val registration = master ? RegisterClient(actor) 34 | 35 | 36 | 37 | 38 | /** 39 | * Creates a distributed model on the parameter servers 40 | * 41 | * @param keys The total number of keys 42 | * @param modelsPerServer The number of models to spawn per parameter server 43 | * @param createPartitioner A function that creates a partitioner based on a number of keys and partitions 44 | * @param generateServerProp A function that generates a server prop of a partial model for a particular partition 45 | * @param generateClientObject A function that generates a client object based on the partitioner and spawned models 46 | * @tparam M The final model type to generate 47 | * @return The generated model 48 | */ 49 | private def create[M](keys: Long, 50 | modelsPerServer: Int, 51 | createPartitioner: (Int, Long) => Partitioner, 52 | generateServerProp: Partition => Props, 53 | generateClientObject: (Partitioner, Array[ActorRef], Config) => M): M = { 54 | 55 | // Get a list of servers 56 | val listOfServers = serverList() 57 | println(listOfServers.isCompleted) 58 | 59 | // Construct a big model based on the list of servers 60 | val bigModelFuture = listOfServers.map { servers => 61 | 62 | // Check if there are servers online 63 | if (servers.isEmpty) { 64 | throw new Exception("Cannot create a model without active parameter servers") 65 | } 66 | 67 | // Construct a partitioner 68 | val numberOfPartitions = Math.min(keys, modelsPerServer * servers.length).toInt 69 | val partitioner = createPartitioner(numberOfPartitions, keys) 70 | val partitions = partitioner.all() 71 | 72 | // Spawn models that are deployed on the parameter servers according to the partitioner 73 | val models = new Array[ActorRef](numberOfPartitions) 74 | var partitionIndex = 0 75 | while (partitionIndex < numberOfPartitions) { 76 | val serverIndex = partitionIndex % servers.length 77 | val server = servers(serverIndex) 78 | val partition = partitions(partitionIndex) 79 | val prop = generateServerProp(partition) 80 | models(partitionIndex) = system.actorOf(prop.withDeploy(Deploy(scope = RemoteScope(server.path.address)))) 81 | partitionIndex += 1 82 | } 83 | 84 | // Construct a big model client object 85 | generateClientObject(partitioner, models, config) 86 | } 87 | 88 | // Wait for the big model to finish 89 | Await.result(bigModelFuture, config.getDuration("asyspark.client.timeout", TimeUnit.MILLISECONDS) milliseconds) 90 | 91 | } 92 | def bigVector[V: breeze.math.Semiring: TypeTag](keys: Long, modelsPerServer: Int = 1): BigVector[V] = { 93 | bigVector[V](keys, modelsPerServer, (partitions: Int, keys: Long) => RangePartitioner(partitions, keys)) 94 | } 95 | def bigVector[V: breeze.math.Semiring: TypeTag](keys: Long, modelsPerServer: Int, 96 | createPartitioner: (Int, Long) => Partitioner): BigVector[V] = { 97 | 98 | val propFunction = numberType[V] match { 99 | case "Int" => (partition: Partition) => Props(classOf[PartialVectorInt], partition) 100 | case "Long" => (partition: Partition) => Props(classOf[PartialVectorLong], partition) 101 | case "Flaot" => (partition: Partition) => Props(classOf[PartialVectorFloat], partition) 102 | case "Double" => (partition: Partition) => Props(classOf[PartialVectorDouble], partition) 103 | case x => 104 | throw new Exception(s"cannot create model for unsupported value tupe") 105 | 106 | } 107 | 108 | val objFunction = numberType[V] match { 109 | case "Int" => (partitioner: Partitioner, models: Array[ActorRef], config: Config) => 110 | new AsyBigVectorInt(partitioner, models, config, keys).asInstanceOf[BigVector[V]] 111 | case "Long" => (partitioner: Partitioner, models: Array[ActorRef], config: Config) => 112 | new AsyBigVectorLong(partitioner, models, config, keys).asInstanceOf[BigVector[V]] 113 | case "Float" => (partitioner: Partitioner, models: Array[ActorRef], config: Config) => 114 | new AsyBigVectorFloat(partitioner, models, config, keys).asInstanceOf[BigVector[V]] 115 | case "Double" => (partitioner: Partitioner, models: Array[ActorRef], config: Config) => 116 | new AsyBigVectorDouble(partitioner, models, config, keys).asInstanceOf[BigVector[V]] 117 | case x => throw new Exception(s"Cannot create model for unsupported value type $x") 118 | } 119 | create[BigVector[V]](keys, modelsPerServer, createPartitioner, propFunction, objFunction) 120 | 121 | 122 | 123 | } 124 | 125 | def serverList(): Future[Array[ActorRef]] = { 126 | (master ? ServerList()).mapTo[Array[ActorRef]] 127 | } 128 | 129 | def numberType[V: TypeTag]: String = { 130 | implicitly[TypeTag[V]].tpe match { 131 | case x if x <:< typeOf[Int] => "Int" 132 | case x if x <:< typeOf[Long] => "Long" 133 | case x if x <:< typeOf[Float] => "Float" 134 | case x if x <:< typeOf[Double] => "Double" 135 | case x => s"${x.toString}" 136 | } 137 | } 138 | 139 | /** 140 | * Stops the client 141 | */ 142 | def stop(): Unit ={ 143 | system.terminate() 144 | } 145 | 146 | } 147 | 148 | /** 149 | * Contains functions to easily create a client object that is connected to the asyspark cluster. 150 | * 151 | * You can construct a client with a specific configuration: 152 | * {{{ 153 | * import asyspark.Client 154 | * 155 | * import java.io.File 156 | * import com.typesafe.config.ConfigFactory 157 | * 158 | * val config = ConfigFactory.parseFile(new File("/your/file.conf")) 159 | * val client = Client(config) 160 | * }}} 161 | * 162 | * The resulting client object can then be used to create distributed matrices or vectors on the available parameter 163 | * servers: 164 | * {{{ 165 | * val matrix = client.matrix[Double](10000, 50) 166 | * }}} 167 | */ 168 | object Client extends StrictLogging { 169 | 170 | /** 171 | * Constructs a client with the default configuration 172 | * 173 | * @return The client 174 | */ 175 | def apply(): Client = { 176 | this(ConfigFactory.empty()) 177 | } 178 | 179 | /** 180 | * Constructs a client 181 | * 182 | * @param config The configuration 183 | * @return A future Client 184 | */ 185 | def apply(config: Config): Client = { 186 | val default = ConfigFactory.parseResourcesAnySyntax("asyspark") 187 | val conf = config.withFallback(default).resolve() 188 | Await.result(start(conf), conf.getDuration("asyspark.client.timeout", TimeUnit.MILLISECONDS) milliseconds) 189 | } 190 | 191 | /** 192 | * Implementation to start a client by constructing an ActorSystem and establishing a connection to a master. It 193 | * creates the Client object and checks if its registration actually succeeds 194 | * 195 | * @param config The configuration 196 | * @return The future client 197 | */ 198 | private def start(config: Config): Future[Client] = { 199 | 200 | // Get information from config 201 | logger.debug("start client") 202 | val masterHost = config.getString("asyspark.master.host") 203 | val masterPort = config.getInt("asyspark.master.port") 204 | println(masterPort) 205 | val masterName = config.getString("asyspark.master.name") 206 | val masterSystem = config.getString("asyspark.master.system") 207 | 208 | // Construct system and reference to master 209 | val system = ActorSystem(config.getString("asyspark.client.system"), config.getConfig("asyspark.client")) 210 | val master = system.actorSelection(s"akka.tcp://${masterSystem}@${masterHost}:${masterPort}/user/${masterName}") 211 | 212 | // Set up implicit values for concurrency 213 | implicit val ec = ExecutionContext.Implicits.global 214 | implicit val timeout = Timeout(config.getDuration("asyspark.client.timeout", TimeUnit.MILLISECONDS) milliseconds) 215 | 216 | // Resolve master node asynchronously 217 | val masterFuture = master.resolveOne() 218 | 219 | // Construct client based on resolved master asynchronously 220 | masterFuture.flatMap { 221 | case m => 222 | val client = new Client(config, system, m) 223 | logger.debug("construct a client") 224 | client.registration.map { 225 | case true => client 226 | case _ => throw new RuntimeException("Invalid client registration response from master") 227 | } 228 | } 229 | } 230 | 231 | 232 | } 233 | 234 | /** 235 | * The client actor class. The master keeps a death watch on this actor and knows when it is terminated. 236 | * 237 | * This actor either gets terminated when the system shuts down (e.g. when the Client object is destroyed) or when it 238 | * crashes unexpectedly. 239 | */ 240 | private class ClientActor extends Actor with ActorLogging { 241 | override def receive: Receive = { 242 | case x => log.info(s"Client actor received message ${x}") 243 | } 244 | } 245 | 246 | object TestClient { 247 | def main(args: Array[String]): Unit = { 248 | 249 | val default = ConfigFactory.parseResourcesAnySyntax("asyspark") 250 | val config = ConfigFactory.parseFile(new File(getClass.getClassLoader.getResource("asyspark.conf").getFile)).withFallback(default).resolve() 251 | val client =Client(config) 252 | val vector = client.bigVector[Long](2) 253 | val keys = Array[Long](1,2) 254 | val values = Array[Long](1,2) 255 | // vector.push() 256 | println(vector.size) 257 | 258 | } 259 | 260 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/Exceptios/PullFailedException.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.Exceptios 2 | 3 | /** 4 | * Created by wjf on 16-9-25. 5 | */ 6 | private[asyspark] class PullFailedException(message: String) extends Exception(message){ 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/Exceptios/PushFailedException.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.Exceptios 2 | 3 | /** 4 | * Created by wjf on 16-9-25. 5 | */ 6 | class PushFailedException(message: String) extends Exception(message) { 7 | 8 | } 9 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/Main.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core 2 | 3 | import java.io.File 4 | 5 | import com.typesafe.config.ConfigFactory 6 | import com.typesafe.scalalogging.slf4j.StrictLogging 7 | import org.apache.log4j.BasicConfigurator 8 | 9 | import scala.concurrent.ExecutionContext 10 | 11 | /** 12 | * This is the main class that runs when you start asyspark. By manually specifying additional command-line options it is 13 | * possible to start a master node or a parameter server. 14 | * 15 | * To start a master node: 16 | * {{{ 17 | * java -jar /path/to/compiled/asyspark.jar master -c /path/to/asyspark.conf 18 | * }}} 19 | * 20 | * To start a parameter server node: 21 | * {{{ 22 | * java -jar /path/to/compiled/asyspark.jar server -c /path/to/asyspark.conf 23 | * }}} 24 | * 25 | * Alternatively you can use the scripts provided in the ./sbin/ folder of the project to automatically construct a 26 | * master and servers over passwordless ssh. 27 | */ 28 | object Main extends StrictLogging { 29 | 30 | /** 31 | * Main entry point of the application 32 | * 33 | * @param args The command-line arguments 34 | */ 35 | def main(args: Array[String]): Unit = { 36 | // BasicConfigurator.configure() 37 | 38 | val parser = new scopt.OptionParser[Options]("asyspark") { 39 | head("asyspark", "0.1") 40 | opt[File]('c', "config") valueName "" action { (x, c) => 41 | c.copy(config = x) 42 | } text "The .conf file for asyspark" 43 | cmd("master") action { (_, c) => 44 | c.copy(mode = "master") 45 | } text "Starts a master node." 46 | cmd("server") action { (_, c) => 47 | c.copy(mode = "server") 48 | } text "Starts a server node." 49 | } 50 | 51 | parser.parse(args, Options()) match { 52 | case Some(options) => 53 | 54 | // Read configuration 55 | logger.debug("Parsing configuration file") 56 | val default = ConfigFactory.parseResourcesAnySyntax("asyspark") 57 | val config = ConfigFactory.parseFile(options.config).withFallback(default).resolve() 58 | 59 | 60 | // Start specified mode of operation 61 | implicit val ec = ExecutionContext.Implicits.global 62 | options.mode match { 63 | case "server" => Server.run(config).onSuccess { 64 | case (system, ref) => sys.addShutdownHook { 65 | logger.info("Shutting down") 66 | system.shutdown() 67 | system.awaitTermination() 68 | } 69 | } 70 | case "master" => Master.run(config).onSuccess { 71 | case (system, ref) => sys.addShutdownHook { 72 | logger.info("Shutting down") 73 | system.shutdown() 74 | system.awaitTermination() 75 | } 76 | } 77 | case _ => 78 | parser.showUsageAsError 79 | System.exit(1) 80 | } 81 | case None => System.exit(1) 82 | } 83 | 84 | } 85 | 86 | /** 87 | * Command-line options 88 | * 89 | * @param mode The mode of operation (either "master" or "server") 90 | * @param config The configuration file to load (defaults to the included asyspark.conf) 91 | * @param host The host of the parameter server (only when mode of operation is "server") 92 | * @param port The port of the parameter server (only when mode of operation is "server") 93 | */ 94 | private case class Options(mode: String = "", 95 | config: File = new File(getClass.getClassLoader.getResource("asyspark.conf").getFile), 96 | host: String = "localhost", 97 | port: Int = 0) 98 | 99 | } 100 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/Master.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core 2 | 3 | import java.util.concurrent.TimeUnit 4 | 5 | import akka.actor.{Actor, ActorLogging, ActorRef, ActorSystem, Address, Props, Terminated} 6 | import akka.util.Timeout 7 | import com.typesafe.config.Config 8 | import com.typesafe.scalalogging.slf4j.StrictLogging 9 | import org.apache.spark.asyspark.core.messages.master.{ClientList, RegisterClient, RegisterServer, ServerList} 10 | 11 | import scala.concurrent.duration._ 12 | import scala.concurrent.{ExecutionContext, Future} 13 | 14 | 15 | /** 16 | * Created by wjf on 16-9-22. 17 | * the master that registers the spark workers 18 | */ 19 | class Master() extends Actor with ActorLogging { 20 | 21 | /** 22 | * collection of servers available 23 | */ 24 | var servers = Set.empty[ActorRef] 25 | 26 | /** 27 | * collection of clients 28 | */ 29 | var clients = Set.empty[ActorRef] 30 | 31 | override def receive: Receive = { 32 | case RegisterServer(server) => 33 | log.info(s"Registering server ${server.path.toString}") 34 | println("register server") 35 | servers += server 36 | context.watch(server) 37 | sender ! true 38 | 39 | case RegisterClient(client) => 40 | log.info(s"Registering client ${sender.path.toString}") 41 | clients += client 42 | context.watch(client) 43 | sender ! true 44 | 45 | case ServerList() => 46 | log.info(s"Sending current server list to ${sender.path.toString}") 47 | sender ! servers.toArray 48 | 49 | case ClientList() => 50 | log.info(s"Sending current client list to ${sender.path.toString}") 51 | sender ! clients.toArray 52 | 53 | 54 | case Terminated(actor) => 55 | actor match { 56 | case server: ActorRef if servers contains server => 57 | log.info(s"Removing server ${server.path.toString}") 58 | servers -= server 59 | case client: ActorRef if clients contains client => 60 | log.info(s"Removing client ${client.path.toString}") 61 | clients -= client 62 | case actor: ActorRef => 63 | log.warning(s"Actor ${actor.path.toString} will be terminated for some unknown reason") 64 | } 65 | } 66 | 67 | } 68 | 69 | object Master extends StrictLogging { 70 | def run(config: Config): Future[(ActorSystem, ActorRef)] = { 71 | logger.debug("Starting master actor system") 72 | val system = ActorSystem(config.getString("asyspark.master.system"), config.getConfig("asyspark.master")) 73 | logger.debug("Starting master") 74 | val master = system.actorOf(Props[Master], config.getString("asyspark.master.name")) 75 | implicit val timeout = Timeout(config.getDuration("asyspark.master.startup-timeout", TimeUnit.MILLISECONDS) milliseconds) 76 | implicit val ec = ExecutionContext.Implicits.global 77 | val address = Address("akka.tcp", config.getString("asyspark.master.system"), config.getString("asyspark.master.host"), 78 | config.getString("asyspark.master.port").toInt) 79 | system.actorSelection(master.path.toSerializationFormat).resolveOne().map { 80 | case actor: ActorRef => 81 | logger.debug("Master successfully started") 82 | (system, master) 83 | 84 | } 85 | } 86 | 87 | } 88 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/Server.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core 2 | 3 | /** 4 | * Created by wjf on 16-9-22. 5 | */ 6 | 7 | import java.util.concurrent.TimeUnit 8 | 9 | import akka.actor._ 10 | import akka.pattern.ask 11 | import akka.util.Timeout 12 | import com.typesafe.config.Config 13 | import com.typesafe.scalalogging.slf4j.StrictLogging 14 | import org.apache.spark.asyspark.core.messages.master.RegisterServer 15 | 16 | import scala.concurrent.duration._ 17 | import scala.concurrent.{ExecutionContext, Future} 18 | 19 | /** 20 | * A parameter server 21 | */ 22 | private class Server extends Actor with ActorLogging { 23 | 24 | override def receive: Receive = { 25 | case x => 26 | log.warning(s"Received unknown message of type ${x.getClass}") 27 | } 28 | } 29 | /** 30 | * The parameter server object 31 | */ 32 | private object Server extends StrictLogging { 33 | 34 | /** 35 | * Starts a parameter server ready to receive commands 36 | * 37 | * @param config The configuration 38 | * @return A future containing the started actor system and reference to the server actor 39 | */ 40 | def run(config: Config): Future[(ActorSystem, ActorRef)] = { 41 | println("server run") 42 | 43 | 44 | logger.debug(s"Starting actor system ${config.getString("asyspark.server.system")}") 45 | val system = ActorSystem(config.getString("asyspark.server.system"), config.getConfig("asyspark.server")) 46 | 47 | logger.debug("Starting server actor") 48 | val server = system.actorOf(Props[Server], config.getString("asyspark.server.name")) 49 | println(server.path) 50 | 51 | logger.debug("Reading master information from config") 52 | val masterHost = config.getString("asyspark.master.host") 53 | val masterPort = config.getInt("asyspark.master.port") 54 | val masterName = config.getString("asyspark.master.name") 55 | val masterSystem = config.getString("asyspark.master.system") 56 | 57 | logger.info(s"Registering with master ${masterSystem}@${masterHost}:${masterPort}/user/${masterName}") 58 | implicit val ec = ExecutionContext.Implicits.global 59 | implicit val timeout = Timeout(config.getDuration("asyspark.server.registration-timeout", TimeUnit.MILLISECONDS) milliseconds) 60 | val master = system.actorSelection(s"akka.tcp://${masterSystem}@${masterHost}:${masterPort}/user/${masterName}") 61 | val registration = master ? RegisterServer(server) 62 | 63 | 64 | 65 | registration.map { 66 | case a => 67 | logger.info("Server successfully registered with master") 68 | (system, server) 69 | } 70 | 71 | } 72 | } 73 | 74 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/master/ClientList.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.master 2 | 3 | /** 4 | * Created by wjf on 16-9-24. 5 | */ 6 | private[asyspark] case class ClientList() 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/master/RegisterClient.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.master 2 | 3 | import akka.actor.ActorRef 4 | 5 | /** 6 | * Created by wjf on 16-9-24. 7 | */ 8 | private[asyspark] case class RegisterClient(client: ActorRef) 9 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/master/RegisterServer.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.master 2 | 3 | import akka.actor.ActorRef 4 | 5 | /** 6 | * Created by wjf on 16-9-22. 7 | */ 8 | private[asyspark] case class RegisterServer(server: ActorRef) 9 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/master/ServerList.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.master 2 | 3 | /** 4 | * Created by wjf on 16-9-22. 5 | */ 6 | private[asyspark] case class ServerList() 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/AcknowledgeReceipt.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.logic 2 | 3 | /** 4 | * Created by wjf on 16-9-25. 5 | */ 6 | private[asyspark]case class AcknowledgeReceipt(id: Int) 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/Forget.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.logic 2 | 3 | /** 4 | * Created by wjf on 16-9-25. 5 | */ 6 | private[asyspark] case class Forget(id: Int) 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/GenerateUniqueID.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.logic 2 | 3 | /** 4 | * Created by wjf on 16-9-25. 5 | */ 6 | private[asyspark] case class GenerateUniqueID() -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/NotAcknowledgeReceipt.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.logic 2 | 3 | /** 4 | * Created by wjf on 16-9-25. 5 | */ 6 | private[asyspark] case class NotAcknowledgeReceipt(id: Long) 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/UniqueID.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.logic 2 | 3 | /** 4 | * Created by wjf on 16-9-25. 5 | */ 6 | private[asyspark] case class UniqueID(id: Int) 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PullVector.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.request 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] case class PullVector(keys: Array[Long]) extends Request 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PushVectorDouble.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.request 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] case class PushVectorDouble(id: Int, keys: Array[Long], values: Array[Double]) extends Request 7 | 8 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PushVectorFloat.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.request 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] case class PushVectorFloat(id: Int, keys:Array[Long], values:Array[Float]) extends Request 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PushVectorInt.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.request 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] case class PushVectorInt(id: Int, keys: Array[Long], values: Array[Int]) extends Request 7 | 8 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PushVectorLong.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.request 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] case class PushVectorLong(id: Int, keys: Array[Long], values:Array[Long]) extends Request 7 | 8 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/request/Request.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.request 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] trait Request 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/response/Response.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.response 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] trait Response 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/response/ResponseDouble.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.response 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] case class ResponseDouble(values: Array[Double]) extends Response 7 | 8 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/response/ResponseFloat.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.response 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] case class ResponseFloat(values: Array[Float]) extends Response 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/response/ResponseInt.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.response 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] case class ResponseInt(values: Array[Int]) extends Response 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/messages/server/response/ResponseLong.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.messages.server.response 2 | 3 | /** 4 | * Created by wjf on 16-9-23. 5 | */ 6 | private[asyspark] case class ResponseLong(values: Array[Long]) extends Response 7 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/client/BigVector.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.client 2 | 3 | import scala.concurrent.{ExecutionContext, Future} 4 | 5 | /** 6 | * A big vector supporting basic parameter server element-wise operations 7 | * {{{ 8 | * val vector:BigVector[Int] = ... 9 | * vector.pull() 10 | * vector.push() 11 | * vector.destroy() 12 | * }}} 13 | * Created by wjf on 16-9-25. 14 | */ 15 | trait BigVector[V] extends Serializable { 16 | val size: Long 17 | 18 | 19 | /** 20 | * Pull a set of elements 21 | * @param keys The Indices of the elements 22 | * @param ec The implicit execution context in which to execute the request 23 | * @return A Future containing the values of the elements at given rows, columns 24 | */ 25 | //TODO it's convenient fo use scala.concurrent.ExecutionContext.Implicits.global 26 | // but we should do some optimization in production env 27 | def pull(keys: Array[Long])(implicit ec:ExecutionContext): Future[Array[V]] 28 | 29 | /** 30 | * Push a set of elements 31 | * @param keys The indices of the rows 32 | * @param values The values to update 33 | * @param ec The Implicit execution context 34 | * @return A future 35 | */ 36 | def push(keys: Array[Long], values: Array[V])(implicit ec: ExecutionContext): Future[Boolean] 37 | 38 | /** 39 | * Destroy the big vectors and its resources on the parameter server 40 | * @param ec The implicit execution context in which to execute the request 41 | * @return A future 42 | */ 43 | def destroy()(implicit ec: ExecutionContext): Future[Boolean] 44 | } 45 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVector.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.client.asyImp 2 | 3 | import java.io.{IOException, ObjectInputStream, ObjectOutputStream} 4 | 5 | import akka.actor.{ActorRef, ExtendedActorSystem} 6 | import akka.pattern.Patterns.gracefulStop 7 | import akka.serialization.JavaSerializer 8 | import breeze.linalg.DenseVector 9 | import breeze.math.Semiring 10 | import com.typesafe.config.Config 11 | import org.apache.spark.asyspark.core.messages.server.request.PullVector 12 | import org.apache.spark.asyspark.core.models.client.BigVector 13 | import org.apache.spark.asyspark.core.partitions.{Partition, Partitioner} 14 | import spire.implicits._ 15 | 16 | import scala.concurrent.duration._ 17 | import scala.concurrent.{ExecutionContext, Future} 18 | import scala.reflect.ClassTag 19 | 20 | /** 21 | * An asynchronous implementation of the [[glint.models.client.BigVector BigVector]]. You don't want to construct this 22 | * object manually but instead use the methods provided in [[glint.Client Client]], as so 23 | * {{{ 24 | * client.vector[Double](keys) 25 | * }}} 26 | * 27 | * @param partitioner A partitioner to map keys to parameter servers 28 | * @param models The partial models on the parameter servers 29 | * @param size The number of keys 30 | * @tparam V The type of values to store 31 | * @tparam R The type of responses we expect to get from the parameter servers 32 | * @tparam P The type of push requests we should send to the parameter servers 33 | */ 34 | abstract class AsyBigVector[@specialized V: Semiring : ClassTag, R: ClassTag, P: ClassTag](partitioner: Partitioner, 35 | models: Array[ActorRef], 36 | config: Config, 37 | val size: Long) 38 | extends BigVector[V] { 39 | 40 | PullFSM.initialize(config) 41 | PushFSM.initialize(config) 42 | 43 | /** 44 | * Pulls a set of elements 45 | * 46 | * @param keys The indices of the keys 47 | * @return A future containing the values of the elements at given rows, columns 48 | */ 49 | override def pull(keys: Array[Long])(implicit ec: ExecutionContext): Future[Array[V]] = { 50 | 51 | // Send pull request of the list of keys 52 | val pulls = mapPartitions(keys) { 53 | case (partition, indices) => 54 | val pullMessage = PullVector(indices.map(keys).toArray) 55 | val fsm = PullFSM[PullVector, R](pullMessage, models(partition.index)) 56 | fsm.run() 57 | //(models(partition.index) ? pullMessage).mapTo[R] 58 | } 59 | 60 | // Obtain key indices after partitioning so we can place the results in a correctly ordered array 61 | val indices = keys.zipWithIndex.groupBy { 62 | case (k, i) => partitioner.partition(k) 63 | }.map { 64 | case (_, arr) => arr.map(_._2) 65 | }.toArray 66 | 67 | // Define aggregator for successful responses 68 | def aggregateSuccess(responses: Iterable[R]): Array[V] = { 69 | val responsesArray = responses.toArray 70 | val result = DenseVector.zeros[V](keys.length) 71 | cforRange(0 until responsesArray.length)(i => { 72 | val response = responsesArray(i) 73 | val idx = indices(i) 74 | cforRange(0 until idx.length)(j => { 75 | result(idx(j)) = toValue(response, j) 76 | }) 77 | }) 78 | result.toArray 79 | } 80 | 81 | // Combine and aggregate futures 82 | Future.sequence(pulls).transform(aggregateSuccess, err => err) 83 | 84 | } 85 | 86 | /** 87 | * Groups indices of given keys into partitions according to this models partitioner and maps each partition to a 88 | * type T 89 | * 90 | * @param keys The keys to partition 91 | * @param func The function that takes a partition and corresponding indices and creates something of type T 92 | * @tparam T The type to map to 93 | * @return An iterable over the partitioned results 94 | */ 95 | @inline 96 | private def mapPartitions[T](keys: Seq[Long])(func: (Partition, Seq[Int]) => T): Iterable[T] = { 97 | keys.indices.groupBy(i => partitioner.partition(keys(i))).map { case (a, b) => func(a, b) } 98 | } 99 | 100 | /** 101 | * Pushes a set of values 102 | * 103 | * @param keys The indices of the keys 104 | * @param values The values to update 105 | * @return A future containing either the success or failure of the operation 106 | */ 107 | override def push(keys: Array[Long], values: Array[V])(implicit ec: ExecutionContext): Future[Boolean] = { 108 | 109 | // Send push requests 110 | val pushes = mapPartitions(keys) { 111 | case (partition, indices) => 112 | val ks = indices.map(keys).toArray 113 | val vs = indices.map(values).toArray 114 | val fsm = PushFSM[P]((id) => toPushMessage(id, ks, vs), models(partition.index)) 115 | fsm.run() 116 | } 117 | 118 | // Combine and aggregate futures 119 | Future.sequence(pushes).transform(results => true, err => err) 120 | 121 | } 122 | 123 | /** 124 | * @return The number of partitions this big vector's data is spread across 125 | */ 126 | def nrOfPartitions: Int = { 127 | partitioner.all().length 128 | } 129 | 130 | /** 131 | * Destroys the matrix on the parameter servers 132 | * 133 | * @return A future whether the matrix was successfully destroyed 134 | */ 135 | override def destroy()(implicit ec: ExecutionContext): Future[Boolean] = { 136 | val partitionFutures = partitioner.all().map { 137 | case partition => gracefulStop(models(partition.index), 60 seconds) 138 | }.toIterator 139 | Future.sequence(partitionFutures).transform(successes => successes.forall(success => success), err => err) 140 | } 141 | 142 | /** 143 | * Extracts a value from a given response at given index 144 | * 145 | * @param response The response 146 | * @param index The index 147 | * @return The value 148 | */ 149 | @inline 150 | protected def toValue(response: R, index: Int): V 151 | 152 | /** 153 | * Creates a push message from given sequence of keys and values 154 | * 155 | * @param id The identifier 156 | * @param keys The rows 157 | * @param values The values 158 | * @return A PushMatrix message for type V 159 | */ 160 | @inline 161 | protected def toPushMessage(id: Int, keys: Array[Long], values: Array[V]): P 162 | 163 | /** 164 | * Deserializes this instance. This starts an ActorSystem with appropriate configuration before attempting to 165 | * deserialize ActorRefs 166 | * 167 | * @param in The object input stream 168 | * @throws java.io.IOException A possible Input/Output exception 169 | */ 170 | @throws(classOf[IOException]) 171 | private def readObject(in: ObjectInputStream): Unit = { 172 | val config = in.readObject().asInstanceOf[Config] 173 | val as = DeserializationHelper.getActorSystem(config.getConfig("glint.client")) 174 | JavaSerializer.currentSystem.withValue(as.asInstanceOf[ExtendedActorSystem]) { 175 | in.defaultReadObject() 176 | } 177 | } 178 | 179 | /** 180 | * Serializes this instance. This first writes the config before the entire object to ensure we can deserialize with 181 | * an ActorSystem in place 182 | * 183 | * @param out The object output stream 184 | * @throws java.io.IOException A possible Input/Output exception 185 | */ 186 | @throws(classOf[IOException]) 187 | private def writeObject(out: ObjectOutputStream): Unit = { 188 | out.writeObject(config) 189 | out.defaultWriteObject() 190 | } 191 | 192 | } 193 | 194 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVectorDouble.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.client.asyImp 2 | import akka.actor.ActorRef 3 | import com.typesafe.config.Config 4 | import org.apache.spark.asyspark.core.messages.server.request.PushVectorDouble 5 | import org.apache.spark.asyspark.core.messages.server.response.ResponseDouble 6 | import org.apache.spark.asyspark.core.partitions.Partitioner 7 | 8 | /** 9 | * Asynchronous implementation of a BigVector for doubles 10 | */ 11 | private[asyspark] class AsyBigVectorDouble(partitioner: Partitioner, 12 | models: Array[ActorRef], 13 | config: Config, 14 | keys: Long) 15 | extends AsyBigVector[Double, ResponseDouble, PushVectorDouble](partitioner, models, config, keys) { 16 | 17 | /** 18 | * Creates a push message from given sequence of keys and values 19 | * 20 | * @param id The identifier 21 | * @param keys The keys 22 | * @param values The values 23 | * @return A PushVectorDouble message for type V 24 | */ 25 | @inline 26 | override protected def toPushMessage(id: Int, keys: Array[Long], values: Array[Double]): PushVectorDouble = { 27 | PushVectorDouble(id, keys, values) 28 | } 29 | 30 | /** 31 | * Extracts a value from a given response at given index 32 | * 33 | * @param response The response 34 | * @param index The index 35 | * @return The value 36 | */ 37 | @inline 38 | override protected def toValue(response: ResponseDouble, index: Int): Double = response.values(index) 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVectorFloat.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.client.asyImp 2 | 3 | 4 | import akka.actor.ActorRef 5 | import com.typesafe.config.Config 6 | import org.apache.spark.asyspark.core.messages.server.request.PushVectorFloat 7 | import org.apache.spark.asyspark.core.messages.server.response.ResponseFloat 8 | import org.apache.spark.asyspark.core.partitions.Partitioner 9 | 10 | /** 11 | * Asynchronous implementation of a BigVector for floats 12 | */ 13 | private[asyspark] class AsyBigVectorFloat(partitioner: Partitioner, 14 | models: Array[ActorRef], 15 | config: Config, 16 | keys: Long) 17 | extends AsyBigVector[Float, ResponseFloat, PushVectorFloat](partitioner, models, config, keys) { 18 | 19 | /** 20 | * Creates a push message from given sequence of keys and values 21 | * 22 | * @param id The identifier 23 | * @param keys The keys 24 | * @param values The values 25 | * @return A PushVectorFloat message for type V 26 | */ 27 | @inline 28 | override protected def toPushMessage(id: Int, keys: Array[Long], values: Array[Float]): PushVectorFloat = { 29 | PushVectorFloat(id, keys, values) 30 | } 31 | 32 | /** 33 | * Extracts a value from a given response at given index 34 | * 35 | * @param response The response 36 | * @param index The index 37 | * @return The value 38 | */ 39 | @inline 40 | override protected def toValue(response: ResponseFloat, index: Int): Float = response.values(index) 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVectorInt.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.client.asyImp 2 | 3 | 4 | import akka.actor.ActorRef 5 | import com.typesafe.config.Config 6 | import org.apache.spark.asyspark.core.messages.server.request.PushVectorInt 7 | import org.apache.spark.asyspark.core.messages.server.response.ResponseInt 8 | import org.apache.spark.asyspark.core.partitions.Partitioner 9 | 10 | /** 11 | * Asynchronous implementation of a BigVector for integers 12 | */ 13 | private[asyspark] class AsyBigVectorInt(partitioner: Partitioner, 14 | models: Array[ActorRef], 15 | config: Config, 16 | keys: Long) 17 | extends AsyBigVector[Int, ResponseInt, PushVectorInt](partitioner, models, config, keys) { 18 | 19 | /** 20 | * Creates a push message from given sequence of keys and values 21 | * 22 | * @param id The identifier 23 | * @param keys The keys 24 | * @param values The values 25 | * @return A PushVectorInt message for type V 26 | */ 27 | @inline 28 | override protected def toPushMessage(id: Int, keys: Array[Long], values: Array[Int]): PushVectorInt = { 29 | PushVectorInt(id, keys, values) 30 | } 31 | 32 | /** 33 | * Extracts a value from a given response at given index 34 | * 35 | * @param response The response 36 | * @param index The index 37 | * @return The value 38 | */ 39 | @inline 40 | override protected def toValue(response: ResponseInt, index: Int): Int = response.values(index) 41 | 42 | } 43 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVectorLong.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.client.asyImp 2 | 3 | import org.apache.spark.asyspark.core.partitions.Partitioner 4 | import akka.actor.ActorRef 5 | import com.typesafe.config.Config 6 | import org.apache.spark.asyspark.core.messages.server.request.PushVectorLong 7 | import org.apache.spark.asyspark.core.messages.server.response.ResponseLong 8 | /** 9 | * Created by wjf on 16-9-26. 10 | */ 11 | private[asyspark] class AsyBigVectorLong(partitioner: Partitioner, 12 | models: Array[ActorRef], 13 | config: Config, 14 | keys: Long) 15 | extends AsyBigVector[Long, ResponseLong, PushVectorLong](partitioner, models, config, keys) { 16 | /** 17 | * Creates a push message from given sequence of keys and values 18 | * 19 | * @param id The identifier 20 | * @param keys The keys 21 | * @param values The values 22 | * @return A PushVectorLong message for type V 23 | */ 24 | @inline 25 | override protected def toPushMessage(id: Int, keys: Array[Long], values: Array[Long]): PushVectorLong = { 26 | PushVectorLong(id, keys, values) 27 | } 28 | 29 | /** 30 | * Extracts a value from a given response at given index 31 | * 32 | * @param response The response 33 | * @param index The index 34 | * @return The value 35 | */ 36 | @inline 37 | override protected def toValue(response: ResponseLong, index: Int): Long = response.values(index) 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/DeserializationHelper.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.client.asyImp 2 | 3 | import akka.actor.ActorSystem 4 | import com.typesafe.config.Config 5 | 6 | /** Singleton pattern 7 | * Created by wjf on 16-9-26. 8 | */ 9 | object DeserializationHelper { 10 | @volatile private var as: ActorSystem = null 11 | 12 | def getActorSystem(config: Config): ActorSystem = { 13 | if(as == null) { 14 | DeserializationHelper.synchronized { 15 | if (as == null) { 16 | as = ActorSystem("asysparkClient", config) 17 | } 18 | } 19 | } 20 | as 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/PullFSM.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.client.asyImp 2 | 3 | 4 | import java.util.concurrent.TimeUnit 5 | 6 | import akka.actor.ActorRef 7 | import akka.pattern.{AskTimeoutException, ask} 8 | import akka.util.Timeout 9 | import com.typesafe.config.Config 10 | import org.apache.spark.asyspark.core.Exceptios.PullFailedException 11 | 12 | import scala.concurrent.{ExecutionContext, Future, Promise} 13 | import scala.concurrent.duration._ 14 | import scala.reflect.ClassTag 15 | 16 | /** 17 | * Created by wjf on 16-9-25. 18 | */ 19 | class PullFSM[T, R: ClassTag](message: T, 20 | actorRef: ActorRef, 21 | maxAttempts: Int, 22 | initialTimeout: FiniteDuration, 23 | maxTimeout: FiniteDuration, 24 | backoffFactor: Double)(implicit ec: ExecutionContext) { 25 | private implicit var timeout: Timeout = new Timeout(initialTimeout) 26 | 27 | private var attempts = 0 28 | private var runflag =false 29 | private val promise: Promise[R] = Promise[R]() 30 | 31 | def run(): Future[R] ={ 32 | if(!runflag) { 33 | execute() 34 | runflag = true 35 | } 36 | promise.future 37 | } 38 | 39 | private def execute(): Unit ={ 40 | if(attempts < maxAttempts) { 41 | attempts += 1 42 | request() 43 | } else { 44 | promise.failure(new PullFailedException(s"Failed $attempts while the maxAttempts is $maxAttempts to pull data")) 45 | } 46 | } 47 | 48 | private def request(): Unit ={ 49 | val request = actorRef ? message 50 | request.onFailure { 51 | case ex: AskTimeoutException => 52 | timeBackoff() 53 | execute() 54 | case _ => 55 | execute() 56 | } 57 | request.onSuccess { 58 | case response: R => 59 | promise.success(response) 60 | case _ => 61 | execute() 62 | } 63 | } 64 | 65 | private def timeBackoff(): Unit ={ 66 | 67 | if(timeout.duration.toMillis * backoffFactor > maxTimeout.toMillis) { 68 | timeout = new Timeout(maxTimeout) 69 | } else { 70 | timeout = new Timeout(timeout.duration.toMillis * backoffFactor millis) 71 | } 72 | } 73 | 74 | 75 | 76 | } 77 | 78 | object PullFSM { 79 | 80 | private var maxAttempts: Int = 10 81 | private var initialTimeout: FiniteDuration = 5 seconds 82 | private var maxTimeout: FiniteDuration = 5 minutes 83 | private var backoffFactor: Double = 1.6 84 | 85 | /** 86 | * Initializes the FSM default parameters with those specified in given config 87 | * @param config The configuration to use 88 | */ 89 | def initialize(config: Config): Unit = { 90 | maxAttempts = config.getInt("asyspark.pull.maximum-attempts") 91 | initialTimeout = new FiniteDuration(config.getDuration("asyspark.pull.initial-timeout", TimeUnit.MILLISECONDS), 92 | TimeUnit.MILLISECONDS) 93 | maxTimeout = new FiniteDuration(config.getDuration("asyspark.pull.maximum-timeout", TimeUnit.MILLISECONDS), 94 | TimeUnit.MILLISECONDS) 95 | backoffFactor = config.getInt("asyspark.pull.backoff-multiplier") 96 | } 97 | 98 | /** 99 | * Constructs a new FSM for given message and actor 100 | * 101 | * @param message The pull message to send 102 | * @param actorRef The actor to send to 103 | * @param ec The execution context 104 | * @tparam T The type of message to send 105 | * @return An new and initialized PullFSM 106 | */ 107 | def apply[T, R : ClassTag](message: T, actorRef: ActorRef)(implicit ec: ExecutionContext): PullFSM[T, R] = { 108 | new PullFSM[T, R](message, actorRef, maxAttempts, initialTimeout, maxTimeout, backoffFactor) 109 | } 110 | 111 | } 112 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/PushFSM.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.client.asyImp 2 | 3 | import java.util.concurrent.TimeUnit 4 | 5 | import akka.actor.ActorRef 6 | import akka.pattern.{AskTimeoutException, ask} 7 | import akka.util.Timeout 8 | import com.typesafe.config.Config 9 | import org.apache.spark.asyspark.core.Exceptios.PushFailedException 10 | import org.apache.spark.asyspark.core.messages.server.logic._ 11 | 12 | import scala.concurrent.duration.{FiniteDuration, _} 13 | import scala.concurrent.{ExecutionContext, Future, Promise} 14 | 15 | 16 | /** 17 | * A push-mechanism using a finite state machine to guarantee exactly-once delivery with multiple attempts 18 | * 19 | * @param message A function that takes an identifier and generates a message of type T 20 | * @param actorRef The actor to send to 21 | * @param maxAttempts The maximum number of attempts 22 | * @param maxLogicAttempts The maximum number of attempts to establish communication through logic channels 23 | * @param initialTimeout The initial timeout for the request 24 | * @param maxTimeout The maximum timeout for the request 25 | * @param backoffFactor The backoff multiplier 26 | * @param ec The execution context 27 | * @tparam T The type of message to send 28 | * @author wjf created 16-9-25 29 | */ 30 | class PushFSM[T](message: Int => T, 31 | actorRef: ActorRef, 32 | maxAttempts: Int, 33 | maxLogicAttempts: Int, 34 | initialTimeout: FiniteDuration, 35 | maxTimeout: FiniteDuration, 36 | backoffFactor: Double)(implicit ec: ExecutionContext) { 37 | 38 | private implicit var timeout: Timeout = new Timeout(initialTimeout) 39 | 40 | /** 41 | * Unique identifier for this push 42 | */ 43 | private var id =0 44 | 45 | /** 46 | * Counter for the number of push attempts 47 | */ 48 | private var attempts = 0 49 | 50 | private var logicAttempts = 0 51 | 52 | /** 53 | * Flag to make sure only one request at some moment 54 | */ 55 | private var runFlag = false 56 | 57 | private val promise: Promise[Boolean] = Promise[Boolean]() 58 | 59 | /** 60 | * Run the push request 61 | * @return 62 | */ 63 | def run(): Future[Boolean] = { 64 | if(!runFlag) { 65 | prepare() 66 | runFlag = true 67 | } 68 | promise.future 69 | } 70 | 71 | /** 72 | * Prepare state 73 | * obtain a unique and available identifier from the parameter server for the next push request 74 | */ 75 | private def prepare(): Unit ={ 76 | val prepareFuture = actorRef ? GenerateUniqueID() 77 | prepareFuture.onSuccess { 78 | case UniqueID(identifier) => 79 | id = identifier 80 | execute() 81 | case _ => 82 | retry(prepare) 83 | } 84 | prepareFuture.onFailure { 85 | case ex: AskTimeoutException => 86 | timeBackoff() 87 | retry(prepare) 88 | case _ => 89 | retry(prepare) 90 | } 91 | } 92 | 93 | /** 94 | * Execute the push request performing a single push 95 | */ 96 | private def execute(): Unit = { 97 | if(attempts >= maxAttempts) { 98 | promise.failure(new PushFailedException(s"Failed ${attempts} is equal to maxAttempts ${maxAttempts} to push data")) 99 | } else { 100 | attempts += 1 101 | actorRef ! message(id) 102 | acknowledge() 103 | } 104 | } 105 | 106 | /** 107 | * Acknowledge state 108 | * We keep sending acknowledge messages until we either receive a acknowledge or notAcknowledge 109 | */ 110 | private def acknowledge(): Unit ={ 111 | val ackFuture = actorRef ? AcknowledgeReceipt(id) 112 | ackFuture.onSuccess { 113 | case NotAcknowledgeReceipt(identifier) if identifier == id => 114 | execute() 115 | case AcknowledgeReceipt(identifier) if identifier == id => 116 | promise.success(true) 117 | forget() 118 | case _ => 119 | retry(acknowledge) 120 | } 121 | ackFuture.onFailure { 122 | case ex:AskTimeoutException => 123 | timeBackoff() 124 | retry(acknowledge) 125 | case _ => 126 | retry(acknowledge) 127 | } 128 | 129 | } 130 | 131 | /** 132 | * Forget state 133 | * We keep sending forget messages until we receive a successful reply 134 | */ 135 | private def forget(): Unit ={ 136 | val forgetFuture = actorRef ? Forget(id) 137 | forgetFuture.onSuccess { 138 | case Forget(identifier) if identifier == id => 139 | () 140 | case _ => 141 | forget() 142 | } 143 | forgetFuture.onFailure { 144 | case ex: AskTimeoutException => 145 | timeBackoff() 146 | forget() 147 | case _ => 148 | forget() 149 | } 150 | 151 | } 152 | 153 | /** 154 | * Increase the timeout with an exponential backoff 155 | */ 156 | private def timeBackoff(): Unit ={ 157 | if(timeout.duration.toMillis * backoffFactor > maxTimeout.toMillis) { 158 | timeout = new Timeout(maxTimeout) 159 | } else { 160 | timeout = new Timeout(timeout.duration.toMillis * backoffFactor millis) 161 | } 162 | } 163 | 164 | /** 165 | * Retries a function while keeping track of a logic attempts counter and fails when the logic attemps counter is 166 | * too large 167 | * @param func The function to execute again 168 | */ 169 | private def retry(func: () => Unit): Unit ={ 170 | logicAttempts += 1 171 | if (logicAttempts < maxLogicAttempts) { 172 | func() 173 | } else { 174 | promise.failure(new PushFailedException(s"Failed $logicAttempts time while the maxLogicAttemps is $maxLogicAttempts ")) 175 | } 176 | } 177 | } 178 | object PushFSM { 179 | private var maxAttempts: Int = 10 180 | private var maxLogicAttempts: Int = 100 181 | private var initialTimeout: FiniteDuration = 5 seconds 182 | private var maxTimeout: FiniteDuration = 5 minutes 183 | private var backoffFactor: Double = 1.6 184 | 185 | /** 186 | * Initialize the FSM from the default config file 187 | * @param config 188 | */ 189 | def initialize(config: Config): Unit ={ 190 | maxAttempts = config.getInt("asyspark.push.maximum-attempts") 191 | maxLogicAttempts = config.getInt("asyspark.push.maximum-logic-attempts") 192 | initialTimeout = new FiniteDuration(config.getDuration("asyspark.push.initial-timeout", TimeUnit.MILLISECONDS), 193 | TimeUnit.MILLISECONDS 194 | ) 195 | maxTimeout = new FiniteDuration(config.getDuration("asyspark.push.maximum-timeout", TimeUnit.MICROSECONDS), 196 | TimeUnit.MICROSECONDS) 197 | backoffFactor = config.getDouble("asyspark.push.backoff-multiplier") 198 | } 199 | 200 | /** 201 | * construct a fsm 202 | * @param message 203 | * @param actorRef 204 | * @param ec 205 | * @tparam T 206 | * @return 207 | */ 208 | def apply[T](message: Int => T, actorRef: ActorRef)(implicit ec: ExecutionContext): PushFSM[T] = 209 | new PushFSM(message, actorRef, maxAttempts, maxLogicAttempts, initialTimeout, maxTimeout, backoffFactor) 210 | 211 | } 212 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVector.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.server 2 | 3 | import akka.actor.{Actor, ActorLogging} 4 | import spire.algebra.Semiring 5 | import spire.implicits._ 6 | import org.apache.spark.asyspark.core.partitions.Partition 7 | import scala.reflect.ClassTag 8 | 9 | /** 10 | * Created by wjf on 16-9-25. 11 | */ 12 | private[asyspark] abstract class PartialVector[@specialized V:Semiring : ClassTag](partition: Partition) extends Actor 13 | with ActorLogging with PushLogic { 14 | 15 | /** 16 | * the size of the partial vector 17 | */ 18 | val size: Int = partition.size 19 | 20 | /** 21 | * the data array contains the elements 22 | */ 23 | val data: Array[V] 24 | 25 | /** 26 | * update the data of the partial model by aggregating given keys and values into it 27 | * @param keys The keys 28 | * @param values The values 29 | * @return 30 | */ 31 | //todo I thinks this imp can be optimized 32 | def update(keys: Array[Long], values: Array[V]): Boolean = { 33 | var i = 0 34 | try { 35 | while (i < keys.length) { 36 | val key = partition.globalToLocal(keys(i)) 37 | // this is imp with the help of spire.implicits._ 38 | data(key) += values(i) 39 | i += 1 40 | } 41 | true 42 | } catch { 43 | case e: Exception => false 44 | } 45 | } 46 | 47 | def get(keys: Array[Long]): Array[V] = { 48 | var i =0 49 | val a = new Array[V](keys.length) 50 | while(i < keys.length) { 51 | val key = partition.globalToLocal(keys(i)) 52 | a(i) = data(key) 53 | i += 1 54 | } 55 | a 56 | } 57 | 58 | log.info(s"Constructed PartialVector[${implicitly[ClassTag[V]]}] of size $size (partition id: ${partition.index})") 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVectorDouble.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.server 2 | 3 | import org.apache.spark.asyspark.core.messages.server.request.{PullVector, PushVectorDouble} 4 | import org.apache.spark.asyspark.core.messages.server.response.ResponseDouble 5 | import org.apache.spark.asyspark.core.partitions.Partition 6 | import spire.implicits._ 7 | /** 8 | * Created by wjf on 16-9-25. 9 | */ 10 | private[asyspark] class PartialVectorDouble(partition: Partition) extends PartialVector[Double](partition) { 11 | override val data: Array[Double] = new Array[Double](size) 12 | override def receive: Receive = { 13 | case pull: PullVector => sender ! ResponseDouble(get(pull.keys)) 14 | case push: PushVectorDouble => 15 | update(push.keys, push.values) 16 | updateFinished(push.id) 17 | case x => handleLogic(x, sender) 18 | } 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVectorFloat.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.server 2 | 3 | import org.apache.spark.asyspark.core.messages.server.request.{PullVector, PushVectorFloat} 4 | import org.apache.spark.asyspark.core.messages.server.response.ResponseFloat 5 | import org.apache.spark.asyspark.core.partitions.Partition 6 | import spire.implicits._ 7 | /** 8 | * Created by wjf on 16-9-25. 9 | */ 10 | private[asyspark] class PartialVectorFloat(partition: Partition) extends PartialVector[Float](partition) { 11 | 12 | override val data: Array[Float] = new Array[Float](size) 13 | 14 | override def receive: Receive = { 15 | case pull: PullVector => sender ! ResponseFloat(get(pull.keys)) 16 | case push: PushVectorFloat => 17 | update(push.keys, push.values) 18 | updateFinished(push.id) 19 | case x => handleLogic(x, sender) 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVectorInt.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.server 2 | 3 | import org.apache.spark.asyspark.core.messages.server.request.{PullVector, PushVectorInt} 4 | import org.apache.spark.asyspark.core.messages.server.response.ResponseInt 5 | import org.apache.spark.asyspark.core.partitions.Partition 6 | import spire.implicits._ 7 | 8 | /** 9 | * Created by wjf on 16-9-25. 10 | */ 11 | private[asyspark] class PartialVectorInt(partition: Partition) extends PartialVector[Int](partition) { 12 | 13 | override val data: Array[Int] = new Array[Int](size) 14 | 15 | override def receive: Receive = { 16 | case pull: PullVector => sender ! ResponseInt(get(pull.keys)) 17 | case push: PushVectorInt => 18 | update(push.keys, push.values) 19 | updateFinished(push.id) 20 | case x => handleLogic(x, sender) 21 | } 22 | 23 | } 24 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVectorLong.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.server 2 | 3 | import org.apache.spark.asyspark.core.messages.server.request.{PullVector, PushVectorLong} 4 | import org.apache.spark.asyspark.core.messages.server.response.ResponseLong 5 | import org.apache.spark.asyspark.core.partitions.Partition 6 | import spire.implicits._ 7 | /** 8 | * Created by wjf on 16-9-25. 9 | */ 10 | private[asyspark] class PartialVectorLong(partition: Partition) extends PartialVector[Long](partition) { 11 | 12 | override val data: Array[Long] = new Array[Long](size) 13 | 14 | override def receive: Receive = { 15 | case pull: PullVector => sender ! ResponseLong(get(pull.keys)) 16 | case push: PushVectorLong => 17 | update(push.keys, push.values) 18 | updateFinished(push.id) 19 | case x => handleLogic(x, sender) 20 | } 21 | 22 | } 23 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/models/server/PushLogic.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.models.server 2 | 3 | import akka.actor.ActorRef 4 | import org.apache.spark.asyspark.core.messages.server.logic._ 5 | 6 | import scala.collection.mutable 7 | 8 | /** Some common push logic behavior 9 | * Created by wjf on 16-9-25. 10 | */ 11 | trait PushLogic { 12 | /** 13 | * A set of received message ids 14 | */ 15 | val receipt : mutable.HashSet[Int] = mutable.HashSet[Int]() 16 | 17 | /** 18 | * Unique identifier counter 19 | */ 20 | var UID: Int = 0 21 | 22 | /** 23 | * Increases the unique id and returns the next unique id 24 | * @return The next id 25 | */ 26 | private def nextId(): Int = { 27 | UID += 1 28 | UID 29 | } 30 | 31 | /** 32 | * Handle push message receipt logic 33 | * @param message The message 34 | * @param sender The sender 35 | */ 36 | def handleLogic(message: Any, sender: ActorRef) = message match { 37 | case GenerateUniqueID() => 38 | sender ! UniqueID(nextId()) 39 | 40 | case AcknowledgeReceipt(id) => 41 | if(receipt.contains(id)) { 42 | sender ! AcknowledgeReceipt(id) 43 | } else { 44 | sender ! NotAcknowledgeReceipt(id) 45 | } 46 | 47 | case Forget(id) => 48 | if(receipt.contains(id)) { 49 | receipt.remove(id) 50 | } 51 | sender ! Forget(id) 52 | } 53 | 54 | /** 55 | * Adds the message id to the received set 56 | * @param id The message id 57 | */ 58 | def updateFinished(id: Int): Unit ={ 59 | require(id >= 0, s"id must be positive but got ${id}") 60 | receipt.add(id) 61 | } 62 | 63 | 64 | } 65 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/partitions/Partition.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.partitions 2 | 3 | /** 4 | * Created by wjf on 16-9-24. 5 | */ 6 | abstract class Partition(val index: Int) extends Serializable { 7 | 8 | /** 9 | * check whether this partition contains some key 10 | * @param key the key 11 | * @return whether this partition contains some key 12 | */ 13 | def contains(key: Long): Boolean 14 | 15 | /** 16 | * converts given global key to a continuous local array index 17 | * @param key the global key 18 | * @return the local index 19 | */ 20 | def globalToLocal(key: Long): Int 21 | 22 | /** 23 | * the size of this partition 24 | * @return the size 25 | */ 26 | def size:Int 27 | } 28 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/partitions/Partitioner.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.partitions 2 | 3 | /** 4 | * Created by wjf on 16-9-24. 5 | */ 6 | trait Partitioner extends Serializable { 7 | /** 8 | * Assign a server to the given key 9 | * @param key The key to partition 10 | * @return The partition 11 | */ 12 | def partition(key: Long): Partition 13 | 14 | /** 15 | * Returns all partitions 16 | * @return The array of partitions 17 | */ 18 | def all(): Array[Partition] 19 | 20 | } 21 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/partitions/range/RangePartition.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.partitions.range 2 | 3 | import org.apache.spark.asyspark.core.partitions.Partition 4 | 5 | /** 6 | * A range partition 7 | * Created by wjf on 16-9-24. 8 | */ 9 | class RangePartition(index: Int, val start: Long, val end: Long) extends Partition(index) { 10 | 11 | @inline 12 | override def contains(key: Long): Boolean = key >= start && key <= end 13 | 14 | @inline 15 | override def size: Int = (end - start).toInt 16 | 17 | /** 18 | * Converts given key to a continuous local array index[0,1,2...] 19 | * @param key the global key 20 | * @return the local index 21 | */ 22 | @inline 23 | def globalToLocal(key: Long): Int = (key - start).toInt 24 | 25 | } 26 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/partitions/range/RangePartitioner.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.partitions.range 2 | 3 | import org.apache.spark.asyspark.core.partitions.{Partition, Partitioner} 4 | 5 | /** 6 | * Created by wjf on 16-9-24. 7 | */ 8 | class RangePartitioner(val partitions: Array[Partition], val numOfSmallPartitions: Int, val keysSmallPartiiton: Int, val keysSize: Long) extends Partitioner { 9 | val numAllSamllPartitions: Long = numOfSmallPartitions.toLong * keysSmallPartiiton.toLong 10 | val sizeOflargePartitions = keysSmallPartiiton + 1 11 | 12 | @inline 13 | override def partition(key: Long): Partition = { 14 | require(key >= 0 && key <= keysSize, s"IndexOutOfBoundsException ${key} while size = ${keysSize}") 15 | val index = if(key < numOfSmallPartitions) { 16 | (key / numOfSmallPartitions).toInt 17 | } else { 18 | (numOfSmallPartitions + (keysSize - numAllSamllPartitions) / sizeOflargePartitions ).toInt 19 | } 20 | partitions(index) 21 | } 22 | 23 | override def all(): Array[Partition] = partitions 24 | 25 | } 26 | object RangePartitioner { 27 | 28 | /** 29 | * Create a RangePartitioner for given number of partition and keys 30 | * @param numOfPartitions The number of partiitons 31 | * @param numberOfKeys The number of keys 32 | * @return A rangePartitioner 33 | */ 34 | def apply(numOfPartitions: Int, numberOfKeys: Long): RangePartitioner = { 35 | val partitions = new Array[Partition](numOfPartitions) 36 | // this largePartition just has more than one element then smalllPartition 37 | val numLargePartition = numberOfKeys % numOfPartitions 38 | val numSmallPartition = numOfPartitions - numLargePartition 39 | val keysSmallPartition = ((numberOfKeys - numLargePartition) / numOfPartitions).toInt 40 | var i = 0 41 | var start: Long = 0L 42 | var end: Long = 0L 43 | while(i < numOfPartitions) { 44 | if(i < numSmallPartition) { 45 | end += keysSmallPartition 46 | partitions(i) = new RangePartition(i, start, end) 47 | start += keysSmallPartition 48 | } else { 49 | end += keysSmallPartition + 1 50 | partitions(i) = new RangePartition(i, start, end) 51 | start += keysSmallPartition + 1 52 | } 53 | i += 1 54 | } 55 | 56 | 57 | new RangePartitioner(partitions, numOfPartitions, keysSmallPartition, numberOfKeys) 58 | 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/serialization/FastPrimitiveDeserializer.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.serialization 2 | 3 | /** 4 | * A very fast primitive deserializer using sun's Unsafe to directly read/write memory regions in the JVM 5 | * 6 | * @param bytes The serialized data 7 | */ 8 | private class FastPrimitiveDeserializer(bytes: Array[Byte]) { 9 | 10 | private val unsafe = SerializationConstants.unsafe 11 | private val offset = unsafe.arrayBaseOffset(classOf[Array[Byte]]) 12 | private var position: Long = 0 13 | 14 | @inline 15 | def readFloat(): Float = { 16 | position += SerializationConstants.sizeOfFloat 17 | unsafe.getFloat(bytes, offset + position - SerializationConstants.sizeOfFloat) 18 | } 19 | 20 | @inline 21 | def readDouble(): Double = { 22 | position += SerializationConstants.sizeOfDouble 23 | unsafe.getDouble(bytes, offset + position - SerializationConstants.sizeOfDouble) 24 | } 25 | 26 | @inline 27 | def readInt(): Int = { 28 | position += SerializationConstants.sizeOfInt 29 | unsafe.getInt(bytes, offset + position - SerializationConstants.sizeOfInt) 30 | } 31 | 32 | @inline 33 | def readLong(): Long = { 34 | position += SerializationConstants.sizeOfLong 35 | unsafe.getLong(bytes, offset + position - SerializationConstants.sizeOfLong) 36 | } 37 | 38 | @inline 39 | def readByte(): Byte = { 40 | position += SerializationConstants.sizeOfByte 41 | unsafe.getByte(bytes, offset + position - SerializationConstants.sizeOfByte) 42 | } 43 | 44 | @inline 45 | def readArrayInt(size: Int): Array[Int] = { 46 | val array = new Array[Int](size) 47 | unsafe.copyMemory(bytes, offset + position, array, unsafe.arrayBaseOffset(classOf[Array[Int]]), size * SerializationConstants.sizeOfInt) 48 | position += size * SerializationConstants.sizeOfInt 49 | array 50 | } 51 | 52 | @inline 53 | def readArrayLong(size: Int): Array[Long] = { 54 | val array = new Array[Long](size) 55 | unsafe.copyMemory(bytes, offset + position, array, unsafe.arrayBaseOffset(classOf[Array[Long]]), size * SerializationConstants.sizeOfLong) 56 | position += size * SerializationConstants.sizeOfLong 57 | array 58 | } 59 | 60 | @inline 61 | def readArrayFloat(size: Int): Array[Float] = { 62 | val array = new Array[Float](size) 63 | unsafe.copyMemory(bytes, offset + position, array, unsafe.arrayBaseOffset(classOf[Array[Float]]), size * SerializationConstants.sizeOfFloat) 64 | position += size * SerializationConstants.sizeOfFloat 65 | array 66 | } 67 | 68 | @inline 69 | def readArrayDouble(size: Int): Array[Double] = { 70 | val array = new Array[Double](size) 71 | unsafe.copyMemory(bytes, offset + position, array, unsafe.arrayBaseOffset(classOf[Array[Double]]), size * SerializationConstants.sizeOfDouble) 72 | position += size * SerializationConstants.sizeOfDouble 73 | array 74 | } 75 | 76 | } 77 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/serialization/FastPrimitiveSerializer.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.serialization 2 | 3 | /** 4 | * A very fast primitive serializer using sun's Unsafe to directly read/write memory regions in the JVM 5 | * 6 | * @param size The size of the serialized output (in bytes) 7 | */ 8 | private class FastPrimitiveSerializer(size: Int) { 9 | 10 | val bytes = new Array[Byte](size) 11 | 12 | private val unsafe = SerializationConstants.unsafe 13 | private val offset = unsafe.arrayBaseOffset(classOf[Array[Byte]]) 14 | private var position: Long = 0 15 | 16 | @inline 17 | def reset(): Unit = position = 0L 18 | 19 | @inline 20 | def writeFloat(value: Float): Unit = { 21 | unsafe.putFloat(bytes, offset + position, value) 22 | position += SerializationConstants.sizeOfFloat 23 | } 24 | 25 | @inline 26 | def writeInt(value: Int): Unit = { 27 | unsafe.putInt(bytes, offset + position, value) 28 | position += SerializationConstants.sizeOfInt 29 | } 30 | 31 | @inline 32 | def writeByte(value: Byte): Unit = { 33 | unsafe.putByte(bytes, offset + position, value) 34 | position += SerializationConstants.sizeOfByte 35 | } 36 | 37 | @inline 38 | def writeLong(value: Long): Unit = { 39 | unsafe.putLong(bytes, offset + position, value) 40 | position += SerializationConstants.sizeOfLong 41 | } 42 | 43 | @inline 44 | def writeDouble(value: Double): Unit = { 45 | unsafe.putDouble(bytes, offset + position, value) 46 | position += SerializationConstants.sizeOfDouble 47 | } 48 | 49 | @inline 50 | def writeArrayInt(value: Array[Int]): Unit = { 51 | unsafe.copyMemory(value, unsafe.arrayBaseOffset(classOf[Array[Int]]), bytes, offset + position, value.length * SerializationConstants.sizeOfInt) 52 | position += value.length * SerializationConstants.sizeOfInt 53 | } 54 | 55 | @inline 56 | def writeArrayLong(value: Array[Long]): Unit = { 57 | unsafe.copyMemory(value, unsafe.arrayBaseOffset(classOf[Array[Long]]), bytes, offset + position, value.length * SerializationConstants.sizeOfLong) 58 | position += value.length * SerializationConstants.sizeOfLong 59 | } 60 | 61 | @inline 62 | def writeArrayFloat(value: Array[Float]): Unit = { 63 | unsafe.copyMemory(value, unsafe.arrayBaseOffset(classOf[Array[Float]]), bytes, offset + position, value.length * SerializationConstants.sizeOfFloat) 64 | position += value.length * SerializationConstants.sizeOfFloat 65 | } 66 | 67 | @inline 68 | def writeArrayDouble(value: Array[Double]): Unit = { 69 | unsafe.copyMemory(value, unsafe.arrayBaseOffset(classOf[Array[Double]]), bytes, offset + position, value.length * SerializationConstants.sizeOfDouble) 70 | position += value.length * SerializationConstants.sizeOfDouble 71 | } 72 | 73 | } 74 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/serialization/RequestSerializer.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.serialization 2 | import akka.serialization._ 3 | import org.apache.spark.asyspark.core.messages.server.request._ 4 | 5 | 6 | /** 7 | * A fast serializer for requests 8 | * 9 | * Internally this uses a very fast primitive serialization/deserialization routine using sun's Unsafe class for direct 10 | * read/write access to JVM memory. This might not be portable across different JVMs. If serialization causes problems 11 | * you can default to JavaSerialization by removing the serialization-bindings in the configuration. 12 | * Created by wjf on 16-9-23. 13 | */ 14 | class RequestSerializer extends Serializer { 15 | 16 | override def identifier: Int = 13370 17 | 18 | override def includeManifest: Boolean = false 19 | 20 | override def fromBinary(bytes: Array[Byte], manifest: Option[Class[_]]): AnyRef = { 21 | val fpd = new FastPrimitiveDeserializer(bytes) 22 | val objectType = fpd.readByte() 23 | val objectSize = fpd.readInt() 24 | 25 | objectType match { 26 | // case SerializationConstants.pullMatrixByte => 27 | // val rows = fpd.readArrayLong(objectSize) 28 | // val cols = fpd.readArrayInt(objectSize) 29 | // PullMatrix(rows, cols) 30 | // 31 | // case SerializationConstants.pullMatrixRowsByte => 32 | // val rows = fpd.readArrayLong(objectSize) 33 | // PullMatrixRows(rows) 34 | 35 | case SerializationConstants.pullVectorByte => 36 | val keys = fpd.readArrayLong(objectSize) 37 | PullVector(keys) 38 | 39 | // case SerializationConstants.pushMatrixDoubleByte => 40 | // val id = fpd.readInt() 41 | // val rows = fpd.readArrayLong(objectSize) 42 | // val cols = fpd.readArrayInt(objectSize) 43 | // val values = fpd.readArrayDouble(objectSize) 44 | // PushMatrixDouble(id, rows, cols, values) 45 | // 46 | // case SerializationConstants.pushMatrixFloatByte => 47 | // val id = fpd.readInt() 48 | // val rows = fpd.readArrayLong(objectSize) 49 | // val cols = fpd.readArrayInt(objectSize) 50 | // val values = fpd.readArrayFloat(objectSize) 51 | // PushMatrixFloat(id, rows, cols, values) 52 | // 53 | // case SerializationConstants.pushMatrixIntByte => 54 | // val id = fpd.readInt() 55 | // val rows = fpd.readArrayLong(objectSize) 56 | // val cols = fpd.readArrayInt(objectSize) 57 | // val values = fpd.readArrayInt(objectSize) 58 | // PushMatrixInt(id, rows, cols, values) 59 | // 60 | // case SerializationConstants.pushMatrixLongByte => 61 | // val id = fpd.readInt() 62 | // val rows = fpd.readArrayLong(objectSize) 63 | // val cols = fpd.readArrayInt(objectSize) 64 | // val values = fpd.readArrayLong(objectSize) 65 | // PushMatrixLong(id, rows, cols, values) 66 | 67 | case SerializationConstants.pushVectorDoubleByte => 68 | val id = fpd.readInt() 69 | val keys = fpd.readArrayLong(objectSize) 70 | val values = fpd.readArrayDouble(objectSize) 71 | PushVectorDouble(id, keys, values) 72 | 73 | case SerializationConstants.pushVectorFloatByte => 74 | val id = fpd.readInt() 75 | val keys = fpd.readArrayLong(objectSize) 76 | val values = fpd.readArrayFloat(objectSize) 77 | PushVectorFloat(id, keys, values) 78 | 79 | case SerializationConstants.pushVectorIntByte => 80 | val id = fpd.readInt() 81 | val keys = fpd.readArrayLong(objectSize) 82 | val values = fpd.readArrayInt(objectSize) 83 | PushVectorInt(id, keys, values) 84 | 85 | case SerializationConstants.pushVectorLongByte => 86 | val id = fpd.readInt() 87 | val keys = fpd.readArrayLong(objectSize) 88 | val values = fpd.readArrayLong(objectSize) 89 | PushVectorLong(id, keys, values) 90 | } 91 | } 92 | 93 | override def toBinary(o: AnyRef): Array[Byte] = { 94 | o match { 95 | // case x: PullMatrix => 96 | // val fps = new FastPrimitiveSerializer(5 + x.rows.length * SerializationConstants.sizeOfLong + 97 | // x.rows.length * SerializationConstants.sizeOfInt) 98 | // fps.writeByte(SerializationConstants.pullMatrixByte) 99 | // fps.writeInt(x.rows.length) 100 | // fps.writeArrayLong(x.rows) 101 | // fps.writeArrayInt(x.cols) 102 | // fps.bytes 103 | // 104 | // case x: PullMatrixRows => 105 | // val fps = new FastPrimitiveSerializer(5 + x.rows.length * SerializationConstants.sizeOfLong) 106 | // fps.writeByte(SerializationConstants.pullMatrixRowsByte) 107 | // fps.writeInt(x.rows.length) 108 | // fps.writeArrayLong(x.rows) 109 | // fps.bytes 110 | 111 | case x: PullVector => 112 | val fps = new FastPrimitiveSerializer(5 + x.keys.length * SerializationConstants.sizeOfLong) 113 | fps.writeByte(SerializationConstants.pullVectorByte) 114 | fps.writeInt(x.keys.length) 115 | fps.writeArrayLong(x.keys) 116 | fps.bytes 117 | 118 | // case x: PushMatrixDouble => 119 | // val fps = new FastPrimitiveSerializer(9 + x.rows.length * SerializationConstants.sizeOfLong + 120 | // x.rows.length * SerializationConstants.sizeOfInt + 121 | // x.rows.length * SerializationConstants.sizeOfDouble) 122 | // fps.writeByte(SerializationConstants.pushMatrixDoubleByte) 123 | // fps.writeInt(x.rows.length) 124 | // fps.writeInt(x.id) 125 | // fps.writeArrayLong(x.rows) 126 | // fps.writeArrayInt(x.cols) 127 | // fps.writeArrayDouble(x.values) 128 | // fps.bytes 129 | // 130 | // case x: PushMatrixFloat => 131 | // val fps = new FastPrimitiveSerializer(9 + x.rows.length * SerializationConstants.sizeOfLong + 132 | // x.rows.length * SerializationConstants.sizeOfInt + 133 | // x.rows.length * SerializationConstants.sizeOfFloat) 134 | // fps.writeByte(SerializationConstants.pushMatrixFloatByte) 135 | // fps.writeInt(x.rows.length) 136 | // fps.writeInt(x.id) 137 | // fps.writeArrayLong(x.rows) 138 | // fps.writeArrayInt(x.cols) 139 | // fps.writeArrayFloat(x.values) 140 | // fps.bytes 141 | // 142 | // case x: PushMatrixInt => 143 | // val fps = new FastPrimitiveSerializer(9 + x.rows.length * SerializationConstants.sizeOfLong + 144 | // x.rows.length * SerializationConstants.sizeOfInt + 145 | // x.rows.length * SerializationConstants.sizeOfInt) 146 | // fps.writeByte(SerializationConstants.pushMatrixIntByte) 147 | // fps.writeInt(x.rows.length) 148 | // fps.writeInt(x.id) 149 | // fps.writeArrayLong(x.rows) 150 | // fps.writeArrayInt(x.cols) 151 | // fps.writeArrayInt(x.values) 152 | // fps.bytes 153 | // 154 | // case x: PushMatrixLong => 155 | // val fps = new FastPrimitiveSerializer(9 + x.rows.length * SerializationConstants.sizeOfLong + 156 | // x.rows.length * SerializationConstants.sizeOfInt + 157 | // x.rows.length * SerializationConstants.sizeOfLong) 158 | // fps.writeByte(SerializationConstants.pushMatrixLongByte) 159 | // fps.writeInt(x.rows.length) 160 | // fps.writeInt(x.id) 161 | // fps.writeArrayLong(x.rows) 162 | // fps.writeArrayInt(x.cols) 163 | // fps.writeArrayLong(x.values) 164 | // fps.bytes 165 | 166 | case x: PushVectorDouble => 167 | val fps = new FastPrimitiveSerializer(9 + x.keys.length * SerializationConstants.sizeOfLong + 168 | x.keys.length * SerializationConstants.sizeOfDouble) 169 | fps.writeByte(SerializationConstants.pushVectorDoubleByte) 170 | fps.writeInt(x.keys.length) 171 | fps.writeInt(x.id) 172 | fps.writeArrayLong(x.keys) 173 | fps.writeArrayDouble(x.values) 174 | fps.bytes 175 | 176 | case x: PushVectorFloat => 177 | val fps = new FastPrimitiveSerializer(9 + x.keys.length * SerializationConstants.sizeOfLong + 178 | x.keys.length * SerializationConstants.sizeOfFloat) 179 | fps.writeByte(SerializationConstants.pushVectorFloatByte) 180 | fps.writeInt(x.keys.length) 181 | fps.writeInt(x.id) 182 | fps.writeArrayLong(x.keys) 183 | fps.writeArrayFloat(x.values) 184 | fps.bytes 185 | 186 | case x: PushVectorInt => 187 | val fps = new FastPrimitiveSerializer(9 + x.keys.length * SerializationConstants.sizeOfLong + 188 | x.keys.length * SerializationConstants.sizeOfInt) 189 | fps.writeByte(SerializationConstants.pushVectorIntByte) 190 | fps.writeInt(x.keys.length) 191 | fps.writeInt(x.id) 192 | fps.writeArrayLong(x.keys) 193 | fps.writeArrayInt(x.values) 194 | fps.bytes 195 | 196 | case x: PushVectorLong => 197 | val fps = new FastPrimitiveSerializer(9 + x.keys.length * SerializationConstants.sizeOfLong + 198 | x.keys.length * SerializationConstants.sizeOfLong) 199 | fps.writeByte(SerializationConstants.pushVectorLongByte) 200 | fps.writeInt(x.keys.length) 201 | fps.writeInt(x.id) 202 | fps.writeArrayLong(x.keys) 203 | fps.writeArrayLong(x.values) 204 | fps.bytes 205 | } 206 | } 207 | } 208 | 209 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/serialization/ResponseSerializer.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.serialization 2 | 3 | import akka.serialization._ 4 | import org.apache.spark.asyspark.core.messages.server.response.{ResponseDouble, ResponseFloat, ResponseInt, ResponseLong} 5 | 6 | /** 7 | * A fast serializer for responses 8 | * 9 | * Internally this uses a very fast primitive serialization/deserialization routine using sun's Unsafe class for direct 10 | * read/write access to JVM memory. This might not be portable across different JVMs. If serialization causes problems 11 | * you can default to JavaSerialization by removing the serialization-bindings in the configuration. 12 | */ 13 | class ResponseSerializer extends Serializer { 14 | 15 | override def identifier: Int = 13371 16 | 17 | override def includeManifest: Boolean = false 18 | 19 | override def fromBinary(bytes: Array[Byte], manifest: Option[Class[_]]): AnyRef = { 20 | val fpd = new FastPrimitiveDeserializer(bytes) 21 | val objectType = fpd.readByte() 22 | val objectSize = fpd.readInt() 23 | 24 | objectType match { 25 | case SerializationConstants.responseDoubleByte => 26 | val values = fpd.readArrayDouble(objectSize) 27 | ResponseDouble(values) 28 | 29 | case SerializationConstants.responseFloatByte => 30 | val values = fpd.readArrayFloat(objectSize) 31 | ResponseFloat(values) 32 | 33 | case SerializationConstants.responseIntByte => 34 | val values = fpd.readArrayInt(objectSize) 35 | ResponseInt(values) 36 | 37 | case SerializationConstants.responseLongByte => 38 | val values = fpd.readArrayLong(objectSize) 39 | ResponseLong(values) 40 | } 41 | } 42 | 43 | override def toBinary(o: AnyRef): Array[Byte] = { 44 | o match { 45 | case x: ResponseDouble => 46 | val fps = new FastPrimitiveSerializer(5 + x.values.length * SerializationConstants.sizeOfDouble) 47 | fps.writeByte(SerializationConstants.responseDoubleByte) 48 | fps.writeInt(x.values.length) 49 | fps.writeArrayDouble(x.values) 50 | fps.bytes 51 | 52 | // case x: ResponseRowsDouble => 53 | // val fps = new FastPrimitiveSerializer(5 + x.values.length * x.columns * SerializationConstants.sizeOfDouble) 54 | // fps.writeByte(SerializationConstants.responseDoubleByte) 55 | // fps.writeInt(x.values.length * x.columns) 56 | // var i = 0 57 | // while (i < x.values.length) { 58 | // fps.writeArrayDouble(x.values(i)) 59 | // i += 1 60 | // } 61 | // fps.bytes 62 | // 63 | case x: ResponseFloat => 64 | val fps = new FastPrimitiveSerializer(5 + x.values.length * SerializationConstants.sizeOfFloat) 65 | fps.writeByte(SerializationConstants.responseFloatByte) 66 | fps.writeInt(x.values.length) 67 | fps.writeArrayFloat(x.values) 68 | fps.bytes 69 | // 70 | // case x: ResponseRowsFloat => 71 | // val fps = new FastPrimitiveSerializer(5 + x.values.length * x.columns * SerializationConstants.sizeOfFloat) 72 | // fps.writeByte(SerializationConstants.responseFloatByte) 73 | // fps.writeInt(x.values.length * x.columns) 74 | // var i = 0 75 | // while (i < x.values.length) { 76 | // fps.writeArrayFloat(x.values(i)) 77 | // i += 1 78 | // } 79 | // fps.bytes 80 | 81 | case x: ResponseInt => 82 | val fps = new FastPrimitiveSerializer(5 + x.values.length * SerializationConstants.sizeOfInt) 83 | fps.writeByte(SerializationConstants.responseIntByte) 84 | fps.writeInt(x.values.length) 85 | fps.writeArrayInt(x.values) 86 | fps.bytes 87 | 88 | // case x: ResponseRowsInt => 89 | // val fps = new FastPrimitiveSerializer(5 + x.values.length * x.columns * SerializationConstants.sizeOfInt) 90 | // fps.writeByte(SerializationConstants.responseIntByte) 91 | // fps.writeInt(x.values.length * x.columns) 92 | // var i = 0 93 | // while (i < x.values.length) { 94 | // fps.writeArrayInt(x.values(i)) 95 | // i += 1 96 | // } 97 | // fps.bytes 98 | 99 | case x: ResponseLong => 100 | val fps = new FastPrimitiveSerializer(5 + x.values.length * SerializationConstants.sizeOfLong) 101 | fps.writeByte(SerializationConstants.responseLongByte) 102 | fps.writeInt(x.values.length) 103 | fps.writeArrayLong(x.values) 104 | fps.bytes 105 | 106 | // case x: ResponseRowsLong => 107 | // val fps = new FastPrimitiveSerializer(5 + x.values.length * x.columns * SerializationConstants.sizeOfLong) 108 | // fps.writeByte(SerializationConstants.responseLongByte) 109 | // fps.writeInt(x.values.length * x.columns) 110 | // var i = 0 111 | // while (i < x.values.length) { 112 | // fps.writeArrayLong(x.values(i)) 113 | // i += 1 114 | // } 115 | // fps.bytes 116 | } 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/asyspark/core/serialization/SerializationConstants.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.asyspark.core.serialization 2 | 3 | import sun.misc.Unsafe 4 | 5 | /** 6 | * Some constants used for serialization 7 | */ 8 | private object SerializationConstants { 9 | 10 | // Unsafe field for direct memory access 11 | private val field = classOf[Unsafe].getDeclaredField("theUnsafe") 12 | field.setAccessible(true) 13 | val unsafe = field.get(null).asInstanceOf[Unsafe] 14 | 15 | // Size of different java primitives to perform direct read/write to java memory 16 | val sizeOfByte = 1 17 | val sizeOfShort = 2 18 | val sizeOfInt = 4 19 | val sizeOfLong = 8 20 | val sizeOfFloat = 4 21 | val sizeOfDouble = 8 22 | 23 | // Byte identifiers for different message types 24 | val pullMatrixByte: Byte = 0x00 25 | val pullMatrixRowsByte: Byte = 0x01 26 | val pullVectorByte: Byte = 0x02 27 | val pushMatrixDoubleByte: Byte = 0x03 28 | val pushMatrixFloatByte: Byte = 0x04 29 | val pushMatrixIntByte: Byte = 0x05 30 | val pushMatrixLongByte: Byte = 0x06 31 | val pushVectorDoubleByte: Byte = 0x07 32 | val pushVectorFloatByte: Byte = 0x08 33 | val pushVectorIntByte: Byte = 0x09 34 | val pushVectorLongByte: Byte = 0x0A 35 | val responseDoubleByte: Byte = 0x10 36 | val responseFloatByte: Byte = 0x11 37 | val responseIntByte: Byte = 0x12 38 | val responseLongByte: Byte = 0x13 39 | 40 | } 41 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/AsySGDExample.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples 2 | 3 | import org.apache.spark.asyspark.asyml.asysgd.AsyGradientDescent 4 | import org.apache.spark.mllib.linalg.Vectors 5 | import org.apache.spark.mllib.optimization.{GradientDescent, LogisticGradient, SimpleUpdater} 6 | import org.apache.spark.mllib.regression.LabeledPoint 7 | import org.apache.spark.mllib.util.MLUtils 8 | import org.apache.spark.{SparkConf, SparkContext} 9 | 10 | import scala.collection.JavaConverters._ 11 | import scala.util.Random 12 | /** 13 | * Created by wjf on 16-9-19. 14 | */ 15 | object AsySGDExample { 16 | 17 | def main(args: Array[String]): Unit = { 18 | val sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local[*]")) 19 | val nPoints = 4000000 20 | val A = 2.0 21 | val B = -1.5 22 | 23 | val initialB = -1.0 24 | val initialWeights = Array(initialB) 25 | 26 | val gradient = new LogisticGradient() 27 | val updater = new SimpleUpdater() 28 | val stepSize = 1.0 29 | val numIterations = 1000 30 | val regParam = 0 31 | val miniBatchFrac = 1.0 32 | val convergenceTolerance = 5.0e-1 33 | 34 | // Add a extra variable consisting of all 1.0's for the intercept. 35 | val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) 36 | 37 | val data = testData.map { case LabeledPoint(label, features) => 38 | label -> MLUtils.appendBias(features) 39 | } 40 | 41 | val dataRDD = sc.parallelize(data, 8).cache() 42 | val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0) 43 | 44 | // our asychronous implement 45 | var start = System.nanoTime() 46 | val (weights, weightHistory) = AsyGradientDescent.runAsySGD( 47 | dataRDD, 48 | gradient, 49 | updater, 50 | stepSize, 51 | numIterations, 52 | regParam, 53 | miniBatchFrac, 54 | initialWeightsWithIntercept, 55 | convergenceTolerance) 56 | var end = System.nanoTime() 57 | println((end - start) / 1e6 +"ms") 58 | weights.toArray.foreach(println) 59 | // use spark implement 60 | start = System.nanoTime() 61 | val (weight, weightHistorys) = GradientDescent.runMiniBatchSGD( dataRDD, 62 | gradient, 63 | updater, 64 | stepSize, 65 | numIterations, 66 | regParam, 67 | miniBatchFrac, 68 | initialWeightsWithIntercept, 69 | convergenceTolerance) 70 | end = System.nanoTime() 71 | println((end - start)/ 1e6 + "ms") 72 | weight.toArray.foreach(println) 73 | } 74 | 75 | } 76 | object GradientDescentSuite { 77 | 78 | def generateLogisticInputAsList( 79 | offset: Double, 80 | scale: Double, 81 | nPoints: Int, 82 | seed: Int): java.util.List[LabeledPoint] = { 83 | generateGDInput(offset, scale, nPoints, seed).asJava 84 | } 85 | // Generate input of the form Y = logistic(offset + scale * X) 86 | def generateGDInput( 87 | offset: Double, 88 | scale: Double, 89 | nPoints: Int, 90 | seed: Int): Seq[LabeledPoint] = { 91 | val rnd = new Random(seed) 92 | val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian()) 93 | 94 | val unifRand = new Random(45) 95 | val rLogis = (0 until nPoints).map { i => 96 | val u = unifRand.nextDouble() 97 | math.log(u) - math.log(1.0-u) 98 | } 99 | 100 | val y: Seq[Int] = (0 until nPoints).map { i => 101 | val yVal = offset + scale * x1(i) + rLogis(i) 102 | if (yVal > 0) 1 else 0 103 | } 104 | (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(x1(i)))) 105 | } 106 | } -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/TestBroadCast.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples 2 | 3 | import org.apache.spark.internal.Logging 4 | import org.apache.spark.sql.SparkSession 5 | 6 | import scala.collection.mutable 7 | 8 | /** 9 | * Created by wjf on 16-9-24. 10 | */ 11 | object TestBroadCast extends Logging{ 12 | val sparkSession = SparkSession.builder().appName("test BoradCast").getOrCreate() 13 | val sc = sparkSession.sparkContext 14 | def main(args: Array[String]): Unit = { 15 | 16 | // val data = sc.parallelize(Seq(1 until 10000000)) 17 | val num = args(args.length - 2).toInt 18 | val times = args(args.length -1).toInt 19 | println(num) 20 | val start = System.nanoTime() 21 | val seq =Seq(1 until num) 22 | for(i <- 0 until times) { 23 | val start2 = System.nanoTime() 24 | val bc = sc.broadcast(seq) 25 | val rdd = sc.parallelize(1 until 10, 5) 26 | rdd.map(_ => bc.value.take(1)).collect() 27 | println((System.nanoTime() - start2)/ 1e6 + "ms") 28 | } 29 | logInfo((System.nanoTime() - start) / 1e6 + "ms") 30 | } 31 | 32 | def testMap(): Unit ={ 33 | 34 | val smallRDD = sc.parallelize(Seq(1,2,3)) 35 | val bigRDD = sc.parallelize(Seq(1 until 20)) 36 | 37 | bigRDD.mapPartitions { 38 | partition => 39 | val hashMap = new mutable.HashMap[Int,Int]() 40 | for(ele <- smallRDD) { 41 | hashMap(ele) = ele 42 | } 43 | // some operation here 44 | partition 45 | 46 | } 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/TestClient.scala: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CASISCAS/asyspark/cff60c9d4ae6eb01691a8a8b86f3deca8b67a025/src/main/scala/org/apache/spark/examples/TestClient.scala -------------------------------------------------------------------------------- /src/main/scala/org/apache/spark/examples/TestRemote.scala: -------------------------------------------------------------------------------- 1 | package org.apache.spark.examples 2 | 3 | import java.io.File 4 | 5 | import akka.actor.{ActorSystem, Props} 6 | import com.typesafe.config.ConfigFactory 7 | import com.typesafe.scalalogging.slf4j.StrictLogging 8 | import org.apache.spark.asyspark.core.Main.{logger => _, _} 9 | import org.apache.spark.asyspark.core.messages.master.ServerList 10 | 11 | /** 12 | * Created by wjf on 16-9-27. 13 | */ 14 | object TestRemote extends StrictLogging { 15 | def main(args: Array[String]): Unit = { 16 | 17 | val default = ConfigFactory.parseResourcesAnySyntax("asyspark") 18 | val config = ConfigFactory.parseFile(new File(getClass.getClassLoader.getResource("asyspark.conf").getFile)).withFallback(default).resolve() 19 | val system = ActorSystem(config.getString("asyspark.server.system"), config.getConfig("asyspark.server")) 20 | 21 | val serverHost = config.getString("asyspark.master.host") 22 | val serverPort = config.getInt("asyspark.master.port") 23 | val serverName = config.getString("asyspark.master.name") 24 | val serverSystem = config.getString("asyspark.master.system") 25 | logger.debug("Starting server actor") 26 | 27 | val master = system.actorSelection(s"akka.tcp://${serverSystem}@${serverHost}:${serverPort}/user/${serverName}") 28 | 29 | 30 | println(master.anchorPath.address) 31 | 32 | } 33 | 34 | } 35 | -------------------------------------------------------------------------------- /src/test/scala/Test.scala: -------------------------------------------------------------------------------- 1 | /** 2 | * Created by wjf on 16-9-19. 3 | */ 4 | class Test { 5 | 6 | } --------------------------------------------------------------------------------