├── .github
└── workflows
│ └── scala.yml
├── .gitignore
├── .scalafmt.conf
├── LICENSE
├── README.md
├── build.sbt
├── project
├── build.properties
└── plugins.sbt
└── src
├── main
├── resources
│ └── reference.conf
└── scala
│ └── io
│ └── bernhardt
│ └── akka
│ └── locality
│ ├── Locality.scala
│ ├── package.scala
│ └── router
│ ├── ShardLocationAwareRouter.scala
│ └── ShardStateMonitor.scala
├── multi-jvm
└── scala
│ ├── akka
│ ├── README.md
│ ├── cluster
│ │ ├── FailureDetectorPuppet.scala
│ │ ├── MultiNodeClusterSpec.scala
│ │ ├── sharding
│ │ │ ├── MultiNodeClusterShardingConfig.scala
│ │ │ └── MultiNodeClusterShardingSpec.scala
│ │ └── testkit
│ │ │ └── AutoDowning.scala
│ ├── remote
│ │ └── testkit
│ │ │ └── STMultiNodeSpec.scala
│ ├── serialization
│ │ └── jackson
│ │ │ └── CborSerializable.scala
│ └── testkit
│ │ └── AkkaSpec.scala
│ └── io
│ └── bernhardt
│ └── akka
│ └── locality
│ └── router
│ ├── ShardLocationAwareRouterNewShardsSpec.scala
│ ├── ShardLocationAwareRouterSpec.scala
│ └── ShardLocationAwareRouterWithProxySpec.scala
└── test
└── scala
└── io
└── bernhardt
└── akka
└── locality
└── router
└── ShardLocationAwareRoutingLogicSpec.scala
/.github/workflows/scala.yml:
--------------------------------------------------------------------------------
1 | name: Scala CI
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 |
10 | steps:
11 | - uses: actions/checkout@v1
12 | - name: Set up JDK 1.8
13 | uses: actions/setup-java@v1
14 | with:
15 | java-version: 1.8
16 | - name: Run tests
17 | run: sbt test
18 | - name: Run multi-jvm tests
19 | run: sbt multi-jvm:test
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | bin/
2 | target/
3 | build/
4 |
5 | *.log
6 | *.iml
7 | *.ipr
8 | *.iws
9 | .idea
10 |
11 | .DS_Store
12 |
13 | .history
14 | .scala_dependencies
15 | .cache
16 | .cache-main
17 |
18 | *.class
19 |
--------------------------------------------------------------------------------
/.scalafmt.conf:
--------------------------------------------------------------------------------
1 | version = 2.2.2
2 |
3 | style = defaultWithAlign
4 |
5 | docstrings = JavaDoc
6 | indentOperator = spray
7 | maxColumn = 120
8 | rewrite.rules = [RedundantParens, SortImports, AvoidInfix]
9 | unindentTopLevelOperators = true
10 | align.tokens = [{code = "=>", owner = "Case"}]
11 | align.openParenDefnSite = false
12 | align.openParenCallSite = false
13 | optIn.breakChainOnFirstMethodDot = false
14 | optIn.configStyleArguments = false
15 | danglingParentheses = false
16 | spaces.inImportCurlyBraces = true
17 | rewrite.neverInfix.excludeFilters = [
18 | and
19 | min
20 | max
21 | until
22 | to
23 | by
24 | eq
25 | ne
26 | "should.*"
27 | "contain.*"
28 | "must.*"
29 | in
30 | ignore
31 | be
32 | taggedAs
33 | thrownBy
34 | synchronized
35 | have
36 | when
37 | size
38 | only
39 | noneOf
40 | oneElementOf
41 | noElementsOf
42 | atLeastOneElementOf
43 | atMostOneElementOf
44 | allElementsOf
45 | inOrderElementsOf
46 | theSameElementsAs
47 | ]
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2019 Manuel Bernhardt
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # akka-locality
2 |
3 | This module provides constructs that help to make better use of the locality of actors within a clustered Akka system.
4 | For a full explanation of the problem it addresses, [check out this article](https://manuel.bernhardt.io/2019/10/28/one-step-closer-exploiting-locality-in-akka-cluster-based-systems/)
5 |
6 | ### SBT
7 |
8 | ```sbt
9 | libraryDependencies += "io.bernhardt" %% "akka-locality" % "1.1.0"
10 | ```
11 |
12 | ### Maven
13 |
14 | ```xml
15 |
16 | io.bernhardt
17 | akka-locality_2.12
18 | 1.1.0
19 |
20 | ```
21 |
22 | ## Shard location aware routers
23 |
24 | This type of router is useful for systems in which the routees of cluster-aware routers need to communicate with sharded
25 | entities.
26 |
27 | With a common routing logic (random, round-robin) there may be an extra network hop (or two when considering replies)
28 | between a routee and the sharded entities it needs to talk to. Shard location aware routers optimize this by routing
29 | to the routee closest to the sharded entity. It does so by using the same rules for extracting the `shardId` from a
30 | message as used by the shard regions themselves.
31 |
32 | When the router has not yet retrieved sharding state, it falls back to random routing.
33 | When there are more than one candidate routee close to a sharded entity, one of them is picked at random.
34 |
35 | In order to use these routers, the `Locality` extension must be started:
36 |
37 | ### Scala
38 |
39 | ```scala
40 | import io.bernhardt.akka.locality._
41 | import akka.actor.ActorSystem
42 |
43 | val system: ActorSystem = ActorSystem("system")
44 | val locality = Locality(system)
45 | ```
46 |
47 | ### Java
48 |
49 | ```java
50 | import io.bernhardt.akka.locality;
51 | import akka.actor.ActorSystem;
52 |
53 | ActorSystem system = ActorSystem.create("system");
54 | Locality locality = Locality.get(system);
55 | ```
56 |
57 | You can then use the group or pool routers as a cluster-aware router. These routers must be declared in code, as they
58 | require to be passed elements from the sharding setup:
59 |
60 | ### Scala
61 |
62 | ```scala
63 | import akka.actor.{ActorSystem, ActorRef}
64 | import akka.cluster.sharding.ShardRegion
65 | import akka.cluster.routing._
66 |
67 | import io.bernhardt.akka.locality.Locality
68 |
69 | val system: ActorSystem = ActorSystem("system")
70 | val locality: Locality = Locality(system)
71 | val extractEntityId: ShardRegion.ExtractEntityId = ???
72 | val extractShardId: ShardRegion.ExtractShardId = ???
73 | val region: ActorRef = ???
74 |
75 | val router = system.actorOf(ClusterRouterGroup(locality.shardLocationAwareGroup(
76 | routeePaths = Nil,
77 | shardRegion = region,
78 | extractEntityId = extractEntityId,
79 | extractShardId = extractShardId
80 | ), ClusterRouterGroupSettings(
81 | totalInstances = 5,
82 | routeesPaths = List("/user/routee"),
83 | allowLocalRoutees = true
84 | )).props(), "shard-location-aware-router")
85 | ```
86 |
87 | ### Java
88 |
89 |
90 | ```java
91 | import akka.actor.ActorSystem;
92 | import akka.actor.ActorRef;
93 | import akka.cluster.sharding.ShardRegion;
94 | import akka.cluster.routing.ClusterRouterGroup;
95 | import akka.cluster.routing.ClusterRouterGroupSettings;
96 |
97 | ActorRef region = ...;
98 | ShardRegion.MessageExtractor messageExtractor = ...;
99 | int totalInstances = 5;
100 | Iterable routeesPaths = Collections.singletonList("/user/routee");
101 | boolean allowLocalRoutees = true;
102 | Set useRoles = new HashSet<>(Arrays.asList("role"));
103 |
104 | ActorRef router = system.actorOf(
105 | new ClusterRouterGroup(
106 | locality.shardLocationAwareGroup(
107 | routeesPaths,
108 | region,
109 | messageExtractor
110 | ),
111 | new ClusterRouterGroupSettings(
112 | totalInstances,
113 | routeesPaths,
114 | allowLocalRoutees,
115 | useRoles
116 | )
117 | ).props(), "shard-location-aware-router");
118 | ```
119 |
120 | Always make sure that:
121 |
122 | - you use exactly the same logic for the routers as you use for sharding
123 | - you deploy the routers on all the nodes on which sharding is enabled
124 |
125 | ### Configuration
126 |
127 | See [reference.conf](https://github.com/manuelbernhardt/akka-locality/blob/master/src/main/resources/reference.conf) for more information about the configuration of the routing mechanism.
128 |
129 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | import com.typesafe.sbt.MultiJvmPlugin.multiJvmSettings
2 |
3 | lazy val akkaVersion = "2.5.26"
4 |
5 | lazy val `akka-locality` = project
6 | .in(file("."))
7 | .enablePlugins(MultiJvmPlugin)
8 | .configs(MultiJvm)
9 | .settings(multiJvmSettings: _*)
10 | .settings(publishingSettings: _*)
11 | .settings(
12 | name := "akka-locality",
13 | version := "1.1.0",
14 | startYear := Some(2019),
15 | scalaVersion := "2.12.10",
16 | crossScalaVersions := Seq("2.12.10", "2.13.1"),
17 | scalacOptions ++= Seq(
18 | "-unchecked",
19 | "-deprecation",
20 | "-language:_",
21 | "-target:jvm-1.8",
22 | "-encoding", "UTF-8"
23 | ),
24 | parallelExecution in Test := false,
25 | libraryDependencies ++= Seq(
26 | "com.typesafe.akka" %% "akka-actor" % akkaVersion % "provided;multi-jvm;test",
27 | "com.typesafe.akka" %% "akka-cluster-sharding" % akkaVersion % "provided;multi-jvm;test",
28 | "com.typesafe.akka" %% "akka-persistence" % akkaVersion % Test,
29 | "com.typesafe.akka" %% "akka-testkit" % akkaVersion % Test,
30 | "com.typesafe.akka" %% "akka-multi-node-testkit" % akkaVersion % Test,
31 | "org.iq80.leveldb" % "leveldb" % "0.12" % "optional;provided;multi-jvm;test",
32 | "commons-io" % "commons-io" % "2.6" % Test,
33 | "org.scalatest" %% "scalatest" % "3.0.8" % Test
34 | ),
35 | credentials += Credentials(Path.userHome / ".sbt" / "sonatype_credential")
36 | )
37 |
38 | val publishingSettings = Seq(
39 | ThisBuild / organization := "io.bernhardt",
40 | ThisBuild / organizationName := "manuel.bernhardt.io",
41 | ThisBuild / organizationHomepage := Some(url("https://manuel.bernhardt.io")),
42 |
43 | ThisBuild / scmInfo := Some(
44 | ScmInfo(
45 | url("https://github.com/manuelbernhardt/akka-locality"),
46 | "scm:git@github.com:manuelbernhardt/akka-locality.git"
47 | )
48 | ),
49 | ThisBuild / developers := List(
50 | Developer(
51 | id = "manuel",
52 | name = "Manuel Bernhardt",
53 | email = "manuel@bernhardt.io",
54 | url = url("https://manuel.bernhardt.io")
55 | )
56 | ),
57 | ThisBuild / description := "Akka extension to make better use of locality of actors in clustered systems",
58 | ThisBuild / licenses := List("Apache 2" -> new URL("http://www.apache.org/licenses/LICENSE-2.0.txt")),
59 | ThisBuild / homepage := Some(url("https://github.com/manuelbernhardt/akka-locality")),
60 |
61 | // Remove all additional repository other than Maven Central from POM
62 | ThisBuild / pomIncludeRepository := { _ => false },
63 | ThisBuild / publishTo := {
64 | val nexus = "https://oss.sonatype.org/"
65 | if (isSnapshot.value) Some("snapshots" at nexus + "content/repositories/snapshots")
66 | else Some("releases" at nexus + "service/local/staging/deploy/maven2")
67 | },
68 | ThisBuild / publishMavenStyle := true
69 | )
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=1.3.3
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("com.typesafe.sbt" % "sbt-multi-jvm" % "0.4.0")
2 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.0")
3 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.2.0")
--------------------------------------------------------------------------------
/src/main/resources/reference.conf:
--------------------------------------------------------------------------------
1 | akka {
2 | locality {
3 |
4 | # The timeout to use when attempting to retrive shard states from the cluster sharding system.
5 | # For large clusters this value may need to be increased
6 | retrieve-shard-state-timeout = 5 seconds
7 |
8 | # The margin to keep before requesting an update from the cluster sharding system when a rebalance is detected.
9 | # This is necessary because shards need some time to be rebalanced (in case of normal rebalancing, or in case
10 | # of a topology change).
11 | # If you use a [[akka.cluster.DowningProvider]], you should take into account the `downRemovalMargin` since shards
12 | # will only be re-allocated after this margin has elapsed.
13 | shard-state-update-margin = 10 seconds
14 |
15 | # The interval between periodic polling of the shard state
16 | # The shard state is queried periodically in order to get aware of newly created shards that have not caused
17 | # a rebalancing yet to occur yet.
18 | shard-state-polling-interval = 15 seconds
19 | }
20 | }
--------------------------------------------------------------------------------
/src/main/scala/io/bernhardt/akka/locality/Locality.scala:
--------------------------------------------------------------------------------
1 | package io.bernhardt.akka.locality
2 |
3 | import java.util.concurrent.TimeUnit
4 |
5 | import akka.actor.{
6 | Actor,
7 | ActorLogging,
8 | ActorRef,
9 | ActorSystem,
10 | ExtendedActorSystem,
11 | Extension,
12 | ExtensionId,
13 | ExtensionIdProvider,
14 | Props
15 | }
16 | import akka.cluster.sharding.ShardRegion
17 | import akka.cluster.sharding.ShardRegion.MessageExtractor
18 | import com.typesafe.config.Config
19 | import io.bernhardt.akka.locality.router.{ ShardLocationAwareGroup, ShardLocationAwarePool, ShardStateMonitor }
20 |
21 | import scala.collection.immutable
22 | import scala.concurrent.duration.FiniteDuration
23 |
24 | object Locality extends ExtensionId[Locality] with ExtensionIdProvider {
25 | override def get(system: ActorSystem): Locality = super.get(system)
26 |
27 | override def createExtension(system: ExtendedActorSystem): Locality = new Locality(system)
28 |
29 | override def lookup(): ExtensionId[_ <: Extension] = Locality
30 | }
31 |
32 | /**
33 | * This module provides constructs that help to make better use of the locality of actors within a clustered Akka system.
34 | */
35 | class Locality(system: ExtendedActorSystem) extends Extension {
36 | private val settings = LocalitySettings(system.settings.config)
37 |
38 | system.systemActorOf(Props(new LocalitySupervisor(settings)), "locality")
39 |
40 | /**
41 | * Scala API: Create a shard location aware group
42 | *
43 | * @param routeePaths string representation of the actor paths of the routees, messages are
44 | * sent with [[akka.actor.ActorSelection]] to these paths
45 | * @param shardRegion the reference to the shard region
46 | * @param extractEntityId the [[akka.cluster.sharding.ShardRegion.ExtractEntityId]] function used to extract the entity id from a message
47 | * @param extractShardId the [[akka.cluster.sharding.ShardRegion.ExtractShardId]] function used to extract the shard id from a message
48 | */
49 | def shardLocationAwareGroup(
50 | routeePaths: immutable.Iterable[String],
51 | shardRegion: ActorRef,
52 | extractEntityId: ShardRegion.ExtractEntityId,
53 | extractShardId: ShardRegion.ExtractShardId): ShardLocationAwareGroup =
54 | ShardLocationAwareGroup(routeePaths, shardRegion, extractEntityId, extractShardId)
55 |
56 | /**
57 | * Java API: Create a shard location aware group
58 | *
59 | * @param routeePaths string representation of the actor paths of the routees, messages are
60 | * sent with [[akka.actor.ActorSelection]] to these paths
61 | * @param shardRegion the reference to the shard region
62 | * @param messageExtractor the [[akka.cluster.sharding.ShardRegion.MessageExtractor]] used for the sharding
63 | * of the entities this router should optimize routing for
64 | */
65 | def shardLocationAwareGroup(
66 | routeePaths: java.lang.Iterable[String],
67 | shardRegion: ActorRef,
68 | messageExtractor: MessageExtractor): ShardLocationAwareGroup =
69 | new ShardLocationAwareGroup(routeePaths, shardRegion, messageExtractor)
70 |
71 | /**
72 | * Scala API: Create a shard location aware pool
73 | *
74 | * @param nrOfInstances how many routees this pool router should have
75 | * @param shardRegion the reference to the shard region
76 | * @param extractEntityId the [[akka.cluster.sharding.ShardRegion.ExtractEntityId]] function used to extract the entity id from a message
77 | * @param extractShardId the [[akka.cluster.sharding.ShardRegion.ExtractShardId]] function used to extract the shard id from a message
78 | */
79 | def shardLocationAwarePool(
80 | nrOfInstances: Int,
81 | shardRegion: ActorRef,
82 | extractEntityId: ShardRegion.ExtractEntityId,
83 | extractShardId: ShardRegion.ExtractShardId): ShardLocationAwarePool =
84 | ShardLocationAwarePool(
85 | nrOfInstances = nrOfInstances,
86 | shardRegion = shardRegion,
87 | extractEntityId = extractEntityId,
88 | extractShardId = extractShardId)
89 |
90 | /**
91 | * Java API: Create a shard location aware pool
92 | *
93 | * @param nrOfInstances how many routees this pool router should have
94 | * @param shardRegion the reference to the shard region
95 | * @param messageExtractor the [[akka.cluster.sharding.ShardRegion.MessageExtractor]] used for the sharding
96 | * of the entities this router should optimize routing for
97 | */
98 | def shardLocationAwarePool(nrOfInstances: Int, shardRegion: ActorRef, messageExtractor: MessageExtractor) =
99 | new ShardLocationAwarePool(nrOfInstances, shardRegion, messageExtractor)
100 | }
101 |
102 | private[locality] final class LocalitySupervisor(settings: LocalitySettings) extends Actor with ActorLogging {
103 | import LocalitySupervisor._
104 |
105 | def receive: Receive = {
106 | case m @ MonitorShards(region) =>
107 | val regionName = encodeRegionName(region)
108 | context
109 | .child(regionName)
110 | .map { monitor =>
111 | monitor.forward(m)
112 | }
113 | .getOrElse {
114 | log.info("Starting to monitor shards of region {}", regionName)
115 | context.actorOf(ShardStateMonitor.props(region, regionName, settings), regionName).forward(m)
116 | }
117 | }
118 | }
119 |
120 | object LocalitySupervisor {
121 | private[locality] final case class MonitorShards(region: ActorRef)
122 | }
123 |
124 | final case class LocalitySettings(config: Config) {
125 | private val localityConfig = config.getConfig("akka.locality")
126 |
127 | val RetrieveShardStateTimeout: FiniteDuration =
128 | FiniteDuration(
129 | localityConfig.getDuration("retrieve-shard-state-timeout", TimeUnit.MILLISECONDS),
130 | TimeUnit.MILLISECONDS)
131 |
132 | val ShardStateUpdateMargin =
133 | FiniteDuration(
134 | localityConfig.getDuration("shard-state-update-margin", TimeUnit.MILLISECONDS),
135 | TimeUnit.MILLISECONDS)
136 |
137 | val ShardStatePollingInterval =
138 | FiniteDuration(
139 | localityConfig.getDuration("shard-state-polling-interval", TimeUnit.MILLISECONDS),
140 | TimeUnit.MILLISECONDS)
141 | }
142 |
--------------------------------------------------------------------------------
/src/main/scala/io/bernhardt/akka/locality/package.scala:
--------------------------------------------------------------------------------
1 | package io.bernhardt.akka
2 |
3 | import java.net.URLEncoder
4 |
5 | import akka.actor.ActorRef
6 | import akka.cluster.sharding.ShardRegion.ShardId
7 |
8 | package object locality {
9 | private[locality] def encodeRegionName(region: ActorRef): String =
10 | URLEncoder.encode(region.path.name.replaceAll("Proxy", ""), "utf-8")
11 |
12 | private[locality] def encodeShardId(id: ShardId): String = URLEncoder.encode(id, "utf-8")
13 | }
14 |
--------------------------------------------------------------------------------
/src/main/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRouter.scala:
--------------------------------------------------------------------------------
1 | package io.bernhardt.akka.locality.router
2 |
3 | import java.util.concurrent.atomic.AtomicReference
4 | import java.util.concurrent.{ ThreadLocalRandom, TimeUnit }
5 |
6 | import akka.actor._
7 | import akka.cluster.sharding.ShardRegion
8 | import akka.cluster.sharding.ShardRegion._
9 | import akka.dispatch.Dispatchers
10 | import akka.event.Logging
11 | import akka.japi.Util.immutableSeq
12 | import akka.pattern.{ ask, AskTimeoutException }
13 | import akka.routing._
14 | import akka.util.Timeout
15 | import io.bernhardt.akka.locality.LocalitySupervisor
16 |
17 | import scala.collection.immutable
18 | import scala.collection.immutable.IndexedSeq
19 | import scala.concurrent.Future
20 | import scala.util.control.NonFatal
21 |
22 | object ShardLocationAwareRouter {
23 | def extractEntityIdFrom(messageExtractor: MessageExtractor): ShardRegion.ExtractEntityId = {
24 | case msg if messageExtractor.entityId(msg) ne null =>
25 | (messageExtractor.entityId(msg), messageExtractor.entityMessage(msg))
26 | }
27 | }
28 |
29 | /**
30 | * A group router that will route to the routees deployed the closest to the sharded entity they need to interact with
31 | */
32 | @SerialVersionUID(1L)
33 | final case class ShardLocationAwareGroup(
34 | routeePaths: immutable.Iterable[String],
35 | shardRegion: ActorRef,
36 | extractEntityId: ShardRegion.ExtractEntityId,
37 | extractShardId: ShardRegion.ExtractShardId,
38 | override val routerDispatcher: String = Dispatchers.DefaultDispatcherId)
39 | extends Group {
40 | /**
41 | * Java API
42 | *
43 | * @param routeePaths string representation of the actor paths of the routees, messages are
44 | * sent with [[akka.actor.ActorSelection]] to these paths
45 | * @param shardRegion the reference to the shard region
46 | * @param messageExtractor the [[akka.cluster.sharding.ShardRegion.MessageExtractor]] used for the sharding
47 | * of the entities this router should optimize routing for
48 | */
49 | def this(routeePaths: java.lang.Iterable[String], shardRegion: ActorRef, messageExtractor: MessageExtractor) =
50 | this(
51 | immutableSeq(routeePaths),
52 | shardRegion,
53 | ShardLocationAwareRouter.extractEntityIdFrom(messageExtractor),
54 | extractShardId = msg => messageExtractor.shardId(msg))
55 |
56 | /**
57 | * Setting the dispatcher to be used for the router head actor, which handles
58 | * supervision, death watch and router management messages.
59 | */
60 | def withDispatcher(dispatcherId: String): ShardLocationAwareGroup = copy(routerDispatcher = dispatcherId)
61 |
62 | override def paths(system: ActorSystem): immutable.Iterable[String] = routeePaths
63 |
64 | override def createRouter(system: ActorSystem): Router =
65 | new Router(ShardLocationAwareRoutingLogic(system, shardRegion, extractEntityId, extractShardId))
66 | }
67 |
68 | /**
69 | * A pool router that will route to the routees deployed the closest to the sharded entity they need to interact with
70 | */
71 | @SerialVersionUID(1L)
72 | final case class ShardLocationAwarePool(
73 | nrOfInstances: Int,
74 | override val resizer: Option[Resizer] = None,
75 | shardRegion: ActorRef,
76 | extractEntityId: ShardRegion.ExtractEntityId,
77 | extractShardId: ShardRegion.ExtractShardId,
78 | override val supervisorStrategy: SupervisorStrategy = Pool.defaultSupervisorStrategy,
79 | override val routerDispatcher: String = Dispatchers.DefaultDispatcherId,
80 | override val usePoolDispatcher: Boolean = false)
81 | extends Pool {
82 | /**
83 | * Java API
84 | *
85 | * @param nrOfInstances how many routees this pool router should have
86 | * @param shardRegion the reference to the shard region
87 | * @param messageExtractor the [[akka.cluster.sharding.ShardRegion.MessageExtractor]] used for the sharding
88 | * of the entities this router should optimize routing for
89 | */
90 | def this(nrOfInstances: Int, shardRegion: ActorRef, messageExtractor: MessageExtractor) =
91 | this(
92 | nrOfInstances = nrOfInstances,
93 | shardRegion = shardRegion,
94 | extractEntityId = ShardLocationAwareRouter.extractEntityIdFrom(messageExtractor),
95 | extractShardId = msg => messageExtractor.shardId(msg))
96 |
97 | /**
98 | * Setting the supervisor strategy to be used for the “head” Router actor.
99 | */
100 | def withSupervisorStrategy(strategy: SupervisorStrategy): ShardLocationAwarePool = copy(supervisorStrategy = strategy)
101 |
102 | /**
103 | * Setting the resizer to be used.
104 | */
105 | def withResizer(resizer: Resizer): ShardLocationAwarePool = copy(resizer = Some(resizer))
106 |
107 | /**
108 | * Setting the dispatcher to be used for the router head actor, which handles
109 | * supervision, death watch and router management messages.
110 | */
111 | def withDispatcher(dispatcherId: String): ShardLocationAwarePool = copy(routerDispatcher = dispatcherId)
112 |
113 | /**
114 | * Setting whether to use a dedicated dispatcher for the routees of the pool.
115 | * The dispatcher is defined in 'pool-dispatcher' configuration property in the
116 | * deployment section of the router.
117 | */
118 | def usePoolDispatcher(usePoolDispatcher: Boolean): ShardLocationAwarePool =
119 | copy(usePoolDispatcher = usePoolDispatcher)
120 |
121 | override def createRouter(system: ActorSystem): Router =
122 | new Router(ShardLocationAwareRoutingLogic(system, shardRegion, extractEntityId, extractShardId))
123 |
124 | override def nrOfInstances(sys: ActorSystem): Int = this.nrOfInstances
125 | }
126 |
127 | /**
128 | * Router logic that makes its routing decision based on the relative location of routees to the shards they will
129 | * communicate with, on a best-effort basis.
130 | * When no shard state information is available this logic falls back to random routing.
131 | * When there are multiple candidate routees on the same node, one of them is selected at random.
132 | *
133 | * @param system the [[akka.actor.ActorSystem]]
134 | * @param shardRegion the reference to the [[akka.cluster.sharding.ShardRegion]] the local routees will communicate with
135 | * @param extractEntityId partial function to extract the entity id from a message, should be the same as used for sharding
136 | * @param extractShardId partial function to extract the shard id based on a message, should be the same as used for sharding
137 | */
138 | final case class ShardLocationAwareRoutingLogic(
139 | system: ActorSystem,
140 | shardRegion: ActorRef,
141 | extractEntityId: ShardRegion.ExtractEntityId,
142 | extractShardId: ShardRegion.ExtractShardId)
143 | extends RoutingLogic {
144 | import io.bernhardt.akka.locality.router.ShardStateMonitor._
145 | import system.dispatcher
146 |
147 | private lazy val log = Logging(system, getClass)
148 | private lazy val selfAddress = system.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress
149 | private val clusterShardingStateRef = new AtomicReference[Map[ShardId, Address]](Map.empty)
150 | private val shardLocationAwareRouteeRef =
151 | new AtomicReference[(IndexedSeq[Routee], Map[Address, IndexedSeq[ShardLocationAwareRoutee]])](
152 | (IndexedSeq.empty, Map.empty))
153 |
154 | watchShardStateChanges()
155 |
156 | override def select(message: Any, routees: IndexedSeq[Routee]): Routee = {
157 | if (routees.isEmpty) {
158 | NoRoutee
159 | } else {
160 | // avoid re-creating routees for each message by checking if they have changed
161 | def updateShardLocationAwareRoutees(): Map[Address, IndexedSeq[ShardLocationAwareRoutee]] = {
162 | val oldShardingRouteeTuple = shardLocationAwareRouteeRef.get()
163 | val (oldRoutees, oldShardLocationAwareRoutees) = oldShardingRouteeTuple
164 | if ((routees ne oldRoutees) && routees != oldRoutees) {
165 | val allRoutees = routees.map(ShardLocationAwareRoutee(_, selfAddress))
166 | val newShardLocationAwareRoutees = allRoutees.groupBy(_.address)
167 | // don't act on failure of compare and set, as the next call to select will update anyway if the routees remain different
168 | shardLocationAwareRouteeRef.compareAndSet(oldShardingRouteeTuple, (routees, newShardLocationAwareRoutees))
169 | newShardLocationAwareRoutees
170 | } else {
171 | oldShardLocationAwareRoutees
172 | }
173 | }
174 |
175 | val shardId: ShardId = extractShardId(message)
176 | val shardLocationAwareRoutees = updateShardLocationAwareRoutees()
177 |
178 | val candidateRoutees = for {
179 | location <- clusterShardingStateRef.get().get(shardId)
180 | locationAwareRoutees <- shardLocationAwareRoutees.get(location)
181 | } yield {
182 | val closeRoutees = locationAwareRoutees.map(_.routee)
183 |
184 | // pick one of the local routees at random
185 | closeRoutees(ThreadLocalRandom.current.nextInt(closeRoutees.size))
186 | }
187 |
188 | candidateRoutees.getOrElse {
189 | log.debug("Falling back to random routing for message in shard {}", shardId)
190 | // if we couldn't figure out the location of the shard, fall back to random routing
191 | routees(ThreadLocalRandom.current.nextInt(routees.size))
192 | }
193 | }
194 | }
195 |
196 | private def watchShardStateChanges(): Unit = {
197 | implicit val timeout: Timeout = Timeout(2 ^ 64, TimeUnit.DAYS)
198 | val localitySel = system.actorSelection("/system/locality")
199 | val change: Future[ShardStateChanged] =
200 | (localitySel ? LocalitySupervisor.MonitorShards(shardRegion)).mapTo[ShardStateChanged]
201 | change
202 | .map { stateChanged =>
203 | if (stateChanged.newState.nonEmpty) {
204 | log.info("Updating cluster sharding state for {} shards", stateChanged.newState.keys.size)
205 | clusterShardingStateRef.set(stateChanged.newState)
206 | }
207 | watchShardStateChanges()
208 | }
209 | .recover {
210 | case _: AskTimeoutException =>
211 | // we were shutting down, ignore
212 | case NonFatal(t) =>
213 | log.warning("Could not monitor cluster sharding state: {}", t.getMessage)
214 | }
215 | }
216 | }
217 |
218 | private[locality] final case class ShardLocationAwareRoutee(routee: Routee, selfAddress: Address) {
219 | // extract the address of the routee. In case of a LocalActorRef, host and port are not provided
220 | // therefore we fall back to the address of the local node
221 | val address = {
222 | val routeeAddress = routee match {
223 | case ActorRefRoutee(ref) => ref.path.address
224 | case ActorSelectionRoutee(sel) => sel.anchorPath.address
225 | }
226 |
227 | routeeAddress match {
228 | case Address(_, system, None, None) => selfAddress.copy(system = system)
229 | case fullAddress => fullAddress
230 | }
231 | }
232 | }
233 |
--------------------------------------------------------------------------------
/src/main/scala/io/bernhardt/akka/locality/router/ShardStateMonitor.scala:
--------------------------------------------------------------------------------
1 | package io.bernhardt.akka.locality.router
2 |
3 | import akka.actor.{
4 | Actor,
5 | ActorIdentity,
6 | ActorLogging,
7 | ActorRef,
8 | Address,
9 | DeadLetterSuppression,
10 | Identify,
11 | Props,
12 | RootActorPath,
13 | Terminated,
14 | Timers
15 | }
16 | import akka.cluster.sharding.ShardRegion.{ ClusterShardingStats, GetClusterShardingStats, ShardId, ShardRegionStats }
17 | import io.bernhardt.akka.locality._
18 | import io.bernhardt.akka.locality.LocalitySupervisor.MonitorShards
19 |
20 | /**
21 | * Internal: watches shard actors in order to trigger an update. Only trigger the update when the system is stable for a while.
22 | */
23 | private[locality] class ShardStateMonitor(shardRegion: ActorRef, encodedRegionName: String, settings: LocalitySettings)
24 | extends Actor
25 | with ActorLogging
26 | with Timers {
27 | import ShardStateMonitor._
28 |
29 | val ClusterGuardianName: String =
30 | context.system.settings.config.getString("akka.cluster.sharding.guardian-name")
31 |
32 | var watchedShards = Set.empty[ShardId]
33 |
34 | var routerLogic: ActorRef = context.system.deadLetters
35 |
36 | var latestClusterState: Option[ClusterShardingStats] = None
37 |
38 | // technically, this may also just be a node that was terminated
39 | // but if the coordinator does its job, it will rebalance / reallocate the terminated shards
40 | // either way, this flag signals that the topology is currently changing
41 | var rebalanceInProgress: Boolean = false
42 |
43 | def receive: Receive = {
44 | case _: MonitorShards =>
45 | log.debug("Starting to monitor shards for logic {}", routerLogic.path)
46 | routerLogic = sender()
47 | requestClusterShardingState()
48 | timers.startPeriodicTimer(UpdateClusterState, UpdateClusterState, settings.ShardStatePollingInterval)
49 | context.become(watchingChanges)
50 | }
51 |
52 | def watchingChanges: Receive = {
53 | case _: MonitorShards =>
54 | routerLogic = sender()
55 | case UpdateClusterStateOnRebalance =>
56 | rebalanceInProgress = false
57 | requestClusterShardingState()
58 | case UpdateClusterState =>
59 | if (!rebalanceInProgress) {
60 | requestClusterShardingState()
61 | }
62 | case ActorIdentity(shardId: ShardId, Some(ref)) =>
63 | log.debug("Now watching shard {}", ref.path)
64 | context.watch(ref)
65 | watchedShards += shardId
66 | case ActorIdentity(shardId, None) => // couldn't get shard ref, not much we can do
67 | log.warning("Could not watch shard {}, shard location aware routing may not work", shardId)
68 | case Terminated(ref) =>
69 | log.debug("Watched shard actor {} terminated", ref.path)
70 | rebalanceInProgress = true
71 | watchedShards -= encodeShardId(ref.path.name)
72 | // reset the timer - we only want to request state once things are stable
73 | timers.cancel(UpdateClusterStateOnRebalance)
74 | timers.startSingleTimer(
75 | UpdateClusterStateOnRebalance,
76 | UpdateClusterStateOnRebalance,
77 | settings.ShardStateUpdateMargin)
78 | case stats @ ClusterShardingStats(regions) =>
79 | log.debug("Received cluster sharding stats for {} regions", regions.size)
80 | if (!latestClusterState.contains(stats)) {
81 | log.debug("Cluster sharding state changed, notifying subscriber")
82 | latestClusterState = Some(stats)
83 | if (regions.isEmpty) {
84 | log.warning("Cluster Sharding Stats empty - locality-aware routing will not function correctly")
85 | } else {
86 | notifyShardStateChanged(regions)
87 | watchShards(regions)
88 | }
89 | }
90 | }
91 |
92 | def requestClusterShardingState(): Unit = {
93 | log.debug("Requesting cluster state update")
94 | shardRegion ! GetClusterShardingStats(settings.RetrieveShardStateTimeout)
95 | }
96 |
97 | def watchShards(regions: Map[Address, ShardRegionStats]): Unit = {
98 | regions.foreach {
99 | case (address, regionStats) =>
100 | val regionPath = RootActorPath(address) / "system" / ClusterGuardianName / encodedRegionName
101 | regionStats.stats.keys.filterNot(watchedShards).foreach { shardId =>
102 | val shardPath = regionPath / encodeShardId(shardId)
103 | context.actorSelection(shardPath) ! Identify(shardId)
104 | }
105 | }
106 | }
107 |
108 | def notifyShardStateChanged(regions: Map[Address, ShardRegionStats]): Unit = {
109 | val shardsByAddress = regions.flatMap {
110 | case (address, ShardRegionStats(shards)) =>
111 | shards.map {
112 | case (shardId, _) =>
113 | shardId -> address
114 | }
115 | }
116 | routerLogic ! ShardStateChanged(shardsByAddress)
117 | }
118 |
119 | override def postStop(): Unit = {
120 | routerLogic ! ShardStateChanged(Map.empty)
121 | }
122 | }
123 |
124 | object ShardStateMonitor {
125 | final case class ShardStateChanged(newState: Map[ShardId, Address]) extends DeadLetterSuppression
126 | final case object UpdateClusterState extends DeadLetterSuppression
127 | final case object UpdateClusterStateOnRebalance extends DeadLetterSuppression
128 |
129 | private[locality] def props(shardRegion: ActorRef, entityName: String, settings: LocalitySettings) =
130 | Props(new ShardStateMonitor(shardRegion, entityName, settings))
131 | }
132 |
--------------------------------------------------------------------------------
/src/multi-jvm/scala/akka/README.md:
--------------------------------------------------------------------------------
1 | Files in this package are copied verbatim (with minor adaptations) from Akka
2 | so as to re-use the existing multi-node cluster sharding spec.
--------------------------------------------------------------------------------
/src/multi-jvm/scala/akka/cluster/FailureDetectorPuppet.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2009-2019 Lightbend Inc.
3 | */
4 |
5 | package akka.cluster
6 |
7 | import java.util.concurrent.atomic.AtomicReference
8 |
9 | import akka.remote.FailureDetector
10 | import com.typesafe.config.Config
11 | import akka.event.EventStream
12 | import akka.util.unused
13 |
14 | /**
15 | * User controllable "puppet" failure detector.
16 | */
17 | class FailureDetectorPuppet(@unused config: Config, @unused ev: EventStream) extends FailureDetector {
18 |
19 | trait Status
20 | object Up extends Status
21 | object Down extends Status
22 | object Unknown extends Status
23 |
24 | private val status: AtomicReference[Status] = new AtomicReference(Unknown)
25 |
26 | def markNodeAsUnavailable(): Unit = status.set(Down)
27 |
28 | def markNodeAsAvailable(): Unit = status.set(Up)
29 |
30 | override def isAvailable: Boolean = status.get match {
31 | case Unknown | Up => true
32 | case Down => false
33 | }
34 |
35 | override def isMonitoring: Boolean = status.get != Unknown
36 |
37 | override def heartbeat(): Unit = status.compareAndSet(Unknown, Up)
38 |
39 | }
40 |
--------------------------------------------------------------------------------
/src/multi-jvm/scala/akka/cluster/MultiNodeClusterSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2009-2019 Lightbend Inc.
3 | */
4 |
5 | package akka.cluster
6 |
7 | import java.util.UUID
8 | import java.util.concurrent.ConcurrentHashMap
9 |
10 | import akka.actor.{ Actor, ActorRef, ActorSystem, Address, Deploy, PoisonPill, Props, RootActorPath }
11 | import akka.cluster.ClusterEvent.{ MemberEvent, MemberRemoved }
12 | import akka.event.Logging.ErrorLevel
13 | import akka.remote.DefaultFailureDetectorRegistry
14 | import akka.remote.testconductor.RoleName
15 | import akka.remote.testkit.{ MultiNodeSpec, STMultiNodeSpec }
16 | import akka.serialization.jackson.CborSerializable
17 | import akka.testkit.TestEvent._
18 | import akka.testkit._
19 | import akka.util.ccompat._
20 | import com.typesafe.config.{ Config, ConfigFactory }
21 | import org.scalatest.exceptions.TestCanceledException
22 | import org.scalatest.{ Canceled, Outcome, Suite }
23 |
24 | import scala.collection.immutable
25 | import scala.concurrent.Await
26 | import scala.concurrent.duration._
27 | import scala.language.implicitConversions
28 |
29 | @ccompatUsedUntil213
30 | object MultiNodeClusterSpec {
31 |
32 | def clusterConfigWithFailureDetectorPuppet: Config =
33 | ConfigFactory
34 | .parseString("akka.cluster.failure-detector.implementation-class = akka.cluster.FailureDetectorPuppet")
35 | .withFallback(clusterConfig)
36 |
37 | def clusterConfig(failureDetectorPuppet: Boolean): Config =
38 | if (failureDetectorPuppet) clusterConfigWithFailureDetectorPuppet else clusterConfig
39 |
40 | def clusterConfig: Config = ConfigFactory.parseString(s"""
41 | akka.actor.provider = cluster
42 | akka.actor.warn-about-java-serializer-usage = off
43 | akka.cluster {
44 | jmx.enabled = off
45 | gossip-interval = 200 ms
46 | leader-actions-interval = 200 ms
47 | unreachable-nodes-reaper-interval = 500 ms
48 | periodic-tasks-initial-delay = 300 ms
49 | publish-stats-interval = 0 s # always, when it happens
50 | failure-detector.heartbeat-interval = 500 ms
51 | run-coordinated-shutdown-when-down = off
52 |
53 | sharding {
54 | retry-interval = 200ms
55 | waiting-for-state-timeout = 200ms
56 | }
57 | }
58 | akka.loglevel = INFO
59 | akka.log-dead-letters = off
60 | akka.log-dead-letters-during-shutdown = off
61 | akka.remote {
62 | log-remote-lifecycle-events = off
63 | artery.advanced.flight-recorder {
64 | enabled=on
65 | destination=target/flight-recorder-${UUID.randomUUID().toString}.afr
66 | }
67 | }
68 | akka.loggers = ["akka.testkit.TestEventListener"]
69 | akka.test {
70 | single-expect-default = 5 s
71 | }
72 |
73 | """)
74 |
75 | // sometimes we need to coordinate test shutdown with messages instead of barriers
76 | object EndActor {
77 | case object SendEnd extends CborSerializable
78 | case object End extends CborSerializable
79 | case object EndAck extends CborSerializable
80 | }
81 |
82 | class EndActor(testActor: ActorRef, target: Option[Address]) extends Actor {
83 | import EndActor._
84 | def receive: Receive = {
85 | case SendEnd =>
86 | target.foreach { t =>
87 | context.actorSelection(RootActorPath(t) / self.path.elements) ! End
88 | }
89 | case End =>
90 | testActor.forward(End)
91 | sender() ! EndAck
92 | case EndAck =>
93 | testActor.forward(EndAck)
94 | }
95 | }
96 | }
97 |
98 | trait MultiNodeClusterSpec extends Suite with STMultiNodeSpec {
99 | self: MultiNodeSpec =>
100 |
101 | override def initialParticipants = roles.size
102 |
103 | private val cachedAddresses = new ConcurrentHashMap[RoleName, Address]
104 |
105 | override protected def atStartup(): Unit = {
106 | muteLog()
107 | self.atStartup()
108 | }
109 |
110 | override protected def afterTermination(): Unit = {
111 | self.afterTermination()
112 | }
113 |
114 | def muteLog(sys: ActorSystem = system): Unit = {
115 | if (!sys.log.isDebugEnabled) {
116 | Seq(
117 | ".*Cluster Node.* - registered cluster JMX MBean.*",
118 | ".*Cluster Node.* - is starting up.*",
119 | ".*Shutting down cluster Node.*",
120 | ".*Cluster node successfully shut down.*",
121 | ".*Using a dedicated scheduler for cluster.*").foreach { s =>
122 | sys.eventStream.publish(Mute(EventFilter.info(pattern = s)))
123 | }
124 |
125 | muteDeadLetters(
126 | classOf[ClusterHeartbeatSender.Heartbeat],
127 | classOf[ClusterHeartbeatSender.HeartbeatRsp],
128 | classOf[GossipEnvelope],
129 | classOf[GossipStatus],
130 | classOf[InternalClusterAction.Tick],
131 | classOf[akka.actor.PoisonPill],
132 | classOf[akka.dispatch.sysmsg.DeathWatchNotification],
133 | classOf[akka.remote.transport.AssociationHandle.Disassociated],
134 | // akka.remote.transport.AssociationHandle.Disassociated.getClass,
135 | classOf[akka.remote.transport.ActorTransportAdapter.DisassociateUnderlying],
136 | // akka.remote.transport.ActorTransportAdapter.DisassociateUnderlying.getClass,
137 | classOf[akka.remote.transport.AssociationHandle.InboundPayload])(sys)
138 |
139 | }
140 | }
141 |
142 | def muteMarkingAsUnreachable(sys: ActorSystem = system): Unit =
143 | if (!sys.log.isDebugEnabled)
144 | sys.eventStream.publish(Mute(EventFilter.error(pattern = ".*Marking.* as UNREACHABLE.*")))
145 |
146 | def muteMarkingAsReachable(sys: ActorSystem = system): Unit =
147 | if (!sys.log.isDebugEnabled)
148 | sys.eventStream.publish(Mute(EventFilter.info(pattern = ".*Marking.* as REACHABLE.*")))
149 |
150 | override def afterAll(): Unit = {
151 | if (!log.isDebugEnabled) {
152 | muteDeadLetters()()
153 | system.eventStream.setLogLevel(ErrorLevel)
154 | }
155 | super.afterAll()
156 | }
157 |
158 | /**
159 | * Lookup the Address for the role.
160 | *
161 | * Implicit conversion from RoleName to Address.
162 | *
163 | * It is cached, which has the implication that stopping
164 | * and then restarting a role (jvm) with another address is not
165 | * supported.
166 | */
167 | implicit def address(role: RoleName): Address = {
168 | cachedAddresses.get(role) match {
169 | case null =>
170 | val address = node(role).address
171 | cachedAddresses.put(role, address)
172 | address
173 | case address => address
174 | }
175 | }
176 |
177 | // Cluster tests are written so that if previous step (test method) failed
178 | // it will most likely not be possible to run next step. This ensures
179 | // fail fast of steps after the first failure.
180 | private var failed = false
181 | override protected def withFixture(test: NoArgTest): Outcome =
182 | if (failed) {
183 | Canceled(new TestCanceledException("Previous step failed", 0))
184 | } else {
185 | val out = super.withFixture(test)
186 | if (!out.isSucceeded)
187 | failed = true
188 | out
189 | }
190 |
191 | def clusterView: ClusterReadView = cluster.readView
192 |
193 | /**
194 | * Get the cluster node to use.
195 | */
196 | def cluster: Cluster = Cluster(system)
197 |
198 | /**
199 | * Use this method for the initial startup of the cluster node.
200 | */
201 | def startClusterNode(): Unit = {
202 | if (clusterView.members.isEmpty) {
203 | cluster.join(myself)
204 | awaitAssert(clusterView.members.map(_.address) should contain(address(myself)))
205 | } else
206 | clusterView.self
207 | }
208 |
209 | /**
210 | * Initialize the cluster of the specified member
211 | * nodes (roles) and wait until all joined and `Up`.
212 | * First node will be started first and others will join
213 | * the first.
214 | */
215 | def awaitClusterUp(roles: RoleName*): Unit = {
216 | runOn(roles.head) {
217 | // make sure that the node-to-join is started before other join
218 | startClusterNode()
219 | }
220 | enterBarrier(roles.head.name + "-started")
221 | if (roles.tail.contains(myself)) {
222 | cluster.join(roles.head)
223 | }
224 | if (roles.contains(myself)) {
225 | awaitMembersUp(numberOfMembers = roles.length)
226 | }
227 | enterBarrier(roles.map(_.name).mkString("-") + "-joined")
228 | }
229 |
230 | /**
231 | * Join the specific node within the given period by sending repeated join
232 | * requests at periodic intervals until we succeed.
233 | */
234 | def joinWithin(joinNode: RoleName, max: Duration = remainingOrDefault, interval: Duration = 1.second): Unit = {
235 | def memberInState(member: Address, status: Seq[MemberStatus]): Boolean =
236 | clusterView.members.exists { m =>
237 | (m.address == member) && status.contains(m.status)
238 | }
239 |
240 | cluster.join(joinNode)
241 | awaitCond(
242 | {
243 | if (memberInState(joinNode, List(MemberStatus.Up)) &&
244 | memberInState(myself, List(MemberStatus.Joining, MemberStatus.Up)))
245 | true
246 | else {
247 | cluster.join(joinNode)
248 | false
249 | }
250 | },
251 | max,
252 | interval)
253 | }
254 |
255 | /**
256 | * Assert that the member addresses match the expected addresses in the
257 | * sort order used by the cluster.
258 | */
259 | def assertMembers(gotMembers: Iterable[Member], expectedAddresses: Address*): Unit = {
260 | import Member.addressOrdering
261 | val members = gotMembers.toIndexedSeq
262 | members.size should ===(expectedAddresses.length)
263 | expectedAddresses.sorted.zipWithIndex.foreach { case (a, i) => members(i).address should ===(a) }
264 | }
265 |
266 | /**
267 | * Note that this can only be used for a cluster with all members
268 | * in Up status, i.e. use `awaitMembersUp` before using this method.
269 | * The reason for that is that the cluster leader is preferably a
270 | * member with status Up or Leaving and that information can't
271 | * be determined from the `RoleName`.
272 | */
273 | def assertLeader(nodesInCluster: RoleName*): Unit =
274 | if (nodesInCluster.contains(myself)) assertLeaderIn(nodesInCluster.to(immutable.Seq))
275 |
276 | /**
277 | * Assert that the cluster has elected the correct leader
278 | * out of all nodes in the cluster. First
279 | * member in the cluster ring is expected leader.
280 | *
281 | * Note that this can only be used for a cluster with all members
282 | * in Up status, i.e. use `awaitMembersUp` before using this method.
283 | * The reason for that is that the cluster leader is preferably a
284 | * member with status Up or Leaving and that information can't
285 | * be determined from the `RoleName`.
286 | */
287 | def assertLeaderIn(nodesInCluster: immutable.Seq[RoleName]): Unit =
288 | if (nodesInCluster.contains(myself)) {
289 | nodesInCluster.length should not be (0)
290 | val expectedLeader = roleOfLeader(nodesInCluster)
291 | val leader = clusterView.leader
292 | val isLeader = leader == Some(clusterView.selfAddress)
293 | assert(
294 | isLeader == isNode(expectedLeader),
295 | "expectedLeader [%s], got leader [%s], members [%s]".format(expectedLeader, leader, clusterView.members))
296 | clusterView.status should (be(MemberStatus.Up).or(be(MemberStatus.Leaving)))
297 | }
298 |
299 | /**
300 | * Wait until the expected number of members has status Up has been reached.
301 | * Also asserts that nodes in the 'canNotBePartOfMemberRing' are *not* part of the cluster ring.
302 | */
303 | def awaitMembersUp(
304 | numberOfMembers: Int,
305 | canNotBePartOfMemberRing: Set[Address] = Set.empty,
306 | timeout: FiniteDuration = 25.seconds): Unit = {
307 | within(timeout) {
308 | if (!canNotBePartOfMemberRing.isEmpty) // don't run this on an empty set
309 | awaitAssert(canNotBePartOfMemberRing.foreach(a => clusterView.members.map(_.address) should not contain (a)))
310 | awaitAssert(clusterView.members.size should ===(numberOfMembers))
311 | awaitAssert(clusterView.members.unsorted.map(_.status) should ===(Set(MemberStatus.Up)))
312 | // clusterView.leader is updated by LeaderChanged, await that to be updated also
313 | val expectedLeader = clusterView.members.collectFirst {
314 | case m if m.dataCenter == cluster.settings.SelfDataCenter => m.address
315 | }
316 | awaitAssert(clusterView.leader should ===(expectedLeader))
317 | }
318 | }
319 |
320 | def awaitMemberRemoved(toBeRemovedAddress: Address, timeout: FiniteDuration = 25.seconds): Unit = within(timeout) {
321 | if (toBeRemovedAddress == cluster.selfAddress) {
322 | enterBarrier("registered-listener")
323 |
324 | cluster.leave(toBeRemovedAddress)
325 | enterBarrier("member-left")
326 |
327 | awaitCond(cluster.isTerminated, remaining)
328 | enterBarrier("member-shutdown")
329 | } else {
330 | val exitingLatch = TestLatch()
331 |
332 | val awaiter = system.actorOf(Props(new Actor {
333 | def receive = {
334 | case MemberRemoved(m, _) if m.address == toBeRemovedAddress =>
335 | exitingLatch.countDown()
336 | case _ =>
337 | // ignore
338 | }
339 | }).withDeploy(Deploy.local))
340 | cluster.subscribe(awaiter, classOf[MemberEvent])
341 | enterBarrier("registered-listener")
342 |
343 | // in the meantime member issues leave
344 | enterBarrier("member-left")
345 |
346 | // verify that the member is EXITING
347 | try Await.result(exitingLatch, timeout)
348 | catch {
349 | case cause: Exception =>
350 | throw new AssertionError(s"Member ${toBeRemovedAddress} was not removed within ${timeout}!", cause)
351 | }
352 | awaiter ! PoisonPill // you've done your job, now die
353 |
354 | enterBarrier("member-shutdown")
355 | markNodeAsUnavailable(toBeRemovedAddress)
356 | }
357 |
358 | enterBarrier("member-totally-shutdown")
359 | }
360 |
361 | def awaitAllReachable(): Unit =
362 | awaitAssert(clusterView.unreachableMembers should ===(Set.empty))
363 |
364 | /**
365 | * Wait until the specified nodes have seen the same gossip overview.
366 | */
367 | def awaitSeenSameState(addresses: Address*): Unit =
368 | awaitAssert((addresses.toSet.diff(clusterView.seenBy)) should ===(Set.empty))
369 |
370 | /**
371 | * Leader according to the address ordering of the roles.
372 | * Note that this can only be used for a cluster with all members
373 | * in Up status, i.e. use `awaitMembersUp` before using this method.
374 | * The reason for that is that the cluster leader is preferably a
375 | * member with status Up or Leaving and that information can't
376 | * be determined from the `RoleName`.
377 | */
378 | def roleOfLeader(nodesInCluster: immutable.Seq[RoleName] = roles): RoleName = {
379 | nodesInCluster.length should not be (0)
380 | nodesInCluster.sorted.head
381 | }
382 |
383 | /**
384 | * Sort the roles in the address order used by the cluster node ring.
385 | */
386 | implicit val clusterOrdering: Ordering[RoleName] = new Ordering[RoleName] {
387 | import Member.addressOrdering
388 | def compare(x: RoleName, y: RoleName) = addressOrdering.compare(address(x), address(y))
389 | }
390 |
391 | def roleName(addr: Address): Option[RoleName] = roles.find(address(_) == addr)
392 |
393 | /**
394 | * Marks a node as available in the failure detector if
395 | * [[akka.cluster.FailureDetectorPuppet]] is used as
396 | * failure detector.
397 | */
398 | def markNodeAsAvailable(address: Address): Unit =
399 | failureDetectorPuppet(address).foreach(_.markNodeAsAvailable())
400 |
401 | /**
402 | * Marks a node as unavailable in the failure detector if
403 | * [[akka.cluster.FailureDetectorPuppet]] is used as
404 | * failure detector.
405 | */
406 | def markNodeAsUnavailable(address: Address): Unit = {
407 | if (isFailureDetectorPuppet) {
408 | // before marking it as unavailable there should be at least one heartbeat
409 | // to create the FailureDetectorPuppet in the FailureDetectorRegistry
410 | cluster.failureDetector.heartbeat(address)
411 | failureDetectorPuppet(address).foreach(_.markNodeAsUnavailable())
412 | }
413 | }
414 |
415 | private def isFailureDetectorPuppet: Boolean =
416 | cluster.settings.FailureDetectorImplementationClass == classOf[FailureDetectorPuppet].getName
417 |
418 | private def failureDetectorPuppet(address: Address): Option[FailureDetectorPuppet] =
419 | cluster.failureDetector match {
420 | case reg: DefaultFailureDetectorRegistry[Address] =>
421 | reg.failureDetector(address).collect { case p: FailureDetectorPuppet => p }
422 | case _ => None
423 | }
424 |
425 | }
426 |
--------------------------------------------------------------------------------
/src/multi-jvm/scala/akka/cluster/sharding/MultiNodeClusterShardingConfig.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 Lightbend Inc.
3 | */
4 |
5 | package akka.cluster.sharding
6 |
7 | import akka.cluster.MultiNodeClusterSpec
8 | import akka.remote.testkit.MultiNodeConfig
9 | import akka.testkit.AkkaSpec
10 | import com.typesafe.config.{Config, ConfigFactory}
11 |
12 | /**
13 | * A MultiNodeConfig for ClusterSharding. Implement the roles, etc. and create with the following:
14 | *
15 | * @param mode the state store mode
16 | * @param rememberEntities defaults to off
17 | * @param overrides additional config
18 | * @param loglevel defaults to INFO
19 | */
20 | abstract class MultiNodeClusterShardingConfig(
21 | val mode: String = ClusterShardingSettings.StateStoreModeDData,
22 | val rememberEntities: Boolean = false,
23 | overrides: Config = ConfigFactory.empty,
24 | loglevel: String = "INFO")
25 | extends MultiNodeConfig {
26 |
27 | val targetDir = s"target/ClusterSharding${AkkaSpec.getCallerName(getClass)}Spec-$mode-remember-$rememberEntities"
28 |
29 | val modeConfig =
30 | if (mode == ClusterShardingSettings.StateStoreModeDData) ConfigFactory.empty
31 | else ConfigFactory.parseString(s"""
32 | akka.persistence.journal.plugin = "akka.persistence.journal.leveldb-shared"
33 | akka.persistence.journal.leveldb-shared.timeout = 5s
34 | akka.persistence.journal.leveldb-shared.store.native = off
35 | akka.persistence.journal.leveldb-shared.store.dir = "$targetDir/journal"
36 | akka.persistence.snapshot-store.plugin = "akka.persistence.snapshot-store.local"
37 | akka.persistence.snapshot-store.local.dir = "$targetDir/snapshots"
38 | """)
39 |
40 | commonConfig(
41 | overrides
42 | .withFallback(modeConfig)
43 | .withFallback(ConfigFactory.parseString(s"""
44 | akka.loglevel = $loglevel
45 | akka.actor.provider = "cluster"
46 | akka.cluster.downing-provider-class = akka.cluster.testkit.AutoDowning
47 | akka.cluster.testkit.auto-down-unreachable-after = 0s
48 | akka.remote.log-remote-lifecycle-events = off
49 | akka.cluster.sharding.state-store-mode = "$mode"
50 | akka.cluster.sharding.distributed-data.durable.lmdb {
51 | dir = $targetDir/sharding-ddata
52 | map-size = 10 MiB
53 | }
54 | akka.actor.allow-java-serialization = on
55 | """))
56 | // .withFallback(SharedLeveldbJournal.configToEnableJavaSerializationForTest)
57 | .withFallback(MultiNodeClusterSpec.clusterConfig))
58 |
59 | }
60 |
--------------------------------------------------------------------------------
/src/multi-jvm/scala/akka/cluster/sharding/MultiNodeClusterShardingSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2019 Lightbend Inc.
3 | */
4 |
5 | package akka.cluster.sharding
6 |
7 | import java.io.File
8 |
9 | import scala.concurrent.duration._
10 | import akka.actor.{Actor, ActorIdentity, ActorRef, ActorSystem, Identify, PoisonPill, Props}
11 | import akka.cluster.{Cluster, MemberStatus, MultiNodeClusterSpec}
12 | import akka.persistence.Persistence
13 | import akka.persistence.journal.leveldb.{SharedLeveldbJournal, SharedLeveldbStore}
14 | import akka.remote.testconductor.RoleName
15 | import akka.remote.testkit.MultiNodeSpec
16 | import akka.testkit.TestProbe
17 | import org.apache.commons.io.FileUtils
18 |
19 | object MultiNodeClusterShardingSpec {
20 |
21 | final case class EntityStarted(ref: ActorRef)
22 |
23 | def props(probe: ActorRef): Props = Props(new EntityActor(probe))
24 |
25 | class EntityActor(probe: ActorRef) extends Actor {
26 | probe ! EntityStarted(self)
27 |
28 | def receive: Receive = {
29 | case m => sender() ! m
30 | }
31 | }
32 |
33 | val defaultExtractEntityId: ShardRegion.ExtractEntityId = {
34 | case id: Int => (id.toString, id)
35 | }
36 |
37 | val defaultExtractShardId: ShardRegion.ExtractShardId = msg =>
38 | msg match {
39 | case id: Int => id.toString
40 | case ShardRegion.StartEntity(id) => id
41 | }
42 |
43 | }
44 |
45 | abstract class MultiNodeClusterShardingSpec(val config: MultiNodeClusterShardingConfig)
46 | extends MultiNodeSpec(config)
47 | with MultiNodeClusterSpec {
48 |
49 | import MultiNodeClusterShardingSpec._
50 | import config._
51 |
52 | override def initialParticipants: Int = roles.size
53 |
54 | protected val storageLocations = List(
55 | new File(system.settings.config.getString("akka.cluster.sharding.distributed-data.durable.lmdb.dir")).getParentFile)
56 |
57 | override protected def atStartup(): Unit = {
58 | storageLocations.foreach(dir => if (dir.exists) FileUtils.deleteQuietly(dir))
59 | enterBarrier("startup")
60 | super.atStartup()
61 | }
62 |
63 | override protected def afterTermination(): Unit = {
64 | storageLocations.foreach(dir => if (dir.exists) FileUtils.deleteQuietly(dir))
65 | super.afterTermination()
66 | }
67 |
68 | protected def join(from: RoleName, to: RoleName): Unit = {
69 | runOn(from) {
70 | Cluster(system).join(node(to).address)
71 | awaitAssert {
72 | Cluster(system).state.members.exists(m => m.address == node(from).address && m.status == MemberStatus.Up)
73 | }
74 | }
75 | enterBarrier(from.name + "-joined")
76 | }
77 |
78 | protected def startSharding(
79 | sys: ActorSystem,
80 | entityProps: Props,
81 | dataType: String,
82 | extractEntityId: ShardRegion.ExtractEntityId = defaultExtractEntityId,
83 | extractShardId: ShardRegion.ExtractShardId = defaultExtractShardId,
84 | handOffStopMessage: Any = PoisonPill): ActorRef = {
85 |
86 | ClusterSharding(sys).start(
87 | typeName = dataType,
88 | entityProps = entityProps,
89 | settings = ClusterShardingSettings(sys).withRememberEntities(rememberEntities),
90 | extractEntityId = extractEntityId,
91 | extractShardId = extractShardId,
92 | ClusterSharding(sys).defaultShardAllocationStrategy(ClusterShardingSettings(sys)),
93 | handOffStopMessage)
94 | }
95 |
96 | protected def isDdataMode: Boolean = mode == ClusterShardingSettings.StateStoreModeDData
97 |
98 | private def setStoreIfNotDdataMode(sys: ActorSystem, storeOn: RoleName): Unit =
99 | if (!isDdataMode) {
100 | val probe = TestProbe()(sys)
101 | sys.actorSelection(node(storeOn) / "user" / "store").tell(Identify(None), probe.ref)
102 | val sharedStore = probe.expectMsgType[ActorIdentity](20.seconds).ref.get
103 | SharedLeveldbJournal.setStore(sharedStore, sys)
104 | }
105 |
106 | /**
107 | * {{{
108 | * startPersistence(startOn = first, setStoreOn = Seq(first, second, third))
109 | * }}}
110 | *
111 | * @param startOn the node to start the `SharedLeveldbStore` store on
112 | * @param setStoreOn the nodes to `SharedLeveldbJournal.setStore` on
113 | */
114 | protected def startPersistenceIfNotDdataMode(startOn: RoleName, setStoreOn: Seq[RoleName]): Unit =
115 | if (!isDdataMode) {
116 |
117 | Persistence(system)
118 | runOn(startOn) {
119 | system.actorOf(Props[SharedLeveldbStore], "store")
120 | }
121 | enterBarrier("persistence-started")
122 |
123 | runOn(setStoreOn: _*) {
124 | setStoreIfNotDdataMode(system, startOn)
125 | }
126 |
127 | enterBarrier(s"after-${startOn.name}")
128 |
129 | }
130 |
131 | }
132 |
133 |
--------------------------------------------------------------------------------
/src/multi-jvm/scala/akka/cluster/testkit/AutoDowning.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2009-2019 Lightbend Inc.
3 | */
4 |
5 | package akka.cluster.testkit
6 |
7 | import scala.concurrent.duration.Duration
8 | import scala.concurrent.duration.FiniteDuration
9 |
10 | import akka.actor.Actor
11 | import akka.actor.ActorLogging
12 | import akka.actor.ActorSystem
13 | import akka.actor.Address
14 | import akka.actor.Cancellable
15 | import akka.actor.Props
16 | import akka.actor.Scheduler
17 | import akka.cluster.Cluster
18 | import akka.cluster.ClusterEvent._
19 | import akka.cluster.DowningProvider
20 | import akka.cluster.Member
21 | import akka.cluster.MembershipState
22 | import akka.cluster.UniqueAddress
23 | import akka.util.Helpers.ConfigOps
24 | import akka.util.Helpers.Requiring
25 | import akka.util.Helpers.toRootLowerCase
26 |
27 | /**
28 | * Downing provider used for testing.
29 | *
30 | * Auto-downing is a naïve approach to remove unreachable nodes from the cluster membership.
31 | * In a production environment it will eventually break down the cluster.
32 | * When a network partition occurs, both sides of the partition will see the other side as unreachable
33 | * and remove it from the cluster. This results in the formation of two separate, disconnected, clusters
34 | * (known as *Split Brain*).
35 | *
36 | * This behavior is not limited to network partitions. It can also occur if a node in the cluster is
37 | * overloaded, or experiences a long GC pause.
38 | *
39 | * When using Cluster Singleton or Cluster Sharding it can break the contract provided by those features.
40 | * Both provide a guarantee that an actor will be unique in a cluster.
41 | * With the auto-down feature enabled, it is possible for multiple independent clusters to form (*Split Brain*).
42 | * When this happens the guaranteed uniqueness will no longer be true resulting in undesirable behavior
43 | * in the system.
44 | *
45 | * This is even more severe when Akka Persistence is used in conjunction with Cluster Sharding.
46 | * In this case, the lack of unique actors can cause multiple actors to write to the same journal.
47 | * Akka Persistence operates on a single writer principle. Having multiple writers will corrupt
48 | * the journal and make it unusable.
49 | *
50 | * Finally, even if you don't use features such as Persistence, Sharding, or Singletons, auto-downing can lead the
51 | * system to form multiple small clusters. These small clusters will be independent from each other. They will be
52 | * unable to communicate and as a result you may experience performance degradation. Once this condition occurs,
53 | * it will require manual intervention in order to reform the cluster.
54 | *
55 | * Because of these issues, auto-downing should never be used in a production environment.
56 | */
57 | final class AutoDowning(system: ActorSystem) extends DowningProvider {
58 |
59 | private def clusterSettings = Cluster(system).settings
60 |
61 | private val AutoDownUnreachableAfter: Duration = {
62 | val key = "akka.cluster.testkit.auto-down-unreachable-after"
63 | // it's not in reference.conf, since only used in tests
64 | if (clusterSettings.config.hasPath(key)) {
65 | toRootLowerCase(clusterSettings.config.getString(key)) match {
66 | case "off" => Duration.Undefined
67 | case _ => clusterSettings.config.getMillisDuration(key).requiring(_ >= Duration.Zero, key + " >= 0s, or off")
68 | }
69 | } else
70 | Duration.Undefined
71 | }
72 |
73 | override def downRemovalMargin: FiniteDuration = clusterSettings.DownRemovalMargin
74 |
75 | override def downingActorProps: Option[Props] =
76 | AutoDownUnreachableAfter match {
77 | case d: FiniteDuration => Some(AutoDown.props(d))
78 | case _ => None // auto-down-unreachable-after = off
79 | }
80 | }
81 |
82 | /**
83 | * INTERNAL API
84 | */
85 | private[cluster] object AutoDown {
86 |
87 | def props(autoDownUnreachableAfter: FiniteDuration): Props =
88 | Props(classOf[AutoDown], autoDownUnreachableAfter)
89 |
90 | final case class UnreachableTimeout(node: UniqueAddress)
91 | }
92 |
93 | /**
94 | * INTERNAL API
95 | *
96 | * An unreachable member will be downed by this actor if it remains unreachable
97 | * for the specified duration and this actor is running on the leader node in the
98 | * cluster.
99 | *
100 | * The implementation is split into two classes AutoDown and AutoDownBase to be
101 | * able to unit test the logic without running cluster.
102 | */
103 | private[cluster] class AutoDown(autoDownUnreachableAfter: FiniteDuration)
104 | extends AutoDownBase(autoDownUnreachableAfter)
105 | with ActorLogging {
106 |
107 | val cluster = Cluster(context.system)
108 | import cluster.ClusterLogger._
109 |
110 | override def selfAddress = cluster.selfAddress
111 |
112 | override def scheduler: Scheduler = cluster.scheduler
113 |
114 | // re-subscribe when restart
115 | override def preStart(): Unit = {
116 | log.debug("Auto-down is enabled in test.")
117 | cluster.subscribe(self, classOf[ClusterDomainEvent])
118 | super.preStart()
119 | }
120 | override def postStop(): Unit = {
121 | cluster.unsubscribe(self)
122 | super.postStop()
123 | }
124 |
125 | override def down(node: Address): Unit = {
126 | require(leader)
127 | logInfo("Leader is auto-downing unreachable node [{}].", node)
128 | cluster.down(node)
129 | }
130 |
131 | }
132 |
133 | /**
134 | * INTERNAL API
135 | *
136 | * The implementation is split into two classes AutoDown and AutoDownBase to be
137 | * able to unit test the logic without running cluster.
138 | */
139 | private[cluster] abstract class AutoDownBase(autoDownUnreachableAfter: FiniteDuration) extends Actor {
140 |
141 | import AutoDown._
142 |
143 | def selfAddress: Address
144 |
145 | def down(node: Address): Unit
146 |
147 | def scheduler: Scheduler
148 |
149 | import context.dispatcher
150 |
151 | val skipMemberStatus = MembershipState.convergenceSkipUnreachableWithMemberStatus
152 |
153 | var scheduledUnreachable: Map[UniqueAddress, Cancellable] = Map.empty
154 | var pendingUnreachable: Set[UniqueAddress] = Set.empty
155 | var leader = false
156 |
157 | override def postStop(): Unit = {
158 | scheduledUnreachable.values.foreach { _.cancel }
159 | }
160 |
161 | def receive = {
162 | case state: CurrentClusterState =>
163 | leader = state.leader.exists(_ == selfAddress)
164 | state.unreachable.foreach(unreachableMember)
165 |
166 | case UnreachableMember(m) => unreachableMember(m)
167 |
168 | case ReachableMember(m) => remove(m.uniqueAddress)
169 | case MemberRemoved(m, _) => remove(m.uniqueAddress)
170 |
171 | case LeaderChanged(leaderOption) =>
172 | leader = leaderOption.exists(_ == selfAddress)
173 | if (leader) {
174 | pendingUnreachable.foreach(node => down(node.address))
175 | pendingUnreachable = Set.empty
176 | }
177 |
178 | case UnreachableTimeout(node) =>
179 | if (scheduledUnreachable contains node) {
180 | scheduledUnreachable -= node
181 | downOrAddPending(node)
182 | }
183 |
184 | case _: ClusterDomainEvent => // not interested in other events
185 |
186 | }
187 |
188 | def unreachableMember(m: Member): Unit =
189 | if (!skipMemberStatus(m.status) && !scheduledUnreachable.contains(m.uniqueAddress))
190 | scheduleUnreachable(m.uniqueAddress)
191 |
192 | def scheduleUnreachable(node: UniqueAddress): Unit = {
193 | if (autoDownUnreachableAfter == Duration.Zero) {
194 | downOrAddPending(node)
195 | } else {
196 | val task = scheduler.scheduleOnce(autoDownUnreachableAfter, self, UnreachableTimeout(node))
197 | scheduledUnreachable += (node -> task)
198 | }
199 | }
200 |
201 | def downOrAddPending(node: UniqueAddress): Unit = {
202 | if (leader) {
203 | down(node.address)
204 | } else {
205 | // it's supposed to be downed by another node, current leader, but if that crash
206 | // a new leader must pick up these
207 | pendingUnreachable += node
208 | }
209 | }
210 |
211 | def remove(node: UniqueAddress): Unit = {
212 | scheduledUnreachable.get(node).foreach { _.cancel }
213 | scheduledUnreachable -= node
214 | pendingUnreachable -= node
215 | }
216 |
217 | }
218 |
--------------------------------------------------------------------------------
/src/multi-jvm/scala/akka/remote/testkit/STMultiNodeSpec.scala:
--------------------------------------------------------------------------------
1 | package akka.remote.testkit
2 |
3 | import org.scalatest.{BeforeAndAfterAll, WordSpecLike}
4 | import org.scalatest.Matchers
5 |
6 | /**
7 | * Hooks up MultiNodeSpec with ScalaTest
8 | */
9 | trait STMultiNodeSpec extends MultiNodeSpecCallbacks with WordSpecLike with Matchers with BeforeAndAfterAll {
10 |
11 | override def beforeAll() = multiNodeSpecBeforeAll()
12 |
13 | override def afterAll() = multiNodeSpecAfterAll()
14 | }
--------------------------------------------------------------------------------
/src/multi-jvm/scala/akka/serialization/jackson/CborSerializable.scala:
--------------------------------------------------------------------------------
1 | package akka.serialization.jackson
2 |
3 | trait CborSerializable
--------------------------------------------------------------------------------
/src/multi-jvm/scala/akka/testkit/AkkaSpec.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (C) 2009-2019 Lightbend Inc.
3 | */
4 |
5 | package akka.testkit
6 |
7 | import org.scalactic.{ CanEqual, TypeCheckedTripleEquals }
8 |
9 | import language.postfixOps
10 | import org.scalatest.{ BeforeAndAfterAll, WordSpecLike }
11 | import org.scalatest.Matchers
12 | import akka.actor.ActorSystem
13 | import akka.event.{ Logging, LoggingAdapter }
14 |
15 | import scala.concurrent.Future
16 | import com.typesafe.config.{ Config, ConfigFactory }
17 | import akka.dispatch.Dispatchers
18 | import akka.testkit.TestEvent._
19 | import org.scalatest.concurrent.ScalaFutures
20 | import org.scalatest.time.{ Millis, Span }
21 |
22 | object AkkaSpec {
23 | val testConf: Config = ConfigFactory.parseString("""
24 | akka {
25 | loggers = ["akka.testkit.TestEventListener"]
26 | loglevel = "WARNING"
27 | stdout-loglevel = "WARNING"
28 | actor {
29 | default-dispatcher {
30 | executor = "fork-join-executor"
31 | fork-join-executor {
32 | parallelism-min = 8
33 | parallelism-factor = 2.0
34 | parallelism-max = 8
35 | }
36 | }
37 | }
38 | }
39 | """)
40 |
41 | def mapToConfig(map: Map[String, Any]): Config = {
42 | import akka.util.ccompat.JavaConverters._
43 | ConfigFactory.parseMap(map.asJava)
44 | }
45 |
46 | def getCallerName(clazz: Class[_]): String = {
47 | val s = Thread.currentThread.getStackTrace
48 | .map(_.getClassName)
49 | .drop(1)
50 | .dropWhile(_.matches("(java.lang.Thread|.*AkkaSpec.*|.*\\.StreamSpec.*|.*MultiNodeSpec.*|.*\\.Abstract.*)"))
51 | val reduced = s.lastIndexWhere(_ == clazz.getName) match {
52 | case -1 => s
53 | case z => s.drop(z + 1)
54 | }
55 | reduced.head.replaceFirst(""".*\.""", "").replaceAll("[^a-zA-Z_0-9]", "_")
56 | }
57 |
58 | }
59 |
60 | abstract class AkkaSpec(_system: ActorSystem)
61 | extends TestKit(_system)
62 | with WordSpecLike
63 | with Matchers
64 | with BeforeAndAfterAll
65 | with TypeCheckedTripleEquals
66 | with ScalaFutures {
67 |
68 | implicit val patience = PatienceConfig(testKitSettings.DefaultTimeout.duration, Span(100, Millis))
69 |
70 | def this(config: Config) =
71 | this(ActorSystem(AkkaSpec.getCallerName(getClass), ConfigFactory.load(config.withFallback(AkkaSpec.testConf))))
72 |
73 | def this(s: String) = this(ConfigFactory.parseString(s))
74 |
75 | def this(configMap: Map[String, _]) = this(AkkaSpec.mapToConfig(configMap))
76 |
77 | def this() = this(ActorSystem(AkkaSpec.getCallerName(getClass), AkkaSpec.testConf))
78 |
79 | val log: LoggingAdapter = Logging(system, this.getClass)
80 |
81 | override val invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected = true
82 |
83 | final override def beforeAll: Unit = {
84 | atStartup()
85 | }
86 |
87 | final override def afterAll: Unit = {
88 | beforeTermination()
89 | shutdown()
90 | afterTermination()
91 | }
92 |
93 | protected def atStartup(): Unit = {}
94 |
95 | protected def beforeTermination(): Unit = {}
96 |
97 | protected def afterTermination(): Unit = {}
98 |
99 | def spawn(dispatcherId: String = Dispatchers.DefaultDispatcherId)(body: => Unit): Unit =
100 | Future(body)(system.dispatchers.lookup(dispatcherId))
101 |
102 | def muteDeadLetters(messageClasses: Class[_]*)(sys: ActorSystem = system): Unit =
103 | if (!sys.log.isDebugEnabled) {
104 | def mute(clazz: Class[_]): Unit =
105 | sys.eventStream.publish(Mute(DeadLettersFilter(clazz)(occurrences = Int.MaxValue)))
106 | if (messageClasses.isEmpty) mute(classOf[AnyRef])
107 | else messageClasses.foreach(mute)
108 | }
109 |
110 | // for ScalaTest === compare of Class objects
111 | implicit def classEqualityConstraint[A, B]: CanEqual[Class[A], Class[B]] =
112 | new CanEqual[Class[A], Class[B]] {
113 | def areEqual(a: Class[A], b: Class[B]) = a == b
114 | }
115 |
116 | implicit def setEqualityConstraint[A, T <: Set[_ <: A]]: CanEqual[Set[A], T] =
117 | new CanEqual[Set[A], T] {
118 | def areEqual(a: Set[A], b: T) = a == b
119 | }
120 | }
121 |
--------------------------------------------------------------------------------
/src/multi-jvm/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRouterNewShardsSpec.scala:
--------------------------------------------------------------------------------
1 | package io.bernhardt.akka.locality.router
2 |
3 | import akka.actor.{ActorRef, Props}
4 | import akka.cluster.routing.{ClusterRouterGroup, ClusterRouterGroupSettings}
5 | import akka.cluster.sharding.{MultiNodeClusterShardingConfig, MultiNodeClusterShardingSpec}
6 | import akka.pattern.ask
7 | import akka.remote.testconductor.RoleName
8 | import akka.remote.testkit.STMultiNodeSpec
9 | import akka.routing.{GetRoutees, Routees}
10 | import akka.testkit.{DefaultTimeout, ImplicitSender, TestProbe}
11 | import com.typesafe.config.ConfigFactory
12 | import io.bernhardt.akka.locality.Locality
13 |
14 | import scala.concurrent.Await
15 | import scala.concurrent.duration._
16 |
17 | object ShardLocationAwareRouterNewShardsSpecConfig extends MultiNodeClusterShardingConfig(
18 | overrides = ConfigFactory.parseString(
19 | """akka.cluster.sharding.distributed-data.majority-min-cap = 2
20 | |akka.cluster.distributed-data.gossip-interval = 200 ms
21 | |akka.cluster.distributed-data.notify-subscribers-interval = 200 ms
22 | |akka.cluster.sharding.updating-state-timeout = 500 ms
23 | |akka.locality.shard-state-update-margin = 1000 ms
24 | |akka.locality.shard-state-polling-interval = 2000 ms
25 | |""".stripMargin)) {
26 | val first = role("first")
27 | val second = role("second")
28 | val third = role("third")
29 | val fourth = role("fourth")
30 | val fifth = role("fifth")
31 | }
32 |
33 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode1 extends ShardLocationAwareRouterNewShardsSpec
34 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode2 extends ShardLocationAwareRouterNewShardsSpec
35 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode3 extends ShardLocationAwareRouterNewShardsSpec
36 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode4 extends ShardLocationAwareRouterNewShardsSpec
37 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode5 extends ShardLocationAwareRouterNewShardsSpec
38 |
39 | class ShardLocationAwareRouterNewShardsSpec extends MultiNodeClusterShardingSpec(ShardLocationAwareRouterNewShardsSpecConfig)
40 | with STMultiNodeSpec
41 | with DefaultTimeout
42 | with ImplicitSender {
43 |
44 | import ShardLocationAwareRouterNewShardsSpecConfig._
45 | import ShardLocationAwareRouterSpec._
46 |
47 | var region: Option[ActorRef] = None
48 |
49 | var router: ActorRef = ActorRef.noSender
50 |
51 | Locality(system)
52 |
53 | def joinAndAllocate(node: RoleName, entityIds: Range): Unit = {
54 | within(10.seconds) {
55 | join(node, first)
56 | runOn(node) {
57 | val region = startSharding(
58 | sys = system,
59 | entityProps = Props[TestEntity],
60 | dataType = "TestEntity",
61 | extractEntityId = extractEntityId,
62 | extractShardId = extractShardId)
63 |
64 | this.region = Some(region)
65 |
66 | entityIds.map { entityId =>
67 | val probe = TestProbe("test")
68 | val msg = Ping(entityId, ActorRef.noSender)
69 | probe.send(region, msg)
70 | probe.expectMsgType[Pong]
71 | probe.lastSender.path should be(region.path / s"$entityId" / s"$entityId")
72 | }
73 | }
74 | }
75 | enterBarrier(s"started")
76 | }
77 |
78 | "allocate shards" in {
79 |
80 | joinAndAllocate(first, (1 to 10))
81 | joinAndAllocate(second, (11 to 20))
82 | joinAndAllocate(third, (21 to 30))
83 | joinAndAllocate(fourth, (31 to 40))
84 | joinAndAllocate(fifth, (41 to 50))
85 |
86 | enterBarrier("shards-allocated")
87 |
88 | }
89 |
90 | "route by taking into account shard location" in {
91 | within(20.seconds) {
92 |
93 | region.map { r =>
94 | system.actorOf(Props(new TestRoutee(r)), "routee")
95 | enterBarrier("routee-started")
96 |
97 | router = system.actorOf(ClusterRouterGroup(ShardLocationAwareGroup(
98 | routeePaths = Nil,
99 | shardRegion = r,
100 | extractEntityId = extractEntityId,
101 | extractShardId = extractShardId
102 | ), ClusterRouterGroupSettings(
103 | totalInstances = 5,
104 | routeesPaths = List("/user/routee"),
105 | allowLocalRoutees = true
106 | )).props(), "sharding-aware-router")
107 |
108 | awaitAssert {
109 | currentRoutees(router).size shouldBe 5
110 | }
111 |
112 | enterBarrier("router-started")
113 |
114 | runOn(first) {
115 | val probe = TestProbe("probe")
116 | for (i <- 1 to 50) {
117 | probe.send(router, Ping(i, probe.ref))
118 | }
119 |
120 | val msgs: Seq[Pong] = probe.receiveN(50).collect { case p: Pong => p }
121 |
122 | val (same: Seq[Pong], different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) =>
123 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
124 | }
125 |
126 | different.isEmpty shouldBe true
127 | }
128 | enterBarrier("test-done")
129 |
130 | } getOrElse {
131 | fail("Region not set")
132 | }
133 | }
134 |
135 | }
136 |
137 | "route randomly when new shards are created" in {
138 | runOn(first) {
139 | val probe = TestProbe("probe")
140 | for (i <- 51 to 100) {
141 | probe.send(router, Ping(i, probe.ref))
142 | }
143 |
144 | val msgs: Seq[Pong] = probe.receiveN(50).collect { case p: Pong => p }
145 |
146 | val (_, different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) =>
147 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
148 | }
149 |
150 | different.isEmpty shouldBe false
151 | }
152 | enterBarrier("test-done")
153 | }
154 |
155 | "route with location awareness after the update margin has elapsed" in {
156 |
157 | Thread.sleep(4000)
158 |
159 | runOn(first) {
160 | val probe = TestProbe("probe")
161 | for (i <- 51 to 100) {
162 | probe.send(router, Ping(i, probe.ref))
163 | }
164 |
165 | val msgs: Seq[Pong] = probe.receiveN(50).collect { case p: Pong => p }
166 |
167 | val (_, different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) =>
168 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
169 | }
170 |
171 | different.isEmpty shouldBe true
172 | }
173 | enterBarrier("test-done")
174 | }
175 |
176 |
177 |
178 | def currentRoutees(router: ActorRef) =
179 | Await.result(router ? GetRoutees, timeout.duration).asInstanceOf[Routees].routees
180 |
181 | def partitionByAddress(msgs: Seq[Pong]) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) =>
182 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
183 | }
184 |
185 |
186 | }
187 |
--------------------------------------------------------------------------------
/src/multi-jvm/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRouterSpec.scala:
--------------------------------------------------------------------------------
1 | package io.bernhardt.akka.locality.router
2 |
3 | import akka.actor.{Actor, ActorRef, Address, Props}
4 | import akka.cluster.Cluster
5 | import akka.cluster.routing.{ClusterRouterGroup, ClusterRouterGroupSettings}
6 | import akka.cluster.sharding.{MultiNodeClusterShardingConfig, MultiNodeClusterShardingSpec, ShardRegion}
7 | import akka.pattern.ask
8 | import akka.remote.testconductor.RoleName
9 | import akka.remote.testkit.STMultiNodeSpec
10 | import akka.routing.{GetRoutees, Routees}
11 | import akka.serialization.jackson.CborSerializable
12 | import akka.testkit.{DefaultTimeout, ImplicitSender, TestProbe}
13 | import com.typesafe.config.ConfigFactory
14 | import io.bernhardt.akka.locality.Locality
15 |
16 | import scala.concurrent.Await
17 | import scala.concurrent.duration._
18 |
19 | object ShardLocationAwareRouterSpec {
20 |
21 | class TestEntity extends Actor {
22 | val cluster = Cluster(context.system)
23 | def receive: Receive = {
24 |
25 | case ping: Ping =>
26 | sender() ! Pong(ping.id, ping.sender, cluster.selfAddress, cluster.selfAddress)
27 | }
28 | }
29 |
30 | final case class Ping(id: Int, sender: ActorRef) extends CborSerializable
31 | final case class Pong(id: Int, sender: ActorRef, routeeAddress: Address, entityAddress: Address) extends CborSerializable
32 |
33 | val extractEntityId: ShardRegion.ExtractEntityId = {
34 | case msg@Ping(id, _) => (id.toString, msg)
35 | case msg@Pong(id, _, _, _) => (id.toString, msg)
36 | }
37 |
38 | val extractShardId: ShardRegion.ExtractShardId = {
39 | // take this simplest mapping on purpose
40 | case Ping(id, _) => id.toString
41 | case Pong(id, _, _, _) => id.toString
42 | }
43 |
44 | class TestRoutee(region: ActorRef) extends Actor {
45 | val cluster = Cluster(context.system)
46 | def receive: Receive = {
47 | case msg: Ping =>
48 | region ! msg
49 | case pong: Pong =>
50 | pong.sender ! pong.copy(routeeAddress = cluster.selfAddress)
51 | }
52 | }
53 |
54 | }
55 |
56 |
57 | object ShardLocationAwareRouterSpecConfig extends MultiNodeClusterShardingConfig(
58 | overrides = ConfigFactory.parseString(
59 | """akka.cluster.sharding.distributed-data.majority-min-cap = 2
60 | |akka.cluster.distributed-data.gossip-interval = 200 ms
61 | |akka.cluster.distributed-data.notify-subscribers-interval = 200 ms
62 | |akka.cluster.sharding.updating-state-timeout = 500 ms
63 | |akka.locality.shard-state-update-margin = 1000 ms
64 | |""".stripMargin)) {
65 | val first = role("first")
66 | val second = role("second")
67 | val third = role("third")
68 | val fourth = role("fourth")
69 | val fifth = role("fifth")
70 | }
71 |
72 | class ShardLocationAwareRouterSpecMultiJvmNode1 extends ShardLocationAwareRouterSpec
73 | class ShardLocationAwareRouterSpecMultiJvmNode2 extends ShardLocationAwareRouterSpec
74 | class ShardLocationAwareRouterSpecMultiJvmNode3 extends ShardLocationAwareRouterSpec
75 | class ShardLocationAwareRouterSpecMultiJvmNode4 extends ShardLocationAwareRouterSpec
76 | class ShardLocationAwareRouterSpecMultiJvmNode5 extends ShardLocationAwareRouterSpec
77 |
78 | class ShardLocationAwareRouterSpec extends MultiNodeClusterShardingSpec(ShardLocationAwareRouterSpecConfig)
79 | with STMultiNodeSpec
80 | with DefaultTimeout
81 | with ImplicitSender {
82 |
83 | import ShardLocationAwareRouterSpec._
84 | import ShardLocationAwareRouterSpecConfig._
85 |
86 | var region: Option[ActorRef] = None
87 |
88 | var router: ActorRef = ActorRef.noSender
89 |
90 | Locality(system)
91 |
92 | def joinAndAllocate(node: RoleName, entityIds: Range): Unit = {
93 | within(10.seconds) {
94 | join(node, first)
95 | runOn(node) {
96 | val region = startSharding(
97 | sys = system,
98 | entityProps = Props[TestEntity],
99 | dataType = "TestEntity",
100 | extractEntityId = extractEntityId,
101 | extractShardId = extractShardId)
102 |
103 | this.region = Some(region)
104 |
105 | entityIds.map { entityId =>
106 | val probe = TestProbe("test")
107 | val msg = Ping(entityId, ActorRef.noSender)
108 | probe.send(region, msg)
109 | probe.expectMsgType[Pong]
110 | probe.lastSender.path should be(region.path / s"$entityId" / s"$entityId")
111 | }
112 | }
113 | }
114 | enterBarrier(s"started")
115 | }
116 |
117 | "allocate shards" in {
118 |
119 | joinAndAllocate(first, (1 to 10))
120 | joinAndAllocate(second, (11 to 20))
121 | joinAndAllocate(third, (21 to 30))
122 | joinAndAllocate(fourth, (31 to 40))
123 | joinAndAllocate(fifth, (41 to 50))
124 |
125 | enterBarrier("shards-allocated")
126 |
127 | }
128 |
129 | "route by taking into account shard location" in {
130 | within(20.seconds) {
131 |
132 | region.map { r =>
133 | system.actorOf(Props(new TestRoutee(r)), "routee")
134 | enterBarrier("routee-started")
135 |
136 | router = system.actorOf(ClusterRouterGroup(ShardLocationAwareGroup(
137 | routeePaths = Nil,
138 | shardRegion = r,
139 | extractEntityId = extractEntityId,
140 | extractShardId = extractShardId
141 | ), ClusterRouterGroupSettings(
142 | totalInstances = 5,
143 | routeesPaths = List("/user/routee"),
144 | allowLocalRoutees = true
145 | )).props(), "sharding-aware-router")
146 |
147 | awaitAssert {
148 | currentRoutees(router).size shouldBe 5
149 | }
150 |
151 | enterBarrier("router-started")
152 |
153 | runOn(first) {
154 | val probe = TestProbe("probe")
155 | for (i <- 1 to 50) {
156 | probe.send(router, Ping(i, probe.ref))
157 | }
158 |
159 | val msgs: Seq[Pong] = probe.receiveN(50).collect { case p: Pong => p }
160 |
161 | val (same: Seq[Pong], different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) =>
162 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
163 | }
164 |
165 | different.isEmpty shouldBe true
166 |
167 | val byAddress = same.groupBy(_.routeeAddress)
168 |
169 | byAddress(first).size shouldBe 10
170 | byAddress(first).map(_.id).toSet shouldEqual (1 to 10).toSet
171 | byAddress(second).size shouldBe 10
172 | byAddress(second).map(_.id).toSet shouldEqual (11 to 20).toSet
173 | byAddress(third).size shouldBe 10
174 | byAddress(third).map(_.id).toSet shouldEqual (21 to 30).toSet
175 | byAddress(fourth).size shouldBe 10
176 | byAddress(fourth).map(_.id).toSet shouldEqual (31 to 40).toSet
177 | byAddress(fifth).size shouldBe 10
178 | byAddress(fifth).map(_.id).toSet shouldEqual (41 to 50).toSet
179 | }
180 | enterBarrier("test-done")
181 |
182 | } getOrElse {
183 | fail("Region not set")
184 | }
185 | }
186 |
187 | }
188 |
189 | "adjust routing after a topology change" in {
190 | awaitMemberRemoved(fourth)
191 | awaitAllReachable()
192 |
193 | runOn(first) {
194 | // trigger rebalancing the shards of the removed node
195 | val rebalanceProbe = TestProbe("rebalance")
196 | for (i <- 31 to 40) {
197 | rebalanceProbe.send(router, Ping(i, rebalanceProbe.ref))
198 | }
199 |
200 | // we should be receiving messages even in the absence of the updated shard location information
201 | // random routing should kick in, i.e. we won't have perfect matches
202 | val randomRoutedMessages: Seq[Pong] = rebalanceProbe.receiveN(10, 15.seconds).collect { case p: Pong => p }
203 | val (_, differentMsgs) = partitionByAddress(randomRoutedMessages)
204 | differentMsgs.nonEmpty shouldBe true
205 |
206 | // now give time to the new shards to be allocated and time to the router to retrieve new information
207 | // one second to send out the update shard information and one second of safety margin
208 | Thread.sleep(2000)
209 |
210 | val probe = TestProbe("probe")
211 | for (i <- 1 to 50) {
212 | probe.send(router, Ping(i, probe.ref))
213 | }
214 |
215 | val msgs: Seq[Pong] = probe.receiveN(50, 15.seconds).collect { case p: Pong => p }
216 |
217 | val (_, different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) =>
218 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
219 | }
220 |
221 | different.isEmpty shouldBe true
222 |
223 | }
224 |
225 | enterBarrier("finished")
226 | }
227 |
228 |
229 | def currentRoutees(router: ActorRef) =
230 | Await.result(router ? GetRoutees, timeout.duration).asInstanceOf[Routees].routees
231 |
232 | def partitionByAddress(msgs: Seq[Pong]) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) =>
233 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
234 | }
235 |
236 | }
237 |
--------------------------------------------------------------------------------
/src/multi-jvm/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRouterWithProxySpec.scala:
--------------------------------------------------------------------------------
1 | package io.bernhardt.akka.locality.router
2 |
3 | import akka.actor.{ ActorRef, Props }
4 | import akka.cluster.routing.{ ClusterRouterGroup, ClusterRouterGroupSettings }
5 | import akka.cluster.sharding.{ ClusterSharding, MultiNodeClusterShardingConfig, MultiNodeClusterShardingSpec }
6 | import akka.pattern.ask
7 | import akka.remote.testconductor.RoleName
8 | import akka.remote.testkit.STMultiNodeSpec
9 | import akka.routing.{ GetRoutees, Routees }
10 | import akka.testkit.{ DefaultTimeout, ImplicitSender, TestProbe }
11 | import com.typesafe.config.ConfigFactory
12 | import io.bernhardt.akka.locality.Locality
13 |
14 | import scala.concurrent.Await
15 | import scala.concurrent.duration._
16 |
17 | object ShardLocationAwareRouterWithProxySpecConfig
18 | extends MultiNodeClusterShardingConfig(
19 | loglevel = "INFO",
20 | overrides = ConfigFactory.parseString(
21 | """akka.cluster.sharding.distributed-data.majority-min-cap = 2
22 | |akka.cluster.distributed-data.gossip-interval = 200 ms
23 | |akka.cluster.distributed-data.notify-subscribers-interval = 200 ms
24 | |akka.cluster.sharding.updating-state-timeout = 500 ms
25 | |akka.locality.shard-state-update-margin = 1000 ms
26 | |""".stripMargin)) {
27 | val first = role("first")
28 | val second = role("second")
29 | val third = role("third")
30 | val fourth = role("fourth")
31 | val fifth = role("fifth")
32 |
33 | }
34 |
35 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode1 extends ShardLocationAwareRouterWithProxySpec
36 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode2 extends ShardLocationAwareRouterWithProxySpec
37 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode3 extends ShardLocationAwareRouterWithProxySpec
38 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode4 extends ShardLocationAwareRouterWithProxySpec
39 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode5 extends ShardLocationAwareRouterWithProxySpec
40 |
41 | class ShardLocationAwareRouterWithProxySpec
42 | extends MultiNodeClusterShardingSpec(ShardLocationAwareRouterWithProxySpecConfig)
43 | with STMultiNodeSpec
44 | with DefaultTimeout
45 | with ImplicitSender {
46 |
47 | import ShardLocationAwareRouterSpec._
48 | import ShardLocationAwareRouterWithProxySpecConfig._
49 |
50 | var region: Option[ActorRef] = None
51 | var proxyRegion: Option[ActorRef] = None
52 |
53 | var router: ActorRef = ActorRef.noSender
54 | var proxyRouter: ActorRef = ActorRef.noSender
55 |
56 | Locality(system)
57 |
58 | def joinAndAllocate(node: RoleName, entityIds: Range): Unit = {
59 | within(10.seconds) {
60 | join(node, first)
61 | runOn(node) {
62 | val region = startSharding(
63 | sys = system,
64 | entityProps = Props[TestEntity],
65 | dataType = "TestEntity",
66 | extractEntityId = extractEntityId,
67 | extractShardId = extractShardId)
68 |
69 | this.region = Some(region)
70 |
71 | entityIds.map { entityId =>
72 | val probe = TestProbe("test")
73 | val msg = Ping(entityId, ActorRef.noSender)
74 | probe.send(region, msg)
75 | probe.expectMsgType[Pong]
76 | probe.lastSender.path should be(region.path / s"$entityId" / s"$entityId")
77 | }
78 | }
79 | }
80 | enterBarrier(s"started")
81 | }
82 |
83 | "allocate shards" in {
84 |
85 | joinAndAllocate(first, (1 to 10))
86 | joinAndAllocate(second, (11 to 20))
87 | joinAndAllocate(third, (21 to 30))
88 | joinAndAllocate(fourth, (31 to 40))
89 |
90 | join(fifth, first)
91 |
92 | runOn(fifth) {
93 | val proxy = ClusterSharding(system).startProxy(
94 | typeName = "TestEntity",
95 | role = None,
96 | dataCenter = None,
97 | extractEntityId = extractEntityId,
98 | extractShardId = extractShardId)
99 | proxyRegion = Some(proxy)
100 | }
101 |
102 | awaitClusterUp(roles: _*)
103 |
104 | enterBarrier("shards-allocated")
105 |
106 | }
107 |
108 | "route by taking into account shard location" in {
109 | within(20.seconds) {
110 |
111 | runOn(first, second, third, fourth) {
112 | region
113 | .map { r =>
114 | system.actorOf(Props(new TestRoutee(r)), "routee")
115 | enterBarrier("routee-started")
116 |
117 | router = system.actorOf(
118 | ClusterRouterGroup(
119 | ShardLocationAwareGroup(
120 | routeePaths = Nil,
121 | shardRegion = r,
122 | extractEntityId = extractEntityId,
123 | extractShardId = extractShardId),
124 | ClusterRouterGroupSettings(
125 | totalInstances = 4,
126 | routeesPaths = List("/user/routee"),
127 | allowLocalRoutees = true)).props(),
128 | "sharding-aware-router")
129 |
130 | awaitAssert {
131 | currentRoutees(router).size shouldBe 4
132 | }
133 | enterBarrier("router-started")
134 |
135 | }
136 | .getOrElse {
137 | fail("Region not set")
138 | }
139 |
140 | }
141 | runOn(fifth) {
142 |
143 | // no routees here
144 | enterBarrier("routee-started")
145 |
146 | proxyRegion
147 | .map { proxy =>
148 | proxyRouter = system.actorOf(
149 | ClusterRouterGroup(
150 | ShardLocationAwareGroup(
151 | routeePaths = Nil,
152 | shardRegion = proxy,
153 | extractEntityId = extractEntityId,
154 | extractShardId = extractShardId),
155 | ClusterRouterGroupSettings(
156 | totalInstances = 4,
157 | routeesPaths = List("/user/routee"),
158 | allowLocalRoutees = false)).props(),
159 | "sharding-aware-router")
160 | }
161 | .getOrElse {
162 | fail("Proxy region not set")
163 | }
164 |
165 | awaitAssert {
166 | currentRoutees(proxyRouter).size shouldBe 4
167 | }
168 | enterBarrier("router-started")
169 | }
170 |
171 | // now give time to the new shards to be allocated and time to the router to retrieve new information
172 | // one second to send out the update shard information and one second of safety margin
173 | Thread.sleep(2000)
174 |
175 | runOn(fifth) {
176 | val probe = TestProbe("probe")
177 | for (i <- 1 to 40) {
178 | probe.send(proxyRouter, Ping(i, probe.ref))
179 | }
180 |
181 | val msgs: Seq[Pong] = probe.receiveN(40).collect { case p: Pong => p }
182 |
183 | val (same: Seq[Pong], different) = msgs.partition {
184 | case Pong(_, _, routeeAddress, entityAddress) =>
185 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
186 | }
187 |
188 | different.isEmpty shouldBe true
189 |
190 | val byAddress = same.groupBy(_.routeeAddress)
191 |
192 | byAddress(first).size shouldBe 10
193 | byAddress(first).map(_.id).toSet shouldEqual (1 to 10).toSet
194 | byAddress(second).size shouldBe 10
195 | byAddress(second).map(_.id).toSet shouldEqual (11 to 20).toSet
196 | byAddress(third).size shouldBe 10
197 | byAddress(third).map(_.id).toSet shouldEqual (21 to 30).toSet
198 | byAddress(fourth).size shouldBe 10
199 | byAddress(fourth).map(_.id).toSet shouldEqual (31 to 40).toSet
200 | }
201 | enterBarrier("test-done")
202 |
203 | }
204 |
205 | }
206 |
207 | "adjust routing after a topology change" in {
208 | awaitMemberRemoved(fourth)
209 | awaitAllReachable()
210 | runOn(first) {
211 | testConductor.removeNode(fourth)
212 | }
213 |
214 | runOn(fifth) {
215 | // trigger rebalancing the shards of the removed node
216 | val rebalanceProbe = TestProbe("rebalance")
217 | for (i <- 31 to 40) {
218 | rebalanceProbe.send(proxyRouter, Ping(i, rebalanceProbe.ref))
219 | }
220 |
221 | // we should be receiving messages even in the absence of the updated shard location information
222 | // random routing should kick in, i.e. we won't have perfect matches
223 | val randomRoutedMessages: Seq[Pong] = rebalanceProbe.receiveN(10, 20.seconds).collect { case p: Pong => p }
224 | val (_, differentMsgs) = partitionByAddress(randomRoutedMessages)
225 | differentMsgs.nonEmpty shouldBe true
226 |
227 | // now give time to the new shards to be allocated and time to the router to retrieve new information
228 | // one second to send out the update shard information and one second of safety margin
229 | Thread.sleep(2000)
230 |
231 | val probe = TestProbe("probe")
232 | for (i <- 1 to 40) {
233 | probe.send(proxyRouter, Ping(i, probe.ref))
234 | }
235 |
236 | val msgs: Seq[Pong] = probe.receiveN(40, 10.seconds).collect { case p: Pong => p }
237 |
238 | val (_, different) = msgs.partition {
239 | case Pong(_, _, routeeAddress, entityAddress) =>
240 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
241 | }
242 |
243 | different.isEmpty shouldBe true
244 |
245 | }
246 |
247 | runOn(first, second, third, fifth) {
248 | enterBarrier("finished")
249 |
250 | }
251 |
252 | }
253 |
254 | def currentRoutees(router: ActorRef) =
255 | Await.result(router ? GetRoutees, timeout.duration).asInstanceOf[Routees].routees
256 |
257 | def partitionByAddress(msgs: Seq[Pong]) = msgs.partition {
258 | case Pong(_, _, routeeAddress, entityAddress) =>
259 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty
260 | }
261 |
262 | }
263 |
--------------------------------------------------------------------------------
/src/test/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRoutingLogicSpec.scala:
--------------------------------------------------------------------------------
1 | package io.bernhardt.akka.locality.router
2 |
3 | import akka.actor.ActorSystem
4 | import akka.cluster.sharding.ShardRegion
5 | import akka.cluster.sharding.ShardRegion.{ ClusterShardingStats, GetClusterShardingStats, ShardRegionStats }
6 | import akka.routing.ActorRefRoutee
7 | import akka.testkit.{ TestKit, TestProbe }
8 | import io.bernhardt.akka.locality.Locality
9 | import org.scalatest.{ BeforeAndAfterAll, Matchers, WordSpecLike }
10 |
11 | import scala.collection.immutable.IndexedSeq
12 |
13 | class ShardLocationAwareRoutingLogicSpec
14 | extends TestKit(ActorSystem("ShardLocalityAwareRoutingLogicSpec"))
15 | with WordSpecLike
16 | with Matchers
17 | with BeforeAndAfterAll {
18 | import ShardLocationAwareRoutingLogicSpec._
19 |
20 | "The ShardLocalityAwareRoutingLogic" should {
21 | "fall back to random routing when no sharding state is available" in {
22 | Locality(system)
23 |
24 | val shardRegion = TestProbe("region")
25 | val routee1 = TestProbe("routee1")
26 | val routee2 = TestProbe("routee2")
27 | val allRoutees = IndexedSeq(routee1, routee2).map(r => ActorRefRoutee(r.ref))
28 |
29 | val logic = ShardLocationAwareRoutingLogic(system, shardRegion.ref, extractEntityId, extractShardId)
30 |
31 | val runs = 1000
32 | val selections = for (_ <- 1 to runs) yield {
33 | logic.select(TestMessage(1), allRoutees)
34 | }
35 | val count = selections.count(_ == allRoutees(0))
36 | val expectedCount = runs * 50 / 100
37 | val variation = expectedCount / 10
38 |
39 | count shouldEqual expectedCount +- variation
40 | }
41 |
42 | "retrieve shard state on startup and use it in order to route messages" in {
43 | // use system identifier in order to simulate running on multiple nodes
44 | val system1 = ActorSystem("node1")
45 | val system2 = ActorSystem("node2")
46 | val system3 = ActorSystem("node3")
47 | Locality(system1)
48 | Locality(system2)
49 | Locality(system3)
50 |
51 | val routee1 = TestProbe("routee1")(system1)
52 | val routee2 = TestProbe("routee2")(system2)
53 | val routee3 = TestProbe("routee3")(system3)
54 |
55 | val allRoutees = IndexedSeq(routee1, routee2, routee3).map(r => ActorRefRoutee(r.ref))
56 |
57 | val region1 = TestProbe("region1")(system1)
58 | val shards1 = for (i <- 1 to 10) yield i
59 | val region2 = TestProbe("region2")(system2)
60 | val shards2 = for (i <- 11 to 20) yield i
61 | val region3 = TestProbe("region3")(system3)
62 | val shards3 = for (i <- 21 to 30) yield i
63 |
64 | val logic = ShardLocationAwareRoutingLogic(system1, region1.ref, extractEntityId, extractShardId)
65 |
66 | // logic tries to get shard state
67 | import scala.concurrent.duration._
68 | region1.expectMsgType[GetClusterShardingStats](5.seconds)
69 | val monitor1 = region1.sender()
70 |
71 | monitor1 ! ClusterShardingStats(regions = Map(region1.ref.path.address -> ShardRegionStats(shards1.map { id =>
72 | id.toString -> 0
73 | }.toMap), region2.ref.path.address -> ShardRegionStats(shards2.map { id =>
74 | id.toString -> 0
75 | }.toMap), region3.ref.path.address -> ShardRegionStats(shards3.map { id =>
76 | id.toString -> 0
77 | }.toMap)))
78 |
79 | // TODO find a better way to determine that the router is ready
80 | Thread.sleep(500)
81 |
82 | for (i <- 1 to 10) {
83 | logic.select(TestMessage(i), allRoutees) shouldEqual ActorRefRoutee(routee1.ref)
84 | }
85 | for (i <- 11 to 20) {
86 | logic.select(TestMessage(i), allRoutees) shouldEqual ActorRefRoutee(routee2.ref)
87 | }
88 | for (i <- 21 to 30) {
89 | logic.select(TestMessage(i), allRoutees) shouldEqual ActorRefRoutee(routee3.ref)
90 | }
91 |
92 | TestKit.shutdownActorSystem(system1)
93 | TestKit.shutdownActorSystem(system2)
94 | TestKit.shutdownActorSystem(system3)
95 | }
96 | }
97 |
98 | override protected def afterAll(): Unit =
99 | TestKit.shutdownActorSystem(system)
100 | }
101 |
102 | object ShardLocationAwareRoutingLogicSpec {
103 | final case class TestMessage(id: Int)
104 |
105 | val extractEntityId: ShardRegion.ExtractEntityId = {
106 | case msg @ TestMessage(id) => (id.toString, msg)
107 | }
108 |
109 | val extractShardId: ShardRegion.ExtractShardId = {
110 | // take this simplest mapping on purpose
111 | case TestMessage(id) => id.toString
112 | }
113 | }
114 |
--------------------------------------------------------------------------------