├── project
├── build.properties
├── plugins.sbt
└── release_on_tag.sh
├── plugin
└── src
│ └── main
│ ├── resources
│ └── scalac-plugin.xml
│ ├── scala
│ └── scala
│ │ └── tools
│ │ └── selectivecps
│ │ ├── SelectiveCPSPlugin.scala
│ │ ├── CPSUtils.scala
│ │ ├── SelectiveCPSTransform.scala
│ │ └── CPSAnnotationChecker.scala
│ ├── scala-2.11
│ └── scala
│ │ └── tools
│ │ └── selectivecps
│ │ └── SelectiveANFTransform.scala
│ └── scala-2.12
│ └── scala
│ └── tools
│ └── selectivecps
│ └── SelectiveANFTransform.scala
├── CODE_OF_CONDUCT.md
├── README.md
├── NOTICE
├── .gitignore
├── .travis.yml
├── library
└── src
│ ├── test
│ └── scala
│ │ └── scala
│ │ └── tools
│ │ └── selectivecps
│ │ ├── ShouldCompile.scala
│ │ ├── CompilerErrors.scala
│ │ └── TestSuite.scala
│ └── main
│ └── scala
│ └── scala
│ └── util
│ └── continuations
│ ├── package.scala
│ └── ControlContext.scala
└── LICENSE
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=1.3.12
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("org.scala-lang.modules" % "sbt-scala-module" % "2.2.0")
2 |
--------------------------------------------------------------------------------
/plugin/src/main/resources/scalac-plugin.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | continuations
4 | scala.tools.selectivecps.SelectiveCPSPlugin
5 |
6 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | all repositories in these organizations:
2 |
3 | * [scala](https://github.com/scala)
4 | * [scalacenter](https://github.com/scalacenter)
5 | * [lampepfl](https://github.com/lampepfl)
6 |
7 | are covered by the Scala Code of Conduct: https://scala-lang.org/conduct/
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | scala-continuations is a compiler plugin and library for Scala providing support for CPS (continuation-passing style) transformations.
2 |
3 | It is no longer maintained. Past releases (for Scala 2.12 and earlier) remain available on Maven Central.
4 |
5 | You might also be interested in https://github.com/scala/scala-async, which covers what was once the most common use case for the continuations plugin.
6 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Scala continuations
2 | Copyright (c) 2010-2020 EPFL
3 | Copyright (c) 2011-2020 Lightbend, Inc.
4 |
5 | Scala includes software developed at
6 | LAMP/EPFL (https://lamp.epfl.ch/) and
7 | Lightbend, Inc. (https://www.lightbend.com/).
8 |
9 | Licensed under the Apache License, Version 2.0 (the "License").
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | #
2 | # Are you tempted to edit this file?
3 | #
4 | # First consider if the changes make sense for all,
5 | # or if they are specific to your workflow/system.
6 | # If it is the latter, you can augment this list with
7 | # entries in .git/info/excludes
8 | #
9 |
10 | *.jar
11 | *~
12 |
13 | # eclipse, intellij
14 | /.classpath
15 | /.project
16 | /src/intellij/*.iml
17 | /src/intellij/*.ipr
18 | /src/intellij/*.iws
19 | /.cache
20 | /.idea
21 | /.settings
22 |
23 | # bak files produced by ./cleanup-commit
24 | *.bak
25 |
26 | # Mac specific, but that is common enough a dev platform to warrant inclusion.
27 | .DS_Store
28 |
29 | target/
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | version: ~> 1.0 # needed for imports
2 |
3 | import: scala/scala-dev:travis/default.yml
4 |
5 | language: scala
6 |
7 | scala:
8 | - 2.11.12
9 | - 2.12.8
10 | - 2.12.9
11 | - 2.12.10
12 |
13 | env:
14 | global:
15 | # The plugin needs to be released for every minor version, but the library is only released
16 | # once per major version (2.11, 2.12). The library is only published on:
17 | - LIBRARY_PUBLISH_SCALA_VERSIONS="2.11.12 2.12.10"
18 | matrix:
19 | - ADOPTOPENJDK=8
20 | - ADOPTOPENJDK=11
21 |
22 | install:
23 | - git fetch --tags # get all tags for sbt-dynver
24 |
25 | script: ./build.sh
26 |
27 | notifications:
28 | email:
29 | - adriaan.moors@lightbend.com
30 | - seth.tisue@lightbend.com
31 |
--------------------------------------------------------------------------------
/project/release_on_tag.sh:
--------------------------------------------------------------------------------
1 | #/bin/bash
2 |
3 | # if the current commit has a tag named like v(\d+\.\d+\.\d+.*),
4 | # and we're running on the right jdk/branch,
5 | # echo the sbt commands that publish a release with the version derived from the tag
6 | publishJdk=openjdk8
7 | publishBranch=master
8 | publishScalaVersion=2.12.0
9 |
10 | unset tag version
11 |
12 | # Exit without error when no (annotated) tag is found.
13 | tag=$(git describe HEAD --exact-match 2>/dev/null || exit 0)
14 |
15 | version=$(echo $tag | perl -pe 's/v(\d+\.\d+\.\d+.*)/$1/')
16 |
17 | if [[ "$version" != "" &&\
18 | "${TRAVIS_PULL_REQUEST}" == "false" &&\
19 | "${JAVA_HOME}" == "$(jdk_switcher home $publishJdk)" &&\
20 | "${TRAVIS_BRANCH}" == "${publishBranch}" &&\
21 | "${TRAVIS_SCALA_VERSION}" == "${publishScalaVersion}" ]]; then
22 | echo \'"set every version := $version"\' publish-signed
23 | fi
24 |
--------------------------------------------------------------------------------
/library/src/test/scala/scala/tools/selectivecps/ShouldCompile.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | import scala.collection.mutable.HashMap
14 | import scala.util.continuations._
15 |
16 | // https://issues.scala-lang.org/browse/SI-3620
17 | object t3620 extends App {
18 |
19 | class Store[K,V] {
20 |
21 | trait Waiting {
22 | def key: K
23 | def inform(value: V): Unit
24 | }
25 |
26 | private val map = new HashMap[K, V]
27 | private var waiting: List[Waiting] = Nil
28 |
29 | def waitFor(k: K, f: (V => Unit)): Unit = {
30 | map.get(k) match {
31 | case Some(v) => f(v)
32 | case None => {
33 | val w = new Waiting {
34 | def key = k
35 | def inform(v: V) = f(v)
36 | }
37 | waiting = w :: waiting
38 | }
39 | }
40 | }
41 |
42 |
43 | def add(key: K, value: V): Unit = {
44 | map(key) = value
45 | val p = waiting.partition(_.key == key)
46 | waiting = p._2
47 | p._1.foreach(_.inform(value))
48 | }
49 |
50 | def required(key: K) = {
51 | shift {
52 | c: (V => Unit) => {
53 | waitFor(key, c)
54 | }
55 | }
56 | }
57 |
58 | def option(key: Option[K]) = {
59 | shift {
60 | c: (Option[V] => Unit) => {
61 | key match {
62 | case Some(key) => waitFor(key, (v: V) => c(Some(v)))
63 | case None => c(None)
64 | }
65 |
66 | }
67 | }
68 | }
69 |
70 | }
71 |
72 | val store = new Store[String, Int]
73 |
74 | def test(p: Option[String]): Unit = {
75 | reset {
76 | // uncommenting the following two lines makes the compiler happy!
77 | // val o = store.option(p)
78 | // println(o)
79 | val i = store.option(p).getOrElse(1)
80 | println(i)
81 | }
82 | }
83 |
84 | test(Some("a"))
85 |
86 | }
87 |
--------------------------------------------------------------------------------
/plugin/src/main/scala/scala/tools/selectivecps/SelectiveCPSPlugin.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.tools.selectivecps
14 |
15 | import scala.tools.nsc
16 | import nsc.Global
17 | import nsc.plugins.Plugin
18 | import nsc.plugins.PluginComponent
19 |
20 | class SelectiveCPSPlugin(val global: Global) extends Plugin {
21 | val name = "continuations"
22 | val description = "applies selective cps conversion"
23 |
24 | val pluginEnabled = options contains "enable"
25 |
26 | val anfPhase = new {
27 | val global = SelectiveCPSPlugin.this.global
28 | val cpsEnabled = pluginEnabled
29 | override val enabled = cpsEnabled
30 | } with SelectiveANFTransform {
31 | val runsAfter = List("pickler")
32 | }
33 |
34 | val cpsPhase = new {
35 | val global = SelectiveCPSPlugin.this.global
36 | val cpsEnabled = pluginEnabled
37 | override val enabled = cpsEnabled
38 | } with SelectiveCPSTransform {
39 | val runsAfter = List("selectiveanf")
40 | override val runsBefore = List("uncurry")
41 | }
42 |
43 | val components = List[PluginComponent](anfPhase, cpsPhase)
44 |
45 | val checker = new {
46 | val global: SelectiveCPSPlugin.this.global.type = SelectiveCPSPlugin.this.global
47 | val cpsEnabled = pluginEnabled
48 | } with CPSAnnotationChecker
49 |
50 | // TODO don't muck up global with unused checkers
51 | global.addAnnotationChecker(checker.checker)
52 | global.analyzer.addAnalyzerPlugin(checker.plugin)
53 |
54 | global.log("instantiated cps plugin: " + this)
55 |
56 | override def init(options: List[String], error: String => Unit) = {
57 | options foreach {
58 | case "enable" => // in initializer
59 | case arg => error(s"Bad argument: $arg")
60 | }
61 | pluginEnabled
62 | }
63 |
64 | override val optionsHelp: Option[String] =
65 | Some(" -P:continuations:enable Enable continuations")
66 | }
67 |
--------------------------------------------------------------------------------
/plugin/src/main/scala/scala/tools/selectivecps/CPSUtils.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.tools.selectivecps
14 |
15 | import scala.tools.nsc.Global
16 |
17 | trait CPSUtils {
18 | val global: Global
19 | import global._
20 |
21 | val cpsEnabled: Boolean
22 | val verbose: Boolean = System.getProperty("cpsVerbose", "false") == "true"
23 | def vprintln(x: =>Any): Unit = if (verbose) println(x)
24 |
25 | object cpsNames {
26 | val catches = newTermName(s"$$catches")
27 | val ex = newTermName(s"$$ex")
28 | val flatMapCatch = newTermName("flatMapCatch")
29 | val getTrivialValue = newTermName("getTrivialValue")
30 | val isTrivial = newTermName("isTrivial")
31 | val reify = newTermName("reify")
32 | val reifyR = newTermName("reifyR")
33 | val shift = newTermName("shift")
34 | val shiftR = newTermName("shiftR")
35 | val shiftSuffix = newTermName(s"$$shift")
36 | val shiftUnit0 = newTermName("shiftUnit0")
37 | val shiftUnit = newTermName("shiftUnit")
38 | val shiftUnitR = newTermName("shiftUnitR")
39 | }
40 |
41 | lazy val MarkerCPSSym = rootMirror.getRequiredClass("scala.util.continuations.cpsSym")
42 | lazy val MarkerCPSTypes = rootMirror.getRequiredClass("scala.util.continuations.cpsParam")
43 | lazy val MarkerCPSSynth = rootMirror.getRequiredClass("scala.util.continuations.cpsSynth")
44 | lazy val MarkerCPSAdaptPlus = rootMirror.getRequiredClass("scala.util.continuations.cpsPlus")
45 | lazy val MarkerCPSAdaptMinus = rootMirror.getRequiredClass("scala.util.continuations.cpsMinus")
46 |
47 | lazy val Context = rootMirror.getRequiredClass("scala.util.continuations.ControlContext")
48 | lazy val ModCPS = rootMirror.getPackage(TermName("scala.util.continuations"))
49 |
50 | lazy val MethShiftUnit = definitions.getMember(ModCPS, cpsNames.shiftUnit)
51 | lazy val MethShiftUnit0 = definitions.getMember(ModCPS, cpsNames.shiftUnit0)
52 | lazy val MethShiftUnitR = definitions.getMember(ModCPS, cpsNames.shiftUnitR)
53 | lazy val MethShift = definitions.getMember(ModCPS, cpsNames.shift)
54 | lazy val MethShiftR = definitions.getMember(ModCPS, cpsNames.shiftR)
55 | lazy val MethReify = definitions.getMember(ModCPS, cpsNames.reify)
56 | lazy val MethReifyR = definitions.getMember(ModCPS, cpsNames.reifyR)
57 |
58 | lazy val allCPSAnnotations = List(MarkerCPSSym, MarkerCPSTypes, MarkerCPSSynth,
59 | MarkerCPSAdaptPlus, MarkerCPSAdaptMinus)
60 |
61 | // TODO - needed? Can these all use the same annotation info?
62 | protected def newSynthMarker() = newMarker(MarkerCPSSynth)
63 | protected def newPlusMarker() = newMarker(MarkerCPSAdaptPlus)
64 | protected def newMinusMarker() = newMarker(MarkerCPSAdaptMinus)
65 | protected def newMarker(tpe: Type): AnnotationInfo = AnnotationInfo marker tpe
66 | protected def newMarker(sym: Symbol): AnnotationInfo = AnnotationInfo marker sym.tpe
67 |
68 | protected def newCpsParamsMarker(tp1: Type, tp2: Type) =
69 | newMarker(appliedType(MarkerCPSTypes, tp1, tp2))
70 |
71 | // annotation checker
72 |
73 | protected def annTypes(ann: AnnotationInfo): (Type, Type) = {
74 | val tp0 :: tp1 :: Nil = ann.atp.dealiasWiden.typeArgs
75 | ((tp0, tp1))
76 | }
77 | protected def hasMinusMarker(tpe: Type) = tpe hasAnnotation MarkerCPSAdaptMinus
78 | protected def hasPlusMarker(tpe: Type) = tpe hasAnnotation MarkerCPSAdaptPlus
79 | protected def hasSynthMarker(tpe: Type) = tpe hasAnnotation MarkerCPSSynth
80 | protected def hasCpsParamTypes(tpe: Type) = tpe hasAnnotation MarkerCPSTypes
81 | protected def cpsParamTypes(tpe: Type) = tpe getAnnotation MarkerCPSTypes map annTypes
82 |
83 | def filterAttribs(tpe:Type, cls:Symbol) =
84 | tpe.annotations filter (_ matches cls)
85 |
86 | def removeAttribs(tpe: Type, classes: Symbol*) =
87 | tpe filterAnnotations (ann => !(classes exists (ann matches _)))
88 |
89 | def removeAllCPSAnnotations(tpe: Type) = removeAttribs(tpe, allCPSAnnotations:_*)
90 |
91 | def cpsParamAnnotation(tpe: Type) = filterAttribs(tpe, MarkerCPSTypes)
92 |
93 | def linearize(ann: List[AnnotationInfo]): AnnotationInfo = {
94 | ann reduceLeft { (a, b) =>
95 | val (u0,v0) = annTypes(a)
96 | val (u1,v1) = annTypes(b)
97 | // vprintln("check lin " + a + " andThen " + b)
98 |
99 | if (v1 <:< u0)
100 | newCpsParamsMarker(u1, v0)
101 | else
102 | throw new TypeError("illegal answer type modification: " + a + " andThen " + b)
103 | }
104 | }
105 |
106 | // anf transform
107 |
108 | def getExternalAnswerTypeAnn(tp: Type) = {
109 | cpsParamTypes(tp) orElse {
110 | if (hasPlusMarker(tp))
111 | global.warning("trying to instantiate type " + tp + " to unknown cps type")
112 | None
113 | }
114 | }
115 |
116 | def getAnswerTypeAnn(tp: Type): Option[(Type, Type)] =
117 | cpsParamTypes(tp) filterNot (_ => hasPlusMarker(tp))
118 |
119 | def hasAnswerTypeAnn(tp: Type) =
120 | hasCpsParamTypes(tp) && !hasPlusMarker(tp)
121 |
122 | def updateSynthFlag(tree: Tree) = { // remove annotations if *we* added them (@synth present)
123 | if (hasSynthMarker(tree.tpe)) {
124 | log("removing annotation from " + tree)
125 | tree modifyType removeAllCPSAnnotations
126 | } else
127 | tree
128 | }
129 |
130 | type CPSInfo = Option[(Type,Type)]
131 |
132 | def linearize(a: CPSInfo, b: CPSInfo)(implicit pos: Position): CPSInfo = {
133 | (a,b) match {
134 | case (Some((u0,v0)), Some((u1,v1))) =>
135 | vprintln("check lin " + a + " andThen " + b)
136 | if (!(v1 <:< u0)) {
137 | reporter.error(pos,"cannot change answer type in composition of cps expressions " +
138 | "from " + u1 + " to " + v0 + " because " + v1 + " is not a subtype of " + u0 + ".")
139 | throw new Exception("check lin " + a + " andThen " + b)
140 | }
141 | Some((u1,v0))
142 | case (Some(_), _) => a
143 | case (_, Some(_)) => b
144 | case _ => None
145 | }
146 | }
147 | }
148 |
--------------------------------------------------------------------------------
/library/src/main/scala/scala/util/continuations/package.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.util
14 |
15 | /* TODO: better documentation of return-type modification.
16 | * (Especially what means "Illegal answer type modification: ... andThen ...")
17 | */
18 |
19 | /**
20 | * Delimited continuations are a feature for modifying the usual control flow
21 | * of a program. To use continuations, provide the option `-P:continuations:enable`
22 | * to the Scala compiler or REPL to activate the compiler plugin.
23 | *
24 | * Below is an example of using continuations to suspend execution while awaiting
25 | * user input. Similar facilities are used in so-called continuation-based web frameworks.
26 | *
27 | * {{{
28 | * def go =
29 | * reset {
30 | * println("Welcome!")
31 | * val first = ask("Please give me a number")
32 | * val second = ask("Please enter another number")
33 | * printf("The sum of your numbers is: %d\n", first + second)
34 | * }
35 | * }}}
36 | *
37 | * The `reset` is provided by this package and delimits the extent of the
38 | * transformation. The `ask` is a function that will be defined below. Its
39 | * effect is to issue a prompt and then suspend execution awaiting user input.
40 | * Once the user provides an input value, execution of the suspended block
41 | * resumes.
42 | *
43 | * {{{
44 | * val sessions = new HashMap[UUID, Int=>Unit]
45 | * def ask(prompt: String): Int @cps[Unit] =
46 | * shift {
47 | * k: (Int => Unit) => {
48 | * val id = uuidGen
49 | * printf("%s\nrespond with: submit(0x%x, ...)\n", prompt, id)
50 | * sessions += id -> k
51 | * }
52 | * }
53 | * }}}
54 | *
55 | * The type of `ask` includes a `@cps` annotation which drives the transformation.
56 | * The type signature `Int @cps[Unit]` means that `ask` should be used in a
57 | * context requiring an `Int`, but actually it will suspend and return `Unit`.
58 | *
59 | * The computation leading up to the first `ask` is executed normally. The
60 | * remainder of the reset block is wrapped into a closure that is passed as
61 | * the parameter `k` to the `shift` function, which can then decide whether
62 | * and how to execute the continuation. In this example, the continuation is
63 | * stored in a sessions map for later execution. This continuation includes a
64 | * second call to `ask`, which is treated likewise once the execution resumes.
65 | *
66 | *
CPS Annotation
67 | *
68 | * The aforementioned `@cps[A]` annotation is an alias for the more general
69 | * `@cpsParam[B,C]` where `B=C`. The type `A @cpsParam[B,C]` describes a term
70 | * which yields a value of type `A` within an evaluation context producing a
71 | * value of type `B`. After the CPS transformation, this return type is
72 | * modified to `C`.
73 | *
74 | * The `@cpsParam` annotations are introduced by `shift` blocks, and propagate
75 | * via the return types to the dynamically enclosing context. The propagation
76 | * stops upon reaching a `reset` block.
77 | */
78 |
79 | package object continuations {
80 |
81 | /** An annotation that denotes a type is part of a continuation context.
82 | * `@cps[A]` is shorthand for `cpsParam[A,A]`.
83 | * @tparam A The return type of the continuation context.
84 | */
85 | type cps[A] = cpsParam[A,A]
86 |
87 | /** An annotation that denotes a type is part of a side effecting continuation context.
88 | * `@suspendable` is shorthand notation for `@cpsParam[Unit,Unit]` or `@cps[Unit]`.
89 | */
90 | type suspendable = cps[Unit]
91 |
92 | /**
93 | * The `shift` function captures the remaining computation in a `reset` block
94 | * and passes it to a closure provided by the user.
95 | *
96 | * For example:
97 | * {{{
98 | * reset {
99 | * shift { (k: Int => Int) => k(5) } + 1
100 | * }
101 | * }}}
102 | *
103 | * In this example, `shift` is used in the expression `shift ... + 1`.
104 | * The compiler will alter this expression so that the call
105 | * to `shift` becomes a parameter to a function, creating something like:
106 | * {{{
107 | * { (k: Int => Int) => k(5) } apply { _ + 1 }
108 | * }}}
109 | * The result of this expression is 6.
110 | *
111 | * There can be more than one `shift` call in a `reset` block. Each call
112 | * to `shift` can alter the return type of expression within the reset block,
113 | * but will not change the return type of the entire `reset { block }`
114 | * expression.
115 | *
116 | * @param fun A function where
117 | * - The parameter is the remainder of computation within the current
118 | * `reset` block. This is passed as a function `A => B`.
119 | * - The return is the return value of the `ControlContext` which is
120 | * generated from this inversion.
121 | * @note Must be invoked in the context of a call to `reset` This context
122 | * may not be far up the stack, but a call to reset is needed to
123 | * eventually remove the `@cps` annotations from types.
124 | */
125 | def shift[A,B,C](fun: (A => B) => C): A @cpsParam[B,C] = {
126 | throw new NoSuchMethodException("this code has to be compiled with the Scala continuations plugin enabled")
127 | }
128 | /** Creates a context for continuations captured within the argument closure
129 | * of this `reset` call and returns the result of the entire transformed
130 | * computation. Within an expression of the form `reset { block }`,
131 | * the closure expression (`block`) will be modified such that at each
132 | * call to `shift` the remainder of the expression is transformed into a
133 | * function to be passed into the shift.
134 | * @return The result of a block of code that uses `shift` to capture continuations.
135 | */
136 | def reset[A,C](ctx: =>(A @cpsParam[A,C])): C = {
137 | val ctxR = reify[A,A,C](ctx)
138 | if (ctxR.isTrivial)
139 | ctxR.getTrivialValue.asInstanceOf[C]
140 | else
141 | ctxR.foreach((x:A) => x)
142 | }
143 |
144 | def reset0[A](ctx: =>(A @cpsParam[A,A])): A = reset(ctx)
145 |
146 | def run[A](ctx: =>(Any @cpsParam[Unit,A])): A = {
147 | val ctxR = reify[Any,Unit,A](ctx)
148 | if (ctxR.isTrivial)
149 | ctxR.getTrivialValue.asInstanceOf[A]
150 | else
151 | ctxR.foreach((x:Any) => ())
152 | }
153 |
154 |
155 | // methods below are primarily implementation details and are not
156 | // needed frequently in client code
157 |
158 | def shiftUnit0[A,B](x: A): A @cpsParam[B,B] = {
159 | shiftUnit[A,B,B](x)
160 | }
161 |
162 | def shiftUnit[A,B,C>:B](x: A): A @cpsParam[B,C] = {
163 | throw new NoSuchMethodException("this code has to be compiled with the Scala continuations plugin enabled")
164 | }
165 |
166 | /** This method converts from the sugared `A @cpsParam[B,C]` type to the desugared
167 | * `ControlContext[A,B,C]` type. The underlying data is not changed.
168 | */
169 | def reify[A,B,C](ctx: =>(A @cpsParam[B,C])): ControlContext[A,B,C] = {
170 | throw new NoSuchMethodException("this code has to be compiled with the Scala continuations plugin enabled")
171 | }
172 |
173 | def shiftUnitR[A,B](x: A): ControlContext[A,B,B] = { // called in code generated by SelectiveCPSTransform
174 | new ControlContext[A, B, B](null, x)
175 | }
176 |
177 | /**
178 | * Captures a computation into a `ControlContext`.
179 | * @param fun The function which accepts the inverted computation and returns
180 | * a final result.
181 | * @see shift
182 | */
183 | def shiftR[A,B,C](fun: (A => B) => C): ControlContext[A,B,C] = { // called in code generated by SelectiveCPSTransform
184 | new ControlContext((f:A=>B,g:Exception=>B) => fun(f), null.asInstanceOf[A])
185 | }
186 |
187 | def reifyR[A,B,C](ctx: => ControlContext[A,B,C]): ControlContext[A,B,C] = { // called in code generated by SelectiveCPSTransform
188 | ctx
189 | }
190 |
191 | }
192 |
--------------------------------------------------------------------------------
/library/src/test/scala/scala/tools/selectivecps/CompilerErrors.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.tools.selectivecps
14 |
15 | import org.junit.Test
16 |
17 | class CompilerErrors extends CompilerTesting {
18 | // @Test -- disabled
19 | def infer0 =
20 | expectCPSError("cannot cps-transform expression 8: type arguments [Int(8),String,Int] do not conform to method shiftUnit's type parameter bounds [A,B,C >: B]",
21 | """|def test(x: => Int @cpsParam[String,Int]) = 7
22 | |
23 | |def main(args: Array[String]) = {
24 | | test(8)
25 | |}""")
26 |
27 | @Test def function0 =
28 | expectCPSError("""|type mismatch;
29 | | found : () => Int @scala.util.continuations.cpsParam[Int,Int]
30 | | required: () => Int""".stripMargin,
31 | """|def main(args: Array[String]): Any = {
32 | | val f = () => shift { k: (Int=>Int) => k(7) }
33 | | val g: () => Int = f
34 | |
35 | | println(reset(g()))
36 | |}""")
37 |
38 | @Test def function2 =
39 | expectCPSError(
40 | """|type mismatch;
41 | | found : () => Int
42 | | required: () => Int @scala.util.continuations.cpsParam[Int,Int]""".stripMargin,
43 | """|def main(args: Array[String]): Any = {
44 | | val f = () => 7
45 | | val g: () => Int @cps[Int] = f
46 | |
47 | | println(reset(g()))
48 | |}""")
49 |
50 | @Test def function3 =
51 | expectCPSError(
52 | """|type mismatch;
53 | | found : Int @scala.util.continuations.cpsParam[Int,Int]
54 | | required: Int""".stripMargin,
55 | """|def main(args: Array[String]): Any = {
56 | | val g: () => Int = () => shift { k: (Int=>Int) => k(7) }
57 | |
58 | | println(reset(g()))
59 | |}""")
60 |
61 | @Test def infer2 =
62 | expectCPSError("illegal answer type modification: scala.util.continuations.cpsParam[String,Int] andThen scala.util.continuations.cpsParam[String,Int]",
63 | """|def test(x: => Int @cpsParam[String,Int]) = 7
64 | |
65 | |def sym() = shift { k: (Int => String) => 9 }
66 | |
67 | |
68 | |def main(args: Array[String]): Any = {
69 | | test { sym(); sym() }
70 | |}""")
71 |
72 | @Test def `lazy` =
73 | expectCPSError("implementation restriction: cps annotations not allowed on lazy value definitions",
74 | """|def foo() = {
75 | | lazy val x = shift((k:Unit=>Unit)=>k())
76 | | println(x)
77 | |}
78 | |
79 | |def main(args: Array[String]) = {
80 | | reset {
81 | | foo()
82 | | }
83 | |}""")
84 |
85 | @Test def t1929 =
86 | expectCPSError(
87 | """|type mismatch;
88 | | found : Int @scala.util.continuations.cpsParam[String,String] @scala.util.continuations.cpsSynth
89 | | required: Int @scala.util.continuations.cpsParam[Int,String]""".stripMargin,
90 | """|def main(args : Array[String]) {
91 | | reset {
92 | | println("up")
93 | | val x = shift((k:Int=>String) => k(8) + k(2))
94 | | println("down " + x)
95 | | val y = shift((k:Int=>String) => k(3))
96 | | println("down2 " + y)
97 | | y + x
98 | | }
99 | |}""")
100 |
101 | @Test def t2285 =
102 | expectCPSError(
103 | """|type mismatch;
104 | | found : Int @scala.util.continuations.cpsParam[String,String] @scala.util.continuations.cpsSynth
105 | | required: Int @scala.util.continuations.cpsParam[Int,String]""".stripMargin,
106 | """|def bar() = shift { k: (String => String) => k("1") }
107 | |
108 | |def foo() = reset { bar(); 7 }""")
109 |
110 | @Test def t2949 =
111 | expectCPSError(
112 | """|type mismatch;
113 | | found : Int
114 | | required: ? @scala.util.continuations.cpsParam[List[?],Any]""".stripMargin,
115 | """|def reflect[A,B](xs : List[A]) = shift{ xs.flatMap[B, List[B]] }
116 | |def reify[A, B](x : A @cpsParam[List[A], B]) = reset{ List(x) }
117 | |
118 | |def main(args: Array[String]): Unit = println(reify {
119 | | val x = reflect[Int, Int](List(1,2,3))
120 | | val y = reflect[Int, Int](List(2,4,8))
121 | | x * y
122 | |})""")
123 |
124 | @Test def t3718 =
125 | expectCPSError(
126 | "cannot cps-transform malformed (possibly in shift/reset placement) expression",
127 | "scala.util.continuations.reset((_: Any).##)")
128 |
129 | @Test def t5314_missing_result_type =
130 | expectCPSError(
131 | "method bar has return statement; needs result type",
132 | """|def foo(x:Int): Int @cps[Int] = x
133 | |
134 | |def bar(x:Int) = return foo(x)
135 | |
136 | |reset {
137 | | val res = bar(8)
138 | | println(res)
139 | | res
140 | |}""")
141 |
142 | @Test def t5314_npe =
143 | expectCPSError(
144 | "method bar has return statement; needs result type",
145 | "def bar(x:Int) = { return x; x } // NPE")
146 |
147 | @Test def t5314_return_reset =
148 | expectCPSError(
149 | "return expression not allowed, since method calls CPS method",
150 | """|val rnd = new scala.util.Random
151 | |
152 | |def foo(x: Int): Int @cps[Int] = shift { k => k(x) }
153 | |
154 | |def bar(x: Int): Int @cps[Int] = return foo(x)
155 | |
156 | |def caller(): Int = {
157 | | val v: Int = reset {
158 | | val res: Int = bar(8)
159 | | if (rnd.nextInt(100) > 50) return 5 // not allowed, since method is calling `reset`
160 | | 42
161 | | }
162 | | v
163 | |}
164 | |
165 | |caller()""")
166 |
167 | @Test def t5314_type_error =
168 | expectCPSError(
169 | """|type mismatch;
170 | | found : Int @scala.util.continuations.cpsParam[Int,Int]
171 | | required: Int @scala.util.continuations.cpsParam[String,String]""".stripMargin,
172 | """|def foo(x:Int): Int @cps[Int] = shift { k => k(x) }
173 | |
174 | |// should be a type error
175 | |def bar(x:Int): Int @cps[String] = return foo(x)
176 | |
177 | |def caller(): Unit = {
178 | | val v: String = reset {
179 | | val res: Int = bar(8)
180 | | "hello"
181 | | }
182 | |}
183 | |
184 | |caller()""")
185 |
186 | @Test def t5445 =
187 | expectCPSError(
188 | "cps annotations not allowed on by-value parameters or value definitions",
189 | "def foo(block: Unit @suspendable ): Unit @suspendable = {}")
190 |
191 | @Test def trycatch2 =
192 | expectCPSErrors(2, "only simple cps types allowed in try/catch blocks (found: Int @scala.util.continuations.cpsParam[String,Int])",
193 | """|def fatal[T]: T = throw new Exception
194 | |def cpsIntStringInt = shift { k:(Int=>String) => k(3); 7 }
195 | |def cpsIntIntString = shift { k:(Int=>Int) => k(3); "7" }
196 | |
197 | |def foo1 = try {
198 | | fatal[Int]
199 | | cpsIntStringInt
200 | |} catch {
201 | | case ex: Throwable =>
202 | | cpsIntStringInt
203 | |}
204 | |
205 | |def foo2 = try {
206 | | fatal[Int]
207 | | cpsIntStringInt
208 | |} catch {
209 | | case ex: Throwable =>
210 | | cpsIntStringInt
211 | |}
212 | |
213 | |
214 | |def main(args: Array[String]): Unit = {
215 | | println(reset { foo1; "3" })
216 | | println(reset { foo2; "3" })
217 | |}""")
218 | }
219 |
220 | class CompilerTesting {
221 | private def pluginJar: String = {
222 | val f = sys.props("scala-continuations-plugin.jar")
223 | assert(new java.io.File(f).exists, f)
224 | f
225 | }
226 | def loadPlugin = s"-Xplugin:${pluginJar} -P:continuations:enable"
227 |
228 | // note: `code` should have a | margin
229 | def cpsErrorMessages(msg: String, code: String) =
230 | errorMessages(msg, loadPlugin)(s"import scala.util.continuations._\nobject Test {\n${code.stripMargin}\n}")
231 |
232 | def expectCPSError(msg: String, code: String) = {
233 | val errors = cpsErrorMessages(msg, code)
234 | assert(errors exists (_ contains msg), errors mkString "\n")
235 | }
236 |
237 | def expectCPSErrors(msgCount: Int, msg: String, code: String) = {
238 | val errors = cpsErrorMessages(msg, code)
239 | val errorCount = errors.filter(_ contains msg).length
240 | assert(errorCount == msgCount, s"$errorCount occurrences of \'$msg\' found -- expected $msgCount in:\n${errors mkString "\n"}")
241 | }
242 |
243 | // TODO: move to scala.tools.reflect.ToolboxFactory
244 | def errorMessages(errorSnippet: String, compileOptions: String)(code: String): List[String] = {
245 | import scala.tools.reflect._
246 | val m = scala.reflect.runtime.currentMirror
247 | val tb = m.mkToolBox(options = compileOptions) //: ToolBox[m.universe.type]
248 | val fe = tb.frontEnd
249 |
250 | try {
251 | tb.eval(tb.parse(code))
252 | Nil
253 | } catch {
254 | case _: ToolBoxError =>
255 | import fe._
256 | infos.toList collect { case Info(_, msg, ERROR) => msg }
257 | }
258 | }
259 | }
--------------------------------------------------------------------------------
/library/src/main/scala/scala/util/continuations/ControlContext.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.util.continuations
14 |
15 | import scala.annotation.{ Annotation, StaticAnnotation, TypeConstraint }
16 |
17 | /** This annotation is used to mark a parameter as part of a continuation
18 | * context.
19 | *
20 | * The type `A @cpsParam[B,C]` is desugared to `ControlContext[A,B,C]` at compile
21 | * time.
22 | *
23 | * @tparam B The type of computation state after computation has executed, and
24 | * before control is returned to the shift.
25 | * @tparam C The eventual return type of this delimited compuation.
26 | * @see scala.util.continuations.ControlContext
27 | */
28 | class cpsParam[-B,+C] extends StaticAnnotation with TypeConstraint
29 |
30 | private class cpsSym[B] extends Annotation // implementation detail
31 |
32 | private class cpsSynth extends Annotation // implementation detail
33 |
34 | private class cpsPlus extends StaticAnnotation with TypeConstraint // implementation detail
35 | private class cpsMinus extends Annotation // implementation detail
36 |
37 |
38 | /**
39 | * This class represent a portion of computation that has a 'hole' in it. The
40 | * class has the ability to compute state up until a certain point where the
41 | * state has the `A` type. If this context is given a function of type
42 | * `A => B` to move the state to the `B` type, then the entire computation can
43 | * be completed resulting in a value of type `C`.
44 | *
45 | * An Example: {{{
46 | * val cc = new ControlContext[String, String, String](
47 | * fun = { (f: String=>String, err: Exception => String) =>
48 | * val updatedState =
49 | * try f("State")
50 | * catch {
51 | * case e: Exception => err(e)
52 | * }
53 | * updatedState + "-Complete!"
54 | * },
55 | * x = null.asIntanceOf[String]
56 | * }
57 | * cc.foreach(_ + "-Continued") // Results in "State-Continued-Complete!"
58 | * }}}
59 | *
60 | * This class is used to transform calls to `shift` in the `continuations`
61 | * package. Direct use and instantiation is possible, but usually reserved
62 | * for advanced cases.
63 | *
64 | *
65 | * A context may either be ''trivial'' or ''non-trivial''. A ''trivial''
66 | * context '''just''' has a state of type `A`. When completing the computation,
67 | * it's only necessary to use the function of type `A => B` directly against
68 | * the trivial value. A ''non-trivial'' value stores a computation '''around'''
69 | * the state transformation of type `A => B` and cannot be short-circuited.
70 | *
71 | * @param fun The captured computation so far. The type
72 | * `(A => B, Exception => B) => C` is a function where:
73 | * - The first parameter `A=>B` represents the computation defined against
74 | * the current state held in the ControlContext.
75 | * - The second parameter `Exception => B` represents a computation to
76 | * perform if an exception is thrown from the first parameter's computation.
77 | * - The return value is the result of the entire computation contained in this
78 | * `ControlContext`.
79 | * @param x The current state stored in this context. Allowed to be null if
80 | * the context is non-trivial.
81 | * @tparam A The type of the state currently held in the context.
82 | * @tparam B The type of the transformed state needed to complete this computation.
83 | * @tparam C The return type of the entire computation stored in this context.
84 | * @note `fun` and `x` are allowed to be `null`.
85 | * @see scala.util.continutations.shiftR
86 | */
87 | final class ControlContext[+A,-B,+C](val fun: (A => B, Exception => B) => C, val x: A) extends Serializable {
88 |
89 | /*
90 | final def map[A1](f: A => A1): ControlContext[A1,B,C] = {
91 | new ControlContext((k:(A1 => B)) => fun((x:A) => k(f(x))), null.asInstanceOf[A1])
92 | }
93 |
94 | final def flatMap[A1,B1<:B](f: (A => ControlContext[A1,B1,B])): ControlContext[A1,B1,C] = {
95 | new ControlContext((k:(A1 => B1)) => fun((x:A) => f(x).fun(k)))
96 | }
97 | */
98 |
99 | /**
100 | * Modifies the currently captured state in this `ControlContext`.
101 | * @tparam A1 The new type of state in this context.
102 | * @param f A transformation function on the current state of the `ControlContext`.
103 | * @return The new `ControlContext`.
104 | */
105 | @noinline final def map[A1](f: A => A1): ControlContext[A1,B,C] = {
106 | if (fun eq null)
107 | try {
108 | new ControlContext[A1,B,C](null, f(x)) // TODO: only alloc if f(x) != x
109 | } catch {
110 | case ex: Exception =>
111 | new ControlContext((k: A1 => B, thr: Exception => B) => thr(ex).asInstanceOf[C], null.asInstanceOf[A1])
112 | }
113 | else
114 | new ControlContext({ (k: A1 => B, thr: Exception => B) =>
115 | fun( { (x:A) =>
116 | var done = false
117 | try {
118 | val res = f(x)
119 | done = true
120 | k(res)
121 | } catch {
122 | case ex: Exception if !done =>
123 | thr(ex)
124 | }
125 | }, thr)
126 | }, null.asInstanceOf[A1])
127 | }
128 |
129 |
130 | // it would be nice if @inline would turn the trivial path into a tail call.
131 | // unfortunately it doesn't, so we do it ourselves in SelectiveCPSTransform
132 |
133 | /**
134 | * Maps and flattens this `ControlContext` with another `ControlContext` generated from the current state.
135 | * @note The resulting comuptation is still the type `C`.
136 | * @tparam A1 The new type of the contained state.
137 | * @tparam B1 The new type of the state after the stored continuation has executed.
138 | * @tparam C1 The result type of the nested `ControlContext`. Because the nested `ControlContext` is executed within
139 | * the outer `ControlContext`, this type must `>: B` so that the resulting nested computation can be fed through
140 | * the current continuation.
141 | * @param f A transformation function from the current state to a nested `ControlContext`.
142 | * @return The transformed `ControlContext`.
143 | */
144 | @noinline final def flatMap[A1,B1,C1<:B](f: (A => ControlContext[A1,B1,C1])): ControlContext[A1,B1,C] = {
145 | if (fun eq null)
146 | try {
147 | f(x).asInstanceOf[ControlContext[A1,B1,C]]
148 | } catch {
149 | case ex: Exception =>
150 | new ControlContext((k: A1 => B1, thr: Exception => B1) => thr(ex).asInstanceOf[C], null.asInstanceOf[A1])
151 | }
152 | else
153 | new ControlContext({ (k: A1 => B1, thr: Exception => B1) =>
154 | fun( { (x:A) =>
155 | var done = false
156 | try {
157 | val ctxR = f(x)
158 | done = true
159 | val res: C1 = ctxR.foreachFull(k, thr) // => B1
160 | res
161 | } catch {
162 | case ex: Exception if !done =>
163 | thr(ex).asInstanceOf[B] // => B NOTE: in general this is unsafe!
164 | } // However, the plugin will not generate offending code
165 | }, thr.asInstanceOf[Exception=>B]) // => B
166 | }, null.asInstanceOf[A1])
167 | }
168 |
169 | /**
170 | * Runs the computation against the state stored in this `ControlContext`.
171 | * @param f the computation that modifies the current state of the context.
172 | * @note This method could throw exceptions from the computations.
173 | */
174 | final def foreach(f: A => B) = foreachFull(f, throw _)
175 |
176 | def foreachFull(f: A => B, g: Exception => B): C = {
177 | if (fun eq null)
178 | f(x).asInstanceOf[C]
179 | else
180 | fun(f, g)
181 | }
182 |
183 | /** @return true if this context only stores a state value and not any deferred computation. */
184 | final def isTrivial = fun eq null
185 | /** @return The current state value. */
186 | final def getTrivialValue = x.asInstanceOf[A]
187 |
188 | // need filter or other functions?
189 |
190 | final def flatMapCatch[A1>:A,B1<:B,C1>:C<:B1](pf: PartialFunction[Exception, ControlContext[A1,B1,C1]]): ControlContext[A1,B1,C1] = { // called by codegen from SelectiveCPSTransform
191 | if (fun eq null)
192 | this
193 | else {
194 | val fun1 = (ret1: A1 => B1, thr1: Exception => B1) => {
195 | val thr: Exception => B1 = { t: Exception =>
196 | var captureExceptions = true
197 | try {
198 | if (pf.isDefinedAt(t)) {
199 | val cc1 = pf(t)
200 | captureExceptions = false
201 | cc1.foreachFull(ret1, thr1) // Throw => B
202 | } else {
203 | captureExceptions = false
204 | thr1(t) // Throw => B1
205 | }
206 | } catch {
207 | case t1: Exception if captureExceptions => thr1(t1) // => E2
208 | }
209 | }
210 | fun(ret1, thr)// fun(ret1, thr) // => B
211 | }
212 | new ControlContext(fun1, null.asInstanceOf[A1])
213 | }
214 | }
215 |
216 | final def mapFinally(f: () => Unit): ControlContext[A,B,C] = { // called in code generated by SelectiveCPSTransform
217 | if (fun eq null) {
218 | try {
219 | f()
220 | this
221 | } catch {
222 | case ex: Exception =>
223 | new ControlContext((k: A => B, thr: Exception => B) => thr(ex).asInstanceOf[C], null.asInstanceOf[A])
224 | }
225 | } else {
226 | val fun1 = (ret1: A => B, thr1: Exception => B) => {
227 | val ret: A => B = { x: A =>
228 | var captureExceptions = true
229 | try {
230 | f()
231 | captureExceptions = false
232 | ret1(x)
233 | } catch {
234 | case t1: Exception if captureExceptions => thr1(t1)
235 | }
236 | }
237 | val thr: Exception => B = { t: Exception =>
238 | var captureExceptions = true
239 | try {
240 | f()
241 | captureExceptions = false
242 | thr1(t)
243 | } catch {
244 | case t1: Exception if captureExceptions => thr1(t1)
245 | }
246 | }
247 | fun(ret, thr1)
248 | }
249 | new ControlContext(fun1, null.asInstanceOf[A])
250 | }
251 | }
252 |
253 | }
254 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/plugin/src/main/scala/scala/tools/selectivecps/SelectiveCPSTransform.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.tools.selectivecps
14 |
15 | import scala.tools.nsc.transform._
16 | import scala.tools.nsc.plugins._
17 | import scala.tools.nsc.ast._
18 |
19 | /**
20 | * In methods marked @cps, CPS-transform assignments introduced by ANF-transform phase.
21 | */
22 | abstract class SelectiveCPSTransform extends PluginComponent with
23 | InfoTransform with TypingTransformers with CPSUtils with TreeDSL {
24 | // inherits abstract value `global` and class `Phase` from Transform
25 |
26 | import global._ // the global environment
27 | import definitions._ // standard classes and methods
28 | import typer.atOwner // methods to type trees
29 |
30 | override def description = "@cps-driven transform of selectiveanf assignments"
31 |
32 | /** the following two members override abstract members in Transform */
33 | val phaseName: String = "selectivecps"
34 |
35 | protected def newTransformer(unit: CompilationUnit): Transformer =
36 | new CPSTransformer(unit)
37 |
38 | /** This class does not change linearization */
39 | override def changesBaseClasses = false
40 |
41 | /** - return symbol's transformed type,
42 | */
43 | def transformInfo(sym: Symbol, tp: Type): Type = {
44 | if (!cpsEnabled) return tp
45 |
46 | val newtp = transformCPSType(tp)
47 |
48 | if (newtp != tp)
49 | debuglog("transformInfo changed type for " + sym + " to " + newtp);
50 |
51 | if (sym == MethReifyR)
52 | debuglog("transformInfo (not)changed type for " + sym + " to " + newtp);
53 |
54 | newtp
55 | }
56 |
57 | def transformCPSType(tp: Type): Type = { // TODO: use a TypeMap? need to handle more cases?
58 | tp match {
59 | case PolyType(params,res) => PolyType(params, transformCPSType(res))
60 | case NullaryMethodType(res) => NullaryMethodType(transformCPSType(res))
61 | case MethodType(params,res) => MethodType(params, transformCPSType(res))
62 | case TypeRef(pre, sym, args) => TypeRef(pre, sym, args.map(transformCPSType(_)))
63 | case _ =>
64 | getExternalAnswerTypeAnn(tp) match {
65 | case Some((res, outer)) =>
66 | appliedType(Context.tpeHK, List(removeAllCPSAnnotations(tp), res, outer))
67 | case _ =>
68 | removeAllCPSAnnotations(tp)
69 | }
70 | }
71 | }
72 |
73 |
74 | class CPSTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
75 | private val patmatTransformer = patmat.newTransformer(unit)
76 |
77 | override def transform(tree: Tree): Tree = {
78 | if (!cpsEnabled) return tree
79 | postTransform(mainTransform(tree))
80 | }
81 |
82 | def postTransform(tree: Tree): Tree = {
83 | tree.setType(transformCPSType(tree.tpe))
84 | }
85 |
86 |
87 | def mainTransform(tree: Tree): Tree = {
88 | tree match {
89 |
90 | // TODO: can we generalize this?
91 |
92 | case Apply(TypeApply(fun, targs), args)
93 | if (fun.symbol == MethShift) =>
94 | debuglog("found shift: " + tree)
95 | atPos(tree.pos) {
96 | val funR = gen.mkAttributedRef(MethShiftR) // TODO: correct?
97 | //gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedSelect(gen.mkAttributedIdent(ScalaPackage),
98 | //ScalaPackage.tpe.member("util")), ScalaPackage.tpe.member("util").tpe.member("continuations")), MethShiftR)
99 | //gen.mkAttributedRef(ModCPS.tpe, MethShiftR) // TODO: correct?
100 | debuglog("funR.tpe: " + funR.tpe)
101 | Apply(
102 | TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))),
103 | args.map(transform(_))
104 | ).setType(transformCPSType(tree.tpe))
105 | }
106 |
107 | case Apply(TypeApply(fun, targs), args)
108 | if (fun.symbol == MethShiftUnit) =>
109 | debuglog("found shiftUnit: " + tree)
110 | atPos(tree.pos) {
111 | val funR = gen.mkAttributedRef(MethShiftUnitR) // TODO: correct?
112 | debuglog("funR.tpe: " + funR.tpe)
113 | Apply(
114 | TypeApply(funR, List(targs(0), targs(1))).setType(appliedType(funR.tpe,
115 | List(targs(0).tpe, targs(1).tpe))),
116 | args.map(transform(_))
117 | ).setType(appliedType(Context.tpeHK, List(targs(0).tpe,targs(1).tpe,targs(1).tpe)))
118 | }
119 |
120 | case Apply(TypeApply(fun, targs), args)
121 | if (fun.symbol == MethReify) =>
122 | log("found reify: " + tree)
123 | atPos(tree.pos) {
124 | val funR = gen.mkAttributedRef(MethReifyR) // TODO: correct?
125 | debuglog("funR.tpe: " + funR.tpe)
126 | Apply(
127 | TypeApply(funR, targs).setType(appliedType(funR.tpe, targs.map((t:Tree) => t.tpe))),
128 | args.map(transform(_))
129 | ).setType(transformCPSType(tree.tpe))
130 | }
131 |
132 | case Try(block, catches, finalizer) =>
133 | // currently duplicates the catch block into a partial function.
134 | // this is kinda risky, but we don't expect there will be lots
135 | // of try/catches inside catch blocks (exp. blowup unlikely).
136 |
137 | // CAVEAT: finalizers are surprisingly tricky!
138 | // the problem is that they cannot easily be removed
139 | // from the regular control path and hence will
140 | // also be invoked after creating the Context object.
141 |
142 | /*
143 | object Test {
144 | def foo1 = {
145 | throw new Exception("in sub")
146 | shift((k:Int=>Int) => k(1))
147 | 10
148 | }
149 | def foo2 = {
150 | shift((k:Int=>Int) => k(2))
151 | 20
152 | }
153 | def foo3 = {
154 | shift((k:Int=>Int) => k(3))
155 | throw new Exception("in sub")
156 | 30
157 | }
158 | def foo4 = {
159 | shift((k:Int=>Int) => 4)
160 | throw new Exception("in sub")
161 | 40
162 | }
163 | def bar(x: Int) = try {
164 | if (x == 1)
165 | foo1
166 | else if (x == 2)
167 | foo2
168 | else if (x == 3)
169 | foo3
170 | else //if (x == 4)
171 | foo4
172 | } catch {
173 | case _ =>
174 | println("exception")
175 | 0
176 | } finally {
177 | println("done")
178 | }
179 | }
180 |
181 | reset(Test.bar(1)) // should print: exception,done,0
182 | reset(Test.bar(2)) // should print: done,20 <-- but prints: done,done,20
183 | reset(Test.bar(3)) // should print: exception,done,0 <-- but prints: done,exception,done,0
184 | reset(Test.bar(4)) // should print: 4 <-- but prints: done,4
185 | */
186 |
187 | val block1 = transform(block)
188 | val catches1 = transformCaseDefs(catches)
189 | val finalizer1 = transform(finalizer)
190 |
191 | if (hasAnswerTypeAnn(tree.tpe)) {
192 | //vprintln("CPS Transform: " + tree + "/" + tree.tpe + "/" + block1.tpe)
193 |
194 | val (stms, expr1) = block1 match {
195 | case Block(stms, expr) => (stms, expr)
196 | case expr => (Nil, expr)
197 | }
198 |
199 | val targettp = transformCPSType(tree.tpe)
200 |
201 | val pos = catches.head.pos
202 | val funSym = currentOwner.newValueParameter(cpsNames.catches, pos).setInfo(appliedType(PartialFunctionClass, ThrowableTpe, targettp))
203 | val funDef = localTyper.typedPos(pos) {
204 | ValDef(funSym, Match(EmptyTree, catches1))
205 | }
206 | val expr2 = localTyper.typedPos(pos) {
207 | Apply(Select(expr1, expr1.tpe.member(cpsNames.flatMapCatch)), List(Ident(funSym)))
208 | }
209 |
210 | val exSym = currentOwner.newValueParameter(cpsNames.ex, pos).setInfo(ThrowableTpe)
211 |
212 | import CODE._
213 | // generate a case that is supported directly by the back-end
214 | val catchIfDefined = CaseDef(
215 | Bind(exSym, Ident(nme.WILDCARD)),
216 | EmptyTree,
217 | IF ((REF(funSym) DOT nme.isDefinedAt)(REF(exSym))) THEN (REF(funSym) APPLY (REF(exSym))) ELSE Throw(REF(exSym))
218 | )
219 |
220 | val catch2 = localTyper.typedCases(List(catchIfDefined), ThrowableTpe, targettp)
221 | //typedCases(tree, catches, ThrowableTpe, pt)
222 |
223 | patmatTransformer.transform(localTyper.typed(Block(List(funDef), treeCopy.Try(tree, treeCopy.Block(block1, stms, expr2), catch2, finalizer1))))
224 |
225 |
226 | /*
227 | disabled for now - see notes above
228 |
229 | val expr3 = if (!finalizer.isEmpty) {
230 | val pos = finalizer.pos
231 | val finalizer2 = duplicateTree(finalizer1)
232 | val fun = Function(List(), finalizer2)
233 | val expr3 = localTyper.typedPos(pos) { Apply(Select(expr2, expr2.tpe.member("mapFinally")), List(fun)) }
234 |
235 | val chown = new ChangeOwnerTraverser(currentOwner, fun.symbol)
236 | chown.traverse(finalizer2)
237 |
238 | expr3
239 | } else
240 | expr2
241 | */
242 | } else {
243 | treeCopy.Try(tree, block1, catches1, finalizer1)
244 | }
245 |
246 | case Block(stms, expr) =>
247 |
248 | val (stms1, expr1) = transBlock(stms, expr)
249 | treeCopy.Block(tree, stms1, expr1)
250 |
251 | case _ =>
252 | super.transform(tree)
253 | }
254 | }
255 |
256 |
257 |
258 | def transBlock(stms: List[Tree], expr: Tree): (List[Tree], Tree) = {
259 |
260 | stms match {
261 | case Nil =>
262 | (Nil, transform(expr))
263 |
264 | case stm::rest =>
265 |
266 | stm match {
267 | case vd @ ValDef(mods, name, tpt, rhs)
268 | if (vd.symbol.hasAnnotation(MarkerCPSSym)) =>
269 |
270 | debuglog("found marked ValDef "+name+" of type " + vd.symbol.tpe)
271 |
272 | val tpe = vd.symbol.tpe
273 | val rhs1 = atOwner(vd.symbol) { transform(rhs) }
274 | rhs1.changeOwner(vd.symbol -> currentOwner) // TODO: don't traverse twice
275 |
276 | debuglog("valdef symbol " + vd.symbol + " has type " + tpe)
277 | debuglog("right hand side " + rhs1 + " has type " + rhs1.tpe)
278 |
279 | debuglog("currentOwner: " + currentOwner)
280 | debuglog("currentMethod: " + currentMethod)
281 |
282 | val (bodyStms, bodyExpr) = transBlock(rest, expr)
283 | // FIXME: result will later be traversed again by TreeSymSubstituter and
284 | // ChangeOwnerTraverser => exp. running time.
285 | // Should be changed to fuse traversals into one.
286 |
287 | val specialCaseTrivial = bodyExpr match {
288 | case Apply(fun, args) =>
289 | // for now, look for explicit tail calls only.
290 | // are there other cases that could profit from specializing on
291 | // trivial contexts as well?
292 | (bodyExpr.tpe.typeSymbol == Context) && (currentMethod == fun.symbol)
293 | case _ => false
294 | }
295 |
296 | def applyTrivial(ctxValSym: Symbol, body: Tree) = {
297 |
298 | val body1 = (new TreeSymSubstituter(List(vd.symbol), List(ctxValSym)))(body)
299 |
300 | val body2 = localTyper.typedPos(vd.symbol.pos) { body1 }
301 |
302 | // in theory it would be nicer to look for an @cps annotation instead
303 | // of testing for Context
304 | if ((body2.tpe == null) || !(body2.tpe.typeSymbol == Context)) {
305 | //println(body2 + "/" + body2.tpe)
306 | reporter.error(rhs.pos, "cannot compute type for CPS-transformed function result")
307 | }
308 | body2
309 | }
310 |
311 | def applyCombinatorFun(ctxR: Tree, body: Tree) = {
312 | val arg = currentOwner.newValueParameter(name, ctxR.pos).setInfo(tpe)
313 | val body1 = (new TreeSymSubstituter(List(vd.symbol), List(arg)))(body)
314 | val fun = localTyper.typedPos(vd.symbol.pos) { Function(List(ValDef(arg)), body1) } // types body as well
315 | arg.owner = fun.symbol
316 | body1.changeOwner(currentOwner -> fun.symbol)
317 |
318 | // see note about multiple traversals above
319 |
320 | debuglog("fun.symbol: "+fun.symbol)
321 | debuglog("fun.symbol.owner: "+fun.symbol.owner)
322 | debuglog("arg.owner: "+arg.owner)
323 |
324 | debuglog("fun.tpe:"+fun.tpe)
325 | debuglog("return type of fun:"+body1.tpe)
326 |
327 | var methodName = nme.map
328 |
329 | if (body1.tpe != null) {
330 | if (body1.tpe.typeSymbol == Context)
331 | methodName = nme.flatMap
332 | }
333 | else
334 | reporter.error(rhs.pos, "cannot compute type for CPS-transformed function result")
335 |
336 | debuglog("will use method:"+methodName)
337 |
338 | localTyper.typedPos(vd.symbol.pos) {
339 | Apply(Select(ctxR, ctxR.tpe.member(methodName)), List(fun))
340 | }
341 | }
342 |
343 | // TODO use gen.mkBlock after 2.11.0-M6. Why wait? It allows us to still build in development
344 | // mode with `ant -DskipLocker=1`
345 | def mkBlock(stms: List[Tree], expr: Tree) = if (stms.nonEmpty) Block(stms, expr) else expr
346 |
347 | try {
348 | if (specialCaseTrivial) {
349 | debuglog("will optimize possible tail call: " + bodyExpr)
350 |
351 | // FIXME: flatMap impl has become more complicated due to
352 | // exceptions. do we need to put a try/catch in the then part??
353 |
354 | // val ctx =
355 | // if (ctx.isTrivial)
356 | // val = ctx.getTrivialValue; ... <--- TODO: try/catch ??? don't bother for the moment...
357 | // else
358 | // ctx.flatMap { => ... }
359 | val ctxSym = currentOwner.newValue(newTermName("" + vd.symbol.name + cpsNames.shiftSuffix)).setInfo(rhs1.tpe)
360 | val ctxDef = localTyper.typed(ValDef(ctxSym, rhs1))
361 | def ctxRef = localTyper.typed(Ident(ctxSym))
362 | val argSym = currentOwner.newValue(vd.symbol.name.toTermName).setInfo(tpe)
363 | val argDef = localTyper.typed(ValDef(argSym, Select(ctxRef, ctxRef.tpe.member(cpsNames.getTrivialValue))))
364 | val switchExpr = localTyper.typedPos(vd.symbol.pos) {
365 | val body2 = mkBlock(bodyStms, bodyExpr).duplicate // dup before typing!
366 | If(Select(ctxRef, ctxSym.tpe.member(cpsNames.isTrivial)),
367 | applyTrivial(argSym, mkBlock(argDef::bodyStms, bodyExpr)),
368 | applyCombinatorFun(ctxRef, body2))
369 | }
370 | (List(ctxDef), switchExpr)
371 | } else {
372 | // ctx.flatMap { => ... }
373 | // or
374 | // ctx.map { => ... }
375 | (Nil, applyCombinatorFun(rhs1, mkBlock(bodyStms, bodyExpr)))
376 | }
377 | } catch {
378 | case ex:TypeError =>
379 | reporter.error(ex.pos, ex.msg)
380 | (bodyStms, bodyExpr)
381 | }
382 |
383 | case _ =>
384 | val stm1 = transform(stm)
385 | val (a, b) = transBlock(rest, expr)
386 | (stm1::a, b)
387 | }
388 | }
389 | }
390 |
391 |
392 | }
393 | }
394 |
--------------------------------------------------------------------------------
/plugin/src/main/scala/scala/tools/selectivecps/CPSAnnotationChecker.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.tools.selectivecps
14 |
15 | import scala.tools.nsc.{ Global, Mode }
16 | import scala.tools.nsc.MissingRequirementError
17 |
18 | abstract class CPSAnnotationChecker extends CPSUtils {
19 | val global: Global
20 | import global._
21 | import analyzer.{AnalyzerPlugin, Typer}
22 | import definitions._
23 |
24 | //override val verbose = true
25 | @inline override final def vprintln(x: =>Any): Unit = if (verbose) println(x)
26 |
27 | /**
28 | * Checks whether @cps annotations conform
29 | */
30 | object checker extends AnnotationChecker {
31 | private[CPSAnnotationChecker] def addPlusMarker(tp: Type) = tp withAnnotation newPlusMarker()
32 | private[CPSAnnotationChecker] def addMinusMarker(tp: Type) = tp withAnnotation newMinusMarker()
33 |
34 | private[CPSAnnotationChecker] def cleanPlus(tp: Type) =
35 | removeAttribs(tp, MarkerCPSAdaptPlus, MarkerCPSTypes)
36 | private[CPSAnnotationChecker] def cleanPlusWith(tp: Type)(newAnnots: AnnotationInfo*) =
37 | cleanPlus(tp) withAnnotations newAnnots.toList
38 |
39 | /** Check annotations to decide whether tpe1 <:< tpe2 */
40 | def annotationsConform(tpe1: Type, tpe2: Type): Boolean = {
41 | if (!cpsEnabled) return true
42 |
43 | vprintln("check annotations: " + tpe1 + " <:< " + tpe2)
44 |
45 | // Nothing is least element, but Any is not the greatest
46 | if (tpe1.typeSymbol eq NothingClass)
47 | return true
48 |
49 | val annots1 = cpsParamAnnotation(tpe1)
50 | val annots2 = cpsParamAnnotation(tpe2)
51 |
52 | // @plus and @minus should only occur at the left, and never together
53 | // TODO: insert check
54 |
55 | // @minus @cps is the same as no annotations
56 | if (hasMinusMarker(tpe1))
57 | return annots2.isEmpty
58 |
59 | // to handle answer type modification, we must make @plus <:< @cps
60 | if (hasPlusMarker(tpe1) && annots1.isEmpty)
61 | return true
62 |
63 | // @plus @cps will fall through and compare the @cps type args
64 | // @cps parameters must match exactly
65 | if ((annots1 corresponds annots2)(_.atp <:< _.atp))
66 | return true
67 |
68 | // Need to handle uninstantiated type vars specially:
69 |
70 | // g map (x => x) with expected type List[Int] @cps
71 | // results in comparison ?That <:< List[Int] @cps
72 |
73 | // Instantiating ?That to an annotated type would fail during
74 | // transformation.
75 |
76 | // Instead we force-compare tpe1 <:< tpe2.withoutAnnotations
77 | // to trigger instantiation of the TypeVar to the base type
78 |
79 | // This is a bit unorthodox (we're only supposed to look at
80 | // annotations here) but seems to work.
81 |
82 | if (!annots2.isEmpty && !tpe1.isGround)
83 | return tpe1 <:< tpe2.withoutAnnotations
84 |
85 | false
86 | }
87 |
88 | /** Refine the computed least upper bound of a list of types.
89 | * All this should do is add annotations. */
90 | override def annotationsLub(tpe: Type, ts: List[Type]): Type = {
91 | if (!cpsEnabled) return tpe
92 |
93 | val annots1 = cpsParamAnnotation(tpe)
94 | val annots2 = ts flatMap cpsParamAnnotation
95 |
96 | if (annots2.nonEmpty) {
97 | val cpsLub = newMarker(global.lub(annots1:::annots2 map (_.atp)))
98 | val tpe1 = if (annots1.nonEmpty) removeAttribs(tpe, MarkerCPSTypes) else tpe
99 | tpe1.withAnnotation(cpsLub)
100 | }
101 | else tpe
102 | }
103 |
104 | /** Refine the bounds on type parameters to the given type arguments. */
105 | override def adaptBoundsToAnnotations(bounds: List[TypeBounds], tparams: List[Symbol], targs: List[Type]): List[TypeBounds] = {
106 | if (!cpsEnabled) return bounds
107 |
108 | val anyAtCPS = newCpsParamsMarker(NothingTpe, AnyTpe)
109 | if (isFunctionType(tparams.head.owner.tpe_*) || isPartialFunctionType(tparams.head.owner.tpe_*)) {
110 | vprintln("function bound: " + tparams.head.owner.tpe + "/"+bounds+"/"+targs)
111 | if (hasCpsParamTypes(targs.last))
112 | bounds.reverse match {
113 | case res::b if !hasCpsParamTypes(res.hi) =>
114 | (TypeBounds(res.lo, res.hi.withAnnotation(anyAtCPS))::b).reverse
115 | case _ => bounds
116 | }
117 | else
118 | bounds
119 | }
120 | else if (tparams.head.owner == ByNameParamClass) {
121 | vprintln("byname bound: " + tparams.head.owner.tpe + "/"+bounds+"/"+targs)
122 | val TypeBounds(lo, hi) = bounds.head
123 | if (hasCpsParamTypes(targs.head) && !hasCpsParamTypes(hi))
124 | TypeBounds(lo, hi withAnnotation anyAtCPS) :: Nil
125 | else bounds
126 | } else
127 | bounds
128 | }
129 | }
130 |
131 | object plugin extends AnalyzerPlugin {
132 |
133 | import checker._
134 |
135 | override def canAdaptAnnotations(tree: Tree, typer: Typer, mode: Mode, pt: Type): Boolean = {
136 | if (!cpsEnabled) return false
137 | vprintln("can adapt annotations? " + tree + " / " + tree.tpe + " / " + mode + " / " + pt)
138 |
139 | val annots1 = cpsParamAnnotation(tree.tpe)
140 | val annots2 = cpsParamAnnotation(pt)
141 |
142 | if (mode.inPatternMode) {
143 | //println("can adapt pattern annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt)
144 | if (!annots1.isEmpty) {
145 | return true
146 | }
147 | }
148 |
149 | /*
150 | // not precise enough -- still relying on addAnnotations to remove things from ValDef symbols
151 | if (mode.inAllModes(TYPEmode | BYVALmode)) {
152 | if (!annots1.isEmpty) {
153 | return true
154 | }
155 | }
156 | */
157 |
158 | /*
159 | this interferes with overloading resolution
160 | if (mode.inByValMode && tree.tpe <:< pt) {
161 | vprintln("already compatible, can't adapt further")
162 | return false
163 | }
164 | */
165 | if (mode.inExprMode) {
166 | if ((annots1 corresponds annots2)(_.atp <:< _.atp)) {
167 | vprintln("already same, can't adapt further")
168 | false
169 | } else if (annots1.isEmpty && !annots2.isEmpty && !mode.inByValMode) {
170 | //println("can adapt annotations? " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt)
171 | if (!hasPlusMarker(tree.tpe)) {
172 | // val base = tree.tpe <:< removeAllCPSAnnotations(pt)
173 | // val known = global.analyzer.isFullyDefined(pt)
174 | // println(same + "/" + base + "/" + known)
175 | //val same = annots2 forall { case AnnotationInfo(atp: TypeRef, _, _) => atp.typeArgs(0) =:= atp.typeArgs(1) }
176 | // TBD: use same or not?
177 | //if (same) {
178 | vprintln("yes we can!! (unit)")
179 | true
180 | //}
181 | } else false
182 | } else if (!hasPlusMarker(tree.tpe) && annots1.isEmpty && !annots2.isEmpty && typer.context.inReturnExpr) {
183 | vprintln("checking enclosing method's result type without annotations")
184 | tree.tpe <:< pt.withoutAnnotations
185 | } else if (!hasMinusMarker(tree.tpe) && !annots1.isEmpty && mode.inByValMode) {
186 | val optCpsTypes: Option[(Type, Type)] = cpsParamTypes(tree.tpe)
187 | val optExpectedCpsTypes: Option[(Type, Type)] = cpsParamTypes(pt)
188 | if (optCpsTypes.isEmpty || optExpectedCpsTypes.isEmpty) {
189 | vprintln("yes we can!! (byval)")
190 | true
191 | } else { // check cps param types
192 | val cpsTpes = optCpsTypes.get
193 | val cpsPts = optExpectedCpsTypes.get
194 | // class cpsParam[-B,+C], therefore:
195 | cpsPts._1 <:< cpsTpes._1 && cpsTpes._2 <:< cpsPts._2
196 | }
197 | } else false
198 | } else false
199 | }
200 |
201 | override def adaptAnnotations(tree: Tree, typer: Typer, mode: Mode, pt: Type): Tree = {
202 | if (!cpsEnabled) return tree
203 |
204 | vprintln("adapt annotations " + tree + " / " + tree.tpe + " / " + mode + " / " + pt)
205 |
206 | val annotsTree = cpsParamAnnotation(tree.tpe)
207 | val annotsExpected = cpsParamAnnotation(pt)
208 | def isMissingExpectedAnnots = annotsTree.isEmpty && annotsExpected.nonEmpty
209 |
210 | // not sure I rephrased this comment correctly:
211 | // replacing `mode.inPatternMode` in the condition below by `mode.inPatternMode || mode.inAllModes(TYPEmode | BYVALmode)`
212 | // doesn't work correctly -- still relying on addAnnotations to remove things from ValDef symbols
213 | if (mode.inPatternMode && annotsTree.nonEmpty) tree modifyType removeAllCPSAnnotations
214 | else if (mode.typingExprNotValue && !hasPlusMarker(tree.tpe) && isMissingExpectedAnnots) { // shiftUnit
215 | // add a marker annotation that will make tree.tpe behave as pt, subtyping wise
216 | // tree will look like having any possible annotation
217 | //println("adapt annotations " + tree + " / " + tree.tpe + " / " + Integer.toHexString(mode) + " / " + pt)
218 |
219 | // CAVEAT:
220 | // for monomorphic answer types we want to have @plus @cps (for better checking)
221 | // for answer type modification we want to have only @plus (because actual answer type may differ from pt)
222 |
223 | val res = tree modifyType (_ withAnnotations newPlusMarker() :: annotsExpected) // needed for #1807
224 | vprintln("adapted annotations (not by val) of " + tree + " to " + res.tpe)
225 | res
226 | } else if (mode.typingExprByValue && !hasMinusMarker(tree.tpe) && annotsTree.nonEmpty) { // dropping annotation
227 | // add a marker annotation that will make tree.tpe behave as pt, subtyping wise
228 | // tree will look like having no annotation
229 | val res = tree modifyType addMinusMarker
230 | vprintln("adapted annotations (by val) of " + tree + " to " + res.tpe)
231 | res
232 | } else if (typer.context.inReturnExpr && !hasPlusMarker(tree.tpe) && isMissingExpectedAnnots) {
233 | // add a marker annotation that will make tree.tpe behave as pt, subtyping wise
234 | // tree will look like having any possible annotation
235 |
236 | // note 1: we are only adding a plus marker if the method's result type is a cps type
237 | // (annotsExpected.nonEmpty == cpsParamAnnotation(pt).nonEmpty)
238 | // note 2: we are not adding the expected cps annotations, since they will be added
239 | // by adaptTypeOfReturn (see below).
240 | val res = tree modifyType (_ withAnnotation newPlusMarker())
241 | vprintln("adapted annotations (return) of " + tree + " to " + res.tpe)
242 | res
243 | } else tree
244 | }
245 |
246 | /** Returns an adapted type for a return expression if the method's result type (pt) is a CPS type.
247 | * Otherwise, it returns the `default` type (`typedReturn` passes `NothingTpe`).
248 | *
249 | * A return expression in a method that has a CPS result type is an error unless the return
250 | * is in tail position. Therefore, we are making sure that only the types of return expressions
251 | * are adapted which will either be removed, or lead to an error.
252 | */
253 | override def pluginsTypedReturn(default: Type, typer: Typer, tree: Return, pt: Type): Type = {
254 | val expr = tree.expr
255 | // only adapt if method's result type (pt) is cps type
256 | val annots = cpsParamAnnotation(pt)
257 | if (annots.nonEmpty) {
258 | // return type of `expr` without plus marker, but only if it doesn't have other cps annots
259 | if (hasPlusMarker(expr.tpe) && !hasCpsParamTypes(expr.tpe))
260 | expr.setType(removeAttribs(expr.tpe, MarkerCPSAdaptPlus))
261 | expr.tpe
262 | } else default
263 | }
264 |
265 | def updateAttributesFromChildren(tpe: Type, childAnnots: List[AnnotationInfo], byName: List[Tree]): Type = {
266 | tpe match {
267 | // Would need to push annots into each alternative of overloaded type
268 | // But we can't, since alternatives aren't types but symbols, which we
269 | // can't change (we'd be affecting symbols globally)
270 | /*
271 | case OverloadedType(pre, alts) =>
272 | OverloadedType(pre, alts.map((sym: Symbol) => updateAttributes(pre.memberType(sym), annots)))
273 | */
274 | case OverloadedType(pre, alts) => tpe //reconstruct correct annotations later
275 | case MethodType(params, restpe) => tpe
276 | case PolyType(params, restpe) => tpe
277 | case _ =>
278 | assert(childAnnots forall (_ matches MarkerCPSTypes), childAnnots)
279 | /*
280 | [] + [] = []
281 | plus + [] = plus
282 | cps + [] = cps
283 | plus cps + [] = plus cps
284 | minus cps + [] = minus cps
285 | synth cps + [] = synth cps // <- synth on left - does it happen?
286 |
287 | [] + cps = cps
288 | plus + cps = synth cps
289 | cps + cps = cps! <- lin
290 | plus cps + cps = synth cps! <- unify
291 | minus cps + cps = minus cps! <- lin
292 | synth cps + cps = synth cps! <- unify
293 | */
294 |
295 | val plus = hasPlusMarker(tpe) || (
296 | hasCpsParamTypes(tpe)
297 | && byName.nonEmpty
298 | && (byName forall (t => hasPlusMarker(t.tpe)))
299 | )
300 |
301 | // move @plus annotations outward from by-name children
302 | if (childAnnots.isEmpty) return {
303 | if (plus) { // @plus or @plus @cps
304 | byName foreach (_ modifyType cleanPlus)
305 | addPlusMarker(tpe)
306 | }
307 | else tpe
308 | }
309 |
310 | val annots1 = cpsParamAnnotation(tpe)
311 |
312 | if (annots1.isEmpty) { // nothing or @plus
313 | cleanPlusWith(tpe)(newSynthMarker(), linearize(childAnnots))
314 | }
315 | else {
316 | val annot1 = single(annots1)
317 | if (plus) { // @plus @cps
318 | val annot2 = linearize(childAnnots)
319 |
320 | if (annot2.atp <:< annot1.atp) {
321 | try cleanPlusWith(tpe)(newSynthMarker(), annot2)
322 | finally byName foreach (_ modifyType cleanPlus)
323 | }
324 | else throw new TypeError(annot2 + " is not a subtype of " + annot1)
325 | }
326 | else if (hasSynthMarker(tpe)) { // @synth @cps
327 | val annot2 = linearize(childAnnots)
328 | if (annot2.atp <:< annot1.atp)
329 | cleanPlusWith(tpe)(annot2)
330 | else
331 | throw new TypeError(annot2 + " is not a subtype of " + annot1)
332 | }
333 | else // @cps
334 | cleanPlusWith(tpe)(linearize(childAnnots:::annots1))
335 | }
336 | }
337 | }
338 |
339 | def transArgList(fun: Tree, args: List[Tree]): List[List[Tree]] = {
340 | val formals = fun.tpe.paramTypes
341 | val overshoot = args.length - formals.length
342 |
343 | for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield {
344 | tp match {
345 | case TypeRef(_, ByNameParamClass, List(elemtp)) =>
346 | Nil // TODO: check conformance??
347 | case _ =>
348 | List(a)
349 | }
350 | }
351 | }
352 |
353 |
354 | def transStms(stms: List[Tree]): List[Tree] = stms match {
355 | case ValDef(mods, name, tpt, rhs)::xs =>
356 | rhs::transStms(xs)
357 | case Assign(lhs, rhs)::xs =>
358 | rhs::transStms(xs)
359 | case x::xs =>
360 | x::transStms(xs)
361 | case Nil =>
362 | Nil
363 | }
364 |
365 | def single(xs: List[AnnotationInfo]) = xs match {
366 | case List(x) => x
367 | case _ =>
368 | global.globalError("not a single cps annotation: " + xs)
369 | xs(0)
370 | }
371 |
372 | def emptyOrSingleList(xs: List[AnnotationInfo]) = if (xs.isEmpty) Nil else List(single(xs))
373 |
374 | def transChildrenInOrder(tree: Tree, tpe: Type, childTrees: List[Tree], byName: List[Tree]) = {
375 | def inspect(t: Tree): List[AnnotationInfo] = {
376 | if (t.tpe eq null) Nil else {
377 | val extra: List[AnnotationInfo] = t.tpe match {
378 | case _: MethodType | _: PolyType | _: OverloadedType =>
379 | // method types, poly types and overloaded types do not obtain cps annotions by propagation
380 | // need to reconstruct transitively from their children.
381 | t match {
382 | case Select(qual, name) => inspect(qual)
383 | case Apply(fun, args) => (fun::(transArgList(fun,args).flatten)) flatMap inspect
384 | case TypeApply(fun, args) => (fun::(transArgList(fun,args).flatten)) flatMap inspect
385 | case _ => Nil
386 | }
387 | case _ => Nil
388 | }
389 |
390 | val types = cpsParamAnnotation(t.tpe)
391 | // TODO: check that it has been adapted and if so correctly
392 | extra ++ emptyOrSingleList(types)
393 | }
394 | }
395 | val children = childTrees flatMap inspect
396 |
397 | val newtpe = updateAttributesFromChildren(tpe, children, byName)
398 |
399 | if (!newtpe.annotations.isEmpty)
400 | vprintln("[checker] inferred " + tree + " / " + tpe + " ===> "+ newtpe)
401 |
402 | newtpe
403 | }
404 |
405 | /** Modify the type that has thus far been inferred
406 | * for a tree. All this should do is add annotations. */
407 |
408 | override def pluginsTyped(tpe: Type, typer: Typer, tree: Tree, mode: Mode, pt: Type): Type = {
409 | if (!cpsEnabled) {
410 | val report = try hasCpsParamTypes(tpe) catch { case _: MissingRequirementError => false }
411 | if (report)
412 | reporter.error(tree.pos, "this code must be compiled with the Scala continuations plugin enabled")
413 |
414 | return tpe
415 | }
416 |
417 | // if (tree.tpe.hasAnnotation(MarkerCPSAdaptPlus))
418 | // println("addAnnotation " + tree + "/" + tpe)
419 |
420 | tree match {
421 |
422 | case Apply(fun @ Select(qual, name), args) if fun.isTyped =>
423 |
424 | // HACK: With overloaded methods, fun will never get annotated. This is because
425 | // the 'overloaded' type gets annotated, but not the alternatives (among which
426 | // fun's type is chosen)
427 |
428 | vprintln("[checker] checking select apply " + tree + "/" + tpe)
429 |
430 | transChildrenInOrder(tree, tpe, qual::(transArgList(fun, args).flatten), Nil)
431 |
432 | case Apply(TypeApply(fun @ Select(qual, name), targs), args) if fun.isTyped => // not trigge
433 |
434 | vprintln("[checker] checking select apply type-apply " + tree + "/" + tpe)
435 |
436 | transChildrenInOrder(tree, tpe, qual::(transArgList(fun, args).flatten), Nil)
437 |
438 | case TypeApply(fun @ Select(qual, name), args) if fun.isTyped =>
439 | def stripNullaryMethodType(tp: Type) = tp match { case NullaryMethodType(restpe) => restpe case tp => tp }
440 | vprintln("[checker] checking select type-apply " + tree + "/" + tpe)
441 |
442 | transChildrenInOrder(tree, stripNullaryMethodType(tpe), List(qual, fun), Nil)
443 |
444 | case Apply(fun, args) if fun.isTyped =>
445 |
446 | vprintln("[checker] checking unknown apply " + tree + "/" + tpe)
447 |
448 | transChildrenInOrder(tree, tpe, fun::(transArgList(fun, args).flatten), Nil)
449 |
450 | case TypeApply(fun, args) =>
451 |
452 | vprintln("[checker] checking unknown type apply " + tree + "/" + tpe)
453 |
454 | transChildrenInOrder(tree, tpe, List(fun), Nil)
455 |
456 | case Select(qual, name) if qual.isTyped =>
457 |
458 | vprintln("[checker] checking select " + tree + "/" + tpe)
459 |
460 | // straightforward way is problematic (see select.scala and Test2.scala)
461 | // transChildrenInOrder(tree, tpe, List(qual), Nil)
462 |
463 | // the problem is that qual may be of type OverloadedType (or MethodType) and
464 | // we cannot safely annotate these. so we just ignore these cases and
465 | // clean up later in the Apply/TypeApply trees.
466 |
467 | if (hasCpsParamTypes(qual.tpe)) {
468 | // however there is one special case:
469 | // if it's a method without parameters, just apply it. normally done in adapt, but
470 | // we have to do it here so we don't lose the cps information (wouldn't trigger our
471 | // adapt and there is no Apply/TypeApply created)
472 | tpe match {
473 | case NullaryMethodType(restpe) =>
474 | //println("yep: " + restpe + "," + restpe.getClass)
475 | transChildrenInOrder(tree, restpe, List(qual), Nil)
476 | case _ : PolyType => tpe
477 | case _ : MethodType => tpe
478 | case _ : OverloadedType => tpe
479 | case _ =>
480 | transChildrenInOrder(tree, tpe, List(qual), Nil)
481 | }
482 | } else
483 | tpe
484 |
485 | case If(cond, thenp, elsep) =>
486 | transChildrenInOrder(tree, tpe, List(cond), List(thenp, elsep))
487 |
488 | case Match(select, cases) =>
489 | transChildrenInOrder(tree, tpe, List(select), cases:::(cases map { case CaseDef(_, _, body) => body }))
490 |
491 | case Try(block, catches, finalizer) =>
492 | val tpe1 = transChildrenInOrder(tree, tpe, Nil, block::catches:::(catches map { case CaseDef(_, _, body) => body }))
493 |
494 | val annots = cpsParamAnnotation(tpe1)
495 | if (annots.nonEmpty) {
496 | val ann = single(annots)
497 | val (atp0, atp1) = annTypes(ann)
498 | if (!(atp0 =:= atp1))
499 | throw new TypeError("only simple cps types allowed in try/catch blocks (found: " + tpe1 + ")")
500 | if (!finalizer.isEmpty) // no finalizers allowed. see explanation in SelectiveCPSTransform
501 | typer.context.error(tree.pos, "try/catch blocks that use continuations cannot have finalizers")
502 | }
503 | tpe1
504 |
505 | case Block(stms, expr) =>
506 | // if any stm has annotation, so does block
507 | transChildrenInOrder(tree, tpe, transStms(stms), List(expr))
508 |
509 | case ValDef(mods, name, tpt, rhs) =>
510 | vprintln("[checker] checking valdef " + name + "/"+tpe+"/"+tpt+"/"+tree.symbol.tpe)
511 | // ValDef symbols must *not* have annotations!
512 | // lazy vals are currently not supported
513 | // but if we erase here all annotations, compiler will complain only
514 | // when generating bytecode.
515 | // This way lazy vals will be reported as unsupported feature later rather than weird type error.
516 | if (hasAnswerTypeAnn(tree.symbol.info) && !mods.isLazy) { // is it okay to modify sym here?
517 | vprintln("removing annotation from sym " + tree.symbol + "/" + tree.symbol.tpe + "/" + tpt)
518 | tpt modifyType removeAllCPSAnnotations
519 | tree.symbol modifyInfo removeAllCPSAnnotations
520 | }
521 | tpe
522 |
523 | case _ =>
524 | tpe
525 | }
526 |
527 |
528 | }
529 | }
530 | }
531 |
--------------------------------------------------------------------------------
/plugin/src/main/scala-2.11/scala/tools/selectivecps/SelectiveANFTransform.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.tools.selectivecps
14 |
15 | import scala.tools.nsc.plugins._
16 | import scala.tools.nsc.symtab._
17 | import scala.tools.nsc.transform._
18 |
19 | /**
20 | * In methods marked @cps, explicitly name results of calls to other @cps methods
21 | */
22 | abstract class SelectiveANFTransform extends PluginComponent with Transform with
23 | TypingTransformers with CPSUtils {
24 | // inherits abstract value `global` and class `Phase` from Transform
25 |
26 | import global._
27 | import definitions._ // methods to type trees
28 |
29 | override def description = "ANF pre-transform for @cps"
30 |
31 | /** the following two members override abstract members in Transform */
32 | val phaseName: String = "selectiveanf"
33 |
34 | protected def newTransformer(unit: CompilationUnit): Transformer =
35 | new ANFTransformer(unit)
36 |
37 | class ANFTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
38 |
39 | var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet)
40 |
41 | object RemoveTailReturnsTransformer extends Transformer {
42 | override def transform(tree: Tree): Tree = tree match {
43 | case Block(stms, r @ Return(expr)) =>
44 | treeCopy.Block(tree, stms, expr)
45 |
46 | case Block(stms, expr) =>
47 | treeCopy.Block(tree, stms, transform(expr))
48 |
49 | case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) =>
50 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
51 |
52 | case If(cond, r1 @ Return(thenExpr), elseExpr) =>
53 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
54 |
55 | case If(cond, thenExpr, r2 @ Return(elseExpr)) =>
56 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
57 |
58 | case If(cond, thenExpr, elseExpr) =>
59 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
60 |
61 | case Try(block, catches, finalizer) =>
62 | treeCopy.Try(tree,
63 | transform(block),
64 | (catches map (t => transform(t))).asInstanceOf[List[CaseDef]],
65 | transform(finalizer))
66 |
67 | case CaseDef(pat, guard, r @ Return(expr)) =>
68 | treeCopy.CaseDef(tree, pat, guard, expr)
69 |
70 | case CaseDef(pat, guard, body) =>
71 | treeCopy.CaseDef(tree, pat, guard, transform(body))
72 |
73 | case Return(_) =>
74 | reporter.error(tree.pos, "return expressions in CPS code must be in tail position")
75 | tree
76 |
77 | case _ =>
78 | super.transform(tree)
79 | }
80 | }
81 |
82 | def removeTailReturns(body: Tree): Tree = {
83 | // support body with single return expression
84 | body match {
85 | case Return(expr) => expr
86 | case _ => RemoveTailReturnsTransformer.transform(body)
87 | }
88 | }
89 |
90 | override def transform(tree: Tree): Tree = {
91 | if (!cpsEnabled) return tree
92 |
93 | tree match {
94 |
95 | // Maybe we should further generalize the transform and move it over
96 | // to the regular Transformer facility. But then, actual and required cps
97 | // state would need more complicated (stateful!) tracking.
98 |
99 | // Making the default case use transExpr(tree, None, None) instead of
100 | // calling super.transform() would be a start, but at the moment,
101 | // this would cause infinite recursion. But we could remove the
102 | // ValDef case here.
103 |
104 | case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) =>
105 | debuglog("transforming " + dd.symbol)
106 |
107 | atOwner(dd.symbol) {
108 | val rhs =
109 | if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0)
110 | else rhs0
111 | val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))(getExternalAnswerTypeAnn(tpt.tpe).isDefined)
112 |
113 | debuglog("result "+rhs1)
114 | debuglog("result is of type "+rhs1.tpe)
115 |
116 | treeCopy.DefDef(dd, mods, name, transformTypeDefs(tparams), transformValDefss(vparamss),
117 | transform(tpt), rhs1)
118 | }
119 |
120 | case ff @ Function(vparams, body) =>
121 | debuglog("transforming anon function " + ff.symbol)
122 |
123 | atOwner(ff.symbol) {
124 |
125 | //val body1 = transExpr(body, None, getExternalAnswerTypeAnn(body.tpe))
126 |
127 | // need to special case partial functions: if expected type is @cps
128 | // but all cases are pure, then we would transform
129 | // { x => x match { case A => ... }} to
130 | // { x => shiftUnit(x match { case A => ... })}
131 | // which Uncurry cannot handle (see function6.scala)
132 | // thus, we push down the shiftUnit to each of the case bodies
133 |
134 | val ext = getExternalAnswerTypeAnn(body.tpe)
135 | val pureBody = getAnswerTypeAnn(body.tpe).isEmpty
136 | implicit val isParentImpure = ext.isDefined
137 |
138 | def transformPureMatch(tree: Tree, selector: Tree, cases: List[CaseDef]) = {
139 | val caseVals = cases map { case cd @ CaseDef(pat, guard, body) =>
140 | // if (!hasPlusMarker(body.tpe)) body modifyType (_ withAnnotation newPlusMarker()) // TODO: to avoid warning
141 | val bodyVal = transExpr(body, None, ext) // ??? triggers "cps-transformed unexpectedly" warning in transTailValue
142 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
143 | }
144 | treeCopy.Match(tree, transform(selector), caseVals)
145 | }
146 |
147 | def transformPureVirtMatch(body: Block, selDef: ValDef, cases: List[Tree], matchEnd: Tree) = {
148 | val stats = transform(selDef) :: (cases map (transExpr(_, None, ext)))
149 | treeCopy.Block(body, stats, transExpr(matchEnd, None, ext))
150 | }
151 |
152 | val body1 = body match {
153 | case Match(selector, cases) if ext.isDefined && pureBody =>
154 | transformPureMatch(body, selector, cases)
155 |
156 | // virtpatmat switch
157 | case Block(List(selDef: ValDef), mat@Match(selector, cases)) if ext.isDefined && pureBody =>
158 | treeCopy.Block(body, List(transform(selDef)), transformPureMatch(mat, selector, cases))
159 |
160 | // virtpatmat
161 | case b@Block(matchStats@((selDef: ValDef) :: cases), matchEnd) if ext.isDefined && pureBody && (matchStats forall treeInfo.hasSynthCaseSymbol) =>
162 | transformPureVirtMatch(b, selDef, cases, matchEnd)
163 |
164 | // virtpatmat that stores the scrut separately -- TODO: can we eliminate this case??
165 | case Block(List(selDef0: ValDef), mat@Block(matchStats@((selDef: ValDef) :: cases), matchEnd)) if ext.isDefined && pureBody && (matchStats forall treeInfo.hasSynthCaseSymbol)=>
166 | treeCopy.Block(body, List(transform(selDef0)), transformPureVirtMatch(mat, selDef, cases, matchEnd))
167 |
168 | case _ =>
169 | transExpr(body, None, ext)
170 | }
171 |
172 | debuglog("anf result "+body1+"\nresult is of type "+body1.tpe)
173 |
174 | treeCopy.Function(ff, transformValDefs(vparams), body1)
175 | }
176 |
177 | case vd @ ValDef(mods, name, tpt, rhs) => // object-level valdefs
178 | debuglog("transforming valdef " + vd.symbol)
179 |
180 | if (getExternalAnswerTypeAnn(tpt.tpe).isEmpty) {
181 |
182 | atOwner(vd.symbol) {
183 |
184 | val rhs1 = transExpr(rhs, None, None)
185 |
186 | treeCopy.ValDef(vd, mods, name, transform(tpt), rhs1)
187 | }
188 | } else {
189 | reporter.error(tree.pos, "cps annotations not allowed on by-value parameters or value definitions")
190 | super.transform(tree)
191 | }
192 |
193 | case TypeTree() =>
194 | // circumvent cpsAllowed here
195 | super.transform(tree)
196 |
197 | case Apply(_,_) =>
198 | // this allows reset { ... } in object constructors
199 | // it's kind of a hack to put it here (see note above)
200 | transExpr(tree, None, None)
201 |
202 | case _ =>
203 | if (hasAnswerTypeAnn(tree.tpe)) {
204 | if (!cpsAllowed) {
205 | if (tree.symbol.isLazy)
206 | reporter.error(tree.pos, "implementation restriction: cps annotations not allowed on lazy value definitions")
207 | else
208 | reporter.error(tree.pos, "cps code not allowed here / " + tree.getClass + " / " + tree)
209 | }
210 | log(tree)
211 | }
212 |
213 | cpsAllowed = false
214 | super.transform(tree)
215 | }
216 | }
217 |
218 |
219 | def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean = false): Tree = {
220 | transTailValue(tree, cpsA, cpsR)(cpsR.isDefined || isAnyParentImpure) match {
221 | case (Nil, b) => b
222 | case (a, b) =>
223 | treeCopy.Block(tree, a,b)
224 | }
225 | }
226 |
227 |
228 | def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[List[Tree]], List[Tree], CPSInfo) = {
229 | val formals = fun.tpe.paramTypes
230 | val overshoot = args.length - formals.length
231 |
232 | var spc: CPSInfo = cpsA
233 |
234 | val (stm,expr) = (for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield {
235 | tp match {
236 | case TypeRef(_, ByNameParamClass, List(elemtp)) =>
237 | // note that we're not passing just isAnyParentImpure
238 | (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))(getAnswerTypeAnn(elemtp).isDefined || isAnyParentImpure))
239 | case _ =>
240 | val (valStm, valExpr, valSpc) = transInlineValue(a, spc)
241 | spc = valSpc
242 | (valStm, valExpr)
243 | }
244 | }).unzip
245 |
246 | (stm,expr,spc)
247 | }
248 |
249 |
250 | // precondition: cpsR.isDefined "implies" isAnyParentImpure
251 | def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = {
252 | // return value: (stms, expr, spc), where spc is CPSInfo after stms but *before* expr
253 | implicit val pos = tree.pos
254 | tree match {
255 | case Block(stms, expr) =>
256 | val (cpsA2, cpsR2) = (cpsA, linearize(cpsA, getAnswerTypeAnn(tree.tpe))) // tbd
257 | // val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe))
258 |
259 | val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
260 | val tree1 = (treeCopy.Block(tree, a, b)) // no updateSynthFlag here!!!
261 |
262 | (Nil, tree1, cpsA)
263 |
264 | case If(cond, thenp, elsep) =>
265 | /* possible situations:
266 | cps before (cpsA)
267 | cps in condition (spc) <-- synth flag set if *only* here!
268 | cps in (one or both) branches */
269 | val (condStats, condVal, spc) = transInlineValue(cond, cpsA)
270 | val (cpsA2, cpsR2) = if (hasSynthMarker(tree.tpe))
271 | (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else
272 | (None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly
273 | val thenVal = transExpr(thenp, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
274 | val elseVal = transExpr(elsep, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
275 |
276 | // check that then and else parts agree (not necessary any more, but left as sanity check)
277 | if (cpsR.isDefined) {
278 | if (elsep == EmptyTree)
279 | reporter.error(tree.pos, "always need else part in cps code")
280 | }
281 | if (hasAnswerTypeAnn(thenVal.tpe) != hasAnswerTypeAnn(elseVal.tpe)) {
282 | reporter.error(tree.pos, "then and else parts must both be cps code or neither of them")
283 | }
284 |
285 | (condStats, updateSynthFlag(treeCopy.If(tree, condVal, thenVal, elseVal)), spc)
286 |
287 | case Match(selector, cases) =>
288 | val (selStats, selVal, spc) = transInlineValue(selector, cpsA)
289 | val (cpsA2, cpsR2) =
290 | if (hasSynthMarker(tree.tpe)) (spc, linearize(spc, getAnswerTypeAnn(tree.tpe)))
291 | else (None, getAnswerTypeAnn(tree.tpe))
292 |
293 | val caseVals = cases map { case cd @ CaseDef(pat, guard, body) =>
294 | val bodyVal = transExpr(body, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
295 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
296 | }
297 |
298 | (selStats, updateSynthFlag(treeCopy.Match(tree, selVal, caseVals)), spc)
299 |
300 | // this is utterly broken: LabelDefs need to be considered together when transforming them to DefDefs:
301 | // suppose a Block {L1; ... ; LN}
302 | // this should become {D1def ; ... ; DNdef ; D1()}
303 | // where D$idef = def L$i(..) = {L$i.body; L${i+1}(..)}
304 |
305 | case ldef @ LabelDef(name, params, rhs) =>
306 | // println("trans LABELDEF "+(name, params, tree.tpe, hasAnswerTypeAnn(tree.tpe)))
307 | // TODO why does the labeldef's type have a cpsMinus annotation, whereas the rhs does not? (BYVALmode missing/too much somewhere?)
308 | if (hasAnswerTypeAnn(tree.tpe)) {
309 | // currentOwner.newMethod(name, tree.pos, Flags.SYNTHETIC) setInfo ldef.symbol.info
310 | val sym = ldef.symbol resetFlag Flags.LABEL
311 | val rhs1 = rhs //new TreeSymSubstituter(List(ldef.symbol), List(sym)).transform(rhs)
312 | val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe))(getAnswerTypeAnn(tree.tpe).isDefined || isAnyParentImpure) changeOwner (currentOwner -> sym)
313 |
314 | val stm1 = localTyper.typed(DefDef(sym, rhsVal))
315 | // since virtpatmat does not rely on fall-through, don't call the labels it emits
316 | // transBlock will take care of calling the first label
317 | // calling each labeldef is wrong, since some labels may be jumped over
318 | // we can get away with this for now since the only other labels we emit are for tailcalls/while loops,
319 | // which do not have consecutive labeldefs (and thus fall-through is irrelevant)
320 | if (treeInfo.hasSynthCaseSymbol(ldef)) (List(stm1), localTyper.typed{Literal(Constant(()))}, cpsA)
321 | else {
322 | assert(params.isEmpty, "problem in ANF transforming label with non-empty params "+ ldef)
323 | (List(stm1), localTyper.typed{Apply(Ident(sym), List())}, cpsA)
324 | }
325 | } else {
326 | val rhsVal = transExpr(rhs, None, None)
327 | (Nil, updateSynthFlag(treeCopy.LabelDef(tree, name, params, rhsVal)), cpsA)
328 | }
329 |
330 |
331 | case Try(block, catches, finalizer) =>
332 | val blockVal = transExpr(block, cpsA, cpsR)
333 |
334 | val catchVals = for {
335 | cd @ CaseDef(pat, guard, body) <- catches
336 | bodyVal = transExpr(body, cpsA, cpsR)
337 | } yield {
338 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
339 | }
340 |
341 | val finallyVal = transExpr(finalizer, None, None) // for now, no cps in finally
342 |
343 | (Nil, updateSynthFlag(treeCopy.Try(tree, blockVal, catchVals, finallyVal)), cpsA)
344 |
345 | case Assign(lhs, rhs) =>
346 | // allow cps code in rhs only
347 | val (stms, expr, spc) = transInlineValue(rhs, cpsA)
348 | (stms, updateSynthFlag(treeCopy.Assign(tree, transform(lhs), expr)), spc)
349 |
350 | case Return(expr0) =>
351 | if (isAnyParentImpure)
352 | reporter.error(tree.pos, "return expression not allowed, since method calls CPS method")
353 | val (stms, expr, spc) = transInlineValue(expr0, cpsA)
354 | (stms, updateSynthFlag(treeCopy.Return(tree, expr)), spc)
355 |
356 | case Throw(expr0) =>
357 | val (stms, expr, spc) = transInlineValue(expr0, cpsA)
358 | (stms, updateSynthFlag(treeCopy.Throw(tree, expr)), spc)
359 |
360 | case Typed(expr0, tpt) =>
361 | // TODO: should x: A @cps[B,C] have a special meaning?
362 | // type casts used in different ways (see match2.scala, #3199)
363 | val (stms, expr, spc) = transInlineValue(expr0, cpsA)
364 | val tpt1 = if (treeInfo.isWildcardStarArg(tree)) tpt else
365 | treeCopy.TypeTree(tpt).setType(removeAllCPSAnnotations(tpt.tpe))
366 | // (stms, updateSynthFlag(treeCopy.Typed(tree, expr, tpt1)), spc)
367 | (stms, treeCopy.Typed(tree, expr, tpt1).setType(removeAllCPSAnnotations(tree.tpe)), spc)
368 |
369 | case TypeApply(fun, args) =>
370 | val (stms, expr, spc) = transInlineValue(fun, cpsA)
371 | (stms, updateSynthFlag(treeCopy.TypeApply(tree, expr, args)), spc)
372 |
373 | case Select(qual, name) =>
374 | val (stms, expr, spc) = transInlineValue(qual, cpsA)
375 | (stms, updateSynthFlag(treeCopy.Select(tree, expr, name)), spc)
376 |
377 | case Apply(fun, args) =>
378 | val (funStm, funExpr, funSpc) = transInlineValue(fun, cpsA)
379 | val (argStm, argExpr, argSpc) = transArgList(fun, args, funSpc)
380 |
381 | (funStm ::: (argStm.flatten), updateSynthFlag(treeCopy.Apply(tree, funExpr, argExpr)),
382 | argSpc)
383 |
384 | case _ =>
385 | cpsAllowed = true
386 | (Nil, transform(tree), cpsA)
387 | }
388 | }
389 |
390 | // precondition: cpsR.isDefined "implies" isAnyParentImpure
391 | def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = {
392 |
393 | val (stms, expr, spc) = transValue(tree, cpsA, cpsR)
394 |
395 | val bot = linearize(spc, getAnswerTypeAnn(expr.tpe))(tree.pos)
396 |
397 | val plainTpe = removeAllCPSAnnotations(expr.tpe)
398 |
399 | if (cpsR.isDefined && !bot.isDefined) {
400 |
401 | if (!expr.isEmpty && (expr.tpe.typeSymbol ne NothingClass)) {
402 | // must convert!
403 | debuglog("cps type conversion (has: " + cpsA + "/" + spc + "/" + expr.tpe + ")")
404 | debuglog("cps type conversion (expected: " + cpsR.get + "): " + expr)
405 |
406 | if (!hasPlusMarker(expr.tpe))
407 | reporter.warning(tree.pos, "expression " + tree + " is cps-transformed unexpectedly")
408 |
409 | try {
410 | val Some((a, b)) = cpsR
411 | /* Since shiftUnit is bounded [A,B,C>:B] this may not typecheck
412 | * if C is overly specific. So if !(B <:< C), call shiftUnit0
413 | * instead, which takes only two type arguments.
414 | */
415 | val conforms = a <:< b
416 | val call = localTyper.typedPos(tree.pos)(
417 | Apply(
418 | TypeApply(
419 | gen.mkAttributedRef( if (conforms) MethShiftUnit else MethShiftUnit0 ),
420 | List(TypeTree(plainTpe), TypeTree(a)) ++ ( if (conforms) List(TypeTree(b)) else Nil )
421 | ),
422 | List(expr)
423 | )
424 | )
425 | // This is today's sick/meaningless heuristic for spotting breakdown so
426 | // we don't proceed until stack traces start draping themselves over everything.
427 | // If there are wildcard types in the tree and B == Nothing, something went wrong.
428 | // (I thought WildcardTypes would be enough, but nope. 'reset0 { 0 }' has them.)
429 | //
430 | // Code as simple as reset((_: String).length)
431 | // will crash meaninglessly without this check. See SI-3718.
432 | //
433 | // TODO - obviously this should be done earlier, differently, or with
434 | // a more skilled hand. Most likely, all three.
435 | if ((b.typeSymbol eq NothingClass) && call.tpe.exists(_ eq WildcardType))
436 | reporter.error(tree.pos, "cannot cps-transform malformed (possibly in shift/reset placement) expression")
437 | else
438 | return ((stms, call))
439 | }
440 | catch {
441 | case ex:TypeError =>
442 | reporter.error(ex.pos, "cannot cps-transform expression " + tree + ": " + ex.msg)
443 | }
444 | }
445 |
446 | } else if (!cpsR.isDefined && bot.isDefined) {
447 | // error!
448 | debuglog("cps type error: " + expr)
449 | //println("cps type error: " + expr + "/" + expr.tpe + "/" + getAnswerTypeAnn(expr.tpe))
450 |
451 | //println(cpsR + "/" + spc + "/" + bot)
452 |
453 | reporter.error(tree.pos, "found cps expression in non-cps position")
454 | } else {
455 | // all is well
456 |
457 | if (hasPlusMarker(expr.tpe)) {
458 | reporter.warning(tree.pos, "expression " + expr + " of type " + expr.tpe + " is not expected to have a cps type")
459 | expr modifyType removeAllCPSAnnotations
460 | }
461 |
462 | // TODO: sanity check that types agree
463 | }
464 |
465 | (stms, expr)
466 | }
467 |
468 | def transInlineValue(tree: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = {
469 |
470 | val (stms, expr, spc) = transValue(tree, cpsA, None) // never required to be cps
471 |
472 | getAnswerTypeAnn(expr.tpe) match {
473 | case spcVal @ Some(_) =>
474 |
475 | val valueTpe = removeAllCPSAnnotations(expr.tpe)
476 |
477 | val sym: Symbol = (
478 | currentOwner.newValue(newTermName(unit.fresh.newName("tmp")), tree.pos, Flags.SYNTHETIC)
479 | setInfo valueTpe
480 | setAnnotations List(AnnotationInfo(MarkerCPSSym.tpe_*, Nil, Nil))
481 | )
482 | expr.changeOwner(currentOwner -> sym)
483 |
484 | (stms ::: List(ValDef(sym, expr) setType(NoType)),
485 | Ident(sym) setType(valueTpe) setPos(tree.pos), linearize(spc, spcVal)(tree.pos))
486 |
487 | case _ =>
488 | (stms, expr, spc)
489 | }
490 |
491 | }
492 |
493 |
494 |
495 | def transInlineStm(stm: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], CPSInfo) = {
496 | stm match {
497 |
498 | // TODO: what about DefDefs?
499 | // TODO: relation to top-level val def?
500 | // TODO: what about lazy vals?
501 |
502 | case tree @ ValDef(mods, name, tpt, rhs) =>
503 | val (stms, anfRhs, spc) = atOwner(tree.symbol) { transValue(rhs, cpsA, None) }
504 |
505 | val tv = new ChangeOwnerTraverser(tree.symbol, currentOwner)
506 | stms.foreach(tv.traverse(_))
507 |
508 | // TODO: symbol might already have annotation. Should check conformance
509 | // TODO: better yet: do without annotations on symbols
510 |
511 | val spcVal = getAnswerTypeAnn(anfRhs.tpe)
512 | spcVal foreach (_ => tree.symbol setAnnotations List(AnnotationInfo(MarkerCPSSym.tpe_*, Nil, Nil)))
513 |
514 | (stms:::List(treeCopy.ValDef(tree, mods, name, tpt, anfRhs)), linearize(spc, spcVal)(tree.pos))
515 |
516 | case _ =>
517 | val (headStms, headExpr, headSpc) = transInlineValue(stm, cpsA)
518 | val valSpc = getAnswerTypeAnn(headExpr.tpe)
519 | (headStms:::List(headExpr), linearize(headSpc, valSpc)(stm.pos))
520 | }
521 | }
522 |
523 | // precondition: cpsR.isDefined "implies" isAnyParentImpure
524 | def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = {
525 | def rec(currStats: List[Tree], currAns: CPSInfo, accum: List[Tree]): (List[Tree], Tree) =
526 | currStats match {
527 | case Nil =>
528 | val (anfStats, anfExpr) = transTailValue(expr, currAns, cpsR)
529 | (accum ++ anfStats, anfExpr)
530 |
531 | case stat :: rest =>
532 | val (stats, nextAns) = transInlineStm(stat, currAns)
533 | rec(rest, nextAns, accum ++ stats)
534 | }
535 |
536 | val (anfStats, anfExpr) = rec(stms, cpsA, List())
537 | // println("\nanf-block:\n"+ ((stms :+ expr) mkString ("{", "\n", "}")) +"\nBECAME\n"+ ((anfStats :+ anfExpr) mkString ("{", "\n", "}")))
538 | // println("synth case? "+ (anfStats map (t => (t, t.isDef, treeInfo.hasSynthCaseSymbol(t)))))
539 | // SUPER UGLY HACK: handle virtpatmat-style matches, whose labels have already been turned into DefDefs
540 | if (anfStats.nonEmpty && (anfStats forall (t => !t.isDef || treeInfo.hasSynthCaseSymbol(t)))) {
541 | val (prologue, rest) = (anfStats :+ anfExpr) span (s => !s.isInstanceOf[DefDef]) // find first case
542 | // println("rest: "+ rest)
543 | // val (defs, calls) = rest partition (_.isInstanceOf[DefDef])
544 | if (rest.nonEmpty) {
545 | // the filter drops the ()'s emitted when transValue encountered a LabelDef
546 | val stats = prologue ++ (rest filter (_.isInstanceOf[DefDef])).reverse // ++ calls
547 | // println("REVERSED "+ (stats mkString ("{", "\n", "}")))
548 | (stats, localTyper.typed{Apply(Ident(rest.head.symbol), List())}) // call first label to kick-start the match
549 | } else (anfStats, anfExpr)
550 | } else (anfStats, anfExpr)
551 | }
552 | }
553 | }
554 |
--------------------------------------------------------------------------------
/plugin/src/main/scala-2.12/scala/tools/selectivecps/SelectiveANFTransform.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.tools.selectivecps
14 |
15 | import scala.tools.nsc.transform._
16 | import scala.tools.nsc.symtab._
17 | import scala.tools.nsc.plugins._
18 |
19 | /**
20 | * In methods marked @cps, explicitly name results of calls to other @cps methods
21 | */
22 | abstract class SelectiveANFTransform extends PluginComponent with Transform with
23 | TypingTransformers with CPSUtils {
24 | // inherits abstract value `global` and class `Phase` from Transform
25 |
26 | import global._ // the global environment
27 | import definitions._ // standard classes and methods
28 | import typer.atOwner // methods to type trees
29 |
30 | override def description = "ANF pre-transform for @cps"
31 |
32 | /** the following two members override abstract members in Transform */
33 | val phaseName: String = "selectiveanf"
34 |
35 | protected def newTransformer(unit: CompilationUnit): Transformer =
36 | new ANFTransformer(unit)
37 |
38 | class ANFTransformer(unit: CompilationUnit) extends TypingTransformer(unit) {
39 |
40 | var cpsAllowed: Boolean = false // detect cps code in places we do not handle (yet)
41 |
42 | object RemoveTailReturnsTransformer extends Transformer {
43 | override def transform(tree: Tree): Tree = tree match {
44 | case Block(stms, r @ Return(expr)) =>
45 | treeCopy.Block(tree, stms, expr)
46 |
47 | case Block(stms, expr) =>
48 | treeCopy.Block(tree, stms, transform(expr))
49 |
50 | case If(cond, r1 @ Return(thenExpr), r2 @ Return(elseExpr)) =>
51 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
52 |
53 | case If(cond, r1 @ Return(thenExpr), elseExpr) =>
54 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
55 |
56 | case If(cond, thenExpr, r2 @ Return(elseExpr)) =>
57 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
58 |
59 | case If(cond, thenExpr, elseExpr) =>
60 | treeCopy.If(tree, cond, transform(thenExpr), transform(elseExpr))
61 |
62 | case Try(block, catches, finalizer) =>
63 | treeCopy.Try(tree,
64 | transform(block),
65 | (catches map (t => transform(t))).asInstanceOf[List[CaseDef]],
66 | transform(finalizer))
67 |
68 | case CaseDef(pat, guard, r @ Return(expr)) =>
69 | treeCopy.CaseDef(tree, pat, guard, expr)
70 |
71 | case CaseDef(pat, guard, body) =>
72 | treeCopy.CaseDef(tree, pat, guard, transform(body))
73 |
74 | case Return(_) =>
75 | reporter.error(tree.pos, "return expressions in CPS code must be in tail position")
76 | tree
77 |
78 | case _ =>
79 | super.transform(tree)
80 | }
81 | }
82 |
83 | def removeTailReturns(body: Tree): Tree = {
84 | // support body with single return expression
85 | body match {
86 | case Return(expr) => expr
87 | case _ => RemoveTailReturnsTransformer.transform(body)
88 | }
89 | }
90 |
91 | override def transform(tree: Tree): Tree = {
92 | if (!cpsEnabled) return tree
93 |
94 | tree match {
95 |
96 | // Maybe we should further generalize the transform and move it over
97 | // to the regular Transformer facility. But then, actual and required cps
98 | // state would need more complicated (stateful!) tracking.
99 |
100 | // Making the default case use transExpr(tree, None, None) instead of
101 | // calling super.transform() would be a start, but at the moment,
102 | // this would cause infinite recursion. But we could remove the
103 | // ValDef case here.
104 |
105 | case dd @ DefDef(mods, name, tparams, vparamss, tpt, rhs0) =>
106 | debuglog("transforming " + dd.symbol)
107 |
108 | atOwner(dd.symbol) {
109 | val rhs =
110 | if (cpsParamTypes(tpt.tpe).nonEmpty) removeTailReturns(rhs0)
111 | else rhs0
112 | val rhs1 = transExpr(rhs, None, getExternalAnswerTypeAnn(tpt.tpe))(getExternalAnswerTypeAnn(tpt.tpe).isDefined)
113 |
114 | debuglog("result "+rhs1)
115 | debuglog("result is of type "+rhs1.tpe)
116 |
117 | treeCopy.DefDef(dd, mods, name, transformTypeDefs(tparams), transformValDefss(vparamss),
118 | transform(tpt), rhs1)
119 | }
120 |
121 | case ff @ Function(vparams, body) =>
122 | debuglog("transforming anon function " + ff.symbol)
123 |
124 | atOwner(ff.symbol) {
125 |
126 | //val body1 = transExpr(body, None, getExternalAnswerTypeAnn(body.tpe))
127 |
128 | // need to special case partial functions: if expected type is @cps
129 | // but all cases are pure, then we would transform
130 | // { x => x match { case A => ... }} to
131 | // { x => shiftUnit(x match { case A => ... })}
132 | // which Uncurry cannot handle (see function6.scala)
133 | // thus, we push down the shiftUnit to each of the case bodies
134 |
135 | val ext = getExternalAnswerTypeAnn(body.tpe)
136 | val pureBody = getAnswerTypeAnn(body.tpe).isEmpty
137 | implicit val isParentImpure = ext.isDefined
138 |
139 | def transformPureMatch(tree: Tree, selector: Tree, cases: List[CaseDef]) = {
140 | val caseVals = cases map { case cd @ CaseDef(pat, guard, body) =>
141 | // if (!hasPlusMarker(body.tpe)) body modifyType (_ withAnnotation newPlusMarker()) // TODO: to avoid warning
142 | val bodyVal = transExpr(body, None, ext) // ??? triggers "cps-transformed unexpectedly" warning in transTailValue
143 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
144 | }
145 | treeCopy.Match(tree, transform(selector), caseVals)
146 | }
147 |
148 | def transformPureVirtMatch(body: Block, selDef: ValDef, cases: List[Tree], matchEnd: Tree) = {
149 | val stats = transform(selDef) :: (cases map (transExpr(_, None, ext)))
150 | treeCopy.Block(body, stats, transExpr(matchEnd, None, ext))
151 | }
152 |
153 | val body1 = body match {
154 | case Match(selector, cases) if ext.isDefined && pureBody =>
155 | transformPureMatch(body, selector, cases)
156 |
157 | // virtpatmat switch
158 | case Block(List(selDef: ValDef), mat@Match(selector, cases)) if ext.isDefined && pureBody =>
159 | treeCopy.Block(body, List(transform(selDef)), transformPureMatch(mat, selector, cases))
160 |
161 | // virtpatmat
162 | case b@Block(matchStats@((selDef: ValDef) :: cases), matchEnd) if ext.isDefined && pureBody && (matchStats forall treeInfo.hasSynthCaseSymbol) =>
163 | transformPureVirtMatch(b, selDef, cases, matchEnd)
164 |
165 | // virtpatmat that stores the scrut separately -- TODO: can we eliminate this case??
166 | case Block(List(selDef0: ValDef), mat@Block(matchStats@((selDef: ValDef) :: cases), matchEnd)) if ext.isDefined && pureBody && (matchStats forall treeInfo.hasSynthCaseSymbol)=>
167 | treeCopy.Block(body, List(transform(selDef0)), transformPureVirtMatch(mat, selDef, cases, matchEnd))
168 |
169 | case _ =>
170 | transExpr(body, None, ext)
171 | }
172 |
173 | debuglog("anf result "+body1+"\nresult is of type "+body1.tpe)
174 |
175 | treeCopy.Function(ff, transformValDefs(vparams), body1)
176 | }
177 |
178 | case vd @ ValDef(mods, name, tpt, rhs) => // object-level valdefs
179 | debuglog("transforming valdef " + vd.symbol)
180 |
181 | if (getExternalAnswerTypeAnn(tpt.tpe).isEmpty) {
182 |
183 | atOwner(vd.symbol) {
184 |
185 | val rhs1 = transExpr(rhs, None, None)
186 |
187 | treeCopy.ValDef(vd, mods, name, transform(tpt), rhs1)
188 | }
189 | } else {
190 | reporter.error(tree.pos, "cps annotations not allowed on by-value parameters or value definitions")
191 | super.transform(tree)
192 | }
193 |
194 | case TypeTree() =>
195 | // circumvent cpsAllowed here
196 | super.transform(tree)
197 |
198 | case Apply(_,_) =>
199 | // this allows reset { ... } in object constructors
200 | // it's kind of a hack to put it here (see note above)
201 | transExpr(tree, None, None)
202 |
203 | case _ =>
204 | if (hasAnswerTypeAnn(tree.tpe)) {
205 | if (tree.symbol.isLazy) {
206 | reporter.error(tree.pos, "implementation restriction: cps annotations not allowed on lazy value definitions")
207 | cpsAllowed = false
208 | } else if (!cpsAllowed)
209 | reporter.error(tree.pos, "cps code not allowed here / " + tree.getClass + " / " + tree)
210 |
211 | log(tree)
212 | }
213 |
214 | cpsAllowed = false
215 | super.transform(tree)
216 | }
217 | }
218 |
219 |
220 | def transExpr(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean = false): Tree = {
221 | transTailValue(tree, cpsA, cpsR)(cpsR.isDefined || isAnyParentImpure) match {
222 | case (Nil, b) => b
223 | case (a, b) =>
224 | treeCopy.Block(tree, a,b)
225 | }
226 | }
227 |
228 |
229 | def transArgList(fun: Tree, args: List[Tree], cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[List[Tree]], List[Tree], CPSInfo) = {
230 | val formals = fun.tpe.paramTypes
231 | val overshoot = args.length - formals.length
232 |
233 | var spc: CPSInfo = cpsA
234 |
235 | val (stm,expr) = (for ((a,tp) <- args.zip(formals ::: List.fill(overshoot)(NoType))) yield {
236 | tp match {
237 | case TypeRef(_, ByNameParamClass, List(elemtp)) =>
238 | // note that we're not passing just isAnyParentImpure
239 | (Nil, transExpr(a, None, getAnswerTypeAnn(elemtp))(getAnswerTypeAnn(elemtp).isDefined || isAnyParentImpure))
240 | case _ =>
241 | val (valStm, valExpr, valSpc) = transInlineValue(a, spc)
242 | spc = valSpc
243 | (valStm, valExpr)
244 | }
245 | }).unzip
246 |
247 | (stm,expr,spc)
248 | }
249 |
250 |
251 | // precondition: cpsR.isDefined "implies" isAnyParentImpure
252 | def transValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = {
253 | // return value: (stms, expr, spc), where spc is CPSInfo after stms but *before* expr
254 | implicit val pos = tree.pos
255 | tree match {
256 | case Block(stms, expr) =>
257 | val (cpsA2, cpsR2) = (cpsA, linearize(cpsA, getAnswerTypeAnn(tree.tpe))) // tbd
258 | // val (cpsA2, cpsR2) = (None, getAnswerTypeAnn(tree.tpe))
259 |
260 | val (a, b) = transBlock(stms, expr, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
261 | val tree1 = (treeCopy.Block(tree, a, b)) // no updateSynthFlag here!!!
262 |
263 | (Nil, tree1, cpsA)
264 |
265 | case If(cond, thenp, elsep) =>
266 | /* possible situations:
267 | cps before (cpsA)
268 | cps in condition (spc) <-- synth flag set if *only* here!
269 | cps in (one or both) branches */
270 | val (condStats, condVal, spc) = transInlineValue(cond, cpsA)
271 | val (cpsA2, cpsR2) = if (hasSynthMarker(tree.tpe))
272 | (spc, linearize(spc, getAnswerTypeAnn(tree.tpe))) else
273 | (None, getAnswerTypeAnn(tree.tpe)) // if no cps in condition, branches must conform to tree.tpe directly
274 | val thenVal = transExpr(thenp, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
275 | val elseVal = transExpr(elsep, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
276 |
277 | // check that then and else parts agree (not necessary any more, but left as sanity check)
278 | if (cpsR.isDefined) {
279 | if (elsep == EmptyTree)
280 | reporter.error(tree.pos, "always need else part in cps code")
281 | }
282 | if (hasAnswerTypeAnn(thenVal.tpe) != hasAnswerTypeAnn(elseVal.tpe)) {
283 | reporter.error(tree.pos, "then and else parts must both be cps code or neither of them")
284 | }
285 |
286 | (condStats, updateSynthFlag(treeCopy.If(tree, condVal, thenVal, elseVal)), spc)
287 |
288 | case Match(selector, cases) =>
289 | val (selStats, selVal, spc) = transInlineValue(selector, cpsA)
290 | val (cpsA2, cpsR2) =
291 | if (hasSynthMarker(tree.tpe)) (spc, linearize(spc, getAnswerTypeAnn(tree.tpe)))
292 | else (None, getAnswerTypeAnn(tree.tpe))
293 |
294 | val caseVals = cases map { case cd @ CaseDef(pat, guard, body) =>
295 | val bodyVal = transExpr(body, cpsA2, cpsR2)(cpsR2.isDefined || isAnyParentImpure)
296 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
297 | }
298 |
299 | (selStats, updateSynthFlag(treeCopy.Match(tree, selVal, caseVals)), spc)
300 |
301 | // this is utterly broken: LabelDefs need to be considered together when transforming them to DefDefs:
302 | // suppose a Block {L1; ... ; LN}
303 | // this should become {D1def ; ... ; DNdef ; D1()}
304 | // where D$idef = def L$i(..) = {L$i.body; L${i+1}(..)}
305 |
306 | case ldef @ LabelDef(name, params, rhs) =>
307 | // println("trans LABELDEF "+(name, params, tree.tpe, hasAnswerTypeAnn(tree.tpe)))
308 | // TODO why does the labeldef's type have a cpsMinus annotation, whereas the rhs does not? (BYVALmode missing/too much somewhere?)
309 | if (hasAnswerTypeAnn(tree.tpe)) {
310 | // currentOwner.newMethod(name, tree.pos, Flags.SYNTHETIC) setInfo ldef.symbol.info
311 | val sym = ldef.symbol resetFlag Flags.LABEL
312 | val rhs1 = rhs //new TreeSymSubstituter(List(ldef.symbol), List(sym)).transform(rhs)
313 | val rhsVal = transExpr(rhs1, None, getAnswerTypeAnn(tree.tpe))(getAnswerTypeAnn(tree.tpe).isDefined || isAnyParentImpure) changeOwner (currentOwner -> sym)
314 |
315 | val stm1 = localTyper.typed(DefDef(sym, rhsVal))
316 | // since virtpatmat does not rely on fall-through, don't call the labels it emits
317 | // transBlock will take care of calling the first label
318 | // calling each labeldef is wrong, since some labels may be jumped over
319 | // we can get away with this for now since the only other labels we emit are for tailcalls/while loops,
320 | // which do not have consecutive labeldefs (and thus fall-through is irrelevant)
321 | if (treeInfo.hasSynthCaseSymbol(ldef)) (List(stm1), localTyper.typed{Literal(Constant(()))}, cpsA)
322 | else {
323 | assert(params.isEmpty, "problem in ANF transforming label with non-empty params "+ ldef)
324 | (List(stm1), localTyper.typed{Apply(Ident(sym), List())}, cpsA)
325 | }
326 | } else {
327 | val rhsVal = transExpr(rhs, None, None)
328 | (Nil, updateSynthFlag(treeCopy.LabelDef(tree, name, params, rhsVal)), cpsA)
329 | }
330 |
331 |
332 | case Try(block, catches, finalizer) =>
333 | val blockVal = transExpr(block, cpsA, cpsR)
334 |
335 | val catchVals = for {
336 | cd @ CaseDef(pat, guard, body) <- catches
337 | bodyVal = transExpr(body, cpsA, cpsR)
338 | } yield {
339 | treeCopy.CaseDef(cd, transform(pat), transform(guard), bodyVal)
340 | }
341 |
342 | val finallyVal = transExpr(finalizer, None, None) // for now, no cps in finally
343 |
344 | (Nil, updateSynthFlag(treeCopy.Try(tree, blockVal, catchVals, finallyVal)), cpsA)
345 |
346 | case Assign(lhs, rhs) =>
347 | // allow cps code in rhs only
348 | val (stms, expr, spc) = transInlineValue(rhs, cpsA)
349 | (stms, updateSynthFlag(treeCopy.Assign(tree, transform(lhs), expr)), spc)
350 |
351 | case Return(expr0) =>
352 | if (isAnyParentImpure)
353 | reporter.error(tree.pos, "return expression not allowed, since method calls CPS method")
354 | val (stms, expr, spc) = transInlineValue(expr0, cpsA)
355 | (stms, updateSynthFlag(treeCopy.Return(tree, expr)), spc)
356 |
357 | case Throw(expr0) =>
358 | val (stms, expr, spc) = transInlineValue(expr0, cpsA)
359 | (stms, updateSynthFlag(treeCopy.Throw(tree, expr)), spc)
360 |
361 | case Typed(expr0, tpt) =>
362 | // TODO: should x: A @cps[B,C] have a special meaning?
363 | // type casts used in different ways (see match2.scala, #3199)
364 | val (stms, expr, spc) = transInlineValue(expr0, cpsA)
365 | val tpt1 = if (treeInfo.isWildcardStarArg(tree)) tpt else
366 | treeCopy.TypeTree(tpt).setType(removeAllCPSAnnotations(tpt.tpe))
367 | // (stms, updateSynthFlag(treeCopy.Typed(tree, expr, tpt1)), spc)
368 | (stms, treeCopy.Typed(tree, expr, tpt1).setType(removeAllCPSAnnotations(tree.tpe)), spc)
369 |
370 | case TypeApply(fun, args) =>
371 | val (stms, expr, spc) = transInlineValue(fun, cpsA)
372 | (stms, updateSynthFlag(treeCopy.TypeApply(tree, expr, args)), spc)
373 |
374 | case Select(qual, name) =>
375 | val (stms, expr, spc) = transInlineValue(qual, cpsA)
376 | (stms, updateSynthFlag(treeCopy.Select(tree, expr, name)), spc)
377 |
378 | case Apply(fun, args) =>
379 | val (funStm, funExpr, funSpc) = transInlineValue(fun, cpsA)
380 | val (argStm, argExpr, argSpc) = transArgList(fun, args, funSpc)
381 |
382 | (funStm ::: (argStm.flatten), updateSynthFlag(treeCopy.Apply(tree, funExpr, argExpr)),
383 | argSpc)
384 |
385 | case _ =>
386 | cpsAllowed = true
387 | (Nil, transform(tree), cpsA)
388 | }
389 | }
390 |
391 | // precondition: cpsR.isDefined "implies" isAnyParentImpure
392 | def transTailValue(tree: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = {
393 |
394 | val (stms, expr, spc) = transValue(tree, cpsA, cpsR)
395 |
396 | val bot = linearize(spc, getAnswerTypeAnn(expr.tpe))(tree.pos)
397 |
398 | val plainTpe = removeAllCPSAnnotations(expr.tpe)
399 |
400 | if (cpsR.isDefined && !bot.isDefined) {
401 |
402 | if (!expr.isEmpty && (expr.tpe.typeSymbol ne NothingClass)) {
403 | // must convert!
404 | debuglog("cps type conversion (has: " + cpsA + "/" + spc + "/" + expr.tpe + ")")
405 | debuglog("cps type conversion (expected: " + cpsR.get + "): " + expr)
406 |
407 | if (!hasPlusMarker(expr.tpe))
408 | reporter.warning(tree.pos, "expression " + tree + " is cps-transformed unexpectedly")
409 |
410 | try {
411 | val Some((a, b)) = cpsR
412 | /* Since shiftUnit is bounded [A,B,C>:B] this may not typecheck
413 | * if C is overly specific. So if !(B <:< C), call shiftUnit0
414 | * instead, which takes only two type arguments.
415 | */
416 | val conforms = a <:< b
417 | val call = localTyper.typedPos(tree.pos)(
418 | Apply(
419 | TypeApply(
420 | gen.mkAttributedRef( if (conforms) MethShiftUnit else MethShiftUnit0 ),
421 | List(TypeTree(plainTpe), TypeTree(a)) ++ ( if (conforms) List(TypeTree(b)) else Nil )
422 | ),
423 | List(expr)
424 | )
425 | )
426 | // This is today's sick/meaningless heuristic for spotting breakdown so
427 | // we don't proceed until stack traces start draping themselves over everything.
428 | // If there are wildcard types in the tree and B == Nothing, something went wrong.
429 | // (I thought WildcardTypes would be enough, but nope. 'reset0 { 0 }' has them.)
430 | //
431 | // Code as simple as reset((_: String).length)
432 | // will crash meaninglessly without this check. See SI-3718.
433 | //
434 | // TODO - obviously this should be done earlier, differently, or with
435 | // a more skilled hand. Most likely, all three.
436 | if ((b.typeSymbol eq NothingClass) && call.tpe.exists(_ eq WildcardType))
437 | reporter.error(tree.pos, "cannot cps-transform malformed (possibly in shift/reset placement) expression")
438 | else
439 | return ((stms, call))
440 | }
441 | catch {
442 | case ex:TypeError =>
443 | reporter.error(ex.pos, "cannot cps-transform expression " + tree + ": " + ex.msg)
444 | }
445 | }
446 |
447 | } else if (!cpsR.isDefined && bot.isDefined) {
448 | // error!
449 | debuglog("cps type error: " + expr)
450 | //println("cps type error: " + expr + "/" + expr.tpe + "/" + getAnswerTypeAnn(expr.tpe))
451 |
452 | //println(cpsR + "/" + spc + "/" + bot)
453 |
454 | reporter.error(tree.pos, "found cps expression in non-cps position")
455 | } else {
456 | // all is well
457 |
458 | if (hasPlusMarker(expr.tpe)) {
459 | reporter.warning(tree.pos, "expression " + expr + " of type " + expr.tpe + " is not expected to have a cps type")
460 | expr modifyType removeAllCPSAnnotations
461 | }
462 |
463 | // TODO: sanity check that types agree
464 | }
465 |
466 | (stms, expr)
467 | }
468 |
469 | def transInlineValue(tree: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree, CPSInfo) = {
470 |
471 | val (stms, expr, spc) = transValue(tree, cpsA, None) // never required to be cps
472 |
473 | getAnswerTypeAnn(expr.tpe) match {
474 | case spcVal @ Some(_) =>
475 |
476 | val valueTpe = removeAllCPSAnnotations(expr.tpe)
477 |
478 | val sym: Symbol = (
479 | currentOwner.newValue(newTermName(unit.fresh.newName("tmp")), tree.pos, Flags.SYNTHETIC)
480 | setInfo valueTpe
481 | setAnnotations List(AnnotationInfo(MarkerCPSSym.tpe_*, Nil, Nil))
482 | )
483 | expr.changeOwner(currentOwner -> sym)
484 |
485 | (stms ::: List(ValDef(sym, expr) setType(NoType)),
486 | Ident(sym) setType(valueTpe) setPos(tree.pos), linearize(spc, spcVal)(tree.pos))
487 |
488 | case _ =>
489 | (stms, expr, spc)
490 | }
491 |
492 | }
493 |
494 |
495 |
496 | def transInlineStm(stm: Tree, cpsA: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], CPSInfo) = {
497 | stm match {
498 |
499 | // TODO: what about DefDefs?
500 | // TODO: relation to top-level val def?
501 | // TODO: what about lazy vals?
502 |
503 | case tree @ ValDef(mods, name, tpt, rhs) =>
504 | val (stms, anfRhs, spc) = atOwner(tree.symbol) { transValue(rhs, cpsA, None) }
505 |
506 | val tv = new ChangeOwnerTraverser(tree.symbol, currentOwner)
507 | stms.foreach(tv.traverse(_))
508 |
509 | // TODO: symbol might already have annotation. Should check conformance
510 | // TODO: better yet: do without annotations on symbols
511 |
512 | val spcVal = getAnswerTypeAnn(anfRhs.tpe)
513 | spcVal foreach (_ => tree.symbol setAnnotations List(AnnotationInfo(MarkerCPSSym.tpe_*, Nil, Nil)))
514 |
515 | (stms:::List(treeCopy.ValDef(tree, mods, name, tpt, anfRhs)), linearize(spc, spcVal)(tree.pos))
516 |
517 | case _ =>
518 | val (headStms, headExpr, headSpc) = transInlineValue(stm, cpsA)
519 | val valSpc = getAnswerTypeAnn(headExpr.tpe)
520 | (headStms:::List(headExpr), linearize(headSpc, valSpc)(stm.pos))
521 | }
522 | }
523 |
524 | // precondition: cpsR.isDefined "implies" isAnyParentImpure
525 | def transBlock(stms: List[Tree], expr: Tree, cpsA: CPSInfo, cpsR: CPSInfo)(implicit isAnyParentImpure: Boolean): (List[Tree], Tree) = {
526 | def rec(currStats: List[Tree], currAns: CPSInfo, accum: List[Tree]): (List[Tree], Tree) =
527 | currStats match {
528 | case Nil =>
529 | val (anfStats, anfExpr) = transTailValue(expr, currAns, cpsR)
530 | (accum ++ anfStats, anfExpr)
531 |
532 | case stat :: rest =>
533 | val (stats, nextAns) = transInlineStm(stat, currAns)
534 | rec(rest, nextAns, accum ++ stats)
535 | }
536 |
537 | val (anfStats, anfExpr) = rec(stms, cpsA, List())
538 | // println("\nanf-block:\n"+ ((stms :+ expr) mkString ("{", "\n", "}")) +"\nBECAME\n"+ ((anfStats :+ anfExpr) mkString ("{", "\n", "}")))
539 | // println("synth case? "+ (anfStats map (t => (t, t.isDef, treeInfo.hasSynthCaseSymbol(t)))))
540 | // SUPER UGLY HACK: handle virtpatmat-style matches, whose labels have already been turned into DefDefs
541 | if (anfStats.nonEmpty && (anfStats forall (t => !t.isDef || treeInfo.hasSynthCaseSymbol(t)))) {
542 | val (prologue, rest) = (anfStats :+ anfExpr) span (s => !s.isInstanceOf[DefDef]) // find first case
543 | // println("rest: "+ rest)
544 | // val (defs, calls) = rest partition (_.isInstanceOf[DefDef])
545 | if (rest.nonEmpty) {
546 | // the filter drops the ()'s emitted when transValue encountered a LabelDef
547 | val stats = prologue ++ (rest filter (_.isInstanceOf[DefDef])).reverse // ++ calls
548 | // println("REVERSED "+ (stats mkString ("{", "\n", "}")))
549 | (stats, localTyper.typed{Apply(Ident(rest.head.symbol), List())}) // call first label to kick-start the match
550 | } else (anfStats, anfExpr)
551 | } else (anfStats, anfExpr)
552 | }
553 | }
554 | }
555 |
--------------------------------------------------------------------------------
/library/src/test/scala/scala/tools/selectivecps/TestSuite.scala:
--------------------------------------------------------------------------------
1 | /*
2 | * Scala (https://www.scala-lang.org)
3 | *
4 | * Copyright EPFL and Lightbend, Inc.
5 | *
6 | * Licensed under Apache License 2.0
7 | * (http://www.apache.org/licenses/LICENSE-2.0).
8 | *
9 | * See the NOTICE file distributed with this work for
10 | * additional information regarding copyright ownership.
11 | */
12 |
13 | package scala.tools.selectivecps
14 |
15 | import org.junit.Test
16 | import org.junit.Assert.assertEquals
17 |
18 | import scala.annotation._
19 | import scala.collection.Seq
20 | import scala.collection.generic.CanBuildFrom
21 | import scala.language.{ implicitConversions, higherKinds }
22 | import scala.util.continuations._
23 | import scala.util.control.Exception
24 |
25 | class Functions {
26 | def m0() = {
27 | shift((k: Int => Int) => k(k(7))) * 2
28 | }
29 |
30 | def m1() = {
31 | 2 * shift((k: Int => Int) => k(k(7)))
32 | }
33 |
34 | @Test def basics = {
35 |
36 | assertEquals(28, reset(m0()))
37 | assertEquals(28, reset(m1()))
38 | }
39 |
40 | @Test def function1 = {
41 |
42 | val f = () => shift { k: (Int => Int) => k(7) }
43 | val g: () => Int @cps[Int] = f
44 |
45 | assertEquals(7, reset(g()))
46 | }
47 |
48 | @Test def function4 = {
49 |
50 | val g: () => Int @cps[Int] = () => shift { k: (Int => Int) => k(7) }
51 |
52 | assertEquals(7, reset(g()))
53 | }
54 |
55 | @Test def function5 = {
56 |
57 | val g: () => Int @cps[Int] = () => 7
58 |
59 | assertEquals(7, reset(g()))
60 | }
61 |
62 | @Test def function6 = {
63 |
64 | val g: PartialFunction[Int, Int @cps[Int]] = { case x => 7 }
65 |
66 | assertEquals(7, reset(g(2)))
67 |
68 | }
69 |
70 | }
71 |
72 | class IfThenElse {
73 | val out = new StringBuilder; def printOut(x: Any): Unit = out ++= x.toString
74 |
75 | def test(x: Int) = if (x <= 7)
76 | shift { k: (Int => Int) => k(k(k(x))) }
77 | else
78 | shift { k: (Int => Int) => k(x) }
79 |
80 | @Test def ifelse0 = {
81 | assertEquals(10, reset(1 + test(7)))
82 | assertEquals(9, reset(1 + test(8)))
83 | }
84 |
85 | def test1(x: Int) = if (x <= 7)
86 | shift { k: (Int => Int) => k(k(k(x))) }
87 | else
88 | x
89 |
90 | def test2(x: Int) = if (x <= 7)
91 | x
92 | else
93 | shift { k: (Int => Int) => k(k(k(x))) }
94 |
95 | @Test def ifelse1 = {
96 | assertEquals(10, reset(1 + test1(7)))
97 | assertEquals(9, reset(1 + test1(8)))
98 | assertEquals(8, reset(1 + test2(7)))
99 | assertEquals(11, reset(1 + test2(8)))
100 | }
101 |
102 | def test3(x: Int) = if (x <= 7)
103 | shift { k: (Unit => Unit) => printOut("abort") }
104 |
105 | @Test def ifelse2 = {
106 | out.clear()
107 | printOut(reset { test3(7); printOut("alive") })
108 | printOut(reset { test3(8); printOut("alive") })
109 | assertEquals("abort()alive()", out.toString)
110 | }
111 |
112 | def util(x: Boolean) = shift { k: (Boolean => Int) => k(x) }
113 |
114 | def test4(x: Int) = if (util(x <= 7))
115 | x - 1
116 | else
117 | x + 1
118 |
119 | @Test def ifelse3 = {
120 | assertEquals(6, reset(test4(7)))
121 | assertEquals(9, reset(test4(8)))
122 | }
123 |
124 | def sh(x1: Int) = shift((k: Int => Int) => k(k(k(x1))))
125 |
126 | def testA(x1: Int): Int @cps[Int] = {
127 | sh(x1)
128 | if (x1 == 42) x1 else sh(x1)
129 | }
130 |
131 | def testB(x1: Int): Int @cps[Int] = {
132 | if (sh(x1) == 43) x1 else x1
133 | }
134 |
135 | def testC(x1: Int): Int @cps[Int] = {
136 | sh(x1)
137 | if (sh(x1) == 44) x1 else x1
138 | }
139 |
140 | def testD(x1: Int): Int @cps[Int] = {
141 | sh(x1)
142 | if (sh(x1) == 45) x1 else sh(x1)
143 | }
144 |
145 | @Test def ifelse4 = {
146 | assertEquals(10, reset(1 + testA(7)))
147 | assertEquals(10, reset(1 + testB(9)))
148 | assertEquals(10, reset(1 + testC(9)))
149 | assertEquals(10, reset(1 + testD(7)))
150 | }
151 | }
152 |
153 | class Inference {
154 |
155 | object A {
156 | class foo[-B, +C] extends StaticAnnotation with TypeConstraint
157 |
158 | def shift[A, B, C](fun: (A => B) => C): A @foo[B, C] = ???
159 | def reset[A, C](ctx: => (A @foo[A, C])): C = ???
160 |
161 | def m1 = reset { shift { f: (Int => Range) => f(5) }.to(10) }
162 | }
163 |
164 | object B {
165 |
166 | def m1 = reset { shift { f: (Int => Range) => f(5) }.to(10) }
167 | def m2 = reset { val a = shift { f: (Int => Range) => f(5) }; a.to(10) }
168 |
169 | val x1 = reset {
170 | shift { cont: (Int => Range) =>
171 | cont(5)
172 | }.to(10)
173 | }
174 |
175 | val x2 = reset {
176 | val a = shift { cont: (Int => Range) =>
177 | cont(5)
178 | }
179 | a.to(10)
180 | } // x is now Range(5, 6, 7, 8, 9, 10)
181 |
182 | val x3 = reset {
183 | shift { cont: (Int => Int) =>
184 | cont(5)
185 | } + 10
186 | } // x is now 15
187 |
188 | val x4 = reset {
189 | 10 :: shift { cont: (List[Int] => List[Int]) =>
190 | cont(List(1, 2, 3))
191 | }
192 | } // x is List(10, 1, 2, 3)
193 |
194 | val x5 = reset {
195 | new scala.runtime.RichInt(shift { cont: (Int => Range) =>
196 | cont(5)
197 | }) to 10
198 | }
199 | }
200 |
201 | @Test def implicit_infer_annotations = {
202 | import B._
203 | assertEquals(5 to 10, x1)
204 | assertEquals(5 to 10, x2)
205 | assertEquals(15, x3)
206 | assertEquals(List(10, 1, 2, 3), x4)
207 | assertEquals(5 to 10, x5)
208 | }
209 |
210 | def test(x: => Int @cpsParam[String, Int]) = 7
211 |
212 | def test2() = {
213 | val x = shift { k: (Int => String) => 9 }
214 | x
215 | }
216 |
217 | def test3(x: => Int @cpsParam[Int, Int]) = 7
218 |
219 | def util() = shift { k: (String => String) => "7" }
220 |
221 | @Test def infer1: Unit = {
222 | test { shift { k: (Int => String) => 9 } }
223 | test { shift { k: (Int => String) => 9 }; 2 }
224 | // test { shift { k: (Int => String) => 9 }; util() } <-- doesn't work
225 | test { shift { k: (Int => String) => 9 }; util(); 2 }
226 |
227 | test { shift { k: (Int => String) => 9 }; { test3(0); 2 } }
228 |
229 | test3 { { test3(0); 2 } }
230 |
231 | }
232 |
233 | }
234 |
235 | class PatternMatching {
236 |
237 | def test(x: Int) = x match {
238 | case 7 => shift { k: (Int => Int) => k(k(k(x))) }
239 | case 8 => shift { k: (Int => Int) => k(x) }
240 | }
241 |
242 | @Test def match0 = {
243 | assertEquals(10, reset(1 + test(7)))
244 | assertEquals(9, reset(1 + test(8)))
245 | }
246 |
247 | def test1(x: Int) = x match {
248 | case 7 => shift { k: (Int => Int) => k(k(k(x))) }
249 | case _ => x
250 | }
251 |
252 | @Test def match1 = {
253 | assertEquals(10, reset(1 + test(7)))
254 | assertEquals(9, reset(1 + test(8)))
255 | }
256 |
257 | def test2() = {
258 | val (a, b) = shift { k: (((String, String)) => String) => k("A", "B") }
259 | b
260 | }
261 |
262 | case class Elem[T, U](a: T, b: U)
263 |
264 | def test3() = {
265 | val Elem(a, b) = shift { k: (Elem[String, String] => String) => k(Elem("A", "B")) }
266 | b
267 | }
268 |
269 | @Test def match2 = {
270 | assertEquals("B", reset(test2()))
271 | assertEquals("B", reset(test3()))
272 | }
273 |
274 | def sh(x1: Int) = shift((k: Int => Int) => k(k(k(x1))))
275 |
276 | def testv(x1: Int) = {
277 | val o7 = {
278 | val o6 = {
279 | val o3 =
280 | if (7 == x1) Some(x1)
281 | else None
282 |
283 | if (o3.isEmpty) None
284 | else Some(sh(x1))
285 | }
286 | if (o6.isEmpty) {
287 | val o5 =
288 | if (8 == x1) Some(x1)
289 | else None
290 |
291 | if (o5.isEmpty) None
292 | else Some(sh(x1))
293 | } else o6
294 | }
295 | o7.get
296 | }
297 |
298 | @Test def patvirt = {
299 | assertEquals(10, reset(1 + testv(7)))
300 | assertEquals(11, reset(1 + testv(8)))
301 | }
302 |
303 | class MatchRepro {
304 | def s: String @cps[Any] = shift { k => k("foo") }
305 |
306 | def p = {
307 | val k = s
308 | s match { case lit0 => }
309 | }
310 |
311 | def q = {
312 | val k = s
313 | k match { case lit1 => }
314 | }
315 |
316 | def r = {
317 | s match { case "FOO" => }
318 | }
319 |
320 | def t = {
321 | val k = s
322 | k match { case "FOO" => }
323 | }
324 | }
325 |
326 | @Test def z1673 = {
327 | val m = new MatchRepro
328 | ()
329 | }
330 | }
331 |
332 | class IfReturn {
333 | // d = 1, d2 = 1.0, pct = 1.000
334 | // d = 2, d2 = 4.0, pct = 0.500
335 | // d = 3, d2 = 9.0, pct = 0.333
336 | // d = 4, d2 = 16.0, pct = 0.250
337 | // d = 5, d2 = 25.0, pct = 0.200
338 | // d = 6, d2 = 36.0, pct = 0.167
339 | // d = 7, d2 = 49.0, pct = 0.143
340 | // d = 8, d2 = 64.0, pct = 0.125
341 | // d = 9, d2 = 81.0, pct = 0.111
342 | // d = 10, d2 = 100.0, pct = 0.100
343 | // d = 11, d2 = 121.0, pct = 0.091
344 | // d = 12, d2 = 144.0, pct = 0.083
345 | // d = 13, d2 = 169.0, pct = 0.077
346 | // d = 14, d2 = 196.0, pct = 0.071
347 | // d = 15, d2 = 225.0, pct = 0.067
348 | // d = 16, d2 = 256.0, pct = 0.063
349 | // d = 17, d2 = 289.0, pct = 0.059
350 | // d = 18, d2 = 324.0, pct = 0.056
351 | // d = 19, d2 = 361.0, pct = 0.053
352 | // d = 20, d2 = 400.0, pct = 0.050
353 | // d = 21, d2 = 441.0, pct = 0.048
354 | // d = 22, d2 = 484.0, pct = 0.045
355 | // d = 23, d2 = 529.0, pct = 0.043
356 | // d = 24, d2 = 576.0, pct = 0.042
357 | // d = 25, d2 = 625.0, pct = 0.040
358 |
359 | abstract class IfReturnRepro {
360 | def s1: Double @cpsParam[Any, Unit]
361 | def s2: Double @cpsParam[Any, Unit]
362 |
363 | def p(i: Int): Double @cpsParam[Unit, Any] = {
364 | val px = s1
365 | val pct = if (px > 100) px else px / s2
366 | //printOut("pct = %.3f".format(pct))
367 | assertEquals(s1 / s2, pct, 0.0001)
368 | pct
369 | }
370 | }
371 |
372 | @Test def shift_pct = {
373 | var d: Double = 0d
374 | def d2 = d * d
375 |
376 | val irr = new IfReturnRepro {
377 | def s1 = shift(f => f(d))
378 | def s2 = shift(f => f(d2))
379 | }
380 | 1 to 25 foreach { i =>
381 | d = i
382 | // print("d = " + i + ", d2 = " + d2 + ", ")
383 | assertEquals(i.toDouble * i, d2, 0.1)
384 | run(irr p i)
385 | }
386 | }
387 |
388 | @Test def t1807 = {
389 | val z = reset {
390 | val f: (() => Int @cps[Int]) = () => 1
391 | f()
392 | }
393 | assertEquals(1, z)
394 | }
395 |
396 | @Test def t1808: Unit = {
397 | reset0 { 0 }
398 | }
399 | }
400 |
401 | class Suspendable {
402 | def shifted: Unit @suspendable = shift { (k: Unit => Unit) => () }
403 | def test1(b: => Boolean) = {
404 | reset {
405 | if (b) shifted
406 | }
407 | }
408 | @Test def t1820 = test1(true)
409 |
410 | def suspended[A](x: A): A @suspendable = x
411 | def test1[A](x: A): A @suspendable = suspended(x) match { case x => x }
412 | def test2[A](x: List[A]): A @suspendable = suspended(x) match { case List(x) => x }
413 |
414 | def test3[A](x: A): A @suspendable = x match { case x => x }
415 | def test4[A](x: List[A]): A @suspendable = x match { case List(x) => x }
416 |
417 | @Test def t1821: Unit = {
418 | assertEquals((), reset(test1(())))
419 | assertEquals((), reset(test2(List(()))))
420 | assertEquals((), reset(test3(())))
421 | assertEquals((), reset(test4(List(()))))
422 | }
423 | }
424 |
425 | class Misc {
426 |
427 | def double[B](n: Int)(k: Int => B): B = k(n * 2)
428 |
429 | @Test def t2864: Unit = {
430 | reset {
431 | val result1 = shift(double[Unit](100))
432 | val result2 = shift(double[Unit](result1))
433 | assertEquals(400, result2)
434 | }
435 | }
436 |
437 | def foo: Int @cps[Int] = {
438 | val a0 = shift((k: Int => Int) => k(0))
439 | val x0 = 2
440 | val a1 = shift((k: Int => Int) => x0)
441 | 0
442 | }
443 |
444 | /*
445 | def bar: ControlContext[Int,Int,Int] = {
446 | shiftR((k:Int=>Int) => k(0)).flatMap { a0 =>
447 | val x0 = 2
448 | shiftR((k:Int=>Int) => x0).map { a1 =>
449 | 0
450 | }}
451 | }
452 | */
453 |
454 | @Test def t2934 = {
455 | assertEquals(List(3, 4, 5), reset {
456 | val x = shift(List(1, 2, 3).flatMap[Int, List[Int]])
457 | List(x + 2)
458 | })
459 | }
460 |
461 | class Bla {
462 | val x = 8
463 | def y[T] = 9
464 | }
465 |
466 | /*
467 | def bla[A] = shift { k:(Bla=>A) => k(new Bla) }
468 | */
469 |
470 | def bla1 = shift { k: (Bla => Bla) => k(new Bla) }
471 | def bla2 = shift { k: (Bla => Int) => k(new Bla) }
472 |
473 | def fooA = bla2.x
474 | def fooB[T] = bla2.y[T]
475 |
476 | // TODO: check whether this also applies to a::shift { k => ... }
477 |
478 | @Test def t3225Mono(): Unit = {
479 | assertEquals(8, reset(bla1).x)
480 | assertEquals(8, reset(bla2.x))
481 | assertEquals(9, reset(bla2.y[Int]))
482 | assertEquals(9, reset(bla2.y))
483 | assertEquals(8, reset(fooA))
484 | assertEquals(9, reset(fooB))
485 | 0
486 | }
487 |
488 | def blaX[A] = shift { k: (Bla => A) => k(new Bla) }
489 |
490 | def fooX[A] = blaX[A].x
491 | def fooY[A] = blaX[A].y[A]
492 |
493 | @Test def t3225Poly(): Unit = {
494 | assertEquals(8, reset(blaX[Bla]).x)
495 | assertEquals(8, reset(blaX[Int].x))
496 | assertEquals(9, reset(blaX[Int].y[Int]))
497 | assertEquals(9, reset(blaX[Int].y))
498 | assertEquals(8, reset(fooX[Int]))
499 | assertEquals(9, reset(fooY[Int]))
500 | 0
501 | }
502 |
503 | def capture(): Int @suspendable = 42
504 |
505 | @Test def t3501: Unit = reset {
506 | var i = 0
507 | while (i < 5) {
508 | i += 1
509 | val y = capture()
510 | val s = y
511 | assertEquals(42, s)
512 | }
513 | assertEquals(5, i)
514 | }
515 | }
516 |
517 | class Return {
518 | val out = new StringBuilder; def printOut(x: Any): Unit = out ++= x.toString
519 |
520 | class ReturnRepro {
521 | def s1: Int @cps[Any] = shift { k => k(5) }
522 | def caller = reset { printOut(p(3)) }
523 | def caller2 = reset { printOut(p2(3)) }
524 | def caller3 = reset { printOut(p3(3)) }
525 |
526 | def p(i: Int): Int @cps[Any] = {
527 | val v = s1 + 3
528 | return v
529 | }
530 |
531 | def p2(i: Int): Int @cps[Any] = {
532 | val v = s1 + 3
533 | if (v > 0) {
534 | printOut("hi")
535 | return v
536 | } else {
537 | printOut("hi")
538 | return 8
539 | }
540 | }
541 |
542 | def p3(i: Int): Int @cps[Any] = {
543 | val v = s1 + 3
544 | try {
545 | printOut("from try")
546 | return v
547 | } catch {
548 | case e: Exception =>
549 | printOut("from catch")
550 | return 7
551 | }
552 | }
553 | }
554 |
555 | @Test def t5314_2 = {
556 | out.clear()
557 | val repro = new ReturnRepro
558 | repro.caller
559 | repro.caller2
560 | repro.caller3
561 | assertEquals("8hi8from try8", out.toString)
562 | }
563 |
564 | class ReturnRepro2 {
565 |
566 | def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
567 | def caller = reset { printOut(p(3)) }
568 | def caller2 = reset { printOut(p2(3)) }
569 |
570 | def p(i: Int): Int @cpsParam[Unit, Any] = {
571 | val v = s1 + 3
572 | return { printOut("enter return expr"); v }
573 | }
574 |
575 | def p2(i: Int): Int @cpsParam[Unit, Any] = {
576 | val v = s1 + 3
577 | if (v > 0) {
578 | return { printOut("hi"); v }
579 | } else {
580 | return { printOut("hi"); 8 }
581 | }
582 | }
583 | }
584 |
585 | @Test def t5314_3 = {
586 | out.clear()
587 | val repro = new ReturnRepro2
588 | repro.caller
589 | repro.caller2
590 | assertEquals("enter return expr8hi8", out.toString)
591 | }
592 |
593 | def foo(x: Int): Int @cps[Int] = 7
594 |
595 | def bar(x: Int): Int @cps[Int] = {
596 | val v = foo(x)
597 | if (v > 0)
598 | return v
599 | else
600 | return 10
601 | }
602 |
603 | @Test def t5314_with_if =
604 | assertEquals(7, reset { bar(10) })
605 |
606 | }
607 |
608 | class t5314 {
609 | val out = new StringBuilder; def printOut(x: Any): Unit = out ++= x.toString
610 |
611 | class ReturnRepro3 {
612 | def s1: Int @cpsParam[Any, Unit] = shift { k => k(5) }
613 | def caller = reset { printOut(p(3)) }
614 | def caller2 = reset { printOut(p2(3)) }
615 |
616 | def p(i: Int): Int @cpsParam[Unit, Any] = {
617 | val v = s1 + 3
618 | return v
619 | }
620 |
621 | def p2(i: Int): Int @cpsParam[Unit, Any] = {
622 | val v = s1 + 3
623 | if (v > 0) {
624 | printOut("hi")
625 | return v
626 | } else {
627 | printOut("hi")
628 | return 8
629 | }
630 | }
631 | }
632 |
633 | def foo(x: Int): Int @cps[Int] = shift { k => k(x) }
634 |
635 | def bar(x: Int): Int @cps[Int] = return foo(x)
636 |
637 | def nocps(x: Int): Int = { return x; x }
638 |
639 | def foo2(x: Int): Int @cps[Int] = 7
640 | def bar2(x: Int): Int @cps[Int] = { foo2(x); return 7 }
641 | def bar3(x: Int): Int @cps[Int] = { foo2(x); if (x == 7) return 7 else return foo2(x) }
642 | def bar4(x: Int): Int @cps[Int] = { foo2(x); if (x == 7) return 7 else foo2(x) }
643 | def bar5(x: Int): Int @cps[Int] = { foo2(x); if (x == 7) return 7 else 8 }
644 |
645 | @Test def t5314 = {
646 | out.clear()
647 |
648 | printOut(reset { bar2(10) })
649 | printOut(reset { bar3(10) })
650 | printOut(reset { bar4(10) })
651 | printOut(reset { bar5(10) })
652 |
653 | /* original test case */
654 | val repro = new ReturnRepro3
655 | repro.caller
656 | repro.caller2
657 |
658 | reset {
659 | val res = bar(8)
660 | printOut(res)
661 | res
662 | }
663 |
664 | assertEquals("77788hi88", out.toString)
665 | }
666 |
667 | }
668 |
669 | class HigherOrder {
670 |
671 | import java.util.concurrent.atomic._
672 |
673 | @Test def t5472 = {
674 | val map = Map("foo" -> 1, "bar" -> 2)
675 | reset {
676 | val mapped = new ContinuationizedParallelIterable(map)
677 | .map(strIntTpl => shiftUnit0[Int, Unit](23))
678 | assertEquals(List(23, 23), mapped.toList)
679 | }
680 | }
681 |
682 | @deprecated("Suppress warnings", since = "2.11")
683 | final class ContinuationizedParallelIterable[+A](protected val underline: Iterable[A]) {
684 | def toList = underline.toList.sortBy(_.toString)
685 |
686 | final def filter(p: A => Boolean @suspendable): ContinuationizedParallelIterable[A] @suspendable =
687 | shift(
688 | new AtomicInteger(1) with ((ContinuationizedParallelIterable[A] => Unit) => Unit) {
689 | private val results = new AtomicReference[List[A]](Nil)
690 |
691 | @tailrec
692 | private def add(element: A): Unit = {
693 | val old = results.get
694 | if (!results.compareAndSet(old, element :: old)) {
695 | add(element)
696 | }
697 | }
698 |
699 | override final def apply(continue: ContinuationizedParallelIterable[A] => Unit): Unit = {
700 | for (element <- underline) {
701 | super.incrementAndGet()
702 | reset {
703 | val pass = p(element)
704 | if (pass) {
705 | add(element)
706 | }
707 | if (super.decrementAndGet() == 0) {
708 | continue(new ContinuationizedParallelIterable(results.get))
709 | }
710 | }
711 | }
712 | if (super.decrementAndGet() == 0) {
713 | continue(new ContinuationizedParallelIterable(results.get))
714 | }
715 | }
716 | })
717 |
718 | final def foreach[U](f: A => U @suspendable): Unit @suspendable =
719 | shift(
720 | new AtomicInteger(1) with ((Unit => Unit) => Unit) {
721 | override final def apply(continue: Unit => Unit): Unit = {
722 | for (element <- underline) {
723 | super.incrementAndGet()
724 | reset {
725 | f(element)
726 | if (super.decrementAndGet() == 0) {
727 | continue(())
728 | }
729 | }
730 | }
731 | if (super.decrementAndGet() == 0) {
732 | continue(())
733 | }
734 | }
735 | })
736 |
737 | final def map[B: Manifest](f: A => B @suspendable): ContinuationizedParallelIterable[B] @suspendable =
738 | shift(
739 | new AtomicInteger(underline.size) with ((ContinuationizedParallelIterable[B] => Unit) => Unit) {
740 | override final def apply(continue: ContinuationizedParallelIterable[B] => Unit): Unit = {
741 | val results = new Array[B](super.get)
742 | for ((element, i) <- underline.view.zipWithIndex) {
743 | reset {
744 | val result = f(element)
745 | results(i) = result
746 | if (super.decrementAndGet() == 0) {
747 | continue(new ContinuationizedParallelIterable(results))
748 | }
749 | }
750 | }
751 | }
752 | })
753 | }
754 |
755 | def g: List[Int] @suspendable = List(1, 2, 3)
756 |
757 | def fp10: List[Int] @suspendable = {
758 | g.map(x => x)
759 | }
760 |
761 | def fp11: List[Int] @suspendable = {
762 | val z = g.map(x => x)
763 | z
764 | }
765 |
766 | def fp12: List[Int] @suspendable = {
767 | val z = List(1, 2, 3)
768 | z.map(x => x)
769 | }
770 |
771 | def fp20: List[Int] @suspendable = {
772 | g.map[Int, List[Int]](x => x)
773 | }
774 |
775 | def fp21: List[Int] @suspendable = {
776 | val z = g.map[Int, List[Int]](x => x)
777 | z
778 | }
779 |
780 | def fp22: List[Int] @suspendable = {
781 | val z = g.map[Int, List[Int]](x => x)(List.canBuildFrom[Int])
782 | z
783 | }
784 |
785 | def fp23: List[Int] @suspendable = {
786 | val z = g.map(x => x)(List.canBuildFrom[Int])
787 | z
788 | }
789 |
790 | @Test def t5506 = {
791 | reset {
792 | assertEquals(List(1, 2, 3), fp10)
793 | assertEquals(List(1, 2, 3), fp11)
794 | assertEquals(List(1, 2, 3), fp12)
795 | assertEquals(List(1, 2, 3), fp20)
796 | assertEquals(List(1, 2, 3), fp21)
797 | assertEquals(List(1, 2, 3), fp22)
798 | assertEquals(List(1, 2, 3), fp23)
799 | }
800 | }
801 | class ExecutionContext
802 |
803 | implicit def defaultExecutionContext: ExecutionContext = new ExecutionContext
804 |
805 | case class Future[+T](x: T) {
806 | final def map[A](f: T => A): Future[A] = new Future[A](f(x))
807 | final def flatMap[A](f: T => Future[A]): Future[A] = f(x)
808 | }
809 |
810 | class PromiseStream[A] {
811 | override def toString = xs.toString
812 |
813 | var xs: List[A] = Nil
814 |
815 | final def +=(elem: A): this.type = { xs :+= elem; this }
816 |
817 | final def ++=(elem: Traversable[A]): this.type = { xs ++= elem; this }
818 |
819 | final def <<(elem: Future[A]): PromiseStream[A] @cps[Future[Any]] =
820 | shift { cont: (PromiseStream[A] => Future[Any]) => elem map (a => cont(this += a)) }
821 |
822 | final def <<(elem1: Future[A], elem2: Future[A], elems: Future[A]*): PromiseStream[A] @cps[Future[Any]] =
823 | shift { cont: (PromiseStream[A] => Future[Any]) => Future.flow(this << elem1 << elem2 <<< Future.sequence(elems.toSeq)) map cont }
824 |
825 | final def <<<(elems: Traversable[A]): PromiseStream[A] @cps[Future[Any]] =
826 | shift { cont: (PromiseStream[A] => Future[Any]) => cont(this ++= elems) }
827 |
828 | final def <<<(elems: Future[Traversable[A]]): PromiseStream[A] @cps[Future[Any]] =
829 | shift { cont: (PromiseStream[A] => Future[Any]) => elems map (as => cont(this ++= as)) }
830 | }
831 |
832 | object Future {
833 |
834 | def sequence[A, M[_] <: Traversable[_]](in: M[Future[A]])(implicit cbf: CanBuildFrom[M[Future[A]], A, M[A]], executor: ExecutionContext): Future[M[A]] =
835 | new Future(in.asInstanceOf[Traversable[Future[A]]].map((f: Future[A]) => f.x)(cbf.asInstanceOf[CanBuildFrom[Traversable[Future[A]], A, M[A]]]))
836 |
837 | def flow[A](body: => A @cps[Future[Any]])(implicit executor: ExecutionContext): Future[A] = reset(Future(body)).asInstanceOf[Future[A]]
838 |
839 | }
840 |
841 | @Test def t5538 = {
842 | val p = new PromiseStream[Int]
843 | assertEquals(Future(Future(Future(Future(Future(List(1, 2, 3, 4, 5)))))).toString, Future.flow(p << (Future(1), Future(2), Future(3), Future(4), Future(5))).toString)
844 | }
845 | }
846 |
847 | class TryCatch {
848 |
849 | def foo = try {
850 | shift((k: Int => Int) => k(7))
851 | } catch {
852 | case ex: Throwable =>
853 | 9
854 | }
855 |
856 | def bar = try {
857 | 7
858 | } catch {
859 | case ex: Throwable =>
860 | shiftUnit0[Int, Int](9)
861 | }
862 |
863 | @Test def trycatch0 = {
864 | assertEquals(10, reset { foo + 3 })
865 | assertEquals(10, reset { bar + 3 })
866 | }
867 |
868 | def fatal: Int = throw new Exception()
869 |
870 | def foo1 = try {
871 | fatal
872 | shift((k: Int => Int) => k(7))
873 | } catch {
874 | case ex: Throwable =>
875 | 9
876 | }
877 |
878 | def foo2 = try {
879 | shift((k: Int => Int) => k(7))
880 | fatal
881 | } catch {
882 | case ex: Throwable =>
883 | 9
884 | }
885 |
886 | def bar1 = try {
887 | fatal
888 | 7
889 | } catch {
890 | case ex: Throwable =>
891 | shiftUnit0[Int, Int](9) // regular shift causes no-symbol doesn't have owner
892 | }
893 |
894 | def bar2 = try {
895 | 7
896 | fatal
897 | } catch {
898 | case ex: Throwable =>
899 | shiftUnit0[Int, Int](9) // regular shift causes no-symbol doesn't have owner
900 | }
901 |
902 | @Test def trycatch1 = {
903 | assertEquals(12, reset { foo1 + 3 })
904 | assertEquals(12, reset { foo2 + 3 })
905 | assertEquals(12, reset { bar1 + 3 })
906 | assertEquals(12, reset { bar2 + 3 })
907 | }
908 |
909 | trait AbstractResource[+R <: AnyRef] {
910 | def reflect[B]: R @cpsParam[B, Either[Throwable, B]] = shift(acquireFor)
911 | def acquireFor[B](f: R => B): Either[Throwable, B] = {
912 | import Exception._
913 | catching(List(classOf[Throwable]): _*) either (f(null.asInstanceOf[R]))
914 | }
915 | }
916 |
917 | @Test def t3199 = {
918 | val x = new AbstractResource[String] {}
919 | val result = x.acquireFor(x => 7)
920 | assertEquals(Right(7), result)
921 | }
922 |
923 | }
924 |
925 | object AvoidClassT3233 {
926 | def foo(x: Int) = {
927 | try {
928 | throw new Exception
929 | shiftUnit0[Int, Int](7)
930 | } catch {
931 | case ex: Throwable =>
932 | val g = (a: Int) => a
933 | 9
934 | }
935 | }
936 | }
937 |
938 | class T3233 {
939 | // work around scalac bug: Trying to access the this of another class: tree.symbol = class TryCatch, ctx.clazz.symbol = <$anon: Function1>
940 | import AvoidClassT3233._
941 | @Test def t3223 = {
942 | assertEquals(9, reset(foo(0)))
943 | }
944 | }
945 |
946 | class While {
947 | val out = new StringBuilder; def printOut(x: Any): Unit = out ++= x.toString
948 |
949 | def foo0(): Int @cps[Unit] = 2
950 |
951 | def test0(): Unit @cps[Unit] = {
952 | var x = 0
953 | while (x < 9000) { // pick number large enough to require tail-call opt
954 | x += foo0()
955 | }
956 | assertEquals(9000, x)
957 | }
958 |
959 | @Test def while0 = {
960 | reset(test0())
961 | }
962 |
963 | def foo3(): Int @cps[Unit] = shift { k => printOut("up"); k(2); printOut("down") }
964 |
965 | def test3(): Unit @cps[Unit] = {
966 | var x = 0
967 | while (x < 9) {
968 | x += foo3()
969 | }
970 | printOut(x)
971 | }
972 |
973 | @Test def while1 = {
974 | out.clear
975 | reset(test3())
976 | assertEquals("upupupupup10downdowndowndowndown", out.toString)
977 | }
978 |
979 | def foo1(): Int @cps[Unit] = 2
980 | def foo2(): Int @cps[Unit] = shift { k => printOut("up"); k(2); printOut("down") }
981 |
982 | def test2(): Unit @cps[Unit] = {
983 | var x = 0
984 | while (x < 9000) { // pick number large enough to require tail-call opt
985 | x += (if (x % 1000 != 0) foo1() else foo2())
986 | }
987 | printOut(x)
988 | }
989 |
990 | @Test def while2 = {
991 | out.clear
992 | reset(test2())
993 | assertEquals("upupupupupupupupup9000downdowndowndowndowndowndowndowndown", out.toString)
994 | }
995 | }
--------------------------------------------------------------------------------