├── stack.yaml ├── Makefile ├── typechecker ├── Utils.hs ├── Utils │ └── Monad.hs ├── Makefile ├── Main.hs ├── Syntax.cf ├── TypeChecker │ ├── Monad │ │ ├── Signature.hs │ │ ├── Context.hs │ │ └── Heap.hs │ ├── Force.hs │ ├── DeBruijn.hs │ ├── Print.hs │ ├── Monad.hs │ └── Reduce.hs ├── Syntax │ └── Internal.hs └── TypeChecker.hs ├── package.yaml ├── .gitignore ├── README ├── Context.agda ├── insane.cabal └── Sigma.agda /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: lts-22.44 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | default : 3 | make -C typechecker 4 | 5 | -------------------------------------------------------------------------------- /typechecker/Utils.hs: -------------------------------------------------------------------------------- 1 | 2 | module Utils 3 | ( module Utils.Monad 4 | , on 5 | ) where 6 | 7 | import Utils.Monad 8 | 9 | on f g x y = f (g x) (g y) 10 | 11 | -------------------------------------------------------------------------------- /package.yaml: -------------------------------------------------------------------------------- 1 | name: insane 2 | 3 | executable: 4 | main: Main.hs 5 | source-dirs: typechecker 6 | 7 | dependencies: 8 | - base 9 | - array 10 | - containers 11 | - mtl 12 | - pretty 13 | 14 | language: GHC2021 15 | -------------------------------------------------------------------------------- /typechecker/Utils/Monad.hs: -------------------------------------------------------------------------------- 1 | 2 | module Utils.Monad where 3 | 4 | import Control.Monad 5 | 6 | type Cont r a = (a -> r) -> r 7 | 8 | thread :: (a -> Cont r b) -> [a] -> Cont r [b] 9 | thread f [] ret = ret [] 10 | thread f (x:xs) ret = 11 | f x $ \x -> 12 | thread f xs $ \xs -> ret (x:xs) 13 | 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.o 2 | *.hi 3 | typechecker/tc 4 | .swp 5 | *.swp 6 | _darcs 7 | .darcsignore 8 | typechecker/Syntax/Abs.hs 9 | typechecker/Syntax/ErrM.hs 10 | typechecker/Syntax/Layout.hs 11 | typechecker/Syntax/Lex.hs 12 | typechecker/Syntax/Lex.x 13 | typechecker/Syntax/Par.hs 14 | typechecker/Syntax/Par.info 15 | typechecker/Syntax/Par.y 16 | typechecker/Syntax/Print.hs 17 | /.stack-work/ 18 | /stack*.yaml.lock 19 | *~ -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | 2 | Toy implementation of Insanely Dependent Types 3 | 4 | Features 5 | 6 | - Insane pi-types: 7 | 8 | [x1 : A1, x2 : A2, .., xn : An] -> B 9 | 10 | All xi are in scope in the Ai's (and in B of course). Applications of 11 | insane functions must be fully applied. 12 | 13 | - Everything is mutually recursive 14 | 15 | - Simple Agda-like syntax 16 | 17 | Limitations 18 | 19 | - No implicit arguments 20 | - Function types and Set are not terms 21 | - No indexed datatypes 22 | -------------------------------------------------------------------------------- /typechecker/Makefile: -------------------------------------------------------------------------------- 1 | 2 | bnfc_output = $(patsubst %,Syntax/%,Abs.hs ErrM.hs Layout.hs Print.hs Lex.x Par.y) 3 | gen_hs_files = $(patsubst %,Syntax/%.hs,Par Lex Abs ErrM Layout Print) 4 | 5 | unwanted = $(patsubst %,Syntax/%.hs,Skel Test) 6 | 7 | src_files = $(filter-out $(unwanted),$(shell find . -name '*hs')) $(shell find . -name '*hs-boot') 8 | 9 | ghc_flags = -fwarn-incomplete-patterns 10 | # -Werror # Fails because of LANGUAGE OverlappingInstances 11 | 12 | default : tc 13 | 14 | $(bnfc_output) : Syntax.cf 15 | -@rm $(bnfc_output) 16 | bnfc -d $< 17 | -@rm Syntax/Skel* Syntax/Doc* Syntax/Test* 18 | 19 | %.hs : %.y 20 | happy $< -i 21 | 22 | %.hs : %.x 23 | alex $< 24 | 25 | tc : Main.hs $(gen_hs_files) $(src_files) 26 | ghc --make -o $@ $< $(ghc_flags) 27 | 28 | -------------------------------------------------------------------------------- /Context.agda: -------------------------------------------------------------------------------- 1 | 2 | data Cxt [ Ty : Cxt Ty -> Set ] : Set where 3 | nil : Cxt Ty 4 | snoc : (G : Cxt Ty) -> Ty G -> Cxt Ty 5 | 6 | -- Cons-based context extension 7 | data Ext [ Ty : Cxt Ty -> Set ] (G : Cxt Ty) : Set where 8 | nilE : Ext Ty G 9 | consE : (a : Ty G) -> Ext Ty (snoc Ty G a) -> Ext Ty G 10 | 11 | append : [ Ty : Cxt Ty -> Set ] -> (G : Cxt Ty) -> Ext Ty G -> Cxt Ty 12 | append Ty G nilE = G 13 | append Ty G (consE a D) = append Ty (snoc Ty G a) D 14 | 15 | -- Snoc-based context extension 16 | data ExtR [ Ty : Cxt Ty -> Set ] (G : Cxt Ty) : Set where 17 | nilR : ExtR Ty G 18 | snocR : (E : ExtR Ty G) -> Ty (appendR Ty G E) -> ExtR Ty G 19 | 20 | appendR : [ Ty : Cxt Ty -> Set ] -> (G : Cxt Ty) -> ExtR Ty G -> Cxt Ty 21 | appendR Ty G nilR = G 22 | appendR Ty G (snocR E a) = snoc Ty (appendR Ty G E) a 23 | 24 | -------------------------------------------------------------------------------- /typechecker/Main.hs: -------------------------------------------------------------------------------- 1 | 2 | module Main where 3 | 4 | import System.Environment 5 | 6 | import Syntax.Abs 7 | import Syntax.Par 8 | import Syntax.Layout 9 | import Syntax.ErrM 10 | import Syntax.Print 11 | import Syntax.Internal 12 | 13 | import TypeChecker 14 | import TypeChecker.Monad 15 | 16 | checkFile :: FilePath -> IO () 17 | checkFile file = do 18 | s <- readFile file 19 | case pProgram $ resolveLayout True $ myLexer s of 20 | Bad s -> putStrLn $ "Parse error: " ++ s 21 | Ok p -> do 22 | r <- runTC $ checkProgram p 23 | case r of 24 | Left err -> print err 25 | Right () -> putStrLn "OK" 26 | 27 | main = do 28 | args <- getArgs 29 | prog <- getProgName 30 | case args of 31 | [file] -> checkFile file 32 | _ -> putStrLn $ "Usage: " ++ prog ++ " FILE" 33 | 34 | -------------------------------------------------------------------------------- /insane.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.38.1. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: insane 8 | version: 0.0.0 9 | build-type: Simple 10 | 11 | executable insane 12 | main-is: Main.hs 13 | other-modules: 14 | Syntax.Abs 15 | Syntax.ErrM 16 | Syntax.Internal 17 | Syntax.Layout 18 | Syntax.Lex 19 | Syntax.Par 20 | Syntax.Print 21 | TypeChecker 22 | TypeChecker.DeBruijn 23 | TypeChecker.Force 24 | TypeChecker.Monad 25 | TypeChecker.Monad.Context 26 | TypeChecker.Monad.Heap 27 | TypeChecker.Monad.Signature 28 | TypeChecker.Print 29 | TypeChecker.Reduce 30 | Utils 31 | Utils.Monad 32 | Paths_insane 33 | hs-source-dirs: 34 | typechecker 35 | build-depends: 36 | array 37 | , base 38 | , containers 39 | , mtl 40 | , pretty 41 | default-language: GHC2021 42 | -------------------------------------------------------------------------------- /Sigma.agda: -------------------------------------------------------------------------------- 1 | 2 | data Bool : Set where 3 | false : Bool 4 | true : Bool 5 | 6 | If : Set -> Set -> Bool -> Set 7 | If x y true = x 8 | If x y false = y 9 | 10 | if : (P : Bool -> Set) -> P true -> P false -> (b : Bool) -> P b 11 | if P x y true = x 12 | if P x y false = y 13 | 14 | data Sigma (A : Set)(B : A -> Set) : Set where 15 | wrap : [f : (b : Bool) -> If A (B (f true)) b] -> Sigma A B 16 | 17 | fst : (A : Set) -> (B : A -> Set) -> Sigma A B -> A 18 | fst A B (wrap f) = f true 19 | 20 | snd : (A : Set) -> (B : A -> Set) -> (p : Sigma A B) -> B (fst A B p) 21 | snd A B (wrap f) = f false 22 | 23 | pair : (A : Set) -> (B : A -> Set) -> (x : A) -> (y : B x) -> Sigma A B 24 | pair A B x y = wrap A B (if (If A (B x)) x y) 25 | 26 | data Id (A : Set) (x : A) (y : A) : Set where 27 | eq : ((P : A -> Set) -> P x -> P y) -> Id A x y 28 | 29 | refl : (A : Set) -> (x : A) -> Id A x x 30 | refl A x = eq A x x (\P px -> px) 31 | 32 | lemfst : (A : Set) -> (B : A -> Set) -> (x : A) -> (y : B x) -> Id A (fst A B (pair A B x y)) x 33 | lemfst A B x y = refl A x 34 | 35 | lemsnd : (A : Set) -> (B : A -> Set) -> (x : A) -> (y : B x) -> Id (B x) (snd A B (pair A B x y)) y 36 | lemsnd A B x y = refl (B x) y 37 | 38 | -------------------------------------------------------------------------------- /typechecker/Syntax.cf: -------------------------------------------------------------------------------- 1 | 2 | entrypoints Program; 3 | 4 | layout toplevel; 5 | layout "where"; 6 | 7 | Prog. Program ::= [Decl]; 8 | 9 | separator Decl ";"; 10 | 11 | TypeSig. Decl ::= Ident ":" Expr; 12 | FunDef. Decl ::= Ident [Pattern1] "=" Expr; 13 | DataDecl. Decl ::= "data" Ident [TelBinding] ":" "Set" "where" "{" [Constr] "}"; 14 | 15 | Constr. Constr ::= Ident ":" Expr; 16 | separator Constr ";"; 17 | 18 | TelBind. TelBinding ::= Telescope; 19 | PiBind. TelBinding ::= Binding; 20 | separator TelBinding ""; 21 | 22 | Bind. Binding ::= "(" Ident ":" Expr ")"; 23 | separator Binding ""; 24 | 25 | Tel. Telescope ::= "[" [RBinding] "]"; 26 | RBind. RBinding ::= Ident ":" Expr; 27 | separator nonempty RBinding ","; 28 | 29 | Lam. Expr ::= "\\" [Ident] "->" Expr; 30 | Pi. Expr ::= "(" Ident ":" Expr ")" "->" Expr; 31 | RPi. Expr ::= Telescope "->" Expr; 32 | Fun. Expr ::= Expr1 "->" Expr; 33 | App. Expr1 ::= Expr1 Expr2; 34 | Set. Expr2 ::= "Set"; 35 | Name. Expr2 ::= Ident; 36 | coercions Expr 2; 37 | 38 | AppP. Pattern ::= Pattern Pattern1; 39 | VarP. Pattern1 ::= Ident; 40 | coercions Pattern 1; 41 | separator Pattern1 ""; 42 | 43 | separator Ident ""; 44 | 45 | separator Expr ","; 46 | 47 | comment "--"; 48 | comment "{-" "-}"; 49 | 50 | -------------------------------------------------------------------------------- /typechecker/TypeChecker/Monad/Signature.hs: -------------------------------------------------------------------------------- 1 | 2 | module TypeChecker.Monad.Signature where 3 | 4 | import Control.Monad.State 5 | import qualified Data.Map as Map 6 | 7 | import Syntax.Internal 8 | import TypeChecker.Monad 9 | import TypeChecker.Monad.Heap 10 | 11 | setSig :: Signature -> TC () 12 | setSig sig = modify $ \s -> s { stSig = sig } 13 | 14 | getSig :: TC Signature 15 | getSig = gets stSig 16 | 17 | getDefinition :: Name -> TC Definition 18 | getDefinition x = do 19 | sig <- getSig 20 | case Map.lookup x sig of 21 | Just d -> return d 22 | Nothing -> fail $ "not a defined name " ++ x 23 | 24 | withDefinition :: Name -> (Definition-> TC a) -> TC a 25 | withDefinition x f = f =<< getDefinition x 26 | 27 | addConstraint :: TC () -> TC () 28 | addConstraint c = do 29 | p <- suspend c 30 | modify $ \s -> s { stConstraints = p : stConstraints s } 31 | 32 | defType :: Name -> TC Type 33 | defType x = withDefinition x $ \d -> 34 | case d of 35 | Axiom _ t -> return t 36 | Defn _ t _ -> return t 37 | Data _ t _ -> return t 38 | Cons _ t -> return t 39 | 40 | isConstructor :: Name -> TC () 41 | isConstructor x = withDefinition x $ \d -> 42 | case d of 43 | Cons _ _ -> return () 44 | _ -> fail $ x ++ " should be a constructor" 45 | 46 | isData :: Name -> TC () 47 | isData x = withDefinition x $ \d -> 48 | case d of 49 | Data _ _ _ -> return () 50 | _ -> fail $ x ++ " should be a datatype" 51 | 52 | functionArity :: Name -> TC (Maybe Arity) 53 | functionArity x = withDefinition x $ \d -> 54 | case d of 55 | Defn _ _ (c:_) -> do 56 | Clause ps _ <- forceClosure c 57 | return $ Just $ length ps 58 | _ -> return Nothing 59 | 60 | -------------------------------------------------------------------------------- /typechecker/TypeChecker/Force.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances, UndecidableInstances, OverlappingInstances #-} 2 | 3 | module TypeChecker.Force where 4 | 5 | import Control.Applicative 6 | import Control.Monad.Trans 7 | import Control.Monad.Except 8 | import Data.Traversable 9 | import qualified Data.Map as Map 10 | 11 | import Syntax.Internal 12 | import TypeChecker.Monad 13 | import TypeChecker.Monad.Heap 14 | import TypeChecker.Monad.Signature 15 | import Utils 16 | 17 | traverse_ f x = () <$ traverse f x 18 | 19 | -- | Force the checking of everything in the signature 20 | forceSig :: TC () 21 | forceSig = do 22 | sig <- getSig 23 | force $ Map.elems sig 24 | cs <- getConstraints 25 | force cs 26 | 27 | class Force a where 28 | force :: a -> TC () 29 | 30 | instance Force Definition where 31 | force (Axiom _ t) = force t 32 | force (Defn _ t cs) = force (t,cs) 33 | force (Data _ t _) = force t 34 | force (Cons _ t) = force t 35 | 36 | instance (Pointer ptr a, Force a) => Force ptr where 37 | force p = do 38 | f <- howForceful 39 | let gloves = case f of 40 | Hard -> id 41 | Soft -> flip catchError $ \_ -> return () 42 | gloves $ force =<< forceClosure p 43 | 44 | instance Force () where 45 | force () = return () 46 | 47 | instance Force Clause' where 48 | force c = 49 | case c of 50 | Clause ps t -> force t 51 | 52 | instance Force Type' where 53 | force a = 54 | case a of 55 | Pi a b -> force (a,b) 56 | RPi tel a -> force (tel, a) 57 | Fun a b -> force (a,b) 58 | Set -> return () 59 | El t -> force t 60 | 61 | instance Force a => Force (RBind a) where 62 | force (RBind _ a) = force a 63 | 64 | instance Force Term' where 65 | force t = 66 | case t of 67 | Var n -> return () 68 | Def c -> return () 69 | App s t -> force (s,t) 70 | Lam t -> force t 71 | 72 | instance (Force a, Force b) => Force (a,b) where 73 | force (x,y) = force x >> force y 74 | 75 | instance Force a => Force (Abs a) where 76 | force = traverse_ force 77 | 78 | instance Force a => Force [a] where 79 | force = traverse_ force 80 | 81 | -------------------------------------------------------------------------------- /typechecker/TypeChecker/Monad/Context.hs: -------------------------------------------------------------------------------- 1 | 2 | module TypeChecker.Monad.Context where 3 | 4 | import Control.Applicative 5 | import Control.Monad.Reader 6 | import Data.List 7 | 8 | import Syntax.Internal 9 | import TypeChecker.Monad 10 | import TypeChecker.Monad.Heap 11 | import TypeChecker.DeBruijn 12 | import Utils 13 | import Control.Monad 14 | 15 | getContext :: TC Context 16 | getContext = asks envContext 17 | 18 | withContext :: (Context -> Context) -> TC a -> TC a 19 | withContext f = local $ \e -> e { envContext = f (envContext e) } 20 | 21 | extendContext :: Name -> Type -> TC a -> TC a 22 | extendContext x t = withContext (VBind x t :) 23 | 24 | extendContext_ :: Name -> TC a -> TC a 25 | extendContext_ x m = do 26 | set <- evaluated Set 27 | extendContext x set m 28 | 29 | extendContextTel :: Telescope -> TC a -> TC a 30 | extendContextTel tel = withContext (TBind (reverse tel) :) 31 | 32 | (!) :: Context -> Name -> Maybe (DeBruijnIndex, DeBruijnIndex, Type) 33 | ctx ! x = look 0 ctx 34 | where 35 | look n (VBind y t : ctx) 36 | | x == y = return (n, n + 1, t) 37 | | otherwise = look (n + 1) ctx 38 | look n (TBind tel : ctx) = 39 | lookTel n n tel `mplus` look n' ctx 40 | where n' = n + genericLength tel 41 | look _ [] = fail "" 42 | 43 | lookTel n m (RBind y t : tel) 44 | | x == y = return (n, m, t) 45 | | otherwise = lookTel (n + 1) m tel 46 | lookTel _ _ [] = fail "" 47 | 48 | lookupContext :: Name -> TC (DeBruijnIndex, Type) 49 | lookupContext x = do 50 | ctx <- getContext 51 | case ctx ! x of 52 | Just (n, m, t) -> (,) n <$> raiseBy m t 53 | Nothing -> fail $ "Unbound variable: " ++ x 54 | 55 | flattenContext :: Context -> [(Name, Type)] 56 | flattenContext = concatMap f 57 | where 58 | f (VBind x t) = [(x, t)] 59 | f (TBind tel) = [ (x, t) | RBind x t <- tel ] 60 | 61 | getVarName :: DeBruijnIndex -> TC String 62 | getVarName n = do 63 | ctx <- getContext 64 | fst <$> (ctx ! n) 65 | where 66 | cxt ! n 67 | | len <= n = fail $ "deBruijn index out of range " ++ show n ++ " in " ++ show xs 68 | | otherwise = return $ xs !! fromIntegral n 69 | where 70 | len = fromIntegral $ length xs 71 | xs = flattenContext cxt 72 | 73 | -------------------------------------------------------------------------------- /typechecker/TypeChecker/DeBruijn.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances, UndecidableInstances, OverlappingInstances #-} 2 | 3 | module TypeChecker.DeBruijn where 4 | 5 | import Control.Applicative 6 | import Data.Traversable 7 | 8 | import Syntax.Internal 9 | import TypeChecker.Monad 10 | import TypeChecker.Monad.Heap 11 | import Utils 12 | 13 | class DeBruijn a where 14 | transform :: (Integer -> DeBruijnIndex -> TC Term') -> Integer -> a -> TC a 15 | 16 | instance (Pointer ptr a, DeBruijn a) => DeBruijn ptr where 17 | transform f n = liftPtrM (transform f n) 18 | 19 | instance DeBruijn Type' where 20 | transform f n t = case t of 21 | Pi a b -> uncurry Pi <$> trf (a,b) 22 | RPi tel a -> uncurry RPi <$> transform f n' (tel, a) 23 | where n' = n + fromIntegral (length tel) 24 | Fun a b -> uncurry Fun <$> trf (a,b) 25 | El t -> El <$> trf t 26 | Set -> return Set 27 | where 28 | trf x = transform f n x 29 | 30 | instance DeBruijn a => DeBruijn (RBind a) where 31 | transform f n (RBind x a) = 32 | RBind x <$> transform f n a 33 | 34 | instance DeBruijn Term' where 35 | transform f n t = case t of 36 | Def f -> return $ Def f 37 | Var m -> f n m 38 | App s t -> uncurry App <$> trf (s,t) 39 | Lam t -> Lam <$> trf t 40 | where 41 | trf x = transform f n x 42 | 43 | instance (DeBruijn a, DeBruijn b) => DeBruijn (a,b) where 44 | transform f n (x,y) = (,) <$> transform f n x <*> transform f n y 45 | 46 | instance DeBruijn a => DeBruijn (Abs a) where 47 | transform f n (Abs x b) = Abs x <$> transform f (n + 1) b 48 | 49 | instance DeBruijn a => DeBruijn [a] where 50 | transform f n = traverse (transform f n) 51 | 52 | raiseByFrom :: DeBruijn a => Integer -> Integer -> a -> TC a 53 | raiseByFrom k = transform f 54 | where 55 | f n m | m < n = return $ Var m 56 | | otherwise = return $ Var (m + k) 57 | 58 | raiseBy :: DeBruijn a => Integer -> a -> TC a 59 | raiseBy k = raiseByFrom k 0 60 | 61 | raise :: DeBruijn a => a -> TC a 62 | raise = raiseBy 1 63 | 64 | substUnder :: DeBruijn a => Integer -> Term -> a -> TC a 65 | substUnder n0 t = transform f n0 66 | where 67 | f n m | m < n = return $ Var m 68 | | m == n = forceClosure =<< raiseByFrom (n - n0) n0 t 69 | | otherwise = return $ Var (m - 1) 70 | 71 | subst :: DeBruijn a => Term -> Abs a -> TC a 72 | subst t = substUnder 0 t . absBody 73 | 74 | substs :: DeBruijn a => [Term] -> a -> TC a 75 | substs [] a = return a 76 | substs (t:ts) a = substUnder 0 t =<< flip substs a =<< raise ts 77 | 78 | -------------------------------------------------------------------------------- /typechecker/Syntax/Internal.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies, 2 | DeriveDataTypeable, DeriveFunctor, DeriveFoldable, 3 | DeriveTraversable #-} 4 | module Syntax.Internal where 5 | 6 | import Control.Applicative 7 | import Data.Traversable 8 | import Data.Typeable 9 | import Data.Foldable 10 | 11 | import qualified Syntax.Abs as Abs 12 | import Utils 13 | 14 | -- Pointers --------------------------------------------------------------- 15 | 16 | newtype Ptr = Ptr Integer 17 | deriving (Eq, Ord) 18 | 19 | instance Show Ptr where 20 | show (Ptr n) = "*" ++ show n 21 | 22 | newtype Type = TypePtr Ptr deriving (Eq, Typeable) 23 | newtype Term = TermPtr Ptr deriving (Eq, Typeable) 24 | newtype Clause = ClausePtr Ptr deriving (Eq, Typeable) 25 | newtype Pair a b = PairPtr Ptr deriving (Eq, Typeable) 26 | newtype Unit = UnitPtr Ptr deriving (Eq, Typeable) 27 | 28 | class (Show ptr, Eq ptr, Show a, Typeable a) => Pointer ptr a | ptr -> a, a -> ptr where 29 | toRawPtr :: ptr -> Ptr 30 | fromRawPtr :: Ptr -> ptr 31 | 32 | instance Pointer Unit () where toRawPtr (UnitPtr p) = p; fromRawPtr = UnitPtr 33 | instance Pointer Type Type' where toRawPtr (TypePtr p) = p; fromRawPtr = TypePtr 34 | instance Pointer Term Term' where toRawPtr (TermPtr p) = p; fromRawPtr = TermPtr 35 | instance Pointer Clause Clause' where toRawPtr (ClausePtr p) = p; fromRawPtr = ClausePtr 36 | instance (Show a, Show b, Typeable a, Typeable b) => 37 | Pointer (Pair a b) (a,b) where 38 | toRawPtr (PairPtr p) = p 39 | fromRawPtr = PairPtr 40 | 41 | instance Show Type where show = show . toRawPtr 42 | instance Show Term where show = show . toRawPtr 43 | instance Show Clause where show = show . toRawPtr 44 | instance Show (Pair a b) where show (PairPtr p) = show p 45 | instance Show Unit where show = show . toRawPtr 46 | 47 | -- Syntax ----------------------------------------------------------------- 48 | 49 | type Arity = Int 50 | 51 | data Definition 52 | = Axiom Name Type 53 | | Defn Name Type [Clause] 54 | | Data Name Type [Constructor] 55 | | Cons Name Type 56 | deriving (Show, Typeable) 57 | 58 | data Clause' = Clause [Pattern] Term 59 | deriving (Show, Typeable) 60 | 61 | data Constructor = Constr Name Arity 62 | deriving (Show, Typeable) 63 | 64 | data Pattern = VarP Name 65 | | ConP Name [Pattern] 66 | deriving (Show, Typeable) 67 | 68 | type Name = String 69 | type DeBruijnIndex = Integer 70 | 71 | type Telescope = [RBind Type] 72 | data RBind a = RBind String a 73 | deriving (Show) 74 | 75 | data Type' = Pi Type (Abs Type) 76 | | RPi Telescope Type 77 | | Fun Type Type 78 | | El Term 79 | | Set 80 | deriving (Show, Typeable) 81 | 82 | data Term' = Lam (Abs Term) 83 | | App Term Term 84 | | Var DeBruijnIndex 85 | | Def Name 86 | deriving (Show, Typeable) 87 | 88 | data Abs a = Abs { absName :: Name, absBody :: a } 89 | deriving (Typeable, Functor, Foldable, Traversable) 90 | 91 | instance Show a => Show (Abs a) where 92 | show (Abs _ b) = show b 93 | 94 | -------------------------------------------------------------------------------- /typechecker/TypeChecker/Print.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances, UndecidableInstances, OverlappingInstances #-} 2 | 3 | module TypeChecker.Print where 4 | 5 | import Prelude hiding ((<>)) 6 | import Control.Monad.Except 7 | import Control.Applicative 8 | import qualified Text.PrettyPrint as PP 9 | 10 | import Syntax.Internal 11 | import qualified Syntax.Abs as Abs 12 | import TypeChecker.Monad 13 | import TypeChecker.Monad.Heap 14 | import TypeChecker.Monad.Context 15 | import Utils 16 | 17 | type Doc = PP.Doc 18 | 19 | text s = return $ PP.text s 20 | vcat ds = PP.vcat <$> sequence ds 21 | fsep ds = PP.fsep <$> sequence ds 22 | hsep ds = PP.hsep <$> sequence ds 23 | sep ds = PP.sep <$> sequence ds 24 | nest n d = PP.nest n <$> d 25 | d <+> d' = (PP.<+>) <$> d <*> d' 26 | d <> d' = (PP.<>) <$> d <*> d' 27 | parens d = PP.parens <$> d 28 | brackets d = PP.brackets <$> d 29 | comma = return PP.comma 30 | 31 | punctuate :: TC Doc -> [TC Doc] -> TC Doc 32 | punctuate d xs = (.) PP.fsep . PP.punctuate <$> d <*> sequence xs 33 | 34 | mparens True = parens 35 | mparens False = id 36 | 37 | class Pretty a where 38 | pretty :: a -> TC Doc 39 | prettyPrec :: Int -> a -> TC Doc 40 | 41 | pretty = prettyPrec 0 42 | prettyPrec _ = pretty 43 | 44 | instance (Pointer ptr a, Pretty a) => Pretty ptr where 45 | prettyPrec n p = do 46 | cl <- getClosure p 47 | case cl of 48 | Unevaluated _ -> text "_" 49 | Evaluated x -> prettyPrec n x 50 | 51 | instance Pretty Definition where 52 | pretty (Axiom x t) = 53 | hsep [ text x, text ":", pretty t ] 54 | pretty (Defn x t cs) = 55 | vcat $ pretty (Axiom x t) : [ text x <+> d | d <- map pretty cs ] 56 | pretty (Cons c t) = pretty (Axiom c t) 57 | pretty (Data d t cs) = 58 | vcat [ hsep [ text "data", text d, text ":", pretty t, text "where" ] 59 | , nest 2 $ vcat $ map pretty cs 60 | ] 61 | 62 | instance Pretty Constructor where 63 | pretty (Constr c ar) = text c <> text "/" <> text (show ar) 64 | 65 | instance Pretty Clause' where 66 | pretty (Clause ps t) = 67 | thread (prettyPat 1) ps $ \ds -> 68 | sep [ fsep (map return ds) <+> text "=" 69 | , nest 2 $ pretty t 70 | ] 71 | 72 | instance Pretty Type' where 73 | prettyPrec n t = case t of 74 | Pi a b -> 75 | mparens (n > 0) $ 76 | sep [ parens (text x <+> text ":" <+> pretty a) <+> text "->" 77 | , pretty b 78 | ] 79 | where x = absName b 80 | RPi tel a -> 81 | extendContextTel tel $ 82 | mparens (n > 0) $ 83 | sep [ brackets (punctuate comma $ map pretty tel) <+> text "->" 84 | , pretty a ] 85 | Fun a b -> 86 | mparens (n > 0) $ 87 | sep [ prettyPrec 1 a <+> text "->" 88 | , pretty b 89 | ] 90 | Set -> text "Set" 91 | El t -> prettyPrec n t 92 | 93 | instance Pretty a => Pretty (RBind a) where 94 | pretty (RBind x a) = 95 | hsep [ text x, text ":", pretty a ] 96 | 97 | instance Pretty Term' where 98 | prettyPrec n t = case t of 99 | Var m -> do 100 | x <- getVarName m 101 | text x 102 | `catchError` \_ -> 103 | text ("!" ++ show m ++ "!") 104 | Def x -> text x 105 | App s t -> 106 | mparens (n > 5) $ 107 | sep [ prettyPrec 5 s, prettyPrec 6 t ] 108 | Lam t -> 109 | mparens (n > 0) $ 110 | sep [ text "\\" <> text (absName t) <+> text "->" 111 | , nest 2 $ pretty t 112 | ] 113 | 114 | prettyPat n p ret = case p of 115 | VarP x -> extendContext_ x $ ret $ PP.text x 116 | ConP c ps -> 117 | thread (prettyPat 1) ps $ \ds -> 118 | ret $ mparens' (n > 0 && not (null ps)) 119 | $ PP.fsep $ PP.text c : ds 120 | where 121 | mparens' True = PP.parens 122 | mparens' False = id 123 | 124 | instance Pretty a => Pretty (Abs a) where 125 | prettyPrec n (Abs x b) = extendContext_ x $ prettyPrec n b 126 | 127 | -------------------------------------------------------------------------------- /typechecker/TypeChecker/Monad.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ExistentialQuantification, DeriveFunctor, 2 | GeneralizedNewtypeDeriving #-} 3 | 4 | module TypeChecker.Monad where 5 | 6 | import Control.Monad 7 | import Control.Applicative 8 | import Control.Monad.State 9 | import Control.Monad.Reader 10 | import Control.Monad.Except 11 | import Data.Map as Map 12 | import Data.Dynamic 13 | 14 | import Syntax.Internal 15 | 16 | --------------------------------------------------------------------------- 17 | -- * Environment 18 | --------------------------------------------------------------------------- 19 | 20 | data TCEnv = TCEnv 21 | { envContext :: Context 22 | , envForce :: HowMuchForce 23 | } 24 | 25 | data HowMuchForce = Soft | Hard 26 | 27 | emptyEnv :: TCEnv 28 | emptyEnv = TCEnv 29 | { envContext = [] 30 | , envForce = Hard 31 | } 32 | 33 | gently :: TC a -> TC a 34 | gently = local $ \env -> env { envForce = Soft } 35 | 36 | howForceful :: TC HowMuchForce 37 | howForceful = asks envForce 38 | 39 | --------------------------------------------------------------------------- 40 | -- ** Context 41 | --------------------------------------------------------------------------- 42 | 43 | type Context = [ContextEntry] 44 | 45 | data ContextEntry 46 | = VBind Name Type 47 | | TBind Telescope 48 | 49 | --------------------------------------------------------------------------- 50 | -- * State 51 | --------------------------------------------------------------------------- 52 | 53 | data TCState = TCState 54 | { stHeap :: Heap 55 | , stSig :: Signature 56 | , stConstraints :: [Constraint] 57 | , stNextFree :: Ptr 58 | } 59 | 60 | initState :: TCState 61 | initState = TCState 62 | { stHeap = Map.empty 63 | , stSig = Map.empty 64 | , stConstraints = [] 65 | , stNextFree = Ptr 0 66 | } 67 | 68 | --------------------------------------------------------------------------- 69 | -- ** Signature 70 | --------------------------------------------------------------------------- 71 | 72 | type Signature = Map Name Definition 73 | 74 | type Constraint = Unit 75 | 76 | getConstraints :: TC [Constraint] 77 | getConstraints = gets stConstraints 78 | 79 | --------------------------------------------------------------------------- 80 | -- ** Heap 81 | --------------------------------------------------------------------------- 82 | 83 | type Heap = Map Ptr HeapObject 84 | 85 | data HeapObject = forall a. (Show a, Typeable a) => HpObj (Closure a) 86 | 87 | instance Show HeapObject where 88 | show (HpObj x) = show x 89 | 90 | data Closure a = Evaluated a 91 | | Unevaluated (TCClosure a) 92 | 93 | instance Show a => Show (Closure a) where 94 | show (Evaluated x) = show x 95 | show (Unevaluated _) = "_" 96 | 97 | data TCClosure a = TCCl 98 | { clEnv :: TCEnv 99 | , clActive :: Active 100 | , clAction :: TC a 101 | } 102 | 103 | data Active = Active | Inactive 104 | 105 | setActive :: Active -> Closure a -> Closure a 106 | setActive _ cl@Evaluated{} = cl 107 | setActive a (Unevaluated cl) = Unevaluated cl{ clActive = a } 108 | 109 | getActive :: Closure a -> Active 110 | getActive Evaluated{} = Inactive 111 | getActive (Unevaluated cl) = clActive cl 112 | 113 | runTCClosure :: TCClosure a -> TC a 114 | runTCClosure cl = local (const $ clEnv cl) $ clAction cl 115 | 116 | --------------------------------------------------------------------------- 117 | -- * Error 118 | --------------------------------------------------------------------------- 119 | 120 | data TCError = Fail String 121 | 122 | instance Show TCError where 123 | show (Fail s) = "fail: " ++ s 124 | 125 | --------------------------------------------------------------------------- 126 | -- * The Monad 127 | --------------------------------------------------------------------------- 128 | 129 | newtype TC a = TC { unTC :: ReaderT TCEnv (ExceptT TCError (StateT TCState IO)) a } 130 | deriving (Functor, MonadReader TCEnv, MonadState TCState, MonadError TCError, MonadIO, Monad) 131 | 132 | instance Applicative TC where 133 | pure = return 134 | (<*>) = ap 135 | 136 | instance MonadFail TC where 137 | fail = throwError . Fail 138 | 139 | runTC :: TC a -> IO (Either TCError a) 140 | runTC (TC m) = 141 | flip evalStateT initState $ 142 | runExceptT $ 143 | flip runReaderT emptyEnv $ 144 | m 145 | 146 | -------------------------------------------------------------------------------- /typechecker/TypeChecker/Monad/Heap.hs: -------------------------------------------------------------------------------- 1 | 2 | module TypeChecker.Monad.Heap where 3 | 4 | import Control.Applicative 5 | import Control.Monad 6 | import Control.Monad.Reader 7 | import Control.Monad.State 8 | import Control.Monad.Except 9 | import qualified Data.Map as Map 10 | import Data.Typeable 11 | 12 | import TypeChecker.Monad 13 | import Syntax.Internal 14 | import Utils.Monad 15 | 16 | --------------------------------------------------------------------------- 17 | -- * Helpers 18 | --------------------------------------------------------------------------- 19 | 20 | typeOfCl :: Typeable a => Closure a -> TypeRep 21 | typeOfCl = typeOf . unCl 22 | where 23 | unCl :: Closure a -> a 24 | unCl = undefined 25 | 26 | --------------------------------------------------------------------------- 27 | -- * Heap manipulation 28 | --------------------------------------------------------------------------- 29 | 30 | heapLookup :: Pointer ptr a => ptr -> Heap -> Closure a 31 | heapLookup = aux undefined 32 | where 33 | aux :: Pointer ptr a => a -> ptr -> Heap -> Closure a 34 | aux x ptr heap = either error id $ do 35 | HpObj cl <- Map.lookup p heap `err` ("bad pointer: " ++ show p) 36 | gcast cl `err` unlines 37 | [ "bad type in closure:" 38 | , "expected " ++ show (typeOf x) 39 | , "found " ++ show (typeOfCl cl) 40 | ] 41 | where 42 | p = toRawPtr ptr 43 | 44 | err Nothing s = Left s 45 | err (Just x) _ = Right x 46 | 47 | heapUpdate :: Pointer ptr a => ptr -> Closure a -> Heap -> Heap 48 | heapUpdate ptr cl = Map.insert p (HpObj cl) 49 | where 50 | p = toRawPtr ptr 51 | 52 | --------------------------------------------------------------------------- 53 | -- * Monadic functions 54 | --------------------------------------------------------------------------- 55 | 56 | getHeap :: TC Heap 57 | getHeap = gets stHeap 58 | 59 | setHeap :: Heap -> TC () 60 | setHeap = modHeap . const 61 | 62 | modHeap :: (Heap -> Heap) -> TC () 63 | modHeap f = modify $ \s -> s { stHeap = f $ stHeap s} 64 | 65 | getClosure :: Pointer ptr a => ptr -> TC (Closure a) 66 | getClosure p = heapLookup p <$> getHeap 67 | 68 | setClosure :: Pointer ptr a => ptr -> Closure a -> TC () 69 | setClosure p cl = modHeap $ heapUpdate p cl 70 | 71 | -- | Returns the new closure 72 | modClosure :: Pointer ptr a => (Closure a -> TC (Closure a)) -> ptr -> TC (Closure a) 73 | modClosure f p = do 74 | cl <- getClosure p 75 | case getActive cl of 76 | Active -> fail "<>" 77 | _ -> do 78 | setClosure p $ setActive Active cl 79 | cl' <- f cl 80 | setClosure p $ setActive Inactive cl' 81 | return cl' 82 | 83 | forceClosure :: Pointer ptr a => ptr -> TC a 84 | forceClosure p = do 85 | Evaluated x <- modClosure eval p 86 | return x 87 | where 88 | eval cl@(Evaluated _) = return cl 89 | eval (Unevaluated m) = do 90 | x <- runTCClosure m 91 | return (Evaluated x) 92 | 93 | freshPtr :: Pointer ptr a => TC ptr 94 | freshPtr = do 95 | Ptr n <- gets stNextFree 96 | modify $ \s -> s { stNextFree = Ptr (n + 1) } 97 | return $ fromRawPtr $ Ptr n 98 | 99 | alloc :: Pointer ptr a => Closure a -> TC ptr 100 | alloc cl = do 101 | p <- freshPtr 102 | setClosure p cl 103 | return p 104 | 105 | buildClosure :: TC a -> TC (Closure a) 106 | buildClosure m = do 107 | env <- ask 108 | return $ Unevaluated $ TCCl env Inactive m 109 | 110 | suspend :: Pointer ptr a => TC a -> TC ptr 111 | suspend m = alloc =<< buildClosure m 112 | 113 | evaluated :: Pointer ptr a => a -> TC ptr 114 | evaluated = alloc . Evaluated 115 | 116 | poke :: Pointer ptr a => ptr -> TC a -> TC () 117 | poke p m = do 118 | cl <- buildClosure m 119 | setClosure p cl 120 | 121 | recursive :: Pointer ptr a => (ptr -> TC ptr) -> TC ptr 122 | recursive f = do 123 | p <- suspend (fail "") 124 | poke p (forceClosure =<< f p) 125 | return p 126 | 127 | recursives :: Pointer ptr a => [b] -> ([ptr] -> b -> TC a) -> TC [ptr] 128 | recursives xs f = do 129 | ps <- replicateM (length xs) $ suspend (fail "") 130 | sequence_ [ poke p (f ps x) | (p, x) <- zip ps xs ] 131 | return ps 132 | 133 | updatePtr :: Pointer ptr a => (a -> a) -> ptr -> TC () 134 | updatePtr f = updatePtrM (return . f) 135 | 136 | updatePtrM :: Pointer ptr a => (a -> TC a) -> ptr -> TC () 137 | updatePtrM f p = do 138 | modClosure (apply f) p 139 | return () 140 | where 141 | apply f (Evaluated x) = do 142 | cl <- buildClosure (f x) 143 | return cl 144 | apply f (Unevaluated m) = 145 | return $ Unevaluated m{ clAction = action } 146 | where 147 | action = f =<< clAction m 148 | 149 | --------------------------------------------------------------------------- 150 | -- * Lifting functions to pointers 151 | --------------------------------------------------------------------------- 152 | 153 | liftPtr :: (Pointer ptr a, Pointer ptr' b) => 154 | (a -> b) -> ptr -> TC ptr' 155 | liftPtr f = liftPtrM (return . f) 156 | 157 | liftPtrM :: (Pointer ptr a, Pointer ptr' b) => 158 | (a -> TC b) -> ptr -> TC ptr' 159 | liftPtrM f p = suspend $ f =<< forceClosure p 160 | 161 | -------------------------------------------------------------------------------- /typechecker/TypeChecker/Reduce.hs: -------------------------------------------------------------------------------- 1 | 2 | module TypeChecker.Reduce where 3 | 4 | import Control.Applicative 5 | import Control.Monad.Trans 6 | import Control.Monad 7 | import Data.Monoid 8 | import Data.Foldable hiding (foldr1) 9 | import Data.Traversable hiding (mapM) 10 | 11 | import Syntax.Internal 12 | import TypeChecker.Monad 13 | import TypeChecker.Monad.Heap 14 | import TypeChecker.Monad.Signature 15 | import TypeChecker.Monad.Context 16 | import TypeChecker.DeBruijn 17 | import TypeChecker.Print 18 | import TypeChecker.Force 19 | import Utils 20 | 21 | data RedexView 22 | = Iota Name [Term] -- | The number of arguments should be the same as the arity. 23 | | Beta (Abs Term) Term 24 | | NonRedex Term' 25 | deriving (Show) 26 | 27 | -- | @spine (f a b c) = [(f, f), (a, f a), (b, f a b), (c, f a b c)]@ 28 | spine :: Term -> TC [(Term, Term)] 29 | spine p = do 30 | t <- forceClosure p 31 | case t of 32 | App s t -> do 33 | sp <- spine s 34 | return $ sp ++ [(t, p)] 35 | _ -> return [(p,p)] 36 | 37 | redexView :: Term -> TC RedexView 38 | redexView t = do 39 | sp <- spine t 40 | case sp of 41 | (h,_) : args -> do 42 | t <- forceClosure h 43 | case t of 44 | Var x -> other sp 45 | Def c -> do 46 | ar <- functionArity c 47 | case ar of 48 | Just n | n == length args 49 | -> return $ Iota c $ map fst args 50 | _ -> other sp 51 | Lam s -> case args of 52 | [(t,_)] -> return $ Beta s t 53 | _ -> other sp 54 | App s t -> fail "redexView: impossible App" 55 | _ -> fail "redexView: impossibly empty spine" 56 | where 57 | top = snd . last 58 | other sp = NonRedex <$> forceClosure (top sp) 59 | 60 | data ConView = ConApp Name [Term] 61 | | NonCon Term 62 | 63 | conView :: Term -> TC ConView 64 | conView t = do 65 | sp <- spine t 66 | case sp of 67 | (c,_):args -> do 68 | s <- forceClosure c 69 | case s of 70 | Def c -> return $ ConApp c $ map fst args 71 | _ -> return $ NonCon t 72 | _ -> fail "conView: impossibly empty spine" 73 | 74 | data Progress = NoProgress | YesProgress 75 | 76 | instance Monoid Progress where 77 | mempty = NoProgress 78 | mappend NoProgress p = p 79 | mappend p NoProgress = p 80 | mappend YesProgress YesProgress = YesProgress 81 | 82 | instance Semigroup Progress where 83 | (<>) = mappend 84 | 85 | whenProgress :: Monad m => Progress -> m a -> m () 86 | whenProgress YesProgress m = m >> return () 87 | whenProgress NoProgress m = return () 88 | 89 | class WHNF a where 90 | whnf :: a -> TC Progress 91 | 92 | instance (WHNF a, WHNF b) => WHNF (a,b) where 93 | whnf (x,y) = mappend <$> whnf x <*> whnf y 94 | 95 | instance WHNF Type where 96 | whnf p = do 97 | a <- forceClosure p 98 | case a of 99 | RPi _ _ -> return NoProgress 100 | Pi _ _ -> return NoProgress 101 | Fun _ _ -> return NoProgress 102 | Set -> return NoProgress 103 | El t -> whnf t 104 | 105 | instance WHNF Term where 106 | whnf p = do 107 | v <- redexView p 108 | case v of 109 | NonRedex t -> case t of 110 | App s t -> do 111 | pr <- whnf s 112 | whenProgress pr $ whnf p 113 | return pr 114 | Lam _ -> return NoProgress 115 | Var _ -> return NoProgress 116 | Def _ -> return NoProgress 117 | Beta s t -> do 118 | poke p (forceClosure =<< subst t s) 119 | whnf p 120 | Iota f ts -> do 121 | Defn _ _ cs <- getDefinition f 122 | m <- match cs ts 123 | case m of 124 | YesMatch t -> do 125 | poke p (forceClosure t) 126 | whnf p 127 | MaybeMatch -> return NoProgress 128 | NoMatch -> fail "Incomplete pattern matching" 129 | 130 | data Match a = YesMatch a | MaybeMatch | NoMatch 131 | 132 | instance Monoid a => Monoid (Match a) where 133 | mempty = YesMatch mempty 134 | mappend NoMatch _ = NoMatch 135 | mappend _ NoMatch = NoMatch 136 | mappend MaybeMatch _ = MaybeMatch 137 | mappend _ MaybeMatch = MaybeMatch 138 | mappend (YesMatch ts) (YesMatch ss) = YesMatch $ ts `mappend` ss 139 | 140 | instance Monoid a => Semigroup (Match a) where 141 | (<>) = mappend 142 | 143 | instance Functor Match where 144 | fmap f (YesMatch x) = YesMatch $ f x 145 | fmap f NoMatch = NoMatch 146 | fmap f MaybeMatch = MaybeMatch 147 | 148 | instance Foldable Match where 149 | foldMap f (YesMatch x) = f x 150 | foldMap f NoMatch = mempty 151 | foldMap f MaybeMatch = mempty 152 | 153 | instance Traversable Match where 154 | traverse f (YesMatch x) = YesMatch <$> f x 155 | traverse f NoMatch = pure NoMatch 156 | traverse f MaybeMatch = pure MaybeMatch 157 | 158 | choice :: Match a -> Match a -> Match a 159 | choice NoMatch m = m 160 | choice m _ = m 161 | 162 | -- | Invariant: there are the same number of terms as there are patterns in the clauses 163 | match :: [Clause] -> [Term] -> TC (Match Term) 164 | match cs ts = foldr1 choice <$> mapM (flip matchClause ts) cs 165 | 166 | matchClause :: Clause -> [Term] -> TC (Match Term) 167 | matchClause c ts = do 168 | Clause ps t <- forceClosure c 169 | m <- matchPatterns ps ts 170 | traverse (\ss -> substs ss t) m 171 | 172 | matchPatterns :: [Pattern] -> [Term] -> TC (Match [Term]) 173 | matchPatterns ps ts = mconcat <$> zipWithM matchPattern ps ts 174 | 175 | matchPattern :: Pattern -> Term -> TC (Match [Term]) 176 | matchPattern (VarP _) t = return $ YesMatch [t] 177 | matchPattern (ConP c ps) t = do 178 | whnf t 179 | v <- conView t 180 | case v of 181 | ConApp c' ts 182 | | c == c' -> matchPatterns ps (dropPars ts) 183 | | otherwise -> return NoMatch 184 | _ -> return MaybeMatch 185 | where 186 | dropPars ts = drop (length ts - length ps) ts 187 | 188 | -------------------------------------------------------------------------------- /typechecker/TypeChecker.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE UndecidableInstances, FlexibleInstances, OverlappingInstances #-} 2 | 3 | module TypeChecker where 4 | 5 | import Control.Applicative 6 | import Control.Monad 7 | import Control.Monad.Except 8 | import Control.Monad.IO.Class 9 | 10 | import qualified Data.Map as Map 11 | import Data.List 12 | 13 | import qualified Syntax.Abs as Abs 14 | import Syntax.Print 15 | import Syntax.Internal 16 | 17 | import TypeChecker.Monad 18 | import TypeChecker.Monad.Heap 19 | import TypeChecker.Monad.Context 20 | import TypeChecker.Monad.Signature 21 | import TypeChecker.Force 22 | import TypeChecker.DeBruijn 23 | import TypeChecker.Print 24 | import TypeChecker.Reduce 25 | 26 | import Utils 27 | 28 | identToName :: Abs.Ident -> Name 29 | identToName (Abs.Ident x) = x 30 | 31 | debug :: TC () 32 | debug = do 33 | sig <- getSig 34 | heap <- getHeap 35 | -- liftIO $ putStrLn "Heap:" 36 | -- mapM_ (pr show) $ Map.assocs heap 37 | mapM_ pr $ Map.elems sig 38 | return () 39 | where 40 | pr e = do 41 | d <- pretty e 42 | liftIO $ print d 43 | 44 | dbg :: Show a => TC a -> TC () 45 | dbg d = return () -- (liftIO . putStrLn) . show =<< d 46 | 47 | checkProgram :: Abs.Program -> TC () 48 | checkProgram (Abs.Prog ds) = do 49 | setSig =<< buildSignature ds 50 | forceSig 51 | debug 52 | `catchError` \e -> do 53 | debug 54 | throwError e 55 | 56 | buildSignature :: [Abs.Decl] -> TC Signature 57 | buildSignature ds = 58 | Map.unions 59 | <$> mapM buildDef 60 | ( groupBy ((==) `on` name) 61 | $ sortBy (compare `on` name) 62 | $ ds 63 | ) 64 | where 65 | name (Abs.TypeSig i _) = identToName i 66 | name (Abs.FunDef i _ _) = identToName i 67 | name (Abs.DataDecl i _ _) = identToName i 68 | 69 | telePi :: [Abs.TelBinding] -> Abs.Expr -> Abs.Expr 70 | telePi tel e = foldr bind e tel -- (\ (Abs.Bind x t) s -> Abs.Pi x t s) e tel 71 | where 72 | bind (Abs.PiBind (Abs.Bind x t)) s = Abs.Pi x t s 73 | bind (Abs.TelBind tel) s = Abs.RPi tel s 74 | 75 | constrWithArity :: Abs.Constr -> Constructor 76 | constrWithArity (Abs.Constr c t) = Constr (identToName c) $ arity t 77 | where 78 | arity (Abs.Pi _ _ b) = 1 + arity b 79 | arity (Abs.Fun _ b) = 1 + arity b 80 | arity (Abs.RPi (Abs.Tel tel) b) = length tel + arity b 81 | arity _ = 0 82 | 83 | buildConstr :: [Abs.TelBinding] -> Abs.Constr -> TC Signature 84 | buildConstr tel (Abs.Constr x e) = do 85 | let e' = telePi tel e 86 | c = identToName x 87 | t <- isType e' 88 | return $ Map.singleton c $ Cons c t 89 | 90 | buildDef :: [Abs.Decl] -> TC Signature 91 | buildDef [Abs.DataDecl x tel cs] = do 92 | t <- isType (telePi tel Abs.Set) 93 | tcs <- mapM (buildConstr tel) cs 94 | let d = identToName x 95 | decl = Map.singleton d $ Data d t $ map constrWithArity cs 96 | return $ Map.unions $ decl : tcs 97 | 98 | buildDef ds = do 99 | t <- getType 100 | cs <- getClauses t 101 | let d = case cs of 102 | [] -> Axiom x t 103 | _ -> Defn x t cs 104 | return $ Map.singleton x d 105 | where 106 | x = name $ head ds 107 | 108 | getType = case [ t | Abs.TypeSig _ t <- ds ] of 109 | [t] -> isType t 110 | [] -> fail $ "No type signature for " ++ x 111 | ts -> fail $ "Multiple type signatures for " ++ x 112 | 113 | getClauses t = mapM (getClause t) [ d | d@(Abs.FunDef _ _ _) <- ds ] 114 | getClause t (Abs.FunDef _ ps e) = 115 | checkClause ps e t 116 | getClause _ _ = error "getClause: __IMPOSSIBLE__" 117 | 118 | checkClause :: [Abs.Pattern] -> Abs.Expr -> Type -> TC Clause 119 | checkClause ps e t = suspend (checkClause' ps e t) 120 | 121 | checkClause' :: [Abs.Pattern] -> Abs.Expr -> Type -> TC Clause' 122 | checkClause' ps e a = do 123 | ps <- mapM buildPattern ps 124 | checkPatterns ps a $ \_ _ b -> do 125 | t <- checkType e b 126 | return $ Clause ps t 127 | 128 | buildPattern :: Abs.Pattern -> TC Pattern 129 | buildPattern p = case appView p of 130 | Abs.VarP i : ps -> do 131 | isConstructor x 132 | ConP x <$> mapM buildPattern ps 133 | `catchError` \_ -> 134 | case ps of 135 | [] -> return $ VarP x 136 | _ -> fail $ "Undefined constructor " ++ x 137 | where 138 | x = identToName i 139 | _ -> fail "__IMPOSSIBLE__" 140 | where 141 | appView (Abs.AppP p q) = appView p ++ [q] 142 | appView p = [p] 143 | 144 | isDatatype :: Type -> TC (Name, [Term]) 145 | isDatatype p = do 146 | whnf p 147 | a <- forceClosure p 148 | case a of 149 | El t -> do 150 | (d,_):args <- spine t 151 | d <- forceClosure d 152 | case d of 153 | Def d -> return (d, map fst args) 154 | _ -> badData 155 | _ -> badData 156 | where 157 | badData = fail . show =<< pretty p <+> text "is not a datatype" 158 | 159 | piApply :: Type -> [Term] -> TC Type 160 | piApply a [] = return a 161 | piApply a ts0@(t:ts) = suspend $ do 162 | whnf a 163 | a <- forceClosure a 164 | forceClosure =<< 165 | case a of 166 | Pi _ b -> flip piApply ts =<< subst t b 167 | Fun _ b -> piApply b ts 168 | RPi tel b 169 | | length ts0 < length tel -> fail $ "piApply: too few arguments to " ++ show a ++ ": " ++ show ts0 170 | | otherwise -> do 171 | let (ts1, ts2) = splitAt (length tel) ts0 172 | b <- substs ts1 b 173 | piApply b ts2 174 | _ -> fail . show =<< text "piApply: not a function type" <+> pretty a 175 | 176 | data ArgType = SimpleArg Type 177 | | TelArg Telescope 178 | 179 | -- | The argument should be a function type. 180 | argumentType :: Type -> TC ArgType 181 | argumentType p = do 182 | a <- forceClosure p 183 | case a of 184 | Pi a _ -> return (SimpleArg a) 185 | Fun a _ -> return (SimpleArg a) 186 | RPi tel _ -> return (TelArg tel) 187 | _ -> fail . show =<< text "expected function type, found" <+> pretty p 188 | 189 | checkPatterns :: [Pattern] -> Type -> (Integer -> [Term] -> Type -> TC a) -> TC a 190 | checkPatterns [] a ret = ret 0 [] a 191 | checkPatterns ps0@(p:ps) a ret = do 192 | arg <- argumentType a 193 | case arg of 194 | SimpleArg arg -> 195 | checkPattern p arg $ \n t -> do 196 | a <- raiseBy n a 197 | b <- piApply a [t] 198 | checkPatterns ps b $ \m ts b -> do 199 | t <- raiseBy m t 200 | ret (n + m) (t:ts) b 201 | TelArg tel -> do 202 | unless (length ps0 >= length tel) $ fail $ "Not enough arguments to constructor of type " ++ show a 203 | let (ps1, ps2) = splitAt (length tel) ps0 204 | allVars = mapM isVar 205 | isVar (VarP x) = Just x 206 | isVar _ = Nothing 207 | case allVars ps1 of 208 | Nothing -> fail $ "Recursive telescope patterns must be variables: " ++ show ps1 209 | Just xs -> do 210 | let n = genericLength xs 211 | a <- raiseBy n a 212 | vs <- mapM evaluated [ Var i | i <- reverse [0..n - 1] ] 213 | b <- piApply a vs 214 | extendContextTel [ RBind x a | (x, RBind _ a) <- zip xs tel ] $ 215 | checkPatterns ps2 b $ \m us b -> do 216 | vs <- raiseBy m vs 217 | ret (n + m) (vs ++ us) b 218 | 219 | checkPattern :: Pattern -> Type -> (Integer -> Term -> TC a) -> TC a 220 | checkPattern p a ret = 221 | case p of 222 | VarP x -> extendContext x a $ ret 1 =<< evaluated (Var 0) 223 | ConP x ps -> do 224 | b <- defType x 225 | (d, us) <- isDatatype a 226 | b' <- piApply b us 227 | checkPatterns ps b' $ \n ts a' -> do 228 | a <- raiseBy n a 229 | us <- raiseBy n us 230 | a === a' `catchError` \ (Fail s) -> do 231 | s' <- show <$> vcat [ text "when checking the type of", nest 2 $ text (show p) 232 | , nest 2 $ text "expected:" <+> pretty a 233 | , nest 2 $ text "inferred:" <+> pretty a' ] 234 | fail $ s ++ "\n" ++ s' 235 | h <- evaluated (Def x) 236 | ret n =<< apps h (us ++ ts) 237 | where 238 | apps :: Term -> [Term] -> TC Term 239 | apps s [] = return s 240 | apps s (t:ts) = do 241 | st <- evaluated (App s t) 242 | apps st ts 243 | 244 | 245 | isType :: Abs.Expr -> TC Type 246 | isType e = suspend (isType' e) 247 | 248 | checkType :: Abs.Expr -> Type -> TC Term 249 | checkType e t = suspend (checkType' e t) 250 | 251 | inferType :: Abs.Expr -> TC (Term, Type) 252 | inferType e = do 253 | p <- suspend (inferType' e) 254 | ptm <- suspend $ fst <$> forceClosure p 255 | ptp <- (forceClosure . snd) `liftPtrM` p 256 | return (ptm, ptp) 257 | 258 | isType' :: Abs.Expr -> TC Type' 259 | isType' e = do 260 | case e of 261 | Abs.Pi x e1 e2 -> do 262 | x <- return $ identToName x 263 | a <- isType e1 264 | b <- extendContext x a $ isType e2 265 | return $ Pi a (Abs x b) 266 | Abs.RPi tel e -> do 267 | tel <- checkTel tel 268 | a <- extendContextTel tel $ isType e 269 | return $ RPi tel a 270 | Abs.Fun e1 e2 -> Fun <$> isType e1 <*> isType e2 271 | Abs.Set -> return Set 272 | e -> do 273 | set <- evaluated Set 274 | El <$> checkType e set 275 | 276 | checkTel :: Abs.Telescope -> TC Telescope 277 | checkTel (Abs.Tel tel) = do 278 | let mkTel as = [ RBind x a | (Abs.RBind (Abs.Ident x) _, a) <- zip tel as ] 279 | check as (Abs.RBind _ e) = extendContextTel (mkTel as) $ isType' e 280 | as <- recursives tel check :: TC [Type] 281 | return $ mkTel as 282 | 283 | checkType' :: Abs.Expr -> Type -> TC Term' 284 | checkType' e a = do 285 | case e of 286 | Abs.Lam [] e -> checkType' e a 287 | Abs.Lam xs0@(x:xs) e -> do 288 | let e' = Abs.Lam xs e 289 | x <- return $ identToName x 290 | a <- forceClosure a 291 | s <- case a of 292 | Pi a b -> extendContext x a (checkType e' (absBody b)) 293 | Fun a b -> extendContext x a (checkType e' =<< raise b) 294 | RPi tel b -> do 295 | unless (length xs0 >= length tel) $ fail $ "Too few arguments in insane lambda: " ++ printTree xs0 ++ " < " ++ show (length tel) 296 | let (ys, zs) = splitAt (length tel) xs0 297 | tel' = [ RBind (identToName y) a | (y, RBind _ a) <- zip ys tel ] 298 | extendContextTel tel' $ checkType (Abs.Lam zs e) b 299 | _ -> fail $ "expected function type, found " ++ show a 300 | return $ Lam (Abs x s) 301 | e -> do 302 | (v, b) <- inferType' e 303 | addConstraint $ 304 | a === b `sayWhen` do 305 | gently $ force (v, (a, b)) 306 | vcat [ sep [text "when checking the type of", nest 4 $ pretty v] 307 | , nest 2 $ text "expected:" <+> pretty a 308 | , nest 2 $ text "inferred:" <+> pretty b ] 309 | return v 310 | `sayWhen` sep [ text "when checking that" 311 | , nest 4 $ text (printTree e) 312 | , nest 2 $ text "has type" 313 | , nest 4 $ pretty a ] 314 | 315 | sayWhen m wh = m `catchError` \(Fail s) -> do 316 | s' <- show <$> wh 317 | fail $ s ++ "\n" ++ s' 318 | 319 | appView :: Abs.Expr -> [Abs.Expr] 320 | appView (Abs.App e1 e2) = appView e1 ++ [e2] 321 | appView e = [e] 322 | 323 | checkArgs :: Type -> [Abs.Expr] -> [Abs.Expr] -> TC ([Term], Type) 324 | checkArgs a _ [] = return ([], a) 325 | checkArgs a es0 es@(e:es1) = do 326 | a <- forceClosure a 327 | case a of 328 | Fun a b -> do 329 | v <- checkType e a 330 | (vs, c) <- checkArgs b (es0 ++ [e]) es1 331 | return (v : vs, c) 332 | Pi a b -> do 333 | v <- checkType e a 334 | b <- subst v b 335 | (vs, c) <- checkArgs b (es0 ++ [e]) es1 336 | return (v : vs, c) 337 | RPi tel b -> do 338 | unless (length es >= length tel) $ fail $ "Insanely dependent functions must be fully applied: " ++ show es ++ " < " ++ show tel 339 | let (es1, es2) = splitAt (length tel) es 340 | vs <- recursives (zip tel es1) $ \vs (RBind _ a, e) -> do 341 | a <- substs vs a 342 | v <- checkType' e a 343 | return v 344 | b <- substs vs b 345 | (us, c) <- checkArgs b (es0 ++ es1) es2 346 | return (vs ++ us, c) 347 | _ -> fail $ unlines [ "Expected function type, found " ++ show a 348 | , "in the application of " ++ printTree (foldl1 Abs.App es0) 349 | , "to " ++ printTree (head es) 350 | ] 351 | 352 | inferType' :: Abs.Expr -> TC (Term', Type) 353 | inferType' e0 = do 354 | case e0 of 355 | Abs.Name i -> do 356 | t <- defType x 357 | return (Def x, t) 358 | `catchError` \_ -> do 359 | (n,t) <- lookupContext x 360 | return (Var n, t) 361 | where 362 | x = identToName i 363 | Abs.App{} -> do 364 | let e : es@(_:_) = appView e0 365 | (f, a) <- inferType e 366 | (args, b) <- checkArgs a [e] es 367 | v <- buildApp f args 368 | return (v, b) 369 | where 370 | buildApp :: Term -> [Term] -> TC Term' 371 | buildApp f [] = forceClosure f 372 | buildApp f (e : es) = do 373 | fe <- evaluated $ App f e 374 | buildApp fe es 375 | e -> fail $ "inferType not implemented for " ++ printTree e 376 | 377 | class Convert a where 378 | (===) :: a -> a -> TC () 379 | 380 | instance (Pretty a, Pointer ptr a, Convert a, Force a, WHNF ptr) => Convert ptr where 381 | p === q 382 | | p == q = return () 383 | | otherwise = do 384 | whnf (p, q) 385 | x <- forceClosure p 386 | y <- forceClosure q 387 | x === y `sayWhen` do 388 | gently $ force (x, y) 389 | sep [ text "when checking that" 390 | , nest 2 $ pretty x <+> text "==" 391 | , nest 2 $ pretty y ] 392 | 393 | instance Convert Type' where 394 | Pi a b === Pi a' b' = (a,b) === (a',b') 395 | Fun a b === Fun a' b' = (a,b) === (a',b') 396 | Pi a b === Fun a' b' = do 397 | a === a' 398 | b' <- raise b' 399 | b === Abs "x" b' 400 | Fun a b === Pi a' b' = do 401 | a === a' 402 | b <- raise b 403 | Abs "x" b === b' 404 | Set === Set = return () 405 | El t === El t' = t === t' 406 | a === b = fail . show =<< pretty a <+> text "!=" <+> pretty b 407 | 408 | instance Convert Term' where 409 | a === b = test a b 410 | where 411 | test (Lam s) (Lam t) = 412 | extendContext (absName s) (error "") $ s === t 413 | test (App s s') (App t t') = do 414 | s === t 415 | s' === t' 416 | test (Var n) (Var m) 417 | | n == m = return () 418 | | otherwise = do 419 | x <- getVarName n 420 | y <- getVarName m 421 | fail $ x ++ " (" ++ show n ++ ") != " ++ y ++ " (" ++ show m ++ ")" 422 | test (Def f) (Def g) 423 | | f == g = return () 424 | | otherwise = fail $ f ++ " != " ++ g 425 | test s t = do 426 | force (s, t) `catchError` \_ -> return () 427 | d <- pretty s <+> text " != " <+> pretty t 428 | fail $ show d 429 | 430 | instance (Convert a, Convert b) => Convert (a,b) where 431 | (x,y) === (x',y') = do x === x'; y === y' 432 | 433 | instance Convert a => Convert (Abs a) where 434 | (===) = (===) `on` absBody 435 | 436 | --------------------------------------------------------------------------------