├── .dir-locals.el ├── Setup.hs ├── stack.yaml ├── hie.yaml ├── .gitignore ├── .ghci ├── src └── Lamdu │ └── Calc │ ├── Internal │ └── Prelude.hs │ ├── Definition.hs │ ├── Term │ └── Eq.hs │ ├── Identifier.hs │ ├── Infer.hs │ ├── Lens.hs │ ├── Term.hs │ └── Type.hs ├── test ├── test.hs ├── benchmark.hs └── TestVals.hs ├── LICENSE ├── doc ├── Optimization.md └── ExceptionMonad.md ├── package.yaml ├── tools └── core-type-apps.py ├── lamdu-calculus.cabal └── README.md /.dir-locals.el: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: lts-20.23 2 | packages: 3 | - '.' 4 | extra-deps: 5 | - github: lamdu/hypertypes 6 | commit: 354ac7be8dc1b15a8df9b94ff010054ec2b31da8 7 | -------------------------------------------------------------------------------- /hie.yaml: -------------------------------------------------------------------------------- 1 | cradle: 2 | stack: 3 | - path: "./test" 4 | component: "lamdu-calculus:test:lamdu-calculus-test" 5 | 6 | - path: "./src" 7 | component: "lamdu-calculus:lib" 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /dist/ 2 | .stack-work/ 3 | .stack-work-profile/ 4 | /tags 5 | /ghci-out 6 | .ghcid 7 | .vscode/ 8 | *.o 9 | *.hi 10 | *.dyn_o 11 | *.dyn_hi 12 | *.prof 13 | stack.yaml.lock 14 | /dumps/ 15 | .DS_Store 16 | -------------------------------------------------------------------------------- /.ghci: -------------------------------------------------------------------------------- 1 | :set -isrc -itest 2 | :set -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints 3 | :set -odir=ghci-out 4 | :set -outputdir=ghci-out 5 | :set -hide-package base-compat-batteries 6 | -------------------------------------------------------------------------------- /src/Lamdu/Calc/Internal/Prelude.hs: -------------------------------------------------------------------------------- 1 | module Lamdu.Calc.Internal.Prelude 2 | ( module X 3 | ) where 4 | 5 | import Control.DeepSeq as X (NFData(..)) 6 | import Control.Lens.Operators as X 7 | import Control.Monad as X (void, guard, join) 8 | import Data.Binary as X (Binary) 9 | import Data.ByteString as X (ByteString) 10 | import Data.Hashable as X (Hashable) 11 | import Data.Map as X (Map) 12 | import Data.Set as X (Set) 13 | import Data.String as X (IsString(..)) 14 | import Prelude.Compat as X 15 | -------------------------------------------------------------------------------- /test/test.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings, TypeFamilyDependencies #-} 2 | 3 | import Control.Lens (at) 4 | import Control.Lens.Operators 5 | import Data.Set (fromList) 6 | import Hyper 7 | import Hyper.Syntax.Scheme 8 | import Lamdu.Calc.Infer 9 | import Lamdu.Calc.Type 10 | import Test.Framework 11 | import Test.Framework.Providers.HUnit (testCase) 12 | import Test.HUnit (assertBool) 13 | 14 | alphaEqTest :: Test 15 | alphaEqTest = 16 | not (alphaEq (f (TVarP "a")) (f (TRecordP REmptyP))) 17 | & assertBool "should alpha eq" 18 | & testCase "alpha-eq" 19 | where 20 | f x = 21 | TFunP (TVarP "a") (TVariantP (RExtendP "t" x (RVarP "c"))) ^. hPlain 22 | & Scheme 23 | ( Types 24 | (QVars (mempty & at "a" ?~ mempty)) 25 | (QVars (mempty & at "c" ?~ RowConstraints (fromList ["t"]) mempty)) 26 | ) 27 | & Pure 28 | 29 | main :: IO () 30 | main = defaultMain [alphaEqTest] 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Author name here (c) 2016 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of Author name here nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /src/Lamdu/Calc/Definition.hs: -------------------------------------------------------------------------------- 1 | -- | A lamdu-calculus `Definition` is a top-level definition of a 2 | -- value along with the types of all free variables/nominals it uses 3 | {-# LANGUAGE NoImplicitPrelude, TemplateHaskell, DeriveGeneric, TypeOperators #-} 4 | 5 | module Lamdu.Calc.Definition 6 | ( Deps(..), depsGlobalTypes, depsNominals 7 | , pruneDeps 8 | ) where 9 | 10 | import Hyper (Ann, Pure, Const(..), Generic, HFunctor(..), hflipped, type (#)) 11 | import Hyper.Syntax.Nominal (NominalDecl) 12 | import qualified Control.Lens as Lens 13 | import qualified Data.Map as Map 14 | import qualified Data.Set as Set 15 | import Lamdu.Calc.Lens (valGlobals, valNominals) 16 | import qualified Lamdu.Calc.Term as V 17 | import Lamdu.Calc.Type (Type) 18 | import qualified Lamdu.Calc.Type as T 19 | 20 | import Lamdu.Calc.Internal.Prelude 21 | 22 | data Deps = Deps 23 | { _depsGlobalTypes :: !(Map V.Var (Pure # T.Scheme)) 24 | , _depsNominals :: !(Map T.NominalId (Pure # NominalDecl Type)) 25 | } deriving (Generic, Show, Eq, Ord) 26 | instance NFData Deps 27 | instance Binary Deps 28 | 29 | Lens.makeLenses ''Deps 30 | 31 | instance Semigroup Deps where 32 | Deps t0 n0 <> Deps t1 n1 = Deps (t0 <> t1) (n0 <> n1) 33 | instance Monoid Deps where 34 | mempty = Deps mempty mempty 35 | mappend = (<>) 36 | 37 | pruneDeps :: 38 | Ann a # V.Term -> Deps -> Deps 39 | pruneDeps e deps = 40 | deps 41 | & depsGlobalTypes %~ prune (valGlobals mempty) 42 | & depsNominals %~ prune valNominals 43 | where 44 | ev = e & hflipped %~ hmap (\_ _ -> Const ()) 45 | prune f = Map.filterWithKey (const . (`Set.member` Set.fromList (ev ^.. f))) 46 | -------------------------------------------------------------------------------- /doc/Optimization.md: -------------------------------------------------------------------------------- 1 | # The process for optimizing the inference with SPECIALIZATION 2 | 3 | This document is work-in-progress! I need help! 4 | 5 | Run the benchmark while dumping GHC core: 6 | 7 | * If building after not changing the code, force a build via `stack clean lamdu-calculus` 8 | * `stack bench --ghc-options "-dumpdir dumps -ddump-simpl -dsuppress-coercions -dsuppress-idinfo -dsuppress-module-prefixes -dsuppress-timestamps"` 9 | * This generated `.dump-simpl` files in the `dumps` folder 10 | * The top-level file of interest is `dumps/test/benchmark.dump-simpl` 11 | * In `dumps/src/Lamdu/Calc/` there are dump files for `Infer` and `Term` which are also of interest 12 | 13 | ## Searching for type applications 14 | 15 | The type applications may look like: 16 | 17 | * `pruneDeps1 @ ()` 18 | * `emptyScope @ ('AHyperType UVar)` 19 | * `$fNFDataTypeError_$crnf @ ('AHyperType Pure)` 20 | 21 | Now we need to see which of those are benign. 22 | If the type of the definition, does not have class constraints on these variables, such as 23 | 24 | pruneDeps :: Tree (Ann a) V.Term -> Deps -> Deps 25 | 26 | Then this is a benign type application (i.e no benefit from adding a `SPECIALIZE` pragma for it). 27 | Other applications may not be benign but are not significant, 28 | for example if they are only called one time when the inference results in an error, but not in inner loops, 29 | such as the `NFData` instance above. 30 | 31 | We want to add `SPECIALIZE` pragmas to significant unspecialized (using type applications) calls. 32 | 33 | We can find such type applications using the `tools/core-type-apps.py` script and manually searching for important functions in the code (the wip script isn't fully functional..) 34 | 35 | Note that apparently sometimes specializing one function causes GHC to not use a specialized version of an inner call due to type family confusion. 36 | -------------------------------------------------------------------------------- /package.yaml: -------------------------------------------------------------------------------- 1 | name: lamdu-calculus 2 | version: 0.2.0.1 3 | github: "lamdu/lamdu-calculus" 4 | license: BSD3 5 | author: "Yair Chuchem, Eyal Lotem" 6 | maintainer: "yairchu@gmail.com" 7 | copyright: "2021 Yair Chuchem, Eyal Lotem" 8 | 9 | extra-source-files: 10 | - README.md 11 | 12 | synopsis: The Lamdu Calculus programming language 13 | category: Language 14 | 15 | description: Please see README.md 16 | 17 | dependencies: 18 | - base >= 4.7 19 | - base-compat >= 0.8.2 20 | - bytestring 21 | - containers 22 | - hypertypes >= 0.2 23 | - lens >= 4.1 24 | 25 | ghc-options: 26 | - -fexpose-all-unfoldings 27 | - -Wall 28 | - -Wnoncanonical-monad-instances 29 | - -Wcompat 30 | - -Wincomplete-record-updates 31 | - -Wincomplete-uni-patterns 32 | - -Wredundant-constraints 33 | - -Wunused-packages 34 | - -fdicts-cheap 35 | - -O2 36 | - -fspecialise-aggressively 37 | 38 | ghc-prof-options: 39 | - -O2 40 | 41 | library: 42 | source-dirs: src 43 | other-modules: 44 | - Lamdu.Calc.Internal.Prelude 45 | dependencies: 46 | - base16-bytestring 47 | - binary 48 | - deepseq 49 | - generic-constraints 50 | - generic-data 51 | - hashable 52 | - lattices 53 | - monad-st 54 | - mtl 55 | - pretty >= 1.1.2 56 | - transformers 57 | 58 | tests: 59 | lamdu-calculus-test: 60 | main: test.hs 61 | source-dirs: test 62 | dependencies: 63 | - HUnit 64 | - lamdu-calculus 65 | - test-framework 66 | - test-framework-hunit 67 | 68 | benchmarks: 69 | lamdu-calculus-bench: 70 | main: benchmark.hs 71 | source-dirs: test 72 | ghc-options: 73 | - -O2 74 | - -Wall 75 | - -Wnoncanonical-monad-instances 76 | - -Wcompat 77 | - -Wincomplete-record-updates 78 | - -Wincomplete-uni-patterns 79 | - -Wredundant-constraints 80 | dependencies: 81 | - criterion 82 | - deepseq 83 | - monad-st 84 | - mtl 85 | - lamdu-calculus 86 | - transformers 87 | -------------------------------------------------------------------------------- /src/Lamdu/Calc/Term/Eq.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoImplicitPrelude, TypeApplications, TypeOperators, FlexibleInstances #-} 2 | 3 | module Lamdu.Calc.Term.Eq 4 | ( couldEq 5 | ) where 6 | 7 | import Hyper 8 | import Hyper.Class.ZipMatch 9 | import Hyper.Type.Prune 10 | import qualified Control.Lens as Lens 11 | import qualified Data.Map as Map 12 | import Lamdu.Calc.Term 13 | import qualified Lamdu.Calc.Type as T 14 | 15 | import Lamdu.Calc.Internal.Prelude 16 | 17 | class CouldEq e where 18 | go :: Map Var Var -> Pure # e -> Pure # e -> Bool 19 | 20 | instance CouldEq Term where 21 | go xToY (Pure (BLam (TypedLam xvar xtyp xresult))) (Pure (BLam (TypedLam yvar ytyp yresult))) = 22 | go xToY xtyp ytyp && 23 | go (xToY & Lens.at xvar ?~ yvar) xresult yresult 24 | go xToY (Pure xBody) (Pure yBody) = 25 | case join (zipMatch_ (Proxy @CouldEq #> \x -> guard . go xToY x) xBody yBody) of 26 | Just () -> True 27 | Nothing -> 28 | case (xBody, yBody) of 29 | (BLeaf LHole, _) -> True 30 | (_, BLeaf LHole) -> True 31 | (BLeaf (LVar x), BLeaf (LVar y)) -> xToY ^. Lens.at x == Just y 32 | _ -> False 33 | 34 | instance CouldEq (HCompose Prune T.Type) where 35 | go _ (Pure (HCompose Pruned)) _ = True 36 | go _ _ (Pure (HCompose Pruned)) = True 37 | go xToY (Pure xBody) (Pure yBody) = 38 | Lens.has Lens._Just 39 | (join (zipMatch_ (Proxy @CouldEq #> \x -> guard . go xToY x) xBody yBody)) 40 | 41 | instance CouldEq (HCompose Prune T.Row) where 42 | go _ (Pure (HCompose Pruned)) _ = True 43 | go _ _ (Pure (HCompose Pruned)) = True 44 | go xToY (Pure xBody) (Pure yBody) = 45 | Lens.has Lens._Just 46 | (join (zipMatch_ (Proxy @CouldEq #> \x -> guard . go xToY x) xBody yBody)) 47 | 48 | couldEq :: Pure # Term -> Pure # Term -> Bool 49 | couldEq = go Map.empty 50 | -------------------------------------------------------------------------------- /src/Lamdu/Calc/Identifier.hs: -------------------------------------------------------------------------------- 1 | -- | This module defines the 'Identifier' type and is meant to be 2 | -- cheaply importable without creaeting import cycles. 3 | {-# LANGUAGE NoImplicitPrelude, DeriveGeneric, GeneralizedNewtypeDeriving, DerivingStrategies #-} 4 | module Lamdu.Calc.Identifier 5 | ( 6 | -- * Identifier type 7 | Identifier(..) 8 | -- * Hex representation of identifier bytes 9 | , identHex, identFromHex 10 | -- ** Laws 11 | 12 | -- | 13 | -- > identFromHex . identHex == Right 14 | ) where 15 | 16 | import qualified Data.ByteString.Base16 as Hex 17 | import qualified Data.ByteString.Char8 as BS 18 | import qualified Data.Char as Char 19 | import GHC.Generics (Generic) 20 | import qualified Text.PrettyPrint as PP 21 | import Text.PrettyPrint.HughesPJClass (Pretty(..)) 22 | 23 | import Lamdu.Calc.Internal.Prelude 24 | 25 | -- | A low-level identifier data-type. This is used to identify 26 | -- variables, type variables, tags and more. 27 | newtype Identifier = Identifier ByteString 28 | deriving stock (Generic, Show) 29 | deriving newtype (Eq, Ord, Binary, Hashable) 30 | instance NFData Identifier 31 | instance Pretty Identifier where 32 | pPrint (Identifier x) 33 | | all Char.isPrint (BS.unpack x) = PP.text $ BS.unpack x 34 | | otherwise = PP.text $ identHex $ Identifier x 35 | instance IsString Identifier where fromString = Identifier . fromString 36 | -- ^ IsString uses the underlying `ByteString.Char8` instance, use 37 | -- only with Latin1 strings 38 | 39 | -- | Convert the identifier bytes to a hex string 40 | -- 41 | -- > > identHex (Identifier "a1") 42 | -- > "6131" 43 | identHex :: Identifier -> String 44 | identHex (Identifier bs) = Hex.encode bs & BS.unpack 45 | 46 | -- | Convert a hex string (e.g: one generated by 'identHex') to an 47 | -- Identifier 48 | -- 49 | -- > > identFromHex (Identifier "6131") 50 | -- > Right (Identifier "a1") 51 | identFromHex :: String -> Either String Identifier 52 | identFromHex str = BS.pack str & Hex.decode <&> Identifier 53 | -------------------------------------------------------------------------------- /test/benchmark.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoImplicitPrelude, OverloadedStrings, ScopedTypeVariables, FlexibleContexts, TypeOperators #-} 2 | 3 | import Hyper 4 | import Hyper.Infer (infer, inferResult) 5 | import Hyper.Recurse (wrap) 6 | import Hyper.Unify (UVarOf, UnifyGen, applyBindings) 7 | import Control.DeepSeq (rnf) 8 | import Control.Exception (evaluate) 9 | import Control.Lens (ASetter') 10 | import qualified Control.Lens as Lens 11 | import Control.Lens.Operators 12 | import Control.Monad.Reader 13 | import Control.Monad.ST.Class (MonadST(..)) 14 | import Control.Monad.Trans.Maybe (MaybeT(..)) 15 | import Criterion (Benchmarkable, whnfIO) 16 | import Criterion.Main (bench, defaultMain) 17 | import Data.STRef (newSTRef) 18 | import Lamdu.Calc.Definition (pruneDeps) 19 | import Lamdu.Calc.Infer 20 | import Lamdu.Calc.Term (Term, Scope, emptyScope) 21 | import qualified Lamdu.Calc.Type as T 22 | import TestVals 23 | 24 | import Prelude.Compat 25 | 26 | localInitEnv :: 27 | ( MonadReader env m 28 | , UnifyGen m T.Type 29 | , UnifyGen m T.Row 30 | ) => 31 | ASetter' env (Scope # UVarOf m) -> Ann z # Term -> m a -> m a 32 | localInitEnv inferEnv e action = 33 | do 34 | addScope <- loadDeps (pruneDeps e allDeps) 35 | Lens.locally inferEnv addScope action 36 | 37 | toAnn :: HPlain Term -> Ann (Const ()) # Term 38 | toAnn = wrap (\_ x -> Ann (Const ()) x) . (^. hPlain) 39 | 40 | benchInferPure :: HPlain Term -> Benchmarkable 41 | benchInferPure e = 42 | infer x 43 | <&> (^. hAnn . Lens._2 . inferResult) 44 | >>= applyBindings 45 | & localInitEnv id x 46 | & runPureInfer emptyScope (InferState emptyPureInferState varGen) 47 | & Lens._Right %~ (^. Lens._1) 48 | & rnf 49 | & evaluate 50 | & whnfIO 51 | where 52 | x = toAnn e 53 | 54 | benchInferST :: HPlain Term -> Benchmarkable 55 | benchInferST e = 56 | do 57 | vg <- newSTRef varGen 58 | localInitEnv Lens._1 x 59 | (infer x <&> (^. hAnn . Lens._2 . inferResult) >>= applyBindings) ^. _STInfer 60 | & (`runReaderT` (emptyScope, vg)) 61 | & runMaybeT 62 | & liftST >>= evaluate . rnf & whnfIO 63 | where 64 | x = toAnn e 65 | 66 | benches :: [(String, Benchmarkable)] 67 | benches = 68 | [ ("S_factorial", benchInferST factorialVal) 69 | , ("S_euler1", benchInferST euler1Val) 70 | , ("S_solveDepressedQuartic", benchInferST solveDepressedQuarticVal) 71 | , ("S_factors", benchInferST factorsVal) 72 | , ("P_factorial", benchInferPure factorialVal) 73 | , ("P_euler1", benchInferPure euler1Val) 74 | , ("P_solveDepressedQuartic", benchInferPure solveDepressedQuarticVal) 75 | , ("P_factors", benchInferPure factorsVal) 76 | ] 77 | 78 | main :: IO () 79 | main = benches <&> uncurry bench & defaultMain 80 | -------------------------------------------------------------------------------- /tools/core-type-apps.py: -------------------------------------------------------------------------------- 1 | def normalize(x): 2 | x = x.strip().replace('\n', ' ') 3 | while ' ' in x: 4 | x = x.replace(' ', ' ') 5 | return x 6 | 7 | def take_word(): 8 | global c 9 | stack = [] 10 | for i, x in enumerate(c): 11 | if stack and x == stack[-1]: 12 | stack.pop() 13 | continue 14 | if x == ')' or x == ',': 15 | break 16 | if x == '(': 17 | stack.append(')') 18 | continue 19 | if x == '[': 20 | stack.append(']') 21 | continue 22 | if x == ' ' and not stack: 23 | break 24 | r = normalize(c[:i]) 25 | c = c[i:].strip() 26 | return r 27 | 28 | def params(): 29 | global c 30 | r = [take_word()] 31 | while c.startswith('@'): 32 | c = c[1:].strip() 33 | r.append (take_word()) 34 | return r 35 | 36 | def last_word(d): 37 | stack = [] 38 | for i, x in enumerate(d[::-1]): 39 | if stack and x == stack[-1]: 40 | stack.pop() 41 | continue 42 | if x == '(': 43 | break 44 | if x == ')': 45 | stack.append('(') 46 | continue 47 | if x == ']': 48 | stack.append('[') 49 | continue 50 | if x == ' ' and not stack: 51 | break 52 | return normalize(d[-i:]) 53 | 54 | cant = set(''' 55 | >>= 56 | <*> 57 | <$ 58 | fmap 59 | pure 60 | leq 61 | liftA2 62 | '''.split()) 63 | 64 | ok = set(''' 65 | $fMonadPureInfer_$s$fMonadRWST_$c>>= 66 | $fNFDataTypeError_$crnf 67 | $fNFDataType_$crnf 68 | $fOrdVar 69 | $fUnifyPureInferType_$cunifyError 70 | $fRecursiveUnify_$crecurse 71 | [] 72 | : 73 | ++ 74 | absentError 75 | emptyScope 76 | error 77 | heq_sel 78 | map 79 | newMutVar# 80 | pruneDeps1 81 | readMutVar# 82 | runMainIO1 83 | rwhnf 84 | seq# 85 | unifyError 86 | whnfIO' 87 | writeMutVar# 88 | '''.split()) 89 | ok.update(cant) 90 | 91 | apps = set() 92 | for x in ['test/benchmark', 'src/Lamdu/Calc/Infer', 'src/Lamdu/Calc/Term']: 93 | c = open('dumps/'+x+'.dump-simpl').read().split('------ Local rules for imported ids --------', 1)[0] 94 | while '@' in c: 95 | pre, post = c.split('@', 1) 96 | c = post.strip() 97 | if pre.strip().endswith('\\'): 98 | # Lambda. Skip type parameters 99 | params() 100 | continue 101 | if post.startswith('~'): 102 | # Not a type application 103 | continue 104 | var = last_word(pre.strip()) 105 | while var.startswith('('): 106 | var = var[1:] 107 | p = params() 108 | if var[:1].isupper(): 109 | # Skip data constructors 110 | continue 111 | if var in ok: 112 | continue 113 | w = [] 114 | for x in p: 115 | w += x.replace('(', ' ').replace('[', ' ').split() 116 | if [x for x in w if x[:1].islower() or x in ['RealWorld', 'Any']]: 117 | # Skip specializations with type variables, 118 | # find only the top-level specializations 119 | continue 120 | apps.add((var, tuple(p))) 121 | 122 | for var, p in sorted(apps): 123 | print(' '.join([var] + ['@'+x for x in p])) 124 | -------------------------------------------------------------------------------- /doc/ExceptionMonad.md: -------------------------------------------------------------------------------- 1 | # Exception monad in Lamdu Calculus 2 | 3 | In Haskell, exception monads are not used for most code that can throw 4 | exceptions. Usually, their use is localized where the domain of known 5 | errors is fixed. 6 | 7 | The reason for this, we believe, is the type of bind in Haskell's exception 8 | monad: 9 | 10 | `(>>=) :: Either err a -> (a -> Either err b) -> Either err b` 11 | 12 | Note that `err` must be the same type in both sides of the bind - even 13 | if the kinds of errors thrown by the two bound actions are vastly 14 | different. 15 | 16 | Haskell's sum (variant) types are also fully nominal, meaning that one 17 | must manually declare all possible errors in all bound actions in a 18 | single sum type, to use that as the error type. 19 | 20 | This inhibits use of Haskell exception monads, and instead, much more 21 | code uses dynamically typed exceptions (`SomeException` and the 22 | `Exception` class). 23 | 24 | However, with extensible variant types and Lamdu Calculus's powerful 25 | case expressions, we can do better. 26 | 27 | ## Throwing errors 28 | 29 | ``` 30 | div x y = 31 | case y == 0 of 32 | #True () -> #Error (#DivByZero ()) 33 | #False () -> #Success (x / y) 34 | ``` 35 | 36 | The inferred type of `div` is: 37 | ``` 38 | forall err. #DivByZero ∉ err => 39 | Int -> Int -> +{ #Error : +{ #DivByZero : () | err }, #Success : Int } 40 | ``` 41 | 42 | Note that the only possible error is inferred, and we need not 43 | manually declare an error type. 44 | 45 | We also declare `fromJust` to extract a value from a `Maybe` type: 46 | 47 | ``` 48 | fromJust = \case 49 | #Just x -> #Success x 50 | #Nothing () -> #Error #NotJust 51 | ``` 52 | which has type: 53 | ``` 54 | forall a err. #NotJust ∉ err => 55 | +{ #Just : a, #Nothing : () } -> +{ #Success : a, #Error : +{ #NotJust : () | err } } 56 | ``` 57 | 58 | If we bind the two actions together (using an ordinary exception monad): 59 | 60 | `\v -> fromJust v >>= div 1000` 61 | 62 | The resulting type is automatically inferred to be: 63 | ``` 64 | forall err. {#NotJust, #DivByZero} ∉ err => 65 | +{ #Nothing : (), #Just : Int } -> 66 | +{ #Error : +{#NotJust : (), #DivByZero : () | err} 67 | , #Success : Int 68 | } 69 | ``` 70 | 71 | The automatically inferred set of errors is visibly available in the type. 72 | 73 | ## Catching errors 74 | 75 | We can declare an error catcher: 76 | 77 | ``` 78 | catch handler = 79 | \case 80 | #Success v -> #Success v 81 | #Error -> handler 82 | ``` 83 | 84 | And then we can catch just the `#DivByZero` error above: 85 | 86 | ``` 87 | \v -> 88 | catch (\case 89 | #DivByZero -> #Success 0 90 | err -> #Error err) 91 | (fromJust v >>= div 1000)` 92 | ``` 93 | 94 | The inferred type then becomes: 95 | ``` 96 | forall err. {#NotJust} ∉ err => 97 | +{ #Nothing : (), #Just : Int } -> 98 | +{ #Error : +{#NotJust : () | err} 99 | , #Success : Int 100 | } 101 | ``` 102 | 103 | Once all errors are caught, a result type becomes: 104 | 105 | `forall err. +{ Error : +{| err}, #Success : ... }` 106 | 107 | We can then eliminate the variant wrapper with: 108 | 109 | ``` 110 | successful : forall err a. +{ Error : +{| err}, #Success : a } -> a 111 | successful = \case 112 | #Error -> absurd 113 | #Success -> id 114 | ``` 115 | 116 | This instantiates `err` to be the empty variant type (Void), and is thus 117 | allowed to eliminate the #Error/#Success wrapper to yield an `a`. 118 | -------------------------------------------------------------------------------- /lamdu-calculus.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 1.12 2 | 3 | -- This file has been generated from package.yaml by hpack version 0.35.2. 4 | -- 5 | -- see: https://github.com/sol/hpack 6 | 7 | name: lamdu-calculus 8 | version: 0.2.0.1 9 | synopsis: The Lamdu Calculus programming language 10 | description: Please see README.md 11 | category: Language 12 | homepage: https://github.com/lamdu/lamdu-calculus#readme 13 | bug-reports: https://github.com/lamdu/lamdu-calculus/issues 14 | author: Yair Chuchem, Eyal Lotem 15 | maintainer: yairchu@gmail.com 16 | copyright: 2021 Yair Chuchem, Eyal Lotem 17 | license: BSD3 18 | license-file: LICENSE 19 | build-type: Simple 20 | extra-source-files: 21 | README.md 22 | 23 | source-repository head 24 | type: git 25 | location: https://github.com/lamdu/lamdu-calculus 26 | 27 | library 28 | exposed-modules: 29 | Lamdu.Calc.Definition 30 | Lamdu.Calc.Identifier 31 | Lamdu.Calc.Infer 32 | Lamdu.Calc.Lens 33 | Lamdu.Calc.Term 34 | Lamdu.Calc.Term.Eq 35 | Lamdu.Calc.Type 36 | other-modules: 37 | Lamdu.Calc.Internal.Prelude 38 | hs-source-dirs: 39 | src 40 | ghc-options: -fexpose-all-unfoldings -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints -Wunused-packages -fdicts-cheap -O2 -fspecialise-aggressively 41 | ghc-prof-options: -O2 42 | build-depends: 43 | base >=4.7 44 | , base-compat >=0.8.2 45 | , base16-bytestring 46 | , binary 47 | , bytestring 48 | , containers 49 | , deepseq 50 | , generic-constraints 51 | , generic-data 52 | , hashable 53 | , hypertypes >=0.2 54 | , lattices 55 | , lens >=4.1 56 | , monad-st 57 | , mtl 58 | , pretty >=1.1.2 59 | , transformers 60 | default-language: Haskell2010 61 | 62 | test-suite lamdu-calculus-test 63 | type: exitcode-stdio-1.0 64 | main-is: test.hs 65 | other-modules: 66 | TestVals 67 | Paths_lamdu_calculus 68 | hs-source-dirs: 69 | test 70 | ghc-options: -fexpose-all-unfoldings -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints -Wunused-packages -fdicts-cheap -O2 -fspecialise-aggressively 71 | ghc-prof-options: -O2 72 | build-depends: 73 | HUnit 74 | , base >=4.7 75 | , base-compat >=0.8.2 76 | , bytestring 77 | , containers 78 | , hypertypes >=0.2 79 | , lamdu-calculus 80 | , lens >=4.1 81 | , test-framework 82 | , test-framework-hunit 83 | default-language: Haskell2010 84 | 85 | benchmark lamdu-calculus-bench 86 | type: exitcode-stdio-1.0 87 | main-is: benchmark.hs 88 | other-modules: 89 | TestVals 90 | Paths_lamdu_calculus 91 | hs-source-dirs: 92 | test 93 | ghc-options: -fexpose-all-unfoldings -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints -Wunused-packages -fdicts-cheap -O2 -fspecialise-aggressively -O2 -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints 94 | ghc-prof-options: -O2 95 | build-depends: 96 | base >=4.7 97 | , base-compat >=0.8.2 98 | , bytestring 99 | , containers 100 | , criterion 101 | , deepseq 102 | , hypertypes >=0.2 103 | , lamdu-calculus 104 | , lens >=4.1 105 | , monad-st 106 | , mtl 107 | , transformers 108 | default-language: Haskell2010 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # lamdu-calculus 2 | 3 | An extended typed Lambda Calculus AST. 4 | 5 | Used by [Lamdu](http://www.lamdu.org/) as the underlying unsugared language. 6 | 7 | The Lamdu Calculus language is a statically typed polymorphic 8 | non-dependent lambda calculus extension. 9 | 10 | The language does not include a parser, as it is meant to be used 11 | structurally via its AST directly. It does include a pretty-printer, 12 | for ease of debugging. 13 | 14 | Type inference for Lamdu Calculus is implemented in a [separate 15 | package](https://github.com/lamdu/Algorithm-W-Step-By-Step). 16 | 17 | ## Values 18 | 19 | The term language supports lambda abstractions and variables, literal 20 | values of custom types, holes, nominal type wrapping/unwrapping, 21 | extensible variant types and extensible records. 22 | 23 | ## Types 24 | 25 | The type language includes nominal types with polymoprhic components, 26 | extensible record types (row types) and extensible variant types 27 | (polymorphic variants / column types). 28 | 29 | ## Composite types 30 | 31 | In Lamdu Calculus, both records and variants are "composite" types and 32 | share an underlying AST, using a phantom type to distinguish them. 33 | 34 | Records are sets of typed and named fields. 35 | 36 | Variants are a value of a set of possible typed data constructors. 37 | 38 | Both records and variants are order-agnostic, and disallow repeated 39 | appearance of the same names for fields or data constructors. 40 | 41 | Both records and variants are extensible. This is usually called "row 42 | types" for records and "polymorphic variants" for variants. We use the 43 | terminology of "row" and "column" types instead. 44 | 45 | ## Records 46 | 47 | In the term language, records are constructed using these AST constructions: 48 | 49 | `BLeaf LRecEmpty` denotes the empty record (denote in pseudo-syntax as `()`). 50 | Its type is `TRecord CEmpty`. 51 | 52 | `BRecExtend (RowExtend tag (value : T) (rest : R))`1 denotes a record extension 53 | of an existing record `rest`. 54 | 55 | Its type is `TRecord (CExtend tag T R)`. 56 | 57 | A sequence of `BRecExtend` followed by a `BLeaf LRecEmpty` is denoted 58 | in pseudo-syntax as `{ x : T, y : U, ... }`. 59 | 60 | To deconstruct a record, we use the `GetField` operation (denoted in 61 | pseudo-syntax as `.`). 62 | 63 | 1. The `(value : Type)` notation is used to denote values' types 64 | 65 | ### Row types 66 | 67 | Row types are denoted using a record type variable. 68 | 69 | For example, lets examine a lambda abstraction. Assume `(+) : Int → 70 | Int → Int`. 71 | 72 | ``` 73 | vector → vector.x + vector.y 74 | ``` 75 | 76 | We can infer the type: 77 | 78 | ```Haskell 79 | { x : Int, y : Int } → Int 80 | ``` 81 | 82 | But the most general type is: 83 | 84 | ```Haskell 85 | forall r1. {x, y} ∉ r1 => { x : Int, y : Int | r1 } → Int 86 | ``` 87 | 2 88 | 89 | The lambda may be applied with a record that contains more fields 90 | besides `x` and `y`, and that is what the `| r1` denotes. Note that to 91 | enforce the no-duplication requirement Lamdu Calculus uses a 92 | constraint on composite type variables, for each field that they may 93 | not duplicate. 94 | 95 | 2. The `|` symbol in the pseudo-syntax is used to indicate 96 | that some set of fields denoted by the variable that follows (in this 97 | case, `r1`) is concatenated to the set) 98 | 99 | ## Variants 100 | 101 | In the term language, variants are constructed using value *injection*. 102 | 103 | For example, let's look at the type: 104 | 105 | ```Haskell 106 | +{ Nothing : (), Just : Int } 107 | ``` 108 | 109 | This pseudo syntax means the structural variant type isomorphic to 110 | Haskell's `Maybe Int`. 111 | 112 | The value `5 : Int` is not of the type: `+{ Nothing : (), Just : Int }`. 113 | 114 | In Haskell, we use `Just` to "lift" 5 from `Int` to `Maybe Int`. In 115 | Lamdu Calculus, we inject via `BInject (Inject "Just" 5)`, which we denote in 116 | pseudo-syntax as `Just: 5`. 117 | 118 | `BInject (Inject tag (value : T))` *injects* a given value into a 119 | variant type. Its type is `forall alts. tag ∉ alts => TVariant (CExtend tag T alts)`. 120 | 121 | This type means that the injected value allows any set of typed 122 | alternatives to exist in the larger variant. 123 | 124 | ### Case expressions 125 | 126 | Deconstructing a variant requires a case statement. Lamdu Calculus 127 | case statements *peel* one alternate case at a time, so need to be 128 | composed to create an ordinary full case. 129 | 130 | `BCase (Case tag handler rest)` denoted in pseudo syntax as 131 | ```Haskell 132 | \case 133 | tag: handler 134 | rest 135 | ``` 136 | 137 | The above case expression creates a function with a variant 138 | parameter. It analyzes its argument, and if it is a `tag`, the given 139 | `handler` is invoked with the typed content of the `tag`. 140 | 141 | If the argument is not a `tag`, then the `rest` handler is 142 | invoked. The `rest` handler is given a smaller variant type as an 143 | argument. A variant type which no longer has the `tag` case inside it, as 144 | that was ruled out. 145 | 146 | ```Haskell 147 | \case 148 | Just: x → x + 1 149 | \case 150 | Nothing: () → 0 151 | ? 152 | ``` 153 | 154 | The above composed case expression will match against `Just`: 155 | 156 | * Match: it will evaluate to the content of the `Just` added to 1. 157 | 158 | * Mismatch: it will match against `Nothing`: 159 | 160 | * Match: it will ignore the empty record contained in the 161 | `Nothing` case, and evaluate to 0 162 | 163 | * Mismatch: Evaluate to a hole (denoted by `?`) which is like 164 | Haskell's `undefined`, and takes on any type. 165 | 166 | The type of the above case statement would be inferred to: 167 | 168 | ```Haskell 169 | forall v1. (Just, Nothing) ∉ v1 => 170 | +{ Just : Int , Nothing : () | v1 } → Int 171 | ``` 172 | 173 | Note that the case statement allows *any* structural variant type that 174 | has the proper `Nothing` and `Just` cases (and it would reach a hole if the 175 | value happens to be neither `Nothing` nor `Just`). This is not 176 | typically what we want. We'd like to *close* the variant type so it is *not* extensible. 177 | 178 | To do that, we must also support the empty case statement that matches 179 | no possible alternatives, allowing us to *close* the composition. The 180 | empty case statement AST construction is: 181 | 182 | BLeaf LAbsurd, and is denoted as `absurd` (as it is the analogue of 183 | the `absurd` function in Haskell and Agda). The type of `absurd` is: 184 | `forall r. +{} → r` (`+{}` is the empty variant type, aka `Void`). 185 | 186 | We can now close the above case expression: 187 | 188 | ```Haskell 189 | \case 190 | Just: x → x + 1 191 | \case 192 | Nothing: () → 0 193 | absurd 194 | ``` 195 | 196 | And now our inferred type will be simpler: 197 | 198 | ```Haskell 199 | +{ Just : Int, Nothing : () } → Int 200 | ``` 201 | 202 | As a short-hand for nested/composed cases, we'll use Haskell-like 203 | syntax for cases as well, meaning the expanded, composed case. 204 | 205 | ```Haskell 206 | \case 207 | Just: x → x + 1 208 | Nothing: () → 0 209 | ``` 210 | 211 | ### Example use-case: Composing interpreters 212 | 213 | We can define single-command interpreters: 214 | 215 | ```Haskell 216 | interpretAdd default = 217 | \case 218 | Add: {x, y} → x + y 219 | default 220 | ``` 221 | 222 | ```Haskell 223 | interpretMul default = 224 | \case 225 | Mul {x, y} → x * y 226 | default 227 | ``` 228 | 229 | And then we can compose them: 230 | 231 | ```Haskell 232 | interpreterArithmetic = interpretAdd . interpretMul 233 | 234 | interpreter = (interpreterArithmetic . interpreterConditionals) absurd 235 | ``` 236 | 237 | Yielding a single interperter that can handle the full set of cases in 238 | its composition. 239 | 240 | ### Example use-case: Exception Monads 241 | 242 | See the [Exception Monad](ExceptionMonad.md) use-case for a more 243 | compelling example use-case of extensible variants. 244 | 245 | ## Nominal types 246 | 247 | ### Purposes 248 | 249 | #### Recursive Types 250 | 251 | Infinite structural types are not allowed, so you can use a nominal type to “tie the knot”. 252 | For example: 253 | 254 | ```Haskell 255 | Stream a = +{ Empty : (), NonEmpty : { head: a, tail: () -> Stream a } } 256 | ``` 257 | 258 | A `Stream a` may contain a `Stream a` within, therefore the type is recursive. 259 | 260 | #### Safety 261 | 262 | Often one wants to distinguish types whose representations/structure are the same, but their meanings are very different 263 | 264 | #### Rank-N types 265 | 266 | Nominal types may bind type variables (i.e: the structural type within a nominal type is wrapped in “forall”s) 267 | 268 | This is being used by the `Mut` type (equivalent to Haskell's `ST`) 269 | 270 | # Future plans 271 | 272 | * Unifying with the underlying AST in 273 | [AlgoWMutable](https://github.com/Peaker/AlgoWMutable) which 274 | implements type inference more efficiently than 275 | [Algorithm-W-Step-By-Step](https://github.com/lamdu/Algorithm-W-Step-By-Step) 276 | which is compatible with this package. 277 | 278 | * Support for type-classes 279 | -------------------------------------------------------------------------------- /src/Lamdu/Calc/Infer.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoImplicitPrelude, TemplateHaskell, GeneralizedNewtypeDeriving #-} 2 | {-# LANGUAGE UndecidableInstances, MultiParamTypeClasses, TypeFamilies #-} 3 | {-# LANGUAGE FlexibleInstances, LambdaCase, DataKinds, FlexibleContexts #-} 4 | {-# LANGUAGE TypeApplications, RankNTypes, DerivingStrategies, TypeOperators #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | 7 | module Lamdu.Calc.Infer 8 | ( InferState(..), isBinding, isQVarGen 9 | , emptyPureInferState 10 | , PureInfer(..), _PureInfer, runPureInfer 11 | , STInfer(..), _STInfer 12 | , loadDeps 13 | , varGen 14 | , alphaEq 15 | ) where 16 | 17 | import Hyper 18 | import Hyper.Infer 19 | import Hyper.Syntax.Nominal 20 | import qualified Hyper.Syntax.Var as TermVar 21 | import qualified Hyper.Syntax.Scheme as S 22 | import qualified Hyper.Syntax.Scheme.AlphaEq as S 23 | import Hyper.Unify 24 | import Hyper.Unify.Binding 25 | import Hyper.Unify.Binding.ST 26 | import Hyper.Unify.Generalize 27 | import Hyper.Unify.QuantifiedVar 28 | import Hyper.Unify.Term (UTerm, UTermBody) 29 | import Control.Applicative (Alternative(..)) 30 | import qualified Control.Lens as Lens 31 | import Control.Lens (LensLike') 32 | import Control.Monad.Except 33 | import Control.Monad.Reader.Class 34 | import Control.Monad.ST 35 | import Control.Monad.ST.Class (MonadST(..)) 36 | import Control.Monad.State 37 | import Control.Monad.Trans.Maybe 38 | import Control.Monad.Trans.RWS (RWST(..)) 39 | import Control.Monad.Trans.Reader (ReaderT(..)) 40 | import Control.Monad.Trans.Writer (WriterT) 41 | import Data.STRef 42 | import Lamdu.Calc.Definition (Deps, depsNominals, depsGlobalTypes) 43 | import Lamdu.Calc.Term 44 | import qualified Lamdu.Calc.Type as T 45 | 46 | import Lamdu.Calc.Internal.Prelude 47 | 48 | data QVarGen = QVarGen 49 | { _nextTV :: !Int 50 | , _nextRV :: !Int 51 | } deriving (Eq, Ord, Show) 52 | Lens.makeLenses ''QVarGen 53 | 54 | varGen :: QVarGen 55 | varGen = QVarGen 0 0 56 | 57 | data InferState = InferState 58 | { _isBinding :: T.Types # Binding 59 | , _isQVarGen :: QVarGen 60 | } deriving (Eq, Ord, Show) 61 | Lens.makeLenses ''InferState 62 | 63 | newtype PureInfer env a = 64 | PureInfer 65 | (RWST env () InferState (Either (T.TypeError # UVar)) a) 66 | deriving newtype 67 | ( Functor, Applicative, Monad 68 | , MonadReader env 69 | , MonadError (T.TypeError # UVar) 70 | , MonadState InferState 71 | ) 72 | Lens.makePrisms ''PureInfer 73 | 74 | runPureInfer :: 75 | env -> InferState -> PureInfer env a -> 76 | Either (T.TypeError # UVar) (a, InferState) 77 | runPureInfer env st (PureInfer act) = 78 | runRWST act env st <&> \(x, s, ~()) -> (x, s) 79 | 80 | type instance UVarOf (PureInfer _) = UVar 81 | 82 | loadDeps :: 83 | (UnifyGen m T.Row, S.HasScheme T.Types m T.Type) => 84 | Deps -> m (Scope # UVarOf m -> Scope # UVarOf m) 85 | loadDeps deps = 86 | do 87 | loadedNoms <- deps ^. depsNominals & traverse loadNominalDecl 88 | loadedSchemes <- deps ^. depsGlobalTypes & traverse S.loadScheme 89 | pure $ \env -> 90 | env 91 | & scopeVarTypes <>~ (loadedSchemes <&> MkHFlip) 92 | & scopeNominals <>~ loadedNoms 93 | 94 | instance MonadScopeLevel (PureInfer (Scope # UVar)) where 95 | localLevel = local (scopeLevel . _ScopeLevel +~ 1) 96 | 97 | instance MonadNominals T.NominalId T.Type (PureInfer (Scope # UVar)) where 98 | {-# INLINE getNominalDecl #-} 99 | getNominalDecl n = 100 | Lens.view (scopeNominals . Lens.at n) 101 | >>= maybe (throwError (T.NominalNotFound n)) pure 102 | 103 | instance TermVar.HasScope (PureInfer (Scope # UVar)) Scope where 104 | {-# INLINE getScope #-} 105 | getScope = Lens.view id 106 | 107 | instance LocalScopeType Var (UVar # T.Type) (PureInfer (Scope # UVar)) where 108 | {-# INLINE localScopeType #-} 109 | localScopeType k v = local (scopeVarTypes . Lens.at k ?~ MkHFlip (GMono v)) 110 | 111 | instance UnifyGen (PureInfer (Scope # UVar)) T.Type where 112 | {-# INLINE scopeConstraints #-} 113 | scopeConstraints _ = Lens.view scopeLevel 114 | 115 | nextTVNamePure :: 116 | (MonadState InferState m, IsString a) => 117 | Char -> LensLike' ((,) Int) QVarGen Int -> m a 118 | nextTVNamePure prefix lens = 119 | isQVarGen . lens <<%= (+1) <&> show <&> (prefix :) <&> fromString 120 | 121 | instance MonadQuantify ScopeLevel T.TypeVar (PureInfer env) where 122 | newQuantifiedVariable _ = nextTVNamePure 't' nextTV 123 | 124 | instance UnifyGen (PureInfer (Scope # UVar)) T.Row where 125 | {-# INLINE scopeConstraints #-} 126 | scopeConstraints _ = scopeConstraints (Proxy @T.Type) <&> T.RowConstraints mempty 127 | 128 | instance MonadQuantify T.RConstraints T.RowVar (PureInfer env) where 129 | newQuantifiedVariable _ = nextTVNamePure 'r' nextRV 130 | 131 | instance Unify (PureInfer env) T.Type where 132 | {-# INLINE binding #-} 133 | binding = bindingDict (isBinding . T.tType) 134 | unifyError = throwError . T.TypeError 135 | 136 | instance Unify (PureInfer env) T.Row where 137 | {-# INLINE binding #-} 138 | binding = bindingDict (isBinding . T.tRow) 139 | unifyError = throwError . T.RowError 140 | {-# INLINE structureMismatch #-} 141 | structureMismatch = T.rStructureMismatch 142 | 143 | emptyPureInferState :: T.Types # Binding 144 | emptyPureInferState = T.Types emptyBinding emptyBinding 145 | 146 | newtype STInfer s a = STInfer 147 | (ReaderT (Scope # STUVar s, STRef s QVarGen) (MaybeT (ST s)) a) 148 | deriving newtype 149 | ( Functor, Alternative, Applicative, Monad, MonadST 150 | , MonadReader (Scope # STUVar s, STRef s QVarGen) 151 | ) 152 | Lens.makePrisms ''STInfer 153 | 154 | type instance UVarOf (STInfer s) = STUVar s 155 | 156 | instance MonadScopeLevel (STInfer s) where 157 | localLevel = local (Lens._1 . scopeLevel . _ScopeLevel +~ 1) 158 | 159 | instance MonadNominals T.NominalId T.Type (STInfer s) where 160 | {-# INLINE getNominalDecl #-} 161 | getNominalDecl n = 162 | Lens.view (Lens._1 . scopeNominals . Lens.at n) >>= maybe empty pure 163 | 164 | instance TermVar.HasScope (STInfer s) Scope where 165 | {-# INLINE getScope #-} 166 | getScope = Lens.view Lens._1 167 | 168 | instance LocalScopeType Var (STUVar s # T.Type) (STInfer s) where 169 | {-# INLINE localScopeType #-} 170 | localScopeType k v = 171 | local (Lens._1 . scopeVarTypes . Lens.at k ?~ MkHFlip (GMono v)) 172 | 173 | instance UnifyGen (STInfer s) T.Type where 174 | {-# INLINE scopeConstraints #-} 175 | scopeConstraints _ = Lens.view (Lens._1 . scopeLevel) 176 | 177 | nextTVNameST :: IsString a => Char -> Lens.ALens' QVarGen Int -> STInfer s a 178 | nextTVNameST prefix lens = 179 | do 180 | genRef <- Lens.view Lens._2 181 | gen <- readSTRef genRef & liftST 182 | let res = prefix : show (gen ^# lens) & fromString 183 | let newGen = gen & lens #%~ (+1) 184 | res <$ writeSTRef genRef newGen & liftST 185 | 186 | instance MonadQuantify ScopeLevel T.TypeVar (STInfer s) where 187 | newQuantifiedVariable _ = nextTVNameST 't' nextTV 188 | 189 | instance UnifyGen (STInfer s) T.Row where 190 | {-# INLINE scopeConstraints #-} 191 | scopeConstraints _ = scopeConstraints (Proxy @T.Type) <&> T.RowConstraints mempty 192 | 193 | instance MonadQuantify T.RConstraints T.RowVar (STInfer s) where 194 | newQuantifiedVariable _ = nextTVNameST 'r' nextRV 195 | 196 | instance Unify (STInfer s) T.Type where 197 | {-# INLINE binding #-} 198 | binding = stBinding 199 | unifyError _ = empty 200 | 201 | instance Unify (STInfer s) T.Row where 202 | {-# INLINE binding #-} 203 | binding = stBinding 204 | unifyError _ = empty 205 | {-# INLINE structureMismatch #-} 206 | structureMismatch = T.rStructureMismatch 207 | 208 | alphaEq :: 209 | Pure # S.Scheme T.Types T.Type -> 210 | Pure # S.Scheme T.Types T.Type -> 211 | Bool 212 | alphaEq x y = 213 | runST $ 214 | do 215 | vg <- newSTRef varGen 216 | S.alphaEq x y 217 | ^. _STInfer 218 | & (`runReaderT` (emptyScope, vg)) 219 | & runMaybeT 220 | <&> Lens.has Lens._Just 221 | 222 | {-# SPECIALIZE unify :: STUVar s # T.Row -> STUVar s # T.Row -> STInfer s (STUVar s # T.Row) #-} 223 | {-# SPECIALIZE updateConstraints :: ScopeLevel -> STUVar s # T.Type -> UTerm (STUVar s) # T.Type -> STInfer s () #-} 224 | {-# SPECIALIZE updateConstraints :: T.RConstraints -> STUVar s # T.Row -> UTerm (STUVar s) # T.Row -> STInfer s () #-} 225 | {-# SPECIALIZE updateTermConstraints :: STUVar s # T.Row -> UTermBody (STUVar s) # T.Row -> T.RConstraints -> STInfer s () #-} 226 | {-# SPECIALIZE instantiateH :: (forall n. TypeConstraintsOf n -> UTerm (STUVar s) # n) -> GTerm (STUVar s) # T.Row -> WriterT [STInfer s ()] (STInfer s) (STUVar s # T.Row) #-} 227 | {-# SPECIALIZE semiPruneLookup :: STUVar s # T.Type -> STInfer s (STUVar s # T.Type, UTerm (STUVar s) # T.Type) #-} 228 | {-# SPECIALIZE semiPruneLookup :: STUVar s # T.Row -> STInfer s (STUVar s # T.Row, UTerm (STUVar s) # T.Row) #-} 229 | {-# SPECIALIZE unifyUTerms :: STUVar s # T.Row -> UTerm (STUVar s) # T.Row -> STUVar s # T.Row -> UTerm (STUVar s) # T.Row -> STInfer s (STUVar s # T.Row) #-} 230 | -------------------------------------------------------------------------------- /test/TestVals.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-} 2 | {-# LANGUAGE NoImplicitPrelude, OverloadedStrings, TypeOperators #-} 3 | 4 | module TestVals 5 | ( allDeps 6 | , factorialVal, euler1Val, solveDepressedQuarticVal 7 | , factorsVal 8 | , recordType 9 | , intType 10 | , listTypePair, boolTypePair 11 | ) where 12 | 13 | import Hyper 14 | import Hyper.Syntax.Nominal 15 | import Hyper.Syntax.Row 16 | import Hyper.Syntax.Scheme 17 | import Hyper.Type.Prune 18 | import qualified Control.Lens as Lens 19 | import Control.Lens.Operators 20 | import qualified Data.ByteString.Char8 as BS8 21 | import qualified Data.Map as Map 22 | import Lamdu.Calc.Definition (Deps(..)) 23 | import Lamdu.Calc.Term 24 | import Lamdu.Calc.Type (Type, (~>)) 25 | import qualified Lamdu.Calc.Type as T 26 | 27 | import Prelude.Compat 28 | 29 | {-# ANN module ("HLint: ignore Redundant $" :: String) #-} 30 | 31 | -- TODO: $$ to be type-classed for TApp vs BApp 32 | -- TODO: TCon "->" instead of TFun 33 | 34 | recExtends :: Pure # T.Row -> [(T.Tag, Pure # Type)] -> Pure # Type 35 | recExtends recTail fields = 36 | foldr 37 | (\(tag, typ) -> Pure . T.RExtend . RowExtend tag typ) recTail fields 38 | & T.TRecord & Pure 39 | 40 | recordType :: [(T.Tag, Pure # Type)] -> Pure # Type 41 | recordType = recExtends (_Pure # T.REmpty) 42 | 43 | forAll :: 44 | [T.TypeVar] -> ([Pure # Type] -> Pure # Type) -> Pure # T.Scheme 45 | forAll tvs mkType = 46 | tvs <&> T.TVar <&> (_Pure #) & mkType 47 | & Scheme T.Types 48 | { T._tType = QVars (Map.fromList [(tv, mempty) | tv <- tvs]) 49 | , T._tRow = mempty 50 | } & Pure 51 | 52 | stOf :: Pure # Type -> Pure # Type -> Pure # Type 53 | stOf s a = 54 | T.Types (QVarInstances (mempty & Lens.at "res" ?~ a & Lens.at "s" ?~ s)) (QVarInstances mempty) 55 | & NominalInst "ST" & T.TInst & Pure 56 | 57 | listTypePair :: (T.NominalId, Pure # NominalDecl Type) 58 | listTypePair = 59 | ( "List" 60 | , _Pure # NominalDecl 61 | { _nParams = 62 | T.Types 63 | { T._tType = mempty & Lens.at "elem" ?~ mempty 64 | , T._tRow = mempty 65 | } 66 | , _nScheme = 67 | _Pure # T.REmpty 68 | & RowExtend "[]" (recordType []) & T.RExtend & Pure 69 | & RowExtend ":" (recordType [("head", tv), ("tail", listOf tv)]) 70 | & T.RExtend & Pure 71 | & T.TVariant & Pure 72 | & Scheme (T.Types mempty mempty) 73 | } 74 | ) 75 | where 76 | tv = _Pure # T.TVar "a" 77 | 78 | listOf :: Pure # Type -> Pure # Type 79 | listOf x = 80 | T.Types (QVarInstances (mempty & Lens.at "elem" ?~ x)) (QVarInstances mempty) 81 | & NominalInst (fst listTypePair) & T.TInst & Pure 82 | 83 | boolType :: Pure # Type 84 | boolType = 85 | T.Types (QVarInstances mempty) (QVarInstances mempty) 86 | & NominalInst (fst boolTypePair) & T.TInst & Pure 87 | 88 | intType :: Pure # Type 89 | intType = 90 | T.Types (QVarInstances mempty) (QVarInstances mempty) 91 | & NominalInst "Int" & T.TInst & Pure 92 | 93 | boolTypePair :: (T.NominalId, Pure # NominalDecl Type) 94 | boolTypePair = 95 | ( "Bool" 96 | , _Pure # NominalDecl 97 | { _nParams = T.Types mempty mempty 98 | , _nScheme = 99 | _Pure # T.REmpty 100 | & RowExtend "True" (recordType []) & T.RExtend & Pure 101 | & RowExtend "False" (recordType []) & T.RExtend & Pure 102 | & T.TVariant & Pure 103 | & Scheme (T.Types mempty mempty) 104 | } 105 | ) 106 | 107 | maybeOf :: Pure # Type -> Pure # Type 108 | maybeOf t = 109 | _Pure # T.REmpty 110 | & RowExtend "Just" t & T.RExtend & Pure 111 | & RowExtend "Nothing" (recordType []) & T.RExtend & Pure 112 | & T.TVariant & Pure 113 | 114 | infixType :: Pure # Type -> Pure # Type -> Pure # Type -> Pure # Type 115 | infixType a b c = recordType [("l", a), ("r", b)] ~> c 116 | 117 | allDeps :: Deps 118 | allDeps = 119 | Deps 120 | { _depsNominals = Map.fromList [boolTypePair, listTypePair] 121 | , _depsGlobalTypes = 122 | Map.fromList 123 | [ ("fix", forAll ["a"] $ \ [a] -> (a ~> a) ~> a) 124 | , ("if", forAll ["a"] $ \ [a] -> recordType [("condition", boolType), ("then", a), ("else", a)] ~> a) 125 | , ("==", forAll ["a"] $ \ [a] -> infixType a a boolType) 126 | , (">", forAll ["a"] $ \ [a] -> infixType a a boolType) 127 | , ("%", forAll ["a"] $ \ [a] -> infixType a a a) 128 | , ("*", forAll ["a"] $ \ [a] -> infixType a a a) 129 | , ("-", forAll ["a"] $ \ [a] -> infixType a a a) 130 | , ("+", forAll ["a"] $ \ [a] -> infixType a a a) 131 | , ("/", forAll ["a"] $ \ [a] -> infixType a a a) 132 | , ("//", forAll [] $ \ [] -> infixType intType intType intType) 133 | , ("sum", forAll ["a"] $ \ [a] -> listOf a ~> a) 134 | , ("filter", forAll ["a"] $ \ [a] -> recordType [("from", listOf a), ("predicate", a ~> boolType)] ~> listOf a) 135 | , (":", forAll ["a"] $ \ [a] -> recordType [("head", a), ("tail", listOf a)] ~> listOf a) 136 | , ("[]", forAll ["a"] $ \ [a] -> listOf a) 137 | , ("concat", forAll ["a"] $ \ [a] -> listOf (listOf a) ~> listOf a) 138 | , ("map", forAll ["a", "b"] $ \ [a, b] -> recordType [("list", listOf a), ("mapping", a ~> b)] ~> listOf b) 139 | , ("..", forAll [] $ \ [] -> infixType intType intType (listOf intType)) 140 | , ("||", forAll [] $ \ [] -> infixType boolType boolType boolType) 141 | , ("head", forAll ["a"] $ \ [a] -> listOf a ~> a) 142 | , ("negate", forAll ["a"] $ \ [a] -> a ~> a) 143 | , ("sqrt", forAll ["a"] $ \ [a] -> a ~> a) 144 | , ("id", forAll ["a"] $ \ [a] -> a ~> a) 145 | , ("zipWith",forAll ["a","b","c"] $ \ [a,b,c] -> 146 | (a ~> b ~> c) ~> listOf a ~> listOf b ~> listOf c ) 147 | , ("Just", forAll ["a"] $ \ [a] -> a ~> maybeOf a) 148 | , ("Nothing",forAll ["a"] $ \ [a] -> maybeOf a) 149 | , ("maybe", forAll ["a", "b"] $ \ [a, b] -> b ~> (a ~> b) ~> maybeOf a ~> b) 150 | , ("plus1", forAll [] $ \ [] -> intType ~> intType) 151 | , ("True", forAll [] $ \ [] -> boolType) 152 | , ("False", forAll [] $ \ [] -> boolType) 153 | 154 | , ("stBind", forAll ["s", "a", "b"] $ \ [s, a, b] -> infixType (stOf s a) (a ~> stOf s b) (stOf s b)) 155 | ] 156 | } 157 | 158 | litInt :: Integer -> HPlain Term 159 | litInt = BLeafP . LLiteral . PrimVal "Int" . BS8.pack . show 160 | 161 | record :: [(T.Tag, HPlain Term)] -> HPlain Term 162 | record = foldr (uncurry BRecExtendP) (BLeafP LRecEmpty) 163 | 164 | ($$:) :: HPlain Term -> [(T.Tag, HPlain Term)] -> HPlain Term 165 | f $$: args = BAppP f (record args) 166 | 167 | inf :: HPlain Term -> HPlain Term -> HPlain Term -> HPlain Term 168 | inf l f r = f $$: [("l", l), ("r", r)] 169 | 170 | factorialVal :: HPlain Term 171 | factorialVal = 172 | BAppP "fix" $ 173 | BLamP "loop" Pruned $ 174 | BLamP "x" Pruned $ 175 | "if" $$: 176 | [ ("condition", inf "x" "==" (litInt 0)) 177 | , ("then", litInt 1) 178 | , ("else", inf "x" "*" (BAppP "loop" $ inf "x" "-" (litInt 1))) 179 | ] 180 | 181 | euler1Val :: HPlain Term 182 | euler1Val = 183 | BAppP "sum" $ 184 | "filter" $$: 185 | [ ("from", inf (litInt 1) ".." (litInt 1000)) 186 | , ("predicate", 187 | BLamP "x" Pruned $ 188 | inf 189 | (inf (litInt 0) "==" (inf "x" "%" (litInt 3))) 190 | "||" 191 | (inf (litInt 0) "==" (inf "x" "%" (litInt 5))) 192 | ) 193 | ] 194 | 195 | let_ :: Var -> HPlain Term -> HPlain Term -> HPlain Term 196 | let_ k v r = BAppP (BLamP k Pruned r) v 197 | 198 | cons :: HPlain Term -> HPlain Term -> HPlain Term 199 | cons h t = BLeafP (LInject ":") `BAppP` record [("head", h), ("tail", t)] & BToNomP "List" 200 | 201 | list :: [HPlain Term] -> HPlain Term 202 | list = foldr cons (BToNomP "List" (BLeafP (LInject "[]") `BAppP` BLeafP LRecEmpty)) 203 | 204 | solveDepressedQuarticVal :: HPlain Term 205 | solveDepressedQuarticVal = 206 | BLamP "params" Pruned $ 207 | let_ "solvePoly" "id" $ 208 | let_ "sqrts" 209 | ( BLamP "x" Pruned $ 210 | let_ "r" (BAppP "sqrt" "x") $ 211 | list ["r", BAppP "negate" "r"] 212 | ) $ 213 | "if" $$: 214 | [ ("condition", inf d "==" (litInt 0)) 215 | , ( "then", 216 | BAppP "concat" $ 217 | "map" $$: 218 | [ ("list", BAppP "solvePoly" $ list [e, c, litInt 1]) 219 | , ("mapping", "sqrts") 220 | ]) 221 | , ( "else", 222 | BAppP "concat" $ 223 | "map" $$: 224 | [ ( "list", 225 | BAppP "sqrts" $ BAppP "head" $ BAppP "solvePoly" $ 226 | list 227 | [ BAppP "negate" (d %* d) 228 | , (c %* c) %- (litInt 4 %* e) 229 | , litInt 2 %* c 230 | , litInt 1 231 | ]) 232 | , ( "mapping", 233 | BLamP "x" Pruned $ 234 | BAppP "solvePoly" $ 235 | list 236 | [ (c %+ ("x" %* "x")) %- (d %/ "x") 237 | , litInt 2 %* "x" 238 | , litInt 2 239 | ]) 240 | ]) 241 | ] 242 | where 243 | c = BLeafP (LGetField "c") `BAppP` "params" 244 | d = BLeafP (LGetField "d") `BAppP` "params" 245 | e = BLeafP (LGetField "e") `BAppP` "params" 246 | 247 | (%+), (%-), (%*), (%/), (%//), (%>), (%%), (%==) :: HPlain Term -> HPlain Term -> HPlain Term 248 | x %+ y = inf x "+" y 249 | x %- y = inf x "-" y 250 | x %* y = inf x "*" y 251 | x %/ y = inf x "/" y 252 | x %// y = inf x "//" y 253 | x %> y = inf x ">" y 254 | x %% y = inf x "%" y 255 | x %== y = inf x "==" y 256 | 257 | factorsVal :: HPlain Term 258 | factorsVal = 259 | BAppP "fix" $ 260 | BLamP "loop" Pruned $ 261 | BLamP "params" Pruned $ 262 | if_ ((m %* m) %> n) (list [n]) $ 263 | if_ ((n %% m) %== litInt 0) 264 | (cons m $ "loop" $$: [("n", n %// m), ("min", m)]) $ 265 | "loop" $$: [("n", n), ("min", m %+ litInt 1)] 266 | where 267 | n = BLeafP (LGetField "n") `BAppP` "params" 268 | m = BLeafP (LGetField "min") `BAppP` "params" 269 | if_ b t f = 270 | BCaseP "False" (BLamP "_" Pruned f) (BCaseP "True" (BLamP "_" Pruned t) (BLeafP LAbsurd)) 271 | `BAppP` (BLeafP (LFromNom "Bool") `BAppP` b) 272 | -------------------------------------------------------------------------------- /src/Lamdu/Calc/Lens.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE NoImplicitPrelude, RankNTypes, NoMonomorphismRestriction #-} 2 | {-# LANGUAGE FlexibleContexts, TypeFamilies, TypeApplications, FlexibleInstances #-} 3 | {-# LANGUAGE DataKinds, ScopedTypeVariables, LambdaCase, TypeOperators #-} 4 | module Lamdu.Calc.Lens 5 | ( -- Leafs 6 | valHole , valBodyHole 7 | , valVar 8 | , valLiteral , valBodyLiteral 9 | , valLeafs 10 | -- Non-leafs 11 | , valApply 12 | , valTags 13 | , valGlobals 14 | , valNominals 15 | -- Subexpressions: 16 | , subExprPayloads 17 | , payloadsOf 18 | , HasTIds(..), tIds, qVarIds 19 | , HasQVarIds(..) 20 | , HasQVar(..), qvarsQVarIds 21 | ) where 22 | 23 | import Control.Lens (Traversal', Prism') 24 | import qualified Control.Lens as Lens 25 | import qualified Data.Map as Map 26 | import qualified Data.Set as Set 27 | import Hyper 28 | import Hyper.Recurse 29 | import Hyper.Syntax.Nominal (ToNom(..), NominalInst(..), NominalDecl, nScheme) 30 | import Hyper.Syntax.Row (RowExtend(..)) 31 | import Hyper.Syntax.Scheme (Scheme, QVars, _QVars, QVarInstances, _QVarInstances, sTyp) 32 | import Hyper.Type.Prune (Prune, _Unpruned) 33 | import Lamdu.Calc.Identifier (Identifier) 34 | import Lamdu.Calc.Term (Val) 35 | import qualified Lamdu.Calc.Term as V 36 | import qualified Lamdu.Calc.Type as T 37 | import Hyper.Unify.QuantifiedVar (HasQuantifiedVar(..)) 38 | 39 | import Lamdu.Calc.Internal.Prelude 40 | 41 | tIds :: 42 | forall k expr. 43 | (RTraversable k, HasTIds expr) => 44 | Traversal' (k # expr) T.NominalId 45 | tIds f = 46 | withDict (recurse (Proxy @(RTraversable k))) $ 47 | htraverse (Proxy @RTraversable #> bodyTIds f) 48 | 49 | qVarIds :: 50 | forall k expr. 51 | (RTraversable k, HasQVarIds expr) => 52 | Traversal' (k # expr) Identifier 53 | qVarIds f = 54 | withDict (recurse (Proxy @(RTraversable k))) $ 55 | htraverse (Proxy @RTraversable #> bodyQVarIds f) 56 | 57 | class HasTIds expr where 58 | bodyTIds :: RTraversable k => Traversal' (expr # k) T.NominalId 59 | 60 | class HasQVarIds expr where 61 | bodyQVarIds :: RTraversable k => Traversal' (expr # k) Identifier -- only a legal traversal if keys not modified to be duplicates! 62 | 63 | unsafeMapList :: Ord k1 => Lens.Iso (Map k0 v0) (Map k1 v1) [(k0, v0)] [(k1, v1)] 64 | unsafeMapList = Lens.iso (^@.. Lens.itraversed) Map.fromList 65 | 66 | class (HasQVarIds expr, Ord (QVar expr)) => HasQVar expr where 67 | qvarId :: Proxy expr -> Lens.Iso' (QVar expr) Identifier 68 | 69 | instance HasQVar T.Type where qvarId _ = T._Var 70 | instance HasQVar T.Row where qvarId _ = T._Var 71 | 72 | instance HasTIds T.Type where 73 | {-# INLINE bodyTIds #-} 74 | bodyTIds f (T.TInst (NominalInst tId args)) = 75 | NominalInst 76 | <$> f tId 77 | <*> htraverse (Proxy @HasTIds #> (_QVarInstances . traverse . tIds) f) 78 | args 79 | <&> T.TInst 80 | bodyTIds f x = htraverse (Proxy @HasTIds #> tIds f) x 81 | 82 | instance HasQVarIds T.Type where 83 | bodyQVarIds f (T.TInst (NominalInst tId args)) = 84 | NominalInst tId 85 | <$> htraverse (Proxy @HasQVar #> qvarInstancesQVarIds f) args 86 | <&> T.TInst 87 | bodyQVarIds f (T.TVar v) = T._Var f v <&> T.TVar 88 | bodyQVarIds f x = htraverse (Proxy @HasQVar #> qVarIds f) x 89 | 90 | qvarInstancesQVarIds :: forall expr h. (HasQVar expr, RTraversable h) => Traversal' (QVarInstances h # expr) Identifier 91 | qvarInstancesQVarIds f = 92 | (_QVarInstances . unsafeMapList . traverse) 93 | (\(k, typ) -> 94 | (,) <$> qvarId (Proxy @expr) f k <*> qVarIds f typ) 95 | 96 | instance HasTIds T.Row where 97 | {-# INLINE bodyTIds #-} 98 | bodyTIds f = htraverse (Proxy @HasTIds #> tIds f) 99 | 100 | instance HasQVarIds T.Row where 101 | bodyQVarIds f = htraverse (Proxy @HasQVarIds #> qVarIds f) 102 | 103 | instance HasTIds (Scheme T.Types T.Type) where 104 | bodyTIds = sTyp . tIds 105 | 106 | qvarsQVarIds :: 107 | forall expr. 108 | HasQVar expr => 109 | Traversal' (QVars # expr) Identifier 110 | qvarsQVarIds = 111 | _QVars . unsafeMapList . traverse . Lens._1 . qvarId (Proxy @expr) 112 | 113 | instance HasTIds (NominalDecl T.Type) where 114 | bodyTIds = nScheme . bodyTIds 115 | 116 | {-# INLINE valApply #-} 117 | valApply :: Traversal' (Ann a # V.Term) (V.App V.Term # Ann a) 118 | valApply = hVal . V._BApp 119 | 120 | {-# INLINE valHole #-} 121 | valHole :: Traversal' (Ann a # V.Term) () 122 | valHole = hVal . valBodyHole 123 | 124 | {-# INLINE valVar #-} 125 | valVar :: Traversal' (Ann a # V.Term) V.Var 126 | valVar = hVal . valBodyVar 127 | 128 | {-# INLINE valLiteral #-} 129 | valLiteral :: Traversal' (Ann a # V.Term) V.PrimVal 130 | valLiteral = hVal . valBodyLiteral 131 | 132 | {-# INLINE valBodyHole #-} 133 | valBodyHole :: Prism' (V.Term expr) () 134 | valBodyHole = V._BLeaf . V._LHole 135 | 136 | {-# INLINE valBodyVar #-} 137 | valBodyVar :: Prism' (V.Term expr) V.Var 138 | valBodyVar = V._BLeaf . V._LVar 139 | 140 | {-# INLINE valBodyLiteral #-} 141 | valBodyLiteral :: Prism' (V.Term expr) V.PrimVal 142 | valBodyLiteral = V._BLeaf . V._LLiteral 143 | 144 | subTerms :: Lens.Traversal' (V.Term # k) (k # V.Term) 145 | subTerms f = 146 | htraverse 147 | ( \case 148 | HWitness V.W_Term_Term -> f 149 | HWitness V.W_Term_HCompose_Prune_Type -> pure 150 | ) 151 | 152 | {-# INLINE valLeafs #-} 153 | valLeafs :: Lens.IndexedTraversal' (a # V.Term) (Ann a # V.Term) V.Leaf 154 | valLeafs f (Ann pl body) = 155 | case body of 156 | V.BLeaf l -> Lens.indexed f pl l <&> V.BLeaf 157 | _ -> subTerms (valLeafs f) body 158 | <&> Ann pl 159 | 160 | {-# INLINE subExprPayloads #-} 161 | subExprPayloads :: Lens.IndexedTraversal' (Pure # V.Term) (Val a) a 162 | subExprPayloads f x@(Ann (Const pl) body) = 163 | Ann 164 | <$> (Lens.indexed f (unwrap (const (^. hVal)) x) pl <&> Const) 165 | <*> (subTerms .> subExprPayloads) f body 166 | 167 | {-# INLINE payloadsOf #-} 168 | payloadsOf :: 169 | Lens.Fold (Pure # V.Term) a -> Lens.IndexedTraversal' (Pure # V.Term) (Val b) b 170 | payloadsOf l = 171 | subExprPayloads . Lens.ifiltered predicate 172 | where 173 | predicate idx _ = Lens.has l idx 174 | 175 | leafTags :: Lens.Traversal' V.Leaf T.Tag 176 | leafTags f (V.LInject t) = f t <&> V.LInject 177 | leafTags f (V.LGetField t) = f t <&> V.LGetField 178 | leafTags _ x = pure x 179 | 180 | {-# INLINE valTags #-} 181 | valTags :: Lens.Traversal' (Ann a # V.Term) T.Tag 182 | valTags = 183 | hVal . 184 | \f -> 185 | \case 186 | V.BLeaf l -> leafTags f l <&> V.BLeaf 187 | V.BCase (RowExtend t v r) -> 188 | RowExtend <$> f t <*> valTags f v <*> valTags f r <&> V.BCase 189 | V.BRecExtend (RowExtend t v r) -> 190 | RowExtend <$> f t <*> valTags f v <*> valTags f r <&> V.BRecExtend 191 | body -> 192 | htraverse 193 | ( \case 194 | HWitness V.W_Term_Term -> valTags f 195 | HWitness V.W_Term_HCompose_Prune_Type -> typeTags f 196 | ) body 197 | 198 | typeTags :: Lens.Traversal' (Ann a # HCompose Prune T.Type) T.Tag 199 | typeTags f = 200 | (hVal . hcomposed _Unpruned) 201 | ( htraverse 202 | ( \case 203 | HWitness T.W_Type_Type -> (_HCompose . typeTags) f 204 | HWitness T.W_Type_Row -> (_HCompose . rowTags) f 205 | HWitness (T.E_Type_NominalInst_NominalId_Types (HWitness T.W_Types_Type)) -> (_HCompose . typeTags) f 206 | HWitness (T.E_Type_NominalInst_NominalId_Types (HWitness T.W_Types_Row)) -> (_HCompose . rowTags) f 207 | ) 208 | ) 209 | 210 | rowTags :: Lens.Traversal' (Ann a # HCompose Prune T.Row) T.Tag 211 | rowTags = 212 | hVal . hcomposed _Unpruned . T._RExtend . onRExtend 213 | where 214 | onRExtend f (RowExtend tag val rest) = 215 | RowExtend 216 | <$> f tag 217 | <*> (_HCompose . typeTags) f val 218 | <*> (_HCompose . rowTags) f rest 219 | 220 | {-# INLINE valGlobals #-} 221 | valGlobals :: 222 | Set V.Var -> 223 | Lens.IndexedFold (a # V.Term) (Ann a # V.Term) V.Var 224 | valGlobals scope f (Ann pl body) = 225 | case body of 226 | V.BLeaf (V.LVar v) 227 | | scope ^. Lens.contains v -> V.LVar v & V.BLeaf & pure 228 | | otherwise -> Lens.indexed f pl v <&> V.LVar <&> V.BLeaf 229 | V.BLam (V.TypedLam var typ lamBody) -> 230 | valGlobals (Set.insert var scope) f lamBody 231 | <&> V.TypedLam var typ <&> V.BLam 232 | _ -> 233 | htraverse 234 | ( \case 235 | HWitness V.W_Term_Term -> valGlobals scope f 236 | HWitness V.W_Term_HCompose_Prune_Type -> pure 237 | ) body 238 | <&> Ann pl 239 | 240 | {-# INLINE valNominals #-} 241 | valNominals :: Lens.Traversal' (Ann a # V.Term) T.NominalId 242 | valNominals = 243 | hVal . 244 | \f -> 245 | \case 246 | V.BLeaf (V.LFromNom nomId) -> f nomId <&> V.LFromNom <&> V.BLeaf 247 | V.BToNom (ToNom nomId x) -> 248 | ToNom 249 | <$> f nomId 250 | <*> valNominals f x 251 | <&> V.BToNom 252 | body -> 253 | htraverse 254 | ( \case 255 | HWitness V.W_Term_Term -> valNominals f 256 | HWitness V.W_Term_HCompose_Prune_Type -> typeNominals f 257 | ) body 258 | 259 | {-# INLINE typeNominals #-} 260 | typeNominals :: Lens.Traversal' (Ann a # HCompose Prune T.Type) T.NominalId 261 | typeNominals = 262 | hVal . hcomposed _Unpruned . 263 | \f -> 264 | \case 265 | T.TInst (NominalInst nomId args) -> 266 | NominalInst 267 | <$> f nomId 268 | <*> htraverse 269 | ( \case 270 | HWitness T.W_Types_Type -> (_QVarInstances . traverse . _HCompose . typeNominals) f 271 | HWitness T.W_Types_Row -> (_QVarInstances . traverse . _HCompose . rowNominals) f 272 | ) args 273 | <&> T.TInst 274 | body -> 275 | htraverse 276 | ( \case 277 | HWitness T.W_Type_Type -> (_HCompose . typeNominals) f 278 | HWitness T.W_Type_Row -> (_HCompose . rowNominals) f 279 | HWitness (T.E_Type_NominalInst_NominalId_Types (HWitness T.W_Types_Type)) -> (_HCompose . typeNominals) f 280 | HWitness (T.E_Type_NominalInst_NominalId_Types (HWitness T.W_Types_Row)) -> (_HCompose . rowNominals) f 281 | ) body 282 | 283 | {-# INLINE rowNominals #-} 284 | rowNominals :: Lens.Traversal' (Ann a # HCompose Prune T.Row) T.NominalId 285 | rowNominals = 286 | hVal . hcomposed _Unpruned . 287 | \f -> 288 | htraverse 289 | ( \case 290 | HWitness T.W_Row_Type -> (_HCompose . typeNominals) f 291 | HWitness T.W_Row_Row -> (_HCompose . rowNominals) f 292 | ) 293 | -------------------------------------------------------------------------------- /src/Lamdu/Calc/Term.hs: -------------------------------------------------------------------------------- 1 | -- | Val AST 2 | {-# LANGUAGE NoImplicitPrelude, DeriveGeneric, DeriveTraversable, GADTs #-} 3 | {-# LANGUAGE GeneralizedNewtypeDeriving, TemplateHaskell, FlexibleContexts #-} 4 | {-# LANGUAGE UndecidableInstances, StandaloneDeriving, TypeFamilies, TypeOperators #-} 5 | {-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, ConstraintKinds #-} 6 | {-# LANGUAGE TupleSections, ScopedTypeVariables, DerivingStrategies, DataKinds #-} 7 | {-# LANGUAGE PolyKinds, TypeApplications #-} 8 | 9 | module Lamdu.Calc.Term 10 | ( Val 11 | , Leaf(..), _LVar, _LHole, _LLiteral, _LRecEmpty, _LAbsurd, _LFromNom, _LGetField, _LInject 12 | , PrimVal(..), primType, primData 13 | , Term(..), _BApp, _BLam, _BRecExtend, _BCase, _BToNom, _BLeaf, W_Term(..) 14 | , App(..), appFunc, appArg 15 | , TypedLam(..), tlIn, tlInType, tlOut 16 | , Var(..) 17 | , Scope(..), scopeNominals, scopeVarTypes, scopeLevel 18 | , emptyScope 19 | , ToNom(..), FromNom(..), RowExtend(..) 20 | , HPlain(..) 21 | ) where 22 | 23 | import qualified Control.Lens as Lens 24 | import qualified Data.ByteString.Char8 as BS8 25 | import Data.Maybe (fromMaybe) 26 | import Generics.Constraints (makeDerivings, makeInstances) 27 | import Hyper 28 | import Hyper.Infer 29 | import Hyper.Infer.Blame (Blame(..)) 30 | import Hyper.Type.Prune (Prune) 31 | import Hyper.Syntax hiding (Var) 32 | import Hyper.Syntax.Nominal (ToNom(..), FromNom(..), NominalInst(..), MonadNominals, LoadedNominalDecl) 33 | import Hyper.Syntax.Row (RowExtend(..), rowElementInfer) 34 | import Hyper.Syntax.Scheme (QVarInstances(..)) 35 | import qualified Hyper.Syntax.Var as TermVar 36 | import Hyper.Unify 37 | import qualified Hyper.Unify.Generalize as G 38 | import Hyper.Unify.New (newTerm, newUnbound) 39 | import Hyper.Unify.Term (UTerm(..)) 40 | import Lamdu.Calc.Identifier (Identifier) 41 | import qualified Lamdu.Calc.Type as T 42 | import Text.PrettyPrint ((<+>)) 43 | import qualified Text.PrettyPrint as PP 44 | import Text.PrettyPrint.HughesPJClass (Pretty(..), maybeParens) 45 | 46 | import Lamdu.Calc.Internal.Prelude 47 | 48 | {-# ANN module ("HLint: ignore Use const"::String) #-} 49 | 50 | newtype Var = Var { vvName :: Identifier } 51 | deriving stock Show 52 | deriving newtype (Eq, Ord, NFData, IsString, Pretty, Binary, Hashable) 53 | 54 | data PrimVal = PrimVal 55 | { _primType :: {-# UNPACK #-} !T.NominalId 56 | , _primData :: {-# UNPACK #-} !ByteString 57 | } deriving (Generic, Show, Eq, Ord) 58 | instance NFData PrimVal 59 | instance Binary PrimVal 60 | instance Hashable PrimVal 61 | 62 | Lens.makeLenses ''PrimVal 63 | 64 | data Leaf 65 | = LVar {-# UNPACK #-}!Var 66 | | LHole 67 | | LLiteral {-# UNPACK #-}!PrimVal 68 | | LRecEmpty 69 | | LAbsurd 70 | | LFromNom {-# UNPACK #-}!T.NominalId 71 | | LGetField {-# UNPACK #-}!T.Tag 72 | | LInject {-# UNPACK #-}!T.Tag 73 | deriving (Generic, Show, Eq, Ord) 74 | instance NFData Leaf 75 | instance Binary Leaf 76 | instance Hashable Leaf 77 | 78 | Lens.makePrisms ''Leaf 79 | 80 | data Term k 81 | = BApp {-# UNPACK #-}!(App Term k) 82 | | BLam {-# UNPACK #-}!(TypedLam Var (HCompose Prune T.Type) Term k) 83 | | BRecExtend {-# UNPACK #-}!(RowExtend T.Tag Term Term k) 84 | | BCase {-# UNPACK #-}!(RowExtend T.Tag Term Term k) 85 | | -- Convert to Nominal type 86 | BToNom {-# UNPACK #-}!(ToNom T.NominalId Term k) 87 | | BLeaf Leaf 88 | deriving Generic 89 | 90 | Lens.makePrisms ''Term 91 | makeHTraversableAndBases ''Term 92 | makeZipMatch ''Term 93 | makeHContext ''Term 94 | makeHasHPlain [''Term] 95 | 96 | instance RNodes Term 97 | instance 98 | (c Term, c (HCompose Prune T.Type), c (HCompose Prune T.Row)) => 99 | Recursively c Term 100 | instance RTraversable Term 101 | 102 | instance IsString (HPlain Term) where 103 | fromString = BLeafP . LVar . fromString 104 | 105 | instance 106 | (Pretty (f :# Term), Pretty (f :# HCompose Prune T.Type)) => 107 | Pretty (Term f) where 108 | pPrintPrec lvl prec b = 109 | case b of 110 | BLeaf (LVar var) -> pPrint var 111 | BLeaf (LLiteral (PrimVal _p d)) -> PP.text (BS8.unpack d) 112 | BLeaf LHole -> PP.text "?" 113 | BLeaf LAbsurd -> PP.text "absurd" 114 | BLeaf (LFromNom ident) -> PP.text "[" <+> PP.text "unpack" <+> pPrint ident <+> PP.text "]" 115 | BApp (App e1 e2) -> maybeParens (10 < prec) $ 116 | pPrintPrec lvl 10 e1 <+> pPrintPrec lvl 11 e2 117 | BLam (TypedLam n t e) -> maybeParens (0 < prec) $ 118 | PP.char '\\' PP.<> pPrint n PP.<> 119 | PP.char ':' PP.<> pPrint t <+> 120 | PP.text "->" <+> 121 | pPrint e 122 | BLeaf (LGetField x) -> maybeParens (12 < prec) $ PP.char '.' <> pPrint x 123 | BLeaf (LInject x) -> maybeParens (12 < prec) $ pPrint x <> PP.char ':' 124 | BCase (RowExtend n m mm) -> maybeParens (0 < prec) $ 125 | PP.vcat 126 | [ PP.text "case of" 127 | , pPrint n <> PP.text " -> " <> pPrint m 128 | , PP.text "_" <> PP.text " -> " <> pPrint mm 129 | ] 130 | BToNom (ToNom ident val) -> PP.text "[" <+> pPrint ident <+> PP.text "pack" <+> pPrint val <+> PP.text "]" 131 | BLeaf LRecEmpty -> PP.text "{}" 132 | BRecExtend (RowExtend tag val rest) -> 133 | PP.text "{" <+> 134 | prField PP.<> 135 | PP.comma <+> 136 | pPrint rest <+> 137 | PP.text "}" 138 | where 139 | prField = pPrint tag <+> PP.text "=" <+> pPrint val 140 | 141 | data Scope v = Scope 142 | { _scopeNominals :: Map T.NominalId (LoadedNominalDecl T.Type v) 143 | , _scopeVarTypes :: Map Var (HFlip G.GTerm T.Type v) 144 | , _scopeLevel :: ScopeLevel 145 | } deriving Generic 146 | Lens.makeLenses ''Scope 147 | 148 | makeHTraversableAndBases ''Scope 149 | 150 | {-# INLINE emptyScope #-} 151 | emptyScope :: Scope # v 152 | emptyScope = Scope mempty mempty (ScopeLevel 0) 153 | 154 | type instance TermVar.ScopeOf Term = Scope 155 | type instance InferOf Term = ANode T.Type 156 | instance HasInferredType Term where 157 | type instance TypeOf Term = T.Type 158 | inferredType _ = _ANode 159 | 160 | makeDerivings [''Eq, ''Ord, ''Show] [''Term, ''Scope] 161 | makeInstances [''Binary, ''NFData, ''Hashable] [''Term, ''Scope] 162 | 163 | instance TermVar.VarType Var Term where 164 | {-# INLINE varType #-} 165 | varType _ v x = 166 | x ^? scopeVarTypes . Lens.ix v . _HFlip 167 | & fromMaybe (error ("var not in scope: " <> show v)) 168 | & G.instantiate 169 | 170 | instance 171 | ( MonadNominals T.NominalId T.Type m 172 | , MonadScopeLevel m 173 | , TermVar.HasScope m Scope 174 | , UnifyGen m T.Type, UnifyGen m T.Row 175 | , LocalScopeType Var (UVarOf m # T.Type) m 176 | ) => 177 | Infer m Term where 178 | 179 | {-# INLINE inferBody #-} 180 | inferBody (BApp x) = inferBody x <&> Lens._1 %~ BApp 181 | inferBody (BLam x) = inferBody x <&> Lens._1 %~ BLam 182 | inferBody (BToNom x) = 183 | do 184 | (xI, xT) <- inferBody x 185 | newTerm (T.TInst xT) <&> MkANode <&> (BToNom xI, ) 186 | inferBody (BLeaf leaf) = 187 | case leaf of 188 | LHole -> newUnbound 189 | LRecEmpty -> newTerm T.REmpty >>= newTerm . T.TRecord 190 | LAbsurd -> 191 | FuncType 192 | <$> (newTerm T.REmpty >>= newTerm . T.TVariant) 193 | <*> newUnbound 194 | >>= newTerm . T.TFun 195 | LLiteral (PrimVal t _) -> 196 | T.Types (QVarInstances mempty) (QVarInstances mempty) 197 | & NominalInst t 198 | & T.TInst & newTerm 199 | LVar x -> 200 | inferBody (TermVar.Var x :: TermVar.Var Var Term # InferChild k w) 201 | <&> (^. Lens._2 . _ANode) 202 | LFromNom x -> 203 | inferBody (FromNom x :: FromNom T.NominalId Term # InferChild k w) 204 | <&> (^. Lens._2) 205 | >>= newTerm . T.TFun 206 | LGetField k -> 207 | do 208 | (rT, wR) <- rowElementInfer T.RExtend k 209 | T.TRecord wR & newTerm 210 | >>= newTerm . T.TFun . (`FuncType` rT) 211 | LInject k -> 212 | do 213 | (rT, wR) <- rowElementInfer T.RExtend k 214 | T.TVariant wR & newTerm 215 | >>= newTerm . T.TFun . FuncType rT 216 | <&> MkANode 217 | <&> (BLeaf leaf, ) 218 | inferBody (BRecExtend (RowExtend k v r)) = 219 | do 220 | InferredChild vI vR <- inferChild v 221 | InferredChild rI rR <- inferChild r 222 | restR <- 223 | scopeConstraints (Proxy @T.Row) <&> T.rForbiddenFields . Lens.contains k .~ True 224 | >>= newVar binding . UUnbound 225 | _ <- T.TRecord restR & newTerm >>= unify (rR ^. _ANode) 226 | RowExtend k (vR ^. _ANode) restR & T.RExtend & newTerm 227 | >>= newTerm . T.TRecord 228 | <&> MkANode 229 | <&> (BRecExtend (RowExtend k vI rI), ) 230 | inferBody (BCase (RowExtend tag handler rest)) = 231 | do 232 | InferredChild handlerI handlerT <- inferChild handler 233 | InferredChild restI restT <- inferChild rest 234 | fieldT <- newUnbound 235 | restR <- newUnbound 236 | result <- newUnbound 237 | _ <- 238 | FuncType fieldT result & T.TFun & newTerm 239 | >>= unify (handlerT ^. _ANode) 240 | restVarT <- T.TVariant restR & newTerm 241 | _ <- 242 | FuncType restVarT result & T.TFun & newTerm 243 | >>= unify (restT ^. _ANode) 244 | whole <- RowExtend tag fieldT restR & T.RExtend & newTerm >>= newTerm . T.TVariant 245 | FuncType whole result & T.TFun & newTerm 246 | <&> MkANode 247 | <&> (BCase (RowExtend tag handlerI restI), ) 248 | 249 | instance 250 | ( MonadNominals T.NominalId T.Type m 251 | , MonadScopeLevel m 252 | , TermVar.HasScope m Scope 253 | , UnifyGen m T.Type, UnifyGen m T.Row 254 | , LocalScopeType Var (UVarOf m # T.Type) m 255 | ) => 256 | Blame m Term where 257 | inferOfUnify _ x y = unify (x ^. _ANode) (y ^. _ANode) & void 258 | inferOfMatches _ x y = 259 | (==) 260 | <$> (semiPruneLookup (x ^. _ANode) <&> fst) 261 | <*> (semiPruneLookup (y ^. _ANode) <&> fst) 262 | 263 | -- Type synonym to ease the transition 264 | 265 | type Val a = Ann (Const a) # Term 266 | -------------------------------------------------------------------------------- /src/Lamdu/Calc/Type.hs: -------------------------------------------------------------------------------- 1 | -- | The Lamdu Calculus type AST. 2 | -- 3 | -- The Lamdu Calculus type system includes the set of types that can 4 | -- be expressed via the AST elements in this module. 5 | -- 6 | -- The Lamdu Calculus type AST is actually 2 different ASTs: 7 | -- 8 | -- * The AST for structural composite types (records, variants). The kinds of 9 | -- composite are differentiated via a phantom type-level tag 10 | -- 11 | -- * The AST for types: Nominal types, structural composite types, 12 | -- function types. 13 | {-# LANGUAGE NoImplicitPrelude, DeriveGeneric, GeneralizedNewtypeDeriving, TypeApplications #-} 14 | {-# LANGUAGE TemplateHaskell, DataKinds, StandaloneDeriving, DerivingVia, TypeOperators #-} 15 | {-# LANGUAGE UndecidableInstances, ConstraintKinds, FlexibleContexts, GADTs, TupleSections #-} 16 | {-# LANGUAGE FlexibleInstances, TypeFamilies, MultiParamTypeClasses, RankNTypes #-} 17 | 18 | module Lamdu.Calc.Type 19 | ( 20 | -- * Type Variable kinds 21 | RowVar, TypeVar 22 | -- * Typed identifiers of the Type AST 23 | , Var(..), NominalId(..), Tag(..) 24 | -- * Rows 25 | , Row(..), W_Row(..) 26 | -- * Row Prisms 27 | , _RExtend, _REmpty, _RVar 28 | -- * Type AST 29 | , Type(..), W_Type(..) 30 | , Scheme, Nominal 31 | , (~>) 32 | -- * Type Prisms 33 | , _Var 34 | , _TVar, _TFun, _TInst, _TRecord, _TVariant 35 | -- TODO: describe 36 | , Types(..), W_Types(..), tType, tRow 37 | , RConstraints(..), rForbiddenFields, rScope 38 | , rStructureMismatch 39 | 40 | , TypeError(..), _TypeError, _RowError 41 | 42 | , flatRow 43 | 44 | , HPlain(..) 45 | ) where 46 | 47 | import Algebra.PartialOrd (PartialOrd(..)) 48 | import qualified Control.Lens as Lens 49 | import Generic.Data (Generically(..)) 50 | import Generics.Constraints (makeDerivings, makeInstances) 51 | import Hyper 52 | import Hyper.Class.Optic (HNodeLens(..), HSubset(..)) 53 | import Hyper.Infer 54 | import Hyper.Infer.Blame (Blame(..)) 55 | import Hyper.Syntax hiding (Var, _Var) 56 | import Hyper.Syntax.Nominal 57 | import Hyper.Syntax.Row 58 | import qualified Hyper.Syntax.Scheme as S 59 | import Hyper.Unify 60 | import Hyper.Unify.New (newTerm) 61 | import Hyper.Unify.QuantifiedVar (HasQuantifiedVar(..)) 62 | import Lamdu.Calc.Identifier (Identifier) 63 | import Text.PrettyPrint ((<+>)) 64 | import qualified Text.PrettyPrint as PP 65 | import Text.PrettyPrint.HughesPJClass (Pretty(..), maybeParens) 66 | 67 | import Lamdu.Calc.Internal.Prelude 68 | 69 | -- | A type varible of some kind ('Var' 'Type', 'Var' 'Variant', or 'Var' 'Record') 70 | newtype Var (t :: HyperType) = Var { tvName :: Identifier } 71 | deriving stock Show 72 | deriving newtype (Eq, Ord, NFData, IsString, Pretty, Binary, Hashable) 73 | 74 | -- | An identifier for a nominal type 75 | newtype NominalId = NominalId { nomId :: Identifier } 76 | deriving stock Show 77 | deriving newtype (Eq, Ord, NFData, IsString, Pretty, Binary, Hashable) 78 | 79 | -- | An identifier for a component in a variant type (aka data 80 | -- constructor) or a component(field) in a record 81 | newtype Tag = Tag { tagName :: Identifier } 82 | deriving stock Show 83 | deriving newtype (Eq, Ord, NFData, IsString, Pretty, Binary, Hashable) 84 | 85 | -- | A row type variable that represents a set of 86 | -- typed fields in a row 87 | type RowVar = Var Row 88 | 89 | -- | A type variable that represents a type 90 | type TypeVar = Var Type 91 | 92 | -- | The AST for rows (records, variants) For 93 | -- example: RExtend "a" int (RExtend "b" bool (RVar "c")) represents 94 | -- the composite type: 95 | -- > { a : int, b : bool | c } 96 | data Row k 97 | = RExtend (RowExtend Tag Type Row k) 98 | -- ^ Extend a row type with an extra component (field / 99 | -- data constructor). 100 | | REmpty 101 | -- ^ The empty row type (empty record [unit] or empty variant [void]) 102 | | RVar RowVar 103 | -- ^ A row type variable 104 | deriving Generic 105 | 106 | -- | The AST for any Lamdu Calculus type 107 | data Type k 108 | = TVar TypeVar 109 | -- ^ A type variable 110 | | TFun (FuncType Type k) 111 | -- ^ A (non-dependent) function of the given parameter and result types 112 | | TInst (NominalInst NominalId Types k) 113 | -- ^ An instantiation of a nominal type of the given id with the 114 | -- given keyword type arguments 115 | | TRecord (k :# Row) 116 | -- ^ Lifts a composite record type 117 | | TVariant (k :# Row) 118 | -- ^ Lifts a composite variant type 119 | deriving Generic 120 | 121 | data Types k = Types 122 | { _tType :: k :# Type 123 | , _tRow :: k :# Row 124 | } deriving Generic 125 | 126 | data RConstraints = RowConstraints 127 | { _rForbiddenFields :: Set Tag 128 | , _rScope :: ScopeLevel 129 | } deriving (Generic, Eq, Ord, Show) 130 | deriving (Semigroup, Monoid) via Generically RConstraints 131 | 132 | data TypeError k 133 | = TypeError (UnifyError Type k) 134 | | RowError (UnifyError Row k) 135 | | NominalNotFound NominalId 136 | deriving Generic 137 | 138 | Lens.makeLenses ''RConstraints 139 | Lens.makeLenses ''Types 140 | Lens.makePrisms ''Row 141 | Lens.makePrisms ''Type 142 | Lens.makePrisms ''TypeError 143 | Lens.makePrisms ''Var 144 | 145 | makeHTraversableApplyAndBases ''Types 146 | makeHTraversableAndBases ''Row 147 | makeHTraversableAndBases ''Type 148 | makeZipMatch ''Row 149 | makeZipMatch ''Type 150 | makeZipMatch ''Types 151 | makeHContext ''Row 152 | makeHContext ''Type 153 | makeHContext ''Types 154 | instance RNodes Row 155 | instance RNodes Type 156 | instance (c Type, c Row) => Recursively c Type 157 | instance (c Type, c Row) => Recursively c Row 158 | instance RTraversable Row 159 | instance RTraversable Type 160 | 161 | type Nominal = NominalInst NominalId Types 162 | type Scheme = S.Scheme Types Type 163 | 164 | instance HNodeLens Types Type where 165 | {-# INLINE hNodeLens #-} 166 | hNodeLens = tType 167 | 168 | instance HNodeLens Types Row where 169 | {-# INLINE hNodeLens #-} 170 | hNodeLens = tRow 171 | 172 | instance (UnifyGen m Type, UnifyGen m Row) => S.HasScheme Types m Type 173 | instance (UnifyGen m Type, UnifyGen m Row) => S.HasScheme Types m Row 174 | 175 | -- | A convenience infix alias for 'TFun' 176 | infixr 2 ~> 177 | (~>) :: Pure # Type -> Pure # Type -> Pure # Type 178 | x ~> y = _Pure # TFun (FuncType x y) 179 | 180 | type Deps c k = ((c (k :# Type), c (k :# Row)) :: Constraint) 181 | 182 | instance Deps Pretty k => Pretty (Type k) where 183 | pPrintPrec lvl prec typ = 184 | case typ of 185 | TVar n -> pPrint n 186 | TInst n -> pPrint n 187 | TFun (FuncType t s) -> 188 | maybeParens (8 < prec) $ 189 | pPrintPrec lvl 9 t <+> PP.text "->" <+> pPrintPrec lvl 8 s 190 | TRecord r -> PP.text "*" <> pPrint r 191 | TVariant s -> PP.text "+" <> pPrint s 192 | 193 | instance Pretty (Row # Pure) where 194 | pPrint REmpty = PP.text "{}" 195 | pPrint x = 196 | PP.text "{" <+> go PP.empty x <+> PP.text "}" 197 | where 198 | go _ REmpty = PP.empty 199 | go sep (RVar tv) = sep <> pPrint tv <> PP.text "..." 200 | go sep (RExtend (RowExtend f t r)) = 201 | sep PP.<> pPrint f <+> PP.text ":" <+> pPrint t PP.<> go (PP.text ", ") (r ^. _Pure) 202 | 203 | instance Deps Pretty k => Pretty (Types k) where 204 | pPrint (Types t r) = PP.text "{" <+> pPrint t <+> PP.text "|" <+> pPrint r <+> PP.text "}" 205 | 206 | instance Pretty RConstraints where 207 | pPrint (RowConstraints tags level) = 208 | pPrint (tags ^.. Lens.folded) PP.<+> 209 | (PP.text "(" <> pPrint level <> PP.text ")") 210 | 211 | instance Pretty (TypeError # Pure) where 212 | pPrint (TypeError x) = pPrint x 213 | pPrint (RowError x) = pPrint x 214 | pPrint (NominalNotFound x) = PP.text "Nominal not found:" <+> pPrint x 215 | 216 | type instance NomVarTypes Type = Types 217 | 218 | instance HSubset Type Type (FuncType Type) (FuncType Type) where 219 | {-# INLINE hSubset #-} 220 | hSubset = _TFun 221 | 222 | instance HasNominalInst NominalId Type where 223 | {-# INLINE nominalInst #-} 224 | nominalInst = _TInst 225 | 226 | instance HasQuantifiedVar Type where 227 | type QVar Type = TypeVar 228 | quantifiedVar = _TVar 229 | 230 | instance HasQuantifiedVar Row where 231 | type QVar Row = RowVar 232 | quantifiedVar = _RVar 233 | 234 | instance HasTypeConstraints Type where 235 | type instance TypeConstraintsOf Type = ScopeLevel 236 | {-# INLINE verifyConstraints #-} 237 | verifyConstraints _ (TVar x) = TVar x & Just 238 | verifyConstraints c (TFun x) = x & hmapped1 %~ WithConstraint c & TFun & Just 239 | verifyConstraints c (TRecord x) = 240 | WithConstraint (RowConstraints mempty c) x & TRecord & Just 241 | verifyConstraints c (TVariant x) = 242 | WithConstraint (RowConstraints mempty c) x & TVariant & Just 243 | verifyConstraints c (TInst (NominalInst n (Types t r))) = 244 | Types 245 | (t & S._QVarInstances . traverse %~ WithConstraint c) 246 | (r & S._QVarInstances . traverse %~ WithConstraint (RowConstraints mempty c)) 247 | & NominalInst n & TInst & Just 248 | 249 | instance HasTypeConstraints Row where 250 | type instance TypeConstraintsOf Row = RConstraints 251 | {-# INLINE verifyConstraints #-} 252 | verifyConstraints _ REmpty = Just REmpty 253 | verifyConstraints _ (RVar x) = RVar x & Just 254 | verifyConstraints c (RExtend x) = verifyRowExtendConstraints (^. rScope) c x <&> RExtend 255 | 256 | instance TypeConstraints RConstraints where 257 | {-# INLINE generalizeConstraints #-} 258 | generalizeConstraints = rScope .~ mempty 259 | toScopeConstraints = rForbiddenFields .~ mempty 260 | 261 | instance RowConstraints RConstraints where 262 | type RowConstraintsKey RConstraints = Tag 263 | {-# INLINE forbidden #-} 264 | forbidden = rForbiddenFields 265 | 266 | instance PartialOrd RConstraints where 267 | {-# INLINE leq #-} 268 | RowConstraints f0 s0 `leq` RowConstraints f1 s1 = f0 `leq` f1 && s0 `leq` s1 269 | 270 | type instance InferOf Type = ANode Type 271 | type instance InferOf Row = ANode Row 272 | instance HasInferredValue Type where inferredValue = _ANode 273 | instance HasInferredValue Row where inferredValue = _ANode 274 | 275 | instance (Monad m, UnifyGen m Type, UnifyGen m Row) => Infer m Type where 276 | inferBody x = 277 | do 278 | xI <- htraverse (const inferChild) x 279 | hmap (Proxy @HasInferredValue #> (^. inType . inferredValue)) xI 280 | & newTerm 281 | <&> (hmap (const (^. inRep)) xI, ) . MkANode 282 | 283 | instance (Monad m, UnifyGen m Type, UnifyGen m Row) => Infer m Row where 284 | inferBody x = 285 | do 286 | xI <- htraverse (const inferChild) x 287 | hmap (Proxy @HasInferredValue #> (^. inType . inferredValue)) xI 288 | & newTerm 289 | <&> (hmap (const (^. inRep)) xI, ) . MkANode 290 | 291 | instance (UnifyGen m Type, UnifyGen m Row) => Blame m Row where 292 | inferOfUnify _ x y = unify (x ^. _ANode) (y ^. _ANode) & void 293 | inferOfMatches _ x y = 294 | (==) 295 | <$> (semiPruneLookup (x ^. _ANode) <&> fst) 296 | <*> (semiPruneLookup (y ^. _ANode) <&> fst) 297 | 298 | instance (UnifyGen m Type, UnifyGen m Row) => Blame m Type where 299 | inferOfUnify _ x y = unify (x ^. _ANode) (y ^. _ANode) & void 300 | inferOfMatches _ x y = 301 | (==) 302 | <$> (semiPruneLookup (x ^. _ANode) <&> fst) 303 | <*> (semiPruneLookup (y ^. _ANode) <&> fst) 304 | 305 | {-# INLINE rStructureMismatch #-} 306 | rStructureMismatch :: 307 | (Unify m Type, Unify m Row) => 308 | (forall c. Unify m c => UVarOf m # c -> UVarOf m # c -> m (UVarOf m # c)) -> 309 | Row # UVarOf m -> 310 | Row # UVarOf m -> 311 | m () 312 | rStructureMismatch f (RExtend r0) (RExtend r1) = 313 | rowExtendStructureMismatch f _RExtend r0 r1 314 | rStructureMismatch _ x y = Mismatch x y & unifyError 315 | 316 | flatRow :: Lens.Iso' (Pure # Row) (FlatRowExtends Tag Type Row # Pure) 317 | flatRow = 318 | Lens.iso flatten unflatten 319 | where 320 | flatten = 321 | Lens.runIdentity . 322 | flattenRow (Lens.Identity . (^? _Pure . _RExtend)) 323 | unflatten = 324 | Lens.runIdentity . 325 | unflattenRow (Lens.Identity . (_Pure . _RExtend #)) 326 | 327 | makeDerivings [''Eq, ''Ord, ''Show] [''Row, ''Type, ''Types, ''TypeError] 328 | makeInstances [''Binary, ''NFData] [''Row, ''Type, ''Types, ''TypeError] 329 | 330 | makeHasHPlain [''Type, ''Row, ''Types] 331 | 332 | instance NFData RConstraints 333 | instance Binary RConstraints 334 | --------------------------------------------------------------------------------