├── README.md ├── SConstruct ├── demo.cpp ├── inference.hpp ├── syntax.hpp └── unification.hpp /README.md: -------------------------------------------------------------------------------- 1 | hindley_milner 2 | ============== 3 | 4 | An implementation of the Hindley-Milner type checking algorithm based on the [Python code by Robert Smallshire](http://www.smallshire.org.uk/sufficientlysmall/2010/04/11/a-hindley-milner-type-inference-implementation-in-python/), the [Scala code by Andrew Forrest](http://dysphoria.net/2009/06/28/hindley-milner-type-inference-in-scala/), the Perl code by Nikita Borisov, and [the paper "Basic Polymorphic Typechecking" by Cardelli](http://lucacardelli.name/Papers/BasicTypechecking.pdf). 5 | 6 | Interestingly, this implementation makes extensive use of ```boost::variant``` and ```boost::apply_visitor```. 7 | 8 | Build the demo program with: 9 | 10 | ``` 11 | $ scons 12 | ``` 13 | 14 | Run the demo program with: 15 | 16 | ``` 17 | $ ./demo 18 | (letrec factorial = (fn n => (((cond (zero n)) 1) ((times n) (factorial (pred n))))) in (factorial 5)) : int 19 | (fn x => ((pair (x 3)) (x true))) : type mismatch: bool != int 20 | ((pair (f 4)) (f true)) : Undefined symbol f 21 | (let f = (fn x => x) in ((pair (f 4)) (f true))) : (int * bool) 22 | (fn f => (f f)) : recursive unification: a in (a -> b) 23 | (let g = (fn f => 5) in (g g)) : int 24 | (fn g => (let f = (fn x => g) in ((pair (f 3)) (f true)))) : (a -> (a * a)) 25 | (fn f => (fn g => (fn arg => (g (f arg))))) : ((a -> b) -> ((b -> c) -> (a -> c))) 26 | (fn f => (f 5)) : ((int -> a) -> a) 27 | ((fn y => (y 1)) (fn x => 1)) : int 28 | ``` 29 | -------------------------------------------------------------------------------- /SConstruct: -------------------------------------------------------------------------------- 1 | env = Environment(CCFLAGS = "-std=c++0x -Wall -g") 2 | 3 | if env['PLATFORM'] == 'darwin': 4 | env['CXX'] = '/opt/local/bin/g++-mp-4.5' 5 | 6 | sources = Glob('*.cpp') 7 | 8 | env.Program('demo', "demo.cpp") 9 | 10 | -------------------------------------------------------------------------------- /demo.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "unification.hpp" 3 | #include "syntax.hpp" 4 | #include "inference.hpp" 5 | 6 | namespace types 7 | { 8 | static const int integer = 0; 9 | static const int boolean = 1; 10 | static const int function = 2; 11 | static const int pair = 3; 12 | } 13 | 14 | class pretty_printer 15 | : public boost::static_visitor 16 | { 17 | public: 18 | inline pretty_printer(std::ostream &os) 19 | : m_os(os), 20 | m_next_name('a') 21 | {} 22 | 23 | inline pretty_printer &operator()(const unification::type_variable &x) 24 | { 25 | if(!m_names.count(x)) 26 | { 27 | std::ostringstream os; 28 | os << m_next_name++; 29 | m_names[x] = os.str(); 30 | } // end if 31 | 32 | m_os << m_names[x]; 33 | return *this; 34 | } 35 | 36 | inline pretty_printer &operator()(const unification::type_operator &x) 37 | { 38 | switch(x.kind()) 39 | { 40 | case types::integer: 41 | { 42 | m_os << "int"; 43 | break; 44 | } // end case 45 | case types::boolean: 46 | { 47 | m_os << "bool"; 48 | break; 49 | } // end case 50 | case types::function: 51 | { 52 | m_os << "("; 53 | *this << x[0]; 54 | m_os << " -> "; 55 | *this << x[1]; 56 | m_os << ")"; 57 | break; 58 | } // end case 59 | case types::pair: 60 | { 61 | m_os << "("; 62 | *this << x[0]; 63 | m_os << " * "; 64 | *this << x[1]; 65 | m_os << ")"; 66 | break; 67 | } // end case 68 | default: 69 | { 70 | } // end default 71 | } // end switch 72 | 73 | return *this; 74 | } 75 | 76 | inline pretty_printer &operator<<(const unification::type &x) 77 | { 78 | return boost::apply_visitor(*this, x); 79 | } 80 | 81 | inline pretty_printer &operator<<(std::ostream & (*fp)(std::ostream &)) 82 | { 83 | fp(m_os); 84 | return *this; 85 | } 86 | 87 | template 88 | inline pretty_printer &operator<<(const T &x) 89 | { 90 | m_os << x; 91 | return *this; 92 | } 93 | 94 | private: 95 | std::ostream &m_os; 96 | 97 | std::map m_names; 98 | char m_next_name; 99 | }; 100 | 101 | namespace unification 102 | { 103 | 104 | inline std::ostream &operator<<(std::ostream &os, const type_variable &x) 105 | { 106 | pretty_printer pp(os); 107 | pp << type(x); 108 | return os; 109 | } 110 | 111 | inline std::ostream &operator<<(std::ostream &os, const type_operator &x) 112 | { 113 | pretty_printer pp(os); 114 | pp << type(x); 115 | return os; 116 | } 117 | 118 | } 119 | 120 | struct try_to_infer 121 | { 122 | inline try_to_infer(const inference::environment &e) 123 | : env(e) 124 | {} 125 | 126 | inline void operator()(const syntax::node &n) const 127 | { 128 | try 129 | { 130 | auto result = inference::infer_type(n, env); 131 | 132 | std::cout << n << " : "; 133 | pretty_printer pp(std::cout); 134 | pp << result << std::endl; 135 | } // end try 136 | catch(const unification::recursive_unification &e) 137 | { 138 | std::cerr << n << " : "; 139 | pretty_printer pp(std::cerr); 140 | pp << e.what() << ": " << e.x << " in " << e.y << std::endl; 141 | } // end catch 142 | catch(const unification::type_mismatch &e) 143 | { 144 | std::cerr << n << " : "; 145 | pretty_printer pp(std::cerr); 146 | pp << e.what() << ": " << e.x << " != " << e.y << std::endl; 147 | } // end catch 148 | catch(const std::runtime_error &e) 149 | { 150 | std::cerr << n << " : " << e.what() << std::endl; 151 | } // end catch 152 | } // end operator() 153 | 154 | const inference::environment &env; 155 | }; 156 | 157 | int main() 158 | { 159 | using namespace unification; 160 | using namespace syntax; 161 | using namespace inference; 162 | 163 | environment env; 164 | std::vector examples; 165 | 166 | auto var1 = type_variable(env.unique_id()); 167 | auto var2 = type_variable(env.unique_id()); 168 | auto var3 = type_variable(env.unique_id()); 169 | 170 | env["pair"] = make_function(var1, inference::make_function(var2, pair(var1, var2))); 171 | env["true"] = boolean(); 172 | env["cond"] = make_function( 173 | boolean(), 174 | make_function( 175 | var3, make_function( 176 | var3, var3 177 | ) 178 | ) 179 | ); 180 | env["zero"] = make_function(integer(), boolean()); 181 | env["pred"] = make_function(integer(), integer()); 182 | env["times"] = make_function( 183 | integer(), make_function( 184 | integer(), integer() 185 | ) 186 | ); 187 | 188 | auto pair = apply(apply(identifier("pair"), apply(identifier("f"), integer_literal(4))), apply(identifier("f"), identifier("true"))); 189 | 190 | // factorial 191 | { 192 | auto example = 193 | letrec("factorial", 194 | lambda("n", 195 | apply( 196 | apply( 197 | apply(identifier("cond"), 198 | apply(identifier("zero"), identifier("n")) 199 | ), 200 | integer_literal(1) 201 | ), 202 | apply( 203 | apply(identifier("times"), identifier("n")), 204 | apply(identifier("factorial"), 205 | apply(identifier("pred"), identifier("n")) 206 | ) 207 | ) 208 | ) 209 | ), 210 | apply(identifier("factorial"), integer_literal(5)) 211 | ); 212 | examples.push_back(example); 213 | } 214 | 215 | // fn x => (pair(x(3) (x(true))) 216 | { 217 | auto example = lambda("x", 218 | apply( 219 | apply(identifier("pair"), 220 | apply(identifier("x"), integer_literal(3))), 221 | apply(identifier("x"), identifier("true")))); 222 | examples.push_back(example); 223 | } 224 | 225 | // pair(f(3), f(true)) 226 | { 227 | auto example = 228 | apply( 229 | apply(identifier("pair"), apply(identifier("f"), integer_literal(4))), 230 | apply(identifier("f"), identifier("true")) 231 | ); 232 | examples.push_back(example); 233 | } 234 | 235 | // let f = (fn x => x) in ((pair (f 4)) (f true)) 236 | { 237 | auto example = let("f", lambda("x", identifier("x")), pair); 238 | examples.push_back(example); 239 | } 240 | 241 | // fn f => f f (fail) 242 | { 243 | auto example = lambda("f", apply(identifier("f"), identifier("f"))); 244 | examples.push_back(example); 245 | } 246 | 247 | // let g = fn f => 5 in g g 248 | { 249 | auto example = let("g", 250 | lambda("f", integer_literal(5)), 251 | apply(identifier("g"), identifier("g"))); 252 | examples.push_back(example); 253 | } 254 | 255 | // example that demonstrates generic and non-generic variables 256 | // fn g => let f = fn x => g in pair (f 3, f true) 257 | { 258 | auto example = 259 | lambda("g", 260 | let("f", 261 | lambda("x", identifier("g")), 262 | apply( 263 | apply(identifier("pair"), 264 | apply(identifier("f"), integer_literal(3)) 265 | ), 266 | apply(identifier("f"), identifier("true")) 267 | ) 268 | ) 269 | ); 270 | examples.push_back(example); 271 | } 272 | 273 | // function composition 274 | // fn f (fn g (fn arg (f g arg))) 275 | { 276 | auto example = lambda("f", lambda("g", lambda("arg", apply(identifier("g"), apply(identifier("f"), identifier("arg")))))); 277 | examples.push_back(example); 278 | } 279 | 280 | // fn f => f 5 281 | { 282 | auto example = lambda("f", apply(identifier("f"), integer_literal(5))); 283 | examples.push_back(example); 284 | } 285 | 286 | // f = fn x => 1 287 | // g = fn y => y 1 288 | // (g f) 289 | { 290 | auto return_one = lambda("x", integer_literal(1)); 291 | auto apply_one = lambda("y", apply(identifier("y"), integer_literal(1))); 292 | auto example = apply(apply_one, return_one); 293 | examples.push_back(example); 294 | } 295 | 296 | auto f = try_to_infer(env); 297 | std::for_each(examples.begin(), examples.end(), f); 298 | 299 | return 0; 300 | } 301 | 302 | -------------------------------------------------------------------------------- /inference.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "unification.hpp" 10 | #include "syntax.hpp" 11 | 12 | namespace inference 13 | { 14 | 15 | using unification::type; 16 | using unification::type_variable; 17 | using unification::type_operator; 18 | 19 | std::ostream &operator<<(std::ostream &os, const std::set &x) 20 | { 21 | for(auto i = x.begin(); 22 | i != x.end(); 23 | ++i) 24 | { 25 | os << *i << " "; 26 | } 27 | 28 | return os; 29 | } 30 | 31 | std::ostream &operator<<(std::ostream &os, const std::map &x) 32 | { 33 | os << "{"; 34 | for(auto i = x.begin(); 35 | i != x.end(); 36 | ++i) 37 | { 38 | os << i->first << " : " << i->second << ", "; 39 | } 40 | os << "}"; 41 | 42 | return os; 43 | } 44 | 45 | namespace types 46 | { 47 | 48 | static const int integer = 0; 49 | static const int boolean = 1; 50 | static const int function = 2; 51 | static const int pair = 3; 52 | 53 | } 54 | 55 | inline type make_function(const type &arg, 56 | const type &result) 57 | { 58 | return type_operator(types::function, {arg, result}); 59 | } 60 | 61 | inline type integer(void) 62 | { 63 | return type_operator(types::integer); 64 | } 65 | 66 | inline type boolean(void) 67 | { 68 | return type_operator(types::boolean); 69 | } 70 | 71 | inline type pair(const type &first, 72 | const type &second) 73 | { 74 | return type_operator(types::pair, {first, second}); 75 | } 76 | 77 | inline type definitive(const std::map &substitution, const type_variable &x) 78 | { 79 | type result = x; 80 | 81 | // iteratively follow type_variables in the substitution until we can't go any further 82 | type_variable *ptr = 0; 83 | while((ptr = boost::get(&result)) && substitution.count(*ptr)) 84 | { 85 | result = substitution.find(*ptr)->second; 86 | } // end while 87 | 88 | return result; 89 | } 90 | 91 | class environment 92 | : public std::map 93 | { 94 | public: 95 | inline environment() 96 | : m_next_id(0) 97 | {} 98 | 99 | inline std::size_t unique_id() 100 | { 101 | return m_next_id++; 102 | } 103 | 104 | private: 105 | std::size_t m_next_id; 106 | }; 107 | 108 | struct fresh_maker 109 | : boost::static_visitor 110 | { 111 | inline fresh_maker(environment &env, 112 | const std::set &non_generic, 113 | const std::map &substitution) 114 | : m_env(env), 115 | m_non_generic(non_generic), 116 | m_substitution(substitution) 117 | {} 118 | 119 | inline result_type operator()(const type_variable &var) 120 | { 121 | if(is_generic(var)) 122 | { 123 | std::clog << var << " is generic" << std::endl; 124 | std::clog << "mappings: " << m_mappings << std::endl; 125 | if(!m_mappings.count(var)) 126 | { 127 | std::clog << var << " is not in mappings" << std::endl; 128 | m_mappings[var] = type_variable(m_env.unique_id()); 129 | } // end if 130 | 131 | return m_mappings[var]; 132 | } // end if 133 | 134 | std::clog << var << " is not generic" << std::endl; 135 | 136 | return var; 137 | } // end operator()() 138 | 139 | inline result_type operator()(const type_operator &op) 140 | { 141 | std::vector types(op.size()); 142 | // make sure to pass a reference to this to maintain our state 143 | std::transform(op.begin(), op.end(), types.begin(), std::ref(*this)); 144 | return type_operator(op.kind(), types); 145 | } // end operator()() 146 | 147 | inline result_type operator()(const type &x) 148 | { 149 | result_type result; 150 | // XXX hard-coded 0 here sucks 151 | if(x.which() == 0) 152 | { 153 | auto definitive_type = definitive(m_substitution, boost::get(x)); 154 | result = boost::apply_visitor(*this, definitive_type); 155 | } // end if 156 | else 157 | { 158 | result = boost::apply_visitor(*this, x); 159 | } // end else 160 | 161 | return result; 162 | } // end operator() 163 | 164 | private: 165 | inline bool is_generic(const type_variable &var) const 166 | { 167 | bool occurs = false; 168 | 169 | std::clog << "is_generic: checking for " << var << std::endl; 170 | 171 | for(auto i = m_non_generic.begin(); 172 | i != m_non_generic.end(); 173 | ++i) 174 | { 175 | std::clog << "is_generic: checking in " << *i << std::endl; 176 | occurs = unification::detail::occurs(definitive(m_substitution, *i), var); 177 | 178 | std::clog << "is_generic: occurs: " << occurs << std::endl; 179 | 180 | if(occurs) break; 181 | } // end for i 182 | 183 | return !occurs; 184 | } // end is_generic() 185 | 186 | environment &m_env; 187 | const std::set &m_non_generic; 188 | const std::map &m_substitution; 189 | std::map m_mappings; 190 | }; // end fresh_maker 191 | 192 | struct inferencer 193 | : boost::static_visitor 194 | { 195 | inline inferencer(const environment &env) 196 | : m_environment(env) 197 | {} 198 | 199 | inline result_type operator()(const syntax::integer_literal) 200 | { 201 | return integer(); 202 | } // end operator()() 203 | 204 | inline result_type operator()(const syntax::identifier &id) 205 | { 206 | if(!m_environment.count(id.name())) 207 | { 208 | auto what = std::string("Undefined symbol ") + id.name(); 209 | throw std::runtime_error(what); 210 | } // end if 211 | 212 | // create a fresh type 213 | std::clog << "inferencer(identifier): m_non_generic_variables: " << m_non_generic_variables << std::endl; 214 | std::clog << "inferencer(identifier): calling fresh_maker on " << id.name() << std::endl; 215 | auto freshen_me = m_environment[id.name()]; 216 | auto v = fresh_maker(m_environment, m_non_generic_variables, m_substitution); 217 | return v(freshen_me); 218 | } // end operator()() 219 | 220 | inline result_type operator()(const syntax::apply &app) 221 | { 222 | std::clog << "inferencer(apply): m_non_generic_variables: " << std::endl; 223 | std::clog << m_non_generic_variables << std::endl; 224 | 225 | auto fun_type = boost::apply_visitor(*this, app.function()); 226 | auto arg_type = boost::apply_visitor(*this, app.argument()); 227 | 228 | std::clog << "inferencer(apply): calling unique_id" << std::endl; 229 | auto x = type_variable(m_environment.unique_id()); 230 | auto lhs = make_function(arg_type, x); 231 | 232 | unification::unify(lhs, fun_type, m_substitution); 233 | 234 | return definitive(m_substitution,x); 235 | } // end operator()() 236 | 237 | inline result_type operator()(const syntax::lambda &lambda) 238 | { 239 | std::clog << "inferencer(lambda): calling unique_id" << std::endl; 240 | auto arg_type = type_variable(m_environment.unique_id()); 241 | 242 | // introduce a scope with a non-generic variable 243 | auto s = scoped_non_generic_variable(this, lambda.parameter(), arg_type); 244 | 245 | // get the type of the body of the lambda 246 | std::clog << "inferencer(lambda): m_non_generic_variables: " << m_non_generic_variables << std::endl; 247 | auto body_type = boost::apply_visitor(*this, lambda.body()); 248 | 249 | // x = (arg_type -> body_type) 250 | std::clog << "inferencer(lambda): calling unique_id" << std::endl; 251 | auto x = type_variable(m_environment.unique_id()); 252 | unification::unify(x, make_function(arg_type, body_type), m_substitution); 253 | 254 | return definitive(m_substitution,x); 255 | } // end operator()() 256 | 257 | inline result_type operator()(const syntax::let &let) 258 | { 259 | auto defn_type = boost::apply_visitor(*this, let.definition()); 260 | 261 | // introduce a scope with a generic variable 262 | auto s = scoped_generic(this, let.name(), defn_type); 263 | 264 | auto result = boost::apply_visitor(*this, let.body()); 265 | 266 | return result; 267 | } // end operator()() 268 | 269 | inline result_type operator()(const syntax::letrec &letrec) 270 | { 271 | std::clog << "inferencer(letrec): calling unique_id" << std::endl; 272 | auto new_type = type_variable(m_environment.unique_id()); 273 | 274 | // introduce a scope with a non generic variable 275 | auto s = scoped_non_generic_variable(this, letrec.name(), new_type); 276 | 277 | auto definition_type = boost::apply_visitor(*this, letrec.definition()); 278 | 279 | // new_type = definition_type 280 | unification::unify(new_type, definition_type, m_substitution); 281 | 282 | auto result = boost::apply_visitor(*this, letrec.body()); 283 | 284 | return result; 285 | } 286 | 287 | struct scoped_generic 288 | { 289 | inline scoped_generic(inferencer *inf, 290 | const std::string &name, 291 | const type &t) 292 | : m_environment(inf->m_environment) 293 | { 294 | auto iter = m_environment.find(name); 295 | 296 | if(iter != m_environment.end()) 297 | { 298 | // the key already exists 299 | m_restore = std::make_tuple(true, iter, iter->second); 300 | iter->second = t; 301 | } // end if 302 | else 303 | { 304 | // the key does not exist 305 | auto kv = std::make_pair(name,t); 306 | m_restore = std::make_tuple(false, m_environment.insert(kv).first, type()); 307 | } // end else 308 | } // end scoped_generic() 309 | 310 | inline ~scoped_generic() 311 | { 312 | using namespace std; 313 | 314 | if(get<0>(m_restore)) 315 | { 316 | auto iter = get<1>(m_restore); 317 | auto val = get<2>(m_restore); 318 | 319 | iter->second = val; 320 | } // end if 321 | else 322 | { 323 | auto iter = get<1>(m_restore); 324 | m_environment.erase(iter); 325 | } // end else 326 | } // end ~scoped_generic() 327 | 328 | environment &m_environment; 329 | std::tuple m_restore; 330 | }; 331 | 332 | struct scoped_non_generic_variable 333 | : scoped_generic 334 | { 335 | inline scoped_non_generic_variable(inferencer *inf, 336 | const std::string &name, 337 | const type_variable &var) 338 | : scoped_generic(inf, name, var), 339 | m_non_generic(inf->m_non_generic_variables), 340 | m_erase_me(m_non_generic.insert(var)) 341 | {} 342 | 343 | inline ~scoped_non_generic_variable() 344 | { 345 | if(m_erase_me.second) 346 | { 347 | m_non_generic.erase(m_erase_me.first); 348 | } // end if 349 | } // end ~scoped_non_generic_variable() 350 | 351 | std::set &m_non_generic; 352 | std::pair::iterator, bool> m_erase_me; 353 | }; 354 | 355 | environment m_environment; 356 | std::set m_non_generic_variables; 357 | std::map m_substitution; 358 | }; 359 | 360 | type infer_type(const syntax::node &node, 361 | const environment &env) 362 | { 363 | auto v = inferencer(env); 364 | auto old = std::clog.rdbuf(0); 365 | auto result = boost::apply_visitor(v, node); 366 | std::clog.rdbuf(old); 367 | return result; 368 | } 369 | 370 | } // end inference 371 | 372 | -------------------------------------------------------------------------------- /syntax.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace syntax 8 | { 9 | 10 | class integer_literal 11 | { 12 | public: 13 | inline integer_literal(const int v) 14 | : m_value(v) 15 | {} 16 | 17 | int value(void) const 18 | { 19 | return m_value; 20 | } 21 | 22 | private: 23 | int m_value; 24 | }; 25 | 26 | inline std::ostream &operator<<(std::ostream &os, const integer_literal &il) 27 | { 28 | return os << il.value(); 29 | } // end operator<<() 30 | 31 | 32 | class identifier 33 | { 34 | public: 35 | inline identifier(const std::string &name) 36 | : m_name(name) 37 | {} 38 | 39 | inline const std::string &name(void) const 40 | { 41 | return m_name; 42 | } 43 | 44 | private: 45 | std::string m_name; 46 | }; 47 | 48 | inline std::ostream &operator<<(std::ostream &os, const identifier &i) 49 | { 50 | return os << i.name(); 51 | } 52 | 53 | class apply; 54 | class lambda; 55 | class let; 56 | class letrec; 57 | 58 | typedef boost::variant< 59 | integer_literal, 60 | identifier, 61 | boost::recursive_wrapper, 62 | boost::recursive_wrapper, 63 | boost::recursive_wrapper, 64 | boost::recursive_wrapper 65 | > node; 66 | 67 | class apply 68 | { 69 | public: 70 | inline apply(node &&fn, 71 | node &&arg) 72 | : m_fn(std::move(fn)), 73 | m_arg(std::move(arg)) 74 | {} 75 | 76 | inline apply(const node &fn, 77 | const node &arg) 78 | : m_fn(fn), 79 | m_arg(arg) 80 | {} 81 | 82 | inline const node &function(void) const 83 | { 84 | return m_fn; 85 | } 86 | 87 | inline const node &argument(void) const 88 | { 89 | return m_arg; 90 | } 91 | 92 | private: 93 | node m_fn, m_arg; 94 | }; 95 | 96 | inline std::ostream &operator<<(std::ostream &os, const apply &a) 97 | { 98 | return os << "(" << a.function() << " " << a.argument() << ")"; 99 | } 100 | 101 | class lambda 102 | { 103 | public: 104 | inline lambda(const std::string ¶m, 105 | node &&body) 106 | : m_param(param), 107 | m_body(std::move(body)) 108 | {} 109 | 110 | inline lambda(const std::string ¶m, 111 | const node &body) 112 | : m_param(param), 113 | m_body(body) 114 | {} 115 | 116 | inline const std::string ¶meter(void) const 117 | { 118 | return m_param; 119 | } 120 | 121 | inline const node &body(void) const 122 | { 123 | return m_body; 124 | } 125 | 126 | private: 127 | std::string m_param; 128 | node m_body; 129 | }; 130 | 131 | inline std::ostream &operator<<(std::ostream &os, const lambda &l) 132 | { 133 | return os << "(fn " << l.parameter() << " => " << l.body() << ")"; 134 | } 135 | 136 | class let 137 | { 138 | public: 139 | inline let(const std::string &name, 140 | node &&def, 141 | node &&body) 142 | : m_name(name), 143 | m_definition(std::move(def)), 144 | m_body(std::move(body)) 145 | {} 146 | 147 | inline let(const std::string &name, 148 | const node &def, 149 | const node &body) 150 | : m_name(name), 151 | m_definition(def), 152 | m_body(body) 153 | {} 154 | 155 | inline const std::string &name() const 156 | { 157 | return m_name; 158 | } 159 | 160 | inline const node &definition() const 161 | { 162 | return m_definition; 163 | } 164 | 165 | inline const node &body() const 166 | { 167 | return m_body; 168 | } 169 | 170 | private: 171 | std::string m_name; 172 | node m_definition, m_body; 173 | }; 174 | 175 | inline std::ostream &operator<<(std::ostream &os, const let &l) 176 | { 177 | return os << "(let " << l.name() << " = " << l.definition() << " in " << l.body() << ")"; 178 | } 179 | 180 | class letrec 181 | { 182 | public: 183 | inline letrec(const std::string &name, 184 | node &&def, 185 | node &&body) 186 | : m_name(name), 187 | m_definition(std::move(def)), 188 | m_body(std::move(body)) 189 | {} 190 | 191 | inline letrec(const std::string &name, 192 | const node &def, 193 | const node &body) 194 | : m_name(name), 195 | m_definition(def), 196 | m_body(body) 197 | {} 198 | 199 | inline const std::string &name() const 200 | { 201 | return m_name; 202 | } 203 | 204 | inline const node &definition() const 205 | { 206 | return m_definition; 207 | } 208 | 209 | inline const node &body() const 210 | { 211 | return m_body; 212 | } 213 | 214 | private: 215 | std::string m_name; 216 | node m_definition, m_body; 217 | }; 218 | 219 | inline std::ostream &operator<<(std::ostream &os, const letrec &l) 220 | { 221 | return os << "(letrec " << l.name() << " = " << l.definition() << " in " << l.body() << ")"; 222 | } 223 | 224 | } // end syntax 225 | 226 | -------------------------------------------------------------------------------- /unification.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | namespace unification 15 | { 16 | 17 | class type_variable; 18 | class type_operator; 19 | 20 | typedef boost::variant< 21 | type_variable, 22 | boost::recursive_wrapper 23 | > type; 24 | 25 | class type_variable 26 | { 27 | public: 28 | inline type_variable() 29 | : m_id() 30 | {} 31 | 32 | inline type_variable(const std::size_t i) 33 | : m_id(i) 34 | {} 35 | 36 | inline std::size_t id() const 37 | { 38 | return m_id; 39 | } // end id() 40 | 41 | inline bool operator==(const type_variable &other) const 42 | { 43 | return id() == other.id(); 44 | } // end operator==() 45 | 46 | inline bool operator!=(const type_variable &other) const 47 | { 48 | return !(*this == other); 49 | } // end operator!=() 50 | 51 | inline bool operator<(const type_variable &other) const 52 | { 53 | return id() < other.id(); 54 | } // end operator<() 55 | 56 | inline operator std::size_t (void) const 57 | { 58 | return id(); 59 | } // end operator size_t 60 | 61 | private: 62 | std::size_t m_id; 63 | }; // end type_variable 64 | 65 | class type_operator 66 | : private std::vector 67 | { 68 | public: 69 | typedef std::size_t kind_type; 70 | 71 | private: 72 | typedef std::vector super_t; 73 | std::vector m_types; 74 | 75 | kind_type m_kind; 76 | 77 | public: 78 | using super_t::begin; 79 | using super_t::end; 80 | using super_t::size; 81 | using super_t::operator[]; 82 | 83 | inline type_operator(const type_operator &other) 84 | : super_t(other), 85 | m_types(other.m_types), 86 | m_kind(other.m_kind) 87 | {} 88 | 89 | inline type_operator(const kind_type &kind) 90 | : m_kind(kind) 91 | {} 92 | 93 | template 94 | type_operator(const kind_type &kind, 95 | Iterator first, 96 | Iterator last) 97 | : super_t(first, last), 98 | m_kind(kind) 99 | {} 100 | 101 | template 102 | inline type_operator(const kind_type &kind, 103 | const Range &rng) 104 | : super_t(rng.begin(), rng.end()), 105 | m_kind(kind) 106 | {} 107 | 108 | inline type_operator(const kind_type &kind, 109 | std::initializer_list &&types) 110 | : super_t(types), 111 | m_kind(kind) 112 | {} 113 | 114 | inline type_operator(type_operator &&other) 115 | : super_t(std::move(other)), 116 | m_kind(std::move(other.m_kind)) 117 | {} 118 | 119 | inline type_operator &operator=(const type_operator &other) 120 | { 121 | super_t::operator=(other); 122 | m_types = other.m_types; 123 | m_kind = other.m_kind; 124 | return *this; 125 | } 126 | 127 | inline const kind_type &kind(void) const 128 | { 129 | return m_kind; 130 | } // end kind() 131 | 132 | inline bool compare_kind(const type_operator &other) const 133 | { 134 | return kind() == other.kind() && size() == other.size(); 135 | } // end operator==() 136 | 137 | inline bool operator==(const type_operator &other) const 138 | { 139 | return compare_kind(other) & std::equal(begin(), end(), other.begin()); 140 | } // end operator==() 141 | }; // end type_operator 142 | 143 | typedef std::pair constraint; 144 | 145 | struct type_mismatch 146 | : std::runtime_error 147 | { 148 | inline type_mismatch(const type &xx, const type &yy) 149 | : std::runtime_error("type mismatch"), 150 | x(xx), 151 | y(yy) 152 | {} 153 | 154 | inline virtual ~type_mismatch(void) throw() 155 | {} 156 | 157 | type x; 158 | type y; 159 | }; 160 | 161 | struct recursive_unification 162 | : std::runtime_error 163 | { 164 | inline recursive_unification(const type &xx, const type &yy) 165 | : std::runtime_error("recursive unification"), 166 | x(xx), 167 | y(yy) 168 | {} 169 | 170 | inline virtual ~recursive_unification(void) throw() 171 | {} 172 | 173 | type x, y; 174 | }; 175 | 176 | namespace detail 177 | { 178 | 179 | inline void replace(type &x, const type_variable &replace_me, const type &replacement) 180 | { 181 | if(x.which()) 182 | { 183 | auto &op = boost::get(x); 184 | auto f = std::bind(replace, std::placeholders::_1, replace_me, replacement); 185 | std::for_each(op.begin(), op.end(), f); 186 | } // end if 187 | else 188 | { 189 | auto &var = boost::get(x); 190 | if(var == replace_me) 191 | { 192 | x = replacement; 193 | } // end if 194 | } // end else 195 | } // end replace() 196 | 197 | inline bool occurs(const type &haystack, const type_variable &needle) 198 | { 199 | bool result = false; 200 | if(haystack.which()) 201 | { 202 | auto &op = boost::get(haystack); 203 | auto f = std::bind(occurs, std::placeholders::_1, needle); 204 | result = std::any_of(op.begin(), op.end(), f); 205 | } // end end if 206 | else 207 | { 208 | auto &var = boost::get(haystack); 209 | result = (var == needle); 210 | } // end else 211 | 212 | return result; 213 | } // end occurs() 214 | 215 | struct equals_variable 216 | : boost::static_visitor 217 | { 218 | inline equals_variable(const type_variable &xx) 219 | : m_x(xx) 220 | {} 221 | 222 | inline bool operator()(const type_variable &y) 223 | { 224 | return m_x == y; 225 | } // end operator()() 226 | 227 | inline bool operator()(const type_operator &y) 228 | { 229 | return false; 230 | } // end operator()() 231 | 232 | const type_variable &m_x; 233 | }; // end equals_variable 234 | 235 | struct replacer 236 | : boost::static_visitor<> 237 | { 238 | inline replacer(const type_variable &replace_me) 239 | : m_replace_me(replace_me) 240 | {} 241 | 242 | inline void operator()(type_variable &var, const type_variable &replacement) 243 | { 244 | if(var == m_replace_me) 245 | { 246 | var = replacement; 247 | } // end if 248 | } // end operator()() 249 | 250 | template 251 | inline void operator()(type_operator &op, const T &replacement) 252 | { 253 | auto v = boost::apply_visitor(*this); 254 | auto f = std::bind(v, std::placeholders::_2, replacement); 255 | std::for_each(op.begin(), op.end(), f); 256 | } // end operator()() 257 | 258 | const type_variable &m_replace_me; 259 | }; // end replacer 260 | 261 | class unifier 262 | : public boost::static_visitor<> 263 | { 264 | inline void eliminate(const type_variable &x, const type &y) 265 | { 266 | // replace all occurrances of x with y in the stack and the substitution 267 | for(auto i = m_stack.begin(); 268 | i != m_stack.end(); 269 | ++i) 270 | { 271 | replace(i->first, x, y); 272 | replace(i->second, x, y); 273 | } // end for i 274 | 275 | for(auto i = m_substitution.begin(); 276 | i != m_substitution.end(); 277 | ++i) 278 | { 279 | replace(i->second, x, y); 280 | } // end for i 281 | 282 | // add x = y to the substitution 283 | m_substitution[x] = y; 284 | } // end eliminate() 285 | 286 | std::vector m_stack; 287 | std::map &m_substitution; 288 | 289 | public: 290 | // apply_visitor requires that these functions be public 291 | inline void operator()(const type_variable &x, const type_variable &y) 292 | { 293 | if(x != y) 294 | { 295 | eliminate(x,y); 296 | } // end if 297 | } // end operator()() 298 | 299 | inline void operator()(const type_variable &x, const type_operator &y) 300 | { 301 | if(occurs(y,x)) 302 | { 303 | throw recursive_unification(x,y); 304 | } // end if 305 | 306 | eliminate(x,y); 307 | } // end operator()() 308 | 309 | inline void operator()(const type_operator &x, const type_variable &y) 310 | { 311 | if(occurs(x,y)) 312 | { 313 | throw recursive_unification(y,x); 314 | } // end if 315 | 316 | eliminate(y,x); 317 | } // end operator()() 318 | 319 | inline void operator()(const type_operator &x, const type_operator &y) 320 | { 321 | if(!x.compare_kind(y)) 322 | { 323 | throw type_mismatch(x,y); 324 | } // end if 325 | 326 | // push (xi,yi) onto the stack 327 | for(auto xi = x.begin(), yi = y.begin(); 328 | xi != x.end(); 329 | ++xi, ++yi) 330 | { 331 | m_stack.push_back(std::make_pair(*xi, *yi)); 332 | } // end for xi, yi 333 | } // end operator()() 334 | 335 | template 336 | inline unifier(Iterator first_constraint, Iterator last_constraint, std::map &substitution) 337 | : m_stack(first_constraint, last_constraint), 338 | m_substitution(substitution) 339 | { 340 | // add the current substitution to the stack 341 | // XXX this step might be unnecessary 342 | m_stack.insert(m_stack.end(), m_substitution.begin(), m_substitution.end()); 343 | m_substitution.clear(); 344 | } // end unifier() 345 | 346 | inline void operator()(void) 347 | { 348 | while(!m_stack.empty()) 349 | { 350 | type x = std::move(m_stack.back().first); 351 | type y = std::move(m_stack.back().second); 352 | m_stack.pop_back(); 353 | 354 | boost::apply_visitor(*this, x, y); 355 | } // end while 356 | } // end operator()() 357 | }; // unifier() 358 | 359 | } // end detail 360 | 361 | template 362 | void unify(Iterator first_constraint, Iterator last_constraint, std::map &substitution) 363 | { 364 | detail::unifier u(first_constraint, last_constraint, substitution); 365 | u(); 366 | } // end unify() 367 | 368 | template 369 | void unify(const Range &rng, std::map &substitution) 370 | { 371 | return unify(rng.begin(), rng.end(), substitution); 372 | } // end unify() 373 | 374 | // often our system has only a single constraint 375 | void unify(const type &x, const type &y, std::map &substitution) 376 | { 377 | auto c = constraint(x,y); 378 | return unify(&c, &c + 1, substitution); 379 | } // end unify() 380 | 381 | template 382 | std::map 383 | unify(const Range &rng) 384 | { 385 | std::map solutions; 386 | unify(rng, solutions); 387 | return std::move(solutions); 388 | } // end unify() 389 | 390 | } // end unification 391 | 392 | --------------------------------------------------------------------------------