4 | #include "HsFFI.h"
5 |
6 | void dexInit() {
7 | int argc = 4;
8 | char *argv[] = { "+RTS", "-I0", "-A16m", "-RTS", NULL };
9 | char **pargv = argv;
10 |
11 | hs_init(&argc, &pargv);
12 | }
13 |
14 | void dexFini() {
15 | hs_exit();
16 | }
17 |
18 | __thread char dex_err_storage[2048];
19 |
20 | const char* dexGetError() {
21 | return dex_err_storage;
22 | }
23 |
24 | void _internal_dexSetError(char* new_err, int64_t len) {
25 | if (len > 2048) len = 2048;
26 | memcpy(dex_err_storage, new_err, len);
27 | dex_err_storage[2047] = 0;
28 | }
29 |
30 | typedef int64_t (*dex_xla_f)(void*, void**);
31 | void dexXLACPUTrampoline(void* out, void** in) {
32 | dex_xla_f f = *((dex_xla_f*)(*in));
33 | f(out, in + 1);
34 | }
35 |
--------------------------------------------------------------------------------
/src/lib/CUDA.hs:
--------------------------------------------------------------------------------
1 |
2 | module CUDA (hasCUDA, loadCUDAArray, synchronizeCUDA, getCudaArchitecture) where
3 |
4 | import Data.Int
5 | import Foreign.Ptr
6 | #ifdef DEX_CUDA
7 | import Foreign.C
8 | #else
9 | #endif
10 |
11 | hasCUDA :: Bool
12 |
13 | #ifdef DEX_CUDA
14 | hasCUDA = True
15 |
16 | foreign import ccall "dex_cuMemcpyDtoH" cuMemcpyDToH :: Int64 -> Ptr () -> Ptr () -> IO ()
17 | foreign import ccall "dex_synchronizeCUDA" synchronizeCUDA :: IO ()
18 | foreign import ccall "dex_ensure_has_cuda_context" ensureHasCUDAContext :: IO ()
19 | foreign import ccall "dex_get_cuda_architecture" dex_getCudaArchitecture :: Int -> CString -> IO ()
20 |
21 | getCudaArchitecture :: Int -> IO String
22 | getCudaArchitecture dev =
23 | withCString "sm_00" $ \cs ->
24 | dex_getCudaArchitecture dev cs >> peekCString cs
25 | #else
26 | hasCUDA = False
27 |
28 | cuMemcpyDToH :: Int64 -> Ptr () -> Ptr () -> IO ()
29 | cuMemcpyDToH = error "Dex built without CUDA support"
30 |
31 | synchronizeCUDA :: IO ()
32 | synchronizeCUDA = return ()
33 | {-# SCC synchronizeCUDA #-}
34 |
35 | ensureHasCUDAContext :: IO ()
36 | ensureHasCUDAContext = return ()
37 | {-# SCC ensureHasCUDAContext #-}
38 |
39 | getCudaArchitecture :: Int -> IO String
40 | getCudaArchitecture _ = error "Dex built without CUDA support"
41 | #endif
42 |
43 | loadCUDAArray :: Ptr () -> Ptr () -> Int -> IO ()
44 | loadCUDAArray hostPtr devicePtr bytes = do
45 | ensureHasCUDAContext
46 | cuMemcpyDToH (fromIntegral bytes) devicePtr hostPtr
47 | {-# SCC loadCUDAArray #-}
48 |
--------------------------------------------------------------------------------
/src/lib/IRVariants.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2022 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | {-# LANGUAGE AllowAmbiguousTypes #-}
8 |
9 | module IRVariants
10 | ( IR (..), IRPredicate (..), Sat, Sat'
11 | , CoreToSimpIR, InferenceIR, IRRep (..), IRProxy (..), interpretIR
12 | , IRsEqual (..), eqIRRep, WhenIR (..)) where
13 |
14 | import GHC.Generics (Generic (..))
15 | import Data.Store
16 | import Data.Hashable
17 | import Data.Store.Internal
18 | import Data.Kind
19 |
20 | import qualified Unsafe.Coerce as TrulyUnsafe
21 |
22 | data IR =
23 | CoreIR -- used after inference and before simplification
24 | | SimpIR -- used after simplification
25 | deriving (Eq, Ord, Generic, Show, Enum)
26 | instance Store IR
27 |
28 | type CoreToSimpIR = CoreIR -- used during the Core-to-Simp translation
29 | data IRFeature =
30 | DAMOps
31 | | CoreOps
32 | | SimpOps
33 |
34 | -- TODO: make this a hard distinctions
35 | type InferenceIR = CoreIR -- used during type inference only
36 |
37 | data IRPredicate =
38 | Is IR
39 | -- TODO: find a way to make this safe and derive it automatically. For now, we
40 | -- assert it manually for the valid cases we know about.
41 | | IsSubsetOf IR
42 | | HasFeature IRFeature
43 |
44 | type Sat (r::IR) (p::IRPredicate) = (Sat' r p ~ True) :: Constraint
45 | type family Sat' (r::IR) (p::IRPredicate) where
46 | Sat' r (Is r) = True
47 | -- subsets
48 | Sat' SimpIR (IsSubsetOf CoreIR) = True
49 | -- DAMOps
50 | Sat' SimpIR (HasFeature DAMOps) = True
51 | -- DAMOps
52 | Sat' SimpIR (HasFeature SimpOps) = True
53 | -- CoreOps
54 | Sat' CoreIR (HasFeature CoreOps) = True
55 | -- otherwise
56 | Sat' _ _ = False
57 |
58 | class IRRep (r::IR) where
59 | getIRRep :: IR
60 |
61 | data IRProxy (r::IR) = IRProxy
62 |
63 | interpretIR :: IR -> (forall r. IRRep r => IRProxy r -> a) -> a
64 | interpretIR ir cont = case ir of
65 | CoreIR -> cont $ IRProxy @CoreIR
66 | SimpIR -> cont $ IRProxy @SimpIR
67 |
68 | instance IRRep CoreIR where getIRRep = CoreIR
69 | instance IRRep SimpIR where getIRRep = SimpIR
70 |
71 | data IRsEqual (r1::IR) (r2::IR) where
72 | IRsEqual :: IRsEqual r r
73 |
74 | eqIRRep :: forall r1 r2. (IRRep r1, IRRep r2) => Maybe (IRsEqual r1 r2)
75 | eqIRRep = if r1Rep == r2Rep
76 | then Just (TrulyUnsafe.unsafeCoerce (IRsEqual :: IRsEqual r1 r1) :: IRsEqual r1 r2)
77 | else Nothing
78 | where r1Rep = getIRRep @r1; r2Rep = getIRRep @r2
79 | {-# INLINE eqIRRep #-}
80 |
81 | data WhenIR (r::IR) (r'::IR) (a::Type) where
82 | WhenIR :: a -> WhenIR r r a
83 |
84 | instance (IRRep r, IRRep r', Store e) => Store (WhenIR r r' e) where
85 | size = VarSize \(WhenIR e) -> getSize e
86 | peek = case eqIRRep @r @r' of
87 | Just IRsEqual -> WhenIR <$> peek
88 | Nothing -> error "impossible"
89 | poke (WhenIR e) = poke e
90 |
91 | instance Hashable a => Hashable (WhenIR r r' a) where
92 | hashWithSalt salt (WhenIR a) = hashWithSalt salt a
93 |
94 | deriving instance Show a => Show (WhenIR r r' a)
95 | deriving instance Eq a => Eq (WhenIR r r' a)
96 |
--------------------------------------------------------------------------------
/src/lib/JAX/Rename.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2023 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | module JAX.Rename (liftRenameM, renameClosedJaxpr, renameJaxpr) where
8 |
9 | import Control.Monad.Reader
10 | import Data.Map qualified as M
11 |
12 | import Core
13 | import IRVariants
14 | import JAX.Concrete
15 | import MTL1
16 | import Name
17 |
18 | newtype RenamerM (n::S) (a:: *) =
19 | RenamerM { runRenamerM :: ReaderT1 SourceMap (ScopeReaderM) n a }
20 | deriving ( Functor, Applicative, Monad
21 | , ScopeReader, ScopeExtender)
22 |
23 | newtype SourceMap (n::S) = SourceMap
24 | (M.Map JSourceName (Name (AtomNameC SimpIR) n))
25 | deriving (Semigroup, Monoid)
26 |
27 | instance SinkableE SourceMap where
28 | sinkingProofE = undefined
29 |
30 | askSourceMap :: RenamerM n (SourceMap n)
31 | askSourceMap = RenamerM ask
32 |
33 | extendSourceMap :: JSourceName -> (Name (AtomNameC SimpIR)) n
34 | -> RenamerM n a -> RenamerM n a
35 | extendSourceMap sname name (RenamerM cont) = RenamerM do
36 | let ext = SourceMap $ M.singleton sname name
37 | local (<> ext) cont
38 |
39 | liftRenameM :: EnvReader m => RenamerM n (e n) -> m n (e n)
40 | liftRenameM act = liftScopeReaderM $ runReaderT1 mempty $ runRenamerM act
41 |
42 | renameClosedJaxpr :: Distinct o => ClosedJaxpr i -> RenamerM o (ClosedJaxpr o)
43 | renameClosedJaxpr ClosedJaxpr{jaxpr, consts} = do
44 | jaxpr' <- renameJaxpr jaxpr
45 | return ClosedJaxpr{jaxpr=jaxpr', consts}
46 |
47 | renameJaxpr :: Distinct o => Jaxpr i -> RenamerM o (Jaxpr o)
48 | renameJaxpr (Jaxpr invars constvars eqns outvars) =
49 | renameJBinders invars \invars' ->
50 | renameJBinders constvars \constvars' ->
51 | renameJEqns eqns \eqns' -> do
52 | outvars' <- mapM renameJAtom outvars
53 | return $ Jaxpr invars' constvars' eqns' outvars'
54 |
55 | renameJBinder :: Distinct o
56 | => JBinder i i'
57 | -> (forall o'. DExt o o' => JBinder o o' -> RenamerM o' a)
58 | -> RenamerM o a
59 | renameJBinder binder cont = case binder of
60 | JBindSource sname ty -> do
61 | withFreshM (getNameHint sname) \freshName -> do
62 | Distinct <- getDistinct
63 | extendSourceMap sname (binderName freshName) $
64 | cont $ JBind sname ty freshName
65 | JBind _ _ _ -> error "Shouldn't be source-renaming internal names"
66 |
67 | renameJBinders :: Distinct o
68 | => Nest JBinder i i'
69 | -> (forall o'. DExt o o' => Nest JBinder o o' -> RenamerM o' a)
70 | -> RenamerM o a
71 | renameJBinders Empty cont = cont Empty
72 | renameJBinders (Nest b bs) cont =
73 | renameJBinder b \b' ->
74 | renameJBinders bs \bs' ->
75 | cont $ Nest b' bs'
76 |
77 | renameJAtom :: JAtom i -> RenamerM o (JAtom o)
78 | renameJAtom = \case
79 | JVariable jvar -> JVariable <$> renameJVar jvar
80 | JLiteral jlit -> return $ JLiteral jlit
81 |
82 | renameJVar :: JVar i -> RenamerM o (JVar o)
83 | renameJVar JVar{sourceName, ty} = do
84 | sourceName' <- renameJSourceNameOr sourceName
85 | return $ JVar sourceName' ty
86 |
87 | renameJSourceNameOr :: JSourceNameOr (Name (AtomNameC SimpIR)) i
88 | -> RenamerM o (JSourceNameOr (Name (AtomNameC SimpIR)) o)
89 | renameJSourceNameOr = \case
90 | SourceName sname -> do
91 | SourceMap sm <- askSourceMap
92 | case M.lookup sname sm of
93 | (Just name) -> return $ InternalName sname name
94 | Nothing -> error $ "Unbound variable " ++ show sname
95 | InternalName _ _ -> error "Shouldn't be source-renaming internal names"
96 |
97 | renameJEqn :: Distinct o
98 | => JEqn i i'
99 | -> (forall o'. DExt o o' => JEqn o o' -> RenamerM o' a)
100 | -> RenamerM o a
101 | renameJEqn JEqn{outvars, primitive, invars} cont = do
102 | invars' <- mapM renameJAtom invars
103 | renameJBinders outvars \outvars' -> cont $ JEqn outvars' primitive invars'
104 |
105 | renameJEqns :: Distinct o
106 | => Nest JEqn i i'
107 | -> (forall o'. DExt o o' => Nest JEqn o o' -> RenamerM o' a)
108 | -> RenamerM o a
109 | renameJEqns Empty cont = cont Empty
110 | renameJEqns (Nest b bs) cont =
111 | renameJEqn b \b' ->
112 | renameJEqns bs \bs' ->
113 | cont $ Nest b' bs'
114 |
115 |
--------------------------------------------------------------------------------
/src/lib/LLVM/Link.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2022 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | module LLVM.Link
8 | ( createLinker, destroyLinker
9 | , addExplicitLinkMap, addObjectFile, getFunctionPointer
10 | , ExplicitLinkMap
11 | ) where
12 |
13 | import Data.String (fromString)
14 | import Foreign.Ptr
15 | import qualified Data.ByteString as BS
16 |
17 | import System.IO
18 | import System.IO.Temp
19 |
20 | import qualified LLVM.OrcJIT as OrcJIT
21 | import qualified LLVM.Internal.OrcJIT as OrcJIT
22 | import qualified LLVM.Internal.Target as Target
23 | import qualified LLVM.Internal.FFI.Target as FFI
24 |
25 | import qualified LLVM.Shims
26 |
27 | data Linker = Linker
28 | { linkerExecutionSession :: OrcJIT.ExecutionSession
29 | #ifdef darwin_HOST_OS
30 | , linkerLinkLayer :: OrcJIT.ObjectLinkingLayer
31 | #else
32 | , linkerLinkLayer :: OrcJIT.RTDyldObjectLinkingLayer
33 | #endif
34 | , _linkerTargetMachine :: Target.TargetMachine
35 | -- We ought to just need the link layer and the mangler but but llvm-hs
36 | -- requires a full `IRCompileLayer` for incidental reasons. TODO: fix.
37 | , linkerIRLayer :: OrcJIT.IRCompileLayer
38 | , linkerDylib :: OrcJIT.JITDylib }
39 |
40 | instance OrcJIT.IRLayer Linker where
41 | -- llvm-hs requires an compile/IR layer but don't actually need it for the
42 | -- linking functions we call. TODO: update llvm-hs to expose more precise
43 | -- requirements for its linking functions.
44 | getIRLayer l = OrcJIT.getIRLayer $ linkerIRLayer l
45 | getDataLayout l = OrcJIT.getDataLayout $ linkerIRLayer l
46 | getMangler l = OrcJIT.getMangler $ linkerIRLayer l
47 |
48 | type CName = String
49 |
50 | type ExplicitLinkMap = [(CName, Ptr ())]
51 |
52 | createLinker :: IO Linker
53 | createLinker = do
54 | -- TODO: should this be a parameter to `createLinker` instead?
55 | tm <- LLVM.Shims.newDefaultHostTargetMachine
56 | s <- OrcJIT.createExecutionSession
57 | #ifdef darwin_HOST_OS
58 | linkLayer <- OrcJIT.createObjectLinkingLayer s
59 | #else
60 | linkLayer <- OrcJIT.createRTDyldObjectLinkingLayer s
61 | #endif
62 | dylib <- OrcJIT.createJITDylib s "main_dylib"
63 | compileLayer <- OrcJIT.createIRCompileLayer s linkLayer tm
64 | OrcJIT.addDynamicLibrarySearchGeneratorForCurrentProcess compileLayer dylib
65 | return $ Linker s linkLayer tm compileLayer dylib
66 |
67 | destroyLinker :: Linker -> IO ()
68 | destroyLinker (Linker session _ (Target.TargetMachine tm) _ _) = do
69 | -- dylib, link layer and IRLayer should get cleaned up automatically
70 | OrcJIT.disposeExecutionSession session
71 | FFI.disposeTargetMachine tm
72 |
73 | addExplicitLinkMap :: Linker -> ExplicitLinkMap -> IO ()
74 | addExplicitLinkMap l linkMap = do
75 | let (linkedNames, linkedPtrs) = unzip linkMap
76 | let flags = OrcJIT.defaultJITSymbolFlags { OrcJIT.jitSymbolAbsolute = True }
77 | let ptrSymbols = [OrcJIT.JITSymbol (ptrToWordPtr ptr) flags | ptr <- linkedPtrs]
78 | mangledNames <- mapM (OrcJIT.mangleSymbol l . fromString) linkedNames
79 | OrcJIT.defineAbsoluteSymbols (linkerDylib l) $ zip mangledNames ptrSymbols
80 | mapM_ OrcJIT.disposeMangledSymbol mangledNames
81 |
82 | addObjectFile :: Linker -> BS.ByteString -> IO ()
83 | addObjectFile l objFileContents = do
84 | withSystemTempFile "objfile.o" \path h -> do
85 | BS.hPut h objFileContents
86 | hFlush h
87 | OrcJIT.addObjectFile (linkerLinkLayer l) (linkerDylib l) path
88 |
89 | getFunctionPointer :: Linker -> CName -> IO (FunPtr a)
90 | getFunctionPointer l name = do
91 | OrcJIT.lookupSymbol (linkerExecutionSession l) (linkerIRLayer l)
92 | (linkerDylib l) (fromString name) >>= \case
93 | Right (OrcJIT.JITSymbol funcAddr _) ->
94 | return $ castPtrToFunPtr $ wordPtrToPtr funcAddr
95 | Left s -> error $ "Couldn't find function: " ++ name ++ "\n" ++ show s
96 |
--------------------------------------------------------------------------------
/src/lib/LLVM/Shims.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2021 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | module LLVM.Shims (
8 | newTargetMachine, newHostTargetMachine, disposeTargetMachine,
9 | newDefaultHostTargetMachine
10 | ) where
11 |
12 | import qualified Data.Map as M
13 | import qualified Data.ByteString.Char8 as BS
14 | import qualified Data.ByteString.Short as SBS
15 |
16 | import qualified LLVM.Relocation as R
17 | import qualified LLVM.CodeModel as CM
18 | import qualified LLVM.CodeGenOpt as CGO
19 | import qualified LLVM.Internal.Target as Target
20 | import qualified LLVM.Internal.FFI.Target as Target.FFI
21 | import LLVM.Prelude (ShortByteString, ByteString)
22 | import LLVM.Internal.Coding (encodeM)
23 |
24 | -- llvm-hs doesn't expose any way to manage target machines in a non-bracketed way
25 |
26 | newTargetMachine :: Target.Target
27 | -> ShortByteString
28 | -> ByteString
29 | -> M.Map Target.CPUFeature Bool
30 | -> Target.TargetOptions
31 | -> R.Model
32 | -> CM.Model
33 | -> CGO.Level
34 | -> IO Target.TargetMachine
35 | newTargetMachine (Target.Target targetFFI) triple cpu features
36 | (Target.TargetOptions targetOptFFI)
37 | relocModel codeModel cgoLevel = do
38 | SBS.useAsCString triple \tripleFFI -> do
39 | BS.useAsCString cpu \cpuFFI -> do
40 | let featuresStr = BS.intercalate "," $ fmap encodeFeature $ M.toList features
41 | BS.useAsCString featuresStr \featuresFFI -> do
42 | relocModelFFI <- encodeM relocModel
43 | codeModelFFI <- encodeM codeModel
44 | cgoLevelFFI <- encodeM cgoLevel
45 | Target.TargetMachine <$> Target.FFI.createTargetMachine
46 | targetFFI tripleFFI cpuFFI featuresFFI
47 | targetOptFFI relocModelFFI codeModelFFI cgoLevelFFI
48 | where encodeFeature (Target.CPUFeature f, on) = (if on then "+" else "-") <> f
49 |
50 | -- XXX: We need to use the large code model for macOS, because the libC functions
51 | -- are loaded very far away from the JITed code. This does not prevent the
52 | -- runtime linker from attempting to shove their offsets into 32-bit values
53 | -- which cannot represent them, leading to segfaults that are very fun to debug.
54 | -- It would be good to find a better solution, because larger code models might
55 | -- hurt performance if we were to end up doing a lot of function calls.
56 | -- TODO: Consider changing the linking layer, as suggested in:
57 | -- http://llvm.1065342.n5.nabble.com/llvm-dev-ORC-JIT-Weekly-5-td135203.html
58 | newDefaultHostTargetMachine :: IO Target.TargetMachine
59 | newDefaultHostTargetMachine = LLVM.Shims.newHostTargetMachine R.PIC cm CGO.Aggressive
60 | where
61 | #if darwin_HOST_OS
62 | cm = CM.Small
63 | #else
64 | cm = CM.Large
65 | #endif
66 |
67 | newHostTargetMachine :: R.Model -> CM.Model -> CGO.Level -> IO Target.TargetMachine
68 | newHostTargetMachine relocModel codeModel cgoLevel = do
69 | Target.initializeAllTargets
70 | triple <- Target.getProcessTargetTriple
71 | (target, _) <- Target.lookupTarget Nothing triple
72 | cpu <- Target.getHostCPUName
73 | features <- Target.getHostCPUFeatures
74 | Target.withTargetOptions \targetOptions ->
75 | newTargetMachine target triple cpu features targetOptions relocModel codeModel cgoLevel
76 |
77 | disposeTargetMachine :: Target.TargetMachine -> IO ()
78 | disposeTargetMachine (Target.TargetMachine tmFFI) = Target.FFI.disposeTargetMachine tmFFI
79 |
--------------------------------------------------------------------------------
/src/lib/Live/Web.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2019 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | module Live.Web (runWeb, generateHTML) where
8 |
9 | import Control.Concurrent (readChan)
10 | import Control.Monad (forever)
11 |
12 | import Network.Wai (Application, StreamingBody, pathInfo,
13 | responseStream, responseLBS, responseFile)
14 | import Network.Wai.Handler.Warp (run)
15 | import Network.HTTP.Types (status200, status404)
16 | import Data.Aeson (ToJSON, encode)
17 | import Data.Binary.Builder (fromByteString)
18 | import Data.ByteString.Lazy (toStrict)
19 | import qualified Data.ByteString as BS
20 | import System.Directory (withCurrentDirectory)
21 |
22 | -- import Paths_dex (getDataFileName)
23 | import RenderHtml
24 | import Live.Eval
25 | import TopLevel
26 |
27 | runWeb :: FilePath -> EvalConfig -> TopStateEx -> IO ()
28 | runWeb fname opts env = do
29 | resultsChan <- watchAndEvalFile fname opts env
30 | putStrLn "Streaming output to http://localhost:8000/"
31 | run 8000 $ serveResults resultsChan
32 |
33 | pagesDir :: FilePath
34 | pagesDir = "pages"
35 |
36 | generateHTML :: FilePath -> FilePath -> EvalConfig -> TopStateEx -> IO ()
37 | generateHTML sourcePath destPath cfg env = do
38 | finalState <- evalFileNonInteractive sourcePath cfg env
39 | results <- renderResults finalState
40 | withCurrentDirectory pagesDir do
41 | renderStandaloneHTML destPath results
42 |
43 | serveResults :: EvalServer -> Application
44 | serveResults resultsSubscribe request respond = do
45 | print (pathInfo request)
46 | case pathInfo request of
47 | [] -> respondWith "static/dynamic.html" "text/html"
48 | ["style.css"] -> respondWith "static/style.css" "text/css"
49 | ["index.js"] -> respondWith "static/index.js" "text/javascript"
50 | ["getnext"] -> respond $ responseStream status200
51 | [ ("Content-Type", "text/event-stream")
52 | , ("Cache-Control", "no-cache")]
53 | $ resultStream resultsSubscribe
54 | _ -> respond $ responseLBS status404
55 | [("Content-Type", "text/plain")] "404 - Not Found"
56 | where
57 | respondWith dataFname ctype = do
58 | fname <- return dataFname -- lets us skip rebuilding during development
59 | -- fname <- getDataFileName dataFname
60 | respond $ responseFile status200 [("Content-Type", ctype)] fname Nothing
61 |
62 | resultStream :: EvalServer -> StreamingBody
63 | resultStream resultsServer write flush = do
64 | sendUpdate ("start"::String)
65 | (initResult, resultsChan) <- subscribeIO resultsServer
66 | (renderedInit, renderUpdateFun) <- renderResultsInc initResult
67 | sendUpdate renderedInit
68 | forever $ readChan resultsChan >>= renderUpdateFun >>= sendUpdate
69 | where
70 | sendUpdate :: ToJSON a => a -> IO ()
71 | sendUpdate x = write (fromByteString $ encodePacket x) >> flush
72 |
73 | encodePacket :: ToJSON a => a -> BS.ByteString
74 | encodePacket = toStrict . wrap . encode
75 | where wrap s = "data:" <> s <> "\n\n"
76 |
--------------------------------------------------------------------------------
/src/lib/Serialize.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2019 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | module Serialize (HasPtrs (..), takePtrSnapshot, restorePtrSnapshot) where
8 |
9 | import Prelude hiding (pi, abs)
10 | import Control.Monad
11 | import qualified Data.ByteString as BS
12 | import Data.ByteString.Internal (memcpy)
13 | import Data.ByteString.Unsafe (unsafeUseAsCString)
14 | import Data.Int
15 | import Data.Store hiding (size)
16 | import Foreign.Ptr
17 | import Foreign.Marshal.Array
18 | import GHC.Generics (Generic)
19 |
20 | import Types.Primitives
21 |
22 | foreign import ccall "malloc_dex" dexMalloc :: Int64 -> IO (Ptr ())
23 | foreign import ccall "dex_allocation_size" dexAllocSize :: Ptr () -> IO Int64
24 |
25 | data WithSnapshot a = WithSnapshot a [PtrSnapshot] deriving Generic
26 | type RawPtr = Ptr ()
27 |
28 | class HasPtrs a where
29 | traversePtrs :: Applicative f => (PtrType -> RawPtr -> f RawPtr) -> a -> f a
30 |
31 | takePtrSnapshot :: PtrType -> PtrLitVal -> IO PtrLitVal
32 | takePtrSnapshot _ NullPtr = return NullPtr
33 | takePtrSnapshot (CPU, ptrTy) (PtrLitVal ptrVal) = case ptrTy of
34 | PtrType eltTy -> do
35 | childPtrs <- loadPtrPtrs ptrVal
36 | PtrSnapshot <$> PtrArray <$> mapM (takePtrSnapshot eltTy) childPtrs
37 | _ -> PtrSnapshot . ByteArray <$> loadPtrBytes ptrVal
38 | takePtrSnapshot (GPU, _) _ = error "Snapshots of GPU memory not implemented"
39 | takePtrSnapshot _ (PtrSnapshot _) = error "Already a snapshot"
40 | {-# SCC takePtrSnapshot #-}
41 |
42 | loadPtrBytes :: RawPtr -> IO BS.ByteString
43 | loadPtrBytes ptr = do
44 | numBytes <- fromIntegral <$> dexAllocSize ptr
45 | liftM BS.pack $ peekArray numBytes $ castPtr ptr
46 |
47 | loadPtrPtrs :: RawPtr -> IO [PtrLitVal]
48 | loadPtrPtrs ptr = do
49 | numBytes <- fromIntegral <$> dexAllocSize ptr
50 | childPtrs <- peekArray (numBytes `div` ptrSize) $ castPtr ptr
51 | forM childPtrs \childPtr ->
52 | if childPtr == nullPtr
53 | then return NullPtr
54 | else return $ PtrLitVal childPtr
55 |
56 | restorePtrSnapshot :: PtrLitVal -> IO PtrLitVal
57 | restorePtrSnapshot NullPtr = return NullPtr
58 | restorePtrSnapshot (PtrSnapshot snapshot) = case snapshot of
59 | PtrArray children -> do
60 | childrenPtrs <- forM children \child ->
61 | restorePtrSnapshot child >>= \case
62 | NullPtr -> return nullPtr
63 | PtrLitVal p -> return p
64 | PtrSnapshot _ -> error "expected a pointer literal"
65 | PtrLitVal <$> storePtrPtrs childrenPtrs
66 | ByteArray bytes -> PtrLitVal <$> storePtrBytes bytes
67 | restorePtrSnapshot (PtrLitVal _) = error "not a snapshot"
68 | {-# SCC restorePtrSnapshot #-}
69 |
70 | storePtrBytes :: BS.ByteString -> IO RawPtr
71 | storePtrBytes xs = do
72 | let numBytes = BS.length xs
73 | destPtr <- dexMalloc $ fromIntegral numBytes
74 | -- this is safe because we don't modify srcPtr's memory or let it escape
75 | unsafeUseAsCString xs \srcPtr ->
76 | memcpy (castPtr destPtr) (castPtr srcPtr) numBytes
77 | return destPtr
78 |
79 | storePtrPtrs :: [RawPtr] -> IO RawPtr
80 | storePtrPtrs ptrs = do
81 | ptr <- dexMalloc $ fromIntegral $ length ptrs * ptrSize
82 | pokeArray (castPtr ptr) ptrs
83 | return ptr
84 |
85 | -- === instances ===
86 |
87 | instance Store a => Store (WithSnapshot a)
88 |
--------------------------------------------------------------------------------
/src/lib/Simplify.hs-boot:
--------------------------------------------------------------------------------
1 | -- Copyright 2023 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | module Simplify (linearizeTopFun) where
8 |
9 | import Name
10 | import Builder
11 | import Types.Core
12 | import Types.Top
13 |
14 | linearizeTopFun :: (Mut n, Fallible1 m, TopBuilder m) => LinearizationSpec n -> m n (TopFunName n, TopFunName n)
15 |
--------------------------------------------------------------------------------
/src/lib/Types/OpNames.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2023 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | -- This module contains payload-free versions of the ops defined in Types.Core.
8 | -- It uses the same constructor names so it should be imported qualified.
9 |
10 | module Types.OpNames where
11 |
12 | import IRVariants
13 | import Data.Hashable
14 | import GHC.Generics (Generic (..))
15 | import Data.Store (Store (..))
16 |
17 | import PPrint
18 |
19 | data TC = ProdType | SumType | RefType | TypeKind | HeapType
20 | data Con = ProdCon | SumCon Int | HeapVal
21 |
22 | data BinOp =
23 | IAdd | ISub | IMul | IDiv | ICmp CmpOp | FAdd | FSub | FMul
24 | | FDiv | FCmp CmpOp | FPow | BAnd | BOr | BShL | BShR | IRem | BXor
25 |
26 | data UnOp =
27 | Exp | Exp2 | Log | Log2 | Log10 | Log1p | Sin | Cos | Tan | Sqrt | Floor
28 | | Ceil | Round | LGamma | Erf | Erfc | FNeg | BNot
29 |
30 | data CmpOp = Less | Greater | Equal | LessEqual | GreaterEqual
31 |
32 | data MemOp = IOAlloc | IOFree | PtrOffset | PtrLoad | PtrStore
33 |
34 | data MiscOp =
35 | Select | CastOp | BitcastOp | UnsafeCoerce | GarbageVal | Effects
36 | | ThrowError | ThrowException | Tag | SumTag | Create | ToEnum
37 | | OutputStream | ShowAny | ShowScalar
38 |
39 | data VectorOp = VectorBroadcast | VectorIota | VectorIdx | VectorSubref
40 |
41 | data Hof (r::IR) =
42 | While | RunReader | RunWriter | RunState | RunIO | RunInit
43 | | CatchException | Linearize | Transpose
44 |
45 | data DAMOp = Seq | RememberDest | AllocDest | Place | Freeze
46 |
47 | data RefOp = MAsk | MExtend | MGet | MPut | IndexRef | ProjRef Projection
48 |
49 | data Projection =
50 | UnwrapNewtype -- TODO: add `HasCore r` constraint
51 | | ProjectProduct Int
52 | deriving (Show, Eq, Generic)
53 |
54 | data UserEffectOp = Handle | Resume | Perform
55 |
56 | deriving instance Generic BinOp
57 | deriving instance Generic UnOp
58 | deriving instance Generic CmpOp
59 | deriving instance Generic TC
60 | deriving instance Generic Con
61 | deriving instance Generic MemOp
62 | deriving instance Generic MiscOp
63 | deriving instance Generic VectorOp
64 | deriving instance Generic (Hof r)
65 | deriving instance Generic DAMOp
66 | deriving instance Generic RefOp
67 | deriving instance Generic UserEffectOp
68 |
69 | instance Hashable BinOp
70 | instance Hashable UnOp
71 | instance Hashable CmpOp
72 | instance Hashable TC
73 | instance Hashable Con
74 | instance Hashable MemOp
75 | instance Hashable MiscOp
76 | instance Hashable VectorOp
77 | instance Hashable (Hof r)
78 | instance Hashable DAMOp
79 | instance Hashable RefOp
80 | instance Hashable UserEffectOp
81 | instance Hashable Projection
82 |
83 | instance Store BinOp
84 | instance Store UnOp
85 | instance Store CmpOp
86 | instance Store TC
87 | instance Store Con
88 | instance Store MemOp
89 | instance Store MiscOp
90 | instance Store VectorOp
91 | instance IRRep r => Store (Hof r)
92 | instance Store DAMOp
93 | instance Store RefOp
94 | instance Store UserEffectOp
95 | instance Store Projection
96 |
97 | deriving instance Show BinOp
98 | deriving instance Show UnOp
99 | deriving instance Show CmpOp
100 | deriving instance Show TC
101 | deriving instance Show Con
102 | deriving instance Show MemOp
103 | deriving instance Show MiscOp
104 | deriving instance Show VectorOp
105 | deriving instance Show (Hof r)
106 | deriving instance Show DAMOp
107 | deriving instance Show RefOp
108 | deriving instance Show UserEffectOp
109 |
110 | deriving instance Eq BinOp
111 | deriving instance Eq UnOp
112 | deriving instance Eq CmpOp
113 | deriving instance Eq TC
114 | deriving instance Eq Con
115 | deriving instance Eq MemOp
116 | deriving instance Eq MiscOp
117 | deriving instance Eq VectorOp
118 | deriving instance Eq (Hof r)
119 | deriving instance Eq DAMOp
120 | deriving instance Eq RefOp
121 | deriving instance Eq UserEffectOp
122 |
123 | instance Pretty Projection where
124 | pretty = \case
125 | UnwrapNewtype -> "u"
126 | ProjectProduct i -> pretty i
127 |
--------------------------------------------------------------------------------
/src/old/Imp/Optimize.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2020 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | module Imp.Optimize (liftCUDAAllocations) where
8 |
9 | import Control.Monad
10 |
11 | import PPrint
12 | import Env
13 | import Cat
14 | import Syntax
15 | import Imp.Builder
16 |
17 | -- TODO: DCE!
18 |
19 | type AllocInfo = (BaseType, Int)
20 | type FuncAllocEnv = [(IBinder, AllocInfo)]
21 | type ModAllocEnv = Env [AllocInfo]
22 |
23 | liftCUDAAllocations :: ImpModule -> ImpModule
24 | liftCUDAAllocations m =
25 | fst $ runCat (traverseImpModule liftFunc m) mempty
26 | where
27 | liftFunc :: Env IFunVar -> ImpFunction -> Cat ModAllocEnv ImpFunction
28 | liftFunc fenv f = case f of
29 | FFIFunction _ -> return f
30 | ImpFunction (fname:>IFunType cc argTys retTys) argBs' body' -> case cc of
31 | CUDAKernelLaunch -> do
32 | let ((argBs, body), fAllocEnv) =
33 | flip runCat mempty $ runISubstBuilderT (ISubstEnv mempty fenv) $ do
34 | ~args@(tid:wid:wsz:_) <- traverse freshIVar argBs'
35 | newBody <- extendValSubst (newEnv argBs' $ fmap IVar args) $ buildScoped $ do
36 | gtid <- iadd (IVar tid) =<< imul (IVar wid) (IVar wsz)
37 | evalImpBlock (liftAlloc gtid) body'
38 | return (fmap Bind args, newBody)
39 | let (allocBs, allocs) = unzip fAllocEnv
40 | extend $ fname @> allocs
41 | let newFunTy = IFunType cc (argTys ++ fmap binderAnn allocBs) retTys
42 | return $ ImpFunction (fname :> newFunTy) (argBs ++ allocBs) body
43 | _ -> traverseImpFunction amendLaunch fenv f
44 |
45 | liftAlloc :: IExpr -> ITraversalDef (Cat FuncAllocEnv)
46 | liftAlloc gtid = (liftAllocDecl, traverseImpInstr rec)
47 | where
48 | rec = liftAlloc gtid
49 | liftAllocDecl decl = case decl of
50 | ImpLet [b] (Alloc addrSpace ty (IIdxRepVal size)) ->
51 | case addrSpace of
52 | Stack -> traverseImpDecl rec decl
53 | Heap CPU -> error "Unexpected CPU allocation in a CUDA kernel"
54 | Heap GPU -> do
55 | bArg <- freshIVar b
56 | liftSE $ extend $ [(Bind bArg, (ty, fromIntegral size))]
57 | ptr <- ptrOffset (IVar bArg) =<< imul gtid (IIdxRepVal size)
58 | return $ b @> ptr
59 | ImpLet _ (Alloc _ _ _) ->
60 | error $ "Failed to lift an allocation out of a CUDA kernel: " ++ pprint decl
61 | ImpLet _ (Free _) -> return mempty
62 | _ -> traverseImpDecl rec decl
63 |
64 | amendLaunch :: ITraversalDef (Cat ModAllocEnv)
65 | amendLaunch = (traverseImpDecl amendLaunch, amendLaunchInstr)
66 | where
67 | amendLaunchInstr :: ImpInstr -> ISubstBuilderT (Cat ModAllocEnv) ImpInstr
68 | amendLaunchInstr instr = case instr of
69 | ILaunch f' s' args' -> do
70 | s <- traverseIExpr s'
71 | args <- traverse traverseIExpr args'
72 | liftedAllocs <- liftSE $ looks (!f')
73 | f <- traverseIFunVar f'
74 | extraArgs <- case null liftedAllocs of
75 | True -> return []
76 | False -> do
77 | ~[numWorkgroups, workgroupSize] <- emit $ IQueryParallelism f s
78 | nthreads <- imul numWorkgroups workgroupSize
79 | forM liftedAllocs $ \(ty, size) -> do
80 | totalSize <- imul (IIdxRepVal $ fromIntegral size) nthreads
81 | alloc (Heap GPU) ty totalSize
82 | return $ ILaunch f s (args ++ extraArgs)
83 | _ -> traverseImpInstr amendLaunch instr
84 |
--------------------------------------------------------------------------------
/src/old/MLIR/Eval.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2021 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | module MLIR.Eval where
8 |
9 | import Data.Function
10 | import qualified Data.ByteString.Char8 as BSC8
11 | import qualified Data.ByteString as BS
12 | import GHC.Stack
13 |
14 | import qualified MLIR.AST as AST
15 | import qualified MLIR.AST.Serialize as AST
16 | import qualified MLIR.Native as Native
17 | import qualified MLIR.Native.Pass as Native
18 | import qualified MLIR.Native.ExecutionEngine as Native
19 |
20 |
21 | import Syntax
22 | -- TODO(apaszke): Separate the LitVal operations from LLVMExec
23 | import LLVMExec
24 |
25 | evalModule :: AST.Operation -> [LitVal] -> [BaseType] -> IO [LitVal]
26 | evalModule ast args resultTypes =
27 | Native.withContext \ctx -> do
28 | Native.registerAllDialects ctx
29 | mOp <- AST.fromAST ctx (mempty, mempty) ast
30 | Just m <- Native.moduleFromOperation mOp
31 | verifyModule m
32 | Native.withPassManager ctx \pm -> do
33 | throwOnFailure "Failed to parse pass pipeline" $
34 | (Native.addParsedPassPipeline pm $ BS.intercalate ","
35 | [ "func(tensor-bufferize,std-bufferize,finalizing-bufferize)"
36 | , "convert-memref-to-llvm"
37 | , "convert-std-to-llvm"
38 | ])
39 | Native.runPasses pm m & throwOnFailure "Failed to lower module"
40 | verifyModule m
41 | Native.withExecutionEngine m \(Just eng) -> do
42 | Native.withStringRef "entry" \name -> do
43 | allocaCells (length args) \argsPtr ->
44 | allocaCells (length resultTypes) \resultPtr -> do
45 | storeLitVals argsPtr args
46 | Just () <- Native.executionEngineInvoke @() eng name
47 | [Native.SomeStorable argsPtr, Native.SomeStorable resultPtr]
48 | loadLitVals resultPtr resultTypes
49 |
50 | verifyModule :: HasCallStack => Native.Module -> IO ()
51 | verifyModule m = do
52 | correct <- Native.verifyOperation =<< Native.moduleAsOperation m
53 | case correct of
54 | True -> return ()
55 | False -> do
56 | modStr <- BSC8.unpack <$> Native.showModule m
57 | error $ "Invalid module:\n" ++ modStr
58 |
59 | throwOnFailure :: String -> IO Native.LogicalResult -> IO ()
60 | throwOnFailure msg m = do
61 | result <- m
62 | case result of
63 | Native.Success -> return ()
64 | Native.Failure -> error msg
65 |
--------------------------------------------------------------------------------
/stack-llvm-head.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Use of this source code is governed by a BSD-style
4 | # license that can be found in the LICENSE file or at
5 | # https://developers.google.com/open-source/licenses/bsd
6 |
7 | resolver: lts-16.31
8 |
9 | packages:
10 | - .
11 |
12 | extra-deps:
13 | - github: llvm-hs/llvm-hs
14 | commit: aba6986a644916239ad414f0966b40f2faffa5f3
15 | subdirs:
16 | - llvm-hs
17 | - llvm-hs-pure
18 | - github: google/mlir-hs
19 | commit: 7a4f4984c71e8fb0d7730bc541e9f2daf1971073
20 | - megaparsec-8.0.0
21 | - prettyprinter-1.6.2
22 | - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001
23 | - store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430
24 | - th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869
25 | - floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442
26 |
27 |
--------------------------------------------------------------------------------
/stack-macos.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2019 Google LLC
2 | #
3 | # Use of this source code is governed by a BSD-style
4 | # license that can be found in the LICENSE file or at
5 | # https://developers.google.com/open-source/licenses/bsd
6 |
7 | resolver: lts-18.23
8 |
9 | packages:
10 | - .
11 |
12 | extra-deps:
13 | - github: llvm-hs/llvm-hs
14 | commit: 423220bffac4990d019fc088c46c5f25310d5a33
15 | subdirs:
16 | - llvm-hs
17 | - llvm-hs-pure
18 | - megaparsec-8.0.0
19 | - prettyprinter-1.6.2
20 | - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001
21 | - store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430
22 | - th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869
23 | - floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442
24 |
25 | flags:
26 | llvm-hs:
27 | shared-llvm: false
28 |
--------------------------------------------------------------------------------
/stack.yaml:
--------------------------------------------------------------------------------
1 | # Copyright 2020 Google LLC
2 | #
3 | # Use of this source code is governed by a BSD-style
4 | # license that can be found in the LICENSE file or at
5 | # https://developers.google.com/open-source/licenses/bsd
6 |
7 | resolver: lts-18.23
8 |
9 | packages:
10 | - .
11 |
12 | extra-deps:
13 | - github: llvm-hs/llvm-hs
14 | commit: 423220bffac4990d019fc088c46c5f25310d5a33
15 | subdirs:
16 | - llvm-hs
17 | - llvm-hs-pure
18 | - megaparsec-8.0.0
19 | - prettyprinter-1.6.2
20 | - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001
21 | - store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430
22 | - th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869
23 | - floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442
24 |
25 | nix:
26 | enable: false
27 | packages: [ libpng llvm_12 pkg-config zlib ]
28 |
29 | ghc-options:
30 | containers: -fno-prof-auto -O2
31 | hashable: -fno-prof-auto -O2
32 | llvm-hs-pure: -fno-prof-auto -O2
33 | llvm-hs: -fno-prof-auto -O2
34 | megaparsec: -fno-prof-auto -O2
35 | parser-combinators: -fno-prof-auto -O2
36 | prettyprinter: -fno-prof-auto -O2
37 | store-core: -fno-prof-auto -O2
38 | store: -fno-prof-auto -O2
39 | unordered-containers: -fno-prof-auto -O2
40 |
--------------------------------------------------------------------------------
/static/dynamic.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Dex Output
6 |
7 |
8 |
9 |
10 | (hover over code for more information)
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/static/style.css:
--------------------------------------------------------------------------------
1 | /* Copyright 2019 Google LLC */
2 | /* */
3 | /* Use of this source code is governed by a BSD-style */
4 | /* license that can be found in the LICENSE file or at */
5 | /* https://developers.google.com/open-source/licenses/bsd */
6 |
7 | body {
8 | font-family: Helvetica, sans-serif;
9 | font-size: 100%;
10 | color: #333;
11 | overflow-x: hidden;
12 | padding-bottom:50vw;
13 | }
14 |
15 | #main-output {
16 | margin-left: 20px;
17 | }
18 | #minimap {
19 | display: flex;
20 | flex-direction: column;
21 | position: fixed;
22 | top: 0em;
23 | left: 0em;
24 | height: 85vh;
25 | width: 32px;
26 | overflow: hidden;
27 | }
28 | .status {
29 | flex: 1;
30 | width : 30px;
31 | border-top: 1px solid;
32 | border-color: lightgray;
33 | margin-left: 1px;
34 | }
35 | #hover-info {
36 | position: fixed;
37 | height: 15vh;
38 | bottom: 0em;
39 | width: 100vw;
40 | overflow: hidden;
41 | background-color: white;
42 | border-top: 1px solid firebrick;
43 | font-family: monospace;
44 | white-space: pre;
45 | }
46 | /* cell structure */
47 | .cell {
48 | margin-left: 5px;
49 | display: flex;
50 | }
51 | .line-nums {
52 | flex: 0 0 3em;
53 | height: 100%;
54 | text-align: right;
55 | color: #808080;
56 | font-family: monospace;
57 | white-space: pre;
58 | }
59 | .contents {
60 | margin-left: 1em;
61 | font-family: monospace;
62 | white-space: pre;
63 | }
64 |
65 | /* special results */
66 | .err-result {
67 | font-weight: bold;
68 | color: #B22222;
69 | }
70 |
71 | /* status colors */
72 | .status-inert {}
73 | .status-waiting {background-color: gray;}
74 | .status-running {background-color: lightblue;}
75 | .status-err {background-color: red;}
76 | .status-success {background-color: white;}
77 |
78 | /* span highlighting */
79 | .highlight-error {
80 | text-decoration: red underline;
81 | text-decoration-thickness: 5px;
82 | text-decoration-skip-ink: none;}
83 | .highlight-group { background-color: yellow; }
84 | .highlight-scope { background-color: lightyellow; }
85 | .highlight-binder { background-color: lightblue; }
86 | .highlight-occ { background-color: yellow; }
87 | .highlight-leaf { background-color: lightgray; }
88 |
89 | /* lexeme colors */
90 | .comment {color: gray;}
91 | .keyword {color: #0000DD;}
92 | .command {color: #A80000;}
93 | .symbol {color: #E07000;}
94 | .type-name {color: #A80000;}
95 |
96 | .status-hover {
97 | background-color: yellow;
98 | }
99 |
--------------------------------------------------------------------------------
/tests/algeff-tests.dx:
--------------------------------------------------------------------------------
1 | effect Exn
2 | ctl raise : (a: Type) ?-> Unit -> a
3 |
4 | handler catch_ of Exn r : Maybe r
5 | ctl raise = \_. Nothing
6 | return = \x. Just x
7 |
8 | handler bad_catch_1 of Exn r : Maybe r
9 | ctl raise = \_. Nothing
10 | ctl raise = \_. Nothing -- duplicate!
11 | return = \x. Just x
12 | > Type error:Duplicate operation: raise
13 |
14 | handler bad_catch_2 of Exn r : Maybe r
15 | ctl raise = \_. Nothing
16 | > Type error:missing return
17 | -- return = \x. Just x -- missing!
18 |
19 | handler bad_catch_3 of Exn r : Maybe r
20 | -- ctl raise = \_. Nothing -- missing!
21 | return = \x. Just x
22 | > Type error:Missing operation: raise
23 |
24 | handler bad_catch_4 of Exn r : Maybe r
25 | ctl raise = \_. 42.0 -- definitely not Maybe
26 | return = \x. Just x
27 | > Type error:
28 | > Expected: (Maybe r)
29 | > Actual: Float32
30 | >
31 | > ctl raise = \_. 42.0 -- definitely not Maybe
32 | > ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
33 |
34 | handler bad_catch_5 of Exn r : Maybe r
35 | ctl raise = \_. Nothing
36 | return = \x. 42.0 -- definitely not Maybe
37 | > Type error:
38 | > Expected: (Maybe r)
39 | > Actual: Float32
40 | >
41 | > return = \x. 42.0 -- definitely not Maybe
42 | > ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
43 |
44 | handler bad_catch_6 of Exn r : Maybe r
45 | def raise = \_. Nothing -- wrong policy!
46 | return = \x. Just x
47 | > Type error:operation raise was declared with def but defined with ctl
48 |
49 | def check (b:Bool) : {Exn} Unit =
50 | if not b then raise ()
51 |
52 | def checkFloatNonNegative (x:Float) : {Exn} Float =
53 | check $ x >= 0.0
54 | x
55 |
56 | -- catch_ \_.
57 | -- checkFloatNonNegative (3.14)
58 | -- > Compiler bug!
59 | -- > Please report this at github.com/google-research/dex-lang/issues
60 | -- >
61 | -- > Not implemented
62 | -- > CallStack (from HasCallStack):
63 | -- > error, called at src/lib/Simplify.hs:214:19 in dex-0.1.0.0-8hDfthyGTXmzhkTo2ydOn:Simplify
64 |
65 | -- catch_ \_.
66 | -- checkFloatNonNegative (-1.0)
67 | -- > Compiler bug!
68 | -- > Please report this at github.com/google-research/dex-lang/issues
69 | -- >
70 | -- > Not implemented
71 | -- > CallStack (from HasCallStack):
72 | -- > error, called at src/lib/Simplify.hs:214:19 in dex-0.1.0.0-8hDfthyGTXmzhkTo2ydOn:Simplify
73 |
74 | effect Counter
75 | def inc : Unit -> Unit
76 |
77 | handler runCounter of Counter r {h} (ref : Ref h Nat) : {State h} (r & Nat)
78 | def inc = \_.
79 | ref := (1 + get ref)
80 | resume ()
81 | return = \x. (x, get ref)
82 | > Error: variable not in scope: resume
83 | >
84 | > resume ()
85 | > ^^^^^^^
86 |
--------------------------------------------------------------------------------
/tests/cast-tests.dx:
--------------------------------------------------------------------------------
1 | -- ==== Integral casts ====
2 | --
3 | -- Semantics of internal_cast on integral types are based on the bit representation
4 | -- of the values in question. All WordX types have a bit representation equal to their
5 | -- value in standard binary format. All IntX types use two's complement representation.
6 | --
7 | -- The cast is always performed by taking the source value to its bit representation,
8 | -- resizing that representation (depending on the target signedness), and interpreting
9 | -- the resulting bit pattern in the target type.
10 | --
11 | -- The rules for resizing the bit pattern are as follows:
12 | -- 1. If the target bitwidth is smaller than source bitwidth, the maximum number of
13 | -- least significant bits are preserved.
14 | -- 2. If the target bitwidth is equal to the source bitwidth, nothing happens.
15 | -- 3. If the target bitwidth is greater than the source bitwidth, and:
16 | -- 3a. the target type is signed, the representation is sign-extended (the
17 | -- MSB is used to pad the value up to the desired width).
18 | -- 3b. the target type is unsigned, the representation is zero-extended.
19 |
20 | -- Casts to Int32
21 |
22 | internal_cast(to=Int32, 2147483647 :: Int64)
23 | > 2147483647
24 |
25 | internal_cast(to=Int32, 2147483648 :: Int64)
26 | > -2147483648
27 |
28 | internal_cast(to=Int32, 8589935826 :: Int64) -- 2^33 + 1234
29 | > 1234
30 |
31 | internal_cast(to=Int32, 123 :: Word8)
32 | > 123
33 |
34 | internal_cast(to=Int32, 1234 :: Word32)
35 | > 1234
36 |
37 | internal_cast(to=Int32, 4294967295 :: Word32)
38 | > -1
39 |
40 | internal_cast(to=Int32, 1234 :: Word64)
41 | > 1234
42 |
43 | internal_cast(to=Int32, 4294967295 :: Word64)
44 | > -1
45 |
46 | internal_cast(to=Int32, 4294967296 :: Word64)
47 | > 0
48 |
49 | internal_cast(to=Int32, 5000000000 :: Word64)
50 | > 705032704
51 |
52 | -- Casts to Int64
53 |
54 | internal_cast(to=Int64, 123 :: Int32)
55 | > 123
56 |
57 | internal_cast(to=Int64, -123 :: Int32)
58 | > -123
59 |
60 | internal_cast(to=Int64, 123 :: Word8)
61 | > 123
62 |
63 | internal_cast(to=Int64, 1234 :: Word32)
64 | > 1234
65 |
66 | internal_cast(to=Int64, 4294967296 :: Word64) -- 2^32
67 | > 4294967296
68 |
69 | -- Casts to Word8
70 |
71 | internal_cast(to=Word8, 1234 :: Int32)
72 | > 0xd2
73 |
74 | internal_cast(to=Word8, 1234 :: Int)
75 | > 0xd2
76 |
77 | internal_cast(to=Word8, 1234 :: Word32)
78 | > 0xd2
79 |
80 | internal_cast(to=Word8, 1234 :: Word64)
81 | > 0xd2
82 |
83 | -- Casts to Word32
84 |
85 | internal_cast(to=Word32, 1234 :: Int32)
86 | > 0x4d2
87 |
88 | internal_cast(to=Word32, -2147483648 :: Int32)
89 | > 0x80000000
90 |
91 | internal_cast(to=Word32, 1234 :: Int64)
92 | > 0x4d2
93 |
94 | internal_cast(to=Word32, 4294968530 :: Int64) -- 2^32 + 1234
95 | > 0x4d2
96 |
97 | internal_cast(to=Word32, -1 :: Int64)
98 | > 0xffffffff
99 |
100 | internal_cast(to=Word32, 123 :: Word8)
101 | > 0x7b
102 |
103 | internal_cast(to=Word32, 1234 :: Word64)
104 | > 0x4d2
105 |
106 | internal_cast(to=Word32, 4294967296 :: Word64)
107 | > 0x0
108 |
109 | -- Casts to Word64
110 |
111 | internal_cast(to=Word64, 1234 :: Int32)
112 | > 0x4d2
113 |
114 | internal_cast(to=Word64, -1 :: Int32)
115 | > 0xffffffff
116 |
117 | internal_cast(to=Word64, 1234 :: Int64)
118 | > 0x4d2
119 |
120 | internal_cast(to=Word64, -1 :: Int64)
121 | > 0xffffffffffffffff
122 |
123 | internal_cast(to=Word64, 123 :: Word8)
124 | > 0x7b
125 |
126 | internal_cast(to=Word64, 1234 :: Word32)
127 | > 0x4d2
128 |
129 | internal_cast(to=Word64, 4294967295 :: Word32)
130 | > 0xffffffff
131 |
--------------------------------------------------------------------------------
/tests/complex-tests.dx:
--------------------------------------------------------------------------------
1 | import complex
2 |
3 | :p complex_floor $ Complex 0.3 0.6
4 | > Complex(0., 0.)
5 | :p complex_floor $ Complex 0.6 0.8
6 | > Complex(0., 1.)
7 | :p complex_floor $ Complex 0.8 0.6
8 | > Complex(1., 0.)
9 | :p complex_floor $ Complex 0.6 0.3
10 | > Complex(0., 0.)
11 |
12 | a = Complex 2.1 0.4
13 | b = Complex (-1.1) 1.3
14 | :p (a + b - a) ~~ b
15 | > True
16 | :p (a * b) ~~ (b * a)
17 | > True
18 | :p divide (a * b) a ~~ b
19 | > True
20 | -- This next test can be added once we parameterize the field in the VSpace typeclass.
21 | --:p ((a * b) / a) ~~ b
22 | --> True
23 | :p a == b
24 | > False
25 | :p a == a
26 | > True
27 | :p log (exp a) ~~ a
28 | > True
29 | :p exp (log a) ~~ a
30 | > True
31 | :p log2 (exp2 a) ~~ a
32 | > True
33 | :p exp2 (log2 a) ~~ a
34 | > True
35 | :p sqrt (sq a) ~~ a
36 | > True
37 | :p sqrt (Complex (-1.0) 0.0) ~~ (Complex 0.0 1.0)
38 | > True
39 | :p log ((Complex 1.0 0.0) + a) ~~ log1p a
40 | > True
41 | :p sin (-a) ~~ (-(sin a))
42 | > True
43 | :p cos (-a) ~~ cos a
44 | > True
45 | :p tan (-a) ~~ (- (tan a))
46 | > True
47 | :p exp (pi .* (Complex 0.0 1.0)) ~~ (Complex (-1.0) 0.0) -- Euler's identity
48 | > True
49 | :p ((sq (sin a)) + (sq (cos a))) ~~ (Complex 1.0 0.0)
50 | > True
51 | :p complex_abs b > 0.0
52 | > True
53 |
54 | :p sinh (Complex 1.2 3.2)
55 | > Complex(-1.506887, -0.1056956)
56 | :p cosh (Complex 1.2 3.2)
57 | > Complex(-1.807568, 0.08811359)
58 | :p tanh (Complex 1.1 0.1)
59 | > Complex(0.8033752, 0.03580933)
60 | :p tan (Complex 1.2 3.2)
61 | > Complex(0.002250167, 1.002451)
62 |
--------------------------------------------------------------------------------
/tests/exception-tests.dx:
--------------------------------------------------------------------------------
1 |
2 |
3 | def checkFloatInUnitInterval(x:Float) -> {Except} Float =
4 | assert $ x >= 0.0
5 | assert $ x <= 1.0
6 | x
7 |
8 | :p catch \. assert False
9 | > Nothing
10 |
11 | :p catch \. assert True
12 | > (Just ())
13 |
14 | :p catch \. checkFloatInUnitInterval 1.2
15 | > Nothing
16 |
17 | :p catch \. checkFloatInUnitInterval (-1.2)
18 | > Nothing
19 |
20 | :p catch \. checkFloatInUnitInterval 0.2
21 | > (Just 0.2)
22 |
23 | :p yield_state 0 \ref.
24 | catch \.
25 | ref := 1
26 | assert False
27 | ref := 2
28 | > 1
29 |
30 | :p catch \.
31 | for i:(Fin 5).
32 | if ordinal i > 3
33 | then throw()
34 | else 23
35 | > Nothing
36 |
37 | :p catch \.
38 | for i:(Fin 3).
39 | if ordinal i > 3
40 | then throw()
41 | else 23
42 | > (Just [23, 23, 23])
43 |
44 | -- Is this the result we want?
45 | :p yield_state zero \ref.
46 | catch \.
47 | for i:(Fin 6).
48 | if (ordinal i `rem` 2) == 0
49 | then throw()
50 | else ()
51 | ref!i := 1
52 | > [0, 1, 0, 1, 0, 1]
53 |
54 | :p catch \.
55 | run_state 0 \ref.
56 | ref := 1
57 | assert False
58 | ref := 2
59 | > Nothing
60 |
61 | -- https://github.com/google-research/dex-lang/issues/612
62 | def sashabug(h: ()) -> {Except} List Int =
63 | yield_state mempty \results.
64 | results := (get results) <> AsList 1 [2]
65 |
66 | catch \. (catch \. sashabug ())
67 | > (Just (Just (AsList 1 [2])))
68 |
--------------------------------------------------------------------------------
/tests/fft-tests.dx:
--------------------------------------------------------------------------------
1 | import complex
2 | import fft
3 |
4 | :p map nextpow2 [0, 1, 2, 3, 4, 7, 8, 9, 1023, 1024, 1025]
5 | > [0, 0, 1, 2, 2, 3, 3, 4, 10, 10, 11]
6 |
7 | a : (Fin 4)=>Complex = arb $ new_key 0
8 | :p a ~~ (ifft $ fft a)
9 | > True
10 | :p a ~~ (fft $ ifft a)
11 | > True
12 |
13 | b : (Fin 20)=>(Fin 70)=>Complex = arb $ new_key 0
14 | :p b ~~ (ifft2 $ fft2 b)
15 | > True
16 | :p b ~~ (fft2 $ ifft2 b)
17 | > True
18 |
--------------------------------------------------------------------------------
/tests/gpu-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | x = for i:(Fin 5). i_to_f $ ordinal i
3 | x
4 | > [0., 1., 2., 3., 4.]
5 |
6 | x + x
7 | > [0., 2., 4., 6., 8.]
8 |
9 | -- TODO: Make it a FileCheck test
10 | testNestedParallelism =
11 | for i:(Fin 10).
12 | x = ordinal i
13 | q = for j:(Fin 2000). i_to_f $ x * ordinal j
14 | (2.0 .* q, 4.0 .* q)
15 | (fst testNestedParallelism.(2@_)).(5@_)
16 | > 20.
17 |
18 | -- TODO: Make it a FileCheck test
19 | testNestedLoops =
20 | for i:(Fin 10).
21 | for j:(Fin 20).
22 | ordinal i * ordinal j
23 | testNestedLoops.(4@_).(5@_)
24 | > 20
25 |
26 | -- The state is large enough such that it shouldn't fit on the stack of a
27 | -- single GPU thread. It should get lifted to a top-level allocation instead.
28 | -- allocationLiftingTest =
29 | -- for i:(Fin 100).
30 | -- yieldState (for j:(Fin 1000). ordinal i) $ \s.
31 | -- s!(0@_) := get s!(0@_) + 1
32 | -- (allocationLiftingTest.(4@_).(0@_), allocationLiftingTest.(4@_).(1@_))
33 | -- > (5, 4)
34 |
--------------------------------------------------------------------------------
/tests/inline-tests.dx:
--------------------------------------------------------------------------------
1 | -- The "=== inline ===" strings below are a hack around the fact that
2 | -- Dex currently does two passes of inlining and prints the results of
3 | -- both. Surrounding the CHECK block with these commands constrains
4 | -- the body to occur in the output from the first inlining pass.
5 |
6 | @noinline
7 | def id'(x:Nat) -> Nat = x
8 |
9 | -- CHECK-LABEL: Inline for into for
10 | "Inline for into for"
11 |
12 | %passes inline
13 | :pp
14 | xs = for i:(Fin 10). ordinal i
15 | for j. xs[j] + 2
16 | -- CHECK: === inline ===
17 | -- CHECK: for
18 | -- CHECK-NOT: for
19 | -- CHECK: === inline ===
20 |
21 | -- CHECK-LABEL: Inline for into sum
22 | "Inline for into sum"
23 |
24 | %passes inline
25 | :pp sum for i:(Fin 10). ordinal i
26 | -- CHECK: === inline ===
27 | -- CHECK: for
28 | -- CHECK-NOT: for
29 | -- CHECK: === inline ===
30 |
31 | -- CHECK-LABEL: Inline nested for into for
32 | "Inline nested for into for"
33 |
34 | %passes inline
35 | :pp
36 | xs = for i:(Fin 10). for j:(Fin 20). ordinal i * ordinal j
37 | for j i. xs[i, j] + 2
38 | -- CHECK: === inline ===
39 | -- CHECK: for
40 | -- CHECK: for
41 | -- CHECK-NOT: for
42 | -- CHECK: === inline ===
43 |
44 | -- CHECK-LABEL: Inlining does not reorder effects
45 | "Inlining does not reorder effects"
46 |
47 | -- Note that it _would be_ legal to reorder independent effects, but
48 | -- the inliner currently does not do that. But the effect in this
49 | -- example is not legal to reorder in any case.
50 |
51 | %passes inline
52 | :pp run_state 0 \ct.
53 | xs = for i:(Fin 10).
54 | ct := (get ct) + 1
55 | ordinal i
56 | for j.
57 | ct := (get ct) * 2
58 | xs[j] + 2
59 | -- CHECK: === inline ===
60 | -- CHECK: for
61 | -- CHECK: for
62 | -- CHECK: === inline ===
63 |
64 | -- CHECK-LABEL: Inlining does not duplicate the inlinee through beta reduction
65 | "Inlining does not duplicate the inlinee through beta reduction"
66 |
67 | -- The check is for the error call in the dynamic check that `ix` has
68 | -- type `Fin 100`.
69 | %passes inline
70 | :pp
71 | ix = (id' 20)@(Fin 100)
72 | (for i:(Fin 100). ordinal i + ordinal i)[ix]
73 | -- CHECK: === inline ===
74 | -- CHECK: error
75 | -- CHECK-NOT: error
76 | -- CHECK: === inline ===
77 |
78 | -- CHECK-LABEL: Inlining does not violate type IR through beta reduction
79 | "Inlining does not violate type IR through beta reduction"
80 |
81 | -- Beta reducing this ix into the `i` index of the `for` should stop
82 | -- before it produces anything a type expression can't handle, and
83 | -- thus execute.
84 |
85 | :p
86 | ix = (1@(Fin 2))
87 | sum (for i:(Fin 2) j:(..i). ordinal j)[ix]
88 | -- CHECK: 1
89 | -- CHECK-NOT: Compiler bug
90 |
91 | -- CHECK-LABEL: Inlining simplifies case-of-known-constructor
92 | "Inlining simplifies case-of-known-constructor"
93 |
94 | -- Inlining xs exposes a case-of-known-constructor opportunity here;
95 | -- the first inlining pass doesn't take it (yet) because it's
96 | -- conservative about inlining `i` into the body of `xs`, but the
97 | -- second pass does.
98 | %passes inline
99 | :pp
100 | xs = for i:(Either (Fin 3) (Fin 4)).
101 | case i of
102 | Left k -> 1
103 | Right k -> 2
104 | for j:(Fin 3). xs[Left j]
105 | -- CHECK: === inline ===
106 | -- CHECK: for
107 | -- CHECK: case
108 | -- CHECK: === inline ===
109 | -- CHECK: for
110 | -- CHECK-NOT: case
111 |
112 | -- CHECK-LABEL: Inlining carries out the case-of-case optimization
113 | "Inlining carries out the case-of-case optimization"
114 |
115 | -- Before inlining there are two cases, but attempting to inline `x`
116 | -- reveals a case-of-case opprtunity, which in turn exposes
117 | -- case-of-known-constructor in each branch, leading to just one case
118 | -- in the end.
119 | %passes inline
120 | :pp
121 | x = if id'(3) > 2
122 | then Just 4
123 | else Nothing
124 | case x of
125 | Just a -> a * a
126 | Nothing -> 0
127 | -- CHECK: === inline ===
128 | -- CHECK: case
129 | -- CHECK-NOT: case
130 | -- CHECK: === inline ===
131 |
--------------------------------------------------------------------------------
/tests/instance-interface-syntax-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | interface Empty(a:Type)
3 | pass
4 | -- CHECK-NOT: Parse error
5 |
6 | instance Empty(Int)
7 | pass
8 | -- CHECK-NOT: Parse error
9 |
10 | instance Empty(Float32)
11 | def witness() = 0.0
12 | -- CHECK-NOT: Parse error
13 | -- CHECK: Error: variable not in scope: witness
14 |
15 | interface Inhabited(a)
16 | witness : a
17 | -- CHECK-NOT: Parse error
18 |
19 | instance Inhabited(Int)
20 | witness = 0
21 | -- CHECK-NOT: Parse error
22 |
23 | instance Inhabited(Float64)
24 | witness = f_to_f64(0.0)
25 | pass
26 | -- CHECK: Parse error
27 | -- CHECK: unexpected "pa"
28 | -- CHECK: expecting end of line
29 |
30 | instance Inhabited(Word32)
31 | witness = 0
32 | pass
33 | -- CHECK: Parse error
34 |
--------------------------------------------------------------------------------
/tests/instance-methods-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | interface FooBar0(a)
3 | foo0 : (a) -> Int
4 | bar0 : (a) -> Int
5 |
6 | instance FooBar0(Int)
7 | def foo0(x) = x + 1
8 | def bar0(x) = foo0 x + 1
9 |
10 | w : Int = 42
11 |
12 | -- CHECK: 43
13 | foo0 w
14 | > 43
15 |
16 | -- CHECK: 44
17 | bar0 w
18 | > 44
19 |
20 |
21 | interface FooBar1(a)
22 | foo1 : (a) -> Int
23 | bar1 : (a) -> Int
24 |
25 | instance FooBar1(Int)
26 | foo1 = \x. x + 1
27 | -- Fails: Definition of `bar1` uses the class method `bar1` (with index 1);
28 | -- but the instance `FooBar1 Int` is currently still being defined and, at
29 | -- this point, can only grant access to method `foo1` (with index 0).
30 | bar1 = \x. bar1 x + 1
31 | > Type error:Wrong number of positional arguments provided. Expected 1 but got 0
32 | >
33 | > foo1 = \x. x + 1
34 | > ^^^^^^^^^^^^^^^^
35 | -- CHECK: Type error:Couldn't synthesize a class dictionary for: (FooBar1 Int32)
36 | -- CHECK: bar1 = \x. bar1 x + 1
37 | -- CHECK: ^^^^^
38 |
39 |
40 | interface FooBar2(a)
41 | foo2 : (a) -> Int
42 | bar2 : (a) -> Int
43 |
44 | def f2(x:a) given (a|FooBar2) = (\y. foo2 y + 1) x
45 | -- The defintion of `f2` is OK because argument `d : FooBar2 a` grants access to
46 | -- all methods of class `FooBar2 a`. (Only one method of `FooBar2` is actually
47 | -- used in the body of `f2`.)
48 |
49 | def g2(x:a) given (a|FooBar2) = (\y z. foo2 y + z) x (bar2 x)
50 | -- The defintion of `g2` is OK because argument `d : FooBar2 a` grants access to
51 | -- all methods of class `FooBar2 a`.
52 |
53 |
54 | instance FooBar2(Int)
55 | def foo2(x) = x + 1
56 | -- Fails: The definition of `bar2` uses `f2`, which requires a dictionary
57 | -- `d : FooBar2 Int` that has access to all methods of `FooBar2 Int`.
58 | def bar2(x) = f2 x + 1
59 | > Type error:Couldn't synthesize a class dictionary for: (FooBar2 Int32)
60 | >
61 | > def bar2(x) = f2 x + 1
62 | > ^^^^^
63 | -- CHECK: Type error:Couldn't synthesize a class dictionary for: (FooBar2 Int32)
64 | -- CHECK: bar2 = \x. f2 x + 1
65 | -- CHECK: ^^^
66 |
67 |
68 | interface Shows0(a)
69 | shows0 : (a) -> String
70 | showsList0 : (List a) -> String
71 |
72 | -- The body of method `showsList0` uses method `shows0` from the same instance.
73 | instance Shows0(Nat)
74 | def shows0(x) = show x
75 | def showsList0(xs) =
76 | AsList(n, ys) = xs
77 | strings = map shows0 ys
78 | reduce "" (<>) strings
79 |
80 | showsList0 (AsList 3 [0, 1, 2])
81 | > "012"
82 | -- CHECK: "012"
83 |
84 | interface Shows1(a)
85 | shows1 : (a) -> String
86 | showsList1 : (List a) -> String
87 |
88 | instance Shows1(Nat)
89 | def shows1(x) = showsList1 (AsList 1 [x])
90 | -- Methods `shows1` and `showsList1` refer to each other in a mutually recursive
91 | -- fashion: the body of method `showsList1` uses method `shows1` from the same
92 | -- instance, and the body of method `showsList1` uses method `shows1` also from
93 | -- this instance.
94 | def showsList1(xs) =
95 | AsList(n, ys) = xs
96 | strings = map shows1 ys
97 | reduce "" (<>) strings
98 | > Type error:Couldn't synthesize a class dictionary for: (Shows1 Nat)
99 | >
100 | > def shows1(x) = showsList1 (AsList 1 [x])
101 | > ^^^^^^^^^^^^^^^^^^^^^^^^^
102 | -- CHECK: Type error:Couldn't synthesize a class dictionary for: (Shows1 Nat)
103 | -- CHECK: shows1 = \x. showsList1 (AsList 1 [x])
104 | -- CHECK: ^^^^^^^^^^^
105 |
--------------------------------------------------------------------------------
/tests/io-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | :p unsafe_io \.
3 | with_temp_file \fname.
4 | with_file fname WriteMode \stream.
5 | fwrite stream "lorem ipsum\n"
6 | fwrite stream "dolor sit amet\n"
7 | read_file fname
8 | > "lorem ipsum
9 | > dolor sit amet
10 | > "
11 |
12 | :p unsafe_io \.
13 | with_alloc 4 \ptr:(Ptr Nat).
14 | for i:(Fin 4). store (ptr +>> ordinal i) (ordinal i)
15 | table_from_ptr(n=Fin 4, ptr)
16 | > [0, 1, 2, 3]
17 |
18 | unsafe_io \.
19 | print "testing log"
20 | 1.0 -- prevent DCE
21 | > testing log
22 | > 1.
23 |
24 | unsafe_io \.
25 | for i':(Fin 10).
26 | i = ordinal i'
27 | if rem i 2 == 0
28 | then print $ show i <> " is even"
29 | else print $ show i <> " is odd"
30 | 1.0 -- prevent DCE
31 | > 0 is even
32 | > 1 is odd
33 | > 2 is even
34 | > 3 is odd
35 | > 4 is even
36 | > 5 is odd
37 | > 6 is even
38 | > 7 is odd
39 | > 8 is even
40 | > 9 is odd
41 | > 1.
42 |
43 | :p storage_size(a=Int)
44 | > 4
45 |
46 | :p unsafe_io \.
47 | with_alloc 1 \ptr:(Ptr Int).
48 | store ptr 3
49 | load ptr
50 | > 3
51 |
52 | :p with_stack Nat \stack.
53 | stack.extend(for i:(Fin 1000). ordinal i)
54 | stack.extend(for i:(Fin 1000). ordinal i)
55 | AsList(_, xs) = stack.read()
56 | sum xs
57 | > 999000
58 |
59 | :p unsafe_io \.
60 | s = for i:(Fin 10000). i_to_w8 $ f_to_i $ 128.0 * rand (ixkey (new_key 0) i)
61 | with_temp_file \fname.
62 | with_file fname WriteMode \stream.
63 | fwrite stream $ AsList _ s
64 | AsList(_, s') = read_file fname
65 | sum (for i. w8_to_i s[i]) == sum (for i. w8_to_i s'[i])
66 | > True
67 |
68 | :p unsafe_io \. get_env "NOT_AN_ENV_VAR"
69 | > Nothing
70 |
71 | :p unsafe_io \. get_env "DEX_TEST_MODE"
72 | > (Just "t")
73 |
74 | :p dex_test_mode()
75 | > True
76 |
--------------------------------------------------------------------------------
/tests/linalg-tests.dx:
--------------------------------------------------------------------------------
1 | import linalg
2 |
3 | -- Check that the optimized matmul gives the same answers as the naive one
4 | amat = for i:(Fin 100) j:(Fin 100). n_to_f $ ordinal (i, j)
5 |
6 | :p tiled_matmul(amat, amat) ~~ naive_matmul amat amat
7 | > True
8 |
9 | -- Check that the inverse of the inverse is identity.
10 | mat = [[11.,9.,24.,2.],[1.,5.,2.,6.],[3.,17.,18.,1.],[2.,5.,7.,1.]]
11 | :p mat ~~ (invert (invert mat))
12 | > True
13 |
14 | -- Check that solving gives the inverse.
15 | v = [1., 2., 3., 4.]
16 | :p v ~~ (mat **. (solve mat v))
17 | > True
18 |
19 | -- Check that det and exp(logdet) are the same.
20 | (s, logdet) = sign_and_log_determinant mat
21 | :p (determinant mat) ~~ (s * (exp logdet))
22 | > True
23 |
24 | -- Matrix integer powers.
25 | :p matrix_power mat 0 ~~ eye
26 | > True
27 | :p matrix_power mat 1 ~~ mat
28 | > True
29 | :p matrix_power mat 2 ~~ (mat ** mat)
30 | > True
31 | :p matrix_power mat 5 ~~ (mat ** mat ** mat ** mat ** mat)
32 | > True
33 |
34 | :p trace mat == (11. + 5. + 18. + 1.)
35 | > True
36 |
37 | -- Check that we can linearize LU decomposition
38 | -- This is a regression test for Issue #842.
39 | snd(linearize (\x. snd $ sign_and_log_determinant [[x]]) 1.0)(2.0)
40 | > 2.
41 |
42 | -- Check that we can differentiate through LU decomposition
43 | -- This is a regression test for Issue #848.
44 | grad (\x. (pivotize [[x]]).sign) 1.0
45 | > 0.
46 |
47 | grad (\x. snd $ sign_and_log_determinant [[x]]) 2.0
48 | > 0.5
49 |
50 | -- Check forward_substitute solve by comparing
51 | -- against zero-padding and doing the full solve.
52 | def padLowerTriMat(mat:LowerTriMat n v) -> n=>n=>v given (n|Ix, v|Add) =
53 | for i j.
54 | if (ordinal j)<=(ordinal i)
55 | then mat[i,unsafe_project j]
56 | else zero
57 |
58 | lower : LowerTriMat (Fin 4) Float = arb $ new_key 0
59 | lower_padded = padLowerTriMat lower
60 | vec : (Fin 4)=>Float = arb $ new_key 0
61 |
62 | forward_substitute lower vec ~~ solve lower_padded vec
63 | > True
64 |
--------------------------------------------------------------------------------
/tests/lower.dx:
--------------------------------------------------------------------------------
1 | for i:(Fin 2) j:(Fin 4). ordinal (i,j)
2 | > [[0, 1, 2, 3], [4, 5, 6, 7]]
3 |
--------------------------------------------------------------------------------
/tests/module-tests.dx:
--------------------------------------------------------------------------------
1 | import test_module_A
2 | import test_module_B
3 |
4 | :p 1 + 1
5 | > 2
6 |
7 | :p test_module_A_val + 4
8 | > 7
9 |
10 | :p test_module_amb
11 | > Error: ambiguous variable: test_module_amb is defined:
12 | > in test_module_A
13 | > in test_module_B
14 | >
15 | >
16 | > :p test_module_amb
17 | > ^^^^^^^^^^^^^^^
18 |
19 | :p test_module_B_val_from_C
20 | > 23
21 |
22 | :p test_module_C_val
23 | > Error: variable not in scope: test_module_C_val
24 | >
25 | > :p test_module_C_val
26 | > ^^^^^^^^^^^^^^^^^
27 |
28 | :p test_module_A_fun 2
29 | > 4
30 |
31 | :p test_module_A_fun_noinline 3
32 | > 6
33 |
34 | :p fooMethodExportFromB 1
35 | > 2
36 |
37 | :p fooMethodExportFromB 1.0
38 | > 10.
39 |
40 | :p arrayVal
41 | > [1, 2, 3]
42 |
43 | :p arrayVal2
44 | > [2, 4, 6]
45 |
--------------------------------------------------------------------------------
/tests/parser-combinator-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | import parser
3 |
4 | parseABC : Parser () = MkParser \h.
5 | parse h $ p_char 'A'
6 | parse h $ p_char 'B'
7 | parse h $ p_char 'C'
8 |
9 | :p run_parser "AAA" parseABC
10 | > Nothing
11 |
12 | :p run_parser "ABCABC" parseABC
13 | > Nothing
14 |
15 | :p run_parser "AB" parseABC
16 | > Nothing
17 |
18 | :p run_parser "ABC" parseABC
19 | > (Just ())
20 |
21 | def parseT() ->> Parser Bool = MkParser \h.
22 | parse h $ p_char 'T'
23 | True
24 |
25 | def parseF() ->> Parser Bool = MkParser \h.
26 | parse h $ p_char 'F'
27 | False
28 |
29 | def parseTF() ->> Parser Bool =
30 | parseT <|> parseF
31 |
32 | def parserTFTriple() ->> Parser (Fin 3=>Bool) = MkParser \h.
33 | for i. parse h parseTF
34 |
35 | :p run_parser "TTF" parserTFTriple
36 | > (Just [True, True, False])
37 |
38 | :p run_parser "TTFX" parserTFTriple
39 | > Nothing
40 |
41 | :p run_parser "TTFFTT" $ parse_many parseTF
42 | > (Just (AsList 6 [True, True, False, False, True, True]))
43 |
44 | :p run_parser "1021389" $ parse_many parse_digit
45 | > (Just (AsList 7 [1, 0, 2, 1, 3, 8, 9]))
46 |
47 | :p run_parser "1389" $ parse_int
48 | > (Just 1389)
49 |
50 | :p run_parser "01389" $ parse_int
51 | > (Just 1389)
52 |
53 | :p run_parser "-1389" $ parse_int
54 | > (Just -1389)
55 |
56 | split ' ' " This is a sentence. "
57 | > (AsList 4 ["This", "is", "a", "sentence."])
58 |
--------------------------------------------------------------------------------
/tests/print-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | :pcodegen [(),(),()]
3 | > [(), (), ()]
4 |
5 | -- :pcodegen {x = 1.0, y = 2}
6 | -- > {x = 1., y = 2}
7 |
8 | :pcodegen (the Nat 60, the Int 60, the Float 60, the Int64 60, the Float64 60)
9 | > (60, 60, 60., 60, 60.)
10 |
11 | :pcodegen (the Word8 60, the Word32 60, the Word64 60)
12 | > (0x3c, 0x3c, 0x3c)
13 |
14 | :pcodegen [Just (Just 1.0), Just Nothing, Nothing]
15 | > [(Just (Just 1.)), (Just Nothing), Nothing]
16 |
17 | data MyType = MyValue(Nat)
18 |
19 | :pcodegen MyValue 1
20 | > (MyValue 1)
21 |
22 | :pcodegen "the quick brown fox jumps over the lazy dog"
23 | > "the quick brown fox jumps over the lazy dog"
24 |
25 | :pcodegen ['a', 'b', 'c']
26 | > [0x61, 0x62, 0x63]
27 |
28 | :pcodegen "abcd"
29 | > "abcd"
30 |
--------------------------------------------------------------------------------
/tests/read-tests.dx:
--------------------------------------------------------------------------------
1 | parseString "123" :: Maybe Float
2 | > (Just 123.)
3 | parseString "123.4" :: Maybe Float
4 | > (Just 123.4)
5 | parseString "123x" :: Maybe Float
6 | > Nothing
7 | parseString "x123" :: Maybe Float
8 | > Nothing
9 |
--------------------------------------------------------------------------------
/tests/repl-multiline-test-expected-output:
--------------------------------------------------------------------------------
1 | >=> >=> >=> >=> >=> >=> ... ... ... ... 30.
2 | >=> >=> ... ... >=> >=> (1, 1)
3 | >=> >=> >=> >=> 3.
4 | >=>
--------------------------------------------------------------------------------
/tests/repl-multiline-test.dx:
--------------------------------------------------------------------------------
1 |
2 | -- comment
3 |
4 | 'Single-line multiline comment
5 |
6 | :p
7 | triple = \x.
8 | y = x + x
9 | x + y
10 | triple 10.0
11 |
12 | f = \x:Int.
13 | (x,
14 | x)
15 |
16 | f 1
17 |
18 | y = 1. * 3.
19 |
20 | :p y
21 |
--------------------------------------------------------------------------------
/tests/repl-regression-528-test-expected-output:
--------------------------------------------------------------------------------
1 | >=> ... ... Parse error:1:6:
2 | |
3 | 1 | :help
4 | | ^
5 | unrecognized command: "help"
6 |
7 | >=>
--------------------------------------------------------------------------------
/tests/repl-regression-528-test.dx:
--------------------------------------------------------------------------------
1 | :help
2 |
3 | asdf
4 |
--------------------------------------------------------------------------------
/tests/serialize-tests.dx:
--------------------------------------------------------------------------------
1 | :p 1
2 | > 1
3 |
4 | :p 1.0
5 | > 1.
6 |
7 | :p [1, 2, 3]
8 | > [1, 2, 3]
9 |
10 | :p [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
11 | > [[1., 2., 3.], [4., 5., 6.]]
12 |
13 | :p from_ordinal(n=Fin 10, 7)
14 | > 7
15 |
16 | :p [True, False]
17 | > [True, False]
18 |
19 | :p ()
20 | > ()
21 |
22 | x = ['a', 'b']
23 | :p for p.
24 | (i,j) = p
25 | [x[i], x[j]]
26 | > [[0x61, 0x61], [0x61, 0x62], [0x62, 0x61], [0x62, 0x62]]
27 |
28 | 'Values without a pretty-printer
29 |
30 | :p Int
31 | > Int32
32 |
33 | :p Fin 10
34 | > (Fin 10)
35 |
36 | :p (Fin 10, Fin 20)
37 | > ((Fin 10), (Fin 20))
38 |
--------------------------------------------------------------------------------
/tests/set-tests.dx:
--------------------------------------------------------------------------------
1 | import set
2 |
3 | -- check order invariance.
4 | :p (to_set ["Bob", "Alice", "Charlie"]) == (to_set ["Charlie", "Bob", "Alice"])
5 | > True
6 |
7 | -- check uniqueness.
8 | :p (to_set ["Bob", "Alice", "Alice", "Charlie"]) == (to_set ["Charlie", "Charlie", "Bob", "Alice"])
9 | > True
10 |
11 | set1 = to_set ["Xeno", "Alice", "Bob"]
12 | set2 = to_set ["Bob", "Xeno", "Charlie"]
13 |
14 | :p set1 == set2
15 | > False
16 |
17 | :p set_union set1 set2
18 | > (UnsafeAsSet 4 ["Alice", "Bob", "Charlie", "Xeno"])
19 |
20 | :p set_intersect set1 set2
21 | > (UnsafeAsSet 2 ["Bob", "Xeno"])
22 |
23 | :p remove_duplicates_from_sorted ["Alice", "Alice", "Alice", "Bob", "Bob", "Charlie", "Charlie", "Charlie"]
24 | > (AsList 3 ["Alice", "Bob", "Charlie"])
25 |
26 | :p set1 == (set_union set1 set1)
27 | > True
28 |
29 | :p set1 == (set_intersect set1 set1)
30 | > True
31 |
32 | '#### Empty set tests
33 |
34 | emptyset = to_set ([]::(Fin 0)=>String)
35 |
36 | :p emptyset == emptyset
37 | > True
38 |
39 | :p emptyset == (set_union emptyset emptyset)
40 | > True
41 |
42 | :p emptyset == (set_intersect emptyset emptyset)
43 | > True
44 |
45 | :p set1 == (set_union set1 emptyset)
46 | > True
47 |
48 | :p emptyset == (set_intersect set1 emptyset)
49 | > True
50 |
51 | '### Set Index Set tests
52 |
53 | names2 = to_set ["Bob", "Alice", "Charlie", "Alice"]
54 |
55 | Person : Type = Element names2
56 |
57 | :p size Person
58 | > 3
59 |
60 | -- Check that ordinal and unsafeFromOrdinal are inverses.
61 | roundTrip = for i:Person.
62 | i == (unsafe_from_ordinal (ordinal i))
63 | :p all roundTrip
64 | > True
65 |
66 | -- Check that member and value are inverses.
67 | roundTrip2 = for i:Person.
68 | s = value i
69 | ix = member s names2
70 | i == from_just ix
71 | :p all roundTrip2
72 | > True
73 |
74 | setix : Person = from_just $ member "Bob" names2
75 | :p setix
76 | > Element(1)
77 |
78 | setix2 : Person = from_just $ member "Charlie" names2
79 | :p setix2
80 | > Element(2)
81 |
--------------------------------------------------------------------------------
/tests/shadow-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | -- repeated vars in patterns not allowed
3 | :p
4 | (x, x) = (1, 1)
5 | x
6 | > Error: variable already defined within pattern: x
7 | >
8 | > (x, x) = (1, 1)
9 | > ^
10 |
11 | :p
12 | f = \p.
13 | (x, x) = p
14 | x
15 | f (1, 1)
16 | > Error: variable already defined within pattern: x
17 | >
18 | > (x, x) = p
19 | > ^
20 |
21 | -- TODO: re-enable if we choose to allow non-peer shadowing
22 | -- -- shouldn't cause error even though it shadows x elsewhere
23 | -- x = 50
24 |
25 | -- :p let x = 100 in (let x = 200 in x)
26 |
27 | -- > [200]
28 |
29 | arr = 10
30 |
31 | -- TODO: enable when we handle this case
32 | -- _ = 10
33 | -- _ = 10 -- underscore shadows allowed
34 |
35 | arr = 20
36 | > Error: variable already defined: arr
37 | >
38 | > arr = 20
39 | > ^^^^
40 |
41 | :p arr
42 | > Error: ambiguous variable: arr is defined:
43 | > in this file
44 | > in this file
45 | >
46 | >
47 | > :p arr
48 | > ^^^
49 |
50 | -- testing top-level shadowing
51 | f : (given (a:Type), a) -> a = \x. x
52 |
53 | x = 1
54 |
55 | :p f 1
56 | > 1
57 |
58 | :p y
59 | > Error: variable not in scope: y
60 | >
61 | > :p y
62 | > ^
63 |
64 | (_, _, z) = (1,2,3)
65 |
66 | :p z
67 | > 3
68 |
69 | :p
70 | (_, _, w) = (1,2,4)
71 | w
72 | > 4
73 |
74 | -- Testing data shadowing
75 | data Shadow = Shadow
76 | > Error: variable already defined: Shadow
77 |
78 | data Shadow2 =
79 | Shadow1(Int)
80 | Shadow1
81 | > Error: variable already defined: Shadow1
82 |
83 | ShadowCon = 1
84 | data Shadow3 = ShadowCon
85 | > Error: variable already defined: ShadowCon
86 |
87 | data Shadow4 = ShadowCon'
88 | ShadowCon' = 1
89 | > Error: variable already defined: ShadowCon'
90 | >
91 | > ShadowCon' = 1
92 | > ^^^^^^^^^^^
93 |
--------------------------------------------------------------------------------
/tests/show-tests.dx:
--------------------------------------------------------------------------------
1 | '# `Show` instances
2 | -- String
3 |
4 | :p show "abc"
5 | > "abc"
6 |
7 | -- Int32
8 |
9 | :p show (1234 :: Int32)
10 | > "1234"
11 |
12 | :p show (-1234 :: Int32)
13 | > "-1234"
14 |
15 | :p show ((f_to_i (-(pow 2. 31.))) :: Int32)
16 | > "-2147483648"
17 |
18 | -- Int64
19 |
20 | :p show (i_to_i64 1234 :: Int64)
21 | > "1234"
22 |
23 | :p show (i_to_i64 (-1234) :: Int64)
24 | > "-1234"
25 |
26 | -- Float32
27 |
28 | :p show (123.456789 :: Float32)
29 | > "123.456787"
30 |
31 | :p show ((pow 2. 16.) :: Float32)
32 | > "65536"
33 |
34 | -- FIXME(https://github.com/google-research/dex-lang/issues/316):
35 | -- Unparenthesized expression with type ascription does not parse.
36 | -- :p show (nan: Float32)
37 |
38 | :p show (nan :: Float32)
39 | > "nan"
40 |
41 | -- Note: `show nan` (Dex runtime dtoa implementation) appears different from
42 | -- `:p nan` (Dex interpreter implementation).
43 | :p nan
44 | > nan
45 |
46 | :p show (infinity :: Float32)
47 | > "inf"
48 |
49 | -- Note: `show infinity` (Dex runtime dtoa implementation) appears different from
50 | -- `:p nan` (Dex interpreter implementation).
51 | :p infinity
52 | > inf
53 |
54 | -- Float64
55 |
56 | :p show (f_to_f64 123.456789:: Float64)
57 | > "123.456787109375"
58 |
59 | :p show (f_to_f64 (pow 2. 16.):: Float64)
60 | > "65536"
61 |
62 | :p show ((f_to_f64 nan):: Float64)
63 | > "nan"
64 |
65 | -- Note: `show nan` (Dex runtime dtoa implementation) appears different from
66 | -- `:p nan` (Dex interpreter implementation).
67 | :p (f_to_f64 nan)
68 | > nan
69 |
70 | :p show ((f_to_f64 infinity):: Float64)
71 | > "inf"
72 |
73 | -- Note: `show infinity` (Dex runtime dtoa implementation) appears different from
74 | -- `:p nan` (Dex interpreter implementation).
75 | :p (f_to_f64 infinity)
76 | > inf
77 |
78 | -- Tuples
79 |
80 | :p show (123, 456)
81 | > "(123, 456)"
82 |
83 | :p show ("abc", 123)
84 | > "(abc, 123)"
85 |
86 | :p show ("abc", 123, ("def", 456))
87 | > "(abc, 123, (def, 456))"
88 |
--------------------------------------------------------------------------------
/tests/sort-tests.dx:
--------------------------------------------------------------------------------
1 | import sort
2 |
3 | :p is_sorted $ sort []::((Fin 0)=>Int)
4 | > True
5 | :p is_sorted $ sort [9, 3, 7, 4, 6, 1, 9, 1, 9, -1, 10, 10, 100, 0]
6 | > True
7 |
8 | :p
9 | xs = [1,2,4]
10 | for i:(Fin 6).
11 | search_sorted_exact(xs, ordinal i)
12 | > [Nothing, (Just 0), (Just 1), Nothing, (Just 2), Nothing]
13 |
14 | '### Lexical Sorting Tests
15 |
16 | :p "aaa" < "bbb"
17 | > True
18 |
19 | :p "aa" < "bbb"
20 | > True
21 |
22 | :p "a" < "aa"
23 | > True
24 |
25 | :p "aaa" > "bbb"
26 | > False
27 |
28 | :p "aa" > "bbb"
29 | > False
30 |
31 | :p "a" > "aa"
32 | > False
33 |
34 | :p "a" < "aa"
35 | > True
36 |
37 | :p ("" :: List Word8) > ("" :: List Word8)
38 | > False
39 |
40 | :p ("" :: List Word8) < ("" :: List Word8)
41 | > False
42 |
43 | :p "a" > "a"
44 | > False
45 |
46 | :p "a" < "a"
47 | > False
48 |
49 | :p "Thomas" < "Thompson"
50 | > True
51 |
52 | :p "Thomas" > "Thompson"
53 | > False
54 |
55 | :p is_sorted $ sort ["Charlie", "Alice", "Bob", "Aaron"]
56 | > True
57 |
--------------------------------------------------------------------------------
/tests/stack-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | with_stack Nat \stack.
3 | stack.push 10
4 | stack.push 11
5 | stack.pop()
6 | stack.pop()
7 | > (Just 10)
8 |
9 | with_stack Nat \stack.
10 | stack.push 10
11 | stack.push 11
12 | stack.pop()
13 | stack.pop()
14 | stack.pop() -- Check that popping an empty stack is OK.
15 | stack.push 20
16 | stack.push 21
17 | stack.pop()
18 | > (Just 21)
19 |
20 | with_stack Nat \stack.
21 | stack.pop()
22 | > Nothing
23 |
--------------------------------------------------------------------------------
/tests/standalone-function-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | @noinline
3 | def standalone_sum(xs:n=>v) -> v given (n|Ix, v|Add) =
4 | sum xs
5 |
6 | vec3 = [1,2,3]
7 | vec2 = [4,5]
8 |
9 | -- TODO: test that we only get one copy inlined (hard to without dumping IR
10 | -- until we have logging for that sort of thing)
11 | :p standalone_sum vec2 + standalone_sum vec3
12 | > 15
13 |
14 | mat23 = [[1,2,3],[4,5,6]]
15 | mat32 = [[1,2],[3,4],[5,6]]
16 |
17 | @noinline
18 | def standalone_transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) =
19 | for i j. x[j,i]
20 |
21 | :p (standalone_transpose mat23, standalone_transpose mat32)
22 | > ([[1, 4], [2, 5], [3, 6]], [[1, 3, 5], [2, 4, 6]])
23 |
24 | xs = [1,2,3]
25 |
26 | @noinline
27 | def foo(_:()) -> Nat = sum xs
28 |
29 | foo ()
30 | > 6
31 |
32 | 'Regression test for #1152. The standalone function is just here to
33 | make the size of the tables unknown. The actual bug is in Alegbra
34 | handling an expression like `sum_{i=0}^k k * i` where the same
35 | name occurs in the monomial and the limit.
36 |
37 | def LowerTriMat(n|Ix, v:Type) -> Type = (i:n)=>(..i)=>v
38 | def UpperTriMat(n|Ix, v:Type) -> Type = (i:n)=>(i..)=>v
39 |
40 | @noinline
41 | def bar(n: Nat) -> Float =
42 | (for k. for j:(..k). 0.0, for k. for j:(k..). 0.0) :: (LowerTriMat (Fin n) Float, UpperTriMat (Fin n) Float)
43 | 0.0
44 |
45 | bar 2
46 | > 0.
47 |
--------------------------------------------------------------------------------
/tests/struct-tests.dx:
--------------------------------------------------------------------------------
1 |
2 | struct MyStruct =
3 | field1 : Int
4 | field2 : Float
5 | field3 : String
6 |
7 | my_struct = MyStruct 1 2 "abc"
8 |
9 | :p my_struct.field3
10 | > "abc"
11 |
12 | :p my_struct.(1 + 1)
13 | > Syntax error: Field must be a name
14 | >
15 | > :p my_struct.(1 + 1)
16 | > ^^^^^^^
17 |
18 | > Parse error:12:13:
19 | > |
20 | > 12 | :p my_struct.(1 + 1)
21 | > | ^^
22 | > unexpected ".("
23 | > expecting "->", "..", "<..", "with", backquoted name, end of input, end of line, infix operator, name, or symbol name
24 | :p my_struct
25 | > MyStruct(1, 2., "abc")
26 |
27 | :t my_struct
28 | > MyStruct
29 |
30 | struct MyParametricStruct(a) =
31 | foo : a
32 | bar : Nat
33 |
34 | :p
35 | foo = MyParametricStruct(1.0, 1)
36 | foo.bar
37 | > 1
38 |
39 | :p
40 | foo = MyParametricStruct(1.0, 1)
41 | foo.baz
42 | > Type error:Can't resolve field baz of type (MyParametricStruct Float32)
43 | > Known fields are: [bar, foo, 0, 1]
44 | >
45 | > foo.baz
46 | > ^^^
47 |
48 |
49 | x = (1, 2)
50 |
51 | x.0
52 | > 1
53 |
54 | x.1
55 | > 2
56 |
57 | x.2
58 | > Type error:Can't resolve field 2 of type (Nat, Nat)
59 | > Known fields are: [0, 1]
60 | >
61 | > x.2
62 | > ^
63 |
64 | x.foo
65 | > Type error:Can't resolve field foo of type (Nat, Nat)
66 | > Known fields are: [0, 1]
67 | >
68 | > x.foo
69 | > ^^^
70 |
71 | struct Thing(a|Add) =
72 | x : a
73 | y : a
74 |
75 | def incby(n:a) -> Thing(a) =
76 | Thing(self.x + n, self.y + n)
77 |
78 | Thing(1,2).incby(10)
79 | > Thing(11, 12)
80 |
81 | struct MissingConstraint(n) =
82 | thing : n=>Float
83 | > Type error:Couldn't synthesize a class dictionary for: (Ix n)
84 | >
85 | > thing : n=>Float
86 | > ^^^^^^^^
87 |
88 | data AnotherMissingConstraint(n) =
89 | MkAnotherMissingConstraint(n=>Float)
90 | > Type error:Couldn't synthesize a class dictionary for: (Ix n)
91 | >
92 | > MkAnotherMissingConstraint(n=>Float)
93 | > ^^^^^^^^
94 |
--------------------------------------------------------------------------------
/tests/test_module_A.dx:
--------------------------------------------------------------------------------
1 |
2 | import test_module_C
3 |
4 | test_module_amb = 10
5 |
6 | test_module_A_val = 1 + 2
7 |
8 | def test_module_A_fun(x:Int) -> Int = x + x
9 |
10 | @noinline
11 | def test_module_A_fun_noinline(x:Int) -> Int = x + x
12 |
13 | instance FooClass(Float)
14 | def fooMethod(x) = 10.0 * x
15 |
--------------------------------------------------------------------------------
/tests/test_module_B:
--------------------------------------------------------------------------------
1 |
2 | import test_module_C
3 |
4 | test_module_B_val = 10 + 2
5 |
6 |
--------------------------------------------------------------------------------
/tests/test_module_B.dx:
--------------------------------------------------------------------------------
1 |
2 | import test_module_C
3 |
4 | test_module_amb = 10
5 |
6 | test_module_B_val = 10
7 |
8 | test_module_B_val_from_C = test_module_C_val
9 |
10 | instance FooClass(Nat)
11 | def fooMethod(x) = x + x
12 |
13 | arrayVal = [1,2,3]
14 |
15 | arrayVal2 = for i. arrayVal[i] * 2
16 |
17 | def fooMethodExportFromB(x:a) -> a given (a|FooClass) = fooMethod x
18 |
--------------------------------------------------------------------------------
/tests/test_module_C.dx:
--------------------------------------------------------------------------------
1 |
2 |
3 | test_module_C_val = 23
4 |
5 | interface FooClass(a)
6 | fooMethod : (a) -> a
7 |
--------------------------------------------------------------------------------
/tests/trig-tests.dx:
--------------------------------------------------------------------------------
1 | :p isnan nan
2 | > True
3 | :p isnan 1.0
4 | > False
5 | :p isinf infinity
6 | > True
7 | :p isinf (-infinity)
8 | > True
9 | :p isinf 1.0
10 | > False
11 |
12 | :p either_is_nan infinity nan
13 | > True
14 | :p either_is_nan nan nan
15 | > True
16 |
17 | :p atan2 (sin 0.44) (cos 0.44) ~~ 0.44
18 | > True
19 | :p atan2 (sin (-0.44)) (cos (-0.44)) ~~ (-0.44)
20 | > True
21 | :p atan2 (-sin (-0.44)) (cos (-0.44)) ~~ (0.44)
22 | > True
23 | :p atan2 (-1.0) (-1.0) ~~ (-3.0/4.0*pi)
24 | > True
25 |
26 | -- Test all the way around the circle.
27 | angles = linspace (Fin 11) (-pi + 0.001) (pi)
28 | :p all for i:(Fin 11).
29 | angles[i] ~~ atan2 (sin angles[i]) (cos angles[i])
30 | > True
31 |
32 | :p (atan2 infinity 1.0) ~~ ( pi / 2.0)
33 | > True
34 | :p (atan2 (-infinity) 1.0) ~~ (-pi / 2.0)
35 | > True
36 | :p (atan2 1.0 infinity) ~~ 0.0
37 | > True
38 | :p (atan2 (-1.0) infinity) ~~ 0.0
39 | > True
40 |
41 | :p (atan2 infinity infinity) ~~ ( pi / 4.0)
42 | > True
43 | :p (atan2 infinity (-infinity)) ~~ ( 3.0 * pi / 4.0)
44 | > True
45 | :p (atan2 (-infinity) infinity) ~~ (-pi / 4.0)
46 | > True
47 | :p (atan2 (-infinity) (-infinity)) ~~ (-3.0 * pi / 4.0)
48 | > True
49 |
50 | :p isnan $ atan2 nan infinity
51 | > True
52 | :p isnan $ atan2 infinity nan
53 | > True
54 | :p isnan $ atan2 nan nan
55 | > True
56 |
57 | :p sinh 1.2 ~~ 1.5094614
58 | > True
59 |
60 | :p tanh 1.2 ~~ ((sinh 1.2) / (cosh 1.2))
61 | > True
62 |
63 | :p tanh (f_to_f64 1.2) ~~ divide (sinh (f_to_f64 1.2)) (cosh (f_to_f64 1.2))
64 | > True
65 |
--------------------------------------------------------------------------------
/tests/unit/JaxADTSpec.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2023 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | module JaxADTSpec (spec) where
8 |
9 | import Data.Aeson (encode, decode)
10 | import Test.Hspec
11 |
12 | import Name
13 | import JAX.Concrete
14 | import JAX.Rename
15 | import JAX.ToSimp
16 | import Runtime
17 | import TopLevel
18 | import Types.Imp
19 | import Types.Primitives hiding (Sin)
20 | import Types.Source hiding (SourceName)
21 | import QueryType
22 |
23 | x_nm, y_nm :: JSourceName
24 | x_nm = JSourceName 0 0 "x"
25 | y_nm = JSourceName 1 0 "y"
26 |
27 | float :: JVarType
28 | float = (JArrayName [] F32)
29 |
30 | ten_vec :: JVarType
31 | ten_vec = (JArrayName [DimSize 10] F32)
32 |
33 | a_jaxpr :: JVarType -> Jaxpr VoidS
34 | a_jaxpr ty = Jaxpr
35 | (Nest (JBindSource x_nm ty) Empty)
36 | Empty
37 | (Nest (JEqn
38 | (Nest (JBindSource y_nm ty) Empty)
39 | Sin
40 | [JVariable $ JVar (SourceName x_nm) ty]) Empty)
41 | [JVariable $ JVar (SourceName y_nm) ty]
42 |
43 | compile :: Jaxpr VoidS -> IO LLVMCallable
44 | compile jaxpr = do
45 | let cfg = EvalConfig LLVM [LibBuiltinPath] Nothing Nothing Nothing NoOptimize PrintCodegen
46 | env <- initTopState
47 | fst <$> runTopperM cfg env do
48 | -- TODO Implement GenericE for jaxprs, derive SinkableE, and properly sink
49 | -- the jaxpr instead of just coercing it.
50 | Distinct <- getDistinct
51 | jRename <- liftRenameM $ renameJaxpr (unsafeCoerceE jaxpr)
52 | jSimp <- liftJaxSimpM (simplifyJaxpr jRename) >>= asTopLam
53 | compileTopLevelFun (EntryFunCC CUDANotRequired) jSimp >>= packageLLVMCallable
54 |
55 | spec :: Spec
56 | spec = do
57 | describe "JaxADT" do
58 | it "round-trips to json" do
59 | let first = encode $ a_jaxpr ten_vec
60 | let (Just decoded) = (decode first :: Maybe (Jaxpr VoidS))
61 | let second = encode decoded
62 | second `shouldBe` first
63 | it "executes" do
64 | jLLVM <- compile $ a_jaxpr float
65 | result <- callEntryFun jLLVM [Float32Lit 3.0]
66 | result `shouldBe` [Float32Lit $ sin 3.0]
67 |
--------------------------------------------------------------------------------
/tests/unit/RawNameSpec.hs:
--------------------------------------------------------------------------------
1 | -- Copyright 2022 Google LLC
2 | --
3 | -- Use of this source code is governed by a BSD-style
4 | -- license that can be found in the LICENSE file or at
5 | -- https://developers.google.com/open-source/licenses/bsd
6 |
7 | {-# OPTIONS_GHC -Wno-orphans #-}
8 |
9 | module RawNameSpec (spec) where
10 |
11 | import Control.Monad
12 | import Data.Char
13 | import Test.Hspec
14 | import Test.QuickCheck
15 | import RawName qualified as R
16 |
17 | newtype RawNameMap = RMap (R.RawNameMap ())
18 | deriving (Show)
19 |
20 | instance Arbitrary RawNameMap where
21 | arbitrary = do
22 | s <- getSize
23 | RMap . R.fromList <$> (replicateM s $ (,()) <$> arbitrary)
24 |
25 | instance Arbitrary R.NameHint where
26 | arbitrary = do
27 | arbitrary >>= \case
28 | True -> R.getNameHint . fromStringNameHint <$> arbitrary
29 | False -> return R.noHint -- TODO: Generate more interesting non-string names
30 |
31 | instance Arbitrary R.RawName where
32 | arbitrary = R.rawNameFromHint <$> arbitrary
33 |
34 | newtype StringNameHint = StringNameHint { fromStringNameHint :: String }
35 |
36 | instance Show StringNameHint where
37 | show (StringNameHint s) = s
38 |
39 | instance Arbitrary StringNameHint where
40 | arbitrary = StringNameHint <$> do
41 | s <- chooseInt (1, 7)
42 | replicateM s $ arbitrary `suchThat` isNiceAscii
43 |
44 | isNiceAscii :: Char -> Bool
45 | isNiceAscii h = isAsciiLower h || isAsciiUpper h || isDigit h
46 |
47 | spec :: Spec
48 | spec = do
49 | describe "RawName" do
50 | it "generates a fresh name" do
51 | property \hint (RMap m) -> do
52 | let name = R.freshRawName hint m
53 | not $ name `R.member` m
54 |
55 | it "repeatedly generates fresh names from the same hint" do
56 | property \hint (RMap initM) -> do
57 | let n = 512
58 | let step = \(m, ok) () ->
59 | let name = R.freshRawName hint m in
60 | (R.insert name () m, ok && not (name `R.member` m))
61 | snd $ foldl step (initM, True) (replicate n ())
62 |
63 | it "string names are in a bijection with short strings" do
64 | property \(StringNameHint s) -> do
65 | let s' = show (R.rawNameFromHint (R.getNameHint s))
66 | counterexample s' $ s == s'
67 |
68 | it "string names with non-zero counters print correctly" do
69 | property \(StringNameHint s) -> do
70 | let hint = R.getNameHint s
71 | let n = R.rawNameFromHint hint
72 | let scope = R.singleton n ()
73 | show (R.freshRawName hint scope) == s ++ ".1"
74 |
--------------------------------------------------------------------------------
/tests/unit/Spec.hs:
--------------------------------------------------------------------------------
1 | {-# OPTIONS_GHC -F -pgmF hspec-discover #-}
2 |
--------------------------------------------------------------------------------