├── .gitignore ├── stack.yaml ├── LICENSE ├── .travis.yml ├── irbuilder └── Main.hs ├── basic └── Main.hs ├── README.md ├── examples.cabal ├── orc └── Main.hs └── arith └── Arith.hs /.gitignore: -------------------------------------------------------------------------------- 1 | *.sw[pon] 2 | .stack-work 3 | dist-newstyle 4 | stack.yaml.lock 5 | -------------------------------------------------------------------------------- /stack.yaml: -------------------------------------------------------------------------------- 1 | resolver: lts-14.0 2 | packages: 3 | - '.' 4 | 5 | extra-deps: 6 | - llvm-hs-9.0.0 7 | - llvm-hs-pure-9.0.0 8 | - llvm-hs-pretty-0.9.0.0 9 | 10 | flags: 11 | llvm-hs: 12 | shared-llvm: true 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017-2020, Stephen Diehl 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to 5 | deal in the Software without restriction, including without limitation the 6 | rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 7 | sell copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 18 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 19 | IN THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | dist: xenial 3 | language: generic 4 | 5 | env: 6 | global: 7 | - GCC=gcc-5 8 | - GXX=g++-5 9 | 10 | cache: 11 | directories: 12 | - $HOME/.stack/ 13 | 14 | addons: 15 | apt: 16 | packages: 17 | - gcc-5 18 | - g++-5 19 | - libgmp-dev 20 | sources: 21 | - llvm-toolchain-xenial-9 22 | - ubuntu-toolchain-r-test 23 | 24 | before_install: 25 | - mkdir -p ~/.local/bin 26 | - export PATH=~/.local/bin:$PATH 27 | - travis_retry curl -L https://www.stackage.org/stack/linux-x86_64 | tar xz --wildcards --strip-components=1 -C ~/.local/bin '*/stack' 28 | - export CC=/usr/bin/$GCC 29 | - export CXX=/usr/bin/$GXX 30 | - wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - 31 | - echo "deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-9 main" | sudo tee -a /etc/apt/sources.list 32 | - sudo add-apt-repository --yes ppa:ubuntu-toolchain-r/ppa 33 | - sudo apt-get update 34 | - sudo apt-get --yes install llvm-9 llvm-9-dev 35 | 36 | install: 37 | - stack --no-terminal --install-ghc test --only-dependencies 38 | 39 | script: 40 | - stack --no-terminal test --haddock --no-haddock-deps 41 | -------------------------------------------------------------------------------- /irbuilder/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE RecursiveDo #-} 2 | {-# LANGUAGE OverloadedStrings #-} 3 | 4 | module Main where 5 | 6 | import Data.Text.Lazy.IO as T 7 | 8 | import LLVM.AST hiding (function) 9 | import LLVM.AST.Type as AST 10 | import qualified LLVM.AST.Float as F 11 | import qualified LLVM.AST.Constant as C 12 | import qualified LLVM.AST.IntegerPredicate as P 13 | 14 | import LLVM.IRBuilder.Module 15 | import LLVM.IRBuilder.Monad 16 | import LLVM.IRBuilder.Instruction 17 | 18 | simple :: Module 19 | simple = buildModule "exampleModule" $ mdo 20 | function "f" [(AST.i32, "a")] AST.i32 $ \[a] -> mdo 21 | _entry <- block `named` "entry" 22 | cond <- icmp P.EQ a (ConstantOperand (C.Int 32 0)) 23 | condBr cond ifThen ifElse 24 | ifThen <- block 25 | trVal <- add a (ConstantOperand (C.Int 32 0)) 26 | br ifExit 27 | ifElse <- block `named` "if.else" 28 | flVal <- add a (ConstantOperand (C.Int 32 0)) 29 | br ifExit 30 | ifExit <- block `named` "if.exit" 31 | r <- phi [(trVal, ifThen), (flVal, ifElse)] 32 | ret r 33 | 34 | function "plus" [(AST.i32, "x"), (AST.i32, "y")] AST.i32 $ \[x, y] -> do 35 | _entry <- block `named` "entry2" 36 | r <- add x y 37 | ret r 38 | 39 | main :: IO () 40 | main = print simple 41 | -------------------------------------------------------------------------------- /basic/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE OverloadedStrings #-} 2 | 3 | module Main where 4 | 5 | import LLVM.AST 6 | import qualified LLVM.AST as AST 7 | import LLVM.AST.Global 8 | import LLVM.Context 9 | import LLVM.Module 10 | 11 | import Control.Monad.Except 12 | import Data.ByteString.Char8 as BS 13 | 14 | int :: Type 15 | int = IntegerType 32 16 | 17 | defAdd :: Definition 18 | defAdd = GlobalDefinition functionDefaults 19 | { name = Name "add" 20 | , parameters = 21 | ( [ Parameter int (Name "a") [] 22 | , Parameter int (Name "b") [] ] 23 | , False ) 24 | , returnType = int 25 | , basicBlocks = [body] 26 | } 27 | where 28 | body = BasicBlock 29 | (Name "entry") 30 | [ Name "result" := 31 | Add False -- no signed wrap 32 | False -- no unsigned wrap 33 | (LocalReference int (Name "a")) 34 | (LocalReference int (Name "b")) 35 | []] 36 | (Do $ Ret (Just (LocalReference int (Name "result"))) []) 37 | 38 | 39 | module_ :: AST.Module 40 | module_ = defaultModule 41 | { moduleName = "basic" 42 | , moduleDefinitions = [defAdd] 43 | } 44 | 45 | 46 | toLLVM :: AST.Module -> IO () 47 | toLLVM mod = withContext $ \ctx -> do 48 | llvm <- withModuleFromAST ctx mod moduleLLVMAssembly 49 | BS.putStrLn llvm 50 | 51 | 52 | main :: IO () 53 | main = toLLVM module_ 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | LLVM Haskell Examples 2 | ===================== 3 | 4 | [![Build Status](https://travis-ci.org/llvm-hs/llvm-hs-examples.svg?branch=master)](https://travis-ci.org/llvm-hs/llvm-hs-examples) 5 | 6 | Simple examples demonstrating the usage of the 7 | [llvm-hs](https://github.com/llvm-hs/llvm-hs) for generating and manipulating 8 | LLVM from Haskell. 9 | 10 | * [basic](./basic) - Generating LLVM AST and rendering Textual IR 11 | * [orc](./orc) - JIT Compiling IR on the Eager and Lazy ORC Jit using Compile-On-Demand 12 | * [arith](./arith) - a minimal JIT compiler for functions of one (real) variable using recursion schemes 13 | * [irbuilder](./irbuilder) - Basic usage of the LLVM IRBuilder for constructing modules 14 | 15 | These examples require LLVM 9.0. Check that your installed LLVM version is 16 | precisely 9.0. If not then follow the install directions in the 17 | [llvm-hs](https://github.com/llvm-hs/llvm-hs) repository. 18 | 19 | ```bash 20 | $ llvm-config --version 21 | 9.0 22 | ``` 23 | 24 | To run the examples with Stack: 25 | 26 | ```bash 27 | $ stack exec basic 28 | $ stack exec orc 29 | $ stack exec arith 30 | $ stack exec irbuilder 31 | ``` 32 | 33 | To load the examples in GHCI: 34 | 35 | ```bash 36 | $ stack repl examples:basic 37 | $ stack repl examples:orc 38 | $ stack repl examples:arith 39 | $ stack repl examples:irbuilder 40 | ``` 41 | 42 | To run the examples with Cabal: 43 | 44 | ```bash 45 | $ cabal run basic 46 | $ cabal run orc 47 | $ cabal run arith 48 | $ cabal run irbuilder 49 | ``` 50 | 51 | License 52 | ------- 53 | 54 | MIT License 55 | Copyright (c) 2017-2020, Stephen Diehl 56 | -------------------------------------------------------------------------------- /examples.cabal: -------------------------------------------------------------------------------- 1 | name: examples 2 | version: 1.0.0.0 3 | license: MIT 4 | license-file: LICENSE 5 | cabal-version: >=1.8 6 | tested-with: GHC >=7.8 7 | build-type: Simple 8 | maintainer: Stephen Diehl` 9 | category: Compilers 10 | synopsis: Examples for the llvm-hs library 11 | description: Examples using the llvm-hs library 12 | extra-source-files: 13 | README.md 14 | 15 | Source-Repository head 16 | type: git 17 | location: git@github.com:llvm-hs/llvm-hs-examples.git 18 | 19 | executable basic 20 | main-is: Main.hs 21 | hs-source-dirs: basic 22 | build-depends: 23 | base >=4.7 && <4.14 24 | , bytestring >=0.10 && <0.11 25 | , llvm-hs >=9.0 && <9.1 26 | , llvm-hs-pure >=9.0 && <9.1 27 | , mtl >=2.2.2 && <2.3 28 | 29 | executable orc 30 | main-is: Main.hs 31 | hs-source-dirs: orc 32 | build-depends: 33 | base >=4.7 && <4.14 34 | , bytestring >=0.10 && <0.11 35 | , containers >=0.6 && <0.7 36 | , llvm-hs >=9.0 && <9.1 37 | , llvm-hs-pure >=9.0 && <9.1 38 | , mtl >=2.2.2 && <2.3 39 | 40 | executable irbuilder 41 | main-is: Main.hs 42 | hs-source-dirs: irbuilder 43 | build-depends: 44 | base >=4.7 && <4.14 45 | , bytestring >=0.10 && <0.11 46 | , llvm-hs >=9.0 && <9.1 47 | , llvm-hs-pretty >=0.9 && <0.10 48 | , llvm-hs-pure >=9.0 && <9.1 49 | , mtl >=2.2.2 && <2.3 50 | , text >=1.2 && <1.3 51 | 52 | executable arith 53 | main-is: Arith.hs 54 | hs-source-dirs: arith 55 | build-depends: 56 | base >=4.7 && <4.14 57 | , bytestring >=0.10 && <0.11 58 | , containers >=0.6 && <0.7 59 | , deepseq >=1.4 && <1.5 60 | , llvm-hs >=9.0 && <9.1 61 | , llvm-hs-pretty >=0.9 && <0.10 62 | , llvm-hs-pure >=9.0 && <9.1 63 | , mtl >=2.2.2 && <2.3 64 | , recursion-schemes >=5.1 && <5.2 65 | , text >=1.2 && <1.3 66 | , transformers >=0.5 && <0.6 67 | -------------------------------------------------------------------------------- /orc/Main.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE ForeignFunctionInterface #-} 2 | {-# LANGUAGE OverloadedStrings #-} 3 | {-# OPTIONS_GHC -Wall #-} 4 | 5 | module Main where 6 | 7 | import Control.Monad.Except 8 | import qualified Data.ByteString.Char8 as BS 9 | import Data.IORef 10 | import Data.Int 11 | import qualified Data.Map.Strict as Map 12 | import Foreign.Ptr 13 | import LLVM.AST 14 | import qualified LLVM.AST as AST 15 | import LLVM.AST.Constant 16 | import LLVM.AST.Global 17 | import LLVM.CodeGenOpt 18 | import LLVM.CodeModel 19 | import LLVM.Context 20 | import LLVM.Internal.OrcJIT.CompileLayer 21 | import LLVM.Module 22 | import LLVM.OrcJIT 23 | import LLVM.Relocation 24 | import LLVM.Target 25 | import Prelude hiding (mod) 26 | 27 | foreign import ccall "dynamic" 28 | mkMain :: FunPtr (IO Int32) -> IO Int32 29 | 30 | int :: Type 31 | int = IntegerType 32 32 | 33 | defAdd :: Definition 34 | defAdd = 35 | GlobalDefinition 36 | functionDefaults 37 | { name = Name "add", 38 | parameters = ([], False), 39 | returnType = int, 40 | basicBlocks = [body] 41 | } 42 | where 43 | body = 44 | BasicBlock 45 | (Name "entry") 46 | [] 47 | (Do $ Ret (Just (ConstantOperand (Int 32 42))) []) 48 | 49 | module_ :: AST.Module 50 | module_ = 51 | defaultModule 52 | { moduleName = "basic", 53 | moduleDefinitions = [defAdd] 54 | } 55 | 56 | withTestModule :: AST.Module -> (LLVM.Module.Module -> IO a) -> IO a 57 | withTestModule mod f = withContext $ \context -> withModuleFromAST context mod f 58 | 59 | resolver :: CompileLayer l => l -> MangledSymbol -> IO (Either JITSymbolError JITSymbol) 60 | resolver compileLayer symbol = findSymbol compileLayer symbol True 61 | 62 | failInIO :: ExceptT String IO a -> IO a 63 | failInIO = either fail return <=< runExceptT 64 | 65 | eagerJit :: AST.Module -> IO () 66 | eagerJit amod = do 67 | resolvers <- newIORef Map.empty 68 | withTestModule amod $ \mod -> 69 | withHostTargetMachine PIC LLVM.CodeModel.Default LLVM.CodeGenOpt.Default $ \tm -> 70 | withExecutionSession $ \es -> 71 | withObjectLinkingLayer es (\k -> fmap (\rs -> rs Map.! k) (readIORef resolvers)) $ \linkingLayer -> 72 | withIRCompileLayer linkingLayer tm $ \compileLayer -> do 73 | mainSymbol <- mangleSymbol compileLayer "add" 74 | asm <- moduleLLVMAssembly mod 75 | BS.putStrLn asm 76 | withModuleKey es $ \k -> 77 | withSymbolResolver es (SymbolResolver (resolver compileLayer)) $ \sresolver -> do 78 | modifyIORef' resolvers (Map.insert k sresolver) 79 | rsym <- findSymbol compileLayer mainSymbol True 80 | case rsym of 81 | Left err -> do 82 | print err 83 | Right (JITSymbol mainFn _) -> do 84 | result <- mkMain (castPtrToFunPtr (wordPtrToPtr mainFn)) 85 | print result 86 | 87 | main :: IO () 88 | main = do 89 | res <- eagerJit module_ 90 | putStrLn "Eager JIT Result:" 91 | print res 92 | -------------------------------------------------------------------------------- /arith/Arith.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DeriveFoldable #-} 2 | {-# LANGUAGE DeriveFunctor #-} 3 | {-# LANGUAGE DeriveTraversable #-} 4 | {-# LANGUAGE FlexibleContexts #-} 5 | {-# LANGUAGE FlexibleInstances #-} 6 | {-# LANGUAGE ForeignFunctionInterface #-} 7 | {-# LANGUAGE OverloadedStrings #-} 8 | {-# LANGUAGE TypeSynonymInstances #-} 9 | 10 | module Main where 11 | 12 | import Control.DeepSeq (NFData, force) 13 | import Control.Exception 14 | import Control.Monad 15 | import Control.Monad.IO.Class 16 | import Data.ByteString (ByteString) 17 | import qualified Data.ByteString.Char8 as BS 18 | import Data.Foldable 19 | import Data.Functor.Foldable hiding (fold) 20 | import Data.IORef 21 | import Data.Map.Strict (Map) 22 | import qualified Data.Map.Strict as Map 23 | import Data.Monoid 24 | import qualified Data.Set as Set 25 | import Data.Text.Lazy (Text) 26 | import qualified Data.Text.Lazy.IO as Text 27 | import Foreign.Ptr 28 | import qualified LLVM.AST as LLVM 29 | import qualified LLVM.AST.Constant as LLVM 30 | import qualified LLVM.AST.Float as LLVM 31 | import qualified LLVM.AST.Type as LLVM 32 | import qualified LLVM.CodeGenOpt as CodeGenOpt 33 | import qualified LLVM.CodeModel as CodeModel 34 | import qualified LLVM.Context as JIT 35 | import qualified LLVM.IRBuilder.Instruction as LLVMIR 36 | import qualified LLVM.IRBuilder.Module as LLVMIR 37 | import qualified LLVM.IRBuilder.Monad as LLVMIR 38 | import qualified LLVM.Internal.OrcJIT.CompileLayer as JIT 39 | import qualified LLVM.Linking as JIT 40 | import qualified LLVM.Module as JIT 41 | import qualified LLVM.OrcJIT as JIT 42 | import qualified LLVM.Pretty as LLVMPretty 43 | import qualified LLVM.Relocation as Reloc 44 | import qualified LLVM.Target as JIT 45 | 46 | -- * Core expression type 47 | 48 | -- | An expression will be any value of type @'Fix' 'ExprF'@, which 49 | -- has as values arbitrarily nested applications of constructors from 50 | -- 'ExprF'. This is equivalent to just having an 'Expr type with no type 51 | -- parameter and all @a@s replaced by 'Expr', but the 'Functor' and 'Foldable' 52 | -- instances are quite handy, especially combined with the /recursion-schemes/ 53 | -- library. 54 | -- 55 | -- This type allows us to express the body of a @Double -> Double@ function, 56 | -- where 'Var' allows us to refer to the (only) argument of the function. 57 | data ExprF a 58 | = -- | a 'Double' literal 59 | Lit Double 60 | | -- | @a+b@ 61 | Add a a 62 | | -- | @a-b@ 63 | Sub a a 64 | | -- | @a*b@ 65 | Mul a a 66 | | -- | @a/b@ 67 | Div a a 68 | | -- | @-a@ 69 | Neg a 70 | | -- | @'exp' a@ 71 | Exp a 72 | | -- | @'log' a@ 73 | Log a 74 | | -- | @'sqrt' a@ 75 | Sqrt a 76 | | -- | @'sin' a@ 77 | Sin a 78 | | -- | @'cos' a@ 79 | Cos a 80 | | -- | @'x'@ 81 | Var 82 | deriving (Functor, Foldable, Traversable) 83 | 84 | type Expr = Fix ExprF 85 | 86 | -- * Helpers for building expressions 87 | 88 | x :: Expr 89 | x = Fix Var 90 | 91 | lit :: Double -> Expr 92 | lit d = Fix (Lit d) 93 | 94 | add :: Expr -> Expr -> Expr 95 | add a b = Fix (Add a b) 96 | 97 | sub :: Expr -> Expr -> Expr 98 | sub a b = Fix (Sub a b) 99 | 100 | mul :: Expr -> Expr -> Expr 101 | mul a b = Fix (Mul a b) 102 | 103 | neg :: Expr -> Expr 104 | neg a = Fix (Neg a) 105 | 106 | instance Num Expr where 107 | fromInteger = lit . fromInteger 108 | (+) = add 109 | (-) = sub 110 | (*) = mul 111 | negate = neg 112 | abs = notImplemented "Expr.abs" 113 | signum = notImplemented "Expr.signum" 114 | 115 | divide :: Expr -> Expr -> Expr 116 | divide a b = Fix (Div a b) 117 | 118 | instance Fractional Expr where 119 | (/) = divide 120 | recip = divide 1 121 | fromRational = lit . fromRational 122 | 123 | instance Floating Expr where 124 | pi = lit pi 125 | exp = Fix . Exp 126 | log = Fix . Log 127 | sqrt = Fix . Sqrt 128 | sin = Fix . Sin 129 | cos = Fix . Cos 130 | asin = notImplemented "Expr.asin" 131 | acos = notImplemented "Expr.acos" 132 | atan = notImplemented "Expr.atan" 133 | sinh = notImplemented "Expr.sinh" 134 | cosh = notImplemented "Expr.cosh" 135 | asinh = notImplemented "Expr.asinh" 136 | acosh = notImplemented "Expr.acosh" 137 | atanh = notImplemented "Expr.atanh" 138 | 139 | notImplemented :: String -> a 140 | notImplemented = error . (++ " is not implemented") 141 | 142 | -- * Pretty printing 143 | 144 | -- | Pretty print an 'Expr' 145 | pp :: Expr -> String 146 | pp e = funprefix ++ para ppExpAlg e 147 | where 148 | funprefix = "\\x -> " 149 | 150 | printExpr :: MonadIO m => Expr -> m () 151 | printExpr expr = liftIO $ do 152 | putStrLn "*** Expression ***\n" 153 | putStrLn (pp expr) 154 | 155 | -- | Core pretty printing function. For each 156 | -- constructor that contains sub expressions, 157 | -- we get the string for the sub expression as 158 | -- well as the original 'Expr' value, to help us 159 | -- decide when to use parens. 160 | ppExpAlg :: ExprF (Expr, String) -> String 161 | ppExpAlg (Lit d) = show d 162 | ppExpAlg (Add (_, a) (_, b)) = a ++ " + " ++ b 163 | ppExpAlg (Sub (_, a) (e2, b)) = 164 | a ++ " - " ++ paren (isAdd e2 || isSub e2) b 165 | ppExpAlg (Mul (e1, a) (e2, b)) = 166 | paren (isAdd e1 || isSub e1) a ++ " * " ++ paren (isAdd e2 || isSub e2) b 167 | ppExpAlg (Div (e1, a) (e2, b)) = 168 | paren (isAdd e1 || isSub e1) a ++ " / " ++ paren (isComplex e2) b 169 | where 170 | isComplex (Fix (Add _ _)) = True 171 | isComplex (Fix (Sub _ _)) = True 172 | isComplex (Fix (Mul _ _)) = True 173 | isComplex (Fix (Div _ _)) = True 174 | isComplex _ = False 175 | ppExpAlg (Neg (_, a)) = function "negate" a 176 | ppExpAlg (Exp (_, a)) = function "exp" a 177 | ppExpAlg (Log (_, a)) = function "log" a 178 | ppExpAlg (Sqrt (_, a)) = function "sqrt" a 179 | ppExpAlg (Sin (_, a)) = function "sin" a 180 | ppExpAlg (Cos (_, a)) = function "cos" a 181 | ppExpAlg Var = "x" 182 | 183 | paren :: Bool -> String -> String 184 | paren b x 185 | | b = "(" ++ x ++ ")" 186 | | otherwise = x 187 | 188 | function name arg = 189 | name ++ paren True arg 190 | 191 | isAdd :: Expr -> Bool 192 | isAdd (Fix (Add _ _)) = True 193 | isAdd _ = False 194 | 195 | isSub :: Expr -> Bool 196 | isSub (Fix (Sub _ _)) = True 197 | isSub _ = False 198 | 199 | isLit :: Expr -> Bool 200 | isLit (Fix (Lit _)) = True 201 | isLit _ = False 202 | 203 | isVar :: Expr -> Bool 204 | isVar (Fix Var) = True 205 | isVar _ = False 206 | 207 | -- * Simple evaluator 208 | 209 | -- | Evaluate an 'Expr'ession using standard 210 | -- 'Num', 'Fractional' and 'Floating' operations. 211 | eval :: Expr -> (Double -> Double) 212 | eval fexpr x = cata alg fexpr 213 | where 214 | alg e = case e of 215 | Var -> x 216 | Lit d -> d 217 | Add a b -> a + b 218 | Sub a b -> a - b 219 | Mul a b -> a * b 220 | Div a b -> a / b 221 | Neg a -> negate a 222 | Exp a -> exp a 223 | Log a -> log a 224 | Sqrt a -> sqrt a 225 | Sin a -> sin a 226 | Cos a -> cos a 227 | 228 | -- * Code generation 229 | 230 | -- | Helper for calling intrinsics for 'exp', 'log' and friends. 231 | callDblfun :: 232 | LLVMIR.MonadIRBuilder m => LLVM.Operand -> LLVM.Operand -> m LLVM.Operand 233 | callDblfun fun arg = LLVMIR.call fun [(arg, [])] 234 | 235 | xparam :: LLVMIR.ParameterName 236 | xparam = LLVMIR.ParameterName "x" 237 | 238 | -- | Generate @declare@ statements for all the intrinsics required for 239 | -- executing the given expression and return a mapping from function 240 | -- name to 'Operand' so that we can very easily refer to those functions 241 | -- for calling them, when generating the code for the expression itself. 242 | declarePrimitives :: 243 | LLVMIR.MonadModuleBuilder m => Expr -> m (Map String LLVM.Operand) 244 | declarePrimitives expr = fmap Map.fromList 245 | $ forM primitives 246 | $ \primName -> do 247 | f <- 248 | LLVMIR.extern 249 | (LLVM.mkName ("llvm." <> primName <> ".f64")) 250 | [LLVM.double] 251 | LLVM.double 252 | return (primName, f) 253 | where 254 | primitives = Set.toList (cata alg expr) 255 | alg (Exp ps) = Set.insert "exp" ps 256 | alg (Log ps) = Set.insert "log" ps 257 | alg (Sqrt ps) = Set.insert "sqrt" ps 258 | alg (Sin ps) = Set.insert "sin" ps 259 | alg (Cos ps) = Set.insert "cos" ps 260 | alg e = fold e 261 | 262 | -- | Generate an LLVM IR module for the given expression, 263 | -- including @declare@ statements for the intrinsics and 264 | -- a function, always called @f@, that will perform the copoutations 265 | -- described by the 'Expr'ession. 266 | codegen :: Expr -> LLVM.Module 267 | codegen fexpr = LLVMIR.buildModule "arith.ll" $ do 268 | prims <- declarePrimitives fexpr 269 | _ <- LLVMIR.function "f" [(LLVM.double, xparam)] LLVM.double $ \[arg] -> do 270 | res <- cataM (alg arg prims) fexpr 271 | LLVMIR.ret res 272 | return () 273 | where 274 | alg arg _ (Lit d) = 275 | return (LLVM.ConstantOperand $ LLVM.Float $ LLVM.Double d) 276 | alg arg _ Var = return arg 277 | alg arg _ (Add a b) = LLVMIR.fadd a b `LLVMIR.named` "x" 278 | alg arg _ (Sub a b) = LLVMIR.fsub a b `LLVMIR.named` "x" 279 | alg arg _ (Mul a b) = LLVMIR.fmul a b `LLVMIR.named` "x" 280 | alg arg _ (Div a b) = LLVMIR.fdiv a b `LLVMIR.named` "x" 281 | alg arg ps (Neg a) = do 282 | z <- alg arg ps (Lit 0) 283 | LLVMIR.fsub z a `LLVMIR.named` "x" 284 | alg arg ps (Exp a) = callDblfun (ps Map.! "exp") a `LLVMIR.named` "x" 285 | alg arg ps (Log a) = callDblfun (ps Map.! "log") a `LLVMIR.named` "x" 286 | alg arg ps (Sqrt a) = callDblfun (ps Map.! "sqrt") a `LLVMIR.named` "x" 287 | alg arg ps (Sin a) = callDblfun (ps Map.! "sin") a `LLVMIR.named` "x" 288 | alg arg ps (Cos a) = callDblfun (ps Map.! "cos") a `LLVMIR.named` "x" 289 | 290 | codegenText :: Expr -> Text 291 | codegenText = LLVMPretty.ppllvm . codegen 292 | 293 | printCodegen :: Expr -> IO () 294 | printCodegen = Text.putStrLn . codegenText 295 | 296 | -- * JIT compilation & loading 297 | 298 | -- | This allows us to call dynamically loaded functions 299 | foreign import ccall "dynamic" 300 | mkDoubleFun :: FunPtr (Double -> Double) -> (Double -> Double) 301 | 302 | resolver :: 303 | JIT.IRCompileLayer l -> 304 | JIT.MangledSymbol -> 305 | IO (Either JIT.JITSymbolError JIT.JITSymbol) 306 | resolver compileLayer symbol = 307 | JIT.findSymbol compileLayer symbol True 308 | 309 | symbolFromProcess :: JIT.MangledSymbol -> IO JIT.JITSymbol 310 | symbolFromProcess sym = 311 | (\addr -> JIT.JITSymbol addr JIT.defaultJITSymbolFlags) 312 | <$> JIT.getSymbolAddressInProcess sym 313 | 314 | resolv :: JIT.IRCompileLayer l -> JIT.SymbolResolver 315 | resolv cl = JIT.SymbolResolver (\sym -> JIT.findSymbol cl sym True) 316 | 317 | printIR :: MonadIO m => ByteString -> m () 318 | printIR = liftIO . BS.putStrLn . ("\n*** LLVM IR ***\n\n" <>) 319 | 320 | -- | JIT-compile the given 'Expr'ession and use the resulting function. 321 | withSimpleJIT :: 322 | NFData a => 323 | Expr -> 324 | -- | what to do with the generated functiion 325 | ((Double -> Double) -> a) -> 326 | IO a 327 | withSimpleJIT expr doFun = do 328 | resolvers <- newIORef Map.empty 329 | JIT.withContext $ \context -> (>>) (JIT.loadLibraryPermanently Nothing) 330 | $ JIT.withModuleFromAST context (codegen expr) 331 | $ \mod' -> 332 | JIT.withHostTargetMachine Reloc.PIC CodeModel.Default CodeGenOpt.Default $ \tm -> 333 | JIT.withExecutionSession $ \es -> 334 | JIT.withObjectLinkingLayer es (\k -> fmap (\rs -> rs Map.! k) (readIORef resolvers)) $ \objectLayer -> 335 | JIT.withIRCompileLayer objectLayer tm $ \compileLayer -> do 336 | asm <- JIT.moduleLLVMAssembly mod' 337 | printExpr expr 338 | printIR asm 339 | JIT.withModuleKey es $ \k -> 340 | JIT.withModule compileLayer k mod' $ do 341 | fSymbol <- JIT.mangleSymbol compileLayer "f" 342 | Right (JIT.JITSymbol fnAddr _) <- JIT.findSymbol compileLayer fSymbol True 343 | let f = mkDoubleFun . castPtrToFunPtr $ wordPtrToPtr fnAddr 344 | liftIO (putStrLn "*** Result ***\n") 345 | evaluate $ force (doFun f) 346 | 347 | -- * Utilities 348 | 349 | cataM :: 350 | (Monad m, Traversable (Base t), Recursive t) => 351 | (Base t a -> m a) -> 352 | t -> 353 | m a 354 | cataM alg = c 355 | where 356 | c = alg <=< traverse c . project 357 | 358 | -- * Main 359 | 360 | f :: Floating a => a -> a 361 | f t = sin (pi * t / 2) * (1 + sqrt t) ^ 2 362 | 363 | main :: IO () 364 | main = do 365 | let res1 = map f [0 .. 10] :: [Double] 366 | res2 <- withSimpleJIT (f x) (\fopt -> map fopt [0 .. 10]) 367 | if res1 == res2 368 | then putStrLn "results match" >> print res1 369 | else print res1 >> print res2 >> putStrLn "results don't match" 370 | --------------------------------------------------------------------------------