├── README.md └── src ├── Type.scala ├── Term.scala ├── Lambda.scala ├── test.sc └── Inferencer.scala /README.md: -------------------------------------------------------------------------------- 1 | hindley-milner 2 | ============== 3 | 4 | A concise Hindley-Milner type inferencer (algorithm W) implemented with Scala 5 | 6 | - Term.scala: syntax tree definition 7 | - Parser.scala: simple parser to typed lambda expression 8 | - Type.scala: Type definition 9 | - Inferencer.scala: main Hindley-Milner type inferencer 10 | 11 | Reference: 12 | - http://okmij.org/ftp/ML/generalization.html 13 | - http://dysphoria.net/2009/06/28/hindley-milner-type-inference-in-scala/ -------------------------------------------------------------------------------- /src/Type.scala: -------------------------------------------------------------------------------- 1 | abstract class Type 2 | 3 | case class Prim(name: String) extends Type { 4 | override def toString = name 5 | } 6 | 7 | // Lambda type: t1 -> t2 8 | case class Arrow(t1: Type, t2: Type) extends Type { 9 | override def toString = t1.toString() + "->" + (t2 match { case Arrow(_, _) => "(" + t2.toString() + ")" case _ => t2.toString() }) 10 | } 11 | 12 | // Type variable, name is the variable name 13 | case class TVar(name: String) extends Type { 14 | var pointTo: Type = this; 15 | override def toString = if (pointTo == this) name else pointTo.toString() 16 | } 17 | 18 | -------------------------------------------------------------------------------- /src/Term.scala: -------------------------------------------------------------------------------- 1 | 2 | import scala.collection.immutable.StringOps 3 | 4 | sealed abstract class Term { 5 | // This is the type annotation in the language 6 | var t: Option[Type] = None 7 | 8 | def typeInfo = t match { case None => "" case Some(ty) => ": " + ty.toString() } 9 | 10 | def toString(n: Int): String 11 | override def toString = toString(0) + typeInfo 12 | } 13 | 14 | case class Const(v: Any, ty: Type) extends Term { 15 | t = Some(ty) 16 | override def toString(n: Int) = (" " * n) + v.toString 17 | } 18 | 19 | case class Var(name: String) extends Term { 20 | override def toString(n: Int) = (" " * n) + name.toString 21 | } 22 | 23 | case class Abs(x: Var, e: Term) extends Term { 24 | override def toString(n: Int) = (" " * n) + "(\\" + x.toString() + " => " + e.toString() + ")" 25 | } 26 | 27 | case class App(e1: Term, e2: Term) extends Term { 28 | override def toString(n: Int) = (" " * n) + "(" + e1.toString() + ", " + e2.toString() + ")" 29 | } 30 | 31 | case class Let(x: Var, e1: Term, e2: Term) extends Term { 32 | override def toString(n: Int) = (" " * n) + "Let " + x.toString() + " = " + e1.toString() + " in\n" + e2.toString(n + 4) + ")" 33 | } -------------------------------------------------------------------------------- /src/Lambda.scala: -------------------------------------------------------------------------------- 1 | import scala.util.parsing.combinator._ 2 | import util.parsing.combinator.syntactical._ 3 | 4 | object Lambda extends StandardTokenParsers with ImplicitConversions { 5 | lexical.delimiters ++= ("\\ => + - * / ( ) , == = ;" split ' ') 6 | lexical.reserved ++= ("let in true false" split ' ') 7 | 8 | def boolean: Parser[Term] = ("true" | "false") ^^ { s => Const(s.toBoolean, Prim("Bool")) } 9 | def string: Parser[Term] = stringLit ^^ { s => Const(s, Prim("String")) } 10 | def double: Parser[Term] = numericLit ^^ { s => Const(s.toDouble, Prim("Double")) } 11 | 12 | def literal: Parser[Term] = boolean | string | double 13 | 14 | def variable: Parser[Var] = ident ^^ Var 15 | 16 | def let: Parser[Term] = ("let" ~> variable) ~ ("=" ~> expr) ~ ("in" ~> expr) ^^ Let 17 | 18 | def lam: Parser[Term] = ("\\" ~> variable) ~ ("=>" ~> expr) ^^ Abs 19 | 20 | def app: Parser[Term] = "(" ~> expr ~ expr <~ ")" ^^ App 21 | 22 | def operator: Parser[Term] = ("+" | "-" | "*" | ident) ^^ Var 23 | def op: Parser[Term] = "(" ~> expr ~ operator ~ expr <~ ")" ^^ { case e1 ~ o ~ e2 => App(App(o, e1), e2) } 24 | 25 | def expr: Parser[Term] = variable | literal | lam | let | app | op | "(" ~> expr <~ ")" 26 | 27 | def parse(input: String): Term = 28 | phrase(expr)(new lexical.Scanner(input)) match { 29 | case Success(e, _) => e 30 | case e: NoSuccess => throw new Exception(e.msg) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /src/test.sc: -------------------------------------------------------------------------------- 1 | // Add an JUnit test 2 | 3 | object test { 4 | List(2,3).apply(0) //> res0: Int = 2 5 | var v1 = new Var("1") //> v1 : Var = 1 6 | Inferencer.MakeFreshTVar() //> res1: TVar = t1 7 | 8 | Lambda.parse("x") //> res2: Term = x 9 | Lambda.parse("\"string\"") //> res3: Term = string: String 10 | Lambda.parse("3") //> res4: Term = 3.0: Double 11 | Lambda.parse("true") //> res5: Term = true: Bool 12 | Lambda.parse("""\x => x""") //> res6: Term = (\x => x) 13 | Lambda.parse("""\x => 1""") //> res7: Term = (\x => 1.0: Double) 14 | Lambda.parse("""\x => (x x)""") //> res8: Term = (\x => (x, x)) 15 | Lambda.parse("""\x => (x + x )""") //> res9: Term = (\x => ((+, x), x)) 16 | Lambda.parse("let x = x in x") //> res10: Term = Let x = x in 17 | //| x) 18 | Lambda.parse("let x = 3 in x") //> res11: Term = Let x = 3.0: Double in 19 | //| x) 20 | Lambda.parse("let x = 3 in let y = 2 in (x + y)") 21 | //> res12: Term = Let x = 3.0: Double in 22 | //| Let y = 2.0: Double in 23 | //| ((+, x), y))) 24 | Lambda.parse("""let f = \x => x in let _ = (f true) in (f false)""") 25 | //> res13: Term = Let f = (\x => x) in 26 | //| Let _ = (f, true: Bool) in 27 | //| (f, false: Bool))) 28 | // infer 29 | Inferencer Infer Lambda.parse("""let f = \x => x in let _ = (f 3) in (f false)""" ) 30 | //> res14: Type = Bool 31 | Inferencer Infer Lambda.parse("""let f = \x => (1 + x) in let _ = (f 3) in (f 1)""") 32 | //> res15: Type = Double 33 | Inferencer Infer Lambda.parse("""let f = \x => (x + 1) in let _ = (f 3) in (f 1)""") 34 | //> res16: Type = Double 35 | 36 | Inferencer Infer Lambda.parse("""let f = (\x => (\x => x)) in let _ = (f 3) in (f "str")""") 37 | 38 | } -------------------------------------------------------------------------------- /src/Inferencer.scala: -------------------------------------------------------------------------------- 1 | import scala.collection.immutable.{ Map, HashMap, Set } 2 | 3 | object Inferencer { 4 | type Context = Map[Var, Type] 5 | 6 | private var count = 0 7 | def MakeFreshTVar(): TVar = { 8 | count = count + 1 9 | new TVar("t" + count) 10 | } 11 | 12 | def Find(t: Type): Type = { 13 | t match { 14 | case tv: TVar => { 15 | if (tv.pointTo != tv) tv.pointTo = Find(tv.pointTo) 16 | tv.pointTo 17 | } 18 | case _ => t 19 | } 20 | } 21 | 22 | def Union(ta: TVar, tb: Type) = { 23 | // currently the system dosn't support recursive type, so need check 24 | def occurs(ta: TVar, tb: Type): Boolean = { 25 | Find(tb) match { 26 | case _: TVar => ta == tb // `ta` 27 | case Arrow(t1, t2) => occurs(ta, t1) || occurs(ta, t2) 28 | case _ => false 29 | } 30 | } 31 | 32 | if (ta != tb) { 33 | if (occurs(ta, tb)) 34 | throw new Exception("no rec type please") 35 | // do not need find again here 36 | ta.pointTo = tb 37 | } 38 | } 39 | 40 | def Unify(ta: Type, tb: Type): Unit = { 41 | val (t1, t2) = (Find(ta), Find(tb)) 42 | (t1, t2) match { 43 | case (Arrow(a1, b1), Arrow(a2, b2)) => { Unify(a1, a2); Unify(b1, b2) } 44 | case (tv: TVar, _) => Union(tv, tb) 45 | case (_, tv: TVar) => Union(tv, ta) 46 | case _ => if (t1 != t2) throw new Exception("unify error, type mismatch") 47 | } 48 | } 49 | 50 | def Inst(t: Type, ctx: Context): Type = { 51 | val bindVar = ctx.values.toSet 52 | val gen = new collection.mutable.HashMap[TVar, TVar] 53 | def reify(t: Type): Type = { // Try to iterate all free TVar of the type 54 | Find(t) match { 55 | case tv: TVar => { 56 | // If it is a polytype (generic type), that is, not bind by the context create a new Type variable to be subst free. 57 | if (!bindVar.contains(tv)) gen.getOrElseUpdate(tv, MakeFreshTVar()) 58 | else tv 59 | } 60 | case Arrow(t1, t2) => Arrow(reify(t1), reify(t2)) 61 | case _ => t 62 | } 63 | } 64 | reify(t) 65 | } 66 | 67 | // given type annotation, synthesize the type info and recursively infer the type 68 | //def Synthesize(ctx: Context, expr: Term, t: Type) 69 | 70 | // type inference, the first element is the type of the expression while context has all other 71 | // type information 72 | def Analyze(ctx: Context, expr: Term): Type = { 73 | // TODO: if there is type annotation, apply synthesize rule 74 | val res: Type = expr match { 75 | case Const(v, t) => t 76 | case v: Var => { 77 | ctx.get(v) match { 78 | case Some(t) => Inst(t, ctx) 79 | case None => throw new Exception("current syntax doesn't need this") 80 | } 81 | } 82 | case Abs(x, e) => { 83 | val tx = ctx.get(x) match { case Some(t) => t case None => MakeFreshTVar() } 84 | val te = Analyze(ctx + (x -> tx), e) 85 | Arrow(tx, te) 86 | } 87 | case Let(x, e1, e2) => { 88 | var te1 = Analyze(ctx, e1) 89 | Analyze(ctx + ( x -> te1 ), e2) 90 | } 91 | case App(e1, e2) => { 92 | val te1 = Analyze(ctx, e1) 93 | val te2 = Analyze(ctx, e2) 94 | val t = MakeFreshTVar() 95 | Unify(te1, Arrow(te2, t)) 96 | t 97 | } 98 | } 99 | expr.t = Some(res) 100 | res 101 | } 102 | 103 | def Infer(expr: Term): Type = { 104 | val ctx = new HashMap[Var, Type] + 105 | { Var("+") -> Arrow(TVar("t0"), Arrow(TVar("t0"), TVar("t0"))) } 106 | Analyze(ctx, expr) 107 | } 108 | 109 | def main(args: Array[String]) { 110 | val expr = Lambda.parse("""let f = (\x => (\x => x)) in let _ = (f 3) in (f "str")""") 111 | println(expr) 112 | Inferencer.Infer(expr) 113 | println(expr) 114 | } 115 | } --------------------------------------------------------------------------------