├── project ├── build.properties └── plugins.sbt ├── .gitignore ├── test-src └── abt │ ├── examples │ └── BidirTests.scala │ └── AbtTests.scala ├── .travis.yml ├── src └── abt │ ├── examples │ ├── BidirTypechecking.scala │ ├── LambdaCalc.scala │ └── MiniMLTermSyn.scala │ └── abt.scala ├── README.md └── abt.ml /project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version=0.13.9 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /bin/ 2 | .cache-main 3 | .classpath 4 | .project 5 | -------------------------------------------------------------------------------- /project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "3.0.0") 2 | -------------------------------------------------------------------------------- /test-src/abt/examples/BidirTests.scala: -------------------------------------------------------------------------------- 1 | package abt 2 | package examples 3 | 4 | import LambdaCalc._ 5 | import BidirTypechecking._ 6 | import org.scalatest.FunSuite 7 | 8 | /** 9 | * @author pgiarrusso 10 | */ 11 | class BidirTests extends FunSuite { 12 | test("Simple tests for bidirectional type checking") ({ 13 | check(Map(Name("x") -> Base), Var("x"), Base) 14 | val tId = Arrow(Base, Base) 15 | check(Map.empty, Lam("x", Var("x")), tId) 16 | check(Map.empty, Lam("x", Var("x")), Arrow(tId, tId)) 17 | check(Map.empty, Lam("x", Lam("y", App(Var("x"), Var("y")))), Arrow(tId, tId)) 18 | intercept[IllegalArgumentException] { 19 | check(Map.empty, Lam("x", Var("x")), Arrow(tId, Base)) 20 | } 21 | }) 22 | } 23 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: scala 2 | # Keep in sync by hand with build.sbt. 3 | scala: 4 | - 2.11.7 5 | jdk: 6 | - openjdk7 7 | sudo: false 8 | # Caching setup from: 9 | # http://www.scala-sbt.org/0.13/docs/Travis-CI-with-sbt.html 10 | cache: 11 | directories: 12 | - $HOME/.sbt/boot 13 | - $HOME/.sbt/launchers 14 | - $HOME/.ivy2/cache 15 | install: 16 | - travis_retry sbt ++$TRAVIS_SCALA_VERSION update 17 | script: 18 | # Default Travis script. 19 | - sbt ++$TRAVIS_SCALA_VERSION test 20 | 21 | # Tricks to avoid unnecessary cache updates, also from 22 | # http://www.scala-sbt.org/0.13/docs/Travis-CI-with-sbt.html 23 | - find $HOME/.sbt -name "*.lock" | xargs rm 24 | - find $HOME/.ivy2 -name "ivydata-*.properties" | xargs rm 25 | -------------------------------------------------------------------------------- /test-src/abt/AbtTests.scala: -------------------------------------------------------------------------------- 1 | package abt 2 | 3 | import examples.LambdaCalc._ 4 | 5 | import org.scalatest.FunSuite 6 | 7 | /** 8 | * @author pgiarrusso 9 | */ 10 | class AbtTests extends FunSuite { 11 | test("alpha-equiv distinguishes different terms") { 12 | assert( 13 | !(Lam("x", App(Var("x"), Var("x"))) alphaEquiv 14 | Lam("y", Var("y")))) 15 | } 16 | 17 | test("alpha-equiv works on closed terms") { 18 | assert( 19 | Lam("x", Var("x")) alphaEquiv Lam("y", Var("y"))) 20 | assert( 21 | Lam("x", App(Var("x"), Var("x"))) alphaEquiv 22 | Lam("y", App(Var("y"), Var("y")))) 23 | } 24 | 25 | test("alpha-equiv works on open terms") { 26 | assert(!(Lam("x", Var("x")) alphaEquiv Lam("y", Var("x")))) 27 | assert(!(Lam("x", Var("x")) alphaEquiv Lam("x", Var("y")))) 28 | assert(!(Lam("y", Var("x")) alphaEquiv Lam("x", Var("y")))) 29 | } 30 | 31 | test("alpha-equiv is a no-op when substituting unused variables") { 32 | val term = Lam("x", App(Var("x"), Var("y"))) 33 | assert((term subst ("z", Var("w"))) == term) 34 | } 35 | 36 | test("alpha-equiv should not rename variables needlessly") { 37 | assert((Lam("x", App(Var("x"), Var("y"))) subst ("y", Var("z"))) == 38 | Lam("x", App(Var("x"), Var("z")))) 39 | } 40 | 41 | test("substitution testcases") { 42 | assert(Lam("x", Var("y")) subst ("y", Var("x")) alphaEquiv Lam("y", Var("x"))) 43 | assert(Lam("x", App(Var("y"), Lam("y", App(Var("y"), App(Var("x"), Var("z")))))) subst ("y", Var("x")) alphaEquiv 44 | Lam("w", App(Var("x"), Lam("y", App(Var("y"), App(Var("w"), Var("z"))))))) 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /src/abt/examples/BidirTypechecking.scala: -------------------------------------------------------------------------------- 1 | package abt 2 | package examples 3 | 4 | import ABTs._ 5 | 6 | object BidirTypechecking { 7 | import LambdaCalc._ 8 | 9 | type Ctx = Map[Name, SimpleType] 10 | 11 | def isSynth: Term => Boolean = { 12 | case Lam(_, _) => false 13 | case Let(_, _, _) => false 14 | case _ => true 15 | } 16 | def isCheck(bt: Term): Boolean = !isSynth(bt) 17 | 18 | def fail(msg: String) = 19 | throw new IllegalArgumentException(msg) 20 | 21 | def synth(ctx: Ctx, e: Term): SimpleType = { 22 | e match { 23 | case Var(x) => 24 | ctx get x getOrElse fail("unbound variable") 25 | case Annot(e, tp) => 26 | check(ctx, e, tp) 27 | tp 28 | case App(f, e) => 29 | synth(ctx, f) match { 30 | case Arrow(s, t) => 31 | check(ctx, e, s) 32 | t 33 | case _ => fail("Applying a non-function!") 34 | } 35 | case _ if isCheck(e) => 36 | fail("Cannot synthesize type for checking term") 37 | case _ => 38 | fail("Unexpected term") 39 | } 40 | } 41 | 42 | def check(ctx: Ctx, e: Term, tp: SimpleType): Unit = { 43 | (e, tp) match { 44 | //Lambda 45 | case (Lam(x, e1), Arrow(tp1, tp2)) => 46 | check(ctx updated (x, tp1), e1, tp2) 47 | case (Lam(_, _), _) => 48 | fail("Expected arrow type") 49 | //Let 50 | case (Let(x, e1, e2), _) => 51 | val tp1 = synth(ctx, e1) 52 | check(ctx updated (x, tp1), e2, tp) 53 | case _ if isSynth(e) => 54 | if (tp == synth(ctx, e)) 55 | () 56 | else fail("Type mismatch") 57 | case _ => 58 | fail("Unexpected term") 59 | } 60 | } 61 | } 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prototyping Abstract Binding Trees 2 | 3 | While this might grow into more (say, a usable library), currently this is a 4 | playground for me to learn about Abstract Binding Trees. 5 | 6 | I started off from the code in 7 | http://semantic-domain.blogspot.de/2015/03/abstract-binding-trees.html. 8 | 9 | Scala-specific goodies: 10 | - simplify manipulating ABTs through extractors, so you can write `Lam("x", Var("x"))` instead of the underlying 11 | `TermInt(Set(),_Tm(_Lam(TermInt(Set(),__Abs(x,TermInt(Set(x),_Var(x)))))))`. 12 | Unfortunately, the latter is still visible through `.toString`, and calls for an 13 | implemetation of pretty-printing. 14 | 15 | There are many possible TODOs: 16 | 17 | 1. Use better algorithms -- the blog post declares it's using simple ones 18 | - [x] For instance, use parallel substitution to avoid quadratic complexity (done in 1db5ee73a1cec39c6674a04b417b30bd212a0626). 19 | - [x] Avoid concatenating string to build names, that's slow (see https://github.com/Blaisorblade/abt/compare/topic/faster-names). 20 | - [x] Replace the freshness generator with something faster. 21 | 2. Review those fancier algorithms to ensure they're correct --- ABT 22 | mean you need to get binding right only once, but you still have to. 23 | 3. Try out whether the implementation can use other techniques: Must one use 24 | names to implement ABTs, or would it be possible to use, say, deBrujin 25 | indexes? ABTs relate to HOAS, and are in fact probably a more awkward 26 | approach to achieve the same benefits, but is it possible to implement the 27 | ABT interface *using* HOAS? What's the design space? 28 | 29 | # Motivation 30 | 31 | Usually, when implementing languages with binding, you have lots of 32 | language-specific boilerplate, of size proportional to the size of the 33 | language's grammar and to the number of languages. This becomes worse when a 34 | project contains multiple languages with binding. 35 | 36 | Abstract Binding Trees promise to minimize the language-specific overhead; in 37 | particular, one needs only to take the language's syntax, remove binding 38 | information, turn the algebraic data types into a *pattern functor* and 39 | implement an instance of Foldable for this functor; the latter step is 40 | mechanical enough that it can even be automated (and it is in many languages). 41 | -------------------------------------------------------------------------------- /src/abt/examples/LambdaCalc.scala: -------------------------------------------------------------------------------- 1 | package abt 2 | package examples 3 | 4 | import ABTs._ 5 | import scalaz.Functor 6 | import scalaz.Foldable 7 | import scalaz.Monoid 8 | 9 | /** 10 | * A Curry-style typed lambda calculus, implemented using ABTs. 11 | */ 12 | object LambdaCalc { 13 | trait SimpleType 14 | case object Base extends SimpleType 15 | case class Arrow(t1: SimpleType, t2: SimpleType) extends SimpleType 16 | 17 | protected sealed trait _TLambda[T] 18 | private case class _Lam[T](t: T) extends _TLambda[T] 19 | private case class _App[T](t1: T, t2: T) extends _TLambda[T] 20 | private case class _Let[T](t1: T, t2: T) extends _TLambda[T] 21 | private case class _Annot[T](t: T, tp: SimpleType) extends _TLambda[T] 22 | 23 | implicit val lambdaSig: Functor[_TLambda] with Foldable[_TLambda] = 24 | new Functor[_TLambda] with Foldable[_TLambda] with Foldable.FromFoldMap[_TLambda] { 25 | def map[A, B](fa: _TLambda[A])(f: A => B): _TLambda[B] = fa match { 26 | case _Lam(t) => _Lam(f(t)) 27 | case _Annot(t, tp) => _Annot(f(t), tp) 28 | case _App(t1, t2) => _App(f(t1), f(t2)) 29 | case _Let(t1, t2) => _Let(f(t1), f(t2)) 30 | } 31 | 32 | def foldMap[A,B](fa: _TLambda[A])(f: A => B)(implicit F: Monoid[B]): B = 33 | fa match { 34 | case _Lam(t) => f(t) 35 | case _Annot(t, _) => f(t) 36 | case _App(t1, t2) => F.append(f(t1), f(t2)) 37 | case _Let(t1, t2) => F.append(f(t1), f(t2)) 38 | } 39 | } 40 | 41 | val lambdaAbt: IAbt[_TLambda] = new Abt 42 | 43 | //Needed reexports 44 | type Term = lambdaAbt.Term 45 | val Var = lambdaAbt.Var 46 | type Name = ABTs.Name 47 | val Name = ABTs.Name 48 | implicit val TermOps = lambdaAbt.TermOps _ 49 | 50 | // 51 | import lambdaAbt._ 52 | //Smart constructors/extractors 53 | object Lam { 54 | def apply(name: Name, body: Term): Term = 55 | _Tm(_Lam(_Abs(name, body))) 56 | def unapply(t: Term): Option[(Name, Term)] = t match { 57 | case _Tm(_Lam(_Abs(name, body))) => Some((name, body)) 58 | case _ => None 59 | } 60 | } 61 | 62 | object App { 63 | def apply(f: Term, arg: Term): Term = 64 | _Tm(_App(f, arg)) 65 | def unapply(t: Term): Option[(Term, Term)] = t match { 66 | case _Tm(_App(f, arg)) => Some((f, arg)) 67 | case _ => None 68 | } 69 | } 70 | 71 | object Annot { 72 | def apply(t: Term, tp: SimpleType): Term = 73 | _Tm(_Annot(t, tp)) 74 | def unapply(t: Term): Option[(Term, SimpleType)] = t match { 75 | case _Tm(_Annot(t, tp)) => Some((t, tp)) 76 | case _ => None 77 | } 78 | } 79 | 80 | object Let { 81 | def apply(name: Name, t1: Term, t2: Term): Term = 82 | _Tm(_Let(t1, _Abs(name, t2))) 83 | def unapply(t: Term): Option[(Name, Term, Term)] = t match { 84 | case _Tm(_Let(t1, _Abs(name, t2))) => Some((name, t1, t2)) 85 | case _ => None 86 | } 87 | } 88 | } 89 | -------------------------------------------------------------------------------- /abt.ml: -------------------------------------------------------------------------------- 1 | (* -*- mode: ocaml; -*- *) 2 | (* A copy of https://gist.github.com/neel-krishnaswami/834b892327271e348f79 *) 3 | 4 | module type FUNCTOR = sig 5 | type 'a t 6 | val map : ('a -> 'b) -> 'a t -> 'b t 7 | end 8 | 9 | type 'a monoid = {unit : 'a ; join : 'a -> 'a -> 'a} 10 | 11 | type var = string 12 | module V = Set.Make(struct type t = var let compare = compare end) 13 | 14 | module type ABT = sig 15 | type 'a signature 16 | type 'a f = Var of var | Abs of var * 'a | Tm of 'a signature 17 | val map : ('a -> 'b) -> 'a f -> 'b f 18 | 19 | type t 20 | val into : t f -> t 21 | val out : t -> t f 22 | 23 | val freevars : t -> V.t 24 | val var : V.elt -> t 25 | val abs : V.elt * t -> t 26 | val tm : t signature -> t 27 | 28 | val subst : t -> var -> t -> t 29 | end 30 | 31 | module type SIGNATURE = sig 32 | include FUNCTOR 33 | 34 | val join : 'a monoid -> 'a t -> 'a 35 | end 36 | 37 | module Abt(F : SIGNATURE) : ABT with type 'a signature := 'a F.t = 38 | struct 39 | type 'a f = 40 | | Var of var 41 | | Abs of var * 'a 42 | | Tm of 'a F.t 43 | 44 | let map f = function 45 | | Var x -> Var x 46 | | Abs(x, e) -> Abs(x, f e) 47 | | Tm t -> Tm (F.map f t) 48 | 49 | type t = In of V.t * t f 50 | 51 | let freevars (In(vs, _)) = vs 52 | let out (In(_, t)) = t 53 | 54 | let m = {unit = V.empty; join = V.union} 55 | 56 | let var x = In(V.singleton x, Var x) 57 | let abs(z, e) = In(V.remove z (freevars e), Abs(z, e)) 58 | let tm t = In(F.join m (F.map freevars t), Tm t) 59 | 60 | let into = function 61 | | Var x -> var x 62 | | Abs(x, e) -> abs(x, e) 63 | | Tm t -> tm t 64 | 65 | let rec fresh vs v = 66 | if V.mem v vs then fresh vs (v ^ "'") else v 67 | 68 | let rec rename x y (In(fvs, t)) = 69 | match t with 70 | | Var z -> if x = z then var y else var z 71 | | Abs(z, e) -> if x = z then abs(z, e) else abs(z, rename x y e) 72 | | Tm v -> tm (F.map (rename x y) v) 73 | 74 | let rec subst t x body = 75 | match out body with 76 | | Var z when x = z -> t 77 | | Var z -> var z 78 | | Abs(x, e) -> 79 | let x' = fresh (V.union (freevars t) (freevars body)) x in 80 | let e' = subst t x (rename x x' e) in 81 | abs(x', e') 82 | | Tm body -> tm (F.map (subst t x) body) 83 | end 84 | 85 | module Lambda = 86 | struct 87 | type tp = Base | Arrow of tp * tp 88 | type 'a t = Lam of 'a | App of 'a * 'a | Let of 'a * 'a | Annot of tp * 'a 89 | let map f = function 90 | | Lam x -> Lam (f x) 91 | | App (x, y) -> App(f x, f y) 92 | | Let (x, y) -> Let(f x, f y) 93 | | Annot(t, x) -> Annot(t, f x) 94 | 95 | let join m = function 96 | | Lam x -> x 97 | | App(x, y) -> m.join x y 98 | | Let(x, y) -> m.join x y 99 | | Annot(_, x) -> x 100 | end 101 | 102 | module Syntax = Abt(Lambda) 103 | 104 | module Bidir = struct 105 | open Lambda 106 | open Syntax 107 | type ctx = (var * tp) list 108 | 109 | let is_synth = function 110 | | Tm (Lam _) | Tm (Let (_, _)) -> false 111 | | _ -> true 112 | let is_check e = not(is_synth e) 113 | 114 | let unabs e = 115 | match out e with 116 | | Abs(x, e) -> (x, e) 117 | | _ -> assert false 118 | 119 | let rec check ctx e tp = 120 | match out e, tp with 121 | | Tm (Lam t), Arrow(tp1, tp') -> 122 | let (x, e') = unabs t in 123 | check ((x, tp1) :: ctx) e' tp' 124 | | Tm (Lam _), _ -> failwith "expected arrow type" 125 | | Tm (Let(e', t)), _ -> 126 | let (x, e'') = unabs t in 127 | let tp1 = synth ctx e' in 128 | check ((x, tp1) :: ctx) e'' tp 129 | | body, _ when is_synth body -> 130 | if tp = synth ctx e then () else failwith "Type mismatch" 131 | | _ -> assert false 132 | 133 | and synth ctx e = 134 | match out e with 135 | | Var x -> (try List.assoc x ctx with Not_found -> failwith "unbound variable") 136 | | Tm(Annot(tp, e)) -> let () = check ctx e tp in tp 137 | | Tm(App(f, e)) -> 138 | (match synth ctx f with 139 | | Arrow(tp, tp') -> let () = check ctx e tp in tp' 140 | | _ -> failwith "Applying a non-function!") 141 | | body when is_check body -> failwith "Cannot synthesize type for checking term" 142 | | _ -> assert false 143 | end 144 | -------------------------------------------------------------------------------- /src/abt/examples/MiniMLTermSyn.scala: -------------------------------------------------------------------------------- 1 | package abt 2 | package examples 3 | 4 | import ABTs._ 5 | import scalaz.Functor 6 | import scalaz.Foldable 7 | import scalaz.Monoid 8 | 9 | /** 10 | * @author pgiarrusso 11 | */ 12 | object MiniMLTypeSyn { 13 | protected sealed trait _Type[T] 14 | private case class _Base[T]() extends _Type[T] 15 | private case class _Arrow[T](t1: T, t2: T) extends _Type[T] 16 | 17 | private case class _Poly[T](t: T) extends _Type[T] 18 | 19 | implicit val typeSig: Functor[_Type] with Foldable[_Type] = 20 | new Functor[_Type] with Foldable[_Type] with Foldable.FromFoldMap[_Type] { 21 | def map[A, B](fa: _Type[A])(f: A => B): _Type[B] = 22 | fa match { 23 | case _Base() => _Base() 24 | case _Arrow(t1, t2) => _Arrow(f(t1), f(t2)) 25 | case _Poly(t) => _Poly(f(t)) 26 | } 27 | 28 | def foldMap[A,B](fa: _Type[A])(f: A => B)(implicit F: Monoid[B]): B = 29 | fa match { 30 | case _Base() => F.zero 31 | case _Arrow(t1, t2) => F.append(f(t1), f(t2)) 32 | case _Poly(t) => f(t) 33 | } 34 | } 35 | 36 | val typeAbt: IAbt[_Type] = new Abt 37 | 38 | //Needed reexports 39 | type Type = typeAbt.Term 40 | val TVar = typeAbt.Var 41 | implicit val TermOps = typeAbt.TermOps _ 42 | import typeAbt._ 43 | 44 | val Base: Type = _Tm(_Base()) 45 | 46 | object Arrow { 47 | def apply(t1: Type, t2: Type) = 48 | _Tm(_Arrow(t1, t2)) 49 | def unapply(t: Type): Option[(Type, Type)] = t match { 50 | case _Tm(_Arrow(t1, t2)) => Some((t1, t2)) 51 | case _ => None 52 | } 53 | } 54 | 55 | object Poly { 56 | def apply(name: Name, t: Type) = 57 | _Tm(_Poly(_Abs(name, t))) 58 | def unapply(t: Type): Option[(Name, Type)] = t match { 59 | case _Tm(_Poly(_Abs(name, t))) => Some((name, t)) 60 | case _ => None 61 | } 62 | } 63 | } 64 | 65 | object MiniMLTypes { 66 | import MiniMLTypeSyn._ 67 | 68 | def isMono(t: Type): Boolean = t match { 69 | case Base => true 70 | case Arrow(t1, t2) => isMono(t1) && isMono(t2) 71 | case Poly(_, _) => false 72 | } 73 | 74 | //Following MiniML, we don't bind type variables in the context, we just use 75 | //them free. 76 | case class Context(asMap: Map[Name, Type]) { 77 | require(asMap forall { 78 | case (_, typ) => isMono(typ) 79 | }) 80 | } 81 | 82 | def generalize(ctx: Context, t: Type): Type = { 83 | assert(isMono(t)) 84 | val freeVars = (t.freeVars -- ctx.asMap.values.flatMap(_.freeVars)).toList.sorted 85 | freeVars.foldRight(t)(Poly(_, _)) 86 | } 87 | 88 | def instantiate(t: Type, ctx: Context): Type = { 89 | t.subst(ctx.asMap) 90 | } 91 | 92 | } 93 | 94 | object MiniMLTermSyn { 95 | import MiniMLTypeSyn._ 96 | 97 | protected sealed trait _TLambda[T] 98 | private case class _Lam[T](t: T) extends _TLambda[T] 99 | private case class _App[T](t1: T, t2: T) extends _TLambda[T] 100 | private case class _Let[T](t1: T, t2: T) extends _TLambda[T] 101 | private case class _Annot[T](t: T, tp: Type) extends _TLambda[T] 102 | 103 | implicit val lambdaSig: Functor[_TLambda] with Foldable[_TLambda] = 104 | new Functor[_TLambda] with Foldable[_TLambda] with Foldable.FromFoldMap[_TLambda] { 105 | def map[A, B](fa: _TLambda[A])(f: A => B): _TLambda[B] = fa match { 106 | case _Lam(t) => _Lam(f(t)) 107 | case _Annot(t, tp) => _Annot(f(t), tp) 108 | case _App(t1, t2) => _App(f(t1), f(t2)) 109 | case _Let(t1, t2) => _Let(f(t1), f(t2)) 110 | } 111 | 112 | def foldMap[A,B](fa: _TLambda[A])(f: A => B)(implicit F: Monoid[B]): B = 113 | fa match { 114 | case _Lam(t) => f(t) 115 | case _Annot(t, _) => f(t) 116 | case _App(t1, t2) => F.append(f(t1), f(t2)) 117 | case _Let(t1, t2) => F.append(f(t1), f(t2)) 118 | } 119 | } 120 | 121 | val lambdaAbt: IAbt[_TLambda] = new Abt 122 | 123 | //Needed reexports 124 | type Term = lambdaAbt.Term 125 | val Var = lambdaAbt.Var 126 | implicit val TermOps = lambdaAbt.TermOps _ 127 | 128 | // 129 | import lambdaAbt._ 130 | //Smart constructors/extractors 131 | object Lam { 132 | def apply(name: Name, body: Term): Term = 133 | _Tm(_Lam(_Abs(name, body))) 134 | def unapply(t: Term): Option[(Name, Term)] = t match { 135 | case _Tm(_Lam(_Abs(name, body))) => Some((name, body)) 136 | case _ => None 137 | } 138 | } 139 | 140 | object App { 141 | def apply(f: Term, arg: Term): Term = 142 | _Tm(_App(f, arg)) 143 | def unapply(t: Term): Option[(Term, Term)] = t match { 144 | case _Tm(_App(f, arg)) => Some((f, arg)) 145 | case _ => None 146 | } 147 | } 148 | 149 | object Annot { 150 | def apply(t: Term, tp: Type): Term = 151 | _Tm(_Annot(t, tp)) 152 | def unapply(t: Term): Option[(Term, Type)] = t match { 153 | case _Tm(_Annot(t, tp)) => Some((t, tp)) 154 | case _ => None 155 | } 156 | } 157 | 158 | object Let { 159 | def apply(name: Name, t1: Term, t2: Term): Term = 160 | _Tm(_Let(t1, _Abs(name, t2))) 161 | def unapply(t: Term): Option[(Name, Term, Term)] = t match { 162 | case _Tm(_Let(t1, _Abs(name, t2))) => Some((name, t1, t2)) 163 | case _ => None 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /src/abt/abt.scala: -------------------------------------------------------------------------------- 1 | // Translate https://gist.github.com/neel-krishnaswami/834b892327271e348f79 2 | /** 3 | * Support for *Abstract Binding Trees*, or ABTs. 4 | * 5 | * To implement a language using Abstract Binding Trees: 6 | * 1. First, you provide an underlying algebric data type 7 | * Signature representing the abstract syntax of your language, and provide an 8 | * implementation of Functor[Signature] and Foldable[Signature]. This abstract 9 | * syntax definition contains no binding information; for instance, 10 | * a lambda abstraction is defined as a node _Lam(t: Term). 11 | * 2. Then, you provide smart constructors/extractors exposing the intended 12 | * interface to the client. These smart constructors will assemble together AST 13 | * nodes with binding nodes. For instance, you can provide constructors/extractors 14 | * for Lam, that assemble together a _Lam node with an abstractor. 15 | * 16 | * See Lambda for an example. 17 | */ 18 | package abt 19 | 20 | import language.higherKinds 21 | import language.implicitConversions 22 | import scalaz.Functor 23 | import scalaz.Foldable 24 | import scalaz.Monoid 25 | 26 | object ABTs { 27 | //Not part of Scalaz, unless we switch to Scalaz's Set. 28 | implicit def setMonoid[T] = 29 | new Monoid[Set[T]] { 30 | def zero = Set.empty 31 | def append(a: Set[T], b: => Set[T]): Set[T] = a ++ b 32 | } 33 | implicit def seqMonoid[T] = 34 | new Monoid[Seq[T]] { 35 | def zero = Seq.empty 36 | def append(a: Seq[T], b: => Seq[T]): Seq[T] = a ++ b 37 | } 38 | implicit def vectorMonoid[T] = 39 | new Monoid[Vector[T]] { 40 | def zero = Vector.empty 41 | def append(a: Vector[T], b: => Vector[T]): Vector[T] = a ++ b 42 | } 43 | case class Name(s: String, n: Int = 0) { 44 | override def toString = s"${s}_${n}" 45 | } 46 | 47 | object Name { 48 | implicit val nOrdering: Ordering[Name] = Ordering.by { name => (name.s, name.n) } 49 | } 50 | 51 | implicit def toName(s: String) = Name(s) 52 | type Names = Set[Name] 53 | } 54 | 55 | import ABTs._ 56 | 57 | /** 58 | * Interface for Abstract Binding Trees. To specialize them to a specific 59 | * datatype, you need to define extra constructors, and smart constructors 60 | * using Abs in the right places (the ones where, in HOAS, you'd use an open 61 | * term Term => Term). 62 | * 63 | * Members are divided in three groups: 64 | * - Some members can be re-exported by the implementation of the client 65 | * language; their name starts without underscores. 66 | * - Some members are to be used by the implementation of the client 67 | * language. Since that is no subclass, they cannot be made protected; 68 | * instead their names start with underscores. 69 | * - Some members are just to be implemented by concrete ABT implementations; 70 | * they are marked as protected and their name start with one or more 71 | * underscores. 72 | * 73 | * @param Signature the signature of the underlying AST. 74 | */ 75 | /* XXX: Don't use I for interfaces. What's the right convention? */ 76 | trait IAbt[Signature[_]] { 77 | outer => 78 | 79 | /* 80 | * Members to reexport to clients. 81 | */ 82 | //Abstract type of terms 83 | type Term 84 | 85 | //Smart constructors/extractors. 86 | 87 | //This one will typically be part of your language. 88 | object Var { 89 | def apply(n: Name): Term = _mkVar(n) 90 | def unapply(t: Term): Option[Name] = t match { 91 | case _Term(_Var(n)) => Some(n) 92 | case _ => None 93 | } 94 | } 95 | 96 | implicit class TermOps(t: Term) { 97 | @inline def freeVars: Names = _freeVars(t) 98 | @inline def subst(v: Name, inner: Term): Term = subst(Map(v -> inner)) 99 | @inline def subst(map: Map[Name, Term]): Term = _subst(t, map) 100 | @inline def alphaEquiv(other: Term): Boolean = _alphaEquiv(t, other) 101 | } 102 | 103 | /* 104 | * Members for use by language implementations. 105 | */ 106 | /** 107 | * Term is isomorphic to __Term[Term], and this is the witness. 108 | * This isomorphism is not exposed to language clients. 109 | */ 110 | object _Term { 111 | def apply(t: __Term[Term]): Term = _into(t) 112 | def unapply(t: Term): Some[__Term[Term]] = Some(_out(t)) 113 | } 114 | 115 | object _Tm { 116 | def apply(t: Signature[Term]): Term = _mkTm(t) 117 | def unapply(t: Term): Option[Signature[Term]] = t match { 118 | case _Term(__Tm(t1)) => Some(t1) 119 | case _ => None 120 | } 121 | } 122 | 123 | object _Abs { 124 | def apply(n: Name, body: Term): Term = _mkAbs(n, body) 125 | def unapply(t: Term): Option[(Name, Term)] = t match { 126 | case _Term(__Abs(n, body)) => Some((n, body)) 127 | case _ => None 128 | } 129 | } 130 | 131 | protected def _freeVars(t: Term): Names 132 | protected def _subst(t: Term, map: Map[Name, Term]): Term 133 | protected def _alphaEquiv(t1: Term, t2: Term): Boolean 134 | 135 | //Methods for internal usage, outside of the interface. 136 | //Concrete type used to build terms. 137 | protected sealed trait __Term[T] 138 | //This is used by smart constructors of all binders. 139 | protected case class __Abs[T](n: Name, t: T) extends __Term[T] 140 | 141 | protected case class _Var[A](n: Name) extends __Term[A] 142 | //This is used by smart constructors of all other operations. 143 | protected case class __Tm[T](t: Signature[T]) extends __Term[T] 144 | 145 | protected def map[A, B](bt: __Term[A])(f: A => B): __Term[B] 146 | 147 | protected object __Term { 148 | implicit val isFunctor = 149 | new Functor[__Term] { 150 | def map[A, B](bt: __Term[A])(f: A => B): __Term[B] = outer.map(bt)(f) 151 | } 152 | } 153 | 154 | protected def _into(t: __Term[Term]): Term 155 | protected def _out(t: Term): __Term[Term] 156 | 157 | protected def _mkVar(n: Name): Term 158 | protected def _mkAbs(n: Name, body: Term): Term 159 | protected def _mkTm(t: Signature[Term]): Term 160 | } 161 | 162 | class Abt[Signature[_]: Functor: Foldable] extends IAbt[Signature] { 163 | def map[A, B](bt: __Term[A])(f: A => B): __Term[B] = bt match { 164 | case _Var(n) => _Var(n) 165 | case __Abs(n, body) => __Abs(n, f(body)) 166 | case __Tm(t) => __Tm(Functor[Signature].map(t)(f)) 167 | } 168 | 169 | type Term = TermInt 170 | case class TermInt(vars: Names, t: __Term[Term]) 171 | 172 | def _into(t: __Term[Term]): Term = 173 | t match { 174 | case _Var(n) => _mkVar(n) 175 | case __Abs(n, body) => _mkAbs(n, body) 176 | case __Tm(t) => _mkTm(t) 177 | } 178 | 179 | def _out(t: Term): __Term[Term] = t.t 180 | 181 | def _mkVar(n: Name): Term = TermInt(Set(n), _Var(n)) 182 | def _mkAbs(n: Name, body: Term): Term = 183 | TermInt(_freeVars(body) - n, __Abs(n, body)) 184 | 185 | def _mkTm(t: Signature[Term]): Term = 186 | TermInt(Foldable[Signature].fold(Functor[Signature].map(t)(_freeVars)), __Tm(t)) 187 | 188 | def _freeVars(t: Term): Names = t.vars 189 | 190 | var index = 0 191 | def fresh(): Name = { 192 | index += 1 193 | Name("x", index) 194 | } 195 | def fresh(baseName: Name, vars: Names): Name = 196 | if (vars contains baseName) { 197 | val Name(name, idx) = baseName 198 | fresh(Name(name, idx + 1), vars) 199 | } else 200 | baseName 201 | 202 | def _substQuadratic(outer: Term, v: Name, inner: Term): Term = 203 | substQuadratic(outer, v, inner, true) 204 | 205 | def children[T](s: Signature[T]): Vector[T] = 206 | Foldable[Signature].fold(Functor[Signature].map(s)(Vector(_))) 207 | 208 | //This is a basic substitution with quadratic complexity. 209 | def substQuadratic(outer: Term, v: Name, inner: Term, preRename: Boolean): Term = 210 | _out(outer) match { 211 | case _Var(name) if v == name => inner 212 | case __Tm(body) => 213 | _Tm(Functor[Signature].map[Term, Term](body)(x => substQuadratic(x, v, inner, preRename))) 214 | case __Abs(name, body) if v != name => 215 | val (name1, body1) = 216 | if (preRename) { 217 | val newName = fresh(name, _freeVars(body) ++ _freeVars(inner)) 218 | val newBody = substQuadratic(body, name, Var(newName), false) 219 | (newName, newBody) 220 | } else { 221 | //We're replacing 222 | (name, body) 223 | } 224 | val body2 = substQuadratic(body1, v, inner, preRename) 225 | _Abs(name1, body2) 226 | case _ => //For when guards fail 227 | outer 228 | } 229 | 230 | def _subst(outer: Term, map: Map[Name, Term]): Term = 231 | subst(outer, map.values.flatMap(_freeVars).toSet, map) 232 | 233 | /** 234 | * Parallel substitution. 235 | * Precondition: 236 | * map.values.flatMap(_freeVars).toSet == fvInners 237 | */ 238 | def subst(outer: Term, fvInners: Set[Name], map: Map[Name, Term]): Term = { 239 | assert(map.values.flatMap(_freeVars).toSet.subsetOf(fvInners), 240 | s"!${map.values.flatMap(_freeVars).toSet}.subsetOf($fvInners)") 241 | //Even stronger assertion 242 | assert(map.values.flatMap(_freeVars).toSet == fvInners, 243 | s"${map.values.flatMap(_freeVars).toSet} != $fvInners") 244 | _out(outer) match { 245 | case _Var(name) => 246 | map get name getOrElse outer 247 | case __Tm(body) => 248 | _Tm(Functor[Signature].map[Term, Term](body)(x => subst(x, fvInners, map))) 249 | case __Abs(name, body) => 250 | val newName = fresh(name, (_freeVars(body) - name) ++ fvInners) 251 | val varsToRemove = map get name map _freeVars getOrElse Set() 252 | val newBody = subst(body, fvInners -- varsToRemove + newName, map + (name -> Var(newName))) 253 | _Abs(newName, newBody) 254 | case _ => //For when guards fail 255 | outer 256 | } 257 | } 258 | 259 | def _alphaEquiv(t1: Term, t2: Term): Boolean = { 260 | alphaEquivLoop(t1, t2, Map(), Map()) 261 | } 262 | 263 | /** 264 | * A linear-time implementation of alpha-equivalence checking. 265 | */ 266 | def alphaEquivLoop(t1: Term, t2: Term, map1: Map[Name, Name], map2: Map[Name, Name]): Boolean = { 267 | (_out(t1), _out(t2)) match { 268 | case (_Var(n1), _Var(n2)) => 269 | (map1 get n1, map2 get n2) match { 270 | case (Some(renamedN1), Some(renamedN2)) => renamedN1 == renamedN2 271 | case (None, None) => n1 == n2 272 | case _ => false 273 | } 274 | case (__Abs(n1, b1), __Abs(n2, b2)) => 275 | val freshName = fresh() 276 | alphaEquivLoop(b1, b2, map1 + (n1 -> freshName), map2 + (n2 -> freshName)) 277 | case (__Tm(s1), __Tm(s2)) => 278 | val c1 = children(s1) 279 | val c2 = children(s2) 280 | c1.length == c2.length && (c1 zip c2).forall { 281 | case (t1, t2) => alphaEquivLoop(t1, t2, map1, map2) 282 | } 283 | case _ => false 284 | } 285 | } 286 | } 287 | --------------------------------------------------------------------------------