├── .github └── workflows │ └── scala.yml ├── .gitignore ├── .scala-steward.conf ├── LICENSE ├── README.md ├── bin └── simple-algebraic-subtyping-opt.js ├── build.sbt ├── index.css ├── index.html ├── js └── src │ └── main │ └── scala │ └── Main.scala ├── notes └── interpretation-of-as-types.txt ├── out ├── basic.check ├── booleans.check ├── isolated.check ├── let-poly.check ├── occurs-check.check ├── random.check ├── records.check ├── recursion.check └── self-app.check ├── project ├── build.properties └── plugins.sbt └── shared └── src ├── main └── scala │ └── simplesub │ ├── Parser.scala │ ├── Typer.scala │ ├── TyperDebugging.scala │ ├── helpers.scala │ ├── package.scala │ └── syntax.scala └── test └── scala └── simplesub ├── IsolatedTests.scala ├── OtherTests.scala ├── ParserTests.scala ├── ProgramTests.scala ├── TypingTestHelpers.scala └── TypingTests.scala /.github/workflows/scala.yml: -------------------------------------------------------------------------------- 1 | name: Scala CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up JDK 1.8 17 | uses: actions/setup-java@v1 18 | with: 19 | java-version: 1.8 20 | - name: Run tests 21 | run: sbt test 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | .vscode/ 3 | .bloop/ 4 | .metals/ 5 | .bsp/ 6 | project/Dependencies.scala 7 | project/metals.sbt 8 | project/project/ 9 | **.worksheet.sc 10 | mlsub/ 11 | -------------------------------------------------------------------------------- /.scala-steward.conf: -------------------------------------------------------------------------------- 1 | updates.ignore = [ { groupId = "org.wartremover" } ] 2 | pullRequests.frequency = "45 days" 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Lionel Parreaux 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simpler-sub Algorithm for Type Inference with Subtyping 2 | 3 | This repository shows the implementation of **Simpler-sub**, 4 | an alternative algorithm to [Simple-sub](https://github.com/LPTK/simple-sub) which is much easier to understand but also much more limited, 5 | though it is probably enough for many practical use cases. 6 | 7 | An online demo is available here: https://lptk.github.io/simpler-sub/ 8 | 9 | 10 | ## Simplifications 11 | 12 | By contrast to Simple-sub, Simpler-sub does not support: 13 | 14 | * (1) Recursive types: any recursive constraint will yield an error instead. 15 | 16 | If your language has externally-defined recursive data types (such as algebraic data types), 17 | you don't strictly need recursive types anyway. 18 | But in order to support field-recursing definitions of the form `let foo x = ... foo x.f ...`, you'll need a way to make it clear to the type checker which recursive data type's `f` field this `foo` recurses on 19 | (otherwise you'll get a recursive constraint error between inferred record types). 20 | One way to do that is to either reject the use of overloaded field names (like Haskell 98) 21 | or use a type class for field selection instead of subtyping (like recent Haskell) 22 | or use contextual information to disambiguate field selections when possible (like OCaml). 23 | 24 | The absence of recursive types makes some of the type inference algorithms simpler and more efficient, 25 | as they can now freely decompose types inductively without having to carry a cache around. 26 | 27 | * (2) Nested let polymorphism: in this prototype, a _local_ (i.e., nested) `let ... in ...` binding will never be assigned a polymorphic type. 28 | 29 | This simplifies the approach further in that we don't have to deal with levels. 30 | And there is precedent for it, 31 | for example see the paper [Let Should not be Generalised](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/tldi10-vytiniotis.pdf) by Dimitrios Vytiniotis, Simon Peyton Jones, and Tom Schrijvers. 32 | 33 | * (3) Precise type-variable-to-type-variable constraints: any constraint between two type variables immediately unifies the two variables. 34 | 35 | The consequence of this is not as dire as one may think, 36 | thanks to pervasive polymorphism (though also see Restriction 2). 37 | Programs that would typically exhibit problematic loss of precision with this approach are those similar to 38 | `let test f x y = if f x then x else y`, 39 | which is typed by Simple-sub as `('a -> bool) -> 'a ∧ 'b -> 'b -> 'b`, 40 | but is now typed as `('a -> bool) -> 'a -> 'a -> 'a` 41 | – notice that `y` is forced to be typed as `'a`, the parameter type taken by `f`, 42 | even though it never flows into `f` and has in fact nothing to do with it. 43 | 44 | In Simple-sub, and in algebraic subtyping in general, 45 | precise graphs of type variable inequalities are recorded, 46 | and then need to be simplified aggressively before being displayed to the user. 47 | Unifying type variables aggressively, on the other hand, 48 | forces inferred type graphs to be almost as simple as in traditional Hindley-Milner type inference, 49 | and makes inferred type simplification much easier. 50 | 51 | Restriction (3) destroys principal type inference: 52 | there may now be well-typed terms that the type inference approach rejects. 53 | 54 | Restriction (1) destroys the principal type property: 55 | there are now terms which cannot be ascribe a single most precise type 56 | – those which would have been typed by Simple-sub through a recursive type, 57 | but which can still be given less precise non-recursive types. 58 | Simpler-sub will in fact plainly reject any such term. 59 | 60 | Each of these simplifications could be made independently, starting from Simple-sub. 61 | I chose to implement them all together in this project to show how simple the result could look like, 62 | while still being quite useful as a possible foundation for type inference in languages with subtyping. 63 | 64 | 65 | -------------------------------------------------------------------------------- /build.sbt: -------------------------------------------------------------------------------- 1 | import Wart._ 2 | 3 | enablePlugins(ScalaJSPlugin) 4 | 5 | ThisBuild / scalaVersion := "2.13.8" 6 | ThisBuild / version := "0.1.0-SNAPSHOT" 7 | ThisBuild / organization := "io.lptk" 8 | ThisBuild / organizationName := "LPTK" 9 | 10 | lazy val root = project.in(file(".")) 11 | .aggregate(simplesubJS, simplesubJVM) 12 | .settings( 13 | publish := {}, 14 | publishLocal := {} 15 | ) 16 | 17 | lazy val simplesub = crossProject(JSPlatform, JVMPlatform).in(file(".")) 18 | .settings( 19 | name := "simple-algebraic-subtyping", 20 | scalacOptions ++= Seq( 21 | "-deprecation", 22 | "-feature", 23 | "-unchecked", 24 | "-language:higherKinds", 25 | "-Ywarn-value-discard", 26 | ), 27 | wartremoverWarnings ++= Warts.allBut( 28 | Recursion, Throw, Nothing, Return, While, 29 | Var, MutableDataStructures, NonUnitStatements, 30 | DefaultArguments, ImplicitParameter, StringPlusAny, 31 | JavaSerializable, Serializable, Product, 32 | LeakingSealed, 33 | Option2Iterable, TraversableOps, 34 | Any, 35 | ), 36 | libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.12" % Test, 37 | libraryDependencies += "com.lihaoyi" %%% "fastparse" % "2.3.3", 38 | libraryDependencies += "com.lihaoyi" %%% "sourcecode" % "0.2.8", 39 | libraryDependencies += "com.lihaoyi" %% "ammonite-ops" % "2.4.0", 40 | ) 41 | .jsSettings( 42 | scalaJSUseMainModuleInitializer := true, 43 | libraryDependencies += "org.scala-js" %%% "scalajs-dom" % "2.1.0", 44 | libraryDependencies += "be.doeraene" %%% "scalajs-jquery" % "1.0.0", 45 | ThisBuild / evictionErrorLevel := Level.Info, 46 | ) 47 | 48 | lazy val simplesubJVM = simplesub.jvm 49 | lazy val simplesubJS = simplesub.js 50 | -------------------------------------------------------------------------------- /index.css: -------------------------------------------------------------------------------- 1 | body { 2 | background-color: #eee8d5; 3 | max-width: 1200px; 4 | margin: 0px auto; 5 | padding: 5px; 6 | color: #073642; 7 | } 8 | h1, p{ 9 | font-family: sans-serif; 10 | padding: 0px 10px; 11 | } 12 | #content { 13 | width: 100%; 14 | padding: 0px; 15 | height: 450px; 16 | } 17 | #content::after { 18 | content: "."; 19 | visibility: hidden; 20 | display: block; 21 | height: 0; 22 | clear: both; 23 | } 24 | 25 | #left, #right, textarea { 26 | width: 50%; 27 | height: 100%; 28 | padding: 10px; 29 | box-sizing: border-box; 30 | font-family: monospace; 31 | } 32 | 33 | #left { 34 | float: left; 35 | } 36 | 37 | #right { 38 | float: right; 39 | } 40 | 41 | #simple-sub-input, #simple-sub-output { 42 | width: 100%; 43 | height: 100%; 44 | box-sizing: border-box; 45 | border-radius: 10px; 46 | padding: 10px; 47 | background-color: #002b36; 48 | /* color: #93a1a1; */ 49 | /* color: #cacaca; */ 50 | color: #efefef; 51 | font-size: 14px; 52 | } 53 | #simple-sub-input { 54 | resize: none; 55 | outline: none; 56 | } 57 | 58 | #simple-sub-output { 59 | border: 1px solid black; 60 | } 61 | 62 | .binding { 63 | font-family: monospace; 64 | white-space: pre; 65 | } 66 | 67 | .name::before { 68 | content: "val "; 69 | color: #839496; 70 | } 71 | 72 | .name::after { 73 | content: " : "; 74 | color: #839496; 75 | } 76 | 77 | .name, .type { 78 | display: inline; 79 | } 80 | .name { 81 | color: #93a1a1; 82 | font-weight: bold; 83 | } 84 | .type { 85 | color: #859900; 86 | font-weight: bold; 87 | display: inline-block; 88 | } 89 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Simpler-sub demonstration 7 | 8 | 9 |

Simpler-sub demonstration

10 |
11 |
12 | 30 | 33 |
34 | 37 |
38 | 39 |
40 |

The code is available on github.

41 |
42 |

Credit: the CSS style sheet of this page was shamelessly stolen from the MLsub demo page.

43 | 44 | -------------------------------------------------------------------------------- /js/src/main/scala/Main.scala: -------------------------------------------------------------------------------- 1 | import scala.util.Try 2 | import scala.scalajs.js.annotation.JSExportTopLevel 3 | import org.scalajs.dom 4 | import org.scalajs.dom.document 5 | 6 | object Main { 7 | def main(args: Array[String]): Unit = { 8 | val source = document.querySelector("#simple-sub-input") 9 | update(source.textContent) 10 | source.addEventListener("input", typecheck) 11 | } 12 | @JSExportTopLevel("typecheck") 13 | def typecheck(e: dom.UIEvent): Unit = { 14 | e.target match { 15 | case elt: dom.HTMLTextAreaElement => 16 | update(elt.value) 17 | } 18 | } 19 | def update(str: String): Unit = { 20 | // println(s"Input: $str") 21 | val target = document.querySelector("#simple-sub-output") 22 | target.innerHTML = Try { 23 | import fastparse._ 24 | import fastparse.Parsed.{Success, Failure} 25 | import simplesub.Parser.pgrm 26 | import simplesub.TypeError 27 | parse(str, pgrm(_), verboseFailures = false) match { 28 | case f: Failure => 29 | val Failure(err, index, extra) = f 30 | // this line-parsing logic was copied from fastparse internals: 31 | val lineNumberLookup = fastparse.internal.Util.lineNumberLookup(str) 32 | val line = lineNumberLookup.indexWhere(_ > index) match { 33 | case -1 => lineNumberLookup.length - 1 34 | case n => math.max(0, n - 1) 35 | } 36 | val lines = str.split('\n') 37 | val lineStr = lines(line min lines.length - 1) 38 | "Parse error: " + extra.trace().msg + 39 | s" at line $line:
$lineStr
" 40 | case Success(p, index) => 41 | // println(s"Parsed: $p") 42 | object Typer extends simplesub.Typer(dbg = false) { 43 | import simplesub._ 44 | // Saldy, the original `inferTypes` version does not seem to work in JavaScript, as it raises a 45 | // "RangeError: Maximum call stack size exceeded" 46 | // So we have to go with this uglier one: 47 | def inferTypesJS( 48 | pgrm: Pgrm, 49 | ctx: Ctx = builtins, 50 | stopAtFirstError: Boolean = true, 51 | ): List[Either[TypeError, PolymorphicType]] = { 52 | var defs = pgrm.defs 53 | var curCtx = ctx 54 | var res = collection.mutable.ListBuffer.empty[Either[TypeError, PolymorphicType]] 55 | while (defs.nonEmpty) { 56 | val (isrec, nme, rhs) = defs.head 57 | defs = defs.tail 58 | val ty_sch = try Right(typeLetRhs(isrec, nme, rhs)(curCtx)) catch { 59 | case err: TypeError => 60 | if (stopAtFirstError) defs = Nil 61 | Left(err) 62 | } 63 | res += ty_sch 64 | curCtx += (nme -> ty_sch.getOrElse(freshVar)) 65 | } 66 | res.toList 67 | } 68 | } 69 | val tys = Typer.inferTypesJS(p) 70 | (p.defs.zipWithIndex lazyZip tys).map { 71 | case ((d, i), Right(ty)) => 72 | val ity = ty.instantiate 73 | println(s"Typed `${d._2}` as: $ity") 74 | println(s" where: ${ity.showBounds}") 75 | /* 76 | val com = Typer.canonicalizeType(ty.instantiate) 77 | println(s"Compact type before simplification: ${com}") 78 | val sim = Typer.simplifyType(com) 79 | println(s"Compact type after simplification: ${sim}") 80 | val exp = Typer.coalesceCompactType(sim) 81 | */ 82 | val sim = Typer.simplifyType(ity) 83 | println(s"Type after simplification: ${sim}") 84 | val exp = Typer.coalesceType(sim) 85 | s""" 86 | val 87 | ${d._2}: 88 | ${exp.show} 89 | """ 90 | case ((d, i), Left(TypeError(msg))) => 91 | s""" 92 | Type error in ${d._2}: $msg 93 | """ 94 | }.mkString("
") 95 | } 96 | }.fold(err => s""" 97 | 98 | Unexpected error: ${err}${ 99 | err.printStackTrace 100 | err 101 | }""", identity) 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /notes/interpretation-of-as-types.txt: -------------------------------------------------------------------------------- 1 | 2 | Two interpretations of types of the form T['a] as 'a 3 | - if interereted as proper recursive (a.k.a., infinite) type, 4 | it means ('a -> Bot) as 'a is too restrictive a type for self-app 5 | (it's specialized for applying self-app to itself, resulting in non-term and the Bot type as expected), 6 | but the original type ('a -> 'b) as 'a works fine. 7 | explanation: 8 | 'b cannot be simplified away, because it appears in both pos and neg, 9 | since 'a itself does (through its own recursive occurrence!) 10 | - if interpreted as a type variable bounded with its body accordingly to its position, 11 | then ('a -> Bot) as 'a is STILL NOT an acceptable type for self-app 12 | For instance, let's see what happens when we pass it (lam x. 42) of type Top -> Int; 13 | we constrain ?a <: (Top -> Int) -> ?b for some fresh ?b, where ?a :> (?a -> Bot) 14 | this leads to (?a -> Bot) <: (Top -> Int) -> ?b 15 | decomposed as (Top -> Int) <: ?a and Bot <: ?b 16 | which leads to Top -> Int <: Top -> Int 17 | So the constraints work out, but we end up with result type ?b where ?b :> Bot, 18 | which is obviously wrong! (we should have gotten an Int) 19 | And this also happens for the original type ('a -> 'b) as 'a, 20 | which is however valid under the other interpretation! 21 | Here, we may need to think harder about the phrase "bounded accordingly to its position" 22 | Again, the position of 'a really is both pos and neg, so we should really interpret the type as 23 | ?a :> ('a -> 'b) <: ('a -> 'b) 24 | Under this revised interpretation, things work out when applying it to (lam x. 42): 25 | we constrain ?a <: (Top -> Int) -> ?c for some fresh ?c, where ?a :> (?a -> ?b) <: (?a -> ?b) 26 | this leads to ?a -> ?b <: (Top -> Int) -> ?c 27 | decomposed as (Top -> Int) <: ?a and ?b <: Int 28 | which leads to Top -> Int <: Top -> Int and Top -> Int <: ?a -> ?b 29 | decomposed as ?a <: Top and Int <: ?b 30 | which leads to Int <: Int 31 | So the constraints work out, and we do get the appropriate result type Int 32 | 33 | Conclusion: 34 | Both interpretations of "T['a] as 'a" as infinite types and as type variables seem valid and equivalent. 35 | We can treat "T['a] as 'a" as representing a type variable ?a upper and lower-bounded by T[?a]. 36 | In the particular case where it appears only positively (resp. neg.) in the original type AS WELL 37 | AS in T[?a] itself, then we can forget about the lower bound (resp. upper bound). 38 | 39 | Note: 40 | The bounds of a recursive types that appear both positively and negatively have no reason to be the same; 41 | so our interpretation above is unlikely to capture the full generality of recursive type variables. 42 | I have changed the type expansion algorithm to only produce recursive type variables which occur only 43 | in positive or only in negative positions (the change was easy). Before that, the algorithm was probably 44 | wrong anyways, under our both-bounds interpretation. 45 | We'll no longer infer things like "('a -> 'b) as 'a", but rather "(('a -> 'b) -> 'b) as 'a", 46 | which is easier to co-occurrence-anlayse using the current indrastructure (otherwise, we'd have to keep 47 | track of the fact some recursive variables may appear both posly and negly as we do the analysis!) 48 | -------------------------------------------------------------------------------- /out/basic.check: -------------------------------------------------------------------------------- 1 | // 42 2 | int 3 | 4 | // fun x -> 42 5 | ⊤ -> int 6 | 7 | // fun x -> x 8 | 'a -> 'a 9 | 10 | // fun x -> x 42 11 | (int -> 'a) -> 'a 12 | 13 | // (fun x -> x) 42 14 | int 15 | 16 | // fun f -> fun x -> f (f x) // twice 17 | ('a -> 'a) -> 'a -> 'a 18 | 19 | // let twice = fun f -> fun x -> f (f x) in twice 20 | ('a -> 'a) -> 'a -> 'a 21 | 22 | -------------------------------------------------------------------------------- /out/booleans.check: -------------------------------------------------------------------------------- 1 | // true 2 | bool 3 | 4 | // not true 5 | bool 6 | 7 | // fun x -> not x 8 | bool -> bool 9 | 10 | // (fun x -> not x) true 11 | bool 12 | 13 | // fun x -> fun y -> fun z -> if x then y else z 14 | bool -> 'a -> 'a -> 'a 15 | 16 | // fun x -> fun y -> if x then y else x 17 | 'a ∧ bool -> 'a ∧ bool -> 'a 18 | 19 | // fun x -> { u = not x; v = x } 20 | 'a ∧ bool -> {u: bool, v: 'a} 21 | 22 | // [wrong:] succ true 23 | // ERROR: cannot constrain bool <: int 24 | 25 | // [wrong:] fun x -> succ (not x) 26 | // ERROR: cannot constrain bool <: int 27 | 28 | // [wrong:] (fun x -> not x.f) { f = 123 } 29 | // ERROR: cannot constrain int <: bool 30 | 31 | // [wrong:] (fun f -> fun x -> not (f x.u)) false 32 | // ERROR: cannot constrain bool <: 'a -> 'b 33 | 34 | -------------------------------------------------------------------------------- /out/isolated.check: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LPTK/simpler-sub/afe92450bd8803c0de13054f07eedb59f86f0373/out/isolated.check -------------------------------------------------------------------------------- /out/let-poly.check: -------------------------------------------------------------------------------- 1 | // let f = fun x -> x in {a = f 0; b = f true} 2 | {a: ⊤, b: ⊤} 3 | 4 | // fun y -> let f = fun x -> x in {a = f y; b = f true} 5 | 'a -> {a: 'a ∨ bool, b: 'a ∨ bool} 6 | 7 | // fun y -> let f = fun x -> y x in {a = f 0; b = f true} 8 | (⊤ -> 'a) -> {a: 'a, b: 'a} 9 | 10 | // fun y -> let f = fun x -> x y in {a = f (fun z -> z); b = f (fun z -> true)} 11 | 'a -> {a: 'a ∨ bool, b: 'a ∨ bool} 12 | 13 | // fun y -> let f = fun x -> x y in {a = f (fun z -> z); b = f (fun z -> succ z)} 14 | int -> {a: int, b: int} 15 | 16 | // [wrong:] (fun k -> k (fun x -> let tmp = add x 1 in x)) (fun f -> f true) 17 | // ERROR: cannot constrain bool <: int 18 | 19 | // [wrong:] (fun k -> let test = k (fun x -> let tmp = add x 1 in x) in test) (fun f -> f true) 20 | // ERROR: cannot constrain bool <: int 21 | 22 | -------------------------------------------------------------------------------- /out/occurs-check.check: -------------------------------------------------------------------------------- 1 | // [wrong:] fun x -> x.u x 2 | // ERROR: Illegal cyclic constraint: α1 <: (α0 -> α2) 3 | where: α0 <: {u: α1} 4 | 5 | // [wrong:] fun x -> x.u {v=x} 6 | // ERROR: Illegal cyclic constraint: α1 <: ({v: α0} -> α2) 7 | where: α0 <: {u: α1} 8 | 9 | // fun x -> x.u x.v 10 | {u: 'a -> 'b, v: 'a} -> 'b 11 | 12 | // [wrong:] fun x -> x.u.v x 13 | // ERROR: Illegal cyclic constraint: α2 <: (α0 -> α3) 14 | where: α0 <: {u: α1}, α1 <: {v: α2} 15 | 16 | -------------------------------------------------------------------------------- /out/random.check: -------------------------------------------------------------------------------- 1 | // (let rec x = {a = x; b = x} in x) 2 | // ERROR: Illegal cyclic constraint: α0 :> {a: α0, b: α0} 3 | 4 | // (let rec x = fun v -> {a = x v; b = x v} in x) 5 | // ERROR: Illegal cyclic constraint: α3 :> {a: α3, b: α3} 6 | 7 | // [wrong:] let rec x = (let rec y = {u = y; v = (x y)} in 0) in 0 8 | // ERROR: Unsupported: local recursive let binding 9 | 10 | // (fun x -> (let y = (x x) in 0)) 11 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2) 12 | 13 | // (let rec x = (fun y -> (y (x x))) in x) 14 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2) 15 | 16 | // fun next -> 0 17 | ⊤ -> int 18 | 19 | // ((fun x -> (x x)) (fun x -> x)) 20 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1) 21 | 22 | // (let rec x = (fun y -> (x (y y))) in x) 23 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α2) 24 | 25 | // fun x -> (fun y -> (x (y y))) 26 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α2) 27 | 28 | // (let rec x = (let y = (x x) in (fun z -> z)) in x) 29 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α3) 30 | 31 | // (let rec x = (fun y -> (let z = (x x) in y)) in x) 32 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α3) 33 | 34 | // (let rec x = (fun y -> {u = y; v = (x x)}) in x) 35 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2) 36 | 37 | // (let rec x = (fun y -> {u = (x x); v = y}) in x) 38 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2) 39 | 40 | // (let rec x = (fun y -> (let z = (y x) in y)) in x) 41 | // ERROR: Illegal cyclic constraint: α0 :> (α4 -> α4) 42 | where: α4 <: (α0 -> α2) 43 | 44 | // (fun x -> (let y = (x x.v) in 0)) 45 | ⊥ -> int 46 | 47 | // let rec x = (let y = (x x) in (fun z -> z)) in (x (fun y -> y.u)) 48 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α3) 49 | 50 | -------------------------------------------------------------------------------- /out/records.check: -------------------------------------------------------------------------------- 1 | // fun x -> x.f 2 | {f: 'a} -> 'a 3 | 4 | // {} 5 | {} 6 | 7 | // { f = 42 } 8 | {f: int} 9 | 10 | // { f = 42 }.f 11 | int 12 | 13 | // (fun x -> x.f) { f = 42 } 14 | int 15 | 16 | // fun f -> { x = f 42 }.x 17 | (int -> 'a) -> 'a 18 | 19 | // fun f -> { x = f 42; y = 123 }.y 20 | (int -> ⊤) -> int 21 | 22 | // if true then { a = 1; b = true } else { b = false; c = 42 } 23 | {b: bool} 24 | 25 | // if true then { u = 1; v = 2; w = 3 } else { u = true; v = 4; x = 5 } 26 | {u: ⊤, v: int} 27 | 28 | // if true then fun x -> { u = 1; v = x } else fun y -> { u = y; v = y } 29 | 'a -> {u: 'a ∨ int, v: 'a ∨ int} 30 | 31 | // [wrong:] { a = 123; b = true }.c 32 | // ERROR: missing field: c in {a: int, b: bool} 33 | 34 | // [wrong:] fun x -> { a = x }.b 35 | // ERROR: missing field: b in {a: 'a} 36 | 37 | -------------------------------------------------------------------------------- /out/recursion.check: -------------------------------------------------------------------------------- 1 | // let rec f = fun x -> f x.u in f 2 | // ERROR: Illegal cyclic constraint: α2 <: {u: α2} 3 | 4 | // let rec consume = fun strm -> add strm.head (consume strm.tail) in consume 5 | // ERROR: Illegal cyclic constraint: α4 <: {head: α2, tail: α4} 6 | where: α2 <: int 7 | 8 | // let rec r = fun a -> r in if true then r else r 9 | // ERROR: Illegal cyclic constraint: α0 :> (α1 -> α0) 10 | 11 | // let rec l = fun a -> l in let rec r = fun a -> fun a -> r in if true then l else r 12 | // ERROR: Illegal cyclic constraint: α0 :> (α1 -> α0) 13 | 14 | // let rec l = fun a -> fun a -> fun a -> l in let rec r = fun a -> fun a -> r in if true then l else r 15 | // ERROR: Illegal cyclic constraint: α0 :> (α1 -> (α2 -> (α3 -> α0))) 16 | 17 | // let rec recursive_monster = fun x -> { thing = x; self = recursive_monster x } in recursive_monster 18 | // ERROR: Illegal cyclic constraint: α2 :> {thing: α1, self: α2} 19 | 20 | -------------------------------------------------------------------------------- /out/self-app.check: -------------------------------------------------------------------------------- 1 | // fun x -> x x 2 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1) 3 | 4 | // fun x -> x x x 5 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1) 6 | 7 | // fun x -> fun y -> x y x 8 | // ERROR: Illegal cyclic constraint: α2 <: (α0 -> α3) 9 | where: α0 <: (α1 -> α2) 10 | 11 | // fun x -> fun y -> x x y 12 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2) 13 | 14 | // (fun x -> x x) (fun x -> x x) 15 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1) 16 | 17 | // fun x -> {l = x x; r = x } 18 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1) 19 | 20 | // (fun f -> (fun x -> f (x x)) (fun x -> f (x x))) 21 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α2) 22 | 23 | // (fun f -> (fun x -> f (fun v -> (x x) v)) (fun x -> f (fun v -> (x x) v))) 24 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α3) 25 | 26 | // (fun f -> (fun x -> f (fun v -> (x x) v)) (fun x -> f (fun v -> (x x) v))) (fun f -> fun x -> f) 27 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α3) 28 | 29 | // let rec trutru = fun g -> trutru (g true) in trutru 30 | // ERROR: Illegal cyclic constraint: α2 <: (bool -> α2) 31 | 32 | // fun i -> if ((i i) true) then true else true 33 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2) 34 | 35 | -------------------------------------------------------------------------------- /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=1.6.2 2 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("org.wartremover" % "sbt-wartremover" % "2.4.16") 2 | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.10.0") 3 | addSbtPlugin("org.portable-scala" % "sbt-scalajs-crossproject" % "1.2.0") 4 | -------------------------------------------------------------------------------- /shared/src/main/scala/simplesub/Parser.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | import scala.util.chaining._ 4 | import fastparse._, fastparse.ScalaWhitespace._ 5 | 6 | @SuppressWarnings(Array("org.wartremover.warts.All")) 7 | object Parser { 8 | 9 | val keywords = Set("let", "rec", "in", "fun", "if", "then", "else", "true", "false") 10 | def kw[__ : P](s: String) = s ~~ !(letter | digit | "_" | "'") 11 | 12 | def letter[__ : P] = P( lowercase | uppercase ) 13 | def lowercase[__ : P] = P( CharIn("a-z") ) 14 | def uppercase[__ : P] = P( CharIn("A-Z") ) 15 | def digit[__ : P] = P( CharIn("0-9") ) 16 | def number[__ : P]: P[Int] = P( CharIn("0-9").repX(1).!.map(_.toInt) ) 17 | def ident[__ : P]: P[String] = 18 | P( (letter | "_") ~~ (letter | digit | "_" | "'").repX ).!.filter(!keywords(_)) 19 | 20 | def term[__ : P]: P[Term] = P( let | fun | ite | apps ) 21 | def const[__ : P]: P[Term] = number.map(Lit) 22 | def variable[__ : P]: P[Term] = (ident | "true".! | "false".!).map(Var) 23 | def parens[__ : P]: P[Term] = P( "(" ~/ term ~ ")" ) 24 | def subtermNoSel[__ : P]: P[Term] = P( parens | record | const | variable ) 25 | def subterm[__ : P]: P[Term] = P( subtermNoSel ~ ("." ~/ ident).rep ).map { 26 | case (st, sels) => sels.foldLeft(st)(Sel) } 27 | def record[__ : P]: P[Term] = P( "{" ~/ (ident ~ "=" ~ term).rep(sep = ";") ~ "}" ) 28 | .filter(xs => xs.map(_._1).toSet.size === xs.size).map(_.toList pipe Rcd) 29 | def fun[__ : P]: P[Term] = P( kw("fun") ~/ ident ~ "->" ~ term ).map(Lam.tupled) 30 | def let[__ : P]: P[Term] = 31 | P( kw("let") ~/ kw("rec").!.?.map(_.isDefined) ~ ident ~ "=" ~ term ~ kw("in") ~ term ) 32 | .map(Let.tupled) 33 | def ite[__ : P]: P[Term] = P( kw("if") ~/ term ~ kw("then") ~ term ~ kw("else") ~ term ).map(ite => 34 | App(App(App(Var("if"), ite._1), ite._2), ite._3)) 35 | def apps[__ : P]: P[Term] = P( subterm.rep(1).map(_.reduce(App)) ) 36 | 37 | def expr[__ : P]: P[Term] = P( term ~ End ) 38 | 39 | def toplvl[__ : P]: P[(Boolean, String, Term)] = 40 | P( kw("let") ~/ kw("rec").!.?.map(_.isDefined) ~ ident ~ "=" ~ term ) 41 | def pgrm[__ : P]: P[Pgrm] = P( ("" ~ toplvl).rep.map(_.toList) ~ End ).map(Pgrm) 42 | 43 | } 44 | -------------------------------------------------------------------------------- /shared/src/main/scala/simplesub/Typer.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | import scala.collection.mutable 4 | import scala.collection.mutable.{Map => MutMap, Set => MutSet} 5 | import scala.collection.immutable.{SortedSet, SortedMap} 6 | import scala.util.chaining._ 7 | import scala.annotation.tailrec 8 | 9 | final case class TypeError(val msg: String) extends Exception(msg) 10 | 11 | /** A class encapsulating type inference state. 12 | * It uses its own internal representation of types and type variables, using mutable data structures. 13 | * Inferred SimpleType values are then turned into immutable simplesub.Type values. 14 | */ 15 | class Typer(protected val dbg: Boolean) extends TyperDebugging { 16 | 17 | type Ctx = Map[String, TypeScheme] 18 | 19 | val BoolType: Primitive = Primitive("bool") 20 | val IntType: Primitive = Primitive("int") 21 | 22 | val builtins: Ctx = Map( 23 | "true" -> BoolType, 24 | "false" -> BoolType, 25 | "not" -> Function(BoolType, BoolType), 26 | "succ" -> Function(IntType, IntType), 27 | "add" -> Function(IntType, Function(IntType, IntType)), 28 | "if" -> { 29 | val v = freshVar 30 | PolymorphicType(Function(BoolType, Function(v, Function(v, v)))) 31 | } 32 | ) 33 | private val builtinsSize = builtins.size 34 | 35 | /** The main type inference function */ 36 | def inferTypes(pgrm: Pgrm, ctx: Ctx = builtins): List[Either[TypeError, PolymorphicType]] = 37 | pgrm.defs match { 38 | case (isrec, nme, rhs) :: defs => 39 | val ty_sch = try Right(typeLetRhs(isrec, nme, rhs)(ctx)) catch { 40 | case err: TypeError => Left(err) } 41 | ty_sch :: inferTypes(Pgrm(defs), ctx + (nme -> ty_sch.getOrElse(freshVar))) 42 | case Nil => Nil 43 | } 44 | 45 | def inferType(term: Term, ctx: Ctx = builtins): SimpleType = typeTerm(term)(ctx) 46 | 47 | /** Infer the type of a let binding right-hand side. */ 48 | def typeLetRhs(isrec: Boolean, nme: String, rhs: Term)(implicit ctx: Ctx): PolymorphicType = { 49 | val res = if (isrec) { 50 | val e_ty = freshVar 51 | val ty = typeTerm(rhs)(ctx + (nme -> e_ty)) 52 | constrain(ty, e_ty) 53 | e_ty 54 | } else typeTerm(rhs)(ctx) 55 | PolymorphicType(res) 56 | } 57 | 58 | /** Infer the type of a term. */ 59 | def typeTerm(term: Term)(implicit ctx: Ctx): SimpleType = trace(s"T $term") { 60 | lazy val res = freshVar 61 | term match { 62 | case Var(name) => 63 | ctx.getOrElse(name, err("identifier not found: " + name)).instantiate 64 | case Lam(name, body) => 65 | val param = freshVar 66 | val body_ty = typeTerm(body)(ctx + (name -> param)) 67 | Function(param, body_ty) 68 | case App(f, a) => 69 | val f_ty = typeTerm(f) 70 | val a_ty = typeTerm(a) 71 | constrain(f_ty, Function(a_ty, res)) 72 | res 73 | case Lit(n) => 74 | IntType 75 | case Sel(obj, name) => 76 | val obj_ty = typeTerm(obj) 77 | constrain(obj_ty, Record((name, res) :: Nil)) 78 | res 79 | case Rcd(fs) => 80 | Record(fs.map { case (n, t) => (n, typeTerm(t)) }) 81 | case Let(isrec, nme, rhs, bod) => 82 | if (isrec) if (ctx.sizeCompare(builtinsSize) <= 0) { 83 | val n_ty = typeLetRhs(isrec, nme, rhs) 84 | typeTerm(bod)(ctx + (nme -> n_ty)) 85 | } else err("Unsupported: local recursive let binding") 86 | else typeTerm(App(Lam(nme, bod), rhs)) 87 | } 88 | }(res => ": " + res) 89 | 90 | /** Constrains the types to enforce a subtyping relationship `lhs` <: `rhs`. */ 91 | def constrain(lhs: SimpleType, rhs: SimpleType): Unit = trace(s"C $lhs <: $rhs") { 92 | if (lhs is rhs) return () 93 | (lhs, rhs) match { 94 | case (Bot, _) | (_, Top) => () 95 | case (Function(l0, r0), Function(l1, r1)) => 96 | constrain(l1, l0) 97 | constrain(r0, r1) 98 | case (Record(fs0), Record(fs1)) => 99 | fs1.foreach { case (n1, t1) => 100 | fs0.find(_._1 === n1).fold( 101 | err(s"missing field: $n1 in ${lhs.show}") 102 | ) { case (n0, t0) => constrain(t0, t1) } 103 | } 104 | case (lhs: Variable, rhs: Variable) => 105 | lhs.unifyWith(rhs) 106 | case (lhs: Variable, rhs: ConcreteType) => 107 | lhs.newUpperBound(rhs) 108 | case (lhs: ConcreteType, rhs: Variable) => 109 | rhs.newLowerBound(lhs) 110 | case _ => err(s"cannot constrain ${lhs.show} <: ${rhs.show}") 111 | } 112 | }() 113 | 114 | def err(msg: String): Nothing = throw TypeError(msg) 115 | 116 | private var freshCount = 0 117 | def freshVar: Variable = new Variable(Bot, Top) 118 | 119 | def freshenType(ty: SimpleType): SimpleType = { 120 | val freshened = MutMap.empty[Variable, Variable] 121 | def freshen(ty: SimpleType): SimpleType = ty match { 122 | case tv: Variable => 123 | freshened.get(tv) match { 124 | case Some(tv) => tv 125 | case None => 126 | val v = new Variable( 127 | freshenConcrete(tv.lowerBound), 128 | freshenConcrete(tv.upperBound)) 129 | freshened += tv -> v 130 | v 131 | } 132 | case c: ConcreteType => freshenConcrete(c) 133 | } 134 | def freshenConcrete(ty: ConcreteType): ConcreteType = ty match { 135 | case Function(l, r) => Function(freshen(l), freshen(r)) 136 | case Record(fs) => Record(fs.map(ft => ft._1 -> freshen(ft._2))) 137 | case Primitive(_) | Top | Bot => ty 138 | } 139 | freshen(ty) 140 | } 141 | 142 | def glbConcrete(lhs: ConcreteType, rhs: ConcreteType): ConcreteType = (lhs, rhs) match { 143 | case (Top, _) => rhs 144 | case (_, Top) => lhs 145 | case (Bot, _) | (_, Bot) => Bot 146 | case (Function(l0, r0), Function(l1, r1)) => Function(lub(l0, l1), glb(r0, r1)) 147 | case (Record(fs0), Record(fs1)) => Record(mergeMap(fs0, fs1)(glb(_, _)).toList) 148 | case (Primitive(n0), Primitive(n1)) if n0 === n1 => Primitive(n0) 149 | case _ => Bot 150 | } 151 | def lubConcrete(lhs: ConcreteType, rhs: ConcreteType): ConcreteType = (lhs, rhs) match { 152 | case (Bot, _) => rhs 153 | case (_, Bot) => lhs 154 | case (Top, _) | (_, Top) => Top 155 | case (Function(l0, r0), Function(l1, r1)) => Function(glb(l0, l1), lub(r0, r1)) 156 | case (Record(fs0), Record(fs1)) => 157 | val fs1m = fs1.toMap 158 | Record(fs0.flatMap { 159 | case (n, t0) => fs1m.get(n) match { case Some(t1) => n -> lub(t0, t1) :: Nil; case None => Nil } 160 | }) 161 | case (Primitive(n0), Primitive(n1)) if n0 === n1 => Primitive(n0) 162 | case _ => Top 163 | } 164 | def glb(lhs: SimpleType, rhs: SimpleType): SimpleType = (lhs, rhs) match { 165 | case (c0: ConcreteType, c1: ConcreteType) => glbConcrete(c0, c1) 166 | case (v0: Variable, v1: Variable) => v0.unifyWith(v1); v1 167 | case (c0: ConcreteType, v1: Variable) => v1.newUpperBound(c0); v1 168 | case (v0: Variable, c1: ConcreteType) => v0.newUpperBound(c1); v0 169 | } 170 | def lub(lhs: SimpleType, rhs: SimpleType): SimpleType = (lhs, rhs) match { 171 | case (c0: ConcreteType, c1: ConcreteType) => lubConcrete(c0, c1) 172 | case (v0: Variable, v1: Variable) => v0.unifyWith(v1); v1 173 | case (c0: ConcreteType, v1: Variable) => v1.newLowerBound(c0); v1 174 | case (v0: Variable, c1: ConcreteType) => v0.newLowerBound(c1); v0 175 | } 176 | 177 | 178 | // The data types used for type inference: 179 | 180 | /** A type that potentially contains universally quantified type variables, 181 | * and which can be isntantiated to a given level. */ 182 | sealed abstract class TypeScheme { 183 | def instantiate: SimpleType 184 | } 185 | /** A type with universally quantified type variables 186 | * (by convention, those variables of level greater than `level` are considered quantified). */ 187 | case class PolymorphicType(body: SimpleType) extends TypeScheme { 188 | def instantiate = freshenType(body) 189 | } 190 | /** A type without universally quantified type variables. */ 191 | sealed abstract class SimpleType extends TypeScheme with SimpleTypeImpl { 192 | def instantiate = this 193 | } 194 | 195 | sealed abstract class ConcreteType extends SimpleType 196 | 197 | case object Top extends ConcreteType 198 | case object Bot extends ConcreteType 199 | 200 | case class Function(lhs: SimpleType, rhs: SimpleType) extends ConcreteType { 201 | override def toString = s"($lhs -> $rhs)" 202 | } 203 | case class Record(fields: List[(String, SimpleType)]) extends ConcreteType { 204 | override def toString = s"{${fields.map(f => s"${f._1}: ${f._2}").mkString(", ")}}" 205 | } 206 | case class Primitive(name: String) extends ConcreteType { 207 | override def toString = name 208 | } 209 | 210 | /** A type variable with mutable bounds. */ 211 | final class Variable( 212 | private var _lowerBound: ConcreteType, 213 | private var _upperBound: ConcreteType, 214 | ) extends SimpleType { 215 | private val _uid: Int = { val n = freshCount; freshCount += 1; n } 216 | def uid: Int = representative._uid 217 | private var _representative: Option[Variable] = None 218 | def representative: Variable = 219 | _representative match { 220 | case Some(v) => 221 | val rep = v.representative 222 | if (rep isnt v) _representative = Some(rep) 223 | rep 224 | case None => this 225 | } 226 | def lowerBound: ConcreteType = representative._lowerBound 227 | def upperBound: ConcreteType = representative._upperBound 228 | def newUpperBound(ub: ConcreteType): Unit = { 229 | occursCheck(ub, true) 230 | val rep = representative 231 | rep._upperBound = glbConcrete(rep._upperBound, ub) 232 | constrain(rep._lowerBound, ub) 233 | } 234 | def newLowerBound(lb: ConcreteType): Unit = { 235 | occursCheck(lb, false) 236 | val rep = representative 237 | rep._lowerBound = lubConcrete(rep._lowerBound, lb) 238 | constrain(lb, rep._upperBound) 239 | } 240 | private def occursCheck(ty: ConcreteType, dir: Boolean): Unit = { 241 | if (ty.getVars.contains(representative)) { 242 | val boundsStr = ty.showBounds 243 | err(s"Illegal cyclic constraint: $this ${if (dir) "<:" else ":>"} $ty" 244 | + (if (boundsStr.isEmpty) "" else "\n\t\twhere: " + boundsStr)) 245 | } 246 | } 247 | lazy val asTypeVar = new TypeVariable("α", uid) 248 | def unifyWith(that: Variable): Unit = { 249 | val rep0 = representative 250 | val rep1 = that.representative 251 | if (rep0 isnt rep1) { 252 | // Note: these is occursCheck calls (and the following ones from addXBound are pretty 253 | // inefficient as they will incur repeated computation of type variables through getVars: 254 | occursCheck(rep1._lowerBound, false) 255 | occursCheck(rep1._upperBound, true) 256 | rep1.newLowerBound(rep0._lowerBound) 257 | rep1.newUpperBound(rep0._upperBound) 258 | rep0._representative = Some(rep1) 259 | } 260 | } 261 | override def toString: String = 262 | // _representative.fold("α" + uid)(_.toString + "<~" + uid) 263 | "α" + representative.uid 264 | override def hashCode: Int = representative.uid 265 | override def equals(that: Any): Boolean = that match { 266 | case that: Typer#Variable => representative is that.representative 267 | case _ => false 268 | } 269 | } 270 | 271 | 272 | 273 | 274 | def simplifyType(st: SimpleType): SimpleType = { 275 | 276 | val pos, neg = mutable.Set.empty[Variable] 277 | 278 | def analyze(st: SimpleType, pol: Boolean): Unit = st match { 279 | case Record(fs) => fs.foreach(f => analyze(f._2, pol)) 280 | case Function(l, r) => analyze(l, !pol); analyze(r, pol) 281 | case v: Variable => 282 | (if (pol) pos else neg) += v 283 | analyze(if (pol) v.lowerBound else v.upperBound, pol) 284 | case Primitive(_) | Top | Bot => () 285 | } 286 | analyze(st, true) 287 | 288 | val mapping = mutable.Map.empty[Variable, SimpleType] 289 | 290 | def transformConcrete(st: ConcreteType, pol: Boolean): ConcreteType = st match { 291 | case Record(fs) => Record(fs.map(f => f._1 -> transform(f._2, pol))) 292 | case Function(l, r) => Function(transform(l, !pol), transform(r, pol)) 293 | case Primitive(_) | Top | Bot => st 294 | } 295 | def transform(st: SimpleType, pol: Boolean): SimpleType = st match { 296 | case v: Variable => 297 | mapping.getOrElseUpdate(v, 298 | if (v.lowerBound === v.upperBound) transformConcrete(v.lowerBound, pol) 299 | else if (pol && !neg(v)) transformConcrete(v.lowerBound, pol) 300 | else if (!pol && !pos(v)) transformConcrete(v.upperBound, pol) 301 | else new Variable(transformConcrete(v.lowerBound, true), transformConcrete(v.upperBound, false)) 302 | ) 303 | case c: ConcreteType => transformConcrete(c, pol) 304 | } 305 | transform(st, true) 306 | 307 | } 308 | 309 | 310 | type PolarVariable = (Variable, Boolean) 311 | 312 | /** Convert an inferred SimpleType into the immutable Type representation. */ 313 | def coalesceType(st: SimpleType): Type = { 314 | def go(st: SimpleType, polarity: Boolean): Type = st match { 315 | case tv: Variable => 316 | val bound = if (polarity) tv.lowerBound else tv.upperBound 317 | val boundType = go(bound, polarity) 318 | val mrg = if (polarity) Union else Inter 319 | if (polarity && bound === Bot || bound === Top) tv.asTypeVar 320 | else mrg(tv.asTypeVar, boundType) 321 | case Function(l, r) => FunctionType(go(l, !polarity), go(r, polarity)) 322 | case Record(fs) => RecordType(fs.map(nt => nt._1 -> go(nt._2, polarity))) 323 | case Primitive(n) => PrimitiveType(n) 324 | case Top => PrimitiveType("⊤") 325 | case Bot => PrimitiveType("⊥") 326 | } 327 | go(st, true) 328 | } 329 | 330 | 331 | } 332 | -------------------------------------------------------------------------------- /shared/src/main/scala/simplesub/TyperDebugging.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | import scala.collection.mutable.{Map => MutMap, Set => MutSet, LinkedHashMap, LinkedHashSet} 4 | import scala.collection.immutable.{SortedMap, SortedSet} 5 | import scala.annotation.tailrec 6 | 7 | /** Inessential methods used to help debugging. */ 8 | abstract class TyperDebugging { self: Typer => 9 | 10 | // Shadow Predef functions with debugging-flag-enabled ones: 11 | def println(msg: => Any): Unit = if (dbg) emitDbg(" " * indent + msg) 12 | def assert(assertion: => Boolean): Unit = if (dbg) scala.Predef.assert(assertion) 13 | 14 | private val noPostTrace: Any => String = _ => "" 15 | 16 | protected var indent = 0 17 | def trace[T](pre: String)(thunk: => T)(post: T => String = noPostTrace): T = { 18 | println(pre) 19 | indent += 1 20 | val res = try thunk finally indent -= 1 21 | if (post isnt noPostTrace) println(post(res)) 22 | res 23 | } 24 | def emitDbg(str: String): Unit = scala.Predef.println(str) 25 | 26 | trait SimpleTypeImpl { self: SimpleType => 27 | 28 | def children: List[SimpleType] = this match { 29 | case tv: Variable => tv.lowerBound :: tv.upperBound :: Nil 30 | case Function(l, r) => l :: r :: Nil 31 | case Record(fs) => fs.map(_._2) 32 | case Primitive(_) => Nil 33 | case Top | Bot => Nil 34 | } 35 | def getVars: Set[Variable] = { 36 | val res = MutSet.empty[Variable] 37 | @tailrec def rec(queue: List[SimpleType]): Unit = queue match { 38 | case (tv: Variable) :: tys => 39 | if (res(tv)) rec(tys) 40 | else { res += tv; rec(tv.children ::: tys) } 41 | case ty :: tys => rec(ty.children ::: tys) 42 | case Nil => () 43 | } 44 | rec(this :: Nil) 45 | SortedSet.from(res)(Ordering.by(_.uid)) 46 | } 47 | def show: String = coalesceType(this).show 48 | def showBounds: String = 49 | getVars.iterator.filter(tv => tv.upperBound =/= Top || tv.lowerBound =/= Bot).map(tv => 50 | tv.toString 51 | + (if (tv.lowerBound === Bot) "" else " :> " + tv.lowerBound) 52 | + (if (tv.upperBound === Top) "" else " <: " + tv.upperBound) 53 | ).mkString(", ") 54 | 55 | } 56 | 57 | } 58 | -------------------------------------------------------------------------------- /shared/src/main/scala/simplesub/helpers.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | import scala.util.chaining._ 4 | import scala.collection.mutable.{Map => MutMap, SortedMap => SortedMutMap, Set => MutSet} 5 | import scala.collection.immutable.SortedSet 6 | 7 | 8 | // Helper methods for types 9 | 10 | abstract class TypeImpl { self: Type => 11 | 12 | lazy val typeVarsList: List[TypeVariable] = this match { 13 | case uv: TypeVariable => uv :: Nil 14 | case _ => children.flatMap(_.typeVarsList) 15 | } 16 | 17 | def show: String = { 18 | val vars = typeVarsList.distinct 19 | val ctx = vars.zipWithIndex.map { 20 | case (tv, idx) => 21 | def nme = { 22 | assert(idx <= 'z' - 'a', "TODO handle case of not enough chars") 23 | ('a' + idx).toChar.toString 24 | } 25 | tv -> ("'" + nme) 26 | }.toMap 27 | showIn(ctx, 0) 28 | } 29 | 30 | private def parensIf(str: String, cnd: Boolean): String = if (cnd) "(" + str + ")" else str 31 | def showIn(ctx: Map[TypeVariable, String], outerPrec: Int): String = this match { 32 | case Top => "⊤" 33 | case Bot => "⊥" 34 | case PrimitiveType(name) => name 35 | case uv: TypeVariable => ctx(uv) 36 | case FunctionType(l, r) => parensIf(l.showIn(ctx, 11) + " -> " + r.showIn(ctx, 10), outerPrec > 10) 37 | case RecordType(fs) => fs.map(nt => s"${nt._1}: ${nt._2.showIn(ctx, 0)}").mkString("{", ", ", "}") 38 | case Union(l, r) => parensIf(l.showIn(ctx, 20) + " ∨ " + r.showIn(ctx, 20), outerPrec > 20) 39 | case Inter(l, r) => parensIf(l.showIn(ctx, 25) + " ∧ " + r.showIn(ctx, 25), outerPrec > 25) 40 | } 41 | 42 | def children: List[Type] = this match { 43 | case _: PrimitiveType | _: TypeVariable | Top | Bot => Nil 44 | case FunctionType(l, r) => l :: r :: Nil 45 | case RecordType(fs) => fs.map(_._2) 46 | case Union(l, r) => l :: r :: Nil 47 | case Inter(l, r) => l :: r :: Nil 48 | } 49 | 50 | } 51 | -------------------------------------------------------------------------------- /shared/src/main/scala/simplesub/package.scala: -------------------------------------------------------------------------------- 1 | package object simplesub { 2 | 3 | import scala.collection.mutable 4 | import scala.collection.immutable.SortedMap 5 | 6 | @SuppressWarnings(Array( 7 | "org.wartremover.warts.Equals", 8 | "org.wartremover.warts.AsInstanceOf")) 9 | implicit final class AnyOps[A](self: A) { 10 | def ===(other: A): Boolean = self == other 11 | def =/=(other: A): Boolean = self != other 12 | def is(other: AnyRef): Boolean = self.asInstanceOf[AnyRef] eq other 13 | def isnt(other: AnyRef): Boolean = !(self.asInstanceOf[AnyRef] eq other) 14 | /** An alternative to === when in ScalaTest, which shadows our === */ 15 | def =:=(other: A): Boolean = self == other 16 | } 17 | 18 | implicit class IterableOps[A](private val self: IterableOnce[A]) extends AnyVal { 19 | def mkStringOr( 20 | sep: String = "", start: String = "", end: String = "", els: String = "" 21 | ): String = 22 | if (self.iterator.nonEmpty) self.iterator.mkString(start, sep, end) else els 23 | } 24 | 25 | def mergeOptions[A](lhs: Option[A], rhs: Option[A])(f: (A, A) => A): Option[A] = (lhs, rhs) match { 26 | case (Some(l), Some(r)) => Some(f(l, r)) 27 | case (lhs @ Some(_), _) => lhs 28 | case (_, rhs @ Some(_)) => rhs 29 | case (None, None) => None 30 | } 31 | 32 | def mergeMap[A, B](lhs: Iterable[(A, B)], rhs: Iterable[(A, B)])(f: (B, B) => B): Map[A,B] = 33 | new mutable.ArrayBuffer(lhs.knownSize + rhs.knownSize max 8) 34 | .addAll(lhs).addAll(rhs).groupMapReduce(_._1)(_._2)(f) 35 | 36 | def mergeSortedMap[A: Ordering, B](lhs: Iterable[(A, B)], rhs: Iterable[(A, B)])(f: (B, B) => B): SortedMap[A,B] = 37 | SortedMap.from(mergeMap(lhs, rhs)(f)) 38 | 39 | def closeOver[A](xs: Set[A])(f: A => Set[A]): Set[A] = 40 | closeOverCached(Set.empty, xs)(f) 41 | def closeOverCached[A](done: Set[A], todo: Set[A])(f: A => Set[A]): Set[A] = 42 | if (todo.isEmpty) done else { 43 | val newDone = done ++ todo 44 | closeOverCached(newDone, todo.flatMap(f) -- newDone)(f) 45 | } 46 | 47 | } 48 | -------------------------------------------------------------------------------- /shared/src/main/scala/simplesub/syntax.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | 4 | // Terms 5 | 6 | final case class Pgrm(defs: List[(Boolean, String, Term)]) 7 | 8 | sealed abstract class Term 9 | final case class Lit(value: Int) extends Term 10 | final case class Var(name: String) extends Term 11 | final case class Lam(name: String, rhs: Term) extends Term 12 | final case class App(lhs: Term, rhs: Term) extends Term 13 | final case class Rcd(fields: List[(String, Term)]) extends Term 14 | final case class Sel(receiver: Term, fieldName: String) extends Term 15 | final case class Let(isRec: Boolean, name: String, rhs: Term, body: Term) extends Term 16 | 17 | 18 | // Types 19 | 20 | sealed abstract class Type extends TypeImpl 21 | case object Top extends Type 22 | case object Bot extends Type 23 | final case class Union(lhs: Type, rhs: Type) extends Type 24 | final case class Inter(lhs: Type, rhs: Type) extends Type 25 | final case class FunctionType(lhs: Type, rhs: Type) extends Type 26 | final case class RecordType(fields: List[(String, Type)]) extends Type 27 | final case class PrimitiveType(name: String) extends Type 28 | final class TypeVariable(val nameHint: String, val hash: Int) extends Type { 29 | override def toString: String = s"$nameHint:$hash" 30 | } 31 | 32 | -------------------------------------------------------------------------------- /shared/src/test/scala/simplesub/IsolatedTests.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | @SuppressWarnings(Array("org.wartremover.warts.Equals")) 4 | class IsolatedTests extends TypingTestHelpers { 5 | 6 | // This test class is for isolating single tests and running them alone 7 | // with sbt command `~testOnly simplesub.IsolatedTests` 8 | 9 | test("isolated") { 10 | 11 | // put your test here 12 | 13 | 14 | } 15 | 16 | } 17 | -------------------------------------------------------------------------------- /shared/src/test/scala/simplesub/OtherTests.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | import org.scalatest.funsuite.AnyFunSuite 4 | 5 | @SuppressWarnings(Array("org.wartremover.warts.Equals")) 6 | class OtherTests extends AnyFunSuite { 7 | /* 8 | test("canonicalization produces LCD") { 9 | 10 | val typer = new Typer(false) with TypeSimplifier 11 | import typer.{assert => _, _} 12 | val tv0, tv1, tv3 = freshVar(0) 13 | 14 | // {f: {B: int, f: 'a}} as 'a – cycle length 2 15 | val st0 = Record("f"->Record("f"->tv0::"B"->IntType::Nil)::Nil) 16 | tv0.lowerBounds ::= st0 17 | 18 | // {f: {B: int, f: {A: int, f: 'a}}} as 'a – cycle length 3 19 | val st1 = Record("f"->Record("f"->Record("f"->tv1::"A"->IntType::Nil)::"B"->IntType::Nil)::Nil) 20 | tv1.lowerBounds ::= st1 21 | tv3.lowerBounds = tv0 :: tv1 :: Nil 22 | 23 | // println(tv3.showBounds) 24 | 25 | val ct = canonicalizeType(tv3) 26 | val sct = simplifyType(ct) 27 | val csct = coalesceCompactType(sct).show 28 | 29 | assert(csct == "{f: {B: int, f: {f: {f: {f: {f: 'a}}}}}} as 'a") // cycle length 6 30 | 31 | } 32 | */ 33 | } 34 | -------------------------------------------------------------------------------- /shared/src/test/scala/simplesub/ParserTests.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | import org.scalatest._ 4 | import fastparse._ 5 | import Parser.expr 6 | import fastparse.Parsed.Failure 7 | import fastparse.Parsed.Success 8 | import org.scalatest.funsuite.AnyFunSuite 9 | 10 | class ParserTests extends AnyFunSuite { 11 | 12 | def doTest(str: String): Unit = { 13 | parse(str, expr(_), verboseFailures = true) match { 14 | case Success(value, index) => 15 | // println("OK: " + value) 16 | case f: Failure => 17 | val Failure(expected, failIndex, extra) = f 18 | println(extra.trace()) 19 | println(extra.trace().longAggregateMsg) 20 | assert(false) 21 | } 22 | () 23 | } 24 | 25 | test("basics") { 26 | doTest("1") 27 | doTest("a") 28 | doTest("1 2 3") 29 | doTest("a b c") 30 | doTest("true") 31 | } 32 | 33 | test("let") { 34 | doTest("let a = b in c") 35 | doTest("let a = 1 in 1") 36 | doTest("let a = (1) in 1") 37 | assert(!parse("let true = 0 in true", expr(_)).isSuccess) 38 | } 39 | 40 | test("records") { 41 | doTest("{ a = 1; b = 2 }") 42 | assert(!parse("{ a = 1; b = 2; a = 3 }", expr(_)).isSuccess) 43 | } 44 | 45 | } 46 | -------------------------------------------------------------------------------- /shared/src/test/scala/simplesub/ProgramTests.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | import org.scalatest._ 4 | import fastparse._ 5 | import Parser.pgrm 6 | import fastparse.Parsed.Failure 7 | import fastparse.Parsed.Success 8 | import sourcecode.Line 9 | import org.scalatest.funsuite.AnyFunSuite 10 | 11 | @SuppressWarnings(Array("org.wartremover.warts.Equals")) 12 | class ProgramTests extends AnyFunSuite { 13 | 14 | // TODO port these tests 15 | /* 16 | implicit class ExpectedStr(val str: String)(implicit val line: Line) 17 | 18 | def doTest(str: String)(expected: ExpectedStr*): Unit = { 19 | val dbg = expected.exists(_.str.isEmpty) 20 | val Success(p, index) = parse(str, pgrm(_), verboseFailures = true) 21 | val typer = new Typer(dbg) with TypeSimplifier 22 | val tys = typer.inferTypes(p) 23 | var toPrint: List[String] = Nil 24 | (p.defs lazyZip tys lazyZip expected).foreach { (str, pty, exp) => 25 | if (exp.str.isEmpty) println(s">>> $str") 26 | val ty = pty.fold(err => throw err, _.instantiate) 27 | // val cty = typer.canonicalizeType(ty) 28 | // val sty = typer.simplifyType(cty) 29 | // val res = typer.coalesceCompactType(sty).show 30 | val res = typer.coalesceType(ty).show 31 | if (exp.str.nonEmpty) { assert(res == exp.str, "at line " + exp.line.value); () } 32 | else { 33 | toPrint ::= res 34 | println("inferred: " + ty) 35 | println(" where " + ty.showBounds) 36 | println(res) 37 | println("---") 38 | } 39 | } 40 | if (toPrint.nonEmpty) toPrint.reverseIterator.foreach(s => println("Inferred: " + s)) 41 | assert(tys.size == expected.size); () 42 | } 43 | 44 | test("mlsub") { // from https://www.cl.cam.ac.uk/~sd601/mlsub/ 45 | doTest(""" 46 | let id = fun x -> x 47 | let twice = fun f -> fun x -> f (f x) 48 | let object1 = { x = 42; y = id } 49 | let object2 = { x = 17; y = false } 50 | let pick_an_object = fun b -> 51 | if b then object1 else object2 52 | let rec recursive_monster = fun x -> 53 | { thing = x; 54 | self = recursive_monster x } 55 | """)( 56 | "'a -> 'a", 57 | "('a ∨ 'b -> 'a) -> 'b -> 'a", 58 | "{x: int, y: 'a -> 'a}", 59 | "{x: int, y: bool}", 60 | "bool -> {x: int, y: bool ∨ ('a -> 'a)}", 61 | "'a -> {self: 'b, thing: 'a} as 'b", 62 | ) 63 | } 64 | 65 | test("top-level-polymorphism") { 66 | doTest(""" 67 | let id = fun x -> x 68 | let ab = {u = id 0; v = id true} 69 | """)( 70 | "'a -> 'a", 71 | "{u: int, v: bool}", 72 | ) 73 | } 74 | 75 | test("rec-producer-consumer") { 76 | doTest(""" 77 | let rec produce = fun arg -> { head = arg; tail = produce (succ arg) } 78 | let rec consume = fun strm -> add strm.head (consume strm.tail) 79 | 80 | let codata = produce 42 81 | let res = consume codata 82 | 83 | let rec codata2 = { head = 0; tail = { head = 1; tail = codata2 } } 84 | let res = consume codata2 85 | 86 | let rec produce3 = fun b -> { head = 123; tail = if b then codata else codata2 } 87 | let res = fun x -> consume (produce3 x) 88 | 89 | let consume2 = 90 | let rec go = fun strm -> add strm.head (add strm.tail.head (go strm.tail.tail)) 91 | in fun strm -> add strm.head (go strm.tail) 92 | // in go 93 | // let rec consume2 = fun strm -> add strm.head (add strm.tail.head (consume2 strm.tail.tail)) 94 | let res = consume2 codata2 95 | """)( 96 | "int -> {head: int, tail: 'a} as 'a", 97 | "{head: int, tail: 'a} as 'a -> int", 98 | "{head: int, tail: 'a} as 'a", 99 | "int", 100 | "{head: int, tail: {head: int, tail: 'a}} as 'a", 101 | "int", 102 | "bool -> {head: int, tail: {head: int, tail: 'a}} as 'a", 103 | // ^ simplifying this would probably require more advanced 104 | // automata-based techniques such as the one proposed by Dolan 105 | "bool -> int", 106 | "{head: int, tail: {head: int, tail: 'a}} as 'a -> int", 107 | "int", 108 | ) 109 | } 110 | 111 | test("misc") { 112 | doTest(""" 113 | // 114 | // From a comment on the blog post: 115 | // 116 | let rec r = fun a -> r 117 | let join = fun a -> fun b -> if true then a else b 118 | let s = join r r 119 | // 120 | // Inspired by [Pottier 98, chap 13.4] 121 | // 122 | let rec f = fun x -> fun y -> add (f x.tail y) (f x y) 123 | let rec f = fun x -> fun y -> add (f x.tail y) (f y x) 124 | let rec f = fun x -> fun y -> add (f x.tail y) (f x y.tail) 125 | let rec f = fun x -> fun y -> add (f x.tail y.tail) (f x.tail y.tail) 126 | let rec f = fun x -> fun y -> add (f x.tail x.tail) (f y.tail y.tail) 127 | let rec f = fun x -> fun y -> add (f x.tail x) (f y.tail y) 128 | let rec f = fun x -> fun y -> add (f x.tail y) (f y.tail x) 129 | // 130 | let f = fun x -> fun y -> if true then { l = x; r = y } else { l = y; r = x } // 2-crown 131 | // 132 | // Inspired by [Pottier 98, chap 13.5] 133 | // 134 | let rec f = fun x -> fun y -> if true then x else { t = f x.t y.t } 135 | """)( 136 | "(⊤ -> 'a) as 'a", 137 | "'a -> 'a -> 'a", 138 | "(⊤ -> 'a) as 'a", 139 | 140 | "{tail: 'a} as 'a -> ⊤ -> int", 141 | "{tail: 'a} as 'a -> {tail: 'b} as 'b -> int", 142 | "{tail: 'a} as 'a -> {tail: 'b} as 'b -> int", 143 | "{tail: 'a} as 'a -> {tail: 'b} as 'b -> int", 144 | "{tail: {tail: 'a} as 'a} -> {tail: {tail: 'b} as 'b} -> int", 145 | // ^ Could simplify more `{tail: {tail: 'a} as 'a}` to `{tail: 'a} as 'a` 146 | // This would likely require another hash-consing pass. 147 | // Indeed, currently, we coalesce {tail: ‹{tail: ‹α25›}›} and it's hash-consing 148 | // which introduces the 'a to stand for {tail: ‹α25›} 149 | // ^ Note: MLsub says: 150 | // let rec f = fun x -> fun y -> (f x.tail x) + (f y.tail y) 151 | // val f : ({tail : (rec b = {tail : b})} -> ({tail : {tail : (rec a = {tail : a})}} -> int)) 152 | "{tail: 'a} as 'a -> {tail: {tail: 'b} as 'b} -> int", 153 | // ^ Note: MLsub says: 154 | // let rec f = fun x -> fun y -> (f x.tail x.tail) + (f y.tail y.tail) 155 | // val f : ({tail : {tail : (rec b = {tail : b})}} -> ({tail : {tail : (rec a = {tail : a})}} -> int)) 156 | "{tail: 'a} as 'a -> {tail: {tail: 'b} as 'b} -> int", 157 | 158 | "'a -> 'a -> {l: 'a, r: 'a}", 159 | 160 | "('b ∧ {t: 'a}) as 'a -> {t: 'c} as 'c -> ('b ∨ {t: 'd}) as 'd", 161 | // ^ Note: MLsub says: 162 | // let rec f = fun x -> fun y -> if true then x else { t = f x.t y.t } 163 | // val f : (({t : (rec d = ({t : d} & a))} & a) -> ({t : (rec c = {t : c})} -> ({t : (rec b = ({t : b} | a))} | a))) 164 | // ^ Pottier says a simplified version would essentially be, once translated to MLsub types: 165 | // {t: 'a} as 'a -> 'a -> {t: 'd} as 'd 166 | // but even he does not infer that. 167 | // Notice the loss of connection between the first parameetr and the result, in his proposed type, 168 | // which he says is not necessary as it is actually implied. 169 | // He argues that if 'a <: F 'a and F 'b <: 'b then 'a <: 'b, for a type operator F, 170 | // which does indeed seem true (even in MLsub), 171 | // though leveraging such facts for simplification would require much more advanced reasoning. 172 | ) 173 | } 174 | */ 175 | 176 | } 177 | -------------------------------------------------------------------------------- /shared/src/test/scala/simplesub/TypingTestHelpers.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | import org.scalatest._ 4 | import fastparse._ 5 | import Parser.expr 6 | import fastparse.Parsed.Failure 7 | import fastparse.Parsed.Success 8 | import sourcecode.Line 9 | import org.scalatest.funsuite.AnyFunSuite 10 | import ammonite.ops._ 11 | 12 | @SuppressWarnings(Array("org.wartremover.warts.Equals")) 13 | class TypingTestHelpers extends AnyFunSuite { 14 | 15 | private var outFile = Option.empty[os.Path] 16 | 17 | override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: org.scalactic.source.Position): Unit = { 18 | super.test(testName, testTags: _*) { 19 | assert(outFile.isEmpty) 20 | val f = pwd/"out"/(testName+".check") 21 | outFile = Some(f) 22 | write.over(f, "") 23 | try testFun 24 | finally outFile = None 25 | } 26 | } 27 | 28 | def doTest(str: String, expected: String = "", expectError: Boolean = false)(implicit line: Line): Unit = { 29 | val dbg = expected.isEmpty 30 | 31 | if (dbg) println(s">>> $str") 32 | val Success(term, index) = parse(str, expr(_), verboseFailures = true) 33 | 34 | val typer = new Typer(dbg) 35 | val res = try { 36 | val tyv = typer.inferType(term) 37 | 38 | if (dbg) { 39 | println("inferred: " + tyv) 40 | println(" where " + tyv.showBounds) 41 | } 42 | 43 | val res0 = typer.simplifyType(tyv) 44 | if (dbg) { 45 | println("simplified: " + res0) 46 | println(" where " + res0.showBounds) 47 | } 48 | 49 | val res = typer.coalesceType(res0).show 50 | 51 | if (dbg) { 52 | println("typed: " + res) 53 | println("---") 54 | } else { 55 | // assert(res == expected, "at line " + line.value); () 56 | } 57 | 58 | res 59 | } catch { 60 | case e: TypeError => 61 | if (dbg) { 62 | println("ERROR: " + e.msg) 63 | println("---") 64 | } 65 | "// ERROR: " + e.msg 66 | } 67 | write.append(outFile.getOrElse(fail()), 68 | "// " + (if (expectError) "[wrong:] " else "") + str + "\n" + res + "\n\n", createFolders = true) 69 | } 70 | def error(str: String, msg: String): Unit = { 71 | // assert(intercept[TypeError](doTest(str, "")).msg == msg); () 72 | doTest(str, "", true) 73 | } 74 | 75 | } 76 | -------------------------------------------------------------------------------- /shared/src/test/scala/simplesub/TypingTests.scala: -------------------------------------------------------------------------------- 1 | package simplesub 2 | 3 | import fastparse._ 4 | import Parser.expr 5 | import fastparse.Parsed.Failure 6 | import fastparse.Parsed.Success 7 | 8 | @SuppressWarnings(Array("org.wartremover.warts.Equals")) 9 | class TypingTests extends TypingTestHelpers { 10 | 11 | // TODO remove expected strings; we now use check files in folder ~/out/ instead 12 | 13 | // In the tests, leave the expected string empty so the inferred type is printed in the console 14 | // and you can copy and paste it after making sure it is correct. 15 | 16 | test("basic") { 17 | doTest("42", "int") 18 | doTest("fun x -> 42", "⊤ -> int") 19 | doTest("fun x -> x", "'a -> 'a") 20 | doTest("fun x -> x 42", "(int -> 'a) -> 'a") 21 | doTest("(fun x -> x) 42", "int") 22 | doTest("fun f -> fun x -> f (f x) // twice", "('a ∨ 'b -> 'a) -> 'b -> 'a") 23 | doTest("let twice = fun f -> fun x -> f (f x) in twice", "('a ∨ 'b -> 'a) -> 'b -> 'a") 24 | } 25 | 26 | test("booleans") { 27 | doTest("true", "bool") 28 | doTest("not true", "bool") 29 | doTest("fun x -> not x", "bool -> bool") 30 | doTest("(fun x -> not x) true", "bool") 31 | doTest("fun x -> fun y -> fun z -> if x then y else z", 32 | "bool -> 'a -> 'a -> 'a") 33 | doTest("fun x -> fun y -> if x then y else x", 34 | "'a ∧ bool -> 'a -> 'a") 35 | doTest("fun x -> { u = not x; v = x }", " ") 36 | 37 | error("succ true", 38 | "cannot constrain bool <: int") 39 | error("fun x -> succ (not x)", 40 | "cannot constrain bool <: int") 41 | error("(fun x -> not x.f) { f = 123 }", 42 | "cannot constrain int <: bool") 43 | error("(fun f -> fun x -> not (f x.u)) false", 44 | "cannot constrain bool <: 'a -> 'b") 45 | } 46 | 47 | test("records") { 48 | doTest("fun x -> x.f", "{f: 'a} -> 'a") 49 | doTest("{}", "{}") // note: MLsub returns "⊤" (equivalent) 50 | doTest("{ f = 42 }", "{f: int}") 51 | doTest("{ f = 42 }.f", "int") 52 | doTest("(fun x -> x.f) { f = 42 }", "int") 53 | doTest("fun f -> { x = f 42 }.x", "(int -> 'a) -> 'a") 54 | doTest("fun f -> { x = f 42; y = 123 }.y", "(int -> ⊤) -> int") 55 | doTest("if true then { a = 1; b = true } else { b = false; c = 42 }", "{b: bool}") 56 | 57 | doTest("if true then { u = 1; v = 2; w = 3 } else { u = true; v = 4; x = 5 }", " ") 58 | doTest("if true then fun x -> { u = 1; v = x } else fun y -> { u = y; v = y }", " ") 59 | 60 | error("{ a = 123; b = true }.c", 61 | "missing field: c in {a: int, b: bool}") 62 | error("fun x -> { a = x }.b", 63 | "missing field: b in {a: 'a}") 64 | } 65 | 66 | test("self-app") { 67 | doTest("fun x -> x x", "'a ∧ ('a -> 'b) -> 'b") 68 | 69 | doTest("fun x -> x x x", "'a ∧ ('a -> 'a -> 'b) -> 'b") 70 | doTest("fun x -> fun y -> x y x", "'a ∧ ('b -> 'a -> 'c) -> 'b -> 'c") 71 | doTest("fun x -> fun y -> x x y", "'a ∧ ('a -> 'b -> 'c) -> 'b -> 'c") 72 | doTest("(fun x -> x x) (fun x -> x x)", "⊥") 73 | 74 | doTest("fun x -> {l = x x; r = x }", 75 | "'a ∧ ('a -> 'b) -> {l: 'b, r: 'a}") 76 | 77 | // From https://github.com/stedolan/mlsub 78 | // Y combinator: 79 | doTest("(fun f -> (fun x -> f (x x)) (fun x -> f (x x)))", 80 | "('a -> 'a) -> 'a") 81 | // Z combinator: 82 | doTest("(fun f -> (fun x -> f (fun v -> (x x) v)) (fun x -> f (fun v -> (x x) v)))", 83 | "(('a -> 'b) -> 'c ∧ ('a -> 'b)) -> 'c") 84 | // Function that takes arbitrarily many arguments: 85 | doTest("(fun f -> (fun x -> f (fun v -> (x x) v)) (fun x -> f (fun v -> (x x) v))) (fun f -> fun x -> f)", 86 | "⊤ -> (⊤ -> 'a) as 'a") 87 | 88 | doTest("let rec trutru = fun g -> trutru (g true) in trutru", 89 | "(bool -> 'a) as 'a -> ⊥") 90 | doTest("fun i -> if ((i i) true) then true else true", 91 | "'a ∧ ('a -> bool -> bool) -> bool") 92 | // ^ for: λi. if ((i i) true) then true else true, 93 | // Dolan's thesis says MLsub infers: (α → ((bool → bool) ⊓ α)) → bool 94 | // which does seem equivalent, despite being quite syntactically different 95 | } 96 | 97 | test("let-poly") { 98 | doTest("let f = fun x -> x in {a = f 0; b = f true}", "{a: int, b: bool}") 99 | doTest("fun y -> let f = fun x -> x in {a = f y; b = f true}", 100 | "'a -> {a: 'a, b: bool}") 101 | doTest("fun y -> let f = fun x -> y x in {a = f 0; b = f true}", 102 | "(bool ∨ int -> 'a) -> {a: 'a, b: 'a}") 103 | doTest("fun y -> let f = fun x -> x y in {a = f (fun z -> z); b = f (fun z -> true)}", 104 | "'a -> {a: 'a, b: bool}") 105 | doTest("fun y -> let f = fun x -> x y in {a = f (fun z -> z); b = f (fun z -> succ z)}", 106 | "'a ∧ int -> {a: 'a, b: int}") 107 | 108 | error("(fun k -> k (fun x -> let tmp = add x 1 in x)) (fun f -> f true)", 109 | "cannot constrain bool <: int") 110 | // Let-binding a part in the above test: 111 | error("(fun k -> let test = k (fun x -> let tmp = add x 1 in x) in test) (fun f -> f true)", 112 | "cannot constrain bool <: int") 113 | } 114 | 115 | test("recursion") { 116 | doTest("let rec f = fun x -> f x.u in f", 117 | "{u: 'a} as 'a -> ⊥") 118 | 119 | doTest("let rec consume = fun strm -> add strm.head (consume strm.tail) in consume", 120 | " ") 121 | 122 | // [test:T2]: 123 | doTest("let rec r = fun a -> r in if true then r else r", 124 | "(⊤ -> 'a) as 'a") 125 | // ^ without canonicalization, we get the type: 126 | // ⊤ -> (⊤ -> 'a) as 'a ∨ (⊤ -> 'b) as 'b 127 | doTest("let rec l = fun a -> l in let rec r = fun a -> fun a -> r in if true then l else r", 128 | "(⊤ -> ⊤ -> 'a) as 'a") 129 | // ^ without canonicalization, we get the type: 130 | // ⊤ -> (⊤ -> 'a) as 'a ∨ (⊤ -> (⊤ -> ⊤ -> 'b) as 'b) 131 | doTest("let rec l = fun a -> fun a -> fun a -> l in let rec r = fun a -> fun a -> r in if true then l else r", 132 | "(⊤ -> ⊤ -> ⊤ -> ⊤ -> ⊤ -> ⊤ -> 'a) as 'a") // 6 is the LCD of 3 and 2 133 | // ^ without canonicalization, we get the type: 134 | // ⊤ -> ⊤ -> (⊤ -> ⊤ -> 'a) as 'a ∨ (⊤ -> (⊤ -> ⊤ -> ⊤ -> 'b) as 'b) 135 | 136 | // from https://www.cl.cam.ac.uk/~sd601/mlsub/ 137 | doTest("let rec recursive_monster = fun x -> { thing = x; self = recursive_monster x } in recursive_monster", 138 | "'a -> {self: 'b, thing: 'a} as 'b") 139 | } 140 | 141 | test("random") { 142 | doTest("(let rec x = {a = x; b = x} in x)", "{a: 'a, b: 'a} as 'a") 143 | doTest("(let rec x = fun v -> {a = x v; b = x v} in x)", "⊤ -> {a: 'a, b: 'a} as 'a") 144 | error("let rec x = (let rec y = {u = y; v = (x y)} in 0) in 0", "cannot constrain int <: 'a -> 'b") 145 | doTest("(fun x -> (let y = (x x) in 0))", "'a ∧ ('a -> ⊤) -> int") 146 | doTest("(let rec x = (fun y -> (y (x x))) in x)", "('a -> ('a ∧ ('a -> 'b)) as 'b) -> 'a") 147 | // ^ Note: without canonicalization, we get the simpler: ('b -> 'b ∧ 'a) as 'a -> 'b 148 | doTest("fun next -> 0", "⊤ -> int") 149 | doTest("((fun x -> (x x)) (fun x -> x))", "('b ∨ ('b -> 'a)) as 'a") 150 | doTest("(let rec x = (fun y -> (x (y y))) in x)", "('b ∧ ('b -> 'a)) as 'a -> ⊥") 151 | doTest("fun x -> (fun y -> (x (y y)))", "('a -> 'b) -> 'c ∧ ('c -> 'a) -> 'b") 152 | doTest("(let rec x = (let y = (x x) in (fun z -> z)) in x)", "'a -> ('a ∨ ('a -> 'b)) as 'b") 153 | doTest("(let rec x = (fun y -> (let z = (x x) in y)) in x)", "'a -> ('a ∨ ('a -> 'b)) as 'b") 154 | doTest("(let rec x = (fun y -> {u = y; v = (x x)}) in x)", 155 | "'a -> {u: 'a ∨ ('a -> 'b), v: 'c} as 'c as 'b") 156 | doTest("(let rec x = (fun y -> {u = (x x); v = y}) in x)", 157 | "'a -> {u: 'c, v: 'a ∨ ('a -> 'b)} as 'c as 'b") 158 | doTest("(let rec x = (fun y -> (let z = (y x) in y)) in x)", "('b ∧ ('a -> ⊤) -> 'b) as 'a") 159 | doTest("(fun x -> (let y = (x x.v) in 0))", "{v: 'a} ∧ ('a -> ⊤) -> int") 160 | doTest("let rec x = (let y = (x x) in (fun z -> z)) in (x (fun y -> y.u))", // [test:T1] 161 | "'a ∨ ('a ∧ {u: 'b} -> ('a ∨ 'b ∨ ('a ∧ {u: 'b} -> 'c)) as 'c)") 162 | // ^ Note: without canonicalization, we get the simpler: 163 | // ('b ∨ ('b ∧ {u: 'c} -> 'a ∨ 'c)) as 'a 164 | } 165 | 166 | test("occurs-check") { 167 | 168 | error("fun x -> x.u x", "") 169 | error("fun x -> x.u {v=x}", "") 170 | doTest("fun x -> x.u x.v", " ") 171 | error("fun x -> x.u.v x", "") 172 | 173 | } 174 | 175 | 176 | } 177 | --------------------------------------------------------------------------------