├── .gitignore ├── .ocamlformat ├── Makefile ├── README.md ├── algo_w ├── algo_w.ml ├── bin │ ├── dune │ └── main.ml ├── dune ├── expr.ml ├── infer.ml ├── lexer.mll ├── parser.mly └── test │ ├── dune │ └── test_infer.ml ├── dune-project ├── hmx ├── bin │ ├── dune │ └── main.ml ├── constraint.ml ├── debug.ml ├── dune ├── hmx.ml ├── infer.ml ├── lexer.mll ├── parser.mly ├── syntax.ml ├── test │ ├── dune │ └── test_infer.ml ├── type_error.ml ├── union_find.ml ├── union_find.mli ├── var.ml └── var.mli ├── hmx_tc ├── bin │ ├── dune │ └── main.ml ├── constraint.ml ├── debug.ml ├── dune ├── hmx_tc.ml ├── infer.ml ├── lexer.mll ├── parser.mly ├── syntax.ml ├── test │ ├── dune │ └── test_infer.ml ├── type_error.ml ├── union_find.ml ├── union_find.mli ├── var.ml └── var.mli └── type-systems.opam /.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | _opam 3 | -------------------------------------------------------------------------------- /.ocamlformat: -------------------------------------------------------------------------------- 1 | profile=conventional 2 | comment-check=true 3 | wrap-fun-args=true 4 | wrap-comments=false 5 | type-decl-indent=2 6 | type-decl=compact 7 | stritem-extension-indent=0 8 | space-around-variants=true 9 | space-around-records=true 10 | space-around-lists=true 11 | space-around-arrays=true 12 | single-case=compact 13 | sequence-style=terminator 14 | sequence-blank-line=preserve-one 15 | parse-docstrings=false 16 | parens-tuple-patterns=multi-line-only 17 | parens-tuple=always 18 | parens-ite=false 19 | ocp-indent-compat=false 20 | nested-match=wrap 21 | module-item-spacing=sparse 22 | max-indent=68 23 | match-indent-nested=never 24 | match-indent=0 25 | margin=80 26 | line-endings=lf 27 | let-module=compact 28 | let-binding-spacing=compact 29 | let-binding-indent=2 30 | let-and=sparse 31 | leading-nested-match-parens=false 32 | infix-precedence=indent 33 | indicate-nested-or-patterns=unsafe-no 34 | indicate-multiline-delimiters=no 35 | indent-after-in=0 36 | if-then-else=compact 37 | function-indent-nested=never 38 | function-indent=2 39 | field-space=loose 40 | extension-indent=2 41 | exp-grouping=parens 42 | dock-collection-brackets=true 43 | doc-comments-tag-only=default 44 | doc-comments-padding=2 45 | doc-comments=after-when-possible 46 | disambiguate-non-breaking-match=false 47 | disable=false 48 | cases-matching-exp-indent=normal 49 | cases-exp-indent=2 50 | break-struct=force 51 | break-string-literals=auto 52 | break-sequences=true 53 | break-separators=after 54 | break-infix-before-func=false 55 | break-infix=fit-or-vertical 56 | break-fun-sig=smart 57 | break-fun-decl=smart 58 | break-collection-expressions=fit-or-vertical 59 | break-cases=all 60 | break-before-in=fit-or-vertical 61 | assignment-operator=end-line 62 | align-variants-decl=false 63 | align-constructors-decl=false 64 | align-cases=false 65 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build b 2 | build b: 3 | dune build 4 | 5 | .PHONY: watch w 6 | watch w: 7 | dune build --watch 8 | 9 | .PHONY: clean 10 | clean: 11 | dune clean 12 | 13 | .PHONY: test t 14 | test t: 15 | dune runtest 16 | 17 | .PHONY: init 18 | init: 19 | opam switch create . 4.12.0 20 | opam install . -y --deps-only 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # type-systems 2 | 3 | - `algo_w` - Algorithm W extended with "Extensible Records with Scoped Labels" 4 | and Multi Parameter Typeclasses. 5 | 6 | This follows closely [tomprimozic/type-systems] and then [THIH][] (Typing 7 | Haskell in Haskell). Note that this tries to support multi-parameter typeclasses 8 | (MPTC) but generalization is currently buggy for typeclasses with multiple parameters. 9 | See `hmx_tc` implementation for the bug-free HM + MPTC implementation. 10 | 11 | - `hmx` - [HM(X)][] style implementation of Hindley Minler type inference. 12 | 13 | The main idea is to introduce constraint language and split the algo into two 14 | phases — first generate constraints from terms, then solve those constraints. 15 | 16 | This implementation also does elaboration, the `infer` function has the 17 | following signature: 18 | ``` 19 | val infer : env:Env.t -> expr -> (expr, Error.t) result 20 | ``` 21 | That means that `infer` returns not just the type of the expression but an 22 | elaborated expression (an original expression annotated with types). 23 | 24 | The elaboration mechanism is shamelessly stolen from [inferno][]. 25 | 26 | - `hmx_tc` - extends [HM(X)][] with Multi-Parameter Typeclasses (MPTC). 27 | 28 | Type inference and elaborator are implemented but the environment construction 29 | doesn't check for overlapping instances yet. 30 | 31 | # Development 32 | 33 | ``` 34 | make init 35 | make build 36 | make test 37 | ``` 38 | 39 | # References 40 | 41 | - https://en.wikipedia.org/wiki/Hindley–Milner_type_system 42 | - https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/scopedlabels.pdf 43 | - https://www.cs.tufts.edu/~nr/cs257/archive/martin-odersky/hmx.pdf 44 | - https://web.cecs.pdx.edu/~mpj/thih/thih.pdf 45 | - https://github.com/tomprimozic/type-systems 46 | - https://gitlab.inria.fr/fpottier/inferno/ 47 | - https://github.com/naominitel/hmx 48 | 49 | [HM(X)]: https://www.cs.tufts.edu/~nr/cs257/archive/martin-odersky/hmx.pdf 50 | [inferno]: https://gitlab.inria.fr/fpottier/inferno/ 51 | [THIH]: https://web.cecs.pdx.edu/~mpj/thih/thih.pdf 52 | [tomprimozic/type-systems]: https://github.com/tomprimozic/type-systems 53 | -------------------------------------------------------------------------------- /algo_w/algo_w.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | module Expr = Expr 3 | module Infer = Infer 4 | 5 | module Expr_parser = Nice_parser.Make (struct 6 | type token = Parser.token 7 | 8 | type result = Expr.expr 9 | 10 | let parse = Parser.expr_eof 11 | 12 | let next_token = Lexer.token 13 | 14 | exception ParseError = Parser.Error 15 | 16 | exception LexError = Lexer.Error 17 | end) 18 | 19 | module Ty_parser = Nice_parser.Make (struct 20 | type token = Parser.token 21 | 22 | type result = Expr.qual_ty 23 | 24 | let parse = Parser.qual_ty_forall_eof 25 | 26 | let next_token = Lexer.token 27 | 28 | exception ParseError = Parser.Error 29 | 30 | exception LexError = Lexer.Error 31 | end) 32 | 33 | module Qual_pred_parser = Nice_parser.Make (struct 34 | type token = Parser.token 35 | 36 | type result = Expr.qual_pred 37 | 38 | let parse = Parser.qual_pred_eof 39 | 40 | let next_token = Lexer.token 41 | 42 | exception ParseError = Parser.Error 43 | 44 | exception LexError = Lexer.Error 45 | end) 46 | 47 | let parse_expr = Expr_parser.parse_string 48 | 49 | let parse_ty = Ty_parser.parse_string 50 | 51 | let parse_qual_pred = Qual_pred_parser.parse_string 52 | 53 | let infer_ty ?(env = Infer.Env.empty) expr = 54 | try Ok (Infer.infer env expr) with 55 | | Infer.Type_error err -> Error err 56 | 57 | let () = 58 | Expr_parser.pp_exceptions (); 59 | Ty_parser.pp_exceptions (); 60 | Qual_pred_parser.pp_exceptions () 61 | -------------------------------------------------------------------------------- /algo_w/bin/dune: -------------------------------------------------------------------------------- 1 | (executable 2 | (public_name algo_w) 3 | (name main) 4 | (libraries algo_w base stdio)) 5 | -------------------------------------------------------------------------------- /algo_w/bin/main.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | 3 | let () = 4 | let env = 5 | let assume name ty env = 6 | Algo_w.Infer.Env.add env name (Algo_w.parse_ty ty) 7 | in 8 | let assume_typeclass qp env = 9 | let qp = Algo_w.parse_qual_pred qp in 10 | Algo_w.Infer.Env.add_typeclass env qp 11 | in 12 | let assume_instance qp witness env = 13 | let qp = Algo_w.parse_qual_pred qp in 14 | Algo_w.Infer.Env.add_instance env qp witness 15 | in 16 | Algo_w.Infer.Env.empty 17 | (* Show *) 18 | |> assume_typeclass "a . Show(a)" 19 | |> assume "show" "a . Show(a) => a -> string" 20 | |> assume "show_int" "int -> string" 21 | |> assume_instance "Show(int)" "show_int" 22 | |> assume "show_float" "float -> string" 23 | |> assume_instance "Show(float)" "show_float" 24 | (* Read *) 25 | |> assume_typeclass "a . Read(a)" 26 | |> assume "read" "a . Read(a) => string -> a" 27 | |> assume "read_int" "string -> int" 28 | |> assume_instance "Read(int)" "read_int" 29 | |> assume "read_float" "string -> float" 30 | |> assume_instance "Read(float)" "read_float" 31 | (* Eq *) 32 | |> assume_typeclass "a . Eq(a)" 33 | |> assume "eq" "a . Eq(a) => (a, a) -> bool" 34 | |> assume "eq_bool" "(bool, bool) -> bool" 35 | |> assume_instance "Eq(bool)" "eq_bool" 36 | |> assume "eq_int" "(int, int) -> bool" 37 | |> assume_instance "Eq(int)" "eq_int" 38 | |> assume "eq_list" "a . Eq(a) => (list[a], list[a]) -> bool" 39 | |> assume_instance "a . Eq(a) => Eq(list[a])" "eq_list" 40 | |> assume "eq_pair" "a, b . Eq(a), Eq(b) => (pair[a, b], list[a, b]) -> bool" 41 | |> assume_instance "a, b . Eq(a), Eq(b) => Eq(pair[a, b])" "eq_pair" 42 | (* Ord *) 43 | |> assume_typeclass "a . Eq(a) => Ord(a)" 44 | |> assume "compare" "a . Ord(a) => (a, a) -> int" 45 | |> assume "compare_bool" "(bool, bool) -> int" 46 | |> assume_instance "Ord(bool)" "compare_bool" 47 | |> assume "compare_int" "(int, int) -> int" 48 | |> assume_instance "Ord(int)" "compare_int" 49 | |> assume "compare_list" "a . Ord(a) => (list[a], list[a]) -> int" 50 | |> assume_instance "a . Ord(a) => Ord(list[a])" "compare_list" 51 | |> assume "compare_pair" 52 | "a, b . Ord(a), Ord(b) => (pair[a, b], list[a, b]) -> int" 53 | |> assume_instance "a, b . Ord(a), Ord(b) => Ord(pair[a, b])" "compare_pair" 54 | (* Lists *) 55 | |> assume "head" "a . list[a] -> a" 56 | |> assume "tail" "a . list[a] -> list[a]" 57 | |> assume "nil" "a . list[a]" 58 | |> assume "cons" "a . (a, list[a]) -> list[a]" 59 | |> assume "cons_curry" "a . a -> list[a] -> list[a]" 60 | |> assume "map" "a, b . (a -> b, list[a]) -> list[b]" 61 | |> assume "map_curry" "a, b . (a -> b) -> list[a] -> list[b]" 62 | |> assume "fix" "a . (a -> a) -> a" 63 | |> assume "one" "int" 64 | |> assume "zero" "int" 65 | |> assume "succ" "int -> int" 66 | |> assume "plus" "(int, int) -> int" 67 | |> assume "eq_curry" "a . a -> a -> bool" 68 | |> assume "not" "bool -> bool" 69 | |> assume "true" "bool" 70 | |> assume "false" "bool" 71 | |> assume "pair" "a, b . (a, b) -> pair[a, b]" 72 | |> assume "pair_curry" "a, b . a -> b -> pair[a, b]" 73 | |> assume "first" "a, b . pair[a, b] -> a" 74 | |> assume "second" "a, b . pair[a, b] -> b" 75 | |> assume "id" "a . a -> a" 76 | |> assume "const" "a, b . a -> b -> a" 77 | |> assume "apply" "a, b . (a -> b, a) -> b" 78 | |> assume "apply_curry" "a, b . (a -> b) -> a -> b" 79 | |> assume "choose" "a . (a, a) -> a" 80 | |> assume "choose_curry" "a . a -> a -> a" 81 | |> assume "age" "int" 82 | |> assume "world" "string" 83 | |> assume "print" "string -> string" 84 | |> assume "print_user" "(string,age) -> string" 85 | in 86 | let prog = Algo_w.Expr_parser.parse_chan Stdio.stdin in 87 | match Algo_w.infer_ty ~env prog with 88 | | Ok qty -> Stdlib.Format.printf ": %s@." (Algo_w.Expr.show_qual_ty qty) 89 | | Error err -> 90 | Stdlib.Format.printf "ERROR: %s@." (Algo_w.Infer.show_error err); 91 | Stdlib.exit 1 92 | -------------------------------------------------------------------------------- /algo_w/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name algo_w) 3 | (preprocess (pps ppx_sexp_conv)) 4 | (libraries base pprint nice_parser) 5 | ) 6 | 7 | (ocamllex lexer) 8 | 9 | (menhir 10 | (modules parser) 11 | (flags --explain --dump)) 12 | -------------------------------------------------------------------------------- /algo_w/expr.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | 3 | type name = string 4 | 5 | type lvl = int 6 | 7 | type id = int 8 | 9 | module Ty_var_unbound = struct 10 | type t = { id : id; lvl : lvl } 11 | 12 | include Comparator.Make (struct 13 | type nonrec t = t 14 | 15 | let sexp_of_t { id; lvl } = 16 | Sexp.List [ Sexp.Atom (Int.to_string id); Sexp.Atom (Int.to_string lvl) ] 17 | 18 | let compare a b = Int.compare a.id b.id 19 | end) 20 | end 21 | 22 | type expr = 23 | | Expr_name of name 24 | | Expr_abs of name list * expr 25 | | Expr_app of expr * expr list 26 | | Expr_lit of lit 27 | | Expr_let of name * expr * expr 28 | | Expr_let_rec of name * expr * expr 29 | | Expr_record of (name * expr) list 30 | | Expr_record_proj of expr * name 31 | | Expr_record_extend of expr * (name * expr) list 32 | | Expr_record_update of expr * (name * expr) list 33 | 34 | and lit = Lit_string of string | Lit_int of int 35 | 36 | type ty = 37 | | Ty_const of name 38 | | Ty_var of ty var ref 39 | | Ty_app of ty * ty list 40 | | Ty_arr of ty list * ty 41 | | Ty_record of ty_row 42 | 43 | and 'a var = 44 | | Ty_var_unbound of Ty_var_unbound.t 45 | | Ty_var_link of 'a 46 | | Ty_var_generic of id 47 | 48 | and ty_row = 49 | | Ty_row_field of string * ty * ty_row 50 | | Ty_row_empty 51 | | Ty_row_var of ty_row var ref 52 | | Ty_row_const of string 53 | 54 | and pred = string * ty list 55 | 56 | and qual_ty = pred list * ty 57 | 58 | and qual_pred = pred list * pred 59 | 60 | let is_simple_expr = function 61 | | Expr_abs _ 62 | | Expr_let_rec _ 63 | | Expr_let _ 64 | | Expr_record _ 65 | | Expr_record_extend _ 66 | | Expr_record_update _ 67 | | Expr_lit _ -> 68 | false 69 | | Expr_app _ 70 | | Expr_name _ 71 | | Expr_record_proj _ -> 72 | true 73 | 74 | let generic_ty_vars ty = 75 | let rec of_ty set = function 76 | | Ty_const _ -> set 77 | | Ty_arr (args, ret) -> List.fold args ~init:(of_ty set ret) ~f:of_ty 78 | | Ty_app (f, args) -> List.fold args ~init:(of_ty set f) ~f:of_ty 79 | | Ty_var { contents = Ty_var_link ty } -> of_ty set ty 80 | | Ty_var { contents = Ty_var_unbound _ } -> set 81 | | Ty_var { contents = Ty_var_generic id } -> Set.add set id 82 | | Ty_record row -> of_ty_row set row 83 | and of_ty_row set = function 84 | | Ty_row_empty -> set 85 | | Ty_row_var { contents = Ty_var_link ty_row } -> of_ty_row set ty_row 86 | | Ty_row_var { contents = Ty_var_unbound _ } -> set 87 | | Ty_row_var { contents = Ty_var_generic id } -> Set.add set id 88 | | Ty_row_const _ -> set 89 | | Ty_row_field (_, _, ty_row) -> of_ty_row set ty_row 90 | in 91 | of_ty (Set.empty (module Int)) ty 92 | 93 | let rec layout_expr = 94 | let open PPrint in 95 | function 96 | | Expr_name name -> string name 97 | | Expr_abs ([ arg ], body) -> 98 | string "fun " ^^ string arg ^^ string " -> " ^^ group (layout_expr body) 99 | | Expr_abs (args, body) -> 100 | let sep = comma ^^ blank 1 in 101 | string "fun " 102 | ^^ parens (separate sep (List.map args ~f:string)) 103 | ^^ string " -> " 104 | ^^ group (layout_expr body) 105 | | Expr_app (f, args) -> 106 | let sep = comma ^^ blank 1 in 107 | layout_expr f ^^ parens (separate sep (List.map args ~f:layout_expr)) 108 | | Expr_lit (Lit_string v) -> dquotes (string v) 109 | | Expr_lit (Lit_int v) -> dquotes (string (Int.to_string v)) 110 | | Expr_let (n, e, b) -> 111 | surround 2 1 112 | (string "let " ^^ string n ^^ string " =") 113 | (layout_expr e) 114 | (string "in " ^^ layout_expr b) 115 | | Expr_let_rec (n, e, b) -> 116 | surround 2 1 117 | (string "let rec " ^^ string n ^^ string " =") 118 | (layout_expr e) 119 | (string "in " ^^ layout_expr b) 120 | | Expr_record fields -> 121 | let sep = string ";" ^^ blank 1 in 122 | let f (n, e) = string n ^^ string " = " ^^ layout_expr e in 123 | surround 2 1 (string "{") (separate sep (List.map fields ~f)) (string "}") 124 | | Expr_record_proj (e, n) -> 125 | let head = 126 | if is_simple_expr e then layout_expr e else parens (layout_expr e) 127 | in 128 | head ^^ string "." ^^ string n 129 | | Expr_record_extend (e, fields) -> 130 | let sep = string ";" ^^ blank 1 in 131 | let f (n, e) = string n ^^ string " = " ^^ layout_expr e in 132 | surround 2 1 (string "{") 133 | (layout_expr e ^^ string " with " ^^ separate sep (List.map fields ~f)) 134 | (string "}") 135 | | Expr_record_update (e, fields) -> 136 | let sep = string ";" ^^ blank 1 in 137 | let f (n, e) = string n ^^ string " := " ^^ layout_expr e in 138 | surround 2 1 (string "{") 139 | (layout_expr e ^^ string " with " ^^ separate sep (List.map fields ~f)) 140 | (string "}") 141 | 142 | let make_names () = 143 | let names = Hashtbl.create (module Int) in 144 | let count = ref 0 in 145 | let genname () = 146 | let i = !count in 147 | Int.incr count; 148 | let name = 149 | String.make 1 (Char.of_int_exn (97 + Int.rem i 26)) 150 | ^ if i >= 26 then Int.to_string (i / 26) else "" 151 | in 152 | name 153 | in 154 | fun id -> Hashtbl.find_or_add names id ~default:genname 155 | 156 | let layout_ty' ~lookup_name ty = 157 | let open PPrint in 158 | let rec layout_ty = function 159 | | Ty_const name -> string name 160 | | Ty_arr ([ (Ty_arr _ as arg) ], ret) -> 161 | parens (layout_ty arg) ^^ string " -> " ^^ layout_ty ret 162 | | Ty_arr ([ arg ], ret) -> layout_ty arg ^^ string " -> " ^^ layout_ty ret 163 | | Ty_arr (args, ret) -> 164 | let sep = comma ^^ blank 1 in 165 | parens (separate sep (List.map args ~f:layout_ty)) 166 | ^^ string " -> " 167 | ^^ layout_ty ret 168 | | Ty_app (f, args) -> 169 | let sep = comma ^^ blank 1 in 170 | layout_ty f ^^ brackets (separate sep (List.map args ~f:layout_ty)) 171 | | Ty_var { contents } -> layout_ty_var layout_ty contents 172 | | Ty_record row -> 173 | let rec layout_ty_row = function 174 | | Ty_row_empty -> empty 175 | | Ty_row_field (name, ty, row) -> ( 176 | string name 177 | ^^ string ": " 178 | ^^ layout_ty ty 179 | ^^ 180 | match row with 181 | | Ty_row_empty -> string ";" 182 | | row -> string "; " ^^ layout_ty_row row) 183 | | Ty_row_var { contents } -> layout_ty_var layout_ty_row contents 184 | | Ty_row_const _ -> assert false 185 | in 186 | surround 2 1 (string "{") (layout_ty_row row) (string "}") 187 | and layout_ty_var : 'a. ('a -> document) -> 'a var -> document = 188 | fun layout_v -> function 189 | | Ty_var_link v -> layout_v v 190 | | Ty_var_generic id -> string (lookup_name id) 191 | | Ty_var_unbound { id; lvl = _ } -> string ("_" ^ Int.to_string id) 192 | in 193 | layout_ty ty 194 | 195 | let layout_forall' ~lookup_name ty_vars = 196 | let open PPrint in 197 | if Set.is_empty ty_vars then empty 198 | else 199 | let layout_ty_var id = string (lookup_name id) in 200 | let sep = comma ^^ blank 1 in 201 | separate sep (Set.to_list ty_vars |> List.map ~f:layout_ty_var) 202 | ^^ blank 1 203 | ^^ dot 204 | ^^ blank 1 205 | 206 | let layout_pred' ~lookup_name (name, args) = 207 | let open PPrint in 208 | let sep = comma ^^ blank 1 in 209 | string name 210 | ^^ parens (separate sep (List.map args ~f:(layout_ty' ~lookup_name))) 211 | 212 | let layout_ty ty = 213 | let lookup_name = make_names () in 214 | let ty_vars = generic_ty_vars ty in 215 | let open PPrint in 216 | layout_forall' ~lookup_name ty_vars ^^ layout_ty' ~lookup_name ty 217 | 218 | let layout_pred p = 219 | let lookup_name = make_names () in 220 | layout_pred' ~lookup_name p 221 | 222 | let layout_qual_ty qty = 223 | let open PPrint in 224 | let lookup_name = make_names () in 225 | let ty_vars = 226 | let ps, ty = qty in 227 | let init = generic_ty_vars ty in 228 | List.fold ps ~init ~f:(fun init (_, args) -> 229 | List.fold args ~init ~f:(fun set ty -> 230 | Set.union set (generic_ty_vars ty))) 231 | in 232 | layout_forall' ~lookup_name ty_vars 233 | ^^ 234 | match qty with 235 | | [], ty -> layout_ty' ~lookup_name ty 236 | | cs, ty -> 237 | let sep = comma ^^ blank 1 in 238 | separate sep (List.map cs ~f:(layout_pred' ~lookup_name)) 239 | ^^ string " => " 240 | ^^ layout_ty' ~lookup_name ty 241 | 242 | let layout_qual_pred qp = 243 | let lookup_name = make_names () in 244 | let open PPrint in 245 | match qp with 246 | | [], pred -> layout_pred' ~lookup_name pred 247 | | cs, pred -> 248 | let sep = comma ^^ blank 1 in 249 | separate sep (List.map cs ~f:(layout_pred' ~lookup_name)) 250 | ^^ string " => " 251 | ^^ layout_pred' ~lookup_name pred 252 | 253 | let pp' doc ppf v = PPrint.ToFormatter.pretty 1. 80 ppf (doc v) 254 | 255 | let show' ?(width = 80) doc v = 256 | let buf = Buffer.create 100 in 257 | PPrint.ToBuffer.pretty 1. width buf (doc v); 258 | Buffer.contents buf 259 | 260 | let print' doc v = Stdlib.print_endline (show' doc v) 261 | 262 | let pp_expr = pp' layout_expr 263 | 264 | let show_expr = show' layout_expr 265 | 266 | let print_expr = print' layout_expr 267 | 268 | let pp_qual_ty = pp' layout_qual_ty 269 | 270 | let show_qual_ty = show' layout_qual_ty 271 | 272 | let print_qual_ty = print' layout_qual_ty 273 | 274 | let pp_pred = pp' layout_pred 275 | 276 | let show_pred = show' layout_pred 277 | 278 | let print_pred = print' layout_pred 279 | 280 | let pp_qual_pred = pp' layout_qual_pred 281 | 282 | let show_qual_pred = show' layout_qual_pred 283 | 284 | let print_qual_pred = print' layout_qual_pred 285 | -------------------------------------------------------------------------------- /algo_w/infer.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | open Expr 3 | 4 | module Instance = struct 5 | type t = { instance : qual_pred; witness : String.t } 6 | end 7 | 8 | module Typeclass = struct 9 | type t = { typeclass : qual_pred; instances : Instance.t list } 10 | end 11 | 12 | module Env : sig 13 | type t 14 | 15 | (* Construction API. *) 16 | 17 | val empty : t 18 | 19 | val add : t -> String.t -> qual_ty -> t 20 | 21 | val add_typeclass : t -> qual_pred -> t 22 | 23 | val add_instance : t -> qual_pred -> String.t -> t 24 | 25 | (* Query API. *) 26 | 27 | val find : t -> String.t -> qual_ty option 28 | 29 | val dependencies : t -> String.t -> qual_pred list 30 | 31 | val instances : t -> String.t -> qual_pred list 32 | end = struct 33 | type t = { 34 | env : (String.t, qual_ty, String.comparator_witness) Map.t; 35 | typeclasses : (String.t, Typeclass.t, String.comparator_witness) Map.t; 36 | } 37 | 38 | let empty = 39 | { env = Map.empty (module String); typeclasses = Map.empty (module String) } 40 | 41 | let add env name qty = { env with env = Map.set env.env ~key:name ~data:qty } 42 | 43 | let add_typeclass env (qp : Expr.qual_pred) = 44 | (* TODO: add checks *) 45 | let _, (name, _) = qp in 46 | { 47 | env with 48 | typeclasses = 49 | Map.set env.typeclasses ~key:name 50 | ~data:{ typeclass = qp; instances = [] }; 51 | } 52 | 53 | let add_instance env qp witness = 54 | (* TODO: add checks *) 55 | let _, (name, _) = qp in 56 | let cls = 57 | match Map.find env.typeclasses name with 58 | | None -> failwith (Printf.sprintf "no such typeclass %s" name) 59 | | Some cls -> cls 60 | in 61 | let cls = 62 | { cls with instances = { instance = qp; witness } :: cls.instances } 63 | in 64 | { env with typeclasses = Map.set env.typeclasses ~key:name ~data:cls } 65 | 66 | let dependencies env id = 67 | let cls = Map.find_exn env.typeclasses id in 68 | List.map (fst cls.typeclass) ~f:(fun (name, _) -> 69 | let dep = Map.find_exn env.typeclasses name in 70 | dep.typeclass) 71 | 72 | let instances env id = 73 | let cls = Map.find_exn env.typeclasses id in 74 | List.map cls.instances ~f:(fun instance -> instance.instance) 75 | 76 | let find env = Map.find env.env 77 | end 78 | 79 | type error = 80 | | Error_unification of ty * ty 81 | | Error_recursive_types 82 | | Error_recursive_row_types 83 | | Error_not_a_function of ty 84 | | Error_unknown_name of string 85 | | Error_arity_mismatch of ty * int * int 86 | | Error_missing_typeclass_instance of pred 87 | | Error_ambigious_predicate of pred 88 | 89 | exception Type_error of error 90 | 91 | let type_error err = raise (Type_error err) 92 | 93 | let layout_error = 94 | PPrint.( 95 | function 96 | | Error_recursive_types -> string "recursive types" 97 | | Error_recursive_row_types -> string "recursive row types" 98 | | Error_not_a_function ty -> 99 | string "expected a function but got:" ^^ nest 2 (break 1 ^^ layout_ty ty) 100 | | Error_unknown_name name -> string "unknown name: " ^^ string name 101 | | Error_arity_mismatch (ty, expected, got) -> 102 | string "arity mismatch: expected " 103 | ^^ string (Int.to_string expected) 104 | ^^ string " arguments but got " 105 | ^^ string (Int.to_string got) 106 | ^^ nest 2 (break 1 ^^ layout_ty ty) 107 | | Error_unification (ty1, ty2) -> 108 | string "unification error of" 109 | ^^ nest 2 (break 1 ^^ layout_ty ty1) 110 | ^^ (break 1 ^^ string "with") 111 | ^^ nest 2 (break 1 ^^ layout_ty ty2) 112 | | Error_missing_typeclass_instance p -> 113 | string "missing typeclass instance: " ^^ layout_pred p 114 | | Error_ambigious_predicate p -> 115 | string "ambigious predicate: " ^^ layout_pred p) 116 | 117 | let pp_error = pp' layout_error 118 | 119 | let show_error = show' layout_error 120 | 121 | module Vars : sig 122 | val newvar : lvl -> unit -> ty 123 | 124 | val newrowvar : lvl -> unit -> ty_row 125 | 126 | val reset_vars : unit -> unit 127 | 128 | val newgenvar : unit -> ty 129 | end = struct 130 | let currentid = ref 0 131 | 132 | let reset_vars () = currentid := 0 133 | 134 | let newid () = 135 | Int.incr currentid; 136 | !currentid 137 | 138 | let newvar lvl () = 139 | Ty_var { contents = Ty_var_unbound { id = newid (); lvl } } 140 | 141 | let newrowvar lvl () = 142 | Ty_row_var { contents = Ty_var_unbound { id = newid (); lvl } } 143 | 144 | let newgenvar () = Ty_var { contents = Ty_var_generic (newid ()) } 145 | end 146 | 147 | include Vars 148 | 149 | (** Instantiation of type schemas into types. 150 | 151 | This is done by replacing all generic type variables with fresh unbound type 152 | variables. 153 | 154 | *) 155 | module Instantiate : sig 156 | val instantiate_qual_ty : lvl -> qual_ty -> qual_ty 157 | 158 | val instantiate_qual_pred : lvl -> qual_pred -> qual_pred 159 | 160 | val instantiate_pred : lvl -> pred -> pred 161 | end = struct 162 | type ctx = { 163 | lvl : Int.t; 164 | vars : (Int.t, ty) Hashtbl.t; 165 | rowvars : (Int.t, ty_row) Hashtbl.t; 166 | } 167 | 168 | let make_ctx lvl = 169 | { 170 | lvl; 171 | vars = Hashtbl.create (module Int); 172 | rowvars = Hashtbl.create (module Int); 173 | } 174 | 175 | let rec instantiate_ty' ctx (ty : ty) : ty = 176 | match ty with 177 | | Ty_const _ -> ty 178 | | Ty_arr (ty_args, ty_ret) -> 179 | Ty_arr 180 | (List.map ty_args ~f:(instantiate_ty' ctx), instantiate_ty' ctx ty_ret) 181 | | Ty_app (ty, ty_args) -> 182 | Ty_app (instantiate_ty' ctx ty, List.map ty_args ~f:(instantiate_ty' ctx)) 183 | | Ty_var { contents = Ty_var_link ty } -> instantiate_ty' ctx ty 184 | | Ty_var { contents = Ty_var_unbound _ } -> ty 185 | | Ty_var { contents = Ty_var_generic id } -> 186 | Hashtbl.find_or_add ctx.vars id ~default:(newvar ctx.lvl) 187 | | Ty_record row -> Ty_record (instantiate_ty_row' ctx row) 188 | 189 | and instantiate_ty_row' ctx (ty_row : ty_row) = 190 | match ty_row with 191 | | Ty_row_field (name, ty, ty_row) -> 192 | Ty_row_field (name, instantiate_ty' ctx ty, instantiate_ty_row' ctx ty_row) 193 | | Ty_row_empty -> ty_row 194 | | Ty_row_var { contents = Ty_var_link ty_row } -> 195 | instantiate_ty_row' ctx ty_row 196 | | Ty_row_var { contents = Ty_var_unbound _ } -> ty_row 197 | | Ty_row_var { contents = Ty_var_generic id } -> 198 | Hashtbl.find_or_add ctx.rowvars id ~default:(newrowvar ctx.lvl) 199 | | Ty_row_const _ -> assert false 200 | 201 | and instantiate_pred' ctx (name, args) = 202 | (name, List.map args ~f:(instantiate_ty' ctx)) 203 | 204 | and instantiate_qual_ty' ctx qty = 205 | let preds, ty = qty in 206 | (List.map preds ~f:(instantiate_pred' ctx), instantiate_ty' ctx ty) 207 | 208 | let instantiate_pred lvl p = 209 | let ctx = make_ctx lvl in 210 | instantiate_pred' ctx p 211 | 212 | let instantiate_qual_ty lvl qty = 213 | let ctx = make_ctx lvl in 214 | instantiate_qual_ty' ctx qty 215 | 216 | let instantiate_qual_pred lvl (ps, p) = 217 | let ctx = make_ctx lvl in 218 | (List.map ps ~f:(instantiate_pred' ctx), instantiate_pred' ctx p) 219 | end 220 | 221 | include Instantiate 222 | 223 | module Pred_solver : sig 224 | val solve_preds : 225 | lvl -> 226 | Env.t -> 227 | (Ty_var_unbound.t, Ty_var_unbound.comparator_witness) Set.t -> 228 | pred list -> 229 | pred list * pred list 230 | (** Solve a set of predicates. 231 | 232 | This raises a [Type_error] in case it cannot find a suitable instance for 233 | a ground predicate or if a predicate is ambigious. 234 | 235 | The function returns a pair of predicate sets [deferred, retained] where 236 | [retained] should be generalized while [deferred] should be propagated 237 | upwards. 238 | *) 239 | end = struct 240 | let match_ty ty1 ty2 = 241 | (* invariant: this destructs only ty1 *) 242 | (* TODO: handle closed record types. *) 243 | let rec aux ty1 ty2 : bool = 244 | if phys_equal ty1 ty2 then true 245 | else 246 | match (ty1, ty2) with 247 | | ty1, Ty_var { contents = Ty_var_link ty2 } -> aux ty1 ty2 248 | | Ty_app (f1, args1), Ty_app (f2, args2) -> 249 | aux f1 f2 250 | && List.length args1 = List.length args2 251 | && List.for_all2_exn args1 args2 ~f:(fun ty1 ty2 -> aux ty1 ty2) 252 | | Ty_var { contents = Ty_var_link ty1 }, ty2 -> aux ty1 ty2 253 | | Ty_var ({ contents = Ty_var_unbound _ } as var), ty2 -> 254 | var := Ty_var_link ty2; 255 | true 256 | | Ty_var { contents = Ty_var_generic _ }, _ -> 257 | failwith "uninstantiated type variable" 258 | | Ty_const name1, Ty_const name2 -> String.(name1 = name2) 259 | | _, _ -> false 260 | in 261 | aux ty1 ty2 262 | 263 | let match_pred (name1, args1) (name2, args2) = 264 | if not String.(name1 = name2) then false 265 | else 266 | let rec aux args1 args2 = 267 | match (args1, args2) with 268 | | [], [] -> true 269 | | [], _ -> false 270 | | _, [] -> false 271 | | a1 :: args1, a2 :: args2 -> match_ty a1 a2 && aux args1 args2 272 | in 273 | aux args1 args2 274 | 275 | let entailments_of_dependencies _lvl env pred = 276 | (* TODO: need to return a list of all things here *) 277 | let rec aux entailments pred = 278 | let dependencies = Env.dependencies env (fst pred) in 279 | List.fold dependencies ~init:(pred :: entailments) 280 | ~f:(fun entailments dep -> aux entailments (snd dep)) 281 | in 282 | aux [] pred 283 | 284 | (* Try each instance of the class and on first match return the list of 285 | dependencies. 286 | 287 | We are looking for the first match becuase we are supposed to have 288 | non-overlapping instances in the environment (that's a TODO to enforce this 289 | invatiant on environment construction). *) 290 | let entailments_of_instances lvl env pred = 291 | let rec aux = function 292 | | [] -> None 293 | | q :: qs -> 294 | let deps', pred' = instantiate_qual_pred lvl q in 295 | if match_pred pred' pred then Some deps' else aux qs 296 | in 297 | aux (Env.instances env (fst pred)) 298 | 299 | (* Entailment relation between predicates. 300 | 301 | [entail lvl env ps p] returns [true] in case predicates [ps] are enough to 302 | establish [p] predicate. *) 303 | let rec entail lvl env ps p = 304 | let rec inspect_dependencies = function 305 | | [] -> false 306 | | q :: qs -> 307 | let deps = entailments_of_dependencies lvl env q in 308 | List.exists deps ~f:(fun dep -> 309 | let dep = instantiate_pred lvl dep in 310 | match_pred dep p) 311 | || inspect_dependencies qs 312 | in 313 | inspect_dependencies ps 314 | || 315 | match entailments_of_instances lvl env p with 316 | | None -> false 317 | | Some qs -> List.for_all qs ~f:(fun q -> entail lvl env ps q) 318 | 319 | (* Check that a predicate in a head normal form (HNF). 320 | 321 | A predicate is in HNF if all its arguments are type variables (this HNF 322 | definition is specific for languages with first order polymorphism only). *) 323 | let is_hnf (_name, args) = 324 | let rec aux = function 325 | | Ty_var { contents = Ty_var_link ty } -> aux ty 326 | | Ty_var { contents = Ty_var_generic _ } -> assert false 327 | | Ty_var { contents = Ty_var_unbound _ } -> true 328 | | Ty_app _ -> false 329 | | Ty_arr _ -> false 330 | | Ty_const _ -> false 331 | | Ty_record _ -> false 332 | in 333 | List.for_all args ~f:aux 334 | 335 | (* Try to convert a predicate into a HNF. 336 | 337 | Raises a type error if some instances are missing. *) 338 | let rec to_hnf lvl env p = 339 | if is_hnf p then [ p ] 340 | else 341 | match entailments_of_instances lvl env p with 342 | | None -> type_error (Error_missing_typeclass_instance p) 343 | | Some ps -> to_hnfs lvl env ps 344 | 345 | and to_hnfs lvl env ps = List.concat (List.map ps ~f:(to_hnf lvl env)) 346 | 347 | (* Simplify a list of predicates. 348 | 349 | Simplification is performed by removing those predicates which can be 350 | inferred from other predicates in the same list (for which an entailment 351 | relation holds). *) 352 | let simplify lvl env ps = 353 | let rec aux simplified = function 354 | | [] -> simplified 355 | | p :: ps -> 356 | if entail lvl env (simplified @ ps) p then aux simplified ps 357 | else aux (p :: simplified) ps 358 | in 359 | aux [] ps 360 | 361 | (* Reduce a list of predicates. *) 362 | let reduce lvl env ps = 363 | let ps = to_hnfs lvl env ps in 364 | simplify lvl env ps 365 | 366 | let ty_vars ((_name, args) as p) = 367 | let rec inspect = function 368 | | Ty_var { contents = Ty_var_unbound tv } -> tv 369 | | Ty_var { contents = Ty_var_link ty } -> inspect ty 370 | | _ -> failwith (Printf.sprintf "predicate not in HNF: %s" (show_pred p)) 371 | in 372 | List.map args ~f:inspect 373 | 374 | let solve_preds lvl env vars ps = 375 | let ps = reduce lvl env ps in 376 | let should_defer p = 377 | List.for_all (ty_vars p) ~f:(fun tv -> tv.lvl <= lvl) 378 | in 379 | let rec aux (deferred, retained) = function 380 | | [] -> (deferred, retained) 381 | | p :: ps -> 382 | if should_defer p then aux (p :: deferred, retained) ps 383 | else 384 | let not_in_vars tv = not (Set.mem vars tv) in 385 | if List.exists (ty_vars p) ~f:not_in_vars then 386 | type_error (Error_ambigious_predicate p); 387 | aux (deferred, p :: retained) ps 388 | in 389 | aux ([], []) ps 390 | end 391 | 392 | include Pred_solver 393 | 394 | let generalize lvl env (qty : qual_ty) = 395 | let generalize_ty ty = 396 | (* Along with generalizing the type we also find all unbound type variables 397 | which we later use to check predicates for ambiguity. *) 398 | let seen = ref (Set.empty (module Ty_var_unbound)) in 399 | let mark id = Ref.replace seen (fun seen -> Set.add seen id) in 400 | let rec generalize_ty ty = 401 | match ty with 402 | | Ty_const _ -> ty 403 | | Ty_arr (ty_args, ty_ret) -> 404 | Ty_arr (List.map ty_args ~f:generalize_ty, generalize_ty ty_ret) 405 | | Ty_app (ty, ty_args) -> 406 | Ty_app (generalize_ty ty, List.map ty_args ~f:generalize_ty) 407 | | Ty_var { contents = Ty_var_link ty } -> generalize_ty ty 408 | | Ty_var { contents = Ty_var_generic _ } -> ty 409 | | Ty_var { contents = Ty_var_unbound tv } -> 410 | mark tv; 411 | if tv.lvl > lvl then Ty_var { contents = Ty_var_generic tv.id } else ty 412 | | Ty_record row -> Ty_record (generalize_ty_row row) 413 | and generalize_ty_row (ty_row : ty_row) = 414 | match ty_row with 415 | | Ty_row_field (name, ty, row) -> 416 | Ty_row_field (name, generalize_ty ty, generalize_ty_row row) 417 | | Ty_row_empty -> ty_row 418 | | Ty_row_var { contents = Ty_var_link ty_row } -> generalize_ty_row ty_row 419 | | Ty_row_var { contents = Ty_var_generic _ } -> ty_row 420 | | Ty_row_var { contents = Ty_var_unbound { id; lvl = var_lvl } } -> 421 | if var_lvl > lvl then Ty_row_var { contents = Ty_var_generic id } 422 | else ty_row 423 | | Ty_row_const _ -> assert false 424 | in 425 | let ty = generalize_ty ty in 426 | (ty, !seen) 427 | in 428 | let generalize_pred (name, args) = 429 | let args = List.map args ~f:(fun ty -> fst (generalize_ty ty)) in 430 | (name, args) 431 | in 432 | let ps, ty = qty in 433 | let ty, vars = generalize_ty ty in 434 | let deferred, retained = solve_preds lvl env vars ps in 435 | (deferred @ List.map retained ~f:generalize_pred, ty) 436 | 437 | let occurs_check lvl id ty = 438 | let rec occurs_check_ty (ty : ty) : unit = 439 | match ty with 440 | | Ty_const _ -> () 441 | | Ty_arr (args, ret) -> 442 | List.iter args ~f:occurs_check_ty; 443 | occurs_check_ty ret 444 | | Ty_app (f, args) -> 445 | occurs_check_ty f; 446 | List.iter args ~f:occurs_check_ty 447 | | Ty_var { contents = Ty_var_link ty } -> occurs_check_ty ty 448 | | Ty_var { contents = Ty_var_generic _ } -> () 449 | | Ty_var ({ contents = Ty_var_unbound v } as var) -> 450 | if v.id = id then type_error Error_recursive_types 451 | else if lvl < v.lvl then var := Ty_var_unbound { id = v.id; lvl } 452 | else () 453 | | Ty_record ty_row -> occurs_check_ty_row ty_row 454 | and occurs_check_ty_row (ty_row : ty_row) : unit = 455 | match ty_row with 456 | | Ty_row_field (_name, ty, ty_row) -> 457 | occurs_check_ty ty; 458 | occurs_check_ty_row ty_row 459 | | Ty_row_empty -> () 460 | | Ty_row_var { contents = Ty_var_link ty_row } -> occurs_check_ty_row ty_row 461 | | Ty_row_var { contents = Ty_var_generic _ } -> () 462 | | Ty_row_var ({ contents = Ty_var_unbound v } as var) -> 463 | if v.id = id then type_error Error_recursive_types 464 | else if lvl < v.lvl then var := Ty_var_unbound { id = v.id; lvl } 465 | else () 466 | | Ty_row_const _ -> assert false 467 | in 468 | occurs_check_ty ty 469 | 470 | let rec unify (ty1 : ty) (ty2 : ty) = 471 | if phys_equal ty1 ty2 then () 472 | else 473 | match (ty1, ty2) with 474 | | Ty_const name1, Ty_const name2 -> 475 | if not String.(name1 = name2) then 476 | type_error (Error_unification (ty1, ty2)) 477 | | Ty_arr (args1, ret1), Ty_arr (args2, ret2) -> ( 478 | match List.iter2 args1 args2 ~f:unify with 479 | | Unequal_lengths -> 480 | type_error 481 | (Error_arity_mismatch (ty1, List.length args2, List.length args1)) 482 | | Ok () -> unify ret1 ret2) 483 | | Ty_app (f1, args1), Ty_app (f2, args2) -> 484 | unify f1 f2; 485 | List.iter2_exn args1 args2 ~f:unify 486 | | Ty_record row1, Ty_record row2 -> unify_row row1 row2 487 | | Ty_var { contents = Ty_var_link ty1 }, ty2 488 | | ty1, Ty_var { contents = Ty_var_link ty2 } -> 489 | unify ty1 ty2 490 | | Ty_var ({ contents = Ty_var_unbound { id; lvl } } as var), ty 491 | | ty, Ty_var ({ contents = Ty_var_unbound { id; lvl } } as var) -> 492 | occurs_check lvl id ty; 493 | var := Ty_var_link ty 494 | | ty1, ty2 -> type_error (Error_unification (ty1, ty2)) 495 | 496 | and unify_row row1 row2 = 497 | if phys_equal row1 row2 then () 498 | else 499 | match (row1, row2) with 500 | | Ty_row_empty, Ty_row_empty -> () 501 | | Ty_row_field (name, ty, row1), Ty_row_field _ -> 502 | let exception Row_rewrite_error in 503 | let rec rewrite = function 504 | | Ty_row_empty -> raise Row_rewrite_error 505 | | Ty_row_field (name', ty', row') -> 506 | if String.(name = name') then ( 507 | unify ty ty'; 508 | row') 509 | else Ty_row_field (name', ty', rewrite row') 510 | | Ty_row_var { contents = Ty_var_link row' } -> rewrite row' 511 | | Ty_row_var ({ contents = Ty_var_unbound { id = _; lvl } } as var) -> 512 | let row' = newrowvar lvl () in 513 | var := Ty_var_link (Ty_row_field (name, ty, row')); 514 | row' 515 | | Ty_row_var { contents = Ty_var_generic _ } -> 516 | failwith "non instantiated row variable" 517 | | Ty_row_const _ -> assert false 518 | in 519 | let row1_unbound = 520 | match row1 with 521 | | Ty_row_var ({ contents = Ty_var_unbound _ } as var) -> Some var 522 | | _ -> None 523 | in 524 | let row2 = 525 | try rewrite row2 with 526 | | Row_rewrite_error -> 527 | type_error (Error_unification (Ty_record row1, Ty_record row2)) 528 | in 529 | (match row1_unbound with 530 | | Some { contents = Ty_var_link _ } -> 531 | type_error Error_recursive_row_types 532 | | _ -> ()); 533 | unify_row row1 row2 534 | | Ty_row_var { contents = Ty_var_link row1 }, row2 535 | | row2, Ty_row_var { contents = Ty_var_link row1 } -> 536 | unify_row row1 row2 537 | | Ty_row_var ({ contents = Ty_var_unbound { id; lvl } } as var), row 538 | | row, Ty_row_var ({ contents = Ty_var_unbound { id; lvl } } as var) -> 539 | occurs_check lvl id (Ty_record row); 540 | var := Ty_var_link row 541 | | row1, row2 -> 542 | type_error (Error_unification (Ty_record row1, Ty_record row2)) 543 | 544 | let rec unify_abs arity ty = 545 | match ty with 546 | | Ty_arr (ty_args, ty_ret) -> 547 | if List.length ty_args <> arity then 548 | type_error (Error_arity_mismatch (ty, List.length ty_args, arity)); 549 | (ty_args, ty_ret) 550 | | Ty_var var -> ( 551 | match !var with 552 | | Ty_var_link ty -> unify_abs arity ty 553 | | Ty_var_unbound v -> 554 | let ty_ret = newvar v.lvl () in 555 | let ty_args = List.init arity ~f:(fun _ -> newvar v.lvl ()) in 556 | var := Ty_var_link (Ty_arr (ty_args, ty_ret)); 557 | (ty_args, ty_ret) 558 | | Ty_var_generic _ -> failwith "uninstantiated generic type") 559 | | Ty_app _ 560 | | Ty_const _ 561 | | Ty_record _ -> 562 | type_error (Error_not_a_function ty) 563 | 564 | let rec infer' lvl (env : Env.t) (e : expr) = 565 | match e with 566 | | Expr_name name -> 567 | let qty = 568 | match Env.find env name with 569 | | Some ty -> ty 570 | | None -> type_error (Error_unknown_name name) 571 | in 572 | instantiate_qual_ty lvl qty 573 | | Expr_abs (args, body) -> 574 | let ty_args = List.map args ~f:(fun _ -> newvar lvl ()) in 575 | let cs, ty_body = 576 | let env = 577 | List.fold_left (List.zip_exn args ty_args) ~init:env 578 | ~f:(fun env (arg, ty_arg) -> Env.add env arg ([], ty_arg)) 579 | in 580 | infer' lvl env body 581 | in 582 | (cs, Ty_arr (ty_args, ty_body)) 583 | | Expr_app (func, args) -> 584 | let cs, ty_func = infer' lvl env func in 585 | let ty_args, ty_ret = unify_abs (List.length args) ty_func in 586 | let cs = 587 | List.fold2_exn args ty_args ~init:cs ~f:(fun cs arg ty_arg -> 588 | let cs', ty = infer' lvl env arg in 589 | unify ty ty_arg; 590 | cs @ cs') 591 | in 592 | (cs, ty_ret) 593 | | Expr_let (name, e, b) -> 594 | let ty_e = infer' (lvl + 1) env e in 595 | let ty_e = generalize lvl env ty_e in 596 | let env = Env.add env name ty_e in 597 | infer' lvl env b 598 | | Expr_let_rec (name, e, b) -> 599 | let ty_e = 600 | (* fix : a . (a -> a) -> a *) 601 | let ty_ret = newvar lvl () in 602 | let ty_fun = Ty_arr ([ ty_ret ], ty_ret) in 603 | let cs, ty_fun' = infer' (lvl + 1) env (Expr_abs ([ name ], e)) in 604 | unify ty_fun' ty_fun; 605 | (cs, ty_ret) 606 | in 607 | let ty_e = generalize lvl env ty_e in 608 | let env = Env.add env name ty_e in 609 | infer' lvl env b 610 | | Expr_record fields -> 611 | let cs, ty_row = 612 | List.fold_left fields ~init:([], Ty_row_empty) 613 | ~f:(fun (cs, row) (label, e) -> 614 | let cs', ty_e = infer' lvl env e in 615 | (cs @ cs', Ty_row_field (label, ty_e, row))) 616 | in 617 | (cs, Ty_record ty_row) 618 | | Expr_record_proj (e, label) -> 619 | let cs, ty_e = infer' lvl env e in 620 | let ty_proj = newvar lvl () in 621 | unify ty_e (Ty_record (Ty_row_field (label, ty_proj, newrowvar lvl ()))); 622 | (cs, ty_proj) 623 | | Expr_record_extend (e, fields) -> 624 | let ty_row = newrowvar lvl () in 625 | let cs, return_ty_row = 626 | List.fold_left fields ~init:([], ty_row) 627 | ~f:(fun (cs, ty_row) (label, e) -> 628 | let ty_e = newvar lvl () in 629 | let cs', ty_e' = infer' lvl env e in 630 | unify ty_e ty_e'; 631 | (cs @ cs', Ty_row_field (label, ty_e, ty_row))) 632 | in 633 | let cs', ty_e' = infer' lvl env e in 634 | unify (Ty_record ty_row) ty_e'; 635 | (cs @ cs', Ty_record return_ty_row) 636 | | Expr_record_update (e, fields) -> 637 | let ty_row = newrowvar lvl () in 638 | let return_ty_row, to_unify = 639 | List.fold fields ~init:(ty_row, []) 640 | ~f:(fun (ty_row, to_unify) (label, e) -> 641 | let ty_e = newvar lvl () in 642 | (Ty_row_field (label, ty_e, ty_row), (e, ty_e) :: to_unify)) 643 | in 644 | let cs, ty_e = infer' lvl env e in 645 | unify (Ty_record return_ty_row) ty_e; 646 | let cs = 647 | List.fold (List.rev to_unify) ~init:cs ~f:(fun cs (e, ty_e) -> 648 | let cs', ty_e' = infer' lvl env e in 649 | unify ty_e ty_e'; 650 | cs @ cs') 651 | in 652 | (cs, Ty_record return_ty_row) 653 | | Expr_lit (Lit_string _) -> ([], Ty_const "string") 654 | | Expr_lit (Lit_int _) -> ([], Ty_const "int") 655 | 656 | let infer env e = 657 | let qty = infer' 0 env e in 658 | generalize (-1) env qty 659 | -------------------------------------------------------------------------------- /algo_w/lexer.mll: -------------------------------------------------------------------------------- 1 | { 2 | 3 | open Parser 4 | 5 | exception Error of string 6 | 7 | } 8 | 9 | 10 | let ident = ['_' 'A'-'Z' 'a'-'z'] ['_' 'A'-'Z' 'a'-'z' '0'-'9']* 11 | let integer = ['0'-'9']+ 12 | 13 | rule token = parse 14 | | [' ' '\t' '\r' '\n'] { token lexbuf } 15 | | "fun" { FUN } 16 | | "let" { LET } 17 | | "rec" { REC } 18 | | "in" { IN } 19 | | "with" { WITH } 20 | | ident { IDENT (Lexing.lexeme lexbuf) } 21 | | '(' { LPAREN } 22 | | ')' { RPAREN } 23 | | '[' { LBRACKET } 24 | | ']' { RBRACKET } 25 | | '{' { LBRACE } 26 | | '}' { RBRACE } 27 | | '=' { EQUALS } 28 | | ':' '=' { ASSIGN } 29 | | "->" { ARROW } 30 | | "=>" { GTE } 31 | | ',' { COMMA } 32 | | '.' { DOT } 33 | | ';' { SEMI } 34 | | ':' { COLON } 35 | | eof { EOF } 36 | | _ as c { raise (Error ("unexpected token: '" ^ Char.escaped c ^ "'")) } 37 | 38 | 39 | { 40 | 41 | let string_of_token = function 42 | | FUN -> "fun" 43 | | LET -> "let" 44 | | REC -> "rec" 45 | | IN -> "in" 46 | | WITH -> "forall" 47 | | IDENT ident -> ident 48 | | LPAREN -> "(" 49 | | RPAREN -> ")" 50 | | LBRACKET -> "[" 51 | | RBRACKET -> "]" 52 | | LBRACE -> "{" 53 | | RBRACE -> "}" 54 | | EQUALS -> "=" 55 | | ASSIGN -> ":=" 56 | | ARROW -> "->" 57 | | COMMA -> "," 58 | | DOT -> "." 59 | | SEMI -> "." 60 | | COLON -> ":" 61 | | GTE -> "=>" 62 | | EOF -> "" 63 | 64 | } 65 | -------------------------------------------------------------------------------- /algo_w/parser.mly: -------------------------------------------------------------------------------- 1 | %{ 2 | 3 | open Expr 4 | 5 | let makeenv vars = 6 | let open Base in 7 | Infer.reset_vars (); 8 | List.fold_left 9 | vars 10 | ~init:(Map.empty (module String)) 11 | ~f:(fun env var_name -> 12 | Map.set env ~key:var_name ~data:(Infer.newgenvar ())) 13 | 14 | let build_ty env ty = 15 | let open Base in 16 | let rec aux ty = match ty with 17 | | Ty_const name -> ( 18 | match Map.find env name with 19 | | Some ty -> ty 20 | | None -> ty) 21 | | Ty_var _ -> ty 22 | | Ty_app (f, args) -> Ty_app (aux f, List.map args ~f:aux) 23 | | Ty_arr (args, ret) -> Ty_arr (List.map args ~f:aux, aux ret) 24 | | Ty_record ty_row -> Ty_record (build_ty_row ty_row) 25 | and build_ty_row ty_row = match ty_row with 26 | | Ty_row_empty 27 | | Ty_row_var _ -> ty_row 28 | | Ty_row_field (name, ty, ty_row) -> 29 | Ty_row_field (name, aux ty, build_ty_row ty_row) 30 | | Ty_row_const name -> ( 31 | match Map.find env name with 32 | | Some (Ty_var {contents = Ty_var_generic id}) -> 33 | (* "convert" it to generic row variable *) 34 | Ty_row_var {contents = Ty_var_generic id} 35 | | Some _ -> 36 | (* shouldn't happen as we only insert generic vars into env *) 37 | assert false 38 | | None -> 39 | (* TODO: we should report a syntax error here as we only allow generic 40 | row variables at surface syntax. *) 41 | assert false) 42 | in 43 | aux ty 44 | 45 | let build_qual_ty env (cs, ty : qual_ty) : qual_ty = 46 | let open Base in 47 | let cs = List.map cs ~f:(fun (n, tys) -> n, List.map tys ~f:(build_ty env)) in 48 | cs, build_ty env ty 49 | 50 | let build_qual_pred env (deps, p) = 51 | let open Base in 52 | let build_pred (name, args) = 53 | name, List.map args ~f:(build_ty env) 54 | in 55 | List.map deps ~f:build_pred, build_pred p 56 | 57 | %} 58 | 59 | %token IDENT 60 | %token FUN LET REC IN WITH 61 | %token LPAREN RPAREN LBRACKET RBRACKET LBRACE RBRACE 62 | %token ARROW EQUALS COMMA DOT SEMI COLON ASSIGN GTE 63 | %token EOF 64 | 65 | %start expr_eof 66 | %type expr_eof 67 | %start qual_ty_forall_eof 68 | %type qual_ty_forall_eof 69 | %start qual_pred_eof 70 | %type qual_pred_eof 71 | 72 | %% 73 | 74 | expr_eof: 75 | e = expr EOF { e } 76 | 77 | qual_ty_forall_eof: 78 | t = qual_ty_forall EOF { t } 79 | 80 | qual_pred_eof: 81 | qp = qual_pred EOF { qp } 82 | 83 | %inline qual_pred: 84 | p = pred EOF { [], p } 85 | | vars = ident_list DOT p = pred EOF { 86 | build_qual_pred (makeenv vars) ([], p) 87 | } 88 | | vars = ident_list DOT deps = flex_list(COMMA, pred) GTE p = pred EOF { 89 | let env = makeenv vars in 90 | build_qual_pred env (deps, p) 91 | } 92 | 93 | expr: 94 | e = simple_expr { e } 95 | 96 | (* let-bindings *) 97 | | LET n = IDENT EQUALS e = expr IN b = expr { Expr_let (n, e, b) } 98 | | LET REC n = IDENT EQUALS e = expr IN b = expr { Expr_let_rec (n, e, b) } 99 | 100 | (* functions *) 101 | | FUN arg = IDENT ARROW body = expr 102 | { Expr_abs ([arg], body) } 103 | | FUN LPAREN args = flex_list(COMMA, IDENT) RPAREN ARROW body = expr 104 | { Expr_abs (args, body) } 105 | 106 | (* let-fun fused *) 107 | | LET n = IDENT arg = IDENT EQUALS e = expr IN b = expr 108 | { Expr_let (n, Expr_abs ([arg], e), b) } 109 | | LET n = IDENT LPAREN args = flex_list(COMMA, IDENT) RPAREN EQUALS e = expr IN b = expr 110 | { Expr_let (n, Expr_abs (args, e), b) } 111 | | LET REC n = IDENT arg = IDENT EQUALS e = expr IN b = expr 112 | { Expr_let_rec (n, Expr_abs ([arg], e), b) } 113 | | LET REC n = IDENT LPAREN args = flex_list(COMMA, IDENT) RPAREN EQUALS e = expr IN b = expr 114 | { Expr_let_rec (n, Expr_abs (args, e), b) } 115 | 116 | (* records *) 117 | | LBRACE fs = flex_list(SEMI, field) RBRACE 118 | { Expr_record fs } 119 | | LBRACE e = expr WITH fs = nonempty_flex_list(SEMI, field) RBRACE 120 | { Expr_record_extend (e, fs) } 121 | | LBRACE e = expr WITH fs = nonempty_flex_list(SEMI, field_update) RBRACE 122 | { Expr_record_update (e, fs) } 123 | 124 | simple_expr: 125 | n = IDENT { Expr_name n } 126 | | LPAREN e = expr RPAREN { e } 127 | | f = simple_expr LPAREN args = flex_list(COMMA, expr) RPAREN 128 | { Expr_app (f, args) } 129 | | e = simple_expr DOT n = IDENT 130 | { Expr_record_proj (e, n) } 131 | 132 | field: 133 | n = IDENT EQUALS e = expr { (n, e) } 134 | 135 | field_update: 136 | n = IDENT ASSIGN e = expr { (n, e) } 137 | 138 | ident_list: 139 | xs = nonempty_flex_list(COMMA, IDENT) { xs } 140 | 141 | qual_ty_forall: 142 | qt = qual_ty { qt } 143 | | vars = ident_list DOT qt = qual_ty 144 | { let env = makeenv vars in build_qual_ty env qt } 145 | 146 | qual_ty: 147 | t = ty { ([], t) } 148 | | cs = nonempty_flex_list(COMMA, pred) GTE t = ty 149 | { (cs, t) } 150 | 151 | pred: 152 | n = IDENT LPAREN args = nonempty_flex_list(COMMA, ty) RPAREN { (n, args) } 153 | 154 | ty: 155 | t = simple_ty 156 | { t } 157 | | LPAREN RPAREN ARROW ret = ty 158 | { Ty_arr ([], ret) } 159 | | arg = simple_ty ARROW ret = ty 160 | { Ty_arr ([arg], ret) } 161 | | LPAREN arg = ty COMMA args = flex_list(COMMA, ty) RPAREN ARROW ret = ty 162 | { Ty_arr (arg :: args, ret) } 163 | | LBRACE row = ty_row RBRACE 164 | { Ty_record row } 165 | 166 | ty_row: 167 | { Ty_row_empty } 168 | | t = IDENT { Ty_row_const t } 169 | | n = IDENT COLON t = ty { Ty_row_field (n, t, Ty_row_empty) } 170 | | n = IDENT COLON t = ty SEMI r = ty_row { Ty_row_field (n, t, r) } 171 | 172 | simple_ty: 173 | n = IDENT { Ty_const n } 174 | | LPAREN t = ty RPAREN { t } 175 | | f = simple_ty LBRACKET args = nonempty_flex_list(COMMA, ty) RBRACKET 176 | { Ty_app (f, args) } 177 | 178 | (* Utilities for flexible lists (and its non-empty version). 179 | 180 | A flexible list [flex_list(delim, X)] is the delimited with [delim] list of 181 | it [X] items where it is allowed to have a trailing [delim]. 182 | 183 | A non-empty [nonempty_flex_list(delim, X)] version of flexible list is 184 | provided as well. 185 | 186 | From http://gallium.inria.fr/blog/lr-lists/ 187 | 188 | *) 189 | 190 | flex_list(delim, X): 191 | { [] } 192 | | x = X { [x] } 193 | | x = X delim xs = flex_list(delim, X) { x::xs } 194 | 195 | nonempty_flex_list(delim, X): 196 | x = X { [x] } 197 | | x = X delim xs = flex_list(delim, X) { x::xs } 198 | -------------------------------------------------------------------------------- /algo_w/test/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name test_algo_w) 3 | (inline_tests) 4 | (preprocess 5 | (pps ppx_expect)) 6 | (libraries base algo_w)) 7 | -------------------------------------------------------------------------------- /dune-project: -------------------------------------------------------------------------------- 1 | (lang dune 2.9) 2 | (name type-systems) 3 | (using menhir 2.1) 4 | (generate_opam_files true) 5 | 6 | (package 7 | (name type-systems) 8 | (synopsis "Type Systems Toy Implementations") 9 | (depends 10 | (dune (>= 2.8)) 11 | menhir 12 | pprint 13 | nice_parser 14 | ppx_sexp_conv 15 | ppx_expect 16 | base)) 17 | -------------------------------------------------------------------------------- /hmx/bin/dune: -------------------------------------------------------------------------------- 1 | (executable 2 | (public_name hmx) 3 | (name main) 4 | (libraries hmx base stdio)) 5 | -------------------------------------------------------------------------------- /hmx/bin/main.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | open Hmx 3 | 4 | let () = 5 | let env = 6 | Env.empty 7 | |> Env.assume_val "one" "int" 8 | |> Env.assume_val "hello" "string" 9 | |> Env.assume_val "pair" "a, b . (a, b) -> pair[a, b]" 10 | |> Env.assume_val "plus" "int -> int -> int" 11 | |> Env.assume_val "true" "bool" 12 | |> Env.assume_val "eq" "a . (a, a) -> bool" 13 | in 14 | let e = Expr.parse_chan Stdio.stdin in 15 | match infer ~env e with 16 | | Ok e -> Caml.Format.printf "%s@." (Expr.show e) 17 | | Error err -> Caml.Format.printf "ERROR: %s@." (Type_error.show err) 18 | -------------------------------------------------------------------------------- /hmx/constraint.ml: -------------------------------------------------------------------------------- 1 | (** 2 | 3 | This module defines constraint language. 4 | 5 | The constraint language has applicative structure (see [C_map] constructor) 6 | which is used for elaboration (the constraint solving algo computes a value). 7 | 8 | *) 9 | 10 | open Base 11 | open Syntax 12 | 13 | type _ t = 14 | | C_trivial : unit t 15 | (** A trivial constraint, states nothing useful. Always can be solved. *) 16 | | C_eq : ty * ty -> unit t 17 | (** [C_eq (ty1, ty2)] states that the types [ty1] and [ty2] are equal. *) 18 | | C_inst : name * ty -> expr t 19 | (** [C_inst (name, ty)] states that [name] should be fetched from the 20 | environment, instantiated and equated to [ty]. *) 21 | | C_and : 'a t * 'b t -> ('a * 'b) t 22 | (** Conjuction of two constraints, possibly of different value type. *) 23 | | C_and_list : 'a t list -> 'a list t 24 | (** Conjuction of multiple constraints of the same value type. *) 25 | | C_exists : var list * 'a t -> 'a t 26 | (** [C_exists (vs, c)] existentially quantifies variables [vs] over [c]. *) 27 | | C_let : 28 | (name * var list * expr t * ty) list * 'b t 29 | -> ((expr * ty_sch) list * 'b) t 30 | (** [C_let (bindings, c)] works is a constraint abstraction fused with 31 | existential quantification. It adds [bindings] to the environment of 32 | the following constraint [c]. 33 | 34 | It allows to define multiple names at once to support n-ary functions. *) 35 | | C_map : 'a t * ('a -> 'b) -> 'b t 36 | (** Map operation, this gives an applicative structure. *) 37 | 38 | let trivial = C_trivial 39 | 40 | let return v = C_map (C_trivial, fun () -> v) 41 | 42 | let ( =~ ) x y = C_eq (x, y) 43 | 44 | let inst name cty = C_inst (name, cty) 45 | 46 | let exists tys c = 47 | match tys with 48 | | [] -> c 49 | | tys -> ( 50 | match c with 51 | | C_exists (tys', c) -> C_exists (tys @ tys', c) 52 | | c -> C_exists (tys, c)) 53 | 54 | let let_in bindings c = C_let (bindings, c) 55 | 56 | let ( &~ ) x y = C_and (x, y) 57 | 58 | let ( >>| ) c f = C_map (c, f) 59 | 60 | let list cs = C_and_list cs 61 | 62 | let rec layout' : type a. names:Names.t -> a t -> PPrint.document = 63 | fun ~names c -> 64 | let open PPrint in 65 | match c with 66 | | C_trivial -> string "TRUE" 67 | | C_eq (ty1, ty2) -> 68 | layout_ty' ~names ty1 ^^ string " = " ^^ layout_ty' ~names ty2 69 | | C_and (a, b) -> layout' ~names a ^^ string " & " ^^ layout' ~names b 70 | | C_and_list cs -> 71 | let sep = string " & " in 72 | separate sep (List.map cs ~f:(layout' ~names)) 73 | | C_exists (vs, c) -> 74 | string "∃" 75 | ^^ separate comma (List.map vs ~f:(layout_con_var' ~names)) 76 | ^^ dot 77 | ^^ parens (layout' ~names c) 78 | | C_let (bindings, c) -> 79 | let layout_cty' : type a. a t * ty -> document = function 80 | | C_trivial, ty -> layout_ty' ~names ty 81 | | c, ty -> layout' ~names c ^^ string " => " ^^ layout_ty' ~names ty 82 | in 83 | let layout_binding : type a. string * var list * a t * ty -> document = 84 | fun (name, vs, c, ty) -> 85 | string name 86 | ^^ string " : " 87 | ^^ 88 | match vs with 89 | | [] -> layout_cty' (c, ty) 90 | | vs -> 91 | let vs = layout_var_prenex' ~names vs in 92 | vs ^^ layout_cty' (c, ty) 93 | in 94 | let sep = comma ^^ blank 1 in 95 | string "let " 96 | ^^ separate sep (List.map bindings ~f:layout_binding) 97 | ^^ string " in " 98 | ^^ layout' ~names c 99 | | C_inst (name, ty) -> string name ^^ string " ≲ " ^^ layout_ty' ~names ty 100 | | C_map (c, _f) -> layout' ~names c 101 | 102 | include ( 103 | struct 104 | type pack = P : _ t -> pack 105 | 106 | include Showable (struct 107 | type t = pack 108 | 109 | let layout (P c) = 110 | let names = Names.make () in 111 | layout' ~names c 112 | end) 113 | 114 | let layout c = layout (P c) 115 | 116 | let show c = show (P c) 117 | 118 | let print ?label c = print ?label (P c) 119 | end : 120 | sig 121 | val layout : _ t -> PPrint.document 122 | 123 | val show : _ t -> string 124 | 125 | val print : ?label:string -> _ t -> unit 126 | end) 127 | -------------------------------------------------------------------------------- /hmx/debug.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | 3 | let flags = Caml.Sys.getenv_opt "HMX_DEBUG" |> Option.value ~default:"" 4 | 5 | let log_levels = String.mem flags 'l' 6 | 7 | let log_instantiate = String.mem flags 'i' 8 | 9 | let log_generalize = String.mem flags 'g' 10 | 11 | let log_unify = String.mem flags 'u' 12 | 13 | let log_define = String.mem flags 'd' 14 | -------------------------------------------------------------------------------- /hmx/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name hmx) 3 | (preprocess (pps ppx_sexp_conv)) 4 | (libraries base pprint nice_parser) 5 | ) 6 | 7 | (ocamllex lexer) 8 | 9 | (menhir 10 | (modules parser) 11 | (flags --explain --dump)) 12 | -------------------------------------------------------------------------------- /hmx/hmx.ml: -------------------------------------------------------------------------------- 1 | module Expr = struct 2 | include Syntax.Expr 3 | 4 | include Nice_parser.Make (struct 5 | type token = Parser.token 6 | 7 | type result = Syntax.expr 8 | 9 | let parse = Parser.expr_eof 10 | 11 | let next_token = Lexer.token 12 | 13 | exception ParseError = Parser.Error 14 | 15 | exception LexError = Lexer.Error 16 | end) 17 | end 18 | 19 | module Ty = struct 20 | include Syntax.Ty 21 | end 22 | 23 | module Ty_sch = struct 24 | include Syntax.Ty_sch 25 | 26 | include Nice_parser.Make (struct 27 | type token = Parser.token 28 | 29 | type result = Syntax.ty_sch 30 | 31 | let parse = Parser.ty_sch_eof 32 | 33 | let next_token = Lexer.token 34 | 35 | exception ParseError = Parser.Error 36 | 37 | exception LexError = Lexer.Error 38 | end) 39 | end 40 | 41 | module Type_error = Type_error 42 | module Var = Var 43 | 44 | module Env = struct 45 | include Infer.Env 46 | 47 | let assume_val name ty env = add_val env name (Ty_sch.parse_string ty) 48 | end 49 | 50 | let infer = Infer.infer 51 | -------------------------------------------------------------------------------- /hmx/infer.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | open Syntax 3 | 4 | (** Constraints generation. *) 5 | let rec generate : expr -> ty -> expr Constraint.t = 6 | fun e ty -> 7 | match e with 8 | | E_lit (Lit_int _) -> Constraint.(ty =~ Ty_const "int" >>| fun () -> e) 9 | | E_lit (Lit_string _) -> Constraint.(ty =~ Ty_const "string" >>| fun () -> e) 10 | | E_var name -> Constraint.inst name ty 11 | | E_app (f, args) -> 12 | let vs, arg_tys, args = 13 | List.fold args ~init:([], [], []) ~f:(fun (vs, arg_tys, args) arg -> 14 | let v = Var.fresh () in 15 | let arg_ty = Ty.var v in 16 | let arg = generate arg arg_ty in 17 | (v :: vs, arg_ty :: arg_tys, arg :: args)) 18 | in 19 | let f_ty = Ty.(arr (List.rev arg_tys) ty) in 20 | let f = generate f f_ty in 21 | let args = Constraint.list (List.rev args) in 22 | Constraint.(exists vs (f &~ args) >>| fun (f, args) -> E_app (f, args)) 23 | | E_abs (args, b) -> 24 | let vs, bindings, atys = 25 | List.fold args ~init:([], [], []) ~f:(fun (vs, bindings, atys) param -> 26 | let v = Var.fresh () in 27 | let ty = Ty.var v in 28 | ( v :: vs, 29 | (param, [], Constraint.return (E_var param), ty) :: bindings, 30 | Ty.var v :: atys )) 31 | in 32 | let v = Var.fresh () in 33 | let b = generate b (Ty.var v) in 34 | let ty' = Ty.arr (List.rev atys) (Ty.var v) in 35 | Constraint.( 36 | exists (v :: vs) (let_in (List.rev bindings) b &~ (ty' =~ ty)) 37 | >>| fun ((_tys, b), ()) -> E_abs (args, b)) 38 | | E_let ((name, e, _), b) -> ( 39 | let v = Var.fresh () in 40 | let e_ty = Ty.var v in 41 | let e = generate e e_ty in 42 | let b = generate b ty in 43 | Constraint.( 44 | let_in [ (name, [ v ], e, e_ty) ] b >>| fun (tys, b) -> 45 | match tys with 46 | | [ (e, ty_sch) ] -> E_let ((name, e, Some ty_sch), b) 47 | | _ -> failwith "impossible as we are supplying a single binding")) 48 | 49 | (* [unify ty1 ty2] unifies two types [ty1] and [ty2]. *) 50 | let rec unify ty1 ty2 = 51 | if Debug.log_unify then 52 | Caml.Format.printf "UNIFY %s ~ %s@." (Ty.show ty1) (Ty.show ty2); 53 | if phys_equal ty1 ty2 then () 54 | else 55 | match (ty1, ty2) with 56 | | Ty_const a, Ty_const b -> 57 | if not String.(a = b) then Type_error.raise (Error_unification (ty1, ty2)) 58 | | Ty_app (a1, b1), Ty_app (a2, b2) -> ( 59 | unify a1 a2; 60 | match List.iter2 b1 b2 ~f:unify with 61 | | Unequal_lengths -> Type_error.raise (Error_unification (ty1, ty2)) 62 | | Ok () -> ()) 63 | | Ty_arr (a1, b1), Ty_arr (a2, b2) -> ( 64 | match List.iter2 a1 a2 ~f:unify with 65 | | Unequal_lengths -> Type_error.raise (Error_unification (ty1, ty2)) 66 | | Ok () -> unify b1 b2) 67 | | Ty_var var1, Ty_var var2 -> ( 68 | match Var.unify var1 var2 with 69 | | Some (ty1, ty2) -> unify ty1 ty2 70 | | None -> ()) 71 | | Ty_var var, ty 72 | | ty, Ty_var var -> ( 73 | match Var.ty var with 74 | | None -> 75 | Var.occurs_check_adjust_lvl var ty; 76 | Var.set_ty ty var 77 | | Some ty' -> unify ty' ty) 78 | | _, _ -> Type_error.raise (Error_unification (ty1, ty2)) 79 | 80 | (** A substitution over [ty] terms. *) 81 | module Subst : sig 82 | type t 83 | 84 | val empty : t 85 | 86 | val make : (var * ty) list -> t 87 | 88 | val apply_ty : t -> ty -> ty 89 | end = struct 90 | type t = (var * ty) list 91 | 92 | let empty = [] 93 | 94 | let make pairs = pairs 95 | 96 | let find subst var = 97 | match List.Assoc.find ~equal:Var.equal subst var with 98 | | Some ty -> Some ty 99 | | None -> None 100 | 101 | let rec apply_ty subst ty = 102 | match ty with 103 | | Ty_const _ -> ty 104 | | Ty_app (a, args) -> 105 | Ty_app (apply_ty subst a, List.map args ~f:(apply_ty subst)) 106 | | Ty_arr (args, b) -> 107 | Ty_arr (List.map args ~f:(apply_ty subst), apply_ty subst b) 108 | | Ty_var v -> ( 109 | match Var.ty v with 110 | | Some ty -> apply_ty subst ty 111 | | None -> ( 112 | match find subst v with 113 | | Some ty -> ty 114 | | None -> ty)) 115 | end 116 | 117 | (** An applicative structure which is used to build a computation which 118 | elaborates terms. *) 119 | module Elab : sig 120 | type 'a t 121 | 122 | val run : 'a t -> 'a 123 | (** Compute a value. 124 | 125 | This should be run at the very end, when all holes are elaborated. *) 126 | 127 | val return : 'a -> 'a t 128 | 129 | val map : 'a t -> ('a -> 'b) -> 'b t 130 | 131 | val both : 'a t -> 'b t -> ('a * 'b) t 132 | 133 | val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t 134 | 135 | val ( and+ ) : 'a t -> 'b t -> ('a * 'b) t 136 | 137 | val list : 'a t list -> 'a list t 138 | end = struct 139 | type 'a t = unit -> 'a 140 | 141 | let run elab = elab () 142 | 143 | let return v () = v 144 | 145 | let map v f () = f (v ()) 146 | 147 | let both a b () = (a (), b ()) 148 | 149 | let ( let+ ) = map 150 | 151 | let ( and+ ) = both 152 | 153 | let list es () = List.map es ~f:run 154 | end 155 | 156 | (** Instantiate type scheme into a constrained type. *) 157 | let instantiate ~lvl (ty_sch : ty_sch) : ty = 158 | match ty_sch with 159 | | [], ty -> 160 | (* No ∀-quantified variables, return the type as-is *) 161 | ty 162 | | vars, cty -> 163 | let subst = 164 | Subst.make (List.map vars ~f:(fun v -> (v, Ty.var (Var.fresh ~lvl ())))) 165 | in 166 | Subst.apply_ty subst cty 167 | 168 | let instantiate ~lvl ty_sch = 169 | if Debug.log_instantiate then Ty_sch.print ~label:"I<" ty_sch; 170 | let cty = instantiate ~lvl ty_sch in 171 | if Debug.log_instantiate then Ty.print ~label:"I>" cty; 172 | cty 173 | 174 | module Env = struct 175 | type t = { 176 | values : (name, def, String.comparator_witness) Map.t; 177 | tclasses : (name, tclass, String.comparator_witness) Map.t; 178 | } 179 | 180 | and def = { name : name; ty_sch : ty_sch } 181 | 182 | and tclass = { def : def; method_def : def; instances : def list } 183 | 184 | let empty = 185 | { values = Map.empty (module String); tclasses = Map.empty (module String) } 186 | 187 | let find_val env name = Map.find env.values name 188 | 189 | let find_tclass env name = Map.find env.tclasses name 190 | 191 | let add_val env name ty_sch = 192 | if Debug.log_define then 193 | Caml.Format.printf "val %s : %s@." name (Ty_sch.show ty_sch); 194 | { env with values = Map.set env.values ~key:name ~data:{ name; ty_sch } } 195 | end 196 | 197 | let simple_vs vs = 198 | let rec simple_vs vs' = function 199 | | [] -> vs' 200 | | v :: vs -> 201 | if Option.is_some (Var.ty v) then 202 | (* Skipping as the var is already bound. *) 203 | simple_vs vs' vs 204 | else if List.mem ~equal:Var.equal vs v then 205 | (* Skipping as the var is duplicated within [vs]. *) 206 | simple_vs vs' vs 207 | else simple_vs (v :: vs') vs 208 | in 209 | simple_vs [] vs 210 | 211 | let ty_vs ~lvl ty = 212 | let rec aux vs = function 213 | | Ty_const _ -> vs 214 | | Ty_app (a, args) -> 215 | let vs = aux vs a in 216 | List.fold args ~init:vs ~f:aux 217 | | Ty_arr (args, b) -> 218 | let vs = aux vs b in 219 | List.fold args ~init:vs ~f:aux 220 | | Ty_var v -> ( 221 | match Var.ty v with 222 | | Some ty -> aux vs ty 223 | | None -> if Var.lvl v > lvl then v :: vs else vs) 224 | in 225 | aux [] ty 226 | 227 | let generalize ~lvl ty = 228 | let gvs = simple_vs (ty_vs ~lvl ty) in 229 | (simple_vs gvs, ty) 230 | 231 | let generalize ~lvl ty = 232 | let ty_sch = generalize ~lvl ty in 233 | if Debug.log_generalize then Ty_sch.print ~label:"G>" ty_sch; 234 | ty_sch 235 | 236 | let rec solve : type a. lvl:lvl -> env:Env.t -> a Constraint.t -> a Elab.t = 237 | fun ~lvl ~env c -> 238 | match c with 239 | | C_trivial -> Elab.return () 240 | | C_eq (a, b) -> 241 | unify a b; 242 | Elab.return () 243 | | C_map (c, f) -> 244 | let v = solve ~lvl ~env c in 245 | Elab.map v f 246 | | C_and (a, b) -> 247 | let a = solve ~lvl ~env a in 248 | let b = solve ~lvl ~env b in 249 | Elab.both a b 250 | | C_and_list cs -> 251 | let vs = 252 | List.fold cs ~init:[] ~f:(fun vs c -> 253 | let v = solve ~lvl ~env c in 254 | v :: vs) 255 | in 256 | Elab.( 257 | let+ vs = list vs in 258 | List.rev vs) 259 | | C_exists (vs, c) -> 260 | List.iter vs ~f:(Var.set_lvl lvl); 261 | solve ~lvl ~env c 262 | | C_let (bindings, c) -> 263 | let env, values = 264 | let env0 = env in 265 | List.fold bindings ~init:(env, []) 266 | ~f:(fun (env, values) (name, vs, c, ty) -> 267 | (* Need to set levels here as [C_let] works as [C_exists] as well. *) 268 | List.iter vs ~f:(Var.set_lvl (lvl + 1)); 269 | let e, ty_sch = solve_and_generalize ~lvl:(lvl + 1) ~env:env0 c ty in 270 | let env = Env.add_val env name ty_sch in 271 | (env, e :: values)) 272 | in 273 | let v = solve ~lvl ~env c in 274 | let values = Elab.list (List.rev values) in 275 | Elab.(both values v) 276 | | C_inst (name, ty) -> 277 | let ty_sch = 278 | match Env.find_val env name with 279 | | Some def -> def.ty_sch 280 | | None -> Type_error.raise (Error_unknown_name name) 281 | in 282 | let ty' = instantiate ~lvl ty_sch in 283 | unify ty ty'; 284 | Elab.return (E_var name) 285 | 286 | and solve_and_generalize ~lvl ~env c ty = 287 | let e = solve ~lvl ~env c in 288 | let ty_sch = generalize ~lvl:(lvl - 1) ty in 289 | let e = 290 | Elab.( 291 | let+ e = e in 292 | (e, ty_sch)) 293 | in 294 | (e, ty_sch) 295 | 296 | (** [infer ~env e] infers the type scheme for expression [e]. 297 | 298 | It returns either an [Ok (ty_sch, elaborated)] where [ty_sch] is the type 299 | scheme inferred and [elaborated] is an elaborated expression corresponding 300 | to [e]. 301 | 302 | ... or in case of an error it returns [Error err]. 303 | *) 304 | let infer ~env e : (expr, Type_error.t) Result.t = 305 | (* To infer an expression type we first generate constraints *) 306 | let v = Var.fresh () in 307 | let ty = Ty.var v in 308 | let c = generate e ty in 309 | let c = Constraint.exists [ v ] c in 310 | try 311 | (* and then solve them and generaralize!. *) 312 | let e, _ty_sch = solve_and_generalize ~lvl:1 ~env c ty in 313 | Ok 314 | Elab.( 315 | run 316 | (let+ e, ty_sch = e in 317 | E_let (("_", e, Some ty_sch), E_var "_"))) 318 | with 319 | | Type_error.Type_error error -> Error error 320 | -------------------------------------------------------------------------------- /hmx/lexer.mll: -------------------------------------------------------------------------------- 1 | { 2 | 3 | open Parser 4 | 5 | exception Error of string 6 | 7 | } 8 | 9 | 10 | let ident = ['_' 'A'-'Z' 'a'-'z'] ['_' 'A'-'Z' 'a'-'z' '0'-'9']* 11 | let integer = ['0'-'9']+ 12 | 13 | rule token = parse 14 | | [' ' '\t' '\r' '\n'] { token lexbuf } 15 | | "fun" { FUN } 16 | | "let" { LET } 17 | | "rec" { REC } 18 | | "in" { IN } 19 | | "with" { WITH } 20 | | ident { IDENT (Lexing.lexeme lexbuf) } 21 | | '(' { LPAREN } 22 | | ')' { RPAREN } 23 | | '[' { LBRACKET } 24 | | ']' { RBRACKET } 25 | | '{' { LBRACE } 26 | | '}' { RBRACE } 27 | | '=' { EQUALS } 28 | | ':' '=' { ASSIGN } 29 | | "->" { ARROW } 30 | | "=>" { GTE } 31 | | ',' { COMMA } 32 | | '.' { DOT } 33 | | ';' { SEMI } 34 | | ':' { COLON } 35 | | eof { EOF } 36 | | _ as c { raise (Error ("unexpected token: '" ^ Char.escaped c ^ "'")) } 37 | 38 | 39 | { 40 | 41 | let string_of_token = function 42 | | FUN -> "fun" 43 | | LET -> "let" 44 | | REC -> "rec" 45 | | IN -> "in" 46 | | WITH -> "forall" 47 | | IDENT ident -> ident 48 | | LPAREN -> "(" 49 | | RPAREN -> ")" 50 | | LBRACKET -> "[" 51 | | RBRACKET -> "]" 52 | | LBRACE -> "{" 53 | | RBRACE -> "}" 54 | | EQUALS -> "=" 55 | | ASSIGN -> ":=" 56 | | ARROW -> "->" 57 | | COMMA -> "," 58 | | DOT -> "." 59 | | SEMI -> "." 60 | | COLON -> ":" 61 | | GTE -> "=>" 62 | | EOF -> "" 63 | 64 | } 65 | -------------------------------------------------------------------------------- /hmx/parser.mly: -------------------------------------------------------------------------------- 1 | %{ 2 | 3 | open Syntax 4 | 5 | let makeenv vars = 6 | let open Base in 7 | Var.reset (); 8 | let vs, map = List.fold_left 9 | vars 10 | ~init:([], Map.empty (module String)) 11 | ~f:(fun (vs, env) name -> 12 | let v = Var.fresh () in 13 | v::vs, 14 | Map.set env ~key:name ~data:(Ty.var v)) in 15 | List.rev vs, map 16 | 17 | let build_ty_sch (vs, env) ty = 18 | let open Base in 19 | let rec build_ty ty = match ty with 20 | | Ty_const name -> ( 21 | match Map.find env name with 22 | | Some ty -> ty 23 | | None -> ty) 24 | | Ty_var _ -> ty 25 | | Ty_app (fty, atys) -> Ty_app (build_ty fty, List.map atys ~f:build_ty) 26 | | Ty_arr (atys, rty) -> Ty_arr (List.map atys ~f:build_ty, build_ty rty) 27 | in 28 | vs, build_ty ty 29 | %} 30 | 31 | %token IDENT 32 | %token FUN LET REC IN WITH 33 | %token LPAREN RPAREN LBRACKET RBRACKET LBRACE RBRACE 34 | %token ARROW EQUALS COMMA DOT SEMI COLON ASSIGN GTE 35 | %token EOF 36 | 37 | %start expr_eof 38 | %type expr_eof 39 | %start ty_sch_eof 40 | %type ty_sch_eof 41 | 42 | %% 43 | 44 | expr_eof: 45 | e = expr EOF { e } 46 | 47 | ty_sch_eof: 48 | t = ty_sch EOF { t } 49 | 50 | expr: 51 | e = simple_expr { e } 52 | 53 | (* let-bindings *) 54 | | LET n = IDENT EQUALS e = expr IN b = expr { E_let ((n, e, None), b) } 55 | 56 | (* functions *) 57 | | FUN arg = IDENT ARROW body = expr 58 | { E_abs ([arg], body) } 59 | | FUN LPAREN args = flex_list(COMMA, IDENT) RPAREN ARROW body = expr 60 | { E_abs (args, body) } 61 | 62 | | LET n = IDENT arg = IDENT EQUALS e = expr IN b = expr 63 | { E_let ((n, E_abs ([arg], e), None), b) } 64 | | LET n = IDENT LPAREN args = flex_list(COMMA, IDENT) RPAREN EQUALS e = expr IN b = expr 65 | { E_let ((n, E_abs (args, e), None), b) } 66 | 67 | simple_expr: 68 | n = IDENT { E_var n } 69 | | LPAREN e = expr RPAREN { e } 70 | | f = simple_expr LPAREN args = flex_list(COMMA, expr) RPAREN 71 | { E_app (f, args) } 72 | 73 | ident_list: 74 | xs = nonempty_flex_list(COMMA, IDENT) { xs } 75 | 76 | ty_sch: 77 | t = ty { [], t } 78 | | vars = ident_list DOT t = ty 79 | { let env = makeenv vars in build_ty_sch env t } 80 | 81 | ty: 82 | t = simple_ty 83 | { t } 84 | | LPAREN RPAREN ARROW ret = ty 85 | { Ty_arr ([], ret) } 86 | | arg = simple_ty ARROW ret = ty 87 | { Ty_arr ([arg], ret) } 88 | | LPAREN arg = ty COMMA args = flex_list(COMMA, ty) RPAREN ARROW ret = ty 89 | { Ty_arr (arg :: args, ret) } 90 | 91 | simple_ty: 92 | n = IDENT { Ty_const n } 93 | | LPAREN t = ty RPAREN { t } 94 | | f = simple_ty LBRACKET args = nonempty_flex_list(COMMA, ty) RBRACKET 95 | { Ty_app (f, args) } 96 | 97 | (* Utilities for flexible lists (and its non-empty version). 98 | 99 | A flexible list [flex_list(delim, X)] is the delimited with [delim] list of 100 | it [X] items where it is allowed to have a trailing [delim]. 101 | 102 | A non-empty [nonempty_flex_list(delim, X)] version of flexible list is 103 | provided as well. 104 | 105 | From http://gallium.inria.fr/blog/lr-lists/ 106 | 107 | *) 108 | 109 | flex_list(delim, X): 110 | { [] } 111 | | x = X { [x] } 112 | | x = X delim xs = flex_list(delim, X) { x::xs } 113 | 114 | nonempty_flex_list(delim, X): 115 | x = X { [x] } 116 | | x = X delim xs = flex_list(delim, X) { x::xs } 117 | -------------------------------------------------------------------------------- /hmx/syntax.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | 3 | type name = string [@@deriving sexp_of] 4 | 5 | and id = int 6 | 7 | and lvl = int 8 | 9 | type expr = 10 | | E_var of name 11 | | E_abs of name list * expr 12 | | E_app of expr * expr list 13 | | E_let of (name * expr * ty_sch option) * expr 14 | | E_lit of lit 15 | [@@deriving sexp_of] 16 | 17 | and lit = Lit_string of string | Lit_int of int 18 | 19 | and ty = 20 | | Ty_const of name 21 | | Ty_var of var 22 | | Ty_app of ty * ty list 23 | | Ty_arr of ty list * ty 24 | 25 | and var = var_data Union_find.t 26 | 27 | and var_data = { 28 | id : int; 29 | mutable lvl : lvl option; 30 | (** Levels are assigned when we enter [C_exists] or [C_let] constraints *) 31 | mutable ty : ty option; 32 | (** Types are discovered as a result of unification. *) 33 | } 34 | 35 | and ty_sch = var list * ty 36 | 37 | module Names : sig 38 | type t 39 | 40 | val make : unit -> t 41 | 42 | val alloc : t -> id -> string 43 | 44 | val lookup : t -> id -> string option 45 | end = struct 46 | type t = (Int.t, string) Hashtbl.t 47 | 48 | let make () = Hashtbl.create (module Int) 49 | 50 | let alloc names id = 51 | let i = Hashtbl.length names in 52 | let name = 53 | String.make 1 (Char.of_int_exn (97 + Int.rem i 26)) 54 | ^ if i >= 26 then Int.to_string (i / 26) else "" 55 | in 56 | Hashtbl.set names ~key:id ~data:name; 57 | name 58 | 59 | let lookup names id = Hashtbl.find names id 60 | end 61 | 62 | module MakeId () = struct 63 | let c = ref 0 64 | 65 | let fresh () = 66 | Int.incr c; 67 | !c 68 | 69 | let reset () = c := 0 70 | end 71 | 72 | module type SHOWABLE = sig 73 | type t 74 | 75 | val layout : t -> PPrint.document 76 | 77 | val show : t -> string 78 | 79 | val print : ?label:string -> t -> unit 80 | end 81 | 82 | module Showable (S : sig 83 | type t 84 | 85 | val layout : t -> PPrint.document 86 | end) : SHOWABLE with type t = S.t = struct 87 | type t = S.t 88 | 89 | let layout = S.layout 90 | 91 | let show v = 92 | let width = 60 in 93 | let buf = Buffer.create 100 in 94 | PPrint.ToBuffer.pretty 1. width buf (S.layout v); 95 | Buffer.contents buf 96 | 97 | let print ?label v = 98 | match label with 99 | | Some label -> Caml.print_endline (label ^ ": " ^ show v) 100 | | None -> Caml.print_endline (show v) 101 | end 102 | 103 | module type DUMPABLE = sig 104 | type t 105 | 106 | val dump : ?label:string -> t -> unit 107 | 108 | val sdump : ?label:string -> t -> string 109 | end 110 | 111 | module Dumpable (S : sig 112 | type t 113 | 114 | val sexp_of_t : t -> Sexp.t 115 | end) : DUMPABLE with type t = S.t = struct 116 | type t = S.t 117 | 118 | let dump ?label v = 119 | let s = S.sexp_of_t v in 120 | match label with 121 | | None -> Caml.Format.printf "%a@." Sexp.pp_hum s 122 | | Some label -> Caml.Format.printf "%s %a@." label Sexp.pp_hum s 123 | 124 | let sdump ?label v = 125 | let s = S.sexp_of_t v in 126 | match label with 127 | | None -> Caml.Format.asprintf "%a@." Sexp.pp_hum s 128 | | Some label -> Caml.Format.asprintf "%s %a@." label Sexp.pp_hum s 129 | end 130 | 131 | let layout_var v = 132 | let open PPrint in 133 | let lvl = Option.(v.lvl |> map ~f:Int.to_string |> value ~default:"!") in 134 | if Debug.log_levels then string (Printf.sprintf "_%i@%s" v.id lvl) 135 | else string (Printf.sprintf "_%i" v.id) 136 | 137 | let rec layout_expr' ~names = 138 | let open PPrint in 139 | function 140 | | E_var name -> string name 141 | | E_abs (args, body) -> 142 | let sep = comma ^^ blank 1 in 143 | let newline = 144 | (* Always break on let inside the body. *) 145 | match body with 146 | | E_let _ -> hardline 147 | | _ -> break 1 148 | in 149 | let args = 150 | match args with 151 | | [ arg ] -> string arg 152 | | args -> parens (separate sep (List.map args ~f:string)) 153 | in 154 | group 155 | (group (string "fun " ^^ args ^^ string " ->") 156 | ^^ nest 2 (group (newline ^^ group (layout_expr' ~names body)))) 157 | | E_app (f, args) -> 158 | let sep = comma ^^ break 1 in 159 | group 160 | (layout_expr' ~names f 161 | ^^ parens 162 | (nest 2 163 | (group 164 | (break 0 165 | ^^ separate sep (List.map args ~f:(layout_expr' ~names)))))) 166 | | E_let _ as e -> 167 | let es = 168 | (* We do not want to print multiple nested let-expression with indents and 169 | therefore we linearize them first and print on the same indent instead. *) 170 | let rec linearize es e = 171 | match e with 172 | | E_let (_, b) -> linearize (e :: es) b 173 | | e -> e :: es 174 | in 175 | List.rev (linearize [] e) 176 | in 177 | let newline = 178 | (* If there's more than a single let-expression found (checking length > 2 179 | because es containts the body of the last let-expression too) we split 180 | them with a hardline. *) 181 | if List.length es > 2 then hardline else break 1 182 | in 183 | concat 184 | (List.map es ~f:(function 185 | | E_let ((name, expr, ty_sch), _) -> 186 | let ascription = 187 | (* We need to layout ty_sch first as it will allocate names for use 188 | down the road. *) 189 | match ty_sch with 190 | | None -> empty 191 | | Some ty_sch -> 192 | string " :" ^^ nest 4 (break 1 ^^ layout_ty_sch' ~names ty_sch) 193 | in 194 | let expr_newline = 195 | (* If there's [let x = let y = ... in ... in ...] then we want to 196 | force break. *) 197 | match expr with 198 | | E_let _ -> hardline 199 | | _ -> break 1 200 | in 201 | group 202 | (group (string "let " ^^ string name ^^ ascription ^^ string " =") 203 | ^^ nest 2 (expr_newline ^^ layout_expr' ~names expr) 204 | ^^ expr_newline 205 | ^^ string "in") 206 | ^^ newline 207 | | e -> layout_expr' ~names e)) 208 | | E_lit (Lit_string v) -> dquotes (string v) 209 | | E_lit (Lit_int v) -> dquotes (string (Int.to_string v)) 210 | 211 | and layout_ty' ~names ty = 212 | let open PPrint in 213 | let rec is_ty_arr = function 214 | | Ty_var var -> ( 215 | match (Union_find.value var).ty with 216 | | None -> false 217 | | Some ty -> is_ty_arr ty) 218 | | Ty_arr _ -> true 219 | | _ -> false 220 | in 221 | let rec layout_ty = function 222 | | Ty_const name -> string name 223 | | Ty_arr ([ aty ], rty) -> 224 | (* Check if we can layout this as simply as [aty -> try] in case of a 225 | single argument. *) 226 | (if is_ty_arr aty then 227 | (* If the single arg is the Ty_arr we need to wrap it in parens. *) 228 | parens (layout_ty aty) 229 | else layout_ty aty) 230 | ^^ string " -> " 231 | ^^ layout_ty rty 232 | | Ty_arr (atys, rty) -> 233 | let sep = comma ^^ blank 1 in 234 | parens (separate sep (List.map atys ~f:layout_ty)) 235 | ^^ string " -> " 236 | ^^ layout_ty rty 237 | | Ty_app (fty, atys) -> 238 | let sep = comma ^^ blank 1 in 239 | layout_ty fty ^^ brackets (separate sep (List.map atys ~f:layout_ty)) 240 | | Ty_var var -> ( 241 | let data = Union_find.value var in 242 | match data.ty with 243 | | None -> layout_con_var' ~names var 244 | | Some ty -> layout_ty ty) 245 | in 246 | layout_ty ty 247 | 248 | and layout_con_var' ~names v = 249 | let open PPrint in 250 | let v = Union_find.value v in 251 | match v.ty with 252 | | Some ty -> layout_ty' ~names ty 253 | | None -> ( 254 | match Names.lookup names v.id with 255 | | Some name -> 256 | if Debug.log_levels then string name ^^ parens (layout_var v) 257 | else string name 258 | | None -> layout_var v) 259 | 260 | and layout_ty_sch' ~names ty_sch = 261 | let open PPrint in 262 | match ty_sch with 263 | | [], ty -> layout_ty' ~names ty 264 | | vs, ty -> 265 | let vs = layout_var_prenex' ~names vs in 266 | group (vs ^^ layout_ty' ~names ty) 267 | 268 | and layout_var_prenex' ~names vs = 269 | let open PPrint in 270 | let sep = comma ^^ blank 1 in 271 | let vs = 272 | List.map vs ~f:(fun v -> 273 | let v = Union_find.value v in 274 | string (Names.alloc names v.id)) 275 | in 276 | separate sep vs ^^ string " . " 277 | 278 | module Expr = struct 279 | type t = expr 280 | 281 | include ( 282 | Showable (struct 283 | type t = expr 284 | 285 | let layout e = layout_expr' ~names:(Names.make ()) e 286 | end) : 287 | SHOWABLE with type t := t) 288 | 289 | include ( 290 | Dumpable (struct 291 | type t = expr 292 | 293 | let sexp_of_t = sexp_of_expr 294 | end) : 295 | DUMPABLE with type t := t) 296 | end 297 | 298 | module Ty = struct 299 | type t = ty 300 | 301 | let arr a b = Ty_arr (a, b) 302 | 303 | let var var = Ty_var var 304 | 305 | include ( 306 | Showable (struct 307 | type t = ty 308 | 309 | let layout ty = layout_ty' ~names:(Names.make ()) ty 310 | end) : 311 | SHOWABLE with type t := t) 312 | 313 | include ( 314 | Dumpable (struct 315 | type t = ty 316 | 317 | let sexp_of_t = sexp_of_ty 318 | end) : 319 | DUMPABLE with type t := t) 320 | end 321 | 322 | module Ty_sch = struct 323 | type t = ty_sch 324 | 325 | include ( 326 | Showable (struct 327 | type t = ty_sch 328 | 329 | let layout ty_sch = layout_ty_sch' ~names:(Names.make ()) ty_sch 330 | end) : 331 | SHOWABLE with type t := t) 332 | 333 | include ( 334 | Dumpable (struct 335 | type t = ty_sch 336 | 337 | let sexp_of_t = sexp_of_ty_sch 338 | end) : 339 | DUMPABLE with type t := t) 340 | end 341 | -------------------------------------------------------------------------------- /hmx/test/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name test_hmx) 3 | (inline_tests) 4 | (preprocess 5 | (pps ppx_expect)) 6 | (libraries base hmx)) 7 | -------------------------------------------------------------------------------- /hmx/test/test_infer.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | open Hmx 3 | 4 | let env = 5 | Env.empty 6 | |> Env.assume_val "fix" "a . (a -> a) -> a" 7 | |> Env.assume_val "head" "a . list[a] -> a" 8 | |> Env.assume_val "tail" "a . list[a] -> list[a]" 9 | |> Env.assume_val "nil" "a . list[a]" 10 | |> Env.assume_val "cons" "a . (a, list[a]) -> list[a]" 11 | |> Env.assume_val "cons_curry" "a . a -> list[a] -> list[a]" 12 | |> Env.assume_val "map" "a, b . (a -> b, list[a]) -> list[b]" 13 | |> Env.assume_val "map_curry" "a, b . (a -> b) -> list[a] -> list[b]" 14 | |> Env.assume_val "one" "int" 15 | |> Env.assume_val "zero" "int" 16 | |> Env.assume_val "succ" "int -> int" 17 | |> Env.assume_val "plus" "(int, int) -> int" 18 | |> Env.assume_val "eq" "a . (a, a) -> bool" 19 | |> Env.assume_val "eq_curry" "a . a -> a -> bool" 20 | |> Env.assume_val "not" "bool -> bool" 21 | |> Env.assume_val "true" "bool" 22 | |> Env.assume_val "false" "bool" 23 | |> Env.assume_val "pair" "a, b . (a, b) -> pair[a, b]" 24 | |> Env.assume_val "pair_curry" "a, b . a -> b -> pair[a, b]" 25 | |> Env.assume_val "first" "a, b . pair[a, b] -> a" 26 | |> Env.assume_val "second" "a, b . pair[a, b] -> b" 27 | |> Env.assume_val "id" "a . a -> a" 28 | |> Env.assume_val "const" "a, b . a -> b -> a" 29 | |> Env.assume_val "apply" "a, b . (a -> b, a) -> b" 30 | |> Env.assume_val "apply_curry" "a, b . (a -> b) -> a -> b" 31 | |> Env.assume_val "choose" "a . (a, a) -> a" 32 | |> Env.assume_val "choose_curry" "a . a -> a -> a" 33 | |> Env.assume_val "age" "int" 34 | |> Env.assume_val "world" "string" 35 | |> Env.assume_val "print" "string -> string" 36 | 37 | let infer ?(env = env) code = 38 | Var.reset (); 39 | let prog = Expr.parse_string code in 40 | match infer ~env prog with 41 | | Ok e -> Caml.Format.printf "%s@.|" (Expr.show e) 42 | | Error err -> Caml.Format.printf "ERROR: %s@.|" (Type_error.show err) 43 | 44 | let%expect_test "" = 45 | infer "world"; 46 | [%expect {| 47 | let _ : string = world in 48 | _ 49 | | |}] 50 | 51 | let%expect_test "" = 52 | infer "print"; 53 | [%expect {| 54 | let _ : string -> string = print in 55 | _ 56 | | |}] 57 | 58 | let%expect_test "" = 59 | infer "let x = world in x"; 60 | [%expect 61 | {| 62 | let _ : string = 63 | let x : string = world in 64 | x 65 | in 66 | _ 67 | | |}] 68 | 69 | let%expect_test "" = 70 | infer "fun () -> world"; 71 | [%expect {| 72 | let _ : () -> string = fun () -> world in 73 | _ 74 | | |}] 75 | 76 | let%expect_test "" = 77 | infer "let x = fun () -> world in world"; 78 | [%expect 79 | {| 80 | let _ : string = 81 | let x : () -> string = fun () -> world in 82 | world 83 | in 84 | _ 85 | | |}] 86 | 87 | let%expect_test "" = 88 | infer "let x = fun () -> world in x"; 89 | [%expect 90 | {| 91 | let _ : () -> string = 92 | let x : () -> string = fun () -> world in 93 | x 94 | in 95 | _ 96 | | |}] 97 | 98 | let%expect_test "" = 99 | infer "print(world)"; 100 | [%expect {| 101 | let _ : string = print(world) in 102 | _ 103 | | |}] 104 | 105 | let%expect_test "" = 106 | infer "let hello = fun msg -> print(msg) in hello(world)"; 107 | [%expect 108 | {| 109 | let _ : string = 110 | let hello : string -> string = fun msg -> print(msg) in 111 | hello(world) 112 | in 113 | _ 114 | | |}] 115 | 116 | let%expect_test "" = 117 | infer "fun x -> let y = fun z -> z in y"; 118 | [%expect 119 | {| 120 | let _ : a, b . a -> b -> b = 121 | fun x -> 122 | let y : c . c -> c = fun z -> z in y 123 | in 124 | _ 125 | | |}] 126 | 127 | let%expect_test "" = 128 | infer "fun x -> let y = x in y"; 129 | [%expect 130 | {| 131 | let _ : a . a -> a = 132 | fun x -> 133 | let y : a = x in y 134 | in 135 | _ 136 | | |}] 137 | 138 | let%expect_test "" = 139 | infer "fun x -> let y = fun z -> x in y"; 140 | [%expect 141 | {| 142 | let _ : a, b . b -> a -> b = 143 | fun x -> 144 | let y : c . c -> b = fun z -> x in y 145 | in 146 | _ 147 | | |}] 148 | 149 | let%expect_test "" = 150 | infer "id"; 151 | [%expect {| 152 | let _ : a . a -> a = id in 153 | _ 154 | | |}] 155 | 156 | let%expect_test "" = 157 | infer "one"; 158 | [%expect {| 159 | let _ : int = one in 160 | _ 161 | | |}] 162 | 163 | let%expect_test "" = 164 | infer "x"; 165 | [%expect {| 166 | ERROR: unknown name: x 167 | | |}] 168 | 169 | let%expect_test "" = 170 | infer "let x = x in x"; 171 | [%expect {| 172 | ERROR: unknown name: x 173 | | |}] 174 | 175 | let%expect_test "" = 176 | infer "let x = id in x"; 177 | [%expect 178 | {| 179 | let _ : a . a -> a = 180 | let x : b . b -> b = id in 181 | x 182 | in 183 | _ 184 | | |}] 185 | 186 | let%expect_test "" = 187 | infer "let x = fun y -> y in x"; 188 | [%expect 189 | {| 190 | let _ : a . a -> a = 191 | let x : b . b -> b = fun y -> y in 192 | x 193 | in 194 | _ 195 | | |}] 196 | 197 | let%expect_test "" = 198 | infer "fun x -> x"; 199 | [%expect {| 200 | let _ : a . a -> a = fun x -> x in 201 | _ 202 | | |}] 203 | 204 | let%expect_test "" = 205 | infer "pair"; 206 | [%expect {| 207 | let _ : a, b . (b, a) -> pair[b, a] = pair in 208 | _ 209 | | |}] 210 | 211 | let%expect_test "" = 212 | infer "fun x -> let y = fun z -> z in y"; 213 | [%expect 214 | {| 215 | let _ : a, b . a -> b -> b = 216 | fun x -> 217 | let y : c . c -> c = fun z -> z in y 218 | in 219 | _ 220 | | |}] 221 | 222 | let%expect_test "" = 223 | infer "let f = fun x -> x in let id = fun y -> y in eq(f, id)"; 224 | [%expect 225 | {| 226 | let _ : bool = 227 | let f : a . a -> a = fun x -> x in 228 | let id : b . b -> b = fun y -> y in 229 | eq(f, id) 230 | in 231 | _ 232 | | |}] 233 | 234 | let%expect_test "" = 235 | infer "let f = fun x -> x in let id = fun y -> y in eq_curry(f)(id)"; 236 | [%expect 237 | {| 238 | let _ : bool = 239 | let f : a . a -> a = fun x -> x in 240 | let id : b . b -> b = fun y -> y in 241 | eq_curry(f)(id) 242 | in 243 | _ 244 | | |}] 245 | 246 | let%expect_test "" = 247 | infer "let f = fun x -> x in eq(f, succ)"; 248 | [%expect 249 | {| 250 | let _ : bool = 251 | let f : a . a -> a = fun x -> x in 252 | eq(f, succ) 253 | in 254 | _ 255 | | |}] 256 | 257 | let%expect_test "" = 258 | infer "let f = fun x -> x in eq_curry(f)(succ)"; 259 | [%expect 260 | {| 261 | let _ : bool = 262 | let f : a . a -> a = fun x -> x in 263 | eq_curry(f)(succ) 264 | in 265 | _ 266 | | |}] 267 | 268 | let%expect_test "" = 269 | infer "let f = fun x -> x in pair(f(one), f(true))"; 270 | [%expect 271 | {| 272 | let _ : pair[int, bool] = 273 | let f : a . a -> a = fun x -> x in 274 | pair(f(one), f(true)) 275 | in 276 | _ 277 | | |}] 278 | 279 | let%expect_test "" = 280 | infer "fun f -> pair(f(one), f(true))"; 281 | [%expect 282 | {| 283 | ERROR: incompatible types: 284 | int 285 | and 286 | bool 287 | | |}] 288 | 289 | let%expect_test "" = 290 | infer "let f = fun (x, y) -> let a = eq(x, y) in eq(x, y) in f"; 291 | [%expect 292 | {| 293 | let _ : a . (a, a) -> bool = 294 | let f : b . (b, b) -> bool = 295 | fun (x, y) -> 296 | let a : bool = eq(x, y) in eq(x, y) 297 | in 298 | f 299 | in 300 | _ 301 | | |}] 302 | 303 | let%expect_test "" = 304 | infer "let f = fun (x, y) -> let a = eq_curry(x)(y) in eq_curry(x)(y) in f"; 305 | [%expect 306 | {| 307 | let _ : a . (a, a) -> bool = 308 | let f : b . (b, b) -> bool = 309 | fun (x, y) -> 310 | let a : bool = eq_curry(x)(y) in eq_curry(x)(y) 311 | in 312 | f 313 | in 314 | _ 315 | | |}] 316 | 317 | let%expect_test "" = 318 | infer "id(id)"; 319 | [%expect {| 320 | let _ : a . a -> a = id(id) in 321 | _ 322 | | |}] 323 | 324 | let%expect_test "" = 325 | infer "choose(fun (x, y) -> x, fun (x, y) -> y)"; 326 | [%expect 327 | {| 328 | let _ : a . (a, a) -> a = 329 | choose(fun (x, y) -> x, fun (x, y) -> y) 330 | in 331 | _ 332 | | |}] 333 | 334 | let%expect_test "" = 335 | infer "choose_curry(fun (x, y) -> x)(fun (x, y) -> y)"; 336 | [%expect 337 | {| 338 | let _ : a . (a, a) -> a = 339 | choose_curry(fun (x, y) -> x)(fun (x, y) -> y) 340 | in 341 | _ 342 | | |}] 343 | 344 | let%expect_test "" = 345 | infer "let x = id in let y = let z = x(id) in z in y"; 346 | [%expect 347 | {| 348 | let _ : a . a -> a = 349 | let x : b . b -> b = id in 350 | let y : c . c -> c = 351 | let z : d . d -> d = x(id) in 352 | z 353 | in 354 | y 355 | in 356 | _ 357 | | |}] 358 | 359 | let%expect_test "" = 360 | infer "cons(id, nil)"; 361 | [%expect {| 362 | let _ : a . list[a -> a] = cons(id, nil) in 363 | _ 364 | | |}] 365 | 366 | let%expect_test "" = 367 | infer "cons_curry(id)(nil)"; 368 | [%expect 369 | {| 370 | let _ : a . list[a -> a] = cons_curry(id)(nil) in 371 | _ 372 | | |}] 373 | 374 | let%expect_test "" = 375 | infer "let lst1 = cons(id, nil) in let lst2 = cons(succ, lst1) in lst2"; 376 | [%expect 377 | {| 378 | let _ : list[int -> int] = 379 | let lst1 : a . list[a -> a] = cons(id, nil) in 380 | let lst2 : list[int -> int] = cons(succ, lst1) in 381 | lst2 382 | in 383 | _ 384 | | |}] 385 | 386 | let%expect_test "" = 387 | infer "cons_curry(id)(cons_curry(succ)(cons_curry(id)(nil)))"; 388 | [%expect 389 | {| 390 | let _ : list[int -> int] = 391 | cons_curry(id)(cons_curry(succ)(cons_curry(id)(nil))) 392 | in 393 | _ 394 | | |}] 395 | 396 | let%expect_test "" = 397 | infer "plus(one, true)"; 398 | [%expect 399 | {| 400 | ERROR: incompatible types: 401 | int 402 | and 403 | bool 404 | | |}] 405 | 406 | let%expect_test "" = 407 | infer "plus(one)"; 408 | [%expect 409 | {| 410 | ERROR: incompatible types: 411 | _2 -> _1 412 | and 413 | (int, int) -> int 414 | | |}] 415 | 416 | let%expect_test "" = 417 | infer "fun x -> let y = x in y"; 418 | [%expect 419 | {| 420 | let _ : a . a -> a = 421 | fun x -> 422 | let y : a = x in y 423 | in 424 | _ 425 | | |}] 426 | 427 | let%expect_test "" = 428 | infer "fun x -> let y = let z = x(fun x -> x) in z in y"; 429 | [%expect 430 | {| 431 | let _ : a, b . ((a -> a) -> b) -> b = 432 | fun x -> 433 | let y : b = 434 | let z : b = x(fun x -> x) in 435 | z 436 | in 437 | y 438 | in 439 | _ 440 | | |}] 441 | 442 | let%expect_test "" = 443 | infer "fun x -> fun y -> let x = x(y) in x(y)"; 444 | [%expect 445 | {| 446 | let _ : a, b . (a -> a -> b) -> a -> b = 447 | fun x -> 448 | fun y -> 449 | let x : a -> b = x(y) in x(y) 450 | in 451 | _ 452 | | |}] 453 | 454 | let%expect_test "" = 455 | infer "fun x -> let y = fun z -> x(z) in y"; 456 | [%expect 457 | {| 458 | let _ : a, b . (a -> b) -> a -> b = 459 | fun x -> 460 | let y : a -> b = fun z -> x(z) in y 461 | in 462 | _ 463 | | |}] 464 | 465 | let%expect_test "" = 466 | infer "fun x -> let y = fun z -> x in y"; 467 | [%expect 468 | {| 469 | let _ : a, b . b -> a -> b = 470 | fun x -> 471 | let y : c . c -> b = fun z -> x in y 472 | in 473 | _ 474 | | |}] 475 | 476 | let%expect_test "" = 477 | infer "fun x -> fun y -> let x = x(y) in fun x -> y(x)"; 478 | [%expect 479 | {| 480 | let _ : a, b, c . ((b -> c) -> a) -> (b -> c) -> b -> c = 481 | fun x -> 482 | fun y -> 483 | let x : a = x(y) in fun x -> y(x) 484 | in 485 | _ 486 | | |}] 487 | 488 | let%expect_test "" = 489 | infer "fun x -> let y = x in y(y)"; 490 | [%expect {| 491 | ERROR: recursive type 492 | | |}] 493 | 494 | let%expect_test "" = 495 | infer "fun x -> let y = fun z -> z in y(y)"; 496 | [%expect 497 | {| 498 | let _ : a, b . a -> b -> b = 499 | fun x -> 500 | let y : c . c -> c = fun z -> z in y(y) 501 | in 502 | _ 503 | | |}] 504 | 505 | let%expect_test "" = 506 | infer "fun x -> x(x)"; 507 | [%expect {| 508 | ERROR: recursive type 509 | | |}] 510 | 511 | let%expect_test "" = 512 | infer "one(id)"; 513 | [%expect 514 | {| 515 | ERROR: incompatible types: 516 | _2 -> _1 517 | and 518 | int 519 | | |}] 520 | 521 | let%expect_test "" = 522 | infer "fun f -> let x = fun (g, y) -> let _ = g(y) in eq(f, g) in x"; 523 | [%expect 524 | {| 525 | let _ : a, b . (a -> b) -> (a -> b, a) -> bool = 526 | fun f -> 527 | let x : (a -> b, a) -> bool = 528 | fun (g, y) -> 529 | let _ : b = g(y) in eq(f, g) 530 | in 531 | x 532 | in 533 | _ 534 | | |}] 535 | 536 | let%expect_test "" = 537 | infer "let const = fun x -> fun y -> x in const"; 538 | [%expect 539 | {| 540 | let _ : a, b . b -> a -> b = 541 | let const : c, d . d -> c -> d = fun x -> fun y -> x in 542 | const 543 | in 544 | _ 545 | | |}] 546 | 547 | let%expect_test "" = 548 | infer "let apply = fun (f, x) -> f(x) in apply"; 549 | [%expect 550 | {| 551 | let _ : a, b . (a -> b, a) -> b = 552 | let apply : c, d . (c -> d, c) -> d = 553 | fun (f, x) -> f(x) 554 | in 555 | apply 556 | in 557 | _ 558 | | |}] 559 | 560 | let%expect_test "" = 561 | infer "let apply_curry = fun f -> fun x -> f(x) in apply_curry"; 562 | [%expect 563 | {| 564 | let _ : a, b . (a -> b) -> a -> b = 565 | let apply_curry : c, d . (c -> d) -> c -> d = 566 | fun f -> fun x -> f(x) 567 | in 568 | apply_curry 569 | in 570 | _ 571 | | |}] 572 | -------------------------------------------------------------------------------- /hmx/type_error.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | open Syntax 3 | 4 | type t = 5 | | Error_unification of ty * ty 6 | | Error_recursive_type 7 | | Error_unknown_name of string 8 | 9 | include ( 10 | Showable (struct 11 | type nonrec t = t 12 | 13 | let layout = 14 | let open PPrint in 15 | function 16 | | Error_unification (ty1, ty2) -> 17 | string "incompatible types:" 18 | ^^ nest 2 (break 1 ^^ Ty.layout ty1) 19 | ^^ break 1 20 | ^^ string "and" 21 | ^^ nest 2 (break 1 ^^ Ty.layout ty2) 22 | | Error_recursive_type -> string "recursive type" 23 | | Error_unknown_name name -> string "unknown name: " ^^ string name 24 | end) : 25 | SHOWABLE with type t := t) 26 | 27 | exception Type_error of t 28 | 29 | let raise error = raise (Type_error error) 30 | -------------------------------------------------------------------------------- /hmx/union_find.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | 3 | type 'a loc = Root of 'a | Link of 'a t 4 | 5 | and 'a t = 'a loc ref [@@deriving sexp_of] 6 | 7 | let make value = ref (Root value) 8 | 9 | let rec root p : _ t = 10 | match p.contents with 11 | | Root _ -> p 12 | | Link p' -> 13 | let p'' = root p' in 14 | (* Perform path compression. *) 15 | if not (phys_equal p' p'') then p.contents <- p'.contents; 16 | p'' 17 | 18 | let value p = 19 | match (root p).contents with 20 | | Root value -> value 21 | | Link _ -> assert false 22 | 23 | let union ~f a b = 24 | if phys_equal a b then () 25 | else 26 | let a = root a in 27 | let b = root b in 28 | if phys_equal a b then () 29 | else 30 | match (a.contents, b.contents) with 31 | | Root avalue, Root bvalue -> 32 | a.contents <- Link b; 33 | b.contents <- Root (f avalue bvalue) 34 | | Root _, Link _ 35 | | Link _, Root _ 36 | | Link _, Link _ -> 37 | assert false 38 | 39 | let equal a b = phys_equal a b || phys_equal (root a) (root b) 40 | -------------------------------------------------------------------------------- /hmx/union_find.mli: -------------------------------------------------------------------------------- 1 | (** Union find. *) 2 | 3 | open! Base 4 | 5 | type 'a t 6 | (** Represents a single element. 7 | 8 | Each element belongs to an equivalence class and each equivalence class has 9 | a value of type ['a] assocated with it. *) 10 | 11 | val make : 'a -> 'a t 12 | (** [make v] creates a new equivalence class consisting of a single element 13 | which is returned to the caller. 14 | 15 | The value [v] is assocated with the equivalence class being created. *) 16 | 17 | val value : 'a t -> 'a 18 | (** [value e] returns the value associated with equivalence class the element 19 | [e] belongs to. *) 20 | 21 | val union : f:('a -> 'a -> 'a) -> 'a t -> 'a t -> unit 22 | (** [union ~f a b] makes elements [a] and [b] belong to the same equivalence 23 | class so that [equal a b] returns [true] afterwards. 24 | 25 | The resulted value associated with the equivalence class is being merged as 26 | specified by the [f] function. *) 27 | 28 | val equal : 'a t -> 'a t -> bool 29 | (** [equal a b] checks that both elements [a] and [b] belong to the same 30 | equivalence class. *) 31 | 32 | val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t 33 | -------------------------------------------------------------------------------- /hmx/var.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | open Syntax 3 | 4 | type t = var 5 | 6 | module Id = MakeId () 7 | 8 | let fresh ?lvl () : var = 9 | let id = Id.fresh () in 10 | Union_find.make { ty = None; lvl; id } 11 | 12 | let reset = Id.reset 13 | 14 | let ty v = (Union_find.value v).ty 15 | 16 | let set_ty ty v = (Union_find.value v).ty <- Some ty 17 | 18 | let lvl var = 19 | let v = Union_find.value var in 20 | match v.lvl with 21 | | Some lvl -> lvl 22 | | None -> failwith (Printf.sprintf "%i has no lvl assigned" v.id) 23 | 24 | let set_lvl lvl v = (Union_find.value v).lvl <- Some lvl 25 | 26 | let equal = Union_find.equal 27 | 28 | let show v = 29 | let data = Union_find.value v in 30 | match data.ty with 31 | | None -> Printf.sprintf "_%i" data.id 32 | | Some ty -> Ty.show ty 33 | 34 | let merge_lvl lvl1 lvl2 = 35 | match (lvl1, lvl2) with 36 | | None, None 37 | | Some _, None 38 | | None, Some _ -> 39 | failwith "lvl is not assigned" 40 | | Some lvl1, Some lvl2 -> Some (min lvl1 lvl2) 41 | 42 | (** [occurs_check_adjust_lvl var ty] checks that variable [var] is not 43 | contained within type [ty] and adjust levels of all unbound vars within 44 | the [ty]. *) 45 | let occurs_check_adjust_lvl var = 46 | let rec occurs_check_ty ty' : unit = 47 | match ty' with 48 | | Ty_const _ -> () 49 | | Ty_arr (args, ret) -> 50 | List.iter args ~f:occurs_check_ty; 51 | occurs_check_ty ret 52 | | Ty_app (f, args) -> 53 | occurs_check_ty f; 54 | List.iter args ~f:occurs_check_ty 55 | | Ty_var other_var -> ( 56 | match ty other_var with 57 | | Some ty' -> occurs_check_ty ty' 58 | | None -> 59 | if equal other_var var then Type_error.raise Error_recursive_type 60 | else 61 | let data = Union_find.value var 62 | and odata = Union_find.value other_var in 63 | odata.lvl <- merge_lvl data.lvl odata.lvl) 64 | in 65 | occurs_check_ty 66 | 67 | let unify var1 var2 = 68 | let merge v1 v2 = 69 | let v = 70 | match (v1.ty, v2.ty) with 71 | | Some _, Some _ 72 | | Some _, None 73 | | None, None -> 74 | v1 75 | | None, Some _ -> v2 76 | in 77 | v.lvl <- merge_lvl v1.lvl v2.lvl; 78 | v 79 | in 80 | match (ty var1, ty var2) with 81 | | Some ty1, Some ty2 -> 82 | Union_find.union var1 var2 ~f:merge; 83 | Some (ty1, ty2) 84 | | Some ty1, None -> 85 | occurs_check_adjust_lvl var2 ty1; 86 | Union_find.union var1 var2 ~f:merge; 87 | None 88 | | None, Some ty2 -> 89 | occurs_check_adjust_lvl var1 ty2; 90 | Union_find.union var1 var2 ~f:merge; 91 | None 92 | | None, None -> 93 | Union_find.union var1 var2 ~f:merge; 94 | None 95 | -------------------------------------------------------------------------------- /hmx/var.mli: -------------------------------------------------------------------------------- 1 | open Syntax 2 | 3 | type t = var 4 | 5 | val fresh : ?lvl:lvl -> unit -> t 6 | 7 | val reset : unit -> unit 8 | 9 | val equal : t -> t -> bool 10 | 11 | val show : t -> string 12 | 13 | val lvl : t -> lvl 14 | 15 | val set_lvl : lvl -> t -> unit 16 | 17 | val ty : t -> ty option 18 | 19 | val set_ty : ty -> t -> unit 20 | 21 | val unify : t -> t -> (ty * ty) option 22 | 23 | val occurs_check_adjust_lvl : t -> ty -> unit 24 | -------------------------------------------------------------------------------- /hmx_tc/bin/dune: -------------------------------------------------------------------------------- 1 | (executable 2 | (public_name hmx_tc) 3 | (name main) 4 | (libraries hmx_tc base stdio)) 5 | -------------------------------------------------------------------------------- /hmx_tc/bin/main.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | open Hmx_tc 3 | 4 | let () = 5 | let env = 6 | Env.empty 7 | |> Env.assume_val "one" "int" 8 | |> Env.assume_val "hello" "string" 9 | |> Env.assume_val "world" "string" 10 | |> Env.assume_val "pair" "a, b . (a, b) -> pair[a, b]" 11 | |> Env.assume_val "triple" "a, b, c . (a, b, c) -> triple[a, b, c]" 12 | |> Env.assume_val "quadruple" 13 | "a, b, c, e . (a, b, c, e) -> quadruple[a, b, c, e]" 14 | |> Env.assume_val "plus" "(int, int) -> int" 15 | |> Env.assume_val "true" "bool" 16 | |> Env.assume_val "nil" "a . list[a]" 17 | |> Env.assume_val "cons" "a . (a, list[a]) -> list[a]" 18 | (* eq *) 19 | |> Env.assume_tclass "eq" "a . (a, a) -> bool" 20 | |> Env.assume_tclass_instance "eq_int" "eq[int]" 21 | |> Env.assume_tclass_instance "eq_bool" "eq[bool]" 22 | |> Env.assume_tclass_instance "eq_list" "a . eq[a] => eq[list[a]]" 23 | |> Env.assume_tclass_instance "eq_pair" 24 | "a, b . eq[a], eq[b] => eq[pair[a, b]]" 25 | (* compare[a] *) 26 | |> Env.assume_tclass "compare" "a . eq[a] => (a, a) -> bool" 27 | |> Env.assume_tclass_instance "compare_list" 28 | "a . compare[a] => compare[list[a]]" 29 | |> Env.assume_tclass_instance "compare_int" "compare[int]" 30 | |> Env.assume_tclass_instance "compare_bool" "compare[bool]" 31 | (* coerce[a, b] *) 32 | |> Env.assume_tclass "coerce" "a, b . a -> b" 33 | |> Env.assume_tclass_instance "coerce_id" "a . coerce[a, a]" 34 | |> Env.assume_tclass_instance "coerce_list" 35 | "a, b . coerce[a, b] => coerce[list[a], list[b]]" 36 | |> Env.assume_tclass_instance "coerce_bool_int" "coerce[bool, int]" 37 | (* show *) 38 | |> Env.assume_tclass "show" "a . a -> string" 39 | |> Env.assume_tclass_instance "show_int" "show[int]" 40 | |> Env.assume_tclass_instance "show_float" "show[float]" 41 | (* (1* read *1) *) 42 | |> Env.assume_tclass "read" "a . string -> a" 43 | |> Env.assume_tclass_instance "read_int" "read[int]" 44 | |> Env.assume_tclass_instance "read_float" "read[float]" 45 | (* |> Env.assume_tclass "eq2" "a . (a, a) -> bool" *) 46 | (* |> Env.assume_tclass_instance "eq2_int" "eq2[int]" *) 47 | (* |> Env.assume_tclass_instance "eq2_list" "a . eq2[a] => eq2[list[a]]" *) 48 | (* |> Env.assume_val "ch" "a . compare[a], eq2[a] => (a, a) -> bool" *) 49 | in 50 | let e = Expr.parse_chan Stdio.stdin in 51 | match infer ~env e with 52 | | Ok e -> Caml.Format.printf "%s@." (Expr.show e) 53 | | Error err -> Caml.Format.printf "ERROR: %s@." (Type_error.show err) 54 | -------------------------------------------------------------------------------- /hmx_tc/constraint.ml: -------------------------------------------------------------------------------- 1 | (** 2 | 3 | This module defines constraint language. 4 | 5 | The constraint language has applicative structure (see [C_map] constructor) 6 | which is used for elaboration (the constraint solving algo computes a value). 7 | 8 | *) 9 | 10 | open Base 11 | open Syntax 12 | 13 | type _ t = 14 | | C_trivial : unit t 15 | (** A trivial constraint, states nothing useful. Always can be solved. *) 16 | | C_eq : ty * ty -> unit t 17 | (** [C_eq (ty1, ty2)] states that the types [ty1] and [ty2] are equal. *) 18 | | C_inst : name * ty -> expr t 19 | (** [C_inst (name, ty)] states that [name] should be fetched from the 20 | environment, instantiated and equated to [ty]. *) 21 | | C_and : 'a t * 'b t -> ('a * 'b) t 22 | (** Conjuction of two constraints, possibly of different value type. *) 23 | | C_and_list : 'a t list -> 'a list t 24 | (** Conjuction of multiple constraints of the same value type. *) 25 | | C_exists : var list * 'a t -> 'a t 26 | (** [C_exists (vs, c)] existentially quantifies variables [vs] over [c]. *) 27 | | C_let : 28 | (name * var list * expr t * ty) list * 'b t 29 | -> ((expr * ty_sch) list * 'b) t 30 | (** [C_let (bindings, c)] works is a constraint abstraction fused with 31 | existential quantification. It adds [bindings] to the environment of 32 | the following constraint [c]. 33 | 34 | It allows to define multiple names at once to support n-ary functions. *) 35 | | C_map : 'a t * ('a -> 'b) -> 'b t 36 | (** Map operation, this gives an applicative structure. *) 37 | 38 | let trivial = C_trivial 39 | 40 | let return v = C_map (C_trivial, fun () -> v) 41 | 42 | let ( =~ ) x y = C_eq (x, y) 43 | 44 | let inst name cty = C_inst (name, cty) 45 | 46 | let exists tys c = 47 | match tys with 48 | | [] -> c 49 | | tys -> ( 50 | match c with 51 | | C_exists (tys', c) -> C_exists (tys @ tys', c) 52 | | c -> C_exists (tys, c)) 53 | 54 | let let_in bindings c = C_let (bindings, c) 55 | 56 | let ( &~ ) x y = C_and (x, y) 57 | 58 | let ( >>| ) c f = C_map (c, f) 59 | 60 | let list cs = C_and_list cs 61 | 62 | let rec layout' : type a. names:Names.t -> a t -> PPrint.document = 63 | fun ~names c -> 64 | let open PPrint in 65 | match c with 66 | | C_trivial -> string "TRUE" 67 | | C_eq (ty1, ty2) -> 68 | layout_ty' ~names ty1 ^^ string " = " ^^ layout_ty' ~names ty2 69 | | C_and (a, b) -> layout' ~names a ^^ string " & " ^^ layout' ~names b 70 | | C_and_list cs -> 71 | let sep = string " & " in 72 | separate sep (List.map cs ~f:(layout' ~names)) 73 | | C_exists (vs, c) -> 74 | string "∃" 75 | ^^ separate comma (List.map vs ~f:(layout_con_var' ~names)) 76 | ^^ dot 77 | ^^ parens (layout' ~names c) 78 | | C_let (bindings, c) -> 79 | let layout_cty' : type a. a t * ty -> document = function 80 | | C_trivial, ty -> layout_ty' ~names ty 81 | | c, ty -> layout' ~names c ^^ string " => " ^^ layout_ty' ~names ty 82 | in 83 | let layout_binding : type a. string * var list * a t * ty -> document = 84 | fun (name, vs, c, ty) -> 85 | string name 86 | ^^ string " : " 87 | ^^ 88 | match vs with 89 | | [] -> layout_cty' (c, ty) 90 | | vs -> 91 | let vs = layout_var_prenex' ~names vs in 92 | vs ^^ layout_cty' (c, ty) 93 | in 94 | let sep = comma ^^ blank 1 in 95 | string "let " 96 | ^^ separate sep (List.map bindings ~f:layout_binding) 97 | ^^ string " in " 98 | ^^ layout' ~names c 99 | | C_inst (name, ty) -> string name ^^ string " ≲ " ^^ layout_ty' ~names ty 100 | | C_map (c, _f) -> layout' ~names c 101 | 102 | include ( 103 | struct 104 | type pack = P : _ t -> pack 105 | 106 | include Showable (struct 107 | type t = pack 108 | 109 | let layout (P c) = 110 | let names = Names.make () in 111 | layout' ~names c 112 | end) 113 | 114 | let layout c = layout (P c) 115 | 116 | let show c = show (P c) 117 | 118 | let print ?label c = print ?label (P c) 119 | end : 120 | sig 121 | val layout : _ t -> PPrint.document 122 | 123 | val show : _ t -> string 124 | 125 | val print : ?label:string -> _ t -> unit 126 | end) 127 | -------------------------------------------------------------------------------- /hmx_tc/debug.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | 3 | let flags = Caml.Sys.getenv_opt "HMX_DEBUG" |> Option.value ~default:"" 4 | 5 | let log_levels = String.mem flags 'l' 6 | 7 | let log_solve = String.mem flags 's' 8 | 9 | let log_instantiate = String.mem flags 'i' 10 | 11 | let log_generalize = String.mem flags 'g' 12 | 13 | let log_unify = String.mem flags 'u' 14 | 15 | let log_match = String.mem flags 'm' 16 | 17 | let log_define = String.mem flags 'd' 18 | -------------------------------------------------------------------------------- /hmx_tc/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name hmx_tc) 3 | (preprocess (pps ppx_sexp_conv)) 4 | (libraries base pprint nice_parser) 5 | ) 6 | 7 | (ocamllex lexer) 8 | 9 | (menhir 10 | (modules parser) 11 | (flags --explain --dump)) 12 | -------------------------------------------------------------------------------- /hmx_tc/hmx_tc.ml: -------------------------------------------------------------------------------- 1 | module Expr = struct 2 | include Syntax.Expr 3 | 4 | include Nice_parser.Make (struct 5 | type token = Parser.token 6 | 7 | type result = Syntax.expr 8 | 9 | let parse = Parser.expr_eof 10 | 11 | let next_token = Lexer.token 12 | 13 | exception ParseError = Parser.Error 14 | 15 | exception LexError = Lexer.Error 16 | end) 17 | end 18 | 19 | module Ty = struct 20 | include Syntax.Ty 21 | end 22 | 23 | module Ty_sch = struct 24 | include Syntax.Ty_sch 25 | 26 | include Nice_parser.Make (struct 27 | type token = Parser.token 28 | 29 | type result = Syntax.ty_sch 30 | 31 | let parse = Parser.ty_sch_eof 32 | 33 | let next_token = Lexer.token 34 | 35 | exception ParseError = Parser.Error 36 | 37 | exception LexError = Lexer.Error 38 | end) 39 | end 40 | 41 | module Bound = struct 42 | include Syntax.Bound 43 | end 44 | 45 | module Var = Var 46 | module Type_error = Type_error 47 | 48 | module Env = struct 49 | include Infer.Env 50 | 51 | let assume_val name ty env = add_val env name (Ty_sch.parse_string ty) 52 | 53 | let assume_tclass name ty env = add_tclass env name (Ty_sch.parse_string ty) 54 | 55 | let assume_tclass_instance name ty env = 56 | add_tclass_instance env name (Ty_sch.parse_string ty) 57 | end 58 | 59 | let infer = Infer.infer 60 | -------------------------------------------------------------------------------- /hmx_tc/infer.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | open Syntax 3 | 4 | (** Constraints generation. *) 5 | let rec generate : expr -> ty -> expr Constraint.t = 6 | fun e ty -> 7 | match e with 8 | | E_lit (Lit_int _) -> Constraint.(ty =~ Ty_const "int" >>| fun () -> e) 9 | | E_lit (Lit_string _) -> Constraint.(ty =~ Ty_const "string" >>| fun () -> e) 10 | | E_var name -> Constraint.inst name ty 11 | | E_app (f, args) -> 12 | let vs, arg_tys, args = 13 | List.fold args ~init:([], [], []) ~f:(fun (vs, arg_tys, args) arg -> 14 | let v = Var.fresh () in 15 | let arg_ty = Ty.var v in 16 | let arg = generate arg arg_ty in 17 | (v :: vs, arg_ty :: arg_tys, arg :: args)) 18 | in 19 | let f_ty = Ty.(arr (List.rev arg_tys) ty) in 20 | let f = generate f f_ty in 21 | let args = Constraint.list (List.rev args) in 22 | Constraint.(exists vs (f &~ args) >>| fun (f, args) -> E_app (f, args)) 23 | | E_abs (args, b) -> 24 | let vs, bindings, atys = 25 | List.fold args ~init:([], [], []) ~f:(fun (vs, bindings, atys) param -> 26 | let v = Var.fresh () in 27 | let ty = Ty.var v in 28 | ( v :: vs, 29 | (param, [], Constraint.return (E_var param), ty) :: bindings, 30 | Ty.var v :: atys )) 31 | in 32 | let v = Var.fresh () in 33 | let b = generate b (Ty.var v) in 34 | let ty' = Ty.arr (List.rev atys) (Ty.var v) in 35 | Constraint.( 36 | exists (v :: vs) (let_in (List.rev bindings) b &~ (ty' =~ ty)) 37 | >>| fun ((_tys, b), ()) -> E_abs (args, b)) 38 | | E_let ((name, e, _), b) -> ( 39 | let v = Var.fresh () in 40 | let e_ty = Ty.var v in 41 | let e = generate e e_ty in 42 | let b = generate b ty in 43 | Constraint.( 44 | let_in [ (name, [ v ], e, e_ty) ] b >>| fun (tys, b) -> 45 | match tys with 46 | | [ (e, ty_sch) ] -> E_let ((name, e, Some ty_sch), b) 47 | | _ -> failwith "impossible as we are supplying a single binding")) 48 | | E_tclass_method _ -> 49 | (* This is a part of elaborated terms so we can't reach it here. Consider 50 | splitting source terms and elaborated terms. *) 51 | assert false 52 | 53 | (* [unify ty1 ty2] unifies two types [ty1] and [ty2]. *) 54 | let rec unify ty1 ty2 = 55 | if Debug.log_unify then 56 | Caml.Format.printf "UNIFY %s ~ %s@." (Ty.show ty1) (Ty.show ty2); 57 | if phys_equal ty1 ty2 then () 58 | else 59 | match (ty1, ty2) with 60 | | Ty_const a, Ty_const b -> 61 | if not String.(a = b) then Type_error.raise (Error_unification (ty1, ty2)) 62 | | Ty_app (a1, b1), Ty_app (a2, b2) -> ( 63 | unify a1 a2; 64 | match List.iter2 b1 b2 ~f:unify with 65 | | Unequal_lengths -> Type_error.raise (Error_unification (ty1, ty2)) 66 | | Ok () -> ()) 67 | | Ty_arr (a1, b1), Ty_arr (a2, b2) -> ( 68 | match List.iter2 a1 a2 ~f:unify with 69 | | Unequal_lengths -> Type_error.raise (Error_unification (ty1, ty2)) 70 | | Ok () -> unify b1 b2) 71 | | Ty_var var1, Ty_var var2 -> ( 72 | match Var.unify var1 var2 with 73 | | Some (ty1, ty2) -> unify ty1 ty2 74 | | None -> ()) 75 | | Ty_var var, ty 76 | | ty, Ty_var var -> ( 77 | match Var.ty var with 78 | | None -> 79 | Var.occurs_check_adjust_lvl var ty; 80 | Var.set_ty ty var 81 | | Some ty' -> unify ty' ty) 82 | | _, _ -> Type_error.raise (Error_unification (ty1, ty2)) 83 | 84 | (* [matches ty1 ty2] checks if [ty2] matches [ty1] shape. It does destructive 85 | substitution inside the [ty2] argument but keeps [ty1] intact and therefore 86 | [ty1] can be used in multiple [matches] calls. *) 87 | let rec matches ty1 ty2 = 88 | if phys_equal ty1 ty2 then true 89 | else 90 | match (ty1, ty2) with 91 | | Ty_const a, Ty_const b -> String.(a = b) 92 | | Ty_app (a1, b1), Ty_app (a2, b2) -> ( 93 | matches a1 a2 94 | && 95 | match 96 | List.fold2 b1 b2 ~init:true ~f:(fun ok b1 b2 -> ok && matches b1 b2) 97 | with 98 | | Unequal_lengths -> false 99 | | Ok ok -> ok) 100 | | Ty_arr (a1, b1), Ty_arr (a2, b2) -> ( 101 | match 102 | List.fold2 a1 a2 ~init:true ~f:(fun ok a1 a2 -> ok && matches a1 a2) 103 | with 104 | | Unequal_lengths -> false 105 | | Ok ok -> ok && matches b1 b2) 106 | | Ty_var var1, Ty_var var2 -> ( 107 | if Var.equal var1 var2 then true 108 | else 109 | match (Var.ty var1, Var.ty var2) with 110 | | (Some _ | None), None -> 111 | if Var.bound_lvl var1 < Var.bound_lvl var2 then ( 112 | if Debug.log_match then 113 | Caml.Format.printf "MATCH %s <-! %s@." (Var.show var1) 114 | (Var.show var2); 115 | Union_find.link ~target:var1 var2; 116 | true) 117 | else false 118 | | Some ty1, Some ty2 -> matches ty1 ty2 119 | | None, Some _ -> assert false) 120 | | Ty_var var1, ty2 -> ( 121 | match Var.ty var1 with 122 | | Some ty1 -> matches ty1 ty2 123 | | None -> false) 124 | | ty1, Ty_var var2 -> ( 125 | match Var.ty var2 with 126 | | None -> 127 | Var.set_ty ty1 var2; 128 | true 129 | | Some ty2 -> matches ty1 ty2) 130 | | _, _ -> false 131 | 132 | let matches ty1 ty2 = 133 | if Debug.log_match then 134 | Caml.Format.printf "MATCH %s <-? %s@." (Ty.show ty1) (Ty.show ty2); 135 | matches ty1 ty2 136 | 137 | let matches_bound b1 b2 = matches (Bound.to_ty b1) (Bound.to_ty b2) 138 | 139 | (** A substitution over [ty] terms. *) 140 | module Subst : sig 141 | type t 142 | 143 | val empty : t 144 | 145 | val make : (var * ty) list -> t 146 | 147 | val apply_cty : t -> cty -> cty 148 | 149 | val apply_ty : t -> ty -> ty 150 | 151 | val apply_bound : t -> bound -> bound 152 | end = struct 153 | type t = (var * ty) list 154 | 155 | let empty = [] 156 | 157 | let make pairs = pairs 158 | 159 | let find subst var = 160 | match List.Assoc.find ~equal:Var.equal subst var with 161 | | Some ty -> Some ty 162 | | None -> None 163 | 164 | let rec apply_cty subst (bs, ty) = 165 | (List.map bs ~f:(apply_bound subst), apply_ty subst ty) 166 | 167 | and apply_ty subst ty = 168 | match ty with 169 | | Ty_const _ -> ty 170 | | Ty_app (a, args) -> 171 | Ty_app (apply_ty subst a, List.map args ~f:(apply_ty subst)) 172 | | Ty_arr (args, b) -> 173 | Ty_arr (List.map args ~f:(apply_ty subst), apply_ty subst b) 174 | | Ty_var v -> ( 175 | match Var.ty v with 176 | | Some ty -> apply_ty subst ty 177 | | None -> ( 178 | match find subst v with 179 | | Some ty -> ty 180 | | None -> ty)) 181 | 182 | and apply_bound subst = function 183 | | B_class (name, tys) -> B_class (name, List.map tys ~f:(apply_ty subst)) 184 | end 185 | 186 | (** Implicit params represent function parameters to be filled in by typeclass 187 | instances. *) 188 | module Implicit_param : sig 189 | type t 190 | 191 | val make : lvl:lvl -> string -> t 192 | 193 | val link : target:t -> t -> unit 194 | 195 | val reset : unit -> unit 196 | 197 | val param : t -> name 198 | end = struct 199 | type t = name Union_find.t 200 | 201 | module Id = MakeId () 202 | 203 | let make ~lvl name = 204 | let name = Printf.sprintf "_%s_%i_%i" name lvl (Id.fresh ()) in 205 | Union_find.make name 206 | 207 | let link = Union_find.link 208 | 209 | let param = Union_find.value 210 | 211 | let reset = Id.reset 212 | end 213 | 214 | (** An applicative structure which is used to build a computation which 215 | elaborates terms. *) 216 | module Elab : sig 217 | type 'a t 218 | 219 | val run : 'a t -> 'a 220 | (** Compute a value. 221 | 222 | This should be run at the very end, when all holes are elaborated. *) 223 | 224 | val return : 'a -> 'a t 225 | 226 | val map : 'a t -> ('a -> 'b) -> 'b t 227 | 228 | val both : 'a t -> 'b t -> ('a * 'b) t 229 | 230 | val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t 231 | 232 | val ( and+ ) : 'a t -> 'b t -> ('a * 'b) t 233 | 234 | val list : 'a t list -> 'a list t 235 | 236 | module Hole : sig 237 | type 'a hole = private 238 | | Hole_expr of { bounds : bound list; mutable elab : 'a t } 239 | | Hole_method of { 240 | bound : bound; 241 | name : string; 242 | mutable elab : 'a t option; 243 | } 244 | | Hole_implicit_param of 'a implicit_param_hole 245 | 246 | and 'a implicit_param_hole = private { 247 | bound : bound; 248 | param : Implicit_param.t; 249 | mutable elab : 'a t; 250 | } 251 | 252 | val implicit_param : 253 | lvl:lvl -> bound -> name -> expr t * expr implicit_param_hole 254 | 255 | val make_implicit_param : expr implicit_param_hole -> expr t * expr hole 256 | 257 | val make_expr : bound list -> expr -> expr t * expr hole 258 | 259 | val make_method : bound -> name -> expr t * expr hole 260 | 261 | val set_elab : 'a t -> 'a hole -> unit 262 | end 263 | end = struct 264 | type 'a t = unit -> 'a 265 | 266 | let run elab = elab () 267 | 268 | let return v () = v 269 | 270 | let map v f () = f (v ()) 271 | 272 | let both a b () = (a (), b ()) 273 | 274 | let ( let+ ) = map 275 | 276 | let ( and+ ) = both 277 | 278 | let list es () = List.map es ~f:run 279 | 280 | module Hole = struct 281 | type 'a hole = 282 | | Hole_expr of { bounds : bound list; mutable elab : 'a t } 283 | | Hole_method of { 284 | bound : bound; 285 | name : string; 286 | mutable elab : 'a t option; 287 | } 288 | | Hole_implicit_param of 'a implicit_param_hole 289 | 290 | and 'a implicit_param_hole = { 291 | bound : bound; 292 | param : Implicit_param.t; 293 | mutable elab : 'a t; 294 | } 295 | 296 | let of_hole hole () = 297 | match hole with 298 | | Hole_expr e -> e.elab () 299 | | Hole_method { elab = Some expr; name; _ } -> 300 | E_tclass_method (expr (), name) 301 | | Hole_method { elab = None; _ } -> failwith "unresolved typeclass method" 302 | | Hole_implicit_param { elab; _ } -> elab () 303 | 304 | let of_implicit_param param () = E_var (Implicit_param.param param) 305 | 306 | let implicit_param ~lvl bound name = 307 | let param = Implicit_param.make ~lvl name in 308 | let elab = of_implicit_param param in 309 | let hole = { elab; param; bound } in 310 | (of_hole (Hole_implicit_param hole), hole) 311 | 312 | let make_implicit_param param = 313 | let hole = Hole_implicit_param param in 314 | (of_hole hole, hole) 315 | 316 | let make_expr bounds expr = 317 | let hole = Hole_expr { bounds; elab = return expr } in 318 | (of_hole hole, hole) 319 | 320 | let make_method bound name = 321 | let hole = Hole_method { bound; name; elab = None } in 322 | (of_hole hole, hole) 323 | 324 | let set_elab elab hole = 325 | match hole with 326 | | Hole_expr hole -> hole.elab <- elab 327 | | Hole_method hole -> hole.elab <- Some elab 328 | | Hole_implicit_param hole -> hole.elab <- elab 329 | end 330 | end 331 | 332 | (** Instantiate type scheme into a constrained type. *) 333 | let instantiate ?bound_lvl ~lvl (ty_sch : ty_sch) : cty = 334 | match ty_sch with 335 | | [], ty -> 336 | (* No ∀-quantified variables, return the type as-is *) 337 | ty 338 | | vars, cty -> 339 | let subst = 340 | Subst.make 341 | (List.map vars ~f:(fun v -> (v, Ty.var (Var.fresh ?bound_lvl ~lvl ())))) 342 | in 343 | Subst.apply_cty subst cty 344 | 345 | let instantiate ?bound_lvl ~lvl ty_sch = 346 | if Debug.log_instantiate then Ty_sch.print ~label:"I<" ty_sch; 347 | let cty = instantiate ?bound_lvl ~lvl ty_sch in 348 | if Debug.log_instantiate then Cty.print ~label:"I>" cty; 349 | cty 350 | 351 | module Env0 = struct 352 | type t = { 353 | values : (name, def_kind * def, String.comparator_witness) Map.t; 354 | tclasses : (name, tclass, String.comparator_witness) Map.t; 355 | } 356 | 357 | and def = { name : name; ty_sch : ty_sch } 358 | 359 | and def_kind = Def_value | Def_method 360 | 361 | and tclass = { def : def; method_def : def; instances : def list } 362 | 363 | let empty = 364 | { values = Map.empty (module String); tclasses = Map.empty (module String) } 365 | 366 | let find_val env name = Map.find env.values name 367 | 368 | let find_tclass env name = Map.find env.tclasses name 369 | 370 | let add_val ?(kind = Def_value) env name ty_sch = 371 | if Debug.log_define then 372 | Caml.Format.printf "val %s : %s@." name (Ty_sch.show ty_sch); 373 | let () = 374 | match (kind, ty_sch) with 375 | | Def_method, (_, ([ _ ], _)) -> () 376 | | Def_method, (_, (_, _)) -> failwith "method with multiple constraints" 377 | | Def_value, _ -> () 378 | in 379 | { 380 | env with 381 | values = Map.set env.values ~key:name ~data:(kind, { name; ty_sch }); 382 | } 383 | end 384 | 385 | let rec ty_resolved = function 386 | | Ty_var v -> ( 387 | match Var.ty v with 388 | | Some ty -> ty_resolved ty 389 | | None -> false) 390 | | Ty_const _ -> true 391 | | Ty_arr (args, ret) -> List.for_all args ~f:ty_resolved && ty_resolved ret 392 | | Ty_app (head, args) -> ty_resolved head && List.for_all args ~f:ty_resolved 393 | 394 | let is_unresolved_v = function 395 | | Ty_var v -> Option.is_none (Var.ty v) 396 | | _ -> false 397 | 398 | (** Elaborates on all the holes and returns a list of params and bounds in the 399 | Head First Normal (HNF) form. 400 | 401 | Bounds in HNF are considered solved bounds (there's nothing to solve further 402 | about them). *) 403 | module Hole_elaborator : sig 404 | val elaborate : 405 | lvl:id -> 406 | env:Env0.t -> 407 | expr Elab.Hole.hole list -> 408 | expr Elab.Hole.implicit_param_hole list 409 | end = struct 410 | (* Compute a list of super bounds for the bound. *) 411 | let lineage ~bound_lvl ~lvl ~env = 412 | let rec aux entailments (B_class (name, _) as bound) = 413 | match Env0.find_tclass env name with 414 | | None -> assert false 415 | | Some tclass -> ( 416 | let c_ty = Bound.to_ty bound in 417 | let bounds, ty = instantiate ~bound_lvl ~lvl tclass.def.ty_sch in 418 | match bounds with 419 | | [] -> bound :: entailments 420 | | bounds -> 421 | (* This is not just an assert but also a way to associate [ty] 422 | variables with [c_ty], so [bounds] has all vars unified as well. *) 423 | assert (matches c_ty ty); 424 | List.fold bounds ~init:(bound :: entailments) ~f:aux) 425 | in 426 | aux [] 427 | 428 | (* [entails others c] checks if the [c] bound can be "inferred" from [others] 429 | bounds. *) 430 | let entails ~bound_lvl ~lvl ~env others hole = 431 | let rec aux = function 432 | | [] -> None 433 | | hole' :: others' -> 434 | if 435 | List.exists (lineage ~bound_lvl ~lvl ~env hole'.Elab.Hole.bound) 436 | ~f:(fun bound' -> matches_bound hole.Elab.Hole.bound bound') 437 | then Some hole' 438 | else aux others' 439 | in 440 | aux others 441 | 442 | (* Simple bounds. *) 443 | let simpl_bounds ~bound_lvl ~lvl ~env = 444 | let rec aux simplified = function 445 | | [] -> simplified 446 | | hole :: holes -> ( 447 | let others = simplified @ holes in 448 | match entails ~bound_lvl ~lvl ~env others hole with 449 | | Some hole' -> 450 | Implicit_param.link ~target:hole'.param hole.Elab.Hole.param; 451 | aux simplified holes 452 | | None -> aux (hole :: simplified) holes) 453 | in 454 | fun holes -> List.rev (aux [] holes) 455 | 456 | (* Check if bound is in HNF. *) 457 | let is_hnf (B_class (_name, tys)) = 458 | let resolved = List.for_all tys ~f:ty_resolved in 459 | (not resolved) 460 | && List.for_all tys ~f:(fun ty -> is_unresolved_v ty || ty_resolved ty) 461 | 462 | let find_tclass_instance ~bound_lvl ~lvl ~env bound : Env0.def * bound list = 463 | let (B_class (name, _)) = bound in 464 | let instances = 465 | match Env0.find_tclass env name with 466 | | None -> Type_error.raise (Error_unknown_tclass name) 467 | | Some tclass -> tclass.instances 468 | in 469 | let bound_ty = Bound.to_ty bound in 470 | match 471 | List.find_map instances ~f:(fun def -> 472 | let cty = instantiate ~bound_lvl ~lvl def.Env0.ty_sch in 473 | let bounds, ty = cty in 474 | let m = matches bound_ty ty in 475 | if m then Some (def, bounds) else None) 476 | with 477 | | None -> Type_error.raise (Error_no_tclass_instance bound_ty) 478 | | Some found -> found 479 | 480 | let holes_to_hnf ~bound_lvl ~lvl ~env (holes : expr Elab.Hole.hole list) = 481 | let rec solve_bound bound = 482 | if is_hnf bound then 483 | (* Bound is in HNF, so we allocate an implicit param for it. *) 484 | let (B_class (name, _)) = bound in 485 | let elab, hole = Elab.Hole.implicit_param ~lvl bound name in 486 | (elab, [ hole ]) 487 | else 488 | (* Not in HNF, so we find an instance and recurse into its bounds. *) 489 | let def, bounds' = find_tclass_instance ~bound_lvl ~lvl ~env bound in 490 | let elab' = Elab.return (E_var def.name) in 491 | solve_bounds elab' bounds' 492 | and solve_bounds elab bounds = 493 | let es, solved = 494 | List.fold bounds ~init:([], []) ~f:(fun (es, solved) bound -> 495 | let e', solved' = solve_bound bound in 496 | (e' :: es, solved' @ solved)) 497 | in 498 | let elab = 499 | Elab.( 500 | let+ es = list es 501 | and+ e = elab in 502 | match List.rev es with 503 | | [] -> e 504 | | es -> E_app (e, es)) 505 | in 506 | (elab, solved) 507 | in 508 | List.concat 509 | (List.map holes ~f:(fun hole -> 510 | match hole with 511 | | Elab.Hole.Hole_implicit_param p -> 512 | if is_hnf p.bound then [ p ] 513 | else 514 | let elab, solved = solve_bound p.bound in 515 | Elab.Hole.set_elab elab hole; 516 | solved 517 | | Hole_expr p -> 518 | let elab, solved = solve_bounds p.elab p.bounds in 519 | Elab.Hole.set_elab elab hole; 520 | solved 521 | | Hole_method p -> 522 | let elab, solved = solve_bound p.bound in 523 | Elab.Hole.set_elab elab hole; 524 | solved)) 525 | 526 | module Epoch = MakeId () 527 | 528 | let elaborate ~lvl ~env (holes : expr Elab.Hole.hole list) = 529 | Implicit_param.reset (); 530 | let bound_lvl = Epoch.fresh () in 531 | let holes = holes_to_hnf ~bound_lvl ~lvl ~env holes in 532 | simpl_bounds ~bound_lvl ~lvl ~env holes 533 | end 534 | 535 | let simple_vs vs = 536 | let rec simple_vs vs' = function 537 | | [] -> vs' 538 | | v :: vs -> 539 | if Option.is_some (Var.ty v) then 540 | (* Skipping as the var is already bound. *) 541 | simple_vs vs' vs 542 | else if List.mem ~equal:Var.equal vs v then 543 | (* Skipping as the var is duplicated within [vs]. *) 544 | simple_vs vs' vs 545 | else simple_vs (v :: vs') vs 546 | in 547 | simple_vs [] vs 548 | 549 | let bound_vs (B_class (_name, tys)) = 550 | List.filter_map tys ~f:(function 551 | | Ty_var v as ty -> if is_unresolved_v ty then Some v else None 552 | | _ -> None) 553 | 554 | let ty_vs ~lvl ty = 555 | let rec aux vs = function 556 | | Ty_const _ -> vs 557 | | Ty_app (a, args) -> 558 | let vs = aux vs a in 559 | List.fold args ~init:vs ~f:aux 560 | | Ty_arr (args, b) -> 561 | let vs = aux vs b in 562 | List.fold args ~init:vs ~f:aux 563 | | Ty_var v -> ( 564 | match Var.ty v with 565 | | Some ty -> aux vs ty 566 | | None -> if Var.lvl v > lvl then v :: vs else vs) 567 | in 568 | aux [] ty 569 | 570 | let partition_bounds ~lvl = 571 | (* Given a set of "deferred" variables (variables we want to prevent to 572 | be generalized) we compute if we need to defer the current bound and a set 573 | of variables we want to "restrict" *) 574 | let should_defer dvs bound = 575 | let is_deferred_v v = Var.lvl v <= lvl || List.mem dvs v ~equal:Var.equal in 576 | let vs = bound_vs bound in 577 | let dvs', defer = 578 | List.fold vs ~init:([], false) ~f:(fun (dvs', defer) v -> 579 | if is_deferred_v v then (dvs', true) else (v :: dvs', defer)) 580 | in 581 | if defer then 582 | (* if we are going to defer this bound, then all others not-yet deferred 583 | vars should be treated as deferred next. *) 584 | Some dvs' 585 | else None 586 | in 587 | let rec aux ~dvs ~next_dvs (rbs, dbs) holes = 588 | match holes with 589 | | [] -> ( 590 | match next_dvs with 591 | | [] -> 592 | (* No deferred vars found, return what's found. *) 593 | ((rbs, dbs), dvs) 594 | | next_dvs -> 595 | (* Check retained bounds once more with new deferred vars. *) 596 | aux ~dvs:(simple_vs next_dvs) ~next_dvs:[] ([], dbs) rbs) 597 | | hole :: bounds -> ( 598 | match should_defer dvs hole.Elab.Hole.bound with 599 | | None -> aux ~dvs ~next_dvs (hole :: rbs, dbs) bounds 600 | | Some next_dvs' -> 601 | (* Ok, we want to defer this bound, collect deferred vars. *) 602 | aux ~dvs ~next_dvs:(next_dvs' @ next_dvs) (rbs, hole :: dbs) bounds) 603 | in 604 | aux ~dvs:[] ~next_dvs:[] ([], []) 605 | 606 | let generalize ~lvl (holes, ty) = 607 | let gvs = simple_vs (ty_vs ~lvl ty) in 608 | let (retained, holes), dvs = partition_bounds ~lvl holes in 609 | let gvs = 610 | List.filter gvs ~f:(fun v -> not (List.mem dvs v ~equal:Var.equal)) 611 | in 612 | let holes = 613 | List.map holes ~f:(fun hole -> 614 | let _, hole = Elab.Hole.make_implicit_param hole in 615 | hole) 616 | in 617 | let params, bounds = 618 | let not_in_vs v = not (List.mem gvs v ~equal:Var.equal) in 619 | List.fold retained ~init:([], []) ~f:(fun (params, bounds) hole -> 620 | if List.exists (bound_vs hole.Elab.Hole.bound) ~f:not_in_vs then 621 | Type_error.raise 622 | (Error_ambigious_tclass_application hole.Elab.Hole.bound); 623 | (hole.param :: params, hole.bound :: bounds)) 624 | in 625 | (List.rev params, (simple_vs gvs, (List.rev bounds, ty)), holes) 626 | 627 | let generalize ~lvl (holes, ty) = 628 | (if Debug.log_generalize then 629 | let bounds = List.map holes ~f:(fun hole -> hole.Elab.Hole.bound) in 630 | Cty.print ~label:"G<" (bounds, ty)); 631 | let params, ty_sch, holes = generalize ~lvl (holes, ty) in 632 | if Debug.log_generalize then Ty_sch.print ~label:"G>" ty_sch; 633 | (params, ty_sch, holes) 634 | 635 | let rec solve : 636 | type a. 637 | lvl:lvl -> 638 | env:Env0.t -> 639 | a Constraint.t -> 640 | a Elab.t * expr Elab.Hole.hole list = 641 | fun ~lvl ~env c -> 642 | match c with 643 | | C_trivial -> (Elab.return (), []) 644 | | C_eq (a, b) -> 645 | unify a b; 646 | (Elab.return (), []) 647 | | C_map (c, f) -> 648 | let v, holes = solve ~lvl ~env c in 649 | (Elab.map v f, holes) 650 | | C_and (a, b) -> 651 | let a, aholes = solve ~lvl ~env a in 652 | let b, bholes = solve ~lvl ~env b in 653 | (Elab.both a b, aholes @ bholes) 654 | | C_and_list cs -> 655 | let vs, holes = 656 | List.fold cs ~init:([], []) ~f:(fun (vs, holes) c -> 657 | let v, holes' = solve ~lvl ~env c in 658 | (v :: vs, holes' @ holes)) 659 | in 660 | let elab = 661 | Elab.( 662 | let+ vs = list vs in 663 | List.rev vs) 664 | in 665 | (elab, holes) 666 | | C_exists (vs, c) -> 667 | List.iter vs ~f:(Var.set_lvl lvl); 668 | solve ~lvl ~env c 669 | | C_let (bindings, c) -> 670 | let env, values, holes = 671 | let env0 = env in 672 | List.fold bindings ~init:(env, [], []) 673 | ~f:(fun (env, values, holes) (name, vs, c, ty) -> 674 | (* Need to set levels here as [C_let] works as [C_exists] as well. *) 675 | List.iter vs ~f:(Var.set_lvl (lvl + 1)); 676 | let e, ty_sch, holes' = 677 | solve_and_generalize ~lvl:(lvl + 1) ~env:env0 c ty 678 | in 679 | let env = Env0.add_val env name ty_sch in 680 | (env, e :: values, holes' @ holes)) 681 | in 682 | let v, holes' = solve ~lvl ~env c in 683 | let values = Elab.list (List.rev values) in 684 | (Elab.(both values v), holes' @ holes) 685 | | C_inst (name, ty) -> ( 686 | let kind, ty_sch = 687 | match Env0.find_val env name with 688 | | Some (kind, def) -> (kind, def.ty_sch) 689 | | None -> Type_error.raise (Error_unknown_name name) 690 | in 691 | let bounds', ty' = instantiate ~lvl ty_sch in 692 | unify ty ty'; 693 | match (kind, bounds') with 694 | | Def_value, [] -> (Elab.return (E_var name), []) 695 | | Def_value, bounds -> 696 | let elab, hole = Elab.Hole.make_expr bounds (E_var name) in 697 | (elab, [ hole ]) 698 | | Def_method, [ c ] -> 699 | let elab, hole = Elab.Hole.make_method c name in 700 | (elab, [ hole ]) 701 | | Def_method, _ -> assert false) 702 | 703 | and solve_and_generalize ~lvl ~env c ty = 704 | let e, holes = solve ~lvl ~env c in 705 | let holes = Hole_elaborator.elaborate ~lvl:(lvl - 1) ~env holes in 706 | let params, ty_sch, holes = generalize ~lvl:(lvl - 1) (holes, ty) in 707 | let e = 708 | Elab.( 709 | let+ e = e in 710 | let e = 711 | match params with 712 | | [] -> e 713 | | params -> E_abs (List.map params ~f:Implicit_param.param, e) 714 | in 715 | (e, ty_sch)) 716 | in 717 | (e, ty_sch, holes) 718 | 719 | (** [infer ~env e] infers the type scheme for expression [e]. 720 | 721 | It returns either an [Ok (ty_sch, elaborated)] where [ty_sch] is the type 722 | scheme inferred and [elaborated] is an elaborated expression corresponding 723 | to [e]. 724 | 725 | ... or in case of an error it returns [Error err]. 726 | *) 727 | let infer ~env e : (expr, Type_error.t) Result.t = 728 | (* To infer an expression type we first generate constraints *) 729 | let v = Var.fresh () in 730 | let ty = Ty.var v in 731 | let c = generate e ty in 732 | let c = Constraint.exists [ v ] c in 733 | try 734 | (* and then solve them and generaralize!. *) 735 | let e, _ty_sch, holes = solve_and_generalize ~lvl:1 ~env c ty in 736 | if List.length holes > 0 then 737 | failwith "unelaborated expressions at the top level"; 738 | Ok 739 | Elab.( 740 | run 741 | (let+ e, ty_sch = e in 742 | E_let (("_", e, Some ty_sch), E_var "_"))) 743 | with 744 | | Type_error.Type_error error -> Error error 745 | 746 | (** Now that we have all needed machinery we can extend [Env0] with methods for 747 | adding new definitions to the environment. *) 748 | module Env = struct 749 | include Env0 750 | 751 | let add_tclass env name ty_sch = 752 | if Debug.log_define then 753 | Caml.Format.printf "typeclass %s = %s@." name (Ty_sch.show ty_sch); 754 | let def = 755 | let vs, (bounds, _) = ty_sch in 756 | let ty_sch = 757 | (vs, (bounds, Ty_app (Ty_const name, List.map vs ~f:Ty.var))) 758 | in 759 | { name; ty_sch } 760 | in 761 | let method_def = { name; ty_sch } in 762 | let env = 763 | let ty_sch = 764 | let vs, (_, ty) = ty_sch in 765 | let c = B_class (name, List.map vs ~f:Ty.var) in 766 | (vs, ([ c ], ty)) 767 | in 768 | add_val ~kind:Def_method env name ty_sch 769 | in 770 | { 771 | env with 772 | tclasses = 773 | (match 774 | Map.add env.tclasses ~key:name 775 | ~data:{ def; method_def; instances = [] } 776 | with 777 | | `Ok tclass -> tclass 778 | | `Duplicate -> 779 | failwith (Printf.sprintf "typeclass %s is already defined" name)); 780 | } 781 | 782 | let add_tclass_instance env name ty_sch = 783 | if Debug.log_define then 784 | Caml.Format.printf "instance %s : %s@." name (Ty_sch.show ty_sch); 785 | let vs, bounds, cls_name, cls_args = 786 | let vs, (bounds, ty) = ty_sch in 787 | match ty with 788 | | Ty_app (Ty_const name, args) -> (vs, bounds, name, args) 789 | | _ -> 790 | failwith 791 | (Printf.sprintf "typeclass instance should be a type applicaton") 792 | in 793 | let instance = { name; ty_sch } in 794 | let tclass = 795 | match Map.find env.tclasses cls_name with 796 | | Some tclass -> tclass 797 | | None -> failwith (Printf.sprintf "no typeclass found: %s" cls_name) 798 | in 799 | let tclass = { tclass with instances = instance :: tclass.instances } in 800 | let env = 801 | let vs', cty = tclass.method_def.ty_sch in 802 | let subst = 803 | match List.zip vs' cls_args with 804 | | Unequal_lengths -> 805 | failwith 806 | (Printf.sprintf "invalid number of arguments: %s for %s" 807 | (Ty_sch.show instance.ty_sch) 808 | (Ty_sch.show ty_sch)) 809 | | Ok items -> Subst.make items 810 | in 811 | let bounds', ty' = Subst.apply_cty subst cty in 812 | let bounds = 813 | let _, hole = 814 | Elab.Hole.make_expr (bounds' @ bounds) (E_var instance.name) 815 | in 816 | Hole_elaborator.elaborate ~lvl:1 ~env [ hole ] 817 | |> List.map ~f:(fun hole -> hole.Elab.Hole.bound) 818 | in 819 | add_val env name (vs, (bounds, ty')) 820 | in 821 | { env with tclasses = Map.set env.tclasses ~key:cls_name ~data:tclass } 822 | end 823 | -------------------------------------------------------------------------------- /hmx_tc/lexer.mll: -------------------------------------------------------------------------------- 1 | { 2 | 3 | open Parser 4 | 5 | exception Error of string 6 | 7 | } 8 | 9 | 10 | let ident = ['_' 'A'-'Z' 'a'-'z'] ['_' 'A'-'Z' 'a'-'z' '0'-'9']* 11 | let integer = ['0'-'9']+ 12 | 13 | rule token = parse 14 | | [' ' '\t' '\r' '\n'] { token lexbuf } 15 | | "fun" { FUN } 16 | | "let" { LET } 17 | | "rec" { REC } 18 | | "in" { IN } 19 | | "with" { WITH } 20 | | ident { IDENT (Lexing.lexeme lexbuf) } 21 | | '(' { LPAREN } 22 | | ')' { RPAREN } 23 | | '[' { LBRACKET } 24 | | ']' { RBRACKET } 25 | | '{' { LBRACE } 26 | | '}' { RBRACE } 27 | | '=' { EQUALS } 28 | | ':' '=' { ASSIGN } 29 | | "->" { ARROW } 30 | | "=>" { GTE } 31 | | ',' { COMMA } 32 | | '.' { DOT } 33 | | ';' { SEMI } 34 | | ':' { COLON } 35 | | eof { EOF } 36 | | _ as c { raise (Error ("unexpected token: '" ^ Char.escaped c ^ "'")) } 37 | 38 | 39 | { 40 | 41 | let string_of_token = function 42 | | FUN -> "fun" 43 | | LET -> "let" 44 | | REC -> "rec" 45 | | IN -> "in" 46 | | WITH -> "forall" 47 | | IDENT ident -> ident 48 | | LPAREN -> "(" 49 | | RPAREN -> ")" 50 | | LBRACKET -> "[" 51 | | RBRACKET -> "]" 52 | | LBRACE -> "{" 53 | | RBRACE -> "}" 54 | | EQUALS -> "=" 55 | | ASSIGN -> ":=" 56 | | ARROW -> "->" 57 | | COMMA -> "," 58 | | DOT -> "." 59 | | SEMI -> "." 60 | | COLON -> ":" 61 | | GTE -> "=>" 62 | | EOF -> "" 63 | 64 | } 65 | -------------------------------------------------------------------------------- /hmx_tc/parser.mly: -------------------------------------------------------------------------------- 1 | %{ 2 | 3 | open Syntax 4 | 5 | let makeenv vars = 6 | let open Base in 7 | Var.reset (); 8 | let vs, map = List.fold_left 9 | vars 10 | ~init:([], Map.empty (module String)) 11 | ~f:(fun (vs, env) name -> 12 | let v = Var.fresh () in 13 | v::vs, 14 | Map.set env ~key:name ~data:(Ty.var v)) in 15 | List.rev vs, map 16 | 17 | let build_ty_sch (vs, env) (cs, ty) = 18 | let open Base in 19 | let rec build_ty ty = match ty with 20 | | Ty_const name -> ( 21 | match Map.find env name with 22 | | Some ty -> ty 23 | | None -> ty) 24 | | Ty_var _ -> ty 25 | | Ty_app (fty, atys) -> Ty_app (build_ty fty, List.map atys ~f:build_ty) 26 | | Ty_arr (atys, rty) -> Ty_arr (List.map atys ~f:build_ty, build_ty rty) 27 | and build_bound c = 28 | match c with 29 | | B_class (name, tys) -> 30 | let tys = List.map tys ~f:build_ty in 31 | B_class (name, tys) 32 | in 33 | vs, (List.map cs ~f:build_bound, build_ty ty) 34 | %} 35 | 36 | %token IDENT 37 | %token FUN LET REC IN WITH 38 | %token LPAREN RPAREN LBRACKET RBRACKET LBRACE RBRACE 39 | %token ARROW EQUALS COMMA DOT SEMI COLON ASSIGN GTE 40 | %token EOF 41 | 42 | %start expr_eof 43 | %type expr_eof 44 | %start ty_sch_eof 45 | %type ty_sch_eof 46 | 47 | %% 48 | 49 | expr_eof: 50 | e = expr EOF { e } 51 | 52 | ty_sch_eof: 53 | t = ty_sch EOF { t } 54 | 55 | expr: 56 | e = simple_expr { e } 57 | 58 | (* let-bindings *) 59 | | LET n = IDENT EQUALS e = expr IN b = expr { E_let ((n, e, None), b) } 60 | 61 | (* functions *) 62 | | FUN arg = IDENT ARROW body = expr 63 | { E_abs ([arg], body) } 64 | | FUN LPAREN args = flex_list(COMMA, IDENT) RPAREN ARROW body = expr 65 | { E_abs (args, body) } 66 | 67 | | LET n = IDENT arg = IDENT EQUALS e = expr IN b = expr 68 | { E_let ((n, E_abs ([arg], e), None), b) } 69 | | LET n = IDENT LPAREN args = flex_list(COMMA, IDENT) RPAREN EQUALS e = expr IN b = expr 70 | { E_let ((n, E_abs (args, e), None), b) } 71 | 72 | simple_expr: 73 | n = IDENT { E_var n } 74 | | LPAREN e = expr RPAREN { e } 75 | | f = simple_expr LPAREN args = flex_list(COMMA, expr) RPAREN 76 | { E_app (f, args) } 77 | 78 | ident_list: 79 | xs = nonempty_flex_list(COMMA, IDENT) { xs } 80 | 81 | ty_sch: 82 | t = cty { [], t } 83 | | vars = ident_list DOT t = cty 84 | { let env = makeenv vars in build_ty_sch env t } 85 | 86 | cty: 87 | t = ty { [], t } 88 | | bs = nonempty_flex_list(COMMA, bound) GTE ty = ty 89 | { bs, ty } 90 | 91 | ty: 92 | t = simple_ty 93 | { t } 94 | | LPAREN RPAREN ARROW ret = ty 95 | { Ty_arr ([], ret) } 96 | | arg = simple_ty ARROW ret = ty 97 | { Ty_arr ([arg], ret) } 98 | | LPAREN arg = ty COMMA args = flex_list(COMMA, ty) RPAREN ARROW ret = ty 99 | { Ty_arr (arg :: args, ret) } 100 | 101 | simple_ty: 102 | n = IDENT { Ty_const n } 103 | | LPAREN t = ty RPAREN { t } 104 | | t = ty_app { t } 105 | 106 | ty_app: 107 | f = simple_ty LBRACKET args = nonempty_flex_list(COMMA, ty) RBRACKET 108 | { Ty_app (f, args) } 109 | 110 | bound: 111 | t = ty_app 112 | { match t with 113 | | Ty_app (Ty_const name, args) -> B_class (name, args) 114 | | _ -> assert false 115 | } 116 | 117 | (* Utilities for flexible lists (and its non-empty version). 118 | 119 | A flexible list [flex_list(delim, X)] is the delimited with [delim] list of 120 | it [X] items where it is allowed to have a trailing [delim]. 121 | 122 | A non-empty [nonempty_flex_list(delim, X)] version of flexible list is 123 | provided as well. 124 | 125 | From http://gallium.inria.fr/blog/lr-lists/ 126 | 127 | *) 128 | 129 | flex_list(delim, X): 130 | { [] } 131 | | x = X { [x] } 132 | | x = X delim xs = flex_list(delim, X) { x::xs } 133 | 134 | nonempty_flex_list(delim, X): 135 | x = X { [x] } 136 | | x = X delim xs = flex_list(delim, X) { x::xs } 137 | -------------------------------------------------------------------------------- /hmx_tc/syntax.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | 3 | type name = string [@@deriving sexp_of] 4 | 5 | and id = int 6 | 7 | and lvl = int 8 | 9 | type expr = 10 | | E_var of name 11 | | E_abs of name list * expr 12 | | E_app of expr * expr list 13 | | E_let of (name * expr * ty_sch option) * expr 14 | | E_lit of lit 15 | | E_tclass_method of expr * string 16 | [@@deriving sexp_of] 17 | 18 | and lit = Lit_string of string | Lit_int of int 19 | 20 | and ty = 21 | | Ty_const of name 22 | | Ty_var of var 23 | | Ty_app of ty * ty list 24 | | Ty_arr of ty list * ty 25 | 26 | and var = var_data Union_find.t 27 | 28 | and var_data = { 29 | id : int; 30 | bound_lvl : int; 31 | mutable lvl : lvl option; 32 | (** Levels are assigned when we enter [C_exists] or [C_let] constraints *) 33 | mutable ty : ty option; 34 | (** Types are discovered as a result of unification. *) 35 | } 36 | 37 | and bound = B_class of name * ty list 38 | 39 | and cty = bound list * ty 40 | 41 | and ty_sch = var list * cty 42 | 43 | module Names : sig 44 | type t 45 | 46 | val make : unit -> t 47 | 48 | val alloc : t -> id -> string 49 | 50 | val lookup : t -> id -> string option 51 | end = struct 52 | type t = (Int.t, string) Hashtbl.t 53 | 54 | let make () = Hashtbl.create (module Int) 55 | 56 | let alloc names id = 57 | let i = Hashtbl.length names in 58 | let name = 59 | String.make 1 (Char.of_int_exn (97 + Int.rem i 26)) 60 | ^ if i >= 26 then Int.to_string (i / 26) else "" 61 | in 62 | Hashtbl.set names ~key:id ~data:name; 63 | name 64 | 65 | let lookup names id = Hashtbl.find names id 66 | end 67 | 68 | module MakeId () = struct 69 | let c = ref 0 70 | 71 | let fresh () = 72 | Int.incr c; 73 | !c 74 | 75 | let reset () = c := 0 76 | end 77 | 78 | module type SHOWABLE = sig 79 | type t 80 | 81 | val layout : t -> PPrint.document 82 | 83 | val show : t -> string 84 | 85 | val print : ?label:string -> t -> unit 86 | end 87 | 88 | module Showable (S : sig 89 | type t 90 | 91 | val layout : t -> PPrint.document 92 | end) : SHOWABLE with type t = S.t = struct 93 | type t = S.t 94 | 95 | let layout = S.layout 96 | 97 | let show v = 98 | let width = 60 in 99 | let buf = Buffer.create 100 in 100 | PPrint.ToBuffer.pretty 1. width buf (S.layout v); 101 | Buffer.contents buf 102 | 103 | let print ?label v = 104 | match label with 105 | | Some label -> Caml.print_endline (label ^ ": " ^ show v) 106 | | None -> Caml.print_endline (show v) 107 | end 108 | 109 | module type DUMPABLE = sig 110 | type t 111 | 112 | val dump : ?label:string -> t -> unit 113 | 114 | val sdump : ?label:string -> t -> string 115 | end 116 | 117 | module Dumpable (S : sig 118 | type t 119 | 120 | val sexp_of_t : t -> Sexp.t 121 | end) : DUMPABLE with type t = S.t = struct 122 | type t = S.t 123 | 124 | let dump ?label v = 125 | let s = S.sexp_of_t v in 126 | match label with 127 | | None -> Caml.Format.printf "%a@." Sexp.pp_hum s 128 | | Some label -> Caml.Format.printf "%s %a@." label Sexp.pp_hum s 129 | 130 | let sdump ?label v = 131 | let s = S.sexp_of_t v in 132 | match label with 133 | | None -> Caml.Format.asprintf "%a@." Sexp.pp_hum s 134 | | Some label -> Caml.Format.asprintf "%s %a@." label Sexp.pp_hum s 135 | end 136 | 137 | let layout_var v = 138 | let open PPrint in 139 | let lvl = Option.(v.lvl |> map ~f:Int.to_string |> value ~default:"!") in 140 | if Debug.log_levels then string (Printf.sprintf "_%i@%s" v.id lvl) 141 | else string (Printf.sprintf "_%i" v.id) 142 | 143 | let rec layout_expr' ~names = 144 | let open PPrint in 145 | function 146 | | E_var name -> string name 147 | | E_abs (args, body) -> 148 | let sep = comma ^^ blank 1 in 149 | let newline = 150 | (* Always break on let inside the body. *) 151 | match body with 152 | | E_let _ -> hardline 153 | | _ -> break 1 154 | in 155 | let args = 156 | match args with 157 | | [ arg ] -> string arg 158 | | args -> parens (separate sep (List.map args ~f:string)) 159 | in 160 | group 161 | (group (string "fun " ^^ args ^^ string " ->") 162 | ^^ nest 2 (group (newline ^^ group (layout_expr' ~names body)))) 163 | | E_app (f, args) -> 164 | let sep = comma ^^ break 1 in 165 | group 166 | (layout_expr' ~names f 167 | ^^ parens 168 | (nest 2 169 | (group 170 | (break 0 171 | ^^ separate sep (List.map args ~f:(layout_expr' ~names)))))) 172 | | E_let _ as e -> 173 | let es = 174 | (* We do not want to print multiple nested let-expression with indents and 175 | therefore we linearize them first and print on the same indent instead. *) 176 | let rec linearize es e = 177 | match e with 178 | | E_let (_, b) -> linearize (e :: es) b 179 | | e -> e :: es 180 | in 181 | List.rev (linearize [] e) 182 | in 183 | let newline = 184 | (* If there's more than a single let-expression found (checking length > 2 185 | because es containts the body of the last let-expression too) we split 186 | them with a hardline. *) 187 | if List.length es > 2 then hardline else break 1 188 | in 189 | concat 190 | (List.map es ~f:(function 191 | | E_let ((name, expr, ty_sch), _) -> 192 | let ascription = 193 | (* We need to layout ty_sch first as it will allocate names for use 194 | down the road. *) 195 | match ty_sch with 196 | | None -> empty 197 | | Some ty_sch -> 198 | string " :" ^^ nest 4 (break 1 ^^ layout_ty_sch' ~names ty_sch) 199 | in 200 | let expr_newline = 201 | (* If there's [let x = let y = ... in ... in ...] then we want to 202 | force break. *) 203 | match expr with 204 | | E_let _ -> hardline 205 | | _ -> break 1 206 | in 207 | group 208 | (group (string "let " ^^ string name ^^ ascription ^^ string " =") 209 | ^^ nest 2 (expr_newline ^^ layout_expr' ~names expr) 210 | ^^ expr_newline 211 | ^^ string "in") 212 | ^^ newline 213 | | e -> layout_expr' ~names e)) 214 | | E_lit (Lit_string v) -> dquotes (string v) 215 | | E_lit (Lit_int v) -> dquotes (string (Int.to_string v)) 216 | | E_tclass_method (e, name) -> layout_expr' ~names e ^^ dot ^^ string name 217 | 218 | and layout_ty' ~names ty = 219 | let open PPrint in 220 | let rec is_ty_arr = function 221 | | Ty_var var -> ( 222 | match (Union_find.value var).ty with 223 | | None -> false 224 | | Some ty -> is_ty_arr ty) 225 | | Ty_arr _ -> true 226 | | _ -> false 227 | in 228 | let rec layout_ty = function 229 | | Ty_const name -> string name 230 | | Ty_arr ([ aty ], rty) -> 231 | (* Check if we can layout this as simply as [aty -> try] in case of a 232 | single argument. *) 233 | (if is_ty_arr aty then 234 | (* If the single arg is the Ty_arr we need to wrap it in parens. *) 235 | parens (layout_ty aty) 236 | else layout_ty aty) 237 | ^^ string " -> " 238 | ^^ layout_ty rty 239 | | Ty_arr (atys, rty) -> 240 | let sep = comma ^^ blank 1 in 241 | parens (separate sep (List.map atys ~f:layout_ty)) 242 | ^^ string " -> " 243 | ^^ layout_ty rty 244 | | Ty_app (fty, atys) -> 245 | let sep = comma ^^ blank 1 in 246 | layout_ty fty ^^ brackets (separate sep (List.map atys ~f:layout_ty)) 247 | | Ty_var var -> ( 248 | let data = Union_find.value var in 249 | match data.ty with 250 | | None -> layout_con_var' ~names var 251 | | Some ty -> layout_ty ty) 252 | in 253 | layout_ty ty 254 | 255 | and layout_con_var' ~names v = 256 | let open PPrint in 257 | let v = Union_find.value v in 258 | match v.ty with 259 | | Some ty -> layout_ty' ~names ty 260 | | None -> ( 261 | match Names.lookup names v.id with 262 | | Some name -> 263 | if Debug.log_levels then string name ^^ parens (layout_var v) 264 | else string name 265 | | None -> layout_var v) 266 | 267 | and layout_bound' ~names = 268 | let open PPrint in 269 | let sep = comma ^^ blank 1 in 270 | function 271 | | B_class (name, vs) -> 272 | string name ^^ brackets (separate sep (List.map vs ~f:(layout_ty' ~names))) 273 | 274 | and layout_cty' ~names = 275 | let open PPrint in 276 | function 277 | | [], ty -> layout_ty' ~names ty 278 | | bounds, ty -> 279 | let sep = comma ^^ blank 1 in 280 | group (separate sep (List.map bounds ~f:(layout_bound' ~names))) 281 | ^^ string " =>" 282 | ^^ break 1 283 | ^^ group (layout_ty' ~names ty) 284 | 285 | and layout_ty_sch' ~names ty_sch = 286 | let open PPrint in 287 | match ty_sch with 288 | | [], cty -> layout_cty' ~names cty 289 | | vs, cty -> 290 | let vs = layout_var_prenex' ~names vs in 291 | group (vs ^^ layout_cty' ~names cty) 292 | 293 | and layout_var_prenex' ~names vs = 294 | let open PPrint in 295 | let sep = comma ^^ blank 1 in 296 | let vs = 297 | List.map vs ~f:(fun v -> 298 | let v = Union_find.value v in 299 | string (Names.alloc names v.id)) 300 | in 301 | separate sep vs ^^ string " . " 302 | 303 | module Expr = struct 304 | type t = expr 305 | 306 | include ( 307 | Showable (struct 308 | type t = expr 309 | 310 | let layout e = layout_expr' ~names:(Names.make ()) e 311 | end) : 312 | SHOWABLE with type t := t) 313 | 314 | include ( 315 | Dumpable (struct 316 | type t = expr 317 | 318 | let sexp_of_t = sexp_of_expr 319 | end) : 320 | DUMPABLE with type t := t) 321 | end 322 | 323 | module Ty = struct 324 | type t = ty 325 | 326 | let arr a b = Ty_arr (a, b) 327 | 328 | let var var = Ty_var var 329 | 330 | include ( 331 | Showable (struct 332 | type t = ty 333 | 334 | let layout ty = layout_ty' ~names:(Names.make ()) ty 335 | end) : 336 | SHOWABLE with type t := t) 337 | 338 | include ( 339 | Dumpable (struct 340 | type t = ty 341 | 342 | let sexp_of_t = sexp_of_ty 343 | end) : 344 | DUMPABLE with type t := t) 345 | end 346 | 347 | module Bound = struct 348 | type t = bound 349 | 350 | let to_ty (B_class (name, tys)) = Ty_app (Ty_const name, tys) 351 | 352 | include ( 353 | Showable (struct 354 | type t = bound 355 | 356 | let layout b = layout_bound' ~names:(Names.make ()) b 357 | end) : 358 | SHOWABLE with type t := t) 359 | 360 | include ( 361 | Dumpable (struct 362 | type t = bound 363 | 364 | let sexp_of_t = sexp_of_bound 365 | end) : 366 | DUMPABLE with type t := t) 367 | end 368 | 369 | module Cty = struct 370 | type t = cty 371 | 372 | include ( 373 | Showable (struct 374 | type t = cty 375 | 376 | let layout cty = layout_cty' ~names:(Names.make ()) cty 377 | end) : 378 | SHOWABLE with type t := t) 379 | 380 | include ( 381 | Dumpable (struct 382 | type t = cty 383 | 384 | let sexp_of_t = sexp_of_cty 385 | end) : 386 | DUMPABLE with type t := t) 387 | end 388 | 389 | module Ty_sch = struct 390 | type t = ty_sch 391 | 392 | include ( 393 | Showable (struct 394 | type t = ty_sch 395 | 396 | let layout ty_sch = layout_ty_sch' ~names:(Names.make ()) ty_sch 397 | end) : 398 | SHOWABLE with type t := t) 399 | 400 | include ( 401 | Dumpable (struct 402 | type t = ty_sch 403 | 404 | let sexp_of_t = sexp_of_ty_sch 405 | end) : 406 | DUMPABLE with type t := t) 407 | end 408 | -------------------------------------------------------------------------------- /hmx_tc/test/dune: -------------------------------------------------------------------------------- 1 | (library 2 | (name test_hmx_tc) 3 | (inline_tests) 4 | (preprocess 5 | (pps ppx_expect)) 6 | (libraries base hmx_tc)) 7 | -------------------------------------------------------------------------------- /hmx_tc/test/test_infer.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | open Hmx_tc 3 | 4 | let env = 5 | Env.empty 6 | |> Env.assume_val "fix" "a . (a -> a) -> a" 7 | |> Env.assume_val "head" "a . list[a] -> a" 8 | |> Env.assume_val "tail" "a . list[a] -> list[a]" 9 | |> Env.assume_val "nil" "a . list[a]" 10 | |> Env.assume_val "cons" "a . (a, list[a]) -> list[a]" 11 | |> Env.assume_val "cons_curry" "a . a -> list[a] -> list[a]" 12 | |> Env.assume_val "map" "a, b . (a -> b, list[a]) -> list[b]" 13 | |> Env.assume_val "map_curry" "a, b . (a -> b) -> list[a] -> list[b]" 14 | |> Env.assume_val "one" "int" 15 | |> Env.assume_val "zero" "int" 16 | |> Env.assume_val "succ" "int -> int" 17 | |> Env.assume_val "plus" "(int, int) -> int" 18 | |> Env.assume_val "eq" "a . (a, a) -> bool" 19 | |> Env.assume_val "eq_curry" "a . a -> a -> bool" 20 | |> Env.assume_val "not" "bool -> bool" 21 | |> Env.assume_val "true" "bool" 22 | |> Env.assume_val "false" "bool" 23 | |> Env.assume_val "pair" "a, b . (a, b) -> pair[a, b]" 24 | |> Env.assume_val "pair_curry" "a, b . a -> b -> pair[a, b]" 25 | |> Env.assume_val "triple" "a, b, c . (a, b, c) -> triple[a, b, c]" 26 | |> Env.assume_val "quadruple" 27 | "a, b, c, e . (a, b, c, e) -> quadruple[a, b, c, e]" 28 | |> Env.assume_val "first" "a, b . pair[a, b] -> a" 29 | |> Env.assume_val "second" "a, b . pair[a, b] -> b" 30 | |> Env.assume_val "id" "a . a -> a" 31 | |> Env.assume_val "const" "a, b . a -> b -> a" 32 | |> Env.assume_val "apply" "a, b . (a -> b, a) -> b" 33 | |> Env.assume_val "apply_curry" "a, b . (a -> b) -> a -> b" 34 | |> Env.assume_val "choose" "a . (a, a) -> a" 35 | |> Env.assume_val "choose_curry" "a . a -> a -> a" 36 | |> Env.assume_val "age" "int" 37 | |> Env.assume_val "world" "string" 38 | |> Env.assume_val "print" "string -> string" 39 | 40 | let infer ~env code = 41 | Var.reset (); 42 | let prog = Expr.parse_string code in 43 | match infer ~env prog with 44 | | Ok e -> Caml.Format.printf "%s@.|" (Expr.show e) 45 | | Error err -> Caml.Format.printf "ERROR: %s@.|" (Type_error.show err) 46 | 47 | let%expect_test "" = 48 | infer ~env "world"; 49 | [%expect {| 50 | let _ : string = world in 51 | _ 52 | | |}] 53 | 54 | let%expect_test "" = 55 | infer ~env "print"; 56 | [%expect {| 57 | let _ : string -> string = print in 58 | _ 59 | | |}] 60 | 61 | let%expect_test "" = 62 | infer ~env "let x = world in x"; 63 | [%expect 64 | {| 65 | let _ : string = 66 | let x : string = world in 67 | x 68 | in 69 | _ 70 | | |}] 71 | 72 | let%expect_test "" = 73 | infer ~env "fun () -> world"; 74 | [%expect {| 75 | let _ : () -> string = fun () -> world in 76 | _ 77 | | |}] 78 | 79 | let%expect_test "" = 80 | infer ~env "let x = fun () -> world in world"; 81 | [%expect 82 | {| 83 | let _ : string = 84 | let x : () -> string = fun () -> world in 85 | world 86 | in 87 | _ 88 | | |}] 89 | 90 | let%expect_test "" = 91 | infer ~env "let x = fun () -> world in x"; 92 | [%expect 93 | {| 94 | let _ : () -> string = 95 | let x : () -> string = fun () -> world in 96 | x 97 | in 98 | _ 99 | | |}] 100 | 101 | let%expect_test "" = 102 | infer ~env "print(world)"; 103 | [%expect {| 104 | let _ : string = print(world) in 105 | _ 106 | | |}] 107 | 108 | let%expect_test "" = 109 | infer ~env "let hello = fun msg -> print(msg) in hello(world)"; 110 | [%expect 111 | {| 112 | let _ : string = 113 | let hello : string -> string = fun msg -> print(msg) in 114 | hello(world) 115 | in 116 | _ 117 | | |}] 118 | 119 | let%expect_test "" = 120 | infer ~env "fun x -> let y = fun z -> z in y"; 121 | [%expect 122 | {| 123 | let _ : a, b . a -> b -> b = 124 | fun x -> 125 | let y : c . c -> c = fun z -> z in y 126 | in 127 | _ 128 | | |}] 129 | 130 | let%expect_test "" = 131 | infer ~env "fun x -> let y = x in y"; 132 | [%expect 133 | {| 134 | let _ : a . a -> a = 135 | fun x -> 136 | let y : a = x in y 137 | in 138 | _ 139 | | |}] 140 | 141 | let%expect_test "" = 142 | infer ~env "fun x -> let y = fun z -> x in y"; 143 | [%expect 144 | {| 145 | let _ : a, b . b -> a -> b = 146 | fun x -> 147 | let y : c . c -> b = fun z -> x in y 148 | in 149 | _ 150 | | |}] 151 | 152 | let%expect_test "" = 153 | infer ~env "id"; 154 | [%expect {| 155 | let _ : a . a -> a = id in 156 | _ 157 | | |}] 158 | 159 | let%expect_test "" = 160 | infer ~env "one"; 161 | [%expect {| 162 | let _ : int = one in 163 | _ 164 | | |}] 165 | 166 | let%expect_test "" = 167 | infer ~env "x"; 168 | [%expect {| 169 | ERROR: unknown name: x 170 | | |}] 171 | 172 | let%expect_test "" = 173 | infer ~env "let x = x in x"; 174 | [%expect {| 175 | ERROR: unknown name: x 176 | | |}] 177 | 178 | let%expect_test "" = 179 | infer ~env "let x = id in x"; 180 | [%expect 181 | {| 182 | let _ : a . a -> a = 183 | let x : b . b -> b = id in 184 | x 185 | in 186 | _ 187 | | |}] 188 | 189 | let%expect_test "" = 190 | infer ~env "let x = fun y -> y in x"; 191 | [%expect 192 | {| 193 | let _ : a . a -> a = 194 | let x : b . b -> b = fun y -> y in 195 | x 196 | in 197 | _ 198 | | |}] 199 | 200 | let%expect_test "" = 201 | infer ~env "fun x -> x"; 202 | [%expect {| 203 | let _ : a . a -> a = fun x -> x in 204 | _ 205 | | |}] 206 | 207 | let%expect_test "" = 208 | infer ~env "pair"; 209 | [%expect {| 210 | let _ : a, b . (b, a) -> pair[b, a] = pair in 211 | _ 212 | | |}] 213 | 214 | let%expect_test "" = 215 | infer ~env "fun x -> let y = fun z -> z in y"; 216 | [%expect 217 | {| 218 | let _ : a, b . a -> b -> b = 219 | fun x -> 220 | let y : c . c -> c = fun z -> z in y 221 | in 222 | _ 223 | | |}] 224 | 225 | let%expect_test "" = 226 | infer ~env "let f = fun x -> x in let id = fun y -> y in eq(f, id)"; 227 | [%expect 228 | {| 229 | let _ : bool = 230 | let f : a . a -> a = fun x -> x in 231 | let id : b . b -> b = fun y -> y in 232 | eq(f, id) 233 | in 234 | _ 235 | | |}] 236 | 237 | let%expect_test "" = 238 | infer ~env "let f = fun x -> x in let id = fun y -> y in eq_curry(f)(id)"; 239 | [%expect 240 | {| 241 | let _ : bool = 242 | let f : a . a -> a = fun x -> x in 243 | let id : b . b -> b = fun y -> y in 244 | eq_curry(f)(id) 245 | in 246 | _ 247 | | |}] 248 | 249 | let%expect_test "" = 250 | infer ~env "let f = fun x -> x in eq(f, succ)"; 251 | [%expect 252 | {| 253 | let _ : bool = 254 | let f : a . a -> a = fun x -> x in 255 | eq(f, succ) 256 | in 257 | _ 258 | | |}] 259 | 260 | let%expect_test "" = 261 | infer ~env "let f = fun x -> x in eq_curry(f)(succ)"; 262 | [%expect 263 | {| 264 | let _ : bool = 265 | let f : a . a -> a = fun x -> x in 266 | eq_curry(f)(succ) 267 | in 268 | _ 269 | | |}] 270 | 271 | let%expect_test "" = 272 | infer ~env "let f = fun x -> x in pair(f(one), f(true))"; 273 | [%expect 274 | {| 275 | let _ : pair[int, bool] = 276 | let f : a . a -> a = fun x -> x in 277 | pair(f(one), f(true)) 278 | in 279 | _ 280 | | |}] 281 | 282 | let%expect_test "" = 283 | infer ~env "fun f -> pair(f(one), f(true))"; 284 | [%expect 285 | {| 286 | ERROR: incompatible types: 287 | int 288 | and 289 | bool 290 | | |}] 291 | 292 | let%expect_test "" = 293 | infer ~env "let f = fun (x, y) -> let a = eq(x, y) in eq(x, y) in f"; 294 | [%expect 295 | {| 296 | let _ : a . (a, a) -> bool = 297 | let f : b . (b, b) -> bool = 298 | fun (x, y) -> 299 | let a : bool = eq(x, y) in eq(x, y) 300 | in 301 | f 302 | in 303 | _ 304 | | |}] 305 | 306 | let%expect_test "" = 307 | infer ~env 308 | "let f = fun (x, y) -> let a = eq_curry(x)(y) in eq_curry(x)(y) in f"; 309 | [%expect 310 | {| 311 | let _ : a . (a, a) -> bool = 312 | let f : b . (b, b) -> bool = 313 | fun (x, y) -> 314 | let a : bool = eq_curry(x)(y) in eq_curry(x)(y) 315 | in 316 | f 317 | in 318 | _ 319 | | |}] 320 | 321 | let%expect_test "" = 322 | infer ~env "id(id)"; 323 | [%expect {| 324 | let _ : a . a -> a = id(id) in 325 | _ 326 | | |}] 327 | 328 | let%expect_test "" = 329 | infer ~env "choose(fun (x, y) -> x, fun (x, y) -> y)"; 330 | [%expect 331 | {| 332 | let _ : a . (a, a) -> a = 333 | choose(fun (x, y) -> x, fun (x, y) -> y) 334 | in 335 | _ 336 | | |}] 337 | 338 | let%expect_test "" = 339 | infer ~env "choose_curry(fun (x, y) -> x)(fun (x, y) -> y)"; 340 | [%expect 341 | {| 342 | let _ : a . (a, a) -> a = 343 | choose_curry(fun (x, y) -> x)(fun (x, y) -> y) 344 | in 345 | _ 346 | | |}] 347 | 348 | let%expect_test "" = 349 | infer ~env "let x = id in let y = let z = x(id) in z in y"; 350 | [%expect 351 | {| 352 | let _ : a . a -> a = 353 | let x : b . b -> b = id in 354 | let y : c . c -> c = 355 | let z : d . d -> d = x(id) in 356 | z 357 | in 358 | y 359 | in 360 | _ 361 | | |}] 362 | 363 | let%expect_test "" = 364 | infer ~env "cons(id, nil)"; 365 | [%expect {| 366 | let _ : a . list[a -> a] = cons(id, nil) in 367 | _ 368 | | |}] 369 | 370 | let%expect_test "" = 371 | infer ~env "cons_curry(id)(nil)"; 372 | [%expect 373 | {| 374 | let _ : a . list[a -> a] = cons_curry(id)(nil) in 375 | _ 376 | | |}] 377 | 378 | let%expect_test "" = 379 | infer ~env "let lst1 = cons(id, nil) in let lst2 = cons(succ, lst1) in lst2"; 380 | [%expect 381 | {| 382 | let _ : list[int -> int] = 383 | let lst1 : a . list[a -> a] = cons(id, nil) in 384 | let lst2 : list[int -> int] = cons(succ, lst1) in 385 | lst2 386 | in 387 | _ 388 | | |}] 389 | 390 | let%expect_test "" = 391 | infer ~env "cons_curry(id)(cons_curry(succ)(cons_curry(id)(nil)))"; 392 | [%expect 393 | {| 394 | let _ : list[int -> int] = 395 | cons_curry(id)(cons_curry(succ)(cons_curry(id)(nil))) 396 | in 397 | _ 398 | | |}] 399 | 400 | let%expect_test "" = 401 | infer ~env "plus(one, true)"; 402 | [%expect 403 | {| 404 | ERROR: incompatible types: 405 | int 406 | and 407 | bool 408 | | |}] 409 | 410 | let%expect_test "" = 411 | infer ~env "plus(one)"; 412 | [%expect 413 | {| 414 | ERROR: incompatible types: 415 | _2 -> _1 416 | and 417 | (int, int) -> int 418 | | |}] 419 | 420 | let%expect_test "" = 421 | infer ~env "fun x -> let y = x in y"; 422 | [%expect 423 | {| 424 | let _ : a . a -> a = 425 | fun x -> 426 | let y : a = x in y 427 | in 428 | _ 429 | | |}] 430 | 431 | let%expect_test "" = 432 | infer ~env "fun x -> let y = let z = x(fun x -> x) in z in y"; 433 | [%expect 434 | {| 435 | let _ : a, b . ((a -> a) -> b) -> b = 436 | fun x -> 437 | let y : b = 438 | let z : b = x(fun x -> x) in 439 | z 440 | in 441 | y 442 | in 443 | _ 444 | | |}] 445 | 446 | let%expect_test "" = 447 | infer ~env "fun x -> fun y -> let x = x(y) in x(y)"; 448 | [%expect 449 | {| 450 | let _ : a, b . (a -> a -> b) -> a -> b = 451 | fun x -> 452 | fun y -> 453 | let x : a -> b = x(y) in x(y) 454 | in 455 | _ 456 | | |}] 457 | 458 | let%expect_test "" = 459 | infer ~env "fun x -> let y = fun z -> x(z) in y"; 460 | [%expect 461 | {| 462 | let _ : a, b . (a -> b) -> a -> b = 463 | fun x -> 464 | let y : a -> b = fun z -> x(z) in y 465 | in 466 | _ 467 | | |}] 468 | 469 | let%expect_test "" = 470 | infer ~env "fun x -> let y = fun z -> x in y"; 471 | [%expect 472 | {| 473 | let _ : a, b . b -> a -> b = 474 | fun x -> 475 | let y : c . c -> b = fun z -> x in y 476 | in 477 | _ 478 | | |}] 479 | 480 | let%expect_test "" = 481 | infer ~env "fun x -> fun y -> let x = x(y) in fun x -> y(x)"; 482 | [%expect 483 | {| 484 | let _ : a, b, c . ((b -> c) -> a) -> (b -> c) -> b -> c = 485 | fun x -> 486 | fun y -> 487 | let x : a = x(y) in fun x -> y(x) 488 | in 489 | _ 490 | | |}] 491 | 492 | let%expect_test "" = 493 | infer ~env "fun x -> let y = x in y(y)"; 494 | [%expect {| 495 | ERROR: recursive type 496 | | |}] 497 | 498 | let%expect_test "" = 499 | infer ~env "fun x -> let y = fun z -> z in y(y)"; 500 | [%expect 501 | {| 502 | let _ : a, b . a -> b -> b = 503 | fun x -> 504 | let y : c . c -> c = fun z -> z in y(y) 505 | in 506 | _ 507 | | |}] 508 | 509 | let%expect_test "" = 510 | infer ~env "fun x -> x(x)"; 511 | [%expect {| 512 | ERROR: recursive type 513 | | |}] 514 | 515 | let%expect_test "" = 516 | infer ~env "one(id)"; 517 | [%expect 518 | {| 519 | ERROR: incompatible types: 520 | _2 -> _1 521 | and 522 | int 523 | | |}] 524 | 525 | let%expect_test "" = 526 | infer ~env "fun f -> let x = fun (g, y) -> let _ = g(y) in eq(f, g) in x"; 527 | [%expect 528 | {| 529 | let _ : a, b . (a -> b) -> (a -> b, a) -> bool = 530 | fun f -> 531 | let x : (a -> b, a) -> bool = 532 | fun (g, y) -> 533 | let _ : b = g(y) in eq(f, g) 534 | in 535 | x 536 | in 537 | _ 538 | | |}] 539 | 540 | let%expect_test "" = 541 | infer ~env "let const = fun x -> fun y -> x in const"; 542 | [%expect 543 | {| 544 | let _ : a, b . b -> a -> b = 545 | let const : c, d . d -> c -> d = fun x -> fun y -> x in 546 | const 547 | in 548 | _ 549 | | |}] 550 | 551 | let%expect_test "" = 552 | infer ~env "let apply = fun (f, x) -> f(x) in apply"; 553 | [%expect 554 | {| 555 | let _ : a, b . (a -> b, a) -> b = 556 | let apply : c, d . (c -> d, c) -> d = 557 | fun (f, x) -> f(x) 558 | in 559 | apply 560 | in 561 | _ 562 | | |}] 563 | 564 | let%expect_test "" = 565 | infer ~env "let apply_curry = fun f -> fun x -> f(x) in apply_curry"; 566 | [%expect 567 | {| 568 | let _ : a, b . (a -> b) -> a -> b = 569 | let apply_curry : c, d . (c -> d) -> c -> d = 570 | fun f -> fun x -> f(x) 571 | in 572 | apply_curry 573 | in 574 | _ 575 | | |}] 576 | 577 | (* typeclasses *) 578 | 579 | let env = 580 | env 581 | (* eq *) 582 | |> Env.assume_tclass "eq" "a . (a, a) -> bool" 583 | |> Env.assume_tclass_instance "eq_int" "eq[int]" 584 | |> Env.assume_tclass_instance "eq_bool" "eq[bool]" 585 | |> Env.assume_tclass_instance "eq_list" "a . eq[a] => eq[list[a]]" 586 | (* compare[a] *) 587 | |> Env.assume_tclass "compare" "a . eq[a] => (a, a) -> bool" 588 | |> Env.assume_tclass_instance "compare_int" "compare[int]" 589 | |> Env.assume_tclass_instance "compare_list" 590 | "a . compare[a] => compare[list[a]]" 591 | (* show *) 592 | |> Env.assume_tclass "show" "a . a -> string" 593 | |> Env.assume_tclass_instance "show_int" "show[int]" 594 | |> Env.assume_tclass_instance "show_float" "show[float]" 595 | (* read *) 596 | |> Env.assume_tclass "read" "a . string -> a" 597 | |> Env.assume_tclass_instance "read_int" "read[int]" 598 | |> Env.assume_tclass_instance "read_float" "read[float]" 599 | (* coerce *) 600 | |> Env.assume_tclass "coerce" "a, b . a -> b" 601 | |> Env.assume_tclass_instance "coerce_id" "a . coerce[a, a]" 602 | |> Env.assume_tclass_instance "coerce_list" 603 | "a, b . coerce[a, b] => coerce[list[a], list[b]]" 604 | |> Env.assume_tclass_instance "coerce_bool_int" "coerce[bool, int]" 605 | 606 | let%expect_test "just a sanity check" = 607 | infer ~env "eq"; 608 | [%expect 609 | {| 610 | let _ : a . eq[a] => (a, a) -> bool = 611 | fun _eq_0_1 -> _eq_0_1.eq 612 | in 613 | _ 614 | | |}] 615 | 616 | let%expect_test "just a sanity check" = 617 | infer ~env "let f = eq in f"; 618 | [%expect 619 | {| 620 | let _ : a . eq[a] => (a, a) -> bool = 621 | fun _eq_0_1 -> 622 | let f : b . eq[b] => (b, b) -> bool = 623 | fun _eq_1_1 -> _eq_1_1.eq 624 | in 625 | f(_eq_0_1) 626 | in 627 | _ 628 | | |}] 629 | 630 | let%expect_test "just a sanity check" = 631 | infer ~env "fun (x, y) -> eq(x, y)"; 632 | [%expect 633 | {| 634 | let _ : a . eq[a] => (a, a) -> bool = 635 | fun _eq_0_1 -> fun (x, y) -> _eq_0_1.eq(x, y) 636 | in 637 | _ 638 | | |}] 639 | 640 | let%expect_test "eq[int] should be completely resolved" = 641 | infer ~env "fun (x) -> eq(x, one)"; 642 | [%expect 643 | {| 644 | let _ : int -> bool = fun x -> eq_int.eq(x, one) in 645 | _ 646 | | |}] 647 | 648 | let%expect_test "eq[list[a]] should be resolved to eq[a]" = 649 | infer ~env "fun (x) -> eq(cons(x, nil), cons(x, nil))"; 650 | [%expect 651 | {| 652 | let _ : a . eq[a] => a -> bool = 653 | fun _eq_0_1 -> 654 | fun x -> eq_list(_eq_0_1).eq(cons(x, nil), cons(x, nil)) 655 | in 656 | _ 657 | | |}] 658 | 659 | let%expect_test "eq[list[int]] should be completely resolved" = 660 | infer ~env "fun (x) -> eq(cons(x, nil), cons(one, nil))"; 661 | [%expect 662 | {| 663 | let _ : int -> bool = 664 | fun x -> eq_list(eq_int).eq(cons(x, nil), cons(one, nil)) 665 | in 666 | _ 667 | | |}] 668 | 669 | let%expect_test "eq[list[int]] should be completely resolved" = 670 | infer ~env "fun (x) -> eq(x, cons(one, nil))"; 671 | [%expect 672 | {| 673 | let _ : list[int] -> bool = 674 | fun x -> eq_list(eq_int).eq(x, cons(one, nil)) 675 | in 676 | _ 677 | | |}] 678 | 679 | let%expect_test "eq[c] constraint is retained at f while eq[list[a]] is \ 680 | deferred till the top" = 681 | infer ~env 682 | {| 683 | fun y -> 684 | let f x = pair(eq(cons(x, nil), nil), eq(y, nil)) in 685 | f 686 | |}; 687 | [%expect 688 | {| 689 | let _ : 690 | a, b . eq[a], eq[b] => list[a] -> b -> pair[bool, bool] = 691 | fun (_eq_1_1, _eq_0_1) -> 692 | fun y -> 693 | let f : c . eq[c] => c -> pair[bool, bool] = 694 | fun _eq_1_2 -> 695 | fun x -> 696 | pair( 697 | eq_list(_eq_1_2).eq(cons(x, nil), nil), 698 | eq_list(_eq_1_1).eq(y, nil)) 699 | in 700 | f(_eq_0_1) 701 | in 702 | _ 703 | | |}] 704 | 705 | let%expect_test "should be ambigious" = 706 | infer ~env "show(read(world))"; 707 | [%expect {| 708 | ERROR: ambigious typeclass application: read[_2] 709 | | |}] 710 | 711 | let%expect_test "should be ambigious" = 712 | infer ~env "fun x -> show(read(x))"; 713 | [%expect {| 714 | ERROR: ambigious typeclass application: read[_4] 715 | | |}] 716 | 717 | let%expect_test "usage of plus resolves ambiguity" = 718 | infer ~env "show(plus(read(world), one))"; 719 | [%expect 720 | {| 721 | let _ : string = 722 | show_int.show(plus(read_int.read(world), one)) 723 | in 724 | _ 725 | | |}] 726 | 727 | let%expect_test "usage of plus resolves ambiguity" = 728 | infer ~env "fun x -> show(plus(read(x), one))"; 729 | [%expect 730 | {| 731 | let _ : string -> string = 732 | fun x -> show_int.show(plus(read_int.read(x), one)) 733 | in 734 | _ 735 | | |}] 736 | 737 | let%expect_test "just a sanity check" = 738 | infer ~env "fun (x, y) -> pair(eq(x, x), eq(y, y))"; 739 | [%expect 740 | {| 741 | let _ : a, b . eq[b], eq[a] => (b, a) -> pair[bool, bool] = 742 | fun (_eq_0_2, _eq_0_1) -> 743 | fun (x, y) -> pair(_eq_0_2.eq(x, x), _eq_0_1.eq(y, y)) 744 | in 745 | _ 746 | | |}] 747 | 748 | let%expect_test "just a sanity check" = 749 | infer ~env "fun (x, y) -> pair(eq(cons(x, nil), nil), eq(y, nil))"; 750 | [%expect 751 | {| 752 | let _ : 753 | a, b . eq[b], eq[a] => (b, list[a]) -> pair[bool, bool] = 754 | fun (_eq_0_2, _eq_0_1) -> 755 | fun (x, y) -> 756 | pair( 757 | eq_list(_eq_0_2).eq(cons(x, nil), nil), 758 | eq_list(_eq_0_1).eq(y, nil)) 759 | in 760 | _ 761 | | |}] 762 | 763 | let%expect_test "just a sanity check" = 764 | infer ~env "fun (x, y) -> pair(eq(cons(x, nil), nil), eq(cons(y, nil), nil))"; 765 | [%expect 766 | {| 767 | let _ : a, b . eq[b], eq[a] => (b, a) -> pair[bool, bool] = 768 | fun (_eq_0_2, _eq_0_1) -> 769 | fun (x, y) -> 770 | pair( 771 | eq_list(_eq_0_2).eq(cons(x, nil), nil), 772 | eq_list(_eq_0_1).eq(cons(y, nil), nil)) 773 | in 774 | _ 775 | | |}] 776 | 777 | let%expect_test "just a sanity check" = 778 | infer ~env "fun (x, y) -> compare(x, y)"; 779 | [%expect 780 | {| 781 | let _ : a . compare[a] => (a, a) -> bool = 782 | fun _compare_0_1 -> 783 | fun (x, y) -> _compare_0_1.compare(x, y) 784 | in 785 | _ 786 | | |}] 787 | 788 | let%expect_test "compare[a] subsumes eq[a]" = 789 | infer ~env "fun (x, y) -> pair(compare(x, y), eq(x, y))"; 790 | [%expect 791 | {| 792 | let _ : a . compare[a] => (a, a) -> pair[bool, bool] = 793 | fun _compare_0_2 -> 794 | fun (x, y) -> 795 | pair(_compare_0_2.compare(x, y), _compare_0_2.eq(x, y)) 796 | in 797 | _ 798 | | |}] 799 | 800 | let%expect_test "no compare[bool] defined" = 801 | infer ~env "fun (x, y) -> compare(x, true)"; 802 | [%expect {| 803 | ERROR: no typeclass instance found: compare[bool] 804 | | |}] 805 | 806 | let%expect_test "just a sanity check" = 807 | infer ~env "fun (x, y) -> compare(x, nil)"; 808 | [%expect 809 | {| 810 | let _ : a, b . compare[b] => (list[b], a) -> bool = 811 | fun _compare_0_1 -> 812 | fun (x, y) -> compare_list(_compare_0_1).compare(x, nil) 813 | in 814 | _ 815 | | |}] 816 | 817 | let%expect_test "just a sanity check" = 818 | infer ~env 819 | {| 820 | let f (x, y) = 821 | quadruple( 822 | eq(x, x), 823 | compare(x, x), 824 | eq(nil, y), 825 | compare(nil, y) 826 | ) 827 | in 828 | f(cons(one, nil), cons(one, nil)) 829 | |}; 830 | [%expect 831 | {| 832 | let _ : quadruple[bool, bool, bool, bool] = 833 | let f : 834 | a, b . compare[b], compare[a] => 835 | (b, list[a]) -> quadruple[bool, bool, bool, bool] = 836 | fun (_compare_1_3, _compare_1_1) -> 837 | fun (x, y) -> 838 | quadruple( 839 | _compare_1_3.eq(x, x), 840 | _compare_1_3.compare(x, x), 841 | eq_list(_compare_1_1).eq(nil, y), 842 | compare_list(_compare_1_1).compare(nil, y)) 843 | in 844 | f(compare_list(compare_int), compare_int)( 845 | cons(one, nil), 846 | cons(one, nil)) 847 | in 848 | _ 849 | | |}] 850 | 851 | let%expect_test "check deferred constraints" = 852 | infer ~env 853 | {| 854 | let g y = 855 | let f x = pair(eq(cons(x, nil), nil), eq(y, nil)) in 856 | one 857 | in 858 | g(cons(one, nil)) 859 | |}; 860 | [%expect 861 | {| 862 | let _ : int = 863 | let g : a . eq[a] => list[a] -> int = 864 | fun _eq_2_1 -> 865 | fun y -> 866 | let f : b . eq[b] => b -> pair[bool, bool] = 867 | fun _eq_2_2 -> 868 | fun x -> 869 | pair( 870 | eq_list(_eq_2_2).eq(cons(x, nil), nil), 871 | eq_list(_eq_2_1).eq(y, nil)) 872 | in 873 | one 874 | in 875 | g(eq_int)(cons(one, nil)) 876 | in 877 | _ 878 | | |}] 879 | 880 | let%expect_test "check deferred constraints" = 881 | infer ~env 882 | {| 883 | let g y = 884 | let f x = pair(eq(cons(x, nil), nil), eq(y, nil)) in 885 | pair(one, eq(one, one)) 886 | in 887 | g(cons(one, nil)) 888 | |}; 889 | [%expect 890 | {| 891 | let _ : pair[int, bool] = 892 | let g : a . eq[a] => list[a] -> pair[int, bool] = 893 | fun _eq_2_1 -> 894 | fun y -> 895 | let f : b . eq[b] => b -> pair[bool, bool] = 896 | fun _eq_2_2 -> 897 | fun x -> 898 | pair( 899 | eq_list(_eq_2_2).eq(cons(x, nil), nil), 900 | eq_list(_eq_2_1).eq(y, nil)) 901 | in 902 | pair(one, eq_int.eq(one, one)) 903 | in 904 | g(eq_int)(cons(one, nil)) 905 | in 906 | _ 907 | | |}] 908 | 909 | let%expect_test "check deferred constraints" = 910 | infer ~env 911 | {| 912 | let g y = 913 | let f x = pair(eq(cons(x, nil), nil), eq(y, nil)) in 914 | pair(one, eq(y, nil)) 915 | in 916 | g(cons(one, nil)) 917 | |}; 918 | [%expect 919 | {| 920 | let _ : pair[int, bool] = 921 | let g : a . eq[a] => list[a] -> pair[int, bool] = 922 | fun _eq_2_1 -> 923 | fun y -> 924 | let f : b . eq[b] => b -> pair[bool, bool] = 925 | fun _eq_2_2 -> 926 | fun x -> 927 | pair( 928 | eq_list(_eq_2_2).eq(cons(x, nil), nil), 929 | eq_list(_eq_2_1).eq(y, nil)) 930 | in 931 | pair(one, eq_list(_eq_2_1).eq(y, nil)) 932 | in 933 | g(eq_int)(cons(one, nil)) 934 | in 935 | _ 936 | | |}] 937 | 938 | let%expect_test "check deferred constraints" = 939 | infer ~env 940 | {| 941 | let g y = 942 | let f x = pair(eq(cons(x, nil), nil), eq(y, nil)) in 943 | pair(f(y), compare(y, nil)) 944 | in 945 | g(cons(one, nil)) 946 | |}; 947 | [%expect 948 | {| 949 | let _ : pair[pair[bool, bool], bool] = 950 | let g : 951 | a . compare[a] => 952 | list[a] -> pair[pair[bool, bool], bool] = 953 | fun _compare_1_1 -> 954 | fun y -> 955 | let f : b . eq[b] => b -> pair[bool, bool] = 956 | fun _eq_2_2 -> 957 | fun x -> 958 | pair( 959 | eq_list(_eq_2_2).eq(cons(x, nil), nil), 960 | eq_list(_compare_1_1).eq(y, nil)) 961 | in 962 | pair( 963 | f(eq_list(_compare_1_1))(y), 964 | compare_list(_compare_1_1).compare(y, nil)) 965 | in 966 | g(compare_int)(cons(one, nil)) 967 | in 968 | _ 969 | | |}] 970 | 971 | let%expect_test "multi parameter type classes" = 972 | infer ~env 973 | {| 974 | let g = 975 | fun x -> 976 | let f y = eq(coerce(x), y) in f 977 | in 978 | g 979 | |}; 980 | [%expect 981 | {| 982 | let _ : a, b . coerce[a, b], eq[b] => a -> b -> bool = 983 | fun (_coerce_0_1, _eq_0_2) -> 984 | let g : c, d . coerce[c, d], eq[d] => c -> d -> bool = 985 | fun (_coerce_2_2, _eq_2_1) -> 986 | fun x -> 987 | let f : d -> bool = 988 | fun y -> _eq_2_1.eq(_coerce_2_2.coerce(x), y) 989 | in 990 | f 991 | in 992 | g(_coerce_0_1, _eq_0_2) 993 | in 994 | _ 995 | | |}] 996 | 997 | let%expect_test "multi parameter type classes" = 998 | infer ~env 999 | {| 1000 | let g = 1001 | fun x -> 1002 | let f y = eq(coerce(x), y) in f 1003 | in 1004 | g(true) 1005 | |}; 1006 | [%expect 1007 | {| 1008 | let _ : a . coerce[bool, a], eq[a] => a -> bool = 1009 | fun (_coerce_0_1, _eq_0_2) -> 1010 | let g : b, c . coerce[b, c], eq[c] => b -> c -> bool = 1011 | fun (_coerce_2_2, _eq_2_1) -> 1012 | fun x -> 1013 | let f : c -> bool = 1014 | fun y -> _eq_2_1.eq(_coerce_2_2.coerce(x), y) 1015 | in 1016 | f 1017 | in 1018 | g(_coerce_0_1, _eq_0_2)(true) 1019 | in 1020 | _ 1021 | | |}] 1022 | 1023 | let%expect_test "multi parameter type classes" = 1024 | infer ~env 1025 | {| 1026 | let g = 1027 | fun x -> 1028 | let f y = eq(coerce(x), y) in f 1029 | in 1030 | g(true)(one) 1031 | |}; 1032 | [%expect 1033 | {| 1034 | let _ : bool = 1035 | let g : a, b . coerce[a, b], eq[b] => a -> b -> bool = 1036 | fun (_coerce_2_2, _eq_2_1) -> 1037 | fun x -> 1038 | let f : b -> bool = 1039 | fun y -> _eq_2_1.eq(_coerce_2_2.coerce(x), y) 1040 | in 1041 | f 1042 | in 1043 | g(coerce_bool_int, eq_int)(true)(one) 1044 | in 1045 | _ 1046 | | |}] 1047 | 1048 | let%expect_test "should eliminate eq[int]" = 1049 | infer ~env {| 1050 | fun x -> 1051 | let f y = eq(x, y) in f(one) 1052 | |}; 1053 | [%expect 1054 | {| 1055 | let _ : int -> bool = 1056 | fun x -> 1057 | let f : int -> bool = fun y -> eq_int.eq(x, y) in f(one) 1058 | in 1059 | _ 1060 | | |}] 1061 | 1062 | let%expect_test "should eliminate eq[int]" = 1063 | infer ~env 1064 | {| 1065 | let equal_to_one x = eq(one, coerce(x)) in 1066 | equal_to_one 1067 | |}; 1068 | [%expect 1069 | {| 1070 | let _ : a . coerce[a, int] => a -> bool = 1071 | fun _coerce_0_1 -> 1072 | let equal_to_one : b . coerce[b, int] => b -> bool = 1073 | fun _coerce_1_1 -> 1074 | fun x -> eq_int.eq(one, _coerce_1_1.coerce(x)) 1075 | in 1076 | equal_to_one(_coerce_0_1) 1077 | in 1078 | _ 1079 | | |}] 1080 | -------------------------------------------------------------------------------- /hmx_tc/type_error.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | open Syntax 3 | 4 | type t = 5 | | Error_unification of ty * ty 6 | | Error_recursive_type 7 | | Error_unknown_name of string 8 | | Error_ambigious_tclass_application of bound 9 | | Error_no_tclass_instance of ty 10 | | Error_unknown_tclass of name 11 | 12 | include ( 13 | Showable (struct 14 | type nonrec t = t 15 | 16 | let layout = 17 | let open PPrint in 18 | function 19 | | Error_unification (ty1, ty2) -> 20 | string "incompatible types:" 21 | ^^ nest 2 (break 1 ^^ Ty.layout ty1) 22 | ^^ break 1 23 | ^^ string "and" 24 | ^^ nest 2 (break 1 ^^ Ty.layout ty2) 25 | | Error_recursive_type -> string "recursive type" 26 | | Error_unknown_name name -> string "unknown name: " ^^ string name 27 | | Error_ambigious_tclass_application (B_class (name, tys)) -> 28 | string "ambigious typeclass application: " 29 | ^^ Ty.layout (Ty_app (Ty_const name, tys)) 30 | | Error_unknown_tclass name -> string "unknown typeclass: " ^^ string name 31 | | Error_no_tclass_instance ty -> 32 | string "no typeclass instance found: " ^^ Ty.layout ty 33 | end) : 34 | SHOWABLE with type t := t) 35 | 36 | exception Type_error of t 37 | 38 | let raise error = raise (Type_error error) 39 | -------------------------------------------------------------------------------- /hmx_tc/union_find.ml: -------------------------------------------------------------------------------- 1 | open! Base 2 | 3 | type 'a loc = Root of 'a | Link of 'a t 4 | 5 | and 'a t = 'a loc ref [@@deriving sexp_of] 6 | 7 | let make value = ref (Root value) 8 | 9 | let rec root p : _ t = 10 | match p.contents with 11 | | Root _ -> p 12 | | Link p' -> 13 | let p'' = root p' in 14 | (* Perform path compression. *) 15 | if not (phys_equal p' p'') then p.contents <- p'.contents; 16 | p'' 17 | 18 | let value p = 19 | match (root p).contents with 20 | | Root value -> value 21 | | Link _ -> assert false 22 | 23 | let union ~f a b = 24 | if phys_equal a b then () 25 | else 26 | let a = root a in 27 | let b = root b in 28 | if phys_equal a b then () 29 | else 30 | match (a.contents, b.contents) with 31 | | Root avalue, Root bvalue -> 32 | a.contents <- Link b; 33 | b.contents <- Root (f avalue bvalue) 34 | | Root _, Link _ 35 | | Link _, Root _ 36 | | Link _, Link _ -> 37 | assert false 38 | 39 | let link ~target p = union ~f:(fun _b target -> target) p target 40 | 41 | let equal a b = phys_equal a b || phys_equal (root a) (root b) 42 | -------------------------------------------------------------------------------- /hmx_tc/union_find.mli: -------------------------------------------------------------------------------- 1 | (** Union find. *) 2 | 3 | open! Base 4 | 5 | type 'a t 6 | (** Represents a single element. 7 | 8 | Each element belongs to an equivalence class and each equivalence class has 9 | a value of type ['a] assocated with it. *) 10 | 11 | val make : 'a -> 'a t 12 | (** [make v] creates a new equivalence class consisting of a single element 13 | which is returned to the caller. 14 | 15 | The value [v] is assocated with the equivalence class being created. *) 16 | 17 | val value : 'a t -> 'a 18 | (** [value e] returns the value associated with equivalence class the element 19 | [e] belongs to. *) 20 | 21 | val union : f:('a -> 'a -> 'a) -> 'a t -> 'a t -> unit 22 | (** [union ~f a b] makes elements [a] and [b] belong to the same equivalence 23 | class so that [equal a b] returns [true] afterwards. 24 | 25 | The resulted value associated with the equivalence class is being merged as 26 | specified by the [f] function. *) 27 | 28 | val link : target:'a t -> 'a t -> unit 29 | (** [link a b] is the same as [union a b] but guarantees to link [b] to [a] and 30 | not vice versa. *) 31 | 32 | val equal : 'a t -> 'a t -> bool 33 | (** [equal a b] checks that both elements [a] and [b] belong to the same 34 | equivalence class. *) 35 | 36 | val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t 37 | -------------------------------------------------------------------------------- /hmx_tc/var.ml: -------------------------------------------------------------------------------- 1 | open Base 2 | open Syntax 3 | 4 | type t = var 5 | 6 | module Id = MakeId () 7 | 8 | let fresh ?(bound_lvl = 0) ?lvl () : var = 9 | let id = Id.fresh () in 10 | Union_find.make { ty = None; lvl; bound_lvl; id } 11 | 12 | let reset = Id.reset 13 | 14 | let bound_lvl v = (Union_find.value v).bound_lvl 15 | 16 | let ty v = (Union_find.value v).ty 17 | 18 | let set_ty ty v = (Union_find.value v).ty <- Some ty 19 | 20 | let lvl var = 21 | let v = Union_find.value var in 22 | match v.lvl with 23 | | Some lvl -> lvl 24 | | None -> failwith (Printf.sprintf "%i has no lvl assigned" v.id) 25 | 26 | let set_lvl lvl v = (Union_find.value v).lvl <- Some lvl 27 | 28 | let equal = Union_find.equal 29 | 30 | let show v = 31 | let data = Union_find.value v in 32 | match data.ty with 33 | | None -> Printf.sprintf "_%i" data.id 34 | | Some ty -> Ty.show ty 35 | 36 | let merge_lvl lvl1 lvl2 = 37 | match (lvl1, lvl2) with 38 | | None, None 39 | | Some _, None 40 | | None, Some _ -> 41 | failwith "lvl is not assigned" 42 | | Some lvl1, Some lvl2 -> Some (min lvl1 lvl2) 43 | 44 | (** [occurs_check_adjust_lvl var ty] checks that variable [var] is not 45 | contained within type [ty] and adjust levels of all unbound vars within 46 | the [ty]. *) 47 | let occurs_check_adjust_lvl var = 48 | let rec occurs_check_ty ty' : unit = 49 | match ty' with 50 | | Ty_const _ -> () 51 | | Ty_arr (args, ret) -> 52 | List.iter args ~f:occurs_check_ty; 53 | occurs_check_ty ret 54 | | Ty_app (f, args) -> 55 | occurs_check_ty f; 56 | List.iter args ~f:occurs_check_ty 57 | | Ty_var other_var -> ( 58 | match ty other_var with 59 | | Some ty' -> occurs_check_ty ty' 60 | | None -> 61 | if equal other_var var then Type_error.raise Error_recursive_type 62 | else 63 | let data = Union_find.value var 64 | and odata = Union_find.value other_var in 65 | odata.lvl <- merge_lvl data.lvl odata.lvl) 66 | in 67 | occurs_check_ty 68 | 69 | let unify var1 var2 = 70 | let merge v1 v2 = 71 | let v = 72 | match (v1.ty, v2.ty) with 73 | | Some _, Some _ 74 | | Some _, None 75 | | None, None -> 76 | v1 77 | | None, Some _ -> v2 78 | in 79 | v.lvl <- merge_lvl v1.lvl v2.lvl; 80 | v 81 | in 82 | match (ty var1, ty var2) with 83 | | Some ty1, Some ty2 -> 84 | Union_find.union var1 var2 ~f:merge; 85 | Some (ty1, ty2) 86 | | Some ty1, None -> 87 | occurs_check_adjust_lvl var2 ty1; 88 | Union_find.union var1 var2 ~f:merge; 89 | None 90 | | None, Some ty2 -> 91 | occurs_check_adjust_lvl var1 ty2; 92 | Union_find.union var1 var2 ~f:merge; 93 | None 94 | | None, None -> 95 | Union_find.union var1 var2 ~f:merge; 96 | None 97 | -------------------------------------------------------------------------------- /hmx_tc/var.mli: -------------------------------------------------------------------------------- 1 | open Syntax 2 | 3 | type t = var 4 | 5 | val fresh : ?bound_lvl:lvl -> ?lvl:lvl -> unit -> t 6 | 7 | val reset : unit -> unit 8 | 9 | val equal : t -> t -> bool 10 | 11 | val show : t -> string 12 | 13 | val lvl : t -> lvl 14 | 15 | val set_lvl : lvl -> t -> unit 16 | 17 | val ty : t -> ty option 18 | 19 | val set_ty : ty -> t -> unit 20 | 21 | val bound_lvl : t -> lvl 22 | 23 | val unify : t -> t -> (ty * ty) option 24 | 25 | val occurs_check_adjust_lvl : t -> ty -> unit 26 | -------------------------------------------------------------------------------- /type-systems.opam: -------------------------------------------------------------------------------- 1 | # This file is generated by dune, edit dune-project instead 2 | opam-version: "2.0" 3 | synopsis: "Type Systems Toy Implementations" 4 | depends: [ 5 | "dune" {>= "2.9" & >= "2.8"} 6 | "menhir" 7 | "pprint" 8 | "nice_parser" 9 | "ppx_sexp_conv" 10 | "ppx_expect" 11 | "base" 12 | "odoc" {with-doc} 13 | ] 14 | build: [ 15 | ["dune" "subst"] {dev} 16 | [ 17 | "dune" 18 | "build" 19 | "-p" 20 | name 21 | "-j" 22 | jobs 23 | "--promote-install-files=false" 24 | "@install" 25 | "@runtest" {with-test} 26 | "@doc" {with-doc} 27 | ] 28 | ["dune" "install" "-p" name "--create-install-files" name] 29 | ] 30 | --------------------------------------------------------------------------------