├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── constraint-rules.cabal ├── src └── Data │ └── Constraint │ ├── Rule.hs │ └── Rule │ ├── Bool.hs │ ├── Nat.hs │ ├── Plugin.hs │ ├── Plugin │ ├── Cache.hs │ ├── Definitions.hs │ ├── Equiv.hs │ ├── Message.hs │ ├── Prelude.hs │ ├── Rule.hs │ └── Runtime.hs │ ├── Symbol.hs │ ├── TH.hs │ └── Trace.hs └── test ├── Main.hs └── Test ├── IntroDefs.hs ├── IntroSpec.hs ├── SimplDefs.hs ├── SimplSpec.hs └── Util.hs /.gitignore: -------------------------------------------------------------------------------- 1 | dist 2 | dist-* 3 | cabal-dev 4 | *.o 5 | *.hi 6 | *.hie 7 | *.chi 8 | *.chs.h 9 | *.dyn_o 10 | *.dyn_hi 11 | .hpc 12 | .hsenv 13 | .cabal-sandbox/ 14 | cabal.sandbox.config 15 | *.prof 16 | *.aux 17 | *.hp 18 | *.eventlog 19 | .stack-work/ 20 | cabal.project.local 21 | cabal.project.local~ 22 | .HTF/ 23 | .ghc.environment.* 24 | 25 | .vscode 26 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Revision history for constraint-rules 2 | 3 | ## 0.1.0.0 -- YYYY-mm-dd 4 | 5 | * Initial release. 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Anthony Vandikas 2 | 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | * Neither the name of Anthony Vandikas nor the names of other 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # constraint-rules 2 | 3 | This package provides a GHC plugin to facilitate the implementation of custom type checking rules, _without_ the need to implement a new type checker plugin every time. To use this plugin, add the following line to the top of the file for which you would like to use the extension: 4 | 5 | ```haskell 6 | {-# OPTIONS_GHC -fplugin=Data.Constraint.Rule.Plugin #-} 7 | ``` 8 | 9 | Three types of rules are supported: 10 | 11 | * _Introduction_ rules replace wanted constraints with other wanted constraints. They are declared as 12 | 13 | ```haskell 14 | myIntroRule ∷ (C₁, ..., Cₙ) ⇒ Dict C 15 | myIntroRule = ... 16 | ``` 17 | 18 | Whenever `C` appears as a wanted contraint, it is replaced by the set of constraints `C₁, ..., Cₙ`. 19 | 20 | * _Derivation_ rules add additional constraints to the set of given constraints. They are declared as 21 | 22 | ```haskell 23 | myDerivRule ∷ (C₁, ..., Cₙ) ⇒ Dict C 24 | myDerivRule = ... 25 | ``` 26 | 27 | Any type variables appearing in `C` must also appear in at least one of `C₁, ..., Cₙ`. Whenever the constraints `C₁, ..., Cₙ` appear as given constraints, the constraint `C` is added to set of given constraints. 28 | 29 | * _Simplification_ rules add new equalities to the set of given constraints. They are declared as 30 | 31 | ```haskell 32 | mySimplRule ∷ (C₁, ..., Cₙ) ⇒ Dict (P ~ Q) 33 | mySimplRule = ... 34 | ``` 35 | 36 | Whenever the constraints `C₁, ..., Cₙ` appear as given constraints and the pattern `P` appears in a given _or_ wanted constraint, the constraint `P ~ Q` is added to the set of given constraints. 37 | 38 | ## Usage 39 | 40 | Suppose we have the introduction rule `plusNat`, and we'd like to use it to type check `example`: 41 | 42 | ```haskell 43 | import Data.Constraint (Dict (..)) 44 | import GHC.TypeNats (KnownNat, type (+)) 45 | 46 | plusNat ∷ (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (m + n)) 47 | plusNat = ... 48 | 49 | example ∷ KnownNat n ⇒ Dict (KnownNat (5 + n)) 50 | example = Dict -- Error: unsatisfied `KnownNat (5 + n)` constraint. 51 | ``` 52 | 53 | By default, the plugin does not apply a rule unless explicitly told to. We can instruct the plugin solve the unsatisfied `KnownNat (5 + n)` constraint using `plusNat` by adding an `Intro` constraint to the context: 54 | 55 | ```haskell 56 | import Data.Constraint.Rule (Intro) 57 | import Data.Constraint.Rule.TH (spec) 58 | 59 | example ∷ (Intro $(spec 'plusNat), KnownNat n) ⇒ Dict (KnownNat (5 + n)) 60 | example = Dict -- Works! 61 | ``` 62 | 63 | `Intro` constraints have trivial instances, so any code calling `example` need not worry about providing an instance. Another way to do this without changing the type signature of `example` is to use the `withIntro` function, which has the signature `withIntro ∷ Proxy a → (Intro a ⇒ r) → r`: 64 | 65 | ```haskell 66 | import Data.Constraint.Rule (withIntro) 67 | import Data.Constraint.Rule.TH (spec) 68 | 69 | example ∷ KnownNat n ⇒ Dict (KnownNat (5 + n)) 70 | example = withIntro $(spec 'plusNat) Dict -- Works! 71 | ``` 72 | 73 | Lastly, we can instruct the plugin to use `plusNat` _by default_ by adding an annotation to `plusNat`: 74 | 75 | ```haskell 76 | import Data.Constraint.Rule (RuleUsage (..)) 77 | 78 | {-# ANN plusNat Intro #-} 79 | plusNat ∷ (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (m + n)) 80 | plusNat = ... 81 | ``` 82 | 83 | Note that rules annotated in this manner are only applied automatically when they are imported: 84 | 85 | ```haskell 86 | import MyRules (plusNat) 87 | 88 | example ∷ KnownNat n ⇒ Dict (KnownNat (5 + n)) 89 | example = Dict -- Works! 90 | ``` 91 | 92 | This restriction exists to prevent rule implementors from accidentally creating an infinite loop. If the restriction didn't exist, the following code would pass the type checker but cause an infinite loop at runtime: 93 | 94 | ```haskell 95 | {-# ANN plusNat Intro #-} 96 | plusNat ∷ (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (m + n)) 97 | plusNat = Dict 98 | ``` 99 | 100 | If you would like to prevent a rule from being applied by default, you must introduce a `NoIntro` constraint into the context of the expression being type checked: 101 | 102 | ```haskell 103 | import MyRules (plusNat) 104 | import Data.Constraint.Rule (Ignore) 105 | 106 | example ∷ (NoIntro $(ref 'plusNat), KnownNat n) ⇒ Dict (KnownNat (5 + n)) 107 | example = Dict -- Unsatisfied `KnownNat (5 + n)` constraint. 108 | ``` 109 | 110 | This can also be achieved using the `ignoreIntro` function: 111 | 112 | ```haskell 113 | import MyRules (plusNat) 114 | import Data.Constraint.Rule (ignoreIntro) 115 | 116 | example ∷ KnownNat n ⇒ Dict (KnownNat (5 + n)) 117 | example = ignoreIntro $(ref 'plusNat) Dict -- Unsatisfied `KnownNat (5 + n)` constraint. 118 | ``` 119 | 120 | ## Use Cases 121 | 122 | **Automatically Solving Known\* Constraints.** As shown in the example, this plugin can be used to automatically solve `KnownNat` constraints. Unlike the [ghc-typelits-knownnat](https://hackage.haskell.org/package/ghc-typelits-knownnat) package, this plugin can also be used to solve `KnownSymbol` constraints, and any other constraints of a similar nature. 123 | 124 | **Automatically Solving Type Equalities.** Simplification rules can be used to automatically solve type equality constraints. For example, one can define a rule for the commutativity of addition: 125 | 126 | ```haskell 127 | {-# ANN plusCommutes Simpl #-} 128 | plusCommutes ∷ Dict ((x + y) ~ (y + x)) 129 | plusCommutes = ... 130 | ``` 131 | 132 | Unlike the [typelevel-rewrite-rules](https://hackage.haskell.org/package/typelevel-rewrite-rules) package, simplification rules do not modify any existing constraints and merely add additional equality constraints to the context. For example, if we want to solve the constraint `x + 5 ~ y`, the plugin will generate the constraints `x + 5 ~ 5 + x` and `5 + x ~ x + 5` and then stop (instead of endlessly rewriting `x + 5 → 5 + x → x + 5 → ...`). Thus this plugin terminates in many cases where `typelevel-rewrite-rules` won't. 133 | 134 | Note that a full set of rules for associativity and commutativity might cause a blowup in the number of given constraints, so this approach should only be used for small problems. For a more robust solution, see the [ghc-typelits-natnormalise](https://hackage.haskell.org/package/ghc-typelits-natnormalise) package. 135 | 136 | # Q & A 137 | 138 | **What happens if I give a 'bad' rule implementation?** If your rule implementation evaluates to `⊥`, then this will manifest _at runtime_ whenever a constraint produced by your rule is used. 139 | 140 | ```haskell 141 | badEq ∷ Dict (Eq a) 142 | badEq = error "badEq" 143 | 144 | uhOh ∷ a → Bool 145 | uhOh x = withIntro $(spec 'badEq) (x == x) 146 | 147 | > uhOh "OH NO" 148 | *** Exception: badEq 149 | ``` 150 | 151 | **Can this break class coherence?** Yes and no. 152 | 153 | We can easily break coherence if we use `unsafe*` functions to implement a rule that synthesizes a new class instance. 154 | 155 | If we ignore `unsafe*` functions, it is possible to implement a rule that evaluates to `⊥`: 156 | 157 | ```haskell 158 | {-# ANN badEq Intro #-} 159 | badEq ∷ Dict (Eq a) 160 | badEq = error "badEq" 161 | ``` 162 | 163 | However, this (should be) the _only_ way in which incoherence arises. If your rule implementations are total, then they will always evaluate to existing instances, and thus coherence is preserved. Otherwise, we have "coherence up to `⊥`". 164 | 165 | **What order are rules applied?** The order is unspecified, so be careful about what rules you annotate! If in doubt, choose opt-in over opt-out. 166 | 167 | Do note that this plugin runs _after_ the type-checker has already tried its hardest to solve your constraints, so existing type class instances take priority over introduction rules. 168 | 169 | Additionally, only one rule is applied at a time. Every time a rule is applied, control is given back to the type checker, which then tries to make progress with the new constraints. If the type checker gets stuck again, the plugin is invoked once more and this process repeats until the constraint is solved or no more rules can be applied. 170 | 171 | **Is termination guaranteed?** Not at all! If you have a rule that continuously generates larger and larger constraints, e.g. 172 | 173 | ``` 174 | {-# ANN plusZero Intro #-} 175 | plusZero ∷ KnownNat (n + 1) ⇒ Dict (KnownNat n) 176 | plusZero = ... 177 | ``` 178 | 179 | then type checking can absolutely loop forever. However, non-termination is not guaranteed either in these cases since GHC takes over in between every rule application and might shrink or solve all the remaining constraints. 180 | -------------------------------------------------------------------------------- /constraint-rules.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 3.0 2 | name: constraint-rules 3 | version: 0.1 4 | synopsis: Extend the type checker with user-defined rules 5 | -- description: 6 | homepage: https://github.com/YellPika/constraint-rules 7 | bug-reports: https://github.com/YellPika/constraint-rules/issues 8 | license: BSD-3-Clause 9 | license-file: LICENSE 10 | author: Anthony Vandikas 11 | maintainer: yellpika@gmail.com 12 | copyright: © 2021 Anthony Vandikas 13 | category: GHC, Type System 14 | extra-source-files: CHANGELOG.md 15 | 16 | library 17 | exposed-modules: Data.Constraint.Rule, 18 | Data.Constraint.Rule.Bool, 19 | Data.Constraint.Rule.Nat, 20 | Data.Constraint.Rule.Plugin, 21 | Data.Constraint.Rule.Plugin.Runtime, 22 | Data.Constraint.Rule.Symbol, 23 | Data.Constraint.Rule.TH 24 | other-modules: Data.Constraint.Rule.Plugin.Cache, 25 | Data.Constraint.Rule.Plugin.Definitions, 26 | Data.Constraint.Rule.Plugin.Equiv, 27 | Data.Constraint.Rule.Plugin.Message, 28 | Data.Constraint.Rule.Plugin.Prelude, 29 | Data.Constraint.Rule.Plugin.Rule 30 | -- other-extensions: 31 | build-depends: base >=4.14 && <4.17, 32 | closed-classes >=0.1 && <0.2, 33 | constraints >=0.13 && <0.15, 34 | ghc >=8.10 && <9.3, 35 | ghc-definitions-th >=0.1 && <0.2, 36 | template-haskell >=2.16 && <2.19, 37 | transformers >=0.5 && <0.6, 38 | hs-source-dirs: src 39 | default-language: Haskell2010 40 | ghc-options: -Wall 41 | 42 | if impl(ghc >= 9.0) && impl(ghc < 9.2) 43 | mixins: 44 | ghc ( GHC.Core.Type as GHC.Types.TyThing ), 45 | ghc 46 | 47 | if impl(ghc < 9.0) 48 | mixins: 49 | ghc ( Class as GHC.Core.Class 50 | , Constraint as GHC.Tc.Types.Constraint 51 | , CoreSyn as GHC.Core 52 | , ErrUtils as GHC.Utils.Error 53 | , GhcPlugins as GHC.Plugins 54 | , LoadIface as GHC.Iface.Load 55 | , Maybes as GHC.Data.Maybe 56 | , OccName as GHC.Types.Name.Occurence 57 | , Pair as GHC.Data.Pair 58 | , Predicate as GHC.Core.Predicate 59 | , RnEnv as GHC.Rename.Env 60 | , TcEnv as GHC.Tc.Utils.Env 61 | , TcEvidence as GHC.Tc.Types.Evidence 62 | , TcOrigin as GHC.Tc.Types.Origin 63 | , TcPluginM as GHC.Tc.Plugin 64 | , TcRnMonad as GHC.Tc.Utils.Monad 65 | , TcRnTypes as GHC.Tc.Types 66 | , TcType as GHC.Tc.Utils.TcType 67 | , TyCon as GHC.Core.TyCon 68 | , Type as GHC.Core.Type 69 | , Type as GHC.Types.TyThing 70 | , TysPrim as GHC.Builtin.Types.Prim 71 | , TysWiredIn as GHC.Builtin.Types 72 | ) 73 | 74 | test-suite test 75 | type: exitcode-stdio-1.0 76 | main-is: Main.hs 77 | other-modules: Test.IntroDefs, 78 | Test.IntroSpec, 79 | Test.SimplDefs, 80 | Test.SimplSpec, 81 | Test.Util 82 | build-depends: base >=4.14 && <4.17, 83 | constraints >=0.13 && <0.15, 84 | hspec >=2.8 && <2.9, 85 | constraint-rules 86 | build-tool-depends: hspec-discover:hspec-discover >=2.8 && <2.9 87 | hs-source-dirs: test 88 | default-language: Haskell2010 89 | ghc-options: -Wall 90 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ConstraintKinds #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE DefaultSignatures #-} 4 | {-# LANGUAGE DeriveDataTypeable #-} 5 | {-# LANGUAGE ExistentialQuantification #-} 6 | {-# LANGUAGE FlexibleContexts #-} 7 | {-# LANGUAGE FlexibleInstances #-} 8 | {-# LANGUAGE KindSignatures #-} 9 | {-# LANGUAGE MonoLocalBinds #-} 10 | {-# LANGUAGE MultiParamTypeClasses #-} 11 | {-# LANGUAGE RankNTypes #-} 12 | {-# LANGUAGE TemplateHaskell #-} 13 | {-# LANGUAGE UnicodeSyntax #-} 14 | 15 | module Data.Constraint.Rule ( 16 | RuleUsage (..), RuleArg (..), RuleName (..), RuleSpec (..), 17 | Use, Ignore, Intro, Deriv, Simpl, NoIntro, NoDeriv, NoSimpl, 18 | withIntro, withDeriv, withSimpl, 19 | ignoreIntro, ignoreDeriv, ignoreSimpl 20 | ) where 21 | 22 | import Data.Constraint.Rule.Plugin.Prelude 23 | 24 | import Data.Class.Closed.TH (close) 25 | import Data.Data (Data, Proxy) 26 | import GHC.TypeLits (Symbol) 27 | 28 | data RuleUsage = Intro | Deriv | Simpl 29 | deriving (Data, Eq, Ord, Show) 30 | 31 | instance Outputable RuleUsage where 32 | ppr = text . show 33 | 34 | data RuleArg = ∀a. RuleArg a 35 | data RuleName = RuleName Symbol Symbol 36 | data RuleSpec = RuleSpec RuleName [RuleArg] 37 | 38 | close [d| 39 | class Use (ruleUsage ∷ RuleUsage) (ruleSpec ∷ RuleSpec) 40 | instance Use (ruleUsage ∷ RuleUsage) (ruleSpec ∷ RuleSpec) 41 | class Ignore (ruleUsage ∷ RuleUsage) (ruleName ∷ RuleName) 42 | instance Ignore (ruleUsage ∷ RuleUsage) (ruleName ∷ RuleName) 43 | |] 44 | 45 | type Intro = Use 'Intro 46 | type Deriv = Use 'Deriv 47 | type Simpl = Use 'Simpl 48 | type NoIntro = Ignore 'Intro 49 | type NoDeriv = Ignore 'Deriv 50 | type NoSimpl = Ignore 'Simpl 51 | 52 | withIntro ∷ Proxy a → (Intro a ⇒ r) → r 53 | withIntro _ x = x 54 | 55 | withDeriv ∷ Proxy a → (Deriv a ⇒ r) → r 56 | withDeriv _ x = x 57 | 58 | withSimpl ∷ Proxy a → (Simpl a ⇒ r) → r 59 | withSimpl _ x = x 60 | 61 | ignoreIntro ∷ Proxy a → (NoIntro a ⇒ r) → r 62 | ignoreIntro _ x = x 63 | 64 | ignoreDeriv ∷ Proxy a → (NoDeriv a ⇒ r) → r 65 | ignoreDeriv _ x = x 66 | 67 | ignoreSimpl ∷ Proxy a → (NoSimpl a ⇒ r) → r 68 | ignoreSimpl _ x = x 69 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Bool.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE MultiParamTypeClasses #-} 5 | {-# LANGUAGE QuantifiedConstraints #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE StandaloneDeriving #-} 8 | {-# LANGUAGE TypeApplications #-} 9 | {-# LANGUAGE TypeOperators #-} 10 | {-# LANGUAGE UnicodeSyntax #-} 11 | 12 | module Data.Constraint.Rule.Bool where 13 | 14 | import Data.Constraint (Dict (..), HasDict (..)) 15 | import Data.Constraint.Rule (RuleUsage (..)) 16 | import Data.Type.Bool (If, Not, type (&&), type (||)) 17 | import Data.Type.Equality (TestEquality (..), type (==), (:~:) (..)) 18 | import Unsafe.Coerce (unsafeCoerce) 19 | 20 | data IsBool b where 21 | IsTrue ∷ IsBool 'True 22 | IsFalse ∷ IsBool 'False 23 | 24 | deriving instance Eq (IsBool b) 25 | deriving instance Ord (IsBool b) 26 | deriving instance Show (IsBool b) 27 | 28 | instance TestEquality IsBool where 29 | testEquality IsTrue IsTrue = Just Refl 30 | testEquality IsFalse IsFalse = Just Refl 31 | testEquality _ _ = Nothing 32 | 33 | instance HasDict (KnownBool a) (IsBool a) where 34 | evidence IsTrue = Dict 35 | evidence IsFalse = Dict 36 | 37 | class KnownBool b where isBool ∷ IsBool b 38 | instance KnownBool 'True where isBool = IsTrue 39 | instance KnownBool 'False where isBool = IsFalse 40 | 41 | {-# ANN ifConstraint Intro #-} 42 | ifConstraint ∷ ∀b p q. (KnownBool b, b ~ 'True ⇒ p, b ~ 'False ⇒ q) ⇒ Dict (If b p q) 43 | ifConstraint = 44 | case isBool @b of 45 | IsTrue → Dict 46 | IsFalse → Dict 47 | 48 | {-# ANN ifBool Simpl #-} 49 | ifBool ∷ Dict (KnownBool (If b m n) ~ If b (KnownBool m) (KnownBool n)) 50 | ifBool = unsafeCoerce (Dict ∷ Dict (a ~ a)) 51 | 52 | {-# ANN andBool Intro #-} 53 | andBool ∷ ∀b c. (KnownBool b, KnownBool c) ⇒ Dict (KnownBool (b && c)) 54 | andBool = 55 | case isBool @b of 56 | IsTrue → Dict 57 | IsFalse → Dict 58 | 59 | {-# ANN orBool Intro #-} 60 | orBool ∷ ∀b c. (KnownBool b, KnownBool c) ⇒ Dict (KnownBool (b || c)) 61 | orBool = 62 | case isBool @b of 63 | IsTrue → Dict 64 | IsFalse → Dict 65 | 66 | {-# ANN notBool Intro #-} 67 | notBool ∷ ∀b. KnownBool b ⇒ Dict (KnownBool (Not b)) 68 | notBool = 69 | case isBool @b of 70 | IsTrue → Dict 71 | IsFalse → Dict 72 | 73 | {-# ANN equalBool Intro #-} 74 | equalBool ∷ ∀b c. (KnownBool b, KnownBool c) ⇒ Dict (KnownBool (b == c)) 75 | equalBool = 76 | case (isBool @b, isBool @c) of 77 | (IsTrue, IsTrue) → Dict 78 | (IsFalse, IsTrue) → Dict 79 | (IsTrue, IsFalse) → Dict 80 | (IsFalse, IsFalse) → Dict 81 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Nat.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE NoStarIsType #-} 5 | {-# LANGUAGE QuantifiedConstraints #-} 6 | {-# LANGUAGE ScopedTypeVariables #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | {-# LANGUAGE TypeOperators #-} 9 | {-# LANGUAGE UnicodeSyntax #-} 10 | 11 | module Data.Constraint.Rule.Nat where 12 | 13 | import Data.Constraint (Dict (..), evidence) 14 | import Data.Constraint.Nat (Gcd, Lcm) 15 | import qualified Data.Constraint.Nat as Nat 16 | import Data.Constraint.Rule (RuleUsage (..)) 17 | import Data.Constraint.Rule.Bool (IsBool (..), KnownBool) 18 | import Data.Proxy (Proxy (..)) 19 | import Data.Type.Bool (If) 20 | import Data.Type.Equality (type (==)) 21 | import GHC.TypeNats (Div, KnownNat, Mod, natVal, 22 | type (*), type (+), type (-), 23 | type (<=), type (<=?)) 24 | import Unsafe.Coerce (unsafeCoerce) 25 | 26 | {-# ANN plusNat Intro #-} 27 | plusNat ∷ ∀m n. (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (m + n)) 28 | plusNat = evidence (Nat.plusNat @m @n) 29 | 30 | {-# ANN minusNat Intro #-} 31 | minusNat ∷ ∀m n. (KnownNat m, KnownNat n, n <= m) ⇒ Dict (KnownNat (m - n)) 32 | minusNat = evidence (Nat.minusNat @m @n) 33 | 34 | {-# ANN timesNat Intro #-} 35 | timesNat ∷ ∀m n. (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (m * n)) 36 | timesNat = evidence (Nat.timesNat @m @n) 37 | 38 | {-# ANN divNat Intro #-} 39 | divNat ∷ ∀m n. (KnownNat m, KnownNat n, 1 <= n) ⇒ Dict (KnownNat (Div m n)) 40 | divNat = evidence (Nat.divNat @m @n) 41 | 42 | {-# ANN modNat Intro #-} 43 | modNat ∷ ∀m n. (KnownNat m, KnownNat n, 1 <= n) ⇒ Dict (KnownNat (Mod m n)) 44 | modNat = evidence (Nat.modNat @m @n) 45 | 46 | {-# ANN lcmNat Intro #-} 47 | lcmNat ∷ ∀m n. (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (Lcm m n)) 48 | lcmNat = evidence (Nat.lcmNat @m @n) 49 | 50 | {-# ANN gcdNat Intro #-} 51 | gcdNat ∷ ∀m n. (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (Gcd m n)) 52 | gcdNat = evidence (Nat.gcdNat @m @n) 53 | 54 | {-# ANN ifNat Simpl #-} 55 | ifNat ∷ Dict (KnownNat (If b m n) ~ If b (KnownNat m) (KnownNat n)) 56 | ifNat = unsafeCoerce (Dict ∷ Dict (a ~ a)) 57 | 58 | {-# ANN lessEqualNat Intro #-} 59 | lessEqualNat ∷ ∀m n. (KnownNat m, KnownNat n) ⇒ Dict (KnownBool (m <=? n)) 60 | lessEqualNat = 61 | if natVal @m Proxy <= natVal @n Proxy 62 | then unsafeCoerce (evidence IsTrue) 63 | else unsafeCoerce (evidence IsFalse) 64 | 65 | {-# ANN equalNat Intro #-} 66 | equalNat ∷ ∀m n. (KnownNat m, KnownNat n) ⇒ Dict (KnownBool (m == n)) 67 | equalNat = 68 | if natVal @m Proxy == natVal @n Proxy 69 | then unsafeCoerce (evidence IsTrue) 70 | else unsafeCoerce (evidence IsFalse) 71 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Plugin.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} 2 | {-# LANGUAGE BlockArguments #-} 3 | {-# LANGUAGE ConstraintKinds #-} 4 | {-# LANGUAGE DataKinds #-} 5 | {-# LANGUAGE ImplicitParams #-} 6 | {-# LANGUAGE RecordWildCards #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | {-# LANGUAGE UnicodeSyntax #-} 9 | {-# LANGUAGE ViewPatterns #-} 10 | 11 | module Data.Constraint.Rule.Plugin where 12 | 13 | import Data.Constraint.Rule.Plugin.Cache 14 | import Data.Constraint.Rule.Plugin.Definitions 15 | import Data.Constraint.Rule.Plugin.Equiv 16 | import Data.Constraint.Rule.Plugin.Message 17 | import Data.Constraint.Rule.Plugin.Prelude 18 | import Data.Constraint.Rule.Plugin.Rule 19 | 20 | import Control.Monad (guard) 21 | import Control.Monad.Trans.State (execStateT, runStateT) 22 | import Data.Constraint (Dict (..)) 23 | import Data.Maybe (fromJust, isJust, maybeToList) 24 | 25 | plugin ∷ Plugin 26 | plugin = defaultPlugin 27 | { tcPlugin = \_ → Just TcPlugin 28 | { tcPluginInit = (,,) <$> findDefinitions <*> newCache <*> newMessages 29 | , tcPluginSolve = solve 30 | , tcPluginStop = \(_, _, Dict) → reportMessages 31 | } 32 | , pluginRecompile = purePlugin 33 | } 34 | 35 | solve ∷ (Dict Definitions, Dict Cache, Dict Messages) → [Ct] → [Ct] → [Ct] → TcPluginM TcPluginResult 36 | solve (Dict, Dict, Dict) givens deriveds wanteds = do 37 | Dict ← return (findTraceKeys givens) 38 | 39 | trace @"Constraints" $ 40 | hang (text "Givens:") 4 (ppr givens) $$ 41 | hang (text "Deriveds:") 4 (ppr deriveds) $$ 42 | hang (text "Wanteds:") 4 (ppr wanteds) 43 | 44 | Dict ← return (findEqualities givens) 45 | let pprCo x = ppr x <+> text "∷" <+> ppr (coercionKind x) 46 | trace @"Equalities" (ppr (map pprCo ?equalities)) 47 | 48 | (Dict, givens') ← findCached givens 49 | trace @"Cached" (ppr ?cached) 50 | 51 | (rules, givens'') ← findRules givens' 52 | trace @"Rules" (ppr rules) 53 | 54 | case applyRules givens'' (deriveds ++ wanteds) rules of 55 | [] → return (TcPluginOk [] []) 56 | apply:_ → apply 57 | 58 | applyRules ∷ (Definitions, Cached, Equalities, Messages, TraceKeys) ⇒ [Ct] → [Ct] → [Rule] → [TcPluginM TcPluginResult] 59 | applyRules givens wanteds rules = do 60 | rule@Rule {..} ← rules 61 | apply ← applyRule givens wanteds rule 62 | return do 63 | addUsedGREs (maybeToList ruleElt) 64 | apply 65 | 66 | applyRule ∷ (Definitions, Cached, Equalities, Messages, TraceKeys) ⇒ [Ct] → [Ct] → Rule → [TcPluginM TcPluginResult] 67 | applyRule _ wanteds rule@Rule { ruleGoal = IntroGoal template, .. } = do 68 | ct ← wanteds 69 | (coe, σ) ← runStateT (match template (equivClass (ctPred ct))) ruleArgs 70 | return do 71 | σ' ← instantiate σ ruleVars 72 | evs ← mapM (newWanted (bumpCtLocDepth (ctLoc ct)) . substTyAddInScope σ') ruleCts 73 | trace @"Intro" $ 74 | hang (text "Applying Intro") 4 (ppr rule) $$ 75 | text "with" <+> ppr σ' 76 | 77 | let ruleExpr = Var ruleDef 78 | `mkTyApps` map (fromJust . lookupTyVar σ') ruleVars 79 | `mkApps` map ctEvExpr evs 80 | DictTy [goal] = exprType ruleExpr 81 | openExpr = OpenDictExpr [Type goal, Type goal, ruleExpr] 82 | castExpr = mkCast openExpr (mkSubCo (mkSymCo coe)) 83 | return (TcPluginOk [(EvExpr castExpr, ct)] (map mkNonCanonical evs)) 84 | applyRule givens wanteds rule@Rule { ruleGoal = DerivGoal _, .. } = do 85 | guard (null wanteds) 86 | (evs, σ) ← runStateT (matchAny ruleCts (map (\ct → (equivClass (ctPred ct), ctEvidence ct)) givens)) ruleArgs 87 | let ruleExpr = Var ruleDef 88 | `mkTyApps` map (fromJust . lookupTyVar σ) ruleVars 89 | `mkApps` map (\(ev, coe) → mkCast (ctEvExpr ev) (mkSubCo coe)) evs 90 | DictTy [goal] = exprType ruleExpr 91 | openExpr = OpenDictExpr [Type goal, Type goal, ruleExpr] 92 | guard (not (isCached goal)) 93 | return do 94 | trace @"Deriv" $ 95 | hang (text "Applying Deriv") 4 (ppr rule) $$ 96 | text "with" <+> ppr σ 97 | cachedExpr ← cached goal 98 | emitGivens ruleLoc [openExpr, cachedExpr] wanteds 99 | applyRule givens wanteds rule@Rule { ruleGoal = SimplGoal lhs _, .. } = do 100 | ct ← if null wanteds then givens else wanteds 101 | guard (not (isHoleCt ct)) 102 | 103 | child ← children (ctPred ct) 104 | σ ← execStateT (match lhs (equivClass child)) ruleArgs 105 | (evs, σ') ← runStateT (matchAny ruleCts (map (\c → (equivClass (ctPred c), ctEvidence c)) givens)) σ 106 | let ruleExpr = Var ruleDef 107 | `mkTyApps` map (fromJust . lookupTyVar σ') ruleVars 108 | `mkApps` map (\(ev, coe) → mkCast (ctEvExpr ev) (mkSubCo coe)) evs 109 | DictTy [goal] = exprType ruleExpr 110 | openExpr = OpenDictExpr [Type goal, Type goal, ruleExpr] 111 | guard (not (isCached goal)) 112 | return do 113 | trace @"Simpl" $ 114 | hang (text "Applying Simpl") 4 (ppr rule) $$ 115 | text "with" <+> ppr σ' 116 | cachedExpr ← cached goal 117 | emitGivens ruleLoc [openExpr, cachedExpr] wanteds 118 | 119 | instantiate ∷ TCvSubst → [Var] → TcPluginM TCvSubst 120 | instantiate σ [] = return σ 121 | instantiate σ (x:xs) 122 | | isJust (lookupTyVar σ x) = instantiate σ xs 123 | | otherwise = do 124 | x' ← newFlexiTyVar (substTyAddInScope σ (varType x)) 125 | instantiate (extendTvSubstAndInScope σ x (mkTyVarTy x')) xs 126 | 127 | children ∷ Type → [Type] 128 | children = \t → t : go t where 129 | go ∷ Type → [Type] 130 | go (splitAppTy_maybe → Just (t, u)) = children t ++ children u 131 | go (splitTyConApp_maybe → Just (_, ts)) = concatMap children ts 132 | go (splitFunTy_maybe → Just (t, u)) = children t ++ children u 133 | go (splitCastTy_maybe → Just (t, _)) = children t 134 | go _ = [] 135 | 136 | emitGivens ∷ (Messages, TraceKeys) ⇒ CtLoc → [EvExpr] → [Ct] → TcPluginM TcPluginResult 137 | emitGivens loc givens [] = do 138 | evs ← mapM (\e → newGiven loc (exprType e) e) givens 139 | return (TcPluginOk [] (map mkNonCanonical evs)) 140 | emitGivens loc givens wanteds = do 141 | let remap (EqPrimTy [x, y, z, w]) = HEqTy [x, y, z, w] 142 | remap t = t 143 | unmap (EqPrimTy tys) e = Var (classSCSelId heqClass 0) `mkTyApps` tys `mkApps` [e] 144 | unmap _ e = e 145 | outputs = map (remap . ctPred) wanteds 146 | (output, sel) ← 147 | case outputs of 148 | [t] → return (t, \e _ _ → e) 149 | _ → do 150 | cls ← tcLookupClass (cTupleTyConName (length wanteds)) 151 | return (mkClassPred cls outputs, \e xs x → 152 | mkSingleAltCase 153 | e (mkWildValBinder (exprType e)) 154 | (DataAlt (classDataCon cls)) xs 155 | (Var x)) 156 | 157 | ev ← newWanted loc (mkInvisFunTys (map exprType givens) output) 158 | vars ← mapM (\t → (`mkLocalId` t) <$> newName (mkVarOcc "x")) outputs 159 | let app = mkApps (ctEvExpr ev) givens 160 | evs = zipWith (\ct x → unmap (ctPred ct) (sel app vars x)) wanteds vars 161 | trace @"EmitGivens" $ 162 | hang (text "Adding") 4 (ppr ev) $$ 163 | hang (text "Replacing") 4 (ppr (map (\e → ppr e <+> text "∷" <+> ppr (exprType e)) evs)) 164 | 165 | return (TcPluginOk (zip (map EvExpr evs) wanteds) [mkNonCanonical ev]) 166 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Plugin/Cache.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BlockArguments #-} 2 | {-# LANGUAGE ConstraintKinds #-} 3 | {-# LANGUAGE ImplicitParams #-} 4 | {-# LANGUAGE UnicodeSyntax #-} 5 | {-# LANGUAGE ViewPatterns #-} 6 | 7 | module Data.Constraint.Rule.Plugin.Cache where 8 | 9 | import Data.Constraint.Rule.Plugin.Definitions 10 | import Data.Constraint.Rule.Plugin.Prelude 11 | 12 | import Data.Constraint (Dict (..)) 13 | import Data.IORef (IORef, newIORef, readIORef, writeIORef) 14 | import Data.List (findIndex) 15 | 16 | type Cache = (?cache ∷ IORef [Type]) 17 | 18 | newCache ∷ TcPluginM (Dict Cache) 19 | newCache = tcPluginIO do 20 | ref ← newIORef [] 21 | let ?cache = ref 22 | return Dict 23 | 24 | typeIndex ∷ Cache ⇒ Type → TcPluginM Integer 25 | typeIndex t = tcPluginIO do 26 | cache ← readIORef ?cache 27 | case findIndex (eqType t) cache of 28 | Just i → return (fromIntegral i) 29 | Nothing → do 30 | writeIORef ?cache (cache ++ [t]) 31 | return (fromIntegral (length cache)) 32 | 33 | indexType ∷ Cache ⇒ Integer → TcPluginM Type 34 | indexType i = tcPluginIO do 35 | cache ← readIORef ?cache 36 | return (cache !! fromIntegral i) 37 | 38 | cached ∷ (Definitions, Cache) ⇒ Type → TcPluginM EvExpr 39 | cached t = do 40 | i ← typeIndex t 41 | return (CachedExpr [Type (CachedTy [NumLitTy i])]) 42 | 43 | type Cached = (Cache, ?cached ∷ [Type]) 44 | 45 | isCached ∷ Cached ⇒ Type → Bool 46 | isCached t = any (eqType t) ?cached 47 | 48 | findCached ∷ (Cache, Definitions) ⇒ [Ct] → TcPluginM (Dict Cached, [Ct]) 49 | findCached = fmap (\(tys, cts) → let ?cached = tys in (Dict, cts)) . go where 50 | go [] = return ([], []) 51 | go ((ctPred → CachedTy [NumLitTy i]) : cts) = do 52 | t ← indexType i 53 | (ts, cts') ← go cts 54 | return (t:ts, cts') 55 | go (ct:cts) = do 56 | (ts, cts') ← go cts 57 | return (ts, ct:cts') 58 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Plugin/Definitions.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ConstraintKinds #-} 2 | {-# LANGUAGE ImplicitParams #-} 3 | {-# LANGUAGE PatternSynonyms #-} 4 | {-# LANGUAGE TemplateHaskell #-} 5 | {-# LANGUAGE UnicodeSyntax #-} 6 | {-# LANGUAGE ViewPatterns #-} 7 | 8 | module Data.Constraint.Rule.Plugin.Definitions ( 9 | Definitions, 10 | findDefinitions, 11 | pattern DictTy, 12 | pattern UseTy, 13 | pattern IgnoreTy, 14 | pattern RuleUsageTy, 15 | pattern RuleArgTy, 16 | pattern RuleArgListTy, 17 | pattern RuleNameTy, 18 | pattern RuleSpecTy, 19 | pattern CachedTy, 20 | pattern CachedExpr, 21 | pattern OpenDictExpr, 22 | pattern TraceTy, 23 | pattern EqPrimTy, 24 | pattern EqTy, 25 | pattern HEqTy, 26 | pattern NumLitTy, 27 | pattern StrLitTy, 28 | pattern ListTy 29 | ) where 30 | 31 | import Data.Constraint.Rule 32 | import Data.Constraint.Rule.Plugin.Prelude 33 | import Data.Constraint.Rule.Plugin.Runtime 34 | import Data.Constraint.Rule.Trace 35 | 36 | import Data.Constraint (Dict (..)) 37 | import GHC.Definitions.TH (makeDefinitions, makePattern) 38 | 39 | makeDefinitions 40 | [ ''Dict 41 | , ''Use 42 | , ''Ignore 43 | , 'Intro 44 | , 'Deriv 45 | , 'Simpl 46 | , 'RuleArg 47 | , 'RuleName 48 | , 'RuleSpec 49 | , ''Cached 50 | , 'cached 51 | , 'unsafeOpenDict 52 | , ''Trace 53 | ] 54 | 55 | makePattern "DictTy" 'dictTyCon 56 | makePattern "UseTy" 'useClass 57 | makePattern "IgnoreTy" 'ignoreClass 58 | makePattern "IntroTy" 'promotedIntroTyCon 59 | makePattern "DerivTy" 'promotedDerivTyCon 60 | makePattern "SimplTy" 'promotedSimplTyCon 61 | makePattern "RuleArgTy" 'promotedRuleArgTyCon 62 | makePattern "RuleNameTy" 'promotedRuleNameTyCon 63 | makePattern "RuleSpecTy" 'promotedRuleSpecTyCon 64 | makePattern "CachedTy" 'cachedClass 65 | makePattern "CachedExpr" 'cachedVar 66 | makePattern "OpenDictExpr" 'unsafeOpenDictVar 67 | makePattern "TraceTy" 'traceClass 68 | 69 | makePattern "EqPrimTy" 'eqPrimTyCon 70 | makePattern "HEqTy" 'heqTyCon 71 | makePattern "EqTy" 'eqTyCon 72 | makePattern "NilTy" 'promotedNilDataCon 73 | makePattern "ConsTy" 'promotedConsDataCon 74 | 75 | pattern RuleUsageTy ∷ Definitions ⇒ RuleUsage → Type 76 | pattern RuleUsageTy x ← (isRuleUsageTy → Just x) 77 | where RuleUsageTy Intro = IntroTy [] 78 | RuleUsageTy Deriv = DerivTy [] 79 | RuleUsageTy Simpl = SimplTy [] 80 | 81 | isRuleUsageTy ∷ Definitions ⇒ Type → Maybe RuleUsage 82 | isRuleUsageTy (IntroTy []) = Just Intro 83 | isRuleUsageTy (DerivTy []) = Just Deriv 84 | isRuleUsageTy (SimplTy []) = Just Simpl 85 | isRuleUsageTy _ = Nothing 86 | 87 | pattern NumLitTy ∷ Integer → Type 88 | pattern NumLitTy x ← (isNumLitTy → Just x) 89 | where NumLitTy x = mkNumLitTy x 90 | 91 | pattern StrLitTy ∷ FastString → Type 92 | pattern StrLitTy x ← (isStrLitTy → Just x) 93 | 94 | pattern ListTy ∷ [Type] → Type 95 | pattern ListTy xs ← (isListTy → Just xs) 96 | 97 | isListTy ∷ Type → Maybe [Type] 98 | isListTy (NilTy [_]) = Just [] 99 | isListTy (ConsTy [_, x, xs]) = (x:) <$> isListTy xs 100 | isListTy _ = Nothing 101 | 102 | pattern RuleArgListTy ∷ Definitions ⇒ [Type] → Type 103 | pattern RuleArgListTy xs ← (ListTy (isRuleArgListTy → Just xs)) 104 | 105 | isRuleArgListTy ∷ Definitions ⇒ [Type] → Maybe [Type] 106 | isRuleArgListTy [] = Just [] 107 | isRuleArgListTy (RuleArgTy [_, x] : xs) = (x:) <$> isRuleArgListTy xs 108 | isRuleArgListTy _ = Nothing 109 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Plugin/Equiv.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ApplicativeDo #-} 2 | {-# LANGUAGE BlockArguments #-} 3 | {-# LANGUAGE ConstraintKinds #-} 4 | {-# LANGUAGE ImplicitParams #-} 5 | {-# LANGUAGE UnicodeSyntax #-} 6 | {-# LANGUAGE ViewPatterns #-} 7 | 8 | module Data.Constraint.Rule.Plugin.Equiv ( 9 | Equalities, findEqualities, 10 | EquivClass, equivClass, contains, 11 | match, matchAny 12 | ) where 13 | 14 | import Data.Constraint.Rule.Plugin.Prelude hiding (empty) 15 | 16 | import Control.Applicative (Alternative (..)) 17 | import Control.Monad (guard, zipWithM) 18 | import Control.Monad.Trans.Class (lift) 19 | import Control.Monad.Trans.State (StateT (..), get, put) 20 | import Data.Bifunctor (first) 21 | import Data.Constraint (Dict (..)) 22 | import Data.Foldable (asum, toList) 23 | import Data.List.NonEmpty (NonEmpty (..), head) 24 | import Data.Maybe (isJust) 25 | import Prelude hiding (head, reverse) 26 | 27 | type Equalities = (?equalities ∷ [Coercion]) 28 | 29 | findEqualities ∷ [Ct] → Dict Equalities 30 | findEqualities cts = let ?equalities = eqs in Dict where 31 | eqs = 32 | map (ctEvCoercion . ctEvidence) . 33 | filter ((== Nominal) . getEqPredRole . ctPred) . 34 | filter (isEqPrimPred . ctPred) $ 35 | cts 36 | 37 | data EquivClass = EquivClass 38 | (NonEmpty (EquivInst, Coercion)) 39 | EquivClass 40 | -- type EquivClass = NonEmpty (EquivInst, Coercion) 41 | 42 | data EquivInst 43 | = EApp EquivClass EquivClass 44 | | ETyCon TyCon [EquivClass] 45 | | EAtomic Type 46 | 47 | repr ∷ EquivClass → Type 48 | repr = fst . head . flatten 49 | 50 | flatten ∷ EquivClass → NonEmpty (Type, Coercion) 51 | flatten (EquivClass t _) = do 52 | (t₀, coe₀) ← t 53 | (t₁, coe₁) ← flattenInst t₀ 54 | return (t₁, mkTransCo coe₀ coe₁) 55 | 56 | flattenInst ∷ EquivInst → NonEmpty (Type, Coercion) 57 | flattenInst (EApp t u) = do 58 | (t', coet) ← flatten t 59 | (u', coeu) ← flatten u 60 | return (mkAppTy t' u', mkAppCo coet coeu) 61 | flattenInst (ETyCon c ts) = do 62 | ts' ← mapM flatten ts 63 | return (mkTyConApp c (map fst ts'), mkTyConAppCo Nominal c (map snd ts')) 64 | flattenInst (EAtomic t) = (t, mkNomReflCo t) :| [] 65 | 66 | containsCo ∷ Alternative f ⇒ EquivClass → Type → f Coercion 67 | containsCo (EquivClass ty' _) ty = asum (fmap go ty') where 68 | go (EApp t' u', coe) | Just (t, u) ← splitAppTy_maybe ty = do 69 | coe₁ ← containsCo t' t 70 | coe₂ ← containsCo u' u 71 | return (mkTransCo coe (mkAppCo coe₁ coe₂)) 72 | go (ETyCon c' ts', coe) | Just (c, ts) ← splitTyConApp_maybe ty = do 73 | guard (c == c' && length ts == length ts') 74 | coes ← zipWithM containsCo ts' ts 75 | return (mkTransCo coe (mkTyConAppCo Nominal c coes)) 76 | go (EAtomic t, coe) = do 77 | guard (eqType t ty) 78 | return coe 79 | go _ = empty 80 | 81 | contains ∷ EquivClass → Type → Bool 82 | contains ty' ty = isJust (containsCo ty' ty) 83 | 84 | sing ∷ Type → EquivClass 85 | sing ty = EquivClass ((ty', mkNomReflCo ty) :| []) (sing (typeKind ty)) where 86 | ty' | Just (t, u) ← splitAppTy_maybe ty = EApp (sing t) (sing u) 87 | | Just (c, ts) ← splitTyConApp_maybe ty = ETyCon c (map sing ts) 88 | | otherwise = EAtomic ty 89 | 90 | close ∷ Coercion → EquivClass → EquivClass → EquivClass → EquivClass 91 | close co lhs@(EquivClass lhs' _) rhs@(EquivClass rhs' _) = go where 92 | go ∷ EquivClass → EquivClass 93 | go c@(EquivClass (t :| ts) k) = EquivClass 94 | (first goInst t :| map (first goInst) (ts ++ lhs'' ++ rhs'')) 95 | (go k) 96 | where clhs = c `containsCo` repr lhs 97 | crhs = c `containsCo` repr rhs 98 | lhs'' | Just co' ← crhs 99 | , Nothing ← clhs 100 | = map (fmap (mkTransCo (mkTransCo co' (mkSymCo co)))) (toList lhs') 101 | | otherwise = [] 102 | rhs'' | Just co' ← clhs 103 | , Nothing ← crhs 104 | = map (fmap (mkTransCo (mkTransCo co' co))) (toList rhs') 105 | | otherwise = [] 106 | 107 | goInst ∷ EquivInst → EquivInst 108 | goInst (EApp t u) = EApp (go t) (go u) 109 | goInst (ETyCon c ts) = ETyCon c (map go ts) 110 | goInst (EAtomic t) = EAtomic t 111 | 112 | equivClass ∷ Equalities ⇒ Type → EquivClass 113 | equivClass = go ?equalities where 114 | go [] t = sing t 115 | go (eq@(coercionKind → Pair lhs rhs):eqs) t = 116 | close eq (go eqs lhs) (go eqs rhs) (go eqs t) 117 | 118 | match ∷ Type → EquivClass → StateT TCvSubst [] Coercion 119 | match template goal@(EquivClass _ goalKind) | Just x ← getTyVar_maybe template = do 120 | σ ← get 121 | case lookupTyVar σ x of 122 | Just t → maybe empty return (goal `containsCo` t) 123 | Nothing → do 124 | let (t, coe) = head (flatten goal) 125 | put (extendTvSubstAndInScope σ x t) 126 | _ ← match (varType x) goalKind 127 | return coe 128 | match template (EquivClass goal _) = asum (fmap go goal) where 129 | go (EApp t' u', coe) | Just (t, u) ← splitAppTy_maybe template = do 130 | coe₁ ← match t t' 131 | coe₂ ← match u u' 132 | return (mkTransCo coe (mkAppCo coe₁ coe₂)) 133 | go (ETyCon c' ts', coe) | Just (c, ts) ← splitTyConApp_maybe template = do 134 | guard (c == c' && length ts == length ts') 135 | coes ← zipWithM match ts ts' 136 | return (mkTransCo coe (mkTyConAppCo Nominal c coes)) 137 | go (EAtomic t, coe) = do 138 | guard (eqType t template) 139 | return coe 140 | go _ = empty 141 | 142 | matchAny ∷ [Type] → [(EquivClass, a)] → StateT TCvSubst [] [(a, Coercion)] 143 | matchAny [] _ = return [] 144 | matchAny (template:templates) goals = do 145 | ((goal, x), goals') ← lift (select goals) 146 | coe ← match template goal 147 | xs ← matchAny templates goals' 148 | return ((x, coe):xs) 149 | 150 | select ∷ [a] → [(a, [a])] 151 | select [] = [] 152 | select (x:xs) = (x, xs) : do 153 | (y, ys) ← select xs 154 | return (y, x:ys) 155 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Plugin/Message.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ConstraintKinds #-} 2 | {-# LANGUAGE ImplicitParams #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | {-# LANGUAGE TypeApplications #-} 5 | {-# LANGUAGE UnicodeSyntax #-} 6 | {-# LANGUAGE ViewPatterns #-} 7 | 8 | module Data.Constraint.Rule.Plugin.Message where 9 | 10 | import Data.Constraint.Rule.Plugin.Definitions 11 | import Data.Constraint.Rule.Plugin.Prelude 12 | import Data.Constraint.Rule.Trace 13 | 14 | import Data.Bifunctor (first, second) 15 | import Data.Constraint (Dict (..)) 16 | import Data.IORef (IORef, modifyIORef', newIORef, readIORef) 17 | import Data.Proxy (Proxy (..)) 18 | import GHC.TypeLits (KnownSymbol, symbolVal) 19 | 20 | type Messages = ?messages ∷ IORef ([(SrcSpan, SDoc)], [(SrcSpan, SDoc)]) 21 | 22 | newMessages ∷ TcPluginM (Dict Messages) 23 | newMessages = do 24 | ref ← tcPluginIO (newIORef ([], [])) 25 | let ?messages = ref 26 | return Dict 27 | 28 | addErrorMessage ∷ Messages ⇒ SDoc → TcPluginM () 29 | addErrorMessage doc = do 30 | spn ← getSrcSpanM 31 | tcPluginIO (modifyIORef' ?messages (first ((spn, doc):))) 32 | 33 | addWarningMessage ∷ Messages ⇒ SDoc → TcPluginM () 34 | addWarningMessage doc = do 35 | spn ← getSrcSpanM 36 | tcPluginIO (modifyIORef' ?messages (second ((spn, doc):))) 37 | 38 | reportMessages ∷ Messages ⇒ TcPluginM () 39 | reportMessages = do 40 | flags ← getDynFlags 41 | unqual ← getPrintUnqualified flags 42 | let undup = foldr insert [] 43 | 44 | pick (srcSpanStart → toRealSrcLoc → Just loc, _) msg 45 | | srcLocLine loc == 1 && srcLocCol loc == 1 = [msg] 46 | pick msg (srcSpanStart → toRealSrcLoc → Just loc, _) 47 | | srcLocLine loc == 1 && srcLocCol loc == 1 = [msg] 48 | pick msg@(loc, _) msg'@(loc', _) 49 | | loc == loc' = [msg] 50 | | otherwise = [msg, msg'] 51 | 52 | insert msg [] = [msg] 53 | insert msg@(spn, doc) (msg'@(spn', doc'):msgs) 54 | | show (mkLongErrMsg flags spn unqual doc empty) == 55 | show (mkLongErrMsg flags spn unqual doc' empty) = pick msg msg' ++ msgs 56 | | otherwise = (spn', doc') : insert msg msgs 57 | 58 | (errs, warns) ← tcPluginIO (readIORef ?messages) 59 | mapM_ (uncurry addErrAt) (undup errs) 60 | mapM_ (uncurry (addWarnAt NoReason)) (undup warns) 61 | 62 | type TraceKeys = (?traceKeys ∷ [FastString]) 63 | 64 | findTraceKeys ∷ Definitions ⇒ [Ct] → Dict TraceKeys 65 | findTraceKeys cts = let ?traceKeys = go cts in Dict where 66 | go [] = [] 67 | go ((ctPred → TraceTy [StrLitTy key]) : xs) = key : go xs 68 | go (_ : xs) = go xs 69 | 70 | trace ∷ ∀key. (KnownSymbol key, TraceKey key, Messages, TraceKeys) ⇒ SDoc → TcPluginM () 71 | trace doc 72 | | key `elem` ?traceKeys = addWarningMessage doc 73 | | otherwise = return () 74 | where key = fsLit (symbolVal (Proxy @key)) 75 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Plugin/Prelude.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE CPP #-} 2 | {-# LANGUAGE UnicodeSyntax #-} 3 | 4 | module Data.Constraint.Rule.Plugin.Prelude ( 5 | module GHC, 6 | addErrAt, 7 | addUsedGREs, 8 | addWarnAt, 9 | getCtLocM, 10 | getDynFlags, 11 | getPrintUnqualified, 12 | getSrcSpanM, 13 | grePrintableName, 14 | isHoleCt, 15 | mkInvisFunTys, 16 | mkLocalId, 17 | mkWildValBinder, 18 | newName, 19 | newWanted, 20 | setSrcSpan, 21 | splitFunTy_maybe, 22 | tcSplitForAllTys, 23 | toRealSrcLoc, 24 | toSrcSpan 25 | ) where 26 | 27 | import GHC.Plugins as GHC (AltCon (..), AnnTarget (..), Coercion, DynFlags, 28 | Expr (..), FastString, GlobalRdrElt, Name, OccName, 29 | Outputable (..), Plugin (..), PrintUnqualified, 30 | RealSrcLoc, RealSrcSpan, Role (..), SDoc, 31 | SrcLoc (..), SrcSpan (..), TCvSubst, TyCon, Type, 32 | TypeOrKind, Var, WarnReason (..), cTupleTyConName, 33 | classDataCon, coercionKind, defaultPlugin, 34 | deserializeWithData, empty, emptyTCvSubst, eqTyCon, 35 | eqType, exprType, extendTvSubstAndInScope, findAnns, 36 | fsLit, getTyVar_maybe, hang, heqClass, heqTyCon, 37 | isGoodSrcSpan, isLocalGRE, isNumLitTy, isStrLitTy, 38 | lookupTyVar, mkAppCo, mkAppTy, mkApps, mkCast, 39 | mkModuleNameFS, mkNomReflCo, mkNumLitTy, 40 | mkSingleAltCase, mkSubCo, mkSymCo, mkTransCo, 41 | mkTvSubstPrs, mkTyApps, mkTyConApp, mkTyConAppCo, 42 | mkTyVarTy, mkVarOcc, mkVarOccFS, occEnvElts, 43 | pprTrace, pprTraceM, prepareAnnotations, 44 | promotedConsDataCon, promotedNilDataCon, purePlugin, 45 | quotes, showSDocForUser, splitAppTy_maybe, 46 | splitCastTy_maybe, splitTyConApp, 47 | splitTyConApp_maybe, srcLocCol, srcLocLine, 48 | srcSpanStart, substTyAddInScope, text, typeKind, 49 | varName, varType, ($$), (<+>)) 50 | 51 | import GHC.Tc.Plugin as GHC (FindResult (..), TcPluginM, findImportedModule, 52 | getEnvs, getEvBindsTcPluginM, getTopEnv, 53 | lookupOrig, newFlexiTyVar, newGiven, tcLookupClass, 54 | tcPluginIO, tcPluginTrace, unsafeTcPluginTcM) 55 | 56 | import GHC.Tc.Types.Constraint as GHC (Ct, CtEvidence (..), CtLoc, 57 | bumpCtLocDepth, ctEvCoercion, ctEvExpr, 58 | ctEvidence, ctLoc, ctLocSpan, ctPred, 59 | mkNonCanonical, pprCtLoc) 60 | 61 | import GHC.Tc.Utils.Monad as GHC (TcGblEnv (..), TcPlugin (..), 62 | TcPluginResult (..), TcTyThing (..), 63 | mapMaybeM, runTcPluginM) 64 | 65 | import GHC.Builtin.Types.Prim as GHC (eqPrimTyCon) 66 | import GHC.Core.Class as GHC (classSCSelId) 67 | import GHC.Core.Predicate as GHC (getEqPredRole, getEqPredTys, isEqPrimPred, 68 | mkClassPred) 69 | import GHC.Data.Maybe as GHC (MaybeErr (..)) 70 | import GHC.Data.Pair as GHC (Pair (..)) 71 | import GHC.Iface.Load as GHC (tcLookupImported_maybe) 72 | import GHC.Tc.Types.Evidence as GHC (EvExpr, EvTerm (..)) 73 | import GHC.Tc.Types.Origin as GHC (CtOrigin (..)) 74 | import GHC.Tc.Utils.Env as GHC (tcLookupLcl_maybe) 75 | import GHC.Tc.Utils.TcType as GHC (tcSplitPhiTy) 76 | import GHC.Types.TyThing as GHC (TyThing (..)) 77 | import GHC.Utils.Error as GHC (mkLongErrMsg) 78 | 79 | import qualified GHC.Plugins as GHC.Internal hiding 80 | (getPrintUnqualified, 81 | getSrcSpanM) 82 | import qualified GHC.Rename.Env as GHC.Internal 83 | import qualified GHC.Tc.Plugin as GHC.Internal 84 | import qualified GHC.Tc.Utils.Monad as GHC.Internal 85 | import qualified GHC.Tc.Utils.TcType as GHC.Internal 86 | 87 | #if !MIN_VERSION_ghc(9, 0, 1) 88 | import qualified GHC.Tc.Types.Constraint as GHC.Internal 89 | #endif 90 | 91 | addErrAt ∷ SrcSpan → SDoc → TcPluginM () 92 | addErrAt spn msg = unsafeTcPluginTcM (GHC.Internal.addErrAt spn msg) 93 | 94 | addUsedGREs ∷ [GlobalRdrElt] → TcPluginM () 95 | addUsedGREs xs = unsafeTcPluginTcM (GHC.Internal.addUsedGREs xs) 96 | 97 | addWarnAt ∷ WarnReason → SrcSpan → SDoc → TcPluginM () 98 | addWarnAt rsn spn msg = unsafeTcPluginTcM (GHC.Internal.addWarnAt rsn spn msg) 99 | 100 | getCtLocM ∷ CtOrigin → Maybe TypeOrKind → TcPluginM CtLoc 101 | getCtLocM x y = unsafeTcPluginTcM (GHC.Internal.getCtLocM x y) 102 | 103 | getDynFlags ∷ TcPluginM DynFlags 104 | getDynFlags = unsafeTcPluginTcM GHC.Internal.getDynFlags 105 | 106 | getSrcSpanM ∷ TcPluginM SrcSpan 107 | getSrcSpanM = unsafeTcPluginTcM GHC.Internal.getSrcSpanM 108 | 109 | grePrintableName ∷ GlobalRdrElt → Name 110 | #if MIN_VERSION_ghc(9, 2, 0) 111 | grePrintableName = GHC.Internal.grePrintableName 112 | #else 113 | grePrintableName = GHC.Internal.gre_name 114 | #endif 115 | 116 | getPrintUnqualified ∷ DynFlags → TcPluginM PrintUnqualified 117 | getPrintUnqualified flags = unsafeTcPluginTcM (GHC.Internal.getPrintUnqualified flags) 118 | 119 | isHoleCt ∷ Ct → Bool 120 | #if MIN_VERSION_ghc(9, 0, 1) 121 | isHoleCt _ = False 122 | #else 123 | isHoleCt = GHC.Internal.isHoleCt 124 | #endif 125 | 126 | mkInvisFunTys ∷ [Type] → Type → Type 127 | #if MIN_VERSION_ghc(9, 0, 1) 128 | mkInvisFunTys = GHC.Internal.mkInvisFunTysMany 129 | #else 130 | mkInvisFunTys = GHC.Internal.mkInvisFunTys 131 | #endif 132 | 133 | mkLocalId ∷ Name → Type → Var 134 | #if MIN_VERSION_ghc(9, 0, 1) 135 | mkLocalId n = GHC.Internal.mkLocalId n GHC.Internal.Many 136 | #else 137 | mkLocalId = GHC.Internal.mkLocalId 138 | #endif 139 | 140 | mkWildValBinder ∷ Type → Var 141 | #if MIN_VERSION_ghc(9, 0, 1) 142 | mkWildValBinder = GHC.Internal.mkWildValBinder GHC.Internal.Many 143 | #else 144 | mkWildValBinder = GHC.Internal.mkWildValBinder 145 | #endif 146 | 147 | newName ∷ OccName → TcPluginM Name 148 | newName x = unsafeTcPluginTcM (GHC.Internal.newName x) 149 | 150 | newWanted ∷ CtLoc → Type → TcPluginM CtEvidence 151 | newWanted loc ty = (\ev → ev { ctev_loc = loc }) <$> GHC.Internal.newWanted loc ty 152 | 153 | setSrcSpan ∷ SrcSpan → TcPluginM a → TcPluginM a 154 | setSrcSpan spn m = do 155 | binds ← getEvBindsTcPluginM 156 | unsafeTcPluginTcM (GHC.Internal.setSrcSpan spn (runTcPluginM m binds)) 157 | 158 | {-# ANN splitFunTy_maybe "HLint: ignore Use camelCase" #-} 159 | splitFunTy_maybe ∷ Type → Maybe (Type, Type) 160 | #if MIN_VERSION_ghc(9, 0, 1) 161 | splitFunTy_maybe t = (\(_, arg, res) → (arg, res)) <$> GHC.Internal.splitFunTy_maybe t 162 | #else 163 | splitFunTy_maybe = GHC.Internal.splitFunTy_maybe 164 | #endif 165 | 166 | tcSplitForAllTys ∷ Type → ([Var], Type) 167 | #if MIN_VERSION_ghc(9, 2, 0) 168 | tcSplitForAllTys = GHC.Internal.tcSplitForAllTyVars 169 | #else 170 | tcSplitForAllTys = GHC.Internal.tcSplitForAllTys 171 | #endif 172 | 173 | toRealSrcLoc ∷ SrcLoc → Maybe RealSrcLoc 174 | #if MIN_VERSION_ghc(9, 0, 1) 175 | toRealSrcLoc (RealSrcLoc x _) = Just x 176 | #else 177 | toRealSrcLoc (RealSrcLoc x) = Just x 178 | #endif 179 | toRealSrcLoc _ = Nothing 180 | 181 | toSrcSpan ∷ RealSrcSpan → SrcSpan 182 | #if MIN_VERSION_ghc(9, 0, 1) 183 | toSrcSpan x = RealSrcSpan x Nothing 184 | #else 185 | toSrcSpan = RealSrcSpan 186 | #endif 187 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Plugin/Rule.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BlockArguments #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE OverloadedStrings #-} 4 | {-# LANGUAGE RecordWildCards #-} 5 | {-# LANGUAGE TupleSections #-} 6 | {-# LANGUAGE TypeApplications #-} 7 | {-# LANGUAGE UnicodeSyntax #-} 8 | {-# LANGUAGE ViewPatterns #-} 9 | 10 | module Data.Constraint.Rule.Plugin.Rule ( 11 | Rule (..), RuleGoal (..), findRules 12 | ) where 13 | 14 | import Data.Constraint.Rule 15 | import Data.Constraint.Rule.Plugin.Definitions 16 | import Data.Constraint.Rule.Plugin.Equiv 17 | import Data.Constraint.Rule.Plugin.Message 18 | import Data.Constraint.Rule.Plugin.Prelude hiding (empty) 19 | 20 | import Control.Applicative (empty) 21 | import Control.Monad.Trans.Class (lift) 22 | import Control.Monad.Trans.Maybe (MaybeT (..)) 23 | import Data.Function (on) 24 | import Data.List (nub, sort, sortBy) 25 | import Data.Maybe (mapMaybe) 26 | 27 | data Rule = Rule { 28 | ruleDef ∷ Var, 29 | ruleArgs ∷ TCvSubst, 30 | ruleVars ∷ [Var], 31 | ruleCts ∷ [Type], 32 | ruleGoal ∷ RuleGoal, 33 | ruleLoc ∷ CtLoc, 34 | ruleElt ∷ Maybe GlobalRdrElt 35 | } 36 | 37 | instance Outputable Rule where 38 | ppr Rule {..} = 39 | ppr ruleDef <+> ppr (mapMaybe (\var → (,) var <$> lookupTyVar ruleArgs var) ruleVars) <+> "∷" <+> 40 | "∀" <+> ppr ruleVars <+> 41 | "." <+> ppr ruleCts <+> 42 | "⇒" <+> ppr ruleGoal 43 | 44 | ruleName ∷ Rule → Name 45 | ruleName = varName . ruleDef 46 | 47 | ruleUsage ∷ Rule → RuleUsage 48 | ruleUsage = ruleGoalKind . ruleGoal 49 | 50 | data RuleGoal 51 | = IntroGoal Type 52 | | DerivGoal Type 53 | | SimplGoal Type Type 54 | 55 | instance Outputable RuleGoal where 56 | ppr (IntroGoal goal) = ppr goal <+> "(Intro)" 57 | ppr (DerivGoal goal) = ppr goal <+> "(Deriv)" 58 | ppr (SimplGoal lhs rhs) = ppr lhs <+> "~" <+> ppr rhs <+> "(Simpl)" 59 | 60 | ruleGoalKind ∷ RuleGoal → RuleUsage 61 | ruleGoalKind (IntroGoal _) = Intro 62 | ruleGoalKind (DerivGoal _) = Deriv 63 | ruleGoalKind (SimplGoal _ _) = Simpl 64 | 65 | findRules ∷ (Definitions, Equalities, Messages, TraceKeys) ⇒ [Ct] → TcPluginM ([Rule], [Ct]) 66 | findRules givens = do 67 | ignored ← findIgnored givens 68 | (specifieds, givens') ← findSpecifiedRules givens 69 | defaults ← findDefaultRules 70 | let rules = 71 | sortBy (compare `on` ruleUsage) . 72 | filter (\rule → (ruleUsage rule, ruleName rule) `notElem` ignored) $ 73 | -- Specified rules take precedence over default rules. 74 | specifieds ++ defaults 75 | return (rules, givens') 76 | 77 | findIgnored ∷ (Definitions, Equalities, Messages) ⇒ [Ct] → TcPluginM [(RuleUsage, Name)] 78 | findIgnored = mapMaybeM (go . ctPred) where 79 | go (IgnoreTy [RuleUsageTy kind, RuleNameTy [StrLitTy md, StrLitTy var]]) = 80 | fmap (kind,) <$> lookupName md var 81 | go t@(IgnoreTy _) = do 82 | addErrorMessage (hang "Malformatted Ignore constraint:" 4 (ppr t)) 83 | return Nothing 84 | go _ = return Nothing 85 | 86 | findSpecifiedRules ∷ (Definitions, Equalities, Messages, TraceKeys) ⇒ [Ct] → TcPluginM ([Rule], [Ct]) 87 | findSpecifiedRules = loop where 88 | loop [] = return ([], []) 89 | loop (ct:cts) = do 90 | result ← go ct 91 | case result of 92 | Just rule → do 93 | (rules, cts') ← loop cts 94 | return (rule:rules, cts') 95 | Nothing → do 96 | (rules, cts') ← loop cts 97 | return (rules, ct:cts') 98 | 99 | go ct@(ctPred → t@(UseTy [RuleUsageTy kind, RuleSpecTy [RuleNameTy [StrLitTy md, StrLitTy nm], RuleArgListTy args]])) = runMaybeT do 100 | name ← MaybeT (lookupName md nm) 101 | Rule { ruleArgs = _, ..} ← MaybeT (makeRule Nothing (ctLoc ct) kind name) 102 | let ruleArgs = mkTvSubstPrs (zip ruleVars args) 103 | checkArgs emptyTCvSubst t ruleVars args 104 | return Rule {..} 105 | go (ctPred → t@(UseTy _)) = do 106 | addErrorMessage (hang "Malformatted Use constraint:" 4 (ppr t)) 107 | return Nothing 108 | go _ = return Nothing 109 | 110 | findDefaultRules ∷ (Definitions, Messages, TraceKeys) ⇒ TcPluginM [Rule] 111 | findDefaultRules = do 112 | topEnv ← getTopEnv 113 | annEnv ← tcPluginIO (prepareAnnotations topEnv Nothing) 114 | let getAnns ∷ Name → [RuleUsage] 115 | getAnns = nub . sort . findAnns deserializeWithData annEnv . NamedTarget 116 | 117 | (gblEnv, _) ← getEnvs 118 | 119 | let elts = filter (not . isLocalGRE) (concat (occEnvElts (tcg_rdr_env gblEnv))) 120 | named = map (\elt → (elt, grePrintableName elt)) elts 121 | annotd = concatMap (\(elt, name) → (,,) elt name <$> getAnns name) named 122 | trace @"Annotations" (ppr annotd) 123 | 124 | locd ← mapM (\(elt, name, kind) → (,,,) elt name kind <$> getCtLocM (OccurrenceOf name) Nothing) annotd 125 | mapMaybeM (\(elt, name, kind, loc) → makeRule (Just elt) loc kind name) locd 126 | 127 | makeRule ∷ (Definitions, Messages, TraceKeys) ⇒ Maybe GlobalRdrElt → CtLoc → RuleUsage → Name → TcPluginM (Maybe Rule) 128 | makeRule ruleElt ruleLoc kind name = runMaybeT do 129 | ruleDef ← MaybeT (lookupVar name) 130 | MaybeT . setSrcSpan (toSrcSpan (ctLocSpan ruleLoc)) . runMaybeT $ do 131 | (ruleVars, ruleCts, goalTy) ← parseRuleType name (varType ruleDef) 132 | let ruleArgs = emptyTCvSubst 133 | ruleGoal ← MaybeT (parseRuleGoal name kind goalTy) 134 | return Rule {..} 135 | 136 | parseRuleType ∷ (Definitions, Messages) ⇒ Name → Type → MaybeT TcPluginM ([Var], [Type], Type) 137 | parseRuleType _ (tcSplitForAllTys → (vars, tcSplitPhiTy → (cts, DictTy [t]))) = return (vars, cts, t) 138 | parseRuleType name t = do 139 | lift . addErrorMessage $ 140 | hang "Malformatted Use constraint:" 4 (ppr name <+> "∷" <+> ppr t) 141 | empty 142 | 143 | parseRuleGoal ∷ (Definitions, Messages) ⇒ Name → RuleUsage → Type → TcPluginM (Maybe RuleGoal) 144 | parseRuleGoal _ Intro ct = return (Just (IntroGoal ct)) 145 | parseRuleGoal _ Deriv ct = return (Just (DerivGoal ct)) 146 | parseRuleGoal _ Simpl (EqTy [_, lhs, rhs]) = return (Just (SimplGoal lhs rhs)) 147 | parseRuleGoal _ Simpl (HEqTy [_, _, lhs, rhs]) = return (Just (SimplGoal lhs rhs)) 148 | parseRuleGoal name kind goal = do 149 | addErrorMessage (hang ("Rule" <+> quotes (ppr name) <+> "has invalid" <+> ppr kind <+> "goal:") 4 (ppr goal)) 150 | return Nothing 151 | 152 | checkArgs ∷ (Equalities, Messages) ⇒ TCvSubst → Type → [Var] → [Type] → MaybeT TcPluginM () 153 | checkArgs _ _ _ [] = return () 154 | checkArgs _ t [] _ = do 155 | lift . addErrorMessage $ 156 | hang "Too many arguments for use constraint:" 4 (ppr t) 157 | empty 158 | checkArgs σ t (var:vars) (arg:args) = 159 | if equivClass (typeKind arg) `contains` substTyAddInScope σ (varType var) 160 | then checkArgs (extendTvSubstAndInScope σ var arg) t vars args 161 | else do 162 | lift . addErrorMessage $ 163 | hang "Malformatted Use constraint:" 4 (ppr t) $$ 164 | hang "Expected an argument of type" 4 (ppr (varType var)) $$ 165 | hang "but" 4 (ppr arg) $$ 166 | hang "has type" 4 (ppr (typeKind arg)) 167 | empty 168 | 169 | lookupVar ∷ (Messages, TraceKeys) ⇒ Name → TcPluginM (Maybe Var) 170 | lookupVar name = do 171 | local ← unsafeTcPluginTcM (tcLookupLcl_maybe name) 172 | case local of 173 | Just ATcId {..} → return (Just tct_id) 174 | _ → do 175 | global ← unsafeTcPluginTcM (tcLookupImported_maybe name) 176 | case global of 177 | Succeeded (AnId x) → return (Just x) 178 | _ → do 179 | trace @"Lookup" ("Could not lookup variable" <+> ppr name) 180 | return Nothing 181 | 182 | lookupName ∷ Messages ⇒ FastString → FastString → TcPluginM (Maybe Name) 183 | lookupName mdName occ = do 184 | result ← findImportedModule (mkModuleNameFS mdName) Nothing 185 | case result of 186 | Found _ md → Just <$> lookupOrig md (mkVarOccFS occ) 187 | FoundMultiple _ → do 188 | addErrorMessage ("Found multiple modules named" <+> quotes (ppr mdName)) 189 | return Nothing 190 | _ → do 191 | addErrorMessage ("Could not find a module named" <+> quotes (ppr mdName)) 192 | return Nothing 193 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Plugin/Runtime.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ConstraintKinds #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE KindSignatures #-} 4 | {-# LANGUAGE RankNTypes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | {-# LANGUAGE UnicodeSyntax #-} 7 | 8 | module Data.Constraint.Rule.Plugin.Runtime (Cached, cached, unsafeOpenDict) where 9 | 10 | import Data.Constraint (Dict, withDict) 11 | import GHC.TypeNats (Nat) 12 | import Unsafe.Coerce (unsafeCoerce) 13 | 14 | class Cached (s ∷ Nat) 15 | 16 | cached ∷ a 17 | cached = error "cached: impossible" 18 | 19 | {-# INLINE unsafeOpenDict #-} 20 | unsafeOpenDict ∷ ∀a b. Dict a → b 21 | unsafeOpenDict d = withDict d f 22 | where Magic f = unsafeCoerce (id ∷ b → b) ∷ Magic a b 23 | 24 | newtype Magic a b = Magic (a ⇒ b) 25 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Symbol.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE GADTs #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | {-# LANGUAGE TypeApplications #-} 6 | {-# LANGUAGE TypeOperators #-} 7 | {-# LANGUAGE UnicodeSyntax #-} 8 | 9 | module Data.Constraint.Rule.Symbol where 10 | 11 | import Data.Constraint (Dict (..), HasDict (..)) 12 | import Data.Constraint.Nat (Min) 13 | import Data.Constraint.Rule (RuleUsage (..)) 14 | import Data.Constraint.Symbol (Drop, Length, Take, type (++)) 15 | import qualified Data.Constraint.Symbol as Symbol 16 | import GHC.TypeLits (KnownNat, KnownSymbol, type (+), 17 | type (<=)) 18 | 19 | {-# ANN appendSymbol Intro #-} 20 | appendSymbol ∷ ∀m n. (KnownSymbol m, KnownSymbol n) ⇒ Dict (KnownSymbol (m ++ n)) 21 | appendSymbol = evidence (Symbol.appendSymbol @m @n) 22 | 23 | {-# ANN takeSymbol Intro #-} 24 | takeSymbol ∷ ∀n a. (KnownNat n, KnownSymbol a) ⇒ Dict (KnownSymbol (Take n a)) 25 | takeSymbol = evidence (Symbol.takeSymbol @n @a) 26 | 27 | {-# ANN dropSymbol Intro #-} 28 | dropSymbol ∷ ∀n a. (KnownNat n, KnownSymbol a) ⇒ Dict (KnownSymbol (Drop n a)) 29 | dropSymbol = evidence (Symbol.dropSymbol @n @a) 30 | 31 | {-# ANN takeAppendDrop Simpl #-} 32 | takeAppendDrop ∷ ∀n a. Dict ((Take n a ++ Drop n a) ~ a) 33 | takeAppendDrop = evidence (Symbol.takeAppendDrop @n @a) 34 | 35 | {-# ANN lengthSymbol Intro #-} 36 | lengthSymbol ∷ ∀a. KnownSymbol a ⇒ Dict (KnownNat (Length a)) 37 | lengthSymbol = evidence (Symbol.lengthSymbol @a) 38 | 39 | {-# ANN takeLength Simpl #-} 40 | takeLength ∷ ∀n a. (Length a <= n) ⇒ Dict (Take n a ~ a) 41 | takeLength = evidence (Symbol.takeLength @n @a) 42 | 43 | {-# ANN take0 Simpl #-} 44 | take0 ∷ ∀a. Dict (Take 0 a ~ "") 45 | take0 = evidence (Symbol.take0 @a) 46 | 47 | {-# ANN takeEmpty Simpl #-} 48 | takeEmpty ∷ ∀n. Dict (Take n "" ~ "") 49 | takeEmpty = evidence (Symbol.takeEmpty @n) 50 | 51 | {-# ANN dropLength Simpl #-} 52 | dropLength ∷ ∀n a. (Length a <= n) ⇒ Dict (Drop n a ~ "") 53 | dropLength = evidence (Symbol.dropLength @n @a) 54 | 55 | {-# ANN drop0 Simpl #-} 56 | drop0 ∷ ∀a. Dict (Drop 0 a ~ a) 57 | drop0 = evidence (Symbol.drop0 @a) 58 | 59 | {-# ANN dropEmpty Simpl #-} 60 | dropEmpty ∷ ∀n. Dict (Drop n "" ~ "") 61 | dropEmpty = evidence (Symbol.dropEmpty @n) 62 | 63 | {-# ANN lengthTake Simpl #-} 64 | lengthTake ∷ ∀n a. Dict (Length (Take n a) <= n) 65 | lengthTake = evidence (Symbol.lengthTake @n @a) 66 | 67 | {-# ANN lengthDrop Simpl #-} 68 | lengthDrop ∷ ∀n a. Dict (Length a <= (Length (Drop n a) + n)) 69 | lengthDrop = evidence (Symbol.lengthDrop @n @a) 70 | 71 | {-# ANN dropDrop Simpl #-} 72 | dropDrop ∷ ∀n m a. Dict (Drop n (Drop m a) ~ Drop (n + m) a) 73 | dropDrop = evidence (Symbol.dropDrop @n @m @a) 74 | 75 | {-# ANN takeTake Simpl #-} 76 | takeTake ∷ ∀n m a. Dict (Take n (Take m a) ~ Take (Min n m) a) 77 | takeTake = evidence (Symbol.takeTake @n @m @a) 78 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/TH.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE DefaultSignatures #-} 3 | {-# LANGUAGE FlexibleContexts #-} 4 | {-# LANGUAGE FlexibleInstances #-} 5 | {-# LANGUAGE MonoLocalBinds #-} 6 | {-# LANGUAGE TemplateHaskell #-} 7 | {-# LANGUAGE TypeApplications #-} 8 | {-# LANGUAGE TypeOperators #-} 9 | {-# LANGUAGE UnicodeSyntax #-} 10 | 11 | module Data.Constraint.Rule.TH (Rule (..), Spec (spec)) where 12 | 13 | import Data.Class.Closed.TH (close) 14 | import Data.Constraint.Rule (RuleArg (..), RuleName (..), RuleSpec (..)) 15 | import Data.Proxy (Proxy (..)) 16 | import Language.Haskell.TH.Lib (litT, strTyLit) 17 | import Language.Haskell.TH.Syntax (Exp, ModName (..), Name (..), 18 | NameFlavour (..), NameSpace (..), 19 | OccName (..), Q, Type) 20 | 21 | close [d| 22 | class Rule a where 23 | rule ∷ Name → Q a 24 | 25 | instance Rule Type where 26 | rule = ruleType 27 | 28 | instance Rule Exp where 29 | rule = ruleExp 30 | 31 | class Spec a where 32 | spec ∷ Name → a 33 | spec name = specHelper name [] 34 | 35 | specHelper ∷ Name → [Q Type] → a 36 | 37 | instance Spec (Q Type) where 38 | specHelper = specType 39 | 40 | instance Spec (Q Exp) where 41 | specHelper = specExp 42 | 43 | instance Spec a ⇒ Spec (Q Type → a) where 44 | specHelper name args arg = specHelper name (arg:args) 45 | |] 46 | 47 | ruleType ∷ Name → Q Type 48 | ruleType (Name (OccName name) (NameG VarName _ (ModName md))) = [t| 'RuleName $(litT (strTyLit md)) $(litT (strTyLit name)) |] 49 | ruleType _ = error "rule: Expected top-level function name" 50 | 51 | ruleExp ∷ Name → Q Exp 52 | ruleExp name = [e| Proxy @($(rule name)) |] 53 | 54 | specType ∷ Name → [Q Type] → Q Type 55 | specType name args = [t| 'RuleSpec $(rule name) $(list (reverse args)) |] 56 | 57 | specExp ∷ Name → [Q Type] → Q Exp 58 | specExp name args = [e| Proxy @($(specHelper name args)) |] 59 | 60 | list ∷ [Q Type] → Q Type 61 | list [] = [t| '[] |] 62 | list (x:xs) = [t| 'RuleArg $x ': $(list xs) |] 63 | -------------------------------------------------------------------------------- /src/Data/Constraint/Rule/Trace.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE AllowAmbiguousTypes #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE DefaultSignatures #-} 4 | {-# LANGUAGE FlexibleContexts #-} 5 | {-# LANGUAGE FlexibleInstances #-} 6 | {-# LANGUAGE RankNTypes #-} 7 | {-# LANGUAGE ScopedTypeVariables #-} 8 | {-# LANGUAGE TemplateHaskell #-} 9 | {-# LANGUAGE TypeFamilies #-} 10 | {-# LANGUAGE TypeOperators #-} 11 | {-# LANGUAGE UndecidableInstances #-} 12 | {-# LANGUAGE UndecidableSuperClasses #-} 13 | {-# LANGUAGE UnicodeSyntax #-} 14 | 15 | module Data.Constraint.Rule.Trace (TraceKey, Trace, withTrace) where 16 | 17 | import Data.Class.Closed.TH (close) 18 | import Data.Kind (Constraint) 19 | import GHC.TypeLits (ErrorMessage (..), Symbol, TypeError) 20 | 21 | type family TraceKey (key ∷ Symbol) ∷ Constraint where 22 | TraceKey "Constraints" = () 23 | TraceKey "Equalities" = () 24 | TraceKey "Cached" = () 25 | TraceKey "Rules" = () 26 | TraceKey "Intro" = () 27 | TraceKey "Deriv" = () 28 | TraceKey "Simpl" = () 29 | TraceKey "EmitGivens" = () 30 | TraceKey "Annotations" = () 31 | TraceKey "Lookup" = () 32 | TraceKey key = TypeError ('Text "Invalid trace key " ':<>: 'ShowType key) 33 | 34 | close [d| 35 | class TraceKey key ⇒ Trace key 36 | instance TraceKey key ⇒ Trace key 37 | |] 38 | 39 | withTrace ∷ ∀key a. TraceKey key ⇒ (Trace key ⇒ a) → a 40 | withTrace x = x 41 | -------------------------------------------------------------------------------- /test/Main.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -F -pgmF hspec-discover #-} 2 | -------------------------------------------------------------------------------- /test/Test/IntroDefs.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -fdefer-type-errors #-} 2 | {-# OPTIONS_GHC -Wno-deferred-type-errors #-} 3 | {-# OPTIONS_GHC -fplugin=Data.Constraint.Rule.Plugin #-} 4 | {-# OPTIONS_GHC -dcore-lint #-} 5 | -- {-# OPTIONS_GHC -ddump-tc-trace -ddump-to-file #-} 6 | {-# LANGUAGE AllowAmbiguousTypes #-} 7 | {-# LANGUAGE DataKinds #-} 8 | {-# LANGUAGE FlexibleContexts #-} 9 | {-# LANGUAGE NoStarIsType #-} 10 | {-# LANGUAGE ScopedTypeVariables #-} 11 | {-# LANGUAGE TemplateHaskell #-} 12 | {-# LANGUAGE TypeApplications #-} 13 | {-# LANGUAGE TypeOperators #-} 14 | {-# LANGUAGE UnicodeSyntax #-} 15 | 16 | module Test.IntroDefs where 17 | 18 | import Data.Constraint (Dict (..), HasDict (..)) 19 | import Data.Constraint.Nat (plusNat, timesNat) 20 | import Data.Constraint.Rule (withIntro) 21 | import Data.Constraint.Rule.TH (spec) 22 | import Data.Proxy (Proxy (Proxy)) 23 | import GHC.TypeNats (KnownNat, natVal, type (*), type (+)) 24 | import Numeric.Natural (Natural) 25 | import Test.Util (badProof) 26 | 27 | proof₁ ∷ (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (m + n)) 28 | proof₁ = badProof 29 | 30 | proof₂ ∷ ∀m n. (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (m + n)) 31 | proof₂ = evidence (plusNat @m @n) 32 | 33 | proof₃ ∷ ∀m n. (KnownNat m, KnownNat n) ⇒ Dict (KnownNat (m * n)) 34 | proof₃ = evidence (timesNat @m @n) 35 | 36 | test₁ ∷ ∀x. KnownNat x ⇒ Natural 37 | test₁ = natVal (Proxy @(x + 5)) 38 | 39 | test₂ ∷ ∀x. KnownNat x ⇒ Natural 40 | test₂ = withIntro $(spec 'proof₁) (natVal (Proxy @(x + 5))) 41 | 42 | test₃ ∷ ∀x. KnownNat x ⇒ Natural 43 | test₃ = withIntro $(spec 'proof₂) (natVal (Proxy @(x + 5))) 44 | 45 | nested₁ ∷ ∀x y. (KnownNat x, KnownNat y) ⇒ Natural 46 | nested₁ = natVal (Proxy @(x + y + 5)) 47 | 48 | nested₂ ∷ ∀x y. (KnownNat x, KnownNat y) ⇒ Natural 49 | nested₂ = withIntro $(spec 'proof₂) $ 50 | natVal (Proxy @(x + y + 5)) 51 | 52 | multiple₁ ∷ ∀x y. (KnownNat x, KnownNat y) ⇒ Natural 53 | multiple₁ = withIntro $(spec 'proof₂) $ 54 | natVal (Proxy @(x * y + 5)) 55 | 56 | multiple₂ ∷ ∀x y. (KnownNat x, KnownNat y) ⇒ Natural 57 | multiple₂ = 58 | withIntro $(spec 'proof₂) $ 59 | withIntro $(spec 'proof₃) $ 60 | natVal (Proxy @(x * y + 5)) 61 | -------------------------------------------------------------------------------- /test/Test/IntroSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BlockArguments #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE TemplateHaskell #-} 4 | {-# LANGUAGE TypeApplications #-} 5 | {-# LANGUAGE TypeFamilies #-} 6 | {-# LANGUAGE UnicodeSyntax #-} 7 | 8 | module Test.IntroSpec where 9 | 10 | import Control.Exception (evaluate) 11 | import Control.Exception.Base (TypeError) 12 | import Test.Hspec (Spec, describe, it, shouldBe, shouldThrow) 13 | import Test.Util (BadProof, exc) 14 | 15 | import Test.IntroDefs 16 | 17 | spec ∷ Spec 18 | spec = do 19 | describe "test₁" do 20 | it "should not type check" do 21 | evaluate (test₁ @10) `shouldThrow` exc @TypeError 22 | 23 | describe "test₂" do 24 | it "solves KnownNat (10 + 5) with bad proof" do 25 | evaluate (test₂ @10) `shouldThrow` exc @BadProof 26 | 27 | describe "test₃" do 28 | it "solves KnownNat (10 + 5) with plusNat" do 29 | test₃ @10 `shouldBe` 15 30 | 31 | describe "nested₁" do 32 | it "should not type check" do 33 | evaluate (nested₁ @10 @7) `shouldThrow` exc @TypeError 34 | 35 | describe "nested₂" do 36 | it "solves KnownNat (10 + 7 + 5) with plusNat" do 37 | nested₂ @10 @7 `shouldBe` 22 38 | 39 | describe "multiple₁" do 40 | it "should not type check" do 41 | evaluate (multiple₁ @10 @7) `shouldThrow` exc @TypeError 42 | 43 | describe "multiple₂" do 44 | it "solves KnownNat (10 * 7 + 5) with plusNat and timesNat" do 45 | multiple₂ @10 @7 `shouldBe` 75 46 | -------------------------------------------------------------------------------- /test/Test/SimplDefs.hs: -------------------------------------------------------------------------------- 1 | {-# OPTIONS_GHC -fdefer-type-errors #-} 2 | {-# OPTIONS_GHC -Wno-deferred-type-errors #-} 3 | {-# OPTIONS_GHC -fplugin=Data.Constraint.Rule.Plugin #-} 4 | {-# OPTIONS_GHC -dcore-lint #-} 5 | -- {-# OPTIONS_GHC -ddump-tc-trace -ddump-to-file #-} 6 | {-# LANGUAGE AllowAmbiguousTypes #-} 7 | {-# LANGUAGE BlockArguments #-} 8 | {-# LANGUAGE DataKinds #-} 9 | {-# LANGUAGE FlexibleContexts #-} 10 | {-# LANGUAGE PolyKinds #-} 11 | {-# LANGUAGE RankNTypes #-} 12 | {-# LANGUAGE ScopedTypeVariables #-} 13 | {-# LANGUAGE TemplateHaskell #-} 14 | {-# LANGUAGE TypeApplications #-} 15 | {-# LANGUAGE TypeFamilies #-} 16 | {-# LANGUAGE UnicodeSyntax #-} 17 | 18 | module Test.SimplDefs where 19 | 20 | import Data.Constraint (Dict (..), withDict) 21 | import Data.Constraint.Rule (withSimpl) 22 | import qualified Data.Constraint.Rule.TH as TH 23 | import Data.Type.Bool (If) 24 | import Test.Util (badProof, testEq, trustMe) 25 | 26 | type family Foo a 27 | type family Bar a 28 | 29 | withEq ∷ ∀a b. Eq a ⇒ (Eq (Foo a) ⇒ b) → b 30 | withEq = withDict (proof₂ @a) 31 | 32 | proof₁ ∷ Dict (Foo a ~ a) 33 | proof₁ = badProof 34 | 35 | proof₂ ∷ Dict (Foo a ~ a) 36 | proof₂ = trustMe 37 | 38 | proof₃ ∷ Dict (Bar a ~ a) 39 | proof₃ = trustMe 40 | 41 | proof₄ ∷ Dict (If a (f b) (f c) ~ f (If a b c)) 42 | proof₄ = trustMe 43 | 44 | wanted₁ ∷ Maybe Int 45 | wanted₁ = Nothing @(Foo Int) 46 | 47 | wanted₂ ∷ Maybe Int 48 | wanted₂ = withSimpl $(TH.spec 'proof₁) (Nothing @(Foo Int)) 49 | 50 | wanted₃ ∷ Maybe Int 51 | wanted₃ = withSimpl $(TH.spec 'proof₂) (Nothing @(Foo Int)) 52 | 53 | wanted₄ ∷ () 54 | wanted₄ = testEq @(Foo Int) Dict 0 55 | 56 | wanted₅ ∷ () 57 | wanted₅ = withSimpl $(TH.spec 'proof₁) (testEq @(Foo Int) Dict 0) 58 | 59 | wanted₆ ∷ () 60 | wanted₆ = withSimpl $(TH.spec 'proof₂) (testEq @(Foo Int) Dict 0) 61 | 62 | given₁ ∷ () 63 | given₁ = withEq @Int (testEq @Int go 0) where 64 | go ∷ Eq (Foo a) ⇒ Dict (Eq a) 65 | go = Dict 66 | 67 | given₂ ∷ () 68 | given₂ = withEq @Int (testEq @Int go 0) where 69 | go ∷ Eq (Foo a) ⇒ Dict (Eq a) 70 | go = withSimpl $(TH.spec 'proof₁) Dict 71 | 72 | given₃ ∷ () 73 | given₃ = withEq @Int (testEq @Int go 0) where 74 | go ∷ Eq (Foo a) ⇒ Dict (Eq a) 75 | go = withSimpl $(TH.spec 'proof₂) Dict 76 | 77 | nested₁ ∷ () 78 | nested₁ = testEq @Int (Dict @(Eq (Foo (Foo (Foo Int))))) 0 79 | 80 | nested₂ ∷ () 81 | nested₂ = withSimpl $(TH.spec 'proof₂) $ 82 | testEq @Int (Dict @(Eq (Foo (Foo (Foo Int))))) 0 83 | 84 | multiple₁ ∷ () 85 | multiple₁ = testEq @(Bar Int) (Dict @(Eq (Foo Int))) 0 86 | 87 | multiple₂ ∷ () 88 | multiple₂ = 89 | withSimpl $(TH.spec 'proof₂) $ 90 | withSimpl $(TH.spec 'proof₃) $ 91 | testEq @(Bar Int) (Dict @(Eq (Foo Int))) 0 92 | 93 | kinds ∷ Dict (Eq (If a b c) ~ If a (Eq b) (Eq c)) 94 | kinds = withSimpl $(TH.spec 'proof₄) Dict 95 | -------------------------------------------------------------------------------- /test/Test/SimplSpec.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BlockArguments #-} 2 | {-# LANGUAGE DataKinds #-} 3 | {-# LANGUAGE TemplateHaskell #-} 4 | {-# LANGUAGE TypeApplications #-} 5 | {-# LANGUAGE TypeFamilies #-} 6 | {-# LANGUAGE UnicodeSyntax #-} 7 | 8 | module Test.SimplSpec where 9 | 10 | import Control.Exception (evaluate) 11 | import Control.Exception.Base (TypeError) 12 | import Test.Hspec (Spec, describe, it, shouldBe, shouldThrow) 13 | import Test.Util (BadProof, exc) 14 | 15 | import Test.SimplDefs 16 | 17 | spec ∷ Spec 18 | spec = do 19 | describe "wanted₁" do 20 | it "should not type check" do 21 | evaluate wanted₁ `shouldThrow` exc @TypeError 22 | 23 | describe "wanted₂" do 24 | it "rewrites wanted equality with bad proof" do 25 | evaluate wanted₂ `shouldThrow` exc @BadProof 26 | 27 | describe "wanted₃" do 28 | it "rewrites wanted equality with trustMe" do 29 | wanted₃ `shouldBe` Nothing 30 | 31 | describe "wanted₄" do 32 | it "should not type check" do 33 | evaluate wanted₄ `shouldThrow` exc @TypeError 34 | 35 | describe "wanted₅" do 36 | it "rewrites wanted class with bad proof" do 37 | evaluate wanted₅ `shouldThrow` exc @BadProof 38 | 39 | describe "wanted₆" do 40 | it "rewrites wanted class with trustMe" do 41 | wanted₆ `shouldBe` () 42 | 43 | describe "given₁" do 44 | it "should not type check" do 45 | evaluate given₁ `shouldThrow` exc @TypeError 46 | 47 | describe "given₂" do 48 | it "rewrites given class with bad proof" do 49 | evaluate given₂ `shouldThrow` exc @BadProof 50 | 51 | describe "given₃" do 52 | it "rewrites given class with trustMe" do 53 | given₃ `shouldBe` () 54 | 55 | describe "nested₁" do 56 | it "should not type check" do 57 | evaluate nested₁ `shouldThrow` exc @TypeError 58 | 59 | describe "nested₂" do 60 | it "performs nested rewrites" do 61 | nested₂ `shouldBe` () 62 | 63 | describe "multiple₁" do 64 | it "should not type check" do 65 | evaluate multiple₁ `shouldThrow` exc @TypeError 66 | 67 | describe "multiple₂" do 68 | it "rewrites multiple constraints" do 69 | multiple₂ `shouldBe` () 70 | 71 | describe "kinds" do 72 | it "fails unless kinds are matched properly" do 73 | kinds `shouldBe` () 74 | -------------------------------------------------------------------------------- /test/Test/Util.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveAnyClass #-} 2 | {-# LANGUAGE ScopedTypeVariables #-} 3 | {-# LANGUAGE UnicodeSyntax #-} 4 | 5 | module Test.Util where 6 | 7 | import Control.Exception (Exception, throw) 8 | import Data.Constraint (Dict (..)) 9 | import Unsafe.Coerce (unsafeCoerce) 10 | 11 | data BadProof = BadProof 12 | deriving (Exception, Show) 13 | 14 | badProof ∷ a 15 | badProof = throw BadProof 16 | 17 | exc ∷ ∀a. Exception a ⇒ a → Bool 18 | exc _ = True 19 | 20 | trustMe ∷ Dict (a ~ b) 21 | trustMe = unsafeCoerce (Dict ∷ Dict (() ~ ())) 22 | 23 | testEq ∷ Dict (Eq a) → a → () 24 | testEq Dict x = x == x `seq` () 25 | --------------------------------------------------------------------------------