├── .dir-locals.el
├── Setup.hs
├── stack.yaml
├── hie.yaml
├── .gitignore
├── .ghci
├── src
└── Lamdu
│ └── Calc
│ ├── Internal
│ └── Prelude.hs
│ ├── Definition.hs
│ ├── Term
│ └── Eq.hs
│ ├── Identifier.hs
│ ├── Infer.hs
│ ├── Lens.hs
│ ├── Term.hs
│ └── Type.hs
├── test
├── test.hs
├── benchmark.hs
└── TestVals.hs
├── LICENSE
├── doc
├── Optimization.md
└── ExceptionMonad.md
├── package.yaml
├── tools
└── core-type-apps.py
├── lamdu-calculus.cabal
└── README.md
/.dir-locals.el:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/Setup.hs:
--------------------------------------------------------------------------------
1 | import Distribution.Simple
2 | main = defaultMain
3 |
--------------------------------------------------------------------------------
/stack.yaml:
--------------------------------------------------------------------------------
1 | resolver: lts-20.23
2 | packages:
3 | - '.'
4 | extra-deps:
5 | - github: lamdu/hypertypes
6 | commit: 354ac7be8dc1b15a8df9b94ff010054ec2b31da8
7 |
--------------------------------------------------------------------------------
/hie.yaml:
--------------------------------------------------------------------------------
1 | cradle:
2 | stack:
3 | - path: "./test"
4 | component: "lamdu-calculus:test:lamdu-calculus-test"
5 |
6 | - path: "./src"
7 | component: "lamdu-calculus:lib"
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /dist/
2 | .stack-work/
3 | .stack-work-profile/
4 | /tags
5 | /ghci-out
6 | .ghcid
7 | .vscode/
8 | *.o
9 | *.hi
10 | *.dyn_o
11 | *.dyn_hi
12 | *.prof
13 | stack.yaml.lock
14 | /dumps/
15 | .DS_Store
16 |
--------------------------------------------------------------------------------
/.ghci:
--------------------------------------------------------------------------------
1 | :set -isrc -itest
2 | :set -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints
3 | :set -odir=ghci-out
4 | :set -outputdir=ghci-out
5 | :set -hide-package base-compat-batteries
6 |
--------------------------------------------------------------------------------
/src/Lamdu/Calc/Internal/Prelude.hs:
--------------------------------------------------------------------------------
1 | module Lamdu.Calc.Internal.Prelude
2 | ( module X
3 | ) where
4 |
5 | import Control.DeepSeq as X (NFData(..))
6 | import Control.Lens.Operators as X
7 | import Control.Monad as X (void, guard, join)
8 | import Data.Binary as X (Binary)
9 | import Data.ByteString as X (ByteString)
10 | import Data.Hashable as X (Hashable)
11 | import Data.Map as X (Map)
12 | import Data.Set as X (Set)
13 | import Data.String as X (IsString(..))
14 | import Prelude.Compat as X
15 |
--------------------------------------------------------------------------------
/test/test.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE OverloadedStrings, TypeFamilyDependencies #-}
2 |
3 | import Control.Lens (at)
4 | import Control.Lens.Operators
5 | import Data.Set (fromList)
6 | import Hyper
7 | import Hyper.Syntax.Scheme
8 | import Lamdu.Calc.Infer
9 | import Lamdu.Calc.Type
10 | import Test.Framework
11 | import Test.Framework.Providers.HUnit (testCase)
12 | import Test.HUnit (assertBool)
13 |
14 | alphaEqTest :: Test
15 | alphaEqTest =
16 | not (alphaEq (f (TVarP "a")) (f (TRecordP REmptyP)))
17 | & assertBool "should alpha eq"
18 | & testCase "alpha-eq"
19 | where
20 | f x =
21 | TFunP (TVarP "a") (TVariantP (RExtendP "t" x (RVarP "c"))) ^. hPlain
22 | & Scheme
23 | ( Types
24 | (QVars (mempty & at "a" ?~ mempty))
25 | (QVars (mempty & at "c" ?~ RowConstraints (fromList ["t"]) mempty))
26 | )
27 | & Pure
28 |
29 | main :: IO ()
30 | main = defaultMain [alphaEqTest]
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright Author name here (c) 2016
2 |
3 | All rights reserved.
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | * Redistributions of source code must retain the above copyright
9 | notice, this list of conditions and the following disclaimer.
10 |
11 | * Redistributions in binary form must reproduce the above
12 | copyright notice, this list of conditions and the following
13 | disclaimer in the documentation and/or other materials provided
14 | with the distribution.
15 |
16 | * Neither the name of Author name here nor the names of other
17 | contributors may be used to endorse or promote products derived
18 | from this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/src/Lamdu/Calc/Definition.hs:
--------------------------------------------------------------------------------
1 | -- | A lamdu-calculus `Definition` is a top-level definition of a
2 | -- value along with the types of all free variables/nominals it uses
3 | {-# LANGUAGE NoImplicitPrelude, TemplateHaskell, DeriveGeneric, TypeOperators #-}
4 |
5 | module Lamdu.Calc.Definition
6 | ( Deps(..), depsGlobalTypes, depsNominals
7 | , pruneDeps
8 | ) where
9 |
10 | import Hyper (Ann, Pure, Const(..), Generic, HFunctor(..), hflipped, type (#))
11 | import Hyper.Syntax.Nominal (NominalDecl)
12 | import qualified Control.Lens as Lens
13 | import qualified Data.Map as Map
14 | import qualified Data.Set as Set
15 | import Lamdu.Calc.Lens (valGlobals, valNominals)
16 | import qualified Lamdu.Calc.Term as V
17 | import Lamdu.Calc.Type (Type)
18 | import qualified Lamdu.Calc.Type as T
19 |
20 | import Lamdu.Calc.Internal.Prelude
21 |
22 | data Deps = Deps
23 | { _depsGlobalTypes :: !(Map V.Var (Pure # T.Scheme))
24 | , _depsNominals :: !(Map T.NominalId (Pure # NominalDecl Type))
25 | } deriving (Generic, Show, Eq, Ord)
26 | instance NFData Deps
27 | instance Binary Deps
28 |
29 | Lens.makeLenses ''Deps
30 |
31 | instance Semigroup Deps where
32 | Deps t0 n0 <> Deps t1 n1 = Deps (t0 <> t1) (n0 <> n1)
33 | instance Monoid Deps where
34 | mempty = Deps mempty mempty
35 | mappend = (<>)
36 |
37 | pruneDeps ::
38 | Ann a # V.Term -> Deps -> Deps
39 | pruneDeps e deps =
40 | deps
41 | & depsGlobalTypes %~ prune (valGlobals mempty)
42 | & depsNominals %~ prune valNominals
43 | where
44 | ev = e & hflipped %~ hmap (\_ _ -> Const ())
45 | prune f = Map.filterWithKey (const . (`Set.member` Set.fromList (ev ^.. f)))
46 |
--------------------------------------------------------------------------------
/doc/Optimization.md:
--------------------------------------------------------------------------------
1 | # The process for optimizing the inference with SPECIALIZATION
2 |
3 | This document is work-in-progress! I need help!
4 |
5 | Run the benchmark while dumping GHC core:
6 |
7 | * If building after not changing the code, force a build via `stack clean lamdu-calculus`
8 | * `stack bench --ghc-options "-dumpdir dumps -ddump-simpl -dsuppress-coercions -dsuppress-idinfo -dsuppress-module-prefixes -dsuppress-timestamps"`
9 | * This generated `.dump-simpl` files in the `dumps` folder
10 | * The top-level file of interest is `dumps/test/benchmark.dump-simpl`
11 | * In `dumps/src/Lamdu/Calc/` there are dump files for `Infer` and `Term` which are also of interest
12 |
13 | ## Searching for type applications
14 |
15 | The type applications may look like:
16 |
17 | * `pruneDeps1 @ ()`
18 | * `emptyScope @ ('AHyperType UVar)`
19 | * `$fNFDataTypeError_$crnf @ ('AHyperType Pure)`
20 |
21 | Now we need to see which of those are benign.
22 | If the type of the definition, does not have class constraints on these variables, such as
23 |
24 | pruneDeps :: Tree (Ann a) V.Term -> Deps -> Deps
25 |
26 | Then this is a benign type application (i.e no benefit from adding a `SPECIALIZE` pragma for it).
27 | Other applications may not be benign but are not significant,
28 | for example if they are only called one time when the inference results in an error, but not in inner loops,
29 | such as the `NFData` instance above.
30 |
31 | We want to add `SPECIALIZE` pragmas to significant unspecialized (using type applications) calls.
32 |
33 | We can find such type applications using the `tools/core-type-apps.py` script and manually searching for important functions in the code (the wip script isn't fully functional..)
34 |
35 | Note that apparently sometimes specializing one function causes GHC to not use a specialized version of an inner call due to type family confusion.
36 |
--------------------------------------------------------------------------------
/package.yaml:
--------------------------------------------------------------------------------
1 | name: lamdu-calculus
2 | version: 0.2.0.1
3 | github: "lamdu/lamdu-calculus"
4 | license: BSD3
5 | author: "Yair Chuchem, Eyal Lotem"
6 | maintainer: "yairchu@gmail.com"
7 | copyright: "2021 Yair Chuchem, Eyal Lotem"
8 |
9 | extra-source-files:
10 | - README.md
11 |
12 | synopsis: The Lamdu Calculus programming language
13 | category: Language
14 |
15 | description: Please see README.md
16 |
17 | dependencies:
18 | - base >= 4.7
19 | - base-compat >= 0.8.2
20 | - bytestring
21 | - containers
22 | - hypertypes >= 0.2
23 | - lens >= 4.1
24 |
25 | ghc-options:
26 | - -fexpose-all-unfoldings
27 | - -Wall
28 | - -Wnoncanonical-monad-instances
29 | - -Wcompat
30 | - -Wincomplete-record-updates
31 | - -Wincomplete-uni-patterns
32 | - -Wredundant-constraints
33 | - -Wunused-packages
34 | - -fdicts-cheap
35 | - -O2
36 | - -fspecialise-aggressively
37 |
38 | ghc-prof-options:
39 | - -O2
40 |
41 | library:
42 | source-dirs: src
43 | other-modules:
44 | - Lamdu.Calc.Internal.Prelude
45 | dependencies:
46 | - base16-bytestring
47 | - binary
48 | - deepseq
49 | - generic-constraints
50 | - generic-data
51 | - hashable
52 | - lattices
53 | - monad-st
54 | - mtl
55 | - pretty >= 1.1.2
56 | - transformers
57 |
58 | tests:
59 | lamdu-calculus-test:
60 | main: test.hs
61 | source-dirs: test
62 | dependencies:
63 | - HUnit
64 | - lamdu-calculus
65 | - test-framework
66 | - test-framework-hunit
67 |
68 | benchmarks:
69 | lamdu-calculus-bench:
70 | main: benchmark.hs
71 | source-dirs: test
72 | ghc-options:
73 | - -O2
74 | - -Wall
75 | - -Wnoncanonical-monad-instances
76 | - -Wcompat
77 | - -Wincomplete-record-updates
78 | - -Wincomplete-uni-patterns
79 | - -Wredundant-constraints
80 | dependencies:
81 | - criterion
82 | - deepseq
83 | - monad-st
84 | - mtl
85 | - lamdu-calculus
86 | - transformers
87 |
--------------------------------------------------------------------------------
/src/Lamdu/Calc/Term/Eq.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE NoImplicitPrelude, TypeApplications, TypeOperators, FlexibleInstances #-}
2 |
3 | module Lamdu.Calc.Term.Eq
4 | ( couldEq
5 | ) where
6 |
7 | import Hyper
8 | import Hyper.Class.ZipMatch
9 | import Hyper.Type.Prune
10 | import qualified Control.Lens as Lens
11 | import qualified Data.Map as Map
12 | import Lamdu.Calc.Term
13 | import qualified Lamdu.Calc.Type as T
14 |
15 | import Lamdu.Calc.Internal.Prelude
16 |
17 | class CouldEq e where
18 | go :: Map Var Var -> Pure # e -> Pure # e -> Bool
19 |
20 | instance CouldEq Term where
21 | go xToY (Pure (BLam (TypedLam xvar xtyp xresult))) (Pure (BLam (TypedLam yvar ytyp yresult))) =
22 | go xToY xtyp ytyp &&
23 | go (xToY & Lens.at xvar ?~ yvar) xresult yresult
24 | go xToY (Pure xBody) (Pure yBody) =
25 | case join (zipMatch_ (Proxy @CouldEq #> \x -> guard . go xToY x) xBody yBody) of
26 | Just () -> True
27 | Nothing ->
28 | case (xBody, yBody) of
29 | (BLeaf LHole, _) -> True
30 | (_, BLeaf LHole) -> True
31 | (BLeaf (LVar x), BLeaf (LVar y)) -> xToY ^. Lens.at x == Just y
32 | _ -> False
33 |
34 | instance CouldEq (HCompose Prune T.Type) where
35 | go _ (Pure (HCompose Pruned)) _ = True
36 | go _ _ (Pure (HCompose Pruned)) = True
37 | go xToY (Pure xBody) (Pure yBody) =
38 | Lens.has Lens._Just
39 | (join (zipMatch_ (Proxy @CouldEq #> \x -> guard . go xToY x) xBody yBody))
40 |
41 | instance CouldEq (HCompose Prune T.Row) where
42 | go _ (Pure (HCompose Pruned)) _ = True
43 | go _ _ (Pure (HCompose Pruned)) = True
44 | go xToY (Pure xBody) (Pure yBody) =
45 | Lens.has Lens._Just
46 | (join (zipMatch_ (Proxy @CouldEq #> \x -> guard . go xToY x) xBody yBody))
47 |
48 | couldEq :: Pure # Term -> Pure # Term -> Bool
49 | couldEq = go Map.empty
50 |
--------------------------------------------------------------------------------
/src/Lamdu/Calc/Identifier.hs:
--------------------------------------------------------------------------------
1 | -- | This module defines the 'Identifier' type and is meant to be
2 | -- cheaply importable without creaeting import cycles.
3 | {-# LANGUAGE NoImplicitPrelude, DeriveGeneric, GeneralizedNewtypeDeriving, DerivingStrategies #-}
4 | module Lamdu.Calc.Identifier
5 | (
6 | -- * Identifier type
7 | Identifier(..)
8 | -- * Hex representation of identifier bytes
9 | , identHex, identFromHex
10 | -- ** Laws
11 |
12 | -- |
13 | -- > identFromHex . identHex == Right
14 | ) where
15 |
16 | import qualified Data.ByteString.Base16 as Hex
17 | import qualified Data.ByteString.Char8 as BS
18 | import qualified Data.Char as Char
19 | import GHC.Generics (Generic)
20 | import qualified Text.PrettyPrint as PP
21 | import Text.PrettyPrint.HughesPJClass (Pretty(..))
22 |
23 | import Lamdu.Calc.Internal.Prelude
24 |
25 | -- | A low-level identifier data-type. This is used to identify
26 | -- variables, type variables, tags and more.
27 | newtype Identifier = Identifier ByteString
28 | deriving stock (Generic, Show)
29 | deriving newtype (Eq, Ord, Binary, Hashable)
30 | instance NFData Identifier
31 | instance Pretty Identifier where
32 | pPrint (Identifier x)
33 | | all Char.isPrint (BS.unpack x) = PP.text $ BS.unpack x
34 | | otherwise = PP.text $ identHex $ Identifier x
35 | instance IsString Identifier where fromString = Identifier . fromString
36 | -- ^ IsString uses the underlying `ByteString.Char8` instance, use
37 | -- only with Latin1 strings
38 |
39 | -- | Convert the identifier bytes to a hex string
40 | --
41 | -- > > identHex (Identifier "a1")
42 | -- > "6131"
43 | identHex :: Identifier -> String
44 | identHex (Identifier bs) = Hex.encode bs & BS.unpack
45 |
46 | -- | Convert a hex string (e.g: one generated by 'identHex') to an
47 | -- Identifier
48 | --
49 | -- > > identFromHex (Identifier "6131")
50 | -- > Right (Identifier "a1")
51 | identFromHex :: String -> Either String Identifier
52 | identFromHex str = BS.pack str & Hex.decode <&> Identifier
53 |
--------------------------------------------------------------------------------
/test/benchmark.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE NoImplicitPrelude, OverloadedStrings, ScopedTypeVariables, FlexibleContexts, TypeOperators #-}
2 |
3 | import Hyper
4 | import Hyper.Infer (infer, inferResult)
5 | import Hyper.Recurse (wrap)
6 | import Hyper.Unify (UVarOf, UnifyGen, applyBindings)
7 | import Control.DeepSeq (rnf)
8 | import Control.Exception (evaluate)
9 | import Control.Lens (ASetter')
10 | import qualified Control.Lens as Lens
11 | import Control.Lens.Operators
12 | import Control.Monad.Reader
13 | import Control.Monad.ST.Class (MonadST(..))
14 | import Control.Monad.Trans.Maybe (MaybeT(..))
15 | import Criterion (Benchmarkable, whnfIO)
16 | import Criterion.Main (bench, defaultMain)
17 | import Data.STRef (newSTRef)
18 | import Lamdu.Calc.Definition (pruneDeps)
19 | import Lamdu.Calc.Infer
20 | import Lamdu.Calc.Term (Term, Scope, emptyScope)
21 | import qualified Lamdu.Calc.Type as T
22 | import TestVals
23 |
24 | import Prelude.Compat
25 |
26 | localInitEnv ::
27 | ( MonadReader env m
28 | , UnifyGen m T.Type
29 | , UnifyGen m T.Row
30 | ) =>
31 | ASetter' env (Scope # UVarOf m) -> Ann z # Term -> m a -> m a
32 | localInitEnv inferEnv e action =
33 | do
34 | addScope <- loadDeps (pruneDeps e allDeps)
35 | Lens.locally inferEnv addScope action
36 |
37 | toAnn :: HPlain Term -> Ann (Const ()) # Term
38 | toAnn = wrap (\_ x -> Ann (Const ()) x) . (^. hPlain)
39 |
40 | benchInferPure :: HPlain Term -> Benchmarkable
41 | benchInferPure e =
42 | infer x
43 | <&> (^. hAnn . Lens._2 . inferResult)
44 | >>= applyBindings
45 | & localInitEnv id x
46 | & runPureInfer emptyScope (InferState emptyPureInferState varGen)
47 | & Lens._Right %~ (^. Lens._1)
48 | & rnf
49 | & evaluate
50 | & whnfIO
51 | where
52 | x = toAnn e
53 |
54 | benchInferST :: HPlain Term -> Benchmarkable
55 | benchInferST e =
56 | do
57 | vg <- newSTRef varGen
58 | localInitEnv Lens._1 x
59 | (infer x <&> (^. hAnn . Lens._2 . inferResult) >>= applyBindings) ^. _STInfer
60 | & (`runReaderT` (emptyScope, vg))
61 | & runMaybeT
62 | & liftST >>= evaluate . rnf & whnfIO
63 | where
64 | x = toAnn e
65 |
66 | benches :: [(String, Benchmarkable)]
67 | benches =
68 | [ ("S_factorial", benchInferST factorialVal)
69 | , ("S_euler1", benchInferST euler1Val)
70 | , ("S_solveDepressedQuartic", benchInferST solveDepressedQuarticVal)
71 | , ("S_factors", benchInferST factorsVal)
72 | , ("P_factorial", benchInferPure factorialVal)
73 | , ("P_euler1", benchInferPure euler1Val)
74 | , ("P_solveDepressedQuartic", benchInferPure solveDepressedQuarticVal)
75 | , ("P_factors", benchInferPure factorsVal)
76 | ]
77 |
78 | main :: IO ()
79 | main = benches <&> uncurry bench & defaultMain
80 |
--------------------------------------------------------------------------------
/tools/core-type-apps.py:
--------------------------------------------------------------------------------
1 | def normalize(x):
2 | x = x.strip().replace('\n', ' ')
3 | while ' ' in x:
4 | x = x.replace(' ', ' ')
5 | return x
6 |
7 | def take_word():
8 | global c
9 | stack = []
10 | for i, x in enumerate(c):
11 | if stack and x == stack[-1]:
12 | stack.pop()
13 | continue
14 | if x == ')' or x == ',':
15 | break
16 | if x == '(':
17 | stack.append(')')
18 | continue
19 | if x == '[':
20 | stack.append(']')
21 | continue
22 | if x == ' ' and not stack:
23 | break
24 | r = normalize(c[:i])
25 | c = c[i:].strip()
26 | return r
27 |
28 | def params():
29 | global c
30 | r = [take_word()]
31 | while c.startswith('@'):
32 | c = c[1:].strip()
33 | r.append (take_word())
34 | return r
35 |
36 | def last_word(d):
37 | stack = []
38 | for i, x in enumerate(d[::-1]):
39 | if stack and x == stack[-1]:
40 | stack.pop()
41 | continue
42 | if x == '(':
43 | break
44 | if x == ')':
45 | stack.append('(')
46 | continue
47 | if x == ']':
48 | stack.append('[')
49 | continue
50 | if x == ' ' and not stack:
51 | break
52 | return normalize(d[-i:])
53 |
54 | cant = set('''
55 | >>=
56 | <*>
57 | <$
58 | fmap
59 | pure
60 | leq
61 | liftA2
62 | '''.split())
63 |
64 | ok = set('''
65 | $fMonadPureInfer_$s$fMonadRWST_$c>>=
66 | $fNFDataTypeError_$crnf
67 | $fNFDataType_$crnf
68 | $fOrdVar
69 | $fUnifyPureInferType_$cunifyError
70 | $fRecursiveUnify_$crecurse
71 | []
72 | :
73 | ++
74 | absentError
75 | emptyScope
76 | error
77 | heq_sel
78 | map
79 | newMutVar#
80 | pruneDeps1
81 | readMutVar#
82 | runMainIO1
83 | rwhnf
84 | seq#
85 | unifyError
86 | whnfIO'
87 | writeMutVar#
88 | '''.split())
89 | ok.update(cant)
90 |
91 | apps = set()
92 | for x in ['test/benchmark', 'src/Lamdu/Calc/Infer', 'src/Lamdu/Calc/Term']:
93 | c = open('dumps/'+x+'.dump-simpl').read().split('------ Local rules for imported ids --------', 1)[0]
94 | while '@' in c:
95 | pre, post = c.split('@', 1)
96 | c = post.strip()
97 | if pre.strip().endswith('\\'):
98 | # Lambda. Skip type parameters
99 | params()
100 | continue
101 | if post.startswith('~'):
102 | # Not a type application
103 | continue
104 | var = last_word(pre.strip())
105 | while var.startswith('('):
106 | var = var[1:]
107 | p = params()
108 | if var[:1].isupper():
109 | # Skip data constructors
110 | continue
111 | if var in ok:
112 | continue
113 | w = []
114 | for x in p:
115 | w += x.replace('(', ' ').replace('[', ' ').split()
116 | if [x for x in w if x[:1].islower() or x in ['RealWorld', 'Any']]:
117 | # Skip specializations with type variables,
118 | # find only the top-level specializations
119 | continue
120 | apps.add((var, tuple(p)))
121 |
122 | for var, p in sorted(apps):
123 | print(' '.join([var] + ['@'+x for x in p]))
124 |
--------------------------------------------------------------------------------
/doc/ExceptionMonad.md:
--------------------------------------------------------------------------------
1 | # Exception monad in Lamdu Calculus
2 |
3 | In Haskell, exception monads are not used for most code that can throw
4 | exceptions. Usually, their use is localized where the domain of known
5 | errors is fixed.
6 |
7 | The reason for this, we believe, is the type of bind in Haskell's exception
8 | monad:
9 |
10 | `(>>=) :: Either err a -> (a -> Either err b) -> Either err b`
11 |
12 | Note that `err` must be the same type in both sides of the bind - even
13 | if the kinds of errors thrown by the two bound actions are vastly
14 | different.
15 |
16 | Haskell's sum (variant) types are also fully nominal, meaning that one
17 | must manually declare all possible errors in all bound actions in a
18 | single sum type, to use that as the error type.
19 |
20 | This inhibits use of Haskell exception monads, and instead, much more
21 | code uses dynamically typed exceptions (`SomeException` and the
22 | `Exception` class).
23 |
24 | However, with extensible variant types and Lamdu Calculus's powerful
25 | case expressions, we can do better.
26 |
27 | ## Throwing errors
28 |
29 | ```
30 | div x y =
31 | case y == 0 of
32 | #True () -> #Error (#DivByZero ())
33 | #False () -> #Success (x / y)
34 | ```
35 |
36 | The inferred type of `div` is:
37 | ```
38 | forall err. #DivByZero ∉ err =>
39 | Int -> Int -> +{ #Error : +{ #DivByZero : () | err }, #Success : Int }
40 | ```
41 |
42 | Note that the only possible error is inferred, and we need not
43 | manually declare an error type.
44 |
45 | We also declare `fromJust` to extract a value from a `Maybe` type:
46 |
47 | ```
48 | fromJust = \case
49 | #Just x -> #Success x
50 | #Nothing () -> #Error #NotJust
51 | ```
52 | which has type:
53 | ```
54 | forall a err. #NotJust ∉ err =>
55 | +{ #Just : a, #Nothing : () } -> +{ #Success : a, #Error : +{ #NotJust : () | err } }
56 | ```
57 |
58 | If we bind the two actions together (using an ordinary exception monad):
59 |
60 | `\v -> fromJust v >>= div 1000`
61 |
62 | The resulting type is automatically inferred to be:
63 | ```
64 | forall err. {#NotJust, #DivByZero} ∉ err =>
65 | +{ #Nothing : (), #Just : Int } ->
66 | +{ #Error : +{#NotJust : (), #DivByZero : () | err}
67 | , #Success : Int
68 | }
69 | ```
70 |
71 | The automatically inferred set of errors is visibly available in the type.
72 |
73 | ## Catching errors
74 |
75 | We can declare an error catcher:
76 |
77 | ```
78 | catch handler =
79 | \case
80 | #Success v -> #Success v
81 | #Error -> handler
82 | ```
83 |
84 | And then we can catch just the `#DivByZero` error above:
85 |
86 | ```
87 | \v ->
88 | catch (\case
89 | #DivByZero -> #Success 0
90 | err -> #Error err)
91 | (fromJust v >>= div 1000)`
92 | ```
93 |
94 | The inferred type then becomes:
95 | ```
96 | forall err. {#NotJust} ∉ err =>
97 | +{ #Nothing : (), #Just : Int } ->
98 | +{ #Error : +{#NotJust : () | err}
99 | , #Success : Int
100 | }
101 | ```
102 |
103 | Once all errors are caught, a result type becomes:
104 |
105 | `forall err. +{ Error : +{| err}, #Success : ... }`
106 |
107 | We can then eliminate the variant wrapper with:
108 |
109 | ```
110 | successful : forall err a. +{ Error : +{| err}, #Success : a } -> a
111 | successful = \case
112 | #Error -> absurd
113 | #Success -> id
114 | ```
115 |
116 | This instantiates `err` to be the empty variant type (Void), and is thus
117 | allowed to eliminate the #Error/#Success wrapper to yield an `a`.
118 |
--------------------------------------------------------------------------------
/lamdu-calculus.cabal:
--------------------------------------------------------------------------------
1 | cabal-version: 1.12
2 |
3 | -- This file has been generated from package.yaml by hpack version 0.35.2.
4 | --
5 | -- see: https://github.com/sol/hpack
6 |
7 | name: lamdu-calculus
8 | version: 0.2.0.1
9 | synopsis: The Lamdu Calculus programming language
10 | description: Please see README.md
11 | category: Language
12 | homepage: https://github.com/lamdu/lamdu-calculus#readme
13 | bug-reports: https://github.com/lamdu/lamdu-calculus/issues
14 | author: Yair Chuchem, Eyal Lotem
15 | maintainer: yairchu@gmail.com
16 | copyright: 2021 Yair Chuchem, Eyal Lotem
17 | license: BSD3
18 | license-file: LICENSE
19 | build-type: Simple
20 | extra-source-files:
21 | README.md
22 |
23 | source-repository head
24 | type: git
25 | location: https://github.com/lamdu/lamdu-calculus
26 |
27 | library
28 | exposed-modules:
29 | Lamdu.Calc.Definition
30 | Lamdu.Calc.Identifier
31 | Lamdu.Calc.Infer
32 | Lamdu.Calc.Lens
33 | Lamdu.Calc.Term
34 | Lamdu.Calc.Term.Eq
35 | Lamdu.Calc.Type
36 | other-modules:
37 | Lamdu.Calc.Internal.Prelude
38 | hs-source-dirs:
39 | src
40 | ghc-options: -fexpose-all-unfoldings -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints -Wunused-packages -fdicts-cheap -O2 -fspecialise-aggressively
41 | ghc-prof-options: -O2
42 | build-depends:
43 | base >=4.7
44 | , base-compat >=0.8.2
45 | , base16-bytestring
46 | , binary
47 | , bytestring
48 | , containers
49 | , deepseq
50 | , generic-constraints
51 | , generic-data
52 | , hashable
53 | , hypertypes >=0.2
54 | , lattices
55 | , lens >=4.1
56 | , monad-st
57 | , mtl
58 | , pretty >=1.1.2
59 | , transformers
60 | default-language: Haskell2010
61 |
62 | test-suite lamdu-calculus-test
63 | type: exitcode-stdio-1.0
64 | main-is: test.hs
65 | other-modules:
66 | TestVals
67 | Paths_lamdu_calculus
68 | hs-source-dirs:
69 | test
70 | ghc-options: -fexpose-all-unfoldings -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints -Wunused-packages -fdicts-cheap -O2 -fspecialise-aggressively
71 | ghc-prof-options: -O2
72 | build-depends:
73 | HUnit
74 | , base >=4.7
75 | , base-compat >=0.8.2
76 | , bytestring
77 | , containers
78 | , hypertypes >=0.2
79 | , lamdu-calculus
80 | , lens >=4.1
81 | , test-framework
82 | , test-framework-hunit
83 | default-language: Haskell2010
84 |
85 | benchmark lamdu-calculus-bench
86 | type: exitcode-stdio-1.0
87 | main-is: benchmark.hs
88 | other-modules:
89 | TestVals
90 | Paths_lamdu_calculus
91 | hs-source-dirs:
92 | test
93 | ghc-options: -fexpose-all-unfoldings -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints -Wunused-packages -fdicts-cheap -O2 -fspecialise-aggressively -O2 -Wall -Wnoncanonical-monad-instances -Wcompat -Wincomplete-record-updates -Wincomplete-uni-patterns -Wredundant-constraints
94 | ghc-prof-options: -O2
95 | build-depends:
96 | base >=4.7
97 | , base-compat >=0.8.2
98 | , bytestring
99 | , containers
100 | , criterion
101 | , deepseq
102 | , hypertypes >=0.2
103 | , lamdu-calculus
104 | , lens >=4.1
105 | , monad-st
106 | , mtl
107 | , transformers
108 | default-language: Haskell2010
109 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # lamdu-calculus
2 |
3 | An extended typed Lambda Calculus AST.
4 |
5 | Used by [Lamdu](http://www.lamdu.org/) as the underlying unsugared language.
6 |
7 | The Lamdu Calculus language is a statically typed polymorphic
8 | non-dependent lambda calculus extension.
9 |
10 | The language does not include a parser, as it is meant to be used
11 | structurally via its AST directly. It does include a pretty-printer,
12 | for ease of debugging.
13 |
14 | Type inference for Lamdu Calculus is implemented in a [separate
15 | package](https://github.com/lamdu/Algorithm-W-Step-By-Step).
16 |
17 | ## Values
18 |
19 | The term language supports lambda abstractions and variables, literal
20 | values of custom types, holes, nominal type wrapping/unwrapping,
21 | extensible variant types and extensible records.
22 |
23 | ## Types
24 |
25 | The type language includes nominal types with polymoprhic components,
26 | extensible record types (row types) and extensible variant types
27 | (polymorphic variants / column types).
28 |
29 | ## Composite types
30 |
31 | In Lamdu Calculus, both records and variants are "composite" types and
32 | share an underlying AST, using a phantom type to distinguish them.
33 |
34 | Records are sets of typed and named fields.
35 |
36 | Variants are a value of a set of possible typed data constructors.
37 |
38 | Both records and variants are order-agnostic, and disallow repeated
39 | appearance of the same names for fields or data constructors.
40 |
41 | Both records and variants are extensible. This is usually called "row
42 | types" for records and "polymorphic variants" for variants. We use the
43 | terminology of "row" and "column" types instead.
44 |
45 | ## Records
46 |
47 | In the term language, records are constructed using these AST constructions:
48 |
49 | `BLeaf LRecEmpty` denotes the empty record (denote in pseudo-syntax as `()`).
50 | Its type is `TRecord CEmpty`.
51 |
52 | `BRecExtend (RowExtend tag (value : T) (rest : R))`1 denotes a record extension
53 | of an existing record `rest`.
54 |
55 | Its type is `TRecord (CExtend tag T R)`.
56 |
57 | A sequence of `BRecExtend` followed by a `BLeaf LRecEmpty` is denoted
58 | in pseudo-syntax as `{ x : T, y : U, ... }`.
59 |
60 | To deconstruct a record, we use the `GetField` operation (denoted in
61 | pseudo-syntax as `.`).
62 |
63 | 1. The `(value : Type)` notation is used to denote values' types
64 |
65 | ### Row types
66 |
67 | Row types are denoted using a record type variable.
68 |
69 | For example, lets examine a lambda abstraction. Assume `(+) : Int →
70 | Int → Int`.
71 |
72 | ```
73 | vector → vector.x + vector.y
74 | ```
75 |
76 | We can infer the type:
77 |
78 | ```Haskell
79 | { x : Int, y : Int } → Int
80 | ```
81 |
82 | But the most general type is:
83 |
84 | ```Haskell
85 | forall r1. {x, y} ∉ r1 => { x : Int, y : Int | r1 } → Int
86 | ```
87 | 2
88 |
89 | The lambda may be applied with a record that contains more fields
90 | besides `x` and `y`, and that is what the `| r1` denotes. Note that to
91 | enforce the no-duplication requirement Lamdu Calculus uses a
92 | constraint on composite type variables, for each field that they may
93 | not duplicate.
94 |
95 | 2. The `|` symbol in the pseudo-syntax is used to indicate
96 | that some set of fields denoted by the variable that follows (in this
97 | case, `r1`) is concatenated to the set)
98 |
99 | ## Variants
100 |
101 | In the term language, variants are constructed using value *injection*.
102 |
103 | For example, let's look at the type:
104 |
105 | ```Haskell
106 | +{ Nothing : (), Just : Int }
107 | ```
108 |
109 | This pseudo syntax means the structural variant type isomorphic to
110 | Haskell's `Maybe Int`.
111 |
112 | The value `5 : Int` is not of the type: `+{ Nothing : (), Just : Int }`.
113 |
114 | In Haskell, we use `Just` to "lift" 5 from `Int` to `Maybe Int`. In
115 | Lamdu Calculus, we inject via `BInject (Inject "Just" 5)`, which we denote in
116 | pseudo-syntax as `Just: 5`.
117 |
118 | `BInject (Inject tag (value : T))` *injects* a given value into a
119 | variant type. Its type is `forall alts. tag ∉ alts => TVariant (CExtend tag T alts)`.
120 |
121 | This type means that the injected value allows any set of typed
122 | alternatives to exist in the larger variant.
123 |
124 | ### Case expressions
125 |
126 | Deconstructing a variant requires a case statement. Lamdu Calculus
127 | case statements *peel* one alternate case at a time, so need to be
128 | composed to create an ordinary full case.
129 |
130 | `BCase (Case tag handler rest)` denoted in pseudo syntax as
131 | ```Haskell
132 | \case
133 | tag: handler
134 | rest
135 | ```
136 |
137 | The above case expression creates a function with a variant
138 | parameter. It analyzes its argument, and if it is a `tag`, the given
139 | `handler` is invoked with the typed content of the `tag`.
140 |
141 | If the argument is not a `tag`, then the `rest` handler is
142 | invoked. The `rest` handler is given a smaller variant type as an
143 | argument. A variant type which no longer has the `tag` case inside it, as
144 | that was ruled out.
145 |
146 | ```Haskell
147 | \case
148 | Just: x → x + 1
149 | \case
150 | Nothing: () → 0
151 | ?
152 | ```
153 |
154 | The above composed case expression will match against `Just`:
155 |
156 | * Match: it will evaluate to the content of the `Just` added to 1.
157 |
158 | * Mismatch: it will match against `Nothing`:
159 |
160 | * Match: it will ignore the empty record contained in the
161 | `Nothing` case, and evaluate to 0
162 |
163 | * Mismatch: Evaluate to a hole (denoted by `?`) which is like
164 | Haskell's `undefined`, and takes on any type.
165 |
166 | The type of the above case statement would be inferred to:
167 |
168 | ```Haskell
169 | forall v1. (Just, Nothing) ∉ v1 =>
170 | +{ Just : Int , Nothing : () | v1 } → Int
171 | ```
172 |
173 | Note that the case statement allows *any* structural variant type that
174 | has the proper `Nothing` and `Just` cases (and it would reach a hole if the
175 | value happens to be neither `Nothing` nor `Just`). This is not
176 | typically what we want. We'd like to *close* the variant type so it is *not* extensible.
177 |
178 | To do that, we must also support the empty case statement that matches
179 | no possible alternatives, allowing us to *close* the composition. The
180 | empty case statement AST construction is:
181 |
182 | BLeaf LAbsurd, and is denoted as `absurd` (as it is the analogue of
183 | the `absurd` function in Haskell and Agda). The type of `absurd` is:
184 | `forall r. +{} → r` (`+{}` is the empty variant type, aka `Void`).
185 |
186 | We can now close the above case expression:
187 |
188 | ```Haskell
189 | \case
190 | Just: x → x + 1
191 | \case
192 | Nothing: () → 0
193 | absurd
194 | ```
195 |
196 | And now our inferred type will be simpler:
197 |
198 | ```Haskell
199 | +{ Just : Int, Nothing : () } → Int
200 | ```
201 |
202 | As a short-hand for nested/composed cases, we'll use Haskell-like
203 | syntax for cases as well, meaning the expanded, composed case.
204 |
205 | ```Haskell
206 | \case
207 | Just: x → x + 1
208 | Nothing: () → 0
209 | ```
210 |
211 | ### Example use-case: Composing interpreters
212 |
213 | We can define single-command interpreters:
214 |
215 | ```Haskell
216 | interpretAdd default =
217 | \case
218 | Add: {x, y} → x + y
219 | default
220 | ```
221 |
222 | ```Haskell
223 | interpretMul default =
224 | \case
225 | Mul {x, y} → x * y
226 | default
227 | ```
228 |
229 | And then we can compose them:
230 |
231 | ```Haskell
232 | interpreterArithmetic = interpretAdd . interpretMul
233 |
234 | interpreter = (interpreterArithmetic . interpreterConditionals) absurd
235 | ```
236 |
237 | Yielding a single interperter that can handle the full set of cases in
238 | its composition.
239 |
240 | ### Example use-case: Exception Monads
241 |
242 | See the [Exception Monad](ExceptionMonad.md) use-case for a more
243 | compelling example use-case of extensible variants.
244 |
245 | ## Nominal types
246 |
247 | ### Purposes
248 |
249 | #### Recursive Types
250 |
251 | Infinite structural types are not allowed, so you can use a nominal type to “tie the knot”.
252 | For example:
253 |
254 | ```Haskell
255 | Stream a = +{ Empty : (), NonEmpty : { head: a, tail: () -> Stream a } }
256 | ```
257 |
258 | A `Stream a` may contain a `Stream a` within, therefore the type is recursive.
259 |
260 | #### Safety
261 |
262 | Often one wants to distinguish types whose representations/structure are the same, but their meanings are very different
263 |
264 | #### Rank-N types
265 |
266 | Nominal types may bind type variables (i.e: the structural type within a nominal type is wrapped in “forall”s)
267 |
268 | This is being used by the `Mut` type (equivalent to Haskell's `ST`)
269 |
270 | # Future plans
271 |
272 | * Unifying with the underlying AST in
273 | [AlgoWMutable](https://github.com/Peaker/AlgoWMutable) which
274 | implements type inference more efficiently than
275 | [Algorithm-W-Step-By-Step](https://github.com/lamdu/Algorithm-W-Step-By-Step)
276 | which is compatible with this package.
277 |
278 | * Support for type-classes
279 |
--------------------------------------------------------------------------------
/src/Lamdu/Calc/Infer.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE NoImplicitPrelude, TemplateHaskell, GeneralizedNewtypeDeriving #-}
2 | {-# LANGUAGE UndecidableInstances, MultiParamTypeClasses, TypeFamilies #-}
3 | {-# LANGUAGE FlexibleInstances, LambdaCase, DataKinds, FlexibleContexts #-}
4 | {-# LANGUAGE TypeApplications, RankNTypes, DerivingStrategies, TypeOperators #-}
5 | {-# LANGUAGE ScopedTypeVariables #-}
6 |
7 | module Lamdu.Calc.Infer
8 | ( InferState(..), isBinding, isQVarGen
9 | , emptyPureInferState
10 | , PureInfer(..), _PureInfer, runPureInfer
11 | , STInfer(..), _STInfer
12 | , loadDeps
13 | , varGen
14 | , alphaEq
15 | ) where
16 |
17 | import Hyper
18 | import Hyper.Infer
19 | import Hyper.Syntax.Nominal
20 | import qualified Hyper.Syntax.Var as TermVar
21 | import qualified Hyper.Syntax.Scheme as S
22 | import qualified Hyper.Syntax.Scheme.AlphaEq as S
23 | import Hyper.Unify
24 | import Hyper.Unify.Binding
25 | import Hyper.Unify.Binding.ST
26 | import Hyper.Unify.Generalize
27 | import Hyper.Unify.QuantifiedVar
28 | import Hyper.Unify.Term (UTerm, UTermBody)
29 | import Control.Applicative (Alternative(..))
30 | import qualified Control.Lens as Lens
31 | import Control.Lens (LensLike')
32 | import Control.Monad.Except
33 | import Control.Monad.Reader.Class
34 | import Control.Monad.ST
35 | import Control.Monad.ST.Class (MonadST(..))
36 | import Control.Monad.State
37 | import Control.Monad.Trans.Maybe
38 | import Control.Monad.Trans.RWS (RWST(..))
39 | import Control.Monad.Trans.Reader (ReaderT(..))
40 | import Control.Monad.Trans.Writer (WriterT)
41 | import Data.STRef
42 | import Lamdu.Calc.Definition (Deps, depsNominals, depsGlobalTypes)
43 | import Lamdu.Calc.Term
44 | import qualified Lamdu.Calc.Type as T
45 |
46 | import Lamdu.Calc.Internal.Prelude
47 |
48 | data QVarGen = QVarGen
49 | { _nextTV :: !Int
50 | , _nextRV :: !Int
51 | } deriving (Eq, Ord, Show)
52 | Lens.makeLenses ''QVarGen
53 |
54 | varGen :: QVarGen
55 | varGen = QVarGen 0 0
56 |
57 | data InferState = InferState
58 | { _isBinding :: T.Types # Binding
59 | , _isQVarGen :: QVarGen
60 | } deriving (Eq, Ord, Show)
61 | Lens.makeLenses ''InferState
62 |
63 | newtype PureInfer env a =
64 | PureInfer
65 | (RWST env () InferState (Either (T.TypeError # UVar)) a)
66 | deriving newtype
67 | ( Functor, Applicative, Monad
68 | , MonadReader env
69 | , MonadError (T.TypeError # UVar)
70 | , MonadState InferState
71 | )
72 | Lens.makePrisms ''PureInfer
73 |
74 | runPureInfer ::
75 | env -> InferState -> PureInfer env a ->
76 | Either (T.TypeError # UVar) (a, InferState)
77 | runPureInfer env st (PureInfer act) =
78 | runRWST act env st <&> \(x, s, ~()) -> (x, s)
79 |
80 | type instance UVarOf (PureInfer _) = UVar
81 |
82 | loadDeps ::
83 | (UnifyGen m T.Row, S.HasScheme T.Types m T.Type) =>
84 | Deps -> m (Scope # UVarOf m -> Scope # UVarOf m)
85 | loadDeps deps =
86 | do
87 | loadedNoms <- deps ^. depsNominals & traverse loadNominalDecl
88 | loadedSchemes <- deps ^. depsGlobalTypes & traverse S.loadScheme
89 | pure $ \env ->
90 | env
91 | & scopeVarTypes <>~ (loadedSchemes <&> MkHFlip)
92 | & scopeNominals <>~ loadedNoms
93 |
94 | instance MonadScopeLevel (PureInfer (Scope # UVar)) where
95 | localLevel = local (scopeLevel . _ScopeLevel +~ 1)
96 |
97 | instance MonadNominals T.NominalId T.Type (PureInfer (Scope # UVar)) where
98 | {-# INLINE getNominalDecl #-}
99 | getNominalDecl n =
100 | Lens.view (scopeNominals . Lens.at n)
101 | >>= maybe (throwError (T.NominalNotFound n)) pure
102 |
103 | instance TermVar.HasScope (PureInfer (Scope # UVar)) Scope where
104 | {-# INLINE getScope #-}
105 | getScope = Lens.view id
106 |
107 | instance LocalScopeType Var (UVar # T.Type) (PureInfer (Scope # UVar)) where
108 | {-# INLINE localScopeType #-}
109 | localScopeType k v = local (scopeVarTypes . Lens.at k ?~ MkHFlip (GMono v))
110 |
111 | instance UnifyGen (PureInfer (Scope # UVar)) T.Type where
112 | {-# INLINE scopeConstraints #-}
113 | scopeConstraints _ = Lens.view scopeLevel
114 |
115 | nextTVNamePure ::
116 | (MonadState InferState m, IsString a) =>
117 | Char -> LensLike' ((,) Int) QVarGen Int -> m a
118 | nextTVNamePure prefix lens =
119 | isQVarGen . lens <<%= (+1) <&> show <&> (prefix :) <&> fromString
120 |
121 | instance MonadQuantify ScopeLevel T.TypeVar (PureInfer env) where
122 | newQuantifiedVariable _ = nextTVNamePure 't' nextTV
123 |
124 | instance UnifyGen (PureInfer (Scope # UVar)) T.Row where
125 | {-# INLINE scopeConstraints #-}
126 | scopeConstraints _ = scopeConstraints (Proxy @T.Type) <&> T.RowConstraints mempty
127 |
128 | instance MonadQuantify T.RConstraints T.RowVar (PureInfer env) where
129 | newQuantifiedVariable _ = nextTVNamePure 'r' nextRV
130 |
131 | instance Unify (PureInfer env) T.Type where
132 | {-# INLINE binding #-}
133 | binding = bindingDict (isBinding . T.tType)
134 | unifyError = throwError . T.TypeError
135 |
136 | instance Unify (PureInfer env) T.Row where
137 | {-# INLINE binding #-}
138 | binding = bindingDict (isBinding . T.tRow)
139 | unifyError = throwError . T.RowError
140 | {-# INLINE structureMismatch #-}
141 | structureMismatch = T.rStructureMismatch
142 |
143 | emptyPureInferState :: T.Types # Binding
144 | emptyPureInferState = T.Types emptyBinding emptyBinding
145 |
146 | newtype STInfer s a = STInfer
147 | (ReaderT (Scope # STUVar s, STRef s QVarGen) (MaybeT (ST s)) a)
148 | deriving newtype
149 | ( Functor, Alternative, Applicative, Monad, MonadST
150 | , MonadReader (Scope # STUVar s, STRef s QVarGen)
151 | )
152 | Lens.makePrisms ''STInfer
153 |
154 | type instance UVarOf (STInfer s) = STUVar s
155 |
156 | instance MonadScopeLevel (STInfer s) where
157 | localLevel = local (Lens._1 . scopeLevel . _ScopeLevel +~ 1)
158 |
159 | instance MonadNominals T.NominalId T.Type (STInfer s) where
160 | {-# INLINE getNominalDecl #-}
161 | getNominalDecl n =
162 | Lens.view (Lens._1 . scopeNominals . Lens.at n) >>= maybe empty pure
163 |
164 | instance TermVar.HasScope (STInfer s) Scope where
165 | {-# INLINE getScope #-}
166 | getScope = Lens.view Lens._1
167 |
168 | instance LocalScopeType Var (STUVar s # T.Type) (STInfer s) where
169 | {-# INLINE localScopeType #-}
170 | localScopeType k v =
171 | local (Lens._1 . scopeVarTypes . Lens.at k ?~ MkHFlip (GMono v))
172 |
173 | instance UnifyGen (STInfer s) T.Type where
174 | {-# INLINE scopeConstraints #-}
175 | scopeConstraints _ = Lens.view (Lens._1 . scopeLevel)
176 |
177 | nextTVNameST :: IsString a => Char -> Lens.ALens' QVarGen Int -> STInfer s a
178 | nextTVNameST prefix lens =
179 | do
180 | genRef <- Lens.view Lens._2
181 | gen <- readSTRef genRef & liftST
182 | let res = prefix : show (gen ^# lens) & fromString
183 | let newGen = gen & lens #%~ (+1)
184 | res <$ writeSTRef genRef newGen & liftST
185 |
186 | instance MonadQuantify ScopeLevel T.TypeVar (STInfer s) where
187 | newQuantifiedVariable _ = nextTVNameST 't' nextTV
188 |
189 | instance UnifyGen (STInfer s) T.Row where
190 | {-# INLINE scopeConstraints #-}
191 | scopeConstraints _ = scopeConstraints (Proxy @T.Type) <&> T.RowConstraints mempty
192 |
193 | instance MonadQuantify T.RConstraints T.RowVar (STInfer s) where
194 | newQuantifiedVariable _ = nextTVNameST 'r' nextRV
195 |
196 | instance Unify (STInfer s) T.Type where
197 | {-# INLINE binding #-}
198 | binding = stBinding
199 | unifyError _ = empty
200 |
201 | instance Unify (STInfer s) T.Row where
202 | {-# INLINE binding #-}
203 | binding = stBinding
204 | unifyError _ = empty
205 | {-# INLINE structureMismatch #-}
206 | structureMismatch = T.rStructureMismatch
207 |
208 | alphaEq ::
209 | Pure # S.Scheme T.Types T.Type ->
210 | Pure # S.Scheme T.Types T.Type ->
211 | Bool
212 | alphaEq x y =
213 | runST $
214 | do
215 | vg <- newSTRef varGen
216 | S.alphaEq x y
217 | ^. _STInfer
218 | & (`runReaderT` (emptyScope, vg))
219 | & runMaybeT
220 | <&> Lens.has Lens._Just
221 |
222 | {-# SPECIALIZE unify :: STUVar s # T.Row -> STUVar s # T.Row -> STInfer s (STUVar s # T.Row) #-}
223 | {-# SPECIALIZE updateConstraints :: ScopeLevel -> STUVar s # T.Type -> UTerm (STUVar s) # T.Type -> STInfer s () #-}
224 | {-# SPECIALIZE updateConstraints :: T.RConstraints -> STUVar s # T.Row -> UTerm (STUVar s) # T.Row -> STInfer s () #-}
225 | {-# SPECIALIZE updateTermConstraints :: STUVar s # T.Row -> UTermBody (STUVar s) # T.Row -> T.RConstraints -> STInfer s () #-}
226 | {-# SPECIALIZE instantiateH :: (forall n. TypeConstraintsOf n -> UTerm (STUVar s) # n) -> GTerm (STUVar s) # T.Row -> WriterT [STInfer s ()] (STInfer s) (STUVar s # T.Row) #-}
227 | {-# SPECIALIZE semiPruneLookup :: STUVar s # T.Type -> STInfer s (STUVar s # T.Type, UTerm (STUVar s) # T.Type) #-}
228 | {-# SPECIALIZE semiPruneLookup :: STUVar s # T.Row -> STInfer s (STUVar s # T.Row, UTerm (STUVar s) # T.Row) #-}
229 | {-# SPECIALIZE unifyUTerms :: STUVar s # T.Row -> UTerm (STUVar s) # T.Row -> STUVar s # T.Row -> UTerm (STUVar s) # T.Row -> STInfer s (STUVar s # T.Row) #-}
230 |
--------------------------------------------------------------------------------
/test/TestVals.hs:
--------------------------------------------------------------------------------
1 | {-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
2 | {-# LANGUAGE NoImplicitPrelude, OverloadedStrings, TypeOperators #-}
3 |
4 | module TestVals
5 | ( allDeps
6 | , factorialVal, euler1Val, solveDepressedQuarticVal
7 | , factorsVal
8 | , recordType
9 | , intType
10 | , listTypePair, boolTypePair
11 | ) where
12 |
13 | import Hyper
14 | import Hyper.Syntax.Nominal
15 | import Hyper.Syntax.Row
16 | import Hyper.Syntax.Scheme
17 | import Hyper.Type.Prune
18 | import qualified Control.Lens as Lens
19 | import Control.Lens.Operators
20 | import qualified Data.ByteString.Char8 as BS8
21 | import qualified Data.Map as Map
22 | import Lamdu.Calc.Definition (Deps(..))
23 | import Lamdu.Calc.Term
24 | import Lamdu.Calc.Type (Type, (~>))
25 | import qualified Lamdu.Calc.Type as T
26 |
27 | import Prelude.Compat
28 |
29 | {-# ANN module ("HLint: ignore Redundant $" :: String) #-}
30 |
31 | -- TODO: $$ to be type-classed for TApp vs BApp
32 | -- TODO: TCon "->" instead of TFun
33 |
34 | recExtends :: Pure # T.Row -> [(T.Tag, Pure # Type)] -> Pure # Type
35 | recExtends recTail fields =
36 | foldr
37 | (\(tag, typ) -> Pure . T.RExtend . RowExtend tag typ) recTail fields
38 | & T.TRecord & Pure
39 |
40 | recordType :: [(T.Tag, Pure # Type)] -> Pure # Type
41 | recordType = recExtends (_Pure # T.REmpty)
42 |
43 | forAll ::
44 | [T.TypeVar] -> ([Pure # Type] -> Pure # Type) -> Pure # T.Scheme
45 | forAll tvs mkType =
46 | tvs <&> T.TVar <&> (_Pure #) & mkType
47 | & Scheme T.Types
48 | { T._tType = QVars (Map.fromList [(tv, mempty) | tv <- tvs])
49 | , T._tRow = mempty
50 | } & Pure
51 |
52 | stOf :: Pure # Type -> Pure # Type -> Pure # Type
53 | stOf s a =
54 | T.Types (QVarInstances (mempty & Lens.at "res" ?~ a & Lens.at "s" ?~ s)) (QVarInstances mempty)
55 | & NominalInst "ST" & T.TInst & Pure
56 |
57 | listTypePair :: (T.NominalId, Pure # NominalDecl Type)
58 | listTypePair =
59 | ( "List"
60 | , _Pure # NominalDecl
61 | { _nParams =
62 | T.Types
63 | { T._tType = mempty & Lens.at "elem" ?~ mempty
64 | , T._tRow = mempty
65 | }
66 | , _nScheme =
67 | _Pure # T.REmpty
68 | & RowExtend "[]" (recordType []) & T.RExtend & Pure
69 | & RowExtend ":" (recordType [("head", tv), ("tail", listOf tv)])
70 | & T.RExtend & Pure
71 | & T.TVariant & Pure
72 | & Scheme (T.Types mempty mempty)
73 | }
74 | )
75 | where
76 | tv = _Pure # T.TVar "a"
77 |
78 | listOf :: Pure # Type -> Pure # Type
79 | listOf x =
80 | T.Types (QVarInstances (mempty & Lens.at "elem" ?~ x)) (QVarInstances mempty)
81 | & NominalInst (fst listTypePair) & T.TInst & Pure
82 |
83 | boolType :: Pure # Type
84 | boolType =
85 | T.Types (QVarInstances mempty) (QVarInstances mempty)
86 | & NominalInst (fst boolTypePair) & T.TInst & Pure
87 |
88 | intType :: Pure # Type
89 | intType =
90 | T.Types (QVarInstances mempty) (QVarInstances mempty)
91 | & NominalInst "Int" & T.TInst & Pure
92 |
93 | boolTypePair :: (T.NominalId, Pure # NominalDecl Type)
94 | boolTypePair =
95 | ( "Bool"
96 | , _Pure # NominalDecl
97 | { _nParams = T.Types mempty mempty
98 | , _nScheme =
99 | _Pure # T.REmpty
100 | & RowExtend "True" (recordType []) & T.RExtend & Pure
101 | & RowExtend "False" (recordType []) & T.RExtend & Pure
102 | & T.TVariant & Pure
103 | & Scheme (T.Types mempty mempty)
104 | }
105 | )
106 |
107 | maybeOf :: Pure # Type -> Pure # Type
108 | maybeOf t =
109 | _Pure # T.REmpty
110 | & RowExtend "Just" t & T.RExtend & Pure
111 | & RowExtend "Nothing" (recordType []) & T.RExtend & Pure
112 | & T.TVariant & Pure
113 |
114 | infixType :: Pure # Type -> Pure # Type -> Pure # Type -> Pure # Type
115 | infixType a b c = recordType [("l", a), ("r", b)] ~> c
116 |
117 | allDeps :: Deps
118 | allDeps =
119 | Deps
120 | { _depsNominals = Map.fromList [boolTypePair, listTypePair]
121 | , _depsGlobalTypes =
122 | Map.fromList
123 | [ ("fix", forAll ["a"] $ \ [a] -> (a ~> a) ~> a)
124 | , ("if", forAll ["a"] $ \ [a] -> recordType [("condition", boolType), ("then", a), ("else", a)] ~> a)
125 | , ("==", forAll ["a"] $ \ [a] -> infixType a a boolType)
126 | , (">", forAll ["a"] $ \ [a] -> infixType a a boolType)
127 | , ("%", forAll ["a"] $ \ [a] -> infixType a a a)
128 | , ("*", forAll ["a"] $ \ [a] -> infixType a a a)
129 | , ("-", forAll ["a"] $ \ [a] -> infixType a a a)
130 | , ("+", forAll ["a"] $ \ [a] -> infixType a a a)
131 | , ("/", forAll ["a"] $ \ [a] -> infixType a a a)
132 | , ("//", forAll [] $ \ [] -> infixType intType intType intType)
133 | , ("sum", forAll ["a"] $ \ [a] -> listOf a ~> a)
134 | , ("filter", forAll ["a"] $ \ [a] -> recordType [("from", listOf a), ("predicate", a ~> boolType)] ~> listOf a)
135 | , (":", forAll ["a"] $ \ [a] -> recordType [("head", a), ("tail", listOf a)] ~> listOf a)
136 | , ("[]", forAll ["a"] $ \ [a] -> listOf a)
137 | , ("concat", forAll ["a"] $ \ [a] -> listOf (listOf a) ~> listOf a)
138 | , ("map", forAll ["a", "b"] $ \ [a, b] -> recordType [("list", listOf a), ("mapping", a ~> b)] ~> listOf b)
139 | , ("..", forAll [] $ \ [] -> infixType intType intType (listOf intType))
140 | , ("||", forAll [] $ \ [] -> infixType boolType boolType boolType)
141 | , ("head", forAll ["a"] $ \ [a] -> listOf a ~> a)
142 | , ("negate", forAll ["a"] $ \ [a] -> a ~> a)
143 | , ("sqrt", forAll ["a"] $ \ [a] -> a ~> a)
144 | , ("id", forAll ["a"] $ \ [a] -> a ~> a)
145 | , ("zipWith",forAll ["a","b","c"] $ \ [a,b,c] ->
146 | (a ~> b ~> c) ~> listOf a ~> listOf b ~> listOf c )
147 | , ("Just", forAll ["a"] $ \ [a] -> a ~> maybeOf a)
148 | , ("Nothing",forAll ["a"] $ \ [a] -> maybeOf a)
149 | , ("maybe", forAll ["a", "b"] $ \ [a, b] -> b ~> (a ~> b) ~> maybeOf a ~> b)
150 | , ("plus1", forAll [] $ \ [] -> intType ~> intType)
151 | , ("True", forAll [] $ \ [] -> boolType)
152 | , ("False", forAll [] $ \ [] -> boolType)
153 |
154 | , ("stBind", forAll ["s", "a", "b"] $ \ [s, a, b] -> infixType (stOf s a) (a ~> stOf s b) (stOf s b))
155 | ]
156 | }
157 |
158 | litInt :: Integer -> HPlain Term
159 | litInt = BLeafP . LLiteral . PrimVal "Int" . BS8.pack . show
160 |
161 | record :: [(T.Tag, HPlain Term)] -> HPlain Term
162 | record = foldr (uncurry BRecExtendP) (BLeafP LRecEmpty)
163 |
164 | ($$:) :: HPlain Term -> [(T.Tag, HPlain Term)] -> HPlain Term
165 | f $$: args = BAppP f (record args)
166 |
167 | inf :: HPlain Term -> HPlain Term -> HPlain Term -> HPlain Term
168 | inf l f r = f $$: [("l", l), ("r", r)]
169 |
170 | factorialVal :: HPlain Term
171 | factorialVal =
172 | BAppP "fix" $
173 | BLamP "loop" Pruned $
174 | BLamP "x" Pruned $
175 | "if" $$:
176 | [ ("condition", inf "x" "==" (litInt 0))
177 | , ("then", litInt 1)
178 | , ("else", inf "x" "*" (BAppP "loop" $ inf "x" "-" (litInt 1)))
179 | ]
180 |
181 | euler1Val :: HPlain Term
182 | euler1Val =
183 | BAppP "sum" $
184 | "filter" $$:
185 | [ ("from", inf (litInt 1) ".." (litInt 1000))
186 | , ("predicate",
187 | BLamP "x" Pruned $
188 | inf
189 | (inf (litInt 0) "==" (inf "x" "%" (litInt 3)))
190 | "||"
191 | (inf (litInt 0) "==" (inf "x" "%" (litInt 5)))
192 | )
193 | ]
194 |
195 | let_ :: Var -> HPlain Term -> HPlain Term -> HPlain Term
196 | let_ k v r = BAppP (BLamP k Pruned r) v
197 |
198 | cons :: HPlain Term -> HPlain Term -> HPlain Term
199 | cons h t = BLeafP (LInject ":") `BAppP` record [("head", h), ("tail", t)] & BToNomP "List"
200 |
201 | list :: [HPlain Term] -> HPlain Term
202 | list = foldr cons (BToNomP "List" (BLeafP (LInject "[]") `BAppP` BLeafP LRecEmpty))
203 |
204 | solveDepressedQuarticVal :: HPlain Term
205 | solveDepressedQuarticVal =
206 | BLamP "params" Pruned $
207 | let_ "solvePoly" "id" $
208 | let_ "sqrts"
209 | ( BLamP "x" Pruned $
210 | let_ "r" (BAppP "sqrt" "x") $
211 | list ["r", BAppP "negate" "r"]
212 | ) $
213 | "if" $$:
214 | [ ("condition", inf d "==" (litInt 0))
215 | , ( "then",
216 | BAppP "concat" $
217 | "map" $$:
218 | [ ("list", BAppP "solvePoly" $ list [e, c, litInt 1])
219 | , ("mapping", "sqrts")
220 | ])
221 | , ( "else",
222 | BAppP "concat" $
223 | "map" $$:
224 | [ ( "list",
225 | BAppP "sqrts" $ BAppP "head" $ BAppP "solvePoly" $
226 | list
227 | [ BAppP "negate" (d %* d)
228 | , (c %* c) %- (litInt 4 %* e)
229 | , litInt 2 %* c
230 | , litInt 1
231 | ])
232 | , ( "mapping",
233 | BLamP "x" Pruned $
234 | BAppP "solvePoly" $
235 | list
236 | [ (c %+ ("x" %* "x")) %- (d %/ "x")
237 | , litInt 2 %* "x"
238 | , litInt 2
239 | ])
240 | ])
241 | ]
242 | where
243 | c = BLeafP (LGetField "c") `BAppP` "params"
244 | d = BLeafP (LGetField "d") `BAppP` "params"
245 | e = BLeafP (LGetField "e") `BAppP` "params"
246 |
247 | (%+), (%-), (%*), (%/), (%//), (%>), (%%), (%==) :: HPlain Term -> HPlain Term -> HPlain Term
248 | x %+ y = inf x "+" y
249 | x %- y = inf x "-" y
250 | x %* y = inf x "*" y
251 | x %/ y = inf x "/" y
252 | x %// y = inf x "//" y
253 | x %> y = inf x ">" y
254 | x %% y = inf x "%" y
255 | x %== y = inf x "==" y
256 |
257 | factorsVal :: HPlain Term
258 | factorsVal =
259 | BAppP "fix" $
260 | BLamP "loop" Pruned $
261 | BLamP "params" Pruned $
262 | if_ ((m %* m) %> n) (list [n]) $
263 | if_ ((n %% m) %== litInt 0)
264 | (cons m $ "loop" $$: [("n", n %// m), ("min", m)]) $
265 | "loop" $$: [("n", n), ("min", m %+ litInt 1)]
266 | where
267 | n = BLeafP (LGetField "n") `BAppP` "params"
268 | m = BLeafP (LGetField "min") `BAppP` "params"
269 | if_ b t f =
270 | BCaseP "False" (BLamP "_" Pruned f) (BCaseP "True" (BLamP "_" Pruned t) (BLeafP LAbsurd))
271 | `BAppP` (BLeafP (LFromNom "Bool") `BAppP` b)
272 |
--------------------------------------------------------------------------------
/src/Lamdu/Calc/Lens.hs:
--------------------------------------------------------------------------------
1 | {-# LANGUAGE NoImplicitPrelude, RankNTypes, NoMonomorphismRestriction #-}
2 | {-# LANGUAGE FlexibleContexts, TypeFamilies, TypeApplications, FlexibleInstances #-}
3 | {-# LANGUAGE DataKinds, ScopedTypeVariables, LambdaCase, TypeOperators #-}
4 | module Lamdu.Calc.Lens
5 | ( -- Leafs
6 | valHole , valBodyHole
7 | , valVar
8 | , valLiteral , valBodyLiteral
9 | , valLeafs
10 | -- Non-leafs
11 | , valApply
12 | , valTags
13 | , valGlobals
14 | , valNominals
15 | -- Subexpressions:
16 | , subExprPayloads
17 | , payloadsOf
18 | , HasTIds(..), tIds, qVarIds
19 | , HasQVarIds(..)
20 | , HasQVar(..), qvarsQVarIds
21 | ) where
22 |
23 | import Control.Lens (Traversal', Prism')
24 | import qualified Control.Lens as Lens
25 | import qualified Data.Map as Map
26 | import qualified Data.Set as Set
27 | import Hyper
28 | import Hyper.Recurse
29 | import Hyper.Syntax.Nominal (ToNom(..), NominalInst(..), NominalDecl, nScheme)
30 | import Hyper.Syntax.Row (RowExtend(..))
31 | import Hyper.Syntax.Scheme (Scheme, QVars, _QVars, QVarInstances, _QVarInstances, sTyp)
32 | import Hyper.Type.Prune (Prune, _Unpruned)
33 | import Lamdu.Calc.Identifier (Identifier)
34 | import Lamdu.Calc.Term (Val)
35 | import qualified Lamdu.Calc.Term as V
36 | import qualified Lamdu.Calc.Type as T
37 | import Hyper.Unify.QuantifiedVar (HasQuantifiedVar(..))
38 |
39 | import Lamdu.Calc.Internal.Prelude
40 |
41 | tIds ::
42 | forall k expr.
43 | (RTraversable k, HasTIds expr) =>
44 | Traversal' (k # expr) T.NominalId
45 | tIds f =
46 | withDict (recurse (Proxy @(RTraversable k))) $
47 | htraverse (Proxy @RTraversable #> bodyTIds f)
48 |
49 | qVarIds ::
50 | forall k expr.
51 | (RTraversable k, HasQVarIds expr) =>
52 | Traversal' (k # expr) Identifier
53 | qVarIds f =
54 | withDict (recurse (Proxy @(RTraversable k))) $
55 | htraverse (Proxy @RTraversable #> bodyQVarIds f)
56 |
57 | class HasTIds expr where
58 | bodyTIds :: RTraversable k => Traversal' (expr # k) T.NominalId
59 |
60 | class HasQVarIds expr where
61 | bodyQVarIds :: RTraversable k => Traversal' (expr # k) Identifier -- only a legal traversal if keys not modified to be duplicates!
62 |
63 | unsafeMapList :: Ord k1 => Lens.Iso (Map k0 v0) (Map k1 v1) [(k0, v0)] [(k1, v1)]
64 | unsafeMapList = Lens.iso (^@.. Lens.itraversed) Map.fromList
65 |
66 | class (HasQVarIds expr, Ord (QVar expr)) => HasQVar expr where
67 | qvarId :: Proxy expr -> Lens.Iso' (QVar expr) Identifier
68 |
69 | instance HasQVar T.Type where qvarId _ = T._Var
70 | instance HasQVar T.Row where qvarId _ = T._Var
71 |
72 | instance HasTIds T.Type where
73 | {-# INLINE bodyTIds #-}
74 | bodyTIds f (T.TInst (NominalInst tId args)) =
75 | NominalInst
76 | <$> f tId
77 | <*> htraverse (Proxy @HasTIds #> (_QVarInstances . traverse . tIds) f)
78 | args
79 | <&> T.TInst
80 | bodyTIds f x = htraverse (Proxy @HasTIds #> tIds f) x
81 |
82 | instance HasQVarIds T.Type where
83 | bodyQVarIds f (T.TInst (NominalInst tId args)) =
84 | NominalInst tId
85 | <$> htraverse (Proxy @HasQVar #> qvarInstancesQVarIds f) args
86 | <&> T.TInst
87 | bodyQVarIds f (T.TVar v) = T._Var f v <&> T.TVar
88 | bodyQVarIds f x = htraverse (Proxy @HasQVar #> qVarIds f) x
89 |
90 | qvarInstancesQVarIds :: forall expr h. (HasQVar expr, RTraversable h) => Traversal' (QVarInstances h # expr) Identifier
91 | qvarInstancesQVarIds f =
92 | (_QVarInstances . unsafeMapList . traverse)
93 | (\(k, typ) ->
94 | (,) <$> qvarId (Proxy @expr) f k <*> qVarIds f typ)
95 |
96 | instance HasTIds T.Row where
97 | {-# INLINE bodyTIds #-}
98 | bodyTIds f = htraverse (Proxy @HasTIds #> tIds f)
99 |
100 | instance HasQVarIds T.Row where
101 | bodyQVarIds f = htraverse (Proxy @HasQVarIds #> qVarIds f)
102 |
103 | instance HasTIds (Scheme T.Types T.Type) where
104 | bodyTIds = sTyp . tIds
105 |
106 | qvarsQVarIds ::
107 | forall expr.
108 | HasQVar expr =>
109 | Traversal' (QVars # expr) Identifier
110 | qvarsQVarIds =
111 | _QVars . unsafeMapList . traverse . Lens._1 . qvarId (Proxy @expr)
112 |
113 | instance HasTIds (NominalDecl T.Type) where
114 | bodyTIds = nScheme . bodyTIds
115 |
116 | {-# INLINE valApply #-}
117 | valApply :: Traversal' (Ann a # V.Term) (V.App V.Term # Ann a)
118 | valApply = hVal . V._BApp
119 |
120 | {-# INLINE valHole #-}
121 | valHole :: Traversal' (Ann a # V.Term) ()
122 | valHole = hVal . valBodyHole
123 |
124 | {-# INLINE valVar #-}
125 | valVar :: Traversal' (Ann a # V.Term) V.Var
126 | valVar = hVal . valBodyVar
127 |
128 | {-# INLINE valLiteral #-}
129 | valLiteral :: Traversal' (Ann a # V.Term) V.PrimVal
130 | valLiteral = hVal . valBodyLiteral
131 |
132 | {-# INLINE valBodyHole #-}
133 | valBodyHole :: Prism' (V.Term expr) ()
134 | valBodyHole = V._BLeaf . V._LHole
135 |
136 | {-# INLINE valBodyVar #-}
137 | valBodyVar :: Prism' (V.Term expr) V.Var
138 | valBodyVar = V._BLeaf . V._LVar
139 |
140 | {-# INLINE valBodyLiteral #-}
141 | valBodyLiteral :: Prism' (V.Term expr) V.PrimVal
142 | valBodyLiteral = V._BLeaf . V._LLiteral
143 |
144 | subTerms :: Lens.Traversal' (V.Term # k) (k # V.Term)
145 | subTerms f =
146 | htraverse
147 | ( \case
148 | HWitness V.W_Term_Term -> f
149 | HWitness V.W_Term_HCompose_Prune_Type -> pure
150 | )
151 |
152 | {-# INLINE valLeafs #-}
153 | valLeafs :: Lens.IndexedTraversal' (a # V.Term) (Ann a # V.Term) V.Leaf
154 | valLeafs f (Ann pl body) =
155 | case body of
156 | V.BLeaf l -> Lens.indexed f pl l <&> V.BLeaf
157 | _ -> subTerms (valLeafs f) body
158 | <&> Ann pl
159 |
160 | {-# INLINE subExprPayloads #-}
161 | subExprPayloads :: Lens.IndexedTraversal' (Pure # V.Term) (Val a) a
162 | subExprPayloads f x@(Ann (Const pl) body) =
163 | Ann
164 | <$> (Lens.indexed f (unwrap (const (^. hVal)) x) pl <&> Const)
165 | <*> (subTerms .> subExprPayloads) f body
166 |
167 | {-# INLINE payloadsOf #-}
168 | payloadsOf ::
169 | Lens.Fold (Pure # V.Term) a -> Lens.IndexedTraversal' (Pure # V.Term) (Val b) b
170 | payloadsOf l =
171 | subExprPayloads . Lens.ifiltered predicate
172 | where
173 | predicate idx _ = Lens.has l idx
174 |
175 | leafTags :: Lens.Traversal' V.Leaf T.Tag
176 | leafTags f (V.LInject t) = f t <&> V.LInject
177 | leafTags f (V.LGetField t) = f t <&> V.LGetField
178 | leafTags _ x = pure x
179 |
180 | {-# INLINE valTags #-}
181 | valTags :: Lens.Traversal' (Ann a # V.Term) T.Tag
182 | valTags =
183 | hVal .
184 | \f ->
185 | \case
186 | V.BLeaf l -> leafTags f l <&> V.BLeaf
187 | V.BCase (RowExtend t v r) ->
188 | RowExtend <$> f t <*> valTags f v <*> valTags f r <&> V.BCase
189 | V.BRecExtend (RowExtend t v r) ->
190 | RowExtend <$> f t <*> valTags f v <*> valTags f r <&> V.BRecExtend
191 | body ->
192 | htraverse
193 | ( \case
194 | HWitness V.W_Term_Term -> valTags f
195 | HWitness V.W_Term_HCompose_Prune_Type -> typeTags f
196 | ) body
197 |
198 | typeTags :: Lens.Traversal' (Ann a # HCompose Prune T.Type) T.Tag
199 | typeTags f =
200 | (hVal . hcomposed _Unpruned)
201 | ( htraverse
202 | ( \case
203 | HWitness T.W_Type_Type -> (_HCompose . typeTags) f
204 | HWitness T.W_Type_Row -> (_HCompose . rowTags) f
205 | HWitness (T.E_Type_NominalInst_NominalId_Types (HWitness T.W_Types_Type)) -> (_HCompose . typeTags) f
206 | HWitness (T.E_Type_NominalInst_NominalId_Types (HWitness T.W_Types_Row)) -> (_HCompose . rowTags) f
207 | )
208 | )
209 |
210 | rowTags :: Lens.Traversal' (Ann a # HCompose Prune T.Row) T.Tag
211 | rowTags =
212 | hVal . hcomposed _Unpruned . T._RExtend . onRExtend
213 | where
214 | onRExtend f (RowExtend tag val rest) =
215 | RowExtend
216 | <$> f tag
217 | <*> (_HCompose . typeTags) f val
218 | <*> (_HCompose . rowTags) f rest
219 |
220 | {-# INLINE valGlobals #-}
221 | valGlobals ::
222 | Set V.Var ->
223 | Lens.IndexedFold (a # V.Term) (Ann a # V.Term) V.Var
224 | valGlobals scope f (Ann pl body) =
225 | case body of
226 | V.BLeaf (V.LVar v)
227 | | scope ^. Lens.contains v -> V.LVar v & V.BLeaf & pure
228 | | otherwise -> Lens.indexed f pl v <&> V.LVar <&> V.BLeaf
229 | V.BLam (V.TypedLam var typ lamBody) ->
230 | valGlobals (Set.insert var scope) f lamBody
231 | <&> V.TypedLam var typ <&> V.BLam
232 | _ ->
233 | htraverse
234 | ( \case
235 | HWitness V.W_Term_Term -> valGlobals scope f
236 | HWitness V.W_Term_HCompose_Prune_Type -> pure
237 | ) body
238 | <&> Ann pl
239 |
240 | {-# INLINE valNominals #-}
241 | valNominals :: Lens.Traversal' (Ann a # V.Term) T.NominalId
242 | valNominals =
243 | hVal .
244 | \f ->
245 | \case
246 | V.BLeaf (V.LFromNom nomId) -> f nomId <&> V.LFromNom <&> V.BLeaf
247 | V.BToNom (ToNom nomId x) ->
248 | ToNom
249 | <$> f nomId
250 | <*> valNominals f x
251 | <&> V.BToNom
252 | body ->
253 | htraverse
254 | ( \case
255 | HWitness V.W_Term_Term -> valNominals f
256 | HWitness V.W_Term_HCompose_Prune_Type -> typeNominals f
257 | ) body
258 |
259 | {-# INLINE typeNominals #-}
260 | typeNominals :: Lens.Traversal' (Ann a # HCompose Prune T.Type) T.NominalId
261 | typeNominals =
262 | hVal . hcomposed _Unpruned .
263 | \f ->
264 | \case
265 | T.TInst (NominalInst nomId args) ->
266 | NominalInst
267 | <$> f nomId
268 | <*> htraverse
269 | ( \case
270 | HWitness T.W_Types_Type -> (_QVarInstances . traverse . _HCompose . typeNominals) f
271 | HWitness T.W_Types_Row -> (_QVarInstances . traverse . _HCompose . rowNominals) f
272 | ) args
273 | <&> T.TInst
274 | body ->
275 | htraverse
276 | ( \case
277 | HWitness T.W_Type_Type -> (_HCompose . typeNominals) f
278 | HWitness T.W_Type_Row -> (_HCompose . rowNominals) f
279 | HWitness (T.E_Type_NominalInst_NominalId_Types (HWitness T.W_Types_Type)) -> (_HCompose . typeNominals) f
280 | HWitness (T.E_Type_NominalInst_NominalId_Types (HWitness T.W_Types_Row)) -> (_HCompose . rowNominals) f
281 | ) body
282 |
283 | {-# INLINE rowNominals #-}
284 | rowNominals :: Lens.Traversal' (Ann a # HCompose Prune T.Row) T.NominalId
285 | rowNominals =
286 | hVal . hcomposed _Unpruned .
287 | \f ->
288 | htraverse
289 | ( \case
290 | HWitness T.W_Row_Type -> (_HCompose . typeNominals) f
291 | HWitness T.W_Row_Row -> (_HCompose . rowNominals) f
292 | )
293 |
--------------------------------------------------------------------------------
/src/Lamdu/Calc/Term.hs:
--------------------------------------------------------------------------------
1 | -- | Val AST
2 | {-# LANGUAGE NoImplicitPrelude, DeriveGeneric, DeriveTraversable, GADTs #-}
3 | {-# LANGUAGE GeneralizedNewtypeDeriving, TemplateHaskell, FlexibleContexts #-}
4 | {-# LANGUAGE UndecidableInstances, StandaloneDeriving, TypeFamilies, TypeOperators #-}
5 | {-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, ConstraintKinds #-}
6 | {-# LANGUAGE TupleSections, ScopedTypeVariables, DerivingStrategies, DataKinds #-}
7 | {-# LANGUAGE PolyKinds, TypeApplications #-}
8 |
9 | module Lamdu.Calc.Term
10 | ( Val
11 | , Leaf(..), _LVar, _LHole, _LLiteral, _LRecEmpty, _LAbsurd, _LFromNom, _LGetField, _LInject
12 | , PrimVal(..), primType, primData
13 | , Term(..), _BApp, _BLam, _BRecExtend, _BCase, _BToNom, _BLeaf, W_Term(..)
14 | , App(..), appFunc, appArg
15 | , TypedLam(..), tlIn, tlInType, tlOut
16 | , Var(..)
17 | , Scope(..), scopeNominals, scopeVarTypes, scopeLevel
18 | , emptyScope
19 | , ToNom(..), FromNom(..), RowExtend(..)
20 | , HPlain(..)
21 | ) where
22 |
23 | import qualified Control.Lens as Lens
24 | import qualified Data.ByteString.Char8 as BS8
25 | import Data.Maybe (fromMaybe)
26 | import Generics.Constraints (makeDerivings, makeInstances)
27 | import Hyper
28 | import Hyper.Infer
29 | import Hyper.Infer.Blame (Blame(..))
30 | import Hyper.Type.Prune (Prune)
31 | import Hyper.Syntax hiding (Var)
32 | import Hyper.Syntax.Nominal (ToNom(..), FromNom(..), NominalInst(..), MonadNominals, LoadedNominalDecl)
33 | import Hyper.Syntax.Row (RowExtend(..), rowElementInfer)
34 | import Hyper.Syntax.Scheme (QVarInstances(..))
35 | import qualified Hyper.Syntax.Var as TermVar
36 | import Hyper.Unify
37 | import qualified Hyper.Unify.Generalize as G
38 | import Hyper.Unify.New (newTerm, newUnbound)
39 | import Hyper.Unify.Term (UTerm(..))
40 | import Lamdu.Calc.Identifier (Identifier)
41 | import qualified Lamdu.Calc.Type as T
42 | import Text.PrettyPrint ((<+>))
43 | import qualified Text.PrettyPrint as PP
44 | import Text.PrettyPrint.HughesPJClass (Pretty(..), maybeParens)
45 |
46 | import Lamdu.Calc.Internal.Prelude
47 |
48 | {-# ANN module ("HLint: ignore Use const"::String) #-}
49 |
50 | newtype Var = Var { vvName :: Identifier }
51 | deriving stock Show
52 | deriving newtype (Eq, Ord, NFData, IsString, Pretty, Binary, Hashable)
53 |
54 | data PrimVal = PrimVal
55 | { _primType :: {-# UNPACK #-} !T.NominalId
56 | , _primData :: {-# UNPACK #-} !ByteString
57 | } deriving (Generic, Show, Eq, Ord)
58 | instance NFData PrimVal
59 | instance Binary PrimVal
60 | instance Hashable PrimVal
61 |
62 | Lens.makeLenses ''PrimVal
63 |
64 | data Leaf
65 | = LVar {-# UNPACK #-}!Var
66 | | LHole
67 | | LLiteral {-# UNPACK #-}!PrimVal
68 | | LRecEmpty
69 | | LAbsurd
70 | | LFromNom {-# UNPACK #-}!T.NominalId
71 | | LGetField {-# UNPACK #-}!T.Tag
72 | | LInject {-# UNPACK #-}!T.Tag
73 | deriving (Generic, Show, Eq, Ord)
74 | instance NFData Leaf
75 | instance Binary Leaf
76 | instance Hashable Leaf
77 |
78 | Lens.makePrisms ''Leaf
79 |
80 | data Term k
81 | = BApp {-# UNPACK #-}!(App Term k)
82 | | BLam {-# UNPACK #-}!(TypedLam Var (HCompose Prune T.Type) Term k)
83 | | BRecExtend {-# UNPACK #-}!(RowExtend T.Tag Term Term k)
84 | | BCase {-# UNPACK #-}!(RowExtend T.Tag Term Term k)
85 | | -- Convert to Nominal type
86 | BToNom {-# UNPACK #-}!(ToNom T.NominalId Term k)
87 | | BLeaf Leaf
88 | deriving Generic
89 |
90 | Lens.makePrisms ''Term
91 | makeHTraversableAndBases ''Term
92 | makeZipMatch ''Term
93 | makeHContext ''Term
94 | makeHasHPlain [''Term]
95 |
96 | instance RNodes Term
97 | instance
98 | (c Term, c (HCompose Prune T.Type), c (HCompose Prune T.Row)) =>
99 | Recursively c Term
100 | instance RTraversable Term
101 |
102 | instance IsString (HPlain Term) where
103 | fromString = BLeafP . LVar . fromString
104 |
105 | instance
106 | (Pretty (f :# Term), Pretty (f :# HCompose Prune T.Type)) =>
107 | Pretty (Term f) where
108 | pPrintPrec lvl prec b =
109 | case b of
110 | BLeaf (LVar var) -> pPrint var
111 | BLeaf (LLiteral (PrimVal _p d)) -> PP.text (BS8.unpack d)
112 | BLeaf LHole -> PP.text "?"
113 | BLeaf LAbsurd -> PP.text "absurd"
114 | BLeaf (LFromNom ident) -> PP.text "[" <+> PP.text "unpack" <+> pPrint ident <+> PP.text "]"
115 | BApp (App e1 e2) -> maybeParens (10 < prec) $
116 | pPrintPrec lvl 10 e1 <+> pPrintPrec lvl 11 e2
117 | BLam (TypedLam n t e) -> maybeParens (0 < prec) $
118 | PP.char '\\' PP.<> pPrint n PP.<>
119 | PP.char ':' PP.<> pPrint t <+>
120 | PP.text "->" <+>
121 | pPrint e
122 | BLeaf (LGetField x) -> maybeParens (12 < prec) $ PP.char '.' <> pPrint x
123 | BLeaf (LInject x) -> maybeParens (12 < prec) $ pPrint x <> PP.char ':'
124 | BCase (RowExtend n m mm) -> maybeParens (0 < prec) $
125 | PP.vcat
126 | [ PP.text "case of"
127 | , pPrint n <> PP.text " -> " <> pPrint m
128 | , PP.text "_" <> PP.text " -> " <> pPrint mm
129 | ]
130 | BToNom (ToNom ident val) -> PP.text "[" <+> pPrint ident <+> PP.text "pack" <+> pPrint val <+> PP.text "]"
131 | BLeaf LRecEmpty -> PP.text "{}"
132 | BRecExtend (RowExtend tag val rest) ->
133 | PP.text "{" <+>
134 | prField PP.<>
135 | PP.comma <+>
136 | pPrint rest <+>
137 | PP.text "}"
138 | where
139 | prField = pPrint tag <+> PP.text "=" <+> pPrint val
140 |
141 | data Scope v = Scope
142 | { _scopeNominals :: Map T.NominalId (LoadedNominalDecl T.Type v)
143 | , _scopeVarTypes :: Map Var (HFlip G.GTerm T.Type v)
144 | , _scopeLevel :: ScopeLevel
145 | } deriving Generic
146 | Lens.makeLenses ''Scope
147 |
148 | makeHTraversableAndBases ''Scope
149 |
150 | {-# INLINE emptyScope #-}
151 | emptyScope :: Scope # v
152 | emptyScope = Scope mempty mempty (ScopeLevel 0)
153 |
154 | type instance TermVar.ScopeOf Term = Scope
155 | type instance InferOf Term = ANode T.Type
156 | instance HasInferredType Term where
157 | type instance TypeOf Term = T.Type
158 | inferredType _ = _ANode
159 |
160 | makeDerivings [''Eq, ''Ord, ''Show] [''Term, ''Scope]
161 | makeInstances [''Binary, ''NFData, ''Hashable] [''Term, ''Scope]
162 |
163 | instance TermVar.VarType Var Term where
164 | {-# INLINE varType #-}
165 | varType _ v x =
166 | x ^? scopeVarTypes . Lens.ix v . _HFlip
167 | & fromMaybe (error ("var not in scope: " <> show v))
168 | & G.instantiate
169 |
170 | instance
171 | ( MonadNominals T.NominalId T.Type m
172 | , MonadScopeLevel m
173 | , TermVar.HasScope m Scope
174 | , UnifyGen m T.Type, UnifyGen m T.Row
175 | , LocalScopeType Var (UVarOf m # T.Type) m
176 | ) =>
177 | Infer m Term where
178 |
179 | {-# INLINE inferBody #-}
180 | inferBody (BApp x) = inferBody x <&> Lens._1 %~ BApp
181 | inferBody (BLam x) = inferBody x <&> Lens._1 %~ BLam
182 | inferBody (BToNom x) =
183 | do
184 | (xI, xT) <- inferBody x
185 | newTerm (T.TInst xT) <&> MkANode <&> (BToNom xI, )
186 | inferBody (BLeaf leaf) =
187 | case leaf of
188 | LHole -> newUnbound
189 | LRecEmpty -> newTerm T.REmpty >>= newTerm . T.TRecord
190 | LAbsurd ->
191 | FuncType
192 | <$> (newTerm T.REmpty >>= newTerm . T.TVariant)
193 | <*> newUnbound
194 | >>= newTerm . T.TFun
195 | LLiteral (PrimVal t _) ->
196 | T.Types (QVarInstances mempty) (QVarInstances mempty)
197 | & NominalInst t
198 | & T.TInst & newTerm
199 | LVar x ->
200 | inferBody (TermVar.Var x :: TermVar.Var Var Term # InferChild k w)
201 | <&> (^. Lens._2 . _ANode)
202 | LFromNom x ->
203 | inferBody (FromNom x :: FromNom T.NominalId Term # InferChild k w)
204 | <&> (^. Lens._2)
205 | >>= newTerm . T.TFun
206 | LGetField k ->
207 | do
208 | (rT, wR) <- rowElementInfer T.RExtend k
209 | T.TRecord wR & newTerm
210 | >>= newTerm . T.TFun . (`FuncType` rT)
211 | LInject k ->
212 | do
213 | (rT, wR) <- rowElementInfer T.RExtend k
214 | T.TVariant wR & newTerm
215 | >>= newTerm . T.TFun . FuncType rT
216 | <&> MkANode
217 | <&> (BLeaf leaf, )
218 | inferBody (BRecExtend (RowExtend k v r)) =
219 | do
220 | InferredChild vI vR <- inferChild v
221 | InferredChild rI rR <- inferChild r
222 | restR <-
223 | scopeConstraints (Proxy @T.Row) <&> T.rForbiddenFields . Lens.contains k .~ True
224 | >>= newVar binding . UUnbound
225 | _ <- T.TRecord restR & newTerm >>= unify (rR ^. _ANode)
226 | RowExtend k (vR ^. _ANode) restR & T.RExtend & newTerm
227 | >>= newTerm . T.TRecord
228 | <&> MkANode
229 | <&> (BRecExtend (RowExtend k vI rI), )
230 | inferBody (BCase (RowExtend tag handler rest)) =
231 | do
232 | InferredChild handlerI handlerT <- inferChild handler
233 | InferredChild restI restT <- inferChild rest
234 | fieldT <- newUnbound
235 | restR <- newUnbound
236 | result <- newUnbound
237 | _ <-
238 | FuncType fieldT result & T.TFun & newTerm
239 | >>= unify (handlerT ^. _ANode)
240 | restVarT <- T.TVariant restR & newTerm
241 | _ <-
242 | FuncType restVarT result & T.TFun & newTerm
243 | >>= unify (restT ^. _ANode)
244 | whole <- RowExtend tag fieldT restR & T.RExtend & newTerm >>= newTerm . T.TVariant
245 | FuncType whole result & T.TFun & newTerm
246 | <&> MkANode
247 | <&> (BCase (RowExtend tag handlerI restI), )
248 |
249 | instance
250 | ( MonadNominals T.NominalId T.Type m
251 | , MonadScopeLevel m
252 | , TermVar.HasScope m Scope
253 | , UnifyGen m T.Type, UnifyGen m T.Row
254 | , LocalScopeType Var (UVarOf m # T.Type) m
255 | ) =>
256 | Blame m Term where
257 | inferOfUnify _ x y = unify (x ^. _ANode) (y ^. _ANode) & void
258 | inferOfMatches _ x y =
259 | (==)
260 | <$> (semiPruneLookup (x ^. _ANode) <&> fst)
261 | <*> (semiPruneLookup (y ^. _ANode) <&> fst)
262 |
263 | -- Type synonym to ease the transition
264 |
265 | type Val a = Ann (Const a) # Term
266 |
--------------------------------------------------------------------------------
/src/Lamdu/Calc/Type.hs:
--------------------------------------------------------------------------------
1 | -- | The Lamdu Calculus type AST.
2 | --
3 | -- The Lamdu Calculus type system includes the set of types that can
4 | -- be expressed via the AST elements in this module.
5 | --
6 | -- The Lamdu Calculus type AST is actually 2 different ASTs:
7 | --
8 | -- * The AST for structural composite types (records, variants). The kinds of
9 | -- composite are differentiated via a phantom type-level tag
10 | --
11 | -- * The AST for types: Nominal types, structural composite types,
12 | -- function types.
13 | {-# LANGUAGE NoImplicitPrelude, DeriveGeneric, GeneralizedNewtypeDeriving, TypeApplications #-}
14 | {-# LANGUAGE TemplateHaskell, DataKinds, StandaloneDeriving, DerivingVia, TypeOperators #-}
15 | {-# LANGUAGE UndecidableInstances, ConstraintKinds, FlexibleContexts, GADTs, TupleSections #-}
16 | {-# LANGUAGE FlexibleInstances, TypeFamilies, MultiParamTypeClasses, RankNTypes #-}
17 |
18 | module Lamdu.Calc.Type
19 | (
20 | -- * Type Variable kinds
21 | RowVar, TypeVar
22 | -- * Typed identifiers of the Type AST
23 | , Var(..), NominalId(..), Tag(..)
24 | -- * Rows
25 | , Row(..), W_Row(..)
26 | -- * Row Prisms
27 | , _RExtend, _REmpty, _RVar
28 | -- * Type AST
29 | , Type(..), W_Type(..)
30 | , Scheme, Nominal
31 | , (~>)
32 | -- * Type Prisms
33 | , _Var
34 | , _TVar, _TFun, _TInst, _TRecord, _TVariant
35 | -- TODO: describe
36 | , Types(..), W_Types(..), tType, tRow
37 | , RConstraints(..), rForbiddenFields, rScope
38 | , rStructureMismatch
39 |
40 | , TypeError(..), _TypeError, _RowError
41 |
42 | , flatRow
43 |
44 | , HPlain(..)
45 | ) where
46 |
47 | import Algebra.PartialOrd (PartialOrd(..))
48 | import qualified Control.Lens as Lens
49 | import Generic.Data (Generically(..))
50 | import Generics.Constraints (makeDerivings, makeInstances)
51 | import Hyper
52 | import Hyper.Class.Optic (HNodeLens(..), HSubset(..))
53 | import Hyper.Infer
54 | import Hyper.Infer.Blame (Blame(..))
55 | import Hyper.Syntax hiding (Var, _Var)
56 | import Hyper.Syntax.Nominal
57 | import Hyper.Syntax.Row
58 | import qualified Hyper.Syntax.Scheme as S
59 | import Hyper.Unify
60 | import Hyper.Unify.New (newTerm)
61 | import Hyper.Unify.QuantifiedVar (HasQuantifiedVar(..))
62 | import Lamdu.Calc.Identifier (Identifier)
63 | import Text.PrettyPrint ((<+>))
64 | import qualified Text.PrettyPrint as PP
65 | import Text.PrettyPrint.HughesPJClass (Pretty(..), maybeParens)
66 |
67 | import Lamdu.Calc.Internal.Prelude
68 |
69 | -- | A type varible of some kind ('Var' 'Type', 'Var' 'Variant', or 'Var' 'Record')
70 | newtype Var (t :: HyperType) = Var { tvName :: Identifier }
71 | deriving stock Show
72 | deriving newtype (Eq, Ord, NFData, IsString, Pretty, Binary, Hashable)
73 |
74 | -- | An identifier for a nominal type
75 | newtype NominalId = NominalId { nomId :: Identifier }
76 | deriving stock Show
77 | deriving newtype (Eq, Ord, NFData, IsString, Pretty, Binary, Hashable)
78 |
79 | -- | An identifier for a component in a variant type (aka data
80 | -- constructor) or a component(field) in a record
81 | newtype Tag = Tag { tagName :: Identifier }
82 | deriving stock Show
83 | deriving newtype (Eq, Ord, NFData, IsString, Pretty, Binary, Hashable)
84 |
85 | -- | A row type variable that represents a set of
86 | -- typed fields in a row
87 | type RowVar = Var Row
88 |
89 | -- | A type variable that represents a type
90 | type TypeVar = Var Type
91 |
92 | -- | The AST for rows (records, variants) For
93 | -- example: RExtend "a" int (RExtend "b" bool (RVar "c")) represents
94 | -- the composite type:
95 | -- > { a : int, b : bool | c }
96 | data Row k
97 | = RExtend (RowExtend Tag Type Row k)
98 | -- ^ Extend a row type with an extra component (field /
99 | -- data constructor).
100 | | REmpty
101 | -- ^ The empty row type (empty record [unit] or empty variant [void])
102 | | RVar RowVar
103 | -- ^ A row type variable
104 | deriving Generic
105 |
106 | -- | The AST for any Lamdu Calculus type
107 | data Type k
108 | = TVar TypeVar
109 | -- ^ A type variable
110 | | TFun (FuncType Type k)
111 | -- ^ A (non-dependent) function of the given parameter and result types
112 | | TInst (NominalInst NominalId Types k)
113 | -- ^ An instantiation of a nominal type of the given id with the
114 | -- given keyword type arguments
115 | | TRecord (k :# Row)
116 | -- ^ Lifts a composite record type
117 | | TVariant (k :# Row)
118 | -- ^ Lifts a composite variant type
119 | deriving Generic
120 |
121 | data Types k = Types
122 | { _tType :: k :# Type
123 | , _tRow :: k :# Row
124 | } deriving Generic
125 |
126 | data RConstraints = RowConstraints
127 | { _rForbiddenFields :: Set Tag
128 | , _rScope :: ScopeLevel
129 | } deriving (Generic, Eq, Ord, Show)
130 | deriving (Semigroup, Monoid) via Generically RConstraints
131 |
132 | data TypeError k
133 | = TypeError (UnifyError Type k)
134 | | RowError (UnifyError Row k)
135 | | NominalNotFound NominalId
136 | deriving Generic
137 |
138 | Lens.makeLenses ''RConstraints
139 | Lens.makeLenses ''Types
140 | Lens.makePrisms ''Row
141 | Lens.makePrisms ''Type
142 | Lens.makePrisms ''TypeError
143 | Lens.makePrisms ''Var
144 |
145 | makeHTraversableApplyAndBases ''Types
146 | makeHTraversableAndBases ''Row
147 | makeHTraversableAndBases ''Type
148 | makeZipMatch ''Row
149 | makeZipMatch ''Type
150 | makeZipMatch ''Types
151 | makeHContext ''Row
152 | makeHContext ''Type
153 | makeHContext ''Types
154 | instance RNodes Row
155 | instance RNodes Type
156 | instance (c Type, c Row) => Recursively c Type
157 | instance (c Type, c Row) => Recursively c Row
158 | instance RTraversable Row
159 | instance RTraversable Type
160 |
161 | type Nominal = NominalInst NominalId Types
162 | type Scheme = S.Scheme Types Type
163 |
164 | instance HNodeLens Types Type where
165 | {-# INLINE hNodeLens #-}
166 | hNodeLens = tType
167 |
168 | instance HNodeLens Types Row where
169 | {-# INLINE hNodeLens #-}
170 | hNodeLens = tRow
171 |
172 | instance (UnifyGen m Type, UnifyGen m Row) => S.HasScheme Types m Type
173 | instance (UnifyGen m Type, UnifyGen m Row) => S.HasScheme Types m Row
174 |
175 | -- | A convenience infix alias for 'TFun'
176 | infixr 2 ~>
177 | (~>) :: Pure # Type -> Pure # Type -> Pure # Type
178 | x ~> y = _Pure # TFun (FuncType x y)
179 |
180 | type Deps c k = ((c (k :# Type), c (k :# Row)) :: Constraint)
181 |
182 | instance Deps Pretty k => Pretty (Type k) where
183 | pPrintPrec lvl prec typ =
184 | case typ of
185 | TVar n -> pPrint n
186 | TInst n -> pPrint n
187 | TFun (FuncType t s) ->
188 | maybeParens (8 < prec) $
189 | pPrintPrec lvl 9 t <+> PP.text "->" <+> pPrintPrec lvl 8 s
190 | TRecord r -> PP.text "*" <> pPrint r
191 | TVariant s -> PP.text "+" <> pPrint s
192 |
193 | instance Pretty (Row # Pure) where
194 | pPrint REmpty = PP.text "{}"
195 | pPrint x =
196 | PP.text "{" <+> go PP.empty x <+> PP.text "}"
197 | where
198 | go _ REmpty = PP.empty
199 | go sep (RVar tv) = sep <> pPrint tv <> PP.text "..."
200 | go sep (RExtend (RowExtend f t r)) =
201 | sep PP.<> pPrint f <+> PP.text ":" <+> pPrint t PP.<> go (PP.text ", ") (r ^. _Pure)
202 |
203 | instance Deps Pretty k => Pretty (Types k) where
204 | pPrint (Types t r) = PP.text "{" <+> pPrint t <+> PP.text "|" <+> pPrint r <+> PP.text "}"
205 |
206 | instance Pretty RConstraints where
207 | pPrint (RowConstraints tags level) =
208 | pPrint (tags ^.. Lens.folded) PP.<+>
209 | (PP.text "(" <> pPrint level <> PP.text ")")
210 |
211 | instance Pretty (TypeError # Pure) where
212 | pPrint (TypeError x) = pPrint x
213 | pPrint (RowError x) = pPrint x
214 | pPrint (NominalNotFound x) = PP.text "Nominal not found:" <+> pPrint x
215 |
216 | type instance NomVarTypes Type = Types
217 |
218 | instance HSubset Type Type (FuncType Type) (FuncType Type) where
219 | {-# INLINE hSubset #-}
220 | hSubset = _TFun
221 |
222 | instance HasNominalInst NominalId Type where
223 | {-# INLINE nominalInst #-}
224 | nominalInst = _TInst
225 |
226 | instance HasQuantifiedVar Type where
227 | type QVar Type = TypeVar
228 | quantifiedVar = _TVar
229 |
230 | instance HasQuantifiedVar Row where
231 | type QVar Row = RowVar
232 | quantifiedVar = _RVar
233 |
234 | instance HasTypeConstraints Type where
235 | type instance TypeConstraintsOf Type = ScopeLevel
236 | {-# INLINE verifyConstraints #-}
237 | verifyConstraints _ (TVar x) = TVar x & Just
238 | verifyConstraints c (TFun x) = x & hmapped1 %~ WithConstraint c & TFun & Just
239 | verifyConstraints c (TRecord x) =
240 | WithConstraint (RowConstraints mempty c) x & TRecord & Just
241 | verifyConstraints c (TVariant x) =
242 | WithConstraint (RowConstraints mempty c) x & TVariant & Just
243 | verifyConstraints c (TInst (NominalInst n (Types t r))) =
244 | Types
245 | (t & S._QVarInstances . traverse %~ WithConstraint c)
246 | (r & S._QVarInstances . traverse %~ WithConstraint (RowConstraints mempty c))
247 | & NominalInst n & TInst & Just
248 |
249 | instance HasTypeConstraints Row where
250 | type instance TypeConstraintsOf Row = RConstraints
251 | {-# INLINE verifyConstraints #-}
252 | verifyConstraints _ REmpty = Just REmpty
253 | verifyConstraints _ (RVar x) = RVar x & Just
254 | verifyConstraints c (RExtend x) = verifyRowExtendConstraints (^. rScope) c x <&> RExtend
255 |
256 | instance TypeConstraints RConstraints where
257 | {-# INLINE generalizeConstraints #-}
258 | generalizeConstraints = rScope .~ mempty
259 | toScopeConstraints = rForbiddenFields .~ mempty
260 |
261 | instance RowConstraints RConstraints where
262 | type RowConstraintsKey RConstraints = Tag
263 | {-# INLINE forbidden #-}
264 | forbidden = rForbiddenFields
265 |
266 | instance PartialOrd RConstraints where
267 | {-# INLINE leq #-}
268 | RowConstraints f0 s0 `leq` RowConstraints f1 s1 = f0 `leq` f1 && s0 `leq` s1
269 |
270 | type instance InferOf Type = ANode Type
271 | type instance InferOf Row = ANode Row
272 | instance HasInferredValue Type where inferredValue = _ANode
273 | instance HasInferredValue Row where inferredValue = _ANode
274 |
275 | instance (Monad m, UnifyGen m Type, UnifyGen m Row) => Infer m Type where
276 | inferBody x =
277 | do
278 | xI <- htraverse (const inferChild) x
279 | hmap (Proxy @HasInferredValue #> (^. inType . inferredValue)) xI
280 | & newTerm
281 | <&> (hmap (const (^. inRep)) xI, ) . MkANode
282 |
283 | instance (Monad m, UnifyGen m Type, UnifyGen m Row) => Infer m Row where
284 | inferBody x =
285 | do
286 | xI <- htraverse (const inferChild) x
287 | hmap (Proxy @HasInferredValue #> (^. inType . inferredValue)) xI
288 | & newTerm
289 | <&> (hmap (const (^. inRep)) xI, ) . MkANode
290 |
291 | instance (UnifyGen m Type, UnifyGen m Row) => Blame m Row where
292 | inferOfUnify _ x y = unify (x ^. _ANode) (y ^. _ANode) & void
293 | inferOfMatches _ x y =
294 | (==)
295 | <$> (semiPruneLookup (x ^. _ANode) <&> fst)
296 | <*> (semiPruneLookup (y ^. _ANode) <&> fst)
297 |
298 | instance (UnifyGen m Type, UnifyGen m Row) => Blame m Type where
299 | inferOfUnify _ x y = unify (x ^. _ANode) (y ^. _ANode) & void
300 | inferOfMatches _ x y =
301 | (==)
302 | <$> (semiPruneLookup (x ^. _ANode) <&> fst)
303 | <*> (semiPruneLookup (y ^. _ANode) <&> fst)
304 |
305 | {-# INLINE rStructureMismatch #-}
306 | rStructureMismatch ::
307 | (Unify m Type, Unify m Row) =>
308 | (forall c. Unify m c => UVarOf m # c -> UVarOf m # c -> m (UVarOf m # c)) ->
309 | Row # UVarOf m ->
310 | Row # UVarOf m ->
311 | m ()
312 | rStructureMismatch f (RExtend r0) (RExtend r1) =
313 | rowExtendStructureMismatch f _RExtend r0 r1
314 | rStructureMismatch _ x y = Mismatch x y & unifyError
315 |
316 | flatRow :: Lens.Iso' (Pure # Row) (FlatRowExtends Tag Type Row # Pure)
317 | flatRow =
318 | Lens.iso flatten unflatten
319 | where
320 | flatten =
321 | Lens.runIdentity .
322 | flattenRow (Lens.Identity . (^? _Pure . _RExtend))
323 | unflatten =
324 | Lens.runIdentity .
325 | unflattenRow (Lens.Identity . (_Pure . _RExtend #))
326 |
327 | makeDerivings [''Eq, ''Ord, ''Show] [''Row, ''Type, ''Types, ''TypeError]
328 | makeInstances [''Binary, ''NFData] [''Row, ''Type, ''Types, ''TypeError]
329 |
330 | makeHasHPlain [''Type, ''Row, ''Types]
331 |
332 | instance NFData RConstraints
333 | instance Binary RConstraints
334 |
--------------------------------------------------------------------------------