├── release.go ├── .travis.yml ├── .gitignore ├── const.go ├── constraint.go ├── solver.go ├── constraint_test.go ├── Gopkg.toml ├── typeVariable.go ├── LICENCE ├── Gopkg.lock ├── env.go ├── debug.go ├── substitutables_test.go ├── expression.go ├── substitutables.go ├── solver_test.go ├── scheme_test.go ├── README.md ├── typeVarSet.go ├── typeVariable_test.go ├── scheme.go ├── env_test.go ├── typeVarSet_test.go ├── test_test.go ├── perf_test.go ├── substitutions.go ├── functionType.go ├── functionType_test.go ├── type.go ├── substitutions_test.go ├── perf.go ├── example_greenspun_test.go ├── hm_test.go └── hm.go /release.go: -------------------------------------------------------------------------------- 1 | // +build !debug 2 | 3 | package hm 4 | 5 | func enterLoggingContext() {} 6 | func leaveLoggingContext() {} 7 | func logf(format string, others ...interface{}) {} 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: go 3 | branches: 4 | only: 5 | - master 6 | 7 | go: 8 | - 1.5.x 9 | - 1.6.x 10 | - 1.7.x 11 | - 1.8.x 12 | - 1.9.x 13 | - tip 14 | 15 | env: 16 | global: 17 | - GOARCH=amd64 18 | - TRAVISTEST=true 19 | 20 | before_install: 21 | - go get github.com/mattn/goveralls 22 | 23 | script: 24 | - $HOME/gopath/bin/goveralls -service=travis-ci -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Object files, Static and Dynamic libs (Shared Objects) 2 | *.o 3 | *.a 4 | *.so 5 | 6 | # Folders 7 | _obj 8 | _test 9 | 10 | # Architecture specific extensions/prefixes 11 | *.[568vq] 12 | [568vq].out 13 | 14 | *.cgo1.go 15 | *.cgo2.c 16 | _cgo_defun.c 17 | _cgo_gotypes.go 18 | _cgo_export.* 19 | 20 | _testmain.go 21 | 22 | *.exe 23 | *.test 24 | *.prof 25 | 26 | # vendor 27 | /vendor -------------------------------------------------------------------------------- /const.go: -------------------------------------------------------------------------------- 1 | // Package hm provides a Hindley-Milner type and type inference system. 2 | // 3 | // If you are creating a new programming language and you'd like it to be 4 | // strongly typed with parametric polymorphism (or just have Haskell-envy), 5 | // this library provides the necessary types and functions for creating such a system. 6 | // 7 | // The key to the HM type inference system is in the Unify() function. 8 | 9 | package hm 10 | 11 | const letters = `abcdefghijklmnopqrstuvwxyz` 12 | -------------------------------------------------------------------------------- /constraint.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "fmt" 4 | 5 | // A Constraint is well.. a constraint that says a must equal to b. It's used mainly in the constraint generation process. 6 | type Constraint struct { 7 | a, b Type 8 | } 9 | 10 | func (c Constraint) Apply(sub Subs) Substitutable { 11 | c.a = c.a.Apply(sub).(Type) 12 | c.b = c.b.Apply(sub).(Type) 13 | return c 14 | } 15 | 16 | func (c Constraint) FreeTypeVar() TypeVarSet { 17 | var retVal TypeVarSet 18 | retVal = c.a.FreeTypeVar().Union(retVal) 19 | retVal = c.b.FreeTypeVar().Union(retVal) 20 | return retVal 21 | } 22 | 23 | func (c Constraint) Format(state fmt.State, r rune) { 24 | fmt.Fprintf(state, "{%v = %v}", c.a, c.b) 25 | } 26 | -------------------------------------------------------------------------------- /solver.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | type solver struct { 4 | sub Subs 5 | err error 6 | } 7 | 8 | func newSolver() *solver { 9 | return new(solver) 10 | } 11 | 12 | func (s *solver) solve(cs Constraints) { 13 | logf("solving constraints: %d", len(cs)) 14 | enterLoggingContext() 15 | defer leaveLoggingContext() 16 | logf("starting sub %v", s.sub) 17 | if s.err != nil { 18 | return 19 | } 20 | 21 | switch len(cs) { 22 | case 0: 23 | return 24 | default: 25 | var sub Subs 26 | c := cs[0] 27 | sub, s.err = Unify(c.a, c.b) 28 | defer ReturnSubs(s.sub) 29 | 30 | s.sub = compose(sub, s.sub) 31 | cs = cs[1:].Apply(s.sub).(Constraints) 32 | s.solve(cs) 33 | 34 | } 35 | logf("Ending Sub %v", s.sub) 36 | return 37 | } 38 | -------------------------------------------------------------------------------- /constraint_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "testing" 4 | 5 | func TestConstraint(t *testing.T) { 6 | c := Constraint{ 7 | a: TypeVariable('a'), 8 | b: NewFnType(TypeVariable('b'), TypeVariable('c')), 9 | } 10 | 11 | ftv := c.FreeTypeVar() 12 | if !ftv.Equals(TypeVarSet{TypeVariable('a'), TypeVariable('b'), TypeVariable('c')}) { 13 | t.Error("the free type variables of a Constraint is not as expected") 14 | } 15 | 16 | subs := mSubs{ 17 | 'a': NewFnType(proton, proton), 18 | 'b': proton, 19 | 'c': neutron, 20 | } 21 | 22 | c = c.Apply(subs).(Constraint) 23 | if !c.a.Eq(NewFnType(proton, proton)) { 24 | t.Errorf("c.a: %v", c) 25 | } 26 | 27 | if !c.b.Eq(NewFnType(proton, neutron)) { 28 | t.Errorf("c.b: %v", c) 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /Gopkg.toml: -------------------------------------------------------------------------------- 1 | 2 | # Gopkg.toml example 3 | # 4 | # Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md 5 | # for detailed Gopkg.toml documentation. 6 | # 7 | # required = ["github.com/user/thing/cmd/thing"] 8 | # ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] 9 | # 10 | # [[constraint]] 11 | # name = "github.com/user/project" 12 | # version = "1.0.0" 13 | # 14 | # [[constraint]] 15 | # name = "github.com/user/project2" 16 | # branch = "dev" 17 | # source = "github.com/myfork/project2" 18 | # 19 | # [[override]] 20 | # name = "github.com/x/y" 21 | # version = "2.4.0" 22 | 23 | ignored = ["github.com/alecthomas/assert"] 24 | 25 | [[constraint]] 26 | name = "github.com/pkg/errors" 27 | version = "0.8.0" 28 | 29 | [[constraint]] 30 | name = "github.com/stretchr/testify" 31 | version = "1.1.4" 32 | 33 | [[constraint]] 34 | branch = "master" 35 | name = "github.com/xtgo/set" 36 | -------------------------------------------------------------------------------- /typeVariable.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | // TypeVariable is a variable that ranges over the types - that is to say it can take any type. 10 | type TypeVariable rune 11 | 12 | func (t TypeVariable) Name() string { return string(t) } 13 | func (t TypeVariable) Apply(sub Subs) Substitutable { 14 | if sub == nil { 15 | return t 16 | } 17 | 18 | if retVal, ok := sub.Get(t); ok { 19 | return retVal 20 | } 21 | 22 | return t 23 | } 24 | 25 | func (t TypeVariable) FreeTypeVar() TypeVarSet { tvs := BorrowTypeVarSet(1); tvs[0] = t; return tvs } 26 | func (t TypeVariable) Normalize(k, v TypeVarSet) (Type, error) { 27 | if i := k.Index(t); i >= 0 { 28 | return v[i], nil 29 | } 30 | return nil, errors.Errorf("Type Variable %v not in signature", t) 31 | } 32 | 33 | func (t TypeVariable) Types() Types { return nil } 34 | func (t TypeVariable) String() string { return string(t) } 35 | func (t TypeVariable) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%c", rune(t)) } 36 | func (t TypeVariable) Eq(other Type) bool { return other == t } 37 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Xuanyi Chew 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Gopkg.lock: -------------------------------------------------------------------------------- 1 | # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. 2 | 3 | 4 | [[projects]] 5 | name = "github.com/davecgh/go-spew" 6 | packages = ["spew"] 7 | revision = "346938d642f2ec3594ed81d874461961cd0faa76" 8 | version = "v1.1.0" 9 | 10 | [[projects]] 11 | name = "github.com/pkg/errors" 12 | packages = ["."] 13 | revision = "645ef00459ed84a119197bfb8d8205042c6df63d" 14 | version = "v0.8.0" 15 | 16 | [[projects]] 17 | name = "github.com/pmezard/go-difflib" 18 | packages = ["difflib"] 19 | revision = "792786c7400a136282c1664665ae0a8db921c6c2" 20 | version = "v1.0.0" 21 | 22 | [[projects]] 23 | name = "github.com/stretchr/testify" 24 | packages = ["assert"] 25 | revision = "69483b4bd14f5845b5a1e55bca19e954e827f1d0" 26 | version = "v1.1.4" 27 | 28 | [[projects]] 29 | branch = "master" 30 | name = "github.com/xtgo/set" 31 | packages = ["."] 32 | revision = "4431f6b51265b1e0b76af4dafc09d6f12c2bdcd0" 33 | 34 | [solve-meta] 35 | analyzer-name = "dep" 36 | analyzer-version = 1 37 | inputs-digest = "e010eb1269d6f582c0c9b306d2506a5bcfd99464e64a0f21c9be12478b262b37" 38 | solver-name = "gps-cdcl" 39 | solver-version = 1 40 | -------------------------------------------------------------------------------- /env.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | // An Env is essentially a map of names to schemes 4 | type Env interface { 5 | Substitutable 6 | SchemeOf(string) (*Scheme, bool) 7 | Clone() Env 8 | 9 | Add(string, *Scheme) Env 10 | Remove(string) Env 11 | } 12 | 13 | type SimpleEnv map[string]*Scheme 14 | 15 | func (e SimpleEnv) Apply(sub Subs) Substitutable { 16 | logf("Applying %v to env", sub) 17 | if sub == nil { 18 | return e 19 | } 20 | 21 | for _, v := range e { 22 | v.Apply(sub) // apply mutates Scheme, so no need to set 23 | } 24 | return e 25 | } 26 | 27 | func (e SimpleEnv) FreeTypeVar() TypeVarSet { 28 | var retVal TypeVarSet 29 | for _, v := range e { 30 | retVal = v.FreeTypeVar().Union(retVal) 31 | } 32 | return retVal 33 | } 34 | 35 | func (e SimpleEnv) SchemeOf(name string) (retVal *Scheme, ok bool) { retVal, ok = e[name]; return } 36 | func (e SimpleEnv) Clone() Env { 37 | retVal := make(SimpleEnv) 38 | for k, v := range e { 39 | retVal[k] = v.Clone() 40 | } 41 | return retVal 42 | } 43 | 44 | func (e SimpleEnv) Add(name string, s *Scheme) Env { 45 | e[name] = s 46 | return e 47 | } 48 | 49 | func (e SimpleEnv) Remove(name string) Env { 50 | delete(e, name) 51 | return e 52 | } 53 | -------------------------------------------------------------------------------- /debug.go: -------------------------------------------------------------------------------- 1 | // +build debug 2 | 3 | package hm 4 | 5 | import ( 6 | "fmt" 7 | "log" 8 | "os" 9 | "strings" 10 | "sync/atomic" 11 | ) 12 | 13 | // DEBUG returns true when it's in debug mode 14 | const DEBUG = false 15 | 16 | var tabcount uint32 17 | 18 | var _logger_ = log.New(os.Stderr, "", 0) 19 | var replacement = "\n" 20 | 21 | func tc() int { 22 | return int(atomic.LoadUint32(&tabcount)) 23 | } 24 | 25 | func enterLoggingContext() { 26 | atomic.AddUint32(&tabcount, 1) 27 | tabs := tc() 28 | _logger_.SetPrefix(strings.Repeat("\t", tabs)) 29 | replacement = "\n" + strings.Repeat("\t", tabs) 30 | } 31 | 32 | func leaveLoggingContext() { 33 | tabs := tc() 34 | tabs-- 35 | 36 | if tabs < 0 { 37 | atomic.StoreUint32(&tabcount, 0) 38 | tabs = 0 39 | } else { 40 | atomic.StoreUint32(&tabcount, uint32(tabs)) 41 | } 42 | _logger_.SetPrefix(strings.Repeat("\t", tabs)) 43 | replacement = "\n" + strings.Repeat("\t", tabs) 44 | } 45 | 46 | func logf(format string, others ...interface{}) { 47 | if DEBUG { 48 | // format = strings.Replace(format, "\n", replacement, -1) 49 | s := fmt.Sprintf(format, others...) 50 | s = strings.Replace(s, "\n", replacement, -1) 51 | _logger_.Println(s) 52 | // _logger_.Printf(format, others...) 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /substitutables_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestConstraints(t *testing.T) { 9 | cs := Constraints{ 10 | {TypeVariable('a'), proton}, 11 | {TypeVariable('b'), proton}, 12 | } 13 | correct := TypeVarSet{'a', 'b'} 14 | 15 | ftv := cs.FreeTypeVar() 16 | for _, v := range correct { 17 | if !ftv.Contains(v) { 18 | t.Errorf("Expected free type vars to contain %v", v) 19 | break 20 | } 21 | } 22 | 23 | sub := mSubs{ 24 | 'a': neutron, 25 | } 26 | 27 | cs = cs.Apply(sub).(Constraints) 28 | if cs[0].a != neutron { 29 | t.Error("Expected neutron") 30 | } 31 | if cs[0].b != proton { 32 | t.Error("Expected proton") 33 | } 34 | 35 | if cs[1].a != TypeVariable('b') { 36 | t.Error("There was nothing to substitute b with") 37 | } 38 | if cs[1].b != proton { 39 | t.Error("Expected proton") 40 | } 41 | 42 | if fmt.Sprintf("%v", cs) != "Constraints[{neutron = proton}, {b = proton}]" { 43 | t.Errorf("Error in formatting cs") 44 | } 45 | 46 | } 47 | 48 | func TestTypes_Contains(t *testing.T) { 49 | ts := Types{TypeVariable('a'), proton} 50 | 51 | if !ts.Contains(TypeVariable('a')) { 52 | t.Error("Expected ts to contain 'a'") 53 | } 54 | 55 | if !ts.Contains(proton) { 56 | t.Error("Expected ts to contain proton") 57 | } 58 | 59 | if ts.Contains(neutron) { 60 | t.Error("ts shouldn't contain neutron") 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /expression.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | // A Namer is anything that knows its own name 4 | type Namer interface { 5 | Name() string 6 | } 7 | 8 | // A Typer is an Expression node that knows its own Type 9 | type Typer interface { 10 | Type() Type 11 | } 12 | 13 | // An Inferer is an Expression that can infer its own Type given an Env 14 | type Inferer interface { 15 | Infer(Env, Fresher) (Type, error) 16 | } 17 | 18 | // An Expression is basically an AST node. In its simplest form, it's lambda calculus 19 | type Expression interface { 20 | Body() Expression 21 | } 22 | 23 | // Var is an expression representing a variable 24 | type Var interface { 25 | Expression 26 | Namer 27 | Typer 28 | } 29 | 30 | // Literal is an Expression/AST Node representing a literal 31 | type Literal interface { 32 | Var 33 | IsLit() bool 34 | } 35 | 36 | // Apply is an Expression/AST node that represents a function application 37 | type Apply interface { 38 | Expression 39 | Fn() Expression 40 | } 41 | 42 | // LetRec is an Expression/AST node that represents a recursive let 43 | type LetRec interface { 44 | Let 45 | IsRecursive() bool 46 | } 47 | 48 | // Let is an Expression/AST node that represents the standard let polymorphism found in functional languages 49 | type Let interface { 50 | // let name = def in body 51 | Expression 52 | Namer 53 | Def() Expression 54 | } 55 | 56 | // Lambda is an Expression/AST node that represents a function definiton 57 | type Lambda interface { 58 | Expression 59 | Namer 60 | IsLambda() bool 61 | } 62 | -------------------------------------------------------------------------------- /substitutables.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "fmt" 4 | 5 | // Constraints is a slice of Constraint. Like a Constraint, it is also a Substitutable 6 | type Constraints []Constraint 7 | 8 | func (cs Constraints) Apply(sub Subs) Substitutable { 9 | // an optimization 10 | if sub == nil { 11 | return cs 12 | } 13 | 14 | if len(cs) == 0 { 15 | return cs 16 | } 17 | 18 | logf("Constraints: %d", len(cs)) 19 | logf("Applying %v to %v", sub, cs) 20 | for i, c := range cs { 21 | cs[i] = c.Apply(sub).(Constraint) 22 | } 23 | logf("Constraints %v", cs) 24 | return cs 25 | } 26 | 27 | func (cs Constraints) FreeTypeVar() TypeVarSet { 28 | var retVal TypeVarSet 29 | for _, v := range cs { 30 | retVal = v.FreeTypeVar().Union(retVal) 31 | } 32 | return retVal 33 | } 34 | 35 | func (cs Constraints) Format(state fmt.State, c rune) { 36 | state.Write([]byte("Constraints[")) 37 | for i, c := range cs { 38 | if i < len(cs)-1 { 39 | fmt.Fprintf(state, "%v, ", c) 40 | } else { 41 | fmt.Fprintf(state, "%v", c) 42 | } 43 | } 44 | state.Write([]byte{']'}) 45 | } 46 | 47 | // Types is a slice of Type. Future additions to the methods of this type may be possible 48 | type Types []Type 49 | 50 | func (ts Types) Contains(t Type) bool { 51 | for _, T := range ts { 52 | if t.Eq(T) { 53 | return true 54 | } 55 | } 56 | return false 57 | } 58 | 59 | // func (ts Types) Apply(sub Subs) Substitutable { 60 | // for i, t := range ts { 61 | // ts[i] = t.Apply(sub).(Type) 62 | // } 63 | // return ts 64 | // } 65 | 66 | // func (ts Types) FreeTypeVar() TypeVarSet { 67 | // var retVal TypeVarSet 68 | // for _, v := range ts { 69 | // retVal = v.FreeTypeVar().Union(retVal) 70 | // } 71 | // return retVal 72 | // } 73 | -------------------------------------------------------------------------------- /solver_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "testing" 4 | 5 | var solverTest = []struct { 6 | cs Constraints 7 | 8 | expected Subs 9 | err bool 10 | }{ 11 | {Constraints{{TypeVariable('a'), proton}}, mSubs{'a': proton}, false}, 12 | {Constraints{{NewFnType(TypeVariable('a'), proton), neutron}}, nil, true}, 13 | {Constraints{{NewFnType(TypeVariable('a'), proton), NewFnType(proton, proton)}}, mSubs{'a': proton}, false}, 14 | 15 | {Constraints{ 16 | { 17 | NewFnType(TypeVariable('a'), TypeVariable('a'), list{TypeVariable('a')}), 18 | NewFnType(proton, proton, TypeVariable('b')), 19 | }, 20 | }, 21 | mSubs{'a': proton, 'b': list{proton}}, false, 22 | }, 23 | 24 | { 25 | Constraints{ 26 | {TypeVariable('a'), TypeVariable('b')}, 27 | {TypeVariable('a'), proton}, 28 | }, 29 | mSubs{'a': proton}, false, 30 | }, 31 | 32 | { 33 | Constraints{ 34 | { 35 | NewRecordType("", TypeVariable('a'), TypeVariable('a'), TypeVariable('b')), 36 | NewRecordType("", neutron, neutron, proton), 37 | }, 38 | }, 39 | mSubs{'a': neutron, 'b': proton}, false, 40 | }, 41 | } 42 | 43 | func TestSolver(t *testing.T) { 44 | for i, sts := range solverTest { 45 | solver := newSolver() 46 | solver.solve(sts.cs) 47 | 48 | if sts.err { 49 | if solver.err == nil { 50 | t.Errorf("Test %d Expected an error", i) 51 | } 52 | continue 53 | } else if solver.err != nil { 54 | t.Error(solver.err) 55 | } 56 | 57 | for _, v := range sts.expected.Iter() { 58 | if T, ok := solver.sub.Get(v.Tv); !ok { 59 | t.Errorf("Test %d: Expected type variable %v in subs: %v", i, v.Tv, solver.sub) 60 | break 61 | } else if T != v.T { 62 | t.Errorf("Test %d: Expected replacement to be %v. Got %v instead", i, v.T, T) 63 | } 64 | } 65 | } 66 | } 67 | -------------------------------------------------------------------------------- /scheme_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestSchemeBasics(t *testing.T) { 9 | s := new(Scheme) 10 | s.tvs = TypeVarSet{'a', 'b'} 11 | s.t = NewFnType(TypeVariable('c'), proton) 12 | 13 | sub := mSubs{ 14 | 'a': proton, 15 | 'b': neutron, 16 | 'c': electron, 17 | } 18 | 19 | s2 := s.Apply(nil).(*Scheme) 20 | if s2 != s { 21 | t.Errorf("Different pointers") 22 | } 23 | 24 | s2 = s.Apply(sub).(*Scheme) 25 | if s2 != s { 26 | t.Errorf("Different pointers") 27 | } 28 | 29 | if !s.tvs.Equals(TypeVarSet{'a', 'b'}) { 30 | t.Error("TypeVarSet mutated") 31 | } 32 | 33 | if !s.t.Eq(NewFnType(electron, proton)) { 34 | t.Error("Application failed") 35 | } 36 | 37 | s = new(Scheme) 38 | s.tvs = TypeVarSet{'a', 'b'} 39 | s.t = NewFnType(TypeVariable('c'), proton) 40 | 41 | ftv := s.FreeTypeVar() 42 | 43 | if !ftv.Equals(TypeVarSet{'c'}) { 44 | t.Errorf("Expected ftv: {'c'}. Got %v instead", ftv) 45 | } 46 | 47 | // format 48 | if fmt.Sprintf("%v", s) != "∀[a, b]: c → proton" { 49 | t.Errorf("Scheme format is wrong.: Got %q", fmt.Sprintf("%v", s)) 50 | } 51 | 52 | // Polytype scheme.Type 53 | T, isMono := s.Type() 54 | if isMono { 55 | t.Errorf("%v is supposed to be a polytype. It shouldn't return true", s) 56 | } 57 | if !T.Eq(NewFnType(TypeVariable('c'), proton)) { 58 | t.Error("Wrong type returned by scheme") 59 | } 60 | } 61 | 62 | func TestSchemeNormalize(t *testing.T) { 63 | s := new(Scheme) 64 | s.tvs = TypeVarSet{'c', 'z', 'd'} 65 | s.t = NewFnType(TypeVariable('a'), TypeVariable('c')) 66 | 67 | err := s.Normalize() 68 | if err != nil { 69 | t.Error(err) 70 | } 71 | 72 | if !s.tvs.Equals(TypeVarSet{'a', 'b'}) { 73 | t.Errorf("Expected: TypeVarSet{'a','b'}. Got: %v", s.tvs) 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hm [![GoDoc](https://godoc.org/github.com/chewxy/hm?status.svg)](https://godoc.org/github.com/chewxy/hm) [![Build Status](https://travis-ci.org/chewxy/hm.svg?branch=master)](https://travis-ci.org/chewxy/hm) [![Coverage Status](https://coveralls.io/repos/github/chewxy/hm/badge.png)](https://coveralls.io/github/chewxy/hm) 2 | 3 | Package hm is a simple Hindley-Milner type inference system in Go. It provides the necessary data structures and functions for creating such a system. 4 | 5 | # Installation # 6 | 7 | This package is go-gettable: `go get -u github.com/chewxy/hm` 8 | 9 | There are very few dependencies that this package uses. Therefore there isn't a need for vendoring tools. However, package hm DOES provide a `Gopkg.toml` and `Gopkg.lock` for any potential users of the [dep](https://github.com/golang/dep) tool. 10 | 11 | Here is a listing of the dependencies of `hm`: 12 | 13 | |Package|Used For|Vitality|Notes|Licence| 14 | |-------|--------|--------|-----|-------| 15 | |[errors](https://github.com/pkg/errors)|Error wrapping|Can do without, but this is by far the superior error solution out there|Stable API for the past 6 months|[errors licence](https://github.com/pkg/errors/blob/master/LICENSE) (MIT/BSD-like)| 16 | |[testify/assert](https://github.com/stretchr/testify)|Testing|Can do without but will be a massive pain in the ass to test||[testify licence](https://github.com/stretchr/testify/blob/master/LICENSE) (MIT/BSD-like)| 17 | 18 | # Usage 19 | 20 | TODO: Write this 21 | 22 | # Notes 23 | 24 | This package is used by [Gorgonia](https://github.com/chewxy/gorgonia) as part of the graph building process. It is also used by several other internal projects of this author, all sharing a similar theme of requiring a type system, which is why this was abstracted out. 25 | 26 | 27 | # Contributing 28 | 29 | This library is developed using Github. Therefore the workflow is very github-centric. 30 | 31 | # Licence 32 | 33 | Package `hm` is licenced under the MIT licence. 34 | -------------------------------------------------------------------------------- /typeVarSet.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "sort" 5 | 6 | "github.com/xtgo/set" 7 | ) 8 | 9 | // TypeVarSet is a set of TypeVariable 10 | type TypeVarSet []TypeVariable 11 | 12 | // TypeVariables are orderable, so we fulfil the interface for sort.Interface 13 | 14 | func (s TypeVarSet) Len() int { return len(s) } 15 | func (s TypeVarSet) Less(i, j int) bool { return s[i] < s[j] } 16 | func (s TypeVarSet) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 17 | 18 | func (s TypeVarSet) Set() TypeVarSet { 19 | sort.Sort(s) 20 | n := set.Uniq(s) 21 | s = s[:n] 22 | return s 23 | } 24 | 25 | func (s TypeVarSet) Union(other TypeVarSet) TypeVarSet { 26 | if other == nil { 27 | return s 28 | } 29 | 30 | sort.Sort(s) 31 | sort.Sort(other) 32 | s2 := append(s, other...) 33 | n := set.Union(s2, len(s)) 34 | return s2[:n] 35 | } 36 | 37 | func (s TypeVarSet) Intersect(other TypeVarSet) TypeVarSet { 38 | if len(s) == 0 || len(other) == 0 { 39 | return nil 40 | } 41 | 42 | sort.Sort(s) 43 | sort.Sort(other) 44 | s2 := append(s, other...) 45 | n := set.Inter(s2, len(s)) 46 | return s2[:n] 47 | } 48 | 49 | func (s TypeVarSet) Difference(other TypeVarSet) TypeVarSet { 50 | sort.Sort(s) 51 | sort.Sort(other) 52 | s2 := append(s, other...) 53 | n := set.Diff(s2, len(s)) 54 | return s2[:n] 55 | } 56 | 57 | func (s TypeVarSet) Contains(tv TypeVariable) bool { 58 | for _, v := range s { 59 | if v == tv { 60 | return true 61 | } 62 | } 63 | return false 64 | } 65 | 66 | func (s TypeVarSet) Index(tv TypeVariable) int { 67 | for i, v := range s { 68 | if v == tv { 69 | return i 70 | } 71 | } 72 | return -1 73 | } 74 | 75 | func (s TypeVarSet) Equals(other TypeVarSet) bool { 76 | if len(s) != len(other) { 77 | return false 78 | } 79 | 80 | if len(s) == 0 { 81 | return true 82 | } 83 | 84 | if &s[0] == &other[0] { 85 | return true 86 | } 87 | 88 | for _, v := range s { 89 | if !other.Contains(v) { 90 | return false 91 | } 92 | } 93 | return true 94 | } 95 | -------------------------------------------------------------------------------- /typeVariable_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | func TestTypeVariableBasics(t *testing.T) { 9 | tv := TypeVariable('a') 10 | if name := tv.Name(); name != "a" { 11 | t.Errorf("Expected name to be \"a\". Got %q instead", name) 12 | } 13 | 14 | if str := tv.String(); str != "a" { 15 | t.Errorf("Expected String() of 'a'. Got %q instead", str) 16 | } 17 | 18 | if tv.Types() != nil { 19 | t.Errorf("Expected Types() of TypeVariable to be nil") 20 | } 21 | 22 | ftv := tv.FreeTypeVar() 23 | if len(ftv) != 1 { 24 | t.Errorf("Expected a type variable to be free when FreeTypeVar() is called") 25 | } 26 | 27 | if ftv[0] != tv { 28 | t.Errorf("Expected ...") 29 | } 30 | 31 | sub := mSubs{ 32 | 'a': proton, 33 | } 34 | 35 | if tv.Apply(sub) != proton { 36 | t.Error("Expected proton") 37 | } 38 | 39 | sub = mSubs{ 40 | 'b': proton, 41 | } 42 | 43 | if tv.Apply(sub) != tv { 44 | t.Error("Expected unchanged") 45 | } 46 | } 47 | 48 | func TestTypeVariableNormalize(t *testing.T) { 49 | original := TypeVarSet{'c', 'a', 'd'} 50 | normalized := TypeVarSet{'a', 'b', 'c'} 51 | 52 | tv := TypeVariable('a') 53 | norm, err := tv.Normalize(original, normalized) 54 | if err != nil { 55 | t.Error(err) 56 | } 57 | 58 | if norm != TypeVariable('b') { 59 | t.Errorf("Expected 'b'. Got %v", norm) 60 | } 61 | 62 | tv = TypeVariable('e') 63 | if _, err = tv.Normalize(original, normalized); err == nil { 64 | t.Error("Expected an error") 65 | } 66 | } 67 | 68 | func TestTypeConst(t *testing.T) { 69 | T := proton 70 | if T.Name() != "proton" { 71 | t.Error("Expected name to be proton") 72 | } 73 | 74 | if fmt.Sprintf("%v", T) != "proton" { 75 | t.Error("Expected name to be proton") 76 | } 77 | 78 | if T.String() != "proton" { 79 | t.Error("Expected name to be proton") 80 | } 81 | 82 | if T2, err := T.Normalize(nil, nil); err != nil { 83 | t.Error(err) 84 | } else if T2 != T { 85 | t.Error("Const types should return itself") 86 | } 87 | } 88 | -------------------------------------------------------------------------------- /scheme.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "fmt" 4 | 5 | // Scheme represents a polytype. 6 | // It basically says this: 7 | // ∀TypeVariables.Type. 8 | // What this means is for all TypeVariables enclosed in Type, those TypeVariables can be of any Type. 9 | type Scheme struct { 10 | tvs TypeVarSet 11 | t Type 12 | } 13 | 14 | func NewScheme(tvs TypeVarSet, t Type) *Scheme { 15 | return &Scheme{ 16 | tvs: tvs, 17 | t: t, 18 | } 19 | } 20 | 21 | func (s *Scheme) Apply(sub Subs) Substitutable { 22 | logf("s: %v, sub: %v", s, sub) 23 | if sub == nil { 24 | return s 25 | } 26 | sub = sub.Clone() 27 | defer ReturnSubs(sub) 28 | 29 | for _, tv := range s.tvs { 30 | sub = sub.Remove(tv) 31 | } 32 | 33 | s.t = s.t.Apply(sub).(Type) 34 | return s 35 | } 36 | 37 | func (s *Scheme) FreeTypeVar() TypeVarSet { 38 | ftvs := s.t.FreeTypeVar() 39 | tvs := s.tvs.Set() 40 | return ftvs.Difference(tvs) 41 | } 42 | 43 | func (s *Scheme) Clone() *Scheme { 44 | tvs := make(TypeVarSet, len(s.tvs)) 45 | for i, v := range s.tvs { 46 | tvs[i] = v 47 | } 48 | return &Scheme{ 49 | tvs: tvs, 50 | t: s.t, 51 | } 52 | } 53 | 54 | func (s *Scheme) Format(state fmt.State, c rune) { 55 | state.Write([]byte("∀[")) 56 | for i, tv := range s.tvs { 57 | if i < len(s.tvs)-1 { 58 | fmt.Fprintf(state, "%v, ", tv) 59 | } else { 60 | fmt.Fprintf(state, "%v", tv) 61 | } 62 | } 63 | fmt.Fprintf(state, "]: %v", s.t) 64 | } 65 | 66 | // Type returns the type of the scheme, as well as a boolean indicating if *Scheme represents a monotype. If it's a polytype, it'll return false 67 | func (s *Scheme) Type() (t Type, isMonoType bool) { 68 | if len(s.tvs) == 0 { 69 | return s.t, true 70 | } 71 | return s.t, false 72 | } 73 | 74 | // Normalize normalizes the type variables in a scheme, so all the names will be in alphabetical order 75 | func (s *Scheme) Normalize() (err error) { 76 | tfv := s.t.FreeTypeVar() 77 | 78 | if len(tfv) == 0 { 79 | return nil 80 | } 81 | 82 | defer ReturnTypeVarSet(tfv) 83 | ord := BorrowTypeVarSet(len(tfv)) 84 | for i := range tfv { 85 | ord[i] = TypeVariable(letters[i]) 86 | } 87 | 88 | s.t, err = s.t.Normalize(tfv, ord) 89 | s.tvs = ord.Set() 90 | return 91 | } 92 | -------------------------------------------------------------------------------- /env_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestSimpleEnv(t *testing.T) { 10 | assert := assert.New(t) 11 | var orig, env Env 12 | var expected SimpleEnv 13 | 14 | // Add 15 | orig = make(SimpleEnv) 16 | orig = orig.Add("foo", NewScheme( 17 | TypeVarSet{'a', 'b', 'c'}, 18 | TypeVariable('a'), 19 | )) 20 | orig = orig.Add("bar", NewScheme( 21 | TypeVarSet{'b', 'c', 'd'}, 22 | TypeVariable('a'), 23 | )) 24 | orig = orig.Add("baz", NewScheme( 25 | TypeVarSet{'a', 'b', 'c'}, 26 | neutron, 27 | )) 28 | qs := NewScheme( 29 | TypeVarSet{'a', 'b'}, 30 | proton, 31 | ) 32 | orig = orig.Add("qux", qs) 33 | 34 | expected = SimpleEnv{ 35 | "foo": NewScheme( 36 | TypeVarSet{'a', 'b', 'c'}, 37 | TypeVariable('a'), 38 | ), 39 | "bar": NewScheme( 40 | TypeVarSet{'b', 'c', 'd'}, 41 | TypeVariable('a'), 42 | ), 43 | "baz": NewScheme( 44 | TypeVarSet{'a', 'b', 'c'}, 45 | neutron, 46 | ), 47 | "qux": NewScheme( 48 | TypeVarSet{'a', 'b'}, 49 | proton, 50 | ), 51 | } 52 | assert.Equal(expected, orig) 53 | 54 | // Get 55 | s, ok := orig.SchemeOf("qux") 56 | if s != qs || !ok { 57 | t.Error("Expected to get scheme of \"qux\"") 58 | } 59 | 60 | // Remove 61 | orig = orig.Remove("qux") 62 | delete(expected, "qux") 63 | assert.Equal(expected, orig) 64 | 65 | // Clone 66 | env = orig.Clone() 67 | assert.Equal(orig, env) 68 | 69 | subs := mSubs{ 70 | 'a': proton, 71 | 'b': neutron, 72 | 'd': electron, 73 | 'e': proton, 74 | } 75 | 76 | env = env.Apply(subs).(Env) 77 | expected = SimpleEnv{ 78 | "foo": &Scheme{ 79 | tvs: TypeVarSet{'a', 'b', 'c'}, 80 | t: TypeVariable('a'), 81 | }, 82 | "bar": &Scheme{ 83 | tvs: TypeVarSet{'b', 'c', 'd'}, 84 | t: proton, 85 | }, 86 | "baz": &Scheme{ 87 | tvs: TypeVarSet{'a', 'b', 'c'}, 88 | t: neutron, 89 | }, 90 | } 91 | assert.Equal(expected, env) 92 | 93 | env = orig.Clone() 94 | ftv := env.FreeTypeVar() 95 | correctFTV := TypeVarSet{'a'} 96 | 97 | if !correctFTV.Equals(ftv) { 98 | t.Errorf("Expected freetypevars to be equal. Got %v instead", ftv) 99 | } 100 | } 101 | -------------------------------------------------------------------------------- /typeVarSet_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "testing" 4 | 5 | var tvSetTests = []struct { 6 | op string 7 | tvs0 TypeVarSet 8 | tvs1 TypeVarSet 9 | 10 | expected TypeVarSet 11 | ind int 12 | eq bool 13 | }{ 14 | {"set", TypeVarSet{'a', 'a', 'a'}, nil, TypeVarSet{'a'}, 0, false}, 15 | {"set", TypeVarSet{'c', 'b', 'a'}, nil, TypeVarSet{'a', 'b', 'c'}, 0, false}, 16 | {"intersect", TypeVarSet{'a', 'b', 'c'}, TypeVarSet{'d', 'e', 'f'}, TypeVarSet{}, -1, false}, 17 | {"intersect", TypeVarSet{'a', 'b', 'c'}, TypeVarSet{'b', 'c', 'd'}, TypeVarSet{'b', 'c'}, -1, false}, 18 | {"intersect", TypeVarSet{'a', 'b', 'c'}, nil, nil, -1, false}, 19 | {"intersect", TypeVarSet{'a', 'b', 'c'}, TypeVarSet{'c', 'b', 'a'}, TypeVarSet{'a', 'b', 'c'}, 0, true}, 20 | {"union", TypeVarSet{'a', 'b'}, TypeVarSet{'c', 'd'}, TypeVarSet{'a', 'b', 'c', 'd'}, 0, false}, 21 | {"union", TypeVarSet{'a', 'c', 'b'}, TypeVarSet{'c', 'd'}, TypeVarSet{'a', 'b', 'c', 'd'}, 0, false}, 22 | {"union", TypeVarSet{'a', 'b'}, nil, TypeVarSet{'a', 'b'}, 0, false}, 23 | {"diff", TypeVarSet{'a', 'b', 'c'}, TypeVarSet{'d', 'e', 'c'}, TypeVarSet{'a', 'b'}, 0, false}, 24 | {"diff", TypeVarSet{'a', 'b', 'c'}, TypeVarSet{'c', 'd', 'e'}, TypeVarSet{'a', 'b'}, 0, false}, 25 | {"diff", TypeVarSet{'a', 'b', 'c'}, TypeVarSet{'d', 'e', 'f'}, TypeVarSet{'a', 'b', 'c'}, 0, false}, 26 | } 27 | 28 | func TestTypeVarSet(t *testing.T) { 29 | for i, tst := range tvSetTests { 30 | var s TypeVarSet 31 | switch tst.op { 32 | case "set": 33 | s = tst.tvs0.Set() 34 | if !s.Equals(tst.expected) { 35 | t.Errorf("%s op (%d): expected: %v, got %v", tst.op, i, tst.expected, s) 36 | } 37 | case "intersect": 38 | s = tst.tvs0.Intersect(tst.tvs1) 39 | if !s.Equals(tst.expected) { 40 | t.Errorf("%s op (%d): expected: %v, got %v", tst.op, i, tst.expected, s) 41 | } 42 | case "union": 43 | s = tst.tvs0.Union(tst.tvs1) 44 | if !s.Equals(tst.expected) { 45 | t.Errorf("%s op (%d): expected: %v, got %v", tst.op, i, tst.expected, s) 46 | } 47 | case "diff": 48 | s = tst.tvs0.Difference(tst.tvs1) 49 | if !s.Equals(tst.expected) { 50 | t.Errorf("%s op (%d): expected: %v, got %v", tst.op, i, tst.expected, s) 51 | } 52 | } 53 | 54 | if ind := s.Index('a'); ind != tst.ind { 55 | t.Errorf("%s op %d index : expected %d got %v", tst.op, i, tst.ind, ind) 56 | } 57 | 58 | if eq := tst.tvs0.Equals(tst.tvs1); eq != tst.eq { 59 | t.Errorf("%s op %d eq: expected %t got %v", tst.op, i, tst.eq, eq) 60 | } 61 | } 62 | 63 | tvs := TypeVarSet{'a'} 64 | if !tvs.Equals(tvs) { 65 | t.Error("A set should be equal to itself") 66 | } 67 | 68 | } 69 | -------------------------------------------------------------------------------- /test_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/pkg/errors" 7 | ) 8 | 9 | const ( 10 | proton TypeConst = "proton" 11 | neutron TypeConst = "neutron" 12 | quark TypeConst = "quark" 13 | 14 | electron TypeConst = "electron" 15 | positron TypeConst = "positron" 16 | muon TypeConst = "muon" 17 | 18 | photon TypeConst = "photon" 19 | higgs TypeConst = "higgs" 20 | ) 21 | 22 | type list struct { 23 | t Type 24 | } 25 | 26 | func (l list) Name() string { return "List" } 27 | func (l list) Apply(subs Subs) Substitutable { l.t = l.t.Apply(subs).(Type); return l } 28 | func (l list) FreeTypeVar() TypeVarSet { return l.t.FreeTypeVar() } 29 | func (l list) Format(s fmt.State, c rune) { fmt.Fprintf(s, "List %v", l.t) } 30 | func (l list) String() string { return fmt.Sprintf("%v", l) } 31 | func (l list) Normalize(k, v TypeVarSet) (Type, error) { 32 | var t Type 33 | var err error 34 | if t, err = l.t.Normalize(k, v); err != nil { 35 | return nil, err 36 | } 37 | l.t = t 38 | return l, nil 39 | } 40 | func (l list) Types() Types { return Types{l.t} } 41 | func (l list) Eq(other Type) bool { 42 | if ot, ok := other.(list); ok { 43 | return ot.t.Eq(l.t) 44 | } 45 | return false 46 | } 47 | 48 | type mirrorUniverseList struct { 49 | t Type 50 | } 51 | 52 | func (l mirrorUniverseList) Name() string { return "GoateeList" } 53 | func (l mirrorUniverseList) Apply(subs Subs) Substitutable { l.t = l.t.Apply(subs).(Type); return l } 54 | func (l mirrorUniverseList) FreeTypeVar() TypeVarSet { return l.t.FreeTypeVar() } 55 | func (l mirrorUniverseList) Format(s fmt.State, c rune) { fmt.Fprintf(s, "List %v", l.t) } 56 | func (l mirrorUniverseList) String() string { return fmt.Sprintf("%v", l) } 57 | func (l mirrorUniverseList) Normalize(k, v TypeVarSet) (Type, error) { 58 | var t Type 59 | var err error 60 | if t, err = l.t.Normalize(k, v); err != nil { 61 | return nil, err 62 | } 63 | l.t = t 64 | return l, nil 65 | } 66 | func (l mirrorUniverseList) Types() Types { return Types{l.t} } 67 | func (l mirrorUniverseList) Eq(other Type) bool { 68 | if ot, ok := other.(list); ok { 69 | return ot.t.Eq(l.t) 70 | } 71 | return false 72 | } 73 | 74 | // satisfies the Inferer interface for testing 75 | type selfInferer bool 76 | 77 | func (t selfInferer) Infer(Env, Fresher) (Type, error) { 78 | if bool(t) { 79 | return proton, nil 80 | } 81 | return nil, errors.Errorf("fail") 82 | } 83 | func (t selfInferer) Body() Expression { panic("not implemented") } 84 | 85 | // satisfies the Var interface for testing. It also doesn't know its own type 86 | type variable string 87 | 88 | func (t variable) Body() Expression { return nil } 89 | func (t variable) Name() string { return string(t) } 90 | func (t variable) Type() Type { return nil } 91 | -------------------------------------------------------------------------------- /perf_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "testing" 4 | 5 | func TestSubsPool(t *testing.T) { 6 | var def TypeVariable 7 | for i := 0; i < poolSize; i++ { 8 | s := BorrowSSubs(i + 1) 9 | if cap(s.s) != (i+1)+extraCap { 10 | t.Errorf("Expected s to have cap of %d", i+1+extraCap) 11 | goto mSubTest 12 | } 13 | if len(s.s) != (i + 1) { 14 | t.Errorf("Expected s to have a len of %d", i+1) 15 | goto mSubTest 16 | } 17 | 18 | s.s[0] = Substitution{TypeVariable('a'), electron} 19 | ReturnSubs(s) 20 | s = BorrowSSubs(i + 1) 21 | 22 | for _, subst := range s.s { 23 | if subst.T != nil { 24 | t.Errorf("sSubsPool %d error: not clean: %v", i, subst) 25 | break 26 | } 27 | 28 | if subst.Tv != def { 29 | t.Errorf("sSubsPool %d error: not clean: %v", i, subst) 30 | break 31 | } 32 | } 33 | 34 | mSubTest: 35 | m := BorrowMSubs() 36 | if len(m) != 0 { 37 | t.Errorf("Expected borrowed mSubs to have 0 length") 38 | } 39 | 40 | m['a'] = electron 41 | ReturnSubs(m) 42 | 43 | m = BorrowMSubs() 44 | if len(m) != 0 { 45 | t.Errorf("Expected borrowed mSubs to have 0 length") 46 | } 47 | 48 | } 49 | 50 | // oob tests 51 | s := BorrowSSubs(10) 52 | if cap(s.s) != 10 { 53 | t.Error("Expected a cap of 10") 54 | } 55 | ReturnSubs(s) 56 | } 57 | 58 | func TestTypesPool(t *testing.T) { 59 | for i := 0; i < poolSize; i++ { 60 | ts := BorrowTypes(i + 1) 61 | if cap(ts) != i+1 { 62 | t.Errorf("Expected ts to have a cap of %v", i+1) 63 | } 64 | 65 | ts[0] = proton 66 | ReturnTypes(ts) 67 | ts = BorrowTypes(i + 1) 68 | for _, v := range ts { 69 | if v != nil { 70 | t.Errorf("Expected reshly borrowed Types to be nil") 71 | } 72 | } 73 | } 74 | 75 | // oob 76 | ts := BorrowTypes(10) 77 | if cap(ts) != 10 { 78 | t.Errorf("Expected a cap to 10") 79 | } 80 | 81 | } 82 | 83 | func TestTypeVarSetPool(t *testing.T) { 84 | var def TypeVariable 85 | for i := 0; i < poolSize; i++ { 86 | ts := BorrowTypeVarSet(i + 1) 87 | if cap(ts) != i+1 { 88 | t.Errorf("Expected ts to have a cap of %v", i+1) 89 | } 90 | 91 | ts[0] = 'z' 92 | ReturnTypeVarSet(ts) 93 | ts = BorrowTypeVarSet(i + 1) 94 | for _, v := range ts { 95 | if v != def { 96 | t.Errorf("Expected reshly borrowed Types to be def") 97 | } 98 | } 99 | } 100 | 101 | // oob 102 | tvs := BorrowTypeVarSet(10) 103 | if cap(tvs) != 10 { 104 | t.Error("Expected a cap of 10") 105 | } 106 | } 107 | 108 | func TestFnTypeOol(t *testing.T) { 109 | f := borrowFnType() 110 | f.a = NewFnType(proton, electron) 111 | f.b = NewFnType(proton, neutron) 112 | 113 | ReturnFnType(f) 114 | f = borrowFnType() 115 | if f.a != nil { 116 | t.Error("FunctionType not cleaned up: a is not nil") 117 | } 118 | if f.b != nil { 119 | t.Error("FunctionType not cleaned up: b is not nil") 120 | } 121 | 122 | } 123 | -------------------------------------------------------------------------------- /substitutions.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "fmt" 4 | 5 | // Subs is a list of substitution. Internally there are two very basic substitutions - one backed by map and the other a normal slice 6 | type Subs interface { 7 | Get(TypeVariable) (Type, bool) 8 | Add(TypeVariable, Type) Subs 9 | Remove(TypeVariable) Subs 10 | 11 | // Iter() <-chan Substitution 12 | Iter() []Substitution 13 | Size() int 14 | Clone() Subs 15 | } 16 | 17 | // A Substitution is a tuple representing the TypeVariable and the replacement Type 18 | type Substitution struct { 19 | Tv TypeVariable 20 | T Type 21 | } 22 | 23 | type sSubs struct { 24 | s []Substitution 25 | } 26 | 27 | func newSliceSubs(maybeSize ...int) *sSubs { 28 | var size int 29 | if len(maybeSize) > 0 && maybeSize[0] > 0 { 30 | size = maybeSize[0] 31 | } 32 | retVal := BorrowSSubs(size) 33 | retVal.s = retVal.s[:0] 34 | return retVal 35 | } 36 | 37 | func (s *sSubs) Get(tv TypeVariable) (Type, bool) { 38 | if i := s.index(tv); i >= 0 { 39 | return s.s[i].T, true 40 | } 41 | return nil, false 42 | } 43 | 44 | func (s *sSubs) Add(tv TypeVariable, t Type) Subs { 45 | if i := s.index(tv); i >= 0 { 46 | s.s[i].T = t 47 | return s 48 | } 49 | s.s = append(s.s, Substitution{tv, t}) 50 | return s 51 | } 52 | 53 | func (s *sSubs) Remove(tv TypeVariable) Subs { 54 | if i := s.index(tv); i >= 0 { 55 | // for now we keep the order 56 | copy(s.s[i:], s.s[i+1:]) 57 | s.s[len(s.s)-1].T = nil 58 | s.s = s.s[:len(s.s)-1] 59 | } 60 | 61 | return s 62 | } 63 | 64 | func (s *sSubs) Iter() []Substitution { return s.s } 65 | func (s *sSubs) Size() int { return len(s.s) } 66 | func (s *sSubs) Clone() Subs { 67 | retVal := BorrowSSubs(len(s.s)) 68 | copy(retVal.s, s.s) 69 | return retVal 70 | } 71 | 72 | func (s *sSubs) index(tv TypeVariable) int { 73 | for i, sub := range s.s { 74 | if sub.Tv == tv { 75 | return i 76 | } 77 | } 78 | return -1 79 | } 80 | 81 | func (s *sSubs) Format(state fmt.State, c rune) { 82 | state.Write([]byte{'{'}) 83 | for i, v := range s.s { 84 | if i < len(s.s)-1 { 85 | fmt.Fprintf(state, "%v: %v, ", v.Tv, v.T) 86 | 87 | } else { 88 | fmt.Fprintf(state, "%v: %v", v.Tv, v.T) 89 | } 90 | } 91 | state.Write([]byte{'}'}) 92 | } 93 | 94 | type mSubs map[TypeVariable]Type 95 | 96 | func (s mSubs) Get(tv TypeVariable) (Type, bool) { retVal, ok := s[tv]; return retVal, ok } 97 | func (s mSubs) Add(tv TypeVariable, t Type) Subs { s[tv] = t; return s } 98 | func (s mSubs) Remove(tv TypeVariable) Subs { delete(s, tv); return s } 99 | 100 | func (s mSubs) Iter() []Substitution { 101 | retVal := make([]Substitution, len(s)) 102 | var i int 103 | for k, v := range s { 104 | retVal[i] = Substitution{k, v} 105 | i++ 106 | } 107 | return retVal 108 | } 109 | 110 | func (s mSubs) Size() int { return len(s) } 111 | func (s mSubs) Clone() Subs { 112 | retVal := make(mSubs) 113 | for k, v := range s { 114 | retVal[k] = v 115 | } 116 | return retVal 117 | } 118 | 119 | func compose(a, b Subs) (retVal Subs) { 120 | if b == nil { 121 | return a 122 | } 123 | 124 | retVal = b.Clone() 125 | 126 | if a == nil { 127 | return 128 | } 129 | 130 | for _, v := range a.Iter() { 131 | retVal = retVal.Add(v.Tv, v.T) 132 | } 133 | 134 | for _, v := range retVal.Iter() { 135 | retVal = retVal.Add(v.Tv, v.T.Apply(a).(Type)) 136 | } 137 | return retVal 138 | } 139 | -------------------------------------------------------------------------------- /functionType.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "fmt" 4 | 5 | // FunctionType is a type constructor that builds function types. 6 | type FunctionType struct { 7 | a, b Type 8 | } 9 | 10 | // NewFnType creates a new FunctionType. Functions are by default right associative. This: 11 | // NewFnType(a, a, a) 12 | // is short hand for this: 13 | // NewFnType(a, NewFnType(a, a)) 14 | func NewFnType(ts ...Type) *FunctionType { 15 | if len(ts) < 2 { 16 | panic("Expected at least 2 input types") 17 | } 18 | 19 | retVal := borrowFnType() 20 | retVal.a = ts[0] 21 | 22 | if len(ts) > 2 { 23 | retVal.b = NewFnType(ts[1:]...) 24 | } else { 25 | retVal.b = ts[1] 26 | } 27 | return retVal 28 | } 29 | 30 | func (t *FunctionType) Name() string { return "→" } 31 | func (t *FunctionType) Apply(sub Subs) Substitutable { 32 | t.a = t.a.Apply(sub).(Type) 33 | t.b = t.b.Apply(sub).(Type) 34 | return t 35 | } 36 | 37 | func (t *FunctionType) FreeTypeVar() TypeVarSet { return t.a.FreeTypeVar().Union(t.b.FreeTypeVar()) } 38 | func (t *FunctionType) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%v → %v", t.a, t.b) } 39 | func (t *FunctionType) String() string { return fmt.Sprintf("%v", t) } 40 | func (t *FunctionType) Normalize(k, v TypeVarSet) (Type, error) { 41 | var a, b Type 42 | var err error 43 | if a, err = t.a.Normalize(k, v); err != nil { 44 | return nil, err 45 | } 46 | 47 | if b, err = t.b.Normalize(k, v); err != nil { 48 | return nil, err 49 | } 50 | 51 | return NewFnType(a, b), nil 52 | } 53 | func (t *FunctionType) Types() Types { 54 | retVal := BorrowTypes(2) 55 | retVal[0] = t.a 56 | retVal[1] = t.b 57 | return retVal 58 | } 59 | 60 | func (t *FunctionType) Eq(other Type) bool { 61 | if ot, ok := other.(*FunctionType); ok { 62 | return ot.a.Eq(t.a) && ot.b.Eq(t.b) 63 | } 64 | return false 65 | } 66 | 67 | // Other methods (accessors mainly) 68 | 69 | // Arg returns the type of the function argument 70 | func (t *FunctionType) Arg() Type { return t.a } 71 | 72 | // Ret returns the return type of a function. If recursive is true, it will get the final return type 73 | func (t *FunctionType) Ret(recursive bool) Type { 74 | if !recursive { 75 | return t.b 76 | } 77 | 78 | if fnt, ok := t.b.(*FunctionType); ok { 79 | return fnt.Ret(recursive) 80 | } 81 | 82 | return t.b 83 | } 84 | 85 | // FlatTypes returns the types in FunctionTypes as a flat slice of types. This allows for easier iteration in some applications 86 | func (t *FunctionType) FlatTypes() Types { 87 | retVal := BorrowTypes(8) // start with 8. Can always grow 88 | retVal = retVal[:0] 89 | 90 | if a, ok := t.a.(*FunctionType); ok { 91 | ft := a.FlatTypes() 92 | retVal = append(retVal, ft...) 93 | ReturnTypes(ft) 94 | } else { 95 | retVal = append(retVal, t.a) 96 | } 97 | 98 | if b, ok := t.b.(*FunctionType); ok { 99 | ft := b.FlatTypes() 100 | retVal = append(retVal, ft...) 101 | ReturnTypes(ft) 102 | } else { 103 | retVal = append(retVal, t.b) 104 | } 105 | return retVal 106 | } 107 | 108 | // Clone implements Cloner 109 | func (t *FunctionType) Clone() interface{} { 110 | retVal := new(FunctionType) 111 | 112 | if ac, ok := t.a.(Cloner); ok { 113 | retVal.a = ac.Clone().(Type) 114 | } else { 115 | retVal.a = t.a 116 | } 117 | 118 | if bc, ok := t.b.(Cloner); ok { 119 | retVal.b = bc.Clone().(Type) 120 | } else { 121 | retVal.b = t.b 122 | } 123 | return retVal 124 | } 125 | -------------------------------------------------------------------------------- /functionType_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "testing" 5 | 6 | "github.com/stretchr/testify/assert" 7 | ) 8 | 9 | func TestFunctionTypeBasics(t *testing.T) { 10 | fnType := NewFnType(TypeVariable('a'), TypeVariable('a'), TypeVariable('a')) 11 | if fnType.Name() != "→" { 12 | t.Errorf("FunctionType should have \"→\" as a name. Got %q instead", fnType.Name()) 13 | } 14 | 15 | if fnType.String() != "a → a → a" { 16 | t.Errorf("Expected \"a → a → a\". Got %q instead", fnType.String()) 17 | } 18 | 19 | if !fnType.Arg().Eq(TypeVariable('a')) { 20 | t.Error("Expected arg of function to be 'a'") 21 | } 22 | 23 | if !fnType.Ret(false).Eq(NewFnType(TypeVariable('a'), TypeVariable('a'))) { 24 | t.Error("Expected ret(false) to be a → a") 25 | } 26 | 27 | if !fnType.Ret(true).Eq(TypeVariable('a')) { 28 | t.Error("Expected final return type to be 'a'") 29 | } 30 | 31 | // a very simple fn 32 | fnType = NewFnType(TypeVariable('a'), TypeVariable('a')) 33 | if !fnType.Ret(true).Eq(TypeVariable('a')) { 34 | t.Error("Expected final return type to be 'a'") 35 | } 36 | 37 | ftv := fnType.FreeTypeVar() 38 | if len(ftv) != 1 { 39 | t.Errorf("Expected only one free type var") 40 | } 41 | 42 | for _, fas := range fnApplyTests { 43 | fn := fas.fn.Apply(fas.sub).(*FunctionType) 44 | if !fn.Eq(fas.expected) { 45 | t.Errorf("Expected %v. Got %v instead", fas.expected, fn) 46 | } 47 | } 48 | 49 | // bad shit 50 | f := func() { 51 | NewFnType(TypeVariable('a')) 52 | } 53 | assert.Panics(t, f) 54 | } 55 | 56 | var fnApplyTests = []struct { 57 | fn *FunctionType 58 | sub Subs 59 | 60 | expected *FunctionType 61 | }{ 62 | {NewFnType(TypeVariable('a'), TypeVariable('a')), mSubs{'a': proton, 'b': neutron}, NewFnType(proton, proton)}, 63 | {NewFnType(TypeVariable('a'), TypeVariable('b')), mSubs{'a': proton, 'b': neutron}, NewFnType(proton, neutron)}, 64 | {NewFnType(TypeVariable('a'), TypeVariable('b')), mSubs{'c': proton, 'd': neutron}, NewFnType(TypeVariable('a'), TypeVariable('b'))}, 65 | {NewFnType(TypeVariable('a'), TypeVariable('b')), mSubs{'a': proton, 'c': neutron}, NewFnType(proton, TypeVariable('b'))}, 66 | {NewFnType(TypeVariable('a'), TypeVariable('b')), mSubs{'c': proton, 'b': neutron}, NewFnType(TypeVariable('a'), neutron)}, 67 | {NewFnType(electron, proton), mSubs{'a': proton, 'b': neutron}, NewFnType(electron, proton)}, 68 | 69 | // a -> (b -> c) 70 | {NewFnType(TypeVariable('a'), TypeVariable('b'), TypeVariable('a')), mSubs{'a': proton, 'b': neutron}, NewFnType(proton, neutron, proton)}, 71 | {NewFnType(TypeVariable('a'), TypeVariable('a'), TypeVariable('b')), mSubs{'a': proton, 'b': neutron}, NewFnType(proton, proton, neutron)}, 72 | {NewFnType(TypeVariable('a'), TypeVariable('b'), TypeVariable('c')), mSubs{'a': proton, 'b': neutron}, NewFnType(proton, neutron, TypeVariable('c'))}, 73 | {NewFnType(TypeVariable('a'), TypeVariable('c'), TypeVariable('b')), mSubs{'a': proton, 'b': neutron}, NewFnType(proton, TypeVariable('c'), neutron)}, 74 | 75 | // (a -> b) -> c 76 | {NewFnType(NewFnType(TypeVariable('a'), TypeVariable('b')), TypeVariable('a')), mSubs{'a': proton, 'b': neutron}, NewFnType(NewFnType(proton, neutron), proton)}, 77 | } 78 | 79 | func TestFunctionType_FlatTypes(t *testing.T) { 80 | fnType := NewFnType(TypeVariable('a'), TypeVariable('b'), TypeVariable('c')) 81 | ts := fnType.FlatTypes() 82 | correct := Types{TypeVariable('a'), TypeVariable('b'), TypeVariable('c')} 83 | assert.Equal(t, ts, correct) 84 | 85 | fnType2 := NewFnType(fnType, TypeVariable('d')) 86 | correct = append(correct, TypeVariable('d')) 87 | ts = fnType2.FlatTypes() 88 | assert.Equal(t, ts, correct) 89 | } 90 | 91 | func TestFunctionType_Clone(t *testing.T) { 92 | fnType := NewFnType(TypeVariable('a'), TypeVariable('b'), TypeVariable('c')) 93 | assert.Equal(t, fnType.Clone(), fnType) 94 | 95 | rec := NewRecordType("", TypeVariable('a'), NewFnType(TypeVariable('a'), TypeVariable('b')), TypeVariable('c')) 96 | fnType = NewFnType(rec, rec) 97 | assert.Equal(t, fnType.Clone(), fnType) 98 | } 99 | -------------------------------------------------------------------------------- /type.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "fmt" 5 | ) 6 | 7 | // Type represents all the possible type constructors. 8 | type Type interface { 9 | Substitutable 10 | Name() string // Name is the name of the constructor 11 | Normalize(TypeVarSet, TypeVarSet) (Type, error) // Normalize normalizes all the type variable names in the type 12 | Types() Types // If the type is made up of smaller types, then it will return them 13 | Eq(Type) bool // equality operation 14 | 15 | fmt.Formatter 16 | fmt.Stringer 17 | } 18 | 19 | // Substitutable is any type that can have a set of substitutions applied on it, as well as being able to know what its free type variables are 20 | type Substitutable interface { 21 | Apply(Subs) Substitutable 22 | FreeTypeVar() TypeVarSet 23 | } 24 | 25 | // TypeConst are the default implementation of a constant type. Feel free to implement your own. TypeConsts should be immutable (so no pointer types plz) 26 | type TypeConst string 27 | 28 | func (t TypeConst) Name() string { return string(t) } 29 | func (t TypeConst) Apply(Subs) Substitutable { return t } 30 | func (t TypeConst) FreeTypeVar() TypeVarSet { return nil } 31 | func (t TypeConst) Normalize(k, v TypeVarSet) (Type, error) { return t, nil } 32 | func (t TypeConst) Types() Types { return nil } 33 | func (t TypeConst) String() string { return string(t) } 34 | func (t TypeConst) Format(s fmt.State, c rune) { fmt.Fprintf(s, "%s", string(t)) } 35 | func (t TypeConst) Eq(other Type) bool { return other == t } 36 | 37 | // Record is a basic record/tuple type. It takes an optional name. 38 | type Record struct { 39 | ts []Type 40 | name string 41 | } 42 | 43 | // NewRecordType creates a new Record Type 44 | func NewRecordType(name string, ts ...Type) *Record { 45 | return &Record{ 46 | ts: ts, 47 | name: name, 48 | } 49 | } 50 | 51 | func (t *Record) Apply(subs Subs) Substitutable { 52 | ts := make([]Type, len(t.ts)) 53 | for i, v := range t.ts { 54 | ts[i] = v.Apply(subs).(Type) 55 | } 56 | return NewRecordType(t.name, ts...) 57 | } 58 | 59 | func (t *Record) FreeTypeVar() TypeVarSet { 60 | var tvs TypeVarSet 61 | for _, v := range t.ts { 62 | tvs = v.FreeTypeVar().Union(tvs) 63 | } 64 | return tvs 65 | } 66 | 67 | func (t *Record) Name() string { 68 | if t.name != "" { 69 | return t.name 70 | } 71 | return t.String() 72 | } 73 | 74 | func (t *Record) Normalize(k, v TypeVarSet) (Type, error) { 75 | ts := make([]Type, len(t.ts)) 76 | var err error 77 | for i, tt := range t.ts { 78 | if ts[i], err = tt.Normalize(k, v); err != nil { 79 | return nil, err 80 | } 81 | } 82 | return NewRecordType(t.name, ts...), nil 83 | } 84 | 85 | func (t *Record) Types() Types { 86 | ts := BorrowTypes(len(t.ts)) 87 | copy(ts, t.ts) 88 | return ts 89 | } 90 | 91 | func (t *Record) Eq(other Type) bool { 92 | if ot, ok := other.(*Record); ok { 93 | if len(ot.ts) != len(t.ts) { 94 | return false 95 | } 96 | for i, v := range t.ts { 97 | if !v.Eq(ot.ts[i]) { 98 | return false 99 | } 100 | } 101 | return true 102 | } 103 | return false 104 | } 105 | 106 | func (t *Record) Format(f fmt.State, c rune) { 107 | f.Write([]byte("(")) 108 | for i, v := range t.ts { 109 | if i < len(t.ts)-1 { 110 | fmt.Fprintf(f, "%v, ", v) 111 | } else { 112 | fmt.Fprintf(f, "%v)", v) 113 | } 114 | } 115 | 116 | } 117 | 118 | func (t *Record) String() string { return fmt.Sprintf("%v", t) } 119 | 120 | // Clone implements Cloner 121 | func (t *Record) Clone() interface{} { 122 | retVal := new(Record) 123 | ts := BorrowTypes(len(t.ts)) 124 | for i, tt := range t.ts { 125 | if c, ok := tt.(Cloner); ok { 126 | ts[i] = c.Clone().(Type) 127 | } else { 128 | ts[i] = tt 129 | } 130 | } 131 | retVal.ts = ts 132 | retVal.name = t.name 133 | 134 | return retVal 135 | } 136 | -------------------------------------------------------------------------------- /substitutions_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "fmt" 5 | "testing" 6 | ) 7 | 8 | var subsTests = []struct { 9 | op string 10 | tv TypeVariable 11 | t Type 12 | 13 | ok bool 14 | size int 15 | }{ 16 | {"get", TypeVariable('a'), nil, false, 0}, 17 | {"add", TypeVariable('a'), proton, true, 1}, 18 | {"get", TypeVariable('a'), proton, true, 1}, 19 | {"add", TypeVariable('a'), neutron, true, 1}, 20 | {"get", TypeVariable('a'), neutron, true, 1}, 21 | {"rem", TypeVariable('b'), nil, false, 1}, 22 | {"rem", TypeVariable('a'), nil, false, 0}, 23 | {"add", TypeVariable('a'), proton, true, 1}, 24 | {"add", TypeVariable('b'), proton, true, 2}, 25 | {"add", TypeVariable('c'), proton, true, 3}, 26 | } 27 | 28 | func testSubs(t *testing.T, sub Subs) { 29 | var T Type 30 | var ok bool 31 | for _, sts := range subsTests { 32 | switch sts.op { 33 | case "get": 34 | if T, ok = sub.Get(sts.tv); ok != sts.ok { 35 | t.Errorf("Expected Get to return %t. Got a value of %v instead", sts.ok, T) 36 | } 37 | case "add": 38 | sub = sub.Add(sts.tv, sts.t) 39 | case "rem": 40 | sub = sub.Remove(sts.tv) 41 | } 42 | 43 | if sub.Size() != sts.size { 44 | t.Errorf("Inconsistent size. Want %d. Got %d", sts.size, sub.Size()) 45 | } 46 | } 47 | 48 | // Iter 49 | correct := []Substitution{ 50 | {TypeVariable('a'), proton}, 51 | {TypeVariable('b'), proton}, 52 | {TypeVariable('c'), proton}, 53 | } 54 | 55 | for _, s := range sub.Iter() { 56 | var found bool 57 | for _, c := range correct { 58 | if s.T == c.T && s.Tv == c.Tv { 59 | found = true 60 | break 61 | } 62 | } 63 | if !found { 64 | t.Errorf("Testing of %T: cannot find %v in Range", sub, s) 65 | } 66 | } 67 | 68 | // Clone 69 | cloned := sub.Clone() 70 | cloned = cloned.Add(TypeVariable('a'), photon) 71 | gt, ok := sub.Get(TypeVariable('a')) 72 | if !ok { 73 | t.Errorf("Expected the key 'a' to be found") 74 | } 75 | if gt == photon { 76 | t.Errorf("Mutable cloning found") 77 | } 78 | } 79 | 80 | func TestSliceSubs(t *testing.T) { 81 | var sub Subs 82 | 83 | sub = newSliceSubs() 84 | if sub.Size() != 0 { 85 | t.Error("Expected a size of 0") 86 | } 87 | 88 | sub = newSliceSubs(5) 89 | if cap(sub.(*sSubs).s) != 5 { 90 | t.Error("Expected a cap of 5") 91 | } 92 | if sub.Size() != 0 { 93 | t.Error("Expected a size of 0") 94 | } 95 | 96 | testSubs(t, sub) 97 | 98 | // Format for completeness sake 99 | sub = newSliceSubs(2) 100 | sub = sub.Add('a', proton) 101 | sub = sub.Add('b', neutron) 102 | if fmt.Sprintf("%v", sub) != "{a: proton, b: neutron}" { 103 | t.Errorf("Format of sub is wrong. Got %q instead", sub) 104 | } 105 | } 106 | 107 | func TestMapSubs(t *testing.T) { 108 | var sub Subs 109 | 110 | sub = make(mSubs) 111 | if sub.Size() != 0 { 112 | t.Error("Expected a size of 0") 113 | } 114 | 115 | testSubs(t, sub) 116 | } 117 | 118 | var composeTests = []struct { 119 | a Subs 120 | b Subs 121 | 122 | expected Subs 123 | }{ 124 | {mSubs{'a': proton}, &sSubs{[]Substitution{{'b', neutron}}}, &sSubs{[]Substitution{{'a', proton}, {'b', neutron}}}}, 125 | {&sSubs{[]Substitution{{'b', neutron}}}, mSubs{'a': proton}, mSubs{'a': proton, 'b': neutron}}, 126 | 127 | {mSubs{'a': proton, 'b': neutron}, &sSubs{[]Substitution{{'b', neutron}}}, &sSubs{[]Substitution{{'a', proton}, {'b', neutron}}}}, 128 | {mSubs{'a': proton, 'b': TypeVariable('a')}, &sSubs{[]Substitution{{'b', neutron}}}, &sSubs{[]Substitution{{'a', proton}, {'b', proton}}}}, 129 | {mSubs{'a': proton}, &sSubs{[]Substitution{{'b', TypeVariable('a')}}}, &sSubs{[]Substitution{{'a', proton}, {'b', proton}}}}, 130 | } 131 | 132 | func TestCompose(t *testing.T) { 133 | for i, cts := range composeTests { 134 | subs := compose(cts.a, cts.b) 135 | 136 | for _, v := range cts.expected.Iter() { 137 | if T, ok := subs.Get(v.Tv); !ok { 138 | t.Errorf("Test %d: Expected TypeVariable %v to be in subs", i, v.Tv) 139 | } else if T != v.T { 140 | t.Errorf("Test %d: Expected replacement to be %v. Got %v instead", i, v.T, T) 141 | } 142 | } 143 | } 144 | } 145 | -------------------------------------------------------------------------------- /perf.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "sync" 4 | 5 | const ( 6 | poolSize = 4 7 | extraCap = 2 8 | ) 9 | 10 | var sSubPool = [poolSize]*sync.Pool{ 11 | &sync.Pool{ 12 | New: func() interface{} { return &sSubs{s: make([]Substitution, 1, 1+extraCap)} }, 13 | }, 14 | &sync.Pool{ 15 | New: func() interface{} { return &sSubs{s: make([]Substitution, 2, 2+extraCap)} }, 16 | }, 17 | &sync.Pool{ 18 | New: func() interface{} { return &sSubs{s: make([]Substitution, 3, 3+extraCap)} }, 19 | }, 20 | &sync.Pool{ 21 | New: func() interface{} { return &sSubs{s: make([]Substitution, 4, 4+extraCap)} }, 22 | }, 23 | } 24 | 25 | var mSubPool = &sync.Pool{ 26 | New: func() interface{} { return make(mSubs) }, 27 | } 28 | 29 | // ReturnSubs returns substitutions to the pool. USE WITH CAUTION. 30 | func ReturnSubs(sub Subs) { 31 | switch s := sub.(type) { 32 | case mSubs: 33 | for k := range s { 34 | delete(s, k) 35 | } 36 | mSubPool.Put(sub) 37 | case *sSubs: 38 | size := cap(s.s) - 2 39 | if size > 0 && size < poolSize+1 { 40 | // reset to empty 41 | for i := range s.s { 42 | s.s[i] = Substitution{} 43 | } 44 | 45 | s.s = s.s[:size] 46 | sSubPool[size-1].Put(sub) 47 | } 48 | } 49 | } 50 | 51 | // BorrowMSubs gets a map based substitution from a shared pool. USE WITH CAUTION 52 | func BorrowMSubs() mSubs { 53 | return mSubPool.Get().(mSubs) 54 | } 55 | 56 | // BorrowSSubs gets a slice based substituiton from a shared pool. USE WITH CAUTION 57 | func BorrowSSubs(size int) *sSubs { 58 | if size > 0 && size < 5 { 59 | retVal := sSubPool[size-1].Get().(*sSubs) 60 | return retVal 61 | } 62 | s := make([]Substitution, size) 63 | return &sSubs{s: s} 64 | } 65 | 66 | var typesPool = [poolSize]*sync.Pool{ 67 | &sync.Pool{ 68 | New: func() interface{} { return make(Types, 1) }, 69 | }, 70 | 71 | &sync.Pool{ 72 | New: func() interface{} { return make(Types, 2) }, 73 | }, 74 | 75 | &sync.Pool{ 76 | New: func() interface{} { return make(Types, 3) }, 77 | }, 78 | 79 | &sync.Pool{ 80 | New: func() interface{} { return make(Types, 4) }, 81 | }, 82 | } 83 | 84 | // BorrowTypes gets a slice of Types with size. USE WITH CAUTION. 85 | func BorrowTypes(size int) Types { 86 | if size > 0 && size < poolSize+1 { 87 | return typesPool[size-1].Get().(Types) 88 | } 89 | return make(Types, size) 90 | } 91 | 92 | // ReturnTypes returns the slice of types into the pool. USE WITH CAUTION 93 | func ReturnTypes(ts Types) { 94 | if size := cap(ts); size > 0 && size < poolSize+1 { 95 | ts = ts[:cap(ts)] 96 | for i := range ts { 97 | ts[i] = nil 98 | } 99 | typesPool[size-1].Put(ts) 100 | } 101 | } 102 | 103 | var typeVarSetPool = [poolSize]*sync.Pool{ 104 | &sync.Pool{ 105 | New: func() interface{} { return make(TypeVarSet, 1) }, 106 | }, 107 | 108 | &sync.Pool{ 109 | New: func() interface{} { return make(TypeVarSet, 2) }, 110 | }, 111 | 112 | &sync.Pool{ 113 | New: func() interface{} { return make(TypeVarSet, 3) }, 114 | }, 115 | 116 | &sync.Pool{ 117 | New: func() interface{} { return make(TypeVarSet, 4) }, 118 | }, 119 | } 120 | 121 | // BorrowTypeVarSet gets a TypeVarSet of size from pool. USE WITH CAUTION 122 | func BorrowTypeVarSet(size int) TypeVarSet { 123 | if size > 0 && size < poolSize+1 { 124 | return typeVarSetPool[size-1].Get().(TypeVarSet) 125 | } 126 | return make(TypeVarSet, size) 127 | } 128 | 129 | // ReturnTypeVarSet returns the TypeVarSet to pool. USE WITH CAUTION 130 | func ReturnTypeVarSet(ts TypeVarSet) { 131 | var def TypeVariable 132 | if size := cap(ts); size > 0 && size < poolSize+1 { 133 | ts = ts[:cap(ts)] 134 | for i := range ts { 135 | ts[i] = def 136 | } 137 | typeVarSetPool[size-1].Put(ts) 138 | } 139 | } 140 | 141 | var fnTypePool = &sync.Pool{ 142 | New: func() interface{} { return new(FunctionType) }, 143 | } 144 | 145 | func borrowFnType() *FunctionType { 146 | return fnTypePool.Get().(*FunctionType) 147 | } 148 | 149 | // ReturnFnType returns a *FunctionType to the pool. NewFnType automatically borrows from the pool. USE WITH CAUTION 150 | func ReturnFnType(fnt *FunctionType) { 151 | if a, ok := fnt.a.(*FunctionType); ok { 152 | ReturnFnType(a) 153 | } 154 | 155 | if b, ok := fnt.b.(*FunctionType); ok { 156 | ReturnFnType(b) 157 | } 158 | 159 | fnt.a = nil 160 | fnt.b = nil 161 | fnTypePool.Put(fnt) 162 | } 163 | -------------------------------------------------------------------------------- /example_greenspun_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import ( 4 | "fmt" 5 | "log" 6 | "strings" 7 | 8 | "github.com/pkg/errors" 9 | ) 10 | 11 | const digits = "0123456789" 12 | 13 | type TyperExpression interface { 14 | Expression 15 | Typer 16 | } 17 | 18 | type λ struct { 19 | name string 20 | body Expression 21 | } 22 | 23 | func (n λ) Name() string { return n.name } 24 | func (n λ) Body() Expression { return n.body } 25 | func (n λ) IsLambda() bool { return true } 26 | 27 | type lit string 28 | 29 | func (n lit) Name() string { return string(n) } 30 | func (n lit) Body() Expression { return n } 31 | func (n lit) Type() Type { 32 | switch { 33 | case strings.ContainsAny(digits, string(n)) && strings.ContainsAny(digits, string(n[0])): 34 | return Float 35 | case string(n) == "true" || string(n) == "false": 36 | return Bool 37 | default: 38 | return nil 39 | } 40 | } 41 | func (n lit) IsLit() bool { return true } 42 | func (n lit) IsLambda() bool { return true } 43 | 44 | type app struct { 45 | f Expression 46 | arg Expression 47 | } 48 | 49 | func (n app) Fn() Expression { return n.f } 50 | func (n app) Body() Expression { return n.arg } 51 | func (n app) Arg() Expression { return n.arg } 52 | 53 | type let struct { 54 | name string 55 | def Expression 56 | in Expression 57 | } 58 | 59 | func (n let) Name() string { return n.name } 60 | func (n let) Def() Expression { return n.def } 61 | func (n let) Body() Expression { return n.in } 62 | 63 | type letrec struct { 64 | name string 65 | def Expression 66 | in Expression 67 | } 68 | 69 | func (n letrec) Name() string { return n.name } 70 | func (n letrec) Def() Expression { return n.def } 71 | func (n letrec) Body() Expression { return n.in } 72 | func (n letrec) Children() []Expression { return []Expression{n.def, n.in} } 73 | func (n letrec) IsRecursive() bool { return true } 74 | 75 | type prim byte 76 | 77 | const ( 78 | Float prim = iota 79 | Bool 80 | ) 81 | 82 | // implement Type 83 | func (t prim) Name() string { return t.String() } 84 | func (t prim) Apply(Subs) Substitutable { return t } 85 | func (t prim) FreeTypeVar() TypeVarSet { return nil } 86 | func (t prim) Normalize(TypeVarSet, TypeVarSet) (Type, error) { return t, nil } 87 | func (t prim) Types() Types { return nil } 88 | func (t prim) Eq(other Type) bool { 89 | if ot, ok := other.(prim); ok { 90 | return ot == t 91 | } 92 | return false 93 | } 94 | 95 | func (t prim) Format(s fmt.State, c rune) { fmt.Fprintf(s, t.String()) } 96 | func (t prim) String() string { 97 | switch t { 98 | case Float: 99 | return "Float" 100 | case Bool: 101 | return "Bool" 102 | } 103 | return "HELP" 104 | } 105 | 106 | //Phillip Greenspun's tenth law says: 107 | // "Any sufficiently complicated C or Fortran program contains an ad hoc, informally-specified, bug-ridden, slow implementation of half of Common Lisp." 108 | // 109 | // So let's implement a half-arsed lisp (Or rather, an AST that can optionally be executed upon if you write the correct interpreter)! 110 | func Example_greenspun() { 111 | // haskell envy in a greenspun's tenth law example function! 112 | // 113 | // We'll assume the following is the "input" code 114 | // let fac n = if n == 0 then 1 else n * fac (n - 1) in fac 5 115 | // and what we have is the AST 116 | 117 | fac := letrec{ 118 | "fac", 119 | λ{ 120 | "n", 121 | app{ 122 | app{ 123 | app{ 124 | lit("if"), 125 | app{ 126 | lit("isZero"), 127 | lit("n"), 128 | }, 129 | }, 130 | lit("1"), 131 | }, 132 | app{ 133 | app{lit("mul"), lit("n")}, 134 | app{lit("fac"), app{lit("--"), lit("n")}}, 135 | }, 136 | }, 137 | }, 138 | app{lit("fac"), lit("5")}, 139 | } 140 | 141 | // but first, let's start with something simple: 142 | // let x = 3 in x+5 143 | simple := let{ 144 | "x", 145 | lit("3"), 146 | app{ 147 | app{ 148 | lit("+"), 149 | lit("5"), 150 | }, 151 | lit("x"), 152 | }, 153 | } 154 | 155 | env := SimpleEnv{ 156 | "--": &Scheme{tvs: TypeVarSet{'a'}, t: NewFnType(TypeVariable('a'), TypeVariable('a'))}, 157 | "if": &Scheme{tvs: TypeVarSet{'a'}, t: NewFnType(Bool, TypeVariable('a'), TypeVariable('a'), TypeVariable('a'))}, 158 | "isZero": &Scheme{t: NewFnType(Float, Bool)}, 159 | "mul": &Scheme{t: NewFnType(Float, Float, Float)}, 160 | "+": &Scheme{tvs: TypeVarSet{'a'}, t: NewFnType(TypeVariable('a'), TypeVariable('a'), TypeVariable('a'))}, 161 | } 162 | 163 | var scheme *Scheme 164 | var err error 165 | scheme, err = Infer(env, simple) 166 | if err != nil { 167 | log.Printf("%+v", errors.Cause(err)) 168 | } 169 | simpleType, ok := scheme.Type() 170 | fmt.Printf("simple Type: %v | isMonoType: %v | err: %v\n", simpleType, ok, err) 171 | 172 | scheme, err = Infer(env, fac) 173 | if err != nil { 174 | log.Printf("%+v", errors.Cause(err)) 175 | } 176 | 177 | facType, ok := scheme.Type() 178 | fmt.Printf("fac Type: %v | isMonoType: %v | err: %v", facType, ok, err) 179 | 180 | // Output: 181 | // simple Type: Float | isMonoType: true | err: 182 | // fac Type: Float | isMonoType: true | err: 183 | 184 | } 185 | -------------------------------------------------------------------------------- /hm_test.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "testing" 4 | 5 | var unifyTests = []struct { 6 | name string 7 | a Type 8 | b Type 9 | 10 | subs Subs 11 | err bool // does it error? 12 | }{ 13 | {"a ~ a (recursive unification)", TypeVariable('a'), TypeVariable('a'), nil, true}, 14 | {"a ~ b", TypeVariable('a'), TypeVariable('b'), mSubs{'a': TypeVariable('b')}, false}, 15 | {"a ~ proton", TypeVariable('a'), proton, mSubs{'a': proton}, false}, 16 | {"proton ~ a", proton, TypeVariable('a'), mSubs{'a': proton}, false}, 17 | 18 | // typeconst ~ typeconst 19 | {"proton ~ proton", proton, proton, nil, false}, 20 | {"proton ~ neutron", proton, neutron, nil, true}, 21 | {"List a ~ List proton", list{TypeVariable('a')}, list{proton}, mSubs{'a': proton}, false}, 22 | 23 | // function types 24 | {"List a → List a ~ List proton → List proton", 25 | NewFnType(list{TypeVariable('a')}, list{TypeVariable('a')}), 26 | NewFnType(list{proton}, list{proton}), 27 | mSubs{'a': proton}, false}, 28 | {"List proton → List proton ~ List a → List a", 29 | NewFnType(list{proton}, list{proton}), 30 | NewFnType(list{TypeVariable('a')}, list{TypeVariable('a')}), 31 | mSubs{'a': proton}, false}, 32 | {"List a → a ~ List proton → proton", 33 | NewFnType(list{TypeVariable('a')}, TypeVariable('a')), 34 | NewFnType(list{proton}, proton), 35 | mSubs{'a': proton}, false}, 36 | {"List proton → proton ~ List a → a ", 37 | NewFnType(list{proton}, proton), 38 | NewFnType(list{TypeVariable('a')}, TypeVariable('a')), 39 | mSubs{'a': proton}, false}, 40 | {"List a → a → List a ~ List proton → proton → b", 41 | NewFnType(list{TypeVariable('a')}, TypeVariable('a'), list{TypeVariable('a')}), 42 | NewFnType(list{proton}, proton, TypeVariable('b')), 43 | mSubs{'a': proton, 'b': list{proton}}, false}, 44 | {"(a, a, b) ~ (proton, proton, neutron)", 45 | NewRecordType("", TypeVariable('a'), TypeVariable('a'), TypeVariable('b')), 46 | NewRecordType("", proton, proton, neutron), 47 | mSubs{'a': proton, 'b': neutron}, false}, 48 | } 49 | 50 | func TestUnify(t *testing.T) { 51 | // assert := assert.New(t) 52 | var t0, t1 Type 53 | var u0, u1 Type 54 | var sub Subs 55 | var err error 56 | 57 | for _, uts := range unifyTests { 58 | // logf("unifying %v", uts.name) 59 | t0 = uts.a 60 | t1 = uts.b 61 | sub, err = Unify(t0, t1) 62 | 63 | switch { 64 | case err == nil && uts.err: 65 | t.Errorf("Test %q - Expected an error: %v | u0: %#v, u1: %#v", uts.name, err, u0, u1) 66 | case err != nil && !uts.err: 67 | t.Errorf("Test %q errored: %v ", uts.name, err) 68 | } 69 | 70 | if uts.err { 71 | continue 72 | } 73 | 74 | if uts.subs == nil { 75 | if sub != nil { 76 | t.Errorf("Test: %q Expected no substitution. Got %v instead", uts.name, sub) 77 | } 78 | continue 79 | } 80 | 81 | for _, s := range uts.subs.Iter() { 82 | if T, ok := sub.Get(s.Tv); !ok { 83 | t.Errorf("Test: %q TypeVariable %v expected in result", uts.name, s.Tv) 84 | } else if T != s.T { 85 | t.Errorf("Test: %q Expected TypeVariable %v to be substituted by %v. Got %v instead", uts.name, s.Tv, s.T, T) 86 | } 87 | } 88 | 89 | if uts.subs.Size() != sub.Size() { 90 | t.Errorf("Test: %q Expected subs to be the same size", uts.name) 91 | } 92 | 93 | sub = nil 94 | } 95 | } 96 | 97 | var inferTests = []struct { 98 | name string 99 | 100 | expr Expression 101 | correct Type 102 | correctTVS TypeVarSet 103 | err bool 104 | }{ 105 | {"Lit", lit("1"), Float, nil, false}, 106 | {"Undefined Lit", lit("a"), nil, nil, true}, 107 | {"App", app{lit("+"), lit("1")}, NewFnType(Float, Float), nil, false}, 108 | 109 | {"Lambda", λ{"n", app{lit("+"), lit("1")}}, NewFnType(TypeVariable('a'), Float, Float), TypeVarSet{'a'}, false}, 110 | {"Lambda (+1)", λ{"a", app{lit("+1"), lit("a")}}, NewFnType(TypeVariable('a'), TypeVariable('a')), TypeVarSet{'a'}, false}, 111 | 112 | {"Var - found", variable("x"), proton, nil, false}, 113 | {"Var - notfound", variable("y"), nil, nil, true}, 114 | 115 | {"Self Infer - no err", selfInferer(true), proton, nil, false}, 116 | {"Self Infer - err", selfInferer(false), nil, nil, true}, 117 | 118 | {"nil expr", nil, nil, nil, true}, 119 | } 120 | 121 | func TestInfer(t *testing.T) { 122 | env := SimpleEnv{ 123 | "+": &Scheme{tvs: TypeVarSet{'a'}, t: NewFnType(TypeVariable('a'), TypeVariable('a'), TypeVariable('a'))}, 124 | "+1": &Scheme{tvs: TypeVarSet{'a'}, t: NewFnType(TypeVariable('a'), TypeVariable('a'))}, 125 | "x": NewScheme(nil, proton), 126 | } 127 | 128 | for _, its := range inferTests { 129 | sch, err := Infer(env, its.expr) 130 | 131 | if its.err { 132 | if err == nil { 133 | t.Errorf("Test %q : Expected error. %v", its.name, sch) 134 | } 135 | continue 136 | } else { 137 | if err != nil { 138 | t.Errorf("Test %q Error: %v", its.name, err) 139 | } 140 | } 141 | 142 | if !sch.t.Eq(its.correct) { 143 | t.Errorf("Test %q: Expected %v. Got %v", its.name, its.correct, sch.t) 144 | } 145 | 146 | for _, tv := range its.correctTVS { 147 | if !sch.tvs.Contains(tv) { 148 | t.Errorf("Test %q: Expected %v to be in the scheme.", its.name, tv) 149 | break 150 | } 151 | } 152 | 153 | if len(its.correctTVS) != len(sch.tvs) { 154 | t.Errorf("Test %q: Expected scheme to have %v. Got %v instead", its.name, its.correctTVS, sch.tvs) 155 | } 156 | } 157 | 158 | // test without env 159 | its := inferTests[0] 160 | sch, err := Infer(nil, its.expr) 161 | if err != nil { 162 | t.Errorf("Testing a nil Env. Shouldn't have errored. Got err: %v", err) 163 | } 164 | if !sch.t.Eq(its.correct) { 165 | t.Errorf("Testing nil Env. Expected %v to be in the scheme. Got scheme %v instead", its.correct, sch) 166 | } 167 | 168 | } 169 | -------------------------------------------------------------------------------- /hm.go: -------------------------------------------------------------------------------- 1 | package hm 2 | 3 | import "github.com/pkg/errors" 4 | 5 | // Cloner is any type that can clone 6 | type Cloner interface { 7 | Clone() interface{} 8 | } 9 | 10 | // Fresher keeps track of all the TypeVariables that has been generated so far. It has one method - Fresh(), which is to create a new TypeVariable 11 | type Fresher interface { 12 | Fresh() TypeVariable 13 | } 14 | 15 | type inferer struct { 16 | env Env 17 | cs Constraints 18 | t Type 19 | 20 | count int 21 | } 22 | 23 | func newInferer(env Env) *inferer { 24 | return &inferer{ 25 | env: env, 26 | } 27 | } 28 | 29 | func (infer *inferer) Fresh() TypeVariable { 30 | retVal := letters[infer.count] 31 | infer.count++ 32 | return TypeVariable(retVal) 33 | } 34 | 35 | func (infer *inferer) lookup(name string) error { 36 | s, ok := infer.env.SchemeOf(name) 37 | if !ok { 38 | return errors.Errorf("Undefined %v", name) 39 | } 40 | infer.t = Instantiate(infer, s) 41 | return nil 42 | } 43 | 44 | func (infer *inferer) consGen(expr Expression) (err error) { 45 | 46 | // explicit types/inferers - can fail 47 | switch et := expr.(type) { 48 | case Typer: 49 | if infer.t = et.Type(); infer.t != nil { 50 | return nil 51 | } 52 | case Inferer: 53 | if infer.t, err = et.Infer(infer.env, infer); err == nil && infer.t != nil { 54 | return nil 55 | } 56 | 57 | err = nil // reset errors 58 | } 59 | 60 | // fallbacks 61 | 62 | switch et := expr.(type) { 63 | case Literal: 64 | return infer.lookup(et.Name()) 65 | 66 | case Var: 67 | if err = infer.lookup(et.Name()); err != nil { 68 | infer.env.Add(et.Name(), &Scheme{t: et.Type()}) 69 | err = nil 70 | } 71 | 72 | case Lambda: 73 | tv := infer.Fresh() 74 | env := infer.env // backup 75 | 76 | infer.env = infer.env.Clone() 77 | infer.env.Remove(et.Name()) 78 | sc := new(Scheme) 79 | sc.t = tv 80 | infer.env.Add(et.Name(), sc) 81 | 82 | if err = infer.consGen(et.Body()); err != nil { 83 | return errors.Wrapf(err, "Unable to infer body of %v. Body: %v", et, et.Body()) 84 | } 85 | 86 | infer.t = NewFnType(tv, infer.t) 87 | infer.env = env // restore backup 88 | 89 | case Apply: 90 | if err = infer.consGen(et.Fn()); err != nil { 91 | return errors.Wrapf(err, "Unable to infer Fn of Apply: %v. Fn: %v", et, et.Fn()) 92 | } 93 | fnType, fnCs := infer.t, infer.cs 94 | 95 | if err = infer.consGen(et.Body()); err != nil { 96 | return errors.Wrapf(err, "Unable to infer body of Apply: %v. Body: %v", et, et.Body()) 97 | } 98 | bodyType, bodyCs := infer.t, infer.cs 99 | 100 | tv := infer.Fresh() 101 | cs := append(fnCs, bodyCs...) 102 | cs = append(cs, Constraint{fnType, NewFnType(bodyType, tv)}) 103 | 104 | infer.t = tv 105 | infer.cs = cs 106 | 107 | case LetRec: 108 | tv := infer.Fresh() 109 | // env := infer.env // backup 110 | 111 | infer.env = infer.env.Clone() 112 | infer.env.Remove(et.Name()) 113 | infer.env.Add(et.Name(), &Scheme{tvs: TypeVarSet{tv}, t: tv}) 114 | 115 | if err = infer.consGen(et.Def()); err != nil { 116 | return errors.Wrapf(err, "Unable to infer the definition of a letRec %v. Def: %v", et, et.Def()) 117 | } 118 | defType, defCs := infer.t, infer.cs 119 | 120 | s := newSolver() 121 | s.solve(defCs) 122 | if s.err != nil { 123 | return errors.Wrapf(s.err, "Unable to solve constraints of def: %v", defCs) 124 | } 125 | 126 | sc := Generalize(infer.env.Apply(s.sub).(Env), defType.Apply(s.sub).(Type)) 127 | 128 | infer.env.Remove(et.Name()) 129 | infer.env.Add(et.Name(), sc) 130 | 131 | if err = infer.consGen(et.Body()); err != nil { 132 | return errors.Wrapf(err, "Unable to infer body of letRec %v. Body: %v", et, et.Body()) 133 | } 134 | 135 | infer.t = infer.t.Apply(s.sub).(Type) 136 | infer.cs = infer.cs.Apply(s.sub).(Constraints) 137 | infer.cs = append(infer.cs, defCs...) 138 | 139 | case Let: 140 | env := infer.env 141 | 142 | if err = infer.consGen(et.Def()); err != nil { 143 | return errors.Wrapf(err, "Unable to infer the definition of a let %v. Def: %v", et, et.Def()) 144 | } 145 | defType, defCs := infer.t, infer.cs 146 | 147 | s := newSolver() 148 | s.solve(defCs) 149 | if s.err != nil { 150 | return errors.Wrapf(s.err, "Unable to solve for the constraints of a def %v", defCs) 151 | } 152 | 153 | sc := Generalize(env.Apply(s.sub).(Env), defType.Apply(s.sub).(Type)) 154 | infer.env = infer.env.Clone() 155 | infer.env.Remove(et.Name()) 156 | infer.env.Add(et.Name(), sc) 157 | 158 | if err = infer.consGen(et.Body()); err != nil { 159 | return errors.Wrapf(err, "Unable to infer body of let %v. Body: %v", et, et.Body()) 160 | } 161 | 162 | infer.t = infer.t.Apply(s.sub).(Type) 163 | infer.cs = infer.cs.Apply(s.sub).(Constraints) 164 | infer.cs = append(infer.cs, defCs...) 165 | 166 | default: 167 | return errors.Errorf("Expression of %T is unhandled", expr) 168 | } 169 | 170 | return nil 171 | } 172 | 173 | // Instantiate takes a fresh name generator, an a polytype and makes a concrete type out of it. 174 | // 175 | // If ... 176 | // Γ ⊢ e: T1 T1 ⊑ T 177 | // ---------------------- 178 | // Γ ⊢ e: T 179 | // 180 | func Instantiate(f Fresher, s *Scheme) Type { 181 | l := len(s.tvs) 182 | tvs := make(TypeVarSet, l) 183 | 184 | var sub Subs 185 | if l > 30 { 186 | sub = make(mSubs) 187 | } else { 188 | sub = newSliceSubs(l) 189 | } 190 | 191 | for i, tv := range s.tvs { 192 | fr := f.Fresh() 193 | tvs[i] = fr 194 | sub = sub.Add(tv, fr) 195 | } 196 | 197 | return s.t.Apply(sub).(Type) 198 | } 199 | 200 | // Generalize takes an env and a type and creates the most general possible type - which is a polytype 201 | // 202 | // Generalization 203 | // 204 | // If ... 205 | // Γ ⊢ e: T1 T1 ∉ free(Γ) 206 | // --------------------------- 207 | // Γ ⊢ e: ∀ α.T1 208 | func Generalize(env Env, t Type) *Scheme { 209 | logf("generalizing %v over %v", t, env) 210 | enterLoggingContext() 211 | defer leaveLoggingContext() 212 | var envFree, tFree, diff TypeVarSet 213 | 214 | if env != nil { 215 | envFree = env.FreeTypeVar() 216 | } 217 | 218 | tFree = t.FreeTypeVar() 219 | 220 | switch { 221 | case envFree == nil && tFree == nil: 222 | goto ret 223 | case len(envFree) > 0 && len(tFree) > 0: 224 | defer ReturnTypeVarSet(envFree) 225 | defer ReturnTypeVarSet(tFree) 226 | case len(envFree) > 0 && len(tFree) == 0: 227 | // cannot return envFree because envFree will just be sorted and set 228 | case len(envFree) == 0 && len(tFree) > 0: 229 | // return ? 230 | } 231 | 232 | diff = tFree.Difference(envFree) 233 | 234 | ret: 235 | return &Scheme{ 236 | tvs: diff, 237 | t: t, 238 | } 239 | } 240 | 241 | // Infer takes an env, and an expression, and returns a scheme. 242 | // 243 | // The Infer function is the core of the HM type inference system. This is a reference implementation and is completely servicable, but not quite performant. 244 | // You should use this as a reference and write your own infer function. 245 | // 246 | // Very briefly, these rules are implemented: 247 | // 248 | // Var 249 | // 250 | // If x is of type T, in a collection of statements Γ, then we can infer that x has type T when we come to a new instance of x 251 | // x: T ∈ Γ 252 | // ----------- 253 | // Γ ⊢ x: T 254 | // 255 | // Apply 256 | // 257 | // If f is a function that takes T1 and returns T2; and if x is of type T1; 258 | // then we can infer that the result of applying f on x will yield a result has type T2 259 | // Γ ⊢ f: T1→T2 Γ ⊢ x: T1 260 | // ------------------------- 261 | // Γ ⊢ f(x): T2 262 | // 263 | // 264 | // Lambda Abstraction 265 | // 266 | // If we assume x has type T1, and because of that we were able to infer e has type T2 267 | // then we can infer that the lambda abstraction of e with respect to the variable x, λx.e, 268 | // will be a function with type T1→T2 269 | // Γ, x: T1 ⊢ e: T2 270 | // ------------------- 271 | // Γ ⊢ λx.e: T1→T2 272 | // 273 | // Let 274 | // 275 | // If we can infer that e1 has type T1 and if we take x to have type T1 such that we could infer that e2 has type T2, 276 | // then we can infer that the result of letting x = e1 and substituting it into e2 has type T2 277 | // Γ, e1: T1 Γ, x: T1 ⊢ e2: T2 278 | // -------------------------------- 279 | // Γ ⊢ let x = e1 in e2: T2 280 | // 281 | func Infer(env Env, expr Expression) (*Scheme, error) { 282 | if expr == nil { 283 | return nil, errors.Errorf("Cannot infer a nil expression") 284 | } 285 | 286 | if env == nil { 287 | env = make(SimpleEnv) 288 | } 289 | 290 | infer := newInferer(env) 291 | if err := infer.consGen(expr); err != nil { 292 | return nil, err 293 | } 294 | 295 | s := newSolver() 296 | s.solve(infer.cs) 297 | 298 | if s.err != nil { 299 | return nil, s.err 300 | } 301 | 302 | if infer.t == nil { 303 | return nil, errors.Errorf("infer.t is nil") 304 | } 305 | 306 | t := infer.t.Apply(s.sub).(Type) 307 | return closeOver(t) 308 | } 309 | 310 | // Unify unifies the two types and returns a list of substitutions. 311 | // These are the rules: 312 | // 313 | // Type Constants and Type Constants 314 | // 315 | // Type constants (atomic types) have no substitution 316 | // c ~ c : [] 317 | // 318 | // Type Variables and Type Variables 319 | // 320 | // Type variables have no substitutions if there are no instances: 321 | // a ~ a : [] 322 | // 323 | // Default Unification 324 | // 325 | // if type variable 'a' is not in 'T', then unification is simple: replace all instances of 'a' with 'T' 326 | // a ∉ T 327 | // --------------- 328 | // a ~ T : [a/T] 329 | // 330 | func Unify(a, b Type) (sub Subs, err error) { 331 | logf("%v ~ %v", a, b) 332 | enterLoggingContext() 333 | defer leaveLoggingContext() 334 | 335 | switch at := a.(type) { 336 | case TypeVariable: 337 | return bind(at, b) 338 | default: 339 | if a.Eq(b) { 340 | return nil, nil 341 | } 342 | 343 | if btv, ok := b.(TypeVariable); ok { 344 | return bind(btv, a) 345 | } 346 | atypes := a.Types() 347 | btypes := b.Types() 348 | defer ReturnTypes(atypes) 349 | defer ReturnTypes(btypes) 350 | 351 | if len(atypes) == 0 && len(btypes) == 0 { 352 | goto e 353 | } 354 | 355 | return unifyMany(atypes, btypes) 356 | 357 | e: 358 | } 359 | err = errors.Errorf("Unification Fail: %v ~ %v cannot be unified", a, b) 360 | return 361 | } 362 | 363 | func unifyMany(a, b Types) (sub Subs, err error) { 364 | logf("UnifyMany %v %v", a, b) 365 | enterLoggingContext() 366 | defer leaveLoggingContext() 367 | 368 | if len(a) != len(b) { 369 | return nil, errors.Errorf("Unequal length. a: %v b %v", a, b) 370 | } 371 | 372 | for i, at := range a { 373 | bt := b[i] 374 | 375 | if sub != nil { 376 | at = at.Apply(sub).(Type) 377 | bt = bt.Apply(sub).(Type) 378 | } 379 | 380 | var s2 Subs 381 | if s2, err = Unify(at, bt); err != nil { 382 | return nil, err 383 | } 384 | 385 | if sub == nil { 386 | sub = s2 387 | } else { 388 | sub2 := compose(sub, s2) 389 | defer ReturnSubs(s2) 390 | if sub2 != sub { 391 | defer ReturnSubs(sub) 392 | } 393 | sub = sub2 394 | } 395 | } 396 | return 397 | } 398 | 399 | func bind(tv TypeVariable, t Type) (sub Subs, err error) { 400 | logf("Binding %v to %v", tv, t) 401 | switch { 402 | // case tv == t: 403 | case occurs(tv, t): 404 | err = errors.Errorf("recursive unification") 405 | default: 406 | ssub := BorrowSSubs(1) 407 | ssub.s[0] = Substitution{tv, t} 408 | sub = ssub 409 | } 410 | logf("Sub %v", sub) 411 | return 412 | } 413 | 414 | func occurs(tv TypeVariable, s Substitutable) bool { 415 | ftv := s.FreeTypeVar() 416 | defer ReturnTypeVarSet(ftv) 417 | 418 | return ftv.Contains(tv) 419 | } 420 | 421 | func closeOver(t Type) (sch *Scheme, err error) { 422 | sch = Generalize(nil, t) 423 | err = sch.Normalize() 424 | logf("closeoversch: %v", sch) 425 | return 426 | } 427 | --------------------------------------------------------------------------------