├── NOTICE ├── .gitignore ├── test ├── Main.hs ├── TestUtil.hs └── Core │ └── LinearizationTest.hs ├── src └── TensorRight │ └── Internal │ ├── DSL │ ├── Syntax.hs │ ├── RelabelMap.hs │ ├── Parameters.hs │ ├── Identifier.hs │ ├── Condition.hs │ ├── BoundInference.hs │ └── Shape.hs │ ├── Util │ ├── Pretty.hs │ └── Error.hs │ └── Core │ ├── Linearization.hs │ ├── Axis.hs │ └── Tensor │ └── TensorInt.hs ├── docs └── pull_request_template.md ├── Makefile ├── rules ├── xla │ ├── iota │ │ └── Main.hs │ ├── not │ │ └── Main.hs │ ├── max │ │ └── Main.hs │ ├── relabel │ │ └── Main.hs │ ├── sub │ │ └── Main.hs │ ├── compare │ │ └── Main.hs │ ├── add │ │ └── Main.hs │ ├── select │ │ └── Main.hs │ ├── clamp │ │ └── Main.hs │ ├── reverse │ │ └── Main.hs │ ├── dyupslice │ │ └── Main.hs │ ├── broadcast │ │ └── Main.hs │ ├── divmod │ │ └── Main.hs │ ├── logical │ │ └── Main.hs │ ├── generalize │ │ └── Main.hs │ ├── mul │ │ └── Main.hs │ ├── dyslice │ │ └── Main.hs │ ├── concat │ │ └── Main.hs │ ├── pad │ │ └── Main.hs │ ├── dot │ │ └── Main.hs │ └── reduce │ │ └── Main.hs └── debug │ └── Main.hs ├── stack.yaml.lock ├── flake.lock ├── runall.sh ├── Dockerfile ├── flake.nix ├── stack.yaml ├── hie.yaml ├── package.yaml ├── plot └── timing_plot.py ├── LICENSE └── README.md /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright University of Illinois Board of Trustees 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | result.txt 2 | *.pdf 3 | .stack-work/ 4 | .envrc 5 | .direnv 6 | -------------------------------------------------------------------------------- /test/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Core.LinearizationTest (linearizationTest) 4 | import Core.TensorTest (tensorTest) 5 | import Test.Framework (defaultMain, testGroup) 6 | 7 | main :: IO () 8 | main = 9 | defaultMain 10 | [ testGroup "Core" [linearizationTest, tensorTest] 11 | ] 12 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/DSL/Syntax.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FunctionalDependencies #-} 2 | 3 | module TensorRight.Internal.DSL.Syntax (ArrowSyntax (..), AtSyntax (..)) where 4 | 5 | -- | Syntax for defining a mapping. 6 | class ArrowSyntax a b c where 7 | (-->) :: a -> b -> c 8 | infix 8 --> 9 | 10 | -- | Syntax for labelling an rclass. 11 | class AtSyntax a b c | c -> a b where 12 | (@@) :: a -> b -> c 13 | infix 7 @@ 14 | -------------------------------------------------------------------------------- /docs/pull_request_template.md: -------------------------------------------------------------------------------- 1 | This PR adds the following changes/fixes/enhancements: 2 | 3 | ... 4 | 5 | # Checks 6 | 7 | - [ ] Did you run `ormolu`, a [Haskell Formatter](https://hackage.haskell.org/package/ormolu)? Run `ormolu --mode inplace $(git ls-files '*.hs')` to format all files in the project. 8 | - [ ] Did you update `hie.yaml` using the command `gen-hie > hie.yaml`? 9 | - [ ] Do all tests pass using `stack test`? 10 | 11 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL=/usr/bin/env bash 2 | 3 | .PHONY: build verify maxMinToClampBefore maxMinToClampAfter generalize plot clean 4 | 5 | build: 6 | stack build 7 | 8 | verify: build 9 | ./runall.sh 2> >(tee ./plot/result.txt); 10 | 11 | generalize: build 12 | stack exec rules-generalize 13 | 14 | plot: build 15 | if [ ! -f ./plot/result.txt ]; then ./runall.sh 2> >(tee ./plot/result.txt); fi 16 | cd plot && python3 timing_plot.py 17 | 18 | clean: 19 | stack clean 20 | rm -rf .stack-work 21 | rm -f ./plot/result.txt ./plot/timing_plot.pdf 22 | -------------------------------------------------------------------------------- /rules/xla/iota/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: DSLContext Rewrite 7 | rule01 = do 8 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 9 | rc0Size <- newMap "rc0Size" rclass0 10 | rc1Size <- newMap "rc1Size" rclass1 11 | 12 | lhs <- iota [rclass0 --> rc0Size, rclass1 --> rc1Size] (ByRClass rclass0) 13 | precondition [rc0Size] $ \[s] -> s .== 1 14 | 15 | rhs <- constant @TensorInt 0 [rclass0 --> rc0Size, rclass1 --> rc1Size] 16 | rewrite "Iota ⇒ Zero" lhs rhs 17 | 18 | main :: IO () 19 | main = do 20 | print "############################## rule01 ##############################" 21 | verifyDSL rule01 22 | -------------------------------------------------------------------------------- /rules/xla/not/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: DSLContext Rewrite 7 | rule01 = do 8 | rclass <- newRClass "rclass" 9 | map <- newMap "map" rclass 10 | tA <- newTensor @SymBool "A" [rclass --> map] 11 | lhs <- boolUnaryOp Not (boolUnaryOp Not tA) 12 | let rhs = tA 13 | rewrite "Not(Not(A)) ⇒ A" lhs rhs 14 | 15 | rule02 :: forall a. NumRule a 16 | rule02 _ = do 17 | rclass <- newRClass "rclass" 18 | map <- newMap "map" rclass 19 | tA <- newTensor @a "A" [rclass --> map] 20 | lhs <- numUnaryOp Neg (numUnaryOp Neg tA) 21 | let rhs = tA 22 | rewrite "Negate(Negate(A)) ⇒ A" lhs rhs 23 | 24 | main :: IO () 25 | main = do 26 | print "############################## rule01 ##############################" 27 | verifyDSL rule01 28 | print "############################## rule02 ##############################" 29 | verifyNumDSL rule02 30 | -------------------------------------------------------------------------------- /stack.yaml.lock: -------------------------------------------------------------------------------- 1 | # This file was autogenerated by Stack. 2 | # You should not edit this file by hand. 3 | # For more information, please see the documentation at: 4 | # https://docs.haskellstack.org/en/stable/lock_files 5 | 6 | packages: 7 | - completed: 8 | commit: 751e74a5c0049868da75f320c47a1ce435efc9d0 9 | git: https://github.com/lsrcz/grisette.git 10 | name: grisette 11 | pantry-tree: 12 | sha256: f6a26f7ad482f5d927d8f47c4df94d184cca7aef2eb6d1ce809a0d9aa6197e25 13 | size: 31455 14 | version: 0.11.0.0 15 | original: 16 | commit: 751e74a5c0049868da75f320c47a1ce435efc9d0 17 | git: https://github.com/lsrcz/grisette.git 18 | snapshots: 19 | - completed: 20 | sha256: 96d941a6c484efb750ceab66a2dd177caa580391a4e2e4afb14fc4e9f536846f 21 | size: 621372 22 | url: https://raw.githubusercontent.com/commercialhaskell/stackage-snapshots/master/nightly/2025/1/6.yaml 23 | original: nightly-2025-01-06 24 | -------------------------------------------------------------------------------- /rules/xla/max/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import TensorRight 4 | 5 | rule01 :: forall a. NumRule a 6 | rule01 _ = do 7 | rclass <- newRClass "rclass" 8 | map <- newMap "map" rclass 9 | tA <- newTensor @a "A" [rclass --> map] 10 | tNInf <- constant @a (negInf :: a) [rclass --> map] 11 | lhs <- numBinOp Max tA tNInf 12 | let rhs = tA 13 | rewrite "Max(A, -inf) ⇒ A" lhs rhs 14 | 15 | rule02 :: forall a. NumRule a 16 | rule02 _ = do 17 | rclass <- newRClass "rclass" 18 | map <- newMap "map" rclass 19 | tA <- newTensor @a "A" [rclass --> map] 20 | tInf <- constant @a (posInf :: a) [rclass --> map] 21 | lhs <- numBinOp Min tA tInf 22 | let rhs = tA 23 | rewrite "Min(A, inf) ⇒ A" lhs rhs 24 | 25 | main :: IO () 26 | main = do 27 | print "############################## rule01 ##############################" 28 | verifyNumDSL rule01 29 | print "############################## rule02 ##############################" 30 | verifyNumDSL rule02 31 | -------------------------------------------------------------------------------- /rules/xla/relabel/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import TensorRight 4 | 5 | rule01 :: forall a. AnyDTypeRule a 6 | rule01 _ = do 7 | rclass <- newRClass "rclass" 8 | map <- newMap "map" rclass 9 | tA <- newTensor @a "A" [rclass --> map] 10 | lhs <- 11 | relabel 12 | (relabel tA [rclass --> ByLabel "label"]) 13 | [ByLabel "label" --> ByLabel "label2"] 14 | rhs <- relabel tA [rclass --> ByLabel "label2"] 15 | rewrite "Transpose(Transpose(A)) ⇒ Transpose(A)" lhs rhs 16 | 17 | rule02 :: forall a. AnyDTypeRule a 18 | rule02 _ = do 19 | rclass <- newRClass "rclass" 20 | map <- newMap "map" rclass 21 | tA <- newTensor @a "A" [rclass --> map @@ "label"] 22 | lhs <- relabel tA [ByLabel "label" --> ByLabel "label"] 23 | let rhs = tA 24 | rewrite "Transpose(A) ⇒ A" lhs rhs 25 | 26 | main :: IO () 27 | main = do 28 | print "############################## rule01 ##############################" 29 | verifyAnyDTypeDSL rule01 30 | print "############################## rule02 ##############################" 31 | verifyAnyDTypeDSL rule02 32 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/Util/Pretty.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | 3 | module TensorRight.Internal.Util.Pretty 4 | ( encloseList, 5 | groupedEnclose, 6 | condEnclose, 7 | prettyWithConstructor, 8 | gprettyParen, 9 | ) 10 | where 11 | 12 | import Prettyprinter (Doc, align, cat, flatAlt, group, nest, vcat, vsep) 13 | 14 | encloseList :: Doc ann -> Doc ann -> Doc ann -> [Doc ann] -> Doc ann 15 | encloseList l r s ds = case ds of 16 | [] -> l <> r 17 | [d] -> cat [nest 2 $ vcat [l, d], r] 18 | _ -> 19 | group $ 20 | vcat [nest 2 $ vcat [l, vcat $ map (<> sep) (init ds), last ds], r] 21 | where 22 | sep = flatAlt s (s <> " ") 23 | 24 | groupedEnclose :: Doc ann -> Doc ann -> Doc ann -> Doc ann 25 | groupedEnclose l r d = group $ align $ vcat [l <> flatAlt " " "" <> d, r] 26 | 27 | condEnclose :: Bool -> Doc ann -> Doc ann -> Doc ann -> Doc ann 28 | condEnclose b = if b then groupedEnclose else const $ const id 29 | 30 | gprettyParen :: Bool -> Doc ann -> Doc ann 31 | gprettyParen b = condEnclose b "(" ")" 32 | 33 | prettyWithConstructor :: Int -> Doc ann -> [Doc ann] -> Doc ann 34 | prettyWithConstructor n c l = 35 | group $ condEnclose (n > 10) "(" ")" $ align $ nest 2 $ vsep (c : l) 36 | -------------------------------------------------------------------------------- /rules/xla/sub/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import TensorRight 4 | 5 | rule00 :: forall a. NumRule a 6 | rule00 _ = do 7 | rclass <- newRClass "rclass" 8 | map <- newMap "map" rclass 9 | tA <- newTensor @a "A" [rclass --> map] 10 | constTensor <- constant @a 0 [rclass --> map] 11 | lhs <- numBinOp Sub tA constTensor 12 | let rhs = tA 13 | rewrite "Sub(A, 0) ⇒ A" lhs rhs 14 | 15 | rule01 :: forall a. NumRule a 16 | rule01 _ = do 17 | rclass <- newRClass "rclass" 18 | map <- newMap "map" rclass 19 | tA <- newTensor @a "A" [rclass --> map] 20 | constTensor <- constant @a "x" [rclass --> map] 21 | negConstTensor <- constant @a (-"x") [rclass --> map] 22 | lhs <- numBinOp Sub tA constTensor 23 | rhs <- numBinOp Add tA negConstTensor 24 | rewrite "Sub(A, Const) ⇒ Add(A, Neg(Const))" lhs rhs 25 | 26 | rule02 :: forall a. NumRule a 27 | rule02 _ = do 28 | rclass <- newRClass "rclass" 29 | map <- newMap "map" rclass 30 | tA <- newTensor @a "A" [rclass --> map] 31 | constTensor <- constant @a 0 [rclass --> map] 32 | lhs <- numBinOp Sub tA tA 33 | let rhs = constTensor 34 | rewrite "Sub(A, A) ⇒ 0" lhs rhs 35 | 36 | main :: IO () 37 | main = do 38 | print "############################## rule00 ##############################" 39 | verifyNumDSL rule00 40 | print "############################## rule01 ##############################" 41 | verifyNumDSL rule01 42 | print "############################## rule02 ##############################" 43 | verifyNumDSL rule02 44 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/Util/Error.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# LANGUAGE MonoLocalBinds #-} 3 | {-# LANGUAGE RankNTypes #-} 4 | {-# LANGUAGE ScopedTypeVariables #-} 5 | 6 | module TensorRight.Internal.Util.Error 7 | ( Error, 8 | ErrorEnv, 9 | assert, 10 | splitWithError, 11 | ) 12 | where 13 | 14 | import Control.Monad.Error.Class (MonadError (throwError)) 15 | import Control.Monad.Except (ExceptT, runExceptT) 16 | import Data.Either (isRight) 17 | import qualified Data.Text as T 18 | import Grisette 19 | ( Mergeable, 20 | PlainUnion (toGuardedList), 21 | SymBool, 22 | Union, 23 | mrgReturn, 24 | ) 25 | import Grisette.Unified (GetBool, UnifiedBranching, mrgIf) 26 | 27 | type Error = T.Text 28 | 29 | type ErrorEnv = ExceptT Error Union 30 | 31 | assert :: 32 | (UnifiedBranching mode m, MonadError Error m) => Error -> GetBool mode -> m () 33 | assert err cond = mrgIf cond (return ()) $ throwError err 34 | 35 | -- May introduce this into Grisette library in the future 36 | splitWithError :: 37 | forall a. (Mergeable a) => ExceptT Error Union a -> Maybe (SymBool, Union a) 38 | splitWithError a = do 39 | let joined :: Union (Either () (Union a)) = do 40 | v <- runExceptT a 41 | case v of 42 | Left _ -> mrgReturn $ Left () 43 | Right v -> mrgReturn $ Right $ mrgReturn v 44 | let flattened = filter (\(_, v) -> isRight v) $ toGuardedList joined 45 | case flattened of 46 | [] -> Nothing 47 | [(b, Right v)] -> Just (b, v) 48 | _ -> error "Should not happen." 49 | -------------------------------------------------------------------------------- /rules/xla/compare/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import qualified Data.Text as T 4 | import Grisette hiding ((-->)) 5 | import TensorRight 6 | 7 | constructRule :: forall a. T.Text -> CompareOp -> Bool -> NumRule a 8 | constructRule name op value _ = do 9 | rclass <- newRClass "rclass" 10 | map <- newMap "map" rclass 11 | tA <- newTensor @a "A" [rclass --> map] 12 | lhs <- compareOp op tA tA 13 | rhs <- constant @SymBool (con value) [rclass --> map] 14 | rewrite name lhs rhs 15 | 16 | rule01 :: forall a. NumRule a 17 | rule01 = constructRule "Gt(A, A) ⇒ False" Gt False 18 | 19 | rule02 :: forall a. NumRule a 20 | rule02 = constructRule "Lt(A, A) ⇒ False" Lt False 21 | 22 | rule03 :: forall a. NumRule a 23 | rule03 = constructRule "Ne(A, A) ⇒ False" Ne False 24 | 25 | rule04 :: forall a. NumRule a 26 | rule04 = constructRule "Ge(A, A) ⇒ True" Ge True 27 | 28 | rule05 :: forall a. NumRule a 29 | rule05 = constructRule "Le(A, A) ⇒ True" Le True 30 | 31 | rule06 :: forall a. NumRule a 32 | rule06 = constructRule "Eqv(A, A) ⇒ True" Eqv True 33 | 34 | main :: IO () 35 | main = do 36 | print "############################## rule01 ##############################" 37 | verifyNumDSL rule01 38 | print "############################## rule02 ##############################" 39 | verifyNumDSL rule02 40 | print "############################## rule03 ##############################" 41 | verifyNumDSL rule03 42 | print "############################## rule04 ##############################" 43 | verifyNumDSL rule04 44 | print "############################## rule05 ##############################" 45 | verifyNumDSL rule05 46 | print "############################## rule06 ##############################" 47 | verifyNumDSL rule06 48 | -------------------------------------------------------------------------------- /flake.lock: -------------------------------------------------------------------------------- 1 | { 2 | "nodes": { 3 | "flake-utils": { 4 | "inputs": { 5 | "systems": "systems" 6 | }, 7 | "locked": { 8 | "lastModified": 1731533236, 9 | "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", 10 | "owner": "numtide", 11 | "repo": "flake-utils", 12 | "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", 13 | "type": "github" 14 | }, 15 | "original": { 16 | "owner": "numtide", 17 | "repo": "flake-utils", 18 | "type": "github" 19 | } 20 | }, 21 | "nixpkgs": { 22 | "locked": { 23 | "lastModified": 1736012469, 24 | "narHash": "sha256-/qlNWm/IEVVH7GfgAIyP6EsVZI6zjAx1cV5zNyrs+rI=", 25 | "owner": "NixOS", 26 | "repo": "nixpkgs", 27 | "rev": "8f3e1f807051e32d8c95cd12b9b421623850a34d", 28 | "type": "github" 29 | }, 30 | "original": { 31 | "owner": "NixOS", 32 | "ref": "nixos-unstable", 33 | "repo": "nixpkgs", 34 | "type": "github" 35 | } 36 | }, 37 | "root": { 38 | "inputs": { 39 | "flake-utils": "flake-utils", 40 | "nixpkgs": "nixpkgs" 41 | } 42 | }, 43 | "systems": { 44 | "locked": { 45 | "lastModified": 1681028828, 46 | "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", 47 | "owner": "nix-systems", 48 | "repo": "default", 49 | "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", 50 | "type": "github" 51 | }, 52 | "original": { 53 | "owner": "nix-systems", 54 | "repo": "default", 55 | "type": "github" 56 | } 57 | } 58 | }, 59 | "root": "root", 60 | "version": 7 61 | } 62 | -------------------------------------------------------------------------------- /test/TestUtil.hs: -------------------------------------------------------------------------------- 1 | module TestUtil (eqWhenSuccess, isError, isNotError) where 2 | 3 | import Control.Monad.Except (runExceptT) 4 | import Data.List (intercalate) 5 | import GHC.Stack (HasCallStack) 6 | import Grisette 7 | ( EvalSym (evalSym), 8 | Mergeable, 9 | Solvable (con), 10 | SolvingFailure (Unsat), 11 | SymEq ((./=)), 12 | mrgFmap, 13 | mrgReturn, 14 | simpleMerge, 15 | solve, 16 | z3, 17 | ) 18 | import TensorRight.Internal.Util.Error (ErrorEnv) 19 | import Test.HUnit (assertBool, (@?=)) 20 | 21 | isError :: (HasCallStack) => (Mergeable a) => ErrorEnv a -> IO () 22 | isError err = do 23 | let actual = 24 | mrgFmap (either (const $ Left ()) (const $ Right ())) $ runExceptT err 25 | actual @?= mrgReturn (Left ()) 26 | 27 | isNotError :: (HasCallStack) => (Mergeable a, Show a) => ErrorEnv a -> IO () 28 | isNotError v = do 29 | let actual = 30 | mrgFmap (either (const $ Left ()) (const $ Right ())) $ runExceptT v 31 | assertBool ("Must not be error, but got: " <> show v) $ 32 | actual /= mrgReturn (Left ()) 33 | 34 | eqWhenSuccess :: 35 | (HasCallStack, EvalSym v, Show v, SymEq v, Mergeable v) => 36 | ErrorEnv v -> 37 | v -> 38 | IO () 39 | eqWhenSuccess actual expected = do 40 | isNotError actual 41 | let r = simpleMerge $ do 42 | v <- runExceptT actual 43 | case v of 44 | Left _ -> mrgReturn $ con False 45 | Right x -> mrgReturn $ x ./= expected 46 | m <- solve z3 r 47 | case m of 48 | Left Unsat -> pure () 49 | Left err -> fail $ "Solver failed: " <> show err 50 | Right m -> do 51 | fail $ 52 | intercalate 53 | "\n" 54 | [ "unexpected model: " <> show m, 55 | "actual: " <> show (evalSym False m actual), 56 | "expected: " <> show (evalSym False m expected), 57 | "Failed" 58 | ] 59 | -------------------------------------------------------------------------------- /runall.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | TOTAL_FAILED=0 4 | TOTAL_SUCCESS=0 5 | 6 | run_and_capture() { 7 | r=$(stack run rules-$1-$2 2>/dev/null | grep -E "(FAIL|SUCCESS|WARNING|INFO|====>|>>>)" | 8 | ( 9 | LOCAL_FAILED=0 10 | LOCAL_SUCCESS=0 11 | while IFS= read -r line; do 12 | echo "$line" 1>&2 13 | if [[ $line =~ ^\[SUCCESS\].* ]]; then 14 | export LOCAL_SUCCESS=$((LOCAL_SUCCESS + 1)) 15 | elif [[ $line =~ ^\[SUCCESS-Overall\].* ]]; then 16 | export LOCAL_SUCCESS=$((LOCAL_SUCCESS + 1)) 17 | elif [[ $line =~ ^\[SUCCESS-.*\].* ]]; then 18 | true 19 | elif [[ $line =~ ^\[FAIL\].* ]]; then 20 | export LOCAL_FAILED=$((LOCAL_FAILED + 1)) 21 | elif [[ $line =~ ^\[FAIL-Overall\].* ]]; then 22 | export LOCAL_FAILED=$((LOCAL_FAILED + 1)) 23 | elif [[ $line =~ ^\[FAIL-.*\].* ]]; then 24 | true 25 | elif [[ $line =~ ^\[WARNING\].* ]]; then 26 | true 27 | elif [[ $line =~ ^\[INFO-.*\].* ]]; then 28 | true 29 | elif [[ $line =~ ^\[INFO\].* ]]; then 30 | true 31 | elif [[ $line =~ ^====\>.* ]]; then 32 | true 33 | elif [[ $line =~ ^\>\>\>.* ]]; then 34 | true 35 | else 36 | echo "Unknown line: $line" 37 | exit 1 38 | fi 39 | done 40 | echo "$LOCAL_SUCCESS $LOCAL_FAILED" 41 | )) 42 | r=($r) 43 | TOTAL_SUCCESS=$((TOTAL_SUCCESS + ${r[0]})) 44 | TOTAL_FAILED=$((TOTAL_FAILED + ${r[1]})) 45 | } 46 | 47 | framework=$1 48 | if [[ -z $framework ]]; then 49 | echo "Usage: $0 " 50 | exit 1 51 | fi 52 | 53 | ALL_RULES=$(ls rules/$framework) 54 | for rule in $ALL_RULES; do 55 | if [[ $rule != "generalize" ]]; then 56 | run_and_capture $framework $rule 57 | fi 58 | done 59 | 60 | echo "Total success: $TOTAL_SUCCESS" 61 | echo "Total failed: $TOTAL_FAILED" 62 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/Core/Linearization.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleContexts #-} 2 | {-# LANGUAGE OverloadedStrings #-} 3 | {-# LANGUAGE ScopedTypeVariables #-} 4 | 5 | module TensorRight.Internal.Core.Linearization 6 | ( linearize, 7 | delinearize, 8 | ) 9 | where 10 | 11 | import Control.Exception (ArithException) 12 | import qualified Data.HashMap.Lazy as HM 13 | import Grisette 14 | ( SafeDiv (safeDiv, safeMod), 15 | SymInteger, 16 | mrgReturn, 17 | ) 18 | import Grisette.Lib.Control.Monad.Except (mrgModifyError) 19 | import TensorRight.Internal.Core.Axis 20 | ( Axis, 21 | AxisMapLike 22 | ( fromKVPairs 23 | ), 24 | Indices, 25 | Sizes, 26 | getAxis, 27 | ) 28 | import TensorRight.Internal.Util.Error (ErrorEnv) 29 | 30 | delinearize :: [Axis] -> Sizes -> SymInteger -> ErrorEnv Indices 31 | delinearize layout dims v = do 32 | let sizesOrdered = map (`getAxis` dims) layout 33 | let sizesOrderedAdj = tail sizesOrdered 34 | let linearizationFactors = scanr (*) 1 sizesOrderedAdj 35 | let linearizationFactorsHash = HM.fromList $ zip layout linearizationFactors 36 | 37 | let delinearizeAxis axis = 38 | mrgModifyError (\(_ :: ArithException) -> "Division by zero") $ do 39 | x <- safeDiv v $ linearizationFactorsHash HM.! axis 40 | r <- safeMod x $ getAxis axis dims 41 | mrgReturn (axis, r) 42 | delinearized <- mapM delinearizeAxis layout 43 | mrgReturn $ fromKVPairs delinearized 44 | 45 | linearize :: [Axis] -> Sizes -> Indices -> SymInteger 46 | linearize layout dims indices = 47 | sum $ 48 | map 49 | (\dim -> getAxis dim indices * linearizationFactorsHash HM.! dim) 50 | layout 51 | where 52 | sizesOrdered = map (`getAxis` dims) layout 53 | sizesOrderedAdj = tail sizesOrdered 54 | linearizationFactors = scanr (*) 1 sizesOrderedAdj 55 | linearizationFactorsHash = HM.fromList $ zip layout linearizationFactors 56 | -------------------------------------------------------------------------------- /rules/xla/add/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: forall a. NumRule a 7 | rule01 _ = do 8 | rclass <- newRClass "rclass" 9 | map <- newMap "map" rclass 10 | tA <- newTensor @a "A" [rclass --> map] 11 | c1 <- constant @a "c1" [rclass --> map] 12 | c2 <- constant @a "c2" [rclass --> map] 13 | lhs <- numBinOp Add (numBinOp Add tA c1) c2 14 | rhs <- numBinOp Add tA (numBinOp Add c1 c2) 15 | rewrite "Add(Add(A, c1), c2) ⇒ Add(A, Add(c1, c2))" lhs rhs 16 | 17 | rule02 :: forall a. NumRule a 18 | rule02 _ = do 19 | rclass <- newRClass "rclass" 20 | map <- newMap "map" rclass 21 | tA <- newTensor @a "A" [rclass --> map] 22 | zeroTensor <- constant @a 0 [rclass --> map] 23 | lhs <- numBinOp Add tA zeroTensor 24 | let rhs = tA 25 | rewrite "Add(A, 0) ⇒ A" lhs rhs 26 | 27 | rule03 :: forall a. NumRule a 28 | rule03 _ = do 29 | rclass <- newRClass "rclass" 30 | map <- newMap "map" rclass 31 | tA <- newTensor @a "A" [rclass --> map] 32 | zeroTensor <- constant @a 0 [rclass --> map] 33 | lhs <- numBinOp Add zeroTensor tA 34 | let rhs = tA 35 | rewrite "Add(0, A) ⇒ A" lhs rhs 36 | 37 | rule04 :: forall a. NumRule a 38 | rule04 _ = do 39 | rclass <- newRClass "rclass" 40 | map <- newMap "map" rclass 41 | tA <- newTensor @a "A" [rclass --> map] 42 | c <- constant @a "c" [rclass --> map] 43 | lhs <- numBinOp Add c tA 44 | rhs <- numBinOp Add tA c 45 | rewrite "Add(Const, A) ⇒ Add(A, Const)" lhs rhs 46 | 47 | main :: IO () 48 | main = do 49 | print "############################## rule01 ##############################" 50 | verifyNumDSL rule01 51 | print "############################## rule02 ##############################" 52 | verifyNumDSL rule02 53 | print "############################## rule03 ##############################" 54 | verifyNumDSL rule03 55 | print "############################## rule04 ##############################" 56 | verifyNumDSL rule04 57 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/DSL/RelabelMap.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE FlexibleInstances #-} 2 | {-# LANGUAGE MultiParamTypeClasses #-} 3 | 4 | module TensorRight.Internal.DSL.RelabelMap (IsRelabelMap (..), RelabelMapDesc (..)) where 5 | 6 | import qualified Data.HashMap.Lazy as HM 7 | import TensorRight.Internal.DSL.Identifier (RClassIdentifier) 8 | import TensorRight.Internal.DSL.Shape (RClassRef (ByRClass)) 9 | import TensorRight.Internal.DSL.Syntax (ArrowSyntax ((-->))) 10 | 11 | -- | A type class for types that can be converted to a map from 'RClassRef' to 12 | -- 'RClassRef'. These maps, called relabel maps, are used to specify the 13 | -- axis relabling mapping for the 'TensorRight.Internal.DSL.DSL.relabel' operator. 14 | class IsRelabelMap m where 15 | toRelabelMap :: m -> HM.HashMap RClassRef RClassRef 16 | 17 | instance IsRelabelMap (HM.HashMap RClassRef RClassRef) where 18 | toRelabelMap = id 19 | 20 | -- | t'RelabelMapDesc' describes a single relabling in a relabel map, which 21 | -- consists of an 'RClassRef' and another 'RClassRef'. A t'RelabelMapDesc' can be 22 | -- created using the syntax @'RClassRef' 'TensorRight.Internal.DSL.Syntax.-->' 'RClassRef'@ 23 | data RelabelMapDesc = RelabelMapDesc RClassRef RClassRef 24 | deriving (Eq, Show) 25 | 26 | instance IsRelabelMap RelabelMapDesc where 27 | toRelabelMap (RelabelMapDesc rclass1 rclass2) = HM.singleton rclass1 rclass2 28 | 29 | -- | TensorRight DSL uses a list of t'RelabelMapDesc' to represent relabel maps 30 | instance IsRelabelMap [RelabelMapDesc] where 31 | toRelabelMap = foldr (HM.union . toRelabelMap) HM.empty 32 | 33 | instance ArrowSyntax RClassIdentifier RClassRef RelabelMapDesc where 34 | rclass1 --> rclass2 = RelabelMapDesc (ByRClass rclass1) rclass2 35 | 36 | instance ArrowSyntax RClassIdentifier RClassIdentifier RelabelMapDesc where 37 | rclass1 --> rclass2 = RelabelMapDesc (ByRClass rclass1) (ByRClass rclass2) 38 | 39 | instance ArrowSyntax RClassRef RClassRef RelabelMapDesc where 40 | ref --> rclass = RelabelMapDesc ref rclass 41 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | ARG TARGETARCH 4 | 5 | ARG CVC5_ARCH_IDENTIFIER_arm64="arm64" 6 | ARG CVC5_ARCH_IDENTIFIER_amd64="x86_64" 7 | ARG CONDA_ARCH_IDENTIFIER_arm64="aarch64" 8 | ARG CONDA_ARCH_IDENTIFIER_amd64="x86_64" 9 | 10 | RUN apt-get update \ 11 | && apt-get install -y \ 12 | curl wget unzip git z3 \ 13 | build-essential curl libffi-dev libffi8ubuntu1 libgmp-dev libgmp10 libncurses-dev libncurses5 libtinfo5 && \ 14 | apt-get clean && \ 15 | rm -rf /var/lib/apt/lists/* 16 | 17 | 18 | RUN export CVC5_ARCH_IDENTIFIER0="CVC5_ARCH_IDENTIFIER_$TARGETARCH" && \ 19 | eval CVC5_ARCH_IDENTIFIER="\$$CVC5_ARCH_IDENTIFIER0" && \ 20 | wget -O ~/cvc5-Linux-static.zip https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.0/cvc5-Linux-${CVC5_ARCH_IDENTIFIER}-static.zip && \ 21 | unzip ~/cvc5-Linux-static.zip && \ 22 | rm ~/cvc5-Linux-static.zip && \ 23 | mv cvc5-Linux-*-static ~/cvc5 && \ 24 | ln -s ~/cvc5/bin/cvc5 /usr/bin/cvc5 25 | 26 | # Install Python 27 | RUN export CONDA_ARCH_IDENTIFIER0="CONDA_ARCH_IDENTIFIER_$TARGETARCH" && \ 28 | eval CONDA_ARCH_IDENTIFIER="\$$CONDA_ARCH_IDENTIFIER0" && \ 29 | curl -sL -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py37_4.12.0-Linux-${CONDA_ARCH_IDENTIFIER}.sh && \ 30 | bash ~/miniconda.sh -b -p ~/.conda && \ 31 | rm ~/miniconda.sh && \ 32 | ~/.conda/bin/conda init 33 | 34 | RUN ~/.conda/bin/conda run -n base \ 35 | conda install -y python=3.10 numpy=1.26.4 matplotlib=3.8.4 36 | RUN ~/.conda/bin/conda run -n base \ 37 | conda clean -y --all 38 | 39 | RUN curl --proto '=https' --tlsv1.2 -sSf https://get-ghcup.haskell.org | sh 40 | ENV PATH="/root/.ghcup/bin:${PATH}" 41 | ENV LANG=C.UTF-8 42 | RUN ghcup install stack 43 | 44 | WORKDIR /home 45 | RUN mkdir -p /home/tr/ 46 | 47 | COPY rules/ /home/tr/rules 48 | COPY src/ /home/tr/src 49 | COPY test/ /home/tr/test 50 | COPY package.yaml /home/tr/package.yaml 51 | COPY stack.yaml /home/tr/stack.yaml 52 | COPY Makefile /home/tr/Makefile 53 | COPY runall.sh /home/tr/runall.sh 54 | 55 | WORKDIR /home/tr 56 | RUN stack build 57 | -------------------------------------------------------------------------------- /rules/xla/select/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: forall a. AnyDTypeRule a 7 | rule01 _ = do 8 | rclass <- newRClass "rclass" 9 | map <- newMap "map" rclass 10 | tP <- newTensor @SymBool "P" [rclass --> map] 11 | tA <- newTensor @a "A" [rclass --> map] 12 | lhs <- select tP tA tA 13 | let rhs = tA 14 | rewrite "Select(P, A, A) ⇒ A" lhs rhs 15 | 16 | rule02 :: forall a. AnyDTypeRule a 17 | rule02 _ = do 18 | rclass <- newRClass "rclass" 19 | map <- newMap "map" rclass 20 | pred <- constant @SymBool true [rclass --> map] 21 | tA <- newTensor @a "A" [rclass --> map] 22 | tB <- newTensor @a "B" [rclass --> map] 23 | lhs <- select pred tA tB 24 | let rhs = tA 25 | rewrite "Select(True, A, B) ⇒ A" lhs rhs 26 | 27 | rule03 :: forall a. AnyDTypeRule a 28 | rule03 _ = do 29 | rclass <- newRClass "rclass" 30 | map <- newMap "map" rclass 31 | pred <- constant @SymBool false [rclass --> map] 32 | tA <- newTensor @a "A" [rclass --> map] 33 | tB <- newTensor @a "B" [rclass --> map] 34 | lhs <- select pred tA tB 35 | let rhs = tB 36 | rewrite "Select(False, A, B) ⇒ B" lhs rhs 37 | 38 | rule04 :: forall a. AnyDTypeRule a 39 | rule04 _ = do 40 | rclass <- newRClass "rclass" 41 | map <- newMap "map" rclass 42 | pred <- constant @SymBool false [rclass --> map] 43 | tA <- newTensor @a "A" [rclass --> map] 44 | tB <- newTensor @a "B" [rclass --> map] 45 | lhs <- select (boolUnaryOp Not pred) tA tB 46 | rhs <- select pred tB tA 47 | rewrite "Select(Not(P), A, B) ⇒ Select(P, B, A)" lhs rhs 48 | 49 | main :: IO () 50 | main = do 51 | print "############################## rule01 ##############################" 52 | verifyAnyDTypeDSL rule01 53 | print "############################## rule02 ##############################" 54 | verifyAnyDTypeDSL rule02 55 | print "############################## rule03 ##############################" 56 | verifyAnyDTypeDSL rule03 57 | print "############################## rule04 ##############################" 58 | verifyAnyDTypeDSL rule04 59 | -------------------------------------------------------------------------------- /flake.nix: -------------------------------------------------------------------------------- 1 | { 2 | description = "Automated Verification of Tensor Graph Rewrites"; 3 | inputs.nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; 4 | inputs.flake-utils.url = "github:numtide/flake-utils"; 5 | 6 | outputs = { self, nixpkgs, flake-utils }: 7 | flake-utils.lib.eachDefaultSystem (system: 8 | let 9 | pkgs = import nixpkgs { inherit system; }; 10 | hPkgs = pkgs.haskell.packages."ghc9101".extend (hself: hsuper: { 11 | hlint = hself.callCabal2nix "hlint" 12 | (pkgs.fetchFromGitHub { 13 | owner = "ndmitchell"; 14 | repo = "hlint"; 15 | rev = "7dfba720eaf6fa9bd0b23ae269334559aa722847"; 16 | sha256 = "sha256-niGBdSrkatr+TZCcLYXo4MDg5FyXTYiKQ5K+ZIWSWBs="; 17 | }) 18 | { }; 19 | }); 20 | stack-wrapped = pkgs.symlinkJoin { 21 | name = "stack"; # will be available as the usual `stack` in terminal 22 | paths = [ pkgs.stack ]; 23 | buildInputs = [ pkgs.makeWrapper ]; 24 | postBuild = '' 25 | wrapProgram $out/bin/stack \ 26 | --add-flags "\ 27 | --no-nix \ 28 | --system-ghc \ 29 | --no-install-ghc \ 30 | " 31 | ''; 32 | }; 33 | 34 | devTools = with pkgs; [ 35 | z3 36 | stack-wrapped 37 | hPkgs.hlint 38 | hPkgs.haskell-language-server 39 | (cvc5.overrideAttrs (oldAttrs: rec { 40 | cmakeFlags = oldAttrs.cmakeFlags ++ [ 41 | "-DUSE_POLY=ON" 42 | ]; 43 | buildInputs = oldAttrs.buildInputs ++ [ 44 | libpoly 45 | ]; 46 | })) 47 | (python3.withPackages (ps: with ps; [ 48 | numpy 49 | matplotlib 50 | ])) 51 | ]; 52 | in 53 | { 54 | formatter = pkgs.nixpkgs-fmt; 55 | devShell = pkgs.mkShell { 56 | buildInputs = devTools; 57 | 58 | LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath devTools; 59 | }; 60 | }); 61 | } 62 | -------------------------------------------------------------------------------- /rules/xla/clamp/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Control.Monad.Except (runExceptT) 4 | import Grisette hiding ((-->)) 5 | import TensorRight 6 | import TensorRight.Internal.Core.Tensor.TensorInt (tensorValLt) 7 | 8 | rule01 :: forall a. NumRule a 9 | rule01 _ = do 10 | rclass <- newRClass "rclass" 11 | map <- newMap "map" rclass 12 | tA <- newTensor @a "A" [rclass --> map] 13 | c1 <- newTensor @a "c1" [rclass --> map] 14 | c2 <- newTensor @a "c2" [rclass --> map] 15 | lhs <- clamp c1 (clamp c1 tA c2) c2 16 | rhs <- clamp c1 tA c2 17 | rewrite "Clamp(c1, Clamp(c1, A, c2), c2) ⇒ Clamp(c1, A, c2)" lhs rhs 18 | 19 | rule02 :: forall a. NumRule a 20 | rule02 _ = do 21 | rclass <- newRClass "rclass" 22 | map <- newMap "map" rclass 23 | tA <- newTensor @a "A" [rclass --> map] 24 | c1 <- newTensor @a "c1" [rclass --> map] 25 | c2 <- newTensor @a "c2" [rclass --> map] 26 | 27 | lhs <- numBinOp Max c1 (numBinOp Min tA c2) 28 | forallIdx <- newMap "forallIdx" rclass 29 | numTensorAssumption 30 | [c1, c2] 31 | forallIdx 32 | ( \[vc1, vc2] -> simpleMerge $ do 33 | u <- runExceptT $ tensorValLt vc1 vc2 34 | case u of 35 | Left _ -> con True 36 | Right v -> return v 37 | ) 38 | 39 | rhs <- clamp c1 tA c2 40 | rewrite "Max(Broadcast(c1), Min(A, Broadcast(c2))) ⇒ Clamp(c1, A, c2)" lhs rhs 41 | 42 | rule03 :: forall a. NumRule a 43 | rule03 _ = do 44 | rclass <- newRClass "rclass" 45 | map <- newMap "map" rclass 46 | tA <- newTensor @a "A" [rclass --> map] 47 | c1 <- newTensor @a "c1" [rclass --> map] 48 | c2 <- newTensor @a "c2" [rclass --> map] 49 | lhs <- numBinOp Min c1 (numBinOp Max tA c2) 50 | rhs <- clamp c2 tA c1 51 | rewrite "Min(Broadcast(c1), Max(A, Broadcast(c2))) ⇒ Clamp(c1, A, c2)" lhs rhs 52 | 53 | main :: IO () 54 | main = do 55 | print "############################## rule01 ##############################" 56 | verifyNumDSL rule01 57 | print "############################## rule02 ##############################" 58 | verifyNumDSL rule02 59 | print "############################## rule03 ##############################" 60 | verifyNumDSL rule03 61 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/DSL/Parameters.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveAnyClass #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE DerivingVia #-} 4 | {-# LANGUAGE FlexibleInstances #-} 5 | {-# LANGUAGE MultiParamTypeClasses #-} 6 | 7 | module TensorRight.Internal.DSL.Parameters (IsParamMaps (..), ParamDesc (..)) where 8 | 9 | import qualified Data.HashMap.Lazy as HM 10 | import Data.Hashable (Hashable) 11 | import GHC.Generics (Generic) 12 | import Grisette (Default (Default), PPrint) 13 | import TensorRight.Internal.DSL.Identifier (MapIdentifier, RClassIdentifier) 14 | import TensorRight.Internal.DSL.Shape (RClassRef (ByRClass)) 15 | import TensorRight.Internal.DSL.Syntax (ArrowSyntax ((-->))) 16 | 17 | -- | A type class for types that can be converted to a map from 'RClassRef' to 18 | -- 'MapIdentifier'. These maps, called parameter maps, are used to specify the 19 | -- parameters or attributes of a tensor opertor. 20 | class IsParamMaps m where 21 | toParamMaps :: m -> HM.HashMap RClassRef MapIdentifier 22 | 23 | instance IsParamMaps (HM.HashMap RClassRef MapIdentifier) where 24 | toParamMaps = id 25 | 26 | -- | TensorRight DSL uses a list of t'ParamDesc' to represent parameter maps 27 | instance IsParamMaps [ParamDesc] where 28 | toParamMaps = foldr (HM.union . toParamMaps) HM.empty 29 | 30 | -- | t'ParamDesc' describes a mapping in a parameter map, which consists of an 31 | -- 'RClassRef' and a 'MapIdentifier'. A t'ParamDesc' can be created in the 32 | -- following ways: 33 | -- 34 | -- - Directly using the RClass for unlabelled mappings: @'RClassIdentifier' 'TensorRight.Internal.DSL.Syntax.-->' 'MapIdentifier'@ 35 | -- - Using a 'RClassRef' for labelled mappings: @'RClassRef' 'TensorRight.Internal.DSL.Syntax.-->' 'MapIdentifier'@ 36 | data ParamDesc = ParamDesc RClassRef MapIdentifier 37 | deriving (Generic, Eq, Show) 38 | deriving (Hashable) 39 | deriving (PPrint) via (Default ParamDesc) 40 | 41 | instance IsParamMaps ParamDesc where 42 | toParamMaps (ParamDesc rclass map) = HM.singleton rclass map 43 | 44 | instance ArrowSyntax RClassIdentifier MapIdentifier ParamDesc where 45 | rclass --> map = ParamDesc (ByRClass rclass) map 46 | 47 | instance ArrowSyntax RClassRef MapIdentifier ParamDesc where 48 | ref --> map = ParamDesc ref map 49 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | # This file was automatically generated by 'stack init' 2 | # 3 | # Some commonly used options have been documented as comments in this file. 4 | # For advanced use and comprehensive documentation of the format, please see: 5 | # https://docs.haskellstack.org/en/stable/yaml_configuration/ 6 | 7 | # Resolver to choose a 'specific' stackage snapshot or a compiler version. 8 | # A snapshot resolver dictates the compiler version and the set of packages 9 | # to be used for project dependencies. For example: 10 | # 11 | # resolver: lts-3.5 12 | # resolver: nightly-2015-09-21 13 | # resolver: ghc-7.10.2 14 | # 15 | # The location of a snapshot can be provided as a file or url. Stack assumes 16 | # a snapshot provided as a file might change, whereas a url resource does not. 17 | # 18 | # resolver: ./custom-snapshot.yaml 19 | # resolver: https://example.com/snapshots/2018-01-01.yaml 20 | resolver: nightly-2025-01-06 21 | 22 | # User packages to be built. 23 | # Various formats can be used as shown in the example below. 24 | # 25 | # packages: 26 | # - some-directory 27 | # - https://example.com/foo/bar/baz-0.0.2.tar.gz 28 | # subdirs: 29 | # - auto-update 30 | # - wai 31 | packages: 32 | - . 33 | # Dependency packages to be pulled from upstream that are not in the resolver. 34 | # These entries can reference officially published versions as well as 35 | # forks / in-progress versions pinned to a git hash. For example: 36 | # 37 | # extra-deps: 38 | # - acme-missiles-0.3 39 | # - git: https://github.com/commercialhaskell/stack.git 40 | # commit: e7b331f14bcffb8367cd58fbfc8b40ec7642100a 41 | # 42 | # extra-deps: [] 43 | 44 | extra-deps: 45 | - git: https://github.com/lsrcz/grisette.git 46 | commit: 751e74a5c0049868da75f320c47a1ce435efc9d0 47 | 48 | # Override default flag values for local packages and extra-deps 49 | # flags: {} 50 | 51 | # Extra package databases containing global packages 52 | # extra-package-dbs: [] 53 | 54 | # Control whether we use the GHC we find on the path 55 | # system-ghc: true 56 | # 57 | # Require a specific version of stack, using version ranges 58 | # require-stack-version: -any # Default 59 | # require-stack-version: ">=2.7" 60 | # 61 | # Override the architecture used by stack, especially useful on Windows 62 | # arch: i386 63 | # arch: x86_64 64 | # 65 | # Extra directories used by stack for building 66 | # extra-include-dirs: [/path/to/dir] 67 | # extra-lib-dirs: [/path/to/dir] 68 | # 69 | # Allow a newer minor version of GHC than the snapshot specifies 70 | # compiler-check: newer-minor 71 | -------------------------------------------------------------------------------- /test/Core/LinearizationTest.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | 3 | module Core.LinearizationTest (linearizationTest) where 4 | 5 | import Grisette (Solvable (con), mrgReturn) 6 | import TensorRight.Internal.Core.Axis (Axis (Axis), AxisMapLike (fromKVPairs)) 7 | import TensorRight.Internal.Core.Linearization (delinearize, linearize) 8 | import Test.Framework (Test, testGroup) 9 | import Test.Framework.Providers.HUnit (testCase) 10 | import Test.Framework.Providers.QuickCheck2 (testProperty) 11 | import Test.HUnit ((@?=)) 12 | import Test.QuickCheck 13 | ( NonNegative (NonNegative), 14 | Positive (Positive), 15 | ioProperty, 16 | ) 17 | import TestUtil (eqWhenSuccess) 18 | 19 | axisa :: Axis 20 | axisa = Axis "a" 21 | 22 | axisb :: Axis 23 | axisb = Axis "b" 24 | 25 | axisc :: Axis 26 | axisc = Axis "c" 27 | 28 | linearizationTest :: Test 29 | linearizationTest = 30 | testGroup 31 | "Linearization" 32 | [ testProperty "linearization and delinearization are inverse" $ 33 | \(Positive as) 34 | (Positive bs) 35 | (Positive cs) 36 | (NonNegative ai) 37 | (NonNegative bi) 38 | (NonNegative ci) -> ioProperty $ do 39 | let layout = [axisa, axisb, axisc] 40 | let sizes = 41 | fromKVPairs 42 | [ (axisa, con $ as + ai), 43 | (axisb, con $ bs + bi), 44 | (axisc, con $ cs + ci) 45 | ] 46 | let indices = 47 | fromKVPairs 48 | [(axisa, con ai), (axisb, con bi), (axisc, con ci)] 49 | let linearized = linearize layout sizes indices 50 | let delinearized = delinearize layout sizes linearized 51 | delinearized `eqWhenSuccess` indices, 52 | testCase "linearize" $ do 53 | let sizes = fromKVPairs [(axisa, con 5), (axisb, con 6), (axisc, con 7)] 54 | let indices = 55 | fromKVPairs [(axisa, con 2), (axisb, con 3), (axisc, con 4)] 56 | let layout = [axisa, axisb, axisc] 57 | let expected = 2 * 6 * 7 + 3 * 7 + 4 58 | linearize layout sizes indices @?= expected, 59 | testCase "delinearize" $ do 60 | let sizes = fromKVPairs [(axisa, con 5), (axisb, con 6), (axisc, con 7)] 61 | let linearized = 2 * 6 * 7 + 3 * 7 + 4 62 | let layout = [axisa, axisb, axisc] 63 | let expected = 64 | fromKVPairs [(axisa, con 2), (axisb, con 3), (axisc, con 4)] 65 | delinearize layout sizes linearized @?= mrgReturn expected 66 | ] 67 | -------------------------------------------------------------------------------- /rules/xla/reverse/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: forall a. AnyDTypeRule a 7 | rule01 _ = do 8 | rclass <- newRClass "rclass" 9 | map <- newMap "map" rclass 10 | 11 | tA <- newTensor @a "A" [rclass --> map] 12 | lhs <- reverseTensor tA [ByRClass rclass] 13 | precondition [map] $ \[s] -> s .== 1 14 | 15 | let rhs = tA 16 | rewrite "Reverse(A, dims) ⇒ A" lhs rhs 17 | 18 | rule02 :: forall a. AnyDTypeRule a 19 | rule02 _ = do 20 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 21 | rc0Size <- newMap "rc0Size" rclass0 22 | rc1Size <- newMap "rc1Size" rclass1 23 | tA <- newTensor @a "A" [rclass0 --> rc0Size, rclass1 --> rc1Size] 24 | lhs <- reverseTensor (reverseTensor tA [ByRClass rclass0]) [ByRClass rclass0] 25 | let rhs = tA 26 | rewrite "Reverse(Reverse(A, dims1), dims2) ⇒ A" lhs rhs 27 | 28 | rule03 :: forall a. AnyDTypeRule a 29 | rule03 _ = do 30 | [rclass0, rclass1, rclass2] <- newRClasses ["rclass0", "rclass1", "rclass2"] 31 | rc0Size <- newMap "rc0Size" rclass0 32 | rc1Size <- newMap "rc1Size" rclass1 33 | rc2Size <- newMap "rc2Size" rclass2 34 | tA <- 35 | newTensor @a 36 | "A" 37 | [ rclass0 --> rc0Size, 38 | rclass1 --> rc1Size, 39 | rclass2 --> rc2Size 40 | ] 41 | lhs <- 42 | reverseTensor 43 | (reverseTensor tA [ByRClass rclass0, ByRClass rclass1]) 44 | [ByRClass rclass1, ByRClass rclass2] 45 | rhs <- reverseTensor tA [ByRClass rclass2, ByRClass rclass0] 46 | rewrite 47 | "Reverse(Reverse(A, dims1), dims2) ⇒ Reverse(A, disjoint union of dims1 and dims2)" 48 | lhs 49 | rhs 50 | 51 | rule04 :: forall a. NumRule a 52 | rule04 _ = do 53 | rclass <- newRClass "rclass" 54 | map <- newMap "map" rclass 55 | tA <- newTensor @a "A" [rclass --> map] 56 | const <- constant @a "const" [rclass --> map] 57 | lhs <- reverseTensor (numBinOp Add tA const) [ByRClass rclass] 58 | rhs <- numBinOp Add (reverseTensor tA [ByRClass rclass]) const 59 | rewrite 60 | "Reverse(Binary(A, Const)) ⇒ Binary(Reverse(A), Const)" 61 | lhs 62 | rhs 63 | 64 | main :: IO () 65 | main = do 66 | print "############################## rule01 ##############################" 67 | verifyAnyDTypeDSL rule01 68 | print "############################## rule02 ##############################" 69 | verifyAnyDTypeDSL rule02 70 | print "############################## rule03 ##############################" 71 | verifyAnyDTypeDSL rule03 72 | print "############################## rule04 ##############################" 73 | verifyNumDSL rule04 74 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/DSL/Identifier.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE DeriveAnyClass #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE DeriveLift #-} 5 | {-# LANGUAGE DerivingStrategies #-} 6 | {-# LANGUAGE KindSignatures #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | 9 | module TensorRight.Internal.DSL.Identifier 10 | ( IdentifierKind (..), 11 | Identifier (..), 12 | RClassIdentifier, 13 | MapIdentifier, 14 | Label, 15 | TensorIdentifier, 16 | nextIdentifier, 17 | ) 18 | where 19 | 20 | import Control.DeepSeq (NFData) 21 | import Data.Hashable (Hashable) 22 | import Data.String (IsString (fromString)) 23 | import qualified Data.Text as T 24 | import GHC.Generics (Generic) 25 | import Grisette 26 | ( AsMetadata (asMetadata, fromMetadata), 27 | PPrint (pformat), 28 | SExpr (Atom, List), 29 | ) 30 | import Language.Haskell.TH.Syntax (Lift) 31 | 32 | data IdentifierKind = RClassKind | MapKind | TensorKind 33 | 34 | data Identifier (kind :: IdentifierKind) 35 | = SimpleIdentifier T.Text 36 | | IndexedIdentifier T.Text Int 37 | deriving stock (Eq, Ord, Generic, Lift) 38 | deriving anyclass (Hashable, NFData) 39 | 40 | instance AsMetadata (Identifier kind) where 41 | asMetadata (SimpleIdentifier name) = List [Atom "SimpleIdentifier", Atom name] 42 | asMetadata (IndexedIdentifier name i) = 43 | List [Atom "IndexedIdentifier", Atom name, Atom $ T.pack $ show i] 44 | fromMetadata (List [Atom "SimpleIdentifier", Atom name]) = 45 | Just $ SimpleIdentifier name 46 | fromMetadata (List [Atom "IndexedIdentifier", Atom name, Atom i]) = 47 | Just $ IndexedIdentifier name (read $ T.unpack i) 48 | fromMetadata _ = Nothing 49 | 50 | instance IsString (Identifier kind) where 51 | fromString = SimpleIdentifier . T.pack 52 | 53 | instance Show (Identifier kind) where 54 | show (SimpleIdentifier name) = T.unpack name 55 | show (IndexedIdentifier name i) = T.unpack name <> "@" <> show i 56 | 57 | instance PPrint (Identifier kind) where 58 | pformat (SimpleIdentifier name) = pformat name 59 | pformat (IndexedIdentifier name i) = pformat name <> "@" <> pformat i 60 | 61 | -- | An identifier for an RClass. 62 | type RClassIdentifier = Identifier 'RClassKind 63 | 64 | -- | An identifier for an aggregated-map. 65 | type MapIdentifier = Identifier 'MapKind 66 | 67 | -- | An identifier for a tensor. 68 | type TensorIdentifier = Identifier 'TensorKind 69 | 70 | nextIdentifier :: Identifier kind -> Identifier kind 71 | nextIdentifier (SimpleIdentifier name) = IndexedIdentifier name 0 72 | nextIdentifier (IndexedIdentifier name i) = IndexedIdentifier name (i + 1) 73 | 74 | -- | A 'Label' represents a name for an aggregated-axis. 75 | -- It is not needed if an RClass has only one aggregated-axis. 76 | type Label = T.Text 77 | -------------------------------------------------------------------------------- /rules/xla/dyupslice/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: forall a. NumRule a 7 | rule01 _ = do 8 | rclass <- newRClass "rclass" 9 | [rcSizeA, rcSizeB, rcStart] <- 10 | newMaps ["rcSizeA", "rcSizeB", "rcStart"] rclass 11 | tA <- newTensor @a "A" [rclass --> rcSizeA] 12 | tB <- newTensor @a "B" [rclass --> rcSizeB] 13 | lhs <- 14 | numBinOp 15 | Add 16 | tA 17 | (dynamicUpdateSlice (constant @a 0 [rclass --> rcSizeA]) tB [rclass --> rcStart]) 18 | rhs <- 19 | dynamicUpdateSlice 20 | tA 21 | ( numBinOp 22 | Add 23 | tB 24 | ( dynamicSlice tA $ 25 | DySlice 26 | { start = [rclass --> rcStart], 27 | sizes = [rclass --> rcSizeB] 28 | } 29 | ) 30 | ) 31 | [rclass --> rcStart] 32 | rewrite "Add(A, DynamicUpdateSlice(Broadcast(0), B) ⇒ DynamicUpdateSlice(A,...)" lhs rhs 33 | 34 | rule02 :: forall a. AnyDTypeRule a 35 | rule02 _ = do 36 | rclass <- newRClass "rclass" 37 | [rcOrigSize, rcNewSize, rcStart] <- 38 | newMaps ["rcOrigSize", "rcNewSize", "rcStart"] rclass 39 | 40 | tA <- newTensor @a "A" [rclass --> rcOrigSize] 41 | lhs <- 42 | dynamicUpdateSlice 43 | (constant @a "a" [rclass --> rcNewSize]) 44 | tA 45 | [rclass --> rcStart] 46 | 47 | rcInt <- newConstMap "rcInt" 0 rclass 48 | rcHigh <- combineMap "rcHigh" (\[ns, os, s] -> ns - os - s) [rcNewSize, rcOrigSize, rcStart] 49 | rhs <- 50 | pad tA ("a" :: a) $ 51 | Padding 52 | { low = [rclass --> rcStart], 53 | high = [rclass --> rcHigh], 54 | interior = [rclass --> rcInt] 55 | } 56 | 57 | rewrite "DynamicUpdateSlice(Broadcast(Const),A,...) ⇒ Pad(" lhs rhs 58 | 59 | rule03 :: forall a. AnyDTypeRule a 60 | rule03 _ = do 61 | rclass <- newRClass "rclass" 62 | [rcSize, rcStart] <- newMaps ["rcSize", "rcStart"] rclass 63 | 64 | tA <- newTensor @a "tA" [rclass --> rcSize] 65 | tB <- newTensor @a "tB" [rclass --> rcSize] 66 | lhs <- dynamicUpdateSlice tA tB [rclass --> rcStart] 67 | precondition [rcSize] $ \[s] -> s .== 0 68 | 69 | let rhs = tB 70 | rewrite "DynamicUpdateSlice(A, B, 0) ⇒ B" lhs rhs 71 | 72 | rule04 :: forall a. AnyDTypeRule a 73 | rule04 _ = do 74 | rclass <- newRClass "rclass" 75 | [rcSizeA, rcSizeB, rcStart0, rcLength, rcStart1] <- 76 | newMaps ["rcSizeA", "rcSizeB", "startMap0", "sliceSizeMap0", "startMap1"] rclass 77 | 78 | tA <- newTensor @a "A" [rclass --> rcSizeA] 79 | tB <- newTensor @a "B" [rclass --> rcSizeB] 80 | lhs <- 81 | dynamicUpdateSlice 82 | tA 83 | ( dynamicUpdateSlice 84 | ( dynamicSlice tA $ 85 | DySlice 86 | { start = [rclass --> rcStart0], 87 | sizes = [rclass --> rcLength] 88 | } 89 | ) 90 | tB 91 | [rclass --> rcStart1] 92 | ) 93 | [rclass --> rcStart0] 94 | 95 | rcStart2 <- combineMap "rcStart2" sum [rcStart0, rcStart1] 96 | rhs <- dynamicUpdateSlice tA tB [rclass --> rcStart2] 97 | rewrite "DynamicUpdateSlice(A, DynamicUpdateSlice(DynamicSlice(A, ...), B, ...), ...)) ⇒ DynamicUpdateSlice(A, B, ...)" lhs rhs 98 | 99 | main :: IO () 100 | main = do 101 | print "############################## rule01 ##############################" 102 | verifyNumDSL rule01 103 | print "############################## rule02 ##############################" 104 | verifyAnyDTypeDSL rule02 105 | print "############################## rule03 ##############################" 106 | verifyAnyDTypeDSL rule03 107 | print "############################## rule04 ##############################" 108 | verifyAnyDTypeDSL rule04 109 | -------------------------------------------------------------------------------- /rules/debug/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rulePadTwice :: forall a. AnyDTypeRule a 7 | rulePadTwice _ = do 8 | rclass <- newRClass "rclass" 9 | 10 | shape <- newMap "shape" rclass 11 | 12 | [innerLow, innerInterior, innerHigh] <- 13 | newMaps ["innerLow", "innerInterior", "innerHigh"] rclass 14 | [outerLow, outerInterior, outerHigh] <- 15 | newMaps ["outerLow", "outerInterior", "outerHigh"] rclass 16 | [rhsLow, rhsInterior, rhsHigh] <- 17 | newMaps ["rhsLow", "rhsInterior", "rhsHigh"] rclass 18 | 19 | let cond outerPad innerPad rhsPad = 20 | precondition [outerPad, innerPad, rhsPad] $ 21 | \[vouterPad, vinnerPad, vrhsPad] -> 22 | vrhsPad .== vouterPad + vinnerPad 23 | 24 | cond outerLow innerLow rhsLow 25 | cond outerInterior innerInterior rhsInterior 26 | cond outerHigh innerHigh rhsHigh 27 | 28 | precondition [outerInterior] $ \[vi0] -> vi0 .== 0 29 | precondition [innerLow] $ \[vi0] -> vi0 .>= 0 30 | precondition [innerHigh] $ \[vi0] -> vi0 .>= 0 31 | 32 | x <- newTensor @a "x" [rclass --> shape] 33 | 34 | lhs <- 35 | pad 36 | ( pad x ("a" :: a) $ 37 | Padding 38 | { low = [rclass --> innerLow], 39 | interior = [rclass --> innerInterior], 40 | high = [rclass --> innerHigh] 41 | } 42 | ) 43 | ("a" :: a) 44 | $ Padding 45 | { low = [rclass --> outerLow], 46 | interior = [rclass --> outerInterior], 47 | high = [rclass --> outerHigh] 48 | } 49 | rhs <- 50 | pad x ("a" :: a) $ 51 | Padding 52 | { low = [rclass --> rhsLow], 53 | interior = [rclass --> rhsInterior], 54 | high = [rclass --> rhsHigh] 55 | } 56 | rewrite 57 | "when i0 == 0, pad(pad(x, l0, i0, h0), l1, i1, h1) --> pad(x, l0+l1, i0+i1, h0+h1)" 58 | lhs 59 | rhs 60 | 61 | rulePadLowCombine :: forall a. AnyDTypeRule a 62 | rulePadLowCombine _ = do 63 | rclass <- newRClass "rclass" 64 | 65 | shape <- newMap "shape" rclass 66 | 67 | [innerLow, outerLow, rhsLow] <- 68 | newMaps ["innerLow", "outerLow", "rhsLow"] rclass 69 | 70 | let cond outerPad innerPad rhsPad = 71 | precondition [outerPad, innerPad, rhsPad] $ 72 | \[vouterPad, vinnerPad, vrhsPad] -> 73 | vrhsPad .== vouterPad + vinnerPad 74 | 75 | cond outerLow innerLow rhsLow 76 | precondition [innerLow] $ \[vi0] -> vi0 .>= 0 77 | precondition [outerLow] $ \[vi0] -> vi0 .>= 0 78 | 79 | x <- newTensor @a "x" [rclass --> shape] 80 | 81 | lhs <- 82 | padLow 83 | (padLow x ("a" :: a) [rclass --> innerLow]) 84 | ("a" :: a) 85 | [rclass --> outerLow] 86 | 87 | rhs <- 88 | padLow x ("a" :: a) [rclass --> rhsLow] 89 | rewrite 90 | "padLow(padLow(x, l0), l1) --> padLow(x, l0+l1)" 91 | lhs 92 | rhs 93 | 94 | ruleDyUpSliceSlice :: forall a. AnyDTypeRule a 95 | ruleDyUpSliceSlice _ = do 96 | rclass <- newRClass "rclass" 97 | rcSize <- newMap "rcSize" rclass 98 | rcStart <- newConstMap "rcStart" 0 rclass 99 | rcStrideLhs <- newConstMap "rcStrideLhs" 1 rclass 100 | rcStrideRhs <- newConstMap "rcStrideRhs" 2 rclass 101 | rcOffset <- newConstMap "rcOffset" 1 rclass 102 | rcEndLhs <- combineMap "rcEndLhs" (\[s] -> divOr 0 (s + 1) 2) [rcSize] 103 | rcUpdateSize <- 104 | combineMap "rcUpdateSize" (\[e, o] -> e - o) [rcEndLhs, rcOffset] 105 | tX <- newTensor @a "X" [rclass --> rcSize] 106 | lhsSlice <- 107 | slice tX $ 108 | Slice 109 | { start = [rclass --> rcStart], 110 | end = [rclass --> rcEndLhs], 111 | strides = [rclass --> rcStrideLhs] 112 | } 113 | updateTensor <- constant @a ("a" :: a) [rclass --> rcUpdateSize] 114 | lhs <- dynamicUpdateSlice lhsSlice updateTensor [rclass --> rcOffset] 115 | rhsSlice <- 116 | slice tX $ 117 | Slice 118 | { start = [rclass --> rcStart], 119 | end = [rclass --> rcSize], 120 | strides = [rclass --> rcStrideRhs] 121 | } 122 | rhs <- dynamicUpdateSlice rhsSlice updateTensor [rclass --> rcOffset] 123 | rewrite "TensorRight Motivating Example" lhs rhs 124 | 125 | main :: IO () 126 | main = do 127 | verifyAnyDTypeDSL rulePadTwice 128 | verifyAnyDTypeDSL rulePadLowCombine 129 | verifyAnyDTypeDSL ruleDyUpSliceSlice 130 | -------------------------------------------------------------------------------- /rules/xla/broadcast/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: forall a. NumRule a 7 | rule01 _ = do 8 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 9 | rc0Size <- newMap "rc0Size" rclass0 10 | rc1Size <- newMap "rc1Size" rclass1 11 | tA <- newTensor @a "A" [rclass0 --> rc0Size] 12 | tB <- newTensor @a "B" [rclass0 --> rc0Size] 13 | lhs <- numBinOp Add (broadcast tA [rclass1 --> rc1Size]) (broadcast tB [rclass1 --> rc1Size]) 14 | rhs <- broadcast (numBinOp Add tA tB) [rclass1 --> rc1Size] 15 | rewrite "Add(Broadcast(A), Broadcast(B)) ⇒ Broadcast(Add(A, B))" lhs rhs 16 | 17 | rule02 :: forall a. NumRule a 18 | rule02 _ = do 19 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 20 | rc0Size <- newMap "rc0Size" rclass0 21 | rc1Size <- newMap "rc1Size" rclass1 22 | tA <- newTensor @a "A" [rclass0 --> rc0Size] 23 | tB <- newTensor @a "B" [rclass0 --> rc0Size] 24 | lhs <- numBinOp Mul (broadcast tA [rclass1 --> rc1Size]) (broadcast tB [rclass1 --> rc1Size]) 25 | rhs <- broadcast (numBinOp Mul tA tB) [rclass1 --> rc1Size] 26 | rewrite "Mul(Broadcast(A), Broadcast(B)) ⇒ Broadcast(Mul(A, B))" lhs rhs 27 | 28 | rule03 :: forall a. AnyDTypeRule a 29 | rule03 _ = do 30 | rclass <- newRClass "rclass" 31 | map <- newMap "map" rclass 32 | lhsTensor <- constant @a "a" [rclass --> map @@ "label1"] 33 | rhsTensor <- constant @a "a" [rclass --> map @@ "label2"] 34 | lhs <- relabel lhsTensor [ByLabel "label1" --> ByLabel "label2"] 35 | let rhs = rhsTensor 36 | rewrite "Transpose(Broadcast(Scalar)) ⇒ Broadcast(Scalar)" lhs rhs 37 | 38 | rule04 :: forall a. AnyDTypeRule a 39 | rule04 _ = do 40 | rclass <- newRClass "rclass" 41 | map <- newMap "map" rclass 42 | lhs <- reverseTensor (constant @a "a" [rclass --> map]) [ByRClass rclass] 43 | rhs <- constant @a "a" [rclass --> map] 44 | rewrite "Reverse(Broadcast(Scalar)) ⇒ Broadcast(Scalar)" lhs rhs 45 | 46 | rule05 :: forall a. AnyDTypeRule a 47 | rule05 _ = do 48 | rclass <- newRClass "rclass" 49 | [origSize, start, end, stride] <- 50 | newMaps ["origSize", "start", "end", "stride"] rclass 51 | 52 | lhs <- 53 | slice (constant @a "a" [rclass --> origSize]) $ 54 | Slice 55 | { start = [rclass --> start], 56 | end = [rclass --> end], 57 | strides = [rclass --> stride] 58 | } 59 | 60 | newSize <- 61 | combineMap 62 | "newSize" 63 | (\[s, e, p] -> divOr 0 (e - s + p - 1) p) 64 | [start, end, stride] 65 | rhs <- constant @a "a" [rclass --> newSize] 66 | rewrite "Slice(Broadcast(Scalar)) ⇒ Broadcast(Scalar)" lhs rhs 67 | 68 | rule06 :: forall a. AnyDTypeRule a 69 | rule06 _ = do 70 | rclass <- newRClass "rclass" 71 | [origSize, newSize, start] <- newMaps ["origSize", "newSize", "start"] rclass 72 | lhs <- 73 | dynamicSlice (constant @a "a" [rclass --> origSize]) $ 74 | DySlice 75 | { start = [rclass --> start], 76 | sizes = [rclass --> newSize] 77 | } 78 | rhs <- constant @a "a" [rclass --> newSize] 79 | rewrite "DynamicSlice(Broadcast(Scalar)) ⇒ Broadcast(Scalar)" lhs rhs 80 | 81 | rule07 :: DSLContext Rewrite 82 | rule07 = do 83 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 84 | rc0Size <- newMap "rc0Size" rclass0 85 | rc1Size <- newMap "rc1Size" rclass1 86 | lhs <- 87 | broadcast 88 | (iota [rclass0 --> rc0Size] (ByRClass rclass0)) 89 | [rclass1 --> rc1Size] 90 | rhs <- iota [rclass0 --> rc0Size, rclass1 --> rc1Size] (ByRClass rclass0) 91 | rewrite "Broadcast(Iota) ⇒ Iota" lhs rhs 92 | 93 | rule08 :: forall a. AnyDTypeRule a 94 | rule08 _ = do 95 | [rclass0, rclass1, rclass2] <- newRClasses ["rclass0", "rclass1", "rclass2"] 96 | rc0Size <- newMap "rc0Size" rclass0 97 | rc1Size <- newMap "rc1Size" rclass1 98 | rc2Size <- newMap "rc2Size" rclass2 99 | tA <- newTensor @a "A" [rclass0 --> rc0Size] 100 | lhs <- broadcast (broadcast tA [rclass1 --> rc1Size]) [rclass2 --> rc2Size] 101 | rhs <- broadcast tA [rclass1 --> rc1Size, rclass2 --> rc2Size] 102 | rewrite "Broadcast(Broadcast(A, shape, dims), shape2, dims2) ⇒ Broadcast(A, shape3, dims3)" lhs rhs 103 | 104 | main :: IO () 105 | main = do 106 | print "############################## rule01 ##############################" 107 | verifyNumDSL rule01 108 | print "############################## rule02 ##############################" 109 | verifyNumDSL rule02 110 | print "############################## rule03 ##############################" 111 | verifyAnyDTypeDSL rule03 112 | print "############################## rule04 ##############################" 113 | verifyAnyDTypeDSL rule04 114 | print "############################## rule05 ##############################" 115 | verifyAnyDTypeDSL rule05 116 | print "############################## rule06 ##############################" 117 | verifyAnyDTypeDSL rule06 118 | print "############################## rule07 ##############################" 119 | verifyDSL rule07 120 | print "############################## rule08 ##############################" 121 | verifyAnyDTypeDSL rule08 122 | -------------------------------------------------------------------------------- /rules/xla/divmod/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Control.Monad.Except (runExceptT) 4 | import Grisette hiding ((-->)) 5 | import TensorRight 6 | import TensorRight.Internal.Core.Tensor.TensorInt (TensorDivMod (tensorDiv)) 7 | 8 | rule01 :: forall a. NumRule a 9 | rule01 _ = do 10 | rclass <- newRClass "rclass" 11 | map <- newMap "map" rclass 12 | tA <- newTensor @a "A" [rclass --> map] 13 | one <- constant @a 1 [rclass --> map] 14 | lhs <- numBinOp Div tA one 15 | let rhs = tA 16 | rewrite "Div(A, 1) ⇒ A" lhs rhs 17 | 18 | rule02 :: DSLContext Rewrite 19 | rule02 = do 20 | rclass <- newRClass "rclass" 21 | map <- newMap "map" rclass 22 | tA <- newTensor @TensorReal "A" [rclass --> map] 23 | let c = "c" :: TensorReal 24 | constTensor <- constant @TensorReal c [rclass --> map] 25 | let creci = tensorDiv 1 "c" :: TensorReal 26 | constTensorreci <- constant @TensorReal creci [rclass --> map] 27 | lhs <- numBinOp Div tA constTensor 28 | rhs <- numBinOp Mul tA constTensorreci 29 | rewrite "Div(A, Const) ⇒ Mul(A, 1/Const)" lhs rhs 30 | 31 | rule03 :: DSLContext Rewrite 32 | rule03 = do 33 | rclass <- newRClass "rclass" 34 | map <- newMap "map" rclass 35 | tA <- newTensor @TensorReal "A" [rclass --> map] 36 | tB <- newTensor @TensorReal "B" [rclass --> map] 37 | tC <- newTensor @TensorReal "C" [rclass --> map] 38 | tD <- newTensor @TensorReal "D" [rclass --> map] 39 | lhs <- numBinOp Div (numBinOp Div tA tB) (numBinOp Div tC tD) 40 | rhs <- numBinOp Div (numBinOp Mul tA tD) (numBinOp Mul tB tC) 41 | rewrite "Divide(Divide(A, B), Divide(C, D)) ⇒ Divide(Mul(A, D), Mul(B, C))" lhs rhs 42 | 43 | rule04 :: DSLContext Rewrite 44 | rule04 = do 45 | rclass <- newRClass "rclass" 46 | map <- newMap "map" rclass 47 | tA <- newTensor @TensorReal "A" [rclass --> map] 48 | tB <- newTensor @TensorReal "B" [rclass --> map] 49 | tC <- newTensor @TensorReal "C" [rclass --> map] 50 | lhs <- numBinOp Div tA (numBinOp Div tB tC) 51 | rhs <- numBinOp Div (numBinOp Mul tA tC) tB 52 | rewrite "Divide(A, Divide(B, C)) ⇒ Divide(Mul(A, C), B)" lhs rhs 53 | 54 | rule05 :: DSLContext Rewrite 55 | rule05 = do 56 | rclass <- newRClass "rclass" 57 | map <- newMap "map" rclass 58 | tA <- newTensor @TensorInt "A" [rclass --> map] 59 | tB <- newTensor @TensorInt "B" [rclass --> map] 60 | lhs <- numBinOp Rem (numBinOp Rem tA tB) tB 61 | rhs <- numBinOp Rem tA tB 62 | rewrite "Rem(Rem(A, B), B) ⇒ Rem(A, B)" lhs rhs 63 | 64 | rule06 :: DSLContext Rewrite 65 | rule06 = do 66 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 67 | rc0Size <- newMap "rc0Size" rclass0 68 | rc1Size <- newMap "rc1Size" rclass1 69 | 70 | const <- constant @TensorInt "c" [rclass0 --> rc0Size, rclass1 --> rc1Size] 71 | lhs <- numBinOp Rem (iota [rclass0 --> rc0Size, rclass1 --> rc1Size] (ByRClass rclass0)) const 72 | precondition [rc0Size] $ \[s] -> s .<= "c" 73 | 74 | rhs <- iota [rclass0 --> rc0Size, rclass1 --> rc1Size] (ByRClass rclass0) 75 | rewrite "Rem(Iota, Const) ⇒ Iota" lhs rhs 76 | 77 | rule07 :: DSLContext Rewrite 78 | rule07 = do 79 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 80 | rc0Size <- newMap "rc0Size" rclass0 81 | rc1Size <- newMap "rc1Size" rclass1 82 | 83 | let c = "c" :: SymInteger 84 | const <- constant @TensorInt (nonInf c) [rclass0 --> rc0Size, rclass1 --> rc1Size] 85 | lhs <- numBinOp Rem (numBinOp Add (iota [rclass0 --> rc0Size, rclass1 --> rc1Size] (ByRClass rclass0)) const) const 86 | precondition [rc0Size] $ \[s] -> c .>= 0 87 | 88 | rhs <- numBinOp Rem (iota [rclass0 --> rc0Size, rclass1 --> rc1Size] (ByRClass rclass0)) const 89 | rewrite "Rem(Add(Iota, Const), Const) ⇒ Rem(Iota, Const)" lhs rhs 90 | 91 | rule08 :: DSLContext Rewrite 92 | rule08 = do 93 | rclass <- newRClass "rclass" 94 | map <- newMap "map" rclass 95 | let xv = "x" :: SymInteger 96 | let nv = "n" :: SymInteger 97 | x <- constant @TensorInt (nonInf xv) [rclass --> map] 98 | n <- constant @TensorInt (nonInf nv) [rclass --> map] 99 | lhs <- numBinOp Rem (numBinOp Add x n) n 100 | rhs <- numBinOp Rem x n 101 | precondition [map] $ \[m] -> 102 | symIte (xv .>= 0) (xv + nv .>= 0) (xv + nv .< 0) 103 | rewrite "Rem(Add(X, Const), Const) ⇒ Rem(X, Const)" lhs rhs 104 | 105 | main :: IO () 106 | main = do 107 | print "############################## rule01 ##############################" 108 | verifyNumDSL rule01 109 | print "############################## rule02 ##############################" 110 | verifyDSL rule02 111 | print "############################## rule03 ##############################" 112 | verifyDSL rule03 113 | print "############################## rule04 ##############################" 114 | verifyDSL rule04 115 | print "############################## rule05 ##############################" 116 | verifyDSL rule05 117 | print "############################## rule06 ##############################" 118 | verifyDSL rule06 119 | print "############################## rule07 ##############################" 120 | verifyDSLWith (withTimeout 10000000 z3) rule07 121 | print "############################## rule08 ##############################" 122 | verifyDSLWith (withTimeout 10000000 z3) rule08 123 | -------------------------------------------------------------------------------- /rules/xla/logical/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: DSLContext Rewrite 7 | rule01 = do 8 | rclass <- newRClass "rclass" 9 | map <- newMap "map" rclass 10 | tA <- newTensor @SymBool "A" [rclass --> map] 11 | lhs <- boolBinOp Or tA (constant @SymBool (con True) [rclass --> map]) 12 | rhs <- constant @SymBool (con True) [rclass --> map] 13 | rewrite "Or(A, True) ⇒ True" lhs rhs 14 | 15 | rule02 :: DSLContext Rewrite 16 | rule02 = do 17 | rclass <- newRClass "rclass" 18 | map <- newMap "map" rclass 19 | tA <- newTensor @SymBool "A" [rclass --> map] 20 | lhs <- boolBinOp Or (constant @SymBool (con True) [rclass --> map]) tA 21 | rhs <- constant @SymBool (con True) [rclass --> map] 22 | rewrite "Or(True, A) ⇒ True" lhs rhs 23 | 24 | rule03 :: DSLContext Rewrite 25 | rule03 = do 26 | rclass <- newRClass "rclass" 27 | map <- newMap "map" rclass 28 | tA <- newTensor @SymBool "A" [rclass --> map] 29 | lhs <- boolBinOp Or tA (constant @SymBool (con False) [rclass --> map]) 30 | let rhs = tA 31 | rewrite "Or(A, False) ⇒ A" lhs rhs 32 | 33 | rule04 :: DSLContext Rewrite 34 | rule04 = do 35 | rclass <- newRClass "rclass" 36 | map <- newMap "map" rclass 37 | tA <- newTensor @SymBool "A" [rclass --> map] 38 | lhs <- boolBinOp Or (constant @SymBool (con False) [rclass --> map]) tA 39 | let rhs = tA 40 | rewrite "Or(False, A) ⇒ A" lhs rhs 41 | 42 | rule05 :: DSLContext Rewrite 43 | rule05 = do 44 | rclass <- newRClass "rclass" 45 | map <- newMap "map" rclass 46 | tA <- newTensor @SymBool "A" [rclass --> map] 47 | lhs <- boolBinOp And tA (constant @SymBool (con True) [rclass --> map]) 48 | let rhs = tA 49 | rewrite "And(A, 1) ⇒ A" lhs rhs 50 | 51 | rule06 :: DSLContext Rewrite 52 | rule06 = do 53 | rclass <- newRClass "rclass" 54 | map <- newMap "map" rclass 55 | tA <- newTensor @SymBool "A" [rclass --> map] 56 | lhs <- boolBinOp And (constant @SymBool (con True) [rclass --> map]) tA 57 | let rhs = tA 58 | rewrite "And(1, A) ⇒ A" lhs rhs 59 | 60 | rule07 :: DSLContext Rewrite 61 | rule07 = do 62 | rclass <- newRClass "rclass" 63 | map <- newMap "map" rclass 64 | tA <- newTensor @SymBool "A" [rclass --> map] 65 | lhs <- boolBinOp And tA (constant @SymBool (con False) [rclass --> map]) 66 | rhs <- constant @SymBool (con False) [rclass --> map] 67 | rewrite "And(A, 0) ⇒ 0" lhs rhs 68 | 69 | rule08 :: DSLContext Rewrite 70 | rule08 = do 71 | rclass <- newRClass "rclass" 72 | map <- newMap "map" rclass 73 | tA <- newTensor @SymBool "A" [rclass --> map] 74 | lhs <- boolBinOp And (constant @SymBool (con False) [rclass --> map]) tA 75 | rhs <- constant @SymBool (con False) [rclass --> map] 76 | rewrite "And(0, A) ⇒ 0" lhs rhs 77 | 78 | rule09 :: forall a. NumRule a 79 | rule09 _ = do 80 | rclass <- newRClass "rclass" 81 | map <- newMap "map" rclass 82 | tA <- newTensor @a "A" [rclass --> map] 83 | constTensor1 <- constant @a "a" [rclass --> map] 84 | constTensor2 <- constant @a "b" [rclass --> map] 85 | lhs <- boolBinOp And (compareOp Lt tA constTensor1) (compareOp Gt constTensor2 tA) 86 | rhs <- compareOp Lt tA (numBinOp Min constTensor1 constTensor2) 87 | rewrite "And(A < Const, Const1 > A) ⇒ Lt(A, min(Const, Const1))" lhs rhs 88 | 89 | rule10 :: forall a. NumRule a 90 | rule10 _ = do 91 | rclass <- newRClass "rclass" 92 | map <- newMap "map" rclass 93 | tA <- newTensor @a "A" [rclass --> map] 94 | tB <- newTensor @a "B" [rclass --> map] 95 | lhs <- compareOp Gt (numBinOp Max tA tB) tB 96 | rhs <- compareOp Gt tA tB 97 | rewrite "Gt(Max(A, B), B) ⇒ Gt(A, B)" lhs rhs 98 | 99 | rule11 :: forall a. NumRule a 100 | rule11 _ = do 101 | rclass <- newRClass "rclass" 102 | map <- newMap "map" rclass 103 | tA <- newTensor @a "A" [rclass --> map] 104 | tB <- newTensor @a "B" [rclass --> map] 105 | lhs <- compareOp Gt (numBinOp Max tA tB) tA 106 | rhs <- compareOp Gt tB tA 107 | rewrite "Gt(Max(A, B), A) ⇒ Gt(B, A)" lhs rhs 108 | 109 | main :: IO () 110 | main = do 111 | print "############################## rule01 ##############################" 112 | verifyDSL rule01 113 | print "############################## rule02 ##############################" 114 | verifyDSL rule02 115 | print "############################## rule03 ##############################" 116 | verifyDSL rule03 117 | print "############################## rule04 ##############################" 118 | verifyDSL rule04 119 | print "############################## rule05 ##############################" 120 | verifyDSL rule05 121 | print "############################## rule06 ##############################" 122 | verifyDSL rule06 123 | print "############################## rule07 ##############################" 124 | verifyDSL rule07 125 | print "############################## rule08 ##############################" 126 | verifyDSL rule08 127 | print "############################## rule09 ##############################" 128 | verifyNumDSL rule09 129 | print "############################## rule10 ##############################" 130 | verifyNumDSL rule10 131 | print "############################## rule11 ##############################" 132 | verifyNumDSL rule11 133 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/DSL/Condition.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | {-# LANGUAGE RecordWildCards #-} 3 | 4 | module TensorRight.Internal.DSL.Condition 5 | ( Condition (..), 6 | zipCondition, 7 | elemWiseCond, 8 | elemWiseArith, 9 | unaryCond, 10 | unaryArith, 11 | ) 12 | where 13 | 14 | import qualified Data.HashMap.Lazy as HM 15 | import qualified Data.Text as T 16 | import Grisette 17 | ( LogicalOp ((.&&)), 18 | PPrint (pformat), 19 | Solvable (con), 20 | SymBool, 21 | SymInteger, 22 | symAll, 23 | ) 24 | import Prettyprinter ((<+>)) 25 | import TensorRight.Internal.DSL.Identifier (MapIdentifier) 26 | 27 | data Condition = Condition 28 | { maps :: [MapIdentifier], 29 | condition :: [HM.HashMap T.Text SymInteger] -> SymBool 30 | } 31 | 32 | instance Show Condition where 33 | show Condition {..} = "Condition [" <> show maps <> "]" 34 | 35 | instance PPrint Condition where 36 | pformat Condition {..} = "Condition" <+> pformat maps 37 | 38 | -- | The condition will be applied to each group of the elements with the same 39 | -- axes. 40 | -- 41 | -- For instance, @'zipCondition' (\l -> sum l .== 10) [a, b, c]@, checks if for 42 | -- every key @k@, @a[k] + b[k] + c[k] .== 10@. This also means that 43 | -- 'zipCondition' can only be used on maps having the same domain, i.e., the 44 | -- same @RClass@. 45 | zipCondition :: 46 | ([SymInteger] -> SymBool) -> 47 | [HM.HashMap T.Text SymInteger] -> 48 | SymBool 49 | zipCondition _ [] = con True 50 | zipCondition f allMaps@(m : _) = allSameKeys .&& symAll (f . byKey) keys 51 | where 52 | allSame [] = con True 53 | allSame (x : xs) = symAll (con . (== x)) xs .&& allSame xs 54 | allSameKeys = allSame (map HM.keys allMaps) 55 | keys = HM.keys m 56 | 57 | byKey :: T.Text -> [SymInteger] 58 | byKey key = fmap (HM.! key) allMaps 59 | 60 | -- | The condition will be applied to two maps in an element-wise way. 61 | -- 62 | -- For instance, @'elemWiseCond' (.==) a b@, checks if for every key 63 | -- @k@, @a[k] .== b[k]@. This also means that 'elemWiseCond' can only be 64 | -- used on maps having the same domain, i.e., the same @RClass@. 65 | -- 66 | -- For @f@, users can use the already avaiable functions in Grisette like 67 | -- 'Grisette..==', 'Grisette../=', 'Grisette..>', 'Grisette..>=', 'Grisette..<', 68 | -- 'Grisette..<=', or they can create their own functions on the fly, as long as 69 | -- it satisifies the signature. For example, the following precondition checks 70 | -- if @m1 <= m2 || m1 == 1@: 71 | -- 72 | -- @ 73 | -- let compareFunc x y = x <= y .|| x .== 1 74 | -- 'TensorRight.precondition'' [m1, m2] $ \[m1, m2] -> 'elemWiseCond' (.<=) m1 m2 .|| 'unaryCond' (.== 1) m1 75 | -- @ 76 | elemWiseCond :: 77 | -- | Element-wise function 78 | (SymInteger -> SymInteger -> SymBool) -> 79 | -- | Left-hand side 80 | HM.HashMap T.Text SymInteger -> 81 | -- | Right-hand side 82 | HM.HashMap T.Text SymInteger -> 83 | -- | Resulting condition 84 | SymBool 85 | elemWiseCond f a b = 86 | con (HM.keysSet a == HM.keysSet b) 87 | .&& foldr (.&&) (con True) (HM.intersectionWith f a b) 88 | 89 | -- | The condition will be applied to each element of the map. 90 | -- 91 | -- For instance, @'unaryCond' (.== 0) a@, checks if for every key @k@, 92 | -- @a[k] .== 0@. 93 | -- 94 | -- The users can use any function @f@, as long as it satisfies the signature. 95 | -- 96 | -- @ 97 | -- 'TensorRight.precondition'' [m] $ \[m] -> 'unaryCond' (.> 1) m 98 | -- @ 99 | unaryCond :: 100 | (SymInteger -> SymBool) -> 101 | HM.HashMap T.Text SymInteger -> 102 | SymBool 103 | unaryCond f a = foldr (.&&) (con True) (HM.map f a) 104 | 105 | -- | Helper for operate on two maps in an element-wise way. 106 | -- 107 | -- For instance, @'elemWiseArith' (+) a b@, returns a map @r@ such that for 108 | -- every key @k@, @r[k] = a[k] + b[k]@. This also means that 'elemWiseArith' 109 | -- can only be used on maps having the same domain, i.e., the same @RClass@. 110 | -- 111 | -- For @f@, users can use any binary operator, as long as it satisifies the 112 | -- signature. This is useful if you have a condition involving more than two 113 | -- maps. For instance, this checks if @m1 == m2 + m3@ 114 | -- 115 | -- @ 116 | -- 'TensorRight.precondition'' [m1, m2, m3] $ 117 | -- \[m1, m2, m3] -> 'elemWiseCond' (.==) m1 ('elemWiseArith' (+) m2 m3) 118 | -- @ 119 | elemWiseArith :: 120 | (SymInteger -> SymInteger -> SymInteger) -> 121 | HM.HashMap T.Text SymInteger -> 122 | HM.HashMap T.Text SymInteger -> 123 | HM.HashMap T.Text SymInteger 124 | elemWiseArith = HM.intersectionWith 125 | 126 | -- | Helper for operating on a map in a unary way. 127 | -- 128 | -- For instance, @'unaryArith' (+2) a@, return a map @r@ such that for 129 | -- every key @k@, @r[k] = a[k] + 2@. 130 | -- 131 | -- For @f@, users can use any unary operator, as long as it satisifies the 132 | -- signature. For instance, this checks if @m1 == m2 + 2@ 133 | -- 134 | -- @ 135 | -- 'TensorRight.precondition'' [m1, m2] $ 136 | -- \[m1, m2] -> 'elemWiseCond' (.==) m1 ('unaryArith' (+2) m2) 137 | -- @ 138 | unaryArith :: 139 | (SymInteger -> SymInteger) -> 140 | HM.HashMap T.Text SymInteger -> 141 | HM.HashMap T.Text SymInteger 142 | unaryArith = HM.map 143 | -------------------------------------------------------------------------------- /rules/xla/generalize/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Data.Proxy 4 | import qualified Data.Text as T 5 | import Grisette hiding (dot, (-->)) 6 | import TensorRight 7 | 8 | foldConvInputPadGeneral :: forall a. NumRule a 9 | foldConvInputPadGeneral _ = do 10 | [batch, feature, output, spatial] <- 11 | newRClasses ["batch", "feature", "output", "spatial"] 12 | 13 | batchShape <- newMap "batchShape" batch 14 | featureShape <- newMap "featureShape" feature 15 | outputShape <- newMap "outputShape" output 16 | inputSpatialShape <- newMap "inputSpatialShape" spatial 17 | weightSpatialShape <- newMap "weightSpatialShape" spatial 18 | 19 | inputs <- 20 | newTensor @a 21 | "inputs" 22 | [ batch --> batchShape, 23 | feature --> featureShape, 24 | spatial --> inputSpatialShape 25 | ] 26 | weights <- 27 | newTensor @a 28 | "weights" 29 | [ feature --> featureShape, 30 | output --> outputShape, 31 | spatial --> weightSpatialShape 32 | ] 33 | 34 | strides <- newConstMap "strides" 1 spatial 35 | -- Turns out that non-negative is required, even if we support negative 36 | -- padding. The reason is that when x contains negative indices, 37 | -- and x + y >= 0, then 38 | -- 39 | -- pad(pad(A, x), y) /= pad(A, x + y) 40 | [low, high] <- newNonNegMaps ["low", "high"] spatial 41 | ldilation <- newNonNegMap "ldilation" spatial 42 | rdilation <- newNonNegMap "rdilation" spatial 43 | [plow, phigh] <- newNonNegMaps ["plow", "phigh"] spatial 44 | pint <- newNonNegMap "pint" spatial 45 | newlow <- combineMap "newlow" (\[a, b, c] -> a + b * c) [low, plow, ldilation] 46 | newldilation <- combineMap "newldilation" (\[a, b] -> a + a * b) [ldilation, pint] 47 | newhigh <- combineMap "newhigh" (\[a, b, c] -> a + b * c) [high, phigh, ldilation] 48 | precondition [inputSpatialShape] $ \[s] -> s .> 0 49 | precondition [weightSpatialShape] $ \[s] -> s .> 0 50 | 51 | [siMapLhsFeature, siMapRhsFeature] <- 52 | newMaps ["siMapLhsFeature", "siMapRhsFeature"] feature 53 | [siMapLhsSpatial, siMapRhsSpatial] <- 54 | newMaps ["siMapLhsSpatial", "siMapRhsSpatial"] spatial 55 | 56 | lhsInputsPadded <- 57 | pad inputs (0 :: a) $ 58 | Padding 59 | { low = [spatial --> plow], 60 | interior = [spatial --> pint], 61 | high = [spatial --> phigh] 62 | } 63 | lhs <- 64 | conv 65 | lhsInputsPadded 66 | weights 67 | ConvConfig 68 | { batchRClasses = [ByRClass batch], 69 | featureRClasses = [ByRClass feature], 70 | outputFeatureRClasses = [ByRClass output], 71 | strides = [spatial --> strides], 72 | contractingSIMaps = 73 | [feature --> siMapLhsFeature, spatial --> siMapLhsSpatial] 74 | } 75 | ConvPadding 76 | { low = [spatial --> low], 77 | ldilation = [spatial --> ldilation], 78 | high = [spatial --> high], 79 | rdilation = [spatial --> rdilation] 80 | } 81 | monitorExprOnFailure "inputs" inputs 82 | monitorExprOnFailure "weights" weights 83 | monitorExprOnFailure "lhsInputPadded" lhsInputsPadded 84 | monitorExprOnFailure "lhs" lhs 85 | monitorMapOnFailure "plow" (ByRClass spatial) plow 86 | monitorMapOnFailure "pint" (ByRClass spatial) pint 87 | monitorMapOnFailure "phigh" (ByRClass spatial) phigh 88 | monitorMapOnFailure "low" (ByRClass spatial) low 89 | monitorMapOnFailure "ldilation" (ByRClass spatial) ldilation 90 | monitorMapOnFailure "high" (ByRClass spatial) high 91 | 92 | monitorMapOnFailure "newlow" (ByRClass spatial) newlow 93 | monitorMapOnFailure "newint" (ByRClass spatial) newldilation 94 | monitorMapOnFailure "newhigh" (ByRClass spatial) newhigh 95 | monitorMapOnFailure "rdilation" (ByRClass spatial) rdilation 96 | 97 | rhs <- 98 | conv 99 | inputs 100 | weights 101 | ConvConfig 102 | { batchRClasses = [ByRClass batch], 103 | featureRClasses = [ByRClass feature], 104 | outputFeatureRClasses = [ByRClass output], 105 | strides = [spatial --> strides], 106 | contractingSIMaps = 107 | [feature --> siMapRhsFeature, spatial --> siMapRhsSpatial] 108 | } 109 | ConvPadding 110 | { low = [spatial --> newlow], 111 | ldilation = [spatial --> newldilation], 112 | high = [spatial --> newhigh], 113 | rdilation = [spatial --> rdilation] 114 | } 115 | 116 | siRelation 117 | [siMapLhsFeature, siMapRhsFeature] 118 | $ \[vsiMapLhsFeature, vsiMapRhsFeature] -> 119 | vsiMapLhsFeature .== vsiMapRhsFeature 120 | siRelation 121 | [siMapLhsSpatial, siMapRhsSpatial] 122 | $ \[vsiMapLhsSpatial, vsiMapRhsSpatial] -> 123 | vsiMapLhsSpatial .== vsiMapRhsSpatial 124 | checkSIMap 125 | [siMapLhsFeature, siMapLhsSpatial] 126 | [siMapRhsFeature, siMapRhsSpatial] 127 | rewrite 128 | ( T.intercalate 129 | "\n" 130 | [ "Conv(Pad(input, innerLow, innerInt, innerHigh), weights, convLow, convInt, convHigh, rdilation)", 131 | " ⇒ ", 132 | "Conv(input, weights, convLowOut, convIntOut, convHighOut, rdilation)" 133 | ] 134 | ) 135 | lhs 136 | rhs 137 | 138 | main :: IO () 139 | main = do 140 | verifyNumDSL foldConvInputPadGeneral 141 | -------------------------------------------------------------------------------- /rules/xla/mul/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette (cvc5) 4 | import TensorRight 5 | 6 | rule01 :: forall a. NumRule a 7 | rule01 _ = do 8 | rclass <- newRClass "rclass" 9 | map <- newMap "map" rclass 10 | tA <- newTensor @a "A" [rclass --> map] 11 | one <- constant @a 1 [rclass --> map] 12 | lhs <- numBinOp Mul tA one 13 | let rhs = tA 14 | rewrite "Mul(A, 1) ⇒ A" lhs rhs 15 | 16 | rule02 :: forall a. NumRule a 17 | rule02 _ = do 18 | rclass <- newRClass "rclass" 19 | map <- newMap "map" rclass 20 | tA <- newTensor @a "A" [rclass --> map] 21 | one <- constant @a 1 [rclass --> map] 22 | lhs <- numBinOp Mul one tA 23 | let rhs = tA 24 | rewrite "Mul(1, A) ⇒ A" lhs rhs 25 | 26 | rule03 :: forall a. NumRule a 27 | rule03 _ = do 28 | rclass <- newRClass "rclass" 29 | map <- newMap "map" rclass 30 | tA <- newTensor @a "A" [rclass --> map] 31 | zero <- constant @a 0 [rclass --> map] 32 | lhs <- numBinOp Mul tA zero 33 | let rhs = zero 34 | rewrite "Mul(A, 0) ⇒ 0" lhs rhs 35 | 36 | rule04 :: forall a. NumRule a 37 | rule04 _ = do 38 | rclass <- newRClass "rclass" 39 | map <- newMap "map" rclass 40 | tensor <- newTensor @a "tensor" [rclass --> map] 41 | zero <- constant @a 0 [rclass --> map] 42 | lhs <- numBinOp Mul zero tensor 43 | let rhs = zero 44 | rewrite "Mul(0, A) ⇒ 0" lhs rhs 45 | 46 | rule05 :: forall a. NumRule a 47 | rule05 _ = do 48 | rclass <- newRClass "rclass" 49 | map <- newMap "map" rclass 50 | tA <- newTensor @a "A" [rclass --> map] 51 | tB <- newTensor @a "B" [rclass --> map] 52 | c1 <- constant @a "c1" [rclass --> map] 53 | c2 <- constant @a "c2" [rclass --> map] 54 | lhs <- numBinOp Mul (numBinOp Mul tA c1) (numBinOp Mul tB c2) 55 | rhs <- numBinOp Mul (numBinOp Mul tA tB) (numBinOp Mul c1 c2) 56 | rewrite "Mul(Mul(A, Const1), Mul(B, Const2)) ⇒ Mul(Mul(A, B), Mul(Const1, Const2))" lhs rhs 57 | 58 | rule06 :: forall a. NumRule a 59 | rule06 _ = do 60 | rclass <- newRClass "rclass" 61 | map <- newMap "map" rclass 62 | tA <- newTensor @a "A" [rclass --> map] 63 | c1 <- constant @a "c1" [rclass --> map] 64 | c2 <- constant @a "c2" [rclass --> map] 65 | lhs <- numBinOp Mul (numBinOp Mul tA c1) c2 66 | rhs <- numBinOp Mul tA (numBinOp Mul c1 c2) 67 | rewrite "Mul(Mul(A, Const1), Const2) ⇒ Mul(A, Mul(Const1, Const2))" lhs rhs 68 | 69 | rule07 :: forall a. NumRule a 70 | rule07 _ = do 71 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 72 | rc0Size <- newMap "rc0Size" rclass0 73 | rc1Size <- newMap "rc1Size" rclass1 74 | tA <- newTensor @a "A" [rclass0 --> rc0Size, rclass1 --> rc1Size] 75 | tB <- newTensor @a "B" [rclass0 --> rc0Size] 76 | lhs <- numBinOp Mul (numBinOp Mul tA (constant @a "a" [rclass0 --> rc0Size, rclass1 --> rc1Size])) (broadcast tB [rclass1 --> rc1Size]) 77 | rhs <- numBinOp Mul (broadcast (numBinOp Mul tB (constant @a "a" [rclass0 --> rc0Size])) [rclass1 --> rc1Size]) tA 78 | rewrite "Mul(Mul(A, Const1), Broadcast(B)) ⇒ Mul(Broadcast(Mul(B, Const1), A))" lhs rhs 79 | 80 | rule08 :: forall a. NumRule a 81 | rule08 _ = do 82 | rclass <- newRClass "rclass" 83 | map <- newMap "map" rclass 84 | tA <- newTensor @a "A" [rclass --> map] 85 | tB <- newTensor @a "B" [rclass --> map] 86 | tC <- newTensor @a "C" [rclass --> map] 87 | lhs <- numBinOp Add (numBinOp Mul tA tC) (numBinOp Mul tB tC) 88 | rhs <- numBinOp Mul (numBinOp Add tA tB) tC 89 | rewrite "Add(Mul(A, C), Mul(B, C)) ⇒ Mul(Add(A, B), C)" lhs rhs 90 | 91 | rule09 :: forall a. NumRule a 92 | rule09 _ = do 93 | rclass <- newRClass "rclass" 94 | map <- newMap "map" rclass 95 | tA <- newTensor @a "A" [rclass --> map] 96 | lhs <- numBinOp Mul (numUnaryOp Abs tA) (numUnaryOp Abs tA) 97 | rhs <- numBinOp Mul tA tA 98 | rewrite "Mul(Abs(A), Abs(A)) ⇒ Mul(A, A)" lhs rhs 99 | 100 | rule10 :: DSLContext Rewrite 101 | rule10 = do 102 | rclass <- newRClass "rclass" 103 | map <- newMap "map" rclass 104 | tA <- newTensor @TensorReal "A" [rclass --> map] 105 | tB <- newTensor @TensorReal "B" [rclass --> map] 106 | lhs <- numBinOp Mul (numUnaryOp Exp tA) (numUnaryOp Exp tB) 107 | rhs <- numUnaryOp Exp (numBinOp Add tA tB) 108 | rewrite "Mul(Exp(A), Exp(B)) ⇒ Exp(Add(A, B))" lhs rhs 109 | 110 | main :: IO () 111 | main = do 112 | print "############################## rule01 ##############################" 113 | verifyNumDSL rule01 114 | print "############################## rule02 ##############################" 115 | verifyNumDSL rule02 116 | print "############################## rule03 ##############################" 117 | verifyNumDSL rule03 118 | print "############################## rule04 ##############################" 119 | verifyNumDSL rule04 120 | print "############################## rule05 ##############################" 121 | verifyNumDSL rule05 122 | print "############################## rule06 ##############################" 123 | verifyNumDSL rule06 124 | print "############################## rule07 ##############################" 125 | verifyNumDSL rule07 126 | print "############################## rule08 ##############################" 127 | verifyNumDSL rule08 128 | print "############################## rule09 ##############################" 129 | verifyNumDSL rule09 130 | print "############################## rule10 ##############################" 131 | verifyDSLWith cvc5 rule10 132 | -------------------------------------------------------------------------------- /rules/xla/dyslice/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: forall a. AnyDTypeRule a 7 | rule01 _ = do 8 | rclass <- newRClass "rclass" 9 | [rcSize, rcStart, rcLength] <- 10 | newMaps ["rcSize", "rcStart", "rcLength"] rclass 11 | 12 | tA <- newTensor @a "A" [rclass --> rcSize] 13 | lhs <- 14 | dynamicSlice tA $ 15 | DySlice 16 | { start = [rclass --> rcStart], 17 | sizes = [rclass --> rcLength] 18 | } 19 | 20 | rcStride <- newConstMap "rcStride" 1 rclass 21 | rcEnd <- combineMap "rcEnd" sum [rcStart, rcLength] 22 | rhs <- 23 | slice tA $ 24 | Slice 25 | { start = [rclass --> rcStart], 26 | end = [rclass --> rcEnd], 27 | strides = [rclass --> rcStride] 28 | } 29 | 30 | rewrite "DynamicSlice(A) ⇒ Slice(A)" lhs rhs 31 | 32 | rule02 :: forall a. AnyDTypeRule a 33 | rule02 _ = do 34 | rclass <- newRClass "rclass" 35 | [rcSize, rcStart, rcLength] <- newMaps ["rcSize", "rcStart", "rcLength"] rclass 36 | 37 | tA <- newTensor @a "A" [rclass --> rcSize] 38 | lhs <- 39 | dynamicSlice tA $ 40 | DySlice 41 | { start = [rclass --> rcStart], 42 | sizes = [rclass --> rcLength] 43 | } 44 | precondition [rcStart] $ \[s] -> s .== 0 45 | precondition [rcLength, rcSize] $ \[l, s] -> l .== s 46 | 47 | let rhs = tA 48 | rewrite "DynamicSlice(A,...) ⇒ A // output shape is the same as input shape" lhs rhs 49 | 50 | rule03 :: forall a. AnyDTypeRule a 51 | rule03 _ = do 52 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 53 | [rc0Size, rc0Start, rc0Length] <- 54 | newMaps ["rc0Size", "rc0Start", "rc0Length"] rclass0 55 | [rc1Size, rc1Start, rc1Length] <- 56 | newMaps ["rc1Size", "rc1Start", "rc1Length"] rclass1 57 | tA <- newTensor @a "A" [rclass0 --> rc0Size] 58 | lhs <- 59 | dynamicSlice (broadcast tA [rclass1 --> rc1Size]) $ 60 | DySlice 61 | { start = [rclass0 --> rc0Start, rclass1 --> rc1Start], 62 | sizes = [rclass0 --> rc0Length, rclass1 --> rc1Length] 63 | } 64 | rhs <- 65 | broadcast 66 | ( dynamicSlice tA $ 67 | DySlice 68 | { start = [rclass0 --> rc0Start], 69 | sizes = [rclass0 --> rc0Length] 70 | } 71 | ) 72 | [rclass1 --> rc1Length] 73 | rewrite "DynamicSlice(Broadcast(A), ...) ⇒ Broadcast(DynamicSlice(A, ...))" lhs rhs 74 | 75 | rule04 :: forall a. AnyDTypeRule a 76 | rule04 _ = do 77 | rclass <- newRClass "rclass" 78 | [rcSize, rcStart, rcLength] <- 79 | newMaps ["rcSize", "rcStart", "rcLength"] rclass 80 | tA <- newTensor @a "A" [rclass --> rcSize @@ "label1"] 81 | lhs <- 82 | relabel 83 | ( dynamicSlice tA $ 84 | DySlice 85 | { start = [ByLabel "label1" --> rcStart], 86 | sizes = [ByLabel "label1" --> rcLength] 87 | } 88 | ) 89 | [ByLabel "label1" --> ByLabel "label2"] 90 | rhs <- 91 | dynamicSlice (relabel tA [ByLabel "label1" --> ByLabel "label2"]) $ 92 | DySlice 93 | { start = [ByLabel "label2" --> rcStart], 94 | sizes = [ByLabel "label2" --> rcLength] 95 | } 96 | rewrite "DynamicSlice(Transpose(A), ...) ⇒ Transpose(DynamicSlice(A, ...))" lhs rhs 97 | 98 | rule05 :: forall a. AnyDTypeRule a 99 | rule05 _ = do 100 | rclass <- newRClass "rclass" 101 | [rcSize, rcStartInner, rcStartOuter, rcLengthInner, rcLengthOuter] <- 102 | newMaps ["rcSize", "rcStartInner", "rcStartOuter", "rcLengthInner", "rcLengthOuter"] rclass 103 | 104 | tA <- newTensor @a "A" [rclass --> rcSize] 105 | dySliceInner <- 106 | dynamicSlice tA $ 107 | DySlice 108 | { start = [rclass --> rcStartInner], 109 | sizes = [rclass --> rcLengthInner] 110 | } 111 | lhs <- 112 | dynamicSlice dySliceInner $ 113 | DySlice 114 | { start = [rclass --> rcStartOuter], 115 | sizes = [rclass --> rcLengthOuter] 116 | } 117 | 118 | rcStartRhs <- combineMap "rcStartRhs" sum [rcStartInner, rcStartOuter] 119 | rhs <- 120 | dynamicSlice tA $ 121 | DySlice 122 | { start = [rclass --> rcStartRhs], 123 | sizes = [rclass --> rcLengthOuter] 124 | } 125 | 126 | rewrite "DynamicSlice(DynamicSlice(A,...),...) ⇒ DynamicSlice(A,...)" lhs rhs 127 | 128 | rule06 :: DSLContext Rewrite 129 | rule06 = do 130 | rclass <- newRClass "rclass" 131 | [sizeMap, startMap, lengthMap] <- 132 | newMaps ["sizeMap", "startMap", "lengthMap"] rclass 133 | 134 | lhs <- 135 | dynamicSlice (iota [rclass --> sizeMap] (ByRClass rclass)) $ 136 | DySlice 137 | { start = [rclass --> startMap], 138 | sizes = [rclass --> lengthMap] 139 | } 140 | rhsSize <- newConstMap "size" 1 rclass 141 | -- Cannot express rule since we need a precondition on "a" 142 | rhs <- constant @TensorInt "a" [rclass --> rhsSize] 143 | rewrite "DynamicSlice(Iota) ⇒ index" lhs rhs 144 | 145 | main :: IO () 146 | main = do 147 | print "############################## rule01 ##############################" 148 | verifyAnyDTypeDSL rule01 149 | print "############################## rule02 ##############################" 150 | verifyAnyDTypeDSL rule02 151 | print "############################## rule03 ##############################" 152 | verifyAnyDTypeDSL rule03 153 | print "############################## rule04 ##############################" 154 | verifyAnyDTypeDSL rule04 155 | print "############################## rule05 ##############################" 156 | verifyAnyDTypeDSL rule05 157 | -------------------------------------------------------------------------------- /hie.yaml: -------------------------------------------------------------------------------- 1 | cradle: 2 | stack: 3 | - path: "./src" 4 | component: "tensor-right:lib" 5 | 6 | - path: "./rules/debug/Main.hs" 7 | component: "tensor-right:exe:rules-debug" 8 | 9 | - path: "./rules/debug/Paths_tensor_right.hs" 10 | component: "tensor-right:exe:rules-debug" 11 | 12 | - path: "./rules/xla/add/Main.hs" 13 | component: "tensor-right:exe:rules-xla-add" 14 | 15 | - path: "./rules/xla/add/Paths_tensor_right.hs" 16 | component: "tensor-right:exe:rules-xla-add" 17 | 18 | - path: "./rules/xla/broadcast/Main.hs" 19 | component: "tensor-right:exe:rules-xla-broadcast" 20 | 21 | - path: "./rules/xla/broadcast/Paths_tensor_right.hs" 22 | component: "tensor-right:exe:rules-xla-broadcast" 23 | 24 | - path: "./rules/xla/clamp/Main.hs" 25 | component: "tensor-right:exe:rules-xla-clamp" 26 | 27 | - path: "./rules/xla/clamp/Paths_tensor_right.hs" 28 | component: "tensor-right:exe:rules-xla-clamp" 29 | 30 | - path: "./rules/xla/compare/Main.hs" 31 | component: "tensor-right:exe:rules-xla-compare" 32 | 33 | - path: "./rules/xla/compare/Paths_tensor_right.hs" 34 | component: "tensor-right:exe:rules-xla-compare" 35 | 36 | - path: "./rules/xla/concat/Main.hs" 37 | component: "tensor-right:exe:rules-xla-concat" 38 | 39 | - path: "./rules/xla/concat/Paths_tensor_right.hs" 40 | component: "tensor-right:exe:rules-xla-concat" 41 | 42 | - path: "./rules/xla/conv/Main.hs" 43 | component: "tensor-right:exe:rules-xla-conv" 44 | 45 | - path: "./rules/xla/conv/Paths_tensor_right.hs" 46 | component: "tensor-right:exe:rules-xla-conv" 47 | 48 | - path: "./rules/xla/divmod/Main.hs" 49 | component: "tensor-right:exe:rules-xla-divmod" 50 | 51 | - path: "./rules/xla/divmod/Paths_tensor_right.hs" 52 | component: "tensor-right:exe:rules-xla-divmod" 53 | 54 | - path: "./rules/xla/dot/Main.hs" 55 | component: "tensor-right:exe:rules-xla-dot" 56 | 57 | - path: "./rules/xla/dot/Paths_tensor_right.hs" 58 | component: "tensor-right:exe:rules-xla-dot" 59 | 60 | - path: "./rules/xla/dyslice/Main.hs" 61 | component: "tensor-right:exe:rules-xla-dyslice" 62 | 63 | - path: "./rules/xla/dyslice/Paths_tensor_right.hs" 64 | component: "tensor-right:exe:rules-xla-dyslice" 65 | 66 | - path: "./rules/xla/dyupslice/Main.hs" 67 | component: "tensor-right:exe:rules-xla-dyupslice" 68 | 69 | - path: "./rules/xla/dyupslice/Paths_tensor_right.hs" 70 | component: "tensor-right:exe:rules-xla-dyupslice" 71 | 72 | - path: "./rules/xla/generalize/Main.hs" 73 | component: "tensor-right:exe:rules-xla-generalize" 74 | 75 | - path: "./rules/xla/generalize/Paths_tensor_right.hs" 76 | component: "tensor-right:exe:rules-xla-generalize" 77 | 78 | - path: "./rules/xla/iota/Main.hs" 79 | component: "tensor-right:exe:rules-xla-iota" 80 | 81 | - path: "./rules/xla/iota/Paths_tensor_right.hs" 82 | component: "tensor-right:exe:rules-xla-iota" 83 | 84 | - path: "./rules/xla/logical/Main.hs" 85 | component: "tensor-right:exe:rules-xla-logical" 86 | 87 | - path: "./rules/xla/logical/Paths_tensor_right.hs" 88 | component: "tensor-right:exe:rules-xla-logical" 89 | 90 | - path: "./rules/xla/max/Main.hs" 91 | component: "tensor-right:exe:rules-xla-max" 92 | 93 | - path: "./rules/xla/max/Paths_tensor_right.hs" 94 | component: "tensor-right:exe:rules-xla-max" 95 | 96 | - path: "./rules/xla/mul/Main.hs" 97 | component: "tensor-right:exe:rules-xla-mul" 98 | 99 | - path: "./rules/xla/mul/Paths_tensor_right.hs" 100 | component: "tensor-right:exe:rules-xla-mul" 101 | 102 | - path: "./rules/xla/not/Main.hs" 103 | component: "tensor-right:exe:rules-xla-not" 104 | 105 | - path: "./rules/xla/not/Paths_tensor_right.hs" 106 | component: "tensor-right:exe:rules-xla-not" 107 | 108 | - path: "./rules/xla/pad/Main.hs" 109 | component: "tensor-right:exe:rules-xla-pad" 110 | 111 | - path: "./rules/xla/pad/Paths_tensor_right.hs" 112 | component: "tensor-right:exe:rules-xla-pad" 113 | 114 | - path: "./rules/xla/reduce/Main.hs" 115 | component: "tensor-right:exe:rules-xla-reduce" 116 | 117 | - path: "./rules/xla/reduce/Paths_tensor_right.hs" 118 | component: "tensor-right:exe:rules-xla-reduce" 119 | 120 | - path: "./rules/xla/relabel/Main.hs" 121 | component: "tensor-right:exe:rules-xla-relabel" 122 | 123 | - path: "./rules/xla/relabel/Paths_tensor_right.hs" 124 | component: "tensor-right:exe:rules-xla-relabel" 125 | 126 | - path: "./rules/xla/reverse/Main.hs" 127 | component: "tensor-right:exe:rules-xla-reverse" 128 | 129 | - path: "./rules/xla/reverse/Paths_tensor_right.hs" 130 | component: "tensor-right:exe:rules-xla-reverse" 131 | 132 | - path: "./rules/xla/select/Main.hs" 133 | component: "tensor-right:exe:rules-xla-select" 134 | 135 | - path: "./rules/xla/select/Paths_tensor_right.hs" 136 | component: "tensor-right:exe:rules-xla-select" 137 | 138 | - path: "./rules/xla/slice/Main.hs" 139 | component: "tensor-right:exe:rules-xla-slice" 140 | 141 | - path: "./rules/xla/slice/Paths_tensor_right.hs" 142 | component: "tensor-right:exe:rules-xla-slice" 143 | 144 | - path: "./rules/xla/sub/Main.hs" 145 | component: "tensor-right:exe:rules-xla-sub" 146 | 147 | - path: "./rules/xla/sub/Paths_tensor_right.hs" 148 | component: "tensor-right:exe:rules-xla-sub" 149 | 150 | - path: "./test" 151 | component: "tensor-right:test:spec" 152 | -------------------------------------------------------------------------------- /package.yaml: -------------------------------------------------------------------------------- 1 | name: tensor-right 2 | version: 0.1.0.0 3 | synopsis: Automated Verification of Tensor Graph Rewrites 4 | description: | 5 | TensorRight is an automatic tool that can be used to verify 6 | Tensor Graph Rewrites. 7 | license: Apache-2.0 8 | license-file: LICENSE 9 | 10 | dependencies: 11 | - base >= 4.14 && < 5 12 | - grisette >= 0.11 && < 0.12 13 | - unordered-containers 14 | - text 15 | - mtl 16 | - hashable 17 | - prettyprinter 18 | - ordered-containers 19 | - deepseq 20 | - sbv 21 | - template-haskell 22 | 23 | library: 24 | source-dirs: src 25 | ghc-options: 26 | - -Wextra 27 | - -Wcompat 28 | - -Widentities 29 | - -Wincomplete-record-updates 30 | - -Wmissing-export-lists 31 | - -Wmissing-home-modules 32 | - -Wmissing-import-lists 33 | - -Wpartial-fields 34 | - -Wunused-type-patterns 35 | - -Wno-x-partial 36 | - -Wno-unrecognised-warning-flags 37 | 38 | _exe-ghc-options: &exe-ghc-options 39 | - -threaded 40 | - -rtsopts 41 | - -with-rtsopts=-N 42 | 43 | _exe-extensions: &exe-extensions 44 | - DuplicateRecordFields 45 | - OverloadedStrings 46 | - TypeApplications 47 | - AllowAmbiguousTypes 48 | - ScopedTypeVariables 49 | - FlexibleContexts 50 | - RankNTypes 51 | 52 | executables: 53 | # XLA Executables 54 | rules-xla-add: 55 | source-dirs: rules/xla/add 56 | main: Main.hs 57 | dependencies: tensor-right 58 | ghc-options: *exe-ghc-options 59 | default-extensions: *exe-extensions 60 | rules-xla-mul: 61 | source-dirs: rules/xla/mul 62 | main: Main.hs 63 | dependencies: tensor-right 64 | ghc-options: *exe-ghc-options 65 | default-extensions: *exe-extensions 66 | rules-xla-reduce: 67 | source-dirs: rules/xla/reduce 68 | main: Main.hs 69 | dependencies: tensor-right 70 | ghc-options: *exe-ghc-options 71 | default-extensions: *exe-extensions 72 | rules-xla-compare: 73 | source-dirs: rules/xla/compare 74 | main: Main.hs 75 | dependencies: tensor-right 76 | ghc-options: *exe-ghc-options 77 | default-extensions: *exe-extensions 78 | rules-xla-slice: 79 | source-dirs: rules/xla/slice 80 | main: Main.hs 81 | dependencies: tensor-right 82 | ghc-options: *exe-ghc-options 83 | default-extensions: *exe-extensions 84 | rules-xla-iota: 85 | source-dirs: rules/xla/iota 86 | main: Main.hs 87 | dependencies: tensor-right 88 | ghc-options: *exe-ghc-options 89 | default-extensions: *exe-extensions 90 | rules-xla-pad: 91 | source-dirs: rules/xla/pad 92 | main: Main.hs 93 | dependencies: tensor-right 94 | ghc-options: *exe-ghc-options 95 | default-extensions: *exe-extensions 96 | rules-xla-dyslice: 97 | source-dirs: rules/xla/dyslice 98 | main: Main.hs 99 | dependencies: tensor-right 100 | ghc-options: *exe-ghc-options 101 | default-extensions: *exe-extensions 102 | rules-xla-dyupslice: 103 | source-dirs: rules/xla/dyupslice 104 | main: Main.hs 105 | dependencies: tensor-right 106 | ghc-options: *exe-ghc-options 107 | default-extensions: *exe-extensions 108 | rules-xla-broadcast: 109 | source-dirs: rules/xla/broadcast 110 | main: Main.hs 111 | dependencies: tensor-right 112 | ghc-options: *exe-ghc-options 113 | default-extensions: *exe-extensions 114 | rules-xla-concat: 115 | source-dirs: rules/xla/concat 116 | main: Main.hs 117 | dependencies: tensor-right 118 | ghc-options: *exe-ghc-options 119 | default-extensions: *exe-extensions 120 | rules-xla-logical: 121 | source-dirs: rules/xla/logical 122 | main: Main.hs 123 | dependencies: tensor-right 124 | ghc-options: *exe-ghc-options 125 | default-extensions: *exe-extensions 126 | rules-xla-relabel: 127 | source-dirs: rules/xla/relabel 128 | main: Main.hs 129 | dependencies: tensor-right 130 | ghc-options: *exe-ghc-options 131 | default-extensions: *exe-extensions 132 | rules-xla-dot: 133 | source-dirs: rules/xla/dot 134 | main: Main.hs 135 | dependencies: tensor-right 136 | ghc-options: *exe-ghc-options 137 | default-extensions: *exe-extensions 138 | rules-xla-conv: 139 | source-dirs: rules/xla/conv 140 | main: Main.hs 141 | dependencies: tensor-right 142 | ghc-options: *exe-ghc-options 143 | default-extensions: *exe-extensions 144 | rules-xla-max: 145 | source-dirs: rules/xla/max 146 | main: Main.hs 147 | dependencies: tensor-right 148 | ghc-options: *exe-ghc-options 149 | default-extensions: *exe-extensions 150 | rules-xla-not: 151 | source-dirs: rules/xla/not 152 | main: Main.hs 153 | dependencies: tensor-right 154 | ghc-options: *exe-ghc-options 155 | default-extensions: *exe-extensions 156 | rules-xla-clamp: 157 | source-dirs: rules/xla/clamp 158 | main: Main.hs 159 | dependencies: tensor-right 160 | ghc-options: *exe-ghc-options 161 | default-extensions: *exe-extensions 162 | rules-xla-select: 163 | source-dirs: rules/xla/select 164 | main: Main.hs 165 | dependencies: tensor-right 166 | ghc-options: *exe-ghc-options 167 | default-extensions: *exe-extensions 168 | rules-xla-reverse: 169 | source-dirs: rules/xla/reverse 170 | main: Main.hs 171 | dependencies: tensor-right 172 | ghc-options: *exe-ghc-options 173 | default-extensions: *exe-extensions 174 | rules-xla-sub: 175 | source-dirs: rules/xla/sub 176 | main: Main.hs 177 | dependencies: tensor-right 178 | ghc-options: *exe-ghc-options 179 | default-extensions: *exe-extensions 180 | rules-xla-divmod: 181 | source-dirs: rules/xla/divmod 182 | main: Main.hs 183 | dependencies: tensor-right 184 | ghc-options: *exe-ghc-options 185 | default-extensions: *exe-extensions 186 | rules-xla-generalize: 187 | source-dirs: rules/xla/generalize 188 | main: Main.hs 189 | dependencies: tensor-right 190 | ghc-options: *exe-ghc-options 191 | default-extensions: *exe-extensions 192 | # Other Executables 193 | rules-debug: 194 | source-dirs: rules/debug 195 | main: Main.hs 196 | dependencies: tensor-right 197 | ghc-options: *exe-ghc-options 198 | default-extensions: *exe-extensions 199 | 200 | tests: 201 | spec: 202 | main: Main.hs 203 | source-dirs: test 204 | dependencies: 205 | - tensor-right 206 | - test-framework >= 0.8.2 && < 0.9 207 | - test-framework-hunit >= 0.3.0.2 && < 0.4 208 | - test-framework-quickcheck2 >= 0.3.0.5 && < 0.4 209 | - HUnit >= 1.6 210 | - QuickCheck 211 | ghc-options: 212 | - -threaded 213 | - -rtsopts 214 | - -with-rtsopts=-N 215 | -------------------------------------------------------------------------------- /rules/xla/concat/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: forall a. AnyDTypeRule a 7 | rule01 _ = do 8 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 9 | rc0Size <- newMap "rc0Size" rclass0 10 | [rc1SizeA, rc1SizeB] <- newMaps ["rclass1size0", "rclass1size1"] rclass1 11 | tA <- newTensor @a "A" [rclass0 --> rc0Size, rclass1 --> rc1SizeA @@ "label1"] 12 | tB <- newTensor @a "B" [rclass0 --> rc0Size, rclass1 --> rc1SizeB @@ "label1"] 13 | lhs <- concatTensorList [tA, tB] (ByLabel "label1") 14 | rhs <- concatTensor tA tB (ByLabel "label1") 15 | rewrite "ConcatList(A, B) ⇒ Concat(A, B)" lhs rhs 16 | 17 | rule02 :: forall a. AnyDTypeRule a 18 | rule02 _ = do 19 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 20 | rc0Size <- newMap "rc0Size" rclass0 21 | [rc1SizeA, rc1SizeB, rc1SizeC] <- 22 | newMaps ["rc1SizeA", "rc1SizeB", "rc1SizeC"] rclass1 23 | tA <- newTensor @a "A" [rclass0 --> rc0Size, rclass1 --> rc1SizeA @@ "label1"] 24 | tB <- newTensor @a "B" [rclass0 --> rc0Size, rclass1 --> rc1SizeB @@ "label1"] 25 | tC <- newTensor @a "C" [rclass0 --> rc0Size, rclass1 --> rc1SizeC @@ "label1"] 26 | lhs <- 27 | concatTensor 28 | tA 29 | (concatTensor tB tC (ByLabel "label1")) 30 | (ByLabel "label1") 31 | rhs <- 32 | concatTensor 33 | (concatTensor tA tB (ByLabel "label1")) 34 | tC 35 | (ByLabel "label1") 36 | rewrite "Concat(A, Concat(B, C)) ⇒ Concat(Concat(A, B), C)" lhs rhs 37 | 38 | rule03 :: forall a. AnyDTypeRule a 39 | rule03 _ = do 40 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 41 | rc0Size <- newMap "rc0Size" rclass0 42 | [rc1SizeA, rc1SizeB, rc1SizeC] <- 43 | newMaps ["rc1SizeA", "rc1SizeB", "rc1SizeC"] rclass1 44 | tA <- newTensor @a "A" [rclass0 --> rc0Size, rclass1 --> rc1SizeA @@ "label1"] 45 | tB <- newTensor @a "B" [rclass0 --> rc0Size, rclass1 --> rc1SizeB @@ "label1"] 46 | tC <- newTensor @a "C" [rclass0 --> rc0Size, rclass1 --> rc1SizeC @@ "label1"] 47 | lhs <- concatTensorList [tA, tB, tC] (ByLabel "label1") 48 | rhs <- 49 | concatTensor 50 | tA 51 | (concatTensor tB tC (ByLabel "label1")) 52 | (ByLabel "label1") 53 | rewrite "ConcatList(A, B, C) ⇒ Concat(A, Concat(B, C))" lhs rhs 54 | 55 | rule04 :: forall a. AnyDTypeRule a 56 | rule04 _ = do 57 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 58 | [rc0Size, rc0Start, rc0End, rc0Stride] <- 59 | newMaps ["rc0Start", "rc0Start", "rc0End", "rc0Stride"] rclass0 60 | [rc1Size, rc1Start1, rc1End1, rc1Start2, rc1End2, rc1Stride] <- 61 | newMaps ["rc1Size", "rc1Start1", "rc1End1", "rc1Start2", "rc1End2", "rc1Stride"] rclass1 62 | 63 | tA <- newTensor @a "A" [rclass0 --> rc0Size, rclass1 --> rc1Size] 64 | tA1 <- 65 | slice tA $ 66 | Slice 67 | { start = [rclass0 --> rc0Start, rclass1 --> rc1Start1], 68 | end = [rclass0 --> rc0End, rclass1 --> rc1End1], 69 | strides = [rclass0 --> rc0Stride, rclass1 --> rc1Stride] 70 | } 71 | tA2 <- 72 | slice tA $ 73 | Slice 74 | { start = [rclass0 --> rc0Start, rclass1 --> rc1Start2], 75 | end = [rclass0 --> rc0End, rclass1 --> rc1End2], 76 | strides = [rclass0 --> rc0Stride, rclass1 --> rc1Stride] 77 | } 78 | lhs <- concatTensor tA1 tA2 (ByRClass rclass1) 79 | precondition [rc1Stride] $ \[p] -> p .== 1 80 | precondition [rc1End1, rc1Start2] $ \[e, s] -> e .== s 81 | 82 | rhs <- 83 | slice tA $ 84 | Slice 85 | { start = [rclass0 --> rc0Start, rclass1 --> rc1Start1], 86 | end = [rclass0 --> rc0End, rclass1 --> rc1End2], 87 | strides = [rclass0 --> rc0Stride, rclass1 --> rc1Stride] 88 | } 89 | 90 | rewrite "Concat(Slice(A), Slice(A)) ⇒ Slice(A)" lhs rhs 91 | 92 | rule05 :: forall a. AnyDTypeRule a 93 | rule05 _ = do 94 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 95 | [rc0Size, rc0Low] <- newMaps ["rc0Size", "rc0Low"] rclass0 96 | rc1Size <- newMap "rc1Size" rclass1 97 | 98 | tB <- newTensor @a "B" [rclass0 --> rc0Size, rclass1 --> rc1Size] 99 | scalar <- constant @a "a" [rclass0 --> rc0Low, rclass1 --> rc1Size] 100 | lhs <- concatTensor scalar tB (ByRClass rclass0) 101 | 102 | rc0Zero <- newConstMap "rc0Zero" 0 rclass0 103 | rhs <- 104 | pad tB ("a" :: a) $ 105 | Padding 106 | { low = [rclass0 --> rc0Low], 107 | high = [rclass0 --> rc0Zero], 108 | interior = [rclass0 --> rc0Zero] 109 | } 110 | rewrite "Concat(Broadcast(Scalar), B) ⇒ Pad(B, scalar, low)" lhs rhs 111 | 112 | rule06 :: forall a. AnyDTypeRule a 113 | rule06 _ = do 114 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 115 | [rc0Size, rc0High] <- newMaps ["rc0Size", "rc0High"] rclass0 116 | rc1Size <- newMap "rc1Size" rclass1 117 | 118 | tA <- newTensor @a "A" [rclass0 --> rc0Size, rclass1 --> rc1Size] 119 | scalar <- constant @a "a" [rclass0 --> rc0High, rclass1 --> rc1Size] 120 | lhs <- concatTensor tA scalar (ByRClass rclass0) 121 | 122 | rc0Zero <- newConstMap "rc0Zero" 0 rclass0 123 | rhs <- 124 | pad tA ("a" :: a) $ 125 | Padding 126 | { low = [rclass0 --> rc0Zero], 127 | high = [rclass0 --> rc0High], 128 | interior = [rclass0 --> rc0Zero] 129 | } 130 | rewrite "Concat(A, Broadcast(Scalar)) ⇒ Pad(A, scalar, high)" lhs rhs 131 | 132 | rule07 :: forall a. AnyDTypeRule a 133 | rule07 _ = do 134 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 135 | rc0Size <- newMap "rc0Size" rclass0 136 | [rc1SizeA, rc1SizeB, rc1SizeC] <- newMaps ["rc1SizeA", "rc1SizeB", "rc1SizeC"] rclass1 137 | tA <- newTensor @a "A" [rclass0 --> rc0Size, rclass1 --> rc1SizeA @@ "label1"] 138 | tB <- newTensor @a "B" [rclass0 --> rc0Size, rclass1 --> rc1SizeB @@ "label1"] 139 | tC <- newTensor @a "C" [rclass0 --> rc0Size, rclass1 --> rc1SizeC @@ "label1"] 140 | lhs <- concatTensor tA (concatTensor tB tC (ByLabel "label1")) (ByLabel "label1") 141 | rhs <- concatTensorList [tA, tB, tC] (ByLabel "label1") 142 | rewrite "Concat(A, Concat(B, C)) ⇒ ConcatList(A, B, C)" lhs rhs 143 | 144 | main :: IO () 145 | main = do 146 | print "############################## rule01 ##############################" 147 | verifyAnyDTypeDSL rule01 148 | print "############################## rule02 ##############################" 149 | verifyAnyDTypeDSL rule02 150 | print "############################## rule03 ##############################" 151 | verifyAnyDTypeDSL rule03 152 | print "############################## rule04 ##############################" 153 | verifyAnyDTypeDSL rule04 154 | print "############################## rule05 ##############################" 155 | verifyAnyDTypeDSL rule05 156 | print "############################## rule06 ##############################" 157 | verifyAnyDTypeDSL rule06 158 | print "############################## rule07 ##############################" 159 | verifyAnyDTypeDSL rule07 160 | -------------------------------------------------------------------------------- /rules/xla/pad/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding ((-->)) 4 | import TensorRight 5 | 6 | rule01 :: forall a. AnyDTypeRule a 7 | rule01 _ = do 8 | rclass <- newRClass "rclass" 9 | [rcSize, rcLow, rcInt, rcHigh] <- 10 | newMaps ["rcSize", "rcLow", "rcInt", "rcHigh"] rclass 11 | 12 | tA <- newTensor @a "A" [rclass --> rcSize] 13 | lhs <- 14 | pad tA ("a" :: a) $ 15 | Padding 16 | { low = [rclass --> rcLow], 17 | high = [rclass --> rcHigh], 18 | interior = [rclass --> rcInt] 19 | } 20 | precondition [rcLow] $ \[l] -> l .== 0 21 | precondition [rcInt] $ \[i] -> i .== 0 22 | precondition [rcHigh] $ \[h] -> h .== 0 23 | 24 | let rhs = tA 25 | rewrite "Pad(A, val, 0_0_0) ⇒ A" lhs rhs 26 | 27 | rule02 :: forall a. AnyDTypeRule a 28 | rule02 _ = do 29 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 30 | [rc0Size, rc0Low, rc0Int, rc0High] <- 31 | newMaps ["rc0Size", "rc0Low", "rc0Int", "rc0High"] rclass0 32 | [rc1Size, rc1Low, rc1Int, rc1High] <- 33 | newMaps ["rc1Size", "rc1Low", "rc1Int", "rc1High"] rclass1 34 | 35 | tA <- newTensor @a "A" [rclass0 --> rc0Size, rclass1 --> rc1Size] 36 | lhs <- 37 | pad tA ("a" :: a) $ 38 | Padding 39 | { low = [rclass0 --> rc0Low, rclass1 --> rc1Low], 40 | high = [rclass0 --> rc0High, rclass1 --> rc1High], 41 | interior = [rclass0 --> rc0Int, rclass1 --> rc1Int] 42 | } 43 | precondition [rc1Size] $ \[size1] -> size1 .== 1 44 | 45 | rc1NewInt <- newConstMap "rc1NewInt" 0 rclass1 46 | rhs <- 47 | pad tA ("a" :: a) $ 48 | Padding 49 | { low = [rclass0 --> rc0Low, rclass1 --> rc1Low], 50 | high = [rclass0 --> rc0High, rclass1 --> rc1High], 51 | interior = [rclass0 --> rc0Int, rclass1 --> rc1NewInt] 52 | } 53 | 54 | rewrite "Pad(A, val, low_int_high) ⇒ Pad(A, val, low_0_high)" lhs rhs 55 | 56 | rule03 :: forall a. AnyDTypeRule a 57 | rule03 _ = do 58 | [rclass0, rclass1, rclass2] <- newRClasses ["rclass0", "rclass1", "rclass2"] 59 | [rc0Size, rc0Low, rc0High, rc0Int] <- 60 | newMaps ["rc0Size", "rc0Low", "rc0High", "rc0Int"] rclass0 61 | [rc1Size, rc1Low, rc1High, rc1Int] <- 62 | newMaps ["rc1Size", "rc1Low", "rc1High", "rc1Int"] rclass1 63 | [rc2Size, rc2Low, rc2High, rc2Int] <- 64 | newMaps ["rc2Size", "rc2Low", "rc2high", "rc2Int"] rclass2 65 | 66 | tA <- newTensor @a "A" [rclass0 --> rc0Size] 67 | 68 | lhs <- 69 | pad (broadcast tA [rclass1 --> rc1Size, rclass2 --> rc2Size]) ("a" :: a) $ 70 | Padding 71 | { low = [rclass0 --> rc0Low, rclass1 --> rc1Low, rclass2 --> rc2Low], 72 | high = [rclass0 --> rc0High, rclass1 --> rc1High, rclass2 --> rc2High], 73 | interior = [rclass0 --> rc0Int, rclass1 --> rc1Int, rclass2 --> rc2Int] 74 | } 75 | precondition [rc2Int] $ \[i] -> i .== 0 76 | precondition [rc2Low] $ \[l] -> l .== 0 77 | precondition [rc2High] $ \[h] -> h .== 0 78 | 79 | rhs <- 80 | broadcast 81 | ( pad (broadcast tA [rclass1 --> rc1Size]) ("a" :: a) $ 82 | Padding 83 | { low = [rclass0 --> rc0Low, rclass1 --> rc1Low], 84 | high = [rclass0 --> rc0High, rclass1 --> rc1High], 85 | interior = [rclass0 --> rc0Int, rclass1 --> rc1Int] 86 | } 87 | ) 88 | [rclass2 --> rc2Size] 89 | 90 | rewrite "Pad(Broadcast(A), v, low_0_0) ⇒ Broadcast(Pad(Broadcast(A), v))" lhs rhs 91 | 92 | rule04 :: forall a. AnyDTypeRule a 93 | rule04 _ = do 94 | -- rclass0: positive low, positive high 95 | -- rclass1: positive low, negative high 96 | -- rclass2: negative low, positive high 97 | -- rclass3: negative low, negative high 98 | [rclass0, rclass1, rclass2, rclass3] <- 99 | newRClasses ["rclass0", "rclass1", "rclass2", "rclass3"] 100 | [rc0Size, rc0Low, rc0High, rc0Int] <- 101 | newMaps ["rc0Size", "rc0Low", "rc0High", "rc0Int"] rclass0 102 | [rc1Size, rc1Low, rc1High, rc1Int] <- 103 | newMaps ["rc1Size", "rc1Low", "rc1High", "rc1Int"] rclass1 104 | [rc2Size, rc2Low, rc2High, rc2Int] <- 105 | newMaps ["rc2Size", "rc2Low", "rc2High", "rc2Int"] rclass2 106 | [rc3Size, rc3Low, rc3High, rc3Int] <- 107 | newMaps ["rc3Size", "rc3Low", "rc3High", "rc3Int"] rclass3 108 | 109 | tA <- 110 | newTensor @a 111 | "A" 112 | [ rclass0 --> rc0Size, 113 | rclass1 --> rc1Size, 114 | rclass2 --> rc2Size, 115 | rclass3 --> rc3Size 116 | ] 117 | 118 | lhs <- 119 | pad tA ("a" :: a) $ 120 | Padding 121 | { low = 122 | [ rclass0 --> rc0Low, 123 | rclass1 --> rc1Low, 124 | rclass2 --> rc2Low, 125 | rclass3 --> rc3Low 126 | ], 127 | interior = 128 | [ rclass0 --> rc0Int, 129 | rclass1 --> rc1Int, 130 | rclass2 --> rc2Int, 131 | rclass3 --> rc3Int 132 | ], 133 | high = 134 | [ rclass0 --> rc0High, 135 | rclass1 --> rc1High, 136 | rclass2 --> rc2High, 137 | rclass3 --> rc3High 138 | ] 139 | } 140 | precondition [rc0Int] $ \[i] -> i .== 0 141 | precondition [rc1Int] $ \[i] -> i .== 0 142 | precondition [rc2Int] $ \[i] -> i .== 0 143 | precondition [rc3Int] $ \[i] -> i .== 0 144 | precondition [rc0Low] $ \[l] -> l .>= 0 145 | precondition [rc0High] $ \[h] -> h .>= 0 146 | precondition [rc1Low] $ \[l] -> l .>= 0 147 | precondition [rc1High] $ \[h] -> h .< 0 148 | precondition [rc2Low] $ \[l] -> l .< 0 149 | precondition [rc2High] $ \[h] -> h .>= 0 150 | precondition [rc3Low] $ \[l] -> l .< 0 151 | precondition [rc3High] $ \[h] -> h .< 0 152 | 153 | rc1End <- 154 | combineMap "rc1End" (\[s, l, h] -> s + l + h) [rc1Size, rc1Low, rc1High] 155 | rc2Start <- combineMap "rc2Start" (\[x] -> abs x) [rc2Low] 156 | rc3Start <- combineMap "rc3Start" (\[x] -> abs x) [rc3Low] 157 | rc3End <- combineMap "rc3End" sum [rc3Size, rc3High] 158 | rhs <- 159 | slice 160 | ( pad tA ("a" :: a) $ 161 | Padding 162 | { low = [rclass0 --> rc0Low, rclass1 --> rc1Low], 163 | interior = [], 164 | high = [rclass0 --> rc0High, rclass2 --> rc2High] 165 | } 166 | ) 167 | $ Slice 168 | { start = [rclass2 --> rc2Start, rclass3 --> rc3Start], 169 | end = [rclass1 --> rc1End, rclass3 --> rc3End], 170 | strides = [] 171 | } 172 | 173 | rewrite "Pad(A, val, negative_negative) ⇒ Slice(Pad(A, val, 0_0), abs(negative),negative+size)" lhs rhs 174 | 175 | main :: IO () 176 | main = do 177 | print "############################## rule01 ##############################" 178 | verifyAnyDTypeDSL rule01 179 | print "############################## rule02 ##############################" 180 | verifyAnyDTypeDSL rule02 181 | print "############################## rule03 ##############################" 182 | verifyAnyDTypeDSL rule03 183 | print "############################## rule04 ##############################" 184 | verifyAnyDTypeDSL rule04 185 | -------------------------------------------------------------------------------- /rules/xla/dot/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Debug.Trace (traceShow) 4 | import Grisette hiding (dot, (-->)) 5 | import TensorRight 6 | 7 | rule01 :: forall a. NumRule a 8 | rule01 _ = do 9 | [rclass0, rclass1, rclass2] <- newRClasses ["rclass0", "rclass1", "rclass2"] 10 | [xs0, ys0, dotsi, sixc, siyd, sir, rclass0AllOne] <- 11 | newMaps 12 | ["xs0", "ys0", "dotsi", "sixc", "siyd", "sir", "allOne"] 13 | rclass0 14 | [xs1, ys1] <- newMaps ["xs1", "ys1"] rclass1 15 | [cs2, ds2] <- newMaps ["cs2", "ds2"] rclass2 16 | x <- newTensor @a "x" [rclass0 --> xs0, rclass1 --> xs1] 17 | y <- newTensor @a "y" [rclass0 --> ys0, rclass1 --> ys1] 18 | c <- newTensor @a "c" [rclass0 --> xs0, rclass2 --> cs2] 19 | d <- newTensor @a "d" [rclass0 --> ys0, rclass2 --> ds2] 20 | lhs <- 21 | dot 22 | (concatTensor x y $ ByRClass rclass0) 23 | (concatTensor c d $ ByRClass rclass0) 24 | [rclass0 --> dotsi] 25 | [] 26 | rhs <- 27 | reduce 28 | ( concatTensor 29 | (broadcast (dot x c [rclass0 --> sixc] []) [rclass0 --> rclass0AllOne]) 30 | (broadcast (dot y d [rclass0 --> siyd] []) [rclass0 --> rclass0AllOne]) 31 | $ ByRClass rclass0 32 | ) 33 | [rclass0 --> sir] 34 | precondition [rclass0AllOne] $ \[rclass0AllOne] -> rclass0AllOne .== 1 35 | let siCondition [vdotsi, vsixc, vsiyd, vsir, vxs0, vys0] = 36 | symIte 37 | (vsir .== 0) 38 | (vdotsi .== vsixc) 39 | (vdotsi .== vsiyd + vxs0) 40 | .&& (vsixc .>= 0) 41 | .&& (vsiyd .>= 0) 42 | .&& (vsixc .< vxs0) 43 | .&& (vsiyd .< vys0) 44 | .&& (vsir .== 0 .|| vsir .== 1) 45 | siCondition _ = undefined 46 | siRelation [dotsi, sixc, siyd, sir, xs0, ys0] siCondition 47 | checkSIMap [dotsi] [sir, sixc, siyd] 48 | let lhsStr = "Dot(Concat(X, Y), Concat(C, D))" 49 | let rhsStr = "Reduce(Concat(Broadcast(Dot(X, C)), Broadcast(Dot(Y, D)))" 50 | rewrite (lhsStr <> " ⇒ " <> rhsStr) lhs rhs 51 | 52 | rule02 :: forall a. NumRule a 53 | rule02 _ = do 54 | [rclass0, rclass1, rclass2, rclass3] <- newRClasses ["rclass0", "rclass1", "rclass2", "rclass3"] 55 | [size0, lhssi0, rhssi0] <- newMaps ["size0", "lhssi0", "rhssi0"] rclass0 56 | size1 <- newMap "size1" rclass1 57 | size2 <- newMap "size2" rclass2 58 | size3 <- newMap "size3" rclass3 59 | a <- newTensor @a "a" [rclass0 --> size0, rclass1 --> size1, rclass2 --> size2] 60 | b <- newTensor @a "b" [rclass0 --> size0, rclass1 --> size1, rclass3 --> size3] 61 | lhs <- dot a b [rclass0 --> lhssi0] [ByRClass rclass1] 62 | rhs <- dot b a [rclass0 --> rhssi0] [ByRClass rclass1] 63 | siRelation [lhssi0, rhssi0] $ \[vlhssi0, vrhssi0] -> vlhssi0 .== vrhssi0 64 | checkSIMap [lhssi0] [rhssi0] 65 | rewrite "Dot(A, B) ⇒ Dot(B, A)" lhs rhs 66 | 67 | rule03 :: forall a. NumRule a 68 | rule03 _ = do 69 | [rclass0, rclass1, rclass2, rclass3] <- newRClasses ["rclass0", "rclass1", "rclass2", "rclass3"] 70 | [sizea0] <- newMaps ["sizea0"] rclass0 71 | [sizea1, sizeb1, siabLhs, siabRhs] <- newMaps ["sizea1", "sizeb1", "siabLhs", "siabRhs"] rclass1 72 | [sizeb2, sizec2, sibcLhs, sibcRhs] <- newMaps ["sizeb2", "sizec2", "sibcLhs", "sibcRhs"] rclass2 73 | [sizec3] <- newMaps ["sizec3"] rclass3 74 | 75 | tensorA <- newTensor @a "tensorA" [rclass0 --> sizea0, rclass1 --> sizea1] 76 | tensorB <- newTensor @a "tensorB" [rclass1 --> sizeb1, rclass2 --> sizeb2] 77 | tensorC <- newTensor @a "tensorC" [rclass2 --> sizec2, rclass3 --> sizec3] 78 | 79 | lhs <- 80 | dot 81 | tensorA 82 | ( dot 83 | tensorB 84 | tensorC 85 | [rclass2 --> sibcLhs] 86 | [] 87 | ) 88 | [rclass1 --> siabLhs] 89 | [] 90 | rhs <- 91 | dot 92 | ( dot 93 | tensorA 94 | tensorB 95 | [rclass1 --> siabRhs] 96 | [] 97 | ) 98 | tensorC 99 | [rclass2 --> sibcRhs] 100 | [] 101 | 102 | siRelation [siabLhs, siabRhs] $ \[vsiabLhs, vsiabRhs] -> vsiabLhs .== vsiabRhs 103 | siRelation [sibcLhs, sibcRhs] $ \[vsibcLhs, vsibcRhs] -> vsibcLhs .== vsibcRhs 104 | checkSIMap [siabLhs, sibcLhs] [siabRhs, sibcRhs] 105 | 106 | rewrite "Dot(A,Dot(B,C)) ⇒ Dot(Dot(A,B),C)" lhs lhs 107 | 108 | rule04 :: forall a. NumRule a 109 | rule04 _ = do 110 | [rclass0, rclass1, rclass2] <- newRClasses ["rclass0", "rclass1", "rclass2"] 111 | rclass0Size <- newMap "rclass0Size" rclass0 112 | rclass1Size <- newMap "rclass1Size" rclass1 113 | rclass2Size <- newMap "rclass2Size" rclass2 114 | x <- newTensor @a "x" [rclass0 --> rclass0Size, rclass1 --> rclass1Size] 115 | y <- newTensor @a "y" [rclass0 --> rclass0Size, rclass2 --> rclass2Size] 116 | lhs <- dot x y [] [ByRClass rclass0] 117 | rhs <- 118 | numBinOp 119 | Mul 120 | (broadcast x [rclass2 --> rclass2Size]) 121 | (broadcast y [rclass1 --> rclass1Size]) 122 | rewrite "Dot(A,B) ⇒ Mul(Broadcast(A), Broadcast(B)) when no contraction" lhs rhs 123 | 124 | rule05 :: forall a. NumRule a 125 | rule05 _ = do 126 | [rclass0, rclass1, crclass0, crclass1, brclass0, brclass1] <- 127 | newRClasses ["rclass0", "rclass1", "crclass0", "crclass1", "brclass0", "brclass1"] 128 | rclass0Size <- newMap "rclass0Size" rclass0 129 | rclass1Size <- newMap "rclass1Size" rclass1 130 | crclass0Size <- newMap "crclass0Size" crclass0 131 | crclass1Size <- newMap "crclass1Size" crclass1 132 | brclass0Size <- newMap "brclassSize" brclass0 133 | brclass1Size <- newMap "brclassSize" brclass1 134 | si0 <- newMap "si0" crclass0 135 | si1 <- newMap "si1" crclass1 136 | x <- 137 | newTensor @a 138 | "x" 139 | [ rclass0 --> rclass0Size, 140 | crclass0 --> crclass0Size, 141 | crclass1 --> crclass1Size, 142 | brclass0 --> brclass0Size, 143 | brclass1 --> brclass1Size 144 | ] 145 | y <- 146 | newTensor @a 147 | "y" 148 | [ rclass1 --> rclass1Size, 149 | crclass0 --> crclass0Size, 150 | crclass1 --> crclass1Size, 151 | brclass0 --> brclass0Size, 152 | brclass1 --> brclass1Size 153 | ] 154 | lhs <- dot x y [crclass0 --> si0, crclass1 --> si1] [ByRClass brclass0, ByRClass brclass1] 155 | rhs <- 156 | constant @a 157 | 0 158 | [ rclass0 --> rclass0Size, 159 | rclass1 --> rclass1Size, 160 | brclass0 --> brclass0Size, 161 | brclass1 --> brclass1Size 162 | ] 163 | precondition' 164 | [ rclass0Size, 165 | rclass1Size, 166 | crclass0Size, 167 | crclass1Size, 168 | brclass0Size, 169 | brclass1Size 170 | ] 171 | $ \[rclass0Size, rclass1Size, crclass0Size, crclass1Size, brclass0Size, brclass1Size] -> 172 | zipCondition (\[rclass0Size] -> rclass0Size .== 0) [rclass0Size] 173 | .|| zipCondition (\[rclass1Size] -> rclass1Size .== 0) [rclass1Size] 174 | .|| zipCondition (\[crclassSize] -> crclassSize .== 0) [crclass0Size] 175 | .|| zipCondition (\[crclassSize] -> crclassSize .== 0) [crclass1Size] 176 | .|| zipCondition (\[brclassSize] -> brclassSize .== 0) [brclass0Size] 177 | .|| zipCondition (\[brclassSize] -> brclassSize .== 0) [brclass1Size] 178 | rewrite "Dot(A,B) ⇒ 0 when one of the dimensions is 0" lhs rhs 179 | 180 | rule06 :: forall a. NumRule a 181 | rule06 _ = do 182 | [rclass, crclass, brclass] <- newRClasses ["rclass", "crclass", "brclass"] 183 | rclassSize <- newMap "rclassSize" rclass 184 | crclassSize <- newMap "crclassSize" crclass 185 | brclassSize <- newMap "brclassSize" brclass 186 | x <- newTensor @a "x" [rclass --> rclassSize, crclass --> crclassSize, brclass --> brclassSize] 187 | y <- newTensor @a "y" [crclass --> crclassSize, brclass --> brclassSize] 188 | si <- newMap "si" crclass 189 | rsi <- newMap "rsi" crclass 190 | lhs <- dot x y [crclass --> si] [ByRClass brclass] 191 | rhs <- 192 | reduce 193 | ( numBinOp 194 | Mul 195 | x 196 | (broadcast y [rclass --> rclassSize]) 197 | ) 198 | [crclass --> rsi] 199 | siRelation [si, rsi] $ \[vsi, vrsi] -> vsi .== vrsi 200 | checkSIMap [si] [rsi] 201 | rewrite 202 | ( "Dot(A,B) ⇒ Reduce(Mul(Broadcast(Transpose(A)), Broadcast(Transpose(B)))) " 203 | <> "when rhs only have contraction and batch rclasses" 204 | ) 205 | lhs 206 | rhs 207 | 208 | main :: IO () 209 | main = do 210 | print "############################## rule01 ##############################" 211 | verifyNumDSLWith cvc5 rule01 212 | print "############################## rule02 ##############################" 213 | verifyNumDSL rule02 214 | print "############################## rule03 ##############################" 215 | verifyNumDSL rule03 216 | print "############################## rule04 ##############################" 217 | verifyNumDSL rule04 218 | print "############################## rule05 ##############################" 219 | verifyNumDSL rule05 220 | print "############################## rule06 ##############################" 221 | verifyNumDSL rule06 222 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/Core/Axis.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE DeriveAnyClass #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE DerivingVia #-} 5 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 6 | {-# LANGUAGE OverloadedStrings #-} 7 | {-# LANGUAGE ScopedTypeVariables #-} 8 | 9 | module TensorRight.Internal.Core.Axis 10 | ( Axis (..), 11 | Indices, 12 | Sizes, 13 | AxisMapLike (..), 14 | Axes, 15 | unionAxisMap, 16 | allAxes, 17 | lookupAxis, 18 | getAxis, 19 | removeAxes, 20 | restrictAxes, 21 | mapAxisMap, 22 | mapAxisMapWithAxisKey, 23 | zipFoldAxisMap, 24 | zipAxisMap, 25 | zipAxisMap3, 26 | zipAxisMapM, 27 | zipAxisMapM3, 28 | addAxisMap, 29 | mulAxisMap, 30 | subAxisMap, 31 | foldAxisMap, 32 | castAxisMap, 33 | intersectionAxisMap, 34 | sameAxisMap, 35 | safeDivAxisMap, 36 | safeModAxisMap, 37 | ) 38 | where 39 | 40 | import Control.Exception (ArithException) 41 | import Data.Foldable (Foldable (toList)) 42 | import qualified Data.HashMap.Lazy as HM 43 | import qualified Data.HashSet as HS 44 | import Data.Hashable (Hashable) 45 | import Data.List (sortOn) 46 | import qualified Data.Text as T 47 | import GHC.Generics (Generic) 48 | import GHC.Stack (HasCallStack) 49 | import Grisette 50 | ( Default (Default), 51 | EvalSym (evalSym), 52 | ExtractSym (extractSymMaybe), 53 | LogicalOp ((.&&)), 54 | Mergeable (rootStrategy), 55 | MergingStrategy (SimpleStrategy, SortedStrategy), 56 | MonadTryMerge, 57 | PPrint (pformat), 58 | SafeDiv (safeDiv, safeMod), 59 | SimpleMergeable (mrgIte), 60 | Solvable (con), 61 | SymBool, 62 | SymEq ((.==)), 63 | SymInteger, 64 | ) 65 | import Grisette.Lib.Control.Monad.Except (mrgModifyError) 66 | import TensorRight.Internal.Util.Error (ErrorEnv) 67 | 68 | data Axis 69 | = Axis {_axis :: T.Text} 70 | | LabelledAxis {_label :: T.Text, _axis :: T.Text} 71 | deriving (Generic, Ord, Eq) 72 | deriving anyclass (Hashable) 73 | deriving (Mergeable, SymEq, EvalSym) via (Default Axis) 74 | 75 | instance PPrint Axis where 76 | pformat (Axis a) = pformat a 77 | pformat (LabelledAxis l a) = pformat l <> "@@" <> pformat a 78 | 79 | instance Show Axis where 80 | show (Axis a) = T.unpack a 81 | show (LabelledAxis l a) = T.unpack l ++ "@@" ++ T.unpack a 82 | 83 | newtype UnifiedMap = UnifiedMap 84 | { unUnifiedMap :: HM.HashMap Axis SymInteger 85 | } 86 | deriving newtype (Show, Eq) 87 | deriving newtype (Semigroup, Monoid) 88 | 89 | instance PPrint UnifiedMap where 90 | pformat = pformat . asHashMap 91 | 92 | instance SymEq UnifiedMap where 93 | UnifiedMap l .== UnifiedMap r = 94 | sortOn fst (HM.toList l) .== sortOn fst (HM.toList r) 95 | 96 | instance EvalSym UnifiedMap where 97 | evalSym b m = fromKVPairs . evalSym b m . HM.toList . asHashMap 98 | 99 | instance ExtractSym UnifiedMap where 100 | extractSymMaybe (UnifiedMap m) = extractSymMaybe $ snd <$> HM.toList m 101 | 102 | instance Mergeable UnifiedMap where 103 | rootStrategy = 104 | SortedStrategy (HM.keys . unUnifiedMap) $ 105 | const $ 106 | SimpleStrategy $ \c (UnifiedMap d1) (UnifiedMap d2) -> 107 | UnifiedMap $ HM.unionWith (mrgIte c) d1 d2 108 | 109 | instance AxisMapLike UnifiedMap where 110 | fromHashMap = UnifiedMap 111 | asHashMap = unUnifiedMap 112 | 113 | allAxes :: (AxisMapLike m) => m -> Axes 114 | allAxes = HS.fromList . HM.keys . asHashMap 115 | 116 | lookupAxis :: (AxisMapLike m) => Axis -> m -> Maybe SymInteger 117 | lookupAxis a = HM.lookup a . asHashMap 118 | 119 | getAxis :: (AxisMapLike m) => Axis -> m -> SymInteger 120 | getAxis a = (HM.! a) . asHashMap 121 | 122 | removeAxes :: (AxisMapLike m) => Axes -> m -> m 123 | removeAxes axes = 124 | fromHashMap . HM.filterWithKey (\k _ -> not $ k `HS.member` axes) . asHashMap 125 | 126 | restrictAxes :: (AxisMapLike m) => Axes -> m -> m 127 | restrictAxes axes = 128 | fromHashMap . HM.filterWithKey (\k _ -> k `HS.member` axes) . asHashMap 129 | 130 | castAxisMap :: (AxisMapLike m1, AxisMapLike m2) => m1 -> m2 131 | castAxisMap = fromHashMap . asHashMap 132 | 133 | unionAxisMap :: (AxisMapLike m) => m -> m -> m 134 | unionAxisMap l r = fromHashMap $ HM.union (asHashMap l) (asHashMap r) 135 | 136 | intersectionAxisMap :: (AxisMapLike m) => m -> m -> m 137 | intersectionAxisMap l r = 138 | fromHashMap $ HM.intersection (asHashMap l) (asHashMap r) 139 | 140 | mapAxisMap :: (AxisMapLike m) => (SymInteger -> SymInteger) -> m -> m 141 | mapAxisMap f m = fromHashMap $ HM.map f (asHashMap m) 142 | 143 | mapAxisMapWithAxisKey :: 144 | (AxisMapLike m) => (Axis -> SymInteger -> SymInteger) -> m -> m 145 | mapAxisMapWithAxisKey f m = fromHashMap $ HM.mapWithKey f (asHashMap m) 146 | 147 | zipAxisMap :: 148 | (HasCallStack, AxisMapLike m) => 149 | (SymInteger -> SymInteger -> SymInteger) -> 150 | m -> 151 | m -> 152 | m 153 | zipAxisMap f l r 154 | | allAxes l == allAxes r = 155 | fromHashMap $ HM.unionWith f (asHashMap l) (asHashMap r) 156 | | otherwise = error "Cannot zip maps with different axes" 157 | 158 | foldAxisMap :: 159 | (AxisMapLike m) => 160 | (SymInteger -> a) -> 161 | a -> 162 | (a -> a -> a) -> 163 | m -> 164 | a 165 | foldAxisMap f initial g m = 166 | HM.foldl' (\acc vl -> g acc (f vl)) initial (asHashMap m) 167 | 168 | zipFoldAxisMap :: 169 | (HasCallStack, AxisMapLike m) => 170 | (SymInteger -> SymInteger -> a) -> 171 | a -> 172 | (a -> a -> a) -> 173 | m -> 174 | m -> 175 | a 176 | zipFoldAxisMap f initial g l r 177 | | allAxes l == allAxes r = 178 | HM.foldlWithKey' 179 | (\acc k vl -> g acc (f vl (getAxis k r))) 180 | initial 181 | (asHashMap l) 182 | | otherwise = error "Cannot zip maps with different axes" 183 | 184 | zipAxisMap3 :: 185 | (HasCallStack, AxisMapLike m) => 186 | (SymInteger -> SymInteger -> SymInteger -> SymInteger) -> 187 | m -> 188 | m -> 189 | m -> 190 | m 191 | zipAxisMap3 f l r s 192 | | allAxes l == allAxes r && allAxes l == allAxes s = 193 | fromHashMap $ 194 | HM.mapWithKey (\k vl -> f vl (getAxis k r) (getAxis k s)) $ 195 | asHashMap l 196 | | otherwise = error "Cannot zip maps with different axes" 197 | 198 | zipAxisMapM :: 199 | (HasCallStack, AxisMapLike am, MonadTryMerge m) => 200 | (SymInteger -> SymInteger -> m SymInteger) -> 201 | am -> 202 | am -> 203 | m am 204 | zipAxisMapM f l r 205 | | allAxes l == allAxes r = 206 | fromHashMap 207 | <$> HM.traverseWithKey (\k vl -> f vl (getAxis k r)) (asHashMap l) 208 | | otherwise = error "Cannot zip maps with different axes" 209 | 210 | zipAxisMapM3 :: 211 | ( HasCallStack, 212 | AxisMapLike am, 213 | MonadTryMerge m 214 | ) => 215 | (SymInteger -> SymInteger -> SymInteger -> m SymInteger) -> 216 | am -> 217 | am -> 218 | am -> 219 | m am 220 | zipAxisMapM3 f l r s 221 | | allAxes l == allAxes r && allAxes l == allAxes s = 222 | fromHashMap 223 | <$> HM.traverseWithKey 224 | (\k vl -> f vl (getAxis k r) (getAxis k s)) 225 | (asHashMap l) 226 | | otherwise = error "Cannot zip maps with different axes" 227 | 228 | newtype Indices = Indices UnifiedMap 229 | deriving newtype 230 | ( Show, 231 | Eq, 232 | Mergeable, 233 | AxisMapLike, 234 | Semigroup, 235 | Monoid, 236 | SymEq, 237 | EvalSym, 238 | ExtractSym, 239 | PPrint 240 | ) 241 | 242 | newtype Sizes = Sizes UnifiedMap 243 | deriving newtype 244 | ( Show, 245 | Eq, 246 | Mergeable, 247 | AxisMapLike, 248 | Semigroup, 249 | Monoid, 250 | SymEq, 251 | ExtractSym, 252 | EvalSym, 253 | PPrint 254 | ) 255 | 256 | class (Monoid m) => AxisMapLike m where 257 | fromHashMap :: HM.HashMap Axis SymInteger -> m 258 | fromKVPairs :: (Foldable t) => t (Axis, SymInteger) -> m 259 | fromKVPairs = fromHashMap . HM.fromList . toList 260 | asHashMap :: m -> HM.HashMap Axis SymInteger 261 | 262 | type Axes = HS.HashSet Axis 263 | 264 | addAxisMap :: (HasCallStack, AxisMapLike m) => m -> m -> m 265 | addAxisMap = zipAxisMap (+) 266 | 267 | subAxisMap :: (HasCallStack, AxisMapLike m) => m -> m -> m 268 | subAxisMap = zipAxisMap (-) 269 | 270 | mulAxisMap :: (HasCallStack, AxisMapLike m) => m -> m -> m 271 | mulAxisMap = zipAxisMap (*) 272 | 273 | sameAxisMap :: (HasCallStack, AxisMapLike m) => m -> m -> SymBool 274 | sameAxisMap l r = 275 | foldl 276 | (\acc dimName -> acc .&& getAxis dimName l .== getAxis dimName r) 277 | (con (allAxes l == allAxes r)) 278 | (allAxes l) 279 | 280 | safeDivAxisMap :: 281 | (HasCallStack, AxisMapLike m, Mergeable m) => 282 | m -> 283 | m -> 284 | ErrorEnv m 285 | safeDivAxisMap = 286 | zipAxisMapM 287 | ( \l r -> 288 | mrgModifyError (\(_ :: ArithException) -> "Division by zero") $ 289 | safeDiv l r 290 | ) 291 | 292 | safeModAxisMap :: 293 | (HasCallStack, AxisMapLike m, Mergeable m) => 294 | m -> 295 | m -> 296 | ErrorEnv m 297 | safeModAxisMap = 298 | zipAxisMapM 299 | ( \l r -> 300 | mrgModifyError (\(_ :: ArithException) -> "Division by zero") $ 301 | safeMod l r 302 | ) 303 | -------------------------------------------------------------------------------- /plot/timing_plot.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | import dataclasses 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib 6 | from collections import Counter 7 | matplotlib.rcParams['pdf.fonttype'] = 42 8 | matplotlib.rcParams['ps.fonttype'] = 42 9 | 10 | @dataclasses.dataclass 11 | class Result: 12 | theory: str 13 | num_of_tasks: int 14 | has_warning: bool 15 | time: float 16 | succeeded: bool 17 | 18 | 19 | @dataclasses.dataclass 20 | class Rule: 21 | name: str 22 | results: list[Result] 23 | succeeeded: bool 24 | def all_same_num_of_tasks(self) -> bool: 25 | return all(x.num_of_tasks == self.results[0].num_of_tasks for x in self.results) 26 | def first_num_of_tasks(self) -> int: 27 | return self.results[0].num_of_tasks 28 | def overall_time(self) -> float: 29 | return sum(x.time for x in self.results) 30 | 31 | 32 | def parse_file(lines: Sequence[str]) -> list[Rule]: 33 | """ 34 | The file looks like this: 35 | ====> Slice(Concat(A,B), A.size, A.size+B.size) ⇒ B 36 | >>> Bool 37 | [INFO-Bool]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 38 | [INFO-Bool]: Number of bounded verification tasks: 1 39 | [SUCCESS-Bool]: [0.12387501999910455s] Verification succeeded. 40 | >>> Int 41 | [INFO-Int]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 42 | [INFO-Int]: Number of bounded verification tasks: 1 43 | [SUCCESS-Int]: [0.11165066000830848s] Verification succeeded. 44 | >>> Real 45 | [INFO-Real]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 46 | [INFO-Real]: Number of bounded verification tasks: 1 47 | [SUCCESS-Real]: [0.106301934007206s] Verification succeeded. 48 | >>> Overall 49 | [SUCCESS-Overall]: [0.34182761401461903s] Verification succeeded. 50 | ====> Slice(Broadcast(A)) ⇒ Broadcast(A) 51 | >>> Bool 52 | [INFO-Bool]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 53 | [INFO-Bool]: Number of bounded verification tasks: 1 54 | [SUCCESS-Bool]: [0.34009405800316017s] Verification succeeded. 55 | >>> Int 56 | [INFO-Int]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 57 | [INFO-Int]: Number of bounded verification tasks: 1 58 | [SUCCESS-Int]: [0.23800732199742924s] Verification succeeded. 59 | >>> Real 60 | [INFO-Real]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 61 | [INFO-Real]: Number of bounded verification tasks: 1 62 | [SUCCESS-Real]: [0.26108629700320307s] Verification succeeded. 63 | >>> Overall 64 | [SUCCESS-Overall]: [0.8391876770037925s] Verification succeeded. 65 | """ 66 | 67 | rules: list[Rule] = [] 68 | i = 0 69 | while i < len(lines): 70 | line = lines[i] 71 | if line.startswith("====>"): 72 | name = " ".join(line.split(" ")[1:]) 73 | results: list[Result] = [] 74 | has_warning = False 75 | i += 1 76 | while i < len(lines) and not lines[i].startswith("====>"): 77 | if lines[i].startswith(">>>"): 78 | theory = lines[i].split(" ")[1].strip() 79 | if theory != "Overall": 80 | i += 1 81 | if lines[i].startswith("[FAIL"): 82 | time = float(lines[i].split(" ")[1][1:-2]) 83 | results.append(Result(theory, -1, False, time, False)) 84 | i += 1 85 | continue 86 | i += 1 87 | num_of_tasks = int(lines[i].split(" ")[-1]) 88 | i += 1 89 | while lines[i].startswith("[WARNING"): 90 | has_warning = True 91 | i += 1 92 | time = float(lines[i].split(" ")[1][1:-2]) 93 | succeeded = lines[i].startswith("[SUCCESS") 94 | results.append( 95 | Result(theory, num_of_tasks, has_warning, time, succeeded)) 96 | else: 97 | i += 1 98 | time = float(lines[i].split(" ")[1][1:-2]) 99 | succeeded = lines[i].startswith("[SUCCESS") 100 | rules.append(Rule(name, results, succeeded)) 101 | i += 1 102 | break 103 | else: 104 | # No theory is given 105 | i += 1 106 | num_of_tasks = int(lines[i].split(" ")[-1]) 107 | i += 1 108 | while lines[i].startswith("[WARNING"): 109 | has_warning = True 110 | i += 1 111 | time = float(lines[i].split(" ")[1][1:-2]) 112 | succeeded = lines[i].startswith("[SUCCESS") 113 | results.append(Result("", num_of_tasks, has_warning, time, succeeded)) 114 | rules.append(Rule(name, results, succeeded)) 115 | i += 1 116 | break 117 | i += 1 118 | else: 119 | i += 1 120 | return rules 121 | 122 | 123 | sample = """====> Slice(Concat(A,B), A.size, A.size+B.size) ⇒ B 124 | >>> Bool 125 | [INFO-Bool]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 126 | [INFO-Bool]: Number of bounded verification tasks: 1 127 | [SUCCESS-Bool]: [0.12387501999910455s] Verification succeeded. 128 | >>> Int 129 | [INFO-Int]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 130 | [INFO-Int]: Number of bounded verification tasks: 1 131 | [SUCCESS-Int]: [0.11165066000830848s] Verification succeeded. 132 | >>> Real 133 | [INFO-Real]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 134 | [INFO-Real]: Number of bounded verification tasks: 1 135 | [SUCCESS-Real]: [0.106301934007206s] Verification succeeded. 136 | >>> Overall 137 | [SUCCESS-Overall]: [0.34182761401461903s] Verification succeeded. 138 | ====> Slice(Broadcast(A)) ⇒ Broadcast(A) 139 | >>> Bool 140 | [INFO-Bool]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 141 | [INFO-Bool]: Number of bounded verification tasks: 1 142 | [SUCCESS-Bool]: [0.34009405800316017s] Verification succeeded. 143 | >>> Int 144 | [INFO-Int]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 145 | [INFO-Int]: Number of bounded verification tasks: 1 146 | [SUCCESS-Int]: [0.23800732199742924s] Verification succeeded. 147 | >>> Real 148 | [INFO-Real]: Inferred bounds: fromList [(rclass1,1),(rclass0,1)] 149 | [INFO-Real]: Number of bounded verification tasks: 1 150 | [SUCCESS-Real]: [0.26108629700320307s] Verification succeeded. 151 | >>> Overall 152 | [SUCCESS-Overall]: [0.8391876770037925s] Verification succeeded. 153 | """.splitlines() 154 | # print(parse_file(sample)) 155 | with open("result.txt", "r") as f: 156 | res = parse_file(list(map(lambda x: x.strip(), f.readlines()))) 157 | 158 | number_of_failed_inference = len(list(filter(lambda x: x < 0, map(lambda x: x.first_num_of_tasks(), res)))) 159 | number_of_tasks = list(filter(lambda x: x >= 0, map(lambda x: x.first_num_of_tasks(), res))) 160 | success = list(filter(lambda x: x.succeeeded, res)) 161 | 162 | times = list(map(lambda x: x.overall_time(), success)) 163 | 164 | def print_stats(success): 165 | times = list(map(lambda x: x.overall_time(), success)) 166 | print(f'Number of verified rules: {len(times)}') 167 | print(f'Max time: {max(times)}') 168 | print(f'Min time: {min(times)}') 169 | print(f'Average time: {sum(times) / len(times)}') 170 | print(f'Geometric mean: {np.exp(np.mean(np.log(times)))}') 171 | print(f'Number of rules verified under 1s: {len(list(filter(lambda x: x < 1, times)))}') 172 | print(f'Number of rules verified under 5s: {len(list(filter(lambda x: x < 5, times)))}') 173 | 174 | assert(all(list(map(lambda x: x.all_same_num_of_tasks(), success)))) 175 | num_tasks_freq = Counter(list(map(lambda x: x.first_num_of_tasks(), success))) 176 | for key, value in num_tasks_freq.items(): 177 | print(f'Number of rules with {key} tasks: {value}') 178 | 179 | tasks = sorted(num_tasks_freq.keys()) 180 | values = [num_tasks_freq[t] for t in tasks] 181 | tasks = [str(t) for t in tasks] 182 | plt.figure(figsize=(3,2)) 183 | plt.bar(tasks, values, label="AQP", color='C0') 184 | plt.xticks(tasks) 185 | plt.ylim(0, max(values) * 1.2) 186 | 187 | for t in tasks: 188 | plt.text(t, num_tasks_freq[int(t)] + 0.5, str(num_tasks_freq[int(t)]), ha='center', va='bottom') 189 | 190 | plt.ylabel('Number of Rules') 191 | plt.xlabel('Number of Tasks') 192 | plt.savefig('num_tasks.pdf', bbox_inches='tight', dpi=600, pad_inches=0) 193 | plt.clf() 194 | 195 | def plot_hist(times): 196 | bins = np.logspace(np.log10(0.001),np.log10(max(t for t in times)), 5) 197 | 198 | plt.hist(times, bins=20, edgecolor='black') 199 | 200 | plt.title('Total Time taken for Unbounded Verification') 201 | plt.xlabel('Time (seconds)') 202 | plt.ylabel('Number of Rules') 203 | 204 | plt.savefig('timing_plot.pdf', format="pdf", bbox_inches="tight", dpi=600) 205 | plt.clf() 206 | 207 | def plot_cdf(times): 208 | times = sorted(times) 209 | yvals = np.arange(len(times))/float(len(times)) 210 | plt.figure(figsize=(3,2)) 211 | plt.grid() 212 | plt.xscale("log") 213 | plt.yticks([0.2 * i for i in range(6)]) 214 | plt.ylim(0,1) 215 | plt.xlabel(" Total Verification Time (s)") 216 | plt.ylabel("CDF") 217 | plt.xticks([0.1, 1, 10], ['0.1', '1', '10']) 218 | plt.plot(times, yvals, color='C0') 219 | plt.savefig("timing_plot.pdf", format="pdf", bbox_inches='tight', dpi=600, pad_inches=0) 220 | plt.clf() 221 | 222 | print_stats(success) 223 | plot_cdf(times) 224 | -------------------------------------------------------------------------------- /rules/xla/reduce/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | import Grisette hiding (dot, (-->)) 4 | import TensorRight 5 | 6 | rule01 :: forall a. NumRule a 7 | rule01 _ = do 8 | [rclass0, rclass1] <- newRClasses ["rclass0", "rclass1"] 9 | [rc0Size, rc0LhsSi, rc0RhsSi] <- 10 | newMaps ["rclass0size", "rc0LhsSi", "rc0RhsSi"] rclass0 11 | [rc1Size, rc1LhsSi, rc1RhsSi] <- 12 | newMaps ["rclass1size", "rc1LhsSi", "rc1RhsSi"] rclass1 13 | tX <- newTensor @a "X" [rclass0 --> rc0Size] 14 | tY <- newTensor @a "Y" [rclass1 --> rc1Size] 15 | lhs <- 16 | numBinOp 17 | Mul 18 | (reduce tX [rclass0 --> rc0LhsSi]) 19 | (reduce tY [rclass1 --> rc1LhsSi]) 20 | rhs <- 21 | reduce 22 | ( numBinOp 23 | Mul 24 | (broadcast tX [rclass1 --> rc1Size]) 25 | (broadcast tY [rclass0 --> rc0Size]) 26 | ) 27 | [rclass0 --> rc0RhsSi, rclass1 --> rc1RhsSi] 28 | siRelation [rc0LhsSi, rc0RhsSi] $ 29 | \[l, r] -> l .== r 30 | siRelation [rc1LhsSi, rc1RhsSi] $ 31 | \[l, r] -> l .== r 32 | checkSIMap [rc0LhsSi, rc1LhsSi] [rc0RhsSi, rc1RhsSi] 33 | rewrite 34 | "Mul(Reduce(X), Reduce(Y)) ⇒ Reduce(Mul(Broadcast(X), Broadcast(Y)))" 35 | lhs 36 | rhs 37 | 38 | rule02 :: forall a. NumRule a 39 | rule02 _ = do 40 | [concatRClass, otherRClass] <- newRClasses ["concatRClass", "otherRClass"] 41 | [concatSize0, concatSize1] <- 42 | newMaps ["concatSize0", "concatSize1"] concatRClass 43 | singletonConcat <- newConstMap "singletonConcat" 1 concatRClass 44 | [otherSize] <- newMaps ["otherSize"] otherRClass 45 | [lhsSIInnerX, lhsSIInnerY, lhsSIOuter, rhsSI] <- 46 | newMaps ["lhsSIInnerX", "lhsSIInnerY", "lhsSIOuter", "rhsSI"] concatRClass 47 | tX <- newTensor @a "X" [concatRClass --> concatSize0, otherRClass --> otherSize] 48 | tY <- newTensor @a "Y" [concatRClass --> concatSize1, otherRClass --> otherSize] 49 | lhs <- 50 | reduce 51 | ( concatTensor 52 | ( broadcast 53 | (reduce tX [concatRClass --> lhsSIInnerX]) 54 | [concatRClass --> singletonConcat] 55 | ) 56 | ( broadcast 57 | (reduce tY [concatRClass --> lhsSIInnerY]) 58 | [concatRClass --> singletonConcat] 59 | ) 60 | (ByRClass concatRClass) 61 | ) 62 | [concatRClass --> lhsSIOuter] 63 | rhs <- reduce (concatTensor tX tY (ByRClass concatRClass)) [concatRClass --> rhsSI] 64 | monitorMapOnFailure "concatSize0" (ByRClass concatRClass) concatSize0 65 | monitorMapOnFailure "concatSize1" (ByRClass concatRClass) concatSize1 66 | monitorMapOnFailure "lhsSIInnerX" (ByRClass concatRClass) lhsSIInnerX 67 | monitorMapOnFailure "lhsSIInnerY" (ByRClass concatRClass) lhsSIInnerY 68 | monitorMapOnFailure "lhsSIOuter" (ByRClass concatRClass) lhsSIOuter 69 | 70 | siRelation [lhsSIInnerX, lhsSIInnerY, lhsSIOuter, rhsSI, concatSize0, concatSize1] $ 71 | \[lhsSIInnerX, lhsSIInnerY, lhsSIOuter, rhsSI, concatSize0, concatSize1] -> 72 | (lhsSIInnerX .>= 0 .&& lhsSIInnerX .< concatSize0) 73 | .&& (lhsSIInnerY .>= 0 .&& lhsSIInnerY .< concatSize1) 74 | .&& (lhsSIOuter .== 0 .|| lhsSIOuter .== 1) 75 | .&& symIte 76 | (lhsSIOuter .== 0) 77 | (lhsSIInnerX .== rhsSI) 78 | (concatSize0 + lhsSIInnerY .== rhsSI) 79 | checkSIMap [lhsSIInnerX, lhsSIInnerY, lhsSIOuter] [rhsSI] 80 | rewrite "Reduce(Concat(Reduce(X), Reduce(Y))) ⇒ Reduce(Concat(X,Y))" lhs rhs 81 | 82 | rule03 :: forall a. NumRule a 83 | rule03 _ = do 84 | [reductionRClass, rclass1] <- newRClasses ["reductionRClass", "rclass1"] 85 | [reductionSize] <- newMaps ["reductionSize"] reductionRClass 86 | [lhsSI, rhsSI] <- newMaps ["lhsSI", "rhsSI"] reductionRClass 87 | [otherSize] <- newMaps ["otherSize"] rclass1 88 | x <- newTensor @a "x" [reductionRClass --> reductionSize, rclass1 --> otherSize] 89 | lhs <- 90 | numBinOp 91 | Mul 92 | (constant @a "a" [rclass1 --> otherSize]) 93 | (reduce x [reductionRClass --> lhsSI]) 94 | rhs <- 95 | reduce 96 | ( numBinOp 97 | Mul 98 | ( constant @a 99 | "a" 100 | [reductionRClass --> reductionSize, rclass1 --> otherSize] 101 | ) 102 | x 103 | ) 104 | [reductionRClass --> rhsSI] 105 | siRelation [lhsSI, rhsSI] $ \[lhsSI, rhsSI] -> lhsSI .== rhsSI 106 | checkSIMap [lhsSI] [rhsSI] 107 | rewrite "Const * Reduce(X) ⇒ Reduce(Const * X)" lhs rhs 108 | 109 | rule04 :: forall a. NumRule a 110 | rule04 _ = do 111 | [reductionRClass, nonReductionRClass] <- newRClasses ["reductionRClass", "nonReductionRClass"] 112 | [reductionSize] <- 113 | newMaps ["reductionSize"] reductionRClass 114 | [otherSize] <- 115 | newMaps ["otherSize"] nonReductionRClass 116 | [lhsSI, rhsSI] <- newMaps ["lhsSI", "rhsSI"] reductionRClass 117 | x <- 118 | newTensor @a 119 | "x" 120 | [ reductionRClass --> reductionSize @@ "l0", 121 | nonReductionRClass --> otherSize @@ "l1" 122 | ] 123 | lhs <- 124 | reduce 125 | ( relabel 126 | x 127 | [ ByLabel "l0" --> ByLabel "l1", 128 | ByLabel "l1" --> ByLabel "l0" 129 | ] 130 | ) 131 | [ByLabel "l1" --> lhsSI] 132 | rhs <- 133 | relabel 134 | (reduce x [ByLabel "l0" --> rhsSI]) 135 | [ByLabel "l1" --> ByLabel "l0"] 136 | siRelation [lhsSI, rhsSI] $ \[lhsSI, rhsSI] -> lhsSI .== rhsSI 137 | checkSIMap [lhsSI] [rhsSI] 138 | rewrite "Reduce(Relabel(X)) ⇒ Relabel(Reduce(X))" lhs rhs 139 | 140 | rule05 :: forall a. NumRule a 141 | rule05 _ = do 142 | [rclass0, rclass1, rclass2] <- newRClasses ["rclass0", "rclass1", "rclass2"] 143 | [rc0Size, rc0LhsSi, rc0RhsSi] <- 144 | newMaps ["rc0Size", "rc0LhsSi", "rc0RhsSi"] rclass0 145 | [rc1Size, rc1LhsSi, rc1RhsSi] <- 146 | newMaps ["rc1Size", "rc1LhsSi", "rc1RhsSi"] rclass1 147 | rc2Size <- newMap "rc2Size" rclass2 148 | tX <- 149 | newTensor @a 150 | "X" 151 | [rclass0 --> rc0Size, rclass1 --> rc1Size, rclass2 --> rc2Size] 152 | lhs <- reduce (reduce tX [rclass0 --> rc0LhsSi]) [rclass1 --> rc1LhsSi] 153 | rhs <- reduce tX [rclass0 --> rc0RhsSi, rclass1 --> rc1RhsSi] 154 | siRelation [rc0LhsSi, rc0RhsSi] $ 155 | \[l, r] -> l .== r 156 | siRelation [rc1LhsSi, rc1RhsSi] $ 157 | \[l, r] -> l .== r 158 | checkSIMap [rc0LhsSi, rc1LhsSi] [rc0RhsSi, rc1RhsSi] 159 | rewrite "Reduce(Reduce(X)) ⇒ Reduce(X)" lhs rhs 160 | 161 | rule06 :: forall a. NumRule a 162 | rule06 _ = do 163 | [rclass0, rclass1, rclass2, rclass3] <- newRClasses ["rclass0", "rclass1", "rclass2", "rclass3"] 164 | [rc0Size, rc0LhsSi, rc0RhsSi] <- 165 | newMaps ["rc0Size", "rc0LhsSi", "rc0RhsSi"] rclass0 166 | [rc1Size, rc1LhsSi, rc1RhsSi] <- 167 | newMaps ["rc1Size", "rc1LhsSi", "rc1RhsSi"] rclass1 168 | rc2Size <- newMap "rc2Size" rclass2 169 | rc3Size <- newMap "rc3Size" rclass3 170 | x <- 171 | newTensor @a 172 | "x" 173 | [rclass0 --> rc0Size, rclass1 --> rc1Size, rclass2 --> rc2Size] 174 | y <- 175 | newTensor @a 176 | "y" 177 | [rclass0 --> rc0Size, rclass1 --> rc1Size, rclass3 --> rc3Size] 178 | lhs <- 179 | reduce 180 | (dot x y [rclass0 --> rc0LhsSi] [ByRClass rclass1]) 181 | [rclass1 --> rc1LhsSi] 182 | rhs <- dot x y [rclass0 --> rc0RhsSi, rclass1 --> rc1RhsSi] [] 183 | siRelation [rc0LhsSi, rc0RhsSi] $ 184 | \[l, r] -> l .== r 185 | siRelation [rc1LhsSi, rc1RhsSi] $ 186 | \[l, r] -> l .== r 187 | checkSIMap [rc0LhsSi, rc1LhsSi] [rc0RhsSi, rc1RhsSi] 188 | rewrite "Reduce(Dot(X,Y)) ⇒ Dot(X,Y)" lhs rhs 189 | 190 | rule07 :: forall a. NumRule a 191 | rule07 _ = do 192 | [rclassDegenerate, rclass1] <- newRClasses ["rclassDegenerate", "rclass1"] 193 | rclassDegenerateSize <- newConstMap "rclassDegenerateSize" 1 rclassDegenerate 194 | rclass1Size <- newMap "rclass1Size" rclass1 195 | x <- newTensor @a "x" [rclassDegenerate --> rclassDegenerateSize, rclass1 --> rclass1Size] 196 | siDegenerate <- newMap "siDegenerate" rclassDegenerate 197 | 198 | lhs <- reduce x [rclassDegenerate --> siDegenerate] 199 | rhs <- reshapeDegenerate x [] [ByRClass rclassDegenerate] 200 | 201 | siRelation [siDegenerate] $ \[siDegenerate] -> siDegenerate .== 0 202 | checkSIMap [siDegenerate] [] 203 | rewrite "Reduce(X) ⇒ ReshapeDegenerate(X)" lhs rhs 204 | 205 | rule08 :: forall a. NumRule a 206 | rule08 _ = do 207 | [rclassDegenerate, rclass1] <- newRClasses ["rclassDegenerate", "rclass1"] 208 | rclassDegenerateSize <- newConstMap "rclassDegenerateSize" 1 rclassDegenerate 209 | rclass1Size <- newMap "rclass1Size" rclass1 210 | x <- newTensor @a "x" [rclassDegenerate --> rclassDegenerateSize, rclass1 --> rclass1Size] 211 | siDegenerate <- newMap "siDegenerate" rclassDegenerate 212 | 213 | lhs <- 214 | reduce 215 | (relabel x [rclassDegenerate --> ByLabel "l0"]) 216 | [ByLabel "l0" --> siDegenerate] 217 | rhs <- reshapeDegenerate x [] [ByRClass rclassDegenerate] 218 | 219 | siRelation [siDegenerate] $ \[siDegenerate] -> siDegenerate .== 0 220 | checkSIMap [siDegenerate] [] 221 | rewrite "Reduce(X) ⇒ ReshapeDegenerate(X)" lhs rhs 222 | 223 | main :: IO () 224 | main = do 225 | print "############################## rule01 ##############################" 226 | verifyNumDSL rule01 227 | print "############################## rule02 ##############################" 228 | verifyNumDSL rule02 229 | print "############################## rule03 ##############################" 230 | verifyNumDSL rule03 231 | print "############################## rule04 ##############################" 232 | verifyNumDSL rule04 233 | print "############################## rule05 ##############################" 234 | verifyNumDSL rule05 235 | print "############################## rule06 ##############################" 236 | verifyNumDSL rule06 237 | print "############################## rule07 ##############################" 238 | verifyNumDSL rule07 239 | print "############################## rule08 ##############################" 240 | verifyNumDSL rule08 241 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/DSL/BoundInference.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DataKinds #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE DerivingVia #-} 4 | {-# LANGUAGE FlexibleContexts #-} 5 | {-# LANGUAGE GADTs #-} 6 | {-# LANGUAGE OverloadedStrings #-} 7 | {-# LANGUAGE PatternSynonyms #-} 8 | {-# LANGUAGE RankNTypes #-} 9 | {-# LANGUAGE RecordWildCards #-} 10 | {-# LANGUAGE ScopedTypeVariables #-} 11 | {-# LANGUAGE TupleSections #-} 12 | 13 | module TensorRight.Internal.DSL.BoundInference (inferBound) where 14 | 15 | import Control.Monad (when) 16 | import Control.Monad.State (State, execState, modify) 17 | import Data.Function (on) 18 | import qualified Data.HashMap.Lazy as HM 19 | import qualified Data.HashSet as HS 20 | import Data.List (groupBy, sortOn) 21 | import Data.List.NonEmpty (NonEmpty ((:|))) 22 | import Data.String (IsString (fromString)) 23 | import qualified Data.Text as T 24 | import Data.Typeable (cast) 25 | import GHC.Generics (Generic) 26 | import Grisette 27 | ( Default (Default), 28 | GrisetteSMTConfig, 29 | Identifier (Identifier), 30 | LinkedRep (underlyingTerm), 31 | LogicalOp ((.&&)), 32 | PPrint, 33 | SExpr (List), 34 | Solvable (con, ssym), 35 | SolvingFailure (Unsat), 36 | SomeTerm (SomeTerm), 37 | SymBool (SymBool), 38 | SymEq ((./=)), 39 | Symbol (SimpleSymbol), 40 | Term, 41 | TypedSymbol (TypedSymbol, unTypedSymbol), 42 | solve, 43 | someTermSize, 44 | withMetadata, 45 | pattern ApplyTerm, 46 | pattern ConTerm, 47 | pattern DistinctTerm, 48 | pattern Metadata, 49 | pattern SubTerms, 50 | pattern SymTerm, 51 | ) 52 | import TensorRight.Internal.Core.Axis 53 | ( Axis (Axis, LabelledAxis), 54 | AxisMapLike (fromHashMap), 55 | Indices, 56 | ) 57 | import TensorRight.Internal.Core.Tensor (tensorDType) 58 | import TensorRight.Internal.Core.Verify 59 | ( VerifyTask (VerifyTask), 60 | getTensorWithValidityCondition, 61 | rewritingRuleAccess, 62 | ) 63 | import TensorRight.Internal.DSL.Eval 64 | ( SymIdentInfo (SymMap, SymTensor), 65 | getAxisName, 66 | ) 67 | import TensorRight.Internal.DSL.Identifier (RClassIdentifier, TensorIdentifier) 68 | import TensorRight.Internal.DSL.Shape 69 | ( AbstractShape 70 | ( AbstractShape, 71 | labelled, 72 | unlabelled 73 | ), 74 | ) 75 | 76 | newtype AnalysisState = AnalysisState 77 | { termRClassTensors :: 78 | HM.HashMap 79 | SomeTerm 80 | (HS.HashSet RClassIdentifier, HS.HashSet TensorIdentifier) 81 | } 82 | deriving (Show, Generic) 83 | deriving (PPrint) via (Default AnalysisState) 84 | 85 | hasRClass :: AnalysisState -> RClassIdentifier -> SomeTerm -> Bool 86 | hasRClass AnalysisState {..} rclass st = 87 | case HM.lookup st termRClassTensors of 88 | Just (rclasses, _) -> HS.member rclass rclasses 89 | Nothing -> error "Term not found" 90 | 91 | analysisTerm :: SomeTerm -> AnalysisState 92 | analysisTerm someTerm = 93 | execState (analysisTermState someTerm) (AnalysisState HM.empty) 94 | 95 | analysisTermState :: 96 | SomeTerm -> 97 | State AnalysisState (HS.HashSet RClassIdentifier, HS.HashSet TensorIdentifier) 98 | analysisTermState someTerm = do 99 | (rclasses, tensors) <- analysisTermState' someTerm 100 | modify 101 | ( \s -> 102 | s 103 | { termRClassTensors = 104 | HM.insert someTerm (rclasses, tensors) (termRClassTensors s) 105 | } 106 | ) 107 | return (rclasses, tensors) 108 | 109 | analysisTermState' :: 110 | SomeTerm -> 111 | State AnalysisState (HS.HashSet RClassIdentifier, HS.HashSet TensorIdentifier) 112 | analysisTermState' (SomeTerm t) = do 113 | case t of 114 | ConTerm {} -> return (HS.empty, HS.empty) 115 | SymTerm symb -> case unTypedSymbol symb of 116 | SimpleSymbol (Identifier _ (List [])) -> 117 | return (HS.empty, HS.empty) 118 | SimpleSymbol (Identifier _ meta) -> 119 | case meta of 120 | Metadata (SymMap rclass _ _) -> do 121 | return (HS.singleton rclass, HS.empty) 122 | Metadata (SymTensor t) -> do 123 | return (HS.empty, HS.singleton t) 124 | _ -> error $ "Unexpected metadata: " <> show meta 125 | _ -> return (HS.empty, HS.empty) 126 | SubTerms ts -> do 127 | r <- traverse analysisTermState ts 128 | return $ mconcat r 129 | 130 | getAllConditions :: AnalysisState -> SomeTerm -> HS.HashSet SomeTerm 131 | getAllConditions state@AnalysisState {..} st@(SomeTerm t) = 132 | case HM.lookup st termRClassTensors of 133 | Just (rclasses, _) | HS.null rclasses -> HS.empty 134 | Just (_, tensors) | HS.null tensors -> 135 | case cast t of 136 | Just (_ :: Term Bool) -> HS.singleton st 137 | Nothing -> goBody t 138 | Just _ -> goBody t 139 | Nothing -> error "Term not found" 140 | where 141 | goSome :: SomeTerm -> HS.HashSet SomeTerm 142 | goSome = getAllConditions state 143 | goBody :: forall a. Term a -> HS.HashSet SomeTerm 144 | goBody (ConTerm _) = error "Should not happen" 145 | goBody (SymTerm _) = error "Should not happen" 146 | goBody ApplyTerm {} = HS.empty 147 | goBody (SubTerms ts) = mconcat $ goSome <$> ts 148 | 149 | getAllAccesses :: AnalysisState -> SomeTerm -> HS.HashSet SomeTerm 150 | getAllAccesses state@AnalysisState {..} st@(SomeTerm t) = 151 | case t of 152 | ConTerm _ -> HS.empty 153 | SymTerm _ -> HS.empty 154 | ApplyTerm {} -> HS.singleton st 155 | SubTerms ts -> mconcat $ go <$> ts 156 | where 157 | go :: SomeTerm -> HS.HashSet SomeTerm 158 | go = getAllAccesses state 159 | 160 | conditionEquivalent :: 161 | GrisetteSMTConfig -> SymBool -> SomeTerm -> SomeTerm -> IO Bool 162 | conditionEquivalent solverConfig allPreCond (SomeTerm l) (SomeTerm r) = 163 | case (cast l, cast r) of 164 | (Just (a11 :: Term Bool), Just (a12 :: Term Bool)) -> do 165 | r <- 166 | solve solverConfig (allPreCond .&& SymBool a11 ./= SymBool a12) 167 | case r of 168 | Left Unsat -> return True 169 | Left err -> fail $ "Unexpected solver failure: " <> show err 170 | Right _ -> return False 171 | _ -> error "Not conditions" 172 | 173 | accessEquivalent :: 174 | GrisetteSMTConfig -> SymBool -> SomeTerm -> SomeTerm -> IO Bool 175 | accessEquivalent solverConfig allPreCond (SomeTerm l) (SomeTerm r) = 176 | case (l, r) of 177 | (ApplyTerm (f1 :: Term f1) (args1 :: Term args1), ApplyTerm f2 args2) -> 178 | case (cast f2, cast args2) of 179 | (Just (f2' :: Term f1), _) | f1 /= f2' -> return False 180 | (Just (_ :: Term f1), Just (args2' :: Term args1)) -> do 181 | r <- 182 | solve 183 | solverConfig 184 | (allPreCond .&& SymBool (DistinctTerm (args1 :| [args2']))) 185 | case r of 186 | Left Unsat -> return True 187 | Left err -> fail $ "Unexpected solver failure: " <> show err 188 | Right _ -> return False 189 | _ -> return False 190 | _ -> error "Not accesses" 191 | 192 | filterPairs :: 193 | (SomeTerm -> SomeTerm -> IO Bool) -> 194 | HS.HashSet SomeTerm -> 195 | IO (HS.HashSet SomeTerm) 196 | filterPairs eqv conditions = do 197 | HS.fromList <$> go (sortOn (\x -> -someTermSize x) $ HS.toList conditions) 198 | where 199 | go [] = return [] 200 | go (x : xs) = do 201 | r <- go1 x xs 202 | t <- go xs 203 | if r 204 | then return t 205 | else return (x : t) 206 | go1 _ [] = return False 207 | go1 x (y : ys) = do 208 | r <- eqv x y 209 | rs <- go1 x ys 210 | return $ r || rs 211 | 212 | groupAccessByTensors :: [SomeTerm] -> [[SomeTerm]] 213 | groupAccessByTensors = groupBy (on (==) termTensor) . sortOn termTensor 214 | where 215 | termTensor :: SomeTerm -> TensorIdentifier 216 | termTensor 217 | ( SomeTerm 218 | ( ApplyTerm 219 | ( SymTerm 220 | ( TypedSymbol 221 | (SimpleSymbol (Identifier _ (Metadata (SymTensor t)))) 222 | ) 223 | ) 224 | _ 225 | ) 226 | ) = t 227 | termTensor _ = error "Should not happen" 228 | 229 | inferBound :: 230 | GrisetteSMTConfig -> 231 | VerifyTask -> 232 | HS.HashSet RClassIdentifier -> 233 | HS.HashSet RClassIdentifier -> 234 | AbstractShape -> 235 | IO (HM.HashMap RClassIdentifier Int) 236 | inferBound 237 | solverConfig 238 | (VerifyTask _ lhs rhs pre siRelation _ _ _ _ _ _ _ _ _) 239 | nonSingletonRClasses 240 | singletonRClasses 241 | sp = do 242 | let preCond = pre 243 | when (preCond == con False) $ 244 | fail "verified (precondition is false)" 245 | 246 | (lhsTensorIsValid, lhsTensor) <- 247 | getTensorWithValidityCondition "lhs-tensor" lhs 248 | (rhsTensorIsValid, rhsTensor) <- 249 | getTensorWithValidityCondition "rhs-tensor" rhs 250 | 251 | when (tensorDType lhsTensor /= tensorDType rhsTensor) $ 252 | fail "not verified (lhs and rhs have different types)" 253 | let access = abstractShapeAccess sp 254 | 255 | (lhsAccessIsValid, rhsAccessIsValid, equivalent) <- 256 | rewritingRuleAccess lhsTensor rhsTensor access 257 | let st = analysisTerm $ SomeTerm $ underlyingTerm equivalent 258 | let allPreCond = 259 | preCond 260 | .&& siRelation 261 | .&& lhsTensorIsValid 262 | .&& rhsTensorIsValid 263 | .&& lhsAccessIsValid 264 | .&& rhsAccessIsValid 265 | let allConditions = 266 | getAllConditions st $ 267 | SomeTerm $ 268 | underlyingTerm equivalent 269 | filteredConditions <- 270 | filterPairs (conditionEquivalent solverConfig allPreCond) allConditions 271 | let allAccesses = getAllAccesses st $ SomeTerm $ underlyingTerm equivalent 272 | filteredAccesses <- filterPairs (accessEquivalent solverConfig allPreCond) allAccesses 273 | putStrLn $ "# all conditions: " <> show (HS.size allConditions) 274 | putStrLn $ "# all accesses: " <> show (HS.size allAccesses) 275 | putStrLn $ "# filtered conditions: " <> show (HS.size filteredConditions) 276 | putStrLn $ "# filtered accesses: " <> show (HS.size filteredAccesses) 277 | let groupedAccesses = groupAccessByTensors $ HS.toList filteredAccesses 278 | 279 | let numHasRClassInGroup :: RClassIdentifier -> [SomeTerm] -> Int 280 | numHasRClassInGroup rclass = length . filter (hasRClass st rclass) 281 | factorial :: Int -> Int 282 | factorial 0 = 1 283 | factorial n = n * factorial (n - 1) 284 | kFromAccess :: RClassIdentifier -> [SomeTerm] -> Int 285 | kFromAccess rclass group = case numHasRClassInGroup rclass group of 286 | v | v < 2 -> 0 287 | v -> factorial v `div` 2 `div` factorial (v - 2) 288 | 289 | let kFromAllAccesses rclass = sum $ kFromAccess rclass <$> groupedAccesses 290 | let kForRClass rclass = 291 | max 1 $ 292 | kFromAllAccesses rclass 293 | + numHasRClassInGroup rclass (HS.toList filteredConditions) 294 | return $ 295 | HM.fromList $ 296 | ( (\rclass -> (rclass, kForRClass rclass)) 297 | <$> HS.toList (nonSingletonRClasses `HS.difference` singletonRClasses) 298 | ) 299 | <> ((,1) <$> HS.toList singletonRClasses) 300 | 301 | abstractShapeAccess :: AbstractShape -> Indices 302 | abstractShapeAccess AbstractShape {..} = do 303 | fromHashMap $ unlabelledRClassAccesses <> labelledRClassAccess 304 | where 305 | unlabelledRClassAccesses = 306 | HM.fromList $ 307 | ( \rclass -> 308 | ( Axis $ getAxisName rclass 0, 309 | ssym $ 310 | withMetadata "access" $ 311 | SymMap rclass 0 "#accnolabel" 312 | ) 313 | ) 314 | <$> HS.toList unlabelled 315 | labelledRClassAccess = 316 | HM.fromList $ 317 | ( \(label, rclass) -> 318 | ( LabelledAxis label $ getAxisName rclass 0, 319 | ssym $ 320 | withMetadata "access" $ 321 | SymMap rclass 0 $ 322 | fromString $ 323 | T.unpack label 324 | ) 325 | ) 326 | <$> HM.toList labelled 327 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/Core/Tensor/TensorInt.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ConstraintKinds #-} 2 | {-# LANGUAGE DeriveAnyClass #-} 3 | {-# LANGUAGE DeriveGeneric #-} 4 | {-# LANGUAGE DerivingVia #-} 5 | {-# LANGUAGE FlexibleContexts #-} 6 | {-# LANGUAGE GeneralizedNewtypeDeriving #-} 7 | {-# LANGUAGE MonoLocalBinds #-} 8 | {-# LANGUAGE OverloadedStrings #-} 9 | {-# LANGUAGE PatternSynonyms #-} 10 | {-# LANGUAGE ScopedTypeVariables #-} 11 | {-# LANGUAGE StandaloneDeriving #-} 12 | {-# LANGUAGE TemplateHaskell #-} 13 | {-# LANGUAGE TypeApplications #-} 14 | {-# LANGUAGE TypeOperators #-} 15 | {-# LANGUAGE UndecidableInstances #-} 16 | 17 | module TensorRight.Internal.Core.Tensor.TensorInt 18 | ( TensorNum (..), 19 | TensorDivMod (..), 20 | TensorExp (..), 21 | TensorInt, 22 | TensorReal, 23 | nonInf, 24 | posInf, 25 | negInf, 26 | tensorValSymMax, 27 | tensorValSymMin, 28 | tensorValEq, 29 | tensorValNe, 30 | tensorValLt, 31 | tensorValGt, 32 | tensorValLe, 33 | tensorValGe, 34 | IsTensorNum, 35 | ) 36 | where 37 | 38 | import Data.Hashable (Hashable) 39 | import Data.String (IsString (fromString)) 40 | import Grisette 41 | ( DivOr (divOr, modOr, remOr), 42 | EvalSym, 43 | FdivOr (fdivOr), 44 | ITEOp (symIte), 45 | LogicalOp (symNot, symXor, (.&&), (.||)), 46 | Mergeable, 47 | PPrint (pformat), 48 | SimpleMergeable (mrgIte), 49 | Solvable (con), 50 | SymAlgReal, 51 | SymBool, 52 | SymEq ((.==)), 53 | SymInteger, 54 | SymOrd ((.<), (.>)), 55 | Union, 56 | deriveGADT, 57 | liftUnion, 58 | mrgFmap, 59 | mrgIf, 60 | mrgReturn, 61 | symMax, 62 | symMin, 63 | pattern Con, 64 | ) 65 | import Grisette.Lib.Control.Monad.Except (mrgThrowError) 66 | import TensorRight.Internal.Util.Error (ErrorEnv) 67 | 68 | data TensorNumBase a 69 | = NonInf a 70 | | -- | Positive or negative infinity 71 | Inf SymBool 72 | | Unknown 73 | 74 | deriveGADT 75 | [''TensorNumBase] 76 | [ ''Show, 77 | ''Mergeable, 78 | ''Eq, 79 | ''Hashable, 80 | ''EvalSym 81 | ] 82 | 83 | -- | A tensor of numerical values that may contain positive or negative infinity. 84 | newtype TensorNum a = TensorNum (Union (TensorNumBase a)) 85 | 86 | type TensorInt = TensorNum SymInteger 87 | 88 | type TensorReal = TensorNum SymAlgReal 89 | 90 | instance (PPrint a) => PPrint (TensorNumBase a) where 91 | pformat (NonInf u) = pformat u 92 | pformat (Inf u) = case u of 93 | Con v -> if v then "inf" else "-inf" 94 | _ -> "?inf" 95 | pformat Unknown = "unk" 96 | 97 | instance (SymEq a) => SymEq (TensorNumBase a) where 98 | NonInf a .== NonInf b = a .== b 99 | Inf a .== Inf b = a .== b 100 | Unknown .== Unknown = con False 101 | _ .== _ = con False 102 | 103 | deriveGADT 104 | [''TensorNum] 105 | [ ''Show, 106 | ''SymEq, 107 | ''Eq, 108 | ''Hashable, 109 | ''EvalSym, 110 | ''PPrint 111 | ] 112 | 113 | deriving newtype instance (Mergeable a) => Mergeable (TensorNum a) 114 | 115 | deriving newtype instance (SimpleMergeable a) => SimpleMergeable (TensorNum a) 116 | 117 | instance (Mergeable a) => ITEOp (TensorNum a) where 118 | symIte c (TensorNum l) (TensorNum r) = TensorNum $ mrgIte c l r 119 | 120 | -- | Wrap a symbolic integer into a tensor numerical value. 121 | nonInf :: (Mergeable a) => a -> TensorNum a 122 | nonInf = TensorNum . mrgReturn . NonInf 123 | 124 | -- | Positive infinity. 125 | posInf :: (Mergeable a) => TensorNum a 126 | posInf = TensorNum $ mrgReturn $ Inf (con True) 127 | 128 | -- | Negative infinity. 129 | negInf :: (Mergeable a) => TensorNum a 130 | negInf = TensorNum $ mrgReturn $ Inf (con False) 131 | 132 | instance (Mergeable a, Num a, SymEq a, SymOrd a, ITEOp a) => Num (TensorNum a) where 133 | (TensorNum l) + (TensorNum r) = TensorNum $ do 134 | l1 <- l 135 | r1 <- r 136 | case (l1, r1) of 137 | (NonInf lv, NonInf rv) -> mrgReturn $ NonInf $ lv + rv 138 | (Inf lv, Inf rv) -> 139 | mrgIf (lv .== rv) (mrgReturn $ Inf lv) (mrgReturn Unknown) 140 | (NonInf _, Inf rv) -> mrgReturn $ Inf rv 141 | (Inf lv, NonInf _) -> mrgReturn $ Inf lv 142 | _ -> mrgReturn Unknown 143 | (TensorNum l) * (TensorNum r) = TensorNum $ do 144 | l1 <- l 145 | r1 <- r 146 | case (l1, r1) of 147 | (NonInf lv, NonInf rv) -> mrgReturn $ NonInf $ lv * rv 148 | (Inf lv, Inf rv) -> 149 | mrgIf 150 | (symXor lv rv) 151 | (mrgReturn $ Inf (con False)) 152 | (mrgReturn $ Inf (con True)) 153 | (NonInf lv, Inf rv) -> 154 | mrgIf 155 | (lv .== 0) 156 | (mrgReturn Unknown) 157 | (mrgIf (lv .> 0) (mrgReturn $ Inf rv) (mrgReturn $ Inf (symNot rv))) 158 | (Inf lv, NonInf rv) -> 159 | mrgIf 160 | (rv .== 0) 161 | (mrgReturn Unknown) 162 | (mrgIf (rv .> 0) (mrgReturn $ Inf lv) (mrgReturn $ Inf (symNot lv))) 163 | _ -> mrgReturn Unknown 164 | abs (TensorNum l) = TensorNum $ do 165 | l1 <- l 166 | case l1 of 167 | NonInf lv -> mrgReturn $ NonInf $ abs lv 168 | Inf _ -> mrgReturn $ Inf (con True) 169 | Unknown -> mrgReturn Unknown 170 | signum (TensorNum l) = TensorNum $ do 171 | l1 <- l 172 | case l1 of 173 | NonInf lv -> mrgReturn $ NonInf $ signum lv 174 | Inf lv -> mrgReturn $ NonInf $ symIte lv 1 (-1) 175 | Unknown -> mrgReturn Unknown 176 | fromInteger i = TensorNum $ mrgReturn $ NonInf $ fromInteger i 177 | negate (TensorNum l) = TensorNum $ do 178 | l1 <- l 179 | case l1 of 180 | NonInf lv -> mrgReturn $ NonInf $ negate lv 181 | Inf lv -> mrgReturn $ Inf lv 182 | Unknown -> mrgReturn Unknown 183 | 184 | -- | Computes the maximum of two tensor numerical values. 185 | tensorValSymMax :: 186 | (Mergeable a, SymOrd a, ITEOp a) => TensorNum a -> TensorNum a -> TensorNum a 187 | tensorValSymMax (TensorNum l) (TensorNum r) = TensorNum $ do 188 | l1 <- l 189 | r1 <- r 190 | case (l1, r1) of 191 | (NonInf lv, NonInf rv) -> mrgReturn $ NonInf $ symMax lv rv 192 | (Inf lv, Inf rv) -> mrgReturn $ Inf $ lv .|| rv 193 | (NonInf lv, Inf rv) -> mrgIf rv (mrgReturn $ Inf rv) (mrgReturn $ NonInf lv) 194 | (Inf lv, NonInf rv) -> mrgIf lv (mrgReturn $ Inf lv) (mrgReturn $ NonInf rv) 195 | _ -> mrgReturn Unknown 196 | 197 | -- | Computes the minimum of two tensor numerical values. 198 | tensorValSymMin :: 199 | (Mergeable a, SymOrd a, ITEOp a) => TensorNum a -> TensorNum a -> TensorNum a 200 | tensorValSymMin (TensorNum l) (TensorNum r) = TensorNum $ do 201 | l1 <- l 202 | r1 <- r 203 | case (l1, r1) of 204 | (NonInf lv, NonInf rv) -> mrgReturn $ NonInf $ symMin lv rv 205 | (Inf lv, Inf rv) -> mrgReturn $ Inf $ lv .&& rv 206 | (NonInf lv, Inf rv) -> mrgIf rv (mrgReturn $ NonInf lv) (mrgReturn $ Inf rv) 207 | (Inf lv, NonInf rv) -> mrgIf lv (mrgReturn $ NonInf rv) (mrgReturn $ Inf lv) 208 | _ -> mrgReturn Unknown 209 | 210 | -- | Checks if two tensor numerical values are equal. 211 | tensorValEq :: (Mergeable a, SymEq a) => TensorNum a -> TensorNum a -> ErrorEnv SymBool 212 | tensorValEq (TensorNum l) (TensorNum r) = do 213 | l1 <- liftUnion l 214 | r1 <- liftUnion r 215 | case (l1, r1) of 216 | (NonInf lv, NonInf rv) -> return $ lv .== rv 217 | (Inf lv, Inf rv) -> return $ lv .== rv 218 | (Unknown, _) -> 219 | mrgThrowError "tensorValEq: Unsupported reasoning: comparing unknown values" 220 | (_, Unknown) -> 221 | mrgThrowError "tensorValEq: Unsupported reasoning: comparing unknown values" 222 | _ -> return $ con False 223 | 224 | -- | Checks if two tensor numerical values are not equal. 225 | tensorValNe :: (Mergeable a, SymEq a) => TensorNum a -> TensorNum a -> ErrorEnv SymBool 226 | tensorValNe l r = mrgFmap symNot $ tensorValEq l r 227 | 228 | -- | Checks if one tensor numerical value is less than another. 229 | tensorValLt :: (Mergeable a, SymOrd a) => TensorNum a -> TensorNum a -> ErrorEnv SymBool 230 | tensorValLt (TensorNum l) (TensorNum r) = do 231 | l1 <- liftUnion l 232 | r1 <- liftUnion r 233 | case (l1, r1) of 234 | (NonInf lv, NonInf rv) -> mrgReturn $ lv .< rv 235 | (Inf lv, Inf rv) -> mrgReturn $ lv .< rv 236 | (Unknown, _) -> 237 | mrgThrowError "tensorValLt: Unsupported reasoning: comparing unknown values" 238 | (_, Unknown) -> 239 | mrgThrowError "tensorValLt: Unsupported reasoning: comparing unknown values" 240 | (NonInf _, Inf rv) -> mrgReturn rv 241 | (Inf lv, NonInf _) -> mrgReturn $ symNot lv 242 | 243 | -- | Checks if one tensor numerical value is greater than another. 244 | tensorValGt :: 245 | (Mergeable a, SymOrd a) => TensorNum a -> TensorNum a -> ErrorEnv SymBool 246 | tensorValGt = flip tensorValLt 247 | 248 | -- | Checks if one tensor numerical value is less than or equal to another. 249 | tensorValLe :: 250 | (Mergeable a, SymOrd a) => TensorNum a -> TensorNum a -> ErrorEnv SymBool 251 | tensorValLe l r = mrgFmap symNot $ tensorValLt r l 252 | 253 | -- | Checks if one tensor numerical value is greater than or equal to another. 254 | tensorValGe :: 255 | (Mergeable a, SymOrd a) => TensorNum a -> TensorNum a -> ErrorEnv SymBool 256 | tensorValGe l r = mrgFmap symNot $ tensorValLt l r 257 | 258 | instance (IsString a, Mergeable a) => IsString (TensorNum a) where 259 | fromString = nonInf . fromString 260 | 261 | class TensorExp a where 262 | -- | Computes the exponential of a tensor numerical value. 263 | -- This is currently only supported for real tensor values 264 | tensorExp :: TensorNum a -> TensorNum a 265 | 266 | instance TensorExp SymInteger where 267 | tensorExp = error "Not supported" 268 | 269 | instance TensorExp SymAlgReal where 270 | tensorExp (TensorNum l) = TensorNum $ do 271 | l1 <- l 272 | case l1 of 273 | NonInf lv -> mrgReturn $ NonInf $ exp lv 274 | Inf lv -> mrgIf lv (return $ Inf lv) (return $ NonInf 0) 275 | Unknown -> mrgReturn Unknown 276 | 277 | class TensorDivMod a where 278 | -- | Divides one tensor numerical value by another. 279 | tensorDiv :: TensorNum a -> TensorNum a -> TensorNum a 280 | 281 | -- | Computes the quotient of one tensor numerical value by another. 282 | tensorMod :: TensorNum a -> TensorNum a -> TensorNum a 283 | 284 | -- | Computes the remainder of one tensor numerical value by another. 285 | tensorRem :: TensorNum a -> TensorNum a -> TensorNum a 286 | 287 | instance TensorDivMod SymInteger where 288 | tensorDiv (TensorNum l) (TensorNum r) = TensorNum $ do 289 | l1 <- l 290 | r1 <- r 291 | case (l1, r1) of 292 | (NonInf lv, NonInf rv) -> mrgReturn $ NonInf $ divOr 0 lv rv 293 | -- (Inf lv, Inf rv) -> mrgReturn $ Unknown 294 | -- (NonInf lv, Inf rv) -> mrgReturn $ Inf rv 295 | -- (Inf lv, NonInf rv) -> mrgReturn $ Inf lv 296 | -- TODO: Overly approximated 297 | _ -> mrgReturn Unknown 298 | tensorMod (TensorNum l) (TensorNum r) = TensorNum $ do 299 | l1 <- l 300 | r1 <- r 301 | case (l1, r1) of 302 | (NonInf lv, NonInf rv) -> mrgReturn $ NonInf $ modOr lv lv rv 303 | -- (Inf lv, Inf rv) -> mrgReturn $ Unknown 304 | -- (NonInf lv, Inf rv) -> mrgReturn $ Inf rv 305 | -- (Inf lv, NonInf rv) -> mrgReturn $ Inf lv 306 | -- TODO: Overly approximated 307 | _ -> mrgReturn Unknown 308 | tensorRem (TensorNum l) (TensorNum r) = TensorNum $ do 309 | l1 <- l 310 | r1 <- r 311 | case (l1, r1) of 312 | (NonInf lv, NonInf rv) -> mrgReturn $ NonInf $ remOr lv lv rv 313 | -- (Inf lv, Inf rv) -> mrgReturn $ Unknown 314 | -- (NonInf lv, Inf rv) -> mrgReturn $ Inf rv 315 | -- (Inf lv, NonInf rv) -> mrgReturn $ Inf lv 316 | -- TODO: Overly approximated 317 | _ -> mrgReturn Unknown 318 | 319 | instance TensorDivMod SymAlgReal where 320 | tensorDiv (TensorNum l) (TensorNum r) = TensorNum $ do 321 | l1 <- l 322 | r1 <- r 323 | case (l1, r1) of 324 | (NonInf lv, NonInf rv) -> mrgReturn $ NonInf $ fdivOr 0 lv rv 325 | -- (Inf lv, Inf rv) -> mrgReturn $ Unknown 326 | -- (NonInf lv, Inf rv) -> mrgReturn $ Inf rv 327 | -- (Inf lv, NonInf rv) -> mrgReturn $ Inf lv 328 | -- TODO: Overly approximated 329 | _ -> mrgReturn Unknown 330 | tensorMod = error "Not supported" 331 | tensorRem = error "Not supported" 332 | 333 | type IsTensorNum a = 334 | ( Num a, 335 | SymOrd a, 336 | ITEOp a, 337 | SimpleMergeable a, 338 | TensorDivMod a, 339 | TensorExp a 340 | ) 341 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorRight 2 | 3 | TensorRight is an automatic tool to verify tensor graph rewrites for tensors of arbitrary ranks and sizes. 4 | Tensor Graph Rewriting is one of the key optimizations in Tensor Compilers such as [XLA](https://github.com/openxla/xla). 5 | 6 | ## Key Features of TensorRight 7 | 8 | - We introduce a core language, TensorRight DSL, to represent complex tensor graph rewrites with preconditions. 9 | - TensorRight DSL uses a novel axis definition, called _aggregated-axis_, which allows reasoning about an arbitrary number of dimensions. 10 | - TensorRight provides operator specifications that closely resemble [XLA-HLO](https://openxla.org/xla/operation_semantics). 11 | TensorRight implements the denotational semantics for these operators. 12 | - TensorRight presents an automatic verification strategy to verify tensor graph rewrites in the unbounded setting, i.e, for arbitrary ranks and sizes, by inferring a bound on aggregated-axis ranks, such that verifying the rewrite for all ranks within the bound implies correctness in the unbounded setting.
13 | Hence, TensorRight converts the _unbounded-verification_ proof obligation to a finite set of _bounded-verification_ proof obligations, which are then dispatched to an SMT solver using symbolic execution to automatically verify rewrite rules. 14 | - TensorRight is implemented in Haskell and uses [Grisette](https://github.com/lsrcz/grisette) as the symbolic evaluation engine. 15 | TensorRight can successfully represent 121 of the 175 rewrites present in [XLA's algebraic simplifier](https://github.com/openxla/xla/blob/main/xla/hlo/transforms/simplifiers/algebraic_simplifier.cc) and is able to verify 115 of those in the unbounded setting. 16 | 17 | ## Publications 18 | 19 | - [TensorRight: Automated Verification of Tensor Graph Rewrites](https://dl.acm.org/doi/10.1145/3704865)
20 | Jai Arora, Sirui Lu, Devansh Jain, Tianfan Xu, Farzin Houshmand, Phitchaya Mangpo Phothilimthana, Mohsen Lesani, Praveen Narayanan, Karthik Srinivasa Murthy, Rastislav Bodik, Amit Sabne, and Charith Mendis.
21 | In Proceedings of the 52nd ACM SIGPLAN Symposium on Principles of Programming Languages (POPL'25), January 2025, Denver, Colorado, USA (To Appear) 22 | 23 |
24 | BibTeX 25 |
@article{10.1145/3704865,
 26 | author = {Arora, Jai and Lu, Sirui and Jain, Devansh and Xu, Tianfan and Houshmand, Farzin and Phothilimthana, Phitchaya Mangpo and Lesani, Mohsen and Narayanan, Praveen and Murthy, Karthik Srinivasa and Bodik, Rastislav and Sabne, Amit and Mendis, Charith},
 27 | title = {TensorRight: Automated Verification of Tensor Graph Rewrites},
 28 | year = {2025},
 29 | issue_date = {January 2025},
 30 | publisher = {Association for Computing Machinery},
 31 | address = {New York, NY, USA},
 32 | volume = {9},
 33 | number = {POPL},
 34 | url = {https://doi.org/10.1145/3704865},
 35 | doi = {10.1145/3704865},
 36 | abstract = {Tensor compilers, essential for generating efficient code for deep learning models across various applications, employ tensor graph rewrites as one of the key optimizations. These rewrites optimize tensor computational graphs with the expectation of preserving semantics for tensors of arbitrary rank and size. Despite this expectation, to the best of our knowledge, there does not exist a fully automated verification system to prove the soundness of these rewrites for tensors of arbitrary rank and size. Previous works, while successful in verifying rewrites with tensors of concrete rank, do not provide guarantees in the unbounded setting.  To fill this gap, we introduce TensorRight, the first automatic verification system that can verify tensor graph rewrites for input tensors of arbitrary rank and size. We introduce a core language, TensorRight DSL, to represent rewrite rules using a novel axis definition, called aggregated-axis, which allows us to reason about an unbounded number of axes. We achieve unbounded verification by proving that there exists a bound on tensor ranks, under which bounded verification of all instances implies the correctness of the rewrite rule in the unbounded setting. We derive an algorithm to compute this rank using the denotational semantics of TensorRight DSL. TensorRight employs this algorithm to generate a finite number of bounded-verification proof obligations, which are then dispatched to an SMT solver using symbolic execution to automatically verify the correctness of the rewrite rules. We evaluate TensorRight’s verification capabilities by implementing rewrite rules present in XLA’s algebraic simplifier. The results demonstrate that TensorRight can prove the correctness of 115 out of 175 rules in their full generality, while the closest automatic, bounded-verification system can express only 18 of these rules.},
 37 | journal = {Proc. ACM Program. Lang.},
 38 | month = jan,
 39 | articleno = {29},
 40 | numpages = {32},
 41 | keywords = {Denotational Semantics, Tensor Compilers, Unbounded Verification}
 42 | }
 43 | 
44 |
45 | 46 | ## Installation 47 | 48 | ### Installing Stack 49 | 50 | `stack` and other tools in the Haskell Toolchain can be installed by following the instructions at [this link](https://www.haskell.org/ghcup/install/). 51 | 52 | ### Installing SMT Solvers 53 | 54 | To verify the implemented rewrite rules, you need to install the Z3 and cvc5 SMT Solvers and make them available through `PATH`. 55 | 56 | #### Installing Z3 57 | 58 | On Ubuntu, you can install Z3 with: 59 | 60 | ```bash 61 | apt update && apt install z3 62 | ``` 63 | 64 | On macOS, you can install Z3 with [Homebrew](https://brew.sh/): 65 | 66 | ```bash 67 | brew install z3 68 | ``` 69 | 70 | Please refer to the [Z3 homepage](https://github.com/Z3Prover/z3) for more details. 71 | 72 | #### Installing cvc5 73 | 74 | cvc5 can be installed by downloading one of the pre-built binaries from [here](https://cvc5.github.io/downloads.html) or [building it from source](https://cvc5.github.io/docs/cvc5-1.2.0/installation/installation.html). 75 | 76 | ### Testing your Installation 77 | 78 | You can test your installation by first cloning the repository, running regression tests and verifying rewrite rules. 79 | 80 | #### Build 81 | 82 | ```bash 83 | git clone https://github.com/ADAPT-uiuc/TensorRight.git && cd TensorRight/ && stack build 84 | 85 | # Regression Tests: all testcases should pass 86 | stack test 87 | 88 | # Verifying Rewrite Rules: 115/118 passed 89 | make verify 90 | ``` 91 | 92 | Running `make verify` tries to verify all the 118 implemented rewrite rules. 93 | It results in 3 expected timeouts (the actual number could vary). 94 | 95 | ## Usage 96 | 97 | We will now take a look at how we can use TensorRight DSL to express complex tensor graph rewrites with preconditions and verify them. 98 | Please refer to the [implemented rules](./rules/) for more examples. 99 | 100 | Consider the `DySliceToSlice` rule that we would like to express and verify in our DSL. 101 | 102 | $$ 103 | \mathsf{dy\hbox{-}slice}(\mathsf{X}, B, L) \Rightarrow_{E - B' = L \ \wedge \ P = 1 \ \wedge \ B' = B } \mathsf{slice}(\mathsf{X}, B', E, P) 104 | $$ 105 | 106 | The $\mathsf{dy\hbox{-}slice}$ operator extracts a sub-tensor from the input tensor $\mathsf{X}$, where the start-index for each axis is specified in $B$ and the length of the slice along each axis is passed in $L$. 107 | Meanwhile, the $\mathsf{slice}$ operator also extracts a sub-tensor from within a bounding box in the input tensor $\mathsf{X}$. 108 | The start-indices for the bounding box are specified in $B'$, while the end-indices (exclusive) are specified in $E$. 109 | $P$ specifies the stride for each axis, which determines the step size between elements in the bounding box. 110 | 111 | The `DySliceToSlice` rule is generally not correct, unless $E - B'$ (the size of the bounding box in $\mathsf{slice}$) is equal to $L$ (the length in $\mathsf{dy\hbox{-}slice}$). 112 | The other requirements are that $\mathsf{slice}$ should skip no elements, i.e., $P=1$, and the start indices in $\mathsf{slice}$ and $\mathsf{dy\hbox{-}slice}$ must be the same, i.e., $B' = B$. 113 | Since these are specified in the precondition, the RHS expression is equivalent to the LHS expression. 114 | 115 | We support verification of boolean, integer, and real valued tensors. 116 | Since we would like to verify the `DySliceToSlice` rule for all tensor types, we declare the rule in our DSL as follows: 117 | 118 | ```haskell 119 | rule :: forall a. AnyDTypeRule a 120 | rule = do 121 | ... 122 | ``` 123 | 124 | We can use the type parameter `a` inside the rule definition to declare tensors of a polymorphic type. 125 | 126 | We would like to verify the rule for an arbitrary number of named-axes in $\mathsf{X}$. 127 | Since there is only one "role" of axes in the rewrite rule, i.e., every axis is getting sliced, we need only one aggregated-axis or one `RClass`, which we can declare using `newRClass`: 128 | 129 | ```haskell 130 | rcls <- newRClass "rcls" 131 | ``` 132 | 133 | `rcls` can be thought of as an abstract set of named-axes, which can be instantiated to any number of named-axes. 134 | This allows us to specify an abstract representation of a rewrite rule, which can be specialized to any rank. 135 | 136 | We also want to verify the rule for arbitrary sizes and operator attributes like $B$, $E$, $L$, etc. 137 | We represent these using abstract maps, which can be instantiated to maps of concrete rank. 138 | We can declare maps on an `RClass` in our DSL using `newMaps`: 139 | 140 | ```haskell 141 | [size, start, start', length, end, stride] <- 142 | newMaps ["size", "start", "start'", "length", "end", "stride"] rcls 143 | ``` 144 | 145 | We then declare an abstract tensor of shape `rcls --> size` containing elements of type `a` using `newTensor`: 146 | 147 | ```haskell 148 | tensor <- newTensor @a "X" [rcls --> size] 149 | ``` 150 | 151 | The resulting tensor is said to have arbitrary values of type `a`. 152 | 153 | We define LHS and RHS tensor expressions using the operators available in our DSL: 154 | 155 | ```haskell 156 | lhs <- 157 | dynamicSlice tensor $ 158 | DySlice {start = [rcls --> start], sizes = [rcls --> length]} 159 | rhs <- 160 | slice tensor $ 161 | Slice 162 | { start = [rcls --> start'], 163 | end = [rcls --> end], 164 | strides = [rcls --> stride] 165 | } 166 | ``` 167 | 168 | We can specify preconditions using `precondition`: 169 | 170 | ```haskell 171 | precondition [end, start', length] $ \[e, s', l] -> e - s' .== l 172 | precondition [stride] $ \[p] -> p .== 1 173 | precondition [start, start'] $ \[s, s'] -> s' .== s 174 | ``` 175 | 176 | Finally, we declare a rewrite rule using the `rewrite` construct: 177 | 178 | ```haskell 179 | rewrite "DynamicSlice(X) => Slice(X)" lhs rhs 180 | ``` 181 | 182 | Putting everything together, the specification of the `DySliceToSlice` rule in TensorRight DSL looks like the following: 183 | 184 | ```haskell 185 | rule :: forall a. AnyDTypeRule a 186 | rule = do 187 | rcls <- newRClass "rcls" 188 | [size, start, start', length, end, stride] <- 189 | newMaps ["size", "start", "start'", "length", "end", "stride"] rcls 190 | tensor <- newTensor @a "X" [rcls --> size] 191 | 192 | lhs <- 193 | dynamicSlice tensor $ 194 | DySlice {start = [rcls --> start], sizes = [rcls --> length]} 195 | rhs <- 196 | slice tensor $ 197 | Slice 198 | { start = [rcls --> start'], 199 | end = [rcls --> end], 200 | strides = [rcls --> stride] 201 | } 202 | 203 | precondition [end, start', length] $ 204 | \[end, start', length] -> end - start' .== length 205 | precondition [stride] $ \[stride] -> stride .== 1 206 | precondition [start, start'] $ \[start, start'] -> start' .== start 207 | 208 | rewrite "DynamicSlice(X) => Slice(X)" lhs rhs 209 | ``` 210 | 211 | We can verify the rule by using `verifyAnyDTypeDSL`: 212 | 213 | ```haskell 214 | main :: IO () 215 | main = do verifyAnyDTypeDSL rule 216 | ``` 217 | 218 | ## Documentation 219 | 220 | Please build the haddock doc using: 221 | ```bash 222 | stack haddock 223 | ``` 224 | 225 | This will build the documentation in a folder like: 226 | 227 | ```bash 228 | .stack-work/install/x86_64-linux//9.8.2/doc/index.html 229 | ``` 230 | 231 | You can navigate to have a look at the full API documentation. If you are using 232 | vscode, the live server plugin might be helpful for hosting the documentation. 233 | 234 | ### Code Formatting 235 | 236 | We use [ormolu](https://hackage.haskell.org/package/ormolu) for formatting 237 | Haskell source code. 238 | 239 | ## License 240 | TensorRight is distributed under the terms of the Apache-2.0 license. 241 | The [LICENSE](./LICENSE) file contains the full license text. 242 | -------------------------------------------------------------------------------- /src/TensorRight/Internal/DSL/Shape.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveAnyClass #-} 2 | {-# LANGUAGE DeriveGeneric #-} 3 | {-# LANGUAGE DerivingVia #-} 4 | {-# LANGUAGE DuplicateRecordFields #-} 5 | {-# LANGUAGE FlexibleContexts #-} 6 | {-# LANGUAGE FlexibleInstances #-} 7 | {-# LANGUAGE MultiParamTypeClasses #-} 8 | {-# LANGUAGE OverloadedStrings #-} 9 | {-# LANGUAGE RecordWildCards #-} 10 | 11 | module TensorRight.Internal.DSL.Shape 12 | ( TensorShape (..), 13 | RClassRef (..), 14 | RClassRefSet, 15 | AbstractShape (..), 16 | toAbstractShape, 17 | TensorShapeLike (toTensorShape), 18 | TensorShapeDesc (..), 19 | abstractShapeAllRefs, 20 | removeRClass, 21 | getRClassByRClassRef, 22 | addRClassByRClassRef, 23 | concatAbstractShape, 24 | restrictAbstractShape, 25 | ) 26 | where 27 | 28 | import Control.Monad (when) 29 | import Control.Monad.Error.Class (MonadError (throwError)) 30 | import qualified Data.HashMap.Lazy as HM 31 | import qualified Data.HashSet as HS 32 | import Data.Hashable (Hashable (hashWithSalt)) 33 | import Data.List (sortBy) 34 | import qualified Data.Text as T 35 | import GHC.Generics (Generic) 36 | import Grisette (Default (Default), PPrint (pformat, pformatPrec), TryMerge) 37 | import TensorRight.Internal.DSL.Identifier (Label, MapIdentifier, RClassIdentifier) 38 | import TensorRight.Internal.DSL.Syntax (ArrowSyntax ((-->)), AtSyntax ((@@))) 39 | import TensorRight.Internal.Util.Error (Error, assert) 40 | import TensorRight.Internal.Util.Pretty (encloseList, prettyWithConstructor) 41 | 42 | -- | Reference to an RClass, also referred to as an aggregated-axis in our DSL. 43 | -- An RClass may be labelled or unlabelled. 44 | -- If an RClass is unlabelled, then it means that the RClass has exactly 45 | -- one-aggregated-axis, and the RClass itself acts as a reference to the 46 | -- said aggregated-axis. 47 | -- If the RClass is labelled, then the label acts as a reference to the 48 | -- aggregated-axis. In this case, the RClass can have multiple labels, 49 | -- each referring to a different aggregated-axis on that RClass. 50 | data RClassRef 51 | = ByRClass RClassIdentifier 52 | | ByLabel Label 53 | deriving (Generic, Eq, Ord, Show) 54 | deriving (Hashable) 55 | deriving (PPrint) via (Default RClassRef) 56 | 57 | type RClassRefSet = HS.HashSet RClassRef 58 | 59 | -- | A Tensor Shape is a mapping from aggregated-axes to aggregated-maps, which contain axes sizes. 60 | -- Each mapping in a tensor shape can be labelled or unlabelled. 61 | -- 62 | -- - A labelled mapping contains the name for an aggregated-axis ('Label'), 63 | -- an 'RClassIdentifier', and a 'MapIdentifier'. 64 | -- - An unlabelled mapping contains only an 'RClassIdentifier' (which acts as the aggregated-axis name) and a 'MapIdentifier'. 65 | data TensorShape = TensorShape 66 | { labelled :: HM.HashMap Label (RClassIdentifier, MapIdentifier), 67 | unlabelled :: HM.HashMap RClassIdentifier MapIdentifier 68 | } 69 | deriving (Show, Generic) 70 | 71 | instance Eq TensorShape where 72 | (TensorShape l1 u1) == (TensorShape l2 u2) = l1 == l2 && u1 == u2 73 | 74 | instance Hashable TensorShape where 75 | hashWithSalt salt (TensorShape labelled unlabelled) = 76 | salt 77 | `hashWithSalt` HM.toList labelled 78 | `hashWithSalt` HM.toList unlabelled 79 | 80 | instance PPrint TensorShape where 81 | pformatPrec n (TensorShape labelled unlabelled) = 82 | prettyWithConstructor 83 | n 84 | "TensorShape" 85 | [ encloseList "{" "}" "," $ 86 | [ prettyLabelled label rclass map 87 | | (label, (rclass, map)) <- HM.toList labelled 88 | ] 89 | ++ [ prettyUnlabelled rclass map 90 | | (rclass, map) <- HM.toList unlabelled 91 | ] 92 | ] 93 | where 94 | prettyLabelled label rclass map = 95 | pformat rclass <> " -> " <> pformat map <> " @@ " <> pformat label 96 | prettyUnlabelled rclass map = pformat rclass <> " -> " <> pformat map 97 | 98 | data PartialTensorShapeDesc 99 | = PartialTensorShapeDesc RClassIdentifier MapIdentifier 100 | 101 | -- | A 'TensorShapeDesc' describes a mapping in a tensor shape. 102 | -- It has two constructors: 103 | -- 104 | -- - 'UnlabelledDesc': @'RClassIdentifier' 'TensorRight.Internal.DSL.Syntax.-->' 'MapIdentifier'@ can be used to create an unlabelled mapping. 105 | -- - 'LabelledDesc': @'RClassIdentifier' 'TensorRight.Internal.DSL.Syntax.-->' 'MapIdentifier' 'TensorRight.Internal.DSL.Syntax.@@' 'Label'@ can be used to create a labelled mapping. 106 | data TensorShapeDesc 107 | = UnlabelledDesc RClassIdentifier MapIdentifier 108 | | LabelledDesc Label RClassIdentifier MapIdentifier 109 | 110 | instance ArrowSyntax RClassIdentifier MapIdentifier PartialTensorShapeDesc where 111 | (-->) = PartialTensorShapeDesc 112 | 113 | instance ArrowSyntax RClassIdentifier MapIdentifier TensorShapeDesc where 114 | (-->) = UnlabelledDesc 115 | 116 | instance AtSyntax PartialTensorShapeDesc Label TensorShapeDesc where 117 | PartialTensorShapeDesc rclass map @@ label = LabelledDesc label rclass map 118 | 119 | getTensorShape' :: 120 | (MonadError Error m) => [TensorShapeDesc] -> m TensorShape 121 | getTensorShape' [] = return $ TensorShape HM.empty HM.empty 122 | getTensorShape' (UnlabelledDesc rclass map : rest) = do 123 | TensorShape labelled unlabelled <- getTensorShape' rest 124 | when (HM.member rclass unlabelled) $ throwError "Duplicate rclass without labels" 125 | return $ 126 | TensorShape labelled (HM.insert rclass map unlabelled) 127 | getTensorShape' (LabelledDesc label rclass map : rest) = do 128 | TensorShape labelled unlabelled <- getTensorShape' rest 129 | when (HM.member label labelled) $ throwError "Duplicate label" 130 | when (HM.member rclass unlabelled) $ 131 | throwError "Labelled rclass already present as unlabelled" 132 | return $ TensorShape (HM.insert label (rclass, map) labelled) unlabelled 133 | 134 | -- | A function to convert a list of 'TensorShapeDesc' to a t'TensorShape' 135 | -- The function makes the following checks: 136 | -- 137 | -- - No two unlabelled mappings can have the same 'RClassIdentifier'. 138 | -- - No two labelled mappings can have the same 'Label'. 139 | -- - A labelled mapping cannot have the same 'RClassIdentifier' as an unlabelled mapping. 140 | getTensorShape :: 141 | (MonadError Error m) => [TensorShapeDesc] -> m TensorShape 142 | getTensorShape descs = 143 | getTensorShape' $ 144 | sortBy 145 | ( \a b -> 146 | case (a, b) of 147 | (LabelledDesc {}, UnlabelledDesc {}) -> LT 148 | (UnlabelledDesc {}, LabelledDesc {}) -> GT 149 | _ -> EQ 150 | ) 151 | descs 152 | 153 | class TensorShapeLike a where 154 | toTensorShape :: (MonadError Error m) => a -> m TensorShape 155 | 156 | instance TensorShapeLike TensorShape where 157 | toTensorShape = return 158 | 159 | instance TensorShapeLike [TensorShapeDesc] where 160 | toTensorShape = getTensorShape 161 | 162 | -- | Represents a t'TensorShape' without any 'MapIdentifier' information. 163 | data AbstractShape = AbstractShape 164 | { labelled :: HM.HashMap Label RClassIdentifier, 165 | unlabelled :: HS.HashSet RClassIdentifier 166 | } 167 | deriving (Show) 168 | 169 | instance Eq AbstractShape where 170 | (AbstractShape l1 u1) == (AbstractShape l2 u2) = l1 == l2 && u1 == u2 171 | 172 | instance PPrint AbstractShape where 173 | pformatPrec n (AbstractShape labelled unlabelled) = 174 | prettyWithConstructor 175 | n 176 | "AbstractShape" 177 | [ encloseList "{" "}" "," $ 178 | [prettyLabelled label rclass | (label, rclass) <- HM.toList labelled] 179 | ++ [pformat rclass | rclass <- HS.toList unlabelled] 180 | ] 181 | where 182 | prettyLabelled label rclass = pformat rclass <> " @@ " <> pformat label 183 | 184 | -- | Converts a t'TensorShape' to an t'AbstractShape' 185 | toAbstractShape :: TensorShape -> AbstractShape 186 | toAbstractShape (TensorShape labelled unlabelled) = 187 | AbstractShape 188 | { labelled = HM.map fst labelled, 189 | unlabelled = HM.keysSet unlabelled 190 | } 191 | 192 | -- | Returns all 'RClassRef's in an t'AbstractShape' 193 | abstractShapeAllRefs :: AbstractShape -> HS.HashSet RClassRef 194 | abstractShapeAllRefs (AbstractShape labelled unlabelled) = 195 | HS.map ByRClass unlabelled `HS.union` HS.map ByLabel (HM.keysSet labelled) 196 | 197 | -- | Removes an 'RClassRef' from an t'AbstractShape' 198 | removeRClass :: 199 | (MonadError T.Text m, TryMerge m) => 200 | AbstractShape -> 201 | RClassRef -> 202 | m AbstractShape 203 | removeRClass AbstractShape {..} (ByRClass rclass) = do 204 | assert "RClass not exist" $ HS.member rclass unlabelled 205 | return $ AbstractShape labelled (HS.delete rclass unlabelled) 206 | removeRClass AbstractShape {..} (ByLabel label) = do 207 | assert "Label not exist" $ HM.member label labelled 208 | return $ AbstractShape (HM.delete label labelled) unlabelled 209 | 210 | -- | Returns the 'RClassIdentifier' corresponding to an 'RClassRef' given an t'AbstractShape' 211 | getRClassByRClassRef :: 212 | (MonadError T.Text m, TryMerge m) => 213 | AbstractShape -> 214 | RClassRef -> 215 | m RClassIdentifier 216 | getRClassByRClassRef AbstractShape {..} (ByRClass rclass) = do 217 | assert "RClass not exist" $ HS.member rclass unlabelled 218 | return rclass 219 | getRClassByRClassRef AbstractShape {..} (ByLabel label) = 220 | case HM.lookup label labelled of 221 | Nothing -> throwError "Label not exist" 222 | Just rclass -> return rclass 223 | 224 | -- | Adds an 'RClassRef' to an t'AbstractShape' 225 | addRClassByRClassRef :: 226 | (MonadError T.Text m, TryMerge m) => 227 | AbstractShape -> 228 | RClassRef -> 229 | RClassIdentifier -> 230 | m AbstractShape 231 | addRClassByRClassRef AbstractShape {..} (ByRClass rclass) rclass' = do 232 | assert "RClass already exist" $ not $ HS.member rclass unlabelled 233 | assert "If adding by rclass itself, then rclass must be the same" $ rclass == rclass' 234 | return $ AbstractShape labelled (HS.insert rclass' unlabelled) 235 | addRClassByRClassRef AbstractShape {..} (ByLabel label) rclass = do 236 | assert "Label already exist" $ not $ HM.member label labelled 237 | return $ AbstractShape (HM.insert label rclass labelled) unlabelled 238 | 239 | -- | Concatenates two t'AbstractShape's. 240 | -- The function makes the following checks: 241 | -- 242 | -- - The input t'AbstractShape's must not have any overlapping labelled RClass. 243 | -- - The input t'AbstractShape's must not have any overlapping unlabelled RClass. 244 | -- - The resulting t'AbstractShape' must not have any labelled RClass that overlaps with an unlabelled RClass. 245 | concatAbstractShape :: 246 | (MonadError T.Text m, TryMerge m) => 247 | AbstractShape -> 248 | AbstractShape -> 249 | m AbstractShape 250 | concatAbstractShape 251 | (AbstractShape l1 u1) 252 | (AbstractShape l2 u2) = do 253 | assert "Labelled rclass overlap" $ HM.null $ HM.intersection l1 l2 254 | assert "Unlabelled rclass overlap" $ HS.null $ HS.intersection u1 u2 255 | let newUnlabelledRClasses = u1 `HS.union` u2 256 | let newLabelled = l1 `HM.union` l2 257 | let newLabelledRClasses = HS.fromList $ HM.elems newLabelled 258 | assert "Labelled rclass overlap with unlabelled" $ 259 | HS.null $ 260 | HS.intersection newUnlabelledRClasses newLabelledRClasses 261 | let newAbstractShape = 262 | AbstractShape 263 | (l1 `HM.union` l2) 264 | (u1 `HS.union` u2) 265 | return newAbstractShape 266 | 267 | -- | Restricts an t'AbstractShape' to the specified set of 'RClassRef's. 268 | restrictAbstractShape :: 269 | (MonadError T.Text m, TryMerge m) => 270 | AbstractShape -> 271 | RClassRefSet -> 272 | m AbstractShape 273 | restrictAbstractShape AbstractShape {..} rclasses = do 274 | let newLabelled = 275 | HM.filterWithKey (\k _ -> k `HS.member` byLabel) labelled 276 | let newUnlabelled = HS.intersection unlabelled byRClass 277 | assert "restrictAbstractShape: some label does not exist" $ 278 | HS.size byLabel == HM.size newLabelled 279 | assert "restrictAbstractShape: some rclass does not exist" $ 280 | HS.size byRClass == HS.size newUnlabelled 281 | return $ AbstractShape newLabelled newUnlabelled 282 | where 283 | byLabel' [] = [] 284 | byLabel' (ByLabel label : as) = label : byLabel' as 285 | byLabel' (_ : as) = byLabel' as 286 | byLabel = HS.fromList $ byLabel' $ HS.toList rclasses 287 | 288 | byRClass' [] = [] 289 | byRClass' (ByRClass rclass : as) = rclass : byRClass' as 290 | byRClass' (_ : as) = byRClass' as 291 | byRClass = HS.fromList $ byRClass' $ HS.toList rclasses 292 | --------------------------------------------------------------------------------