├── .github └── workflows │ └── scala.yml ├── .gitignore ├── .scalafmt.conf ├── LICENSE ├── README.md ├── build.sbt ├── project ├── build.properties └── plugins.sbt └── src ├── main ├── resources │ └── reference.conf └── scala │ └── io │ └── bernhardt │ └── akka │ └── locality │ ├── Locality.scala │ ├── package.scala │ └── router │ ├── ShardLocationAwareRouter.scala │ └── ShardStateMonitor.scala ├── multi-jvm └── scala │ ├── akka │ ├── README.md │ ├── cluster │ │ ├── FailureDetectorPuppet.scala │ │ ├── MultiNodeClusterSpec.scala │ │ ├── sharding │ │ │ ├── MultiNodeClusterShardingConfig.scala │ │ │ └── MultiNodeClusterShardingSpec.scala │ │ └── testkit │ │ │ └── AutoDowning.scala │ ├── remote │ │ └── testkit │ │ │ └── STMultiNodeSpec.scala │ ├── serialization │ │ └── jackson │ │ │ └── CborSerializable.scala │ └── testkit │ │ └── AkkaSpec.scala │ └── io │ └── bernhardt │ └── akka │ └── locality │ └── router │ ├── ShardLocationAwareRouterNewShardsSpec.scala │ ├── ShardLocationAwareRouterSpec.scala │ └── ShardLocationAwareRouterWithProxySpec.scala └── test └── scala └── io └── bernhardt └── akka └── locality └── router └── ShardLocationAwareRoutingLogicSpec.scala /.github/workflows/scala.yml: -------------------------------------------------------------------------------- 1 | name: Scala CI 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@v1 12 | - name: Set up JDK 1.8 13 | uses: actions/setup-java@v1 14 | with: 15 | java-version: 1.8 16 | - name: Run tests 17 | run: sbt test 18 | - name: Run multi-jvm tests 19 | run: sbt multi-jvm:test 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | bin/ 2 | target/ 3 | build/ 4 | 5 | *.log 6 | *.iml 7 | *.ipr 8 | *.iws 9 | .idea 10 | 11 | .DS_Store 12 | 13 | .history 14 | .scala_dependencies 15 | .cache 16 | .cache-main 17 | 18 | *.class 19 | -------------------------------------------------------------------------------- /.scalafmt.conf: -------------------------------------------------------------------------------- 1 | version = 2.2.2 2 | 3 | style = defaultWithAlign 4 | 5 | docstrings = JavaDoc 6 | indentOperator = spray 7 | maxColumn = 120 8 | rewrite.rules = [RedundantParens, SortImports, AvoidInfix] 9 | unindentTopLevelOperators = true 10 | align.tokens = [{code = "=>", owner = "Case"}] 11 | align.openParenDefnSite = false 12 | align.openParenCallSite = false 13 | optIn.breakChainOnFirstMethodDot = false 14 | optIn.configStyleArguments = false 15 | danglingParentheses = false 16 | spaces.inImportCurlyBraces = true 17 | rewrite.neverInfix.excludeFilters = [ 18 | and 19 | min 20 | max 21 | until 22 | to 23 | by 24 | eq 25 | ne 26 | "should.*" 27 | "contain.*" 28 | "must.*" 29 | in 30 | ignore 31 | be 32 | taggedAs 33 | thrownBy 34 | synchronized 35 | have 36 | when 37 | size 38 | only 39 | noneOf 40 | oneElementOf 41 | noElementsOf 42 | atLeastOneElementOf 43 | atMostOneElementOf 44 | allElementsOf 45 | inOrderElementsOf 46 | theSameElementsAs 47 | ] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Manuel Bernhardt 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # akka-locality 2 | 3 | This module provides constructs that help to make better use of the locality of actors within a clustered Akka system. 4 | For a full explanation of the problem it addresses, [check out this article](https://manuel.bernhardt.io/2019/10/28/one-step-closer-exploiting-locality-in-akka-cluster-based-systems/) 5 | 6 | ### SBT 7 | 8 | ```sbt 9 | libraryDependencies += "io.bernhardt" %% "akka-locality" % "1.1.0" 10 | ``` 11 | 12 | ### Maven 13 | 14 | ```xml 15 | 16 | io.bernhardt 17 | akka-locality_2.12 18 | 1.1.0 19 | 20 | ``` 21 | 22 | ## Shard location aware routers 23 | 24 | This type of router is useful for systems in which the routees of cluster-aware routers need to communicate with sharded 25 | entities. 26 | 27 | With a common routing logic (random, round-robin) there may be an extra network hop (or two when considering replies) 28 | between a routee and the sharded entities it needs to talk to. Shard location aware routers optimize this by routing 29 | to the routee closest to the sharded entity. It does so by using the same rules for extracting the `shardId` from a 30 | message as used by the shard regions themselves. 31 | 32 | When the router has not yet retrieved sharding state, it falls back to random routing. 33 | When there are more than one candidate routee close to a sharded entity, one of them is picked at random. 34 | 35 | In order to use these routers, the `Locality` extension must be started: 36 | 37 | ### Scala 38 | 39 | ```scala 40 | import io.bernhardt.akka.locality._ 41 | import akka.actor.ActorSystem 42 | 43 | val system: ActorSystem = ActorSystem("system") 44 | val locality = Locality(system) 45 | ``` 46 | 47 | ### Java 48 | 49 | ```java 50 | import io.bernhardt.akka.locality; 51 | import akka.actor.ActorSystem; 52 | 53 | ActorSystem system = ActorSystem.create("system"); 54 | Locality locality = Locality.get(system); 55 | ``` 56 | 57 | You can then use the group or pool routers as a cluster-aware router. These routers must be declared in code, as they 58 | require to be passed elements from the sharding setup: 59 | 60 | ### Scala 61 | 62 | ```scala 63 | import akka.actor.{ActorSystem, ActorRef} 64 | import akka.cluster.sharding.ShardRegion 65 | import akka.cluster.routing._ 66 | 67 | import io.bernhardt.akka.locality.Locality 68 | 69 | val system: ActorSystem = ActorSystem("system") 70 | val locality: Locality = Locality(system) 71 | val extractEntityId: ShardRegion.ExtractEntityId = ??? 72 | val extractShardId: ShardRegion.ExtractShardId = ??? 73 | val region: ActorRef = ??? 74 | 75 | val router = system.actorOf(ClusterRouterGroup(locality.shardLocationAwareGroup( 76 | routeePaths = Nil, 77 | shardRegion = region, 78 | extractEntityId = extractEntityId, 79 | extractShardId = extractShardId 80 | ), ClusterRouterGroupSettings( 81 | totalInstances = 5, 82 | routeesPaths = List("/user/routee"), 83 | allowLocalRoutees = true 84 | )).props(), "shard-location-aware-router") 85 | ``` 86 | 87 | ### Java 88 | 89 | 90 | ```java 91 | import akka.actor.ActorSystem; 92 | import akka.actor.ActorRef; 93 | import akka.cluster.sharding.ShardRegion; 94 | import akka.cluster.routing.ClusterRouterGroup; 95 | import akka.cluster.routing.ClusterRouterGroupSettings; 96 | 97 | ActorRef region = ...; 98 | ShardRegion.MessageExtractor messageExtractor = ...; 99 | int totalInstances = 5; 100 | Iterable routeesPaths = Collections.singletonList("/user/routee"); 101 | boolean allowLocalRoutees = true; 102 | Set useRoles = new HashSet<>(Arrays.asList("role")); 103 | 104 | ActorRef router = system.actorOf( 105 | new ClusterRouterGroup( 106 | locality.shardLocationAwareGroup( 107 | routeesPaths, 108 | region, 109 | messageExtractor 110 | ), 111 | new ClusterRouterGroupSettings( 112 | totalInstances, 113 | routeesPaths, 114 | allowLocalRoutees, 115 | useRoles 116 | ) 117 | ).props(), "shard-location-aware-router"); 118 | ``` 119 | 120 | Always make sure that: 121 | 122 | - you use exactly the same logic for the routers as you use for sharding 123 | - you deploy the routers on all the nodes on which sharding is enabled 124 | 125 | ### Configuration 126 | 127 | See [reference.conf](https://github.com/manuelbernhardt/akka-locality/blob/master/src/main/resources/reference.conf) for more information about the configuration of the routing mechanism. 128 | 129 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import com.typesafe.sbt.MultiJvmPlugin.multiJvmSettings 2 | 3 | lazy val akkaVersion = "2.5.26" 4 | 5 | lazy val `akka-locality` = project 6 | .in(file(".")) 7 | .enablePlugins(MultiJvmPlugin) 8 | .configs(MultiJvm) 9 | .settings(multiJvmSettings: _*) 10 | .settings(publishingSettings: _*) 11 | .settings( 12 | name := "akka-locality", 13 | version := "1.1.0", 14 | startYear := Some(2019), 15 | scalaVersion := "2.12.10", 16 | crossScalaVersions := Seq("2.12.10", "2.13.1"), 17 | scalacOptions ++= Seq( 18 | "-unchecked", 19 | "-deprecation", 20 | "-language:_", 21 | "-target:jvm-1.8", 22 | "-encoding", "UTF-8" 23 | ), 24 | parallelExecution in Test := false, 25 | libraryDependencies ++= Seq( 26 | "com.typesafe.akka" %% "akka-actor" % akkaVersion % "provided;multi-jvm;test", 27 | "com.typesafe.akka" %% "akka-cluster-sharding" % akkaVersion % "provided;multi-jvm;test", 28 | "com.typesafe.akka" %% "akka-persistence" % akkaVersion % Test, 29 | "com.typesafe.akka" %% "akka-testkit" % akkaVersion % Test, 30 | "com.typesafe.akka" %% "akka-multi-node-testkit" % akkaVersion % Test, 31 | "org.iq80.leveldb" % "leveldb" % "0.12" % "optional;provided;multi-jvm;test", 32 | "commons-io" % "commons-io" % "2.6" % Test, 33 | "org.scalatest" %% "scalatest" % "3.0.8" % Test 34 | ), 35 | credentials += Credentials(Path.userHome / ".sbt" / "sonatype_credential") 36 | ) 37 | 38 | val publishingSettings = Seq( 39 | ThisBuild / organization := "io.bernhardt", 40 | ThisBuild / organizationName := "manuel.bernhardt.io", 41 | ThisBuild / organizationHomepage := Some(url("https://manuel.bernhardt.io")), 42 | 43 | ThisBuild / scmInfo := Some( 44 | ScmInfo( 45 | url("https://github.com/manuelbernhardt/akka-locality"), 46 | "scm:git@github.com:manuelbernhardt/akka-locality.git" 47 | ) 48 | ), 49 | ThisBuild / developers := List( 50 | Developer( 51 | id = "manuel", 52 | name = "Manuel Bernhardt", 53 | email = "manuel@bernhardt.io", 54 | url = url("https://manuel.bernhardt.io") 55 | ) 56 | ), 57 | ThisBuild / description := "Akka extension to make better use of locality of actors in clustered systems", 58 | ThisBuild / licenses := List("Apache 2" -> new URL("http://www.apache.org/licenses/LICENSE-2.0.txt")), 59 | ThisBuild / homepage := Some(url("https://github.com/manuelbernhardt/akka-locality")), 60 | 61 | // Remove all additional repository other than Maven Central from POM 62 | ThisBuild / pomIncludeRepository := { _ => false }, 63 | ThisBuild / publishTo := { 64 | val nexus = "https://oss.sonatype.org/" 65 | if (isSnapshot.value) Some("snapshots" at nexus + "content/repositories/snapshots") 66 | else Some("releases" at nexus + "service/local/staging/deploy/maven2") 67 | }, 68 | ThisBuild / publishMavenStyle := true 69 | ) -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.3.3 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.typesafe.sbt" % "sbt-multi-jvm" % "0.4.0") 2 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.0") 3 | addSbtPlugin("org.scalameta" % "sbt-scalafmt" % "2.2.0") -------------------------------------------------------------------------------- /src/main/resources/reference.conf: -------------------------------------------------------------------------------- 1 | akka { 2 | locality { 3 | 4 | # The timeout to use when attempting to retrive shard states from the cluster sharding system. 5 | # For large clusters this value may need to be increased 6 | retrieve-shard-state-timeout = 5 seconds 7 | 8 | # The margin to keep before requesting an update from the cluster sharding system when a rebalance is detected. 9 | # This is necessary because shards need some time to be rebalanced (in case of normal rebalancing, or in case 10 | # of a topology change). 11 | # If you use a [[akka.cluster.DowningProvider]], you should take into account the `downRemovalMargin` since shards 12 | # will only be re-allocated after this margin has elapsed. 13 | shard-state-update-margin = 10 seconds 14 | 15 | # The interval between periodic polling of the shard state 16 | # The shard state is queried periodically in order to get aware of newly created shards that have not caused 17 | # a rebalancing yet to occur yet. 18 | shard-state-polling-interval = 15 seconds 19 | } 20 | } -------------------------------------------------------------------------------- /src/main/scala/io/bernhardt/akka/locality/Locality.scala: -------------------------------------------------------------------------------- 1 | package io.bernhardt.akka.locality 2 | 3 | import java.util.concurrent.TimeUnit 4 | 5 | import akka.actor.{ 6 | Actor, 7 | ActorLogging, 8 | ActorRef, 9 | ActorSystem, 10 | ExtendedActorSystem, 11 | Extension, 12 | ExtensionId, 13 | ExtensionIdProvider, 14 | Props 15 | } 16 | import akka.cluster.sharding.ShardRegion 17 | import akka.cluster.sharding.ShardRegion.MessageExtractor 18 | import com.typesafe.config.Config 19 | import io.bernhardt.akka.locality.router.{ ShardLocationAwareGroup, ShardLocationAwarePool, ShardStateMonitor } 20 | 21 | import scala.collection.immutable 22 | import scala.concurrent.duration.FiniteDuration 23 | 24 | object Locality extends ExtensionId[Locality] with ExtensionIdProvider { 25 | override def get(system: ActorSystem): Locality = super.get(system) 26 | 27 | override def createExtension(system: ExtendedActorSystem): Locality = new Locality(system) 28 | 29 | override def lookup(): ExtensionId[_ <: Extension] = Locality 30 | } 31 | 32 | /** 33 | * This module provides constructs that help to make better use of the locality of actors within a clustered Akka system. 34 | */ 35 | class Locality(system: ExtendedActorSystem) extends Extension { 36 | private val settings = LocalitySettings(system.settings.config) 37 | 38 | system.systemActorOf(Props(new LocalitySupervisor(settings)), "locality") 39 | 40 | /** 41 | * Scala API: Create a shard location aware group 42 | * 43 | * @param routeePaths string representation of the actor paths of the routees, messages are 44 | * sent with [[akka.actor.ActorSelection]] to these paths 45 | * @param shardRegion the reference to the shard region 46 | * @param extractEntityId the [[akka.cluster.sharding.ShardRegion.ExtractEntityId]] function used to extract the entity id from a message 47 | * @param extractShardId the [[akka.cluster.sharding.ShardRegion.ExtractShardId]] function used to extract the shard id from a message 48 | */ 49 | def shardLocationAwareGroup( 50 | routeePaths: immutable.Iterable[String], 51 | shardRegion: ActorRef, 52 | extractEntityId: ShardRegion.ExtractEntityId, 53 | extractShardId: ShardRegion.ExtractShardId): ShardLocationAwareGroup = 54 | ShardLocationAwareGroup(routeePaths, shardRegion, extractEntityId, extractShardId) 55 | 56 | /** 57 | * Java API: Create a shard location aware group 58 | * 59 | * @param routeePaths string representation of the actor paths of the routees, messages are 60 | * sent with [[akka.actor.ActorSelection]] to these paths 61 | * @param shardRegion the reference to the shard region 62 | * @param messageExtractor the [[akka.cluster.sharding.ShardRegion.MessageExtractor]] used for the sharding 63 | * of the entities this router should optimize routing for 64 | */ 65 | def shardLocationAwareGroup( 66 | routeePaths: java.lang.Iterable[String], 67 | shardRegion: ActorRef, 68 | messageExtractor: MessageExtractor): ShardLocationAwareGroup = 69 | new ShardLocationAwareGroup(routeePaths, shardRegion, messageExtractor) 70 | 71 | /** 72 | * Scala API: Create a shard location aware pool 73 | * 74 | * @param nrOfInstances how many routees this pool router should have 75 | * @param shardRegion the reference to the shard region 76 | * @param extractEntityId the [[akka.cluster.sharding.ShardRegion.ExtractEntityId]] function used to extract the entity id from a message 77 | * @param extractShardId the [[akka.cluster.sharding.ShardRegion.ExtractShardId]] function used to extract the shard id from a message 78 | */ 79 | def shardLocationAwarePool( 80 | nrOfInstances: Int, 81 | shardRegion: ActorRef, 82 | extractEntityId: ShardRegion.ExtractEntityId, 83 | extractShardId: ShardRegion.ExtractShardId): ShardLocationAwarePool = 84 | ShardLocationAwarePool( 85 | nrOfInstances = nrOfInstances, 86 | shardRegion = shardRegion, 87 | extractEntityId = extractEntityId, 88 | extractShardId = extractShardId) 89 | 90 | /** 91 | * Java API: Create a shard location aware pool 92 | * 93 | * @param nrOfInstances how many routees this pool router should have 94 | * @param shardRegion the reference to the shard region 95 | * @param messageExtractor the [[akka.cluster.sharding.ShardRegion.MessageExtractor]] used for the sharding 96 | * of the entities this router should optimize routing for 97 | */ 98 | def shardLocationAwarePool(nrOfInstances: Int, shardRegion: ActorRef, messageExtractor: MessageExtractor) = 99 | new ShardLocationAwarePool(nrOfInstances, shardRegion, messageExtractor) 100 | } 101 | 102 | private[locality] final class LocalitySupervisor(settings: LocalitySettings) extends Actor with ActorLogging { 103 | import LocalitySupervisor._ 104 | 105 | def receive: Receive = { 106 | case m @ MonitorShards(region) => 107 | val regionName = encodeRegionName(region) 108 | context 109 | .child(regionName) 110 | .map { monitor => 111 | monitor.forward(m) 112 | } 113 | .getOrElse { 114 | log.info("Starting to monitor shards of region {}", regionName) 115 | context.actorOf(ShardStateMonitor.props(region, regionName, settings), regionName).forward(m) 116 | } 117 | } 118 | } 119 | 120 | object LocalitySupervisor { 121 | private[locality] final case class MonitorShards(region: ActorRef) 122 | } 123 | 124 | final case class LocalitySettings(config: Config) { 125 | private val localityConfig = config.getConfig("akka.locality") 126 | 127 | val RetrieveShardStateTimeout: FiniteDuration = 128 | FiniteDuration( 129 | localityConfig.getDuration("retrieve-shard-state-timeout", TimeUnit.MILLISECONDS), 130 | TimeUnit.MILLISECONDS) 131 | 132 | val ShardStateUpdateMargin = 133 | FiniteDuration( 134 | localityConfig.getDuration("shard-state-update-margin", TimeUnit.MILLISECONDS), 135 | TimeUnit.MILLISECONDS) 136 | 137 | val ShardStatePollingInterval = 138 | FiniteDuration( 139 | localityConfig.getDuration("shard-state-polling-interval", TimeUnit.MILLISECONDS), 140 | TimeUnit.MILLISECONDS) 141 | } 142 | -------------------------------------------------------------------------------- /src/main/scala/io/bernhardt/akka/locality/package.scala: -------------------------------------------------------------------------------- 1 | package io.bernhardt.akka 2 | 3 | import java.net.URLEncoder 4 | 5 | import akka.actor.ActorRef 6 | import akka.cluster.sharding.ShardRegion.ShardId 7 | 8 | package object locality { 9 | private[locality] def encodeRegionName(region: ActorRef): String = 10 | URLEncoder.encode(region.path.name.replaceAll("Proxy", ""), "utf-8") 11 | 12 | private[locality] def encodeShardId(id: ShardId): String = URLEncoder.encode(id, "utf-8") 13 | } 14 | -------------------------------------------------------------------------------- /src/main/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRouter.scala: -------------------------------------------------------------------------------- 1 | package io.bernhardt.akka.locality.router 2 | 3 | import java.util.concurrent.atomic.AtomicReference 4 | import java.util.concurrent.{ ThreadLocalRandom, TimeUnit } 5 | 6 | import akka.actor._ 7 | import akka.cluster.sharding.ShardRegion 8 | import akka.cluster.sharding.ShardRegion._ 9 | import akka.dispatch.Dispatchers 10 | import akka.event.Logging 11 | import akka.japi.Util.immutableSeq 12 | import akka.pattern.{ ask, AskTimeoutException } 13 | import akka.routing._ 14 | import akka.util.Timeout 15 | import io.bernhardt.akka.locality.LocalitySupervisor 16 | 17 | import scala.collection.immutable 18 | import scala.collection.immutable.IndexedSeq 19 | import scala.concurrent.Future 20 | import scala.util.control.NonFatal 21 | 22 | object ShardLocationAwareRouter { 23 | def extractEntityIdFrom(messageExtractor: MessageExtractor): ShardRegion.ExtractEntityId = { 24 | case msg if messageExtractor.entityId(msg) ne null => 25 | (messageExtractor.entityId(msg), messageExtractor.entityMessage(msg)) 26 | } 27 | } 28 | 29 | /** 30 | * A group router that will route to the routees deployed the closest to the sharded entity they need to interact with 31 | */ 32 | @SerialVersionUID(1L) 33 | final case class ShardLocationAwareGroup( 34 | routeePaths: immutable.Iterable[String], 35 | shardRegion: ActorRef, 36 | extractEntityId: ShardRegion.ExtractEntityId, 37 | extractShardId: ShardRegion.ExtractShardId, 38 | override val routerDispatcher: String = Dispatchers.DefaultDispatcherId) 39 | extends Group { 40 | /** 41 | * Java API 42 | * 43 | * @param routeePaths string representation of the actor paths of the routees, messages are 44 | * sent with [[akka.actor.ActorSelection]] to these paths 45 | * @param shardRegion the reference to the shard region 46 | * @param messageExtractor the [[akka.cluster.sharding.ShardRegion.MessageExtractor]] used for the sharding 47 | * of the entities this router should optimize routing for 48 | */ 49 | def this(routeePaths: java.lang.Iterable[String], shardRegion: ActorRef, messageExtractor: MessageExtractor) = 50 | this( 51 | immutableSeq(routeePaths), 52 | shardRegion, 53 | ShardLocationAwareRouter.extractEntityIdFrom(messageExtractor), 54 | extractShardId = msg => messageExtractor.shardId(msg)) 55 | 56 | /** 57 | * Setting the dispatcher to be used for the router head actor, which handles 58 | * supervision, death watch and router management messages. 59 | */ 60 | def withDispatcher(dispatcherId: String): ShardLocationAwareGroup = copy(routerDispatcher = dispatcherId) 61 | 62 | override def paths(system: ActorSystem): immutable.Iterable[String] = routeePaths 63 | 64 | override def createRouter(system: ActorSystem): Router = 65 | new Router(ShardLocationAwareRoutingLogic(system, shardRegion, extractEntityId, extractShardId)) 66 | } 67 | 68 | /** 69 | * A pool router that will route to the routees deployed the closest to the sharded entity they need to interact with 70 | */ 71 | @SerialVersionUID(1L) 72 | final case class ShardLocationAwarePool( 73 | nrOfInstances: Int, 74 | override val resizer: Option[Resizer] = None, 75 | shardRegion: ActorRef, 76 | extractEntityId: ShardRegion.ExtractEntityId, 77 | extractShardId: ShardRegion.ExtractShardId, 78 | override val supervisorStrategy: SupervisorStrategy = Pool.defaultSupervisorStrategy, 79 | override val routerDispatcher: String = Dispatchers.DefaultDispatcherId, 80 | override val usePoolDispatcher: Boolean = false) 81 | extends Pool { 82 | /** 83 | * Java API 84 | * 85 | * @param nrOfInstances how many routees this pool router should have 86 | * @param shardRegion the reference to the shard region 87 | * @param messageExtractor the [[akka.cluster.sharding.ShardRegion.MessageExtractor]] used for the sharding 88 | * of the entities this router should optimize routing for 89 | */ 90 | def this(nrOfInstances: Int, shardRegion: ActorRef, messageExtractor: MessageExtractor) = 91 | this( 92 | nrOfInstances = nrOfInstances, 93 | shardRegion = shardRegion, 94 | extractEntityId = ShardLocationAwareRouter.extractEntityIdFrom(messageExtractor), 95 | extractShardId = msg => messageExtractor.shardId(msg)) 96 | 97 | /** 98 | * Setting the supervisor strategy to be used for the “head” Router actor. 99 | */ 100 | def withSupervisorStrategy(strategy: SupervisorStrategy): ShardLocationAwarePool = copy(supervisorStrategy = strategy) 101 | 102 | /** 103 | * Setting the resizer to be used. 104 | */ 105 | def withResizer(resizer: Resizer): ShardLocationAwarePool = copy(resizer = Some(resizer)) 106 | 107 | /** 108 | * Setting the dispatcher to be used for the router head actor, which handles 109 | * supervision, death watch and router management messages. 110 | */ 111 | def withDispatcher(dispatcherId: String): ShardLocationAwarePool = copy(routerDispatcher = dispatcherId) 112 | 113 | /** 114 | * Setting whether to use a dedicated dispatcher for the routees of the pool. 115 | * The dispatcher is defined in 'pool-dispatcher' configuration property in the 116 | * deployment section of the router. 117 | */ 118 | def usePoolDispatcher(usePoolDispatcher: Boolean): ShardLocationAwarePool = 119 | copy(usePoolDispatcher = usePoolDispatcher) 120 | 121 | override def createRouter(system: ActorSystem): Router = 122 | new Router(ShardLocationAwareRoutingLogic(system, shardRegion, extractEntityId, extractShardId)) 123 | 124 | override def nrOfInstances(sys: ActorSystem): Int = this.nrOfInstances 125 | } 126 | 127 | /** 128 | * Router logic that makes its routing decision based on the relative location of routees to the shards they will 129 | * communicate with, on a best-effort basis. 130 | * When no shard state information is available this logic falls back to random routing. 131 | * When there are multiple candidate routees on the same node, one of them is selected at random. 132 | * 133 | * @param system the [[akka.actor.ActorSystem]] 134 | * @param shardRegion the reference to the [[akka.cluster.sharding.ShardRegion]] the local routees will communicate with 135 | * @param extractEntityId partial function to extract the entity id from a message, should be the same as used for sharding 136 | * @param extractShardId partial function to extract the shard id based on a message, should be the same as used for sharding 137 | */ 138 | final case class ShardLocationAwareRoutingLogic( 139 | system: ActorSystem, 140 | shardRegion: ActorRef, 141 | extractEntityId: ShardRegion.ExtractEntityId, 142 | extractShardId: ShardRegion.ExtractShardId) 143 | extends RoutingLogic { 144 | import io.bernhardt.akka.locality.router.ShardStateMonitor._ 145 | import system.dispatcher 146 | 147 | private lazy val log = Logging(system, getClass) 148 | private lazy val selfAddress = system.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress 149 | private val clusterShardingStateRef = new AtomicReference[Map[ShardId, Address]](Map.empty) 150 | private val shardLocationAwareRouteeRef = 151 | new AtomicReference[(IndexedSeq[Routee], Map[Address, IndexedSeq[ShardLocationAwareRoutee]])]( 152 | (IndexedSeq.empty, Map.empty)) 153 | 154 | watchShardStateChanges() 155 | 156 | override def select(message: Any, routees: IndexedSeq[Routee]): Routee = { 157 | if (routees.isEmpty) { 158 | NoRoutee 159 | } else { 160 | // avoid re-creating routees for each message by checking if they have changed 161 | def updateShardLocationAwareRoutees(): Map[Address, IndexedSeq[ShardLocationAwareRoutee]] = { 162 | val oldShardingRouteeTuple = shardLocationAwareRouteeRef.get() 163 | val (oldRoutees, oldShardLocationAwareRoutees) = oldShardingRouteeTuple 164 | if ((routees ne oldRoutees) && routees != oldRoutees) { 165 | val allRoutees = routees.map(ShardLocationAwareRoutee(_, selfAddress)) 166 | val newShardLocationAwareRoutees = allRoutees.groupBy(_.address) 167 | // don't act on failure of compare and set, as the next call to select will update anyway if the routees remain different 168 | shardLocationAwareRouteeRef.compareAndSet(oldShardingRouteeTuple, (routees, newShardLocationAwareRoutees)) 169 | newShardLocationAwareRoutees 170 | } else { 171 | oldShardLocationAwareRoutees 172 | } 173 | } 174 | 175 | val shardId: ShardId = extractShardId(message) 176 | val shardLocationAwareRoutees = updateShardLocationAwareRoutees() 177 | 178 | val candidateRoutees = for { 179 | location <- clusterShardingStateRef.get().get(shardId) 180 | locationAwareRoutees <- shardLocationAwareRoutees.get(location) 181 | } yield { 182 | val closeRoutees = locationAwareRoutees.map(_.routee) 183 | 184 | // pick one of the local routees at random 185 | closeRoutees(ThreadLocalRandom.current.nextInt(closeRoutees.size)) 186 | } 187 | 188 | candidateRoutees.getOrElse { 189 | log.debug("Falling back to random routing for message in shard {}", shardId) 190 | // if we couldn't figure out the location of the shard, fall back to random routing 191 | routees(ThreadLocalRandom.current.nextInt(routees.size)) 192 | } 193 | } 194 | } 195 | 196 | private def watchShardStateChanges(): Unit = { 197 | implicit val timeout: Timeout = Timeout(2 ^ 64, TimeUnit.DAYS) 198 | val localitySel = system.actorSelection("/system/locality") 199 | val change: Future[ShardStateChanged] = 200 | (localitySel ? LocalitySupervisor.MonitorShards(shardRegion)).mapTo[ShardStateChanged] 201 | change 202 | .map { stateChanged => 203 | if (stateChanged.newState.nonEmpty) { 204 | log.info("Updating cluster sharding state for {} shards", stateChanged.newState.keys.size) 205 | clusterShardingStateRef.set(stateChanged.newState) 206 | } 207 | watchShardStateChanges() 208 | } 209 | .recover { 210 | case _: AskTimeoutException => 211 | // we were shutting down, ignore 212 | case NonFatal(t) => 213 | log.warning("Could not monitor cluster sharding state: {}", t.getMessage) 214 | } 215 | } 216 | } 217 | 218 | private[locality] final case class ShardLocationAwareRoutee(routee: Routee, selfAddress: Address) { 219 | // extract the address of the routee. In case of a LocalActorRef, host and port are not provided 220 | // therefore we fall back to the address of the local node 221 | val address = { 222 | val routeeAddress = routee match { 223 | case ActorRefRoutee(ref) => ref.path.address 224 | case ActorSelectionRoutee(sel) => sel.anchorPath.address 225 | } 226 | 227 | routeeAddress match { 228 | case Address(_, system, None, None) => selfAddress.copy(system = system) 229 | case fullAddress => fullAddress 230 | } 231 | } 232 | } 233 | -------------------------------------------------------------------------------- /src/main/scala/io/bernhardt/akka/locality/router/ShardStateMonitor.scala: -------------------------------------------------------------------------------- 1 | package io.bernhardt.akka.locality.router 2 | 3 | import akka.actor.{ 4 | Actor, 5 | ActorIdentity, 6 | ActorLogging, 7 | ActorRef, 8 | Address, 9 | DeadLetterSuppression, 10 | Identify, 11 | Props, 12 | RootActorPath, 13 | Terminated, 14 | Timers 15 | } 16 | import akka.cluster.sharding.ShardRegion.{ ClusterShardingStats, GetClusterShardingStats, ShardId, ShardRegionStats } 17 | import io.bernhardt.akka.locality._ 18 | import io.bernhardt.akka.locality.LocalitySupervisor.MonitorShards 19 | 20 | /** 21 | * Internal: watches shard actors in order to trigger an update. Only trigger the update when the system is stable for a while. 22 | */ 23 | private[locality] class ShardStateMonitor(shardRegion: ActorRef, encodedRegionName: String, settings: LocalitySettings) 24 | extends Actor 25 | with ActorLogging 26 | with Timers { 27 | import ShardStateMonitor._ 28 | 29 | val ClusterGuardianName: String = 30 | context.system.settings.config.getString("akka.cluster.sharding.guardian-name") 31 | 32 | var watchedShards = Set.empty[ShardId] 33 | 34 | var routerLogic: ActorRef = context.system.deadLetters 35 | 36 | var latestClusterState: Option[ClusterShardingStats] = None 37 | 38 | // technically, this may also just be a node that was terminated 39 | // but if the coordinator does its job, it will rebalance / reallocate the terminated shards 40 | // either way, this flag signals that the topology is currently changing 41 | var rebalanceInProgress: Boolean = false 42 | 43 | def receive: Receive = { 44 | case _: MonitorShards => 45 | log.debug("Starting to monitor shards for logic {}", routerLogic.path) 46 | routerLogic = sender() 47 | requestClusterShardingState() 48 | timers.startPeriodicTimer(UpdateClusterState, UpdateClusterState, settings.ShardStatePollingInterval) 49 | context.become(watchingChanges) 50 | } 51 | 52 | def watchingChanges: Receive = { 53 | case _: MonitorShards => 54 | routerLogic = sender() 55 | case UpdateClusterStateOnRebalance => 56 | rebalanceInProgress = false 57 | requestClusterShardingState() 58 | case UpdateClusterState => 59 | if (!rebalanceInProgress) { 60 | requestClusterShardingState() 61 | } 62 | case ActorIdentity(shardId: ShardId, Some(ref)) => 63 | log.debug("Now watching shard {}", ref.path) 64 | context.watch(ref) 65 | watchedShards += shardId 66 | case ActorIdentity(shardId, None) => // couldn't get shard ref, not much we can do 67 | log.warning("Could not watch shard {}, shard location aware routing may not work", shardId) 68 | case Terminated(ref) => 69 | log.debug("Watched shard actor {} terminated", ref.path) 70 | rebalanceInProgress = true 71 | watchedShards -= encodeShardId(ref.path.name) 72 | // reset the timer - we only want to request state once things are stable 73 | timers.cancel(UpdateClusterStateOnRebalance) 74 | timers.startSingleTimer( 75 | UpdateClusterStateOnRebalance, 76 | UpdateClusterStateOnRebalance, 77 | settings.ShardStateUpdateMargin) 78 | case stats @ ClusterShardingStats(regions) => 79 | log.debug("Received cluster sharding stats for {} regions", regions.size) 80 | if (!latestClusterState.contains(stats)) { 81 | log.debug("Cluster sharding state changed, notifying subscriber") 82 | latestClusterState = Some(stats) 83 | if (regions.isEmpty) { 84 | log.warning("Cluster Sharding Stats empty - locality-aware routing will not function correctly") 85 | } else { 86 | notifyShardStateChanged(regions) 87 | watchShards(regions) 88 | } 89 | } 90 | } 91 | 92 | def requestClusterShardingState(): Unit = { 93 | log.debug("Requesting cluster state update") 94 | shardRegion ! GetClusterShardingStats(settings.RetrieveShardStateTimeout) 95 | } 96 | 97 | def watchShards(regions: Map[Address, ShardRegionStats]): Unit = { 98 | regions.foreach { 99 | case (address, regionStats) => 100 | val regionPath = RootActorPath(address) / "system" / ClusterGuardianName / encodedRegionName 101 | regionStats.stats.keys.filterNot(watchedShards).foreach { shardId => 102 | val shardPath = regionPath / encodeShardId(shardId) 103 | context.actorSelection(shardPath) ! Identify(shardId) 104 | } 105 | } 106 | } 107 | 108 | def notifyShardStateChanged(regions: Map[Address, ShardRegionStats]): Unit = { 109 | val shardsByAddress = regions.flatMap { 110 | case (address, ShardRegionStats(shards)) => 111 | shards.map { 112 | case (shardId, _) => 113 | shardId -> address 114 | } 115 | } 116 | routerLogic ! ShardStateChanged(shardsByAddress) 117 | } 118 | 119 | override def postStop(): Unit = { 120 | routerLogic ! ShardStateChanged(Map.empty) 121 | } 122 | } 123 | 124 | object ShardStateMonitor { 125 | final case class ShardStateChanged(newState: Map[ShardId, Address]) extends DeadLetterSuppression 126 | final case object UpdateClusterState extends DeadLetterSuppression 127 | final case object UpdateClusterStateOnRebalance extends DeadLetterSuppression 128 | 129 | private[locality] def props(shardRegion: ActorRef, entityName: String, settings: LocalitySettings) = 130 | Props(new ShardStateMonitor(shardRegion, entityName, settings)) 131 | } 132 | -------------------------------------------------------------------------------- /src/multi-jvm/scala/akka/README.md: -------------------------------------------------------------------------------- 1 | Files in this package are copied verbatim (with minor adaptations) from Akka 2 | so as to re-use the existing multi-node cluster sharding spec. -------------------------------------------------------------------------------- /src/multi-jvm/scala/akka/cluster/FailureDetectorPuppet.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2009-2019 Lightbend Inc. 3 | */ 4 | 5 | package akka.cluster 6 | 7 | import java.util.concurrent.atomic.AtomicReference 8 | 9 | import akka.remote.FailureDetector 10 | import com.typesafe.config.Config 11 | import akka.event.EventStream 12 | import akka.util.unused 13 | 14 | /** 15 | * User controllable "puppet" failure detector. 16 | */ 17 | class FailureDetectorPuppet(@unused config: Config, @unused ev: EventStream) extends FailureDetector { 18 | 19 | trait Status 20 | object Up extends Status 21 | object Down extends Status 22 | object Unknown extends Status 23 | 24 | private val status: AtomicReference[Status] = new AtomicReference(Unknown) 25 | 26 | def markNodeAsUnavailable(): Unit = status.set(Down) 27 | 28 | def markNodeAsAvailable(): Unit = status.set(Up) 29 | 30 | override def isAvailable: Boolean = status.get match { 31 | case Unknown | Up => true 32 | case Down => false 33 | } 34 | 35 | override def isMonitoring: Boolean = status.get != Unknown 36 | 37 | override def heartbeat(): Unit = status.compareAndSet(Unknown, Up) 38 | 39 | } 40 | -------------------------------------------------------------------------------- /src/multi-jvm/scala/akka/cluster/MultiNodeClusterSpec.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2009-2019 Lightbend Inc. 3 | */ 4 | 5 | package akka.cluster 6 | 7 | import java.util.UUID 8 | import java.util.concurrent.ConcurrentHashMap 9 | 10 | import akka.actor.{ Actor, ActorRef, ActorSystem, Address, Deploy, PoisonPill, Props, RootActorPath } 11 | import akka.cluster.ClusterEvent.{ MemberEvent, MemberRemoved } 12 | import akka.event.Logging.ErrorLevel 13 | import akka.remote.DefaultFailureDetectorRegistry 14 | import akka.remote.testconductor.RoleName 15 | import akka.remote.testkit.{ MultiNodeSpec, STMultiNodeSpec } 16 | import akka.serialization.jackson.CborSerializable 17 | import akka.testkit.TestEvent._ 18 | import akka.testkit._ 19 | import akka.util.ccompat._ 20 | import com.typesafe.config.{ Config, ConfigFactory } 21 | import org.scalatest.exceptions.TestCanceledException 22 | import org.scalatest.{ Canceled, Outcome, Suite } 23 | 24 | import scala.collection.immutable 25 | import scala.concurrent.Await 26 | import scala.concurrent.duration._ 27 | import scala.language.implicitConversions 28 | 29 | @ccompatUsedUntil213 30 | object MultiNodeClusterSpec { 31 | 32 | def clusterConfigWithFailureDetectorPuppet: Config = 33 | ConfigFactory 34 | .parseString("akka.cluster.failure-detector.implementation-class = akka.cluster.FailureDetectorPuppet") 35 | .withFallback(clusterConfig) 36 | 37 | def clusterConfig(failureDetectorPuppet: Boolean): Config = 38 | if (failureDetectorPuppet) clusterConfigWithFailureDetectorPuppet else clusterConfig 39 | 40 | def clusterConfig: Config = ConfigFactory.parseString(s""" 41 | akka.actor.provider = cluster 42 | akka.actor.warn-about-java-serializer-usage = off 43 | akka.cluster { 44 | jmx.enabled = off 45 | gossip-interval = 200 ms 46 | leader-actions-interval = 200 ms 47 | unreachable-nodes-reaper-interval = 500 ms 48 | periodic-tasks-initial-delay = 300 ms 49 | publish-stats-interval = 0 s # always, when it happens 50 | failure-detector.heartbeat-interval = 500 ms 51 | run-coordinated-shutdown-when-down = off 52 | 53 | sharding { 54 | retry-interval = 200ms 55 | waiting-for-state-timeout = 200ms 56 | } 57 | } 58 | akka.loglevel = INFO 59 | akka.log-dead-letters = off 60 | akka.log-dead-letters-during-shutdown = off 61 | akka.remote { 62 | log-remote-lifecycle-events = off 63 | artery.advanced.flight-recorder { 64 | enabled=on 65 | destination=target/flight-recorder-${UUID.randomUUID().toString}.afr 66 | } 67 | } 68 | akka.loggers = ["akka.testkit.TestEventListener"] 69 | akka.test { 70 | single-expect-default = 5 s 71 | } 72 | 73 | """) 74 | 75 | // sometimes we need to coordinate test shutdown with messages instead of barriers 76 | object EndActor { 77 | case object SendEnd extends CborSerializable 78 | case object End extends CborSerializable 79 | case object EndAck extends CborSerializable 80 | } 81 | 82 | class EndActor(testActor: ActorRef, target: Option[Address]) extends Actor { 83 | import EndActor._ 84 | def receive: Receive = { 85 | case SendEnd => 86 | target.foreach { t => 87 | context.actorSelection(RootActorPath(t) / self.path.elements) ! End 88 | } 89 | case End => 90 | testActor.forward(End) 91 | sender() ! EndAck 92 | case EndAck => 93 | testActor.forward(EndAck) 94 | } 95 | } 96 | } 97 | 98 | trait MultiNodeClusterSpec extends Suite with STMultiNodeSpec { 99 | self: MultiNodeSpec => 100 | 101 | override def initialParticipants = roles.size 102 | 103 | private val cachedAddresses = new ConcurrentHashMap[RoleName, Address] 104 | 105 | override protected def atStartup(): Unit = { 106 | muteLog() 107 | self.atStartup() 108 | } 109 | 110 | override protected def afterTermination(): Unit = { 111 | self.afterTermination() 112 | } 113 | 114 | def muteLog(sys: ActorSystem = system): Unit = { 115 | if (!sys.log.isDebugEnabled) { 116 | Seq( 117 | ".*Cluster Node.* - registered cluster JMX MBean.*", 118 | ".*Cluster Node.* - is starting up.*", 119 | ".*Shutting down cluster Node.*", 120 | ".*Cluster node successfully shut down.*", 121 | ".*Using a dedicated scheduler for cluster.*").foreach { s => 122 | sys.eventStream.publish(Mute(EventFilter.info(pattern = s))) 123 | } 124 | 125 | muteDeadLetters( 126 | classOf[ClusterHeartbeatSender.Heartbeat], 127 | classOf[ClusterHeartbeatSender.HeartbeatRsp], 128 | classOf[GossipEnvelope], 129 | classOf[GossipStatus], 130 | classOf[InternalClusterAction.Tick], 131 | classOf[akka.actor.PoisonPill], 132 | classOf[akka.dispatch.sysmsg.DeathWatchNotification], 133 | classOf[akka.remote.transport.AssociationHandle.Disassociated], 134 | // akka.remote.transport.AssociationHandle.Disassociated.getClass, 135 | classOf[akka.remote.transport.ActorTransportAdapter.DisassociateUnderlying], 136 | // akka.remote.transport.ActorTransportAdapter.DisassociateUnderlying.getClass, 137 | classOf[akka.remote.transport.AssociationHandle.InboundPayload])(sys) 138 | 139 | } 140 | } 141 | 142 | def muteMarkingAsUnreachable(sys: ActorSystem = system): Unit = 143 | if (!sys.log.isDebugEnabled) 144 | sys.eventStream.publish(Mute(EventFilter.error(pattern = ".*Marking.* as UNREACHABLE.*"))) 145 | 146 | def muteMarkingAsReachable(sys: ActorSystem = system): Unit = 147 | if (!sys.log.isDebugEnabled) 148 | sys.eventStream.publish(Mute(EventFilter.info(pattern = ".*Marking.* as REACHABLE.*"))) 149 | 150 | override def afterAll(): Unit = { 151 | if (!log.isDebugEnabled) { 152 | muteDeadLetters()() 153 | system.eventStream.setLogLevel(ErrorLevel) 154 | } 155 | super.afterAll() 156 | } 157 | 158 | /** 159 | * Lookup the Address for the role. 160 | * 161 | * Implicit conversion from RoleName to Address. 162 | * 163 | * It is cached, which has the implication that stopping 164 | * and then restarting a role (jvm) with another address is not 165 | * supported. 166 | */ 167 | implicit def address(role: RoleName): Address = { 168 | cachedAddresses.get(role) match { 169 | case null => 170 | val address = node(role).address 171 | cachedAddresses.put(role, address) 172 | address 173 | case address => address 174 | } 175 | } 176 | 177 | // Cluster tests are written so that if previous step (test method) failed 178 | // it will most likely not be possible to run next step. This ensures 179 | // fail fast of steps after the first failure. 180 | private var failed = false 181 | override protected def withFixture(test: NoArgTest): Outcome = 182 | if (failed) { 183 | Canceled(new TestCanceledException("Previous step failed", 0)) 184 | } else { 185 | val out = super.withFixture(test) 186 | if (!out.isSucceeded) 187 | failed = true 188 | out 189 | } 190 | 191 | def clusterView: ClusterReadView = cluster.readView 192 | 193 | /** 194 | * Get the cluster node to use. 195 | */ 196 | def cluster: Cluster = Cluster(system) 197 | 198 | /** 199 | * Use this method for the initial startup of the cluster node. 200 | */ 201 | def startClusterNode(): Unit = { 202 | if (clusterView.members.isEmpty) { 203 | cluster.join(myself) 204 | awaitAssert(clusterView.members.map(_.address) should contain(address(myself))) 205 | } else 206 | clusterView.self 207 | } 208 | 209 | /** 210 | * Initialize the cluster of the specified member 211 | * nodes (roles) and wait until all joined and `Up`. 212 | * First node will be started first and others will join 213 | * the first. 214 | */ 215 | def awaitClusterUp(roles: RoleName*): Unit = { 216 | runOn(roles.head) { 217 | // make sure that the node-to-join is started before other join 218 | startClusterNode() 219 | } 220 | enterBarrier(roles.head.name + "-started") 221 | if (roles.tail.contains(myself)) { 222 | cluster.join(roles.head) 223 | } 224 | if (roles.contains(myself)) { 225 | awaitMembersUp(numberOfMembers = roles.length) 226 | } 227 | enterBarrier(roles.map(_.name).mkString("-") + "-joined") 228 | } 229 | 230 | /** 231 | * Join the specific node within the given period by sending repeated join 232 | * requests at periodic intervals until we succeed. 233 | */ 234 | def joinWithin(joinNode: RoleName, max: Duration = remainingOrDefault, interval: Duration = 1.second): Unit = { 235 | def memberInState(member: Address, status: Seq[MemberStatus]): Boolean = 236 | clusterView.members.exists { m => 237 | (m.address == member) && status.contains(m.status) 238 | } 239 | 240 | cluster.join(joinNode) 241 | awaitCond( 242 | { 243 | if (memberInState(joinNode, List(MemberStatus.Up)) && 244 | memberInState(myself, List(MemberStatus.Joining, MemberStatus.Up))) 245 | true 246 | else { 247 | cluster.join(joinNode) 248 | false 249 | } 250 | }, 251 | max, 252 | interval) 253 | } 254 | 255 | /** 256 | * Assert that the member addresses match the expected addresses in the 257 | * sort order used by the cluster. 258 | */ 259 | def assertMembers(gotMembers: Iterable[Member], expectedAddresses: Address*): Unit = { 260 | import Member.addressOrdering 261 | val members = gotMembers.toIndexedSeq 262 | members.size should ===(expectedAddresses.length) 263 | expectedAddresses.sorted.zipWithIndex.foreach { case (a, i) => members(i).address should ===(a) } 264 | } 265 | 266 | /** 267 | * Note that this can only be used for a cluster with all members 268 | * in Up status, i.e. use `awaitMembersUp` before using this method. 269 | * The reason for that is that the cluster leader is preferably a 270 | * member with status Up or Leaving and that information can't 271 | * be determined from the `RoleName`. 272 | */ 273 | def assertLeader(nodesInCluster: RoleName*): Unit = 274 | if (nodesInCluster.contains(myself)) assertLeaderIn(nodesInCluster.to(immutable.Seq)) 275 | 276 | /** 277 | * Assert that the cluster has elected the correct leader 278 | * out of all nodes in the cluster. First 279 | * member in the cluster ring is expected leader. 280 | * 281 | * Note that this can only be used for a cluster with all members 282 | * in Up status, i.e. use `awaitMembersUp` before using this method. 283 | * The reason for that is that the cluster leader is preferably a 284 | * member with status Up or Leaving and that information can't 285 | * be determined from the `RoleName`. 286 | */ 287 | def assertLeaderIn(nodesInCluster: immutable.Seq[RoleName]): Unit = 288 | if (nodesInCluster.contains(myself)) { 289 | nodesInCluster.length should not be (0) 290 | val expectedLeader = roleOfLeader(nodesInCluster) 291 | val leader = clusterView.leader 292 | val isLeader = leader == Some(clusterView.selfAddress) 293 | assert( 294 | isLeader == isNode(expectedLeader), 295 | "expectedLeader [%s], got leader [%s], members [%s]".format(expectedLeader, leader, clusterView.members)) 296 | clusterView.status should (be(MemberStatus.Up).or(be(MemberStatus.Leaving))) 297 | } 298 | 299 | /** 300 | * Wait until the expected number of members has status Up has been reached. 301 | * Also asserts that nodes in the 'canNotBePartOfMemberRing' are *not* part of the cluster ring. 302 | */ 303 | def awaitMembersUp( 304 | numberOfMembers: Int, 305 | canNotBePartOfMemberRing: Set[Address] = Set.empty, 306 | timeout: FiniteDuration = 25.seconds): Unit = { 307 | within(timeout) { 308 | if (!canNotBePartOfMemberRing.isEmpty) // don't run this on an empty set 309 | awaitAssert(canNotBePartOfMemberRing.foreach(a => clusterView.members.map(_.address) should not contain (a))) 310 | awaitAssert(clusterView.members.size should ===(numberOfMembers)) 311 | awaitAssert(clusterView.members.unsorted.map(_.status) should ===(Set(MemberStatus.Up))) 312 | // clusterView.leader is updated by LeaderChanged, await that to be updated also 313 | val expectedLeader = clusterView.members.collectFirst { 314 | case m if m.dataCenter == cluster.settings.SelfDataCenter => m.address 315 | } 316 | awaitAssert(clusterView.leader should ===(expectedLeader)) 317 | } 318 | } 319 | 320 | def awaitMemberRemoved(toBeRemovedAddress: Address, timeout: FiniteDuration = 25.seconds): Unit = within(timeout) { 321 | if (toBeRemovedAddress == cluster.selfAddress) { 322 | enterBarrier("registered-listener") 323 | 324 | cluster.leave(toBeRemovedAddress) 325 | enterBarrier("member-left") 326 | 327 | awaitCond(cluster.isTerminated, remaining) 328 | enterBarrier("member-shutdown") 329 | } else { 330 | val exitingLatch = TestLatch() 331 | 332 | val awaiter = system.actorOf(Props(new Actor { 333 | def receive = { 334 | case MemberRemoved(m, _) if m.address == toBeRemovedAddress => 335 | exitingLatch.countDown() 336 | case _ => 337 | // ignore 338 | } 339 | }).withDeploy(Deploy.local)) 340 | cluster.subscribe(awaiter, classOf[MemberEvent]) 341 | enterBarrier("registered-listener") 342 | 343 | // in the meantime member issues leave 344 | enterBarrier("member-left") 345 | 346 | // verify that the member is EXITING 347 | try Await.result(exitingLatch, timeout) 348 | catch { 349 | case cause: Exception => 350 | throw new AssertionError(s"Member ${toBeRemovedAddress} was not removed within ${timeout}!", cause) 351 | } 352 | awaiter ! PoisonPill // you've done your job, now die 353 | 354 | enterBarrier("member-shutdown") 355 | markNodeAsUnavailable(toBeRemovedAddress) 356 | } 357 | 358 | enterBarrier("member-totally-shutdown") 359 | } 360 | 361 | def awaitAllReachable(): Unit = 362 | awaitAssert(clusterView.unreachableMembers should ===(Set.empty)) 363 | 364 | /** 365 | * Wait until the specified nodes have seen the same gossip overview. 366 | */ 367 | def awaitSeenSameState(addresses: Address*): Unit = 368 | awaitAssert((addresses.toSet.diff(clusterView.seenBy)) should ===(Set.empty)) 369 | 370 | /** 371 | * Leader according to the address ordering of the roles. 372 | * Note that this can only be used for a cluster with all members 373 | * in Up status, i.e. use `awaitMembersUp` before using this method. 374 | * The reason for that is that the cluster leader is preferably a 375 | * member with status Up or Leaving and that information can't 376 | * be determined from the `RoleName`. 377 | */ 378 | def roleOfLeader(nodesInCluster: immutable.Seq[RoleName] = roles): RoleName = { 379 | nodesInCluster.length should not be (0) 380 | nodesInCluster.sorted.head 381 | } 382 | 383 | /** 384 | * Sort the roles in the address order used by the cluster node ring. 385 | */ 386 | implicit val clusterOrdering: Ordering[RoleName] = new Ordering[RoleName] { 387 | import Member.addressOrdering 388 | def compare(x: RoleName, y: RoleName) = addressOrdering.compare(address(x), address(y)) 389 | } 390 | 391 | def roleName(addr: Address): Option[RoleName] = roles.find(address(_) == addr) 392 | 393 | /** 394 | * Marks a node as available in the failure detector if 395 | * [[akka.cluster.FailureDetectorPuppet]] is used as 396 | * failure detector. 397 | */ 398 | def markNodeAsAvailable(address: Address): Unit = 399 | failureDetectorPuppet(address).foreach(_.markNodeAsAvailable()) 400 | 401 | /** 402 | * Marks a node as unavailable in the failure detector if 403 | * [[akka.cluster.FailureDetectorPuppet]] is used as 404 | * failure detector. 405 | */ 406 | def markNodeAsUnavailable(address: Address): Unit = { 407 | if (isFailureDetectorPuppet) { 408 | // before marking it as unavailable there should be at least one heartbeat 409 | // to create the FailureDetectorPuppet in the FailureDetectorRegistry 410 | cluster.failureDetector.heartbeat(address) 411 | failureDetectorPuppet(address).foreach(_.markNodeAsUnavailable()) 412 | } 413 | } 414 | 415 | private def isFailureDetectorPuppet: Boolean = 416 | cluster.settings.FailureDetectorImplementationClass == classOf[FailureDetectorPuppet].getName 417 | 418 | private def failureDetectorPuppet(address: Address): Option[FailureDetectorPuppet] = 419 | cluster.failureDetector match { 420 | case reg: DefaultFailureDetectorRegistry[Address] => 421 | reg.failureDetector(address).collect { case p: FailureDetectorPuppet => p } 422 | case _ => None 423 | } 424 | 425 | } 426 | -------------------------------------------------------------------------------- /src/multi-jvm/scala/akka/cluster/sharding/MultiNodeClusterShardingConfig.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2019 Lightbend Inc. 3 | */ 4 | 5 | package akka.cluster.sharding 6 | 7 | import akka.cluster.MultiNodeClusterSpec 8 | import akka.remote.testkit.MultiNodeConfig 9 | import akka.testkit.AkkaSpec 10 | import com.typesafe.config.{Config, ConfigFactory} 11 | 12 | /** 13 | * A MultiNodeConfig for ClusterSharding. Implement the roles, etc. and create with the following: 14 | * 15 | * @param mode the state store mode 16 | * @param rememberEntities defaults to off 17 | * @param overrides additional config 18 | * @param loglevel defaults to INFO 19 | */ 20 | abstract class MultiNodeClusterShardingConfig( 21 | val mode: String = ClusterShardingSettings.StateStoreModeDData, 22 | val rememberEntities: Boolean = false, 23 | overrides: Config = ConfigFactory.empty, 24 | loglevel: String = "INFO") 25 | extends MultiNodeConfig { 26 | 27 | val targetDir = s"target/ClusterSharding${AkkaSpec.getCallerName(getClass)}Spec-$mode-remember-$rememberEntities" 28 | 29 | val modeConfig = 30 | if (mode == ClusterShardingSettings.StateStoreModeDData) ConfigFactory.empty 31 | else ConfigFactory.parseString(s""" 32 | akka.persistence.journal.plugin = "akka.persistence.journal.leveldb-shared" 33 | akka.persistence.journal.leveldb-shared.timeout = 5s 34 | akka.persistence.journal.leveldb-shared.store.native = off 35 | akka.persistence.journal.leveldb-shared.store.dir = "$targetDir/journal" 36 | akka.persistence.snapshot-store.plugin = "akka.persistence.snapshot-store.local" 37 | akka.persistence.snapshot-store.local.dir = "$targetDir/snapshots" 38 | """) 39 | 40 | commonConfig( 41 | overrides 42 | .withFallback(modeConfig) 43 | .withFallback(ConfigFactory.parseString(s""" 44 | akka.loglevel = $loglevel 45 | akka.actor.provider = "cluster" 46 | akka.cluster.downing-provider-class = akka.cluster.testkit.AutoDowning 47 | akka.cluster.testkit.auto-down-unreachable-after = 0s 48 | akka.remote.log-remote-lifecycle-events = off 49 | akka.cluster.sharding.state-store-mode = "$mode" 50 | akka.cluster.sharding.distributed-data.durable.lmdb { 51 | dir = $targetDir/sharding-ddata 52 | map-size = 10 MiB 53 | } 54 | akka.actor.allow-java-serialization = on 55 | """)) 56 | // .withFallback(SharedLeveldbJournal.configToEnableJavaSerializationForTest) 57 | .withFallback(MultiNodeClusterSpec.clusterConfig)) 58 | 59 | } 60 | -------------------------------------------------------------------------------- /src/multi-jvm/scala/akka/cluster/sharding/MultiNodeClusterShardingSpec.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2019 Lightbend Inc. 3 | */ 4 | 5 | package akka.cluster.sharding 6 | 7 | import java.io.File 8 | 9 | import scala.concurrent.duration._ 10 | import akka.actor.{Actor, ActorIdentity, ActorRef, ActorSystem, Identify, PoisonPill, Props} 11 | import akka.cluster.{Cluster, MemberStatus, MultiNodeClusterSpec} 12 | import akka.persistence.Persistence 13 | import akka.persistence.journal.leveldb.{SharedLeveldbJournal, SharedLeveldbStore} 14 | import akka.remote.testconductor.RoleName 15 | import akka.remote.testkit.MultiNodeSpec 16 | import akka.testkit.TestProbe 17 | import org.apache.commons.io.FileUtils 18 | 19 | object MultiNodeClusterShardingSpec { 20 | 21 | final case class EntityStarted(ref: ActorRef) 22 | 23 | def props(probe: ActorRef): Props = Props(new EntityActor(probe)) 24 | 25 | class EntityActor(probe: ActorRef) extends Actor { 26 | probe ! EntityStarted(self) 27 | 28 | def receive: Receive = { 29 | case m => sender() ! m 30 | } 31 | } 32 | 33 | val defaultExtractEntityId: ShardRegion.ExtractEntityId = { 34 | case id: Int => (id.toString, id) 35 | } 36 | 37 | val defaultExtractShardId: ShardRegion.ExtractShardId = msg => 38 | msg match { 39 | case id: Int => id.toString 40 | case ShardRegion.StartEntity(id) => id 41 | } 42 | 43 | } 44 | 45 | abstract class MultiNodeClusterShardingSpec(val config: MultiNodeClusterShardingConfig) 46 | extends MultiNodeSpec(config) 47 | with MultiNodeClusterSpec { 48 | 49 | import MultiNodeClusterShardingSpec._ 50 | import config._ 51 | 52 | override def initialParticipants: Int = roles.size 53 | 54 | protected val storageLocations = List( 55 | new File(system.settings.config.getString("akka.cluster.sharding.distributed-data.durable.lmdb.dir")).getParentFile) 56 | 57 | override protected def atStartup(): Unit = { 58 | storageLocations.foreach(dir => if (dir.exists) FileUtils.deleteQuietly(dir)) 59 | enterBarrier("startup") 60 | super.atStartup() 61 | } 62 | 63 | override protected def afterTermination(): Unit = { 64 | storageLocations.foreach(dir => if (dir.exists) FileUtils.deleteQuietly(dir)) 65 | super.afterTermination() 66 | } 67 | 68 | protected def join(from: RoleName, to: RoleName): Unit = { 69 | runOn(from) { 70 | Cluster(system).join(node(to).address) 71 | awaitAssert { 72 | Cluster(system).state.members.exists(m => m.address == node(from).address && m.status == MemberStatus.Up) 73 | } 74 | } 75 | enterBarrier(from.name + "-joined") 76 | } 77 | 78 | protected def startSharding( 79 | sys: ActorSystem, 80 | entityProps: Props, 81 | dataType: String, 82 | extractEntityId: ShardRegion.ExtractEntityId = defaultExtractEntityId, 83 | extractShardId: ShardRegion.ExtractShardId = defaultExtractShardId, 84 | handOffStopMessage: Any = PoisonPill): ActorRef = { 85 | 86 | ClusterSharding(sys).start( 87 | typeName = dataType, 88 | entityProps = entityProps, 89 | settings = ClusterShardingSettings(sys).withRememberEntities(rememberEntities), 90 | extractEntityId = extractEntityId, 91 | extractShardId = extractShardId, 92 | ClusterSharding(sys).defaultShardAllocationStrategy(ClusterShardingSettings(sys)), 93 | handOffStopMessage) 94 | } 95 | 96 | protected def isDdataMode: Boolean = mode == ClusterShardingSettings.StateStoreModeDData 97 | 98 | private def setStoreIfNotDdataMode(sys: ActorSystem, storeOn: RoleName): Unit = 99 | if (!isDdataMode) { 100 | val probe = TestProbe()(sys) 101 | sys.actorSelection(node(storeOn) / "user" / "store").tell(Identify(None), probe.ref) 102 | val sharedStore = probe.expectMsgType[ActorIdentity](20.seconds).ref.get 103 | SharedLeveldbJournal.setStore(sharedStore, sys) 104 | } 105 | 106 | /** 107 | * {{{ 108 | * startPersistence(startOn = first, setStoreOn = Seq(first, second, third)) 109 | * }}} 110 | * 111 | * @param startOn the node to start the `SharedLeveldbStore` store on 112 | * @param setStoreOn the nodes to `SharedLeveldbJournal.setStore` on 113 | */ 114 | protected def startPersistenceIfNotDdataMode(startOn: RoleName, setStoreOn: Seq[RoleName]): Unit = 115 | if (!isDdataMode) { 116 | 117 | Persistence(system) 118 | runOn(startOn) { 119 | system.actorOf(Props[SharedLeveldbStore], "store") 120 | } 121 | enterBarrier("persistence-started") 122 | 123 | runOn(setStoreOn: _*) { 124 | setStoreIfNotDdataMode(system, startOn) 125 | } 126 | 127 | enterBarrier(s"after-${startOn.name}") 128 | 129 | } 130 | 131 | } 132 | 133 | -------------------------------------------------------------------------------- /src/multi-jvm/scala/akka/cluster/testkit/AutoDowning.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2009-2019 Lightbend Inc. 3 | */ 4 | 5 | package akka.cluster.testkit 6 | 7 | import scala.concurrent.duration.Duration 8 | import scala.concurrent.duration.FiniteDuration 9 | 10 | import akka.actor.Actor 11 | import akka.actor.ActorLogging 12 | import akka.actor.ActorSystem 13 | import akka.actor.Address 14 | import akka.actor.Cancellable 15 | import akka.actor.Props 16 | import akka.actor.Scheduler 17 | import akka.cluster.Cluster 18 | import akka.cluster.ClusterEvent._ 19 | import akka.cluster.DowningProvider 20 | import akka.cluster.Member 21 | import akka.cluster.MembershipState 22 | import akka.cluster.UniqueAddress 23 | import akka.util.Helpers.ConfigOps 24 | import akka.util.Helpers.Requiring 25 | import akka.util.Helpers.toRootLowerCase 26 | 27 | /** 28 | * Downing provider used for testing. 29 | * 30 | * Auto-downing is a naïve approach to remove unreachable nodes from the cluster membership. 31 | * In a production environment it will eventually break down the cluster. 32 | * When a network partition occurs, both sides of the partition will see the other side as unreachable 33 | * and remove it from the cluster. This results in the formation of two separate, disconnected, clusters 34 | * (known as *Split Brain*). 35 | * 36 | * This behavior is not limited to network partitions. It can also occur if a node in the cluster is 37 | * overloaded, or experiences a long GC pause. 38 | * 39 | * When using Cluster Singleton or Cluster Sharding it can break the contract provided by those features. 40 | * Both provide a guarantee that an actor will be unique in a cluster. 41 | * With the auto-down feature enabled, it is possible for multiple independent clusters to form (*Split Brain*). 42 | * When this happens the guaranteed uniqueness will no longer be true resulting in undesirable behavior 43 | * in the system. 44 | * 45 | * This is even more severe when Akka Persistence is used in conjunction with Cluster Sharding. 46 | * In this case, the lack of unique actors can cause multiple actors to write to the same journal. 47 | * Akka Persistence operates on a single writer principle. Having multiple writers will corrupt 48 | * the journal and make it unusable. 49 | * 50 | * Finally, even if you don't use features such as Persistence, Sharding, or Singletons, auto-downing can lead the 51 | * system to form multiple small clusters. These small clusters will be independent from each other. They will be 52 | * unable to communicate and as a result you may experience performance degradation. Once this condition occurs, 53 | * it will require manual intervention in order to reform the cluster. 54 | * 55 | * Because of these issues, auto-downing should never be used in a production environment. 56 | */ 57 | final class AutoDowning(system: ActorSystem) extends DowningProvider { 58 | 59 | private def clusterSettings = Cluster(system).settings 60 | 61 | private val AutoDownUnreachableAfter: Duration = { 62 | val key = "akka.cluster.testkit.auto-down-unreachable-after" 63 | // it's not in reference.conf, since only used in tests 64 | if (clusterSettings.config.hasPath(key)) { 65 | toRootLowerCase(clusterSettings.config.getString(key)) match { 66 | case "off" => Duration.Undefined 67 | case _ => clusterSettings.config.getMillisDuration(key).requiring(_ >= Duration.Zero, key + " >= 0s, or off") 68 | } 69 | } else 70 | Duration.Undefined 71 | } 72 | 73 | override def downRemovalMargin: FiniteDuration = clusterSettings.DownRemovalMargin 74 | 75 | override def downingActorProps: Option[Props] = 76 | AutoDownUnreachableAfter match { 77 | case d: FiniteDuration => Some(AutoDown.props(d)) 78 | case _ => None // auto-down-unreachable-after = off 79 | } 80 | } 81 | 82 | /** 83 | * INTERNAL API 84 | */ 85 | private[cluster] object AutoDown { 86 | 87 | def props(autoDownUnreachableAfter: FiniteDuration): Props = 88 | Props(classOf[AutoDown], autoDownUnreachableAfter) 89 | 90 | final case class UnreachableTimeout(node: UniqueAddress) 91 | } 92 | 93 | /** 94 | * INTERNAL API 95 | * 96 | * An unreachable member will be downed by this actor if it remains unreachable 97 | * for the specified duration and this actor is running on the leader node in the 98 | * cluster. 99 | * 100 | * The implementation is split into two classes AutoDown and AutoDownBase to be 101 | * able to unit test the logic without running cluster. 102 | */ 103 | private[cluster] class AutoDown(autoDownUnreachableAfter: FiniteDuration) 104 | extends AutoDownBase(autoDownUnreachableAfter) 105 | with ActorLogging { 106 | 107 | val cluster = Cluster(context.system) 108 | import cluster.ClusterLogger._ 109 | 110 | override def selfAddress = cluster.selfAddress 111 | 112 | override def scheduler: Scheduler = cluster.scheduler 113 | 114 | // re-subscribe when restart 115 | override def preStart(): Unit = { 116 | log.debug("Auto-down is enabled in test.") 117 | cluster.subscribe(self, classOf[ClusterDomainEvent]) 118 | super.preStart() 119 | } 120 | override def postStop(): Unit = { 121 | cluster.unsubscribe(self) 122 | super.postStop() 123 | } 124 | 125 | override def down(node: Address): Unit = { 126 | require(leader) 127 | logInfo("Leader is auto-downing unreachable node [{}].", node) 128 | cluster.down(node) 129 | } 130 | 131 | } 132 | 133 | /** 134 | * INTERNAL API 135 | * 136 | * The implementation is split into two classes AutoDown and AutoDownBase to be 137 | * able to unit test the logic without running cluster. 138 | */ 139 | private[cluster] abstract class AutoDownBase(autoDownUnreachableAfter: FiniteDuration) extends Actor { 140 | 141 | import AutoDown._ 142 | 143 | def selfAddress: Address 144 | 145 | def down(node: Address): Unit 146 | 147 | def scheduler: Scheduler 148 | 149 | import context.dispatcher 150 | 151 | val skipMemberStatus = MembershipState.convergenceSkipUnreachableWithMemberStatus 152 | 153 | var scheduledUnreachable: Map[UniqueAddress, Cancellable] = Map.empty 154 | var pendingUnreachable: Set[UniqueAddress] = Set.empty 155 | var leader = false 156 | 157 | override def postStop(): Unit = { 158 | scheduledUnreachable.values.foreach { _.cancel } 159 | } 160 | 161 | def receive = { 162 | case state: CurrentClusterState => 163 | leader = state.leader.exists(_ == selfAddress) 164 | state.unreachable.foreach(unreachableMember) 165 | 166 | case UnreachableMember(m) => unreachableMember(m) 167 | 168 | case ReachableMember(m) => remove(m.uniqueAddress) 169 | case MemberRemoved(m, _) => remove(m.uniqueAddress) 170 | 171 | case LeaderChanged(leaderOption) => 172 | leader = leaderOption.exists(_ == selfAddress) 173 | if (leader) { 174 | pendingUnreachable.foreach(node => down(node.address)) 175 | pendingUnreachable = Set.empty 176 | } 177 | 178 | case UnreachableTimeout(node) => 179 | if (scheduledUnreachable contains node) { 180 | scheduledUnreachable -= node 181 | downOrAddPending(node) 182 | } 183 | 184 | case _: ClusterDomainEvent => // not interested in other events 185 | 186 | } 187 | 188 | def unreachableMember(m: Member): Unit = 189 | if (!skipMemberStatus(m.status) && !scheduledUnreachable.contains(m.uniqueAddress)) 190 | scheduleUnreachable(m.uniqueAddress) 191 | 192 | def scheduleUnreachable(node: UniqueAddress): Unit = { 193 | if (autoDownUnreachableAfter == Duration.Zero) { 194 | downOrAddPending(node) 195 | } else { 196 | val task = scheduler.scheduleOnce(autoDownUnreachableAfter, self, UnreachableTimeout(node)) 197 | scheduledUnreachable += (node -> task) 198 | } 199 | } 200 | 201 | def downOrAddPending(node: UniqueAddress): Unit = { 202 | if (leader) { 203 | down(node.address) 204 | } else { 205 | // it's supposed to be downed by another node, current leader, but if that crash 206 | // a new leader must pick up these 207 | pendingUnreachable += node 208 | } 209 | } 210 | 211 | def remove(node: UniqueAddress): Unit = { 212 | scheduledUnreachable.get(node).foreach { _.cancel } 213 | scheduledUnreachable -= node 214 | pendingUnreachable -= node 215 | } 216 | 217 | } 218 | -------------------------------------------------------------------------------- /src/multi-jvm/scala/akka/remote/testkit/STMultiNodeSpec.scala: -------------------------------------------------------------------------------- 1 | package akka.remote.testkit 2 | 3 | import org.scalatest.{BeforeAndAfterAll, WordSpecLike} 4 | import org.scalatest.Matchers 5 | 6 | /** 7 | * Hooks up MultiNodeSpec with ScalaTest 8 | */ 9 | trait STMultiNodeSpec extends MultiNodeSpecCallbacks with WordSpecLike with Matchers with BeforeAndAfterAll { 10 | 11 | override def beforeAll() = multiNodeSpecBeforeAll() 12 | 13 | override def afterAll() = multiNodeSpecAfterAll() 14 | } -------------------------------------------------------------------------------- /src/multi-jvm/scala/akka/serialization/jackson/CborSerializable.scala: -------------------------------------------------------------------------------- 1 | package akka.serialization.jackson 2 | 3 | trait CborSerializable -------------------------------------------------------------------------------- /src/multi-jvm/scala/akka/testkit/AkkaSpec.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (C) 2009-2019 Lightbend Inc. 3 | */ 4 | 5 | package akka.testkit 6 | 7 | import org.scalactic.{ CanEqual, TypeCheckedTripleEquals } 8 | 9 | import language.postfixOps 10 | import org.scalatest.{ BeforeAndAfterAll, WordSpecLike } 11 | import org.scalatest.Matchers 12 | import akka.actor.ActorSystem 13 | import akka.event.{ Logging, LoggingAdapter } 14 | 15 | import scala.concurrent.Future 16 | import com.typesafe.config.{ Config, ConfigFactory } 17 | import akka.dispatch.Dispatchers 18 | import akka.testkit.TestEvent._ 19 | import org.scalatest.concurrent.ScalaFutures 20 | import org.scalatest.time.{ Millis, Span } 21 | 22 | object AkkaSpec { 23 | val testConf: Config = ConfigFactory.parseString(""" 24 | akka { 25 | loggers = ["akka.testkit.TestEventListener"] 26 | loglevel = "WARNING" 27 | stdout-loglevel = "WARNING" 28 | actor { 29 | default-dispatcher { 30 | executor = "fork-join-executor" 31 | fork-join-executor { 32 | parallelism-min = 8 33 | parallelism-factor = 2.0 34 | parallelism-max = 8 35 | } 36 | } 37 | } 38 | } 39 | """) 40 | 41 | def mapToConfig(map: Map[String, Any]): Config = { 42 | import akka.util.ccompat.JavaConverters._ 43 | ConfigFactory.parseMap(map.asJava) 44 | } 45 | 46 | def getCallerName(clazz: Class[_]): String = { 47 | val s = Thread.currentThread.getStackTrace 48 | .map(_.getClassName) 49 | .drop(1) 50 | .dropWhile(_.matches("(java.lang.Thread|.*AkkaSpec.*|.*\\.StreamSpec.*|.*MultiNodeSpec.*|.*\\.Abstract.*)")) 51 | val reduced = s.lastIndexWhere(_ == clazz.getName) match { 52 | case -1 => s 53 | case z => s.drop(z + 1) 54 | } 55 | reduced.head.replaceFirst(""".*\.""", "").replaceAll("[^a-zA-Z_0-9]", "_") 56 | } 57 | 58 | } 59 | 60 | abstract class AkkaSpec(_system: ActorSystem) 61 | extends TestKit(_system) 62 | with WordSpecLike 63 | with Matchers 64 | with BeforeAndAfterAll 65 | with TypeCheckedTripleEquals 66 | with ScalaFutures { 67 | 68 | implicit val patience = PatienceConfig(testKitSettings.DefaultTimeout.duration, Span(100, Millis)) 69 | 70 | def this(config: Config) = 71 | this(ActorSystem(AkkaSpec.getCallerName(getClass), ConfigFactory.load(config.withFallback(AkkaSpec.testConf)))) 72 | 73 | def this(s: String) = this(ConfigFactory.parseString(s)) 74 | 75 | def this(configMap: Map[String, _]) = this(AkkaSpec.mapToConfig(configMap)) 76 | 77 | def this() = this(ActorSystem(AkkaSpec.getCallerName(getClass), AkkaSpec.testConf)) 78 | 79 | val log: LoggingAdapter = Logging(system, this.getClass) 80 | 81 | override val invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected = true 82 | 83 | final override def beforeAll: Unit = { 84 | atStartup() 85 | } 86 | 87 | final override def afterAll: Unit = { 88 | beforeTermination() 89 | shutdown() 90 | afterTermination() 91 | } 92 | 93 | protected def atStartup(): Unit = {} 94 | 95 | protected def beforeTermination(): Unit = {} 96 | 97 | protected def afterTermination(): Unit = {} 98 | 99 | def spawn(dispatcherId: String = Dispatchers.DefaultDispatcherId)(body: => Unit): Unit = 100 | Future(body)(system.dispatchers.lookup(dispatcherId)) 101 | 102 | def muteDeadLetters(messageClasses: Class[_]*)(sys: ActorSystem = system): Unit = 103 | if (!sys.log.isDebugEnabled) { 104 | def mute(clazz: Class[_]): Unit = 105 | sys.eventStream.publish(Mute(DeadLettersFilter(clazz)(occurrences = Int.MaxValue))) 106 | if (messageClasses.isEmpty) mute(classOf[AnyRef]) 107 | else messageClasses.foreach(mute) 108 | } 109 | 110 | // for ScalaTest === compare of Class objects 111 | implicit def classEqualityConstraint[A, B]: CanEqual[Class[A], Class[B]] = 112 | new CanEqual[Class[A], Class[B]] { 113 | def areEqual(a: Class[A], b: Class[B]) = a == b 114 | } 115 | 116 | implicit def setEqualityConstraint[A, T <: Set[_ <: A]]: CanEqual[Set[A], T] = 117 | new CanEqual[Set[A], T] { 118 | def areEqual(a: Set[A], b: T) = a == b 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /src/multi-jvm/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRouterNewShardsSpec.scala: -------------------------------------------------------------------------------- 1 | package io.bernhardt.akka.locality.router 2 | 3 | import akka.actor.{ActorRef, Props} 4 | import akka.cluster.routing.{ClusterRouterGroup, ClusterRouterGroupSettings} 5 | import akka.cluster.sharding.{MultiNodeClusterShardingConfig, MultiNodeClusterShardingSpec} 6 | import akka.pattern.ask 7 | import akka.remote.testconductor.RoleName 8 | import akka.remote.testkit.STMultiNodeSpec 9 | import akka.routing.{GetRoutees, Routees} 10 | import akka.testkit.{DefaultTimeout, ImplicitSender, TestProbe} 11 | import com.typesafe.config.ConfigFactory 12 | import io.bernhardt.akka.locality.Locality 13 | 14 | import scala.concurrent.Await 15 | import scala.concurrent.duration._ 16 | 17 | object ShardLocationAwareRouterNewShardsSpecConfig extends MultiNodeClusterShardingConfig( 18 | overrides = ConfigFactory.parseString( 19 | """akka.cluster.sharding.distributed-data.majority-min-cap = 2 20 | |akka.cluster.distributed-data.gossip-interval = 200 ms 21 | |akka.cluster.distributed-data.notify-subscribers-interval = 200 ms 22 | |akka.cluster.sharding.updating-state-timeout = 500 ms 23 | |akka.locality.shard-state-update-margin = 1000 ms 24 | |akka.locality.shard-state-polling-interval = 2000 ms 25 | |""".stripMargin)) { 26 | val first = role("first") 27 | val second = role("second") 28 | val third = role("third") 29 | val fourth = role("fourth") 30 | val fifth = role("fifth") 31 | } 32 | 33 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode1 extends ShardLocationAwareRouterNewShardsSpec 34 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode2 extends ShardLocationAwareRouterNewShardsSpec 35 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode3 extends ShardLocationAwareRouterNewShardsSpec 36 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode4 extends ShardLocationAwareRouterNewShardsSpec 37 | class ShardLocationAwareRouterNewShardsSpecMultiJvmNode5 extends ShardLocationAwareRouterNewShardsSpec 38 | 39 | class ShardLocationAwareRouterNewShardsSpec extends MultiNodeClusterShardingSpec(ShardLocationAwareRouterNewShardsSpecConfig) 40 | with STMultiNodeSpec 41 | with DefaultTimeout 42 | with ImplicitSender { 43 | 44 | import ShardLocationAwareRouterNewShardsSpecConfig._ 45 | import ShardLocationAwareRouterSpec._ 46 | 47 | var region: Option[ActorRef] = None 48 | 49 | var router: ActorRef = ActorRef.noSender 50 | 51 | Locality(system) 52 | 53 | def joinAndAllocate(node: RoleName, entityIds: Range): Unit = { 54 | within(10.seconds) { 55 | join(node, first) 56 | runOn(node) { 57 | val region = startSharding( 58 | sys = system, 59 | entityProps = Props[TestEntity], 60 | dataType = "TestEntity", 61 | extractEntityId = extractEntityId, 62 | extractShardId = extractShardId) 63 | 64 | this.region = Some(region) 65 | 66 | entityIds.map { entityId => 67 | val probe = TestProbe("test") 68 | val msg = Ping(entityId, ActorRef.noSender) 69 | probe.send(region, msg) 70 | probe.expectMsgType[Pong] 71 | probe.lastSender.path should be(region.path / s"$entityId" / s"$entityId") 72 | } 73 | } 74 | } 75 | enterBarrier(s"started") 76 | } 77 | 78 | "allocate shards" in { 79 | 80 | joinAndAllocate(first, (1 to 10)) 81 | joinAndAllocate(second, (11 to 20)) 82 | joinAndAllocate(third, (21 to 30)) 83 | joinAndAllocate(fourth, (31 to 40)) 84 | joinAndAllocate(fifth, (41 to 50)) 85 | 86 | enterBarrier("shards-allocated") 87 | 88 | } 89 | 90 | "route by taking into account shard location" in { 91 | within(20.seconds) { 92 | 93 | region.map { r => 94 | system.actorOf(Props(new TestRoutee(r)), "routee") 95 | enterBarrier("routee-started") 96 | 97 | router = system.actorOf(ClusterRouterGroup(ShardLocationAwareGroup( 98 | routeePaths = Nil, 99 | shardRegion = r, 100 | extractEntityId = extractEntityId, 101 | extractShardId = extractShardId 102 | ), ClusterRouterGroupSettings( 103 | totalInstances = 5, 104 | routeesPaths = List("/user/routee"), 105 | allowLocalRoutees = true 106 | )).props(), "sharding-aware-router") 107 | 108 | awaitAssert { 109 | currentRoutees(router).size shouldBe 5 110 | } 111 | 112 | enterBarrier("router-started") 113 | 114 | runOn(first) { 115 | val probe = TestProbe("probe") 116 | for (i <- 1 to 50) { 117 | probe.send(router, Ping(i, probe.ref)) 118 | } 119 | 120 | val msgs: Seq[Pong] = probe.receiveN(50).collect { case p: Pong => p } 121 | 122 | val (same: Seq[Pong], different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) => 123 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 124 | } 125 | 126 | different.isEmpty shouldBe true 127 | } 128 | enterBarrier("test-done") 129 | 130 | } getOrElse { 131 | fail("Region not set") 132 | } 133 | } 134 | 135 | } 136 | 137 | "route randomly when new shards are created" in { 138 | runOn(first) { 139 | val probe = TestProbe("probe") 140 | for (i <- 51 to 100) { 141 | probe.send(router, Ping(i, probe.ref)) 142 | } 143 | 144 | val msgs: Seq[Pong] = probe.receiveN(50).collect { case p: Pong => p } 145 | 146 | val (_, different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) => 147 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 148 | } 149 | 150 | different.isEmpty shouldBe false 151 | } 152 | enterBarrier("test-done") 153 | } 154 | 155 | "route with location awareness after the update margin has elapsed" in { 156 | 157 | Thread.sleep(4000) 158 | 159 | runOn(first) { 160 | val probe = TestProbe("probe") 161 | for (i <- 51 to 100) { 162 | probe.send(router, Ping(i, probe.ref)) 163 | } 164 | 165 | val msgs: Seq[Pong] = probe.receiveN(50).collect { case p: Pong => p } 166 | 167 | val (_, different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) => 168 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 169 | } 170 | 171 | different.isEmpty shouldBe true 172 | } 173 | enterBarrier("test-done") 174 | } 175 | 176 | 177 | 178 | def currentRoutees(router: ActorRef) = 179 | Await.result(router ? GetRoutees, timeout.duration).asInstanceOf[Routees].routees 180 | 181 | def partitionByAddress(msgs: Seq[Pong]) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) => 182 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 183 | } 184 | 185 | 186 | } 187 | -------------------------------------------------------------------------------- /src/multi-jvm/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRouterSpec.scala: -------------------------------------------------------------------------------- 1 | package io.bernhardt.akka.locality.router 2 | 3 | import akka.actor.{Actor, ActorRef, Address, Props} 4 | import akka.cluster.Cluster 5 | import akka.cluster.routing.{ClusterRouterGroup, ClusterRouterGroupSettings} 6 | import akka.cluster.sharding.{MultiNodeClusterShardingConfig, MultiNodeClusterShardingSpec, ShardRegion} 7 | import akka.pattern.ask 8 | import akka.remote.testconductor.RoleName 9 | import akka.remote.testkit.STMultiNodeSpec 10 | import akka.routing.{GetRoutees, Routees} 11 | import akka.serialization.jackson.CborSerializable 12 | import akka.testkit.{DefaultTimeout, ImplicitSender, TestProbe} 13 | import com.typesafe.config.ConfigFactory 14 | import io.bernhardt.akka.locality.Locality 15 | 16 | import scala.concurrent.Await 17 | import scala.concurrent.duration._ 18 | 19 | object ShardLocationAwareRouterSpec { 20 | 21 | class TestEntity extends Actor { 22 | val cluster = Cluster(context.system) 23 | def receive: Receive = { 24 | 25 | case ping: Ping => 26 | sender() ! Pong(ping.id, ping.sender, cluster.selfAddress, cluster.selfAddress) 27 | } 28 | } 29 | 30 | final case class Ping(id: Int, sender: ActorRef) extends CborSerializable 31 | final case class Pong(id: Int, sender: ActorRef, routeeAddress: Address, entityAddress: Address) extends CborSerializable 32 | 33 | val extractEntityId: ShardRegion.ExtractEntityId = { 34 | case msg@Ping(id, _) => (id.toString, msg) 35 | case msg@Pong(id, _, _, _) => (id.toString, msg) 36 | } 37 | 38 | val extractShardId: ShardRegion.ExtractShardId = { 39 | // take this simplest mapping on purpose 40 | case Ping(id, _) => id.toString 41 | case Pong(id, _, _, _) => id.toString 42 | } 43 | 44 | class TestRoutee(region: ActorRef) extends Actor { 45 | val cluster = Cluster(context.system) 46 | def receive: Receive = { 47 | case msg: Ping => 48 | region ! msg 49 | case pong: Pong => 50 | pong.sender ! pong.copy(routeeAddress = cluster.selfAddress) 51 | } 52 | } 53 | 54 | } 55 | 56 | 57 | object ShardLocationAwareRouterSpecConfig extends MultiNodeClusterShardingConfig( 58 | overrides = ConfigFactory.parseString( 59 | """akka.cluster.sharding.distributed-data.majority-min-cap = 2 60 | |akka.cluster.distributed-data.gossip-interval = 200 ms 61 | |akka.cluster.distributed-data.notify-subscribers-interval = 200 ms 62 | |akka.cluster.sharding.updating-state-timeout = 500 ms 63 | |akka.locality.shard-state-update-margin = 1000 ms 64 | |""".stripMargin)) { 65 | val first = role("first") 66 | val second = role("second") 67 | val third = role("third") 68 | val fourth = role("fourth") 69 | val fifth = role("fifth") 70 | } 71 | 72 | class ShardLocationAwareRouterSpecMultiJvmNode1 extends ShardLocationAwareRouterSpec 73 | class ShardLocationAwareRouterSpecMultiJvmNode2 extends ShardLocationAwareRouterSpec 74 | class ShardLocationAwareRouterSpecMultiJvmNode3 extends ShardLocationAwareRouterSpec 75 | class ShardLocationAwareRouterSpecMultiJvmNode4 extends ShardLocationAwareRouterSpec 76 | class ShardLocationAwareRouterSpecMultiJvmNode5 extends ShardLocationAwareRouterSpec 77 | 78 | class ShardLocationAwareRouterSpec extends MultiNodeClusterShardingSpec(ShardLocationAwareRouterSpecConfig) 79 | with STMultiNodeSpec 80 | with DefaultTimeout 81 | with ImplicitSender { 82 | 83 | import ShardLocationAwareRouterSpec._ 84 | import ShardLocationAwareRouterSpecConfig._ 85 | 86 | var region: Option[ActorRef] = None 87 | 88 | var router: ActorRef = ActorRef.noSender 89 | 90 | Locality(system) 91 | 92 | def joinAndAllocate(node: RoleName, entityIds: Range): Unit = { 93 | within(10.seconds) { 94 | join(node, first) 95 | runOn(node) { 96 | val region = startSharding( 97 | sys = system, 98 | entityProps = Props[TestEntity], 99 | dataType = "TestEntity", 100 | extractEntityId = extractEntityId, 101 | extractShardId = extractShardId) 102 | 103 | this.region = Some(region) 104 | 105 | entityIds.map { entityId => 106 | val probe = TestProbe("test") 107 | val msg = Ping(entityId, ActorRef.noSender) 108 | probe.send(region, msg) 109 | probe.expectMsgType[Pong] 110 | probe.lastSender.path should be(region.path / s"$entityId" / s"$entityId") 111 | } 112 | } 113 | } 114 | enterBarrier(s"started") 115 | } 116 | 117 | "allocate shards" in { 118 | 119 | joinAndAllocate(first, (1 to 10)) 120 | joinAndAllocate(second, (11 to 20)) 121 | joinAndAllocate(third, (21 to 30)) 122 | joinAndAllocate(fourth, (31 to 40)) 123 | joinAndAllocate(fifth, (41 to 50)) 124 | 125 | enterBarrier("shards-allocated") 126 | 127 | } 128 | 129 | "route by taking into account shard location" in { 130 | within(20.seconds) { 131 | 132 | region.map { r => 133 | system.actorOf(Props(new TestRoutee(r)), "routee") 134 | enterBarrier("routee-started") 135 | 136 | router = system.actorOf(ClusterRouterGroup(ShardLocationAwareGroup( 137 | routeePaths = Nil, 138 | shardRegion = r, 139 | extractEntityId = extractEntityId, 140 | extractShardId = extractShardId 141 | ), ClusterRouterGroupSettings( 142 | totalInstances = 5, 143 | routeesPaths = List("/user/routee"), 144 | allowLocalRoutees = true 145 | )).props(), "sharding-aware-router") 146 | 147 | awaitAssert { 148 | currentRoutees(router).size shouldBe 5 149 | } 150 | 151 | enterBarrier("router-started") 152 | 153 | runOn(first) { 154 | val probe = TestProbe("probe") 155 | for (i <- 1 to 50) { 156 | probe.send(router, Ping(i, probe.ref)) 157 | } 158 | 159 | val msgs: Seq[Pong] = probe.receiveN(50).collect { case p: Pong => p } 160 | 161 | val (same: Seq[Pong], different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) => 162 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 163 | } 164 | 165 | different.isEmpty shouldBe true 166 | 167 | val byAddress = same.groupBy(_.routeeAddress) 168 | 169 | byAddress(first).size shouldBe 10 170 | byAddress(first).map(_.id).toSet shouldEqual (1 to 10).toSet 171 | byAddress(second).size shouldBe 10 172 | byAddress(second).map(_.id).toSet shouldEqual (11 to 20).toSet 173 | byAddress(third).size shouldBe 10 174 | byAddress(third).map(_.id).toSet shouldEqual (21 to 30).toSet 175 | byAddress(fourth).size shouldBe 10 176 | byAddress(fourth).map(_.id).toSet shouldEqual (31 to 40).toSet 177 | byAddress(fifth).size shouldBe 10 178 | byAddress(fifth).map(_.id).toSet shouldEqual (41 to 50).toSet 179 | } 180 | enterBarrier("test-done") 181 | 182 | } getOrElse { 183 | fail("Region not set") 184 | } 185 | } 186 | 187 | } 188 | 189 | "adjust routing after a topology change" in { 190 | awaitMemberRemoved(fourth) 191 | awaitAllReachable() 192 | 193 | runOn(first) { 194 | // trigger rebalancing the shards of the removed node 195 | val rebalanceProbe = TestProbe("rebalance") 196 | for (i <- 31 to 40) { 197 | rebalanceProbe.send(router, Ping(i, rebalanceProbe.ref)) 198 | } 199 | 200 | // we should be receiving messages even in the absence of the updated shard location information 201 | // random routing should kick in, i.e. we won't have perfect matches 202 | val randomRoutedMessages: Seq[Pong] = rebalanceProbe.receiveN(10, 15.seconds).collect { case p: Pong => p } 203 | val (_, differentMsgs) = partitionByAddress(randomRoutedMessages) 204 | differentMsgs.nonEmpty shouldBe true 205 | 206 | // now give time to the new shards to be allocated and time to the router to retrieve new information 207 | // one second to send out the update shard information and one second of safety margin 208 | Thread.sleep(2000) 209 | 210 | val probe = TestProbe("probe") 211 | for (i <- 1 to 50) { 212 | probe.send(router, Ping(i, probe.ref)) 213 | } 214 | 215 | val msgs: Seq[Pong] = probe.receiveN(50, 15.seconds).collect { case p: Pong => p } 216 | 217 | val (_, different) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) => 218 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 219 | } 220 | 221 | different.isEmpty shouldBe true 222 | 223 | } 224 | 225 | enterBarrier("finished") 226 | } 227 | 228 | 229 | def currentRoutees(router: ActorRef) = 230 | Await.result(router ? GetRoutees, timeout.duration).asInstanceOf[Routees].routees 231 | 232 | def partitionByAddress(msgs: Seq[Pong]) = msgs.partition { case Pong(_, _, routeeAddress, entityAddress) => 233 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 234 | } 235 | 236 | } 237 | -------------------------------------------------------------------------------- /src/multi-jvm/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRouterWithProxySpec.scala: -------------------------------------------------------------------------------- 1 | package io.bernhardt.akka.locality.router 2 | 3 | import akka.actor.{ ActorRef, Props } 4 | import akka.cluster.routing.{ ClusterRouterGroup, ClusterRouterGroupSettings } 5 | import akka.cluster.sharding.{ ClusterSharding, MultiNodeClusterShardingConfig, MultiNodeClusterShardingSpec } 6 | import akka.pattern.ask 7 | import akka.remote.testconductor.RoleName 8 | import akka.remote.testkit.STMultiNodeSpec 9 | import akka.routing.{ GetRoutees, Routees } 10 | import akka.testkit.{ DefaultTimeout, ImplicitSender, TestProbe } 11 | import com.typesafe.config.ConfigFactory 12 | import io.bernhardt.akka.locality.Locality 13 | 14 | import scala.concurrent.Await 15 | import scala.concurrent.duration._ 16 | 17 | object ShardLocationAwareRouterWithProxySpecConfig 18 | extends MultiNodeClusterShardingConfig( 19 | loglevel = "INFO", 20 | overrides = ConfigFactory.parseString( 21 | """akka.cluster.sharding.distributed-data.majority-min-cap = 2 22 | |akka.cluster.distributed-data.gossip-interval = 200 ms 23 | |akka.cluster.distributed-data.notify-subscribers-interval = 200 ms 24 | |akka.cluster.sharding.updating-state-timeout = 500 ms 25 | |akka.locality.shard-state-update-margin = 1000 ms 26 | |""".stripMargin)) { 27 | val first = role("first") 28 | val second = role("second") 29 | val third = role("third") 30 | val fourth = role("fourth") 31 | val fifth = role("fifth") 32 | 33 | } 34 | 35 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode1 extends ShardLocationAwareRouterWithProxySpec 36 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode2 extends ShardLocationAwareRouterWithProxySpec 37 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode3 extends ShardLocationAwareRouterWithProxySpec 38 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode4 extends ShardLocationAwareRouterWithProxySpec 39 | class ShardLocationAwareRouterWithProxySpecMultiJvmNode5 extends ShardLocationAwareRouterWithProxySpec 40 | 41 | class ShardLocationAwareRouterWithProxySpec 42 | extends MultiNodeClusterShardingSpec(ShardLocationAwareRouterWithProxySpecConfig) 43 | with STMultiNodeSpec 44 | with DefaultTimeout 45 | with ImplicitSender { 46 | 47 | import ShardLocationAwareRouterSpec._ 48 | import ShardLocationAwareRouterWithProxySpecConfig._ 49 | 50 | var region: Option[ActorRef] = None 51 | var proxyRegion: Option[ActorRef] = None 52 | 53 | var router: ActorRef = ActorRef.noSender 54 | var proxyRouter: ActorRef = ActorRef.noSender 55 | 56 | Locality(system) 57 | 58 | def joinAndAllocate(node: RoleName, entityIds: Range): Unit = { 59 | within(10.seconds) { 60 | join(node, first) 61 | runOn(node) { 62 | val region = startSharding( 63 | sys = system, 64 | entityProps = Props[TestEntity], 65 | dataType = "TestEntity", 66 | extractEntityId = extractEntityId, 67 | extractShardId = extractShardId) 68 | 69 | this.region = Some(region) 70 | 71 | entityIds.map { entityId => 72 | val probe = TestProbe("test") 73 | val msg = Ping(entityId, ActorRef.noSender) 74 | probe.send(region, msg) 75 | probe.expectMsgType[Pong] 76 | probe.lastSender.path should be(region.path / s"$entityId" / s"$entityId") 77 | } 78 | } 79 | } 80 | enterBarrier(s"started") 81 | } 82 | 83 | "allocate shards" in { 84 | 85 | joinAndAllocate(first, (1 to 10)) 86 | joinAndAllocate(second, (11 to 20)) 87 | joinAndAllocate(third, (21 to 30)) 88 | joinAndAllocate(fourth, (31 to 40)) 89 | 90 | join(fifth, first) 91 | 92 | runOn(fifth) { 93 | val proxy = ClusterSharding(system).startProxy( 94 | typeName = "TestEntity", 95 | role = None, 96 | dataCenter = None, 97 | extractEntityId = extractEntityId, 98 | extractShardId = extractShardId) 99 | proxyRegion = Some(proxy) 100 | } 101 | 102 | awaitClusterUp(roles: _*) 103 | 104 | enterBarrier("shards-allocated") 105 | 106 | } 107 | 108 | "route by taking into account shard location" in { 109 | within(20.seconds) { 110 | 111 | runOn(first, second, third, fourth) { 112 | region 113 | .map { r => 114 | system.actorOf(Props(new TestRoutee(r)), "routee") 115 | enterBarrier("routee-started") 116 | 117 | router = system.actorOf( 118 | ClusterRouterGroup( 119 | ShardLocationAwareGroup( 120 | routeePaths = Nil, 121 | shardRegion = r, 122 | extractEntityId = extractEntityId, 123 | extractShardId = extractShardId), 124 | ClusterRouterGroupSettings( 125 | totalInstances = 4, 126 | routeesPaths = List("/user/routee"), 127 | allowLocalRoutees = true)).props(), 128 | "sharding-aware-router") 129 | 130 | awaitAssert { 131 | currentRoutees(router).size shouldBe 4 132 | } 133 | enterBarrier("router-started") 134 | 135 | } 136 | .getOrElse { 137 | fail("Region not set") 138 | } 139 | 140 | } 141 | runOn(fifth) { 142 | 143 | // no routees here 144 | enterBarrier("routee-started") 145 | 146 | proxyRegion 147 | .map { proxy => 148 | proxyRouter = system.actorOf( 149 | ClusterRouterGroup( 150 | ShardLocationAwareGroup( 151 | routeePaths = Nil, 152 | shardRegion = proxy, 153 | extractEntityId = extractEntityId, 154 | extractShardId = extractShardId), 155 | ClusterRouterGroupSettings( 156 | totalInstances = 4, 157 | routeesPaths = List("/user/routee"), 158 | allowLocalRoutees = false)).props(), 159 | "sharding-aware-router") 160 | } 161 | .getOrElse { 162 | fail("Proxy region not set") 163 | } 164 | 165 | awaitAssert { 166 | currentRoutees(proxyRouter).size shouldBe 4 167 | } 168 | enterBarrier("router-started") 169 | } 170 | 171 | // now give time to the new shards to be allocated and time to the router to retrieve new information 172 | // one second to send out the update shard information and one second of safety margin 173 | Thread.sleep(2000) 174 | 175 | runOn(fifth) { 176 | val probe = TestProbe("probe") 177 | for (i <- 1 to 40) { 178 | probe.send(proxyRouter, Ping(i, probe.ref)) 179 | } 180 | 181 | val msgs: Seq[Pong] = probe.receiveN(40).collect { case p: Pong => p } 182 | 183 | val (same: Seq[Pong], different) = msgs.partition { 184 | case Pong(_, _, routeeAddress, entityAddress) => 185 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 186 | } 187 | 188 | different.isEmpty shouldBe true 189 | 190 | val byAddress = same.groupBy(_.routeeAddress) 191 | 192 | byAddress(first).size shouldBe 10 193 | byAddress(first).map(_.id).toSet shouldEqual (1 to 10).toSet 194 | byAddress(second).size shouldBe 10 195 | byAddress(second).map(_.id).toSet shouldEqual (11 to 20).toSet 196 | byAddress(third).size shouldBe 10 197 | byAddress(third).map(_.id).toSet shouldEqual (21 to 30).toSet 198 | byAddress(fourth).size shouldBe 10 199 | byAddress(fourth).map(_.id).toSet shouldEqual (31 to 40).toSet 200 | } 201 | enterBarrier("test-done") 202 | 203 | } 204 | 205 | } 206 | 207 | "adjust routing after a topology change" in { 208 | awaitMemberRemoved(fourth) 209 | awaitAllReachable() 210 | runOn(first) { 211 | testConductor.removeNode(fourth) 212 | } 213 | 214 | runOn(fifth) { 215 | // trigger rebalancing the shards of the removed node 216 | val rebalanceProbe = TestProbe("rebalance") 217 | for (i <- 31 to 40) { 218 | rebalanceProbe.send(proxyRouter, Ping(i, rebalanceProbe.ref)) 219 | } 220 | 221 | // we should be receiving messages even in the absence of the updated shard location information 222 | // random routing should kick in, i.e. we won't have perfect matches 223 | val randomRoutedMessages: Seq[Pong] = rebalanceProbe.receiveN(10, 20.seconds).collect { case p: Pong => p } 224 | val (_, differentMsgs) = partitionByAddress(randomRoutedMessages) 225 | differentMsgs.nonEmpty shouldBe true 226 | 227 | // now give time to the new shards to be allocated and time to the router to retrieve new information 228 | // one second to send out the update shard information and one second of safety margin 229 | Thread.sleep(2000) 230 | 231 | val probe = TestProbe("probe") 232 | for (i <- 1 to 40) { 233 | probe.send(proxyRouter, Ping(i, probe.ref)) 234 | } 235 | 236 | val msgs: Seq[Pong] = probe.receiveN(40, 10.seconds).collect { case p: Pong => p } 237 | 238 | val (_, different) = msgs.partition { 239 | case Pong(_, _, routeeAddress, entityAddress) => 240 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 241 | } 242 | 243 | different.isEmpty shouldBe true 244 | 245 | } 246 | 247 | runOn(first, second, third, fifth) { 248 | enterBarrier("finished") 249 | 250 | } 251 | 252 | } 253 | 254 | def currentRoutees(router: ActorRef) = 255 | Await.result(router ? GetRoutees, timeout.duration).asInstanceOf[Routees].routees 256 | 257 | def partitionByAddress(msgs: Seq[Pong]) = msgs.partition { 258 | case Pong(_, _, routeeAddress, entityAddress) => 259 | routeeAddress.hostPort == entityAddress.hostPort && routeeAddress.hostPort.nonEmpty 260 | } 261 | 262 | } 263 | -------------------------------------------------------------------------------- /src/test/scala/io/bernhardt/akka/locality/router/ShardLocationAwareRoutingLogicSpec.scala: -------------------------------------------------------------------------------- 1 | package io.bernhardt.akka.locality.router 2 | 3 | import akka.actor.ActorSystem 4 | import akka.cluster.sharding.ShardRegion 5 | import akka.cluster.sharding.ShardRegion.{ ClusterShardingStats, GetClusterShardingStats, ShardRegionStats } 6 | import akka.routing.ActorRefRoutee 7 | import akka.testkit.{ TestKit, TestProbe } 8 | import io.bernhardt.akka.locality.Locality 9 | import org.scalatest.{ BeforeAndAfterAll, Matchers, WordSpecLike } 10 | 11 | import scala.collection.immutable.IndexedSeq 12 | 13 | class ShardLocationAwareRoutingLogicSpec 14 | extends TestKit(ActorSystem("ShardLocalityAwareRoutingLogicSpec")) 15 | with WordSpecLike 16 | with Matchers 17 | with BeforeAndAfterAll { 18 | import ShardLocationAwareRoutingLogicSpec._ 19 | 20 | "The ShardLocalityAwareRoutingLogic" should { 21 | "fall back to random routing when no sharding state is available" in { 22 | Locality(system) 23 | 24 | val shardRegion = TestProbe("region") 25 | val routee1 = TestProbe("routee1") 26 | val routee2 = TestProbe("routee2") 27 | val allRoutees = IndexedSeq(routee1, routee2).map(r => ActorRefRoutee(r.ref)) 28 | 29 | val logic = ShardLocationAwareRoutingLogic(system, shardRegion.ref, extractEntityId, extractShardId) 30 | 31 | val runs = 1000 32 | val selections = for (_ <- 1 to runs) yield { 33 | logic.select(TestMessage(1), allRoutees) 34 | } 35 | val count = selections.count(_ == allRoutees(0)) 36 | val expectedCount = runs * 50 / 100 37 | val variation = expectedCount / 10 38 | 39 | count shouldEqual expectedCount +- variation 40 | } 41 | 42 | "retrieve shard state on startup and use it in order to route messages" in { 43 | // use system identifier in order to simulate running on multiple nodes 44 | val system1 = ActorSystem("node1") 45 | val system2 = ActorSystem("node2") 46 | val system3 = ActorSystem("node3") 47 | Locality(system1) 48 | Locality(system2) 49 | Locality(system3) 50 | 51 | val routee1 = TestProbe("routee1")(system1) 52 | val routee2 = TestProbe("routee2")(system2) 53 | val routee3 = TestProbe("routee3")(system3) 54 | 55 | val allRoutees = IndexedSeq(routee1, routee2, routee3).map(r => ActorRefRoutee(r.ref)) 56 | 57 | val region1 = TestProbe("region1")(system1) 58 | val shards1 = for (i <- 1 to 10) yield i 59 | val region2 = TestProbe("region2")(system2) 60 | val shards2 = for (i <- 11 to 20) yield i 61 | val region3 = TestProbe("region3")(system3) 62 | val shards3 = for (i <- 21 to 30) yield i 63 | 64 | val logic = ShardLocationAwareRoutingLogic(system1, region1.ref, extractEntityId, extractShardId) 65 | 66 | // logic tries to get shard state 67 | import scala.concurrent.duration._ 68 | region1.expectMsgType[GetClusterShardingStats](5.seconds) 69 | val monitor1 = region1.sender() 70 | 71 | monitor1 ! ClusterShardingStats(regions = Map(region1.ref.path.address -> ShardRegionStats(shards1.map { id => 72 | id.toString -> 0 73 | }.toMap), region2.ref.path.address -> ShardRegionStats(shards2.map { id => 74 | id.toString -> 0 75 | }.toMap), region3.ref.path.address -> ShardRegionStats(shards3.map { id => 76 | id.toString -> 0 77 | }.toMap))) 78 | 79 | // TODO find a better way to determine that the router is ready 80 | Thread.sleep(500) 81 | 82 | for (i <- 1 to 10) { 83 | logic.select(TestMessage(i), allRoutees) shouldEqual ActorRefRoutee(routee1.ref) 84 | } 85 | for (i <- 11 to 20) { 86 | logic.select(TestMessage(i), allRoutees) shouldEqual ActorRefRoutee(routee2.ref) 87 | } 88 | for (i <- 21 to 30) { 89 | logic.select(TestMessage(i), allRoutees) shouldEqual ActorRefRoutee(routee3.ref) 90 | } 91 | 92 | TestKit.shutdownActorSystem(system1) 93 | TestKit.shutdownActorSystem(system2) 94 | TestKit.shutdownActorSystem(system3) 95 | } 96 | } 97 | 98 | override protected def afterAll(): Unit = 99 | TestKit.shutdownActorSystem(system) 100 | } 101 | 102 | object ShardLocationAwareRoutingLogicSpec { 103 | final case class TestMessage(id: Int) 104 | 105 | val extractEntityId: ShardRegion.ExtractEntityId = { 106 | case msg @ TestMessage(id) => (id.toString, msg) 107 | } 108 | 109 | val extractShardId: ShardRegion.ExtractShardId = { 110 | // take this simplest mapping on purpose 111 | case TestMessage(id) => id.toString 112 | } 113 | } 114 | --------------------------------------------------------------------------------