├── .github
└── workflows
│ └── scala.yml
├── .gitignore
├── .scala-steward.conf
├── LICENSE
├── README.md
├── bin
└── simple-algebraic-subtyping-opt.js
├── build.sbt
├── index.css
├── index.html
├── js
└── src
│ └── main
│ └── scala
│ └── Main.scala
├── notes
└── interpretation-of-as-types.txt
├── out
├── basic.check
├── booleans.check
├── isolated.check
├── let-poly.check
├── occurs-check.check
├── random.check
├── records.check
├── recursion.check
└── self-app.check
├── project
├── build.properties
└── plugins.sbt
└── shared
└── src
├── main
└── scala
│ └── simplesub
│ ├── Parser.scala
│ ├── Typer.scala
│ ├── TyperDebugging.scala
│ ├── helpers.scala
│ ├── package.scala
│ └── syntax.scala
└── test
└── scala
└── simplesub
├── IsolatedTests.scala
├── OtherTests.scala
├── ParserTests.scala
├── ProgramTests.scala
├── TypingTestHelpers.scala
└── TypingTests.scala
/.github/workflows/scala.yml:
--------------------------------------------------------------------------------
1 | name: Scala CI
2 |
3 | on:
4 | push:
5 | branches: [ master ]
6 | pull_request:
7 | branches: [ master ]
8 |
9 | jobs:
10 | build:
11 |
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v2
16 | - name: Set up JDK 1.8
17 | uses: actions/setup-java@v1
18 | with:
19 | java-version: 1.8
20 | - name: Run tests
21 | run: sbt test
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | target/
2 | .vscode/
3 | .bloop/
4 | .metals/
5 | .bsp/
6 | project/Dependencies.scala
7 | project/metals.sbt
8 | project/project/
9 | **.worksheet.sc
10 | mlsub/
11 |
--------------------------------------------------------------------------------
/.scala-steward.conf:
--------------------------------------------------------------------------------
1 | updates.ignore = [ { groupId = "org.wartremover" } ]
2 | pullRequests.frequency = "45 days"
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Lionel Parreaux
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Simpler-sub Algorithm for Type Inference with Subtyping
2 |
3 | This repository shows the implementation of **Simpler-sub**,
4 | an alternative algorithm to [Simple-sub](https://github.com/LPTK/simple-sub) which is much easier to understand but also much more limited,
5 | though it is probably enough for many practical use cases.
6 |
7 | An online demo is available here: https://lptk.github.io/simpler-sub/
8 |
9 |
10 | ## Simplifications
11 |
12 | By contrast to Simple-sub, Simpler-sub does not support:
13 |
14 | * (1) Recursive types: any recursive constraint will yield an error instead.
15 |
16 | If your language has externally-defined recursive data types (such as algebraic data types),
17 | you don't strictly need recursive types anyway.
18 | But in order to support field-recursing definitions of the form `let foo x = ... foo x.f ...`, you'll need a way to make it clear to the type checker which recursive data type's `f` field this `foo` recurses on
19 | (otherwise you'll get a recursive constraint error between inferred record types).
20 | One way to do that is to either reject the use of overloaded field names (like Haskell 98)
21 | or use a type class for field selection instead of subtyping (like recent Haskell)
22 | or use contextual information to disambiguate field selections when possible (like OCaml).
23 |
24 | The absence of recursive types makes some of the type inference algorithms simpler and more efficient,
25 | as they can now freely decompose types inductively without having to carry a cache around.
26 |
27 | * (2) Nested let polymorphism: in this prototype, a _local_ (i.e., nested) `let ... in ...` binding will never be assigned a polymorphic type.
28 |
29 | This simplifies the approach further in that we don't have to deal with levels.
30 | And there is precedent for it,
31 | for example see the paper [Let Should not be Generalised](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/tldi10-vytiniotis.pdf) by Dimitrios Vytiniotis, Simon Peyton Jones, and Tom Schrijvers.
32 |
33 | * (3) Precise type-variable-to-type-variable constraints: any constraint between two type variables immediately unifies the two variables.
34 |
35 | The consequence of this is not as dire as one may think,
36 | thanks to pervasive polymorphism (though also see Restriction 2).
37 | Programs that would typically exhibit problematic loss of precision with this approach are those similar to
38 | `let test f x y = if f x then x else y`,
39 | which is typed by Simple-sub as `('a -> bool) -> 'a ∧ 'b -> 'b -> 'b`,
40 | but is now typed as `('a -> bool) -> 'a -> 'a -> 'a`
41 | – notice that `y` is forced to be typed as `'a`, the parameter type taken by `f`,
42 | even though it never flows into `f` and has in fact nothing to do with it.
43 |
44 | In Simple-sub, and in algebraic subtyping in general,
45 | precise graphs of type variable inequalities are recorded,
46 | and then need to be simplified aggressively before being displayed to the user.
47 | Unifying type variables aggressively, on the other hand,
48 | forces inferred type graphs to be almost as simple as in traditional Hindley-Milner type inference,
49 | and makes inferred type simplification much easier.
50 |
51 | Restriction (3) destroys principal type inference:
52 | there may now be well-typed terms that the type inference approach rejects.
53 |
54 | Restriction (1) destroys the principal type property:
55 | there are now terms which cannot be ascribe a single most precise type
56 | – those which would have been typed by Simple-sub through a recursive type,
57 | but which can still be given less precise non-recursive types.
58 | Simpler-sub will in fact plainly reject any such term.
59 |
60 | Each of these simplifications could be made independently, starting from Simple-sub.
61 | I chose to implement them all together in this project to show how simple the result could look like,
62 | while still being quite useful as a possible foundation for type inference in languages with subtyping.
63 |
64 |
65 |
--------------------------------------------------------------------------------
/build.sbt:
--------------------------------------------------------------------------------
1 | import Wart._
2 |
3 | enablePlugins(ScalaJSPlugin)
4 |
5 | ThisBuild / scalaVersion := "2.13.8"
6 | ThisBuild / version := "0.1.0-SNAPSHOT"
7 | ThisBuild / organization := "io.lptk"
8 | ThisBuild / organizationName := "LPTK"
9 |
10 | lazy val root = project.in(file("."))
11 | .aggregate(simplesubJS, simplesubJVM)
12 | .settings(
13 | publish := {},
14 | publishLocal := {}
15 | )
16 |
17 | lazy val simplesub = crossProject(JSPlatform, JVMPlatform).in(file("."))
18 | .settings(
19 | name := "simple-algebraic-subtyping",
20 | scalacOptions ++= Seq(
21 | "-deprecation",
22 | "-feature",
23 | "-unchecked",
24 | "-language:higherKinds",
25 | "-Ywarn-value-discard",
26 | ),
27 | wartremoverWarnings ++= Warts.allBut(
28 | Recursion, Throw, Nothing, Return, While,
29 | Var, MutableDataStructures, NonUnitStatements,
30 | DefaultArguments, ImplicitParameter, StringPlusAny,
31 | JavaSerializable, Serializable, Product,
32 | LeakingSealed,
33 | Option2Iterable, TraversableOps,
34 | Any,
35 | ),
36 | libraryDependencies += "org.scalatest" %% "scalatest" % "3.2.12" % Test,
37 | libraryDependencies += "com.lihaoyi" %%% "fastparse" % "2.3.3",
38 | libraryDependencies += "com.lihaoyi" %%% "sourcecode" % "0.2.8",
39 | libraryDependencies += "com.lihaoyi" %% "ammonite-ops" % "2.4.0",
40 | )
41 | .jsSettings(
42 | scalaJSUseMainModuleInitializer := true,
43 | libraryDependencies += "org.scala-js" %%% "scalajs-dom" % "2.1.0",
44 | libraryDependencies += "be.doeraene" %%% "scalajs-jquery" % "1.0.0",
45 | ThisBuild / evictionErrorLevel := Level.Info,
46 | )
47 |
48 | lazy val simplesubJVM = simplesub.jvm
49 | lazy val simplesubJS = simplesub.js
50 |
--------------------------------------------------------------------------------
/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | background-color: #eee8d5;
3 | max-width: 1200px;
4 | margin: 0px auto;
5 | padding: 5px;
6 | color: #073642;
7 | }
8 | h1, p{
9 | font-family: sans-serif;
10 | padding: 0px 10px;
11 | }
12 | #content {
13 | width: 100%;
14 | padding: 0px;
15 | height: 450px;
16 | }
17 | #content::after {
18 | content: ".";
19 | visibility: hidden;
20 | display: block;
21 | height: 0;
22 | clear: both;
23 | }
24 |
25 | #left, #right, textarea {
26 | width: 50%;
27 | height: 100%;
28 | padding: 10px;
29 | box-sizing: border-box;
30 | font-family: monospace;
31 | }
32 |
33 | #left {
34 | float: left;
35 | }
36 |
37 | #right {
38 | float: right;
39 | }
40 |
41 | #simple-sub-input, #simple-sub-output {
42 | width: 100%;
43 | height: 100%;
44 | box-sizing: border-box;
45 | border-radius: 10px;
46 | padding: 10px;
47 | background-color: #002b36;
48 | /* color: #93a1a1; */
49 | /* color: #cacaca; */
50 | color: #efefef;
51 | font-size: 14px;
52 | }
53 | #simple-sub-input {
54 | resize: none;
55 | outline: none;
56 | }
57 |
58 | #simple-sub-output {
59 | border: 1px solid black;
60 | }
61 |
62 | .binding {
63 | font-family: monospace;
64 | white-space: pre;
65 | }
66 |
67 | .name::before {
68 | content: "val ";
69 | color: #839496;
70 | }
71 |
72 | .name::after {
73 | content: " : ";
74 | color: #839496;
75 | }
76 |
77 | .name, .type {
78 | display: inline;
79 | }
80 | .name {
81 | color: #93a1a1;
82 | font-weight: bold;
83 | }
84 | .type {
85 | color: #859900;
86 | font-weight: bold;
87 | display: inline-block;
88 | }
89 |
--------------------------------------------------------------------------------
/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Simpler-sub demonstration
7 |
8 |
9 | Simpler -sub demonstration
10 |
11 |
12 |
30 |
33 |
34 |
37 |
38 |
39 |
40 | The code is available on github .
41 |
42 | Credit: the CSS style sheet of this page was shamelessly stolen from the MLsub demo page .
43 |
44 |
--------------------------------------------------------------------------------
/js/src/main/scala/Main.scala:
--------------------------------------------------------------------------------
1 | import scala.util.Try
2 | import scala.scalajs.js.annotation.JSExportTopLevel
3 | import org.scalajs.dom
4 | import org.scalajs.dom.document
5 |
6 | object Main {
7 | def main(args: Array[String]): Unit = {
8 | val source = document.querySelector("#simple-sub-input")
9 | update(source.textContent)
10 | source.addEventListener("input", typecheck)
11 | }
12 | @JSExportTopLevel("typecheck")
13 | def typecheck(e: dom.UIEvent): Unit = {
14 | e.target match {
15 | case elt: dom.HTMLTextAreaElement =>
16 | update(elt.value)
17 | }
18 | }
19 | def update(str: String): Unit = {
20 | // println(s"Input: $str")
21 | val target = document.querySelector("#simple-sub-output")
22 | target.innerHTML = Try {
23 | import fastparse._
24 | import fastparse.Parsed.{Success, Failure}
25 | import simplesub.Parser.pgrm
26 | import simplesub.TypeError
27 | parse(str, pgrm(_), verboseFailures = false) match {
28 | case f: Failure =>
29 | val Failure(err, index, extra) = f
30 | // this line-parsing logic was copied from fastparse internals:
31 | val lineNumberLookup = fastparse.internal.Util.lineNumberLookup(str)
32 | val line = lineNumberLookup.indexWhere(_ > index) match {
33 | case -1 => lineNumberLookup.length - 1
34 | case n => math.max(0, n - 1)
35 | }
36 | val lines = str.split('\n')
37 | val lineStr = lines(line min lines.length - 1)
38 | "Parse error: " + extra.trace().msg +
39 | s" at line $line:$lineStr "
40 | case Success(p, index) =>
41 | // println(s"Parsed: $p")
42 | object Typer extends simplesub.Typer(dbg = false) {
43 | import simplesub._
44 | // Saldy, the original `inferTypes` version does not seem to work in JavaScript, as it raises a
45 | // "RangeError: Maximum call stack size exceeded"
46 | // So we have to go with this uglier one:
47 | def inferTypesJS(
48 | pgrm: Pgrm,
49 | ctx: Ctx = builtins,
50 | stopAtFirstError: Boolean = true,
51 | ): List[Either[TypeError, PolymorphicType]] = {
52 | var defs = pgrm.defs
53 | var curCtx = ctx
54 | var res = collection.mutable.ListBuffer.empty[Either[TypeError, PolymorphicType]]
55 | while (defs.nonEmpty) {
56 | val (isrec, nme, rhs) = defs.head
57 | defs = defs.tail
58 | val ty_sch = try Right(typeLetRhs(isrec, nme, rhs)(curCtx)) catch {
59 | case err: TypeError =>
60 | if (stopAtFirstError) defs = Nil
61 | Left(err)
62 | }
63 | res += ty_sch
64 | curCtx += (nme -> ty_sch.getOrElse(freshVar))
65 | }
66 | res.toList
67 | }
68 | }
69 | val tys = Typer.inferTypesJS(p)
70 | (p.defs.zipWithIndex lazyZip tys).map {
71 | case ((d, i), Right(ty)) =>
72 | val ity = ty.instantiate
73 | println(s"Typed `${d._2}` as: $ity")
74 | println(s" where: ${ity.showBounds}")
75 | /*
76 | val com = Typer.canonicalizeType(ty.instantiate)
77 | println(s"Compact type before simplification: ${com}")
78 | val sim = Typer.simplifyType(com)
79 | println(s"Compact type after simplification: ${sim}")
80 | val exp = Typer.coalesceCompactType(sim)
81 | */
82 | val sim = Typer.simplifyType(ity)
83 | println(s"Type after simplification: ${sim}")
84 | val exp = Typer.coalesceType(sim)
85 | s"""
86 | val
87 | ${d._2} :
88 | ${exp.show}
89 | """
90 | case ((d, i), Left(TypeError(msg))) =>
91 | s"""
92 | Type error in ${d._2} : $msg
93 | """
94 | }.mkString(" ")
95 | }
96 | }.fold(err => s"""
97 |
98 | Unexpected error: ${err}${
99 | err.printStackTrace
100 | err
101 | } """, identity)
102 | }
103 | }
104 |
--------------------------------------------------------------------------------
/notes/interpretation-of-as-types.txt:
--------------------------------------------------------------------------------
1 |
2 | Two interpretations of types of the form T['a] as 'a
3 | - if interereted as proper recursive (a.k.a., infinite) type,
4 | it means ('a -> Bot) as 'a is too restrictive a type for self-app
5 | (it's specialized for applying self-app to itself, resulting in non-term and the Bot type as expected),
6 | but the original type ('a -> 'b) as 'a works fine.
7 | explanation:
8 | 'b cannot be simplified away, because it appears in both pos and neg,
9 | since 'a itself does (through its own recursive occurrence!)
10 | - if interpreted as a type variable bounded with its body accordingly to its position,
11 | then ('a -> Bot) as 'a is STILL NOT an acceptable type for self-app
12 | For instance, let's see what happens when we pass it (lam x. 42) of type Top -> Int;
13 | we constrain ?a <: (Top -> Int) -> ?b for some fresh ?b, where ?a :> (?a -> Bot)
14 | this leads to (?a -> Bot) <: (Top -> Int) -> ?b
15 | decomposed as (Top -> Int) <: ?a and Bot <: ?b
16 | which leads to Top -> Int <: Top -> Int
17 | So the constraints work out, but we end up with result type ?b where ?b :> Bot,
18 | which is obviously wrong! (we should have gotten an Int)
19 | And this also happens for the original type ('a -> 'b) as 'a,
20 | which is however valid under the other interpretation!
21 | Here, we may need to think harder about the phrase "bounded accordingly to its position"
22 | Again, the position of 'a really is both pos and neg, so we should really interpret the type as
23 | ?a :> ('a -> 'b) <: ('a -> 'b)
24 | Under this revised interpretation, things work out when applying it to (lam x. 42):
25 | we constrain ?a <: (Top -> Int) -> ?c for some fresh ?c, where ?a :> (?a -> ?b) <: (?a -> ?b)
26 | this leads to ?a -> ?b <: (Top -> Int) -> ?c
27 | decomposed as (Top -> Int) <: ?a and ?b <: Int
28 | which leads to Top -> Int <: Top -> Int and Top -> Int <: ?a -> ?b
29 | decomposed as ?a <: Top and Int <: ?b
30 | which leads to Int <: Int
31 | So the constraints work out, and we do get the appropriate result type Int
32 |
33 | Conclusion:
34 | Both interpretations of "T['a] as 'a" as infinite types and as type variables seem valid and equivalent.
35 | We can treat "T['a] as 'a" as representing a type variable ?a upper and lower-bounded by T[?a].
36 | In the particular case where it appears only positively (resp. neg.) in the original type AS WELL
37 | AS in T[?a] itself, then we can forget about the lower bound (resp. upper bound).
38 |
39 | Note:
40 | The bounds of a recursive types that appear both positively and negatively have no reason to be the same;
41 | so our interpretation above is unlikely to capture the full generality of recursive type variables.
42 | I have changed the type expansion algorithm to only produce recursive type variables which occur only
43 | in positive or only in negative positions (the change was easy). Before that, the algorithm was probably
44 | wrong anyways, under our both-bounds interpretation.
45 | We'll no longer infer things like "('a -> 'b) as 'a", but rather "(('a -> 'b) -> 'b) as 'a",
46 | which is easier to co-occurrence-anlayse using the current indrastructure (otherwise, we'd have to keep
47 | track of the fact some recursive variables may appear both posly and negly as we do the analysis!)
48 |
--------------------------------------------------------------------------------
/out/basic.check:
--------------------------------------------------------------------------------
1 | // 42
2 | int
3 |
4 | // fun x -> 42
5 | ⊤ -> int
6 |
7 | // fun x -> x
8 | 'a -> 'a
9 |
10 | // fun x -> x 42
11 | (int -> 'a) -> 'a
12 |
13 | // (fun x -> x) 42
14 | int
15 |
16 | // fun f -> fun x -> f (f x) // twice
17 | ('a -> 'a) -> 'a -> 'a
18 |
19 | // let twice = fun f -> fun x -> f (f x) in twice
20 | ('a -> 'a) -> 'a -> 'a
21 |
22 |
--------------------------------------------------------------------------------
/out/booleans.check:
--------------------------------------------------------------------------------
1 | // true
2 | bool
3 |
4 | // not true
5 | bool
6 |
7 | // fun x -> not x
8 | bool -> bool
9 |
10 | // (fun x -> not x) true
11 | bool
12 |
13 | // fun x -> fun y -> fun z -> if x then y else z
14 | bool -> 'a -> 'a -> 'a
15 |
16 | // fun x -> fun y -> if x then y else x
17 | 'a ∧ bool -> 'a ∧ bool -> 'a
18 |
19 | // fun x -> { u = not x; v = x }
20 | 'a ∧ bool -> {u: bool, v: 'a}
21 |
22 | // [wrong:] succ true
23 | // ERROR: cannot constrain bool <: int
24 |
25 | // [wrong:] fun x -> succ (not x)
26 | // ERROR: cannot constrain bool <: int
27 |
28 | // [wrong:] (fun x -> not x.f) { f = 123 }
29 | // ERROR: cannot constrain int <: bool
30 |
31 | // [wrong:] (fun f -> fun x -> not (f x.u)) false
32 | // ERROR: cannot constrain bool <: 'a -> 'b
33 |
34 |
--------------------------------------------------------------------------------
/out/isolated.check:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LPTK/simpler-sub/afe92450bd8803c0de13054f07eedb59f86f0373/out/isolated.check
--------------------------------------------------------------------------------
/out/let-poly.check:
--------------------------------------------------------------------------------
1 | // let f = fun x -> x in {a = f 0; b = f true}
2 | {a: ⊤, b: ⊤}
3 |
4 | // fun y -> let f = fun x -> x in {a = f y; b = f true}
5 | 'a -> {a: 'a ∨ bool, b: 'a ∨ bool}
6 |
7 | // fun y -> let f = fun x -> y x in {a = f 0; b = f true}
8 | (⊤ -> 'a) -> {a: 'a, b: 'a}
9 |
10 | // fun y -> let f = fun x -> x y in {a = f (fun z -> z); b = f (fun z -> true)}
11 | 'a -> {a: 'a ∨ bool, b: 'a ∨ bool}
12 |
13 | // fun y -> let f = fun x -> x y in {a = f (fun z -> z); b = f (fun z -> succ z)}
14 | int -> {a: int, b: int}
15 |
16 | // [wrong:] (fun k -> k (fun x -> let tmp = add x 1 in x)) (fun f -> f true)
17 | // ERROR: cannot constrain bool <: int
18 |
19 | // [wrong:] (fun k -> let test = k (fun x -> let tmp = add x 1 in x) in test) (fun f -> f true)
20 | // ERROR: cannot constrain bool <: int
21 |
22 |
--------------------------------------------------------------------------------
/out/occurs-check.check:
--------------------------------------------------------------------------------
1 | // [wrong:] fun x -> x.u x
2 | // ERROR: Illegal cyclic constraint: α1 <: (α0 -> α2)
3 | where: α0 <: {u: α1}
4 |
5 | // [wrong:] fun x -> x.u {v=x}
6 | // ERROR: Illegal cyclic constraint: α1 <: ({v: α0} -> α2)
7 | where: α0 <: {u: α1}
8 |
9 | // fun x -> x.u x.v
10 | {u: 'a -> 'b, v: 'a} -> 'b
11 |
12 | // [wrong:] fun x -> x.u.v x
13 | // ERROR: Illegal cyclic constraint: α2 <: (α0 -> α3)
14 | where: α0 <: {u: α1}, α1 <: {v: α2}
15 |
16 |
--------------------------------------------------------------------------------
/out/random.check:
--------------------------------------------------------------------------------
1 | // (let rec x = {a = x; b = x} in x)
2 | // ERROR: Illegal cyclic constraint: α0 :> {a: α0, b: α0}
3 |
4 | // (let rec x = fun v -> {a = x v; b = x v} in x)
5 | // ERROR: Illegal cyclic constraint: α3 :> {a: α3, b: α3}
6 |
7 | // [wrong:] let rec x = (let rec y = {u = y; v = (x y)} in 0) in 0
8 | // ERROR: Unsupported: local recursive let binding
9 |
10 | // (fun x -> (let y = (x x) in 0))
11 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2)
12 |
13 | // (let rec x = (fun y -> (y (x x))) in x)
14 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2)
15 |
16 | // fun next -> 0
17 | ⊤ -> int
18 |
19 | // ((fun x -> (x x)) (fun x -> x))
20 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1)
21 |
22 | // (let rec x = (fun y -> (x (y y))) in x)
23 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α2)
24 |
25 | // fun x -> (fun y -> (x (y y)))
26 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α2)
27 |
28 | // (let rec x = (let y = (x x) in (fun z -> z)) in x)
29 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α3)
30 |
31 | // (let rec x = (fun y -> (let z = (x x) in y)) in x)
32 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α3)
33 |
34 | // (let rec x = (fun y -> {u = y; v = (x x)}) in x)
35 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2)
36 |
37 | // (let rec x = (fun y -> {u = (x x); v = y}) in x)
38 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2)
39 |
40 | // (let rec x = (fun y -> (let z = (y x) in y)) in x)
41 | // ERROR: Illegal cyclic constraint: α0 :> (α4 -> α4)
42 | where: α4 <: (α0 -> α2)
43 |
44 | // (fun x -> (let y = (x x.v) in 0))
45 | ⊥ -> int
46 |
47 | // let rec x = (let y = (x x) in (fun z -> z)) in (x (fun y -> y.u))
48 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α3)
49 |
50 |
--------------------------------------------------------------------------------
/out/records.check:
--------------------------------------------------------------------------------
1 | // fun x -> x.f
2 | {f: 'a} -> 'a
3 |
4 | // {}
5 | {}
6 |
7 | // { f = 42 }
8 | {f: int}
9 |
10 | // { f = 42 }.f
11 | int
12 |
13 | // (fun x -> x.f) { f = 42 }
14 | int
15 |
16 | // fun f -> { x = f 42 }.x
17 | (int -> 'a) -> 'a
18 |
19 | // fun f -> { x = f 42; y = 123 }.y
20 | (int -> ⊤) -> int
21 |
22 | // if true then { a = 1; b = true } else { b = false; c = 42 }
23 | {b: bool}
24 |
25 | // if true then { u = 1; v = 2; w = 3 } else { u = true; v = 4; x = 5 }
26 | {u: ⊤, v: int}
27 |
28 | // if true then fun x -> { u = 1; v = x } else fun y -> { u = y; v = y }
29 | 'a -> {u: 'a ∨ int, v: 'a ∨ int}
30 |
31 | // [wrong:] { a = 123; b = true }.c
32 | // ERROR: missing field: c in {a: int, b: bool}
33 |
34 | // [wrong:] fun x -> { a = x }.b
35 | // ERROR: missing field: b in {a: 'a}
36 |
37 |
--------------------------------------------------------------------------------
/out/recursion.check:
--------------------------------------------------------------------------------
1 | // let rec f = fun x -> f x.u in f
2 | // ERROR: Illegal cyclic constraint: α2 <: {u: α2}
3 |
4 | // let rec consume = fun strm -> add strm.head (consume strm.tail) in consume
5 | // ERROR: Illegal cyclic constraint: α4 <: {head: α2, tail: α4}
6 | where: α2 <: int
7 |
8 | // let rec r = fun a -> r in if true then r else r
9 | // ERROR: Illegal cyclic constraint: α0 :> (α1 -> α0)
10 |
11 | // let rec l = fun a -> l in let rec r = fun a -> fun a -> r in if true then l else r
12 | // ERROR: Illegal cyclic constraint: α0 :> (α1 -> α0)
13 |
14 | // let rec l = fun a -> fun a -> fun a -> l in let rec r = fun a -> fun a -> r in if true then l else r
15 | // ERROR: Illegal cyclic constraint: α0 :> (α1 -> (α2 -> (α3 -> α0)))
16 |
17 | // let rec recursive_monster = fun x -> { thing = x; self = recursive_monster x } in recursive_monster
18 | // ERROR: Illegal cyclic constraint: α2 :> {thing: α1, self: α2}
19 |
20 |
--------------------------------------------------------------------------------
/out/self-app.check:
--------------------------------------------------------------------------------
1 | // fun x -> x x
2 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1)
3 |
4 | // fun x -> x x x
5 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1)
6 |
7 | // fun x -> fun y -> x y x
8 | // ERROR: Illegal cyclic constraint: α2 <: (α0 -> α3)
9 | where: α0 <: (α1 -> α2)
10 |
11 | // fun x -> fun y -> x x y
12 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2)
13 |
14 | // (fun x -> x x) (fun x -> x x)
15 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1)
16 |
17 | // fun x -> {l = x x; r = x }
18 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α1)
19 |
20 | // (fun f -> (fun x -> f (x x)) (fun x -> f (x x)))
21 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α2)
22 |
23 | // (fun f -> (fun x -> f (fun v -> (x x) v)) (fun x -> f (fun v -> (x x) v)))
24 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α3)
25 |
26 | // (fun f -> (fun x -> f (fun v -> (x x) v)) (fun x -> f (fun v -> (x x) v))) (fun f -> fun x -> f)
27 | // ERROR: Illegal cyclic constraint: α1 <: (α1 -> α3)
28 |
29 | // let rec trutru = fun g -> trutru (g true) in trutru
30 | // ERROR: Illegal cyclic constraint: α2 <: (bool -> α2)
31 |
32 | // fun i -> if ((i i) true) then true else true
33 | // ERROR: Illegal cyclic constraint: α0 <: (α0 -> α2)
34 |
35 |
--------------------------------------------------------------------------------
/project/build.properties:
--------------------------------------------------------------------------------
1 | sbt.version=1.6.2
2 |
--------------------------------------------------------------------------------
/project/plugins.sbt:
--------------------------------------------------------------------------------
1 | addSbtPlugin("org.wartremover" % "sbt-wartremover" % "2.4.16")
2 | addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.10.0")
3 | addSbtPlugin("org.portable-scala" % "sbt-scalajs-crossproject" % "1.2.0")
4 |
--------------------------------------------------------------------------------
/shared/src/main/scala/simplesub/Parser.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | import scala.util.chaining._
4 | import fastparse._, fastparse.ScalaWhitespace._
5 |
6 | @SuppressWarnings(Array("org.wartremover.warts.All"))
7 | object Parser {
8 |
9 | val keywords = Set("let", "rec", "in", "fun", "if", "then", "else", "true", "false")
10 | def kw[__ : P](s: String) = s ~~ !(letter | digit | "_" | "'")
11 |
12 | def letter[__ : P] = P( lowercase | uppercase )
13 | def lowercase[__ : P] = P( CharIn("a-z") )
14 | def uppercase[__ : P] = P( CharIn("A-Z") )
15 | def digit[__ : P] = P( CharIn("0-9") )
16 | def number[__ : P]: P[Int] = P( CharIn("0-9").repX(1).!.map(_.toInt) )
17 | def ident[__ : P]: P[String] =
18 | P( (letter | "_") ~~ (letter | digit | "_" | "'").repX ).!.filter(!keywords(_))
19 |
20 | def term[__ : P]: P[Term] = P( let | fun | ite | apps )
21 | def const[__ : P]: P[Term] = number.map(Lit)
22 | def variable[__ : P]: P[Term] = (ident | "true".! | "false".!).map(Var)
23 | def parens[__ : P]: P[Term] = P( "(" ~/ term ~ ")" )
24 | def subtermNoSel[__ : P]: P[Term] = P( parens | record | const | variable )
25 | def subterm[__ : P]: P[Term] = P( subtermNoSel ~ ("." ~/ ident).rep ).map {
26 | case (st, sels) => sels.foldLeft(st)(Sel) }
27 | def record[__ : P]: P[Term] = P( "{" ~/ (ident ~ "=" ~ term).rep(sep = ";") ~ "}" )
28 | .filter(xs => xs.map(_._1).toSet.size === xs.size).map(_.toList pipe Rcd)
29 | def fun[__ : P]: P[Term] = P( kw("fun") ~/ ident ~ "->" ~ term ).map(Lam.tupled)
30 | def let[__ : P]: P[Term] =
31 | P( kw("let") ~/ kw("rec").!.?.map(_.isDefined) ~ ident ~ "=" ~ term ~ kw("in") ~ term )
32 | .map(Let.tupled)
33 | def ite[__ : P]: P[Term] = P( kw("if") ~/ term ~ kw("then") ~ term ~ kw("else") ~ term ).map(ite =>
34 | App(App(App(Var("if"), ite._1), ite._2), ite._3))
35 | def apps[__ : P]: P[Term] = P( subterm.rep(1).map(_.reduce(App)) )
36 |
37 | def expr[__ : P]: P[Term] = P( term ~ End )
38 |
39 | def toplvl[__ : P]: P[(Boolean, String, Term)] =
40 | P( kw("let") ~/ kw("rec").!.?.map(_.isDefined) ~ ident ~ "=" ~ term )
41 | def pgrm[__ : P]: P[Pgrm] = P( ("" ~ toplvl).rep.map(_.toList) ~ End ).map(Pgrm)
42 |
43 | }
44 |
--------------------------------------------------------------------------------
/shared/src/main/scala/simplesub/Typer.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | import scala.collection.mutable
4 | import scala.collection.mutable.{Map => MutMap, Set => MutSet}
5 | import scala.collection.immutable.{SortedSet, SortedMap}
6 | import scala.util.chaining._
7 | import scala.annotation.tailrec
8 |
9 | final case class TypeError(val msg: String) extends Exception(msg)
10 |
11 | /** A class encapsulating type inference state.
12 | * It uses its own internal representation of types and type variables, using mutable data structures.
13 | * Inferred SimpleType values are then turned into immutable simplesub.Type values.
14 | */
15 | class Typer(protected val dbg: Boolean) extends TyperDebugging {
16 |
17 | type Ctx = Map[String, TypeScheme]
18 |
19 | val BoolType: Primitive = Primitive("bool")
20 | val IntType: Primitive = Primitive("int")
21 |
22 | val builtins: Ctx = Map(
23 | "true" -> BoolType,
24 | "false" -> BoolType,
25 | "not" -> Function(BoolType, BoolType),
26 | "succ" -> Function(IntType, IntType),
27 | "add" -> Function(IntType, Function(IntType, IntType)),
28 | "if" -> {
29 | val v = freshVar
30 | PolymorphicType(Function(BoolType, Function(v, Function(v, v))))
31 | }
32 | )
33 | private val builtinsSize = builtins.size
34 |
35 | /** The main type inference function */
36 | def inferTypes(pgrm: Pgrm, ctx: Ctx = builtins): List[Either[TypeError, PolymorphicType]] =
37 | pgrm.defs match {
38 | case (isrec, nme, rhs) :: defs =>
39 | val ty_sch = try Right(typeLetRhs(isrec, nme, rhs)(ctx)) catch {
40 | case err: TypeError => Left(err) }
41 | ty_sch :: inferTypes(Pgrm(defs), ctx + (nme -> ty_sch.getOrElse(freshVar)))
42 | case Nil => Nil
43 | }
44 |
45 | def inferType(term: Term, ctx: Ctx = builtins): SimpleType = typeTerm(term)(ctx)
46 |
47 | /** Infer the type of a let binding right-hand side. */
48 | def typeLetRhs(isrec: Boolean, nme: String, rhs: Term)(implicit ctx: Ctx): PolymorphicType = {
49 | val res = if (isrec) {
50 | val e_ty = freshVar
51 | val ty = typeTerm(rhs)(ctx + (nme -> e_ty))
52 | constrain(ty, e_ty)
53 | e_ty
54 | } else typeTerm(rhs)(ctx)
55 | PolymorphicType(res)
56 | }
57 |
58 | /** Infer the type of a term. */
59 | def typeTerm(term: Term)(implicit ctx: Ctx): SimpleType = trace(s"T $term") {
60 | lazy val res = freshVar
61 | term match {
62 | case Var(name) =>
63 | ctx.getOrElse(name, err("identifier not found: " + name)).instantiate
64 | case Lam(name, body) =>
65 | val param = freshVar
66 | val body_ty = typeTerm(body)(ctx + (name -> param))
67 | Function(param, body_ty)
68 | case App(f, a) =>
69 | val f_ty = typeTerm(f)
70 | val a_ty = typeTerm(a)
71 | constrain(f_ty, Function(a_ty, res))
72 | res
73 | case Lit(n) =>
74 | IntType
75 | case Sel(obj, name) =>
76 | val obj_ty = typeTerm(obj)
77 | constrain(obj_ty, Record((name, res) :: Nil))
78 | res
79 | case Rcd(fs) =>
80 | Record(fs.map { case (n, t) => (n, typeTerm(t)) })
81 | case Let(isrec, nme, rhs, bod) =>
82 | if (isrec) if (ctx.sizeCompare(builtinsSize) <= 0) {
83 | val n_ty = typeLetRhs(isrec, nme, rhs)
84 | typeTerm(bod)(ctx + (nme -> n_ty))
85 | } else err("Unsupported: local recursive let binding")
86 | else typeTerm(App(Lam(nme, bod), rhs))
87 | }
88 | }(res => ": " + res)
89 |
90 | /** Constrains the types to enforce a subtyping relationship `lhs` <: `rhs`. */
91 | def constrain(lhs: SimpleType, rhs: SimpleType): Unit = trace(s"C $lhs <: $rhs") {
92 | if (lhs is rhs) return ()
93 | (lhs, rhs) match {
94 | case (Bot, _) | (_, Top) => ()
95 | case (Function(l0, r0), Function(l1, r1)) =>
96 | constrain(l1, l0)
97 | constrain(r0, r1)
98 | case (Record(fs0), Record(fs1)) =>
99 | fs1.foreach { case (n1, t1) =>
100 | fs0.find(_._1 === n1).fold(
101 | err(s"missing field: $n1 in ${lhs.show}")
102 | ) { case (n0, t0) => constrain(t0, t1) }
103 | }
104 | case (lhs: Variable, rhs: Variable) =>
105 | lhs.unifyWith(rhs)
106 | case (lhs: Variable, rhs: ConcreteType) =>
107 | lhs.newUpperBound(rhs)
108 | case (lhs: ConcreteType, rhs: Variable) =>
109 | rhs.newLowerBound(lhs)
110 | case _ => err(s"cannot constrain ${lhs.show} <: ${rhs.show}")
111 | }
112 | }()
113 |
114 | def err(msg: String): Nothing = throw TypeError(msg)
115 |
116 | private var freshCount = 0
117 | def freshVar: Variable = new Variable(Bot, Top)
118 |
119 | def freshenType(ty: SimpleType): SimpleType = {
120 | val freshened = MutMap.empty[Variable, Variable]
121 | def freshen(ty: SimpleType): SimpleType = ty match {
122 | case tv: Variable =>
123 | freshened.get(tv) match {
124 | case Some(tv) => tv
125 | case None =>
126 | val v = new Variable(
127 | freshenConcrete(tv.lowerBound),
128 | freshenConcrete(tv.upperBound))
129 | freshened += tv -> v
130 | v
131 | }
132 | case c: ConcreteType => freshenConcrete(c)
133 | }
134 | def freshenConcrete(ty: ConcreteType): ConcreteType = ty match {
135 | case Function(l, r) => Function(freshen(l), freshen(r))
136 | case Record(fs) => Record(fs.map(ft => ft._1 -> freshen(ft._2)))
137 | case Primitive(_) | Top | Bot => ty
138 | }
139 | freshen(ty)
140 | }
141 |
142 | def glbConcrete(lhs: ConcreteType, rhs: ConcreteType): ConcreteType = (lhs, rhs) match {
143 | case (Top, _) => rhs
144 | case (_, Top) => lhs
145 | case (Bot, _) | (_, Bot) => Bot
146 | case (Function(l0, r0), Function(l1, r1)) => Function(lub(l0, l1), glb(r0, r1))
147 | case (Record(fs0), Record(fs1)) => Record(mergeMap(fs0, fs1)(glb(_, _)).toList)
148 | case (Primitive(n0), Primitive(n1)) if n0 === n1 => Primitive(n0)
149 | case _ => Bot
150 | }
151 | def lubConcrete(lhs: ConcreteType, rhs: ConcreteType): ConcreteType = (lhs, rhs) match {
152 | case (Bot, _) => rhs
153 | case (_, Bot) => lhs
154 | case (Top, _) | (_, Top) => Top
155 | case (Function(l0, r0), Function(l1, r1)) => Function(glb(l0, l1), lub(r0, r1))
156 | case (Record(fs0), Record(fs1)) =>
157 | val fs1m = fs1.toMap
158 | Record(fs0.flatMap {
159 | case (n, t0) => fs1m.get(n) match { case Some(t1) => n -> lub(t0, t1) :: Nil; case None => Nil }
160 | })
161 | case (Primitive(n0), Primitive(n1)) if n0 === n1 => Primitive(n0)
162 | case _ => Top
163 | }
164 | def glb(lhs: SimpleType, rhs: SimpleType): SimpleType = (lhs, rhs) match {
165 | case (c0: ConcreteType, c1: ConcreteType) => glbConcrete(c0, c1)
166 | case (v0: Variable, v1: Variable) => v0.unifyWith(v1); v1
167 | case (c0: ConcreteType, v1: Variable) => v1.newUpperBound(c0); v1
168 | case (v0: Variable, c1: ConcreteType) => v0.newUpperBound(c1); v0
169 | }
170 | def lub(lhs: SimpleType, rhs: SimpleType): SimpleType = (lhs, rhs) match {
171 | case (c0: ConcreteType, c1: ConcreteType) => lubConcrete(c0, c1)
172 | case (v0: Variable, v1: Variable) => v0.unifyWith(v1); v1
173 | case (c0: ConcreteType, v1: Variable) => v1.newLowerBound(c0); v1
174 | case (v0: Variable, c1: ConcreteType) => v0.newLowerBound(c1); v0
175 | }
176 |
177 |
178 | // The data types used for type inference:
179 |
180 | /** A type that potentially contains universally quantified type variables,
181 | * and which can be isntantiated to a given level. */
182 | sealed abstract class TypeScheme {
183 | def instantiate: SimpleType
184 | }
185 | /** A type with universally quantified type variables
186 | * (by convention, those variables of level greater than `level` are considered quantified). */
187 | case class PolymorphicType(body: SimpleType) extends TypeScheme {
188 | def instantiate = freshenType(body)
189 | }
190 | /** A type without universally quantified type variables. */
191 | sealed abstract class SimpleType extends TypeScheme with SimpleTypeImpl {
192 | def instantiate = this
193 | }
194 |
195 | sealed abstract class ConcreteType extends SimpleType
196 |
197 | case object Top extends ConcreteType
198 | case object Bot extends ConcreteType
199 |
200 | case class Function(lhs: SimpleType, rhs: SimpleType) extends ConcreteType {
201 | override def toString = s"($lhs -> $rhs)"
202 | }
203 | case class Record(fields: List[(String, SimpleType)]) extends ConcreteType {
204 | override def toString = s"{${fields.map(f => s"${f._1}: ${f._2}").mkString(", ")}}"
205 | }
206 | case class Primitive(name: String) extends ConcreteType {
207 | override def toString = name
208 | }
209 |
210 | /** A type variable with mutable bounds. */
211 | final class Variable(
212 | private var _lowerBound: ConcreteType,
213 | private var _upperBound: ConcreteType,
214 | ) extends SimpleType {
215 | private val _uid: Int = { val n = freshCount; freshCount += 1; n }
216 | def uid: Int = representative._uid
217 | private var _representative: Option[Variable] = None
218 | def representative: Variable =
219 | _representative match {
220 | case Some(v) =>
221 | val rep = v.representative
222 | if (rep isnt v) _representative = Some(rep)
223 | rep
224 | case None => this
225 | }
226 | def lowerBound: ConcreteType = representative._lowerBound
227 | def upperBound: ConcreteType = representative._upperBound
228 | def newUpperBound(ub: ConcreteType): Unit = {
229 | occursCheck(ub, true)
230 | val rep = representative
231 | rep._upperBound = glbConcrete(rep._upperBound, ub)
232 | constrain(rep._lowerBound, ub)
233 | }
234 | def newLowerBound(lb: ConcreteType): Unit = {
235 | occursCheck(lb, false)
236 | val rep = representative
237 | rep._lowerBound = lubConcrete(rep._lowerBound, lb)
238 | constrain(lb, rep._upperBound)
239 | }
240 | private def occursCheck(ty: ConcreteType, dir: Boolean): Unit = {
241 | if (ty.getVars.contains(representative)) {
242 | val boundsStr = ty.showBounds
243 | err(s"Illegal cyclic constraint: $this ${if (dir) "<:" else ":>"} $ty"
244 | + (if (boundsStr.isEmpty) "" else "\n\t\twhere: " + boundsStr))
245 | }
246 | }
247 | lazy val asTypeVar = new TypeVariable("α", uid)
248 | def unifyWith(that: Variable): Unit = {
249 | val rep0 = representative
250 | val rep1 = that.representative
251 | if (rep0 isnt rep1) {
252 | // Note: these is occursCheck calls (and the following ones from addXBound are pretty
253 | // inefficient as they will incur repeated computation of type variables through getVars:
254 | occursCheck(rep1._lowerBound, false)
255 | occursCheck(rep1._upperBound, true)
256 | rep1.newLowerBound(rep0._lowerBound)
257 | rep1.newUpperBound(rep0._upperBound)
258 | rep0._representative = Some(rep1)
259 | }
260 | }
261 | override def toString: String =
262 | // _representative.fold("α" + uid)(_.toString + "<~" + uid)
263 | "α" + representative.uid
264 | override def hashCode: Int = representative.uid
265 | override def equals(that: Any): Boolean = that match {
266 | case that: Typer#Variable => representative is that.representative
267 | case _ => false
268 | }
269 | }
270 |
271 |
272 |
273 |
274 | def simplifyType(st: SimpleType): SimpleType = {
275 |
276 | val pos, neg = mutable.Set.empty[Variable]
277 |
278 | def analyze(st: SimpleType, pol: Boolean): Unit = st match {
279 | case Record(fs) => fs.foreach(f => analyze(f._2, pol))
280 | case Function(l, r) => analyze(l, !pol); analyze(r, pol)
281 | case v: Variable =>
282 | (if (pol) pos else neg) += v
283 | analyze(if (pol) v.lowerBound else v.upperBound, pol)
284 | case Primitive(_) | Top | Bot => ()
285 | }
286 | analyze(st, true)
287 |
288 | val mapping = mutable.Map.empty[Variable, SimpleType]
289 |
290 | def transformConcrete(st: ConcreteType, pol: Boolean): ConcreteType = st match {
291 | case Record(fs) => Record(fs.map(f => f._1 -> transform(f._2, pol)))
292 | case Function(l, r) => Function(transform(l, !pol), transform(r, pol))
293 | case Primitive(_) | Top | Bot => st
294 | }
295 | def transform(st: SimpleType, pol: Boolean): SimpleType = st match {
296 | case v: Variable =>
297 | mapping.getOrElseUpdate(v,
298 | if (v.lowerBound === v.upperBound) transformConcrete(v.lowerBound, pol)
299 | else if (pol && !neg(v)) transformConcrete(v.lowerBound, pol)
300 | else if (!pol && !pos(v)) transformConcrete(v.upperBound, pol)
301 | else new Variable(transformConcrete(v.lowerBound, true), transformConcrete(v.upperBound, false))
302 | )
303 | case c: ConcreteType => transformConcrete(c, pol)
304 | }
305 | transform(st, true)
306 |
307 | }
308 |
309 |
310 | type PolarVariable = (Variable, Boolean)
311 |
312 | /** Convert an inferred SimpleType into the immutable Type representation. */
313 | def coalesceType(st: SimpleType): Type = {
314 | def go(st: SimpleType, polarity: Boolean): Type = st match {
315 | case tv: Variable =>
316 | val bound = if (polarity) tv.lowerBound else tv.upperBound
317 | val boundType = go(bound, polarity)
318 | val mrg = if (polarity) Union else Inter
319 | if (polarity && bound === Bot || bound === Top) tv.asTypeVar
320 | else mrg(tv.asTypeVar, boundType)
321 | case Function(l, r) => FunctionType(go(l, !polarity), go(r, polarity))
322 | case Record(fs) => RecordType(fs.map(nt => nt._1 -> go(nt._2, polarity)))
323 | case Primitive(n) => PrimitiveType(n)
324 | case Top => PrimitiveType("⊤")
325 | case Bot => PrimitiveType("⊥")
326 | }
327 | go(st, true)
328 | }
329 |
330 |
331 | }
332 |
--------------------------------------------------------------------------------
/shared/src/main/scala/simplesub/TyperDebugging.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | import scala.collection.mutable.{Map => MutMap, Set => MutSet, LinkedHashMap, LinkedHashSet}
4 | import scala.collection.immutable.{SortedMap, SortedSet}
5 | import scala.annotation.tailrec
6 |
7 | /** Inessential methods used to help debugging. */
8 | abstract class TyperDebugging { self: Typer =>
9 |
10 | // Shadow Predef functions with debugging-flag-enabled ones:
11 | def println(msg: => Any): Unit = if (dbg) emitDbg(" " * indent + msg)
12 | def assert(assertion: => Boolean): Unit = if (dbg) scala.Predef.assert(assertion)
13 |
14 | private val noPostTrace: Any => String = _ => ""
15 |
16 | protected var indent = 0
17 | def trace[T](pre: String)(thunk: => T)(post: T => String = noPostTrace): T = {
18 | println(pre)
19 | indent += 1
20 | val res = try thunk finally indent -= 1
21 | if (post isnt noPostTrace) println(post(res))
22 | res
23 | }
24 | def emitDbg(str: String): Unit = scala.Predef.println(str)
25 |
26 | trait SimpleTypeImpl { self: SimpleType =>
27 |
28 | def children: List[SimpleType] = this match {
29 | case tv: Variable => tv.lowerBound :: tv.upperBound :: Nil
30 | case Function(l, r) => l :: r :: Nil
31 | case Record(fs) => fs.map(_._2)
32 | case Primitive(_) => Nil
33 | case Top | Bot => Nil
34 | }
35 | def getVars: Set[Variable] = {
36 | val res = MutSet.empty[Variable]
37 | @tailrec def rec(queue: List[SimpleType]): Unit = queue match {
38 | case (tv: Variable) :: tys =>
39 | if (res(tv)) rec(tys)
40 | else { res += tv; rec(tv.children ::: tys) }
41 | case ty :: tys => rec(ty.children ::: tys)
42 | case Nil => ()
43 | }
44 | rec(this :: Nil)
45 | SortedSet.from(res)(Ordering.by(_.uid))
46 | }
47 | def show: String = coalesceType(this).show
48 | def showBounds: String =
49 | getVars.iterator.filter(tv => tv.upperBound =/= Top || tv.lowerBound =/= Bot).map(tv =>
50 | tv.toString
51 | + (if (tv.lowerBound === Bot) "" else " :> " + tv.lowerBound)
52 | + (if (tv.upperBound === Top) "" else " <: " + tv.upperBound)
53 | ).mkString(", ")
54 |
55 | }
56 |
57 | }
58 |
--------------------------------------------------------------------------------
/shared/src/main/scala/simplesub/helpers.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | import scala.util.chaining._
4 | import scala.collection.mutable.{Map => MutMap, SortedMap => SortedMutMap, Set => MutSet}
5 | import scala.collection.immutable.SortedSet
6 |
7 |
8 | // Helper methods for types
9 |
10 | abstract class TypeImpl { self: Type =>
11 |
12 | lazy val typeVarsList: List[TypeVariable] = this match {
13 | case uv: TypeVariable => uv :: Nil
14 | case _ => children.flatMap(_.typeVarsList)
15 | }
16 |
17 | def show: String = {
18 | val vars = typeVarsList.distinct
19 | val ctx = vars.zipWithIndex.map {
20 | case (tv, idx) =>
21 | def nme = {
22 | assert(idx <= 'z' - 'a', "TODO handle case of not enough chars")
23 | ('a' + idx).toChar.toString
24 | }
25 | tv -> ("'" + nme)
26 | }.toMap
27 | showIn(ctx, 0)
28 | }
29 |
30 | private def parensIf(str: String, cnd: Boolean): String = if (cnd) "(" + str + ")" else str
31 | def showIn(ctx: Map[TypeVariable, String], outerPrec: Int): String = this match {
32 | case Top => "⊤"
33 | case Bot => "⊥"
34 | case PrimitiveType(name) => name
35 | case uv: TypeVariable => ctx(uv)
36 | case FunctionType(l, r) => parensIf(l.showIn(ctx, 11) + " -> " + r.showIn(ctx, 10), outerPrec > 10)
37 | case RecordType(fs) => fs.map(nt => s"${nt._1}: ${nt._2.showIn(ctx, 0)}").mkString("{", ", ", "}")
38 | case Union(l, r) => parensIf(l.showIn(ctx, 20) + " ∨ " + r.showIn(ctx, 20), outerPrec > 20)
39 | case Inter(l, r) => parensIf(l.showIn(ctx, 25) + " ∧ " + r.showIn(ctx, 25), outerPrec > 25)
40 | }
41 |
42 | def children: List[Type] = this match {
43 | case _: PrimitiveType | _: TypeVariable | Top | Bot => Nil
44 | case FunctionType(l, r) => l :: r :: Nil
45 | case RecordType(fs) => fs.map(_._2)
46 | case Union(l, r) => l :: r :: Nil
47 | case Inter(l, r) => l :: r :: Nil
48 | }
49 |
50 | }
51 |
--------------------------------------------------------------------------------
/shared/src/main/scala/simplesub/package.scala:
--------------------------------------------------------------------------------
1 | package object simplesub {
2 |
3 | import scala.collection.mutable
4 | import scala.collection.immutable.SortedMap
5 |
6 | @SuppressWarnings(Array(
7 | "org.wartremover.warts.Equals",
8 | "org.wartremover.warts.AsInstanceOf"))
9 | implicit final class AnyOps[A](self: A) {
10 | def ===(other: A): Boolean = self == other
11 | def =/=(other: A): Boolean = self != other
12 | def is(other: AnyRef): Boolean = self.asInstanceOf[AnyRef] eq other
13 | def isnt(other: AnyRef): Boolean = !(self.asInstanceOf[AnyRef] eq other)
14 | /** An alternative to === when in ScalaTest, which shadows our === */
15 | def =:=(other: A): Boolean = self == other
16 | }
17 |
18 | implicit class IterableOps[A](private val self: IterableOnce[A]) extends AnyVal {
19 | def mkStringOr(
20 | sep: String = "", start: String = "", end: String = "", els: String = ""
21 | ): String =
22 | if (self.iterator.nonEmpty) self.iterator.mkString(start, sep, end) else els
23 | }
24 |
25 | def mergeOptions[A](lhs: Option[A], rhs: Option[A])(f: (A, A) => A): Option[A] = (lhs, rhs) match {
26 | case (Some(l), Some(r)) => Some(f(l, r))
27 | case (lhs @ Some(_), _) => lhs
28 | case (_, rhs @ Some(_)) => rhs
29 | case (None, None) => None
30 | }
31 |
32 | def mergeMap[A, B](lhs: Iterable[(A, B)], rhs: Iterable[(A, B)])(f: (B, B) => B): Map[A,B] =
33 | new mutable.ArrayBuffer(lhs.knownSize + rhs.knownSize max 8)
34 | .addAll(lhs).addAll(rhs).groupMapReduce(_._1)(_._2)(f)
35 |
36 | def mergeSortedMap[A: Ordering, B](lhs: Iterable[(A, B)], rhs: Iterable[(A, B)])(f: (B, B) => B): SortedMap[A,B] =
37 | SortedMap.from(mergeMap(lhs, rhs)(f))
38 |
39 | def closeOver[A](xs: Set[A])(f: A => Set[A]): Set[A] =
40 | closeOverCached(Set.empty, xs)(f)
41 | def closeOverCached[A](done: Set[A], todo: Set[A])(f: A => Set[A]): Set[A] =
42 | if (todo.isEmpty) done else {
43 | val newDone = done ++ todo
44 | closeOverCached(newDone, todo.flatMap(f) -- newDone)(f)
45 | }
46 |
47 | }
48 |
--------------------------------------------------------------------------------
/shared/src/main/scala/simplesub/syntax.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 |
4 | // Terms
5 |
6 | final case class Pgrm(defs: List[(Boolean, String, Term)])
7 |
8 | sealed abstract class Term
9 | final case class Lit(value: Int) extends Term
10 | final case class Var(name: String) extends Term
11 | final case class Lam(name: String, rhs: Term) extends Term
12 | final case class App(lhs: Term, rhs: Term) extends Term
13 | final case class Rcd(fields: List[(String, Term)]) extends Term
14 | final case class Sel(receiver: Term, fieldName: String) extends Term
15 | final case class Let(isRec: Boolean, name: String, rhs: Term, body: Term) extends Term
16 |
17 |
18 | // Types
19 |
20 | sealed abstract class Type extends TypeImpl
21 | case object Top extends Type
22 | case object Bot extends Type
23 | final case class Union(lhs: Type, rhs: Type) extends Type
24 | final case class Inter(lhs: Type, rhs: Type) extends Type
25 | final case class FunctionType(lhs: Type, rhs: Type) extends Type
26 | final case class RecordType(fields: List[(String, Type)]) extends Type
27 | final case class PrimitiveType(name: String) extends Type
28 | final class TypeVariable(val nameHint: String, val hash: Int) extends Type {
29 | override def toString: String = s"$nameHint:$hash"
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/shared/src/test/scala/simplesub/IsolatedTests.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | @SuppressWarnings(Array("org.wartremover.warts.Equals"))
4 | class IsolatedTests extends TypingTestHelpers {
5 |
6 | // This test class is for isolating single tests and running them alone
7 | // with sbt command `~testOnly simplesub.IsolatedTests`
8 |
9 | test("isolated") {
10 |
11 | // put your test here
12 |
13 |
14 | }
15 |
16 | }
17 |
--------------------------------------------------------------------------------
/shared/src/test/scala/simplesub/OtherTests.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | import org.scalatest.funsuite.AnyFunSuite
4 |
5 | @SuppressWarnings(Array("org.wartremover.warts.Equals"))
6 | class OtherTests extends AnyFunSuite {
7 | /*
8 | test("canonicalization produces LCD") {
9 |
10 | val typer = new Typer(false) with TypeSimplifier
11 | import typer.{assert => _, _}
12 | val tv0, tv1, tv3 = freshVar(0)
13 |
14 | // {f: {B: int, f: 'a}} as 'a – cycle length 2
15 | val st0 = Record("f"->Record("f"->tv0::"B"->IntType::Nil)::Nil)
16 | tv0.lowerBounds ::= st0
17 |
18 | // {f: {B: int, f: {A: int, f: 'a}}} as 'a – cycle length 3
19 | val st1 = Record("f"->Record("f"->Record("f"->tv1::"A"->IntType::Nil)::"B"->IntType::Nil)::Nil)
20 | tv1.lowerBounds ::= st1
21 | tv3.lowerBounds = tv0 :: tv1 :: Nil
22 |
23 | // println(tv3.showBounds)
24 |
25 | val ct = canonicalizeType(tv3)
26 | val sct = simplifyType(ct)
27 | val csct = coalesceCompactType(sct).show
28 |
29 | assert(csct == "{f: {B: int, f: {f: {f: {f: {f: 'a}}}}}} as 'a") // cycle length 6
30 |
31 | }
32 | */
33 | }
34 |
--------------------------------------------------------------------------------
/shared/src/test/scala/simplesub/ParserTests.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | import org.scalatest._
4 | import fastparse._
5 | import Parser.expr
6 | import fastparse.Parsed.Failure
7 | import fastparse.Parsed.Success
8 | import org.scalatest.funsuite.AnyFunSuite
9 |
10 | class ParserTests extends AnyFunSuite {
11 |
12 | def doTest(str: String): Unit = {
13 | parse(str, expr(_), verboseFailures = true) match {
14 | case Success(value, index) =>
15 | // println("OK: " + value)
16 | case f: Failure =>
17 | val Failure(expected, failIndex, extra) = f
18 | println(extra.trace())
19 | println(extra.trace().longAggregateMsg)
20 | assert(false)
21 | }
22 | ()
23 | }
24 |
25 | test("basics") {
26 | doTest("1")
27 | doTest("a")
28 | doTest("1 2 3")
29 | doTest("a b c")
30 | doTest("true")
31 | }
32 |
33 | test("let") {
34 | doTest("let a = b in c")
35 | doTest("let a = 1 in 1")
36 | doTest("let a = (1) in 1")
37 | assert(!parse("let true = 0 in true", expr(_)).isSuccess)
38 | }
39 |
40 | test("records") {
41 | doTest("{ a = 1; b = 2 }")
42 | assert(!parse("{ a = 1; b = 2; a = 3 }", expr(_)).isSuccess)
43 | }
44 |
45 | }
46 |
--------------------------------------------------------------------------------
/shared/src/test/scala/simplesub/ProgramTests.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | import org.scalatest._
4 | import fastparse._
5 | import Parser.pgrm
6 | import fastparse.Parsed.Failure
7 | import fastparse.Parsed.Success
8 | import sourcecode.Line
9 | import org.scalatest.funsuite.AnyFunSuite
10 |
11 | @SuppressWarnings(Array("org.wartremover.warts.Equals"))
12 | class ProgramTests extends AnyFunSuite {
13 |
14 | // TODO port these tests
15 | /*
16 | implicit class ExpectedStr(val str: String)(implicit val line: Line)
17 |
18 | def doTest(str: String)(expected: ExpectedStr*): Unit = {
19 | val dbg = expected.exists(_.str.isEmpty)
20 | val Success(p, index) = parse(str, pgrm(_), verboseFailures = true)
21 | val typer = new Typer(dbg) with TypeSimplifier
22 | val tys = typer.inferTypes(p)
23 | var toPrint: List[String] = Nil
24 | (p.defs lazyZip tys lazyZip expected).foreach { (str, pty, exp) =>
25 | if (exp.str.isEmpty) println(s">>> $str")
26 | val ty = pty.fold(err => throw err, _.instantiate)
27 | // val cty = typer.canonicalizeType(ty)
28 | // val sty = typer.simplifyType(cty)
29 | // val res = typer.coalesceCompactType(sty).show
30 | val res = typer.coalesceType(ty).show
31 | if (exp.str.nonEmpty) { assert(res == exp.str, "at line " + exp.line.value); () }
32 | else {
33 | toPrint ::= res
34 | println("inferred: " + ty)
35 | println(" where " + ty.showBounds)
36 | println(res)
37 | println("---")
38 | }
39 | }
40 | if (toPrint.nonEmpty) toPrint.reverseIterator.foreach(s => println("Inferred: " + s))
41 | assert(tys.size == expected.size); ()
42 | }
43 |
44 | test("mlsub") { // from https://www.cl.cam.ac.uk/~sd601/mlsub/
45 | doTest("""
46 | let id = fun x -> x
47 | let twice = fun f -> fun x -> f (f x)
48 | let object1 = { x = 42; y = id }
49 | let object2 = { x = 17; y = false }
50 | let pick_an_object = fun b ->
51 | if b then object1 else object2
52 | let rec recursive_monster = fun x ->
53 | { thing = x;
54 | self = recursive_monster x }
55 | """)(
56 | "'a -> 'a",
57 | "('a ∨ 'b -> 'a) -> 'b -> 'a",
58 | "{x: int, y: 'a -> 'a}",
59 | "{x: int, y: bool}",
60 | "bool -> {x: int, y: bool ∨ ('a -> 'a)}",
61 | "'a -> {self: 'b, thing: 'a} as 'b",
62 | )
63 | }
64 |
65 | test("top-level-polymorphism") {
66 | doTest("""
67 | let id = fun x -> x
68 | let ab = {u = id 0; v = id true}
69 | """)(
70 | "'a -> 'a",
71 | "{u: int, v: bool}",
72 | )
73 | }
74 |
75 | test("rec-producer-consumer") {
76 | doTest("""
77 | let rec produce = fun arg -> { head = arg; tail = produce (succ arg) }
78 | let rec consume = fun strm -> add strm.head (consume strm.tail)
79 |
80 | let codata = produce 42
81 | let res = consume codata
82 |
83 | let rec codata2 = { head = 0; tail = { head = 1; tail = codata2 } }
84 | let res = consume codata2
85 |
86 | let rec produce3 = fun b -> { head = 123; tail = if b then codata else codata2 }
87 | let res = fun x -> consume (produce3 x)
88 |
89 | let consume2 =
90 | let rec go = fun strm -> add strm.head (add strm.tail.head (go strm.tail.tail))
91 | in fun strm -> add strm.head (go strm.tail)
92 | // in go
93 | // let rec consume2 = fun strm -> add strm.head (add strm.tail.head (consume2 strm.tail.tail))
94 | let res = consume2 codata2
95 | """)(
96 | "int -> {head: int, tail: 'a} as 'a",
97 | "{head: int, tail: 'a} as 'a -> int",
98 | "{head: int, tail: 'a} as 'a",
99 | "int",
100 | "{head: int, tail: {head: int, tail: 'a}} as 'a",
101 | "int",
102 | "bool -> {head: int, tail: {head: int, tail: 'a}} as 'a",
103 | // ^ simplifying this would probably require more advanced
104 | // automata-based techniques such as the one proposed by Dolan
105 | "bool -> int",
106 | "{head: int, tail: {head: int, tail: 'a}} as 'a -> int",
107 | "int",
108 | )
109 | }
110 |
111 | test("misc") {
112 | doTest("""
113 | //
114 | // From a comment on the blog post:
115 | //
116 | let rec r = fun a -> r
117 | let join = fun a -> fun b -> if true then a else b
118 | let s = join r r
119 | //
120 | // Inspired by [Pottier 98, chap 13.4]
121 | //
122 | let rec f = fun x -> fun y -> add (f x.tail y) (f x y)
123 | let rec f = fun x -> fun y -> add (f x.tail y) (f y x)
124 | let rec f = fun x -> fun y -> add (f x.tail y) (f x y.tail)
125 | let rec f = fun x -> fun y -> add (f x.tail y.tail) (f x.tail y.tail)
126 | let rec f = fun x -> fun y -> add (f x.tail x.tail) (f y.tail y.tail)
127 | let rec f = fun x -> fun y -> add (f x.tail x) (f y.tail y)
128 | let rec f = fun x -> fun y -> add (f x.tail y) (f y.tail x)
129 | //
130 | let f = fun x -> fun y -> if true then { l = x; r = y } else { l = y; r = x } // 2-crown
131 | //
132 | // Inspired by [Pottier 98, chap 13.5]
133 | //
134 | let rec f = fun x -> fun y -> if true then x else { t = f x.t y.t }
135 | """)(
136 | "(⊤ -> 'a) as 'a",
137 | "'a -> 'a -> 'a",
138 | "(⊤ -> 'a) as 'a",
139 |
140 | "{tail: 'a} as 'a -> ⊤ -> int",
141 | "{tail: 'a} as 'a -> {tail: 'b} as 'b -> int",
142 | "{tail: 'a} as 'a -> {tail: 'b} as 'b -> int",
143 | "{tail: 'a} as 'a -> {tail: 'b} as 'b -> int",
144 | "{tail: {tail: 'a} as 'a} -> {tail: {tail: 'b} as 'b} -> int",
145 | // ^ Could simplify more `{tail: {tail: 'a} as 'a}` to `{tail: 'a} as 'a`
146 | // This would likely require another hash-consing pass.
147 | // Indeed, currently, we coalesce {tail: ‹{tail: ‹α25›}›} and it's hash-consing
148 | // which introduces the 'a to stand for {tail: ‹α25›}
149 | // ^ Note: MLsub says:
150 | // let rec f = fun x -> fun y -> (f x.tail x) + (f y.tail y)
151 | // val f : ({tail : (rec b = {tail : b})} -> ({tail : {tail : (rec a = {tail : a})}} -> int))
152 | "{tail: 'a} as 'a -> {tail: {tail: 'b} as 'b} -> int",
153 | // ^ Note: MLsub says:
154 | // let rec f = fun x -> fun y -> (f x.tail x.tail) + (f y.tail y.tail)
155 | // val f : ({tail : {tail : (rec b = {tail : b})}} -> ({tail : {tail : (rec a = {tail : a})}} -> int))
156 | "{tail: 'a} as 'a -> {tail: {tail: 'b} as 'b} -> int",
157 |
158 | "'a -> 'a -> {l: 'a, r: 'a}",
159 |
160 | "('b ∧ {t: 'a}) as 'a -> {t: 'c} as 'c -> ('b ∨ {t: 'd}) as 'd",
161 | // ^ Note: MLsub says:
162 | // let rec f = fun x -> fun y -> if true then x else { t = f x.t y.t }
163 | // val f : (({t : (rec d = ({t : d} & a))} & a) -> ({t : (rec c = {t : c})} -> ({t : (rec b = ({t : b} | a))} | a)))
164 | // ^ Pottier says a simplified version would essentially be, once translated to MLsub types:
165 | // {t: 'a} as 'a -> 'a -> {t: 'd} as 'd
166 | // but even he does not infer that.
167 | // Notice the loss of connection between the first parameetr and the result, in his proposed type,
168 | // which he says is not necessary as it is actually implied.
169 | // He argues that if 'a <: F 'a and F 'b <: 'b then 'a <: 'b, for a type operator F,
170 | // which does indeed seem true (even in MLsub),
171 | // though leveraging such facts for simplification would require much more advanced reasoning.
172 | )
173 | }
174 | */
175 |
176 | }
177 |
--------------------------------------------------------------------------------
/shared/src/test/scala/simplesub/TypingTestHelpers.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | import org.scalatest._
4 | import fastparse._
5 | import Parser.expr
6 | import fastparse.Parsed.Failure
7 | import fastparse.Parsed.Success
8 | import sourcecode.Line
9 | import org.scalatest.funsuite.AnyFunSuite
10 | import ammonite.ops._
11 |
12 | @SuppressWarnings(Array("org.wartremover.warts.Equals"))
13 | class TypingTestHelpers extends AnyFunSuite {
14 |
15 | private var outFile = Option.empty[os.Path]
16 |
17 | override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit pos: org.scalactic.source.Position): Unit = {
18 | super.test(testName, testTags: _*) {
19 | assert(outFile.isEmpty)
20 | val f = pwd/"out"/(testName+".check")
21 | outFile = Some(f)
22 | write.over(f, "")
23 | try testFun
24 | finally outFile = None
25 | }
26 | }
27 |
28 | def doTest(str: String, expected: String = "", expectError: Boolean = false)(implicit line: Line): Unit = {
29 | val dbg = expected.isEmpty
30 |
31 | if (dbg) println(s">>> $str")
32 | val Success(term, index) = parse(str, expr(_), verboseFailures = true)
33 |
34 | val typer = new Typer(dbg)
35 | val res = try {
36 | val tyv = typer.inferType(term)
37 |
38 | if (dbg) {
39 | println("inferred: " + tyv)
40 | println(" where " + tyv.showBounds)
41 | }
42 |
43 | val res0 = typer.simplifyType(tyv)
44 | if (dbg) {
45 | println("simplified: " + res0)
46 | println(" where " + res0.showBounds)
47 | }
48 |
49 | val res = typer.coalesceType(res0).show
50 |
51 | if (dbg) {
52 | println("typed: " + res)
53 | println("---")
54 | } else {
55 | // assert(res == expected, "at line " + line.value); ()
56 | }
57 |
58 | res
59 | } catch {
60 | case e: TypeError =>
61 | if (dbg) {
62 | println("ERROR: " + e.msg)
63 | println("---")
64 | }
65 | "// ERROR: " + e.msg
66 | }
67 | write.append(outFile.getOrElse(fail()),
68 | "// " + (if (expectError) "[wrong:] " else "") + str + "\n" + res + "\n\n", createFolders = true)
69 | }
70 | def error(str: String, msg: String): Unit = {
71 | // assert(intercept[TypeError](doTest(str, "")).msg == msg); ()
72 | doTest(str, "", true)
73 | }
74 |
75 | }
76 |
--------------------------------------------------------------------------------
/shared/src/test/scala/simplesub/TypingTests.scala:
--------------------------------------------------------------------------------
1 | package simplesub
2 |
3 | import fastparse._
4 | import Parser.expr
5 | import fastparse.Parsed.Failure
6 | import fastparse.Parsed.Success
7 |
8 | @SuppressWarnings(Array("org.wartremover.warts.Equals"))
9 | class TypingTests extends TypingTestHelpers {
10 |
11 | // TODO remove expected strings; we now use check files in folder ~/out/ instead
12 |
13 | // In the tests, leave the expected string empty so the inferred type is printed in the console
14 | // and you can copy and paste it after making sure it is correct.
15 |
16 | test("basic") {
17 | doTest("42", "int")
18 | doTest("fun x -> 42", "⊤ -> int")
19 | doTest("fun x -> x", "'a -> 'a")
20 | doTest("fun x -> x 42", "(int -> 'a) -> 'a")
21 | doTest("(fun x -> x) 42", "int")
22 | doTest("fun f -> fun x -> f (f x) // twice", "('a ∨ 'b -> 'a) -> 'b -> 'a")
23 | doTest("let twice = fun f -> fun x -> f (f x) in twice", "('a ∨ 'b -> 'a) -> 'b -> 'a")
24 | }
25 |
26 | test("booleans") {
27 | doTest("true", "bool")
28 | doTest("not true", "bool")
29 | doTest("fun x -> not x", "bool -> bool")
30 | doTest("(fun x -> not x) true", "bool")
31 | doTest("fun x -> fun y -> fun z -> if x then y else z",
32 | "bool -> 'a -> 'a -> 'a")
33 | doTest("fun x -> fun y -> if x then y else x",
34 | "'a ∧ bool -> 'a -> 'a")
35 | doTest("fun x -> { u = not x; v = x }", " ")
36 |
37 | error("succ true",
38 | "cannot constrain bool <: int")
39 | error("fun x -> succ (not x)",
40 | "cannot constrain bool <: int")
41 | error("(fun x -> not x.f) { f = 123 }",
42 | "cannot constrain int <: bool")
43 | error("(fun f -> fun x -> not (f x.u)) false",
44 | "cannot constrain bool <: 'a -> 'b")
45 | }
46 |
47 | test("records") {
48 | doTest("fun x -> x.f", "{f: 'a} -> 'a")
49 | doTest("{}", "{}") // note: MLsub returns "⊤" (equivalent)
50 | doTest("{ f = 42 }", "{f: int}")
51 | doTest("{ f = 42 }.f", "int")
52 | doTest("(fun x -> x.f) { f = 42 }", "int")
53 | doTest("fun f -> { x = f 42 }.x", "(int -> 'a) -> 'a")
54 | doTest("fun f -> { x = f 42; y = 123 }.y", "(int -> ⊤) -> int")
55 | doTest("if true then { a = 1; b = true } else { b = false; c = 42 }", "{b: bool}")
56 |
57 | doTest("if true then { u = 1; v = 2; w = 3 } else { u = true; v = 4; x = 5 }", " ")
58 | doTest("if true then fun x -> { u = 1; v = x } else fun y -> { u = y; v = y }", " ")
59 |
60 | error("{ a = 123; b = true }.c",
61 | "missing field: c in {a: int, b: bool}")
62 | error("fun x -> { a = x }.b",
63 | "missing field: b in {a: 'a}")
64 | }
65 |
66 | test("self-app") {
67 | doTest("fun x -> x x", "'a ∧ ('a -> 'b) -> 'b")
68 |
69 | doTest("fun x -> x x x", "'a ∧ ('a -> 'a -> 'b) -> 'b")
70 | doTest("fun x -> fun y -> x y x", "'a ∧ ('b -> 'a -> 'c) -> 'b -> 'c")
71 | doTest("fun x -> fun y -> x x y", "'a ∧ ('a -> 'b -> 'c) -> 'b -> 'c")
72 | doTest("(fun x -> x x) (fun x -> x x)", "⊥")
73 |
74 | doTest("fun x -> {l = x x; r = x }",
75 | "'a ∧ ('a -> 'b) -> {l: 'b, r: 'a}")
76 |
77 | // From https://github.com/stedolan/mlsub
78 | // Y combinator:
79 | doTest("(fun f -> (fun x -> f (x x)) (fun x -> f (x x)))",
80 | "('a -> 'a) -> 'a")
81 | // Z combinator:
82 | doTest("(fun f -> (fun x -> f (fun v -> (x x) v)) (fun x -> f (fun v -> (x x) v)))",
83 | "(('a -> 'b) -> 'c ∧ ('a -> 'b)) -> 'c")
84 | // Function that takes arbitrarily many arguments:
85 | doTest("(fun f -> (fun x -> f (fun v -> (x x) v)) (fun x -> f (fun v -> (x x) v))) (fun f -> fun x -> f)",
86 | "⊤ -> (⊤ -> 'a) as 'a")
87 |
88 | doTest("let rec trutru = fun g -> trutru (g true) in trutru",
89 | "(bool -> 'a) as 'a -> ⊥")
90 | doTest("fun i -> if ((i i) true) then true else true",
91 | "'a ∧ ('a -> bool -> bool) -> bool")
92 | // ^ for: λi. if ((i i) true) then true else true,
93 | // Dolan's thesis says MLsub infers: (α → ((bool → bool) ⊓ α)) → bool
94 | // which does seem equivalent, despite being quite syntactically different
95 | }
96 |
97 | test("let-poly") {
98 | doTest("let f = fun x -> x in {a = f 0; b = f true}", "{a: int, b: bool}")
99 | doTest("fun y -> let f = fun x -> x in {a = f y; b = f true}",
100 | "'a -> {a: 'a, b: bool}")
101 | doTest("fun y -> let f = fun x -> y x in {a = f 0; b = f true}",
102 | "(bool ∨ int -> 'a) -> {a: 'a, b: 'a}")
103 | doTest("fun y -> let f = fun x -> x y in {a = f (fun z -> z); b = f (fun z -> true)}",
104 | "'a -> {a: 'a, b: bool}")
105 | doTest("fun y -> let f = fun x -> x y in {a = f (fun z -> z); b = f (fun z -> succ z)}",
106 | "'a ∧ int -> {a: 'a, b: int}")
107 |
108 | error("(fun k -> k (fun x -> let tmp = add x 1 in x)) (fun f -> f true)",
109 | "cannot constrain bool <: int")
110 | // Let-binding a part in the above test:
111 | error("(fun k -> let test = k (fun x -> let tmp = add x 1 in x) in test) (fun f -> f true)",
112 | "cannot constrain bool <: int")
113 | }
114 |
115 | test("recursion") {
116 | doTest("let rec f = fun x -> f x.u in f",
117 | "{u: 'a} as 'a -> ⊥")
118 |
119 | doTest("let rec consume = fun strm -> add strm.head (consume strm.tail) in consume",
120 | " ")
121 |
122 | // [test:T2]:
123 | doTest("let rec r = fun a -> r in if true then r else r",
124 | "(⊤ -> 'a) as 'a")
125 | // ^ without canonicalization, we get the type:
126 | // ⊤ -> (⊤ -> 'a) as 'a ∨ (⊤ -> 'b) as 'b
127 | doTest("let rec l = fun a -> l in let rec r = fun a -> fun a -> r in if true then l else r",
128 | "(⊤ -> ⊤ -> 'a) as 'a")
129 | // ^ without canonicalization, we get the type:
130 | // ⊤ -> (⊤ -> 'a) as 'a ∨ (⊤ -> (⊤ -> ⊤ -> 'b) as 'b)
131 | doTest("let rec l = fun a -> fun a -> fun a -> l in let rec r = fun a -> fun a -> r in if true then l else r",
132 | "(⊤ -> ⊤ -> ⊤ -> ⊤ -> ⊤ -> ⊤ -> 'a) as 'a") // 6 is the LCD of 3 and 2
133 | // ^ without canonicalization, we get the type:
134 | // ⊤ -> ⊤ -> (⊤ -> ⊤ -> 'a) as 'a ∨ (⊤ -> (⊤ -> ⊤ -> ⊤ -> 'b) as 'b)
135 |
136 | // from https://www.cl.cam.ac.uk/~sd601/mlsub/
137 | doTest("let rec recursive_monster = fun x -> { thing = x; self = recursive_monster x } in recursive_monster",
138 | "'a -> {self: 'b, thing: 'a} as 'b")
139 | }
140 |
141 | test("random") {
142 | doTest("(let rec x = {a = x; b = x} in x)", "{a: 'a, b: 'a} as 'a")
143 | doTest("(let rec x = fun v -> {a = x v; b = x v} in x)", "⊤ -> {a: 'a, b: 'a} as 'a")
144 | error("let rec x = (let rec y = {u = y; v = (x y)} in 0) in 0", "cannot constrain int <: 'a -> 'b")
145 | doTest("(fun x -> (let y = (x x) in 0))", "'a ∧ ('a -> ⊤) -> int")
146 | doTest("(let rec x = (fun y -> (y (x x))) in x)", "('a -> ('a ∧ ('a -> 'b)) as 'b) -> 'a")
147 | // ^ Note: without canonicalization, we get the simpler: ('b -> 'b ∧ 'a) as 'a -> 'b
148 | doTest("fun next -> 0", "⊤ -> int")
149 | doTest("((fun x -> (x x)) (fun x -> x))", "('b ∨ ('b -> 'a)) as 'a")
150 | doTest("(let rec x = (fun y -> (x (y y))) in x)", "('b ∧ ('b -> 'a)) as 'a -> ⊥")
151 | doTest("fun x -> (fun y -> (x (y y)))", "('a -> 'b) -> 'c ∧ ('c -> 'a) -> 'b")
152 | doTest("(let rec x = (let y = (x x) in (fun z -> z)) in x)", "'a -> ('a ∨ ('a -> 'b)) as 'b")
153 | doTest("(let rec x = (fun y -> (let z = (x x) in y)) in x)", "'a -> ('a ∨ ('a -> 'b)) as 'b")
154 | doTest("(let rec x = (fun y -> {u = y; v = (x x)}) in x)",
155 | "'a -> {u: 'a ∨ ('a -> 'b), v: 'c} as 'c as 'b")
156 | doTest("(let rec x = (fun y -> {u = (x x); v = y}) in x)",
157 | "'a -> {u: 'c, v: 'a ∨ ('a -> 'b)} as 'c as 'b")
158 | doTest("(let rec x = (fun y -> (let z = (y x) in y)) in x)", "('b ∧ ('a -> ⊤) -> 'b) as 'a")
159 | doTest("(fun x -> (let y = (x x.v) in 0))", "{v: 'a} ∧ ('a -> ⊤) -> int")
160 | doTest("let rec x = (let y = (x x) in (fun z -> z)) in (x (fun y -> y.u))", // [test:T1]
161 | "'a ∨ ('a ∧ {u: 'b} -> ('a ∨ 'b ∨ ('a ∧ {u: 'b} -> 'c)) as 'c)")
162 | // ^ Note: without canonicalization, we get the simpler:
163 | // ('b ∨ ('b ∧ {u: 'c} -> 'a ∨ 'c)) as 'a
164 | }
165 |
166 | test("occurs-check") {
167 |
168 | error("fun x -> x.u x", "")
169 | error("fun x -> x.u {v=x}", "")
170 | doTest("fun x -> x.u x.v", " ")
171 | error("fun x -> x.u.v x", "")
172 |
173 | }
174 |
175 |
176 | }
177 |
--------------------------------------------------------------------------------