├── .gitignore ├── .jvmopts ├── .travis.yml ├── LICENSE ├── README.md ├── build.sbt ├── js └── src │ └── main │ └── scala │ └── com │ └── olegpy │ └── stm │ └── internal │ └── StorePlatform.scala ├── jvm └── src │ ├── main │ └── scala │ │ └── com │ │ └── olegpy │ │ └── stm │ │ └── internal │ │ └── StorePlatform.scala │ └── test │ └── scala │ └── com │ └── olegpy │ └── stm │ └── ConcurrentTests.scala ├── project ├── .keyring.asc.enc ├── build.properties └── plugins.sbt └── shared └── src ├── main └── scala │ └── com │ └── olegpy │ └── stm │ ├── PotentialDeadlockException.scala │ ├── TRef.scala │ ├── UnexpectedRetryInSyncException.scala │ ├── internal │ ├── Monitor.scala │ ├── Retry.scala │ ├── Store.scala │ └── TRefImpl.scala │ ├── misc │ ├── TDeferred.scala │ ├── TMVar.scala │ └── TQueue.scala │ └── package.scala └── test └── scala └── com └── olegpy └── stm ├── APITests.scala ├── BaseIOSuite.scala ├── LawsTests.scala ├── RetryTests.scala ├── RollbackTests.scala ├── StoreTests.scala ├── TRefTests.scala ├── misc ├── TDeferredTests.scala ├── TMVarTests.scala └── TQueueTests.scala ├── problems ├── CigaretteSmokersProblem.scala └── DiningPhilosophersProblem.scala └── results.scala /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | .idea/ -------------------------------------------------------------------------------- /.jvmopts: -------------------------------------------------------------------------------- 1 | -Xms1G 2 | -Xmx3G 3 | -XX:+UseG1GC 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - 2.11.12 4 | - 2.12.8 5 | jdk: 6 | - openjdk8 7 | sudo: required 8 | dist: xenial 9 | cache: 10 | directories: 11 | - "$HOME/.ivy2" 12 | - "$HOME/.coursier" 13 | - "$HOME/.sbt" 14 | script: 15 | - sbt ++$TRAVIS_SCALA_VERSION coverage test coverageReport 16 | - sbt coverageAggregate 17 | after_success: 18 | - sbt coveralls 19 | - test $PUBLISH == "true" && test $TRAVIS_PULL_REQUEST == "false" && test $TRAVIS_BRANCH 20 | == "master" && script -c 'gpg -a --import project/.keyring.asc' > /dev/null && sbt +publishSigned 21 | before_install: 22 | - test $TRAVIS_PULL_REQUEST == "false" && openssl aes-256-cbc -K $encrypted_8d7abf919097_key -iv $encrypted_8d7abf919097_iv 23 | -in project/.keyring.asc.enc -out project/.keyring.asc -d 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Oleg Pyzhcov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stm4cats 2 | ![Maven Central](https://img.shields.io/maven-central/v/com.olegpy/stm4cats_2.12.svg?color=06C) 3 | [![Build Status](https://travis-ci.org/oleg-py/stm4cats.svg?branch=master)](https://travis-ci.org/oleg-py/stm4cats) 4 | [![Coverage Status](https://coveralls.io/repos/github/oleg-py/stm4cats/badge.svg?branch=master)](https://coveralls.io/github/oleg-py/stm4cats?branch=master) 5 | 6 | 7 | An implementation of STM for any cats-effect compatible effect type. 8 | 9 | Current stable version is `0.1.0-M1`, available for Scala 2.11 and 2.12 and Scala.JS 0.6: 10 | ```scala 11 | // Use %%% for Scala.JS 12 | libraryDependencies += "com.olegpy" %% "stm4cats" % "0.1.0-M1" 13 | ``` 14 | 15 | Or, if you're feeling adventurous, a snapshot is build from `master` on each commit. 16 | ```scala 17 | resolvers += Resolver.sonatypeRepo("snapshots") 18 | libraryDependencies += "com.olegpy" %% "stm4cats" % "0.1.0-SNAPSHOT" 19 | ``` 20 | 21 | ### Try it 22 | ```scala 23 | import cats.implicits._ 24 | import cats.effect.IO 25 | import com.olegpy.stm._ 26 | import scala.concurrent.ExecutionContext.global 27 | import scala.concurrent.duration._ 28 | 29 | implicit val cs = IO.contextShift(global) 30 | implicit val timer = IO.timer(global) 31 | 32 | 33 | def transfer(fromR: TRef[Int], toR: TRef[Int], amt: Int): STM[Unit] = 34 | for { 35 | from <- fromR.get 36 | if from >= amt // Or STM.check(from >= amt) 37 | _ <- fromR.update(_ - amt) 38 | _ <- toR.update(_ + amt) 39 | } yield () 40 | 41 | def freeMoney(toR: TRef[Int]): IO[Unit] = STM.atomically[IO] { 42 | toR.update(_ + 10) 43 | } 44 | 45 | val io = for { 46 | fromR <- TRef(0).commit[IO] 47 | // Or shorter syntax: 48 | toR <- TRef.in[IO](0) 49 | amt = 100 50 | 51 | f1 <- transfer(fromR, toR, amt).commit[IO].start 52 | f2 <- (freeMoney(fromR) >> IO.sleep(1.second)).foreverM.start 53 | // In 10 seconds, the transfer succeeds 54 | _ <- f1.join 55 | _ <- f2.cancel 56 | res <- toR.get.commit[IO] 57 | _ <- IO(assert(res == amt)) 58 | } yield () 59 | 60 | io.unsafeRunSync() // Well, not on JS 61 | ``` 62 | 63 | ### Acknowledgements 64 | My interest in STM, as well as some of API in stm4cats was influenced by: 65 | - [Talk](https://www.youtube.com/watch?v=d6WWmia0BPM) by @jdegoes and @wi101 on STM in ZIO 66 | - An [alternative implementation](https://github.com/TimWSpence/cats-stm) by @TimWSpence 67 | - And last, but not least, 68 | [Beautiful concurrency paper](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/beautiful.pdf) 69 | by Simon Peyton Jones 70 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import xerial.sbt.Sonatype._ 2 | import sbtcrossproject.CrossPlugin.autoImport.{crossProject, CrossType} 3 | 4 | inThisBuild(Seq( 5 | organization := "com.olegpy", 6 | scalaVersion := "2.12.8", 7 | version := "0.1.0-SNAPSHOT", 8 | crossScalaVersions := Seq("2.11.12", "2.12.8"), 9 | pgpPassphrase := sys.env.get("PGP_PASS").map(_.toArray), 10 | licenses += ("MIT", url("http://opensource.org/licenses/MIT")) 11 | )) 12 | 13 | lazy val root = project.in(file(".")) 14 | .aggregate(stm4cats.js, stm4cats.jvm) 15 | .settings( 16 | name := "stm4cats", 17 | publish := {}, 18 | publishLocal := {}, 19 | publishArtifact := false, 20 | publishTo := sonatypePublishTo.value, 21 | ) 22 | 23 | lazy val stm4cats = crossProject(JSPlatform, JVMPlatform) 24 | .crossType(CrossType.Full) 25 | .in(file(".")) 26 | .settings( 27 | name := "stm4cats", 28 | fork in test := true, 29 | libraryDependencies ++= Seq( 30 | "org.typelevel" %%% "cats-effect" % "1.2.0", 31 | "com.lihaoyi" %%% "utest" % "0.6.7" % Test, 32 | "org.typelevel" %%% "cats-laws" % "1.5.0" % Test, 33 | "org.typelevel" %%% "cats-effect-laws" % "1.2.0" % Test, 34 | ), 35 | 36 | testFrameworks += new TestFramework("utest.runner.Framework"), 37 | 38 | scalacOptions --= Seq( 39 | "-Xfatal-warnings", 40 | "-Ywarn-unused:params", 41 | "-Ywarn-unused:implicits", 42 | ), 43 | publishTo := sonatypePublishTo.value, 44 | publishMavenStyle := true, 45 | sonatypeProjectHosting := 46 | Some(GitHubHosting("oleg-py", "stm4cats", "oleg.pyzhcov@gmail.com")), 47 | ) 48 | -------------------------------------------------------------------------------- /js/src/main/scala/com/olegpy/stm/internal/StorePlatform.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.internal 2 | 3 | import scala.collection.mutable 4 | import scala.scalajs.js 5 | import scala.util.control.NonFatal 6 | 7 | 8 | private[stm] trait StorePlatform { 9 | def forPlatform(): Store = new Store { 10 | import scala.scalajs.js.DynamicImplicits._ 11 | @inline implicit def any2JSAny(x: Any): js.Any = x.asInstanceOf[js.Any] 12 | private[this] val committed = 13 | js.Dynamic.newInstance(js.Dynamic.global.WeakMap)() 14 | 15 | class Journal( 16 | val uncommitted : js.Dynamic = js.Dynamic.newInstance(js.Dynamic.global.Map)(), 17 | val readKeys: mutable.AnyRefMap[AnyRef, Long] = mutable.AnyRefMap.empty 18 | ) extends Store.Journal { 19 | 20 | def writtenKeys: mutable.AnyRefMap[AnyRef, Long] = { 21 | val map = mutable.AnyRefMap.empty[AnyRef, Long] 22 | val it = uncommitted.keys().asInstanceOf[js.Iterable[AnyRef]].jsIterator() 23 | var entry = it.next() 24 | while (!entry.done) { 25 | map.update(entry.value, version) 26 | entry = it.next() 27 | } 28 | map 29 | } 30 | 31 | 32 | def read(k: AnyRef): Any = { 33 | if (uncommitted.has(k)) uncommitted.get(k) 34 | else { 35 | readKeys.update(k, version) 36 | committed.get(k) 37 | } 38 | } 39 | 40 | def update(k: AnyRef, v: Any): Unit = { 41 | uncommitted.set(k, v) 42 | () 43 | } 44 | 45 | def copy(): Journal = new Journal( 46 | js.Dynamic.newInstance(js.Dynamic.global.Map)(uncommitted), 47 | readKeys 48 | ) 49 | } 50 | 51 | private[this] var version = 0L 52 | private[this] var theLog: Journal = _ 53 | def current(): Store.Journal = theLog 54 | def transact[A](f: => A): A = { 55 | version += 1 56 | theLog = new Journal 57 | val result = f 58 | val it = theLog.uncommitted.asInstanceOf[js.Iterable[js.Array[js.Any]]].jsIterator() 59 | var entry = it.next() 60 | while (!entry.done) { 61 | val arr = entry.value 62 | committed.set(arr(0), arr(1)) 63 | entry = it.next() 64 | } 65 | theLog = null 66 | result 67 | } 68 | 69 | def attempt[A](f: => A): A = { 70 | val j = theLog 71 | try { 72 | theLog = j.copy() 73 | f 74 | } catch { case NonFatal(ex) => 75 | theLog = j 76 | throw ex 77 | } 78 | } 79 | 80 | def unsafeReadCommitted(k: AnyRef): Any = committed.get(k) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /jvm/src/main/scala/com/olegpy/stm/internal/StorePlatform.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.internal 2 | 3 | import scala.annotation.tailrec 4 | import scala.collection.JavaConverters._ 5 | import scala.collection.mutable 6 | import scala.util.control.NonFatal 7 | 8 | import java.{util => ju} 9 | import java.util.concurrent.atomic.{AtomicLong, AtomicReference} 10 | 11 | 12 | private[stm] trait StorePlatform { 13 | def forPlatform(): Store = new Store { 14 | private[this] val committed = 15 | new AtomicReference(new ju.WeakHashMap[AnyRef, (Any, Long)]()) 16 | private[this] val mkId = new AtomicLong() 17 | private[this] val journal = new ThreadLocal[Journal] 18 | 19 | class Journal( 20 | start: ju.WeakHashMap[AnyRef, (Any, Long)], 21 | val id: Long = mkId.getAndIncrement(), 22 | val uncommitted: mutable.AnyRefMap[AnyRef, (Any, Long)] = mutable.AnyRefMap.empty, 23 | val reads: mutable.AnyRefMap[AnyRef, Long] = mutable.AnyRefMap.empty 24 | ) extends Store.Journal { 25 | def writtenKeys: collection.Map[AnyRef, Long] = uncommitted.mapValues(_._2) 26 | def readKeys: mutable.AnyRefMap[AnyRef, Long] = reads 27 | 28 | def read(k: AnyRef): Any = { 29 | if (uncommitted contains k) uncommitted(k)._1 30 | else { 31 | start.get(k) match { 32 | case null => 33 | reads.update(k, Long.MinValue) 34 | null 35 | case (value, version) => 36 | reads.update(k, version) 37 | value 38 | } 39 | } 40 | } 41 | 42 | def update(k: AnyRef, v: Any): Unit = { 43 | uncommitted.update(k, (v, id)) 44 | } 45 | 46 | def copy() = 47 | new Journal(start, id, uncommitted ++ Map(), reads) 48 | } 49 | 50 | final def current(): Journal = journal.get() 51 | 52 | final def transact[A](f: => A): A = { 53 | @tailrec def reevaluate(): A = { 54 | val start = committed.get() 55 | journal.set(new Journal(start)) 56 | val result = f 57 | @tailrec def tryConsolidate(): Boolean = { 58 | val preCommit = committed.get() 59 | var hasConflict = start ne preCommit 60 | val j = journal.get() 61 | if (hasConflict) { 62 | hasConflict = false 63 | val ksi = j.reads.keysIterator 64 | while (ksi.hasNext && !hasConflict) { 65 | val key = ksi.next() 66 | hasConflict = start.get(key) ne preCommit.get(key) 67 | } 68 | } 69 | if (hasConflict) { 70 | // This might not be hit in a single test run, avoid fluctuating coverage 71 | // $COVERAGE-OFF$ 72 | false 73 | // $COVERAGE-ON$ 74 | } else { 75 | val end = new ju.WeakHashMap[AnyRef, (Any, Long)](preCommit) 76 | end.putAll(j.uncommitted.asJava) 77 | committed.compareAndSet(preCommit, end) || tryConsolidate() 78 | } 79 | } 80 | if (tryConsolidate()) { 81 | result 82 | } else { 83 | // Same as above 84 | // $COVERAGE-OFF$ 85 | reevaluate() 86 | // $COVERAGE-ON$ 87 | } 88 | } 89 | try { 90 | reevaluate() 91 | } finally { 92 | journal.remove() 93 | } 94 | } 95 | 96 | def attempt[A](f: => A): A = { 97 | val j = current() 98 | try { 99 | journal.set(j.copy()) 100 | f 101 | } catch { case NonFatal(ex) => 102 | journal.set(j) 103 | throw ex 104 | } 105 | } 106 | 107 | def unsafeReadCommitted(k: AnyRef): Any = committed.get().get(k) match { 108 | case null => 109 | // $COVERAGE-OFF$ 110 | null 111 | // $COVERAGE-ON$ 112 | case t => t._1 113 | } 114 | } 115 | } 116 | -------------------------------------------------------------------------------- /jvm/src/test/scala/com/olegpy/stm/ConcurrentTests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import scala.concurrent.ExecutionContext 4 | 5 | import cats.effect.{ContextShift, IO} 6 | import utest._ 7 | import cats.implicits._ 8 | 9 | import java.util.concurrent.Executors 10 | import java.util.concurrent.atomic.AtomicInteger 11 | 12 | 13 | object ConcurrentTests extends NondetIOSuite { 14 | val tests = Tests { 15 | "concurrent transactions can complete w/o reevaluation" - disabled { 16 | val mkEc = IO { 17 | ExecutionContext.fromExecutor(Executors.newSingleThreadExecutor()) 18 | } 19 | val attempts = new AtomicInteger() 20 | def test(ref: TRef[Int], ec: ExecutionContext)(implicit 21 | cs: ContextShift[IO]) = cs.evalOn(ec) { 22 | ref.update(_ + 1).map { _ => 23 | Thread.sleep(100) 24 | attempts.incrementAndGet() 25 | }.commit[IO] 26 | }.start 27 | 28 | val ecs = 3 29 | 30 | for { 31 | refs <- TRef.in[IO](0).product(mkEc).replicateA(ecs) 32 | fibers <- refs.traverse(Function.tupled(test(_, _))) 33 | _ <- fibers.sequence.join 34 | } yield assert(attempts.get() == ecs) 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /project/.keyring.asc.enc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oleg-py/stm4cats/dd8686d01b1ad4ab8e414de1ec1ed17b1fbfa1ca/project/.keyring.asc.enc -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.2.8 -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("io.github.davidgregory084" % "sbt-tpolecat" % "0.1.3") 2 | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.26") 3 | addSbtPlugin("org.portable-scala" % "sbt-scalajs-crossproject" % "0.6.0") 4 | addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "2.5") 5 | addSbtPlugin("com.jsuereth" % "sbt-pgp" % "1.1.2") 6 | addSbtPlugin("com.timushev.sbt" % "sbt-updates" % "0.4.0") 7 | addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.1") 8 | addSbtPlugin("org.scoverage" % "sbt-coveralls" % "1.2.7") 9 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/PotentialDeadlockException.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | class PotentialDeadlockException extends RuntimeException( 4 | "Potential STM deadlock: `retry` is used before any TRef was read" 5 | ) 6 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/TRef.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import cats.InvariantMonoidal 4 | import cats.data.State 5 | import cats.effect.Sync 6 | import internal.TRefImpl 7 | import cats.implicits._ 8 | import cats.effect.concurrent.Ref 9 | 10 | trait TRef[A] extends Ref[STM, A] { 11 | def get: STM[A] 12 | def set(a: A): STM[Unit] 13 | def update(f: A => A): STM[Unit] = get >>= (f >>> set) 14 | 15 | def updateF(f: A => STM[A]): STM[Unit] = get >>= f >>= set 16 | 17 | def updOrRetry(f: PartialFunction[A, A]): STM[Unit] = 18 | get.collect(f) >>= set 19 | 20 | def getAndSet(a: A): STM[A] = get <* set(a) 21 | 22 | def access: STM[(A, A => STM[Boolean])] = get.tupleRight(set(_).as(true)) 23 | 24 | def tryUpdate(f: A => A): STM[Boolean] = update(f).as(true) 25 | 26 | def tryModify[B](f: A => (A, B)): STM[Option[B]] = modify(f).map(_.some) 27 | 28 | def modify[B](f: A => (A, B)): STM[B] = get.map(f).flatMap { case (a, b) => set(a) as b } 29 | 30 | def tryModifyState[B](state: State[A, B]): STM[Option[B]] = modifyState(state).map(_.some) 31 | 32 | def modifyState[B](state: State[A, B]): STM[B] = modify(state.run(_).value) 33 | 34 | override def toString: String = s"TRef($unsafeLastValue)" 35 | 36 | def unsafeLastValue(): A 37 | } 38 | 39 | object TRef { 40 | def apply[A](initial: A): STM[TRef[A]] = STM.delay(new TRefImpl(initial)) 41 | def in[F[_]]: InPartiallyApplied[F] = new InPartiallyApplied[F] 42 | 43 | final class InPartiallyApplied[F[_]](private val dummy: Boolean = false) extends AnyVal { 44 | def apply[A](initial: A)(implicit F: Sync[F]): F[TRef[A]] = 45 | STM.tryCommitSync(TRef(initial)) 46 | } 47 | 48 | implicit val invariantMonoidal: InvariantMonoidal[TRef] = new InvariantMonoidal[TRef] { 49 | val unit: TRef[Unit] = new TRef[Unit] { 50 | def get: STM[Unit] = STM.unit 51 | def set(a: Unit): STM[Unit] = STM.unit 52 | 53 | def unsafeLastValue(): Unit = () 54 | } 55 | 56 | def imap[A, B](fa: TRef[A])(f: A => B)(g: B => A): TRef[B] = new TRef[B] { 57 | def get: STM[B] = fa.get map f 58 | def set(a: B): STM[Unit] = fa.set(g(a)) 59 | 60 | def unsafeLastValue(): B = f(fa.unsafeLastValue) 61 | } 62 | 63 | def product[A, B](fa: TRef[A], fb: TRef[B]): TRef[(A, B)] = new TRef[(A, B)] { 64 | def get: STM[(A, B)] = fa.get product fb.get 65 | def set(a: (A, B)): STM[Unit] = fa.set(a._1) *> fb.set(a._2) 66 | 67 | def unsafeLastValue(): (A, B) = (fa.unsafeLastValue, fb.unsafeLastValue) 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/UnexpectedRetryInSyncException.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | class UnexpectedRetryInSyncException extends RuntimeException( 4 | "Attempt to use STM.retry with STM.unsafeCommitSync" 5 | ) 6 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/internal/Monitor.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.internal 2 | 3 | import scala.collection.immutable.Queue 4 | 5 | import cats.effect.implicits._ 6 | import cats.effect.Concurrent 7 | import cats.implicits._ 8 | 9 | 10 | private[stm] class Monitor { 11 | type Callback = Either[Throwable, Unit] => Unit 12 | private[this] val store: Store = /*_*/Store.forPlatform()/*_*/ 13 | private[this] val rightUnit = Right(()) 14 | 15 | case class PendingUpdate(lastNotifyVersion: Long, cbs: Queue[Callback] = Queue.empty) 16 | private[this] val emptyPU = PendingUpdate(Long.MinValue) 17 | private[this] def read(k: AnyRef): PendingUpdate = store.current().read(k) match { 18 | case p: PendingUpdate => p 19 | case _ => emptyPU 20 | } 21 | 22 | private[this] def register(k: AnyRef, cb: Callback): Unit = { 23 | val PendingUpdate(lnv, cbs) = read(k) 24 | store.current().update(k, PendingUpdate(lnv, cbs enqueue cb)) 25 | } 26 | 27 | private[this] def unsub(cb: Callback): Unit = { 28 | val j = store.current() 29 | val keys = j.read(cb).asInstanceOf[collection.Set[AnyRef @unchecked]] 30 | j.update(cb, null) // TODO - wipe? 31 | val kit = 32 | // $COVERAGE-OFF$ 33 | if (keys eq null) Iterator.empty else keys.iterator 34 | // $COVERAGE-ON$ 35 | while (kit.hasNext) { 36 | val k = kit.next() 37 | val PendingUpdate(lnv, cbs) = j.read(k) 38 | j.update(k, PendingUpdate(lnv, cbs.diff(List(cb)))) 39 | } 40 | } 41 | 42 | def waitOn[F[_]](lastSeen: collection.Map[AnyRef, Long])(implicit F: Concurrent[F]): F[Unit] = 43 | F.cancelable { cb => 44 | store.transact { 45 | var abort = false 46 | val it = lastSeen.iterator 47 | while (it.hasNext && !abort) { 48 | val (k, ver) = it.next() 49 | if (ver < read(k).lastNotifyVersion) abort = true 50 | } 51 | if (abort) () => { 52 | cb(rightUnit); F.unit 53 | } 54 | else { 55 | store.current().update(cb, lastSeen.keySet) 56 | val it = lastSeen.iterator 57 | while (it.hasNext) register(it.next()._1, cb) 58 | () => F.delay { 59 | store.transact { unsub(cb) } 60 | } 61 | } 62 | }.apply() 63 | } 64 | 65 | def notifyOn[F[_]](versions: collection.Map[AnyRef, Long])(implicit F: Concurrent[F]): F[Unit] = 66 | F.suspend { 67 | store.transact { 68 | val qb = Queue.newBuilder[Callback] 69 | val it = versions.iterator 70 | val j = store.current() 71 | while (it.hasNext) { 72 | val (k, ver) = it.next() 73 | val PendingUpdate(v2, cbs) = read(k) 74 | if (ver > v2) { 75 | qb ++= cbs 76 | j.update(k, PendingUpdate(ver)) 77 | } 78 | } 79 | val callbacks = qb.result() 80 | val qit = callbacks.iterator 81 | while (qit.hasNext) { 82 | unsub(qit.next()) 83 | } 84 | 85 | if (callbacks.isEmpty) F.unit 86 | else F.delay(callbacks.foreach(_(rightUnit))).start.void 87 | } 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/internal/Retry.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.internal 2 | 3 | import scala.util.control.NoStackTrace 4 | 5 | 6 | private[stm] case object Retry extends Throwable with NoStackTrace 7 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/internal/Store.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.internal 2 | 3 | 4 | private[stm] trait Store { 5 | def current(): Store.Journal 6 | def transact[A](f: => A): A 7 | def attempt[A](f: => A): A 8 | 9 | def unsafeReadCommitted(k: AnyRef): Any 10 | } 11 | 12 | private[stm] object Store extends StorePlatform { 13 | trait Journal { 14 | def writtenKeys: collection.Map[AnyRef, Long] 15 | def readKeys: collection.Map[AnyRef, Long] 16 | def read(k: AnyRef): Any 17 | def update(k: AnyRef, v: Any): Unit 18 | } 19 | } -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/internal/TRefImpl.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.internal 2 | 3 | import com.olegpy.stm.{STM, TRef} 4 | 5 | 6 | private[stm] class TRefImpl[A](initial: A) extends TRef[A] { 7 | STM.store.current().update(this, initial) 8 | def get: STM[A] = STM.delay { 9 | STM.store.current().read(this).asInstanceOf[A] 10 | } 11 | def set(a: A): STM[Unit] = STM.delay { 12 | STM.store.current().update(this, a) 13 | } 14 | 15 | def unsafeLastValue(): A = 16 | STM.store.unsafeReadCommitted(this).asInstanceOf[A] 17 | } 18 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/misc/TDeferred.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.misc 2 | 3 | import cats.{Functor, Invariant} 4 | import cats.effect.{Concurrent, Sync} 5 | import cats.effect.concurrent.TryableDeferred 6 | import com.olegpy.stm.{STM, TRef} 7 | import cats.implicits._ 8 | 9 | class TDeferred[A] (private[stm] val state: TRef[Option[A]]) extends TryableDeferred[STM, A] { outer => 10 | def tryGet: STM[Option[A]] = state.get 11 | def get: STM[A] = tryGet.unNone 12 | def complete(a: A): STM[Unit] = state.updateF { 13 | case Some(_) => STM.abort(new IllegalStateException("Attempting to complete deferred twice")) 14 | case None => STM.pure(Some(a)) 15 | } 16 | 17 | // N.B: cannot use this.mapK as that doesn't return TryableDeferred 18 | def in[F[_]: Concurrent]: TryableDeferred[F, A] = new TryableDeferred[F, A] { 19 | def tryGet: F[Option[A]] = outer.tryGet.commit[F] 20 | def get: F[A] = outer.get.commit[F] 21 | def complete(a: A): F[Unit] = outer.complete(a).commit[F] 22 | } 23 | 24 | override def toString: String = state.unsafeLastValue match { 25 | case Some(value) => s"TDeferred()" 26 | case None => s"TDeferred()" 27 | } 28 | } 29 | 30 | object TDeferred { 31 | def apply[A]: STM[TDeferred[A]] = TRef(Option.empty[A]).map(new TDeferred(_)) 32 | def in[F[_]: Sync, A]: F[TDeferred[A]] = STM.tryCommitSync(apply) 33 | 34 | implicit val invariant: Invariant[TDeferred] = new Invariant[TDeferred] { 35 | def imap[A, B](fa: TDeferred[A])(f: A => B)(g: B => A): TDeferred[B] = { 36 | val fo = Functor[Option] 37 | new TDeferred[B](fa.state.imap(fo.lift(f))(fo.lift(g))) 38 | } 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/misc/TMVar.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.misc 2 | 3 | import cats.{Functor, Invariant} 4 | import cats.effect.{Concurrent, Sync} 5 | import cats.effect.concurrent.MVar 6 | import com.olegpy.stm.{STM, TRef} 7 | import cats.implicits._ 8 | 9 | class TMVar[A] (private[stm] val state: TRef[Option[A]]) extends MVar[STM, A] { 10 | def isEmpty: STM[Boolean] = state.get.map(_.isEmpty) 11 | def put(a: A): STM[Unit] = state.updOrRetry { case None => a.some } 12 | def tryPut(a: A): STM[Boolean] = put(a).as(true) orElse STM.pure(false) 13 | def take: STM[A] = tryTake.unNone 14 | def tryTake: STM[Option[A]] = isEmpty.ifM(STM.pure(none[A]), state.getAndSet(None)) 15 | def read: STM[A] = state.get.unNone 16 | 17 | def to[F[_]: Concurrent]: MVar[F, A] = mapK(STM.atomicallyK[F]) 18 | 19 | override def toString: String = state.unsafeLastValue match { 20 | case Some(value) => s"TMVar()" 21 | case None => "TMVar()" 22 | } 23 | } 24 | 25 | object TMVar { 26 | def empty[A]: STM[TMVar[A]] = TRef(none[A]).map(new TMVar(_)) 27 | def apply[A](initial: A): STM[TMVar[A]] = TRef(initial.some).map(new TMVar(_)) 28 | 29 | def in[F[_]: Sync, A](initial: A): F[TMVar[A]] = STM.tryCommitSync(TMVar(initial)) 30 | def emptyIn[F[_]: Sync, A]: F[TMVar[A]] = STM.tryCommitSync(TMVar.empty) 31 | 32 | implicit val invariantInstance: Invariant[TMVar] = new Invariant[TMVar] { 33 | def imap[A, B](fa: TMVar[A])(f: A => B)(g: B => A): TMVar[B] = { 34 | val fo = Functor[Option] 35 | new TMVar[B](fa.state.imap(fo.lift(f))(fo.lift(g))) 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/misc/TQueue.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.misc 2 | 3 | import scala.collection.immutable.Queue 4 | 5 | import cats.{Foldable, Functor, Invariant} 6 | import cats.effect.Sync 7 | import com.olegpy.stm.{STM, TRef} 8 | import cats.syntax.all._ 9 | import cats.instances.option._ 10 | 11 | trait TQueue[A] { 12 | def offer(a: A): STM[Boolean] 13 | def tryPeek: STM[Option[A]] 14 | protected def drop1 : STM[Unit] 15 | 16 | def tryDequeue: STM[Option[A]] = tryPeek.flatTap { 17 | case Some(_) => drop1 18 | case _ => STM.unit 19 | } 20 | 21 | def isEmpty: STM[Boolean] = tryPeek.map(_.isEmpty) 22 | 23 | def enqueue(a: A): STM[Unit] = offer(a).flatMap(STM.check) 24 | def enqueueAll[F[_]: Foldable](fa: F[A]): STM[Unit] = fa.traverse_(enqueue) 25 | 26 | def peek: STM[A] = tryPeek.unNone 27 | def dequeue: STM[A] = tryDequeue.unNone 28 | def dequeueUpTo(n: Int): STM[List[A]] = STM.suspend { 29 | val b = List.newBuilder[A] 30 | def loop(n: Int): STM[List[A]] = 31 | if (n > 0) (dequeue.map(b += _) >> loop(n - 1)) orElse STM.pure(b.result()) 32 | else STM.pure(b.result()) 33 | 34 | if (n < 0) STM.abort(new IllegalArgumentException(s"Cannot dequeue $n elements")) 35 | else loop(n) 36 | } 37 | 38 | protected def debugValues: Seq[A] 39 | protected def debugType: String 40 | 41 | override def toString: String = s"TQueue($debugType)(${debugValues.mkString(",")})" 42 | } 43 | 44 | object TQueue { 45 | def synchronous[A]: STM[TQueue[A]] = TRef(none[A]).map { slot => 46 | new TQueue[A] { 47 | def offer(a: A): STM[Boolean] = slot.get 48 | .flatMap(_.fold(slot.set(a.some).as(true))(_ => STM.pure(false))) 49 | def tryPeek: STM[Option[A]] = slot.get 50 | 51 | protected def drop1: STM[Unit] = slot.set(None) 52 | 53 | protected def debugValues: Seq[A] = slot.unsafeLastValue().toSeq 54 | protected def debugType: String = "synchronous" 55 | } 56 | } 57 | 58 | def synchronousIn[F[_]: Sync, A]: F[TQueue[A]] = STM.tryCommitSync(synchronous) 59 | 60 | def unbounded[A]: STM[TQueue[A]] = TRef(Queue.empty[A]).map { state => 61 | new TQueue[A] { 62 | def offer(a: A): STM[Boolean] = state.update(_.enqueue(a)).as(true) 63 | def tryPeek: STM[Option[A]] = state.get.map(_.headOption) 64 | protected def drop1: STM[Unit] = state.update(_.drop(1)) 65 | 66 | protected def debugValues: Seq[A] = state.unsafeLastValue() 67 | protected def debugType: String = "unbounded" 68 | } 69 | } 70 | 71 | def unboundedIn[F[_]: Sync, A]: F[TQueue[A]] = STM.tryCommitSync(unbounded) 72 | 73 | def bounded[A](max: Int): STM[TQueue[A]] = TRef(Vector.empty[A]).map { state => 74 | new TQueue[A] { 75 | def offer(a: A): STM[Boolean] = state.modify { 76 | case vec if vec.length < max => (vec :+ a, true) 77 | case vec => (vec, false) 78 | } 79 | 80 | def tryPeek: STM[Option[A]] = state.get.map(_.headOption) 81 | protected def drop1: STM[Unit] = state.update(_.drop(1)) 82 | 83 | protected def debugValues: Seq[A] = state.unsafeLastValue() 84 | protected def debugType: String = s"bounded($max)" 85 | } 86 | } 87 | 88 | def boundedIn[F[_]: Sync, A](max: Int): F[TQueue[A]] = STM.tryCommitSync(bounded(max)) 89 | 90 | def circularBuffer[A](max: Int): STM[TQueue[A]] = TRef(Vector.empty[A]).map { state => 91 | new TQueue[A] { 92 | def offer(a: A): STM[Boolean] = state.update { 93 | case v if v.length < max => v :+ a 94 | case v => v.drop(1) :+ a 95 | } as true 96 | def tryPeek: STM[Option[A]] = state.get.map(_.headOption) 97 | protected def drop1: STM[Unit] = state.update(_.drop(1)) 98 | 99 | protected def debugValues: Seq[A] = state.unsafeLastValue() 100 | protected def debugType: String = s"circularBuffer($max)" 101 | 102 | override def toString: String = 103 | s"TQueue(circularBuffer($max))(${state.unsafeLastValue().mkString(", ")})" 104 | } 105 | } 106 | 107 | def circularBufferIn[F[_]: Sync, A](max: Int): F[TQueue[A]] = 108 | STM.tryCommitSync(circularBuffer(max)) 109 | 110 | implicit val invariant: Invariant[TQueue] = new Invariant[TQueue] { 111 | def imap[A, B](fa: TQueue[A])(f: A => B)(g: B => A): TQueue[B] = new TQueue[B] { 112 | private[this] val liftedF = Functor[Option].lift(f) 113 | def offer(a: B): STM[Boolean] = fa.offer(g(a)) 114 | def tryPeek: STM[Option[B]] = fa.tryPeek.map(liftedF) 115 | 116 | protected def drop1: STM[Unit] = fa.drop1 117 | 118 | protected def debugValues: Seq[B] = fa.debugValues.map(f) 119 | protected def debugType: String = fa.debugType 120 | } 121 | } 122 | } -------------------------------------------------------------------------------- /shared/src/main/scala/com/olegpy/stm/package.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy 2 | 3 | import scala.language.implicitConversions 4 | 5 | import cats.effect.{Concurrent, IO, Sync} 6 | import cats.{Defer, FunctorFilter, Monad, Monoid, MonoidK, StackSafeMonad, ~>} 7 | import cats.implicits._ 8 | import com.olegpy.stm.internal.{Monitor, Retry, Store} 9 | 10 | package object stm { 11 | type STM[+A] = STM.Of[A] 12 | 13 | object STM { 14 | type Base = Any { type STM$newtype$ } 15 | trait Tag extends Any 16 | type Of[+A] <: Base with Tag 17 | 18 | def pure[A](a: A): STM[A] = wrap(IO.pure(a)) 19 | def suspend[A](stm: => STM[A]): STM[A] = wrap(IO.suspend(expose[A](stm))) 20 | 21 | val unit: STM[Unit] = wrap(IO.unit) 22 | val retry: STM[Nothing] = delay { throw Retry } 23 | def check(c: Boolean): STM[Unit] = retry.unlessA(c) 24 | def abort[A](ex: Throwable): STM[A] = wrap(IO.raiseError(ex)) 25 | 26 | def atomically[F[_]] = new AtomicallyFn[F] 27 | 28 | def tryCommitSync[F[_], A](stm: STM[A])(implicit F: Sync[F]): F[A] = 29 | F.delay(store.transact { expose[A](stm).unsafeRunSync() }) 30 | .adaptError { case Retry => new UnexpectedRetryInSyncException } // Add a stack trace, basically 31 | 32 | final class AtomicallyFn[F[_]](private val dummy: Boolean = false) extends AnyVal { 33 | def apply[A](stm: STM[A])(implicit F: Concurrent[F]): F[A] = 34 | atomicallyImpl[F, A](stm) 35 | } 36 | 37 | def atomicallyK[F[_]: Concurrent]: STM ~> F = new (STM ~> F) { 38 | def apply[A](fa: STM[A]): F[A] = atomicallyImpl[F, A](fa) 39 | } 40 | 41 | implicit class STMOps[A](private val self: STM[A]) extends AnyVal { 42 | def commit[F[_] : Concurrent]: F[A] = atomicallyImpl[F, A](self) 43 | 44 | def orElse[B >: A](other: STM[B]): STM[B] = suspend { 45 | try { 46 | STM.pure { store.attempt { expose[B](self).unsafeRunSync() } } 47 | } catch { case Retry => 48 | other 49 | } 50 | } 51 | 52 | def withFilter(f: A => Boolean): STM[A] = self.filter(f) 53 | 54 | def filterNot(f: A => Boolean): STM[A] = self.filter(!f(_)) 55 | 56 | def unNone[B](implicit ev: A <:< Option[B]): STM[B] = 57 | functorFilter.mapFilter(self)(ev) 58 | 59 | def iterateUntilRetry: STM[List[A]] = STM.suspend { 60 | val b = List.newBuilder[A] 61 | def loop: STM[List[A]] = 62 | (self.map(b += _) >> loop).orElse(STM.pure(b.result())) 63 | loop 64 | } 65 | } 66 | 67 | implicit val monad: StackSafeMonad[STM] with Defer[STM] = 68 | IO.ioEffect.asInstanceOf[StackSafeMonad[STM] with Defer[STM]] 69 | 70 | implicit def stmToAllMonadOps[A](stm: STM[A]): Monad.AllOps[STM, A] = 71 | Monad.ops.toAllMonadOps(stm) 72 | 73 | implicit val functorFilter: FunctorFilter[STM] = new FunctorFilter[STM] { 74 | def functor: cats.Functor[STM] = monad 75 | def mapFilter[A, B](fa: STM[A])(f: A => Option[B]): STM[B] = 76 | fa.flatMap(f(_).fold[STM[B]](STM.retry)(_.pure[STM])) 77 | } 78 | implicit def stmToFunctorFilterOps[A](stm: STM[A]): FunctorFilter.AllOps[STM, A] = 79 | FunctorFilter.ops.toAllFunctorFilterOps(stm) 80 | 81 | implicit val monoidK: MonoidK[STM] = new MonoidK[STM] { 82 | def empty[A]: STM[A] = STM.retry 83 | def combineK[A](a: STM[A], b: STM[A]): STM[A] = a orElse b 84 | } 85 | 86 | implicit def stmToMonoidKOps[A](stm: STM[A]): MonoidK.AllOps[STM, A] = 87 | MonoidK.ops.toAllMonoidKOps(stm) 88 | 89 | implicit def monoid[A: Monoid]: Monoid[STM[A]] = 90 | IO.ioMonoid[A].asInstanceOf[Monoid[STM[A]]] 91 | 92 | 93 | private[this] def wrap[A](io: IO[A]): STM[A] = io.asInstanceOf[STM[A]] 94 | private[this] def expose[A](stm: STM[_ >: A]): IO[A] = stm.asInstanceOf[IO[A]] 95 | 96 | private[stm] val store: Store = /*_*/Store.forPlatform()/*_*/ 97 | private[stm] def delay[A](a: => A): STM[A] = wrap(IO(a)) 98 | 99 | private[this] val globalLock = new Monitor 100 | 101 | private[this] def atomicallyImpl[F[_]: Concurrent, A](stm: STM[A]): F[A] = 102 | Concurrent[F].suspend { 103 | var journal: Store.Journal = null 104 | try { 105 | val result = store.transact { 106 | try { 107 | expose[A](stm).unsafeRunSync() 108 | } finally { 109 | journal = store.current() 110 | } 111 | } 112 | globalLock.notifyOn[F](journal.writtenKeys) as result 113 | } catch { case Retry => 114 | val rk = journal.readKeys 115 | if (rk.isEmpty) throw new PotentialDeadlockException 116 | globalLock.waitOn[F](rk) >> atomicallyImpl[F, A](stm) 117 | } 118 | } 119 | } 120 | } 121 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/APITests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import cats.effect.IO 4 | import utest._ 5 | import cats.syntax.apply._ 6 | 7 | object APITests extends NondetIOSuite { 8 | val tests = Tests { 9 | "STM.atomically" - { 10 | STM.atomically[IO](STM.pure(number)) 11 | .map { _ ==> number } 12 | } 13 | 14 | "for-comprehension with guards" - { 15 | def transfer(from: TRef[Int], to: TRef[Int], amt: Int): STM[Unit] = 16 | for { 17 | balance <- from.get 18 | if balance >= amt 19 | _ <- from.update(_ - amt) 20 | _ <- to.update(_ + amt) 21 | } yield () 22 | 23 | val acc = TRef(number) 24 | (acc, acc).mapN(transfer(_, _, 10)).commit[IO] 25 | } 26 | 27 | "SemigroupK syntax" - { 28 | val ref = TRef.in[IO](10).unsafeRunSync() 29 | (ref.get <+> STM.retry).commit[IO] 30 | } 31 | 32 | "STM#unNone" - { 33 | for { 34 | ref <- TRef.in[IO](Option(number)) 35 | x <- ref.get.unNone.commit[IO] 36 | } yield assert(x == number) 37 | } 38 | 39 | "STM#filterNot" - { 40 | STM.pure(number).filterNot(_ != number).commit[IO].map { _ ==> number } 41 | } 42 | 43 | "STM.atomicallyK" - { 44 | val fk = STM.atomicallyK[IO] 45 | fk(STM.pure(number)).map(_ ==> number) 46 | } 47 | 48 | "STM#iterateUntilRetry" - { 49 | for { 50 | r <- TRef.in[IO](4) 51 | x <- r.modify(x => (x - 1, x)).filter(_ > 0).iterateUntilRetry.commit[IO] 52 | } yield x ==> List(4, 3, 2, 1) 53 | } 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/BaseIOSuite.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import scala.concurrent.duration._ 4 | import scala.concurrent.{ExecutionContext, Future} 5 | 6 | import cats.effect.laws.util.TestContext 7 | import cats.implicits._ 8 | import cats.effect.{ContextShift, IO, Timer} 9 | import utest._ 10 | 11 | trait IOSuiteUtils { 12 | def timer: Timer[IO] 13 | def nap: IO[Unit] = timer.sleep(10.millis) 14 | 15 | def longNap: IO[Unit] = nap.replicateA(10).void 16 | 17 | def fail[A]: IO[A] = IO.suspend { 18 | assert(false) 19 | IO.never // unreachable, but above has type Unit 20 | } 21 | 22 | def disabled(a: => Any): Unit = () 23 | 24 | val number = 42 25 | } 26 | 27 | abstract class NondetIOSuite extends TestSuite with IOSuiteUtils { 28 | def ec: ExecutionContext = ExecutionContext.global 29 | implicit def cs: ContextShift[IO] = IO.contextShift(ec) 30 | implicit def timer: Timer[IO] = IO.timer(ec) 31 | 32 | 33 | def ioTimeout: FiniteDuration = 1.second 34 | 35 | override def utestWrap(path: Seq[String], runBody: => Future[Any])(implicit ec: ExecutionContext): Future[Any] = { 36 | super.utestWrap(path, runBody.flatMap { 37 | case io: IO[_] => io.timeout(ioTimeout).unsafeToFuture() 38 | case other => Future.successful(other) 39 | })(ec) 40 | } 41 | } 42 | 43 | abstract class DeterministicIOSuite extends TestSuite with IOSuiteUtils { 44 | private[this] val tc = TestContext() 45 | implicit def cs: ContextShift[IO] = tc.contextShift[IO](IO.ioEffect) 46 | implicit def timer: Timer[IO] = tc.timer[IO](IO.ioEffect) 47 | 48 | override def utestWrap(path: Seq[String], runBody: => Future[Any])(implicit ec: ExecutionContext): Future[Any] = { 49 | super.utestWrap(path, runBody.flatMap { 50 | case io: IO[_] => 51 | val f = io.unsafeToFuture() 52 | tc.tick(365.days) 53 | assert(tc.state.tasks.isEmpty) 54 | Future.fromTry(f.value.get) 55 | case other => Future.successful(other) 56 | })(new ExecutionContext { 57 | def execute(runnable: Runnable): Unit = runnable.run() 58 | def reportFailure(cause: Throwable): Unit = throw cause 59 | }) 60 | } 61 | } -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/LawsTests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import cats.effect.{IO, SyncIO} 4 | import org.typelevel.discipline.Laws 5 | import cats.laws.discipline._ 6 | import cats.laws.discipline.arbitrary._ 7 | import cats.effect.laws.discipline.arbitrary._ 8 | import cats.effect.laws.util.TestContext 9 | import cats.effect.laws.util.TestInstances._ 10 | import org.scalacheck.{Arbitrary, Gen} 11 | import org.scalacheck.util.{ConsoleReporter, Pretty} 12 | import utest._ 13 | import cats.implicits._ 14 | import cats.kernel.Eq 15 | import cats.kernel.laws.discipline.MonoidTests 16 | import utest.framework.TestPath 17 | import com.olegpy.stm.misc.{TDeferred, TMVar, TQueue} 18 | 19 | object LawsTests extends NondetIOSuite { 20 | val tests = Tests { 21 | "Monad[STM]" - 22 | uCheckAll(MonadTests[STM].monad[Int, String, Long]) 23 | 24 | "Defer[STM]" - 25 | uCheckAll(DeferTests[STM].defer[Long]) 26 | 27 | "MonoidK[STM]" - 28 | uCheckAll(MonoidKTests[STM].monoidK[Int]) 29 | 30 | "Monoid[STM[A]]" - 31 | uCheckAll(MonoidTests[STM[Int]].monoid) 32 | 33 | "FunctorFilter[STM]" - 34 | uCheckAll(FunctorFilterTests[STM].functorFilter[Int, String, Long]) 35 | 36 | "InvariantMonoidal[TRef]" - 37 | uCheckAll(InvariantMonoidalTests[TRef].invariantMonoidal[Int, Int, Int]) 38 | 39 | "Invariant[TDeferred]" - 40 | uCheckAll(InvariantTests[TDeferred].invariant[Int, String, Long]) 41 | 42 | "Invariant[TMVar]" - 43 | uCheckAll(InvariantTests[TMVar].invariant[Int, String, Long]) 44 | 45 | "Invariant[TQueue]" - 46 | uCheckAll(InvariantTests[TQueue].invariant[Int, String, Long]) 47 | } 48 | 49 | implicit val tc: TestContext = TestContext() 50 | implicit def arb[A](implicit a: Arbitrary[SyncIO[A]]): Arbitrary[STM[A]] = 51 | Arbitrary(a.arbitrary.map(_.toIO.asInstanceOf[STM[A]])) 52 | 53 | implicit def arbRef[A](implicit a: Arbitrary[A]): Arbitrary[TRef[A]] = 54 | Arbitrary(a.arbitrary.map(TRef.in[IO](_).unsafeRunSync())) 55 | 56 | implicit def arbTMVar[A](implicit a: Arbitrary[Option[A]]): Arbitrary[TMVar[A]] = 57 | Arbitrary(arbRef[Option[A]].arbitrary.map(new TMVar(_))) 58 | 59 | implicit def arbDef[A](implicit a: Arbitrary[Option[A]]): Arbitrary[TDeferred[A]] = 60 | Arbitrary(arbRef[Option[A]].arbitrary.map(new TDeferred(_))) 61 | 62 | implicit def arbQueue[A](implicit a: Arbitrary[List[A]]): Arbitrary[TQueue[A]] = 63 | Arbitrary { 64 | a.arbitrary.flatMap { initial => 65 | val len = initial.length 66 | val allowSync = len <= 1 67 | val options = List( 68 | TQueue.bounded[A](len), 69 | TQueue.unbounded[A], 70 | TQueue.circularBuffer[A](len) 71 | ) ::: (if (allowSync) List(TQueue.synchronous[A]) else Nil) 72 | 73 | Gen.oneOf(options).map(q => 74 | STM.tryCommitSync[IO, TQueue[A]](q.flatTap(_.enqueueAll(initial))).unsafeRunSync() 75 | ) 76 | } 77 | } 78 | 79 | implicit def eq[A: Eq]: Eq[STM[A]] = Eq[IO[A]].contramap(STM.atomically[IO](_)) 80 | 81 | implicit def eqRef[A: Eq: Arbitrary]: Eq[TRef[A]] = Eq.instance[TRef[A]] { (l, r) => 82 | val next = implicitly[Arbitrary[A]].arbitrary.sample.get 83 | val check = (l.get, r.get).mapN(_ == _) 84 | STM.tryCommitSync[IO, Boolean] { 85 | (check, l.set(next), check).mapN((a, _, b) => a && b) 86 | }.unsafeRunSync() 87 | } 88 | 89 | implicit def eqTDeferred[A: Eq: Arbitrary]: Eq[TDeferred[A]] = Eq.by(_.state) 90 | implicit def eqTMVar[A: Eq: Arbitrary]: Eq[TMVar[A]] = Eq.by(_.state) 91 | implicit def eqTQueue[A: Eq: Arbitrary]: Eq[TQueue[A]] = Eq.instance { (q1, q2) => 92 | STM.tryCommitSync[IO, Boolean]{ 93 | for { 94 | _ <- q1.dequeueUpTo(Int.MaxValue) 95 | emptied <- q2.isEmpty 96 | } yield emptied 97 | }.unsafeRunSync() 98 | } 99 | 100 | class UTestReporter(prop: String) extends ConsoleReporter(0) { 101 | override def onTestResult(name: String, res: org.scalacheck.Test.Result): Unit = { 102 | val scalaCheckResult = if (res.passed) "" else prop + " " + Pretty.pretty(res) 103 | assert(scalaCheckResult == "") 104 | } 105 | } 106 | 107 | private def uCheckAll(set: Laws#RuleSet)(implicit tp: TestPath): Unit = 108 | for ((name, prop) <- set.all.properties) { 109 | prop.check(_.withTestCallback(new UTestReporter(tp.value.last + name))) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/RetryTests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import cats.effect.{ExitCase, IO, SyncIO} 4 | import utest._ 5 | import cats.implicits._ 6 | 7 | object RetryTests extends DeterministicIOSuite { 8 | val tests = Tests { 9 | "Retry with no reads throws an exception" - { 10 | (STM.retry.commit[IO] >> fail[Unit]).recover { 11 | case _: PotentialDeadlockException => () 12 | } 13 | } 14 | 15 | "Unconditional retry from local reads only throws an exception" - { 16 | (TRef(0).flatMap(_.get).filter(_ => false).commit[IO] >> fail[Unit]).recover { 17 | case _: PotentialDeadlockException => () 18 | } 19 | } 20 | 21 | "Unconditional retry with nonlocal reads doesn't terminate" - { 22 | val ref = TRef.in[SyncIO](0).unsafeRunSync() 23 | val txn = ref.get >> STM.retry 24 | IO.race(txn.commit[IO], longNap) map { _ ==> Right(()) } 25 | } 26 | 27 | "Retrying eventually completes, if possible" - { 28 | for { 29 | ref <- TRef.in[IO](0) 30 | inc <- (ref.update(_ + 1).commit[IO] *> nap).replicateA(5).start 31 | x <- ref.get.filter(_ >= 5).commit[IO] 32 | _ <- inc.cancel 33 | } yield x ==> 5 34 | } 35 | 36 | "orElse falls back to first successful" - { 37 | for { 38 | x <- STM.retry.orElse(STM.pure(number).orElse(STM.pure(0))).commit[IO] 39 | } yield x ==> number 40 | } 41 | 42 | "orElse rolls the left back on retry" - { 43 | for { 44 | ref <- TRef.in[IO](0) 45 | _ <- STM.atomically[IO] { 46 | for { 47 | _ <- ref.set(number) // That should complete 48 | _ <- (ref.set(-1) >> STM.retry) orElse STM.unit 49 | } yield () 50 | } 51 | x <- ref.get.commit[IO] 52 | } yield x ==> number 53 | } 54 | 55 | "retries are actually cancellable" - { 56 | for { 57 | x <- TRef.in[IO](0) 58 | upd <- (x.update(_ + 1).commit[IO] >> nap).replicateA(10).start 59 | f1 <- x.get.filter(_ == 5).commit[IO] 60 | .guaranteeCase { 61 | case ExitCase.Canceled => upd.cancel >> x.set(-1).commit[IO] 62 | case _ => fail 63 | } 64 | .start 65 | _ <- nap 66 | _ <- f1.cancel 67 | _ <- longNap 68 | res <- x.get.commit[IO] 69 | } yield res ==> -1 70 | } 71 | 72 | "retries are not triggered by writes to independent variables" - { 73 | @volatile var count = 0 74 | val r1, r2, r3 = TRef.in[SyncIO](0).unsafeRunSync() 75 | val txn: STM[Unit] = for { 76 | i1 <- r1.get 77 | i2 <- r2.get 78 | _ = { count += 1 } // side effects to actually track retries 79 | if i1 < i2 80 | _ <- r3.get // after-check gets should not affect anything 81 | } yield () 82 | 83 | val isJS = ().toString != "()" 84 | 85 | def later(expect: Int): IO[Unit] = nap >> { 86 | if (isJS) IO(assert(count == expect)) 87 | else IO { 88 | // Use fairly lax checking for JVM, where CPU black magic is more prominent 89 | assert((expect - 2).to(expect + 10) contains count) 90 | count = expect 91 | } 92 | } 93 | 94 | for { 95 | f <- txn.commit[IO].start(cs) 96 | _ <- later(1) // Tried once, but failed 97 | _ <- r1.set(number).commit[IO] 98 | _ <- later(2) // Tried twice, as we modified r1 99 | _ <- r3.set(number).commit[IO] 100 | _ <- later(2) // Didn't try again, as we didn't touch r1 or r2 101 | _ <- r2.set(number + 1).commit[IO] 102 | _ <- later(3) // Tried again, and should complete at this point 103 | _ <- f.join 104 | } yield () 105 | } 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/RollbackTests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import cats.effect.IO 4 | import utest._ 5 | import cats.implicits._ 6 | import com.olegpy.stm.results._ 7 | 8 | object RollbackTests extends NondetIOSuite { 9 | val tests = Tests { 10 | "commit fails on exceptions" - { 11 | val Dummy = new Exception() 12 | def crash(): Unit = throw Dummy 13 | TRef.in[IO](0) 14 | .mproduct(_.set(5).map(_ => crash()).commit[IO].attempt) 15 | .flatMap { 16 | case (tref, Left(Dummy)) => tref.get.commit[IO] 17 | case _ => fail 18 | } 19 | .map { _ ==> 0 } 20 | } 21 | 22 | "commit fails on aborts" - { 23 | val ex = new Exception("Transaction aborted") 24 | for { 25 | s <- TRef.in[IO](0) 26 | res <- (s.set(number) >> STM.abort(ex)).result 27 | r <- s.get.commit[IO] 28 | } yield { 29 | r ==> 0 30 | res ==> STMAbort(ex) 31 | } 32 | } 33 | 34 | "orElse doesn't fall back for aborted computations" - { 35 | val ex = new Exception("Transaction aborted") 36 | for { 37 | s <- TRef.in[IO](0) 38 | res <- (s.set(number) >> (STM.abort(ex) orElse STM.unit)).result 39 | r <- s.get.commit[IO] 40 | } yield { 41 | r ==> 0 42 | res ==> STMAbort(ex) 43 | } 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/StoreTests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import scala.concurrent.Future 4 | import scala.concurrent.ExecutionContext.Implicits.global 5 | 6 | import com.olegpy.stm.internal.Store 7 | import utest._ 8 | 9 | object StoreTests extends TestSuite { 10 | val tests = Tests { 11 | "Store resolves conflicting updates" - { 12 | val store = Store.forPlatform() 13 | val key1, key2 = new Object 14 | store.transact(store.current().update(key1, 0)) 15 | store.transact(store.current().update(key2, 0)) 16 | def increment(key: Object): Unit = store.transact { 17 | val j = store.current() 18 | j.update(key, j.read(key).asInstanceOf[Int] + 1) 19 | } 20 | val execs = 10000 21 | Future.sequence { List.tabulate(execs * 2)(i => Future { 22 | if (i % 2 == 0) increment(key1) else increment(key2) 23 | }) } 24 | .map { _ => 25 | val (r1, r2) = store.transact { 26 | val j = store.current() 27 | (j.read(key1), j.read(key2)) 28 | } 29 | assert(r1 == execs) 30 | assert(r2 == execs) 31 | } 32 | } 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/TRefTests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import cats.data.State 4 | import cats.implicits._ 5 | import cats.effect.{IO, SyncIO} 6 | import com.olegpy.stm.results._ 7 | import utest._ 8 | 9 | 10 | object TRefTests extends NondetIOSuite { 11 | val tests = Tests { 12 | 13 | "TRef.apply" - { 14 | TRef(number).flatMap(_.get).result 15 | .map { _ ==> STMSuccess(number) } 16 | } 17 | 18 | "TRef.in" - { 19 | for { 20 | tr <- TRef.in[SyncIO](number).toIO 21 | x <- tr.get.commit[IO] 22 | } yield assert(x == number) 23 | } 24 | 25 | "TRef#modifyState" - { 26 | for { 27 | tr <- TRef.in[IO](number) 28 | stateOp = State.get[Int] <* State.set(0) 29 | n <- tr.modifyState(stateOp).commit[IO] 30 | _ = n ==> number 31 | x <- tr.get.commit[IO] 32 | } yield x ==> 0 33 | } 34 | 35 | "TRef#access" - { 36 | for { 37 | tr <- TRef.in[IO](0) 38 | _ <- tr.access 39 | .filter { case (i, _) => i == 0 } 40 | .flatMap { case (_, set) => set(number) >> tr.get } 41 | .result.map(_ ==> STMSuccess(number)) 42 | } yield () 43 | } 44 | 45 | "TRef#tryModifyState, TRef#tryModify and TRef#tryUpdate never fail" - { 46 | for { 47 | tr <- TRef.in[IO](number) 48 | stateOp = State.get[Int] <* State.set(0) 49 | _ <- ( 50 | tr.tryModifyState(stateOp), 51 | tr.tryModify(_ => (0, 'a')), 52 | tr.tryUpdate(_ + 1) 53 | ).tupled.commit[IO].flatMap { 54 | case (Some(_), Some(_), true) => IO.unit 55 | case _ => fail[Unit] 56 | } 57 | } yield () 58 | } 59 | 60 | "TRef#toString prints commited value" - { 61 | for { 62 | tr <- TRef.in[IO](0) 63 | _ <- tr.set(number).commit[IO] 64 | _ = tr.toString ==> s"TRef($number)" 65 | unitLike = TRef.invariantMonoidal.unit.imap(_ => "()")(_ => ()) 66 | _ = unitLike.toString ==> "TRef(())" 67 | _ = tr.imap(_ + 1)(_ - 1).toString ==> s"TRef(${number + 1})" 68 | _ = (tr, unitLike).tupled.toString ==> s"TRef(($number,()))" 69 | } yield () 70 | } 71 | } 72 | } 73 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/misc/TDeferredTests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.misc 2 | 3 | import cats.effect.IO 4 | import cats.implicits._ 5 | import com.olegpy.stm.NondetIOSuite 6 | import com.olegpy.stm.results._ 7 | import utest._ 8 | 9 | 10 | object TDeferredTests extends NondetIOSuite { 11 | val ofUnit = TDeferred.in[IO, Unit].map(_.in[IO]) 12 | val tests = Tests { 13 | "TDeferred.tryGet works as with regular TryableDeferred" - { 14 | for { 15 | d1 <- ofUnit 16 | r1 <- d1.tryGet 17 | _ = r1 ==> None 18 | _ <- d1.complete(()) 19 | r2 <- d1.tryGet 20 | _ = r2 ==> Some(()) 21 | } yield () 22 | } 23 | 24 | "TDeferred.get semantically blocks" - { 25 | for { 26 | d1 <- ofUnit 27 | d2 <- ofUnit 28 | _ <- (d1.get >> d2.complete(())).start 29 | _ <- nap 30 | _ <- d2.tryGet.map(_ ==> None) 31 | _ <- d1.complete(()) 32 | _ <- d2.get 33 | } yield () 34 | } 35 | 36 | "TDeferred.complete twice fails" - { 37 | for { 38 | d <- TDeferred.in[IO, Unit] 39 | _ <- d.complete(()).commit[IO] 40 | r <- d.complete(()).result 41 | } yield assert(r.is[STMAbort]) 42 | } 43 | 44 | "TDeferred#toString shows a state" - { 45 | for { 46 | d <- TDeferred.in[IO, Int] 47 | _ = d.toString ==> "TDeferred()" 48 | _ <- d.complete(number).commit[IO] 49 | _ = d.toString ==> s"TDeferred()" 50 | } yield () 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/misc/TMVarTests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.misc 2 | 3 | import cats.effect.IO 4 | import cats.effect.concurrent.MVar 5 | import cats.implicits._ 6 | import com.olegpy.stm.NondetIOSuite 7 | import com.olegpy.stm.results._ 8 | import utest._ 9 | 10 | object TMVarTests extends NondetIOSuite { 11 | val tests = Tests { 12 | "TMVar#isEmpty" - { 13 | for { 14 | mv1 <- TMVar.emptyIn[IO, Unit] 15 | mv2 <- TMVar.in[IO, Unit](()) 16 | _ <- mv1.isEmpty.commit[IO].map(_ ==> true) 17 | _ <- mv2.isEmpty.commit[IO].map(_ ==> false) 18 | } yield () 19 | } 20 | 21 | "TMVar#read" - { 22 | for { 23 | mv1 <- TMVar.emptyIn[IO, Unit] 24 | mv2 <- TMVar.in[IO, Unit](()) 25 | _ <- mv1.read.result.map(_.is[STMRetry.type]) 26 | _ <- mv1.read.result.map(_.is[STMRetry.type]) 27 | _ <- mv2.read.result.map(_ ==> STMSuccess(())) 28 | _ <- mv2.read.result.map(_ ==> STMSuccess(())) 29 | } yield () 30 | } 31 | 32 | "TMVar#tryTake" - { 33 | for { 34 | mv1 <- TMVar.emptyIn[IO, Unit].map(_.to[IO]) 35 | mv2 <- TMVar.in[IO, Unit](()).map(_.to[IO]) 36 | _ <- mv1.tryTake.map(_ ==> None) 37 | _ <- mv1.tryTake.map(_ ==> None) 38 | _ <- mv2.tryTake.map(_ ==> Some(())) 39 | _ <- mv2.tryTake.map(_ ==> None) 40 | } yield () 41 | } 42 | 43 | "TMVar#tryPut" - { 44 | for { 45 | mv1 <- TMVar.emptyIn[IO, Unit].map(_.to[IO]) 46 | mv2 <- TMVar.in[IO, Unit](()).map(_.to[IO]) 47 | _ <- mv1.tryPut(()).map(_ ==> true) 48 | _ <- mv1.tryPut(()).map(_ ==> false) 49 | _ <- mv2.tryPut(()).map(_ ==> false) 50 | _ <- mv2.tryPut(()).map(_ ==> false) 51 | } yield () 52 | } 53 | 54 | "TMVar#take" - { 55 | def producer(mv: MVar[IO, Int]): IO[Unit] = mv.put(1) >> mv.put(2) >> mv.put(3) 56 | def consumer(mv: MVar[IO, Int]): IO[List[Int]] = { 57 | def loop(list: List[Int]): IO[List[Int]] = { 58 | if (list.length == 3) IO.pure(list.reverse) 59 | else nap >> mv.take.map(_ :: list) >>= loop 60 | } 61 | loop(Nil) 62 | } 63 | 64 | for { 65 | mv <- TMVar.emptyIn[IO, Int].map(_.to[IO]) 66 | _ <- producer(mv).start 67 | list <- consumer(mv) 68 | } yield list ==> List(1, 2, 3) 69 | } 70 | 71 | "TMVar#toString shows a state" - { 72 | for { 73 | mv <- TMVar.emptyIn[IO, Int] 74 | _ = mv.toString ==> "TMVar()" 75 | _ <- mv.put(number).commit[IO] 76 | _ = mv.toString ==> s"TMVar()" 77 | } yield () 78 | } 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/misc/TQueueTests.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.misc 2 | 3 | import scala.util.Random 4 | 5 | import cats.implicits._ 6 | import cats.effect.IO 7 | import cats.effect.concurrent.Deferred 8 | import com.olegpy.stm.results._ 9 | import com.olegpy.stm.{DeterministicIOSuite, STM} 10 | import utest._ 11 | 12 | 13 | object TQueueTests extends DeterministicIOSuite { 14 | val tests = Tests { 15 | "TQueue.bounded" - { 16 | val mkQ = TQueue.boundedIn[IO, Int](3) 17 | "shows element with toString" - mkQ.flatMap(checkToString) 18 | "works in FIFO fashion" - mkQ.flatMap(isFifo) 19 | "doesn't block in inspect methods" - mkQ.flatMap(testNoBlocking) 20 | "blocks on dequeue and peek when empty" - mkQ.flatMap(blocksOnDequeue) 21 | "blocks on enqueue when full" - mkQ.flatMap(blocksOnEnqueue) 22 | } 23 | 24 | "TQueue.unbounded" - { 25 | val mkQ = TQueue.unboundedIn[IO, Int] 26 | "shows element with toString" - mkQ.flatMap(checkToString) 27 | "works in FIFO fashion" - mkQ.flatMap(isFifo) 28 | "doesn't block in inspect methods" - mkQ.flatMap(testNoBlocking) 29 | "blocks on dequeue and peek when empty" - mkQ.flatMap(blocksOnDequeue) 30 | "never blocks on enqueue" - mkQ.flatMap(doesntBlockOnEnqueue) 31 | } 32 | 33 | "TQueue.synchronous" - { 34 | val mkQ = TQueue.synchronousIn[IO, Int] 35 | "shows element with toString" - mkQ.flatMap(checkToString) 36 | "works in FIFO fashion" - mkQ.flatMap(isFifo) 37 | "doesn't block in inspect methods" - mkQ.flatMap(testNoBlocking) 38 | "blocks on dequeue and peek when empty" - mkQ.flatMap(blocksOnDequeue) 39 | "blocks on enqueue when full" - mkQ.flatMap(blocksOnEnqueue) 40 | "allows single element only" - mkQ.flatMap { queue => 41 | queue.enqueueAll(List(number, number)).result 42 | }.map(_ ==> STMRetry) 43 | } 44 | 45 | "TQueue.circularBuffer" - { 46 | val mkQ = TQueue.circularBufferIn[IO, Int](10) 47 | "shows element with toString" - mkQ.flatMap(checkToString) 48 | "works in FIFO fashion" - mkQ.flatMap(isFifo) 49 | "doesn't block in inspect methods" - mkQ.flatMap(testNoBlocking) 50 | "blocks on dequeue and peek when empty" - mkQ.flatMap(blocksOnDequeue) 51 | "never blocks on enqueue" - mkQ.flatMap(doesntBlockOnEnqueue) 52 | "drops oldest elements on enqueue" - mkQ.flatMap { queue => 53 | queue.enqueueAll(List.range(0, 20)).commit[IO] >> 54 | queue.dequeueUpTo(Int.MaxValue).result.map(_ ==> STMSuccess(List.range(10, 20))) 55 | } 56 | } 57 | } 58 | 59 | 60 | private[this] val mkGate = Deferred.tryableUncancelable[IO, Unit] 61 | 62 | private def noBlock(io: STM[Any]): IO[Unit] = 63 | for { 64 | gate <- mkGate 65 | _ <- (io.commit[IO] >> gate.complete(())).start 66 | _ <- longNap 67 | rs <- gate.tryGet 68 | } yield assert(rs.nonEmpty) 69 | 70 | private def blockUnblock(io: STM[Any], unblock: STM[Any]): IO[Unit] = 71 | for { 72 | gate <- mkGate 73 | _ <- (io.commit[IO] >> gate.complete(())).start 74 | _ <- longNap 75 | rs <- gate.tryGet 76 | _ = assert(rs.isEmpty) 77 | _ <- unblock.commit[IO] 78 | _ <- nap 79 | rs2 <- gate.tryGet 80 | } yield assert(rs2.nonEmpty) 81 | 82 | private def testNoBlocking(q: TQueue[Int]) = 83 | noBlock(q.offer(number)) >> 84 | noBlock(q.tryPeek) >> 85 | noBlock(q.tryDequeue) >> 86 | noBlock(q.isEmpty) >> 87 | noBlock(q.dequeueUpTo(Int.MaxValue)) 88 | 89 | private def blocksOnDequeue(q: TQueue[Int]): IO[Unit] = 90 | q.dequeueUpTo(Int.MaxValue).commit[IO] >> 91 | blockUnblock(q.dequeue, q.enqueue(number)) >> 92 | blockUnblock(q.peek, q.enqueue(number)) 93 | 94 | private def blocksOnEnqueue(q: TQueue[Int]): IO[Unit] = 95 | q.enqueue(number).iterateUntilRetry.commit[IO] >> 96 | blockUnblock(q.enqueue(number), q.dequeue) 97 | 98 | private def doesntBlockOnEnqueue(q: TQueue[Int]): IO[Unit] = 99 | noBlock(q.enqueueAll(List.range(1, 100))) 100 | 101 | private def isFifo(q: TQueue[Int]): IO[Unit] = { 102 | val expect = List.range(1, 10) 103 | expect.traverse_(q.enqueue(_).commit[IO]) &> 104 | q.dequeue.commit[IO].replicateA(expect.length).map { got => 105 | assert(got == expect) 106 | } 107 | } 108 | 109 | private def checkToString(q: TQueue[Int]): IO[Unit] = { 110 | val n = Random.nextInt() 111 | q.enqueue(n).commit[IO] >> IO(assert(q.toString contains n.toString)) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/problems/CigaretteSmokersProblem.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.problems 2 | 3 | import scala.concurrent.duration._ 4 | import scala.util.Random 5 | 6 | import cats.effect.IO 7 | import com.olegpy.stm.misc.TQueue 8 | import com.olegpy.stm._ 9 | import utest._ 10 | import cats.implicits._ 11 | 12 | 13 | object CigaretteSmokersProblem extends NondetIOSuite { 14 | override def ioTimeout: FiniteDuration = 2.seconds 15 | 16 | val tests = Tests { 17 | "Cigarette smokers problem" - { 18 | val attempts = 50 19 | for { 20 | table <- mkTable 21 | deal <- new Dealer(table).dealRandom.replicateA(attempts).start 22 | counter <- TRef.in[IO](0) 23 | puff = counter.update(_ + 1).commit[IO]// >> nap 24 | smoke <- allIngredients.foldMapM { 25 | new Smoker(_, table).buildACig(puff).foreverM[Unit].start 26 | } 27 | _ <- longNap 28 | _ <- counter.get.filter(_ == attempts).commit[IO] 29 | _ <- deal.cancel 30 | _ <- smoke.cancel 31 | } yield () 32 | } 33 | } 34 | 35 | sealed trait Ingredient extends Product with Serializable 36 | case object Tobacco extends Ingredient 37 | case object Paper extends Ingredient 38 | case object Matches extends Ingredient 39 | 40 | def allIngredients: List[Ingredient] = List(Tobacco, Paper, Matches) 41 | 42 | class Table(queue: TQueue[Ingredient]) { 43 | def put(ingredient: Ingredient): STM[Unit] = queue.enqueue(ingredient) 44 | def takeThings: STM[Set[Ingredient]] = queue.dequeue.replicateA(2).map(_.toSet) 45 | override def toString: String = s"Table($queue)" 46 | } 47 | 48 | def mkTable: IO[Table] = TQueue.boundedIn[IO, Ingredient](2).map(new Table(_)) 49 | 50 | class Smoker (ingredient: Ingredient, table: Table) { 51 | def buildACig(puff: IO[Unit]): IO[Unit] = 52 | table.takeThings.filterNot(_ contains ingredient).commit[IO] >> puff 53 | } 54 | 55 | class Dealer(table: Table) { 56 | private val randomIngredients = IO { Random.shuffle(allIngredients).take(2) } 57 | def dealRandom: IO[Unit] = 58 | randomIngredients.flatMap { _.traverse_(table.put).commit[IO] } 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/problems/DiningPhilosophersProblem.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm.problems 2 | 3 | import cats.effect.IO 4 | import cats.effect.concurrent.Ref 5 | import cats.implicits._ 6 | import com.olegpy.stm.misc.TMVar 7 | import com.olegpy.stm.{NondetIOSuite, STM} 8 | import utest._ 9 | 10 | object DiningPhilosophersProblem extends NondetIOSuite { 11 | val tests = Tests { 12 | "Dining philosophers problem" - { 13 | val philCount = 8 14 | val iterCount = 40 15 | 16 | val mkCycle = for { 17 | philosophers <- Ref[IO].of(0).map(new Philosopher(_)).replicateA(philCount) 18 | leftForks <- TMVar.in[IO, Unit](()).replicateA(philCount) 19 | rightForks = leftForks.tail :+ leftForks.head 20 | } yield (philosophers, leftForks, rightForks) 21 | .parMapN(_.eat(_, _)) // Parallel for List is ZipList, so we zip all forks together 22 | .parSequence_ // This is IO's Parallel now, do everything concurrently 23 | .>>(philosophers.traverse(_.timesEaten)) // get values out for assertions 24 | 25 | mkCycle.flatMap { cycle => 26 | def loop(n: Int): IO[Unit] = 27 | if (n == iterCount) IO.unit 28 | else cycle.map { list => 29 | assert(list.length == philCount) 30 | assert(list.forall(_ == n)) 31 | } >> loop(n + 1) 32 | loop(1) 33 | } 34 | } 35 | 36 | type Fork = TMVar[Unit] 37 | 38 | class Philosopher (timesEatenRef: Ref[IO, Int]) { 39 | val timesEaten = timesEatenRef.get 40 | def eat(left: Fork, right: Fork): IO[Unit] = for { 41 | // It is ABSOLUTELY essential to solution to do takes atomically. We either take 42 | // both, or take nothing 43 | _ <- STM.atomically[IO] { left.take >> right.take } 44 | // Pretend we're eating 45 | _ <- IO.shift 46 | _ <- timesEatenRef.update(_ + 1) 47 | // Putting them back atomically isn't strictly necessary, it's just a micro- 48 | // optimization to reduce # of transactions and also just looks more brief 49 | _ <- STM.atomically[IO] { left.put(()) >> right.put(()) } 50 | } yield () 51 | } 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /shared/src/test/scala/com/olegpy/stm/results.scala: -------------------------------------------------------------------------------- 1 | package com.olegpy.stm 2 | 3 | import scala.reflect.ClassTag 4 | 5 | import cats.effect.{Concurrent, IO} 6 | import cats.implicits._ 7 | 8 | object results { 9 | sealed trait STMResult[+A] { 10 | def is[T](implicit ct: ClassTag[T]): Boolean = 11 | ct.runtimeClass.isInstance(this) 12 | } 13 | 14 | case class STMSuccess[+A](value: A) extends STMResult[A] 15 | case class STMAbort(reason: Throwable) extends STMResult[Nothing] 16 | case object STMRetry extends STMResult[Nothing] 17 | 18 | implicit class STMOps[+A](private val self: STM[A]) extends AnyVal { 19 | def result(implicit c: Concurrent[IO]): IO[STMResult[A]] = 20 | self.map(STMSuccess(_)).orElse(STMRetry.pure[STM]) 21 | .commit[IO].handleError(STMAbort) 22 | } 23 | } 24 | --------------------------------------------------------------------------------