├── project ├── build.properties ├── plugins.sbt └── release_on_tag.sh ├── plugin └── src │ └── main │ ├── resources │ └── scalac-plugin.xml │ ├── scala │ └── scala │ │ └── tools │ │ └── selectivecps │ │ ├── SelectiveCPSPlugin.scala │ │ ├── CPSUtils.scala │ │ ├── SelectiveCPSTransform.scala │ │ └── CPSAnnotationChecker.scala │ ├── scala-2.11 │ └── scala │ │ └── tools │ │ └── selectivecps │ │ └── SelectiveANFTransform.scala │ └── scala-2.12 │ └── scala │ └── tools │ └── selectivecps │ └── SelectiveANFTransform.scala ├── CODE_OF_CONDUCT.md ├── README.md ├── NOTICE ├── .gitignore ├── .travis.yml ├── library └── src │ ├── test │ └── scala │ │ └── scala │ │ └── tools │ │ └── selectivecps │ │ ├── ShouldCompile.scala │ │ ├── CompilerErrors.scala │ │ └── TestSuite.scala │ └── main │ └── scala │ └── scala │ └── util │ └── continuations │ ├── package.scala │ └── ControlContext.scala └── LICENSE /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.3.12 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.scala-lang.modules" % "sbt-scala-module" % "2.2.0") 2 | -------------------------------------------------------------------------------- /plugin/src/main/resources/scalac-plugin.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | continuations 4 | scala.tools.selectivecps.SelectiveCPSPlugin 5 | 6 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | all repositories in these organizations: 2 | 3 | * [scala](https://github.com/scala) 4 | * [scalacenter](https://github.com/scalacenter) 5 | * [lampepfl](https://github.com/lampepfl) 6 | 7 | are covered by the Scala Code of Conduct: https://scala-lang.org/conduct/ 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | scala-continuations is a compiler plugin and library for Scala providing support for CPS (continuation-passing style) transformations. 2 | 3 | It is no longer maintained. Past releases (for Scala 2.12 and earlier) remain available on Maven Central. 4 | 5 | You might also be interested in https://github.com/scala/scala-async, which covers what was once the most common use case for the continuations plugin. 6 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Scala continuations 2 | Copyright (c) 2010-2020 EPFL 3 | Copyright (c) 2011-2020 Lightbend, Inc. 4 | 5 | Scala includes software developed at 6 | LAMP/EPFL (https://lamp.epfl.ch/) and 7 | Lightbend, Inc. (https://www.lightbend.com/). 8 | 9 | Licensed under the Apache License, Version 2.0 (the "License"). 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | # Are you tempted to edit this file? 3 | # 4 | # First consider if the changes make sense for all, 5 | # or if they are specific to your workflow/system. 6 | # If it is the latter, you can augment this list with 7 | # entries in .git/info/excludes 8 | # 9 | 10 | *.jar 11 | *~ 12 | 13 | # eclipse, intellij 14 | /.classpath 15 | /.project 16 | /src/intellij/*.iml 17 | /src/intellij/*.ipr 18 | /src/intellij/*.iws 19 | /.cache 20 | /.idea 21 | /.settings 22 | 23 | # bak files produced by ./cleanup-commit 24 | *.bak 25 | 26 | # Mac specific, but that is common enough a dev platform to warrant inclusion. 27 | .DS_Store 28 | 29 | target/ -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | version: ~> 1.0 # needed for imports 2 | 3 | import: scala/scala-dev:travis/default.yml 4 | 5 | language: scala 6 | 7 | scala: 8 | - 2.11.12 9 | - 2.12.8 10 | - 2.12.9 11 | - 2.12.10 12 | 13 | env: 14 | global: 15 | # The plugin needs to be released for every minor version, but the library is only released 16 | # once per major version (2.11, 2.12). The library is only published on: 17 | - LIBRARY_PUBLISH_SCALA_VERSIONS="2.11.12 2.12.10" 18 | matrix: 19 | - ADOPTOPENJDK=8 20 | - ADOPTOPENJDK=11 21 | 22 | install: 23 | - git fetch --tags # get all tags for sbt-dynver 24 | 25 | script: ./build.sh 26 | 27 | notifications: 28 | email: 29 | - adriaan.moors@lightbend.com 30 | - seth.tisue@lightbend.com 31 | -------------------------------------------------------------------------------- /project/release_on_tag.sh: -------------------------------------------------------------------------------- 1 | #/bin/bash 2 | 3 | # if the current commit has a tag named like v(\d+\.\d+\.\d+.*), 4 | # and we're running on the right jdk/branch, 5 | # echo the sbt commands that publish a release with the version derived from the tag 6 | publishJdk=openjdk8 7 | publishBranch=master 8 | publishScalaVersion=2.12.0 9 | 10 | unset tag version 11 | 12 | # Exit without error when no (annotated) tag is found. 13 | tag=$(git describe HEAD --exact-match 2>/dev/null || exit 0) 14 | 15 | version=$(echo $tag | perl -pe 's/v(\d+\.\d+\.\d+.*)/$1/') 16 | 17 | if [[ "$version" != "" &&\ 18 | "${TRAVIS_PULL_REQUEST}" == "false" &&\ 19 | "${JAVA_HOME}" == "$(jdk_switcher home $publishJdk)" &&\ 20 | "${TRAVIS_BRANCH}" == "${publishBranch}" &&\ 21 | "${TRAVIS_SCALA_VERSION}" == "${publishScalaVersion}" ]]; then 22 | echo \'"set every version := $version"\' publish-signed 23 | fi 24 | -------------------------------------------------------------------------------- /library/src/test/scala/scala/tools/selectivecps/ShouldCompile.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | import scala.collection.mutable.HashMap 14 | import scala.util.continuations._ 15 | 16 | // https://issues.scala-lang.org/browse/SI-3620 17 | object t3620 extends App { 18 | 19 | class Store[K,V] { 20 | 21 | trait Waiting { 22 | def key: K 23 | def inform(value: V): Unit 24 | } 25 | 26 | private val map = new HashMap[K, V] 27 | private var waiting: List[Waiting] = Nil 28 | 29 | def waitFor(k: K, f: (V => Unit)): Unit = { 30 | map.get(k) match { 31 | case Some(v) => f(v) 32 | case None => { 33 | val w = new Waiting { 34 | def key = k 35 | def inform(v: V) = f(v) 36 | } 37 | waiting = w :: waiting 38 | } 39 | } 40 | } 41 | 42 | 43 | def add(key: K, value: V): Unit = { 44 | map(key) = value 45 | val p = waiting.partition(_.key == key) 46 | waiting = p._2 47 | p._1.foreach(_.inform(value)) 48 | } 49 | 50 | def required(key: K) = { 51 | shift { 52 | c: (V => Unit) => { 53 | waitFor(key, c) 54 | } 55 | } 56 | } 57 | 58 | def option(key: Option[K]) = { 59 | shift { 60 | c: (Option[V] => Unit) => { 61 | key match { 62 | case Some(key) => waitFor(key, (v: V) => c(Some(v))) 63 | case None => c(None) 64 | } 65 | 66 | } 67 | } 68 | } 69 | 70 | } 71 | 72 | val store = new Store[String, Int] 73 | 74 | def test(p: Option[String]): Unit = { 75 | reset { 76 | // uncommenting the following two lines makes the compiler happy! 77 | // val o = store.option(p) 78 | // println(o) 79 | val i = store.option(p).getOrElse(1) 80 | println(i) 81 | } 82 | } 83 | 84 | test(Some("a")) 85 | 86 | } 87 | -------------------------------------------------------------------------------- /plugin/src/main/scala/scala/tools/selectivecps/SelectiveCPSPlugin.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.tools.selectivecps 14 | 15 | import scala.tools.nsc 16 | import nsc.Global 17 | import nsc.plugins.Plugin 18 | import nsc.plugins.PluginComponent 19 | 20 | class SelectiveCPSPlugin(val global: Global) extends Plugin { 21 | val name = "continuations" 22 | val description = "applies selective cps conversion" 23 | 24 | val pluginEnabled = options contains "enable" 25 | 26 | val anfPhase = new { 27 | val global = SelectiveCPSPlugin.this.global 28 | val cpsEnabled = pluginEnabled 29 | override val enabled = cpsEnabled 30 | } with SelectiveANFTransform { 31 | val runsAfter = List("pickler") 32 | } 33 | 34 | val cpsPhase = new { 35 | val global = SelectiveCPSPlugin.this.global 36 | val cpsEnabled = pluginEnabled 37 | override val enabled = cpsEnabled 38 | } with SelectiveCPSTransform { 39 | val runsAfter = List("selectiveanf") 40 | override val runsBefore = List("uncurry") 41 | } 42 | 43 | val components = List[PluginComponent](anfPhase, cpsPhase) 44 | 45 | val checker = new { 46 | val global: SelectiveCPSPlugin.this.global.type = SelectiveCPSPlugin.this.global 47 | val cpsEnabled = pluginEnabled 48 | } with CPSAnnotationChecker 49 | 50 | // TODO don't muck up global with unused checkers 51 | global.addAnnotationChecker(checker.checker) 52 | global.analyzer.addAnalyzerPlugin(checker.plugin) 53 | 54 | global.log("instantiated cps plugin: " + this) 55 | 56 | override def init(options: List[String], error: String => Unit) = { 57 | options foreach { 58 | case "enable" => // in initializer 59 | case arg => error(s"Bad argument: $arg") 60 | } 61 | pluginEnabled 62 | } 63 | 64 | override val optionsHelp: Option[String] = 65 | Some(" -P:continuations:enable Enable continuations") 66 | } 67 | -------------------------------------------------------------------------------- /plugin/src/main/scala/scala/tools/selectivecps/CPSUtils.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.tools.selectivecps 14 | 15 | import scala.tools.nsc.Global 16 | 17 | trait CPSUtils { 18 | val global: Global 19 | import global._ 20 | 21 | val cpsEnabled: Boolean 22 | val verbose: Boolean = System.getProperty("cpsVerbose", "false") == "true" 23 | def vprintln(x: =>Any): Unit = if (verbose) println(x) 24 | 25 | object cpsNames { 26 | val catches = newTermName(s"$$catches") 27 | val ex = newTermName(s"$$ex") 28 | val flatMapCatch = newTermName("flatMapCatch") 29 | val getTrivialValue = newTermName("getTrivialValue") 30 | val isTrivial = newTermName("isTrivial") 31 | val reify = newTermName("reify") 32 | val reifyR = newTermName("reifyR") 33 | val shift = newTermName("shift") 34 | val shiftR = newTermName("shiftR") 35 | val shiftSuffix = newTermName(s"$$shift") 36 | val shiftUnit0 = newTermName("shiftUnit0") 37 | val shiftUnit = newTermName("shiftUnit") 38 | val shiftUnitR = newTermName("shiftUnitR") 39 | } 40 | 41 | lazy val MarkerCPSSym = rootMirror.getRequiredClass("scala.util.continuations.cpsSym") 42 | lazy val MarkerCPSTypes = rootMirror.getRequiredClass("scala.util.continuations.cpsParam") 43 | lazy val MarkerCPSSynth = rootMirror.getRequiredClass("scala.util.continuations.cpsSynth") 44 | lazy val MarkerCPSAdaptPlus = rootMirror.getRequiredClass("scala.util.continuations.cpsPlus") 45 | lazy val MarkerCPSAdaptMinus = rootMirror.getRequiredClass("scala.util.continuations.cpsMinus") 46 | 47 | lazy val Context = rootMirror.getRequiredClass("scala.util.continuations.ControlContext") 48 | lazy val ModCPS = rootMirror.getPackage(TermName("scala.util.continuations")) 49 | 50 | lazy val MethShiftUnit = definitions.getMember(ModCPS, cpsNames.shiftUnit) 51 | lazy val MethShiftUnit0 = definitions.getMember(ModCPS, cpsNames.shiftUnit0) 52 | lazy val MethShiftUnitR = definitions.getMember(ModCPS, cpsNames.shiftUnitR) 53 | lazy val MethShift = definitions.getMember(ModCPS, cpsNames.shift) 54 | lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR) 55 | lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify) 56 | lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR) 57 | 58 | lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth, 59 | MarkerCPSAdaptPlus, MarkerCPSAdaptMinus) 60 | 61 | // TODO - needed? Can these all use the same annotation info? 62 | protected def newSynthMarker() = newMarker(MarkerCPSSynth) 63 | protected def newPlusMarker() = newMarker(MarkerCPSAdaptPlus) 64 | protected def newMinusMarker() = newMarker(MarkerCPSAdaptMinus) 65 | protected def newMarker(tpe: Type): AnnotationInfo = AnnotationInfo marker tpe 66 | protected def newMarker(sym: Symbol): AnnotationInfo = AnnotationInfo marker sym.tpe 67 | 68 | protected def newCpsParamsMarker(tp1: Type, tp2: Type) = 69 | newMarker(appliedType(MarkerCPSTypes, tp1, tp2)) 70 | 71 | // annotation checker 72 | 73 | protected def annTypes(ann: AnnotationInfo): (Type, Type) = { 74 | val tp0 :: tp1 :: Nil = ann.atp.dealiasWiden.typeArgs 75 | ((tp0, tp1)) 76 | } 77 | protected def hasMinusMarker(tpe: Type) = tpe hasAnnotation MarkerCPSAdaptMinus 78 | protected def hasPlusMarker(tpe: Type) = tpe hasAnnotation MarkerCPSAdaptPlus 79 | protected def hasSynthMarker(tpe: Type) = tpe hasAnnotation MarkerCPSSynth 80 | protected def hasCpsParamTypes(tpe: Type) = tpe hasAnnotation MarkerCPSTypes 81 | protected def cpsParamTypes(tpe: Type) = tpe getAnnotation MarkerCPSTypes map annTypes 82 | 83 | def filterAttribs(tpe:Type, cls:Symbol) = 84 | tpe.annotations filter (_ matches cls) 85 | 86 | def removeAttribs(tpe: Type, classes: Symbol*) = 87 | tpe filterAnnotations (ann => !(classes exists (ann matches _))) 88 | 89 | def removeAllCPSAnnotations(tpe: Type) = removeAttribs(tpe, allCPSAnnotations:_*) 90 | 91 | def cpsParamAnnotation(tpe: Type) = filterAttribs(tpe, MarkerCPSTypes) 92 | 93 | def linearize(ann: List[AnnotationInfo]): AnnotationInfo = { 94 | ann reduceLeft { (a, b) => 95 | val (u0,v0) = annTypes(a) 96 | val (u1,v1) = annTypes(b) 97 | // vprintln("check lin " + a + " andThen " + b) 98 | 99 | if (v1 <:< u0) 100 | newCpsParamsMarker(u1, v0) 101 | else 102 | throw new TypeError("illegal answer type modification: " + a + " andThen " + b) 103 | } 104 | } 105 | 106 | // anf transform 107 | 108 | def getExternalAnswerTypeAnn(tp: Type) = { 109 | cpsParamTypes(tp) orElse { 110 | if (hasPlusMarker(tp)) 111 | global.warning("trying to instantiate type " + tp + " to unknown cps type") 112 | None 113 | } 114 | } 115 | 116 | def getAnswerTypeAnn(tp: Type): Option[(Type, Type)] = 117 | cpsParamTypes(tp) filterNot (_ => hasPlusMarker(tp)) 118 | 119 | def hasAnswerTypeAnn(tp: Type) = 120 | hasCpsParamTypes(tp) && !hasPlusMarker(tp) 121 | 122 | def updateSynthFlag(tree: Tree) = { // remove annotations if *we* added them (@synth present) 123 | if (hasSynthMarker(tree.tpe)) { 124 | log("removing annotation from " + tree) 125 | tree modifyType removeAllCPSAnnotations 126 | } else 127 | tree 128 | } 129 | 130 | type CPSInfo = Option[(Type,Type)] 131 | 132 | def linearize(a: CPSInfo, b: CPSInfo)(implicit pos: Position): CPSInfo = { 133 | (a,b) match { 134 | case (Some((u0,v0)), Some((u1,v1))) => 135 | vprintln("check lin " + a + " andThen " + b) 136 | if (!(v1 <:< u0)) { 137 | reporter.error(pos,"cannot change answer type in composition of cps expressions " + 138 | "from " + u1 + " to " + v0 + " because " + v1 + " is not a subtype of " + u0 + ".") 139 | throw new Exception("check lin " + a + " andThen " + b) 140 | } 141 | Some((u1,v0)) 142 | case (Some(_), _) => a 143 | case (_, Some(_)) => b 144 | case _ => None 145 | } 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /library/src/main/scala/scala/util/continuations/package.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.util 14 | 15 | /* TODO: better documentation of return-type modification. 16 | * (Especially what means "Illegal answer type modification: ... andThen ...") 17 | */ 18 | 19 | /** 20 | * Delimited continuations are a feature for modifying the usual control flow 21 | * of a program. To use continuations, provide the option `-P:continuations:enable` 22 | * to the Scala compiler or REPL to activate the compiler plugin. 23 | * 24 | * Below is an example of using continuations to suspend execution while awaiting 25 | * user input. Similar facilities are used in so-called continuation-based web frameworks. 26 | * 27 | * {{{ 28 | * def go = 29 | * reset { 30 | * println("Welcome!") 31 | * val first = ask("Please give me a number") 32 | * val second = ask("Please enter another number") 33 | * printf("The sum of your numbers is: %d\n", first + second) 34 | * } 35 | * }}} 36 | * 37 | * The `reset` is provided by this package and delimits the extent of the 38 | * transformation. The `ask` is a function that will be defined below. Its 39 | * effect is to issue a prompt and then suspend execution awaiting user input. 40 | * Once the user provides an input value, execution of the suspended block 41 | * resumes. 42 | * 43 | * {{{ 44 | * val sessions = new HashMap[UUID, Int=>Unit] 45 | * def ask(prompt: String): Int @cps[Unit] = 46 | * shift { 47 | * k: (Int => Unit) => { 48 | * val id = uuidGen 49 | * printf("%s\nrespond with: submit(0x%x, ...)\n", prompt, id) 50 | * sessions += id -> k 51 | * } 52 | * } 53 | * }}} 54 | * 55 | * The type of `ask` includes a `@cps` annotation which drives the transformation. 56 | * The type signature `Int @cps[Unit]` means that `ask` should be used in a 57 | * context requiring an `Int`, but actually it will suspend and return `Unit`. 58 | * 59 | * The computation leading up to the first `ask` is executed normally. The 60 | * remainder of the reset block is wrapped into a closure that is passed as 61 | * the parameter `k` to the `shift` function, which can then decide whether 62 | * and how to execute the continuation. In this example, the continuation is 63 | * stored in a sessions map for later execution. This continuation includes a 64 | * second call to `ask`, which is treated likewise once the execution resumes. 65 | * 66 | *

CPS Annotation

