├── .gitignore
├── NOTICE
├── project
├── build.properties
├── plugins.sbt
└── build.scala
├── .gitmodules
├── data
└── figures
│ ├── speedo_psgd_cpu.png
│ ├── speedo_psgd_gpu.png
│ ├── speedo_easgd_gpu.png
│ └── speedo_architecture.png
├── src
├── main
│ ├── resources
│ │ └── com
│ │ │ └── htc
│ │ │ └── speedo
│ │ │ └── yarn
│ │ │ └── AppConf.xml
│ └── scala
│ │ └── com
│ │ └── htc
│ │ └── speedo
│ │ ├── caffe
│ │ ├── NetParameterUtils.scala
│ │ ├── StorehausUtils.scala
│ │ ├── HDFSStore.scala
│ │ ├── ProtobufUtils.scala
│ │ └── CaffeWorker.scala
│ │ ├── yarn
│ │ ├── YarnApp.scala
│ │ ├── ReflectionUtils.scala
│ │ ├── package.scala
│ │ ├── AppClient.scala
│ │ └── AppContainers.scala
│ │ ├── akka
│ │ ├── HostSlaveActor.scala
│ │ ├── WorkerActor.scala
│ │ ├── SynchronousMasterActor.scala
│ │ ├── WeedOutMasterActor.scala
│ │ ├── ParameterActor.scala
│ │ ├── package.scala
│ │ ├── PSMasterActor.scala
│ │ ├── AkkaUtil.scala
│ │ ├── DBActor.scala
│ │ ├── HybridMasterActor.scala
│ │ ├── MasterActor.scala
│ │ └── HostMasterActor.scala
│ │ └── SpeeDOApp.scala
└── test
│ └── scala
│ ├── org
│ └── specs2
│ │ └── AkkaSpecification.scala
│ └── com
│ └── htc
│ └── speedo
│ ├── caffe
│ └── HDFSStoreSpec.scala
│ └── akka
│ ├── ActorSpec.scala
│ ├── MasterActorSpec.scala
│ ├── HybridMasterActorSpec.scala
│ ├── SynchronousMasterActorSpec.scala
│ ├── HostMasterActorSpec.scala
│ ├── PSMasterActorSpec.scala
│ └── WeedOutMasterActorSpec.scala
├── docker
├── Dockerfile
└── entrypoint.sh
├── README_YARN.md
├── LICENSE
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright 2016 HTC Corporation
2 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=0.13.9
2 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "caffe"]
2 | path = caffe
3 | url = https://github.com/obdg/caffe.git
4 | branch = speedo
5 |
--------------------------------------------------------------------------------
/data/figures/speedo_psgd_cpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openbigdatagroup/speedo/HEAD/data/figures/speedo_psgd_cpu.png
--------------------------------------------------------------------------------
/data/figures/speedo_psgd_gpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openbigdatagroup/speedo/HEAD/data/figures/speedo_psgd_gpu.png
--------------------------------------------------------------------------------
/data/figures/speedo_easgd_gpu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openbigdatagroup/speedo/HEAD/data/figures/speedo_easgd_gpu.png
--------------------------------------------------------------------------------
/data/figures/speedo_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openbigdatagroup/speedo/HEAD/data/figures/speedo_architecture.png
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | scalacOptions ++= Seq("-unchecked","-deprecation", "-feature")
2 |
3 | resolvers += "Sonatype OSS Releases" at "https://oss.sonatype.org/service/local/staging/deploy/maven2"
4 |
5 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.1")
6 |
7 | addSbtPlugin("org.scalariform" % "sbt-scalariform" % "1.6.0")
8 |
9 | addSbtPlugin("de.heikoseeberger" % "sbt-header" % "1.5.1")
10 |
--------------------------------------------------------------------------------
/src/main/resources/com/htc/speedo/yarn/AppConf.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 1
6 | 1
7 | 0
8 |
9 |
10 |
11 | 2
12 | 1
13 | 1
14 | yarnApp
15 | default
16 |
17 |
18 |
19 | -Xmx1024M
20 |
21 |
22 |
--------------------------------------------------------------------------------
/src/test/scala/org/specs2/AkkaSpecification.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package org.specs2
18 |
19 | import akka.testkit.TestKitBase
20 |
21 | import org.specs2.specification.core.mutable.SpecificationStructure
22 | import org.specs2.specification.create.SpecificationCreation
23 | import org.specs2.specification.dsl.mutable.MutableDsl
24 | import org.specs2.specification.mutable.SpecificationFeatures
25 |
26 | /**
27 | * A base class for using specs2 specification with akka test kit.
28 | * This depends on internal class SpecificationStructure, and must be placed in org.specs2 package.
29 | */
30 | abstract class AkkaSpecification extends TestKitBase with SpecificationStructure
31 | with SpecificationFeatures with SpecificationCreation with MutableDsl
32 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | # Copyright 2016 HTC Corporation
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | FROM ubuntu:14.04
16 | MAINTAINER Wenrui Jiang
17 |
18 | # Install Redis
19 | RUN apt-get update && \
20 | DEBIAN_FRONTEND=noninteractive apt-get install -y redis-server git wget && \
21 | apt-get clean && \
22 | rm -rf /var/lib/apt/lists/*
23 |
24 | WORKDIR /root
25 | RUN git clone --recursive https://github.com/obdg/speedo.git
26 |
27 | WORKDIR /root/speedo/caffe
28 |
29 | RUN ./install_dependency && \
30 | rm -r /tmp/* && \
31 | apt-get clean && \
32 | rm -rf /var/lib/apt/lists/*
33 |
34 | ENV JAVA_LIBRARY_PATH=/usr/lib LD_LIBRARY_PATH=/usr/lib JAVA_HOME=/usr/lib/jvm/java-8-oracle
35 |
36 | RUN make all javainstall && sudo make install
37 |
38 | RUN ./data/cifar10/get_cifar10.sh && \
39 | ./examples/speedo/create_cifar10.sh
40 |
41 | WORKDIR /root/speedo
42 |
43 | RUN ./sbt akka:assembly && \
44 | cp target/scala-2.11/SpeeDO-akka-1.0.jar . && \
45 | rm -rf ~/.sbt target
46 |
47 | ENTRYPOINT ["docker/entrypoint.sh"]
48 | CMD ["master", "localhost", "3", "--test", "0", "--maxIter", "1000"]
49 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/caffe/NetParameterUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.caffe
18 |
19 | import scala.util.Random
20 |
21 | import caffe.Caffe.NetParameter
22 |
23 | import com.twitter.algebird.Semigroup
24 |
25 | /**
26 | * Utility of caffe network parameters semigroup.
27 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
28 | */
29 | object NetParameterUtils {
30 | /** Semigroup for caffe network in protobuf. */
31 | def semigroupProto(skip: Float = 0f): Semigroup[NetParameter] =
32 | if (skip == 0)
33 | new Semigroup[NetParameter] {
34 | override def plus(p1: NetParameter, p2: NetParameter) = NetParameterOperation.plus(p1, p2)
35 | }
36 | else
37 | new Semigroup[NetParameter] {
38 | override def plus(p1: NetParameter, p2: NetParameter) = NetParameterOperation.plus(p1, p2, skip)
39 | }
40 |
41 | /** Semigroup for caffe network in binary. */
42 | def semigroup(skip: Float = 0f): Semigroup[Array[Byte]] =
43 | if (skip == 0)
44 | new Semigroup[Array[Byte]] {
45 | override def plus(p1: Array[Byte], p2: Array[Byte]) = NetParameterOperation.plus(p1, p2)
46 | }
47 | else
48 | new Semigroup[Array[Byte]] {
49 | override def plus(p1: Array[Byte], p2: Array[Byte]) = NetParameterOperation.plus(p1, p2, skip)
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/src/test/scala/com/htc/speedo/caffe/HDFSStoreSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.caffe
18 |
19 | import com.twitter.util.Await
20 |
21 | import org.specs2.mutable.SpecificationWithJUnit
22 |
23 | /**
24 | * Unit test for [[HDFSStore]].
25 | *
26 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
27 | */
28 | class HDFSStoreSpec extends SpecificationWithJUnit {
29 | sequential
30 |
31 | val key = "key"
32 | val value1 = 1
33 | val value2 = 2
34 |
35 | "HDFSStore" should {
36 | val store = HDFSStore[Int]("target/hdfsstore-test")
37 |
38 | "get None for non-existance keys" in {
39 | Await.result(store.get(key)) must beNone
40 | }
41 | "put and get correctly" in {
42 | Await.result(store.put(key, Some(value1)))
43 | Await.result(store.get(key)).get must_== value1
44 | }
45 | "override existing keys with put correctly" in {
46 | Await.result(store.put(key, Some(value2)))
47 | Await.result(store.get(key)).get must_== value2
48 | }
49 | "delete existing keys correctly" in {
50 | Await.result(store.put(key, None))
51 | Await.result(store.get(key)) must_== None
52 | }
53 | "delete non-existing keys as no-op" in {
54 | Await.result(store.put(key, None))
55 | Await.result(store.get(key)) must_== None
56 | }
57 | step { store.close() }
58 | }
59 | }
60 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/caffe/StorehausUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.caffe
18 |
19 | import org.jboss.netty.buffer.{ ChannelBuffer, ChannelBuffers }
20 |
21 | import com.twitter.bijection.Injection
22 | import com.twitter.bijection.netty.ChannelBufferBijection
23 | import com.twitter.finagle.redis.Client
24 | import com.twitter.scalding.Args
25 | import com.twitter.storehaus.{ JMapStore, Store }
26 | import com.twitter.storehaus.algebra.MergeableStore
27 | import com.twitter.storehaus.redis.RedisStore
28 |
29 | /**
30 | * Utility of creating storehaus.
31 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
32 | */
33 | object StorehausUtils {
34 | /** Create mergeable store from command line arguments. */
35 | def createStore(args: Args): MergeableStore[String, Array[Byte]] = {
36 | val store = args.optional("redis").map { host => // create redis store
37 | val client = Client("%s:%d" format (host, 6379))
38 | implicit val inj = Injection.fromBijection(ChannelBufferBijection.inverse)
39 | Store.convert(RedisStore(client)) { str: String => ChannelBuffers.copiedBuffer(str.getBytes) }
40 | }.orElse(args.optional("hdfs").map { path => // create hdfs store
41 | HDFSStore[Array[Byte]](path)
42 | }).getOrElse(new JMapStore[String, Array[Byte]]) // create memory store by default
43 | implicit val semigroup = NetParameterUtils.semigroup(args.getOrElse("skip", "0").toFloat)
44 | MergeableStore.fromStore(store)
45 | }
46 | }
47 |
--------------------------------------------------------------------------------
/docker/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash -e
2 |
3 | # Copyright 2016 HTC Corporation
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | MASTER_IP=$2
18 | SOLVER=${SOLVER:-caffe/examples/speedo/solver.prototxt}
19 |
20 | # Prints the usage for this script and exits.
21 | function print_usage() {
22 | cat < --net=host obdg/speedo master [strategy]
27 | For slave:
28 | docker run -d --name= --net=host obdg/speedo master
29 |
30 | strategy(optinal):
31 | "--sync" for sync strategy
32 | "--maxAdvance n" for psc strategy
33 | "--drop n" for weed-out strategy
34 | "--movingRate v" for easgd strategy
35 |
36 | EXAMPLES:
37 |
38 | run master actor (in default Async model with 3 workers):
39 | docker run -d --name=speedo-master --net=host obdg/speedo
40 | run master actor in Easgd model with 3 workers:
41 | docker run -d --name=speedo-master --net=host obdg/speedo master localhost 3 --test 500 --maxIter 1000 --movingRate 0.5
42 |
43 | run worker actors:
44 | docker run -d --name=speedo-worker --net=host obdg/speedo worker master_ip worker_ip
45 | EOF
46 | exit 1
47 | }
48 |
49 | if [[ $1 == "master" ]]; then
50 | redis-server --daemonize yes
51 | NUM_WORKERS=${3:-3}
52 | shift
53 | shift
54 | shift
55 | java -cp SpeeDO-akka-1.0.jar -Xmx2G com.htc.speedo.akka.AkkaUtil --solver ${SOLVER} --worker ${NUM_WORKERS} --redis ${MASTER_IP} --host ${HOST_IP} --port 56126 $@ 2> /dev/null
56 | elif [[ $1 == "worker" ]]; then
57 | WORKER_IP=$3
58 | java -cp SpeeDO-akka-1.0.jar -Xmx2G com.htc.speedo.akka.AkkaUtil --host ${WORKER_IP} --master akka.tcp://SpeeDO@${MASTER_IP}:56126/user/host 2> /dev/null
59 | else
60 | print_usage
61 | fi
62 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/caffe/HDFSStore.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.caffe
18 |
19 | import com.twitter.bijection.Codec
20 | import com.twitter.storehaus.Store
21 | import com.twitter.util.{ Future, Time }
22 |
23 | import org.apache.commons.io.IOUtils
24 | import org.apache.hadoop.conf.Configuration
25 | import org.apache.hadoop.fs.Path
26 |
27 | object HDFSStore {
28 | /** Create a [[HDFSStore]] from a string path. */
29 | def apply[V: Codec](rootDir: String): HDFSStore[V] = HDFSStore(new Path(rootDir), new Configuration)
30 | }
31 |
32 | /**
33 | * A HDFS store for caffe workers. Each key value pair is stored as a file, with the key as filepath
34 | * and value as the file contents. Therefore, it's recommended to use this store for large values,
35 | * e.g. for caffe network with more than 100M snapshots.
36 | * TODO: solve the problem of con-current read and write
37 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
38 | */
39 | case class HDFSStore[V](rootDir: Path, conf: Configuration = new Configuration)(implicit codec: Codec[V])
40 | extends Store[String, V] {
41 | /** The file system for the root path. */
42 | val fs = rootDir.getFileSystem(conf)
43 |
44 | // make sure the root directory exists
45 | fs.mkdirs(rootDir)
46 |
47 | override def get(key: String) = Future {
48 | Some(new Path(rootDir, key)).filter(fs.exists).flatMap { path =>
49 | val stream = fs.open(path)
50 | val bytes = IOUtils.toByteArray(stream)
51 | stream.close
52 | codec.invert(bytes).toOption
53 | }
54 | }
55 |
56 | override def put(kv: (String, Option[V])) = Future {
57 | val path = new Path(rootDir, kv._1)
58 | kv._2 match {
59 | case None => fs.delete(path, false)
60 | case Some(v) =>
61 | val bytes = codec(v)
62 | val stream = fs.create(path, true)
63 | stream.write(bytes)
64 | stream.close
65 | }
66 | }
67 |
68 | override def close(time: Time) = Future { fs.close }
69 | }
70 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/yarn/YarnApp.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.yarn
18 |
19 | /**
20 | * trait for creating appointed application on Yarn, application can be an Akka system for example.
21 | */
22 | trait YarnApp {
23 | /** get the number of containers to launch */
24 | def getSize(args: Array[String]): Int
25 | /**
26 | * define the appointed application, including both what the master and slave roles will do.
27 | * @param args Command line argument
28 | * @param master Host name of master container
29 | */
30 | def getApp(args: Array[String], master: String): MasterRole
31 | }
32 |
33 | /** trait defined the appointed application */
34 | trait MasterRole {
35 | /** If we should wait all containers to stop or kill them immediately after master is finished. */
36 | val waitAllContainers: Boolean = true
37 | /**
38 | * Command line arguments for launching slave roles.
39 | * @param host host name for slave container
40 | */
41 | def slaveArgs(host: String): List[String]
42 | /** main class for launching the slave roles */
43 | val slaveMain: String
44 | /** define environment for this application system. For master and slaves. */
45 | val appEnv: Map[String, String] = Map()
46 | /**
47 | * Action of the master role after all the slaves are created. If you need to do actions before
48 | * starting all the slaves (or even before slaveArgs is determined), do it in the constructor.
49 | * @return A function that will always return [[JobState]], which will be running progress if job
50 | * not completed, success state if job completed.
51 | * @note called once per second
52 | */
53 | def action: () => JobState
54 | }
55 |
56 | /** Job running progress or last state */
57 | sealed trait JobState
58 |
59 | /** Job running progress, should be within [0, 1] */
60 | case class InProgress(progress: Float) extends JobState
61 |
62 | /** Job final completion state */
63 | case class Finished(success: Boolean) extends JobState
64 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/HostSlaveActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import akka.actor.{ Actor, ActorIdentity, ActorLogging, Identify }
20 |
21 | import com.twitter.scalding.Args
22 |
23 | /**
24 | * A host slave actor is used to connect master roles on yarn. It tells host master actor its
25 | * address, so the host master actor can create worker remotely on this slave host. The host slave
26 | * actor then waits until a stop message is received, which indicates the training is finished.
27 | *
28 | * Required parameters:
29 | * - `--master `: The host name or ip of the master actor system to connect to.
30 | * The host name or ip must be the same with that passed
31 | * - `--worker `: (Optional) How many worker actors will be started in the
32 | * system, default to 1. Each worker can utilize a different GPU.
33 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
34 | */
35 | case class HostSlaveActor(args: Args) extends Actor with ActorLogging {
36 | /**
37 | * The path of host master actor. Must be a full remote path.
38 | * For example, `akka.tcp://DeepLearning@cloud-master:40357/user/host`.
39 | */
40 | val masterPath = args.required("master")
41 |
42 | /** Number of workers to start in this system. */
43 | val workerCount = args.int("worker", 1)
44 |
45 | // Tries to identify if the master exists
46 | context.actorSelection(masterPath) ! Identify(masterPath)
47 |
48 | override def receive = {
49 | // If found master actor, join
50 | case ActorIdentity(`masterPath`, Some(master)) =>
51 | // Each join message will create a worker actor
52 | (1 to workerCount).foreach(_ => master ! Join)
53 | // If not found master actor, log and exit
54 | case ActorIdentity(`masterPath`, None) =>
55 | log.error(s"Cannot found master at $masterPath, stopping!")
56 | context.system.shutdown
57 | // stop slave akka system
58 | case StopAkka => context.system.shutdown
59 | }
60 | }
61 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/caffe/ProtobufUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.caffe
18 |
19 | import java.io.{ InputStream, InputStreamReader }
20 | import java.nio.charset.StandardCharsets.UTF_8
21 | import java.nio.file.{ Files, Paths }
22 |
23 | import com.google.protobuf.{ Message, MessageOrBuilder, TextFormat }
24 |
25 | /**
26 | * Utility to serialize/deserialize protobuf messages in text and binary format. Compact with caffe.
27 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
28 | */
29 | object ProtobufUtils {
30 | /** Loads a protobuf message in binary format from a local file. */
31 | def load[M <: Message](path: String)(implicit mf: Manifest[M]): M =
32 | mf.runtimeClass.getDeclaredMethod("parseFrom", classOf[Array[Byte]])
33 | .invoke(null, Files.readAllBytes(Paths.get(path))).asInstanceOf[M]
34 |
35 | /** Saves a protobuf message in binary format to a local file. */
36 | def save(path: String, model: Message): Unit = Files.write(Paths.get(path), model.toByteArray)
37 |
38 | /** Loads a protobuf message in text format from a local file. */
39 | def loadText[M <: Message: Manifest](path: String): M = loadText(Files.newInputStream(Paths.get(path)))
40 |
41 | /**
42 | * Loads a protobuf message in text format from an input stream. This is
43 | * useful for loading solver and model definitions from hdfs.
44 | * @note The stream will be closed in this function.
45 | */
46 | def loadText[M <: Message](input: InputStream)(implicit mf: Manifest[M]): M = {
47 | val reader = new InputStreamReader(input, UTF_8)
48 | val builder = mf.runtimeClass.getDeclaredMethod("newBuilder").invoke(null).asInstanceOf[Message.Builder]
49 | TextFormat.merge(reader, builder)
50 | reader.close
51 | input.close
52 | builder.build.asInstanceOf[M]
53 | }
54 |
55 | /** Saves a protobuf message in text format to a local file. */
56 | def saveText(path: String, model: MessageOrBuilder): Unit =
57 | Files.write(Paths.get(path), TextFormat.printToUnicodeString(model).getBytes(UTF_8))
58 | }
59 |
--------------------------------------------------------------------------------
/README_YARN.md:
--------------------------------------------------------------------------------
1 | # Run SpeeDO on Yarn and HDFS with Cloudera
2 |
3 | ## Step.1 Install Cloudera on cluster
4 | Install Cloudera Manager
5 | ```bash
6 | wget https://archive.cloudera.com/cm5/installer/latest/cloudera-manager-installer.bin
7 | chmod u+x cloudera-manager-installer.bin && sudo ./cloudera-manager-installer.bin
8 | ```
9 |
10 | Any issues when installation, please check [Cloudera Manager Installation Guide](http://www.cloudera.com/documentation/manager/5-1-x/Cloudera-Manager-Installation-Guide/Cloudera-Manager-Installation-Guide.html)
11 |
12 | ## Step.2 Install caffe and its dependencies
13 | Install [speedo/caffe](https://github.com/obdg/caffe) and all its dependencies, see section **B. Automatic deployment by cloudera parcels** from [speedo/caffe install guide](https://github.com/obdg/caffe).
14 |
15 | ## Step.3 Upload training datasets and network definitions to HDFS
16 |
17 | Follow the same step in section **Step.2 Prepare Input Data for each nodes** of [Deploy and run SpeeDO manually without cloudera](https://github.com/obdg/speedo), the only difference is we do not need to preapre datasets on all hosts, we only upload them onto hdfs:
18 | ```bash
19 | sudo su hdfs
20 | # denotes the user who will run SpeeDO
21 | hdfs dfs -mkdir -p /user//cifar10/
22 | hdfs dfs -chown -R :supergroup /user//cifar10/
23 | # switch back to
24 | exit
25 | hdfs dfs -put /user//cifar10/
26 | ```
27 |
28 | At last, you hdfs should look like the following:
29 | ```bash
30 | hdfs dfs -ls
31 | -rw-r--r-- 3 supergroup *** cifar10/cifar10_full_solver.prototxt
32 | -rw-r--r-- 3 supergroup *** cifar10/cifar10_full_train_test.prototxt
33 | -rw-r--r-- 3 supergroup *** cifar10/cifar10_test_datumfile
34 | -rw-r--r-- 3 supergroup *** cifar10/cifar10_train_datumfile
35 | -rw-r--r-- 3 supergroup *** cifar10/mean.binaryproto
36 | ```
37 |
38 | ## Step.4 Run directly with Yarn
39 |
40 | In this mode, Yarn is responsible for allocating containers for master and workers. You can manage applications and view logs in the usual Yarn ways.
41 |
42 | For example, to run 1000 iterations asynchronously using 3 workers:
43 | ```bash
44 | ./sbt assembly
45 | hadoop jar target/scala-2.11/SpeeDO-yarn-1.0.jar --appClass com.htc.speedo.SpeeDOApp --solver cifar10/cifar10_full_solver.prototxt --worker 3 --redis --test 500 --maxIter 1000
46 | ```
47 |
48 | **NOTE** The training data will distribute to each node from hdfs. So, we may encounter an IO bottleneck when training data get massive for some applications. It's better to put datasets at the same location on all the machines in advanced for this situation(don not need to upload onto HDFS anymore). You also need to change the path in the Caffe network definition to absolute path of the data on each machine.
49 |
--------------------------------------------------------------------------------
/project/build.scala:
--------------------------------------------------------------------------------
1 | import sbt._
2 | import Keys._
3 | import sbtassembly.AssemblyPlugin.autoImport._
4 | import de.heikoseeberger.sbtheader._
5 |
6 | object SpeeDOBuild extends Build {
7 | // configuration to assembly jar and run without Yarn
8 | val AkkaConfig = config("akka") extend(Compile)
9 |
10 | lazy val speedo = (project in file("."))
11 | .enablePlugins(AutomateHeaderPlugin)
12 | .configs(AkkaConfig)
13 | .settings(
14 | version := "1.0",
15 | scalaVersion := "2.11.7",
16 | crossScalaVersions := Seq("2.11.7", "2.10.6"),
17 | scalacOptions ++= Seq("-target:jvm-1.7", "-deprecation", "-unchecked", "-feature"),
18 | updateOptions := updateOptions.value.withCachedResolution(true).withLatestSnapshots(false),
19 | libraryDependencies ++= Seq(
20 | "com.htc.speedo" % "caffe-jni" % "0.1" % "compile,akka",
21 | "com.typesafe.akka" %% "akka-remote" % "2.3.14" % "compile,akka",
22 | "org.apache.hadoop" % "hadoop-client" % "2.6.0" % "akka,provided",
23 | "com.twitter" %% "scalding-args" % "0.15.0" % "compile,akka",
24 | "com.twitter" %% "storehaus-redis" % "0.13.0" % "compile,akka",
25 |
26 | "org.specs2" %% "specs2-junit" % "3.3.1" % "test",
27 | "com.typesafe.akka" %% "akka-testkit" % "2.3.14" % "test"
28 | ),
29 | libraryDependencies ++= {
30 | CrossVersion.partialVersion(scalaVersion.value) match {
31 | // if scala 2.11+ is used, add dependency on scala-xml module
32 | case Some((2, scalaMajor)) if scalaMajor >= 11 =>
33 | Seq("org.scala-lang.modules" %% "scala-xml" % "1.0.3" % "compile,akka")
34 | case _ => Nil
35 | }
36 | },
37 | resolvers ++= Seq(
38 | "scalaz-bintray" at "http://dl.bintray.com/scalaz/releases",
39 | Resolver.mavenLocal
40 | ),
41 | testFrameworks := Seq(sbt.TestFrameworks.Specs2),
42 | test in assembly := {},
43 | mainClass in (Compile, run) := Some("com.htc.speedo.akka.AkkaUtil"),
44 | mainClass in assembly := Some("com.htc.speedo.yarn.AppClient"),
45 | assemblyJarName in assembly := "SpeeDO-yarn-" + version.value + ".jar",
46 | HeaderPlugin.autoImport.headers := Map(
47 | "scala" -> license.Apache2_0("2016", "HTC Corporation")
48 | )
49 | )
50 | .settings(inConfig(AkkaConfig)(Classpaths.configSettings ++ Defaults.configTasks ++ baseAssemblySettings ++ Seq(
51 | compile := (compile in Compile).value,
52 | test := {},
53 | mainClass in assembly := Some("com.htc.speedo.akka.AkkaUtil"),
54 | assemblyJarName in assembly := "SpeeDO-akka-" + version.value + ".jar",
55 | assemblyMergeStrategy in assembly := {
56 | case PathList("org", "apache", xs @ _*) => MergeStrategy.first
57 | case x => (assemblyMergeStrategy in assembly).value(x)
58 | }
59 | )): _*)
60 | }
61 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/WorkerActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.util.Try
20 |
21 | import com.twitter.scalding.Args
22 |
23 | import com.htc.speedo.caffe.CaffeWorker
24 |
25 | /**
26 | * Creating a worker actor that invokes [[CaffeWorker]]. The arguments of this actor are passed to
27 | * [[CaffeWorker]].
28 | *
29 | * Parameters:
30 | * - `--test `: (optional) If set to 0, then tests are skipped.
31 | * Default is non-zero.
32 | * @see [[ParameterActor]] for arguments required by the actor itself.
33 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
34 | */
35 | case class WorkerActor(args: Args) extends ParameterActor(args) {
36 | /** If we run test or not. */
37 | val runTest = args.int("test", 1) > 0
38 |
39 | // The function to close the lazy worker.
40 | // If worker is not created, we don't need to close it
41 | private var closeWorker: Option[() => Unit] = None
42 |
43 | /**
44 | * The worker of this actor.
45 | * Marked as lazy so no worker is created if skip tests.
46 | */
47 | lazy val worker = {
48 | val worker = CaffeWorker(new Args(args.m - "snapshot"))
49 | closeWorker = Some(() => worker.close)
50 | worker
51 | }
52 |
53 | // Don't be lazy if worker is used fro training
54 | if (args.boolean("suffix")) worker
55 |
56 | override def receive = {
57 | case UpdateParameter(arg) =>
58 | // update training parameters in base class
59 | updateParameter(arg)
60 | // update solver parameters in caffe worker
61 | worker.updateParameter(arg)
62 | case Train(iteration) =>
63 | log.info("{} is training the {}th iteration.", self.path.name, iteration)
64 | worker.setIteration(iteration)
65 | // use training parameters from base class
66 | sender ! Trained(worker.train(!synchronous, weightUpdate))
67 | case Test if runTest =>
68 | // handle error for test actor, so no fault tolerance in host master actor
69 | val result = Try(worker.test).toOption.flatten
70 | sender ! TestResult(result)
71 | case Test => log.info("Test is disabled, skipping tests.")
72 | // forward messages, used for stop system after all tests are finished
73 | case Forward(actor, message) => actor.tell(message, sender)
74 | }
75 |
76 | override def postStop = closeWorker.foreach(_.apply)
77 | }
78 |
--------------------------------------------------------------------------------
/src/test/scala/com/htc/speedo/akka/ActorSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import akka.actor.ActorSystem
20 | import akka.testkit.{ TestActorRef, TestKitBase, TestProbe }
21 |
22 | import com.twitter.scalding.Args
23 | import com.typesafe.config.ConfigFactory
24 |
25 | import org.specs2.AkkaSpecification
26 | import org.specs2.specification.core.Fragments
27 |
28 | /**
29 | * An abstract base class for testing actors
30 | * @author Wenrui Jiang (roy_jiang@htc.com)
31 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
32 | */
33 | abstract class ActorSpec extends AkkaSpecification {
34 | sequential
35 |
36 | // Disable logging during test
37 | override implicit lazy val system =
38 | ActorSystem(getClass.getSimpleName, ConfigFactory.parseString("""
39 | akka.loglevel = "OFF"
40 | akka.test.single-expect-default = 2s
41 | """))
42 |
43 | /** Number of workers. */
44 | val workerNumber = 3
45 | /** Number of training iterations. */
46 | val maxIter = 10
47 | /** Test Interval. */
48 | val testInterval = 2
49 | /** The common command line argument for maxIter and test interval. */
50 | lazy val commandLine = s"--maxIter $maxIter --test $testInterval"
51 | /** Test probe for db actor, for testing messages received by db actor. */
52 | lazy val dbProb = TestProbe()
53 | /** Test probe for worker actors. */
54 | lazy val workerProb = (1 to workerNumber).map(_ => TestProbe())
55 | /** Test probe for test actor. */
56 | lazy val testProb = TestProbe()
57 |
58 | /** Create master actor for test. The command line is same as real run. */
59 | def createMasterActor[T <: MasterActor](addtionalArgs: String = ""): T =
60 | TestActorRef[T](AkkaUtil.createMasterActorProps(
61 | Args(commandLine + " " + addtionalArgs),
62 | dbProb.ref, testProb.ref, workerProb.map(_.ref)
63 | )).underlyingActor
64 |
65 | /** Asserts that a specified message is received using specs matcher */
66 | implicit class MessageAssertion(val testKit: TestKitBase) {
67 | def assertMessage(message: Any) = testKit.expectMsg(message) must_== message
68 |
69 | /** @note Blocks for 2 full seconds to make sure no messages are received */
70 | def assertNoMessage = testKit.expectNoMsg must throwA[Throwable] not
71 | }
72 |
73 | // shut down system after test
74 | override def map(fs: => Fragments) = fs ^ step(system.shutdown)
75 | }
76 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/SynchronousMasterActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.collection.mutable.{ Buffer, Set => MutableSet }
20 |
21 | import akka.actor.ActorRef
22 |
23 | import com.twitter.scalding.Args
24 |
25 | import MasterActor._
26 |
27 | /**
28 | * The implementation of synchronous master actor. The master will wait for all the workers to
29 | * complete their training before merge all the deltas. The deltas are averaged first and merge into
30 | * the weights.
31 | * Required parameter:
32 | * - `--sync`: A flag to determine use synchronous master actor.
33 | * TODO: Support auto adjustment of batch size
34 | * @note Supports the `--gradientOnly` flag.
35 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
36 | */
37 | trait SynchronousMasterActor extends MasterActor {
38 | /** The mutable set containing current waiting workers. */
39 | val waitingSet = MutableSet(workers: _*)
40 |
41 | /** The mutable set containing all the workers that waits for merging. */
42 | val mergeSet = MutableSet[ActorRef]()
43 |
44 | /**
45 | * The immutable set containing the last merge list. This is smae with the worker list in
46 | * [[DBActor]]. If mergeSet is different with this list, the worekr list in [[DBActor]] is updated
47 | * by [[UpdateWorkers]] message.
48 | */
49 | var lastUpdateWorkers = Set[ActorRef]()
50 |
51 | /** The buffer containint all the losses. */
52 | val lossList = Buffer[Double]()
53 |
54 | override def strategyName = "synchronous"
55 |
56 | override def parseTrainResult(loss: Double) = {
57 | // remove the sender from waiting list
58 | waitingSet -= sender
59 | if (loss >= 0) { // if the train is not faked by [[workerTerminated]]
60 | mergeSet += sender
61 | lossList += loss
62 | }
63 | if (waitingSet.isEmpty) {
64 | waitingSet ++= workers
65 | val average = lossList.sum / lossList.size
66 | if (mergeSet != lastUpdateWorkers) {
67 | lastUpdateWorkers = mergeSet.toSet // to immutable set
68 | dbActor ! UpdateWorkers(lastUpdateWorkers.map(_.path.name))
69 | }
70 | mergeSet.clear
71 | lossList.clear
72 | ParsedTrainResult(MergeResultAll(average), StartTrainAll)
73 | } else ParsedTrainResult(MergeResultWait, StartTrainNone)
74 | }
75 |
76 | // don't need to do anything, just wait for next iteration
77 | override def workerCreated(worker: ActorRef) = {}
78 |
79 | // just remove the worker from waiting list
80 | override def workerTerminated(worker: ActorRef) = {
81 | waitingSet -= worker
82 | // fakes train is finished to trigger normal progress of finished iteration
83 | if (waitingSet.isEmpty) self.tell(Trained(-1), worker)
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/WeedOutMasterActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.collection.mutable.Buffer
20 |
21 | import akka.actor.ActorRef
22 |
23 | import com.twitter.scalding.Args
24 |
25 | import MasterActor._
26 |
27 | /**
28 | * Weed-out strategy that discard deltas from delayed workers. A delayed worker is determined by an
29 | * interval. We use a fix sized queue to record the finished workers of last `interval` iterations.
30 | * When a worker finished training, we define it is not delayed i.i.f it exists in the queue. No
31 | * matter the worker is delayed or not, it's still enqueued and start a new training
32 | * (after the merge if not delayed).
33 | * Required parameter:
34 | * - `--weedout`: The interval to determine if a worker is delayed or not. Must be at least same as
35 | * size of workers.
36 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
37 | */
38 | trait WeedOutMasterActor extends MasterActor {
39 | /** The interval to determine weedout, at least size of workers. */
40 | var maxInterval = Math.max(args.required("weedout").toInt, workers.size)
41 |
42 | /** An fix sized queue to hold last updated workers. */
43 | val lastUpdates = Buffer[ActorRef]()
44 |
45 | /** The next index to update in the [[lastUpdates]] queue. */
46 | var updateIndex = 0
47 |
48 | override def strategyName = "weed-out"
49 |
50 | override def parseTrainResult(loss: Double) = {
51 | val needMerge =
52 | if (lastUpdates.size < maxInterval) {
53 | // For the first few iterations, always do merge
54 | lastUpdates += sender
55 | true
56 | } else {
57 | // If the sender exists in [[lastUpdates]], then we consider it not delay
58 | // and merge its delta into snapshot weight
59 | val merge = lastUpdates.contains(sender)
60 | // update the last updated workers in the queue
61 | lastUpdates(updateIndex) = sender
62 | merge
63 | }
64 | // update the next index in queue
65 | updateIndex += 1
66 | if (updateIndex == maxInterval) updateIndex = 0
67 | // always start training for the worker
68 | ParsedTrainResult(if (needMerge) MergeResultSender else MergeResultNone)
69 | }
70 |
71 | override def workerCreated(worker: ActorRef) = {
72 | lastUpdates.insert(updateIndex, worker) // insert worker as oldest updater
73 | updateIndex += 1 // update next index
74 | maxInterval += 1 // the interval is increaed by 1
75 | super.workerCreated(worker) // start training
76 | }
77 |
78 | override def workerTerminated(worker: ActorRef) = {
79 | // remove oldest element
80 | if (updateIndex < lastUpdates.size) lastUpdates.remove(updateIndex, 1)
81 | maxInterval -= 1 // the interval is removed by 1
82 | if (updateIndex == maxInterval) updateIndex = 0 // update next index
83 | }
84 | }
85 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/ParameterActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import akka.actor.{ Actor, ActorLogging }
20 |
21 | import com.twitter.scalding.Args
22 |
23 | import com.htc.speedo.caffe.CaffeWorker._
24 |
25 | /**
26 | * An abstract actor with training parameters. This is the base actor for [[DBActor]] and [[WorkerActor]].
27 | *
28 | * Optional parameters:
29 | * - `--sync`: The same flag used to determine synchronous master actor. With synchronous master
30 | * actor, all workers read snapshot from global key in storehaus, otherwise read from key with suffix.
31 | * - `--gradientOnly`: (flag) If provided, the workers will only calculate the gradients and the db
32 | * actor is responsible to calculate velocity and weights based on the gradients. Also affacts the
33 | * behavior of db actor. This can work with all types of master actor strategies.
34 | * - `--movingRate`: If provided, each worker has its own clock when workers update their local
35 | * weights. The master performance an update whenever the local workers finished t steps of their
36 | * gradient updates. Magnitude of movingRate/learningRate represents the amount of exploration we
37 | * allow in the model. Smaller of which allows for more exploration as it allows worker fluctuating
38 | * further from the center. Due to the existence of local optima of non-convex problem, we want for
39 | * more exploration.
40 | * @param args The command line arguments to create actors. Usually is provided by [[HostMasterActor]].
41 | * @note All the parameters can be updated with [[UpdateParameter]] message
42 | * @note If running with synchronous master actor, `--movingRate` is ignored.
43 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
44 | * @author Wenrui Jiang (roy_jiang@htc.com)
45 | */
46 | abstract class ParameterActor(args: Args) extends Actor with ActorLogging {
47 | // The default value of the following parameters are not important
48 | // They will be updated by [[updateParameter]] below.
49 |
50 | /** If running with synchronous master actor or not (`--sync` flag). */
51 | var synchronous: Boolean = false
52 | /** Moving rate for EASGD (i.e. [[SelfPaceFullUpdate]]) */
53 | var movingRate: Option[Float] = None
54 | /** The weight update strategy to use in [[CaffeWorker.train]]. */
55 | var weightUpdate: WeightUpdate = FullUpdate
56 |
57 | // Actual initialization of the vars goes here
58 | updateParameter(args)
59 |
60 | /**
61 | * Update the parameters based on given arguments. Should be called upon
62 | * receiving [[UpdateParameter]] message.
63 | */
64 | def updateParameter(args: Args): Unit = {
65 | synchronous = args.boolean("sync")
66 | movingRate = args.optional("movingRate").map(_.toFloat)
67 | weightUpdate = movingRate match {
68 | case Some(_) if !synchronous => SelfPaceFullUpdate
69 | case _ => args.boolean("gradientOnly") match {
70 | case true => GradientOnly
71 | case false => FullUpdate
72 | }
73 | }
74 | }
75 | }
76 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/package.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo
18 |
19 | import _root_.akka.actor.ActorRef
20 |
21 | import com.twitter.scalding.Args
22 |
23 | package object akka {
24 | /** Name of the actor system. */
25 | val SystemName = "SpeeDO"
26 |
27 | /** The command line option of max train iteration in master actors. */
28 | val MaxIterFlag = "maxIter"
29 | /** The command line option of first iteration number in master actors. */
30 | val StartIterFlag = "startIter"
31 |
32 | /**
33 | * A pinned dispatcher for all actors that uses [[CaffeWorker]] to make sure the actor is always
34 | * running on the same thread. This is required for GPU related codes, since the GPU device is set
35 | * per thread.
36 | */
37 | val WorkerDispatcher = "worker-dispatcher"
38 |
39 | /** The message for an akka system to join. Used in HostActor */
40 | case object Join
41 | /** The message for time out is triggered. */
42 | case object JoinTimeOut
43 | /** The message to stop. */
44 | case object StopAkka
45 |
46 | /** The message to forward message, used in db actor to keep message order. */
47 | case class Forward(worker: ActorRef, message: Any)
48 | /** The message to query and return current progress. */
49 | case class Progress(progress: Float)
50 | /** The message that indicates the training is finished. */
51 | case object TrainFinished
52 | /** Update training parameters of the caffe worker from command line. */
53 | case class UpdateParameter(arg: Args)
54 | /** The message for created worker actor after master actor started */
55 | case class WorkerCreated(worker: ActorRef)
56 | /** The message to inform DB actor about keys of all active worker actors */
57 | case class UpdateWorkers(keys: Set[String])
58 |
59 | /** The message to init caffe snapshot. */
60 | case class Init(resume: Boolean)
61 | /** The message to clear the corresponding suffix key of the given worker. */
62 | case class ClearWorkerKey(worker: ActorRef)
63 | /** The message to train caffe as the given java iteration. */
64 | case class Train(iteration: Int)
65 | /**
66 | * The message to merge snapshots written by the worker, used in db actor.
67 | * @param worker The worker to merge snapshot from.
68 | * @param silent If set to true, will not warn if the delta is empty. This should only be used
69 | * together with drop mater actor. Default is false.
70 | */
71 | case class Merge(worker: ActorRef, silent: Boolean = false)
72 | /** The message to merge snapshots from all worker, used in db actor. */
73 | case object MergeAll
74 | /** The message to represent one train is finished. */
75 | case class Trained(loss: Double)
76 | /** The message to test caffe once. */
77 | case object Test
78 | /** The message of test accuracy. */
79 | case class TestResult(accuracy: Option[Double])
80 |
81 | /**
82 | * The message to query current progress, should only used between [[HybridMasterActor]] and
83 | * [[MasterActor]]. Progress is calculated as `(count + offset - maxIter) / base` in master actor.
84 | */
85 | private[akka] case class ProgressIter(offset: Int, base: Int)
86 | }
87 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/yarn/ReflectionUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.yarn
18 |
19 | import scala.reflect.runtime.universe._
20 | import scala.util.Try
21 |
22 | import com.twitter.scalding.Args
23 |
24 | /**
25 | * Reflection utilities using scala reflect library.
26 | *
27 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
28 | */
29 | object ReflectionUtils {
30 | /** This is required for reflection */
31 | lazy val mirror = runtimeMirror(getClass.getClassLoader)
32 |
33 | /**
34 | * Get companion object for a class. The class can be class of companion object or the class itself.
35 | * @example Suppose we have {{{
36 | * class A
37 | * object A
38 | * }}}
39 | * both `companionObjectOption[A]` and `companionObjectOption(A.getClass)` returns the companion
40 | * object `A`. Returns None for classes without companion objects (including java classes).
41 | */
42 | def companionObjectOption(clazz: Class[_]): Option[Any] =
43 | Try(mirror.reflectModule(mirror.moduleSymbol(clazz)).instance).toOption
44 |
45 | /**
46 | * Get all constructors of the given class match the given condition based on
47 | * parameter types.
48 | * @return A list of constructors. Each constructor is a tuple, the first element is the method
49 | * mirror, which can be used to invoke the constructor using apply method; the second element is
50 | * a list of parameter class names.
51 | * @note The class names may not exist if they are generic.
52 | */
53 | def getConstructors(clazz: Class[_], filter: List[String] => Boolean): List[(MethodMirror, List[String])] = {
54 | val symbol = mirror.classSymbol(clazz) // the class symbol
55 | val classMirror = mirror.reflectClass(symbol) // the class mirror
56 | if (!symbol.isModuleClass && !symbol.isAbstractClass)
57 | symbol.toType.members.toList
58 | .collect {
59 | case m: MethodSymbol if m.isConstructor && m.owner == symbol =>
60 | (m, m.paramss.flatten.map(_.typeSignature.typeSymbol.fullName))
61 | }
62 | .collect { case (m, param) if filter(param) => (classMirror.reflectConstructor(m), param) }
63 | else Nil // either object only or class or trait
64 | }
65 |
66 | /**
67 | * Get an instance of an object or create a class with default constructor from object/class name.
68 | */
69 | def getInstaceFrom(fullName: String, args: Args): Option[Any] =
70 | getInstaceFrom(Class.forName(fullName), args)
71 |
72 | /**
73 | * Get an instance frm given java class. The instance tries to:
74 | * 1. Create instance with constructor with scalding [[Args]] parameter
75 | * 2. Create instance with constructor with no parameters
76 | * 3. Companion object
77 | * 4. Return None
78 | */
79 | def getInstaceFrom(clazz: Class[_], args: Args): Option[Any] = {
80 | // full class name for scalding Args
81 | val argName = classOf[Args].getName
82 | // try to get constructors with 0 parameters or 1 Args parameter
83 | getConstructors(clazz, params => params == Nil || params == List(argName)) match {
84 | case Nil => companionObjectOption(clazz) // try companion object
85 | case list: List[(MethodMirror, List[String])] =>
86 | val (m, param) = list.maxBy(_._2.size) // use constructor with most args
87 | Some(m(param.map(_ => args): _*))
88 | }
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/SpeeDOApp.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo
18 |
19 | import scala.concurrent.Await
20 | import scala.concurrent.duration._
21 | import scala.util.Try
22 |
23 | import _root_.akka.pattern.ask
24 | import _root_.akka.util.Timeout
25 |
26 | import com.twitter.scalding.Args
27 |
28 | import com.htc.speedo.akka.{ AkkaUtil, Progress }
29 | import com.htc.speedo.yarn.{ Finished, InProgress, MasterRole, YarnApp }
30 |
31 | /**
32 | * The yarn app to run SpeeDO
33 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
34 | */
35 | case class SpeeDOApp(args: Args) extends YarnApp {
36 | override def getSize(arg: Array[String]) =
37 | args.optional("worker").map(_.toInt match {
38 | case i: Int if i < 1 =>
39 | throw new IllegalArgumentException("--worker must be positive")
40 | case i: Int => i
41 | }).getOrElse(1)
42 |
43 | override def getApp(arg: Array[String], master: String) = SpeeDOMasterRole(args, master)
44 | }
45 |
46 | object SpeeDOMasterRole {
47 | /** The interval to query progress from master actor, in milliseconds. */
48 | val ProgressQueryInterval = 15000L // 15 seconds
49 | }
50 |
51 | /**
52 | * The master role that defines how to run SpeeDO on yarn.
53 | * @param args The command line arguments passed to the master.
54 | * @param master The host name or ip of the master container.
55 | */
56 | case class SpeeDOMasterRole(args: Args, master: String) extends MasterRole {
57 | // operations before starting slaves
58 | /** The actor system for master container. */
59 | val system = AkkaUtil.createSystem(args + ("host" -> Seq(master)))
60 | /** The remote address of the [[system]]. */
61 | val address = AkkaUtil.addressOf(system)
62 | /** The host actor of [[system]], by default, join time out is 90 seconds. */
63 | val hostActor = AkkaUtil.createHostActor(
64 | args + ("timeout" -> Seq(args.getOrElse("timeout", "90"))) +
65 | // if don't run tests, wait 15 seconds after train for correct exit state
66 | ("sleepAfterFinish" -> (if (args.optional("test") == Some("0")) Seq("15") else Nil)), system
67 | )
68 | /** The full external uri of [[hostActor]]. */
69 | val path = hostActor.path.toSerializationFormatWithAddress(address)
70 |
71 | override val slaveMain = AkkaUtil.getClass.getName.stripSuffix("$")
72 |
73 | override def slaveArgs(host: String) = List("--host", host, "--master", path)
74 |
75 | /** Current progress of the akka system */
76 | var progress = 0f
77 |
78 | /** The time of last progress query */
79 | var lastProgressTime = 0L
80 |
81 | // operations after starting slaves
82 | override def action = () => {
83 | val current = System.currentTimeMillis
84 | if (current - lastProgressTime > SpeeDOMasterRole.ProgressQueryInterval) {
85 | lastProgressTime = current
86 | implicit val timeout = Timeout(5.seconds)
87 | if (system.isTerminated) Finished(progress >= 1f)
88 | else {
89 | val future = hostActor ? Progress
90 | Try(Await.result(future.mapTo[Progress], timeout.duration))
91 | .toOption.foreach(p => progress = p.progress)
92 | // if the actor is stopped, but system not yet, return last progress
93 | InProgress(progress)
94 | }
95 | } else InProgress(progress)
96 | }
97 |
98 | override val waitAllContainers = false
99 | }
100 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/PSMasterActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.collection.mutable.{ Map => MutableMap, Set => MutableSet }
20 |
21 | import akka.actor.ActorRef
22 |
23 | import com.twitter.scalding.Args
24 |
25 | import MasterActor._
26 |
27 | /**
28 | * The Partial Synchronous Actor, which is responsible to start worker actors and collect
29 | * information. In PS, there're two possible running status for caffe workers, one is normal
30 | * running status while the other is catup-up status. Fastest/slowest workers not allowed to
31 | * drift > maxAdvance iterations apart, otherwise, the fastest workers will suspend.
32 | *
33 | * Required parameter:
34 | * - `--maxAdvance`: The max iteration interval between different workers
35 | *
36 | * @author Wenrui Jiang (roy_jiang@htc.com)
37 | */
38 | trait PSMasterActor extends MasterActor {
39 | /** The max iteration interval allowed between workers */
40 | val maxAdvance = args.required("maxAdvance").toInt
41 |
42 | /** all caffe workers and their running iterations */
43 | val workerIters = MutableMap(workers.map(_ -> 1): _*)
44 |
45 | /** works needing to catch up the fastest worker */
46 | val catchupWorkers = MutableSet[ActorRef]()
47 |
48 | /**
49 | * a catch-up flag indicates the running status, false means a normal running status,
50 | * while true means catch-up.
51 | */
52 | var catchup = false
53 |
54 | override def strategyName = "psc"
55 |
56 | override def parseTrainResult(loss: Double) = {
57 | if (catchup) {
58 | catchupWorkers -= sender
59 | // if in catch-up status and the catchupWorkers is empty
60 | if (catchupWorkers.isEmpty) {
61 | catchup = false
62 | log.info("Back to normal running status...")
63 | workers.foreach(workerIters.update(_, 1))
64 | ParsedTrainResult(train = StartTrainAll)
65 | } else ParsedTrainResult(train = StartTrainNone)
66 | } else { // if in normal running status
67 | val lastIter = workerIters.get(sender).get
68 | workerIters.update(sender, lastIter + 1)
69 | val values = workerIters.values
70 | catchup = values.max - values.min >= maxAdvance
71 | if (catchup) {
72 | catchupWorkers.clear
73 | catchupWorkers ++= workers
74 | catchupWorkers -= sender
75 | log.info("Change to catchup status, advanced actor: {}", sender.path.name)
76 | }
77 | ParsedTrainResult(train = if (catchup) StartTrainNone else StartTrainSender)
78 | }
79 | }
80 |
81 | override def workerCreated(worker: ActorRef) = {
82 | // if we are catching up, we do nothing and wait until catch up is over
83 | if (!catchup) {
84 | // set current iteration to quickest worker
85 | workerIters.update(worker, workerIters.values.max)
86 | super.workerCreated(worker) // start training
87 | }
88 | }
89 |
90 | override def workerTerminated(worker: ActorRef) = {
91 | // clean-up for the worker
92 | workerIters -= worker
93 | if (catchup) {
94 | // if we are catching up, we need to check if we are waiting for worker
95 | catchupWorkers -= worker
96 | if (catchupWorkers.isEmpty) { // all other workers are finished
97 | catchup = false
98 | log.info("Back to normal running status...")
99 | workers.foreach { w =>
100 | workerIters.update(w, 1)
101 | dbActor ! Forward(w, trainMessage)
102 | }
103 | }
104 | }
105 | }
106 | }
107 |
--------------------------------------------------------------------------------
/src/test/scala/com/htc/speedo/akka/MasterActorSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.util.Random
20 |
21 | import akka.testkit.ImplicitSender
22 |
23 | /**
24 | * Test for basic master actor
25 | * @author Wenrui Jiang (roy_jiang@htc.com)
26 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
27 | */
28 | class MasterActorSpec extends ActorSpec with ImplicitSender {
29 | "Master Actor" should {
30 | "parse configuration correctly" in {
31 | val actor = createMasterActor[MasterActor]()
32 | actor.count must_== 0
33 | actor.maxIter must_== maxIter
34 | actor.testInterval must_== testInterval
35 | actor.startIter must_== 0
36 | }
37 | "handle init state correctly" in {
38 | val actor = createMasterActor[MasterActor]()
39 | // should not react to Train message before init is finished
40 | actor.self.tell(Trained(0f), actor.workers(0))
41 | dbProb.assertNoMessage
42 | // simulate init is finished (NOTE: sender is not important)
43 | actor.self ! Init
44 | // should start test after init
45 | testProb.assertMessage(Test)
46 | // should start train on all workers after init
47 | workerProb.foreach(_.assertMessage(Train(0)))
48 | ok
49 | }
50 | "handle startIter correctly" in {
51 | val startIter = 10
52 | val actor = createMasterActor[MasterActor](s"--startIter $startIter")
53 | // should not react to Train message before init is finished
54 | actor.self.tell(Trained(0f), actor.workers(0))
55 | dbProb.assertNoMessage
56 | // simulate init is finished (NOTE: sender is not important)
57 | actor.self ! Init
58 | // should start test after init
59 | testProb.assertMessage(Test)
60 | // should start train on all workers after init
61 | workerProb.foreach(_.assertMessage(Train(startIter)))
62 | // progress should not affected by startIter
63 | actor.self ! Progress
64 | this.assertMessage(Progress(0f))
65 | }
66 | "work correctly as normal asynchronous strategy" in {
67 | val actor = createMasterActor[MasterActor]()
68 | // Skip init and change to training state
69 | actor.context.become(actor.trainState)
70 | // Simulate training
71 | (1 to maxIter).foreach { i =>
72 | // Select a random worker
73 | val sender = actor.workers(Random.nextInt(workerNumber))
74 | // Tell master actor that it has finished training
75 | actor.self.tell(Trained(0f), sender)
76 | // DB actor should receive a merge message first
77 | dbProb.assertMessage(Merge(sender))
78 | // If not enough iterations, trigger next train
79 | if (i != maxIter) {
80 | dbProb.assertMessage(Forward(sender, Train(i)))
81 | // If triggered test correctly
82 | if (i % 2 == 0) dbProb.assertMessage(Forward(actor.tester, Test))
83 | } else {
84 | // If finalizing correctly
85 | dbProb.assertMessage(Forward(actor.context.parent, TrainFinished))
86 | }
87 | }
88 | ok
89 | }
90 | "handle common states correctly" in {
91 | "handle progress query correctly" in {
92 | val actor = createMasterActor[MasterActor]()
93 | actor.self ! Progress
94 | this.assertMessage(Progress(0f))
95 | // Change count to a random number
96 | actor.count = Random.nextInt
97 | actor.self ! Progress
98 | this.assertMessage(Progress(actor.count.toFloat / maxIter))
99 | }
100 | }
101 | }
102 | }
103 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/AkkaUtil.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.collection.mutable.Buffer
20 |
21 | import akka.actor._
22 |
23 | import com.twitter.scalding.Args
24 | import com.typesafe.config.ConfigFactory
25 |
26 | /**
27 | * Provides utilities for creating akka system and host actors.
28 | *
29 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
30 | */
31 | object AkkaUtil {
32 | /** The command line flag of host. */
33 | val hostFlag = "host"
34 |
35 | /** An akka extension to get system address. See http://stackoverflow.com/questions/14288068 */
36 | class AddressExtension(system: ExtendedActorSystem) extends Extension {
37 | // The address of the akka system (contains remote information)
38 | val address = system.provider.getDefaultAddress
39 | }
40 | /** A companion object for easier creating extensions from actor system. */
41 | object AddressExtension extends ExtensionKey[AddressExtension]
42 |
43 | /**
44 | * Return the external address of the actor system, including protocol, hostname and port if
45 | * remoting is enbabled.
46 | */
47 | def addressOf(system: ActorSystem): Address = AddressExtension(system).address
48 |
49 | /**
50 | * Create an akka system which supports tcp remoting from the given arguments.
51 | *
52 | * Arguments:
53 | * - `--host `: The host name or ip to listen to.
54 | * - `--port `: (optional) The port number to start akka system.
55 | * Default is 0, i.e. a random available port. To get port after system started, use [[addressOf]].
56 | */
57 | def createSystem(args: Args): ActorSystem = {
58 | val hostname = args.required(hostFlag)
59 | val port = args.int("port", 0)
60 | val configString = s"""akka {
61 | actor { provider = "akka.remote.RemoteActorRefProvider" }
62 | remote {
63 | enabled-transports = ["akka.remote.netty.tcp"]
64 | netty.tcp {
65 | hostname = "$hostname"
66 | port = $port
67 | }
68 | }
69 | }
70 | $WorkerDispatcher {
71 | executor = "thread-pool-executor"
72 | type = PinnedDispatcher
73 | }"""
74 | ActorSystem(SystemName, ConfigFactory.parseString(configString).withFallback(ConfigFactory.load))
75 | }
76 |
77 | /**
78 | * Create the host actor in the given actor system. If a `--master` parameter is given, a
79 | * [[HostSlaveActor]] is created, otherwise [[HostMasterActor]] is created.
80 | *
81 | * Arguments for slave system:
82 | * - `--master `: The host name or ip of the master actor system to connect to.
83 | * The host name or ip must be the same with that passed to [[createSystem]] on the master machine.
84 | * - `--worker `: (Optional) If provided, start multiple worker actor in one
85 | * slave system.
86 | */
87 | def createHostActor(args: Args, system: ActorSystem): ActorRef = {
88 | // Slave host actors need to specify which master to connect to
89 | val hostActorClass = args.boolean("master") match {
90 | case false => classOf[HostMasterActor]
91 | case true => classOf[HostSlaveActor]
92 | }
93 | // exclude `--host` argument
94 | system.actorOf(Props(hostActorClass, new Args(args.m - hostFlag)), hostFlag)
95 | }
96 |
97 | /**
98 | * Create the master actor props. According to akka document, the props is recommended to be
99 | * created outside of an actor class, if use `Props(new XXXActor)` syntax. We have to use this
100 | * syntax other than `Props(classOf[XXXAxtor], args...)` since we need to mix-in traits.
101 | */
102 | def createMasterActorProps(args: Args, db: ActorRef, tester: ActorRef, workers: Seq[ActorRef]): Props = {
103 | val buffer = workers match {
104 | // use a clone of workers buffer, since it's modified in master actor
105 | case b: Buffer[ActorRef] => b.clone
106 | // otherwise just create a new buffer
107 | case _ => workers.toBuffer
108 | }
109 | Props(
110 | if (args.boolean("hybrid")) // hybrid must be first
111 | new HybridMasterActor(args, db, tester, buffer)
112 | else if (args.boolean("weedout"))
113 | new MasterActor(args, db, tester, buffer) with WeedOutMasterActor
114 | else if (args.boolean("maxAdvance"))
115 | new MasterActor(args, db, tester, buffer) with PSMasterActor
116 | else if (args.boolean("sync"))
117 | new MasterActor(args, db, tester, buffer) with SynchronousMasterActor
118 | else new MasterActor(args, db, tester, buffer)
119 | )
120 | }
121 |
122 | /**
123 | * A main utility to start master/slave akka system and host actors.
124 | *
125 | * Common Parameters:
126 | * - `--host `: The host name or ip to listen to.
127 | * - `--port ` The port to listen to. Default is random.
128 | *
129 | * Parameters for worker:
130 | * - `--master `: The host name or ip of the master actor system to connect to.
131 | * - `--worker `: (Optional) If provided, start multiple worker actor in one
132 | * slave system. Useful for multi-GPU machine, as each worker can utilize a different GPU.
133 | * @note Multiple workers in one akka system is only enabled when running slaves manually. When
134 | * running on yarn, always start one worker per system.
135 | * @see [[HostMasterActor]] for other parameters needed for master.
136 | */
137 | def main(arg: Array[String]): Unit = {
138 | val args = Args(arg)
139 | createHostActor(args, createSystem(args))
140 | }
141 | }
142 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/yarn/package.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo
18 |
19 | import java.io.File
20 | import java.net.InetAddress
21 | import java.nio.ByteBuffer
22 |
23 | import scala.collection.JavaConverters._
24 | import scala.collection.mutable.{ Map => MutableMap }
25 | import scala.util.Try
26 | import scala.xml.XML
27 |
28 | import org.apache.commons.lang.StringUtils
29 | import org.apache.hadoop.conf.Configuration
30 | import org.apache.hadoop.fs.{ FileSystem, Path }
31 | import org.apache.hadoop.yarn.api.ApplicationConstants.{ Environment, LOG_DIR_EXPANSION_VAR }
32 | import org.apache.hadoop.yarn.api.records._
33 | import org.apache.hadoop.yarn.util.{ Apps, ConverterUtils, Records }
34 |
35 | /**
36 | * helper function for build Container/AppMaster Context info
37 | *
38 | * @author Wenrui Jiang (roy_jiang@htc.com)
39 | */
40 | package object yarn {
41 | /** size unit */
42 | val gB = 1024
43 | /**
44 | * JVM heap size proportion of yarn container memory size.
45 | * YARN only pays attention to the amount of physical memory used by a
46 | * process. With Java, you can set the heap, but there's also permgen, JVM,
47 | * JNI libraries, and off-heap memory usage. All of these contribute to the
48 | * physical memory usage that YARN cares about, but are outside the JVM heap
49 | */
50 | val heapProportion = .6f
51 | /** check waiting time */
52 | val waitTime = 1000
53 |
54 | lazy val AppConfig = XML.load(getClass.getResource("/com/htc/speedo/yarn/AppConf.xml"))
55 | private def getEle(label: String) = AppConfig.\\(label).text
56 |
57 | /** get AppMaster configuration from AppConf.xml */
58 | lazy val appMem = getEle("appMemory").toInt
59 | lazy val appCores = getEle("appCores").toInt
60 | lazy val appName = getEle("appName")
61 | lazy val queueType = getEle("queueType")
62 |
63 | /** get Container configuration from AppConf.xml */
64 | lazy val containerMem = getEle("containerMemory").toInt
65 | lazy val containerCores = getEle("containerCores").toInt
66 | lazy val priorityLevel = getEle("priorityLevel").toInt
67 | lazy val javaOpts = getEle("javaOpts")
68 |
69 | /** The YarnApp class name for the application. */
70 | val AppClassFlag = "appClass"
71 | /** The Yarn application id for the application. */
72 | val AppIdFlag = "appId"
73 | /** resource flag indicating memory requirement for each container */
74 | val coreFlag = "core"
75 | /** resource flag indicating vCore requirement for each container */
76 | val memFlag = "mem"
77 | /** resource flag indicating heap size requirement for the job */
78 | val heapFlag = "heap"
79 | /** resource flag indicating host launch the application master container */
80 | val amFlag = "appMaster"
81 |
82 | /** get local host name */
83 | def getHostName: String =
84 | Try(InetAddress.getLocalHost().getHostName).toOption
85 | .filter(StringUtils.isNotEmpty)
86 | .orElse(Option(System.getenv("HOSTNAME"))) // Unix
87 | .orElse(Option(System.getenv("COMPUTERNAME"))) // Windows
88 | .getOrElse("localhost")
89 |
90 | /** build command to launch a JVM */
91 | def launchJVM(mainClass: String, arguments: List[String]): List[String] =
92 | List("hadoop", mainClass) ++ arguments ++ List(
93 | s"1>$LOG_DIR_EXPANSION_VAR/stdout", s"2>$LOG_DIR_EXPANSION_VAR/stderr"
94 | )
95 |
96 | /**
97 | * build [[ContainerLaunchContext]] which defines all informations to launch an container.
98 | * inluding:
99 | * - container id
100 | * - the command to be executed
101 | * - the local resources (binaries, jars, files etc.)
102 | * - security tokens
103 | * - environment settings (CLASSPATH etc.)
104 | */
105 | def buildContainerContext(cmd: List[String], hdfsPaths: List[Path],
106 | env: MutableMap[String, String], tokens: Option[ByteBuffer] = None)(
107 | implicit
108 | conf: Configuration
109 | ): ContainerLaunchContext = {
110 | val appMasterEnv = hdfsPaths.map { hdfsPath =>
111 | val appMasterJar = Records.newRecord(classOf[LocalResource])
112 | setUpLocalResource(hdfsPath, appMasterJar)
113 | (hdfsPath.getName, appMasterJar)
114 | }.toMap.asJava
115 | val container = Records.newRecord(classOf[ContainerLaunchContext])
116 | if (tokens.isDefined) container.setTokens(tokens.get.duplicate)
117 | container.setCommands(cmd.asJava)
118 | container.setLocalResources(appMasterEnv)
119 | container.setEnvironment(env.asJava)
120 | container
121 | }
122 |
123 | /**
124 | * add the jar which contains the Application master code to local resource
125 | *
126 | * @note using the LocalResource to add resources to our application request will cause YARN to
127 | * distribute application's jars to all of the nodes in the YARN cluster that need it
128 | */
129 | def setUpLocalResource(resourcePath: Path, res: LocalResource)(
130 | implicit
131 | conf: Configuration
132 | ): Unit = {
133 | val jarStat = FileSystem.get(conf).getFileStatus(resourcePath)
134 | res.setResource(ConverterUtils.getYarnUrlFromPath(resourcePath))
135 | res.setSize(jarStat.getLen)
136 | res.setTimestamp(jarStat.getModificationTime)
137 | res.setType(LocalResourceType.FILE)
138 | res.setVisibility(LocalResourceVisibility.PUBLIC)
139 | }
140 |
141 | /**
142 | * Add Environment
143 | * @param env The mutable map that contains environment variables
144 | * @param jarNames Additional jars appended to HADOOP_CLASSPATH
145 | * @param heapSize Heap size in gb, use None for default, e.g. 2
146 | */
147 | def setUpEnv(env: MutableMap[String, String], jarNames: List[String],
148 | heapSize: Float): Unit = {
149 | jarNames.foreach { Apps.addToEnvironment(env.asJava, "HADOOP_CLASSPATH", _, File.pathSeparator) }
150 | val heapInMB = (heapSize * gB).toInt
151 | Apps.addToEnvironment(env.asJava, "JAVA_HEAP_MAX", "-Xmx" + heapInMB + "M", File.pathSeparator)
152 | }
153 | }
154 |
--------------------------------------------------------------------------------
/src/test/scala/com/htc/speedo/akka/HybridMasterActorSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import java.util.UUID
20 |
21 | import akka.actor.{ ActorRef, PoisonPill, Props }
22 | import akka.testkit.{ ImplicitSender, TestActorRef, TestProbe }
23 |
24 | import com.twitter.scalding.Args
25 |
26 | object HybridMasterActorSpec {
27 | /** Actor creation must be put in companion object to pass serialization. */
28 | def createHybridMasterActorProps(args: Args, db: ActorRef, tester: ActorRef,
29 | workers: Seq[ActorRef], master1: ActorRef, master2: ActorRef): Props =
30 | Props(new HybridMasterActor(args, db, tester, workers.toBuffer) {
31 | override def createMasterActor(args: Args) = if (args.boolean("sync")) master1 else master2
32 | })
33 | }
34 |
35 | /**
36 | * Test for hybrid master actor
37 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
38 | */
39 | class HybridMasterActorSpec extends ActorSpec with ImplicitSender {
40 | // two fake master actors
41 | lazy val masterProb1 = TestProbe()
42 | lazy val masterProb2 = TestProbe()
43 |
44 | /** argument for creating [[HybridMasterActor]] in test */
45 | val hybridArgs = "--hybrid =sync =maxIter 2 = =maxIter 1"
46 | /** The argument for first master actor. */
47 | val masterArgs1 = "--sync --maxIter 2"
48 | /** The argument for second master actor. */
49 | val masterArgs2 = "--maxIter 1"
50 |
51 | // remove maxIter
52 | override lazy val commandLine = s"--test $testInterval"
53 |
54 | /** Create master actor for test. The command line is same as real run. */
55 | def createHybridMasterActor(addtionalArgs: String = ""): HybridMasterActor = {
56 | val props = HybridMasterActorSpec.createHybridMasterActorProps(
57 | Args(commandLine + " " + addtionalArgs),
58 | dbProb.ref, testProb.ref, workerProb.map(_.ref), masterProb1.ref, masterProb2.ref
59 | )
60 | TestActorRef(props, testActor, UUID.randomUUID.toString).underlyingActor
61 | }
62 |
63 | /** checks the messages received by db actor when a master actor started */
64 | def checkStartMaster(master: TestProbe, args: String): Unit = {
65 | // update parameters for each worker
66 | workerProb.foreach(worker => dbProb.assertMessage(
67 | Forward(worker.ref, UpdateParameter(Args(commandLine + " " + args)))
68 | ))
69 | // tell master actor everything is ready
70 | dbProb.assertMessage(Forward(master.ref, Init))
71 | }
72 |
73 | "Hybrid Master Actor" should {
74 | "parse --hybrid argument correctly" in {
75 | val arg1 = List("=key1", "value1=", "value=1")
76 | val arg2 = List("=key2", "value2", "value3")
77 | (1 to 10).foreach { i =>
78 | val separator = List(Seq.fill(i)("=").mkString(""))
79 | val args = HybridMasterActor.parseArgs(
80 | separator ::: arg1 ::: separator ::: arg2 ::: separator
81 | )
82 | args.size must_== 4
83 | args(0).toString must beEmpty
84 | args(1).toString must_== "--key1 value1= value=1"
85 | args(2).toString must_== "--key2 value2 value3"
86 | args(3).toString must beEmpty
87 | }
88 | ok
89 | }
90 | "create master actors correctly" in {
91 | "if maxIter is not given" in {
92 | val master = createHybridMasterActor(hybridArgs)
93 | master.start must_== 0
94 | // simulate DB actor finished init
95 | master.self ! Init
96 | // check messages to start master
97 | checkStartMaster(masterProb1, masterArgs1)
98 | master.start must_== 2
99 | // simualte first master finished
100 | master.self ! TrainFinished
101 | testProb.assertMessage(Forward(masterProb1.ref, PoisonPill))
102 | // check messages to start master
103 | checkStartMaster(masterProb2, masterArgs2)
104 | master.start must_== master.maxIter
105 | // simualte second master finished
106 | master.self ! TrainFinished
107 | testProb.assertMessage(Forward(masterProb2.ref, PoisonPill))
108 | // hybrid master actor should have finished train
109 | this.assertMessage(TrainFinished)
110 | }
111 | "if maxIter is smaller than sum of masters' maxIter" in {
112 | val master = createHybridMasterActor("--maxIter 1 " + hybridArgs)
113 | master.start must_== 0
114 | // simulate DB actor finished init
115 | master.self ! Init
116 | // check messages to start master
117 | checkStartMaster(masterProb1, masterArgs1)
118 | master.start must_== master.maxIter
119 | // simualte first master finished
120 | master.self ! TrainFinished
121 | testProb.assertMessage(Forward(masterProb1.ref, PoisonPill))
122 | // hybrid master actor should have finished train
123 | this.assertMessage(TrainFinished)
124 | }
125 | "if maxIter is larger than sum of master's maxIter" in {
126 | val master = createHybridMasterActor("--maxIter 4 " + hybridArgs)
127 | master.start must_== 0
128 | // simulate DB actor finished init
129 | master.self ! Init
130 | // check messages to start master
131 | checkStartMaster(masterProb1, masterArgs1)
132 | master.start must_== 2
133 | // simualte first master finished
134 | master.self ! TrainFinished
135 | testProb.assertMessage(Forward(masterProb1.ref, PoisonPill))
136 | // check messages to start master
137 | checkStartMaster(masterProb2, masterArgs2)
138 | master.start must_== 3
139 | // simualte second master finished
140 | master.self ! TrainFinished
141 | testProb.assertMessage(Forward(masterProb2.ref, PoisonPill))
142 | // check messages to start master
143 | checkStartMaster(masterProb1, masterArgs1)
144 | master.start must_== master.maxIter
145 | // simualte first master finished again
146 | master.self ! TrainFinished
147 | testProb.assertMessage(Forward(masterProb1.ref, PoisonPill))
148 | // hybrid master actor should have finished train
149 | this.assertMessage(TrainFinished)
150 | }
151 | }
152 | "handle Progress message correctly" in {
153 | val master = createHybridMasterActor(hybridArgs)
154 | // before master started, progress is 0
155 | master.self ! Progress
156 | this.assertMessage(Progress(0f))
157 | master.self ! Init
158 | // check messages to start master
159 | checkStartMaster(masterProb1, masterArgs1)
160 | // after master started, forward to master
161 | master.self ! Progress
162 | masterProb1.assertMessage(ProgressIter(2, 3))
163 | }
164 | }
165 | }
166 |
--------------------------------------------------------------------------------
/src/test/scala/com/htc/speedo/akka/SynchronousMasterActorSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.util.Random
20 |
21 | /**
22 | * Test Synchronous Master Actor
23 | * @author Wenrui Jiang (roy_jiang@htc.com)
24 | */
25 | class SynchronousMasterActorSpec extends ActorSpec {
26 | override val maxIter = 10 // runs long enough
27 | override val testInterval = 20 // no tests triggered
28 |
29 | "Synchronous Master Actor" should {
30 | val actor = createMasterActor[SynchronousMasterActor]("--sync")
31 | val master = actor.self
32 | // Skip init and change to training state
33 | actor.context.become(actor.trainState)
34 | val sender_0 = actor.workers(0)
35 | val sender_1 = actor.workers(1)
36 | val sender_2 = actor.workers(2)
37 | "work correctly for synchronous strategy" in {
38 | val sender = actor.workers(Random.nextInt(workerNumber))
39 | actor.waitingSet must containTheSameElementsAs(actor.workers)
40 | // NOT all workers finished training, the master will be waiting status.
41 | master.tell(Trained(.1f), sender)
42 | actor.waitingSet must containTheSameElementsAs(actor.workers - sender)
43 | dbProb.expectNoMsg
44 | workerProb.foreach(_.expectNoMsg)
45 | // All workers finished the training, the master will merge all deltas.
46 | (actor.workers - sender).foreach(master.tell(Trained(.1f), _))
47 | dbProb.assertMessage(UpdateWorkers(actor.workers.map(_.path.name).toSet))
48 | dbProb.assertMessage(MergeAll)
49 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, actor.trainMessage)) }
50 | actor.waitingSet must containTheSameElementsAs(actor.workers)
51 | }
52 | "handle worker hot join correctly" in {
53 | // hot join one worker (should be different with existing workers)
54 | master ! WorkerCreated(testActor) // use this test kit as new worker
55 | actor.workers must_== Seq(sender_0, sender_1, sender_2, testActor)
56 | // nothing should happen
57 | dbProb.expectNoMsg
58 |
59 | // finish one iteration
60 | (actor.workers - testActor).foreach(master.tell(Trained(.1f), _))
61 | dbProb.assertMessage(MergeAll)
62 | // all workers including new one should start training new iteration
63 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, actor.trainMessage)) }
64 |
65 | // finish another iteration, now new worker should be used
66 | actor.workers.foreach(master.tell(Trained(.1f), _))
67 | dbProb.assertMessage(UpdateWorkers(actor.workers.map(_.path.name).toSet))
68 | dbProb.assertMessage(MergeAll)
69 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, actor.trainMessage)) }
70 |
71 | // clean up
72 | actor.workers -= testActor
73 | actor.lastUpdateWorkers = Set(actor.workers: _*)
74 | actor.waitingSet -= testActor
75 | ok
76 | }
77 | "handle terminated worker correctly" in {
78 | "worker terminate before train finished" in {
79 | "worker is the last worker in waiting list" in {
80 | // worker 0 finished
81 | master.tell(Trained(.1f), sender_0)
82 | // worker 1 finished
83 | master.tell(Trained(.1f), sender_1)
84 |
85 | // simulate worker 2 is terminated
86 | actor.workers -= sender_2
87 | actor.workerTerminated(sender_2)
88 | dbProb.assertMessage(UpdateWorkers(actor.workers.map(_.path.name).toSet))
89 | dbProb.assertMessage(MergeAll)
90 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, actor.trainMessage)) }
91 | // stopped worker should not start training
92 | workerProb(2).msgAvailable must beFalse
93 |
94 | // clean up
95 | actor.workers += sender_2
96 | actor.lastUpdateWorkers = Set(actor.workers: _*)
97 | actor.waitingSet += sender_2
98 | ok
99 | }
100 | "worker is not the last worker in waiting list" in {
101 | // worker 0 finished
102 | master.tell(Trained(.1f), sender_0)
103 |
104 | // simulate worker 2 is terminated
105 | actor.workers -= sender_2
106 | actor.workerTerminated(sender_2)
107 | // nothing should happened
108 | dbProb.expectNoMsg
109 |
110 | // worker 1 finished
111 | master.tell(Trained(.1f), sender_1)
112 | dbProb.assertMessage(UpdateWorkers(actor.workers.map(_.path.name).toSet))
113 | dbProb.assertMessage(MergeAll)
114 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, actor.trainMessage)) }
115 | // stopped worker should not start training
116 | workerProb(2).msgAvailable must beFalse
117 |
118 | // clean up
119 | actor.workers += sender_2
120 | actor.lastUpdateWorkers = Set(actor.workers: _*)
121 | actor.waitingSet += sender_2
122 | ok
123 | }
124 | }
125 | "worker terminate after train finished" in {
126 | // worker 2 finished
127 | master.tell(Trained(.1f), sender_2)
128 | // worker 1 finished
129 | master.tell(Trained(.1f), sender_1)
130 | actor.waitingSet must_== Set(sender_0)
131 | actor.mergeSet must containTheSameElementsAs(Seq(sender_1, sender_2))
132 |
133 | // simulate worker 2 is terminated
134 | actor.workers -= sender_2
135 | actor.workerTerminated(sender_2)
136 | // waiting set and merge set is not affacted
137 | actor.waitingSet must_== Set(sender_0)
138 | actor.mergeSet must containTheSameElementsAs(Seq(sender_1, sender_2))
139 |
140 | // worker 0 finished
141 | master.tell(Trained(.1f), sender_0)
142 | // show merge worker 0, 1, 2 at this iteration
143 | dbProb.assertMessage(MergeAll)
144 | // all workers including new one should start training new iteration
145 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, actor.trainMessage)) }
146 | // stopped worker should not start training
147 | workerProb(2).msgAvailable must beFalse
148 |
149 | // worker 0 and 1 finished 2nd iteration
150 | master.tell(Trained(.1f), sender_0)
151 | master.tell(Trained(.1f), sender_1)
152 | // now should update merge list to db actor
153 | dbProb.assertMessage(UpdateWorkers(actor.workers.map(_.path.name).toSet))
154 | dbProb.assertMessage(MergeAll)
155 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, actor.trainMessage)) }
156 |
157 | // clean up
158 | actor.workers += sender_2
159 | actor.lastUpdateWorkers = Set(actor.workers: _*)
160 | actor.waitingSet += sender_2
161 | ok
162 | }
163 | }
164 | }
165 | }
166 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/DBActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import com.twitter.algebird.Semigroup
20 | import com.twitter.scalding.Args
21 | import com.twitter.storehaus.FutureOps
22 | import com.twitter.util.Await
23 |
24 | import com.htc.speedo.caffe.{ CaffeWorker, NetParameterOperation }
25 |
26 | import CaffeWorker._
27 |
28 | /**
29 | * A database actor stays in the master machine to hold a copy of weights in memory, merges
30 | * snapshots from workers and forward messages (Train or Test) to workers.
31 | *
32 | * Parameters:
33 | * - `--factor `: (Optional) If provided, divide the delta trained of each worker by the
34 | * given factor. Only works if non of the parameters in [[ParameterActor]] is defined.
35 | *
36 | * @note This actor should be in the same jvm as the master actor, so forwarding messages should not
37 | * be an overhead.
38 | * @see [[ParameterActor]] for other arguments required by the actor.
39 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
40 | * @author Wenrui Jiang (roy_jiang@htc.com)
41 | */
42 | case class DBActor(args: Args) extends ParameterActor(args) {
43 | /** working directory for caffe worker */
44 | val workingDir = "caffe.db"
45 | /** The worker for this actor. Used for init and merge delta. */
46 | val caffeWorker = CaffeWorker(args + ("baseDir" -> Seq(workingDir)))
47 | /** The name of the key in store. */
48 | val name = caffeWorker.getKey
49 | /** Create the same store as in [[CaffeWorker]] */
50 | val store = caffeWorker.snapshotStore
51 | /** The same semigroup as in [[CaffeWorker]], but for Option[Array[Byte]]. */
52 | val semigroup = Semigroup.optionSemigroup(store.semigroup)
53 | /**
54 | * define store keys for all workers for synchronous master actor, empty by default.
55 | * Synchronous master updates it before first iteration is merged.
56 | */
57 | var workerKeys = Set[String]()
58 | /** If provided, divide the delta by the factor. */
59 | val factor = args.optional("factor").map(_.toFloat)
60 | /** The snapshot cached in memory for all situations. */
61 | var snapshot: Option[Array[Byte]] = None
62 |
63 | override def receive = {
64 | // Init weights
65 | // TODO: support to reuse previous history
66 | case Init(resume) =>
67 | // init weights and save to local copy (same as storehaus)
68 | snapshot = Some(caffeWorker.init(resume))
69 | // tell the master that initialization is finished
70 | sender ! Init
71 | // deletes the key of worker's suffix in storehaus
72 | case ClearWorkerKey(worker) =>
73 | Await.result(store.put((name + worker.path.name) -> None))
74 | case UpdateParameter(arg) =>
75 | // update training parameters in base class
76 | updateParameter(arg)
77 | // update solver parameters in caffe worker
78 | caffeWorker.updateParameter(arg)
79 | // Forward test message if running asynchronously
80 | case Forward(worker, Test) if !synchronous && weightUpdate != GradientOnly =>
81 | // put snapshot to global key first since snapshot is written to worker's
82 | // suffix key when handling Merge message
83 | Await.result(store.put(name, snapshot))
84 | worker.tell(Test, sender)
85 | // Forward message
86 | case Forward(worker, message) => worker.tell(message, sender)
87 | // Merge snapshot of the worker, only meaningful for asynchronous master
88 | case Merge(worker, silent) if !synchronous =>
89 | Await.result(store.get(name + worker.path.name)) match {
90 | case Some(delta) => (weightUpdate, movingRate) match {
91 | // each worker updates the weight in their own pace
92 | case (SelfPaceFullUpdate, Some(rate)) =>
93 | // movingRate * (local - global)
94 | val elasticDiff = NetParameterOperation.multiply(
95 | NetParameterOperation.minus(delta, snapshot.get), rate
96 | )
97 | // update local weight
98 | val weight = NetParameterOperation.minus(delta, elasticDiff)
99 | // update global weight
100 | snapshot = semigroup.plus(snapshot, Some(elasticDiff))
101 | // put to worker's suffix key (worker will read it in next train)
102 | Await.result(store.put(name + worker.path.name, Some(weight)))
103 | // TODO: save snapshot to global key (not only before Test)
104 | // This should not happen, but just in case something is going wrong
105 | case (SelfPaceFullUpdate, None) =>
106 | log.error("Can't merge self pace update as moving rate is not set")
107 | // Add the deltas to weights and put to store
108 | case (FullUpdate, _) =>
109 | snapshot = semigroup.plus(snapshot, factor.map(f =>
110 | NetParameterOperation.divide(delta, f)).orElse(Some(delta)))
111 | // put to worker's suffix key (worker will read it in next train)
112 | Await.result(store.put(name + worker.path.name, snapshot))
113 | // TODO: save snapshot to global key (not only before Test)
114 | // Merge deltas and write to suffix key
115 | case (GradientOnly, _) =>
116 | snapshot = Some(caffeWorker.mergeDelta(delta, worker.path.name))
117 | }
118 | case None => // this may happen in drop master actor
119 | if (!silent) log.warning("Delta from {} is empty!", worker.path.name)
120 | // put snapshot to worker's suffix key (will read it in next train)
121 | Await.result(store.put(name + worker.path.name, snapshot))
122 | }
123 | // update keys of active workers
124 | case UpdateWorkers(keys) => workerKeys = keys.map(name + _)
125 | // Merge snapshot of all workers, only meaningful for synchronous master
126 | case MergeAll if synchronous =>
127 | // Get deltas for all workers
128 | val futures = store.multiGet(workerKeys)
129 | val deltas = Await.result(FutureOps.mapCollect(futures)).values
130 | // TODO: Shall we fail here?
131 | if (deltas.exists(_.isEmpty)) log.error("Delta contains empty!")
132 | // average of all deltas, ignores None, but always divide by #workers
133 | semigroup.sumOption(deltas).flatten.map(NetParameterOperation.divide(_, workerKeys.size)) match {
134 | case Some(averaged) => weightUpdate match {
135 | // Add the deltas to weights and put to store
136 | case FullUpdate =>
137 | snapshot = semigroup.plus(snapshot, Some(averaged))
138 | Await.result(store.put(name, snapshot))
139 | // Merge deltas and write to global key
140 | case GradientOnly => snapshot = Some(caffeWorker.mergeDelta(averaged))
141 | // This should not happen, but just in case something is going wrong
142 | case SelfPaceFullUpdate =>
143 | log.error("Does not support self pace update in synchronous master")
144 | }
145 | case None => log.error("All deltas are empty!")
146 | }
147 | }
148 |
149 | override def postStop = caffeWorker.close
150 | }
151 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/HybridMasterActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.collection.mutable.Buffer
20 |
21 | import akka.actor.{ Actor, ActorLogging, ActorRef, PoisonPill, Terminated }
22 |
23 | import com.twitter.scalding.Args
24 |
25 | object HybridMasterActor {
26 | /**
27 | * Parse the command line arguments in HybridMasterActor for specifying a sequence of master
28 | * actors. In the arguments, all leading `--` should be passed in as `=`, and arguments for
29 | * different master actors are seperated by `==`.
30 | */
31 | def parseArgs(args: List[String]): List[Args] = parseArgs(args, Nil).reverse
32 |
33 | /**
34 | * The recursive implementation of [[parseArgs]].
35 | * @param args Remaining args
36 | * @param list Existing list
37 | * @note The list is in reverse order, i.e. new args are prepended to the list
38 | */
39 | @annotation.tailrec
40 | private def parseArgs(args: List[String], list: List[Args]): List[Args] = {
41 | // split the sequence by the first `=`, `==`, `===` or etc
42 | val (prefix, suffix) = args.span(!_.matches("^=+$"))
43 | // replace all prefixing = to --
44 | val new_list = Args(prefix.map(_.replaceAll("^=+", "--"))) :: list
45 | suffix match {
46 | // if no `==` exists
47 | case Nil => new_list
48 | // the first element is `==`, remove it
49 | case _ :: tail => parseArgs(tail, new_list)
50 | }
51 | }
52 | }
53 |
54 | /**
55 | * A HybridMasterActor allows to switch among a sequence of master actors by current iterations.
56 | *
57 | * Required Arguments:
58 | * - `--hybrid`: The sequence of arguments passed to the underlying master actors. In the arguments,
59 | * all leading `--` should be passed in as `=`, and arguments for different master actors are
60 | * seperated by `==`.
61 | *
62 | * Optional Arguments:
63 | * - '--maxIter': The maximum number of caffe runs. Default is the sum of `maxIter` for all masters.
64 | * If `--maxIter` is smaller than total `maxIter` for all master actors, then only part of the
65 | * master sequence is executed; if larger, the master sequence is repeated.
66 | * See below example for details.
67 | *
68 | * Common arguments for all masters, which are not used directly by [[HybridMasterActor]]. They are
69 | * passed to all master actors along with the arguments given in the `--hybrid` option. If same
70 | * option is given in both places, the `--hybrid` option has higher priority.
71 | * - `--test`: The test interval.
72 | *
73 | * Other arguments consumed by master actors should be passed in the `--hybrid` options, since they
74 | * are not shared by all master actors.
75 | * @note Do not set `--startIter` in the `--hybrid` argument, as it's always overridden by
76 | * [[HybridMasterActor]].
77 | * @example An example input arguments of [[HybridMasterActor]] of
78 | * `--maxIter 800 --test 50 --hybrid =maxIter 100 =sync == =maxIter 200 =test 100 == =maxIter 300 =drop 3`
79 | * is interpreted as a sequence of 3 master actors with different parameters. Since the `--maxIter`
80 | * option is given, is larger than the sum of all master actors, their will be total of 5 master
81 | * actors running in the following order, with a total of 800 iterations:
82 | * - `--maxIter 100 --sync --test 50`: first master, synchronous for 100 iterations, test interval is 50
83 | * - `--maxIter 200 --test 100`: second master, asynchronous for 200 iterations, test interval is 100
84 | * - `--maxIter 300 --drop 3 --test 50`: third master, drop for 300 iterations, test interval is 50
85 | * - `--maxIter 100 --sync --test 50`: now back to first master again, synchronous for 100
86 | * iterations, test interval is 50
87 | * - `--maxIter 100 --test 100`: second master again, asynchronous for the last 100 iterations,
88 | * test interval is 100. Only run 100 iterations, instead of 200, since it exceeds `--maxIter`.
89 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
90 | */
91 | case class HybridMasterActor(args: Args, db: ActorRef, tester: ActorRef, workers: Buffer[ActorRef])
92 | extends Actor with ActorLogging {
93 | /** The arguments for all master actors. */
94 | val hybridArgs = HybridMasterActor.parseArgs(args.list("hybrid"))
95 | /** The max iteration to run. */
96 | val maxIter = args.int(MaxIterFlag, hybridArgs.map(_.int(MaxIterFlag)).sum)
97 |
98 | /** The master actor. */
99 | var masterActor: Option[ActorRef] = None
100 | /** The start iteration for next master actor. */
101 | var start = 0
102 | /** The index of next master actor in the [[hybridArgs]]. */
103 | var index = 0
104 |
105 | /** Create a master actor from given args. */
106 | def createMasterActor(args: Args): ActorRef =
107 | // create master actor with automatic name
108 | context.actorOf(AkkaUtil.createMasterActorProps(args, db, tester, workers))
109 |
110 | override def receive = {
111 | // Init indicates DBActor finished initialization before training
112 | // TrainFinished indicate a master actor finished its training
113 | // In either situation, we just start the next master actor or finish
114 | case Init | TrainFinished =>
115 | // stop the previous master actor, forward through test actor to make
116 | // sure all test results are presented
117 | masterActor.foreach(tester ! Forward(_, PoisonPill))
118 | if (start == maxIter) {
119 | // previous master already exceeds the iteration limit, just finish
120 | context.parent ! TrainFinished
121 | } else {
122 | // concat the args, hybridArgs(index) will override args for same keys
123 | val arg = new Args(args.m ++ hybridArgs(index).m - "hybrid")
124 | // update parameters of each worker
125 | // go through DB actor to make sure message order
126 | workers.foreach(db ! Forward(_, UpdateParameter(arg)))
127 | // iteration to run for this master actor, not exceed iteration limit
128 | val iter = Math.min(arg.int(MaxIterFlag), maxIter - start)
129 | log.info("Start {}th master actor for {} iterations", index + 1, iter)
130 | // create master actor with startIter argument
131 | masterActor = Some(createMasterActor(arg + (MaxIterFlag -> Seq(iter.toString)) + (StartIterFlag -> List(start.toString))))
132 | // start training after db actor finished merge and worker updated param
133 | masterActor.foreach(db ! Forward(_, Init))
134 | // increase counters
135 | start = start + iter
136 | index = (index + 1) % hybridArgs.size
137 | }
138 | case Progress => masterActor match {
139 | // forward to master actor
140 | case Some(master) => master.tell(ProgressIter(start, maxIter), sender)
141 | // if master actor is not started yet
142 | case None => sender ! Progress(0)
143 | }
144 | case WorkerCreated(worker) =>
145 | workers += worker
146 | context.watch(worker)
147 | // forward to master actor
148 | masterActor.foreach(_ ! WorkerCreated(worker))
149 | case Terminated(worker) if workers.contains(worker) => workers -= worker
150 | }
151 | }
152 |
--------------------------------------------------------------------------------
/src/test/scala/com/htc/speedo/akka/HostMasterActorSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.util.Random
20 |
21 | import akka.actor.{ ActorIdentity, Address, PoisonPill }
22 | import akka.testkit.{ ImplicitSender, TestActorRef, TestProbe }
23 |
24 | import com.twitter.scalding.Args
25 |
26 | import HostMasterActor.IdentifyWorker
27 |
28 | /**
29 | * Test for host master actor
30 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
31 | */
32 | class HostMasterActorSpec extends ActorSpec with ImplicitSender {
33 | /** Test probe for master actor. */
34 | lazy val masterProb = TestProbe()
35 |
36 | def createHostMasterActor(worker: Int = workerNumber, resume: Boolean = false): HostMasterActor = {
37 | val args = Args(s"--worker $worker" + (if (resume) " --resume" else ""))
38 | TestActorRef(new HostMasterActor(args) {
39 | override def createMasterActor = masterProb.ref
40 | override def createDbActor = dbProb.ref
41 | override def createWorker(index: Int, address: Address) = workerProb(index).ref
42 | override def createTester = testProb.ref
43 |
44 | // Do not stop this akka system, but do stop host slave actors
45 | def stopAkka: Receive = {
46 | case StopAkka => hostActors.foreach(_ ! StopAkka)
47 | // override messages that chagnes actor's behavior
48 | case message @ ActorIdentity(IdentifyWorker(_), _) =>
49 | // handle the message as usual
50 | initOrTrainingState(message)
51 | // if master is created, actor's behavior has changed.
52 | // We need to make sure StopAkka message is handled as we wanted.
53 | if (masterActor.isDefined)
54 | context.become(stopAkka orElse commonState orElse initOrTrainingState orElse trainingState)
55 | }
56 | override def receive: Receive = stopAkka orElse super.receive
57 | }).underlyingActor
58 | }
59 |
60 | "Host Master Actor" should {
61 | "parse configuration correctly" in {
62 | val actor = createHostMasterActor()
63 | actor.numWorkers must_== workerNumber
64 | actor.hostActors must beEmpty
65 | actor.workerActors must beEmpty
66 | actor.masterActor must beNone
67 | actor.dbActor must_== dbProb.ref
68 | actor.testActor must_== testProb.ref
69 | // stop the actor, so it does not affact later tests
70 | actor.self ! PoisonPill
71 | ok
72 | }
73 | "handle start up correctly" in {
74 | val actor = createHostMasterActor()
75 | "handle worker join when not enough workers" in {
76 | (1 to (workerNumber - 1)).foreach { i =>
77 | // simulate an akka system joins, the address is not used in test
78 | actor.self ! Join
79 | // Host actors should contain i senders (i.e. ref)
80 | actor.hostActors must containTheSameElementsAs(Seq.fill(i)(testActor))
81 | // workers created
82 | actor.workerActors must containTheSameElementsAs(workerProb.take(i).map(_.ref))
83 | // master actor is not created
84 | actor.masterActor must beNone
85 | dbProb.assertMessage(ClearWorkerKey(workerProb(i - 1).ref))
86 | }
87 | ok
88 | }
89 | "handle last worker join correctly" in {
90 | // the last worker joins
91 | actor.self ! Join
92 | actor.hostActors must containTheSameElementsAs(Seq.fill(workerNumber)(testActor))
93 | actor.workerActors must containTheSameElementsAs(workerProb.map(_.ref))
94 | actor.masterActor must beSome(masterProb.ref)
95 | dbProb.assertMessage(ClearWorkerKey(workerProb.last.ref))
96 | dbProb.assertMessage(Init(false))
97 | }
98 | "send init message correctly when --resume is given" in {
99 | val actor = createHostMasterActor(1, true).self
100 | actor ! Join
101 | dbProb.assertMessage(Init(true))
102 | actor ! PoisonPill
103 | ok
104 | }
105 | step { actor.self ! PoisonPill }
106 | }
107 | "handle progress query correctly" in {
108 | val actor = createHostMasterActor(1, true).self
109 | "when master actor is not started" in {
110 | actor ! Progress
111 | // return 0 progress
112 | this.assertMessage(Progress(0))
113 | }
114 | "when master actor is started" in {
115 | // join one worker to start master
116 | actor ! Join
117 | // consume the message in db, so following tests are not affacted
118 | dbProb.assertMessage(Init(true))
119 | actor ! Progress
120 | // forward Progress message to master prob
121 | masterProb.assertMessage(Progress)
122 | masterProb.lastSender must_== testActor
123 | }
124 | step { actor ! PoisonPill }
125 | }
126 | // Following three tests will terminate actors in probs, so must run last
127 | "handle exceptions correctly" in {
128 | val workerNumber = 2
129 | val actor = createHostMasterActor(workerNumber, true)
130 | // join two workers to start master
131 | actor.self ! Join
132 | actor.self ! Join
133 | // consume the message in db, so following tests are not affacted
134 | dbProb.assertMessage(Init(true))
135 | "handle failure to create worker correctly" in {
136 | // if the Identify of the created worker failed
137 | actor.self ! ActorIdentity(IdentifyWorker(actor.self), None)
138 | // stop the whole system (i.e. send messages to all host slaves)
139 | (1 to workerNumber).foreach(_ => this.assertMessage(StopAkka))
140 | this.msgAvailable must beFalse
141 | // if the Identify returned different actor (should NOT happen)
142 | actor.self ! ActorIdentity(IdentifyWorker(actor.self), Some(masterProb.ref))
143 | // stop the whole system (i.e. send messages to all host slaves)
144 | (1 to workerNumber).foreach(_ => this.assertMessage(StopAkka))
145 | this.msgAvailable must beFalse
146 | }
147 | "handle termination of master actor correctly" in {
148 | masterProb.ref ! PoisonPill
149 | // stop the whole system (i.e. send messages to all host slaves)
150 | (1 to workerNumber).foreach(_ => this.assertMessage(StopAkka))
151 | this.msgAvailable must beFalse
152 | }
153 | "handle termination of db actor correctly" in {
154 | dbProb.ref ! PoisonPill
155 | // stop the whole system (i.e. send messages to all host slaves)
156 | (1 to workerNumber).foreach(_ => this.assertMessage(StopAkka))
157 | this.msgAvailable must beFalse
158 | }
159 | "handle termination of worker actor correctly" in {
160 | // stop the first worker
161 | workerProb(0).ref ! PoisonPill
162 | // remove the stopped worker from list (only one worker reamining now)
163 | actor.workerActors must containTheSameElementsAs(Seq(workerProb(1).ref))
164 | // should not stop whole system
165 | this.msgAvailable must beFalse
166 |
167 | // stop the second (last) worker
168 | workerProb(1).ref ! PoisonPill
169 | // worker list should be empty
170 | actor.workerActors must beEmpty
171 | this.msgAvailable must beFalse
172 | }
173 | "handle termination of test actor correctly" in {
174 | // stop actor before train is finished
175 | testProb.ref ! PoisonPill
176 | // nothing should happen
177 | this.expectNoMsg
178 | // simulate the train is finished
179 | actor.self ! TrainFinished
180 | // stop the whole system (i.e. send messages to all host slaves)
181 | (1 to workerNumber).foreach(_ => this.assertMessage(StopAkka))
182 | this.msgAvailable must beFalse
183 | }
184 | }
185 | }
186 | }
187 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/yarn/AppClient.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.yarn
18 |
19 | import java.io.{ File, FileFilter }
20 | import java.net.URLClassLoader
21 | import java.text.SimpleDateFormat
22 | import java.util.{ Date, TimeZone }
23 |
24 | import scala.sys.process._
25 |
26 | import com.twitter.scalding.Args
27 |
28 | import org.apache.hadoop.fs.{ FileSystem, Path }
29 | import org.apache.hadoop.fs.permission.{ FsAction, FsPermission }
30 | import org.apache.hadoop.yarn.api.records.{ Resource, YarnApplicationState }
31 | import org.apache.hadoop.yarn.client.api.YarnClient
32 | import org.apache.hadoop.yarn.conf.YarnConfiguration
33 | import org.apache.hadoop.yarn.util.Records
34 |
35 | /**
36 | * YARN client which launches an Application and AppMaster by adding the jar to local resources.
37 | *
38 | * In need to submit an application, the client needs to provide sufficient information to the
39 | * ResourceManager to launch the application's first container: ApplicationMaster. This info is
40 | * named as [[ApplicationSubmissionContext]], which include:
41 | * - Application Info: id, name
42 | * - Queue info: Queue to which the application will be submitted.
43 | * - Priority info: the priority to be assigned for the application.
44 | * - User: The user submitting the application
45 | * - ContainerLaunchContext, see [[buildContainerContext]]
46 | *
47 | * @note This is an entry point for the Yarn application.
48 | *
49 | * Ways to run the yarn app:
50 | * - hadoop jar executable.jar --appClass [options]
51 | * - HADOOP_CLASSPATH=executable.jar hadoop [options]
52 | *
53 | * TODO: Support multiple jars
54 | * @author Wenrui Jiang (roy_jiang@htc.com)
55 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
56 | */
57 | object AppClient {
58 |
59 | /** define file filter to keep executable jar package */
60 | val jarFileFilter = new FileFilter {
61 | override def accept(file: File) = file.isFile && file.getName.endsWith(".jar")
62 | }
63 |
64 | /** Remove arguments from command line. */
65 | def removeArg(args: Array[String], keys: String*): Array[String] = {
66 | keys.foldLeft(args) { (array, key) =>
67 | val ind = array.indexOf(s"--$key")
68 | if (ind > -1)
69 | array.patch(ind, Nil, 2)
70 | else
71 | array
72 | }
73 | }
74 |
75 | /** launches the AppMaster with the Application */
76 | def main(args: Array[String]): Unit = {
77 | // scalding arguments
78 | val sArgs = Args(args)
79 | /** resources for the container which will launch the ApplicationMaster */
80 | val mMem = sArgs.float(memFlag, appMem)
81 | val mCores = sArgs.int(coreFlag, appCores)
82 | val mHeapSize = sArgs.float(heapFlag, mMem * heapProportion)
83 |
84 | // get hadoop classpath from hadoop command
85 | val hadoop_classpath = Process(Seq("hadoop", "classpath"), None,
86 | "HADOOP_CLASSPATH" -> "").!!.split(":").map(_.trim).filterNot(_.isEmpty)
87 | .flatMap { path =>
88 | val file = new File(path.stripSuffix("*")).getCanonicalFile
89 | if (file.isFile) Array(file)
90 | else file.listFiles(jarFileFilter).map(_.getCanonicalFile)
91 | }
92 |
93 | // get hadoop classpath from runtime environment
94 | val classpath = Thread.currentThread.getContextClassLoader match {
95 | case url: URLClassLoader => url.getURLs.map(url => url.getFile).toList
96 | case _ => sArgs.list("classPath")
97 | }
98 |
99 | val user_classpath = classpath.map(path => new File(path).getCanonicalFile)
100 | .filter(jarFileFilter.accept).diff(hadoop_classpath)
101 |
102 | /** create yarn configuration */
103 | implicit val conf = new YarnConfiguration()
104 |
105 | /** start a yarn client */
106 | val yarnClient = YarnClient.createYarnClient
107 | yarnClient.init(conf)
108 | yarnClient.start
109 |
110 | /** create yarn application */
111 | val app = yarnClient.createApplication
112 | val appResponse = app.getNewApplicationResponse
113 | val appId = appResponse.getApplicationId
114 |
115 | val fs = FileSystem.get(conf)
116 | // hdfs classpath should contain URI whose scheme and authority
117 | // identify this FileSystem
118 | val stagingRoot = new Path(fs.getUri.toString, fs.getHomeDirectory + "/yarnapp.staging")
119 | val hdfs_classpath_root = new Path(stagingRoot, appId.toString)
120 |
121 | /** upload user classpath package to hdfs */
122 | fs.mkdirs(hdfs_classpath_root)
123 | // make sure the directory can be deleted on exit
124 | fs.setPermission(stagingRoot, new FsPermission(FsAction.ALL, FsAction.ALL, FsAction.ALL))
125 | fs.deleteOnExit(hdfs_classpath_root)
126 | fs.copyFromLocalFile(false, true, user_classpath.map(file =>
127 | new Path(file.getCanonicalPath)).toArray, hdfs_classpath_root)
128 |
129 | // TODO: handle duplicate jar names
130 | val hdfsPaths = user_classpath.map(file => new Path(hdfs_classpath_root, file.getName))
131 |
132 | /** setup env to get all yarn and hadoop classes in classpath */
133 | val env = collection.mutable.Map[String, String]()
134 | setUpEnv(env, user_classpath.map(_.getName), mHeapSize)
135 |
136 | /**
137 | * build [[ContainerLaunchContext]] for the container which will launch the [[ApplicationMaster]]
138 | */
139 | val cmd = launchJVM(
140 | classOf[AppContainers].getName,
141 | (args :+ (s" --$AppIdFlag ${appId.getClusterTimestamp} ${appId.getId}")).toList
142 | )
143 | val ctx = buildContainerContext(cmd, hdfsPaths, env)
144 |
145 | /**
146 | * Set the resource required by the [[ApplicationMaster]]
147 | * for this application
148 | */
149 | val resource = Records.newRecord(classOf[Resource])
150 | resource.setMemory((mMem * gB).toInt)
151 | resource.setVirtualCores(mCores)
152 |
153 | /**
154 | * setup the [[ApplicationSubmissionContext]] which defines all the
155 | * information needed by the ResourceManager to launch the
156 | * [[ApplicationMaster]].
157 | */
158 | val appContext = app.getApplicationSubmissionContext
159 | appContext.setApplicationName(appName)
160 | appContext.setAMContainerSpec(ctx)
161 | appContext.setResource(resource)
162 | appContext.setQueue(queueType)
163 |
164 | val clockF = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")
165 | val elapsedF = new SimpleDateFormat("HH:mm:ss.SSS")
166 | elapsedF.setTimeZone(TimeZone.getTimeZone("GMT+0"))
167 | /** submit the application */
168 | println("submitting application " + yarnClient.submitApplication(appContext) +
169 | " Clock: " + clockF.format(new Date(System.currentTimeMillis)))
170 |
171 | /** init yarn application report */
172 | var appReport = yarnClient.getApplicationReport(appId)
173 | val start_time = appReport.getStartTime
174 | var appState = appReport.getYarnApplicationState
175 |
176 | sys.addShutdownHook {
177 | if (appState == YarnApplicationState.RUNNING ||
178 | appState == YarnApplicationState.ACCEPTED)
179 | yarnClient.killApplication(appId)
180 |
181 | val fTime = appReport.getFinishTime
182 | println(appId + " last state: " + appState + " Clock: " +
183 | clockF.format(new Date(if (fTime > 0) fTime else System.currentTimeMillis)))
184 |
185 | yarnClient.stop
186 | }
187 |
188 | while (appState != YarnApplicationState.FINISHED &&
189 | appState != YarnApplicationState.KILLED &&
190 | appState != YarnApplicationState.FAILED) {
191 | try {
192 | // get running status every minute
193 | Thread.sleep(60000)
194 | appReport = yarnClient.getApplicationReport(appId)
195 | appState = appReport.getYarnApplicationState
196 | println("%s running : %4.2f%%, time elapse(ms): %s".format(
197 | appId, appReport.getProgress * 100, elapsedF.format(new Date(System.currentTimeMillis - start_time))
198 | ))
199 | } catch { case _: InterruptedException => }
200 | }
201 | }
202 | }
203 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/MasterActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.collection.mutable.Buffer
20 |
21 | import akka.actor.{ Actor, ActorLogging, ActorRef, Terminated }
22 |
23 | import com.twitter.scalding.Args
24 |
25 | import MasterActor._
26 |
27 | /**
28 | * The master actor, which is responsible to organize how workers cooperate. By default, it acts as
29 | * asynchrounous master actor.
30 | * @param args Command line arguments used to initialize master actor and caffe worker.
31 | * Parameters required by master actor:
32 | * - '--maxIter': The maximum number of caffe runs, must be provided.
33 | * - '--startIter': The number of first iteration, default is 0. This is useful if the training is
34 | * continued from previous snapshots. Start iteration is only used in logging and setting iteration
35 | * and learning rate for caffe worker.
36 | * - '--test': The iteration interval to trigger tests. Default is maxIter. If set to 0, then all
37 | * tests are skipped by test actor.
38 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
39 | */
40 | case class MasterActor(args: Args, dbActor: ActorRef, tester: ActorRef, workers: Buffer[ActorRef])
41 | extends Actor with ActorLogging {
42 | // output arguments
43 | log.info("Running {} master actor with args: {}", strategyName, args.toList.mkString(" "))
44 |
45 | // watch all workers
46 | workers.foreach(context.watch)
47 |
48 | /** The start iteration, i.e. iterations trained before this master actor */
49 | val startIter = args.int(StartIterFlag, 0)
50 | /** The maximum iterations to invoke workers for training. */
51 | val maxIter = args.int(MaxIterFlag)
52 | /** The iteration interval to trigger test repeatedly. */
53 | val testInterval = args.int("test", maxIter)
54 |
55 | /** The current iterations run. */
56 | var count = 0
57 |
58 | /** The name of parallel strategy in this master actor, used in logging */
59 | def strategyName: String = "asynchrounous"
60 |
61 | override def receive = commonState orElse {
62 | case Init =>
63 | log.info("DB actor initialized snapshot")
64 | // run test after initialization
65 | tester ! Test
66 | // start training
67 | workers.foreach(_ ! trainMessage)
68 | // Change behavior of this actor (Support train and test results)
69 | context.become(trainState)
70 | }
71 |
72 | /** The common message that are safe for all states. */
73 | val commonState: Receive = {
74 | case Progress => sender ! Progress(count.toFloat / maxIter)
75 | case ProgressIter(offset, base) => sender ! Progress((count + offset - maxIter) / base.toFloat)
76 | case TestResult(Some(accuracy)) => log.info("Current accuracy is {}", accuracy)
77 | case TestResult(None) => log.warning("Test failed!")
78 | }
79 |
80 | /** The message parsing for training and testing results. */
81 | def trainState: Receive = commonState orElse {
82 | case Trained(loss) =>
83 | val ParsedTrainResult(needMerge, needTrain) = parseTrainResult(loss)
84 | needMerge match {
85 | case MergeResultNone =>
86 | log.info("Dropped {} with loss = {}.", sender.path.name, loss)
87 | // clear worker's suffix key, so we merge nothing
88 | dbActor ! ClearWorkerKey(sender)
89 | // we still need to merge here, since we need to make sure the next
90 | // training can read the latest snapshot, skip warning
91 | dbActor ! Merge(sender, true)
92 | case MergeResultSender =>
93 | count = count + 1
94 | log.info("{} finished {}th run with loss = {}.", sender.path.name, count + startIter, loss)
95 | // merge the delta into snapshot
96 | dbActor ! Merge(sender)
97 | case MergeResultWait =>
98 | log.info("{} finished part of batch with loss = {}, waiting to merge", sender.path.name, loss)
99 | case MergeResultAll(aveloss) =>
100 | count = count + 1
101 | log.info("{} finished {}th run with loss = {} (averaged), {} (raw)", sender.path.name, count + startIter, aveloss, loss)
102 | dbActor ! MergeAll
103 | }
104 | if (count >= maxIter) {
105 | dbActor ! Forward(context.parent, TrainFinished)
106 | // unwatch workers, so as to avoid DeathPactException
107 | workers.foreach(context.unwatch)
108 | // Do not merge further training reslts
109 | context.become(commonState)
110 | } else {
111 | needTrain match {
112 | case StartTrainNone => // do nothing
113 | case StartTrainSender =>
114 | dbActor ! Forward(sender, trainMessage)
115 | case StartTrainAll =>
116 | workers.foreach(dbActor ! Forward(_, trainMessage))
117 | }
118 | if (testInterval > 0 && count % testInterval == 0 &&
119 | // Don't trigger test if we didn't merge deltas
120 | needMerge != MergeResultWait && needMerge != MergeResultNone) {
121 | log.info("Start testing")
122 | dbActor ! Forward(tester, Test)
123 | }
124 | }
125 | case WorkerCreated(worker) =>
126 | workers += worker
127 | context.watch(worker)
128 | workerCreated(worker)
129 | case Terminated(worker) if workers.contains(worker) =>
130 | workers -= worker
131 | workerTerminated(worker)
132 | }
133 |
134 | /**
135 | * Override this function in the subclasses for implementing different parallel strategies.
136 | * @return [[ParsedTrainResult]] with two options: [[MergeResult]] and [[StartTrain]]. Default is
137 | * [[MergeResultSender]] and [[StartTrainSender]].
138 | * @note [[count]] will be increased if merge.
139 | */
140 | def parseTrainResult(loss: Double): ParsedTrainResult = ParsedTrainResult()
141 |
142 | /**
143 | * A handler for new worker actors created after the master actor is started.
144 | *
145 | * The default implementation (for asynchronous master actor) just send a Train message to the
146 | * new worker.
147 | * @note The worker is already added to the [[workers]] buffer and watched by the master actor.
148 | */
149 | def workerCreated(worker: ActorRef): Unit = worker ! trainMessage
150 |
151 | /**
152 | * A handler for terminated worker actors during master actor's run.
153 | *
154 | * The default implementation (for asynchronous master actor) does nothing.
155 | * @note The worker is already removed from the [[workers]] buffer.
156 | */
157 | def workerTerminated(worker: ActorRef): Unit = {}
158 |
159 | /**
160 | * A helper function to create a [[Train]] message, according to the current
161 | * iteration and start iteration.
162 | */
163 | final def trainMessage: Train = Train(count + startIter)
164 | }
165 |
166 | object MasterActor {
167 | /** The base trait for a enumeration of how to merge the trained result. */
168 | sealed trait MergeResult
169 | /** Do not merge the result from the sender. */
170 | case object MergeResultNone extends MergeResult
171 | /** Merge the result from the sender. */
172 | case object MergeResultSender extends MergeResult
173 | /**
174 | * Do not merge the result from the sender now, but output log are same as [[MergeResultSender]].
175 | * @note Only used for synchronous master actor. And the result will be merged
176 | * when [[MergeResultAll]] is triggered.
177 | */
178 | case object MergeResultWait extends MergeResult
179 | /**
180 | * Averages results from all workers and merge once.
181 | * @param aveloss The averaged loss for all the workers.
182 | * @note Only used for synchronous master actor.
183 | */
184 | case class MergeResultAll(aveloss: Double) extends MergeResult
185 |
186 | /** The base trait for a enumeration of how to start the workers to train. */
187 | sealed trait StartTrain
188 | /** Do not start he sender to train next iteration. */
189 | case object StartTrainNone extends StartTrain
190 | /** Start the sender to train next iteration (after merge if any). */
191 | case object StartTrainSender extends StartTrain
192 | /**
193 | * Start all the senders to train next iteration (after merge if any).
194 | * @note Should be used with [[StartTrainNone]] to make sure the worker is not queued by more than
195 | * one train message.
196 | */
197 | case object StartTrainAll extends StartTrain
198 |
199 | /**
200 | * The result type of the abstract [[MasterActor.parseTrainResult]] function to tell the how to
201 | * act with the result sent from a worker (sender).
202 | * By default, always merge from sender and start next train for sender.
203 | */
204 | case class ParsedTrainResult(merge: MergeResult = MergeResultSender, train: StartTrain = StartTrainSender)
205 | }
206 |
--------------------------------------------------------------------------------
/src/test/scala/com/htc/speedo/akka/PSMasterActorSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | /**
20 | * Test Partial Synchronous Master Actor.
21 | * Simulate workers[0,1,2] finished the interations in such a sequence: 0, 0, 1, 2
22 | * Status after each worker finishing the iteration:
23 | * worker 0: normal
24 | * worker 0: normal => catchup
25 | * worker 1: catchup
26 | * worker 2: catchup => normal
27 | *
28 | * @author Wenrui Jiang (roy_jiang@htc.com)
29 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
30 | */
31 | class PSMasterActorSpec extends ActorSpec {
32 | override val maxIter = 30 // runs long enough
33 | override val testInterval = 50 // no tests triggered
34 |
35 | "Partial Synchronous Master Actor" should {
36 | val actor = createMasterActor[PSMasterActor]("--maxAdvance 2")
37 | val master = actor.self
38 | var iter = 0
39 | // Skip init and change to training state
40 | actor.context.become(actor.trainState)
41 | val sender_0 = actor.workers(0)
42 | val sender_1 = actor.workers(1)
43 | val sender_2 = actor.workers(2)
44 | "work correctly for normal and catch up strategy" in {
45 | // worker 0 finished train for iteration 0
46 | master.tell(Trained(.1f), sender_0)
47 | iter += 1
48 | dbProb.assertMessage(Merge(sender_0))
49 | dbProb.assertMessage(Forward(sender_0, Train(iter)))
50 | actor.catchup must_== false
51 | actor.catchupWorkers must beEmpty
52 |
53 | // worker 0 finished train for iteration 1
54 | master.tell(Trained(.1f), sender_0)
55 | iter += 1
56 | dbProb.assertMessage(Merge(sender_0))
57 | workerProb(0).expectNoMsg
58 | actor.catchup must_== true
59 | actor.catchupWorkers must containTheSameElementsAs(actor.workers.toBuffer - sender_0)
60 |
61 | // worker 1 finished train for iteration 2
62 | master.tell(Trained(.1f), sender_1)
63 | iter += 1
64 | dbProb.assertMessage(Merge(sender_1))
65 | workerProb(1).expectNoMsg
66 | actor.catchup must_== true
67 | actor.catchupWorkers must containTheSameElementsAs(Seq(sender_2))
68 |
69 | // worker 2 finished train for iteration 3
70 | master.tell(Trained(.1f), sender_2)
71 | iter += 1
72 | dbProb.assertMessage(Merge(sender_2))
73 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, Train(iter))) }
74 | actor.catchup must_== false
75 | actor.catchupWorkers must beEmpty
76 | }
77 | "handle worker hot join correctly" in {
78 | "worker hot join when catch up" in {
79 | // worker 0 finished train for iteration 0
80 | master.tell(Trained(.1f), sender_0)
81 | dbProb.assertMessage(Merge(sender_0))
82 | dbProb.assertMessage(Forward(sender_0, actor.trainMessage))
83 | // worker 0 finished train for iteration 1 and it goes to catch up
84 | master.tell(Trained(.1f), sender_0)
85 | dbProb.assertMessage(Merge(sender_0))
86 | actor.catchup must_== true
87 | val catchupWorkers = actor.workers.toBuffer - sender_0
88 | actor.catchupWorkers must containTheSameElementsAs(catchupWorkers)
89 |
90 | // hot join one worker (should be different with existing workers)
91 | master ! WorkerCreated(testActor) // use this test kit as new worker
92 | // catchup workers are not changed
93 | actor.catchupWorkers must containTheSameElementsAs(catchupWorkers)
94 | // total workers contains newly added workers
95 | actor.workers must containTheSameElementsAs(workerProb.map(_.ref) :+ testActor)
96 |
97 | // worker 1 finished train for iteration 2
98 | master.tell(Trained(.1f), sender_1)
99 | dbProb.assertMessage(Merge(sender_1))
100 | // worker 2 finished train for iteration 3 and back to normal
101 | master.tell(Trained(.1f), sender_2)
102 | dbProb.assertMessage(Merge(sender_2))
103 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, actor.trainMessage)) }
104 | actor.catchup must_== false
105 | actor.catchupWorkers must beEmpty
106 |
107 | // clean up
108 | actor.workers -= testActor
109 | actor.workerIters -= testActor
110 | actor.context.unwatch(testActor)
111 | ok
112 | }
113 | "worker hot join when not catch up" in {
114 | // just check we have all worker iters set to 1
115 | actor.workerIters.toSeq must containTheSameElementsAs(actor.workers.map(_ -> 1).toSeq)
116 | // hot join one worker (should be different with existing workers)
117 | actor.workerCreated(testActor) // use this test kit as new worker
118 | this.assertMessage(actor.trainMessage)
119 | actor.workerIters(testActor) must_== 1
120 |
121 | // clean up
122 | actor.workerIters -= testActor
123 | ok
124 | }
125 | }
126 | "handle terminated worker correctly" in {
127 | "worker terminate when catch up" in {
128 | "worker terminate before train finished" in {
129 | // worker 0 finished train for iteration 0
130 | master.tell(Trained(.1f), sender_0)
131 | dbProb.assertMessage(Merge(sender_0))
132 | dbProb.assertMessage(Forward(sender_0, actor.trainMessage))
133 | // worker 0 finished train for iteration 1 and it goes to catch up
134 | master.tell(Trained(.1f), sender_0)
135 | dbProb.assertMessage(Merge(sender_0))
136 | actor.catchup must_== true
137 | val catchupWorkers = actor.workers.toBuffer - sender_0
138 | actor.catchupWorkers must containTheSameElementsAs(catchupWorkers)
139 |
140 | // simulate worker0 is terminated
141 | actor.workers -= sender_0
142 | actor.workerTerminated(sender_0)
143 | actor.workerIters.get(sender_0) must beNone
144 | // catchup workers are not changed
145 | actor.catchupWorkers must containTheSameElementsAs(catchupWorkers)
146 | // worker0 is removed from total workers
147 | actor.workers must containTheSameElementsAs(Seq(sender_1, sender_2))
148 |
149 | // worker 1 finished train for iteration 2
150 | master.tell(Trained(.1f), sender_1)
151 | dbProb.assertMessage(Merge(sender_1))
152 | // worker 2 finished train for iteration 3 and back to normal
153 | master.tell(Trained(.1f), sender_2)
154 | dbProb.assertMessage(Merge(sender_2))
155 | actor.workers.foreach { sender => dbProb.assertMessage(Forward(sender, actor.trainMessage)) }
156 | actor.catchup must_== false
157 | actor.catchupWorkers must beEmpty
158 |
159 | // clean up
160 | actor.workerIters += sender_0 -> 1
161 | sender_0 +=: actor.workers // prepend worker0 back to total workers
162 | ok
163 | }
164 | "worker terminate after train finished" in {
165 | // worker 0 finished train for iteration 0
166 | master.tell(Trained(.1f), sender_0)
167 | dbProb.assertMessage(Merge(sender_0))
168 | dbProb.assertMessage(Forward(sender_0, actor.trainMessage))
169 | // worker 0 finished train for iteration 1 and it goes to catch up
170 | master.tell(Trained(.1f), sender_0)
171 | dbProb.assertMessage(Merge(sender_0))
172 | actor.catchup must_== true
173 | val catchupWorkers = actor.workers.toBuffer - sender_0
174 | actor.catchupWorkers must containTheSameElementsAs(catchupWorkers)
175 |
176 | // simulate worker2 is terminated
177 | actor.workers -= sender_2
178 | actor.workerTerminated(sender_2)
179 | actor.workerIters.get(sender_2) must beNone
180 | // worker2 should be removed from catchup workers
181 | actor.catchupWorkers must containTheSameElementsAs(Seq(sender_1))
182 | // worker2 is removed from total workers
183 | actor.workers must containTheSameElementsAs(Seq(sender_0, sender_1))
184 | // we are still catching up (waiting for worker1)
185 | actor.catchup must_== true
186 |
187 | // simulate worker1 is terminated
188 | actor.workers -= sender_1
189 | actor.workerTerminated(sender_1)
190 | actor.workerIters.get(sender_1) must beNone
191 | // worker1 is removed from total workers
192 | actor.workers must_== Seq(sender_0)
193 | // we should now back to normal
194 | actor.catchup must_== false
195 | actor.catchupWorkers must beEmpty
196 | dbProb.assertMessage(Forward(sender_0, actor.trainMessage))
197 |
198 | // clean up
199 | actor.workerIters ++= Seq(sender_1 -> 1, sender_2 -> 1)
200 | actor.workers ++= Seq(sender_1, sender_2)
201 | ok
202 | }
203 | }
204 | "worker terminate when not catch up" in {
205 | // terminate a worker
206 | actor.workerTerminated(sender_0)
207 | actor.workerIters.get(sender_0) must beNone
208 |
209 | // clean up
210 | actor.workerIters += sender_0 -> 1
211 | ok
212 | }
213 | }
214 | }
215 | }
216 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SpeeDo - Parallelizing Stochastic Gradient Descent for Deep Convolutional Neural Network
2 |
3 | ## Introduction
4 |
5 | Convolutional Neural Networks (CNNs) have achieved breakthrough results on many machine learning tasks. However, training CNNs is computationally intensive. When the size of training data is large and the depth of CNNs is high, as typically required for attaining high classification accuracy, training a model can take days and even weeks. So we propose [SpeeDO](http://learningsys.org/papers/LearningSys_2015_paper_13.pdf) (for Open DEEP learning System in backward order), a deep learning system designed for off-the-shelf hardwares. SpeeDO can be easily deployed, scaled and maintained in a cloud environment, such as AWS EC2 cloud, Google GCE, and Microsoft Azure.
6 |
7 | In our implement, we support 5 distributed SGD models to speed up the training:
8 |
9 | * Synchronous SGD
10 | * Asynchronous SGD
11 | * Partially Synchronous SGD
12 | * Weed-Out SGD
13 | * Elastic Averaging SGD
14 |
15 | Please cite [SpeeDO](http://learningsys.org/papers/LearningSys_2015_paper_13.pdf) in your publications if it helps your research:
16 |
17 | @article{zhengspeedo,
18 | title={SpeeDO: Parallelizing Stochastic Gradient Descent for Deep Convolutional Neural Network},
19 | author={Zheng, Zhongyang and Jiang, Wenrui and Wu, Gang and Chang, Edward Y}
20 | }
21 |
22 | ## Architecture
23 |
24 | SpeeDO takes advantage of many existing solutions in the open-source community, data flow of SpeeDO:
25 |
26 | 
27 |
28 | SpeeDO mainly contains these components:
29 |
30 | * [Caffe](http://caffe.berkeleyvision.org/) (required)
31 | * [Redis](http://redis.io) (required)
32 | * [Akka](http://akka.io) (required)
33 | * [Yarn](https://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/YARN.html) [optional]
34 | * [HDFS](https://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-hdfs/HdfsUserGuide.html) [optional]
35 |
36 | These components denote what we need to deploy before the distributed training.
37 |
38 | # Deploy and Run
39 |
40 | ## Pre-requisite
41 |
42 | * JDK 1.7+
43 | * [Redis Server](http://redis.io/download)
44 | * Ubuntu 12.04+ (not test on other linux os family)
45 | * Network connection: You need to connect to maven and ivy repositories when compiling the demo
46 |
47 | SpeeDO is running in **Master-Slaves(Worker)** archiecture. To avoid manual process in running, and distribute input data in Master and worker nodes, we can use YARN and HDFS. In below, we provides the instruction to run SpeeDO for both scenarios.
48 |
49 | | Configuration | YARN present | HDFS present |
50 | | ------------- | ------------- | ------------- |
51 | | A | N | N |
52 | | B | Y | Y |
53 |
54 | **NOTE**
55 |
56 | i. YARN is used for nodes resource scheduling. If YARN is not present, we can run our Master , and Worker process manually.
57 |
58 | ii. HDFS is used for storing training data and network definition of caffe. If HDFS is not present, we can use shared-files system (like NFS ) or manually copying these files to each nodes.
59 |
60 | We provide the steps to run configuration **A** and **B** here:
61 | * For configuration **A**, we manually deploy and run on all nodes.
62 | * For configuration **B**, we use [cloudera](http://www.cloudera.com/documentation/manager/5-1-x/Cloudera-Manager-Installation-Guide/Cloudera-Manager-Installation-Guide.html) (offering us both YARN and HDFS) to deploy and run SpeeDO.
63 |
64 | ## A. Deploy and run SpeeDO without YARN and HDFS
65 |
66 | We provides TWO methods here: 1) Docker , 2) Manual ( step by step)
67 |
68 | ##1. Quick Start ( via Docker )
69 |
70 | ### Step.0 Pull image
71 | Pull the speedo image ( bundled with caffe and all its dependencies libraries):
72 | ```bash
73 | docker pull obdg/speedo:latest
74 | ```
75 |
76 | ### Step.1 Run containers on cluster
77 | The following example will run 1000 iterations asynchronously using 1 Master with 3 workers ( 4 cluster nodes )
78 |
79 | #### Master
80 | Launch master container on your master node (in default Async model with 3 workers):
81 | ```bash
82 | docker run -d --name=speedo-master --net=host obdg/speedo
83 | ```
84 |
85 | **Or** run master actor in Easgd model with 3 workers
86 | ```bash
87 | docker run -d --name=speedo-master --net=host obdg/speedo master 3 --test 0 --maxIter 1000 --movingRate 0.5
88 | ```
89 |
90 | Please replaces `master-address` with master node's ip
91 |
92 | **NOTE**
93 | Redis service will be started automatically when launching master container
94 |
95 | #### Worker
96 | Launch 3 worker containers on different worker nodes:
97 | ```bash
98 | docker run -d --name=speedo-worker --net=host obdg/speedo worker
99 | ```
100 |
101 | Please replaces `master-address` with master node's ip, and `worker-address` with the current worker node's ip
102 |
103 | ##2. Manually ( Step by Step )
104 |
105 | ### Step.0 Pre-requistie
106 | Install at each nodes ( Master and Worker)
107 | 1. JDK 1.7+
108 | 2. Redis Server
109 | 3. Clone SpeeDO and Caffe source from our github repo
110 |
111 | Please use
112 | ```
113 | git clone --recursive git@github.com/obdg/speedo.git # SpeeDO and caffe
114 | ```
115 |
116 | ### Step.1 Install caffe and its dependencies
117 | Install [speedo/caffe](https://github.com/obdg/caffe) and all its dependencies on each nodes , please refer to section **A. Manually install on all cluster nodes** from [speedo/caffe install guide](https://github.com/obdg/caffe).
118 |
119 | ### Step.2 Prepare Input Data to run under Caffe
120 |
121 | **NOTE**: We prefer to use **datumfile** format for SpeeDO ( see [caffe-pullrequest-2193](https://github.com/BVLC/caffe/pull/2193) ) instead of the default leveldb/lmdb format during training in Caffe to solve the memory usage problem ( refer to [caffe-issues-1377](https://github.com/BVLC/caffe/issues/1377)).
122 |
123 | The input data required by Caffe, including:
124 | * solver definition
125 | * network definition
126 | * training datasets
127 | * testing datasets
128 | * mean values
129 |
130 | In this example, let's train cifar10 dataset and generate `training datasets` and `testing datasets` in dataumfile format:
131 | ```bash
132 | cd caffe
133 | ./data/cifar10/get_cifar10.sh # download cifar dataset
134 | ./examples/speedo/create_cifar10.sh # create protobuf file - in datumfile instead of leveldb/lmdb format
135 | ```
136 |
137 | `Solver definition`, `network definition` and `means values` written in datumfile format for cifar10 is provided at examples/speedo.
138 |
139 | > If you want to manually produce these files, please follow the steps below. (Modify all paths in network definitions if needed ) :
140 | ```bash
141 | sed -i "s/examples\/cifar10\/mean.binaryproto/mean.binaryproto/g" cifar10_full_train_test.prototxt
142 | sed -i "s/examples\/cifar10\/cifar10_train_lmdb/cifar10_train_datumfile/g" cifar10_full_train_test.prototxt
143 | sed -i "s/examples\/cifar10\/cifar10_test_lmdb/cifar10_test_datumfile/g" cifar10_full_train_test.prototxt
144 | sed -i "s/backend: LMDB/backend: DATUMFILE/g" cifar10_full_train_test.prototxt
145 | sed -i "17i\ rand_skip: 50000" cifar10_full_train_test.prototxt
146 | sed -i "s/examples\/cifar10\/cifar10_full_train_test.prototxt/cifar10_full_train_test.prototxt/g" cifar10_full_solver.prototxt
147 | ```
148 |
149 | At last, put the data in the same location(like /tmp/caffe/cifar10) on **all Master and Workers node**. You can do that by [Ansible](https://www.ansible.com/) or just scp to the right location.
150 |
151 | ### Step.3 Training under SpeeDO
152 |
153 | SpeeDO use Master + Worker archiecture for the distributed training (Please refer to our paper for the detail information). We need to start Master node and Worker node as below.
154 |
155 | #### Compile bundle jar
156 | On each master and worker nodes, run
157 | ```bash
158 | git clone git@github.com/obdg/speedo.git # if not done yet
159 | cd speedo
160 | ./sbt akka:assembly
161 | ```
162 |
163 | #### Run Master and Worker process
164 |
165 | The following example will run 1000 iterations asynchronously using 1 Master with 3 workers ( 4 cluster nodes ).
166 |
167 | ##### Master
168 | Launch master process on your master node:
169 | ```bash
170 | JAVA_LIBRARY_PATH=$JAVA_LIBRARY_PATH:/usr/lib java -cp target/scala-2.11/SpeeDO-akka-1.0.jar -Xmx2G com.htc.speedo.akka.AkkaUtil --solver /absolute_path/to/cifar10_full_solver.prototxt --worker 3 --redis --test 500 --maxIter 1000 --host 2> /dev/null
171 | ```
172 |
173 | Please replaces `redis-address` with the redis server location, and `master-address` with master node's ip/hostname.
174 |
175 | This should output some thing like:
176 |
177 | [INFO] [03/03/2016 15:07:41.626] [main] [Remoting] Starting remoting
178 | [INFO] [03/03/2016 15:07:41.761] [main] [Remoting] Remoting started; listening on addresses :[akka.tcp://SpeeDO@cloud-master:56126]
179 | [INFO] [03/03/2016 15:07:41.763] [main] [Remoting] Remoting now listens on addresses: [akka.tcp://SpeeDO@cloud-master:56126]
180 | [INFO] [03/03/2016 15:07:41.777] [SpeeDO-akka.actor.default-dispatcher-3] [akka.tcp://SpeeDO@cloud-master:56126/user/host] Waiting for 3 workers to join.
181 |
182 | ##### Worker
183 | Launch 3 workers process on worker nodes:
184 | ```bash
185 | JAVA_LIBRARY_PATH=$JAVA_LIBRARY_PATH:/usr/lib java -cp target/scala-2.11/SpeeDO-akka-1.0.jar -Xmx2G com.htc.speedo.akka.AkkaUtil --host --master 2> /dev/null
186 | ```
187 |
188 | Please replaces `worker-address` with worker's ip/hostname, and `masteractor-addr` with master actor address.
189 |
190 | The format of master actor address is **`akka.tcp://SpeeDO@cloud-master:56126/user/host`**, where cloud-master is the hostname of master node, and 56126 is the TCP port listen by akka's actor. Since the port is random by default, the address can vary in different runs. You can also use fixed port by passing a `--port ` command line argument when start Master.
191 |
192 |
193 |
194 | ## B. Deploy and run SpeeDO by cloudera
195 | To try a cloudera solution for SpeeDO. Please refer [Run SpeeDO on Yarn & HDFS Cluster](https://github.com/obdg/speedo/blob/master/README_YARN.md)
196 |
197 | ## Experiments on AWS
198 |
199 | The Cifar10 dataset is used to validate all parallel implementations on a CPU cluster with four 8-core instances
200 |
201 | 
202 |
203 | Training [GoogleNet](http://www.cs.unc.edu/~wliu/papers/GoogLeNet.pdf) on a GPU cluster for different parallel implementations
204 |
205 | 
206 |
207 | EASGD achieves the best speedup in our parallel implementations. And parameters of it have great impact for the speedup.
208 |
209 | 
210 |
211 | ## Authors
212 |
213 | * [Zhongyang Zheng](https://github.com/zyzheng)
214 | * [Wenrui Jiang](https://github.com/wenruij)
215 | * [Gang Wu](https://github.com/simonandluna)
216 |
217 | ## Supervisor
218 | * [Edward Y. Chang](http://infolab.stanford.edu/~echang/)
219 |
220 | ## License
221 |
222 | Copyright 2016 HTC Corporation
223 |
224 | Licensed under the Apache License, Version 2.0: http://www.apache.org/licenses/LICENSE-2.0
225 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/yarn/AppContainers.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.yarn
18 |
19 | import java.util.{ List => jList }
20 | import java.util.concurrent.atomic.AtomicInteger
21 |
22 | import scala.collection.JavaConverters._
23 | import scala.collection.mutable.{ Map => MutableMap }
24 |
25 | import com.twitter.scalding.Args
26 |
27 | import org.apache.hadoop.fs.{ FileSystem, Path }
28 | import org.apache.hadoop.yarn.api.records._
29 | import org.apache.hadoop.yarn.client.api.YarnClient
30 | import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
31 | import org.apache.hadoop.yarn.client.api.async.{ AMRMClientAsync, NMClientAsync }
32 | import org.apache.hadoop.yarn.conf.YarnConfiguration
33 | import org.apache.hadoop.yarn.server.utils.BuilderUtils
34 | import org.apache.hadoop.yarn.util.Records
35 | import org.slf4j.LoggerFactory
36 |
37 | object AppContainers {
38 | def main(args: Array[String]): Unit = {
39 | val appContainers = new AppContainers(args)
40 | val state = appContainers.run
41 | if (state) System.exit(0) else System.exit(2)
42 | }
43 | }
44 |
45 | /**
46 | * Allocate and launch the containers that will run the application job
47 | *
48 | * For each allocated container, the ApplicationMaster can then set up the necessary
49 | * launch context via ContainerLaunchContext to specify the allocated container id, local resources
50 | * required by the executable, the environment to be setup for the executable, commands to execute,
51 | * etc. and submit a link StartContainerRequest to the ContainerManagementProtocol to launch and
52 | * execute the defined commands on the given allocated container.
53 | *
54 | * @author Wenrui Jiang (roy_jiang@htc.com)
55 | */
56 | class AppContainers(gArgs: Array[String]) {
57 | @transient private val logger = LoggerFactory.getLogger(classOf[AppContainers])
58 | val sArgs = Args(gArgs)
59 | val (clusterTimestamp, id) = sArgs.list(AppIdFlag) match {
60 | /** Combination of clusterTimestamp and id will be an application id */
61 | case List(clusterTimestamp, id) => (clusterTimestamp.toLong, id.toInt)
62 | case s: List[String] => throw new IllegalArgumentException("wrong args for application id...")
63 | }
64 | val newArgs = new Args(sArgs.m.-(AppClassFlag, coreFlag, memFlag, AppIdFlag, heapFlag))
65 | val appId = BuilderUtils.newApplicationId(clusterTimestamp, id)
66 |
67 | val appClass = sArgs.required(AppClassFlag)
68 | /** resources for the container which will launch the Application */
69 | val cMem = sArgs.float(memFlag, containerMem)
70 | val cCores = sArgs.int(coreFlag, containerCores)
71 | val cHeapSize = sArgs.float(heapFlag, cMem * heapProportion)
72 |
73 | implicit val conf = new YarnConfiguration()
74 | /** start a yarn Client */
75 | val yarnClient = YarnClient.createYarnClient
76 | yarnClient.init(conf)
77 | yarnClient.start
78 |
79 | /**
80 | * Count of running containers requested from the RM
81 | * Needed as once requested, we should not request for containers again.
82 | * Only request for more if the original requirement changes.
83 | */
84 | private val numRunningContainers: AtomicInteger = new AtomicInteger()
85 | /** Count of total containers already requested from the RM. */
86 | private val totalRequestedContainers: AtomicInteger = new AtomicInteger()
87 | /** indicate if the application completed and the success or not */
88 | private var done: Option[Boolean] = None
89 |
90 | /** get application report */
91 | val appReport = yarnClient.getApplicationReport(appId)
92 |
93 | val hdfs_classpath_root = new Path("/user/" + appReport.getUser + "/yarnapp.staging/" + appId.toString)
94 | val hdfsPaths = FileSystem.get(conf).listStatus(hdfs_classpath_root).map(_.getPath).toList
95 |
96 | /** get instance of this application */
97 | val yarnApp = ReflectionUtils.getInstaceFrom(appClass, newArgs).get.asInstanceOf[YarnApp]
98 | /** total number of containers needed by the application */
99 | val numTotalContainers: Int = yarnApp.getSize(gArgs)
100 | /** max number of containers allow to request */
101 | val maxRequest: Int = numTotalContainers * 2
102 |
103 | /** Handle to communicate with the Resource Manager */
104 | val amRMClientAsync: AMRMClientAsync[ContainerRequest] = AMRMClientAsync.createAMRMClientAsync(1000, RMCallbackHandler)
105 | amRMClientAsync.init(conf)
106 | amRMClientAsync.start
107 |
108 | /** Handle to communicate with the Node Manager */
109 | val nmClientAsync: NMClientAsync = NMClientAsync.createNMClientAsync(null)
110 | nmClientAsync.init(conf)
111 | nmClientAsync.start
112 |
113 | /**
114 | * launch containers for each task, after a container has been allocated to the ApplicationMaster,
115 | * it needs to set up the ContainerLaunchContext for the eventual task that is going to be running
116 | * on the allocated Container.
117 | */
118 | def run: Boolean = {
119 | /** Registers this application master client with the resource manager */
120 | val response = amRMClientAsync.registerApplicationMaster("", 0, "")
121 |
122 | val previousRunningContainers = response.getContainersFromPreviousAttempts
123 | logger.debug(appReport.getCurrentApplicationAttemptId.toString + " received " + previousRunningContainers.size() +
124 | " previous attempts' running containers on AM registration.")
125 |
126 | val containerAsk = setupContainerAskForRM(cMem, cCores)
127 | val numTotalContainersToRequest = numTotalContainers - previousRunningContainers.size
128 | // request containers from ResourceManager
129 | (1 to numTotalContainersToRequest).map { _ => amRMClientAsync.addContainerRequest(containerAsk) }
130 | numRunningContainers.set(numTotalContainers)
131 | totalRequestedContainers.set(numTotalContainers)
132 |
133 | // wait for completion.
134 | while (done.isEmpty) {
135 | try { Thread.sleep(waitTime) } catch { case _: InterruptedException => }
136 | }
137 |
138 | // When the application completes, it should stop all running containers
139 | logger.info("Application completed. Stopping running containers")
140 | nmClientAsync.stop
141 |
142 | // When the application completes, it should send a finish application signal to the RM
143 | val success = done.getOrElse(false)
144 | val appStatus = if (success) FinalApplicationStatus.SUCCEEDED else FinalApplicationStatus.FAILED
145 |
146 | amRMClientAsync.unregisterApplicationMaster(appStatus, "", null)
147 | amRMClientAsync.stop
148 | yarnClient.stop
149 | success
150 | }
151 |
152 | /**
153 | * Setup the request that will be sent to the RM for the container ask.
154 | * @return the setup ResourceRequest to be sent to RM
155 | */
156 | def setupContainerAskForRM(cMem: Float, cCores: Int): ContainerRequest = {
157 | val priority = Records.newRecord(classOf[Priority])
158 | priority.setPriority(priorityLevel)
159 |
160 | /** resources needed for each container */
161 | val resource = Records.newRecord(classOf[Resource])
162 | resource.setMemory((cMem * gB).toInt)
163 | resource.setVirtualCores(cCores)
164 |
165 | new ContainerRequest(resource, null, null, priority)
166 | }
167 |
168 | object RMCallbackHandler extends AMRMClientAsync.CallbackHandler {
169 | val masterRole = yarnApp.getApp(gArgs, getHostName)
170 |
171 | override def onContainersCompleted(statuses: jList[ContainerStatus]) = {
172 | statuses.asScala.foreach { containerStatus =>
173 | logger.debug("Get container status for containerID=" +
174 | containerStatus.getContainerId + ", state=" +
175 | containerStatus.getState + ", exitStatus=" +
176 | containerStatus.getExitStatus + ", diagnostics=" +
177 | containerStatus.getDiagnostics)
178 |
179 | containerStatus.getExitStatus match {
180 | case ContainerExitStatus.SUCCESS =>
181 | /** Container completed successfully */
182 | logger.info("Container completed successfully." + ", containerId=" + containerStatus.getContainerId)
183 | case ContainerExitStatus.ABORTED | ContainerExitStatus.KILLED_EXCEEDED_PMEM | ContainerExitStatus.KILLED_EXCEEDED_VMEM =>
184 | logger.warn("container aborted or killed(like oom problem)," + "will be recovered later!")
185 | /**
186 | * container was killed by framework, possibly preempted we should
187 | * re-try as the container was lost for some reason, do not need
188 | * to release the container as it would be done by the RM
189 | */
190 | numRunningContainers.decrementAndGet
191 | // TODO: test for system.exit
192 | case status: Int if status > 0 =>
193 | logger.warn("container manually killed, will be recovered later!")
194 | numRunningContainers.decrementAndGet
195 | case _ =>
196 | logger.warn("container failed for unkown reason")
197 | }
198 | }
199 | /** ask for more containers if any failed */
200 | val askCount = numTotalContainers - numRunningContainers.get
201 | numRunningContainers.addAndGet(askCount)
202 | totalRequestedContainers.addAndGet(askCount)
203 | if (totalRequestedContainers.get > maxRequest) {
204 | logger.warn("number of total containers requesting exceeds the threshold")
205 | onShutdownRequest
206 | } else (1 to askCount).foreach { _ =>
207 | logger.debug("Append new container request!!")
208 | val containerAsk = setupContainerAskForRM(cMem, cCores)
209 | amRMClientAsync.addContainerRequest(containerAsk)
210 | }
211 | }
212 |
213 | override def onContainersAllocated(containers: jList[Container]) = {
214 | logger.debug("Allocated containers " + containers.asScala.map(
215 | c => c.getId.toString + " on " + c.getNodeId.getHost
216 | ).mkString(" "))
217 | /** remove previous container requesting cache */
218 | containers.asScala.foreach { c =>
219 | val ask = new ContainerRequest(c.getResource, null, null, c.getPriority)
220 | amRMClientAsync.removeContainerRequest(ask)
221 | }
222 |
223 | /** setup application environment */
224 | val env = MutableMap(masterRole.appEnv.toSeq: _*)
225 | setUpEnv(env, hdfsPaths.map(_.getName), cHeapSize)
226 | /** start slave role on corresponding container */
227 | containers.asScala.foreach { container =>
228 | val cmd = launchJVM(masterRole.slaveMain, masterRole.slaveArgs(container.getNodeId.getHost))
229 | val ctx = buildContainerContext(cmd, hdfsPaths, env)
230 | nmClientAsync.startContainerAsync(container, ctx)
231 | }
232 | }
233 |
234 | override def onShutdownRequest() = {
235 | done = Some(false) // this is called when master is out of sync
236 | amRMClientAsync.stop
237 | }
238 |
239 | override def onNodesUpdated(updatedNodes: jList[NodeReport]) = {}
240 |
241 | override def getProgress(): Float = masterRole.action() match {
242 | case InProgress(progress) => Math.min(Math.max(0, progress), 1)
243 | case Finished(success) =>
244 | done = Some(success)
245 | 1f
246 | }
247 |
248 | override def onError(e: Throwable) = {
249 | e.printStackTrace
250 | onShutdownRequest
251 | }
252 | }
253 | }
254 |
--------------------------------------------------------------------------------
/src/test/scala/com/htc/speedo/akka/WeedOutMasterActorSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | /**
20 | * Test Weed-Out Master Actor.
21 | * Simulate workers[0,1,2] finished the interations in such a sequence: 0, 1, 1, 0, 2
22 | * The drop window 'll be 3, So:
23 | * after worker 0 finished at iteration 4, it's delta will be keeped
24 | * after worker 2 finished at iteration 5, it's delta will be droped
25 | *
26 | * @author Wenrui Jiang (roy_jiang@htc.com)
27 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
28 | */
29 | class WeedOutMasterActorSpec extends ActorSpec {
30 | override val testInterval = 6
31 |
32 | "Drop Master Actor" should {
33 | val actor = createMasterActor[WeedOutMasterActor]("--weedout 3")
34 | val master = actor.self
35 | var iter = 0
36 | // Skip init and change to training state
37 | actor.context.become(actor.trainState)
38 | val sender_0 = actor.workers(0)
39 | val sender_1 = actor.workers(1)
40 | val sender_2 = actor.workers(2)
41 | "discard deltas for delayed workers" in {
42 | // worker 0 finished train at iteration 1
43 | master.tell(Trained(.1f), sender_0)
44 | iter += 1
45 | dbProb.assertMessage(Merge(sender_0))
46 | dbProb.assertMessage(Forward(sender_0, Train(iter)))
47 | actor.lastUpdates must_== Seq(sender_0)
48 | actor.count must_== iter
49 |
50 | // worker 1 finished train at iteration 2
51 | master.tell(Trained(.1f), sender_1)
52 | iter += 1
53 | dbProb.assertMessage(Merge(sender_1))
54 | dbProb.assertMessage(Forward(sender_1, Train(iter)))
55 | actor.lastUpdates must_== Seq(sender_0, sender_1)
56 | actor.count must_== iter
57 |
58 | // worker 1 finished train at iteration 3
59 | master.tell(Trained(.1f), sender_1)
60 | iter += 1
61 | dbProb.assertMessage(Merge(sender_1))
62 | dbProb.assertMessage(Forward(sender_1, Train(iter)))
63 | actor.lastUpdates must_== Seq(sender_0, sender_1, sender_1)
64 | actor.count must_== iter
65 |
66 | // worker 0 finished train at iteration 4
67 | master.tell(Trained(.1f), sender_0)
68 | iter += 1
69 | dbProb.assertMessage(Merge(sender_0))
70 | dbProb.assertMessage(Forward(sender_0, Train(iter)))
71 | actor.lastUpdates must_== Seq(sender_0, sender_1, sender_1)
72 | actor.count must_== iter
73 |
74 | // worker 2 finished train at iteration 5
75 | master.tell(Trained(.1f), sender_2)
76 | dbProb.assertMessage(ClearWorkerKey(sender_2))
77 | dbProb.assertMessage(Merge(sender_2, true))
78 | dbProb.assertMessage(Forward(sender_2, Train(iter)))
79 | actor.lastUpdates must_== Seq(sender_0, sender_2, sender_1)
80 | actor.count must_== iter
81 | }
82 | "handle worker hot join correctly" in {
83 | "worker hot join at first few iterations" in {
84 | "worker hot join when lastUpdates is empty" in {
85 | actor.updateIndex = 0
86 | actor.lastUpdates.clear
87 | actor.maxInterval = workerNumber
88 | // hot join one worker (should be different with existing workers)
89 | actor.workerCreated(testActor) // use this test kit as new worker
90 | // the joined actor should be added as the newest updater
91 | actor.lastUpdates must_== Seq(testActor)
92 | actor.updateIndex must_== actor.lastUpdates.size
93 | actor.maxInterval must_== workerNumber + 1
94 | // start training on new worker
95 | this.assertMessage(actor.trainMessage)
96 | }
97 | "worker hot join when lastUpdates has 1 updater" in {
98 | actor.updateIndex = 1
99 | actor.lastUpdates.clear
100 | actor.lastUpdates += sender_0
101 | actor.maxInterval = workerNumber
102 | // hot join one worker (should be different with existing workers)
103 | actor.workerCreated(testActor) // use this test kit as new worker
104 | // the joined actor should be added as the newest updater
105 | actor.lastUpdates must_== Seq(sender_0, testActor)
106 | actor.updateIndex must_== actor.lastUpdates.size
107 | actor.maxInterval must_== workerNumber + 1
108 | // start training on new worker
109 | this.assertMessage(actor.trainMessage)
110 | }
111 | "worker hot join when lastUpdates has 2 updater" in {
112 | actor.updateIndex = 2
113 | actor.lastUpdates.clear
114 | actor.lastUpdates ++= Seq(sender_0, sender_1)
115 | actor.maxInterval = workerNumber
116 | // hot join one worker (should be different with existing workers)
117 | actor.workerCreated(testActor) // use this test kit as new worker
118 | // the joined actor should be added as the newest updater
119 | actor.lastUpdates must_== Seq(sender_0, sender_1, testActor)
120 | actor.updateIndex must_== actor.lastUpdates.size
121 | actor.maxInterval must_== workerNumber + 1
122 | // start training on new worker
123 | this.assertMessage(actor.trainMessage)
124 | }
125 | }
126 | "worker hot join when updateIndex = 0" in {
127 | // order of old updaters is 0, 1, 2
128 | actor.updateIndex = 0
129 | actor.lastUpdates.clear
130 | actor.lastUpdates ++= Seq(sender_0, sender_1, sender_2)
131 | actor.maxInterval = actor.lastUpdates.size
132 | // hot join one worker (should be different with existing workers)
133 | actor.workerCreated(testActor) // use this test kit as new worker
134 | // the joined actor should be added as the newest updater
135 | actor.lastUpdates must_== Seq(testActor, sender_0, sender_1, sender_2)
136 | // oldest updaters are not affacted
137 | actor.lastUpdates(actor.updateIndex) must_== sender_0
138 | actor.maxInterval must_== actor.lastUpdates.size
139 | // start training on new worker
140 | this.assertMessage(actor.trainMessage)
141 | }
142 | "worker hot join when updateIndex = 1" in {
143 | // order of old updaters is 1, 2, 0
144 | actor.updateIndex = 1
145 | actor.lastUpdates.clear
146 | actor.lastUpdates ++= Seq(sender_0, sender_1, sender_2)
147 | actor.maxInterval = actor.lastUpdates.size
148 | // hot join one worker (should be different with existing workers)
149 | actor.workerCreated(testActor) // use this test kit as new worker
150 | // the joined actor should be added as the newest updater
151 | actor.lastUpdates must_== Seq(sender_0, testActor, sender_1, sender_2)
152 | // oldest updaters are not affacted
153 | actor.lastUpdates(actor.updateIndex) must_== sender_1
154 | actor.maxInterval must_== actor.lastUpdates.size
155 | // start training on new worker
156 | this.assertMessage(actor.trainMessage)
157 | }
158 | "worker hot join when updateIndex = 2" in {
159 | // order of old updaters is 2, 0, 1
160 | actor.updateIndex = 2
161 | actor.lastUpdates.clear
162 | actor.lastUpdates ++= Seq(sender_0, sender_1, sender_2)
163 | actor.maxInterval = actor.lastUpdates.size
164 | // hot join one worker (should be different with existing workers)
165 | actor.workerCreated(testActor) // use this test kit as new worker
166 | // the joined actor should be added as the newest updater
167 | actor.lastUpdates must_== Seq(sender_0, sender_1, testActor, sender_2)
168 | // oldest updaters are not affacted
169 | actor.lastUpdates(actor.updateIndex) must_== sender_2
170 | actor.maxInterval must_== actor.lastUpdates.size
171 | // start training on new worker
172 | this.assertMessage(actor.trainMessage)
173 | }
174 | }
175 | "handle terminated worker correctly" in {
176 | "worker terminate at first few iterations" in {
177 | "worker terminate when lastUpdates is empty" in {
178 | actor.updateIndex = 0
179 | actor.lastUpdates.clear
180 | actor.maxInterval = workerNumber
181 | // terminate a worker (which one is not important)
182 | actor.workerTerminated(sender_0)
183 | // oldest updater must be removed and second oldest becomes oldest
184 | actor.lastUpdates must beEmpty
185 | actor.updateIndex must_== actor.lastUpdates.size
186 | actor.maxInterval must_== workerNumber - 1
187 | }
188 | "worker terminate when lastUpdates has 1 updater" in {
189 | actor.updateIndex = 1
190 | actor.lastUpdates.clear
191 | actor.lastUpdates += sender_0
192 | actor.maxInterval = workerNumber
193 | // terminate a worker (which one is not important)
194 | actor.workerTerminated(sender_0)
195 | // oldest updater must be removed and second oldest becomes oldest
196 | actor.lastUpdates must_== Seq(sender_0)
197 | actor.updateIndex must_== actor.lastUpdates.size
198 | actor.maxInterval must_== workerNumber - 1
199 | }
200 | "worker terminate when lastUpdates has 2 updater" in {
201 | actor.updateIndex = 2
202 | actor.lastUpdates.clear
203 | actor.lastUpdates ++= Seq(sender_0, sender_1)
204 | actor.maxInterval = workerNumber
205 | // terminate a worker (which one is not important)
206 | actor.workerTerminated(sender_0)
207 | // the joined actor should be added as the newest updater
208 | actor.lastUpdates must_== Seq(sender_0, sender_1)
209 | // not first few iterations any more
210 | actor.updateIndex must_== 0
211 | actor.maxInterval must_== workerNumber - 1
212 | }
213 | }
214 | "worker terminate when updateIndex = 0" in {
215 | // order of old updaters is 0, 1, 2
216 | actor.updateIndex = 0
217 | actor.lastUpdates.clear
218 | actor.lastUpdates ++= Seq(sender_0, sender_1, sender_2)
219 | actor.maxInterval = actor.lastUpdates.size
220 | // terminate a worker (which one is not important)
221 | actor.workerTerminated(sender_0)
222 | // oldest updater must be removed and second oldest becomes oldest
223 | actor.lastUpdates must_== Seq(sender_1, sender_2)
224 | actor.lastUpdates(actor.updateIndex) must_== sender_1
225 | actor.maxInterval must_== actor.lastUpdates.size
226 | }
227 | "worker terminate when updateIndex = 1" in {
228 | // order of old updaters is 1, 2, 0
229 | actor.updateIndex = 1
230 | actor.lastUpdates.clear
231 | actor.lastUpdates ++= Seq(sender_0, sender_1, sender_2)
232 | actor.maxInterval = actor.lastUpdates.size
233 | // terminate a worker (which one is not important)
234 | actor.workerTerminated(sender_0)
235 | // oldest updater must be removed and second oldest becomes oldest
236 | actor.lastUpdates must_== Seq(sender_0, sender_2)
237 | actor.lastUpdates(actor.updateIndex) must_== sender_2
238 | actor.maxInterval must_== actor.lastUpdates.size
239 | }
240 | "worker terminate when updateIndex = 2" in {
241 | // order of old updaters is 2, 0, 1
242 | actor.updateIndex = 2
243 | actor.lastUpdates.clear
244 | actor.lastUpdates ++= Seq(sender_0, sender_1, sender_2)
245 | actor.maxInterval = actor.lastUpdates.size
246 | // terminate a worker (which one is not important)
247 | actor.workerTerminated(sender_0)
248 | // oldest updater must be removed and second oldest becomes oldest
249 | actor.lastUpdates must_== Seq(sender_0, sender_1)
250 | actor.lastUpdates(actor.updateIndex) must_== sender_0
251 | actor.maxInterval must_== actor.lastUpdates.size
252 | }
253 | }
254 | }
255 | }
256 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/akka/HostMasterActor.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.akka
18 |
19 | import scala.collection.mutable.{ Buffer, Map => MutableMap }
20 | import scala.concurrent.duration.{ Duration, DurationInt }
21 | import scala.util.Try
22 |
23 | import akka.AkkaException
24 | import akka.actor._
25 | import akka.actor.SupervisorStrategy.{ Restart, Resume, Stop }
26 | import akka.remote.RemoteScope
27 |
28 | import com.twitter.scalding.Args
29 |
30 | import HostMasterActor._
31 |
32 | /**
33 | * A host master actor is used to wait for all slaves to join. It's the creator of all actors
34 | * (worker, db, test and master).
35 | *
36 | * Arguments:
37 | * - `--worker <#workers>`: The number of workers.
38 | * - '--resume': (Optional) Snapshot in the store is used as initial weight.
39 | * Otherwise current weights in the caffe solver is used as initial weight.
40 | * - '--snapshot ': (Optional) Use a snapshot in file system as initial weights. The
41 | * snapshot will be loaded during initialization, `resume` is set to false. See [[com.htc.speedo.caffe.CaffeWorker CaffeWorker]].
42 | * - '--CPUDBActor': (Flag) If set, forces db actor and test actor to use CPU not GPU.
43 | *
44 | * Arguments for Yarn (Normally don't needed when start manually):
45 | * - '--timeout ': (Optional) Set the waiting time out for all the slave systems to join.
46 | * Default is no time out. Usually not needed when running the system manually. It's set to 30
47 | * seconds if started by Yarn.
48 | * - `--sleepAfterFinish `: (Optional) Sleeps given seconds before shutdown the system.
49 | * This is set by Yarn to 15 seconds if `--test` is 0.
50 | *
51 | * Parameters used by different type of master actor:
52 | * - `--drop `: The strategy that drops slow training iterations.
53 | * - `--maxAdvance `: The strategy that waits for slow iterations.
54 | * - `--sync`: The completely synchronous strategy.
55 | * - None: The completely asynchronous strategy.
56 | * @see Other parameters required by [[MasterActor]] for master actor.
57 | * @see Other parameters required by [[com.htc.speedo.caffe.CaffeWorker CaffeWorker]] for caffe parameters.
58 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
59 | */
60 | case class HostMasterActor(args: Args) extends Actor with ActorLogging {
61 | /** Number of workers. */
62 | val numWorkers = args.int("worker")
63 | /** Time out for waiting all the host actors to join. */
64 | val timeOut = args.int("timeout", 0)
65 | /**
66 | * The flag to resume or not. If `--snapshot` is provided, the snapshot is loaded in the caffe
67 | * worker when DBActor is created. So resume is false.
68 | */
69 | val resume = args.boolean("resume") && !args.boolean("snapshot")
70 | /* The device index for DB actor. Equals to -1 if CPUDBActor flag is given, otherwise 0. */
71 | val dbDevice = if (args.boolean("CPUDBActor")) -1 else 0
72 |
73 | if (timeOut > 0) {
74 | log.info("Waiting {} seconds for {} workers to join.", timeOut, numWorkers)
75 | implicit val dispatcher = context.dispatcher
76 | context.system.scheduler.scheduleOnce(timeOut.seconds, self, JoinTimeOut)
77 | } else log.info("Waiting for {} workers to join.", numWorkers)
78 |
79 | /** The buffer for all the host slave actors on remote machines. */
80 | val hostActors: Buffer[ActorRef] = Buffer()
81 | /** Accumulates count for different host names/ips. */
82 | val hostCounts = MutableMap(getHost(AkkaUtil.addressOf(context.system)) -> (dbDevice + 1))
83 |
84 | /**
85 | * The buffer for all the worker actors on remote machines. A worker actor is created once a
86 | * remote system joins to this actor.
87 | */
88 | val workerActors: Buffer[ActorRef] = Buffer()
89 | /** The master actor. Only created after all worker actors are created. */
90 | var masterActor: Option[ActorRef] = None
91 | /** The db actor. */
92 | val dbActor = createDbActor
93 | context watch dbActor
94 | /**
95 | * The tester actor. Mark as lazy so that test actor is not created at the
96 | * same time with dbActor. This avoids creating two caffe solver at the same
97 | * time, which might cause seg fault on GPU.
98 | * @note The test actor is not watched during init and train, since it's not
99 | * a must in these two phases. But after training, we start watch the actor
100 | * since it's required in the shutdown process. If the actor is terminated
101 | * beforehand, we can still receive the Terminated message by then.
102 | */
103 | lazy val testActor = createTester
104 |
105 | /** Create the master actor. */
106 | def createMasterActor: ActorRef =
107 | context.actorOf(AkkaUtil.createMasterActorProps(args, dbActor, testActor, workerActors), "master")
108 |
109 | /** Create the database actor. */
110 | def createDbActor: ActorRef =
111 | context.actorOf(Props(classOf[DBActor], argsWithDevice(dbDevice)).withDispatcher(WorkerDispatcher), "db")
112 |
113 | /** Create `index`th worker actor in the remote system at given address. */
114 | def createWorker(index: Int, address: Address): ActorRef = {
115 | val name = workerPrefix + index
116 | // The start proportion of the input data
117 | val start = (index.toFloat / numWorkers).toString
118 | // get device index for this worker
119 | val device = hostCounts.getOrElse(getHost(address), 0)
120 | // update the host count
121 | hostCounts.put(getHost(address), device + 1)
122 | // args for worker actor, use different base dir for each worker in case
123 | // multiple workers start in the same jvm
124 | val newArgs = argsWithDevice(device) + ("suffix" -> Seq(name)) +
125 | ("start" -> Seq(start)) + ("baseDir" -> Seq("caffe." + name))
126 | // remote deployment
127 | val props = Props(classOf[WorkerActor], newArgs)
128 | .withDeploy(Deploy(scope = RemoteScope(address)))
129 | .withDispatcher(WorkerDispatcher)
130 | context.actorOf(props, name)
131 | }
132 |
133 | /** Create the test actor. No need to set device as it's set by db actor. */
134 | def createTester: ActorRef =
135 | context.actorOf(Props(classOf[WorkerActor], args).withDispatcher(WorkerDispatcher), "tester")
136 |
137 | /** Create a new scalding args from [[args]] with given device flag. */
138 | def argsWithDevice(device: Int): Args = args + ("device" -> Seq(device.toString))
139 |
140 | /** Get host name/ip from given akka address. Return localhost if empty. */
141 | def getHost(address: Address): String = address.host.getOrElse("localhost")
142 |
143 | /** Message handling during initialization, training and post-train. */
144 | def commonState: Receive = {
145 | case Progress => masterActor match {
146 | // forward message to master actor
147 | case Some(master) => master.tell(Progress, sender)
148 | // master actor is not started (still waiting for remote systems to join)
149 | case None => sender ! Progress(0)
150 | }
151 | case StopAkka => // shutdown the whole system
152 | // stop all slaves first
153 | hostActors.foreach(_ ! StopAkka)
154 | args.int("sleepAfterFinish", 0) match {
155 | // stop master akka system immediately
156 | case 0 => context.system.shutdown
157 | // schedule the shutdown after i seconds
158 | case i: Int =>
159 | log.info("System will be shutdown in {} seconds.", i)
160 | context.system.scheduler.scheduleOnce(i.seconds)(context.system.shutdown)(context.dispatcher)
161 | }
162 | case Terminated(`dbActor`) =>
163 | log.warning("Db actor stopped! Shutdown the system!")
164 | self ! StopAkka
165 | case Terminated(`testActor`) =>
166 | log.warning("Test actor stopped! Shutdown the system!")
167 | self ! StopAkka
168 | }
169 |
170 | /** Message handling during initialization, not training or post-train. */
171 | def initState: Receive = {
172 | // This happens when not all the workers start within given timeout
173 | case JoinTimeOut =>
174 | log.error("Failed to start all workers within {} seconds!" + " Shutdown the system!", timeOut)
175 | self ! StopAkka
176 | }
177 |
178 | /** Message handling during initialization and training, not post-train. */
179 | def initOrTrainingState: Receive = {
180 | case Join =>
181 | // A remote system joins with its address
182 | val address = sender.path.address
183 | log.info("Creating worker{} at {}", hostActors.size, address)
184 | // creates worker at the remote system
185 | val worker = createWorker(hostActors.size, address)
186 | // make sure the worker is created (will receive ActorIdentity message)
187 | worker ! Identify(IdentifyWorker(worker))
188 | hostActors += sender
189 | case ActorIdentity(IdentifyWorker(ref), Some(worker)) if ref == worker =>
190 | log.info("{} created!", worker.path.name)
191 | // If a worker is created successfully
192 | context watch worker
193 | workerActors += worker
194 | // clears worker's suffix key
195 | if (!resume) dbActor ! ClearWorkerKey(worker)
196 | masterActor match {
197 | case None if workerActors.size == numWorkers => // all workers joined
198 | // create and watch master actor
199 | masterActor = Some(createMasterActor)
200 | masterActor.foreach { actor =>
201 | context watch actor
202 | // The master actor should start after dbActor finished init
203 | dbActor.tell(Init(resume), actor)
204 | }
205 | // Cheange message handling
206 | context.become(commonState orElse initOrTrainingState orElse trainingState)
207 | case Some(master) => // re-created workers
208 | // inform master about newly created workers
209 | // go through db actor to make sure the suffix key is cleared first
210 | dbActor ! Forward(master, WorkerCreated(worker))
211 | case _ => // do nothing if we are still waiting for workers to join
212 | }
213 | case ActorIdentity(IdentifyWorker(worker), _) =>
214 | // If a worker is not created successfully
215 | // TODO: continue without this worker
216 | log.error("Failed to create {}, stop akka!", worker)
217 | self ! StopAkka
218 | case Terminated(w) if workerActors.contains(w) => // worker stopped
219 | log.info("{} stopped!", w.path.name)
220 | workerActors -= w
221 | // Check if host actor is stopped or not, try to restart worker
222 | // We can get the corresponding host actor from index in worker's name
223 | // But this may fail in unit tests, so we wrap it with a Try
224 | Try(w.path.name.stripPrefix(workerPrefix).toInt).toOption.map { index =>
225 | val host = hostActors(index)
226 | host ! Identify(IdentifyHost(host))
227 | }
228 | case ActorIdentity(IdentifyHost(ref), Some(host)) if ref == host =>
229 | // If the host is still running after worker is terminated, restart worker
230 | val address = host.path.address
231 | val index = hostActors.indexOf(host)
232 | log.info("Re-creating worker{} at {}", index, address)
233 | // creates worker at the remote system again
234 | val worker = createWorker(index, address)
235 | worker ! Identify(IdentifyWorker(worker))
236 | case ActorIdentity(IdentifyHost(host), _) =>
237 | // If both worker and host actor are stopped, not restart worker
238 | // prepare shutdown if all workers stopped
239 | if (workerActors.size == 0) self ! StopAkka
240 | }
241 |
242 | /** Message handling during training, not initialization or post-train. */
243 | def trainingState: Receive = {
244 | case TrainFinished =>
245 | log.info("Train finished, starting test.")
246 | // start to watch test actor
247 | context.watch(testActor)
248 | // run test after master finished training and log test result in master
249 | testActor.tell(Test, sender)
250 | // stop the system after test finished
251 | testActor ! Forward(self, StopAkka)
252 | // unwatch workers to make log clean
253 | workerActors.foreach(context.unwatch)
254 | // unwatch master
255 | masterActor.foreach(context.unwatch)
256 | // do not react to certain messages
257 | context.become(commonState)
258 | case Terminated(actor) if masterActor == Some(actor) =>
259 | log.warning("Master actor stopped! Shutdown the system!")
260 | self ! StopAkka
261 | }
262 |
263 | override def receive = commonState orElse initState orElse initOrTrainingState
264 |
265 | // always stop the child actors when error occured
266 | override val supervisorStrategy = OneForOneStrategy() { case _ => Stop }
267 | }
268 |
269 | object HostMasterActor {
270 | case class IdentifyHost(host: ActorRef)
271 | case class IdentifyWorker(worker: ActorRef)
272 | val workerPrefix = "worker"
273 | }
274 |
--------------------------------------------------------------------------------
/src/main/scala/com/htc/speedo/caffe/CaffeWorker.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright 2016 HTC Corporation
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | package com.htc.speedo.caffe
18 |
19 | import java.io.File
20 |
21 | import scala.collection.JavaConverters.asScalaBufferConverter
22 | import scala.collection.mutable.{ Buffer, Set => MutableSet }
23 |
24 | import caffe.Caffe.{ NetParameter, Phase, SolverParameter }
25 |
26 | import com.google.protobuf.Message
27 | import com.twitter.scalding.Args
28 | import com.twitter.storehaus.algebra.MergeableStore
29 | import com.twitter.util.Await
30 |
31 | import org.apache.commons.io.FileUtils
32 | import org.apache.hadoop.conf.Configuration
33 | import org.apache.hadoop.fs.Path
34 |
35 | object CaffeWorker {
36 | /**
37 | * A basic trait for how to update weights, i.e. what to write to storehaus in
38 | * [[CaffeWorker.train]].
39 | */
40 | sealed trait WeightUpdate
41 | /**
42 | * Write the full weight different against snapshot to storehaus, including
43 | * gradients, momentum and weight decay, based on the solver update function.
44 | */
45 | case object FullUpdate extends WeightUpdate
46 | /**
47 | * Update the weight in each worker's own pace, in which case, maintains a
48 | * local copy of weight. See EASGD paper for details.
49 | */
50 | case object SelfPaceFullUpdate extends WeightUpdate
51 | /**
52 | * Write grandients to the storehaus store. The gradients can be merged by
53 | * calling [[mergeDelta]] separately.
54 | */
55 | case object GradientOnly extends WeightUpdate
56 |
57 | /** The default work directory for caffe. */
58 | val BaseDir = new File("caffe.worker")
59 |
60 | /**
61 | * Create a caffe worker from command line arguments.
62 | * See constructor for parameter details.
63 | * @note The storehaus store will be created using [[StorehausUtils]]. Make
64 | * sure to close the caffe worker in your code.
65 | */
66 | def apply(args: Args): CaffeWorker = CaffeWorker(
67 | StorehausUtils.createStore(args),
68 | args.required("solver"),
69 | args.optional("snapshot"),
70 | args.optional("keyName"),
71 | args.optional("suffix"),
72 | args.optional("baseDir").map(new File(_)),
73 | args.boolean("debug"),
74 | args.optional("device").map(_.toInt),
75 | args.boolean("doublePrecision"),
76 | args.optional("train").map(_.toInt),
77 | args.optional("lr").map(_.toFloat),
78 | args.optional("weightDecay").map(_.toFloat),
79 | args.optional("momentum").map(_.toFloat),
80 | args.optional("batch").map(_.toInt),
81 | args.optional("start").map(_.toFloat)
82 | )
83 | }
84 |
85 | /**
86 | * Caffe worker to invoke caffe training.
87 | * Training inputs come from hdfs and the network parameters are merged into
88 | * a storehaus store.
89 | *
90 | * TODO: Instead of passing in all the parmaeters, use configuration.
91 | *
92 | * @note All parameters (except snapshotStore) can be set through scalding args
93 | * using [[apply]] function in companion object. The argument is same with the
94 | * variable name, except stated explicitly.
95 | * @param snapshotStore The snapshot store used to merge network parameter.
96 | * Should use [[NetParameterSemigroup]] as semigroup for merge.
97 | * @param solverPath The path to a text format of SolverParameter. The
98 | * definition of model must be specified in a file in `net` field. The command
99 | * line argument is `--solver`.
100 | * @param snapshotPath The path to a binary format of NetParameter used as a
101 | * snapshot to resume train from. The command line argument is `--snapshot`.
102 | * @param keyName The name used as key in the database, overriding the name in
103 | * the model file.
104 | * @param suffix The worker will read/write it's snapshot/delta using a separate
105 | * key, the new key will be ` + suffix`.
106 | * @param baseDirectory Overrides working directory for caffe. Relative to
107 | * current directory. The command line argument is `--baseDir`.
108 | * @param debug If set to true, display the loss every iteration in caffe
109 | * @param device The GPU device index to use for this caffe worker. A negative
110 | * index indicates using CPU; positive device index is MOD by total available
111 | * GPU counts. E.g. non-negative GPU indexes are identical on one-GPU machine.
112 | * On a machine without GPU, device is ignored.
113 | * @param doublePrecision Use double precision floating number in the solver.
114 | * This may be more accurate but consumes double memory. Default is false.
115 | * @param trainIteration Overrides the number of mini-batch to run in training.
116 | * The command line argument is `--train`. This defaults to 1.
117 | * @param baseLR Overrides the base learning rate. The command line argument is
118 | * `--lr`.
119 | * @param weightDecay Overrides the weight decay.
120 | * @param momentum Overrides the momentum.
121 | * @param batchSize Overrides the batch size of all training data layers. The
122 | * command line argument is `--batch`. For back compatibility, this option is
123 | * named as batch size, but it actually operates on the `iter_size` field in
124 | * solver, since changing batch size of data layers during runtime requires a
125 | * lot of hacks and is not stable. For easier usage, if all the data layers has
126 | * the same batch size, their batch size are changed to 1 and iter_size is
127 | * multiplied by the original batch size during initialization.
128 | * @param start Overrides start position of input data for the worker. This
129 | * should be a float value, equals to /.
130 | * You should also set the `rand_skip` field in the protobuf to be the total
131 | * number of data.
132 | * @note Users should be responsible for closing the store themselves.
133 | * @author Zhongyang Zheng (zhongyang_zheng@htc.com)
134 | */
135 | case class CaffeWorker(
136 | snapshotStore: MergeableStore[String, Array[Byte]],
137 | solverPath: String,
138 | snapshotPath: Option[String] = None,
139 | keyName: Option[String] = None,
140 | suffix: Option[String] = None,
141 | baseDirectory: Option[File] = None,
142 | debug: Boolean = false,
143 | device: Option[Int] = None,
144 | doublePrecision: Boolean = false,
145 | // override default config from command line
146 | trainIteration: Option[Int] = None,
147 | baseLR: Option[Float] = None,
148 | weightDecay: Option[Float] = None,
149 | momentum: Option[Float] = None,
150 | batchSize: Option[Int] = None,
151 | start: Option[Float] = None
152 | ) {
153 | val baseDir = baseDirectory.getOrElse(CaffeWorker.BaseDir)
154 |
155 | // Get the solver parameter, jni solver and the name of the network
156 | private val (param, solver, name) = {
157 | // hdfs path of the solver file
158 | val path = new Path(solverPath)
159 | // hdfs fodler of the solver file
160 | val pathDir = path.getParent
161 | val fs = path.getFileSystem(new Configuration)
162 | // load solver parameter from hdfs
163 | val solverBuilder = ProtobufUtils.loadText[SolverParameter](fs.open(path)).toBuilder
164 | // load net parameter from hdfs, relative to solver parameter path
165 | val modelBuilder = ProtobufUtils.loadText[NetParameter](fs.open(new Path(pathDir, solverBuilder.getNet))).toBuilder
166 | // make sure local directory exists and is empty
167 | if (baseDir.exists) FileUtils.deleteDirectory(baseDir)
168 | // copy the whole directory to local
169 | val localPath = new Path(baseDir.toURI)
170 | fs.copyToLocalFile(false, pathDir, localPath, true)
171 | // change the iteration settings in the solver
172 | solverBuilder
173 | // if debug, display loss information every iteration
174 | .setDisplay(if (debug) 1 else 0)
175 | // don't run tests
176 | .setTestInitialization(false)
177 | .setTestInterval(Int.MaxValue)
178 | // don't save snapshots during training
179 | .clearSnapshot
180 | // don't save snapshot after training
181 | .setSnapshotAfterTrain(false)
182 | // clear networks definitions in the solver
183 | .clearTrainNet
184 | .clearTestNet
185 | .clearTrainNetParam
186 | .clearTestNetParam
187 | .clearNetParam
188 | // set training iteration before merge, always override value in solver
189 | .setMaxIter(trainIteration.getOrElse(1))
190 | baseLR.foreach(solverBuilder.setBaseLr)
191 | weightDecay.foreach(solverBuilder.setWeightDecay)
192 | momentum.foreach(solverBuilder.setMomentum)
193 |
194 | // A helper function to change paths in the net definition
195 | def pathPrefix(path: String): String =
196 | // if path is already absolute, do nothing. This is useful for large
197 | // datasets, so we don't need to copy from hdfs everytime
198 | if (path.startsWith("/")) path
199 | // is the path is relative, then it's relative to the baseDir
200 | else baseDir.getName + "/" + path
201 |
202 | // change the input file path, batch size and start pos in the data layers
203 | // TODO: Add unit test
204 | val batchSizeSet = MutableSet[Int]()
205 | val batchSetter = Buffer[Int => Message.Builder]()
206 | modelBuilder.getLayerBuilderList.asScala.foreach { l =>
207 | // Only change batch size for train net
208 | val includes = l.getIncludeList.asScala.filter(_.hasPhase)
209 | .map(_.getPhase) ++ Some(l).filter(_.hasPhase).map(_.getPhase)
210 | val excludes = l.getExcludeList.asScala.filter(_.hasPhase).map(_.getPhase)
211 | val isTrainNet = (includes.size == 0 || includes.contains(Phase.TRAIN)) &&
212 | (excludes.size == 0 || !excludes.contains(Phase.TRAIN))
213 | if (l.hasTransformParam) {
214 | val transformParam = l.getTransformParamBuilder
215 | if (transformParam.hasMeanFile)
216 | transformParam.setMeanFile(pathPrefix(transformParam.getMeanFile))
217 | }
218 | if (l.hasDataParam) {
219 | val dataParam = l.getDataParamBuilder
220 | if (dataParam.hasSource)
221 | dataParam.setSource(pathPrefix(dataParam.getSource))
222 | if (dataParam.hasMeanFile)
223 | dataParam.setMeanFile(pathPrefix(dataParam.getMeanFile))
224 | // Only set start position for train net
225 | start.filter(_ => isTrainNet).foreach { s =>
226 | if (dataParam.hasRandSkip)
227 | dataParam.setRandSkip((dataParam.getRandSkip * s).toInt)
228 | else if (s > 0)
229 | throw new Exception("rand_skip must be set as total number of" +
230 | "training data in data layers for multiple workers")
231 | }
232 | if (isTrainNet) {
233 | batchSizeSet += dataParam.getBatchSize
234 | batchSetter += dataParam.setBatchSize
235 | }
236 | }
237 | if (l.hasHdf5DataParam) {
238 | val hdf5DataParam = l.getHdf5DataParamBuilder
239 | if (hdf5DataParam.hasSource)
240 | hdf5DataParam.setSource(pathPrefix(hdf5DataParam.getSource))
241 | if (isTrainNet) {
242 | batchSizeSet += hdf5DataParam.getBatchSize
243 | batchSetter += hdf5DataParam.setBatchSize
244 | }
245 | }
246 | if (l.hasImageDataParam) {
247 | val imageDataParam = l.getImageDataParamBuilder
248 | if (imageDataParam.hasSource)
249 | imageDataParam.setSource(pathPrefix(imageDataParam.getSource))
250 | if (imageDataParam.hasMeanFile)
251 | imageDataParam.setMeanFile(pathPrefix(imageDataParam.getMeanFile))
252 | // Only set start position for train net
253 | start.filter(_ => isTrainNet).foreach { s =>
254 | if (imageDataParam.hasRandSkip)
255 | imageDataParam.setRandSkip((imageDataParam.getRandSkip * s).toInt)
256 | else if (s > 0)
257 | throw new Exception("rand_skip must be set as total number of" +
258 | "training data in data layers for multiple workers")
259 | }
260 | if (isTrainNet) {
261 | batchSizeSet += imageDataParam.getBatchSize
262 | batchSetter += imageDataParam.setBatchSize
263 | }
264 | }
265 | if (l.hasMemoryDataParam && isTrainNet) {
266 | val memoryDataParam = l.getMemoryDataParamBuilder
267 | batchSizeSet += memoryDataParam.getBatchSize
268 | batchSetter += memoryDataParam.setBatchSize
269 | }
270 | if (l.hasWindowDataParam) {
271 | val windowDataParam = l.getWindowDataParamBuilder
272 | if (windowDataParam.hasSource)
273 | windowDataParam.setSource(pathPrefix(windowDataParam.getSource))
274 | if (windowDataParam.hasMeanFile)
275 | windowDataParam.setMeanFile(pathPrefix(windowDataParam.getMeanFile))
276 | if (isTrainNet) {
277 | batchSizeSet += windowDataParam.getBatchSize
278 | batchSetter += windowDataParam.setBatchSize
279 | }
280 | }
281 | }
282 | // update batch sizes if explicit override or all batch sizes are same
283 | batchSize.orElse(batchSizeSet.size match {
284 | // multiplies by the iter_size in case it's not 1
285 | case 1 => batchSizeSet.headOption.map(_ * solverBuilder.getIterSize)
286 | case _ => None
287 | }).foreach { batch =>
288 | // set batch size to 1
289 | batchSetter.foreach(_(1))
290 | // multiply iter_size by original batch size
291 | solverBuilder.setIterSize(batch)
292 | }
293 | // set device (Only set device if we have GPUs)
294 | if (Solver.deviceCount > 0) device match {
295 | case Some(d) if d >= 0 => Solver.setDevice(d % Solver.deviceCount) // GPU
296 | case Some(_) => Solver.setDevice(-1) // CPU
297 | case _ => // Do nothing is device is not explicitly set
298 | }
299 | val jniSolver = new Solver
300 | val solverParam = solverBuilder.build
301 | jniSolver.init(solverParam, modelBuilder.build, doublePrecision)
302 | // load snapshot
303 | snapshotPath.foreach { path =>
304 | // convert relative path to absolute path
305 | val snapshot = ProtobufUtils.load[NetParameter](pathPrefix(path))
306 | jniSolver.setWeight(snapshot.toByteArray)
307 | }
308 | // return solver parameter, jni solver and network name
309 | (solverParam, jniSolver, keyName.getOrElse(modelBuilder.getName))
310 | }
311 |
312 | /** Returns the name of storehaus key. */
313 | def getKey: String = name
314 |
315 | /**
316 | * Put the initial weight to store if the key is not available.
317 | * @param resume If set to true, the snapshot in storehaus store is used.
318 | * @return The weights after initialization (or the one already exists)
319 | */
320 | def init(resume: Boolean = false): Array[Byte] = {
321 | val snapshot = resume match {
322 | // always override the snapshot
323 | case false => None
324 | // fetch the snapshot
325 | case true => Await.result(snapshotStore.get(name))
326 | }
327 | snapshot.getOrElse {
328 | val weights = solver.getWeight
329 | Await.result(snapshotStore.put(name, Some(weights)))
330 | weights
331 | }
332 | }
333 |
334 | /**
335 | * Trains the caffe model for a small amount of iterations.
336 | * @param readSuffix Whether read snapshot from the key with suffix or not.
337 | * Default is true.
338 | * @param WeightUpdate The enum of how to update the weights and update
339 | * storehuas. See [[CaffeWorker.WeightUpdate]] and its sub case objects.
340 | * Default is [[CaffeWorker.FullUpdate]].
341 | */
342 | def train(readSuffix: Boolean = true, WeightUpdate: CaffeWorker.WeightUpdate = CaffeWorker.FullUpdate): Double = {
343 | // If suffix is provided and readSuffix is enabled,
344 | // fetch weight snapshot from the key with suffix
345 | val snapshot = suffix.filter(_ => readSuffix)
346 | .map(suf => Await.result(snapshotStore.get(name + suffix.get))).flatten
347 | // fall back to key without suffix
348 | .orElse(Await.result(snapshotStore.get(name)))
349 | // snapshot must not be empty
350 | .getOrElse(throw new Exception("Snapshot is empty!"))
351 | // set snapshot
352 | solver.setWeight(snapshot)
353 | // run caffe train and get the delta
354 | val (loss, delta) = WeightUpdate match {
355 | case CaffeWorker.FullUpdate =>
356 | (solver.train(param.getMaxIter, true), NetParameterOperation.minus(solver.getWeight, snapshot))
357 | case CaffeWorker.SelfPaceFullUpdate =>
358 | (solver.train(param.getMaxIter, true), solver.getWeight)
359 | case CaffeWorker.GradientOnly => (solver.train(1, false), solver.getDelta)
360 | }
361 | suffix match {
362 | case None =>
363 | // merges the delta to store if no suffix given
364 | Await.result(snapshotStore.merge(name, delta))
365 | case Some(suf) =>
366 | Await.result(snapshotStore.put(name + suf, Some(delta)))
367 | }
368 | // return loss
369 | loss
370 | }
371 |
372 | /** Tests the caffe model for all test data. Returns accuracy. */
373 | def test: Option[Double] =
374 | // fetch a snapshot
375 | Await.result(snapshotStore.get(name)).map { weights =>
376 | // set the snapshot in solver
377 | solver.setWeight(weights)
378 | // run caffe test
379 | solver.test(param.getTestIter(0))
380 | }
381 |
382 | /** Set current global iteration for the worker. */
383 | def setIteration(iteration: Int): Unit = solver.setIteration(iteration)
384 |
385 | /** Updates weights according to the given gradients and write to store. */
386 | def mergeDelta(delta: Array[Byte], suffix: String = ""): Array[Byte] = {
387 | val weights = solver.mergeDelta(delta)
388 | Await.result(snapshotStore.put(name + suffix, Some(weights)))
389 | weights
390 | }
391 |
392 | /**
393 | * Update solver parameters using command line arguments. Shares same
394 | * arguments as the [[CaffeWorker.apply]] function.
395 | *
396 | * Supported options:
397 | * - `--momentum`: Momentum
398 | * - `--weightDecay`: Weight decay
399 | * - `--lr`: Base learning rate
400 | * - `--batch`: Update the iter_size in the solver.
401 | * @note If the any option is not provided in the command line, it's reset to
402 | * the original one defined in solver protobuf or model protobuf (batch size).
403 | */
404 | def updateParameter(args: Args): Unit = {
405 | val builder = SolverParameter.newBuilder
406 | args.optional("lr").map(_.toFloat).foreach(builder.setBaseLr)
407 | args.optional("weightDecay").map(_.toFloat).foreach(builder.setWeightDecay)
408 | args.optional("momentum").map(_.toFloat).foreach(builder.setMomentum)
409 | args.optional("batch").map(_.toInt).foreach(builder.setIterSize)
410 | solver.updateParameter(builder.build)
411 | }
412 |
413 | /** Finalize resources and close storehaus store. */
414 | def close: Unit = {
415 | solver.dispose
416 | Await.result(snapshotStore.close())
417 | }
418 | }
419 |
--------------------------------------------------------------------------------