├── .gitignore
├── LICENSE
├── README.md
├── conf
├── default.conf.template
├── env.sh.template
└── servers.template
├── dependency-reduced-pom.xml
├── docs
├── doc
└── problem
├── pom.xml
├── spark-asyspark_2.11.iml
└── src
├── META-INF
└── MANIFEST.MF
├── main
└── scala
│ └── org
│ └── apache
│ └── spark
│ ├── asyspark
│ ├── asyml
│ │ └── asysgd
│ │ │ ├── AsyGradientDescent.scala
│ │ │ └── GlobalWeight.scala
│ └── core
│ │ ├── Client.scala
│ │ ├── Exceptios
│ │ ├── PullFailedException.scala
│ │ └── PushFailedException.scala
│ │ ├── Main.scala
│ │ ├── Master.scala
│ │ ├── Server.scala
│ │ ├── messages
│ │ ├── master
│ │ │ ├── ClientList.scala
│ │ │ ├── RegisterClient.scala
│ │ │ ├── RegisterServer.scala
│ │ │ └── ServerList.scala
│ │ └── server
│ │ │ ├── logic
│ │ │ ├── AcknowledgeReceipt.scala
│ │ │ ├── Forget.scala
│ │ │ ├── GenerateUniqueID.scala
│ │ │ ├── NotAcknowledgeReceipt.scala
│ │ │ └── UniqueID.scala
│ │ │ ├── request
│ │ │ ├── PullVector.scala
│ │ │ ├── PushVectorDouble.scala
│ │ │ ├── PushVectorFloat.scala
│ │ │ ├── PushVectorInt.scala
│ │ │ ├── PushVectorLong.scala
│ │ │ └── Request.scala
│ │ │ └── response
│ │ │ ├── Response.scala
│ │ │ ├── ResponseDouble.scala
│ │ │ ├── ResponseFloat.scala
│ │ │ ├── ResponseInt.scala
│ │ │ └── ResponseLong.scala
│ │ ├── models
│ │ ├── client
│ │ │ ├── BigVector.scala
│ │ │ └── asyImp
│ │ │ │ ├── AsyBigVector.scala
│ │ │ │ ├── AsyBigVectorDouble.scala
│ │ │ │ ├── AsyBigVectorFloat.scala
│ │ │ │ ├── AsyBigVectorInt.scala
│ │ │ │ ├── AsyBigVectorLong.scala
│ │ │ │ ├── DeserializationHelper.scala
│ │ │ │ ├── PullFSM.scala
│ │ │ │ └── PushFSM.scala
│ │ └── server
│ │ │ ├── PartialVector.scala
│ │ │ ├── PartialVectorDouble.scala
│ │ │ ├── PartialVectorFloat.scala
│ │ │ ├── PartialVectorInt.scala
│ │ │ ├── PartialVectorLong.scala
│ │ │ └── PushLogic.scala
│ │ ├── partitions
│ │ ├── Partition.scala
│ │ ├── Partitioner.scala
│ │ └── range
│ │ │ ├── RangePartition.scala
│ │ │ └── RangePartitioner.scala
│ │ └── serialization
│ │ ├── FastPrimitiveDeserializer.scala
│ │ ├── FastPrimitiveSerializer.scala
│ │ ├── RequestSerializer.scala
│ │ ├── ResponseSerializer.scala
│ │ └── SerializationConstants.scala
│ └── examples
│ ├── AsySGDExample.scala
│ ├── TestBroadCast.scala
│ ├── TestClient.scala
│ └── TestRemote.scala
└── test
└── scala
└── Test.scala
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 | target/
3 | log/
4 | src/main/scala/org/apache/spark/myExamples
5 | src/main/scala/org/apache/spark/mySql
6 | .idea/
7 | *.log
8 | classes/
9 | *.class
10 | *.jar
11 |
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2016 codlife
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # asyspark
2 | ## Spark
3 | Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, Python, and R, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and DataFrames, MLlib for machine learning, GraphX for graph processing, and Spark Streaming for stream processing.
4 | ## asySpark
5 | AsySpark is an component of spark, this component can make machine learning work more efficient with a asynchronous computing model.such as asynchronous stochastic gradient descent.
6 | ## Tips
7 | If you want to do something with us, contact us.
8 | ## Forther reading
9 | ###Web resources
10 | 1:[Dean, NIPS‘13, Li, OSDI‘14 ](http://ps-lite.readthedocs.io/en/latest/overview.html#further-reads)The parameter server architecture
11 | 2:[淘宝参数服务器架构](http://www.36dsj.com/archives/60938)
12 | ###Papers
13 | 1:[Langford, NIPS‘09, Agarwal, NIPS‘11](http://arxiv.org/pdf/1104.5525.pdf) theoretical convergence of asynchronous SGD
14 | 2:[Li, WSDM‘16](http://www.cs.cmu.edu/~yuxiangw/docs/fm.pdf) practical considerations for asynchronous SGD with the parameter server
15 |
16 |
17 |
--------------------------------------------------------------------------------
/conf/default.conf.template:
--------------------------------------------------------------------------------
1 | # Place custom configuration for your glint cluster in here
2 | glint.master.host = "127.0.0.1"
3 | glint.master.port = 13370
4 |
5 |
--------------------------------------------------------------------------------
/conf/env.sh.template:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Change these values depending on your setup
3 | GLINT_JAR_PATH=./bin/Glint-0.1.jar # Path to the assembled glint jar file, this should be the same on all cluster machines
4 | GLINT_MASTER_OPTS="-Xmx2048m" # Java options to pass the JVM when starting a master
5 | GLINT_SERVER_OPTS="-Xmx2048m" # Java options to pass the JVM when starting a server
6 |
--------------------------------------------------------------------------------
/conf/servers.template:
--------------------------------------------------------------------------------
1 | # Here you can place the hostnames or ip addresses of machines on which you wish to start parameter servers
2 | localhost
3 |
--------------------------------------------------------------------------------
/dependency-reduced-pom.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | 4.0.0
4 | otcaix
5 | asyspark
6 | 1.0-SNAPSHOT
7 | 2016
8 |
9 | src/main/scala
10 | src/test/scala
11 |
12 |
13 | org.scala-tools
14 | maven-scala-plugin
15 |
16 |
17 |
18 | compile
19 | testCompile
20 |
21 |
22 |
23 |
24 | ${scala.version}
25 |
26 | -target:jvm-1.5
27 |
28 |
29 |
30 |
31 | maven-shade-plugin
32 |
33 |
34 | package
35 |
36 | shade
37 |
38 |
39 |
40 |
41 |
42 |
43 | *:*
44 |
45 | META-INF/*.SF
46 | META-INF/*.DSA
47 | META-INF/*.RSA
48 |
49 |
50 |
51 |
52 |
53 | org.apache.spark.asyspark.core.Main
54 |
55 |
56 | reference.conf
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 | scala-tools.org
66 | Scala-Tools Maven2 Repository
67 | http://scala-tools.org/repo-releases
68 |
69 |
70 |
71 |
72 | scala-tools.org
73 | Scala-Tools Maven2 Repository
74 | http://scala-tools.org/repo-releases
75 |
76 |
77 |
78 |
79 | junit
80 | junit
81 | 4.4
82 | test
83 |
84 |
85 | org.specs
86 | specs
87 | 1.2.5
88 | test
89 |
90 |
91 | scalatest
92 | org.scalatest
93 |
94 |
95 | scalacheck
96 | org.scalacheck
97 |
98 |
99 | jmock
100 | org.jmock
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 | org.scala-tools
109 | maven-scala-plugin
110 |
111 | ${scala.version}
112 |
113 |
114 |
115 |
116 |
117 | 2.11.8
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/docs/doc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASISCAS/asyspark/cff60c9d4ae6eb01691a8a8b86f3deca8b67a025/docs/doc
--------------------------------------------------------------------------------
/docs/problem:
--------------------------------------------------------------------------------
1 | 1: about the project package name:
2 | Because something can't be accessed outside org.apache.spark, so currently we just name
3 | the package as: org.apache.spark.asyspark...
4 | 2:
--------------------------------------------------------------------------------
/pom.xml:
--------------------------------------------------------------------------------
1 |
3 | 4.0.0
4 | otcaix
5 | asyspark
6 | 1.0-SNAPSHOT
7 | 2016
8 |
9 | 2.11.8
10 |
11 |
12 |
13 |
14 | scala-tools.org
15 | Scala-Tools Maven2 Repository
16 | http://scala-tools.org/repo-releases
17 |
18 |
19 |
20 |
21 |
22 | scala-tools.org
23 | Scala-Tools Maven2 Repository
24 | http://scala-tools.org/repo-releases
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 | org.scala-lang
33 | scala-library
34 | ${scala.version}
35 |
36 |
37 | junit
38 | junit
39 | 4.4
40 | test
41 |
42 |
43 | org.specs
44 | specs
45 | 1.2.5
46 | test
47 |
48 |
49 | org.apache.hadoop
50 | hadoop-client
51 | 2.7.1
52 |
53 |
54 | javax.servlet
55 | *
56 |
57 |
58 |
59 |
60 | org.apache.hadoop
61 | hadoop-hdfs
62 | 2.7.1
63 |
64 |
65 | org.apache.spark
66 | spark-core_2.11
67 | 2.0.0
68 |
69 |
70 |
71 | org.apache.spark
72 | spark-sql_2.11
73 | 2.0.0
74 |
75 |
76 |
77 | org.apache.spark
78 | spark-streaming_2.11
79 | 2.0.0
80 |
81 |
82 |
83 | org.apache.spark
84 | spark-mllib_2.11
85 | 2.0.0
86 |
87 |
88 |
89 | org.apache.spark
90 | spark-hive_2.11
91 | 2.0.0
92 |
93 |
94 | net.alchim31.maven
95 | scala-maven-plugin
96 | 3.2.2
97 |
98 |
99 | com.typesafe.scala-logging
100 | scala-logging-slf4j_2.11
101 | 2.1.2
102 |
103 |
104 | com.typesafe.akka
105 | akka-actor_2.11
106 | 2.4.10
107 |
108 |
109 | com.github.scopt
110 | scopt_2.10
111 | 3.5.0
112 |
113 |
114 |
115 | com.typesafe.akka
116 | akka-remote_2.11
117 | 2.4.10
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 | src/main/scala
127 | src/test/scala
128 |
129 |
130 | org.scala-tools
131 | maven-scala-plugin
132 |
133 |
134 |
135 | compile
136 | testCompile
137 |
138 |
139 |
140 |
141 | ${scala.version}
142 |
143 | -target:jvm-1.5
144 |
145 |
146 |
147 |
148 |
152 |
153 |
154 | org.apache.maven.plugins
155 | maven-shade-plugin
156 |
157 |
158 | package
159 |
160 | shade
161 |
162 |
163 |
164 |
165 |
166 |
167 | *:*
168 |
169 | META-INF/*.SF
170 | META-INF/*.DSA
171 | META-INF/*.RSA
172 |
173 |
174 |
175 |
176 |
177 | org.apache.spark.asyspark.core.Main
178 |
179 |
181 | reference.conf
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 | org.scala-tools
206 | maven-scala-plugin
207 |
208 | ${scala.version}
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
--------------------------------------------------------------------------------
/spark-asyspark_2.11.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
--------------------------------------------------------------------------------
/src/META-INF/MANIFEST.MF:
--------------------------------------------------------------------------------
1 | Manifest-Version: 1.0
2 | Main-Class: org.apache.spark.asyspark.core.Main
3 |
4 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/asyml/asysgd/AsyGradientDescent.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.asyml.asysgd
2 |
3 | import breeze.linalg.{DenseVector => BDV}
4 | import com.typesafe.scalalogging.slf4j.StrictLogging
5 | import org.apache.spark.mllib.linalg.{Vector, Vectors}
6 | import org.apache.spark.mllib.optimization.{Gradient, Updater}
7 | import org.apache.spark.rdd.RDD
8 |
9 | import scala.collection.mutable.ArrayBuffer
10 |
11 | /**
12 | * Created by wjf on 16-9-19.
13 | */
14 | object AsyGradientDescent extends StrictLogging {
15 | /**
16 | * Run asynchronous stochastic gradient descent (SGD) in parallel using mini batches.
17 | * In each iteration, we sample a subset (fraction miniBatchFraction) of the total data
18 | * in order to compute a gradient estimate.
19 | * Sampling, and averaging the subgradients over this subset is performed using one standard
20 | * spark map-reduce in each iteration.
21 | *
22 | * @param data Input data for SGD. RDD of the set of data examples, each of
23 | * the form (label, [feature values]).
24 | * @param gradient Gradient object (used to compute the gradient of the loss function of
25 | * one single data example)
26 | * @param updater Updater function to actually perform a gradient step in a given direction.
27 | * @param stepSize initial step size for the first step
28 | * @param numIterations number of iterations that SGD should be run.
29 | * @param regParam regularization parameter
30 | * @param miniBatchFraction fraction of the input data set that should be used for
31 | * one iteration of SGD. Default value 1.0.
32 | * @param convergenceTol Minibatch iteration will end before numIterations if the relative
33 | * difference between the current weight and the previous weight is less
34 | * than this value. In measuring convergence, L2 norm is calculated.
35 | * Default value 0.001. Must be between 0.0 and 1.0 inclusively.
36 | * @return A tuple containing two elements. The first element is a column matrix containing
37 | * weights for every feature, and the second element is an array containing the
38 | * stochastic loss computed for every iteration.
39 | */
40 | def runAsySGD(
41 | data: RDD[(Double, Vector)],
42 | gradient: Gradient,
43 | updater: Updater,
44 | stepSize: Double,
45 | numIterations: Int,
46 | regParam: Double,
47 | miniBatchFraction: Double,
48 | initialWeights: Vector,
49 | convergenceTol: Double): (Vector, Array[Double]) = {
50 |
51 | // convergenceTol should be set with non minibatch settings
52 | if (miniBatchFraction < 1.0 && convergenceTol > 0.0) {
53 | logger.warn("Testing against a convergenceTol when using miniBatchFraction " +
54 | "< 1.0 can be unstable because of the stochasticity in sampling.")
55 | }
56 |
57 | if (numIterations * miniBatchFraction < 1.0) {
58 | logger.warn("Not all examples will be used if numIterations * miniBatchFraction < 1.0: " +
59 | s"numIterations=$numIterations and miniBatchFraction=$miniBatchFraction")
60 | }
61 |
62 | val stochasticLossHistory = new ArrayBuffer[Double](numIterations)
63 |
64 | val numExamples = data.count()
65 |
66 | // if no data, return initial weights to avoid NaNs
67 | if (numExamples == 0) {
68 | logger.warn("GradientDescent.runMiniBatchSGD returning initial weights, no data found")
69 | return (initialWeights, stochasticLossHistory.toArray)
70 | }
71 |
72 | if (numExamples * miniBatchFraction < 1) {
73 | logger.warn("The miniBatchFraction is too small")
74 | }
75 |
76 | // Initialize weights as a column vector
77 | val n = data.first()._2.size
78 | GlobalWeight.setN(n)
79 | var weight: Vector = null
80 | if (initialWeights != null) {
81 | val weights = Vectors.dense(initialWeights.toArray)
82 | GlobalWeight.setWeight(weights)
83 | weight = weights
84 | } else {
85 | GlobalWeight.initWeight()
86 | weight = GlobalWeight.getWeight()
87 | }
88 |
89 |
90 | // todo add regval
91 | /**
92 | * For the first iteration, the regVal will be initialized as sum of weight squares
93 | * if it's L2 updater; for L1 updater, the same logic is followed.
94 | */
95 |
96 | // var regVal = updater.compute(weights, Vectors.zeros(weights.size), 0, 1, regParam)._2
97 |
98 | GlobalWeight.updateWeight(weight, Vectors.zeros(weight.size), 0, 1, regParam, convergenceTol)
99 | var regVal = GlobalWeight.getRegVal()
100 |
101 | data.foreachPartition {
102 | partition =>
103 | val array = new ArrayBuffer[(Double, Vector)]()
104 | while(partition.hasNext) {
105 | array += partition.next()
106 | }
107 | var convergence = false
108 | var i = 1
109 | val elementNum = array.size
110 | if (elementNum <= 0) {
111 | logger.warn(s" sorry, this partition has no elements, this worker will stop")
112 | convergence = true
113 | }
114 | while (i <= numIterations && !convergence) {
115 | // todo can do some optimization
116 | // val timesPerIter =10
117 | // for(j <- 0 until timesPerIter) {
118 | //
119 | // }
120 | val bcWeight = GlobalWeight.getWeight()
121 |
122 |
123 |
124 | // todo we can do a sample to avoid use all the data
125 |
126 | // compute gradient
127 | val (gradientSum, lossSum) = array.aggregate((BDV.zeros[Double](n), 0.0))(
128 | seqop = (c, v) => {
129 | val l = gradient.compute(v._2, v._1, bcWeight, Vectors.fromBreeze(c._1))
130 | (c._1, c._2 + l)
131 | },
132 | combop = (c1, c2) => {
133 | (c1._1 += c2._1, c1._2 + c2._2)
134 | }
135 | )
136 | // update gradient
137 |
138 | stochasticLossHistory += lossSum / elementNum + regVal
139 | // todo check whether update success
140 |
141 | val (success, conFlag) = GlobalWeight.updateWeight(bcWeight, Vectors.fromBreeze(gradientSum / elementNum.toDouble), stepSize, i, regParam, convergenceTol)
142 | regVal = GlobalWeight.getRegVal()
143 | if (conFlag) {
144 | convergence = true
145 | }
146 |
147 | i += 1
148 | }
149 | println("this is iii"+ i )
150 | }
151 | (GlobalWeight.getWeight(), stochasticLossHistory.toArray)
152 | }
153 |
154 | }
155 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/asyml/asysgd/GlobalWeight.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.asyml.asysgd
2 |
3 | import breeze.linalg.{Vector => BV, axpy => brzAxpy, norm => brzNorm}
4 | import org.apache.spark.mllib.linalg.{Vector, Vectors}
5 |
6 | /**
7 | * Created by wjf on 16-9-19.
8 | */
9 |
10 | object GlobalWeight extends Serializable {
11 | //TODO this is just a demo, will use a cluster to replace this
12 | //TODO the init weight can be optimized
13 | //TODO concurrent control should be optimized
14 | private var n: Int = 0
15 | private var weight: Vector = _
16 | private var regVal: Double = 0.0
17 |
18 | // must setN first before use GloabalWeight
19 | def setN(n: Int): Unit = {
20 | this.n = n
21 | }
22 |
23 | def getRegVal(): Double = {
24 | this.regVal
25 | }
26 |
27 | def initWeight(): Vector = {
28 | GlobalWeight.synchronized {
29 | if (this.weight == null) {
30 | this.weight = Vectors.dense(new Array[Double](n))
31 | }
32 | this.weight
33 | }
34 | }
35 |
36 | def setWeight(weight: Vector): Boolean = {
37 | this.weight = weight
38 | true
39 | }
40 |
41 | def getWeight(): Vector = {
42 | require(n > 0, s"you must set n before useing GlabelWeight, temp n is ${n}")
43 | this.weight
44 | }
45 |
46 | def updateWeight(
47 | weightsOld: Vector,
48 | gradient: Vector,
49 | stepSize: Double,
50 | iter: Int,
51 | regParam: Double, convergenceTol: Double): (Boolean, Boolean) = {
52 | // add up both updates from the gradient of the loss (= step) as well as
53 | // the gradient of the regularizer (= regParam * weightsOld)
54 | // w' = w - thisIterStepSize * (gradient + regParam * w)
55 | // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient
56 |
57 | val thisIterStepSize = stepSize / math.sqrt(iter)
58 | val brzWeights: BV[Double] = weightsOld.asBreeze.toDenseVector
59 | brzWeights :*= (1.0 - thisIterStepSize * regParam)
60 | brzAxpy(-thisIterStepSize, gradient.asBreeze, brzWeights)
61 | val norm = brzNorm(brzWeights, 2.0)
62 | val currentWeight = Vectors.fromBreeze(brzWeights)
63 | val flag = isConverged(weightsOld, currentWeight, convergenceTol)
64 | this.weight = Vectors.fromBreeze(brzWeights)
65 | this.regVal = 0.5 * regParam * norm * norm
66 | (true, flag)
67 | }
68 |
69 | private def isConverged(
70 | previousWeights: Vector,
71 | currentWeights: Vector,
72 | convergenceTol: Double): Boolean = {
73 | // To compare with convergence tolerance.
74 |
75 | /**
76 | * maybe there a problems with previousWeights.toDense
77 | * the before version is previousWeights.asBreeze.toDenseVector
78 | */
79 | val previousBDV = previousWeights.asBreeze.toDenseVector
80 | val currentBDV = currentWeights.asBreeze.toDenseVector
81 |
82 | // This represents the difference of updated weights in the iteration.
83 | val solutionVecDiff: Double = brzNorm(previousBDV - currentBDV)
84 |
85 | solutionVecDiff < convergenceTol * Math.max(brzNorm(currentBDV), 1.0)
86 | }
87 | }
88 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/Client.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core
2 |
3 | import java.io.File
4 | import java.util.concurrent.TimeUnit
5 |
6 | import akka.actor.{Actor, ActorLogging, ActorSystem, Props, _}
7 | import akka.pattern.ask
8 | import akka.remote.RemoteScope
9 | import akka.util.Timeout
10 | import com.typesafe.config.{Config, ConfigFactory}
11 | import com.typesafe.scalalogging.slf4j.StrictLogging
12 | import org.apache.spark.asyspark.core.messages.master.{RegisterClient, ServerList}
13 | import org.apache.spark.asyspark.core.models.client.BigVector
14 | import org.apache.spark.asyspark.core.models.client.asyImp.{AsyBigVectorDouble, AsyBigVectorFloat, AsyBigVectorInt, AsyBigVectorLong}
15 | import org.apache.spark.asyspark.core.models.server.{PartialVectorDouble, PartialVectorFloat, PartialVectorInt, PartialVectorLong}
16 | import org.apache.spark.asyspark.core.partitions.range.RangePartitioner
17 | import org.apache.spark.asyspark.core.partitions.{Partition, Partitioner}
18 |
19 | import scala.concurrent.duration._
20 | import scala.concurrent.{Await, ExecutionContext, Future}
21 | import scala.reflect.runtime.universe.{TypeTag, typeOf}
22 | /**
23 | * The client provides the functions needed to spawn large distributed matrices and vectors on the parameter servers.
24 | * Use the companion object to construct a Client object from a configuration file.
25 | * Created by wjf on 16-9-24.
26 | */
27 | class Client(val config: Config, private[asyspark] val system: ActorSystem,
28 | private[asyspark] val master: ActorRef) {
29 | private implicit val timeout = Timeout(config.getDuration("asyspark.client.timeout", TimeUnit.MILLISECONDS) milliseconds)
30 | private implicit val ec = ExecutionContext.Implicits.global
31 | private[asyspark] val actor = system.actorOf(Props[ClientActor])
32 | // use ask to get a reply
33 | private[asyspark] val registration = master ? RegisterClient(actor)
34 |
35 |
36 |
37 |
38 | /**
39 | * Creates a distributed model on the parameter servers
40 | *
41 | * @param keys The total number of keys
42 | * @param modelsPerServer The number of models to spawn per parameter server
43 | * @param createPartitioner A function that creates a partitioner based on a number of keys and partitions
44 | * @param generateServerProp A function that generates a server prop of a partial model for a particular partition
45 | * @param generateClientObject A function that generates a client object based on the partitioner and spawned models
46 | * @tparam M The final model type to generate
47 | * @return The generated model
48 | */
49 | private def create[M](keys: Long,
50 | modelsPerServer: Int,
51 | createPartitioner: (Int, Long) => Partitioner,
52 | generateServerProp: Partition => Props,
53 | generateClientObject: (Partitioner, Array[ActorRef], Config) => M): M = {
54 |
55 | // Get a list of servers
56 | val listOfServers = serverList()
57 | println(listOfServers.isCompleted)
58 |
59 | // Construct a big model based on the list of servers
60 | val bigModelFuture = listOfServers.map { servers =>
61 |
62 | // Check if there are servers online
63 | if (servers.isEmpty) {
64 | throw new Exception("Cannot create a model without active parameter servers")
65 | }
66 |
67 | // Construct a partitioner
68 | val numberOfPartitions = Math.min(keys, modelsPerServer * servers.length).toInt
69 | val partitioner = createPartitioner(numberOfPartitions, keys)
70 | val partitions = partitioner.all()
71 |
72 | // Spawn models that are deployed on the parameter servers according to the partitioner
73 | val models = new Array[ActorRef](numberOfPartitions)
74 | var partitionIndex = 0
75 | while (partitionIndex < numberOfPartitions) {
76 | val serverIndex = partitionIndex % servers.length
77 | val server = servers(serverIndex)
78 | val partition = partitions(partitionIndex)
79 | val prop = generateServerProp(partition)
80 | models(partitionIndex) = system.actorOf(prop.withDeploy(Deploy(scope = RemoteScope(server.path.address))))
81 | partitionIndex += 1
82 | }
83 |
84 | // Construct a big model client object
85 | generateClientObject(partitioner, models, config)
86 | }
87 |
88 | // Wait for the big model to finish
89 | Await.result(bigModelFuture, config.getDuration("asyspark.client.timeout", TimeUnit.MILLISECONDS) milliseconds)
90 |
91 | }
92 | def bigVector[V: breeze.math.Semiring: TypeTag](keys: Long, modelsPerServer: Int = 1): BigVector[V] = {
93 | bigVector[V](keys, modelsPerServer, (partitions: Int, keys: Long) => RangePartitioner(partitions, keys))
94 | }
95 | def bigVector[V: breeze.math.Semiring: TypeTag](keys: Long, modelsPerServer: Int,
96 | createPartitioner: (Int, Long) => Partitioner): BigVector[V] = {
97 |
98 | val propFunction = numberType[V] match {
99 | case "Int" => (partition: Partition) => Props(classOf[PartialVectorInt], partition)
100 | case "Long" => (partition: Partition) => Props(classOf[PartialVectorLong], partition)
101 | case "Flaot" => (partition: Partition) => Props(classOf[PartialVectorFloat], partition)
102 | case "Double" => (partition: Partition) => Props(classOf[PartialVectorDouble], partition)
103 | case x =>
104 | throw new Exception(s"cannot create model for unsupported value tupe")
105 |
106 | }
107 |
108 | val objFunction = numberType[V] match {
109 | case "Int" => (partitioner: Partitioner, models: Array[ActorRef], config: Config) =>
110 | new AsyBigVectorInt(partitioner, models, config, keys).asInstanceOf[BigVector[V]]
111 | case "Long" => (partitioner: Partitioner, models: Array[ActorRef], config: Config) =>
112 | new AsyBigVectorLong(partitioner, models, config, keys).asInstanceOf[BigVector[V]]
113 | case "Float" => (partitioner: Partitioner, models: Array[ActorRef], config: Config) =>
114 | new AsyBigVectorFloat(partitioner, models, config, keys).asInstanceOf[BigVector[V]]
115 | case "Double" => (partitioner: Partitioner, models: Array[ActorRef], config: Config) =>
116 | new AsyBigVectorDouble(partitioner, models, config, keys).asInstanceOf[BigVector[V]]
117 | case x => throw new Exception(s"Cannot create model for unsupported value type $x")
118 | }
119 | create[BigVector[V]](keys, modelsPerServer, createPartitioner, propFunction, objFunction)
120 |
121 |
122 |
123 | }
124 |
125 | def serverList(): Future[Array[ActorRef]] = {
126 | (master ? ServerList()).mapTo[Array[ActorRef]]
127 | }
128 |
129 | def numberType[V: TypeTag]: String = {
130 | implicitly[TypeTag[V]].tpe match {
131 | case x if x <:< typeOf[Int] => "Int"
132 | case x if x <:< typeOf[Long] => "Long"
133 | case x if x <:< typeOf[Float] => "Float"
134 | case x if x <:< typeOf[Double] => "Double"
135 | case x => s"${x.toString}"
136 | }
137 | }
138 |
139 | /**
140 | * Stops the client
141 | */
142 | def stop(): Unit ={
143 | system.terminate()
144 | }
145 |
146 | }
147 |
148 | /**
149 | * Contains functions to easily create a client object that is connected to the asyspark cluster.
150 | *
151 | * You can construct a client with a specific configuration:
152 | * {{{
153 | * import asyspark.Client
154 | *
155 | * import java.io.File
156 | * import com.typesafe.config.ConfigFactory
157 | *
158 | * val config = ConfigFactory.parseFile(new File("/your/file.conf"))
159 | * val client = Client(config)
160 | * }}}
161 | *
162 | * The resulting client object can then be used to create distributed matrices or vectors on the available parameter
163 | * servers:
164 | * {{{
165 | * val matrix = client.matrix[Double](10000, 50)
166 | * }}}
167 | */
168 | object Client extends StrictLogging {
169 |
170 | /**
171 | * Constructs a client with the default configuration
172 | *
173 | * @return The client
174 | */
175 | def apply(): Client = {
176 | this(ConfigFactory.empty())
177 | }
178 |
179 | /**
180 | * Constructs a client
181 | *
182 | * @param config The configuration
183 | * @return A future Client
184 | */
185 | def apply(config: Config): Client = {
186 | val default = ConfigFactory.parseResourcesAnySyntax("asyspark")
187 | val conf = config.withFallback(default).resolve()
188 | Await.result(start(conf), conf.getDuration("asyspark.client.timeout", TimeUnit.MILLISECONDS) milliseconds)
189 | }
190 |
191 | /**
192 | * Implementation to start a client by constructing an ActorSystem and establishing a connection to a master. It
193 | * creates the Client object and checks if its registration actually succeeds
194 | *
195 | * @param config The configuration
196 | * @return The future client
197 | */
198 | private def start(config: Config): Future[Client] = {
199 |
200 | // Get information from config
201 | logger.debug("start client")
202 | val masterHost = config.getString("asyspark.master.host")
203 | val masterPort = config.getInt("asyspark.master.port")
204 | println(masterPort)
205 | val masterName = config.getString("asyspark.master.name")
206 | val masterSystem = config.getString("asyspark.master.system")
207 |
208 | // Construct system and reference to master
209 | val system = ActorSystem(config.getString("asyspark.client.system"), config.getConfig("asyspark.client"))
210 | val master = system.actorSelection(s"akka.tcp://${masterSystem}@${masterHost}:${masterPort}/user/${masterName}")
211 |
212 | // Set up implicit values for concurrency
213 | implicit val ec = ExecutionContext.Implicits.global
214 | implicit val timeout = Timeout(config.getDuration("asyspark.client.timeout", TimeUnit.MILLISECONDS) milliseconds)
215 |
216 | // Resolve master node asynchronously
217 | val masterFuture = master.resolveOne()
218 |
219 | // Construct client based on resolved master asynchronously
220 | masterFuture.flatMap {
221 | case m =>
222 | val client = new Client(config, system, m)
223 | logger.debug("construct a client")
224 | client.registration.map {
225 | case true => client
226 | case _ => throw new RuntimeException("Invalid client registration response from master")
227 | }
228 | }
229 | }
230 |
231 |
232 | }
233 |
234 | /**
235 | * The client actor class. The master keeps a death watch on this actor and knows when it is terminated.
236 | *
237 | * This actor either gets terminated when the system shuts down (e.g. when the Client object is destroyed) or when it
238 | * crashes unexpectedly.
239 | */
240 | private class ClientActor extends Actor with ActorLogging {
241 | override def receive: Receive = {
242 | case x => log.info(s"Client actor received message ${x}")
243 | }
244 | }
245 |
246 | object TestClient {
247 | def main(args: Array[String]): Unit = {
248 |
249 | val default = ConfigFactory.parseResourcesAnySyntax("asyspark")
250 | val config = ConfigFactory.parseFile(new File(getClass.getClassLoader.getResource("asyspark.conf").getFile)).withFallback(default).resolve()
251 | val client =Client(config)
252 | val vector = client.bigVector[Long](2)
253 | val keys = Array[Long](1,2)
254 | val values = Array[Long](1,2)
255 | // vector.push()
256 | println(vector.size)
257 |
258 | }
259 |
260 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/Exceptios/PullFailedException.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.Exceptios
2 |
3 | /**
4 | * Created by wjf on 16-9-25.
5 | */
6 | private[asyspark] class PullFailedException(message: String) extends Exception(message){
7 |
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/Exceptios/PushFailedException.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.Exceptios
2 |
3 | /**
4 | * Created by wjf on 16-9-25.
5 | */
6 | class PushFailedException(message: String) extends Exception(message) {
7 |
8 | }
9 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/Main.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core
2 |
3 | import java.io.File
4 |
5 | import com.typesafe.config.ConfigFactory
6 | import com.typesafe.scalalogging.slf4j.StrictLogging
7 | import org.apache.log4j.BasicConfigurator
8 |
9 | import scala.concurrent.ExecutionContext
10 |
11 | /**
12 | * This is the main class that runs when you start asyspark. By manually specifying additional command-line options it is
13 | * possible to start a master node or a parameter server.
14 | *
15 | * To start a master node:
16 | * {{{
17 | * java -jar /path/to/compiled/asyspark.jar master -c /path/to/asyspark.conf
18 | * }}}
19 | *
20 | * To start a parameter server node:
21 | * {{{
22 | * java -jar /path/to/compiled/asyspark.jar server -c /path/to/asyspark.conf
23 | * }}}
24 | *
25 | * Alternatively you can use the scripts provided in the ./sbin/ folder of the project to automatically construct a
26 | * master and servers over passwordless ssh.
27 | */
28 | object Main extends StrictLogging {
29 |
30 | /**
31 | * Main entry point of the application
32 | *
33 | * @param args The command-line arguments
34 | */
35 | def main(args: Array[String]): Unit = {
36 | // BasicConfigurator.configure()
37 |
38 | val parser = new scopt.OptionParser[Options]("asyspark") {
39 | head("asyspark", "0.1")
40 | opt[File]('c', "config") valueName "" action { (x, c) =>
41 | c.copy(config = x)
42 | } text "The .conf file for asyspark"
43 | cmd("master") action { (_, c) =>
44 | c.copy(mode = "master")
45 | } text "Starts a master node."
46 | cmd("server") action { (_, c) =>
47 | c.copy(mode = "server")
48 | } text "Starts a server node."
49 | }
50 |
51 | parser.parse(args, Options()) match {
52 | case Some(options) =>
53 |
54 | // Read configuration
55 | logger.debug("Parsing configuration file")
56 | val default = ConfigFactory.parseResourcesAnySyntax("asyspark")
57 | val config = ConfigFactory.parseFile(options.config).withFallback(default).resolve()
58 |
59 |
60 | // Start specified mode of operation
61 | implicit val ec = ExecutionContext.Implicits.global
62 | options.mode match {
63 | case "server" => Server.run(config).onSuccess {
64 | case (system, ref) => sys.addShutdownHook {
65 | logger.info("Shutting down")
66 | system.shutdown()
67 | system.awaitTermination()
68 | }
69 | }
70 | case "master" => Master.run(config).onSuccess {
71 | case (system, ref) => sys.addShutdownHook {
72 | logger.info("Shutting down")
73 | system.shutdown()
74 | system.awaitTermination()
75 | }
76 | }
77 | case _ =>
78 | parser.showUsageAsError
79 | System.exit(1)
80 | }
81 | case None => System.exit(1)
82 | }
83 |
84 | }
85 |
86 | /**
87 | * Command-line options
88 | *
89 | * @param mode The mode of operation (either "master" or "server")
90 | * @param config The configuration file to load (defaults to the included asyspark.conf)
91 | * @param host The host of the parameter server (only when mode of operation is "server")
92 | * @param port The port of the parameter server (only when mode of operation is "server")
93 | */
94 | private case class Options(mode: String = "",
95 | config: File = new File(getClass.getClassLoader.getResource("asyspark.conf").getFile),
96 | host: String = "localhost",
97 | port: Int = 0)
98 |
99 | }
100 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/Master.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core
2 |
3 | import java.util.concurrent.TimeUnit
4 |
5 | import akka.actor.{Actor, ActorLogging, ActorRef, ActorSystem, Address, Props, Terminated}
6 | import akka.util.Timeout
7 | import com.typesafe.config.Config
8 | import com.typesafe.scalalogging.slf4j.StrictLogging
9 | import org.apache.spark.asyspark.core.messages.master.{ClientList, RegisterClient, RegisterServer, ServerList}
10 |
11 | import scala.concurrent.duration._
12 | import scala.concurrent.{ExecutionContext, Future}
13 |
14 |
15 | /**
16 | * Created by wjf on 16-9-22.
17 | * the master that registers the spark workers
18 | */
19 | class Master() extends Actor with ActorLogging {
20 |
21 | /**
22 | * collection of servers available
23 | */
24 | var servers = Set.empty[ActorRef]
25 |
26 | /**
27 | * collection of clients
28 | */
29 | var clients = Set.empty[ActorRef]
30 |
31 | override def receive: Receive = {
32 | case RegisterServer(server) =>
33 | log.info(s"Registering server ${server.path.toString}")
34 | println("register server")
35 | servers += server
36 | context.watch(server)
37 | sender ! true
38 |
39 | case RegisterClient(client) =>
40 | log.info(s"Registering client ${sender.path.toString}")
41 | clients += client
42 | context.watch(client)
43 | sender ! true
44 |
45 | case ServerList() =>
46 | log.info(s"Sending current server list to ${sender.path.toString}")
47 | sender ! servers.toArray
48 |
49 | case ClientList() =>
50 | log.info(s"Sending current client list to ${sender.path.toString}")
51 | sender ! clients.toArray
52 |
53 |
54 | case Terminated(actor) =>
55 | actor match {
56 | case server: ActorRef if servers contains server =>
57 | log.info(s"Removing server ${server.path.toString}")
58 | servers -= server
59 | case client: ActorRef if clients contains client =>
60 | log.info(s"Removing client ${client.path.toString}")
61 | clients -= client
62 | case actor: ActorRef =>
63 | log.warning(s"Actor ${actor.path.toString} will be terminated for some unknown reason")
64 | }
65 | }
66 |
67 | }
68 |
69 | object Master extends StrictLogging {
70 | def run(config: Config): Future[(ActorSystem, ActorRef)] = {
71 | logger.debug("Starting master actor system")
72 | val system = ActorSystem(config.getString("asyspark.master.system"), config.getConfig("asyspark.master"))
73 | logger.debug("Starting master")
74 | val master = system.actorOf(Props[Master], config.getString("asyspark.master.name"))
75 | implicit val timeout = Timeout(config.getDuration("asyspark.master.startup-timeout", TimeUnit.MILLISECONDS) milliseconds)
76 | implicit val ec = ExecutionContext.Implicits.global
77 | val address = Address("akka.tcp", config.getString("asyspark.master.system"), config.getString("asyspark.master.host"),
78 | config.getString("asyspark.master.port").toInt)
79 | system.actorSelection(master.path.toSerializationFormat).resolveOne().map {
80 | case actor: ActorRef =>
81 | logger.debug("Master successfully started")
82 | (system, master)
83 |
84 | }
85 | }
86 |
87 | }
88 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/Server.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core
2 |
3 | /**
4 | * Created by wjf on 16-9-22.
5 | */
6 |
7 | import java.util.concurrent.TimeUnit
8 |
9 | import akka.actor._
10 | import akka.pattern.ask
11 | import akka.util.Timeout
12 | import com.typesafe.config.Config
13 | import com.typesafe.scalalogging.slf4j.StrictLogging
14 | import org.apache.spark.asyspark.core.messages.master.RegisterServer
15 |
16 | import scala.concurrent.duration._
17 | import scala.concurrent.{ExecutionContext, Future}
18 |
19 | /**
20 | * A parameter server
21 | */
22 | private class Server extends Actor with ActorLogging {
23 |
24 | override def receive: Receive = {
25 | case x =>
26 | log.warning(s"Received unknown message of type ${x.getClass}")
27 | }
28 | }
29 | /**
30 | * The parameter server object
31 | */
32 | private object Server extends StrictLogging {
33 |
34 | /**
35 | * Starts a parameter server ready to receive commands
36 | *
37 | * @param config The configuration
38 | * @return A future containing the started actor system and reference to the server actor
39 | */
40 | def run(config: Config): Future[(ActorSystem, ActorRef)] = {
41 | println("server run")
42 |
43 |
44 | logger.debug(s"Starting actor system ${config.getString("asyspark.server.system")}")
45 | val system = ActorSystem(config.getString("asyspark.server.system"), config.getConfig("asyspark.server"))
46 |
47 | logger.debug("Starting server actor")
48 | val server = system.actorOf(Props[Server], config.getString("asyspark.server.name"))
49 | println(server.path)
50 |
51 | logger.debug("Reading master information from config")
52 | val masterHost = config.getString("asyspark.master.host")
53 | val masterPort = config.getInt("asyspark.master.port")
54 | val masterName = config.getString("asyspark.master.name")
55 | val masterSystem = config.getString("asyspark.master.system")
56 |
57 | logger.info(s"Registering with master ${masterSystem}@${masterHost}:${masterPort}/user/${masterName}")
58 | implicit val ec = ExecutionContext.Implicits.global
59 | implicit val timeout = Timeout(config.getDuration("asyspark.server.registration-timeout", TimeUnit.MILLISECONDS) milliseconds)
60 | val master = system.actorSelection(s"akka.tcp://${masterSystem}@${masterHost}:${masterPort}/user/${masterName}")
61 | val registration = master ? RegisterServer(server)
62 |
63 |
64 |
65 | registration.map {
66 | case a =>
67 | logger.info("Server successfully registered with master")
68 | (system, server)
69 | }
70 |
71 | }
72 | }
73 |
74 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/master/ClientList.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.master
2 |
3 | /**
4 | * Created by wjf on 16-9-24.
5 | */
6 | private[asyspark] case class ClientList()
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/master/RegisterClient.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.master
2 |
3 | import akka.actor.ActorRef
4 |
5 | /**
6 | * Created by wjf on 16-9-24.
7 | */
8 | private[asyspark] case class RegisterClient(client: ActorRef)
9 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/master/RegisterServer.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.master
2 |
3 | import akka.actor.ActorRef
4 |
5 | /**
6 | * Created by wjf on 16-9-22.
7 | */
8 | private[asyspark] case class RegisterServer(server: ActorRef)
9 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/master/ServerList.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.master
2 |
3 | /**
4 | * Created by wjf on 16-9-22.
5 | */
6 | private[asyspark] case class ServerList()
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/AcknowledgeReceipt.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.logic
2 |
3 | /**
4 | * Created by wjf on 16-9-25.
5 | */
6 | private[asyspark]case class AcknowledgeReceipt(id: Int)
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/Forget.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.logic
2 |
3 | /**
4 | * Created by wjf on 16-9-25.
5 | */
6 | private[asyspark] case class Forget(id: Int)
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/GenerateUniqueID.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.logic
2 |
3 | /**
4 | * Created by wjf on 16-9-25.
5 | */
6 | private[asyspark] case class GenerateUniqueID()
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/NotAcknowledgeReceipt.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.logic
2 |
3 | /**
4 | * Created by wjf on 16-9-25.
5 | */
6 | private[asyspark] case class NotAcknowledgeReceipt(id: Long)
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/logic/UniqueID.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.logic
2 |
3 | /**
4 | * Created by wjf on 16-9-25.
5 | */
6 | private[asyspark] case class UniqueID(id: Int)
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PullVector.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.request
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] case class PullVector(keys: Array[Long]) extends Request
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PushVectorDouble.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.request
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] case class PushVectorDouble(id: Int, keys: Array[Long], values: Array[Double]) extends Request
7 |
8 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PushVectorFloat.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.request
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] case class PushVectorFloat(id: Int, keys:Array[Long], values:Array[Float]) extends Request
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PushVectorInt.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.request
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] case class PushVectorInt(id: Int, keys: Array[Long], values: Array[Int]) extends Request
7 |
8 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/request/PushVectorLong.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.request
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] case class PushVectorLong(id: Int, keys: Array[Long], values:Array[Long]) extends Request
7 |
8 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/request/Request.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.request
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] trait Request
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/response/Response.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.response
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] trait Response
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/response/ResponseDouble.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.response
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] case class ResponseDouble(values: Array[Double]) extends Response
7 |
8 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/response/ResponseFloat.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.response
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] case class ResponseFloat(values: Array[Float]) extends Response
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/response/ResponseInt.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.response
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] case class ResponseInt(values: Array[Int]) extends Response
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/messages/server/response/ResponseLong.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.messages.server.response
2 |
3 | /**
4 | * Created by wjf on 16-9-23.
5 | */
6 | private[asyspark] case class ResponseLong(values: Array[Long]) extends Response
7 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/client/BigVector.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.client
2 |
3 | import scala.concurrent.{ExecutionContext, Future}
4 |
5 | /**
6 | * A big vector supporting basic parameter server element-wise operations
7 | * {{{
8 | * val vector:BigVector[Int] = ...
9 | * vector.pull()
10 | * vector.push()
11 | * vector.destroy()
12 | * }}}
13 | * Created by wjf on 16-9-25.
14 | */
15 | trait BigVector[V] extends Serializable {
16 | val size: Long
17 |
18 |
19 | /**
20 | * Pull a set of elements
21 | * @param keys The Indices of the elements
22 | * @param ec The implicit execution context in which to execute the request
23 | * @return A Future containing the values of the elements at given rows, columns
24 | */
25 | //TODO it's convenient fo use scala.concurrent.ExecutionContext.Implicits.global
26 | // but we should do some optimization in production env
27 | def pull(keys: Array[Long])(implicit ec:ExecutionContext): Future[Array[V]]
28 |
29 | /**
30 | * Push a set of elements
31 | * @param keys The indices of the rows
32 | * @param values The values to update
33 | * @param ec The Implicit execution context
34 | * @return A future
35 | */
36 | def push(keys: Array[Long], values: Array[V])(implicit ec: ExecutionContext): Future[Boolean]
37 |
38 | /**
39 | * Destroy the big vectors and its resources on the parameter server
40 | * @param ec The implicit execution context in which to execute the request
41 | * @return A future
42 | */
43 | def destroy()(implicit ec: ExecutionContext): Future[Boolean]
44 | }
45 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVector.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.client.asyImp
2 |
3 | import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
4 |
5 | import akka.actor.{ActorRef, ExtendedActorSystem}
6 | import akka.pattern.Patterns.gracefulStop
7 | import akka.serialization.JavaSerializer
8 | import breeze.linalg.DenseVector
9 | import breeze.math.Semiring
10 | import com.typesafe.config.Config
11 | import org.apache.spark.asyspark.core.messages.server.request.PullVector
12 | import org.apache.spark.asyspark.core.models.client.BigVector
13 | import org.apache.spark.asyspark.core.partitions.{Partition, Partitioner}
14 | import spire.implicits._
15 |
16 | import scala.concurrent.duration._
17 | import scala.concurrent.{ExecutionContext, Future}
18 | import scala.reflect.ClassTag
19 |
20 | /**
21 | * An asynchronous implementation of the [[glint.models.client.BigVector BigVector]]. You don't want to construct this
22 | * object manually but instead use the methods provided in [[glint.Client Client]], as so
23 | * {{{
24 | * client.vector[Double](keys)
25 | * }}}
26 | *
27 | * @param partitioner A partitioner to map keys to parameter servers
28 | * @param models The partial models on the parameter servers
29 | * @param size The number of keys
30 | * @tparam V The type of values to store
31 | * @tparam R The type of responses we expect to get from the parameter servers
32 | * @tparam P The type of push requests we should send to the parameter servers
33 | */
34 | abstract class AsyBigVector[@specialized V: Semiring : ClassTag, R: ClassTag, P: ClassTag](partitioner: Partitioner,
35 | models: Array[ActorRef],
36 | config: Config,
37 | val size: Long)
38 | extends BigVector[V] {
39 |
40 | PullFSM.initialize(config)
41 | PushFSM.initialize(config)
42 |
43 | /**
44 | * Pulls a set of elements
45 | *
46 | * @param keys The indices of the keys
47 | * @return A future containing the values of the elements at given rows, columns
48 | */
49 | override def pull(keys: Array[Long])(implicit ec: ExecutionContext): Future[Array[V]] = {
50 |
51 | // Send pull request of the list of keys
52 | val pulls = mapPartitions(keys) {
53 | case (partition, indices) =>
54 | val pullMessage = PullVector(indices.map(keys).toArray)
55 | val fsm = PullFSM[PullVector, R](pullMessage, models(partition.index))
56 | fsm.run()
57 | //(models(partition.index) ? pullMessage).mapTo[R]
58 | }
59 |
60 | // Obtain key indices after partitioning so we can place the results in a correctly ordered array
61 | val indices = keys.zipWithIndex.groupBy {
62 | case (k, i) => partitioner.partition(k)
63 | }.map {
64 | case (_, arr) => arr.map(_._2)
65 | }.toArray
66 |
67 | // Define aggregator for successful responses
68 | def aggregateSuccess(responses: Iterable[R]): Array[V] = {
69 | val responsesArray = responses.toArray
70 | val result = DenseVector.zeros[V](keys.length)
71 | cforRange(0 until responsesArray.length)(i => {
72 | val response = responsesArray(i)
73 | val idx = indices(i)
74 | cforRange(0 until idx.length)(j => {
75 | result(idx(j)) = toValue(response, j)
76 | })
77 | })
78 | result.toArray
79 | }
80 |
81 | // Combine and aggregate futures
82 | Future.sequence(pulls).transform(aggregateSuccess, err => err)
83 |
84 | }
85 |
86 | /**
87 | * Groups indices of given keys into partitions according to this models partitioner and maps each partition to a
88 | * type T
89 | *
90 | * @param keys The keys to partition
91 | * @param func The function that takes a partition and corresponding indices and creates something of type T
92 | * @tparam T The type to map to
93 | * @return An iterable over the partitioned results
94 | */
95 | @inline
96 | private def mapPartitions[T](keys: Seq[Long])(func: (Partition, Seq[Int]) => T): Iterable[T] = {
97 | keys.indices.groupBy(i => partitioner.partition(keys(i))).map { case (a, b) => func(a, b) }
98 | }
99 |
100 | /**
101 | * Pushes a set of values
102 | *
103 | * @param keys The indices of the keys
104 | * @param values The values to update
105 | * @return A future containing either the success or failure of the operation
106 | */
107 | override def push(keys: Array[Long], values: Array[V])(implicit ec: ExecutionContext): Future[Boolean] = {
108 |
109 | // Send push requests
110 | val pushes = mapPartitions(keys) {
111 | case (partition, indices) =>
112 | val ks = indices.map(keys).toArray
113 | val vs = indices.map(values).toArray
114 | val fsm = PushFSM[P]((id) => toPushMessage(id, ks, vs), models(partition.index))
115 | fsm.run()
116 | }
117 |
118 | // Combine and aggregate futures
119 | Future.sequence(pushes).transform(results => true, err => err)
120 |
121 | }
122 |
123 | /**
124 | * @return The number of partitions this big vector's data is spread across
125 | */
126 | def nrOfPartitions: Int = {
127 | partitioner.all().length
128 | }
129 |
130 | /**
131 | * Destroys the matrix on the parameter servers
132 | *
133 | * @return A future whether the matrix was successfully destroyed
134 | */
135 | override def destroy()(implicit ec: ExecutionContext): Future[Boolean] = {
136 | val partitionFutures = partitioner.all().map {
137 | case partition => gracefulStop(models(partition.index), 60 seconds)
138 | }.toIterator
139 | Future.sequence(partitionFutures).transform(successes => successes.forall(success => success), err => err)
140 | }
141 |
142 | /**
143 | * Extracts a value from a given response at given index
144 | *
145 | * @param response The response
146 | * @param index The index
147 | * @return The value
148 | */
149 | @inline
150 | protected def toValue(response: R, index: Int): V
151 |
152 | /**
153 | * Creates a push message from given sequence of keys and values
154 | *
155 | * @param id The identifier
156 | * @param keys The rows
157 | * @param values The values
158 | * @return A PushMatrix message for type V
159 | */
160 | @inline
161 | protected def toPushMessage(id: Int, keys: Array[Long], values: Array[V]): P
162 |
163 | /**
164 | * Deserializes this instance. This starts an ActorSystem with appropriate configuration before attempting to
165 | * deserialize ActorRefs
166 | *
167 | * @param in The object input stream
168 | * @throws java.io.IOException A possible Input/Output exception
169 | */
170 | @throws(classOf[IOException])
171 | private def readObject(in: ObjectInputStream): Unit = {
172 | val config = in.readObject().asInstanceOf[Config]
173 | val as = DeserializationHelper.getActorSystem(config.getConfig("glint.client"))
174 | JavaSerializer.currentSystem.withValue(as.asInstanceOf[ExtendedActorSystem]) {
175 | in.defaultReadObject()
176 | }
177 | }
178 |
179 | /**
180 | * Serializes this instance. This first writes the config before the entire object to ensure we can deserialize with
181 | * an ActorSystem in place
182 | *
183 | * @param out The object output stream
184 | * @throws java.io.IOException A possible Input/Output exception
185 | */
186 | @throws(classOf[IOException])
187 | private def writeObject(out: ObjectOutputStream): Unit = {
188 | out.writeObject(config)
189 | out.defaultWriteObject()
190 | }
191 |
192 | }
193 |
194 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVectorDouble.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.client.asyImp
2 | import akka.actor.ActorRef
3 | import com.typesafe.config.Config
4 | import org.apache.spark.asyspark.core.messages.server.request.PushVectorDouble
5 | import org.apache.spark.asyspark.core.messages.server.response.ResponseDouble
6 | import org.apache.spark.asyspark.core.partitions.Partitioner
7 |
8 | /**
9 | * Asynchronous implementation of a BigVector for doubles
10 | */
11 | private[asyspark] class AsyBigVectorDouble(partitioner: Partitioner,
12 | models: Array[ActorRef],
13 | config: Config,
14 | keys: Long)
15 | extends AsyBigVector[Double, ResponseDouble, PushVectorDouble](partitioner, models, config, keys) {
16 |
17 | /**
18 | * Creates a push message from given sequence of keys and values
19 | *
20 | * @param id The identifier
21 | * @param keys The keys
22 | * @param values The values
23 | * @return A PushVectorDouble message for type V
24 | */
25 | @inline
26 | override protected def toPushMessage(id: Int, keys: Array[Long], values: Array[Double]): PushVectorDouble = {
27 | PushVectorDouble(id, keys, values)
28 | }
29 |
30 | /**
31 | * Extracts a value from a given response at given index
32 | *
33 | * @param response The response
34 | * @param index The index
35 | * @return The value
36 | */
37 | @inline
38 | override protected def toValue(response: ResponseDouble, index: Int): Double = response.values(index)
39 |
40 | }
41 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVectorFloat.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.client.asyImp
2 |
3 |
4 | import akka.actor.ActorRef
5 | import com.typesafe.config.Config
6 | import org.apache.spark.asyspark.core.messages.server.request.PushVectorFloat
7 | import org.apache.spark.asyspark.core.messages.server.response.ResponseFloat
8 | import org.apache.spark.asyspark.core.partitions.Partitioner
9 |
10 | /**
11 | * Asynchronous implementation of a BigVector for floats
12 | */
13 | private[asyspark] class AsyBigVectorFloat(partitioner: Partitioner,
14 | models: Array[ActorRef],
15 | config: Config,
16 | keys: Long)
17 | extends AsyBigVector[Float, ResponseFloat, PushVectorFloat](partitioner, models, config, keys) {
18 |
19 | /**
20 | * Creates a push message from given sequence of keys and values
21 | *
22 | * @param id The identifier
23 | * @param keys The keys
24 | * @param values The values
25 | * @return A PushVectorFloat message for type V
26 | */
27 | @inline
28 | override protected def toPushMessage(id: Int, keys: Array[Long], values: Array[Float]): PushVectorFloat = {
29 | PushVectorFloat(id, keys, values)
30 | }
31 |
32 | /**
33 | * Extracts a value from a given response at given index
34 | *
35 | * @param response The response
36 | * @param index The index
37 | * @return The value
38 | */
39 | @inline
40 | override protected def toValue(response: ResponseFloat, index: Int): Float = response.values(index)
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVectorInt.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.client.asyImp
2 |
3 |
4 | import akka.actor.ActorRef
5 | import com.typesafe.config.Config
6 | import org.apache.spark.asyspark.core.messages.server.request.PushVectorInt
7 | import org.apache.spark.asyspark.core.messages.server.response.ResponseInt
8 | import org.apache.spark.asyspark.core.partitions.Partitioner
9 |
10 | /**
11 | * Asynchronous implementation of a BigVector for integers
12 | */
13 | private[asyspark] class AsyBigVectorInt(partitioner: Partitioner,
14 | models: Array[ActorRef],
15 | config: Config,
16 | keys: Long)
17 | extends AsyBigVector[Int, ResponseInt, PushVectorInt](partitioner, models, config, keys) {
18 |
19 | /**
20 | * Creates a push message from given sequence of keys and values
21 | *
22 | * @param id The identifier
23 | * @param keys The keys
24 | * @param values The values
25 | * @return A PushVectorInt message for type V
26 | */
27 | @inline
28 | override protected def toPushMessage(id: Int, keys: Array[Long], values: Array[Int]): PushVectorInt = {
29 | PushVectorInt(id, keys, values)
30 | }
31 |
32 | /**
33 | * Extracts a value from a given response at given index
34 | *
35 | * @param response The response
36 | * @param index The index
37 | * @return The value
38 | */
39 | @inline
40 | override protected def toValue(response: ResponseInt, index: Int): Int = response.values(index)
41 |
42 | }
43 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/AsyBigVectorLong.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.client.asyImp
2 |
3 | import org.apache.spark.asyspark.core.partitions.Partitioner
4 | import akka.actor.ActorRef
5 | import com.typesafe.config.Config
6 | import org.apache.spark.asyspark.core.messages.server.request.PushVectorLong
7 | import org.apache.spark.asyspark.core.messages.server.response.ResponseLong
8 | /**
9 | * Created by wjf on 16-9-26.
10 | */
11 | private[asyspark] class AsyBigVectorLong(partitioner: Partitioner,
12 | models: Array[ActorRef],
13 | config: Config,
14 | keys: Long)
15 | extends AsyBigVector[Long, ResponseLong, PushVectorLong](partitioner, models, config, keys) {
16 | /**
17 | * Creates a push message from given sequence of keys and values
18 | *
19 | * @param id The identifier
20 | * @param keys The keys
21 | * @param values The values
22 | * @return A PushVectorLong message for type V
23 | */
24 | @inline
25 | override protected def toPushMessage(id: Int, keys: Array[Long], values: Array[Long]): PushVectorLong = {
26 | PushVectorLong(id, keys, values)
27 | }
28 |
29 | /**
30 | * Extracts a value from a given response at given index
31 | *
32 | * @param response The response
33 | * @param index The index
34 | * @return The value
35 | */
36 | @inline
37 | override protected def toValue(response: ResponseLong, index: Int): Long = response.values(index)
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/DeserializationHelper.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.client.asyImp
2 |
3 | import akka.actor.ActorSystem
4 | import com.typesafe.config.Config
5 |
6 | /** Singleton pattern
7 | * Created by wjf on 16-9-26.
8 | */
9 | object DeserializationHelper {
10 | @volatile private var as: ActorSystem = null
11 |
12 | def getActorSystem(config: Config): ActorSystem = {
13 | if(as == null) {
14 | DeserializationHelper.synchronized {
15 | if (as == null) {
16 | as = ActorSystem("asysparkClient", config)
17 | }
18 | }
19 | }
20 | as
21 | }
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/PullFSM.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.client.asyImp
2 |
3 |
4 | import java.util.concurrent.TimeUnit
5 |
6 | import akka.actor.ActorRef
7 | import akka.pattern.{AskTimeoutException, ask}
8 | import akka.util.Timeout
9 | import com.typesafe.config.Config
10 | import org.apache.spark.asyspark.core.Exceptios.PullFailedException
11 |
12 | import scala.concurrent.{ExecutionContext, Future, Promise}
13 | import scala.concurrent.duration._
14 | import scala.reflect.ClassTag
15 |
16 | /**
17 | * Created by wjf on 16-9-25.
18 | */
19 | class PullFSM[T, R: ClassTag](message: T,
20 | actorRef: ActorRef,
21 | maxAttempts: Int,
22 | initialTimeout: FiniteDuration,
23 | maxTimeout: FiniteDuration,
24 | backoffFactor: Double)(implicit ec: ExecutionContext) {
25 | private implicit var timeout: Timeout = new Timeout(initialTimeout)
26 |
27 | private var attempts = 0
28 | private var runflag =false
29 | private val promise: Promise[R] = Promise[R]()
30 |
31 | def run(): Future[R] ={
32 | if(!runflag) {
33 | execute()
34 | runflag = true
35 | }
36 | promise.future
37 | }
38 |
39 | private def execute(): Unit ={
40 | if(attempts < maxAttempts) {
41 | attempts += 1
42 | request()
43 | } else {
44 | promise.failure(new PullFailedException(s"Failed $attempts while the maxAttempts is $maxAttempts to pull data"))
45 | }
46 | }
47 |
48 | private def request(): Unit ={
49 | val request = actorRef ? message
50 | request.onFailure {
51 | case ex: AskTimeoutException =>
52 | timeBackoff()
53 | execute()
54 | case _ =>
55 | execute()
56 | }
57 | request.onSuccess {
58 | case response: R =>
59 | promise.success(response)
60 | case _ =>
61 | execute()
62 | }
63 | }
64 |
65 | private def timeBackoff(): Unit ={
66 |
67 | if(timeout.duration.toMillis * backoffFactor > maxTimeout.toMillis) {
68 | timeout = new Timeout(maxTimeout)
69 | } else {
70 | timeout = new Timeout(timeout.duration.toMillis * backoffFactor millis)
71 | }
72 | }
73 |
74 |
75 |
76 | }
77 |
78 | object PullFSM {
79 |
80 | private var maxAttempts: Int = 10
81 | private var initialTimeout: FiniteDuration = 5 seconds
82 | private var maxTimeout: FiniteDuration = 5 minutes
83 | private var backoffFactor: Double = 1.6
84 |
85 | /**
86 | * Initializes the FSM default parameters with those specified in given config
87 | * @param config The configuration to use
88 | */
89 | def initialize(config: Config): Unit = {
90 | maxAttempts = config.getInt("asyspark.pull.maximum-attempts")
91 | initialTimeout = new FiniteDuration(config.getDuration("asyspark.pull.initial-timeout", TimeUnit.MILLISECONDS),
92 | TimeUnit.MILLISECONDS)
93 | maxTimeout = new FiniteDuration(config.getDuration("asyspark.pull.maximum-timeout", TimeUnit.MILLISECONDS),
94 | TimeUnit.MILLISECONDS)
95 | backoffFactor = config.getInt("asyspark.pull.backoff-multiplier")
96 | }
97 |
98 | /**
99 | * Constructs a new FSM for given message and actor
100 | *
101 | * @param message The pull message to send
102 | * @param actorRef The actor to send to
103 | * @param ec The execution context
104 | * @tparam T The type of message to send
105 | * @return An new and initialized PullFSM
106 | */
107 | def apply[T, R : ClassTag](message: T, actorRef: ActorRef)(implicit ec: ExecutionContext): PullFSM[T, R] = {
108 | new PullFSM[T, R](message, actorRef, maxAttempts, initialTimeout, maxTimeout, backoffFactor)
109 | }
110 |
111 | }
112 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/client/asyImp/PushFSM.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.client.asyImp
2 |
3 | import java.util.concurrent.TimeUnit
4 |
5 | import akka.actor.ActorRef
6 | import akka.pattern.{AskTimeoutException, ask}
7 | import akka.util.Timeout
8 | import com.typesafe.config.Config
9 | import org.apache.spark.asyspark.core.Exceptios.PushFailedException
10 | import org.apache.spark.asyspark.core.messages.server.logic._
11 |
12 | import scala.concurrent.duration.{FiniteDuration, _}
13 | import scala.concurrent.{ExecutionContext, Future, Promise}
14 |
15 |
16 | /**
17 | * A push-mechanism using a finite state machine to guarantee exactly-once delivery with multiple attempts
18 | *
19 | * @param message A function that takes an identifier and generates a message of type T
20 | * @param actorRef The actor to send to
21 | * @param maxAttempts The maximum number of attempts
22 | * @param maxLogicAttempts The maximum number of attempts to establish communication through logic channels
23 | * @param initialTimeout The initial timeout for the request
24 | * @param maxTimeout The maximum timeout for the request
25 | * @param backoffFactor The backoff multiplier
26 | * @param ec The execution context
27 | * @tparam T The type of message to send
28 | * @author wjf created 16-9-25
29 | */
30 | class PushFSM[T](message: Int => T,
31 | actorRef: ActorRef,
32 | maxAttempts: Int,
33 | maxLogicAttempts: Int,
34 | initialTimeout: FiniteDuration,
35 | maxTimeout: FiniteDuration,
36 | backoffFactor: Double)(implicit ec: ExecutionContext) {
37 |
38 | private implicit var timeout: Timeout = new Timeout(initialTimeout)
39 |
40 | /**
41 | * Unique identifier for this push
42 | */
43 | private var id =0
44 |
45 | /**
46 | * Counter for the number of push attempts
47 | */
48 | private var attempts = 0
49 |
50 | private var logicAttempts = 0
51 |
52 | /**
53 | * Flag to make sure only one request at some moment
54 | */
55 | private var runFlag = false
56 |
57 | private val promise: Promise[Boolean] = Promise[Boolean]()
58 |
59 | /**
60 | * Run the push request
61 | * @return
62 | */
63 | def run(): Future[Boolean] = {
64 | if(!runFlag) {
65 | prepare()
66 | runFlag = true
67 | }
68 | promise.future
69 | }
70 |
71 | /**
72 | * Prepare state
73 | * obtain a unique and available identifier from the parameter server for the next push request
74 | */
75 | private def prepare(): Unit ={
76 | val prepareFuture = actorRef ? GenerateUniqueID()
77 | prepareFuture.onSuccess {
78 | case UniqueID(identifier) =>
79 | id = identifier
80 | execute()
81 | case _ =>
82 | retry(prepare)
83 | }
84 | prepareFuture.onFailure {
85 | case ex: AskTimeoutException =>
86 | timeBackoff()
87 | retry(prepare)
88 | case _ =>
89 | retry(prepare)
90 | }
91 | }
92 |
93 | /**
94 | * Execute the push request performing a single push
95 | */
96 | private def execute(): Unit = {
97 | if(attempts >= maxAttempts) {
98 | promise.failure(new PushFailedException(s"Failed ${attempts} is equal to maxAttempts ${maxAttempts} to push data"))
99 | } else {
100 | attempts += 1
101 | actorRef ! message(id)
102 | acknowledge()
103 | }
104 | }
105 |
106 | /**
107 | * Acknowledge state
108 | * We keep sending acknowledge messages until we either receive a acknowledge or notAcknowledge
109 | */
110 | private def acknowledge(): Unit ={
111 | val ackFuture = actorRef ? AcknowledgeReceipt(id)
112 | ackFuture.onSuccess {
113 | case NotAcknowledgeReceipt(identifier) if identifier == id =>
114 | execute()
115 | case AcknowledgeReceipt(identifier) if identifier == id =>
116 | promise.success(true)
117 | forget()
118 | case _ =>
119 | retry(acknowledge)
120 | }
121 | ackFuture.onFailure {
122 | case ex:AskTimeoutException =>
123 | timeBackoff()
124 | retry(acknowledge)
125 | case _ =>
126 | retry(acknowledge)
127 | }
128 |
129 | }
130 |
131 | /**
132 | * Forget state
133 | * We keep sending forget messages until we receive a successful reply
134 | */
135 | private def forget(): Unit ={
136 | val forgetFuture = actorRef ? Forget(id)
137 | forgetFuture.onSuccess {
138 | case Forget(identifier) if identifier == id =>
139 | ()
140 | case _ =>
141 | forget()
142 | }
143 | forgetFuture.onFailure {
144 | case ex: AskTimeoutException =>
145 | timeBackoff()
146 | forget()
147 | case _ =>
148 | forget()
149 | }
150 |
151 | }
152 |
153 | /**
154 | * Increase the timeout with an exponential backoff
155 | */
156 | private def timeBackoff(): Unit ={
157 | if(timeout.duration.toMillis * backoffFactor > maxTimeout.toMillis) {
158 | timeout = new Timeout(maxTimeout)
159 | } else {
160 | timeout = new Timeout(timeout.duration.toMillis * backoffFactor millis)
161 | }
162 | }
163 |
164 | /**
165 | * Retries a function while keeping track of a logic attempts counter and fails when the logic attemps counter is
166 | * too large
167 | * @param func The function to execute again
168 | */
169 | private def retry(func: () => Unit): Unit ={
170 | logicAttempts += 1
171 | if (logicAttempts < maxLogicAttempts) {
172 | func()
173 | } else {
174 | promise.failure(new PushFailedException(s"Failed $logicAttempts time while the maxLogicAttemps is $maxLogicAttempts "))
175 | }
176 | }
177 | }
178 | object PushFSM {
179 | private var maxAttempts: Int = 10
180 | private var maxLogicAttempts: Int = 100
181 | private var initialTimeout: FiniteDuration = 5 seconds
182 | private var maxTimeout: FiniteDuration = 5 minutes
183 | private var backoffFactor: Double = 1.6
184 |
185 | /**
186 | * Initialize the FSM from the default config file
187 | * @param config
188 | */
189 | def initialize(config: Config): Unit ={
190 | maxAttempts = config.getInt("asyspark.push.maximum-attempts")
191 | maxLogicAttempts = config.getInt("asyspark.push.maximum-logic-attempts")
192 | initialTimeout = new FiniteDuration(config.getDuration("asyspark.push.initial-timeout", TimeUnit.MILLISECONDS),
193 | TimeUnit.MILLISECONDS
194 | )
195 | maxTimeout = new FiniteDuration(config.getDuration("asyspark.push.maximum-timeout", TimeUnit.MICROSECONDS),
196 | TimeUnit.MICROSECONDS)
197 | backoffFactor = config.getDouble("asyspark.push.backoff-multiplier")
198 | }
199 |
200 | /**
201 | * construct a fsm
202 | * @param message
203 | * @param actorRef
204 | * @param ec
205 | * @tparam T
206 | * @return
207 | */
208 | def apply[T](message: Int => T, actorRef: ActorRef)(implicit ec: ExecutionContext): PushFSM[T] =
209 | new PushFSM(message, actorRef, maxAttempts, maxLogicAttempts, initialTimeout, maxTimeout, backoffFactor)
210 |
211 | }
212 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVector.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.server
2 |
3 | import akka.actor.{Actor, ActorLogging}
4 | import spire.algebra.Semiring
5 | import spire.implicits._
6 | import org.apache.spark.asyspark.core.partitions.Partition
7 | import scala.reflect.ClassTag
8 |
9 | /**
10 | * Created by wjf on 16-9-25.
11 | */
12 | private[asyspark] abstract class PartialVector[@specialized V:Semiring : ClassTag](partition: Partition) extends Actor
13 | with ActorLogging with PushLogic {
14 |
15 | /**
16 | * the size of the partial vector
17 | */
18 | val size: Int = partition.size
19 |
20 | /**
21 | * the data array contains the elements
22 | */
23 | val data: Array[V]
24 |
25 | /**
26 | * update the data of the partial model by aggregating given keys and values into it
27 | * @param keys The keys
28 | * @param values The values
29 | * @return
30 | */
31 | //todo I thinks this imp can be optimized
32 | def update(keys: Array[Long], values: Array[V]): Boolean = {
33 | var i = 0
34 | try {
35 | while (i < keys.length) {
36 | val key = partition.globalToLocal(keys(i))
37 | // this is imp with the help of spire.implicits._
38 | data(key) += values(i)
39 | i += 1
40 | }
41 | true
42 | } catch {
43 | case e: Exception => false
44 | }
45 | }
46 |
47 | def get(keys: Array[Long]): Array[V] = {
48 | var i =0
49 | val a = new Array[V](keys.length)
50 | while(i < keys.length) {
51 | val key = partition.globalToLocal(keys(i))
52 | a(i) = data(key)
53 | i += 1
54 | }
55 | a
56 | }
57 |
58 | log.info(s"Constructed PartialVector[${implicitly[ClassTag[V]]}] of size $size (partition id: ${partition.index})")
59 |
60 | }
61 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVectorDouble.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.server
2 |
3 | import org.apache.spark.asyspark.core.messages.server.request.{PullVector, PushVectorDouble}
4 | import org.apache.spark.asyspark.core.messages.server.response.ResponseDouble
5 | import org.apache.spark.asyspark.core.partitions.Partition
6 | import spire.implicits._
7 | /**
8 | * Created by wjf on 16-9-25.
9 | */
10 | private[asyspark] class PartialVectorDouble(partition: Partition) extends PartialVector[Double](partition) {
11 | override val data: Array[Double] = new Array[Double](size)
12 | override def receive: Receive = {
13 | case pull: PullVector => sender ! ResponseDouble(get(pull.keys))
14 | case push: PushVectorDouble =>
15 | update(push.keys, push.values)
16 | updateFinished(push.id)
17 | case x => handleLogic(x, sender)
18 | }
19 |
20 | }
21 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVectorFloat.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.server
2 |
3 | import org.apache.spark.asyspark.core.messages.server.request.{PullVector, PushVectorFloat}
4 | import org.apache.spark.asyspark.core.messages.server.response.ResponseFloat
5 | import org.apache.spark.asyspark.core.partitions.Partition
6 | import spire.implicits._
7 | /**
8 | * Created by wjf on 16-9-25.
9 | */
10 | private[asyspark] class PartialVectorFloat(partition: Partition) extends PartialVector[Float](partition) {
11 |
12 | override val data: Array[Float] = new Array[Float](size)
13 |
14 | override def receive: Receive = {
15 | case pull: PullVector => sender ! ResponseFloat(get(pull.keys))
16 | case push: PushVectorFloat =>
17 | update(push.keys, push.values)
18 | updateFinished(push.id)
19 | case x => handleLogic(x, sender)
20 | }
21 |
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVectorInt.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.server
2 |
3 | import org.apache.spark.asyspark.core.messages.server.request.{PullVector, PushVectorInt}
4 | import org.apache.spark.asyspark.core.messages.server.response.ResponseInt
5 | import org.apache.spark.asyspark.core.partitions.Partition
6 | import spire.implicits._
7 |
8 | /**
9 | * Created by wjf on 16-9-25.
10 | */
11 | private[asyspark] class PartialVectorInt(partition: Partition) extends PartialVector[Int](partition) {
12 |
13 | override val data: Array[Int] = new Array[Int](size)
14 |
15 | override def receive: Receive = {
16 | case pull: PullVector => sender ! ResponseInt(get(pull.keys))
17 | case push: PushVectorInt =>
18 | update(push.keys, push.values)
19 | updateFinished(push.id)
20 | case x => handleLogic(x, sender)
21 | }
22 |
23 | }
24 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/server/PartialVectorLong.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.server
2 |
3 | import org.apache.spark.asyspark.core.messages.server.request.{PullVector, PushVectorLong}
4 | import org.apache.spark.asyspark.core.messages.server.response.ResponseLong
5 | import org.apache.spark.asyspark.core.partitions.Partition
6 | import spire.implicits._
7 | /**
8 | * Created by wjf on 16-9-25.
9 | */
10 | private[asyspark] class PartialVectorLong(partition: Partition) extends PartialVector[Long](partition) {
11 |
12 | override val data: Array[Long] = new Array[Long](size)
13 |
14 | override def receive: Receive = {
15 | case pull: PullVector => sender ! ResponseLong(get(pull.keys))
16 | case push: PushVectorLong =>
17 | update(push.keys, push.values)
18 | updateFinished(push.id)
19 | case x => handleLogic(x, sender)
20 | }
21 |
22 | }
23 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/models/server/PushLogic.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.models.server
2 |
3 | import akka.actor.ActorRef
4 | import org.apache.spark.asyspark.core.messages.server.logic._
5 |
6 | import scala.collection.mutable
7 |
8 | /** Some common push logic behavior
9 | * Created by wjf on 16-9-25.
10 | */
11 | trait PushLogic {
12 | /**
13 | * A set of received message ids
14 | */
15 | val receipt : mutable.HashSet[Int] = mutable.HashSet[Int]()
16 |
17 | /**
18 | * Unique identifier counter
19 | */
20 | var UID: Int = 0
21 |
22 | /**
23 | * Increases the unique id and returns the next unique id
24 | * @return The next id
25 | */
26 | private def nextId(): Int = {
27 | UID += 1
28 | UID
29 | }
30 |
31 | /**
32 | * Handle push message receipt logic
33 | * @param message The message
34 | * @param sender The sender
35 | */
36 | def handleLogic(message: Any, sender: ActorRef) = message match {
37 | case GenerateUniqueID() =>
38 | sender ! UniqueID(nextId())
39 |
40 | case AcknowledgeReceipt(id) =>
41 | if(receipt.contains(id)) {
42 | sender ! AcknowledgeReceipt(id)
43 | } else {
44 | sender ! NotAcknowledgeReceipt(id)
45 | }
46 |
47 | case Forget(id) =>
48 | if(receipt.contains(id)) {
49 | receipt.remove(id)
50 | }
51 | sender ! Forget(id)
52 | }
53 |
54 | /**
55 | * Adds the message id to the received set
56 | * @param id The message id
57 | */
58 | def updateFinished(id: Int): Unit ={
59 | require(id >= 0, s"id must be positive but got ${id}")
60 | receipt.add(id)
61 | }
62 |
63 |
64 | }
65 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/partitions/Partition.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.partitions
2 |
3 | /**
4 | * Created by wjf on 16-9-24.
5 | */
6 | abstract class Partition(val index: Int) extends Serializable {
7 |
8 | /**
9 | * check whether this partition contains some key
10 | * @param key the key
11 | * @return whether this partition contains some key
12 | */
13 | def contains(key: Long): Boolean
14 |
15 | /**
16 | * converts given global key to a continuous local array index
17 | * @param key the global key
18 | * @return the local index
19 | */
20 | def globalToLocal(key: Long): Int
21 |
22 | /**
23 | * the size of this partition
24 | * @return the size
25 | */
26 | def size:Int
27 | }
28 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/partitions/Partitioner.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.partitions
2 |
3 | /**
4 | * Created by wjf on 16-9-24.
5 | */
6 | trait Partitioner extends Serializable {
7 | /**
8 | * Assign a server to the given key
9 | * @param key The key to partition
10 | * @return The partition
11 | */
12 | def partition(key: Long): Partition
13 |
14 | /**
15 | * Returns all partitions
16 | * @return The array of partitions
17 | */
18 | def all(): Array[Partition]
19 |
20 | }
21 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/partitions/range/RangePartition.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.partitions.range
2 |
3 | import org.apache.spark.asyspark.core.partitions.Partition
4 |
5 | /**
6 | * A range partition
7 | * Created by wjf on 16-9-24.
8 | */
9 | class RangePartition(index: Int, val start: Long, val end: Long) extends Partition(index) {
10 |
11 | @inline
12 | override def contains(key: Long): Boolean = key >= start && key <= end
13 |
14 | @inline
15 | override def size: Int = (end - start).toInt
16 |
17 | /**
18 | * Converts given key to a continuous local array index[0,1,2...]
19 | * @param key the global key
20 | * @return the local index
21 | */
22 | @inline
23 | def globalToLocal(key: Long): Int = (key - start).toInt
24 |
25 | }
26 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/partitions/range/RangePartitioner.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.partitions.range
2 |
3 | import org.apache.spark.asyspark.core.partitions.{Partition, Partitioner}
4 |
5 | /**
6 | * Created by wjf on 16-9-24.
7 | */
8 | class RangePartitioner(val partitions: Array[Partition], val numOfSmallPartitions: Int, val keysSmallPartiiton: Int, val keysSize: Long) extends Partitioner {
9 | val numAllSamllPartitions: Long = numOfSmallPartitions.toLong * keysSmallPartiiton.toLong
10 | val sizeOflargePartitions = keysSmallPartiiton + 1
11 |
12 | @inline
13 | override def partition(key: Long): Partition = {
14 | require(key >= 0 && key <= keysSize, s"IndexOutOfBoundsException ${key} while size = ${keysSize}")
15 | val index = if(key < numOfSmallPartitions) {
16 | (key / numOfSmallPartitions).toInt
17 | } else {
18 | (numOfSmallPartitions + (keysSize - numAllSamllPartitions) / sizeOflargePartitions ).toInt
19 | }
20 | partitions(index)
21 | }
22 |
23 | override def all(): Array[Partition] = partitions
24 |
25 | }
26 | object RangePartitioner {
27 |
28 | /**
29 | * Create a RangePartitioner for given number of partition and keys
30 | * @param numOfPartitions The number of partiitons
31 | * @param numberOfKeys The number of keys
32 | * @return A rangePartitioner
33 | */
34 | def apply(numOfPartitions: Int, numberOfKeys: Long): RangePartitioner = {
35 | val partitions = new Array[Partition](numOfPartitions)
36 | // this largePartition just has more than one element then smalllPartition
37 | val numLargePartition = numberOfKeys % numOfPartitions
38 | val numSmallPartition = numOfPartitions - numLargePartition
39 | val keysSmallPartition = ((numberOfKeys - numLargePartition) / numOfPartitions).toInt
40 | var i = 0
41 | var start: Long = 0L
42 | var end: Long = 0L
43 | while(i < numOfPartitions) {
44 | if(i < numSmallPartition) {
45 | end += keysSmallPartition
46 | partitions(i) = new RangePartition(i, start, end)
47 | start += keysSmallPartition
48 | } else {
49 | end += keysSmallPartition + 1
50 | partitions(i) = new RangePartition(i, start, end)
51 | start += keysSmallPartition + 1
52 | }
53 | i += 1
54 | }
55 |
56 |
57 | new RangePartitioner(partitions, numOfPartitions, keysSmallPartition, numberOfKeys)
58 |
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/serialization/FastPrimitiveDeserializer.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.serialization
2 |
3 | /**
4 | * A very fast primitive deserializer using sun's Unsafe to directly read/write memory regions in the JVM
5 | *
6 | * @param bytes The serialized data
7 | */
8 | private class FastPrimitiveDeserializer(bytes: Array[Byte]) {
9 |
10 | private val unsafe = SerializationConstants.unsafe
11 | private val offset = unsafe.arrayBaseOffset(classOf[Array[Byte]])
12 | private var position: Long = 0
13 |
14 | @inline
15 | def readFloat(): Float = {
16 | position += SerializationConstants.sizeOfFloat
17 | unsafe.getFloat(bytes, offset + position - SerializationConstants.sizeOfFloat)
18 | }
19 |
20 | @inline
21 | def readDouble(): Double = {
22 | position += SerializationConstants.sizeOfDouble
23 | unsafe.getDouble(bytes, offset + position - SerializationConstants.sizeOfDouble)
24 | }
25 |
26 | @inline
27 | def readInt(): Int = {
28 | position += SerializationConstants.sizeOfInt
29 | unsafe.getInt(bytes, offset + position - SerializationConstants.sizeOfInt)
30 | }
31 |
32 | @inline
33 | def readLong(): Long = {
34 | position += SerializationConstants.sizeOfLong
35 | unsafe.getLong(bytes, offset + position - SerializationConstants.sizeOfLong)
36 | }
37 |
38 | @inline
39 | def readByte(): Byte = {
40 | position += SerializationConstants.sizeOfByte
41 | unsafe.getByte(bytes, offset + position - SerializationConstants.sizeOfByte)
42 | }
43 |
44 | @inline
45 | def readArrayInt(size: Int): Array[Int] = {
46 | val array = new Array[Int](size)
47 | unsafe.copyMemory(bytes, offset + position, array, unsafe.arrayBaseOffset(classOf[Array[Int]]), size * SerializationConstants.sizeOfInt)
48 | position += size * SerializationConstants.sizeOfInt
49 | array
50 | }
51 |
52 | @inline
53 | def readArrayLong(size: Int): Array[Long] = {
54 | val array = new Array[Long](size)
55 | unsafe.copyMemory(bytes, offset + position, array, unsafe.arrayBaseOffset(classOf[Array[Long]]), size * SerializationConstants.sizeOfLong)
56 | position += size * SerializationConstants.sizeOfLong
57 | array
58 | }
59 |
60 | @inline
61 | def readArrayFloat(size: Int): Array[Float] = {
62 | val array = new Array[Float](size)
63 | unsafe.copyMemory(bytes, offset + position, array, unsafe.arrayBaseOffset(classOf[Array[Float]]), size * SerializationConstants.sizeOfFloat)
64 | position += size * SerializationConstants.sizeOfFloat
65 | array
66 | }
67 |
68 | @inline
69 | def readArrayDouble(size: Int): Array[Double] = {
70 | val array = new Array[Double](size)
71 | unsafe.copyMemory(bytes, offset + position, array, unsafe.arrayBaseOffset(classOf[Array[Double]]), size * SerializationConstants.sizeOfDouble)
72 | position += size * SerializationConstants.sizeOfDouble
73 | array
74 | }
75 |
76 | }
77 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/serialization/FastPrimitiveSerializer.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.serialization
2 |
3 | /**
4 | * A very fast primitive serializer using sun's Unsafe to directly read/write memory regions in the JVM
5 | *
6 | * @param size The size of the serialized output (in bytes)
7 | */
8 | private class FastPrimitiveSerializer(size: Int) {
9 |
10 | val bytes = new Array[Byte](size)
11 |
12 | private val unsafe = SerializationConstants.unsafe
13 | private val offset = unsafe.arrayBaseOffset(classOf[Array[Byte]])
14 | private var position: Long = 0
15 |
16 | @inline
17 | def reset(): Unit = position = 0L
18 |
19 | @inline
20 | def writeFloat(value: Float): Unit = {
21 | unsafe.putFloat(bytes, offset + position, value)
22 | position += SerializationConstants.sizeOfFloat
23 | }
24 |
25 | @inline
26 | def writeInt(value: Int): Unit = {
27 | unsafe.putInt(bytes, offset + position, value)
28 | position += SerializationConstants.sizeOfInt
29 | }
30 |
31 | @inline
32 | def writeByte(value: Byte): Unit = {
33 | unsafe.putByte(bytes, offset + position, value)
34 | position += SerializationConstants.sizeOfByte
35 | }
36 |
37 | @inline
38 | def writeLong(value: Long): Unit = {
39 | unsafe.putLong(bytes, offset + position, value)
40 | position += SerializationConstants.sizeOfLong
41 | }
42 |
43 | @inline
44 | def writeDouble(value: Double): Unit = {
45 | unsafe.putDouble(bytes, offset + position, value)
46 | position += SerializationConstants.sizeOfDouble
47 | }
48 |
49 | @inline
50 | def writeArrayInt(value: Array[Int]): Unit = {
51 | unsafe.copyMemory(value, unsafe.arrayBaseOffset(classOf[Array[Int]]), bytes, offset + position, value.length * SerializationConstants.sizeOfInt)
52 | position += value.length * SerializationConstants.sizeOfInt
53 | }
54 |
55 | @inline
56 | def writeArrayLong(value: Array[Long]): Unit = {
57 | unsafe.copyMemory(value, unsafe.arrayBaseOffset(classOf[Array[Long]]), bytes, offset + position, value.length * SerializationConstants.sizeOfLong)
58 | position += value.length * SerializationConstants.sizeOfLong
59 | }
60 |
61 | @inline
62 | def writeArrayFloat(value: Array[Float]): Unit = {
63 | unsafe.copyMemory(value, unsafe.arrayBaseOffset(classOf[Array[Float]]), bytes, offset + position, value.length * SerializationConstants.sizeOfFloat)
64 | position += value.length * SerializationConstants.sizeOfFloat
65 | }
66 |
67 | @inline
68 | def writeArrayDouble(value: Array[Double]): Unit = {
69 | unsafe.copyMemory(value, unsafe.arrayBaseOffset(classOf[Array[Double]]), bytes, offset + position, value.length * SerializationConstants.sizeOfDouble)
70 | position += value.length * SerializationConstants.sizeOfDouble
71 | }
72 |
73 | }
74 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/serialization/RequestSerializer.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.serialization
2 | import akka.serialization._
3 | import org.apache.spark.asyspark.core.messages.server.request._
4 |
5 |
6 | /**
7 | * A fast serializer for requests
8 | *
9 | * Internally this uses a very fast primitive serialization/deserialization routine using sun's Unsafe class for direct
10 | * read/write access to JVM memory. This might not be portable across different JVMs. If serialization causes problems
11 | * you can default to JavaSerialization by removing the serialization-bindings in the configuration.
12 | * Created by wjf on 16-9-23.
13 | */
14 | class RequestSerializer extends Serializer {
15 |
16 | override def identifier: Int = 13370
17 |
18 | override def includeManifest: Boolean = false
19 |
20 | override def fromBinary(bytes: Array[Byte], manifest: Option[Class[_]]): AnyRef = {
21 | val fpd = new FastPrimitiveDeserializer(bytes)
22 | val objectType = fpd.readByte()
23 | val objectSize = fpd.readInt()
24 |
25 | objectType match {
26 | // case SerializationConstants.pullMatrixByte =>
27 | // val rows = fpd.readArrayLong(objectSize)
28 | // val cols = fpd.readArrayInt(objectSize)
29 | // PullMatrix(rows, cols)
30 | //
31 | // case SerializationConstants.pullMatrixRowsByte =>
32 | // val rows = fpd.readArrayLong(objectSize)
33 | // PullMatrixRows(rows)
34 |
35 | case SerializationConstants.pullVectorByte =>
36 | val keys = fpd.readArrayLong(objectSize)
37 | PullVector(keys)
38 |
39 | // case SerializationConstants.pushMatrixDoubleByte =>
40 | // val id = fpd.readInt()
41 | // val rows = fpd.readArrayLong(objectSize)
42 | // val cols = fpd.readArrayInt(objectSize)
43 | // val values = fpd.readArrayDouble(objectSize)
44 | // PushMatrixDouble(id, rows, cols, values)
45 | //
46 | // case SerializationConstants.pushMatrixFloatByte =>
47 | // val id = fpd.readInt()
48 | // val rows = fpd.readArrayLong(objectSize)
49 | // val cols = fpd.readArrayInt(objectSize)
50 | // val values = fpd.readArrayFloat(objectSize)
51 | // PushMatrixFloat(id, rows, cols, values)
52 | //
53 | // case SerializationConstants.pushMatrixIntByte =>
54 | // val id = fpd.readInt()
55 | // val rows = fpd.readArrayLong(objectSize)
56 | // val cols = fpd.readArrayInt(objectSize)
57 | // val values = fpd.readArrayInt(objectSize)
58 | // PushMatrixInt(id, rows, cols, values)
59 | //
60 | // case SerializationConstants.pushMatrixLongByte =>
61 | // val id = fpd.readInt()
62 | // val rows = fpd.readArrayLong(objectSize)
63 | // val cols = fpd.readArrayInt(objectSize)
64 | // val values = fpd.readArrayLong(objectSize)
65 | // PushMatrixLong(id, rows, cols, values)
66 |
67 | case SerializationConstants.pushVectorDoubleByte =>
68 | val id = fpd.readInt()
69 | val keys = fpd.readArrayLong(objectSize)
70 | val values = fpd.readArrayDouble(objectSize)
71 | PushVectorDouble(id, keys, values)
72 |
73 | case SerializationConstants.pushVectorFloatByte =>
74 | val id = fpd.readInt()
75 | val keys = fpd.readArrayLong(objectSize)
76 | val values = fpd.readArrayFloat(objectSize)
77 | PushVectorFloat(id, keys, values)
78 |
79 | case SerializationConstants.pushVectorIntByte =>
80 | val id = fpd.readInt()
81 | val keys = fpd.readArrayLong(objectSize)
82 | val values = fpd.readArrayInt(objectSize)
83 | PushVectorInt(id, keys, values)
84 |
85 | case SerializationConstants.pushVectorLongByte =>
86 | val id = fpd.readInt()
87 | val keys = fpd.readArrayLong(objectSize)
88 | val values = fpd.readArrayLong(objectSize)
89 | PushVectorLong(id, keys, values)
90 | }
91 | }
92 |
93 | override def toBinary(o: AnyRef): Array[Byte] = {
94 | o match {
95 | // case x: PullMatrix =>
96 | // val fps = new FastPrimitiveSerializer(5 + x.rows.length * SerializationConstants.sizeOfLong +
97 | // x.rows.length * SerializationConstants.sizeOfInt)
98 | // fps.writeByte(SerializationConstants.pullMatrixByte)
99 | // fps.writeInt(x.rows.length)
100 | // fps.writeArrayLong(x.rows)
101 | // fps.writeArrayInt(x.cols)
102 | // fps.bytes
103 | //
104 | // case x: PullMatrixRows =>
105 | // val fps = new FastPrimitiveSerializer(5 + x.rows.length * SerializationConstants.sizeOfLong)
106 | // fps.writeByte(SerializationConstants.pullMatrixRowsByte)
107 | // fps.writeInt(x.rows.length)
108 | // fps.writeArrayLong(x.rows)
109 | // fps.bytes
110 |
111 | case x: PullVector =>
112 | val fps = new FastPrimitiveSerializer(5 + x.keys.length * SerializationConstants.sizeOfLong)
113 | fps.writeByte(SerializationConstants.pullVectorByte)
114 | fps.writeInt(x.keys.length)
115 | fps.writeArrayLong(x.keys)
116 | fps.bytes
117 |
118 | // case x: PushMatrixDouble =>
119 | // val fps = new FastPrimitiveSerializer(9 + x.rows.length * SerializationConstants.sizeOfLong +
120 | // x.rows.length * SerializationConstants.sizeOfInt +
121 | // x.rows.length * SerializationConstants.sizeOfDouble)
122 | // fps.writeByte(SerializationConstants.pushMatrixDoubleByte)
123 | // fps.writeInt(x.rows.length)
124 | // fps.writeInt(x.id)
125 | // fps.writeArrayLong(x.rows)
126 | // fps.writeArrayInt(x.cols)
127 | // fps.writeArrayDouble(x.values)
128 | // fps.bytes
129 | //
130 | // case x: PushMatrixFloat =>
131 | // val fps = new FastPrimitiveSerializer(9 + x.rows.length * SerializationConstants.sizeOfLong +
132 | // x.rows.length * SerializationConstants.sizeOfInt +
133 | // x.rows.length * SerializationConstants.sizeOfFloat)
134 | // fps.writeByte(SerializationConstants.pushMatrixFloatByte)
135 | // fps.writeInt(x.rows.length)
136 | // fps.writeInt(x.id)
137 | // fps.writeArrayLong(x.rows)
138 | // fps.writeArrayInt(x.cols)
139 | // fps.writeArrayFloat(x.values)
140 | // fps.bytes
141 | //
142 | // case x: PushMatrixInt =>
143 | // val fps = new FastPrimitiveSerializer(9 + x.rows.length * SerializationConstants.sizeOfLong +
144 | // x.rows.length * SerializationConstants.sizeOfInt +
145 | // x.rows.length * SerializationConstants.sizeOfInt)
146 | // fps.writeByte(SerializationConstants.pushMatrixIntByte)
147 | // fps.writeInt(x.rows.length)
148 | // fps.writeInt(x.id)
149 | // fps.writeArrayLong(x.rows)
150 | // fps.writeArrayInt(x.cols)
151 | // fps.writeArrayInt(x.values)
152 | // fps.bytes
153 | //
154 | // case x: PushMatrixLong =>
155 | // val fps = new FastPrimitiveSerializer(9 + x.rows.length * SerializationConstants.sizeOfLong +
156 | // x.rows.length * SerializationConstants.sizeOfInt +
157 | // x.rows.length * SerializationConstants.sizeOfLong)
158 | // fps.writeByte(SerializationConstants.pushMatrixLongByte)
159 | // fps.writeInt(x.rows.length)
160 | // fps.writeInt(x.id)
161 | // fps.writeArrayLong(x.rows)
162 | // fps.writeArrayInt(x.cols)
163 | // fps.writeArrayLong(x.values)
164 | // fps.bytes
165 |
166 | case x: PushVectorDouble =>
167 | val fps = new FastPrimitiveSerializer(9 + x.keys.length * SerializationConstants.sizeOfLong +
168 | x.keys.length * SerializationConstants.sizeOfDouble)
169 | fps.writeByte(SerializationConstants.pushVectorDoubleByte)
170 | fps.writeInt(x.keys.length)
171 | fps.writeInt(x.id)
172 | fps.writeArrayLong(x.keys)
173 | fps.writeArrayDouble(x.values)
174 | fps.bytes
175 |
176 | case x: PushVectorFloat =>
177 | val fps = new FastPrimitiveSerializer(9 + x.keys.length * SerializationConstants.sizeOfLong +
178 | x.keys.length * SerializationConstants.sizeOfFloat)
179 | fps.writeByte(SerializationConstants.pushVectorFloatByte)
180 | fps.writeInt(x.keys.length)
181 | fps.writeInt(x.id)
182 | fps.writeArrayLong(x.keys)
183 | fps.writeArrayFloat(x.values)
184 | fps.bytes
185 |
186 | case x: PushVectorInt =>
187 | val fps = new FastPrimitiveSerializer(9 + x.keys.length * SerializationConstants.sizeOfLong +
188 | x.keys.length * SerializationConstants.sizeOfInt)
189 | fps.writeByte(SerializationConstants.pushVectorIntByte)
190 | fps.writeInt(x.keys.length)
191 | fps.writeInt(x.id)
192 | fps.writeArrayLong(x.keys)
193 | fps.writeArrayInt(x.values)
194 | fps.bytes
195 |
196 | case x: PushVectorLong =>
197 | val fps = new FastPrimitiveSerializer(9 + x.keys.length * SerializationConstants.sizeOfLong +
198 | x.keys.length * SerializationConstants.sizeOfLong)
199 | fps.writeByte(SerializationConstants.pushVectorLongByte)
200 | fps.writeInt(x.keys.length)
201 | fps.writeInt(x.id)
202 | fps.writeArrayLong(x.keys)
203 | fps.writeArrayLong(x.values)
204 | fps.bytes
205 | }
206 | }
207 | }
208 |
209 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/serialization/ResponseSerializer.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.serialization
2 |
3 | import akka.serialization._
4 | import org.apache.spark.asyspark.core.messages.server.response.{ResponseDouble, ResponseFloat, ResponseInt, ResponseLong}
5 |
6 | /**
7 | * A fast serializer for responses
8 | *
9 | * Internally this uses a very fast primitive serialization/deserialization routine using sun's Unsafe class for direct
10 | * read/write access to JVM memory. This might not be portable across different JVMs. If serialization causes problems
11 | * you can default to JavaSerialization by removing the serialization-bindings in the configuration.
12 | */
13 | class ResponseSerializer extends Serializer {
14 |
15 | override def identifier: Int = 13371
16 |
17 | override def includeManifest: Boolean = false
18 |
19 | override def fromBinary(bytes: Array[Byte], manifest: Option[Class[_]]): AnyRef = {
20 | val fpd = new FastPrimitiveDeserializer(bytes)
21 | val objectType = fpd.readByte()
22 | val objectSize = fpd.readInt()
23 |
24 | objectType match {
25 | case SerializationConstants.responseDoubleByte =>
26 | val values = fpd.readArrayDouble(objectSize)
27 | ResponseDouble(values)
28 |
29 | case SerializationConstants.responseFloatByte =>
30 | val values = fpd.readArrayFloat(objectSize)
31 | ResponseFloat(values)
32 |
33 | case SerializationConstants.responseIntByte =>
34 | val values = fpd.readArrayInt(objectSize)
35 | ResponseInt(values)
36 |
37 | case SerializationConstants.responseLongByte =>
38 | val values = fpd.readArrayLong(objectSize)
39 | ResponseLong(values)
40 | }
41 | }
42 |
43 | override def toBinary(o: AnyRef): Array[Byte] = {
44 | o match {
45 | case x: ResponseDouble =>
46 | val fps = new FastPrimitiveSerializer(5 + x.values.length * SerializationConstants.sizeOfDouble)
47 | fps.writeByte(SerializationConstants.responseDoubleByte)
48 | fps.writeInt(x.values.length)
49 | fps.writeArrayDouble(x.values)
50 | fps.bytes
51 |
52 | // case x: ResponseRowsDouble =>
53 | // val fps = new FastPrimitiveSerializer(5 + x.values.length * x.columns * SerializationConstants.sizeOfDouble)
54 | // fps.writeByte(SerializationConstants.responseDoubleByte)
55 | // fps.writeInt(x.values.length * x.columns)
56 | // var i = 0
57 | // while (i < x.values.length) {
58 | // fps.writeArrayDouble(x.values(i))
59 | // i += 1
60 | // }
61 | // fps.bytes
62 | //
63 | case x: ResponseFloat =>
64 | val fps = new FastPrimitiveSerializer(5 + x.values.length * SerializationConstants.sizeOfFloat)
65 | fps.writeByte(SerializationConstants.responseFloatByte)
66 | fps.writeInt(x.values.length)
67 | fps.writeArrayFloat(x.values)
68 | fps.bytes
69 | //
70 | // case x: ResponseRowsFloat =>
71 | // val fps = new FastPrimitiveSerializer(5 + x.values.length * x.columns * SerializationConstants.sizeOfFloat)
72 | // fps.writeByte(SerializationConstants.responseFloatByte)
73 | // fps.writeInt(x.values.length * x.columns)
74 | // var i = 0
75 | // while (i < x.values.length) {
76 | // fps.writeArrayFloat(x.values(i))
77 | // i += 1
78 | // }
79 | // fps.bytes
80 |
81 | case x: ResponseInt =>
82 | val fps = new FastPrimitiveSerializer(5 + x.values.length * SerializationConstants.sizeOfInt)
83 | fps.writeByte(SerializationConstants.responseIntByte)
84 | fps.writeInt(x.values.length)
85 | fps.writeArrayInt(x.values)
86 | fps.bytes
87 |
88 | // case x: ResponseRowsInt =>
89 | // val fps = new FastPrimitiveSerializer(5 + x.values.length * x.columns * SerializationConstants.sizeOfInt)
90 | // fps.writeByte(SerializationConstants.responseIntByte)
91 | // fps.writeInt(x.values.length * x.columns)
92 | // var i = 0
93 | // while (i < x.values.length) {
94 | // fps.writeArrayInt(x.values(i))
95 | // i += 1
96 | // }
97 | // fps.bytes
98 |
99 | case x: ResponseLong =>
100 | val fps = new FastPrimitiveSerializer(5 + x.values.length * SerializationConstants.sizeOfLong)
101 | fps.writeByte(SerializationConstants.responseLongByte)
102 | fps.writeInt(x.values.length)
103 | fps.writeArrayLong(x.values)
104 | fps.bytes
105 |
106 | // case x: ResponseRowsLong =>
107 | // val fps = new FastPrimitiveSerializer(5 + x.values.length * x.columns * SerializationConstants.sizeOfLong)
108 | // fps.writeByte(SerializationConstants.responseLongByte)
109 | // fps.writeInt(x.values.length * x.columns)
110 | // var i = 0
111 | // while (i < x.values.length) {
112 | // fps.writeArrayLong(x.values(i))
113 | // i += 1
114 | // }
115 | // fps.bytes
116 | }
117 | }
118 | }
119 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/asyspark/core/serialization/SerializationConstants.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.asyspark.core.serialization
2 |
3 | import sun.misc.Unsafe
4 |
5 | /**
6 | * Some constants used for serialization
7 | */
8 | private object SerializationConstants {
9 |
10 | // Unsafe field for direct memory access
11 | private val field = classOf[Unsafe].getDeclaredField("theUnsafe")
12 | field.setAccessible(true)
13 | val unsafe = field.get(null).asInstanceOf[Unsafe]
14 |
15 | // Size of different java primitives to perform direct read/write to java memory
16 | val sizeOfByte = 1
17 | val sizeOfShort = 2
18 | val sizeOfInt = 4
19 | val sizeOfLong = 8
20 | val sizeOfFloat = 4
21 | val sizeOfDouble = 8
22 |
23 | // Byte identifiers for different message types
24 | val pullMatrixByte: Byte = 0x00
25 | val pullMatrixRowsByte: Byte = 0x01
26 | val pullVectorByte: Byte = 0x02
27 | val pushMatrixDoubleByte: Byte = 0x03
28 | val pushMatrixFloatByte: Byte = 0x04
29 | val pushMatrixIntByte: Byte = 0x05
30 | val pushMatrixLongByte: Byte = 0x06
31 | val pushVectorDoubleByte: Byte = 0x07
32 | val pushVectorFloatByte: Byte = 0x08
33 | val pushVectorIntByte: Byte = 0x09
34 | val pushVectorLongByte: Byte = 0x0A
35 | val responseDoubleByte: Byte = 0x10
36 | val responseFloatByte: Byte = 0x11
37 | val responseIntByte: Byte = 0x12
38 | val responseLongByte: Byte = 0x13
39 |
40 | }
41 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/examples/AsySGDExample.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.examples
2 |
3 | import org.apache.spark.asyspark.asyml.asysgd.AsyGradientDescent
4 | import org.apache.spark.mllib.linalg.Vectors
5 | import org.apache.spark.mllib.optimization.{GradientDescent, LogisticGradient, SimpleUpdater}
6 | import org.apache.spark.mllib.regression.LabeledPoint
7 | import org.apache.spark.mllib.util.MLUtils
8 | import org.apache.spark.{SparkConf, SparkContext}
9 |
10 | import scala.collection.JavaConverters._
11 | import scala.util.Random
12 | /**
13 | * Created by wjf on 16-9-19.
14 | */
15 | object AsySGDExample {
16 |
17 | def main(args: Array[String]): Unit = {
18 | val sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local[*]"))
19 | val nPoints = 4000000
20 | val A = 2.0
21 | val B = -1.5
22 |
23 | val initialB = -1.0
24 | val initialWeights = Array(initialB)
25 |
26 | val gradient = new LogisticGradient()
27 | val updater = new SimpleUpdater()
28 | val stepSize = 1.0
29 | val numIterations = 1000
30 | val regParam = 0
31 | val miniBatchFrac = 1.0
32 | val convergenceTolerance = 5.0e-1
33 |
34 | // Add a extra variable consisting of all 1.0's for the intercept.
35 | val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
36 |
37 | val data = testData.map { case LabeledPoint(label, features) =>
38 | label -> MLUtils.appendBias(features)
39 | }
40 |
41 | val dataRDD = sc.parallelize(data, 8).cache()
42 | val initialWeightsWithIntercept = Vectors.dense(initialWeights.toArray :+ 1.0)
43 |
44 | // our asychronous implement
45 | var start = System.nanoTime()
46 | val (weights, weightHistory) = AsyGradientDescent.runAsySGD(
47 | dataRDD,
48 | gradient,
49 | updater,
50 | stepSize,
51 | numIterations,
52 | regParam,
53 | miniBatchFrac,
54 | initialWeightsWithIntercept,
55 | convergenceTolerance)
56 | var end = System.nanoTime()
57 | println((end - start) / 1e6 +"ms")
58 | weights.toArray.foreach(println)
59 | // use spark implement
60 | start = System.nanoTime()
61 | val (weight, weightHistorys) = GradientDescent.runMiniBatchSGD( dataRDD,
62 | gradient,
63 | updater,
64 | stepSize,
65 | numIterations,
66 | regParam,
67 | miniBatchFrac,
68 | initialWeightsWithIntercept,
69 | convergenceTolerance)
70 | end = System.nanoTime()
71 | println((end - start)/ 1e6 + "ms")
72 | weight.toArray.foreach(println)
73 | }
74 |
75 | }
76 | object GradientDescentSuite {
77 |
78 | def generateLogisticInputAsList(
79 | offset: Double,
80 | scale: Double,
81 | nPoints: Int,
82 | seed: Int): java.util.List[LabeledPoint] = {
83 | generateGDInput(offset, scale, nPoints, seed).asJava
84 | }
85 | // Generate input of the form Y = logistic(offset + scale * X)
86 | def generateGDInput(
87 | offset: Double,
88 | scale: Double,
89 | nPoints: Int,
90 | seed: Int): Seq[LabeledPoint] = {
91 | val rnd = new Random(seed)
92 | val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
93 |
94 | val unifRand = new Random(45)
95 | val rLogis = (0 until nPoints).map { i =>
96 | val u = unifRand.nextDouble()
97 | math.log(u) - math.log(1.0-u)
98 | }
99 |
100 | val y: Seq[Int] = (0 until nPoints).map { i =>
101 | val yVal = offset + scale * x1(i) + rLogis(i)
102 | if (yVal > 0) 1 else 0
103 | }
104 | (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(x1(i))))
105 | }
106 | }
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/examples/TestBroadCast.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.examples
2 |
3 | import org.apache.spark.internal.Logging
4 | import org.apache.spark.sql.SparkSession
5 |
6 | import scala.collection.mutable
7 |
8 | /**
9 | * Created by wjf on 16-9-24.
10 | */
11 | object TestBroadCast extends Logging{
12 | val sparkSession = SparkSession.builder().appName("test BoradCast").getOrCreate()
13 | val sc = sparkSession.sparkContext
14 | def main(args: Array[String]): Unit = {
15 |
16 | // val data = sc.parallelize(Seq(1 until 10000000))
17 | val num = args(args.length - 2).toInt
18 | val times = args(args.length -1).toInt
19 | println(num)
20 | val start = System.nanoTime()
21 | val seq =Seq(1 until num)
22 | for(i <- 0 until times) {
23 | val start2 = System.nanoTime()
24 | val bc = sc.broadcast(seq)
25 | val rdd = sc.parallelize(1 until 10, 5)
26 | rdd.map(_ => bc.value.take(1)).collect()
27 | println((System.nanoTime() - start2)/ 1e6 + "ms")
28 | }
29 | logInfo((System.nanoTime() - start) / 1e6 + "ms")
30 | }
31 |
32 | def testMap(): Unit ={
33 |
34 | val smallRDD = sc.parallelize(Seq(1,2,3))
35 | val bigRDD = sc.parallelize(Seq(1 until 20))
36 |
37 | bigRDD.mapPartitions {
38 | partition =>
39 | val hashMap = new mutable.HashMap[Int,Int]()
40 | for(ele <- smallRDD) {
41 | hashMap(ele) = ele
42 | }
43 | // some operation here
44 | partition
45 |
46 | }
47 | }
48 | }
49 |
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/examples/TestClient.scala:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CASISCAS/asyspark/cff60c9d4ae6eb01691a8a8b86f3deca8b67a025/src/main/scala/org/apache/spark/examples/TestClient.scala
--------------------------------------------------------------------------------
/src/main/scala/org/apache/spark/examples/TestRemote.scala:
--------------------------------------------------------------------------------
1 | package org.apache.spark.examples
2 |
3 | import java.io.File
4 |
5 | import akka.actor.{ActorSystem, Props}
6 | import com.typesafe.config.ConfigFactory
7 | import com.typesafe.scalalogging.slf4j.StrictLogging
8 | import org.apache.spark.asyspark.core.Main.{logger => _, _}
9 | import org.apache.spark.asyspark.core.messages.master.ServerList
10 |
11 | /**
12 | * Created by wjf on 16-9-27.
13 | */
14 | object TestRemote extends StrictLogging {
15 | def main(args: Array[String]): Unit = {
16 |
17 | val default = ConfigFactory.parseResourcesAnySyntax("asyspark")
18 | val config = ConfigFactory.parseFile(new File(getClass.getClassLoader.getResource("asyspark.conf").getFile)).withFallback(default).resolve()
19 | val system = ActorSystem(config.getString("asyspark.server.system"), config.getConfig("asyspark.server"))
20 |
21 | val serverHost = config.getString("asyspark.master.host")
22 | val serverPort = config.getInt("asyspark.master.port")
23 | val serverName = config.getString("asyspark.master.name")
24 | val serverSystem = config.getString("asyspark.master.system")
25 | logger.debug("Starting server actor")
26 |
27 | val master = system.actorSelection(s"akka.tcp://${serverSystem}@${serverHost}:${serverPort}/user/${serverName}")
28 |
29 |
30 | println(master.anchorPath.address)
31 |
32 | }
33 |
34 | }
35 |
--------------------------------------------------------------------------------
/src/test/scala/Test.scala:
--------------------------------------------------------------------------------
1 | /**
2 | * Created by wjf on 16-9-19.
3 | */
4 | class Test {
5 |
6 | }
--------------------------------------------------------------------------------