├── .gitignore ├── .gitlab-ci.yml ├── .travis.yml ├── LICENSE ├── README.md ├── coroutines-128-xmas.png ├── coroutines-128.png ├── coroutines-512-xmas.png ├── coroutines-512.png ├── coroutines-64.png ├── coroutines-common └── src │ └── main │ └── scala │ └── scala │ └── coroutines │ └── common │ ├── ByTreeTyper.scala │ ├── ByTreeUntyper.scala │ ├── Cache.scala │ ├── Stack.scala │ └── TraverserUtil.scala ├── coroutines-extra └── src │ ├── main │ └── scala │ │ └── org │ │ └── coroutines │ │ └── extra │ │ ├── AsyncAwait.scala │ │ └── Enumerator.scala │ └── test │ └── scala │ └── org │ └── coroutines │ └── extra │ ├── async-await-tests.scala │ ├── enumerator-tests.scala │ └── enumerators-boxing-tests.scala ├── coroutines.svg ├── cross.conf ├── dependencies.conf ├── project ├── Build.scala ├── plugins.sbt └── project │ └── Build.scala ├── src ├── bench │ └── scala │ │ ├── org │ │ └── coroutines │ │ │ ├── AsyncAwaitBench.scala │ │ │ ├── DataflowVariableBench.scala │ │ │ ├── GraphIteratorBench.scala │ │ │ ├── HashSetIteratorBench.scala │ │ │ ├── RedBlackIteratorBench.scala │ │ │ ├── ScalaCheckBench.scala │ │ │ ├── StreamBench.scala │ │ │ ├── TreeIteratorBench.scala │ │ │ └── data-structures.scala │ │ └── scala │ │ └── collection │ │ └── Backdoor.scala ├── main │ └── scala │ │ └── org │ │ └── coroutines │ │ ├── Analyzer.scala │ │ ├── AstCanonicalization.scala │ │ ├── CfgGenerator.scala │ │ ├── Coroutine.scala │ │ ├── Synthesizer.scala │ │ ├── package.scala │ │ └── specializations.scala └── test │ └── scala │ └── org │ ├── coroutines │ ├── ast-canonicalization-tests.scala │ ├── async-await-tests.scala │ ├── boxing-tests.scala │ ├── coroutine-syntax-tests.scala │ ├── coroutine-tests.scala │ ├── pattern-match-tests.scala │ ├── regression-tests.scala │ ├── snapshot-tests.scala │ ├── try-catch-tests.scala │ └── yieldto-tests.scala │ ├── examples │ ├── AsyncAwait.scala │ ├── Composition.scala │ ├── CompositionCall.scala │ ├── ControlTransfer.scala │ ├── ControlTransferWithPull.scala │ ├── Datatypes.scala │ ├── Exceptions.scala │ ├── FaqSimpleExample.scala │ ├── Identity.scala │ ├── Lifecycle.scala │ ├── MockSnapshot.scala │ ├── Snapshot.scala │ ├── VowelCount.scala │ └── examples-tests.scala │ └── separatepackage │ └── SeparatePackageTest.scala └── version.conf /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | tmp 3 | -------------------------------------------------------------------------------- /.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | image: stormenroute/storm-enroute-build:0.15 2 | 3 | before_script: 4 | - | 5 | eval `ssh-agent -s` 6 | echo "exec cat" > ap-cat.sh 7 | chmod a+x ap-cat.sh 8 | export DISPLAY=1 9 | echo "$SSH_KEY_PASS" | SSH_ASKPASS=./ap-cat.sh ssh-add ~/.ssh/id_rsa 10 | rm ap-cat.sh 11 | export SSH_KEY_PASS="" 12 | export DISPLAY="" 13 | - | 14 | mkdir -p ~/.sbt/0.13/ 15 | echo 'scalacOptions ++= Seq("-Xmax-classfile-name", "90")' > ~/.sbt/0.13/local.sbt 16 | - git clone git@ci.storm-enroute.com:storm-enroute/super-storm-enroute.git ~/.super-storm-enroute 17 | 18 | after_script: 19 | - rm -rf ~/.super-storm-enroute 20 | 21 | 22 | gate-ubuntu: 23 | tags: 24 | - ubuntu 25 | - gate 26 | script: 27 | - ~/.super-storm-enroute/tools/ci coroutines $(pwd) 28 | 29 | 30 | gate-osx: 31 | tags: 32 | - osx 33 | - gate 34 | script: 35 | - ~/.super-storm-enroute/tools/ci coroutines $(pwd) 36 | cache: 37 | paths: 38 | - /root/.ivy2/ 39 | - /root/.sbt/ 40 | 41 | 42 | bench-ubuntu: 43 | tags: 44 | - ubuntu 45 | - bench 46 | script: 47 | - ~/.super-storm-enroute/tools/ci --bench=True . $(pwd) 48 | cache: 49 | paths: 50 | - /root/.ivy2/ 51 | - /root/.sbt/ 52 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | scala: 3 | - "2.11.7" 4 | jdk: 5 | - oraclejdk8 6 | before_script: 7 | - sudo chmod +x /usr/local/bin/sbt 8 | script: 9 | - sbt ++${TRAVIS_SCALA_VERSION} test 10 | - git clone https://github.com/storm-enroute/dev-tools.git ~/.dev-tools 11 | - ~/.dev-tools/lint -p . 12 | env: 13 | global: 14 | - secure: "adngy4wHe+DLnkLW0K7S8KFe+GR2mTMHx3VPs4YFooXSkJlRZMzhMzd2nJN0VNb2U7zkIrtLvsykLGIaeKF3u02iheHt3RCpRoKmxOjAkFXSRm6V7Z1J+EMVHqAG/72L2P2KkjJEaXrQqE3yG6e6elRk+qp2V3zKpQ6E5sS/g3c=" 15 | - secure: "fjMyfWi+UndcsT+Voqxt1NVvIbqPKwDzipxK18zd+eEgASam+L4fgtmDsIXjbgdBaTX59w+Q1DNnAOT7x34XpneU+GDASDmoNdj6oCoZOHiQb/odu2WOBWf/iINCTpJtPMMr8cLQaQ3CsnSyOojaJiFERDMcO9i58kDfU1gXa/4=" 16 | branches: 17 | only: 18 | - master 19 | notifications: 20 | slack: storm-enroute:GnbA8DEy3mL3Pyp3cbptr7F2 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015-2016, Scala Coroutines 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, 5 | are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, this 11 | list of conditions and the following disclaimer in the documentation and/or 12 | other materials provided with the distribution. 13 | 14 | * Neither the name of the {organization} nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 19 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 20 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 22 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 23 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 24 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 25 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 27 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | ![Coroutines](/coroutines-128-xmas.png) 3 | 4 | # Scala Coroutines 5 | 6 | [![Join the chat at https://gitter.im/storm-enroute/coroutines](https://badges.gitter.im/storm-enroute/coroutines.svg)](https://gitter.im/storm-enroute/coroutines?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) 7 | 8 | [Coroutines](http://storm-enroute.com/coroutines) 9 | is a library-level extension for the Scala programming language 10 | that introduces first-class coroutines. 11 | 12 | Check out the [Scala Coroutines website](http://storm-enroute.com/coroutines) for more info! 13 | 14 | Service | Status | Description 15 | -------------------|--------|------------ 16 | Travis | [![Build Status](https://travis-ci.org/storm-enroute/coroutines.png?branch=master)](https://travis-ci.org/storm-enroute/coroutines) | Testing only 17 | Maven | [![Maven Artifact](https://img.shields.io/maven-central/v/com.storm-enroute/coroutines_2.11.svg)](http://mvnrepository.com/artifact/com.storm-enroute/coroutines_2.11) | Coroutines artifact on Maven 18 | -------------------------------------------------------------------------------- /coroutines-128-xmas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/storm-enroute/coroutines/c72f3fdee2a2dd0b139b0a59e4fc350e53f7610a/coroutines-128-xmas.png -------------------------------------------------------------------------------- /coroutines-128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/storm-enroute/coroutines/c72f3fdee2a2dd0b139b0a59e4fc350e53f7610a/coroutines-128.png -------------------------------------------------------------------------------- /coroutines-512-xmas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/storm-enroute/coroutines/c72f3fdee2a2dd0b139b0a59e4fc350e53f7610a/coroutines-512-xmas.png -------------------------------------------------------------------------------- /coroutines-512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/storm-enroute/coroutines/c72f3fdee2a2dd0b139b0a59e4fc350e53f7610a/coroutines-512.png -------------------------------------------------------------------------------- /coroutines-64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/storm-enroute/coroutines/c72f3fdee2a2dd0b139b0a59e4fc350e53f7610a/coroutines-64.png -------------------------------------------------------------------------------- /coroutines-common/src/main/scala/scala/coroutines/common/ByTreeTyper.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.common 2 | 3 | 4 | 5 | import scala.collection._ 6 | import scala.language.experimental.macros 7 | import scala.reflect.macros.whitebox.Context 8 | 9 | 10 | 11 | private[coroutines] class ByTreeTyper[C <: Context](val c: C)(val treeValue: Any) { 12 | import c.universe._ 13 | private val tree = treeValue.asInstanceOf[Tree] 14 | private val treeMapping = mutable.Map[Tree, Tree]() 15 | private val traverser = new TraverserUtil[c.type](c) 16 | val untypedTree = c.untypecheck(tree) 17 | traverser.traverseByShape(untypedTree, tree)((t, pt) => treeMapping(t) = pt) 18 | 19 | object typeOf { 20 | private val augmentedTypes = mutable.Map[Tree, Type]() 21 | def apply(t: Tree) = { 22 | if (augmentedTypes.contains(t)) augmentedTypes(t) 23 | else if (treeMapping.contains(t)) treeMapping(t).tpe 24 | else t.tpe 25 | } 26 | def update(t: Tree, tpe: Type) = augmentedTypes(t) = tpe 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /coroutines-common/src/main/scala/scala/coroutines/common/ByTreeUntyper.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.common 2 | 3 | 4 | 5 | import scala.collection._ 6 | import scala.language.experimental.macros 7 | import scala.reflect.macros.whitebox.Context 8 | 9 | 10 | 11 | private[coroutines] class ByTreeUntyper[C <: Context](val c: C)(val treeValue: Any) { 12 | import c.universe._ 13 | private val tree = treeValue.asInstanceOf[Tree] 14 | private val untypedTree = c.untypecheck(tree) 15 | private val treeMapping = mutable.Map[Tree, Tree]() 16 | private val traverser = new TraverserUtil[c.type](c) 17 | traverser.traverseByShape(tree, untypedTree)((t, pt) => treeMapping(t) = pt) 18 | 19 | def untypecheck(t: Tree) = if (treeMapping.contains(t)) treeMapping(t) else t 20 | } 21 | -------------------------------------------------------------------------------- /coroutines-common/src/main/scala/scala/coroutines/common/Cache.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.common 2 | 3 | 4 | 5 | import scala.collection._ 6 | 7 | 8 | 9 | object Cache { 10 | class _1[T, S](val function: T => S) { 11 | val cache = mutable.Map[T, S]() 12 | def apply(t: T): S = cache.get(t) match { 13 | case Some(s) => s 14 | case None => 15 | val s = function(t) 16 | cache(t) = s 17 | s 18 | } 19 | } 20 | 21 | def cached[T, S](f: T => S): _1[T, S] = new _1(f) 22 | 23 | class _2[T1, T2, S](val function: (T1, T2) => S) { 24 | val cache = mutable.Map[(T1, T2), S]() 25 | def apply(t1: T1, t2: T2): S = cache.get((t1, t2)) match { 26 | case Some(s) => s 27 | case None => 28 | val s = function(t1, t2) 29 | cache((t1, t2)) = s 30 | s 31 | } 32 | } 33 | 34 | def cached[T1, T2, S](f: (T1, T2) => S): _2[T1, T2, S] = new _2(f) 35 | 36 | class _3[T1, T2, T3, S](val function: (T1, T2, T3) => S) { 37 | val cache = mutable.Map[(T1, T2, T3), S]() 38 | def apply(t1: T1, t2: T2, t3: T3): S = cache.get((t1, t2, t3)) match { 39 | case Some(s) => s 40 | case None => 41 | val s = function(t1, t2, t3) 42 | cache((t1, t2, t3)) = s 43 | s 44 | } 45 | } 46 | 47 | def cached[T1, T2, T3, S](f: (T1, T2, T3) => S): _3[T1, T2, T3, S] = new _3(f) 48 | } 49 | -------------------------------------------------------------------------------- /coroutines-common/src/main/scala/scala/coroutines/common/Stack.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.common 2 | 3 | 4 | 5 | import scala.language.experimental.macros 6 | import scala.reflect.macros.whitebox.Context 7 | 8 | 9 | 10 | object Stack { 11 | def init[T](stack: Array[T], size: Int): Unit = macro initMacro[T] 12 | 13 | def initMacro[T: c.WeakTypeTag](c: Context)(stack: c.Tree, size: c.Tree): c.Tree = { 14 | import c.universe._ 15 | 16 | val tpe = implicitly[c.WeakTypeTag[T]] 17 | if (size == q"-1") q"" else q""" 18 | if ($stack == null) $stack = new _root_.scala.Array[$tpe]($size) 19 | """ 20 | } 21 | 22 | def copy[T](src: Array[T], dest: Array[T]): Unit = macro copyMacro[T] 23 | 24 | def copyMacro[T: c.WeakTypeTag](c: Context)(src: c.Tree, dest: c.Tree): c.Tree = { 25 | import c.universe._ 26 | 27 | val q"$srcpath.${srcname: TermName}" = src 28 | val srcptrname = TermName(s"${srcname}ptr") 29 | val srcptr = q"$srcpath.$srcptrname" 30 | val q"$destpath.${destname: TermName}" = dest 31 | val destptrname = TermName(s"${destname}ptr") 32 | val destptr = q"$destpath.$destptrname" 33 | val tpe = implicitly[WeakTypeTag[T]] 34 | 35 | q""" 36 | $destptr = $srcptr 37 | if ($src != null) { 38 | $dest = new _root_.scala.Array[$tpe]($src.length) 39 | _root_.java.lang.System.arraycopy($src, 0, $dest, 0, $srcptr) 40 | } 41 | """ 42 | } 43 | 44 | def push[T](stack: Array[T], x: T, size: Int): Unit = macro pushMacro[T] 45 | 46 | def pushMacro[T: c.WeakTypeTag](c: Context)( 47 | stack: c.Tree, x: c.Tree, size: c.Tree 48 | ): c.Tree = { 49 | import c.universe._ 50 | 51 | val q"$path.${name: TermName}" = stack 52 | val stackptrname = TermName(s"${name}ptr") 53 | val stackptr = q"$path.$stackptrname" 54 | val tpe = implicitly[WeakTypeTag[T]] 55 | q""" 56 | _root_.org.coroutines.common.Stack.init[$tpe]($stack, $size) 57 | if ($stackptr >= $stack.length) { 58 | val nstack = new _root_.scala.Array[$tpe]($stack.length * 2) 59 | _root_.java.lang.System.arraycopy($stack, 0, nstack, 0, $stack.length) 60 | $stack = nstack 61 | } 62 | $stack($stackptr) = $x 63 | $stackptr += 1 64 | """ 65 | } 66 | 67 | def bulkPush[T](stack: Array[T], n: Int, size: Int): Unit = macro bulkPushMacro[T] 68 | 69 | def bulkPushMacro[T: c.WeakTypeTag](c: Context)( 70 | stack: c.Tree, n: c.Tree, size: c.Tree 71 | ): c.Tree = { 72 | import c.universe._ 73 | 74 | val q"$path.${name: TermName}" = stack 75 | val stackptrname = TermName(s"${name}ptr") 76 | val stackptr = q"$path.$stackptrname" 77 | val tpe = implicitly[WeakTypeTag[T]] 78 | val valnme = TermName(c.freshName()) 79 | q""" 80 | _root_.org.coroutines.common.Stack.init[$tpe]($stack, $size) 81 | $stackptr += $n 82 | while ($stackptr >= $stack.length) { 83 | val nstack = new _root_.scala.Array[$tpe]($stack.length * 2) 84 | _root_.java.lang.System.arraycopy($stack, 0, nstack, 0, $stack.length) 85 | $stack = nstack 86 | } 87 | """ 88 | } 89 | 90 | def pop[T](stack: Array[T]): T = macro popMacro[T] 91 | 92 | def popMacro[T: c.WeakTypeTag](c: Context)(stack: c.Tree): c.Tree = { 93 | import c.universe._ 94 | 95 | val q"$path.${name: TermName}" = stack 96 | val stackptrname = TermName(s"${name}ptr") 97 | val stackptr = q"$path.$stackptrname" 98 | val tpe = implicitly[WeakTypeTag[T]] 99 | val valnme = TermName(c.freshName()) 100 | q""" 101 | $stackptr -= 1 102 | val $valnme = $stack($stackptr) 103 | $stack($stackptr) = null.asInstanceOf[$tpe] 104 | $valnme 105 | """ 106 | } 107 | 108 | def bulkPop[T](stack: Array[T], n: Int): Unit = macro bulkPopMacro[T] 109 | 110 | def bulkPopMacro[T: c.WeakTypeTag](c: Context)(stack: c.Tree, n: c.Tree): c.Tree = { 111 | import c.universe._ 112 | 113 | val q"$path.${name: TermName}" = stack 114 | val stackptrname = TermName(s"${name}ptr") 115 | val stackptr = q"$path.$stackptrname" 116 | val tpe = implicitly[WeakTypeTag[T]] 117 | val valnme = TermName(c.freshName()) 118 | q""" 119 | $stackptr -= $n 120 | """ 121 | } 122 | 123 | def top[T](stack: Array[T]): T = macro topMacro[T] 124 | 125 | def topMacro[T: c.WeakTypeTag](c: Context)(stack: c.Tree): c.Tree = { 126 | import c.universe._ 127 | 128 | val q"$path.${name: TermName}" = stack 129 | val stackptrname = TermName(s"${name}ptr") 130 | val stackptr = q"$path.$stackptrname" 131 | q""" 132 | $stack($stackptr - 1) 133 | """ 134 | } 135 | 136 | def get[T](stack: Array[T], n: Int): T = macro getMacro[T] 137 | 138 | def getMacro[T: c.WeakTypeTag](c: Context)(stack: c.Tree, n: c.Tree): c.Tree = { 139 | import c.universe._ 140 | 141 | val q"$path.${name: TermName}" = stack 142 | val stackptrname = TermName(s"${name}ptr") 143 | val stackptr = q"$path.$stackptrname" 144 | val valnme = TermName(c.freshName()) 145 | q""" 146 | $stack($stackptr - 1 - $n) 147 | """ 148 | } 149 | 150 | def set[T](stack: Array[T], n: Int, x: T): Unit = macro setMacro[T] 151 | 152 | def setMacro[T: c.WeakTypeTag](c: Context)( 153 | stack: c.Tree, n: c.Tree, x: c.Tree 154 | ): c.Tree = { 155 | import c.universe._ 156 | 157 | val q"$path.${name: TermName}" = stack 158 | val stackptrname = TermName(s"${name}ptr") 159 | val stackptr = q"$path.$stackptrname" 160 | val valnme = TermName(c.freshName()) 161 | q""" 162 | $stack($stackptr - 1 - $n) = $x 163 | """ 164 | } 165 | 166 | def update[T](stack: Array[T], x: T): T = macro updateMacro[T] 167 | 168 | def updateMacro[T: c.WeakTypeTag](c: Context)(stack: c.Tree, x: c.Tree): c.Tree = { 169 | import c.universe._ 170 | 171 | val q"$path.${name: TermName}" = stack 172 | val stackptrname = TermName(s"${name}ptr") 173 | val stackptr = q"$path.$stackptrname" 174 | val valnme = TermName(c.freshName()) 175 | q""" 176 | val $valnme = $stack($stackptr - 1) 177 | $stack($stackptr - 1) = $x 178 | $valnme 179 | """ 180 | } 181 | 182 | def isEmpty[T](stack: Array[T]): Boolean = macro isEmptyMacro[T] 183 | 184 | def isEmptyMacro[T: c.WeakTypeTag](c: Context)(stack: c.Tree): c.Tree = { 185 | import c.universe._ 186 | 187 | val q"$path.${name: TermName}" = stack 188 | val stackptrname = TermName(s"${name}ptr") 189 | val stackptr = q"$path.$stackptrname" 190 | q""" 191 | $stackptr <= 0 192 | """ 193 | } 194 | 195 | 196 | } 197 | -------------------------------------------------------------------------------- /coroutines-common/src/main/scala/scala/coroutines/common/TraverserUtil.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.common 2 | 3 | 4 | 5 | import scala.language.experimental.macros 6 | import scala.reflect.macros.whitebox.Context 7 | 8 | 9 | 10 | /** Contains extra tree traversal methods. 11 | */ 12 | class TraverserUtil[C <: Context](val c: C) { 13 | import c.universe._ 14 | 15 | /** Traverse equivalent parts of two trees in parallel, 16 | * applying the specified function to parts of the tree with the same shape. 17 | */ 18 | def traverseByShape(t0: Tree, t1: Tree)(f: (Tree, Tree) => Unit) = { 19 | def traverse(t0: Tree, t1: Tree): Unit = { 20 | f(t0, t1) 21 | (t0, t1) match { 22 | case (q"(..$params0) => $body0", q"(..$params1) => $body1") => 23 | // function 24 | for ((p0, p1) <- params0 zip params1) traverse(p0, p1) 25 | traverse(body0, body1) 26 | case (q"{ case ..$cs0 }", q"{ case ..$cs1 }") => 27 | // partial function 28 | for ((c0, c1) <- cs0 zip cs1) traverse(c0, c1) 29 | case (q"if ($c0) $t0 else $e0", q"if ($c1) $t1 else $e1") => 30 | // if 31 | traverse(c0, c1) 32 | traverse(t0, t1) 33 | traverse(e0, e1) 34 | case (q"while ($c0) $b0", q"while ($c1) $b1") => 35 | // while loop 36 | traverse(c0, c1) 37 | traverse(b0, b1) 38 | case (q"do $b0 while ($c0)", q"do $b1 while ($c1)") => 39 | // while loop 40 | traverse(b0, b1) 41 | traverse(c0, c1) 42 | case (q"for (..$enums0) $b0", q"for (..$enums1) $b1") => 43 | // for loop 44 | for ((e0, e1) <- enums0 zip enums1) traverse(e0, e1) 45 | traverse(b0, b1) 46 | case (q"for (..$enums0) yield $b0", q"for (..$enums1) yield $b1") => 47 | // for-yield loop 48 | for ((e0, e1) <- enums0 zip enums1) traverse(e0, e1) 49 | traverse(b0, b1) 50 | case ( 51 | q"new { ..$eds0 } with ..$ps0 { $self0 => ..$stats0 }", 52 | q"new { ..$eds1 } with ..$ps1 { $self1 => ..$stats1 }" 53 | ) => 54 | // new 55 | for ((e0, e1) <- eds0 zip eds1) traverse(e0, e1) 56 | for ((p0, p1) <- ps0 zip ps1) traverse(p0, p1) 57 | traverse(self0, self1) 58 | for ((s0, s1) <- stats0 zip stats1) traverse(s0, s1) 59 | case (q"$a0[$t0]", q"$a1[$t1]") => 60 | // type application 61 | traverse(a0, a1) 62 | traverse(t0, t1) 63 | case (q"$lhs0 = $rhs0", q"$lhs1 = $rhs1") => 64 | // update 65 | traverse(lhs0, lhs1) 66 | traverse(rhs0, rhs1) 67 | case (q"return $r0", q"return $r1") => 68 | // return 69 | traverse(r0, r1) 70 | case (q"throw $e0", q"throw $e1") => 71 | // throw 72 | traverse(e0, e1) 73 | case (q"$e0: $tpt0", q"$e1: $tpt1") => 74 | // ascription 75 | traverse(e0, e1) 76 | traverse(tpt0, tpt1) 77 | case (q"$e0: @$a0", q"$e1: @$a1") => 78 | // annotated 79 | traverse(e0, e1) 80 | traverse(a0, a1) 81 | case (q"(..$exprs0)", q"(..$exprs1)") if exprs1.length > 1 => 82 | // tuple 83 | for ((e0, e1) <- exprs0 zip exprs1) traverse(e0, e1) 84 | case (q"$e0 match { case ..$cs0 }", q"$e1 match { case ..$cs1 }") => 85 | // pattern match 86 | traverse(e0, e1) 87 | for ((c0, c1) <- cs0 zip cs1) traverse(c0, c1) 88 | case ( 89 | q"try $b0 catch { case ..$cs0 } finally $f0", 90 | q"try $b1 catch { case ..$cs1 } finally $f1" 91 | ) => 92 | // try 93 | traverse(b0, b1) 94 | for ((c0, c1) <- cs0 zip cs1) traverse(c0, c1) 95 | traverse(f0, f1) 96 | case (q"$a0[..$tpts0](...$paramss0)", q"$a1[..$tpts1](...$paramss1)") 97 | if tpts0.length > 0 || paramss0.length > 0 => 98 | // application 99 | traverse(a0, a1) 100 | for ((t0, t1) <- tpts0 zip tpts1) traverse(t0, t1) 101 | for ((ps0, ps1) <- paramss0 zip paramss1; (p0, p1) <- ps0 zip ps1) { 102 | traverse(p0, p1) 103 | } 104 | case (q"$r0.$m0", q"$r1.$m1") => 105 | // selection 106 | traverse(r0, r1) 107 | case (q"$q0.super[$s0].$n0", q"$q1.super[$s1].$n1") => 108 | // super selection 109 | case (q"$q0.this", q"$q1.this") => 110 | // this 111 | case (q"{ ..$ss0 }", q"{ ..$ss1 }") if ss0.length > 1 && ss1.length > 1 => 112 | // stats 113 | for ((a, b) <- ss0 zip ss1) traverse(a, b) 114 | case (Block(List(s0), e0), Block(List(s1), e1)) => 115 | // stats, single 116 | traverse(s0, s1) 117 | traverse(e0, e1) 118 | case (tq"$tpt0.type", tq"$tpt1.type") => 119 | // singleton type 120 | traverse(tpt0, tpt1) 121 | case (tq"$r0#$nme0", tq"$r1#$nme1") => 122 | // type projection 123 | traverse(r0, r1) 124 | case (tq"$r0#$nme0", tq"$r1#$nme1") => 125 | // type selection 126 | traverse(r0, r1) 127 | case (tq"$p0.super[$s0].$q0", tq"$p1.super[$s1].$q1") => 128 | // super type selection 129 | case (tq"this.$n0", tq"this.$n1") => 130 | // this type projection 131 | case (tq"$tpt0[..$tps0]", tq"$tpt1[..$tps1]") if tps0.length > 0 => 132 | // applied type 133 | traverse(tpt0, tpt1) 134 | for ((tp0, tp1) <- tps0 zip tps1) traverse(tp0, tp1) 135 | case (tq"$tpt0 @$annots0", tq"$tpt1 @$annots1") => 136 | // annotated type 137 | traverse(tpt0, tpt1) 138 | traverse(annots0, annots1) 139 | case (tq"..$ps0 { ..$defs0 }", tq"..$ps1 { ..$defs1 }") => 140 | // compound type 141 | for ((p0, p1) <- ps0 zip ps1) traverse(p0, p1) 142 | for ((d0, d1) <- defs0 zip defs1) traverse(d0, d1) 143 | case (tq"$tp0 forSome { ..$defs0 }", tq"$tp1 forSome { ..$defs1 }") => 144 | // existential type 145 | traverse(tp0, tp1) 146 | for ((d0, d1) <- defs0 zip defs1) traverse(d0, d1) 147 | case (tq"(..$tps0)", tq"(..$tps1)") if tps0.length > 1 => 148 | // tuple type 149 | for ((tp0, tp1) <- tps0 zip tps1) traverse(tp0, tp1) 150 | case (tq"(..$tps0) => $rt0", tq"(..$tps1) => $rt1") => 151 | // function type 152 | for ((tp0, tp1) <- tps0 zip tps1) traverse(tp0, tp1) 153 | traverse(rt0, rt1) 154 | case (pq"_", pq"_") => 155 | // wildcard pattern 156 | case (pq"$n0 @ $p0", pq"$n1 @ $p1") => 157 | // binding pattern 158 | traverse(p0, p1) 159 | case (pq"$e0(..$pats0)", pq"$e1(..$pats1)") => 160 | // extractor pattern 161 | traverse(e0, e1) 162 | for ((p0, p1) <- pats0 zip pats1) traverse(p0, p1) 163 | case (pq"$p0: $tp0", pq"$p1: $tp1") => 164 | // type pattern 165 | traverse(tp0, tp1) 166 | case (pq"$first0 | ..$rest0", pq"$first1 | ..$rest1") => 167 | // alternative pattern 168 | traverse(first0, first1) 169 | for ((r0, r1) <- rest0 zip rest1) traverse(r0, r1) 170 | case (pq"(..$pats0)", pq"(..$pats1)") if pats0.length > 1 => 171 | // tuple pattern 172 | for ((p0, p1) <- pats0 zip pats1) traverse(p0, p1) 173 | case (q"$_ val $_: $tp0 = $rhs0", q"$_ val $_: $tp1 = $rhs1") => 174 | // val 175 | traverse(tp0, tp1) 176 | traverse(rhs0, rhs1) 177 | case (q"$_ var $_: $tp0 = $rhs0", q"$_ var $_: $tp1 = $rhs1") => 178 | // var 179 | traverse(tp0, tp1) 180 | traverse(rhs0, rhs1) 181 | case ( 182 | q"$_ def $_[..$tps0](...$pss0): $tp0 = $b0", 183 | q"$_ def $_[..$tps1](...$pss1): $tp1 = $b1" 184 | ) => 185 | // method 186 | for ((tp0, tp1) <- tps0 zip tps1) traverse(tp0, tp1) 187 | for ((ps0, ps1) <- pss0 zip pss1; (p0, p1) <- ps0 zip ps1) traverse(p0, p1) 188 | traverse(tp0, tp1) 189 | traverse(b0, b1) 190 | case ( 191 | q"$_ def this(...$pss0) = this(..$as0)", 192 | q"$_ def this(...$pss1) = this(..$as1)" 193 | ) => 194 | // secondary constructor 195 | for ((ps0, ps1) <- pss0 zip pss1; (p0, p1) <- ps0 zip ps1) traverse(p0, p1) 196 | for ((a0, a1) <- as0 zip as1) traverse(a0, a1) 197 | case (q"$_ type $_[..$tps0] = $tp0", q"$_ type $_[..$tps1] = $tp1") => 198 | // type 199 | for ((tp0, tp1) <- tps0 zip tps1) traverse(tp0, tp1) 200 | traverse(tp0, tp1) 201 | case ( 202 | q""" 203 | $_ class $_[..$tps0] $_(...$pss0) 204 | extends { ..$eds0 } with ..$ps0 { $self0 => ..$ss0 } 205 | """, 206 | q""" 207 | $_ class $_[..$tps1] $_(...$pss1) 208 | extends { ..$eds1 } with ..$ps1 { $self1 => ..$ss1 } 209 | """ 210 | ) => 211 | // class 212 | for ((tp0, tp1) <- tps0 zip tps1) traverse(tp0, tp1) 213 | for ((ps0, ps1) <- pss0 zip pss1; (p0, p1) <- ps0 zip ps1) traverse(p0, p1) 214 | for ((e0, e1) <- eds0 zip eds1) traverse(e0, e1) 215 | for ((p0, p1) <- ps0 zip ps1) traverse(p0, p1) 216 | traverse(self0, self1) 217 | for ((a, b) <- ss0 zip ss1) traverse(a, b) 218 | case ( 219 | q""" 220 | $_ trait $_[..$tps0] extends { ..$eds0 } with ..$ps0 { $self0 => ..$ss0 } 221 | """, 222 | q""" 223 | $_ trait $_[..$tps1] extends { ..$eds1 } with ..$ps1 { $self1 => ..$ss1 } 224 | """ 225 | ) => 226 | // trait 227 | for ((tp0, tp1) <- tps0 zip tps1) traverse(tp0, tp1) 228 | for ((e0, e1) <- eds0 zip eds1) traverse(e0, e1) 229 | for ((p0, p1) <- ps0 zip ps1) traverse(p0, p1) 230 | traverse(self0, self1) 231 | for ((a, b) <- ss0 zip ss1) traverse(a, b) 232 | case ( 233 | q""" 234 | $_ object $_ extends { ..$eds0 } with ..$ps0 { $self0 => ..$ss0 } 235 | """, 236 | q""" 237 | $_ object $_ extends { ..$eds1 } with ..$ps1 { $self1 => ..$ss1 } 238 | """ 239 | ) => 240 | // object 241 | for ((e0, e1) <- eds0 zip eds1) traverse(e0, e1) 242 | for ((p0, p1) <- ps0 zip ps1) traverse(p0, p1) 243 | traverse(self0, self1) 244 | for ((a, b) <- ss0 zip ss1) traverse(a, b) 245 | case ( 246 | q""" 247 | package object $_ extends { ..$eds0 } with ..$ps0 { $self0 => ..$ss0 } 248 | """, 249 | q""" 250 | package object $_ extends { ..$eds1 } with ..$ps1 { $self1 => ..$ss1 } 251 | """ 252 | ) => 253 | // package object 254 | for ((e0, e1) <- eds0 zip eds1) traverse(e0, e1) 255 | for ((p0, p1) <- ps0 zip ps1) traverse(p0, p1) 256 | traverse(self0, self1) 257 | for ((a, b) <- ss0 zip ss1) traverse(a, b) 258 | case (q"package $r0 { ..$tops0 }", q"package $r1 { ..$tops1 }") => 259 | // package 260 | traverse(r0, r1) 261 | for ((t0, t1) <- tops0 zip tops1) traverse(t0, t1) 262 | case (q"import $r0.{..$ss0}", q"import $r1.{..$ss1}") => 263 | // import 264 | traverse(r0, r1) 265 | for ((s0, s1) <- ss0 zip ss1) traverse(s0, s1) 266 | case (cq"$p0 if $c0 => $b0", cq"$p1 if $c1 => $b1") => 267 | // case clause 268 | traverse(p0, p1) 269 | traverse(c0, c1) 270 | traverse(b0, b1) 271 | case (fq"$p0 <- $e0", fq"$p1 <- $e1") => 272 | // generator enumerator 273 | traverse(e0, e1) 274 | traverse(p0, p1) 275 | case (fq"$p0 = $e0", fq"$p1 = $e1") => 276 | // generator value definition 277 | traverse(e0, e1) 278 | traverse(p0, p1) 279 | case (fq"if $e0", fq"if $e1") => 280 | // guard enumerator 281 | traverse(e0, e1) 282 | case _ => 283 | // identifier 284 | // literal 285 | // literal pattern 286 | // type identifier 287 | } 288 | } 289 | traverse(t0, t1) 290 | } 291 | } 292 | -------------------------------------------------------------------------------- /coroutines-extra/src/main/scala/org/coroutines/extra/AsyncAwait.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.extra 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import scala.annotation.unchecked.uncheckedVariance 7 | import scala.concurrent.ExecutionContext.Implicits.global 8 | import scala.concurrent._ 9 | import scala.language.experimental.macros 10 | import scala.reflect.macros.whitebox.Context 11 | import scala.util.{ Success, Failure } 12 | 13 | 14 | 15 | object AsyncAwait { 16 | /** Await the result of a future. 17 | * 18 | * When called inside an `async` body, this function will block until its 19 | * associated future completes. 20 | * 21 | * @return A coroutine that yields a tuple. `async` will assign this tuple's 22 | * second element to hold the completed result of the `Future` passed 23 | * into the coroutine. The coroutine will directly return the 24 | * result of the future. 25 | */ 26 | def await[R]: Future[R] ~~> (Future[R], R) = 27 | coroutine { (awaitedFuture: Future[R]) => 28 | yieldval(awaitedFuture) 29 | var result: R = null.asInstanceOf[R] 30 | awaitedFuture.value match { 31 | case Some(Success(x)) => result = x 32 | case Some(Failure(error)) => throw error 33 | case None => sys.error("Future was not completed") 34 | } 35 | result 36 | } 37 | 38 | /** Calls `body`, blocking on any calls to `await`. 39 | * 40 | * @param body A coroutine to be invoked. 41 | * @return A `Future` wrapping the result of the coroutine. The future fails 42 | * if `body` throws an exception or one of the `await`s takes a failed 43 | * future. 44 | */ 45 | def asyncCall[Y, R](body: ~~~>[Future[Y], R]): Future[R] = { 46 | val c = call(body()) 47 | val p = Promise[R] 48 | def loop() { 49 | if (!c.resume) { 50 | c.tryResult match { 51 | case Success(result) => p.success(result) 52 | case Failure(exception) => p.failure(exception) 53 | } 54 | } else { 55 | val awaitedFuture = c.value 56 | if (awaitedFuture.isCompleted) { 57 | loop() 58 | } else { 59 | awaitedFuture onComplete { 60 | case _ => loop() 61 | } 62 | } 63 | } 64 | } 65 | Future { loop() } 66 | p.future 67 | } 68 | 69 | /** Wraps `body` inside a coroutine and asynchronously invokes it using `asyncMacro`. 70 | * 71 | * @param body The block of code to wrap inside an asynchronous coroutine. 72 | * @return A `Future` wrapping the result of `body`. 73 | */ 74 | def async[Y, R](body: =>R): Future[R] = macro asyncMacro[Y, R] 75 | 76 | /** Implements `async`. 77 | * 78 | * Wraps `body` inside a coroutine and calls `asyncCall`. 79 | * 80 | * @param body The function to be wrapped in a coroutine. 81 | * @return A tree that contains an invocation of `asyncCall` on a coroutine 82 | * with `body` as its body. 83 | */ 84 | def asyncMacro[Y, R](c: Context)(body: c.Tree): c.Tree = { 85 | import c.universe._ 86 | 87 | /** Ensures that no values are yielded inside the async block. 88 | * 89 | * It is similar to and shares functionality with 90 | * [[org.coroutines.AstCanonicalization.NestedContextValidator]]. 91 | * 92 | */ 93 | class NoYieldsValidator extends Traverser { 94 | // return type is the lub of the function return type and yield argument types 95 | def isCoroutinesPkg(q: Tree) = q match { 96 | case q"org.coroutines.`package`" => true 97 | case q"coroutines.this.`package`" => true 98 | case t => false 99 | } 100 | 101 | override def traverse(tree: Tree): Unit = tree match { 102 | case q"$qual.yieldval[$_]($_)" if isCoroutinesPkg(qual) => 103 | c.abort(tree.pos, 104 | "The yieldval statement only be invoked directly inside the coroutine. " + 105 | "Nested classes, functions or for-comprehensions, should either use the " + 106 | "call statement or declare another coroutine.") 107 | case q"$qual.yieldto[$_]($_)" if isCoroutinesPkg(qual) => 108 | c.abort(tree.pos, 109 | "The yieldto statement only be invoked directly inside the coroutine. " + 110 | "Nested classes, functions or for-comprehensions, should either use the " + 111 | "call statement or declare another coroutine.") 112 | case q"$qual.call($co.apply(..$args))" if isCoroutinesPkg(qual) => 113 | // no need to check further, the call macro will validate the coroutine type 114 | case _ => 115 | super.traverse(tree) 116 | } 117 | } 118 | 119 | new NoYieldsValidator().traverse(body) 120 | 121 | q""" 122 | val c = coroutine { () => 123 | $body 124 | } 125 | _root_.org.coroutines.extra.AsyncAwait.asyncCall(c) 126 | """ 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /coroutines-extra/src/main/scala/org/coroutines/extra/Enumerator.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.extra 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import scala.collection._ 7 | import scala.language.experimental.macros 8 | import scala.reflect.macros.whitebox.Context 9 | 10 | 11 | 12 | /** Ignores and does nothing with the return value of the coroutine. This makes 13 | * specialization simpler and also makes it more straightforward for the user to 14 | * create an `Enumerator` from a `Coroutine`. 15 | * 16 | * Takes a `Coroutine.Instance` over a `Coroutine` so that both the constructor is 17 | * more general and so that an enumerator can be built from an in-progress coroutine. 18 | */ 19 | class Enumerator[@specialized(Int, Long, Double) Y] 20 | (instance: Coroutine.Instance[Y, _]) { 21 | private var _hasNext = instance.pull 22 | 23 | /** Return whether or not the enumerator has a next value. 24 | * 25 | * Internally, this variable is set via calls to `instance.pull`. 26 | * 27 | * @return true iff `next` can be called again without error 28 | */ 29 | def hasNext(): Boolean = _hasNext 30 | 31 | /** Returns the next value in the enumerator. 32 | * 33 | * Also advances the enumerator to the next return point. 34 | * 35 | * @return The result of `instance.value` after the previous call to `instance.pull` 36 | */ 37 | def next(): Y = { 38 | val result = instance.value 39 | _hasNext = instance.pull 40 | result 41 | } 42 | } 43 | 44 | 45 | object Enumerator { 46 | def apply[Y](c: Coroutine.Instance[Y, _]) = new Enumerator(c.snapshot) 47 | 48 | def apply[Y](c: Coroutine._0[Y, _]) = new Enumerator(call(c())) 49 | 50 | def apply[Y, R](body: =>R): Enumerator[Y] = macro applyMacro[Y, R] 51 | 52 | def applyMacro[Y, R](c: Context)(body: c.Tree): c.Tree = { 53 | import c.universe._ 54 | 55 | q""" 56 | Enumerator(coroutine { () => 57 | $body 58 | }) 59 | """ 60 | } 61 | } -------------------------------------------------------------------------------- /coroutines-extra/src/test/scala/org/coroutines/extra/async-await-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.extra 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import org.scalatest._ 7 | import scala.annotation.unchecked.uncheckedVariance 8 | import scala.concurrent.ExecutionContext.Implicits.global 9 | import scala.concurrent._ 10 | import scala.concurrent.duration._ 11 | import scala.language.{ reflectiveCalls, postfixOps } 12 | import scala.util.Success 13 | 14 | 15 | 16 | class TestException(msg: String = "") extends Throwable(msg) 17 | 18 | 19 | class AsyncAwaitTest extends FunSuite with Matchers { 20 | import AsyncAwait._ 21 | 22 | /** Source: https://git.io/vorXv 23 | * The use of Async/Await as opposed to pure futures allows this control flow 24 | * to be written more easily. 25 | * The execution blocks when awaiting for the result of `f1`. `f2` only blocks 26 | * after `AsyncAwait.await(f1)` evaluates to `true`. 27 | */ 28 | test("simple test") { 29 | val future = async { 30 | val f1 = Future(true) 31 | val f2 = Future(42) 32 | if (await(f1)) { 33 | await(f2) 34 | } else { 35 | 0 36 | } 37 | } 38 | assert(Await.result(future, 1 seconds) == 42) 39 | } 40 | 41 | /** Asynchronous blocks of code can be defined either outside of or within any 42 | * part of an `async` block. This allows the user to avoid triggering the 43 | * computation of slow futures until it is necessary. 44 | * For instance, computation will not begin on `innerFuture` until 45 | * `await(trueFuture)` evaluates to true. 46 | */ 47 | test("nested async blocks") { 48 | val outerFuture = async { 49 | val trueFuture = Future { true } 50 | if (await(trueFuture)) { 51 | val innerFuture = async { 52 | await(Future { 100 } ) 53 | } 54 | await(innerFuture) 55 | } else { 56 | 200 57 | } 58 | } 59 | assert(Await.result(outerFuture, 1 seconds) == 100) 60 | } 61 | 62 | /** Uncaught exceptions thrown inside async blocks cause the associated futures 63 | * to fail. 64 | */ 65 | test("error handling test 1") { 66 | val errorMessage = "System error!" 67 | val exception = intercept[RuntimeException] { 68 | val future = async { 69 | sys.error(errorMessage) 70 | await(Future("dog")) 71 | } 72 | val result = Await.result(future, 1 seconds) 73 | } 74 | assert(exception.getMessage == errorMessage) 75 | } 76 | 77 | test("error handling test 2") { 78 | val errorMessage = "Internal await error" 79 | val exception = intercept[RuntimeException] { 80 | val future = async { 81 | await(Future { 82 | sys.error(errorMessage) 83 | "Here ya go" 84 | }) 85 | } 86 | val result = Await.result(future, 1 seconds) 87 | } 88 | assert(exception.getMessage == errorMessage) 89 | } 90 | 91 | test("no yields allowed inside async statements 1") { 92 | """val future = AsyncAwait.async { 93 | yieldval("hubba") 94 | Future(1) 95 | }""" shouldNot compile 96 | } 97 | 98 | test("no yields allowed inside async statements 2") { 99 | val c = coroutine { () => 100 | yieldval(0) 101 | } 102 | val instance = call(c()) 103 | 104 | """val future = AsyncAwait.async { 105 | yieldto(instance) 106 | Future(1) 107 | }""" shouldNot compile 108 | } 109 | 110 | /** Source: https://git.io/vowde 111 | * Without the closing `()`, the compiler complains about expecting return 112 | * type `Future[Unit]` but finding `Future[Nothing]`. 113 | */ 114 | test("uncaught exception within async after await") { 115 | val future = async { 116 | await(Future(())) 117 | throw new TestException 118 | () 119 | } 120 | intercept[TestException] { 121 | Await.result(future, 1 seconds) 122 | } 123 | } 124 | 125 | // Source: https://git.io/vowdk 126 | test("await failing future within async") { 127 | val base = Future[Int] { throw new TestException } 128 | val future = async { 129 | val x = await(base) 130 | x * 2 131 | } 132 | intercept[TestException] { Await.result(future, 1 seconds) } 133 | } 134 | 135 | /** Source: https://git.io/vowdY 136 | * Exceptions thrown inside `await` calls are properly bubbled up. They cause 137 | * the async block's future to fail. 138 | */ 139 | test("await failing future within async after await") { 140 | val base = Future[Any] { "five!".length } 141 | val future = async { 142 | val a = await(base.mapTo[Int]) 143 | val b = await(Future { (a * 2).toString }.mapTo[Int]) 144 | val c = await(Future { (7 * 2).toString }) 145 | b + "-" + c 146 | } 147 | intercept[ClassCastException] { 148 | Await.result(future, 1 seconds) 149 | } 150 | } 151 | 152 | test("nested failing future within async after await") { 153 | val base = Future[Any] { "five!".length } 154 | val future = async { 155 | val a = await(base.mapTo[Int]) 156 | val b = await( 157 | await(Future((Future { (a * 2).toString }).mapTo[Int]))) 158 | val c = await(Future { (7 * 2).toString }) 159 | b + "-" + c 160 | } 161 | intercept[ClassCastException] { 162 | Await.result(future, 1 seconds) 163 | } 164 | } 165 | 166 | test("await should bubble up exceptions") { 167 | def thrower() = { 168 | throw new TestException 169 | Future(1) 170 | } 171 | 172 | var exceptionFound = false 173 | val future = async { 174 | try { 175 | await(thrower()) 176 | () 177 | } catch { 178 | case _: TestException => exceptionFound = true 179 | } 180 | } 181 | val r = Await.result(future, 1 seconds) 182 | assert(exceptionFound) 183 | } 184 | 185 | test("await should bubble up exceptions from failed futures") { 186 | def failer(): Future[Int] = { 187 | Future.failed(new TestException("kaboom")) 188 | } 189 | 190 | var exceptionFound = false 191 | val future = async { 192 | try { 193 | await(failer()) 194 | () 195 | } catch { 196 | case _: TestException => exceptionFound = true 197 | } 198 | } 199 | val r = Await.result(future, 1 seconds) 200 | assert(exceptionFound) 201 | } 202 | } 203 | -------------------------------------------------------------------------------- /coroutines-extra/src/test/scala/org/coroutines/extra/enumerator-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.extra 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import org.scalatest._ 7 | import scala.collection._ 8 | 9 | 10 | 11 | class EnumeratorsTest extends FunSuite with Matchers { 12 | val rube = coroutine { () => 13 | yieldval(1) 14 | yieldval(2) 15 | yieldval(3) 16 | } 17 | 18 | // Asserts that `apply` takes a `snapshot` of the instance. 19 | test("enumerator creation from coroutine instance") { 20 | val instance = call(rube()) 21 | 22 | val enumerator1 = Enumerator(instance) 23 | assert(enumerator1.hasNext()) 24 | assert(enumerator1.next == 1) 25 | assert(enumerator1.next == 2) 26 | assert(enumerator1.next == 3) 27 | assert(!enumerator1.hasNext) 28 | 29 | val enumerator2 = Enumerator(instance) 30 | assert(enumerator2.hasNext()) 31 | assert(enumerator2.next == 1) 32 | assert(enumerator2.next == 2) 33 | assert(enumerator2.next == 3) 34 | assert(!enumerator2.hasNext) 35 | } 36 | 37 | /** Asserts that more than one `Enumerator` can be created from the same 38 | * `Coroutine._0`. 39 | */ 40 | test("enumerator creation from coroutine_0") { 41 | val enumerator1 = Enumerator(rube) 42 | assert(enumerator1.hasNext()) 43 | assert(enumerator1.next == 1) 44 | assert(enumerator1.next == 2) 45 | assert(enumerator1.next == 3) 46 | assert(!enumerator1.hasNext) 47 | 48 | val enumerator2 = Enumerator(rube) 49 | assert(enumerator2.hasNext()) 50 | assert(enumerator2.next == 1) 51 | assert(enumerator2.next == 2) 52 | assert(enumerator2.next == 3) 53 | assert(!enumerator2.hasNext) 54 | } 55 | 56 | test("enumerator creation from code block") { 57 | val enumerator = Enumerator { 58 | var i = 0 59 | while (i < 5) { 60 | yieldval(i) 61 | i += 1 62 | } 63 | } 64 | assert(enumerator.hasNext()) 65 | for (i <- 0 until 5) { 66 | assert(enumerator.next == i) 67 | } 68 | assert(!enumerator.hasNext) 69 | } 70 | 71 | test("enumerator should ignore return value of coroutine") { 72 | val rubeWithReturn = coroutine { () => 73 | yieldval(1) 74 | yieldval(2) 75 | yieldval(3) 76 | "foo" 77 | } 78 | val enumerator = Enumerator(rubeWithReturn) 79 | assert(enumerator.hasNext()) 80 | assert(enumerator.next == 1) 81 | assert(enumerator.next == 2) 82 | assert(enumerator.next == 3) 83 | assert(!enumerator.hasNext) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /coroutines-extra/src/test/scala/org/coroutines/extra/enumerators-boxing-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines.extra 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import org.scalameter.api._ 7 | import org.scalameter.japi.JBench 8 | import org.scalameter.picklers.noPickler._ 9 | import org.scalameter.execution.invocation._ 10 | 11 | 12 | 13 | class EnumeratorsBoxingBench extends JBench.Forked[Long] { 14 | override def defaultConfig = Context( 15 | exec.minWarmupRuns -> 2, 16 | exec.maxWarmupRuns -> 5, 17 | exec.independentSamples -> 1, 18 | verbose -> false 19 | ) 20 | 21 | def measurer = 22 | for (table <- Measurer.BoxingCount.allWithoutBoolean()) yield { 23 | table.copy(value = table.value.valuesIterator.sum) 24 | } 25 | 26 | def aggregator = Aggregator.median 27 | 28 | override def reporter = Reporter.Composite( 29 | LoggingReporter(), 30 | ValidationReporter() 31 | ) 32 | 33 | val sizes = Gen.single("size")(1000) 34 | 35 | val noBoxingContext = Context( 36 | reports.validation.predicate -> { (n: Any) => n == 0 } 37 | ) 38 | 39 | @gen("sizes") 40 | @benchmark("coroutines.extra.boxing.apply.instance.noReturn") 41 | @curve("coroutine") 42 | @ctx("noBoxingContext") 43 | def applyInstanceTestNoReturn(size: Int) { 44 | val id = coroutine { (n: Int) => 45 | var i = 0 46 | while (i < n) { 47 | yieldval(i) 48 | i += 1 49 | } 50 | } 51 | var i = 0 52 | val instance = call(id(size)) 53 | val enumerator = Enumerator(instance) 54 | } 55 | 56 | @gen("sizes") 57 | @benchmark("coroutines.extra.boxing.apply.coroutine_0.noReturn") 58 | @curve("coroutine") 59 | @ctx("noBoxingContext") 60 | def applyCoroutine_0TestNoReturn(size: Int) { 61 | val rube = coroutine { () => 62 | yieldval(1) 63 | yieldval(2) 64 | yieldval(3) 65 | } 66 | val enumerator = Enumerator(rube) 67 | } 68 | 69 | @gen("sizes") 70 | @benchmark("coroutines.extra.boxing.apply.instance.return") 71 | @curve("coroutine") 72 | @ctx("noBoxingContext") 73 | def applyInstanceTestReturn(size: Int) { 74 | val idDifferentReturnType = coroutine { (n: Int) => 75 | var i = 0 76 | while (i < n) { 77 | yieldval(i) 78 | i += 1 79 | } 80 | "foo" 81 | } 82 | 83 | val idSameReturnType = coroutine { (n: Int) => 84 | var i = 0 85 | while (i < n) { 86 | yieldval(i) 87 | i += 1 88 | } 89 | 5 90 | } 91 | 92 | var i = 0 93 | val foo = Enumerator(call(idDifferentReturnType(size))) 94 | val bar = Enumerator(call(idSameReturnType(size))) 95 | } 96 | 97 | @gen("sizes") 98 | @benchmark("coroutines.extra.boxing.apply.coroutine_0.return") 99 | @curve("coroutine") 100 | @ctx("noBoxingContext") 101 | def applyCoroutine_0TestReturn(size: Int) { 102 | val rubeDifferentReturnType = coroutine { () => 103 | yieldval(1) 104 | yieldval(2) 105 | yieldval(3) 106 | "bar" 107 | } 108 | val rubeSameReturnType = coroutine { () => 109 | yieldval(1) 110 | yieldval(2) 111 | yieldval(3) 112 | 1 113 | } 114 | val foo = Enumerator(rubeDifferentReturnType) 115 | val bar = Enumerator(rubeSameReturnType) 116 | } 117 | } 118 | -------------------------------------------------------------------------------- /coroutines.svg: -------------------------------------------------------------------------------- 1 | 2 | 22 | 24 | 25 | 27 | image/svg+xml 28 | 30 | 31 | 32 | 33 | 34 | 55 | 57 | 63 | 67 | 71 | 72 | 78 | 83 | 88 | 89 | 99 | 109 | 119 | 120 | 125 | 134 | 143 | 144 | -------------------------------------------------------------------------------- /cross.conf: -------------------------------------------------------------------------------- 1 | 2.11.4 2 | -------------------------------------------------------------------------------- /dependencies.conf: -------------------------------------------------------------------------------- 1 | 2 | coroutines-common = [] 3 | 4 | coroutines = [ 5 | { 6 | repo = "scalameter" 7 | project = "scalameter" 8 | artifact = ["com.storm-enroute", "scalameter", "0.9-SNAPSHOT", "test;bench"] 9 | } 10 | ] 11 | 12 | coroutines-extra = [] 13 | -------------------------------------------------------------------------------- /project/Build.scala: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import java.io._ 5 | import org.stormenroute.mecha._ 6 | import sbt._ 7 | import sbt.Keys._ 8 | import sbt.Process._ 9 | 10 | 11 | 12 | object CoroutinesBuild extends MechaRepoBuild { 13 | 14 | def repoName = "coroutines" 15 | 16 | /* coroutines */ 17 | 18 | val frameworkVersion = Def.setting { 19 | ConfigParsers.versionFromFile( 20 | (baseDirectory in coroutines).value / "version.conf", 21 | List("coroutines_major", "coroutines_minor")) 22 | } 23 | 24 | val coroutinesCrossScalaVersions = Def.setting { 25 | val dir = (baseDirectory in coroutines).value 26 | val path = dir + File.separator + "cross.conf" 27 | scala.io.Source.fromFile(path).getLines.filter(_.trim != "").toSeq 28 | } 29 | 30 | val coroutinesScalaVersion = Def.setting { 31 | coroutinesCrossScalaVersions.value.head 32 | } 33 | 34 | val coroutinesSettings = MechaRepoPlugin.defaultSettings ++ Seq( 35 | name := "coroutines", 36 | organization := "com.storm-enroute", 37 | version <<= frameworkVersion, 38 | scalaVersion <<= coroutinesScalaVersion, 39 | crossScalaVersions <<= coroutinesCrossScalaVersions, 40 | libraryDependencies <++= (scalaVersion)(sv => dependencies(sv)), 41 | libraryDependencies ++= superRepoDependencies("coroutines"), 42 | testFrameworks += new TestFramework("org.scalameter.ScalaMeterFramework"), 43 | scalacOptions ++= Seq( 44 | "-deprecation", 45 | "-unchecked", 46 | "-optimise", 47 | "-Yinline-warnings" 48 | ), 49 | resolvers ++= Seq( 50 | "Sonatype OSS Snapshots" at 51 | "https://oss.sonatype.org/content/repositories/snapshots", 52 | "Sonatype OSS Releases" at 53 | "https://oss.sonatype.org/content/repositories/releases" 54 | ), 55 | ivyLoggingLevel in ThisBuild := UpdateLogging.Quiet, 56 | publishMavenStyle := true, 57 | publishTo <<= version { (v: String) => 58 | val nexus = "https://oss.sonatype.org/" 59 | if (v.trim.endsWith("SNAPSHOT")) 60 | Some("snapshots" at nexus + "content/repositories/snapshots") 61 | else 62 | Some("releases" at nexus + "service/local/staging/deploy/maven2") 63 | }, 64 | publishArtifact in Test := false, 65 | pomIncludeRepository := { _ => false }, 66 | pomExtra := 67 | http://storm-enroute.com/ 68 | 69 | 70 | BSD-style 71 | http://opensource.org/licenses/BSD-3-Clause 72 | repo 73 | 74 | 75 | 76 | git@github.com:storm-enroute/coroutines.git 77 | scm:git:git@github.com:storm-enroute/coroutines.git 78 | 79 | 80 | 81 | axel22 82 | Aleksandar Prokopec 83 | http://axel22.github.com/ 84 | 85 | , 86 | mechaPublishKey <<= mechaPublishKey.dependsOn(publish), 87 | mechaDocsRepoKey := "git@github.com:storm-enroute/apidocs.git", 88 | mechaDocsBranchKey := "gh-pages", 89 | mechaDocsPathKey := "coroutines" 90 | ) 91 | 92 | def dependencies(scalaVersion: String) = 93 | CrossVersion.partialVersion(scalaVersion) match { 94 | case Some((2, major)) if major >= 11 => Seq( 95 | "org.scalatest" % "scalatest_2.11" % "2.2.6" % "test", 96 | "org.scala-lang.modules" %% "scala-parser-combinators" % "1.0.2", 97 | "org.scala-lang" % "scala-reflect" % "2.11.4", 98 | "org.scala-lang.modules" % "scala-async_2.11" % "0.9.5" % "test;bench" 99 | ) 100 | case _ => Nil 101 | } 102 | 103 | val coroutinesCommonSettings = MechaRepoPlugin.defaultSettings ++ Seq( 104 | name := "coroutines-common", 105 | organization := "com.storm-enroute", 106 | version <<= frameworkVersion, 107 | scalaVersion <<= coroutinesScalaVersion, 108 | crossScalaVersions <<= coroutinesCrossScalaVersions, 109 | libraryDependencies <++= (scalaVersion)(sv => commonDependencies(sv)), 110 | libraryDependencies ++= superRepoDependencies("coroutines-common"), 111 | scalacOptions ++= Seq( 112 | "-deprecation", 113 | "-unchecked", 114 | "-optimise", 115 | "-Yinline-warnings" 116 | ), 117 | resolvers ++= Seq( 118 | "Sonatype OSS Snapshots" at 119 | "https://oss.sonatype.org/content/repositories/snapshots", 120 | "Sonatype OSS Releases" at 121 | "https://oss.sonatype.org/content/repositories/releases" 122 | ), 123 | ivyLoggingLevel in ThisBuild := UpdateLogging.Quiet, 124 | publishMavenStyle := true, 125 | publishTo <<= version { (v: String) => 126 | val nexus = "https://oss.sonatype.org/" 127 | if (v.trim.endsWith("SNAPSHOT")) 128 | Some("snapshots" at nexus + "content/repositories/snapshots") 129 | else 130 | Some("releases" at nexus + "service/local/staging/deploy/maven2") 131 | }, 132 | publishArtifact in Test := false, 133 | pomIncludeRepository := { _ => false }, 134 | pomExtra := 135 | http://storm-enroute.com/ 136 | 137 | 138 | BSD-style 139 | http://opensource.org/licenses/BSD-3-Clause 140 | repo 141 | 142 | 143 | 144 | git@github.com:storm-enroute/coroutines.git 145 | scm:git:git@github.com:storm-enroute/coroutines.git 146 | 147 | 148 | 149 | axel22 150 | Aleksandar Prokopec 151 | http://axel22.github.com/ 152 | 153 | , 154 | mechaPublishKey <<= mechaPublishKey.dependsOn(publish), 155 | mechaDocsRepoKey := "git@github.com:storm-enroute/apidocs.git", 156 | mechaDocsBranchKey := "gh-pages", 157 | mechaDocsPathKey := "coroutines-common" 158 | ) 159 | 160 | val coroutinesExtraSettings = MechaRepoPlugin.defaultSettings ++ Seq( 161 | name := "coroutines-extra", 162 | organization := "com.storm-enroute", 163 | version <<= frameworkVersion, 164 | scalaVersion <<= coroutinesScalaVersion, 165 | crossScalaVersions <<= coroutinesCrossScalaVersions, 166 | libraryDependencies <++= (scalaVersion)(sv => extraDependencies(sv)), 167 | testFrameworks += new TestFramework("org.scalameter.ScalaMeterFramework"), 168 | scalacOptions ++= Seq( 169 | "-deprecation", 170 | "-unchecked", 171 | "-optimise", 172 | "-Yinline-warnings" 173 | ), 174 | resolvers ++= Seq( 175 | "Sonatype OSS Snapshots" at 176 | "https://oss.sonatype.org/content/repositories/snapshots", 177 | "Sonatype OSS Releases" at 178 | "https://oss.sonatype.org/content/repositories/releases" 179 | ), 180 | ivyLoggingLevel in ThisBuild := UpdateLogging.Quiet, 181 | publishMavenStyle := true, 182 | publishTo <<= version { (v: String) => 183 | val nexus = "https://oss.sonatype.org/" 184 | if (v.trim.endsWith("SNAPSHOT")) 185 | Some("snapshots" at nexus + "content/repositories/snapshots") 186 | else 187 | Some("releases" at nexus + "service/local/staging/deploy/maven2") 188 | }, 189 | publishArtifact in Test := false, 190 | pomIncludeRepository := { _ => false }, 191 | pomExtra := 192 | http://storm-enroute.com/ 193 | 194 | 195 | BSD-style 196 | http://opensource.org/licenses/BSD-3-Clause 197 | repo 198 | 199 | 200 | 201 | git@github.com:storm-enroute/coroutines.git 202 | scm:git:git@github.com:storm-enroute/coroutines.git 203 | 204 | 205 | 206 | axel22 207 | Aleksandar Prokopec 208 | http://axel22.github.com/ 209 | 210 | , 211 | mechaPublishKey <<= mechaPublishKey.dependsOn(publish), 212 | mechaDocsRepoKey := "git@github.com:storm-enroute/apidocs.git", 213 | mechaDocsBranchKey := "gh-pages", 214 | mechaDocsPathKey := "coroutines-extra" 215 | ) 216 | 217 | def commonDependencies(scalaVersion: String) = 218 | CrossVersion.partialVersion(scalaVersion) match { 219 | case Some((2, major)) if major >= 11 => Seq( 220 | "org.scalatest" % "scalatest_2.11" % "2.2.6" % "test", 221 | "org.scala-lang.modules" %% "scala-parser-combinators" % "1.0.2", 222 | "org.scala-lang" % "scala-reflect" % "2.11.4" 223 | ) 224 | case _ => Nil 225 | } 226 | 227 | def extraDependencies(scalaVersion: String) = 228 | CrossVersion.partialVersion(scalaVersion) match { 229 | case Some((2, major)) if major >= 11 => Seq( 230 | "org.scalatest" % "scalatest_2.11" % "2.2.6" % "test" 231 | ) 232 | case _ => Nil 233 | } 234 | 235 | lazy val Benchmarks = config("bench") extend (Test) 236 | 237 | lazy val coroutines: Project = Project( 238 | "coroutines", 239 | file("."), 240 | settings = coroutinesSettings 241 | ) configs( 242 | Benchmarks 243 | ) settings( 244 | inConfig(Benchmarks)(Defaults.testSettings): _* 245 | ) aggregate( 246 | coroutinesCommon 247 | ) dependsOn( 248 | coroutinesCommon % "compile->compile;test->test" 249 | ) dependsOnSuperRepo 250 | 251 | lazy val coroutinesCommon: Project = Project( 252 | "coroutines-common", 253 | file("coroutines-common"), 254 | settings = coroutinesCommonSettings 255 | ) dependsOnSuperRepo 256 | 257 | 258 | lazy val coroutinesExtra: Project = Project( 259 | "coroutines-extra", 260 | file("coroutines-extra"), 261 | settings = coroutinesExtraSettings 262 | ) dependsOn( 263 | coroutines % "compile->compile;test->test" 264 | ) dependsOnSuperRepo 265 | } 266 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | 2 | addSbtPlugin("com.typesafe.sbt" % "sbt-pgp" % "0.8.1") 3 | 4 | addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") 5 | -------------------------------------------------------------------------------- /project/project/Build.scala: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import java.io.File 5 | import sbt._ 6 | import sbt.Keys._ 7 | 8 | 9 | 10 | object Plugins extends Build { 11 | val mechadir = new File(s"mecha") 12 | val mechaPlugin = { 13 | if (mechadir.exists) ProjectRef(file("../../mecha"), "mecha") 14 | else ProjectRef(uri("git://github.com/storm-enroute/mecha.git"), "mecha") 15 | } 16 | 17 | lazy val build = Project( 18 | "coroutines-build", 19 | file(".") 20 | ).dependsOn(mechaPlugin) 21 | 22 | // boilerplate due to: 23 | // https://github.com/sbt/sbt/issues/895 24 | 25 | // Return our new resolver by default 26 | override def buildLoaders = 27 | BuildLoader.resolve(gitResolver) +: super.buildLoaders 28 | 29 | // Define a new build resolver to wrap the original git one 30 | def gitResolver(info: BuildLoader.ResolveInfo): Option[() => File] = 31 | if (info.uri.getScheme != "git") 32 | None 33 | else { 34 | // Use a subdirectory of the staging directory for the new plugin build. 35 | // The subdirectory name is derived from a hash of the URI, 36 | // and so identical URIs will resolve to the same directory. 37 | val hashDir = new File(info.staging, 38 | Hash.halfHashString(info.uri.normalize.toASCIIString)) 39 | hashDir.mkdirs() 40 | 41 | // Return the original git resolver that will do the actual work. 42 | Resolvers.git(info) 43 | } 44 | } -------------------------------------------------------------------------------- /src/bench/scala/org/coroutines/AsyncAwaitBench.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalameter.api._ 6 | import org.scalameter.japi.JBench 7 | import scala.async.Async.async 8 | import scala.async.Async.await 9 | import scala.collection._ 10 | import scala.concurrent._ 11 | import scala.concurrent.duration._ 12 | import scala.concurrent.ExecutionContext.Implicits.global 13 | 14 | 15 | 16 | class AsyncAwaitBench extends JBench.OfflineReport { 17 | override def defaultConfig = Context( 18 | exec.minWarmupRuns -> 100, 19 | exec.maxWarmupRuns -> 100, 20 | exec.benchRuns -> 36, 21 | exec.independentSamples -> 4, 22 | verbose -> true 23 | ) 24 | 25 | val sizes = Gen.range("size")(5000, 25000, 5000) 26 | 27 | val delayedSizes = Gen.range("size")(5, 25, 5) 28 | 29 | private def request(i: Int): Future[Unit] = Future { () } 30 | 31 | private def delayedRequest(i: Int): Future[Unit] = Future { Thread.sleep(1) } 32 | 33 | @gen("sizes") 34 | @benchmark("coroutines.async.request-reply") 35 | @curve("async") 36 | def asyncAwait(sz: Int) = { 37 | val done = async { 38 | var i = 0 39 | while (i < sz) { 40 | val reply = await(request(i)) 41 | i += 1 42 | } 43 | } 44 | Await.result(done, 10.seconds) 45 | } 46 | 47 | @gen("delayedSizes") 48 | @benchmark("coroutines.async.delayed-request-reply") 49 | @curve("async") 50 | def delayedAsyncAwait(sz: Int) = { 51 | val done = async { 52 | var i = 0 53 | while (i < sz) { 54 | val reply = await(delayedRequest(i)) 55 | i += 1 56 | } 57 | } 58 | Await.result(done, 10.seconds) 59 | } 60 | 61 | def coroutineAsync[Y, T](f: Coroutine._0[Future[Y], T]): Future[T] = { 62 | val c = call(f()) 63 | val p = Promise[T]() 64 | def loop() { 65 | if (!c.resume) p.success(c.result) 66 | else c.value.onComplete { 67 | case _ => loop() 68 | } 69 | } 70 | Future { loop() } 71 | p.future 72 | } 73 | 74 | def coroutineAwait[T]: Coroutine._1[Future[T], Future[T], T] = coroutine { 75 | (f: Future[T]) => 76 | yieldval(f) 77 | f.value.get.get 78 | } 79 | 80 | @gen("sizes") 81 | @benchmark("coroutines.async.request-reply") 82 | @curve("coroutine") 83 | def coroutineAsyncAwait(sz: Int) = { 84 | val done = coroutineAsync { 85 | coroutine { () => 86 | var i = 0 87 | while (i < sz) { 88 | val reply = coroutineAwait(request(i)) 89 | i += 1 90 | } 91 | } 92 | } 93 | Await.result(done, 10.seconds) 94 | } 95 | 96 | @gen("delayedSizes") 97 | @benchmark("coroutines.async.delayed-request-reply") 98 | @curve("coroutine") 99 | def delayedCoroutineAsyncAwait(sz: Int) = { 100 | val done = coroutineAsync { 101 | coroutine { () => 102 | var i = 0 103 | while (i < sz) { 104 | val reply = coroutineAwait(delayedRequest(i)) 105 | i += 1 106 | } 107 | } 108 | } 109 | Await.result(done, 10.seconds) 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /src/bench/scala/org/coroutines/DataflowVariableBench.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import java.util.concurrent.ForkJoinPool 6 | import java.util.concurrent.atomic._ 7 | import org.scalameter.api._ 8 | import org.scalameter.japi.JBench 9 | import scala.annotation.tailrec 10 | import scala.collection._ 11 | import scala.concurrent._ 12 | import scala.concurrent.duration._ 13 | import scala.concurrent.ExecutionContext.Implicits.global 14 | import scala.util.Success 15 | import scala.util.Failure 16 | 17 | 18 | 19 | class DataflowVariableBench extends JBench.OfflineReport { 20 | override def defaultConfig = Context( 21 | exec.minWarmupRuns -> 100, 22 | exec.maxWarmupRuns -> 200, 23 | exec.benchRuns -> 36, 24 | exec.independentSamples -> 4, 25 | verbose -> true 26 | ) 27 | 28 | val sizes = Gen.range("size")(50000, 250000, 50000) 29 | 30 | val TOKENS = 100 31 | 32 | class FutureDataflowVar[T] { 33 | private val p = Promise[T]() 34 | def apply(cont: T => Unit): Unit = p.future.onComplete { 35 | case Success(x) => cont(x) 36 | case Failure(t) => throw t 37 | } 38 | def :=(x: T): Unit = p.success(x) 39 | } 40 | 41 | class FutureDataflowStream[T](val head: T) { 42 | val tail = new FutureDataflowVar[FutureDataflowStream[T]] 43 | } 44 | 45 | @gen("sizes") 46 | @benchmark("coroutines.dataflow.producer-consumer") 47 | @curve("future") 48 | def futureProducerConsumer(sz: Int) = { 49 | val root = new FutureDataflowVar[FutureDataflowStream[String]]() 50 | def producer(left: Int, tail: FutureDataflowVar[FutureDataflowStream[String]]) { 51 | val s = new FutureDataflowStream("") 52 | tail := s 53 | if (left > 0) producer(left - 1, s.tail) 54 | } 55 | val done = Promise[Boolean]() 56 | def consumer(left: Int, tail: FutureDataflowVar[FutureDataflowStream[String]]) { 57 | if (left == 0) done.success(true) 58 | else tail(s => consumer(left - 1, s.tail)) 59 | } 60 | 61 | val p = Future { 62 | producer(sz, root) 63 | } 64 | val c = Future { 65 | consumer(sz, root) 66 | } 67 | Await.result(done.future, 10.seconds) 68 | } 69 | 70 | @gen("sizes") 71 | @benchmark("coroutines.dataflow.bounded-producer-consumer") 72 | @curve("future") 73 | def futureBoundedProducerConsumer(sz: Int) = { 74 | val root = new FutureDataflowVar[FutureDataflowStream[String]]() 75 | val startTokens = new FutureDataflowVar[FutureDataflowStream[String]]() 76 | def producer( 77 | left: Int, 78 | tail: FutureDataflowVar[FutureDataflowStream[String]], 79 | tokenTail: FutureDataflowVar[FutureDataflowStream[String]] 80 | ) { 81 | val s = new FutureDataflowStream("") 82 | tail := s 83 | if (left > 0) tokenTail(toks => producer(left - 1, s.tail, toks.tail)) 84 | } 85 | val done = Promise[Boolean]() 86 | def consumer( 87 | left: Int, 88 | tail: FutureDataflowVar[FutureDataflowStream[String]], 89 | tokenTail: FutureDataflowVar[FutureDataflowStream[String]] 90 | ) { 91 | if (left == 0) done.success(true) 92 | else { 93 | val toks = new FutureDataflowStream("") 94 | tokenTail := toks 95 | tail(s => consumer(left - 1, s.tail, toks.tail)) 96 | } 97 | } 98 | 99 | def fill(t: FutureDataflowVar[FutureDataflowStream[String]], left: Int): 100 | FutureDataflowVar[FutureDataflowStream[String]] = { 101 | val s = new FutureDataflowStream("") 102 | t := s 103 | if (left > 0) fill(s.tail, left - 1) 104 | else s.tail 105 | } 106 | val tokens = fill(startTokens, TOKENS) 107 | val p = Future { 108 | producer(sz, root, startTokens) 109 | } 110 | val c = Future { 111 | consumer(sz, root, tokens) 112 | } 113 | Await.result(done.future, 10.seconds) 114 | } 115 | 116 | @gen("sizes") 117 | @benchmark("coroutines.dataflow.producer-consumer") 118 | @curve("ltq") 119 | def ltqProducerConsumer(sz: Int) = { 120 | val q = new java.util.concurrent.LinkedTransferQueue[String]() 121 | val p = new Thread { 122 | override def run() { 123 | var i = 0 124 | while (i < sz) { 125 | q.add("") 126 | i += 1 127 | } 128 | } 129 | } 130 | val c = new Thread { 131 | override def run() { 132 | var i = 0 133 | while (i < sz) { 134 | q.take() 135 | i += 1 136 | } 137 | } 138 | } 139 | c.start() 140 | p.start() 141 | c.join() 142 | p.join() 143 | } 144 | 145 | @gen("sizes") 146 | @benchmark("coroutines.dataflow.bounded-producer-consumer") 147 | @curve("ltq") 148 | def ltqBoundedProducerConsumer(sz: Int) = { 149 | val q = new java.util.concurrent.LinkedTransferQueue[String]() 150 | val tokens = new java.util.concurrent.LinkedTransferQueue[String]() 151 | for (i <- 0 until TOKENS) tokens.add("") 152 | val p = new Thread { 153 | override def run() { 154 | var i = 0 155 | while (i < sz) { 156 | q.add("") 157 | tokens.take() 158 | i += 1 159 | } 160 | } 161 | } 162 | val c = new Thread { 163 | override def run() { 164 | var i = 0 165 | while (i < sz) { 166 | q.take() 167 | tokens.add("") 168 | i += 1 169 | } 170 | } 171 | } 172 | c.start() 173 | p.start() 174 | c.join() 175 | p.join() 176 | } 177 | 178 | @transient lazy val forkJoinPool = new ForkJoinPool 179 | 180 | def task[T](body: ~~~>[DataflowVar[T], Unit]) { 181 | val c = call(body()) 182 | schedule(c) 183 | } 184 | 185 | def schedule[T](c: DataflowVar[T] <~> Unit) { 186 | forkJoinPool.execute(new Runnable { 187 | @tailrec final def run() { 188 | if (c.resume) { 189 | val dvar = c.value 190 | @tailrec def subscribe(): Boolean = { 191 | val state = dvar.get 192 | if (state.isInstanceOf[List[_]]) { 193 | if (dvar.compareAndSet(state, c :: state.asInstanceOf[List[_]])) true 194 | else subscribe() 195 | } else false 196 | } 197 | if (!subscribe()) run() 198 | } 199 | } 200 | }) 201 | } 202 | 203 | class DataflowVar[T] extends AtomicReference[AnyRef](Nil) { 204 | val apply = coroutine { () => 205 | if (this.get.isInstanceOf[List[_]]) yieldval(this) 206 | this.get.asInstanceOf[T] 207 | } 208 | @tailrec final def :=(x: T): Unit = { 209 | val state = this.get 210 | if (state.isInstanceOf[List[_]]) { 211 | if (this.compareAndSet(state, x.asInstanceOf[AnyRef])) { 212 | var cs = state.asInstanceOf[List[DataflowVar[T] <~> Unit]] 213 | while (cs != Nil) { 214 | schedule(cs.head) 215 | cs = cs.tail 216 | } 217 | } else this := x 218 | } else { 219 | sys.error("Already assigned!") 220 | } 221 | } 222 | override def toString = s"DataflowVar${this.get}" 223 | } 224 | 225 | class DataflowStream[T](val head: T) { 226 | val tail = new DataflowVar[DataflowStream[T]] 227 | } 228 | 229 | @gen("sizes") 230 | @benchmark("coroutines.dataflow.producer-consumer") 231 | @curve("coroutine") 232 | def coroutineProducerConsumer(sz: Int) = { 233 | val root = new DataflowVar[DataflowStream[String]] 234 | val done = Promise[Boolean]() 235 | val producer = coroutine { () => 236 | var left = sz 237 | var tail = root 238 | while (left > 0) { 239 | tail := new DataflowStream("") 240 | tail = tail.apply().tail 241 | left -= 1 242 | } 243 | } 244 | val consumer = coroutine { () => 245 | var left = sz 246 | var tail = root 247 | while (left > 0) { 248 | tail = tail.apply().tail 249 | left -= 1 250 | } 251 | done.success(true) 252 | () 253 | } 254 | 255 | task(producer) 256 | task(consumer) 257 | 258 | Await.result(done.future, 10.seconds) 259 | } 260 | 261 | @gen("sizes") 262 | @benchmark("coroutines.dataflow.bounded-producer-consumer") 263 | @curve("coroutine") 264 | def coroutineBoundedProducerConsumer(sz: Int) = { 265 | val root = new DataflowVar[DataflowStream[String]] 266 | val tokens = new DataflowVar[DataflowStream[String]] 267 | val done = Promise[Boolean]() 268 | val producer = coroutine { () => 269 | var left = sz 270 | var tail = root 271 | var tokenTail = tokens 272 | while (left > 0) { 273 | tail := new DataflowStream("") 274 | tail = tail.apply().tail 275 | tokenTail = tokenTail.apply().tail 276 | left -= 1 277 | } 278 | } 279 | val startTokens = { 280 | var t = tokens 281 | var left = TOKENS 282 | while (left > 0) { 283 | val s = new DataflowStream("") 284 | t := s 285 | t = s.tail 286 | left -= 1 287 | } 288 | t 289 | } 290 | val consumer = coroutine { () => 291 | var left = sz 292 | var tail = root 293 | var tokenTail = startTokens 294 | while (left > 0) { 295 | tail = tail.apply().tail 296 | val s = new DataflowStream("") 297 | tokenTail := s 298 | tokenTail = s.tail 299 | left -= 1 300 | } 301 | done.success(true) 302 | () 303 | } 304 | 305 | task(producer) 306 | task(consumer) 307 | 308 | Await.result(done.future, 10.seconds) 309 | } 310 | } 311 | -------------------------------------------------------------------------------- /src/bench/scala/org/coroutines/GraphIteratorBench.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalameter.api._ 6 | import org.scalameter.japi.JBench 7 | import scala.collection._ 8 | 9 | 10 | 11 | class GraphIteratorBench extends JBench.OfflineReport { 12 | 13 | override def defaultConfig = Context( 14 | exec.minWarmupRuns -> 50, 15 | exec.maxWarmupRuns -> 100, 16 | exec.benchRuns -> 36, 17 | exec.independentSamples -> 4, 18 | verbose -> true 19 | ) 20 | 21 | val SPARSE_DEG = 3 22 | 23 | val DENSE_DEG = 16 24 | 25 | val sparseSizes = Gen.range("size")(50000, 250000, 50000) 26 | 27 | val denseSizes = Gen.range("size")(50000, 250000, 50000) 28 | 29 | def graphs(sizes: Gen[Int], density: Int) = for (sz <- sizes) yield { 30 | var totalNeighbours = 0 31 | val nodes = mutable.Buffer[Node[String]]() 32 | var g = new Graph[String]() 33 | var n = g.add("root") 34 | for (i <- 0 until sz) { 35 | n = n.add(i.toString) 36 | 37 | val neighbours = 38 | if (nodes.length > 0) 39 | (0 until density).map(j => math.abs(i ^ (i + j)) % nodes.length).distinct 40 | else Nil 41 | totalNeighbours += neighbours.length 42 | for (index <- neighbours) { 43 | n.neighbours += nodes(index) 44 | } 45 | 46 | nodes += n 47 | } 48 | 49 | g 50 | } 51 | 52 | val sparseGraphs = graphs(sparseSizes, SPARSE_DEG) 53 | 54 | val denseGraphs = graphs(denseSizes, DENSE_DEG) 55 | 56 | var dfsEnumerator: Coroutine._1[Graph[String], String, Unit] = null 57 | 58 | def initDfsEnumerator() { 59 | def addNeighbours( 60 | stack: mutable.ArrayBuffer[Node[String]], visited: Array[Boolean], n: Node[String] 61 | ) { 62 | var i = 0 63 | while (i < n.neighbours.length) { 64 | val m = n.neighbours(i) 65 | if (!visited(m.index)) { 66 | stack += m 67 | visited(m.index) = true 68 | } 69 | i += 1 70 | } 71 | } 72 | dfsEnumerator = coroutine { (g: Graph[String]) => 73 | val visited = new Array[Boolean](g.indexCount) 74 | val stack = mutable.ArrayBuffer[Node[String]]() 75 | for (n <- g.roots) { 76 | stack += n 77 | visited(n.index) = true 78 | } 79 | while (stack.length > 0) { 80 | val n = stack.remove(stack.length - 1) 81 | addNeighbours(stack, visited, n) 82 | yieldval(n.elem) 83 | } 84 | } 85 | } 86 | 87 | var bfsEnumerator: Coroutine._1[Graph[String], String, Unit] = null 88 | 89 | def initBfsEnumerator() { 90 | def addNeighbours( 91 | queue: mutable.Queue[Node[String]], visited: Array[Boolean], n: Node[String] 92 | ) { 93 | var i = 0 94 | while (i < n.neighbours.length) { 95 | val m = n.neighbours(i) 96 | if (!visited(m.index)) { 97 | queue.enqueue(m) 98 | visited(m.index) = true 99 | } 100 | i += 1 101 | } 102 | } 103 | bfsEnumerator = coroutine { (g: Graph[String]) => 104 | val visited = new Array[Boolean](g.indexCount) 105 | val queue = mutable.Queue[Node[String]]() 106 | for (n <- g.roots) { 107 | queue += n 108 | visited(n.index) = true 109 | } 110 | while (queue.length > 0) { 111 | val n = queue.dequeue() 112 | addNeighbours(queue, visited, n) 113 | yieldval(n.elem) 114 | } 115 | } 116 | } 117 | 118 | /* to buffer */ 119 | 120 | @gen("sparseGraphs") 121 | @benchmark("coroutines.sparse-graph-dfs.to-buffer") 122 | @curve("coroutine") 123 | def coroutineSparseDfs(g: Graph[String]) = { 124 | val buffer = mutable.Buffer[String]() 125 | initDfsEnumerator() 126 | val c = call(dfsEnumerator(g)) 127 | while (c.resume) { 128 | val s = c.value 129 | buffer += s 130 | } 131 | buffer 132 | } 133 | 134 | @gen("sparseGraphs") 135 | @benchmark("coroutines.sparse-graph-dfs.to-buffer") 136 | @curve("iterator") 137 | def iteratorSparseDfs(g: Graph[String]) = { 138 | val buffer = mutable.Buffer[String]() 139 | val i = new GraphDfsIterator(g) 140 | while (i.hasNext) { 141 | val s = i.next() 142 | buffer += s 143 | } 144 | buffer 145 | } 146 | 147 | @gen("denseGraphs") 148 | @benchmark("coroutines.dense-graph.dfs.to-buffer") 149 | @curve("coroutine") 150 | def coroutineDenseDfs(g: Graph[String]) = { 151 | val buffer = mutable.Buffer[String]() 152 | initDfsEnumerator() 153 | val c = call(dfsEnumerator(g)) 154 | while (c.resume) { 155 | val s = c.value 156 | buffer += s 157 | } 158 | buffer 159 | } 160 | 161 | @gen("denseGraphs") 162 | @benchmark("coroutines.dense-graph.dfs.to-buffer") 163 | @curve("iterator") 164 | def iteratorDenseDfs(g: Graph[String]) = { 165 | val buffer = mutable.Buffer[String]() 166 | val i = new GraphDfsIterator(g) 167 | while (i.hasNext) { 168 | val s = i.next() 169 | buffer += s 170 | } 171 | buffer 172 | } 173 | 174 | @gen("sparseGraphs") 175 | @benchmark("coroutines.sparse-graph.bfs.to-buffer") 176 | @curve("coroutine") 177 | def coroutineSparseBfs(g: Graph[String]) = { 178 | val buffer = mutable.Buffer[String]() 179 | initBfsEnumerator() 180 | val c = call(bfsEnumerator(g)) 181 | while (c.resume) { 182 | val s = c.value 183 | buffer += s 184 | } 185 | buffer 186 | } 187 | 188 | @gen("sparseGraphs") 189 | @benchmark("coroutines.sparse-graph.bfs.to-buffer") 190 | @curve("iterator") 191 | def iteratorSparseBfs(g: Graph[String]) = { 192 | val buffer = mutable.Buffer[String]() 193 | val i = new GraphBfsIterator(g) 194 | while (i.hasNext) { 195 | val s = i.next() 196 | buffer += s 197 | } 198 | buffer 199 | } 200 | 201 | } 202 | -------------------------------------------------------------------------------- /src/bench/scala/org/coroutines/HashSetIteratorBench.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalameter.api._ 6 | import org.scalameter.japi.JBench 7 | import scala.collection._ 8 | 9 | 10 | 11 | class HashSetIteratorBench extends JBench.OfflineReport { 12 | 13 | override def defaultConfig = Context( 14 | exec.minWarmupRuns -> 40, 15 | exec.maxWarmupRuns -> 80, 16 | exec.benchRuns -> 60, 17 | exec.independentSamples -> 6, 18 | exec.reinstantiation.frequency -> 1, 19 | verbose -> false 20 | ) 21 | 22 | val sizes = Gen.range("size")(50000, 250000, 50000) 23 | 24 | val hashsets = for (sz <- sizes) yield { 25 | var hs = mutable.HashSet[String]() 26 | for (i <- 0 until sz) hs += i.toString 27 | hs 28 | } 29 | 30 | /* longest string */ 31 | 32 | @gen("hashsets") 33 | @benchmark("coroutines.hash-set-iterator.longest") 34 | @curve("coroutine") 35 | def coroutineLongest(set: mutable.HashSet[String]) = { 36 | var longest = "" 37 | val hashIterator = Backdoor.hashSetEnumerator 38 | val table = Backdoor.hashSet(set) 39 | val c = call(hashIterator(table)) 40 | while (c.pull) { 41 | val s = c.value 42 | if (longest.length < s.length) longest = s 43 | } 44 | longest 45 | } 46 | 47 | @gen("hashsets") 48 | @benchmark("coroutines.hash-set-iterator.longest") 49 | @curve("iterator") 50 | def iteratorLongest(set: mutable.HashSet[String]) = { 51 | var longest = "" 52 | val i = set.iterator 53 | while (i.hasNext) { 54 | val s = i.next() 55 | if (longest.length < s.length) longest = s 56 | } 57 | longest 58 | } 59 | 60 | @gen("hashsets") 61 | @benchmark("coroutines.hash-set-iterator.longest") 62 | @curve("foreach") 63 | def foreachLongest(set: mutable.HashSet[String]) = { 64 | var longest = "" 65 | set.foreach { s => 66 | if (longest.length < s.length) longest = s 67 | } 68 | longest 69 | } 70 | 71 | } 72 | -------------------------------------------------------------------------------- /src/bench/scala/org/coroutines/RedBlackIteratorBench.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalameter.api._ 6 | import org.scalameter.japi.JBench 7 | import scala.collection._ 8 | 9 | 10 | 11 | class RedBlackIteratorBench extends JBench.OfflineReport { 12 | 13 | override def defaultConfig = Context( 14 | exec.minWarmupRuns -> 40, 15 | exec.maxWarmupRuns -> 80, 16 | exec.benchRuns -> 60, 17 | exec.independentSamples -> 6, 18 | exec.reinstantiation.frequency -> 1, 19 | verbose -> false 20 | ) 21 | 22 | val sizes = Gen.range("size")(50000, 250000, 50000) 23 | 24 | val trees = for (sz <- sizes) yield { 25 | var tree = immutable.TreeSet[String]() 26 | for (i <- 0 until sz) tree += i.toString 27 | tree 28 | } 29 | 30 | /* longest string */ 31 | 32 | @gen("trees") 33 | @benchmark("coroutines.red-black-iterator.longest") 34 | @curve("coroutine") 35 | def coroutineLongest(set: immutable.TreeSet[String]) = { 36 | var longest = "" 37 | val treeIterator = Backdoor.redBlackEnumerator 38 | val tree = Backdoor.redBlack(set) 39 | val c = call(treeIterator(tree)) 40 | while (c.pull) { 41 | val s = c.value 42 | if (longest.length < s.length) longest = s 43 | } 44 | longest 45 | } 46 | 47 | @gen("trees") 48 | @benchmark("coroutines.red-black-iterator.longest") 49 | @curve("iterator") 50 | def iteratorLongest(set: immutable.TreeSet[String]) = { 51 | var longest = "" 52 | val i = set.iterator 53 | while (i.hasNext) { 54 | val s = i.next() 55 | if (longest.length < s.length) longest = s 56 | } 57 | longest 58 | } 59 | 60 | } 61 | -------------------------------------------------------------------------------- /src/bench/scala/org/coroutines/ScalaCheckBench.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalameter.api._ 6 | import org.scalameter.japi.JBench 7 | import scala.collection._ 8 | import scala.concurrent._ 9 | import scala.concurrent.duration._ 10 | import scala.concurrent.ExecutionContext.Implicits.global 11 | import scala.util.Random 12 | 13 | 14 | 15 | class ScalaCheckBench extends JBench.OfflineReport { 16 | 17 | override def defaultConfig = Context( 18 | exec.minWarmupRuns -> 100, 19 | exec.maxWarmupRuns -> 200, 20 | exec.benchRuns -> 36, 21 | exec.independentSamples -> 1, 22 | verbose -> true 23 | ) 24 | 25 | val fractNumTests = Gen.range("size")(5000, 25000, 5000) 26 | 27 | val listNumTests = Gen.range("size")(100, 500, 100) 28 | 29 | val max = 1000 30 | 31 | case class Fract(num: Int, den: Int) 32 | 33 | object Fract { 34 | def normalize(num: Int, den: Int) = { 35 | val d = gcd(num, den) 36 | Fract(num / d, den / d) 37 | } 38 | } 39 | 40 | def gcd(x: Int, y: Int) = { 41 | var a = x 42 | var b = y 43 | while (b != 0) { 44 | val t = b 45 | b = a % b 46 | a = t 47 | } 48 | a 49 | } 50 | 51 | def add(a: Fract, b: Fract) = Fract.normalize( 52 | a.num * b.den + a.den * b.num, a.den * b.den) 53 | 54 | def mult(a: Fract, b: Fract) = Fract.normalize( 55 | a.num * b.num, a.den * b.den) 56 | 57 | def inv(a: Fract) = Fract(a.den, a.num) 58 | 59 | trait Gen[T] { 60 | self => 61 | def sample: T 62 | def map[S](f: T => S): Gen[S] = new Gen[S] { 63 | def sample = f(self.sample) 64 | } 65 | def flatMap[S](f: T => Gen[S]): Gen[S] = new Gen[S] { 66 | def sample = f(self.sample).sample 67 | } 68 | } 69 | 70 | def ints(from: Int, until: Int) = new Gen[Int] { 71 | val random = new Random(111) 72 | def sample = from + random.nextInt(until - from) 73 | } 74 | 75 | @gen("fractNumTests") 76 | @benchmark("coroutines.scalacheck.fractions") 77 | @curve("scalacheck") 78 | def scalacheckTestFraction(numTests: Int) = { 79 | val fracts = for { 80 | den <- ints(1, max) 81 | num <- ints(0, den) 82 | } yield Fract(num, den) 83 | val pairs = for { 84 | a <- fracts 85 | b <- fracts 86 | } yield (a, b) 87 | for (i <- 0 until numTests) { 88 | val (a, b) = pairs.sample 89 | val c = add(a, b) 90 | assert(c.num < 2 * c.den) 91 | val d = add(b, a) 92 | assert(d == c) 93 | } 94 | } 95 | 96 | @gen("listNumTests") 97 | @benchmark("coroutines.scalacheck.lists") 98 | @curve("scalacheck") 99 | def scalacheckTestList(numTests: Int) = { 100 | val lists = for { 101 | x <- ints(0, max) 102 | } yield List.fill(max)(x) 103 | val pairs = for { 104 | a <- lists 105 | b <- lists 106 | } yield (a, b) 107 | for (i <- 0 until numTests) { 108 | val (xs, ys) = pairs.sample 109 | assert(xs.size + ys.size == (xs ::: ys).size, (xs, ys)) 110 | } 111 | } 112 | 113 | class Backtracker { 114 | val random = new Random(111) 115 | 116 | val recurse: (Unit <~> Unit) ~~> (Unit, Unit) = coroutine { (c: Unit <~> Unit) => 117 | if (c.resume) { 118 | val saved = c.snapshot 119 | recurse(c) 120 | recurse(saved) 121 | } else { 122 | yieldval(()) 123 | } 124 | } 125 | 126 | val traverse = coroutine { (snippet: ~~~>[Unit, Unit]) => 127 | while (true) { 128 | val c = call(snippet()) 129 | recurse(c) 130 | } 131 | } 132 | 133 | def backtrack(snippet: ~~~>[Unit, Unit], numTests: Int): Unit = { 134 | var testsLeft = numTests 135 | val t = call(traverse(snippet)) 136 | for (i <- 0 until numTests) t.resume 137 | } 138 | 139 | val int = coroutine { (from: Int, until: Int) => 140 | yieldval(()) 141 | from + random.nextInt(until - from) 142 | } 143 | } 144 | 145 | @gen("fractNumTests") 146 | @benchmark("coroutines.scalacheck.fractions") 147 | @curve("coroutine") 148 | def coroutineTestFraction(numTests: Int) = { 149 | val b = new Backtracker 150 | val fract = coroutine { () => 151 | val den = b.int(1, max) 152 | val num = b.int(0, den) 153 | Fract(num, den) 154 | } 155 | val test = coroutine { () => 156 | val a = fract() 157 | val b = fract() 158 | val c = add(a, b) 159 | assert(c.num < 2 * c.den) 160 | val d = add(b, a) 161 | assert(d == c) 162 | } 163 | b.backtrack(test, numTests) 164 | } 165 | 166 | @gen("listNumTests") 167 | @benchmark("coroutines.scalacheck.lists") 168 | @curve("coroutine") 169 | def coroutineTestList(numTests: Int) = { 170 | val b = new Backtracker 171 | val list = coroutine { () => 172 | val x = b.int(1, max) 173 | List.fill(max)(x) 174 | } 175 | val test = coroutine { () => 176 | val a = list() 177 | val b = list() 178 | val c = a ::: b 179 | assert(a.size + b.size == c.size) 180 | } 181 | b.backtrack(test, numTests) 182 | } 183 | } 184 | -------------------------------------------------------------------------------- /src/bench/scala/org/coroutines/StreamBench.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalameter.api._ 6 | import org.scalameter.japi.JBench 7 | import scala.collection._ 8 | 9 | 10 | 11 | class StreamBench extends JBench.OfflineReport { 12 | 13 | override def defaultConfig = Context( 14 | exec.minWarmupRuns -> 50, 15 | exec.maxWarmupRuns -> 100, 16 | exec.benchRuns -> 36, 17 | exec.independentSamples -> 4, 18 | verbose -> false 19 | ) 20 | 21 | val fibSizes = Gen.range("size")(5000, 25000, 5000) 22 | 23 | val taylorSizes = Gen.range("size")(50000, 250000, 50000) 24 | 25 | @gen("fibSizes") 26 | @benchmark("coroutines.stream.fibonacci.to-buffer") 27 | @curve("stream") 28 | def streamFibonacciToBuffer(sz: Int) = { 29 | val buffer = mutable.Buffer[BigInt]() 30 | object Fibs { 31 | lazy val values: Stream[BigInt] = 32 | BigInt(0) #:: BigInt(1) #:: values.zip(values.tail).map(t => t._1 + t._2) 33 | } 34 | var i = 0 35 | var s = Fibs.values 36 | while (i < sz) { 37 | buffer += s.head 38 | s = s.tail 39 | i += 1 40 | } 41 | buffer 42 | } 43 | 44 | @gen("fibSizes") 45 | @benchmark("coroutines.stream.fibonacci.to-buffer") 46 | @curve("coroutine") 47 | def coroutineFibonacciToBuffer(sz: Int) = { 48 | val buffer = mutable.Buffer[BigInt]() 49 | val fibs = coroutine { () => 50 | var prev = BigInt(0) 51 | var curr = BigInt(1) 52 | yieldval(prev) 53 | yieldval(curr) 54 | while (true) { 55 | val x = curr + prev 56 | yieldval(x) 57 | prev = curr 58 | curr = x 59 | } 60 | } 61 | var i = 0 62 | val c = call(fibs()) 63 | while (i < sz) { 64 | c.resume 65 | buffer += c.value 66 | i += 1 67 | } 68 | buffer 69 | } 70 | 71 | @gen("taylorSizes") 72 | @benchmark("coroutines.stream.taylor.sum") 73 | @curve("stream") 74 | def streamTaylorSum(sz: Int) = { 75 | var sum = 0.0 76 | class TaylorInvX(x: Double) { 77 | lazy val values: Stream[Double] = 78 | 1.0 #:: values.map(_ * (x - 1) * -1) 79 | } 80 | var i = 0 81 | var s = new TaylorInvX(0.5).values 82 | while (i < sz) { 83 | sum += s.head 84 | s = s.tail 85 | i += 1 86 | } 87 | sum 88 | } 89 | 90 | @gen("taylorSizes") 91 | @benchmark("coroutines.stream.taylor.sum") 92 | @curve("coroutine") 93 | def coroutineTaylorSum(sz: Int) = { 94 | var sum = 0.0 95 | val taylor = coroutine { (x: Double) => 96 | var last = 1.0 97 | yieldval(last) 98 | while (true) { 99 | last *= -1.0 * (x - 1) 100 | yieldval(last) 101 | } 102 | } 103 | var i = 0 104 | val c = call(taylor(0.5)) 105 | while (i < sz) { 106 | c.resume 107 | sum += c.value 108 | i += 1 109 | } 110 | sum 111 | } 112 | 113 | } 114 | -------------------------------------------------------------------------------- /src/bench/scala/org/coroutines/TreeIteratorBench.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalameter.api._ 6 | import org.scalameter.japi.JBench 7 | import scala.collection._ 8 | 9 | 10 | 11 | class TreeIteratorBench extends JBench.OfflineReport { 12 | override def defaultConfig = Context( 13 | exec.minWarmupRuns -> 40, 14 | exec.maxWarmupRuns -> 80, 15 | exec.benchRuns -> 30, 16 | exec.independentSamples -> 1, 17 | verbose -> true 18 | ) 19 | 20 | sealed trait Tree 21 | case class Node(x: Int, left: Tree, right: Tree) extends Tree 22 | case object Empty extends Tree 23 | 24 | class TreeIterator(val tree: Tree) { 25 | var stack = new Array[Tree](30) 26 | var stackpos = -1 27 | var current: Int = _ 28 | 29 | def goLeft(tree: Tree) { 30 | stackpos += 1 31 | stack(stackpos) = tree 32 | tree match { 33 | case Empty => 34 | case Node(_, left, _) => goLeft(left) 35 | } 36 | } 37 | 38 | goLeft(tree) 39 | moveToNext() 40 | 41 | def moveToNext() { 42 | if (stackpos != -1) stack(stackpos) match { 43 | case Empty => 44 | stack(stackpos) = null 45 | stackpos -= 1 46 | if (stackpos > -1) assert(stack(stackpos) != Empty) 47 | moveToNext() 48 | case Node(x, _, right) => 49 | stack(stackpos) = null 50 | stackpos -= 1 51 | current = x 52 | goLeft(right) 53 | } 54 | } 55 | 56 | def hasNext: Boolean = { 57 | stackpos != -1 58 | } 59 | def next(): Int = { 60 | if (!hasNext) throw new NoSuchElementException 61 | val x = current 62 | moveToNext() 63 | x 64 | } 65 | } 66 | 67 | val sizes = Gen.range("size")(50000, 250000, 50000) 68 | 69 | def genTree(sz: Int): Tree = { 70 | if (sz == 0) Empty 71 | else { 72 | val rem = sz - 1 73 | val left = genTree(rem / 2) 74 | val right = genTree(rem - rem / 2) 75 | Node(sz, left, right) 76 | } 77 | } 78 | 79 | val trees = for (sz <- sizes) yield { 80 | genTree(sz) 81 | } 82 | 83 | val treePairs = for (sz <- sizes) yield { 84 | (genTree(sz), genTree(sz)) 85 | } 86 | 87 | var treeEnumerator: Coroutine._1[Tree, Int, Unit] = null 88 | 89 | /* max int */ 90 | 91 | @gen("trees") 92 | @benchmark("coroutines.tree-iterator.max") 93 | @curve("coroutine") 94 | def coroutineMax(tree: Tree) { 95 | var max = Int.MinValue 96 | treeEnumerator = coroutine { (t: Tree) => 97 | t match { 98 | case n: Node => 99 | if (n.left != Empty) treeEnumerator(n.left) 100 | yieldval(n.x) 101 | if (n.right != Empty) treeEnumerator(n.right) 102 | case Empty => 103 | } 104 | } 105 | val c = call(treeEnumerator(tree)) 106 | while (c.pull) { 107 | val x = c.value 108 | if (x > max) max = x 109 | } 110 | } 111 | 112 | @gen("trees") 113 | @benchmark("coroutines.tree-iterator.max") 114 | @curve("iterator") 115 | def iteratorMax(tree: Tree) { 116 | var max = Int.MinValue 117 | val iter = new TreeIterator(tree) 118 | while (iter.hasNext) { 119 | val x = iter.next() 120 | if (x > max) max = x 121 | } 122 | } 123 | 124 | @gen("trees") 125 | @benchmark("coroutines.tree-iterator.max") 126 | @curve("recursion") 127 | def recursiveMax(tree: Tree) { 128 | var max = Int.MinValue 129 | def recurse(tree: Tree) { 130 | tree match { 131 | case Node(x, left, right) => 132 | recurse(left) 133 | if (x > max) max = x 134 | recurse(right) 135 | case Empty => 136 | } 137 | } 138 | recurse(tree) 139 | } 140 | 141 | /* growing array */ 142 | 143 | @gen("trees") 144 | @benchmark("coroutines.tree-iterator.to-array") 145 | @curve("coroutine") 146 | def coroutineToArray(tree: Tree) { 147 | val a = new IntArray 148 | treeEnumerator = coroutine { (t: Tree) => 149 | t match { 150 | case n: Node => 151 | if (n.left != Empty) treeEnumerator(n.left) 152 | yieldval(n.x) 153 | if (n.right != Empty) treeEnumerator(n.right) 154 | case Empty => 155 | } 156 | } 157 | val c = call(treeEnumerator(tree)) 158 | while (c.pull) { 159 | val x = c.value 160 | a.add(x) 161 | } 162 | } 163 | 164 | @gen("trees") 165 | @benchmark("coroutines.tree-iterator.to-array") 166 | @curve("iterator") 167 | def iteratorToArray(tree: Tree) { 168 | val a = new IntArray 169 | val iter = new TreeIterator(tree) 170 | while (iter.hasNext) { 171 | val x = iter.next() 172 | a.add(x) 173 | } 174 | } 175 | 176 | @gen("trees") 177 | @benchmark("coroutines.tree-iterator.to-array") 178 | @curve("recursion") 179 | def recursiveToArray(tree: Tree) { 180 | val a = new IntArray 181 | def recurse(tree: Tree) { 182 | tree match { 183 | case Node(x, left, right) => 184 | recurse(left) 185 | a.add(x) 186 | recurse(right) 187 | case Empty => 188 | } 189 | } 190 | recurse(tree) 191 | } 192 | 193 | /* samefringe */ 194 | 195 | @volatile var isSame = true 196 | 197 | @gen("treePairs") 198 | @benchmark("coroutines.tree-iterator.same-fringe") 199 | @curve("coroutine") 200 | def coroutineSameFringe(p: (Tree, Tree)) { 201 | val (t1, t2) = p 202 | treeEnumerator = coroutine { (t: Tree) => 203 | t match { 204 | case n: Node => 205 | if (n.left != Empty) treeEnumerator(n.left) 206 | yieldval(n.x) 207 | if (n.right != Empty) treeEnumerator(n.right) 208 | case Empty => 209 | } 210 | } 211 | val c1 = call(treeEnumerator(t1)) 212 | val c2 = call(treeEnumerator(t2)) 213 | var same = true 214 | while (c1.pull && c2.pull) { 215 | val x = c1.value 216 | val y = c2.value 217 | if (x != y) same = false 218 | } 219 | isSame = same 220 | } 221 | 222 | @gen("treePairs") 223 | @benchmark("coroutines.tree-iterator.same-fringe") 224 | @curve("iterator") 225 | def iteratorSameFringe(p: (Tree, Tree)) { 226 | val (t1, t2) = p 227 | val iter1 = new TreeIterator(t1) 228 | val iter2 = new TreeIterator(t2) 229 | var same = true 230 | while (iter1.hasNext && iter2.hasNext) { 231 | val x = iter1.next() 232 | val y = iter2.next() 233 | if (x != y) same = false 234 | } 235 | if (iter1.hasNext != iter2.hasNext) same = false 236 | isSame = same 237 | } 238 | 239 | def treeStream(tree: Tree): Stream[Int] = { 240 | tree match { 241 | case Empty => Stream() 242 | case Node(x, left, right) => treeStream(left) #::: (x #:: treeStream(right)) 243 | } 244 | } 245 | 246 | @gen("treePairs") 247 | @benchmark("coroutines.tree-iterator.same-fringe") 248 | @curve("stream") 249 | def streamSameFringe(p: (Tree, Tree)) { 250 | val (t1, t2) = p 251 | var s1 = treeStream(t1) 252 | var s2 = treeStream(t2) 253 | var same = true 254 | while (s1.nonEmpty && s2.nonEmpty) { 255 | val x = s1.head 256 | val y = s2.head 257 | if (x != y) same = false 258 | s1 = s1.tail 259 | s2 = s2.tail 260 | } 261 | if (s1.nonEmpty != s2.nonEmpty) same = false 262 | isSame = same 263 | } 264 | 265 | /* tests */ 266 | 267 | assert({ 268 | def leaf(x: Int) = Node(x, Empty, Empty) 269 | val tree = Node(1, 270 | Node(19, leaf(21), leaf(23)), 271 | Node(3, 272 | leaf(11), 273 | Node(9, 274 | leaf(5), 275 | leaf(17)))) 276 | val a = mutable.Buffer[Int]() 277 | def rec(tree: Tree): Unit = tree match { 278 | case Empty => 279 | case Node(x, l, r) => 280 | rec(l) 281 | a += x 282 | rec(r) 283 | } 284 | rec(tree) 285 | val b = mutable.Buffer[Int]() 286 | val it = new TreeIterator(tree) 287 | while (it.hasNext) b += it.next() 288 | a == b 289 | }) 290 | 291 | } 292 | -------------------------------------------------------------------------------- /src/bench/scala/org/coroutines/data-structures.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import scala.collection._ 6 | 7 | 8 | 9 | class IntArray { 10 | private var array = new Array[Int](8) 11 | private var size = 0 12 | def length = size 13 | def apply(idx: Int) = { 14 | assert(idx >= 0 && idx < size) 15 | array(idx) 16 | } 17 | def add(x: Int) = { 18 | if (size == array.length) { 19 | val narray = new Array[Int](size * 2) 20 | System.arraycopy(array, 0, narray, 0, size) 21 | array = narray 22 | } 23 | array(size) = x 24 | size += 1 25 | } 26 | def push(x: Int) = add(x) 27 | def top = array(size - 1) 28 | def pop() = { 29 | val x = array(size - 1) 30 | size -= 1 31 | x 32 | } 33 | } 34 | 35 | 36 | class Graph[T] { 37 | var indexCount = 0 38 | val roots = mutable.Buffer[Node[T]]() 39 | 40 | def add(elem: T): Node[T] = { 41 | val n = new Node(this, indexCount, elem) 42 | indexCount += 1 43 | roots += n 44 | n 45 | } 46 | } 47 | 48 | 49 | class Node[T](val graph: Graph[T], val index: Int, val elem: T) { 50 | val neighbours = mutable.Buffer[Node[T]]() 51 | 52 | def add(elem: T): Node[T] = { 53 | val n = new Node(graph, graph.indexCount, elem) 54 | graph.indexCount += 1 55 | neighbours += n 56 | n 57 | } 58 | } 59 | 60 | 61 | class GraphDfsIterator[T](val graph: Graph[T]) { 62 | val visited = new Array[Boolean](graph.indexCount) 63 | val stack = mutable.ArrayBuffer[Node[T]]() 64 | def enq(n: Node[T]) = stack += n 65 | def deq(): Node[T] = stack.remove(stack.length - 1) 66 | for (n <- graph.roots) { 67 | enq(n) 68 | visited(n.index) = true 69 | } 70 | def hasNext = stack.length > 0 71 | def next(): T = { 72 | val n = deq() 73 | var i = 0 74 | while (i < n.neighbours.length) { 75 | val m = n.neighbours(i) 76 | if (!visited(m.index)) { 77 | enq(m) 78 | visited(m.index) = true 79 | } 80 | i += 1 81 | } 82 | n.elem 83 | } 84 | } 85 | 86 | 87 | class GraphBfsIterator[T](val graph: Graph[T]) { 88 | val visited = new Array[Boolean](graph.indexCount) 89 | val queue = mutable.Queue[Node[T]]() 90 | def enq(n: Node[T]) = queue.enqueue(n) 91 | def deq(): Node[T] = queue.dequeue() 92 | for (n <- graph.roots) { 93 | enq(n) 94 | visited(n.index) = true 95 | } 96 | def hasNext = queue.length > 0 97 | def next(): T = { 98 | val n = deq() 99 | var i = 0 100 | while (i < n.neighbours.length) { 101 | val m = n.neighbours(i) 102 | if (!visited(m.index)) { 103 | enq(m) 104 | visited(m.index) = true 105 | } 106 | i += 1 107 | } 108 | n.elem 109 | } 110 | } 111 | -------------------------------------------------------------------------------- /src/bench/scala/scala/collection/Backdoor.scala: -------------------------------------------------------------------------------- 1 | package scala.collection 2 | 3 | 4 | 5 | import org.coroutines._ 6 | 7 | 8 | 9 | object Backdoor { 10 | 11 | type RBTree[T] = immutable.RedBlackTree.Tree[T, Unit] 12 | 13 | def redBlack[T](set: immutable.TreeSet[T]): RBTree[T] = { 14 | val f = set.getClass.getDeclaredField("tree") 15 | f.setAccessible(true) 16 | f.get(set).asInstanceOf[immutable.RedBlackTree.Tree[T, Unit]] 17 | } 18 | 19 | val redBlackEnumeratorInlined: Coroutine._1[RBTree[String], String, Unit] = 20 | coroutine { (tree: RBTree[String]) => 21 | if (tree.left != null) { 22 | if (tree.left.left != null) redBlackEnumerator(tree.left.left) 23 | yieldval(tree.left.key) 24 | if (tree.left.right != null) redBlackEnumerator(tree.left.right) 25 | } 26 | yieldval(tree.key) 27 | if (tree.right != null) { 28 | if (tree.right.left != null) redBlackEnumerator(tree.right.left) 29 | yieldval(tree.right.key) 30 | if (tree.right.right != null) redBlackEnumerator(tree.right.right) 31 | } 32 | } 33 | 34 | val redBlackEnumerator: Coroutine._1[RBTree[String], String, Unit] = 35 | coroutine { (tree: RBTree[String]) => 36 | if (tree.left != null) redBlackEnumerator(tree.left) 37 | yieldval(tree.key) 38 | if (tree.right != null) redBlackEnumerator(tree.right) 39 | } 40 | 41 | def hashSet[T](set: mutable.HashSet[T]): Array[AnyRef] = { 42 | val f = set.getClass.getDeclaredField("table") 43 | f.setAccessible(true) 44 | f.get(set).asInstanceOf[Array[AnyRef]] 45 | } 46 | 47 | val hashSetEnumerator: Coroutine._1[Array[AnyRef], String, Unit] = 48 | coroutine { (table: Array[AnyRef]) => 49 | var i = 0 50 | while (i < table.length) { 51 | val x = table(i) 52 | if (x != null) yieldval(x.asInstanceOf[String]) 53 | i += 1 54 | } 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /src/main/scala/org/coroutines/AstCanonicalization.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.coroutines.common._ 6 | import scala.collection._ 7 | import scala.language.experimental.macros 8 | import scala.reflect.macros.whitebox.Context 9 | 10 | 11 | 12 | /** Transforms the coroutine body into three address form with restricted control flow 13 | * that contains only try-catch statements, while loops, if-statements, value and 14 | * variable declarations, pattern matches, nested blocks and function calls. 15 | * 16 | * Newly synthesized variables get mangled fresh names, and existing variable names are 17 | * preserved. 18 | * 19 | * Coroutine operations usages are checked for correctness, and nested contexts, such 20 | * as function and class declarations, are checked, but not transformed. 21 | */ 22 | trait AstCanonicalization[C <: Context] { 23 | self: Analyzer[C] => 24 | 25 | val c: C 26 | 27 | import c.universe._ 28 | 29 | class NestedContextValidator(implicit typer: ByTreeTyper[c.type]) 30 | extends Traverser { 31 | override def traverse(tree: Tree): Unit = tree match { 32 | case q"$qual.coroutine[$_]($_)" if isCoroutinesPkg(qual) => 33 | // no need to check further, this is checked in a different expansion 34 | case q"$qual.yieldval[$_]($_)" if isCoroutinesPkg(qual) => 35 | c.abort( 36 | tree.pos, 37 | "The yieldval statement only be invoked directly inside the coroutine. " + 38 | "Nested classes, functions or for-comprehensions, should either use the " + 39 | "call statement or declare another coroutine.") 40 | case q"$qual.yieldto[$_]($_)" if isCoroutinesPkg(qual) => 41 | c.abort( 42 | tree.pos, 43 | "The yieldto statement only be invoked directly inside the coroutine. " + 44 | "Nested classes, functions or for-comprehensions, should either use the " + 45 | "call statement or declare another coroutine.") 46 | case q"$qual.call($co.apply(..$args))" if isCoroutinesPkg(qual) => 47 | // no need to check further, the call macro will validate the coroutine type 48 | case q"$co.apply(..$args)" if isCoroutineDefMarker(typer.typeOf(co)) => 49 | c.abort( 50 | tree.pos, 51 | "Coroutine blueprints can only be invoked directly inside the coroutine. " + 52 | "Nested classes, functions or for-comprehensions, should either use the " + 53 | "call statement or declare another coroutine.") 54 | case q"$co.apply[..$_](..$args)(..$_)" 55 | if isCoroutineDefMarker(typer.typeOf(co)) => 56 | c.abort( 57 | tree.pos, 58 | "Coroutine blueprints can only be invoked directly inside the coroutine. " + 59 | "Nested classes, functions or for-comprehensions, should either use the " + 60 | "call statement or declare another coroutine.") 61 | case _ => 62 | super.traverse(tree) 63 | } 64 | } 65 | 66 | def disallowCoroutinesIn(tree: Tree): Unit = { 67 | for (t <- tree) t match { 68 | case CoroutineOp(t) => c.abort(t.pos, "Coroutines disallowed in:\n$tree.") 69 | case _ => // fine 70 | } 71 | } 72 | 73 | private def canonicalize(tree: Tree)( 74 | implicit typer: ByTreeTyper[c.type] 75 | ): (List[Tree], Tree) = tree match { 76 | case q"$r.`package`" => 77 | // package selection 78 | (Nil, tree) 79 | case q"$r.$member" if !tree.symbol.isPackage => 80 | // selection 81 | val (rdecls, rident) = canonicalize(r) 82 | val localvarname = TermName(c.freshName("x")) 83 | val localvartree = q"val $localvarname = $rident.$member" 84 | (rdecls ++ List(localvartree), q"$localvarname") 85 | case q"$r.&&($arg)" 86 | if typer.typeOf(r) =:= typeOf[Boolean] && typer.typeOf(arg) =:= typeOf[Boolean] => 87 | // short-circuit boolean and 88 | val (conddecls, condident) = canonicalize(r) 89 | val (thendecls, thenident) = canonicalize(arg) 90 | val localvarname = TermName(c.freshName("x")) 91 | val decls = List( 92 | q"var $localvarname = null.asInstanceOf[Boolean]", 93 | q""" 94 | ..$conddecls 95 | if ($condident) { 96 | ..$thendecls 97 | $localvarname = $thenident 98 | } else { 99 | $localvarname = false 100 | } 101 | """ 102 | ) 103 | (decls, q"$localvarname") 104 | case q"$r.||($arg)" 105 | if typer.typeOf(r) =:= typeOf[Boolean] && typer.typeOf(arg) =:= typeOf[Boolean] => 106 | // short-circuit boolean or 107 | val (conddecls, condident) = canonicalize(r) 108 | val (elsedecls, elseident) = canonicalize(arg) 109 | val localvarname = TermName(c.freshName("x")) 110 | val decls = List( 111 | q"var $localvarname = null.asInstanceOf[Boolean]", 112 | q""" 113 | ..$conddecls 114 | if ($condident) { 115 | $localvarname = true 116 | } else { 117 | ..$elsedecls 118 | $localvarname = $elseident 119 | } 120 | """ 121 | ) 122 | (decls, q"$localvarname") 123 | case q"$selector[..$tpts](...$paramss)" if tpts.length > 0 || paramss.length > 0 => 124 | // application 125 | val (rdecls, newselector) = selector match { 126 | case q"$r.$method" => 127 | val (rdecls, rident) = canonicalize(r) 128 | (rdecls, q"$rident.$method") 129 | case q"${method: TermName}" => 130 | (Nil, q"$method") 131 | } 132 | for (tpt <- tpts) disallowCoroutinesIn(tpt) 133 | val (pdeclss, pidents) = paramss.map(_.map(canonicalize).unzip).unzip 134 | val localvarname = TermName(c.freshName("x")) 135 | val localvartree = q"val $localvarname = $newselector[..$tpts](...$pidents)" 136 | (rdecls ++ pdeclss.flatten.flatten ++ List(localvartree), q"$localvarname") 137 | case q"$r[..$tpts]" if tpts.length > 0 => 138 | // type application 139 | for (tpt <- tpts) disallowCoroutinesIn(tpt) 140 | val (rdecls, rident) = canonicalize(r) 141 | (rdecls, q"$rident[..$tpts]") 142 | case q"$x = $v" => 143 | // assignment 144 | val (xdecls, xident) = canonicalize(x) 145 | val (vdecls, vident) = canonicalize(v) 146 | (xdecls ++ vdecls ++ List(q"$xident = $vident"), q"()") 147 | case q"$x(..$args) = $v" => 148 | // update 149 | val (xdecls, xident) = canonicalize(x) 150 | val (argdecls, argidents) = args.map(canonicalize).unzip 151 | val (vdecls, vident) = canonicalize(v) 152 | (xdecls ++ argdecls.flatten ++ vdecls, q"$xident(..$argidents) = $vident") 153 | case q"return $_" => 154 | // return 155 | c.abort(tree.pos, "The return statement is not allowed inside coroutines.") 156 | case q"$x: $tpt" => 157 | // ascription 158 | disallowCoroutinesIn(tpt) 159 | val (xdecls, xident) = canonicalize(x) 160 | (xdecls, q"$xident: $tpt") 161 | case q"$x: @$annot" => 162 | // annotation 163 | val (xdecls, xident) = canonicalize(x) 164 | (xdecls, q"$xident: $annot") 165 | case q"(..$xs)" if xs.length > 1 => 166 | // tuples 167 | val (xsdecls, xsidents) = xs.map(canonicalize).unzip 168 | (xsdecls.flatten, q"(..$xsidents)") 169 | case q"throw $expr" => 170 | // throw 171 | val (decls, ident) = canonicalize(expr) 172 | val ndecls = decls ++ List(q"throw $ident") 173 | (ndecls, q"throw $ident") 174 | case q"try $body catch { case ..$cases } finally $expr" => 175 | // try 176 | val tpe = typer.typeOf(tree) 177 | val localvarname = TermName(c.freshName("x")) 178 | val exceptionvarname = TermName(c.freshName("e")) 179 | val bindingname = TermName(c.freshName("t")) 180 | val (bodydecls, bodyident) = canonicalize(body) 181 | val (exprdecls, exprident) = canonicalize(expr) 182 | val matchcases = 183 | cases :+ cq"${pq"null"} => null" :+ cq"${pq"_"} => throw $exceptionvarname" 184 | val exceptionident = q"$exceptionvarname" 185 | val matchbody = q"$exceptionident match { case ..$matchcases }" 186 | typer.typeOf(matchbody) = typer.typeOf(tree) 187 | typer.typeOf(exceptionident) = typeOf[Throwable] 188 | val (matchdecls, matchident) = canonicalize(matchbody) 189 | val ndecls = List( 190 | q"var $localvarname = null.asInstanceOf[$tpe]", 191 | q"var $exceptionvarname: Throwable = null", 192 | q""" 193 | try { 194 | ..$bodydecls 195 | 196 | $localvarname = $bodyident 197 | } catch { 198 | case $bindingname: Throwable => $exceptionvarname = $bindingname 199 | } 200 | """ 201 | ) ++ List(if (expr == q"") q""" 202 | ..$matchdecls 203 | 204 | $localvarname = $matchident 205 | """ else q""" 206 | try { 207 | ..$matchdecls 208 | 209 | $localvarname = $matchident 210 | } finally { 211 | $expr 212 | } 213 | """ 214 | ) 215 | (ndecls, q"$localvarname") 216 | case q"if ($cond) $thenbranch else $elsebranch" => 217 | // if 218 | val (conddecls, condident) = canonicalize(cond) 219 | val (thendecls, thenident) = canonicalize(thenbranch) 220 | val (elsedecls, elseident) = canonicalize(elsebranch) 221 | val localvarname = TermName(c.freshName("x")) 222 | val tpe = typer.typeOf(tree) 223 | val decls = List( 224 | q"var $localvarname = null.asInstanceOf[$tpe]", 225 | q""" 226 | ..$conddecls 227 | if ($condident) { 228 | ..$thendecls 229 | $localvarname = $thenident 230 | } else { 231 | ..$elsedecls 232 | $localvarname = $elseident 233 | } 234 | """ 235 | ) 236 | (decls, q"$localvarname") 237 | case q"$expr match { case ..$cases }" => 238 | // pattern match 239 | val localvarname = TermName(c.freshName("x")) 240 | val (exdecls, exident) = canonicalize(expr) 241 | val tpe = typer.typeOf(tree) 242 | val extpe = typer.typeOf(expr) 243 | val ncases = for (cq"$pat => $branch" <- cases) yield { 244 | disallowCoroutinesIn(pat) 245 | val (branchdecls, branchident) = canonicalize(branch) 246 | val isWildcard = pat match { 247 | case pq"_" => true 248 | case _ => false 249 | } 250 | val checkcases = 251 | if (isWildcard) List(cq"$pat => true") 252 | else List(cq"$pat => true", cq"_ => false") 253 | val patdecl = q"val $pat: $exident @scala.unchecked = $exident" 254 | val body = q""" 255 | ..$patdecl 256 | 257 | ..$branchdecls 258 | 259 | $localvarname = $branchident 260 | """ 261 | (q"$exident match { case ..$checkcases }", body) 262 | } 263 | val patternmatch = 264 | ncases.foldRight(q"throw new scala.MatchError($exident)": Tree) { 265 | case ((patternmatch, ifbranch), elsebranch) => 266 | q"if ($patternmatch) $ifbranch else $elsebranch" 267 | } 268 | val decls = 269 | List(q"var $localvarname = null.asInstanceOf[${tpe.widen}]") ++ 270 | exdecls ++ 271 | List(patternmatch) 272 | (decls, q"$localvarname") 273 | case q"(..$params) => $body" => 274 | // function 275 | new NestedContextValidator().traverse(tree) 276 | (Nil, tree) 277 | case q"{ case ..$cases }" => 278 | // partial function 279 | new NestedContextValidator().traverse(tree) 280 | (Nil, tree) 281 | case q"while ($cond) $body" => 282 | // while 283 | val (xdecls0, xident0) = canonicalize(cond) 284 | // TODO: This is a temporary fix. It is very dangerous, since it makes the 285 | // transformation take O(2^n) time in the depth of the tree. 286 | // 287 | // The correct solution is to duplicate the trees so that duplicate value decls in 288 | // the two trees get fresh names. 289 | val (xdecls1, xident1) = canonicalize(cond) 290 | val localvarname = TermName(c.freshName("x")) 291 | val decls = if (xdecls0 != Nil) { 292 | xdecls0 ++ List( 293 | q"var $localvarname = $xident0", 294 | q""" 295 | while ($localvarname) { 296 | ${transform(body)} 297 | 298 | ..$xdecls1 299 | $localvarname = $xident1 300 | } 301 | """) 302 | } else List(q""" 303 | while ($cond) { 304 | ${transform(body)} 305 | } 306 | """) 307 | (decls, q"()") 308 | case q"do $body while ($cond)" => 309 | // do-while 310 | // TODO: This translation is a temporary fix, and can result in O(2^n) time. The 311 | // correct solution is to transform the subtree once, duplicate the transformed 312 | // trees and rename the variables. 313 | val (xdecls0, xident0) = canonicalize(cond) 314 | val (xdecls1, xident1) = canonicalize(cond) 315 | val localvarname = TermName(c.freshName("x")) 316 | val decls = if (xdecls0 != Nil) List( 317 | q""" 318 | { 319 | ${transform(body)} 320 | } 321 | """ 322 | ) ++ xdecls0 ++ List( 323 | q"var $localvarname = $xident0", 324 | q""" 325 | while ($localvarname) { 326 | ${transform(body)} 327 | 328 | ..$xdecls1 329 | 330 | $localvarname = $xident1 331 | } 332 | """ 333 | ) else List( 334 | q""" 335 | { 336 | ${transform(body)} 337 | } 338 | 339 | while ($cond) { 340 | ${transform(body)} 341 | } 342 | """ 343 | ) 344 | (decls, q"()") 345 | case q"for (..$enums) $body" => 346 | // for loop 347 | for (e <- enums) new NestedContextValidator().traverse(e) 348 | new NestedContextValidator().traverse(body) 349 | (Nil, tree) 350 | case q"for (..$enums) yield $body" => 351 | // for-yield loop 352 | for (e <- enums) new NestedContextValidator().traverse(e) 353 | new NestedContextValidator().traverse(body) 354 | (Nil, tree) 355 | case q"new { ..$edefs } with ..$bases { $self => ..$stats }" => 356 | // new 357 | if (!isCoroutineDef(typer.typeOf(tree))) { 358 | // if this class was not generated from a coroutine declaration, then validate 359 | // the nested context 360 | new NestedContextValidator().traverse(tree) 361 | } 362 | (Nil, tree) 363 | case Block(stats, expr) => 364 | // block 365 | val localvarname = TermName(c.freshName("x")) 366 | val (statdecls, statidents) = stats.map(canonicalize).unzip 367 | val (exprdecls, exprident) = canonicalize(q"$localvarname = $expr") 368 | val tpe = typer.typeOf(expr) 369 | val decls = 370 | List(q"var $localvarname = null.asInstanceOf[${tpe.widen}]") ++ 371 | statdecls.flatten ++ 372 | exprdecls 373 | (decls, q"$localvarname") 374 | case tpt: TypeTree => 375 | // type trees 376 | disallowCoroutinesIn(tpt) 377 | (Nil, tree) 378 | case q"$mods val $v: $tpt = $rhs" => 379 | // val 380 | val (rhsdecls, rhsident) = canonicalize(rhs) 381 | val decls = rhsdecls ++ List(q"$mods val $v: $tpt = $rhsident") 382 | (decls, q"") 383 | case q"$mods var $v: $tpt = $rhs" => 384 | // var 385 | val (rhsdecls, rhsident) = canonicalize(rhs) 386 | val decls = rhsdecls ++ List(q"$mods var $v: $tpt = $rhsident") 387 | (decls, q"") 388 | case q"$mods def $tname[..$tparams](...$paramss): $tpt = $rhs" => 389 | // method 390 | val decls = List( 391 | q""" 392 | $mods def $tname[..$tparams](...$paramss): $tpt = $rhs 393 | """) 394 | new NestedContextValidator().traverse(rhs) 395 | (decls, q"") 396 | case q"$mods type $tpname[..$tparams] = $tpt" => 397 | // type 398 | new NestedContextValidator().traverse(tree) 399 | (Nil, tree) 400 | case q"$_ class $_[..$_] $_(...$_) extends { ..$_ } with ..$_ { $_ => ..$_ }" => 401 | // class 402 | new NestedContextValidator().traverse(tree) 403 | (Nil, tree) 404 | case q"$_ trait $_[..$_] extends { ..$_ } with ..$_ { $_ => ..$_ }" => 405 | // trait 406 | new NestedContextValidator().traverse(tree) 407 | (Nil, tree) 408 | case q"$_ object $_ extends { ..$_ } with ..$_ { $_ => ..$_ }" => 409 | // object 410 | new NestedContextValidator().traverse(tree) 411 | (Nil, tree) 412 | case _ => 413 | // empty 414 | // literal 415 | // identifier 416 | // super selection 417 | // this selection 418 | (Nil, tree) 419 | } 420 | 421 | private def transform(tree: Tree)( 422 | implicit typer: ByTreeTyper[c.type] 423 | ): Tree = tree match { 424 | case Block(stats, expr) => 425 | val (statdecls, statidents) = stats.map(canonicalize).unzip 426 | val (exprdecls, exprident) = canonicalize(expr) 427 | q""" 428 | ..${statdecls.flatten} 429 | 430 | ..$exprdecls 431 | 432 | $exprident 433 | """ 434 | case t => 435 | val (decls, ident) = canonicalize(t) 436 | q""" 437 | ..$decls 438 | 439 | $ident 440 | """ 441 | } 442 | 443 | def canonicalizeTree(rawlambda: Tree): Tree = { 444 | val typer = new ByTreeTyper[c.type](c)(rawlambda) 445 | val untypedrawlambda = typer.untypedTree 446 | 447 | // separate to arguments and body 448 | val (args, body) = untypedrawlambda match { 449 | case q"(..$args) => $body" => (args, body) 450 | case t => c.abort(t.pos, "The coroutine takes a single function literal.") 451 | } 452 | 453 | // recursive transform of the body code 454 | val transformedBody = transform(body)(typer) 455 | val untypedtaflambda = q"(..$args) => $transformedBody" 456 | // println(untypedtaflambda) 457 | c.typecheck(untypedtaflambda) 458 | } 459 | } 460 | -------------------------------------------------------------------------------- /src/main/scala/org/coroutines/Coroutine.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.coroutines.common._ 6 | import scala.annotation.tailrec 7 | import scala.annotation.unchecked.uncheckedVariance 8 | import scala.language.experimental.macros 9 | import scala.reflect.macros.whitebox.Context 10 | import scala.util.Failure 11 | import scala.util.Success 12 | import scala.util.Try 13 | 14 | 15 | 16 | trait Coroutine[@specialized Y, R] extends Coroutine.DefMarker[(Y, R)] { 17 | def $enter(c: Coroutine.Instance[Y, R]): Unit 18 | def $assignyield(c: Coroutine.Instance[Y, R], v: Y): Unit = { 19 | c.$hasYield = true 20 | c.$yield = v 21 | } 22 | def $assignresult(c: Coroutine.Instance[Y, R], v: R): Unit = c.$result = v 23 | def $returnvalue$Z(c: Coroutine.Instance[Y, R], v: Boolean): Unit 24 | def $returnvalue$B(c: Coroutine.Instance[Y, R], v: Byte): Unit 25 | def $returnvalue$S(c: Coroutine.Instance[Y, R], v: Short): Unit 26 | def $returnvalue$C(c: Coroutine.Instance[Y, R], v: Char): Unit 27 | def $returnvalue$I(c: Coroutine.Instance[Y, R], v: Int): Unit 28 | def $returnvalue$F(c: Coroutine.Instance[Y, R], v: Float): Unit 29 | def $returnvalue$J(c: Coroutine.Instance[Y, R], v: Long): Unit 30 | def $returnvalue$D(c: Coroutine.Instance[Y, R], v: Double): Unit 31 | def $returnvalue$L(c: Coroutine.Instance[Y, R], v: Any): Unit 32 | def $ep0(c: Coroutine.Instance[Y, R]): Unit = {} 33 | def $ep1(c: Coroutine.Instance[Y, R]): Unit = {} 34 | def $ep2(c: Coroutine.Instance[Y, R]): Unit = {} 35 | def $ep3(c: Coroutine.Instance[Y, R]): Unit = {} 36 | def $ep4(c: Coroutine.Instance[Y, R]): Unit = {} 37 | def $ep5(c: Coroutine.Instance[Y, R]): Unit = {} 38 | def $ep6(c: Coroutine.Instance[Y, R]): Unit = {} 39 | def $ep7(c: Coroutine.Instance[Y, R]): Unit = {} 40 | def $ep8(c: Coroutine.Instance[Y, R]): Unit = {} 41 | def $ep9(c: Coroutine.Instance[Y, R]): Unit = {} 42 | def $ep10(c: Coroutine.Instance[Y, R]): Unit = {} 43 | def $ep11(c: Coroutine.Instance[Y, R]): Unit = {} 44 | def $ep12(c: Coroutine.Instance[Y, R]): Unit = {} 45 | def $ep13(c: Coroutine.Instance[Y, R]): Unit = {} 46 | def $ep14(c: Coroutine.Instance[Y, R]): Unit = {} 47 | def $ep15(c: Coroutine.Instance[Y, R]): Unit = {} 48 | def $ep16(c: Coroutine.Instance[Y, R]): Unit = {} 49 | def $ep17(c: Coroutine.Instance[Y, R]): Unit = {} 50 | def $ep18(c: Coroutine.Instance[Y, R]): Unit = {} 51 | def $ep19(c: Coroutine.Instance[Y, R]): Unit = {} 52 | def $ep20(c: Coroutine.Instance[Y, R]): Unit = {} 53 | def $ep21(c: Coroutine.Instance[Y, R]): Unit = {} 54 | def $ep22(c: Coroutine.Instance[Y, R]): Unit = {} 55 | def $ep23(c: Coroutine.Instance[Y, R]): Unit = {} 56 | def $ep24(c: Coroutine.Instance[Y, R]): Unit = {} 57 | def $ep25(c: Coroutine.Instance[Y, R]): Unit = {} 58 | def $ep26(c: Coroutine.Instance[Y, R]): Unit = {} 59 | def $ep27(c: Coroutine.Instance[Y, R]): Unit = {} 60 | def $ep28(c: Coroutine.Instance[Y, R]): Unit = {} 61 | def $ep29(c: Coroutine.Instance[Y, R]): Unit = {} 62 | } 63 | 64 | 65 | object Coroutine { 66 | private[coroutines] val INITIAL_COSTACK_SIZE = 4 67 | 68 | type SomeY 69 | 70 | type SomeR 71 | 72 | @tailrec 73 | private[coroutines] final def resume[Y, R]( 74 | callsite: Instance[Y, R], actual: Instance[_, _] 75 | ): Boolean = { 76 | val cd = Stack.top(actual.$costack).asInstanceOf[Coroutine[SomeY, SomeR]] 77 | cd.$enter(actual.asInstanceOf[Instance[SomeY, SomeR]]) 78 | if (actual.$target ne null) { 79 | val newactual = actual.$target 80 | actual.$target = null 81 | resume(callsite, newactual) 82 | } else if (actual.$exception ne null) { 83 | callsite.isLive 84 | } else { 85 | callsite.isLive 86 | } 87 | } 88 | 89 | class Instance[@specialized Y, R] { 90 | var $costackptr = 0 91 | var $costack: Array[Coroutine[Y, R]] = 92 | new Array[Coroutine[Y, R]](INITIAL_COSTACK_SIZE) 93 | var $pcstackptr = 0 94 | var $pcstack = new Array[Short](INITIAL_COSTACK_SIZE) 95 | var $refstackptr = 0 96 | var $refstack: Array[AnyRef] = _ 97 | var $valstackptr = 0 98 | var $valstack: Array[Int] = _ 99 | var $target: Instance[Y, _] = null 100 | var $exception: Throwable = null 101 | var $hasYield: Boolean = false 102 | var $yield: Y = null.asInstanceOf[Y] 103 | var $result: R = null.asInstanceOf[R] 104 | 105 | /** Clones the coroutine that this instance is a part of. 106 | * 107 | * @return A new coroutine instance with exactly the same execution state. 108 | */ 109 | final def snapshot: Instance[Y, R] = { 110 | val frame = new Instance[Y, R] 111 | Stack.copy(this.$costack, frame.$costack) 112 | Stack.copy(this.$pcstack, frame.$pcstack) 113 | Stack.copy(this.$refstack, frame.$refstack) 114 | Stack.copy(this.$valstack, frame.$valstack) 115 | frame.$exception = this.$exception 116 | frame.$hasYield = this.$hasYield 117 | frame.$yield = this.$yield 118 | frame.$result = this.$result 119 | frame 120 | } 121 | 122 | /** Advances the coroutine to the next yield point. 123 | * 124 | * @return `true` if resume can be called again, `false` otherwise. 125 | * @throws CoroutineStoppedException If the coroutine is not live. 126 | */ 127 | final def resume: Boolean = { 128 | if (isLive) { 129 | $hasYield = false 130 | $yield = null.asInstanceOf[Y] 131 | Coroutine.resume[Y, R](this, this) 132 | } else throw new CoroutineStoppedException 133 | } 134 | 135 | /** Calls `resume` until either the coroutine yields a value or returns. 136 | * 137 | * If `pull` returns `true`, then the coroutine has suspended by yielding 138 | * a value and there are more elements to traverse. 139 | * 140 | * Usage: 141 | * 142 | * {{{ 143 | * while (c.pull) c.value 144 | * }}} 145 | * 146 | * @return `false` if the coroutine stopped, `true` otherwise. 147 | * @throws CoroutineStoppedException If the coroutine is not live. 148 | */ 149 | @tailrec 150 | final def pull: Boolean = { 151 | if (isLive) { 152 | if (!resume) false 153 | else if (hasValue) true 154 | else pull 155 | } else throw new CoroutineStoppedException 156 | } 157 | 158 | /** Returns the value yielded by the coroutine. 159 | * 160 | * This method will thrown an exception if the value cannot be accessed. 161 | * 162 | * @return The value yielded by the coroutine, if there is one. 163 | * @throws RuntimeException If the coroutine doesn't have a value or if it 164 | * is not live. 165 | */ 166 | final def value: Y = { 167 | if (!hasValue) 168 | sys.error("Coroutine has no value, because it did not yield.") 169 | if (!isLive) 170 | sys.error("Coroutine has no value, because it is completed.") 171 | $yield 172 | } 173 | 174 | /** Returns whether or not the coroutine yielded a value. 175 | * 176 | * This value can be accessed via `getValue`. 177 | * 178 | * @return `true` if the coroutine yielded a value, `false` otherwise. 179 | */ 180 | final def hasValue: Boolean = $hasYield 181 | 182 | /** Returns an `Option` instance wrapping the current value of the coroutine, if 183 | * any. 184 | * 185 | * @return `Some(value)` if `hasValue`, `None` otherwise. 186 | */ 187 | final def getValue: Option[Y] = if (hasValue) Some(value) else None 188 | 189 | /** Returns a `Try` instance wrapping this coroutine's value, if any. 190 | * 191 | * The `Try` wraps either the current value of this coroutine or any exceptions 192 | * thrown when trying to get the value. 193 | * 194 | * @return `Success(value)` if `value` does not throw an exception, or 195 | * a `Failure` instance if it does. 196 | */ 197 | final def tryValue: Try[Y] = 198 | try { Success(value) } catch { case t: Throwable => Failure(t) } 199 | 200 | /** The value returned by the coroutine, if the coroutine is completed. 201 | * 202 | * This method will throw an exception if the result cannot be accessed. 203 | * 204 | * '''Note:''' the returned value is not the same as the value yielded 205 | * by the coroutine. The coroutine may yield any number of values during its 206 | * lifetime, but it returns only a single value after it terminates. 207 | * 208 | * @return The return value of the coroutine, if the coroutine is completed. 209 | * @throws RuntimeException If `!isCompleted`. 210 | * @throws Exception If `hasException`. 211 | */ 212 | final def result: R = { 213 | if (!isCompleted) 214 | sys.error("Coroutine has no result, because it is not completed.") 215 | if ($exception != null) throw $exception 216 | $result 217 | } 218 | 219 | /** Returns whether or not the coroutine completed without an exception. 220 | * 221 | * @return `true` if the coroutine completed without an exception, `false` 222 | * otherwise. 223 | */ 224 | final def hasResult: Boolean = isCompleted && $exception == null 225 | 226 | /** Returns an `Option` wrapping this coroutine's non-exception result, if any. 227 | * 228 | * @return `Some(result)` if `hasResult`, `None` otherwise. 229 | */ 230 | final def getResult: Option[R] = if (hasResult) Some(result) else None 231 | 232 | /** Returns a `Try` object wrapping either the successful result of this 233 | * coroutine or the exception that the coroutine threw. 234 | * 235 | * @return A `Failure` instance if the coroutine has an exception, 236 | * `Try(result)` otherwise. 237 | */ 238 | final def tryResult: Try[R] = { 239 | if ($exception != null) Failure($exception) 240 | else Try(result) 241 | } 242 | 243 | /** Returns whether or not the coroutine completed with an exception. 244 | * 245 | * @return `true` iff `isCompleted` and the coroutine has a non-null 246 | * exception, `false` otherwise. 247 | */ 248 | final def hasException: Boolean = isCompleted && $exception != null 249 | 250 | /** Returns an `Option` object wrapping the exception thrown by this coroutine. 251 | * 252 | * @return If `hasException`, a `Some` instance wrapping the exception thrown by 253 | * this coroutine. Otherwise, `None`. 254 | */ 255 | final def getException: Option[Throwable] = { 256 | if (hasException) Some($exception) 257 | else None 258 | } 259 | 260 | /** Returns `false` iff the coroutine instance completed execution. 261 | * 262 | * This is true if there are either more yield statements or if the 263 | * coroutine has not yet returned its result. 264 | * 265 | * @return `true` if `resume` can be called without an exception being 266 | * thrown, `false` otherwise. 267 | */ 268 | final def isLive: Boolean = $costackptr > 0 269 | 270 | /** Returns `true` iff the coroutine instance completed execution. 271 | * 272 | * See the documentation for `isLive`. 273 | * 274 | * @return `!isLive`. 275 | */ 276 | final def isCompleted: Boolean = !isLive 277 | 278 | /** Returns a string representation of the coroutine's state. 279 | * 280 | * Contains less information than `debugString`. 281 | * 282 | * @return A string describing the coroutine state. 283 | */ 284 | override def toString = s"Coroutine.Instance" 285 | 286 | /** Returns a string that describes the internal state of the coroutine. 287 | * 288 | * Contains more information than `toString`. 289 | * 290 | * @return A string containing information about the internal state of the 291 | * coroutine. 292 | */ 293 | final def debugString: String = { 294 | def toStackLength[T](stack: Array[T]) = 295 | if (stack != null) "${stack.length}" else "" 296 | def toStackString[T](stack: Array[T]) = 297 | if (stack != null) stack.mkString("[", ", ", "]") else "" 298 | s"Coroutine.Instance <\n" + 299 | s" costackptr: ${$costackptr}\n" + 300 | s" costack sz: ${toStackLength($costack)}\n" + 301 | s" pcstackptr: ${$pcstackptr}\n" + 302 | s" pcstack: ${toStackString($pcstack)}\n" + 303 | s" exception: ${$exception}\n" + 304 | s" yield: ${$yield}\n" + 305 | s" result: ${$result}\n" + 306 | s" refstackptr: ${$refstackptr}\n" + 307 | s" refstack: ${toStackString($refstack)}\n" + 308 | s" valstackptr: ${$valstackptr}\n" + 309 | s" valstack: ${toStackString($valstack)}\n" + 310 | s">" 311 | } 312 | } 313 | 314 | trait DefMarker[YR] 315 | 316 | def synthesize(c: Context)(f: c.Tree): c.Tree = { 317 | new Synthesizer[c.type](c).synthesize(f) 318 | } 319 | 320 | def call[T: c.WeakTypeTag](c: Context)(f: c.Tree): c.Tree = { 321 | new Synthesizer[c.type](c).call(f) 322 | } 323 | 324 | abstract class _0[@specialized T, R] extends Coroutine[T, R] { 325 | def apply(): R 326 | def $call(): Instance[T, R] 327 | def $push(c: Instance[T, R]): Unit 328 | override def toString = s"Coroutine._0@${System.identityHashCode(this)}" 329 | } 330 | 331 | abstract class _1[A0, @specialized T, R] extends Coroutine[T, R] { 332 | def apply(a0: A0): R 333 | def $call(a0: A0): Instance[T, R] 334 | def $push(c: Instance[T, R], a0: A0): Unit 335 | override def toString = s"Coroutine._1@${System.identityHashCode(this)}" 336 | } 337 | 338 | abstract class _2[A0, A1, @specialized T, R] extends Coroutine[T, R] { 339 | def apply(a0: A0, a1: A1): R 340 | def $call(a0: A0, a1: A1): Instance[T, R] 341 | def $push(c: Instance[T, R], a0: A0, a1: A1): Unit 342 | override def toString = s"Coroutine._2@${System.identityHashCode(this)}" 343 | } 344 | 345 | abstract class _3[A0, A1, A2, @specialized T, R] extends Coroutine[T, R] { 346 | def apply(a0: A0, a1: A1, a2: A2): R 347 | def $call(a0: A0, a1: A1, a2: A2): Instance[T, R] 348 | def $push(c: Instance[T, R], a0: A0, a1: A1, a2: A2): Unit 349 | override def toString = s"Coroutine._3@${System.identityHashCode(this)}" 350 | } 351 | } 352 | -------------------------------------------------------------------------------- /src/main/scala/org/coroutines/Synthesizer.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.coroutines.common._ 6 | import scala.annotation.tailrec 7 | import scala.collection._ 8 | import scala.language.experimental.macros 9 | import scala.reflect.macros.whitebox.Context 10 | 11 | 12 | 13 | /** Synthesizes all coroutine-related functionality. 14 | */ 15 | private[coroutines] class Synthesizer[C <: Context](val c: C) 16 | extends Analyzer[C] 17 | with CfgGenerator[C] 18 | with AstCanonicalization[C] { 19 | import c.universe._ 20 | 21 | val NUM_PREDEFINED_ENTRY_STUBS = 30 22 | 23 | private def genEntryPoint(cfg: Cfg, subgraph: SubCfg)( 24 | implicit t: Table 25 | ): Tree = { 26 | val body = subgraph.emit(cfg) 27 | val defname = TermName(s"$$ep${subgraph.uid}") 28 | val defdef = if (subgraph.uid < NUM_PREDEFINED_ENTRY_STUBS) q""" 29 | override def $defname( 30 | ${t.names.coroutineParam}: 31 | _root_.org.coroutines.Coroutine.Instance[${t.yieldType}, ${t.returnType}] 32 | ): _root_.scala.Unit = { 33 | $body 34 | } 35 | """ else q""" 36 | def $defname( 37 | ${t.names.coroutineParam}: 38 | _root_.org.coroutines.Coroutine.Instance[${t.yieldType}, ${t.returnType}] 39 | ): _root_.scala.Unit = { 40 | $body 41 | } 42 | """ 43 | defdef 44 | } 45 | 46 | private def genEntryPoints(cfg: Cfg)(implicit table: Table): Map[Long, Tree] = { 47 | val entrypoints = for ((orignode, subgraph) <- cfg.subgraphs) yield { 48 | (subgraph.uid, genEntryPoint(cfg, subgraph)) 49 | } 50 | mutable.LinkedHashMap() ++= entrypoints.toSeq.sortBy(_._1) 51 | } 52 | 53 | private def genEnterMethod(entrypoints: Map[Long, Tree])( 54 | implicit table: Table 55 | ): Tree = { 56 | val rettpt = table.returnType 57 | val yldtpt = table.yieldType 58 | if (entrypoints.size == 1) { 59 | val q"$_ def $ep0($_): _root_.scala.Unit = $_" = entrypoints(0) 60 | 61 | q""" 62 | def $$enter( 63 | c: _root_.org.coroutines.Coroutine.Instance[$yldtpt, $rettpt] 64 | ): _root_.scala.Unit = $ep0(c) 65 | """ 66 | } else if (entrypoints.size == 2) { 67 | val q"$_ def $ep0($_): _root_.scala.Unit = $_" = entrypoints(0) 68 | val q"$_ def $ep1($_): _root_.scala.Unit = $_" = entrypoints(1) 69 | 70 | q""" 71 | def $$enter( 72 | c: _root_.org.coroutines.Coroutine.Instance[$yldtpt, $rettpt] 73 | ): _root_.scala.Unit = { 74 | val pc = _root_.org.coroutines.common.Stack.top(c.$$pcstack) 75 | if (pc == 0) $ep0(c) else $ep1(c) 76 | } 77 | """ 78 | } else { 79 | val cases = for ((index, defdef) <- entrypoints) yield { 80 | val q"$_ def $ep($_): _root_.scala.Unit = $rhs" = defdef 81 | cq"${index.toShort} => $ep(c)" 82 | } 83 | 84 | q""" 85 | def $$enter( 86 | c: _root_.org.coroutines.Coroutine.Instance[$yldtpt, $rettpt] 87 | ): _root_.scala.Unit = { 88 | val pc: Short = _root_.org.coroutines.common.Stack.top(c.$$pcstack) 89 | (pc: @_root_.scala.annotation.switch) match { 90 | case ..$cases 91 | } 92 | } 93 | """ 94 | } 95 | } 96 | 97 | private def genReturnValueMethods(cfg: Cfg)(implicit table: Table): List[Tree] = { 98 | List( 99 | genReturnValueMethod(cfg, typeOf[Boolean]), 100 | genReturnValueMethod(cfg, typeOf[Byte]), 101 | genReturnValueMethod(cfg, typeOf[Short]), 102 | genReturnValueMethod(cfg, typeOf[Char]), 103 | genReturnValueMethod(cfg, typeOf[Int]), 104 | genReturnValueMethod(cfg, typeOf[Float]), 105 | genReturnValueMethod(cfg, typeOf[Long]), 106 | genReturnValueMethod(cfg, typeOf[Double]), 107 | genReturnValueMethod(cfg, typeOf[Any]) 108 | ) 109 | } 110 | 111 | private def genReturnValueMethod(cfg: Cfg, tpe: Type)(implicit table: Table): Tree = { 112 | def genReturnValueStore(n: Node) = { 113 | val sub = cfg.subgraphs(n.successors.head) 114 | val pcvalue = sub.uid 115 | val info = table(n.tree.symbol) 116 | val eligible = 117 | (isValType(info.tpe) && (info.tpe =:= tpe)) || 118 | (tpe =:= typeOf[Any]) 119 | if (eligible) { 120 | if (info.tpe =:= typeOf[Unit]) { 121 | (pcvalue, q"()") 122 | } else { 123 | val valuetree = 124 | if (tpe =:= typeOf[Any]) q"v.asInstanceOf[${info.tpe}]" else q"v" 125 | val rvset = info.storeTree(q"c", valuetree) 126 | (pcvalue, q"$rvset") 127 | } 128 | } else { 129 | (pcvalue, 130 | q"""_root_.scala.sys.error("Return method called for incorrect type.")""") 131 | } 132 | } 133 | val returnstores = cfg.start.dfs.collect { 134 | case n @ Node.ApplyCoroutine(_, _, _) => genReturnValueStore(n) 135 | } 136 | 137 | val returnvaluemethod = returnValueMethodName(tpe) 138 | val body = { 139 | if (returnstores.size == 0) { 140 | q"()" 141 | } else if (returnstores.size == 1) { 142 | returnstores(0)._2 143 | } else if (returnstores.size == 2) { 144 | q""" 145 | val pc = _root_.org.coroutines.common.Stack.top(c.$$pcstack) 146 | if (pc == ${returnstores(0)._1.toShort}) { 147 | ${returnstores(0)._2} 148 | } else { 149 | ${returnstores(1)._2} 150 | } 151 | """ 152 | } else { 153 | val cases = for ((pcvalue, rvset) <- returnstores) yield { 154 | cq"${pcvalue.toShort} => $rvset" 155 | } 156 | q""" 157 | val pc = _root_.org.coroutines.common.Stack.top(c.$$pcstack) 158 | (pc: @_root_.scala.annotation.switch) match { 159 | case ..$cases 160 | } 161 | """ 162 | } 163 | } 164 | 165 | q""" 166 | def $returnvaluemethod( 167 | c: _root_.org.coroutines.Coroutine.Instance[ 168 | ${table.yieldType}, ${table.returnType}], 169 | v: $tpe 170 | ): _root_.scala.Unit = { 171 | $body 172 | } 173 | """ 174 | } 175 | 176 | def genVarPushesAndPops(cfg: Cfg)(implicit table: Table): (List[Tree], List[Tree]) = { 177 | val stackVars = cfg.stackVars 178 | val storedValVars = cfg.storedValVars 179 | val storedRefVars = cfg.storedRefVars 180 | def stackSize(vs: Map[Symbol, VarInfo]) = vs.map(_._2.stackpos._2).sum 181 | def genVarPushes(allvars: Map[Symbol, VarInfo], stack: Tree): List[Tree] = { 182 | val vars = allvars.filter(kv => stackVars.contains(kv._1)) 183 | val varsize = stackSize(vars) 184 | val stacksize = math.max(table.initialStackSize, varsize) 185 | val bulkpushes = if (vars.size == 0) Nil else List(q""" 186 | _root_.org.coroutines.common.Stack.bulkPush($stack, $varsize, $stacksize) 187 | """) 188 | val args = vars.values.filter(_.isArg).toList 189 | val argstores = for (a <- args) yield a.storeTree(q"$$c", q"${a.name}") 190 | bulkpushes ::: argstores 191 | } 192 | val varpushes = { 193 | genVarPushes(storedRefVars, q"$$c.$$refstack") ++ 194 | genVarPushes(storedValVars, q"$$c.$$valstack") 195 | } 196 | val varpops = (for ((sym, info) <- storedRefVars.toList) yield { 197 | info.popTree 198 | }) ++ (if (storedValVars.size == 0) Nil else List( 199 | q""" 200 | _root_.org.coroutines.common.Stack.bulkPop( 201 | $$c.$$valstack, ${stackSize(storedValVars)}) 202 | """ 203 | )) 204 | (varpushes, varpops) 205 | } 206 | 207 | def specArity1( 208 | argtpts: List[Tree], yldtpt: Tree, rettpt: Tree 209 | ): (Tree, List[Tree]) = { 210 | val tpe = argtpts(0).tpe 211 | if (tpe == typeOf[scala.Boolean]) { 212 | (tq"org.coroutines.Coroutine._1", argtpts :+ yldtpt :+ rettpt) 213 | } else if (tpe == typeOf[scala.Byte]) { 214 | (tq"org.coroutines.Coroutine._1", argtpts :+ yldtpt :+ rettpt) 215 | } else if (tpe == typeOf[scala.Short]) { 216 | val nme = TypeName(s"_1$$spec$$S") 217 | (tq"org.coroutines.$nme", yldtpt :: rettpt :: Nil) 218 | } else if (tpe == typeOf[scala.Char]) { 219 | val nme = TypeName(s"_1$$spec$$C") 220 | (tq"org.coroutines.$nme", yldtpt :: rettpt :: Nil) 221 | } else if (tpe == typeOf[scala.Int]) { 222 | val nme = TypeName(s"_1$$spec$$I") 223 | (tq"org.coroutines.$nme", yldtpt :: rettpt :: Nil) 224 | } else if (tpe == typeOf[scala.Float]) { 225 | val nme = TypeName(s"_1$$spec$$F") 226 | (tq"org.coroutines.$nme", yldtpt :: rettpt :: Nil) 227 | } else if (tpe == typeOf[scala.Long]) { 228 | val nme = TypeName(s"_1$$spec$$J") 229 | (tq"org.coroutines.$nme", yldtpt :: rettpt :: Nil) 230 | } else if (tpe == typeOf[scala.Double]) { 231 | val nme = TypeName(s"_1$$spec$$D") 232 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 233 | } else { 234 | val nme = TypeName(s"_1$$spec$$L") 235 | (tq"_root_.org.coroutines.$nme", argtpts :+ yldtpt :+ rettpt) 236 | } 237 | } 238 | 239 | def specArity2( 240 | argtpts: List[Tree], yldtpt: Tree, rettpt: Tree 241 | ): (Tree, List[Tree]) = { 242 | val (tp0, tp1) = (argtpts(0).tpe, argtpts(1).tpe) 243 | if (tp0 == typeOf[scala.Int] && tp1 == typeOf[scala.Int]) { 244 | val nme = TypeName(s"_2$$spec$$II") 245 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 246 | } else if (tp0 == typeOf[Long] && tp1 == typeOf[Int]) { 247 | val nme = TypeName(s"_2$$spec$$JI") 248 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 249 | } else if (tp0 == typeOf[Double] && tp1 == typeOf[Int]) { 250 | val nme = TypeName(s"_2$$spec$$DI") 251 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 252 | } else if (tp1 == typeOf[Int]) { 253 | val nme = TypeName(s"_2$$spec$$LI") 254 | (tq"_root_.org.coroutines.$nme", argtpts(0) :: yldtpt :: rettpt :: Nil) 255 | } else if (tp0 == typeOf[Int] && tp1 == typeOf[Long]) { 256 | val nme = TypeName(s"_2$$spec$$IJ") 257 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 258 | } else if (tp0 == typeOf[Long] && tp1 == typeOf[Long]) { 259 | val nme = TypeName(s"_2$$spec$$JJ") 260 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 261 | } else if (tp0 == typeOf[Double] && tp1 == typeOf[Long]) { 262 | val nme = TypeName(s"_2$$spec$$DJ") 263 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 264 | } else if (tp1 == typeOf[Long]) { 265 | val nme = TypeName(s"_2$$spec$$LJ") 266 | (tq"_root_.org.coroutines.$nme", argtpts(0) :: yldtpt :: rettpt :: Nil) 267 | } else if (tp0 == typeOf[Int] && tp1 == typeOf[Double]) { 268 | val nme = TypeName(s"_2$$spec$$ID") 269 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 270 | } else if (tp0 == typeOf[Long] && tp1 == typeOf[Double]) { 271 | val nme = TypeName(s"_2$$spec$$JD") 272 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 273 | } else if (tp0 == typeOf[Double] && tp1 == typeOf[Double]) { 274 | val nme = TypeName(s"_2$$spec$$DD") 275 | (tq"_root_.org.coroutines.$nme", yldtpt :: rettpt :: Nil) 276 | } else if (tp1 == typeOf[Double]) { 277 | val nme = TypeName(s"_2$$spec$$LD") 278 | (tq"_root_.org.coroutines.$nme", argtpts(0) :: yldtpt :: rettpt :: Nil) 279 | } else if (tp0 == typeOf[Int]) { 280 | val nme = TypeName(s"_2$$spec$$IL") 281 | (tq"_root_.org.coroutines.$nme", argtpts(1) :: yldtpt :: rettpt :: Nil) 282 | } else if (tp0 == typeOf[Long]) { 283 | val nme = TypeName(s"_2$$spec$$JL") 284 | (tq"_root_.org.coroutines.$nme", argtpts(1) :: yldtpt :: rettpt :: Nil) 285 | } else if (tp0 == typeOf[Double]) { 286 | val nme = TypeName(s"_2$$spec$$DL") 287 | (tq"_root_.org.coroutines.$nme", argtpts(1) :: yldtpt :: rettpt :: Nil) 288 | } else { 289 | val nme = TypeName(s"_2$$spec$$LL") 290 | val tpes = argtpts(0) :: argtpts(1) :: yldtpt :: rettpt :: Nil 291 | (tq"_root_.org.coroutines.$nme", tpes) 292 | } 293 | } 294 | 295 | def genCoroutineTpe( 296 | argtpts: List[Tree], yldtpt: Tree, rettpt: Tree 297 | ): (Tree, List[Tree]) = { 298 | if (argtpts.length == 1) { 299 | specArity1(argtpts, yldtpt, rettpt) 300 | } else if (argtpts.length == 2) { 301 | specArity2(argtpts, yldtpt, rettpt) 302 | } else if (argtpts.length == 0 || argtpts.length > 2) { 303 | val nme = TypeName(s"_${argtpts.size}") 304 | (tq"_root_.org.coroutines.Coroutine.$nme", argtpts :+ yldtpt :+ rettpt) 305 | } else sys.error("Unreachable case.") 306 | } 307 | 308 | def synthesize(rawlambda: Tree): Tree = { 309 | // transform to two operand assignment form 310 | val typedtaflambda = canonicalizeTree(rawlambda) 311 | // println(typedtaflambda) 312 | // println(typedtaflambda.tpe) 313 | 314 | implicit val table = new Table(typedtaflambda) 315 | 316 | // ensure that argument is a function literal 317 | val q"(..$args) => $body" = typedtaflambda 318 | val argidents = for (arg <- args) yield { 319 | val q"$_ val $argname: $_ = $_" = arg 320 | q"$argname" 321 | } 322 | 323 | // extract argument names and types 324 | val (argnames, argtpts) = (for (arg <- args) yield { 325 | val q"$_ val $name: $tpt = $_" = arg 326 | (name, tpt) 327 | }).unzip 328 | 329 | // infer coroutine return type 330 | val rettpt = table.returnType 331 | val yldtpt = table.yieldType 332 | 333 | // generate control flow graph 334 | val cfg = genControlFlowGraph(args, body, rettpt) 335 | 336 | // generate entry points from yields and coroutine applications 337 | val entrypoints = genEntryPoints(cfg) 338 | 339 | // generate entry method 340 | val entermethod = genEnterMethod(entrypoints) 341 | 342 | // generate return value method 343 | val returnvaluemethods = genReturnValueMethods(cfg) 344 | 345 | // generate variable pushes and pops for stack variables 346 | val (varpushes, varpops) = genVarPushesAndPops(cfg) 347 | 348 | // emit coroutine instantiation 349 | val (coroutinequal, tparams) = genCoroutineTpe(argtpts, yldtpt, rettpt) 350 | val entrypointmethods = entrypoints.map(_._2) 351 | val valnme = TermName(c.freshName("c")) 352 | val co = q""" 353 | new $coroutinequal[..$tparams] { 354 | def $$call( 355 | ..$args 356 | ): _root_.org.coroutines.Coroutine.Instance[$yldtpt, $rettpt] = { 357 | val $valnme = new _root_.org.coroutines.Coroutine.Instance[$yldtpt, $rettpt] 358 | $$push($valnme, ..$argidents) 359 | $valnme 360 | } 361 | def apply(..$args): $rettpt = { 362 | _root_.scala.sys.error( 363 | _root_.org.coroutines.COROUTINE_DIRECT_APPLY_ERROR_MESSAGE) 364 | } 365 | def $$push( 366 | $$c: _root_.org.coroutines.Coroutine.Instance[$yldtpt, $rettpt], ..$args 367 | ): _root_.scala.Unit = { 368 | _root_.org.coroutines.common.Stack.push($$c.$$costack, this, -1) 369 | _root_.org.coroutines.common.Stack.push($$c.$$pcstack, 0.toShort, -1) 370 | ..$varpushes 371 | } 372 | def $$pop( 373 | $$c: _root_.org.coroutines.Coroutine.Instance[$yldtpt, $rettpt] 374 | ): _root_.scala.Unit = { 375 | _root_.org.coroutines.common.Stack.pop($$c.$$pcstack) 376 | _root_.org.coroutines.common.Stack.pop($$c.$$costack) 377 | ..$varpops 378 | } 379 | $entermethod 380 | ..$entrypointmethods 381 | ..$returnvaluemethods 382 | } 383 | """ 384 | // println(co) 385 | co 386 | } 387 | 388 | def call[R: WeakTypeTag](tree: Tree): Tree = { 389 | val (receiver, args) = tree match { 390 | case q"$r.apply(..$args)" => 391 | if (!isCoroutineDefMarker(r.tpe)) 392 | c.abort(r.pos, 393 | s"Receiver must be a coroutine.\n" + 394 | s"required: Coroutine[_, ${implicitly[WeakTypeTag[R]]}]\n" + 395 | s"found: ${r.tpe} (with underlying type ${r.tpe.widen})") 396 | (r, args) 397 | case q"$r.apply[..$_](..$args)(..$_)" => 398 | if (!isCoroutineDefSugar(r.tpe)) 399 | c.abort(r.pos, 400 | s"Receiver must be a coroutine.\n" + 401 | s"required: Coroutine[_, ${implicitly[WeakTypeTag[R]]}]\n" + 402 | s"found: ${r.tpe} (with underlying type ${r.tpe.widen})") 403 | (r, args) 404 | case _ => 405 | c.abort( 406 | tree.pos, 407 | "The call statement must take a coroutine invocation expression:\n" + 408 | " call(.apply(, ..., ))") 409 | } 410 | val tpargs = coroutineMethodArgs(receiver.tpe) 411 | 412 | val t = q""" 413 | $receiver.$$call[..$tpargs](..$args) 414 | """ 415 | t 416 | } 417 | } 418 | -------------------------------------------------------------------------------- /src/main/scala/org/coroutines/package.scala: -------------------------------------------------------------------------------- 1 | package org 2 | 3 | 4 | 5 | import scala.annotation.implicitNotFound 6 | import scala.language.experimental.macros 7 | import scala.reflect.macros.whitebox.Context 8 | 9 | 10 | 11 | package object coroutines { 12 | 13 | val COROUTINE_DIRECT_APPLY_ERROR_MESSAGE = 14 | "Coroutines can only be invoked directly from within other coroutines. " + 15 | "Use `call((, ..., ))` instead if you want to " + 16 | "start a new coroutine." 17 | 18 | case class CoroutineStoppedException() extends Exception 19 | 20 | def yieldval[T](x: T): Unit = { 21 | sys.error("Yield allowed only inside coroutines.") 22 | } 23 | 24 | def yieldto[T](f: Coroutine.Instance[T, _]): Unit = { 25 | sys.error("Yield allowed only inside coroutines.") 26 | } 27 | 28 | def call[R](f: R): Any = macro Coroutine.call[R] 29 | 30 | def coroutine[Y, R](f: Any): Any = macro Coroutine.synthesize 31 | 32 | /* syntax sugar */ 33 | 34 | type <~>[Y, R] = Coroutine.Instance[Y, R] 35 | 36 | class ~~~>[@specialized S, R] private[coroutines] ( 37 | val blueprint: Coroutine[S, R] 38 | ) extends Coroutine.DefMarker[(S, R)] { 39 | def apply(): R = 40 | sys.error(COROUTINE_DIRECT_APPLY_ERROR_MESSAGE) 41 | def $call(): Coroutine.Instance[S, R] = 42 | blueprint.asInstanceOf[Coroutine._0[S, R]].$call() 43 | def $push(co: Coroutine.Instance[S, R]): Unit = 44 | blueprint.asInstanceOf[Coroutine._0[S, R]].$push(co) 45 | } 46 | 47 | class ~~>[T, YR] private[coroutines] ( 48 | val blueprint: Coroutine.DefMarker[YR] 49 | ) extends Coroutine.DefMarker[YR] { 50 | def apply[@specialized S, R](t: T)(implicit e: (S, R) =:= YR): R = 51 | sys.error(COROUTINE_DIRECT_APPLY_ERROR_MESSAGE) 52 | def $call[@specialized S, R](t: T)( 53 | implicit e: (S, R) =:= YR 54 | ): Coroutine.Instance[S, R] = { 55 | blueprint.asInstanceOf[Coroutine._1[T, S, R]].$call(t) 56 | } 57 | def $push[@specialized S, R](co: Coroutine.Instance[S, R], t: T)( 58 | implicit e: (S, R) =:= YR 59 | ): Unit = { 60 | blueprint.asInstanceOf[Coroutine._1[T, S, R]].$push(co, t) 61 | } 62 | } 63 | 64 | class ~>[PS, YR] private[coroutines] ( 65 | val blueprint: Coroutine.DefMarker[YR] 66 | ) extends Coroutine.DefMarker[YR] { 67 | def apply[T1, T2, @specialized S, R](t1: T1, t2: T2)( 68 | implicit ps: PS =:= Tuple2[T1, T2], yr: (S, R) =:= YR 69 | ): R = { 70 | sys.error(COROUTINE_DIRECT_APPLY_ERROR_MESSAGE) 71 | } 72 | def $call[T1, T2, @specialized S, R](t1: T1, t2: T2)( 73 | implicit ps: PS =:= Tuple2[T1, T2], yr: (S, R) =:= YR 74 | ): Coroutine.Instance[S, R] = { 75 | blueprint.asInstanceOf[Coroutine._2[T1, T2, S, R]].$call(t1, t2) 76 | } 77 | def $push[T1, T2, @specialized S, R](co: Coroutine.Instance[S, R], t1: T1, t2: T2)( 78 | implicit ps: PS =:= Tuple2[T1, T2], yr: (S, R) =:= YR 79 | ): Unit = { 80 | blueprint.asInstanceOf[Coroutine._2[T1, T2, S, R]].$push(co, t1, t2) 81 | } 82 | def apply[T1, T2, T3, @specialized S, R](t1: T1, t2: T2, t3: T3)( 83 | implicit ps: PS =:= Tuple3[T1, T2, T3], yr: (S, R) =:= YR 84 | ): R = { 85 | sys.error(COROUTINE_DIRECT_APPLY_ERROR_MESSAGE) 86 | } 87 | def $call[T1, T2, T3, @specialized S, R](t1: T1, t2: T2, t3: T3)( 88 | implicit ps: PS =:= Tuple3[T1, T2, T3], yr: (S, R) =:= YR 89 | ): Coroutine.Instance[S, R] = { 90 | blueprint.asInstanceOf[Coroutine._3[T1, T2, T3, S, R]].$call(t1, t2, t3) 91 | } 92 | def $push[T1, T2, T3, @specialized S, R]( 93 | co: Coroutine.Instance[S, R], t1: T1, t2: T2, t3: T3 94 | )( 95 | implicit ps: PS =:= Tuple3[T1, T2, T3], yr: (S, R) =:= YR 96 | ): Unit = { 97 | blueprint.asInstanceOf[Coroutine._3[T1, T2, T3, S, R]].$push(co, t1, t2, t3) 98 | } 99 | } 100 | 101 | implicit def coroutine0nothing[R](b: Coroutine._0[Nothing, R]) = 102 | new ~~~>[Nothing, R](b) 103 | 104 | implicit def coroutine0[@specialized S, R](b: Coroutine._0[S, R]) = 105 | new ~~~>[S, R](b) 106 | 107 | implicit def coroutine1nothing[T, R](b: Coroutine._1[T, Nothing, R]) = 108 | new ~~>[T, (Nothing, R)](b) 109 | 110 | implicit def coroutine1[T, @specialized S, R](b: Coroutine._1[T, S, R]) = 111 | new ~~>[T, (S, R)](b) 112 | 113 | implicit def coroutine2nothing[T1, T2, R]( 114 | b: Coroutine._2[T1, T2, Nothing, R] 115 | ) = { 116 | new ~>[Tuple2[T1, T2], (Nothing, R)](b) 117 | } 118 | 119 | implicit def coroutine2[T1, T2, @specialized S, R](b: Coroutine._2[T1, T2, S, R]) = 120 | new ~>[Tuple2[T1, T2], (S, R)](b) 121 | 122 | implicit def coroutine3nothing[T1, T2, T3, R]( 123 | b: Coroutine._3[T1, T2, T3, Nothing, R] 124 | ) = { 125 | new ~>[Tuple3[T1, T2, T3], (Nothing, R)](b) 126 | } 127 | 128 | implicit def coroutine3[T1, T2, T3, @specialized S, R]( 129 | b: Coroutine._3[T1, T2, T3, S, R] 130 | ) = { 131 | new ~>[Tuple3[T1, T2, T3], (S, R)](b) 132 | } 133 | } 134 | -------------------------------------------------------------------------------- /src/main/scala/org/coroutines/specializations.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | 6 | 7 | 8 | /* Coroutine._1 specializations */ 9 | 10 | trait _1$spec$S[@specialized(Short, Char, Int, Float, Long, Double) S, R] 11 | extends Coroutine._1[Short, S, R] { 12 | def apply(a0: Short): R 13 | def $call(a0: Short): Coroutine.Instance[S, R] 14 | def $push(c: Coroutine.Instance[S, R], a0: Short): Unit 15 | } 16 | 17 | 18 | trait _1$spec$C[@specialized(Short, Char, Int, Float, Long, Double) S, R] 19 | extends Coroutine._1[Char, S, R] { 20 | def apply(a0: Char): R 21 | def $call(a0: Char): Coroutine.Instance[S, R] 22 | def $push(c: Coroutine.Instance[S, R], a0: Char): Unit 23 | } 24 | 25 | 26 | trait _1$spec$I[@specialized(Short, Char, Int, Float, Long, Double) S, R] 27 | extends Coroutine._1[Int, S, R] { 28 | def apply(a0: Int): R 29 | def $call(a0: Int): Coroutine.Instance[S, R] 30 | def $push(c: Coroutine.Instance[S, R], a0: Int): Unit 31 | } 32 | 33 | 34 | trait _1$spec$F[@specialized(Short, Char, Int, Float, Long, Double) S, R] 35 | extends Coroutine._1[Float, S, R] { 36 | def apply(a0: Float): R 37 | def $call(a0: Float): Coroutine.Instance[S, R] 38 | def $push(c: Coroutine.Instance[S, R], a0: Float): Unit 39 | } 40 | 41 | 42 | trait _1$spec$J[@specialized(Short, Char, Int, Float, Long, Double) S, R] 43 | extends Coroutine._1[Long, S, R] { 44 | def apply(a0: Long): R 45 | def $call(a0: Long): Coroutine.Instance[S, R] 46 | def $push(c: Coroutine.Instance[S, R], a0: Long): Unit 47 | } 48 | 49 | 50 | trait _1$spec$D[@specialized(Short, Char, Int, Float, Long, Double) S, R] 51 | extends Coroutine._1[Double, S, R] { 52 | def apply(a0: Double): R 53 | def $call(a0: Double): Coroutine.Instance[S, R] 54 | def $push(c: Coroutine.Instance[S, R], a0: Double): Unit 55 | } 56 | 57 | trait _1$spec$L[T0, @specialized(Short, Char, Int, Float, Long, Double) S, R] 58 | extends Coroutine._1[T0, S, R] { 59 | def apply(a0: T0): R 60 | def $call(a0: T0): Coroutine.Instance[S, R] 61 | def $push(c: Coroutine.Instance[S, R], a0: T0): Unit 62 | } 63 | 64 | /* Coroutine._2 specializations. */ 65 | 66 | trait _2$spec$II[@specialized(Int, Long, Double) S, R] 67 | extends Coroutine._2[Int, Int, S, R] { 68 | def apply(a0: Int, a1: Int): R 69 | def $call(a0: Int, a1: Int): Coroutine.Instance[S, R] 70 | def $push(c: Coroutine.Instance[S, R], a0: Int, a1: Int): Unit 71 | } 72 | 73 | trait _2$spec$JI[@specialized(Int, Long, Double) S, R] 74 | extends Coroutine._2[Long, Int, S, R] { 75 | def apply(a0: Long, a1: Int): R 76 | def $call(a0: Long, a1: Int): Coroutine.Instance[S, R] 77 | def $push(c: Coroutine.Instance[S, R], a0: Long, a1: Int): Unit 78 | } 79 | 80 | trait _2$spec$DI[@specialized(Int, Long, Double) S, R] 81 | extends Coroutine._2[Double, Int, S, R] { 82 | def apply(a0: Double, a1: Int): R 83 | def $call(a0: Double, a1: Int): Coroutine.Instance[S, R] 84 | def $push(c: Coroutine.Instance[S, R], a0: Double, a1: Int): Unit 85 | } 86 | 87 | trait _2$spec$LI[T0, @specialized(Int, Long, Double) S, R] 88 | extends Coroutine._2[T0, Int, S, R] { 89 | def apply(a0: T0, a1: Int): R 90 | def $call(a0: T0, a1: Int): Coroutine.Instance[S, R] 91 | def $push(c: Coroutine.Instance[S, R], a0: T0, a1: Int): Unit 92 | } 93 | 94 | trait _2$spec$IJ[@specialized(Int, Long, Double) S, R] 95 | extends Coroutine._2[Int, Long, S, R] { 96 | def apply(a0: Int, a1: Long): R 97 | def $call(a0: Int, a1: Long): Coroutine.Instance[S, R] 98 | def $push(c: Coroutine.Instance[S, R], a0: Int, a1: Long): Unit 99 | } 100 | 101 | trait _2$spec$JJ[@specialized(Int, Long, Double) S, R] 102 | extends Coroutine._2[Long, Long, S, R] { 103 | def apply(a0: Long, a1: Long): R 104 | def $call(a0: Long, a1: Long): Coroutine.Instance[S, R] 105 | def $push(c: Coroutine.Instance[S, R], a0: Long, a1: Long): Unit 106 | } 107 | 108 | trait _2$spec$DJ[@specialized(Int, Long, Double) S, R] 109 | extends Coroutine._2[Double, Long, S, R] { 110 | def apply(a0: Double, a1: Long): R 111 | def $call(a0: Double, a1: Long): Coroutine.Instance[S, R] 112 | def $push(c: Coroutine.Instance[S, R], a0: Double, a1: Long): Unit 113 | } 114 | 115 | trait _2$spec$LJ[T0, @specialized(Int, Long, Double) S, R] 116 | extends Coroutine._2[T0, Long, S, R] { 117 | def apply(a0: T0, a1: Long): R 118 | def $call(a0: T0, a1: Long): Coroutine.Instance[S, R] 119 | def $push(c: Coroutine.Instance[S, R], a0: T0, a1: Long): Unit 120 | } 121 | 122 | trait _2$spec$ID[@specialized(Int, Long, Double) S, R] 123 | extends Coroutine._2[Int, Double, S, R] { 124 | def apply(a0: Int, a1: Double): R 125 | def $call(a0: Int, a1: Double): Coroutine.Instance[S, R] 126 | def $push(c: Coroutine.Instance[S, R], a0: Int, a1: Double): Unit 127 | } 128 | 129 | trait _2$spec$JD[@specialized(Int, Long, Double) S, R] 130 | extends Coroutine._2[Long, Double, S, R] { 131 | def apply(a0: Long, a1: Double): R 132 | def $call(a0: Long, a1: Double): Coroutine.Instance[S, R] 133 | def $push(c: Coroutine.Instance[S, R], a0: Long, a1: Double): Unit 134 | } 135 | 136 | trait _2$spec$DD[@specialized(Int, Long, Double) S, R] 137 | extends Coroutine._2[Double, Double, S, R] { 138 | def apply(a0: Double, a1: Double): R 139 | def $call(a0: Double, a1: Double): Coroutine.Instance[S, R] 140 | def $push(c: Coroutine.Instance[S, R], a0: Double, a1: Double): Unit 141 | } 142 | 143 | trait _2$spec$LD[T0, @specialized(Int, Long, Double) S, R] 144 | extends Coroutine._2[T0, Double, S, R] { 145 | def apply(a0: T0, a1: Double): R 146 | def $call(a0: T0, a1: Double): Coroutine.Instance[S, R] 147 | def $push(c: Coroutine.Instance[S, R], a0: T0, a1: Double): Unit 148 | } 149 | 150 | trait _2$spec$IL[T1, @specialized(Int, Long, Double) S, R] 151 | extends Coroutine._2[Int, T1, S, R] { 152 | def apply(a0: Int, a1: T1): R 153 | def $call(a0: Int, a1: T1): Coroutine.Instance[S, R] 154 | def $push(c: Coroutine.Instance[S, R], a0: Int, a1: T1): Unit 155 | } 156 | 157 | trait _2$spec$JL[T1, @specialized(Int, Long, Double) S, R] 158 | extends Coroutine._2[Long, T1, S, R] { 159 | def apply(a0: Long, a1: T1): R 160 | def $call(a0: Long, a1: T1): Coroutine.Instance[S, R] 161 | def $push(c: Coroutine.Instance[S, R], a0: Long, a1: T1): Unit 162 | } 163 | 164 | trait _2$spec$DL[T1, @specialized(Int, Long, Double) S, R] 165 | extends Coroutine._2[Double, T1, S, R] { 166 | def apply(a0: Double, a1: T1): R 167 | def $call(a0: Double, a1: T1): Coroutine.Instance[S, R] 168 | def $push(c: Coroutine.Instance[S, R], a0: Double, a1: T1): Unit 169 | } 170 | 171 | trait _2$spec$LL[T0, T1, @specialized(Int, Long, Double) S, R] 172 | extends Coroutine._2[T0, T1, S, R] { 173 | def apply(a0: T0, a1: T1): R 174 | def $call(a0: T0, a1: T1): Coroutine.Instance[S, R] 175 | def $push(c: Coroutine.Instance[S, R], a0: T0, a1: T1): Unit 176 | } 177 | -------------------------------------------------------------------------------- /src/test/scala/org/coroutines/ast-canonicalization-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalatest._ 6 | import scala.util.Failure 7 | 8 | 9 | 10 | class ASTCanonicalizationTest extends FunSuite with Matchers { 11 | test("if statements with applications") { 12 | val rube = coroutine { () => 13 | if (0 < { math.abs(-1); math.max(1, 2) }) 2 else 1 14 | } 15 | val c = call(rube()) 16 | assert(!c.resume) 17 | assert(c.result == 2) 18 | } 19 | 20 | test("if statements with applications and yield") { 21 | val rube = coroutine { () => 22 | val x = if (0 < { math.abs(-1); math.max(1, 2) }) 2 else 1 23 | yieldval(x) 24 | -x 25 | } 26 | val c = call(rube()) 27 | assert(c.resume) 28 | assert(c.value == 2) 29 | assert(!c.resume) 30 | assert(c.result == -2) 31 | assert(c.isCompleted) 32 | } 33 | 34 | test("if statements with selections") { 35 | val rube = coroutine { () => 36 | if (0 < { math.abs(math.Pi) }) 2 else 1 37 | } 38 | val c = call(rube()) 39 | assert(!c.resume) 40 | assert(c.result == 2) 41 | } 42 | 43 | test("if statements with selections and yield") { 44 | val rube = coroutine { () => 45 | val x = if (0 < { math.abs(math.Pi) }) 2 else 1 46 | yieldval(x) 47 | -x 48 | } 49 | val c = call(rube()) 50 | assert(c.resume) 51 | assert(c.value == 2) 52 | assert(!c.resume) 53 | assert(c.result == -2) 54 | assert(c.isCompleted) 55 | } 56 | 57 | test("if statements with updates") { 58 | val rube = coroutine { () => 59 | val xs = new Array[Int](2) 60 | if (0 < { xs(0) = 1; xs(0) }) 2 else 1 61 | } 62 | val c = call(rube()) 63 | assert(!c.resume) 64 | assert(c.result == 2) 65 | } 66 | 67 | test("if statements with block in tuple") { 68 | val rube = coroutine { () => 69 | if (0 < ({ math.abs(1); math.abs(3) + 2 }, 2)._1) 2 else 1 70 | } 71 | val c = call(rube()) 72 | assert(!c.resume) 73 | assert(c.result == 2) 74 | } 75 | 76 | test("if statement with another if statement in condition") { 77 | val rube = coroutine { () => 78 | if (0 < (if (math.abs(-1) > 5) 1 else 2)) 2 else 1 79 | } 80 | val c = call(rube()) 81 | assert(!c.resume) 82 | assert(c.result == 2) 83 | } 84 | 85 | test("value declaration should be the last statement") { 86 | val unit = coroutine { () => 87 | val t = (2, 3) 88 | val (y, z) = t 89 | } 90 | 91 | val c = call(unit()) 92 | assert(!c.resume) 93 | assert(!c.isLive) 94 | c.result 95 | assert(!c.hasException) 96 | } 97 | 98 | test("coroutine should be callable outside value declaration") { 99 | var y = 0 100 | val setY = coroutine { (x: Int) => y = x } 101 | val setTo5 = coroutine { () => 102 | setY(5) 103 | } 104 | val c = call(setTo5()) 105 | assert(!c.resume) 106 | assert(y == 5) 107 | } 108 | 109 | test("coroutine should be callable outside value declaration and yield") { 110 | var y = 0 111 | val setY = coroutine { (x: Int) => y = x } 112 | val setTo5 = coroutine { () => 113 | yieldval(setY(5)) 114 | setY(-5) 115 | } 116 | val c = call(setTo5()) 117 | assert(c.resume) 118 | assert(y == 5) 119 | assert(!c.resume) 120 | assert(y == -5) 121 | } 122 | 123 | test("coroutine should yield in while loop with complex condition") { 124 | val rube = coroutine { (x: Int) => 125 | var i = 0 126 | while (i < x && x < math.abs(-15)) { 127 | yieldval(i) 128 | i += 1 129 | } 130 | i 131 | } 132 | val c1 = call(rube(10)) 133 | for (i <- 0 until 10) { 134 | assert(c1.resume) 135 | assert(c1.value == i) 136 | } 137 | assert(!c1.resume) 138 | assert(c1.result == 10) 139 | assert(c1.isCompleted) 140 | val c2 = call(rube(20)) 141 | assert(!c2.resume) 142 | assert(c2.result == 0) 143 | assert(c2.isCompleted) 144 | } 145 | 146 | test("coroutine should yield every second element or just zero") { 147 | val rube = coroutine { (x: Int) => 148 | var i = 0 149 | while (i < x && x < math.abs(-15)) { 150 | if (i % 2 == 0) yieldval(i) 151 | i += 1 152 | } 153 | i 154 | } 155 | 156 | val c1 = call(rube(10)) 157 | for (i <- 0 until 10; if i % 2 == 0) { 158 | assert(c1.resume) 159 | assert(c1.value == i) 160 | } 161 | assert(!c1.resume) 162 | assert(c1.result == 10) 163 | assert(c1.isCompleted) 164 | val c2 = call(rube(20)) 165 | assert(!c2.resume) 166 | assert(c2.result == 0) 167 | assert(c2.isCompleted) 168 | } 169 | 170 | test("coroutine should yield 1 or yield 10 elements, and then 117") { 171 | val rube = coroutine { (x: Int) => 172 | var i = 1 173 | if (x > math.abs(0)) { 174 | while (i < x) { 175 | yieldval(i) 176 | i += 1 177 | } 178 | } else { 179 | yieldval(i) 180 | } 181 | 117 182 | } 183 | 184 | val c1 = call(rube(10)) 185 | for (i <- 1 until 10) { 186 | assert(c1.resume) 187 | assert(c1.value == i) 188 | } 189 | assert(!c1.resume) 190 | assert(c1.result == 117) 191 | assert(c1.isCompleted) 192 | val c2 = call(rube(-10)) 193 | assert(c2.resume) 194 | assert(c2.value == 1) 195 | assert(!c2.resume) 196 | assert(c2.result == 117) 197 | assert(c2.isCompleted) 198 | } 199 | 200 | test("yield absolute and original value") { 201 | val rube = coroutine { (x: Int) => 202 | yieldval(math.abs(x)) 203 | x 204 | } 205 | 206 | val c = call(rube(-5)) 207 | assert(c.resume) 208 | assert(c.value == 5) 209 | assert(!c.resume) 210 | assert(c.result == -5) 211 | assert(c.isCompleted) 212 | } 213 | 214 | test("short-circuiting should work for and") { 215 | var state = "untouched" 216 | val rube = coroutine { (x: Int) => 217 | if (x < 0 && { state = "touched"; true }) x 218 | else -x 219 | } 220 | 221 | val c0 = call(rube(5)) 222 | assert(!c0.resume) 223 | assert(c0.result == -5) 224 | assert(c0.isCompleted) 225 | assert(state == "untouched") 226 | 227 | val c1 = call(rube(-5)) 228 | assert(!c1.resume) 229 | assert(c1.result == -5) 230 | assert(c1.isCompleted) 231 | assert(state == "touched") 232 | } 233 | 234 | test("short-circuiting should work for or") { 235 | var state = "untouched" 236 | val rube = coroutine { (x: Int) => 237 | if (x > 0 || { state = "touched"; false }) x 238 | else -x 239 | } 240 | 241 | val c0 = call(rube(5)) 242 | assert(!c0.resume) 243 | assert(c0.result == 5) 244 | assert(c0.isCompleted) 245 | assert(state == "untouched") 246 | 247 | val c1 = call(rube(-5)) 248 | assert(!c1.resume) 249 | assert(c1.result == 5) 250 | assert(c1.isCompleted) 251 | assert(state == "touched") 252 | } 253 | 254 | test("do-while should be simplified into a while loop") { 255 | val rube = coroutine { (x: Int) => 256 | var i = 0 257 | do { 258 | yieldval(i) 259 | 260 | i += 1 261 | } while (i < x) 262 | i 263 | } 264 | 265 | val c0 = call(rube(5)) 266 | assert(c0.resume) 267 | assert(c0.value == 0) 268 | assert(c0.resume) 269 | assert(c0.value == 1) 270 | assert(c0.resume) 271 | assert(c0.value == 2) 272 | assert(c0.resume) 273 | assert(c0.value == 3) 274 | assert(c0.resume) 275 | assert(c0.value == 4) 276 | assert(!c0.resume) 277 | assert(c0.result == 5) 278 | assert(c0.isCompleted) 279 | 280 | val c1 = call(rube(0)) 281 | assert(c1.resume) 282 | assert(c1.value == 0) 283 | assert(!c1.resume) 284 | assert(c1.result == 1) 285 | assert(c1.isCompleted) 286 | } 287 | 288 | test("should be able to define uncalled function inside coroutine") { 289 | val oy = coroutine { () => 290 | def foo(): String = "bar" 291 | val bar = "bar" 292 | 1 293 | } 294 | val c = call(oy()) 295 | assert(!c.resume) 296 | assert(c.hasResult) 297 | assert(c.result == 1) 298 | assert(c.isCompleted) 299 | } 300 | 301 | test("should be able to define called function inside coroutine") { 302 | val oy = coroutine { () => 303 | def foo(): String = "bar" 304 | val bar = foo() 305 | 1 306 | } 307 | val c = call(oy()) 308 | assert(!c.resume) 309 | assert(c.hasResult) 310 | assert(c.result == 1) 311 | assert(c.isCompleted) 312 | } 313 | } 314 | -------------------------------------------------------------------------------- /src/test/scala/org/coroutines/async-await-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalatest._ 6 | import scala.annotation.unchecked.uncheckedVariance 7 | import scala.concurrent._ 8 | import scala.concurrent.duration._ 9 | import scala.concurrent.ExecutionContext.Implicits.global 10 | import scala.util.Success 11 | 12 | 13 | 14 | object AsyncAwaitTest { 15 | class Cell[+T] { 16 | var x: T @uncheckedVariance = _ 17 | } 18 | 19 | object ToughTypeObject { 20 | class Inner 21 | 22 | def m2 = async(coroutine { () => 23 | val y = await { Future[List[_]] { Nil } } 24 | val z = await { Future[Inner] { new Inner } } 25 | (y, z) 26 | }) 27 | } 28 | 29 | // Doubly defined for ToughTypeObject 30 | def await[R]: Future[R] ~~> ((Future[R], Cell[R]), R) = 31 | coroutine { (f: Future[R]) => 32 | val cell = new Cell[R] 33 | yieldval((f, cell)) 34 | cell.x 35 | } 36 | 37 | // Doubly defined for ToughTypeObject 38 | def async[Y, R](body: ~~~>[(Future[Y], Cell[Y]), R]): Future[R] = { 39 | val c = call(body()) 40 | val p = Promise[R] 41 | def loop() { 42 | if (!c.resume) p.success(c.result) 43 | else { 44 | val (future, cell) = c.value 45 | for (x <- future) { 46 | cell.x = x 47 | loop() 48 | } 49 | } 50 | } 51 | Future { loop() } 52 | p.future 53 | } 54 | } 55 | 56 | 57 | class IntWrapper(val value: String) extends AnyVal { 58 | def plusStr = Future.successful(value + "!") 59 | } 60 | 61 | 62 | class ParamWrapper[T](val value: T) extends AnyVal 63 | 64 | 65 | class PrivateWrapper private (private val value: String) extends AnyVal 66 | 67 | 68 | object PrivateWrapper { 69 | def Instance = new PrivateWrapper("") 70 | } 71 | 72 | 73 | class AsyncAwaitTest extends FunSuite with Matchers { 74 | def await[R]: Future[R] ~~> ((Future[R], AsyncAwaitTest.Cell[R]), R) = 75 | coroutine { (f: Future[R]) => 76 | val cell = new AsyncAwaitTest.Cell[R] 77 | yieldval((f, cell)) 78 | cell.x 79 | } 80 | 81 | def async[Y, R](body: ~~~>[(Future[Y], AsyncAwaitTest.Cell[Y]), R]): Future[R] = { 82 | val c = call(body()) 83 | val p = Promise[R] 84 | def loop() { 85 | if (!c.resume) p.success(c.result) 86 | else { 87 | val (future, cell) = c.value 88 | for (x <- future) { 89 | cell.x = x 90 | loop() 91 | } 92 | } 93 | } 94 | Future { loop() } 95 | p.future 96 | } 97 | 98 | // Source: https://git.io/vrHtj 99 | test("propagates tough types") { 100 | val fut = org.coroutines.AsyncAwaitTest.ToughTypeObject.m2 101 | val result: (List[_], org.coroutines.AsyncAwaitTest.ToughTypeObject.Inner) = 102 | Await.result(fut, 2 seconds) 103 | assert(result._1 == Nil) 104 | } 105 | 106 | // Source: https://git.io/vr7H9 107 | test("pattern matching function") { 108 | val c = async(coroutine { () => 109 | await(Future(1)) 110 | val a = await(Future(1)) 111 | val f = { case x => x + a }: Function[Int, Int] 112 | await(Future(f(2))) 113 | }) 114 | val res = Await.result(c, 2 seconds) 115 | assert(res == 3) 116 | } 117 | 118 | // Source: https://git.io/vr7HA 119 | test("existential bind 1") { 120 | def m(a: Any) = async(coroutine { () => 121 | a match { 122 | case s: Seq[_] => 123 | val x = s.size 124 | var ss = s 125 | ss = s 126 | await(Future(x)) 127 | } 128 | }) 129 | val res = Await.result(m(Nil), 2 seconds) 130 | assert(res == 0) 131 | } 132 | 133 | // Source: https://git.io/vr7Qm 134 | test("existential bind 2") { 135 | def conjure[T]: T = null.asInstanceOf[T] 136 | 137 | def m1 = AsyncAwaitTest.async(coroutine { () => 138 | val p: List[Option[_]] = conjure[List[Option[_]]] 139 | AsyncAwaitTest.await(Future(1)) 140 | }) 141 | 142 | def m2 = AsyncAwaitTest.async(coroutine { () => 143 | AsyncAwaitTest.await(Future[List[_]](Nil)) 144 | }) 145 | } 146 | 147 | // Source: https://git.io/vr7Fx 148 | test("existential if/else") { 149 | trait Container[+A] 150 | case class ContainerImpl[A](value: A) extends Container[A] 151 | def foo: Future[Container[_]] = AsyncAwaitTest.async(coroutine { () => 152 | val a: Any = List(1) 153 | if (true) { 154 | val buf: Seq[_] = List(1) 155 | val foo = AsyncAwaitTest.await(Future(5)) 156 | val e0 = buf(0) 157 | ContainerImpl(e0) 158 | } else ??? 159 | }) 160 | foo 161 | } 162 | 163 | // Source: https://git.io/vr7ba 164 | test("ticket 63 in scala/async") { 165 | object SomeExecutionContext extends ExecutionContext { 166 | def reportFailure(t: Throwable): Unit = ??? 167 | def execute(runnable: Runnable): Unit = ??? 168 | } 169 | 170 | trait FunDep[W, S, R] { 171 | def method(w: W, s: S): Future[R] 172 | } 173 | 174 | object FunDep { 175 | implicit def `Something to do with List`[W, S, R] 176 | (implicit funDep: FunDep[W, S, R]) = 177 | new FunDep[W, List[S], W] { 178 | def method(w: W, l: List[S]) = AsyncAwaitTest.async(coroutine { () => 179 | val it = l.iterator 180 | while (it.hasNext) { 181 | AsyncAwaitTest.await(Future(funDep.method(w, it.next())) 182 | (SomeExecutionContext)) 183 | } 184 | w 185 | }) 186 | } 187 | } 188 | } 189 | 190 | // Source: https://git.io/vr7bX 191 | test("ticket 66 in scala/async") { 192 | val e = new Exception() 193 | val f: Future[Nothing] = Future.failed(e) 194 | val f1 = AsyncAwaitTest.async(coroutine { () => 195 | AsyncAwaitTest.await(Future(f)) 196 | }) 197 | try { 198 | Await.result(f1, 5.seconds) 199 | } catch { 200 | case `e` => 201 | } 202 | } 203 | 204 | // Source: https://git.io/vr7Nf 205 | test("ticket 83 in scala/async-- using value class") { 206 | val f = AsyncAwaitTest.async(coroutine { () => 207 | val uid = new IntWrapper("foo") 208 | AsyncAwaitTest.await(Future(Future(uid))) 209 | }) 210 | val outer = Await.result(f, 5.seconds) 211 | val inner = Await.result(outer, 5 seconds) 212 | assert(inner == new IntWrapper("foo")) 213 | } 214 | 215 | // Source: https://git.io/vr7Nk 216 | // test("ticket 86 in scala/async-- using matched value class") { 217 | // def doAThing(param: IntWrapper) = Future(None) 218 | 219 | // val fut = AsyncAwaitTest.async(coroutine { () => 220 | // Option(new IntWrapper("value!")) match { 221 | // case Some(valueHolder) => 222 | // AsyncAwaitTest.await(Future(doAThing(valueHolder))) 223 | // case None => 224 | // None 225 | // } 226 | // }) 227 | 228 | // val result = Await.result(fut, 5 seconds) 229 | // assert(result.asInstanceOf[Future[IntWrapper]].value == Some(Success(None))) 230 | // } 231 | 232 | // // Source: https://git.io/vr7NZ 233 | 234 | // Source: https://git.io/vr7NZ 235 | // TODO: Fix flakiness and uncomment. 236 | // test("ticket 86 in scala/async-- using matched parameterized value class") { 237 | // def doAThing(param: ParamWrapper[String]) = Future(None) 238 | 239 | // val fut = AsyncAwaitTest.async(coroutine { () => 240 | // Option(new ParamWrapper("value!")) match { 241 | // case Some(valueHolder) => 242 | // AsyncAwaitTest.await(Future(doAThing(valueHolder))) 243 | // case None => 244 | // None 245 | // } 246 | // }) 247 | 248 | // val result = Await.result(fut, 5 seconds) 249 | // assert(result.asInstanceOf[Future[ParamWrapper[String]]].value == 250 | // Some(Success(None))) 251 | // } 252 | 253 | // // Source: https://git.io/vr7NW 254 | // test("ticket 86 in scala/async-- using private value class") { 255 | // def doAThing(param: PrivateWrapper) = Future(None) 256 | 257 | // val fut = AsyncAwaitTest.async(coroutine { () => 258 | // Option(PrivateWrapper.Instance) match { 259 | // case Some(valueHolder) => 260 | // AsyncAwaitTest.await(doAThing(valueHolder)) 261 | // case None => 262 | // None 263 | // } 264 | // }) 265 | 266 | // val result = Await.result(fut, 5 seconds) 267 | // assert(result == None) 268 | // } 269 | 270 | // Source: https://git.io/vr7N8 271 | test("await of abstract type") { 272 | def combine[A](a1: A, a2: A): A = a1 273 | 274 | def combineAsync[A](a1: Future[A], a2: Future[A]) = 275 | async(coroutine { () => 276 | combine(await(Future(a1)), await(Future(a2))) 277 | }) 278 | 279 | val fut = combineAsync(Future(1), Future(2)) 280 | 281 | val outer = Await.result(fut, 5 seconds) 282 | val inner = Await.result(outer, 5 seconds) 283 | assert(inner == 1) 284 | } 285 | 286 | // Source: https://git.io/vrFp5 287 | test("match as expression 1") { 288 | val c = AsyncAwaitTest.async(coroutine { () => 289 | val x = "" match { 290 | case _ => AsyncAwaitTest.await(Future(1)) + 1 291 | } 292 | x 293 | }) 294 | val result = Await.result(c, 5 seconds) 295 | assert(result == 2) 296 | } 297 | 298 | // Source: https://git.io/vrFhh 299 | test("match as expression 2") { 300 | val c = AsyncAwaitTest.async(coroutine { () => 301 | val x = "" match { 302 | case "" if false => await(Future(1)) + 1 303 | case _ => 2 + await(Future(1)) 304 | } 305 | val y = x 306 | "" match { 307 | case _ => await(Future(y)) + 100 308 | } 309 | }) 310 | val result = Await.result(c, 5 seconds) 311 | assert(result == 103) 312 | } 313 | 314 | // Source: https://git.io/vrhTe 315 | test("named and default arguments respect evaluation order") { 316 | var i = 0 317 | def next() = { 318 | i += 1; 319 | i 320 | } 321 | def foo(a: Int = next(), b: Int = next()) = (a, b) 322 | val c1 = async(coroutine { () => 323 | foo(b = await(Future(next()))) 324 | }) 325 | assert(Await.result(c1, 5 seconds) == (2, 1)) 326 | i = 0 327 | val c2 = async(coroutine { () => 328 | foo(a = await(Future(next()))) 329 | }) 330 | assert(Await.result(c2, 5 seconds) == (1, 2)) 331 | } 332 | 333 | // Source: https://git.io/vrhTT 334 | test("repeated params 1") { 335 | var i = 0 336 | def foo(a: Int, b: Int*) = b.toList 337 | def id(i: Int) = i 338 | val c = async(coroutine { () => 339 | foo(await(Future(0)), id(1), id(2), id(3), await(Future(4))) 340 | }) 341 | assert(Await.result(c, 5 seconds) == List(1, 2, 3, 4)) 342 | } 343 | 344 | // Source: https://git.io/vrhTY 345 | test("repeated params 2") { 346 | var i = 0 347 | def foo(a: Int, b: Int*) = b.toList 348 | def id(i: Int) = i 349 | val c = async(coroutine { () => 350 | foo(await(Future(0)), List(id(1), id(2), id(3)): _*) 351 | }) 352 | assert(Await.result(c, 5 seconds) == List(1, 2, 3)) 353 | } 354 | 355 | // Source: https://git.io/vrhT0 356 | test("await in typed") { 357 | val c = async(coroutine { () => 358 | (("msg: " + await(Future(0))): String).toString 359 | }) 360 | assert(Await.result(c, 5 seconds) == "msg: 0") 361 | } 362 | 363 | // Source: https://git.io/vrhTz 364 | test("await in assign") { 365 | val c = async(coroutine { () => 366 | var x = 0 367 | x = await(Future(1)) 368 | x 369 | }) 370 | assert(Await.result(c, 5 seconds) == 1) 371 | } 372 | 373 | // Source: https://git.io/vrhTr 374 | test("case body must be typed as unit") { 375 | val Up = 1 376 | val Down = 2 377 | val sign = async(coroutine { () => 378 | await(Future(1)) match { 379 | case Up => 1.0 380 | case Down => -1.0 381 | } 382 | }) 383 | assert(Await.result(sign, 5 seconds) == 1.0) 384 | } 385 | 386 | test("compilation error in partial function") { 387 | val c = coroutine { () => 388 | try { 389 | sys.error("error") 390 | await(Future("ho")) 391 | } catch { 392 | case e: RuntimeException => await(Future("oh")) 393 | case _ => await(Future("ho")) 394 | } 395 | await(Future("oh")) 396 | } 397 | val future = async(c) 398 | assert(Await.result(future, 1 seconds) == "oh") 399 | } 400 | } 401 | -------------------------------------------------------------------------------- /src/test/scala/org/coroutines/boxing-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalameter.api._ 6 | import org.scalameter.japi.JBench 7 | import org.scalameter.japi.annotation._ 8 | import org.scalameter.picklers.noPickler._ 9 | import org.scalameter.execution.invocation._ 10 | 11 | 12 | 13 | class CoroutineBoxingBench extends JBench.Forked[Long] { 14 | override def defaultConfig: Context = Context( 15 | exec.minWarmupRuns -> 2, 16 | exec.maxWarmupRuns -> 5, 17 | exec.independentSamples -> 1, 18 | verbose -> false 19 | ) 20 | 21 | def measurer: Measurer[Long] = 22 | for (table <- Measurer.BoxingCount.allWithoutBoolean()) yield { 23 | table.copy(value = table.value.valuesIterator.sum) 24 | } 25 | 26 | def aggregator: Aggregator[Long] = Aggregator.median 27 | 28 | override def reporter = Reporter.Composite( 29 | LoggingReporter(), 30 | ValidationReporter() 31 | ) 32 | 33 | val sizes = Gen.single("size")(1000) 34 | 35 | /* range iterator */ 36 | 37 | val rangeCtx = Context( 38 | reports.validation.predicate -> { (n: Any) => n == 0 } 39 | ) 40 | 41 | @gen("sizes") 42 | @benchmark("coroutines.boxing.range") 43 | @curve("coroutine") 44 | @ctx("rangeCtx") 45 | def range(sz: Int) { 46 | val id = coroutine { (n: Int) => 47 | var i = 0 48 | while (i < n) { 49 | yieldval(i) 50 | i += 1 51 | } 52 | } 53 | 54 | var i = 0 55 | val c = call(id(sz)) 56 | while (i < sz) { 57 | c.resume 58 | c.value 59 | i += 1 60 | } 61 | } 62 | 63 | /* tree iterator */ 64 | 65 | val treeCtx = Context( 66 | reports.validation.predicate -> { (n: Any) => n == 0 } 67 | ) 68 | 69 | sealed trait Tree 70 | case class Node(x: Int, left: Tree, right: Tree) extends Tree 71 | case object Empty extends Tree 72 | 73 | var iterator: Coroutine._1[Tree, Int, Unit] = _ 74 | 75 | @gen("sizes") 76 | @benchmark("coroutines.boxing.tree-iterator") 77 | @curve("coroutine") 78 | @ctx("treeCtx") 79 | def tree(sz: Int) { 80 | def gen(sz: Int): Tree = { 81 | if (sz == 0) Empty 82 | else { 83 | val rem = sz - 1 84 | val left = gen(rem / 2) 85 | val right = gen(rem - rem / 2) 86 | Node(sz, left, right) 87 | } 88 | } 89 | val tree = gen(sz) 90 | 91 | iterator = coroutine { (t: Tree) => 92 | t match { 93 | case n: Node => 94 | iterator(n.left) 95 | yieldval(n.x) 96 | iterator(n.right) 97 | case Empty => 98 | } 99 | } 100 | 101 | val c = call(iterator(tree)) 102 | while (c.pull) c.value 103 | } 104 | 105 | /* Fibonacci */ 106 | 107 | val fibCtx = Context( 108 | reports.validation.predicate -> { (n: Any) => n == 1 } 109 | ) 110 | 111 | val fibSizes = Gen.single("size")(10) 112 | 113 | @gen("fibSizes") 114 | @benchmark("coroutines.boxing.fibonacci") 115 | @curve("coroutine") 116 | @ctx("fibCtx") 117 | def fibonacci(sz: Int) { 118 | var fib: _1$spec$I[Unit, Int] = null 119 | fib = coroutine { (n: Int) => 120 | if (n <= 1) 1 121 | else fib(n - 1) + fib(n - 2) 122 | } 123 | val c = call(fib(sz)) 124 | while (c.pull) c.value 125 | } 126 | 127 | val fibSugarCtx = Context( 128 | reports.validation.predicate -> { (n: Any) => n == 178 } 129 | ) 130 | 131 | @gen("fibSizes") 132 | @benchmark("coroutines.boxing.fibonacci") 133 | @curve("coroutine-sugar") 134 | @ctx("fibSugarCtx") 135 | def fibonacciSugar(sz: Int) { 136 | var fibsugar: Int ~~> (Unit, Int) = null 137 | fibsugar = coroutine { (n: Int) => 138 | if (n <= 1) 1 139 | else fibsugar(n - 1) + fibsugar(n - 2) 140 | } 141 | val cs = call(fibsugar(sz)) 142 | while (cs.pull) cs.value 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /src/test/scala/org/coroutines/coroutine-syntax-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalatest._ 6 | import scala.util.Failure 7 | 8 | 9 | 10 | class CoroutineSyntaxTest extends FunSuite with Matchers { 11 | test("coroutine instance must have nicer syntax") { 12 | val id: Int ~~> (Nothing, Int) = coroutine { (x: Int) => x } 13 | val c: Nothing <~> Int = call(id(5)) 14 | assert(!c.resume) 15 | assert(c.result == 5) 16 | assert(c.isCompleted) 17 | } 18 | 19 | test("Coroutine._0 must be invoked") { 20 | val rube = coroutine { () => 21 | yieldval(5) 22 | yieldval(-5) 23 | "ok" 24 | } 25 | val co: ~~~>[Int, String] = rube 26 | val c = call(co()) 27 | assert(c.resume) 28 | assert(c.value == 5) 29 | assert(c.resume) 30 | assert(c.value == -5) 31 | assert(!c.resume) 32 | assert(c.result == "ok") 33 | assert(c.isCompleted) 34 | } 35 | 36 | test("Coroutine._1 must be invoked") { 37 | val rube: Coroutine._1[Int, Int, String] = coroutine { (x: Int) => 38 | yieldval(x + x) 39 | yieldval(x - 2 * x) 40 | "ok" * x 41 | } 42 | 43 | val co: Int ~~> (Int, String) = rube 44 | val c = call(co(7)) 45 | assert(c.resume) 46 | assert(c.value == 14) 47 | assert(c.resume) 48 | assert(c.value == -7) 49 | assert(!c.resume) 50 | assert(c.result == "ok" * 7) 51 | assert(c.isCompleted) 52 | } 53 | 54 | test("Coroutine._1 must be invoked for a tuple argument") { 55 | val rube = coroutine { (t: (Int, String)) => 56 | yieldval(t._1) 57 | t._2 58 | } 59 | 60 | val co: (Int, String) ~~> (Int, String) = rube 61 | val c = call(co((7, "ok"))) 62 | assert(c.resume) 63 | assert(c.value == 7) 64 | assert(!c.resume) 65 | assert(c.result == "ok") 66 | assert(c.isCompleted) 67 | } 68 | 69 | test("Coroutine._2 must be invoked") { 70 | val rube = coroutine { (x: Int, y: Int) => 71 | yieldval(x + y) 72 | yieldval(x - y) 73 | (x * y).toString 74 | } 75 | 76 | val co: (Int, Int) ~> (Int, String) = rube 77 | val c = call(co(7, 4)) 78 | assert(c.resume) 79 | assert(c.value == 11) 80 | assert(c.resume) 81 | assert(c.value == 3) 82 | assert(!c.resume) 83 | assert(c.result == "28") 84 | assert(c.isCompleted) 85 | } 86 | 87 | test("Coroutine._3 must be invoked") { 88 | val rube = coroutine { (x: Int, y: Int, z: Int) => 89 | yieldval(x) 90 | yieldval(y) 91 | z.toString 92 | } 93 | 94 | val co: (Int, Int, Int) ~> (Int, String) = rube 95 | val c = call(co(3, 5, 8)) 96 | assert(c.resume) 97 | assert(c.value == 3) 98 | assert(c.resume) 99 | assert(c.value == 5) 100 | assert(!c.resume) 101 | assert(c.result == "8") 102 | assert(c.isCompleted) 103 | } 104 | 105 | test("Another coroutine must be invoked without syntax sugar") { 106 | val gimmeFive = coroutine { () => 5 } 107 | val rube: ~~~>[Nothing, Int] = coroutine { () => 108 | gimmeFive() 109 | } 110 | 111 | val c = call(rube()) 112 | assert(!c.resume) 113 | assert(c.result == 5) 114 | assert(c.isCompleted) 115 | } 116 | 117 | test("Another arity-0 coroutine must be invoked with syntax sugar") { 118 | val gimmeFive: ~~~>[Nothing, Int] = coroutine { () => 5 } 119 | val rube = coroutine { () => 120 | gimmeFive() 121 | } 122 | 123 | val c = call(rube()) 124 | assert(!c.resume) 125 | assert(c.result == 5) 126 | assert(c.isCompleted) 127 | } 128 | 129 | test("Another arity-1 coroutine must be invoked with syntax sugar") { 130 | val neg: Int ~~> (Nothing, Int) = coroutine { (x: Int) => -x } 131 | val rube = coroutine { () => 132 | neg(17) 133 | } 134 | 135 | val c = call(rube()) 136 | assert(!c.resume) 137 | assert(c.result == -17) 138 | assert(c.isCompleted) 139 | } 140 | 141 | test("Another arity-2 coroutine must be invoked with syntax sugar") { 142 | val mult: (Int, Int) ~> (Nothing, Int) = coroutine { (x: Int, y: Int) => x * y } 143 | val rube = coroutine { () => 144 | mult(3, 4) 145 | } 146 | 147 | val c = call(rube()) 148 | assert(!c.resume) 149 | assert(c.result == 12) 150 | assert(c.isCompleted) 151 | } 152 | 153 | test("Another arity-3 coroutine must be invoked with syntax sugar") { 154 | val mult: (Int, Int, Int) ~> (Nothing, Int) = coroutine { 155 | (x: Int, y: Int, z: Int) => x * y * z 156 | } 157 | val rube = coroutine { () => 158 | mult(3, 4, 5) 159 | } 160 | 161 | val c = call(rube()) 162 | assert(!c.resume) 163 | assert(c.result == 60) 164 | assert(c.isCompleted) 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/test/scala/org/coroutines/pattern-match-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalatest._ 6 | import scala.util.Failure 7 | 8 | 9 | 10 | class PatternMatchTest extends FunSuite with Matchers { 11 | test("simple pattern match") { 12 | val rube = coroutine { (x: AnyRef) => 13 | x match { 14 | case s: String => s.length 15 | case xs: List[_] => xs.size 16 | } 17 | } 18 | 19 | val c1 = call(rube("ok")) 20 | assert(!c1.resume) 21 | assert(c1.result == 2) 22 | assert(c1.isCompleted) 23 | val c2 = call(rube(1 :: 2 :: 3 :: Nil)) 24 | assert(!c2.resume) 25 | assert(c2.result == 3) 26 | assert(c2.isCompleted) 27 | } 28 | 29 | test("pattern match with yields") { 30 | val rube = coroutine { (x: AnyRef) => 31 | x match { 32 | case s: String => yieldval(s.length) 33 | case xs: List[_] => yieldval(xs.size) 34 | } 35 | 17 36 | } 37 | 38 | val c1 = call(rube("ok")) 39 | assert(c1.resume) 40 | assert(c1.value == 2) 41 | assert(!c1.resume) 42 | assert(c1.result == 17) 43 | assert(c1.isCompleted) 44 | val c2 = call(rube(1 :: 2 :: 3 :: Nil)) 45 | assert(c2.resume) 46 | assert(c2.value == 3) 47 | assert(!c2.resume) 48 | assert(c2.result == 17) 49 | assert(c2.isCompleted) 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /src/test/scala/org/coroutines/regression-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalatest._ 6 | import scala.util.Failure 7 | 8 | 9 | 10 | class RegressionTest extends FunSuite with Matchers { 11 | test("should declare body with if statement") { 12 | val xOrY = coroutine { (x: Int, y: Int) => 13 | if (x > 0) { 14 | yieldval(x) 15 | } else { 16 | yieldval(y) 17 | } 18 | } 19 | val c1 = call(xOrY(5, 2)) 20 | assert(c1.resume) 21 | assert(c1.value == 5) 22 | assert(!c1.resume) 23 | assert(c1.isCompleted) 24 | c1.result 25 | assert(!c1.hasException) 26 | val c2 = call(xOrY(-2, 7)) 27 | assert(c2.resume) 28 | assert(c2.value == 7) 29 | assert(!c2.resume) 30 | assert(c2.isCompleted) 31 | c2.result 32 | assert(!c2.hasException) 33 | } 34 | 35 | test("coroutine should have a nested if statement") { 36 | val numbers = coroutine { () => 37 | var z = 1 38 | var i = 1 39 | while (i < 5) { 40 | if (z > 0) { 41 | yieldval(z * i) 42 | z = -1 43 | } else { 44 | yieldval(z * i) 45 | z = 1 46 | i += 1 47 | } 48 | } 49 | } 50 | val c = call(numbers()) 51 | for (i <- 1 until 5) { 52 | assert(c.resume) 53 | assert(c.value == i) 54 | assert(c.resume) 55 | assert(c.value == -i) 56 | } 57 | } 58 | 59 | test("coroutine should call a coroutine with a different return type") { 60 | val stringer = coroutine { (x: Int) => x.toString } 61 | val caller = coroutine { (x: Int) => 62 | val s = stringer(2 * x) 63 | yieldval(s) 64 | x * 3 65 | } 66 | 67 | val c = call(caller(5)) 68 | assert(c.resume) 69 | assert(c.value == "10") 70 | assert(!c.resume) 71 | assert(c.result == 15) 72 | } 73 | 74 | test("issue #14 -- simple case") { 75 | object Test { 76 | val foo: Int ~~> (Int, Unit) = coroutine { (i: Int) => 77 | yieldval(i) 78 | if (i > 0) { 79 | foo(i - 1) 80 | foo(i - 1) 81 | } 82 | } 83 | } 84 | 85 | val c = call(Test.foo(2)) 86 | assert(c.resume) 87 | assert(c.value == 2) 88 | assert(c.resume) 89 | assert(c.value == 1) 90 | assert(c.resume) 91 | assert(c.value == 0) 92 | assert(c.resume) 93 | assert(c.value == 0) 94 | assert(c.resume) 95 | assert(c.value == 1) 96 | assert(c.resume) 97 | assert(c.value == 0) 98 | assert(c.resume) 99 | assert(c.value == 0) 100 | assert(!c.resume) 101 | } 102 | 103 | test("issue #14 -- complex case") { 104 | object Test { 105 | val foo: Int ~~> (Int, Unit) = coroutine { (i: Int) => 106 | yieldval(i) 107 | if (i > 0) { 108 | foo(i - 1) 109 | foo(i - 1) 110 | } 111 | } 112 | } 113 | 114 | val bar = coroutine { () => 115 | Test.foo(2) 116 | Test.foo(2) 117 | } 118 | 119 | val c = call(bar()) 120 | assert(c.resume) 121 | assert(c.value == 2) 122 | assert(c.resume) 123 | assert(c.value == 1) 124 | assert(c.resume) 125 | assert(c.value == 0) 126 | assert(c.resume) 127 | assert(c.value == 0) 128 | assert(c.resume) 129 | assert(c.value == 1) 130 | assert(c.resume) 131 | assert(c.value == 0) 132 | assert(c.resume) 133 | assert(c.value == 0) 134 | assert(c.resume) 135 | assert(c.value == 2) 136 | assert(c.resume) 137 | assert(c.value == 1) 138 | assert(c.resume) 139 | assert(c.value == 0) 140 | assert(c.resume) 141 | assert(c.value == 0) 142 | assert(c.resume) 143 | assert(c.value == 1) 144 | assert(c.resume) 145 | assert(c.value == 0) 146 | assert(c.resume) 147 | assert(c.value == 0) 148 | assert(!c.resume) 149 | } 150 | 151 | test("issue #15 -- hygiene") { 152 | val scala, Any, String, TypeTag, Unit = () 153 | trait scala; trait Any; trait String; trait TypeTag; trait Unit 154 | 155 | val id = coroutine { (x: Int) => 156 | x 157 | } 158 | } 159 | 160 | test("issue #15 -- more hygiene") { 161 | val org, coroutines, Coroutine = () 162 | trait org; trait coroutines; trait Coroutine 163 | 164 | val id = coroutine { () => } 165 | } 166 | 167 | test("should use c as an argument name") { 168 | val nuthin = coroutine { () => } 169 | val resumer = coroutine { (c: Nothing <~> Unit) => 170 | c.resume 171 | } 172 | val c = call(nuthin()) 173 | val r = call(resumer(c)) 174 | assert(!r.resume) 175 | assert(!r.hasException) 176 | assert(r.hasResult) 177 | } 178 | 179 | test("issue #21") { 180 | val test = coroutine { () => {} } 181 | val foo = coroutine { () => { 182 | test() 183 | test() 184 | test() 185 | test() 186 | test() 187 | test() 188 | test() 189 | test() 190 | test() 191 | test() 192 | test() 193 | test() 194 | test() 195 | test() 196 | test() 197 | test() 198 | test() 199 | test() 200 | test() 201 | test() 202 | test() 203 | test() 204 | test() 205 | test() 206 | test() 207 | test() 208 | test() 209 | test() 210 | test() 211 | test() // Lines after this did not previously compile. 212 | test() 213 | test() 214 | } 215 | } 216 | } 217 | 218 | test("must catch exception passed from a direct call") { 219 | val buggy = coroutine { () => 220 | throw new Exception 221 | } 222 | val catchy = coroutine { () => 223 | var result = "initial value" 224 | try { 225 | buggy() 226 | "not ok..." 227 | } catch { 228 | case e: Exception => 229 | result = "caught!" 230 | } 231 | result 232 | } 233 | 234 | val c = call(catchy()) 235 | assert(!c.resume) 236 | assert(c.result == "caught!") 237 | } 238 | } 239 | -------------------------------------------------------------------------------- /src/test/scala/org/coroutines/snapshot-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalatest._ 6 | import scala.collection._ 7 | import scala.util.Failure 8 | 9 | 10 | 11 | class SnapshotTest extends FunSuite with Matchers { 12 | test("coroutine instance should be cloned and resumed as needed") { 13 | val countdown = coroutine { (n: Int) => 14 | var i = n 15 | while (i >= 0) { 16 | yieldval(i) 17 | i -= 1 18 | } 19 | } 20 | 21 | val c = call(countdown(10)) 22 | for (i <- 0 until 5) { 23 | assert(c.resume) 24 | assert(c.value == (10 - i)) 25 | } 26 | val c2 = c.snapshot 27 | for (i <- 5 to 10) { 28 | assert(c2.resume) 29 | assert(c2.value == (10 - i)) 30 | } 31 | for (i <- 5 to 10) { 32 | assert(c.resume) 33 | assert(c.value == (10 - i)) 34 | } 35 | } 36 | } 37 | -------------------------------------------------------------------------------- /src/test/scala/org/coroutines/try-catch-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalatest._ 6 | import scala.concurrent._ 7 | import scala.concurrent.duration._ 8 | import scala.concurrent.ExecutionContext.Implicits.global 9 | import scala.util.Failure 10 | 11 | 12 | 13 | class TryCatchTest extends FunSuite with Matchers { 14 | test("try-catch block") { 15 | val rube = coroutine { () => 16 | try { 17 | throw new Exception 18 | } catch { 19 | case e: Exception => 20 | } 21 | } 22 | 23 | val c0 = call(rube()) 24 | assert(!c0.resume) 25 | assert(c0.isCompleted) 26 | c0.result 27 | assert(!c0.hasException) 28 | } 29 | 30 | test("try-catch-finally block") { 31 | val rube = coroutine { () => 32 | try { 33 | throw new Error 34 | } catch { 35 | case e: Error => 36 | } finally { 37 | sys.error("done") 38 | } 39 | } 40 | 41 | val c0 = call(rube()) 42 | assert(!c0.resume) 43 | assert(c0.isCompleted) 44 | c0.tryResult match { 45 | case Failure(re: RuntimeException) => assert(re.getMessage == "done") 46 | case _ => assert(false) 47 | } 48 | } 49 | 50 | test("try-catch-finally and several exception types") { 51 | var completed = false 52 | var runtime = false 53 | var error = false 54 | val rube = coroutine { (t: Throwable) => 55 | try { 56 | throw t 57 | } catch { 58 | case e: RuntimeException => 59 | runtime = true 60 | case e: Error => 61 | error = true 62 | } finally { 63 | completed = true 64 | } 65 | } 66 | 67 | val c0 = call(rube(new Error)) 68 | assert(!runtime) 69 | assert(!error) 70 | assert(!completed) 71 | assert(!c0.resume) 72 | c0.result 73 | assert(!c0.hasException) 74 | assert(!runtime) 75 | assert(error) 76 | assert(completed) 77 | assert(c0.isCompleted) 78 | } 79 | 80 | test("coroutine with a throw statement") { 81 | val rube = coroutine { () => 82 | throw { 83 | val str = "boom" 84 | new Exception(str) 85 | } 86 | } 87 | 88 | val c = call(rube()) 89 | assert(!c.resume) 90 | c.tryResult match { 91 | case Failure(e: Exception) => assert(e.getMessage == "boom") 92 | case _ => assert(false) 93 | } 94 | } 95 | 96 | test("invoke another coroutine that throws") { 97 | val boom = coroutine { () => throw new Exception("kaboom") } 98 | val rube = coroutine { () => 99 | boom() 100 | } 101 | 102 | val c = call(rube()) 103 | assert(!c.resume) 104 | c.tryResult match { 105 | case Failure(e: Exception) => assert(e.getMessage == "kaboom") 106 | case _ => assert(false) 107 | } 108 | } 109 | 110 | test("yield inside throw") { 111 | val rube = coroutine { () => 112 | try { 113 | yieldval("inside") 114 | } catch { 115 | case r: RuntimeException => "runtime" 116 | case e: Exception => "generic" 117 | } 118 | "done" 119 | } 120 | 121 | val c = call(rube()) 122 | assert(c.resume) 123 | assert(c.value == "inside") 124 | assert(!c.resume) 125 | assert(c.result == "done") 126 | assert(c.isCompleted) 127 | } 128 | 129 | test("throw and then yield") { 130 | val rube = coroutine { () => 131 | throw new Exception("au revoir") 132 | yieldval("bonjour") 133 | } 134 | 135 | val c = call(rube()) 136 | assert(!c.resume) 137 | assert(c.hasException) 138 | assert(c.getValue == None) 139 | c.tryResult match { 140 | case Failure(e: Exception) => assert(e.getMessage == "au revoir") 141 | case _ => assert(false) 142 | } 143 | } 144 | 145 | test("try/catch with different return types") { 146 | val c = coroutine { () => 147 | try { 148 | () 149 | } catch { 150 | case _: Throwable => Future("ho") 151 | } 152 | Future("oh") 153 | } 154 | val instance = call(c()) 155 | assert(!instance.resume) 156 | assert(Await.result(instance.result, 1 seconds) == "oh") 157 | } 158 | 159 | test("try/catch with same return type") { 160 | val c = coroutine { () => 161 | try { 162 | "ho_1" 163 | } catch { 164 | case _: Throwable => "ho_2" 165 | } 166 | "ho_3" 167 | } 168 | val instance = call(c()) 169 | assert(!instance.resume) 170 | assert(instance.result == "ho_3") 171 | } 172 | } 173 | -------------------------------------------------------------------------------- /src/test/scala/org/coroutines/yieldto-tests.scala: -------------------------------------------------------------------------------- 1 | package org.coroutines 2 | 3 | 4 | 5 | import org.scalatest._ 6 | import scala.collection._ 7 | import scala.util.Failure 8 | 9 | 10 | 11 | class YieldToTest extends FunSuite with Matchers { 12 | test("after resuming to another coroutine, there should be no value") { 13 | val another = coroutine { () => 14 | yieldval("Yohaha") 15 | } 16 | val anotherInstance = call(another()) 17 | 18 | val rube = coroutine { () => 19 | yieldval("started") 20 | yieldto(anotherInstance) 21 | } 22 | 23 | val c = call(rube()) 24 | assert(c.resume) 25 | assert(c.hasValue) 26 | assert(c.value == "started") 27 | assert(c.resume) 28 | assert(!c.hasValue) 29 | assert(!c.resume) 30 | assert(!c.hasValue) 31 | assert(c.isCompleted) 32 | assert(anotherInstance.hasValue) 33 | assert(anotherInstance.value == "Yohaha") 34 | } 35 | 36 | test("yielding to a completed coroutine raises an error") { 37 | val another = coroutine { () => "in and out" } 38 | val anotherInstance = call(another()) 39 | assert(!anotherInstance.resume) 40 | 41 | val rube = coroutine { () => 42 | yieldto(anotherInstance) 43 | yieldval("some more") 44 | } 45 | val c = call(rube()) 46 | assert(!c.resume) 47 | c.tryResult match { 48 | case Failure(e: CoroutineStoppedException) => 49 | case _ => assert(false, "Should have thrown an exception.") 50 | } 51 | } 52 | 53 | test("should be able to yield to a differently typed coroutine") { 54 | val another: ~~~>[String, Unit] = coroutine { () => 55 | yieldval("hohoho") 56 | } 57 | val anotherInstance = call(another()) 58 | 59 | val rube: Int ~~> (Int, Int) = coroutine { (x: Int) => 60 | yieldval(-x) 61 | yieldto(anotherInstance) 62 | x 63 | } 64 | val c = call(rube(5)) 65 | 66 | assert(c.resume) 67 | assert(c.value == -5) 68 | assert(c.resume) 69 | assert(!c.hasValue) 70 | assert(!c.resume) 71 | assert(c.result == 5) 72 | assert(anotherInstance.hasValue) 73 | assert(anotherInstance.value == "hohoho") 74 | } 75 | 76 | test("should drain the coroutine instance that yields to another coroutine") { 77 | val another: ~~~>[String, Unit] = coroutine { () => 78 | yieldval("uh-la-la") 79 | } 80 | val anotherInstance = call(another()) 81 | 82 | val rube: (Int, Int) ~> (Int, Unit) = coroutine { (x: Int, y: Int) => 83 | yieldval(x) 84 | yieldval(y) 85 | yieldto(anotherInstance) 86 | yieldval(x * y) 87 | } 88 | val c = call(rube(5, 4)) 89 | 90 | val b = mutable.Buffer[Int]() 91 | while (c.resume) if (c.hasValue) b += c.value 92 | 93 | assert(b == Seq(5, 4, 20)) 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/AsyncAwait.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import scala.annotation.unchecked.uncheckedVariance 7 | import scala.concurrent._ 8 | import scala.concurrent.duration._ 9 | import scala.concurrent.ExecutionContext.Implicits.global 10 | 11 | 12 | 13 | object AsyncAwait { 14 | class Cell[+T] { 15 | var x: T @uncheckedVariance = _ 16 | } 17 | 18 | /** The future should be computed after the pair is yielded. The result of 19 | * this future can be used to assign a value to `cell.x`. 20 | * Note that `Cell` is used in order to give users the option to not directly 21 | * return the result of the future. 22 | */ 23 | def await[R]: Future[R] ~~> ((Future[R], Cell[R]), R) = 24 | coroutine { (f: Future[R]) => 25 | val cell = new Cell[R] 26 | yieldval((f, cell)) 27 | cell.x 28 | } 29 | 30 | def async[Y, R](body: ~~~>[(Future[Y], Cell[Y]), R]): Future[R] = { 31 | val c = call(body()) 32 | val p = Promise[R] 33 | def loop() { 34 | if (!c.resume) p.success(c.result) 35 | else { 36 | val (future, cell) = c.value 37 | for (x <- future) { 38 | cell.x = x 39 | loop() 40 | } 41 | } 42 | } 43 | Future { loop() } 44 | p.future 45 | } 46 | 47 | def main(args: Array[String]) { 48 | val f = Future { math.sqrt(121) } 49 | val g = Future { math.abs(-15) } 50 | /** Calls to yieldval inside an inner coroutine are yield points inside the 51 | * outer coroutine. 52 | */ 53 | val h = async(coroutine { () => 54 | val x = await { f } 55 | val y = await { g } 56 | x + y 57 | }) 58 | 59 | val res = scala.concurrent.Await.result(h, 5.seconds) 60 | assert(res == 26.0) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/Composition.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | 7 | 8 | 9 | object Composition { 10 | private val optionElems = coroutine { (opt: Option[Int]) => 11 | opt match { 12 | case Some(x) => yieldval(x) 13 | case None => // do nothing 14 | } 15 | } 16 | 17 | private val optionListElems = coroutine { (xs: List[Option[Int]]) => 18 | var curr = xs 19 | while (curr != Nil) { 20 | optionElems(curr.head) 21 | curr = curr.tail 22 | } 23 | } 24 | 25 | def main(args: Array[String]) { 26 | val xs = Some(1) :: None :: Some(3) :: Nil 27 | val c = call(optionListElems(xs)) 28 | assert(c.resume) 29 | assert(c.value == 1) 30 | assert(c.resume) 31 | assert(c.value == 3) 32 | assert(!c.resume) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/CompositionCall.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | 7 | 8 | 9 | object CompositionCall { 10 | private val optionElems = coroutine { (opt: Option[Int]) => 11 | opt match { 12 | case Some(x) => yieldval(x) 13 | case None => // do nothing 14 | } 15 | } 16 | 17 | private val optionListElems = coroutine { (xs: List[Option[Int]]) => 18 | var curr = xs 19 | while (curr != Nil) { 20 | val c = call(optionElems(curr.head)) 21 | while (c.resume) yieldval(c.value) 22 | curr = curr.tail 23 | } 24 | } 25 | 26 | def main(args: Array[String]) { 27 | val xs = Some(1) :: None :: Some(3) :: Nil 28 | val c = call(optionListElems(xs)) 29 | assert(c.resume) 30 | assert(c.value == 1) 31 | assert(c.resume) 32 | assert(c.value == 3) 33 | assert(!c.resume) 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/ControlTransfer.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import scala.collection._ 7 | import scala.util.Random 8 | 9 | 10 | 11 | object ControlTransfer { 12 | var error: String = "" 13 | val check: ~~~>[Boolean, Unit] = coroutine { () => 14 | yieldval(true) 15 | error = "Total failure." 16 | yieldval(false) 17 | } 18 | val checker = call(check()) 19 | 20 | /** From within `r1`, the call `yieldto(checker)` will evaluate `checker` 21 | * until `checker` releases control. Then, `r1` will release control. 22 | * After this happens, `r1.hasValue` will be false; yielded values won't 23 | * propagate upwards because of calls to `yieldto`. 24 | */ 25 | val random: ~~~>[Double, Unit] = coroutine { () => 26 | yieldval(Random.nextDouble()) 27 | yieldto(checker) 28 | yieldval(Random.nextDouble()) 29 | } 30 | 31 | def main(args: Array[String]) { 32 | val r0 = call(random()) 33 | assert(r0.resume) 34 | assert(r0.hasValue) 35 | assert(r0.resume) 36 | assert(!r0.hasValue) 37 | assert(r0.resume) 38 | assert(r0.hasValue) 39 | assert(!r0.resume) 40 | assert(!r0.hasValue) 41 | 42 | val r1 = call(random()) 43 | val values = mutable.Buffer[Double]() 44 | while (r1.resume) if (r1.hasValue) values += r1.value 45 | assert(values.length == 2) 46 | assert(error == "Total failure.") 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/ControlTransferWithPull.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import scala.collection._ 7 | import scala.util.Random 8 | 9 | 10 | 11 | object ControlTransferWithPull { 12 | var error: String = "" 13 | val check: ~~~>[Boolean, Unit] = coroutine { () => 14 | yieldval(true) 15 | error = "Total failure." 16 | yieldval(false) 17 | } 18 | val checker = call(check()) 19 | 20 | val random: ~~~>[Double, Unit] = coroutine { () => 21 | yieldval(Random.nextDouble()) 22 | yieldto(checker) 23 | yieldval(Random.nextDouble()) 24 | } 25 | 26 | def main(args: Array[String]) { 27 | val r0 = call(random()) 28 | assert(r0.resume) 29 | assert(r0.hasValue) 30 | assert(r0.resume) 31 | assert(!r0.hasValue) 32 | assert(r0.resume) 33 | assert(r0.hasValue) 34 | assert(!r0.resume) 35 | assert(!r0.hasValue) 36 | 37 | val r1 = call(random()) 38 | val values = mutable.Buffer[Double]() 39 | while (r1.pull) values += r1.value 40 | assert(values.length == 2) 41 | assert(error == "Total failure.") 42 | } 43 | } 44 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/Datatypes.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | 7 | 8 | 9 | object Datatypes { 10 | val whileRange = coroutine { (n: Int) => 11 | var i = 0 12 | while (i < n) { 13 | yieldval(i) 14 | i += 1 15 | } 16 | } 17 | 18 | val doWhileRange = coroutine { (n: Int) => 19 | var i = 0 20 | do { 21 | yieldval(i) 22 | i += 1 23 | } while (i < n) 24 | } 25 | 26 | def assertEqualsRange(n: Int, co: Int ~~> (Int, Unit)) { 27 | val c = call(co(n)) 28 | for (i <- 0 until n) { 29 | c.resume 30 | assert(c.value == i) 31 | } 32 | } 33 | 34 | def main(args: Array[String]) { 35 | assertEqualsRange(5, whileRange) 36 | assertEqualsRange(5, doWhileRange) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/Exceptions.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import scala.collection._ 7 | import scala.util.Failure 8 | 9 | 10 | 11 | object Exceptions { 12 | case class TestException() extends Throwable 13 | 14 | val kaboom = coroutine { (x: Int) => 15 | yieldval(x) 16 | try { 17 | sys.error("will be caught") 18 | } catch { 19 | case e: RuntimeException => yieldval("oops") 20 | } 21 | throw TestException() 22 | } 23 | 24 | def main(args: Array[String]) { 25 | val c = call(kaboom(5)) 26 | assert(c.resume) 27 | assert(c.value == 5) 28 | assert(c.resume) 29 | assert(c.value == "oops") 30 | assert(!c.resume) 31 | assert(c.tryResult == Failure(TestException())) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/FaqSimpleExample.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | 7 | 8 | 9 | object FaqSimpleExample { 10 | val range = coroutine { (n: Int) => 11 | var i = 0 12 | while (i < n) { 13 | yieldval(i) 14 | i += 1 15 | } 16 | } 17 | 18 | def extract(c: Int <~> Unit): Seq[Int] = { 19 | var xs: List[Int] = Nil 20 | while (c.resume) if (c.hasValue) xs ::= c.value 21 | xs.reverse 22 | } 23 | 24 | def main(args: Array[String]) { 25 | val instance = call(range(10)) 26 | val elems = extract(instance) 27 | assert(elems == (0 until 10)) 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/Identity.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | 7 | 8 | 9 | object Identity { 10 | val id = coroutine { (x: Int) => x } 11 | 12 | def main(args: Array[String]) { 13 | val c = call(id(7)) 14 | assert(!c.resume) 15 | assert(c.result == 7) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/Lifecycle.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import scala.collection._ 7 | import scala.util.Success 8 | 9 | 10 | 11 | object Lifecycle { 12 | val katamari: Int ~~> (String, Int) = coroutine { (n: Int) => 13 | var i = 1 14 | yieldval("naaaa") 15 | while (i < n) { 16 | yieldval("na") 17 | i += 1 18 | } 19 | yieldval("Katamari Damacy!") 20 | i + 2 21 | } 22 | 23 | def main(args: Array[String]) { 24 | val c = call(katamari(9)) 25 | assert(c.resume) 26 | assert(c.hasValue) 27 | assert(c.value == "naaaa") 28 | for (i <- 1 until 9) { 29 | assert(c.resume) 30 | assert(c.getValue == Some("na")) 31 | } 32 | assert(c.resume) 33 | assert(c.tryValue == Success("Katamari Damacy!")) 34 | assert(!c.resume) 35 | assert(c.getValue == None) 36 | assert(c.result == 11) 37 | assert(c.isCompleted) 38 | assert(!c.isLive) 39 | 40 | val theme = "naaaa na na na na na na na na Katamari Damacy!" 41 | assert(drain(call(katamari(9))) == theme) 42 | } 43 | 44 | def drain(f: String <~> Int): String = { 45 | val buffer = mutable.Buffer[String]() 46 | while (f.resume) buffer += f.value 47 | buffer.mkString(" ") 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/MockSnapshot.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import scala.util._ 7 | 8 | 9 | 10 | object MockSnapshot { 11 | abstract class TestSuite { 12 | class Cell { 13 | var value = false 14 | } 15 | 16 | val mock: ~~~>[Cell, Boolean] = coroutine { () => 17 | val cell = new Cell 18 | yieldval(cell) 19 | cell.value 20 | } 21 | 22 | /** Returns true if either `c.isCompleted && c.hasResult` or if the rest 23 | * of the coroutine is satisfied `test` regardless of the veracity of 24 | * `c.value`. 25 | */ 26 | def test[R](c: Cell <~> R): Boolean = { 27 | if (c.resume) { 28 | val cell = c.value 29 | cell.value = true 30 | val res0 = test(c.snapshot) 31 | cell.value = false 32 | val res1 = test(c) 33 | res0 && res1 34 | } else c.hasResult 35 | } 36 | } 37 | 38 | class MyTestSuite extends TestSuite { 39 | val myAlgorithm = coroutine { (x: Int) => 40 | if (mock()) { 41 | assert(2 * x == x + x) 42 | } else { 43 | assert(x * x / x == x) 44 | } 45 | } 46 | 47 | assert(test(call(myAlgorithm(5)))) 48 | 49 | // False because there is division by zero. 50 | assert(!test(call(myAlgorithm(0)))) 51 | } 52 | 53 | def main(args: Array[String]) { 54 | new MyTestSuite 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/Snapshot.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | 7 | 8 | 9 | object Snapshot { 10 | val values = coroutine { () => 11 | yieldval(1) 12 | yieldval(2) 13 | yieldval(3) 14 | } 15 | 16 | def main(args: Array[String]) { 17 | val c = call(values()) 18 | assert(c.resume) 19 | assert(c.value == 1) 20 | val c2 = c.snapshot 21 | assert(c.resume) 22 | assert(c.value == 2) 23 | assert(c.resume) 24 | assert(c.value == 3) 25 | assert(c2.resume) 26 | assert(c2.value == 2) 27 | assert(c2.resume) 28 | assert(c2.value == 3) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/VowelCount.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.coroutines._ 6 | 7 | 8 | 9 | object VowelCounts { 10 | val vowelcounts = coroutine { (s: String) => 11 | yieldval(s.count(_ == 'a')) 12 | yieldval(s.count(_ == 'e')) 13 | yieldval(s.count(_ == 'i')) 14 | yieldval(s.count(_ == 'o')) 15 | yieldval(s.count(_ == 'u')) 16 | } 17 | 18 | def main(args: Array[String]) { 19 | val c = call(vowelcounts("this the season to be jolie")) 20 | c.resume 21 | assert(c.value == 1) 22 | c.resume 23 | assert(c.value == 4) 24 | c.resume 25 | assert(c.value == 2) 26 | c.resume 27 | assert(c.value == 3) 28 | c.resume 29 | assert(c.value == 0) 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /src/test/scala/org/examples/examples-tests.scala: -------------------------------------------------------------------------------- 1 | package org.examples 2 | 3 | 4 | 5 | import org.scalatest._ 6 | import scala.util.Failure 7 | 8 | 9 | 10 | class ExamplesTest extends FunSuite with Matchers { 11 | test("identity coroutine") { 12 | Identity.main(Array()) 13 | } 14 | 15 | test("vowel counts") { 16 | VowelCounts.main(Array()) 17 | } 18 | 19 | test("datatypes") { 20 | Datatypes.main(Array()) 21 | } 22 | 23 | test("lifecycle") { 24 | Lifecycle.main(Array()) 25 | } 26 | 27 | test("exceptions") { 28 | Exceptions.main(Array()) 29 | } 30 | 31 | test("composition") { 32 | Composition.main(Array()) 33 | } 34 | 35 | test("composition call") { 36 | CompositionCall.main(Array()) 37 | } 38 | 39 | test("faq simple example") { 40 | FaqSimpleExample.main(Array()) 41 | } 42 | 43 | test("control transfer") { 44 | ControlTransfer.main(Array()) 45 | } 46 | 47 | test("control transfer with pull") { 48 | ControlTransferWithPull.main(Array()) 49 | } 50 | 51 | test("snapshot") { 52 | Snapshot.main(Array()) 53 | } 54 | 55 | test("mock snapshot") { 56 | MockSnapshot.main(Array()) 57 | } 58 | 59 | test("async/await") { 60 | AsyncAwait.main(Array()) 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /src/test/scala/org/separatepackage/SeparatePackageTest.scala: -------------------------------------------------------------------------------- 1 | package org.separatepackage 2 | 3 | 4 | 5 | import org.coroutines._ 6 | import org.scalatest._ 7 | import scala.util.Failure 8 | 9 | 10 | 11 | class SeparatePackageTest extends FunSuite with Matchers { 12 | test("should declare and run a coroutine") { 13 | val rube = coroutine { (x: Int) => 14 | yieldval(x * 2) 15 | if (x > 0) yieldval(x) 16 | else yieldval(-x) 17 | x + 1 18 | } 19 | 20 | val c0 = call(rube(2)) 21 | assert(c0.resume) 22 | assert(c0.value == 4) 23 | assert(c0.resume) 24 | assert(c0.value == 2) 25 | assert(!c0.resume) 26 | assert(c0.result == 3) 27 | assert(c0.isCompleted) 28 | 29 | val c1 = call(rube(-2)) 30 | assert(c1.resume) 31 | assert(c1.value == -4) 32 | assert(c1.resume) 33 | assert(c1.value == 2) 34 | assert(!c1.resume) 35 | assert(c1.result == -1) 36 | assert(c1.isCompleted) 37 | } 38 | 39 | test("Another coroutine must be invoked without syntax sugar") { 40 | val inc = coroutine { (x: Int) => x + 1 } 41 | val rube = coroutine { () => 42 | inc(3) 43 | } 44 | 45 | val c = call(rube()) 46 | assert(!c.resume) 47 | assert(c.result == 4) 48 | assert(c.isCompleted) 49 | } 50 | } 51 | -------------------------------------------------------------------------------- /version.conf: -------------------------------------------------------------------------------- 1 | coroutines_major=0 2 | coroutines_minor=8-SNAPSHOT 3 | --------------------------------------------------------------------------------