67 | * 68 | * The aforementioned `@cps[A]` annotation is an alias for the more general 69 | * `@cpsParam[B,C]` where `B=C`. The type `A @cpsParam[B,C]` describes a term 70 | * which yields a value of type `A` within an evaluation context producing a 71 | * value of type `B`. After the CPS transformation, this return type is 72 | * modified to `C`. 73 | * 74 | * The `@cpsParam` annotations are introduced by `shift` blocks, and propagate 75 | * via the return types to the dynamically enclosing context. The propagation 76 | * stops upon reaching a `reset` block. 77 | */ 78 | 79 | package object continuations { 80 | 81 | /** An annotation that denotes a type is part of a continuation context. 82 | * `@cps[A]` is shorthand for `cpsParam[A,A]`. 83 | * @tparam A The return type of the continuation context. 84 | */ 85 | type cps[A] = cpsParam[A,A] 86 | 87 | /** An annotation that denotes a type is part of a side effecting continuation context. 88 | * `@suspendable` is shorthand notation for `@cpsParam[Unit,Unit]` or `@cps[Unit]`. 89 | */ 90 | type suspendable = cps[Unit] 91 | 92 | /** 93 | * The `shift` function captures the remaining computation in a `reset` block 94 | * and passes it to a closure provided by the user. 95 | * 96 | * For example: 97 | * {{{ 98 | * reset { 99 | * shift { (k: Int => Int) => k(5) } + 1 100 | * } 101 | * }}} 102 | * 103 | * In this example, `shift` is used in the expression `shift ... + 1`. 104 | * The compiler will alter this expression so that the call 105 | * to `shift` becomes a parameter to a function, creating something like: 106 | * {{{ 107 | * { (k: Int => Int) => k(5) } apply { _ + 1 } 108 | * }}} 109 | * The result of this expression is 6. 110 | * 111 | * There can be more than one `shift` call in a `reset` block. Each call 112 | * to `shift` can alter the return type of expression within the reset block, 113 | * but will not change the return type of the entire `reset { block }` 114 | * expression. 115 | * 116 | * @param fun A function where 117 | * - The parameter is the remainder of computation within the current 118 | * `reset` block. This is passed as a function `A => B`. 119 | * - The return is the return value of the `ControlContext` which is 120 | * generated from this inversion. 121 | * @note Must be invoked in the context of a call to `reset` This context 122 | * may not be far up the stack, but a call to reset is needed to 123 | * eventually remove the `@cps` annotations from types. 124 | */ 125 | def shift[A,B,C](fun: (A => B) => C): A @cpsParam[B,C] = { 126 | throw new NoSuchMethodException("this code has to be compiled with the Scala continuations plugin enabled") 127 | } 128 | /** Creates a context for continuations captured within the argument closure 129 | * of this `reset` call and returns the result of the entire transformed 130 | * computation. Within an expression of the form `reset { block }`, 131 | * the closure expression (`block`) will be modified such that at each 132 | * call to `shift` the remainder of the expression is transformed into a 133 | * function to be passed into the shift. 134 | * @return The result of a block of code that uses `shift` to capture continuations. 135 | */ 136 | def reset[A,C](ctx: =>(A @cpsParam[A,C])): C = { 137 | val ctxR = reify[A,A,C](ctx) 138 | if (ctxR.isTrivial) 139 | ctxR.getTrivialValue.asInstanceOf[C] 140 | else 141 | ctxR.foreach((x:A) => x) 142 | } 143 | 144 | def reset0[A](ctx: =>(A @cpsParam[A,A])): A = reset(ctx) 145 | 146 | def run[A](ctx: =>(Any @cpsParam[Unit,A])): A = { 147 | val ctxR = reify[Any,Unit,A](ctx) 148 | if (ctxR.isTrivial) 149 | ctxR.getTrivialValue.asInstanceOf[A] 150 | else 151 | ctxR.foreach((x:Any) => ()) 152 | } 153 | 154 | 155 | // methods below are primarily implementation details and are not 156 | // needed frequently in client code 157 | 158 | def shiftUnit0[A,B](x: A): A @cpsParam[B,B] = { 159 | shiftUnit[A,B,B](x) 160 | } 161 | 162 | def shiftUnit[A,B,C>:B](x: A): A @cpsParam[B,C] = { 163 | throw new NoSuchMethodException("this code has to be compiled with the Scala continuations plugin enabled") 164 | } 165 | 166 | /** This method converts from the sugared `A @cpsParam[B,C]` type to the desugared 167 | * `ControlContext[A,B,C]` type. The underlying data is not changed. 168 | */ 169 | def reify[A,B,C](ctx: =>(A @cpsParam[B,C])): ControlContext[A,B,C] = { 170 | throw new NoSuchMethodException("this code has to be compiled with the Scala continuations plugin enabled") 171 | } 172 | 173 | def shiftUnitR[A,B](x: A): ControlContext[A,B,B] = { // called in code generated by SelectiveCPSTransform 174 | new ControlContext[A, B, B](null, x) 175 | } 176 | 177 | /** 178 | * Captures a computation into a `ControlContext`. 179 | * @param fun The function which accepts the inverted computation and returns 180 | * a final result. 181 | * @see shift 182 | */ 183 | def shiftR[A,B,C](fun: (A => B) => C): ControlContext[A,B,C] = { // called in code generated by SelectiveCPSTransform 184 | new ControlContext((f:A=>B,g:Exception=>B) => fun(f), null.asInstanceOf[A]) 185 | } 186 | 187 | def reifyR[A,B,C](ctx: => ControlContext[A,B,C]): ControlContext[A,B,C] = { // called in code generated by SelectiveCPSTransform 188 | ctx 189 | } 190 | 191 | } 192 | -------------------------------------------------------------------------------- /library/src/test/scala/scala/tools/selectivecps/CompilerErrors.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.tools.selectivecps 14 | 15 | import org.junit.Test 16 | 17 | class CompilerErrors extends CompilerTesting { 18 | // @Test -- disabled 19 | def infer0 = 20 | expectCPSError("cannot cps-transform expression 8: type arguments [Int(8),String,Int] do not conform to method shiftUnit's type parameter bounds [A,B,C >: B]", 21 | """|def test(x: => Int @cpsParam[String,Int]) = 7 22 | | 23 | |def main(args: Array[String]) = { 24 | | test(8) 25 | |}""") 26 | 27 | @Test def function0 = 28 | expectCPSError("""|type mismatch; 29 | | found : () => Int @scala.util.continuations.cpsParam[Int,Int] 30 | | required: () => Int""".stripMargin, 31 | """|def main(args: Array[String]): Any = { 32 | | val f = () => shift { k: (Int=>Int) => k(7) } 33 | | val g: () => Int = f 34 | | 35 | | println(reset(g())) 36 | |}""") 37 | 38 | @Test def function2 = 39 | expectCPSError( 40 | """|type mismatch; 41 | | found : () => Int 42 | | required: () => Int @scala.util.continuations.cpsParam[Int,Int]""".stripMargin, 43 | """|def main(args: Array[String]): Any = { 44 | | val f = () => 7 45 | | val g: () => Int @cps[Int] = f 46 | | 47 | | println(reset(g())) 48 | |}""") 49 | 50 | @Test def function3 = 51 | expectCPSError( 52 | """|type mismatch; 53 | | found : Int @scala.util.continuations.cpsParam[Int,Int] 54 | | required: Int""".stripMargin, 55 | """|def main(args: Array[String]): Any = { 56 | | val g: () => Int = () => shift { k: (Int=>Int) => k(7) } 57 | | 58 | | println(reset(g())) 59 | |}""") 60 | 61 | @Test def infer2 = 62 | expectCPSError("illegal answer type modification: scala.util.continuations.cpsParam[String,Int] andThen scala.util.continuations.cpsParam[String,Int]", 63 | """|def test(x: => Int @cpsParam[String,Int]) = 7 64 | | 65 | |def sym() = shift { k: (Int => String) => 9 } 66 | | 67 | | 68 | |def main(args: Array[String]): Any = { 69 | | test { sym(); sym() } 70 | |}""") 71 | 72 | @Test def `lazy` = 73 | expectCPSError("implementation restriction: cps annotations not allowed on lazy value definitions", 74 | """|def foo() = { 75 | | lazy val x = shift((k:Unit=>Unit)=>k()) 76 | | println(x) 77 | |} 78 | | 79 | |def main(args: Array[String]) = { 80 | | reset { 81 | | foo() 82 | | } 83 | |}""") 84 | 85 | @Test def t1929 = 86 | expectCPSError( 87 | """|type mismatch; 88 | | found : Int @scala.util.continuations.cpsParam[String,String] @scala.util.continuations.cpsSynth 89 | | required: Int @scala.util.continuations.cpsParam[Int,String]""".stripMargin, 90 | """|def main(args : Array[String]) { 91 | | reset { 92 | | println("up") 93 | | val x = shift((k:Int=>String) => k(8) + k(2)) 94 | | println("down " + x) 95 | | val y = shift((k:Int=>String) => k(3)) 96 | | println("down2 " + y) 97 | | y + x 98 | | } 99 | |}""") 100 | 101 | @Test def t2285 = 102 | expectCPSError( 103 | """|type mismatch; 104 | | found : Int @scala.util.continuations.cpsParam[String,String] @scala.util.continuations.cpsSynth 105 | | required: Int @scala.util.continuations.cpsParam[Int,String]""".stripMargin, 106 | """|def bar() = shift { k: (String => String) => k("1") } 107 | | 108 | |def foo() = reset { bar(); 7 }""") 109 | 110 | @Test def t2949 = 111 | expectCPSError( 112 | """|type mismatch; 113 | | found : Int 114 | | required: ? @scala.util.continuations.cpsParam[List[?],Any]""".stripMargin, 115 | """|def reflect[A,B](xs : List[A]) = shift{ xs.flatMap[B, List[B]] } 116 | |def reify[A, B](x : A @cpsParam[List[A], B]) = reset{ List(x) } 117 | | 118 | |def main(args: Array[String]): Unit = println(reify { 119 | | val x = reflect[Int, Int](List(1,2,3)) 120 | | val y = reflect[Int, Int](List(2,4,8)) 121 | | x * y 122 | |})""") 123 | 124 | @Test def t3718 = 125 | expectCPSError( 126 | "cannot cps-transform malformed (possibly in shift/reset placement) expression", 127 | "scala.util.continuations.reset((_: Any).##)") 128 | 129 | @Test def t5314_missing_result_type = 130 | expectCPSError( 131 | "method bar has return statement; needs result type", 132 | """|def foo(x:Int): Int @cps[Int] = x 133 | | 134 | |def bar(x:Int) = return foo(x) 135 | | 136 | |reset { 137 | | val res = bar(8) 138 | | println(res) 139 | | res 140 | |}""") 141 | 142 | @Test def t5314_npe = 143 | expectCPSError( 144 | "method bar has return statement; needs result type", 145 | "def bar(x:Int) = { return x; x } // NPE") 146 | 147 | @Test def t5314_return_reset = 148 | expectCPSError( 149 | "return expression not allowed, since method calls CPS method", 150 | """|val rnd = new scala.util.Random 151 | | 152 | |def foo(x: Int): Int @cps[Int] = shift { k => k(x) } 153 | | 154 | |def bar(x: Int): Int @cps[Int] = return foo(x) 155 | | 156 | |def caller(): Int = { 157 | | val v: Int = reset { 158 | | val res: Int = bar(8) 159 | | if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset` 160 | | 42 161 | | } 162 | | v 163 | |} 164 | | 165 | |caller()""") 166 | 167 | @Test def t5314_type_error = 168 | expectCPSError( 169 | """|type mismatch; 170 | | found : Int @scala.util.continuations.cpsParam[Int,Int] 171 | | required: Int @scala.util.continuations.cpsParam[String,String]""".stripMargin, 172 | """|def foo(x:Int): Int @cps[Int] = shift { k => k(x) } 173 | | 174 | |// should be a type error 175 | |def bar(x:Int): Int @cps[String] = return foo(x) 176 | | 177 | |def caller(): Unit = { 178 | | val v: String = reset { 179 | | val res: Int = bar(8) 180 | | "hello" 181 | | } 182 | |} 183 | | 184 | |caller()""") 185 | 186 | @Test def t5445 = 187 | expectCPSError( 188 | "cps annotations not allowed on by-value parameters or value definitions", 189 | "def foo(block: Unit @suspendable ): Unit @suspendable = {}") 190 | 191 | @Test def trycatch2 = 192 | expectCPSErrors(2, "only simple cps types allowed in try/catch blocks (found: Int @scala.util.continuations.cpsParam[String,Int])", 193 | """|def fatal[T]: T = throw new Exception 194 | |def cpsIntStringInt = shift { k:(Int=>String) => k(3); 7 } 195 | |def cpsIntIntString = shift { k:(Int=>Int) => k(3); "7" } 196 | | 197 | |def foo1 = try { 198 | | fatal[Int] 199 | | cpsIntStringInt 200 | |} catch { 201 | | case ex: Throwable => 202 | | cpsIntStringInt 203 | |} 204 | | 205 | |def foo2 = try { 206 | | fatal[Int] 207 | | cpsIntStringInt 208 | |} catch { 209 | | case ex: Throwable => 210 | | cpsIntStringInt 211 | |} 212 | | 213 | | 214 | |def main(args: Array[String]): Unit = { 215 | | println(reset { foo1; "3" }) 216 | | println(reset { foo2; "3" }) 217 | |}""") 218 | } 219 | 220 | class CompilerTesting { 221 | private def pluginJar: String = { 222 | val f = sys.props("scala-continuations-plugin.jar") 223 | assert(new java.io.File(f).exists, f) 224 | f 225 | } 226 | def loadPlugin = s"-Xplugin:${pluginJar} -P:continuations:enable" 227 | 228 | // note: `code` should have a | margin 229 | def cpsErrorMessages(msg: String, code: String) = 230 | errorMessages(msg, loadPlugin)(s"import scala.util.continuations._\nobject Test {\n${code.stripMargin}\n}") 231 | 232 | def expectCPSError(msg: String, code: String) = { 233 | val errors = cpsErrorMessages(msg, code) 234 | assert(errors exists (_ contains msg), errors mkString "\n") 235 | } 236 | 237 | def expectCPSErrors(msgCount: Int, msg: String, code: String) = { 238 | val errors = cpsErrorMessages(msg, code) 239 | val errorCount = errors.filter(_ contains msg).length 240 | assert(errorCount == msgCount, s"$errorCount occurrences of \'$msg\' found -- expected $msgCount in:\n${errors mkString "\n"}") 241 | } 242 | 243 | // TODO: move to scala.tools.reflect.ToolboxFactory 244 | def errorMessages(errorSnippet: String, compileOptions: String)(code: String): List[String] = { 245 | import scala.tools.reflect._ 246 | val m = scala.reflect.runtime.currentMirror 247 | val tb = m.mkToolBox(options = compileOptions) //: ToolBox[m.universe.type] 248 | val fe = tb.frontEnd 249 | 250 | try { 251 | tb.eval(tb.parse(code)) 252 | Nil 253 | } catch { 254 | case _: ToolBoxError => 255 | import fe._ 256 | infos.toList collect { case Info(_, msg, ERROR) => msg } 257 | } 258 | } 259 | } -------------------------------------------------------------------------------- /library/src/main/scala/scala/util/continuations/ControlContext.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.util.continuations 14 | 15 | import scala.annotation.{ Annotation, StaticAnnotation, TypeConstraint } 16 | 17 | /** This annotation is used to mark a parameter as part of a continuation 18 | * context. 19 | * 20 | * The type `A @cpsParam[B,C]` is desugared to `ControlContext[A,B,C]` at compile 21 | * time. 22 | * 23 | * @tparam B The type of computation state after computation has executed, and 24 | * before control is returned to the shift. 25 | * @tparam C The eventual return type of this delimited compuation. 26 | * @see scala.util.continuations.ControlContext 27 | */ 28 | class cpsParam[-B,+C] extends StaticAnnotation with TypeConstraint 29 | 30 | private class cpsSym[B] extends Annotation // implementation detail 31 | 32 | private class cpsSynth extends Annotation // implementation detail 33 | 34 | private class cpsPlus extends StaticAnnotation with TypeConstraint // implementation detail 35 | private class cpsMinus extends Annotation // implementation detail 36 | 37 | 38 | /** 39 | * This class represent a portion of computation that has a 'hole' in it. The 40 | * class has the ability to compute state up until a certain point where the 41 | * state has the `A` type. If this context is given a function of type 42 | * `A => B` to move the state to the `B` type, then the entire computation can 43 | * be completed resulting in a value of type `C`. 44 | * 45 | * An Example: {{{ 46 | * val cc = new ControlContext[String, String, String]( 47 | * fun = { (f: String=>String, err: Exception => String) => 48 | * val updatedState = 49 | * try f("State") 50 | * catch { 51 | * case e: Exception => err(e) 52 | * } 53 | * updatedState + "-Complete!" 54 | * }, 55 | * x = null.asIntanceOf[String] 56 | * } 57 | * cc.foreach(_ + "-Continued") // Results in "State-Continued-Complete!" 58 | * }}} 59 | * 60 | * This class is used to transform calls to `shift` in the `continuations` 61 | * package. Direct use and instantiation is possible, but usually reserved 62 | * for advanced cases. 63 | * 64 | * 65 | * A context may either be ''trivial'' or ''non-trivial''. A ''trivial'' 66 | * context '''just''' has a state of type `A`. When completing the computation, 67 | * it's only necessary to use the function of type `A => B` directly against 68 | * the trivial value. A ''non-trivial'' value stores a computation '''around''' 69 | * the state transformation of type `A => B` and cannot be short-circuited. 70 | * 71 | * @param fun The captured computation so far. The type 72 | * `(A => B, Exception => B) => C` is a function where: 73 | * - The first parameter `A=>B` represents the computation defined against 74 | * the current state held in the ControlContext. 75 | * - The second parameter `Exception => B` represents a computation to 76 | * perform if an exception is thrown from the first parameter's computation. 77 | * - The return value is the result of the entire computation contained in this 78 | * `ControlContext`. 79 | * @param x The current state stored in this context. Allowed to be null if 80 | * the context is non-trivial. 81 | * @tparam A The type of the state currently held in the context. 82 | * @tparam B The type of the transformed state needed to complete this computation. 83 | * @tparam C The return type of the entire computation stored in this context. 84 | * @note `fun` and `x` are allowed to be `null`. 85 | * @see scala.util.continutations.shiftR 86 | */ 87 | final class ControlContext[+A,-B,+C](val fun: (A => B, Exception => B) => C, val x: A) extends Serializable { 88 | 89 | /* 90 | final def map[A1](f: A => A1): ControlContext[A1,B,C] = { 91 | new ControlContext((k:(A1 => B)) => fun((x:A) => k(f(x))), null.asInstanceOf[A1]) 92 | } 93 | 94 | final def flatMap[A1,B1<:B](f: (A => ControlContext[A1,B1,B])): ControlContext[A1,B1,C] = { 95 | new ControlContext((k:(A1 => B1)) => fun((x:A) => f(x).fun(k))) 96 | } 97 | */ 98 | 99 | /** 100 | * Modifies the currently captured state in this `ControlContext`. 101 | * @tparam A1 The new type of state in this context. 102 | * @param f A transformation function on the current state of the `ControlContext`. 103 | * @return The new `ControlContext`. 104 | */ 105 | @noinline final def map[A1](f: A => A1): ControlContext[A1,B,C] = { 106 | if (fun eq null) 107 | try { 108 | new ControlContext[A1,B,C](null, f(x)) // TODO: only alloc if f(x) != x 109 | } catch { 110 | case ex: Exception => 111 | new ControlContext((k: A1 => B, thr: Exception => B) => thr(ex).asInstanceOf[C], null.asInstanceOf[A1]) 112 | } 113 | else 114 | new ControlContext({ (k: A1 => B, thr: Exception => B) => 115 | fun( { (x:A) => 116 | var done = false 117 | try { 118 | val res = f(x) 119 | done = true 120 | k(res) 121 | } catch { 122 | case ex: Exception if !done => 123 | thr(ex) 124 | } 125 | }, thr) 126 | }, null.asInstanceOf[A1]) 127 | } 128 | 129 | 130 | // it would be nice if @inline would turn the trivial path into a tail call. 131 | // unfortunately it doesn't, so we do it ourselves in SelectiveCPSTransform 132 | 133 | /** 134 | * Maps and flattens this `ControlContext` with another `ControlContext` generated from the current state. 135 | * @note The resulting comuptation is still the type `C`. 136 | * @tparam A1 The new type of the contained state. 137 | * @tparam B1 The new type of the state after the stored continuation has executed. 138 | * @tparam C1 The result type of the nested `ControlContext`. Because the nested `ControlContext` is executed within 139 | * the outer `ControlContext`, this type must `>: B` so that the resulting nested computation can be fed through 140 | * the current continuation. 141 | * @param f A transformation function from the current state to a nested `ControlContext`. 142 | * @return The transformed `ControlContext`. 143 | */ 144 | @noinline final def flatMap[A1,B1,C1<:B](f: (A => ControlContext[A1,B1,C1])): ControlContext[A1,B1,C] = { 145 | if (fun eq null) 146 | try { 147 | f(x).asInstanceOf[ControlContext[A1,B1,C]] 148 | } catch { 149 | case ex: Exception => 150 | new ControlContext((k: A1 => B1, thr: Exception => B1) => thr(ex).asInstanceOf[C], null.asInstanceOf[A1]) 151 | } 152 | else 153 | new ControlContext({ (k: A1 => B1, thr: Exception => B1) => 154 | fun( { (x:A) => 155 | var done = false 156 | try { 157 | val ctxR = f(x) 158 | done = true 159 | val res: C1 = ctxR.foreachFull(k, thr) // => B1 160 | res 161 | } catch { 162 | case ex: Exception if !done => 163 | thr(ex).asInstanceOf[B] // => B NOTE: in general this is unsafe! 164 | } // However, the plugin will not generate offending code 165 | }, thr.asInstanceOf[Exception=>B]) // => B 166 | }, null.asInstanceOf[A1]) 167 | } 168 | 169 | /** 170 | * Runs the computation against the state stored in this `ControlContext`. 171 | * @param f the computation that modifies the current state of the context. 172 | * @note This method could throw exceptions from the computations. 173 | */ 174 | final def foreach(f: A => B) = foreachFull(f, throw _) 175 | 176 | def foreachFull(f: A => B, g: Exception => B): C = { 177 | if (fun eq null) 178 | f(x).asInstanceOf[C] 179 | else 180 | fun(f, g) 181 | } 182 | 183 | /** @return true if this context only stores a state value and not any deferred computation. */ 184 | final def isTrivial = fun eq null 185 | /** @return The current state value. */ 186 | final def getTrivialValue = x.asInstanceOf[A] 187 | 188 | // need filter or other functions? 189 | 190 | final def flatMapCatch[A1>:A,B1<:B,C1>:C<:B1](pf: PartialFunction[Exception, ControlContext[A1,B1,C1]]): ControlContext[A1,B1,C1] = { // called by codegen from SelectiveCPSTransform 191 | if (fun eq null) 192 | this 193 | else { 194 | val fun1 = (ret1: A1 => B1, thr1: Exception => B1) => { 195 | val thr: Exception => B1 = { t: Exception => 196 | var captureExceptions = true 197 | try { 198 | if (pf.isDefinedAt(t)) { 199 | val cc1 = pf(t) 200 | captureExceptions = false 201 | cc1.foreachFull(ret1, thr1) // Throw => B 202 | } else { 203 | captureExceptions = false 204 | thr1(t) // Throw => B1 205 | } 206 | } catch { 207 | case t1: Exception if captureExceptions => thr1(t1) // => E2 208 | } 209 | } 210 | fun(ret1, thr)// fun(ret1, thr) // => B 211 | } 212 | new ControlContext(fun1, null.asInstanceOf[A1]) 213 | } 214 | } 215 | 216 | final def mapFinally(f: () => Unit): ControlContext[A,B,C] = { // called in code generated by SelectiveCPSTransform 217 | if (fun eq null) { 218 | try { 219 | f() 220 | this 221 | } catch { 222 | case ex: Exception => 223 | new ControlContext((k: A => B, thr: Exception => B) => thr(ex).asInstanceOf[C], null.asInstanceOf[A]) 224 | } 225 | } else { 226 | val fun1 = (ret1: A => B, thr1: Exception => B) => { 227 | val ret: A => B = { x: A => 228 | var captureExceptions = true 229 | try { 230 | f() 231 | captureExceptions = false 232 | ret1(x) 233 | } catch { 234 | case t1: Exception if captureExceptions => thr1(t1) 235 | } 236 | } 237 | val thr: Exception => B = { t: Exception => 238 | var captureExceptions = true 239 | try { 240 | f() 241 | captureExceptions = false 242 | thr1(t) 243 | } catch { 244 | case t1: Exception if captureExceptions => thr1(t1) 245 | } 246 | } 247 | fun(ret, thr1) 248 | } 249 | new ControlContext(fun1, null.asInstanceOf[A]) 250 | } 251 | } 252 | 253 | } 254 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /plugin/src/main/scala/scala/tools/selectivecps/SelectiveCPSTransform.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.tools.selectivecps 14 | 15 | import scala.tools.nsc.transform._ 16 | import scala.tools.nsc.plugins._ 17 | import scala.tools.nsc.ast._ 18 | 19 | /** 20 | * In methods marked @cps, CPS-transform assignments introduced by ANF-transform phase. 21 | */ 22 | abstract class SelectiveCPSTransform extends PluginComponent with 23 | InfoTransform with TypingTransformers with CPSUtils with TreeDSL { 24 | // inherits abstract value `global` and class `Phase` from Transform 25 | 26 | import global._ // the global environment 27 | import definitions._ // standard classes and methods 28 | import typer.atOwner // methods to type trees 29 | 30 | override def description = "@cps-driven transform of selectiveanf assignments" 31 | 32 | /** the following two members override abstract members in Transform */ 33 | val phaseName: String = "selectivecps" 34 | 35 | protected def newTransformer(unit: CompilationUnit): Transformer = 36 | new CPSTransformer(unit) 37 | 38 | /** This class does not change linearization */ 39 | override def changesBaseClasses = false 40 | 41 | /** - return symbol's transformed type, 42 | */ 43 | def transformInfo(sym: Symbol, tp: Type): Type = { 44 | if (!cpsEnabled) return tp 45 | 46 | val newtp = transformCPSType(tp) 47 | 48 | if (newtp != tp) 49 | debuglog("transformInfo changed type for " + sym + " to " + newtp); 50 | 51 | if (sym == MethReifyR) 52 | debuglog("transformInfo (not)changed type for " + sym + " to " + newtp); 53 | 54 | newtp 55 | } 56 | 57 | def transformCPSType(tp: Type): Type = { // TODO: use a TypeMap? need to handle more cases? 58 | tp match { 59 | case PolyType(params,res) => PolyType(params, transformCPSType(res)) 60 | case NullaryMethodType(res) => NullaryMethodType(transformCPSType(res)) 61 | case MethodType(params,res) => MethodType(params, transformCPSType(res)) 62 | case TypeRef(pre, sym, args) => TypeRef(pre, sym, args.map(transformCPSType(_))) 63 | case _ => 64 | getExternalAnswerTypeAnn(tp) match { 65 | case Some((res, outer)) => 66 | appliedType(Context.tpeHK, List(removeAllCPSAnnotations(tp), res, outer)) 67 | case _ => 68 | removeAllCPSAnnotations(tp) 69 | } 70 | } 71 | } 72 | 73 | 74 | class CPSTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { 75 | private val patmatTransformer = patmat.newTransformer(unit) 76 | 77 | override def transform(tree: Tree): Tree = { 78 | if (!cpsEnabled) return tree 79 | postTransform(mainTransform(tree)) 80 | } 81 | 82 | def postTransform(tree: Tree): Tree = { 83 | tree.setType(transformCPSType(tree.tpe)) 84 | } 85 | 86 | 87 | def mainTransform(tree: Tree): Tree = { 88 | tree match { 89 | 90 | // TODO: can we generalize this? 91 | 92 | case Apply(TypeApply(fun, targs), args) 93 | if (fun.symbol == MethShift) => 94 | debuglog("found shift: " + tree) 95 | atPos(tree.pos) { 96 | val funR = gen.mkAttributedRef(MethShiftR) // TODO: correct? 97 | //gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedIdent(ScalaPackage), 98 | //ScalaPackage.tpe.member("util")), ScalaPackage.tpe.member("util").tpe.member("continuations")), MethShiftR) 99 | //gen.mkAttributedRef(ModCPS.tpe, MethShiftR) // TODO: correct? 100 | debuglog("funR.tpe: " + funR.tpe) 101 | Apply( 102 | TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))), 103 | args.map(transform(_)) 104 | ).setType(transformCPSType(tree.tpe)) 105 | } 106 | 107 | case Apply(TypeApply(fun, targs), args) 108 | if (fun.symbol == MethShiftUnit) => 109 | debuglog("found shiftUnit: " + tree) 110 | atPos(tree.pos) { 111 | val funR = gen.mkAttributedRef(MethShiftUnitR) // TODO: correct? 112 | debuglog("funR.tpe: " + funR.tpe) 113 | Apply( 114 | TypeApply(funR, List(targs(0), targs(1))).setType(appliedType(funR.tpe, 115 | List(targs(0).tpe, targs(1).tpe))), 116 | args.map(transform(_)) 117 | ).setType(appliedType(Context.tpeHK, List(targs(0).tpe,targs(1).tpe,targs(1).tpe))) 118 | } 119 | 120 | case Apply(TypeApply(fun, targs), args) 121 | if (fun.symbol == MethReify) => 122 | log("found reify: " + tree) 123 | atPos(tree.pos) { 124 | val funR = gen.mkAttributedRef(MethReifyR) // TODO: correct? 125 | debuglog("funR.tpe: " + funR.tpe) 126 | Apply( 127 | TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))), 128 | args.map(transform(_)) 129 | ).setType(transformCPSType(tree.tpe)) 130 | } 131 | 132 | case Try(block, catches, finalizer) => 133 | // currently duplicates the catch block into a partial function. 134 | // this is kinda risky, but we don't expect there will be lots 135 | // of try/catches inside catch blocks (exp. blowup unlikely). 136 | 137 | // CAVEAT: finalizers are surprisingly tricky! 138 | // the problem is that they cannot easily be removed 139 | // from the regular control path and hence will 140 | // also be invoked after creating the Context object. 141 | 142 | /* 143 | object Test { 144 | def foo1 = { 145 | throw new Exception("in sub") 146 | shift((k:Int=>Int) => k(1)) 147 | 10 148 | } 149 | def foo2 = { 150 | shift((k:Int=>Int) => k(2)) 151 | 20 152 | } 153 | def foo3 = { 154 | shift((k:Int=>Int) => k(3)) 155 | throw new Exception("in sub") 156 | 30 157 | } 158 | def foo4 = { 159 | shift((k:Int=>Int) => 4) 160 | throw new Exception("in sub") 161 | 40 162 | } 163 | def bar(x: Int) = try { 164 | if (x == 1) 165 | foo1 166 | else if (x == 2) 167 | foo2 168 | else if (x == 3) 169 | foo3 170 | else //if (x == 4) 171 | foo4 172 | } catch { 173 | case _ => 174 | println("exception") 175 | 0 176 | } finally { 177 | println("done") 178 | } 179 | } 180 | 181 | reset(Test.bar(1)) // should print: exception,done,0 182 | reset(Test.bar(2)) // should print: done,20 <-- but prints: done,done,20 183 | reset(Test.bar(3)) // should print: exception,done,0 <-- but prints: done,exception,done,0 184 | reset(Test.bar(4)) // should print: 4 <-- but prints: done,4 185 | */ 186 | 187 | val block1 = transform(block) 188 | val catches1 = transformCaseDefs(catches) 189 | val finalizer1 = transform(finalizer) 190 | 191 | if (hasAnswerTypeAnn(tree.tpe)) { 192 | //vprintln("CPS Transform: " + tree + "/" + tree.tpe + "/" + block1.tpe) 193 | 194 | val (stms, expr1) = block1 match { 195 | case Block(stms, expr) => (stms, expr) 196 | case expr => (Nil, expr) 197 | } 198 | 199 | val targettp = transformCPSType(tree.tpe) 200 | 201 | val pos = catches.head.pos 202 | val funSym = currentOwner.newValueParameter(cpsNames.catches, pos).setInfo(appliedType(PartialFunctionClass, ThrowableTpe, targettp)) 203 | val funDef = localTyper.typedPos(pos) { 204 | ValDef(funSym, Match(EmptyTree, catches1)) 205 | } 206 | val expr2 = localTyper.typedPos(pos) { 207 | Apply(Select(expr1, expr1.tpe.member(cpsNames.flatMapCatch)), List(Ident(funSym))) 208 | } 209 | 210 | val exSym = currentOwner.newValueParameter(cpsNames.ex, pos).setInfo(ThrowableTpe) 211 | 212 | import CODE._ 213 | // generate a case that is supported directly by the back-end 214 | val catchIfDefined = CaseDef( 215 | Bind(exSym, Ident(nme.WILDCARD)), 216 | EmptyTree, 217 | IF ((REF(funSym) DOT nme.isDefinedAt)(REF(exSym))) THEN (REF(funSym) APPLY (REF(exSym))) ELSE Throw(REF(exSym)) 218 | ) 219 | 220 | val catch2 = localTyper.typedCases(List(catchIfDefined), ThrowableTpe, targettp) 221 | //typedCases(tree, catches, ThrowableTpe, pt) 222 | 223 | patmatTransformer.transform(localTyper.typed(Block(List(funDef), treeCopy.Try(tree, treeCopy.Block(block1, stms, expr2), catch2, finalizer1)))) 224 | 225 | 226 | /* 227 | disabled for now - see notes above 228 | 229 | val expr3 = if (!finalizer.isEmpty) { 230 | val pos = finalizer.pos 231 | val finalizer2 = duplicateTree(finalizer1) 232 | val fun = Function(List(), finalizer2) 233 | val expr3 = localTyper.typedPos(pos) { Apply(Select(expr2, expr2.tpe.member("mapFinally")), List(fun)) } 234 | 235 | val chown = new ChangeOwnerTraverser(currentOwner, fun.symbol) 236 | chown.traverse(finalizer2) 237 | 238 | expr3 239 | } else 240 | expr2 241 | */ 242 | } else { 243 | treeCopy.Try(tree, block1, catches1, finalizer1) 244 | } 245 | 246 | case Block(stms, expr) => 247 | 248 | val (stms1, expr1) = transBlock(stms, expr) 249 | treeCopy.Block(tree, stms1, expr1) 250 | 251 | case _ => 252 | super.transform(tree) 253 | } 254 | } 255 | 256 | 257 | 258 | def transBlock(stms: List[Tree], expr: Tree): (List[Tree], Tree) = { 259 | 260 | stms match { 261 | case Nil => 262 | (Nil, transform(expr)) 263 | 264 | case stm::rest => 265 | 266 | stm match { 267 | case vd @ ValDef(mods, name, tpt, rhs) 268 | if (vd.symbol.hasAnnotation(MarkerCPSSym)) => 269 | 270 | debuglog("found marked ValDef "+name+" of type " + vd.symbol.tpe) 271 | 272 | val tpe = vd.symbol.tpe 273 | val rhs1 = atOwner(vd.symbol) { transform(rhs) } 274 | rhs1.changeOwner(vd.symbol -> currentOwner) // TODO: don't traverse twice 275 | 276 | debuglog("valdef symbol " + vd.symbol + " has type " + tpe) 277 | debuglog("right hand side " + rhs1 + " has type " + rhs1.tpe) 278 | 279 | debuglog("currentOwner: " + currentOwner) 280 | debuglog("currentMethod: " + currentMethod) 281 | 282 | val (bodyStms, bodyExpr) = transBlock(rest, expr) 283 | // FIXME: result will later be traversed again by TreeSymSubstituter and 284 | // ChangeOwnerTraverser => exp. running time. 285 | // Should be changed to fuse traversals into one. 286 | 287 | val specialCaseTrivial = bodyExpr match { 288 | case Apply(fun, args) => 289 | // for now, look for explicit tail calls only. 290 | // are there other cases that could profit from specializing on 291 | // trivial contexts as well? 292 | (bodyExpr.tpe.typeSymbol == Context) && (currentMethod == fun.symbol) 293 | case _ => false 294 | } 295 | 296 | def applyTrivial(ctxValSym: Symbol, body: Tree) = { 297 | 298 | val body1 = (new TreeSymSubstituter(List(vd.symbol), List(ctxValSym)))(body) 299 | 300 | val body2 = localTyper.typedPos(vd.symbol.pos) { body1 } 301 | 302 | // in theory it would be nicer to look for an @cps annotation instead 303 | // of testing for Context 304 | if ((body2.tpe == null) || !(body2.tpe.typeSymbol == Context)) { 305 | //println(body2 + "/" + body2.tpe) 306 | reporter.error(rhs.pos, "cannot compute type for CPS-transformed function result") 307 | } 308 | body2 309 | } 310 | 311 | def applyCombinatorFun(ctxR: Tree, body: Tree) = { 312 | val arg = currentOwner.newValueParameter(name, ctxR.pos).setInfo(tpe) 313 | val body1 = (new TreeSymSubstituter(List(vd.symbol), List(arg)))(body) 314 | val fun = localTyper.typedPos(vd.symbol.pos) { Function(List(ValDef(arg)), body1) } // types body as well 315 | arg.owner = fun.symbol 316 | body1.changeOwner(currentOwner -> fun.symbol) 317 | 318 | // see note about multiple traversals above 319 | 320 | debuglog("fun.symbol: "+fun.symbol) 321 | debuglog("fun.symbol.owner: "+fun.symbol.owner) 322 | debuglog("arg.owner: "+arg.owner) 323 | 324 | debuglog("fun.tpe:"+fun.tpe) 325 | debuglog("return type of fun:"+body1.tpe) 326 | 327 | var methodName = nme.map 328 | 329 | if (body1.tpe != null) { 330 | if (body1.tpe.typeSymbol == Context) 331 | methodName = nme.flatMap 332 | } 333 | else 334 | reporter.error(rhs.pos, "cannot compute type for CPS-transformed function result") 335 | 336 | debuglog("will use method:"+methodName) 337 | 338 | localTyper.typedPos(vd.symbol.pos) { 339 | Apply(Select(ctxR, ctxR.tpe.member(methodName)), List(fun)) 340 | } 341 | } 342 | 343 | // TODO use gen.mkBlock after 2.11.0-M6. Why wait? It allows us to still build in development 344 | // mode with `ant -DskipLocker=1` 345 | def mkBlock(stms: List[Tree], expr: Tree) = if (stms.nonEmpty) Block(stms, expr) else expr 346 | 347 | try { 348 | if (specialCaseTrivial) { 349 | debuglog("will optimize possible tail call: " + bodyExpr) 350 | 351 | // FIXME: flatMap impl has become more complicated due to 352 | // exceptions. do we need to put a try/catch in the then part?? 353 | 354 | // val ctx = 355 | // if (ctx.isTrivial) 356 | // val = ctx.getTrivialValue; ... <--- TODO: try/catch ??? don't bother for the moment... 357 | // else 358 | // ctx.flatMap { => ... } 359 | val ctxSym = currentOwner.newValue(newTermName("" + vd.symbol.name + cpsNames.shiftSuffix)).setInfo(rhs1.tpe) 360 | val ctxDef = localTyper.typed(ValDef(ctxSym, rhs1)) 361 | def ctxRef = localTyper.typed(Ident(ctxSym)) 362 | val argSym = currentOwner.newValue(vd.symbol.name.toTermName).setInfo(tpe) 363 | val argDef = localTyper.typed(ValDef(argSym, Select(ctxRef, ctxRef.tpe.member(cpsNames.getTrivialValue)))) 364 | val switchExpr = localTyper.typedPos(vd.symbol.pos) { 365 | val body2 = mkBlock(bodyStms, bodyExpr).duplicate // dup before typing! 366 | If(Select(ctxRef, ctxSym.tpe.member(cpsNames.isTrivial)), 367 | applyTrivial(argSym, mkBlock(argDef::bodyStms, bodyExpr)), 368 | applyCombinatorFun(ctxRef, body2)) 369 | } 370 | (List(ctxDef), switchExpr) 371 | } else { 372 | // ctx.flatMap { => ... } 373 | // or 374 | // ctx.map { => ... } 375 | (Nil, applyCombinatorFun(rhs1, mkBlock(bodyStms, bodyExpr))) 376 | } 377 | } catch { 378 | case ex:TypeError => 379 | reporter.error(ex.pos, ex.msg) 380 | (bodyStms, bodyExpr) 381 | } 382 | 383 | case _ => 384 | val stm1 = transform(stm) 385 | val (a, b) = transBlock(rest, expr) 386 | (stm1::a, b) 387 | } 388 | } 389 | } 390 | 391 | 392 | } 393 | } 394 | -------------------------------------------------------------------------------- /plugin/src/main/scala/scala/tools/selectivecps/CPSAnnotationChecker.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.tools.selectivecps 14 | 15 | import scala.tools.nsc.{ Global, Mode } 16 | import scala.tools.nsc.MissingRequirementError 17 | 18 | abstract class CPSAnnotationChecker extends CPSUtils { 19 | val global: Global 20 | import global._ 21 | import analyzer.{AnalyzerPlugin, Typer} 22 | import definitions._ 23 | 24 | //override val verbose = true 25 | @inline override final def vprintln(x: =>Any): Unit = if (verbose) println(x) 26 | 27 | /** 28 | * Checks whether @cps annotations conform 29 | */ 30 | object checker extends AnnotationChecker { 31 | private[CPSAnnotationChecker] def addPlusMarker(tp: Type) = tp withAnnotation newPlusMarker() 32 | private[CPSAnnotationChecker] def addMinusMarker(tp: Type) = tp withAnnotation newMinusMarker() 33 | 34 | private[CPSAnnotationChecker] def cleanPlus(tp: Type) = 35 | removeAttribs(tp, MarkerCPSAdaptPlus, MarkerCPSTypes) 36 | private[CPSAnnotationChecker] def cleanPlusWith(tp: Type)(newAnnots: AnnotationInfo*) = 37 | cleanPlus(tp) withAnnotations newAnnots.toList 38 | 39 | /** Check annotations to decide whether tpe1 <:< tpe2 */ 40 | def annotationsConform(tpe1: Type, tpe2: Type): Boolean = { 41 | if (!cpsEnabled) return true 42 | 43 | vprintln("check annotations: " + tpe1 + " <:< " + tpe2) 44 | 45 | // Nothing is least element, but Any is not the greatest 46 | if (tpe1.typeSymbol eq NothingClass) 47 | return true 48 | 49 | val annots1 = cpsParamAnnotation(tpe1) 50 | val annots2 = cpsParamAnnotation(tpe2) 51 | 52 | // @plus and @minus should only occur at the left, and never together 53 | // TODO: insert check 54 | 55 | // @minus @cps is the same as no annotations 56 | if (hasMinusMarker(tpe1)) 57 | return annots2.isEmpty 58 | 59 | // to handle answer type modification, we must make @plus <:< @cps 60 | if (hasPlusMarker(tpe1) && annots1.isEmpty) 61 | return true 62 | 63 | // @plus @cps will fall through and compare the @cps type args 64 | // @cps parameters must match exactly 65 | if ((annots1 corresponds annots2)(_.atp <:< _.atp)) 66 | return true 67 | 68 | // Need to handle uninstantiated type vars specially: 69 | 70 | // g map (x => x) with expected type List[Int] @cps 71 | // results in comparison ?That <:< List[Int] @cps 72 | 73 | // Instantiating ?That to an annotated type would fail during 74 | // transformation. 75 | 76 | // Instead we force-compare tpe1 <:< tpe2.withoutAnnotations 77 | // to trigger instantiation of the TypeVar to the base type 78 | 79 | // This is a bit unorthodox (we're only supposed to look at 80 | // annotations here) but seems to work. 81 | 82 | if (!annots2.isEmpty && !tpe1.isGround) 83 | return tpe1 <:< tpe2.withoutAnnotations 84 | 85 | false 86 | } 87 | 88 | /** Refine the computed least upper bound of a list of types. 89 | * All this should do is add annotations. */ 90 | override def annotationsLub(tpe: Type, ts: List[Type]): Type = { 91 | if (!cpsEnabled) return tpe 92 | 93 | val annots1 = cpsParamAnnotation(tpe) 94 | val annots2 = ts flatMap cpsParamAnnotation 95 | 96 | if (annots2.nonEmpty) { 97 | val cpsLub = newMarker(global.lub(annots1:::annots2 map (_.atp))) 98 | val tpe1 = if (annots1.nonEmpty) removeAttribs(tpe, MarkerCPSTypes) else tpe 99 | tpe1.withAnnotation(cpsLub) 100 | } 101 | else tpe 102 | } 103 | 104 | /** Refine the bounds on type parameters to the given type arguments. */ 105 | override def adaptBoundsToAnnotations(bounds: List[TypeBounds], tparams: List[Symbol], targs: List[Type]): List[TypeBounds] = { 106 | if (!cpsEnabled) return bounds 107 | 108 | val anyAtCPS = newCpsParamsMarker(NothingTpe, AnyTpe) 109 | if (isFunctionType(tparams.head.owner.tpe_*) || isPartialFunctionType(tparams.head.owner.tpe_*)) { 110 | vprintln("function bound: " + tparams.head.owner.tpe + "/"+bounds+"/"+targs) 111 | if (hasCpsParamTypes(targs.last)) 112 | bounds.reverse match { 113 | case res::b if !hasCpsParamTypes(res.hi) => 114 | (TypeBounds(res.lo, res.hi.withAnnotation(anyAtCPS))::b).reverse 115 | case _ => bounds 116 | } 117 | else 118 | bounds 119 | } 120 | else if (tparams.head.owner == ByNameParamClass) { 121 | vprintln("byname bound: " + tparams.head.owner.tpe + "/"+bounds+"/"+targs) 122 | val TypeBounds(lo, hi) = bounds.head 123 | if (hasCpsParamTypes(targs.head) && !hasCpsParamTypes(hi)) 124 | TypeBounds(lo, hi withAnnotation anyAtCPS) :: Nil 125 | else bounds 126 | } else 127 | bounds 128 | } 129 | } 130 | 131 | object plugin extends AnalyzerPlugin { 132 | 133 | import checker._ 134 | 135 | override def canAdaptAnnotations(tree: Tree, typer: Typer, mode: Mode, pt: Type): Boolean = { 136 | if (!cpsEnabled) return false 137 | vprintln("can adapt annotations? " + tree + " / " + tree.tpe + " / " + mode + " / " + pt) 138 | 139 | val annots1 = cpsParamAnnotation(tree.tpe) 140 | val annots2 = cpsParamAnnotation(pt) 141 | 142 | if (mode.inPatternMode) { 143 | //println("can adapt pattern annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) 144 | if (!annots1.isEmpty) { 145 | return true 146 | } 147 | } 148 | 149 | /* 150 | // not precise enough -- still relying on addAnnotations to remove things from ValDef symbols 151 | if (mode.inAllModes(TYPEmode | BYVALmode)) { 152 | if (!annots1.isEmpty) { 153 | return true 154 | } 155 | } 156 | */ 157 | 158 | /* 159 | this interferes with overloading resolution 160 | if (mode.inByValMode && tree.tpe <:< pt) { 161 | vprintln("already compatible, can't adapt further") 162 | return false 163 | } 164 | */ 165 | if (mode.inExprMode) { 166 | if ((annots1 corresponds annots2)(_.atp <:< _.atp)) { 167 | vprintln("already same, can't adapt further") 168 | false 169 | } else if (annots1.isEmpty && !annots2.isEmpty && !mode.inByValMode) { 170 | //println("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) 171 | if (!hasPlusMarker(tree.tpe)) { 172 | // val base = tree.tpe <:< removeAllCPSAnnotations(pt) 173 | // val known = global.analyzer.isFullyDefined(pt) 174 | // println(same + "/" + base + "/" + known) 175 | //val same = annots2 forall { case AnnotationInfo(atp: TypeRef, _, _) => atp.typeArgs(0) =:= atp.typeArgs(1) } 176 | // TBD: use same or not? 177 | //if (same) { 178 | vprintln("yes we can!! (unit)") 179 | true 180 | //} 181 | } else false 182 | } else if (!hasPlusMarker(tree.tpe) && annots1.isEmpty && !annots2.isEmpty && typer.context.inReturnExpr) { 183 | vprintln("checking enclosing method's result type without annotations") 184 | tree.tpe <:< pt.withoutAnnotations 185 | } else if (!hasMinusMarker(tree.tpe) && !annots1.isEmpty && mode.inByValMode) { 186 | val optCpsTypes: Option[(Type, Type)] = cpsParamTypes(tree.tpe) 187 | val optExpectedCpsTypes: Option[(Type, Type)] = cpsParamTypes(pt) 188 | if (optCpsTypes.isEmpty || optExpectedCpsTypes.isEmpty) { 189 | vprintln("yes we can!! (byval)") 190 | true 191 | } else { // check cps param types 192 | val cpsTpes = optCpsTypes.get 193 | val cpsPts = optExpectedCpsTypes.get 194 | // class cpsParam[-B,+C], therefore: 195 | cpsPts._1 <:< cpsTpes._1 && cpsTpes._2 <:< cpsPts._2 196 | } 197 | } else false 198 | } else false 199 | } 200 | 201 | override def adaptAnnotations(tree: Tree, typer: Typer, mode: Mode, pt: Type): Tree = { 202 | if (!cpsEnabled) return tree 203 | 204 | vprintln("adapt annotations " + tree + " / " + tree.tpe + " / " + mode + " / " + pt) 205 | 206 | val annotsTree = cpsParamAnnotation(tree.tpe) 207 | val annotsExpected = cpsParamAnnotation(pt) 208 | def isMissingExpectedAnnots = annotsTree.isEmpty && annotsExpected.nonEmpty 209 | 210 | // not sure I rephrased this comment correctly: 211 | // replacing `mode.inPatternMode` in the condition below by `mode.inPatternMode || mode.inAllModes(TYPEmode | BYVALmode)` 212 | // doesn't work correctly -- still relying on addAnnotations to remove things from ValDef symbols 213 | if (mode.inPatternMode && annotsTree.nonEmpty) tree modifyType removeAllCPSAnnotations 214 | else if (mode.typingExprNotValue && !hasPlusMarker(tree.tpe) && isMissingExpectedAnnots) { // shiftUnit 215 | // add a marker annotation that will make tree.tpe behave as pt, subtyping wise 216 | // tree will look like having any possible annotation 217 | //println("adapt annotations " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt) 218 | 219 | // CAVEAT: 220 | // for monomorphic answer types we want to have @plus @cps (for better checking) 221 | // for answer type modification we want to have only @plus (because actual answer type may differ from pt) 222 | 223 | val res = tree modifyType (_ withAnnotations newPlusMarker() :: annotsExpected) // needed for #1807 224 | vprintln("adapted annotations (not by val) of " + tree + " to " + res.tpe) 225 | res 226 | } else if (mode.typingExprByValue && !hasMinusMarker(tree.tpe) && annotsTree.nonEmpty) { // dropping annotation 227 | // add a marker annotation that will make tree.tpe behave as pt, subtyping wise 228 | // tree will look like having no annotation 229 | val res = tree modifyType addMinusMarker 230 | vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe) 231 | res 232 | } else if (typer.context.inReturnExpr && !hasPlusMarker(tree.tpe) && isMissingExpectedAnnots) { 233 | // add a marker annotation that will make tree.tpe behave as pt, subtyping wise 234 | // tree will look like having any possible annotation 235 | 236 | // note 1: we are only adding a plus marker if the method's result type is a cps type 237 | // (annotsExpected.nonEmpty == cpsParamAnnotation(pt).nonEmpty) 238 | // note 2: we are not adding the expected cps annotations, since they will be added 239 | // by adaptTypeOfReturn (see below). 240 | val res = tree modifyType (_ withAnnotation newPlusMarker()) 241 | vprintln("adapted annotations (return) of " + tree + " to " + res.tpe) 242 | res 243 | } else tree 244 | } 245 | 246 | /** Returns an adapted type for a return expression if the method's result type (pt) is a CPS type. 247 | * Otherwise, it returns the `default` type (`typedReturn` passes `NothingTpe`). 248 | * 249 | * A return expression in a method that has a CPS result type is an error unless the return 250 | * is in tail position. Therefore, we are making sure that only the types of return expressions 251 | * are adapted which will either be removed, or lead to an error. 252 | */ 253 | override def pluginsTypedReturn(default: Type, typer: Typer, tree: Return, pt: Type): Type = { 254 | val expr = tree.expr 255 | // only adapt if method's result type (pt) is cps type 256 | val annots = cpsParamAnnotation(pt) 257 | if (annots.nonEmpty) { 258 | // return type of `expr` without plus marker, but only if it doesn't have other cps annots 259 | if (hasPlusMarker(expr.tpe) && !hasCpsParamTypes(expr.tpe)) 260 | expr.setType(removeAttribs(expr.tpe, MarkerCPSAdaptPlus)) 261 | expr.tpe 262 | } else default 263 | } 264 | 265 | def updateAttributesFromChildren(tpe: Type, childAnnots: List[AnnotationInfo], byName: List[Tree]): Type = { 266 | tpe match { 267 | // Would need to push annots into each alternative of overloaded type 268 | // But we can't, since alternatives aren't types but symbols, which we 269 | // can't change (we'd be affecting symbols globally) 270 | /* 271 | case OverloadedType(pre, alts) => 272 | OverloadedType(pre, alts.map((sym: Symbol) => updateAttributes(pre.memberType(sym), annots))) 273 | */ 274 | case OverloadedType(pre, alts) => tpe //reconstruct correct annotations later 275 | case MethodType(params, restpe) => tpe 276 | case PolyType(params, restpe) => tpe 277 | case _ => 278 | assert(childAnnots forall (_ matches MarkerCPSTypes), childAnnots) 279 | /* 280 | [] + [] = [] 281 | plus + [] = plus 282 | cps + [] = cps 283 | plus cps + [] = plus cps 284 | minus cps + [] = minus cps 285 | synth cps + [] = synth cps // <- synth on left - does it happen? 286 | 287 | [] + cps = cps 288 | plus + cps = synth cps 289 | cps + cps = cps! <- lin 290 | plus cps + cps = synth cps! <- unify 291 | minus cps + cps = minus cps! <- lin 292 | synth cps + cps = synth cps! <- unify 293 | */ 294 | 295 | val plus = hasPlusMarker(tpe) || ( 296 | hasCpsParamTypes(tpe) 297 | && byName.nonEmpty 298 | && (byName forall (t => hasPlusMarker(t.tpe))) 299 | ) 300 | 301 | // move @plus annotations outward from by-name children 302 | if (childAnnots.isEmpty) return { 303 | if (plus) { // @plus or @plus @cps 304 | byName foreach (_ modifyType cleanPlus) 305 | addPlusMarker(tpe) 306 | } 307 | else tpe 308 | } 309 | 310 | val annots1 = cpsParamAnnotation(tpe) 311 | 312 | if (annots1.isEmpty) { // nothing or @plus 313 | cleanPlusWith(tpe)(newSynthMarker(), linearize(childAnnots)) 314 | } 315 | else { 316 | val annot1 = single(annots1) 317 | if (plus) { // @plus @cps 318 | val annot2 = linearize(childAnnots) 319 | 320 | if (annot2.atp <:< annot1.atp) { 321 | try cleanPlusWith(tpe)(newSynthMarker(), annot2) 322 | finally byName foreach (_ modifyType cleanPlus) 323 | } 324 | else throw new TypeError(annot2 + " is not a subtype of " + annot1) 325 | } 326 | else if (hasSynthMarker(tpe)) { // @synth @cps 327 | val annot2 = linearize(childAnnots) 328 | if (annot2.atp <:< annot1.atp) 329 | cleanPlusWith(tpe)(annot2) 330 | else 331 | throw new TypeError(annot2 + " is not a subtype of " + annot1) 332 | } 333 | else // @cps 334 | cleanPlusWith(tpe)(linearize(childAnnots:::annots1)) 335 | } 336 | } 337 | } 338 | 339 | def transArgList(fun: Tree, args: List[Tree]): List[List[Tree]] = { 340 | val formals = fun.tpe.paramTypes 341 | val overshoot = args.length - formals.length 342 | 343 | for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield { 344 | tp match { 345 | case TypeRef(_, ByNameParamClass, List(elemtp)) => 346 | Nil // TODO: check conformance?? 347 | case _ => 348 | List(a) 349 | } 350 | } 351 | } 352 | 353 | 354 | def transStms(stms: List[Tree]): List[Tree] = stms match { 355 | case ValDef(mods, name, tpt, rhs)::xs => 356 | rhs::transStms(xs) 357 | case Assign(lhs, rhs)::xs => 358 | rhs::transStms(xs) 359 | case x::xs => 360 | x::transStms(xs) 361 | case Nil => 362 | Nil 363 | } 364 | 365 | def single(xs: List[AnnotationInfo]) = xs match { 366 | case List(x) => x 367 | case _ => 368 | global.globalError("not a single cps annotation: " + xs) 369 | xs(0) 370 | } 371 | 372 | def emptyOrSingleList(xs: List[AnnotationInfo]) = if (xs.isEmpty) Nil else List(single(xs)) 373 | 374 | def transChildrenInOrder(tree: Tree, tpe: Type, childTrees: List[Tree], byName: List[Tree]) = { 375 | def inspect(t: Tree): List[AnnotationInfo] = { 376 | if (t.tpe eq null) Nil else { 377 | val extra: List[AnnotationInfo] = t.tpe match { 378 | case _: MethodType | _: PolyType | _: OverloadedType => 379 | // method types, poly types and overloaded types do not obtain cps annotions by propagation 380 | // need to reconstruct transitively from their children. 381 | t match { 382 | case Select(qual, name) => inspect(qual) 383 | case Apply(fun, args) => (fun::(transArgList(fun,args).flatten)) flatMap inspect 384 | case TypeApply(fun, args) => (fun::(transArgList(fun,args).flatten)) flatMap inspect 385 | case _ => Nil 386 | } 387 | case _ => Nil 388 | } 389 | 390 | val types = cpsParamAnnotation(t.tpe) 391 | // TODO: check that it has been adapted and if so correctly 392 | extra ++ emptyOrSingleList(types) 393 | } 394 | } 395 | val children = childTrees flatMap inspect 396 | 397 | val newtpe = updateAttributesFromChildren(tpe, children, byName) 398 | 399 | if (!newtpe.annotations.isEmpty) 400 | vprintln("[checker] inferred " + tree + " / " + tpe + " ===> "+ newtpe) 401 | 402 | newtpe 403 | } 404 | 405 | /** Modify the type that has thus far been inferred 406 | * for a tree. All this should do is add annotations. */ 407 | 408 | override def pluginsTyped(tpe: Type, typer: Typer, tree: Tree, mode: Mode, pt: Type): Type = { 409 | if (!cpsEnabled) { 410 | val report = try hasCpsParamTypes(tpe) catch { case _: MissingRequirementError => false } 411 | if (report) 412 | reporter.error(tree.pos, "this code must be compiled with the Scala continuations plugin enabled") 413 | 414 | return tpe 415 | } 416 | 417 | // if (tree.tpe.hasAnnotation(MarkerCPSAdaptPlus)) 418 | // println("addAnnotation " + tree + "/" + tpe) 419 | 420 | tree match { 421 | 422 | case Apply(fun @ Select(qual, name), args) if fun.isTyped => 423 | 424 | // HACK: With overloaded methods, fun will never get annotated. This is because 425 | // the 'overloaded' type gets annotated, but not the alternatives (among which 426 | // fun's type is chosen) 427 | 428 | vprintln("[checker] checking select apply " + tree + "/" + tpe) 429 | 430 | transChildrenInOrder(tree, tpe, qual::(transArgList(fun, args).flatten), Nil) 431 | 432 | case Apply(TypeApply(fun @ Select(qual, name), targs), args) if fun.isTyped => // not trigge 433 | 434 | vprintln("[checker] checking select apply type-apply " + tree + "/" + tpe) 435 | 436 | transChildrenInOrder(tree, tpe, qual::(transArgList(fun, args).flatten), Nil) 437 | 438 | case TypeApply(fun @ Select(qual, name), args) if fun.isTyped => 439 | def stripNullaryMethodType(tp: Type) = tp match { case NullaryMethodType(restpe) => restpe case tp => tp } 440 | vprintln("[checker] checking select type-apply " + tree + "/" + tpe) 441 | 442 | transChildrenInOrder(tree, stripNullaryMethodType(tpe), List(qual, fun), Nil) 443 | 444 | case Apply(fun, args) if fun.isTyped => 445 | 446 | vprintln("[checker] checking unknown apply " + tree + "/" + tpe) 447 | 448 | transChildrenInOrder(tree, tpe, fun::(transArgList(fun, args).flatten), Nil) 449 | 450 | case TypeApply(fun, args) => 451 | 452 | vprintln("[checker] checking unknown type apply " + tree + "/" + tpe) 453 | 454 | transChildrenInOrder(tree, tpe, List(fun), Nil) 455 | 456 | case Select(qual, name) if qual.isTyped => 457 | 458 | vprintln("[checker] checking select " + tree + "/" + tpe) 459 | 460 | // straightforward way is problematic (see select.scala and Test2.scala) 461 | // transChildrenInOrder(tree, tpe, List(qual), Nil) 462 | 463 | // the problem is that qual may be of type OverloadedType (or MethodType) and 464 | // we cannot safely annotate these. so we just ignore these cases and 465 | // clean up later in the Apply/TypeApply trees. 466 | 467 | if (hasCpsParamTypes(qual.tpe)) { 468 | // however there is one special case: 469 | // if it's a method without parameters, just apply it. normally done in adapt, but 470 | // we have to do it here so we don't lose the cps information (wouldn't trigger our 471 | // adapt and there is no Apply/TypeApply created) 472 | tpe match { 473 | case NullaryMethodType(restpe) => 474 | //println("yep: " + restpe + "," + restpe.getClass) 475 | transChildrenInOrder(tree, restpe, List(qual), Nil) 476 | case _ : PolyType => tpe 477 | case _ : MethodType => tpe 478 | case _ : OverloadedType => tpe 479 | case _ => 480 | transChildrenInOrder(tree, tpe, List(qual), Nil) 481 | } 482 | } else 483 | tpe 484 | 485 | case If(cond, thenp, elsep) => 486 | transChildrenInOrder(tree, tpe, List(cond), List(thenp, elsep)) 487 | 488 | case Match(select, cases) => 489 | transChildrenInOrder(tree, tpe, List(select), cases:::(cases map { case CaseDef(_, _, body) => body })) 490 | 491 | case Try(block, catches, finalizer) => 492 | val tpe1 = transChildrenInOrder(tree, tpe, Nil, block::catches:::(catches map { case CaseDef(_, _, body) => body })) 493 | 494 | val annots = cpsParamAnnotation(tpe1) 495 | if (annots.nonEmpty) { 496 | val ann = single(annots) 497 | val (atp0, atp1) = annTypes(ann) 498 | if (!(atp0 =:= atp1)) 499 | throw new TypeError("only simple cps types allowed in try/catch blocks (found: " + tpe1 + ")") 500 | if (!finalizer.isEmpty) // no finalizers allowed. see explanation in SelectiveCPSTransform 501 | typer.context.error(tree.pos, "try/catch blocks that use continuations cannot have finalizers") 502 | } 503 | tpe1 504 | 505 | case Block(stms, expr) => 506 | // if any stm has annotation, so does block 507 | transChildrenInOrder(tree, tpe, transStms(stms), List(expr)) 508 | 509 | case ValDef(mods, name, tpt, rhs) => 510 | vprintln("[checker] checking valdef " + name + "/"+tpe+"/"+tpt+"/"+tree.symbol.tpe) 511 | // ValDef symbols must *not* have annotations! 512 | // lazy vals are currently not supported 513 | // but if we erase here all annotations, compiler will complain only 514 | // when generating bytecode. 515 | // This way lazy vals will be reported as unsupported feature later rather than weird type error. 516 | if (hasAnswerTypeAnn(tree.symbol.info) && !mods.isLazy) { // is it okay to modify sym here? 517 | vprintln("removing annotation from sym " + tree.symbol + "/" + tree.symbol.tpe + "/" + tpt) 518 | tpt modifyType removeAllCPSAnnotations 519 | tree.symbol modifyInfo removeAllCPSAnnotations 520 | } 521 | tpe 522 | 523 | case _ => 524 | tpe 525 | } 526 | 527 | 528 | } 529 | } 530 | } 531 | -------------------------------------------------------------------------------- /plugin/src/main/scala-2.11/scala/tools/selectivecps/SelectiveANFTransform.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.tools.selectivecps 14 | 15 | import scala.tools.nsc.plugins._ 16 | import scala.tools.nsc.symtab._ 17 | import scala.tools.nsc.transform._ 18 | 19 | /** 20 | * In methods marked @cps, explicitly name results of calls to other @cps methods 21 | */ 22 | abstract class SelectiveANFTransform extends PluginComponent with Transform with 23 | TypingTransformers with CPSUtils { 24 | // inherits abstract value `global` and class `Phase` from Transform 25 | 26 | import global._ 27 | import definitions._ // methods to type trees 28 | 29 | override def description = "ANF pre-transform for @cps" 30 | 31 | /** the following two members override abstract members in Transform */ 32 | val phaseName: String = "selectiveanf" 33 | 34 | protected def newTransformer(unit: CompilationUnit): Transformer = 35 | new ANFTransformer(unit) 36 | 37 | class ANFTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { 38 | 39 | var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet) 40 | 41 | object RemoveTailReturnsTransformer extends Transformer { 42 | override def transform(tree: Tree): Tree = tree match { 43 | case Block(stms, r @ Return(expr)) => 44 | treeCopy.Block(tree, stms, expr) 45 | 46 | case Block(stms, expr) => 47 | treeCopy.Block(tree, stms, transform(expr)) 48 | 49 | case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) => 50 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) 51 | 52 | case If(cond, r1 @ Return(thenExpr), elseExpr) => 53 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) 54 | 55 | case If(cond, thenExpr, r2 @ Return(elseExpr)) => 56 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) 57 | 58 | case If(cond, thenExpr, elseExpr) => 59 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) 60 | 61 | case Try(block, catches, finalizer) => 62 | treeCopy.Try(tree, 63 | transform(block), 64 | (catches map (t => transform(t))).asInstanceOf[List[CaseDef]], 65 | transform(finalizer)) 66 | 67 | case CaseDef(pat, guard, r @ Return(expr)) => 68 | treeCopy.CaseDef(tree, pat, guard, expr) 69 | 70 | case CaseDef(pat, guard, body) => 71 | treeCopy.CaseDef(tree, pat, guard, transform(body)) 72 | 73 | case Return(_) => 74 | reporter.error(tree.pos, "return expressions in CPS code must be in tail position") 75 | tree 76 | 77 | case _ => 78 | super.transform(tree) 79 | } 80 | } 81 | 82 | def removeTailReturns(body: Tree): Tree = { 83 | // support body with single return expression 84 | body match { 85 | case Return(expr) => expr 86 | case _ => RemoveTailReturnsTransformer.transform(body) 87 | } 88 | } 89 | 90 | override def transform(tree: Tree): Tree = { 91 | if (!cpsEnabled) return tree 92 | 93 | tree match { 94 | 95 | // Maybe we should further generalize the transform and move it over 96 | // to the regular Transformer facility. But then, actual and required cps 97 | // state would need more complicated (stateful!) tracking. 98 | 99 | // Making the default case use transExpr(tree, None, None) instead of 100 | // calling super.transform() would be a start, but at the moment, 101 | // this would cause infinite recursion. But we could remove the 102 | // ValDef case here. 103 | 104 | case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) => 105 | debuglog("transforming " + dd.symbol) 106 | 107 | atOwner(dd.symbol) { 108 | val rhs = 109 | if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0) 110 | else rhs0 111 | val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))(getExternalAnswerTypeAnn(tpt.tpe).isDefined) 112 | 113 | debuglog("result "+rhs1) 114 | debuglog("result is of type "+rhs1.tpe) 115 | 116 | treeCopy.DefDef(dd, mods, name, transformTypeDefs(tparams), transformValDefss(vparamss), 117 | transform(tpt), rhs1) 118 | } 119 | 120 | case ff @ Function(vparams, body) => 121 | debuglog("transforming anon function " + ff.symbol) 122 | 123 | atOwner(ff.symbol) { 124 | 125 | //val body1 = transExpr(body, None, getExternalAnswerTypeAnn(body.tpe)) 126 | 127 | // need to special case partial functions: if expected type is @cps 128 | // but all cases are pure, then we would transform 129 | // { x => x match { case A => ... }} to 130 | // { x => shiftUnit(x match { case A => ... })} 131 | // which Uncurry cannot handle (see function6.scala) 132 | // thus, we push down the shiftUnit to each of the case bodies 133 | 134 | val ext = getExternalAnswerTypeAnn(body.tpe) 135 | val pureBody = getAnswerTypeAnn(body.tpe).isEmpty 136 | implicit val isParentImpure = ext.isDefined 137 | 138 | def transformPureMatch(tree: Tree, selector: Tree, cases: List[CaseDef]) = { 139 | val caseVals = cases map { case cd @ CaseDef(pat, guard, body) => 140 | // if (!hasPlusMarker(body.tpe)) body modifyType (_ withAnnotation newPlusMarker()) // TODO: to avoid warning 141 | val bodyVal = transExpr(body, None, ext) // ??? triggers "cps-transformed unexpectedly" warning in transTailValue 142 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) 143 | } 144 | treeCopy.Match(tree, transform(selector), caseVals) 145 | } 146 | 147 | def transformPureVirtMatch(body: Block, selDef: ValDef, cases: List[Tree], matchEnd: Tree) = { 148 | val stats = transform(selDef) :: (cases map (transExpr(_, None, ext))) 149 | treeCopy.Block(body, stats, transExpr(matchEnd, None, ext)) 150 | } 151 | 152 | val body1 = body match { 153 | case Match(selector, cases) if ext.isDefined && pureBody => 154 | transformPureMatch(body, selector, cases) 155 | 156 | // virtpatmat switch 157 | case Block(List(selDef: ValDef), mat@Match(selector, cases)) if ext.isDefined && pureBody => 158 | treeCopy.Block(body, List(transform(selDef)), transformPureMatch(mat, selector, cases)) 159 | 160 | // virtpatmat 161 | case b@Block(matchStats@((selDef: ValDef) :: cases), matchEnd) if ext.isDefined && pureBody && (matchStats forall treeInfo.hasSynthCaseSymbol) => 162 | transformPureVirtMatch(b, selDef, cases, matchEnd) 163 | 164 | // virtpatmat that stores the scrut separately -- TODO: can we eliminate this case?? 165 | case Block(List(selDef0: ValDef), mat@Block(matchStats@((selDef: ValDef) :: cases), matchEnd)) if ext.isDefined && pureBody && (matchStats forall treeInfo.hasSynthCaseSymbol)=> 166 | treeCopy.Block(body, List(transform(selDef0)), transformPureVirtMatch(mat, selDef, cases, matchEnd)) 167 | 168 | case _ => 169 | transExpr(body, None, ext) 170 | } 171 | 172 | debuglog("anf result "+body1+"\nresult is of type "+body1.tpe) 173 | 174 | treeCopy.Function(ff, transformValDefs(vparams), body1) 175 | } 176 | 177 | case vd @ ValDef(mods, name, tpt, rhs) => // object-level valdefs 178 | debuglog("transforming valdef " + vd.symbol) 179 | 180 | if (getExternalAnswerTypeAnn(tpt.tpe).isEmpty) { 181 | 182 | atOwner(vd.symbol) { 183 | 184 | val rhs1 = transExpr(rhs, None, None) 185 | 186 | treeCopy.ValDef(vd, mods, name, transform(tpt), rhs1) 187 | } 188 | } else { 189 | reporter.error(tree.pos, "cps annotations not allowed on by-value parameters or value definitions") 190 | super.transform(tree) 191 | } 192 | 193 | case TypeTree() => 194 | // circumvent cpsAllowed here 195 | super.transform(tree) 196 | 197 | case Apply(_,_) => 198 | // this allows reset { ... } in object constructors 199 | // it's kind of a hack to put it here (see note above) 200 | transExpr(tree, None, None) 201 | 202 | case _ => 203 | if (hasAnswerTypeAnn(tree.tpe)) { 204 | if (!cpsAllowed) { 205 | if (tree.symbol.isLazy) 206 | reporter.error(tree.pos, "implementation restriction: cps annotations not allowed on lazy value definitions") 207 | else 208 | reporter.error(tree.pos, "cps code not allowed here / " + tree.getClass + " / " + tree) 209 | } 210 | log(tree) 211 | } 212 | 213 | cpsAllowed = false 214 | super.transform(tree) 215 | } 216 | } 217 | 218 | 219 | def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean = false): Tree = { 220 | transTailValue(tree, cpsA, cpsR)(cpsR.isDefined || isAnyParentImpure) match { 221 | case (Nil, b) => b 222 | case (a, b) => 223 | treeCopy.Block(tree, a,b) 224 | } 225 | } 226 | 227 | 228 | def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[List[Tree]], List[Tree], CPSInfo) = { 229 | val formals = fun.tpe.paramTypes 230 | val overshoot = args.length - formals.length 231 | 232 | var spc: CPSInfo = cpsA 233 | 234 | val (stm,expr) = (for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield { 235 | tp match { 236 | case TypeRef(_, ByNameParamClass, List(elemtp)) => 237 | // note that we're not passing just isAnyParentImpure 238 | (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))(getAnswerTypeAnn(elemtp).isDefined || isAnyParentImpure)) 239 | case _ => 240 | val (valStm, valExpr, valSpc) = transInlineValue(a, spc) 241 | spc = valSpc 242 | (valStm, valExpr) 243 | } 244 | }).unzip 245 | 246 | (stm,expr,spc) 247 | } 248 | 249 | 250 | // precondition: cpsR.isDefined "implies" isAnyParentImpure 251 | def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = { 252 | // return value: (stms, expr, spc), where spc is CPSInfo after stms but *before* expr 253 | implicit val pos = tree.pos 254 | tree match { 255 | case Block(stms, expr) => 256 | val (cpsA2, cpsR2) = (cpsA, linearize(cpsA, getAnswerTypeAnn(tree.tpe))) // tbd 257 | // val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe)) 258 | 259 | val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) 260 | val tree1 = (treeCopy.Block(tree, a, b)) // no updateSynthFlag here!!! 261 | 262 | (Nil, tree1, cpsA) 263 | 264 | case If(cond, thenp, elsep) => 265 | /* possible situations: 266 | cps before (cpsA) 267 | cps in condition (spc) <-- synth flag set if *only* here! 268 | cps in (one or both) branches */ 269 | val (condStats, condVal, spc) = transInlineValue(cond, cpsA) 270 | val (cpsA2, cpsR2) = if (hasSynthMarker(tree.tpe)) 271 | (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else 272 | (None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly 273 | val thenVal = transExpr(thenp, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) 274 | val elseVal = transExpr(elsep, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) 275 | 276 | // check that then and else parts agree (not necessary any more, but left as sanity check) 277 | if (cpsR.isDefined) { 278 | if (elsep == EmptyTree) 279 | reporter.error(tree.pos, "always need else part in cps code") 280 | } 281 | if (hasAnswerTypeAnn(thenVal.tpe) != hasAnswerTypeAnn(elseVal.tpe)) { 282 | reporter.error(tree.pos, "then and else parts must both be cps code or neither of them") 283 | } 284 | 285 | (condStats, updateSynthFlag(treeCopy.If(tree, condVal, thenVal, elseVal)), spc) 286 | 287 | case Match(selector, cases) => 288 | val (selStats, selVal, spc) = transInlineValue(selector, cpsA) 289 | val (cpsA2, cpsR2) = 290 | if (hasSynthMarker(tree.tpe)) (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) 291 | else (None, getAnswerTypeAnn(tree.tpe)) 292 | 293 | val caseVals = cases map { case cd @ CaseDef(pat, guard, body) => 294 | val bodyVal = transExpr(body, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) 295 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) 296 | } 297 | 298 | (selStats, updateSynthFlag(treeCopy.Match(tree, selVal, caseVals)), spc) 299 | 300 | // this is utterly broken: LabelDefs need to be considered together when transforming them to DefDefs: 301 | // suppose a Block {L1; ... ; LN} 302 | // this should become {D1def ; ... ; DNdef ; D1()} 303 | // where D$idef = def L$i(..) = {L$i.body; L${i+1}(..)} 304 | 305 | case ldef @ LabelDef(name, params, rhs) => 306 | // println("trans LABELDEF "+(name, params, tree.tpe, hasAnswerTypeAnn(tree.tpe))) 307 | // TODO why does the labeldef's type have a cpsMinus annotation, whereas the rhs does not? (BYVALmode missing/too much somewhere?) 308 | if (hasAnswerTypeAnn(tree.tpe)) { 309 | // currentOwner.newMethod(name, tree.pos, Flags.SYNTHETIC) setInfo ldef.symbol.info 310 | val sym = ldef.symbol resetFlag Flags.LABEL 311 | val rhs1 = rhs //new TreeSymSubstituter(List(ldef.symbol), List(sym)).transform(rhs) 312 | val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe))(getAnswerTypeAnn(tree.tpe).isDefined || isAnyParentImpure) changeOwner (currentOwner -> sym) 313 | 314 | val stm1 = localTyper.typed(DefDef(sym, rhsVal)) 315 | // since virtpatmat does not rely on fall-through, don't call the labels it emits 316 | // transBlock will take care of calling the first label 317 | // calling each labeldef is wrong, since some labels may be jumped over 318 | // we can get away with this for now since the only other labels we emit are for tailcalls/while loops, 319 | // which do not have consecutive labeldefs (and thus fall-through is irrelevant) 320 | if (treeInfo.hasSynthCaseSymbol(ldef)) (List(stm1), localTyper.typed{Literal(Constant(()))}, cpsA) 321 | else { 322 | assert(params.isEmpty, "problem in ANF transforming label with non-empty params "+ ldef) 323 | (List(stm1), localTyper.typed{Apply(Ident(sym), List())}, cpsA) 324 | } 325 | } else { 326 | val rhsVal = transExpr(rhs, None, None) 327 | (Nil, updateSynthFlag(treeCopy.LabelDef(tree, name, params, rhsVal)), cpsA) 328 | } 329 | 330 | 331 | case Try(block, catches, finalizer) => 332 | val blockVal = transExpr(block, cpsA, cpsR) 333 | 334 | val catchVals = for { 335 | cd @ CaseDef(pat, guard, body) <- catches 336 | bodyVal = transExpr(body, cpsA, cpsR) 337 | } yield { 338 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) 339 | } 340 | 341 | val finallyVal = transExpr(finalizer, None, None) // for now, no cps in finally 342 | 343 | (Nil, updateSynthFlag(treeCopy.Try(tree, blockVal, catchVals, finallyVal)), cpsA) 344 | 345 | case Assign(lhs, rhs) => 346 | // allow cps code in rhs only 347 | val (stms, expr, spc) = transInlineValue(rhs, cpsA) 348 | (stms, updateSynthFlag(treeCopy.Assign(tree, transform(lhs), expr)), spc) 349 | 350 | case Return(expr0) => 351 | if (isAnyParentImpure) 352 | reporter.error(tree.pos, "return expression not allowed, since method calls CPS method") 353 | val (stms, expr, spc) = transInlineValue(expr0, cpsA) 354 | (stms, updateSynthFlag(treeCopy.Return(tree, expr)), spc) 355 | 356 | case Throw(expr0) => 357 | val (stms, expr, spc) = transInlineValue(expr0, cpsA) 358 | (stms, updateSynthFlag(treeCopy.Throw(tree, expr)), spc) 359 | 360 | case Typed(expr0, tpt) => 361 | // TODO: should x: A @cps[B,C] have a special meaning? 362 | // type casts used in different ways (see match2.scala, #3199) 363 | val (stms, expr, spc) = transInlineValue(expr0, cpsA) 364 | val tpt1 = if (treeInfo.isWildcardStarArg(tree)) tpt else 365 | treeCopy.TypeTree(tpt).setType(removeAllCPSAnnotations(tpt.tpe)) 366 | // (stms, updateSynthFlag(treeCopy.Typed(tree, expr, tpt1)), spc) 367 | (stms, treeCopy.Typed(tree, expr, tpt1).setType(removeAllCPSAnnotations(tree.tpe)), spc) 368 | 369 | case TypeApply(fun, args) => 370 | val (stms, expr, spc) = transInlineValue(fun, cpsA) 371 | (stms, updateSynthFlag(treeCopy.TypeApply(tree, expr, args)), spc) 372 | 373 | case Select(qual, name) => 374 | val (stms, expr, spc) = transInlineValue(qual, cpsA) 375 | (stms, updateSynthFlag(treeCopy.Select(tree, expr, name)), spc) 376 | 377 | case Apply(fun, args) => 378 | val (funStm, funExpr, funSpc) = transInlineValue(fun, cpsA) 379 | val (argStm, argExpr, argSpc) = transArgList(fun, args, funSpc) 380 | 381 | (funStm ::: (argStm.flatten), updateSynthFlag(treeCopy.Apply(tree, funExpr, argExpr)), 382 | argSpc) 383 | 384 | case _ => 385 | cpsAllowed = true 386 | (Nil, transform(tree), cpsA) 387 | } 388 | } 389 | 390 | // precondition: cpsR.isDefined "implies" isAnyParentImpure 391 | def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = { 392 | 393 | val (stms, expr, spc) = transValue(tree, cpsA, cpsR) 394 | 395 | val bot = linearize(spc, getAnswerTypeAnn(expr.tpe))(tree.pos) 396 | 397 | val plainTpe = removeAllCPSAnnotations(expr.tpe) 398 | 399 | if (cpsR.isDefined && !bot.isDefined) { 400 | 401 | if (!expr.isEmpty && (expr.tpe.typeSymbol ne NothingClass)) { 402 | // must convert! 403 | debuglog("cps type conversion (has: " + cpsA + "/" + spc + "/" + expr.tpe + ")") 404 | debuglog("cps type conversion (expected: " + cpsR.get + "): " + expr) 405 | 406 | if (!hasPlusMarker(expr.tpe)) 407 | reporter.warning(tree.pos, "expression " + tree + " is cps-transformed unexpectedly") 408 | 409 | try { 410 | val Some((a, b)) = cpsR 411 | /* Since shiftUnit is bounded [A,B,C>:B] this may not typecheck 412 | * if C is overly specific. So if !(B <:< C), call shiftUnit0 413 | * instead, which takes only two type arguments. 414 | */ 415 | val conforms = a <:< b 416 | val call = localTyper.typedPos(tree.pos)( 417 | Apply( 418 | TypeApply( 419 | gen.mkAttributedRef( if (conforms) MethShiftUnit else MethShiftUnit0 ), 420 | List(TypeTree(plainTpe), TypeTree(a)) ++ ( if (conforms) List(TypeTree(b)) else Nil ) 421 | ), 422 | List(expr) 423 | ) 424 | ) 425 | // This is today's sick/meaningless heuristic for spotting breakdown so 426 | // we don't proceed until stack traces start draping themselves over everything. 427 | // If there are wildcard types in the tree and B == Nothing, something went wrong. 428 | // (I thought WildcardTypes would be enough, but nope. 'reset0 { 0 }' has them.) 429 | // 430 | // Code as simple as reset((_: String).length) 431 | // will crash meaninglessly without this check. See SI-3718. 432 | // 433 | // TODO - obviously this should be done earlier, differently, or with 434 | // a more skilled hand. Most likely, all three. 435 | if ((b.typeSymbol eq NothingClass) && call.tpe.exists(_ eq WildcardType)) 436 | reporter.error(tree.pos, "cannot cps-transform malformed (possibly in shift/reset placement) expression") 437 | else 438 | return ((stms, call)) 439 | } 440 | catch { 441 | case ex:TypeError => 442 | reporter.error(ex.pos, "cannot cps-transform expression " + tree + ": " + ex.msg) 443 | } 444 | } 445 | 446 | } else if (!cpsR.isDefined && bot.isDefined) { 447 | // error! 448 | debuglog("cps type error: " + expr) 449 | //println("cps type error: " + expr + "/" + expr.tpe + "/" + getAnswerTypeAnn(expr.tpe)) 450 | 451 | //println(cpsR + "/" + spc + "/" + bot) 452 | 453 | reporter.error(tree.pos, "found cps expression in non-cps position") 454 | } else { 455 | // all is well 456 | 457 | if (hasPlusMarker(expr.tpe)) { 458 | reporter.warning(tree.pos, "expression " + expr + " of type " + expr.tpe + " is not expected to have a cps type") 459 | expr modifyType removeAllCPSAnnotations 460 | } 461 | 462 | // TODO: sanity check that types agree 463 | } 464 | 465 | (stms, expr) 466 | } 467 | 468 | def transInlineValue(tree: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = { 469 | 470 | val (stms, expr, spc) = transValue(tree, cpsA, None) // never required to be cps 471 | 472 | getAnswerTypeAnn(expr.tpe) match { 473 | case spcVal @ Some(_) => 474 | 475 | val valueTpe = removeAllCPSAnnotations(expr.tpe) 476 | 477 | val sym: Symbol = ( 478 | currentOwner.newValue(newTermName(unit.fresh.newName("tmp")), tree.pos, Flags.SYNTHETIC) 479 | setInfo valueTpe 480 | setAnnotations List(AnnotationInfo(MarkerCPSSym.tpe_*, Nil, Nil)) 481 | ) 482 | expr.changeOwner(currentOwner -> sym) 483 | 484 | (stms ::: List(ValDef(sym, expr) setType(NoType)), 485 | Ident(sym) setType(valueTpe) setPos(tree.pos), linearize(spc, spcVal)(tree.pos)) 486 | 487 | case _ => 488 | (stms, expr, spc) 489 | } 490 | 491 | } 492 | 493 | 494 | 495 | def transInlineStm(stm: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], CPSInfo) = { 496 | stm match { 497 | 498 | // TODO: what about DefDefs? 499 | // TODO: relation to top-level val def? 500 | // TODO: what about lazy vals? 501 | 502 | case tree @ ValDef(mods, name, tpt, rhs) => 503 | val (stms, anfRhs, spc) = atOwner(tree.symbol) { transValue(rhs, cpsA, None) } 504 | 505 | val tv = new ChangeOwnerTraverser(tree.symbol, currentOwner) 506 | stms.foreach(tv.traverse(_)) 507 | 508 | // TODO: symbol might already have annotation. Should check conformance 509 | // TODO: better yet: do without annotations on symbols 510 | 511 | val spcVal = getAnswerTypeAnn(anfRhs.tpe) 512 | spcVal foreach (_ => tree.symbol setAnnotations List(AnnotationInfo(MarkerCPSSym.tpe_*, Nil, Nil))) 513 | 514 | (stms:::List(treeCopy.ValDef(tree, mods, name, tpt, anfRhs)), linearize(spc, spcVal)(tree.pos)) 515 | 516 | case _ => 517 | val (headStms, headExpr, headSpc) = transInlineValue(stm, cpsA) 518 | val valSpc = getAnswerTypeAnn(headExpr.tpe) 519 | (headStms:::List(headExpr), linearize(headSpc, valSpc)(stm.pos)) 520 | } 521 | } 522 | 523 | // precondition: cpsR.isDefined "implies" isAnyParentImpure 524 | def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = { 525 | def rec(currStats: List[Tree], currAns: CPSInfo, accum: List[Tree]): (List[Tree], Tree) = 526 | currStats match { 527 | case Nil => 528 | val (anfStats, anfExpr) = transTailValue(expr, currAns, cpsR) 529 | (accum ++ anfStats, anfExpr) 530 | 531 | case stat :: rest => 532 | val (stats, nextAns) = transInlineStm(stat, currAns) 533 | rec(rest, nextAns, accum ++ stats) 534 | } 535 | 536 | val (anfStats, anfExpr) = rec(stms, cpsA, List()) 537 | // println("\nanf-block:\n"+ ((stms :+ expr) mkString ("{", "\n", "}")) +"\nBECAME\n"+ ((anfStats :+ anfExpr) mkString ("{", "\n", "}"))) 538 | // println("synth case? "+ (anfStats map (t => (t, t.isDef, treeInfo.hasSynthCaseSymbol(t))))) 539 | // SUPER UGLY HACK: handle virtpatmat-style matches, whose labels have already been turned into DefDefs 540 | if (anfStats.nonEmpty && (anfStats forall (t => !t.isDef || treeInfo.hasSynthCaseSymbol(t)))) { 541 | val (prologue, rest) = (anfStats :+ anfExpr) span (s => !s.isInstanceOf[DefDef]) // find first case 542 | // println("rest: "+ rest) 543 | // val (defs, calls) = rest partition (_.isInstanceOf[DefDef]) 544 | if (rest.nonEmpty) { 545 | // the filter drops the ()'s emitted when transValue encountered a LabelDef 546 | val stats = prologue ++ (rest filter (_.isInstanceOf[DefDef])).reverse // ++ calls 547 | // println("REVERSED "+ (stats mkString ("{", "\n", "}"))) 548 | (stats, localTyper.typed{Apply(Ident(rest.head.symbol), List())}) // call first label to kick-start the match 549 | } else (anfStats, anfExpr) 550 | } else (anfStats, anfExpr) 551 | } 552 | } 553 | } 554 | -------------------------------------------------------------------------------- /plugin/src/main/scala-2.12/scala/tools/selectivecps/SelectiveANFTransform.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.tools.selectivecps 14 | 15 | import scala.tools.nsc.transform._ 16 | import scala.tools.nsc.symtab._ 17 | import scala.tools.nsc.plugins._ 18 | 19 | /** 20 | * In methods marked @cps, explicitly name results of calls to other @cps methods 21 | */ 22 | abstract class SelectiveANFTransform extends PluginComponent with Transform with 23 | TypingTransformers with CPSUtils { 24 | // inherits abstract value `global` and class `Phase` from Transform 25 | 26 | import global._ // the global environment 27 | import definitions._ // standard classes and methods 28 | import typer.atOwner // methods to type trees 29 | 30 | override def description = "ANF pre-transform for @cps" 31 | 32 | /** the following two members override abstract members in Transform */ 33 | val phaseName: String = "selectiveanf" 34 | 35 | protected def newTransformer(unit: CompilationUnit): Transformer = 36 | new ANFTransformer(unit) 37 | 38 | class ANFTransformer(unit: CompilationUnit) extends TypingTransformer(unit) { 39 | 40 | var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet) 41 | 42 | object RemoveTailReturnsTransformer extends Transformer { 43 | override def transform(tree: Tree): Tree = tree match { 44 | case Block(stms, r @ Return(expr)) => 45 | treeCopy.Block(tree, stms, expr) 46 | 47 | case Block(stms, expr) => 48 | treeCopy.Block(tree, stms, transform(expr)) 49 | 50 | case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) => 51 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) 52 | 53 | case If(cond, r1 @ Return(thenExpr), elseExpr) => 54 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) 55 | 56 | case If(cond, thenExpr, r2 @ Return(elseExpr)) => 57 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) 58 | 59 | case If(cond, thenExpr, elseExpr) => 60 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr)) 61 | 62 | case Try(block, catches, finalizer) => 63 | treeCopy.Try(tree, 64 | transform(block), 65 | (catches map (t => transform(t))).asInstanceOf[List[CaseDef]], 66 | transform(finalizer)) 67 | 68 | case CaseDef(pat, guard, r @ Return(expr)) => 69 | treeCopy.CaseDef(tree, pat, guard, expr) 70 | 71 | case CaseDef(pat, guard, body) => 72 | treeCopy.CaseDef(tree, pat, guard, transform(body)) 73 | 74 | case Return(_) => 75 | reporter.error(tree.pos, "return expressions in CPS code must be in tail position") 76 | tree 77 | 78 | case _ => 79 | super.transform(tree) 80 | } 81 | } 82 | 83 | def removeTailReturns(body: Tree): Tree = { 84 | // support body with single return expression 85 | body match { 86 | case Return(expr) => expr 87 | case _ => RemoveTailReturnsTransformer.transform(body) 88 | } 89 | } 90 | 91 | override def transform(tree: Tree): Tree = { 92 | if (!cpsEnabled) return tree 93 | 94 | tree match { 95 | 96 | // Maybe we should further generalize the transform and move it over 97 | // to the regular Transformer facility. But then, actual and required cps 98 | // state would need more complicated (stateful!) tracking. 99 | 100 | // Making the default case use transExpr(tree, None, None) instead of 101 | // calling super.transform() would be a start, but at the moment, 102 | // this would cause infinite recursion. But we could remove the 103 | // ValDef case here. 104 | 105 | case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) => 106 | debuglog("transforming " + dd.symbol) 107 | 108 | atOwner(dd.symbol) { 109 | val rhs = 110 | if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0) 111 | else rhs0 112 | val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))(getExternalAnswerTypeAnn(tpt.tpe).isDefined) 113 | 114 | debuglog("result "+rhs1) 115 | debuglog("result is of type "+rhs1.tpe) 116 | 117 | treeCopy.DefDef(dd, mods, name, transformTypeDefs(tparams), transformValDefss(vparamss), 118 | transform(tpt), rhs1) 119 | } 120 | 121 | case ff @ Function(vparams, body) => 122 | debuglog("transforming anon function " + ff.symbol) 123 | 124 | atOwner(ff.symbol) { 125 | 126 | //val body1 = transExpr(body, None, getExternalAnswerTypeAnn(body.tpe)) 127 | 128 | // need to special case partial functions: if expected type is @cps 129 | // but all cases are pure, then we would transform 130 | // { x => x match { case A => ... }} to 131 | // { x => shiftUnit(x match { case A => ... })} 132 | // which Uncurry cannot handle (see function6.scala) 133 | // thus, we push down the shiftUnit to each of the case bodies 134 | 135 | val ext = getExternalAnswerTypeAnn(body.tpe) 136 | val pureBody = getAnswerTypeAnn(body.tpe).isEmpty 137 | implicit val isParentImpure = ext.isDefined 138 | 139 | def transformPureMatch(tree: Tree, selector: Tree, cases: List[CaseDef]) = { 140 | val caseVals = cases map { case cd @ CaseDef(pat, guard, body) => 141 | // if (!hasPlusMarker(body.tpe)) body modifyType (_ withAnnotation newPlusMarker()) // TODO: to avoid warning 142 | val bodyVal = transExpr(body, None, ext) // ??? triggers "cps-transformed unexpectedly" warning in transTailValue 143 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) 144 | } 145 | treeCopy.Match(tree, transform(selector), caseVals) 146 | } 147 | 148 | def transformPureVirtMatch(body: Block, selDef: ValDef, cases: List[Tree], matchEnd: Tree) = { 149 | val stats = transform(selDef) :: (cases map (transExpr(_, None, ext))) 150 | treeCopy.Block(body, stats, transExpr(matchEnd, None, ext)) 151 | } 152 | 153 | val body1 = body match { 154 | case Match(selector, cases) if ext.isDefined && pureBody => 155 | transformPureMatch(body, selector, cases) 156 | 157 | // virtpatmat switch 158 | case Block(List(selDef: ValDef), mat@Match(selector, cases)) if ext.isDefined && pureBody => 159 | treeCopy.Block(body, List(transform(selDef)), transformPureMatch(mat, selector, cases)) 160 | 161 | // virtpatmat 162 | case b@Block(matchStats@((selDef: ValDef) :: cases), matchEnd) if ext.isDefined && pureBody && (matchStats forall treeInfo.hasSynthCaseSymbol) => 163 | transformPureVirtMatch(b, selDef, cases, matchEnd) 164 | 165 | // virtpatmat that stores the scrut separately -- TODO: can we eliminate this case?? 166 | case Block(List(selDef0: ValDef), mat@Block(matchStats@((selDef: ValDef) :: cases), matchEnd)) if ext.isDefined && pureBody && (matchStats forall treeInfo.hasSynthCaseSymbol)=> 167 | treeCopy.Block(body, List(transform(selDef0)), transformPureVirtMatch(mat, selDef, cases, matchEnd)) 168 | 169 | case _ => 170 | transExpr(body, None, ext) 171 | } 172 | 173 | debuglog("anf result "+body1+"\nresult is of type "+body1.tpe) 174 | 175 | treeCopy.Function(ff, transformValDefs(vparams), body1) 176 | } 177 | 178 | case vd @ ValDef(mods, name, tpt, rhs) => // object-level valdefs 179 | debuglog("transforming valdef " + vd.symbol) 180 | 181 | if (getExternalAnswerTypeAnn(tpt.tpe).isEmpty) { 182 | 183 | atOwner(vd.symbol) { 184 | 185 | val rhs1 = transExpr(rhs, None, None) 186 | 187 | treeCopy.ValDef(vd, mods, name, transform(tpt), rhs1) 188 | } 189 | } else { 190 | reporter.error(tree.pos, "cps annotations not allowed on by-value parameters or value definitions") 191 | super.transform(tree) 192 | } 193 | 194 | case TypeTree() => 195 | // circumvent cpsAllowed here 196 | super.transform(tree) 197 | 198 | case Apply(_,_) => 199 | // this allows reset { ... } in object constructors 200 | // it's kind of a hack to put it here (see note above) 201 | transExpr(tree, None, None) 202 | 203 | case _ => 204 | if (hasAnswerTypeAnn(tree.tpe)) { 205 | if (tree.symbol.isLazy) { 206 | reporter.error(tree.pos, "implementation restriction: cps annotations not allowed on lazy value definitions") 207 | cpsAllowed = false 208 | } else if (!cpsAllowed) 209 | reporter.error(tree.pos, "cps code not allowed here / " + tree.getClass + " / " + tree) 210 | 211 | log(tree) 212 | } 213 | 214 | cpsAllowed = false 215 | super.transform(tree) 216 | } 217 | } 218 | 219 | 220 | def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean = false): Tree = { 221 | transTailValue(tree, cpsA, cpsR)(cpsR.isDefined || isAnyParentImpure) match { 222 | case (Nil, b) => b 223 | case (a, b) => 224 | treeCopy.Block(tree, a,b) 225 | } 226 | } 227 | 228 | 229 | def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[List[Tree]], List[Tree], CPSInfo) = { 230 | val formals = fun.tpe.paramTypes 231 | val overshoot = args.length - formals.length 232 | 233 | var spc: CPSInfo = cpsA 234 | 235 | val (stm,expr) = (for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield { 236 | tp match { 237 | case TypeRef(_, ByNameParamClass, List(elemtp)) => 238 | // note that we're not passing just isAnyParentImpure 239 | (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))(getAnswerTypeAnn(elemtp).isDefined || isAnyParentImpure)) 240 | case _ => 241 | val (valStm, valExpr, valSpc) = transInlineValue(a, spc) 242 | spc = valSpc 243 | (valStm, valExpr) 244 | } 245 | }).unzip 246 | 247 | (stm,expr,spc) 248 | } 249 | 250 | 251 | // precondition: cpsR.isDefined "implies" isAnyParentImpure 252 | def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = { 253 | // return value: (stms, expr, spc), where spc is CPSInfo after stms but *before* expr 254 | implicit val pos = tree.pos 255 | tree match { 256 | case Block(stms, expr) => 257 | val (cpsA2, cpsR2) = (cpsA, linearize(cpsA, getAnswerTypeAnn(tree.tpe))) // tbd 258 | // val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe)) 259 | 260 | val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) 261 | val tree1 = (treeCopy.Block(tree, a, b)) // no updateSynthFlag here!!! 262 | 263 | (Nil, tree1, cpsA) 264 | 265 | case If(cond, thenp, elsep) => 266 | /* possible situations: 267 | cps before (cpsA) 268 | cps in condition (spc) <-- synth flag set if *only* here! 269 | cps in (one or both) branches */ 270 | val (condStats, condVal, spc) = transInlineValue(cond, cpsA) 271 | val (cpsA2, cpsR2) = if (hasSynthMarker(tree.tpe)) 272 | (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else 273 | (None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly 274 | val thenVal = transExpr(thenp, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) 275 | val elseVal = transExpr(elsep, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) 276 | 277 | // check that then and else parts agree (not necessary any more, but left as sanity check) 278 | if (cpsR.isDefined) { 279 | if (elsep == EmptyTree) 280 | reporter.error(tree.pos, "always need else part in cps code") 281 | } 282 | if (hasAnswerTypeAnn(thenVal.tpe) != hasAnswerTypeAnn(elseVal.tpe)) { 283 | reporter.error(tree.pos, "then and else parts must both be cps code or neither of them") 284 | } 285 | 286 | (condStats, updateSynthFlag(treeCopy.If(tree, condVal, thenVal, elseVal)), spc) 287 | 288 | case Match(selector, cases) => 289 | val (selStats, selVal, spc) = transInlineValue(selector, cpsA) 290 | val (cpsA2, cpsR2) = 291 | if (hasSynthMarker(tree.tpe)) (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) 292 | else (None, getAnswerTypeAnn(tree.tpe)) 293 | 294 | val caseVals = cases map { case cd @ CaseDef(pat, guard, body) => 295 | val bodyVal = transExpr(body, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure) 296 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) 297 | } 298 | 299 | (selStats, updateSynthFlag(treeCopy.Match(tree, selVal, caseVals)), spc) 300 | 301 | // this is utterly broken: LabelDefs need to be considered together when transforming them to DefDefs: 302 | // suppose a Block {L1; ... ; LN} 303 | // this should become {D1def ; ... ; DNdef ; D1()} 304 | // where D$idef = def L$i(..) = {L$i.body; L${i+1}(..)} 305 | 306 | case ldef @ LabelDef(name, params, rhs) => 307 | // println("trans LABELDEF "+(name, params, tree.tpe, hasAnswerTypeAnn(tree.tpe))) 308 | // TODO why does the labeldef's type have a cpsMinus annotation, whereas the rhs does not? (BYVALmode missing/too much somewhere?) 309 | if (hasAnswerTypeAnn(tree.tpe)) { 310 | // currentOwner.newMethod(name, tree.pos, Flags.SYNTHETIC) setInfo ldef.symbol.info 311 | val sym = ldef.symbol resetFlag Flags.LABEL 312 | val rhs1 = rhs //new TreeSymSubstituter(List(ldef.symbol), List(sym)).transform(rhs) 313 | val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe))(getAnswerTypeAnn(tree.tpe).isDefined || isAnyParentImpure) changeOwner (currentOwner -> sym) 314 | 315 | val stm1 = localTyper.typed(DefDef(sym, rhsVal)) 316 | // since virtpatmat does not rely on fall-through, don't call the labels it emits 317 | // transBlock will take care of calling the first label 318 | // calling each labeldef is wrong, since some labels may be jumped over 319 | // we can get away with this for now since the only other labels we emit are for tailcalls/while loops, 320 | // which do not have consecutive labeldefs (and thus fall-through is irrelevant) 321 | if (treeInfo.hasSynthCaseSymbol(ldef)) (List(stm1), localTyper.typed{Literal(Constant(()))}, cpsA) 322 | else { 323 | assert(params.isEmpty, "problem in ANF transforming label with non-empty params "+ ldef) 324 | (List(stm1), localTyper.typed{Apply(Ident(sym), List())}, cpsA) 325 | } 326 | } else { 327 | val rhsVal = transExpr(rhs, None, None) 328 | (Nil, updateSynthFlag(treeCopy.LabelDef(tree, name, params, rhsVal)), cpsA) 329 | } 330 | 331 | 332 | case Try(block, catches, finalizer) => 333 | val blockVal = transExpr(block, cpsA, cpsR) 334 | 335 | val catchVals = for { 336 | cd @ CaseDef(pat, guard, body) <- catches 337 | bodyVal = transExpr(body, cpsA, cpsR) 338 | } yield { 339 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal) 340 | } 341 | 342 | val finallyVal = transExpr(finalizer, None, None) // for now, no cps in finally 343 | 344 | (Nil, updateSynthFlag(treeCopy.Try(tree, blockVal, catchVals, finallyVal)), cpsA) 345 | 346 | case Assign(lhs, rhs) => 347 | // allow cps code in rhs only 348 | val (stms, expr, spc) = transInlineValue(rhs, cpsA) 349 | (stms, updateSynthFlag(treeCopy.Assign(tree, transform(lhs), expr)), spc) 350 | 351 | case Return(expr0) => 352 | if (isAnyParentImpure) 353 | reporter.error(tree.pos, "return expression not allowed, since method calls CPS method") 354 | val (stms, expr, spc) = transInlineValue(expr0, cpsA) 355 | (stms, updateSynthFlag(treeCopy.Return(tree, expr)), spc) 356 | 357 | case Throw(expr0) => 358 | val (stms, expr, spc) = transInlineValue(expr0, cpsA) 359 | (stms, updateSynthFlag(treeCopy.Throw(tree, expr)), spc) 360 | 361 | case Typed(expr0, tpt) => 362 | // TODO: should x: A @cps[B,C] have a special meaning? 363 | // type casts used in different ways (see match2.scala, #3199) 364 | val (stms, expr, spc) = transInlineValue(expr0, cpsA) 365 | val tpt1 = if (treeInfo.isWildcardStarArg(tree)) tpt else 366 | treeCopy.TypeTree(tpt).setType(removeAllCPSAnnotations(tpt.tpe)) 367 | // (stms, updateSynthFlag(treeCopy.Typed(tree, expr, tpt1)), spc) 368 | (stms, treeCopy.Typed(tree, expr, tpt1).setType(removeAllCPSAnnotations(tree.tpe)), spc) 369 | 370 | case TypeApply(fun, args) => 371 | val (stms, expr, spc) = transInlineValue(fun, cpsA) 372 | (stms, updateSynthFlag(treeCopy.TypeApply(tree, expr, args)), spc) 373 | 374 | case Select(qual, name) => 375 | val (stms, expr, spc) = transInlineValue(qual, cpsA) 376 | (stms, updateSynthFlag(treeCopy.Select(tree, expr, name)), spc) 377 | 378 | case Apply(fun, args) => 379 | val (funStm, funExpr, funSpc) = transInlineValue(fun, cpsA) 380 | val (argStm, argExpr, argSpc) = transArgList(fun, args, funSpc) 381 | 382 | (funStm ::: (argStm.flatten), updateSynthFlag(treeCopy.Apply(tree, funExpr, argExpr)), 383 | argSpc) 384 | 385 | case _ => 386 | cpsAllowed = true 387 | (Nil, transform(tree), cpsA) 388 | } 389 | } 390 | 391 | // precondition: cpsR.isDefined "implies" isAnyParentImpure 392 | def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = { 393 | 394 | val (stms, expr, spc) = transValue(tree, cpsA, cpsR) 395 | 396 | val bot = linearize(spc, getAnswerTypeAnn(expr.tpe))(tree.pos) 397 | 398 | val plainTpe = removeAllCPSAnnotations(expr.tpe) 399 | 400 | if (cpsR.isDefined && !bot.isDefined) { 401 | 402 | if (!expr.isEmpty && (expr.tpe.typeSymbol ne NothingClass)) { 403 | // must convert! 404 | debuglog("cps type conversion (has: " + cpsA + "/" + spc + "/" + expr.tpe + ")") 405 | debuglog("cps type conversion (expected: " + cpsR.get + "): " + expr) 406 | 407 | if (!hasPlusMarker(expr.tpe)) 408 | reporter.warning(tree.pos, "expression " + tree + " is cps-transformed unexpectedly") 409 | 410 | try { 411 | val Some((a, b)) = cpsR 412 | /* Since shiftUnit is bounded [A,B,C>:B] this may not typecheck 413 | * if C is overly specific. So if !(B <:< C), call shiftUnit0 414 | * instead, which takes only two type arguments. 415 | */ 416 | val conforms = a <:< b 417 | val call = localTyper.typedPos(tree.pos)( 418 | Apply( 419 | TypeApply( 420 | gen.mkAttributedRef( if (conforms) MethShiftUnit else MethShiftUnit0 ), 421 | List(TypeTree(plainTpe), TypeTree(a)) ++ ( if (conforms) List(TypeTree(b)) else Nil ) 422 | ), 423 | List(expr) 424 | ) 425 | ) 426 | // This is today's sick/meaningless heuristic for spotting breakdown so 427 | // we don't proceed until stack traces start draping themselves over everything. 428 | // If there are wildcard types in the tree and B == Nothing, something went wrong. 429 | // (I thought WildcardTypes would be enough, but nope. 'reset0 { 0 }' has them.) 430 | // 431 | // Code as simple as reset((_: String).length) 432 | // will crash meaninglessly without this check. See SI-3718. 433 | // 434 | // TODO - obviously this should be done earlier, differently, or with 435 | // a more skilled hand. Most likely, all three. 436 | if ((b.typeSymbol eq NothingClass) && call.tpe.exists(_ eq WildcardType)) 437 | reporter.error(tree.pos, "cannot cps-transform malformed (possibly in shift/reset placement) expression") 438 | else 439 | return ((stms, call)) 440 | } 441 | catch { 442 | case ex:TypeError => 443 | reporter.error(ex.pos, "cannot cps-transform expression " + tree + ": " + ex.msg) 444 | } 445 | } 446 | 447 | } else if (!cpsR.isDefined && bot.isDefined) { 448 | // error! 449 | debuglog("cps type error: " + expr) 450 | //println("cps type error: " + expr + "/" + expr.tpe + "/" + getAnswerTypeAnn(expr.tpe)) 451 | 452 | //println(cpsR + "/" + spc + "/" + bot) 453 | 454 | reporter.error(tree.pos, "found cps expression in non-cps position") 455 | } else { 456 | // all is well 457 | 458 | if (hasPlusMarker(expr.tpe)) { 459 | reporter.warning(tree.pos, "expression " + expr + " of type " + expr.tpe + " is not expected to have a cps type") 460 | expr modifyType removeAllCPSAnnotations 461 | } 462 | 463 | // TODO: sanity check that types agree 464 | } 465 | 466 | (stms, expr) 467 | } 468 | 469 | def transInlineValue(tree: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = { 470 | 471 | val (stms, expr, spc) = transValue(tree, cpsA, None) // never required to be cps 472 | 473 | getAnswerTypeAnn(expr.tpe) match { 474 | case spcVal @ Some(_) => 475 | 476 | val valueTpe = removeAllCPSAnnotations(expr.tpe) 477 | 478 | val sym: Symbol = ( 479 | currentOwner.newValue(newTermName(unit.fresh.newName("tmp")), tree.pos, Flags.SYNTHETIC) 480 | setInfo valueTpe 481 | setAnnotations List(AnnotationInfo(MarkerCPSSym.tpe_*, Nil, Nil)) 482 | ) 483 | expr.changeOwner(currentOwner -> sym) 484 | 485 | (stms ::: List(ValDef(sym, expr) setType(NoType)), 486 | Ident(sym) setType(valueTpe) setPos(tree.pos), linearize(spc, spcVal)(tree.pos)) 487 | 488 | case _ => 489 | (stms, expr, spc) 490 | } 491 | 492 | } 493 | 494 | 495 | 496 | def transInlineStm(stm: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], CPSInfo) = { 497 | stm match { 498 | 499 | // TODO: what about DefDefs? 500 | // TODO: relation to top-level val def? 501 | // TODO: what about lazy vals? 502 | 503 | case tree @ ValDef(mods, name, tpt, rhs) => 504 | val (stms, anfRhs, spc) = atOwner(tree.symbol) { transValue(rhs, cpsA, None) } 505 | 506 | val tv = new ChangeOwnerTraverser(tree.symbol, currentOwner) 507 | stms.foreach(tv.traverse(_)) 508 | 509 | // TODO: symbol might already have annotation. Should check conformance 510 | // TODO: better yet: do without annotations on symbols 511 | 512 | val spcVal = getAnswerTypeAnn(anfRhs.tpe) 513 | spcVal foreach (_ => tree.symbol setAnnotations List(AnnotationInfo(MarkerCPSSym.tpe_*, Nil, Nil))) 514 | 515 | (stms:::List(treeCopy.ValDef(tree, mods, name, tpt, anfRhs)), linearize(spc, spcVal)(tree.pos)) 516 | 517 | case _ => 518 | val (headStms, headExpr, headSpc) = transInlineValue(stm, cpsA) 519 | val valSpc = getAnswerTypeAnn(headExpr.tpe) 520 | (headStms:::List(headExpr), linearize(headSpc, valSpc)(stm.pos)) 521 | } 522 | } 523 | 524 | // precondition: cpsR.isDefined "implies" isAnyParentImpure 525 | def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = { 526 | def rec(currStats: List[Tree], currAns: CPSInfo, accum: List[Tree]): (List[Tree], Tree) = 527 | currStats match { 528 | case Nil => 529 | val (anfStats, anfExpr) = transTailValue(expr, currAns, cpsR) 530 | (accum ++ anfStats, anfExpr) 531 | 532 | case stat :: rest => 533 | val (stats, nextAns) = transInlineStm(stat, currAns) 534 | rec(rest, nextAns, accum ++ stats) 535 | } 536 | 537 | val (anfStats, anfExpr) = rec(stms, cpsA, List()) 538 | // println("\nanf-block:\n"+ ((stms :+ expr) mkString ("{", "\n", "}")) +"\nBECAME\n"+ ((anfStats :+ anfExpr) mkString ("{", "\n", "}"))) 539 | // println("synth case? "+ (anfStats map (t => (t, t.isDef, treeInfo.hasSynthCaseSymbol(t))))) 540 | // SUPER UGLY HACK: handle virtpatmat-style matches, whose labels have already been turned into DefDefs 541 | if (anfStats.nonEmpty && (anfStats forall (t => !t.isDef || treeInfo.hasSynthCaseSymbol(t)))) { 542 | val (prologue, rest) = (anfStats :+ anfExpr) span (s => !s.isInstanceOf[DefDef]) // find first case 543 | // println("rest: "+ rest) 544 | // val (defs, calls) = rest partition (_.isInstanceOf[DefDef]) 545 | if (rest.nonEmpty) { 546 | // the filter drops the ()'s emitted when transValue encountered a LabelDef 547 | val stats = prologue ++ (rest filter (_.isInstanceOf[DefDef])).reverse // ++ calls 548 | // println("REVERSED "+ (stats mkString ("{", "\n", "}"))) 549 | (stats, localTyper.typed{Apply(Ident(rest.head.symbol), List())}) // call first label to kick-start the match 550 | } else (anfStats, anfExpr) 551 | } else (anfStats, anfExpr) 552 | } 553 | } 554 | } 555 | -------------------------------------------------------------------------------- /library/src/test/scala/scala/tools/selectivecps/TestSuite.scala: -------------------------------------------------------------------------------- 1 | /* 2 | * Scala (https://www.scala-lang.org) 3 | * 4 | * Copyright EPFL and Lightbend, Inc. 5 | * 6 | * Licensed under Apache License 2.0 7 | * (http://www.apache.org/licenses/LICENSE-2.0). 8 | * 9 | * See the NOTICE file distributed with this work for 10 | * additional information regarding copyright ownership. 11 | */ 12 | 13 | package scala.tools.selectivecps 14 | 15 | import org.junit.Test 16 | import org.junit.Assert.assertEquals 17 | 18 | import scala.annotation._ 19 | import scala.collection.Seq 20 | import scala.collection.generic.CanBuildFrom 21 | import scala.language.{ implicitConversions, higherKinds } 22 | import scala.util.continuations._ 23 | import scala.util.control.Exception 24 | 25 | class Functions { 26 | def m0() = { 27 | shift((k: Int => Int) => k(k(7))) * 2 28 | } 29 | 30 | def m1() = { 31 | 2 * shift((k: Int => Int) => k(k(7))) 32 | } 33 | 34 | @Test def basics = { 35 | 36 | assertEquals(28, reset(m0())) 37 | assertEquals(28, reset(m1())) 38 | } 39 | 40 | @Test def function1 = { 41 | 42 | val f = () => shift { k: (Int => Int) => k(7) } 43 | val g: () => Int @cps[Int] = f 44 | 45 | assertEquals(7, reset(g())) 46 | } 47 | 48 | @Test def function4 = { 49 | 50 | val g: () => Int @cps[Int] = () => shift { k: (Int => Int) => k(7) } 51 | 52 | assertEquals(7, reset(g())) 53 | } 54 | 55 | @Test def function5 = { 56 | 57 | val g: () => Int @cps[Int] = () => 7 58 | 59 | assertEquals(7, reset(g())) 60 | } 61 | 62 | @Test def function6 = { 63 | 64 | val g: PartialFunction[Int, Int @cps[Int]] = { case x => 7 } 65 | 66 | assertEquals(7, reset(g(2))) 67 | 68 | } 69 | 70 | } 71 | 72 | class IfThenElse { 73 | val out = new StringBuilder; def printOut(x: Any): Unit = out ++= x.toString 74 | 75 | def test(x: Int) = if (x <= 7) 76 | shift { k: (Int => Int) => k(k(k(x))) } 77 | else 78 | shift { k: (Int => Int) => k(x) } 79 | 80 | @Test def ifelse0 = { 81 | assertEquals(10, reset(1 + test(7))) 82 | assertEquals(9, reset(1 + test(8))) 83 | } 84 | 85 | def test1(x: Int) = if (x <= 7) 86 | shift { k: (Int => Int) => k(k(k(x))) } 87 | else 88 | x 89 | 90 | def test2(x: Int) = if (x <= 7) 91 | x 92 | else 93 | shift { k: (Int => Int) => k(k(k(x))) } 94 | 95 | @Test def ifelse1 = { 96 | assertEquals(10, reset(1 + test1(7))) 97 | assertEquals(9, reset(1 + test1(8))) 98 | assertEquals(8, reset(1 + test2(7))) 99 | assertEquals(11, reset(1 + test2(8))) 100 | } 101 | 102 | def test3(x: Int) = if (x <= 7) 103 | shift { k: (Unit => Unit) => printOut("abort") } 104 | 105 | @Test def ifelse2 = { 106 | out.clear() 107 | printOut(reset { test3(7); printOut("alive") }) 108 | printOut(reset { test3(8); printOut("alive") }) 109 | assertEquals("abort()alive()", out.toString) 110 | } 111 | 112 | def util(x: Boolean) = shift { k: (Boolean => Int) => k(x) } 113 | 114 | def test4(x: Int) = if (util(x <= 7)) 115 | x - 1 116 | else 117 | x + 1 118 | 119 | @Test def ifelse3 = { 120 | assertEquals(6, reset(test4(7))) 121 | assertEquals(9, reset(test4(8))) 122 | } 123 | 124 | def sh(x1: Int) = shift((k: Int => Int) => k(k(k(x1)))) 125 | 126 | def testA(x1: Int): Int @cps[Int] = { 127 | sh(x1) 128 | if (x1 == 42) x1 else sh(x1) 129 | } 130 | 131 | def testB(x1: Int): Int @cps[Int] = { 132 | if (sh(x1) == 43) x1 else x1 133 | } 134 | 135 | def testC(x1: Int): Int @cps[Int] = { 136 | sh(x1) 137 | if (sh(x1) == 44) x1 else x1 138 | } 139 | 140 | def testD(x1: Int): Int @cps[Int] = { 141 | sh(x1) 142 | if (sh(x1) == 45) x1 else sh(x1) 143 | } 144 | 145 | @Test def ifelse4 = { 146 | assertEquals(10, reset(1 + testA(7))) 147 | assertEquals(10, reset(1 + testB(9))) 148 | assertEquals(10, reset(1 + testC(9))) 149 | assertEquals(10, reset(1 + testD(7))) 150 | } 151 | } 152 | 153 | class Inference { 154 | 155 | object A { 156 | class foo[-B, +C] extends StaticAnnotation with TypeConstraint 157 | 158 | def shift[A, B, C](fun: (A => B) => C): A @foo[B, C] = ??? 159 | def reset[A, C](ctx: => (A @foo[A, C])): C = ??? 160 | 161 | def m1 = reset { shift { f: (Int => Range) => f(5) }.to(10) } 162 | } 163 | 164 | object B { 165 | 166 | def m1 = reset { shift { f: (Int => Range) => f(5) }.to(10) } 167 | def m2 = reset { val a = shift { f: (Int => Range) => f(5) }; a.to(10) } 168 | 169 | val x1 = reset { 170 | shift { cont: (Int => Range) => 171 | cont(5) 172 | }.to(10) 173 | } 174 | 175 | val x2 = reset { 176 | val a = shift { cont: (Int => Range) => 177 | cont(5) 178 | } 179 | a.to(10) 180 | } // x is now Range(5, 6, 7, 8, 9, 10) 181 | 182 | val x3 = reset { 183 | shift { cont: (Int => Int) => 184 | cont(5) 185 | } + 10 186 | } // x is now 15 187 | 188 | val x4 = reset { 189 | 10 :: shift { cont: (List[Int] => List[Int]) => 190 | cont(List(1, 2, 3)) 191 | } 192 | } // x is List(10, 1, 2, 3) 193 | 194 | val x5 = reset { 195 | new scala.runtime.RichInt(shift { cont: (Int => Range) => 196 | cont(5) 197 | }) to 10 198 | } 199 | } 200 | 201 | @Test def implicit_infer_annotations = { 202 | import B._ 203 | assertEquals(5 to 10, x1) 204 | assertEquals(5 to 10, x2) 205 | assertEquals(15, x3) 206 | assertEquals(List(10, 1, 2, 3), x4) 207 | assertEquals(5 to 10, x5) 208 | } 209 | 210 | def test(x: => Int @cpsParam[String, Int]) = 7 211 | 212 | def test2() = { 213 | val x = shift { k: (Int => String) => 9 } 214 | x 215 | } 216 | 217 | def test3(x: => Int @cpsParam[Int, Int]) = 7 218 | 219 | def util() = shift { k: (String => String) => "7" } 220 | 221 | @Test def infer1: Unit = { 222 | test { shift { k: (Int => String) => 9 } } 223 | test { shift { k: (Int => String) => 9 }; 2 } 224 | // test { shift { k: (Int => String) => 9 }; util() } <-- doesn't work 225 | test { shift { k: (Int => String) => 9 }; util(); 2 } 226 | 227 | test { shift { k: (Int => String) => 9 }; { test3(0); 2 } } 228 | 229 | test3 { { test3(0); 2 } } 230 | 231 | } 232 | 233 | } 234 | 235 | class PatternMatching { 236 | 237 | def test(x: Int) = x match { 238 | case 7 => shift { k: (Int => Int) => k(k(k(x))) } 239 | case 8 => shift { k: (Int => Int) => k(x) } 240 | } 241 | 242 | @Test def match0 = { 243 | assertEquals(10, reset(1 + test(7))) 244 | assertEquals(9, reset(1 + test(8))) 245 | } 246 | 247 | def test1(x: Int) = x match { 248 | case 7 => shift { k: (Int => Int) => k(k(k(x))) } 249 | case _ => x 250 | } 251 | 252 | @Test def match1 = { 253 | assertEquals(10, reset(1 + test(7))) 254 | assertEquals(9, reset(1 + test(8))) 255 | } 256 | 257 | def test2() = { 258 | val (a, b) = shift { k: (((String, String)) => String) => k("A", "B") } 259 | b 260 | } 261 | 262 | case class Elem[T, U](a: T, b: U) 263 | 264 | def test3() = { 265 | val Elem(a, b) = shift { k: (Elem[String, String] => String) => k(Elem("A", "B")) } 266 | b 267 | } 268 | 269 | @Test def match2 = { 270 | assertEquals("B", reset(test2())) 271 | assertEquals("B", reset(test3())) 272 | } 273 | 274 | def sh(x1: Int) = shift((k: Int => Int) => k(k(k(x1)))) 275 | 276 | def testv(x1: Int) = { 277 | val o7 = { 278 | val o6 = { 279 | val o3 = 280 | if (7 == x1) Some(x1) 281 | else None 282 | 283 | if (o3.isEmpty) None 284 | else Some(sh(x1)) 285 | } 286 | if (o6.isEmpty) { 287 | val o5 = 288 | if (8 == x1) Some(x1) 289 | else None 290 | 291 | if (o5.isEmpty) None 292 | else Some(sh(x1)) 293 | } else o6 294 | } 295 | o7.get 296 | } 297 | 298 | @Test def patvirt = { 299 | assertEquals(10, reset(1 + testv(7))) 300 | assertEquals(11, reset(1 + testv(8))) 301 | } 302 | 303 | class MatchRepro { 304 | def s: String @cps[Any] = shift { k => k("foo") } 305 | 306 | def p = { 307 | val k = s 308 | s match { case lit0 => } 309 | } 310 | 311 | def q = { 312 | val k = s 313 | k match { case lit1 => } 314 | } 315 | 316 | def r = { 317 | s match { case "FOO" => } 318 | } 319 | 320 | def t = { 321 | val k = s 322 | k match { case "FOO" => } 323 | } 324 | } 325 | 326 | @Test def z1673 = { 327 | val m = new MatchRepro 328 | () 329 | } 330 | } 331 | 332 | class IfReturn { 333 | // d = 1, d2 = 1.0, pct = 1.000 334 | // d = 2, d2 = 4.0, pct = 0.500 335 | // d = 3, d2 = 9.0, pct = 0.333 336 | // d = 4, d2 = 16.0, pct = 0.250 337 | // d = 5, d2 = 25.0, pct = 0.200 338 | // d = 6, d2 = 36.0, pct = 0.167 339 | // d = 7, d2 = 49.0, pct = 0.143 340 | // d = 8, d2 = 64.0, pct = 0.125 341 | // d = 9, d2 = 81.0, pct = 0.111 342 | // d = 10, d2 = 100.0, pct = 0.100 343 | // d = 11, d2 = 121.0, pct = 0.091 344 | // d = 12, d2 = 144.0, pct = 0.083 345 | // d = 13, d2 = 169.0, pct = 0.077 346 | // d = 14, d2 = 196.0, pct = 0.071 347 | // d = 15, d2 = 225.0, pct = 0.067 348 | // d = 16, d2 = 256.0, pct = 0.063 349 | // d = 17, d2 = 289.0, pct = 0.059 350 | // d = 18, d2 = 324.0, pct = 0.056 351 | // d = 19, d2 = 361.0, pct = 0.053 352 | // d = 20, d2 = 400.0, pct = 0.050 353 | // d = 21, d2 = 441.0, pct = 0.048 354 | // d = 22, d2 = 484.0, pct = 0.045 355 | // d = 23, d2 = 529.0, pct = 0.043 356 | // d = 24, d2 = 576.0, pct = 0.042 357 | // d = 25, d2 = 625.0, pct = 0.040 358 | 359 | abstract class IfReturnRepro { 360 | def s1: Double @cpsParam[Any, Unit] 361 | def s2: Double @cpsParam[Any, Unit] 362 | 363 | def p(i: Int): Double @cpsParam[Unit, Any] = { 364 | val px = s1 365 | val pct = if (px > 100) px else px / s2 366 | //printOut("pct = %.3f".format(pct)) 367 | assertEquals(s1 / s2, pct, 0.0001) 368 | pct 369 | } 370 | } 371 | 372 | @Test def shift_pct = { 373 | var d: Double = 0d 374 | def d2 = d * d 375 | 376 | val irr = new IfReturnRepro { 377 | def s1 = shift(f => f(d)) 378 | def s2 = shift(f => f(d2)) 379 | } 380 | 1 to 25 foreach { i => 381 | d = i 382 | // print("d = " + i + ", d2 = " + d2 + ", ") 383 | assertEquals(i.toDouble * i, d2, 0.1) 384 | run(irr p i) 385 | } 386 | } 387 | 388 | @Test def t1807 = { 389 | val z = reset { 390 | val f: (() => Int @cps[Int]) = () => 1 391 | f() 392 | } 393 | assertEquals(1, z) 394 | } 395 | 396 | @Test def t1808: Unit = { 397 | reset0 { 0 } 398 | } 399 | } 400 | 401 | class Suspendable { 402 | def shifted: Unit @suspendable = shift { (k: Unit => Unit) => () } 403 | def test1(b: => Boolean) = { 404 | reset { 405 | if (b) shifted 406 | } 407 | } 408 | @Test def t1820 = test1(true) 409 | 410 | def suspended[A](x: A): A @suspendable = x 411 | def test1[A](x: A): A @suspendable = suspended(x) match { case x => x } 412 | def test2[A](x: List[A]): A @suspendable = suspended(x) match { case List(x) => x } 413 | 414 | def test3[A](x: A): A @suspendable = x match { case x => x } 415 | def test4[A](x: List[A]): A @suspendable = x match { case List(x) => x } 416 | 417 | @Test def t1821: Unit = { 418 | assertEquals((), reset(test1(()))) 419 | assertEquals((), reset(test2(List(())))) 420 | assertEquals((), reset(test3(()))) 421 | assertEquals((), reset(test4(List(())))) 422 | } 423 | } 424 | 425 | class Misc { 426 | 427 | def double[B](n: Int)(k: Int => B): B = k(n * 2) 428 | 429 | @Test def t2864: Unit = { 430 | reset { 431 | val result1 = shift(double[Unit](100)) 432 | val result2 = shift(double[Unit](result1)) 433 | assertEquals(400, result2) 434 | } 435 | } 436 | 437 | def foo: Int @cps[Int] = { 438 | val a0 = shift((k: Int => Int) => k(0)) 439 | val x0 = 2 440 | val a1 = shift((k: Int => Int) => x0) 441 | 0 442 | } 443 | 444 | /* 445 | def bar: ControlContext[Int,Int,Int] = { 446 | shiftR((k:Int=>Int) => k(0)).flatMap { a0 => 447 | val x0 = 2 448 | shiftR((k:Int=>Int) => x0).map { a1 => 449 | 0 450 | }} 451 | } 452 | */ 453 | 454 | @Test def t2934 = { 455 | assertEquals(List(3, 4, 5), reset { 456 | val x = shift(List(1, 2, 3).flatMap[Int, List[Int]]) 457 | List(x + 2) 458 | }) 459 | } 460 | 461 | class Bla { 462 | val x = 8 463 | def y[T] = 9 464 | } 465 | 466 | /* 467 | def bla[A] = shift { k:(Bla=>A) => k(new Bla) } 468 | */ 469 | 470 | def bla1 = shift { k: (Bla => Bla) => k(new Bla) } 471 | def bla2 = shift { k: (Bla => Int) => k(new Bla) } 472 | 473 | def fooA = bla2.x 474 | def fooB[T] = bla2.y[T] 475 | 476 | // TODO: check whether this also applies to a::shift { k => ... } 477 | 478 | @Test def t3225Mono(): Unit = { 479 | assertEquals(8, reset(bla1).x) 480 | assertEquals(8, reset(bla2.x)) 481 | assertEquals(9, reset(bla2.y[Int])) 482 | assertEquals(9, reset(bla2.y)) 483 | assertEquals(8, reset(fooA)) 484 | assertEquals(9, reset(fooB)) 485 | 0 486 | } 487 | 488 | def blaX[A] = shift { k: (Bla => A) => k(new Bla) } 489 | 490 | def fooX[A] = blaX[A].x 491 | def fooY[A] = blaX[A].y[A] 492 | 493 | @Test def t3225Poly(): Unit = { 494 | assertEquals(8, reset(blaX[Bla]).x) 495 | assertEquals(8, reset(blaX[Int].x)) 496 | assertEquals(9, reset(blaX[Int].y[Int])) 497 | assertEquals(9, reset(blaX[Int].y)) 498 | assertEquals(8, reset(fooX[Int])) 499 | assertEquals(9, reset(fooY[Int])) 500 | 0 501 | } 502 | 503 | def capture(): Int @suspendable = 42 504 | 505 | @Test def t3501: Unit = reset { 506 | var i = 0 507 | while (i < 5) { 508 | i += 1 509 | val y = capture() 510 | val s = y 511 | assertEquals(42, s) 512 | } 513 | assertEquals(5, i) 514 | } 515 | } 516 | 517 | class Return { 518 | val out = new StringBuilder; def printOut(x: Any): Unit = out ++= x.toString 519 | 520 | class ReturnRepro { 521 | def s1: Int @cps[Any] = shift { k => k(5) } 522 | def caller = reset { printOut(p(3)) } 523 | def caller2 = reset { printOut(p2(3)) } 524 | def caller3 = reset { printOut(p3(3)) } 525 | 526 | def p(i: Int): Int @cps[Any] = { 527 | val v = s1 + 3 528 | return v 529 | } 530 | 531 | def p2(i: Int): Int @cps[Any] = { 532 | val v = s1 + 3 533 | if (v > 0) { 534 | printOut("hi") 535 | return v 536 | } else { 537 | printOut("hi") 538 | return 8 539 | } 540 | } 541 | 542 | def p3(i: Int): Int @cps[Any] = { 543 | val v = s1 + 3 544 | try { 545 | printOut("from try") 546 | return v 547 | } catch { 548 | case e: Exception => 549 | printOut("from catch") 550 | return 7 551 | } 552 | } 553 | } 554 | 555 | @Test def t5314_2 = { 556 | out.clear() 557 | val repro = new ReturnRepro 558 | repro.caller 559 | repro.caller2 560 | repro.caller3 561 | assertEquals("8hi8from try8", out.toString) 562 | } 563 | 564 | class ReturnRepro2 { 565 | 566 | def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) } 567 | def caller = reset { printOut(p(3)) } 568 | def caller2 = reset { printOut(p2(3)) } 569 | 570 | def p(i: Int): Int @cpsParam[Unit, Any] = { 571 | val v = s1 + 3 572 | return { printOut("enter return expr"); v } 573 | } 574 | 575 | def p2(i: Int): Int @cpsParam[Unit, Any] = { 576 | val v = s1 + 3 577 | if (v > 0) { 578 | return { printOut("hi"); v } 579 | } else { 580 | return { printOut("hi"); 8 } 581 | } 582 | } 583 | } 584 | 585 | @Test def t5314_3 = { 586 | out.clear() 587 | val repro = new ReturnRepro2 588 | repro.caller 589 | repro.caller2 590 | assertEquals("enter return expr8hi8", out.toString) 591 | } 592 | 593 | def foo(x: Int): Int @cps[Int] = 7 594 | 595 | def bar(x: Int): Int @cps[Int] = { 596 | val v = foo(x) 597 | if (v > 0) 598 | return v 599 | else 600 | return 10 601 | } 602 | 603 | @Test def t5314_with_if = 604 | assertEquals(7, reset { bar(10) }) 605 | 606 | } 607 | 608 | class t5314 { 609 | val out = new StringBuilder; def printOut(x: Any): Unit = out ++= x.toString 610 | 611 | class ReturnRepro3 { 612 | def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) } 613 | def caller = reset { printOut(p(3)) } 614 | def caller2 = reset { printOut(p2(3)) } 615 | 616 | def p(i: Int): Int @cpsParam[Unit, Any] = { 617 | val v = s1 + 3 618 | return v 619 | } 620 | 621 | def p2(i: Int): Int @cpsParam[Unit, Any] = { 622 | val v = s1 + 3 623 | if (v > 0) { 624 | printOut("hi") 625 | return v 626 | } else { 627 | printOut("hi") 628 | return 8 629 | } 630 | } 631 | } 632 | 633 | def foo(x: Int): Int @cps[Int] = shift { k => k(x) } 634 | 635 | def bar(x: Int): Int @cps[Int] = return foo(x) 636 | 637 | def nocps(x: Int): Int = { return x; x } 638 | 639 | def foo2(x: Int): Int @cps[Int] = 7 640 | def bar2(x: Int): Int @cps[Int] = { foo2(x); return 7 } 641 | def bar3(x: Int): Int @cps[Int] = { foo2(x); if (x == 7) return 7 else return foo2(x) } 642 | def bar4(x: Int): Int @cps[Int] = { foo2(x); if (x == 7) return 7 else foo2(x) } 643 | def bar5(x: Int): Int @cps[Int] = { foo2(x); if (x == 7) return 7 else 8 } 644 | 645 | @Test def t5314 = { 646 | out.clear() 647 | 648 | printOut(reset { bar2(10) }) 649 | printOut(reset { bar3(10) }) 650 | printOut(reset { bar4(10) }) 651 | printOut(reset { bar5(10) }) 652 | 653 | /* original test case */ 654 | val repro = new ReturnRepro3 655 | repro.caller 656 | repro.caller2 657 | 658 | reset { 659 | val res = bar(8) 660 | printOut(res) 661 | res 662 | } 663 | 664 | assertEquals("77788hi88", out.toString) 665 | } 666 | 667 | } 668 | 669 | class HigherOrder { 670 | 671 | import java.util.concurrent.atomic._ 672 | 673 | @Test def t5472 = { 674 | val map = Map("foo" -> 1, "bar" -> 2) 675 | reset { 676 | val mapped = new ContinuationizedParallelIterable(map) 677 | .map(strIntTpl => shiftUnit0[Int, Unit](23)) 678 | assertEquals(List(23, 23), mapped.toList) 679 | } 680 | } 681 | 682 | @deprecated("Suppress warnings", since = "2.11") 683 | final class ContinuationizedParallelIterable[+A](protected val underline: Iterable[A]) { 684 | def toList = underline.toList.sortBy(_.toString) 685 | 686 | final def filter(p: A => Boolean @suspendable): ContinuationizedParallelIterable[A] @suspendable = 687 | shift( 688 | new AtomicInteger(1) with ((ContinuationizedParallelIterable[A] => Unit) => Unit) { 689 | private val results = new AtomicReference[List[A]](Nil) 690 | 691 | @tailrec 692 | private def add(element: A): Unit = { 693 | val old = results.get 694 | if (!results.compareAndSet(old, element :: old)) { 695 | add(element) 696 | } 697 | } 698 | 699 | override final def apply(continue: ContinuationizedParallelIterable[A] => Unit): Unit = { 700 | for (element <- underline) { 701 | super.incrementAndGet() 702 | reset { 703 | val pass = p(element) 704 | if (pass) { 705 | add(element) 706 | } 707 | if (super.decrementAndGet() == 0) { 708 | continue(new ContinuationizedParallelIterable(results.get)) 709 | } 710 | } 711 | } 712 | if (super.decrementAndGet() == 0) { 713 | continue(new ContinuationizedParallelIterable(results.get)) 714 | } 715 | } 716 | }) 717 | 718 | final def foreach[U](f: A => U @suspendable): Unit @suspendable = 719 | shift( 720 | new AtomicInteger(1) with ((Unit => Unit) => Unit) { 721 | override final def apply(continue: Unit => Unit): Unit = { 722 | for (element <- underline) { 723 | super.incrementAndGet() 724 | reset { 725 | f(element) 726 | if (super.decrementAndGet() == 0) { 727 | continue(()) 728 | } 729 | } 730 | } 731 | if (super.decrementAndGet() == 0) { 732 | continue(()) 733 | } 734 | } 735 | }) 736 | 737 | final def map[B: Manifest](f: A => B @suspendable): ContinuationizedParallelIterable[B] @suspendable = 738 | shift( 739 | new AtomicInteger(underline.size) with ((ContinuationizedParallelIterable[B] => Unit) => Unit) { 740 | override final def apply(continue: ContinuationizedParallelIterable[B] => Unit): Unit = { 741 | val results = new Array[B](super.get) 742 | for ((element, i) <- underline.view.zipWithIndex) { 743 | reset { 744 | val result = f(element) 745 | results(i) = result 746 | if (super.decrementAndGet() == 0) { 747 | continue(new ContinuationizedParallelIterable(results)) 748 | } 749 | } 750 | } 751 | } 752 | }) 753 | } 754 | 755 | def g: List[Int] @suspendable = List(1, 2, 3) 756 | 757 | def fp10: List[Int] @suspendable = { 758 | g.map(x => x) 759 | } 760 | 761 | def fp11: List[Int] @suspendable = { 762 | val z = g.map(x => x) 763 | z 764 | } 765 | 766 | def fp12: List[Int] @suspendable = { 767 | val z = List(1, 2, 3) 768 | z.map(x => x) 769 | } 770 | 771 | def fp20: List[Int] @suspendable = { 772 | g.map[Int, List[Int]](x => x) 773 | } 774 | 775 | def fp21: List[Int] @suspendable = { 776 | val z = g.map[Int, List[Int]](x => x) 777 | z 778 | } 779 | 780 | def fp22: List[Int] @suspendable = { 781 | val z = g.map[Int, List[Int]](x => x)(List.canBuildFrom[Int]) 782 | z 783 | } 784 | 785 | def fp23: List[Int] @suspendable = { 786 | val z = g.map(x => x)(List.canBuildFrom[Int]) 787 | z 788 | } 789 | 790 | @Test def t5506 = { 791 | reset { 792 | assertEquals(List(1, 2, 3), fp10) 793 | assertEquals(List(1, 2, 3), fp11) 794 | assertEquals(List(1, 2, 3), fp12) 795 | assertEquals(List(1, 2, 3), fp20) 796 | assertEquals(List(1, 2, 3), fp21) 797 | assertEquals(List(1, 2, 3), fp22) 798 | assertEquals(List(1, 2, 3), fp23) 799 | } 800 | } 801 | class ExecutionContext 802 | 803 | implicit def defaultExecutionContext: ExecutionContext = new ExecutionContext 804 | 805 | case class Future[+T](x: T) { 806 | final def map[A](f: T => A): Future[A] = new Future[A](f(x)) 807 | final def flatMap[A](f: T => Future[A]): Future[A] = f(x) 808 | } 809 | 810 | class PromiseStream[A] { 811 | override def toString = xs.toString 812 | 813 | var xs: List[A] = Nil 814 | 815 | final def +=(elem: A): this.type = { xs :+= elem; this } 816 | 817 | final def ++=(elem: Traversable[A]): this.type = { xs ++= elem; this } 818 | 819 | final def <<(elem: Future[A]): PromiseStream[A] @cps[Future[Any]] = 820 | shift { cont: (PromiseStream[A] => Future[Any]) => elem map (a => cont(this += a)) } 821 | 822 | final def <<(elem1: Future[A], elem2: Future[A], elems: Future[A]*): PromiseStream[A] @cps[Future[Any]] = 823 | shift { cont: (PromiseStream[A] => Future[Any]) => Future.flow(this << elem1 << elem2 <<< Future.sequence(elems.toSeq)) map cont } 824 | 825 | final def <<<(elems: Traversable[A]): PromiseStream[A] @cps[Future[Any]] = 826 | shift { cont: (PromiseStream[A] => Future[Any]) => cont(this ++= elems) } 827 | 828 | final def <<<(elems: Future[Traversable[A]]): PromiseStream[A] @cps[Future[Any]] = 829 | shift { cont: (PromiseStream[A] => Future[Any]) => elems map (as => cont(this ++= as)) } 830 | } 831 | 832 | object Future { 833 | 834 | def sequence[A, M[_] <: Traversable[_]](in: M[Future[A]])(implicit cbf: CanBuildFrom[M[Future[A]], A, M[A]], executor: ExecutionContext): Future[M[A]] = 835 | new Future(in.asInstanceOf[Traversable[Future[A]]].map((f: Future[A]) => f.x)(cbf.asInstanceOf[CanBuildFrom[Traversable[Future[A]], A, M[A]]])) 836 | 837 | def flow[A](body: => A @cps[Future[Any]])(implicit executor: ExecutionContext): Future[A] = reset(Future(body)).asInstanceOf[Future[A]] 838 | 839 | } 840 | 841 | @Test def t5538 = { 842 | val p = new PromiseStream[Int] 843 | assertEquals(Future(Future(Future(Future(Future(List(1, 2, 3, 4, 5)))))).toString, Future.flow(p << (Future(1), Future(2), Future(3), Future(4), Future(5))).toString) 844 | } 845 | } 846 | 847 | class TryCatch { 848 | 849 | def foo = try { 850 | shift((k: Int => Int) => k(7)) 851 | } catch { 852 | case ex: Throwable => 853 | 9 854 | } 855 | 856 | def bar = try { 857 | 7 858 | } catch { 859 | case ex: Throwable => 860 | shiftUnit0[Int, Int](9) 861 | } 862 | 863 | @Test def trycatch0 = { 864 | assertEquals(10, reset { foo + 3 }) 865 | assertEquals(10, reset { bar + 3 }) 866 | } 867 | 868 | def fatal: Int = throw new Exception() 869 | 870 | def foo1 = try { 871 | fatal 872 | shift((k: Int => Int) => k(7)) 873 | } catch { 874 | case ex: Throwable => 875 | 9 876 | } 877 | 878 | def foo2 = try { 879 | shift((k: Int => Int) => k(7)) 880 | fatal 881 | } catch { 882 | case ex: Throwable => 883 | 9 884 | } 885 | 886 | def bar1 = try { 887 | fatal 888 | 7 889 | } catch { 890 | case ex: Throwable => 891 | shiftUnit0[Int, Int](9) // regular shift causes no-symbol doesn't have owner 892 | } 893 | 894 | def bar2 = try { 895 | 7 896 | fatal 897 | } catch { 898 | case ex: Throwable => 899 | shiftUnit0[Int, Int](9) // regular shift causes no-symbol doesn't have owner 900 | } 901 | 902 | @Test def trycatch1 = { 903 | assertEquals(12, reset { foo1 + 3 }) 904 | assertEquals(12, reset { foo2 + 3 }) 905 | assertEquals(12, reset { bar1 + 3 }) 906 | assertEquals(12, reset { bar2 + 3 }) 907 | } 908 | 909 | trait AbstractResource[+R <: AnyRef] { 910 | def reflect[B]: R @cpsParam[B, Either[Throwable, B]] = shift(acquireFor) 911 | def acquireFor[B](f: R => B): Either[Throwable, B] = { 912 | import Exception._ 913 | catching(List(classOf[Throwable]): _*) either (f(null.asInstanceOf[R])) 914 | } 915 | } 916 | 917 | @Test def t3199 = { 918 | val x = new AbstractResource[String] {} 919 | val result = x.acquireFor(x => 7) 920 | assertEquals(Right(7), result) 921 | } 922 | 923 | } 924 | 925 | object AvoidClassT3233 { 926 | def foo(x: Int) = { 927 | try { 928 | throw new Exception 929 | shiftUnit0[Int, Int](7) 930 | } catch { 931 | case ex: Throwable => 932 | val g = (a: Int) => a 933 | 9 934 | } 935 | } 936 | } 937 | 938 | class T3233 { 939 | // work around scalac bug: Trying to access the this of another class: tree.symbol = class TryCatch, ctx.clazz.symbol = <$anon: Function1> 940 | import AvoidClassT3233._ 941 | @Test def t3223 = { 942 | assertEquals(9, reset(foo(0))) 943 | } 944 | } 945 | 946 | class While { 947 | val out = new StringBuilder; def printOut(x: Any): Unit = out ++= x.toString 948 | 949 | def foo0(): Int @cps[Unit] = 2 950 | 951 | def test0(): Unit @cps[Unit] = { 952 | var x = 0 953 | while (x < 9000) { // pick number large enough to require tail-call opt 954 | x += foo0() 955 | } 956 | assertEquals(9000, x) 957 | } 958 | 959 | @Test def while0 = { 960 | reset(test0()) 961 | } 962 | 963 | def foo3(): Int @cps[Unit] = shift { k => printOut("up"); k(2); printOut("down") } 964 | 965 | def test3(): Unit @cps[Unit] = { 966 | var x = 0 967 | while (x < 9) { 968 | x += foo3() 969 | } 970 | printOut(x) 971 | } 972 | 973 | @Test def while1 = { 974 | out.clear 975 | reset(test3()) 976 | assertEquals("upupupupup10downdowndowndowndown", out.toString) 977 | } 978 | 979 | def foo1(): Int @cps[Unit] = 2 980 | def foo2(): Int @cps[Unit] = shift { k => printOut("up"); k(2); printOut("down") } 981 | 982 | def test2(): Unit @cps[Unit] = { 983 | var x = 0 984 | while (x < 9000) { // pick number large enough to require tail-call opt 985 | x += (if (x % 1000 != 0) foo1() else foo2()) 986 | } 987 | printOut(x) 988 | } 989 | 990 | @Test def while2 = { 991 | out.clear 992 | reset(test2()) 993 | assertEquals("upupupupupupupupup9000downdowndowndowndowndowndowndowndown", out.toString) 994 | } 995 | } --------------------------------------------------------------------------------