├── Setup.hs ├── app └── Main.hs ├── test └── Spec.hs ├── CHANGELOG.md ├── cabal.project ├── bench-results ├── univariate │ ├── 00.csv │ ├── 01.csv │ ├── 02.csv │ ├── 00.svg │ ├── 01.svg │ └── 02.svg ├── bivariate │ ├── 00.csv │ ├── 02.csv │ ├── 01.csv │ ├── 03.csv │ ├── 04.csv │ ├── 00.svg │ ├── 02.svg │ ├── 03.svg │ ├── 01.svg │ └── 04.svg ├── 4-ary │ ├── 00.csv │ ├── 01.csv │ ├── 02.csv │ ├── 03.csv │ ├── 04.csv │ ├── 00.svg │ ├── 01.svg │ ├── 02.svg │ ├── 03.svg │ └── 04.svg └── trivariate │ ├── 00.csv │ ├── 01.csv │ ├── 00.svg │ └── 01.svg ├── fourmolu.yaml ├── .vscode ├── settings.json └── extensions.json ├── scripts ├── collect-bins.sh └── run-bench.sh ├── src ├── Data │ ├── PRef.hs │ └── URef.hs └── Numeric │ └── AD │ └── DelCont │ ├── Native │ ├── Internal.hs │ ├── Linear.hs │ ├── Double.hs │ └── MultiPrompt.hs │ └── Native.hs ├── bench ├── Helpers.hs ├── bench.hs └── Macros.hs ├── .gitignore ├── ad-delcont-primop.cabal ├── .github └── workflows │ └── haskell.yml ├── cabal.project.freeze └── README.md /Setup.hs: -------------------------------------------------------------------------------- 1 | import Distribution.Simple 2 | main = defaultMain 3 | -------------------------------------------------------------------------------- /app/Main.hs: -------------------------------------------------------------------------------- 1 | module Main (main) where 2 | 3 | main :: IO () 4 | main = pure () 5 | -------------------------------------------------------------------------------- /test/Spec.hs: -------------------------------------------------------------------------------- 1 | main :: IO () 2 | main = putStrLn "Test suite not yet implemented" 3 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog for `ad-delcont-primop` 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), 6 | and this project adheres to the 7 | [Haskell Package Versioning Policy](https://pvp.haskell.org/). 8 | 9 | ## Unreleased 10 | 11 | ## 0.1.0.0 - YYYY-MM-DD 12 | -------------------------------------------------------------------------------- /cabal.project: -------------------------------------------------------------------------------- 1 | packages: . 2 | allow-newer: all 3 | optimization: 2 4 | benchmarks: True 5 | 6 | source-repository-package 7 | type: git 8 | location: https://github.com/konn/ad.git 9 | tag: 579579564617861a5f243f7516f7aeb354970173 10 | 11 | source-repository-package 12 | type: git 13 | location: https://github.com/konn/backprop.git 14 | tag: 166de5ad553337e6a1b68cda13ba1593cf5f4e36 15 | -------------------------------------------------------------------------------- /bench-results/univariate/00.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.univariate.id.transformers,76385,2662,159,0,35651584 3 | All.univariate.id.ad,22966,1392,22,0,35651584 4 | All.univariate.id.ad/double,23046,1310,22,0,35651584 5 | All.univariate.id.backprop,255086,20570,510,0,35651584 6 | All.univariate.id.primops,53342,5232,63,0,35651584 7 | All.univariate.id.primops/Double,72853,5280,216,0,35651584 8 | -------------------------------------------------------------------------------- /bench-results/bivariate/00.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.bivariate.x * y.transformers,891168,41422,1221,0,35651584 3 | All.bivariate.x * y.ad,1396682,99696,3449,0,35651584 4 | All.bivariate.x * y.ad/double,1477562,109726,3446,0,35651584 5 | All.bivariate.x * y.backprop,1385435,107152,4054,0,35651584 6 | All.bivariate.x * y.primops,162658,15290,382,0,35651584 7 | All.bivariate.x * y.primops/double,102783,6482,406,0,35651584 8 | -------------------------------------------------------------------------------- /bench-results/4-ary/00.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.4-ary.((x * y) * z) * w.transformers,1591714,87196,3064,0,35651584 3 | All.4-ary.((x * y) * z) * w.ad,2617230,237500,4857,0,35651584 4 | All.4-ary.((x * y) * z) * w.ad/double,2614281,211612,4881,0,35651584 5 | All.4-ary.((x * y) * z) * w.backprop,3051500,213004,8890,0,35651584 6 | All.4-ary.((x * y) * z) * w.primops,392547,27284,1016,0,35651584 7 | All.4-ary.((x * y) * z) * w.primops/double,259824,12322,1067,0,35651584 8 | -------------------------------------------------------------------------------- /bench-results/trivariate/00.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.trivariate.(x * y) * z.transformers,1167140,95108,2038,0,35651584 3 | All.trivariate.(x * y) * z.ad,2017114,164156,4066,0,35651584 4 | All.trivariate.(x * y) * z.ad/double,2050284,188754,4061,0,35651584 5 | All.trivariate.(x * y) * z.backprop,2201386,174320,6031,0,35651584 6 | All.trivariate.(x * y) * z.primops,280522,27434,611,0,35651584 7 | All.trivariate.(x * y) * z.primops/double,169797,9320,725,0,35651584 8 | -------------------------------------------------------------------------------- /fourmolu.yaml: -------------------------------------------------------------------------------- 1 | indentation: 2 2 | comma-style: leading # for lists, tuples etc. - can also be 'trailing' 3 | record-brace-space: true # rec {x = 1} vs. rec{x = 1} 4 | indent-wheres: true # 'false' means save space by only half-indenting the 'where' keyword 5 | diff-friendly-import-export: false # 'false' uses Ormolu-style lists 6 | respectful: false # don't be too opinionated about newlines etc. 7 | haddock-style: multi-line # '--' vs. '{-' 8 | newlines-between-decls: 1 # number of newlines between top-level declarations 9 | -------------------------------------------------------------------------------- /bench-results/univariate/01.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.univariate.(x + 1) * (x + 1).transformers,480778,21450,863,0,35651584 3 | All.univariate.(x + 1) * (x + 1).ad,209055,11242,431,0,35651584 4 | All.univariate.(x + 1) * (x + 1).ad/double,27217,2702,44,0,35651584 5 | All.univariate.(x + 1) * (x + 1).backprop,1277574,100268,3935,0,35651584 6 | All.univariate.(x + 1) * (x + 1).primops,352071,23038,764,0,35651584 7 | All.univariate.(x + 1) * (x + 1).primops/Double,276479,16494,942,0,35651584 8 | -------------------------------------------------------------------------------- /bench-results/univariate/02.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.univariate.x * exp ((x * x) + 1).transformers,1174945,89952,1426,0,35651584 3 | All.univariate.x * exp ((x * x) + 1).ad,444335,24208,762,0,35651584 4 | All.univariate.x * exp ((x * x) + 1).ad/double,35840,2624,44,0,35651584 5 | All.univariate.x * exp ((x * x) + 1).backprop,1635899,143640,5056,0,35651584 6 | All.univariate.x * exp ((x * x) + 1).primops,457316,34072,1019,0,35651584 7 | All.univariate.x * exp ((x * x) + 1).primops/Double,363957,31506,1118,0,35651584 8 | -------------------------------------------------------------------------------- /bench-results/bivariate/02.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.bivariate.y * exp ((x * x) + y).transformers,1406389,116420,2438,0,35651584 3 | All.bivariate.y * exp ((x * x) + y).ad,1951049,115036,4474,0,35651584 4 | All.bivariate.y * exp ((x * x) + y).ad/double,1890458,106476,4471,0,35651584 5 | All.bivariate.y * exp ((x * x) + y).backprop,2350500,101676,7106,0,35651584 6 | All.bivariate.y * exp ((x * x) + y).primops,450564,41006,1019,0,35651584 7 | All.bivariate.y * exp ((x * x) + y).primops/double,295242,27616,1019,0,35651584 8 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "haskell.toolchain": { 3 | "hls": null, 4 | "stack": null, 5 | "cabal": null 6 | }, 7 | "[cabal]": { 8 | "editor.defaultFormatter": "berberman.vscode-cabal-fmt" 9 | }, 10 | "[haskell]": { 11 | "editor.defaultFormatter": "sjurmillidahl.ormolu-vscode" 12 | }, 13 | "ormolu.path": "fourmolu", 14 | "ormolu.args": ["--stdin-input-file", "."], 15 | "ghcSimple.flag.noNotifySlowRangeType": true, 16 | "scm.defaultViewMode": "tree", 17 | "scm.diffDecorationsGutterPattern": { 18 | "added": true 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=827846 to learn about workspace recommendations. 3 | // Extension identifier format: ${publisher}.${name}. Example: vscode.csharp 4 | 5 | // List of extensions which should be recommended for users of this workspace. 6 | "recommendations": [ 7 | "emeraldwalk.RunOnSave", 8 | "sjurmillidahl.ormolu-vscode", 9 | "rcook.ghci-helper" 10 | ], 11 | // List of extensions recommended by VS Code that should not be recommended for users of this workspace. 12 | "unwantedRecommendations": [] 13 | } 14 | -------------------------------------------------------------------------------- /bench-results/bivariate/01.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.bivariate.(sin x * cos y) * ((x ^ 2) + y).transformers,1876583,131468,3056,0,35651584 3 | All.bivariate.(sin x * cos y) * ((x ^ 2) + y).ad,2374602,213954,4855,0,35651584 4 | All.bivariate.(sin x * cos y) * ((x ^ 2) + y).ad/double,2185096,195230,4868,0,35651584 5 | All.bivariate.(sin x * cos y) * ((x ^ 2) + y).backprop,3080050,214794,8945,0,35651584 6 | All.bivariate.(sin x * cos y) * ((x ^ 2) + y).primops,762311,50176,1528,0,35651584 7 | All.bivariate.(sin x * cos y) * ((x ^ 2) + y).primops/double,469310,17170,1654,0,35651584 8 | -------------------------------------------------------------------------------- /bench-results/4-ary/01.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).transformers,4076062,208838,6930,0,35651584 3 | All.4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).ad,4310307,317350,8944,0,35651584 4 | All.4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).ad/double,3887061,259270,8116,0,35651584 5 | All.4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).backprop,5901325,401422,16210,0,35651584 6 | All.4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).primops,1729532,163778,4059,0,35651584 7 | All.4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).primops/double,1116577,88670,4458,0,35651584 8 | -------------------------------------------------------------------------------- /scripts/collect-bins.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -euxo pipefail 3 | 4 | DEST=${1:-artifacts.tar.zst} 5 | 6 | wget -O- https://github.com/haskell-hvr/cabal-plan/releases/download/v0.6.2.0/cabal-plan-0.6.2.0-x86_64-linux.xz \ 7 | | xz -d - > cabal-plan 8 | chmod +x cabal-plan 9 | BIN_DIR=bins 10 | mkdir -p "${BIN_DIR}/bin" 11 | ./cabal-plan list-bins | \ 12 | grep -e :test: -e :bench: | awk '{ print $2 }' | \ 13 | while read -r FILE; do basename "${FILE}"; done > "${BIN_DIR}/bench.txt" 14 | 15 | ./cabal-plan list-bins | grep ad-delcont-primop | awk '{ print $2 }' | while read -r BIN; do 16 | strip "${BIN}" 17 | cp "${BIN}" "${BIN_DIR}/bin/" 18 | done 19 | 20 | tar -cf "${DEST}" --use-compress-program="zstdmt -9" "${BIN_DIR}" 21 | -------------------------------------------------------------------------------- /bench-results/bivariate/03.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).transformers,3413433,188422,6112,0,35651584 3 | All.bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).ad,3435928,220052,6903,0,35651584 4 | All.bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).ad/double,2872679,230664,6907,0,35651584 5 | All.bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).backprop,4818040,182760,14126,0,35651584 6 | All.bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).primops,1613117,91738,3462,0,35651584 7 | All.bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).primops/double,1043811,86854,3455,0,35651584 8 | -------------------------------------------------------------------------------- /scripts/run-bench.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -euo pipefail 4 | rm -rf bench-results 5 | mkdir -p bench-results 6 | BENCH="${1:-$(cabal list-bin ad-delcont-primop-bench)}" 7 | 8 | "${BENCH}" -l | cut -d. -f2 | uniq | while read -r GROUP; do 9 | GROUP_DIR="bench-results/${GROUP}" 10 | mkdir -p "${GROUP_DIR}" 11 | I=0 12 | "${BENCH}" -l -p "\$2 == \"${GROUP}\"" | cut -d. -f3 | uniq | while read -r CASE; do 13 | CASE_LABEL="All.${GROUP}.${CASE}" 14 | CASE_NUM="$(printf "%02d" "${I}")" 15 | CASE_BASE="${GROUP_DIR}/${CASE_NUM}" 16 | echo "Saving ${CASE_LABEL} to ${CASE_BASE}" 17 | set -x 18 | "${BENCH}" -j1 -p "\$2 == \"${GROUP}\" && \$3 == \"${CASE}\"" --csv "${CASE_BASE}.csv" --svg "${CASE_BASE}.svg" 19 | I=$((I + 1)) 20 | set +x 21 | done 22 | done 23 | -------------------------------------------------------------------------------- /bench-results/4-ary/02.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).transformers,6337219,370894,9764,0,35651584 3 | All.4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).ad,5714521,361656,9757,0,35651584 4 | All.4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).ad/double,4815505,342856,9748,0,35651584 5 | All.4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).backprop,7877216,366460,21900,2,35651584 6 | All.4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).primops,2638999,197734,8161,0,35651584 7 | All.4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).primops/double,1808872,179354,6930,0,35651584 8 | -------------------------------------------------------------------------------- /src/Data/PRef.hs: -------------------------------------------------------------------------------- 1 | module Data.PRef (PRef, newPRef, readPRef, writePRef, modifyPRef') where 2 | 3 | import Control.Monad.Primitive 4 | import Data.Primitive 5 | 6 | newtype PRef s a = PRef {unPRef :: MutablePrimArray s a} 7 | 8 | newPRef :: (PrimMonad m, Prim a) => a -> m (PRef (PrimState m) a) 9 | {-# INLINE newPRef #-} 10 | newPRef = \x -> do 11 | arr <- newPrimArray 1 12 | writePrimArray arr 0 x 13 | pure $ PRef arr 14 | 15 | readPRef :: (PrimMonad m, Prim a) => PRef (PrimState m) a -> m a 16 | {-# INLINE readPRef #-} 17 | readPRef = flip readPrimArray 0 . unPRef 18 | 19 | writePRef :: (PrimMonad m, Prim a) => PRef (PrimState m) a -> a -> m () 20 | {-# INLINE writePRef #-} 21 | writePRef = flip writePrimArray 0 . unPRef 22 | 23 | modifyPRef' :: (PrimMonad m, Prim a) => PRef (PrimState m) a -> (a -> a) -> m () 24 | {-# INLINE modifyPRef' #-} 25 | modifyPRef' pr f = writePRef pr . f =<< readPRef pr 26 | -------------------------------------------------------------------------------- /src/Data/URef.hs: -------------------------------------------------------------------------------- 1 | module Data.URef (URef, newURef, readURef, writeURef, modifyURef') where 2 | 3 | import Control.Monad.Primitive 4 | import qualified Data.Vector.Unboxed as U 5 | import qualified Data.Vector.Unboxed.Mutable as MU 6 | 7 | newtype URef s a = URef {unURef :: MU.MVector s a} 8 | 9 | newURef :: (MU.Unbox a, PrimMonad m) => a -> m (URef (PrimState m) a) 10 | {-# INLINE newURef #-} 11 | newURef = fmap URef . U.unsafeThaw . U.singleton 12 | 13 | writeURef :: 14 | (MU.Unbox a, PrimMonad m) => 15 | URef (PrimState m) a -> 16 | a -> 17 | m () 18 | writeURef = flip MU.unsafeWrite 0 . unURef 19 | 20 | modifyURef' :: 21 | (MU.Unbox a, PrimMonad m) => 22 | URef (PrimState m) a -> 23 | (a -> a) -> 24 | m () 25 | {-# INLINE modifyURef' #-} 26 | modifyURef' = fmap ($ 0) . MU.unsafeModify . unURef 27 | 28 | readURef :: 29 | (MU.Unbox a, PrimMonad m) => 30 | URef (PrimState m) a -> 31 | m a 32 | readURef = flip MU.unsafeRead 0 . unURef 33 | -------------------------------------------------------------------------------- /bench-results/bivariate/04.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).transformers,8612114,353600,13852,1,35651584 3 | All.bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).ad,5327973,405602,9719,0,35651584 4 | All.bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).ad/double,4248756,201950,10133,0,35651584 5 | All.bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).backprop,8192836,759564,19430,0,35651584 6 | All.bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).primops,3975498,333794,9757,0,35651584 7 | All.bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).primops/double,2646582,194154,8925,0,35651584 8 | -------------------------------------------------------------------------------- /bench-results/4-ary/03.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).transformers,13048955,703204,19552,1,36700160 3 | All.4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).ad,7465664,740348,11370,0,36700160 4 | All.4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).ad/double,6132353,336380,13786,0,36700160 5 | All.4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).backprop,11013901,755348,27447,0,36700160 6 | All.4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).primops,6734287,425572,17842,1,36700160 7 | All.4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).primops/double,4455029,391474,17767,0,36700160 8 | -------------------------------------------------------------------------------- /bench-results/trivariate/01.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).transformers,10932622,672938,16286,2,35651584 3 | All.trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).ad,7536372,665862,11362,0,35651584 4 | All.trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).ad/double,6230117,365510,13762,0,35651584 5 | All.trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).backprop,12678178,697176,32377,0,35651584 6 | All.trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).primops,5353352,419688,13741,0,35651584 7 | All.trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).primops/double,3595261,340726,13731,0,35651584 8 | -------------------------------------------------------------------------------- /bench-results/4-ary/04.csv: -------------------------------------------------------------------------------- 1 | Name,Mean (ps),2*Stdev (ps),Allocated,Copied,Peak Memory 2 | All.4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).transformers,17168740,722266,27682,4,35651584 3 | All.4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).ad,8610849,710976,16232,2,35651584 4 | All.4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).ad/double,6833374,656408,11358,0,35651584 5 | All.4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).backprop,13718808,1329062,32420,4,35651584 6 | All.4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).primops,8280039,800332,24243,1,35651584 7 | All.4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).primops/double,5662126,338074,21775,1,35651584 8 | -------------------------------------------------------------------------------- /src/Numeric/AD/DelCont/Native/Internal.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE GHC2021 #-} 2 | {-# LANGUAGE BlockArguments #-} 3 | {-# LANGUAGE LambdaCase #-} 4 | {-# LANGUAGE MagicHash #-} 5 | {-# LANGUAGE UnboxedTuples #-} 6 | 7 | module Numeric.AD.DelCont.Native.Internal ( 8 | PromptTag, 9 | newPromptTag, 10 | reset, 11 | prompt, 12 | control0, 13 | shift, 14 | ) where 15 | 16 | import Control.Arrow 17 | import GHC.IO 18 | import GHC.Prim 19 | import GHC.ST 20 | 21 | data PromptTag a = PromptTag (PromptTag# a) 22 | 23 | newPromptTag :: ST s (PromptTag a) 24 | {-# INLINE newPromptTag #-} 25 | newPromptTag = 26 | unsafeIOToST $ IO $ \s -> 27 | case newPromptTag# s of 28 | (# s', tag #) -> (# s', PromptTag tag #) 29 | 30 | prompt :: PromptTag a -> ST s a -> ST s a 31 | {-# INLINE prompt #-} 32 | prompt (PromptTag tag) = 33 | unsafeSTToIO >>> \case 34 | IO f -> unsafeIOToST $ IO $ prompt# tag f 35 | 36 | reset :: (PromptTag a -> ST s a) -> ST s a 37 | {-# INLINE reset #-} 38 | reset act = (prompt <*> act) =<< newPromptTag 39 | 40 | shift :: PromptTag a -> ((ST s p -> ST s a) -> ST s a) -> ST s p 41 | {-# INLINE shift #-} 42 | shift p f = control0 p $ \k -> 43 | prompt p $ f $ prompt p . k 44 | 45 | control0 :: PromptTag a -> ((ST s p -> ST s a) -> ST s a) -> ST s p 46 | {-# INLINE control0 #-} 47 | control0 (PromptTag tag) f = 48 | unsafeIOToST $ 49 | IO $ 50 | control0# tag $ \k -> 51 | case unsafeSTToIO $ f (unsafeSTToIO >>> \(IO a) -> unsafeIOToST $ IO $ k a) of 52 | IO p -> p 53 | -------------------------------------------------------------------------------- /bench/Helpers.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE DerivingStrategies #-} 2 | {-# LANGUAGE DerivingVia #-} 3 | {-# OPTIONS_GHC -Wno-orphans #-} 4 | 5 | module Helpers ((==~), isDefinite) where 6 | 7 | import Data.Foldable 8 | import Linear 9 | import qualified Numeric.Backprop as BP 10 | import Test.Tasty.QuickCheck 11 | 12 | (==~) :: (Show (v Double), Foldable v, Metric v) => v Double -> v Double -> Property 13 | ls ==~ rs = 14 | counterexample ("Not near: " <> show (ls, rs)) $ 15 | conjoin $ 16 | toList $ 17 | liftI2 18 | ( \l r -> 19 | (isDefinite l .&&. isDefinite r .&&. almostEq 1e1 l r) 20 | .||. (not (isDefinite l) .&&. not (isDefinite r)) 21 | ) 22 | ls 23 | rs 24 | 25 | almostEq :: Double -> Double -> Double -> Bool 26 | almostEq thresh l r 27 | | nearZero l = nearZero (r / thresh) 28 | | otherwise = nearZero $ abs (l - r) / max (abs l) (abs r) / thresh 29 | 30 | isDefinite :: Double -> Bool 31 | isDefinite c = not (isInfinite c || isNaN c) 32 | 33 | instance Arbitrary a => Arbitrary (V2 a) where 34 | arbitrary = V2 <$> arbitrary <*> arbitrary 35 | shrink = mapM shrink 36 | 37 | instance Arbitrary a => Arbitrary (V3 a) where 38 | arbitrary = V3 <$> arbitrary <*> arbitrary <*> arbitrary 39 | shrink = mapM shrink 40 | 41 | instance Arbitrary a => Arbitrary (V4 a) where 42 | arbitrary = sequence $ pure arbitrary 43 | shrink = mapM shrink 44 | 45 | deriving newtype instance Arbitrary a => Arbitrary (V1 a) 46 | 47 | deriving via BP.NumBP (V1 a) instance Num a => BP.Backprop (V1 a) 48 | 49 | deriving via BP.NumBP (V2 a) instance Num a => BP.Backprop (V2 a) 50 | 51 | deriving via BP.NumBP (V3 a) instance Num a => BP.Backprop (V3 a) 52 | 53 | deriving via BP.NumBP (V4 a) instance Num a => BP.Backprop (V4 a) 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .stack-work/ 2 | *~### https://raw.github.com/github/gitignore/14f8a8b4c51ecc00b18905a95c117954e6c77b9d/Haskell.gitignore 3 | 4 | dist 5 | dist-* 6 | cabal-dev 7 | *.o 8 | *.hi 9 | *.hie 10 | *.chi 11 | *.chs.h 12 | *.dyn_o 13 | *.dyn_hi 14 | .hpc 15 | .hsenv 16 | .cabal-sandbox/ 17 | cabal.sandbox.config 18 | *.prof 19 | *.aux 20 | *.hp 21 | *.eventlog 22 | .stack-work/ 23 | cabal.project.local 24 | cabal.project.local~ 25 | .HTF/ 26 | .ghc.environment.* 27 | 28 | 29 | ### https://raw.github.com/github/gitignore/14f8a8b4c51ecc00b18905a95c117954e6c77b9d/Global/VisualStudioCode.gitignore 30 | 31 | .vscode/* 32 | !.vscode/settings.json 33 | !.vscode/tasks.json 34 | !.vscode/launch.json 35 | !.vscode/extensions.json 36 | *.code-workspace 37 | 38 | # Local History for Visual Studio Code 39 | .history/ 40 | 41 | 42 | ### https://raw.github.com/github/gitignore/14f8a8b4c51ecc00b18905a95c117954e6c77b9d/Global/macOS.gitignore 43 | 44 | # General 45 | .DS_Store 46 | .AppleDouble 47 | .LSOverride 48 | 49 | # Icon must end with two \r 50 | Icon 51 | 52 | # Thumbnails 53 | ._* 54 | 55 | # Files that might appear in the root of a volume 56 | .DocumentRevisions-V100 57 | .fseventsd 58 | .Spotlight-V100 59 | .TemporaryItems 60 | .Trashes 61 | .VolumeIcon.icns 62 | .com.apple.timemachine.donotpresent 63 | 64 | # Directories potentially created on remote AFP share 65 | .AppleDB 66 | .AppleDesktop 67 | Network Trash Folder 68 | Temporary Items 69 | .apdisk 70 | 71 | 72 | ### https://raw.github.com/github/gitignore/14f8a8b4c51ecc00b18905a95c117954e6c77b9d/Global/Linux.gitignore 73 | 74 | *~ 75 | 76 | # temporary files which can be created if a process still has a handle open of a deleted file 77 | .fuse_hidden* 78 | 79 | # KDE directory preferences 80 | .directory 81 | 82 | # Linux trash folder which might appear on any partition or disk 83 | .Trash-* 84 | 85 | # .nfs files are created when an open file is removed but is still being accessed 86 | .nfs* 87 | 88 | 89 | .bin 90 | -------------------------------------------------------------------------------- /bench/bench.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE PartialTypeSignatures #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | {-# OPTIONS_GHC -Wno-type-defaults -Wno-partial-type-signatures #-} 4 | 5 | module Main (main) where 6 | 7 | import Linear (V2 (..), V3 (..), V4 (..)) 8 | import Macros 9 | import Test.Tasty.Bench 10 | 11 | main :: IO () 12 | main = 13 | defaultMain 14 | [ bgroup "univariate" $ 15 | [ $(mkDiffBench [|id|]) 16 | , $(mkDiffBench [|\x -> (x + 1) * (x + 1)|]) 17 | , $(mkDiffBench [|\x -> x * exp (x * x + 1)|]) 18 | ] 19 | , bgroup 20 | "bivariate" 21 | [ $(mkGradBench [|\(V2 x y) -> x * y|] [|V2 (42 :: Double) 53|]) 22 | , $(mkGradBench [|\(V2 x y) -> sin x * cos y * (x ^ 2 + y)|] [|V2 42 54|]) 23 | , $(mkGradBench [|\(V2 x y) -> y * exp (x * x + y)|] [|V2 42 52|]) 24 | , $(mkGradBench [|\(V2 x y) -> (x * cos x + y) ^ 2 * exp (x * sin (x + y * y + 1))|] [|V2 42 52|]) 25 | , $(mkGradBench [|\(V2 x y) -> (tanh (exp y * cosh x) + x ^ 2) ^ 3 - (x * cos x + y) ^ 2 * exp (x * sin (x + y * y + 1))|] [|V2 42 52|]) 26 | ] 27 | , bgroup 28 | "trivariate" 29 | [ $(mkGradBench [|\(V3 x y z) -> x * y * z|] [|V3 4 5 6|]) 30 | , $(mkGradBench [|\(V3 x y z) -> (tanh (exp (y + z ^ 2) * cosh x) + x ^ 2) ^ 3 - (x * (z ^ 2 - 1) * cos x + y) ** (2 * z) * exp (x * sin (x + y * z * x + 1))|] [|V3 4 5 6|]) 31 | ] 32 | , bgroup 33 | "4-ary" 34 | [ $(mkGradBench [|\(V4 x y z w) -> x * y * z * w|] [|V4 4 5 6 7|]) 35 | , $(mkGradBench [|\(V4 x y z w) -> (x + w) ^ 4 * exp (x + cos (y ^ 2 * sin z) * w)|] [|V4 4 5 6 7|]) 36 | , $(mkGradBench [|\(V4 x y z w) -> log (x ^ 2 + w) / log (x + w) ^ 4 * exp (x + cos (y ^ 2 * sin z) * w)|] [|V4 4 5 6 7|]) 37 | , $(mkGradBench [|\(V4 x y z w) -> logBase (x ^ 2 + w) (cos (x ^ 2 + 2 * z) + w + 1) ^ 4 * exp (x + sin (pi * x) * cos (exp y ^ 2 * sin z) * w)|] [|V4 4 5 6 7|]) 38 | , $(mkGradBench [|\(V4 x y z w) -> logBase (x ^ 2 + tanh w) (cos (x ^ 2 + 2 * z) + w + 1) ^ 4 + exp (x + sin (pi * x + w ^ 2) * cosh (exp y ^ 2 * sin z) ^ 2 * (w + 1))|] [|V4 4 5 6 7|]) 39 | ] 40 | ] 41 | -------------------------------------------------------------------------------- /src/Numeric/AD/DelCont/Native/Linear.hs: -------------------------------------------------------------------------------- 1 | module Numeric.AD.DelCont.Native.Linear ( 2 | negated, 3 | (^+^), 4 | (^-^), 5 | (*^), 6 | (^*), 7 | (^/), 8 | zero, 9 | sumV, 10 | dot, 11 | quadrance, 12 | norm, 13 | pured, 14 | ) where 15 | 16 | import Control.Monad 17 | import Data.Foldable 18 | import qualified Linear as L 19 | import Numeric.AD.DelCont.Native 20 | 21 | infixl 6 ^+^ 22 | 23 | negated :: (L.Additive v, L.Additive u, Num a, Num da) => AD' s (v a) (u da) -> AD' s (v a) (u da) 24 | negated = op1' L.zero (L.^+^) $ \v -> 25 | (L.negated v, L.negated) 26 | 27 | (^+^) :: 28 | (L.Additive v, L.Additive u, Num a, Num da) => 29 | AD' s (v a) (u da) -> 30 | AD' s (v a) (u da) -> 31 | AD' s (v a) (u da) 32 | {-# INLINE (^+^) #-} 33 | (^+^) = op2' L.zero (L.^+^) (L.^+^) $ \x y -> (x L.^+^ y, id, id) 34 | 35 | infixl 6 ^-^ 36 | 37 | (^-^) :: 38 | (L.Additive v, L.Additive u, Num a, Num da) => 39 | AD' s (v a) (u da) -> 40 | AD' s (v a) (u da) -> 41 | AD' s (v a) (u da) 42 | {-# INLINE (^-^) #-} 43 | (^-^) = op2' L.zero (L.^+^) (L.^+^) $ \x y -> (x L.^-^ y, id, fmap negate) 44 | 45 | infixl 7 *^ 46 | 47 | (*^) :: 48 | (L.Metric v, Num a) => 49 | AD s a -> 50 | AD s (v a) -> 51 | AD s (v a) 52 | {-# INLINE (*^) #-} 53 | (*^) = op2' L.zero (+) (L.^+^) $ \c v -> 54 | (c L.*^ v, (`L.dot` v), (c L.*^)) 55 | 56 | infixl 7 ^* 57 | 58 | (^*) :: 59 | (L.Metric v, Num a) => 60 | AD' s (v a) (v a) -> 61 | AD' s a a -> 62 | AD' s (v a) (v a) 63 | {-# INLINE (^*) #-} 64 | (^*) = op2' L.zero (L.^+^) (+) $ \v c -> 65 | (v L.^* c, (L.^* c), (v `L.dot`)) 66 | 67 | infixl 7 ^/ 68 | 69 | (^/) :: 70 | (L.Metric v, Fractional a) => 71 | AD' s (v a) (v a) -> 72 | AD' s a a -> 73 | AD' s (v a) (v a) 74 | {-# INLINE (^/) #-} 75 | (^/) = op2' L.zero (L.^+^) (+) $ \v c -> 76 | (v L.^/ c, (L.^/ c), \dz -> (-dz `L.dot` v / (c * c))) 77 | 78 | zero :: (L.Additive v, L.Additive u, Num a, Num da) => AD' s (v a) (u da) 79 | zero = konst' L.zero L.zero 80 | 81 | sumV :: 82 | (Foldable t, L.Additive v, L.Additive u, Num a, Num da) => 83 | t (AD' s (v a) (u da)) -> 84 | AD' s (v a) (u da) 85 | sumV = foldl' (^+^) zero 86 | 87 | dot :: 88 | (L.Metric v, Num a) => 89 | AD s (v a) -> 90 | AD s (v a) -> 91 | AD s a 92 | dot = op2' 0 (L.^+^) (L.^+^) $ \x y -> 93 | (x `L.dot` y, \dz -> dz L.*^ y, \dz -> dz L.*^ x) 94 | 95 | norm :: 96 | (Floating a, L.Metric v) => 97 | AD s (v a) -> 98 | AD s a 99 | {-# INLINE norm #-} 100 | norm = sqrt . quadrance 101 | 102 | quadrance :: (Floating a, L.Metric v) => AD s (v a) -> AD s a 103 | {-# INLINE quadrance #-} 104 | quadrance = join dot 105 | 106 | pured :: (L.Additive u, Foldable u, Applicative u, Num da) => AD' s a da -> AD' s (u a) (u da) 107 | pured = op1' L.zero (+) $ \x -> (pure x, sum) 108 | -------------------------------------------------------------------------------- /ad-delcont-primop.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 3.0 2 | name: ad-delcont-primop 3 | version: 0.1.0.0 4 | description: 5 | Please see the README on GitHub at 6 | 7 | homepage: https://github.com/githubuser/ad-delcont-primop#readme 8 | bug-reports: https://github.com/githubuser/ad-delcont-primop/issues 9 | author: Author name here 10 | maintainer: example@example.com 11 | copyright: 2023 Author name here 12 | license: BSD-3-Clause 13 | license-file: LICENSE 14 | build-type: Simple 15 | extra-source-files: 16 | CHANGELOG.md 17 | README.md 18 | 19 | source-repository head 20 | type: git 21 | location: https://github.com/githubuser/ad-delcont-primop 22 | 23 | common commons 24 | ghc-options: 25 | -Wall -Wcompat -Widentities -Wincomplete-record-updates 26 | -Wincomplete-uni-patterns -Wmissing-export-lists 27 | -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints 28 | 29 | default-language: GHC2021 30 | other-modules: Paths_ad_delcont_primop 31 | build-depends: 32 | , base >=4.7 && <5 33 | , ghc-prim 34 | , linear 35 | , simple-reflect 36 | 37 | library 38 | import: commons 39 | exposed-modules: 40 | Data.PRef 41 | Data.URef 42 | Numeric.AD.DelCont.Native 43 | Numeric.AD.DelCont.Native.Double 44 | Numeric.AD.DelCont.Native.Linear 45 | Numeric.AD.DelCont.Native.MultiPrompt 46 | 47 | other-modules: Numeric.AD.DelCont.Native.Internal 48 | hs-source-dirs: src 49 | build-depends: 50 | , distributive 51 | , primitive 52 | , vector 53 | 54 | default-language: Haskell2010 55 | 56 | executable ad-delcont-primop-exe 57 | import: commons 58 | main-is: Main.hs 59 | other-modules: Paths_ad_delcont_primop 60 | hs-source-dirs: app 61 | build-depends: ad-delcont-primop 62 | default-language: Haskell2010 63 | 64 | test-suite ad-delcont-primop-test 65 | import: commons 66 | type: exitcode-stdio-1.0 67 | main-is: Spec.hs 68 | other-modules: Paths_ad_delcont_primop 69 | hs-source-dirs: test 70 | build-tool-depends: tasty-discover:tasty-discover -any 71 | build-depends: 72 | , ad-delcont-primop 73 | , tasty 74 | 75 | default-language: Haskell2010 76 | 77 | benchmark ad-delcont-primop-bench 78 | import: commons 79 | type: exitcode-stdio-1.0 80 | main-is: bench.hs 81 | other-modules: 82 | Helpers 83 | Macros 84 | 85 | hs-source-dirs: bench 86 | ghc-options: "-with-rtsopts=-A32m -T" 87 | build-depends: 88 | , ad 89 | , ad-delcont 90 | , ad-delcont-primop 91 | , backprop 92 | , lens 93 | , tasty 94 | , tasty-bench 95 | , tasty-hunit 96 | , tasty-quickcheck 97 | , template-haskell 98 | -------------------------------------------------------------------------------- /.github/workflows/haskell.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | on: push 3 | 4 | jobs: 5 | build: 6 | name: Build 7 | runs-on: ubuntu-22.04 8 | env: 9 | ARTIFACT: ${{github.workspace}}/artifacts.tar.zst 10 | steps: 11 | - name: Checkout 12 | uses: actions/checkout@v3 13 | - name: Cache ~/.cabal/store 14 | uses: actions/cache@v3 15 | with: 16 | path: ~/.cabal/store 17 | key: cabal-store-961-alpha-${{hashFiles('cabal.project.freeze', 'cabal.project', '*.cabal')}} 18 | restore-keys: 19 | cabal-store-961-alpha- 20 | - name: Cache dist-newstyle 21 | uses: actions/cache@v3 22 | with: 23 | path: dist-newstyle 24 | key: cabal-dist-961-alpha-${{hashFiles('cabal.project.freeze', 'cabal.project', '**/*.cabal')}}-${{hashFiles('**/*.hs')}} 25 | restore-keys: | 26 | cabal-dist-961-alpha-${{hashFiles('cabal.project.freeze', 'cabal.project', '**/*.cabal')}}- 27 | cabal-dist-961-alpha- 28 | - name: Setup Haskell 29 | run: | 30 | curl --proto '=https' --tlsv1.2 -sSf https://get-ghcup.haskell.org | BOOTSTRAP_HASKELL_NONINTERACTIVE=1 BOOTSTRAP_HASKELL_MINIMAL=1 sh 31 | ghcup config add-release-channel https://raw.githubusercontent.com/haskell/ghcup-metadata/master/ghcup-prereleases-0.0.7.yaml 32 | ghcup install ghc 9.6.0.20230111 33 | ghcup install cabal 3.8.1.0 34 | - name: Configure 35 | run: > 36 | cabal configure --jobs $'$ncpus' 37 | --with-compiler ghc-9.6.0.20230111 38 | --enable-optimisation=2 39 | --enable-benchmarks --enable-tests 40 | --index-state 2023-01-14T16:11:16Z 41 | 42 | cabal update 43 | - name: Build 44 | run: cabal build 45 | - name: Collect Binaries 46 | run: ./scripts/collect-bins.sh "${{ env.ARTIFACT }}" 47 | - name: Upload Artifact 48 | uses: actions/upload-artifact@v3 49 | with: 50 | name: binaries 51 | path: "${{ env.ARTIFACT }}" 52 | 53 | bench: 54 | needs: build 55 | name: Test and Bench 56 | runs-on: ubuntu-22.04 57 | steps: 58 | - name: Checkout 59 | uses: actions/checkout@v3 60 | - name: Download Artifact 61 | id: download 62 | uses: actions/download-artifact@v3 63 | with: 64 | name: binaries 65 | - name: Decompress 66 | run: | 67 | echo ${{ steps.download.outputs.download-path }} 68 | ls ${{ steps.download.outputs.download-path }} 69 | tar xvf ${{ steps.download.outputs.download-path }}/artifacts.tar.zst --directory=. 70 | - name: Run Bench 71 | env: 72 | BENCHMARK: bins/bin/ad-delcont-primop-bench 73 | RESULTS: bench-results 74 | shell: bash 75 | run: | 76 | ./scripts/run-bench.sh "${{ env.BENCHMARK }}" 77 | tar -cf bench-results.tar.zst --use-compress-program="zstdmt -9" bench-results 78 | - name: Upload Bench Results 79 | uses: actions/upload-artifact@v3 80 | with: 81 | name: bench-results 82 | path: bench-results.tar.zst 83 | -------------------------------------------------------------------------------- /bench-results/univariate/00.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | univariate.id.transformers 76.4 ns 4 | 5 | 76.4 ns ± 2.7 ns 6 | 7 | 8 | 9 | 10 | 11 | 12 | univariate.id.ad 23.0 ns 13 | 14 | 23.0 ns ± 1.4 ns 15 | 16 | 17 | 18 | 19 | 20 | 21 | univariate.id.ad/double 23.0 ns 22 | 23 | 23.0 ns ± 1.3 ns 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | univariate.id.backprop 32 | 255 ns 33 | 34 | 35 | 255 ns ± 21 ns 36 | 37 | 38 | 39 | 40 | 41 | 42 | univariate.id.primops 53.3 ns 43 | 44 | 53.3 ns ± 5.2 ns 45 | 46 | 47 | 48 | 49 | 50 | 51 | univariate.id.primops/Double 72.9 ns 52 | 53 | 72.9 ns ± 5.3 ns 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /bench-results/univariate/01.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | univariate.(x + 1) * (x + 1).transformers 481 ns 4 | 5 | 481 ns ± 21 ns 6 | 7 | 8 | 9 | 10 | 11 | 12 | univariate.(x + 1) * (x + 1).ad 209 ns 13 | 14 | 209 ns ± 11 ns 15 | 16 | 17 | 18 | 19 | 20 | 21 | univariate.(x + 1) * (x + 1).ad/double 27.2 ns 22 | 23 | 27.2 ns ± 2.7 ns 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | univariate.(x + 1) * (x + 1).backprop 32 | 1.28 μs 33 | 34 | 35 | 1.28 μs ± 100 ns 36 | 37 | 38 | 39 | 40 | 41 | 42 | univariate.(x + 1) * (x + 1).primops 352 ns 43 | 44 | 352 ns ± 23 ns 45 | 46 | 47 | 48 | 49 | 50 | 51 | univariate.(x + 1) * (x + 1).primops/Double 276 ns 52 | 53 | 276 ns ± 16 ns 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /bench-results/univariate/02.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | univariate.x * exp ((x * x) + 1).transformers 5 | 1.17 μs 6 | 7 | 8 | 1.17 μs ± 90 ns 9 | 10 | 11 | 12 | 13 | 14 | 15 | univariate.x * exp ((x * x) + 1).ad 444 ns 16 | 17 | 444 ns ± 24 ns 18 | 19 | 20 | 21 | 22 | 23 | 24 | univariate.x * exp ((x * x) + 1).ad/double 35.8 ns 25 | 26 | 35.8 ns ± 2.6 ns 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | univariate.x * exp ((x * x) + 1).backprop 35 | 1.64 μs 36 | 37 | 38 | 1.64 μs ± 144 ns 39 | 40 | 41 | 42 | 43 | 44 | 45 | univariate.x * exp ((x * x) + 1).primops 457 ns 46 | 47 | 457 ns ± 34 ns 48 | 49 | 50 | 51 | 52 | 53 | 54 | univariate.x * exp ((x * x) + 1).primops/Double 364 ns 55 | 56 | 364 ns ± 32 ns 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /bench-results/bivariate/00.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | bivariate.x * y.transformers 5 | 891 ns 6 | 7 | 8 | 891 ns ± 41 ns 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | bivariate.x * y.ad 17 | 1.40 μs 18 | 19 | 20 | 1.40 μs ± 100 ns 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | bivariate.x * y.ad/double 29 | 1.48 μs 30 | 31 | 32 | 1.48 μs ± 110 ns 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | bivariate.x * y.backprop 41 | 1.39 μs 42 | 43 | 44 | 1.39 μs ± 107 ns 45 | 46 | 47 | 48 | 49 | 50 | 51 | bivariate.x * y.primops 163 ns 52 | 53 | 163 ns ± 15 ns 54 | 55 | 56 | 57 | 58 | 59 | 60 | bivariate.x * y.primops/double 103 ns 61 | 62 | 103 ns ± 6.5 ns 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /bench-results/trivariate/00.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | trivariate.(x * y) * z.transformers 5 | 1.17 μs 6 | 7 | 8 | 1.17 μs ± 95 ns 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | trivariate.(x * y) * z.ad 17 | 2.02 μs 18 | 19 | 20 | 2.02 μs ± 164 ns 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | trivariate.(x * y) * z.ad/double 29 | 2.05 μs 30 | 31 | 32 | 2.05 μs ± 189 ns 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | trivariate.(x * y) * z.backprop 41 | 2.20 μs 42 | 43 | 44 | 2.20 μs ± 174 ns 45 | 46 | 47 | 48 | 49 | 50 | 51 | trivariate.(x * y) * z.primops 281 ns 52 | 53 | 281 ns ± 27 ns 54 | 55 | 56 | 57 | 58 | 59 | 60 | trivariate.(x * y) * z.primops/double 170 ns 61 | 62 | 170 ns ± 9.3 ns 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /bench-results/4-ary/00.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 4-ary.((x * y) * z) * w.transformers 5 | 1.59 μs 6 | 7 | 8 | 1.59 μs ± 87 ns 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 4-ary.((x * y) * z) * w.ad 17 | 2.62 μs 18 | 19 | 20 | 2.62 μs ± 238 ns 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 4-ary.((x * y) * z) * w.ad/double 29 | 2.61 μs 30 | 31 | 32 | 2.61 μs ± 212 ns 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 4-ary.((x * y) * z) * w.backprop 41 | 3.05 μs 42 | 43 | 44 | 3.05 μs ± 213 ns 45 | 46 | 47 | 48 | 49 | 50 | 51 | 4-ary.((x * y) * z) * w.primops 393 ns 52 | 53 | 393 ns ± 27 ns 54 | 55 | 56 | 57 | 58 | 59 | 60 | 4-ary.((x * y) * z) * w.primops/double 260 ns 61 | 62 | 260 ns ± 12 ns 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /bench-results/4-ary/01.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).transformers 4.08 μs 4 | 5 | 4.08 μs ± 209 ns 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).ad 14 | 4.31 μs 15 | 16 | 17 | 4.31 μs ± 317 ns 18 | 19 | 20 | 21 | 22 | 23 | 24 | 4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).ad/double 3.89 μs 25 | 26 | 3.89 μs ± 259 ns 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).backprop 35 | 5.90 μs 36 | 37 | 38 | 5.90 μs ± 401 ns 39 | 40 | 41 | 42 | 43 | 44 | 45 | 4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).primops 1.73 μs 46 | 47 | 1.73 μs ± 164 ns 48 | 49 | 50 | 51 | 52 | 53 | 54 | 4-ary.((x + w) ^ 4) * exp (x + (cos ((y ^ 2) * sin z) * w)).primops/double 1.12 μs 55 | 56 | 1.12 μs ± 89 ns 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /bench-results/bivariate/02.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | bivariate.y * exp ((x * x) + y).transformers 5 | 1.41 μs 6 | 7 | 8 | 1.41 μs ± 116 ns 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | bivariate.y * exp ((x * x) + y).ad 17 | 1.95 μs 18 | 19 | 20 | 1.95 μs ± 115 ns 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | bivariate.y * exp ((x * x) + y).ad/double 29 | 1.89 μs 30 | 31 | 32 | 1.89 μs ± 106 ns 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | bivariate.y * exp ((x * x) + y).backprop 41 | 2.35 μs 42 | 43 | 44 | 2.35 μs ± 102 ns 45 | 46 | 47 | 48 | 49 | 50 | 51 | bivariate.y * exp ((x * x) + y).primops 451 ns 52 | 53 | 451 ns ± 41 ns 54 | 55 | 56 | 57 | 58 | 59 | 60 | bivariate.y * exp ((x * x) + y).primops/double 295 ns 61 | 62 | 295 ns ± 28 ns 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /bench-results/bivariate/03.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).transformers 3.41 μs 4 | 5 | 3.41 μs ± 188 ns 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).ad 14 | 3.44 μs 15 | 16 | 17 | 3.44 μs ± 220 ns 18 | 19 | 20 | 21 | 22 | 23 | 24 | bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).ad/double 2.87 μs 25 | 26 | 2.87 μs ± 231 ns 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).backprop 35 | 4.82 μs 36 | 37 | 38 | 4.82 μs ± 183 ns 39 | 40 | 41 | 42 | 43 | 44 | 45 | bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).primops 1.61 μs 46 | 47 | 1.61 μs ± 92 ns 48 | 49 | 50 | 51 | 52 | 53 | 54 | bivariate.(((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1)).primops/double 1.04 μs 55 | 56 | 1.04 μs ± 87 ns 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /bench-results/bivariate/01.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | bivariate.(sin x * cos y) * ((x ^ 2) + y).transformers 5 | 1.88 μs 6 | 7 | 8 | 1.88 μs ± 131 ns 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | bivariate.(sin x * cos y) * ((x ^ 2) + y).ad 17 | 2.37 μs 18 | 19 | 20 | 2.37 μs ± 214 ns 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | bivariate.(sin x * cos y) * ((x ^ 2) + y).ad/double 29 | 2.19 μs 30 | 31 | 32 | 2.19 μs ± 195 ns 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | bivariate.(sin x * cos y) * ((x ^ 2) + y).backprop 41 | 3.08 μs 42 | 43 | 44 | 3.08 μs ± 215 ns 45 | 46 | 47 | 48 | 49 | 50 | 51 | bivariate.(sin x * cos y) * ((x ^ 2) + y).primops 762 ns 52 | 53 | 762 ns ± 50 ns 54 | 55 | 56 | 57 | 58 | 59 | 60 | bivariate.(sin x * cos y) * ((x ^ 2) + y).primops/double 469 ns 61 | 62 | 469 ns ± 17 ns 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /bench-results/4-ary/02.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).transformers 6.34 μs 4 | 5 | 6.34 μs ± 371 ns 6 | 7 | 8 | 9 | 10 | 11 | 12 | 4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).ad 5.71 μs 13 | 14 | 5.71 μs ± 362 ns 15 | 16 | 17 | 18 | 19 | 20 | 21 | 4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).ad/double 4.82 μs 22 | 23 | 4.82 μs ± 343 ns 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).backprop 32 | 7.88 μs 33 | 34 | 35 | 7.88 μs ± 366 ns 36 | 37 | 38 | 39 | 40 | 41 | 42 | 4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).primops 2.64 μs 43 | 44 | 2.64 μs ± 198 ns 45 | 46 | 47 | 48 | 49 | 50 | 51 | 4-ary.(log ((x ^ 2) + w) / (log (x + w) ^ 4)) * exp (x + (cos ((y ^ 2) * sin z) * w)).primops/double 1.81 μs 52 | 53 | 1.81 μs ± 179 ns 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /bench-results/bivariate/04.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).transformers 8.61 μs 4 | 5 | 8.61 μs ± 354 ns 6 | 7 | 8 | 9 | 10 | 11 | 12 | bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).ad 5.33 μs 13 | 14 | 5.33 μs ± 406 ns 15 | 16 | 17 | 18 | 19 | 20 | 21 | bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).ad/double 4.25 μs 22 | 23 | 4.25 μs ± 202 ns 24 | 25 | 26 | 27 | 28 | 29 | 30 | bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).backprop 8.19 μs 31 | 32 | 8.19 μs ± 760 ns 33 | 34 | 35 | 36 | 37 | 38 | 39 | bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).primops 3.98 μs 40 | 41 | 3.98 μs ± 334 ns 42 | 43 | 44 | 45 | 46 | 47 | 48 | bivariate.((tanh (exp y * cosh x) + (x ^ 2)) ^ 3) - ((((x * cos x) + y) ^ 2) * exp (x * sin ((x + (y * y)) + 1))).primops/double 2.65 μs 49 | 50 | 2.65 μs ± 194 ns 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /bench-results/4-ary/03.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).transformers 13.0 μs 4 | 5 | 13.0 μs ± 703 ns 6 | 7 | 8 | 9 | 10 | 11 | 12 | 4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).ad 7.47 μs 13 | 14 | 7.47 μs ± 740 ns 15 | 16 | 17 | 18 | 19 | 20 | 21 | 4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).ad/double 6.13 μs 22 | 23 | 6.13 μs ± 336 ns 24 | 25 | 26 | 27 | 28 | 29 | 30 | 4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).backprop 11.0 μs 31 | 32 | 11.0 μs ± 755 ns 33 | 34 | 35 | 36 | 37 | 38 | 39 | 4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).primops 6.73 μs 40 | 41 | 6.73 μs ± 426 ns 42 | 43 | 44 | 45 | 46 | 47 | 48 | 4-ary.(logBase ((x ^ 2) + w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) * exp (x + ((sin (pi * x) * cos ((exp y ^ 2) * sin z)) * w)).primops/double 4.46 μs 49 | 50 | 4.46 μs ± 391 ns 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /bench-results/trivariate/01.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).transformers 10.9 μs 4 | 5 | 10.9 μs ± 673 ns 6 | 7 | 8 | 9 | 10 | 11 | 12 | trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).ad 7.54 μs 13 | 14 | 7.54 μs ± 666 ns 15 | 16 | 17 | 18 | 19 | 20 | 21 | trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).ad/double 6.23 μs 22 | 23 | 6.23 μs ± 366 ns 24 | 25 | 26 | 27 | 28 | 29 | 30 | trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).backprop 12.7 μs 31 | 32 | 12.7 μs ± 697 ns 33 | 34 | 35 | 36 | 37 | 38 | 39 | trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).primops 5.35 μs 40 | 41 | 5.35 μs ± 420 ns 42 | 43 | 44 | 45 | 46 | 47 | 48 | trivariate.((tanh (exp (y + (z ^ 2)) * cosh x) + (x ^ 2)) ^ 3) - (((((x * ((z ^ 2) - 1)) * cos x) + y) ** (2 * z)) * exp (x * sin ((x + ((y * z) * x)) + 1))).primops/double 3.60 μs 49 | 50 | 3.60 μs ± 341 ns 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /bench-results/4-ary/04.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).transformers 17.2 μs 4 | 5 | 17.2 μs ± 722 ns 6 | 7 | 8 | 9 | 10 | 11 | 12 | 4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).ad 8.61 μs 13 | 14 | 8.61 μs ± 711 ns 15 | 16 | 17 | 18 | 19 | 20 | 21 | 4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).ad/double 6.83 μs 22 | 23 | 6.83 μs ± 656 ns 24 | 25 | 26 | 27 | 28 | 29 | 30 | 4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).backprop 13.7 μs 31 | 32 | 13.7 μs ± 1.3 μs 33 | 34 | 35 | 36 | 37 | 38 | 39 | 4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).primops 8.28 μs 40 | 41 | 8.28 μs ± 800 ns 42 | 43 | 44 | 45 | 46 | 47 | 48 | 4-ary.(logBase ((x ^ 2) + tanh w) ((cos ((x ^ 2) + (2 * z)) + w) + 1) ^ 4) + exp (x + ((sin ((pi * x) + (w ^ 2)) * (cosh ((exp y ^ 2) * sin z) ^ 2)) * (w + 1))).primops/double 5.66 μs 49 | 50 | 5.66 μs ± 338 ns 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /bench/Macros.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE PartialTypeSignatures #-} 2 | {-# LANGUAGE TemplateHaskell #-} 3 | 4 | module Macros (mkDiffBench, mkGradBench) where 5 | 6 | import Control.Lens.Plated 7 | import Data.Data.Lens 8 | import Helpers 9 | import Language.Haskell.TH 10 | import Linear 11 | import qualified Numeric.AD as AD 12 | import qualified Numeric.AD.DelCont as MTL 13 | import qualified Numeric.AD.DelCont.Native as PrimOp 14 | import qualified Numeric.AD.DelCont.Native.Double as PrimOpDouble 15 | import qualified Numeric.AD.DelCont.Native.MultiPrompt as MP 16 | import qualified Numeric.AD.Double as ADDouble 17 | import qualified Numeric.Backprop as BP 18 | import Test.Tasty.Bench 19 | import Test.Tasty.QuickCheck 20 | 21 | mkDiffBench :: ExpQ -> ExpQ 22 | mkDiffBench func = do 23 | lab <- showBody <$> runQ func 24 | [| 25 | bgroup 26 | lab 27 | [ testProperty "ad and primops returns almost the same answer" $ 28 | \(x :: Double) -> 29 | let ad = AD.diff $(func) x 30 | primop = PrimOp.diff ($func :: PrimOp.AD s Double -> PrimOp.AD s Double) x 31 | in classify (not $ isDefinite primop) "diverged" $ V1 ad ==~ V1 primop 32 | , testProperty "ad and primops/mp returns almost the same answer" $ 33 | \(x :: Double) -> 34 | let ad = AD.diff $(func) x 35 | primop = MP.diff $(func) x 36 | in classify (not $ isDefinite primop) "diverged" $ V1 ad ==~ V1 primop 37 | , testProperty "ad/double and primops/double returns almost the same answer" $ 38 | \(x :: Double) -> 39 | let ad = AD.diff $(func) x 40 | primop = PrimOpDouble.diff $(func) x 41 | in classify (not $ isDefinite primop) "diverged" $ V1 ad ==~ V1 primop 42 | , bench "transformers" $ nf (snd . MTL.rad1 $func) (42.0 :: Double) 43 | , bench "ad" $ nf (AD.diff $func) (42.0 :: Double) 44 | , bench "ad/double" $ nf (ADDouble.diff $func) (42.0 :: Double) 45 | , bench "backprop" $ nf (BP.gradBP $func) (42.0 :: Double) 46 | , bench "primops" $ nf (PrimOp.diff ($func :: PrimOp.AD s Double -> PrimOp.AD s Double)) (42.0 :: Double) 47 | , bench "primops/mp" $ nf (MP.diff ($func :: MP.AD s Double -> MP.AD s Double)) (42.0 :: Double) 48 | , bench "primops/Double" $ nf (PrimOpDouble.diff $func) (42.0 :: Double) 49 | ] 50 | |] 51 | 52 | showBody :: Exp -> String 53 | showBody (LamE _ b) = showBody b 54 | showBody e = pprint $ transformOnOf biplate uniplate unQualNames e 55 | 56 | unQualNames :: Name -> Name 57 | unQualNames = mkName . nameBase 58 | 59 | mkGradBench :: ExpQ -> ExpQ -> ExpQ 60 | mkGradBench func arg0 = do 61 | lab <- showBody <$> runQ func 62 | [| 63 | env (pure ($(arg0) :: _ Double)) $ \arg -> 64 | bgroup 65 | lab 66 | [ testProperty "ad and primops returns almost the same answer" $ 67 | \(x :: _ Double) -> 68 | let ad = AD.grad $(func) x 69 | primop = 70 | PrimOp.grad 71 | ($func :: _ (PrimOp.AD s Double) -> PrimOp.AD s Double) 72 | x 73 | in classify (any (not . isDefinite) ad) "diverged" $ ad ==~ primop 74 | , testProperty "ad and primops/mp returns almost the same answer" $ 75 | \(x :: _ Double) -> 76 | let ad = AD.grad $(func) x 77 | primop = 78 | MP.grad 79 | ($func :: _ (MP.AD s Double) -> MP.AD s Double) 80 | x 81 | in classify (any (not . isDefinite) ad) "diverged" $ ad ==~ primop 82 | , testProperty "ad/double and primops/double returns almost the same answer" $ 83 | \(x :: _ Double) -> 84 | let ad = AD.grad $(func) x 85 | primop = PrimOpDouble.grad $(func) x 86 | in classify (any (not . isDefinite) ad) "diverged" $ ad ==~ primop 87 | , bench "transformers" $ nf (snd . MTL.grad $func) arg 88 | , bench "ad" $ nf (AD.grad $func) arg 89 | , bench "ad/double" $ nf (ADDouble.grad $func) arg 90 | , bench "backprop" $ nf (BP.gradBP ($func . BP.sequenceVar)) arg 91 | , bench "primops" $ nf (PrimOp.grad ($func :: _ (PrimOp.AD s Double) -> PrimOp.AD s Double)) arg 92 | , bench "primops/mp" $ nf (MP.grad ($func :: _ (MP.AD s Double) -> MP.AD s Double)) arg 93 | , bench "primops/double" $ nf (PrimOpDouble.grad $func) arg 94 | ] 95 | |] 96 | -------------------------------------------------------------------------------- /cabal.project.freeze: -------------------------------------------------------------------------------- 1 | active-repositories: hackage.haskell.org:merge 2 | constraints: any.Glob ==0.10.2, 3 | any.OneTuple ==0.3.1, 4 | any.QuickCheck ==2.14.2, 5 | QuickCheck -old-random +templatehaskell, 6 | any.StateVar ==1.2.2, 7 | any.ad ==4.5.2, 8 | ad -ffi -herbie, 9 | any.ad-delcont ==0.3.0.0, 10 | any.adjunctions ==4.4.2, 11 | any.ansi-terminal ==0.11.4, 12 | ansi-terminal -example +win32-2-13-1, 13 | any.ansi-wl-pprint ==0.6.9, 14 | ansi-wl-pprint -example, 15 | any.array ==0.5.4.0, 16 | any.assoc ==1.0.2, 17 | any.backprop ==0.2.6.4, 18 | any.base ==4.18.0.0, 19 | any.base-orphans ==0.8.7, 20 | any.bifunctors ==5.5.14, 21 | bifunctors +semigroups +tagged, 22 | any.binary ==0.8.9.1, 23 | any.binary-orphans ==1.0.3, 24 | any.bytes ==0.17.2, 25 | any.bytestring ==0.11.3.1, 26 | any.call-stack ==0.4.0, 27 | any.cereal ==0.5.8.3, 28 | cereal -bytestring-builder, 29 | any.colour ==2.3.6, 30 | any.comonad ==5.0.8, 31 | comonad +containers +distributive +indexed-traversable, 32 | any.containers ==0.6.6, 33 | any.contravariant ==1.5.5, 34 | contravariant +semigroups +statevar +tagged, 35 | any.data-reify ==0.6.3, 36 | data-reify -tests, 37 | any.deepseq ==1.4.8.0, 38 | any.directory ==1.3.8.0, 39 | any.distributive ==0.6.2.1, 40 | distributive +semigroups +tagged, 41 | any.dlist ==1.0, 42 | dlist -werror, 43 | any.erf ==2.0.0.0, 44 | any.exceptions ==0.10.7, 45 | exceptions +transformers-0-4, 46 | any.filepath ==1.4.100.0, 47 | filepath -cpphs, 48 | any.free ==5.1.10, 49 | any.ghc-bignum ==1.3, 50 | any.ghc-boot-th ==9.6.0.20230111, 51 | any.ghc-prim ==0.10.0, 52 | any.hashable ==1.4.2.0, 53 | hashable +integer-gmp -random-initial-seed, 54 | any.indexed-traversable ==0.1.2, 55 | any.indexed-traversable-instances ==0.1.1.1, 56 | any.integer-logarithms ==1.0.3.1, 57 | integer-logarithms -check-bounds +integer-gmp, 58 | any.invariant ==0.6, 59 | any.kan-extensions ==5.2.5, 60 | any.lens ==5.2, 61 | lens -benchmark-uniplate -dump-splices +inlining -j +test-hunit +test-properties +test-templates +trustworthy, 62 | any.linear ==1.22, 63 | linear -herbie +template-haskell, 64 | any.microlens ==0.4.13.1, 65 | any.mtl ==2.3.1, 66 | any.nats ==1.1.2, 67 | nats +binary +hashable +template-haskell, 68 | any.optparse-applicative ==0.17.0.0, 69 | optparse-applicative +process, 70 | any.parallel ==3.2.2.0, 71 | any.pretty ==1.1.3.6, 72 | any.primitive ==0.7.4.0, 73 | any.process ==1.6.16.0, 74 | any.profunctors ==5.6.2, 75 | any.random ==1.2.1.1, 76 | any.reflection ==2.1.6, 77 | reflection -slow +template-haskell, 78 | any.rts ==1.0.2, 79 | any.scientific ==0.3.7.0, 80 | scientific -bytestring-builder -integer-simple, 81 | any.semigroupoids ==5.3.7, 82 | semigroupoids +comonad +containers +contravariant +distributive +tagged +unordered-containers, 83 | any.semigroups ==0.20, 84 | semigroups +binary +bytestring -bytestring-builder +containers +deepseq +hashable +tagged +template-haskell +text +transformers +unordered-containers, 85 | any.simple-reflect ==0.3.3, 86 | any.splitmix ==0.1.0.4, 87 | splitmix -optimised-mixer, 88 | any.stm ==2.5.1.0, 89 | any.strict ==0.4.0.1, 90 | strict +assoc, 91 | any.tagged ==0.8.6.1, 92 | tagged +deepseq +transformers, 93 | any.tasty ==1.4.3, 94 | tasty +unix, 95 | any.tasty-bench ==0.3.2, 96 | tasty-bench -debug +tasty, 97 | any.tasty-discover ==5.0.0, 98 | any.tasty-hunit ==0.10.0.3, 99 | any.tasty-quickcheck ==0.10.2, 100 | any.template-haskell ==2.19.0.0, 101 | any.text ==2.0.1, 102 | any.th-abstraction ==0.4.5.0, 103 | any.these ==1.1.1.1, 104 | these +assoc, 105 | any.time ==1.12.2, 106 | any.transformers ==0.6.0.6, 107 | any.transformers-base ==0.4.6, 108 | transformers-base +orphaninstances, 109 | any.transformers-compat ==0.7.2, 110 | transformers-compat -five +five-three -four +generic-deriving +mtl -three -two, 111 | any.unix ==2.8.0.0, 112 | any.unordered-containers ==0.2.19.1, 113 | unordered-containers -debug, 114 | any.vector ==0.13.0.0, 115 | vector +boundschecks -internalchecks -unsafechecks -wall, 116 | any.vector-stream ==0.1.0.0, 117 | any.vinyl ==0.14.3, 118 | any.void ==0.7.3, 119 | void -safe 120 | index-state: hackage.haskell.org 2023-01-14T16:11:16Z 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ad-delcont-primop 2 | 3 | An attempt to implement Reverse-Mode AD in terms of delcont primops introduced in GHC 9.6. 4 | 5 | That is, it reimplements [`ad-delcont`][ad-delcont], which translates Scala implementation of [Backpropagation with Continuation Callbacks][cc-differ], in terms of `newPromptTag#`, `prompt#`, and `control0#`. 6 | 7 | ## Performance 8 | 9 | ### Summary 10 | 11 | - In computing multivariate gradients: in most cases, our implementation is at most slightly faster than Edward Kmett's [`ad`][ad]. 12 | In some cases, ours is 4x-10x faster. 13 | - To differentiate univariate functions, always use [`ad`][ad] as it uses Forward-mode. 14 | - Our implementation in most cases outperforms [`backprop`][backprop] and [`ad-delcont`][ad-delcont] (monad transformer-based impl). 15 | 16 | ### Legends 17 | 18 | - `transformers`: [`ad-delcont`][ad-delcont] 19 | - `ad`: [`ad`][ad], generic functions from `Numeric.AD` 20 | - `ad/double`: [`ad`][ad], `Double`-specialised functions provided in `Numeric.AD.Double`. 21 | - `backprop`: [`backprop`][backprop] 22 | - `primop`: our generic implementation. 23 | - `primop/double`: our implementation specialised for `Double`. 24 | 25 | ### Univariate Differentiation 26 | 27 | #### Identity Function: $f(x) = x$ 28 | 29 | ![ad wins](./bench-results/univariate/00.svg) 30 | 31 | #### Binomial: $(x + 1)(x + 1)$ 32 | 33 | ![ad wins](./bench-results/univariate/01.svg) 34 | 35 | #### Gauß-like: $x e^{x^2 + 1}$ 36 | 37 | ![ad wins](./bench-results/univariate/02.svg) 38 | 39 | ### Bivariate 40 | 41 | #### Addition: $f(x, y) = x + y$ 42 | 43 | ![we win by 10!](./bench-results/bivariate/00.svg) 44 | 45 | #### Trigonometrics: $f(x,y) = \sin x \cos y (x^2 + y)$ 46 | 47 | ![we are 4x faster!](./bench-results/bivariate/01.svg) 48 | 49 | #### Exponentials: $f(x, y) = y e^{x^2 + y}$ 50 | 51 | ![still 4x faster!](./bench-results/bivariate/02.svg) 52 | 53 | #### Exponentials and Trigonometrics: $f(x, y) = (x \cos x + y)^2 e^{x \sin (x + y^2 + 1)}$ 54 | 55 | ![twice as fast](./bench-results/bivariate/03.svg) 56 | 57 | #### Complex formula 58 | 59 | $$ 60 | f(x, y) = (\tanh (e^y \cosh x) + x ^ 2) ^ 3 - (x \cos x + y) ^ 2 e^{x \sin (x + y ^2 + 1)} 61 | $$ 62 | 63 | ![1.5x fast](./bench-results/bivariate/04.svg) 64 | 65 | ### Trivariate 66 | 67 | #### Multiplication: $f(x,y,z) = xyz$ 68 | 69 | ![10x fast](./bench-results/trivariate/00.svg) 70 | 71 | #### Complex 72 | 73 | $$ 74 | (\tanh (e^{y + z ^ 2} \cosh x) + x ^ 2) ^ 3 - (x (z ^ 2 - 1) \cos x + y)^{2z} e^{x \sin (x + yzx + 1)} 75 | $$ 76 | 77 | ![1.5x fast](./bench-results/trivariate/01.svg) 78 | 79 | ### 4-ary (quadrivariate) 80 | 81 | #### Multiplication: $f(x,y,z,w) = xyzw$ 82 | 83 | ![10x fast](./bench-results/4-ary/00.svg) 84 | 85 | #### Trigonometrics: $f(x,y,z,w) = (x + w) ^ 4 \exp(x + \cos (y ^ 2 \sin z) w)$ 86 | 87 | ![thrice as fast](./bench-results/4-ary/01.svg) 88 | 89 | #### Some logarithm 90 | 91 | $$ 92 | f(x,y,z,w) = \log (x ^ 2 + w) / \log (x + w) ^ 4 \exp (x + \cos (y ^ 2 \sin z) w) 93 | $$ 94 | 95 | ![twice as fast](./bench-results/4-ary/02.svg) 96 | 97 | #### Some more logarithm 98 | 99 | $$ 100 | f(x,y,z,w) = \log_{x ^ 2 + w}(\cos (x ^ 2 + 2 z) + w + 1) ^ 4 \exp (x + \sin (\pi x) \cos ((e^y) ^ 2 \sin z) w) 101 | $$ 102 | 103 | ![slightly faster](./bench-results/4-ary/03.svg) 104 | 105 | #### Really complex 106 | 107 | $$ 108 | f(x,y,z,w) = \log_{x ^ 2 + \tanh w} (\cos (x ^ 2 + 2z) + w + 1) ^ 4 + \exp (x + \sin (\pi x + w ^ 2) \cosh ((e^y)^ 2 \sin z) ^ 2 (w + 1)) 109 | $$ 110 | 111 | ![slightly faster](./bench-results/4-ary/04.svg) 112 | 113 | ## TODOs 114 | 115 | - :checkmark: Explore more fine-grained use of delcont 116 | + See `Numeric.AD.DelCont.MultiPrompt` for PoC 117 | + We can abolish refs except for the ones for the outermost primitive variables 118 | * perhaps coroutine-like hack can eliminateThis 119 | + This implementation, however, is not as efficient as STRef-based in terms of time 120 | * This is because each continuation allocates different values rather than single mutable variable 121 | * But still in some cases, allocation can be slightly reduced by this approach (need confirmation) 122 | * In particular, as the # of variable increases, the time overhead seems decaying and allocation becomes slightly fewer 123 | - Avoids (indirect) references at any costs! 124 | - ~~Remove `Ref`s from constants~~ 125 | + This increases both runtime and allocation by twice (see [the benchmark log][const-ref-log]) 126 | + Branching overhead outweighs 127 | 128 | [const-ref-log]: https://github.com/konn/ad-delcont-primop/actions/runs/3924787010/jobs/6709300040 129 | 130 | ## References 131 | 132 | - Marco Zocca: [_ad-delcont: Reverse-mode automatic differentiation with delimited continuations_][ad-delcont] 133 | - Fei Wang et al.: [_Backpropagation with Continuation Callbacks: Foundations for Efficient and Expressive Differentiable Programming_][cc-differ] 134 | - Justin Le: [_backprop: Heterogeneous automatic differentation_][backprop] 135 | - Edward Kmett: [_ad: Automatic Differentiation_][ad] 136 | - The GHC Team: [_``Continuations'' section in GHC.Prim document for GHC 9.6_][cont-ghc-prim] 137 | - R. K. Dybvig et al.: [_A Monadic Framework for Delimited Continuations_][monadic-delcont] 138 | 139 | [ad-delcont]: https://hackage.haskell.org/package/ad-delcont 140 | [cc-differ]: https://papers.nips.cc/paper/2018/file/34e157766f31db3d2099831d348a7933-Paper.pdf 141 | [backprop]: https://backprop.jle.im 142 | [ad]: https://hackage.haskell.org/package/ad 143 | [cont-ghc-prim]: https://ghc.gitlab.haskell.org/ghc/doc/libraries/ghc-prim-0.10.0/GHC-Prim.html#continuations 144 | [monadic-delcont]: https://legacy.cs.indiana.edu/~dyb/pubs/monadicDC.pdf 145 | -------------------------------------------------------------------------------- /src/Numeric/AD/DelCont/Native/Double.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE GHC2021 #-} 2 | {-# LANGUAGE ImpredicativeTypes #-} 3 | {-# LANGUAGE RankNTypes #-} 4 | {-# OPTIONS_GHC -funbox-strict-fields #-} 5 | 6 | module Numeric.AD.DelCont.Native.Double ( 7 | ADDouble, 8 | eval, 9 | konst, 10 | diff, 11 | grad, 12 | jacobian, 13 | op1, 14 | op2, 15 | ) where 16 | 17 | import Control.Monad 18 | import Control.Monad.ST.Strict 19 | import Data.Bifoldable 20 | import Data.Bifunctor 21 | import Data.Bitraversable 22 | import Data.PRef 23 | import GHC.Generics 24 | import Numeric.AD.DelCont.Native.Internal 25 | 26 | data ADDouble s = AD {primal :: !Double, dual :: !(PromptTag () -> ST s (PRef s Double))} 27 | 28 | data a :!: b = !a :!: !b 29 | deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable) 30 | 31 | instance Bitraversable (:!:) where 32 | bitraverse f g (l :!: r) = (:!:) <$> f l <*> g r 33 | 34 | instance Bifunctor (:!:) where 35 | bimap = bimapDefault 36 | {-# INLINE bimap #-} 37 | 38 | instance Bifoldable (:!:) where 39 | bifoldMap = bifoldMapDefault 40 | {-# INLINE bifoldMap #-} 41 | 42 | withGrad' :: PromptTag () -> (Double -> ST s ()) -> ST s (PRef s Double) 43 | {-# INLINE withGrad' #-} 44 | withGrad' tag f = shift tag $ \k -> do 45 | bx <- newPRef 0.0 46 | k (pure bx) 47 | f =<< readPRef bx 48 | 49 | op1 :: (Double -> (Double, Double -> Double)) -> ADDouble s -> ADDouble s 50 | {-# INLINE op1 #-} 51 | op1 f (AD x toDx) = 52 | let (!fx, !deriv) = f x 53 | in AD fx $ \tag -> do 54 | dx <- toDx tag 55 | withGrad' tag $ \dc -> 56 | modifyPRef' dx (+ deriv dc) 57 | 58 | op2 :: 59 | (Double -> Double -> (Double, Double -> Double, Double -> Double)) -> 60 | ADDouble s -> 61 | ADDouble s -> 62 | ADDouble s 63 | {-# INLINE op2 #-} 64 | op2 f (AD x toDx) (AD y toDy) = 65 | let (!fx, !derivX, !derivY) = f x y 66 | in AD fx $ \tag -> do 67 | dx <- toDx tag 68 | dy <- toDy tag 69 | withGrad' tag $ \dc -> do 70 | modifyPRef' dx (+ derivX dc) 71 | modifyPRef' dy (+ derivY dc) 72 | 73 | konst :: Double -> ADDouble s 74 | konst = flip AD (const $ newPRef 0.0) 75 | 76 | instance Num (ADDouble s) where 77 | fromInteger = konst . fromInteger 78 | {-# INLINE fromInteger #-} 79 | signum = op1 $ \c -> (signum c, const 0) 80 | {-# INLINE signum #-} 81 | negate = op1 $ \c -> (negate c, negate) 82 | {-# INLINE negate #-} 83 | abs = op1 $ \x -> (abs x, (signum x *)) 84 | {-# INLINE abs #-} 85 | (+) = op2 $ \a b -> (a + b, id, id) 86 | {-# INLINE (+) #-} 87 | (-) = op2 $ \a b -> (a - b, id, negate) 88 | {-# INLINE (-) #-} 89 | (*) = op2 $ \a b -> (a * b, (* b), (a *)) 90 | {-# INLINE (*) #-} 91 | 92 | instance Fractional (ADDouble s) where 93 | fromRational = konst . fromRational 94 | {-# INLINE fromRational #-} 95 | recip = op1 $ \x -> (recip x, negate . (/ (x * x))) 96 | {-# INLINE recip #-} 97 | (/) = op2 $ \x y -> (x / y, (/ y), (* (-x / (y * y)))) 98 | {-# INLINE (/) #-} 99 | 100 | instance Floating (ADDouble s) where 101 | pi = konst pi 102 | {-# INLINE pi #-} 103 | exp = op1 $ \x -> (exp x, (exp x *)) 104 | {-# INLINE exp #-} 105 | log = op1 $ \x -> (log x, (/ x)) 106 | {-# INLINE log #-} 107 | logBase = op2 $ \x y -> 108 | ( logBase x y 109 | , (* (-logBase x y / (log x * x))) 110 | , (/ (y * log x)) 111 | ) 112 | {-# INLINE logBase #-} 113 | sqrt = op1 $ \x -> (sqrt x, (/ (2 * sqrt x))) 114 | {-# INLINE sqrt #-} 115 | sin = op1 $ \x -> (sin x, (* cos x)) 116 | {-# INLINE sin #-} 117 | cos = op1 $ \x -> (cos x, (* (-sin x))) 118 | {-# INLINE cos #-} 119 | tan = op1 $ \x -> (tan x, (/ cos x ^ (2 :: Int))) 120 | {-# INLINE tan #-} 121 | asin = op1 $ \x -> (asin x, (/ sqrt (1 - x * x))) 122 | {-# INLINE asin #-} 123 | acos = op1 $ \x -> (acos x, (/ sqrt (1 - x * x)) . negate) 124 | {-# INLINE acos #-} 125 | atan = op1 $ \x -> (atan x, (/ (1 + x * x))) 126 | {-# INLINE atan #-} 127 | sinh = op1 $ \x -> (sinh x, (* cosh x)) 128 | {-# INLINE sinh #-} 129 | cosh = op1 $ \x -> (cosh x, (* sinh x)) 130 | {-# INLINE cosh #-} 131 | tanh = op1 $ \x -> (tanh x, (/ cosh x ^ (2 :: Int))) 132 | {-# INLINE tanh #-} 133 | asinh = op1 $ \x -> (asinh x, (/ sqrt (x * x + 1))) 134 | {-# INLINE asinh #-} 135 | acosh = op1 $ \x -> (acosh x, (/ sqrt (x * x - 1))) 136 | {-# INLINE acosh #-} 137 | atanh = op1 $ \x -> (atanh x, (/ (1 - x * x))) 138 | {-# INLINE atanh #-} 139 | 140 | eval :: (forall s. ADDouble s -> ADDouble s) -> Double -> Double 141 | {-# INLINE eval #-} 142 | eval op = primal . op . konst 143 | 144 | getDual :: PromptTag () -> ADDouble s -> ST s (PRef s Double) 145 | {-# INLINE getDual #-} 146 | getDual tag = ($ tag) . dual 147 | 148 | diff :: (forall s. ADDouble s -> ADDouble s) -> Double -> Double 149 | {-# INLINE diff #-} 150 | diff op a = runST $ do 151 | ref <- newPRef 0 152 | tag <- newPromptTag 153 | prompt tag $ do 154 | dbRef <- getDual tag $ op $ AD a $ const $ pure ref 155 | writePRef dbRef 1 156 | readPRef ref 157 | 158 | grad :: 159 | (Traversable t) => 160 | (forall s. t (ADDouble s) -> ADDouble s) -> 161 | t Double -> 162 | t Double 163 | {-# INLINE grad #-} 164 | grad f xs = runST $ do 165 | inps <- mapM (\a -> (a,) <$> newPRef 0) xs 166 | tag <- newPromptTag 167 | prompt tag $ do 168 | db <- getDual tag $ f $ fmap (\(a, ref) -> AD a $ const $ pure ref) inps 169 | writePRef db 1 170 | mapM (readPRef . snd) inps 171 | 172 | jacobian :: 173 | (Traversable t, Traversable g) => 174 | (forall s. t (ADDouble s) -> g (ADDouble s)) -> 175 | t Double -> 176 | g (t Double) 177 | {-# INLINE jacobian #-} 178 | jacobian f xs = runST $ do 179 | inps <- mapM (\a -> (a,) <$> newPRef 0) xs 180 | let duals = fmap dual $ f $ fmap (\(a, ref) -> AD a $ const $ pure ref) inps 181 | forM duals $ \toDzi -> do 182 | tag <- newPromptTag 183 | prompt tag $ flip writePRef 1 =<< toDzi tag 184 | ans <- mapM (readPRef . snd) inps 185 | mapM_ (flip writePRef 0 . snd) inps 186 | pure ans 187 | -------------------------------------------------------------------------------- /src/Numeric/AD/DelCont/Native/MultiPrompt.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE GHC2021 #-} 2 | {-# LANGUAGE ImpredicativeTypes #-} 3 | {-# LANGUAGE LambdaCase #-} 4 | {-# LANGUAGE RankNTypes #-} 5 | {-# LANGUAGE ScopedTypeVariables #-} 6 | 7 | module Numeric.AD.DelCont.Native.MultiPrompt ( 8 | AD', 9 | AD, 10 | eval, 11 | konst, 12 | diff, 13 | grad, 14 | -- jacobian, 15 | op1, 16 | op2, 17 | ) where 18 | 19 | import Control.Monad 20 | import Control.Monad.ST.Strict 21 | import Data.Bifoldable 22 | import Data.Bifunctor 23 | import Data.Bitraversable 24 | import Data.Functor 25 | import Data.STRef 26 | import GHC.Generics 27 | import Numeric.AD.DelCont.Native.Internal 28 | 29 | data AD' s a da = AD {primal :: !a, dual :: !(PromptTag da -> ST s ())} 30 | 31 | type AD s a = AD' s a a 32 | 33 | data a :!: b = !a :!: !b 34 | deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable) 35 | 36 | fst' :: (a :!: b) -> a 37 | {-# INLINE fst' #-} 38 | fst' = \case a :!: _ -> a 39 | 40 | instance Bitraversable (:!:) where 41 | bitraverse f g (l :!: r) = (:!:) <$> f l <*> g r 42 | 43 | instance Bifunctor (:!:) where 44 | bimap = bimapDefault 45 | {-# INLINE bimap #-} 46 | 47 | instance Bifoldable (:!:) where 48 | bifoldMap = bifoldMapDefault 49 | {-# INLINE bifoldMap #-} 50 | 51 | op1 :: (a -> (b, db -> da)) -> AD' s a da -> AD' s b db 52 | {-# INLINE op1 #-} 53 | op1 f (AD x toDx) = 54 | let (!fx, !deriv) = f x 55 | in AD fx $ \tag -> control0 tag $ \k -> do 56 | !dy <- k $ pure () 57 | void $ reset $ \innerTag -> do 58 | !() <- toDx innerTag 59 | pure $! deriv dy 60 | pure dy 61 | 62 | op2 :: 63 | (a -> b -> (c, dc -> da, dc -> db)) -> 64 | AD' s a da -> 65 | AD' s b db -> 66 | AD' s c dc 67 | op2 f (AD x getDx) (AD y getDy) = 68 | let (!fx, !derivX, !derivY) = f x y 69 | in AD fx $ \tag -> do 70 | control0 tag $ \k -> do 71 | !dz <- k (pure ()) 72 | void $ reset $ \tagX -> do 73 | () <- getDx tagX 74 | pure $! derivX dz 75 | void $ reset $ \tagY -> do 76 | () <- getDy tagY 77 | pure $! derivY dz 78 | pure dz 79 | 80 | konst :: a -> AD' s a da 81 | konst = flip AD (const $ pure ()) 82 | 83 | instance (Num a, a ~ b) => Num (AD' s a b) where 84 | fromInteger = konst . fromInteger 85 | {-# INLINE fromInteger #-} 86 | signum = op1 $ \c -> (signum c, const 0) 87 | {-# INLINE signum #-} 88 | negate = op1 $ \c -> (negate c, negate) 89 | {-# INLINE negate #-} 90 | abs = op1 $ \x -> (abs x, (signum x *)) 91 | {-# INLINE abs #-} 92 | (+) = op2 $ \a b -> (a + b, id, id) 93 | {-# INLINE (+) #-} 94 | (-) = op2 $ \a b -> (a - b, id, negate) 95 | {-# INLINE (-) #-} 96 | (*) = op2 $ \a b -> (a * b, (* b), (a *)) 97 | {-# INLINE (*) #-} 98 | 99 | instance (Fractional a, a ~ b) => Fractional (AD' s a b) where 100 | fromRational = konst . fromRational 101 | {-# INLINE fromRational #-} 102 | recip = op1 $ \x -> (recip x, negate . (/ (x * x))) 103 | {-# INLINE recip #-} 104 | (/) = op2 $ \x y -> (x / y, (/ y), (* (-x / (y * y)))) 105 | {-# INLINE (/) #-} 106 | 107 | instance (Floating a, a ~ b) => Floating (AD' s a b) where 108 | pi = konst pi 109 | {-# INLINE pi #-} 110 | exp = op1 $ \x -> (exp x, (exp x *)) 111 | {-# INLINE exp #-} 112 | log = op1 $ \x -> (log x, (/ x)) 113 | {-# INLINE log #-} 114 | logBase = op2 $ \x y -> 115 | ( logBase x y 116 | , (* (-logBase x y / (log x * x))) 117 | , (/ (y * log x)) 118 | ) 119 | {-# INLINE logBase #-} 120 | sqrt = op1 $ \x -> (sqrt x, (/ (2 * sqrt x))) 121 | {-# INLINE sqrt #-} 122 | sin = op1 $ \x -> (sin x, (* cos x)) 123 | {-# INLINE sin #-} 124 | cos = op1 $ \x -> (cos x, (* (-sin x))) 125 | {-# INLINE cos #-} 126 | tan = op1 $ \x -> (tan x, (/ cos x ^ (2 :: Int))) 127 | {-# INLINE tan #-} 128 | asin = op1 $ \x -> (asin x, (/ sqrt (1 - x * x))) 129 | {-# INLINE asin #-} 130 | acos = op1 $ \x -> (acos x, (/ sqrt (1 - x * x)) . negate) 131 | {-# INLINE acos #-} 132 | atan = op1 $ \x -> (atan x, (/ (1 + x * x))) 133 | {-# INLINE atan #-} 134 | sinh = op1 $ \x -> (sinh x, (* cosh x)) 135 | {-# INLINE sinh #-} 136 | cosh = op1 $ \x -> (cosh x, (* sinh x)) 137 | {-# INLINE cosh #-} 138 | tanh = op1 $ \x -> (tanh x, (/ cosh x ^ (2 :: Int))) 139 | {-# INLINE tanh #-} 140 | asinh = op1 $ \x -> (asinh x, (/ sqrt (x * x + 1))) 141 | {-# INLINE asinh #-} 142 | acosh = op1 $ \x -> (acosh x, (/ sqrt (x * x - 1))) 143 | {-# INLINE acosh #-} 144 | atanh = op1 $ \x -> (atanh x, (/ (1 - x * x))) 145 | {-# INLINE atanh #-} 146 | 147 | eval :: (forall s. AD' s a da -> AD' s b db) -> a -> b 148 | {-# INLINE eval #-} 149 | eval op = primal . op . konst 150 | 151 | toDual :: PromptTag da -> AD' s a da -> ST s () 152 | {-# INLINE toDual #-} 153 | toDual tag = ($ tag) . dual 154 | 155 | diff :: forall a da b db. (Num da, Num db) => (forall s. AD' s a da -> AD' s b db) -> a -> da 156 | {-# INLINE diff #-} 157 | diff op a = runST $ do 158 | ref <- newSTRef 0 159 | let bdb = op $ AD a $ \tagDa -> control0 tagDa $ \k -> do 160 | !val <- k $! pure () 161 | modifySTRef' ref (+ val) 162 | pure val 163 | void $ reset $ \tag -> 1 <$ toDual tag bdb 164 | readSTRef ref 165 | 166 | grad :: 167 | (Num da, Num db, Traversable t) => 168 | (forall s. t (AD' s a da) -> AD' s b db) -> 169 | t a -> 170 | t da 171 | {-# INLINE grad #-} 172 | grad f xs = runST $ do 173 | refs <- mapM (\a -> (:!: a) <$> newSTRef 0) xs 174 | let bdb = 175 | f $ 176 | refs <&> \(ref :!: a) -> AD a $ \tagDa -> control0 tagDa $ \k -> do 177 | !val <- k $! pure () 178 | modifySTRef' ref (+ val) 179 | pure val 180 | void $ reset $ \tag -> 1 <$ toDual tag bdb 181 | mapM (readSTRef . fst') refs 182 | 183 | {- 184 | 185 | jacobian :: 186 | (Num da, Num db, Traversable t, Traversable g) => 187 | (forall s. t (AD' s a da) -> g (AD' s b db)) -> 188 | t a -> 189 | g (t da) 190 | {-# INLINE jacobian #-} 191 | jacobian f xs = runST $ do 192 | inps <- mapM (\a -> (a,) <$> newSTRef 0) xs 193 | let duals = fmap dual $ f $ fmap (\(a, ref) -> AD a $ const $ pure ref) inps 194 | forM duals $ \toDzi -> do 195 | tag <- newPromptTag 196 | prompt tag $ flip writeSTRef 1 =<< toDzi tag 197 | ans <- mapM (readSTRef . snd) inps 198 | mapM_ (flip writeSTRef 0 . snd) inps 199 | pure ans 200 | -} 201 | -------------------------------------------------------------------------------- /src/Numeric/AD/DelCont/Native.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE GHC2021 #-} 2 | {-# LANGUAGE ImpredicativeTypes #-} 3 | {-# LANGUAGE RankNTypes #-} 4 | 5 | module Numeric.AD.DelCont.Native ( 6 | AD', 7 | AD, 8 | eval, 9 | konst, 10 | konst', 11 | diff, 12 | grad, 13 | jacobian, 14 | op1, 15 | op1', 16 | op2, 17 | op2', 18 | ) where 19 | 20 | import Control.Monad 21 | import Control.Monad.ST.Strict 22 | import Data.Bifoldable 23 | import Data.Bifunctor 24 | import Data.Bitraversable 25 | import Data.STRef 26 | import GHC.Generics 27 | import Numeric.AD.DelCont.Native.Internal 28 | 29 | data AD' s a da = AD {primal :: !a, dual :: !(PromptTag () -> ST s (STRef s da))} 30 | 31 | type AD s a = AD' s a a 32 | 33 | data a :!: b = !a :!: !b 34 | deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable) 35 | 36 | instance Bitraversable (:!:) where 37 | bitraverse f g (l :!: r) = (:!:) <$> f l <*> g r 38 | 39 | instance Bifunctor (:!:) where 40 | bimap = bimapDefault 41 | {-# INLINE bimap #-} 42 | 43 | instance Bifoldable (:!:) where 44 | bifoldMap = bifoldMapDefault 45 | {-# INLINE bifoldMap #-} 46 | 47 | withGrad' :: a -> PromptTag () -> (a -> ST s ()) -> ST s (STRef s a) 48 | {-# INLINE withGrad' #-} 49 | withGrad' zero tag f = shift tag $ \k -> do 50 | bx <- newSTRef zero 51 | k (pure bx) 52 | f =<< readSTRef bx 53 | 54 | op1 :: (Num da, Num db) => (a -> (b, db -> da)) -> AD' s a da -> AD' s b db 55 | {-# INLINE op1 #-} 56 | op1 = op1' 0 (+) 57 | 58 | op1' :: db -> (da -> da -> da) -> (a -> (b, db -> da)) -> AD' s a da -> AD' s b db 59 | {-# INLINE op1' #-} 60 | op1' zerodb addda f (AD x toDx) = 61 | let (!fx, !deriv) = f x 62 | in AD fx $ \tag -> do 63 | dx <- toDx tag 64 | withGrad' zerodb tag $ \dc -> 65 | modifySTRef' dx (`addda` deriv dc) 66 | 67 | op2 :: 68 | (Num da, Num db, Num dc) => 69 | (a -> b -> (c, dc -> da, dc -> db)) -> 70 | AD' s a da -> 71 | AD' s b db -> 72 | AD' s c dc 73 | op2 = op2' 0 (+) (+) 74 | 75 | op2' :: 76 | dc -> 77 | (da -> da -> da) -> 78 | (db -> db -> db) -> 79 | (a -> b -> (c, dc -> da, dc -> db)) -> 80 | AD' s a da -> 81 | AD' s b db -> 82 | AD' s c dc 83 | op2' zeroDc addDa addDb f (AD x toDx) (AD y toDy) = 84 | let (!fx, !derivX, !derivY) = f x y 85 | in AD fx $ \tag -> do 86 | dx <- toDx tag 87 | dy <- toDy tag 88 | withGrad' zeroDc tag $ \dc -> do 89 | modifySTRef' dx (`addDa` derivX dc) 90 | modifySTRef' dy (`addDb` derivY dc) 91 | 92 | konst :: Num da => a -> AD' s a da 93 | konst = konst' 0 94 | 95 | konst' :: da -> a -> AD' s a da 96 | konst' zeroDa = flip AD (const $ newSTRef zeroDa) 97 | 98 | instance (Num a, a ~ b) => Num (AD' s a b) where 99 | fromInteger = konst . fromInteger 100 | {-# INLINE fromInteger #-} 101 | signum = op1 $ \c -> (signum c, const 0) 102 | {-# INLINE signum #-} 103 | negate = op1 $ \c -> (negate c, negate) 104 | {-# INLINE negate #-} 105 | abs = op1 $ \x -> (abs x, (signum x *)) 106 | {-# INLINE abs #-} 107 | (+) = op2 $ \a b -> (a + b, id, id) 108 | {-# INLINE (+) #-} 109 | (-) = op2 $ \a b -> (a - b, id, negate) 110 | {-# INLINE (-) #-} 111 | (*) = op2 $ \a b -> (a * b, (* b), (a *)) 112 | {-# INLINE (*) #-} 113 | 114 | instance (Fractional a, a ~ b) => Fractional (AD' s a b) where 115 | fromRational = konst . fromRational 116 | {-# INLINE fromRational #-} 117 | recip = op1 $ \x -> (recip x, negate . (/ (x * x))) 118 | {-# INLINE recip #-} 119 | (/) = op2 $ \x y -> (x / y, (/ y), (* (-x / (y * y)))) 120 | {-# INLINE (/) #-} 121 | 122 | instance (Floating a, a ~ b) => Floating (AD' s a b) where 123 | pi = konst pi 124 | {-# INLINE pi #-} 125 | exp = op1 $ \x -> (exp x, (exp x *)) 126 | {-# INLINE exp #-} 127 | log = op1 $ \x -> (log x, (/ x)) 128 | {-# INLINE log #-} 129 | logBase = op2 $ \x y -> 130 | ( logBase x y 131 | , (* (-logBase x y / (log x * x))) 132 | , (/ (y * log x)) 133 | ) 134 | {-# INLINE logBase #-} 135 | sqrt = op1 $ \x -> (sqrt x, (/ (2 * sqrt x))) 136 | {-# INLINE sqrt #-} 137 | sin = op1 $ \x -> (sin x, (* cos x)) 138 | {-# INLINE sin #-} 139 | cos = op1 $ \x -> (cos x, (* (-sin x))) 140 | {-# INLINE cos #-} 141 | tan = op1 $ \x -> (tan x, (/ cos x ^ (2 :: Int))) 142 | {-# INLINE tan #-} 143 | asin = op1 $ \x -> (asin x, (/ sqrt (1 - x * x))) 144 | {-# INLINE asin #-} 145 | acos = op1 $ \x -> (acos x, (/ sqrt (1 - x * x)) . negate) 146 | {-# INLINE acos #-} 147 | atan = op1 $ \x -> (atan x, (/ (1 + x * x))) 148 | {-# INLINE atan #-} 149 | sinh = op1 $ \x -> (sinh x, (* cosh x)) 150 | {-# INLINE sinh #-} 151 | cosh = op1 $ \x -> (cosh x, (* sinh x)) 152 | {-# INLINE cosh #-} 153 | tanh = op1 $ \x -> (tanh x, (/ cosh x ^ (2 :: Int))) 154 | {-# INLINE tanh #-} 155 | asinh = op1 $ \x -> (asinh x, (/ sqrt (x * x + 1))) 156 | {-# INLINE asinh #-} 157 | acosh = op1 $ \x -> (acosh x, (/ sqrt (x * x - 1))) 158 | {-# INLINE acosh #-} 159 | atanh = op1 $ \x -> (atanh x, (/ (1 - x * x))) 160 | {-# INLINE atanh #-} 161 | 162 | eval :: (Num da) => (forall s. AD' s a da -> AD' s b db) -> a -> b 163 | {-# INLINE eval #-} 164 | eval op = primal . op . konst 165 | 166 | getDual :: PromptTag () -> AD' s a da -> ST s (STRef s da) 167 | {-# INLINE getDual #-} 168 | getDual tag = ($ tag) . dual 169 | 170 | diff :: (Num da, Num db) => (forall s. AD' s a da -> AD' s b db) -> a -> da 171 | {-# INLINE diff #-} 172 | diff op a = runST $ do 173 | ref <- newSTRef 0 174 | tag <- newPromptTag 175 | prompt tag $ do 176 | dbRef <- getDual tag $ op $ AD a $ const $ pure ref 177 | writeSTRef dbRef 1 178 | readSTRef ref 179 | 180 | grad :: 181 | (Num da, Num db, Traversable t) => 182 | (forall s. t (AD' s a da) -> AD' s b db) -> 183 | t a -> 184 | t da 185 | {-# INLINE grad #-} 186 | grad f xs = runST $ do 187 | inps <- mapM (\a -> (a,) <$> newSTRef 0) xs 188 | tag <- newPromptTag 189 | prompt tag $ do 190 | db <- getDual tag $ f $ fmap (\(a, ref) -> AD a $ const $ pure ref) inps 191 | writeSTRef db 1 192 | mapM (readSTRef . snd) inps 193 | 194 | jacobian :: 195 | (Num da, Num db, Traversable t, Traversable g) => 196 | (forall s. t (AD' s a da) -> g (AD' s b db)) -> 197 | t a -> 198 | g (t da) 199 | {-# INLINE jacobian #-} 200 | jacobian f xs = runST $ do 201 | inps <- mapM (\a -> (a,) <$> newSTRef 0) xs 202 | let duals = fmap dual $ f $ fmap (\(a, ref) -> AD a $ const $ pure ref) inps 203 | forM duals $ \toDzi -> do 204 | tag <- newPromptTag 205 | prompt tag $ flip writeSTRef 1 =<< toDzi tag 206 | ans <- mapM (readSTRef . snd) inps 207 | mapM_ (flip writeSTRef 0 . snd) inps 208 | pure ans 209 | --------------------------------------------------------------------------------